├── larimar_base ├── pytorch_transformers │ ├── tests │ │ ├── __init__.py │ │ ├── fixtures │ │ │ ├── input.txt │ │ │ ├── test_sentencepiece.model │ │ │ └── sample_text.txt │ │ ├── conftest.py │ │ ├── tokenization_dilbert_test.py │ │ ├── tokenization_auto_test.py │ │ ├── tokenization_utils_test.py │ │ ├── configuration_common_test.py │ │ ├── tokenization_openai_test.py │ │ ├── tokenization_transfo_xl_test.py │ │ ├── tokenization_gpt2_test.py │ │ ├── tokenization_xlm_test.py │ │ ├── modeling_auto_test.py │ │ ├── tokenization_roberta_test.py │ │ ├── tokenization_xlnet_test.py │ │ ├── tokenization_bert_test.py │ │ └── optimization_test.py │ ├── utils │ │ ├── constants.py │ │ ├── dummy_keras_nlp_objects.py │ │ ├── dummy_sentencepiece_and_tokenizers_objects.py │ │ ├── dummy_tensorflow_text_objects.py │ │ ├── dummy_detectron2_objects.py │ │ ├── dummy_music_objects.py │ │ ├── dummy_speech_objects.py │ │ ├── dummy_essentia_and_librosa_and_pretty_midi_and_scipy_and_torch_objects.py │ │ ├── bitsandbytes.py │ │ ├── model_parallel_utils.py │ │ ├── versions.py │ │ ├── peft_utils.py │ │ ├── hp_naming.py │ │ ├── sentencepiece_model_pb2_new.py │ │ └── dummy_sentencepiece_objects.py │ ├── configuration_roberta.py │ ├── dependency_versions_check.py │ ├── model_parallel_utils.py │ ├── tokenization_distilbert.py │ ├── convert_tf_checkpoint_to_pytorch.py │ ├── convert_xlm_checkpoint_to_pytorch.py │ ├── convert_gpt2_checkpoint_to_pytorch.py │ ├── convert_openai_checkpoint_to_pytorch.py │ ├── dependency_versions_table.py │ ├── configuration_distilbert.py │ ├── tokenization_roberta.py │ ├── convert_xlnet_checkpoint_to_pytorch.py │ ├── __init__.py │ ├── integrations │ │ └── __init__.py │ ├── convert_pytorch_checkpoint_to_tf.py │ ├── configuration_openai.py │ ├── convert_transfo_xl_checkpoint_to_pytorch.py │ ├── configuration_bert.py │ └── generation │ │ └── stopping_criteria.py ├── modules │ ├── encoders │ │ ├── __init__.py │ │ ├── encoder.py │ │ ├── enc_lstm.py │ │ └── gaussian_encoder.py │ ├── __init__.py │ ├── decoders │ │ └── decoder.py │ ├── utils.py │ └── spacefusion.py ├── configs │ ├── ds_config.json │ ├── default_accelerate_config.yaml │ └── config_train_larimar.yaml ├── eval.sh ├── eval_rephrase.sh ├── main_pl.py ├── train_larimar.sh ├── lightning_data.py └── ddp.py ├── models └── model_locations.json ├── larimar_architecture.png ├── data ├── wikipedia_data_locations.json └── counterfactual_data_locations.json ├── requirements.txt └── README.md /larimar_base/pytorch_transformers/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /larimar_base/modules/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | from .enc_lstm import * -------------------------------------------------------------------------------- /models/model_locations.json: -------------------------------------------------------------------------------- 1 | { 2 | "larimar-1.3b-c3": "./larimar-1.3b-c3.ckpt" 3 | } 4 | -------------------------------------------------------------------------------- /larimar_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/larimar/HEAD/larimar_architecture.png -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/tests/fixtures/input.txt: -------------------------------------------------------------------------------- 1 | Who was Jim Henson ? ||| Jim Henson was a puppeteer 2 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/tests/fixtures/test_sentencepiece.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/larimar/HEAD/larimar_base/pytorch_transformers/tests/fixtures/test_sentencepiece.model -------------------------------------------------------------------------------- /data/wikipedia_data_locations.json: -------------------------------------------------------------------------------- 1 | { 2 | "wikipedia-64": "./wikipedia/blocksize_64", 3 | "wikipedia-128": "./wikipedia/blocksize_128", 4 | "wikipedia-256": "./wikipedia/blocksize_256" 5 | } 6 | -------------------------------------------------------------------------------- /larimar_base/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoders import * 2 | from .decoders import * 3 | from .vae import * 4 | from .utils import * 5 | from .spacefusion import * 6 | from .cara import * 7 | from .arae import * 8 | from .mem_vae import * 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | sentence-transformers==3.3.1 2 | lightning 3 | deepspeed 4 | nltk 5 | boto3 6 | sacremoses 7 | tensorboard 8 | jupyterlab 9 | scipy 10 | scikit-learn 11 | jsonargparse[signatures] 12 | spacy 13 | pandas 14 | 15 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/utils/constants.py: -------------------------------------------------------------------------------- 1 | IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406] 2 | IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225] 3 | IMAGENET_STANDARD_MEAN = [0.5, 0.5, 0.5] 4 | IMAGENET_STANDARD_STD = [0.5, 0.5, 0.5] 5 | OPENAI_CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073] 6 | OPENAI_CLIP_STD = [0.26862954, 0.26130258, 0.27577711] 7 | -------------------------------------------------------------------------------- /data/counterfactual_data_locations.json: -------------------------------------------------------------------------------- 1 | { 2 | "attribute_snippets": "https://rome.baulab.info/data/dsets/attribute_snippets.json", 3 | "counterfactual": "https://rome.baulab.info/data/dsets/counterfact.json", 4 | "idf": "https://rome.baulab.info/data/dsets/idf.npy", 5 | "tfidf_vocab": "https://rome.baulab.info/data/dsets/tfidf_vocab.json" 6 | } -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/utils/dummy_keras_nlp_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import DummyObject, requires_backends 3 | 4 | 5 | class TFGPT2Tokenizer(metaclass=DummyObject): 6 | _backends = ["keras_nlp"] 7 | 8 | def __init__(self, *args, **kwargs): 9 | requires_backends(self, ["keras_nlp"]) 10 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/utils/dummy_sentencepiece_and_tokenizers_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import DummyObject, requires_backends 3 | 4 | 5 | SLOW_TO_FAST_CONVERTERS = None 6 | 7 | 8 | def convert_slow_tokenizer(*args, **kwargs): 9 | requires_backends(convert_slow_tokenizer, ["sentencepiece", "tokenizers"]) 10 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/utils/dummy_tensorflow_text_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import DummyObject, requires_backends 3 | 4 | 5 | class TFBertTokenizer(metaclass=DummyObject): 6 | _backends = ["tensorflow_text"] 7 | 8 | def __init__(self, *args, **kwargs): 9 | requires_backends(self, ["tensorflow_text"]) 10 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/utils/dummy_detectron2_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import requires_backends 3 | 4 | 5 | LAYOUTLM_V2_PRETRAINED_MODEL_ARCHIVE_LIST = None 6 | 7 | 8 | class LayoutLMv2Model: 9 | def __init__(self, *args, **kwargs): 10 | requires_backends(self, ["detectron2"]) 11 | 12 | @classmethod 13 | def from_pretrained(cls, *args, **kwargs): 14 | requires_backends(cls, ["detectron2"]) 15 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/utils/dummy_music_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import DummyObject, requires_backends 3 | 4 | 5 | class Pop2PianoFeatureExtractor(metaclass=DummyObject): 6 | _backends = ["music"] 7 | 8 | def __init__(self, *args, **kwargs): 9 | requires_backends(self, ["music"]) 10 | 11 | 12 | class Pop2PianoTokenizer(metaclass=DummyObject): 13 | _backends = ["music"] 14 | 15 | def __init__(self, *args, **kwargs): 16 | requires_backends(self, ["music"]) 17 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/utils/dummy_speech_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import DummyObject, requires_backends 3 | 4 | 5 | class ASTFeatureExtractor(metaclass=DummyObject): 6 | _backends = ["speech"] 7 | 8 | def __init__(self, *args, **kwargs): 9 | requires_backends(self, ["speech"]) 10 | 11 | 12 | class Speech2TextFeatureExtractor(metaclass=DummyObject): 13 | _backends = ["speech"] 14 | 15 | def __init__(self, *args, **kwargs): 16 | requires_backends(self, ["speech"]) 17 | -------------------------------------------------------------------------------- /larimar_base/configs/ds_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": 2, 3 | "gradient_accumulation_steps": 4, 4 | "fp16": { 5 | "enabled": false, 6 | "min_loss_scale": 0.5, 7 | "fp16_scale_tolerance": 0.25, 8 | "opt_level": "O2" 9 | }, 10 | "bf16": { "enabled": true }, 11 | "zero_optimization": { 12 | "stage": 2, 13 | "offload_param": { 14 | "device": "cpu" 15 | }, 16 | "offload_optimizer": { 17 | "device": "cpu" 18 | }, 19 | "allgather_partitions": true, 20 | "allgather_bucket_size": 5e8, 21 | "contiguous_gradients": true 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/tests/conftest.py: -------------------------------------------------------------------------------- 1 | # content of conftest.py 2 | 3 | import pytest 4 | 5 | 6 | def pytest_addoption(parser): 7 | parser.addoption( 8 | "--runslow", action="store_true", default=False, help="run slow tests" 9 | ) 10 | 11 | 12 | def pytest_collection_modifyitems(config, items): 13 | if config.getoption("--runslow"): 14 | # --runslow given in cli: do not skip slow tests 15 | return 16 | skip_slow = pytest.mark.skip(reason="need --runslow option to run") 17 | for item in items: 18 | if "slow" in item.keywords: 19 | item.add_marker(skip_slow) 20 | -------------------------------------------------------------------------------- /larimar_base/configs/default_accelerate_config.yaml: -------------------------------------------------------------------------------- 1 | command_file: null 2 | commands: null 3 | compute_environment: LOCAL_MACHINE 4 | deepspeed_config: 5 | deepspeed_config_file: configs/ds_config.json 6 | zero3_init_flag: false 7 | distributed_type: DEEPSPEED 8 | downcast_bf16: 'yes' 9 | dynamo_backend: 'NO' 10 | fsdp_config: {} 11 | gpu_ids: null 12 | machine_rank: 0 13 | main_process_ip: null 14 | main_process_port: null 15 | main_training_function: main 16 | megatron_lm_config: {} 17 | num_machines: 1 18 | num_processes: 1 19 | rdzv_backend: static 20 | same_network: true 21 | tpu_name: null 22 | tpu_zone: null 23 | use_cpu: false 24 | -------------------------------------------------------------------------------- /larimar_base/eval.sh: -------------------------------------------------------------------------------- 1 | mode=pyrite 2 | dataset=counterfact 3 | cache_dir=../cache 4 | checkpoint=../models/larimar-1.3b-c3.ckpt 5 | data_dir=../data/counterfact 6 | res_dir_name=../eval/results 7 | num_eval_cases=2000 8 | scope_detect_threshold=0.3 9 | 10 | # scope detection 11 | python counterfact_eval.py \ 12 | --mode ${mode} \ 13 | --dataset ${dataset} \ 14 | --cache_dir ${cache_dir} \ 15 | --checkpoint ${checkpoint} \ 16 | --data_dir ${data_dir} \ 17 | --res_dir_name ${res_dir_name} \ 18 | --num_eval_cases ${num_eval_cases} \ 19 | --scope_detect_threshold ${scope_detect_threshold} 20 | 21 | # no scope 22 | python counterfact_eval.py \ 23 | --mode ${mode} \ 24 | --dataset ${dataset} \ 25 | --cache_dir ${cache_dir} \ 26 | --checkpoint ${checkpoint} \ 27 | --data_dir ${data_dir} \ 28 | --res_dir_name ${res_dir_name} \ 29 | --num_eval_cases ${num_eval_cases} \ 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/utils/dummy_essentia_and_librosa_and_pretty_midi_and_scipy_and_torch_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import DummyObject, requires_backends 3 | 4 | 5 | class Pop2PianoFeatureExtractor(metaclass=DummyObject): 6 | _backends = ["essentia", "librosa", "pretty_midi", "scipy", "torch"] 7 | 8 | def __init__(self, *args, **kwargs): 9 | requires_backends(self, ["essentia", "librosa", "pretty_midi", "scipy", "torch"]) 10 | 11 | 12 | class Pop2PianoTokenizer(metaclass=DummyObject): 13 | _backends = ["essentia", "librosa", "pretty_midi", "scipy", "torch"] 14 | 15 | def __init__(self, *args, **kwargs): 16 | requires_backends(self, ["essentia", "librosa", "pretty_midi", "scipy", "torch"]) 17 | 18 | 19 | class Pop2PianoProcessor(metaclass=DummyObject): 20 | _backends = ["essentia", "librosa", "pretty_midi", "scipy", "torch"] 21 | 22 | def __init__(self, *args, **kwargs): 23 | requires_backends(self, ["essentia", "librosa", "pretty_midi", "scipy", "torch"]) 24 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/utils/bitsandbytes.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import warnings 15 | 16 | 17 | warnings.warn( 18 | "transformers.utils.bitsandbytes module is deprecated and will be removed in a future version. Please import bitsandbytes modules directly from transformers.integrations", 19 | FutureWarning, 20 | ) 21 | 22 | from ..integrations import ( # noqa 23 | get_keys_to_not_convert, 24 | replace_8bit_linear, 25 | replace_with_bnb_linear, 26 | set_module_8bit_tensor_to_device, 27 | set_module_quantized_tensor_to_device, 28 | ) 29 | -------------------------------------------------------------------------------- /larimar_base/eval_rephrase.sh: -------------------------------------------------------------------------------- 1 | mode=pyrite 2 | dataset=counterfact 3 | cache_dir=../cache 4 | checkpoint=../models/larimar-1.3b-c3.ckpt 5 | data_dir=../data/counterfact 6 | res_dir_name=../eval/results 7 | num_eval_cases=2000 8 | scope_detect_threshold=0.3 9 | 10 | # scope detection 11 | for num_rephrases in 0 1 2 12 | do 13 | python counterfact_eval_rephrase.py \ 14 | --mode ${mode} \ 15 | --dataset ${dataset} \ 16 | --cache_dir ${cache_dir} \ 17 | --checkpoint ${checkpoint} \ 18 | --data_dir ${data_dir} \ 19 | --res_dir_name ${res_dir_name} \ 20 | --num_eval_cases ${num_eval_cases} \ 21 | --num_rephrases ${num_rephrases} \ 22 | --remove_distraction \ 23 | --scope_detect_threshold ${scope_detect_threshold} 24 | done 25 | 26 | 27 | 28 | # no scope 29 | for num_rephrases in 0 1 2 30 | do 31 | python counterfact_eval_rephrase.py \ 32 | --mode ${mode} \ 33 | --dataset ${dataset} \ 34 | --cache_dir ${cache_dir} \ 35 | --checkpoint ${checkpoint} \ 36 | --data_dir ${data_dir} \ 37 | --res_dir_name ${res_dir_name} \ 38 | --num_eval_cases ${num_eval_cases} \ 39 | --num_rephrases ${num_rephrases} \ 40 | --remove_distraction 41 | done 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/configuration_roberta.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ RoBERTa configuration """ 17 | 18 | from __future__ import (absolute_import, division, print_function, 19 | unicode_literals) 20 | 21 | import logging 22 | 23 | from .configuration_bert import BertConfig 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = { 28 | 'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json", 29 | 'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-config.json", 30 | 'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-config.json", 31 | } 32 | 33 | 34 | class RobertaConfig(BertConfig): 35 | pretrained_config_archive_map = ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP 36 | -------------------------------------------------------------------------------- /larimar_base/modules/encoders/encoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | from ..utils import log_sum_exp 6 | 7 | class EncoderBase(nn.Module): 8 | """docstring for EncoderBase""" 9 | def __init__(self): 10 | super(EncoderBase, self).__init__() 11 | 12 | def forward(self, x): 13 | """ 14 | Args: 15 | x: (batch_size, *) 16 | Returns: the tensors required to parameterize a distribution. 17 | E.g. for Gaussian encoder it returns the mean and variance tensors 18 | """ 19 | 20 | raise NotImplementedError 21 | 22 | def sample(self, input, nsamples): 23 | """sampling from the encoder 24 | Returns: Tensor1 25 | Tensor1: the tensor latent z with shape [batch, nsamples, nz] 26 | """ 27 | 28 | raise NotImplementedError 29 | 30 | def encode(self, input, nsamples): 31 | """perform the encoding and compute the KL term 32 | Returns: Tensor1, Tensor2 33 | Tensor1: the tensor latent z with shape [batch, nsamples, nz] 34 | Tensor2: the tenor of KL for each x with shape [batch] 35 | """ 36 | 37 | raise NotImplementedError 38 | 39 | 40 | def eval_inference_dist(self, x, z, param=None): 41 | """this function computes log q(z | x) 42 | Args: 43 | z: tensor 44 | different z points that will be evaluated, with 45 | shape [batch, nsamples, nz] 46 | Returns: Tensor1 47 | Tensor1: log q(z|x) with shape [batch, nsamples] 48 | """ 49 | 50 | raise NotImplementedError 51 | 52 | def calc_mi(self, x): 53 | """Approximate the mutual information between x and z 54 | I(x, z) = E_xE_{q(z|x)}log(q(z|x)) - E_xE_{q(z|x)}log(q(z)) 55 | Returns: Float 56 | """ 57 | 58 | raise NotImplementedError -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/tests/tokenization_dilbert_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | from io import open 20 | 21 | from pytorch_transformers.tokenization_distilbert import (DistilBertTokenizer) 22 | 23 | from .tokenization_tests_commons import CommonTestCases 24 | from .tokenization_bert_test import BertTokenizationTest 25 | 26 | class DistilBertTokenizationTest(BertTokenizationTest): 27 | 28 | tokenizer_class = DistilBertTokenizer 29 | 30 | def get_tokenizer(self, **kwargs): 31 | return DistilBertTokenizer.from_pretrained(self.tmpdirname, **kwargs) 32 | 33 | def test_sequence_builders(self): 34 | tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") 35 | 36 | text = tokenizer.encode("sequence builders") 37 | text_2 = tokenizer.encode("multi-sequence build") 38 | 39 | encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) 40 | encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) 41 | 42 | assert encoded_sentence == [101] + text + [102] 43 | assert encoded_pair == [101] + text + [102] + text_2 + [102] 44 | 45 | if __name__ == '__main__': 46 | unittest.main() 47 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/tests/tokenization_auto_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import shutil 21 | import pytest 22 | import logging 23 | 24 | from pytorch_transformers import AutoTokenizer, BertTokenizer, AutoTokenizer, GPT2Tokenizer 25 | from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP 26 | from pytorch_transformers.modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_MAP 27 | 28 | 29 | class AutoTokenizerTest(unittest.TestCase): 30 | def test_tokenizer_from_pretrained(self): 31 | logging.basicConfig(level=logging.INFO) 32 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 33 | tokenizer = AutoTokenizer.from_pretrained(model_name) 34 | self.assertIsNotNone(tokenizer) 35 | self.assertIsInstance(tokenizer, BertTokenizer) 36 | self.assertGreater(len(tokenizer), 0) 37 | 38 | for model_name in list(GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 39 | tokenizer = AutoTokenizer.from_pretrained(model_name) 40 | self.assertIsNotNone(tokenizer) 41 | self.assertIsInstance(tokenizer, GPT2Tokenizer) 42 | self.assertGreater(len(tokenizer), 0) 43 | 44 | 45 | if __name__ == "__main__": 46 | unittest.main() 47 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/tests/tokenization_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 HuggingFace Inc.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import six 21 | 22 | from pytorch_transformers import PreTrainedTokenizer 23 | from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer 24 | 25 | class TokenizerUtilsTest(unittest.TestCase): 26 | def check_tokenizer_from_pretrained(self, tokenizer_class): 27 | s3_models = list(tokenizer_class.max_model_input_sizes.keys()) 28 | for model_name in s3_models[:1]: 29 | tokenizer = tokenizer_class.from_pretrained(model_name) 30 | self.assertIsNotNone(tokenizer) 31 | self.assertIsInstance(tokenizer, tokenizer_class) 32 | self.assertIsInstance(tokenizer, PreTrainedTokenizer) 33 | 34 | for special_tok in tokenizer.all_special_tokens: 35 | if six.PY2: 36 | self.assertIsInstance(special_tok, unicode) 37 | else: 38 | self.assertIsInstance(special_tok, str) 39 | special_tok_id = tokenizer.convert_tokens_to_ids(special_tok) 40 | self.assertIsInstance(special_tok_id, int) 41 | 42 | def test_pretrained_tokenizers(self): 43 | self.check_tokenizer_from_pretrained(GPT2Tokenizer) 44 | 45 | if __name__ == "__main__": 46 | unittest.main() 47 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Larimar 2 | This repo contains a reference implementation of the paper 3 | [Larimar: Large Language Models with Episodic Memory Control](https://research.ibm.com/publications/larimar-large-language-models-with-episodic-memory-control). 4 | 5 |

6 | 7 |

