├── .gitignore
├── LICENSE
├── README.md
├── bert
├── activations.py
├── configuration_bert.py
├── configuration_utils.py
├── file_utils.py
├── generation_utils.py
├── modeling_bert.py
├── modeling_utils.py
├── tokenization_bert.py
├── tokenization_utils.py
└── tokenization_utils_base.py
├── config
├── config.yaml
└── open.yaml
├── data
├── READEME.md
└── coco_id.json
├── engine
├── __init__.py
└── engine.py
├── image
└── framework.jpg
├── model
├── __init__.py
├── backbone.py
├── layers.py
├── mmcv_custom
│ ├── __init__.py
│ └── checkpoint.py
└── segmenter.py
├── requirements.txt
├── test.py
├── tools
├── data_process.py
├── folder2lmdb.py
├── latency.py
├── prepare_datasets.md
└── refer.py
├── train.py
└── utils
├── __init__.py
├── box_ops.py
├── bpe_simple_vocab_16e6.txt.gz
├── config.py
├── dataset.py
├── dataset_open.py
├── misc.py
└── simple_tokenizer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | /exp
2 | /wandb/**
3 | **/__pycache__
4 | /train_open.py
5 | /.vscode
6 | config/config.yaml
7 | config/open.yaml
8 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Toneyaya
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CGFormer
2 | The official PyTorch implementation of the CVPR 2023 paper "Contrastive Grouping with Transformer for Referring Image Segmentation".
3 |
4 | This paper first introduces learnable query tokens to represent objects and then alternately queries linguistic features and groups visual features into the query tokens for object-aware cross-modal reasoning. CGFormer achieves cross-level interaction by jointly updating the query tokens and decoding masks in every two consecutive layers. In addition, we introduce new splits on datasets for evaluating generalization for referring image segmentation models.
5 |
6 | ## Framework
7 |
8 |
9 |
10 |
11 | ## Preparation
12 |
13 | 1. Environment
14 | - [PyTorch](www.pytorch.org)
15 | - Other dependencies in `requirements.txt`
16 | 2. Datasets
17 | - The detailed instruction is in [prepare_datasets](data/READEME.md)
18 | 3. Pretrained weights
19 | - [Swin-Base-window12](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth)
20 | 4. Our checkpoints [hugging face](https://huggingface.co/Toneyaya/CGFormer/tree/main)
21 |
22 | ## Train and Test (RIS)
23 |
24 | This implementation only supports **multi-gpu**, **DistributedDataParallel** training, which is faster and simpler; single-gpu or DataParallel training is not supported. Besides, the evaluation only supports single-gpu mode.
25 |
26 | To do training of CGFormer with 8 GPUs, run:
27 |
28 | ```
29 | python -u train.py --config config/config.yaml
30 | ```
31 |
32 | To do evaluation of CGFormer with 1 GPU, run:
33 | ```
34 | CUDA_VISIBLE_DEVICES=0 python -u test.py --config config/refcoco/config.yaml --opts TEST.test_split val TEST.test_lmdb path/val.lmdb TRAIN.weight path/checkpoint.pth
35 | ```
36 | ## License
37 |
38 | This project is under the MIT license. See [LICENSE](LICENSE) for details.
39 |
40 |
41 | ## Citation
42 | If you find our work useful in your research, please consider citing:
43 | ```
44 | @InProceedings{Tang_2023_CVPR,
45 | author = {Tang, Jiajin and Zheng, Ge and Shi, Cheng and Yang, Sibei},
46 | title = {Contrastive Grouping With Transformer for Referring Image Segmentation},
47 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
48 | month = {June},
49 | year = {2023},
50 | pages = {23570-23580}
51 | }
52 | ```
53 |
54 | Many thanks to these excellent opensource projects
55 | [CRIS](https://github.com/DerrickWang005/CRIS.pytorch/tree/master) and [LAVT](https://github.com/yz93/LAVT-RIS).
--------------------------------------------------------------------------------
/bert/activations.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 |
4 | import torch
5 | import torch.nn.functional as F
6 |
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 |
11 | def swish(x):
12 | return x * torch.sigmoid(x)
13 |
14 |
15 | def _gelu_python(x):
16 | """ Original Implementation of the gelu activation function in Google Bert repo when initially created.
17 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
18 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
19 | This is now written in C in torch.nn.functional
20 | Also see https://arxiv.org/abs/1606.08415
21 | """
22 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
23 |
24 |
25 | def gelu_new(x):
26 | """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
27 | Also see https://arxiv.org/abs/1606.08415
28 | """
29 | return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
30 |
31 |
32 | if torch.__version__ < "1.4.0":
33 | gelu = _gelu_python
34 | else:
35 | gelu = F.gelu
36 |
37 |
38 | def gelu_fast(x):
39 | return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
40 |
41 |
42 | ACT2FN = {
43 | "relu": F.relu,
44 | "swish": swish,
45 | "gelu": gelu,
46 | "tanh": torch.tanh,
47 | "gelu_new": gelu_new,
48 | "gelu_fast": gelu_fast,
49 | }
50 |
51 |
52 | def get_activation(activation_string):
53 | if activation_string in ACT2FN:
54 | return ACT2FN[activation_string]
55 | else:
56 | raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys())))
57 |
--------------------------------------------------------------------------------
/bert/configuration_bert.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ BERT model configuration """
17 |
18 |
19 | import logging
20 |
21 | from .configuration_utils import PretrainedConfig
22 |
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 | BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
27 | "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
28 | "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
29 | "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
30 | "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
31 | "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
32 | "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
33 | "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
34 | "bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
35 | "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
36 | "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
37 | "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
38 | "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
39 | "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
40 | "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json",
41 | "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json",
42 | "cl-tohoku/bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese/config.json",
43 | "cl-tohoku/bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking/config.json",
44 | "cl-tohoku/bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char/config.json",
45 | "cl-tohoku/bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking/config.json",
46 | "TurkuNLP/bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json",
47 | "TurkuNLP/bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json",
48 | "wietsedv/bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/config.json",
49 | # See all BERT models at https://huggingface.co/models?filter=bert
50 | }
51 |
52 |
53 | class BertConfig(PretrainedConfig):
54 | r"""
55 | This is the configuration class to store the configuration of a :class:`~transformers.BertModel`.
56 | It is used to instantiate an BERT model according to the specified arguments, defining the model
57 | architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
58 | the BERT `bert-base-uncased `__ architecture.
59 |
60 | Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used
61 | to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig`
62 | for more information.
63 |
64 |
65 | Args:
66 | vocab_size (:obj:`int`, optional, defaults to 30522):
67 | Vocabulary size of the BERT model. Defines the different tokens that
68 | can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.BertModel`.
69 | hidden_size (:obj:`int`, optional, defaults to 768):
70 | Dimensionality of the encoder layers and the pooler layer.
71 | num_hidden_layers (:obj:`int`, optional, defaults to 12):
72 | Number of hidden layers in the Transformer encoder.
73 | num_attention_heads (:obj:`int`, optional, defaults to 12):
74 | Number of attention heads for each attention layer in the Transformer encoder.
75 | intermediate_size (:obj:`int`, optional, defaults to 3072):
76 | Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
77 | hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "gelu"):
78 | The non-linear activation function (function or string) in the encoder and pooler.
79 | If string, "gelu", "relu", "swish" and "gelu_new" are supported.
80 | hidden_dropout_prob (:obj:`float`, optional, defaults to 0.1):
81 | The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
82 | attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1):
83 | The dropout ratio for the attention probabilities.
84 | max_position_embeddings (:obj:`int`, optional, defaults to 512):
85 | The maximum sequence length that this model might ever be used with.
86 | Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
87 | type_vocab_size (:obj:`int`, optional, defaults to 2):
88 | The vocabulary size of the `token_type_ids` passed into :class:`~transformers.BertModel`.
89 | initializer_range (:obj:`float`, optional, defaults to 0.02):
90 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
91 | layer_norm_eps (:obj:`float`, optional, defaults to 1e-12):
92 | The epsilon used by the layer normalization layers.
93 | gradient_checkpointing (:obj:`bool`, optional, defaults to False):
94 | If True, use gradient checkpointing to save memory at the expense of slower backward pass.
95 |
96 | Example::
97 |
98 | >>> from transformers import BertModel, BertConfig
99 |
100 | >>> # Initializing a BERT bert-base-uncased style configuration
101 | >>> configuration = BertConfig()
102 |
103 | >>> # Initializing a model from the bert-base-uncased style configuration
104 | >>> model = BertModel(configuration)
105 |
106 | >>> # Accessing the model configuration
107 | >>> configuration = model.config
108 | """
109 | model_type = "bert"
110 |
111 | def __init__(
112 | self,
113 | vocab_size=30522,
114 | hidden_size=768,
115 | num_hidden_layers=12,
116 | num_attention_heads=12,
117 | intermediate_size=3072,
118 | hidden_act="gelu",
119 | hidden_dropout_prob=0.1,
120 | attention_probs_dropout_prob=0.1,
121 | max_position_embeddings=512,
122 | type_vocab_size=2,
123 | initializer_range=0.02,
124 | layer_norm_eps=1e-12,
125 | pad_token_id=0,
126 | gradient_checkpointing=False,
127 | **kwargs
128 | ):
129 | super().__init__(pad_token_id=pad_token_id, **kwargs)
130 |
131 | self.vocab_size = vocab_size
132 | self.hidden_size = hidden_size
133 | self.num_hidden_layers = num_hidden_layers
134 | self.num_attention_heads = num_attention_heads
135 | self.hidden_act = hidden_act
136 | self.intermediate_size = intermediate_size
137 | self.hidden_dropout_prob = hidden_dropout_prob
138 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
139 | self.max_position_embeddings = max_position_embeddings
140 | self.type_vocab_size = type_vocab_size
141 | self.initializer_range = initializer_range
142 | self.layer_norm_eps = layer_norm_eps
143 | self.gradient_checkpointing = gradient_checkpointing
144 |
--------------------------------------------------------------------------------
/bert/configuration_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ Configuration base class and utilities."""
17 |
18 |
19 | import copy
20 | import json
21 | import logging
22 | import os
23 | from typing import Dict, Tuple
24 |
25 | from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url
26 |
27 |
28 | logger = logging.getLogger(__name__)
29 |
30 |
31 | class PretrainedConfig(object):
32 | r""" Base class for all configuration classes.
33 | Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations.
34 |
35 | Note:
36 | A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to initialize a model does **not** load the model weights.
37 | It only affects the model's configuration.
38 |
39 | Class attributes (overridden by derived classes):
40 | - ``model_type``: a string that identifies the model type, that we serialize into the JSON file, and that we use to recreate the correct object in :class:`~transformers.AutoConfig`.
41 |
42 | Args:
43 | finetuning_task (:obj:`string` or :obj:`None`, `optional`, defaults to :obj:`None`):
44 | Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint.
45 | num_labels (:obj:`int`, `optional`, defaults to `2`):
46 | Number of classes to use when the model is a classification model (sequences/tokens)
47 | output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`False`):
48 | Should the model returns all hidden-states.
49 | output_attentions (:obj:`bool`, `optional`, defaults to :obj:`False`):
50 | Should the model returns all attentions.
51 | torchscript (:obj:`bool`, `optional`, defaults to :obj:`False`):
52 | Is the model used with Torchscript (for PyTorch models).
53 | """
54 | model_type: str = ""
55 |
56 | def __init__(self, **kwargs):
57 | # Attributes with defaults
58 | self.output_hidden_states = kwargs.pop("output_hidden_states", False)
59 | self.output_attentions = kwargs.pop("output_attentions", False)
60 | self.use_cache = kwargs.pop("use_cache", True) # Not used by all models
61 | self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
62 | self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
63 | self.pruned_heads = kwargs.pop("pruned_heads", {})
64 |
65 | # Is decoder is used in encoder-decoder models to differentiate encoder from decoder
66 | self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
67 | self.is_decoder = kwargs.pop("is_decoder", False)
68 |
69 | # Parameters for sequence generation
70 | self.max_length = kwargs.pop("max_length", 20)
71 | self.min_length = kwargs.pop("min_length", 0)
72 | self.do_sample = kwargs.pop("do_sample", False)
73 | self.early_stopping = kwargs.pop("early_stopping", False)
74 | self.num_beams = kwargs.pop("num_beams", 1)
75 | self.temperature = kwargs.pop("temperature", 1.0)
76 | self.top_k = kwargs.pop("top_k", 50)
77 | self.top_p = kwargs.pop("top_p", 1.0)
78 | self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
79 | self.length_penalty = kwargs.pop("length_penalty", 1.0)
80 | self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
81 | self.bad_words_ids = kwargs.pop("bad_words_ids", None)
82 | self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
83 |
84 | # Fine-tuning task arguments
85 | self.architectures = kwargs.pop("architectures", None)
86 | self.finetuning_task = kwargs.pop("finetuning_task", None)
87 | self.id2label = kwargs.pop("id2label", None)
88 | self.label2id = kwargs.pop("label2id", None)
89 | if self.id2label is not None:
90 | kwargs.pop("num_labels", None)
91 | self.id2label = dict((int(key), value) for key, value in self.id2label.items())
92 | # Keys are always strings in JSON so convert ids to int here.
93 | else:
94 | self.num_labels = kwargs.pop("num_labels", 2)
95 |
96 | # Tokenizer arguments TODO: eventually tokenizer and models should share the same config
97 | self.prefix = kwargs.pop("prefix", None)
98 | self.bos_token_id = kwargs.pop("bos_token_id", None)
99 | self.pad_token_id = kwargs.pop("pad_token_id", None)
100 | self.eos_token_id = kwargs.pop("eos_token_id", None)
101 | self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
102 |
103 | # task specific arguments
104 | self.task_specific_params = kwargs.pop("task_specific_params", None)
105 |
106 | # TPU arguments
107 | self.xla_device = kwargs.pop("xla_device", None)
108 |
109 | # Additional attributes without default values
110 | for key, value in kwargs.items():
111 | try:
112 | setattr(self, key, value)
113 | except AttributeError as err:
114 | logger.error("Can't set {} with value {} for {}".format(key, value, self))
115 | raise err
116 |
117 | @property
118 | def num_labels(self):
119 | return len(self.id2label)
120 |
121 | @num_labels.setter
122 | def num_labels(self, num_labels):
123 | self.id2label = {i: "LABEL_{}".format(i) for i in range(num_labels)}
124 | self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
125 |
126 | def save_pretrained(self, save_directory):
127 | """
128 | Save a configuration object to the directory `save_directory`, so that it
129 | can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method.
130 |
131 | Args:
132 | save_directory (:obj:`string`):
133 | Directory where the configuration JSON file will be saved.
134 | """
135 | if os.path.isfile(save_directory):
136 | raise AssertionError("Provided path ({}) should be a directory, not a file".format(save_directory))
137 | os.makedirs(save_directory, exist_ok=True)
138 | # If we save using the predefined names, we can load using `from_pretrained`
139 | output_config_file = os.path.join(save_directory, CONFIG_NAME)
140 |
141 | self.to_json_file(output_config_file, use_diff=True)
142 | logger.info("Configuration saved in {}".format(output_config_file))
143 |
144 | @classmethod
145 | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> "PretrainedConfig":
146 | r"""
147 |
148 | Instantiate a :class:`~transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration.
149 |
150 | Args:
151 | pretrained_model_name_or_path (:obj:`string`):
152 | either:
153 | - a string with the `shortcut name` of a pre-trained model configuration to load from cache or
154 | download, e.g.: ``bert-base-uncased``.
155 | - a string with the `identifier name` of a pre-trained model configuration that was user-uploaded to
156 | our S3, e.g.: ``dbmdz/bert-base-german-cased``.
157 | - a path to a `directory` containing a configuration file saved using the
158 | :func:`~transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``.
159 | - a path or url to a saved configuration JSON `file`, e.g.:
160 | ``./my_model_directory/configuration.json``.
161 | cache_dir (:obj:`string`, `optional`):
162 | Path to a directory in which a downloaded pre-trained model
163 | configuration should be cached if the standard cache should not be used.
164 | kwargs (:obj:`Dict[str, any]`, `optional`):
165 | The values in kwargs of any keys which are configuration attributes will be used to override the loaded
166 | values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is
167 | controlled by the `return_unused_kwargs` keyword parameter.
168 | force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
169 | Force to (re-)download the model weights and configuration files and override the cached versions if they exist.
170 | resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
171 | Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
172 | proxies (:obj:`Dict`, `optional`):
173 | A dictionary of proxy servers to use by protocol or endpoint, e.g.:
174 | :obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.`
175 | The proxies are used on each request.
176 | return_unused_kwargs: (`optional`) bool:
177 | If False, then this function returns just the final configuration object.
178 | If True, then this functions returns a :obj:`Tuple(config, unused_kwargs)` where `unused_kwargs` is a
179 | dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part
180 | of kwargs which has not been used to update `config` and is otherwise ignored.
181 |
182 | Returns:
183 | :class:`PretrainedConfig`: An instance of a configuration object
184 |
185 | Examples::
186 |
187 | # We can't instantiate directly the base class `PretrainedConfig` so let's show the examples on a
188 | # derived class: BertConfig
189 | config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
190 | config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
191 | config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json')
192 | config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
193 | assert config.output_attention == True
194 | config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True,
195 | foo=False, return_unused_kwargs=True)
196 | assert config.output_attention == True
197 | assert unused_kwargs == {'foo': False}
198 |
199 | """
200 | config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
201 | return cls.from_dict(config_dict, **kwargs)
202 |
203 | @classmethod
204 | def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs) -> Tuple[Dict, Dict]:
205 | """
206 | From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used
207 | for instantiating a Config using `from_dict`.
208 |
209 | Parameters:
210 | pretrained_model_name_or_path (:obj:`string`):
211 | The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
212 |
213 | Returns:
214 | :obj:`Tuple[Dict, Dict]`: The dictionary that will be used to instantiate the configuration object.
215 |
216 | """
217 | cache_dir = kwargs.pop("cache_dir", None)
218 | force_download = kwargs.pop("force_download", False)
219 | resume_download = kwargs.pop("resume_download", False)
220 | proxies = kwargs.pop("proxies", None)
221 | local_files_only = kwargs.pop("local_files_only", False)
222 |
223 | if os.path.isdir(pretrained_model_name_or_path):
224 | config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
225 | elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
226 | config_file = pretrained_model_name_or_path
227 | else:
228 | config_file = hf_bucket_url(pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False)
229 |
230 | try:
231 | # Load from URL or cache if already cached
232 | resolved_config_file = cached_path(
233 | config_file,
234 | cache_dir=cache_dir,
235 | force_download=force_download,
236 | proxies=proxies,
237 | resume_download=resume_download,
238 | local_files_only=local_files_only,
239 | )
240 | # Load config dict
241 | if resolved_config_file is None:
242 | raise EnvironmentError
243 | config_dict = cls._dict_from_json_file(resolved_config_file)
244 |
245 | except EnvironmentError:
246 | msg = (
247 | f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
248 | f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
249 | f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {CONFIG_NAME} file\n\n"
250 | )
251 | raise EnvironmentError(msg)
252 |
253 | except json.JSONDecodeError:
254 | msg = (
255 | "Couldn't reach server at '{}' to download configuration file or "
256 | "configuration file is not a valid JSON file. "
257 | "Please check network or file content here: {}.".format(config_file, resolved_config_file)
258 | )
259 | raise EnvironmentError(msg)
260 |
261 | if resolved_config_file == config_file:
262 | logger.info("loading configuration file {}".format(config_file))
263 | else:
264 | logger.info("loading configuration file {} from cache at {}".format(config_file, resolved_config_file))
265 |
266 | return config_dict, kwargs
267 |
268 | @classmethod
269 | def from_dict(cls, config_dict: Dict, **kwargs) -> "PretrainedConfig":
270 | """
271 | Constructs a `Config` from a Python dictionary of parameters.
272 |
273 | Args:
274 | config_dict (:obj:`Dict[str, any]`):
275 | Dictionary that will be used to instantiate the configuration object. Such a dictionary can be retrieved
276 | from a pre-trained checkpoint by leveraging the :func:`~transformers.PretrainedConfig.get_config_dict`
277 | method.
278 | kwargs (:obj:`Dict[str, any]`):
279 | Additional parameters from which to initialize the configuration object.
280 |
281 | Returns:
282 | :class:`PretrainedConfig`: An instance of a configuration object
283 | """
284 | return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
285 |
286 | config = cls(**config_dict)
287 |
288 | if hasattr(config, "pruned_heads"):
289 | config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
290 |
291 | # Update config with kwargs if needed
292 | to_remove = []
293 | for key, value in kwargs.items():
294 | if hasattr(config, key):
295 | setattr(config, key, value)
296 | to_remove.append(key)
297 | for key in to_remove:
298 | kwargs.pop(key, None)
299 |
300 | logger.info("Model config %s", str(config))
301 | if return_unused_kwargs:
302 | return config, kwargs
303 | else:
304 | return config
305 |
306 | @classmethod
307 | def from_json_file(cls, json_file: str) -> "PretrainedConfig":
308 | """
309 | Constructs a `Config` from the path to a json file of parameters.
310 |
311 | Args:
312 | json_file (:obj:`string`):
313 | Path to the JSON file containing the parameters.
314 |
315 | Returns:
316 | :class:`PretrainedConfig`: An instance of a configuration object
317 |
318 | """
319 | config_dict = cls._dict_from_json_file(json_file)
320 | return cls(**config_dict)
321 |
322 | @classmethod
323 | def _dict_from_json_file(cls, json_file: str):
324 | with open(json_file, "r", encoding="utf-8") as reader:
325 | text = reader.read()
326 | return json.loads(text)
327 |
328 | def __eq__(self, other):
329 | return self.__dict__ == other.__dict__
330 |
331 | def __repr__(self):
332 | return "{} {}".format(self.__class__.__name__, self.to_json_string())
333 |
334 | def to_diff_dict(self):
335 | """
336 | Removes all attributes from config which correspond to the default
337 | config attributes for better readability and serializes to a Python
338 | dictionary.
339 |
340 | Returns:
341 | :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
342 | """
343 | config_dict = self.to_dict()
344 |
345 | # get the default config dict
346 | default_config_dict = PretrainedConfig().to_dict()
347 |
348 | serializable_config_dict = {}
349 |
350 | # only serialize values that differ from the default config
351 | for key, value in config_dict.items():
352 | if key not in default_config_dict or value != default_config_dict[key]:
353 | serializable_config_dict[key] = value
354 |
355 | return serializable_config_dict
356 |
357 | def to_dict(self):
358 | """
359 | Serializes this instance to a Python dictionary.
360 |
361 | Returns:
362 | :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
363 | """
364 | output = copy.deepcopy(self.__dict__)
365 | if hasattr(self.__class__, "model_type"):
366 | output["model_type"] = self.__class__.model_type
367 | return output
368 |
369 | def to_json_string(self, use_diff=True):
370 | """
371 | Serializes this instance to a JSON string.
372 |
373 | Args:
374 | use_diff (:obj:`bool`):
375 | If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON string.
376 |
377 | Returns:
378 | :obj:`string`: String containing all the attributes that make up this configuration instance in JSON format.
379 | """
380 | if use_diff is True:
381 | config_dict = self.to_diff_dict()
382 | else:
383 | config_dict = self.to_dict()
384 | return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
385 |
386 | def to_json_file(self, json_file_path, use_diff=True):
387 | """
388 | Save this instance to a json file.
389 |
390 | Args:
391 | json_file_path (:obj:`string`):
392 | Path to the JSON file in which this configuration instance's parameters will be saved.
393 | use_diff (:obj:`bool`):
394 | If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON file.
395 | """
396 | with open(json_file_path, "w", encoding="utf-8") as writer:
397 | writer.write(self.to_json_string(use_diff=use_diff))
398 |
399 | def update(self, config_dict: Dict):
400 | """
401 | Updates attributes of this class
402 | with attributes from `config_dict`.
403 |
404 | Args:
405 | :obj:`Dict[str, any]`: Dictionary of attributes that shall be updated for this class.
406 | """
407 | for key, value in config_dict.items():
408 | setattr(self, key, value)
409 |
--------------------------------------------------------------------------------
/bert/tokenization_bert.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 |
18 | import collections
19 | import logging
20 | import os
21 | import unicodedata
22 | from typing import List, Optional
23 |
24 | from .tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
25 |
26 |
27 | logger = logging.getLogger(__name__)
28 |
29 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
30 |
31 | PRETRAINED_VOCAB_FILES_MAP = {
32 | "vocab_file": {
33 | "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
34 | "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
35 | "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
36 | "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
37 | "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
38 | "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
39 | "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
40 | "bert-base-german-cased": "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt",
41 | "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt",
42 | "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt",
43 | "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt",
44 | "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
45 | "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt",
46 | "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-vocab.txt",
47 | "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-vocab.txt",
48 | "TurkuNLP/bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/vocab.txt",
49 | "TurkuNLP/bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/vocab.txt",
50 | "wietsedv/bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/vocab.txt",
51 | }
52 | }
53 |
54 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
55 | "bert-base-uncased": 512,
56 | "bert-large-uncased": 512,
57 | "bert-base-cased": 512,
58 | "bert-large-cased": 512,
59 | "bert-base-multilingual-uncased": 512,
60 | "bert-base-multilingual-cased": 512,
61 | "bert-base-chinese": 512,
62 | "bert-base-german-cased": 512,
63 | "bert-large-uncased-whole-word-masking": 512,
64 | "bert-large-cased-whole-word-masking": 512,
65 | "bert-large-uncased-whole-word-masking-finetuned-squad": 512,
66 | "bert-large-cased-whole-word-masking-finetuned-squad": 512,
67 | "bert-base-cased-finetuned-mrpc": 512,
68 | "bert-base-german-dbmdz-cased": 512,
69 | "bert-base-german-dbmdz-uncased": 512,
70 | "TurkuNLP/bert-base-finnish-cased-v1": 512,
71 | "TurkuNLP/bert-base-finnish-uncased-v1": 512,
72 | "wietsedv/bert-base-dutch-cased": 512,
73 | }
74 |
75 | PRETRAINED_INIT_CONFIGURATION = {
76 | "bert-base-uncased": {"do_lower_case": True},
77 | "bert-large-uncased": {"do_lower_case": True},
78 | "bert-base-cased": {"do_lower_case": False},
79 | "bert-large-cased": {"do_lower_case": False},
80 | "bert-base-multilingual-uncased": {"do_lower_case": True},
81 | "bert-base-multilingual-cased": {"do_lower_case": False},
82 | "bert-base-chinese": {"do_lower_case": False},
83 | "bert-base-german-cased": {"do_lower_case": False},
84 | "bert-large-uncased-whole-word-masking": {"do_lower_case": True},
85 | "bert-large-cased-whole-word-masking": {"do_lower_case": False},
86 | "bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True},
87 | "bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False},
88 | "bert-base-cased-finetuned-mrpc": {"do_lower_case": False},
89 | "bert-base-german-dbmdz-cased": {"do_lower_case": False},
90 | "bert-base-german-dbmdz-uncased": {"do_lower_case": True},
91 | "TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False},
92 | "TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True},
93 | "wietsedv/bert-base-dutch-cased": {"do_lower_case": False},
94 | }
95 |
96 |
97 | def load_vocab(vocab_file):
98 | """Loads a vocabulary file into a dictionary."""
99 | vocab = collections.OrderedDict()
100 | with open(vocab_file, "r", encoding="utf-8") as reader:
101 | tokens = reader.readlines()
102 | for index, token in enumerate(tokens):
103 | token = token.rstrip("\n")
104 | vocab[token] = index
105 | return vocab
106 |
107 |
108 | def whitespace_tokenize(text):
109 | """Runs basic whitespace cleaning and splitting on a piece of text."""
110 | text = text.strip()
111 | if not text:
112 | return []
113 | tokens = text.split()
114 | return tokens
115 |
116 |
117 | class BertTokenizer(PreTrainedTokenizer):
118 | r"""
119 | Constructs a BERT tokenizer. Based on WordPiece.
120 |
121 | This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
122 | should refer to the superclass for more information regarding methods.
123 |
124 | Args:
125 | vocab_file (:obj:`string`):
126 | File containing the vocabulary.
127 | do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
128 | Whether to lowercase the input when tokenizing.
129 | do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`):
130 | Whether to do basic tokenization before WordPiece.
131 | never_split (:obj:`Iterable`, `optional`, defaults to :obj:`None`):
132 | Collection of tokens which will never be split during tokenization. Only has an effect when
133 | :obj:`do_basic_tokenize=True`
134 | unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
135 | The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
136 | token instead.
137 | sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
138 | The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
139 | for sequence classification or for a text and a question for question answering.
140 | It is also used as the last token of a sequence built with special tokens.
141 | pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
142 | The token used for padding, for example when batching sequences of different lengths.
143 | cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
144 | The classifier token which is used when doing sequence classification (classification of the whole
145 | sequence instead of per-token classification). It is the first token of the sequence when built with
146 | special tokens.
147 | mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
148 | The token used for masking values. This is the token used when training this model with masked language
149 | modeling. This is the token which the model will try to predict.
150 | tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
151 | Whether to tokenize Chinese characters.
152 | This should likely be deactivated for Japanese:
153 | see: https://github.com/huggingface/transformers/issues/328
154 | """
155 |
156 | vocab_files_names = VOCAB_FILES_NAMES
157 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
158 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
159 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
160 |
161 | def __init__(
162 | self,
163 | vocab_file,
164 | do_lower_case=True,
165 | do_basic_tokenize=True,
166 | never_split=None,
167 | unk_token="[UNK]",
168 | sep_token="[SEP]",
169 | pad_token="[PAD]",
170 | cls_token="[CLS]",
171 | mask_token="[MASK]",
172 | tokenize_chinese_chars=True,
173 | **kwargs
174 | ):
175 | super().__init__(
176 | unk_token=unk_token,
177 | sep_token=sep_token,
178 | pad_token=pad_token,
179 | cls_token=cls_token,
180 | mask_token=mask_token,
181 | **kwargs,
182 | )
183 |
184 | if not os.path.isfile(vocab_file):
185 | raise ValueError(
186 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
187 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)
188 | )
189 | self.vocab = load_vocab(vocab_file)
190 | self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
191 | self.do_basic_tokenize = do_basic_tokenize
192 | if do_basic_tokenize:
193 | self.basic_tokenizer = BasicTokenizer(
194 | do_lower_case=do_lower_case, never_split=never_split, tokenize_chinese_chars=tokenize_chinese_chars
195 | )
196 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
197 |
198 | @property
199 | def vocab_size(self):
200 | return len(self.vocab)
201 |
202 | def get_vocab(self):
203 | return dict(self.vocab, **self.added_tokens_encoder)
204 |
205 | def _tokenize(self, text):
206 | split_tokens = []
207 | if self.do_basic_tokenize:
208 | for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
209 |
210 | # If the token is part of the never_split set
211 | if token in self.basic_tokenizer.never_split:
212 | split_tokens.append(token)
213 | else:
214 | split_tokens += self.wordpiece_tokenizer.tokenize(token)
215 | else:
216 | split_tokens = self.wordpiece_tokenizer.tokenize(text)
217 | return split_tokens
218 |
219 | def _convert_token_to_id(self, token):
220 | """ Converts a token (str) in an id using the vocab. """
221 | return self.vocab.get(token, self.vocab.get(self.unk_token))
222 |
223 | def _convert_id_to_token(self, index):
224 | """Converts an index (integer) in a token (str) using the vocab."""
225 | return self.ids_to_tokens.get(index, self.unk_token)
226 |
227 | def convert_tokens_to_string(self, tokens):
228 | """ Converts a sequence of tokens (string) in a single string. """
229 | out_string = " ".join(tokens).replace(" ##", "").strip()
230 | return out_string
231 |
232 | def build_inputs_with_special_tokens(
233 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
234 | ) -> List[int]:
235 | """
236 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks
237 | by concatenating and adding special tokens.
238 | A BERT sequence has the following format:
239 |
240 | - single sequence: ``[CLS] X [SEP]``
241 | - pair of sequences: ``[CLS] A [SEP] B [SEP]``
242 |
243 | Args:
244 | token_ids_0 (:obj:`List[int]`):
245 | List of IDs to which the special tokens will be added
246 | token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
247 | Optional second list of IDs for sequence pairs.
248 |
249 | Returns:
250 | :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
251 | """
252 | if token_ids_1 is None:
253 | return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
254 | cls = [self.cls_token_id]
255 | sep = [self.sep_token_id]
256 | return cls + token_ids_0 + sep + token_ids_1 + sep
257 |
258 | def get_special_tokens_mask(
259 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
260 | ) -> List[int]:
261 | """
262 | Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
263 | special tokens using the tokenizer ``prepare_for_model`` method.
264 |
265 | Args:
266 | token_ids_0 (:obj:`List[int]`):
267 | List of ids.
268 | token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
269 | Optional second list of IDs for sequence pairs.
270 | already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
271 | Set to True if the token list is already formatted with special tokens for the model
272 |
273 | Returns:
274 | :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
275 | """
276 |
277 | if already_has_special_tokens:
278 | if token_ids_1 is not None:
279 | raise ValueError(
280 | "You should not supply a second sequence if the provided sequence of "
281 | "ids is already formated with special tokens for the model."
282 | )
283 | return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
284 |
285 | if token_ids_1 is not None:
286 | return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
287 | return [1] + ([0] * len(token_ids_0)) + [1]
288 |
289 | def create_token_type_ids_from_sequences(
290 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
291 | ) -> List[int]:
292 | """
293 | Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
294 | A BERT sequence pair mask has the following format:
295 |
296 | ::
297 |
298 | 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
299 | | first sequence | second sequence |
300 |
301 | if token_ids_1 is None, only returns the first portion of the mask (0's).
302 |
303 | Args:
304 | token_ids_0 (:obj:`List[int]`):
305 | List of ids.
306 | token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
307 | Optional second list of IDs for sequence pairs.
308 |
309 | Returns:
310 | :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
311 | sequence(s).
312 | """
313 | sep = [self.sep_token_id]
314 | cls = [self.cls_token_id]
315 | if token_ids_1 is None:
316 | return len(cls + token_ids_0 + sep) * [0]
317 | return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
318 |
319 | def save_vocabulary(self, vocab_path):
320 | """
321 | Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
322 |
323 | Args:
324 | vocab_path (:obj:`str`):
325 | The directory in which to save the vocabulary.
326 |
327 | Returns:
328 | :obj:`Tuple(str)`: Paths to the files saved.
329 | """
330 | index = 0
331 | if os.path.isdir(vocab_path):
332 | vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
333 | else:
334 | vocab_file = vocab_path
335 | with open(vocab_file, "w", encoding="utf-8") as writer:
336 | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
337 | if index != token_index:
338 | logger.warning(
339 | "Saving vocabulary to {}: vocabulary indices are not consecutive."
340 | " Please check that the vocabulary is not corrupted!".format(vocab_file)
341 | )
342 | index = token_index
343 | writer.write(token + "\n")
344 | index += 1
345 | return (vocab_file,)
346 |
347 |
348 | class BasicTokenizer(object):
349 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
350 |
351 | def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True):
352 | """ Constructs a BasicTokenizer.
353 |
354 | Args:
355 | **do_lower_case**: Whether to lower case the input.
356 | **never_split**: (`optional`) list of str
357 | Kept for backward compatibility purposes.
358 | Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`)
359 | List of token not to split.
360 | **tokenize_chinese_chars**: (`optional`) boolean (default True)
361 | Whether to tokenize Chinese characters.
362 | This should likely be deactivated for Japanese:
363 | see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328
364 | """
365 | if never_split is None:
366 | never_split = []
367 | self.do_lower_case = do_lower_case
368 | self.never_split = set(never_split)
369 | self.tokenize_chinese_chars = tokenize_chinese_chars
370 |
371 | def tokenize(self, text, never_split=None):
372 | """ Basic Tokenization of a piece of text.
373 | Split on "white spaces" only, for sub-word tokenization, see WordPieceTokenizer.
374 |
375 | Args:
376 | **never_split**: (`optional`) list of str
377 | Kept for backward compatibility purposes.
378 | Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`)
379 | List of token not to split.
380 | """
381 | # union() returns a new set by concatenating the two sets.
382 | never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
383 |
384 | # This was added on November 1st, 2018 for the multilingual and Chinese
385 | # models. This is also applied to the English models now, but it doesn't
386 | # matter since the English models were not trained on any Chinese data
387 | # and generally don't have any Chinese data in them (there are Chinese
388 | # characters in the vocabulary because Wikipedia does have some Chinese
389 | # words in the English Wikipedia.).
390 | if self.tokenize_chinese_chars:
391 | text = self._tokenize_chinese_chars(text)
392 | orig_tokens = whitespace_tokenize(text)
393 | split_tokens = []
394 | for token in orig_tokens:
395 | if self.do_lower_case and token not in never_split:
396 | token = token.lower()
397 | token = self._run_strip_accents(token)
398 | split_tokens.extend(self._run_split_on_punc(token, never_split))
399 |
400 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
401 | return output_tokens
402 |
403 | def _run_strip_accents(self, text):
404 | """Strips accents from a piece of text."""
405 | text = unicodedata.normalize("NFD", text)
406 | output = []
407 | for char in text:
408 | cat = unicodedata.category(char)
409 | if cat == "Mn":
410 | continue
411 | output.append(char)
412 | return "".join(output)
413 |
414 | def _run_split_on_punc(self, text, never_split=None):
415 | """Splits punctuation on a piece of text."""
416 | if never_split is not None and text in never_split:
417 | return [text]
418 | chars = list(text)
419 | i = 0
420 | start_new_word = True
421 | output = []
422 | while i < len(chars):
423 | char = chars[i]
424 | if _is_punctuation(char):
425 | output.append([char])
426 | start_new_word = True
427 | else:
428 | if start_new_word:
429 | output.append([])
430 | start_new_word = False
431 | output[-1].append(char)
432 | i += 1
433 |
434 | return ["".join(x) for x in output]
435 |
436 | def _tokenize_chinese_chars(self, text):
437 | """Adds whitespace around any CJK character."""
438 | output = []
439 | for char in text:
440 | cp = ord(char)
441 | if self._is_chinese_char(cp):
442 | output.append(" ")
443 | output.append(char)
444 | output.append(" ")
445 | else:
446 | output.append(char)
447 | return "".join(output)
448 |
449 | def _is_chinese_char(self, cp):
450 | """Checks whether CP is the codepoint of a CJK character."""
451 | # This defines a "chinese character" as anything in the CJK Unicode block:
452 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
453 | #
454 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
455 | # despite its name. The modern Korean Hangul alphabet is a different block,
456 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
457 | # space-separated words, so they are not treated specially and handled
458 | # like the all of the other languages.
459 | if (
460 | (cp >= 0x4E00 and cp <= 0x9FFF)
461 | or (cp >= 0x3400 and cp <= 0x4DBF) #
462 | or (cp >= 0x20000 and cp <= 0x2A6DF) #
463 | or (cp >= 0x2A700 and cp <= 0x2B73F) #
464 | or (cp >= 0x2B740 and cp <= 0x2B81F) #
465 | or (cp >= 0x2B820 and cp <= 0x2CEAF) #
466 | or (cp >= 0xF900 and cp <= 0xFAFF)
467 | or (cp >= 0x2F800 and cp <= 0x2FA1F) #
468 | ): #
469 | return True
470 |
471 | return False
472 |
473 | def _clean_text(self, text):
474 | """Performs invalid character removal and whitespace cleanup on text."""
475 | output = []
476 | for char in text:
477 | cp = ord(char)
478 | if cp == 0 or cp == 0xFFFD or _is_control(char):
479 | continue
480 | if _is_whitespace(char):
481 | output.append(" ")
482 | else:
483 | output.append(char)
484 | return "".join(output)
485 |
486 |
487 | class WordpieceTokenizer(object):
488 | """Runs WordPiece tokenization."""
489 |
490 | def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
491 | self.vocab = vocab
492 | self.unk_token = unk_token
493 | self.max_input_chars_per_word = max_input_chars_per_word
494 |
495 | def tokenize(self, text):
496 | """Tokenizes a piece of text into its word pieces.
497 |
498 | This uses a greedy longest-match-first algorithm to perform tokenization
499 | using the given vocabulary.
500 |
501 | For example:
502 | input = "unaffable"
503 | output = ["un", "##aff", "##able"]
504 |
505 | Args:
506 | text: A single token or whitespace separated tokens. This should have
507 | already been passed through `BasicTokenizer`.
508 |
509 | Returns:
510 | A list of wordpiece tokens.
511 | """
512 |
513 | output_tokens = []
514 | for token in whitespace_tokenize(text):
515 | chars = list(token)
516 | if len(chars) > self.max_input_chars_per_word:
517 | output_tokens.append(self.unk_token)
518 | continue
519 |
520 | is_bad = False
521 | start = 0
522 | sub_tokens = []
523 | while start < len(chars):
524 | end = len(chars)
525 | cur_substr = None
526 | while start < end:
527 | substr = "".join(chars[start:end])
528 | if start > 0:
529 | substr = "##" + substr
530 | if substr in self.vocab:
531 | cur_substr = substr
532 | break
533 | end -= 1
534 | if cur_substr is None:
535 | is_bad = True
536 | break
537 | sub_tokens.append(cur_substr)
538 | start = end
539 |
540 | if is_bad:
541 | output_tokens.append(self.unk_token)
542 | else:
543 | output_tokens.extend(sub_tokens)
544 | return output_tokens
545 |
546 |
--------------------------------------------------------------------------------
/config/config.yaml:
--------------------------------------------------------------------------------
1 | DATA:
2 | dataset: refcoco
3 | train_split: train
4 | train_lmdb: path/refcoco/train.lmdb
5 | val_split: val
6 | val_lmdb: path/val.lmdb
7 | mask_root: path/masks/refcoco
8 | TRAIN:
9 | swin_type: base
10 | swin_pretrain: path/swin_base_window12.pth
11 | bert: bert-base-uncased
12 | mha: '8-8-8-8'
13 | input_size: 480
14 | word_len: 20
15 | word_dim: 768
16 | vis_dim: 512
17 | num_token: 2
18 | token_dim: 512
19 | sync_bn: True
20 | dropout: 0.
21 | fusion_drop: 0.
22 | workers: 32 # data loader workers
23 | workers_val: 8
24 | batch_size: 64 # batch size for training
25 | batch_size_val: 16 # batch size for validation during training, memory and speed tradeoff
26 | start_epoch: 0
27 | epochs: 50
28 | lr_backbone: 5.e-5
29 | lr_text_encoder: 5.e-5
30 | lr: 1.e-4
31 | weight_decay: 1.e-4
32 | amsgrad: True
33 | manual_seed:
34 | print_freq: 100
35 | exp_name: cgformer
36 | output_folder: exp/refcoco/
37 | save_freq: 1
38 | weight:
39 | resume:
40 | evaluate: True
41 | Distributed:
42 | dist_url: tcp://localhost:12345
43 | dist_backend: 'nccl'
44 | multiprocessing_distributed: True
45 | world_size: 1
46 | rank: 0
47 | TEST:
48 | window12: True # if use window12 pretrained for training, testing set true
49 | test_split: val
50 | test_lmdb: path/refcoco/val.lmdb
51 | visualize: False
--------------------------------------------------------------------------------
/config/open.yaml:
--------------------------------------------------------------------------------
1 | DATA:
2 | dataset: refcoco
3 | train_split: train_seen
4 | train_lmdb: path/open_lmdb/refcoco/train_seen.lmdb
5 | val_seen_split: val_seen
6 | val_seen_lmdb: path/open_lmdb/refcoco/val_seen.lmdb
7 | val_unseen_split: val_unseen
8 | val_unseen_lmdb: path/open_lmdb/refcoco/val_unseen.lmdb
9 | mask_root: path/masks/refcoco
10 | TRAIN:
11 | swin_type: base
12 | swin_pretrain: path/swin_base_patch4_window12_384_22k.pth
13 | bert: bert-base-uncased
14 | clip_pretrain: path/pretrain/ViT-L-14-336px.pt
15 | mha: '8-8-8-8'
16 | input_size: 480
17 | clip_dim: 768
18 | word_len: 20
19 | num_token: 2
20 | word_dim: 768
21 | vis_dim: 512
22 | token_dim: 512
23 | sync_bn: True
24 | dropout: 0.
25 | fusion_drop: 0.
26 | workers: 32 # data loader workers
27 | workers_val: 8
28 | batch_size: 64 # batch size for training
29 | batch_size_val: 16 # batch size for validation during training, memory and speed tradeoff
30 | start_epoch: 0
31 | epochs: 1000
32 | lr_backbone: 5.e-5
33 | lr_text_encoder: 5.e-5
34 | lr: 1.e-4
35 | weight_decay: 1.e-4
36 | amsgrad: True
37 | manual_seed: 0
38 | print_freq: 100
39 | exp_name: open
40 | output_folder: exp/refcoco
41 | save_freq: 1
42 | weight:
43 | resume:
44 | evaluate: True # evaluate on validation set, extra gpu memory needed and small batch_size_val is recommend
45 | Distributed:
46 | dist_url: tcp://localhost:12345
47 | dist_backend: 'nccl'
48 | multiprocessing_distributed: True
49 | world_size: 1
50 | rank: 0
51 | TEST:
52 | window12: True # if use window12 pretrained for training, testing set true
53 | test_split: test_unseen
54 | test_lmdb: path/refcoco/test_unseen.lmdb
55 | visualize: False
--------------------------------------------------------------------------------
/data/READEME.md:
--------------------------------------------------------------------------------
1 | ## Referrring Image Segmentation
2 | Preparing data for RefCOCO, RefCOCO+, and RefCOCOg: [CRIS·GitHub](https://github.com/DerrickWang005/CRIS.pytorch/blob/master/tools/prepare_datasets.md)
3 | ## Generalization Setting for RIS
4 | Preparing data for generalization experiments. Download the prepared lmdb format datasets from [google drive](https://drive.google.com/file/d/10qIyrslkX50THzZFgGuompF_tWl5TRW0/view?usp=drive_link).
5 |
6 | You can also customize your data format via seen and unseen [id maps](coco_id.json) This id corresponds to the "cat" field in the original annotations.
--------------------------------------------------------------------------------
/data/coco_id.json:
--------------------------------------------------------------------------------
1 | {
2 | "0": "person",
3 | "1": "bicycle",
4 | "2": "car",
5 | "3": "motorcycle",
6 | "4": "airplane",
7 | "5": "bus",
8 | "6": "train",
9 | "7": "truck",
10 | "8": "boat",
11 | "9": "traffic light",
12 | "10": "fire hydrant",
13 | "11": "stop sign",
14 | "12": "parking meter",
15 | "13": "bench",
16 | "14": "bird",
17 | "15": "cat",
18 | "16": "dog",
19 | "17": "horse",
20 | "18": "sheep",
21 | "19": "cow",
22 | "20": "elephant",
23 | "21": "bear",
24 | "22": "zebra",
25 | "23": "giraffe",
26 | "24": "backpack",
27 | "25": "umbrella",
28 | "26": "handbag",
29 | "27": "tie",
30 | "28": "suitcase",
31 | "29": "frisbee",
32 | "30": "skis",
33 | "31": "snowboard",
34 | "32": "sports ball",
35 | "33": "kite",
36 | "34": "baseball bat",
37 | "35": "baseball glove",
38 | "36": "skateboard",
39 | "37": "surfboard",
40 | "38": "tennis racket",
41 | "39": "bottle",
42 | "40": "wine glass",
43 | "41": "cup",
44 | "42": "fork",
45 | "43": "knife",
46 | "44": "spoon",
47 | "45": "bowl",
48 | "46": "banana",
49 | "47": "apple",
50 | "48": "sandwich",
51 | "49": "orange",
52 | "50": "broccoli",
53 | "51": "carrot",
54 | "52": "hot dog",
55 | "53": "pizza",
56 | "54": "donut",
57 | "55": "cake",
58 | "56": "chair",
59 | "57": "couch",
60 | "58": "potted plant",
61 | "59": "bed",
62 | "60": "dining table",
63 | "61": "toilet",
64 | "62": "tv",
65 | "63": "laptop",
66 | "64": "mouse",
67 | "65": "remote",
68 | "66": "keyboard",
69 | "67": "cell phone",
70 | "68": "microwave",
71 | "69": "oven",
72 | "70": "toaster",
73 | "71": "sink",
74 | "72": "refrigerator",
75 | "73": "book",
76 | "74": "clock",
77 | "75": "vase",
78 | "76": "scissors",
79 | "77": "teddy bear",
80 | "78": "hair drier",
81 | "79": "toothbrush",
82 | "seen": [
83 | 0,
84 | 1,
85 | 2,
86 | 3,
87 | 6,
88 | 7,
89 | 8,
90 | 13,
91 | 14,
92 | 17,
93 | 18,
94 | 21,
95 | 22,
96 | 23,
97 | 24,
98 | 26,
99 | 28,
100 | 29,
101 | 30,
102 | 33,
103 | 37,
104 | 39,
105 | 42,
106 | 44,
107 | 45,
108 | 46,
109 | 47,
110 | 48,
111 | 49,
112 | 50,
113 | 51,
114 | 53,
115 | 54,
116 | 56,
117 | 59,
118 | 61,
119 | 62,
120 | 63,
121 | 64,
122 | 65,
123 | 68,
124 | 69,
125 | 70,
126 | 72,
127 | 73,
128 | 74,
129 | 75,
130 | 79
131 | ],
132 | "unseen": [
133 | 4,
134 | 5,
135 | 15,
136 | 16,
137 | 19,
138 | 20,
139 | 25,
140 | 27,
141 | 31,
142 | 36,
143 | 41,
144 | 43,
145 | 55,
146 | 57,
147 | 66,
148 | 71,
149 | 76
150 | ]
151 | }
--------------------------------------------------------------------------------
/engine/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SooLab/CGFormer/766d7fbe0c0101c80806e2499bb4d6960cfc5f4a/engine/__init__.py
--------------------------------------------------------------------------------
/engine/engine.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from tqdm import tqdm
4 | import cv2
5 | import numpy as np
6 | import torch
7 | import torch.cuda.amp as amp
8 | import torch.distributed as dist
9 | import torch.nn.functional as F
10 | import wandb
11 | from PIL import Image
12 | from loguru import logger
13 | from utils.misc import (AverageMeter, ProgressMeter, concat_all_gather, trainMetricGPU)
14 |
15 |
16 | def train(train_loader, model, optimizer, scheduler, scaler, epoch, args):
17 | batch_time = AverageMeter('Batch', ':2.2f')
18 | data_time = AverageMeter('Data', ':2.2f')
19 | lr = AverageMeter('Lr', ':1.6f')
20 | loss_meter = AverageMeter('Loss', ':2.4f')
21 | iou_meter = AverageMeter('IoU', ':2.2f')
22 | pr_meter = AverageMeter('Prec@50', ':2.2f')
23 | progress = ProgressMeter(
24 | len(train_loader),
25 | [batch_time, data_time, lr, loss_meter, iou_meter, pr_meter],
26 | prefix="Training: Epoch=[{}/{}] ".format(epoch, args.epochs))
27 |
28 | model.train()
29 | time.sleep(2)
30 | end = time.time()
31 |
32 | # size_list = [320, 352, 384, 416, 448, 480, 512]
33 | # idx = np.random.choice(len(size_list))
34 | # new_size = size_list[idx]
35 |
36 | for i, (image, text, target, l_mask) in enumerate(train_loader):
37 | data_time.update(time.time() - end)
38 | # data
39 | image = torch.stack(image).cuda(non_blocking=True)
40 | text = torch.stack(text).cuda(non_blocking=True)
41 | target = torch.stack(target).cuda(non_blocking=True)
42 | l_mask = torch.stack(l_mask).cuda(non_blocking=True)
43 | # # multi-scale training
44 | # image = F.interpolate(image, size=(new_size, new_size), mode='bilinear', align_corners=True)
45 | text = text.squeeze(1)
46 | l_mask = l_mask.squeeze(1)
47 | # forward
48 | with amp.autocast():
49 | pred, target, loss = model(image, text, l_mask, target)
50 | # backward
51 | optimizer.zero_grad()
52 | scaler.scale(loss).backward()
53 | # if args.max_norm:
54 | # torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
55 | scaler.step(optimizer)
56 | scaler.update()
57 | scheduler.step()
58 |
59 | # metric
60 | iou, pr5 = trainMetricGPU(pred, target, 0.35)
61 | dist.all_reduce(loss.detach())
62 | dist.all_reduce(iou)
63 | dist.all_reduce(pr5)
64 | loss = loss / dist.get_world_size()
65 | iou = iou / dist.get_world_size()
66 | pr5 = pr5 / dist.get_world_size()
67 |
68 | loss_meter.update(loss.item(), image.size(0))
69 | iou_meter.update(iou.item(), image.size(0))
70 | pr_meter.update(pr5.item(), image.size(0))
71 | lr.update(optimizer.param_groups[0]["lr"])
72 | batch_time.update(time.time() - end)
73 | end = time.time()
74 |
75 | if (i + 1) % args.print_freq == 0:
76 | progress.display(i + 1)
77 | if dist.get_rank() in [-1, 0]:
78 | wandb.log(
79 | {
80 | "time/batch": batch_time.val,
81 | "time/data": data_time.val,
82 | "training/lr": lr.val,
83 | "training/loss": loss_meter.val,
84 | "training/iou": iou_meter.val,
85 | "training/prec@50": pr_meter.val,
86 | },
87 | step=epoch * len(train_loader) + (i + 1))
88 |
89 |
90 | @torch.no_grad()
91 | def validate(val_loader, model, epoch, args):
92 | iou_list = []
93 | I = []
94 | U = []
95 | model.eval()
96 | time.sleep(2)
97 | for imgs, text, masks, l_mask in val_loader:
98 | # data
99 | imgs = torch.stack(imgs).cuda(non_blocking=True)
100 | text = torch.stack(text).cuda(non_blocking=True)
101 | l_mask = torch.stack(l_mask).cuda(non_blocking=True)
102 | text = text.squeeze(1)
103 | l_mask = l_mask.squeeze(1)
104 | # inference
105 | preds, maps = model(imgs, text, l_mask)
106 | preds = torch.sigmoid(preds)
107 | # process one batch
108 | for pred, mask in zip(preds, masks):
109 | # iou
110 | pred = pred.cpu().numpy()
111 | mask = mask.cpu().numpy()
112 | pred = np.array(pred > 0.5)
113 | inter = np.logical_and(pred, mask)
114 | union = np.logical_or(pred, mask)
115 | iou = np.sum(inter) / (np.sum(union) + 1e-6)
116 | iou_list.append(iou)
117 | I.append(np.sum(inter))
118 | U.append(np.sum(union))
119 | iou_list = np.stack(iou_list)
120 | iou_list = torch.from_numpy(iou_list).to(imgs.device)
121 | iou_list = concat_all_gather(iou_list)
122 | I = np.stack(I)
123 | I = torch.from_numpy(I).to(imgs.device)
124 | I = concat_all_gather(I).sum()
125 |
126 | U = np.stack(U)
127 | U = torch.from_numpy(U).to(imgs.device)
128 | U = concat_all_gather(U).sum()
129 | oIoU = I/U
130 | prec_list = []
131 | for thres in torch.arange(0.5, 1.0, 0.1):
132 | tmp = (iou_list > thres).float().mean()
133 | prec_list.append(tmp)
134 | iou = iou_list.mean()
135 | prec = {}
136 | temp = ' '
137 | for i, thres in enumerate(range(5, 10)):
138 | key = 'Pr@{}'.format(thres * 10)
139 | value = prec_list[i].item()
140 | prec[key] = value
141 | temp += "{}: {:.2f} ".format(key, 100. * value)
142 | head = 'Evaluation: Epoch=[{}/{}] mIoU={:.2f} oIoU={:.2f}'.format(
143 | epoch, args.epochs, 100. * iou.item(), 100.*(oIoU))
144 | logger.info(head + temp)
145 | return oIoU, prec
146 |
147 |
148 | @torch.no_grad()
149 | def inference(test_loader, model, args):
150 | iou_list = []
151 | I = 0.
152 | U = 0.
153 | tbar = tqdm(test_loader, desc='Inference:', ncols=100)
154 | model.eval()
155 | time.sleep(2)
156 | for ori_img, img, texts, mask, l_masks, seg_id, sents in tbar:
157 | img = img.cuda(non_blocking=True)
158 | mask = mask.cpu().numpy()
159 | for text, l_mask, sent in zip(texts, l_masks, sents):
160 | text = text.cuda(non_blocking=True)
161 | l_mask = l_mask.cuda(non_blocking=True)
162 |
163 | text = text.squeeze(1)
164 | l_mask = l_mask.squeeze(1)
165 |
166 | # inference
167 | pred, maps = model(img, text, l_mask)
168 | pred = torch.sigmoid(pred)
169 | if pred.shape[-2:] != ori_img.shape[:-1]:
170 | pred = F.interpolate(pred, size=ori_img.shape[1:-1], mode='bicubic', align_corners=True)
171 | # # process one sentence
172 | pred = pred.cpu().numpy()
173 | pred_ = np.array(pred > 0.5)
174 | inter = np.logical_and(pred_, mask)
175 | union = np.logical_or(pred_, mask)
176 | I += np.sum(inter)
177 | U += np.sum(union)
178 | iou = np.sum(inter) / (np.sum(union) + 1e-6)
179 | iou_list.append(iou)
180 |
181 | logger.info('=> Metric Calculation <=')
182 | iou_list = np.stack(iou_list)
183 | iou_list = torch.from_numpy(iou_list).to(img.device)
184 | prec_list = []
185 | for thres in torch.arange(0.5, 1.0, 0.1):
186 | tmp = (iou_list > thres).float().mean()
187 | prec_list.append(tmp)
188 | iou = iou_list.mean()
189 | prec = {}
190 | for i, thres in enumerate(range(5, 10)):
191 | key = 'Pr@{}'.format(thres*10)
192 | value = prec_list[i].item()
193 | prec[key] = value
194 | logger.info('oIoU={:.2f}'.format(100.*(I/U)))
195 | logger.info('mIoU={:.2f}'.format(100.*iou.item()))
196 | for k, v in prec.items():
197 | logger.info('{}: {:.2f}.'.format(k, 100.*v))
198 |
199 | return iou.item(), prec
200 |
--------------------------------------------------------------------------------
/image/framework.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SooLab/CGFormer/766d7fbe0c0101c80806e2499bb4d6960cfc5f4a/image/framework.jpg
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .segmenter import CGFormer
2 | from loguru import logger
3 | import torch.nn as nn
4 | from .backbone import MultiModalSwinTransformer
5 |
6 |
7 | def build_model(args):
8 | # initialize the SwinTransformer backbone with the specified version
9 | if args.swin_type == 'tiny':
10 | embed_dim = 96
11 | depths = [2, 2, 6, 2]
12 | num_heads = [3, 6, 12, 24]
13 | elif args.swin_type == 'small':
14 | embed_dim = 96
15 | depths = [2, 2, 18, 2]
16 | num_heads = [3, 6, 12, 24]
17 | elif args.swin_type == 'base':
18 | embed_dim = 128
19 | depths = [2, 2, 18, 2]
20 | num_heads = [4, 8, 16, 32]
21 | elif args.swin_type == 'large':
22 | embed_dim = 192
23 | depths = [2, 2, 18, 2]
24 | num_heads = [6, 12, 24, 48]
25 | else:
26 | assert False
27 | # args.window12 added for test.py because state_dict is loaded after model initialization
28 | if 'window12' in args.swin_pretrain or args.window12:
29 | logger.info('Window size 12!')
30 | window_size = 12
31 | else:
32 | window_size = 7
33 |
34 | if args.mha:
35 | mha = args.mha.split('-') # if non-empty, then ['a', 'b', 'c', 'd']
36 | mha = [int(a) for a in mha]
37 | else:
38 | mha = [1, 1, 1, 1]
39 |
40 | out_indices = (0, 1, 2, 3)
41 | backbone = MultiModalSwinTransformer(embed_dim=embed_dim, depths=depths, num_heads=num_heads,
42 | window_size=window_size,
43 | ape=False, drop_path_rate=0.3, patch_norm=True,
44 | out_indices=out_indices,
45 | use_checkpoint=False, num_heads_fusion=mha,
46 | fusion_drop=args.fusion_drop
47 | )
48 | if args.swin_pretrain:
49 | logger.info('Initializing Multi-modal Swin Transformer weights from ' + args.swin_pretrain)
50 | backbone.init_weights(pretrained=args.swin_pretrain)
51 | else:
52 | logger.info('Randomly initialize Multi-modal Swin Transformer weights.')
53 | backbone.init_weights()
54 |
55 | model = CGFormer(backbone, args)
56 |
57 | return model
58 |
59 |
60 | def build_segmenter(args, DDP=True, OPEN=False):
61 | model = build_model(args)
62 | if DDP:
63 | if args.sync_bn:
64 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
65 | model = nn.parallel.DistributedDataParallel(model.cuda(),
66 | device_ids=[args.gpu],
67 | find_unused_parameters=True
68 | )
69 |
70 | single_model = model.module
71 | if OPEN:
72 | for p in single_model.backbone.parameters():
73 | p.requires_grad_(False)
74 | param_list = [
75 | {
76 | "params": [
77 | p
78 | for n, p in single_model.named_parameters()
79 | if "backbone" not in n and "text_encoder" not in n and p.requires_grad
80 | ],
81 |
82 | },
83 | {
84 | "params": [
85 | p
86 | for n, p in single_model.named_parameters()
87 | if "pwam" in n and p.requires_grad
88 | ],
89 |
90 | },
91 | {
92 | "params": [p for n, p in single_model.named_parameters() if "backbone" in n and "pwam" not in n and p.requires_grad],
93 | "lr": args.lr_backbone,
94 | },
95 | {
96 | "params": [p for n, p in single_model.named_parameters() if "text_encoder" in n and p.requires_grad],
97 | "lr": args.lr_text_encoder,
98 | },
99 | ]
100 |
101 | return model, param_list
102 | else:
103 | model = nn.DataParallel(model).cuda()
104 | return model
105 |
--------------------------------------------------------------------------------
/model/layers.py:
--------------------------------------------------------------------------------
1 | from einops import rearrange
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from timm.models.layers import trunc_normal_
6 |
7 | def l2norm(X, dim=-1, eps=1e-12):
8 | """
9 | L2-normalize columns of X
10 | """
11 | norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps
12 | X = torch.div(X, norm)
13 | return X
14 |
15 |
16 | class Mlp(nn.Module):
17 |
18 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
19 | super().__init__()
20 | out_features = out_features or in_features
21 | hidden_features = hidden_features or in_features
22 | self.fc1 = nn.Linear(in_features, hidden_features)
23 | self.act = act_layer()
24 | self.fc2 = nn.Linear(hidden_features, out_features)
25 | self.drop = nn.Dropout(drop)
26 |
27 | def forward(self, x):
28 | x = self.fc1(x)
29 | x = self.act(x)
30 | x = self.drop(x)
31 | x = self.fc2(x)
32 | x = self.drop(x)
33 | return x
34 |
35 | def conv_layer(in_dim, out_dim, kernel_size=1, padding=0, stride=1):
36 | return nn.Sequential(
37 | nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias=False),
38 | nn.BatchNorm2d(out_dim), nn.ReLU(True))
39 |
40 | def hard_softmax(logits, dim):
41 | y_soft = logits.softmax(dim)
42 | # Straight through.
43 | index = y_soft.max(dim, keepdim=True)[1]
44 | y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
45 | ret = y_hard - y_soft.detach() + y_soft
46 | return ret
47 |
48 | def gumbel_softmax(logits: torch.Tensor, tau: float = 1, dim: int = -2) -> torch.Tensor:
49 | gumbel_dist = torch.distributions.gumbel.Gumbel(
50 | torch.tensor(0., device=logits.device, dtype=logits.dtype),
51 | torch.tensor(1., device=logits.device, dtype=logits.dtype))
52 | gumbels = gumbel_dist.sample(logits.shape)
53 |
54 | gumbels = (logits + gumbels) / tau
55 | y_soft = gumbels.softmax(dim)
56 |
57 | index = y_soft.max(dim, keepdim=True)[1]
58 | y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
59 | ret = y_hard - y_soft.detach() + y_soft
60 |
61 | return ret
62 |
63 | class Fusion(nn.Module):
64 | def __init__(self, in_dim_1, in_dim_2, out_dim, bias=False) -> None:
65 | super().__init__()
66 |
67 | self.fusion = nn.Sequential(
68 | nn.Conv2d(in_dim_1+in_dim_2, out_dim, 3, padding=1, bias=bias),
69 | nn.BatchNorm2d(out_dim),
70 | nn.ReLU(),
71 | nn.Conv2d(out_dim, out_dim, 3, padding=1, bias=bias),
72 | nn.BatchNorm2d(out_dim),
73 | nn.ReLU(),
74 | )
75 |
76 | def forward(self, in_1, in_2):
77 | if in_1.shape[-1] < in_2.shape[-1]:
78 | in_1 = F.interpolate(in_1, size=in_2.shape[-2:], mode='bilinear', align_corners=True)
79 | elif in_1.shape[-1] > in_2.shape[-1]:
80 | in_2 = F.interpolate(in_2, size=in_1.shape[-2:], mode='bilinear', align_corners=True)
81 |
82 | x = torch.cat((in_1, in_2), dim=1)
83 | x = self.fusion(x)
84 | return x
85 |
86 | class DProjector(nn.Module):
87 | def __init__(self, text_dim=512, in_dim=512, kernel_size=1):
88 | super().__init__()
89 | self.in_dim = in_dim
90 | self.kernel_size = kernel_size
91 | # visual projector
92 |
93 | self.vis = nn.Sequential( # os16 -> os4
94 | nn.Upsample(scale_factor=2, mode='bilinear'),
95 | conv_layer(in_dim, in_dim, 3, padding=1),
96 | nn.Upsample(scale_factor=2, mode='bilinear'),
97 | conv_layer(in_dim, in_dim, 3, padding=1),
98 | nn.Conv2d(in_dim, in_dim, 1))
99 |
100 | # textual projector
101 | out_dim = 1 * in_dim * kernel_size * kernel_size + 1
102 | self.txt = nn.Linear(text_dim, out_dim)
103 |
104 | def forward(self, x, text):
105 | '''
106 | x: b, 512, 104, 104
107 | text: b, 512
108 | '''
109 | x = self.vis(x) # Eq. 8
110 |
111 | B, C, H, W = x.size()
112 | # 1, b*256, 104, 104
113 | x = x.reshape(1, B * C, H, W)
114 | # txt: b, 1, (256*3*3 + 1) -> b, 1, 256, 3, 3 / b
115 | text = self.txt(text) # Eq. 8
116 |
117 | weight, bias = text[:, :-1], text[:, -1]
118 | weight = weight.reshape(B, C, self.kernel_size, self.kernel_size)
119 | # Conv2d - 1, b*256, 104, 104 -> 1, b, 104, 104
120 | out = F.conv2d(x,
121 | weight,
122 | padding=1,
123 | groups=B,
124 | bias=bias)
125 |
126 | # b, 1, 104, 104
127 | out = out.transpose(0,1)
128 | return out
129 |
130 |
131 | class CrossAttn(nn.Module):
132 | def __init__(self,
133 | q_dim,
134 | kv_dim,
135 | hidden_dim,
136 | num_heads,
137 | out_dim=None,
138 | qkv_bias=False,
139 | qk_scale=None,
140 | attn_drop=0.,
141 | proj_drop=0.,
142 | qkv_fuse=False):
143 | super().__init__()
144 | if out_dim is None:
145 | out_dim = q_dim
146 | self.num_heads = num_heads
147 | head_dim = hidden_dim // num_heads
148 | self.scale = qk_scale or head_dim**-0.5
149 | self.qkv_fuse = qkv_fuse
150 |
151 | self.q_proj = nn.Linear(q_dim, hidden_dim, bias=qkv_bias)
152 | self.k_proj = nn.Linear(kv_dim, hidden_dim, bias=qkv_bias)
153 | self.v_proj = nn.Linear(kv_dim, hidden_dim, bias=qkv_bias)
154 | self.attn_drop = nn.Dropout(attn_drop)
155 | self.proj = nn.Linear(hidden_dim, out_dim)
156 | self.proj_drop = nn.Dropout(proj_drop)
157 |
158 | def forward(self, query, key, value=None, mask=None):
159 | B, N, C = query.shape
160 | if value is None:
161 | value = key
162 | S = key.size(1)
163 | # [B, nh, N, C//nh]
164 | q = rearrange(self.q_proj(query), 'b n (h c)-> b h n c', h=self.num_heads, b=B, n=N, c=C // self.num_heads)
165 | # [B, nh, S, C//nh]
166 | k = rearrange(self.k_proj(key), 'b n (h c)-> b h n c', h=self.num_heads, b=B, c=C // self.num_heads)
167 | # [B, nh, S, C//nh]
168 | v = rearrange(self.v_proj(value), 'b n (h c)-> b h n c', h=self.num_heads, b=B, c=C // self.num_heads)
169 | # [B, nh, N, S]
170 |
171 | if mask is not None:
172 | mask = mask[:,None,:,None].expand(-1, self.num_heads, -1, -1) # b nh S 1
173 | k = k * mask
174 | v = v * mask
175 | attn = (q @ k.transpose(-2, -1)) * self.scale
176 | attn = attn + (1e4*mask.transpose(-2,-1)-1e4) # b nh 1 S
177 | else:
178 | attn = (q @ k.transpose(-2, -1)) * self.scale
179 | attn = attn.softmax(dim=-1)
180 | attn = self.attn_drop(attn)
181 |
182 | assert attn.shape == (B, self.num_heads, N, S)
183 | # [B, nh, N, C//nh] -> [B, N, C]
184 | out = rearrange(attn @ v, 'b h n c -> b n (h c)', h=self.num_heads, b=B, n=N, c=C // self.num_heads)
185 | out = self.proj(out)
186 | out = self.proj_drop(out)
187 | return out
188 |
189 | class OriLoadToken(nn.Module):
190 | def __init__(self, token_dim, bias, drop) -> None:
191 | super().__init__()
192 | self.cross_attn = CrossAttn(
193 | q_dim=token_dim,
194 | kv_dim=768,
195 | hidden_dim=token_dim,
196 | num_heads=1,
197 | out_dim=token_dim,
198 | qkv_bias=bias,
199 | attn_drop=drop,
200 | proj_drop=drop,
201 | )
202 | self.normq = nn.LayerNorm(token_dim)
203 | self.normk = nn.LayerNorm(768)
204 |
205 | self.normq = nn.LayerNorm(token_dim)
206 | self.normk = nn.LayerNorm(768)
207 |
208 | def forward(self, tokens, text, pad_mask):
209 | tokens = tokens + self.cross_attn(query=self.normq(tokens), key=self.normk(text.permute(0,2,1)), mask=pad_mask[...,0])
210 | return tokens
211 |
212 | # updated version
213 | class LoadToken(nn.Module):
214 | def __init__(self, token_dim, bias, drop) -> None:
215 | super().__init__()
216 | self.cross_attn = CrossAttn(
217 | q_dim=token_dim,
218 | kv_dim=768,
219 | hidden_dim=token_dim,
220 | num_heads=1,
221 | out_dim=token_dim,
222 | qkv_bias=bias,
223 | attn_drop=drop,
224 | proj_drop=drop,
225 | )
226 | self.normq = nn.LayerNorm(token_dim)
227 | self.normk = nn.LayerNorm(768)
228 |
229 | def forward(self, tokens, text, pad_mask):
230 | ltoken, ttoken = torch.split(tokens, [tokens.shape[1]-1,1], dim=1)
231 | ttoken = ttoken + self.cross_attn(query=self.normq(ttoken), key=self.normk(text.permute(0,2,1)), mask=pad_mask[...,0])
232 | tokens = torch.cat((ltoken, ttoken), dim=1)
233 | return tokens
234 |
235 | class LoadLayer(nn.Module):
236 | def __init__(self, token_dim, drop, bias=False, pe_shape=None) -> None:
237 | super().__init__()
238 | if pe_shape >30:
239 | self.loadtoken = LoadToken(
240 | token_dim=token_dim,
241 | bias=bias,
242 | drop=drop
243 | )
244 | self.norm = nn.LayerNorm(token_dim)
245 | self.mlp = Mlp(token_dim, token_dim*2, token_dim)
246 | self.positional_embedding = nn.Parameter(torch.randn(pe_shape**2, token_dim) / token_dim ** 0.5)
247 | self.pe_shape = pe_shape
248 |
249 | def forward(self, tokens, text, pad_mask):
250 | if self.pe_shape > 30:
251 | tokens = self.loadtoken(tokens, text, pad_mask)
252 | tokens = self.mlp(self.norm(tokens))
253 | return tokens, self.positional_embedding
254 |
255 |
256 | class CGAttention(nn.Module):
257 | def __init__(self, token_dim, vis_dim, hidden_dim, drop=0., bias=True) -> None:
258 | super().__init__()
259 | self.norm_v = nn.LayerNorm(vis_dim)
260 | self.norm_t = nn.LayerNorm(token_dim)
261 | self.q_proj = nn.Linear(token_dim, hidden_dim, bias=bias)
262 | self.k_proj = nn.Linear(vis_dim, hidden_dim, bias=bias)
263 | self.v_proj = nn.Linear(vis_dim, hidden_dim, bias=bias)
264 | self.proj = nn.Linear(hidden_dim, token_dim)
265 | self.proj_drop = nn.Dropout(drop)
266 | self.norm = nn.LayerNorm(token_dim)
267 | self.mlp = Mlp(token_dim, token_dim*2, token_dim, drop=drop)
268 | self.tau = nn.Parameter(torch.ones(1), requires_grad=True)
269 |
270 | def with_pe(self, vis, pe):
271 | return vis + pe
272 |
273 | def forward(self, tokens, vis, pe=None):
274 | b, c, h , w = vis.shape
275 | vis = rearrange(vis, 'b c h w -> b (h w) c')
276 | if pe is not None:
277 | vis = self.with_pe(vis, pe)
278 | vis = self.norm_v(vis)
279 | q = self.q_proj(self.norm_t(tokens))
280 | k = self.k_proj(vis)
281 | v = self.v_proj(vis)
282 |
283 | q = l2norm(q, dim=-1)
284 | k = l2norm(k, dim=-1)
285 | raw_attn = (q @ k.transpose(-2, -1))
286 | tau = torch.clamp(self.tau, max=0).exp()
287 | attn = gumbel_softmax(raw_attn, dim=-2, tau=tau)
288 | hit_map = attn
289 | attn = attn / (attn.sum(dim=-1, keepdim=True) + 1)
290 | new_tokens = attn @ v
291 | new_tokens = self.proj_drop(self.proj(new_tokens))
292 | new_tokens = self.mlp(self.norm(new_tokens+tokens))
293 | return new_tokens, hit_map.reshape(b, -1, h, w)
294 |
295 | class Decoder(nn.Module):
296 | def __init__(self, args) -> None:
297 | super().__init__()
298 | '''
299 | c1 :128, 120, 120
300 | c2 :256, 60, 60
301 | c3 :512, 30, 30
302 | c4 :1024, 15 ,15
303 | '''
304 | token_dim = args.token_dim
305 | self.tokens = nn.Embedding(args.num_token, token_dim)
306 | trunc_normal_(self.tokens.weight, std=0.02)
307 |
308 | dims = [1024, 512, 256, 128]
309 | pe_shapes = [30, 60, 120]
310 |
311 | self.layers = []
312 | for pe_shape in pe_shapes:
313 | self.layers.append(LoadLayer(token_dim, drop=.1, bias=False, pe_shape=pe_shape))
314 | self.cgattention1 = CGAttention(token_dim=token_dim,
315 | vis_dim=token_dim,
316 | hidden_dim=token_dim,
317 | drop=.1,
318 | bias=True)
319 | self.cgattention2 = CGAttention(token_dim=token_dim,
320 | vis_dim=token_dim,
321 | hidden_dim=token_dim,
322 | drop=.1,
323 | bias=True)
324 | self.layers = nn.ModuleList(self.layers)
325 | self.fuses = []
326 | for dim in [dims[0], dims[2], dims[3]]:
327 | self.fuses.append(Fusion(dim, token_dim, token_dim, bias=True))
328 | self.fuses = nn.ModuleList(self.fuses)
329 | self.proj = DProjector(text_dim=token_dim, in_dim=token_dim)
330 |
331 | def forward(self, vis, text, pad_mask):
332 | x_c4, x_c3, x_c2, x_c1 = vis
333 | tokens = self.tokens.weight[None,...].expand(x_c1.shape[0], -1, -1)
334 | maps = []
335 | v = x_c4
336 | for load, layer, fuse, v_ in zip(self.layers,[self.cgattention1,self.cgattention2,self.cgattention2], self.fuses, [x_c3, x_c2, x_c1]):
337 | v = fuse(v, v_)
338 | tokens, pe = load(tokens, text, pad_mask)
339 | tokens, hitmap = layer(tokens, v, pe=pe)
340 | maps.append(hitmap)
341 | out = self.proj(v, tokens[:,-1])
342 | return out, maps
343 |
--------------------------------------------------------------------------------
/model/mmcv_custom/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from .checkpoint import load_checkpoint
4 |
5 | __all__ = ['load_checkpoint']
6 |
--------------------------------------------------------------------------------
/model/mmcv_custom/checkpoint.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Open-MMLab. All rights reserved.
2 | import io
3 | import os
4 | import os.path as osp
5 | import pkgutil
6 | import time
7 | import warnings
8 | from collections import OrderedDict
9 | from importlib import import_module
10 | from tempfile import TemporaryDirectory
11 |
12 | import torch
13 | import torchvision
14 | from torch.optim import Optimizer
15 | from torch.utils import model_zoo
16 | from torch.nn import functional as F
17 |
18 | import mmcv
19 | from mmcv.fileio import FileClient
20 | from mmcv.fileio import load as load_file
21 | from mmcv.parallel import is_module_wrapper
22 | from mmcv.utils import mkdir_or_exist
23 | from mmcv.runner import get_dist_info
24 |
25 | ENV_MMCV_HOME = 'MMCV_HOME'
26 | ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
27 | DEFAULT_CACHE_DIR = '~/.cache'
28 |
29 |
30 | def _get_mmcv_home():
31 | mmcv_home = os.path.expanduser(
32 | os.getenv(
33 | ENV_MMCV_HOME,
34 | os.path.join(
35 | os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
36 |
37 | mkdir_or_exist(mmcv_home)
38 | return mmcv_home
39 |
40 |
41 | def load_state_dict(module, state_dict, strict=False, logger=None):
42 | """Load state_dict to a module.
43 |
44 | This method is modified from :meth:`torch.nn.Module.load_state_dict`.
45 | Default value for ``strict`` is set to ``False`` and the message for
46 | param mismatch will NOT be shown if strict is False.
47 |
48 | Args:
49 | module (Module): Module that receives the state_dict.
50 | state_dict (OrderedDict): Weights.
51 | strict (bool): whether to strictly enforce that the keys
52 | in :attr:`state_dict` match the keys returned by this module's
53 | :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
54 | logger (:obj:`logging.Logger`, optional): Logger to log the error
55 | message. If not specified, print function will be used.
56 | """
57 | unexpected_keys = []
58 | all_missing_keys = []
59 | err_msg = []
60 |
61 | metadata = getattr(state_dict, '_metadata', None)
62 | state_dict = state_dict.copy()
63 | if metadata is not None:
64 | state_dict._metadata = metadata
65 |
66 | # use _load_from_state_dict to enable checkpoint version control
67 | def load(module, prefix=''):
68 | # recursively check parallel module in case that the model has a
69 | # complicated structure, e.g., nn.Module(nn.Module(DDP))
70 | if is_module_wrapper(module):
71 | module = module.module
72 | local_metadata = {} if metadata is None else metadata.get(
73 | prefix[:-1], {})
74 | module._load_from_state_dict(state_dict, prefix, local_metadata, True,
75 | all_missing_keys, unexpected_keys,
76 | err_msg)
77 | for name, child in module._modules.items():
78 | if child is not None:
79 | load(child, prefix + name + '.')
80 |
81 | load(module)
82 | load = None # break load->load reference cycle
83 |
84 | # ignore "num_batches_tracked" of BN layers
85 | missing_keys = [
86 | key for key in all_missing_keys if 'num_batches_tracked' not in key
87 | ]
88 |
89 | if unexpected_keys:
90 | err_msg.append('unexpected key in source '
91 | f'state_dict: {", ".join(unexpected_keys)}\n')
92 | if missing_keys:
93 | err_msg.append(
94 | f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
95 |
96 | if strict:
97 | rank, _ = get_dist_info()
98 | if len(err_msg) > 0 and rank == 0:
99 | err_msg.insert(
100 | 0, 'The model and loaded state dict do not match exactly\n')
101 | err_msg = '\n'.join(err_msg)
102 | if strict:
103 | raise RuntimeError(err_msg)
104 | elif logger is not None:
105 | logger.warning(err_msg)
106 | else:
107 | print(err_msg)
108 |
109 |
110 | def load_url_dist(url, model_dir=None):
111 | """In distributed setting, this function only download checkpoint at local
112 | rank 0."""
113 | rank, world_size = get_dist_info()
114 | rank = int(os.environ.get('LOCAL_RANK', rank))
115 | if rank == 0:
116 | checkpoint = model_zoo.load_url(url, model_dir=model_dir)
117 | if world_size > 1:
118 | torch.distributed.barrier()
119 | if rank > 0:
120 | checkpoint = model_zoo.load_url(url, model_dir=model_dir)
121 | return checkpoint
122 |
123 |
124 | def load_pavimodel_dist(model_path, map_location=None):
125 | """In distributed setting, this function only download checkpoint at local
126 | rank 0."""
127 | try:
128 | from pavi import modelcloud
129 | except ImportError:
130 | raise ImportError(
131 | 'Please install pavi to load checkpoint from modelcloud.')
132 | rank, world_size = get_dist_info()
133 | rank = int(os.environ.get('LOCAL_RANK', rank))
134 | if rank == 0:
135 | model = modelcloud.get(model_path)
136 | with TemporaryDirectory() as tmp_dir:
137 | downloaded_file = osp.join(tmp_dir, model.name)
138 | model.download(downloaded_file)
139 | checkpoint = torch.load(downloaded_file, map_location=map_location)
140 | if world_size > 1:
141 | torch.distributed.barrier()
142 | if rank > 0:
143 | model = modelcloud.get(model_path)
144 | with TemporaryDirectory() as tmp_dir:
145 | downloaded_file = osp.join(tmp_dir, model.name)
146 | model.download(downloaded_file)
147 | checkpoint = torch.load(
148 | downloaded_file, map_location=map_location)
149 | return checkpoint
150 |
151 |
152 | def load_fileclient_dist(filename, backend, map_location):
153 | """In distributed setting, this function only download checkpoint at local
154 | rank 0."""
155 | rank, world_size = get_dist_info()
156 | rank = int(os.environ.get('LOCAL_RANK', rank))
157 | allowed_backends = ['ceph']
158 | if backend not in allowed_backends:
159 | raise ValueError(f'Load from Backend {backend} is not supported.')
160 | if rank == 0:
161 | fileclient = FileClient(backend=backend)
162 | buffer = io.BytesIO(fileclient.get(filename))
163 | checkpoint = torch.load(buffer, map_location=map_location)
164 | if world_size > 1:
165 | torch.distributed.barrier()
166 | if rank > 0:
167 | fileclient = FileClient(backend=backend)
168 | buffer = io.BytesIO(fileclient.get(filename))
169 | checkpoint = torch.load(buffer, map_location=map_location)
170 | return checkpoint
171 |
172 |
173 | def get_torchvision_models():
174 | model_urls = dict()
175 | for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
176 | if ispkg:
177 | continue
178 | _zoo = import_module(f'torchvision.models.{name}')
179 | if hasattr(_zoo, 'model_urls'):
180 | _urls = getattr(_zoo, 'model_urls')
181 | model_urls.update(_urls)
182 | return model_urls
183 |
184 |
185 | def get_external_models():
186 | mmcv_home = _get_mmcv_home()
187 | default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
188 | default_urls = load_file(default_json_path)
189 | assert isinstance(default_urls, dict)
190 | external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
191 | if osp.exists(external_json_path):
192 | external_urls = load_file(external_json_path)
193 | assert isinstance(external_urls, dict)
194 | default_urls.update(external_urls)
195 |
196 | return default_urls
197 |
198 |
199 | def get_mmcls_models():
200 | mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
201 | mmcls_urls = load_file(mmcls_json_path)
202 |
203 | return mmcls_urls
204 |
205 |
206 | def get_deprecated_model_names():
207 | deprecate_json_path = osp.join(mmcv.__path__[0],
208 | 'model_zoo/deprecated.json')
209 | deprecate_urls = load_file(deprecate_json_path)
210 | assert isinstance(deprecate_urls, dict)
211 |
212 | return deprecate_urls
213 |
214 |
215 | def _process_mmcls_checkpoint(checkpoint):
216 | state_dict = checkpoint['state_dict']
217 | new_state_dict = OrderedDict()
218 | for k, v in state_dict.items():
219 | if k.startswith('backbone.'):
220 | new_state_dict[k[9:]] = v
221 | new_checkpoint = dict(state_dict=new_state_dict)
222 |
223 | return new_checkpoint
224 |
225 |
226 | def _load_checkpoint(filename, map_location=None):
227 | """Load checkpoint from somewhere (modelzoo, file, url).
228 |
229 | Args:
230 | filename (str): Accept local filepath, URL, ``torchvision://xxx``,
231 | ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
232 | details.
233 | map_location (str | None): Same as :func:`torch.load`. Default: None.
234 |
235 | Returns:
236 | dict | OrderedDict: The loaded checkpoint. It can be either an
237 | OrderedDict storing model weights or a dict containing other
238 | information, which depends on the checkpoint.
239 | """
240 | if filename.startswith('modelzoo://'):
241 | warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
242 | 'use "torchvision://" instead')
243 | model_urls = get_torchvision_models()
244 | model_name = filename[11:]
245 | checkpoint = load_url_dist(model_urls[model_name])
246 | elif filename.startswith('torchvision://'):
247 | model_urls = get_torchvision_models()
248 | model_name = filename[14:]
249 | checkpoint = load_url_dist(model_urls[model_name])
250 | elif filename.startswith('open-mmlab://'):
251 | model_urls = get_external_models()
252 | model_name = filename[13:]
253 | deprecated_urls = get_deprecated_model_names()
254 | if model_name in deprecated_urls:
255 | warnings.warn(f'open-mmlab://{model_name} is deprecated in favor '
256 | f'of open-mmlab://{deprecated_urls[model_name]}')
257 | model_name = deprecated_urls[model_name]
258 | model_url = model_urls[model_name]
259 | # check if is url
260 | if model_url.startswith(('http://', 'https://')):
261 | checkpoint = load_url_dist(model_url)
262 | else:
263 | filename = osp.join(_get_mmcv_home(), model_url)
264 | if not osp.isfile(filename):
265 | raise IOError(f'{filename} is not a checkpoint file')
266 | checkpoint = torch.load(filename, map_location=map_location)
267 | elif filename.startswith('mmcls://'):
268 | model_urls = get_mmcls_models()
269 | model_name = filename[8:]
270 | checkpoint = load_url_dist(model_urls[model_name])
271 | checkpoint = _process_mmcls_checkpoint(checkpoint)
272 | elif filename.startswith(('http://', 'https://')):
273 | checkpoint = load_url_dist(filename)
274 | elif filename.startswith('pavi://'):
275 | model_path = filename[7:]
276 | checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
277 | elif filename.startswith('s3://'):
278 | checkpoint = load_fileclient_dist(
279 | filename, backend='ceph', map_location=map_location)
280 | else:
281 | if not osp.isfile(filename):
282 | raise IOError(f'{filename} is not a checkpoint file')
283 | checkpoint = torch.load(filename, map_location=map_location)
284 | return checkpoint
285 |
286 |
287 | def load_checkpoint(model,
288 | filename,
289 | map_location='cpu',
290 | strict=False,
291 | logger=None):
292 | """Load checkpoint from a file or URI.
293 |
294 | Args:
295 | model (Module): Module to load checkpoint.
296 | filename (str): Accept local filepath, URL, ``torchvision://xxx``,
297 | ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
298 | details.
299 | map_location (str): Same as :func:`torch.load`.
300 | strict (bool): Whether to allow different params for the model and
301 | checkpoint.
302 | logger (:mod:`logging.Logger` or None): The logger for error message.
303 |
304 | Returns:
305 | dict or OrderedDict: The loaded checkpoint.
306 | """
307 | checkpoint = _load_checkpoint(filename, map_location)
308 | # OrderedDict is a subclass of dict
309 | if not isinstance(checkpoint, dict):
310 | raise RuntimeError(
311 | f'No state_dict found in checkpoint file {filename}')
312 | # get state_dict from checkpoint
313 | if 'state_dict' in checkpoint:
314 | state_dict = checkpoint['state_dict']
315 | elif 'model' in checkpoint:
316 | state_dict = checkpoint['model']
317 | else:
318 | state_dict = checkpoint
319 | # strip prefix of state_dict
320 | if list(state_dict.keys())[0].startswith('module.'):
321 | state_dict = {k[7:]: v for k, v in state_dict.items()}
322 | # for upper net weights only
323 | if list(state_dict.keys())[0].startswith('backbone.'):
324 | print('Start stripping upper net pre-fix and loading backbone weights to our swin encoder')
325 | state_dict = {k.replace('backbone.', ''): v for k, v in state_dict.items() if k.startswith('backbone.')}
326 | # for MoBY, load model of online branch
327 | if sorted(list(state_dict.keys()))[0].startswith('encoder'):
328 | state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
329 |
330 | # reshape absolute position embedding
331 | if state_dict.get('absolute_pos_embed') is not None:
332 | absolute_pos_embed = state_dict['absolute_pos_embed']
333 | N1, L, C1 = absolute_pos_embed.size()
334 | N2, C2, H, W = model.absolute_pos_embed.size()
335 | if N1 != N2 or C1 != C2 or L != H*W:
336 | logger.warning("Error in loading absolute_pos_embed, pass")
337 | else:
338 | state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2)
339 |
340 | # interpolate position bias table if needed
341 | relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
342 | for table_key in relative_position_bias_table_keys:
343 | table_pretrained = state_dict[table_key]
344 | table_current = model.state_dict()[table_key]
345 | L1, nH1 = table_pretrained.size()
346 | L2, nH2 = table_current.size()
347 | if nH1 != nH2:
348 | logger.warning(f"Error in loading {table_key}, pass")
349 | else:
350 | if L1 != L2:
351 | S1 = int(L1 ** 0.5)
352 | S2 = int(L2 ** 0.5)
353 | table_pretrained_resized = F.interpolate(
354 | table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
355 | size=(S2, S2), mode='bicubic')
356 | state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0)
357 |
358 | # load state_dict
359 | load_state_dict(model, state_dict, strict, logger)
360 | return checkpoint
361 |
362 |
363 | def weights_to_cpu(state_dict):
364 | """Copy a model state_dict to cpu.
365 |
366 | Args:
367 | state_dict (OrderedDict): Model weights on GPU.
368 |
369 | Returns:
370 | OrderedDict: Model weights on GPU.
371 | """
372 | state_dict_cpu = OrderedDict()
373 | for key, val in state_dict.items():
374 | state_dict_cpu[key] = val.cpu()
375 | return state_dict_cpu
376 |
377 |
378 | def _save_to_state_dict(module, destination, prefix, keep_vars):
379 | """Saves module state to `destination` dictionary.
380 |
381 | This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
382 |
383 | Args:
384 | module (nn.Module): The module to generate state_dict.
385 | destination (dict): A dict where state will be stored.
386 | prefix (str): The prefix for parameters and buffers used in this
387 | module.
388 | """
389 | for name, param in module._parameters.items():
390 | if param is not None:
391 | destination[prefix + name] = param if keep_vars else param.detach()
392 | for name, buf in module._buffers.items():
393 | # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
394 | if buf is not None:
395 | destination[prefix + name] = buf if keep_vars else buf.detach()
396 |
397 |
398 | def get_state_dict(module, destination=None, prefix='', keep_vars=False):
399 | """Returns a dictionary containing a whole state of the module.
400 |
401 | Both parameters and persistent buffers (e.g. running averages) are
402 | included. Keys are corresponding parameter and buffer names.
403 |
404 | This method is modified from :meth:`torch.nn.Module.state_dict` to
405 | recursively check parallel module in case that the model has a complicated
406 | structure, e.g., nn.Module(nn.Module(DDP)).
407 |
408 | Args:
409 | module (nn.Module): The module to generate state_dict.
410 | destination (OrderedDict): Returned dict for the state of the
411 | module.
412 | prefix (str): Prefix of the key.
413 | keep_vars (bool): Whether to keep the variable property of the
414 | parameters. Default: False.
415 |
416 | Returns:
417 | dict: A dictionary containing a whole state of the module.
418 | """
419 | # recursively check parallel module in case that the model has a
420 | # complicated structure, e.g., nn.Module(nn.Module(DDP))
421 | if is_module_wrapper(module):
422 | module = module.module
423 |
424 | # below is the same as torch.nn.Module.state_dict()
425 | if destination is None:
426 | destination = OrderedDict()
427 | destination._metadata = OrderedDict()
428 | destination._metadata[prefix[:-1]] = local_metadata = dict(
429 | version=module._version)
430 | _save_to_state_dict(module, destination, prefix, keep_vars)
431 | for name, child in module._modules.items():
432 | if child is not None:
433 | get_state_dict(
434 | child, destination, prefix + name + '.', keep_vars=keep_vars)
435 | for hook in module._state_dict_hooks.values():
436 | hook_result = hook(module, destination, prefix, local_metadata)
437 | if hook_result is not None:
438 | destination = hook_result
439 | return destination
440 |
441 |
442 | def save_checkpoint(model, filename, optimizer=None, meta=None):
443 | """Save checkpoint to file.
444 |
445 | The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
446 | ``optimizer``. By default ``meta`` will contain version and time info.
447 |
448 | Args:
449 | model (Module): Module whose params are to be saved.
450 | filename (str): Checkpoint filename.
451 | optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
452 | meta (dict, optional): Metadata to be saved in checkpoint.
453 | """
454 | if meta is None:
455 | meta = {}
456 | elif not isinstance(meta, dict):
457 | raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
458 | meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
459 |
460 | if is_module_wrapper(model):
461 | model = model.module
462 |
463 | if hasattr(model, 'CLASSES') and model.CLASSES is not None:
464 | # save class name to the meta
465 | meta.update(CLASSES=model.CLASSES)
466 |
467 | checkpoint = {
468 | 'meta': meta,
469 | 'state_dict': weights_to_cpu(get_state_dict(model))
470 | }
471 | # save optimizer state dict in the checkpoint
472 | if isinstance(optimizer, Optimizer):
473 | checkpoint['optimizer'] = optimizer.state_dict()
474 | elif isinstance(optimizer, dict):
475 | checkpoint['optimizer'] = {}
476 | for name, optim in optimizer.items():
477 | checkpoint['optimizer'][name] = optim.state_dict()
478 |
479 | if filename.startswith('pavi://'):
480 | try:
481 | from pavi import modelcloud
482 | from pavi.exception import NodeNotFoundError
483 | except ImportError:
484 | raise ImportError(
485 | 'Please install pavi to load checkpoint from modelcloud.')
486 | model_path = filename[7:]
487 | root = modelcloud.Folder()
488 | model_dir, model_name = osp.split(model_path)
489 | try:
490 | model = modelcloud.get(model_dir)
491 | except NodeNotFoundError:
492 | model = root.create_training_model(model_dir)
493 | with TemporaryDirectory() as tmp_dir:
494 | checkpoint_file = osp.join(tmp_dir, model_name)
495 | with open(checkpoint_file, 'wb') as f:
496 | torch.save(checkpoint, f)
497 | f.flush()
498 | model.create_file(checkpoint_file, name=model_name)
499 | else:
500 | mmcv.mkdir_or_exist(osp.dirname(filename))
501 | # immediately flush buffer
502 | with open(filename, 'wb') as f:
503 | torch.save(checkpoint, f)
504 | f.flush()
505 |
--------------------------------------------------------------------------------
/model/segmenter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from .layers import Decoder
4 | import torch.nn.functional as F
5 | from bert.modeling_bert import BertModel
6 |
7 | def dice_loss(inputs, targets):
8 | """
9 | Compute the DICE loss, similar to generalized IOU for masks
10 | Args:
11 | inputs: A float tensor of arbitrary shape.
12 | The predictions for each example.
13 | targets: A float tensor with the same shape as inputs. Stores the binary
14 | classification label for each element in inputs
15 | (0 for the negative class and 1 for the positive class).
16 | """
17 |
18 | inputs = inputs.sigmoid()
19 | inputs = inputs.flatten(1)
20 | targets = targets.flatten(1)
21 | numerator = 2 * (inputs * targets).sum(1)
22 | denominator = inputs.sum(-1) + targets.sum(-1)
23 | loss = 1 - (numerator + 1) / (denominator + 1)
24 | return loss.mean()
25 |
26 | def sigmoid_focal_loss(inputs, targets, alpha: float = 0.25, gamma: float = 2):
27 | """
28 | Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
29 | Args:
30 | inputs: A float tensor of arbitrary shape.
31 | The predictions for each example.
32 | targets: A float tensor with the same shape as inputs. Stores the binary
33 | classification label for each element in inputs
34 | (0 for the negative class and 1 for the positive class).
35 | alpha: (optional) Weighting factor in range (0,1) to balance
36 | positive vs negative examples. Default = -1 (no weighting).
37 | gamma: Exponent of the modulating factor (1 - p_t) to
38 | balance easy vs hard examples.
39 | Returns:
40 | Loss tensor
41 | """
42 |
43 | prob = inputs.sigmoid()
44 | ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
45 | p_t = prob * targets + (1 - prob) * (1 - targets)
46 | loss = ce_loss * ((1 - p_t) ** gamma)
47 |
48 | if alpha >= 0:
49 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
50 | loss = alpha_t * loss
51 | return loss.mean()
52 |
53 |
54 | class CGFormer(nn.Module):
55 | def __init__(self, backbone, args):
56 | super(CGFormer, self).__init__()
57 | self.backbone = backbone
58 | self.decoder = Decoder(args)
59 | self.text_encoder = BertModel.from_pretrained(args.bert)
60 | self.text_encoder.pooler = None
61 |
62 | def forward(self, x, text, l_mask, mask=None):
63 | input_shape = x.shape[-2:]
64 | l_feats = self.text_encoder(text, attention_mask=l_mask)[0] # (6, 10, 768)
65 | l_feats = l_feats.permute(0, 2, 1) # (B, 768, N_l) to make Conv1d happy
66 | l_mask = l_mask.unsqueeze(dim=-1) # (batch, N_l, 1)
67 | ##########################
68 | features = self.backbone(x, l_feats, l_mask)
69 | x_c1, x_c2, x_c3, x_c4 = features
70 | pred, maps = self.decoder([x_c4, x_c3, x_c2, x_c1], l_feats, l_mask)
71 | pred = F.interpolate(pred, input_shape, mode='bilinear', align_corners=True)
72 | # loss
73 | if self.training:
74 | loss = 0.
75 | mask = mask.unsqueeze(1).float()
76 | for m, lam in zip(maps, [0.001,0.01,0.1]):
77 | m = m[:,1].unsqueeze(1)
78 | if m.shape[-2:] != mask.shape[-2:]:
79 | mask_ = F.interpolate(mask, m.shape[-2:], mode='nearest').detach()
80 | loss += dice_loss(m, mask_) * lam
81 | loss += dice_loss(pred, mask) + sigmoid_focal_loss(pred, mask, alpha=-1, gamma=0)
82 | return pred.detach(), mask, loss
83 | else:
84 | return pred.detach(), maps
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers
2 | wandb
3 | lmdb
4 | pyarrow
5 | regex
6 | ftfy
7 | loguru
8 | pycocotools
9 | matplotlib
10 | tqdm
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import warnings
4 |
5 | import cv2
6 | import torch
7 | import torch.nn.parallel
8 | import torch.utils.data
9 | from loguru import logger
10 |
11 | import utils.config as config
12 | from engine.engine import inference
13 | from model import build_segmenter
14 | from utils.dataset import RefDataset
15 | from utils.misc import setup_logger
16 |
17 | warnings.filterwarnings("ignore")
18 | cv2.setNumThreads(0)
19 |
20 |
21 | def get_parser():
22 | parser = argparse.ArgumentParser(
23 | description='Pytorch Referring Expression Segmentation')
24 | parser.add_argument('--config',
25 | default='path to xxx.yaml',
26 | type=str,
27 | help='config file')
28 | parser.add_argument('--opts',
29 | default=None,
30 | nargs=argparse.REMAINDER,
31 | help='override some settings in the config.')
32 | args = parser.parse_args()
33 | assert args.config is not None
34 | cfg = config.load_cfg_from_cfg_file(args.config)
35 | if args.opts is not None:
36 | cfg = config.merge_cfg_from_list(cfg, args.opts)
37 | return cfg
38 |
39 |
40 | @logger.catch
41 | def main():
42 | args = get_parser()
43 | args.output_dir = os.path.join(args.output_folder, args.exp_name)
44 | if args.visualize:
45 | args.vis_dir = os.path.join(args.output_dir, "vis")
46 | os.makedirs(args.vis_dir, exist_ok=True)
47 |
48 | # logger
49 | setup_logger(args.output_dir,
50 | distributed_rank=0,
51 | filename="test.log",
52 | mode="a")
53 | logger.info(args.test_split)
54 |
55 | # build dataset & dataloader
56 | test_data = RefDataset(lmdb_dir=args.test_lmdb,
57 | mask_dir=args.mask_root,
58 | dataset=args.dataset,
59 | split=args.test_split,
60 | mode='test',
61 | input_size=args.input_size,
62 | word_length=args.word_len)
63 | test_loader = torch.utils.data.DataLoader(test_data,
64 | batch_size=1,
65 | shuffle=False,
66 | num_workers=4,
67 | pin_memory=True)
68 |
69 | # build model
70 | model = build_segmenter(args, DDP=False)
71 | logger.info(model)
72 |
73 | if os.path.isfile(args.weight):
74 | logger.info("=> loading checkpoint '{}'".format(args.weight))
75 | checkpoint = torch.load(args.weight)
76 | model.module.load_state_dict(checkpoint['model_state_dict'], strict=True)
77 | logger.info("=> loaded checkpoint '{}'".format(args.weight))
78 | else:
79 | raise ValueError(
80 | "=> resume failed! no checkpoint found at '{}'. Please check args.resume again!"
81 | .format(args.weight))
82 |
83 | # inference
84 | inference(test_loader, model, args)
85 |
86 |
87 | if __name__ == '__main__':
88 | main()
89 |
--------------------------------------------------------------------------------
/tools/data_process.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 |
5 | import cv2
6 | import numpy as np
7 | from tqdm import tqdm
8 |
9 | from refer import REFER
10 |
11 | parser = argparse.ArgumentParser(description='Data preparation')
12 | parser.add_argument('--data_root', type=str)
13 | parser.add_argument('--output_dir', type=str)
14 | parser.add_argument('--dataset',
15 | type=str,
16 | choices=['refcoco', 'refcoco+', 'refcocog', 'refclef'],
17 | default='refcoco')
18 | parser.add_argument('--split', type=str, default='umd')
19 | parser.add_argument('--generate_mask', action='store_true')
20 | args = parser.parse_args()
21 | img_path = os.path.join(args.data_root, 'images', 'train2014')
22 |
23 | h, w = (416, 416)
24 |
25 | refer = REFER(args.data_root, args.dataset, args.split)
26 |
27 | print('dataset [%s_%s] contains: ' % (args.dataset, args.split))
28 | ref_ids = refer.getRefIds()
29 | image_ids = refer.getImgIds()
30 | print('%s expressions for %s refs in %s images.' %
31 | (len(refer.Sents), len(ref_ids), len(image_ids)))
32 |
33 | print('\nAmong them:')
34 | if args.dataset == 'refclef':
35 | if args.split == 'unc':
36 | splits = ['train', 'val', 'testA', 'testB', 'testC']
37 | else:
38 | splits = ['train', 'val', 'test']
39 | elif args.dataset == 'refcoco':
40 | splits = ['train', 'val', 'testA', 'testB']
41 | elif args.dataset == 'refcoco+':
42 | splits = ['train', 'val', 'testA', 'testB']
43 | elif args.dataset == 'refcocog':
44 | splits = ['train', 'val',
45 | 'test'] # we don't have test split for refcocog right now.
46 |
47 | for split in splits:
48 | ref_ids = refer.getRefIds(split=split)
49 | print('%s refs are in split [%s].' % (len(ref_ids), split))
50 |
51 |
52 | def cat_process(cat):
53 | if cat >= 1 and cat <= 11:
54 | cat = cat - 1
55 | elif cat >= 13 and cat <= 25:
56 | cat = cat - 2
57 | elif cat >= 27 and cat <= 28:
58 | cat = cat - 3
59 | elif cat >= 31 and cat <= 44:
60 | cat = cat - 5
61 | elif cat >= 46 and cat <= 65:
62 | cat = cat - 6
63 | elif cat == 67:
64 | cat = cat - 7
65 | elif cat == 70:
66 | cat = cat - 9
67 | elif cat >= 72 and cat <= 82:
68 | cat = cat - 10
69 | elif cat >= 84 and cat <= 90:
70 | cat = cat - 11
71 | return cat
72 |
73 |
74 | def bbox_process(bbox):
75 | x_min = int(bbox[0])
76 | y_min = int(bbox[1])
77 | x_max = x_min + int(bbox[2])
78 | y_max = y_min + int(bbox[3])
79 | return list(map(int, [x_min, y_min, x_max, y_max]))
80 |
81 |
82 | def prepare_dataset(dataset, splits, output_dir, generate_mask=False):
83 | ann_path = os.path.join(output_dir, 'anns', dataset)
84 | mask_path = os.path.join(output_dir, 'masks', dataset)
85 | if not os.path.exists(ann_path):
86 | os.makedirs(ann_path)
87 | if not os.path.exists(mask_path):
88 | os.makedirs(mask_path)
89 |
90 | for split in splits:
91 | dataset_array = []
92 | ref_ids = refer.getRefIds(split=split)
93 | print('Processing split:{} - Len: {}'.format(split, np.alen(ref_ids)))
94 | for i in tqdm(ref_ids):
95 | ref_dict = {}
96 |
97 | refs = refer.Refs[i]
98 | bboxs = refer.getRefBox(i)
99 | sentences = refs['sentences']
100 | image_urls = refer.loadImgs(image_ids=refs['image_id'])[0]
101 | cat = cat_process(refs['category_id'])
102 | image_urls = image_urls['file_name']
103 | if dataset == 'refclef' and image_urls in [
104 | '19579.jpg', '17975.jpg', '19575.jpg'
105 | ]:
106 | continue
107 | box_info = bbox_process(bboxs)
108 |
109 | ref_dict['bbox'] = box_info
110 | ref_dict['cat'] = cat
111 | ref_dict['segment_id'] = i
112 | ref_dict['img_name'] = image_urls
113 |
114 | if generate_mask:
115 | cv2.imwrite(os.path.join(mask_path,
116 | str(i) + '.png'),
117 | refer.getMask(refs)['mask'] * 255)
118 |
119 | sent_dict = []
120 | for i, sent in enumerate(sentences):
121 | sent_dict.append({
122 | 'idx': i,
123 | 'sent_id': sent['sent_id'],
124 | 'sent': sent['sent'].strip()
125 | })
126 |
127 | ref_dict['sentences'] = sent_dict
128 | ref_dict['sentences_num'] = len(sent_dict)
129 |
130 | dataset_array.append(ref_dict)
131 | print('Dumping json file...')
132 | with open(os.path.join(output_dir, 'anns', dataset, split + '.json'),
133 | 'w') as f:
134 | json.dump(dataset_array, f)
135 |
136 |
137 | prepare_dataset(args.dataset, splits, args.output_dir, args.generate_mask)
138 |
--------------------------------------------------------------------------------
/tools/folder2lmdb.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import os.path as osp
4 | import lmdb
5 | import pyarrow as pa
6 | import json
7 | from tqdm import tqdm
8 | import warnings
9 | warnings.filterwarnings("ignore")
10 |
11 |
12 | def loads_pyarrow(buf):
13 | """
14 | Args:
15 | buf: the output of `dumps`.
16 | """
17 | return pa.deserialize(buf)
18 |
19 |
20 | def raw_reader(path):
21 | with open(path, 'rb') as f:
22 | bin_data = f.read()
23 | return bin_data
24 |
25 |
26 | def dumps_pyarrow(obj):
27 | """
28 | Serialize an object.
29 | Returns:
30 | Implementation-dependent bytes-like object
31 | """
32 | return pa.serialize(obj).to_buffer()
33 |
34 |
35 | def folder2lmdb(json_data, img_dir, mask_dir, output_dir, split, write_frequency=1000):
36 | lmdb_path = osp.join(output_dir, "%s.lmdb" % split)
37 | isdir = os.path.isdir(lmdb_path)
38 |
39 | print("Generate LMDB to %s" % lmdb_path)
40 | db = lmdb.open(lmdb_path, subdir=isdir,
41 | map_size=1099511627776 * 2, readonly=False,
42 | meminit=False, map_async=True)
43 |
44 | txn = db.begin(write=True)
45 | tbar = tqdm(json_data)
46 | for idx, item in enumerate(tbar):
47 | img = raw_reader(osp.join(img_dir, item['img_name']))
48 | mask = raw_reader(osp.join(mask_dir, f"{item['segment_id']}.png"))
49 | data = {'img': img, 'mask': mask, 'cat': item['cat'],
50 | 'seg_id': item['segment_id'], 'img_name': item['img_name'],
51 | 'num_sents': item['sentences_num'], 'sents': [i['sent'] for i in item['sentences']]}
52 | txn.put(u'{}'.format(idx).encode('ascii'), dumps_pyarrow(data))
53 | if idx % write_frequency == 0:
54 | # print("[%d/%d]" % (idx, len(data_loader)))
55 | txn.commit()
56 | txn = db.begin(write=True)
57 |
58 | # finish iterating through dataset
59 | txn.commit()
60 | keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)]
61 | with db.begin(write=True) as txn:
62 | txn.put(b'__keys__', dumps_pyarrow(keys))
63 | txn.put(b'__len__', dumps_pyarrow(len(keys)))
64 |
65 | print("Flushing database ...")
66 | db.sync()
67 | db.close()
68 |
69 |
70 | def parse_args():
71 | parser = argparse.ArgumentParser(description='COCO Folder to LMDB.')
72 | parser.add_argument('-j', '--json-dir', type=str,
73 | default='',
74 | help='the name of json file.')
75 | parser.add_argument('-i', '--img-dir', type=str,
76 | default='refcoco+',
77 | help='the folder of images.')
78 | parser.add_argument('-m', '--mask-dir', type=str,
79 | default='refcoco+',
80 | help='the folder of masks.')
81 | parser.add_argument('-o', '--output-dir', type=str,
82 | default='refcoco+',
83 | help='the folder of output lmdb file.')
84 | parser.add_argument('-s', '--split', type=str,
85 | default='train',
86 | help='the split type.')
87 | args = parser.parse_args()
88 | return args
89 |
90 |
91 | if __name__ == '__main__':
92 | args = parse_args()
93 | args.split = osp.basename(args.json_dir).split(".")[0]
94 | os.makedirs(args.output_dir, exist_ok=True)
95 |
96 | with open(args.json_dir, 'rb') as f:
97 | json_data = json.load(f)
98 |
99 | folder2lmdb(json_data, args.img_dir, args.mask_dir, args.output_dir, args.split)
100 |
--------------------------------------------------------------------------------
/tools/latency.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys
3 | import time
4 | import warnings
5 |
6 | sys.path.append('./')
7 | warnings.filterwarnings("ignore")
8 |
9 | import torch
10 | import torch.backends.cudnn as cudnn
11 | import utils.config as config
12 | from model import build_segmenter
13 |
14 |
15 | def get_parser():
16 | parser = argparse.ArgumentParser(
17 | description='Pytorch Referring Expression Segmentation')
18 | parser.add_argument('--config',
19 | default='path to xxx.yaml',
20 | type=str,
21 | help='config file')
22 | parser.add_argument('--opts',
23 | default=None,
24 | nargs=argparse.REMAINDER,
25 | help='override some settings in the config.')
26 | args = parser.parse_args()
27 | assert args.config is not None
28 | cfg = config.load_cfg_from_cfg_file(args.config)
29 | if args.opts is not None:
30 | cfg = config.merge_cfg_from_list(cfg, args.opts)
31 | return cfg
32 |
33 |
34 | def count_parameters(model):
35 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
36 |
37 |
38 | def main():
39 | # init arguments
40 | args = get_parser()
41 | torch.cuda.set_device(0)
42 | # create model
43 | model, _ = build_segmenter(args)
44 | model = model.cuda()
45 | model.eval()
46 | # set cudnn state
47 | cudnn.benchmark = True
48 | cudnn.deterministic = False
49 | cudnn.enabled = True
50 | # init dummy tensor
51 | image = torch.randn(1, 3, 416, 416).cuda()
52 | text = torch.randint(4096, size=(1, args.word_len)).long().cuda()
53 | # init time & memory
54 | avg_time = 0
55 | avg_mem = 0
56 | # record initial gpu memory
57 | mem = torch.cuda.max_memory_allocated()
58 |
59 | with torch.no_grad():
60 | for i in range(500):
61 | start_time = time.time()
62 | _ = model(image, text)
63 | torch.cuda.synchronize()
64 | if (i+1) >= 100:
65 | avg_time += (time.time() - start_time)
66 | avg_mem += (torch.cuda.max_memory_allocated() - mem) / 1.073742e9
67 | params = count_parameters(model) * 1e-6
68 | print('#########################################')
69 | print("Average Parameters : {:.2f} M".format(params))
70 | print("Average FPS: {:.2f}".format(400/avg_time))
71 | print("Average GPU Memory: {:.2f} GB".format(avg_mem/400))
72 | print('#########################################')
73 |
74 |
75 | if __name__ == '__main__':
76 | main()
77 |
--------------------------------------------------------------------------------
/tools/prepare_datasets.md:
--------------------------------------------------------------------------------
1 | ## Prepare datasets
2 |
3 | In our paper, we conduct experiments on three common-used datasets, including Ref-COCO, Ref-COCO+ and G-Ref.
4 |
5 | ### 1. COCO 2014
6 |
7 | The data could be found at [here](https://cocodataset.org/#download). Please run the following commands to download.
8 |
9 | ```shell
10 | # download
11 | mkdir datasets && cd datasets
12 | wget http://images.cocodataset.org/zips/train2014.zip
13 |
14 | # unzip
15 | unzip train2014.zip -d images/ && rm train2014.zip
16 |
17 | ```
18 |
19 | ### 2. Ref-COCO
20 |
21 | The data could be found at [here](https://github.com/lichengunc/refer). Please run the following commands to download and convert.
22 |
23 | ```shell
24 | # download
25 | wget https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcoco.zip
26 |
27 | # unzip
28 | unzip refcoco.zip && rm refcoco.zip
29 |
30 | # convert
31 | python ../tools/data_process.py --data_root . --output_dir . --dataset refcoco --split unc --generate_mask
32 |
33 | # lmdb
34 | python ../tools/folder2lmdb.py -j anns/refcoco/train.json -i images/train2014/ -m masks/refcoco -o lmdb/refcoco
35 | python ../tools/folder2lmdb.py -j anns/refcoco/val.json -i images/train2014/ -m masks/refcoco -o lmdb/refcoco
36 | python ../tools/folder2lmdb.py -j anns/refcoco/testA.json -i images/train2014/ -m masks/refcoco -o lmdb/refcoco
37 | python ../tools/folder2lmdb.py -j anns/refcoco/testB.json -i images/train2014/ -m masks/refcoco -o lmdb/refcoco
38 |
39 | # clean
40 | rm -r refcoco
41 |
42 | ```
43 |
44 | ### 3. Ref-COCO+
45 |
46 | The data could be found at [here](https://github.com/lichengunc/refer). Please run the following commands to download and convert.
47 |
48 | ```shell
49 | # download
50 | wget https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcoco+.zip
51 |
52 | # unzip
53 | unzip refcoco+.zip && rm refcoco+.zip
54 |
55 | # convert
56 | python ../tools/data_process.py --data_root . --output_dir . --dataset refcoco+ --split unc --generate_mask
57 |
58 | # lmdb
59 | python ../tools/folder2lmdb.py -j anns/refcoco+/train.json -i images/train2014/ -m masks/refcoco+ -o lmdb/refcoco+
60 | python ../tools/folder2lmdb.py -j anns/refcoco+/val.json -i images/train2014/ -m masks/refcoco+ -o lmdb/refcoco+
61 | python ../tools/folder2lmdb.py -j anns/refcoco+/testA.json -i images/train2014/ -m masks/refcoco+ -o lmdb/refcoco+
62 | python ../tools/folder2lmdb.py -j anns/refcoco+/testB.json -i images/train2014/ -m masks/refcoco+ -o lmdb/refcoco+
63 |
64 | # clean
65 | rm -r refcoco+
66 |
67 | ```
68 |
69 | ### 4. Ref-COCOg
70 |
71 | The data could be found at [here](https://github.com/lichengunc/refer). Please run the following commands to download and convert.
72 | (Note that we adopt two different splits of this dataset, 'umd' and 'google'.)
73 |
74 | ```shell
75 | # download
76 | wget https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcocog.zip
77 |
78 | # unzip
79 | unzip refcocog.zip && rm refcocog.zip
80 |
81 | # convert
82 | python ../tools/data_process.py --data_root . --output_dir . --dataset refcocog --split umd --generate_mask # umd split
83 | mv anns/refcocog anns/refcocog_u
84 | mv masks/refcocog masks/refcocog_u
85 |
86 | python ../tools/data_process.py --data_root . --output_dir . --dataset refcocog --split google --generate_mask # google split
87 | mv anns/refcocog anns/refcocog_g
88 | mv masks/refcocog masks/refcocog_g
89 |
90 | # lmdb
91 | python ../tools/folder2lmdb.py -j anns/refcocog_u/train.json -i images/train2014/ -m masks/refcocog_u -o lmdb/refcocog_u
92 | python ../tools/folder2lmdb.py -j anns/refcocog_u/val.json -i images/train2014/ -m masks/refcocog_u -o lmdb/refcocog_u
93 | python ../tools/folder2lmdb.py -j anns/refcocog_u/test.json -i images/train2014/ -m masks/refcocog_u -o lmdb/refcocog_u
94 |
95 | python ../tools/folder2lmdb.py -j anns/refcocog_g/train.json -i images/train2014/ -m masks/refcocog_g -o lmdb/refcocog_g
96 | python ../tools/folder2lmdb.py -j anns/refcocog_g/val.json -i images/train2014/ -m masks/refcocog_g -o lmdb/refcocog_g
97 |
98 | rm -r refcocog
99 |
100 | ```
101 |
102 | ### 5. Datasets struture
103 |
104 | After the above-mentioned commands, the strutre of the dataset folder should be like:
105 |
106 | ```none
107 | datasets
108 | ├── anns
109 | │ ├── refcoco
110 | │ │ ├── xxx.json
111 | │ ├── refcoco+
112 | │ │ ├── xxx.json
113 | │ ├── refcocog_g
114 | │ │ ├── xxx.json
115 | │ ├── refcocog_u
116 | │ │ ├── xxx.json
117 | ├── images
118 | │ ├── train2014
119 | │ │ ├── xxx.jpg
120 | ├── lmdb
121 | │ ├── refcoco
122 | │ │ ├── xxx.lmdb
123 | │ │ ├── xxx.lmdb-lock
124 | │ ├── refcoco+
125 | │ │ ├── xxx.lmdb
126 | │ │ ├── xxx.lmdb-lock
127 | │ ├── refcocog_g
128 | │ │ ├── xxx.lmdb
129 | │ │ ├── xxx.lmdb-lock
130 | │ ├── refcocog_u
131 | │ │ ├── xxx.lmdb
132 | │ │ ├── xxx.lmdb-lock
133 | ├── masks
134 | │ ├── refcoco
135 | │ │ ├── xxx.png
136 | │ ├── refcoco+
137 | │ │ ├── xxx.png
138 | │ ├── refcocog_g
139 | │ │ ├── xxx.png
140 | │ ├── refcocog_u
141 | │ │ ├── xxx.png
142 |
143 | ```
--------------------------------------------------------------------------------
/tools/refer.py:
--------------------------------------------------------------------------------
1 | __author__ = 'licheng'
2 | """
3 | This interface provides access to four datasets:
4 | 1) refclef
5 | 2) refcoco
6 | 3) refcoco+
7 | 4) refcocog
8 | split by unc and google
9 | The following API functions are defined:
10 | REFER - REFER api class
11 | getRefIds - get ref ids that satisfy given filter conditions.
12 | getAnnIds - get ann ids that satisfy given filter conditions.
13 | getImgIds - get image ids that satisfy given filter conditions.
14 | getCatIds - get category ids that satisfy given filter conditions.
15 | loadRefs - load refs with the specified ref ids.
16 | loadAnns - load anns with the specified ann ids.
17 | loadImgs - load images with the specified image ids.
18 | loadCats - load category names with the specified category ids.
19 | getRefBox - get ref's bounding box [x, y, w, h] given the ref_id
20 | showRef - show image, segmentation or box of the referred object with the ref
21 | getMask - get mask and area of the referred object given ref
22 | showMask - show mask of the referred object given ref
23 | """
24 |
25 | import itertools
26 | import json
27 | import os.path as osp
28 | import pickle
29 | import sys
30 | import time
31 | from pprint import pprint
32 |
33 | import matplotlib.pyplot as plt
34 | import numpy as np
35 | import skimage.io as io
36 | from matplotlib.collections import PatchCollection
37 | from matplotlib.patches import Polygon, Rectangle
38 | from pycocotools import mask
39 |
40 |
41 | class REFER:
42 | def __init__(self, data_root, dataset='refcoco', splitBy='unc'):
43 | # provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog
44 | # also provide dataset name and splitBy information
45 | # e.g., dataset = 'refcoco', splitBy = 'unc'
46 | print('loading dataset %s into memory...' % dataset)
47 | self.ROOT_DIR = osp.abspath(osp.dirname(__file__))
48 | self.DATA_DIR = osp.join(data_root, dataset)
49 | if dataset in ['refcoco', 'refcoco+', 'refcocog']:
50 | self.IMAGE_DIR = osp.join(data_root, 'images/train2014')
51 | elif dataset == 'refclef':
52 | self.IMAGE_DIR = osp.join(data_root, 'images/saiapr_tc-12')
53 | else:
54 | print('No refer dataset is called [%s]' % dataset)
55 | sys.exit()
56 |
57 | # load refs from data/dataset/refs(dataset).json
58 | tic = time.time()
59 | ref_file = osp.join(self.DATA_DIR, 'refs(' + splitBy + ').p')
60 | self.data = {}
61 | self.data['dataset'] = dataset
62 |
63 | self.data['refs'] = pickle.load(open(ref_file, 'rb'), fix_imports=True)
64 |
65 | # load annotations from data/dataset/instances.json
66 | instances_file = osp.join(self.DATA_DIR, 'instances.json')
67 | instances = json.load(open(instances_file, 'r'))
68 | self.data['images'] = instances['images']
69 | self.data['annotations'] = instances['annotations']
70 | self.data['categories'] = instances['categories']
71 |
72 | # create index
73 | self.createIndex()
74 | print('DONE (t=%.2fs)' % (time.time() - tic))
75 |
76 | def createIndex(self):
77 | # create sets of mapping
78 | # 1) Refs: {ref_id: ref}
79 | # 2) Anns: {ann_id: ann}
80 | # 3) Imgs: {image_id: image}
81 | # 4) Cats: {category_id: category_name}
82 | # 5) Sents: {sent_id: sent}
83 | # 6) imgToRefs: {image_id: refs}
84 | # 7) imgToAnns: {image_id: anns}
85 | # 8) refToAnn: {ref_id: ann}
86 | # 9) annToRef: {ann_id: ref}
87 | # 10) catToRefs: {category_id: refs}
88 | # 11) sentToRef: {sent_id: ref}
89 | # 12) sentToTokens: {sent_id: tokens}
90 | print('creating index...')
91 | # fetch info from instances
92 | Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {}
93 | for ann in self.data['annotations']:
94 | Anns[ann['id']] = ann
95 | imgToAnns[ann['image_id']] = imgToAnns.get(ann['image_id'],
96 | []) + [ann]
97 | for img in self.data['images']:
98 | Imgs[img['id']] = img
99 | for cat in self.data['categories']:
100 | Cats[cat['id']] = cat['name']
101 |
102 | # fetch info from refs
103 | Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {}
104 | Sents, sentToRef, sentToTokens = {}, {}, {}
105 | for ref in self.data['refs']:
106 | # ids
107 | ref_id = ref['ref_id']
108 | ann_id = ref['ann_id']
109 | category_id = ref['category_id']
110 | image_id = ref['image_id']
111 |
112 | # add mapping related to ref
113 | Refs[ref_id] = ref
114 | imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref]
115 | catToRefs[category_id] = catToRefs.get(category_id, []) + [ref]
116 | refToAnn[ref_id] = Anns[ann_id]
117 | annToRef[ann_id] = ref
118 |
119 | # add mapping of sent
120 | for sent in ref['sentences']:
121 | Sents[sent['sent_id']] = sent
122 | sentToRef[sent['sent_id']] = ref
123 | sentToTokens[sent['sent_id']] = sent['tokens']
124 |
125 | # create class members
126 | self.Refs = Refs
127 | self.Anns = Anns
128 | self.Imgs = Imgs
129 | self.Cats = Cats
130 | self.Sents = Sents
131 | self.imgToRefs = imgToRefs
132 | self.imgToAnns = imgToAnns
133 | self.refToAnn = refToAnn
134 | self.annToRef = annToRef
135 | self.catToRefs = catToRefs
136 | self.sentToRef = sentToRef
137 | self.sentToTokens = sentToTokens
138 | print('index created.')
139 |
140 | def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=''):
141 | image_ids = image_ids if type(image_ids) == list else [image_ids]
142 | cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
143 | ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
144 |
145 | if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0:
146 | refs = self.data['refs']
147 | else:
148 | if not len(image_ids) == 0:
149 | refs = [self.imgToRefs[image_id] for image_id in image_ids]
150 | else:
151 | refs = self.data['refs']
152 | if not len(cat_ids) == 0:
153 | refs = [ref for ref in refs if ref['category_id'] in cat_ids]
154 | if not len(ref_ids) == 0:
155 | refs = [ref for ref in refs if ref['ref_id'] in ref_ids]
156 | if not len(split) == 0:
157 | if split in ['testA', 'testB', 'testC']:
158 | refs = [ref for ref in refs if split[-1] in ref['split']
159 | ] # we also consider testAB, testBC, ...
160 | elif split in ['testAB', 'testBC', 'testAC']:
161 | refs = [ref for ref in refs
162 | if ref['split'] == split] # rarely used I guess...
163 | elif split == 'test':
164 | refs = [ref for ref in refs if 'test' in ref['split']]
165 | elif split == 'train' or split == 'val':
166 | refs = [ref for ref in refs if ref['split'] == split]
167 | else:
168 | print('No such split [%s]' % split)
169 | sys.exit()
170 | ref_ids = [ref['ref_id'] for ref in refs]
171 | return ref_ids
172 |
173 | def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]):
174 | image_ids = image_ids if type(image_ids) == list else [image_ids]
175 | cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
176 | ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
177 |
178 | if len(image_ids) == len(cat_ids) == len(ref_ids) == 0:
179 | ann_ids = [ann['id'] for ann in self.data['annotations']]
180 | else:
181 | if not len(image_ids) == 0:
182 | lists = [
183 | self.imgToAnns[image_id] for image_id in image_ids
184 | if image_id in self.imgToAnns
185 | ] # list of [anns]
186 | anns = list(itertools.chain.from_iterable(lists))
187 | else:
188 | anns = self.data['annotations']
189 | if not len(cat_ids) == 0:
190 | anns = [ann for ann in anns if ann['category_id'] in cat_ids]
191 | ann_ids = [ann['id'] for ann in anns]
192 | if not len(ref_ids) == 0:
193 | ids = set(ann_ids).intersection(
194 | set([self.Refs[ref_id]['ann_id'] for ref_id in ref_ids]))
195 | return ann_ids
196 |
197 | def getImgIds(self, ref_ids=[]):
198 | ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
199 |
200 | if not len(ref_ids) == 0:
201 | image_ids = list(
202 | set([self.Refs[ref_id]['image_id'] for ref_id in ref_ids]))
203 | else:
204 | image_ids = self.Imgs.keys()
205 | return image_ids
206 |
207 | def getCatIds(self):
208 | return self.Cats.keys()
209 |
210 | def loadRefs(self, ref_ids=[]):
211 | if type(ref_ids) == list:
212 | return [self.Refs[ref_id] for ref_id in ref_ids]
213 | elif type(ref_ids) == int:
214 | return [self.Refs[ref_ids]]
215 |
216 | def loadAnns(self, ann_ids=[]):
217 | if type(ann_ids) == list:
218 | return [self.Anns[ann_id] for ann_id in ann_ids]
219 | elif type(ann_ids) == int or type(ann_ids) == unicode:
220 | return [self.Anns[ann_ids]]
221 |
222 | def loadImgs(self, image_ids=[]):
223 | if type(image_ids) == list:
224 | return [self.Imgs[image_id] for image_id in image_ids]
225 | elif type(image_ids) == int:
226 | return [self.Imgs[image_ids]]
227 |
228 | def loadCats(self, cat_ids=[]):
229 | if type(cat_ids) == list:
230 | return [self.Cats[cat_id] for cat_id in cat_ids]
231 | elif type(cat_ids) == int:
232 | return [self.Cats[cat_ids]]
233 |
234 | def getRefBox(self, ref_id):
235 | ref = self.Refs[ref_id]
236 | ann = self.refToAnn[ref_id]
237 | return ann['bbox'] # [x, y, w, h]
238 |
239 | def showRef(self, ref, seg_box='seg'):
240 | ax = plt.gca()
241 | # show image
242 | image = self.Imgs[ref['image_id']]
243 | I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))
244 | ax.imshow(I)
245 | # show refer expression
246 | for sid, sent in enumerate(ref['sentences']):
247 | print('%s. %s' % (sid + 1, sent['sent']))
248 | # show segmentations
249 | if seg_box == 'seg':
250 | ann_id = ref['ann_id']
251 | ann = self.Anns[ann_id]
252 | polygons = []
253 | color = []
254 | c = 'none'
255 | if type(ann['segmentation'][0]) == list:
256 | # polygon used for refcoco*
257 | for seg in ann['segmentation']:
258 | poly = np.array(seg).reshape((len(seg) / 2, 2))
259 | polygons.append(Polygon(poly, True, alpha=0.4))
260 | color.append(c)
261 | p = PatchCollection(polygons,
262 | facecolors=color,
263 | edgecolors=(1, 1, 0, 0),
264 | linewidths=3,
265 | alpha=1)
266 | ax.add_collection(p) # thick yellow polygon
267 | p = PatchCollection(polygons,
268 | facecolors=color,
269 | edgecolors=(1, 0, 0, 0),
270 | linewidths=1,
271 | alpha=1)
272 | ax.add_collection(p) # thin red polygon
273 | else:
274 | # mask used for refclef
275 | rle = ann['segmentation']
276 | m = mask.decode(rle)
277 | img = np.ones((m.shape[0], m.shape[1], 3))
278 | color_mask = np.array([2.0, 166.0, 101.0]) / 255
279 | for i in range(3):
280 | img[:, :, i] = color_mask[i]
281 | ax.imshow(np.dstack((img, m * 0.5)))
282 | # show bounding-box
283 | elif seg_box == 'box':
284 | ann_id = ref['ann_id']
285 | ann = self.Anns[ann_id]
286 | bbox = self.getRefBox(ref['ref_id'])
287 | box_plot = Rectangle((bbox[0], bbox[1]),
288 | bbox[2],
289 | bbox[3],
290 | fill=False,
291 | edgecolor='green',
292 | linewidth=3)
293 | ax.add_patch(box_plot)
294 |
295 | def getMask(self, ref):
296 | # return mask, area and mask-center
297 | ann = self.refToAnn[ref['ref_id']]
298 | image = self.Imgs[ref['image_id']]
299 | if type(ann['segmentation'][0]) == list: # polygon
300 | rle = mask.frPyObjects(ann['segmentation'], image['height'],
301 | image['width'])
302 | else:
303 | rle = ann['segmentation']
304 |
305 | # for i in range(len(rle['counts'])):
306 | # print(rle)
307 | m = mask.decode(rle)
308 | m = np.sum(
309 | m, axis=2
310 | ) # sometimes there are multiple binary map (corresponding to multiple segs)
311 | m = m.astype(np.uint8) # convert to np.uint8
312 | # compute area
313 | area = sum(mask.area(rle)) # should be close to ann['area']
314 | return {'mask': m, 'area': area}
315 | # # position
316 | # position_x = np.mean(np.where(m==1)[1]) # [1] means columns (matlab style) -> x (c style)
317 | # position_y = np.mean(np.where(m==1)[0]) # [0] means rows (matlab style) -> y (c style)
318 | # # mass position (if there were multiple regions, we use the largest one.)
319 | # label_m = label(m, connectivity=m.ndim)
320 | # regions = regionprops(label_m)
321 | # if len(regions) > 0:
322 | # largest_id = np.argmax(np.array([props.filled_area for props in regions]))
323 | # largest_props = regions[largest_id]
324 | # mass_y, mass_x = largest_props.centroid
325 | # else:
326 | # mass_x, mass_y = position_x, position_y
327 | # # if centroid is not in mask, we find the closest point to it from mask
328 | # if m[mass_y, mass_x] != 1:
329 | # print 'Finding closes mask point ...'
330 | # kernel = np.ones((10, 10),np.uint8)
331 | # me = cv2.erode(m, kernel, iterations = 1)
332 | # points = zip(np.where(me == 1)[0].tolist(), np.where(me == 1)[1].tolist()) # row, col style
333 | # points = np.array(points)
334 | # dist = np.sum((points - (mass_y, mass_x))**2, axis=1)
335 | # id = np.argsort(dist)[0]
336 | # mass_y, mass_x = points[id]
337 | # # return
338 | # return {'mask': m, 'area': area, 'position_x': position_x, 'position_y': position_y, 'mass_x': mass_x, 'mass_y': mass_y}
339 | # # show image and mask
340 | # I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))
341 | # plt.figure()
342 | # plt.imshow(I)
343 | # ax = plt.gca()
344 | # img = np.ones( (m.shape[0], m.shape[1], 3) )
345 | # color_mask = np.array([2.0,166.0,101.0])/255
346 | # for i in range(3):
347 | # img[:,:,i] = color_mask[i]
348 | # ax.imshow(np.dstack( (img, m*0.5) ))
349 | # plt.show()
350 |
351 | def showMask(self, ref):
352 | M = self.getMask(ref)
353 | msk = M['mask']
354 | ax = plt.gca()
355 | ax.imshow(msk)
356 |
357 |
358 | if __name__ == '__main__':
359 | refer = REFER(dataset='refcocog', splitBy='google')
360 | ref_ids = refer.getRefIds()
361 | print(len(ref_ids))
362 |
363 | print(len(refer.Imgs))
364 | print(len(refer.imgToRefs))
365 |
366 | ref_ids = refer.getRefIds(split='train')
367 | print('There are %s training referred objects.' % len(ref_ids))
368 |
369 | for ref_id in ref_ids:
370 | ref = refer.loadRefs(ref_id)[0]
371 | if len(ref['sentences']) < 2:
372 | continue
373 |
374 | pprint(ref)
375 | print('The label is %s.' % refer.Cats[ref['category_id']])
376 | plt.figure()
377 | refer.showRef(ref, seg_box='box')
378 | plt.show()
379 |
380 | # plt.figure()
381 | # refer.showMask(ref)
382 | # plt.show()
383 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import os
4 | import shutil
5 | import sys
6 | import time
7 | import warnings
8 | from functools import partial
9 |
10 | import cv2
11 | import torch
12 | import torch.cuda.amp as amp
13 | import torch.distributed as dist
14 | import torch.multiprocessing as mp #https://blog.csdn.net/hxxjxw/article/details/119839548
15 | import torch.nn as nn
16 | import torch.nn.parallel
17 | import torch.optim
18 | import torch.utils.data as data
19 | from loguru import logger # https://hanjunqiang.blog.csdn.net/article/details/124779625
20 | from torch.optim.lr_scheduler import MultiStepLR
21 |
22 | import utils.config as config
23 | import wandb
24 | from utils.dataset import RefDataset
25 | from engine.engine import train, validate
26 | from model import build_segmenter
27 | from utils.misc import (init_random_seed, set_random_seed, setup_logger,
28 | worker_init_fn, build_scheduler, collate_fn)
29 |
30 | warnings.filterwarnings("ignore")
31 | cv2.setNumThreads(0)
32 |
33 |
34 | def get_parser():
35 | parser = argparse.ArgumentParser(
36 | description='Pytorch Referring Expression Segmentation')
37 | parser.add_argument('--config',
38 | default='path to xxx.yaml',
39 | type=str,
40 | help='config file')
41 | parser.add_argument('--opts',
42 | default=None,
43 | nargs=argparse.REMAINDER,
44 | help='override some settings in the config.')
45 |
46 | args = parser.parse_args()
47 | assert args.config is not None
48 | cfg = config.load_cfg_from_cfg_file(args.config)
49 | if args.opts is not None:
50 | cfg = config.merge_cfg_from_list(cfg, args.opts)
51 | return cfg
52 |
53 |
54 | @logger.catch #在子线程或主线程中捕获异常
55 | def main():
56 | args = get_parser()
57 | args.manual_seed = init_random_seed(args.manual_seed)
58 | set_random_seed(args.manual_seed, deterministic=True)
59 |
60 | args.ngpus_per_node = torch.cuda.device_count()
61 | args.world_size = args.ngpus_per_node * args.world_size
62 | mp.spawn(main_worker, nprocs=args.ngpus_per_node, args=(args, ))
63 |
64 |
65 | def main_worker(gpu, args):
66 | args.output_dir = os.path.join(args.output_folder, args.exp_name)
67 |
68 | # local rank & global rank
69 | args.gpu = gpu
70 | args.rank = args.rank * args.ngpus_per_node + gpu
71 | torch.cuda.set_device(args.gpu)
72 |
73 | # logger
74 | setup_logger(args.output_dir,
75 | distributed_rank=args.gpu,
76 | filename="train.log",
77 | mode="a")
78 | # dist init
79 | dist.init_process_group(backend=args.dist_backend,
80 | init_method=args.dist_url,
81 | world_size=args.world_size,
82 | rank=args.rank)
83 | # wandb
84 | if args.rank == 0:
85 | wandb.init(job_type="training",
86 | mode="offline",
87 | config=args,
88 | project=args.exp_name,
89 | name=args.exp_name,
90 | tags=[args.dataset])
91 | dist.barrier()
92 | # build model
93 | model, param_list = build_segmenter(args)
94 | # logger.info(model)
95 | logger.info(args)
96 |
97 | # build optimizer & lr scheduler
98 | optimizer = torch.optim.AdamW(param_list,
99 | lr=args.lr,
100 | weight_decay=args.weight_decay,
101 | amsgrad=args.amsgrad
102 | )
103 |
104 | scaler = amp.GradScaler()
105 |
106 | # build dataset
107 | args.batch_size = int(args.batch_size / args.ngpus_per_node)
108 | args.batch_size_val = int(args.batch_size_val / args.ngpus_per_node)
109 | args.workers = int(
110 | (args.workers + args.ngpus_per_node - 1) / args.ngpus_per_node)
111 | train_data = RefDataset(lmdb_dir=args.train_lmdb,
112 | mask_dir=args.mask_root,
113 | dataset=args.dataset,
114 | split=args.train_split,
115 | mode='train',
116 | input_size=args.input_size,
117 | word_length=args.word_len
118 | )
119 | val_data = RefDataset(lmdb_dir=args.val_lmdb,
120 | mask_dir=args.mask_root,
121 | dataset=args.dataset,
122 | split=args.val_split,
123 | mode='val',
124 | input_size=args.input_size,
125 | word_length=args.word_len,
126 | )
127 |
128 | # build dataloader
129 | init_fn = partial(worker_init_fn,
130 | num_workers=args.workers,
131 | rank=args.rank,
132 | seed=args.manual_seed)
133 | train_sampler = data.distributed.DistributedSampler(train_data,
134 | shuffle=True)
135 | val_sampler = data.distributed.DistributedSampler(val_data, shuffle=False)
136 |
137 | train_loader = data.DataLoader(train_data,
138 | batch_size=args.batch_size,
139 | shuffle=False,
140 | num_workers=args.workers,
141 | pin_memory=True,
142 | worker_init_fn=init_fn,
143 | sampler=train_sampler,
144 | collate_fn=collate_fn,
145 | drop_last=True)
146 | val_loader = data.DataLoader(val_data,
147 | batch_size=args.batch_size_val,
148 | shuffle=False,
149 | num_workers=args.workers_val,
150 | pin_memory=True,
151 | sampler=val_sampler,
152 | drop_last=False,
153 | collate_fn=collate_fn,
154 | )
155 |
156 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda x: (1 - x / (len(train_loader) * args.epochs)) ** 0.9)
157 |
158 | best_IoU = 0.0
159 | # resume
160 | if args.resume:
161 | if os.path.isfile(args.resume):
162 | logger.info("=> loading checkpoint '{}'".format(args.resume))
163 | checkpoint = torch.load(
164 | args.resume, map_location=lambda storage, loc: storage.cuda())
165 | args.start_epoch = checkpoint['epoch']
166 | best_IoU = checkpoint["best_iou"]
167 | checkpoint['model_state_dict'].pop('decoder.tokens.weight')
168 | optimizer.load_state_dict(checkpoint['optimizer'])
169 | scheduler.load_state_dict(checkpoint['scheduler'])
170 | logger.info("=> loaded checkpoint '{}' (epoch {})".format(
171 | args.resume, checkpoint['epoch']))
172 | else:
173 | raise ValueError(
174 | "=> resume failed! no checkpoint found at '{}'. Please check args.resume again!"
175 | .format(args.resume))
176 |
177 | # start training
178 | start_time = time.time()
179 | for epoch in range(args.start_epoch, args.epochs):
180 | epoch_log = epoch + 1
181 |
182 | # shuffle loader
183 | train_sampler.set_epoch(epoch_log)
184 |
185 | # train
186 | train(train_loader, model, optimizer, scheduler, scaler, epoch_log, args)
187 |
188 | # evaluation
189 | iou, prec_dict = validate(val_loader, model, epoch_log, args)
190 |
191 | # save model
192 | if dist.get_rank() == 0:
193 | lastname = os.path.join(args.output_dir, "last_model.pth")
194 | torch.save(
195 | {
196 | 'epoch': epoch_log,
197 | 'cur_iou': iou,
198 | 'best_iou': best_IoU,
199 | 'prec': prec_dict,
200 | 'model_state_dict': model.module.state_dict(),
201 | 'optimizer': optimizer.state_dict(),
202 | 'scheduler': scheduler.state_dict()
203 | }, lastname)
204 | if iou >= best_IoU and epoch_log<50:
205 | best_IoU = iou
206 | bestname = os.path.join(args.output_dir, "best_model.pth")
207 | shutil.copyfile(lastname, bestname)
208 |
209 | torch.cuda.empty_cache()
210 |
211 | time.sleep(2)
212 | if dist.get_rank() == 0:
213 | wandb.finish()
214 |
215 | logger.info("* Best IoU={} * ".format(best_IoU))
216 | total_time = time.time() - start_time
217 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
218 | logger.info('* Training time {} *'.format(total_time_str))
219 |
220 |
221 | if __name__ == '__main__':
222 | main()
223 | sys.exit(0)
224 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SooLab/CGFormer/766d7fbe0c0101c80806e2499bb4d6960cfc5f4a/utils/__init__.py
--------------------------------------------------------------------------------
/utils/box_ops.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for bounding box manipulation and GIoU.
3 | """
4 | import torch
5 | from torchvision.ops.boxes import box_area
6 |
7 | def clip_iou(boxes1,boxes2):
8 | area1 = box_area(boxes1)
9 | area2 = box_area(boxes2)
10 | lt = torch.max(boxes1[:, :2], boxes2[:, :2])
11 | rb = torch.min(boxes1[:, 2:], boxes2[:, 2:])
12 | wh = (rb - lt).clamp(min=0)
13 | inter = wh[:,0] * wh[:,1]
14 | union = area1 + area2 - inter
15 | iou = (inter + 1e-6) / (union+1e-6)
16 | return iou
17 |
18 | def multi_iou(boxes1, boxes2):
19 | lt = torch.max(boxes1[...,:2], boxes2[...,:2])
20 | rb = torch.min(boxes1[...,2:], boxes2[...,2:])
21 | wh = (rb - lt).clamp(min=0)
22 | wh_1 = boxes1[...,2:] - boxes1[...,:2]
23 | wh_2 = boxes2[...,2:] - boxes2[...,:2]
24 | inter = wh[...,0] * wh[...,1]
25 | union = wh_1[...,0] * wh_1[...,1] + wh_2[...,0] * wh_2[...,1] - inter
26 | iou = (inter + 1e-6) / (union + 1e-6)
27 | return iou
28 |
29 | def box_cxcywh_to_xyxy(x):
30 | x_c, y_c, w, h = x.unbind(-1)
31 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
32 | (x_c + 0.5 * w), (y_c + 0.5 * h)]
33 | return torch.stack(b, dim=-1)
34 |
35 |
36 | def box_xyxy_to_cxcywh(x):
37 | x0, y0, x1, y1 = x.unbind(-1)
38 | b = [(x0 + x1) / 2, (y0 + y1) / 2,
39 | (x1 - x0), (y1 - y0)]
40 | return torch.stack(b, dim=-1)
41 |
42 |
43 | # modified from torchvision to also return the union
44 | def box_iou(boxes1, boxes2):
45 | area1 = box_area(boxes1)
46 | area2 = box_area(boxes2)
47 |
48 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
49 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
50 |
51 | wh = (rb - lt).clamp(min=0) # [N,M,2]
52 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
53 |
54 | union = area1[:, None] + area2 - inter
55 |
56 | iou = (inter+1e-6) / (union+1e-6)
57 | return iou, union
58 |
59 |
60 | def generalized_box_iou(boxes1, boxes2):
61 | """
62 | Generalized IoU from https://giou.stanford.edu/
63 |
64 | The boxes should be in [x0, y0, x1, y1] format
65 |
66 | Returns a [N, M] pairwise matrix, where N = len(boxes1)
67 | and M = len(boxes2)
68 | """
69 | # degenerate boxes gives inf / nan results
70 | # so do an early check
71 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
72 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
73 | iou, union = box_iou(boxes1, boxes2)
74 |
75 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
76 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
77 |
78 | wh = (rb - lt).clamp(min=0) # [N,M,2]
79 | area = wh[:, :, 0] * wh[:, :, 1]
80 |
81 | return iou - ((area - union) + 1e-6) / (area + 1e-6)
82 |
83 |
84 | def masks_to_boxes(masks):
85 | """Compute the bounding boxes around the provided masks
86 |
87 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
88 |
89 | Returns a [N, 4] tensors, with the boxes in xyxy format
90 | """
91 | if masks.numel() == 0:
92 | return torch.zeros((0, 4), device=masks.device)
93 |
94 | h, w = masks.shape[-2:]
95 |
96 | y = torch.arange(0, h, dtype=torch.float)
97 | x = torch.arange(0, w, dtype=torch.float)
98 | y, x = torch.meshgrid(y, x)
99 |
100 | x_mask = (masks * x.unsqueeze(0))
101 | x_max = x_mask.flatten(1).max(-1)[0]
102 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
103 |
104 | y_mask = (masks * y.unsqueeze(0))
105 | y_max = y_mask.flatten(1).max(-1)[0]
106 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
107 |
108 | return torch.stack([x_min, y_min, x_max, y_max], 1)
109 |
--------------------------------------------------------------------------------
/utils/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SooLab/CGFormer/766d7fbe0c0101c80806e2499bb4d6960cfc5f4a/utils/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/utils/config.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # Functions for parsing args
3 | # -----------------------------------------------------------------------------
4 | import copy
5 | import os
6 | from ast import literal_eval
7 |
8 | import yaml
9 |
10 |
11 | class CfgNode(dict):
12 | """
13 | CfgNode represents an internal node in the configuration tree. It's a simple
14 | dict-like container that allows for attribute-based access to keys.
15 | """
16 | def __init__(self, init_dict=None, key_list=None, new_allowed=False):
17 | # Recursively convert nested dictionaries in init_dict into CfgNodes
18 | init_dict = {} if init_dict is None else init_dict
19 | key_list = [] if key_list is None else key_list
20 | for k, v in init_dict.items():
21 | if type(v) is dict:
22 | # Convert dict to CfgNode
23 | init_dict[k] = CfgNode(v, key_list=key_list + [k])
24 | super(CfgNode, self).__init__(init_dict)
25 |
26 | def __getattr__(self, name):
27 | if name in self:
28 | return self[name]
29 | else:
30 | raise AttributeError(name)
31 |
32 | def __setattr__(self, name, value):
33 | self[name] = value
34 |
35 | def __str__(self):
36 | def _indent(s_, num_spaces):
37 | s = s_.split("\n")
38 | if len(s) == 1:
39 | return s_
40 | first = s.pop(0)
41 | s = [(num_spaces * " ") + line for line in s]
42 | s = "\n".join(s)
43 | s = first + "\n" + s
44 | return s
45 |
46 | r = ""
47 | s = []
48 | for k, v in sorted(self.items()):
49 | seperator = "\n" if isinstance(v, CfgNode) else " "
50 | attr_str = "{}:{}{}".format(str(k), seperator, str(v))
51 | attr_str = _indent(attr_str, 2)
52 | s.append(attr_str)
53 | r += "\n".join(s)
54 | return r
55 |
56 | def __repr__(self):
57 | return "{}({})".format(self.__class__.__name__,
58 | super(CfgNode, self).__repr__())
59 |
60 |
61 | def load_cfg_from_cfg_file(file):
62 | cfg = {}
63 | assert os.path.isfile(file) and file.endswith('.yaml'), \
64 | '{} is not a yaml file'.format(file)
65 |
66 | with open(file, 'r') as f:
67 | cfg_from_file = yaml.safe_load(f)
68 |
69 | for key in cfg_from_file:
70 | for k, v in cfg_from_file[key].items():
71 | cfg[k] = v
72 |
73 | cfg = CfgNode(cfg)
74 | return cfg
75 |
76 |
77 | def merge_cfg_from_list(cfg, cfg_list):
78 | new_cfg = copy.deepcopy(cfg)
79 | assert len(cfg_list) % 2 == 0
80 | for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):
81 | subkey = full_key.split('.')[-1]
82 | assert subkey in cfg, 'Non-existent key: {}'.format(full_key)
83 | value = _decode_cfg_value(v)
84 | value = _check_and_coerce_cfg_value_type(value, cfg[subkey], subkey,
85 | full_key)
86 | setattr(new_cfg, subkey, value)
87 |
88 | return new_cfg
89 |
90 |
91 | def _decode_cfg_value(v):
92 | """Decodes a raw config value (e.g., from a yaml config files or command
93 | line argument) into a Python object.
94 | """
95 | # All remaining processing is only applied to strings
96 | if not isinstance(v, str):
97 | return v
98 | # Try to interpret `v` as a:
99 | # string, number, tuple, list, dict, boolean, or None
100 | try:
101 | v = literal_eval(v)
102 | # The following two excepts allow v to pass through when it represents a
103 | # string.
104 | #
105 | # Longer explanation:
106 | # The type of v is always a string (before calling literal_eval), but
107 | # sometimes it *represents* a string and other times a data structure, like
108 | # a list. In the case that v represents a string, what we got back from the
109 | # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is
110 | # ok with '"foo"', but will raise a ValueError if given 'foo'. In other
111 | # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval
112 | # will raise a SyntaxError.
113 | except ValueError:
114 | pass
115 | except SyntaxError:
116 | pass
117 | return v
118 |
119 |
120 | def _check_and_coerce_cfg_value_type(replacement, original, key, full_key):
121 | """Checks that `replacement`, which is intended to replace `original` is of
122 | the right type. The type is correct if it matches exactly or is one of a few
123 | cases in which the type can be easily coerced.
124 | """
125 | original_type = type(original)
126 | replacement_type = type(replacement)
127 |
128 | # The types must match (with some exceptions)
129 | if replacement_type == original_type:
130 | return replacement
131 |
132 | # Cast replacement from from_type to to_type if the replacement and original
133 | # types match from_type and to_type
134 | def conditional_cast(from_type, to_type):
135 | if replacement_type == from_type and original_type == to_type:
136 | return True, to_type(replacement)
137 | else:
138 | return False, None
139 |
140 | # Conditionally casts
141 | # list <-> tuple
142 | casts = [(tuple, list), (list, tuple)]
143 | # For py2: allow converting from str (bytes) to a unicode string
144 | try:
145 | casts.append((str, unicode)) # noqa: F821
146 | except Exception:
147 | pass
148 |
149 | for (from_type, to_type) in casts:
150 | converted, converted_value = conditional_cast(from_type, to_type)
151 | if converted:
152 | return converted_value
153 |
154 | raise ValueError(
155 | "Type mismatch ({} vs. {}) with values ({} vs. {}) for config "
156 | "key: {}".format(original_type, replacement_type, original,
157 | replacement, full_key))
158 |
--------------------------------------------------------------------------------
/utils/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import List, Union
3 | import cv2
4 | from PIL import Image
5 | import lmdb
6 | import numpy as np
7 | import pyarrow as pa
8 | import torch
9 | from torch.utils.data import Dataset
10 | from torchvision.transforms import functional as F
11 | from bert.tokenization_bert import BertTokenizer
12 |
13 | info = {
14 | 'refcoco': {
15 | 'train': 42404,
16 | 'val': 3811,
17 | 'val-test': 3811,
18 | 'testA': 1975,
19 | 'testB': 1810
20 | },
21 | 'refcoco+': {
22 | 'train': 42278,
23 | 'val': 3805,
24 | 'val-test': 3805,
25 | 'testA': 1975,
26 | 'testB': 1798
27 | },
28 | 'refcocog_u': {
29 | 'train': 42226,
30 | 'val': 2573,
31 | 'val-test': 2573,
32 | 'test': 5023
33 | },
34 | 'refcocog_g': {
35 | 'train': 44822,
36 | 'val': 5000,
37 | 'val-test': 5000
38 | }
39 | }
40 | _tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
41 |
42 |
43 | def tokenize(texts: Union[str, List[str]],
44 | context_length: int = 77,
45 | truncate: bool = False) -> torch.LongTensor:
46 | """
47 | Returns the tokenized representation of given input string(s)
48 |
49 | Parameters
50 | ----------
51 | texts : Union[str, List[str]]
52 | An input string or a list of input strings to tokenize
53 |
54 | context_length : int
55 | The context length to use; all CLIP models use 77 as the context length
56 |
57 | truncate: bool
58 | Whether to truncate the text in case its encoding is longer than the context length
59 |
60 | Returns
61 | -------
62 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
63 | """
64 | l_mask = [0] * context_length
65 | result = [0] * context_length
66 |
67 | tokens = _tokenizer.encode(text=texts, add_special_tokens=True)
68 | tokens = tokens[:context_length]
69 | result[:len(tokens)] = tokens
70 | l_mask[:len(tokens)] = [1]*len(tokens)
71 |
72 | result = torch.tensor(result).unsqueeze(0)
73 | l_mask = torch.tensor(l_mask).unsqueeze(0)
74 | return result, l_mask
75 |
76 |
77 | def loads_pyarrow(buf):
78 | """
79 | Args:
80 | buf: the output of `dumps`.
81 | """
82 | return pa.deserialize(buf)
83 |
84 |
85 | class RefDataset(Dataset):
86 | def __init__(self, lmdb_dir, mask_dir, dataset, split, mode, input_size,
87 | word_length):
88 | super(RefDataset, self).__init__()
89 | self.lmdb_dir = lmdb_dir
90 | self.mask_dir = mask_dir
91 | self.dataset = dataset
92 | self.split = split
93 | self.mode = mode
94 | self.input_size = (input_size, input_size)
95 | #self.mask_size = [13, 26, 52]
96 | self.word_length = word_length
97 | self.mean = torch.tensor([0.485, 0.456, 0.406]).reshape(3, 1, 1)
98 | self.std = torch.tensor([0.229, 0.224, 0.225]).reshape(3, 1, 1)
99 | self.length = info[dataset][split]
100 | self.env = None
101 | # self.coco_transforms = make_coco_transforms(mode, cautious=False)
102 |
103 | def _init_db(self):
104 | self.env = lmdb.open(self.lmdb_dir,
105 | subdir=os.path.isdir(self.lmdb_dir),
106 | readonly=True,
107 | lock=False,
108 | readahead=False,
109 | meminit=False)
110 | with self.env.begin(write=False) as txn:
111 | self.length = loads_pyarrow(txn.get(b'__len__'))
112 | self.keys = loads_pyarrow(txn.get(b'__keys__'))
113 |
114 | def __len__(self):
115 | return self.length
116 |
117 | def __getitem__(self, index):
118 | # Delay loading LMDB data until after initialization: https://github.com/chainer/chainermn/issues/129
119 | if self.env is None:
120 | self._init_db()
121 | env = self.env
122 | with env.begin(write=False) as txn:
123 | byteflow = txn.get(self.keys[index])
124 | ref = loads_pyarrow(byteflow)
125 | # img
126 | ori_img = cv2.imdecode(np.frombuffer(ref['img'], np.uint8),
127 | cv2.IMREAD_COLOR)
128 | img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
129 | img_size = img.shape[:2]
130 | # mask
131 | seg_id = ref['seg_id']
132 | mask_dir = os.path.join(self.mask_dir, str(seg_id) + '.png')
133 | # sentences
134 | idx = np.random.choice(ref['num_sents'])
135 | sents = ref['sents']
136 | # transform
137 | # mask transform
138 | mask = cv2.imdecode(np.frombuffer(ref['mask'], np.uint8),
139 | cv2.IMREAD_GRAYSCALE)
140 | mask = mask / 255.
141 | if self.mode == 'train':
142 | sent = sents[idx]
143 | # sentence -> vector
144 | img, mask, sent = self.convert(img, mask, sent, inference=False)
145 | word_vec, pad_mask = tokenize(sent, self.word_length, True)
146 | return img, word_vec, mask, pad_mask
147 | elif self.mode == 'val':
148 | # sentence -> vector
149 | sent = sents[-1]
150 | word_vec, pad_mask = tokenize(sent, self.word_length, True)
151 | img, mask, sent = self.convert(img, mask, sent, inference=False)
152 | return img, word_vec, mask, pad_mask
153 | else:
154 | # sentence -> vector
155 | word_vecs = []
156 | pad_masks = []
157 | for sent in sents:
158 | word_vec, pad_mask = tokenize(sent, self.word_length, True)
159 | word_vecs.append(word_vec)
160 | pad_masks.append(pad_mask)
161 | img, mask, sent = self.convert(img, mask, sent, inference=True)
162 | return ori_img, img, word_vecs, mask, pad_masks, seg_id, sents,
163 |
164 | def convert(self, img, mask, sent, inference=False):
165 | img = Image.fromarray(np.uint8(img))
166 | mask = Image.fromarray(np.uint8(mask), mode="P")
167 | img = F.resize(img, self.input_size)
168 | if not inference:
169 | mask = F.resize(mask, self.input_size, interpolation=Image.NEAREST)
170 | img = F.to_tensor(img)
171 | mask = torch.as_tensor(np.asarray(mask).copy(), dtype=torch.int64)
172 | img = F.normalize(img, mean=self.mean, std=self.std)
173 | return img, mask, sent
174 |
175 |
176 | def __repr__(self):
177 | return self.__class__.__name__ + "(" + \
178 | f"db_path={self.lmdb_dir}, " + \
179 | f"dataset={self.dataset}, " + \
180 | f"split={self.split}, " + \
181 | f"mode={self.mode}, " + \
182 | f"input_size={self.input_size}, " + \
183 | f"word_length={self.word_length}"
184 |
185 |
186 |
--------------------------------------------------------------------------------
/utils/dataset_open.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import List, Union
3 | import cv2
4 | from PIL import Image
5 | import lmdb
6 | import numpy as np
7 | import pyarrow as pa
8 | import torch
9 | from torch.utils.data import Dataset
10 | from torchvision.transforms import functional as F
11 | from bert.tokenization_bert import BertTokenizer
12 | import random
13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer
14 | # from nltk import word_tokenize, pos_tag
15 | info = {
16 | 'refcoco': {
17 | 'train_seen': 35473,
18 | 'val_seen': 3175,
19 | 'val_unseen': 445,
20 | 'test_seen':3200,
21 | 'test_unseen':394
22 | },
23 | 'refcoco+': {
24 | 'train_seen': 35375,
25 | 'val_seen': 3171,
26 | 'val_unseen': 444,
27 | 'test_seen':3189,
28 | 'test_unseen':394
29 | },
30 | 'refcoco_u': {
31 | 'train_seen': 33093,
32 | 'val_seen': 2000,
33 | 'val_unseen': 386,
34 | 'test_seen': 3935,
35 | 'test_unseen': 759,
36 | },
37 | 'refcoco_g': {
38 | 'train_seen': 35105,
39 | 'val_seen': 3923,
40 | 'val_unseen': 760
41 | }
42 | }
43 | _tokenizer = _Tokenizer()
44 |
45 |
46 | def tokenize(texts: Union[str, List[str]],
47 | context_length: int = 77,
48 | truncate: bool = False) -> torch.LongTensor:
49 | """
50 | Returns the tokenized representation of given input string(s)
51 |
52 | Parameters
53 | ----------
54 | texts : Union[str, List[str]]
55 | An input string or a list of input strings to tokenize
56 |
57 | context_length : int
58 | The context length to use; all CLIP models use 77 as the context length
59 |
60 | truncate: bool
61 | Whether to truncate the text in case its encoding is longer than the context length
62 |
63 | Returns
64 | -------
65 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
66 | """
67 | if isinstance(texts, str):
68 | texts = [texts]
69 |
70 | sot_token = _tokenizer.encoder["<|startoftext|>"]
71 | eot_token = _tokenizer.encoder["<|endoftext|>"]
72 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token]
73 | for text in texts]
74 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
75 |
76 | for i, tokens in enumerate(all_tokens):
77 | if len(tokens) > context_length:
78 | if truncate:
79 | tokens = tokens[:context_length]
80 | tokens[-1] = eot_token
81 | else:
82 | raise RuntimeError(
83 | f"Input {texts[i]} is too long for context length {context_length}"
84 | )
85 | result[i, :len(tokens)] = torch.tensor(tokens)
86 |
87 | return result
88 |
89 |
90 | def loads_pyarrow(buf):
91 | """
92 | Args:
93 | buf: the output of `dumps`.
94 | """
95 | return pa.deserialize(buf)
96 |
97 |
98 | class RefDataset(Dataset):
99 | def __init__(self, lmdb_dir, mask_dir, dataset, split, mode, input_size,
100 | word_length):
101 | super(RefDataset, self).__init__()
102 | self.lmdb_dir = lmdb_dir
103 | self.mask_dir = mask_dir
104 | self.dataset = dataset
105 | self.split = split
106 | self.mode = mode
107 | self.input_size = (input_size, input_size)
108 | #self.mask_size = [13, 26, 52]
109 | self.word_length = word_length
110 | self.mean = torch.tensor([0.485, 0.456, 0.406]).reshape(3, 1, 1)
111 | self.std = torch.tensor([0.229, 0.224, 0.225]).reshape(3, 1, 1)
112 | self.length = info[dataset][split]
113 | self.env = None
114 | def _init_db(self):
115 | self.env = lmdb.open(self.lmdb_dir,
116 | subdir=os.path.isdir(self.lmdb_dir),
117 | readonly=True,
118 | lock=False,
119 | readahead=False,
120 | meminit=False)
121 | with self.env.begin(write=False) as txn:
122 | self.length = loads_pyarrow(txn.get(b'__len__'))
123 | self.keys = loads_pyarrow(txn.get(b'__keys__'))
124 |
125 | def __len__(self):
126 | return self.length
127 |
128 | def __getitem__(self, index):
129 | # Delay loading LMDB data until after initialization: https://github.com/chainer/chainermn/issues/129
130 | if self.env is None:
131 | self._init_db()
132 | env = self.env
133 | with env.begin(write=False) as txn:
134 | byteflow = txn.get(self.keys[index])
135 | ref = loads_pyarrow(byteflow)
136 | # img
137 | ori_img = cv2.imdecode(np.frombuffer(ref['img'], np.uint8),
138 | cv2.IMREAD_COLOR)
139 | img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
140 | img_size = img.shape[:2]
141 | # mask
142 | seg_id = ref['seg_id']
143 | mask_dir = os.path.join(self.mask_dir, str(seg_id) + '.png')
144 | # sentences
145 | idx = np.random.choice(ref['num_sents'])
146 | sents = ref['sents']
147 | # transform
148 | # mask transform
149 | mask = cv2.imdecode(np.frombuffer(ref['mask'], np.uint8),
150 | cv2.IMREAD_GRAYSCALE)
151 | mask = mask / 255.
152 | if self.mode == 'train':
153 | sent = sents[idx]
154 | # sentence -> vector
155 | img, mask, sent = self.convert(img, mask, sent, inference=False)
156 | word_vec = tokenize(sent, self.word_length, True).squeeze(0)
157 | pad_mask = (word_vec != 0).float()
158 | return img, word_vec, mask, pad_mask
159 | elif self.mode == 'val':
160 | # sentence -> vector
161 | sent = sents[-1]
162 | word_vec = tokenize(sent, self.word_length, True).squeeze(0)
163 | pad_mask = (word_vec != 0).float()
164 | img, mask, sent = self.convert(img, mask, sent, inference=False)
165 | return img, word_vec, mask, pad_mask
166 | else:
167 | # sentence -> vector
168 | word_vecs = []
169 | pad_masks = []
170 | for sent in sents:
171 | word_vec = tokenize(sent, self.word_length, True).squeeze(0)
172 | word_vecs.append(word_vec)
173 | pad_mask = (word_vec != 0).float()
174 | pad_masks.append(pad_mask)
175 | img, mask, sent = self.convert(img, mask, sent, inference=True)
176 | return ori_img, img, word_vecs, mask, pad_masks, seg_id, sents
177 |
178 | def convert(self, img, mask, sent, inference=False):
179 | img = Image.fromarray(np.uint8(img))
180 | mask = Image.fromarray(np.uint8(mask), mode="P")
181 | img = F.resize(img, self.input_size)
182 | if not inference:
183 | mask = F.resize(mask, self.input_size, interpolation=Image.NEAREST)
184 | img = F.to_tensor(img)
185 | mask = torch.as_tensor(np.asarray(mask).copy(), dtype=torch.int64)
186 | img = F.normalize(img, mean=self.mean, std=self.std)
187 | return img, mask, sent
188 |
189 | def __repr__(self):
190 | return self.__class__.__name__ + "(" + \
191 | f"db_path={self.lmdb_dir}, " + \
192 | f"dataset={self.dataset}, " + \
193 | f"split={self.split}, " + \
194 | f"mode={self.mode}, " + \
195 | f"input_size={self.input_size}, " + \
196 | f"word_length={self.word_length}"
197 |
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import numpy as np
4 | from PIL import Image
5 | from loguru import logger
6 | import sys
7 | import inspect
8 | from timm.scheduler.cosine_lr import CosineLRScheduler
9 | import torch
10 | from torch import nn
11 | import torch.distributed as dist
12 |
13 |
14 | def init_random_seed(seed=None, device='cuda', rank=0, world_size=1):
15 | """Initialize random seed."""
16 | if seed is not None:
17 | return seed
18 |
19 | # Make sure all ranks share the same random seed to prevent
20 | # some potential bugs. Please refer to
21 | # https://github.com/open-mmlab/mmdetection/issues/6339
22 | seed = np.random.randint(2**31)
23 | if world_size == 1:
24 | return seed
25 |
26 | if rank == 0:
27 | random_num = torch.tensor(seed, dtype=torch.int32, device=device)
28 | else:
29 | random_num = torch.tensor(0, dtype=torch.int32, device=device)
30 | dist.broadcast(random_num, src=0)
31 | return random_num.item()
32 |
33 |
34 | def set_random_seed(seed, deterministic=False):
35 | """Set random seed."""
36 | random.seed(seed)
37 | np.random.seed(seed)
38 | torch.manual_seed(seed)
39 | torch.cuda.manual_seed_all(seed)
40 | if deterministic:
41 | torch.backends.cudnn.deterministic = True
42 | torch.backends.cudnn.benchmark = False
43 |
44 |
45 | @torch.no_grad()
46 | def concat_all_gather(tensor):
47 | """
48 | Performs all_gather operation on the provided tensors.
49 | *** Warning ***: torch.distributed.all_gather has no gradient.
50 | """
51 | tensor = tensor.contiguous()
52 | tensors_gather = [
53 | torch.ones_like(tensor)
54 | for _ in range(torch.distributed.get_world_size())
55 | ]
56 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
57 |
58 | output = torch.cat(tensors_gather, dim=0)
59 | return output
60 |
61 |
62 |
63 | def worker_init_fn(worker_id, num_workers, rank, seed):
64 | # The seed of each worker equals to
65 | # num_worker * rank + worker_id + user_seed
66 | worker_seed = num_workers * rank + worker_id + seed
67 | np.random.seed(worker_seed)
68 | random.seed(worker_seed)
69 |
70 |
71 | class AverageMeter(object):
72 | """Computes and stores the average and current value"""
73 |
74 | def __init__(self, name, fmt=":f"):
75 | self.name = name
76 | self.fmt = fmt
77 | self.reset()
78 |
79 | def reset(self):
80 | self.val = 0
81 | self.avg = 0
82 | self.sum = 0
83 | self.count = 0
84 |
85 | def update(self, val, n=1):
86 | self.val = val
87 | self.sum += val * n
88 | self.count += n
89 | self.avg = self.sum / self.count
90 |
91 | def __str__(self):
92 | if self.name == "Lr":
93 | fmtstr = "{name}={val" + self.fmt + "}"
94 | else:
95 | fmtstr = "{name}={val" + self.fmt + "} ({avg" + self.fmt + "})"
96 | return fmtstr.format(**self.__dict__)
97 |
98 |
99 | class ProgressMeter(object):
100 | def __init__(self, num_batches, meters, prefix=""):
101 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
102 | self.meters = meters
103 | self.prefix = prefix
104 |
105 | def display(self, batch):
106 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
107 | entries += [str(meter) for meter in self.meters]
108 | logger.info(" ".join(entries))
109 |
110 | def _get_batch_fmtstr(self, num_batches):
111 | num_digits = len(str(num_batches // 1))
112 | fmt = "{:" + str(num_digits) + "d}"
113 | return "[" + fmt + "/" + fmt.format(num_batches) + "]"
114 |
115 |
116 | def trainMetricGPU(output, target, threshold=0.35, pr_iou=0.5):
117 | assert (output.dim() in [2, 3, 4])
118 | assert output.shape == target.shape
119 | output = output.flatten(1)
120 | target = target.flatten(1)
121 | output = torch.sigmoid(output)
122 | output[output < threshold] = 0.
123 | output[output >= threshold] = 1.
124 | # inter & union
125 | inter = (output.bool() & target.bool()).sum(dim=1) # b
126 | union = (output.bool() | target.bool()).sum(dim=1) # b
127 | ious = inter / (union + 1e-6) # 0 ~ 1
128 | # iou & pr@5
129 | iou = ious.mean()
130 | prec = (ious > pr_iou).float().mean()
131 | return 100. * iou, 100. * prec
132 |
133 | def ValMetricGPU(output, target, threshold=0.35):
134 | assert output.size(0) == 1
135 | output = output.flatten(1)
136 | target = target.flatten(1)
137 | output = torch.sigmoid(output)
138 | output[output < threshold] = 0.
139 | output[output >= threshold] = 1.
140 | # inter & union
141 | inter = (output.bool() & target.bool()).sum(dim=1) # b
142 | union = (output.bool() | target.bool()).sum(dim=1) # b
143 | ious = inter / (union + 1e-6) # 0 ~ 1
144 | return ious
145 |
146 |
147 | def intersectionAndUnionGPU(output, target, K, threshold=0.5):
148 | # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
149 | assert (output.dim() in [1, 2, 3])
150 | assert output.shape == target.shape
151 | output = output.view(-1)
152 | target = target.view(-1)
153 |
154 | output = torch.sigmoid(output)
155 | output[output < threshold] = 0.
156 | output[output >= threshold] = 1.
157 |
158 | intersection = output[output == target]
159 | area_intersection = torch.histc(intersection.float(),
160 | bins=K,
161 | min=0,
162 | max=K - 1)
163 | area_output = torch.histc(output.float(), bins=K, min=0, max=K - 1)
164 | area_target = torch.histc(target.float(), bins=K, min=0, max=K - 1)
165 | area_union = area_output + area_target - area_intersection
166 | return area_intersection[1], area_union[1]
167 |
168 |
169 | def group_weight(weight_group, module, lr):
170 | group_decay = []
171 | group_no_decay = []
172 | for m in module.modules():
173 | if isinstance(m, nn.Linear):
174 | group_decay.append(m.weight)
175 | if m.bias is not None:
176 | group_no_decay.append(m.bias)
177 | elif isinstance(m, nn.modules.conv._ConvNd):
178 | group_decay.append(m.weight)
179 | if m.bias is not None:
180 | group_no_decay.append(m.bias)
181 | elif isinstance(m, nn.modules.batchnorm._BatchNorm):
182 | if m.weight is not None:
183 | group_no_decay.append(m.weight)
184 | if m.bias is not None:
185 | group_no_decay.append(m.bias)
186 | assert len(list(
187 | module.parameters())) == len(group_decay) + len(group_no_decay)
188 | weight_group.append(dict(params=group_decay, lr=lr))
189 | weight_group.append(dict(params=group_no_decay, weight_decay=.0, lr=lr))
190 | return weight_group
191 |
192 |
193 | def colorize(gray, palette):
194 | # gray: numpy array of the label and 1*3N size list palette
195 | color = Image.fromarray(gray.astype(np.uint8)).convert('P')
196 | color.putpalette(palette)
197 | return color
198 |
199 |
200 | def find_free_port():
201 | import socket
202 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
203 | # Binding to port 0 will cause the OS to find an available port for us
204 | sock.bind(("", 0))
205 | port = sock.getsockname()[1]
206 | sock.close()
207 | # NOTE: there is still a chance the port could be taken by other processes.
208 | return port
209 |
210 |
211 | def get_caller_name(depth=0):
212 | """
213 | Args:
214 | depth (int): Depth of caller conext, use 0 for caller depth.
215 | Default value: 0.
216 |
217 | Returns:
218 | str: module name of the caller
219 | """
220 | # the following logic is a little bit faster than inspect.stack() logic
221 | frame = inspect.currentframe().f_back
222 | for _ in range(depth):
223 | frame = frame.f_back
224 |
225 | return frame.f_globals["__name__"]
226 |
227 |
228 | class StreamToLoguru:
229 | """
230 | stream object that redirects writes to a logger instance.
231 | """
232 | def __init__(self, level="INFO", caller_names=("apex", "pycocotools")):
233 | """
234 | Args:
235 | level(str): log level string of loguru. Default value: "INFO".
236 | caller_names(tuple): caller names of redirected module.
237 | Default value: (apex, pycocotools).
238 | """
239 | self.level = level
240 | self.linebuf = ""
241 | self.caller_names = caller_names
242 |
243 | def write(self, buf):
244 | full_name = get_caller_name(depth=1)
245 | module_name = full_name.rsplit(".", maxsplit=-1)[0]
246 | if module_name in self.caller_names:
247 | for line in buf.rstrip().splitlines():
248 | # use caller level log
249 | logger.opt(depth=2).log(self.level, line.rstrip())
250 | else:
251 | sys.__stdout__.write(buf)
252 |
253 | def flush(self):
254 | pass
255 |
256 |
257 | def redirect_sys_output(log_level="INFO"):
258 | redirect_logger = StreamToLoguru(log_level)
259 | sys.stderr = redirect_logger
260 | sys.stdout = redirect_logger
261 |
262 |
263 | def setup_logger(save_dir, distributed_rank=0, filename="log.txt", mode="a"):
264 | """setup logger for training and testing.
265 | Args:
266 | save_dir(str): location to save log file
267 | distributed_rank(int): device rank when multi-gpu environment
268 | filename (string): log save name.
269 | mode(str): log file write mode, `append` or `override`. default is `a`.
270 |
271 | Return:
272 | logger instance.
273 | """
274 | loguru_format = (
275 | "{time:YYYY-MM-DD HH:mm:ss} | "
276 | "{level: <8} | "
277 | "{name}:{line} - {message}")
278 |
279 | logger.remove()
280 | save_file = os.path.join(save_dir, filename)
281 | if mode == "o" and os.path.exists(save_file):
282 | os.remove(save_file)
283 | # only keep logger in rank0 process
284 | if distributed_rank == 0:
285 | logger.add(
286 | sys.stderr,
287 | format=loguru_format,
288 | level="INFO",
289 | enqueue=True,
290 | )
291 | logger.add(save_file)
292 |
293 | # redirect stdout/stderr to loguru
294 | redirect_sys_output("INFO")
295 |
296 |
297 | def build_scheduler(config, optimizer, n_iter_per_epoch):
298 | num_steps = int(config.epochs * n_iter_per_epoch)
299 | warmup_steps = int(config.warmup_epochs * n_iter_per_epoch)
300 |
301 | lr_scheduler = CosineLRScheduler(
302 | optimizer,
303 | t_initial=num_steps,
304 | lr_min=config.min_lr,
305 | warmup_lr_init=config.warmup_lr,
306 | warmup_t=warmup_steps,
307 | cycle_limit=1,
308 | t_in_epochs=False,
309 | )
310 | return lr_scheduler
311 |
312 | def collate_fn(batch):
313 | batch = list(zip(*batch))
314 | return tuple(batch)
--------------------------------------------------------------------------------
/utils/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | @lru_cache()
11 | def default_bpe():
12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13 |
14 |
15 | @lru_cache()
16 | def bytes_to_unicode():
17 | """
18 | Returns list of utf-8 byte and a corresponding list of unicode strings.
19 | The reversible bpe codes work on unicode strings.
20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22 | This is a signficant percentage of your normal, say, 32K bpe vocab.
23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24 | And avoids mapping to whitespace/control characters the bpe code barfs on.
25 | """
26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27 | cs = bs[:]
28 | n = 0
29 | for b in range(2**8):
30 | if b not in bs:
31 | bs.append(b)
32 | cs.append(2**8+n)
33 | n += 1
34 | cs = [chr(n) for n in cs]
35 | return dict(zip(bs, cs))
36 |
37 |
38 | def get_pairs(word):
39 | """Return set of symbol pairs in a word.
40 | Word is represented as tuple of symbols (symbols being variable-length strings).
41 | """
42 | pairs = set()
43 | prev_char = word[0]
44 | for char in word[1:]:
45 | pairs.add((prev_char, char))
46 | prev_char = char
47 | return pairs
48 |
49 |
50 | def basic_clean(text):
51 | text = ftfy.fix_text(text)
52 | text = html.unescape(html.unescape(text))
53 | return text.strip()
54 |
55 |
56 | def whitespace_clean(text):
57 | text = re.sub(r'\s+', ' ', text)
58 | text = text.strip()
59 | return text
60 |
61 |
62 | class SimpleTokenizer(object):
63 | def __init__(self, bpe_path: str = default_bpe()):
64 | self.byte_encoder = bytes_to_unicode()
65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67 | merges = merges[1:49152-256-2+1]
68 | merges = [tuple(merge.split()) for merge in merges]
69 | vocab = list(bytes_to_unicode().values())
70 | vocab = vocab + [v+'' for v in vocab]
71 | for merge in merges:
72 | vocab.append(''.join(merge))
73 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74 | self.encoder = dict(zip(vocab, range(len(vocab))))
75 | self.decoder = {v: k for k, v in self.encoder.items()}
76 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79 |
80 | def bpe(self, token):
81 | if token in self.cache:
82 | return self.cache[token]
83 | word = tuple(token[:-1]) + ( token[-1] + '',)
84 | pairs = get_pairs(word)
85 |
86 | if not pairs:
87 | return token+''
88 |
89 | while True:
90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91 | if bigram not in self.bpe_ranks:
92 | break
93 | first, second = bigram
94 | new_word = []
95 | i = 0
96 | while i < len(word):
97 | try:
98 | j = word.index(first, i)
99 | new_word.extend(word[i:j])
100 | i = j
101 | except:
102 | new_word.extend(word[i:])
103 | break
104 |
105 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
106 | new_word.append(first+second)
107 | i += 2
108 | else:
109 | new_word.append(word[i])
110 | i += 1
111 | new_word = tuple(new_word)
112 | word = new_word
113 | if len(word) == 1:
114 | break
115 | else:
116 | pairs = get_pairs(word)
117 | word = ' '.join(word)
118 | self.cache[token] = word
119 | return word
120 |
121 | def encode(self, text):
122 | bpe_tokens = []
123 | text = whitespace_clean(basic_clean(text)).lower()
124 | for token in re.findall(self.pat, text):
125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127 | return bpe_tokens
128 |
129 | def decode(self, tokens):
130 | text = ''.join([self.decoder[token] for token in tokens])
131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
132 | return text
133 |
--------------------------------------------------------------------------------