├── COMFUSE ├── __pycache__ │ ├── config.cpython-36.pyc │ ├── config.cpython-37.pyc │ ├── learner.cpython-36.pyc │ ├── learner.cpython-37.pyc │ ├── learn_emb.cpython-36.pyc │ ├── learn_emb.cpython-37.pyc │ ├── agent_base.cpython-36.pyc │ ├── agent_base.cpython-37.pyc │ ├── dataloader.cpython-36.pyc │ ├── dataloader.cpython-37.pyc │ ├── selfdropout.cpython-36.pyc │ ├── selfdropout.cpython-37.pyc │ ├── rumor_dataset.cpython-36.pyc │ └── rumor_dataset.cpython-37.pyc ├── run.txt ├── bertmodels │ ├── __pycache__ │ │ ├── bert.cpython-36.pyc │ │ └── bert.cpython-37.pyc │ └── bert.py ├── pytorch_pretrained │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __main__.cpython-37.pyc │ │ ├── modeling.cpython-36.pyc │ │ ├── modeling.cpython-37.pyc │ │ ├── file_utils.cpython-36.pyc │ │ ├── file_utils.cpython-37.pyc │ │ ├── modeling_gpt2.cpython-36.pyc │ │ ├── modeling_gpt2.cpython-37.pyc │ │ ├── optimization.cpython-36.pyc │ │ ├── optimization.cpython-37.pyc │ │ ├── tokenization.cpython-36.pyc │ │ ├── tokenization.cpython-37.pyc │ │ ├── modeling_openai.cpython-36.pyc │ │ ├── modeling_openai.cpython-37.pyc │ │ ├── tokenization_gpt2.cpython-36.pyc │ │ ├── tokenization_gpt2.cpython-37.pyc │ │ ├── modeling_transfo_xl.cpython-36.pyc │ │ ├── modeling_transfo_xl.cpython-37.pyc │ │ ├── optimization_openai.cpython-36.pyc │ │ ├── optimization_openai.cpython-37.pyc │ │ ├── tokenization_openai.cpython-36.pyc │ │ ├── tokenization_openai.cpython-37.pyc │ │ ├── tokenization_transfo_xl.cpython-36.pyc │ │ ├── tokenization_transfo_xl.cpython-37.pyc │ │ ├── modeling_transfo_xl_utilities.cpython-36.pyc │ │ ├── modeling_transfo_xl_utilities.cpython-37.pyc │ │ ├── convert_tf_checkpoint_to_pytorch.cpython-37.pyc │ │ ├── convert_gpt2_checkpoint_to_pytorch.cpython-37.pyc │ │ ├── convert_openai_checkpoint_to_pytorch.cpython-37.pyc │ │ └── convert_transfo_xl_checkpoint_to_pytorch.cpython-37.pyc │ ├── __init__.py │ ├── convert_tf_checkpoint_to_pytorch.py │ ├── convert_gpt2_checkpoint_to_pytorch.py │ ├── convert_openai_checkpoint_to_pytorch.py │ ├── __main__.py │ ├── optimization_openai.py │ ├── convert_transfo_xl_checkpoint_to_pytorch.py │ ├── file_utils.py │ ├── optimization.py │ ├── tokenization_gpt2.py │ ├── tokenization_openai.py │ ├── modeling_transfo_xl_utilities.py │ ├── tokenization.py │ └── tokenization_transfo_xl.py ├── Pheme_DataSet_Pair_comments │ └── data │ │ └── class.txt ├── bert_pretrain │ └── bert_config.json ├── DataSet_pair_comments │ └── data │ │ └── class.txt ├── config.py ├── learn_emb.py ├── dataloader.py ├── utils.py ├── selfdropout.py ├── rumor_dataset.py └── main.py └── README.md /COMFUSE/__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /COMFUSE/__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /COMFUSE/__pycache__/learner.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/__pycache__/learner.cpython-36.pyc -------------------------------------------------------------------------------- /COMFUSE/__pycache__/learner.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/__pycache__/learner.cpython-37.pyc -------------------------------------------------------------------------------- /COMFUSE/__pycache__/learn_emb.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/__pycache__/learn_emb.cpython-36.pyc -------------------------------------------------------------------------------- /COMFUSE/__pycache__/learn_emb.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/__pycache__/learn_emb.cpython-37.pyc -------------------------------------------------------------------------------- /COMFUSE/run.txt: -------------------------------------------------------------------------------- 1 | python main.py --shot=5 --gru_type=gru_bi_mult --gru_num_layer=2 --droptype=1 --droprate=0.3 --dataset=weibo --topic=3 --split_number=0 --ph=0 2 | -------------------------------------------------------------------------------- /COMFUSE/__pycache__/agent_base.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/__pycache__/agent_base.cpython-36.pyc -------------------------------------------------------------------------------- /COMFUSE/__pycache__/agent_base.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/__pycache__/agent_base.cpython-37.pyc -------------------------------------------------------------------------------- /COMFUSE/__pycache__/dataloader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/__pycache__/dataloader.cpython-36.pyc -------------------------------------------------------------------------------- /COMFUSE/__pycache__/dataloader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/__pycache__/dataloader.cpython-37.pyc -------------------------------------------------------------------------------- /COMFUSE/__pycache__/selfdropout.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/__pycache__/selfdropout.cpython-36.pyc -------------------------------------------------------------------------------- /COMFUSE/__pycache__/selfdropout.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/__pycache__/selfdropout.cpython-37.pyc -------------------------------------------------------------------------------- /COMFUSE/__pycache__/rumor_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/__pycache__/rumor_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /COMFUSE/__pycache__/rumor_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/__pycache__/rumor_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /COMFUSE/bertmodels/__pycache__/bert.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/bertmodels/__pycache__/bert.cpython-36.pyc -------------------------------------------------------------------------------- /COMFUSE/bertmodels/__pycache__/bert.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/bertmodels/__pycache__/bert.cpython-37.pyc -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/pytorch_pretrained/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/pytorch_pretrained/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/__pycache__/__main__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/pytorch_pretrained/__pycache__/__main__.cpython-37.pyc -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/__pycache__/modeling.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/pytorch_pretrained/__pycache__/modeling.cpython-36.pyc -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/__pycache__/modeling.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/pytorch_pretrained/__pycache__/modeling.cpython-37.pyc -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/__pycache__/file_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/pytorch_pretrained/__pycache__/file_utils.cpython-36.pyc -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/__pycache__/file_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/pytorch_pretrained/__pycache__/file_utils.cpython-37.pyc -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/__pycache__/modeling_gpt2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/pytorch_pretrained/__pycache__/modeling_gpt2.cpython-36.pyc -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/__pycache__/modeling_gpt2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/pytorch_pretrained/__pycache__/modeling_gpt2.cpython-37.pyc -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/__pycache__/optimization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/pytorch_pretrained/__pycache__/optimization.cpython-36.pyc -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/__pycache__/optimization.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/pytorch_pretrained/__pycache__/optimization.cpython-37.pyc -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/__pycache__/tokenization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/pytorch_pretrained/__pycache__/tokenization.cpython-36.pyc -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/__pycache__/tokenization.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/pytorch_pretrained/__pycache__/tokenization.cpython-37.pyc -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/__pycache__/modeling_openai.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/pytorch_pretrained/__pycache__/modeling_openai.cpython-36.pyc -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/__pycache__/modeling_openai.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/pytorch_pretrained/__pycache__/modeling_openai.cpython-37.pyc -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/__pycache__/tokenization_gpt2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/pytorch_pretrained/__pycache__/tokenization_gpt2.cpython-36.pyc -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/__pycache__/tokenization_gpt2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/pytorch_pretrained/__pycache__/tokenization_gpt2.cpython-37.pyc -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/__pycache__/modeling_transfo_xl.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/pytorch_pretrained/__pycache__/modeling_transfo_xl.cpython-36.pyc -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/__pycache__/modeling_transfo_xl.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/pytorch_pretrained/__pycache__/modeling_transfo_xl.cpython-37.pyc -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/__pycache__/optimization_openai.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/pytorch_pretrained/__pycache__/optimization_openai.cpython-36.pyc -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/__pycache__/optimization_openai.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/pytorch_pretrained/__pycache__/optimization_openai.cpython-37.pyc -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/__pycache__/tokenization_openai.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/pytorch_pretrained/__pycache__/tokenization_openai.cpython-36.pyc -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/__pycache__/tokenization_openai.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/pytorch_pretrained/__pycache__/tokenization_openai.cpython-37.pyc -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/__pycache__/tokenization_transfo_xl.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/pytorch_pretrained/__pycache__/tokenization_transfo_xl.cpython-36.pyc -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/__pycache__/tokenization_transfo_xl.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/pytorch_pretrained/__pycache__/tokenization_transfo_xl.cpython-37.pyc -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/__pycache__/modeling_transfo_xl_utilities.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/pytorch_pretrained/__pycache__/modeling_transfo_xl_utilities.cpython-36.pyc -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/__pycache__/modeling_transfo_xl_utilities.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/pytorch_pretrained/__pycache__/modeling_transfo_xl_utilities.cpython-37.pyc -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/__pycache__/convert_tf_checkpoint_to_pytorch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/pytorch_pretrained/__pycache__/convert_tf_checkpoint_to_pytorch.cpython-37.pyc -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/__pycache__/convert_gpt2_checkpoint_to_pytorch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/pytorch_pretrained/__pycache__/convert_gpt2_checkpoint_to_pytorch.cpython-37.pyc -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/__pycache__/convert_openai_checkpoint_to_pytorch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/pytorch_pretrained/__pycache__/convert_openai_checkpoint_to_pytorch.cpython-37.pyc -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/__pycache__/convert_transfo_xl_checkpoint_to_pytorch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jncsnlp/FSL-Multimodal-Rumor-Detection/HEAD/COMFUSE/pytorch_pretrained/__pycache__/convert_transfo_xl_checkpoint_to_pytorch.cpython-37.pyc -------------------------------------------------------------------------------- /COMFUSE/Pheme_DataSet_Pair_comments/data/class.txt: -------------------------------------------------------------------------------- 1 | ottawashooting-rumor 2 | ottawashooting-nonrumor 3 | sydneysiege-rumor 4 | sydneysiege-nonrumor 5 | ferguson-rumor 6 | ferguson-nonrumor 7 | charliehebdo-rumor 8 | charliehebdo-nonrumor 9 | germanwings-rumor 10 | germanwings-nonrumor 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FSL-Multimodal-Rumor-Detection 2 | 3 | Code and Data for paper submitted to peerj-computer-science, Accepted. 4 | 5 | "A novel few-shot learning based multi-modality fusion model for COVID-19 rumor detection from online social media" 6 | 7 | ## Cite Information 8 | 9 | Lu H, Fan C, Song X, Fang W. 2021. A novel few-shot learning based multi-modality fusion model for COVID-19 rumor detection from online social media. PeerJ Computer Science 7:e688 https://doi.org/10.7717/peerj-cs.688 10 | -------------------------------------------------------------------------------- /COMFUSE/bert_pretrain/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_size": 768, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 3072, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 12, 11 | "num_hidden_layers": 12, 12 | "pooler_fc_size": 768, 13 | "pooler_num_attention_heads": 12, 14 | "pooler_num_fc_layers": 3, 15 | "pooler_size_per_head": 128, 16 | "pooler_type": "first_token_transform", 17 | "type_vocab_size": 2, 18 | "vocab_size": 21128 19 | } 20 | -------------------------------------------------------------------------------- /COMFUSE/DataSet_pair_comments/data/class.txt: -------------------------------------------------------------------------------- 1 | MH370-rumor 2 | MH370-nonrumor 3 | CollegeEntranceExams-rumor 4 | CollegeEntranceExams-nonrumor 5 | Olympic-rumor 6 | Olympic-nonrumor 7 | UrbanManagers-rumor 8 | UrbanManagers-nonrumor 9 | Cola-rumor 10 | Cola-nonrumor 11 | ChildTrafficking-rumor 12 | ChildTrafficking-nonrumor 13 | WasteOil-rumor 14 | WasteOil-nonrumor 15 | Accident-rumor 16 | Accident-nonrumor 17 | Earthquake-rumor 18 | Earthquake-nonrumor 19 | Typhoon-rumor 20 | Typhoon-nonrumor 21 | Rabies-rumor 22 | Rabies-nonrumor 23 | COVID-Zhongnanshan-rumor 24 | COVID-Zhongnanshan-nonrumor 25 | COVID-Wuhan-rumor 26 | COVID-Wuhan-nonrumor 27 | COVID-BlockCity-rumor 28 | COVID-BlockCity-nonrumor 29 | -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.6.2" 2 | from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer 3 | from .tokenization_openai import OpenAIGPTTokenizer 4 | from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus) 5 | from .tokenization_gpt2 import GPT2Tokenizer 6 | 7 | from .modeling import (BertConfig, BertModel, BertForPreTraining, 8 | BertForMaskedLM, BertForNextSentencePrediction, 9 | BertForSequenceClassification, BertForMultipleChoice, 10 | BertForTokenClassification, BertForQuestionAnswering, 11 | load_tf_weights_in_bert) 12 | from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel, 13 | OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, 14 | load_tf_weights_in_openai_gpt) 15 | from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel, 16 | load_tf_weights_in_transfo_xl) 17 | from .modeling_gpt2 import (GPT2Config, GPT2Model, 18 | GPT2LMHeadModel, GPT2DoubleHeadsModel, 19 | load_tf_weights_in_gpt2) 20 | 21 | from .optimization import BertAdam 22 | from .optimization_openai import OpenAIAdam 23 | 24 | from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path, WEIGHTS_NAME, CONFIG_NAME 25 | -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/convert_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import re 23 | import argparse 24 | import tensorflow as tf 25 | import torch 26 | import numpy as np 27 | 28 | from pytorch_pretrained_bert.modeling import BertConfig, BertForPreTraining, load_tf_weights_in_bert 29 | 30 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 31 | # Initialise PyTorch model 32 | config = BertConfig.from_json_file(bert_config_file) 33 | print("Building PyTorch model from configuration: {}".format(str(config))) 34 | model = BertForPreTraining(config) 35 | 36 | # Load weights from tf checkpoint 37 | load_tf_weights_in_bert(model, tf_checkpoint_path) 38 | 39 | # Save pytorch-model 40 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 41 | torch.save(model.state_dict(), pytorch_dump_path) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | ## Required parameters 47 | parser.add_argument("--tf_checkpoint_path", 48 | default = None, 49 | type = str, 50 | required = True, 51 | help = "Path the TensorFlow checkpoint path.") 52 | parser.add_argument("--bert_config_file", 53 | default = None, 54 | type = str, 55 | required = True, 56 | help = "The config json file corresponding to the pre-trained BERT model. \n" 57 | "This specifies the model architecture.") 58 | parser.add_argument("--pytorch_dump_path", 59 | default = None, 60 | type = str, 61 | required = True, 62 | help = "Path to the output PyTorch model.") 63 | args = parser.parse_args() 64 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, 65 | args.bert_config_file, 66 | args.pytorch_dump_path) 67 | -------------------------------------------------------------------------------- /COMFUSE/config.py: -------------------------------------------------------------------------------- 1 | # n_way classification 2 | # hidden size 3 | def make_lstm_config(n_way, h_size): 4 | print('make lstm config') 5 | 6 | config = [ 7 | ('linear', [256, 768]), # character encoder 8 | ('gru', [h_size, 256]), # GRU encoder-decoder 9 | #### classifier 10 | ('linear', [n_way, h_size]), # classifier 11 | ] 12 | return config 13 | 14 | 15 | def make_bilstm_config(n_way, h_size): 16 | print('make bilstm config') 17 | 18 | config = [ 19 | ('linear', [256, 768]), # character encoder 20 | ('gru_bi', [h_size, 256]), # GRU encoder-decoder 21 | #### classifier 22 | ('linear', [n_way, h_size]), # classifier 23 | ] 24 | return config 25 | 26 | 27 | def make_lstm_multi_task_config(n_way, h_size, n_topic): 28 | print('make lstm multitask config') 29 | 30 | config = [ 31 | ('linear', [256, 768]), # character encoder 32 | ('gru', [h_size, 256]), # GRU encoder-decoder 33 | ('linear', [h_size, h_size]), 34 | ('relu', [True]), 35 | #### classifier 36 | ('linear', [n_way, h_size]), # classifier for rumor classification 37 | ('linear', [n_topic, h_size]), # classifier for topic classification 38 | ] 39 | return config 40 | 41 | 42 | def make_bilstm_multi_task_config(n_way, h_size, n_topic): 43 | print('make lstm multitask config') 44 | 45 | config = [ 46 | ('linear', [256, 768]), # character encoder 47 | ('gru_bi', [h_size, 256]), # GRU encoder-decoder 48 | ('linear', [h_size, h_size]), 49 | ('relu', [True]), 50 | #### classifier 51 | ('linear', [n_way, h_size]), # classifier for rumor classification 52 | ('linear', [n_topic, h_size]), # classifier for topic classification 53 | ] 54 | return config 55 | 56 | 57 | def make_bilstm_multi_layer_multi_task_config(n_way, h_size, n_layer, n_topic): 58 | print('make lstm multitask config') 59 | 60 | config = [ 61 | ('linear', [256, 768]), # character encoder 62 | ('gru_bi_multilayer', [h_size, 256, n_layer]), # GRU encoder-decoder 63 | ('linear', [h_size, h_size]), 64 | ('relu', [True]), 65 | #### classifier 66 | ('linear', [n_way, h_size]), # classifier for rumor classification 67 | ('linear', [n_topic, h_size]), # classifier for topic classification 68 | ] 69 | return config 70 | 71 | 72 | ########################################################### 73 | ## baselines 74 | 75 | def make_fcn_baseline_config(n_way, h_size): 76 | print('make lstm config') 77 | 78 | config = [ 79 | ('linear', [256, 768]), # character encoder 80 | #### classifier 81 | ('linear', [n_way, 256]), # classifier 82 | ] 83 | return config 84 | 85 | 86 | def make_rnn_baseline_config(n_way, h_size): 87 | print('make lstm config') 88 | 89 | config = [ 90 | ('linear', [256, 768]), # character encoder 91 | ('gru_bi_multilayer', [h_size, 256, 2]), # GRU encoder-decoder 92 | #### classifier 93 | ('linear', [n_way, h_size]), # classifier 94 | ] 95 | return config 96 | -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/convert_gpt2_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from pytorch_pretrained_bert.modeling_gpt2 import (CONFIG_NAME, WEIGHTS_NAME, 25 | GPT2Config, 26 | GPT2Model, 27 | load_tf_weights_in_gpt2) 28 | 29 | 30 | def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path): 31 | # Construct model 32 | if gpt2_config_file == "": 33 | config = GPT2Config() 34 | else: 35 | config = GPT2Config(gpt2_config_file) 36 | model = GPT2Model(config) 37 | 38 | # Load weights from numpy 39 | load_tf_weights_in_gpt2(model, gpt2_checkpoint_path) 40 | 41 | # Save pytorch-model 42 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 43 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 44 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 45 | torch.save(model.state_dict(), pytorch_weights_dump_path) 46 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 47 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 48 | f.write(config.to_json_string()) 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | ## Required parameters 54 | parser.add_argument("--gpt2_checkpoint_path", 55 | default = None, 56 | type = str, 57 | required = True, 58 | help = "Path the TensorFlow checkpoint path.") 59 | parser.add_argument("--pytorch_dump_folder_path", 60 | default = None, 61 | type = str, 62 | required = True, 63 | help = "Path to the output PyTorch model.") 64 | parser.add_argument("--gpt2_config_file", 65 | default = "", 66 | type = str, 67 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 68 | "This specifies the model architecture.") 69 | args = parser.parse_args() 70 | convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, 71 | args.gpt2_config_file, 72 | args.pytorch_dump_folder_path) 73 | -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/convert_openai_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from pytorch_pretrained_bert.modeling_openai import (CONFIG_NAME, WEIGHTS_NAME, 25 | OpenAIGPTConfig, 26 | OpenAIGPTModel, 27 | load_tf_weights_in_openai_gpt) 28 | 29 | 30 | def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): 31 | # Construct model 32 | if openai_config_file == "": 33 | config = OpenAIGPTConfig() 34 | else: 35 | config = OpenAIGPTConfig(openai_config_file) 36 | model = OpenAIGPTModel(config) 37 | 38 | # Load weights from numpy 39 | load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path) 40 | 41 | # Save pytorch-model 42 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 43 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 44 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 45 | torch.save(model.state_dict(), pytorch_weights_dump_path) 46 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 47 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 48 | f.write(config.to_json_string()) 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | ## Required parameters 54 | parser.add_argument("--openai_checkpoint_folder_path", 55 | default = None, 56 | type = str, 57 | required = True, 58 | help = "Path the TensorFlow checkpoint path.") 59 | parser.add_argument("--pytorch_dump_folder_path", 60 | default = None, 61 | type = str, 62 | required = True, 63 | help = "Path to the output PyTorch model.") 64 | parser.add_argument("--openai_config_file", 65 | default = "", 66 | type = str, 67 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 68 | "This specifies the model architecture.") 69 | args = parser.parse_args() 70 | convert_openai_checkpoint_to_pytorch(args.openai_checkpoint_folder_path, 71 | args.openai_config_file, 72 | args.pytorch_dump_folder_path) 73 | -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/__main__.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | def main(): 3 | import sys 4 | if (len(sys.argv) != 4 and len(sys.argv) != 5) or sys.argv[1] not in [ 5 | "convert_tf_checkpoint_to_pytorch", 6 | "convert_openai_checkpoint", 7 | "convert_transfo_xl_checkpoint", 8 | "convert_gpt2_checkpoint", 9 | ]: 10 | print( 11 | "Should be used as one of: \n" 12 | ">> `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n" 13 | ">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`, \n" 14 | ">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or \n" 15 | ">> `pytorch_pretrained_bert convert_gpt2_checkpoint TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]`") 16 | else: 17 | if sys.argv[1] == "convert_tf_checkpoint_to_pytorch": 18 | try: 19 | from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch 20 | except ImportError: 21 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 22 | "In that case, it requires TensorFlow to be installed. Please see " 23 | "https://www.tensorflow.org/install/ for installation instructions.") 24 | raise 25 | 26 | if len(sys.argv) != 5: 27 | # pylint: disable=line-too-long 28 | print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`") 29 | else: 30 | PYTORCH_DUMP_OUTPUT = sys.argv.pop() 31 | TF_CONFIG = sys.argv.pop() 32 | TF_CHECKPOINT = sys.argv.pop() 33 | convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 34 | elif sys.argv[1] == "convert_openai_checkpoint": 35 | from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch 36 | OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2] 37 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 38 | if len(sys.argv) == 5: 39 | OPENAI_GPT_CONFIG = sys.argv[4] 40 | else: 41 | OPENAI_GPT_CONFIG = "" 42 | convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH, 43 | OPENAI_GPT_CONFIG, 44 | PYTORCH_DUMP_OUTPUT) 45 | elif sys.argv[1] == "convert_transfo_xl_checkpoint": 46 | try: 47 | from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch 48 | except ImportError: 49 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 50 | "In that case, it requires TensorFlow to be installed. Please see " 51 | "https://www.tensorflow.org/install/ for installation instructions.") 52 | raise 53 | 54 | if 'ckpt' in sys.argv[2].lower(): 55 | TF_CHECKPOINT = sys.argv[2] 56 | TF_DATASET_FILE = "" 57 | else: 58 | TF_DATASET_FILE = sys.argv[2] 59 | TF_CHECKPOINT = "" 60 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 61 | if len(sys.argv) == 5: 62 | TF_CONFIG = sys.argv[4] 63 | else: 64 | TF_CONFIG = "" 65 | convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE) 66 | else: 67 | try: 68 | from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch 69 | except ImportError: 70 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 71 | "In that case, it requires TensorFlow to be installed. Please see " 72 | "https://www.tensorflow.org/install/ for installation instructions.") 73 | raise 74 | 75 | TF_CHECKPOINT = sys.argv[2] 76 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 77 | if len(sys.argv) == 5: 78 | TF_CONFIG = sys.argv[4] 79 | else: 80 | TF_CONFIG = "" 81 | convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 82 | if __name__ == '__main__': 83 | main() 84 | -------------------------------------------------------------------------------- /COMFUSE/learn_emb.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from sklearn import metrics 7 | import time 8 | from dataloader import get_time_dif 9 | from pytorch_pretrained.optimization import BertAdam 10 | 11 | 12 | # 权重初始化,默认xavier 13 | def init_network(model, method='xavier', exclude='embedding', seed=123): 14 | for name, w in model.named_parameters(): 15 | if exclude not in name: 16 | if len(w.size()) < 2: 17 | continue 18 | if 'weight' in name: 19 | if method == 'xavier': 20 | nn.init.xavier_normal_(w) 21 | elif method == 'kaiming': 22 | nn.init.kaiming_normal_(w) 23 | else: 24 | nn.init.normal_(w) 25 | elif 'bias' in name: 26 | nn.init.constant_(w, 0) 27 | else: 28 | pass 29 | 30 | 31 | def extract_emb(config, model, data_iter, featstype): 32 | start_time = time.time() 33 | model.train() 34 | param_optimizer = list(model.named_parameters()) 35 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 36 | optimizer_grouped_parameters = [ 37 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 38 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}] 39 | # optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate) 40 | optimizer = BertAdam(optimizer_grouped_parameters, 41 | lr=config.learning_rate, 42 | warmup=0.05, 43 | t_total=len(data_iter) * config.num_epochs) 44 | total_batch = 0 # 记录进行到多少batch 45 | dev_best_loss = float('inf') 46 | last_improve = 0 # 记录上次验证集loss下降的batch数 47 | flag = False # 记录是否很久没有效果提升 48 | model.train() 49 | 50 | # 存储训练集/验证集/测试集的embedding 51 | data_embs_con = [] 52 | data_embs_com = [] 53 | 54 | 55 | for i, (docs_comments, labels) in enumerate(data_iter): 56 | # print(docs) 57 | if featstype == "pooled": 58 | outputs_con, outputs_com = model.get_pooled(docs_comments) 59 | elif featstype == "emb_outs": 60 | outputs_con, outputs_com = model.get_emb(docs_comments) 61 | elif featstype == "enc_layer": 62 | outputs_con, outputs_com = model.get_enc(docs_comments) 63 | 64 | # 输出bert embedding 65 | # print(outputs.size()) 66 | np_outputs_con = outputs_con.cpu().detach().numpy() 67 | np_outputs_com = outputs_com.cpu().detach().numpy() 68 | data_embs_con.append(np_outputs_con) 69 | data_embs_com.append(np_outputs_com) 70 | 71 | # print(len(data_embs)) 72 | return data_embs_con, data_embs_com 73 | 74 | 75 | 76 | 77 | # def test(config, model, test_iter): 78 | # # test 79 | # model.load_state_dict(torch.load(config.save_path)) 80 | # model.eval() 81 | # start_time = time.time() 82 | # test_acc, test_loss, test_report, test_confusion = evaluate(config, model, test_iter, test=True) 83 | # msg = 'Test Loss: {0:>5.2}, Test Acc: {1:>6.2%}' 84 | # print(msg.format(test_loss, test_acc)) 85 | # print("Precision, Recall and F1-Score...") 86 | # print(test_report) 87 | # print("Confusion Matrix...") 88 | # print(test_confusion) 89 | # time_dif = get_time_dif(start_time) 90 | # print("Time usage:", time_dif) 91 | 92 | 93 | # def evaluate(config, model, data_iter, test=False): 94 | # model.eval() 95 | # loss_total = 0 96 | # predict_all = np.array([], dtype=int) 97 | # labels_all = np.array([], dtype=int) 98 | # with torch.no_grad(): 99 | # for texts, labels in data_iter: 100 | # outputs = model(texts) 101 | # loss = F.cross_entropy(outputs, labels) 102 | # loss_total += loss 103 | # labels = labels.data.cpu().numpy() 104 | # predic = torch.max(outputs.data, 1)[1].cpu().numpy() 105 | # labels_all = np.append(labels_all, labels) 106 | # predict_all = np.append(predict_all, predic) 107 | 108 | # acc = metrics.accuracy_score(labels_all, predict_all) 109 | # if test: 110 | # labels = [i for i in range(0,28)] 111 | # report = metrics.classification_report(labels_all, predict_all, labels = labels, target_names=config.class_list, digits=4) 112 | # confusion = metrics.confusion_matrix(labels_all, predict_all) 113 | # return acc, loss_total / len(data_iter), report, confusion 114 | # return acc, loss_total / len(data_iter) -------------------------------------------------------------------------------- /COMFUSE/bertmodels/bert.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import torch 3 | import torch.nn as nn 4 | # from pytorch_pretrained_bert import BertModel, BertTokenizer 5 | from pytorch_pretrained import BertModel, BertTokenizer 6 | 7 | 8 | class Config(object): 9 | 10 | """配置参数""" 11 | def __init__(self, data_path, doclabel_file, bert_path=None): 12 | self.model_name = 'bert' 13 | self.doc_path = doclabel_file # 文本类标数据路径 14 | # self.dev_path = dataset + '/data/dev.txt' # 验证集 15 | # self.test_path = dataset + '/data/test.txt' # 测试集 16 | self.class_list = [x.strip() for x in open( 17 | data_path + '/data/class.txt').readlines()] # 类别名单 18 | self.save_path = data_path + '/saved_dict/' + self.model_name + '.ckpt' # 模型训练结果 19 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备 20 | 21 | # self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练 22 | self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练 23 | self.num_classes = len(self.class_list) # 类别数 24 | self.num_epochs = 1 # epoch数 25 | self.batch_size = 1 # mini-batch大小 26 | # self.batch_size = 128 # mini-batch大小 27 | self.pad_size = 32 # 每句话处理成的长度(短填长切) 28 | self.learning_rate = 5e-6 # 学习率 29 | # self.learning_rate = 5e-5 # 学习率\ 30 | if bert_path is None: 31 | self.bert_path = './bert_pretrain' 32 | else: 33 | self.bert_path = bert_path 34 | print('Use bert path', self.bert_path) 35 | 36 | self.tokenizer = BertTokenizer.from_pretrained(self.bert_path) 37 | self.hidden_size = 768 38 | 39 | 40 | class Model(nn.Module): 41 | 42 | def __init__(self, config): 43 | super(Model, self).__init__() 44 | self.bert = BertModel.from_pretrained(config.bert_path) 45 | for param in self.bert.parameters(): 46 | param.requires_grad = True 47 | self.fc = nn.Linear(config.hidden_size, config.num_classes) 48 | 49 | def forward(self, x): 50 | context = x[0] # 输入的句子 51 | comment = x[1] # 输入的评论 52 | # print(context,comment) 53 | mask_con = x[3] # 对padding部分进行mask,和句子一个size,padding部分用0表示,如:[1, 1, 1, 1, 0, 0] 54 | mask_com = x[5] # 对padding部分进行mask,和句子一个size,padding部分用0表示,如:[1, 1, 1, 1, 0, 0] 55 | embedding_output, encoder_layer, pooled = self.bert(context, attention_mask=mask, output_all_encoded_layers=False) 56 | # print(embedding_output.size()) 57 | out = self.fc(pooled) 58 | return out 59 | 60 | def get_emb(self, x): 61 | context = x[0] # 输入的句子 62 | comment = x[1] # 输入的评论 63 | # print(context,comment) 64 | # print(context) 65 | mask_con = x[3] # 对padding部分进行mask,和句子一个size,padding部分用0表示,如:[1, 1, 1, 1, 0, 0] 66 | mask_com = x[5] # 对padding部分进行mask,和句子一个size,padding部分用0表示,如:[1, 1, 1, 1, 0, 0] 67 | # print(mask_com) 68 | embedding_output_con, encoder_layer_con, pooled_con = self.bert(context, attention_mask=mask_con, output_all_encoded_layers=False) 69 | embedding_output_com, encoder_layer_com, pooled_com = self.bert(comment, attention_mask=mask_com, output_all_encoded_layers=False) 70 | return embedding_output_con, embedding_output_com 71 | 72 | def get_enc(self, x): 73 | context = x[0] # 输入的句子 74 | comment = x[1] # 输入的评论 75 | # print(context) 76 | mask_con = x[3] # 对padding部分进行mask,和句子一个size,padding部分用0表示,如:[1, 1, 1, 1, 0, 0] 77 | mask_com = x[5] # 对padding部分进行mask,和句子一个size,padding部分用0表示,如:[1, 1, 1, 1, 0, 0] 78 | embedding_output_con, encoder_layer_con, pooled_con = self.bert(context, attention_mask=mask_con, output_all_encoded_layers=False) 79 | embedding_output_com, encoder_layer_com, pooled_com = self.bert(comment, attention_mask=mask_com, output_all_encoded_layers=False) 80 | return encoder_layer_con, embedding_output_com 81 | 82 | def get_pooled(self, x): 83 | context = x[0] # 输入的句子 84 | comment = x[1] # 输入的评论 85 | # print(context) 86 | mask_con = x[3] # 对padding部分进行mask,和句子一个size,padding部分用0表示,如:[1, 1, 1, 1, 0, 0] 87 | mask_com = x[5] # 对padding部分进行mask,和句子一个size,padding部分用0表示,如:[1, 1, 1, 1, 0, 0] 88 | embedding_output_con, encoder_layer_con, pooled_con = self.bert(context, attention_mask=mask_con, output_all_encoded_layers=False) 89 | embedding_output_com, encoder_layer_com, pooled_com = self.bert(comment, attention_mask=mask_com, output_all_encoded_layers=False) 90 | return pooled_con, pooled_com 91 | -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/optimization_openai.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for OpenAI GPT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | import logging 23 | from .optimization import SCHEDULES, _LRSchedule, WarmupCosineWithWarmupRestartsSchedule, \ 24 | WarmupCosineWithHardRestartsSchedule, WarmupCosineSchedule, WarmupLinearSchedule, WarmupConstantSchedule 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | class OpenAIAdam(Optimizer): 30 | """Implements Open AI version of Adam algorithm with weight decay fix. 31 | """ 32 | def __init__(self, params, lr=required, schedule='warmup_linear', warmup=-1, t_total=-1, 33 | b1=0.9, b2=0.999, e=1e-8, weight_decay=0, 34 | vector_l2=False, max_grad_norm=-1, **kwargs): 35 | if lr is not required and lr < 0.0: 36 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 37 | if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES: 38 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 39 | if not 0.0 <= b1 < 1.0: 40 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 41 | if not 0.0 <= b2 < 1.0: 42 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 43 | if not e >= 0.0: 44 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 45 | # initialize schedule object 46 | if not isinstance(schedule, _LRSchedule): 47 | schedule_type = SCHEDULES[schedule] 48 | schedule = schedule_type(warmup=warmup, t_total=t_total) 49 | else: 50 | if warmup != -1 or t_total != -1: 51 | logger.warning("warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. " 52 | "Please specify custom warmup and t_total in _LRSchedule object.") 53 | defaults = dict(lr=lr, schedule=schedule, 54 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2, 55 | max_grad_norm=max_grad_norm) 56 | super(OpenAIAdam, self).__init__(params, defaults) 57 | 58 | def get_lr(self): 59 | lr = [] 60 | for group in self.param_groups: 61 | for p in group['params']: 62 | state = self.state[p] 63 | if len(state) == 0: 64 | return [0] 65 | lr_scheduled = group['lr'] 66 | lr_scheduled *= group['schedule'].get_lr(state['step']) 67 | lr.append(lr_scheduled) 68 | return lr 69 | 70 | def step(self, closure=None): 71 | """Performs a single optimization step. 72 | 73 | Arguments: 74 | closure (callable, optional): A closure that reevaluates the model 75 | and returns the loss. 76 | """ 77 | loss = None 78 | if closure is not None: 79 | loss = closure() 80 | 81 | for group in self.param_groups: 82 | for p in group['params']: 83 | if p.grad is None: 84 | continue 85 | grad = p.grad.data 86 | if grad.is_sparse: 87 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 88 | 89 | state = self.state[p] 90 | 91 | # State initialization 92 | if len(state) == 0: 93 | state['step'] = 0 94 | # Exponential moving average of gradient values 95 | state['exp_avg'] = torch.zeros_like(p.data) 96 | # Exponential moving average of squared gradient values 97 | state['exp_avg_sq'] = torch.zeros_like(p.data) 98 | 99 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 100 | beta1, beta2 = group['b1'], group['b2'] 101 | 102 | state['step'] += 1 103 | 104 | # Add grad clipping 105 | if group['max_grad_norm'] > 0: 106 | clip_grad_norm_(p, group['max_grad_norm']) 107 | 108 | # Decay the first and second moment running average coefficient 109 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 110 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 111 | denom = exp_avg_sq.sqrt().add_(group['e']) 112 | 113 | bias_correction1 = 1 - beta1 ** state['step'] 114 | bias_correction2 = 1 - beta2 ** state['step'] 115 | 116 | lr_scheduled = group['lr'] 117 | lr_scheduled *= group['schedule'].get_lr(state['step']) 118 | 119 | step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 120 | 121 | p.data.addcdiv_(-step_size, exp_avg, denom) 122 | 123 | # Add weight decay at the end (fixed version) 124 | if (len(p.size()) > 1 or group['vector_l2']) and group['weight_decay'] > 0: 125 | p.data.add_(-lr_scheduled * group['weight_decay'], p.data) 126 | 127 | return loss 128 | -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/convert_transfo_xl_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert Transformer XL checkpoint and datasets.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | import os 21 | import sys 22 | from io import open 23 | 24 | import torch 25 | 26 | import pytorch_pretrained_bert.tokenization_transfo_xl as data_utils 27 | from pytorch_pretrained_bert.modeling_transfo_xl import (CONFIG_NAME, 28 | WEIGHTS_NAME, 29 | TransfoXLConfig, 30 | TransfoXLLMHeadModel, 31 | load_tf_weights_in_transfo_xl) 32 | from pytorch_pretrained_bert.tokenization_transfo_xl import (CORPUS_NAME, 33 | VOCAB_NAME) 34 | 35 | if sys.version_info[0] == 2: 36 | import cPickle as pickle 37 | else: 38 | import pickle 39 | 40 | # We do this to be able to load python 2 datasets pickles 41 | # See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918 42 | data_utils.Vocab = data_utils.TransfoXLTokenizer 43 | data_utils.Corpus = data_utils.TransfoXLCorpus 44 | sys.modules['data_utils'] = data_utils 45 | sys.modules['vocabulary'] = data_utils 46 | 47 | def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, 48 | transfo_xl_config_file, 49 | pytorch_dump_folder_path, 50 | transfo_xl_dataset_file): 51 | if transfo_xl_dataset_file: 52 | # Convert a pre-processed corpus (see original TensorFlow repo) 53 | with open(transfo_xl_dataset_file, "rb") as fp: 54 | corpus = pickle.load(fp, encoding="latin1") 55 | # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term) 56 | pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME 57 | print("Save vocabulary to {}".format(pytorch_vocab_dump_path)) 58 | corpus_vocab_dict = corpus.vocab.__dict__ 59 | torch.save(corpus_vocab_dict, pytorch_vocab_dump_path) 60 | 61 | corpus_dict_no_vocab = corpus.__dict__ 62 | corpus_dict_no_vocab.pop('vocab', None) 63 | pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME 64 | print("Save dataset to {}".format(pytorch_dataset_dump_path)) 65 | torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path) 66 | 67 | if tf_checkpoint_path: 68 | # Convert a pre-trained TensorFlow model 69 | config_path = os.path.abspath(transfo_xl_config_file) 70 | tf_path = os.path.abspath(tf_checkpoint_path) 71 | 72 | print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path)) 73 | # Initialise PyTorch model 74 | if transfo_xl_config_file == "": 75 | config = TransfoXLConfig() 76 | else: 77 | config = TransfoXLConfig(transfo_xl_config_file) 78 | print("Building PyTorch model from configuration: {}".format(str(config))) 79 | model = TransfoXLLMHeadModel(config) 80 | 81 | model = load_tf_weights_in_transfo_xl(model, config, tf_path) 82 | # Save pytorch-model 83 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) 84 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) 85 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path))) 86 | torch.save(model.state_dict(), pytorch_weights_dump_path) 87 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path))) 88 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 89 | f.write(config.to_json_string()) 90 | 91 | 92 | if __name__ == "__main__": 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument("--pytorch_dump_folder_path", 95 | default = None, 96 | type = str, 97 | required = True, 98 | help = "Path to the folder to store the PyTorch model or dataset/vocab.") 99 | parser.add_argument("--tf_checkpoint_path", 100 | default = "", 101 | type = str, 102 | help = "An optional path to a TensorFlow checkpoint path to be converted.") 103 | parser.add_argument("--transfo_xl_config_file", 104 | default = "", 105 | type = str, 106 | help = "An optional config json file corresponding to the pre-trained BERT model. \n" 107 | "This specifies the model architecture.") 108 | parser.add_argument("--transfo_xl_dataset_file", 109 | default = "", 110 | type = str, 111 | help = "An optional dataset file to be converted in a vocabulary.") 112 | args = parser.parse_args() 113 | convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path, 114 | args.transfo_xl_config_file, 115 | args.pytorch_dump_folder_path, 116 | args.transfo_xl_dataset_file) 117 | -------------------------------------------------------------------------------- /COMFUSE/dataloader.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import torch 3 | from tqdm import tqdm 4 | import time 5 | from datetime import timedelta 6 | 7 | PAD, CLS = '[PAD]', '[CLS]' # padding符号, bert中综合信息符号 8 | 9 | 10 | def build_dataset(config): 11 | 12 | def load_dataset(path, pad_size=32, com_pad_size=32): 13 | print(path) 14 | contents_comments = [] 15 | with open(path, 'r', encoding='UTF-8') as f: 16 | for line in tqdm(f): 17 | # print(line) 18 | lin = line.strip() 19 | if not lin: 20 | continue 21 | content, comments, label = lin.split('\t') 22 | # tokenize content 23 | token_con = config.tokenizer.tokenize(content) 24 | token_con = [CLS] + token_con 25 | #print(token) 26 | seq_len = len(token_con) 27 | mask = [] 28 | token_con_ids = config.tokenizer.convert_tokens_to_ids(token_con) 29 | 30 | if pad_size: 31 | if len(token_con) < pad_size: 32 | mask = [1] * len(token_con_ids) + [0] * (pad_size - len(token_con)) 33 | token_con_ids += ([0] * (pad_size - len(token_con))) 34 | else: 35 | mask = [1] * pad_size 36 | token_con_ids = token_con_ids[:pad_size] 37 | seq_len = pad_size 38 | 39 | # tokenize comment 40 | # com_pad_size = 32 41 | all_comments = comments.split(';') 42 | all_token_com_ids = [] 43 | all_seq_len_com = [] 44 | all_mask_com = [] 45 | # print("=================",all_comments) 46 | for idx in range(3): 47 | # print(idx, all_comments) 48 | 49 | if (idx+1 > len(all_comments)): 50 | comment = "" 51 | # print(idx, comment) 52 | elif (len(all_comments[idx]) > 0): 53 | comment = all_comments[idx] 54 | # print(idx, comment) 55 | else: 56 | comment = "" 57 | # print(idx, comment) 58 | 59 | token_com = config.tokenizer.tokenize(comment) 60 | token_com = [CLS] + token_com 61 | #print(token) 62 | seq_len_com = len(token_com) 63 | mask_com = [] 64 | token_com_ids = config.tokenizer.convert_tokens_to_ids(token_com) 65 | # print(token_com_ids) 66 | 67 | if com_pad_size: 68 | if len(token_com) < com_pad_size: 69 | mask_com = [1] * len(token_com_ids) + [0] * (com_pad_size - len(token_com)) 70 | token_com_ids += ([0] * (com_pad_size - len(token_com))) 71 | else: 72 | mask_com = [1] * com_pad_size 73 | token_com_ids = token_com_ids[:com_pad_size] 74 | seq_len_com = com_pad_size 75 | # print(seq_len_com) 76 | # print(mask_com) 77 | all_token_com_ids.extend(token_com_ids) 78 | all_seq_len_com.append(seq_len_com) 79 | all_mask_com.extend(mask_com) 80 | # print(all_token_com_ids, all_seq_len_com, all_mask_com) 81 | 82 | contents_comments.append((token_con_ids, all_token_com_ids, int(label), seq_len, mask, all_seq_len_com, all_mask_com)) 83 | 84 | return contents_comments 85 | data = load_dataset(config.doc_path, config.pad_size, config.com_pad_size) 86 | return data 87 | 88 | 89 | class DatasetIterater(object): 90 | def __init__(self, batches, batch_size, device): 91 | self.batch_size = batch_size 92 | self.batches = batches 93 | self.n_batches = len(batches) // batch_size 94 | self.residue = False # 记录batch数量是否为整数 95 | if len(batches) % self.n_batches != 0: 96 | self.residue = True 97 | self.index = 0 98 | self.device = device 99 | 100 | def _to_tensor(self, datas): 101 | x_con = torch.LongTensor([_[0] for _ in datas]).to(self.device) 102 | x_com = torch.LongTensor([_[1] for _ in datas]).to(self.device) 103 | y = torch.LongTensor([_[2] for _ in datas]).to(self.device) 104 | 105 | # pad前的长度(超过pad_size的设为pad_size) 106 | seq_len_con = torch.LongTensor([_[3] for _ in datas]).to(self.device) 107 | seq_len_com = torch.LongTensor([_[5] for _ in datas]).to(self.device) 108 | mask_con = torch.LongTensor([_[4] for _ in datas]).to(self.device) 109 | mask_com = torch.LongTensor([_[6] for _ in datas]).to(self.device) 110 | return (x_con, x_com, seq_len_con, mask_con, seq_len_com, mask_com), y 111 | 112 | def __next__(self): 113 | if self.residue and self.index == self.n_batches: 114 | batches = self.batches[self.index * self.batch_size: len(self.batches)] 115 | self.index += 1 116 | batches = self._to_tensor(batches) 117 | return batches 118 | 119 | elif self.index >= self.n_batches: 120 | self.index = 0 121 | raise StopIteration 122 | else: 123 | batches = self.batches[self.index * self.batch_size: (self.index + 1) * self.batch_size] 124 | self.index += 1 125 | batches = self._to_tensor(batches) 126 | return batches 127 | 128 | def __iter__(self): 129 | return self 130 | 131 | def __len__(self): 132 | if self.residue: 133 | return self.n_batches + 1 134 | else: 135 | return self.n_batches 136 | 137 | 138 | def build_iterator(dataset, config): 139 | iter = DatasetIterater(dataset, config.batch_size, config.device) 140 | return iter 141 | 142 | 143 | def get_time_dif(start_time): 144 | """获取已使用时间""" 145 | end_time = time.time() 146 | time_dif = end_time - start_time 147 | return timedelta(seconds=int(round(time_dif))) 148 | -------------------------------------------------------------------------------- /COMFUSE/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Misc Utility functions 3 | """ 4 | import os 5 | import logging 6 | import datetime 7 | import numpy as np 8 | import torch 9 | from collections import OrderedDict 10 | from torch import nn 11 | import torch.nn.init as init 12 | 13 | 14 | def init_weights(m): 15 | if isinstance(m, nn.Conv1d): 16 | init.normal_(m.weight.data) 17 | if m.bias is not None: 18 | init.normal_(m.bias.data) 19 | elif isinstance(m, nn.Conv2d): 20 | init.xavier_normal_(m.weight.data) 21 | if m.bias is not None: 22 | init.normal_(m.bias.data) 23 | elif isinstance(m, nn.Conv3d): 24 | init.xavier_normal_(m.weight.data) 25 | if m.bias is not None: 26 | init.normal_(m.bias.data) 27 | elif isinstance(m, nn.ConvTranspose1d): 28 | init.normal_(m.weight.data) 29 | if m.bias is not None: 30 | init.normal_(m.bias.data) 31 | elif isinstance(m, nn.ConvTranspose2d): 32 | init.xavier_normal_(m.weight.data) 33 | if m.bias is not None: 34 | init.normal_(m.bias.data) 35 | elif isinstance(m, nn.ConvTranspose3d): 36 | init.xavier_normal_(m.weight.data) 37 | if m.bias is not None: 38 | init.normal_(m.bias.data) 39 | elif isinstance(m, nn.BatchNorm1d): 40 | init.normal_(m.weight.data, mean=1, std=0.02) 41 | init.constant_(m.bias.data, 0) 42 | elif isinstance(m, nn.BatchNorm2d): 43 | init.normal_(m.weight.data, mean=1, std=0.02) 44 | init.constant_(m.bias.data, 0) 45 | elif isinstance(m, nn.BatchNorm3d): 46 | init.normal_(m.weight.data, mean=1, std=0.02) 47 | init.constant_(m.bias.data, 0) 48 | elif isinstance(m, nn.Linear): 49 | init.xavier_normal_(m.weight.data) 50 | init.normal_(m.bias.data) 51 | elif isinstance(m, nn.LSTM): 52 | for param in m.parameters(): 53 | if len(param.shape) >= 2: 54 | init.orthogonal_(param.data) 55 | else: 56 | init.normal_(param.data) 57 | elif isinstance(m, nn.LSTMCell): 58 | for param in m.parameters(): 59 | if len(param.shape) >= 2: 60 | init.orthogonal_(param.data) 61 | else: 62 | init.normal_(param.data) 63 | elif isinstance(m, nn.GRU): 64 | for param in m.parameters(): 65 | if len(param.shape) >= 2: 66 | init.orthogonal_(param.data) 67 | else: 68 | init.normal_(param.data) 69 | elif isinstance(m, nn.GRUCell): 70 | for param in m.parameters(): 71 | if len(param.shape) >= 2: 72 | init.orthogonal_(param.data) 73 | else: 74 | init.normal_(param.data) 75 | 76 | 77 | def recursive_glob(rootdir=".", suffix=""): 78 | """Performs recursive glob with given suffix and rootdir 79 | :param rootdir is the root directory 80 | :param suffix is the suffix to be searched 81 | """ 82 | return [ 83 | os.path.join(looproot, filename) 84 | for looproot, _, filenames in os.walk(rootdir) 85 | for filename in filenames 86 | if filename.endswith(suffix) 87 | ] 88 | 89 | 90 | def alpha_blend(input_image, segmentation_mask, alpha=0.5): 91 | """Alpha Blending utility to overlay RGB masks on RBG images 92 | :param input_image is a np.ndarray with 3 channels 93 | :param segmentation_mask is a np.ndarray with 3 channels 94 | :param alpha is a float value 95 | """ 96 | # blended = np.zeros(input_image.size, dtype=np.float32) 97 | blended = input_image * alpha + segmentation_mask * (1 - alpha) 98 | return blended 99 | 100 | 101 | def convert_state_dict(state_dict): 102 | """Converts a state dict saved from a dataParallel module to normal 103 | module state_dict inplace 104 | :param state_dict is the loaded DataParallel model_state 105 | """ 106 | new_state_dict = OrderedDict() 107 | for k, v in state_dict.items(): 108 | name = k[7:] # remove `module.` 109 | new_state_dict[name] = v 110 | return new_state_dict 111 | 112 | 113 | def get_logger(logdir): 114 | logger = logging.getLogger("ptsemseg") 115 | ts = str(datetime.datetime.now()).split(".")[0].replace(" ", "_") 116 | ts = ts.replace(":", "_").replace("-", "_") 117 | file_path = os.path.join(logdir, "run_{}.log".format(ts)) 118 | hdlr = logging.FileHandler(file_path) 119 | formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s") 120 | hdlr.setFormatter(formatter) 121 | logger.addHandler(hdlr) 122 | logger.setLevel(logging.INFO) 123 | return logger 124 | 125 | 126 | def load_model(model, model_path, strict=True): 127 | model.load_state_dict(torch.load(model_path)["model_state"], strict=strict) 128 | 129 | 130 | def load_model_parallel(model, model_path): 131 | state = convert_state_dict(torch.load(model_path)["model_state"]) 132 | model.load_state_dict(state, strict=False) # for parallel 133 | 134 | 135 | def load_model_both(model, dir): 136 | model_dict = model.state_dict() 137 | print('loading model from :', dir) 138 | pretrained_dict = torch.load(dir)['params'] 139 | if 'encoder' in list(pretrained_dict.keys())[0]: # load from a parallel meta-trained model 140 | if 'module' in list(pretrained_dict.keys())[0]: 141 | pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items()} 142 | else: 143 | pretrained_dict = {k: v for k, v in pretrained_dict.items()} 144 | else: 145 | pretrained_dict = {'encoder.' + k: v for k, v in pretrained_dict.items()} # load from a pretrained model 146 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 147 | model_dict.update(pretrained_dict) # update the param in encoder, remain others still 148 | model.load_state_dict(model_dict) 149 | 150 | return model 151 | 152 | 153 | class Averager(): 154 | 155 | def __init__(self): 156 | self.n = 0 157 | self.v = 0 158 | 159 | def add(self, x): 160 | self.v = (self.v * self.n + x) / (self.n + 1) 161 | self.n += 1 162 | 163 | def item(self): 164 | return self.v 165 | 166 | 167 | 168 | def count_acc(logits, label): 169 | pred = torch.argmax(logits, dim=1) 170 | return (pred == label).float().mean().item() 171 | 172 | 173 | 174 | def count_acc_mask(logits, label, n_way): 175 | pred = torch.argmax(logits, dim=1) 176 | idx = (label>=0) & (label= self.n_per 50 | pos = torch.randperm(len(l))[:self.n_per] # sample n_per data index of this class 51 | batch.append(l[pos]) 52 | batch = torch.stack(batch).t().reshape(-1) 53 | # .t() transpose, 54 | # due to it, the label is in the sequence of abcdabcdabcd form after reshape, 55 | # instead of aaaabbbbccccdddd 56 | yield batch 57 | 58 | elif self.mode == 'probe': 59 | print('Probe to fix val set') 60 | for i_batch in range(self.n_batch): 61 | batch = [] 62 | classes = torch.randperm(self.num_labels // 2)[ 63 | :self.n_cls // 2] # random sample num_class indices,e.g. 5 64 | for c in classes: 65 | c1 = c * 2 66 | c2 = c1 + 1 67 | for ci in [c1, c2]: 68 | l = self.m_ind[ci] # all data indexs of this class 69 | assert len(l) >= self.n_per 70 | pos = torch.randperm(len(l))[:self.n_per] # sample n_per data index of this class 71 | batch.append(l[pos]) 72 | 73 | self.fixed_batches.append(batch) 74 | batch_t = torch.stack(batch).t().reshape(-1) 75 | # .t() transpose, 76 | # due to it, the label is in the sequence of abcdabcdabcd form after reshape, 77 | # instead of aaaabbbbccccdddd 78 | yield batch_t 79 | 80 | else: 81 | assert self.mode == 'fix' 82 | assert len(self.fixed_batches) == self.n_batch 83 | # print(self.fixed_batches) 84 | for ix, batch in enumerate(self.fixed_batches): 85 | batch_t = torch.stack(batch).t().reshape(-1) 86 | yield batch_t 87 | 88 | 89 | class wb_rumor_fsl_dataset(Dataset): 90 | 91 | def __init__(self, split_name='train', split_no=0, model='bert', featstype='emb_outs', data_path='./DataSet_pair_comments/', 92 | pad_size=None, com_pad_size=None, bert_path=None): 93 | 94 | assert split_name in ['train', 'dev', 'test'] 95 | assert split_no in [0, 1, 2] 96 | 97 | self.split_name = split_name 98 | self.split_no = split_no 99 | self.data_path = data_path + "data" 100 | 101 | # doclabel_path:原始文件路径 102 | if split_name in ['train', 'dev']: 103 | # train or dev file: doc + "\t" + label 104 | # print(data_path) 105 | doclabel_file = os.path.join(self.data_path, '%s_%d.txt' % (split_name, split_no)) 106 | else: 107 | # test file: doc + "\t" + label 108 | doclabel_file = os.path.join(self.data_path, '%s.txt' % (split_name,)) 109 | 110 | # print(doclabel_file) 111 | assert os.path.isfile(doclabel_file) 112 | assert pad_size, 'Please set a valid pad size' 113 | 114 | print("Loading data...") 115 | 116 | model_name = model # bert 117 | print("model name: %s" % (model_name)) 118 | X = import_module('bertmodels.' + model_name) 119 | # print(X) 120 | config = X.Config(data_path, doclabel_file, bert_path) 121 | 122 | # load label 123 | # for x in open(doclabel_file, 'r').readlines(): 124 | # print(x.strip().split("\t")) 125 | # 126 | # print() 127 | # for x in open(doclabel_file, 'r', encoding='utf-8').readlines(): 128 | # print(x.strip().split("\t")[0]) 129 | # print(x.strip().split("\t")[2]) 130 | lines = [int(x.strip().split("\t")[2]) for x in open(doclabel_file, 'r', encoding='utf-8').readlines()] 131 | # print(lines) 132 | print('%s split %d has %d samples' % (split_name, split_no, len(lines))) 133 | # print(lines,'---- dataset Line 112') 134 | 135 | # load doc and label 136 | # doc_com_label: tokenIDs_con, tokenIDs_com, labels, seq_len_con, masks_con, seq_len_com(list), masks_com 137 | config.pad_size = pad_size 138 | config.com_pad_size = com_pad_size 139 | doc_com_label = build_dataset(config) 140 | # print(doc_label[0][2]) 141 | data_iter = build_iterator(doc_com_label, config) 142 | 143 | # learn embeddings 144 | 145 | # train 146 | bertmodel = X.Model(config).to(config.device) 147 | emb_con, emb_com = extract_emb(config, bertmodel, data_iter, featstype) 148 | # print(len(emb_con), len(emb_com)) 149 | # print(len(emb_con[0]), len(emb_com[0])) 150 | # print(emb_con, emb_com) 151 | feat = np.array(emb_con) 152 | feat_com = np.array(emb_com) 153 | # print(feat.shape) 154 | # print(feat_com.shape) 155 | 156 | # load feature 157 | # feat = np.load(emd_file) 158 | print('Load numpy of shape', feat.shape) # (N, 1, PAD, D) 159 | assert len(feat.shape) == 4 160 | assert feat.shape[0] == len(lines) 161 | # self.feat = feat[:, 0, :, :] 162 | self.feat = feat[:, :, :, :] 163 | print(self.feat.shape) 164 | self.feat_com = feat_com[:, :, :, :] 165 | print(self.feat_com.shape) 166 | 167 | self.num_data = feat.shape[0] 168 | 169 | # normalize labels to start from zero 170 | self.raw_label_2_real_label = {} 171 | self.real_label_2_raw_label = {} 172 | raw_class = sorted(np.unique(lines)) 173 | print(raw_class) 174 | 175 | self.num_class = len(raw_class) # should ensure it's pairwise 176 | # check the raw_class should contain pairwise label, e.g., class 0,1 class 5,6 177 | assert self.num_class % 2 == 0, 'Should be pairwise rumor-nonrumor, so that should be even' 178 | for i in range(self.num_class // 2): 179 | assert raw_class[i * 2] + 1 == raw_class[ 180 | i * 2 + 1], 'Should be pairwise rumor-nonrumor, so that raw_class number should be pair' 181 | 182 | self.labels = [] 183 | for i in range(self.num_class): 184 | j = raw_class[i] 185 | self.raw_label_2_real_label[j] = i 186 | self.real_label_2_raw_label[i] = j 187 | 188 | print('raw class -> real class', self.raw_label_2_real_label) 189 | print('real class -> raw class', self.real_label_2_raw_label) 190 | print("Make sure the raw class to real class is continuous") 191 | 192 | labels = [] 193 | for i in range(len(lines)): 194 | cls = self.raw_label_2_real_label[lines[i]] 195 | labels.append(cls) 196 | # print('Real labels used in training', labels) 197 | self.all_labels = labels 198 | 199 | # 每个doc的seq len 200 | lens = [] 201 | # 每个comment的seq len 202 | lens_com = [] 203 | for doc_label_len_mask in doc_com_label: 204 | lens.append(doc_label_len_mask[3]) 205 | lens_com.append(doc_label_len_mask[5]) 206 | self.all_lens = lens 207 | self.all_lens_com = lens_com 208 | self.mode = 'norm' 209 | 210 | def __len__(self): 211 | return self.num_data 212 | 213 | def __getitem__(self, index): 214 | # for gaining fixed validation sets 215 | if self.mode == 'dummy': 216 | return 1 217 | 218 | # a single data in a batch 219 | if len(index) == 1: 220 | ft, ft_com, lb, ln, ln_com = self.feat[index], self.feat_com[index], self.all_labels[index], self.all_lens[index], self.all_lens_com[index] 221 | return ft, ft_com, lb, ln, ln_com 222 | 223 | feats = [] 224 | feats_com = [] 225 | lbs = [] 226 | lns = [] 227 | lns_com = [] 228 | # print(index) 229 | for ind in index: 230 | ft, ft_com, lb, ln, ln_com = self.feat[ind:ind + 1, :], self.feat_com[ind:ind + 1, :], self.all_labels[ind], self.all_lens[ind], self.all_lens_com[ind] 231 | feats.append(ft) 232 | feats_com.append(ft_com) 233 | lbs.append(lb) 234 | lns.append(ln) 235 | lns_com.append(ln_com) 236 | feats = np.concatenate(feats, axis=0) 237 | feats_com = np.concatenate(feats_com, axis=0) 238 | # print(feats.shape) 239 | lbs = np.array(lbs, dtype=np.int) # this lbs are the raw class labels, which will not be used in training 240 | lns = np.array(lns, dtype=np.int) 241 | lns_com = np.array(lns_com, dtype=np.int) 242 | # print(lbs) 243 | return feats, lbs, lns, feats_com, lns_com 244 | 245 | 246 | if __name__ == '__main__': 247 | 248 | if 1 == 2: 249 | for sp in ['train', 'dev', 'test']: 250 | for split_no in [0, 1, 2]: 251 | ds = wb_rumor_fsl_dataset(sp, split_no) 252 | # print(len(ds), np.unique(ds.label), ds.accumulate) 253 | 254 | pass 255 | elif 2 == 2: 256 | for sp in ['test']: 257 | for split_no in [0, 1, 2]: 258 | ds = wb_rumor_fsl_dataset(sp, split_no) 259 | # print(len(ds), np.unique(ds.label), ds.accumulate) 260 | 261 | pass 262 | -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | import logging 23 | import abc 24 | import sys 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | if sys.version_info >= (3, 4): 30 | ABC = abc.ABC 31 | else: 32 | ABC = abc.ABCMeta('ABC', (), {}) 33 | 34 | 35 | class _LRSchedule(ABC): 36 | """ Parent of all LRSchedules here. """ 37 | warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense 38 | def __init__(self, warmup=0.002, t_total=-1, **kw): 39 | """ 40 | :param warmup: what fraction of t_total steps will be used for linear warmup 41 | :param t_total: how many training steps (updates) are planned 42 | :param kw: 43 | """ 44 | super(_LRSchedule, self).__init__(**kw) 45 | if t_total < 0: 46 | logger.warning("t_total value of {} results in schedule not being applied".format(t_total)) 47 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 48 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 49 | warmup = max(warmup, 0.) 50 | self.warmup, self.t_total = float(warmup), float(t_total) 51 | self.warned_for_t_total_at_progress = -1 52 | 53 | def get_lr(self, step, nowarn=False): 54 | """ 55 | :param step: which of t_total steps we're on 56 | :param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps 57 | :return: learning rate multiplier for current update 58 | """ 59 | if self.t_total < 0: 60 | return 1. 61 | progress = float(step) / self.t_total 62 | ret = self.get_lr_(progress) 63 | # warning for exceeding t_total (only active with warmup_linear 64 | if not nowarn and self.warn_t_total and progress > 1. and progress > self.warned_for_t_total_at_progress: 65 | logger.warning( 66 | "Training beyond specified 't_total'. Learning rate multiplier set to {}. Please set 't_total' of {} correctly." 67 | .format(ret, self.__class__.__name__)) 68 | self.warned_for_t_total_at_progress = progress 69 | # end warning 70 | return ret 71 | 72 | @abc.abstractmethod 73 | def get_lr_(self, progress): 74 | """ 75 | :param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress 76 | :return: learning rate multiplier for current update 77 | """ 78 | return 1. 79 | 80 | 81 | class ConstantLR(_LRSchedule): 82 | def get_lr_(self, progress): 83 | return 1. 84 | 85 | 86 | class WarmupCosineSchedule(_LRSchedule): 87 | """ 88 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 89 | Decreases learning rate from 1. to 0. over remaining `1 - warmup` steps following a cosine curve. 90 | If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup. 91 | """ 92 | warn_t_total = True 93 | def __init__(self, warmup=0.002, t_total=-1, cycles=.5, **kw): 94 | """ 95 | :param warmup: see LRSchedule 96 | :param t_total: see LRSchedule 97 | :param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1. at progress==warmup and 0 at progress==1. 98 | :param kw: 99 | """ 100 | super(WarmupCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, **kw) 101 | self.cycles = cycles 102 | 103 | def get_lr_(self, progress): 104 | if progress < self.warmup: 105 | return progress / self.warmup 106 | else: 107 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup 108 | return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress)) 109 | 110 | 111 | class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule): 112 | """ 113 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 114 | If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying 115 | learning rate (with hard restarts). 116 | """ 117 | def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): 118 | super(WarmupCosineWithHardRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) 119 | assert(cycles >= 1.) 120 | 121 | def get_lr_(self, progress): 122 | if progress < self.warmup: 123 | return progress / self.warmup 124 | else: 125 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup 126 | ret = 0.5 * (1. + math.cos(math.pi * ((self.cycles * progress) % 1))) 127 | return ret 128 | 129 | 130 | class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedule): 131 | """ 132 | All training progress is divided in `cycles` (default=1.) parts of equal length. 133 | Every part follows a schedule with the first `warmup` fraction of the training steps linearly increasing from 0. to 1., 134 | followed by a learning rate decreasing from 1. to 0. following a cosine curve. 135 | """ 136 | def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): 137 | assert(warmup * cycles < 1.) 138 | warmup = warmup * cycles if warmup >= 0 else warmup 139 | super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) 140 | 141 | def get_lr_(self, progress): 142 | progress = progress * self.cycles % 1. 143 | if progress < self.warmup: 144 | return progress / self.warmup 145 | else: 146 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup 147 | ret = 0.5 * (1. + math.cos(math.pi * progress)) 148 | return ret 149 | 150 | 151 | class WarmupConstantSchedule(_LRSchedule): 152 | """ 153 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 154 | Keeps learning rate equal to 1. after warmup. 155 | """ 156 | def get_lr_(self, progress): 157 | if progress < self.warmup: 158 | return progress / self.warmup 159 | return 1. 160 | 161 | 162 | class WarmupLinearSchedule(_LRSchedule): 163 | """ 164 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 165 | Linearly decreases learning rate from 1. to 0. over remaining `1 - warmup` steps. 166 | """ 167 | warn_t_total = True 168 | def get_lr_(self, progress): 169 | if progress < self.warmup: 170 | return progress / self.warmup 171 | return max((progress - 1.) / (self.warmup - 1.), 0.) 172 | 173 | 174 | SCHEDULES = { 175 | None: ConstantLR, 176 | "none": ConstantLR, 177 | "warmup_cosine": WarmupCosineSchedule, 178 | "warmup_constant": WarmupConstantSchedule, 179 | "warmup_linear": WarmupLinearSchedule 180 | } 181 | 182 | 183 | class BertAdam(Optimizer): 184 | """Implements BERT version of Adam algorithm with weight decay fix. 185 | Params: 186 | lr: learning rate 187 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 188 | t_total: total number of training steps for the learning 189 | rate schedule, -1 means constant learning rate of 1. (no warmup regardless of warmup setting). Default: -1 190 | schedule: schedule to use for the warmup (see above). 191 | Can be `'warmup_linear'`, `'warmup_constant'`, `'warmup_cosine'`, `'none'`, `None` or a `_LRSchedule` object (see below). 192 | If `None` or `'none'`, learning rate is always kept constant. 193 | Default : `'warmup_linear'` 194 | b1: Adams b1. Default: 0.9 195 | b2: Adams b2. Default: 0.999 196 | e: Adams epsilon. Default: 1e-6 197 | weight_decay: Weight decay. Default: 0.01 198 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 199 | """ 200 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', 201 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs): 202 | if lr is not required and lr < 0.0: 203 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 204 | if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES: 205 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 206 | if not 0.0 <= b1 < 1.0: 207 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 208 | if not 0.0 <= b2 < 1.0: 209 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 210 | if not e >= 0.0: 211 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 212 | # initialize schedule object 213 | if not isinstance(schedule, _LRSchedule): 214 | schedule_type = SCHEDULES[schedule] 215 | schedule = schedule_type(warmup=warmup, t_total=t_total) 216 | else: 217 | if warmup != -1 or t_total != -1: 218 | logger.warning("warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. " 219 | "Please specify custom warmup and t_total in _LRSchedule object.") 220 | defaults = dict(lr=lr, schedule=schedule, 221 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, 222 | max_grad_norm=max_grad_norm) 223 | super(BertAdam, self).__init__(params, defaults) 224 | 225 | def get_lr(self): 226 | lr = [] 227 | for group in self.param_groups: 228 | for p in group['params']: 229 | state = self.state[p] 230 | if len(state) == 0: 231 | return [0] 232 | lr_scheduled = group['lr'] 233 | lr_scheduled *= group['schedule'].get_lr(state['step']) 234 | lr.append(lr_scheduled) 235 | return lr 236 | 237 | def step(self, closure=None): 238 | """Performs a single optimization step. 239 | 240 | Arguments: 241 | closure (callable, optional): A closure that reevaluates the model 242 | and returns the loss. 243 | """ 244 | loss = None 245 | if closure is not None: 246 | loss = closure() 247 | 248 | for group in self.param_groups: 249 | for p in group['params']: 250 | if p.grad is None: 251 | continue 252 | grad = p.grad.data 253 | if grad.is_sparse: 254 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 255 | 256 | state = self.state[p] 257 | 258 | # State initialization 259 | if len(state) == 0: 260 | state['step'] = 0 261 | # Exponential moving average of gradient values 262 | state['next_m'] = torch.zeros_like(p.data) 263 | # Exponential moving average of squared gradient values 264 | state['next_v'] = torch.zeros_like(p.data) 265 | 266 | next_m, next_v = state['next_m'], state['next_v'] 267 | beta1, beta2 = group['b1'], group['b2'] 268 | 269 | # Add grad clipping 270 | if group['max_grad_norm'] > 0: 271 | clip_grad_norm_(p, group['max_grad_norm']) 272 | 273 | # Decay the first and second moment running average coefficient 274 | # In-place operations to update the averages at the same time 275 | next_m.mul_(beta1).add_(1 - beta1, grad) 276 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 277 | update = next_m / (next_v.sqrt() + group['e']) 278 | 279 | # Just adding the square of the weights to the loss function is *not* 280 | # the correct way of using L2 regularization/weight decay with Adam, 281 | # since that will interact with the m and v parameters in strange ways. 282 | # 283 | # Instead we want to decay the weights in a manner that doesn't interact 284 | # with the m/v parameters. This is equivalent to adding the square 285 | # of the weights to the loss with plain (non-momentum) SGD. 286 | if group['weight_decay'] > 0.0: 287 | update += group['weight_decay'] * p.data 288 | 289 | lr_scheduled = group['lr'] 290 | lr_scheduled *= group['schedule'].get_lr(state['step']) 291 | 292 | update_with_lr = lr_scheduled * update 293 | p.data.add_(-update_with_lr) 294 | 295 | state['step'] += 1 296 | 297 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 298 | # No bias correction 299 | # bias_correction1 = 1 - beta1 ** state['step'] 300 | # bias_correction2 = 1 - beta2 ** state['step'] 301 | 302 | return loss 303 | -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/tokenization_gpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for OpenAI GPT.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import sys 20 | import json 21 | import logging 22 | import os 23 | import regex as re 24 | from io import open 25 | 26 | try: 27 | from functools import lru_cache 28 | except ImportError: 29 | # Just a dummy decorator to get the checks to run on python2 30 | # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. 31 | def lru_cache(): 32 | return lambda func: func 33 | 34 | from .file_utils import cached_path 35 | 36 | logger = logging.getLogger(__name__) 37 | 38 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 39 | 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", 40 | } 41 | PRETRAINED_MERGES_ARCHIVE_MAP = { 42 | 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", 43 | } 44 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 45 | 'gpt2': 1024, 46 | } 47 | VOCAB_NAME = 'vocab.json' 48 | MERGES_NAME = 'merges.txt' 49 | SPECIAL_TOKENS_NAME = 'special_tokens.txt' 50 | 51 | @lru_cache() 52 | def bytes_to_unicode(): 53 | """ 54 | Returns list of utf-8 byte and a corresponding list of unicode strings. 55 | The reversible bpe codes work on unicode strings. 56 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 57 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 58 | This is a signficant percentage of your normal, say, 32K bpe vocab. 59 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 60 | And avoids mapping to whitespace/control characters the bpe code barfs on. 61 | """ 62 | _chr = unichr if sys.version_info[0] == 2 else chr 63 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 64 | cs = bs[:] 65 | n = 0 66 | for b in range(2**8): 67 | if b not in bs: 68 | bs.append(b) 69 | cs.append(2**8+n) 70 | n += 1 71 | cs = [_chr(n) for n in cs] 72 | return dict(zip(bs, cs)) 73 | 74 | def get_pairs(word): 75 | """Return set of symbol pairs in a word. 76 | 77 | Word is represented as tuple of symbols (symbols being variable-length strings). 78 | """ 79 | pairs = set() 80 | prev_char = word[0] 81 | for char in word[1:]: 82 | pairs.add((prev_char, char)) 83 | prev_char = char 84 | return pairs 85 | 86 | class GPT2Tokenizer(object): 87 | """ 88 | GPT-2 BPE tokenizer. Peculiarities: 89 | - Byte-level BPE 90 | """ 91 | @classmethod 92 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 93 | """ 94 | Instantiate a PreTrainedBertModel from a pre-trained model file. 95 | Download and cache the pre-trained model file if needed. 96 | """ 97 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 98 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] 99 | merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path] 100 | special_tokens_file = None 101 | else: 102 | vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) 103 | merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) 104 | special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME) 105 | if not os.path.exists(special_tokens_file): 106 | special_tokens_file = None 107 | else: 108 | logger.info("loading special tokens file {}".format(special_tokens_file)) 109 | # redirect to the cache, if necessary 110 | try: 111 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 112 | resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) 113 | except EnvironmentError: 114 | logger.error( 115 | "Model name '{}' was not found in model name list ({}). " 116 | "We assumed '{}' was a path or url but couldn't find files {} and {} " 117 | "at this path or url.".format( 118 | pretrained_model_name_or_path, 119 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 120 | pretrained_model_name_or_path, 121 | vocab_file, merges_file)) 122 | return None 123 | if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: 124 | logger.info("loading vocabulary file {}".format(vocab_file)) 125 | logger.info("loading merges file {}".format(merges_file)) 126 | else: 127 | logger.info("loading vocabulary file {} from cache at {}".format( 128 | vocab_file, resolved_vocab_file)) 129 | logger.info("loading merges file {} from cache at {}".format( 130 | merges_file, resolved_merges_file)) 131 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 132 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 133 | # than the number of positional embeddings 134 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] 135 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 136 | # Instantiate tokenizer. 137 | if special_tokens_file and 'special_tokens' not in kwargs: 138 | special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] 139 | else: 140 | special_tokens = kwargs.pop('special_tokens', []) 141 | tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs) 142 | return tokenizer 143 | 144 | def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None): 145 | self.max_len = max_len if max_len is not None else int(1e12) 146 | self.encoder = json.load(open(vocab_file)) 147 | self.decoder = {v:k for k,v in self.encoder.items()} 148 | self.errors = errors # how to handle errors in decoding 149 | self.byte_encoder = bytes_to_unicode() 150 | self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} 151 | bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] 152 | bpe_merges = [tuple(merge.split()) for merge in bpe_data] 153 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 154 | self.cache = {} 155 | 156 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions 157 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 158 | 159 | self.special_tokens = {} 160 | self.special_tokens_decoder = {} 161 | self.set_special_tokens(special_tokens) 162 | 163 | def __len__(self): 164 | return len(self.encoder) + len(self.special_tokens) 165 | 166 | def set_special_tokens(self, special_tokens): 167 | """ Add a list of additional tokens to the encoder. 168 | The additional tokens are indexed starting from the last index of the 169 | current vocabulary in the order of the `special_tokens` list. 170 | """ 171 | if not special_tokens: 172 | self.special_tokens = {} 173 | self.special_tokens_decoder = {} 174 | return 175 | self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)) 176 | self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()} 177 | logger.info("Special tokens {}".format(self.special_tokens)) 178 | 179 | def bpe(self, token): 180 | if token in self.cache: 181 | return self.cache[token] 182 | word = tuple(token) 183 | pairs = get_pairs(word) 184 | 185 | if not pairs: 186 | return token 187 | 188 | while True: 189 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 190 | if bigram not in self.bpe_ranks: 191 | break 192 | first, second = bigram 193 | new_word = [] 194 | i = 0 195 | while i < len(word): 196 | try: 197 | j = word.index(first, i) 198 | new_word.extend(word[i:j]) 199 | i = j 200 | except: 201 | new_word.extend(word[i:]) 202 | break 203 | 204 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 205 | new_word.append(first+second) 206 | i += 2 207 | else: 208 | new_word.append(word[i]) 209 | i += 1 210 | new_word = tuple(new_word) 211 | word = new_word 212 | if len(word) == 1: 213 | break 214 | else: 215 | pairs = get_pairs(word) 216 | word = ' '.join(word) 217 | self.cache[token] = word 218 | return word 219 | 220 | def tokenize(self, text): 221 | """ Tokenize a string. """ 222 | bpe_tokens = [] 223 | for token in re.findall(self.pat, text): 224 | token = ''.join(self.byte_encoder[ord(b)] for b in token) 225 | bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) 226 | return bpe_tokens 227 | 228 | def convert_tokens_to_ids(self, tokens): 229 | """ Converts a sequence of tokens into ids using the vocab. """ 230 | ids = [] 231 | if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)): 232 | if tokens in self.special_tokens: 233 | return self.special_tokens[tokens] 234 | else: 235 | return self.encoder.get(tokens, 0) 236 | for token in tokens: 237 | if token in self.special_tokens: 238 | ids.append(self.special_tokens[token]) 239 | else: 240 | ids.append(self.encoder.get(token, 0)) 241 | if len(ids) > self.max_len: 242 | logger.warning( 243 | "Token indices sequence length is longer than the specified maximum " 244 | " sequence length for this OpenAI GPT model ({} > {}). Running this" 245 | " sequence through the model will result in indexing errors".format(len(ids), self.max_len) 246 | ) 247 | return ids 248 | 249 | def convert_ids_to_tokens(self, ids, skip_special_tokens=False): 250 | """Converts a sequence of ids in BPE tokens using the vocab.""" 251 | tokens = [] 252 | for i in ids: 253 | if i in self.special_tokens_decoder: 254 | if not skip_special_tokens: 255 | tokens.append(self.special_tokens_decoder[i]) 256 | else: 257 | tokens.append(self.decoder[i]) 258 | return tokens 259 | 260 | def encode(self, text): 261 | return self.convert_tokens_to_ids(self.tokenize(text)) 262 | 263 | def decode(self, tokens): 264 | text = ''.join([self.decoder[token] for token in tokens]) 265 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) 266 | return text 267 | 268 | def save_vocabulary(self, vocab_path): 269 | """Save the tokenizer vocabulary and merge files to a directory.""" 270 | if not os.path.isdir(vocab_path): 271 | logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) 272 | return 273 | vocab_file = os.path.join(vocab_path, VOCAB_NAME) 274 | merge_file = os.path.join(vocab_path, MERGES_NAME) 275 | special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME) 276 | 277 | with open(vocab_file, 'w', encoding='utf-8') as f: 278 | f.write(json.dumps(self.encoder, ensure_ascii=False)) 279 | 280 | index = 0 281 | with open(merge_file, "w", encoding="utf-8") as writer: 282 | writer.write(u'#version: 0.2\n') 283 | for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): 284 | if index != token_index: 285 | logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." 286 | " Please check that the tokenizer is not corrupted!".format(merge_file)) 287 | index = token_index 288 | writer.write(' '.join(bpe_tokens) + u'\n') 289 | index += 1 290 | 291 | index = len(self.encoder) 292 | with open(special_tokens_file, 'w', encoding='utf-8') as writer: 293 | for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]): 294 | if index != token_index: 295 | logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive." 296 | " Please check that the tokenizer is not corrupted!".format(special_tokens_file)) 297 | index = token_index 298 | writer.write(token + u'\n') 299 | index += 1 300 | 301 | return vocab_file, merge_file, special_tokens_file 302 | -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/tokenization_openai.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for OpenAI GPT.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import json 20 | import logging 21 | import os 22 | import re 23 | import sys 24 | from io import open 25 | 26 | from tqdm import tqdm 27 | 28 | from .file_utils import cached_path 29 | from .tokenization import BasicTokenizer 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 34 | 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json", 35 | } 36 | PRETRAINED_MERGES_ARCHIVE_MAP = { 37 | 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt", 38 | } 39 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 40 | 'openai-gpt': 512, 41 | } 42 | VOCAB_NAME = 'vocab.json' 43 | MERGES_NAME = 'merges.txt' 44 | SPECIAL_TOKENS_NAME = 'special_tokens.txt' 45 | 46 | def get_pairs(word): 47 | """ 48 | Return set of symbol pairs in a word. 49 | word is represented as tuple of symbols (symbols being variable-length strings) 50 | """ 51 | pairs = set() 52 | prev_char = word[0] 53 | for char in word[1:]: 54 | pairs.add((prev_char, char)) 55 | prev_char = char 56 | return pairs 57 | 58 | def text_standardize(text): 59 | """ 60 | fixes some issues the spacy tokenizer had on books corpus 61 | also does some whitespace standardization 62 | """ 63 | text = text.replace('—', '-') 64 | text = text.replace('–', '-') 65 | text = text.replace('―', '-') 66 | text = text.replace('…', '...') 67 | text = text.replace('´', "'") 68 | text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text) 69 | text = re.sub(r'\s*\n\s*', ' \n ', text) 70 | text = re.sub(r'[^\S\n]+', ' ', text) 71 | return text.strip() 72 | 73 | class OpenAIGPTTokenizer(object): 74 | """ 75 | BPE tokenizer. Peculiarities: 76 | - lower case all inputs 77 | - uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not. 78 | - argument special_tokens and function set_special_tokens: 79 | can be used to add additional symbols (ex: "__classify__") to a vocabulary. 80 | """ 81 | @classmethod 82 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 83 | """ 84 | Instantiate a PreTrainedBertModel from a pre-trained model file. 85 | Download and cache the pre-trained model file if needed. 86 | """ 87 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 88 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] 89 | merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path] 90 | special_tokens_file = None 91 | else: 92 | vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) 93 | merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) 94 | special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME) 95 | if not os.path.exists(special_tokens_file): 96 | special_tokens_file = None 97 | else: 98 | logger.info("loading special tokens file {}".format(special_tokens_file)) 99 | # redirect to the cache, if necessary 100 | try: 101 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 102 | resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) 103 | except EnvironmentError: 104 | logger.error( 105 | "Model name '{}' was not found in model name list ({}). " 106 | "We assumed '{}' was a path or url but couldn't find files {} and {} " 107 | "at this path or url.".format( 108 | pretrained_model_name_or_path, 109 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 110 | pretrained_model_name_or_path, 111 | vocab_file, merges_file)) 112 | return None 113 | if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: 114 | logger.info("loading vocabulary file {}".format(vocab_file)) 115 | logger.info("loading merges file {}".format(merges_file)) 116 | else: 117 | logger.info("loading vocabulary file {} from cache at {}".format( 118 | vocab_file, resolved_vocab_file)) 119 | logger.info("loading merges file {} from cache at {}".format( 120 | merges_file, resolved_merges_file)) 121 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 122 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 123 | # than the number of positional embeddings 124 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] 125 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 126 | # Instantiate tokenizer. 127 | if special_tokens_file and 'special_tokens' not in kwargs: 128 | special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] 129 | else: 130 | special_tokens = kwargs.pop('special_tokens', []) 131 | tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs) 132 | return tokenizer 133 | 134 | def __init__(self, vocab_file, merges_file, special_tokens=None, max_len=None): 135 | try: 136 | import ftfy 137 | import spacy 138 | self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat']) 139 | self.fix_text = ftfy.fix_text 140 | except ImportError: 141 | logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.") 142 | self.nlp = BasicTokenizer(do_lower_case=True, 143 | never_split=special_tokens if special_tokens is not None else []) 144 | self.fix_text = None 145 | 146 | self.max_len = max_len if max_len is not None else int(1e12) 147 | self.encoder = json.load(open(vocab_file, encoding="utf-8")) 148 | self.decoder = {v:k for k,v in self.encoder.items()} 149 | merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] 150 | merges = [tuple(merge.split()) for merge in merges] 151 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 152 | self.cache = {} 153 | self.special_tokens = {} 154 | self.special_tokens_decoder = {} 155 | self.set_special_tokens(special_tokens) 156 | 157 | def __len__(self): 158 | return len(self.encoder) + len(self.special_tokens) 159 | 160 | def set_special_tokens(self, special_tokens): 161 | """ Add a list of additional tokens to the encoder. 162 | The additional tokens are indexed starting from the last index of the 163 | current vocabulary in the order of the `special_tokens` list. 164 | """ 165 | if not special_tokens: 166 | self.special_tokens = {} 167 | self.special_tokens_decoder = {} 168 | return 169 | self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)) 170 | self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()} 171 | if self.fix_text is None: 172 | # Using BERT's BasicTokenizer: we can update the tokenizer 173 | self.nlp.never_split = special_tokens 174 | logger.info("Special tokens {}".format(self.special_tokens)) 175 | 176 | def bpe(self, token): 177 | word = tuple(token[:-1]) + (token[-1] + '',) 178 | if token in self.cache: 179 | return self.cache[token] 180 | pairs = get_pairs(word) 181 | 182 | if not pairs: 183 | return token+'' 184 | 185 | while True: 186 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) 187 | if bigram not in self.bpe_ranks: 188 | break 189 | first, second = bigram 190 | new_word = [] 191 | i = 0 192 | while i < len(word): 193 | try: 194 | j = word.index(first, i) 195 | new_word.extend(word[i:j]) 196 | i = j 197 | except: 198 | new_word.extend(word[i:]) 199 | break 200 | 201 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 202 | new_word.append(first+second) 203 | i += 2 204 | else: 205 | new_word.append(word[i]) 206 | i += 1 207 | new_word = tuple(new_word) 208 | word = new_word 209 | if len(word) == 1: 210 | break 211 | else: 212 | pairs = get_pairs(word) 213 | word = ' '.join(word) 214 | if word == '\n ': 215 | word = '\n' 216 | self.cache[token] = word 217 | return word 218 | 219 | def tokenize(self, text): 220 | """ Tokenize a string. """ 221 | split_tokens = [] 222 | if self.fix_text is None: 223 | # Using BERT's BasicTokenizer 224 | text = self.nlp.tokenize(text) 225 | for token in text: 226 | split_tokens.extend([t for t in self.bpe(token).split(' ')]) 227 | else: 228 | # Using SpaCy & ftfy (original tokenization process of OpenAI GPT) 229 | text = self.nlp(text_standardize(self.fix_text(text))) 230 | for token in text: 231 | split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')]) 232 | return split_tokens 233 | 234 | def convert_tokens_to_ids(self, tokens): 235 | """ Converts a sequence of tokens into ids using the vocab. """ 236 | ids = [] 237 | if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)): 238 | if tokens in self.special_tokens: 239 | return self.special_tokens[tokens] 240 | else: 241 | return self.encoder.get(tokens, 0) 242 | for token in tokens: 243 | if token in self.special_tokens: 244 | ids.append(self.special_tokens[token]) 245 | else: 246 | ids.append(self.encoder.get(token, 0)) 247 | if len(ids) > self.max_len: 248 | logger.warning( 249 | "Token indices sequence length is longer than the specified maximum " 250 | " sequence length for this OpenAI GPT model ({} > {}). Running this" 251 | " sequence through the model will result in indexing errors".format(len(ids), self.max_len) 252 | ) 253 | return ids 254 | 255 | def convert_ids_to_tokens(self, ids, skip_special_tokens=False): 256 | """Converts a sequence of ids in BPE tokens using the vocab.""" 257 | tokens = [] 258 | for i in ids: 259 | if i in self.special_tokens_decoder: 260 | if not skip_special_tokens: 261 | tokens.append(self.special_tokens_decoder[i]) 262 | else: 263 | tokens.append(self.decoder[i]) 264 | return tokens 265 | 266 | def encode(self, text): 267 | return self.convert_tokens_to_ids(self.tokenize(text)) 268 | 269 | def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=True): 270 | """Converts a sequence of ids in a string.""" 271 | tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens) 272 | out_string = ''.join(tokens).replace('', ' ').strip() 273 | if clean_up_tokenization_spaces: 274 | out_string = out_string.replace('', '') 275 | out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ',').replace(' ,', ',' 276 | ).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't" 277 | ).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re") 278 | return out_string 279 | 280 | def save_vocabulary(self, vocab_path): 281 | """Save the tokenizer vocabulary and merge files to a directory.""" 282 | if not os.path.isdir(vocab_path): 283 | logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) 284 | return 285 | vocab_file = os.path.join(vocab_path, VOCAB_NAME) 286 | merge_file = os.path.join(vocab_path, MERGES_NAME) 287 | special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME) 288 | 289 | with open(vocab_file, 'w', encoding='utf-8') as f: 290 | f.write(json.dumps(self.encoder, ensure_ascii=False)) 291 | 292 | index = 0 293 | with open(merge_file, "w", encoding="utf-8") as writer: 294 | writer.write(u'#version: 0.2\n') 295 | for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): 296 | if index != token_index: 297 | logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." 298 | " Please check that the tokenizer is not corrupted!".format(merge_file)) 299 | index = token_index 300 | writer.write(' '.join(bpe_tokens) + u'\n') 301 | index += 1 302 | 303 | index = len(self.encoder) 304 | with open(special_tokens_file, 'w', encoding='utf-8') as writer: 305 | for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]): 306 | if index != token_index: 307 | logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive." 308 | " Please check that the tokenizer is not corrupted!".format(special_tokens_file)) 309 | index = token_index 310 | writer.write(token + u'\n') 311 | index += 1 312 | 313 | return vocab_file, merge_file, special_tokens_file 314 | -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/modeling_transfo_xl_utilities.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Utilities for PyTorch Transformer XL model. 17 | Directly adapted from https://github.com/kimiyoung/transformer-xl. 18 | """ 19 | 20 | from collections import defaultdict 21 | 22 | import numpy as np 23 | 24 | import torch 25 | import torch.nn as nn 26 | import torch.nn.functional as F 27 | 28 | # CUDA_MAJOR = int(torch.version.cuda.split('.')[0]) 29 | # CUDA_MINOR = int(torch.version.cuda.split('.')[1]) 30 | 31 | class ProjectedAdaptiveLogSoftmax(nn.Module): 32 | def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, 33 | keep_order=False): 34 | super(ProjectedAdaptiveLogSoftmax, self).__init__() 35 | 36 | self.n_token = n_token 37 | self.d_embed = d_embed 38 | self.d_proj = d_proj 39 | 40 | self.cutoffs = cutoffs + [n_token] 41 | self.cutoff_ends = [0] + self.cutoffs 42 | self.div_val = div_val 43 | 44 | self.shortlist_size = self.cutoffs[0] 45 | self.n_clusters = len(self.cutoffs) - 1 46 | self.head_size = self.shortlist_size + self.n_clusters 47 | 48 | if self.n_clusters > 0: 49 | self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed)) 50 | self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters)) 51 | 52 | self.out_layers = nn.ModuleList() 53 | self.out_projs = nn.ParameterList() 54 | 55 | if div_val == 1: 56 | for i in range(len(self.cutoffs)): 57 | if d_proj != d_embed: 58 | self.out_projs.append( 59 | nn.Parameter(torch.Tensor(d_proj, d_embed)) 60 | ) 61 | else: 62 | self.out_projs.append(None) 63 | 64 | self.out_layers.append(nn.Linear(d_embed, n_token)) 65 | else: 66 | for i in range(len(self.cutoffs)): 67 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1] 68 | d_emb_i = d_embed // (div_val ** i) 69 | 70 | self.out_projs.append( 71 | nn.Parameter(torch.Tensor(d_proj, d_emb_i)) 72 | ) 73 | 74 | self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx)) 75 | 76 | self.keep_order = keep_order 77 | 78 | def _compute_logit(self, hidden, weight, bias, proj): 79 | if proj is None: 80 | logit = F.linear(hidden, weight, bias=bias) 81 | else: 82 | # if CUDA_MAJOR <= 9 and CUDA_MINOR <= 1: 83 | proj_hid = F.linear(hidden, proj.t().contiguous()) 84 | logit = F.linear(proj_hid, weight, bias=bias) 85 | # else: 86 | # logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t())) 87 | # if bias is not None: 88 | # logit = logit + bias 89 | 90 | return logit 91 | 92 | def forward(self, hidden, target=None, keep_order=False): 93 | ''' 94 | Params: 95 | hidden :: [len*bsz x d_proj] 96 | target :: [len*bsz] 97 | Return: 98 | if target is None: 99 | out :: [len*bsz] Negative log likelihood 100 | else: 101 | out :: [len*bsz x n_tokens] log probabilities of tokens over the vocabulary 102 | We could replace this implementation by the native PyTorch one 103 | if their's had an option to set bias on all clusters in the native one. 104 | here: https://github.com/pytorch/pytorch/blob/dbe6a7a9ff1a364a8706bf5df58a1ca96d2fd9da/torch/nn/modules/adaptive.py#L138 105 | ''' 106 | 107 | if target is not None: 108 | target = target.view(-1) 109 | if hidden.size(0) != target.size(0): 110 | raise RuntimeError('Input and target should have the same size ' 111 | 'in the batch dimension.') 112 | 113 | if self.n_clusters == 0: 114 | logit = self._compute_logit(hidden, self.out_layers[0].weight, 115 | self.out_layers[0].bias, self.out_projs[0]) 116 | if target is not None: 117 | output = -F.log_softmax(logit, dim=-1) \ 118 | .gather(1, target.unsqueeze(1)).squeeze(1) 119 | else: 120 | output = F.log_softmax(logit, dim=-1) 121 | else: 122 | # construct weights and biases 123 | weights, biases = [], [] 124 | for i in range(len(self.cutoffs)): 125 | if self.div_val == 1: 126 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] 127 | weight_i = self.out_layers[0].weight[l_idx:r_idx] 128 | bias_i = self.out_layers[0].bias[l_idx:r_idx] 129 | else: 130 | weight_i = self.out_layers[i].weight 131 | bias_i = self.out_layers[i].bias 132 | 133 | if i == 0: 134 | weight_i = torch.cat( 135 | [weight_i, self.cluster_weight], dim=0) 136 | bias_i = torch.cat( 137 | [bias_i, self.cluster_bias], dim=0) 138 | 139 | weights.append(weight_i) 140 | biases.append(bias_i) 141 | 142 | head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0] 143 | 144 | head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj) 145 | head_logprob = F.log_softmax(head_logit, dim=1) 146 | 147 | if target is None: 148 | out = hidden.new_empty((head_logit.size(0), self.n_token)) 149 | else: 150 | out = torch.zeros_like(target, dtype=hidden.dtype, device=hidden.device) 151 | 152 | offset = 0 153 | cutoff_values = [0] + self.cutoffs 154 | for i in range(len(cutoff_values) - 1): 155 | l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1] 156 | 157 | if target is not None: 158 | mask_i = (target >= l_idx) & (target < r_idx) 159 | indices_i = mask_i.nonzero().squeeze() 160 | 161 | if indices_i.numel() == 0: 162 | continue 163 | 164 | target_i = target.index_select(0, indices_i) - l_idx 165 | head_logprob_i = head_logprob.index_select(0, indices_i) 166 | hidden_i = hidden.index_select(0, indices_i) 167 | else: 168 | hidden_i = hidden 169 | 170 | if i == 0: 171 | if target is not None: 172 | logprob_i = head_logprob_i.gather(1, target_i[:, None]).squeeze(1) 173 | else: 174 | out[:, :self.cutoffs[0]] = head_logprob[:, :self.cutoffs[0]] 175 | else: 176 | weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i] 177 | 178 | tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i) 179 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) 180 | cluster_prob_idx = self.cutoffs[0] + i - 1 # No probability for the head cluster 181 | if target is not None: 182 | logprob_i = head_logprob_i[:, cluster_prob_idx] \ 183 | + tail_logprob_i.gather(1, target_i[:, None]).squeeze(1) 184 | else: 185 | logprob_i = head_logprob[:, cluster_prob_idx, None] + tail_logprob_i 186 | out[:, l_idx:r_idx] = logprob_i 187 | 188 | if target is not None: 189 | if (hasattr(self, 'keep_order') and self.keep_order) or keep_order: 190 | out.index_copy_(0, indices_i, -logprob_i) 191 | else: 192 | out[offset:offset+logprob_i.size(0)].copy_(-logprob_i) 193 | offset += logprob_i.size(0) 194 | 195 | return out 196 | 197 | 198 | def log_prob(self, hidden): 199 | r""" Computes log probabilities for all :math:`n\_classes` 200 | From: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/adaptive.py 201 | Args: 202 | hidden (Tensor): a minibatch of examples 203 | Returns: 204 | log-probabilities of for each class :math:`c` 205 | in range :math:`0 <= c <= n\_classes`, where :math:`n\_classes` is a 206 | parameter passed to ``AdaptiveLogSoftmaxWithLoss`` constructor. 207 | Shape: 208 | - Input: :math:`(N, in\_features)` 209 | - Output: :math:`(N, n\_classes)` 210 | """ 211 | if self.n_clusters == 0: 212 | logit = self._compute_logit(hidden, self.out_layers[0].weight, 213 | self.out_layers[0].bias, self.out_projs[0]) 214 | return F.log_softmax(logit, dim=-1) 215 | else: 216 | # construct weights and biases 217 | weights, biases = [], [] 218 | for i in range(len(self.cutoffs)): 219 | if self.div_val == 1: 220 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] 221 | weight_i = self.out_layers[0].weight[l_idx:r_idx] 222 | bias_i = self.out_layers[0].bias[l_idx:r_idx] 223 | else: 224 | weight_i = self.out_layers[i].weight 225 | bias_i = self.out_layers[i].bias 226 | 227 | if i == 0: 228 | weight_i = torch.cat( 229 | [weight_i, self.cluster_weight], dim=0) 230 | bias_i = torch.cat( 231 | [bias_i, self.cluster_bias], dim=0) 232 | 233 | weights.append(weight_i) 234 | biases.append(bias_i) 235 | 236 | head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0] 237 | head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj) 238 | 239 | out = hidden.new_empty((head_logit.size(0), self.n_token)) 240 | head_logprob = F.log_softmax(head_logit, dim=1) 241 | 242 | cutoff_values = [0] + self.cutoffs 243 | for i in range(len(cutoff_values) - 1): 244 | start_idx, stop_idx = cutoff_values[i], cutoff_values[i + 1] 245 | 246 | if i == 0: 247 | out[:, :self.cutoffs[0]] = head_logprob[:, :self.cutoffs[0]] 248 | else: 249 | weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i] 250 | 251 | tail_logit_i = self._compute_logit(hidden, weight_i, bias_i, proj_i) 252 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) 253 | 254 | logprob_i = head_logprob[:, -i] + tail_logprob_i 255 | out[:, start_idx, stop_idx] = logprob_i 256 | 257 | return out 258 | 259 | 260 | class LogUniformSampler(object): 261 | def __init__(self, range_max, n_sample): 262 | """ 263 | Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py 264 | `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)` 265 | 266 | expected count can be approximated by 1 - (1 - p)^n 267 | and we use a numerically stable version -expm1(num_tries * log1p(-p)) 268 | 269 | Our implementation fixes num_tries at 2 * n_sample, and the actual #samples will vary from run to run 270 | """ 271 | with torch.no_grad(): 272 | self.range_max = range_max 273 | log_indices = torch.arange(1., range_max+2., 1.).log_() 274 | self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1] 275 | # print('P', self.dist.numpy().tolist()[-30:]) 276 | 277 | self.log_q = (- (-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float() 278 | 279 | self.n_sample = n_sample 280 | 281 | def sample(self, labels): 282 | """ 283 | labels: [b1, b2] 284 | Return 285 | true_log_probs: [b1, b2] 286 | samp_log_probs: [n_sample] 287 | neg_samples: [n_sample] 288 | """ 289 | 290 | # neg_samples = torch.empty(0).long() 291 | n_sample = self.n_sample 292 | n_tries = 2 * n_sample 293 | 294 | with torch.no_grad(): 295 | neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique() 296 | device = labels.device 297 | neg_samples = neg_samples.to(device) 298 | true_log_probs = self.log_q[labels].to(device) 299 | samp_log_probs = self.log_q[neg_samples].to(device) 300 | return true_log_probs, samp_log_probs, neg_samples 301 | 302 | def sample_logits(embedding, bias, labels, inputs, sampler): 303 | """ 304 | embedding: an nn.Embedding layer 305 | bias: [n_vocab] 306 | labels: [b1, b2] 307 | inputs: [b1, b2, n_emb] 308 | sampler: you may use a LogUniformSampler 309 | Return 310 | logits: [b1, b2, 1 + n_sample] 311 | """ 312 | true_log_probs, samp_log_probs, neg_samples = sampler.sample(labels) 313 | n_sample = neg_samples.size(0) 314 | b1, b2 = labels.size(0), labels.size(1) 315 | all_ids = torch.cat([labels.view(-1), neg_samples]) 316 | all_w = embedding(all_ids) 317 | true_w = all_w[: -n_sample].view(b1, b2, -1) 318 | sample_w = all_w[- n_sample:].view(n_sample, -1) 319 | 320 | all_b = bias[all_ids] 321 | true_b = all_b[: -n_sample].view(b1, b2) 322 | sample_b = all_b[- n_sample:] 323 | 324 | hit = (labels[:, :, None] == neg_samples).detach() 325 | 326 | true_logits = torch.einsum('ijk,ijk->ij', 327 | [true_w, inputs]) + true_b - true_log_probs 328 | sample_logits = torch.einsum('lk,ijk->ijl', 329 | [sample_w, inputs]) + sample_b - samp_log_probs 330 | sample_logits.masked_fill_(hit, -1e30) 331 | logits = torch.cat([true_logits[:, :, None], sample_logits], -1) 332 | 333 | return logits 334 | 335 | 336 | # class LogUniformSampler(object): 337 | # def __init__(self, range_max, unique=False): 338 | # """ 339 | # Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py 340 | # `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)` 341 | # """ 342 | # self.range_max = range_max 343 | # log_indices = torch.arange(1., range_max+2., 1.).log_() 344 | # self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1] 345 | 346 | # self.unique = unique 347 | 348 | # if self.unique: 349 | # self.exclude_mask = torch.ByteTensor(range_max).fill_(0) 350 | 351 | # def sample(self, n_sample, labels): 352 | # pos_sample, new_labels = labels.unique(return_inverse=True) 353 | # n_pos_sample = pos_sample.size(0) 354 | # n_neg_sample = n_sample - n_pos_sample 355 | 356 | # if self.unique: 357 | # self.exclude_mask.index_fill_(0, pos_sample, 1) 358 | # sample_dist = self.dist.clone().masked_fill_(self.exclude_mask, 0) 359 | # self.exclude_mask.index_fill_(0, pos_sample, 0) 360 | # else: 361 | # sample_dist = self.dist 362 | 363 | # neg_sample = torch.multinomial(sample_dist, n_neg_sample) 364 | 365 | # sample = torch.cat([pos_sample, neg_sample]) 366 | # sample_prob = self.dist[sample] 367 | 368 | # return new_labels, sample, sample_prob 369 | 370 | 371 | if __name__ == '__main__': 372 | S, B = 3, 4 373 | n_vocab = 10000 374 | n_sample = 5 375 | H = 32 376 | 377 | labels = torch.LongTensor(S, B).random_(0, n_vocab) 378 | 379 | # sampler = LogUniformSampler(n_vocab, unique=False) 380 | # new_labels, sample, sample_prob = sampler.sample(n_sample, labels) 381 | 382 | sampler = LogUniformSampler(n_vocab, n_sample)#, unique=True) 383 | # true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels) 384 | 385 | # print('true_probs', true_probs.numpy().tolist()) 386 | # print('samp_probs', samp_probs.numpy().tolist()) 387 | # print('neg_samples', neg_samples.numpy().tolist()) 388 | 389 | # print('sum', torch.sum(sampler.dist).item()) 390 | 391 | # assert torch.all(torch.sort(sample.unique())[0].eq(torch.sort(sample)[0])).item() 392 | 393 | embedding = nn.Embedding(n_vocab, H) 394 | bias = torch.zeros(n_vocab) 395 | inputs = torch.Tensor(S, B, H).normal_() 396 | 397 | logits, out_labels = sample_logits(embedding, bias, labels, inputs, sampler, n_sample) 398 | print('logits', logits.detach().numpy().tolist()) 399 | print('logits shape', logits.size()) 400 | print('out_labels', out_labels.detach().numpy().tolist()) 401 | print('out_labels shape', out_labels.size()) 402 | 403 | -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import collections 20 | import logging 21 | import os 22 | import unicodedata 23 | from io import open 24 | 25 | from .file_utils import cached_path 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 30 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 31 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 32 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 33 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 34 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 35 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 36 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 37 | } 38 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 39 | 'bert-base-uncased': 512, 40 | 'bert-large-uncased': 512, 41 | 'bert-base-cased': 512, 42 | 'bert-large-cased': 512, 43 | 'bert-base-multilingual-uncased': 512, 44 | 'bert-base-multilingual-cased': 512, 45 | 'bert-base-chinese': 512, 46 | } 47 | VOCAB_NAME = 'vocab.txt' 48 | 49 | 50 | def load_vocab(vocab_file): 51 | """Loads a vocabulary file into a dictionary.""" 52 | vocab = collections.OrderedDict() 53 | index = 0 54 | with open(vocab_file, "r", encoding="utf-8") as reader: 55 | while True: 56 | token = reader.readline() 57 | if not token: 58 | break 59 | token = token.strip() 60 | vocab[token] = index 61 | index += 1 62 | return vocab 63 | 64 | 65 | def whitespace_tokenize(text): 66 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 67 | text = text.strip() 68 | if not text: 69 | return [] 70 | tokens = text.split() 71 | return tokens 72 | 73 | 74 | class BertTokenizer(object): 75 | """Runs end-to-end tokenization: punctuation splitting + wordpiece""" 76 | 77 | def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, 78 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 79 | """Constructs a BertTokenizer. 80 | 81 | Args: 82 | vocab_file: Path to a one-wordpiece-per-line vocabulary file 83 | do_lower_case: Whether to lower case the input 84 | Only has an effect when do_wordpiece_only=False 85 | do_basic_tokenize: Whether to do basic tokenization before wordpiece. 86 | max_len: An artificial maximum length to truncate tokenized sequences to; 87 | Effective maximum length is always the minimum of this 88 | value (if specified) and the underlying BERT model's 89 | sequence length. 90 | never_split: List of tokens which will never be split during tokenization. 91 | Only has an effect when do_wordpiece_only=False 92 | """ 93 | if not os.path.isfile(vocab_file): 94 | raise ValueError( 95 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 96 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) 97 | self.vocab = load_vocab(vocab_file) 98 | self.ids_to_tokens = collections.OrderedDict( 99 | [(ids, tok) for tok, ids in self.vocab.items()]) 100 | self.do_basic_tokenize = do_basic_tokenize 101 | if do_basic_tokenize: 102 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, 103 | never_split=never_split) 104 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 105 | self.max_len = max_len if max_len is not None else int(1e12) 106 | 107 | def tokenize(self, text): 108 | split_tokens = [] 109 | if self.do_basic_tokenize: 110 | for token in self.basic_tokenizer.tokenize(text): 111 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 112 | split_tokens.append(sub_token) 113 | else: 114 | split_tokens = self.wordpiece_tokenizer.tokenize(text) 115 | return split_tokens 116 | 117 | def convert_tokens_to_ids(self, tokens): 118 | """Converts a sequence of tokens into ids using the vocab.""" 119 | ids = [] 120 | for token in tokens: 121 | ids.append(self.vocab[token]) 122 | if len(ids) > self.max_len: 123 | logger.warning( 124 | "Token indices sequence length is longer than the specified maximum " 125 | " sequence length for this BERT model ({} > {}). Running this" 126 | " sequence through BERT will result in indexing errors".format(len(ids), self.max_len) 127 | ) 128 | return ids 129 | 130 | def convert_ids_to_tokens(self, ids): 131 | """Converts a sequence of ids in wordpiece tokens using the vocab.""" 132 | tokens = [] 133 | for i in ids: 134 | tokens.append(self.ids_to_tokens[i]) 135 | return tokens 136 | 137 | def save_vocabulary(self, vocab_path): 138 | """Save the tokenizer vocabulary to a directory or file.""" 139 | index = 0 140 | if os.path.isdir(vocab_path): 141 | vocab_file = os.path.join(vocab_path, VOCAB_NAME) 142 | with open(vocab_file, "w", encoding="utf-8") as writer: 143 | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): 144 | if index != token_index: 145 | logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive." 146 | " Please check that the vocabulary is not corrupted!".format(vocab_file)) 147 | index = token_index 148 | writer.write(token + u'\n') 149 | index += 1 150 | return vocab_file 151 | 152 | @classmethod 153 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 154 | """ 155 | Instantiate a PreTrainedBertModel from a pre-trained model file. 156 | Download and cache the pre-trained model file if needed. 157 | """ 158 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 159 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] 160 | if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True): 161 | logger.warning("The pre-trained model you are loading is a cased model but you have not set " 162 | "`do_lower_case` to False. We are setting `do_lower_case=False` for you but " 163 | "you may want to check this behavior.") 164 | kwargs['do_lower_case'] = False 165 | elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True): 166 | logger.warning("The pre-trained model you are loading is an uncased model but you have set " 167 | "`do_lower_case` to False. We are setting `do_lower_case=True` for you " 168 | "but you may want to check this behavior.") 169 | kwargs['do_lower_case'] = True 170 | else: 171 | vocab_file = pretrained_model_name_or_path 172 | if os.path.isdir(vocab_file): 173 | vocab_file = os.path.join(vocab_file, VOCAB_NAME) 174 | # redirect to the cache, if necessary 175 | try: 176 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 177 | except EnvironmentError: 178 | logger.error( 179 | "Model name '{}' was not found in model name list ({}). " 180 | "We assumed '{}' was a path or url but couldn't find any file " 181 | "associated to this path or url.".format( 182 | pretrained_model_name_or_path, 183 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 184 | vocab_file)) 185 | return None 186 | if resolved_vocab_file == vocab_file: 187 | logger.info("loading vocabulary file {}".format(vocab_file)) 188 | else: 189 | logger.info("loading vocabulary file {} from cache at {}".format( 190 | vocab_file, resolved_vocab_file)) 191 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 192 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 193 | # than the number of positional embeddings 194 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] 195 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 196 | # Instantiate tokenizer. 197 | tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) 198 | return tokenizer 199 | 200 | 201 | class BasicTokenizer(object): 202 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 203 | 204 | def __init__(self, 205 | do_lower_case=True, 206 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 207 | """Constructs a BasicTokenizer. 208 | 209 | Args: 210 | do_lower_case: Whether to lower case the input. 211 | """ 212 | self.do_lower_case = do_lower_case 213 | self.never_split = never_split 214 | 215 | def tokenize(self, text): 216 | """Tokenizes a piece of text.""" 217 | text = self._clean_text(text) 218 | # This was added on November 1st, 2018 for the multilingual and Chinese 219 | # models. This is also applied to the English models now, but it doesn't 220 | # matter since the English models were not trained on any Chinese data 221 | # and generally don't have any Chinese data in them (there are Chinese 222 | # characters in the vocabulary because Wikipedia does have some Chinese 223 | # words in the English Wikipedia.). 224 | text = self._tokenize_chinese_chars(text) 225 | orig_tokens = whitespace_tokenize(text) 226 | split_tokens = [] 227 | for token in orig_tokens: 228 | if self.do_lower_case and token not in self.never_split: 229 | token = token.lower() 230 | token = self._run_strip_accents(token) 231 | split_tokens.extend(self._run_split_on_punc(token)) 232 | 233 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 234 | return output_tokens 235 | 236 | def _run_strip_accents(self, text): 237 | """Strips accents from a piece of text.""" 238 | text = unicodedata.normalize("NFD", text) 239 | output = [] 240 | for char in text: 241 | cat = unicodedata.category(char) 242 | if cat == "Mn": 243 | continue 244 | output.append(char) 245 | return "".join(output) 246 | 247 | def _run_split_on_punc(self, text): 248 | """Splits punctuation on a piece of text.""" 249 | if text in self.never_split: 250 | return [text] 251 | chars = list(text) 252 | i = 0 253 | start_new_word = True 254 | output = [] 255 | while i < len(chars): 256 | char = chars[i] 257 | if _is_punctuation(char): 258 | output.append([char]) 259 | start_new_word = True 260 | else: 261 | if start_new_word: 262 | output.append([]) 263 | start_new_word = False 264 | output[-1].append(char) 265 | i += 1 266 | 267 | return ["".join(x) for x in output] 268 | 269 | def _tokenize_chinese_chars(self, text): 270 | """Adds whitespace around any CJK character.""" 271 | output = [] 272 | for char in text: 273 | cp = ord(char) 274 | if self._is_chinese_char(cp): 275 | output.append(" ") 276 | output.append(char) 277 | output.append(" ") 278 | else: 279 | output.append(char) 280 | return "".join(output) 281 | 282 | def _is_chinese_char(self, cp): 283 | """Checks whether CP is the codepoint of a CJK character.""" 284 | # This defines a "chinese character" as anything in the CJK Unicode block: 285 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 286 | # 287 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 288 | # despite its name. The modern Korean Hangul alphabet is a different block, 289 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 290 | # space-separated words, so they are not treated specially and handled 291 | # like the all of the other languages. 292 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 293 | (cp >= 0x3400 and cp <= 0x4DBF) or # 294 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 295 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 296 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 297 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 298 | (cp >= 0xF900 and cp <= 0xFAFF) or # 299 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 300 | return True 301 | 302 | return False 303 | 304 | def _clean_text(self, text): 305 | """Performs invalid character removal and whitespace cleanup on text.""" 306 | output = [] 307 | for char in text: 308 | cp = ord(char) 309 | if cp == 0 or cp == 0xfffd or _is_control(char): 310 | continue 311 | if _is_whitespace(char): 312 | output.append(" ") 313 | else: 314 | output.append(char) 315 | return "".join(output) 316 | 317 | 318 | class WordpieceTokenizer(object): 319 | """Runs WordPiece tokenization.""" 320 | 321 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 322 | self.vocab = vocab 323 | self.unk_token = unk_token 324 | self.max_input_chars_per_word = max_input_chars_per_word 325 | 326 | def tokenize(self, text): 327 | """Tokenizes a piece of text into its word pieces. 328 | 329 | This uses a greedy longest-match-first algorithm to perform tokenization 330 | using the given vocabulary. 331 | 332 | For example: 333 | input = "unaffable" 334 | output = ["un", "##aff", "##able"] 335 | 336 | Args: 337 | text: A single token or whitespace separated tokens. This should have 338 | already been passed through `BasicTokenizer`. 339 | 340 | Returns: 341 | A list of wordpiece tokens. 342 | """ 343 | 344 | output_tokens = [] 345 | for token in whitespace_tokenize(text): 346 | chars = list(token) 347 | if len(chars) > self.max_input_chars_per_word: 348 | output_tokens.append(self.unk_token) 349 | continue 350 | 351 | is_bad = False 352 | start = 0 353 | sub_tokens = [] 354 | while start < len(chars): 355 | end = len(chars) 356 | cur_substr = None 357 | while start < end: 358 | substr = "".join(chars[start:end]) 359 | if start > 0: 360 | substr = "##" + substr 361 | if substr in self.vocab: 362 | cur_substr = substr 363 | break 364 | end -= 1 365 | if cur_substr is None: 366 | is_bad = True 367 | break 368 | sub_tokens.append(cur_substr) 369 | start = end 370 | 371 | if is_bad: 372 | output_tokens.append(self.unk_token) 373 | else: 374 | output_tokens.extend(sub_tokens) 375 | return output_tokens 376 | 377 | 378 | def _is_whitespace(char): 379 | """Checks whether `chars` is a whitespace character.""" 380 | # \t, \n, and \r are technically contorl characters but we treat them 381 | # as whitespace since they are generally considered as such. 382 | if char == " " or char == "\t" or char == "\n" or char == "\r": 383 | return True 384 | cat = unicodedata.category(char) 385 | if cat == "Zs": 386 | return True 387 | return False 388 | 389 | 390 | def _is_control(char): 391 | """Checks whether `chars` is a control character.""" 392 | # These are technically control characters but we count them as whitespace 393 | # characters. 394 | if char == "\t" or char == "\n" or char == "\r": 395 | return False 396 | cat = unicodedata.category(char) 397 | if cat.startswith("C"): 398 | return True 399 | return False 400 | 401 | 402 | def _is_punctuation(char): 403 | """Checks whether `chars` is a punctuation character.""" 404 | cp = ord(char) 405 | # We treat all non-letter/number ASCII as punctuation. 406 | # Characters such as "^", "$", and "`" are not in the Unicode 407 | # Punctuation class but we treat them as punctuation anyways, for 408 | # consistency. 409 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 410 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 411 | return True 412 | cat = unicodedata.category(char) 413 | if cat.startswith("P"): 414 | return True 415 | return False 416 | -------------------------------------------------------------------------------- /COMFUSE/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import argparse 5 | import numpy as np 6 | import datetime 7 | from torch.utils.data import DataLoader 8 | from selfdropout import rnddrop_2 9 | from config import make_lstm_multi_task_config, make_bilstm_multi_task_config, make_bilstm_multi_layer_multi_task_config 10 | from agent_base import MetaLSTMMultiTask 11 | from rumor_dataset import CategoriesSampler, wb_rumor_fsl_dataset 12 | 13 | 14 | def train(args): 15 | run_id = datetime.datetime.now().strftime('%m-%d-%H-%M') 16 | logdir = os.path.join("models", run_id) 17 | 18 | # avoid overriding 19 | os.makedirs(logdir, exist_ok=False) 20 | print("Model dir {}".format(logdir)) 21 | 22 | # ============= Training ============= 23 | 24 | # Setup device 25 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 26 | 27 | # print(split_data_files) 28 | b_size = args.batch_size 29 | n_worker = 4 if torch.cuda.is_available() else 1 30 | task_num = args.batch_size 31 | n_topic = args.topic 32 | n_way = args.way 33 | 34 | # how many episodes to train/test 35 | total_epoch = 20 36 | num_batch_train = 200 37 | num_batch_test = 100 38 | 39 | # datasets 40 | print(n_way * n_topic, args.shot + args.query) 41 | # return: feats, lb, lns, feats_com, lns_com 42 | train_set = wb_rumor_fsl_dataset('train', args.split_number, args.model, args.featstype, args.data_path, 43 | args.pad_size, args.com_pad_size, args.bert_path) 44 | train_sampler = CategoriesSampler(train_set.all_labels, 45 | num_batch_train * b_size, n_way * n_topic, args.shot + args.query) 46 | # print(train_sampler) 47 | train_batchsampler = torch.utils.data.BatchSampler(train_sampler, task_num, drop_last=True) 48 | 49 | val_set = wb_rumor_fsl_dataset('dev', args.split_number, args.model, args.featstype, args.data_path, args.pad_size, args.com_pad_size, 50 | args.bert_path) 51 | val_sampler = CategoriesSampler(val_set.all_labels, num_batch_test, n_way * n_topic, args.shot + args.query) 52 | val_batchsampler = torch.utils.data.BatchSampler(val_sampler, 1, drop_last=True) 53 | 54 | # n_worker = 1 # --------------------- remove this after debugging 55 | # trainloader = data.DataLoader(train_set, batch_sampler=train_sampler, num_workers=n_worker, pin_memory=True) 56 | # valloader = data.DataLoader(val_set, batch_sampler=val_sampler, num_workers=n_worker, pin_memory=True) 57 | 58 | trainloader = DataLoader(train_set, batch_sampler=train_batchsampler, num_workers=n_worker, pin_memory=True) 59 | valloader = DataLoader(val_set, batch_sampler=val_batchsampler, num_workers=n_worker, pin_memory=True) 60 | 61 | print('fix val set for all epochs') 62 | val_sampler.mode = 'probe' 63 | val_set.mode = 'dummy' 64 | # make one pass of dataset, and internally keep the indices 65 | # see rumor_dataset.py -> CategoriesSampler -> __iter__() -> probe 66 | for x in valloader: 67 | pass 68 | # val_sampler will use this fixed set to evaluate model 69 | val_sampler.mode = 'fix' 70 | val_set.mode = 'norm' 71 | print('fixed val set has %d batches' % (len(val_sampler.fixed_batches),)) 72 | 73 | ######################################## 74 | # Setup Model in agent_base.py 75 | ######################################## 76 | n_way = args.way 77 | k_shot = args.shot 78 | n_query = args.query 79 | 80 | if args.gru_type == "gru": 81 | config = make_lstm_multi_task_config(args.way, args.hidden_size, n_topic) # input dim=784 82 | elif args.gru_type == "gru_bi": 83 | config = make_bilstm_multi_task_config(args.way, args.hidden_size, n_topic) 84 | elif args.gru_type == "gru_bi_mult": 85 | assert args.gru_num_layer > 1, "GRU_BI_MULT is a multilayer (>1) GRU, set num_layer to 2" 86 | config = make_bilstm_multi_layer_multi_task_config(args.way, args.hidden_size, args.gru_num_layer, n_topic) 87 | else: 88 | raise Exception("Not Implemented Error") 89 | 90 | model = MetaLSTMMultiTask(args, config) 91 | model = model.to(device) 92 | 93 | # check dropout arguments 94 | drop_type = args.droptype 95 | drop_rate = args.droprate 96 | assert drop_type in [0, 1, 2, 3], 'invalid dropout type' 97 | if drop_type > 0: 98 | assert 0 < drop_rate < 1, 'invalid dropout rate' 99 | 100 | ################## 101 | # resume training 102 | ################## 103 | 104 | if len(args.pretrain_dir) > 0: 105 | model_path = os.path.join(args.pretrain_dir, 'best_model.pt') 106 | print('--------------------------') 107 | print('Load pre-trained model %s' % (model_path,)) 108 | print('--------------------------') 109 | model.load_state_dict(torch.load(model_path), strict=True) 110 | 111 | # Generate the labels for train set of the episodes 112 | # label_shot = torch.arange(n_way).repeat(n_topic).repeat(k_shot) 113 | 114 | t1 = [] 115 | t2 = [] 116 | 117 | for i in range(n_topic): 118 | for j in range(n_way): 119 | t1.append(j) 120 | 121 | for i in range(n_topic): 122 | for j in range(n_way): 123 | t2.append(i) 124 | 125 | # print(t1, t2) 126 | 127 | label_shot = torch.tensor(t1).repeat(k_shot) # [(0,1),(0,1),(0,1)]... there are k_shots 128 | label_shot = label_shot.to(device).long() 129 | label_shot_topic = torch.tensor(t2).repeat(k_shot) # [(0,0), (1,1), (2,2)] .... repeat k_shots 130 | label_shot_topic = label_shot_topic.to(device).long() 131 | p = n_topic * n_way * k_shot 132 | 133 | label_query = torch.tensor(t1).repeat(n_query) 134 | label_query = label_query.to(device).long() 135 | label_query_topic = torch.tensor(t2).repeat(n_query) 136 | label_query_topic = label_query_topic.to(device).long() 137 | 138 | print(label_shot, label_shot_topic) 139 | print(label_query, label_query_topic) 140 | 141 | # print('label_shot size', label_shot.size()) 142 | # print('label_query size', label_query.size()) 143 | 144 | ##################### 145 | ### main function ### 146 | ##################### 147 | best_val_acc = 0.0 148 | best_val_epoch = -1 149 | for epoch in range(total_epoch): 150 | 151 | print('Epoch %d/%d' % (epoch, total_epoch)) 152 | 153 | acc_clients = [] 154 | acc_topic_clients = [] 155 | for ix, data in enumerate(trainloader): 156 | feat, raw_label, seq_len, feat_com, seq_len_com = data 157 | # print(feat.size()) 158 | if drop_type > 0: 159 | feat = rnddrop_2(feat, drop_rate, drop_type, False) 160 | feat_com = rnddrop_2(feat_com, drop_rate, drop_type, False) 161 | feat = feat.to(device) 162 | feat_com = feat_com.to(device) 163 | x_shot, x_qry = feat[:, :p, 0, ...], feat[:, p:, 0, ...] # split to adaptation data and meta-learning data 164 | x_com_shot, x_com_qry = feat_com[:, :p, 0, ...], feat_com[:, p:, 0, ...] # split to adaptation data and meta-learning data 165 | # print(p) 166 | # print("x_shot", x_shot.shape) 167 | # print("x_qry", x_qry.shape) 168 | # Generate the labels for test set of the episodes during meta-train updates 169 | y_shot = label_shot.detach() 170 | y_shot_topic = label_shot_topic.detach() 171 | y_qry = label_query.detach() 172 | y_qry_topic = label_query_topic.detach() 173 | # print(y_qry.size()) 174 | 175 | # split to get len of data and meta-learning data 176 | # print(seq_len.shape) 177 | l_spt = seq_len[:, :p] 178 | l_qry = seq_len[:, p:] 179 | l_com_spt = seq_len_com[:, :p] 180 | l_com_qry = seq_len_com[:, p:] 181 | # print(l_spt) 182 | # print(l_qry) 183 | # print("l_spt", l_spt.shape) 184 | # print("l_qry", l_qry.shape) 185 | 186 | accs, accs_topic = model(x_shot, x_com_shot, y_shot, y_shot_topic, l_spt, l_com_spt, x_qry, x_com_qry, y_qry, y_qry_topic, l_qry, l_com_qry) 187 | acc_clients.append(accs[-1]) 188 | acc_topic_clients.append(accs_topic[-1]) 189 | 190 | if ix % 20 == 0: 191 | print('Train step %d/%d' % (ix, len(trainloader)), 192 | ' training acc: %.4f, topic acc: %.4f' % (accs[-1], accs_topic[-1])) 193 | 194 | print('Training avg acc: %.4f, topic acc: %.4f' % (np.mean(acc_clients), np.mean(acc_topic_clients))) 195 | 196 | # evaluation 197 | if epoch % args.val_frequency == 1: 198 | all_accs = [] 199 | all_topic_accs = [] 200 | accs_all_test = [] 201 | 202 | # print('Test dataloader', len(db_test), len(db_test.dataset)) # 100,100 203 | 204 | for ix, data in enumerate(valloader): 205 | feat, raw_label, seq_len, feat_com, seq_len_com = data 206 | feat = feat.to(device) 207 | feat_com = feat_com.to(device) 208 | 209 | # in validation and testing, batch size is one 210 | assert feat.size(0) == 1 211 | assert feat_com.size(0) == 1 212 | # label = label.to(device) # ignore raw_label, see rumor_dataset.py, __getitem__ last line 213 | x_shot, x_qry = feat[0, :p, 0, ...], feat[0, p:, 0, ...] 214 | x_com_shot, x_com_qry = feat_com[0, :p, 0, ...], feat_com[0, p:, 0, ...] 215 | # Generate the labels for test set of the episodes during meta-train updates 216 | y_shot = label_shot.detach() 217 | y_shot_topic = label_shot_topic.detach() 218 | y_qry = label_query.detach() 219 | y_qry_topic = label_query_topic.detach() 220 | 221 | l_spt = seq_len[0, :p] 222 | l_qry = seq_len[0, p:] 223 | l_com_spt = seq_len_com[0, :p] 224 | l_com_qry = seq_len_com[0, p:] 225 | 226 | ########################### 227 | # finetuning on query set 228 | ########################### 229 | # print("l_spt",l_spt.shape) 230 | accs, accs_topic = model.finetuning(x_shot, x_com_shot, y_shot, y_shot_topic, l_spt, l_com_spt, x_qry, x_com_qry, y_qry, y_qry_topic, 231 | l_qry, l_com_qry) 232 | accs_all_test.append(accs) 233 | all_accs.append(accs[-1]) 234 | all_topic_accs.append(accs_topic[-1]) 235 | 236 | if ix % 40 == 0: 237 | print(' [Ep %d/%d] Test acc %.4f, topic acc %.4f' % 238 | (ix, len(valloader), np.mean(all_accs), np.mean(all_topic_accs)), 239 | np.array(accs_all_test).mean(axis=0).astype(np.float16)[::4]) 240 | 241 | avg_test_acc = np.mean(all_accs) 242 | avg_test_topic_acc = np.mean(all_topic_accs) 243 | print('Testing avg acc: %.4f, topic acc: %.4f' % (avg_test_acc, avg_test_topic_acc)) 244 | if best_val_acc < avg_test_acc: 245 | best_val_acc = avg_test_acc 246 | best_val_epoch = epoch 247 | print('Best testing acc: %.4f at epoch %d' % (best_val_acc, best_val_epoch)) 248 | 249 | # Save model every 5 epochs 250 | model_path = os.path.join(logdir, 'model_%02d.pt' % (epoch,)) 251 | torch.save(model.state_dict(), model_path) 252 | print('Save model to %s' % (model_path,)) 253 | 254 | # save best model 255 | if best_val_epoch == epoch: 256 | model_path = os.path.join(logdir, 'best_model.pt') 257 | torch.save(model.state_dict(), model_path) 258 | 259 | 260 | def test(args): 261 | model_dir = args.model_dir 262 | assert os.path.isdir(model_dir) 263 | model_path = os.path.join(model_dir, 'best_model.pt') 264 | assert os.path.isfile(model_path) 265 | print('Load model %s' % (model_path)) 266 | 267 | # Setup device 268 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 269 | 270 | # print(split_data_files) 271 | # b_size = args.batch_size 272 | n_worker = 4 if torch.cuda.is_available() else 1 273 | num_batch_test = 100 274 | n_way = args.way 275 | k_shot = args.shot 276 | n_query = args.query 277 | n_topic = args.topic 278 | 279 | # datasets 280 | test_set = wb_rumor_fsl_dataset('test', args.split_number, args.model, args.featstype, args.data_path, 281 | args.pad_size, args.com_pad_size, args.bert_path) 282 | test_sampler = CategoriesSampler(test_set.all_labels, num_batch_test, n_way * n_topic, args.shot + args.query) 283 | test_batchsampler = torch.utils.data.BatchSampler(test_sampler, 1, drop_last=True) 284 | testloader = DataLoader(test_set, batch_sampler=test_batchsampler, num_workers=n_worker, pin_memory=True) 285 | 286 | print('Test set should be fixed by setting the same seed.') 287 | 288 | ######################################## 289 | # Setup Model in agent_base.py 290 | ######################################## 291 | 292 | if args.gru_type == "gru": 293 | config = make_lstm_multi_task_config(args.way, args.hidden_size, n_topic) # input dim=784 294 | elif args.gru_type == "gru_bi": 295 | config = make_bilstm_multi_task_config(args.way, args.hidden_size, n_topic) 296 | elif args.gru_type == "gru_bi_mult": 297 | assert args.gru_num_layer > 1, "GRU_BI_MULT is a multilayer (>1) GRU, set num_layer to 2" 298 | config = make_bilstm_multi_layer_multi_task_config(args.way, args.hidden_size, args.gru_num_layer, n_topic) 299 | else: 300 | raise Exception("Not Implemented Error") 301 | 302 | model = MetaLSTMMultiTask(args, config) 303 | model = model.to(device) 304 | 305 | ################## 306 | # resume training 307 | ################## 308 | 309 | print('--------------------------') 310 | print('Load trained model %s' % (model_path,)) 311 | print('--------------------------') 312 | model.load_state_dict(torch.load(model_path), strict=True) 313 | 314 | # Generate the labels for train set of the episodes 315 | t1 = [] 316 | t2 = [] 317 | 318 | for i in range(n_topic): 319 | for j in range(n_way): 320 | t1.append(j) 321 | 322 | for i in range(n_topic): 323 | for j in range(n_way): 324 | t2.append(i) 325 | 326 | label_shot = torch.tensor(t1).repeat(k_shot) # [(0,1),(0,1),(0,1)]... there are k_shots 327 | label_shot = label_shot.to(device).long() 328 | label_shot_topic = torch.tensor(t2).repeat(k_shot) # [(0,0), (1,1), (2,2)] .... repeat k_shots 329 | label_shot_topic = label_shot_topic.to(device).long() 330 | p = n_topic * n_way * k_shot 331 | 332 | label_query = torch.tensor(t1).repeat(n_query) 333 | label_query = label_query.to(device).long() 334 | label_query_topic = torch.tensor(t2).repeat(n_query) 335 | label_query_topic = label_query_topic.to(device).long() 336 | 337 | print('label_shot size', label_shot.size()) 338 | print('label_query size', label_query.size()) 339 | 340 | ##################### 341 | ### main function ### 342 | ##################### 343 | 344 | all_accs = [] 345 | all_topic_accs = [] 346 | accs_all_test = [] 347 | 348 | # print('Test dataloader', len(db_test), len(db_test.dataset)) # 100,100 349 | 350 | for ix, data in enumerate(testloader): 351 | feat, raw_label, seq_len, feat_com, seq_len_com = data 352 | feat = feat.to(device) 353 | feat_com = feat_com.to(device) 354 | 355 | # in validation and testing, batch size is one 356 | assert feat.size(0) == 1 357 | # label = label.to(device) # ignore raw_label, see rumor_dataset.py, __getitem__ last line 358 | x_shot, x_qry = feat[0, :p, 0, ...], feat[0, p:, 0, ...] 359 | x_com_shot, x_com_qry = feat_com[0, :p, 0, ...], feat_com[0, p:, 0, ...] 360 | # x_shot, x_qry = feat[0, :p, ...], feat[0, p:, ...] 361 | # print("x_shot:", x_shot.shape) 362 | # Generate the labels for test set of the episodes during meta-train updates 363 | y_shot = label_shot.detach() 364 | y_shot_topic = label_shot_topic.detach() 365 | y_qry = label_query.detach() 366 | y_qry_topic = label_query_topic.detach() 367 | # print(y_qry.size()) 368 | 369 | l_spt = seq_len[0, :p] 370 | l_qry = seq_len[0, p:] 371 | l_com_spt = seq_len_com[0, :p] 372 | l_com_qry = seq_len_com[0, p:] 373 | 374 | ########################### 375 | # finetuning on query set 376 | ########################### 377 | acc, acc_topic = model.finetuning(x_shot, x_com_shot, y_shot, y_shot_topic, l_spt, l_com_spt, x_qry, x_com_qry, y_qry, y_qry_topic, l_qry, l_com_qry) 378 | accs_all_test.append(acc) 379 | all_accs.append(acc[-1]) 380 | all_topic_accs.append(acc_topic[-1]) 381 | 382 | if ix % 40 == 0: 383 | print(' [Ep %d/%d] Test acc %.4f, topic acc: %.4f' % 384 | (ix, len(testloader), np.mean(all_accs), np.mean(all_topic_accs)), 385 | np.array(accs_all_test).mean(axis=0).astype(np.float16)[::4]) 386 | 387 | avg_test_acc = np.mean(all_accs) 388 | avg_topic_test_acc = np.mean(all_topic_accs) 389 | 390 | print('Testing avg acc: %.4f, topic acc: %.4f' % (avg_test_acc, avg_topic_test_acc)) 391 | 392 | 393 | if __name__ == "__main__": 394 | 395 | parser = argparse.ArgumentParser(description="config") 396 | # about model 397 | parser.add_argument('--model', type=str, default="bert", help="bert") 398 | parser.add_argument('--dataset', type=str, default="weibo", choices="weibo|pheme") 399 | 400 | parser.add_argument('--gru_type', type=str, default="gru_bi_mult", choices=["gru", "gru_bi", "gru_bi_mult"], 401 | help="gru|gru_bidirection") 402 | parser.add_argument('--gru_num_layer', type=int, default=2, help="num of gru hidden layers") 403 | 404 | # about feature type 405 | parser.add_argument('--featstype', type=str, default="emb_outs", help="emb_outs") 406 | parser.add_argument('--droptype', type=int, default=1, help="0-nodrop|1-drop word|2-drop dim|3-drop both") 407 | parser.add_argument('--droprate', type=float, default=0.3, help="dropout rate") 408 | 409 | # about path 410 | parser.add_argument('--ph', type=int, default=0, choices=[0, 1], help='train|test') 411 | parser.add_argument('--pretrain_dir', type=str, default='', help='path of models') 412 | parser.add_argument('--is_seg', type=int, default=0, choices=[0, 1], help='classification|segmentation') 413 | parser.add_argument('--model_dir', type=str, default='', help='path of models') 414 | 415 | # about training 416 | # parser.add_argument("--config", type=str, default="configs/mrms_fsl.yml", help="Configuration file to use", ) 417 | parser.add_argument("--gpu", type=str, default="0", help="Used GPUs") 418 | parser.add_argument('--batch_size', type=int, default=2, help='batch size of tasks') 419 | parser.add_argument('--val_frequency', type=int, default=5, help="Validate every 50 episodes") 420 | parser.add_argument('--split_number', type=int, default=0, help='Cross-validation split number 0,1,2') 421 | parser.add_argument('--meta_lr', type=float, help='meta-level outer learning rate', default=1e-3) 422 | parser.add_argument('--update_lr', type=float, help='task-level inner update learning rate', default=0.01) 423 | parser.add_argument('--update_step', type=int, help='task-level inner update steps', default=20) 424 | parser.add_argument('--update_step_test', type=int, help='update steps for finetuning', default=30) 425 | parser.add_argument('--hidden_size', type=int, help='hidden size', default=128) 426 | 427 | # about task 428 | parser.add_argument('--way', type=int, default=2) 429 | parser.add_argument('--topic', type=int, default=3) # how many topics to sample at a same time as multi-tasking 430 | parser.add_argument('--shot', type=int, default=1) 431 | parser.add_argument('--query', type=int, default=9, help='number of query per class') 432 | 433 | args = parser.parse_args() 434 | 435 | # Set the gpu 436 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 437 | 438 | # assert args.way == 6 439 | 440 | # Setup seeds 441 | seed = 1337 442 | torch.manual_seed(seed) 443 | torch.cuda.manual_seed(seed) 444 | np.random.seed(seed) 445 | random.seed(seed) 446 | 447 | if args.dataset == 'weibo': 448 | args.data_path = './DataSet_pair_comments/' 449 | args.pad_size = 100 450 | args.com_pad_size = 32 451 | args.eta = 0.1 452 | assert args.way == 2 and args.topic == 3 453 | args.bert_path = './bert_pretrain' 454 | elif args.dataset == 'pheme': 455 | args.data_path = './Pheme_DataSet_Pair_comments/' 456 | args.pad_size = 48 457 | args.com_pad_size = 48 458 | args.eta = 0.1 459 | assert args.way == 2 and args.topic == 2 460 | args.bert_path = 'bert-base-uncased' 461 | args.update_step = 10 462 | args.update_step_test = 10 463 | else: 464 | assert 1 == 2 465 | 466 | if args.ph == 0: 467 | train(args) 468 | else: 469 | # need to provide existing model path 470 | assert len(args.model_dir) > 0 471 | test(args) 472 | -------------------------------------------------------------------------------- /COMFUSE/pytorch_pretrained/tokenization_transfo_xl.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Tokenization classes for Transformer XL model. 17 | Adapted from https://github.com/kimiyoung/transformer-xl. 18 | """ 19 | from __future__ import (absolute_import, division, print_function, 20 | unicode_literals) 21 | 22 | import glob 23 | import logging 24 | import os 25 | import sys 26 | from collections import Counter, OrderedDict 27 | from io import open 28 | import unicodedata 29 | 30 | import torch 31 | import numpy as np 32 | 33 | from .file_utils import cached_path 34 | 35 | if sys.version_info[0] == 2: 36 | import cPickle as pickle 37 | else: 38 | import pickle 39 | 40 | 41 | logger = logging.getLogger(__name__) 42 | 43 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 44 | 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin", 45 | } 46 | VOCAB_NAME = 'vocab.bin' 47 | 48 | PRETRAINED_CORPUS_ARCHIVE_MAP = { 49 | 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-corpus.bin", 50 | } 51 | CORPUS_NAME = 'corpus.bin' 52 | 53 | class TransfoXLTokenizer(object): 54 | """ 55 | Transformer-XL tokenizer adapted from Vocab class in https://github.com/kimiyoung/transformer-xl 56 | """ 57 | @classmethod 58 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 59 | """ 60 | Instantiate a TransfoXLTokenizer. 61 | The TransfoXLTokenizer. 62 | """ 63 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 64 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] 65 | else: 66 | if os.path.isdir(pretrained_model_name_or_path): 67 | vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) 68 | else: 69 | vocab_file = pretrained_model_name_or_path 70 | # redirect to the cache, if necessary 71 | try: 72 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 73 | except EnvironmentError: 74 | logger.error( 75 | "Model name '{}' was not found in model name list ({}). " 76 | "We assumed '{}' was a path or url but couldn't find files {} " 77 | "at this path or url.".format( 78 | pretrained_model_name_or_path, 79 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 80 | pretrained_model_name_or_path, 81 | vocab_file)) 82 | return None 83 | if resolved_vocab_file == vocab_file: 84 | logger.info("loading vocabulary file {}".format(vocab_file)) 85 | else: 86 | logger.info("loading vocabulary file {} from cache at {}".format( 87 | vocab_file, resolved_vocab_file)) 88 | 89 | # Instantiate tokenizer. 90 | tokenizer = cls(*inputs, **kwargs) 91 | vocab_dict = torch.load(resolved_vocab_file) 92 | for key, value in vocab_dict.items(): 93 | tokenizer.__dict__[key] = value 94 | return tokenizer 95 | 96 | def __init__(self, special=[], min_freq=0, max_size=None, lower_case=False, 97 | delimiter=None, vocab_file=None, never_split=("", "", "")): 98 | self.counter = Counter() 99 | self.special = special 100 | self.min_freq = min_freq 101 | self.max_size = max_size 102 | self.lower_case = lower_case 103 | self.delimiter = delimiter 104 | self.vocab_file = vocab_file 105 | self.never_split = never_split 106 | 107 | def count_file(self, path, verbose=False, add_eos=False): 108 | if verbose: print('counting file {} ...'.format(path)) 109 | assert os.path.exists(path) 110 | 111 | sents = [] 112 | with open(path, 'r', encoding='utf-8') as f: 113 | for idx, line in enumerate(f): 114 | if verbose and idx > 0 and idx % 500000 == 0: 115 | print(' line {}'.format(idx)) 116 | symbols = self.tokenize(line, add_eos=add_eos) 117 | self.counter.update(symbols) 118 | sents.append(symbols) 119 | 120 | return sents 121 | 122 | def count_sents(self, sents, verbose=False): 123 | """ 124 | sents : a list of sentences, each a list of tokenized symbols 125 | """ 126 | if verbose: print('counting {} sents ...'.format(len(sents))) 127 | for idx, symbols in enumerate(sents): 128 | if verbose and idx > 0 and idx % 500000 == 0: 129 | print(' line {}'.format(idx)) 130 | self.counter.update(symbols) 131 | 132 | def _build_from_file(self, vocab_file): 133 | self.idx2sym = [] 134 | self.sym2idx = OrderedDict() 135 | 136 | with open(vocab_file, 'r', encoding='utf-8') as f: 137 | for line in f: 138 | symb = line.strip().split()[0] 139 | self.add_symbol(symb) 140 | if '' in self.sym2idx: 141 | self.unk_idx = self.sym2idx[''] 142 | elif '' in self.sym2idx: 143 | self.unk_idx = self.sym2idx[''] 144 | else: 145 | raise ValueError('No token in vocabulary') 146 | 147 | def save_vocabulary(self, vocab_path): 148 | """Save the tokenizer vocabulary to a directory or file.""" 149 | index = 0 150 | if os.path.isdir(vocab_path): 151 | vocab_file = os.path.join(vocab_path, VOCAB_NAME) 152 | torch.save(self.__dict__, vocab_file) 153 | return vocab_file 154 | 155 | def build_vocab(self): 156 | if self.vocab_file: 157 | print('building vocab from {}'.format(self.vocab_file)) 158 | self._build_from_file(self.vocab_file) 159 | print('final vocab size {}'.format(len(self))) 160 | else: 161 | print('building vocab with min_freq={}, max_size={}'.format( 162 | self.min_freq, self.max_size)) 163 | self.idx2sym = [] 164 | self.sym2idx = OrderedDict() 165 | 166 | for sym in self.special: 167 | self.add_special(sym) 168 | 169 | for sym, cnt in self.counter.most_common(self.max_size): 170 | if cnt < self.min_freq: break 171 | self.add_symbol(sym) 172 | 173 | print('final vocab size {} from {} unique tokens'.format( 174 | len(self), len(self.counter))) 175 | 176 | def encode_file(self, path, ordered=False, verbose=False, add_eos=True, 177 | add_double_eos=False): 178 | if verbose: print('encoding file {} ...'.format(path)) 179 | assert os.path.exists(path) 180 | encoded = [] 181 | with open(path, 'r', encoding='utf-8') as f: 182 | for idx, line in enumerate(f): 183 | if verbose and idx > 0 and idx % 500000 == 0: 184 | print(' line {}'.format(idx)) 185 | symbols = self.tokenize(line, add_eos=add_eos, 186 | add_double_eos=add_double_eos) 187 | encoded.append(self.convert_to_tensor(symbols)) 188 | 189 | if ordered: 190 | encoded = torch.cat(encoded) 191 | 192 | return encoded 193 | 194 | def encode_sents(self, sents, ordered=False, verbose=False): 195 | if verbose: print('encoding {} sents ...'.format(len(sents))) 196 | encoded = [] 197 | for idx, symbols in enumerate(sents): 198 | if verbose and idx > 0 and idx % 500000 == 0: 199 | print(' line {}'.format(idx)) 200 | encoded.append(self.convert_to_tensor(symbols)) 201 | 202 | if ordered: 203 | encoded = torch.cat(encoded) 204 | 205 | return encoded 206 | 207 | def add_special(self, sym): 208 | if sym not in self.sym2idx: 209 | self.idx2sym.append(sym) 210 | self.sym2idx[sym] = len(self.idx2sym) - 1 211 | setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym]) 212 | 213 | def add_symbol(self, sym): 214 | if sym not in self.sym2idx: 215 | self.idx2sym.append(sym) 216 | self.sym2idx[sym] = len(self.idx2sym) - 1 217 | 218 | def get_sym(self, idx): 219 | assert 0 <= idx < len(self), 'Index {} out of vocabulary range'.format(idx) 220 | return self.idx2sym[idx] 221 | 222 | def get_idx(self, sym): 223 | if sym in self.sym2idx: 224 | return self.sym2idx[sym] 225 | else: 226 | # print('encounter unk {}'.format(sym)) 227 | # assert '' not in sym 228 | if hasattr(self, 'unk_idx'): 229 | return self.sym2idx.get(sym, self.unk_idx) 230 | # Backward compatibility with pre-trained models 231 | elif '' in self.sym2idx: 232 | return self.sym2idx[''] 233 | elif '' in self.sym2idx: 234 | return self.sym2idx[''] 235 | else: 236 | raise ValueError('Token not in vocabulary and no token in vocabulary for replacement') 237 | 238 | def convert_ids_to_tokens(self, indices): 239 | """Converts a sequence of indices in symbols using the vocab.""" 240 | return [self.get_sym(idx) for idx in indices] 241 | 242 | def convert_tokens_to_ids(self, symbols): 243 | """Converts a sequence of symbols into ids using the vocab.""" 244 | return [self.get_idx(sym) for sym in symbols] 245 | 246 | def convert_to_tensor(self, symbols): 247 | return torch.LongTensor(self.convert_tokens_to_ids(symbols)) 248 | 249 | def decode(self, indices, exclude=None): 250 | """Converts a sequence of indices in a string.""" 251 | if exclude is None: 252 | return ' '.join([self.get_sym(idx) for idx in indices]) 253 | else: 254 | return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude]) 255 | 256 | def __len__(self): 257 | return len(self.idx2sym) 258 | 259 | def tokenize(self, line, add_eos=False, add_double_eos=False): 260 | line = line.strip() 261 | # convert to lower case 262 | if self.lower_case: 263 | line = line.lower() 264 | 265 | # empty delimiter '' will evaluate False 266 | if self.delimiter == '': 267 | symbols = line 268 | else: 269 | symbols = line.split(self.delimiter) 270 | 271 | if add_double_eos: # lm1b 272 | return [''] + symbols + [''] 273 | elif add_eos: 274 | return symbols + [''] 275 | else: 276 | return symbols 277 | 278 | 279 | class LMOrderedIterator(object): 280 | def __init__(self, data, bsz, bptt, device='cpu', ext_len=None): 281 | """ 282 | data -- LongTensor -- the LongTensor is strictly ordered 283 | """ 284 | self.bsz = bsz 285 | self.bptt = bptt 286 | self.ext_len = ext_len if ext_len is not None else 0 287 | 288 | self.device = device 289 | 290 | # Work out how cleanly we can divide the dataset into bsz parts. 291 | self.n_step = data.size(0) // bsz 292 | 293 | # Trim off any extra elements that wouldn't cleanly fit (remainders). 294 | data = data.narrow(0, 0, self.n_step * bsz) 295 | 296 | # Evenly divide the data across the bsz batches. 297 | self.data = data.view(bsz, -1).t().contiguous().to(device) 298 | 299 | # Number of mini-batches 300 | self.n_batch = (self.n_step + self.bptt - 1) // self.bptt 301 | 302 | def get_batch(self, i, bptt=None): 303 | if bptt is None: bptt = self.bptt 304 | seq_len = min(bptt, self.data.size(0) - 1 - i) 305 | 306 | end_idx = i + seq_len 307 | beg_idx = max(0, i - self.ext_len) 308 | 309 | data = self.data[beg_idx:end_idx] 310 | target = self.data[i+1:i+1+seq_len] 311 | 312 | data_out = data.transpose(0, 1).contiguous().to(self.device) 313 | target_out = target.transpose(0, 1).contiguous().to(self.device) 314 | 315 | return data_out, target_out, seq_len 316 | 317 | def get_fixlen_iter(self, start=0): 318 | for i in range(start, self.data.size(0) - 1, self.bptt): 319 | yield self.get_batch(i) 320 | 321 | def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3): 322 | max_len = self.bptt + max_deviation * std 323 | i = start 324 | while True: 325 | bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2. 326 | bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std)))) 327 | data, target, seq_len = self.get_batch(i, bptt) 328 | i += seq_len 329 | yield data, target, seq_len 330 | if i >= self.data.size(0) - 2: 331 | break 332 | 333 | def __iter__(self): 334 | return self.get_fixlen_iter() 335 | 336 | 337 | class LMShuffledIterator(object): 338 | def __init__(self, data, bsz, bptt, device='cpu', ext_len=None, shuffle=False): 339 | """ 340 | data -- list[LongTensor] -- there is no order among the LongTensors 341 | """ 342 | self.data = data 343 | 344 | self.bsz = bsz 345 | self.bptt = bptt 346 | self.ext_len = ext_len if ext_len is not None else 0 347 | 348 | self.device = device 349 | self.shuffle = shuffle 350 | 351 | def get_sent_stream(self): 352 | # index iterator 353 | epoch_indices = np.random.permutation(len(self.data)) if self.shuffle \ 354 | else np.array(range(len(self.data))) 355 | 356 | # sentence iterator 357 | for idx in epoch_indices: 358 | yield self.data[idx] 359 | 360 | def stream_iterator(self, sent_stream): 361 | # streams for each data in the batch 362 | streams = [None] * self.bsz 363 | 364 | data = torch.LongTensor(self.bptt, self.bsz) 365 | target = torch.LongTensor(self.bptt, self.bsz) 366 | 367 | n_retain = 0 368 | 369 | while True: 370 | # data : [n_retain+bptt x bsz] 371 | # target : [bptt x bsz] 372 | data[n_retain:].fill_(-1) 373 | target.fill_(-1) 374 | 375 | valid_batch = True 376 | 377 | for i in range(self.bsz): 378 | n_filled = 0 379 | try: 380 | while n_filled < self.bptt: 381 | if streams[i] is None or len(streams[i]) <= 1: 382 | streams[i] = next(sent_stream) 383 | # number of new tokens to fill in 384 | n_new = min(len(streams[i]) - 1, self.bptt - n_filled) 385 | # first n_retain tokens are retained from last batch 386 | data[n_retain+n_filled:n_retain+n_filled+n_new, i] = \ 387 | streams[i][:n_new] 388 | target[n_filled:n_filled+n_new, i] = \ 389 | streams[i][1:n_new+1] 390 | streams[i] = streams[i][n_new:] 391 | n_filled += n_new 392 | except StopIteration: 393 | valid_batch = False 394 | break 395 | 396 | if not valid_batch: 397 | return 398 | 399 | data_out = data.transpose(0, 1).contiguous().to(self.device) 400 | target_out = target.transpose(0, 1).contiguous().to(self.device) 401 | 402 | yield data_out, target_out, self.bptt 403 | 404 | n_retain = min(data.size(0), self.ext_len) 405 | if n_retain > 0: 406 | data[:n_retain] = data[-n_retain:] 407 | data.resize_(n_retain + self.bptt, data.size(1)) 408 | 409 | def __iter__(self): 410 | # sent_stream is an iterator 411 | sent_stream = self.get_sent_stream() 412 | 413 | for batch in self.stream_iterator(sent_stream): 414 | yield batch 415 | 416 | 417 | class LMMultiFileIterator(LMShuffledIterator): 418 | def __init__(self, paths, vocab, bsz, bptt, device='cpu', ext_len=None, 419 | shuffle=False): 420 | 421 | self.paths = paths 422 | self.vocab = vocab 423 | 424 | self.bsz = bsz 425 | self.bptt = bptt 426 | self.ext_len = ext_len if ext_len is not None else 0 427 | 428 | self.device = device 429 | self.shuffle = shuffle 430 | 431 | def get_sent_stream(self, path): 432 | sents = self.vocab.encode_file(path, add_double_eos=True) 433 | if self.shuffle: 434 | np.random.shuffle(sents) 435 | sent_stream = iter(sents) 436 | 437 | return sent_stream 438 | 439 | def __iter__(self): 440 | if self.shuffle: 441 | np.random.shuffle(self.paths) 442 | 443 | for path in self.paths: 444 | # sent_stream is an iterator 445 | sent_stream = self.get_sent_stream(path) 446 | for batch in self.stream_iterator(sent_stream): 447 | yield batch 448 | 449 | 450 | class TransfoXLCorpus(object): 451 | @classmethod 452 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 453 | """ 454 | Instantiate a pre-processed corpus. 455 | """ 456 | vocab = TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 457 | if pretrained_model_name_or_path in PRETRAINED_CORPUS_ARCHIVE_MAP: 458 | corpus_file = PRETRAINED_CORPUS_ARCHIVE_MAP[pretrained_model_name_or_path] 459 | else: 460 | corpus_file = os.path.join(pretrained_model_name_or_path, CORPUS_NAME) 461 | # redirect to the cache, if necessary 462 | try: 463 | resolved_corpus_file = cached_path(corpus_file, cache_dir=cache_dir) 464 | except EnvironmentError: 465 | logger.error( 466 | "Corpus '{}' was not found in corpus list ({}). " 467 | "We assumed '{}' was a path or url but couldn't find files {} " 468 | "at this path or url.".format( 469 | pretrained_model_name_or_path, 470 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 471 | pretrained_model_name_or_path, 472 | corpus_file)) 473 | return None 474 | if resolved_corpus_file == corpus_file: 475 | logger.info("loading corpus file {}".format(corpus_file)) 476 | else: 477 | logger.info("loading corpus file {} from cache at {}".format( 478 | corpus_file, resolved_corpus_file)) 479 | 480 | # Instantiate tokenizer. 481 | corpus = cls(*inputs, **kwargs) 482 | corpus_dict = torch.load(resolved_corpus_file) 483 | for key, value in corpus_dict.items(): 484 | corpus.__dict__[key] = value 485 | corpus.vocab = vocab 486 | if corpus.train is not None: 487 | corpus.train = torch.tensor(corpus.train, dtype=torch.long) 488 | if corpus.valid is not None: 489 | corpus.valid = torch.tensor(corpus.valid, dtype=torch.long) 490 | if corpus.test is not None: 491 | corpus.test = torch.tensor(corpus.test, dtype=torch.long) 492 | return corpus 493 | 494 | def __init__(self, *args, **kwargs): 495 | self.vocab = TransfoXLTokenizer(*args, **kwargs) 496 | self.dataset = None 497 | self.train = None 498 | self.valid = None 499 | self.test = None 500 | 501 | def build_corpus(self, path, dataset): 502 | self.dataset = dataset 503 | 504 | if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']: 505 | self.vocab.count_file(os.path.join(path, 'train.txt')) 506 | self.vocab.count_file(os.path.join(path, 'valid.txt')) 507 | self.vocab.count_file(os.path.join(path, 'test.txt')) 508 | elif self.dataset == 'wt103': 509 | self.vocab.count_file(os.path.join(path, 'train.txt')) 510 | elif self.dataset == 'lm1b': 511 | train_path_pattern = os.path.join( 512 | path, '1-billion-word-language-modeling-benchmark-r13output', 513 | 'training-monolingual.tokenized.shuffled', 'news.en-*') 514 | train_paths = glob.glob(train_path_pattern) 515 | # the vocab will load from file when build_vocab() is called 516 | 517 | self.vocab.build_vocab() 518 | 519 | if self.dataset in ['ptb', 'wt2', 'wt103']: 520 | self.train = self.vocab.encode_file( 521 | os.path.join(path, 'train.txt'), ordered=True) 522 | self.valid = self.vocab.encode_file( 523 | os.path.join(path, 'valid.txt'), ordered=True) 524 | self.test = self.vocab.encode_file( 525 | os.path.join(path, 'test.txt'), ordered=True) 526 | elif self.dataset in ['enwik8', 'text8']: 527 | self.train = self.vocab.encode_file( 528 | os.path.join(path, 'train.txt'), ordered=True, add_eos=False) 529 | self.valid = self.vocab.encode_file( 530 | os.path.join(path, 'valid.txt'), ordered=True, add_eos=False) 531 | self.test = self.vocab.encode_file( 532 | os.path.join(path, 'test.txt'), ordered=True, add_eos=False) 533 | elif self.dataset == 'lm1b': 534 | self.train = train_paths 535 | self.valid = self.vocab.encode_file( 536 | os.path.join(path, 'valid.txt'), ordered=False, add_double_eos=True) 537 | self.test = self.vocab.encode_file( 538 | os.path.join(path, 'test.txt'), ordered=False, add_double_eos=True) 539 | 540 | def get_iterator(self, split, *args, **kwargs): 541 | if split == 'train': 542 | if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']: 543 | data_iter = LMOrderedIterator(self.train, *args, **kwargs) 544 | elif self.dataset == 'lm1b': 545 | kwargs['shuffle'] = True 546 | data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs) 547 | elif split in ['valid', 'test']: 548 | data = self.valid if split == 'valid' else self.test 549 | if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']: 550 | data_iter = LMOrderedIterator(data, *args, **kwargs) 551 | elif self.dataset == 'lm1b': 552 | data_iter = LMShuffledIterator(data, *args, **kwargs) 553 | 554 | return data_iter 555 | 556 | 557 | def get_lm_corpus(datadir, dataset): 558 | fn = os.path.join(datadir, 'cache.pt') 559 | fn_pickle = os.path.join(datadir, 'cache.pkl') 560 | if os.path.exists(fn): 561 | print('Loading cached dataset...') 562 | corpus = torch.load(fn_pickle) 563 | elif os.path.exists(fn): 564 | print('Loading cached dataset from pickle...') 565 | with open(fn, "rb") as fp: 566 | corpus = pickle.load(fp) 567 | else: 568 | print('Producing dataset {}...'.format(dataset)) 569 | kwargs = {} 570 | if dataset in ['wt103', 'wt2']: 571 | kwargs['special'] = [''] 572 | kwargs['lower_case'] = False 573 | elif dataset == 'ptb': 574 | kwargs['special'] = [''] 575 | kwargs['lower_case'] = True 576 | elif dataset == 'lm1b': 577 | kwargs['special'] = [] 578 | kwargs['lower_case'] = False 579 | kwargs['vocab_file'] = os.path.join(datadir, '1b_word_vocab.txt') 580 | elif dataset in ['enwik8', 'text8']: 581 | pass 582 | 583 | corpus = TransfoXLCorpus(datadir, dataset, **kwargs) 584 | torch.save(corpus, fn) 585 | 586 | return corpus 587 | --------------------------------------------------------------------------------