├── .gitignore ├── LICENSE ├── README.md ├── bert ├── activations.py ├── configuration_bert.py ├── configuration_utils.py ├── file_utils.py ├── generation_utils.py ├── modeling_bert.py ├── modeling_utils.py ├── tokenization_bert.py ├── tokenization_utils.py └── tokenization_utils_base.py ├── config ├── config.yaml └── open.yaml ├── data ├── READEME.md └── coco_id.json ├── engine ├── __init__.py └── engine.py ├── image └── framework.jpg ├── model ├── __init__.py ├── backbone.py ├── layers.py ├── mmcv_custom │ ├── __init__.py │ └── checkpoint.py └── segmenter.py ├── requirements.txt ├── test.py ├── tools ├── data_process.py ├── folder2lmdb.py ├── latency.py ├── prepare_datasets.md └── refer.py ├── train.py └── utils ├── __init__.py ├── box_ops.py ├── bpe_simple_vocab_16e6.txt.gz ├── config.py ├── dataset.py ├── dataset_open.py ├── misc.py └── simple_tokenizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | /exp 2 | /wandb/** 3 | **/__pycache__ 4 | /train_open.py 5 | /.vscode 6 | config/config.yaml 7 | config/open.yaml 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Toneyaya 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CGFormer 2 | The official PyTorch implementation of the CVPR 2023 paper "Contrastive Grouping with Transformer for Referring Image Segmentation". 3 | 4 | This paper first introduces learnable query tokens to represent objects and then alternately queries linguistic features and groups visual features into the query tokens for object-aware cross-modal reasoning. CGFormer achieves cross-level interaction by jointly updating the query tokens and decoding masks in every two consecutive layers. In addition, we introduce new splits on datasets for evaluating generalization for referring image segmentation models. 5 | 6 | ## Framework 7 |

8 | 9 |

