├── README.md ├── cpp ├── Tokenizer.hpp ├── images │ ├── 0.jpg │ ├── 1.jpg │ ├── 2.jpg │ ├── 3.jpg │ ├── 4.jpeg │ ├── 5.jpg │ ├── 6.jpg │ ├── 7.jpg │ ├── cat_dog.jpeg │ └── demo7.jpg ├── main.cpp ├── string_utility.hpp └── vocab.txt ├── export_onnx.py └── python ├── clip_tokenizer.py ├── images ├── 0.jpg ├── 1.jpg ├── 2.jpg ├── 3.jpg ├── 4.jpeg ├── 5.jpg ├── 6.jpg ├── 7.jpg ├── cat_dog.jpeg └── demo7.jpg ├── main.py └── vocab.txt /README.md: -------------------------------------------------------------------------------- 1 | 在运行程序时,要注意输入的提示词的格式,类别之间以" . "隔开,并且确保类别名称在词典文件 2 | vocab.txt里是存在的,而且输入提示词里的类别名称是你想要检测的目标类别,否则可能会检测不到目标的。 3 | 4 | 如果要导出onnx文件,把export_onnx.py放在https://github.com/wenyi5608/GroundingDINO 5 | 里运行就可以生成onnx文件的,注意pytorch的版本需要在2.0以上的。这个仓库里的代码跟官方仓库https://github.com/IDEA-Research/GroundingDINO 6 | 里的代码的不同之处在于 7 | groundingdino\models\GroundingDINO\groundingdino.py里的forward函数的输入参数不同。 8 | 9 | 已经导出的onnx文件在百度云盘,下载链接:https://pan.baidu.com/s/1_dDxaSMG2vbw47FJ7FdUUg 10 | 提取码:u6lr 11 | -------------------------------------------------------------------------------- /cpp/Tokenizer.hpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/GroundingDINO-onnxrun/1197ebcc652013871efd634488850a7e42b82b85/cpp/Tokenizer.hpp -------------------------------------------------------------------------------- /cpp/images/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/GroundingDINO-onnxrun/1197ebcc652013871efd634488850a7e42b82b85/cpp/images/0.jpg -------------------------------------------------------------------------------- /cpp/images/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/GroundingDINO-onnxrun/1197ebcc652013871efd634488850a7e42b82b85/cpp/images/1.jpg -------------------------------------------------------------------------------- /cpp/images/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/GroundingDINO-onnxrun/1197ebcc652013871efd634488850a7e42b82b85/cpp/images/2.jpg -------------------------------------------------------------------------------- /cpp/images/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/GroundingDINO-onnxrun/1197ebcc652013871efd634488850a7e42b82b85/cpp/images/3.jpg -------------------------------------------------------------------------------- /cpp/images/4.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/GroundingDINO-onnxrun/1197ebcc652013871efd634488850a7e42b82b85/cpp/images/4.jpeg -------------------------------------------------------------------------------- /cpp/images/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/GroundingDINO-onnxrun/1197ebcc652013871efd634488850a7e42b82b85/cpp/images/5.jpg -------------------------------------------------------------------------------- /cpp/images/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/GroundingDINO-onnxrun/1197ebcc652013871efd634488850a7e42b82b85/cpp/images/6.jpg -------------------------------------------------------------------------------- /cpp/images/7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/GroundingDINO-onnxrun/1197ebcc652013871efd634488850a7e42b82b85/cpp/images/7.jpg -------------------------------------------------------------------------------- /cpp/images/cat_dog.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/GroundingDINO-onnxrun/1197ebcc652013871efd634488850a7e42b82b85/cpp/images/cat_dog.jpeg -------------------------------------------------------------------------------- /cpp/images/demo7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/GroundingDINO-onnxrun/1197ebcc652013871efd634488850a7e42b82b85/cpp/images/demo7.jpg -------------------------------------------------------------------------------- /cpp/main.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/GroundingDINO-onnxrun/1197ebcc652013871efd634488850a7e42b82b85/cpp/main.cpp -------------------------------------------------------------------------------- /cpp/string_utility.hpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/GroundingDINO-onnxrun/1197ebcc652013871efd634488850a7e42b82b85/cpp/string_utility.hpp -------------------------------------------------------------------------------- /export_onnx.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | os.environ["KMP_DUPLICATE_LIB_OK"] = 'TRUE' 5 | from groundingdino.models import build_model 6 | from groundingdino.util.slconfig import SLConfig 7 | from groundingdino.util.utils import clean_state_dict 8 | 9 | def load_model(model_config_path, model_checkpoint_path, cpu_only=False): 10 | args = SLConfig.fromfile(model_config_path) 11 | args.device = "cuda" if not cpu_only else "cpu" 12 | 13 | #modified config 14 | args.use_checkpoint = False 15 | args.use_transformer_ckpt = False 16 | 17 | model = build_model(args) 18 | checkpoint = torch.load(model_checkpoint_path, map_location="cpu") 19 | model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) 20 | _ = model.eval() 21 | return model 22 | 23 | 24 | def export_onnx(model, output_dir): 25 | caption = "the running dog ." #". ".join(input_text) 26 | input_ids = model.tokenizer([caption], return_tensors="pt")["input_ids"] 27 | position_ids = torch.tensor([[0, 0, 1, 2, 3, 0]]) 28 | token_type_ids = torch.tensor([[0, 0, 0, 0, 0, 0]]) 29 | attention_mask = torch.tensor([[True, True, True, True, True, True]]) 30 | text_token_mask = torch.tensor([[[ True, False, False, False, False, False], 31 | [False, True, True, True, True, False], 32 | [False, True, True, True, True, False], 33 | [False, True, True, True, True, False], 34 | [False, True, True, True, True, False], 35 | [False, False, False, False, False, True]]]) 36 | 37 | img = torch.randn(1, 3, 800, 1200) 38 | 39 | dynamic_axes={ 40 | "input_ids": {0: "batch_size", 1: "seq_len"}, 41 | "attention_mask": {0: "batch_size", 1: "seq_len"}, 42 | "position_ids": {0: "batch_size", 1: "seq_len"}, 43 | "token_type_ids": {0: "batch_size", 1: "seq_len"}, 44 | "text_token_mask": {0: "batch_size", 1: "seq_len", 2: "seq_len"}, 45 | "img": {0: "batch_size", 2: "height", 3: "width"}, 46 | "logits": {0: "batch_size"}, 47 | "boxes": {0: "batch_size"} 48 | } 49 | 50 | #export onnx model 51 | torch.onnx.export( 52 | model, 53 | f=os.path.join(output_dir, "groundingdino.onnx"), 54 | args=(img, input_ids, attention_mask, position_ids, token_type_ids, text_token_mask), #, zeros, ones), 55 | input_names=["img" , "input_ids", "attention_mask", "position_ids", "token_type_ids", "text_token_mask"], 56 | output_names=["logits", "boxes"], 57 | dynamic_axes=dynamic_axes, 58 | opset_version=16) 59 | 60 | if __name__ == "__main__": 61 | parser = argparse.ArgumentParser("Export Grounding DINO Model to IR", add_help=True) 62 | parser.add_argument("--config_file", "-c", type=str, required=True, help="path to config file") 63 | parser.add_argument( 64 | "--checkpoint_path", "-p", type=str, required=True, help="path to checkpoint file" 65 | ) 66 | parser.add_argument( 67 | "--output_dir", "-o", type=str, default="outputs", required=True, help="output directory" 68 | ) 69 | 70 | args = parser.parse_args() 71 | 72 | # cfg 73 | config_file = args.config_file # change the path of the model config file 74 | checkpoint_path = args.checkpoint_path # change the path of the model 75 | output_dir = args.output_dir 76 | 77 | # make dir 78 | os.makedirs(output_dir, exist_ok=True) 79 | 80 | # load model 81 | model = load_model(config_file, checkpoint_path, cpu_only=True) 82 | 83 | #export onnx 84 | export_onnx(model, output_dir) 85 | 86 | ###python export_onnx.py -c groundingdino/config/GroundingDINO_SwinT_OGC.py -p weights/groundingdino_swint_ogc.pth -o weights/ 87 | ###python export_onnx.py -c groundingdino/config/GroundingDINO_SwinB_cfg.py -p weights/groundingdino_swinb_cogcoor.pth -o weights/ -------------------------------------------------------------------------------- /python/clip_tokenizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | """Tokenization classes.""" 18 | import collections 19 | import re 20 | import unicodedata 21 | import six 22 | from functools import lru_cache 23 | import os 24 | import numpy as np 25 | @lru_cache() 26 | def default_vocab(): 27 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "vocab.txt") 28 | 29 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): 30 | """Checks whether the casing config is consistent with the checkpoint name.""" 31 | 32 | # The casing has to be passed in by the user and there is no explicit check 33 | # as to whether it matches the checkpoint. The casing information probably 34 | # should have been stored in the bert_config.json file, but it's not, so 35 | # we have to heuristically detect it to validate. 36 | 37 | if not init_checkpoint: 38 | return 39 | 40 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) 41 | if m is None: 42 | return 43 | 44 | model_name = m.group(1) 45 | 46 | lower_models = [ 47 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", 48 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" 49 | ] 50 | 51 | cased_models = [ 52 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", 53 | "multi_cased_L-12_H-768_A-12" 54 | ] 55 | 56 | is_bad_config = False 57 | if model_name in lower_models and not do_lower_case: 58 | is_bad_config = True 59 | actual_flag = "False" 60 | case_name = "lowercased" 61 | opposite_flag = "True" 62 | 63 | if model_name in cased_models and do_lower_case: 64 | is_bad_config = True 65 | actual_flag = "True" 66 | case_name = "cased" 67 | opposite_flag = "False" 68 | 69 | if is_bad_config: 70 | raise ValueError( 71 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " 72 | "However, `%s` seems to be a %s model, so you " 73 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches " 74 | "how the model was pre-training. If this error is wrong, please " 75 | "just comment out this check." % (actual_flag, init_checkpoint, 76 | model_name, case_name, opposite_flag)) 77 | 78 | 79 | def convert_to_unicode(text): 80 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 81 | if six.PY3: 82 | if isinstance(text, str): 83 | return text 84 | elif isinstance(text, bytes): 85 | return text.decode("utf-8", "ignore") 86 | else: 87 | raise ValueError("Unsupported string type: %s" % (type(text))) 88 | elif six.PY2: 89 | if isinstance(text, str): 90 | return text.decode("utf-8", "ignore") 91 | elif isinstance(text, unicode): 92 | return text 93 | else: 94 | raise ValueError("Unsupported string type: %s" % (type(text))) 95 | else: 96 | raise ValueError("Not running on Python2 or Python 3?") 97 | 98 | 99 | def printable_text(text): 100 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 101 | 102 | # These functions want `str` for both Python2 and Python3, but in one case 103 | # it's a Unicode string and in the other it's a byte string. 104 | if six.PY3: 105 | if isinstance(text, str): 106 | return text 107 | elif isinstance(text, bytes): 108 | return text.decode("utf-8", "ignore") 109 | else: 110 | raise ValueError("Unsupported string type: %s" % (type(text))) 111 | elif six.PY2: 112 | if isinstance(text, str): 113 | return text 114 | elif isinstance(text, unicode): 115 | return text.encode("utf-8") 116 | else: 117 | raise ValueError("Unsupported string type: %s" % (type(text))) 118 | else: 119 | raise ValueError("Not running on Python2 or Python 3?") 120 | 121 | 122 | def load_vocab(vocab_file): 123 | """Loads a vocabulary file into a dictionary.""" 124 | vocab = collections.OrderedDict() 125 | index = 0 126 | with open(vocab_file, "r", encoding="utf-8") as reader: 127 | while True: 128 | token = convert_to_unicode(reader.readline()) 129 | if not token: 130 | break 131 | token = token.strip() 132 | vocab[token] = index 133 | index += 1 134 | return vocab 135 | 136 | 137 | def convert_by_vocab(vocab, items): 138 | """Converts a sequence of [tokens|ids] using the vocab.""" 139 | output = [] 140 | for item in items: 141 | output.append(vocab[item]) 142 | return output 143 | 144 | 145 | def convert_tokens_to_ids(vocab, tokens): 146 | return convert_by_vocab(vocab, tokens) 147 | 148 | 149 | def convert_ids_to_tokens(inv_vocab, ids): 150 | return convert_by_vocab(inv_vocab, ids) 151 | 152 | 153 | def whitespace_tokenize(text): 154 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 155 | text = text.strip() 156 | if not text: 157 | return [] 158 | tokens = text.split() 159 | return tokens 160 | 161 | 162 | class FullTokenizer(object): 163 | """Runs end-to-end tokenziation.""" 164 | 165 | def __init__(self, vocab_file=default_vocab(), do_lower_case=True): 166 | self.vocab = load_vocab(vocab_file) 167 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 168 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 169 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 170 | 171 | def tokenize(self, text): 172 | split_tokens = [] 173 | for token in self.basic_tokenizer.tokenize(text): 174 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 175 | split_tokens.append(sub_token) 176 | 177 | return split_tokens 178 | 179 | def convert_tokens_to_ids(self, tokens): 180 | return convert_by_vocab(self.vocab, tokens) 181 | 182 | def convert_ids_to_tokens(self, ids): 183 | return convert_by_vocab(self.inv_vocab, ids) 184 | 185 | @staticmethod 186 | def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True): 187 | """ Converts a sequence of tokens (string) in a single string. """ 188 | 189 | def clean_up_tokenization(out_string): 190 | """ Clean up a list of simple English tokenization artifacts 191 | like spaces before punctuations and abreviated forms. 192 | """ 193 | out_string = ( 194 | out_string.replace(" .", ".") 195 | .replace(" ?", "?") 196 | .replace(" !", "!") 197 | .replace(" ,", ",") 198 | .replace(" ' ", "'") 199 | .replace(" n't", "n't") 200 | .replace(" 'm", "'m") 201 | .replace(" 's", "'s") 202 | .replace(" 've", "'ve") 203 | .replace(" 're", "'re") 204 | ) 205 | return out_string 206 | 207 | text = ' '.join(tokens).replace(' ##', '').strip() 208 | if clean_up_tokenization_spaces: 209 | clean_text = clean_up_tokenization(text) 210 | return clean_text 211 | else: 212 | return text 213 | 214 | def vocab_size(self): 215 | return len(self.vocab) 216 | 217 | 218 | class BasicTokenizer(object): 219 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 220 | 221 | def __init__(self, do_lower_case=True): 222 | """Constructs a BasicTokenizer. 223 | 224 | Args: 225 | do_lower_case: Whether to lower case the input. 226 | """ 227 | self.do_lower_case = do_lower_case 228 | 229 | def tokenize(self, text): 230 | """Tokenizes a piece of text.""" 231 | text = convert_to_unicode(text) 232 | text = self._clean_text(text) 233 | 234 | # This was added on November 1st, 2018 for the multilingual and Chinese 235 | # models. This is also applied to the English models now, but it doesn't 236 | # matter since the English models were not trained on any Chinese data 237 | # and generally don't have any Chinese data in them (there are Chinese 238 | # characters in the vocabulary because Wikipedia does have some Chinese 239 | # words in the English Wikipedia.). 240 | text = self._tokenize_chinese_chars(text) 241 | 242 | orig_tokens = whitespace_tokenize(text) 243 | split_tokens = [] 244 | for token in orig_tokens: 245 | if self.do_lower_case: 246 | token = token.lower() 247 | token = self._run_strip_accents(token) 248 | split_tokens.extend(self._run_split_on_punc(token)) 249 | 250 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 251 | return output_tokens 252 | 253 | def _run_strip_accents(self, text): 254 | """Strips accents from a piece of text.""" 255 | text = unicodedata.normalize("NFD", text) 256 | output = [] 257 | for char in text: 258 | cat = unicodedata.category(char) 259 | if cat == "Mn": 260 | continue 261 | output.append(char) 262 | return "".join(output) 263 | 264 | def _run_split_on_punc(self, text): 265 | """Splits punctuation on a piece of text.""" 266 | chars = list(text) 267 | i = 0 268 | start_new_word = True 269 | output = [] 270 | while i < len(chars): 271 | char = chars[i] 272 | if _is_punctuation(char): 273 | output.append([char]) 274 | start_new_word = True 275 | else: 276 | if start_new_word: 277 | output.append([]) 278 | start_new_word = False 279 | output[-1].append(char) 280 | i += 1 281 | 282 | return ["".join(x) for x in output] 283 | 284 | def _tokenize_chinese_chars(self, text): 285 | """Adds whitespace around any CJK character.""" 286 | output = [] 287 | for char in text: 288 | cp = ord(char) 289 | if self._is_chinese_char(cp): 290 | output.append(" ") 291 | output.append(char) 292 | output.append(" ") 293 | else: 294 | output.append(char) 295 | return "".join(output) 296 | 297 | def _is_chinese_char(self, cp): 298 | """Checks whether CP is the codepoint of a CJK character.""" 299 | # This defines a "chinese character" as anything in the CJK Unicode block: 300 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 301 | # 302 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 303 | # despite its name. The modern Korean Hangul alphabet is a different block, 304 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 305 | # space-separated words, so they are not treated specially and handled 306 | # like the all of the other languages. 307 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 308 | (cp >= 0x3400 and cp <= 0x4DBF) or # 309 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 310 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 311 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 312 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 313 | (cp >= 0xF900 and cp <= 0xFAFF) or # 314 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 315 | return True 316 | 317 | return False 318 | 319 | def _clean_text(self, text): 320 | """Performs invalid character removal and whitespace cleanup on text.""" 321 | output = [] 322 | for char in text: 323 | cp = ord(char) 324 | if cp == 0 or cp == 0xfffd or _is_control(char): 325 | continue 326 | if _is_whitespace(char): 327 | output.append(" ") 328 | else: 329 | output.append(char) 330 | return "".join(output) 331 | 332 | 333 | class WordpieceTokenizer(object): 334 | """Runs WordPiece tokenziation.""" 335 | 336 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 337 | self.vocab = vocab 338 | self.unk_token = unk_token 339 | self.max_input_chars_per_word = max_input_chars_per_word 340 | 341 | def tokenize(self, text): 342 | """Tokenizes a piece of text into its word pieces. 343 | 344 | This uses a greedy longest-match-first algorithm to perform tokenization 345 | using the given vocabulary. 346 | 347 | For example: 348 | input = "unaffable" 349 | output = ["un", "##aff", "##able"] 350 | 351 | Args: 352 | text: A single token or whitespace separated tokens. This should have 353 | already been passed through `BasicTokenizer. 354 | 355 | Returns: 356 | A list of wordpiece tokens. 357 | """ 358 | 359 | text = convert_to_unicode(text) 360 | 361 | output_tokens = [] 362 | for token in whitespace_tokenize(text): 363 | chars = list(token) 364 | if len(chars) > self.max_input_chars_per_word: 365 | output_tokens.append(self.unk_token) 366 | continue 367 | 368 | is_bad = False 369 | start = 0 370 | sub_tokens = [] 371 | while start < len(chars): 372 | end = len(chars) 373 | cur_substr = None 374 | while start < end: 375 | substr = "".join(chars[start:end]) 376 | if start > 0: 377 | substr = "##" + substr 378 | if substr in self.vocab: 379 | cur_substr = substr 380 | break 381 | end -= 1 382 | if cur_substr is None: 383 | is_bad = True 384 | break 385 | sub_tokens.append(cur_substr) 386 | start = end 387 | 388 | if is_bad: 389 | output_tokens.append(self.unk_token) 390 | else: 391 | output_tokens.extend(sub_tokens) 392 | return output_tokens 393 | 394 | 395 | def _is_whitespace(char): 396 | """Checks whether `chars` is a whitespace character.""" 397 | # \t, \n, and \r are technically contorl characters but we treat them 398 | # as whitespace since they are generally considered as such. 399 | if char == " " or char == "\t" or char == "\n" or char == "\r": 400 | return True 401 | cat = unicodedata.category(char) 402 | if cat == "Zs": 403 | return True 404 | return False 405 | 406 | 407 | def _is_control(char): 408 | """Checks whether `chars` is a control character.""" 409 | # These are technically control characters but we count them as whitespace 410 | # characters. 411 | if char == "\t" or char == "\n" or char == "\r": 412 | return False 413 | cat = unicodedata.category(char) 414 | if cat in ("Cc", "Cf"): 415 | return True 416 | return False 417 | 418 | 419 | def _is_punctuation(char): 420 | """Checks whether `chars` is a punctuation character.""" 421 | cp = ord(char) 422 | # We treat all non-letter/number ASCII as punctuation. 423 | # Characters such as "^", "$", and "`" are not in the Unicode 424 | # Punctuation class but we treat them as punctuation anyways, for 425 | # consistency. 426 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 427 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 428 | return True 429 | cat = unicodedata.category(char) 430 | if cat.startswith("P"): 431 | return True 432 | return False 433 | 434 | ###使用 tokenizer 标记器, 将文本数据转换成 input_ids, token_type_ids, attention_mask 数据. token_type_ids 类型标记的作用是告诉模型输入的是第几句, 单句则全是 0, 在批量转换时, 不同句值不同. 435 | ###https://juejin.cn/post/7266344040760590376 436 | ###https://zhuanlan.zhihu.com/p/341994096 437 | ###padding用于填充。它的参数可以是布尔值或字符串:(1):True或”longest“:填充到最长序列(如果你仅提供单个序列,则不会填充).“max_length”:用于指定你想要填充的最大长度,如果max_length=Flase,那么填充到模型能接受的最大长度(这样即使你只输入单个序列,那么也会被填充到指定长度). (2):False或“do_not_pad”:不填充序列。如前所述,这是默认行为。 438 | def tokenize(_tokenizer, texts, specical_texts, context_length: int = 52): 439 | if isinstance(texts, str): 440 | texts = [texts] 441 | 442 | all_tokens = [] 443 | for text in texts: 444 | all_tokens.append([_tokenizer.vocab['[CLS]']] + _tokenizer.convert_tokens_to_ids(_tokenizer.tokenize(text))[ 445 | :context_length - 2] + [_tokenizer.vocab['[SEP]']]) 446 | 447 | input_ids = np.array(all_tokens, dtype=np.int64) 448 | token_type_ids = np.zeros_like(input_ids) ###单句则全是 0 449 | attention_mask = (input_ids > 0).astype(np.bool_) 450 | 451 | specical_tokens = [] 452 | for text in specical_texts: 453 | specical_tokens.append(_tokenizer.vocab[text]) 454 | 455 | return input_ids, token_type_ids, attention_mask, specical_tokens 456 | 457 | def generate_masks_with_special_tokens_and_transfer_map(input_ids, special_tokens_list): 458 | bs, num_token = input_ids.shape 459 | # special_tokens_mask: bs, num_token. 1 for special tokens. 0 for normal tokens 460 | special_tokens_mask = np.zeros((bs, num_token)).astype(np.bool_) ###高版本的numpy需要np.bool_ 461 | for special_token in special_tokens_list: 462 | special_tokens_mask |= input_ids == special_token 463 | 464 | # idxs: each row is a list of indices of special tokens 465 | idxs = np.array(np.nonzero(special_tokens_mask)).T 466 | 467 | # generate attention mask and positional ids 468 | text_self_attention_masks = np.tile(np.expand_dims(np.eye(num_token, dtype=np.bool_), axis=0), (bs, 1, 1)) 469 | position_ids = np.zeros((bs, num_token)) 470 | # cate_to_token_mask_list = [[] for _ in range(bs)] 471 | previous_col = 0 472 | for i in range(idxs.shape[0]): 473 | row, col = idxs[i] 474 | if (col == 0) or (col == num_token - 1): 475 | text_self_attention_masks[row, col, col] = True 476 | position_ids[row, col] = 0 477 | else: 478 | text_self_attention_masks[row, previous_col + 1 : col + 1, previous_col + 1 : col + 1] = True 479 | position_ids[row, previous_col + 1 : col + 1] = np.arange(0, col - previous_col) 480 | # c2t_maski = np.zeros((num_token)).astype(np.bool_) 481 | # c2t_maski[previous_col + 1 : col] = True 482 | # cate_to_token_mask_list[row].append(c2t_maski) 483 | previous_col = col 484 | return text_self_attention_masks, position_ids.astype(np.int64) -------------------------------------------------------------------------------- /python/images/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/GroundingDINO-onnxrun/1197ebcc652013871efd634488850a7e42b82b85/python/images/0.jpg -------------------------------------------------------------------------------- /python/images/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/GroundingDINO-onnxrun/1197ebcc652013871efd634488850a7e42b82b85/python/images/1.jpg -------------------------------------------------------------------------------- /python/images/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/GroundingDINO-onnxrun/1197ebcc652013871efd634488850a7e42b82b85/python/images/2.jpg -------------------------------------------------------------------------------- /python/images/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/GroundingDINO-onnxrun/1197ebcc652013871efd634488850a7e42b82b85/python/images/3.jpg -------------------------------------------------------------------------------- /python/images/4.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/GroundingDINO-onnxrun/1197ebcc652013871efd634488850a7e42b82b85/python/images/4.jpeg -------------------------------------------------------------------------------- /python/images/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/GroundingDINO-onnxrun/1197ebcc652013871efd634488850a7e42b82b85/python/images/5.jpg -------------------------------------------------------------------------------- /python/images/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/GroundingDINO-onnxrun/1197ebcc652013871efd634488850a7e42b82b85/python/images/6.jpg -------------------------------------------------------------------------------- /python/images/7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/GroundingDINO-onnxrun/1197ebcc652013871efd634488850a7e42b82b85/python/images/7.jpg -------------------------------------------------------------------------------- /python/images/cat_dog.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/GroundingDINO-onnxrun/1197ebcc652013871efd634488850a7e42b82b85/python/images/cat_dog.jpeg -------------------------------------------------------------------------------- /python/images/demo7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/GroundingDINO-onnxrun/1197ebcc652013871efd634488850a7e42b82b85/python/images/demo7.jpg -------------------------------------------------------------------------------- /python/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import numpy as np 4 | import onnxruntime 5 | from clip_tokenizer import tokenize, FullTokenizer, generate_masks_with_special_tokens_and_transfer_map 6 | 7 | def resize_image(srcimg, size, max_size=None): 8 | def get_size_with_aspect_ratio(image_size, size, max_size=None): 9 | w, h = image_size 10 | if max_size is not None: 11 | min_original_size = float(min((w, h))) 12 | max_original_size = float(max((w, h))) 13 | if max_original_size / min_original_size * size > max_size: 14 | size = int(round(max_size * min_original_size / max_original_size)) 15 | 16 | if (w <= h and w == size) or (h <= w and h == size): 17 | return (h, w) 18 | 19 | if w < h: 20 | ow = size 21 | oh = int(size * h / w) 22 | else: 23 | oh = size 24 | ow = int(size * w / h) 25 | 26 | return (oh, ow) 27 | 28 | def get_size(image_size, size, max_size=None): ###返回(高度, 宽度) 29 | if isinstance(size, (list, tuple)): 30 | return size[::-1] 31 | else: 32 | return get_size_with_aspect_ratio(image_size, size, max_size) 33 | img = cv2.cvtColor(srcimg, cv2.COLOR_BGR2RGB) 34 | size = get_size((img.shape[1], img.shape[0]), size, max_size) 35 | rescaled_image = cv2.resize(img, (size[1], size[0])) 36 | return rescaled_image 37 | 38 | def get_phrases_from_posmap(posmap, input_ids, tokenizer, left_idx=0, right_idx=255): 39 | if posmap.ndim == 1: 40 | posmap[0: left_idx + 1] = False 41 | posmap[right_idx:] = False 42 | non_zero_idx = np.nonzero(posmap)[0].tolist() 43 | if len(non_zero_idx)>0: 44 | token_ids = [input_ids[i] for i in non_zero_idx] 45 | return tokenizer.convert_ids_to_tokens(token_ids)[0] 46 | ##return ' '.join(tokenizer.convert_ids_to_tokens(token_ids)) ###原始是返回一段描述的,如果只是检测目标,可以只返回一个单词 47 | else: 48 | return None 49 | else: 50 | raise NotImplementedError("posmap must be 1-dim") 51 | 52 | class GroundingDINO(): 53 | def __init__(self, modelpath, box_threshold, vocab_path, text_threshold=None, with_logits=True): 54 | so = onnxruntime.SessionOptions() 55 | so.log_severity_level = 3 56 | self.net = onnxruntime.InferenceSession(modelpath, so) ###opencv-dnn读取失败 57 | # for inp in self.net.get_inputs(): 58 | # print(inp) 59 | # for oup in self.net.get_outputs(): 60 | # print(oup) 61 | 62 | self.input_names=["img" , "input_ids", "attention_mask", "position_ids", "token_type_ids", "text_token_mask"] 63 | self.output_names=["logits", "boxes"] 64 | self.box_threshold = box_threshold 65 | self.text_threshold = text_threshold 66 | self.with_logits = with_logits 67 | 68 | self.size = [1200, 800] ###(宽度, 高度) 69 | self.max_size = None 70 | # self.size = 800 71 | # self.max_size = 1333 72 | self.mean_ = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape((1,1,3)) 73 | self.std_ = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape((1,1,3)) 74 | self.max_text_len = 256 75 | self.specical_texts = ["[CLS]", "[SEP]", ".", "?"] 76 | self.tokenizer = FullTokenizer(vocab_file=vocab_path) 77 | 78 | def detect(self, srcimg, text_prompt): 79 | rescaled_image = resize_image(srcimg, self.size, max_size=self.max_size) 80 | img = (rescaled_image.astype(np.float32) / 255.0 - self.mean_) / self.std_ 81 | inputs = {"img":np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0).astype(np.float32)} 82 | 83 | caption = text_prompt.lower() 84 | caption = caption.strip() 85 | if not caption.endswith("."): 86 | caption = caption + " ." 87 | 88 | input_ids, token_type_ids, attention_mask, specical_tokens = tokenize(self.tokenizer, caption, self.specical_texts, context_length=self.max_text_len) 89 | text_self_attention_masks, position_ids = generate_masks_with_special_tokens_and_transfer_map(input_ids, specical_tokens) 90 | if text_self_attention_masks.shape[1] > self.max_text_len: 91 | text_self_attention_masks = text_self_attention_masks[:, : self.max_text_len, : self.max_text_len] 92 | 93 | position_ids = position_ids[:, : self.max_text_len] 94 | input_ids = input_ids[:, : self.max_text_len] 95 | attention_mask = attention_mask[:, : self.max_text_len] 96 | token_type_ids = token_type_ids[:, : self.max_text_len] 97 | 98 | inputs["input_ids"] = input_ids 99 | inputs["attention_mask"] = attention_mask 100 | inputs["token_type_ids"] = token_type_ids 101 | inputs["position_ids"] = position_ids 102 | inputs["text_token_mask"] = text_self_attention_masks 103 | 104 | outputs = self.net.run(self.output_names, inputs) 105 | 106 | prediction_logits_ = np.squeeze(outputs[0], axis=0) #[0] # prediction_logits.shape = (nq, 256) 107 | prediction_logits_ = 1/(1+np.exp(-prediction_logits_)) 108 | 109 | prediction_boxes_ = np.squeeze(outputs[1], axis=0) #[0] # prediction_boxes.shape = (nq, 4) 110 | 111 | filt_mask = np.max(prediction_logits_, axis=1) > self.box_threshold 112 | logits_filt = prediction_logits_[filt_mask] # num_filt, 256 113 | boxes_filt = prediction_boxes_[filt_mask] # num_filt, 4 114 | 115 | pred_phrases = [] 116 | for logit, box in zip(logits_filt, boxes_filt): 117 | pred_phrase = get_phrases_from_posmap(logit > self.text_threshold, input_ids[0, :], self.tokenizer) 118 | if pred_phrase is None: 119 | continue 120 | if self.with_logits: 121 | pred_phrases.append(pred_phrase + f"({str(logit.max())[:4]})") 122 | else: 123 | pred_phrases.append(pred_phrase) 124 | 125 | return boxes_filt, pred_phrases 126 | 127 | def draw_boxes_to_image(image, boxes, labels): 128 | h,w = image.shape[:2] 129 | for box, label in zip(boxes, labels): 130 | # from 0..1 to 0..W, 0..H 131 | box = box * np.array([w, h, w, h]) 132 | # from xywh to xyxy 133 | box[:2] -= box[2:] * 0.5 134 | box[2:] += box[:2] 135 | 136 | xmin, ymin, xmax, ymax = int(box[0]), int(box[1]), int(box[2]), int(box[3]) 137 | cv2.rectangle(image, (xmin, ymin), (xmax, ymax), (0, 0, 255), thickness=2) 138 | # txt_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1)[0] 139 | # cv2.rectangle(image, (xmin, ymin + 1), (xmin + txt_size[0] + 1, ymin + int(1.5 * txt_size[1])), (255, 255, 255), -1) 140 | cv2.putText(image, label, (xmin, ymin-5), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), thickness=2) 141 | return image 142 | 143 | 144 | if __name__ == '__main__': 145 | parser = argparse.ArgumentParser() 146 | parser.add_argument("--model_path", type=str, default="weights/groundingdino_swint_ogc.onnx", help="onnx model path") 147 | parser.add_argument("--image_path", type=str, default="images/cat_dog.jpeg", help="path to image file") 148 | parser.add_argument("--text_prompt", type=str, default="chair . person . dog .", help="text prompt, 每个类别名称之间以 . 隔开") ###cat_dog.jpeg的提示词:"chair . person . dog ."" demo7.jpg的提示词:"Horse . Clouds . Grasses . Sky . Hill ." 149 | parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold") 150 | parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold") 151 | args = parser.parse_args() 152 | 153 | mynet = GroundingDINO(args.model_path, args.box_threshold, "vocab.txt", text_threshold=args.text_threshold) 154 | srcimg = cv2.imread(args.image_path) 155 | 156 | boxes_filt, pred_phrases = mynet.detect(srcimg, args.text_prompt) 157 | drawimg = draw_boxes_to_image(srcimg, boxes_filt, pred_phrases) 158 | 159 | # cv2.imwrite('result.jpg', drawimg) 160 | winName = 'GroundingDINO use OnnxRuntime' 161 | cv2.namedWindow(winName, 0) 162 | cv2.imshow(winName, drawimg) 163 | cv2.waitKey(0) 164 | cv2.destroyAllWindows() 165 | --------------------------------------------------------------------------------