8 | 9 | 10 | ## Install 11 | 12 | ``` 13 | conda create --prefix envs/larimar python=3.10 -y 14 | conda activate envs/larimar 15 | pip install -r requirements.txt 16 | python -m nltk.downloader punkt_tab 17 | python -m spacy download en_core_web_sm 18 | ``` 19 | 20 | ## Use 21 | Single-fact editing demo notebook: `larimar_base/single_fact_editing_demo.ipynb`. 22 | 23 | Assumes a trained `larimar-1.3b` model checkpoint available as `../models/larimar-1.3b-c3.ckpt`. Please see instructions below. 24 | 25 | 26 | ## Train 27 | To train a `larimar-1.3b` model, first download and extract under `../data` [this dataset tarball](https://ibm.box.com/shared/static/d90td7ycpv3u9jw4i1mecv1mt24heq3t.gz) and then: 28 | 29 | ``` 30 | cd larimar_base/ 31 | bash train_larimar.sh 32 | ``` 33 | This will train the model with configuration C3 as in the paper. Please, adjust to your environment by editing related entries in `train_larimar.sh` and `configs/config_train_larimar.yaml` files before launching. 34 | 35 | 36 | ## Evaluate 37 | Choose the larimar model to evaluate in `eval.sh`, `eval_rephrase.sh` and run: 38 | 39 | ``` 40 | cd larimar_base/ 41 | bash eval.sh 42 | bash eval_rephrase.sh 43 | ``` 44 | 45 | 46 | ## Citation 47 | ``` 48 | @misc{das2024larimarlargelanguagemodels, 49 | title={Larimar: Large Language Models with Episodic Memory Control}, 50 | author={Payel Das and Subhajit Chaudhury and Elliot Nelson and Igor Melnyk and Sarath Swaminathan and Sihui Dai and Aurélie Lozano and Georgios Kollias and Vijil Chenthamarakshan and Jiří and Navrátil and Soham Dan and Pin-Yu Chen}, 51 | year={2024}, 52 | eprint={2403.11901}, 53 | archivePrefix={arXiv}, 54 | primaryClass={cs.LG}, 55 | url={https://arxiv.org/abs/2403.11901}, 56 | } 57 | ``` 58 | -------------------------------------------------------------------------------- /larimar_base/modules/decoders/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class DecoderBase(nn.Module): 6 | """docstring for Decoder""" 7 | def __init__(self): 8 | super(DecoderBase, self).__init__() 9 | 10 | 11 | def freeze(self): 12 | for param in self.parameters(): 13 | param.requires_grad = False 14 | 15 | def decode(self, x, z): 16 | """ 17 | Args: 18 | x: (batch_size, seq_len) 19 | z: (batch_size, n_sample, nz) 20 | Returns: Tensor1 21 | Tensor1: the output logits with size (batch_size * n_sample, seq_len, vocab_size) 22 | """ 23 | 24 | raise NotImplementedError 25 | 26 | def reconstruct_error(self, x, z): 27 | """reconstruction loss 28 | Args: 29 | x: (batch_size, *) 30 | z: (batch_size, n_sample, nz) 31 | Returns: 32 | loss: (batch_size, n_sample). Loss 33 | across different sentence and z 34 | """ 35 | 36 | raise NotImplementedError 37 | 38 | def beam_search_decode(self, z, K): 39 | """beam search decoding 40 | Args: 41 | z: (batch_size, nz) 42 | K: the beam size 43 | Returns: List1 44 | List1: the decoded word sentence list 45 | """ 46 | 47 | raise NotImplementedError 48 | 49 | def sample_decode(self, z): 50 | """sampling from z 51 | Args: 52 | z: (batch_size, nz) 53 | Returns: List1 54 | List1: the decoded word sentence list 55 | """ 56 | 57 | raise NotImplementedError 58 | 59 | def greedy_decode(self, z): 60 | """greedy decoding from z 61 | Args: 62 | z: (batch_size, nz) 63 | Returns: List1 64 | List1: the decoded word sentence list 65 | """ 66 | 67 | raise NotImplementedError 68 | 69 | def log_probability(self, x, z): 70 | """ 71 | Args: 72 | x: (batch_size, *) 73 | z: (batch_size, n_sample, nz) 74 | Returns: 75 | log_p: (batch_size, n_sample). 76 | log_p(x|z) across different x and z 77 | """ 78 | 79 | raise NotImplementedError -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/dependency_versions_check.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .dependency_versions_table import deps 16 | from .utils.versions import require_version, require_version_core 17 | 18 | 19 | # define which module versions we always want to check at run time 20 | # (usually the ones defined in `install_requires` in setup.py) 21 | # 22 | # order specific notes: 23 | # - tqdm must be checked before tokenizers 24 | 25 | pkgs_to_check_at_runtime = [ 26 | "python", 27 | "tqdm", 28 | "regex", 29 | "requests", 30 | "packaging", 31 | "filelock", 32 | "numpy", 33 | "tokenizers", 34 | "huggingface-hub", 35 | "safetensors", 36 | "accelerate", 37 | "pyyaml", 38 | ] 39 | 40 | for pkg in pkgs_to_check_at_runtime: 41 | if pkg in deps: 42 | if pkg == "tokenizers": 43 | # must be loaded here, or else tqdm check may fail 44 | from .utils import is_tokenizers_available 45 | 46 | if not is_tokenizers_available(): 47 | continue # not required, check version only if installed 48 | elif pkg == "accelerate": 49 | # must be loaded here, or else tqdm check may fail 50 | from .utils import is_accelerate_available 51 | 52 | # Maybe switch to is_torch_available in the future here so that Accelerate is hard dep of 53 | # Transformers with PyTorch 54 | if not is_accelerate_available(): 55 | continue # not required, check version only if installed 56 | 57 | require_version_core(deps[pkg]) 58 | else: 59 | raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py") 60 | 61 | 62 | def dep_version_check(pkg, hint=None): 63 | require_version(deps[pkg], hint) 64 | -------------------------------------------------------------------------------- /larimar_base/configs/config_train_larimar.yaml: -------------------------------------------------------------------------------- 1 | ## MODEL 2 | model: 3 | # Encoder 4 | encoder_model_type: bert 5 | encoder_model_name_or_path: bert-base-cased 6 | cache_dir: '../cache' 7 | load_pretrained: false 8 | 9 | # Decoder 10 | decoder_model_type: gpt2 11 | decoder_model_name_or_path: gpt2 12 | 13 | # Auto-encoder 14 | latent_size: 768 15 | do_lower_case: false 16 | block_size: 64 17 | 18 | # Memory 19 | memory_size: 512 20 | direct_writing: true 21 | ordering: false 22 | pseudoinverse_approx_step: 15 23 | episode_sizes: [16] 24 | observation_noise_std: 0.000001 25 | identity: true 26 | w_logvar_setting: 3 27 | deterministic_w: false 28 | 29 | # Training 30 | learning_rate: 5e-5 31 | adam_epsilon: 1e-8 32 | warmup_steps: 0 33 | weight_decay: 0.0 34 | mlm: false 35 | mlm_probability: 0.15 36 | dim_target_kl: 0 37 | length_weighted_loss: false 38 | rec_strength: 1.0 39 | ae_strength: 1.0 40 | l2_strength: 0 41 | decode_rec_strength: 0.0 42 | beta: 0.5 43 | use_beta_schedule: true 44 | ratio_increase: 0.25 45 | ratio_zero: 0.5 46 | fb_mode: 1 47 | optimizer: adamw # or fusedadam or deepspeed 48 | 49 | # Evaluation 50 | bleu: false 51 | ae_only: true 52 | ae_read_write: true 53 | num_samples: 100 54 | read_iters: 1 55 | perturb: "" 56 | 57 | # Sampling 58 | temperature: 1 59 | top_k: 0 60 | top_p: 1 61 | 62 | ## DATA 63 | data: 64 | train_data_file: '../data/wikipedia/train.txt' 65 | eval_data_file: '../data/wikipedia/test.txt' 66 | num_data_workers: 4 67 | train_batch_size: 64 68 | eval_batch_size: 64 69 | max_seq_length: 512 70 | batches_per_bucket: 100 71 | use_labels: 0 72 | dataset: 'Wikipedia' 73 | use_philly: false # action='store_true' 74 | 75 | trainer: 76 | max_epochs: 4 77 | limit_val_batches: 0 78 | reload_dataloaders_every_n_epochs: 1 # to ensure reshuffling of data buckets 79 | limit_val_batches: 0 # don't run eval during training 80 | num_sanity_val_steps : 0 # don't run eval at the beginning 81 | default_root_dir: '../train/larimar/checkpoints/bert-base-cased-gpt2-wiki' 82 | callbacks: 83 | class_path: 'lightning.pytorch.callbacks.ModelCheckpoint' 84 | init_args: 85 | # every_n_epochs: 1 86 | # save_top_k: 3 87 | monitor: train/LOSS 88 | logger: 89 | class_path: 'lightning.pytorch.loggers.TensorBoardLogger' 90 | init_args: 91 | save_dir: '../train/larimar' 92 | name: '' 93 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/model_parallel_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from math import ceil 17 | 18 | 19 | def assert_device_map(device_map, num_blocks): 20 | blocks = list(range(0, num_blocks)) 21 | 22 | device_map_blocks = [item for sublist in list(device_map.values()) for item in sublist] 23 | 24 | # Duplicate check 25 | duplicate_blocks = [] 26 | for i in device_map_blocks: 27 | if device_map_blocks.count(i) > 1 and i not in duplicate_blocks: 28 | duplicate_blocks.append(i) 29 | # Missing blocks 30 | missing_blocks = [i for i in blocks if i not in device_map_blocks] 31 | extra_blocks = [i for i in device_map_blocks if i not in blocks] 32 | 33 | if len(duplicate_blocks) != 0: 34 | raise ValueError( 35 | "Duplicate attention blocks specified in device_map. Attention blocks must be specified to one device." 36 | " These attention blocks were specified more than once: " + str(duplicate_blocks) 37 | ) 38 | if len(missing_blocks) != 0: 39 | raise ValueError( 40 | "There are attention blocks for this model that are not specified in the device_map. Add these attention " 41 | "blocks to a device on the device_map: " + str(missing_blocks) 42 | ) 43 | if len(extra_blocks) != 0: 44 | raise ValueError( 45 | "The device_map contains more attention blocks than this model has. Remove these from the device_map:" 46 | + str(extra_blocks) 47 | ) 48 | 49 | 50 | def get_device_map(n_layers, devices): 51 | """Returns a dictionary of layers distributed evenly across all devices.""" 52 | layers = list(range(n_layers)) 53 | n_blocks = int(ceil(n_layers / len(devices))) 54 | layers_list = [layers[i : i + n_blocks] for i in range(0, n_layers, n_blocks)] 55 | 56 | return dict(zip(devices, layers_list)) 57 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/utils/model_parallel_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from math import ceil 17 | 18 | 19 | def assert_device_map(device_map, num_blocks): 20 | blocks = list(range(0, num_blocks)) 21 | 22 | device_map_blocks = [item for sublist in list(device_map.values()) for item in sublist] 23 | 24 | # Duplicate check 25 | duplicate_blocks = [] 26 | for i in device_map_blocks: 27 | if device_map_blocks.count(i) > 1 and i not in duplicate_blocks: 28 | duplicate_blocks.append(i) 29 | # Missing blocks 30 | missing_blocks = [i for i in blocks if i not in device_map_blocks] 31 | extra_blocks = [i for i in device_map_blocks if i not in blocks] 32 | 33 | if len(duplicate_blocks) != 0: 34 | raise ValueError( 35 | "Duplicate attention blocks specified in device_map. Attention blocks must be specified to one device." 36 | " These attention blocks were specified more than once: " + str(duplicate_blocks) 37 | ) 38 | if len(missing_blocks) != 0: 39 | raise ValueError( 40 | "There are attention blocks for this model that are not specified in the device_map. Add these attention " 41 | "blocks to a device on the device_map: " + str(missing_blocks) 42 | ) 43 | if len(extra_blocks) != 0: 44 | raise ValueError( 45 | "The device_map contains more attention blocks than this model has. Remove these from the device_map:" 46 | + str(extra_blocks) 47 | ) 48 | 49 | 50 | def get_device_map(n_layers, devices): 51 | """Returns a dictionary of layers distributed evenly across all devices.""" 52 | layers = list(range(n_layers)) 53 | n_blocks = int(ceil(n_layers / len(devices))) 54 | layers_list = [layers[i : i + n_blocks] for i in range(0, n_layers, n_blocks)] 55 | 56 | return dict(zip(devices, layers_list)) 57 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/tests/configuration_common_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 HuggingFace Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import copy 20 | import os 21 | import shutil 22 | import json 23 | import random 24 | import uuid 25 | 26 | import unittest 27 | import logging 28 | 29 | 30 | class ConfigTester(object): 31 | def __init__(self, parent, config_class=None, **kwargs): 32 | self.parent = parent 33 | self.config_class = config_class 34 | self.inputs_dict = kwargs 35 | 36 | def create_and_test_config_common_properties(self): 37 | config = self.config_class(**self.inputs_dict) 38 | self.parent.assertTrue(hasattr(config, 'vocab_size')) 39 | self.parent.assertTrue(hasattr(config, 'hidden_size')) 40 | self.parent.assertTrue(hasattr(config, 'num_attention_heads')) 41 | self.parent.assertTrue(hasattr(config, 'num_hidden_layers')) 42 | 43 | def create_and_test_config_to_json_string(self): 44 | config = self.config_class(**self.inputs_dict) 45 | obj = json.loads(config.to_json_string()) 46 | for key, value in self.inputs_dict.items(): 47 | self.parent.assertEqual(obj[key], value) 48 | 49 | def create_and_test_config_to_json_file(self): 50 | config_first = self.config_class(**self.inputs_dict) 51 | json_file_path = os.path.join(os.getcwd(), "config_" + str(uuid.uuid4()) + ".json") 52 | config_first.to_json_file(json_file_path) 53 | config_second = self.config_class.from_json_file(json_file_path) 54 | os.remove(json_file_path) 55 | self.parent.assertEqual(config_second.to_dict(), config_first.to_dict()) 56 | 57 | def run_common_tests(self): 58 | self.create_and_test_config_common_properties() 59 | self.create_and_test_config_to_json_string() 60 | self.create_and_test_config_to_json_file() 61 | 62 | if __name__ == "__main__": 63 | unittest.main() -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/tokenization_distilbert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for DistilBERT.""" 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import collections 20 | import logging 21 | import os 22 | import unicodedata 23 | from io import open 24 | 25 | from .tokenization_bert import BertTokenizer 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'} 30 | 31 | PRETRAINED_VOCAB_FILES_MAP = { 32 | 'vocab_file': 33 | { 34 | 'distilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 35 | 'distilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 36 | } 37 | } 38 | 39 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 40 | 'distilbert-base-uncased': 512, 41 | 'distilbert-base-uncased-distilled-squad': 512, 42 | } 43 | 44 | 45 | class DistilBertTokenizer(BertTokenizer): 46 | r""" 47 | Constructs a DistilBertTokenizer. 48 | :class:`~pytorch_transformers.DistilBertTokenizer` is identical to BertTokenizer and runs end-to-end tokenization: punctuation splitting + wordpiece 49 | 50 | Args: 51 | vocab_file: Path to a one-wordpiece-per-line vocabulary file 52 | do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False 53 | do_basic_tokenize: Whether to do basic tokenization before wordpiece. 54 | max_len: An artificial maximum length to truncate tokenized sequences to; Effective maximum length is always the 55 | minimum of this value (if specified) and the underlying BERT model's sequence length. 56 | never_split: List of tokens which will never be split during tokenization. Only has an effect when 57 | do_wordpiece_only=False 58 | """ 59 | 60 | vocab_files_names = VOCAB_FILES_NAMES 61 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 62 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 63 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/convert_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import argparse 22 | import torch 23 | 24 | from pytorch_transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert 25 | 26 | import logging 27 | logging.basicConfig(level=logging.INFO) 28 | 29 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 30 | # Initialise PyTorch model 31 | config = BertConfig.from_json_file(bert_config_file) 32 | print("Building PyTorch model from configuration: {}".format(str(config))) 33 | model = BertForPreTraining(config) 34 | 35 | # Load weights from tf checkpoint 36 | load_tf_weights_in_bert(model, config, tf_checkpoint_path) 37 | 38 | # Save pytorch-model 39 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 40 | torch.save(model.state_dict(), pytorch_dump_path) 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser() 45 | ## Required parameters 46 | parser.add_argument("--tf_checkpoint_path", 47 | default = None, 48 | type = str, 49 | required = True, 50 | help = "Path to the TensorFlow checkpoint path.") 51 | parser.add_argument("--bert_config_file", 52 | default = None, 53 | type = str, 54 | required = True, 55 | help = "The config json file corresponding to the pre-trained BERT model. \n" 56 | "This specifies the model architecture.") 57 | parser.add_argument("--pytorch_dump_path", 58 | default = None, 59 | type = str, 60 | required = True, 61 | help = "Path to the output PyTorch model.") 62 | args = parser.parse_args() 63 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, 64 | args.bert_config_file, 65 | args.pytorch_dump_path) 66 | -------------------------------------------------------------------------------- /larimar_base/main_pl.py: -------------------------------------------------------------------------------- 1 | import lightning 2 | from lightning.pytorch.cli import LightningCLI 3 | from lightning_model import MemNetLight 4 | from lightning_data import DataModule 5 | import os 6 | import subprocess 7 | 8 | 9 | def fix_infiniband(): 10 | ibv = subprocess.run('ibv_devinfo', stdout=subprocess.PIPE, stderr=subprocess.PIPE) 11 | lines = ibv.stdout.decode('utf-8').split('\n') 12 | exclude = '' 13 | for line in lines: 14 | if 'hca_id:' in line: 15 | name = line.split(':')[1].strip() 16 | if '\tport:' in line: 17 | port = line.split(':')[1].strip() 18 | if 'link_layer:' in line and 'Ethernet' in line: 19 | exclude = exclude + f'{name}:{port},' 20 | 21 | if exclude: 22 | exclude = '^' + exclude[:-1] 23 | print(exclude) 24 | os.environ['NCCL_IB_HCA'] = exclude 25 | 26 | 27 | def set_env(master_port): 28 | LSB_MCPU_HOSTS = os.environ["LSB_MCPU_HOSTS"].split(' ') # Parses Node list set by LSF, in format hostname proceeded by number of cores requested 29 | HOST_LIST = LSB_MCPU_HOSTS[::2] # Strips the cores per node items in the list 30 | os.environ["MASTER_ADDR"] = HOST_LIST[0] # Sets the MasterNode to thefirst node on the list of hosts 31 | os.environ["MASTER_PORT"] = master_port 32 | os.environ["NODE_RANK"] = str(HOST_LIST.index(os.environ["HOSTNAME"])) # Uses the list index for node rank, master node rank must be 0 33 | os.environ["NCCL_SOCKET_IFNAME"] = 'ib,bond' #"^docker0,lo" # avoids using docker of loopback interface 34 | os.environ["NCCL_DEBUG"] = "INFO" # sets NCCL debug to info, during distributed training, bugs in code show up as nccl errors 35 | os.environ["NCCL_IB_CUDA_SUPPORT"] = '1' # Force use of infiniband 36 | 37 | 38 | class MyLightningCLI(LightningCLI): 39 | def add_arguments_to_parser(self, parser): 40 | parser.link_arguments("model.block_size", "data.block_size") 41 | parser.link_arguments("model.perturb", "data.perturb") 42 | parser.link_arguments("model.encoder_model_type", "data.encoder_model_type") 43 | parser.link_arguments("model.encoder_model_name_or_path", "data.encoder_model_name_or_path") 44 | parser.link_arguments("model.decoder_model_type", "data.decoder_model_type") 45 | parser.link_arguments("model.decoder_model_name_or_path", "data.decoder_model_name_or_path") 46 | parser.link_arguments("model.cache_dir", "data.cache_dir") 47 | parser.link_arguments("model.do_lower_case", "data.do_lower_case") 48 | 49 | 50 | def cli_main(): 51 | MyLightningCLI(model_class=MemNetLight, datamodule_class=DataModule, save_config_kwargs={"overwrite": True}) 52 | #lightning.Trainer 53 | #lightning.pytorch.callbacks.ModelCheckpoint 54 | 55 | 56 | if __name__ == "__main__": 57 | fix_infiniband() 58 | set_env('53108') 59 | cli_main() 60 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/tests/tokenization_openai_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import json 20 | 21 | from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES 22 | 23 | from .tokenization_tests_commons import CommonTestCases 24 | 25 | 26 | class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester): 27 | 28 | tokenizer_class = OpenAIGPTTokenizer 29 | 30 | def setUp(self): 31 | super(OpenAIGPTTokenizationTest, self).setUp() 32 | 33 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt 34 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 35 | "w", "r", "t", 36 | "lo", "low", "er", 37 | "low", "lowest", "newer", "wider", ""] 38 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 39 | merges = ["#version: 0.2", "l o", "lo w", "e r", ""] 40 | 41 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 42 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) 43 | with open(self.vocab_file, "w") as fp: 44 | fp.write(json.dumps(vocab_tokens)) 45 | with open(self.merges_file, "w") as fp: 46 | fp.write("\n".join(merges)) 47 | 48 | def get_tokenizer(self, **kwargs): 49 | return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname, **kwargs) 50 | 51 | def get_input_output_texts(self): 52 | input_text = u"lower newer" 53 | output_text = u"lower newer" 54 | return input_text, output_text 55 | 56 | 57 | def test_full_tokenizer(self): 58 | tokenizer = OpenAIGPTTokenizer(self.vocab_file, self.merges_file) 59 | 60 | text = "lower" 61 | bpe_tokens = ["low", "er"] 62 | tokens = tokenizer.tokenize(text) 63 | self.assertListEqual(tokens, bpe_tokens) 64 | 65 | input_tokens = tokens + [""] 66 | input_bpe_tokens = [14, 15, 20] 67 | self.assertListEqual( 68 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 69 | 70 | 71 | if __name__ == '__main__': 72 | unittest.main() 73 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/tests/tokenization_transfo_xl_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | from io import open 20 | 21 | from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES 22 | 23 | from.tokenization_tests_commons import CommonTestCases 24 | 25 | class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester): 26 | 27 | tokenizer_class = TransfoXLTokenizer 28 | 29 | def setUp(self): 30 | super(TransfoXLTokenizationTest, self).setUp() 31 | 32 | vocab_tokens = [ 33 | "", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un", 34 | "running", ",", "low", "l", 35 | ] 36 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 37 | with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer: 38 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 39 | 40 | def get_tokenizer(self, **kwargs): 41 | kwargs['lower_case'] = True 42 | return TransfoXLTokenizer.from_pretrained(self.tmpdirname, **kwargs) 43 | 44 | def get_input_output_texts(self): 45 | input_text = u" UNwanted , running" 46 | output_text = u" unwanted, running" 47 | return input_text, output_text 48 | 49 | def test_full_tokenizer(self): 50 | tokenizer = TransfoXLTokenizer(vocab_file=self.vocab_file, lower_case=True) 51 | 52 | tokens = tokenizer.tokenize(u" UNwanted , running") 53 | self.assertListEqual(tokens, ["", "unwanted", ",", "running"]) 54 | 55 | self.assertListEqual( 56 | tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7]) 57 | 58 | def test_full_tokenizer_lower(self): 59 | tokenizer = TransfoXLTokenizer(lower_case=True) 60 | 61 | self.assertListEqual( 62 | tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "), 63 | ["hello", "!", "how", "are", "you", "?"]) 64 | 65 | def test_full_tokenizer_no_lower(self): 66 | tokenizer = TransfoXLTokenizer(lower_case=False) 67 | 68 | self.assertListEqual( 69 | tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "), 70 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 71 | 72 | 73 | if __name__ == '__main__': 74 | unittest.main() 75 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/tests/tokenization_gpt2_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import json 20 | from io import open 21 | 22 | from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES 23 | 24 | from .tokenization_tests_commons import CommonTestCases 25 | 26 | class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester): 27 | 28 | tokenizer_class = GPT2Tokenizer 29 | 30 | def setUp(self): 31 | super(GPT2TokenizationTest, self).setUp() 32 | 33 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt 34 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 35 | "\u0120", "\u0120l", "\u0120n", 36 | "\u0120lo", "\u0120low", "er", 37 | "\u0120lowest", "\u0120newer", "\u0120wider", ""] 38 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 39 | merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""] 40 | self.special_tokens_map = {"unk_token": ""} 41 | 42 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 43 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) 44 | with open(self.vocab_file, "w", encoding="utf-8") as fp: 45 | fp.write(json.dumps(vocab_tokens) + "\n") 46 | with open(self.merges_file, "w", encoding="utf-8") as fp: 47 | fp.write("\n".join(merges)) 48 | 49 | def get_tokenizer(self, **kwargs): 50 | kwargs.update(self.special_tokens_map) 51 | return GPT2Tokenizer.from_pretrained(self.tmpdirname, **kwargs) 52 | 53 | def get_input_output_texts(self): 54 | input_text = u"lower newer" 55 | output_text = u" lower newer" 56 | return input_text, output_text 57 | 58 | def test_full_tokenizer(self): 59 | tokenizer = GPT2Tokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) 60 | text = "lower newer" 61 | bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"] 62 | tokens = tokenizer.tokenize(text) 63 | self.assertListEqual(tokens, bpe_tokens) 64 | 65 | input_tokens = tokens + [tokenizer.unk_token] 66 | input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19] 67 | self.assertListEqual( 68 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 69 | 70 | 71 | if __name__ == '__main__': 72 | unittest.main() 73 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/convert_xlm_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | import json 21 | from io import open 22 | 23 | import torch 24 | import numpy 25 | 26 | from pytorch_transformers import CONFIG_NAME, WEIGHTS_NAME 27 | from pytorch_transformers.tokenization_xlm import VOCAB_FILES_NAMES 28 | 29 | import logging 30 | logging.basicConfig(level=logging.INFO) 31 | 32 | def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path): 33 | # Load checkpoint 34 | chkpt = torch.load(xlm_checkpoint_path, map_location='cpu') 35 | 36 | model = chkpt['model'] 37 | 38 | config = chkpt['params'] 39 | config = dict((n, v) for n, v in config.items() if not isinstance(v, (torch.FloatTensor, numpy.ndarray))) 40 | 41 | vocab = chkpt['dico_word2id'] 42 | vocab = dict((s + '' if s.find('@@') == -1 and i > 13 else s.replace('@@', ''), i) for s, i in vocab.items()) 43 | 44 | # Save pytorch-model 45 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 46 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 47 | pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['vocab_file'] 48 | 49 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 50 | torch.save(model, pytorch_weights_dump_path) 51 | 52 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 53 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 54 | f.write(json.dumps(config, indent=2) + "\n") 55 | 56 | print("Save vocab file to {}".format(pytorch_config_dump_path)) 57 | with open(pytorch_vocab_dump_path, "w", encoding="utf-8") as f: 58 | f.write(json.dumps(vocab, indent=2) + "\n") 59 | 60 | 61 | if __name__ == "__main__": 62 | parser = argparse.ArgumentParser() 63 | ## Required parameters 64 | parser.add_argument("--xlm_checkpoint_path", 65 | default = None, 66 | type = str, 67 | required = True, 68 | help = "Path the official PyTorch dump.") 69 | parser.add_argument("--pytorch_dump_folder_path", 70 | default = None, 71 | type = str, 72 | required = True, 73 | help = "Path to the output PyTorch model.") 74 | args = parser.parse_args() 75 | convert_xlm_checkpoint_to_pytorch(args.xlm_checkpoint_path, args.pytorch_dump_folder_path) 76 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/convert_gpt2_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from pytorch_transformers import (CONFIG_NAME, WEIGHTS_NAME, 25 | GPT2Config, 26 | GPT2Model, 27 | load_tf_weights_in_gpt2) 28 | 29 | import logging 30 | logging.basicConfig(level=logging.INFO) 31 | 32 | 33 | def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path): 34 | # Construct model 35 | if gpt2_config_file == "": 36 | config = GPT2Config() 37 | else: 38 | config = GPT2Config.from_json_file(gpt2_config_file) 39 | model = GPT2Model(config) 40 | 41 | # Load weights from numpy 42 | load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path) 43 | 44 | # Save pytorch-model 45 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 46 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 47 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 48 | torch.save(model.state_dict(), pytorch_weights_dump_path) 49 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 50 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 51 | f.write(config.to_json_string()) 52 | 53 | 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser() 56 | ## Required parameters 57 | parser.add_argument("--gpt2_checkpoint_path", 58 | default = None, 59 | type = str, 60 | required = True, 61 | help = "Path to the TensorFlow checkpoint path.") 62 | parser.add_argument("--pytorch_dump_folder_path", 63 | default = None, 64 | type = str, 65 | required = True, 66 | help = "Path to the output PyTorch model.") 67 | parser.add_argument("--gpt2_config_file", 68 | default = "", 69 | type = str, 70 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 71 | "This specifies the model architecture.") 72 | args = parser.parse_args() 73 | convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, 74 | args.gpt2_config_file, 75 | args.pytorch_dump_folder_path) 76 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/convert_openai_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from pytorch_transformers import (CONFIG_NAME, WEIGHTS_NAME, 25 | OpenAIGPTConfig, 26 | OpenAIGPTModel, 27 | load_tf_weights_in_openai_gpt) 28 | 29 | import logging 30 | logging.basicConfig(level=logging.INFO) 31 | 32 | 33 | def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): 34 | # Construct model 35 | if openai_config_file == "": 36 | config = OpenAIGPTConfig() 37 | else: 38 | config = OpenAIGPTConfig.from_json_file(openai_config_file) 39 | model = OpenAIGPTModel(config) 40 | 41 | # Load weights from numpy 42 | load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path) 43 | 44 | # Save pytorch-model 45 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 46 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 47 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 48 | torch.save(model.state_dict(), pytorch_weights_dump_path) 49 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 50 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 51 | f.write(config.to_json_string()) 52 | 53 | 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser() 56 | ## Required parameters 57 | parser.add_argument("--openai_checkpoint_folder_path", 58 | default = None, 59 | type = str, 60 | required = True, 61 | help = "Path to the TensorFlow checkpoint path.") 62 | parser.add_argument("--pytorch_dump_folder_path", 63 | default = None, 64 | type = str, 65 | required = True, 66 | help = "Path to the output PyTorch model.") 67 | parser.add_argument("--openai_config_file", 68 | default = "", 69 | type = str, 70 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 71 | "This specifies the model architecture.") 72 | args = parser.parse_args() 73 | convert_openai_checkpoint_to_pytorch(args.openai_checkpoint_folder_path, 74 | args.openai_config_file, 75 | args.pytorch_dump_folder_path) 76 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/tests/tokenization_xlm_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import json 20 | 21 | from pytorch_transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES 22 | 23 | from .tokenization_tests_commons import CommonTestCases 24 | 25 | class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester): 26 | 27 | tokenizer_class = XLMTokenizer 28 | 29 | def setUp(self): 30 | super(XLMTokenizationTest, self).setUp() 31 | 32 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt 33 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 34 | "w", "r", "t", 35 | "lo", "low", "er", 36 | "low", "lowest", "newer", "wider", ""] 37 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 38 | merges = ["l o 123", "lo w 1456", "e r 1789", ""] 39 | 40 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 41 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) 42 | with open(self.vocab_file, "w") as fp: 43 | fp.write(json.dumps(vocab_tokens)) 44 | with open(self.merges_file, "w") as fp: 45 | fp.write("\n".join(merges)) 46 | 47 | def get_tokenizer(self, **kwargs): 48 | return XLMTokenizer.from_pretrained(self.tmpdirname, **kwargs) 49 | 50 | def get_input_output_texts(self): 51 | input_text = u"lower newer" 52 | output_text = u"lower newer" 53 | return input_text, output_text 54 | 55 | def test_full_tokenizer(self): 56 | """ Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """ 57 | tokenizer = XLMTokenizer(self.vocab_file, self.merges_file) 58 | 59 | text = "lower" 60 | bpe_tokens = ["low", "er"] 61 | tokens = tokenizer.tokenize(text) 62 | self.assertListEqual(tokens, bpe_tokens) 63 | 64 | input_tokens = tokens + [""] 65 | input_bpe_tokens = [14, 15, 20] 66 | self.assertListEqual( 67 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 68 | 69 | def test_sequence_builders(self): 70 | tokenizer = XLMTokenizer.from_pretrained("xlm-mlm-en-2048") 71 | 72 | text = tokenizer.encode("sequence builders") 73 | text_2 = tokenizer.encode("multi-sequence build") 74 | 75 | encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) 76 | encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) 77 | 78 | assert encoded_sentence == [1] + text + [1] 79 | assert encoded_pair == [1] + text + [1] + text_2 + [1] 80 | 81 | if __name__ == '__main__': 82 | unittest.main() 83 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/dependency_versions_table.py: -------------------------------------------------------------------------------- 1 | # THIS FILE HAS BEEN AUTOGENERATED. To update: 2 | # 1. modify the `_deps` dict in setup.py 3 | # 2. run `make deps_table_update`` 4 | deps = { 5 | "Pillow": "Pillow<10.0.0", 6 | "accelerate": "accelerate>=0.20.3", 7 | "av": "av==9.2.0", 8 | "beautifulsoup4": "beautifulsoup4", 9 | "black": "black~=23.1", 10 | "codecarbon": "codecarbon==1.2.0", 11 | "cookiecutter": "cookiecutter==1.7.3", 12 | "dataclasses": "dataclasses", 13 | "datasets": "datasets!=2.5.0", 14 | "decord": "decord==0.6.0", 15 | "deepspeed": "deepspeed>=0.9.3", 16 | "diffusers": "diffusers", 17 | "dill": "dill<0.3.5", 18 | "evaluate": "evaluate>=0.2.0", 19 | "fairscale": "fairscale>0.3", 20 | "faiss-cpu": "faiss-cpu", 21 | "fastapi": "fastapi", 22 | "filelock": "filelock", 23 | "flax": "flax>=0.4.1,<=0.7.0", 24 | "ftfy": "ftfy", 25 | "fugashi": "fugashi>=1.0", 26 | "GitPython": "GitPython<3.1.19", 27 | "hf-doc-builder": "hf-doc-builder>=0.3.0", 28 | "huggingface-hub": "huggingface-hub>=0.16.4,<1.0", 29 | "importlib_metadata": "importlib_metadata", 30 | "ipadic": "ipadic>=1.0.0,<2.0", 31 | "isort": "isort>=5.5.4", 32 | "jax": "jax>=0.4.1,<=0.4.13", 33 | "jaxlib": "jaxlib>=0.4.1,<=0.4.13", 34 | "jieba": "jieba", 35 | "kenlm": "kenlm", 36 | "keras-nlp": "keras-nlp>=0.3.1", 37 | "librosa": "librosa", 38 | "nltk": "nltk", 39 | "natten": "natten>=0.14.6", 40 | "numpy": "numpy>=1.17", 41 | "onnxconverter-common": "onnxconverter-common", 42 | "onnxruntime-tools": "onnxruntime-tools>=1.4.2", 43 | "onnxruntime": "onnxruntime>=1.4.0", 44 | "opencv-python": "opencv-python", 45 | "optuna": "optuna", 46 | "optax": "optax>=0.0.8,<=0.1.4", 47 | "packaging": "packaging>=20.0", 48 | "parameterized": "parameterized", 49 | "phonemizer": "phonemizer", 50 | "protobuf": "protobuf", 51 | "psutil": "psutil", 52 | "pyyaml": "pyyaml>=5.1", 53 | "pydantic": "pydantic<2", 54 | "pytest": "pytest>=7.2.0", 55 | "pytest-timeout": "pytest-timeout", 56 | "pytest-xdist": "pytest-xdist", 57 | "python": "python>=3.8.0", 58 | "ray[tune]": "ray[tune]", 59 | "regex": "regex!=2019.12.17", 60 | "requests": "requests", 61 | "rhoknp": "rhoknp>=1.1.0,<1.3.1", 62 | "rjieba": "rjieba", 63 | "rouge-score": "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1", 64 | "ruff": "ruff>=0.0.241,<=0.0.259", 65 | "sacrebleu": "sacrebleu>=1.4.12,<2.0.0", 66 | "sacremoses": "sacremoses", 67 | "safetensors": "safetensors>=0.3.1", 68 | "sagemaker": "sagemaker>=2.31.0", 69 | "scikit-learn": "scikit-learn", 70 | "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92", 71 | "sigopt": "sigopt", 72 | "starlette": "starlette", 73 | "sudachipy": "sudachipy>=0.6.6", 74 | "sudachidict_core": "sudachidict_core>=20220729", 75 | "tensorflow-cpu": "tensorflow-cpu>=2.6,<2.15", 76 | "tensorflow": "tensorflow>=2.6,<2.15", 77 | "tensorflow-text": "tensorflow-text<2.15", 78 | "tf2onnx": "tf2onnx", 79 | "timeout-decorator": "timeout-decorator", 80 | "timm": "timm", 81 | "tokenizers": "tokenizers>=0.21,<0.22", 82 | # "tokenizers": "tokenizers>=0.13,<0.15", 83 | # "tokenizers": "tokenizers>=0.14,<0.15", 84 | "torch": "torch>=1.10,!=1.12.0", 85 | "torchaudio": "torchaudio", 86 | "torchvision": "torchvision", 87 | "pyctcdecode": "pyctcdecode>=0.4.0", 88 | "tqdm": "tqdm>=4.27", 89 | "unidic": "unidic>=1.0.2", 90 | "unidic_lite": "unidic_lite>=1.0.7", 91 | "urllib3": "urllib3<2.0.0", 92 | "uvicorn": "uvicorn", 93 | } 94 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/configuration_distilbert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ DistilBERT model configuration """ 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import sys 20 | import json 21 | import logging 22 | from io import open 23 | 24 | from .configuration_utils import PretrainedConfig 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 29 | 'distilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json", 30 | 'distilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-config.json" 31 | } 32 | 33 | 34 | class DistilBertConfig(PretrainedConfig): 35 | pretrained_config_archive_map = DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP 36 | 37 | def __init__(self, 38 | vocab_size_or_config_json_file=30522, 39 | max_position_embeddings=512, 40 | sinusoidal_pos_embds=True, 41 | n_layers=6, 42 | n_heads=12, 43 | dim=768, 44 | hidden_dim=4*768, 45 | dropout=0.1, 46 | attention_dropout=0.1, 47 | activation='gelu', 48 | initializer_range=0.02, 49 | tie_weights_=True, 50 | qa_dropout=0.1, 51 | seq_classif_dropout=0.2, 52 | **kwargs): 53 | super(DistilBertConfig, self).__init__(**kwargs) 54 | 55 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 56 | and isinstance(vocab_size_or_config_json_file, unicode)): 57 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: 58 | json_config = json.loads(reader.read()) 59 | for key, value in json_config.items(): 60 | self.__dict__[key] = value 61 | elif isinstance(vocab_size_or_config_json_file, int): 62 | self.vocab_size = vocab_size_or_config_json_file 63 | self.max_position_embeddings = max_position_embeddings 64 | self.sinusoidal_pos_embds = sinusoidal_pos_embds 65 | self.n_layers = n_layers 66 | self.n_heads = n_heads 67 | self.dim = dim 68 | self.hidden_dim = hidden_dim 69 | self.dropout = dropout 70 | self.attention_dropout = attention_dropout 71 | self.activation = activation 72 | self.initializer_range = initializer_range 73 | self.tie_weights_ = tie_weights_ 74 | self.qa_dropout = qa_dropout 75 | self.seq_classif_dropout = seq_classif_dropout 76 | else: 77 | raise ValueError("First argument must be either a vocabulary size (int)" 78 | " or the path to a pretrained model config file (str)") 79 | @property 80 | def hidden_size(self): 81 | return self.dim 82 | 83 | @property 84 | def num_attention_heads(self): 85 | return self.n_heads 86 | 87 | @property 88 | def num_hidden_layers(self): 89 | return self.n_layers 90 | -------------------------------------------------------------------------------- /larimar_base/train_larimar.sh: -------------------------------------------------------------------------------- 1 | 2 | #################################################################################################### 3 | # config yaml file 4 | #################################################################################################### 5 | config_file=configs/config_train_larimar.yaml 6 | 7 | 8 | #################################################################################################### 9 | # model_* 10 | #################################################################################################### 11 | model_encoder_model_name_or_path="bert-large-cased" 12 | model_decoder_model_name_or_path="gpt2-large" 13 | model_decode_rec_strength=1.0 14 | model_optimizer=adamw 15 | model_learning_rate=5e-5 16 | model_observation_noise_std=0.000001 17 | model_beta=0.5 18 | 19 | episode_length=16 20 | model_episode_sizes=[${episode_length}] 21 | 22 | 23 | 24 | #################################################################################################### 25 | # trainer_* 26 | #################################################################################################### 27 | trainer_devices=8 28 | trainer_max_epochs=5 29 | trainer_precision=32-true 30 | trainer_strategy=ddp 31 | trainer_callbacks_init_args_every_n_train_steps=20000 32 | trainer_callbacks_init_args_save_top_k=3 33 | 34 | 35 | 36 | #################################################################################################### 37 | # data_* 38 | #################################################################################################### 39 | data_train_batch_size=16 40 | data_num_chunks=false 41 | 42 | 43 | 44 | #################################################################################################### 45 | # directories and files 46 | #################################################################################################### 47 | 48 | # cache directory 49 | model_cache_dir=../cache 50 | 51 | # trained model directory 52 | decoder_name_save=$(echo "gpt2" | sed 's/\//-/g') 53 | loss_type=decoder_loss 54 | top_larimar_model_dir=../train/larimar/checkpoints 55 | larimar_model_description=${model_encoder_model_name_or_path}-${decoder_name_save}-large-wiki-ep-${episode_length}_${loss_type}_${model_observation_noise_std} 56 | trainer_default_root_dir=${top_larimar_model_dir}/${larimar_model_description} 57 | trainer_logger_init_args_save_dir=${trainer_default_root_dir} 58 | 59 | 60 | # training data directory and files 61 | block_size=64 62 | top_training_data_dir=../data 63 | training_data_dir=${top_training_data_dir}/wikipedia/blocksize_${block_size} 64 | data_train_data_file=${training_data_dir}/train.txt 65 | data_eval_data_file=${training_data_dir}/test.txt 66 | 67 | 68 | 69 | 70 | 71 | #################################################################################################### 72 | # train 73 | #################################################################################################### 74 | python main_pl.py fit \ 75 | --config ${config_file} \ 76 | --model.cache_dir=${model_cache_dir} \ 77 | --model.encoder_model_name_or_path ${model_encoder_model_name_or_path} \ 78 | --model.decoder_model_name_or_path ${model_decoder_model_name_or_path} \ 79 | --model.optimizer ${model_optimizer} \ 80 | --model.learning_rate ${model_learning_rate} \ 81 | --model.episode_sizes ${model_episode_sizes} \ 82 | --model.decode_rec_strength ${model_decode_rec_strength} \ 83 | --model.observation_noise_std ${model_observation_noise_std} \ 84 | --model.beta ${model_beta} \ 85 | --trainer.devices ${trainer_devices} \ 86 | --trainer.max_epochs ${trainer_max_epochs} \ 87 | --trainer.precision ${trainer_precision} \ 88 | --trainer.strategy ${trainer_strategy} \ 89 | --trainer.default_root_dir ${trainer_default_root_dir} \ 90 | --trainer.logger.init_args.save_dir ${trainer_logger_init_args_save_dir} \ 91 | --trainer.callbacks.init_args.every_n_train_steps ${trainer_callbacks_init_args_every_n_train_steps} \ 92 | --trainer.callbacks.init_args.save_top_k ${trainer_callbacks_init_args_save_top_k} \ 93 | --data.train_batch_size ${data_train_batch_size} \ 94 | --data.train_data_file ${data_train_data_file} \ 95 | --data.eval_data_file ${data_eval_data_file} \ 96 | --data.num_chunks ${data_num_chunks} 97 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/tests/modeling_auto_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import shutil 21 | import pytest 22 | import logging 23 | 24 | from pytorch_transformers import (AutoConfig, BertConfig, 25 | AutoModel, BertModel, 26 | AutoModelWithLMHead, BertForMaskedLM, 27 | AutoModelForSequenceClassification, BertForSequenceClassification, 28 | AutoModelForQuestionAnswering, BertForQuestionAnswering) 29 | from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP 30 | 31 | from .modeling_common_test import (CommonTestCases, ids_tensor) 32 | from .configuration_common_test import ConfigTester 33 | 34 | 35 | class AutoModelTest(unittest.TestCase): 36 | def test_model_from_pretrained(self): 37 | logging.basicConfig(level=logging.INFO) 38 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 39 | config = AutoConfig.from_pretrained(model_name) 40 | self.assertIsNotNone(config) 41 | self.assertIsInstance(config, BertConfig) 42 | 43 | model = AutoModel.from_pretrained(model_name) 44 | model, loading_info = AutoModel.from_pretrained(model_name, output_loading_info=True) 45 | self.assertIsNotNone(model) 46 | self.assertIsInstance(model, BertModel) 47 | for value in loading_info.values(): 48 | self.assertEqual(len(value), 0) 49 | 50 | def test_lmhead_model_from_pretrained(self): 51 | logging.basicConfig(level=logging.INFO) 52 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 53 | config = AutoConfig.from_pretrained(model_name) 54 | self.assertIsNotNone(config) 55 | self.assertIsInstance(config, BertConfig) 56 | 57 | model = AutoModelWithLMHead.from_pretrained(model_name) 58 | model, loading_info = AutoModelWithLMHead.from_pretrained(model_name, output_loading_info=True) 59 | self.assertIsNotNone(model) 60 | self.assertIsInstance(model, BertForMaskedLM) 61 | 62 | def test_sequence_classification_model_from_pretrained(self): 63 | logging.basicConfig(level=logging.INFO) 64 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 65 | config = AutoConfig.from_pretrained(model_name) 66 | self.assertIsNotNone(config) 67 | self.assertIsInstance(config, BertConfig) 68 | 69 | model = AutoModelForSequenceClassification.from_pretrained(model_name) 70 | model, loading_info = AutoModelForSequenceClassification.from_pretrained(model_name, output_loading_info=True) 71 | self.assertIsNotNone(model) 72 | self.assertIsInstance(model, BertForSequenceClassification) 73 | 74 | def test_question_answering_model_from_pretrained(self): 75 | logging.basicConfig(level=logging.INFO) 76 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 77 | config = AutoConfig.from_pretrained(model_name) 78 | self.assertIsNotNone(config) 79 | self.assertIsInstance(config, BertConfig) 80 | 81 | model = AutoModelForQuestionAnswering.from_pretrained(model_name) 82 | model, loading_info = AutoModelForQuestionAnswering.from_pretrained(model_name, output_loading_info=True) 83 | self.assertIsNotNone(model) 84 | self.assertIsInstance(model, BertForQuestionAnswering) 85 | 86 | 87 | if __name__ == "__main__": 88 | unittest.main() 89 | -------------------------------------------------------------------------------- /larimar_base/lightning_data.py: -------------------------------------------------------------------------------- 1 | import lightning as pl 2 | from utils import BucketingDataLoaderPL 3 | from lightning_model import prepare_enc_dec_tokenizer 4 | 5 | 6 | class DataModule(pl.LightningDataModule): 7 | 8 | def __init__(self, 9 | train_data_file, 10 | train_batch_size, 11 | eval_data_file, 12 | eval_batch_size, 13 | max_seq_length, 14 | perturb, 15 | use_labels, 16 | dataset, 17 | use_philly, 18 | num_data_workers, 19 | batches_per_bucket, 20 | block_size, 21 | encoder_model_type, 22 | encoder_model_name_or_path, 23 | decoder_model_type, 24 | decoder_model_name_or_path, 25 | cache_dir, 26 | do_lower_case, 27 | num_chunks): 28 | 29 | super().__init__() 30 | 31 | self.train_data_file = train_data_file 32 | self.train_batch_size = train_batch_size 33 | self.max_seq_length = max_seq_length 34 | self.eval_data_file = eval_data_file 35 | self.eval_batch_size = eval_batch_size 36 | self.perturb = perturb 37 | self.use_labels = use_labels 38 | self.dataset = dataset 39 | self.use_philly = use_philly 40 | self.block_size = block_size 41 | self.num_data_workers = num_data_workers 42 | self.batches_per_bucket = batches_per_bucket 43 | self.num_chunks = num_chunks 44 | 45 | tokenizer_encoder, tokenizer_decoder = prepare_enc_dec_tokenizer(encoder_model_type, 46 | encoder_model_name_or_path, 47 | decoder_model_type, 48 | decoder_model_name_or_path, 49 | cache_dir, 50 | do_lower_case, 51 | block_size) 52 | 53 | self.tokenizer = [tokenizer_encoder, tokenizer_decoder] 54 | 55 | def setup(self, stage=None): 56 | 57 | if stage == 'fit': 58 | self.traindl = BucketingDataLoaderPL(self.train_data_file, 59 | self.train_batch_size, 60 | self.max_seq_length, 61 | self.tokenizer, 62 | self.block_size, 63 | self.use_labels, 64 | self.dataset, 65 | self.use_philly, 66 | self.num_chunks, 67 | self.num_data_workers, 68 | batches_per_bucket=100, 69 | perturb=self.perturb, 70 | shuffle=True) 71 | if stage in ("fit", "validate"): 72 | self.valdl = BucketingDataLoaderPL(self.eval_data_file, 73 | self.eval_batch_size, 74 | self.max_seq_length, 75 | self.tokenizer, 76 | self.block_size, 77 | self.use_labels, 78 | self.dataset, 79 | self.use_philly, 80 | self.num_chunks, 81 | self.num_data_workers, 82 | batches_per_bucket=100, 83 | perturb=self.perturb, 84 | shuffle=False) 85 | else: 86 | return 87 | 88 | def train_dataloader(self): 89 | return self.traindl.get() 90 | 91 | def val_dataloader(self): 92 | return self.valdl.get() 93 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/tests/fixtures/sample_text.txt: -------------------------------------------------------------------------------- 1 | This text is included to make sure Unicode is handled properly: 力加勝北区ᴵᴺᵀᵃছজটডণত 2 | Text should be one-sentence-per-line, with empty lines between documents. 3 | This sample text is public domain and was randomly selected from Project Guttenberg. 4 | 5 | The rain had only ceased with the gray streaks of morning at Blazing Star, and the settlement awoke to a moral sense of cleanliness, and the finding of forgotten knives, tin cups, and smaller camp utensils, where the heavy showers had washed away the debris and dust heaps before the cabin doors. 6 | Indeed, it was recorded in Blazing Star that a fortunate early riser had once picked up on the highway a solid chunk of gold quartz which the rain had freed from its incumbering soil, and washed into immediate and glittering popularity. 7 | Possibly this may have been the reason why early risers in that locality, during the rainy season, adopted a thoughtful habit of body, and seldom lifted their eyes to the rifted or india-ink washed skies above them. 8 | "Cass" Beard had risen early that morning, but not with a view to discovery. 9 | A leak in his cabin roof,--quite consistent with his careless, improvident habits,--had roused him at 4 A. M., with a flooded "bunk" and wet blankets. 10 | The chips from his wood pile refused to kindle a fire to dry his bed-clothes, and he had recourse to a more provident neighbor's to supply the deficiency. 11 | This was nearly opposite. 12 | Mr. Cassius crossed the highway, and stopped suddenly. 13 | Something glittered in the nearest red pool before him. 14 | Gold, surely! 15 | But, wonderful to relate, not an irregular, shapeless fragment of crude ore, fresh from Nature's crucible, but a bit of jeweler's handicraft in the form of a plain gold ring. 16 | Looking at it more attentively, he saw that it bore the inscription, "May to Cass." 17 | Like most of his fellow gold-seekers, Cass was superstitious. 18 | 19 | The fountain of classic wisdom, Hypatia herself. 20 | As the ancient sage--the name is unimportant to a monk--pumped water nightly that he might study by day, so I, the guardian of cloaks and parasols, at the sacred doors of her lecture-room, imbibe celestial knowledge. 21 | From my youth I felt in me a soul above the matter-entangled herd. 22 | She revealed to me the glorious fact, that I am a spark of Divinity itself. 23 | A fallen star, I am, sir!' continued he, pensively, stroking his lean stomach--'a fallen star!--fallen, if the dignity of philosophy will allow of the simile, among the hogs of the lower world--indeed, even into the hog-bucket itself. Well, after all, I will show you the way to the Archbishop's. 24 | There is a philosophic pleasure in opening one's treasures to the modest young. 25 | Perhaps you will assist me by carrying this basket of fruit?' And the little man jumped up, put his basket on Philammon's head, and trotted off up a neighbouring street. 26 | Philammon followed, half contemptuous, half wondering at what this philosophy might be, which could feed the self-conceit of anything so abject as his ragged little apish guide; 27 | but the novel roar and whirl of the street, the perpetual stream of busy faces, the line of curricles, palanquins, laden asses, camels, elephants, which met and passed him, and squeezed him up steps and into doorways, as they threaded their way through the great Moon-gate into the ample street beyond, drove everything from his mind but wondering curiosity, and a vague, helpless dread of that great living wilderness, more terrible than any dead wilderness of sand which he had left behind. 28 | Already he longed for the repose, the silence of the Laura--for faces which knew him and smiled upon him; but it was too late to turn back now. 29 | His guide held on for more than a mile up the great main street, crossed in the centre of the city, at right angles, by one equally magnificent, at each end of which, miles away, appeared, dim and distant over the heads of the living stream of passengers, the yellow sand-hills of the desert; 30 | while at the end of the vista in front of them gleamed the blue harbour, through a network of countless masts. 31 | At last they reached the quay at the opposite end of the street; 32 | and there burst on Philammon's astonished eyes a vast semicircle of blue sea, ringed with palaces and towers. 33 | He stopped involuntarily; and his little guide stopped also, and looked askance at the young monk, to watch the effect which that grand panorama should produce on him. 34 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/tests/tokenization_roberta_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import json 19 | import unittest 20 | from io import open 21 | 22 | from pytorch_transformers.tokenization_roberta import RobertaTokenizer, VOCAB_FILES_NAMES 23 | from .tokenization_tests_commons import CommonTestCases 24 | 25 | 26 | class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): 27 | tokenizer_class = RobertaTokenizer 28 | 29 | def setUp(self): 30 | super(RobertaTokenizationTest, self).setUp() 31 | 32 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt 33 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 34 | "\u0120", "\u0120l", "\u0120n", 35 | "\u0120lo", "\u0120low", "er", 36 | "\u0120lowest", "\u0120newer", "\u0120wider", ""] 37 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 38 | merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""] 39 | self.special_tokens_map = {"unk_token": ""} 40 | 41 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 42 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) 43 | with open(self.vocab_file, "w", encoding="utf-8") as fp: 44 | fp.write(json.dumps(vocab_tokens) + "\n") 45 | with open(self.merges_file, "w", encoding="utf-8") as fp: 46 | fp.write("\n".join(merges)) 47 | 48 | def get_tokenizer(self, **kwargs): 49 | kwargs.update(self.special_tokens_map) 50 | return RobertaTokenizer.from_pretrained(self.tmpdirname, **kwargs) 51 | 52 | def get_input_output_texts(self): 53 | input_text = u"lower newer" 54 | output_text = u" lower newer" 55 | return input_text, output_text 56 | 57 | def test_full_tokenizer(self): 58 | tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) 59 | text = "lower newer" 60 | bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"] 61 | tokens = tokenizer.tokenize(text) 62 | self.assertListEqual(tokens, bpe_tokens) 63 | 64 | input_tokens = tokens + [tokenizer.unk_token] 65 | input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19] 66 | self.assertListEqual( 67 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 68 | 69 | def roberta_dict_integration_testing(self): 70 | tokenizer = self.get_tokenizer() 71 | 72 | self.assertListEqual( 73 | tokenizer.encode('Hello world!'), 74 | [0, 31414, 232, 328, 2] 75 | ) 76 | self.assertListEqual( 77 | tokenizer.encode('Hello world! cécé herlolip 418'), 78 | [0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2] 79 | ) 80 | 81 | def test_sequence_builders(self): 82 | tokenizer = RobertaTokenizer.from_pretrained("roberta-base") 83 | 84 | text = tokenizer.encode("sequence builders") 85 | text_2 = tokenizer.encode("multi-sequence build") 86 | 87 | encoded_text_from_decode = tokenizer.encode("sequence builders", add_special_tokens=True) 88 | encoded_pair_from_decode = tokenizer.encode("sequence builders", "multi-sequence build", add_special_tokens=True) 89 | 90 | encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) 91 | encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) 92 | 93 | assert encoded_sentence == encoded_text_from_decode 94 | assert encoded_pair == encoded_pair_from_decode 95 | 96 | 97 | if __name__ == '__main__': 98 | unittest.main() 99 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/tokenization_roberta.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for RoBERTa.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import sys 20 | import json 21 | import logging 22 | import os 23 | import regex as re 24 | from io import open 25 | 26 | from .tokenization_gpt2 import GPT2Tokenizer 27 | 28 | try: 29 | from functools import lru_cache 30 | except ImportError: 31 | # Just a dummy decorator to get the checks to run on python2 32 | # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. 33 | def lru_cache(): 34 | return lambda func: func 35 | 36 | logger = logging.getLogger(__name__) 37 | 38 | VOCAB_FILES_NAMES = { 39 | 'vocab_file': 'vocab.json', 40 | 'merges_file': 'merges.txt', 41 | } 42 | 43 | PRETRAINED_VOCAB_FILES_MAP = { 44 | 'vocab_file': 45 | { 46 | 'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json", 47 | 'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json", 48 | 'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-vocab.json", 49 | }, 50 | 'merges_file': 51 | { 52 | 'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-merges.txt", 53 | 'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt", 54 | 'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-merges.txt", 55 | }, 56 | } 57 | 58 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 59 | 'roberta-base': 512, 60 | 'roberta-large': 512, 61 | 'roberta-large-mnli': 512, 62 | } 63 | 64 | 65 | class RobertaTokenizer(GPT2Tokenizer): 66 | """ 67 | RoBERTa BPE tokenizer, derived from the GPT-2 tokenizer. Peculiarities: 68 | - Byte-level Byte-Pair-Encoding 69 | - Requires a space to start the input string => will add a space is there isn't. 70 | As a consequence, this tokenizer `encode` and `decode` method will not conserve 71 | the absence of a space at the beginning of a string: `tokenizer.decode(tokenizer.encode("Hello")) = " Hello" 72 | """ 73 | vocab_files_names = VOCAB_FILES_NAMES 74 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 75 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 76 | 77 | def __init__(self, vocab_file, merges_file, errors='replace', bos_token="", eos_token="", sep_token="", 78 | cls_token="", unk_token="", pad_token='', mask_token='', **kwargs): 79 | super(RobertaTokenizer, self).__init__(vocab_file=vocab_file, merges_file=merges_file, errors=errors, 80 | bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, 81 | sep_token=sep_token, cls_token=cls_token, pad_token=pad_token, 82 | mask_token=mask_token, **kwargs) 83 | 84 | def add_special_tokens_single_sentence(self, token_ids): 85 | """ 86 | Adds special tokens to a sequence for sequence classification tasks. 87 | A RoBERTa sequence has the following format: X 88 | """ 89 | return [self.cls_token_id] + token_ids + [self.sep_token_id] 90 | 91 | def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1): 92 | """ 93 | Adds special tokens to a sequence pair for sequence classification tasks. 94 | A RoBERTa sequence pair has the following format: A B 95 | """ 96 | sep = [self.sep_token_id] 97 | cls = [self.cls_token_id] 98 | return cls + token_ids_0 + sep + sep + token_ids_1 + sep 99 | -------------------------------------------------------------------------------- /larimar_base/modules/encoders/enc_lstm.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | 6 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence 7 | from .gaussian_encoder import GaussianEncoderBase 8 | from ..utils import log_sum_exp 9 | 10 | class GaussianLSTMEncoder(GaussianEncoderBase): 11 | """Gaussian LSTM Encoder with constant-length input""" 12 | def __init__(self, args, vocab_size, model_init, emb_init): 13 | super(GaussianLSTMEncoder, self).__init__() 14 | self.ni = args.ni 15 | self.nh = args.enc_nh 16 | self.nz = args.nz 17 | self.args = args 18 | 19 | self.embed = nn.Embedding(vocab_size, args.ni) 20 | 21 | self.lstm = nn.LSTM(input_size=args.ni, 22 | hidden_size=args.enc_nh, 23 | num_layers=1, 24 | batch_first=True, 25 | dropout=0) 26 | 27 | self.linear = nn.Linear(args.enc_nh, 2 * args.nz, bias=False) 28 | 29 | self.reset_parameters(model_init, emb_init) 30 | 31 | def reset_parameters(self, model_init, emb_init): 32 | # for name, param in self.lstm.named_parameters(): 33 | # # self.initializer(param) 34 | # if 'bias' in name: 35 | # nn.init.constant_(param, 0.0) 36 | # # model_init(param) 37 | # elif 'weight' in name: 38 | # model_init(param) 39 | 40 | # model_init(self.linear.weight) 41 | # emb_init(self.embed.weight) 42 | for param in self.parameters(): 43 | model_init(param) 44 | emb_init(self.embed.weight) 45 | 46 | 47 | def forward(self, input): 48 | """ 49 | Args: 50 | x: (batch_size, seq_len) 51 | Returns: Tensor1, Tensor2 52 | Tensor1: the mean tensor, shape (batch, nz) 53 | Tensor2: the logvar tensor, shape (batch, nz) 54 | """ 55 | 56 | # (batch_size, seq_len-1, args.ni) 57 | word_embed = self.embed(input) 58 | 59 | _, (last_state, last_cell) = self.lstm(word_embed) 60 | 61 | mean, logvar = self.linear(last_state).chunk(2, -1) 62 | 63 | # fix variance as a pre-defined value 64 | if self.args.fix_var > 0: 65 | logvar = mean.new_tensor([[[math.log(self.args.fix_var)]]]).expand_as(mean) 66 | 67 | return mean.squeeze(0), logvar.squeeze(0) 68 | 69 | # def eval_inference_mode(self, x): 70 | # """compute the mode points in the inference distribution 71 | # (in Gaussian case) 72 | # Returns: Tensor 73 | # Tensor: the posterior mode points with shape (*, nz) 74 | # """ 75 | 76 | # # (batch_size, nz) 77 | # mu, logvar = self.forward(x) 78 | 79 | 80 | class VarLSTMEncoder(GaussianLSTMEncoder): 81 | """Gaussian LSTM Encoder with variable-length input""" 82 | def __init__(self, args, vocab_size, model_init, emb_init): 83 | super(VarLSTMEncoder, self).__init__(args, vocab_size, model_init, emb_init) 84 | 85 | 86 | def forward(self, input): 87 | """ 88 | Args: 89 | input: tuple which contains x and sents_len 90 | x: (batch_size, seq_len) 91 | sents_len: long tensor of sentence lengths 92 | Returns: Tensor1, Tensor2 93 | Tensor1: the mean tensor, shape (batch, nz) 94 | Tensor2: the logvar tensor, shape (batch, nz) 95 | """ 96 | 97 | input, sents_len = input 98 | # (batch_size, seq_len, args.ni) 99 | word_embed = self.embed(input) 100 | 101 | packed_embed = pack_padded_sequence(word_embed, sents_len.tolist(), batch_first=True) 102 | 103 | _, (last_state, last_cell) = self.lstm(packed_embed) 104 | 105 | mean, logvar = self.linear(last_state).chunk(2, -1) 106 | 107 | return mean.squeeze(0), logvar.squeeze(0) 108 | 109 | def encode(self, input, nsamples): 110 | """perform the encoding and compute the KL term 111 | Args: 112 | input: tuple which contains x and sents_len 113 | Returns: Tensor1, Tensor2 114 | Tensor1: the tensor latent z with shape [batch, nsamples, nz] 115 | Tensor2: the tenor of KL for each x with shape [batch] 116 | """ 117 | 118 | # (batch_size, nz) 119 | mu, logvar = self.forward(input) 120 | 121 | # (batch, nsamples, nz) 122 | z = self.reparameterize(mu, logvar, nsamples) 123 | 124 | KL = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1).sum(dim=1) 125 | 126 | return z, KL 127 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/convert_xlnet_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import argparse 23 | import torch 24 | 25 | from pytorch_transformers import (CONFIG_NAME, WEIGHTS_NAME, 26 | XLNetConfig, 27 | XLNetLMHeadModel, XLNetForQuestionAnswering, 28 | XLNetForSequenceClassification, 29 | load_tf_weights_in_xlnet) 30 | 31 | GLUE_TASKS_NUM_LABELS = { 32 | "cola": 2, 33 | "mnli": 3, 34 | "mrpc": 2, 35 | "sst-2": 2, 36 | "sts-b": 1, 37 | "qqp": 2, 38 | "qnli": 2, 39 | "rte": 2, 40 | "wnli": 2, 41 | } 42 | 43 | import logging 44 | logging.basicConfig(level=logging.INFO) 45 | 46 | def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path, finetuning_task=None): 47 | # Initialise PyTorch model 48 | config = XLNetConfig.from_json_file(bert_config_file) 49 | 50 | finetuning_task = finetuning_task.lower() if finetuning_task is not None else "" 51 | if finetuning_task in GLUE_TASKS_NUM_LABELS: 52 | print("Building PyTorch XLNetForSequenceClassification model from configuration: {}".format(str(config))) 53 | config.finetuning_task = finetuning_task 54 | config.num_labels = GLUE_TASKS_NUM_LABELS[finetuning_task] 55 | model = XLNetForSequenceClassification(config) 56 | elif 'squad' in finetuning_task: 57 | config.finetuning_task = finetuning_task 58 | model = XLNetForQuestionAnswering(config) 59 | else: 60 | model = XLNetLMHeadModel(config) 61 | 62 | # Load weights from tf checkpoint 63 | load_tf_weights_in_xlnet(model, config, tf_checkpoint_path) 64 | 65 | # Save pytorch-model 66 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) 67 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) 68 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path))) 69 | torch.save(model.state_dict(), pytorch_weights_dump_path) 70 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path))) 71 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 72 | f.write(config.to_json_string()) 73 | 74 | 75 | if __name__ == "__main__": 76 | parser = argparse.ArgumentParser() 77 | ## Required parameters 78 | parser.add_argument("--tf_checkpoint_path", 79 | default = None, 80 | type = str, 81 | required = True, 82 | help = "Path to the TensorFlow checkpoint path.") 83 | parser.add_argument("--xlnet_config_file", 84 | default = None, 85 | type = str, 86 | required = True, 87 | help = "The config json file corresponding to the pre-trained XLNet model. \n" 88 | "This specifies the model architecture.") 89 | parser.add_argument("--pytorch_dump_folder_path", 90 | default = None, 91 | type = str, 92 | required = True, 93 | help = "Path to the folder to store the PyTorch model or dataset/vocab.") 94 | parser.add_argument("--finetuning_task", 95 | default = None, 96 | type = str, 97 | help = "Name of a task on which the XLNet TensorFloaw model was fine-tuned") 98 | args = parser.parse_args() 99 | print(args) 100 | 101 | convert_xlnet_checkpoint_to_pytorch(args.tf_checkpoint_path, 102 | args.xlnet_config_file, 103 | args.pytorch_dump_folder_path, 104 | args.finetuning_task) 105 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/utils/versions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Utilities for working with package versions 16 | """ 17 | 18 | import importlib.metadata 19 | import operator 20 | import re 21 | import sys 22 | from typing import Optional 23 | 24 | from packaging import version 25 | 26 | 27 | ops = { 28 | "<": operator.lt, 29 | "<=": operator.le, 30 | "==": operator.eq, 31 | "!=": operator.ne, 32 | ">=": operator.ge, 33 | ">": operator.gt, 34 | } 35 | 36 | 37 | def _compare_versions(op, got_ver, want_ver, requirement, pkg, hint): 38 | if got_ver is None or want_ver is None: 39 | raise ValueError( 40 | f"Unable to compare versions for {requirement}: need={want_ver} found={got_ver}. This is unusual. Consider" 41 | f" reinstalling {pkg}." 42 | ) 43 | if not ops[op](version.parse(got_ver), version.parse(want_ver)): 44 | raise ImportError( 45 | f"{requirement} is required for a normal functioning of this module, but found {pkg}=={got_ver}.{hint}" 46 | ) 47 | 48 | 49 | def require_version(requirement: str, hint: Optional[str] = None) -> None: 50 | """ 51 | Perform a runtime check of the dependency versions, using the exact same syntax used by pip. 52 | 53 | The installed module version comes from the *site-packages* dir via *importlib.metadata*. 54 | 55 | Args: 56 | requirement (`str`): pip style definition, e.g., "tokenizers==0.9.4", "tqdm>=4.27", "numpy" 57 | hint (`str`, *optional*): what suggestion to print in case of requirements not being met 58 | 59 | Example: 60 | 61 | ```python 62 | require_version("pandas>1.1.2") 63 | require_version("numpy>1.18.5", "this is important to have for whatever reason") 64 | ```""" 65 | 66 | hint = f"\n{hint}" if hint is not None else "" 67 | 68 | # non-versioned check 69 | if re.match(r"^[\w_\-\d]+$", requirement): 70 | pkg, op, want_ver = requirement, None, None 71 | else: 72 | match = re.findall(r"^([^!=<>\s]+)([\s!=<>]{1,2}.+)", requirement) 73 | if not match: 74 | raise ValueError( 75 | "requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but" 76 | f" got {requirement}" 77 | ) 78 | pkg, want_full = match[0] 79 | want_range = want_full.split(",") # there could be multiple requirements 80 | wanted = {} 81 | for w in want_range: 82 | match = re.findall(r"^([\s!=<>]{1,2})(.+)", w) 83 | if not match: 84 | raise ValueError( 85 | "requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23," 86 | f" but got {requirement}" 87 | ) 88 | op, want_ver = match[0] 89 | wanted[op] = want_ver 90 | if op not in ops: 91 | raise ValueError(f"{requirement}: need one of {list(ops.keys())}, but got {op}") 92 | 93 | # special case 94 | if pkg == "python": 95 | got_ver = ".".join([str(x) for x in sys.version_info[:3]]) 96 | for op, want_ver in wanted.items(): 97 | _compare_versions(op, got_ver, want_ver, requirement, pkg, hint) 98 | return 99 | 100 | # check if any version is installed 101 | try: 102 | got_ver = importlib.metadata.version(pkg) 103 | except importlib.metadata.PackageNotFoundError: 104 | raise importlib.metadata.PackageNotFoundError( 105 | f"The '{requirement}' distribution was not found and is required by this application. {hint}" 106 | ) 107 | 108 | # check that the right version is installed if version number or a range was provided 109 | if want_ver is not None: 110 | for op, want_ver in wanted.items(): 111 | _compare_versions(op, got_ver, want_ver, requirement, pkg, hint) 112 | 113 | 114 | def require_version_core(requirement): 115 | """require_version wrapper which emits a core-specific hint on failure""" 116 | hint = "Try: pip install transformers -U or pip install -e '.[dev]' if you're working with git main" 117 | return require_version(requirement, hint) 118 | -------------------------------------------------------------------------------- /larimar_base/modules/encoders/gaussian_encoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | from .encoder import EncoderBase 6 | from ..utils import log_sum_exp 7 | 8 | class GaussianEncoderBase(EncoderBase): 9 | """docstring for EncoderBase""" 10 | def __init__(self): 11 | super(GaussianEncoderBase, self).__init__() 12 | 13 | def freeze(self): 14 | for param in self.parameters(): 15 | param.requires_grad = False 16 | 17 | def forward(self, x): 18 | """ 19 | Args: 20 | x: (batch_size, *) 21 | Returns: Tensor1, Tensor2 22 | Tensor1: the mean tensor, shape (batch, nz) 23 | Tensor2: the logvar tensor, shape (batch, nz) 24 | """ 25 | 26 | raise NotImplementedError 27 | 28 | def encode_stats(self, x): 29 | 30 | return self.forward(x) 31 | 32 | def sample(self, input, nsamples): 33 | """sampling from the encoder 34 | Returns: Tensor1 35 | Tensor1: the tensor latent z with shape [batch, nsamples, nz] 36 | """ 37 | 38 | # (batch_size, nz) 39 | mu, logvar = self.forward(input) 40 | 41 | # (batch, nsamples, nz) 42 | z = self.reparameterize(mu, logvar, nsamples) 43 | 44 | return z, (mu, logvar) 45 | 46 | def encode(self, input, nsamples): 47 | """perform the encoding and compute the KL term 48 | Returns: Tensor1, Tensor2 49 | Tensor1: the tensor latent z with shape [batch, nsamples, nz] 50 | Tensor2: the tenor of KL for each x with shape [batch] 51 | """ 52 | 53 | # (batch_size, nz) 54 | mu, logvar = self.forward(input) 55 | 56 | # (batch, nsamples, nz) 57 | z = self.reparameterize(mu, logvar, nsamples) 58 | 59 | KL = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1).sum(dim=1) 60 | 61 | return z, KL 62 | 63 | def reparameterize(self, mu, logvar, nsamples=1): 64 | """sample from posterior Gaussian family 65 | Args: 66 | mu: Tensor 67 | Mean of gaussian distribution with shape (batch, nz) 68 | logvar: Tensor 69 | logvar of gaussian distibution with shape (batch, nz) 70 | Returns: Tensor 71 | Sampled z with shape (batch, nsamples, nz) 72 | """ 73 | batch_size, nz = mu.size() 74 | std = logvar.mul(0.5).exp() 75 | 76 | mu_expd = mu.unsqueeze(1).expand(batch_size, nsamples, nz) 77 | std_expd = std.unsqueeze(1).expand(batch_size, nsamples, nz) 78 | 79 | eps = torch.zeros_like(std_expd).normal_() 80 | 81 | return mu_expd + torch.mul(eps, std_expd) 82 | 83 | def eval_inference_dist(self, x, z, param=None): 84 | """this function computes log q(z | x) 85 | Args: 86 | z: tensor 87 | different z points that will be evaluated, with 88 | shape [batch, nsamples, nz] 89 | Returns: Tensor1 90 | Tensor1: log q(z|x) with shape [batch, nsamples] 91 | """ 92 | 93 | nz = z.size(2) 94 | 95 | if not param: 96 | mu, logvar = self.forward(x) 97 | else: 98 | mu, logvar = param 99 | 100 | # (batch_size, 1, nz) 101 | mu, logvar = mu.unsqueeze(1), logvar.unsqueeze(1) 102 | var = logvar.exp() 103 | 104 | # (batch_size, nsamples, nz) 105 | dev = z - mu 106 | 107 | # (batch_size, nsamples) 108 | log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \ 109 | 0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1)) 110 | 111 | return log_density 112 | 113 | 114 | 115 | def calc_mi(self, x): 116 | """Approximate the mutual information between x and z 117 | I(x, z) = E_xE_{q(z|x)}log(q(z|x)) - E_xE_{q(z|x)}log(q(z)) 118 | Returns: Float 119 | """ 120 | 121 | # [x_batch, nz] 122 | mu, logvar = self.forward(x) 123 | 124 | x_batch, nz = mu.size() 125 | 126 | # E_{q(z|x)}log(q(z|x)) = -0.5*nz*log(2*\pi) - 0.5*(1+logvar).sum(-1) 127 | neg_entropy = (-0.5 * nz * math.log(2 * math.pi)- 0.5 * (1 + logvar).sum(-1)).mean() 128 | 129 | # [z_batch, 1, nz] 130 | z_samples = self.reparameterize(mu, logvar, 1) 131 | 132 | # [1, x_batch, nz] 133 | mu, logvar = mu.unsqueeze(0), logvar.unsqueeze(0) 134 | var = logvar.exp() 135 | 136 | # (z_batch, x_batch, nz) 137 | dev = z_samples - mu 138 | 139 | # (z_batch, x_batch) 140 | log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \ 141 | 0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1)) 142 | 143 | # log q(z): aggregate posterior 144 | # [z_batch] 145 | log_qz = log_sum_exp(log_density, dim=1) - math.log(x_batch) 146 | 147 | return (neg_entropy - log_qz.mean(-1)).item() -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.2.0" 2 | # Work around to update TensorFlow's absl.logging threshold which alters the 3 | # default Python logging output behavior when present. 4 | # see: https://github.com/abseil/abseil-py/issues/99 5 | # and: https://github.com/tensorflow/tensorflow/issues/26691#issuecomment-500369493 6 | try: 7 | import absl.logging 8 | absl.logging.set_verbosity('info') 9 | absl.logging.set_stderrthreshold('info') 10 | absl.logging._warn_preinit_stderr = False 11 | except: 12 | pass 13 | 14 | # Tokenizer 15 | from .tokenization_utils import (PreTrainedTokenizer) 16 | from .tokenization_auto import AutoTokenizer 17 | from .tokenization_bert import BertTokenizer, BasicTokenizer, WordpieceTokenizer 18 | from .tokenization_openai import OpenAIGPTTokenizer 19 | from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus) 20 | from .tokenization_gpt2 import GPT2Tokenizer 21 | # from .tokenization_gptj import GPTJTokenizer 22 | from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE 23 | from .tokenization_xlm import XLMTokenizer 24 | from .tokenization_roberta import RobertaTokenizer 25 | from .tokenization_distilbert import DistilBertTokenizer 26 | 27 | # Configurations 28 | from .configuration_utils import PretrainedConfig 29 | from .configuration_auto import AutoConfig 30 | from .configuration_bert import BertConfig, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP 31 | from .configuration_openai import OpenAIGPTConfig, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP 32 | from .configuration_transfo_xl import TransfoXLConfig, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP 33 | from .configuration_gpt2 import GPT2Config, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP 34 | from .configuration_gptj import GPTJConfig, GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP 35 | from .configuration_xlnet import XLNetConfig, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP 36 | from .configuration_xlm import XLMConfig, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP 37 | from .configuration_roberta import RobertaConfig, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP 38 | from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP 39 | 40 | # Modeling 41 | from .modeling_utils import (PreTrainedModel, prune_layer, Conv1D) 42 | from .modeling_auto import (AutoModel, AutoModelForSequenceClassification, AutoModelForQuestionAnswering, 43 | AutoModelWithLMHead) 44 | 45 | from .modeling_bert import (BertPreTrainedModel, BertModel, BertForLatentConnector, BertForPreTraining,BertForSequenceClassificationLatentConnector, 46 | BertForMaskedLM, BertForNextSentencePrediction, 47 | BertForSequenceClassification, BertForMultipleChoice, 48 | BertForTokenClassification, BertForQuestionAnswering, 49 | load_tf_weights_in_bert)#, BERT_PRETRAINED_MODEL_ARCHIVE_MAP) 50 | from .modeling_openai import (OpenAIGPTPreTrainedModel, OpenAIGPTModel, 51 | OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, 52 | load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP) 53 | from .modeling_transfo_xl import (TransfoXLPreTrainedModel, TransfoXLModel, TransfoXLLMHeadModel, 54 | load_tf_weights_in_transfo_xl, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP) 55 | from .modeling_gpt2 import (GPT2PreTrainedModel, GPT2Model, GPT2ForLatentConnector, 56 | GPT2LMHeadModel, GPT2DoubleHeadsModel, 57 | load_tf_weights_in_gpt2, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP) 58 | from .modeling_gptj import (GPTJPreTrainedModel, GPTJModel, GPTJForLatentConnector) 59 | from .modeling_xlnet import (XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel, 60 | XLNetForSequenceClassification, XLNetForQuestionAnswering, XLNetForMultipleChoice, 61 | load_tf_weights_in_xlnet, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP) 62 | from .modeling_xlm import (XLMPreTrainedModel , XLMModel, 63 | XLMWithLMHeadModel, XLMForSequenceClassification, 64 | XLMForQuestionAnswering, XLM_PRETRAINED_MODEL_ARCHIVE_MAP) 65 | from .modeling_roberta import (RobertaForMaskedLM, RobertaModel, RobertaForSequenceClassification, 66 | RobertaForMultipleChoice, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP) 67 | from .modeling_distilbert import (DistilBertForMaskedLM, DistilBertModel, 68 | DistilBertForSequenceClassification, DistilBertForQuestionAnswering, 69 | DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP) 70 | 71 | # Optimization 72 | from .optimization import (AdamW, WarmupLinearSchedule, ) 73 | 74 | # Files and general utilities 75 | from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE, 76 | cached_path, add_start_docstrings, add_end_docstrings, 77 | WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME) 78 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/integrations/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import TYPE_CHECKING 15 | 16 | from ..utils import _LazyModule 17 | 18 | 19 | _import_structure = { 20 | "bitsandbytes": [ 21 | "get_keys_to_not_convert", 22 | "replace_8bit_linear", 23 | "replace_with_bnb_linear", 24 | "set_module_8bit_tensor_to_device", 25 | "set_module_quantized_tensor_to_device", 26 | ], 27 | "deepspeed": [ 28 | "HfDeepSpeedConfig", 29 | "HfTrainerDeepSpeedConfig", 30 | "deepspeed_config", 31 | "deepspeed_init", 32 | "deepspeed_load_checkpoint", 33 | "deepspeed_optim_sched", 34 | "is_deepspeed_available", 35 | "is_deepspeed_zero3_enabled", 36 | "set_hf_deepspeed_config", 37 | "unset_hf_deepspeed_config", 38 | ], 39 | "integration_utils": [ 40 | "INTEGRATION_TO_CALLBACK", 41 | "AzureMLCallback", 42 | "ClearMLCallback", 43 | "CodeCarbonCallback", 44 | "CometCallback", 45 | "DagsHubCallback", 46 | "FlyteCallback", 47 | "MLflowCallback", 48 | "NeptuneCallback", 49 | "NeptuneMissingConfiguration", 50 | "TensorBoardCallback", 51 | "WandbCallback", 52 | "get_available_reporting_integrations", 53 | "get_reporting_integration_callbacks", 54 | "hp_params", 55 | "is_azureml_available", 56 | "is_clearml_available", 57 | "is_codecarbon_available", 58 | "is_comet_available", 59 | "is_dagshub_available", 60 | "is_fairscale_available", 61 | "is_flyte_deck_standard_available", 62 | "is_flytekit_available", 63 | "is_mlflow_available", 64 | "is_neptune_available", 65 | "is_optuna_available", 66 | "is_ray_available", 67 | "is_ray_tune_available", 68 | "is_sigopt_available", 69 | "is_tensorboard_available", 70 | "is_wandb_available", 71 | "rewrite_logs", 72 | "run_hp_search_optuna", 73 | "run_hp_search_ray", 74 | "run_hp_search_sigopt", 75 | "run_hp_search_wandb", 76 | ], 77 | "peft": ["PeftAdapterMixin"], 78 | } 79 | 80 | if TYPE_CHECKING: 81 | from .bitsandbytes import ( 82 | get_keys_to_not_convert, 83 | replace_8bit_linear, 84 | replace_with_bnb_linear, 85 | set_module_8bit_tensor_to_device, 86 | set_module_quantized_tensor_to_device, 87 | ) 88 | from .deepspeed import ( 89 | HfDeepSpeedConfig, 90 | HfTrainerDeepSpeedConfig, 91 | deepspeed_config, 92 | deepspeed_init, 93 | deepspeed_load_checkpoint, 94 | deepspeed_optim_sched, 95 | is_deepspeed_available, 96 | is_deepspeed_zero3_enabled, 97 | set_hf_deepspeed_config, 98 | unset_hf_deepspeed_config, 99 | ) 100 | from .integration_utils import ( 101 | INTEGRATION_TO_CALLBACK, 102 | AzureMLCallback, 103 | ClearMLCallback, 104 | CodeCarbonCallback, 105 | CometCallback, 106 | DagsHubCallback, 107 | FlyteCallback, 108 | MLflowCallback, 109 | NeptuneCallback, 110 | NeptuneMissingConfiguration, 111 | TensorBoardCallback, 112 | WandbCallback, 113 | get_available_reporting_integrations, 114 | get_reporting_integration_callbacks, 115 | hp_params, 116 | is_azureml_available, 117 | is_clearml_available, 118 | is_codecarbon_available, 119 | is_comet_available, 120 | is_dagshub_available, 121 | is_fairscale_available, 122 | is_flyte_deck_standard_available, 123 | is_flytekit_available, 124 | is_mlflow_available, 125 | is_neptune_available, 126 | is_optuna_available, 127 | is_ray_available, 128 | is_ray_tune_available, 129 | is_sigopt_available, 130 | is_tensorboard_available, 131 | is_wandb_available, 132 | rewrite_logs, 133 | run_hp_search_optuna, 134 | run_hp_search_ray, 135 | run_hp_search_sigopt, 136 | run_hp_search_wandb, 137 | ) 138 | from .peft import PeftAdapterMixin 139 | else: 140 | import sys 141 | 142 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) 143 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/convert_pytorch_checkpoint_to_tf.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint.""" 17 | 18 | import os 19 | import argparse 20 | import torch 21 | import numpy as np 22 | import tensorflow as tf 23 | from pytorch_transformers import BertModel 24 | 25 | 26 | def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:str): 27 | 28 | """ 29 | :param model:BertModel Pytorch model instance to be converted 30 | :param ckpt_dir: Tensorflow model directory 31 | :param model_name: model name 32 | :return: 33 | 34 | Currently supported HF models: 35 | Y BertModel 36 | N BertForMaskedLM 37 | N BertForPreTraining 38 | N BertForMultipleChoice 39 | N BertForNextSentencePrediction 40 | N BertForSequenceClassification 41 | N BertForQuestionAnswering 42 | """ 43 | 44 | tensors_to_transpose = ( 45 | "dense.weight", 46 | "attention.self.query", 47 | "attention.self.key", 48 | "attention.self.value" 49 | ) 50 | 51 | var_map = ( 52 | ('layer.', 'layer_'), 53 | ('word_embeddings.weight', 'word_embeddings'), 54 | ('position_embeddings.weight', 'position_embeddings'), 55 | ('token_type_embeddings.weight', 'token_type_embeddings'), 56 | ('.', '/'), 57 | ('LayerNorm/weight', 'LayerNorm/gamma'), 58 | ('LayerNorm/bias', 'LayerNorm/beta'), 59 | ('weight', 'kernel') 60 | ) 61 | 62 | if not os.path.isdir(ckpt_dir): 63 | os.makedirs(ckpt_dir) 64 | 65 | state_dict = model.state_dict() 66 | 67 | def to_tf_var_name(name:str): 68 | for patt, repl in iter(var_map): 69 | name = name.replace(patt, repl) 70 | return 'bert/{}'.format(name) 71 | 72 | def create_tf_var(tensor:np.ndarray, name:str, session:tf.Session): 73 | tf_dtype = tf.dtypes.as_dtype(tensor.dtype) 74 | tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer()) 75 | session.run(tf.variables_initializer([tf_var])) 76 | session.run(tf_var) 77 | return tf_var 78 | 79 | tf.reset_default_graph() 80 | with tf.Session() as session: 81 | for var_name in state_dict: 82 | tf_name = to_tf_var_name(var_name) 83 | torch_tensor = state_dict[var_name].numpy() 84 | if any([x in var_name for x in tensors_to_transpose]): 85 | torch_tensor = torch_tensor.T 86 | tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session) 87 | tf.keras.backend.set_value(tf_var, torch_tensor) 88 | tf_weight = session.run(tf_var) 89 | print("Successfully created {}: {}".format(tf_name, np.allclose(tf_weight, torch_tensor))) 90 | 91 | saver = tf.train.Saver(tf.trainable_variables()) 92 | saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt")) 93 | 94 | 95 | def main(raw_args=None): 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument("--model_name", 98 | type=str, 99 | required=True, 100 | help="model name e.g. bert-base-uncased") 101 | parser.add_argument("--cache_dir", 102 | type=str, 103 | default=None, 104 | required=False, 105 | help="Directory containing pytorch model") 106 | parser.add_argument("--pytorch_model_path", 107 | type=str, 108 | required=True, 109 | help="/path/to/.bin") 110 | parser.add_argument("--tf_cache_dir", 111 | type=str, 112 | required=True, 113 | help="Directory in which to save tensorflow model") 114 | args = parser.parse_args(raw_args) 115 | 116 | model = BertModel.from_pretrained( 117 | pretrained_model_name_or_path=args.model_name, 118 | state_dict=torch.load(args.pytorch_model_path), 119 | cache_dir=args.cache_dir 120 | ) 121 | 122 | convert_pytorch_checkpoint_to_tf( 123 | model=model, 124 | ckpt_dir=args.tf_cache_dir, 125 | model_name=args.model_name 126 | ) 127 | 128 | 129 | if __name__ == "__main__": 130 | main() 131 | -------------------------------------------------------------------------------- /larimar_base/modules/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")): 5 | """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering 6 | Args: 7 | logits: logits distribution shape (vocabulary size) 8 | top_k >0: keep only top k tokens with highest probability (top-k filtering). 9 | top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). 10 | Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) 11 | """ 12 | assert ( 13 | logits.dim() == 1 14 | ) # batch size 1 for now - could be updated for more but the code would be less clear 15 | top_k = min(top_k, logits.size(-1)) # Safety check 16 | if top_k > 0: 17 | # Remove all tokens with a probability less than the last token of the top-k 18 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 19 | logits[indices_to_remove] = filter_value 20 | 21 | if top_p > 0.0: 22 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 23 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 24 | 25 | # Remove tokens with cumulative probability above the threshold 26 | sorted_indices_to_remove = cumulative_probs > top_p 27 | # Shift the indices to the right to keep also the first token above the threshold 28 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 29 | sorted_indices_to_remove[..., 0] = 0 30 | 31 | indices_to_remove = sorted_indices[sorted_indices_to_remove] 32 | logits[indices_to_remove] = filter_value 33 | return logits 34 | 35 | def top_k_top_p_filtering_batch(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): 36 | """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering 37 | Args: 38 | logits: logits distribution shape (vocabulary size) 39 | top_k > 0: keep only top k tokens with highest probability (top-k filtering). 40 | top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). 41 | Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) 42 | From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 43 | """ 44 | # assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear 45 | 46 | top_k = min(top_k, logits.size(-1)) # Safety check 47 | 48 | if top_k > 0: 49 | # Remove all tokens with a probability less than the last token of the top-k 50 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 51 | # logits.masked_fill_(logits < threshold, filter_value) # (B, vocab_size) 52 | logits[indices_to_remove] = filter_value 53 | 54 | if top_p > 0.0: 55 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) # (B, vocab_size) 56 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # (B, vocab_size) 57 | 58 | # Remove tokens with cumulative probability above the threshold 59 | sorted_indices_to_remove = cumulative_probs > top_p 60 | 61 | # Shift the indices to the right to keep also the first token above the threshold 62 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 63 | sorted_indices_to_remove[..., 0] = 0 64 | 65 | # indices_to_remove = sorted_indices[sorted_indices_to_remove] 66 | 67 | # logits.masked_fill_(indices_to_remove, filter_value) 68 | indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove) 69 | logits[indices_to_remove] = filter_value 70 | 71 | return logits 72 | 73 | def safe_log(z): 74 | return torch.log(z + 1e-7) 75 | 76 | def log_sum_exp(value, dim=None, keepdim=False): 77 | """Numerically stable implementation of the operation 78 | value.exp().sum(dim, keepdim).log() 79 | """ 80 | if dim is not None: 81 | m, _ = torch.max(value, dim=dim, keepdim=True) 82 | value0 = value - m 83 | if keepdim is False: 84 | m = m.squeeze(dim) 85 | return m + torch.log(torch.sum(torch.exp(value0), dim=dim, keepdim=keepdim)) 86 | else: 87 | m = torch.max(value) 88 | sum_exp = torch.sum(torch.exp(value - m)) 89 | return m + torch.log(sum_exp) 90 | 91 | 92 | def generate_grid(zmin, zmax, dz, device, ndim=2): 93 | """generate a 1- or 2-dimensional grid 94 | Returns: Tensor, int 95 | Tensor: The grid tensor with shape (k^2, 2), 96 | where k=(zmax - zmin)/dz 97 | int: k 98 | """ 99 | 100 | if ndim == 2: 101 | x = torch.arange(zmin, zmax, dz) 102 | k = x.size(0) 103 | 104 | x1 = x.unsqueeze(1).repeat(1, k).view(-1) 105 | x2 = x.repeat(k) 106 | 107 | return torch.cat((x1.unsqueeze(-1), x2.unsqueeze(-1)), dim=-1).to(device), k 108 | 109 | elif ndim == 1: 110 | return torch.arange(zmin, zmax, dz).unsqueeze(1).to(device) 111 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/utils/peft_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import importlib 15 | import os 16 | from typing import Dict, Optional, Union 17 | 18 | from packaging import version 19 | 20 | from .hub import cached_file 21 | from .import_utils import is_peft_available 22 | 23 | 24 | ADAPTER_CONFIG_NAME = "adapter_config.json" 25 | ADAPTER_WEIGHTS_NAME = "adapter_model.bin" 26 | ADAPTER_SAFE_WEIGHTS_NAME = "adapter_model.safetensors" 27 | 28 | 29 | def find_adapter_config_file( 30 | model_id: str, 31 | cache_dir: Optional[Union[str, os.PathLike]] = None, 32 | force_download: bool = False, 33 | resume_download: bool = False, 34 | proxies: Optional[Dict[str, str]] = None, 35 | token: Optional[Union[bool, str]] = None, 36 | revision: Optional[str] = None, 37 | local_files_only: bool = False, 38 | subfolder: str = "", 39 | _commit_hash: Optional[str] = None, 40 | ) -> Optional[str]: 41 | r""" 42 | Simply checks if the model stored on the Hub or locally is an adapter model or not, return the path the the adapter 43 | config file if it is, None otherwise. 44 | 45 | Args: 46 | model_id (`str`): 47 | The identifier of the model to look for, can be either a local path or an id to the repository on the Hub. 48 | cache_dir (`str` or `os.PathLike`, *optional*): 49 | Path to a directory in which a downloaded pretrained model configuration should be cached if the standard 50 | cache should not be used. 51 | force_download (`bool`, *optional*, defaults to `False`): 52 | Whether or not to force to (re-)download the configuration files and override the cached versions if they 53 | exist. 54 | resume_download (`bool`, *optional*, defaults to `False`): 55 | Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. 56 | proxies (`Dict[str, str]`, *optional*): 57 | A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', 58 | 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. 59 | token (`str` or *bool*, *optional*): 60 | The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated 61 | when running `huggingface-cli login` (stored in `~/.huggingface`). 62 | revision (`str`, *optional*, defaults to `"main"`): 63 | The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a 64 | git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any 65 | identifier allowed by git. 66 | 67 | 68 | 69 | To test a pull request you made on the Hub, you can pass `revision="refs/pr/". 70 | 71 | 72 | 73 | local_files_only (`bool`, *optional*, defaults to `False`): 74 | If `True`, will only try to load the tokenizer configuration from local files. 75 | subfolder (`str`, *optional*, defaults to `""`): 76 | In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can 77 | specify the folder name here. 78 | """ 79 | adapter_cached_filename = None 80 | if model_id is None: 81 | return None 82 | elif os.path.isdir(model_id): 83 | list_remote_files = os.listdir(model_id) 84 | if ADAPTER_CONFIG_NAME in list_remote_files: 85 | adapter_cached_filename = os.path.join(model_id, ADAPTER_CONFIG_NAME) 86 | else: 87 | adapter_cached_filename = cached_file( 88 | model_id, 89 | ADAPTER_CONFIG_NAME, 90 | cache_dir=cache_dir, 91 | force_download=force_download, 92 | resume_download=resume_download, 93 | proxies=proxies, 94 | token=token, 95 | revision=revision, 96 | local_files_only=local_files_only, 97 | subfolder=subfolder, 98 | _commit_hash=_commit_hash, 99 | _raise_exceptions_for_missing_entries=False, 100 | _raise_exceptions_for_connection_errors=False, 101 | ) 102 | 103 | return adapter_cached_filename 104 | 105 | 106 | def check_peft_version(min_version: str) -> None: 107 | r""" 108 | Checks if the version of PEFT is compatible. 109 | 110 | Args: 111 | version (`str`): 112 | The version of PEFT to check against. 113 | """ 114 | if not is_peft_available(): 115 | raise ValueError("PEFT is not installed. Please install it with `pip install peft`") 116 | 117 | is_peft_version_compatible = version.parse(importlib.metadata.version("peft")) >= version.parse(min_version) 118 | 119 | if not is_peft_version_compatible: 120 | raise ValueError( 121 | f"The version of PEFT you are using is not compatible, please use a version that is greater" 122 | f" than {min_version}" 123 | ) 124 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/tests/tokenization_xlnet_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | 20 | from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE) 21 | 22 | from .tokenization_tests_commons import CommonTestCases 23 | 24 | SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), 25 | 'fixtures/test_sentencepiece.model') 26 | 27 | class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester): 28 | 29 | tokenizer_class = XLNetTokenizer 30 | 31 | def setUp(self): 32 | super(XLNetTokenizationTest, self).setUp() 33 | 34 | # We have a SentencePiece fixture for testing 35 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) 36 | tokenizer.save_pretrained(self.tmpdirname) 37 | 38 | def get_tokenizer(self, **kwargs): 39 | return XLNetTokenizer.from_pretrained(self.tmpdirname, **kwargs) 40 | 41 | def get_input_output_texts(self): 42 | input_text = u"This is a test" 43 | output_text = u"This is a test" 44 | return input_text, output_text 45 | 46 | 47 | def test_full_tokenizer(self): 48 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) 49 | 50 | tokens = tokenizer.tokenize(u'This is a test') 51 | self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est']) 52 | 53 | self.assertListEqual( 54 | tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382]) 55 | 56 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") 57 | self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', 58 | u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'', 59 | u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 60 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.']) 61 | ids = tokenizer.convert_tokens_to_ids(tokens) 62 | self.assertListEqual( 63 | ids, [8, 21, 84, 55, 24, 19, 7, 0, 64 | 602, 347, 347, 347, 3, 12, 66, 65 | 46, 72, 80, 6, 0, 4]) 66 | 67 | back_tokens = tokenizer.convert_ids_to_tokens(ids) 68 | self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', 69 | u'or', u'n', SPIECE_UNDERLINE + u'in', 70 | SPIECE_UNDERLINE + u'', u'', u'2', u'0', u'0', u'0', u',', 71 | SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 72 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', 73 | u'', u'.']) 74 | 75 | def test_tokenizer_lower(self): 76 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True) 77 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") 78 | self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'', u'i', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', 79 | u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'', 80 | u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 81 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.']) 82 | self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), [u"▁he", u"ll", u"o"]) 83 | 84 | def test_tokenizer_no_lower(self): 85 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=False) 86 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") 87 | self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', u'or', 88 | u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'', 89 | u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 90 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.']) 91 | 92 | def test_sequence_builders(self): 93 | tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased") 94 | 95 | text = tokenizer.encode("sequence builders") 96 | text_2 = tokenizer.encode("multi-sequence build") 97 | 98 | encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) 99 | encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) 100 | 101 | assert encoded_sentence == text + [4, 3] 102 | assert encoded_pair == text + [4] + text_2 + [4, 3] 103 | 104 | 105 | if __name__ == '__main__': 106 | unittest.main() 107 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/utils/hp_naming.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import copy 16 | import re 17 | 18 | 19 | class TrialShortNamer: 20 | PREFIX = "hp" 21 | DEFAULTS = {} 22 | NAMING_INFO = None 23 | 24 | @classmethod 25 | def set_defaults(cls, prefix, defaults): 26 | cls.PREFIX = prefix 27 | cls.DEFAULTS = defaults 28 | cls.build_naming_info() 29 | 30 | @staticmethod 31 | def shortname_for_word(info, word): 32 | if len(word) == 0: 33 | return "" 34 | short_word = None 35 | if any(char.isdigit() for char in word): 36 | raise Exception(f"Parameters should not contain numbers: '{word}' contains a number") 37 | if word in info["short_word"]: 38 | return info["short_word"][word] 39 | for prefix_len in range(1, len(word) + 1): 40 | prefix = word[:prefix_len] 41 | if prefix in info["reverse_short_word"]: 42 | continue 43 | else: 44 | short_word = prefix 45 | break 46 | 47 | if short_word is None: 48 | # Paranoid fallback 49 | def int_to_alphabetic(integer): 50 | s = "" 51 | while integer != 0: 52 | s = chr(ord("A") + integer % 10) + s 53 | integer //= 10 54 | return s 55 | 56 | i = 0 57 | while True: 58 | sword = word + "#" + int_to_alphabetic(i) 59 | if sword in info["reverse_short_word"]: 60 | continue 61 | else: 62 | short_word = sword 63 | break 64 | 65 | info["short_word"][word] = short_word 66 | info["reverse_short_word"][short_word] = word 67 | return short_word 68 | 69 | @staticmethod 70 | def shortname_for_key(info, param_name): 71 | words = param_name.split("_") 72 | 73 | shortname_parts = [TrialShortNamer.shortname_for_word(info, word) for word in words] 74 | 75 | # We try to create a separatorless short name, but if there is a collision we have to fallback 76 | # to a separated short name 77 | separators = ["", "_"] 78 | 79 | for separator in separators: 80 | shortname = separator.join(shortname_parts) 81 | if shortname not in info["reverse_short_param"]: 82 | info["short_param"][param_name] = shortname 83 | info["reverse_short_param"][shortname] = param_name 84 | return shortname 85 | 86 | return param_name 87 | 88 | @staticmethod 89 | def add_new_param_name(info, param_name): 90 | short_name = TrialShortNamer.shortname_for_key(info, param_name) 91 | info["short_param"][param_name] = short_name 92 | info["reverse_short_param"][short_name] = param_name 93 | 94 | @classmethod 95 | def build_naming_info(cls): 96 | if cls.NAMING_INFO is not None: 97 | return 98 | 99 | info = { 100 | "short_word": {}, 101 | "reverse_short_word": {}, 102 | "short_param": {}, 103 | "reverse_short_param": {}, 104 | } 105 | 106 | field_keys = list(cls.DEFAULTS.keys()) 107 | 108 | for k in field_keys: 109 | cls.add_new_param_name(info, k) 110 | 111 | cls.NAMING_INFO = info 112 | 113 | @classmethod 114 | def shortname(cls, params): 115 | cls.build_naming_info() 116 | assert cls.PREFIX is not None 117 | name = [copy.copy(cls.PREFIX)] 118 | 119 | for k, v in params.items(): 120 | if k not in cls.DEFAULTS: 121 | raise Exception(f"You should provide a default value for the param name {k} with value {v}") 122 | if v == cls.DEFAULTS[k]: 123 | # The default value is not added to the name 124 | continue 125 | 126 | key = cls.NAMING_INFO["short_param"][k] 127 | 128 | if isinstance(v, bool): 129 | v = 1 if v else 0 130 | 131 | sep = "" if isinstance(v, (int, float)) else "-" 132 | e = f"{key}{sep}{v}" 133 | name.append(e) 134 | 135 | return "_".join(name) 136 | 137 | @classmethod 138 | def parse_repr(cls, repr): 139 | repr = repr[len(cls.PREFIX) + 1 :] 140 | if repr == "": 141 | values = [] 142 | else: 143 | values = repr.split("_") 144 | 145 | parameters = {} 146 | 147 | for value in values: 148 | if "-" in value: 149 | p_k, p_v = value.split("-") 150 | else: 151 | p_k = re.sub("[0-9.]", "", value) 152 | p_v = float(re.sub("[^0-9.]", "", value)) 153 | 154 | key = cls.NAMING_INFO["reverse_short_param"][p_k] 155 | 156 | parameters[key] = p_v 157 | 158 | for k in cls.DEFAULTS: 159 | if k not in parameters: 160 | parameters[k] = cls.DEFAULTS[k] 161 | 162 | return parameters 163 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/configuration_openai.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ OpenAI GPT configuration """ 17 | 18 | from __future__ import absolute_import, division, print_function, unicode_literals 19 | 20 | import json 21 | import logging 22 | import sys 23 | from io import open 24 | 25 | from .configuration_utils import PretrainedConfig 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 30 | "openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json" 31 | } 32 | 33 | class OpenAIGPTConfig(PretrainedConfig): 34 | """ 35 | Configuration class to store the configuration of a `OpenAIGPTModel`. 36 | 37 | Args: 38 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `OpenAIGPTModel` or a configuration json file. 39 | n_special: The number of special tokens to learn during fine-tuning ('[SEP]', '[CLF]', ...) 40 | n_positions: Number of positional embeddings. 41 | n_ctx: Size of the causal mask (usually same as n_positions). 42 | n_embd: Dimensionality of the embeddings and hidden states. 43 | n_layer: Number of hidden layers in the Transformer encoder. 44 | n_head: Number of attention heads for each attention layer in 45 | the Transformer encoder. 46 | afn: The non-linear activation function (function or string) in the 47 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported. 48 | resid_pdrop: The dropout probabilitiy for all fully connected 49 | layers in the embeddings, encoder, and pooler. 50 | attn_pdrop: The dropout ratio for the attention 51 | probabilities. 52 | embd_pdrop: The dropout ratio for the embeddings. 53 | layer_norm_epsilon: epsilon to use in the layer norm layers 54 | initializer_range: The sttdev of the truncated_normal_initializer for 55 | initializing all weight matrices. 56 | predict_special_tokens: should we predict special tokens (when the model has a LM head) 57 | """ 58 | pretrained_config_archive_map = OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP 59 | 60 | def __init__( 61 | self, 62 | vocab_size_or_config_json_file=40478, 63 | n_positions=512, 64 | n_ctx=512, 65 | n_embd=768, 66 | n_layer=12, 67 | n_head=12, 68 | afn="gelu", 69 | resid_pdrop=0.1, 70 | embd_pdrop=0.1, 71 | attn_pdrop=0.1, 72 | layer_norm_epsilon=1e-5, 73 | initializer_range=0.02, 74 | predict_special_tokens=True, 75 | 76 | num_labels=1, 77 | summary_type='cls_index', 78 | summary_use_proj=True, 79 | summary_activation=None, 80 | summary_proj_to_labels=True, 81 | summary_first_dropout=0.1, 82 | **kwargs 83 | ): 84 | """Constructs OpenAIGPTConfig. 85 | """ 86 | super(OpenAIGPTConfig, self).__init__(**kwargs) 87 | 88 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 89 | and isinstance(vocab_size_or_config_json_file, unicode)): 90 | with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader: 91 | json_config = json.loads(reader.read()) 92 | for key, value in json_config.items(): 93 | self.__dict__[key] = value 94 | elif isinstance(vocab_size_or_config_json_file, int): 95 | self.vocab_size = vocab_size_or_config_json_file 96 | self.n_ctx = n_ctx 97 | self.n_positions = n_positions 98 | self.n_embd = n_embd 99 | self.n_layer = n_layer 100 | self.n_head = n_head 101 | self.afn = afn 102 | self.resid_pdrop = resid_pdrop 103 | self.embd_pdrop = embd_pdrop 104 | self.attn_pdrop = attn_pdrop 105 | self.layer_norm_epsilon = layer_norm_epsilon 106 | self.initializer_range = initializer_range 107 | self.predict_special_tokens = predict_special_tokens 108 | 109 | self.num_labels = num_labels 110 | self.summary_type = summary_type 111 | self.summary_use_proj = summary_use_proj 112 | self.summary_activation = summary_activation 113 | self.summary_first_dropout = summary_first_dropout 114 | self.summary_proj_to_labels = summary_proj_to_labels 115 | else: 116 | raise ValueError( 117 | "First argument must be either a vocabulary size (int)" 118 | "or the path to a pretrained model config file (str)" 119 | ) 120 | 121 | @property 122 | def max_position_embeddings(self): 123 | return self.n_positions 124 | 125 | @property 126 | def hidden_size(self): 127 | return self.n_embd 128 | 129 | @property 130 | def num_attention_heads(self): 131 | return self.n_head 132 | 133 | @property 134 | def num_hidden_layers(self): 135 | return self.n_layer 136 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/tests/tokenization_bert_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | from io import open 20 | 21 | from pytorch_transformers.tokenization_bert import (BasicTokenizer, 22 | BertTokenizer, 23 | WordpieceTokenizer, 24 | _is_control, _is_punctuation, 25 | _is_whitespace, VOCAB_FILES_NAMES) 26 | 27 | from .tokenization_tests_commons import CommonTestCases 28 | 29 | class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): 30 | 31 | tokenizer_class = BertTokenizer 32 | 33 | def setUp(self): 34 | super(BertTokenizationTest, self).setUp() 35 | 36 | vocab_tokens = [ 37 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 38 | "##ing", ",", "low", "lowest", 39 | ] 40 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 41 | with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer: 42 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 43 | 44 | def get_tokenizer(self, **kwargs): 45 | return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs) 46 | 47 | def get_input_output_texts(self): 48 | input_text = u"UNwant\u00E9d,running" 49 | output_text = u"unwanted, running" 50 | return input_text, output_text 51 | 52 | def test_full_tokenizer(self): 53 | tokenizer = self.tokenizer_class(self.vocab_file) 54 | 55 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") 56 | self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) 57 | self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) 58 | 59 | def test_chinese(self): 60 | tokenizer = BasicTokenizer() 61 | 62 | self.assertListEqual( 63 | tokenizer.tokenize(u"ah\u535A\u63A8zz"), 64 | [u"ah", u"\u535A", u"\u63A8", u"zz"]) 65 | 66 | def test_basic_tokenizer_lower(self): 67 | tokenizer = BasicTokenizer(do_lower_case=True) 68 | 69 | self.assertListEqual( 70 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 71 | ["hello", "!", "how", "are", "you", "?"]) 72 | self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) 73 | 74 | def test_basic_tokenizer_no_lower(self): 75 | tokenizer = BasicTokenizer(do_lower_case=False) 76 | 77 | self.assertListEqual( 78 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 79 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 80 | 81 | def test_wordpiece_tokenizer(self): 82 | vocab_tokens = [ 83 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 84 | "##ing" 85 | ] 86 | 87 | vocab = {} 88 | for (i, token) in enumerate(vocab_tokens): 89 | vocab[token] = i 90 | tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]") 91 | 92 | self.assertListEqual(tokenizer.tokenize(""), []) 93 | 94 | self.assertListEqual( 95 | tokenizer.tokenize("unwanted running"), 96 | ["un", "##want", "##ed", "runn", "##ing"]) 97 | 98 | self.assertListEqual( 99 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) 100 | 101 | def test_is_whitespace(self): 102 | self.assertTrue(_is_whitespace(u" ")) 103 | self.assertTrue(_is_whitespace(u"\t")) 104 | self.assertTrue(_is_whitespace(u"\r")) 105 | self.assertTrue(_is_whitespace(u"\n")) 106 | self.assertTrue(_is_whitespace(u"\u00A0")) 107 | 108 | self.assertFalse(_is_whitespace(u"A")) 109 | self.assertFalse(_is_whitespace(u"-")) 110 | 111 | def test_is_control(self): 112 | self.assertTrue(_is_control(u"\u0005")) 113 | 114 | self.assertFalse(_is_control(u"A")) 115 | self.assertFalse(_is_control(u" ")) 116 | self.assertFalse(_is_control(u"\t")) 117 | self.assertFalse(_is_control(u"\r")) 118 | 119 | def test_is_punctuation(self): 120 | self.assertTrue(_is_punctuation(u"-")) 121 | self.assertTrue(_is_punctuation(u"$")) 122 | self.assertTrue(_is_punctuation(u"`")) 123 | self.assertTrue(_is_punctuation(u".")) 124 | 125 | self.assertFalse(_is_punctuation(u"A")) 126 | self.assertFalse(_is_punctuation(u" ")) 127 | 128 | def test_sequence_builders(self): 129 | tokenizer = self.tokenizer_class.from_pretrained("bert-base-uncased") 130 | 131 | text = tokenizer.encode("sequence builders") 132 | text_2 = tokenizer.encode("multi-sequence build") 133 | 134 | encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) 135 | encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) 136 | 137 | assert encoded_sentence == [101] + text + [102] 138 | assert encoded_pair == [101] + text + [102] + text_2 + [102] 139 | 140 | if __name__ == '__main__': 141 | unittest.main() 142 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/convert_transfo_xl_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert Transformer XL checkpoint and datasets.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | import os 21 | import sys 22 | from io import open 23 | 24 | import torch 25 | 26 | import pytorch_transformers.tokenization_transfo_xl as data_utils 27 | 28 | from pytorch_transformers import CONFIG_NAME, WEIGHTS_NAME 29 | from pytorch_transformers import (TransfoXLConfig, TransfoXLLMHeadModel, 30 | load_tf_weights_in_transfo_xl) 31 | from pytorch_transformers.tokenization_transfo_xl import (CORPUS_NAME, VOCAB_FILES_NAMES) 32 | 33 | if sys.version_info[0] == 2: 34 | import cPickle as pickle 35 | else: 36 | import pickle 37 | 38 | import logging 39 | logging.basicConfig(level=logging.INFO) 40 | 41 | # We do this to be able to load python 2 datasets pickles 42 | # See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918 43 | data_utils.Vocab = data_utils.TransfoXLTokenizer 44 | data_utils.Corpus = data_utils.TransfoXLCorpus 45 | sys.modules['data_utils'] = data_utils 46 | sys.modules['vocabulary'] = data_utils 47 | 48 | def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, 49 | transfo_xl_config_file, 50 | pytorch_dump_folder_path, 51 | transfo_xl_dataset_file): 52 | if transfo_xl_dataset_file: 53 | # Convert a pre-processed corpus (see original TensorFlow repo) 54 | with open(transfo_xl_dataset_file, "rb") as fp: 55 | corpus = pickle.load(fp, encoding="latin1") 56 | # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term) 57 | pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['pretrained_vocab_file'] 58 | print("Save vocabulary to {}".format(pytorch_vocab_dump_path)) 59 | corpus_vocab_dict = corpus.vocab.__dict__ 60 | torch.save(corpus_vocab_dict, pytorch_vocab_dump_path) 61 | 62 | corpus_dict_no_vocab = corpus.__dict__ 63 | corpus_dict_no_vocab.pop('vocab', None) 64 | pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME 65 | print("Save dataset to {}".format(pytorch_dataset_dump_path)) 66 | torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path) 67 | 68 | if tf_checkpoint_path: 69 | # Convert a pre-trained TensorFlow model 70 | config_path = os.path.abspath(transfo_xl_config_file) 71 | tf_path = os.path.abspath(tf_checkpoint_path) 72 | 73 | print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path)) 74 | # Initialise PyTorch model 75 | if transfo_xl_config_file == "": 76 | config = TransfoXLConfig() 77 | else: 78 | config = TransfoXLConfig.from_json_file(transfo_xl_config_file) 79 | print("Building PyTorch model from configuration: {}".format(str(config))) 80 | model = TransfoXLLMHeadModel(config) 81 | 82 | model = load_tf_weights_in_transfo_xl(model, config, tf_path) 83 | # Save pytorch-model 84 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) 85 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) 86 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path))) 87 | torch.save(model.state_dict(), pytorch_weights_dump_path) 88 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path))) 89 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 90 | f.write(config.to_json_string()) 91 | 92 | 93 | if __name__ == "__main__": 94 | parser = argparse.ArgumentParser() 95 | parser.add_argument("--pytorch_dump_folder_path", 96 | default = None, 97 | type = str, 98 | required = True, 99 | help = "Path to the folder to store the PyTorch model or dataset/vocab.") 100 | parser.add_argument("--tf_checkpoint_path", 101 | default = "", 102 | type = str, 103 | help = "An optional path to a TensorFlow checkpoint path to be converted.") 104 | parser.add_argument("--transfo_xl_config_file", 105 | default = "", 106 | type = str, 107 | help = "An optional config json file corresponding to the pre-trained BERT model. \n" 108 | "This specifies the model architecture.") 109 | parser.add_argument("--transfo_xl_dataset_file", 110 | default = "", 111 | type = str, 112 | help = "An optional dataset file to be converted in a vocabulary.") 113 | args = parser.parse_args() 114 | convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path, 115 | args.transfo_xl_config_file, 116 | args.pytorch_dump_folder_path, 117 | args.transfo_xl_dataset_file) 118 | -------------------------------------------------------------------------------- /larimar_base/modules/spacefusion.py: -------------------------------------------------------------------------------- 1 | from .vae import VAE 2 | import numpy as np 3 | import torch, copy, pdb 4 | import torch.nn.functional as F 5 | 6 | from torch import nn 7 | 8 | import pdb 9 | 10 | 11 | def set_trainable(module, value): 12 | for param in module.parameters(): 13 | param.requires_grad = value 14 | 15 | class SpaceFusion(VAE): 16 | def __init__(self, encoder, decoder, tokenizer_encoder, tokenizer_decoder, args): 17 | super(SpaceFusion, self).__init__(encoder, decoder, tokenizer_encoder, tokenizer_decoder, args) 18 | children = [v for v in encoder.encoder.layer.children()] # list of 12 BertLayer 19 | 20 | self.num_s2s_bert_layer = args.num_s2s_bert_layer 21 | self.S2S_layers = nn.ModuleList([copy.deepcopy(c) for c in children[-args.num_s2s_bert_layer:] ]) # the last layer of encoder 22 | self.S2S_pooler = copy.deepcopy(encoder.pooler) 23 | self.ix_turn_sep = tokenizer_encoder.convert_tokens_to_ids('[SEP]') 24 | if args.freeze_bert: 25 | print('@'*20 + f' freezing BERT {args.num_frozen_bert_layer} layers') 26 | for child in children[:args.num_frozen_bert_layer]: 27 | set_trainable(child, False) 28 | 29 | 30 | 31 | def ids2speaker(self, ids): 32 | # 0 for speaker A, 1 for speaker B 33 | N, T = ids.shape 34 | speaker = np.zeros((N, T)) 35 | sep = ids == self.ix_turn_sep 36 | for i in range(N): 37 | is_B = False # start with speaker A 38 | for t in range(T): 39 | speaker[i,t] = int(is_B) 40 | if sep[i,t].item(): 41 | is_B = not is_B 42 | 43 | # make sure the final speaker is speaker B (so response is always speaker A) 44 | if not is_B: 45 | speaker = 1 - speaker 46 | 47 | return torch.LongTensor(speaker).to(ids.device) 48 | 49 | def forward(self, inputs_src, inputs_tgt, labels_tgt, return_vec=False): # [batch, time] 50 | # toggle config to get desired encoder output 51 | self.encoder.encoder.output_attentions = False 52 | self.encoder.encoder.output_hidden_states = True 53 | 54 | 55 | # AE encoder 56 | mask = (inputs_tgt > 0).float().to(inputs_src.device) 57 | outputs = self.encoder(inputs_tgt, attention_mask=mask) 58 | z_AE, _ = self.connect(outputs[1]) 59 | z_AE = z_AE.squeeze(1) 60 | 61 | # S2S encoder 62 | mask = (inputs_src > 0).float() 63 | speaker = self.ids2speaker(inputs_src) 64 | outputs = self.encoder(inputs_src, attention_mask=mask, token_type_ids=speaker) 65 | _, _, all_layer_attn = outputs # last_layer_attn, pooled, all_layer_attn = outputs 66 | seq_z_prev = all_layer_attn[-self.num_s2s_bert_layer-1] # seq of z at layer 11 () 67 | 68 | for s2s in self.S2S_layers: 69 | layer_outputs = s2s(seq_z_prev, attention_mask=mask.unsqueeze(1).unsqueeze(1)) 70 | seq_z_prev = layer_outputs[0] 71 | 72 | z_S2S = self.encoder.pooler(layer_outputs[0]) 73 | z_S2S, _ = self.connect(z_S2S) 74 | z_S2S = z_S2S.squeeze(1) 75 | 76 | if return_vec: 77 | return z_AE, z_S2S 78 | 79 | # interpolation/smoothness 80 | u = torch.FloatTensor(np.random.random((z_AE.shape[0], 1))).to(inputs_tgt.device) 81 | z_interp = u * z_AE + (1 - u) * z_S2S 82 | std = 0.1 83 | noise = torch.FloatTensor(np.random.normal(size=z_interp.shape) * std).to(z_interp.device) 84 | z_interp = z_interp + noise 85 | 86 | loss_rec = 0 87 | z_idx = 0 88 | for z in [z_AE, z_S2S, z_interp]: 89 | #pdb.set_trace() 90 | past = z # past = self.decoder.linear(z) 91 | outputs = self.decoder(input_ids=labels_tgt, past=past, labels=labels_tgt, label_ignore=self.pad_token_id) 92 | if z_idx == 1: 93 | loss_rec = loss_rec + 1.0 * outputs[0] 94 | else: 95 | loss_rec = loss_rec + outputs[0] 96 | z_idx += 1 97 | loss_rec = loss_rec/3 98 | 99 | # fusion/regularization 100 | L_pull = self.dist_pair(z_AE, z_S2S) 101 | L_push = torch.stack([self.dist_batch(z) for z in [z_AE, z_S2S]]).min() 102 | loss_reg = (L_pull - L_push * 2) / np.sqrt(z.shape[-1]) 103 | 104 | loss = loss_rec + self.args.beta * loss_reg 105 | return loss_rec, loss_reg, loss 106 | 107 | def sent2latent(self, inputs_src): 108 | # toggle config to get desired encoder output 109 | self.encoder.encoder.output_attentions = False 110 | self.encoder.encoder.output_hidden_states = True 111 | 112 | # S2S encoder 113 | mask = (inputs_src > 0).float() 114 | speaker = self.ids2speaker(inputs_src) 115 | outputs = self.encoder(inputs_src, attention_mask=mask, token_type_ids=speaker) 116 | 117 | _, _, all_layer_attn = outputs # last_layer_attn, pooled, all_layer_attn = outputs 118 | # seq_z_prev = all_layer_attn[-2] # seq of z at layer 11 () 119 | # layer_outputs = self.S2S_layer(seq_z_prev, attention_mask=mask.unsqueeze(1).unsqueeze(1)) 120 | 121 | seq_z_prev = all_layer_attn[-self.num_s2s_bert_layer-1] # seq of z at layer 11 () 122 | for s2s in self.S2S_layers: 123 | layer_outputs = s2s(seq_z_prev, attention_mask=mask.unsqueeze(1).unsqueeze(1)) 124 | seq_z_prev = layer_outputs[0] 125 | 126 | z_S2S = self.encoder.pooler(layer_outputs[0]) 127 | z_S2S, _ = self.connect(z_S2S) 128 | z_S2S = z_S2S.squeeze(1) 129 | 130 | return z_S2S 131 | 132 | 133 | def dist_pair(self, a, b): 134 | return F.pairwise_distance(a, b).mean() 135 | 136 | 137 | def dist_batch(self, vec): 138 | n = vec.shape[0] 139 | dmin = [] 140 | for i in range(n): 141 | dd = F.pairwise_distance(vec[i:i+1,:].repeat(n,1), vec) 142 | dmin.append(dd.min()) 143 | return torch.stack(dmin).mean() -------------------------------------------------------------------------------- /larimar_base/ddp.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import subprocess 4 | import numpy as np 5 | 6 | import torch.distributed as dist 7 | 8 | 9 | def get_nccl_socket_ifname(): 10 | ipa = subprocess.run(['ip', 'a'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) 11 | lines = ipa.stdout.decode('utf-8').split('\n') 12 | all_names = [] 13 | name = None 14 | for line in lines: 15 | if line and not line[0] == ' ': 16 | name = line.split(':')[1].strip() 17 | continue 18 | if 'link/infiniband' in line: 19 | all_names.append(name) 20 | os.environ['NCCL_SOCKET_IFNAME'] = ','.join(all_names) 21 | 22 | 23 | def fix_infiniband(): 24 | # os.environ['NCCL_SOCKET_IFNAME'] = "^lo,docker,virbr,vmnet,vboxnet,wl,ww,ppp,bond" 25 | 26 | # ifname = os.environ.get('NCCL_SOCKET_IFNAME', None) 27 | # if ifname is None: 28 | # os.environ['NCCL_SOCKET_IFNAME'] = '^lo,docker0' 29 | get_nccl_socket_ifname() 30 | os.environ['NCCL_IB_CUDA_SUPPORT'] = '1' 31 | ibv = subprocess.run('ibv_devinfo', stdout=subprocess.PIPE, stderr=subprocess.PIPE) 32 | lines = ibv.stdout.decode('utf-8').split('\n') 33 | exclude = '' 34 | include = '' 35 | for line in lines: 36 | if 'hca_id:' in line: 37 | name = line.split(':')[1].strip() 38 | if '\tport:' in line: 39 | port = line.split(':')[1].strip() 40 | if 'link_layer:' in line and 'Ethernet' in line: 41 | exclude = exclude + f'{name}:{port},' 42 | if 'link_layer:' in line and 'infiniband' in line.lower(): 43 | include = include + f'{name}:{port},' 44 | if exclude: 45 | exclude = '^' + exclude[:-1] 46 | # print(exclude) 47 | os.environ['NCCL_IB_HCA'] = exclude 48 | else: 49 | os.environ['NCCL_IB_HCA'] = include[:-1] 50 | 51 | 52 | 53 | fix_inifiniband = fix_infiniband # For backwards compatibility 54 | 55 | def init_ddp_process_group(local_rank: int = None, port: int = None, world_size: int = None, dist_rank: int = None, 56 | overwrite_env_vars=True): 57 | logger = logging.getLogger('InitDDP') 58 | if os.environ.get('LSB_JOBID', False): 59 | local_rank = int(os.environ.get('LSF_PM_XPROCID', 1)) - 1 if local_rank is None else local_rank 60 | 61 | hostname = os.environ.get('HOSTNAME', 'localhost') 62 | num_gpus = len(os.environ.get('CUDA_VISIBLE_DEVICES', '').split(',')) 63 | node_rank = int(os.environ.get('LSF_PM_XMACHID', 1)) - 1 64 | dist_rank = node_rank * num_gpus + local_rank if dist_rank is None else dist 65 | num_hosts = len(os.environ.get('LSB_MCPU_HOSTS', 'localhost cpus').split()) // 2 66 | rng = np.random.RandomState(seed=int(os.environ.get('LSB_JOBID', 0))) 67 | master_host = os.environ.get('LSF_FROM_HOST', 'localhost') 68 | port = rng.randint(10000, 20000) if port is None else port 69 | if num_hosts > 1: 70 | fix_inifiniband() 71 | prefix = f'{hostname}, Local Rank {local_rank}/{num_gpus}, Global Rank {dist_rank}/{world_size}:' 72 | 73 | logger.info(f'{prefix} Trying to init process group') 74 | logger.debug(f'{prefix} CUDA_VISIBLE_DEVICES=', os.environ.get('CUDA_VISIBLE_DEVICES', '')) 75 | logger.debug(f'{prefix} LSF_PM_XMACHID=', os.environ.get('LSF_PM_XMACHID', '')) 76 | logger.debug(f'{prefix} LSF_PM_XPROCID=', os.environ.get('LSF_PM_XPROCID', '')) 77 | logger.debug(f'{prefix} LSB_MCPU_HOSTS=', os.environ.get('LSB_MCPU_HOSTS', '')) 78 | logger.debug(f'{prefix} MASTER_ADDR=', master_host) 79 | logger.debug(f'{prefix} MASTER_PORT=', port) 80 | elif os.environ.get('SLURM_JOB_ID', False): 81 | 82 | num_gpus = len(os.environ.get('CUDA_VISIBLE_DEVICES', '0').split(',')) 83 | local_rank = int(os.environ.get('SLURM_PROCID', 0)) % num_gpus if local_rank is None else local_rank 84 | node_rank = int(os.environ.get('SLURM_NODEID', 0)) 85 | 86 | hostlist = subprocess.run(['scontrol', 'show', 'hostnames', os.environ.get('SLURM_JOB_NODELIST', 'localhost')], 87 | stdout=subprocess.PIPE, stderr=subprocess.PIPE) 88 | hostlist = hostlist.stdout.decode('utf8').strip().split('\n') 89 | num_hosts = len(hostlist) 90 | master_host = hostlist[0] 91 | hostname = os.environ.get('HOSTNAME', 'localhost') 92 | dist_rank = node_rank * num_gpus + local_rank if dist_rank is None else dist 93 | rng = np.random.RandomState(seed=int(os.environ.get('SLURM_JOB_ID', 0))) 94 | port = rng.randint(10000, 20000) if port is None else port 95 | 96 | prefix = f'{hostname}, Local Rank {local_rank}/{num_gpus}, Global Rank {dist_rank}/{world_size}:' 97 | 98 | logger.info(f'{prefix} Trying to init process group') 99 | logger.debug(f'{prefix} CUDA_VISIBLE_DEVICES=', os.environ.get('CUDA_VISIBLE_DEVICES', '')) 100 | logger.debug(f'{prefix} SLURM_NODEID=', os.environ.get('SLURM_NODEID', '')) 101 | logger.debug(f'{prefix} SLURM_PROCID=', os.environ.get('SLURM_PROCID', '')) 102 | logger.debug(f'{prefix} SLURM_JOB_NODELIST=', os.environ.get('SLURM_JOB_NODELIST', '')) 103 | logger.debug(f'{prefix} MASTER_ADDR=', master_host) 104 | logger.debug(f'{prefix} MASTER_PORT=', port) 105 | 106 | else: 107 | return dist.init_process_group(backend='nccl', init_method='env://') 108 | world_size = num_gpus * num_hosts if world_size is None else world_size 109 | 110 | if 'RANK' not in os.environ.keys() or overwrite_env_vars: 111 | os.environ['RANK'] = str(dist_rank) 112 | if 'LOCAL_RANK' not in os.environ.keys() or overwrite_env_vars: 113 | os.environ['LOCAL_RANK'] = str(local_rank) 114 | if 'NODE_RANK' not in os.environ.keys() or overwrite_env_vars: 115 | os.environ['NODE_RANK'] = str(node_rank) 116 | if 'MASTER_ADDR' not in os.environ.keys() or overwrite_env_vars: 117 | os.environ['MASTER_ADDR'] = master_host 118 | if 'WORLD_SIZE' not in os.environ.keys() or overwrite_env_vars: 119 | os.environ['WORLD_SIZE'] = str(world_size) 120 | 121 | if port is not None: 122 | os.environ['MASTER_PORT'] = str(port) 123 | elif 'MASTER_PORT' not in os.environ.keys() or overwrite_env_vars: 124 | 125 | os.environ['MASTER_PORT'] = str(port) 126 | 127 | 128 | group = dist.init_process_group(backend='nccl', init_method='env://') 129 | 130 | logger.info(f'{prefix} Done init process group') 131 | return group 132 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/utils/sentencepiece_model_pb2_new.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: sentencepiece_model.proto 4 | """Generated protocol buffer code.""" 5 | from google.protobuf import descriptor as _descriptor 6 | from google.protobuf import descriptor_pool as _descriptor_pool 7 | from google.protobuf import symbol_database as _symbol_database 8 | from google.protobuf.internal import builder as _builder 9 | 10 | 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( 17 | b'\n\x19sentencepiece_model.proto\x12\rsentencepiece"\x80\x0c\n\x0bTrainerSpec\x12\r\n\x05input\x18\x01 \x03(\t\x12\x14\n\x0cinput_format\x18\x07 \x01(\t\x12\x14\n\x0cmodel_prefix\x18\x02 \x01(\t\x12\x41\n\nmodel_type\x18\x03 \x01(\x0e\x32$.sentencepiece.TrainerSpec.ModelType:\x07UNIGRAM\x12\x18\n\nvocab_size\x18\x04 \x01(\x05:\x04\x38\x30\x30\x30\x12\x17\n\x0f\x61\x63\x63\x65pt_language\x18\x05 \x03(\t\x12 \n\x15self_test_sample_size\x18\x06 \x01(\x05:\x01\x30\x12*\n\x1b\x65nable_differential_privacy\x18\x32 \x01(\x08:\x05\x66\x61lse\x12+\n differential_privacy_noise_level\x18\x33 \x01(\x02:\x01\x30\x12\x32\n\'differential_privacy_clipping_threshold\x18\x34 \x01(\x04:\x01\x30\x12"\n\x12\x63haracter_coverage\x18\n \x01(\x02:\x06\x30.9995\x12\x1e\n\x13input_sentence_size\x18\x0b \x01(\x04:\x01\x30\x12$\n\x16shuffle_input_sentence\x18\x13 \x01(\x08:\x04true\x12 \n\x14mining_sentence_size\x18\x0c \x01(\x05\x42\x02\x18\x01\x12"\n\x16training_sentence_size\x18\r \x01(\x05\x42\x02\x18\x01\x12(\n\x17seed_sentencepiece_size\x18\x0e \x01(\x05:\x07\x31\x30\x30\x30\x30\x30\x30\x12\x1e\n\x10shrinking_factor\x18\x0f \x01(\x02:\x04\x30.75\x12!\n\x13max_sentence_length\x18\x12 \x01(\x05:\x04\x34\x31\x39\x32\x12\x17\n\x0bnum_threads\x18\x10 \x01(\x05:\x02\x31\x36\x12\x1d\n\x12num_sub_iterations\x18\x11 \x01(\x05:\x01\x32\x12$\n\x18max_sentencepiece_length\x18\x14 \x01(\x05:\x02\x31\x36\x12%\n\x17split_by_unicode_script\x18\x15 \x01(\x08:\x04true\x12\x1d\n\x0fsplit_by_number\x18\x17 \x01(\x08:\x04true\x12!\n\x13split_by_whitespace\x18\x16 \x01(\x08:\x04true\x12)\n\x1atreat_whitespace_as_suffix\x18\x18 \x01(\x08:\x05\x66\x61lse\x12+\n\x1c\x61llow_whitespace_only_pieces\x18\x1a \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0csplit_digits\x18\x19 \x01(\x08:\x05\x66\x61lse\x12#\n\x19pretokenization_delimiter\x18\x35 \x01(\t:\x00\x12\x17\n\x0f\x63ontrol_symbols\x18\x1e \x03(\t\x12\x1c\n\x14user_defined_symbols\x18\x1f \x03(\t\x12\x16\n\x0erequired_chars\x18$ \x01(\t\x12\x1c\n\rbyte_fallback\x18# \x01(\x08:\x05\x66\x61lse\x12+\n\x1dvocabulary_output_piece_score\x18 \x01(\x08:\x04true\x12\x1e\n\x10hard_vocab_limit\x18! \x01(\x08:\x04true\x12\x1c\n\ruse_all_vocab\x18" \x01(\x08:\x05\x66\x61lse\x12\x11\n\x06unk_id\x18( \x01(\x05:\x01\x30\x12\x11\n\x06\x62os_id\x18) \x01(\x05:\x01\x31\x12\x11\n\x06\x65os_id\x18* \x01(\x05:\x01\x32\x12\x12\n\x06pad_id\x18+ \x01(\x05:\x02-1\x12\x18\n\tunk_piece\x18- \x01(\t:\x05\x12\x16\n\tbos_piece\x18. \x01(\t:\x03\x12\x17\n\teos_piece\x18/ \x01(\t:\x04\x12\x18\n\tpad_piece\x18\x30 \x01(\t:\x05\x12\x1a\n\x0bunk_surface\x18, \x01(\t:\x05 \xe2\x81\x87 \x12+\n\x1ctrain_extremely_large_corpus\x18\x31 \x01(\x08:\x05\x66\x61lse"5\n\tModelType\x12\x0b\n\x07UNIGRAM\x10\x01\x12\x07\n\x03\x42PE\x10\x02\x12\x08\n\x04WORD\x10\x03\x12\x08\n\x04\x43HAR\x10\x04*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"\xd1\x01\n\x0eNormalizerSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1c\n\x14precompiled_charsmap\x18\x02 \x01(\x0c\x12\x1e\n\x10\x61\x64\x64_dummy_prefix\x18\x03 \x01(\x08:\x04true\x12&\n\x18remove_extra_whitespaces\x18\x04 \x01(\x08:\x04true\x12 \n\x12\x65scape_whitespaces\x18\x05 \x01(\x08:\x04true\x12\x1e\n\x16normalization_rule_tsv\x18\x06 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"y\n\x0cSelfTestData\x12\x33\n\x07samples\x18\x01 \x03(\x0b\x32".sentencepiece.SelfTestData.Sample\x1a)\n\x06Sample\x12\r\n\x05input\x18\x01 \x01(\t\x12\x10\n\x08\x65xpected\x18\x02 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"\xfe\x03\n\nModelProto\x12\x37\n\x06pieces\x18\x01 \x03(\x0b\x32\'.sentencepiece.ModelProto.SentencePiece\x12\x30\n\x0ctrainer_spec\x18\x02 \x01(\x0b\x32\x1a.sentencepiece.TrainerSpec\x12\x36\n\x0fnormalizer_spec\x18\x03 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x12\x33\n\x0eself_test_data\x18\x04 \x01(\x0b\x32\x1b.sentencepiece.SelfTestData\x12\x38\n\x11\x64\x65normalizer_spec\x18\x05 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x1a\xd2\x01\n\rSentencePiece\x12\r\n\x05piece\x18\x01 \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x12\x42\n\x04type\x18\x03 \x01(\x0e\x32,.sentencepiece.ModelProto.SentencePiece.Type:\x06NORMAL"T\n\x04Type\x12\n\n\x06NORMAL\x10\x01\x12\x0b\n\x07UNKNOWN\x10\x02\x12\x0b\n\x07\x43ONTROL\x10\x03\x12\x10\n\x0cUSER_DEFINED\x10\x04\x12\x08\n\x04\x42YTE\x10\x06\x12\n\n\x06UNUSED\x10\x05*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\x42\x02H\x03' 18 | ) 19 | 20 | _globals = globals() 21 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) 22 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "sentencepiece_model_pb2", _globals) 23 | if _descriptor._USE_C_DESCRIPTORS is False: 24 | DESCRIPTOR._options = None 25 | DESCRIPTOR._serialized_options = b"H\003" 26 | # (generated by protobuf compiler, but `_TRAINERSPEC` is not defined) 27 | # _TRAINERSPEC.fields_by_name["mining_sentence_size"]._options = None 28 | # _TRAINERSPEC.fields_by_name["mining_sentence_size"]._serialized_options = b"\030\001" 29 | # _TRAINERSPEC.fields_by_name["training_sentence_size"]._options = None 30 | # _TRAINERSPEC.fields_by_name["training_sentence_size"]._serialized_options = b"\030\001" 31 | _globals["_TRAINERSPEC"]._serialized_start = 45 32 | _globals["_TRAINERSPEC"]._serialized_end = 1581 33 | _globals["_TRAINERSPEC_MODELTYPE"]._serialized_start = 1517 34 | _globals["_TRAINERSPEC_MODELTYPE"]._serialized_end = 1570 35 | _globals["_NORMALIZERSPEC"]._serialized_start = 1584 36 | _globals["_NORMALIZERSPEC"]._serialized_end = 1793 37 | _globals["_SELFTESTDATA"]._serialized_start = 1795 38 | _globals["_SELFTESTDATA"]._serialized_end = 1916 39 | _globals["_SELFTESTDATA_SAMPLE"]._serialized_start = 1864 40 | _globals["_SELFTESTDATA_SAMPLE"]._serialized_end = 1905 41 | _globals["_MODELPROTO"]._serialized_start = 1919 42 | _globals["_MODELPROTO"]._serialized_end = 2429 43 | _globals["_MODELPROTO_SENTENCEPIECE"]._serialized_start = 2208 44 | _globals["_MODELPROTO_SENTENCEPIECE"]._serialized_end = 2418 45 | _globals["_MODELPROTO_SENTENCEPIECE_TYPE"]._serialized_start = 2323 46 | _globals["_MODELPROTO_SENTENCEPIECE_TYPE"]._serialized_end = 2407 47 | # @@protoc_insertion_point(module_scope) 48 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/tests/optimization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import os 21 | 22 | import torch 23 | 24 | from pytorch_transformers import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, 25 | WarmupCosineSchedule, WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule) 26 | 27 | from .tokenization_tests_commons import TemporaryDirectory 28 | 29 | 30 | def unwrap_schedule(scheduler, num_steps=10): 31 | lrs = [] 32 | for _ in range(num_steps): 33 | scheduler.step() 34 | lrs.append(scheduler.get_lr()) 35 | return lrs 36 | 37 | def unwrap_and_save_reload_schedule(scheduler, num_steps=10): 38 | lrs = [] 39 | for step in range(num_steps): 40 | scheduler.step() 41 | lrs.append(scheduler.get_lr()) 42 | if step == num_steps // 2: 43 | with TemporaryDirectory() as tmpdirname: 44 | file_name = os.path.join(tmpdirname, 'schedule.bin') 45 | torch.save(scheduler.state_dict(), file_name) 46 | 47 | state_dict = torch.load(file_name) 48 | scheduler.load_state_dict(state_dict) 49 | return lrs 50 | 51 | class OptimizationTest(unittest.TestCase): 52 | 53 | def assertListAlmostEqual(self, list1, list2, tol): 54 | self.assertEqual(len(list1), len(list2)) 55 | for a, b in zip(list1, list2): 56 | self.assertAlmostEqual(a, b, delta=tol) 57 | 58 | def test_adam_w(self): 59 | w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True) 60 | target = torch.tensor([0.4, 0.2, -0.5]) 61 | criterion = torch.nn.MSELoss() 62 | # No warmup, constant schedule, no gradient clipping 63 | optimizer = AdamW(params=[w], lr=2e-1, weight_decay=0.0) 64 | for _ in range(100): 65 | loss = criterion(w, target) 66 | loss.backward() 67 | optimizer.step() 68 | w.grad.detach_() # No zero_grad() function on simple tensors. we do it ourselves. 69 | w.grad.zero_() 70 | self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2) 71 | 72 | 73 | class ScheduleInitTest(unittest.TestCase): 74 | m = torch.nn.Linear(50, 50) 75 | optimizer = AdamW(m.parameters(), lr=10.) 76 | num_steps = 10 77 | 78 | def assertListAlmostEqual(self, list1, list2, tol): 79 | self.assertEqual(len(list1), len(list2)) 80 | for a, b in zip(list1, list2): 81 | self.assertAlmostEqual(a, b, delta=tol) 82 | 83 | def test_constant_scheduler(self): 84 | scheduler = ConstantLRSchedule(self.optimizer) 85 | lrs = unwrap_schedule(scheduler, self.num_steps) 86 | expected_learning_rates = [10.] * self.num_steps 87 | self.assertEqual(len(lrs[0]), 1) 88 | self.assertListEqual([l[0] for l in lrs], expected_learning_rates) 89 | 90 | scheduler = ConstantLRSchedule(self.optimizer) 91 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 92 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 93 | 94 | def test_warmup_constant_scheduler(self): 95 | scheduler = WarmupConstantSchedule(self.optimizer, warmup_steps=4) 96 | lrs = unwrap_schedule(scheduler, self.num_steps) 97 | expected_learning_rates = [2.5, 5.0, 7.5, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0] 98 | self.assertEqual(len(lrs[0]), 1) 99 | self.assertListEqual([l[0] for l in lrs], expected_learning_rates) 100 | 101 | scheduler = WarmupConstantSchedule(self.optimizer, warmup_steps=4) 102 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 103 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 104 | 105 | def test_warmup_linear_scheduler(self): 106 | scheduler = WarmupLinearSchedule(self.optimizer, warmup_steps=2, t_total=10) 107 | lrs = unwrap_schedule(scheduler, self.num_steps) 108 | expected_learning_rates = [5.0, 10.0, 8.75, 7.5, 6.25, 5.0, 3.75, 2.5, 1.25, 0.0] 109 | self.assertEqual(len(lrs[0]), 1) 110 | self.assertListEqual([l[0] for l in lrs], expected_learning_rates) 111 | 112 | scheduler = WarmupLinearSchedule(self.optimizer, warmup_steps=2, t_total=10) 113 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 114 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 115 | 116 | def test_warmup_cosine_scheduler(self): 117 | scheduler = WarmupCosineSchedule(self.optimizer, warmup_steps=2, t_total=10) 118 | lrs = unwrap_schedule(scheduler, self.num_steps) 119 | expected_learning_rates = [5.0, 10.0, 9.61, 8.53, 6.91, 5.0, 3.08, 1.46, 0.38, 0.0] 120 | self.assertEqual(len(lrs[0]), 1) 121 | self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2) 122 | 123 | scheduler = WarmupCosineSchedule(self.optimizer, warmup_steps=2, t_total=10) 124 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 125 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 126 | 127 | def test_warmup_cosine_hard_restart_scheduler(self): 128 | scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, warmup_steps=2, cycles=2, t_total=10) 129 | lrs = unwrap_schedule(scheduler, self.num_steps) 130 | expected_learning_rates = [5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46, 0.0] 131 | self.assertEqual(len(lrs[0]), 1) 132 | self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2) 133 | 134 | scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, warmup_steps=2, cycles=2, t_total=10) 135 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 136 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 137 | 138 | if __name__ == "__main__": 139 | unittest.main() 140 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/configuration_bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ BERT model configuration """ 17 | 18 | from __future__ import absolute_import, division, print_function, unicode_literals 19 | 20 | import json 21 | import logging 22 | import sys 23 | from io import open 24 | 25 | from .configuration_utils import PretrainedConfig 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 30 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json", 31 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json", 32 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json", 33 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json", 34 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json", 35 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json", 36 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json", 37 | 'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json", 38 | 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json", 39 | 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json", 40 | 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json", 41 | 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json", 42 | 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json", 43 | } 44 | 45 | 46 | class BertConfig(PretrainedConfig): 47 | r""" 48 | :class:`~pytorch_transformers.BertConfig` is the configuration class to store the configuration of a 49 | `BertModel`. 50 | 51 | 52 | Arguments: 53 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. 54 | hidden_size: Size of the encoder layers and the pooler layer. 55 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 56 | num_attention_heads: Number of attention heads for each attention layer in 57 | the Transformer encoder. 58 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 59 | layer in the Transformer encoder. 60 | hidden_act: The non-linear activation function (function or string) in the 61 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported. 62 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 63 | layers in the embeddings, encoder, and pooler. 64 | attention_probs_dropout_prob: The dropout ratio for the attention 65 | probabilities. 66 | max_position_embeddings: The maximum sequence length that this model might 67 | ever be used with. Typically set this to something large just in case 68 | (e.g., 512 or 1024 or 2048). 69 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 70 | `BertModel`. 71 | initializer_range: The sttdev of the truncated_normal_initializer for 72 | initializing all weight matrices. 73 | layer_norm_eps: The epsilon used by LayerNorm. 74 | """ 75 | pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP 76 | 77 | def __init__(self, 78 | vocab_size_or_config_json_file=28996, 79 | hidden_size=768, 80 | num_hidden_layers=12, 81 | num_attention_heads=12, 82 | intermediate_size=3072, 83 | hidden_act="gelu", 84 | hidden_dropout_prob=0.1, 85 | attention_probs_dropout_prob=0.1, 86 | max_position_embeddings=512, 87 | type_vocab_size=2, 88 | initializer_range=0.02, 89 | layer_norm_eps=1e-12, 90 | **kwargs): 91 | super(BertConfig, self).__init__(**kwargs) 92 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 93 | and isinstance(vocab_size_or_config_json_file, unicode)): 94 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: 95 | json_config = json.loads(reader.read()) 96 | for key, value in json_config.items(): 97 | self.__dict__[key] = value 98 | elif isinstance(vocab_size_or_config_json_file, int): 99 | self.vocab_size = vocab_size_or_config_json_file 100 | self.hidden_size = hidden_size 101 | self.num_hidden_layers = num_hidden_layers 102 | self.num_attention_heads = num_attention_heads 103 | self.hidden_act = hidden_act 104 | self.intermediate_size = intermediate_size 105 | self.hidden_dropout_prob = hidden_dropout_prob 106 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 107 | self.max_position_embeddings = max_position_embeddings 108 | self.type_vocab_size = type_vocab_size 109 | self.initializer_range = initializer_range 110 | self.layer_norm_eps = layer_norm_eps 111 | else: 112 | raise ValueError("First argument must be either a vocabulary size (int)" 113 | " or the path to a pretrained model config file (str)") 114 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/utils/dummy_sentencepiece_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import DummyObject, requires_backends 3 | 4 | 5 | class AlbertTokenizer(metaclass=DummyObject): 6 | _backends = ["sentencepiece"] 7 | 8 | def __init__(self, *args, **kwargs): 9 | requires_backends(self, ["sentencepiece"]) 10 | 11 | 12 | class BarthezTokenizer(metaclass=DummyObject): 13 | _backends = ["sentencepiece"] 14 | 15 | def __init__(self, *args, **kwargs): 16 | requires_backends(self, ["sentencepiece"]) 17 | 18 | 19 | class BartphoTokenizer(metaclass=DummyObject): 20 | _backends = ["sentencepiece"] 21 | 22 | def __init__(self, *args, **kwargs): 23 | requires_backends(self, ["sentencepiece"]) 24 | 25 | 26 | class BertGenerationTokenizer(metaclass=DummyObject): 27 | _backends = ["sentencepiece"] 28 | 29 | def __init__(self, *args, **kwargs): 30 | requires_backends(self, ["sentencepiece"]) 31 | 32 | 33 | class BigBirdTokenizer(metaclass=DummyObject): 34 | _backends = ["sentencepiece"] 35 | 36 | def __init__(self, *args, **kwargs): 37 | requires_backends(self, ["sentencepiece"]) 38 | 39 | 40 | class CamembertTokenizer(metaclass=DummyObject): 41 | _backends = ["sentencepiece"] 42 | 43 | def __init__(self, *args, **kwargs): 44 | requires_backends(self, ["sentencepiece"]) 45 | 46 | 47 | class CodeLlamaTokenizer(metaclass=DummyObject): 48 | _backends = ["sentencepiece"] 49 | 50 | def __init__(self, *args, **kwargs): 51 | requires_backends(self, ["sentencepiece"]) 52 | 53 | 54 | class CpmTokenizer(metaclass=DummyObject): 55 | _backends = ["sentencepiece"] 56 | 57 | def __init__(self, *args, **kwargs): 58 | requires_backends(self, ["sentencepiece"]) 59 | 60 | 61 | class DebertaV2Tokenizer(metaclass=DummyObject): 62 | _backends = ["sentencepiece"] 63 | 64 | def __init__(self, *args, **kwargs): 65 | requires_backends(self, ["sentencepiece"]) 66 | 67 | 68 | class ErnieMTokenizer(metaclass=DummyObject): 69 | _backends = ["sentencepiece"] 70 | 71 | def __init__(self, *args, **kwargs): 72 | requires_backends(self, ["sentencepiece"]) 73 | 74 | 75 | class FNetTokenizer(metaclass=DummyObject): 76 | _backends = ["sentencepiece"] 77 | 78 | def __init__(self, *args, **kwargs): 79 | requires_backends(self, ["sentencepiece"]) 80 | 81 | 82 | class GPTSw3Tokenizer(metaclass=DummyObject): 83 | _backends = ["sentencepiece"] 84 | 85 | def __init__(self, *args, **kwargs): 86 | requires_backends(self, ["sentencepiece"]) 87 | 88 | 89 | class LayoutXLMTokenizer(metaclass=DummyObject): 90 | _backends = ["sentencepiece"] 91 | 92 | def __init__(self, *args, **kwargs): 93 | requires_backends(self, ["sentencepiece"]) 94 | 95 | 96 | class LlamaTokenizer(metaclass=DummyObject): 97 | _backends = ["sentencepiece"] 98 | 99 | def __init__(self, *args, **kwargs): 100 | requires_backends(self, ["sentencepiece"]) 101 | 102 | 103 | class M2M100Tokenizer(metaclass=DummyObject): 104 | _backends = ["sentencepiece"] 105 | 106 | def __init__(self, *args, **kwargs): 107 | requires_backends(self, ["sentencepiece"]) 108 | 109 | 110 | class MarianTokenizer(metaclass=DummyObject): 111 | _backends = ["sentencepiece"] 112 | 113 | def __init__(self, *args, **kwargs): 114 | requires_backends(self, ["sentencepiece"]) 115 | 116 | 117 | class MBart50Tokenizer(metaclass=DummyObject): 118 | _backends = ["sentencepiece"] 119 | 120 | def __init__(self, *args, **kwargs): 121 | requires_backends(self, ["sentencepiece"]) 122 | 123 | 124 | class MBartTokenizer(metaclass=DummyObject): 125 | _backends = ["sentencepiece"] 126 | 127 | def __init__(self, *args, **kwargs): 128 | requires_backends(self, ["sentencepiece"]) 129 | 130 | 131 | class MLukeTokenizer(metaclass=DummyObject): 132 | _backends = ["sentencepiece"] 133 | 134 | def __init__(self, *args, **kwargs): 135 | requires_backends(self, ["sentencepiece"]) 136 | 137 | 138 | class MT5Tokenizer(metaclass=DummyObject): 139 | _backends = ["sentencepiece"] 140 | 141 | def __init__(self, *args, **kwargs): 142 | requires_backends(self, ["sentencepiece"]) 143 | 144 | 145 | class NllbTokenizer(metaclass=DummyObject): 146 | _backends = ["sentencepiece"] 147 | 148 | def __init__(self, *args, **kwargs): 149 | requires_backends(self, ["sentencepiece"]) 150 | 151 | 152 | class PegasusTokenizer(metaclass=DummyObject): 153 | _backends = ["sentencepiece"] 154 | 155 | def __init__(self, *args, **kwargs): 156 | requires_backends(self, ["sentencepiece"]) 157 | 158 | 159 | class PLBartTokenizer(metaclass=DummyObject): 160 | _backends = ["sentencepiece"] 161 | 162 | def __init__(self, *args, **kwargs): 163 | requires_backends(self, ["sentencepiece"]) 164 | 165 | 166 | class ReformerTokenizer(metaclass=DummyObject): 167 | _backends = ["sentencepiece"] 168 | 169 | def __init__(self, *args, **kwargs): 170 | requires_backends(self, ["sentencepiece"]) 171 | 172 | 173 | class RemBertTokenizer(metaclass=DummyObject): 174 | _backends = ["sentencepiece"] 175 | 176 | def __init__(self, *args, **kwargs): 177 | requires_backends(self, ["sentencepiece"]) 178 | 179 | 180 | class Speech2TextTokenizer(metaclass=DummyObject): 181 | _backends = ["sentencepiece"] 182 | 183 | def __init__(self, *args, **kwargs): 184 | requires_backends(self, ["sentencepiece"]) 185 | 186 | 187 | class SpeechT5Tokenizer(metaclass=DummyObject): 188 | _backends = ["sentencepiece"] 189 | 190 | def __init__(self, *args, **kwargs): 191 | requires_backends(self, ["sentencepiece"]) 192 | 193 | 194 | class T5Tokenizer(metaclass=DummyObject): 195 | _backends = ["sentencepiece"] 196 | 197 | def __init__(self, *args, **kwargs): 198 | requires_backends(self, ["sentencepiece"]) 199 | 200 | 201 | class XGLMTokenizer(metaclass=DummyObject): 202 | _backends = ["sentencepiece"] 203 | 204 | def __init__(self, *args, **kwargs): 205 | requires_backends(self, ["sentencepiece"]) 206 | 207 | 208 | class XLMProphetNetTokenizer(metaclass=DummyObject): 209 | _backends = ["sentencepiece"] 210 | 211 | def __init__(self, *args, **kwargs): 212 | requires_backends(self, ["sentencepiece"]) 213 | 214 | 215 | class XLMRobertaTokenizer(metaclass=DummyObject): 216 | _backends = ["sentencepiece"] 217 | 218 | def __init__(self, *args, **kwargs): 219 | requires_backends(self, ["sentencepiece"]) 220 | 221 | 222 | class XLNetTokenizer(metaclass=DummyObject): 223 | _backends = ["sentencepiece"] 224 | 225 | def __init__(self, *args, **kwargs): 226 | requires_backends(self, ["sentencepiece"]) 227 | -------------------------------------------------------------------------------- /larimar_base/pytorch_transformers/generation/stopping_criteria.py: -------------------------------------------------------------------------------- 1 | import time 2 | import warnings 3 | from abc import ABC 4 | from copy import deepcopy 5 | from typing import Optional 6 | 7 | import torch 8 | 9 | from ..utils import add_start_docstrings, logging 10 | 11 | 12 | logger = logging.get_logger(__name__) 13 | 14 | 15 | STOPPING_CRITERIA_INPUTS_DOCSTRING = r""" 16 | Args: 17 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 18 | Indices of input sequence tokens in the vocabulary. 19 | 20 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 21 | [`PreTrainedTokenizer.__call__`] for details. 22 | 23 | [What are input IDs?](../glossary#input-ids) 24 | scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`): 25 | Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax 26 | or scores for each vocabulary token after SoftMax. 27 | kwargs (`Dict[str, Any]`, *optional*): 28 | Additional stopping criteria specific kwargs. 29 | 30 | Return: 31 | `bool`. `False` indicates we should continue, `True` indicates we should stop. 32 | 33 | """ 34 | 35 | 36 | class StoppingCriteria(ABC): 37 | """Abstract base class for all stopping criteria that can be applied during generation.""" 38 | 39 | @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) 40 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 41 | raise NotImplementedError("StoppingCriteria needs to be subclassed") 42 | 43 | 44 | class MaxLengthCriteria(StoppingCriteria): 45 | """ 46 | This class can be used to stop generation whenever the full generated number of tokens exceeds `max_length`. Keep 47 | in mind for decoder-only type of transformers, this will include the initial prompted tokens. 48 | 49 | Args: 50 | max_length (`int`): 51 | The maximum length that the output sequence can have in number of tokens. 52 | max_position_embeddings (`int`, `optional`): 53 | The maximum model length, as defined by the model's `config.max_position_embeddings` attribute. 54 | """ 55 | 56 | def __init__(self, max_length: int, max_position_embeddings: Optional[int] = None): 57 | self.max_length = max_length 58 | self.max_position_embeddings = max_position_embeddings 59 | 60 | @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) 61 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 62 | cur_len = input_ids.shape[-1] 63 | is_done = cur_len >= self.max_length 64 | if self.max_position_embeddings is not None and not is_done and cur_len >= self.max_position_embeddings: 65 | logger.warning_once( 66 | "This is a friendly reminder - the current text generation call will exceed the model's predefined " 67 | f"maximum length ({self.max_position_embeddings}). Depending on the model, you may observe " 68 | "exceptions, performance degradation, or nothing at all." 69 | ) 70 | return is_done 71 | 72 | 73 | class MaxNewTokensCriteria(StoppingCriteria): 74 | """ 75 | This class can be used to stop generation whenever the generated number of tokens exceeds `max_new_tokens`. Keep in 76 | mind for decoder-only type of transformers, this will **not** include the initial prompted tokens. This is very 77 | close to `MaxLengthCriteria` but ignores the number of initial tokens. 78 | 79 | Args: 80 | start_length (`int`): 81 | The number of initial tokens. 82 | max_new_tokens (`int`): 83 | The maximum number of tokens to generate. 84 | """ 85 | 86 | def __init__(self, start_length: int, max_new_tokens: int): 87 | warnings.warn( 88 | "The class `MaxNewTokensCriteria` is deprecated. " 89 | f"Please use `MaxLengthCriteria(max_length={start_length + max_new_tokens})` " 90 | "with `max_length = start_length + max_new_tokens` instead.", 91 | FutureWarning, 92 | ) 93 | self.start_length = start_length 94 | self.max_new_tokens = max_new_tokens 95 | self.max_length = start_length + max_new_tokens 96 | 97 | @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) 98 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 99 | return input_ids.shape[-1] >= self.max_length 100 | 101 | 102 | class MaxTimeCriteria(StoppingCriteria): 103 | """ 104 | This class can be used to stop generation whenever the full generation exceeds some amount of time. By default, the 105 | time will start being counted when you initialize this function. You can override this by passing an 106 | `initial_time`. 107 | 108 | Args: 109 | max_time (`float`): 110 | The maximum allowed time in seconds for the generation. 111 | initial_time (`float`, *optional*, defaults to `time.time()`): 112 | The start of the generation allowed time. 113 | """ 114 | 115 | def __init__(self, max_time: float, initial_timestamp: Optional[float] = None): 116 | self.max_time = max_time 117 | self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp 118 | 119 | @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) 120 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 121 | return time.time() - self.initial_timestamp > self.max_time 122 | 123 | 124 | class StoppingCriteriaList(list): 125 | @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) 126 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 127 | return any(criteria(input_ids, scores) for criteria in self) 128 | 129 | @property 130 | def max_length(self) -> Optional[int]: 131 | for stopping_criterium in self: 132 | if isinstance(stopping_criterium, MaxLengthCriteria): 133 | return stopping_criterium.max_length 134 | elif isinstance(stopping_criterium, MaxNewTokensCriteria): 135 | return stopping_criterium.max_length 136 | return None 137 | 138 | 139 | def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_length: int) -> StoppingCriteriaList: 140 | stopping_max_length = stopping_criteria.max_length 141 | new_stopping_criteria = deepcopy(stopping_criteria) 142 | if stopping_max_length is not None and stopping_max_length != max_length: 143 | warnings.warn("You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning) 144 | elif stopping_max_length is None: 145 | new_stopping_criteria.append(MaxLengthCriteria(max_length=max_length)) 146 | return new_stopping_criteria 147 | --------------------------------------------------------------------------------