10 | 11 | ## Preparation 12 | 13 | 1. Environment 14 | - [PyTorch](www.pytorch.org) 15 | - Other dependencies in `requirements.txt` 16 | 2. Datasets 17 | - The detailed instruction is in [prepare_datasets](data/READEME.md) 18 | 3. Pretrained weights 19 | - [Swin-Base-window12](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth) 20 | 4. Our checkpoints [hugging face](https://huggingface.co/Toneyaya/CGFormer/tree/main) 21 | 22 | ## Train and Test (RIS) 23 | 24 | This implementation only supports **multi-gpu**, **DistributedDataParallel** training, which is faster and simpler; single-gpu or DataParallel training is not supported. Besides, the evaluation only supports single-gpu mode. 25 | 26 | To do training of CGFormer with 8 GPUs, run: 27 | 28 | ``` 29 | python -u train.py --config config/config.yaml 30 | ``` 31 | 32 | To do evaluation of CGFormer with 1 GPU, run: 33 | ``` 34 | CUDA_VISIBLE_DEVICES=0 python -u test.py --config config/refcoco/config.yaml --opts TEST.test_split val TEST.test_lmdb path/val.lmdb TRAIN.weight path/checkpoint.pth 35 | ``` 36 | ## License 37 | 38 | This project is under the MIT license. See [LICENSE](LICENSE) for details. 39 | 40 | 41 | ## Citation 42 | If you find our work useful in your research, please consider citing: 43 | ``` 44 | @InProceedings{Tang_2023_CVPR, 45 | author = {Tang, Jiajin and Zheng, Ge and Shi, Cheng and Yang, Sibei}, 46 | title = {Contrastive Grouping With Transformer for Referring Image Segmentation}, 47 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 48 | month = {June}, 49 | year = {2023}, 50 | pages = {23570-23580} 51 | } 52 | ``` 53 | 54 | Many thanks to these excellent opensource projects 55 | [CRIS](https://github.com/DerrickWang005/CRIS.pytorch/tree/master) and [LAVT](https://github.com/yz93/LAVT-RIS). -------------------------------------------------------------------------------- /bert/activations.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def swish(x): 12 | return x * torch.sigmoid(x) 13 | 14 | 15 | def _gelu_python(x): 16 | """ Original Implementation of the gelu activation function in Google Bert repo when initially created. 17 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 18 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 19 | This is now written in C in torch.nn.functional 20 | Also see https://arxiv.org/abs/1606.08415 21 | """ 22 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 23 | 24 | 25 | def gelu_new(x): 26 | """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT). 27 | Also see https://arxiv.org/abs/1606.08415 28 | """ 29 | return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) 30 | 31 | 32 | if torch.__version__ < "1.4.0": 33 | gelu = _gelu_python 34 | else: 35 | gelu = F.gelu 36 | 37 | 38 | def gelu_fast(x): 39 | return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))) 40 | 41 | 42 | ACT2FN = { 43 | "relu": F.relu, 44 | "swish": swish, 45 | "gelu": gelu, 46 | "tanh": torch.tanh, 47 | "gelu_new": gelu_new, 48 | "gelu_fast": gelu_fast, 49 | } 50 | 51 | 52 | def get_activation(activation_string): 53 | if activation_string in ACT2FN: 54 | return ACT2FN[activation_string] 55 | else: 56 | raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys()))) 57 | -------------------------------------------------------------------------------- /bert/configuration_bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ BERT model configuration """ 17 | 18 | 19 | import logging 20 | 21 | from .configuration_utils import PretrainedConfig 22 | 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 27 | "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json", 28 | "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json", 29 | "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json", 30 | "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json", 31 | "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json", 32 | "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json", 33 | "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json", 34 | "bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json", 35 | "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json", 36 | "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json", 37 | "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json", 38 | "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json", 39 | "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json", 40 | "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json", 41 | "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json", 42 | "cl-tohoku/bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese/config.json", 43 | "cl-tohoku/bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking/config.json", 44 | "cl-tohoku/bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char/config.json", 45 | "cl-tohoku/bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking/config.json", 46 | "TurkuNLP/bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json", 47 | "TurkuNLP/bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json", 48 | "wietsedv/bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/config.json", 49 | # See all BERT models at https://huggingface.co/models?filter=bert 50 | } 51 | 52 | 53 | class BertConfig(PretrainedConfig): 54 | r""" 55 | This is the configuration class to store the configuration of a :class:`~transformers.BertModel`. 56 | It is used to instantiate an BERT model according to the specified arguments, defining the model 57 | architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of 58 | the BERT `bert-base-uncased `__ architecture. 59 | 60 | Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used 61 | to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig` 62 | for more information. 63 | 64 | 65 | Args: 66 | vocab_size (:obj:`int`, optional, defaults to 30522): 67 | Vocabulary size of the BERT model. Defines the different tokens that 68 | can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.BertModel`. 69 | hidden_size (:obj:`int`, optional, defaults to 768): 70 | Dimensionality of the encoder layers and the pooler layer. 71 | num_hidden_layers (:obj:`int`, optional, defaults to 12): 72 | Number of hidden layers in the Transformer encoder. 73 | num_attention_heads (:obj:`int`, optional, defaults to 12): 74 | Number of attention heads for each attention layer in the Transformer encoder. 75 | intermediate_size (:obj:`int`, optional, defaults to 3072): 76 | Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. 77 | hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "gelu"): 78 | The non-linear activation function (function or string) in the encoder and pooler. 79 | If string, "gelu", "relu", "swish" and "gelu_new" are supported. 80 | hidden_dropout_prob (:obj:`float`, optional, defaults to 0.1): 81 | The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. 82 | attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1): 83 | The dropout ratio for the attention probabilities. 84 | max_position_embeddings (:obj:`int`, optional, defaults to 512): 85 | The maximum sequence length that this model might ever be used with. 86 | Typically set this to something large just in case (e.g., 512 or 1024 or 2048). 87 | type_vocab_size (:obj:`int`, optional, defaults to 2): 88 | The vocabulary size of the `token_type_ids` passed into :class:`~transformers.BertModel`. 89 | initializer_range (:obj:`float`, optional, defaults to 0.02): 90 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 91 | layer_norm_eps (:obj:`float`, optional, defaults to 1e-12): 92 | The epsilon used by the layer normalization layers. 93 | gradient_checkpointing (:obj:`bool`, optional, defaults to False): 94 | If True, use gradient checkpointing to save memory at the expense of slower backward pass. 95 | 96 | Example:: 97 | 98 | >>> from transformers import BertModel, BertConfig 99 | 100 | >>> # Initializing a BERT bert-base-uncased style configuration 101 | >>> configuration = BertConfig() 102 | 103 | >>> # Initializing a model from the bert-base-uncased style configuration 104 | >>> model = BertModel(configuration) 105 | 106 | >>> # Accessing the model configuration 107 | >>> configuration = model.config 108 | """ 109 | model_type = "bert" 110 | 111 | def __init__( 112 | self, 113 | vocab_size=30522, 114 | hidden_size=768, 115 | num_hidden_layers=12, 116 | num_attention_heads=12, 117 | intermediate_size=3072, 118 | hidden_act="gelu", 119 | hidden_dropout_prob=0.1, 120 | attention_probs_dropout_prob=0.1, 121 | max_position_embeddings=512, 122 | type_vocab_size=2, 123 | initializer_range=0.02, 124 | layer_norm_eps=1e-12, 125 | pad_token_id=0, 126 | gradient_checkpointing=False, 127 | **kwargs 128 | ): 129 | super().__init__(pad_token_id=pad_token_id, **kwargs) 130 | 131 | self.vocab_size = vocab_size 132 | self.hidden_size = hidden_size 133 | self.num_hidden_layers = num_hidden_layers 134 | self.num_attention_heads = num_attention_heads 135 | self.hidden_act = hidden_act 136 | self.intermediate_size = intermediate_size 137 | self.hidden_dropout_prob = hidden_dropout_prob 138 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 139 | self.max_position_embeddings = max_position_embeddings 140 | self.type_vocab_size = type_vocab_size 141 | self.initializer_range = initializer_range 142 | self.layer_norm_eps = layer_norm_eps 143 | self.gradient_checkpointing = gradient_checkpointing 144 | -------------------------------------------------------------------------------- /bert/configuration_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Configuration base class and utilities.""" 17 | 18 | 19 | import copy 20 | import json 21 | import logging 22 | import os 23 | from typing import Dict, Tuple 24 | 25 | from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url 26 | 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | class PretrainedConfig(object): 32 | r""" Base class for all configuration classes. 33 | Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations. 34 | 35 | Note: 36 | A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to initialize a model does **not** load the model weights. 37 | It only affects the model's configuration. 38 | 39 | Class attributes (overridden by derived classes): 40 | - ``model_type``: a string that identifies the model type, that we serialize into the JSON file, and that we use to recreate the correct object in :class:`~transformers.AutoConfig`. 41 | 42 | Args: 43 | finetuning_task (:obj:`string` or :obj:`None`, `optional`, defaults to :obj:`None`): 44 | Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint. 45 | num_labels (:obj:`int`, `optional`, defaults to `2`): 46 | Number of classes to use when the model is a classification model (sequences/tokens) 47 | output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`False`): 48 | Should the model returns all hidden-states. 49 | output_attentions (:obj:`bool`, `optional`, defaults to :obj:`False`): 50 | Should the model returns all attentions. 51 | torchscript (:obj:`bool`, `optional`, defaults to :obj:`False`): 52 | Is the model used with Torchscript (for PyTorch models). 53 | """ 54 | model_type: str = "" 55 | 56 | def __init__(self, **kwargs): 57 | # Attributes with defaults 58 | self.output_hidden_states = kwargs.pop("output_hidden_states", False) 59 | self.output_attentions = kwargs.pop("output_attentions", False) 60 | self.use_cache = kwargs.pop("use_cache", True) # Not used by all models 61 | self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models 62 | self.use_bfloat16 = kwargs.pop("use_bfloat16", False) 63 | self.pruned_heads = kwargs.pop("pruned_heads", {}) 64 | 65 | # Is decoder is used in encoder-decoder models to differentiate encoder from decoder 66 | self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False) 67 | self.is_decoder = kwargs.pop("is_decoder", False) 68 | 69 | # Parameters for sequence generation 70 | self.max_length = kwargs.pop("max_length", 20) 71 | self.min_length = kwargs.pop("min_length", 0) 72 | self.do_sample = kwargs.pop("do_sample", False) 73 | self.early_stopping = kwargs.pop("early_stopping", False) 74 | self.num_beams = kwargs.pop("num_beams", 1) 75 | self.temperature = kwargs.pop("temperature", 1.0) 76 | self.top_k = kwargs.pop("top_k", 50) 77 | self.top_p = kwargs.pop("top_p", 1.0) 78 | self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0) 79 | self.length_penalty = kwargs.pop("length_penalty", 1.0) 80 | self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0) 81 | self.bad_words_ids = kwargs.pop("bad_words_ids", None) 82 | self.num_return_sequences = kwargs.pop("num_return_sequences", 1) 83 | 84 | # Fine-tuning task arguments 85 | self.architectures = kwargs.pop("architectures", None) 86 | self.finetuning_task = kwargs.pop("finetuning_task", None) 87 | self.id2label = kwargs.pop("id2label", None) 88 | self.label2id = kwargs.pop("label2id", None) 89 | if self.id2label is not None: 90 | kwargs.pop("num_labels", None) 91 | self.id2label = dict((int(key), value) for key, value in self.id2label.items()) 92 | # Keys are always strings in JSON so convert ids to int here. 93 | else: 94 | self.num_labels = kwargs.pop("num_labels", 2) 95 | 96 | # Tokenizer arguments TODO: eventually tokenizer and models should share the same config 97 | self.prefix = kwargs.pop("prefix", None) 98 | self.bos_token_id = kwargs.pop("bos_token_id", None) 99 | self.pad_token_id = kwargs.pop("pad_token_id", None) 100 | self.eos_token_id = kwargs.pop("eos_token_id", None) 101 | self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None) 102 | 103 | # task specific arguments 104 | self.task_specific_params = kwargs.pop("task_specific_params", None) 105 | 106 | # TPU arguments 107 | self.xla_device = kwargs.pop("xla_device", None) 108 | 109 | # Additional attributes without default values 110 | for key, value in kwargs.items(): 111 | try: 112 | setattr(self, key, value) 113 | except AttributeError as err: 114 | logger.error("Can't set {} with value {} for {}".format(key, value, self)) 115 | raise err 116 | 117 | @property 118 | def num_labels(self): 119 | return len(self.id2label) 120 | 121 | @num_labels.setter 122 | def num_labels(self, num_labels): 123 | self.id2label = {i: "LABEL_{}".format(i) for i in range(num_labels)} 124 | self.label2id = dict(zip(self.id2label.values(), self.id2label.keys())) 125 | 126 | def save_pretrained(self, save_directory): 127 | """ 128 | Save a configuration object to the directory `save_directory`, so that it 129 | can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method. 130 | 131 | Args: 132 | save_directory (:obj:`string`): 133 | Directory where the configuration JSON file will be saved. 134 | """ 135 | if os.path.isfile(save_directory): 136 | raise AssertionError("Provided path ({}) should be a directory, not a file".format(save_directory)) 137 | os.makedirs(save_directory, exist_ok=True) 138 | # If we save using the predefined names, we can load using `from_pretrained` 139 | output_config_file = os.path.join(save_directory, CONFIG_NAME) 140 | 141 | self.to_json_file(output_config_file, use_diff=True) 142 | logger.info("Configuration saved in {}".format(output_config_file)) 143 | 144 | @classmethod 145 | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> "PretrainedConfig": 146 | r""" 147 | 148 | Instantiate a :class:`~transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration. 149 | 150 | Args: 151 | pretrained_model_name_or_path (:obj:`string`): 152 | either: 153 | - a string with the `shortcut name` of a pre-trained model configuration to load from cache or 154 | download, e.g.: ``bert-base-uncased``. 155 | - a string with the `identifier name` of a pre-trained model configuration that was user-uploaded to 156 | our S3, e.g.: ``dbmdz/bert-base-german-cased``. 157 | - a path to a `directory` containing a configuration file saved using the 158 | :func:`~transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``. 159 | - a path or url to a saved configuration JSON `file`, e.g.: 160 | ``./my_model_directory/configuration.json``. 161 | cache_dir (:obj:`string`, `optional`): 162 | Path to a directory in which a downloaded pre-trained model 163 | configuration should be cached if the standard cache should not be used. 164 | kwargs (:obj:`Dict[str, any]`, `optional`): 165 | The values in kwargs of any keys which are configuration attributes will be used to override the loaded 166 | values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is 167 | controlled by the `return_unused_kwargs` keyword parameter. 168 | force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): 169 | Force to (re-)download the model weights and configuration files and override the cached versions if they exist. 170 | resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): 171 | Do not delete incompletely recieved file. Attempt to resume the download if such a file exists. 172 | proxies (:obj:`Dict`, `optional`): 173 | A dictionary of proxy servers to use by protocol or endpoint, e.g.: 174 | :obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.` 175 | The proxies are used on each request. 176 | return_unused_kwargs: (`optional`) bool: 177 | If False, then this function returns just the final configuration object. 178 | If True, then this functions returns a :obj:`Tuple(config, unused_kwargs)` where `unused_kwargs` is a 179 | dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part 180 | of kwargs which has not been used to update `config` and is otherwise ignored. 181 | 182 | Returns: 183 | :class:`PretrainedConfig`: An instance of a configuration object 184 | 185 | Examples:: 186 | 187 | # We can't instantiate directly the base class `PretrainedConfig` so let's show the examples on a 188 | # derived class: BertConfig 189 | config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. 190 | config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')` 191 | config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json') 192 | config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False) 193 | assert config.output_attention == True 194 | config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, 195 | foo=False, return_unused_kwargs=True) 196 | assert config.output_attention == True 197 | assert unused_kwargs == {'foo': False} 198 | 199 | """ 200 | config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) 201 | return cls.from_dict(config_dict, **kwargs) 202 | 203 | @classmethod 204 | def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs) -> Tuple[Dict, Dict]: 205 | """ 206 | From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used 207 | for instantiating a Config using `from_dict`. 208 | 209 | Parameters: 210 | pretrained_model_name_or_path (:obj:`string`): 211 | The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. 212 | 213 | Returns: 214 | :obj:`Tuple[Dict, Dict]`: The dictionary that will be used to instantiate the configuration object. 215 | 216 | """ 217 | cache_dir = kwargs.pop("cache_dir", None) 218 | force_download = kwargs.pop("force_download", False) 219 | resume_download = kwargs.pop("resume_download", False) 220 | proxies = kwargs.pop("proxies", None) 221 | local_files_only = kwargs.pop("local_files_only", False) 222 | 223 | if os.path.isdir(pretrained_model_name_or_path): 224 | config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) 225 | elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): 226 | config_file = pretrained_model_name_or_path 227 | else: 228 | config_file = hf_bucket_url(pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False) 229 | 230 | try: 231 | # Load from URL or cache if already cached 232 | resolved_config_file = cached_path( 233 | config_file, 234 | cache_dir=cache_dir, 235 | force_download=force_download, 236 | proxies=proxies, 237 | resume_download=resume_download, 238 | local_files_only=local_files_only, 239 | ) 240 | # Load config dict 241 | if resolved_config_file is None: 242 | raise EnvironmentError 243 | config_dict = cls._dict_from_json_file(resolved_config_file) 244 | 245 | except EnvironmentError: 246 | msg = ( 247 | f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n" 248 | f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n" 249 | f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {CONFIG_NAME} file\n\n" 250 | ) 251 | raise EnvironmentError(msg) 252 | 253 | except json.JSONDecodeError: 254 | msg = ( 255 | "Couldn't reach server at '{}' to download configuration file or " 256 | "configuration file is not a valid JSON file. " 257 | "Please check network or file content here: {}.".format(config_file, resolved_config_file) 258 | ) 259 | raise EnvironmentError(msg) 260 | 261 | if resolved_config_file == config_file: 262 | logger.info("loading configuration file {}".format(config_file)) 263 | else: 264 | logger.info("loading configuration file {} from cache at {}".format(config_file, resolved_config_file)) 265 | 266 | return config_dict, kwargs 267 | 268 | @classmethod 269 | def from_dict(cls, config_dict: Dict, **kwargs) -> "PretrainedConfig": 270 | """ 271 | Constructs a `Config` from a Python dictionary of parameters. 272 | 273 | Args: 274 | config_dict (:obj:`Dict[str, any]`): 275 | Dictionary that will be used to instantiate the configuration object. Such a dictionary can be retrieved 276 | from a pre-trained checkpoint by leveraging the :func:`~transformers.PretrainedConfig.get_config_dict` 277 | method. 278 | kwargs (:obj:`Dict[str, any]`): 279 | Additional parameters from which to initialize the configuration object. 280 | 281 | Returns: 282 | :class:`PretrainedConfig`: An instance of a configuration object 283 | """ 284 | return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) 285 | 286 | config = cls(**config_dict) 287 | 288 | if hasattr(config, "pruned_heads"): 289 | config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items()) 290 | 291 | # Update config with kwargs if needed 292 | to_remove = [] 293 | for key, value in kwargs.items(): 294 | if hasattr(config, key): 295 | setattr(config, key, value) 296 | to_remove.append(key) 297 | for key in to_remove: 298 | kwargs.pop(key, None) 299 | 300 | logger.info("Model config %s", str(config)) 301 | if return_unused_kwargs: 302 | return config, kwargs 303 | else: 304 | return config 305 | 306 | @classmethod 307 | def from_json_file(cls, json_file: str) -> "PretrainedConfig": 308 | """ 309 | Constructs a `Config` from the path to a json file of parameters. 310 | 311 | Args: 312 | json_file (:obj:`string`): 313 | Path to the JSON file containing the parameters. 314 | 315 | Returns: 316 | :class:`PretrainedConfig`: An instance of a configuration object 317 | 318 | """ 319 | config_dict = cls._dict_from_json_file(json_file) 320 | return cls(**config_dict) 321 | 322 | @classmethod 323 | def _dict_from_json_file(cls, json_file: str): 324 | with open(json_file, "r", encoding="utf-8") as reader: 325 | text = reader.read() 326 | return json.loads(text) 327 | 328 | def __eq__(self, other): 329 | return self.__dict__ == other.__dict__ 330 | 331 | def __repr__(self): 332 | return "{} {}".format(self.__class__.__name__, self.to_json_string()) 333 | 334 | def to_diff_dict(self): 335 | """ 336 | Removes all attributes from config which correspond to the default 337 | config attributes for better readability and serializes to a Python 338 | dictionary. 339 | 340 | Returns: 341 | :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, 342 | """ 343 | config_dict = self.to_dict() 344 | 345 | # get the default config dict 346 | default_config_dict = PretrainedConfig().to_dict() 347 | 348 | serializable_config_dict = {} 349 | 350 | # only serialize values that differ from the default config 351 | for key, value in config_dict.items(): 352 | if key not in default_config_dict or value != default_config_dict[key]: 353 | serializable_config_dict[key] = value 354 | 355 | return serializable_config_dict 356 | 357 | def to_dict(self): 358 | """ 359 | Serializes this instance to a Python dictionary. 360 | 361 | Returns: 362 | :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, 363 | """ 364 | output = copy.deepcopy(self.__dict__) 365 | if hasattr(self.__class__, "model_type"): 366 | output["model_type"] = self.__class__.model_type 367 | return output 368 | 369 | def to_json_string(self, use_diff=True): 370 | """ 371 | Serializes this instance to a JSON string. 372 | 373 | Args: 374 | use_diff (:obj:`bool`): 375 | If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON string. 376 | 377 | Returns: 378 | :obj:`string`: String containing all the attributes that make up this configuration instance in JSON format. 379 | """ 380 | if use_diff is True: 381 | config_dict = self.to_diff_dict() 382 | else: 383 | config_dict = self.to_dict() 384 | return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" 385 | 386 | def to_json_file(self, json_file_path, use_diff=True): 387 | """ 388 | Save this instance to a json file. 389 | 390 | Args: 391 | json_file_path (:obj:`string`): 392 | Path to the JSON file in which this configuration instance's parameters will be saved. 393 | use_diff (:obj:`bool`): 394 | If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON file. 395 | """ 396 | with open(json_file_path, "w", encoding="utf-8") as writer: 397 | writer.write(self.to_json_string(use_diff=use_diff)) 398 | 399 | def update(self, config_dict: Dict): 400 | """ 401 | Updates attributes of this class 402 | with attributes from `config_dict`. 403 | 404 | Args: 405 | :obj:`Dict[str, any]`: Dictionary of attributes that shall be updated for this class. 406 | """ 407 | for key, value in config_dict.items(): 408 | setattr(self, key, value) 409 | -------------------------------------------------------------------------------- /bert/tokenization_bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | 18 | import collections 19 | import logging 20 | import os 21 | import unicodedata 22 | from typing import List, Optional 23 | 24 | from .tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace 25 | 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} 30 | 31 | PRETRAINED_VOCAB_FILES_MAP = { 32 | "vocab_file": { 33 | "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 34 | "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 35 | "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 36 | "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 37 | "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 38 | "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 39 | "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 40 | "bert-base-german-cased": "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt", 41 | "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt", 42 | "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt", 43 | "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt", 44 | "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt", 45 | "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt", 46 | "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-vocab.txt", 47 | "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-vocab.txt", 48 | "TurkuNLP/bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/vocab.txt", 49 | "TurkuNLP/bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/vocab.txt", 50 | "wietsedv/bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/vocab.txt", 51 | } 52 | } 53 | 54 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 55 | "bert-base-uncased": 512, 56 | "bert-large-uncased": 512, 57 | "bert-base-cased": 512, 58 | "bert-large-cased": 512, 59 | "bert-base-multilingual-uncased": 512, 60 | "bert-base-multilingual-cased": 512, 61 | "bert-base-chinese": 512, 62 | "bert-base-german-cased": 512, 63 | "bert-large-uncased-whole-word-masking": 512, 64 | "bert-large-cased-whole-word-masking": 512, 65 | "bert-large-uncased-whole-word-masking-finetuned-squad": 512, 66 | "bert-large-cased-whole-word-masking-finetuned-squad": 512, 67 | "bert-base-cased-finetuned-mrpc": 512, 68 | "bert-base-german-dbmdz-cased": 512, 69 | "bert-base-german-dbmdz-uncased": 512, 70 | "TurkuNLP/bert-base-finnish-cased-v1": 512, 71 | "TurkuNLP/bert-base-finnish-uncased-v1": 512, 72 | "wietsedv/bert-base-dutch-cased": 512, 73 | } 74 | 75 | PRETRAINED_INIT_CONFIGURATION = { 76 | "bert-base-uncased": {"do_lower_case": True}, 77 | "bert-large-uncased": {"do_lower_case": True}, 78 | "bert-base-cased": {"do_lower_case": False}, 79 | "bert-large-cased": {"do_lower_case": False}, 80 | "bert-base-multilingual-uncased": {"do_lower_case": True}, 81 | "bert-base-multilingual-cased": {"do_lower_case": False}, 82 | "bert-base-chinese": {"do_lower_case": False}, 83 | "bert-base-german-cased": {"do_lower_case": False}, 84 | "bert-large-uncased-whole-word-masking": {"do_lower_case": True}, 85 | "bert-large-cased-whole-word-masking": {"do_lower_case": False}, 86 | "bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True}, 87 | "bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False}, 88 | "bert-base-cased-finetuned-mrpc": {"do_lower_case": False}, 89 | "bert-base-german-dbmdz-cased": {"do_lower_case": False}, 90 | "bert-base-german-dbmdz-uncased": {"do_lower_case": True}, 91 | "TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False}, 92 | "TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True}, 93 | "wietsedv/bert-base-dutch-cased": {"do_lower_case": False}, 94 | } 95 | 96 | 97 | def load_vocab(vocab_file): 98 | """Loads a vocabulary file into a dictionary.""" 99 | vocab = collections.OrderedDict() 100 | with open(vocab_file, "r", encoding="utf-8") as reader: 101 | tokens = reader.readlines() 102 | for index, token in enumerate(tokens): 103 | token = token.rstrip("\n") 104 | vocab[token] = index 105 | return vocab 106 | 107 | 108 | def whitespace_tokenize(text): 109 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 110 | text = text.strip() 111 | if not text: 112 | return [] 113 | tokens = text.split() 114 | return tokens 115 | 116 | 117 | class BertTokenizer(PreTrainedTokenizer): 118 | r""" 119 | Constructs a BERT tokenizer. Based on WordPiece. 120 | 121 | This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users 122 | should refer to the superclass for more information regarding methods. 123 | 124 | Args: 125 | vocab_file (:obj:`string`): 126 | File containing the vocabulary. 127 | do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`): 128 | Whether to lowercase the input when tokenizing. 129 | do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`): 130 | Whether to do basic tokenization before WordPiece. 131 | never_split (:obj:`Iterable`, `optional`, defaults to :obj:`None`): 132 | Collection of tokens which will never be split during tokenization. Only has an effect when 133 | :obj:`do_basic_tokenize=True` 134 | unk_token (:obj:`string`, `optional`, defaults to "[UNK]"): 135 | The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this 136 | token instead. 137 | sep_token (:obj:`string`, `optional`, defaults to "[SEP]"): 138 | The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences 139 | for sequence classification or for a text and a question for question answering. 140 | It is also used as the last token of a sequence built with special tokens. 141 | pad_token (:obj:`string`, `optional`, defaults to "[PAD]"): 142 | The token used for padding, for example when batching sequences of different lengths. 143 | cls_token (:obj:`string`, `optional`, defaults to "[CLS]"): 144 | The classifier token which is used when doing sequence classification (classification of the whole 145 | sequence instead of per-token classification). It is the first token of the sequence when built with 146 | special tokens. 147 | mask_token (:obj:`string`, `optional`, defaults to "[MASK]"): 148 | The token used for masking values. This is the token used when training this model with masked language 149 | modeling. This is the token which the model will try to predict. 150 | tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`): 151 | Whether to tokenize Chinese characters. 152 | This should likely be deactivated for Japanese: 153 | see: https://github.com/huggingface/transformers/issues/328 154 | """ 155 | 156 | vocab_files_names = VOCAB_FILES_NAMES 157 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 158 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 159 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 160 | 161 | def __init__( 162 | self, 163 | vocab_file, 164 | do_lower_case=True, 165 | do_basic_tokenize=True, 166 | never_split=None, 167 | unk_token="[UNK]", 168 | sep_token="[SEP]", 169 | pad_token="[PAD]", 170 | cls_token="[CLS]", 171 | mask_token="[MASK]", 172 | tokenize_chinese_chars=True, 173 | **kwargs 174 | ): 175 | super().__init__( 176 | unk_token=unk_token, 177 | sep_token=sep_token, 178 | pad_token=pad_token, 179 | cls_token=cls_token, 180 | mask_token=mask_token, 181 | **kwargs, 182 | ) 183 | 184 | if not os.path.isfile(vocab_file): 185 | raise ValueError( 186 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 187 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file) 188 | ) 189 | self.vocab = load_vocab(vocab_file) 190 | self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) 191 | self.do_basic_tokenize = do_basic_tokenize 192 | if do_basic_tokenize: 193 | self.basic_tokenizer = BasicTokenizer( 194 | do_lower_case=do_lower_case, never_split=never_split, tokenize_chinese_chars=tokenize_chinese_chars 195 | ) 196 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token) 197 | 198 | @property 199 | def vocab_size(self): 200 | return len(self.vocab) 201 | 202 | def get_vocab(self): 203 | return dict(self.vocab, **self.added_tokens_encoder) 204 | 205 | def _tokenize(self, text): 206 | split_tokens = [] 207 | if self.do_basic_tokenize: 208 | for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): 209 | 210 | # If the token is part of the never_split set 211 | if token in self.basic_tokenizer.never_split: 212 | split_tokens.append(token) 213 | else: 214 | split_tokens += self.wordpiece_tokenizer.tokenize(token) 215 | else: 216 | split_tokens = self.wordpiece_tokenizer.tokenize(text) 217 | return split_tokens 218 | 219 | def _convert_token_to_id(self, token): 220 | """ Converts a token (str) in an id using the vocab. """ 221 | return self.vocab.get(token, self.vocab.get(self.unk_token)) 222 | 223 | def _convert_id_to_token(self, index): 224 | """Converts an index (integer) in a token (str) using the vocab.""" 225 | return self.ids_to_tokens.get(index, self.unk_token) 226 | 227 | def convert_tokens_to_string(self, tokens): 228 | """ Converts a sequence of tokens (string) in a single string. """ 229 | out_string = " ".join(tokens).replace(" ##", "").strip() 230 | return out_string 231 | 232 | def build_inputs_with_special_tokens( 233 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 234 | ) -> List[int]: 235 | """ 236 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks 237 | by concatenating and adding special tokens. 238 | A BERT sequence has the following format: 239 | 240 | - single sequence: ``[CLS] X [SEP]`` 241 | - pair of sequences: ``[CLS] A [SEP] B [SEP]`` 242 | 243 | Args: 244 | token_ids_0 (:obj:`List[int]`): 245 | List of IDs to which the special tokens will be added 246 | token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): 247 | Optional second list of IDs for sequence pairs. 248 | 249 | Returns: 250 | :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. 251 | """ 252 | if token_ids_1 is None: 253 | return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] 254 | cls = [self.cls_token_id] 255 | sep = [self.sep_token_id] 256 | return cls + token_ids_0 + sep + token_ids_1 + sep 257 | 258 | def get_special_tokens_mask( 259 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False 260 | ) -> List[int]: 261 | """ 262 | Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding 263 | special tokens using the tokenizer ``prepare_for_model`` method. 264 | 265 | Args: 266 | token_ids_0 (:obj:`List[int]`): 267 | List of ids. 268 | token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): 269 | Optional second list of IDs for sequence pairs. 270 | already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): 271 | Set to True if the token list is already formatted with special tokens for the model 272 | 273 | Returns: 274 | :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. 275 | """ 276 | 277 | if already_has_special_tokens: 278 | if token_ids_1 is not None: 279 | raise ValueError( 280 | "You should not supply a second sequence if the provided sequence of " 281 | "ids is already formated with special tokens for the model." 282 | ) 283 | return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) 284 | 285 | if token_ids_1 is not None: 286 | return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] 287 | return [1] + ([0] * len(token_ids_0)) + [1] 288 | 289 | def create_token_type_ids_from_sequences( 290 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 291 | ) -> List[int]: 292 | """ 293 | Creates a mask from the two sequences passed to be used in a sequence-pair classification task. 294 | A BERT sequence pair mask has the following format: 295 | 296 | :: 297 | 298 | 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 299 | | first sequence | second sequence | 300 | 301 | if token_ids_1 is None, only returns the first portion of the mask (0's). 302 | 303 | Args: 304 | token_ids_0 (:obj:`List[int]`): 305 | List of ids. 306 | token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): 307 | Optional second list of IDs for sequence pairs. 308 | 309 | Returns: 310 | :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given 311 | sequence(s). 312 | """ 313 | sep = [self.sep_token_id] 314 | cls = [self.cls_token_id] 315 | if token_ids_1 is None: 316 | return len(cls + token_ids_0 + sep) * [0] 317 | return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] 318 | 319 | def save_vocabulary(self, vocab_path): 320 | """ 321 | Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory. 322 | 323 | Args: 324 | vocab_path (:obj:`str`): 325 | The directory in which to save the vocabulary. 326 | 327 | Returns: 328 | :obj:`Tuple(str)`: Paths to the files saved. 329 | """ 330 | index = 0 331 | if os.path.isdir(vocab_path): 332 | vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"]) 333 | else: 334 | vocab_file = vocab_path 335 | with open(vocab_file, "w", encoding="utf-8") as writer: 336 | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): 337 | if index != token_index: 338 | logger.warning( 339 | "Saving vocabulary to {}: vocabulary indices are not consecutive." 340 | " Please check that the vocabulary is not corrupted!".format(vocab_file) 341 | ) 342 | index = token_index 343 | writer.write(token + "\n") 344 | index += 1 345 | return (vocab_file,) 346 | 347 | 348 | class BasicTokenizer(object): 349 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 350 | 351 | def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True): 352 | """ Constructs a BasicTokenizer. 353 | 354 | Args: 355 | **do_lower_case**: Whether to lower case the input. 356 | **never_split**: (`optional`) list of str 357 | Kept for backward compatibility purposes. 358 | Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`) 359 | List of token not to split. 360 | **tokenize_chinese_chars**: (`optional`) boolean (default True) 361 | Whether to tokenize Chinese characters. 362 | This should likely be deactivated for Japanese: 363 | see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328 364 | """ 365 | if never_split is None: 366 | never_split = [] 367 | self.do_lower_case = do_lower_case 368 | self.never_split = set(never_split) 369 | self.tokenize_chinese_chars = tokenize_chinese_chars 370 | 371 | def tokenize(self, text, never_split=None): 372 | """ Basic Tokenization of a piece of text. 373 | Split on "white spaces" only, for sub-word tokenization, see WordPieceTokenizer. 374 | 375 | Args: 376 | **never_split**: (`optional`) list of str 377 | Kept for backward compatibility purposes. 378 | Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`) 379 | List of token not to split. 380 | """ 381 | # union() returns a new set by concatenating the two sets. 382 | never_split = self.never_split.union(set(never_split)) if never_split else self.never_split 383 | 384 | # This was added on November 1st, 2018 for the multilingual and Chinese 385 | # models. This is also applied to the English models now, but it doesn't 386 | # matter since the English models were not trained on any Chinese data 387 | # and generally don't have any Chinese data in them (there are Chinese 388 | # characters in the vocabulary because Wikipedia does have some Chinese 389 | # words in the English Wikipedia.). 390 | if self.tokenize_chinese_chars: 391 | text = self._tokenize_chinese_chars(text) 392 | orig_tokens = whitespace_tokenize(text) 393 | split_tokens = [] 394 | for token in orig_tokens: 395 | if self.do_lower_case and token not in never_split: 396 | token = token.lower() 397 | token = self._run_strip_accents(token) 398 | split_tokens.extend(self._run_split_on_punc(token, never_split)) 399 | 400 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 401 | return output_tokens 402 | 403 | def _run_strip_accents(self, text): 404 | """Strips accents from a piece of text.""" 405 | text = unicodedata.normalize("NFD", text) 406 | output = [] 407 | for char in text: 408 | cat = unicodedata.category(char) 409 | if cat == "Mn": 410 | continue 411 | output.append(char) 412 | return "".join(output) 413 | 414 | def _run_split_on_punc(self, text, never_split=None): 415 | """Splits punctuation on a piece of text.""" 416 | if never_split is not None and text in never_split: 417 | return [text] 418 | chars = list(text) 419 | i = 0 420 | start_new_word = True 421 | output = [] 422 | while i < len(chars): 423 | char = chars[i] 424 | if _is_punctuation(char): 425 | output.append([char]) 426 | start_new_word = True 427 | else: 428 | if start_new_word: 429 | output.append([]) 430 | start_new_word = False 431 | output[-1].append(char) 432 | i += 1 433 | 434 | return ["".join(x) for x in output] 435 | 436 | def _tokenize_chinese_chars(self, text): 437 | """Adds whitespace around any CJK character.""" 438 | output = [] 439 | for char in text: 440 | cp = ord(char) 441 | if self._is_chinese_char(cp): 442 | output.append(" ") 443 | output.append(char) 444 | output.append(" ") 445 | else: 446 | output.append(char) 447 | return "".join(output) 448 | 449 | def _is_chinese_char(self, cp): 450 | """Checks whether CP is the codepoint of a CJK character.""" 451 | # This defines a "chinese character" as anything in the CJK Unicode block: 452 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 453 | # 454 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 455 | # despite its name. The modern Korean Hangul alphabet is a different block, 456 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 457 | # space-separated words, so they are not treated specially and handled 458 | # like the all of the other languages. 459 | if ( 460 | (cp >= 0x4E00 and cp <= 0x9FFF) 461 | or (cp >= 0x3400 and cp <= 0x4DBF) # 462 | or (cp >= 0x20000 and cp <= 0x2A6DF) # 463 | or (cp >= 0x2A700 and cp <= 0x2B73F) # 464 | or (cp >= 0x2B740 and cp <= 0x2B81F) # 465 | or (cp >= 0x2B820 and cp <= 0x2CEAF) # 466 | or (cp >= 0xF900 and cp <= 0xFAFF) 467 | or (cp >= 0x2F800 and cp <= 0x2FA1F) # 468 | ): # 469 | return True 470 | 471 | return False 472 | 473 | def _clean_text(self, text): 474 | """Performs invalid character removal and whitespace cleanup on text.""" 475 | output = [] 476 | for char in text: 477 | cp = ord(char) 478 | if cp == 0 or cp == 0xFFFD or _is_control(char): 479 | continue 480 | if _is_whitespace(char): 481 | output.append(" ") 482 | else: 483 | output.append(char) 484 | return "".join(output) 485 | 486 | 487 | class WordpieceTokenizer(object): 488 | """Runs WordPiece tokenization.""" 489 | 490 | def __init__(self, vocab, unk_token, max_input_chars_per_word=100): 491 | self.vocab = vocab 492 | self.unk_token = unk_token 493 | self.max_input_chars_per_word = max_input_chars_per_word 494 | 495 | def tokenize(self, text): 496 | """Tokenizes a piece of text into its word pieces. 497 | 498 | This uses a greedy longest-match-first algorithm to perform tokenization 499 | using the given vocabulary. 500 | 501 | For example: 502 | input = "unaffable" 503 | output = ["un", "##aff", "##able"] 504 | 505 | Args: 506 | text: A single token or whitespace separated tokens. This should have 507 | already been passed through `BasicTokenizer`. 508 | 509 | Returns: 510 | A list of wordpiece tokens. 511 | """ 512 | 513 | output_tokens = [] 514 | for token in whitespace_tokenize(text): 515 | chars = list(token) 516 | if len(chars) > self.max_input_chars_per_word: 517 | output_tokens.append(self.unk_token) 518 | continue 519 | 520 | is_bad = False 521 | start = 0 522 | sub_tokens = [] 523 | while start < len(chars): 524 | end = len(chars) 525 | cur_substr = None 526 | while start < end: 527 | substr = "".join(chars[start:end]) 528 | if start > 0: 529 | substr = "##" + substr 530 | if substr in self.vocab: 531 | cur_substr = substr 532 | break 533 | end -= 1 534 | if cur_substr is None: 535 | is_bad = True 536 | break 537 | sub_tokens.append(cur_substr) 538 | start = end 539 | 540 | if is_bad: 541 | output_tokens.append(self.unk_token) 542 | else: 543 | output_tokens.extend(sub_tokens) 544 | return output_tokens 545 | 546 | -------------------------------------------------------------------------------- /config/config.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | dataset: refcoco 3 | train_split: train 4 | train_lmdb: path/refcoco/train.lmdb 5 | val_split: val 6 | val_lmdb: path/val.lmdb 7 | mask_root: path/masks/refcoco 8 | TRAIN: 9 | swin_type: base 10 | swin_pretrain: path/swin_base_window12.pth 11 | bert: bert-base-uncased 12 | mha: '8-8-8-8' 13 | input_size: 480 14 | word_len: 20 15 | word_dim: 768 16 | vis_dim: 512 17 | num_token: 2 18 | token_dim: 512 19 | sync_bn: True 20 | dropout: 0. 21 | fusion_drop: 0. 22 | workers: 32 # data loader workers 23 | workers_val: 8 24 | batch_size: 64 # batch size for training 25 | batch_size_val: 16 # batch size for validation during training, memory and speed tradeoff 26 | start_epoch: 0 27 | epochs: 50 28 | lr_backbone: 5.e-5 29 | lr_text_encoder: 5.e-5 30 | lr: 1.e-4 31 | weight_decay: 1.e-4 32 | amsgrad: True 33 | manual_seed: 34 | print_freq: 100 35 | exp_name: cgformer 36 | output_folder: exp/refcoco/ 37 | save_freq: 1 38 | weight: 39 | resume: 40 | evaluate: True 41 | Distributed: 42 | dist_url: tcp://localhost:12345 43 | dist_backend: 'nccl' 44 | multiprocessing_distributed: True 45 | world_size: 1 46 | rank: 0 47 | TEST: 48 | window12: True # if use window12 pretrained for training, testing set true 49 | test_split: val 50 | test_lmdb: path/refcoco/val.lmdb 51 | visualize: False -------------------------------------------------------------------------------- /config/open.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | dataset: refcoco 3 | train_split: train_seen 4 | train_lmdb: path/open_lmdb/refcoco/train_seen.lmdb 5 | val_seen_split: val_seen 6 | val_seen_lmdb: path/open_lmdb/refcoco/val_seen.lmdb 7 | val_unseen_split: val_unseen 8 | val_unseen_lmdb: path/open_lmdb/refcoco/val_unseen.lmdb 9 | mask_root: path/masks/refcoco 10 | TRAIN: 11 | swin_type: base 12 | swin_pretrain: path/swin_base_patch4_window12_384_22k.pth 13 | bert: bert-base-uncased 14 | clip_pretrain: path/pretrain/ViT-L-14-336px.pt 15 | mha: '8-8-8-8' 16 | input_size: 480 17 | clip_dim: 768 18 | word_len: 20 19 | num_token: 2 20 | word_dim: 768 21 | vis_dim: 512 22 | token_dim: 512 23 | sync_bn: True 24 | dropout: 0. 25 | fusion_drop: 0. 26 | workers: 32 # data loader workers 27 | workers_val: 8 28 | batch_size: 64 # batch size for training 29 | batch_size_val: 16 # batch size for validation during training, memory and speed tradeoff 30 | start_epoch: 0 31 | epochs: 1000 32 | lr_backbone: 5.e-5 33 | lr_text_encoder: 5.e-5 34 | lr: 1.e-4 35 | weight_decay: 1.e-4 36 | amsgrad: True 37 | manual_seed: 0 38 | print_freq: 100 39 | exp_name: open 40 | output_folder: exp/refcoco 41 | save_freq: 1 42 | weight: 43 | resume: 44 | evaluate: True # evaluate on validation set, extra gpu memory needed and small batch_size_val is recommend 45 | Distributed: 46 | dist_url: tcp://localhost:12345 47 | dist_backend: 'nccl' 48 | multiprocessing_distributed: True 49 | world_size: 1 50 | rank: 0 51 | TEST: 52 | window12: True # if use window12 pretrained for training, testing set true 53 | test_split: test_unseen 54 | test_lmdb: path/refcoco/test_unseen.lmdb 55 | visualize: False -------------------------------------------------------------------------------- /data/READEME.md: -------------------------------------------------------------------------------- 1 | ## Referrring Image Segmentation 2 | Preparing data for RefCOCO, RefCOCO+, and RefCOCOg: [CRIS·GitHub](https://github.com/DerrickWang005/CRIS.pytorch/blob/master/tools/prepare_datasets.md) 3 | ## Generalization Setting for RIS 4 | Preparing data for generalization experiments. Download the prepared lmdb format datasets from [google drive](https://drive.google.com/file/d/10qIyrslkX50THzZFgGuompF_tWl5TRW0/view?usp=drive_link). 5 | 6 | You can also customize your data format via seen and unseen [id maps](coco_id.json) This id corresponds to the "cat" field in the original annotations. -------------------------------------------------------------------------------- /data/coco_id.json: -------------------------------------------------------------------------------- 1 | { 2 | "0": "person", 3 | "1": "bicycle", 4 | "2": "car", 5 | "3": "motorcycle", 6 | "4": "airplane", 7 | "5": "bus", 8 | "6": "train", 9 | "7": "truck", 10 | "8": "boat", 11 | "9": "traffic light", 12 | "10": "fire hydrant", 13 | "11": "stop sign", 14 | "12": "parking meter", 15 | "13": "bench", 16 | "14": "bird", 17 | "15": "cat", 18 | "16": "dog", 19 | "17": "horse", 20 | "18": "sheep", 21 | "19": "cow", 22 | "20": "elephant", 23 | "21": "bear", 24 | "22": "zebra", 25 | "23": "giraffe", 26 | "24": "backpack", 27 | "25": "umbrella", 28 | "26": "handbag", 29 | "27": "tie", 30 | "28": "suitcase", 31 | "29": "frisbee", 32 | "30": "skis", 33 | "31": "snowboard", 34 | "32": "sports ball", 35 | "33": "kite", 36 | "34": "baseball bat", 37 | "35": "baseball glove", 38 | "36": "skateboard", 39 | "37": "surfboard", 40 | "38": "tennis racket", 41 | "39": "bottle", 42 | "40": "wine glass", 43 | "41": "cup", 44 | "42": "fork", 45 | "43": "knife", 46 | "44": "spoon", 47 | "45": "bowl", 48 | "46": "banana", 49 | "47": "apple", 50 | "48": "sandwich", 51 | "49": "orange", 52 | "50": "broccoli", 53 | "51": "carrot", 54 | "52": "hot dog", 55 | "53": "pizza", 56 | "54": "donut", 57 | "55": "cake", 58 | "56": "chair", 59 | "57": "couch", 60 | "58": "potted plant", 61 | "59": "bed", 62 | "60": "dining table", 63 | "61": "toilet", 64 | "62": "tv", 65 | "63": "laptop", 66 | "64": "mouse", 67 | "65": "remote", 68 | "66": "keyboard", 69 | "67": "cell phone", 70 | "68": "microwave", 71 | "69": "oven", 72 | "70": "toaster", 73 | "71": "sink", 74 | "72": "refrigerator", 75 | "73": "book", 76 | "74": "clock", 77 | "75": "vase", 78 | "76": "scissors", 79 | "77": "teddy bear", 80 | "78": "hair drier", 81 | "79": "toothbrush", 82 | "seen": [ 83 | 0, 84 | 1, 85 | 2, 86 | 3, 87 | 6, 88 | 7, 89 | 8, 90 | 13, 91 | 14, 92 | 17, 93 | 18, 94 | 21, 95 | 22, 96 | 23, 97 | 24, 98 | 26, 99 | 28, 100 | 29, 101 | 30, 102 | 33, 103 | 37, 104 | 39, 105 | 42, 106 | 44, 107 | 45, 108 | 46, 109 | 47, 110 | 48, 111 | 49, 112 | 50, 113 | 51, 114 | 53, 115 | 54, 116 | 56, 117 | 59, 118 | 61, 119 | 62, 120 | 63, 121 | 64, 122 | 65, 123 | 68, 124 | 69, 125 | 70, 126 | 72, 127 | 73, 128 | 74, 129 | 75, 130 | 79 131 | ], 132 | "unseen": [ 133 | 4, 134 | 5, 135 | 15, 136 | 16, 137 | 19, 138 | 20, 139 | 25, 140 | 27, 141 | 31, 142 | 36, 143 | 41, 144 | 43, 145 | 55, 146 | 57, 147 | 66, 148 | 71, 149 | 76 150 | ] 151 | } -------------------------------------------------------------------------------- /engine/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SooLab/CGFormer/766d7fbe0c0101c80806e2499bb4d6960cfc5f4a/engine/__init__.py -------------------------------------------------------------------------------- /engine/engine.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from tqdm import tqdm 4 | import cv2 5 | import numpy as np 6 | import torch 7 | import torch.cuda.amp as amp 8 | import torch.distributed as dist 9 | import torch.nn.functional as F 10 | import wandb 11 | from PIL import Image 12 | from loguru import logger 13 | from utils.misc import (AverageMeter, ProgressMeter, concat_all_gather, trainMetricGPU) 14 | 15 | 16 | def train(train_loader, model, optimizer, scheduler, scaler, epoch, args): 17 | batch_time = AverageMeter('Batch', ':2.2f') 18 | data_time = AverageMeter('Data', ':2.2f') 19 | lr = AverageMeter('Lr', ':1.6f') 20 | loss_meter = AverageMeter('Loss', ':2.4f') 21 | iou_meter = AverageMeter('IoU', ':2.2f') 22 | pr_meter = AverageMeter('Prec@50', ':2.2f') 23 | progress = ProgressMeter( 24 | len(train_loader), 25 | [batch_time, data_time, lr, loss_meter, iou_meter, pr_meter], 26 | prefix="Training: Epoch=[{}/{}] ".format(epoch, args.epochs)) 27 | 28 | model.train() 29 | time.sleep(2) 30 | end = time.time() 31 | 32 | # size_list = [320, 352, 384, 416, 448, 480, 512] 33 | # idx = np.random.choice(len(size_list)) 34 | # new_size = size_list[idx] 35 | 36 | for i, (image, text, target, l_mask) in enumerate(train_loader): 37 | data_time.update(time.time() - end) 38 | # data 39 | image = torch.stack(image).cuda(non_blocking=True) 40 | text = torch.stack(text).cuda(non_blocking=True) 41 | target = torch.stack(target).cuda(non_blocking=True) 42 | l_mask = torch.stack(l_mask).cuda(non_blocking=True) 43 | # # multi-scale training 44 | # image = F.interpolate(image, size=(new_size, new_size), mode='bilinear', align_corners=True) 45 | text = text.squeeze(1) 46 | l_mask = l_mask.squeeze(1) 47 | # forward 48 | with amp.autocast(): 49 | pred, target, loss = model(image, text, l_mask, target) 50 | # backward 51 | optimizer.zero_grad() 52 | scaler.scale(loss).backward() 53 | # if args.max_norm: 54 | # torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) 55 | scaler.step(optimizer) 56 | scaler.update() 57 | scheduler.step() 58 | 59 | # metric 60 | iou, pr5 = trainMetricGPU(pred, target, 0.35) 61 | dist.all_reduce(loss.detach()) 62 | dist.all_reduce(iou) 63 | dist.all_reduce(pr5) 64 | loss = loss / dist.get_world_size() 65 | iou = iou / dist.get_world_size() 66 | pr5 = pr5 / dist.get_world_size() 67 | 68 | loss_meter.update(loss.item(), image.size(0)) 69 | iou_meter.update(iou.item(), image.size(0)) 70 | pr_meter.update(pr5.item(), image.size(0)) 71 | lr.update(optimizer.param_groups[0]["lr"]) 72 | batch_time.update(time.time() - end) 73 | end = time.time() 74 | 75 | if (i + 1) % args.print_freq == 0: 76 | progress.display(i + 1) 77 | if dist.get_rank() in [-1, 0]: 78 | wandb.log( 79 | { 80 | "time/batch": batch_time.val, 81 | "time/data": data_time.val, 82 | "training/lr": lr.val, 83 | "training/loss": loss_meter.val, 84 | "training/iou": iou_meter.val, 85 | "training/prec@50": pr_meter.val, 86 | }, 87 | step=epoch * len(train_loader) + (i + 1)) 88 | 89 | 90 | @torch.no_grad() 91 | def validate(val_loader, model, epoch, args): 92 | iou_list = [] 93 | I = [] 94 | U = [] 95 | model.eval() 96 | time.sleep(2) 97 | for imgs, text, masks, l_mask in val_loader: 98 | # data 99 | imgs = torch.stack(imgs).cuda(non_blocking=True) 100 | text = torch.stack(text).cuda(non_blocking=True) 101 | l_mask = torch.stack(l_mask).cuda(non_blocking=True) 102 | text = text.squeeze(1) 103 | l_mask = l_mask.squeeze(1) 104 | # inference 105 | preds, maps = model(imgs, text, l_mask) 106 | preds = torch.sigmoid(preds) 107 | # process one batch 108 | for pred, mask in zip(preds, masks): 109 | # iou 110 | pred = pred.cpu().numpy() 111 | mask = mask.cpu().numpy() 112 | pred = np.array(pred > 0.5) 113 | inter = np.logical_and(pred, mask) 114 | union = np.logical_or(pred, mask) 115 | iou = np.sum(inter) / (np.sum(union) + 1e-6) 116 | iou_list.append(iou) 117 | I.append(np.sum(inter)) 118 | U.append(np.sum(union)) 119 | iou_list = np.stack(iou_list) 120 | iou_list = torch.from_numpy(iou_list).to(imgs.device) 121 | iou_list = concat_all_gather(iou_list) 122 | I = np.stack(I) 123 | I = torch.from_numpy(I).to(imgs.device) 124 | I = concat_all_gather(I).sum() 125 | 126 | U = np.stack(U) 127 | U = torch.from_numpy(U).to(imgs.device) 128 | U = concat_all_gather(U).sum() 129 | oIoU = I/U 130 | prec_list = [] 131 | for thres in torch.arange(0.5, 1.0, 0.1): 132 | tmp = (iou_list > thres).float().mean() 133 | prec_list.append(tmp) 134 | iou = iou_list.mean() 135 | prec = {} 136 | temp = ' ' 137 | for i, thres in enumerate(range(5, 10)): 138 | key = 'Pr@{}'.format(thres * 10) 139 | value = prec_list[i].item() 140 | prec[key] = value 141 | temp += "{}: {:.2f} ".format(key, 100. * value) 142 | head = 'Evaluation: Epoch=[{}/{}] mIoU={:.2f} oIoU={:.2f}'.format( 143 | epoch, args.epochs, 100. * iou.item(), 100.*(oIoU)) 144 | logger.info(head + temp) 145 | return oIoU, prec 146 | 147 | 148 | @torch.no_grad() 149 | def inference(test_loader, model, args): 150 | iou_list = [] 151 | I = 0. 152 | U = 0. 153 | tbar = tqdm(test_loader, desc='Inference:', ncols=100) 154 | model.eval() 155 | time.sleep(2) 156 | for ori_img, img, texts, mask, l_masks, seg_id, sents in tbar: 157 | img = img.cuda(non_blocking=True) 158 | mask = mask.cpu().numpy() 159 | for text, l_mask, sent in zip(texts, l_masks, sents): 160 | text = text.cuda(non_blocking=True) 161 | l_mask = l_mask.cuda(non_blocking=True) 162 | 163 | text = text.squeeze(1) 164 | l_mask = l_mask.squeeze(1) 165 | 166 | # inference 167 | pred, maps = model(img, text, l_mask) 168 | pred = torch.sigmoid(pred) 169 | if pred.shape[-2:] != ori_img.shape[:-1]: 170 | pred = F.interpolate(pred, size=ori_img.shape[1:-1], mode='bicubic', align_corners=True) 171 | # # process one sentence 172 | pred = pred.cpu().numpy() 173 | pred_ = np.array(pred > 0.5) 174 | inter = np.logical_and(pred_, mask) 175 | union = np.logical_or(pred_, mask) 176 | I += np.sum(inter) 177 | U += np.sum(union) 178 | iou = np.sum(inter) / (np.sum(union) + 1e-6) 179 | iou_list.append(iou) 180 | 181 | logger.info('=> Metric Calculation <=') 182 | iou_list = np.stack(iou_list) 183 | iou_list = torch.from_numpy(iou_list).to(img.device) 184 | prec_list = [] 185 | for thres in torch.arange(0.5, 1.0, 0.1): 186 | tmp = (iou_list > thres).float().mean() 187 | prec_list.append(tmp) 188 | iou = iou_list.mean() 189 | prec = {} 190 | for i, thres in enumerate(range(5, 10)): 191 | key = 'Pr@{}'.format(thres*10) 192 | value = prec_list[i].item() 193 | prec[key] = value 194 | logger.info('oIoU={:.2f}'.format(100.*(I/U))) 195 | logger.info('mIoU={:.2f}'.format(100.*iou.item())) 196 | for k, v in prec.items(): 197 | logger.info('{}: {:.2f}.'.format(k, 100.*v)) 198 | 199 | return iou.item(), prec 200 | -------------------------------------------------------------------------------- /image/framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SooLab/CGFormer/766d7fbe0c0101c80806e2499bb4d6960cfc5f4a/image/framework.jpg -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .segmenter import CGFormer 2 | from loguru import logger 3 | import torch.nn as nn 4 | from .backbone import MultiModalSwinTransformer 5 | 6 | 7 | def build_model(args): 8 | # initialize the SwinTransformer backbone with the specified version 9 | if args.swin_type == 'tiny': 10 | embed_dim = 96 11 | depths = [2, 2, 6, 2] 12 | num_heads = [3, 6, 12, 24] 13 | elif args.swin_type == 'small': 14 | embed_dim = 96 15 | depths = [2, 2, 18, 2] 16 | num_heads = [3, 6, 12, 24] 17 | elif args.swin_type == 'base': 18 | embed_dim = 128 19 | depths = [2, 2, 18, 2] 20 | num_heads = [4, 8, 16, 32] 21 | elif args.swin_type == 'large': 22 | embed_dim = 192 23 | depths = [2, 2, 18, 2] 24 | num_heads = [6, 12, 24, 48] 25 | else: 26 | assert False 27 | # args.window12 added for test.py because state_dict is loaded after model initialization 28 | if 'window12' in args.swin_pretrain or args.window12: 29 | logger.info('Window size 12!') 30 | window_size = 12 31 | else: 32 | window_size = 7 33 | 34 | if args.mha: 35 | mha = args.mha.split('-') # if non-empty, then ['a', 'b', 'c', 'd'] 36 | mha = [int(a) for a in mha] 37 | else: 38 | mha = [1, 1, 1, 1] 39 | 40 | out_indices = (0, 1, 2, 3) 41 | backbone = MultiModalSwinTransformer(embed_dim=embed_dim, depths=depths, num_heads=num_heads, 42 | window_size=window_size, 43 | ape=False, drop_path_rate=0.3, patch_norm=True, 44 | out_indices=out_indices, 45 | use_checkpoint=False, num_heads_fusion=mha, 46 | fusion_drop=args.fusion_drop 47 | ) 48 | if args.swin_pretrain: 49 | logger.info('Initializing Multi-modal Swin Transformer weights from ' + args.swin_pretrain) 50 | backbone.init_weights(pretrained=args.swin_pretrain) 51 | else: 52 | logger.info('Randomly initialize Multi-modal Swin Transformer weights.') 53 | backbone.init_weights() 54 | 55 | model = CGFormer(backbone, args) 56 | 57 | return model 58 | 59 | 60 | def build_segmenter(args, DDP=True, OPEN=False): 61 | model = build_model(args) 62 | if DDP: 63 | if args.sync_bn: 64 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 65 | model = nn.parallel.DistributedDataParallel(model.cuda(), 66 | device_ids=[args.gpu], 67 | find_unused_parameters=True 68 | ) 69 | 70 | single_model = model.module 71 | if OPEN: 72 | for p in single_model.backbone.parameters(): 73 | p.requires_grad_(False) 74 | param_list = [ 75 | { 76 | "params": [ 77 | p 78 | for n, p in single_model.named_parameters() 79 | if "backbone" not in n and "text_encoder" not in n and p.requires_grad 80 | ], 81 | 82 | }, 83 | { 84 | "params": [ 85 | p 86 | for n, p in single_model.named_parameters() 87 | if "pwam" in n and p.requires_grad 88 | ], 89 | 90 | }, 91 | { 92 | "params": [p for n, p in single_model.named_parameters() if "backbone" in n and "pwam" not in n and p.requires_grad], 93 | "lr": args.lr_backbone, 94 | }, 95 | { 96 | "params": [p for n, p in single_model.named_parameters() if "text_encoder" in n and p.requires_grad], 97 | "lr": args.lr_text_encoder, 98 | }, 99 | ] 100 | 101 | return model, param_list 102 | else: 103 | model = nn.DataParallel(model).cuda() 104 | return model 105 | -------------------------------------------------------------------------------- /model/layers.py: -------------------------------------------------------------------------------- 1 | from einops import rearrange 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from timm.models.layers import trunc_normal_ 6 | 7 | def l2norm(X, dim=-1, eps=1e-12): 8 | """ 9 | L2-normalize columns of X 10 | """ 11 | norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps 12 | X = torch.div(X, norm) 13 | return X 14 | 15 | 16 | class Mlp(nn.Module): 17 | 18 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 19 | super().__init__() 20 | out_features = out_features or in_features 21 | hidden_features = hidden_features or in_features 22 | self.fc1 = nn.Linear(in_features, hidden_features) 23 | self.act = act_layer() 24 | self.fc2 = nn.Linear(hidden_features, out_features) 25 | self.drop = nn.Dropout(drop) 26 | 27 | def forward(self, x): 28 | x = self.fc1(x) 29 | x = self.act(x) 30 | x = self.drop(x) 31 | x = self.fc2(x) 32 | x = self.drop(x) 33 | return x 34 | 35 | def conv_layer(in_dim, out_dim, kernel_size=1, padding=0, stride=1): 36 | return nn.Sequential( 37 | nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias=False), 38 | nn.BatchNorm2d(out_dim), nn.ReLU(True)) 39 | 40 | def hard_softmax(logits, dim): 41 | y_soft = logits.softmax(dim) 42 | # Straight through. 43 | index = y_soft.max(dim, keepdim=True)[1] 44 | y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0) 45 | ret = y_hard - y_soft.detach() + y_soft 46 | return ret 47 | 48 | def gumbel_softmax(logits: torch.Tensor, tau: float = 1, dim: int = -2) -> torch.Tensor: 49 | gumbel_dist = torch.distributions.gumbel.Gumbel( 50 | torch.tensor(0., device=logits.device, dtype=logits.dtype), 51 | torch.tensor(1., device=logits.device, dtype=logits.dtype)) 52 | gumbels = gumbel_dist.sample(logits.shape) 53 | 54 | gumbels = (logits + gumbels) / tau 55 | y_soft = gumbels.softmax(dim) 56 | 57 | index = y_soft.max(dim, keepdim=True)[1] 58 | y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0) 59 | ret = y_hard - y_soft.detach() + y_soft 60 | 61 | return ret 62 | 63 | class Fusion(nn.Module): 64 | def __init__(self, in_dim_1, in_dim_2, out_dim, bias=False) -> None: 65 | super().__init__() 66 | 67 | self.fusion = nn.Sequential( 68 | nn.Conv2d(in_dim_1+in_dim_2, out_dim, 3, padding=1, bias=bias), 69 | nn.BatchNorm2d(out_dim), 70 | nn.ReLU(), 71 | nn.Conv2d(out_dim, out_dim, 3, padding=1, bias=bias), 72 | nn.BatchNorm2d(out_dim), 73 | nn.ReLU(), 74 | ) 75 | 76 | def forward(self, in_1, in_2): 77 | if in_1.shape[-1] < in_2.shape[-1]: 78 | in_1 = F.interpolate(in_1, size=in_2.shape[-2:], mode='bilinear', align_corners=True) 79 | elif in_1.shape[-1] > in_2.shape[-1]: 80 | in_2 = F.interpolate(in_2, size=in_1.shape[-2:], mode='bilinear', align_corners=True) 81 | 82 | x = torch.cat((in_1, in_2), dim=1) 83 | x = self.fusion(x) 84 | return x 85 | 86 | class DProjector(nn.Module): 87 | def __init__(self, text_dim=512, in_dim=512, kernel_size=1): 88 | super().__init__() 89 | self.in_dim = in_dim 90 | self.kernel_size = kernel_size 91 | # visual projector 92 | 93 | self.vis = nn.Sequential( # os16 -> os4 94 | nn.Upsample(scale_factor=2, mode='bilinear'), 95 | conv_layer(in_dim, in_dim, 3, padding=1), 96 | nn.Upsample(scale_factor=2, mode='bilinear'), 97 | conv_layer(in_dim, in_dim, 3, padding=1), 98 | nn.Conv2d(in_dim, in_dim, 1)) 99 | 100 | # textual projector 101 | out_dim = 1 * in_dim * kernel_size * kernel_size + 1 102 | self.txt = nn.Linear(text_dim, out_dim) 103 | 104 | def forward(self, x, text): 105 | ''' 106 | x: b, 512, 104, 104 107 | text: b, 512 108 | ''' 109 | x = self.vis(x) # Eq. 8 110 | 111 | B, C, H, W = x.size() 112 | # 1, b*256, 104, 104 113 | x = x.reshape(1, B * C, H, W) 114 | # txt: b, 1, (256*3*3 + 1) -> b, 1, 256, 3, 3 / b 115 | text = self.txt(text) # Eq. 8 116 | 117 | weight, bias = text[:, :-1], text[:, -1] 118 | weight = weight.reshape(B, C, self.kernel_size, self.kernel_size) 119 | # Conv2d - 1, b*256, 104, 104 -> 1, b, 104, 104 120 | out = F.conv2d(x, 121 | weight, 122 | padding=1, 123 | groups=B, 124 | bias=bias) 125 | 126 | # b, 1, 104, 104 127 | out = out.transpose(0,1) 128 | return out 129 | 130 | 131 | class CrossAttn(nn.Module): 132 | def __init__(self, 133 | q_dim, 134 | kv_dim, 135 | hidden_dim, 136 | num_heads, 137 | out_dim=None, 138 | qkv_bias=False, 139 | qk_scale=None, 140 | attn_drop=0., 141 | proj_drop=0., 142 | qkv_fuse=False): 143 | super().__init__() 144 | if out_dim is None: 145 | out_dim = q_dim 146 | self.num_heads = num_heads 147 | head_dim = hidden_dim // num_heads 148 | self.scale = qk_scale or head_dim**-0.5 149 | self.qkv_fuse = qkv_fuse 150 | 151 | self.q_proj = nn.Linear(q_dim, hidden_dim, bias=qkv_bias) 152 | self.k_proj = nn.Linear(kv_dim, hidden_dim, bias=qkv_bias) 153 | self.v_proj = nn.Linear(kv_dim, hidden_dim, bias=qkv_bias) 154 | self.attn_drop = nn.Dropout(attn_drop) 155 | self.proj = nn.Linear(hidden_dim, out_dim) 156 | self.proj_drop = nn.Dropout(proj_drop) 157 | 158 | def forward(self, query, key, value=None, mask=None): 159 | B, N, C = query.shape 160 | if value is None: 161 | value = key 162 | S = key.size(1) 163 | # [B, nh, N, C//nh] 164 | q = rearrange(self.q_proj(query), 'b n (h c)-> b h n c', h=self.num_heads, b=B, n=N, c=C // self.num_heads) 165 | # [B, nh, S, C//nh] 166 | k = rearrange(self.k_proj(key), 'b n (h c)-> b h n c', h=self.num_heads, b=B, c=C // self.num_heads) 167 | # [B, nh, S, C//nh] 168 | v = rearrange(self.v_proj(value), 'b n (h c)-> b h n c', h=self.num_heads, b=B, c=C // self.num_heads) 169 | # [B, nh, N, S] 170 | 171 | if mask is not None: 172 | mask = mask[:,None,:,None].expand(-1, self.num_heads, -1, -1) # b nh S 1 173 | k = k * mask 174 | v = v * mask 175 | attn = (q @ k.transpose(-2, -1)) * self.scale 176 | attn = attn + (1e4*mask.transpose(-2,-1)-1e4) # b nh 1 S 177 | else: 178 | attn = (q @ k.transpose(-2, -1)) * self.scale 179 | attn = attn.softmax(dim=-1) 180 | attn = self.attn_drop(attn) 181 | 182 | assert attn.shape == (B, self.num_heads, N, S) 183 | # [B, nh, N, C//nh] -> [B, N, C] 184 | out = rearrange(attn @ v, 'b h n c -> b n (h c)', h=self.num_heads, b=B, n=N, c=C // self.num_heads) 185 | out = self.proj(out) 186 | out = self.proj_drop(out) 187 | return out 188 | 189 | class OriLoadToken(nn.Module): 190 | def __init__(self, token_dim, bias, drop) -> None: 191 | super().__init__() 192 | self.cross_attn = CrossAttn( 193 | q_dim=token_dim, 194 | kv_dim=768, 195 | hidden_dim=token_dim, 196 | num_heads=1, 197 | out_dim=token_dim, 198 | qkv_bias=bias, 199 | attn_drop=drop, 200 | proj_drop=drop, 201 | ) 202 | self.normq = nn.LayerNorm(token_dim) 203 | self.normk = nn.LayerNorm(768) 204 | 205 | self.normq = nn.LayerNorm(token_dim) 206 | self.normk = nn.LayerNorm(768) 207 | 208 | def forward(self, tokens, text, pad_mask): 209 | tokens = tokens + self.cross_attn(query=self.normq(tokens), key=self.normk(text.permute(0,2,1)), mask=pad_mask[...,0]) 210 | return tokens 211 | 212 | # updated version 213 | class LoadToken(nn.Module): 214 | def __init__(self, token_dim, bias, drop) -> None: 215 | super().__init__() 216 | self.cross_attn = CrossAttn( 217 | q_dim=token_dim, 218 | kv_dim=768, 219 | hidden_dim=token_dim, 220 | num_heads=1, 221 | out_dim=token_dim, 222 | qkv_bias=bias, 223 | attn_drop=drop, 224 | proj_drop=drop, 225 | ) 226 | self.normq = nn.LayerNorm(token_dim) 227 | self.normk = nn.LayerNorm(768) 228 | 229 | def forward(self, tokens, text, pad_mask): 230 | ltoken, ttoken = torch.split(tokens, [tokens.shape[1]-1,1], dim=1) 231 | ttoken = ttoken + self.cross_attn(query=self.normq(ttoken), key=self.normk(text.permute(0,2,1)), mask=pad_mask[...,0]) 232 | tokens = torch.cat((ltoken, ttoken), dim=1) 233 | return tokens 234 | 235 | class LoadLayer(nn.Module): 236 | def __init__(self, token_dim, drop, bias=False, pe_shape=None) -> None: 237 | super().__init__() 238 | if pe_shape >30: 239 | self.loadtoken = LoadToken( 240 | token_dim=token_dim, 241 | bias=bias, 242 | drop=drop 243 | ) 244 | self.norm = nn.LayerNorm(token_dim) 245 | self.mlp = Mlp(token_dim, token_dim*2, token_dim) 246 | self.positional_embedding = nn.Parameter(torch.randn(pe_shape**2, token_dim) / token_dim ** 0.5) 247 | self.pe_shape = pe_shape 248 | 249 | def forward(self, tokens, text, pad_mask): 250 | if self.pe_shape > 30: 251 | tokens = self.loadtoken(tokens, text, pad_mask) 252 | tokens = self.mlp(self.norm(tokens)) 253 | return tokens, self.positional_embedding 254 | 255 | 256 | class CGAttention(nn.Module): 257 | def __init__(self, token_dim, vis_dim, hidden_dim, drop=0., bias=True) -> None: 258 | super().__init__() 259 | self.norm_v = nn.LayerNorm(vis_dim) 260 | self.norm_t = nn.LayerNorm(token_dim) 261 | self.q_proj = nn.Linear(token_dim, hidden_dim, bias=bias) 262 | self.k_proj = nn.Linear(vis_dim, hidden_dim, bias=bias) 263 | self.v_proj = nn.Linear(vis_dim, hidden_dim, bias=bias) 264 | self.proj = nn.Linear(hidden_dim, token_dim) 265 | self.proj_drop = nn.Dropout(drop) 266 | self.norm = nn.LayerNorm(token_dim) 267 | self.mlp = Mlp(token_dim, token_dim*2, token_dim, drop=drop) 268 | self.tau = nn.Parameter(torch.ones(1), requires_grad=True) 269 | 270 | def with_pe(self, vis, pe): 271 | return vis + pe 272 | 273 | def forward(self, tokens, vis, pe=None): 274 | b, c, h , w = vis.shape 275 | vis = rearrange(vis, 'b c h w -> b (h w) c') 276 | if pe is not None: 277 | vis = self.with_pe(vis, pe) 278 | vis = self.norm_v(vis) 279 | q = self.q_proj(self.norm_t(tokens)) 280 | k = self.k_proj(vis) 281 | v = self.v_proj(vis) 282 | 283 | q = l2norm(q, dim=-1) 284 | k = l2norm(k, dim=-1) 285 | raw_attn = (q @ k.transpose(-2, -1)) 286 | tau = torch.clamp(self.tau, max=0).exp() 287 | attn = gumbel_softmax(raw_attn, dim=-2, tau=tau) 288 | hit_map = attn 289 | attn = attn / (attn.sum(dim=-1, keepdim=True) + 1) 290 | new_tokens = attn @ v 291 | new_tokens = self.proj_drop(self.proj(new_tokens)) 292 | new_tokens = self.mlp(self.norm(new_tokens+tokens)) 293 | return new_tokens, hit_map.reshape(b, -1, h, w) 294 | 295 | class Decoder(nn.Module): 296 | def __init__(self, args) -> None: 297 | super().__init__() 298 | ''' 299 | c1 :128, 120, 120 300 | c2 :256, 60, 60 301 | c3 :512, 30, 30 302 | c4 :1024, 15 ,15 303 | ''' 304 | token_dim = args.token_dim 305 | self.tokens = nn.Embedding(args.num_token, token_dim) 306 | trunc_normal_(self.tokens.weight, std=0.02) 307 | 308 | dims = [1024, 512, 256, 128] 309 | pe_shapes = [30, 60, 120] 310 | 311 | self.layers = [] 312 | for pe_shape in pe_shapes: 313 | self.layers.append(LoadLayer(token_dim, drop=.1, bias=False, pe_shape=pe_shape)) 314 | self.cgattention1 = CGAttention(token_dim=token_dim, 315 | vis_dim=token_dim, 316 | hidden_dim=token_dim, 317 | drop=.1, 318 | bias=True) 319 | self.cgattention2 = CGAttention(token_dim=token_dim, 320 | vis_dim=token_dim, 321 | hidden_dim=token_dim, 322 | drop=.1, 323 | bias=True) 324 | self.layers = nn.ModuleList(self.layers) 325 | self.fuses = [] 326 | for dim in [dims[0], dims[2], dims[3]]: 327 | self.fuses.append(Fusion(dim, token_dim, token_dim, bias=True)) 328 | self.fuses = nn.ModuleList(self.fuses) 329 | self.proj = DProjector(text_dim=token_dim, in_dim=token_dim) 330 | 331 | def forward(self, vis, text, pad_mask): 332 | x_c4, x_c3, x_c2, x_c1 = vis 333 | tokens = self.tokens.weight[None,...].expand(x_c1.shape[0], -1, -1) 334 | maps = [] 335 | v = x_c4 336 | for load, layer, fuse, v_ in zip(self.layers,[self.cgattention1,self.cgattention2,self.cgattention2], self.fuses, [x_c3, x_c2, x_c1]): 337 | v = fuse(v, v_) 338 | tokens, pe = load(tokens, text, pad_mask) 339 | tokens, hitmap = layer(tokens, v, pe=pe) 340 | maps.append(hitmap) 341 | out = self.proj(v, tokens[:,-1]) 342 | return out, maps 343 | -------------------------------------------------------------------------------- /model/mmcv_custom/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .checkpoint import load_checkpoint 4 | 5 | __all__ = ['load_checkpoint'] 6 | -------------------------------------------------------------------------------- /model/mmcv_custom/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | import io 3 | import os 4 | import os.path as osp 5 | import pkgutil 6 | import time 7 | import warnings 8 | from collections import OrderedDict 9 | from importlib import import_module 10 | from tempfile import TemporaryDirectory 11 | 12 | import torch 13 | import torchvision 14 | from torch.optim import Optimizer 15 | from torch.utils import model_zoo 16 | from torch.nn import functional as F 17 | 18 | import mmcv 19 | from mmcv.fileio import FileClient 20 | from mmcv.fileio import load as load_file 21 | from mmcv.parallel import is_module_wrapper 22 | from mmcv.utils import mkdir_or_exist 23 | from mmcv.runner import get_dist_info 24 | 25 | ENV_MMCV_HOME = 'MMCV_HOME' 26 | ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' 27 | DEFAULT_CACHE_DIR = '~/.cache' 28 | 29 | 30 | def _get_mmcv_home(): 31 | mmcv_home = os.path.expanduser( 32 | os.getenv( 33 | ENV_MMCV_HOME, 34 | os.path.join( 35 | os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv'))) 36 | 37 | mkdir_or_exist(mmcv_home) 38 | return mmcv_home 39 | 40 | 41 | def load_state_dict(module, state_dict, strict=False, logger=None): 42 | """Load state_dict to a module. 43 | 44 | This method is modified from :meth:`torch.nn.Module.load_state_dict`. 45 | Default value for ``strict`` is set to ``False`` and the message for 46 | param mismatch will NOT be shown if strict is False. 47 | 48 | Args: 49 | module (Module): Module that receives the state_dict. 50 | state_dict (OrderedDict): Weights. 51 | strict (bool): whether to strictly enforce that the keys 52 | in :attr:`state_dict` match the keys returned by this module's 53 | :meth:`~torch.nn.Module.state_dict` function. Default: ``False``. 54 | logger (:obj:`logging.Logger`, optional): Logger to log the error 55 | message. If not specified, print function will be used. 56 | """ 57 | unexpected_keys = [] 58 | all_missing_keys = [] 59 | err_msg = [] 60 | 61 | metadata = getattr(state_dict, '_metadata', None) 62 | state_dict = state_dict.copy() 63 | if metadata is not None: 64 | state_dict._metadata = metadata 65 | 66 | # use _load_from_state_dict to enable checkpoint version control 67 | def load(module, prefix=''): 68 | # recursively check parallel module in case that the model has a 69 | # complicated structure, e.g., nn.Module(nn.Module(DDP)) 70 | if is_module_wrapper(module): 71 | module = module.module 72 | local_metadata = {} if metadata is None else metadata.get( 73 | prefix[:-1], {}) 74 | module._load_from_state_dict(state_dict, prefix, local_metadata, True, 75 | all_missing_keys, unexpected_keys, 76 | err_msg) 77 | for name, child in module._modules.items(): 78 | if child is not None: 79 | load(child, prefix + name + '.') 80 | 81 | load(module) 82 | load = None # break load->load reference cycle 83 | 84 | # ignore "num_batches_tracked" of BN layers 85 | missing_keys = [ 86 | key for key in all_missing_keys if 'num_batches_tracked' not in key 87 | ] 88 | 89 | if unexpected_keys: 90 | err_msg.append('unexpected key in source ' 91 | f'state_dict: {", ".join(unexpected_keys)}\n') 92 | if missing_keys: 93 | err_msg.append( 94 | f'missing keys in source state_dict: {", ".join(missing_keys)}\n') 95 | 96 | if strict: 97 | rank, _ = get_dist_info() 98 | if len(err_msg) > 0 and rank == 0: 99 | err_msg.insert( 100 | 0, 'The model and loaded state dict do not match exactly\n') 101 | err_msg = '\n'.join(err_msg) 102 | if strict: 103 | raise RuntimeError(err_msg) 104 | elif logger is not None: 105 | logger.warning(err_msg) 106 | else: 107 | print(err_msg) 108 | 109 | 110 | def load_url_dist(url, model_dir=None): 111 | """In distributed setting, this function only download checkpoint at local 112 | rank 0.""" 113 | rank, world_size = get_dist_info() 114 | rank = int(os.environ.get('LOCAL_RANK', rank)) 115 | if rank == 0: 116 | checkpoint = model_zoo.load_url(url, model_dir=model_dir) 117 | if world_size > 1: 118 | torch.distributed.barrier() 119 | if rank > 0: 120 | checkpoint = model_zoo.load_url(url, model_dir=model_dir) 121 | return checkpoint 122 | 123 | 124 | def load_pavimodel_dist(model_path, map_location=None): 125 | """In distributed setting, this function only download checkpoint at local 126 | rank 0.""" 127 | try: 128 | from pavi import modelcloud 129 | except ImportError: 130 | raise ImportError( 131 | 'Please install pavi to load checkpoint from modelcloud.') 132 | rank, world_size = get_dist_info() 133 | rank = int(os.environ.get('LOCAL_RANK', rank)) 134 | if rank == 0: 135 | model = modelcloud.get(model_path) 136 | with TemporaryDirectory() as tmp_dir: 137 | downloaded_file = osp.join(tmp_dir, model.name) 138 | model.download(downloaded_file) 139 | checkpoint = torch.load(downloaded_file, map_location=map_location) 140 | if world_size > 1: 141 | torch.distributed.barrier() 142 | if rank > 0: 143 | model = modelcloud.get(model_path) 144 | with TemporaryDirectory() as tmp_dir: 145 | downloaded_file = osp.join(tmp_dir, model.name) 146 | model.download(downloaded_file) 147 | checkpoint = torch.load( 148 | downloaded_file, map_location=map_location) 149 | return checkpoint 150 | 151 | 152 | def load_fileclient_dist(filename, backend, map_location): 153 | """In distributed setting, this function only download checkpoint at local 154 | rank 0.""" 155 | rank, world_size = get_dist_info() 156 | rank = int(os.environ.get('LOCAL_RANK', rank)) 157 | allowed_backends = ['ceph'] 158 | if backend not in allowed_backends: 159 | raise ValueError(f'Load from Backend {backend} is not supported.') 160 | if rank == 0: 161 | fileclient = FileClient(backend=backend) 162 | buffer = io.BytesIO(fileclient.get(filename)) 163 | checkpoint = torch.load(buffer, map_location=map_location) 164 | if world_size > 1: 165 | torch.distributed.barrier() 166 | if rank > 0: 167 | fileclient = FileClient(backend=backend) 168 | buffer = io.BytesIO(fileclient.get(filename)) 169 | checkpoint = torch.load(buffer, map_location=map_location) 170 | return checkpoint 171 | 172 | 173 | def get_torchvision_models(): 174 | model_urls = dict() 175 | for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): 176 | if ispkg: 177 | continue 178 | _zoo = import_module(f'torchvision.models.{name}') 179 | if hasattr(_zoo, 'model_urls'): 180 | _urls = getattr(_zoo, 'model_urls') 181 | model_urls.update(_urls) 182 | return model_urls 183 | 184 | 185 | def get_external_models(): 186 | mmcv_home = _get_mmcv_home() 187 | default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json') 188 | default_urls = load_file(default_json_path) 189 | assert isinstance(default_urls, dict) 190 | external_json_path = osp.join(mmcv_home, 'open_mmlab.json') 191 | if osp.exists(external_json_path): 192 | external_urls = load_file(external_json_path) 193 | assert isinstance(external_urls, dict) 194 | default_urls.update(external_urls) 195 | 196 | return default_urls 197 | 198 | 199 | def get_mmcls_models(): 200 | mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json') 201 | mmcls_urls = load_file(mmcls_json_path) 202 | 203 | return mmcls_urls 204 | 205 | 206 | def get_deprecated_model_names(): 207 | deprecate_json_path = osp.join(mmcv.__path__[0], 208 | 'model_zoo/deprecated.json') 209 | deprecate_urls = load_file(deprecate_json_path) 210 | assert isinstance(deprecate_urls, dict) 211 | 212 | return deprecate_urls 213 | 214 | 215 | def _process_mmcls_checkpoint(checkpoint): 216 | state_dict = checkpoint['state_dict'] 217 | new_state_dict = OrderedDict() 218 | for k, v in state_dict.items(): 219 | if k.startswith('backbone.'): 220 | new_state_dict[k[9:]] = v 221 | new_checkpoint = dict(state_dict=new_state_dict) 222 | 223 | return new_checkpoint 224 | 225 | 226 | def _load_checkpoint(filename, map_location=None): 227 | """Load checkpoint from somewhere (modelzoo, file, url). 228 | 229 | Args: 230 | filename (str): Accept local filepath, URL, ``torchvision://xxx``, 231 | ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for 232 | details. 233 | map_location (str | None): Same as :func:`torch.load`. Default: None. 234 | 235 | Returns: 236 | dict | OrderedDict: The loaded checkpoint. It can be either an 237 | OrderedDict storing model weights or a dict containing other 238 | information, which depends on the checkpoint. 239 | """ 240 | if filename.startswith('modelzoo://'): 241 | warnings.warn('The URL scheme of "modelzoo://" is deprecated, please ' 242 | 'use "torchvision://" instead') 243 | model_urls = get_torchvision_models() 244 | model_name = filename[11:] 245 | checkpoint = load_url_dist(model_urls[model_name]) 246 | elif filename.startswith('torchvision://'): 247 | model_urls = get_torchvision_models() 248 | model_name = filename[14:] 249 | checkpoint = load_url_dist(model_urls[model_name]) 250 | elif filename.startswith('open-mmlab://'): 251 | model_urls = get_external_models() 252 | model_name = filename[13:] 253 | deprecated_urls = get_deprecated_model_names() 254 | if model_name in deprecated_urls: 255 | warnings.warn(f'open-mmlab://{model_name} is deprecated in favor ' 256 | f'of open-mmlab://{deprecated_urls[model_name]}') 257 | model_name = deprecated_urls[model_name] 258 | model_url = model_urls[model_name] 259 | # check if is url 260 | if model_url.startswith(('http://', 'https://')): 261 | checkpoint = load_url_dist(model_url) 262 | else: 263 | filename = osp.join(_get_mmcv_home(), model_url) 264 | if not osp.isfile(filename): 265 | raise IOError(f'{filename} is not a checkpoint file') 266 | checkpoint = torch.load(filename, map_location=map_location) 267 | elif filename.startswith('mmcls://'): 268 | model_urls = get_mmcls_models() 269 | model_name = filename[8:] 270 | checkpoint = load_url_dist(model_urls[model_name]) 271 | checkpoint = _process_mmcls_checkpoint(checkpoint) 272 | elif filename.startswith(('http://', 'https://')): 273 | checkpoint = load_url_dist(filename) 274 | elif filename.startswith('pavi://'): 275 | model_path = filename[7:] 276 | checkpoint = load_pavimodel_dist(model_path, map_location=map_location) 277 | elif filename.startswith('s3://'): 278 | checkpoint = load_fileclient_dist( 279 | filename, backend='ceph', map_location=map_location) 280 | else: 281 | if not osp.isfile(filename): 282 | raise IOError(f'{filename} is not a checkpoint file') 283 | checkpoint = torch.load(filename, map_location=map_location) 284 | return checkpoint 285 | 286 | 287 | def load_checkpoint(model, 288 | filename, 289 | map_location='cpu', 290 | strict=False, 291 | logger=None): 292 | """Load checkpoint from a file or URI. 293 | 294 | Args: 295 | model (Module): Module to load checkpoint. 296 | filename (str): Accept local filepath, URL, ``torchvision://xxx``, 297 | ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for 298 | details. 299 | map_location (str): Same as :func:`torch.load`. 300 | strict (bool): Whether to allow different params for the model and 301 | checkpoint. 302 | logger (:mod:`logging.Logger` or None): The logger for error message. 303 | 304 | Returns: 305 | dict or OrderedDict: The loaded checkpoint. 306 | """ 307 | checkpoint = _load_checkpoint(filename, map_location) 308 | # OrderedDict is a subclass of dict 309 | if not isinstance(checkpoint, dict): 310 | raise RuntimeError( 311 | f'No state_dict found in checkpoint file {filename}') 312 | # get state_dict from checkpoint 313 | if 'state_dict' in checkpoint: 314 | state_dict = checkpoint['state_dict'] 315 | elif 'model' in checkpoint: 316 | state_dict = checkpoint['model'] 317 | else: 318 | state_dict = checkpoint 319 | # strip prefix of state_dict 320 | if list(state_dict.keys())[0].startswith('module.'): 321 | state_dict = {k[7:]: v for k, v in state_dict.items()} 322 | # for upper net weights only 323 | if list(state_dict.keys())[0].startswith('backbone.'): 324 | print('Start stripping upper net pre-fix and loading backbone weights to our swin encoder') 325 | state_dict = {k.replace('backbone.', ''): v for k, v in state_dict.items() if k.startswith('backbone.')} 326 | # for MoBY, load model of online branch 327 | if sorted(list(state_dict.keys()))[0].startswith('encoder'): 328 | state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')} 329 | 330 | # reshape absolute position embedding 331 | if state_dict.get('absolute_pos_embed') is not None: 332 | absolute_pos_embed = state_dict['absolute_pos_embed'] 333 | N1, L, C1 = absolute_pos_embed.size() 334 | N2, C2, H, W = model.absolute_pos_embed.size() 335 | if N1 != N2 or C1 != C2 or L != H*W: 336 | logger.warning("Error in loading absolute_pos_embed, pass") 337 | else: 338 | state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2) 339 | 340 | # interpolate position bias table if needed 341 | relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k] 342 | for table_key in relative_position_bias_table_keys: 343 | table_pretrained = state_dict[table_key] 344 | table_current = model.state_dict()[table_key] 345 | L1, nH1 = table_pretrained.size() 346 | L2, nH2 = table_current.size() 347 | if nH1 != nH2: 348 | logger.warning(f"Error in loading {table_key}, pass") 349 | else: 350 | if L1 != L2: 351 | S1 = int(L1 ** 0.5) 352 | S2 = int(L2 ** 0.5) 353 | table_pretrained_resized = F.interpolate( 354 | table_pretrained.permute(1, 0).view(1, nH1, S1, S1), 355 | size=(S2, S2), mode='bicubic') 356 | state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0) 357 | 358 | # load state_dict 359 | load_state_dict(model, state_dict, strict, logger) 360 | return checkpoint 361 | 362 | 363 | def weights_to_cpu(state_dict): 364 | """Copy a model state_dict to cpu. 365 | 366 | Args: 367 | state_dict (OrderedDict): Model weights on GPU. 368 | 369 | Returns: 370 | OrderedDict: Model weights on GPU. 371 | """ 372 | state_dict_cpu = OrderedDict() 373 | for key, val in state_dict.items(): 374 | state_dict_cpu[key] = val.cpu() 375 | return state_dict_cpu 376 | 377 | 378 | def _save_to_state_dict(module, destination, prefix, keep_vars): 379 | """Saves module state to `destination` dictionary. 380 | 381 | This method is modified from :meth:`torch.nn.Module._save_to_state_dict`. 382 | 383 | Args: 384 | module (nn.Module): The module to generate state_dict. 385 | destination (dict): A dict where state will be stored. 386 | prefix (str): The prefix for parameters and buffers used in this 387 | module. 388 | """ 389 | for name, param in module._parameters.items(): 390 | if param is not None: 391 | destination[prefix + name] = param if keep_vars else param.detach() 392 | for name, buf in module._buffers.items(): 393 | # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d 394 | if buf is not None: 395 | destination[prefix + name] = buf if keep_vars else buf.detach() 396 | 397 | 398 | def get_state_dict(module, destination=None, prefix='', keep_vars=False): 399 | """Returns a dictionary containing a whole state of the module. 400 | 401 | Both parameters and persistent buffers (e.g. running averages) are 402 | included. Keys are corresponding parameter and buffer names. 403 | 404 | This method is modified from :meth:`torch.nn.Module.state_dict` to 405 | recursively check parallel module in case that the model has a complicated 406 | structure, e.g., nn.Module(nn.Module(DDP)). 407 | 408 | Args: 409 | module (nn.Module): The module to generate state_dict. 410 | destination (OrderedDict): Returned dict for the state of the 411 | module. 412 | prefix (str): Prefix of the key. 413 | keep_vars (bool): Whether to keep the variable property of the 414 | parameters. Default: False. 415 | 416 | Returns: 417 | dict: A dictionary containing a whole state of the module. 418 | """ 419 | # recursively check parallel module in case that the model has a 420 | # complicated structure, e.g., nn.Module(nn.Module(DDP)) 421 | if is_module_wrapper(module): 422 | module = module.module 423 | 424 | # below is the same as torch.nn.Module.state_dict() 425 | if destination is None: 426 | destination = OrderedDict() 427 | destination._metadata = OrderedDict() 428 | destination._metadata[prefix[:-1]] = local_metadata = dict( 429 | version=module._version) 430 | _save_to_state_dict(module, destination, prefix, keep_vars) 431 | for name, child in module._modules.items(): 432 | if child is not None: 433 | get_state_dict( 434 | child, destination, prefix + name + '.', keep_vars=keep_vars) 435 | for hook in module._state_dict_hooks.values(): 436 | hook_result = hook(module, destination, prefix, local_metadata) 437 | if hook_result is not None: 438 | destination = hook_result 439 | return destination 440 | 441 | 442 | def save_checkpoint(model, filename, optimizer=None, meta=None): 443 | """Save checkpoint to file. 444 | 445 | The checkpoint will have 3 fields: ``meta``, ``state_dict`` and 446 | ``optimizer``. By default ``meta`` will contain version and time info. 447 | 448 | Args: 449 | model (Module): Module whose params are to be saved. 450 | filename (str): Checkpoint filename. 451 | optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. 452 | meta (dict, optional): Metadata to be saved in checkpoint. 453 | """ 454 | if meta is None: 455 | meta = {} 456 | elif not isinstance(meta, dict): 457 | raise TypeError(f'meta must be a dict or None, but got {type(meta)}') 458 | meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) 459 | 460 | if is_module_wrapper(model): 461 | model = model.module 462 | 463 | if hasattr(model, 'CLASSES') and model.CLASSES is not None: 464 | # save class name to the meta 465 | meta.update(CLASSES=model.CLASSES) 466 | 467 | checkpoint = { 468 | 'meta': meta, 469 | 'state_dict': weights_to_cpu(get_state_dict(model)) 470 | } 471 | # save optimizer state dict in the checkpoint 472 | if isinstance(optimizer, Optimizer): 473 | checkpoint['optimizer'] = optimizer.state_dict() 474 | elif isinstance(optimizer, dict): 475 | checkpoint['optimizer'] = {} 476 | for name, optim in optimizer.items(): 477 | checkpoint['optimizer'][name] = optim.state_dict() 478 | 479 | if filename.startswith('pavi://'): 480 | try: 481 | from pavi import modelcloud 482 | from pavi.exception import NodeNotFoundError 483 | except ImportError: 484 | raise ImportError( 485 | 'Please install pavi to load checkpoint from modelcloud.') 486 | model_path = filename[7:] 487 | root = modelcloud.Folder() 488 | model_dir, model_name = osp.split(model_path) 489 | try: 490 | model = modelcloud.get(model_dir) 491 | except NodeNotFoundError: 492 | model = root.create_training_model(model_dir) 493 | with TemporaryDirectory() as tmp_dir: 494 | checkpoint_file = osp.join(tmp_dir, model_name) 495 | with open(checkpoint_file, 'wb') as f: 496 | torch.save(checkpoint, f) 497 | f.flush() 498 | model.create_file(checkpoint_file, name=model_name) 499 | else: 500 | mmcv.mkdir_or_exist(osp.dirname(filename)) 501 | # immediately flush buffer 502 | with open(filename, 'wb') as f: 503 | torch.save(checkpoint, f) 504 | f.flush() 505 | -------------------------------------------------------------------------------- /model/segmenter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .layers import Decoder 4 | import torch.nn.functional as F 5 | from bert.modeling_bert import BertModel 6 | 7 | def dice_loss(inputs, targets): 8 | """ 9 | Compute the DICE loss, similar to generalized IOU for masks 10 | Args: 11 | inputs: A float tensor of arbitrary shape. 12 | The predictions for each example. 13 | targets: A float tensor with the same shape as inputs. Stores the binary 14 | classification label for each element in inputs 15 | (0 for the negative class and 1 for the positive class). 16 | """ 17 | 18 | inputs = inputs.sigmoid() 19 | inputs = inputs.flatten(1) 20 | targets = targets.flatten(1) 21 | numerator = 2 * (inputs * targets).sum(1) 22 | denominator = inputs.sum(-1) + targets.sum(-1) 23 | loss = 1 - (numerator + 1) / (denominator + 1) 24 | return loss.mean() 25 | 26 | def sigmoid_focal_loss(inputs, targets, alpha: float = 0.25, gamma: float = 2): 27 | """ 28 | Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. 29 | Args: 30 | inputs: A float tensor of arbitrary shape. 31 | The predictions for each example. 32 | targets: A float tensor with the same shape as inputs. Stores the binary 33 | classification label for each element in inputs 34 | (0 for the negative class and 1 for the positive class). 35 | alpha: (optional) Weighting factor in range (0,1) to balance 36 | positive vs negative examples. Default = -1 (no weighting). 37 | gamma: Exponent of the modulating factor (1 - p_t) to 38 | balance easy vs hard examples. 39 | Returns: 40 | Loss tensor 41 | """ 42 | 43 | prob = inputs.sigmoid() 44 | ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") 45 | p_t = prob * targets + (1 - prob) * (1 - targets) 46 | loss = ce_loss * ((1 - p_t) ** gamma) 47 | 48 | if alpha >= 0: 49 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets) 50 | loss = alpha_t * loss 51 | return loss.mean() 52 | 53 | 54 | class CGFormer(nn.Module): 55 | def __init__(self, backbone, args): 56 | super(CGFormer, self).__init__() 57 | self.backbone = backbone 58 | self.decoder = Decoder(args) 59 | self.text_encoder = BertModel.from_pretrained(args.bert) 60 | self.text_encoder.pooler = None 61 | 62 | def forward(self, x, text, l_mask, mask=None): 63 | input_shape = x.shape[-2:] 64 | l_feats = self.text_encoder(text, attention_mask=l_mask)[0] # (6, 10, 768) 65 | l_feats = l_feats.permute(0, 2, 1) # (B, 768, N_l) to make Conv1d happy 66 | l_mask = l_mask.unsqueeze(dim=-1) # (batch, N_l, 1) 67 | ########################## 68 | features = self.backbone(x, l_feats, l_mask) 69 | x_c1, x_c2, x_c3, x_c4 = features 70 | pred, maps = self.decoder([x_c4, x_c3, x_c2, x_c1], l_feats, l_mask) 71 | pred = F.interpolate(pred, input_shape, mode='bilinear', align_corners=True) 72 | # loss 73 | if self.training: 74 | loss = 0. 75 | mask = mask.unsqueeze(1).float() 76 | for m, lam in zip(maps, [0.001,0.01,0.1]): 77 | m = m[:,1].unsqueeze(1) 78 | if m.shape[-2:] != mask.shape[-2:]: 79 | mask_ = F.interpolate(mask, m.shape[-2:], mode='nearest').detach() 80 | loss += dice_loss(m, mask_) * lam 81 | loss += dice_loss(pred, mask) + sigmoid_focal_loss(pred, mask, alpha=-1, gamma=0) 82 | return pred.detach(), mask, loss 83 | else: 84 | return pred.detach(), maps -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers 2 | wandb 3 | lmdb 4 | pyarrow 5 | regex 6 | ftfy 7 | loguru 8 | pycocotools 9 | matplotlib 10 | tqdm -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import warnings 4 | 5 | import cv2 6 | import torch 7 | import torch.nn.parallel 8 | import torch.utils.data 9 | from loguru import logger 10 | 11 | import utils.config as config 12 | from engine.engine import inference 13 | from model import build_segmenter 14 | from utils.dataset import RefDataset 15 | from utils.misc import setup_logger 16 | 17 | warnings.filterwarnings("ignore") 18 | cv2.setNumThreads(0) 19 | 20 | 21 | def get_parser(): 22 | parser = argparse.ArgumentParser( 23 | description='Pytorch Referring Expression Segmentation') 24 | parser.add_argument('--config', 25 | default='path to xxx.yaml', 26 | type=str, 27 | help='config file') 28 | parser.add_argument('--opts', 29 | default=None, 30 | nargs=argparse.REMAINDER, 31 | help='override some settings in the config.') 32 | args = parser.parse_args() 33 | assert args.config is not None 34 | cfg = config.load_cfg_from_cfg_file(args.config) 35 | if args.opts is not None: 36 | cfg = config.merge_cfg_from_list(cfg, args.opts) 37 | return cfg 38 | 39 | 40 | @logger.catch 41 | def main(): 42 | args = get_parser() 43 | args.output_dir = os.path.join(args.output_folder, args.exp_name) 44 | if args.visualize: 45 | args.vis_dir = os.path.join(args.output_dir, "vis") 46 | os.makedirs(args.vis_dir, exist_ok=True) 47 | 48 | # logger 49 | setup_logger(args.output_dir, 50 | distributed_rank=0, 51 | filename="test.log", 52 | mode="a") 53 | logger.info(args.test_split) 54 | 55 | # build dataset & dataloader 56 | test_data = RefDataset(lmdb_dir=args.test_lmdb, 57 | mask_dir=args.mask_root, 58 | dataset=args.dataset, 59 | split=args.test_split, 60 | mode='test', 61 | input_size=args.input_size, 62 | word_length=args.word_len) 63 | test_loader = torch.utils.data.DataLoader(test_data, 64 | batch_size=1, 65 | shuffle=False, 66 | num_workers=4, 67 | pin_memory=True) 68 | 69 | # build model 70 | model = build_segmenter(args, DDP=False) 71 | logger.info(model) 72 | 73 | if os.path.isfile(args.weight): 74 | logger.info("=> loading checkpoint '{}'".format(args.weight)) 75 | checkpoint = torch.load(args.weight) 76 | model.module.load_state_dict(checkpoint['model_state_dict'], strict=True) 77 | logger.info("=> loaded checkpoint '{}'".format(args.weight)) 78 | else: 79 | raise ValueError( 80 | "=> resume failed! no checkpoint found at '{}'. Please check args.resume again!" 81 | .format(args.weight)) 82 | 83 | # inference 84 | inference(test_loader, model, args) 85 | 86 | 87 | if __name__ == '__main__': 88 | main() 89 | -------------------------------------------------------------------------------- /tools/data_process.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import cv2 6 | import numpy as np 7 | from tqdm import tqdm 8 | 9 | from refer import REFER 10 | 11 | parser = argparse.ArgumentParser(description='Data preparation') 12 | parser.add_argument('--data_root', type=str) 13 | parser.add_argument('--output_dir', type=str) 14 | parser.add_argument('--dataset', 15 | type=str, 16 | choices=['refcoco', 'refcoco+', 'refcocog', 'refclef'], 17 | default='refcoco') 18 | parser.add_argument('--split', type=str, default='umd') 19 | parser.add_argument('--generate_mask', action='store_true') 20 | args = parser.parse_args() 21 | img_path = os.path.join(args.data_root, 'images', 'train2014') 22 | 23 | h, w = (416, 416) 24 | 25 | refer = REFER(args.data_root, args.dataset, args.split) 26 | 27 | print('dataset [%s_%s] contains: ' % (args.dataset, args.split)) 28 | ref_ids = refer.getRefIds() 29 | image_ids = refer.getImgIds() 30 | print('%s expressions for %s refs in %s images.' % 31 | (len(refer.Sents), len(ref_ids), len(image_ids))) 32 | 33 | print('\nAmong them:') 34 | if args.dataset == 'refclef': 35 | if args.split == 'unc': 36 | splits = ['train', 'val', 'testA', 'testB', 'testC'] 37 | else: 38 | splits = ['train', 'val', 'test'] 39 | elif args.dataset == 'refcoco': 40 | splits = ['train', 'val', 'testA', 'testB'] 41 | elif args.dataset == 'refcoco+': 42 | splits = ['train', 'val', 'testA', 'testB'] 43 | elif args.dataset == 'refcocog': 44 | splits = ['train', 'val', 45 | 'test'] # we don't have test split for refcocog right now. 46 | 47 | for split in splits: 48 | ref_ids = refer.getRefIds(split=split) 49 | print('%s refs are in split [%s].' % (len(ref_ids), split)) 50 | 51 | 52 | def cat_process(cat): 53 | if cat >= 1 and cat <= 11: 54 | cat = cat - 1 55 | elif cat >= 13 and cat <= 25: 56 | cat = cat - 2 57 | elif cat >= 27 and cat <= 28: 58 | cat = cat - 3 59 | elif cat >= 31 and cat <= 44: 60 | cat = cat - 5 61 | elif cat >= 46 and cat <= 65: 62 | cat = cat - 6 63 | elif cat == 67: 64 | cat = cat - 7 65 | elif cat == 70: 66 | cat = cat - 9 67 | elif cat >= 72 and cat <= 82: 68 | cat = cat - 10 69 | elif cat >= 84 and cat <= 90: 70 | cat = cat - 11 71 | return cat 72 | 73 | 74 | def bbox_process(bbox): 75 | x_min = int(bbox[0]) 76 | y_min = int(bbox[1]) 77 | x_max = x_min + int(bbox[2]) 78 | y_max = y_min + int(bbox[3]) 79 | return list(map(int, [x_min, y_min, x_max, y_max])) 80 | 81 | 82 | def prepare_dataset(dataset, splits, output_dir, generate_mask=False): 83 | ann_path = os.path.join(output_dir, 'anns', dataset) 84 | mask_path = os.path.join(output_dir, 'masks', dataset) 85 | if not os.path.exists(ann_path): 86 | os.makedirs(ann_path) 87 | if not os.path.exists(mask_path): 88 | os.makedirs(mask_path) 89 | 90 | for split in splits: 91 | dataset_array = [] 92 | ref_ids = refer.getRefIds(split=split) 93 | print('Processing split:{} - Len: {}'.format(split, np.alen(ref_ids))) 94 | for i in tqdm(ref_ids): 95 | ref_dict = {} 96 | 97 | refs = refer.Refs[i] 98 | bboxs = refer.getRefBox(i) 99 | sentences = refs['sentences'] 100 | image_urls = refer.loadImgs(image_ids=refs['image_id'])[0] 101 | cat = cat_process(refs['category_id']) 102 | image_urls = image_urls['file_name'] 103 | if dataset == 'refclef' and image_urls in [ 104 | '19579.jpg', '17975.jpg', '19575.jpg' 105 | ]: 106 | continue 107 | box_info = bbox_process(bboxs) 108 | 109 | ref_dict['bbox'] = box_info 110 | ref_dict['cat'] = cat 111 | ref_dict['segment_id'] = i 112 | ref_dict['img_name'] = image_urls 113 | 114 | if generate_mask: 115 | cv2.imwrite(os.path.join(mask_path, 116 | str(i) + '.png'), 117 | refer.getMask(refs)['mask'] * 255) 118 | 119 | sent_dict = [] 120 | for i, sent in enumerate(sentences): 121 | sent_dict.append({ 122 | 'idx': i, 123 | 'sent_id': sent['sent_id'], 124 | 'sent': sent['sent'].strip() 125 | }) 126 | 127 | ref_dict['sentences'] = sent_dict 128 | ref_dict['sentences_num'] = len(sent_dict) 129 | 130 | dataset_array.append(ref_dict) 131 | print('Dumping json file...') 132 | with open(os.path.join(output_dir, 'anns', dataset, split + '.json'), 133 | 'w') as f: 134 | json.dump(dataset_array, f) 135 | 136 | 137 | prepare_dataset(args.dataset, splits, args.output_dir, args.generate_mask) 138 | -------------------------------------------------------------------------------- /tools/folder2lmdb.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import lmdb 5 | import pyarrow as pa 6 | import json 7 | from tqdm import tqdm 8 | import warnings 9 | warnings.filterwarnings("ignore") 10 | 11 | 12 | def loads_pyarrow(buf): 13 | """ 14 | Args: 15 | buf: the output of `dumps`. 16 | """ 17 | return pa.deserialize(buf) 18 | 19 | 20 | def raw_reader(path): 21 | with open(path, 'rb') as f: 22 | bin_data = f.read() 23 | return bin_data 24 | 25 | 26 | def dumps_pyarrow(obj): 27 | """ 28 | Serialize an object. 29 | Returns: 30 | Implementation-dependent bytes-like object 31 | """ 32 | return pa.serialize(obj).to_buffer() 33 | 34 | 35 | def folder2lmdb(json_data, img_dir, mask_dir, output_dir, split, write_frequency=1000): 36 | lmdb_path = osp.join(output_dir, "%s.lmdb" % split) 37 | isdir = os.path.isdir(lmdb_path) 38 | 39 | print("Generate LMDB to %s" % lmdb_path) 40 | db = lmdb.open(lmdb_path, subdir=isdir, 41 | map_size=1099511627776 * 2, readonly=False, 42 | meminit=False, map_async=True) 43 | 44 | txn = db.begin(write=True) 45 | tbar = tqdm(json_data) 46 | for idx, item in enumerate(tbar): 47 | img = raw_reader(osp.join(img_dir, item['img_name'])) 48 | mask = raw_reader(osp.join(mask_dir, f"{item['segment_id']}.png")) 49 | data = {'img': img, 'mask': mask, 'cat': item['cat'], 50 | 'seg_id': item['segment_id'], 'img_name': item['img_name'], 51 | 'num_sents': item['sentences_num'], 'sents': [i['sent'] for i in item['sentences']]} 52 | txn.put(u'{}'.format(idx).encode('ascii'), dumps_pyarrow(data)) 53 | if idx % write_frequency == 0: 54 | # print("[%d/%d]" % (idx, len(data_loader))) 55 | txn.commit() 56 | txn = db.begin(write=True) 57 | 58 | # finish iterating through dataset 59 | txn.commit() 60 | keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)] 61 | with db.begin(write=True) as txn: 62 | txn.put(b'__keys__', dumps_pyarrow(keys)) 63 | txn.put(b'__len__', dumps_pyarrow(len(keys))) 64 | 65 | print("Flushing database ...") 66 | db.sync() 67 | db.close() 68 | 69 | 70 | def parse_args(): 71 | parser = argparse.ArgumentParser(description='COCO Folder to LMDB.') 72 | parser.add_argument('-j', '--json-dir', type=str, 73 | default='', 74 | help='the name of json file.') 75 | parser.add_argument('-i', '--img-dir', type=str, 76 | default='refcoco+', 77 | help='the folder of images.') 78 | parser.add_argument('-m', '--mask-dir', type=str, 79 | default='refcoco+', 80 | help='the folder of masks.') 81 | parser.add_argument('-o', '--output-dir', type=str, 82 | default='refcoco+', 83 | help='the folder of output lmdb file.') 84 | parser.add_argument('-s', '--split', type=str, 85 | default='train', 86 | help='the split type.') 87 | args = parser.parse_args() 88 | return args 89 | 90 | 91 | if __name__ == '__main__': 92 | args = parse_args() 93 | args.split = osp.basename(args.json_dir).split(".")[0] 94 | os.makedirs(args.output_dir, exist_ok=True) 95 | 96 | with open(args.json_dir, 'rb') as f: 97 | json_data = json.load(f) 98 | 99 | folder2lmdb(json_data, args.img_dir, args.mask_dir, args.output_dir, args.split) 100 | -------------------------------------------------------------------------------- /tools/latency.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import time 4 | import warnings 5 | 6 | sys.path.append('./') 7 | warnings.filterwarnings("ignore") 8 | 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | import utils.config as config 12 | from model import build_segmenter 13 | 14 | 15 | def get_parser(): 16 | parser = argparse.ArgumentParser( 17 | description='Pytorch Referring Expression Segmentation') 18 | parser.add_argument('--config', 19 | default='path to xxx.yaml', 20 | type=str, 21 | help='config file') 22 | parser.add_argument('--opts', 23 | default=None, 24 | nargs=argparse.REMAINDER, 25 | help='override some settings in the config.') 26 | args = parser.parse_args() 27 | assert args.config is not None 28 | cfg = config.load_cfg_from_cfg_file(args.config) 29 | if args.opts is not None: 30 | cfg = config.merge_cfg_from_list(cfg, args.opts) 31 | return cfg 32 | 33 | 34 | def count_parameters(model): 35 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 36 | 37 | 38 | def main(): 39 | # init arguments 40 | args = get_parser() 41 | torch.cuda.set_device(0) 42 | # create model 43 | model, _ = build_segmenter(args) 44 | model = model.cuda() 45 | model.eval() 46 | # set cudnn state 47 | cudnn.benchmark = True 48 | cudnn.deterministic = False 49 | cudnn.enabled = True 50 | # init dummy tensor 51 | image = torch.randn(1, 3, 416, 416).cuda() 52 | text = torch.randint(4096, size=(1, args.word_len)).long().cuda() 53 | # init time & memory 54 | avg_time = 0 55 | avg_mem = 0 56 | # record initial gpu memory 57 | mem = torch.cuda.max_memory_allocated() 58 | 59 | with torch.no_grad(): 60 | for i in range(500): 61 | start_time = time.time() 62 | _ = model(image, text) 63 | torch.cuda.synchronize() 64 | if (i+1) >= 100: 65 | avg_time += (time.time() - start_time) 66 | avg_mem += (torch.cuda.max_memory_allocated() - mem) / 1.073742e9 67 | params = count_parameters(model) * 1e-6 68 | print('#########################################') 69 | print("Average Parameters : {:.2f} M".format(params)) 70 | print("Average FPS: {:.2f}".format(400/avg_time)) 71 | print("Average GPU Memory: {:.2f} GB".format(avg_mem/400)) 72 | print('#########################################') 73 | 74 | 75 | if __name__ == '__main__': 76 | main() 77 | -------------------------------------------------------------------------------- /tools/prepare_datasets.md: -------------------------------------------------------------------------------- 1 | ## Prepare datasets 2 | 3 | In our paper, we conduct experiments on three common-used datasets, including Ref-COCO, Ref-COCO+ and G-Ref. 4 | 5 | ### 1. COCO 2014 6 | 7 | The data could be found at [here](https://cocodataset.org/#download). Please run the following commands to download. 8 | 9 | ```shell 10 | # download 11 | mkdir datasets && cd datasets 12 | wget http://images.cocodataset.org/zips/train2014.zip 13 | 14 | # unzip 15 | unzip train2014.zip -d images/ && rm train2014.zip 16 | 17 | ``` 18 | 19 | ### 2. Ref-COCO 20 | 21 | The data could be found at [here](https://github.com/lichengunc/refer). Please run the following commands to download and convert. 22 | 23 | ```shell 24 | # download 25 | wget https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcoco.zip 26 | 27 | # unzip 28 | unzip refcoco.zip && rm refcoco.zip 29 | 30 | # convert 31 | python ../tools/data_process.py --data_root . --output_dir . --dataset refcoco --split unc --generate_mask 32 | 33 | # lmdb 34 | python ../tools/folder2lmdb.py -j anns/refcoco/train.json -i images/train2014/ -m masks/refcoco -o lmdb/refcoco 35 | python ../tools/folder2lmdb.py -j anns/refcoco/val.json -i images/train2014/ -m masks/refcoco -o lmdb/refcoco 36 | python ../tools/folder2lmdb.py -j anns/refcoco/testA.json -i images/train2014/ -m masks/refcoco -o lmdb/refcoco 37 | python ../tools/folder2lmdb.py -j anns/refcoco/testB.json -i images/train2014/ -m masks/refcoco -o lmdb/refcoco 38 | 39 | # clean 40 | rm -r refcoco 41 | 42 | ``` 43 | 44 | ### 3. Ref-COCO+ 45 | 46 | The data could be found at [here](https://github.com/lichengunc/refer). Please run the following commands to download and convert. 47 | 48 | ```shell 49 | # download 50 | wget https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcoco+.zip 51 | 52 | # unzip 53 | unzip refcoco+.zip && rm refcoco+.zip 54 | 55 | # convert 56 | python ../tools/data_process.py --data_root . --output_dir . --dataset refcoco+ --split unc --generate_mask 57 | 58 | # lmdb 59 | python ../tools/folder2lmdb.py -j anns/refcoco+/train.json -i images/train2014/ -m masks/refcoco+ -o lmdb/refcoco+ 60 | python ../tools/folder2lmdb.py -j anns/refcoco+/val.json -i images/train2014/ -m masks/refcoco+ -o lmdb/refcoco+ 61 | python ../tools/folder2lmdb.py -j anns/refcoco+/testA.json -i images/train2014/ -m masks/refcoco+ -o lmdb/refcoco+ 62 | python ../tools/folder2lmdb.py -j anns/refcoco+/testB.json -i images/train2014/ -m masks/refcoco+ -o lmdb/refcoco+ 63 | 64 | # clean 65 | rm -r refcoco+ 66 | 67 | ``` 68 | 69 | ### 4. Ref-COCOg 70 | 71 | The data could be found at [here](https://github.com/lichengunc/refer). Please run the following commands to download and convert. 72 | (Note that we adopt two different splits of this dataset, 'umd' and 'google'.) 73 | 74 | ```shell 75 | # download 76 | wget https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcocog.zip 77 | 78 | # unzip 79 | unzip refcocog.zip && rm refcocog.zip 80 | 81 | # convert 82 | python ../tools/data_process.py --data_root . --output_dir . --dataset refcocog --split umd --generate_mask # umd split 83 | mv anns/refcocog anns/refcocog_u 84 | mv masks/refcocog masks/refcocog_u 85 | 86 | python ../tools/data_process.py --data_root . --output_dir . --dataset refcocog --split google --generate_mask # google split 87 | mv anns/refcocog anns/refcocog_g 88 | mv masks/refcocog masks/refcocog_g 89 | 90 | # lmdb 91 | python ../tools/folder2lmdb.py -j anns/refcocog_u/train.json -i images/train2014/ -m masks/refcocog_u -o lmdb/refcocog_u 92 | python ../tools/folder2lmdb.py -j anns/refcocog_u/val.json -i images/train2014/ -m masks/refcocog_u -o lmdb/refcocog_u 93 | python ../tools/folder2lmdb.py -j anns/refcocog_u/test.json -i images/train2014/ -m masks/refcocog_u -o lmdb/refcocog_u 94 | 95 | python ../tools/folder2lmdb.py -j anns/refcocog_g/train.json -i images/train2014/ -m masks/refcocog_g -o lmdb/refcocog_g 96 | python ../tools/folder2lmdb.py -j anns/refcocog_g/val.json -i images/train2014/ -m masks/refcocog_g -o lmdb/refcocog_g 97 | 98 | rm -r refcocog 99 | 100 | ``` 101 | 102 | ### 5. Datasets struture 103 | 104 | After the above-mentioned commands, the strutre of the dataset folder should be like: 105 | 106 | ```none 107 | datasets 108 | ├── anns 109 | │ ├── refcoco 110 | │ │ ├── xxx.json 111 | │ ├── refcoco+ 112 | │ │ ├── xxx.json 113 | │ ├── refcocog_g 114 | │ │ ├── xxx.json 115 | │ ├── refcocog_u 116 | │ │ ├── xxx.json 117 | ├── images 118 | │ ├── train2014 119 | │ │ ├── xxx.jpg 120 | ├── lmdb 121 | │ ├── refcoco 122 | │ │ ├── xxx.lmdb 123 | │ │ ├── xxx.lmdb-lock 124 | │ ├── refcoco+ 125 | │ │ ├── xxx.lmdb 126 | │ │ ├── xxx.lmdb-lock 127 | │ ├── refcocog_g 128 | │ │ ├── xxx.lmdb 129 | │ │ ├── xxx.lmdb-lock 130 | │ ├── refcocog_u 131 | │ │ ├── xxx.lmdb 132 | │ │ ├── xxx.lmdb-lock 133 | ├── masks 134 | │ ├── refcoco 135 | │ │ ├── xxx.png 136 | │ ├── refcoco+ 137 | │ │ ├── xxx.png 138 | │ ├── refcocog_g 139 | │ │ ├── xxx.png 140 | │ ├── refcocog_u 141 | │ │ ├── xxx.png 142 | 143 | ``` -------------------------------------------------------------------------------- /tools/refer.py: -------------------------------------------------------------------------------- 1 | __author__ = 'licheng' 2 | """ 3 | This interface provides access to four datasets: 4 | 1) refclef 5 | 2) refcoco 6 | 3) refcoco+ 7 | 4) refcocog 8 | split by unc and google 9 | The following API functions are defined: 10 | REFER - REFER api class 11 | getRefIds - get ref ids that satisfy given filter conditions. 12 | getAnnIds - get ann ids that satisfy given filter conditions. 13 | getImgIds - get image ids that satisfy given filter conditions. 14 | getCatIds - get category ids that satisfy given filter conditions. 15 | loadRefs - load refs with the specified ref ids. 16 | loadAnns - load anns with the specified ann ids. 17 | loadImgs - load images with the specified image ids. 18 | loadCats - load category names with the specified category ids. 19 | getRefBox - get ref's bounding box [x, y, w, h] given the ref_id 20 | showRef - show image, segmentation or box of the referred object with the ref 21 | getMask - get mask and area of the referred object given ref 22 | showMask - show mask of the referred object given ref 23 | """ 24 | 25 | import itertools 26 | import json 27 | import os.path as osp 28 | import pickle 29 | import sys 30 | import time 31 | from pprint import pprint 32 | 33 | import matplotlib.pyplot as plt 34 | import numpy as np 35 | import skimage.io as io 36 | from matplotlib.collections import PatchCollection 37 | from matplotlib.patches import Polygon, Rectangle 38 | from pycocotools import mask 39 | 40 | 41 | class REFER: 42 | def __init__(self, data_root, dataset='refcoco', splitBy='unc'): 43 | # provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog 44 | # also provide dataset name and splitBy information 45 | # e.g., dataset = 'refcoco', splitBy = 'unc' 46 | print('loading dataset %s into memory...' % dataset) 47 | self.ROOT_DIR = osp.abspath(osp.dirname(__file__)) 48 | self.DATA_DIR = osp.join(data_root, dataset) 49 | if dataset in ['refcoco', 'refcoco+', 'refcocog']: 50 | self.IMAGE_DIR = osp.join(data_root, 'images/train2014') 51 | elif dataset == 'refclef': 52 | self.IMAGE_DIR = osp.join(data_root, 'images/saiapr_tc-12') 53 | else: 54 | print('No refer dataset is called [%s]' % dataset) 55 | sys.exit() 56 | 57 | # load refs from data/dataset/refs(dataset).json 58 | tic = time.time() 59 | ref_file = osp.join(self.DATA_DIR, 'refs(' + splitBy + ').p') 60 | self.data = {} 61 | self.data['dataset'] = dataset 62 | 63 | self.data['refs'] = pickle.load(open(ref_file, 'rb'), fix_imports=True) 64 | 65 | # load annotations from data/dataset/instances.json 66 | instances_file = osp.join(self.DATA_DIR, 'instances.json') 67 | instances = json.load(open(instances_file, 'r')) 68 | self.data['images'] = instances['images'] 69 | self.data['annotations'] = instances['annotations'] 70 | self.data['categories'] = instances['categories'] 71 | 72 | # create index 73 | self.createIndex() 74 | print('DONE (t=%.2fs)' % (time.time() - tic)) 75 | 76 | def createIndex(self): 77 | # create sets of mapping 78 | # 1) Refs: {ref_id: ref} 79 | # 2) Anns: {ann_id: ann} 80 | # 3) Imgs: {image_id: image} 81 | # 4) Cats: {category_id: category_name} 82 | # 5) Sents: {sent_id: sent} 83 | # 6) imgToRefs: {image_id: refs} 84 | # 7) imgToAnns: {image_id: anns} 85 | # 8) refToAnn: {ref_id: ann} 86 | # 9) annToRef: {ann_id: ref} 87 | # 10) catToRefs: {category_id: refs} 88 | # 11) sentToRef: {sent_id: ref} 89 | # 12) sentToTokens: {sent_id: tokens} 90 | print('creating index...') 91 | # fetch info from instances 92 | Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {} 93 | for ann in self.data['annotations']: 94 | Anns[ann['id']] = ann 95 | imgToAnns[ann['image_id']] = imgToAnns.get(ann['image_id'], 96 | []) + [ann] 97 | for img in self.data['images']: 98 | Imgs[img['id']] = img 99 | for cat in self.data['categories']: 100 | Cats[cat['id']] = cat['name'] 101 | 102 | # fetch info from refs 103 | Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {} 104 | Sents, sentToRef, sentToTokens = {}, {}, {} 105 | for ref in self.data['refs']: 106 | # ids 107 | ref_id = ref['ref_id'] 108 | ann_id = ref['ann_id'] 109 | category_id = ref['category_id'] 110 | image_id = ref['image_id'] 111 | 112 | # add mapping related to ref 113 | Refs[ref_id] = ref 114 | imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref] 115 | catToRefs[category_id] = catToRefs.get(category_id, []) + [ref] 116 | refToAnn[ref_id] = Anns[ann_id] 117 | annToRef[ann_id] = ref 118 | 119 | # add mapping of sent 120 | for sent in ref['sentences']: 121 | Sents[sent['sent_id']] = sent 122 | sentToRef[sent['sent_id']] = ref 123 | sentToTokens[sent['sent_id']] = sent['tokens'] 124 | 125 | # create class members 126 | self.Refs = Refs 127 | self.Anns = Anns 128 | self.Imgs = Imgs 129 | self.Cats = Cats 130 | self.Sents = Sents 131 | self.imgToRefs = imgToRefs 132 | self.imgToAnns = imgToAnns 133 | self.refToAnn = refToAnn 134 | self.annToRef = annToRef 135 | self.catToRefs = catToRefs 136 | self.sentToRef = sentToRef 137 | self.sentToTokens = sentToTokens 138 | print('index created.') 139 | 140 | def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=''): 141 | image_ids = image_ids if type(image_ids) == list else [image_ids] 142 | cat_ids = cat_ids if type(cat_ids) == list else [cat_ids] 143 | ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] 144 | 145 | if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0: 146 | refs = self.data['refs'] 147 | else: 148 | if not len(image_ids) == 0: 149 | refs = [self.imgToRefs[image_id] for image_id in image_ids] 150 | else: 151 | refs = self.data['refs'] 152 | if not len(cat_ids) == 0: 153 | refs = [ref for ref in refs if ref['category_id'] in cat_ids] 154 | if not len(ref_ids) == 0: 155 | refs = [ref for ref in refs if ref['ref_id'] in ref_ids] 156 | if not len(split) == 0: 157 | if split in ['testA', 'testB', 'testC']: 158 | refs = [ref for ref in refs if split[-1] in ref['split'] 159 | ] # we also consider testAB, testBC, ... 160 | elif split in ['testAB', 'testBC', 'testAC']: 161 | refs = [ref for ref in refs 162 | if ref['split'] == split] # rarely used I guess... 163 | elif split == 'test': 164 | refs = [ref for ref in refs if 'test' in ref['split']] 165 | elif split == 'train' or split == 'val': 166 | refs = [ref for ref in refs if ref['split'] == split] 167 | else: 168 | print('No such split [%s]' % split) 169 | sys.exit() 170 | ref_ids = [ref['ref_id'] for ref in refs] 171 | return ref_ids 172 | 173 | def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]): 174 | image_ids = image_ids if type(image_ids) == list else [image_ids] 175 | cat_ids = cat_ids if type(cat_ids) == list else [cat_ids] 176 | ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] 177 | 178 | if len(image_ids) == len(cat_ids) == len(ref_ids) == 0: 179 | ann_ids = [ann['id'] for ann in self.data['annotations']] 180 | else: 181 | if not len(image_ids) == 0: 182 | lists = [ 183 | self.imgToAnns[image_id] for image_id in image_ids 184 | if image_id in self.imgToAnns 185 | ] # list of [anns] 186 | anns = list(itertools.chain.from_iterable(lists)) 187 | else: 188 | anns = self.data['annotations'] 189 | if not len(cat_ids) == 0: 190 | anns = [ann for ann in anns if ann['category_id'] in cat_ids] 191 | ann_ids = [ann['id'] for ann in anns] 192 | if not len(ref_ids) == 0: 193 | ids = set(ann_ids).intersection( 194 | set([self.Refs[ref_id]['ann_id'] for ref_id in ref_ids])) 195 | return ann_ids 196 | 197 | def getImgIds(self, ref_ids=[]): 198 | ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] 199 | 200 | if not len(ref_ids) == 0: 201 | image_ids = list( 202 | set([self.Refs[ref_id]['image_id'] for ref_id in ref_ids])) 203 | else: 204 | image_ids = self.Imgs.keys() 205 | return image_ids 206 | 207 | def getCatIds(self): 208 | return self.Cats.keys() 209 | 210 | def loadRefs(self, ref_ids=[]): 211 | if type(ref_ids) == list: 212 | return [self.Refs[ref_id] for ref_id in ref_ids] 213 | elif type(ref_ids) == int: 214 | return [self.Refs[ref_ids]] 215 | 216 | def loadAnns(self, ann_ids=[]): 217 | if type(ann_ids) == list: 218 | return [self.Anns[ann_id] for ann_id in ann_ids] 219 | elif type(ann_ids) == int or type(ann_ids) == unicode: 220 | return [self.Anns[ann_ids]] 221 | 222 | def loadImgs(self, image_ids=[]): 223 | if type(image_ids) == list: 224 | return [self.Imgs[image_id] for image_id in image_ids] 225 | elif type(image_ids) == int: 226 | return [self.Imgs[image_ids]] 227 | 228 | def loadCats(self, cat_ids=[]): 229 | if type(cat_ids) == list: 230 | return [self.Cats[cat_id] for cat_id in cat_ids] 231 | elif type(cat_ids) == int: 232 | return [self.Cats[cat_ids]] 233 | 234 | def getRefBox(self, ref_id): 235 | ref = self.Refs[ref_id] 236 | ann = self.refToAnn[ref_id] 237 | return ann['bbox'] # [x, y, w, h] 238 | 239 | def showRef(self, ref, seg_box='seg'): 240 | ax = plt.gca() 241 | # show image 242 | image = self.Imgs[ref['image_id']] 243 | I = io.imread(osp.join(self.IMAGE_DIR, image['file_name'])) 244 | ax.imshow(I) 245 | # show refer expression 246 | for sid, sent in enumerate(ref['sentences']): 247 | print('%s. %s' % (sid + 1, sent['sent'])) 248 | # show segmentations 249 | if seg_box == 'seg': 250 | ann_id = ref['ann_id'] 251 | ann = self.Anns[ann_id] 252 | polygons = [] 253 | color = [] 254 | c = 'none' 255 | if type(ann['segmentation'][0]) == list: 256 | # polygon used for refcoco* 257 | for seg in ann['segmentation']: 258 | poly = np.array(seg).reshape((len(seg) / 2, 2)) 259 | polygons.append(Polygon(poly, True, alpha=0.4)) 260 | color.append(c) 261 | p = PatchCollection(polygons, 262 | facecolors=color, 263 | edgecolors=(1, 1, 0, 0), 264 | linewidths=3, 265 | alpha=1) 266 | ax.add_collection(p) # thick yellow polygon 267 | p = PatchCollection(polygons, 268 | facecolors=color, 269 | edgecolors=(1, 0, 0, 0), 270 | linewidths=1, 271 | alpha=1) 272 | ax.add_collection(p) # thin red polygon 273 | else: 274 | # mask used for refclef 275 | rle = ann['segmentation'] 276 | m = mask.decode(rle) 277 | img = np.ones((m.shape[0], m.shape[1], 3)) 278 | color_mask = np.array([2.0, 166.0, 101.0]) / 255 279 | for i in range(3): 280 | img[:, :, i] = color_mask[i] 281 | ax.imshow(np.dstack((img, m * 0.5))) 282 | # show bounding-box 283 | elif seg_box == 'box': 284 | ann_id = ref['ann_id'] 285 | ann = self.Anns[ann_id] 286 | bbox = self.getRefBox(ref['ref_id']) 287 | box_plot = Rectangle((bbox[0], bbox[1]), 288 | bbox[2], 289 | bbox[3], 290 | fill=False, 291 | edgecolor='green', 292 | linewidth=3) 293 | ax.add_patch(box_plot) 294 | 295 | def getMask(self, ref): 296 | # return mask, area and mask-center 297 | ann = self.refToAnn[ref['ref_id']] 298 | image = self.Imgs[ref['image_id']] 299 | if type(ann['segmentation'][0]) == list: # polygon 300 | rle = mask.frPyObjects(ann['segmentation'], image['height'], 301 | image['width']) 302 | else: 303 | rle = ann['segmentation'] 304 | 305 | # for i in range(len(rle['counts'])): 306 | # print(rle) 307 | m = mask.decode(rle) 308 | m = np.sum( 309 | m, axis=2 310 | ) # sometimes there are multiple binary map (corresponding to multiple segs) 311 | m = m.astype(np.uint8) # convert to np.uint8 312 | # compute area 313 | area = sum(mask.area(rle)) # should be close to ann['area'] 314 | return {'mask': m, 'area': area} 315 | # # position 316 | # position_x = np.mean(np.where(m==1)[1]) # [1] means columns (matlab style) -> x (c style) 317 | # position_y = np.mean(np.where(m==1)[0]) # [0] means rows (matlab style) -> y (c style) 318 | # # mass position (if there were multiple regions, we use the largest one.) 319 | # label_m = label(m, connectivity=m.ndim) 320 | # regions = regionprops(label_m) 321 | # if len(regions) > 0: 322 | # largest_id = np.argmax(np.array([props.filled_area for props in regions])) 323 | # largest_props = regions[largest_id] 324 | # mass_y, mass_x = largest_props.centroid 325 | # else: 326 | # mass_x, mass_y = position_x, position_y 327 | # # if centroid is not in mask, we find the closest point to it from mask 328 | # if m[mass_y, mass_x] != 1: 329 | # print 'Finding closes mask point ...' 330 | # kernel = np.ones((10, 10),np.uint8) 331 | # me = cv2.erode(m, kernel, iterations = 1) 332 | # points = zip(np.where(me == 1)[0].tolist(), np.where(me == 1)[1].tolist()) # row, col style 333 | # points = np.array(points) 334 | # dist = np.sum((points - (mass_y, mass_x))**2, axis=1) 335 | # id = np.argsort(dist)[0] 336 | # mass_y, mass_x = points[id] 337 | # # return 338 | # return {'mask': m, 'area': area, 'position_x': position_x, 'position_y': position_y, 'mass_x': mass_x, 'mass_y': mass_y} 339 | # # show image and mask 340 | # I = io.imread(osp.join(self.IMAGE_DIR, image['file_name'])) 341 | # plt.figure() 342 | # plt.imshow(I) 343 | # ax = plt.gca() 344 | # img = np.ones( (m.shape[0], m.shape[1], 3) ) 345 | # color_mask = np.array([2.0,166.0,101.0])/255 346 | # for i in range(3): 347 | # img[:,:,i] = color_mask[i] 348 | # ax.imshow(np.dstack( (img, m*0.5) )) 349 | # plt.show() 350 | 351 | def showMask(self, ref): 352 | M = self.getMask(ref) 353 | msk = M['mask'] 354 | ax = plt.gca() 355 | ax.imshow(msk) 356 | 357 | 358 | if __name__ == '__main__': 359 | refer = REFER(dataset='refcocog', splitBy='google') 360 | ref_ids = refer.getRefIds() 361 | print(len(ref_ids)) 362 | 363 | print(len(refer.Imgs)) 364 | print(len(refer.imgToRefs)) 365 | 366 | ref_ids = refer.getRefIds(split='train') 367 | print('There are %s training referred objects.' % len(ref_ids)) 368 | 369 | for ref_id in ref_ids: 370 | ref = refer.loadRefs(ref_id)[0] 371 | if len(ref['sentences']) < 2: 372 | continue 373 | 374 | pprint(ref) 375 | print('The label is %s.' % refer.Cats[ref['category_id']]) 376 | plt.figure() 377 | refer.showRef(ref, seg_box='box') 378 | plt.show() 379 | 380 | # plt.figure() 381 | # refer.showMask(ref) 382 | # plt.show() 383 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import os 4 | import shutil 5 | import sys 6 | import time 7 | import warnings 8 | from functools import partial 9 | 10 | import cv2 11 | import torch 12 | import torch.cuda.amp as amp 13 | import torch.distributed as dist 14 | import torch.multiprocessing as mp #https://blog.csdn.net/hxxjxw/article/details/119839548 15 | import torch.nn as nn 16 | import torch.nn.parallel 17 | import torch.optim 18 | import torch.utils.data as data 19 | from loguru import logger # https://hanjunqiang.blog.csdn.net/article/details/124779625 20 | from torch.optim.lr_scheduler import MultiStepLR 21 | 22 | import utils.config as config 23 | import wandb 24 | from utils.dataset import RefDataset 25 | from engine.engine import train, validate 26 | from model import build_segmenter 27 | from utils.misc import (init_random_seed, set_random_seed, setup_logger, 28 | worker_init_fn, build_scheduler, collate_fn) 29 | 30 | warnings.filterwarnings("ignore") 31 | cv2.setNumThreads(0) 32 | 33 | 34 | def get_parser(): 35 | parser = argparse.ArgumentParser( 36 | description='Pytorch Referring Expression Segmentation') 37 | parser.add_argument('--config', 38 | default='path to xxx.yaml', 39 | type=str, 40 | help='config file') 41 | parser.add_argument('--opts', 42 | default=None, 43 | nargs=argparse.REMAINDER, 44 | help='override some settings in the config.') 45 | 46 | args = parser.parse_args() 47 | assert args.config is not None 48 | cfg = config.load_cfg_from_cfg_file(args.config) 49 | if args.opts is not None: 50 | cfg = config.merge_cfg_from_list(cfg, args.opts) 51 | return cfg 52 | 53 | 54 | @logger.catch #在子线程或主线程中捕获异常 55 | def main(): 56 | args = get_parser() 57 | args.manual_seed = init_random_seed(args.manual_seed) 58 | set_random_seed(args.manual_seed, deterministic=True) 59 | 60 | args.ngpus_per_node = torch.cuda.device_count() 61 | args.world_size = args.ngpus_per_node * args.world_size 62 | mp.spawn(main_worker, nprocs=args.ngpus_per_node, args=(args, )) 63 | 64 | 65 | def main_worker(gpu, args): 66 | args.output_dir = os.path.join(args.output_folder, args.exp_name) 67 | 68 | # local rank & global rank 69 | args.gpu = gpu 70 | args.rank = args.rank * args.ngpus_per_node + gpu 71 | torch.cuda.set_device(args.gpu) 72 | 73 | # logger 74 | setup_logger(args.output_dir, 75 | distributed_rank=args.gpu, 76 | filename="train.log", 77 | mode="a") 78 | # dist init 79 | dist.init_process_group(backend=args.dist_backend, 80 | init_method=args.dist_url, 81 | world_size=args.world_size, 82 | rank=args.rank) 83 | # wandb 84 | if args.rank == 0: 85 | wandb.init(job_type="training", 86 | mode="offline", 87 | config=args, 88 | project=args.exp_name, 89 | name=args.exp_name, 90 | tags=[args.dataset]) 91 | dist.barrier() 92 | # build model 93 | model, param_list = build_segmenter(args) 94 | # logger.info(model) 95 | logger.info(args) 96 | 97 | # build optimizer & lr scheduler 98 | optimizer = torch.optim.AdamW(param_list, 99 | lr=args.lr, 100 | weight_decay=args.weight_decay, 101 | amsgrad=args.amsgrad 102 | ) 103 | 104 | scaler = amp.GradScaler() 105 | 106 | # build dataset 107 | args.batch_size = int(args.batch_size / args.ngpus_per_node) 108 | args.batch_size_val = int(args.batch_size_val / args.ngpus_per_node) 109 | args.workers = int( 110 | (args.workers + args.ngpus_per_node - 1) / args.ngpus_per_node) 111 | train_data = RefDataset(lmdb_dir=args.train_lmdb, 112 | mask_dir=args.mask_root, 113 | dataset=args.dataset, 114 | split=args.train_split, 115 | mode='train', 116 | input_size=args.input_size, 117 | word_length=args.word_len 118 | ) 119 | val_data = RefDataset(lmdb_dir=args.val_lmdb, 120 | mask_dir=args.mask_root, 121 | dataset=args.dataset, 122 | split=args.val_split, 123 | mode='val', 124 | input_size=args.input_size, 125 | word_length=args.word_len, 126 | ) 127 | 128 | # build dataloader 129 | init_fn = partial(worker_init_fn, 130 | num_workers=args.workers, 131 | rank=args.rank, 132 | seed=args.manual_seed) 133 | train_sampler = data.distributed.DistributedSampler(train_data, 134 | shuffle=True) 135 | val_sampler = data.distributed.DistributedSampler(val_data, shuffle=False) 136 | 137 | train_loader = data.DataLoader(train_data, 138 | batch_size=args.batch_size, 139 | shuffle=False, 140 | num_workers=args.workers, 141 | pin_memory=True, 142 | worker_init_fn=init_fn, 143 | sampler=train_sampler, 144 | collate_fn=collate_fn, 145 | drop_last=True) 146 | val_loader = data.DataLoader(val_data, 147 | batch_size=args.batch_size_val, 148 | shuffle=False, 149 | num_workers=args.workers_val, 150 | pin_memory=True, 151 | sampler=val_sampler, 152 | drop_last=False, 153 | collate_fn=collate_fn, 154 | ) 155 | 156 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda x: (1 - x / (len(train_loader) * args.epochs)) ** 0.9) 157 | 158 | best_IoU = 0.0 159 | # resume 160 | if args.resume: 161 | if os.path.isfile(args.resume): 162 | logger.info("=> loading checkpoint '{}'".format(args.resume)) 163 | checkpoint = torch.load( 164 | args.resume, map_location=lambda storage, loc: storage.cuda()) 165 | args.start_epoch = checkpoint['epoch'] 166 | best_IoU = checkpoint["best_iou"] 167 | checkpoint['model_state_dict'].pop('decoder.tokens.weight') 168 | optimizer.load_state_dict(checkpoint['optimizer']) 169 | scheduler.load_state_dict(checkpoint['scheduler']) 170 | logger.info("=> loaded checkpoint '{}' (epoch {})".format( 171 | args.resume, checkpoint['epoch'])) 172 | else: 173 | raise ValueError( 174 | "=> resume failed! no checkpoint found at '{}'. Please check args.resume again!" 175 | .format(args.resume)) 176 | 177 | # start training 178 | start_time = time.time() 179 | for epoch in range(args.start_epoch, args.epochs): 180 | epoch_log = epoch + 1 181 | 182 | # shuffle loader 183 | train_sampler.set_epoch(epoch_log) 184 | 185 | # train 186 | train(train_loader, model, optimizer, scheduler, scaler, epoch_log, args) 187 | 188 | # evaluation 189 | iou, prec_dict = validate(val_loader, model, epoch_log, args) 190 | 191 | # save model 192 | if dist.get_rank() == 0: 193 | lastname = os.path.join(args.output_dir, "last_model.pth") 194 | torch.save( 195 | { 196 | 'epoch': epoch_log, 197 | 'cur_iou': iou, 198 | 'best_iou': best_IoU, 199 | 'prec': prec_dict, 200 | 'model_state_dict': model.module.state_dict(), 201 | 'optimizer': optimizer.state_dict(), 202 | 'scheduler': scheduler.state_dict() 203 | }, lastname) 204 | if iou >= best_IoU and epoch_log<50: 205 | best_IoU = iou 206 | bestname = os.path.join(args.output_dir, "best_model.pth") 207 | shutil.copyfile(lastname, bestname) 208 | 209 | torch.cuda.empty_cache() 210 | 211 | time.sleep(2) 212 | if dist.get_rank() == 0: 213 | wandb.finish() 214 | 215 | logger.info("* Best IoU={} * ".format(best_IoU)) 216 | total_time = time.time() - start_time 217 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 218 | logger.info('* Training time {} *'.format(total_time_str)) 219 | 220 | 221 | if __name__ == '__main__': 222 | main() 223 | sys.exit(0) 224 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SooLab/CGFormer/766d7fbe0c0101c80806e2499bb4d6960cfc5f4a/utils/__init__.py -------------------------------------------------------------------------------- /utils/box_ops.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for bounding box manipulation and GIoU. 3 | """ 4 | import torch 5 | from torchvision.ops.boxes import box_area 6 | 7 | def clip_iou(boxes1,boxes2): 8 | area1 = box_area(boxes1) 9 | area2 = box_area(boxes2) 10 | lt = torch.max(boxes1[:, :2], boxes2[:, :2]) 11 | rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) 12 | wh = (rb - lt).clamp(min=0) 13 | inter = wh[:,0] * wh[:,1] 14 | union = area1 + area2 - inter 15 | iou = (inter + 1e-6) / (union+1e-6) 16 | return iou 17 | 18 | def multi_iou(boxes1, boxes2): 19 | lt = torch.max(boxes1[...,:2], boxes2[...,:2]) 20 | rb = torch.min(boxes1[...,2:], boxes2[...,2:]) 21 | wh = (rb - lt).clamp(min=0) 22 | wh_1 = boxes1[...,2:] - boxes1[...,:2] 23 | wh_2 = boxes2[...,2:] - boxes2[...,:2] 24 | inter = wh[...,0] * wh[...,1] 25 | union = wh_1[...,0] * wh_1[...,1] + wh_2[...,0] * wh_2[...,1] - inter 26 | iou = (inter + 1e-6) / (union + 1e-6) 27 | return iou 28 | 29 | def box_cxcywh_to_xyxy(x): 30 | x_c, y_c, w, h = x.unbind(-1) 31 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 32 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 33 | return torch.stack(b, dim=-1) 34 | 35 | 36 | def box_xyxy_to_cxcywh(x): 37 | x0, y0, x1, y1 = x.unbind(-1) 38 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 39 | (x1 - x0), (y1 - y0)] 40 | return torch.stack(b, dim=-1) 41 | 42 | 43 | # modified from torchvision to also return the union 44 | def box_iou(boxes1, boxes2): 45 | area1 = box_area(boxes1) 46 | area2 = box_area(boxes2) 47 | 48 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 49 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 50 | 51 | wh = (rb - lt).clamp(min=0) # [N,M,2] 52 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 53 | 54 | union = area1[:, None] + area2 - inter 55 | 56 | iou = (inter+1e-6) / (union+1e-6) 57 | return iou, union 58 | 59 | 60 | def generalized_box_iou(boxes1, boxes2): 61 | """ 62 | Generalized IoU from https://giou.stanford.edu/ 63 | 64 | The boxes should be in [x0, y0, x1, y1] format 65 | 66 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 67 | and M = len(boxes2) 68 | """ 69 | # degenerate boxes gives inf / nan results 70 | # so do an early check 71 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 72 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 73 | iou, union = box_iou(boxes1, boxes2) 74 | 75 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 76 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 77 | 78 | wh = (rb - lt).clamp(min=0) # [N,M,2] 79 | area = wh[:, :, 0] * wh[:, :, 1] 80 | 81 | return iou - ((area - union) + 1e-6) / (area + 1e-6) 82 | 83 | 84 | def masks_to_boxes(masks): 85 | """Compute the bounding boxes around the provided masks 86 | 87 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 88 | 89 | Returns a [N, 4] tensors, with the boxes in xyxy format 90 | """ 91 | if masks.numel() == 0: 92 | return torch.zeros((0, 4), device=masks.device) 93 | 94 | h, w = masks.shape[-2:] 95 | 96 | y = torch.arange(0, h, dtype=torch.float) 97 | x = torch.arange(0, w, dtype=torch.float) 98 | y, x = torch.meshgrid(y, x) 99 | 100 | x_mask = (masks * x.unsqueeze(0)) 101 | x_max = x_mask.flatten(1).max(-1)[0] 102 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 103 | 104 | y_mask = (masks * y.unsqueeze(0)) 105 | y_max = y_mask.flatten(1).max(-1)[0] 106 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 107 | 108 | return torch.stack([x_min, y_min, x_max, y_max], 1) 109 | -------------------------------------------------------------------------------- /utils/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SooLab/CGFormer/766d7fbe0c0101c80806e2499bb4d6960cfc5f4a/utils/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # Functions for parsing args 3 | # ----------------------------------------------------------------------------- 4 | import copy 5 | import os 6 | from ast import literal_eval 7 | 8 | import yaml 9 | 10 | 11 | class CfgNode(dict): 12 | """ 13 | CfgNode represents an internal node in the configuration tree. It's a simple 14 | dict-like container that allows for attribute-based access to keys. 15 | """ 16 | def __init__(self, init_dict=None, key_list=None, new_allowed=False): 17 | # Recursively convert nested dictionaries in init_dict into CfgNodes 18 | init_dict = {} if init_dict is None else init_dict 19 | key_list = [] if key_list is None else key_list 20 | for k, v in init_dict.items(): 21 | if type(v) is dict: 22 | # Convert dict to CfgNode 23 | init_dict[k] = CfgNode(v, key_list=key_list + [k]) 24 | super(CfgNode, self).__init__(init_dict) 25 | 26 | def __getattr__(self, name): 27 | if name in self: 28 | return self[name] 29 | else: 30 | raise AttributeError(name) 31 | 32 | def __setattr__(self, name, value): 33 | self[name] = value 34 | 35 | def __str__(self): 36 | def _indent(s_, num_spaces): 37 | s = s_.split("\n") 38 | if len(s) == 1: 39 | return s_ 40 | first = s.pop(0) 41 | s = [(num_spaces * " ") + line for line in s] 42 | s = "\n".join(s) 43 | s = first + "\n" + s 44 | return s 45 | 46 | r = "" 47 | s = [] 48 | for k, v in sorted(self.items()): 49 | seperator = "\n" if isinstance(v, CfgNode) else " " 50 | attr_str = "{}:{}{}".format(str(k), seperator, str(v)) 51 | attr_str = _indent(attr_str, 2) 52 | s.append(attr_str) 53 | r += "\n".join(s) 54 | return r 55 | 56 | def __repr__(self): 57 | return "{}({})".format(self.__class__.__name__, 58 | super(CfgNode, self).__repr__()) 59 | 60 | 61 | def load_cfg_from_cfg_file(file): 62 | cfg = {} 63 | assert os.path.isfile(file) and file.endswith('.yaml'), \ 64 | '{} is not a yaml file'.format(file) 65 | 66 | with open(file, 'r') as f: 67 | cfg_from_file = yaml.safe_load(f) 68 | 69 | for key in cfg_from_file: 70 | for k, v in cfg_from_file[key].items(): 71 | cfg[k] = v 72 | 73 | cfg = CfgNode(cfg) 74 | return cfg 75 | 76 | 77 | def merge_cfg_from_list(cfg, cfg_list): 78 | new_cfg = copy.deepcopy(cfg) 79 | assert len(cfg_list) % 2 == 0 80 | for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]): 81 | subkey = full_key.split('.')[-1] 82 | assert subkey in cfg, 'Non-existent key: {}'.format(full_key) 83 | value = _decode_cfg_value(v) 84 | value = _check_and_coerce_cfg_value_type(value, cfg[subkey], subkey, 85 | full_key) 86 | setattr(new_cfg, subkey, value) 87 | 88 | return new_cfg 89 | 90 | 91 | def _decode_cfg_value(v): 92 | """Decodes a raw config value (e.g., from a yaml config files or command 93 | line argument) into a Python object. 94 | """ 95 | # All remaining processing is only applied to strings 96 | if not isinstance(v, str): 97 | return v 98 | # Try to interpret `v` as a: 99 | # string, number, tuple, list, dict, boolean, or None 100 | try: 101 | v = literal_eval(v) 102 | # The following two excepts allow v to pass through when it represents a 103 | # string. 104 | # 105 | # Longer explanation: 106 | # The type of v is always a string (before calling literal_eval), but 107 | # sometimes it *represents* a string and other times a data structure, like 108 | # a list. In the case that v represents a string, what we got back from the 109 | # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is 110 | # ok with '"foo"', but will raise a ValueError if given 'foo'. In other 111 | # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval 112 | # will raise a SyntaxError. 113 | except ValueError: 114 | pass 115 | except SyntaxError: 116 | pass 117 | return v 118 | 119 | 120 | def _check_and_coerce_cfg_value_type(replacement, original, key, full_key): 121 | """Checks that `replacement`, which is intended to replace `original` is of 122 | the right type. The type is correct if it matches exactly or is one of a few 123 | cases in which the type can be easily coerced. 124 | """ 125 | original_type = type(original) 126 | replacement_type = type(replacement) 127 | 128 | # The types must match (with some exceptions) 129 | if replacement_type == original_type: 130 | return replacement 131 | 132 | # Cast replacement from from_type to to_type if the replacement and original 133 | # types match from_type and to_type 134 | def conditional_cast(from_type, to_type): 135 | if replacement_type == from_type and original_type == to_type: 136 | return True, to_type(replacement) 137 | else: 138 | return False, None 139 | 140 | # Conditionally casts 141 | # list <-> tuple 142 | casts = [(tuple, list), (list, tuple)] 143 | # For py2: allow converting from str (bytes) to a unicode string 144 | try: 145 | casts.append((str, unicode)) # noqa: F821 146 | except Exception: 147 | pass 148 | 149 | for (from_type, to_type) in casts: 150 | converted, converted_value = conditional_cast(from_type, to_type) 151 | if converted: 152 | return converted_value 153 | 154 | raise ValueError( 155 | "Type mismatch ({} vs. {}) with values ({} vs. {}) for config " 156 | "key: {}".format(original_type, replacement_type, original, 157 | replacement, full_key)) 158 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Union 3 | import cv2 4 | from PIL import Image 5 | import lmdb 6 | import numpy as np 7 | import pyarrow as pa 8 | import torch 9 | from torch.utils.data import Dataset 10 | from torchvision.transforms import functional as F 11 | from bert.tokenization_bert import BertTokenizer 12 | 13 | info = { 14 | 'refcoco': { 15 | 'train': 42404, 16 | 'val': 3811, 17 | 'val-test': 3811, 18 | 'testA': 1975, 19 | 'testB': 1810 20 | }, 21 | 'refcoco+': { 22 | 'train': 42278, 23 | 'val': 3805, 24 | 'val-test': 3805, 25 | 'testA': 1975, 26 | 'testB': 1798 27 | }, 28 | 'refcocog_u': { 29 | 'train': 42226, 30 | 'val': 2573, 31 | 'val-test': 2573, 32 | 'test': 5023 33 | }, 34 | 'refcocog_g': { 35 | 'train': 44822, 36 | 'val': 5000, 37 | 'val-test': 5000 38 | } 39 | } 40 | _tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 41 | 42 | 43 | def tokenize(texts: Union[str, List[str]], 44 | context_length: int = 77, 45 | truncate: bool = False) -> torch.LongTensor: 46 | """ 47 | Returns the tokenized representation of given input string(s) 48 | 49 | Parameters 50 | ---------- 51 | texts : Union[str, List[str]] 52 | An input string or a list of input strings to tokenize 53 | 54 | context_length : int 55 | The context length to use; all CLIP models use 77 as the context length 56 | 57 | truncate: bool 58 | Whether to truncate the text in case its encoding is longer than the context length 59 | 60 | Returns 61 | ------- 62 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 63 | """ 64 | l_mask = [0] * context_length 65 | result = [0] * context_length 66 | 67 | tokens = _tokenizer.encode(text=texts, add_special_tokens=True) 68 | tokens = tokens[:context_length] 69 | result[:len(tokens)] = tokens 70 | l_mask[:len(tokens)] = [1]*len(tokens) 71 | 72 | result = torch.tensor(result).unsqueeze(0) 73 | l_mask = torch.tensor(l_mask).unsqueeze(0) 74 | return result, l_mask 75 | 76 | 77 | def loads_pyarrow(buf): 78 | """ 79 | Args: 80 | buf: the output of `dumps`. 81 | """ 82 | return pa.deserialize(buf) 83 | 84 | 85 | class RefDataset(Dataset): 86 | def __init__(self, lmdb_dir, mask_dir, dataset, split, mode, input_size, 87 | word_length): 88 | super(RefDataset, self).__init__() 89 | self.lmdb_dir = lmdb_dir 90 | self.mask_dir = mask_dir 91 | self.dataset = dataset 92 | self.split = split 93 | self.mode = mode 94 | self.input_size = (input_size, input_size) 95 | #self.mask_size = [13, 26, 52] 96 | self.word_length = word_length 97 | self.mean = torch.tensor([0.485, 0.456, 0.406]).reshape(3, 1, 1) 98 | self.std = torch.tensor([0.229, 0.224, 0.225]).reshape(3, 1, 1) 99 | self.length = info[dataset][split] 100 | self.env = None 101 | # self.coco_transforms = make_coco_transforms(mode, cautious=False) 102 | 103 | def _init_db(self): 104 | self.env = lmdb.open(self.lmdb_dir, 105 | subdir=os.path.isdir(self.lmdb_dir), 106 | readonly=True, 107 | lock=False, 108 | readahead=False, 109 | meminit=False) 110 | with self.env.begin(write=False) as txn: 111 | self.length = loads_pyarrow(txn.get(b'__len__')) 112 | self.keys = loads_pyarrow(txn.get(b'__keys__')) 113 | 114 | def __len__(self): 115 | return self.length 116 | 117 | def __getitem__(self, index): 118 | # Delay loading LMDB data until after initialization: https://github.com/chainer/chainermn/issues/129 119 | if self.env is None: 120 | self._init_db() 121 | env = self.env 122 | with env.begin(write=False) as txn: 123 | byteflow = txn.get(self.keys[index]) 124 | ref = loads_pyarrow(byteflow) 125 | # img 126 | ori_img = cv2.imdecode(np.frombuffer(ref['img'], np.uint8), 127 | cv2.IMREAD_COLOR) 128 | img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB) 129 | img_size = img.shape[:2] 130 | # mask 131 | seg_id = ref['seg_id'] 132 | mask_dir = os.path.join(self.mask_dir, str(seg_id) + '.png') 133 | # sentences 134 | idx = np.random.choice(ref['num_sents']) 135 | sents = ref['sents'] 136 | # transform 137 | # mask transform 138 | mask = cv2.imdecode(np.frombuffer(ref['mask'], np.uint8), 139 | cv2.IMREAD_GRAYSCALE) 140 | mask = mask / 255. 141 | if self.mode == 'train': 142 | sent = sents[idx] 143 | # sentence -> vector 144 | img, mask, sent = self.convert(img, mask, sent, inference=False) 145 | word_vec, pad_mask = tokenize(sent, self.word_length, True) 146 | return img, word_vec, mask, pad_mask 147 | elif self.mode == 'val': 148 | # sentence -> vector 149 | sent = sents[-1] 150 | word_vec, pad_mask = tokenize(sent, self.word_length, True) 151 | img, mask, sent = self.convert(img, mask, sent, inference=False) 152 | return img, word_vec, mask, pad_mask 153 | else: 154 | # sentence -> vector 155 | word_vecs = [] 156 | pad_masks = [] 157 | for sent in sents: 158 | word_vec, pad_mask = tokenize(sent, self.word_length, True) 159 | word_vecs.append(word_vec) 160 | pad_masks.append(pad_mask) 161 | img, mask, sent = self.convert(img, mask, sent, inference=True) 162 | return ori_img, img, word_vecs, mask, pad_masks, seg_id, sents, 163 | 164 | def convert(self, img, mask, sent, inference=False): 165 | img = Image.fromarray(np.uint8(img)) 166 | mask = Image.fromarray(np.uint8(mask), mode="P") 167 | img = F.resize(img, self.input_size) 168 | if not inference: 169 | mask = F.resize(mask, self.input_size, interpolation=Image.NEAREST) 170 | img = F.to_tensor(img) 171 | mask = torch.as_tensor(np.asarray(mask).copy(), dtype=torch.int64) 172 | img = F.normalize(img, mean=self.mean, std=self.std) 173 | return img, mask, sent 174 | 175 | 176 | def __repr__(self): 177 | return self.__class__.__name__ + "(" + \ 178 | f"db_path={self.lmdb_dir}, " + \ 179 | f"dataset={self.dataset}, " + \ 180 | f"split={self.split}, " + \ 181 | f"mode={self.mode}, " + \ 182 | f"input_size={self.input_size}, " + \ 183 | f"word_length={self.word_length}" 184 | 185 | 186 | -------------------------------------------------------------------------------- /utils/dataset_open.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Union 3 | import cv2 4 | from PIL import Image 5 | import lmdb 6 | import numpy as np 7 | import pyarrow as pa 8 | import torch 9 | from torch.utils.data import Dataset 10 | from torchvision.transforms import functional as F 11 | from bert.tokenization_bert import BertTokenizer 12 | import random 13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 14 | # from nltk import word_tokenize, pos_tag 15 | info = { 16 | 'refcoco': { 17 | 'train_seen': 35473, 18 | 'val_seen': 3175, 19 | 'val_unseen': 445, 20 | 'test_seen':3200, 21 | 'test_unseen':394 22 | }, 23 | 'refcoco+': { 24 | 'train_seen': 35375, 25 | 'val_seen': 3171, 26 | 'val_unseen': 444, 27 | 'test_seen':3189, 28 | 'test_unseen':394 29 | }, 30 | 'refcoco_u': { 31 | 'train_seen': 33093, 32 | 'val_seen': 2000, 33 | 'val_unseen': 386, 34 | 'test_seen': 3935, 35 | 'test_unseen': 759, 36 | }, 37 | 'refcoco_g': { 38 | 'train_seen': 35105, 39 | 'val_seen': 3923, 40 | 'val_unseen': 760 41 | } 42 | } 43 | _tokenizer = _Tokenizer() 44 | 45 | 46 | def tokenize(texts: Union[str, List[str]], 47 | context_length: int = 77, 48 | truncate: bool = False) -> torch.LongTensor: 49 | """ 50 | Returns the tokenized representation of given input string(s) 51 | 52 | Parameters 53 | ---------- 54 | texts : Union[str, List[str]] 55 | An input string or a list of input strings to tokenize 56 | 57 | context_length : int 58 | The context length to use; all CLIP models use 77 as the context length 59 | 60 | truncate: bool 61 | Whether to truncate the text in case its encoding is longer than the context length 62 | 63 | Returns 64 | ------- 65 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 66 | """ 67 | if isinstance(texts, str): 68 | texts = [texts] 69 | 70 | sot_token = _tokenizer.encoder["<|startoftext|>"] 71 | eot_token = _tokenizer.encoder["<|endoftext|>"] 72 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] 73 | for text in texts] 74 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 75 | 76 | for i, tokens in enumerate(all_tokens): 77 | if len(tokens) > context_length: 78 | if truncate: 79 | tokens = tokens[:context_length] 80 | tokens[-1] = eot_token 81 | else: 82 | raise RuntimeError( 83 | f"Input {texts[i]} is too long for context length {context_length}" 84 | ) 85 | result[i, :len(tokens)] = torch.tensor(tokens) 86 | 87 | return result 88 | 89 | 90 | def loads_pyarrow(buf): 91 | """ 92 | Args: 93 | buf: the output of `dumps`. 94 | """ 95 | return pa.deserialize(buf) 96 | 97 | 98 | class RefDataset(Dataset): 99 | def __init__(self, lmdb_dir, mask_dir, dataset, split, mode, input_size, 100 | word_length): 101 | super(RefDataset, self).__init__() 102 | self.lmdb_dir = lmdb_dir 103 | self.mask_dir = mask_dir 104 | self.dataset = dataset 105 | self.split = split 106 | self.mode = mode 107 | self.input_size = (input_size, input_size) 108 | #self.mask_size = [13, 26, 52] 109 | self.word_length = word_length 110 | self.mean = torch.tensor([0.485, 0.456, 0.406]).reshape(3, 1, 1) 111 | self.std = torch.tensor([0.229, 0.224, 0.225]).reshape(3, 1, 1) 112 | self.length = info[dataset][split] 113 | self.env = None 114 | def _init_db(self): 115 | self.env = lmdb.open(self.lmdb_dir, 116 | subdir=os.path.isdir(self.lmdb_dir), 117 | readonly=True, 118 | lock=False, 119 | readahead=False, 120 | meminit=False) 121 | with self.env.begin(write=False) as txn: 122 | self.length = loads_pyarrow(txn.get(b'__len__')) 123 | self.keys = loads_pyarrow(txn.get(b'__keys__')) 124 | 125 | def __len__(self): 126 | return self.length 127 | 128 | def __getitem__(self, index): 129 | # Delay loading LMDB data until after initialization: https://github.com/chainer/chainermn/issues/129 130 | if self.env is None: 131 | self._init_db() 132 | env = self.env 133 | with env.begin(write=False) as txn: 134 | byteflow = txn.get(self.keys[index]) 135 | ref = loads_pyarrow(byteflow) 136 | # img 137 | ori_img = cv2.imdecode(np.frombuffer(ref['img'], np.uint8), 138 | cv2.IMREAD_COLOR) 139 | img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB) 140 | img_size = img.shape[:2] 141 | # mask 142 | seg_id = ref['seg_id'] 143 | mask_dir = os.path.join(self.mask_dir, str(seg_id) + '.png') 144 | # sentences 145 | idx = np.random.choice(ref['num_sents']) 146 | sents = ref['sents'] 147 | # transform 148 | # mask transform 149 | mask = cv2.imdecode(np.frombuffer(ref['mask'], np.uint8), 150 | cv2.IMREAD_GRAYSCALE) 151 | mask = mask / 255. 152 | if self.mode == 'train': 153 | sent = sents[idx] 154 | # sentence -> vector 155 | img, mask, sent = self.convert(img, mask, sent, inference=False) 156 | word_vec = tokenize(sent, self.word_length, True).squeeze(0) 157 | pad_mask = (word_vec != 0).float() 158 | return img, word_vec, mask, pad_mask 159 | elif self.mode == 'val': 160 | # sentence -> vector 161 | sent = sents[-1] 162 | word_vec = tokenize(sent, self.word_length, True).squeeze(0) 163 | pad_mask = (word_vec != 0).float() 164 | img, mask, sent = self.convert(img, mask, sent, inference=False) 165 | return img, word_vec, mask, pad_mask 166 | else: 167 | # sentence -> vector 168 | word_vecs = [] 169 | pad_masks = [] 170 | for sent in sents: 171 | word_vec = tokenize(sent, self.word_length, True).squeeze(0) 172 | word_vecs.append(word_vec) 173 | pad_mask = (word_vec != 0).float() 174 | pad_masks.append(pad_mask) 175 | img, mask, sent = self.convert(img, mask, sent, inference=True) 176 | return ori_img, img, word_vecs, mask, pad_masks, seg_id, sents 177 | 178 | def convert(self, img, mask, sent, inference=False): 179 | img = Image.fromarray(np.uint8(img)) 180 | mask = Image.fromarray(np.uint8(mask), mode="P") 181 | img = F.resize(img, self.input_size) 182 | if not inference: 183 | mask = F.resize(mask, self.input_size, interpolation=Image.NEAREST) 184 | img = F.to_tensor(img) 185 | mask = torch.as_tensor(np.asarray(mask).copy(), dtype=torch.int64) 186 | img = F.normalize(img, mean=self.mean, std=self.std) 187 | return img, mask, sent 188 | 189 | def __repr__(self): 190 | return self.__class__.__name__ + "(" + \ 191 | f"db_path={self.lmdb_dir}, " + \ 192 | f"dataset={self.dataset}, " + \ 193 | f"split={self.split}, " + \ 194 | f"mode={self.mode}, " + \ 195 | f"input_size={self.input_size}, " + \ 196 | f"word_length={self.word_length}" 197 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | from PIL import Image 5 | from loguru import logger 6 | import sys 7 | import inspect 8 | from timm.scheduler.cosine_lr import CosineLRScheduler 9 | import torch 10 | from torch import nn 11 | import torch.distributed as dist 12 | 13 | 14 | def init_random_seed(seed=None, device='cuda', rank=0, world_size=1): 15 | """Initialize random seed.""" 16 | if seed is not None: 17 | return seed 18 | 19 | # Make sure all ranks share the same random seed to prevent 20 | # some potential bugs. Please refer to 21 | # https://github.com/open-mmlab/mmdetection/issues/6339 22 | seed = np.random.randint(2**31) 23 | if world_size == 1: 24 | return seed 25 | 26 | if rank == 0: 27 | random_num = torch.tensor(seed, dtype=torch.int32, device=device) 28 | else: 29 | random_num = torch.tensor(0, dtype=torch.int32, device=device) 30 | dist.broadcast(random_num, src=0) 31 | return random_num.item() 32 | 33 | 34 | def set_random_seed(seed, deterministic=False): 35 | """Set random seed.""" 36 | random.seed(seed) 37 | np.random.seed(seed) 38 | torch.manual_seed(seed) 39 | torch.cuda.manual_seed_all(seed) 40 | if deterministic: 41 | torch.backends.cudnn.deterministic = True 42 | torch.backends.cudnn.benchmark = False 43 | 44 | 45 | @torch.no_grad() 46 | def concat_all_gather(tensor): 47 | """ 48 | Performs all_gather operation on the provided tensors. 49 | *** Warning ***: torch.distributed.all_gather has no gradient. 50 | """ 51 | tensor = tensor.contiguous() 52 | tensors_gather = [ 53 | torch.ones_like(tensor) 54 | for _ in range(torch.distributed.get_world_size()) 55 | ] 56 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 57 | 58 | output = torch.cat(tensors_gather, dim=0) 59 | return output 60 | 61 | 62 | 63 | def worker_init_fn(worker_id, num_workers, rank, seed): 64 | # The seed of each worker equals to 65 | # num_worker * rank + worker_id + user_seed 66 | worker_seed = num_workers * rank + worker_id + seed 67 | np.random.seed(worker_seed) 68 | random.seed(worker_seed) 69 | 70 | 71 | class AverageMeter(object): 72 | """Computes and stores the average and current value""" 73 | 74 | def __init__(self, name, fmt=":f"): 75 | self.name = name 76 | self.fmt = fmt 77 | self.reset() 78 | 79 | def reset(self): 80 | self.val = 0 81 | self.avg = 0 82 | self.sum = 0 83 | self.count = 0 84 | 85 | def update(self, val, n=1): 86 | self.val = val 87 | self.sum += val * n 88 | self.count += n 89 | self.avg = self.sum / self.count 90 | 91 | def __str__(self): 92 | if self.name == "Lr": 93 | fmtstr = "{name}={val" + self.fmt + "}" 94 | else: 95 | fmtstr = "{name}={val" + self.fmt + "} ({avg" + self.fmt + "})" 96 | return fmtstr.format(**self.__dict__) 97 | 98 | 99 | class ProgressMeter(object): 100 | def __init__(self, num_batches, meters, prefix=""): 101 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 102 | self.meters = meters 103 | self.prefix = prefix 104 | 105 | def display(self, batch): 106 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 107 | entries += [str(meter) for meter in self.meters] 108 | logger.info(" ".join(entries)) 109 | 110 | def _get_batch_fmtstr(self, num_batches): 111 | num_digits = len(str(num_batches // 1)) 112 | fmt = "{:" + str(num_digits) + "d}" 113 | return "[" + fmt + "/" + fmt.format(num_batches) + "]" 114 | 115 | 116 | def trainMetricGPU(output, target, threshold=0.35, pr_iou=0.5): 117 | assert (output.dim() in [2, 3, 4]) 118 | assert output.shape == target.shape 119 | output = output.flatten(1) 120 | target = target.flatten(1) 121 | output = torch.sigmoid(output) 122 | output[output < threshold] = 0. 123 | output[output >= threshold] = 1. 124 | # inter & union 125 | inter = (output.bool() & target.bool()).sum(dim=1) # b 126 | union = (output.bool() | target.bool()).sum(dim=1) # b 127 | ious = inter / (union + 1e-6) # 0 ~ 1 128 | # iou & pr@5 129 | iou = ious.mean() 130 | prec = (ious > pr_iou).float().mean() 131 | return 100. * iou, 100. * prec 132 | 133 | def ValMetricGPU(output, target, threshold=0.35): 134 | assert output.size(0) == 1 135 | output = output.flatten(1) 136 | target = target.flatten(1) 137 | output = torch.sigmoid(output) 138 | output[output < threshold] = 0. 139 | output[output >= threshold] = 1. 140 | # inter & union 141 | inter = (output.bool() & target.bool()).sum(dim=1) # b 142 | union = (output.bool() | target.bool()).sum(dim=1) # b 143 | ious = inter / (union + 1e-6) # 0 ~ 1 144 | return ious 145 | 146 | 147 | def intersectionAndUnionGPU(output, target, K, threshold=0.5): 148 | # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. 149 | assert (output.dim() in [1, 2, 3]) 150 | assert output.shape == target.shape 151 | output = output.view(-1) 152 | target = target.view(-1) 153 | 154 | output = torch.sigmoid(output) 155 | output[output < threshold] = 0. 156 | output[output >= threshold] = 1. 157 | 158 | intersection = output[output == target] 159 | area_intersection = torch.histc(intersection.float(), 160 | bins=K, 161 | min=0, 162 | max=K - 1) 163 | area_output = torch.histc(output.float(), bins=K, min=0, max=K - 1) 164 | area_target = torch.histc(target.float(), bins=K, min=0, max=K - 1) 165 | area_union = area_output + area_target - area_intersection 166 | return area_intersection[1], area_union[1] 167 | 168 | 169 | def group_weight(weight_group, module, lr): 170 | group_decay = [] 171 | group_no_decay = [] 172 | for m in module.modules(): 173 | if isinstance(m, nn.Linear): 174 | group_decay.append(m.weight) 175 | if m.bias is not None: 176 | group_no_decay.append(m.bias) 177 | elif isinstance(m, nn.modules.conv._ConvNd): 178 | group_decay.append(m.weight) 179 | if m.bias is not None: 180 | group_no_decay.append(m.bias) 181 | elif isinstance(m, nn.modules.batchnorm._BatchNorm): 182 | if m.weight is not None: 183 | group_no_decay.append(m.weight) 184 | if m.bias is not None: 185 | group_no_decay.append(m.bias) 186 | assert len(list( 187 | module.parameters())) == len(group_decay) + len(group_no_decay) 188 | weight_group.append(dict(params=group_decay, lr=lr)) 189 | weight_group.append(dict(params=group_no_decay, weight_decay=.0, lr=lr)) 190 | return weight_group 191 | 192 | 193 | def colorize(gray, palette): 194 | # gray: numpy array of the label and 1*3N size list palette 195 | color = Image.fromarray(gray.astype(np.uint8)).convert('P') 196 | color.putpalette(palette) 197 | return color 198 | 199 | 200 | def find_free_port(): 201 | import socket 202 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 203 | # Binding to port 0 will cause the OS to find an available port for us 204 | sock.bind(("", 0)) 205 | port = sock.getsockname()[1] 206 | sock.close() 207 | # NOTE: there is still a chance the port could be taken by other processes. 208 | return port 209 | 210 | 211 | def get_caller_name(depth=0): 212 | """ 213 | Args: 214 | depth (int): Depth of caller conext, use 0 for caller depth. 215 | Default value: 0. 216 | 217 | Returns: 218 | str: module name of the caller 219 | """ 220 | # the following logic is a little bit faster than inspect.stack() logic 221 | frame = inspect.currentframe().f_back 222 | for _ in range(depth): 223 | frame = frame.f_back 224 | 225 | return frame.f_globals["__name__"] 226 | 227 | 228 | class StreamToLoguru: 229 | """ 230 | stream object that redirects writes to a logger instance. 231 | """ 232 | def __init__(self, level="INFO", caller_names=("apex", "pycocotools")): 233 | """ 234 | Args: 235 | level(str): log level string of loguru. Default value: "INFO". 236 | caller_names(tuple): caller names of redirected module. 237 | Default value: (apex, pycocotools). 238 | """ 239 | self.level = level 240 | self.linebuf = "" 241 | self.caller_names = caller_names 242 | 243 | def write(self, buf): 244 | full_name = get_caller_name(depth=1) 245 | module_name = full_name.rsplit(".", maxsplit=-1)[0] 246 | if module_name in self.caller_names: 247 | for line in buf.rstrip().splitlines(): 248 | # use caller level log 249 | logger.opt(depth=2).log(self.level, line.rstrip()) 250 | else: 251 | sys.__stdout__.write(buf) 252 | 253 | def flush(self): 254 | pass 255 | 256 | 257 | def redirect_sys_output(log_level="INFO"): 258 | redirect_logger = StreamToLoguru(log_level) 259 | sys.stderr = redirect_logger 260 | sys.stdout = redirect_logger 261 | 262 | 263 | def setup_logger(save_dir, distributed_rank=0, filename="log.txt", mode="a"): 264 | """setup logger for training and testing. 265 | Args: 266 | save_dir(str): location to save log file 267 | distributed_rank(int): device rank when multi-gpu environment 268 | filename (string): log save name. 269 | mode(str): log file write mode, `append` or `override`. default is `a`. 270 | 271 | Return: 272 | logger instance. 273 | """ 274 | loguru_format = ( 275 | "{time:YYYY-MM-DD HH:mm:ss} | " 276 | "{level: <8} | " 277 | "{name}:{line} - {message}") 278 | 279 | logger.remove() 280 | save_file = os.path.join(save_dir, filename) 281 | if mode == "o" and os.path.exists(save_file): 282 | os.remove(save_file) 283 | # only keep logger in rank0 process 284 | if distributed_rank == 0: 285 | logger.add( 286 | sys.stderr, 287 | format=loguru_format, 288 | level="INFO", 289 | enqueue=True, 290 | ) 291 | logger.add(save_file) 292 | 293 | # redirect stdout/stderr to loguru 294 | redirect_sys_output("INFO") 295 | 296 | 297 | def build_scheduler(config, optimizer, n_iter_per_epoch): 298 | num_steps = int(config.epochs * n_iter_per_epoch) 299 | warmup_steps = int(config.warmup_epochs * n_iter_per_epoch) 300 | 301 | lr_scheduler = CosineLRScheduler( 302 | optimizer, 303 | t_initial=num_steps, 304 | lr_min=config.min_lr, 305 | warmup_lr_init=config.warmup_lr, 306 | warmup_t=warmup_steps, 307 | cycle_limit=1, 308 | t_in_epochs=False, 309 | ) 310 | return lr_scheduler 311 | 312 | def collate_fn(batch): 313 | batch = list(zip(*batch)) 314 | return tuple(batch) -------------------------------------------------------------------------------- /utils/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | --------------------------------------------------------------------------------