├── model └── ostrack │ ├── layers │ ├── __init__.py │ ├── __pycache__ │ │ ├── attn.cpython-311.pyc │ │ ├── head.cpython-311.pyc │ │ ├── rpe.cpython-311.pyc │ │ ├── __init__.cpython-311.pyc │ │ ├── frozen_bn.cpython-311.pyc │ │ ├── attn_blocks.cpython-311.pyc │ │ └── patch_embed.cpython-311.pyc │ ├── patch_embed.py │ ├── frozen_bn.py │ ├── rpe.py │ ├── attn.py │ ├── attn_blocks.py │ └── head.py │ ├── utils │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-311.pyc │ │ ├── box_ops.cpython-311.pyc │ │ └── tensor.cpython-311.pyc │ ├── merge.py │ ├── variable_hook.py │ ├── lmdb_utils.py │ ├── focal_loss.py │ ├── box_ops.py │ ├── ce_utils.py │ ├── heapmap_utils.py │ ├── tensor.py │ └── misc.py │ ├── __pycache__ │ ├── util.cpython-311.pyc │ ├── vit.cpython-311.pyc │ ├── vit_ce.cpython-311.pyc │ ├── ostrack.cpython-311.pyc │ └── base_backbone.cpython-311.pyc │ ├── util.py │ ├── ostrack.py │ ├── base_backbone.py │ ├── vit_ce.py │ └── vit.py ├── __init__.py ├── assets ├── uav_1.jpg ├── uav_2.jpg ├── uav_3.jpg ├── OSTrack.jpg ├── bandicam.gif ├── infrared_5.gif ├── cover_image.jpg ├── uav_1_result.jpg ├── results │ ├── P_curve.jpg │ ├── R_curve.jpg │ ├── labels.jpg │ ├── paths.jpg │ ├── results.jpg │ ├── F1_curve.jpg │ ├── PR_curve.jpg │ ├── train_batch0.jpg │ ├── train_batch1.jpg │ ├── train_batch2.jpg │ ├── confusion_matrix.jpg │ ├── val_batch0_labels.jpg │ ├── val_batch0_pred.jpg │ ├── val_batch1_labels.jpg │ ├── val_batch1_pred.jpg │ ├── val_batch2_labels.jpg │ ├── val_batch2_pred.jpg │ └── labels_correlogram.jpg ├── ostrack_test0320.jpg ├── exception │ ├── exception1.gif │ └── exception2.gif ├── processed_infrared_5.gif └── architecture │ ├── ostrack_1.jpg │ ├── ostrack_2.jpg │ ├── ostrack_3.jpg │ ├── ostrack_4.jpg │ ├── ostrack_5.jpg │ └── ostrack_6.jpg ├── frames ├── shot_0.jpg ├── shot_1.jpg ├── shot_2.jpg ├── shot_3.jpg ├── shot_4.jpg └── shot_5.jpg ├── images └── main_window_img.png ├── .gitmodules ├── __pycache__ ├── bbox.cpython-311.pyc ├── util.cpython-311.pyc ├── utils.cpython-311.pyc ├── yolo.cpython-311.pyc ├── yolov5.cpython-311.pyc └── yolo_model.cpython-311.pyc ├── config ├── got_10k.yaml ├── vitb_384_mae_32x4_ep300.yaml ├── vitb_256_mae_32x4_ep300.yaml ├── vitb_256_mae_ce_32x4_got10k_ep100.yaml ├── vitb_384_mae_ce_32x4_ep300.yaml ├── vitb_384_mae_ce_32x4_got10k_ep100.yaml └── vitb_256_mae_ce_32x4_ep300.yaml ├── .gitignore ├── LICENSE ├── bbox.py ├── yolo_model.py ├── util.py ├── README.md └── app_osyo.py /model/ostrack/layers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tkinter as tk 3 | import numpy as np 4 | -------------------------------------------------------------------------------- /model/ostrack/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .tensor import TensorDict, TensorList 2 | -------------------------------------------------------------------------------- /assets/uav_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/uav_1.jpg -------------------------------------------------------------------------------- /assets/uav_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/uav_2.jpg -------------------------------------------------------------------------------- /assets/uav_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/uav_3.jpg -------------------------------------------------------------------------------- /assets/OSTrack.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/OSTrack.jpg -------------------------------------------------------------------------------- /frames/shot_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/frames/shot_0.jpg -------------------------------------------------------------------------------- /frames/shot_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/frames/shot_1.jpg -------------------------------------------------------------------------------- /frames/shot_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/frames/shot_2.jpg -------------------------------------------------------------------------------- /frames/shot_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/frames/shot_3.jpg -------------------------------------------------------------------------------- /frames/shot_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/frames/shot_4.jpg -------------------------------------------------------------------------------- /frames/shot_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/frames/shot_5.jpg -------------------------------------------------------------------------------- /assets/bandicam.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/bandicam.gif -------------------------------------------------------------------------------- /assets/infrared_5.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/infrared_5.gif -------------------------------------------------------------------------------- /assets/cover_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/cover_image.jpg -------------------------------------------------------------------------------- /assets/uav_1_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/uav_1_result.jpg -------------------------------------------------------------------------------- /assets/results/P_curve.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/results/P_curve.jpg -------------------------------------------------------------------------------- /assets/results/R_curve.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/results/R_curve.jpg -------------------------------------------------------------------------------- /assets/results/labels.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/results/labels.jpg -------------------------------------------------------------------------------- /assets/results/paths.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/results/paths.jpg -------------------------------------------------------------------------------- /assets/results/results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/results/results.jpg -------------------------------------------------------------------------------- /images/main_window_img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/images/main_window_img.png -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "yolov5"] 2 | path = yolov5 3 | url = https://github.com/ultralytics/yolov5.git 4 | -------------------------------------------------------------------------------- /assets/ostrack_test0320.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/ostrack_test0320.jpg -------------------------------------------------------------------------------- /assets/results/F1_curve.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/results/F1_curve.jpg -------------------------------------------------------------------------------- /assets/results/PR_curve.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/results/PR_curve.jpg -------------------------------------------------------------------------------- /assets/exception/exception1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/exception/exception1.gif -------------------------------------------------------------------------------- /assets/exception/exception2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/exception/exception2.gif -------------------------------------------------------------------------------- /assets/processed_infrared_5.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/processed_infrared_5.gif -------------------------------------------------------------------------------- /assets/results/train_batch0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/results/train_batch0.jpg -------------------------------------------------------------------------------- /assets/results/train_batch1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/results/train_batch1.jpg -------------------------------------------------------------------------------- /assets/results/train_batch2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/results/train_batch2.jpg -------------------------------------------------------------------------------- /__pycache__/bbox.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/__pycache__/bbox.cpython-311.pyc -------------------------------------------------------------------------------- /__pycache__/util.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/__pycache__/util.cpython-311.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/__pycache__/utils.cpython-311.pyc -------------------------------------------------------------------------------- /__pycache__/yolo.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/__pycache__/yolo.cpython-311.pyc -------------------------------------------------------------------------------- /assets/architecture/ostrack_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/architecture/ostrack_1.jpg -------------------------------------------------------------------------------- /assets/architecture/ostrack_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/architecture/ostrack_2.jpg -------------------------------------------------------------------------------- /assets/architecture/ostrack_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/architecture/ostrack_3.jpg -------------------------------------------------------------------------------- /assets/architecture/ostrack_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/architecture/ostrack_4.jpg -------------------------------------------------------------------------------- /assets/architecture/ostrack_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/architecture/ostrack_5.jpg -------------------------------------------------------------------------------- /assets/architecture/ostrack_6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/architecture/ostrack_6.jpg -------------------------------------------------------------------------------- /config/got_10k.yaml: -------------------------------------------------------------------------------- 1 | train: ./datasets/images/train 2 | val: ./datasets/images/val 3 | 4 | nc: 1 5 | names: ['drone'] -------------------------------------------------------------------------------- /__pycache__/yolov5.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/__pycache__/yolov5.cpython-311.pyc -------------------------------------------------------------------------------- /assets/results/confusion_matrix.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/results/confusion_matrix.jpg -------------------------------------------------------------------------------- /assets/results/val_batch0_labels.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/results/val_batch0_labels.jpg -------------------------------------------------------------------------------- /assets/results/val_batch0_pred.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/results/val_batch0_pred.jpg -------------------------------------------------------------------------------- /assets/results/val_batch1_labels.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/results/val_batch1_labels.jpg -------------------------------------------------------------------------------- /assets/results/val_batch1_pred.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/results/val_batch1_pred.jpg -------------------------------------------------------------------------------- /assets/results/val_batch2_labels.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/results/val_batch2_labels.jpg -------------------------------------------------------------------------------- /assets/results/val_batch2_pred.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/results/val_batch2_pred.jpg -------------------------------------------------------------------------------- /__pycache__/yolo_model.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/__pycache__/yolo_model.cpython-311.pyc -------------------------------------------------------------------------------- /assets/results/labels_correlogram.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/results/labels_correlogram.jpg -------------------------------------------------------------------------------- /model/ostrack/__pycache__/util.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/model/ostrack/__pycache__/util.cpython-311.pyc -------------------------------------------------------------------------------- /model/ostrack/__pycache__/vit.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/model/ostrack/__pycache__/vit.cpython-311.pyc -------------------------------------------------------------------------------- /model/ostrack/__pycache__/vit_ce.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/model/ostrack/__pycache__/vit_ce.cpython-311.pyc -------------------------------------------------------------------------------- /model/ostrack/__pycache__/ostrack.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/model/ostrack/__pycache__/ostrack.cpython-311.pyc -------------------------------------------------------------------------------- /model/ostrack/layers/__pycache__/attn.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/model/ostrack/layers/__pycache__/attn.cpython-311.pyc -------------------------------------------------------------------------------- /model/ostrack/layers/__pycache__/head.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/model/ostrack/layers/__pycache__/head.cpython-311.pyc -------------------------------------------------------------------------------- /model/ostrack/layers/__pycache__/rpe.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/model/ostrack/layers/__pycache__/rpe.cpython-311.pyc -------------------------------------------------------------------------------- /model/ostrack/__pycache__/base_backbone.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/model/ostrack/__pycache__/base_backbone.cpython-311.pyc -------------------------------------------------------------------------------- /model/ostrack/utils/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/model/ostrack/utils/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /model/ostrack/utils/__pycache__/box_ops.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/model/ostrack/utils/__pycache__/box_ops.cpython-311.pyc -------------------------------------------------------------------------------- /model/ostrack/utils/__pycache__/tensor.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/model/ostrack/utils/__pycache__/tensor.cpython-311.pyc -------------------------------------------------------------------------------- /model/ostrack/layers/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/model/ostrack/layers/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /model/ostrack/layers/__pycache__/frozen_bn.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/model/ostrack/layers/__pycache__/frozen_bn.cpython-311.pyc -------------------------------------------------------------------------------- /model/ostrack/layers/__pycache__/attn_blocks.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/model/ostrack/layers/__pycache__/attn_blocks.cpython-311.pyc -------------------------------------------------------------------------------- /model/ostrack/layers/__pycache__/patch_embed.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/model/ostrack/layers/__pycache__/patch_embed.cpython-311.pyc -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # 使用git reset更新.gitignore 2 | 3 | # 原始项目文件不上传 4 | MCJT/ 5 | 6 | # 忽略权重文件 7 | weights/ 8 | 9 | # 数据集不上传 10 | datasets/ 11 | 12 | # 忽略日志文件 13 | logs/ 14 | 15 | # 忽略测试视频 16 | video/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | Copyright (c) 2025 Yolo Redal 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /model/ostrack/utils/merge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def merge_template_search(inp_list, return_search=False, return_template=False): 5 | """NOTICE: search region related features must be in the last place""" 6 | seq_dict = {"feat": torch.cat([x["feat"] for x in inp_list], dim=0), 7 | "mask": torch.cat([x["mask"] for x in inp_list], dim=1), 8 | "pos": torch.cat([x["pos"] for x in inp_list], dim=0)} 9 | if return_search: 10 | x = inp_list[-1] 11 | seq_dict.update({"feat_x": x["feat"], "mask_x": x["mask"], "pos_x": x["pos"]}) 12 | if return_template: 13 | z = inp_list[0] 14 | seq_dict.update({"feat_z": z["feat"], "mask_z": z["mask"], "pos_z": z["pos"]}) 15 | return seq_dict 16 | 17 | 18 | def get_qkv(inp_list): 19 | """The 1st element of the inp_list is about the template, 20 | the 2nd (the last) element is about the search region""" 21 | dict_x = inp_list[-1] 22 | dict_c = {"feat": torch.cat([x["feat"] for x in inp_list], dim=0), 23 | "mask": torch.cat([x["mask"] for x in inp_list], dim=1), 24 | "pos": torch.cat([x["pos"] for x in inp_list], dim=0)} # concatenated dict 25 | q = dict_x["feat"] + dict_x["pos"] 26 | k = dict_c["feat"] + dict_c["pos"] 27 | v = dict_c["feat"] 28 | key_padding_mask = dict_c["mask"] 29 | return q, k, v, key_padding_mask 30 | -------------------------------------------------------------------------------- /model/ostrack/layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from timm.models.layers import to_2tuple 4 | 5 | 6 | class PatchEmbed(nn.Module): 7 | """ 2D Image to Patch Embedding 8 | """ 9 | 10 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): 11 | super().__init__() 12 | img_size = to_2tuple(img_size) 13 | patch_size = to_2tuple(patch_size) 14 | self.img_size = img_size 15 | self.patch_size = patch_size 16 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 17 | self.num_patches = self.grid_size[0] * self.grid_size[1] 18 | self.flatten = flatten 19 | 20 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 21 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 22 | 23 | def forward(self, x): 24 | # allow different input size 25 | # B, C, H, W = x.shape 26 | # _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") 27 | # _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") 28 | x = self.proj(x) 29 | if self.flatten: 30 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 31 | x = self.norm(x) 32 | return x 33 | -------------------------------------------------------------------------------- /config/vitb_384_mae_32x4_ep300.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | MAX_SAMPLE_INTERVAL: 200 3 | MEAN: 4 | - 0.485 5 | - 0.456 6 | - 0.406 7 | SEARCH: 8 | CENTER_JITTER: 4.5 9 | FACTOR: 5.0 10 | SCALE_JITTER: 0.5 11 | SIZE: 384 12 | STD: 13 | - 0.229 14 | - 0.224 15 | - 0.225 16 | TEMPLATE: 17 | CENTER_JITTER: 0 18 | FACTOR: 2.0 19 | SCALE_JITTER: 0 20 | SIZE: 192 21 | # TRAIN: 22 | # DATASETS_NAME: 23 | # - GOT10K_train_full 24 | # DATASETS_RATIO: 25 | # - 1 26 | # SAMPLE_PER_EPOCH: 60000 27 | 28 | TRAIN: 29 | DATASETS_NAME: 30 | - LASOT 31 | - GOT10K_vottrain 32 | - COCO17 33 | - TRACKINGNET 34 | DATASETS_RATIO: 35 | - 1 36 | - 1 37 | - 1 38 | - 1 39 | SAMPLE_PER_EPOCH: 60000 40 | VAL: 41 | DATASETS_NAME: 42 | - GOT10K_votval 43 | DATASETS_RATIO: 44 | - 1 45 | SAMPLE_PER_EPOCH: 10000 46 | MODEL: 47 | PRETRAIN_FILE: "mae_pretrain_vit_base.pth" 48 | EXTRA_MERGER: False 49 | RETURN_INTER: False 50 | BACKBONE: 51 | TYPE: vit_base_patch16_224 52 | STRIDE: 16 53 | HEAD: 54 | TYPE: CENTER 55 | NUM_CHANNELS: 256 56 | TRAIN: 57 | BACKBONE_MULTIPLIER: 0.1 58 | DROP_PATH_RATE: 0.1 59 | BATCH_SIZE: 32 60 | EPOCH: 300 61 | GIOU_WEIGHT: 2.0 62 | L1_WEIGHT: 5.0 63 | GRAD_CLIP_NORM: 0.1 64 | LR: 0.0004 65 | LR_DROP_EPOCH: 240 66 | NUM_WORKER: 10 67 | OPTIMIZER: ADAMW 68 | PRINT_INTERVAL: 50 69 | SCHEDULER: 70 | TYPE: step 71 | DECAY_RATE: 0.1 72 | VAL_EPOCH_INTERVAL: 20 73 | WEIGHT_DECAY: 0.0001 74 | AMP: False 75 | TEST: 76 | EPOCH: 300 77 | SEARCH_FACTOR: 5.0 78 | SEARCH_SIZE: 384 79 | TEMPLATE_FACTOR: 2.0 80 | TEMPLATE_SIZE: 192 -------------------------------------------------------------------------------- /config/vitb_256_mae_32x4_ep300.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | MAX_SAMPLE_INTERVAL: 200 3 | MEAN: 4 | - 0.485 5 | - 0.456 6 | - 0.406 7 | SEARCH: 8 | CENTER_JITTER: 3 9 | FACTOR: 4.0 10 | SCALE_JITTER: 0.25 11 | SIZE: 256 12 | NUMBER: 1 13 | STD: 14 | - 0.229 15 | - 0.224 16 | - 0.225 17 | TEMPLATE: 18 | CENTER_JITTER: 0 19 | FACTOR: 2.0 20 | SCALE_JITTER: 0 21 | SIZE: 128 22 | # TRAIN: 23 | # DATASETS_NAME: 24 | # - GOT10K_train_full 25 | # DATASETS_RATIO: 26 | # - 1 27 | # SAMPLE_PER_EPOCH: 60000 28 | 29 | TRAIN: 30 | DATASETS_NAME: 31 | - LASOT 32 | - GOT10K_vottrain 33 | - COCO17 34 | - TRACKINGNET 35 | DATASETS_RATIO: 36 | - 1 37 | - 1 38 | - 1 39 | - 1 40 | SAMPLE_PER_EPOCH: 60000 41 | VAL: 42 | DATASETS_NAME: 43 | - GOT10K_votval 44 | DATASETS_RATIO: 45 | - 1 46 | SAMPLE_PER_EPOCH: 10000 47 | MODEL: 48 | PRETRAIN_FILE: "mae_pretrain_vit_base.pth" 49 | EXTRA_MERGER: False 50 | RETURN_INTER: False 51 | BACKBONE: 52 | TYPE: vit_base_patch16_224 53 | STRIDE: 16 54 | HEAD: 55 | TYPE: CENTER 56 | NUM_CHANNELS: 256 57 | TRAIN: 58 | BACKBONE_MULTIPLIER: 0.1 59 | DROP_PATH_RATE: 0.1 60 | BATCH_SIZE: 32 61 | EPOCH: 300 62 | GIOU_WEIGHT: 2.0 63 | L1_WEIGHT: 5.0 64 | GRAD_CLIP_NORM: 0.1 65 | LR: 0.0004 66 | LR_DROP_EPOCH: 240 67 | NUM_WORKER: 10 68 | OPTIMIZER: ADAMW 69 | PRINT_INTERVAL: 50 70 | SCHEDULER: 71 | TYPE: step 72 | DECAY_RATE: 0.1 73 | VAL_EPOCH_INTERVAL: 20 74 | WEIGHT_DECAY: 0.0001 75 | AMP: False 76 | TEST: 77 | EPOCH: 300 78 | SEARCH_FACTOR: 4.0 79 | SEARCH_SIZE: 256 80 | TEMPLATE_FACTOR: 2.0 81 | TEMPLATE_SIZE: 128 -------------------------------------------------------------------------------- /model/ostrack/utils/variable_hook.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from bytecode import Bytecode, Instr 3 | 4 | 5 | class get_local(object): 6 | cache = {} 7 | is_activate = False 8 | 9 | def __init__(self, varname): 10 | self.varname = varname 11 | 12 | def __call__(self, func): 13 | if not type(self).is_activate: 14 | return func 15 | 16 | type(self).cache[func.__qualname__] = [] 17 | c = Bytecode.from_code(func.__code__) 18 | extra_code = [ 19 | Instr('STORE_FAST', '_res'), 20 | Instr('LOAD_FAST', self.varname), 21 | Instr('STORE_FAST', '_value'), 22 | Instr('LOAD_FAST', '_res'), 23 | Instr('LOAD_FAST', '_value'), 24 | Instr('BUILD_TUPLE', 2), 25 | Instr('STORE_FAST', '_result_tuple'), 26 | Instr('LOAD_FAST', '_result_tuple'), 27 | ] 28 | c[-1:-1] = extra_code 29 | func.__code__ = c.to_code() 30 | 31 | def wrapper(*args, **kwargs): 32 | res, values = func(*args, **kwargs) 33 | if isinstance(values, torch.Tensor): 34 | type(self).cache[func.__qualname__].append(values.detach().cpu().numpy()) 35 | elif isinstance(values, list): # list of Tensor 36 | type(self).cache[func.__qualname__].append([value.detach().cpu().numpy() for value in values]) 37 | else: 38 | raise NotImplementedError 39 | return res 40 | 41 | return wrapper 42 | 43 | @classmethod 44 | def clear(cls): 45 | for key in cls.cache.keys(): 46 | cls.cache[key] = [] 47 | 48 | @classmethod 49 | def activate(cls): 50 | cls.is_activate = True 51 | -------------------------------------------------------------------------------- /model/ostrack/layers/frozen_bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class FrozenBatchNorm2d(torch.nn.Module): 5 | """ 6 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 7 | 8 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 9 | without which any other models than torchvision.models.resnet[18,34,50,101] 10 | produce nans. 11 | """ 12 | 13 | def __init__(self, n): 14 | super(FrozenBatchNorm2d, self).__init__() 15 | self.register_buffer("weight", torch.ones(n)) 16 | self.register_buffer("bias", torch.zeros(n)) 17 | self.register_buffer("running_mean", torch.zeros(n)) 18 | self.register_buffer("running_var", torch.ones(n)) 19 | 20 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 21 | missing_keys, unexpected_keys, error_msgs): 22 | num_batches_tracked_key = prefix + 'num_batches_tracked' 23 | if num_batches_tracked_key in state_dict: 24 | del state_dict[num_batches_tracked_key] 25 | 26 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 27 | state_dict, prefix, local_metadata, strict, 28 | missing_keys, unexpected_keys, error_msgs) 29 | 30 | def forward(self, x): 31 | # move reshapes to the beginning 32 | # to make it fuser-friendly 33 | w = self.weight.reshape(1, -1, 1, 1) 34 | b = self.bias.reshape(1, -1, 1, 1) 35 | rv = self.running_var.reshape(1, -1, 1, 1) 36 | rm = self.running_mean.reshape(1, -1, 1, 1) 37 | eps = 1e-5 38 | scale = w * (rv + eps).rsqrt() # rsqrt(x): 1/sqrt(x), r: reciprocal 39 | bias = b - rm * scale 40 | return x * scale + bias 41 | -------------------------------------------------------------------------------- /config/vitb_256_mae_ce_32x4_got10k_ep100.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | MAX_SAMPLE_INTERVAL: 200 3 | MEAN: 4 | - 0.485 5 | - 0.456 6 | - 0.406 7 | SEARCH: 8 | CENTER_JITTER: 3 9 | FACTOR: 4.0 10 | SCALE_JITTER: 0.25 11 | SIZE: 256 12 | NUMBER: 1 13 | STD: 14 | - 0.229 15 | - 0.224 16 | - 0.225 17 | TEMPLATE: 18 | CENTER_JITTER: 0 19 | FACTOR: 2.0 20 | SCALE_JITTER: 0 21 | SIZE: 128 22 | TRAIN: 23 | DATASETS_NAME: 24 | - GOT10K_train_full 25 | DATASETS_RATIO: 26 | - 1 27 | SAMPLE_PER_EPOCH: 60000 28 | VAL: 29 | DATASETS_NAME: 30 | - GOT10K_official_val 31 | DATASETS_RATIO: 32 | - 1 33 | SAMPLE_PER_EPOCH: 10000 34 | MODEL: 35 | PRETRAIN_FILE: "mae_pretrain_vit_base.pth" 36 | EXTRA_MERGER: False 37 | RETURN_INTER: False 38 | BACKBONE: 39 | TYPE: vit_base_patch16_224_ce 40 | STRIDE: 16 41 | CE_LOC: [3, 6, 9] 42 | CE_KEEP_RATIO: [0.7, 0.7, 0.7] 43 | CE_TEMPLATE_RANGE: 'CTR_POINT' # choose between ALL, CTR_POINT, CTR_REC, GT_BOX 44 | HEAD: 45 | TYPE: CENTER 46 | NUM_CHANNELS: 256 47 | TRAIN: 48 | BACKBONE_MULTIPLIER: 0.1 49 | DROP_PATH_RATE: 0.1 50 | CE_START_EPOCH: 20 # candidate elimination start epoch 51 | CE_WARM_EPOCH: 50 # candidate elimination warm up epoch 52 | BATCH_SIZE: 16 53 | EPOCH: 100 54 | GIOU_WEIGHT: 2.0 55 | L1_WEIGHT: 5.0 56 | GRAD_CLIP_NORM: 0.1 57 | LR: 0.0004 58 | LR_DROP_EPOCH: 80 59 | NUM_WORKER: 10 60 | OPTIMIZER: ADAMW 61 | PRINT_INTERVAL: 50 62 | SCHEDULER: 63 | TYPE: step 64 | DECAY_RATE: 0.1 65 | VAL_EPOCH_INTERVAL: 20 66 | WEIGHT_DECAY: 0.0001 67 | AMP: False 68 | TEST: 69 | EPOCH: 100 70 | SEARCH_FACTOR: 4.0 71 | SEARCH_SIZE: 256 72 | TEMPLATE_FACTOR: 2.0 73 | TEMPLATE_SIZE: 128 -------------------------------------------------------------------------------- /model/ostrack/utils/lmdb_utils.py: -------------------------------------------------------------------------------- 1 | import lmdb 2 | import numpy as np 3 | import cv2 4 | import json 5 | 6 | LMDB_ENVS = dict() 7 | LMDB_HANDLES = dict() 8 | LMDB_FILELISTS = dict() 9 | 10 | 11 | def get_lmdb_handle(name): 12 | global LMDB_HANDLES, LMDB_FILELISTS 13 | item = LMDB_HANDLES.get(name, None) 14 | if item is None: 15 | env = lmdb.open(name, readonly=True, lock=False, readahead=False, meminit=False) 16 | LMDB_ENVS[name] = env 17 | item = env.begin(write=False) 18 | LMDB_HANDLES[name] = item 19 | 20 | return item 21 | 22 | 23 | def decode_img(lmdb_fname, key_name): 24 | handle = get_lmdb_handle(lmdb_fname) 25 | binfile = handle.get(key_name.encode()) 26 | if binfile is None: 27 | print("Illegal data detected. %s %s" % (lmdb_fname, key_name)) 28 | s = np.frombuffer(binfile, np.uint8) 29 | x = cv2.cvtColor(cv2.imdecode(s, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) 30 | return x 31 | 32 | 33 | def decode_str(lmdb_fname, key_name): 34 | handle = get_lmdb_handle(lmdb_fname) 35 | binfile = handle.get(key_name.encode()) 36 | string = binfile.decode() 37 | return string 38 | 39 | 40 | def decode_json(lmdb_fname, key_name): 41 | return json.loads(decode_str(lmdb_fname, key_name)) 42 | 43 | 44 | if __name__ == "__main__": 45 | lmdb_fname = "/data/sda/v-yanbi/iccv21/LittleBoy_clean/data/got10k_lmdb" 46 | '''Decode image''' 47 | # key_name = "test/GOT-10k_Test_000001/00000001.jpg" 48 | # img = decode_img(lmdb_fname, key_name) 49 | # cv2.imwrite("001.jpg", img) 50 | '''Decode str''' 51 | # key_name = "test/list.txt" 52 | # key_name = "train/GOT-10k_Train_000001/groundtruth.txt" 53 | key_name = "train/GOT-10k_Train_000001/absence.label" 54 | str_ = decode_str(lmdb_fname, key_name) 55 | print(str_) 56 | -------------------------------------------------------------------------------- /config/vitb_384_mae_ce_32x4_ep300.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | MAX_SAMPLE_INTERVAL: 200 3 | MEAN: 4 | - 0.485 5 | - 0.456 6 | - 0.406 7 | SEARCH: 8 | CENTER_JITTER: 4.5 9 | FACTOR: 5.0 10 | SCALE_JITTER: 0.5 11 | SIZE: 384 12 | STD: 13 | - 0.229 14 | - 0.224 15 | - 0.225 16 | TEMPLATE: 17 | CENTER_JITTER: 0 18 | FACTOR: 2.0 19 | SCALE_JITTER: 0 20 | SIZE: 192 21 | # TRAIN: 22 | # DATASETS_NAME: 23 | # - GOT10K_train_full 24 | # DATASETS_RATIO: 25 | # - 1 26 | # SAMPLE_PER_EPOCH: 60000 27 | 28 | TRAIN: 29 | DATASETS_NAME: 30 | - LASOT 31 | - GOT10K_vottrain 32 | - COCO17 33 | - TRACKINGNET 34 | DATASETS_RATIO: 35 | - 1 36 | - 1 37 | - 1 38 | - 1 39 | SAMPLE_PER_EPOCH: 60000 40 | VAL: 41 | DATASETS_NAME: 42 | - GOT10K_votval 43 | DATASETS_RATIO: 44 | - 1 45 | SAMPLE_PER_EPOCH: 10000 46 | MODEL: 47 | PRETRAIN_FILE: "mae_pretrain_vit_base.pth" 48 | EXTRA_MERGER: False 49 | RETURN_INTER: False 50 | BACKBONE: 51 | TYPE: vit_base_patch16_224_ce 52 | STRIDE: 16 53 | CE_LOC: [3, 6, 9] 54 | CE_KEEP_RATIO: [0.7, 0.7, 0.7] 55 | CE_TEMPLATE_RANGE: 'CTR_POINT' # choose between ALL, CTR_POINT, CTR_REC, GT_BOX 56 | HEAD: 57 | TYPE: CENTER 58 | NUM_CHANNELS: 256 59 | TRAIN: 60 | BACKBONE_MULTIPLIER: 0.1 61 | DROP_PATH_RATE: 0.1 62 | BATCH_SIZE: 32 63 | EPOCH: 300 64 | GIOU_WEIGHT: 2.0 65 | L1_WEIGHT: 5.0 66 | GRAD_CLIP_NORM: 0.1 67 | LR: 0.0004 68 | LR_DROP_EPOCH: 240 69 | NUM_WORKER: 10 70 | OPTIMIZER: ADAMW 71 | PRINT_INTERVAL: 50 72 | SCHEDULER: 73 | TYPE: step 74 | DECAY_RATE: 0.1 75 | VAL_EPOCH_INTERVAL: 20 76 | WEIGHT_DECAY: 0.0001 77 | AMP: False 78 | TEST: 79 | EPOCH: 300 80 | SEARCH_FACTOR: 5.0 81 | SEARCH_SIZE: 384 82 | TEMPLATE_FACTOR: 2.0 83 | TEMPLATE_SIZE: 192 -------------------------------------------------------------------------------- /config/vitb_384_mae_ce_32x4_got10k_ep100.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | MAX_SAMPLE_INTERVAL: 200 3 | MEAN: 4 | - 0.485 5 | - 0.456 6 | - 0.406 7 | SEARCH: 8 | CENTER_JITTER: 4.5 9 | FACTOR: 5.0 10 | SCALE_JITTER: 0.5 11 | SIZE: 384 12 | STD: 13 | - 0.229 14 | - 0.224 15 | - 0.225 16 | TEMPLATE: 17 | CENTER_JITTER: 0 18 | FACTOR: 2.0 19 | SCALE_JITTER: 0 20 | SIZE: 192 21 | TRAIN: 22 | DATASETS_NAME: 23 | - GOT10K_train_full 24 | DATASETS_RATIO: 25 | - 1 26 | SAMPLE_PER_EPOCH: 60000 27 | VAL: 28 | DATASETS_NAME: 29 | - GOT10K_official_val 30 | DATASETS_RATIO: 31 | - 1 32 | SAMPLE_PER_EPOCH: 10000 33 | MODEL: 34 | PRETRAIN_FILE: "mae_pretrain_vit_base.pth" 35 | EXTRA_MERGER: False 36 | RETURN_INTER: False 37 | BACKBONE: 38 | CAT_MODE: 'direct' # direct or concat 39 | SEP_SEG: True 40 | TYPE: vit_base_patch16_224_ce 41 | STRIDE: 16 42 | CE_LOC: [3, 6, 9] 43 | CE_KEEP_RATIO: [0.7, 0.7, 0.7] 44 | # CE_KEEP_RATIO: [0.5, 0.5, 0.5] 45 | CE_TEMPLATE_RANGE: 'CTR_POINT' # choose between ALL, CTR_POINT, CTR_REC, GT_BOX 46 | HEAD: 47 | TYPE: CENTER 48 | NUM_CHANNELS: 256 49 | RETURN_STAGES: [2, 5, 8, 11] 50 | TRAIN: 51 | BACKBONE_MULTIPLIER: 0.1 52 | DROP_PATH_RATE: 0.1 53 | CE_START_EPOCH: 20 # candidate elimination start epoch 54 | CE_WARM_EPOCH: 50 # candidate elimination warm up epoch 55 | BATCH_SIZE: 16 # 2022.12.15 32->16 56 | EPOCH: 100 57 | GIOU_WEIGHT: 2.0 58 | L1_WEIGHT: 5.0 59 | GRAD_CLIP_NORM: 0.1 60 | LR: 0.0002 # 2022.12.20 0.0004 -> 0.0002 61 | LR_DROP_EPOCH: 80 62 | NUM_WORKER: 10 63 | OPTIMIZER: ADAMW 64 | PRINT_INTERVAL: 50 65 | SCHEDULER: 66 | TYPE: step 67 | DECAY_RATE: 0.1 68 | VAL_EPOCH_INTERVAL: 20 69 | WEIGHT_DECAY: 0.0001 70 | AMP: False 71 | TEST: 72 | EPOCH: 100 73 | SEARCH_FACTOR: 5.0 74 | SEARCH_SIZE: 384 75 | TEMPLATE_FACTOR: 2.0 76 | TEMPLATE_SIZE: 192 -------------------------------------------------------------------------------- /config/vitb_256_mae_ce_32x4_ep300.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | MAX_SAMPLE_INTERVAL: 200 3 | MEAN: 4 | - 0.485 5 | - 0.456 6 | - 0.406 7 | SEARCH: 8 | CENTER_JITTER: 3 9 | FACTOR: 4.0 10 | SCALE_JITTER: 0.25 11 | SIZE: 256 12 | NUMBER: 1 13 | STD: 14 | - 0.229 15 | - 0.224 16 | - 0.225 17 | TEMPLATE: 18 | CENTER_JITTER: 0 19 | FACTOR: 2.0 20 | SCALE_JITTER: 0 21 | SIZE: 128 22 | # TRAIN: 23 | # DATASETS_NAME: 24 | # - GOT10K_train_full 25 | # DATASETS_RATIO: 26 | # - 1 27 | # SAMPLE_PER_EPOCH: 60000 28 | 29 | TRAIN: 30 | DATASETS_NAME: 31 | - LASOT 32 | - GOT10K_vottrain 33 | - COCO17 34 | - TRACKINGNET 35 | DATASETS_RATIO: 36 | - 1 37 | - 1 38 | - 1 39 | - 1 40 | SAMPLE_PER_EPOCH: 60000 41 | VAL: 42 | DATASETS_NAME: 43 | - GOT10K_votval 44 | DATASETS_RATIO: 45 | - 1 46 | SAMPLE_PER_EPOCH: 10000 47 | MODEL: 48 | PRETRAIN_FILE: "mae_pretrain_vit_base.pth" 49 | EXTRA_MERGER: False 50 | RETURN_INTER: False 51 | BACKBONE: 52 | TYPE: vit_base_patch16_224_ce 53 | STRIDE: 16 54 | CE_LOC: [3, 6, 9] 55 | CE_KEEP_RATIO: [0.7, 0.7, 0.7] 56 | CE_TEMPLATE_RANGE: 'CTR_POINT' # choose between ALL, CTR_POINT, CTR_REC, GT_BOX 57 | HEAD: 58 | TYPE: CENTER 59 | NUM_CHANNELS: 256 60 | TRAIN: 61 | BACKBONE_MULTIPLIER: 0.1 62 | DROP_PATH_RATE: 0.1 63 | CE_START_EPOCH: 20 # candidate elimination start epoch 64 | CE_WARM_EPOCH: 80 # candidate elimination warm up epoch 65 | BATCH_SIZE: 32 66 | EPOCH: 300 67 | GIOU_WEIGHT: 2.0 68 | L1_WEIGHT: 5.0 69 | GRAD_CLIP_NORM: 0.1 70 | LR: 0.0004 71 | LR_DROP_EPOCH: 240 72 | NUM_WORKER: 10 73 | OPTIMIZER: ADAMW 74 | PRINT_INTERVAL: 50 75 | SCHEDULER: 76 | TYPE: step 77 | DECAY_RATE: 0.1 78 | VAL_EPOCH_INTERVAL: 20 79 | WEIGHT_DECAY: 0.0001 80 | AMP: False 81 | TEST: 82 | EPOCH: 300 83 | SEARCH_FACTOR: 4.0 84 | SEARCH_SIZE: 256 85 | TEMPLATE_FACTOR: 2.0 86 | TEMPLATE_SIZE: 128 -------------------------------------------------------------------------------- /model/ostrack/utils/focal_loss.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class FocalLoss(nn.Module, ABC): 9 | def __init__(self, alpha=2, beta=4): 10 | super(FocalLoss, self).__init__() 11 | self.alpha = alpha 12 | self.beta = beta 13 | 14 | def forward(self, prediction, target): 15 | positive_index = target.eq(1).float() 16 | negative_index = target.lt(1).float() 17 | 18 | negative_weights = torch.pow(1 - target, self.beta) 19 | # clamp min value is set to 1e-12 to maintain the numerical stability 20 | prediction = torch.clamp(prediction, 1e-12) 21 | 22 | positive_loss = torch.log(prediction) * torch.pow(1 - prediction, self.alpha) * positive_index 23 | negative_loss = torch.log(1 - prediction) * torch.pow(prediction, 24 | self.alpha) * negative_weights * negative_index 25 | 26 | num_positive = positive_index.float().sum() 27 | positive_loss = positive_loss.sum() 28 | negative_loss = negative_loss.sum() 29 | 30 | if num_positive == 0: 31 | loss = -negative_loss 32 | else: 33 | loss = -(positive_loss + negative_loss) / num_positive 34 | 35 | return loss 36 | 37 | 38 | class LBHinge(nn.Module): 39 | """Loss that uses a 'hinge' on the lower bound. 40 | This means that for samples with a label value smaller than the threshold, the loss is zero if the prediction is 41 | also smaller than that threshold. 42 | args: 43 | error_matric: What base loss to use (MSE by default). 44 | threshold: Threshold to use for the hinge. 45 | clip: Clip the loss if it is above this value. 46 | """ 47 | def __init__(self, error_metric=nn.MSELoss(), threshold=None, clip=None): 48 | super().__init__() 49 | self.error_metric = error_metric 50 | self.threshold = threshold if threshold is not None else -100 51 | self.clip = clip 52 | 53 | def forward(self, prediction, label, target_bb=None): 54 | negative_mask = (label < self.threshold).float() 55 | positive_mask = (1.0 - negative_mask) 56 | 57 | prediction = negative_mask * F.relu(prediction) + positive_mask * prediction 58 | 59 | loss = self.error_metric(prediction, positive_mask * label) 60 | 61 | if self.clip is not None: 62 | loss = torch.min(loss, torch.tensor([self.clip], device=loss.device)) 63 | return loss -------------------------------------------------------------------------------- /model/ostrack/utils/box_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.ops.boxes import box_area 3 | import numpy as np 4 | 5 | 6 | def box_cxcywh_to_xyxy(x): 7 | x_c, y_c, w, h = x.unbind(-1) 8 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 9 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 10 | return torch.stack(b, dim=-1) 11 | 12 | 13 | def box_xywh_to_xyxy(x): 14 | x1, y1, w, h = x.unbind(-1) 15 | b = [x1, y1, x1 + w, y1 + h] 16 | return torch.stack(b, dim=-1) 17 | 18 | 19 | def box_xyxy_to_xywh(x): 20 | x1, y1, x2, y2 = x.unbind(-1) 21 | b = [x1, y1, x2 - x1, y2 - y1] 22 | return torch.stack(b, dim=-1) 23 | 24 | 25 | def box_xyxy_to_cxcywh(x): 26 | x0, y0, x1, y1 = x.unbind(-1) 27 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 28 | (x1 - x0), (y1 - y0)] 29 | return torch.stack(b, dim=-1) 30 | 31 | 32 | # modified from torchvision to also return the union 33 | '''Note that this function only supports shape (N,4)''' 34 | 35 | 36 | def box_iou(boxes1, boxes2): 37 | """ 38 | 39 | :param boxes1: (N, 4) (x1,y1,x2,y2) 40 | :param boxes2: (N, 4) (x1,y1,x2,y2) 41 | :return: 42 | """ 43 | area1 = box_area(boxes1) # (N,) 44 | area2 = box_area(boxes2) # (N,) 45 | 46 | lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # (N,2) 47 | rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # (N,2) 48 | 49 | wh = (rb - lt).clamp(min=0) # (N,2) 50 | inter = wh[:, 0] * wh[:, 1] # (N,) 51 | 52 | union = area1 + area2 - inter 53 | 54 | iou = inter / union 55 | return iou, union 56 | 57 | 58 | '''Note that this implementation is different from DETR's''' 59 | 60 | 61 | def generalized_box_iou(boxes1, boxes2): 62 | """ 63 | Generalized IoU from https://giou.stanford.edu/ 64 | 65 | The boxes should be in [x0, y0, x1, y1] format 66 | 67 | boxes1: (N, 4) 68 | boxes2: (N, 4) 69 | """ 70 | # degenerate boxes gives inf / nan results 71 | # so do an early check 72 | # try: 73 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 74 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 75 | iou, union = box_iou(boxes1, boxes2) # (N,) 76 | 77 | lt = torch.min(boxes1[:, :2], boxes2[:, :2]) 78 | rb = torch.max(boxes1[:, 2:], boxes2[:, 2:]) 79 | 80 | wh = (rb - lt).clamp(min=0) # (N,2) 81 | area = wh[:, 0] * wh[:, 1] # (N,) 82 | 83 | return iou - (area - union) / area, iou 84 | 85 | 86 | def giou_loss(boxes1, boxes2): 87 | """ 88 | 89 | :param boxes1: (N, 4) (x1,y1,x2,y2) 90 | :param boxes2: (N, 4) (x1,y1,x2,y2) 91 | :return: 92 | """ 93 | giou, iou = generalized_box_iou(boxes1, boxes2) 94 | return (1 - giou).mean(), iou 95 | 96 | 97 | def clip_box(box: list, H, W, margin=0): 98 | x1, y1, w, h = box 99 | x2, y2 = x1 + w, y1 + h 100 | x1 = min(max(0, x1), W-margin) 101 | x2 = min(max(margin, x2), W) 102 | y1 = min(max(0, y1), H-margin) 103 | y2 = min(max(margin, y2), H) 104 | w = max(margin, x2-x1) 105 | h = max(margin, y2-y1) 106 | return [x1, y1, w, h] 107 | -------------------------------------------------------------------------------- /model/ostrack/utils/ce_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def generate_bbox_mask(bbox_mask, bbox): 8 | b, h, w = bbox_mask.shape 9 | for i in range(b): 10 | bbox_i = bbox[i].cpu().tolist() 11 | bbox_mask[i, int(bbox_i[1]):int(bbox_i[1] + bbox_i[3] - 1), int(bbox_i[0]):int(bbox_i[0] + bbox_i[2] - 1)] = 1 12 | return bbox_mask 13 | 14 | 15 | def generate_mask_cond(cfg, bs, device, gt_bbox): 16 | template_size = cfg.DATA.TEMPLATE.SIZE 17 | stride = cfg.MODEL.BACKBONE.STRIDE 18 | template_feat_size = template_size // stride 19 | 20 | if cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'ALL': 21 | box_mask_z = None 22 | elif cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'CTR_POINT': 23 | if template_feat_size == 8: 24 | index = slice(3, 4) 25 | elif template_feat_size == 12: 26 | index = slice(5, 6) 27 | elif template_feat_size == 7: 28 | index = slice(3, 4) 29 | elif template_feat_size == 14: 30 | index = slice(6, 7) 31 | else: 32 | raise NotImplementedError 33 | box_mask_z = torch.zeros([bs, template_feat_size, template_feat_size], device=device) 34 | box_mask_z[:, index, index] = 1 35 | box_mask_z = box_mask_z.flatten(1).to(torch.bool) 36 | elif cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'CTR_REC': 37 | # use fixed 4x4 region, 3:5 for 8x8 38 | # use fixed 4x4 region 5:6 for 12x12 39 | if template_feat_size == 8: 40 | index = slice(3, 5) 41 | elif template_feat_size == 12: 42 | index = slice(5, 7) 43 | elif template_feat_size == 7: 44 | index = slice(3, 4) 45 | else: 46 | raise NotImplementedError 47 | box_mask_z = torch.zeros([bs, template_feat_size, template_feat_size], device=device) 48 | box_mask_z[:, index, index] = 1 49 | box_mask_z = box_mask_z.flatten(1).to(torch.bool) 50 | 51 | elif cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'GT_BOX': 52 | box_mask_z = torch.zeros([bs, template_size, template_size], device=device) 53 | # box_mask_z_ori = data['template_seg'][0].view(-1, 1, *data['template_seg'].shape[2:]) # (batch, 1, 128, 128) 54 | box_mask_z = generate_bbox_mask(box_mask_z, gt_bbox * template_size).unsqueeze(1).to( 55 | torch.float) # (batch, 1, 128, 128) 56 | # box_mask_z_vis = box_mask_z.cpu().numpy() 57 | box_mask_z = F.interpolate(box_mask_z, scale_factor=1. / cfg.MODEL.BACKBONE.STRIDE, mode='bilinear', 58 | align_corners=False) 59 | box_mask_z = box_mask_z.flatten(1).to(torch.bool) 60 | # box_mask_z_vis = box_mask_z[:, 0, ...].cpu().numpy() 61 | # gaussian_maps_vis = generate_heatmap(data['template_anno'], self.cfg.DATA.TEMPLATE.SIZE, self.cfg.MODEL.STRIDE)[0].cpu().numpy() 62 | else: 63 | raise NotImplementedError 64 | 65 | return box_mask_z 66 | 67 | 68 | def adjust_keep_rate(epoch, warmup_epochs, total_epochs, ITERS_PER_EPOCH, base_keep_rate=0.5, max_keep_rate=1, iters=-1): 69 | if epoch < warmup_epochs: 70 | return 1 71 | if epoch >= total_epochs: 72 | return base_keep_rate 73 | if iters == -1: 74 | iters = epoch * ITERS_PER_EPOCH 75 | total_iters = ITERS_PER_EPOCH * (total_epochs - warmup_epochs) 76 | iters = iters - ITERS_PER_EPOCH * warmup_epochs 77 | keep_rate = base_keep_rate + (max_keep_rate - base_keep_rate) \ 78 | * (math.cos(iters / total_iters * math.pi) + 1) * 0.5 79 | 80 | return keep_rate 81 | -------------------------------------------------------------------------------- /model/ostrack/util.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def combine_tokens(template_tokens, search_tokens, mode='direct', return_res=False): 8 | # [B, HW, C] 9 | len_t = template_tokens.shape[1] 10 | len_s = search_tokens.shape[1] 11 | 12 | if mode == 'direct': 13 | merged_feature = torch.cat((template_tokens, search_tokens), dim=1) 14 | elif mode == 'template_central': 15 | central_pivot = len_s // 2 16 | first_half = search_tokens[:, :central_pivot, :] 17 | second_half = search_tokens[:, central_pivot:, :] 18 | merged_feature = torch.cat((first_half, template_tokens, second_half), dim=1) 19 | elif mode == 'partition': 20 | feat_size_s = int(math.sqrt(len_s)) 21 | feat_size_t = int(math.sqrt(len_t)) 22 | window_size = math.ceil(feat_size_t / 2.) 23 | # pad feature maps to multiples of window size 24 | B, _, C = template_tokens.shape 25 | H = W = feat_size_t 26 | template_tokens = template_tokens.view(B, H, W, C) 27 | pad_l = pad_b = pad_r = 0 28 | # pad_r = (window_size - W % window_size) % window_size 29 | pad_t = (window_size - H % window_size) % window_size 30 | template_tokens = F.pad(template_tokens, (0, 0, pad_l, pad_r, pad_t, pad_b)) 31 | _, Hp, Wp, _ = template_tokens.shape 32 | template_tokens = template_tokens.view(B, Hp // window_size, window_size, W, C) 33 | template_tokens = torch.cat([template_tokens[:, 0, ...], template_tokens[:, 1, ...]], dim=2) 34 | _, Hc, Wc, _ = template_tokens.shape 35 | template_tokens = template_tokens.view(B, -1, C) 36 | merged_feature = torch.cat([template_tokens, search_tokens], dim=1) 37 | 38 | # calculate new h and w, which may be useful for SwinT or others 39 | merged_h, merged_w = feat_size_s + Hc, feat_size_s 40 | if return_res: 41 | return merged_feature, merged_h, merged_w 42 | 43 | else: 44 | raise NotImplementedError 45 | 46 | return merged_feature 47 | 48 | 49 | def recover_tokens(merged_tokens, len_template_token, len_search_token, mode='direct'): 50 | if mode == 'direct': 51 | recovered_tokens = merged_tokens 52 | elif mode == 'template_central': 53 | central_pivot = len_search_token // 2 54 | len_remain = len_search_token - central_pivot 55 | len_half_and_t = central_pivot + len_template_token 56 | 57 | first_half = merged_tokens[:, :central_pivot, :] 58 | second_half = merged_tokens[:, -len_remain:, :] 59 | template_tokens = merged_tokens[:, central_pivot:len_half_and_t, :] 60 | 61 | recovered_tokens = torch.cat((template_tokens, first_half, second_half), dim=1) 62 | elif mode == 'partition': 63 | recovered_tokens = merged_tokens 64 | else: 65 | raise NotImplementedError 66 | 67 | return recovered_tokens 68 | 69 | 70 | def window_partition(x, window_size: int): 71 | """ 72 | Args: 73 | x: (B, H, W, C) 74 | window_size (int): window size 75 | 76 | Returns: 77 | windows: (num_windows*B, window_size, window_size, C) 78 | """ 79 | B, H, W, C = x.shape 80 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 81 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 82 | return windows 83 | 84 | 85 | def window_reverse(windows, window_size: int, H: int, W: int): 86 | """ 87 | Args: 88 | windows: (num_windows*B, window_size, window_size, C) 89 | window_size (int): Window size 90 | H (int): Height of image 91 | W (int): Width of image 92 | 93 | Returns: 94 | x: (B, H, W, C) 95 | """ 96 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 97 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 98 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 99 | return x 100 | -------------------------------------------------------------------------------- /model/ostrack/layers/rpe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from timm.models.layers import trunc_normal_ 4 | 5 | 6 | def generate_2d_relative_positional_encoding_index(z_shape, x_shape): 7 | ''' 8 | z_shape: (z_h, z_w) 9 | x_shape: (x_h, x_w) 10 | ''' 11 | z_2d_index_h, z_2d_index_w = torch.meshgrid(torch.arange(z_shape[0]), torch.arange(z_shape[1])) 12 | x_2d_index_h, x_2d_index_w = torch.meshgrid(torch.arange(x_shape[0]), torch.arange(x_shape[1])) 13 | 14 | z_2d_index_h = z_2d_index_h.flatten(0) 15 | z_2d_index_w = z_2d_index_w.flatten(0) 16 | x_2d_index_h = x_2d_index_h.flatten(0) 17 | x_2d_index_w = x_2d_index_w.flatten(0) 18 | 19 | diff_h = z_2d_index_h[:, None] - x_2d_index_h[None, :] 20 | diff_w = z_2d_index_w[:, None] - x_2d_index_w[None, :] 21 | 22 | diff = torch.stack((diff_h, diff_w), dim=-1) 23 | _, indices = torch.unique(diff.view(-1, 2), return_inverse=True, dim=0) 24 | return indices.view(z_shape[0] * z_shape[1], x_shape[0] * x_shape[1]) 25 | 26 | 27 | def generate_2d_concatenated_self_attention_relative_positional_encoding_index(z_shape, x_shape): 28 | ''' 29 | z_shape: (z_h, z_w) 30 | x_shape: (x_h, x_w) 31 | ''' 32 | z_2d_index_h, z_2d_index_w = torch.meshgrid(torch.arange(z_shape[0]), torch.arange(z_shape[1])) 33 | x_2d_index_h, x_2d_index_w = torch.meshgrid(torch.arange(x_shape[0]), torch.arange(x_shape[1])) 34 | 35 | z_2d_index_h = z_2d_index_h.flatten(0) 36 | z_2d_index_w = z_2d_index_w.flatten(0) 37 | x_2d_index_h = x_2d_index_h.flatten(0) 38 | x_2d_index_w = x_2d_index_w.flatten(0) 39 | 40 | concatenated_2d_index_h = torch.cat((z_2d_index_h, x_2d_index_h)) 41 | concatenated_2d_index_w = torch.cat((z_2d_index_w, x_2d_index_w)) 42 | 43 | diff_h = concatenated_2d_index_h[:, None] - concatenated_2d_index_h[None, :] 44 | diff_w = concatenated_2d_index_w[:, None] - concatenated_2d_index_w[None, :] 45 | 46 | z_len = z_shape[0] * z_shape[1] 47 | x_len = x_shape[0] * x_shape[1] 48 | a = torch.empty((z_len + x_len), dtype=torch.int64) 49 | a[:z_len] = 0 50 | a[z_len:] = 1 51 | b=a[:, None].repeat(1, z_len + x_len) 52 | c=a[None, :].repeat(z_len + x_len, 1) 53 | 54 | diff = torch.stack((diff_h, diff_w, b, c), dim=-1) 55 | _, indices = torch.unique(diff.view((z_len + x_len) * (z_len + x_len), 4), return_inverse=True, dim=0) 56 | return indices.view((z_len + x_len), (z_len + x_len)) 57 | 58 | 59 | def generate_2d_concatenated_cross_attention_relative_positional_encoding_index(z_shape, x_shape): 60 | ''' 61 | z_shape: (z_h, z_w) 62 | x_shape: (x_h, x_w) 63 | ''' 64 | z_2d_index_h, z_2d_index_w = torch.meshgrid(torch.arange(z_shape[0]), torch.arange(z_shape[1])) 65 | x_2d_index_h, x_2d_index_w = torch.meshgrid(torch.arange(x_shape[0]), torch.arange(x_shape[1])) 66 | 67 | z_2d_index_h = z_2d_index_h.flatten(0) 68 | z_2d_index_w = z_2d_index_w.flatten(0) 69 | x_2d_index_h = x_2d_index_h.flatten(0) 70 | x_2d_index_w = x_2d_index_w.flatten(0) 71 | 72 | concatenated_2d_index_h = torch.cat((z_2d_index_h, x_2d_index_h)) 73 | concatenated_2d_index_w = torch.cat((z_2d_index_w, x_2d_index_w)) 74 | 75 | diff_h = x_2d_index_h[:, None] - concatenated_2d_index_h[None, :] 76 | diff_w = x_2d_index_w[:, None] - concatenated_2d_index_w[None, :] 77 | 78 | z_len = z_shape[0] * z_shape[1] 79 | x_len = x_shape[0] * x_shape[1] 80 | 81 | a = torch.empty(z_len + x_len, dtype=torch.int64) 82 | a[: z_len] = 0 83 | a[z_len:] = 1 84 | c = a[None, :].repeat(x_len, 1) 85 | 86 | diff = torch.stack((diff_h, diff_w, c), dim=-1) 87 | _, indices = torch.unique(diff.view(x_len * (z_len + x_len), 3), return_inverse=True, dim=0) 88 | return indices.view(x_len, (z_len + x_len)) 89 | 90 | 91 | class RelativePosition2DEncoder(nn.Module): 92 | def __init__(self, num_heads, embed_size): 93 | super(RelativePosition2DEncoder, self).__init__() 94 | self.relative_position_bias_table = nn.Parameter(torch.empty((num_heads, embed_size))) 95 | trunc_normal_(self.relative_position_bias_table, std=0.02) 96 | 97 | def forward(self, attn_rpe_index): 98 | ''' 99 | Args: 100 | attn_rpe_index (torch.Tensor): (*), any shape containing indices, max(attn_rpe_index) < embed_size 101 | Returns: 102 | torch.Tensor: (1, num_heads, *) 103 | ''' 104 | return self.relative_position_bias_table[:, attn_rpe_index].unsqueeze(0) 105 | -------------------------------------------------------------------------------- /bbox.py: -------------------------------------------------------------------------------- 1 | """ 2 | TODO: 构建yolov5模型的解码输出,包括bbox的解码和置信度的解码 3 | 以及无人机的位置pixels信息 4 | 时间: 2025/03/11-Redal 5 | """ 6 | import os 7 | import cv2 8 | import random 9 | import torch 10 | import numpy as np 11 | from yolov5.utils.general import non_max_suppression, scale_boxes 12 | from yolov5.utils.augmentations import letterbox 13 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 14 | 15 | 16 | 17 | ########################### 用于YOLOv5模型解析输出 ############################################# 18 | def plot_one_box(x, img, color=None, label=None, line_thickness=None): 19 | """Plots one bounding box on image img""" 20 | tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 21 | color = color or [random.randint(0, 255) for _ in range(3)] 22 | c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3])) 23 | cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA) 24 | if label: 25 | tf = max(tl - 1, 1) 26 | t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0] 27 | c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3 28 | cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) 29 | cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA) 30 | return img 31 | 32 | 33 | def decoder(model, img0): 34 | """decode the yolo model output and plot the bounding box 35 | :param model: the trained yolo model consisting of s/m/l/x version 36 | :param img0: the uav frame image got from computer camera""" 37 | img = letterbox(img0, new_shape=640)[0] 38 | img = img[:, :, ::-1].transpose(2, 0, 1) 39 | img = np.ascontiguousarray(img) 40 | # yolo model inference and postprocess 41 | img = torch.from_numpy(img).to(device) 42 | img = img.float() / 255.0 43 | if img.ndimension() == 3: 44 | img = img.unsqueeze(0) 45 | with torch.no_grad(): 46 | pred = model(img)[0] 47 | # use NMS to remove redundant boxes 48 | conf_thres = 0.25 49 | iou_thres = 0.45 50 | pred = non_max_suppression(pred, conf_thres, iou_thres) 51 | for det in pred: 52 | if len(det): 53 | det[:, :4] = scale_boxes(img.shape[2:], det[:, :4], img0.shape).round() 54 | for *xyxy, conf, cls in reversed(det): 55 | # Xyxy contains the coordinates of the upper left and lower right corners 56 | # of the bounding box, conf is the confidence level, and cls is the category number. 57 | label = f'{model.names[int(cls)]} {conf:.2f}' 58 | print(f"Detected object: {label} at {xyxy}") 59 | plot_one_box(xyxy, img0, label=label, color=(0, 255, 0), line_thickness=3) 60 | try: return img0, xyxy 61 | except: return img0, None 62 | 63 | 64 | 65 | ########################### 用于OSTrack模型规范输入 ############################################# 66 | def ScaleClip(img, xyxy, mode=None): 67 | """ScaleClip is used to clip frame for template and search area 68 | :param img: the frame image must consists of UAV pixels 69 | :param xyxy: the up-left and down-right coordinates of the UAV bounding box""" 70 | img_array = np.array(img) 71 | width, height = xyxy[2] - xyxy[0], xyxy[3] - xyxy[1] 72 | center = np.array([xyxy[0] + width / 2, xyxy[1] + height / 2]) 73 | scale_factor = {'template': 2, 'search': 5}.get(mode, 0) 74 | scaled_width = int(scale_factor * width) 75 | scaled_height = int(scale_factor * height) 76 | # Calculate the cropping rectangle ensuring it does not exceed image boundaries. 77 | top_left_x = max(int(center[0] - scaled_width / 2), 0) 78 | top_left_y = max(int(center[1] - scaled_height / 2), 0) 79 | bottom_right_x = min(int(center[0] + scaled_width / 2), img_array.shape[1]) 80 | bottom_right_y = min(int(center[1] + scaled_height / 2), img_array.shape[0]) 81 | # Clip the image 82 | img_clipped = img_array[top_left_y:bottom_right_y, top_left_x:bottom_right_x, :] 83 | return img_clipped 84 | 85 | 86 | 87 | ############################# 主函数测试分析 ############################################# 88 | if __name__ == '__main__': 89 | img0 = cv2.imread('assets/uav_2.jpg') 90 | # img_templete = ScaleClip(img0, [150, 150, 250, 250], mode='template') 91 | # img_search = ScaleClip(img0, [150, 150, 250, 250], mode='search') 92 | # cv2.imshow('template', img_templete) 93 | # cv2.imshow('search', img_search) 94 | cv2.waitKey(0) 95 | cv2.destroyAllWindows() 96 | -------------------------------------------------------------------------------- /model/ostrack/layers/attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from timm.models.layers import trunc_normal_ 5 | 6 | from .rpe import generate_2d_concatenated_self_attention_relative_positional_encoding_index 7 | 8 | 9 | class Attention(nn.Module): 10 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., 11 | rpe=False, z_size=7, x_size=14): 12 | super().__init__() 13 | self.num_heads = num_heads 14 | head_dim = dim // num_heads 15 | self.scale = head_dim ** -0.5 16 | 17 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 18 | self.attn_drop = nn.Dropout(attn_drop) 19 | self.proj = nn.Linear(dim, dim) 20 | self.proj_drop = nn.Dropout(proj_drop) 21 | 22 | self.rpe =rpe 23 | if self.rpe: 24 | relative_position_index = \ 25 | generate_2d_concatenated_self_attention_relative_positional_encoding_index([z_size, z_size], 26 | [x_size, x_size]) 27 | self.register_buffer("relative_position_index", relative_position_index) 28 | # define a parameter table of relative position bias 29 | self.relative_position_bias_table = nn.Parameter(torch.empty((num_heads, 30 | relative_position_index.max() + 1))) 31 | trunc_normal_(self.relative_position_bias_table, std=0.02) 32 | 33 | def forward(self, x, mask=None, return_attention=False): 34 | # x: B, N, C 35 | # mask: [B, N, ] torch.bool 36 | B, N, C = x.shape 37 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 38 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 39 | 40 | attn = (q @ k.transpose(-2, -1)) * self.scale 41 | 42 | if self.rpe: 43 | relative_position_bias = self.relative_position_bias_table[:, self.relative_position_index].unsqueeze(0) 44 | attn += relative_position_bias 45 | 46 | if mask is not None: 47 | attn = attn.masked_fill(mask.unsqueeze(1).unsqueeze(2), float('-inf'),) 48 | 49 | attn = attn.softmax(dim=-1) 50 | attn = self.attn_drop(attn) 51 | 52 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 53 | x = self.proj(x) 54 | x = self.proj_drop(x) 55 | 56 | if return_attention: 57 | return x, attn 58 | else: 59 | return x 60 | 61 | 62 | class Attention_talking_head(nn.Module): 63 | # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 64 | # with slight modifications to add Talking Heads Attention (https://arxiv.org/pdf/2003.02436v1.pdf) 65 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., 66 | rpe=True, z_size=7, x_size=14): 67 | super().__init__() 68 | 69 | self.num_heads = num_heads 70 | 71 | head_dim = dim // num_heads 72 | 73 | self.scale = qk_scale or head_dim ** -0.5 74 | 75 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 76 | self.attn_drop = nn.Dropout(attn_drop) 77 | 78 | self.proj = nn.Linear(dim, dim) 79 | 80 | self.proj_l = nn.Linear(num_heads, num_heads) 81 | self.proj_w = nn.Linear(num_heads, num_heads) 82 | 83 | self.proj_drop = nn.Dropout(proj_drop) 84 | 85 | self.rpe = rpe 86 | if self.rpe: 87 | relative_position_index = \ 88 | generate_2d_concatenated_self_attention_relative_positional_encoding_index([z_size, z_size], 89 | [x_size, x_size]) 90 | self.register_buffer("relative_position_index", relative_position_index) 91 | # define a parameter table of relative position bias 92 | self.relative_position_bias_table = nn.Parameter(torch.empty((num_heads, 93 | relative_position_index.max() + 1))) 94 | trunc_normal_(self.relative_position_bias_table, std=0.02) 95 | 96 | def forward(self, x, mask=None): 97 | B, N, C = x.shape 98 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 99 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 100 | 101 | attn = (q @ k.transpose(-2, -1)) 102 | 103 | if self.rpe: 104 | relative_position_bias = self.relative_position_bias_table[:, self.relative_position_index].unsqueeze(0) 105 | attn += relative_position_bias 106 | 107 | if mask is not None: 108 | attn = attn.masked_fill(mask.unsqueeze(1).unsqueeze(2), 109 | float('-inf'),) 110 | 111 | attn = self.proj_l(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) 112 | 113 | attn = attn.softmax(dim=-1) 114 | 115 | attn = self.proj_w(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) 116 | attn = self.attn_drop(attn) 117 | 118 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 119 | x = self.proj(x) 120 | x = self.proj_drop(x) 121 | return x -------------------------------------------------------------------------------- /model/ostrack/ostrack.py: -------------------------------------------------------------------------------- 1 | """ 2 | Basic OSTrack model. 3 | """ 4 | import math 5 | import os 6 | from typing import List 7 | 8 | import torch 9 | from torch import nn 10 | from torch.nn.modules.transformer import _get_clones 11 | 12 | from .layers.head import build_box_head 13 | from .vit import vit_base_patch16_224 14 | from .vit_ce import vit_large_patch16_224_ce, vit_base_patch16_224_ce 15 | from .utils.box_ops import box_xyxy_to_cxcywh 16 | 17 | 18 | class OSTrack(nn.Module): 19 | """ This is the base class for OSTrack """ 20 | 21 | def __init__(self, transformer, box_head, aux_loss=False, head_type="CORNER"): 22 | """ Initializes the model. 23 | Parameters: 24 | transformer: torch module of the transformer architecture. 25 | aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. 26 | """ 27 | super().__init__() 28 | self.backbone = transformer 29 | self.box_head = box_head 30 | 31 | self.aux_loss = aux_loss 32 | self.head_type = head_type 33 | if head_type == "CORNER" or head_type == "CENTER": 34 | self.feat_sz_s = int(box_head.feat_sz) 35 | self.feat_len_s = int(box_head.feat_sz ** 2) 36 | 37 | if self.aux_loss: 38 | self.box_head = _get_clones(self.box_head, 6) 39 | 40 | def forward(self, template: torch.Tensor, 41 | search: torch.Tensor, 42 | ce_template_mask=None, 43 | ce_keep_rate=None, 44 | return_last_attn=False, 45 | ): 46 | x, aux_dict = self.backbone(z=template, x=search, 47 | ce_template_mask=ce_template_mask, 48 | ce_keep_rate=ce_keep_rate, 49 | return_last_attn=return_last_attn, ) 50 | 51 | # Forward head 52 | feat_last = x 53 | if isinstance(x, list): 54 | feat_last = x[-1] 55 | out = self.forward_head(feat_last, None) 56 | 57 | out.update(aux_dict) 58 | out['backbone_feat'] = x 59 | return out 60 | 61 | def forward_head(self, cat_feature, gt_score_map=None): 62 | """ 63 | cat_feature: output embeddings of the backbone, it can be (HW1+HW2, B, C) or (HW2, B, C) 64 | """ 65 | enc_opt = cat_feature[:, -self.feat_len_s:] # encoder output for the search region (B, HW, C) 66 | opt = (enc_opt.unsqueeze(-1)).permute((0, 3, 2, 1)).contiguous() 67 | bs, Nq, C, HW = opt.size() 68 | opt_feat = opt.view(-1, C, self.feat_sz_s, self.feat_sz_s) 69 | 70 | if self.head_type == "CORNER": 71 | # run the corner head 72 | pred_box, score_map = self.box_head(opt_feat, True) 73 | outputs_coord = box_xyxy_to_cxcywh(pred_box) 74 | outputs_coord_new = outputs_coord.view(bs, Nq, 4) 75 | out = {'pred_boxes': outputs_coord_new, 76 | 'score_map': score_map, 77 | } 78 | return out 79 | 80 | elif self.head_type == "CENTER": 81 | # run the center head 82 | score_map_ctr, bbox, size_map, offset_map = self.box_head(opt_feat, gt_score_map) 83 | # outputs_coord = box_xyxy_to_cxcywh(bbox) 84 | outputs_coord = bbox 85 | outputs_coord_new = outputs_coord.view(bs, Nq, 4) 86 | out = {'pred_boxes': outputs_coord_new, 87 | 'score_map': score_map_ctr, 88 | 'size_map': size_map, 89 | 'offset_map': offset_map} 90 | return out 91 | else: 92 | raise NotImplementedError 93 | 94 | 95 | def build_ostrack(cfg, training=True): 96 | current_dir = os.path.dirname(os.path.abspath(__file__)) # This is your Project Root 97 | pretrained_path = os.path.join(current_dir, '../../../pretrained_models') 98 | if cfg['MODEL']['PRETRAIN_FILE'] and ('OSTrack' not in cfg['MODEL']['PRETRAIN_FILE']) and training: 99 | pretrained = os.path.join(pretrained_path, cfg['MODEL']['PRETRAIN_FILE']) 100 | else: 101 | pretrained = '' 102 | if cfg['MODEL']['BACKBONE']['TYPE'] == 'vit_base_patch16_224': 103 | backbone = vit_base_patch16_224(pretrained, drop_path_rate=cfg['TRAIN']['DROP_PATH_RATE']) 104 | hidden_dim = backbone.embed_dim 105 | patch_start_index = 1 106 | 107 | elif cfg['MODEL']['BACKBONE']['TYPE'] == 'vit_base_patch16_224_ce': 108 | backbone = vit_base_patch16_224_ce(pretrained, drop_path_rate=cfg['TRAIN']['DROP_PATH_RATE'], 109 | ce_loc=cfg['MODEL']['BACKBONE']['CE_LOC'], 110 | ce_keep_ratio=cfg['MODEL']['BACKBONE']['CE_KEEP_RATIO'], 111 | ) 112 | hidden_dim = backbone.embed_dim 113 | patch_start_index = 1 114 | 115 | elif cfg['MODEL']['BACKBONE']['TYPE'] == 'vit_large_patch16_224_ce': 116 | backbone = vit_large_patch16_224_ce(pretrained, drop_path_rate=cfg['TRAIN']['DROP_PATH_RATE'], 117 | ce_loc=cfg['MODEL']['BACKBONE']['CE_LOC'], 118 | ce_keep_ratio=cfg['MODEL']['BACKBONE']['CE_KEEP_RATIO'], 119 | ) 120 | 121 | hidden_dim = backbone.embed_dim 122 | patch_start_index = 1 123 | 124 | else: 125 | raise NotImplementedError 126 | 127 | backbone.finetune_track(cfg=cfg, patch_start_index=patch_start_index) 128 | 129 | box_head = build_box_head(cfg, hidden_dim) 130 | 131 | model = OSTrack( 132 | backbone, 133 | box_head, 134 | aux_loss=False, 135 | head_type=cfg['MODEL']['HEAD']['TYPE'], 136 | ) 137 | 138 | if 'OSTrack' in cfg['MODEL']['PRETRAIN_FILE'] and training: 139 | checkpoint = torch.load(cfg['MODEL']['PRETRAIN_FILE'], map_location="cpu") 140 | missing_keys, unexpected_keys = model.load_state_dict(checkpoint["net"], strict=False) 141 | print('Load pretrained model from: ' + cfg['MODEL']['PRETRAIN_FILE']) 142 | 143 | return model 144 | -------------------------------------------------------------------------------- /model/ostrack/layers/attn_blocks.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from timm.models.layers import Mlp, DropPath, trunc_normal_, lecun_normal_ 5 | 6 | from .attn import Attention 7 | 8 | 9 | def candidate_elimination(attn: torch.Tensor, tokens: torch.Tensor, lens_t: int, keep_ratio: float, global_index: torch.Tensor, box_mask_z: torch.Tensor): 10 | """ 11 | Eliminate potential background candidates for computation reduction and noise cancellation. 12 | Args: 13 | attn (torch.Tensor): [B, num_heads, L_t + L_s, L_t + L_s], attention weights 14 | tokens (torch.Tensor): [B, L_t + L_s, C], template and search region tokens 15 | lens_t (int): length of template 16 | keep_ratio (float): keep ratio of search region tokens (candidates) 17 | global_index (torch.Tensor): global index of search region tokens 18 | box_mask_z (torch.Tensor): template mask used to accumulate attention weights 19 | 20 | Returns: 21 | tokens_new (torch.Tensor): tokens after candidate elimination 22 | keep_index (torch.Tensor): indices of kept search region tokens 23 | removed_index (torch.Tensor): indices of removed search region tokens 24 | """ 25 | lens_s = attn.shape[-1] - lens_t 26 | bs, hn, _, _ = attn.shape 27 | 28 | lens_keep = math.ceil(keep_ratio * lens_s) 29 | if lens_keep == lens_s: 30 | return tokens, global_index, None 31 | 32 | attn_t = attn[:, :, :lens_t, lens_t:] 33 | 34 | if box_mask_z is not None: 35 | box_mask_z = box_mask_z.unsqueeze(1).unsqueeze(-1).expand(-1, attn_t.shape[1], -1, attn_t.shape[-1]) 36 | # attn_t = attn_t[:, :, box_mask_z, :] 37 | attn_t = attn_t[box_mask_z] 38 | attn_t = attn_t.view(bs, hn, -1, lens_s) 39 | attn_t = attn_t.mean(dim=2).mean(dim=1) # B, H, L-T, L_s --> B, L_s 40 | 41 | # attn_t = [attn_t[i, :, box_mask_z[i, :], :] for i in range(attn_t.size(0))] 42 | # attn_t = [attn_t[i].mean(dim=1).mean(dim=0) for i in range(len(attn_t))] 43 | # attn_t = torch.stack(attn_t, dim=0) 44 | else: 45 | attn_t = attn_t.mean(dim=2).mean(dim=1) # B, H, L-T, L_s --> B, L_s 46 | 47 | # use sort instead of topk, due to the speed issue 48 | # https://github.com/pytorch/pytorch/issues/22812 49 | sorted_attn, indices = torch.sort(attn_t, dim=1, descending=True) 50 | 51 | topk_attn, topk_idx = sorted_attn[:, :lens_keep], indices[:, :lens_keep] 52 | non_topk_attn, non_topk_idx = sorted_attn[:, lens_keep:], indices[:, lens_keep:] 53 | 54 | keep_index = global_index.gather(dim=1, index=topk_idx) 55 | removed_index = global_index.gather(dim=1, index=non_topk_idx) 56 | 57 | # separate template and search tokens 58 | tokens_t = tokens[:, :lens_t] 59 | tokens_s = tokens[:, lens_t:] 60 | 61 | # obtain the attentive and inattentive tokens 62 | B, L, C = tokens_s.shape 63 | # topk_idx_ = topk_idx.unsqueeze(-1).expand(B, lens_keep, C) 64 | attentive_tokens = tokens_s.gather(dim=1, index=topk_idx.unsqueeze(-1).expand(B, -1, C)) 65 | # inattentive_tokens = tokens_s.gather(dim=1, index=non_topk_idx.unsqueeze(-1).expand(B, -1, C)) 66 | 67 | # compute the weighted combination of inattentive tokens 68 | # fused_token = non_topk_attn @ inattentive_tokens 69 | 70 | # concatenate these tokens 71 | # tokens_new = torch.cat([tokens_t, attentive_tokens, fused_token], dim=0) 72 | tokens_new = torch.cat([tokens_t, attentive_tokens], dim=1) 73 | 74 | return tokens_new, keep_index, removed_index 75 | 76 | 77 | class CEBlock(nn.Module): 78 | 79 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 80 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, keep_ratio_search=1.0,): 81 | super().__init__() 82 | self.norm1 = norm_layer(dim) 83 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 84 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 85 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 86 | self.norm2 = norm_layer(dim) 87 | mlp_hidden_dim = int(dim * mlp_ratio) 88 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 89 | 90 | self.keep_ratio_search = keep_ratio_search 91 | 92 | def forward(self, x, global_index_template, global_index_search, mask=None, ce_template_mask=None, keep_ratio_search=None): 93 | x_attn, attn = self.attn(self.norm1(x), mask, True) 94 | x = x + self.drop_path(x_attn) 95 | lens_t = global_index_template.shape[1] 96 | 97 | removed_index_search = None 98 | if self.keep_ratio_search < 1 and (keep_ratio_search is None or keep_ratio_search < 1): 99 | keep_ratio_search = self.keep_ratio_search if keep_ratio_search is None else keep_ratio_search 100 | x, global_index_search, removed_index_search = candidate_elimination(attn, x, lens_t, keep_ratio_search, global_index_search, ce_template_mask) 101 | 102 | x = x + self.drop_path(self.mlp(self.norm2(x))) 103 | return x, global_index_template, global_index_search, removed_index_search, attn 104 | 105 | 106 | class Block(nn.Module): 107 | 108 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 109 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 110 | super().__init__() 111 | self.norm1 = norm_layer(dim) 112 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 113 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 114 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 115 | self.norm2 = norm_layer(dim) 116 | mlp_hidden_dim = int(dim * mlp_ratio) 117 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 118 | 119 | def forward(self, x, mask=None): 120 | x = x + self.drop_path(self.attn(self.norm1(x), mask)) 121 | x = x + self.drop_path(self.mlp(self.norm2(x))) 122 | return x 123 | -------------------------------------------------------------------------------- /yolo_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | TODO: 使用YOLOv5进行无人机图像检测定位测试, 3 | 数据集采用红外影像数据./datasets/test1 4 | 进行分割重构数据集 5 | Time: 2025/03/07-Redal 6 | """ 7 | import pathlib 8 | temp = pathlib.PosixPath 9 | pathlib.PosixPath = pathlib.WindowsPath 10 | 11 | import os 12 | import cv2 13 | import argparse 14 | import sys 15 | import torch 16 | import numpy as np 17 | import matplotlib.pyplot as plt 18 | from yolov5.models.experimental import attempt_load 19 | from bbox import decoder 20 | 21 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 22 | proj_path = os.getcwd() 23 | sys.path.append(os.path.join(proj_path, "yolov5")) 24 | sys.path.append(os.path.abspath(os.path.join(os.getcwd(), ".."))) 25 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 26 | 27 | 28 | 29 | ############################## 配置解析文件变量 ############################### 30 | def parser(): 31 | parser = argparse.ArgumentParser(description='YOLOv5 inference config', 32 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 33 | parser.add_argument_group('YOLOv5 Dataset Reconstruction') 34 | parser.add_argument('--gt_dir', type=str, default=r'datasets\images', help='ground truth file directory path(s)') 35 | parser.add_argument('--gt_mode', type=str, default=r'train', help='ground truth file mode [train, val]') 36 | parser.add_argument('--gt_file', type=str, default=r'groundtruth.txt', help='ground truth file name') 37 | parser.add_argument('--label_dir', type=str, default=r'datasets\labels', help='labels file directory path(s)') 38 | parser.add_argument('--label_mode', type=str, default=r'train', help='labels file mode [train, val]') 39 | parser.add_argument('--origin_dir', type=str, default=r'datasets\test', help='images file directory path(s)') 40 | # related to YOLOv5 model and weights 41 | parser.add_argument_group('YOLOv5 Model Configuration') 42 | parser.add_argument('--weights_dir', type=str, default=r'weights', help='weights file directory path(s)') 43 | parser.add_argument('--weights_file', type=str, default=r'yolov5l.pt', help='weights file name yolov5s/m/l/x.pt') 44 | parser.add_argument('--yolo_dir', type=str, default=r'yolov5', help='yolov5 file directory path(s)') 45 | args = parser.parse_args() 46 | return args 47 | 48 | 49 | 50 | ############################## 重构Got_10k数据集 ############################### 51 | def split_groundtruth_to_individual_labels(args): 52 | """Split the label information in the groundtruth.txt file into the txt file for each image 53 | :param input image is 640x512 pixels width x height""" 54 | gt_filepath = os.path.join(args.gt_dir, args.gt_mode, args.gt_file) 55 | output_folder = os.path.join(args.label_dir, args.label_mode) 56 | img_w = 640; img_h = 512 57 | with open(gt_filepath, 'r') as f: 58 | lines = f.readlines() 59 | for idx,line in enumerate(lines): 60 | parts = line.strip().split(',') 61 | lr_w, lr_h = float(parts[0]), float(parts[1]) 62 | box_w, box_h = float(parts[2]), float(parts[3]) 63 | # normalization 64 | cneter_w, center_h = (lr_w + box_w / 2) / img_w , (lr_h + box_h / 2) / img_h 65 | bounding_w, bounding_h = box_w / img_w, box_h / img_h 66 | txt_file_path = os.path.join(output_folder, f'{idx:08d}.txt') 67 | with open(txt_file_path, 'w') as txt_file: 68 | txt_file.write(f'{0} {cneter_w} {center_h} {bounding_w} {bounding_h}\n') 69 | print(f'{txt_file_path} has been written', end='\r', flush=True) 70 | 71 | 72 | def rebuild_data(args): 73 | """Rebuild the images/labels dataset form test directory""" 74 | img_w = 640; img_h = 512 75 | origin_dir = args.origin_dir 76 | original_dirspath = [os.path.join(origin_dir, dirname) for dirname in os.listdir(origin_dir)] 77 | for idx, dirpath in enumerate(original_dirspath): 78 | with open(os.path.join(dirpath, 'groundtruth.txt')) as gt: 79 | gt_lines = gt.readlines() 80 | if idx % 2 == 0: 81 | args.gt_mode = 'train' 82 | args.label_mode = 'train' 83 | else: 84 | args.gt_mode = 'val' 85 | args.label_mode = 'val' 86 | for id,line in enumerate(gt_lines): 87 | num_files = len(os.listdir(os.path.join(args.gt_dir, args.gt_mode))) 88 | # read image and save it to train images folder 89 | img_path = os.path.join(dirpath, f'{id:08d}.jpg') 90 | img = cv2.imread(img_path) 91 | img_save_path = os.path.join(args.gt_dir, args.gt_mode, f'{num_files:08d}.jpg') 92 | cv2.imwrite(img_save_path, img) 93 | # read the groundtruth.txt file into train labels folder 94 | parts = line.strip().split(',') 95 | lr_w, lr_h = float(parts[0]), float(parts[1]) 96 | box_w, box_h = float(parts[2]), float(parts[3]) 97 | cneter_w, center_h = (lr_w + box_w / 2) / img_w , (lr_h + box_h / 2) / img_h 98 | bounding_w, bounding_h = box_w / img_w, box_h / img_h 99 | txt_filepath = os.path.join(args.label_dir, args.label_mode, f'{num_files:08d}.txt') 100 | with open(txt_filepath, 'w') as txtf: 101 | txtf.write(f'{0} {cneter_w} {center_h} {bounding_w} {bounding_h}\n') 102 | print(f'{dirpath}: {txt_filepath} and {img_save_path} has been written', end='\r', flush=True) 103 | 104 | 105 | 106 | ############################## 部署Yolov5模型 ############################### 107 | def load_yolo(args): 108 | """Load the yolov5 model weights and return the model 109 | :param args.weights_dir: weights file directory path(s)""" 110 | weights_fp = os.path.join(args.weights_dir, args.weights_file) 111 | print(f'Loading weights from {weights_fp} and yolov5 is {args.yolo_dir}') 112 | yolo_model = attempt_load(weights=weights_fp, device=device) 113 | return yolo_model.eval() 114 | 115 | 116 | 117 | ############################## 主函数测试分析 ############################### 118 | if __name__ == '__main__': 119 | args = parser() 120 | model = load_yolo(args).to(device) 121 | # print(model) 122 | 123 | img0 = cv2.imread(r'assets\uav_1.jpg') 124 | img_de = decoder(model, img0) 125 | cv2.imshow('img', img_de) 126 | cv2.waitKey(0) 127 | cv2.destroyAllWindows() -------------------------------------------------------------------------------- /model/ostrack/utils/heapmap_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def generate_heatmap(bboxes, patch_size=320, stride=16): 6 | """ 7 | Generate ground truth heatmap same as CenterNet 8 | Args: 9 | bboxes (torch.Tensor): shape of [num_search, bs, 4] 10 | 11 | Returns: 12 | gaussian_maps: list of generated heatmap 13 | 14 | """ 15 | gaussian_maps = [] 16 | heatmap_size = patch_size // stride 17 | for single_patch_bboxes in bboxes: 18 | bs = single_patch_bboxes.shape[0] 19 | gt_scoremap = torch.zeros(bs, heatmap_size, heatmap_size) 20 | classes = torch.arange(bs).to(torch.long) 21 | bbox = single_patch_bboxes * heatmap_size 22 | wh = bbox[:, 2:] 23 | centers_int = (bbox[:, :2] + wh / 2).round() 24 | CenterNetHeatMap.generate_score_map(gt_scoremap, classes, wh, centers_int, 0.7) # 生成的高斯热图放到gt_scoremap中 25 | gaussian_maps.append(gt_scoremap.to(bbox.device)) 26 | return gaussian_maps 27 | 28 | 29 | class CenterNetHeatMap(object): 30 | @staticmethod 31 | def generate_score_map(fmap, gt_class, gt_wh, centers_int, min_overlap): 32 | radius = CenterNetHeatMap.get_gaussian_radius(gt_wh, min_overlap) # 生成高斯核的半径 33 | radius = torch.clamp_min(radius, 0) 34 | radius = radius.type(torch.int).cpu().numpy() 35 | for i in range(gt_class.shape[0]): 36 | channel_index = gt_class[i] 37 | CenterNetHeatMap.draw_gaussian(fmap[channel_index], centers_int[i], radius[i]) # 根据中心点坐标和半径画出高斯热图,并放到fmap中 38 | 39 | @staticmethod 40 | def get_gaussian_radius(box_size, min_overlap): 41 | """ 42 | copyed from CornerNet 43 | box_size (w, h), it could be a torch.Tensor, numpy.ndarray, list or tuple 44 | notice: we are using a bug-version, please refer to fix bug version in CornerNet 45 | """ 46 | # box_tensor = torch.Tensor(box_size) 47 | box_tensor = box_size 48 | width, height = box_tensor[..., 0], box_tensor[..., 1] 49 | 50 | a1 = 1 51 | b1 = height + width 52 | c1 = width * height * (1 - min_overlap) / (1 + min_overlap) 53 | sq1 = torch.sqrt(b1 ** 2 - 4 * a1 * c1) 54 | r1 = (b1 + sq1) / 2 55 | 56 | a2 = 4 57 | b2 = 2 * (height + width) 58 | c2 = (1 - min_overlap) * width * height 59 | sq2 = torch.sqrt(b2 ** 2 - 4 * a2 * c2) 60 | r2 = (b2 + sq2) / 2 61 | 62 | a3 = 4 * min_overlap 63 | b3 = -2 * min_overlap * (height + width) 64 | c3 = (min_overlap - 1) * width * height 65 | sq3 = torch.sqrt(b3 ** 2 - 4 * a3 * c3) 66 | r3 = (b3 + sq3) / 2 67 | 68 | return torch.min(r1, torch.min(r2, r3)) 69 | 70 | @staticmethod 71 | def gaussian2D(radius, sigma=1): 72 | # m, n = [(s - 1.) / 2. for s in shape] 73 | m, n = radius 74 | y, x = np.ogrid[-m: m + 1, -n: n + 1] # 生成网格,y和x代表两个方向 75 | 76 | gauss = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) # 对应原点 77 | gauss[gauss < np.finfo(gauss.dtype).eps * gauss.max()] = 0 # 将高斯核中值非常接近于零的元素设置为零 78 | return gauss 79 | 80 | @staticmethod 81 | def draw_gaussian(fmap, center, radius, k=1): 82 | diameter = 2 * radius + 1 83 | gaussian = CenterNetHeatMap.gaussian2D((radius, radius), sigma=diameter / 6) # 生成二维高斯核,没有做归一化处理,只有中间的那个数为1 84 | gaussian = torch.Tensor(gaussian) 85 | x, y = int(center[0]), int(center[1]) 86 | height, width = fmap.shape[:2] 87 | 88 | left, right = min(x, radius), min(width - x, radius + 1) 89 | top, bottom = min(y, radius), min(height - y, radius + 1) 90 | 91 | masked_fmap = fmap[y - top: y + bottom, x - left: x + right] 92 | masked_gaussian = gaussian[radius - top: radius + bottom, radius - left: radius + right] 93 | if min(masked_gaussian.shape) > 0 and min(masked_fmap.shape) > 0: 94 | masked_fmap = torch.max(masked_fmap, masked_gaussian * k) 95 | fmap[y - top: y + bottom, x - left: x + right] = masked_fmap 96 | # return fmap 97 | 98 | 99 | def compute_grids(features, strides): 100 | """ 101 | grids regret to the input image size 102 | """ 103 | grids = [] 104 | for level, feature in enumerate(features): 105 | h, w = feature.size()[-2:] 106 | shifts_x = torch.arange( 107 | 0, w * strides[level], 108 | step=strides[level], 109 | dtype=torch.float32, device=feature.device) 110 | shifts_y = torch.arange( 111 | 0, h * strides[level], 112 | step=strides[level], 113 | dtype=torch.float32, device=feature.device) 114 | shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) 115 | shift_x = shift_x.reshape(-1) 116 | shift_y = shift_y.reshape(-1) 117 | grids_per_level = torch.stack((shift_x, shift_y), dim=1) + \ 118 | strides[level] // 2 119 | grids.append(grids_per_level) 120 | return grids 121 | 122 | 123 | def get_center3x3(locations, centers, strides, range=3): 124 | ''' 125 | Inputs: 126 | locations: M x 2 127 | centers: N x 2 128 | strides: M 129 | ''' 130 | range = (range - 1) / 2 131 | M, N = locations.shape[0], centers.shape[0] 132 | locations_expanded = locations.view(M, 1, 2).expand(M, N, 2) # M x N x 2 133 | centers_expanded = centers.view(1, N, 2).expand(M, N, 2) # M x N x 2 134 | strides_expanded = strides.view(M, 1, 1).expand(M, N, 2) # M x N 135 | centers_discret = ((centers_expanded / strides_expanded).int() * strides_expanded).float() + \ 136 | strides_expanded / 2 # M x N x 2 137 | dist_x = (locations_expanded[:, :, 0] - centers_discret[:, :, 0]).abs() 138 | dist_y = (locations_expanded[:, :, 1] - centers_discret[:, :, 1]).abs() 139 | return (dist_x <= strides_expanded[:, :, 0] * range) & \ 140 | (dist_y <= strides_expanded[:, :, 0] * range) 141 | 142 | 143 | def get_pred(score_map_ctr, size_map, offset_map, feat_size): 144 | max_score, idx = torch.max(score_map_ctr.flatten(1), dim=1, keepdim=True) 145 | 146 | idx = idx.unsqueeze(1).expand(idx.shape[0], 2, 1) 147 | size = size_map.flatten(2).gather(dim=2, index=idx).squeeze(-1) 148 | offset = offset_map.flatten(2).gather(dim=2, index=idx).squeeze(-1) 149 | 150 | return size * feat_size, offset 151 | -------------------------------------------------------------------------------- /model/ostrack/base_backbone.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from timm.models.vision_transformer import resize_pos_embed 6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 7 | 8 | from .layers.patch_embed import PatchEmbed 9 | from .util import combine_tokens, recover_tokens 10 | 11 | 12 | class BaseBackbone(nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | # for original ViT 17 | self.pos_embed = None 18 | self.img_size = [224, 224] 19 | self.patch_size = 16 20 | self.embed_dim = 384 21 | 22 | self.cat_mode = 'direct' 23 | 24 | self.pos_embed_z = None 25 | self.pos_embed_x = None 26 | 27 | self.template_segment_pos_embed = None 28 | self.search_segment_pos_embed = None 29 | 30 | self.return_inter = False 31 | self.return_stage = [2, 5, 8, 11] 32 | 33 | self.add_cls_token = False 34 | self.add_sep_seg = False 35 | 36 | def finetune_track(self, cfg, patch_start_index=1): 37 | 38 | search_size = to_2tuple(cfg['DATA']['SEARCH']['SIZE']) 39 | template_size = to_2tuple(cfg['DATA']['TEMPLATE']['SIZE']) 40 | new_patch_size = cfg['MODEL']['BACKBONE']['STRIDE'] 41 | 42 | self.cat_mode = cfg['MODEL']['BACKBONE']['CAT_MODE'] 43 | self.return_inter = cfg['MODEL']['RETURN_INTER'] 44 | self.return_stage = cfg['MODEL']['RETURN_STAGES'] 45 | self.add_sep_seg = cfg['MODEL']['BACKBONE']['SEP_SEG'] 46 | 47 | # resize patch embedding 48 | if new_patch_size != self.patch_size: 49 | print('Inconsistent Patch Size With The Pretrained Weights, Interpolate The Weight!') 50 | old_patch_embed = {} 51 | for name, param in self.patch_embed.named_parameters(): 52 | if 'weight' in name: 53 | param = nn.functional.interpolate(param, size=(new_patch_size, new_patch_size), 54 | mode='bicubic', align_corners=False) 55 | param = nn.Parameter(param) 56 | old_patch_embed[name] = param 57 | self.patch_embed = PatchEmbed(img_size=self.img_size, patch_size=new_patch_size, in_chans=3, 58 | embed_dim=self.embed_dim) 59 | self.patch_embed.proj.bias = old_patch_embed['proj.bias'] 60 | self.patch_embed.proj.weight = old_patch_embed['proj.weight'] 61 | 62 | # for patch embedding 63 | patch_pos_embed = self.pos_embed[:, patch_start_index:, :] 64 | patch_pos_embed = patch_pos_embed.transpose(1, 2) 65 | B, E, Q = patch_pos_embed.shape 66 | P_H, P_W = self.img_size[0] // self.patch_size, self.img_size[1] // self.patch_size 67 | patch_pos_embed = patch_pos_embed.view(B, E, P_H, P_W) 68 | 69 | # for search region 70 | H, W = search_size 71 | new_P_H, new_P_W = H // new_patch_size, W // new_patch_size 72 | search_patch_pos_embed = nn.functional.interpolate(patch_pos_embed, size=(new_P_H, new_P_W), mode='bicubic', 73 | align_corners=False) 74 | search_patch_pos_embed = search_patch_pos_embed.flatten(2).transpose(1, 2) 75 | 76 | # for template region 77 | H, W = template_size 78 | new_P_H, new_P_W = H // new_patch_size, W // new_patch_size 79 | template_patch_pos_embed = nn.functional.interpolate(patch_pos_embed, size=(new_P_H, new_P_W), mode='bicubic', 80 | align_corners=False) 81 | template_patch_pos_embed = template_patch_pos_embed.flatten(2).transpose(1, 2) 82 | 83 | self.pos_embed_z = nn.Parameter(template_patch_pos_embed) 84 | self.pos_embed_x = nn.Parameter(search_patch_pos_embed) 85 | 86 | # for cls token (keep it but not used) 87 | if self.add_cls_token and patch_start_index > 0: 88 | cls_pos_embed = self.pos_embed[:, 0:1, :] 89 | self.cls_pos_embed = nn.Parameter(cls_pos_embed) 90 | 91 | # separate token and segment token 92 | if self.add_sep_seg: 93 | self.template_segment_pos_embed = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) 94 | self.template_segment_pos_embed = trunc_normal_(self.template_segment_pos_embed, std=.02) 95 | self.search_segment_pos_embed = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) 96 | self.search_segment_pos_embed = trunc_normal_(self.search_segment_pos_embed, std=.02) 97 | 98 | # self.cls_token = None 99 | # self.pos_embed = None 100 | 101 | if self.return_inter: 102 | for i_layer in self.return_stage: 103 | if i_layer != 11: 104 | norm_layer = partial(nn.LayerNorm, eps=1e-6) 105 | layer = norm_layer(self.embed_dim) 106 | layer_name = f'norm{i_layer}' 107 | self.add_module(layer_name, layer) 108 | 109 | def forward_features(self, z, x): 110 | B, H, W = x.shape[0], x.shape[2], x.shape[3] 111 | 112 | x = self.patch_embed(x) 113 | z = self.patch_embed(z) 114 | 115 | if self.add_cls_token: 116 | cls_tokens = self.cls_token.expand(B, -1, -1) 117 | cls_tokens = cls_tokens + self.cls_pos_embed 118 | 119 | z += self.pos_embed_z 120 | x += self.pos_embed_x 121 | 122 | if self.add_sep_seg: 123 | x += self.search_segment_pos_embed 124 | z += self.template_segment_pos_embed 125 | 126 | x = combine_tokens(z, x, mode=self.cat_mode) 127 | if self.add_cls_token: 128 | x = torch.cat([cls_tokens, x], dim=1) 129 | 130 | x = self.pos_drop(x) 131 | 132 | for i, blk in enumerate(self.blocks): 133 | x = blk(x) 134 | 135 | lens_z = self.pos_embed_z.shape[1] 136 | lens_x = self.pos_embed_x.shape[1] 137 | x = recover_tokens(x, lens_z, lens_x, mode=self.cat_mode) 138 | 139 | aux_dict = {"attn": None} 140 | return self.norm(x), aux_dict 141 | 142 | def forward(self, z, x, **kwargs): 143 | """ 144 | Joint feature extraction and relation modeling for the basic ViT backbone. 145 | Args: 146 | z (torch.Tensor): template feature, [B, C, H_z, W_z] 147 | x (torch.Tensor): search region feature, [B, C, H_x, W_x] 148 | 149 | Returns: 150 | x (torch.Tensor): merged template and search region feature, [B, L_z+L_x, C] 151 | attn : None 152 | """ 153 | x, aux_dict = self.forward_features(z, x,) 154 | 155 | return x, aux_dict 156 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | """ 2 | 任务: 创建基本的OSTrack模型所需的实例已经辅助函数, 3 | 同时包括相关的机器视觉处理函数 4 | 时间: 2025/01/13-Redal 5 | """ 6 | import os 7 | import sys 8 | import cv2 9 | import yaml 10 | import torch 11 | import argparse 12 | import numpy as np 13 | from PIL import Image, ImageDraw 14 | from torchvision.transforms import transforms 15 | from model.ostrack.ostrack import build_ostrack 16 | import matplotlib.pyplot as plt 17 | 18 | current_path = os.path.abspath(os.path.dirname(__file__)) 19 | sys.path.append(current_path) 20 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 21 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 22 | template_transform = transforms.Compose([ 23 | transforms.ToTensor(), 24 | transforms.Resize((192, 192)), 25 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),]) 26 | search_transform = transforms.Compose([ 27 | transforms.ToTensor(), 28 | transforms.Resize((384, 384)), 29 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),]) 30 | 31 | 32 | 33 | ############################### tkinter GUI相关绘图函数 ############################ 34 | def ComputeHistogramImage(frame, hist_height=200, hist_width=300): 35 | """使用 OpenCV 绘制 RGB 直方图 36 | :param frame: 输入帧(BGR 格式) 37 | :param hist_height: 直方图图像的高度 38 | :param hist_width: 直方图图像的宽度 39 | :return: 直方图图像 40 | """ 41 | # 创建灰色背景图像 42 | hist_image = np.full((hist_height, hist_width, 3), fill_value=0, dtype=np.uint8) 43 | colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255)] 44 | 45 | # 计算每个通道的直方图 46 | for i, color in enumerate(colors): 47 | hist = cv2.calcHist([frame], [i], None, [256], [0, 256]) 48 | cv2.normalize(hist, hist, 0, hist_height - 20, cv2.NORM_MINMAX) # 留出空间给坐标轴 49 | for j in range(1, 256): 50 | # 约束直方图在图像范围内 51 | x1 = (j - 1) * (hist_width // 256) 52 | y1 = hist_height - 10 - int(hist[j - 1]) # 留出空间给横轴 53 | x2 = j * (hist_width // 256) 54 | y2 = hist_height - 10 - int(hist[j]) # 留出空间给横轴 55 | cv2.line(hist_image, (x1, y1), (x2, y2), color, thickness=2) 56 | cv2.line(hist_image, (0, hist_height - 10), (hist_width, hist_height - 10), (0, 0, 0), thickness=2) # 横轴 57 | cv2.line(hist_image, (10, 0), (10, hist_height - 10), (0, 0, 0), thickness=2) 58 | 59 | # 添加横轴刻度线和标签 60 | for i in range(0, 256, 32): 61 | x = i * (hist_width // 256) 62 | cv2.line(hist_image, (x, hist_height - 10), (x, hist_height - 5), (255, 255, 255), thickness=1) 63 | cv2.putText(hist_image, str(i), (x - 10, hist_height - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1) 64 | # 添加纵轴刻度线和标签 65 | for i in range(0, hist_height - 10, 50): 66 | y = hist_height - 10 - i 67 | cv2.line(hist_image, (10, y), (15, y), (255, 255, 255), thickness=1) 68 | cv2.putText(hist_image, str(i), (20, y + 5), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1) 69 | return hist_image 70 | 71 | 72 | def CalculateSpectrogramImage(frame, spec_height=200, spec_width=200): 73 | """使用 OpenCV 绘制频谱图 74 | :param frame: 输入帧(BGR 格式) 75 | :param spec_height: 频谱图图像的高度 76 | :param spec_width: 频谱图图像的宽度 77 | :return: 频谱图图像 78 | """ 79 | gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) 80 | fft = np.fft.fft2(gray_frame) 81 | fft_shift = np.fft.fftshift(fft) # 将低频部分移到中心 82 | magnitude_spectrum = np.log(np.abs(fft_shift) + 1) # 计算幅度谱并取对数 83 | 84 | magnitude_spectrum = cv2.normalize(magnitude_spectrum, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) 85 | spectrogram_image = cv2.resize(magnitude_spectrum, (spec_width, spec_height)) 86 | 87 | # 频谱图美化:应用伪彩色映射 88 | spectrogram_color = cv2.applyColorMap(spectrogram_image, cv2.COLORMAP_JET) 89 | grid_color = (255, 255, 255) 90 | grid_spacing = 50 91 | for x in range(0, spec_width, grid_spacing): 92 | cv2.line(spectrogram_color, (x, 0), (x, spec_height), grid_color, 1) 93 | for y in range(0, spec_height, grid_spacing): 94 | cv2.line(spectrogram_color, (0, y), (spec_width, y), grid_color, 1) 95 | # 添加标题和坐标轴标签 96 | title = "Spectrogram" 97 | cv2.putText(spectrogram_color, title, (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1) 98 | cv2.putText(spectrogram_color, "Frequency", (10, spec_height - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 0), 1) 99 | cv2.putText(spectrogram_color, "Time", (spec_width - 50, spec_height - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 0), 1) 100 | # 调整对比度和亮度 101 | alpha = 1.2 # 对比度系数 102 | beta = 30 # 亮度系数 103 | spectrogram_color = cv2.convertScaleAbs(spectrogram_color, alpha=alpha, beta=beta) 104 | return spectrogram_color 105 | 106 | 107 | ############################### 配置ostrack模型解析文件 ############################ 108 | def load_config(args): 109 | """read the configuration file""" 110 | config_path = os.path.join(args.config_dir, args.config_file) 111 | with open(config_path, 'r') as file: 112 | config = yaml.safe_load(file) 113 | return config 114 | 115 | class Params: 116 | """Load retrained model parameters""" 117 | def __init__(self, args): 118 | self.checkpoint = os.path.join(args.weight_dir, args.weight_file) 119 | self.debug = False 120 | self.save_all_boxes = False 121 | 122 | def config(): 123 | """make configurations about the vit model""" 124 | parser = argparse.ArgumentParser(description='OSTrack model configuration', 125 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 126 | parser.add_argument('--config_dir', default='./config', type=str, 127 | help='The directory of the configuration file') 128 | parser.add_argument('--config_file', default='vitb_384_mae_ce_32x4_got10k_ep100.yaml', type=str, 129 | help='The name of the configuration file') 130 | parser.add_argument('--weight_dir', default='./weights', type=str, 131 | help='the directory of the weight file') 132 | parser.add_argument('--weight_file', default='vit_384_mae_ce.pth', type=str, 133 | help='the name of the weight file OSTrack_ep0061.pth / vit_384_mae_ce.pth') 134 | args = parser.parse_args() 135 | # initialize the config and model weight 136 | cfg = load_config(args) 137 | ostrack_model = build_ostrack(cfg, training=False) 138 | weight_path = os.path.join(args.weight_dir, args.weight_file) 139 | ostrack_model.load_state_dict(torch.load(weight_path, map_location=device)) 140 | 141 | # params = Params(args) 142 | # ostrack_model = build_ostrack(cfg, training=False) 143 | # ostrack_model.load_state_dict(torch.load(params.checkpoint, map_location='cpu')['net'], strict=True) 144 | return ostrack_model 145 | 146 | 147 | 148 | ############################### 主函数测试分析 ################################ 149 | if __name__ == '__main__': 150 | # test the model function, and the model is processing the neighbor frames 151 | # called the template and search image with boundding box 2~5 times 152 | ostrack_model = config().eval().to(device) 153 | print(ostrack_model) 154 | 155 | 156 | template_img = Image.open('assets/uav_1.jpg') 157 | search_img = Image.open('assets/uav_3.jpg') 158 | template_img = template_transform(template_img).unsqueeze(0).to(device) 159 | search_img = search_transform(search_img).unsqueeze(0).to(device) 160 | results = ostrack_model(template_img, search_img) 161 | answer = results['pred_boxes'][0] 162 | bbox = answer.detach().cpu().numpy()[0] 163 | # depict the result 164 | search_img = cv2.imread('assets/uav_3.jpg') 165 | print(results, bbox) 166 | height, width, _ = search_img.shape 167 | x_min = int(bbox[0] * width) 168 | y_min = int(bbox[1] * height) 169 | x_max = int(bbox[2] * width) 170 | y_max = int(bbox[3] * height) 171 | 172 | color = (255, 0, 0) # 绿色 173 | thickness = 2 174 | cv2.rectangle(search_img, (x_min, y_min), (x_max, y_max), color, thickness) 175 | # 使用 matplotlib 显示图像 176 | plt.imshow(cv2.cvtColor(search_img, cv2.COLOR_BGR2RGB)) 177 | plt.axis('off') # 关闭坐标轴 178 | plt.show() 179 | cv2.imwrite('assets/uav_1_result.jpg', search_img) -------------------------------------------------------------------------------- /model/ostrack/utils/tensor.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | import copy 4 | from collections import OrderedDict 5 | 6 | 7 | class TensorDict(OrderedDict): 8 | """Container mainly used for dicts of torch tensors. Extends OrderedDict with pytorch functionality.""" 9 | 10 | def concat(self, other): 11 | """Concatenates two dicts without copying internal data.""" 12 | return TensorDict(self, **other) 13 | 14 | def copy(self): 15 | return TensorDict(super(TensorDict, self).copy()) 16 | 17 | def __deepcopy__(self, memodict={}): 18 | return TensorDict(copy.deepcopy(list(self), memodict)) 19 | 20 | def __getattr__(self, name): 21 | if not hasattr(torch.Tensor, name): 22 | raise AttributeError('\'TensorDict\' object has not attribute \'{}\''.format(name)) 23 | 24 | def apply_attr(*args, **kwargs): 25 | return TensorDict({n: getattr(e, name)(*args, **kwargs) if hasattr(e, name) else e for n, e in self.items()}) 26 | return apply_attr 27 | 28 | def attribute(self, attr: str, *args): 29 | return TensorDict({n: getattr(e, attr, *args) for n, e in self.items()}) 30 | 31 | def apply(self, fn, *args, **kwargs): 32 | return TensorDict({n: fn(e, *args, **kwargs) for n, e in self.items()}) 33 | 34 | @staticmethod 35 | def _iterable(a): 36 | return isinstance(a, (TensorDict, list)) 37 | 38 | 39 | class TensorList(list): 40 | """Container mainly used for lists of torch tensors. Extends lists with pytorch functionality.""" 41 | 42 | def __init__(self, list_of_tensors = None): 43 | if list_of_tensors is None: 44 | list_of_tensors = list() 45 | super(TensorList, self).__init__(list_of_tensors) 46 | 47 | def __deepcopy__(self, memodict={}): 48 | return TensorList(copy.deepcopy(list(self), memodict)) 49 | 50 | def __getitem__(self, item): 51 | if isinstance(item, int): 52 | return super(TensorList, self).__getitem__(item) 53 | elif isinstance(item, (tuple, list)): 54 | return TensorList([super(TensorList, self).__getitem__(i) for i in item]) 55 | else: 56 | return TensorList(super(TensorList, self).__getitem__(item)) 57 | 58 | def __add__(self, other): 59 | if TensorList._iterable(other): 60 | return TensorList([e1 + e2 for e1, e2 in zip(self, other)]) 61 | return TensorList([e + other for e in self]) 62 | 63 | def __radd__(self, other): 64 | if TensorList._iterable(other): 65 | return TensorList([e2 + e1 for e1, e2 in zip(self, other)]) 66 | return TensorList([other + e for e in self]) 67 | 68 | def __iadd__(self, other): 69 | if TensorList._iterable(other): 70 | for i, e2 in enumerate(other): 71 | self[i] += e2 72 | else: 73 | for i in range(len(self)): 74 | self[i] += other 75 | return self 76 | 77 | def __sub__(self, other): 78 | if TensorList._iterable(other): 79 | return TensorList([e1 - e2 for e1, e2 in zip(self, other)]) 80 | return TensorList([e - other for e in self]) 81 | 82 | def __rsub__(self, other): 83 | if TensorList._iterable(other): 84 | return TensorList([e2 - e1 for e1, e2 in zip(self, other)]) 85 | return TensorList([other - e for e in self]) 86 | 87 | def __isub__(self, other): 88 | if TensorList._iterable(other): 89 | for i, e2 in enumerate(other): 90 | self[i] -= e2 91 | else: 92 | for i in range(len(self)): 93 | self[i] -= other 94 | return self 95 | 96 | def __mul__(self, other): 97 | if TensorList._iterable(other): 98 | return TensorList([e1 * e2 for e1, e2 in zip(self, other)]) 99 | return TensorList([e * other for e in self]) 100 | 101 | def __rmul__(self, other): 102 | if TensorList._iterable(other): 103 | return TensorList([e2 * e1 for e1, e2 in zip(self, other)]) 104 | return TensorList([other * e for e in self]) 105 | 106 | def __imul__(self, other): 107 | if TensorList._iterable(other): 108 | for i, e2 in enumerate(other): 109 | self[i] *= e2 110 | else: 111 | for i in range(len(self)): 112 | self[i] *= other 113 | return self 114 | 115 | def __truediv__(self, other): 116 | if TensorList._iterable(other): 117 | return TensorList([e1 / e2 for e1, e2 in zip(self, other)]) 118 | return TensorList([e / other for e in self]) 119 | 120 | def __rtruediv__(self, other): 121 | if TensorList._iterable(other): 122 | return TensorList([e2 / e1 for e1, e2 in zip(self, other)]) 123 | return TensorList([other / e for e in self]) 124 | 125 | def __itruediv__(self, other): 126 | if TensorList._iterable(other): 127 | for i, e2 in enumerate(other): 128 | self[i] /= e2 129 | else: 130 | for i in range(len(self)): 131 | self[i] /= other 132 | return self 133 | 134 | def __matmul__(self, other): 135 | if TensorList._iterable(other): 136 | return TensorList([e1 @ e2 for e1, e2 in zip(self, other)]) 137 | return TensorList([e @ other for e in self]) 138 | 139 | def __rmatmul__(self, other): 140 | if TensorList._iterable(other): 141 | return TensorList([e2 @ e1 for e1, e2 in zip(self, other)]) 142 | return TensorList([other @ e for e in self]) 143 | 144 | def __imatmul__(self, other): 145 | if TensorList._iterable(other): 146 | for i, e2 in enumerate(other): 147 | self[i] @= e2 148 | else: 149 | for i in range(len(self)): 150 | self[i] @= other 151 | return self 152 | 153 | def __mod__(self, other): 154 | if TensorList._iterable(other): 155 | return TensorList([e1 % e2 for e1, e2 in zip(self, other)]) 156 | return TensorList([e % other for e in self]) 157 | 158 | def __rmod__(self, other): 159 | if TensorList._iterable(other): 160 | return TensorList([e2 % e1 for e1, e2 in zip(self, other)]) 161 | return TensorList([other % e for e in self]) 162 | 163 | def __pos__(self): 164 | return TensorList([+e for e in self]) 165 | 166 | def __neg__(self): 167 | return TensorList([-e for e in self]) 168 | 169 | def __le__(self, other): 170 | if TensorList._iterable(other): 171 | return TensorList([e1 <= e2 for e1, e2 in zip(self, other)]) 172 | return TensorList([e <= other for e in self]) 173 | 174 | def __ge__(self, other): 175 | if TensorList._iterable(other): 176 | return TensorList([e1 >= e2 for e1, e2 in zip(self, other)]) 177 | return TensorList([e >= other for e in self]) 178 | 179 | def concat(self, other): 180 | return TensorList(super(TensorList, self).__add__(other)) 181 | 182 | def copy(self): 183 | return TensorList(super(TensorList, self).copy()) 184 | 185 | def unroll(self): 186 | if not any(isinstance(t, TensorList) for t in self): 187 | return self 188 | 189 | new_list = TensorList() 190 | for t in self: 191 | if isinstance(t, TensorList): 192 | new_list.extend(t.unroll()) 193 | else: 194 | new_list.append(t) 195 | return new_list 196 | 197 | def list(self): 198 | return list(self) 199 | 200 | def attribute(self, attr: str, *args): 201 | return TensorList([getattr(e, attr, *args) for e in self]) 202 | 203 | def apply(self, fn): 204 | return TensorList([fn(e) for e in self]) 205 | 206 | def __getattr__(self, name): 207 | if not hasattr(torch.Tensor, name): 208 | raise AttributeError('\'TensorList\' object has not attribute \'{}\''.format(name)) 209 | 210 | def apply_attr(*args, **kwargs): 211 | return TensorList([getattr(e, name)(*args, **kwargs) for e in self]) 212 | 213 | return apply_attr 214 | 215 | @staticmethod 216 | def _iterable(a): 217 | return isinstance(a, (TensorList, list)) 218 | 219 | 220 | def tensor_operation(op): 221 | def islist(a): 222 | return isinstance(a, TensorList) 223 | 224 | @functools.wraps(op) 225 | def oplist(*args, **kwargs): 226 | if len(args) == 0: 227 | raise ValueError('Must be at least one argument without keyword (i.e. operand).') 228 | 229 | if len(args) == 1: 230 | if islist(args[0]): 231 | return TensorList([op(a, **kwargs) for a in args[0]]) 232 | else: 233 | # Multiple operands, assume max two 234 | if islist(args[0]) and islist(args[1]): 235 | return TensorList([op(a, b, *args[2:], **kwargs) for a, b in zip(*args[:2])]) 236 | if islist(args[0]): 237 | return TensorList([op(a, *args[1:], **kwargs) for a in args[0]]) 238 | if islist(args[1]): 239 | return TensorList([op(args[0], b, *args[2:], **kwargs) for b in args[1]]) 240 | 241 | # None of the operands are lists 242 | return op(*args, **kwargs) 243 | 244 | return oplist 245 | -------------------------------------------------------------------------------- /model/ostrack/vit_ce.py: -------------------------------------------------------------------------------- 1 | import math 2 | import logging 3 | from functools import partial 4 | from collections import OrderedDict 5 | from copy import deepcopy 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from timm.models.layers import to_2tuple 12 | 13 | from .layers.patch_embed import PatchEmbed 14 | from .util import combine_tokens, recover_tokens 15 | from .vit import VisionTransformer 16 | from .layers.attn_blocks import CEBlock 17 | 18 | _logger = logging.getLogger(__name__) 19 | 20 | 21 | class VisionTransformerCE(VisionTransformer): 22 | """ Vision Transformer with candidate elimination (CE) module 23 | 24 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` 25 | - https://arxiv.org/abs/2010.11929 26 | 27 | Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` 28 | - https://arxiv.org/abs/2012.12877 29 | """ 30 | 31 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 32 | num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, 33 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, 34 | act_layer=None, weight_init='', 35 | ce_loc=None, ce_keep_ratio=None): 36 | """ 37 | Args: 38 | img_size (int, tuple): input image size 39 | patch_size (int, tuple): patch size 40 | in_chans (int): number of input channels 41 | num_classes (int): number of classes for classification head 42 | embed_dim (int): embedding dimension 43 | depth (int): depth of transformer 44 | num_heads (int): number of attention heads 45 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 46 | qkv_bias (bool): enable bias for qkv if True 47 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 48 | distilled (bool): model includes a distillation token and head as in DeiT models 49 | drop_rate (float): dropout rate 50 | attn_drop_rate (float): attention dropout rate 51 | drop_path_rate (float): stochastic depth rate 52 | embed_layer (nn.Module): patch embedding layer 53 | norm_layer: (nn.Module): normalization layer 54 | weight_init: (str): weight init scheme 55 | """ 56 | # super().__init__() 57 | super().__init__() 58 | if isinstance(img_size, tuple): 59 | self.img_size = img_size 60 | else: 61 | self.img_size = to_2tuple(img_size) 62 | self.patch_size = patch_size 63 | self.in_chans = in_chans 64 | 65 | self.num_classes = num_classes 66 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 67 | self.num_tokens = 2 if distilled else 1 68 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 69 | act_layer = act_layer or nn.GELU 70 | 71 | self.patch_embed = embed_layer( 72 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 73 | num_patches = self.patch_embed.num_patches 74 | 75 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 76 | self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None 77 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) 78 | self.pos_drop = nn.Dropout(p=drop_rate) 79 | 80 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 81 | blocks = [] 82 | ce_index = 0 83 | self.ce_loc = ce_loc 84 | for i in range(depth): 85 | ce_keep_ratio_i = 1.0 86 | if ce_loc is not None and i in ce_loc: 87 | ce_keep_ratio_i = ce_keep_ratio[ce_index] 88 | ce_index += 1 89 | 90 | blocks.append( 91 | CEBlock( 92 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, 93 | attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, 94 | keep_ratio_search=ce_keep_ratio_i) 95 | ) 96 | 97 | self.blocks = nn.Sequential(*blocks) 98 | self.norm = norm_layer(embed_dim) 99 | 100 | self.init_weights(weight_init) 101 | 102 | def forward_features(self, z, x, mask_z=None, mask_x=None, 103 | ce_template_mask=None, ce_keep_rate=None, 104 | return_last_attn=False 105 | ): 106 | B, H, W = x.shape[0], x.shape[2], x.shape[3] 107 | 108 | x = self.patch_embed(x) 109 | z = self.patch_embed(z) 110 | self.cat_mode = 'direct' 111 | 112 | # attention mask handling 113 | # B, H, W 114 | if mask_z is not None and mask_x is not None: 115 | mask_z = F.interpolate(mask_z[None].float(), scale_factor=1. / self.patch_size).to(torch.bool)[0] 116 | mask_z = mask_z.flatten(1).unsqueeze(-1) 117 | 118 | mask_x = F.interpolate(mask_x[None].float(), scale_factor=1. / self.patch_size).to(torch.bool)[0] 119 | mask_x = mask_x.flatten(1).unsqueeze(-1) 120 | 121 | mask_x = combine_tokens(mask_z, mask_x, mode=self.cat_mode) 122 | mask_x = mask_x.squeeze(-1) 123 | 124 | if self.add_cls_token: 125 | cls_tokens = self.cls_token.expand(B, -1, -1) 126 | cls_tokens = cls_tokens + self.cls_pos_embed 127 | 128 | z += self.pos_embed_z 129 | x += self.pos_embed_x 130 | 131 | if self.add_sep_seg: 132 | x += self.search_segment_pos_embed 133 | z += self.template_segment_pos_embed 134 | 135 | x = combine_tokens(z, x, mode=self.cat_mode) 136 | if self.add_cls_token: 137 | x = torch.cat([cls_tokens, x], dim=1) 138 | 139 | x = self.pos_drop(x) 140 | 141 | lens_z = self.pos_embed_z.shape[1] 142 | lens_x = self.pos_embed_x.shape[1] 143 | 144 | global_index_t = torch.linspace(0, lens_z - 1, lens_z).to(x.device) 145 | global_index_t = global_index_t.repeat(B, 1) 146 | 147 | global_index_s = torch.linspace(0, lens_x - 1, lens_x).to(x.device) 148 | global_index_s = global_index_s.repeat(B, 1) 149 | removed_indexes_s = [] 150 | for i, blk in enumerate(self.blocks): 151 | x, global_index_t, global_index_s, removed_index_s, attn = \ 152 | blk(x, global_index_t, global_index_s, mask_x, ce_template_mask, ce_keep_rate) 153 | 154 | if self.ce_loc is not None and i in self.ce_loc: 155 | removed_indexes_s.append(removed_index_s) 156 | 157 | x = self.norm(x) 158 | lens_x_new = global_index_s.shape[1] 159 | lens_z_new = global_index_t.shape[1] 160 | 161 | z = x[:, :lens_z_new] 162 | x = x[:, lens_z_new:] 163 | 164 | if removed_indexes_s and removed_indexes_s[0] is not None: 165 | removed_indexes_cat = torch.cat(removed_indexes_s, dim=1) 166 | 167 | pruned_lens_x = lens_x - lens_x_new 168 | pad_x = torch.zeros([B, pruned_lens_x, x.shape[2]], device=x.device) 169 | x = torch.cat([x, pad_x], dim=1) 170 | index_all = torch.cat([global_index_s, removed_indexes_cat], dim=1) 171 | # recover original token order 172 | C = x.shape[-1] 173 | # x = x.gather(1, index_all.unsqueeze(-1).expand(B, -1, C).argsort(1)) 174 | x = torch.zeros_like(x).scatter_(dim=1, index=index_all.unsqueeze(-1).expand(B, -1, C).to(torch.int64), src=x) 175 | 176 | x = recover_tokens(x, lens_z_new, lens_x, mode=self.cat_mode) 177 | 178 | # re-concatenate with the template, which may be further used by other modules 179 | x = torch.cat([z, x], dim=1) 180 | 181 | aux_dict = { 182 | "attn": attn, 183 | "removed_indexes_s": removed_indexes_s, # used for visualization 184 | } 185 | 186 | return x, aux_dict 187 | 188 | def forward(self, z, x, ce_template_mask=None, ce_keep_rate=None, 189 | tnc_keep_rate=None, 190 | return_last_attn=False): 191 | 192 | x, aux_dict = self.forward_features(z, x, ce_template_mask=ce_template_mask, ce_keep_rate=ce_keep_rate,) 193 | 194 | return x, aux_dict 195 | 196 | 197 | def _create_vision_transformer(pretrained=False, **kwargs): 198 | model = VisionTransformerCE(**kwargs) 199 | 200 | if pretrained: 201 | if 'npz' in pretrained: 202 | model.load_pretrained(pretrained, prefix='') 203 | else: 204 | checkpoint = torch.load(pretrained, map_location="cpu") 205 | missing_keys, unexpected_keys = model.load_state_dict(checkpoint["model"], strict=False) 206 | print('Load pretrained model from: ' + pretrained) 207 | 208 | return model 209 | 210 | 211 | def vit_base_patch16_224_ce(pretrained=False, **kwargs): 212 | """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 213 | """ 214 | model_kwargs = dict( 215 | patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 216 | model = _create_vision_transformer(pretrained=pretrained, **model_kwargs) 217 | return model 218 | 219 | 220 | def vit_large_patch16_224_ce(pretrained=False, **kwargs): 221 | """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). 222 | """ 223 | model_kwargs = dict( 224 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) 225 | model = _create_vision_transformer(pretrained=pretrained, **model_kwargs) 226 | return model 227 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | OSTrack 3 |

4 | 5 | # :rocket:__OSTrack__ 6 | 7 | ![result](assets/bandicam.gif) 8 | 9 | OSTrack is an artificial intelligence technology for tracking and locking unmanned aerial vehicles based on the ViT deep network model. OSTrack is based on the Vision of Transformer deep learning model. For unmanned aerial vehicles moving at high speed in the near and far fields, it uses visual tracking to lock the position of the unmanned aerial vehicle in real-time video frames. The model mainly uses multiple initial anchor bounding boxes, obtains feature maps through feature extraction based on network input, and the position of the unmanned aerial vehicle is determined by the votes of the anchor boxes given by the network model. The trained network model has certain robustness to near and far fields, partial occlusion, and light changes. 10 | 11 | ## 1.Environment :bulb: 12 | 13 | | Name | Version | Name | Version | 14 | |------|---------|--------|---------| 15 | | Python | 3.8.10 | PyTorch | 2.4.1 | 16 | | opencv-python | 4.9.0.80 | Tkinter | 8.6 | 17 | | pillow | 10.2.0 | torchvision | 0.19.1 | 18 | 19 | ## 2.Usage :computer: 20 | 21 | Now you can use utils.py to get the ostrack model and use it for training and testing.the model architecture is shown below: CEblock has 12 layers, and there is also a detection head with five layers behind it. However, in reality, the test input data of Ostrack requires template and search images. Only through manual annotation or automatic annotation can the specific location of the small - scale drone in the first - frame image of the video sequence, that is, the template image, be determined. While the search image doesn't need to be processed, and a normal sequence frame can be selected. 22 | 23 | ```bash 24 | # if you want to use the ostrack model 25 | # you can use the following command 26 | python app_osyo.py 27 | ``` 28 | 29 | | CEBlock | Detection Head | 30 | | ------ | ------------- | 31 | | ![1](assets/architecture/ostrack_1.jpg) | ![2](assets/architecture/ostrack_5.jpg) | 32 | 33 | ## 3.TODO :book: 34 | 35 | - [x] Finish the model configuration code and import the vit_base_384_model_ce model. 36 | - [x] Train an initial - frame localization model using YOLOv5 for the automatic annotation of templates. 37 | - [x] Complete the function of drone tracking for imported videos in OSTrack. 38 | - [x] Finish the GUI interface for OSTrack and YOLOv5 models' deployment. 39 | 40 | | search | template | 41 | | ------ | ------------- | 42 | | ![template_image](assets/uav_1.jpg) | ![search_image](assets/uav_2.jpg) | 43 | | ![orig_video](assets/infrared_5.gif) | ![result](assets/processed_infrared_5.gif) | 44 | 45 | ## 4.YOLO Results :football: 46 | 47 | Now this is YOLOv5 model's training results, consisting of confusion_matrix, labels_correlogram, F1_curve, labels and PR/P/R_curve. The training results of YOLOv5 are not included in this project. Next, you'll deploy the s/m/l/x models of YOLOv5. When you encounter the following error in a Windows 10/11 system environment: raise NotImplementedError("cannot instantiate %r on your system"), you can add the following code to the first line of the ./yolov5/utils/general.py file. 48 | 49 | ```bash 50 | # the error is as follows: 51 | raise NotImplementedError("cannot instantiate %r on your system") 52 | NotImplementedError: cannot instantiate 'PosixPath' on your system 53 | 54 | # you can add the following code to the first line of the ./yolov5/utils/general.py file. 55 | import pathlib 56 | temp = pathlib.PosixPath 57 | pathlib.PosixPath = pathlib.WindowsPath 58 | ``` 59 | 60 | If you want to compress drone IR video, you can choose below command to compress video to .gif format. The training dataset of Yolov5 is reconstructed based on the Got_10k drone infrared dataset used by OSTrack. The purpose of using the Yolov5 model is to locate the coordinates of the first frame image of the drone, so as to provide coordinates for the template image and search image required for tracking by the Ostrack model later. 61 | 62 | ```bash 63 | ffmpeg -ss 00:00:05 -t 00:00:05 -i video/infrared.mp4 -vf "fps=1,scale=640:\ 64 | -1:flags=lanczos,split[s0][s1];[s0]palettegen=stats_mode=single:max_colors=16[p];\ 65 | [s1][p]paletteuse=dither=floyd_steinberg" -gifflags +transdiff -loop 0 \ 66 | -final_delay 20 -y output_3mb.gif 67 | ``` 68 | 69 | ![confusion_matrix](assets/results/confusion_matrix.jpg) 70 | 71 | | labels | labels_correlogram | 72 | | ------------- | ------------- | 73 | | ![labels](assets/results/labels.jpg) | ![labels_correlogram](assets/results/labels_correlogram.jpg) | 74 | 75 | | PR curve | P curve | R curve | F1 score | 76 | | ------ | ------------- | ------------- | ------------- | 77 | | ![PR_curve](assets/results/PR_curve.jpg) | ![P_curve](assets/results/P_curve.jpg) | ![R_curve](assets/results/R_curve.jpg) | ![F1_curve](assets/results/F1_curve.jpg) | 78 | 79 | Before training, the dataset architecture is shown below. And the YOLOv5 s/m/l/x version model training results are shown below. This datasets consisting of 39965 training infrared images and 40355 valing infrared images. The images directory are mainly about original infrared images and the labels dirrectory are mainly about the coordinates of the upper left corner of the bounding box (x, y) and the width and height of the bounding box (w, h). 80 | 81 | | train_batch | val_batch | val preds | 82 | | ------------- | ------------- | ------------- | 83 | |![train_batch](assets/results/train_batch2.jpg)| ![val_batch](assets/results/val_batch2_labels.jpg)|![val_preds](assets/results/val_batch2_pred.jpg)| 84 | 85 | ![loss curve](assets/results/results.jpg) 86 | 87 | ```bash 88 | # the yolov5 training dataset architecture is as follows. 89 | datasets 90 | | 91 | |____images 92 | | | 93 | | |____train 94 | | | |____00000001.jpg 95 | | | |____00000002.jpg 96 | | | |____... 97 | | |____val 98 | | |____00000001.jpg 99 | | |____00000002.jpg 100 | | |____... 101 | |____labels 102 | | 103 | |____train 104 | | |____00000001.txt 105 | | |____00000002.txt 106 | | |____... 107 | |____val 108 | |____00000001.txt 109 | |____00000002.txt 110 | |____... 111 | ``` 112 | 113 | ## 5.OSTrack Results :bulb: 114 | 115 | The automatic annotation of Template Image and Search Image required by OSTrack, and then use Opencv to crop out the obtained drone center coordinates with 2 times and 5 times the size of the bounding box as Template Image and Search Image, respectively. The main purpose of using OSTrack is to eliminate the interference of infrared drone instance images by environmental background, and to reduce the positioning range of single frame images by using template and search methods to improve positioning accuracy and reduce the possibility of environmental interference, thereby greatly improving the tracking performance. The results of the OSTrack model are shown below. 116 | 117 | ```bash 118 | # this function is used to clip frame for template and search area 119 | def ScaleClip(img, xyxy, mode=None): 120 | """ScaleClip is used to clip frame for template and search area 121 | :param img: the frame image must consists of UAV pixels 122 | :param xyxy: the up-left and down-right coordinates of the UAV bounding box""" 123 | img_array = np.array(img) 124 | width, height = xyxy[2] - xyxy[0], xyxy[3] - xyxy[1] 125 | center = np.array([xyxy[0] + width / 2, xyxy[1] + height / 2]) 126 | scale_factor = {'template': 2, 'search': 5}.get(mode, 0) 127 | scaled_width = int(scale_factor * width) 128 | scaled_height = int(scale_factor * height) 129 | # Calculate the cropping rectangle ensuring it does not exceed image boundaries. 130 | top_left_x = max(int(center[0] - scaled_width / 2), 0) 131 | top_left_y = max(int(center[1] - scaled_height / 2), 0) 132 | bottom_right_x = min(int(center[0] + scaled_width / 2), img_array.shape[1]) 133 | bottom_right_y = min(int(center[1] + scaled_height / 2), img_array.shape[0]) 134 | # Clip the image 135 | img_clipped = img_array[top_left_y:bottom_right_y, top_left_x:bottom_right_x, :] 136 | return img_clipped 137 | ``` 138 | 139 | After testing, the ostrack pretrained model offered by the [original author](https://github.com/LY-1/MCJT) has no effect on the tracking performance. And the testing code in app_osyo.py is as follows. 140 | 141 | ![pretrained model image](assets/ostrack_test0320.jpg) 142 | 143 | ```bash 144 | # 调用OSTrack模型测试 145 | self.xyxy = [xy.detach().cpu().numpy() for xy in self.xyxy] 146 | # 进行OSTrack模型裁剪,调用GPUs 147 | template_img = template_transform( ScaleClip(self.frame, self.xyxy, mode='template') ).unsqueeze(0).to(device) 148 | search_img = search_transform( ScaleClip(self.frame, self.xyxy, mode='search') ).unsqueeze(0).to(device) 149 | ostrack_results = self.ostrack(template_img, search_img) 150 | ostrack_results = ostrack_results['pred_boxes'][0] 151 | ostrack_results = ostrack_results.detach().cpu().numpy()[0] 152 | whwh = [int(ostrack_results[0]*self.frame_width), int(ostrack_results[1]*self.frame_height), 153 | int(ostrack_results[2]*self.frame_width), int(ostrack_results[3]*self.frame_height)] 154 | cv2.rectangle(self.frame, (whwh[0], whwh[1]), (whwh[2], whwh[3]), (0, 0, 255), 2) 155 | print(f'the OSTrack results: {ostrack_results}') 156 | print('the yolov5 preds: ',type(self.xyxy), '\t', self.xyxy) 157 | ``` 158 | 159 | When using the software test, it was found that when the drone appeared in the background to generate strong infrared light, it could not perform the normal location tracking task normally. For example, in the following two cases, there was a short-term drone tracking loss. Therefore, we consider using the YOLOv5 + OSTrack model, combined with the excellent positioning ability of the YOLOv5 model and the powerful continuous frame tracking ability of OSTrack, so as to achieve better location tracking effect. 160 | 161 | | infrared uav case 1 | infrared uav case 2 | 162 | | ------------------- | ------------------- | 163 | | ![case1](assets/exception/exception1.gif) | ![case2](assets/exception/exception2.gif) | 164 | 165 | ## 6.Thanks :heart: 166 | 167 | ```bash 168 | # If you are interested in the original project, you can click on the link below. 169 | https://github.com/LY-1/MCJT 170 | ``` 171 | 172 | If you need the pre-trained YOLOv5 infrared drone positioning model and OSTrack model weights, you can download it on Baidu Netdisk. The relevant download links are as follows: 173 | 174 | ```bash 175 | # Files shared via online disk:ostrack 176 | link: https://pan.baidu.com/s/1lPM_ACRkc-g8WDkB0tw7EA?pwd=92fk 177 | extraction code: 92fk 178 | ``` 179 | 180 | Meanwhile, it is declared that this project is the reproduction and improvement based on the work of [original author](https://github.com/LY-1/MCJT). We used a new self-made dataset for training and designed the GUI interface for deployment based on the trained model. Subsequently, we will also conduct actual operations to test the actual effect of the model. We are very grateful to the original author for his work. Of course, if you think our work based on this can attract you, please also give a little star. 181 | -------------------------------------------------------------------------------- /model/ostrack/layers/head.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from .frozen_bn import FrozenBatchNorm2d 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1, 9 | freeze_bn=False): 10 | if freeze_bn: 11 | return nn.Sequential( 12 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, 13 | padding=padding, dilation=dilation, bias=True), 14 | FrozenBatchNorm2d(out_planes), 15 | nn.ReLU(inplace=True)) 16 | else: 17 | return nn.Sequential( 18 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, 19 | padding=padding, dilation=dilation, bias=True), 20 | nn.BatchNorm2d(out_planes), 21 | nn.ReLU(inplace=True)) 22 | 23 | 24 | class Corner_Predictor(nn.Module): 25 | """ Corner Predictor module""" 26 | 27 | def __init__(self, inplanes=64, channel=256, feat_sz=20, stride=16, freeze_bn=False): 28 | super(Corner_Predictor, self).__init__() 29 | self.feat_sz = feat_sz 30 | self.stride = stride 31 | self.img_sz = self.feat_sz * self.stride 32 | '''top-left corner''' 33 | self.conv1_tl = conv(inplanes, channel, freeze_bn=freeze_bn) 34 | self.conv2_tl = conv(channel, channel // 2, freeze_bn=freeze_bn) 35 | self.conv3_tl = conv(channel // 2, channel // 4, freeze_bn=freeze_bn) 36 | self.conv4_tl = conv(channel // 4, channel // 8, freeze_bn=freeze_bn) 37 | self.conv5_tl = nn.Conv2d(channel // 8, 1, kernel_size=1) 38 | 39 | '''bottom-right corner''' 40 | self.conv1_br = conv(inplanes, channel, freeze_bn=freeze_bn) 41 | self.conv2_br = conv(channel, channel // 2, freeze_bn=freeze_bn) 42 | self.conv3_br = conv(channel // 2, channel // 4, freeze_bn=freeze_bn) 43 | self.conv4_br = conv(channel // 4, channel // 8, freeze_bn=freeze_bn) 44 | self.conv5_br = nn.Conv2d(channel // 8, 1, kernel_size=1) 45 | 46 | '''about coordinates and indexs''' 47 | with torch.no_grad(): 48 | self.indice = torch.arange(0, self.feat_sz).view(-1, 1) * self.stride 49 | # generate mesh-grid 50 | self.coord_x = self.indice.repeat((self.feat_sz, 1)) \ 51 | .view((self.feat_sz * self.feat_sz,)).float().cuda() 52 | self.coord_y = self.indice.repeat((1, self.feat_sz)) \ 53 | .view((self.feat_sz * self.feat_sz,)).float().cuda() 54 | 55 | def forward(self, x, return_dist=False, softmax=True): 56 | """ Forward pass with input x. """ 57 | score_map_tl, score_map_br = self.get_score_map(x) 58 | if return_dist: 59 | coorx_tl, coory_tl, prob_vec_tl = self.soft_argmax(score_map_tl, return_dist=True, softmax=softmax) 60 | coorx_br, coory_br, prob_vec_br = self.soft_argmax(score_map_br, return_dist=True, softmax=softmax) 61 | return torch.stack((coorx_tl, coory_tl, coorx_br, coory_br), dim=1) / self.img_sz, prob_vec_tl, prob_vec_br 62 | else: 63 | coorx_tl, coory_tl = self.soft_argmax(score_map_tl) 64 | coorx_br, coory_br = self.soft_argmax(score_map_br) 65 | return torch.stack((coorx_tl, coory_tl, coorx_br, coory_br), dim=1) / self.img_sz 66 | 67 | def get_score_map(self, x): 68 | # top-left branch 69 | x_tl1 = self.conv1_tl(x) 70 | x_tl2 = self.conv2_tl(x_tl1) 71 | x_tl3 = self.conv3_tl(x_tl2) 72 | x_tl4 = self.conv4_tl(x_tl3) 73 | score_map_tl = self.conv5_tl(x_tl4) 74 | 75 | # bottom-right branch 76 | x_br1 = self.conv1_br(x) 77 | x_br2 = self.conv2_br(x_br1) 78 | x_br3 = self.conv3_br(x_br2) 79 | x_br4 = self.conv4_br(x_br3) 80 | score_map_br = self.conv5_br(x_br4) 81 | return score_map_tl, score_map_br 82 | 83 | def soft_argmax(self, score_map, return_dist=False, softmax=True): 84 | """ get soft-argmax coordinate for a given heatmap """ 85 | score_vec = score_map.view((-1, self.feat_sz * self.feat_sz)) # (batch, feat_sz * feat_sz) 86 | prob_vec = nn.functional.softmax(score_vec, dim=1) 87 | exp_x = torch.sum((self.coord_x * prob_vec), dim=1) 88 | exp_y = torch.sum((self.coord_y * prob_vec), dim=1) 89 | if return_dist: 90 | if softmax: 91 | return exp_x, exp_y, prob_vec 92 | else: 93 | return exp_x, exp_y, score_vec 94 | else: 95 | return exp_x, exp_y 96 | 97 | 98 | class CenterPredictor(nn.Module, ): 99 | def __init__(self, inplanes=64, channel=256, feat_sz=20, stride=16, freeze_bn=False): 100 | super(CenterPredictor, self).__init__() 101 | self.feat_sz = feat_sz 102 | self.stride = stride 103 | self.img_sz = self.feat_sz * self.stride 104 | 105 | # corner predict 106 | self.conv1_ctr = conv(inplanes, channel, freeze_bn=freeze_bn) 107 | self.conv2_ctr = conv(channel, channel // 2, freeze_bn=freeze_bn) 108 | self.conv3_ctr = conv(channel // 2, channel // 4, freeze_bn=freeze_bn) 109 | self.conv4_ctr = conv(channel // 4, channel // 8, freeze_bn=freeze_bn) 110 | self.conv5_ctr = nn.Conv2d(channel // 8, 1, kernel_size=1) 111 | 112 | # size regress 113 | self.conv1_offset = conv(inplanes, channel, freeze_bn=freeze_bn) 114 | self.conv2_offset = conv(channel, channel // 2, freeze_bn=freeze_bn) 115 | self.conv3_offset = conv(channel // 2, channel // 4, freeze_bn=freeze_bn) 116 | self.conv4_offset = conv(channel // 4, channel // 8, freeze_bn=freeze_bn) 117 | self.conv5_offset = nn.Conv2d(channel // 8, 2, kernel_size=1) 118 | 119 | # size regress 120 | self.conv1_size = conv(inplanes, channel, freeze_bn=freeze_bn) 121 | self.conv2_size = conv(channel, channel // 2, freeze_bn=freeze_bn) 122 | self.conv3_size = conv(channel // 2, channel // 4, freeze_bn=freeze_bn) 123 | self.conv4_size = conv(channel // 4, channel // 8, freeze_bn=freeze_bn) 124 | self.conv5_size = nn.Conv2d(channel // 8, 2, kernel_size=1) 125 | 126 | for p in self.parameters(): 127 | if p.dim() > 1: 128 | nn.init.xavier_uniform_(p) 129 | 130 | def forward(self, x, gt_score_map=None): 131 | """ Forward pass with input x. """ 132 | score_map_ctr, size_map, offset_map = self.get_score_map(x) 133 | 134 | # assert gt_score_map is None 135 | if gt_score_map is None: 136 | bbox = self.cal_bbox(score_map_ctr, size_map, offset_map) 137 | else: 138 | bbox = self.cal_bbox(gt_score_map.unsqueeze(1), size_map, offset_map) 139 | 140 | return score_map_ctr, bbox, size_map, offset_map 141 | 142 | def cal_bbox(self, score_map_ctr, size_map, offset_map, return_score=False): 143 | max_score, idx = torch.max(score_map_ctr.flatten(1), dim=1, keepdim=True) 144 | idx_y = idx // self.feat_sz # 取整 145 | idx_x = idx % self.feat_sz # 取余 146 | 147 | idx = idx.unsqueeze(1).expand(idx.shape[0], 2, 1) 148 | size = size_map.flatten(2).gather(dim=2, index=idx) 149 | offset = offset_map.flatten(2).gather(dim=2, index=idx).squeeze(-1) 150 | 151 | # bbox = torch.cat([idx_x - size[:, 0] / 2, idx_y - size[:, 1] / 2, 152 | # idx_x + size[:, 0] / 2, idx_y + size[:, 1] / 2], dim=1) / self.feat_sz 153 | # cx, cy, w, h 154 | bbox = torch.cat([(idx_x.to(torch.float) + offset[:, :1]) / self.feat_sz, 155 | (idx_y.to(torch.float) + offset[:, 1:]) / self.feat_sz, 156 | size.squeeze(-1)], dim=1) 157 | 158 | if return_score: 159 | return bbox, max_score 160 | return bbox 161 | 162 | def get_pred(self, score_map_ctr, size_map, offset_map): 163 | max_score, idx = torch.max(score_map_ctr.flatten(1), dim=1, keepdim=True) 164 | idx_y = idx // self.feat_sz 165 | idx_x = idx % self.feat_sz 166 | 167 | idx = idx.unsqueeze(1).expand(idx.shape[0], 2, 1) 168 | size = size_map.flatten(2).gather(dim=2, index=idx) 169 | offset = offset_map.flatten(2).gather(dim=2, index=idx).squeeze(-1) 170 | 171 | # bbox = torch.cat([idx_x - size[:, 0] / 2, idx_y - size[:, 1] / 2, 172 | # idx_x + size[:, 0] / 2, idx_y + size[:, 1] / 2], dim=1) / self.feat_sz 173 | return size * self.feat_sz, offset 174 | 175 | def get_score_map(self, x): 176 | 177 | def _sigmoid(x): 178 | y = torch.clamp(x.sigmoid_(), min=1e-4, max=1 - 1e-4) 179 | return y 180 | 181 | # ctr branch 182 | x_ctr1 = self.conv1_ctr(x) 183 | x_ctr2 = self.conv2_ctr(x_ctr1) 184 | x_ctr3 = self.conv3_ctr(x_ctr2) 185 | x_ctr4 = self.conv4_ctr(x_ctr3) 186 | score_map_ctr = self.conv5_ctr(x_ctr4) 187 | 188 | # offset branch 189 | x_offset1 = self.conv1_offset(x) 190 | x_offset2 = self.conv2_offset(x_offset1) 191 | x_offset3 = self.conv3_offset(x_offset2) 192 | x_offset4 = self.conv4_offset(x_offset3) 193 | score_map_offset = self.conv5_offset(x_offset4) 194 | 195 | # size branch 196 | x_size1 = self.conv1_size(x) 197 | x_size2 = self.conv2_size(x_size1) 198 | x_size3 = self.conv3_size(x_size2) 199 | x_size4 = self.conv4_size(x_size3) 200 | score_map_size = self.conv5_size(x_size4) 201 | return _sigmoid(score_map_ctr), _sigmoid(score_map_size), score_map_offset 202 | 203 | 204 | class MLP(nn.Module): 205 | """ Very simple multi-layer perceptron (also called FFN)""" 206 | 207 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers, BN=False): 208 | super().__init__() 209 | self.num_layers = num_layers 210 | h = [hidden_dim] * (num_layers - 1) 211 | if BN: 212 | self.layers = nn.ModuleList(nn.Sequential(nn.Linear(n, k), nn.BatchNorm1d(k)) 213 | for n, k in zip([input_dim] + h, h + [output_dim])) 214 | else: 215 | self.layers = nn.ModuleList(nn.Linear(n, k) 216 | for n, k in zip([input_dim] + h, h + [output_dim])) 217 | 218 | def forward(self, x): 219 | for i, layer in enumerate(self.layers): 220 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 221 | return x 222 | 223 | 224 | def build_box_head(cfg, hidden_dim): 225 | stride = cfg['MODEL']['BACKBONE']['STRIDE'] 226 | 227 | if cfg['MODEL']['HEAD']['TYPE'] == "MLP": 228 | mlp_head = MLP(hidden_dim, hidden_dim, 4, 3) # dim_in, dim_hidden, dim_out, 3 layers 229 | return mlp_head 230 | elif "CORNER" in cfg['MODEL']['HEAD']['TYPE']: 231 | feat_sz = int(cfg['DATA']['SEARCH']['SIZE'] / stride) 232 | channel = getattr(cfg['MODEL'], "NUM_CHANNELS", 256) 233 | print("head channel: %d" % channel) 234 | if cfg['MODEL']['HEAD']['TYPE'] == "CORNER": 235 | corner_head = Corner_Predictor(inplanes=cfg['MODEL']['HIDDEN_DIM'], channel=channel, 236 | feat_sz=feat_sz, stride=stride) 237 | else: 238 | raise ValueError() 239 | return corner_head 240 | elif cfg['MODEL']['HEAD']['TYPE'] == "CENTER": 241 | in_channel = hidden_dim 242 | out_channel = cfg['MODEL']['HEAD']['NUM_CHANNELS'] 243 | feat_sz = int(cfg['DATA']['SEARCH']['SIZE'] / stride) 244 | center_head = CenterPredictor(inplanes=in_channel, channel=out_channel, 245 | feat_sz=feat_sz, stride=stride) 246 | return center_head 247 | else: 248 | raise ValueError("HEAD TYPE %s is not supported." % cfg['MODEL']['HEAD']['TYPE']) 249 | -------------------------------------------------------------------------------- /app_osyo.py: -------------------------------------------------------------------------------- 1 | """ 2 | 任务: 导入模型,创建GUI界面,实现对实时视频 3 | 捕获或者视频导入的无人机定位 4 | 时间: 2025/03/13-Redal 5 | """ 6 | import pathlib 7 | temp = pathlib.PosixPath 8 | pathlib.PosixPath = pathlib.WindowsPath 9 | 10 | import os 11 | import sys 12 | import cv2 13 | import threading 14 | import argparse 15 | import torch 16 | import tkinter as tk 17 | from tkinter import filedialog 18 | from PIL import Image, ImageTk 19 | from util import config 20 | from util import ComputeHistogramImage 21 | from util import CalculateSpectrogramImage 22 | from torchvision.transforms import transforms 23 | from yolov5.models.experimental import attempt_load 24 | from yolo_model import parser 25 | from yolo_model import load_yolo 26 | from bbox import decoder 27 | 28 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 29 | current_path = os.path.dirname(os.path.abspath(__file__)) 30 | sys.path.append(current_path) 31 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 32 | template_transform = transforms.Compose([ 33 | transforms.ToTensor(), 34 | transforms.Resize((192, 192)), 35 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),]) 36 | search_transform = transforms.Compose([ 37 | transforms.ToTensor(), 38 | transforms.Resize((384, 384)), 39 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),]) 40 | 41 | 42 | 43 | ######################## 定义GUI界面类 ####################### 44 | class OSTrackGUI(tk.Frame): 45 | """设计OSTrack主界面,用于完成多功能的介绍 46 | 以及相关功能的选择使用""" 47 | def __init__(self, root=None): 48 | super().__init__() 49 | self.root = root 50 | self.__set_widgets() 51 | self.frame = None 52 | self.last_xyxy = None 53 | self.dx, self.dy = 0, 0 54 | self.lost_frame_num = 0 55 | self.video_cap = cv2.VideoCapture(0) 56 | 57 | self.is_running = False 58 | self.video_cap = None 59 | self.video_thread = None 60 | self.live_video_flag = False 61 | self.import_video_flag = False 62 | self.track_video_flag = False 63 | self.export_video_flag = False 64 | # 初始化模型 65 | # self.ostrack = config() 66 | self.template_transform = template_transform 67 | self.sreach_transform = search_transform 68 | self.args = parser() 69 | self.yolo_model = load_yolo(self.args).to(device) 70 | # 定义缓存变量 71 | self.img_cached = [] 72 | self.fps = 0 73 | self.frame_width, self.frame_height = 0, 0 74 | self.fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') 75 | self.video_processed_dir = './video' 76 | 77 | def __set_widgets(self): 78 | self.root.title("OSTrack GUI-Redal") 79 | self.root.geometry("800x600") 80 | self.video_label = tk.Label(self.root, text="视频显示区域", width=500, height=400); self.video_label.place(x=0, y=0) 81 | self.title_label = tk.Label(self.root, text='无人机目标追踪', font=("仿宋", 15), fg="black", width=30, height=2); self.title_label.place(x=505, y=0) 82 | self.histogram_label = tk.Label(self.root, text="直方图显示区域", font=("仿宋", 10), fg="black", width=300, height=200); self.histogram_label.place(x=0, y=400) 83 | self.spectrogram_label = tk.Label(self.root, text='频谱图显示区域', font=("仿宋", 10), fg="black", width=200, height=200); self.spectrogram_label.place(x=300, y=400) 84 | self.message_title_label = tk.Label(self.root, text='主要功能信息提示', font=("仿宋", 15), fg="black", width=30, height=2 ); self.message_title_label.place(x=505, y=400) 85 | self.message_label = tk.Label(self.root,text="很感激您能使用我们的软件\n请选择您需要的功能......", 86 | font=("仿宋", 12),width=40,height=5, wraplength=300, justify="left"); self.message_label.place(x=505, y=440) 87 | 88 | # 软件主要功能的选择按钮,按钮大小(80, 30) 89 | self.live_video_button = tk.Button(self.root, text='实时视频', height=1, width=10, command=self.__start_live_video__); self.live_video_button.place(x=550, y=50) 90 | self.import_video_button = tk.Button(self.root, text='导入视频', height=1, width=10, command=self.__start_import_video__); self.import_video_button.place(x=550, y=80) 91 | self.exit_button = tk.Button(self.root, text='退出程序', height=1, width=10, command=self.root.quit); self.exit_button.place(x=550, y=110) 92 | self.live_video_change_button = tk.Button(self.root, text='退出实时', height=1, width=10, command=self.__end_live_video__); self.live_video_change_button.place(x=670, y=50) 93 | self.import_video_change_button = tk.Button(self.root, text='退出导入', height=1, width=10, command=self.__end_import_video__); self.import_video_change_button.place(x=670, y=80) 94 | self.frame_shot_button = tk.Button(self.root, text='截取图像', height=1, width=10, command=self.__frame_shot__); self.frame_shot_button.place(x=670, y=110) 95 | 96 | # 界面初始显示图像 97 | self.main_window_img = cv2.resize( cv2.imread("images/main_window_img.png"), (500, 400) ) 98 | self.main_window_img = ImageTk.PhotoImage(image = Image.fromarray(self.main_window_img)) 99 | self.video_label.config(image=self.main_window_img) 100 | self.video_label.image = self.main_window_img 101 | 102 | # 应用模型进行跟踪处理按钮 103 | self.track_model_button = tk.Button(self.root, text='视频跟踪', height=1, width=10, command=self.__track_model__); self.track_model_button.place(x=550, y=170) 104 | self.video_export_button = tk.Button(self.root, text='视频导出', height=1, width=10, command=self.__video_export__); self.video_export_button.place(x=670, y=170) 105 | 106 | def __start_live_video__(self): 107 | """功能: 实时视频捕获""" 108 | if self.video_cap is not None: 109 | self.video_cap.release() 110 | self.live_video_flag = True 111 | self.import_video_flag = False 112 | text = "实时视频: 调用电脑或外置设备相机\n进行实时视频捕捉,再进行视频跟踪。\n注意: 退出此模式,请双击'退出实时'\n实时视频捕获中......" 113 | self.message_label.config(text=text) 114 | 115 | self.video_cap = cv2.VideoCapture(0) # 使用默认电脑相机 116 | self.is_running = True 117 | self.video_thread = threading.Thread(target=self.__video_loop__) 118 | self.video_thread.daemon = True 119 | self.video_thread.start() 120 | 121 | def __end_live_video__(self): 122 | """功能: 退出实时视频捕获""" 123 | self.is_running = False 124 | if self.video_cap is not None: 125 | self.video_cap.release() 126 | self.live_video_flag = False 127 | self.export_video_flag = False 128 | self.video_cap = None 129 | self.video_thread = None 130 | # 恢复为初始状态 131 | self.video_label.config(image=self.main_window_img) 132 | self.video_label.image = self.main_window_img 133 | 134 | def __start_import_video__(self): 135 | """功能: 导入用户视频""" 136 | file_path = filedialog.askopenfilename(filetypes=[("视频文件", "*.mp4 *.avi *.mov")]) 137 | if file_path: 138 | if self.video_cap is not None: 139 | self.video_cap.release() 140 | self.import_video_flag = True 141 | self.live_video_flag = False 142 | self.track_video_flag = False 143 | text = "导入视频: 导入用户自定义视频文件\n进行视频帧捕捉,再进行跟踪处理。\n注意: 退出此模式,请双击'退出导入'\n用户视频导入中......" 144 | self.message_label.config(text=text) 145 | 146 | self.video_cap = cv2.VideoCapture(file_path) # 打开视频文件 147 | self.is_running = True 148 | self.video_thread = threading.Thread(target=self.__video_loop__) 149 | self.video_thread.daemon = True 150 | self.video_thread.start() 151 | self.img_cached = [] # 每次导入清除缓存 152 | 153 | file_name = os.path.basename(file_path) 154 | self.fps = int(self.video_cap.get(cv2.CAP_PROP_FPS)) 155 | self.frame_height = int(self.video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 156 | self.frame_width = int(self.video_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 157 | self.video_processed_dir = os.path.join(self.video_processed_dir, f'processed_{file_name}') 158 | 159 | def __end_import_video__(self): 160 | """功能: 退出导入视频捕获""" 161 | self.is_running = False 162 | if self.video_cap is not None: 163 | self.video_cap.release() 164 | self.import_video_flag = False 165 | self.video_cap = None 166 | self.video_thread = None 167 | # 恢复为初始状态 168 | self.video_label.config(image=self.main_window_img) 169 | self.video_label.image = self.main_window_img 170 | 171 | def __frame_shot__(self): 172 | """功能: 截取当前视频帧""" 173 | frame_save_dir = 'frames' 174 | if not os.path.exists(frame_save_dir): 175 | os.makedirs(frame_save_dir) 176 | frame_number = len(os.listdir(frame_save_dir)) 177 | img_rgb = cv2.cvtColor(self.frame, cv2.COLOR_BGR2RGB) 178 | cv2.imwrite(f"frames/shot_{frame_number}.jpg", img_rgb) 179 | text = "截取图像: 用户自定义截取主页面图像\n截取图像已存放在frames文件夹。\n图像截取中......" 180 | self.message_label.config(text=text) 181 | 182 | def __track_model__(self): 183 | """功能: 应用模型进行跟踪处理""" 184 | text = "视频跟踪: 对视频进行无人机跟踪\n并将无人机以边框的形势展示。\n无人机视频跟踪中......" 185 | self.message_label.config(text=text) 186 | self.track_video_flag = not self.track_video_flag 187 | 188 | def __video_export__(self): 189 | """功能: 导出当前处理好的用户自定义的视频""" 190 | text = "视频导出: 对用户导入视频跟踪处理\n处理视频存放在processed文件夹。\n无人机视频导出中......" 191 | self.message_label.config(text=text) 192 | self.export_video_flag = not self.export_video_flag 193 | if self.export_video_flag: 194 | # 如果开启导出,初始化视频类 195 | self.video_processed = cv2.VideoWriter(self.video_processed_dir, self.fourcc, 196 | self.fps, (self.frame_width, self.frame_height)) 197 | if not self.export_video_flag: 198 | self.video_processed.release() 199 | 200 | 201 | def __video_loop__(self): 202 | """程序主界面播放视频""" 203 | while self.is_running and self.video_cap.isOpened(): 204 | success, frame = self.video_cap.read() 205 | if success: 206 | # 实时调用电脑摄像头捕获 207 | if self.live_video_flag: 208 | self.frame = cv2.flip(cv2.resize( cv2.cvtColor(frame, 209 | cv2.COLOR_BGR2RGB), (500, 400)) ,1) 210 | # 跟踪无人机 211 | if self.track_video_flag: 212 | self.frame, _ = decoder(self.yolo_model, self.frame) 213 | # 导出无人机视频画面 214 | if self.export_video_flag: 215 | self.frame = cv2.resize( cv2.cvtColor(self.frame, cv2.COLOR_BGR2RGB), 216 | (self.frame_width, self.frame_height)) 217 | self.video_processed.write(self.frame) 218 | text ="视频跟踪: 对视频进行无人机跟踪\n并将无人机以边框的形势展示。\n无人机视频跟踪中......\n视频导出中......" 219 | self.message_label.config(text=text) 220 | 221 | # 调用用户自定义视频捕获 222 | elif self.import_video_flag: 223 | self.frame = cv2.resize( cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), 224 | (self.frame_width, self.frame_height)) 225 | # 跟踪无人机 226 | if self.track_video_flag: 227 | self.frame, self.xyxy = decoder(self.yolo_model, self.frame) 228 | # if self.xyxy is not None: 229 | # self.lost_frame_num = 0 230 | # if self.last_xyxy is not None: 231 | # # 计算差值 232 | # self.dx = self.last_xyxy[0] - self.xyxy[0] 233 | # self.dy = self.last_xyxy[1] - self.xyxy[1] 234 | # self.last_xyxy = self.xyxy 235 | # if self.xyxy is None: 236 | # self.lost_frame_num += 1 237 | # # 确保坐标为整数 238 | # pt1 = (int(self.last_xyxy[0] + self.dx*self.lost_frame_num), int(self.last_xyxy[1] + self.dy*self.lost_frame_num)) 239 | # pt2 = (int(self.last_xyxy[2] + self.dx*self.lost_frame_num), int(self.last_xyxy[3] + self.dy*self.lost_frame_num)) 240 | # cv2.rectangle(self.frame, pt1, pt2, (0, 255, 0), -1, cv2.LINE_AA) 241 | # # 动态计算tl的值 242 | # tl = round(0.002 * (self.frame.shape[0] + self.frame.shape[1]) / 2) + 1 if round(0.002 * (self.frame.shape[0] + self.frame.shape[1]) / 2) + 1 > 3 else 3 243 | # tf = max(tl - 1, 1) 244 | # cv2.putText(self.frame, f'drone: 0.9', (int(self.last_xyxy[0]), int(self.last_xyxy[1]) - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA) 245 | # 导出无人机视频画面 246 | if self.export_video_flag: 247 | self.frame = cv2.resize( cv2.cvtColor(self.frame, cv2.COLOR_BGR2RGB), 248 | (self.frame_width, self.frame_height)) 249 | self.video_processed.write(self.frame) 250 | text ="视频跟踪: 对视频进行无人机跟踪\n并将无人机以边框的形势展示。\n无人机视频跟踪中......\n视频导出中......" 251 | self.message_label.config(text=text) 252 | 253 | # 计算直方图以及频谱图并显示 254 | hist_image = ComputeHistogramImage(cv2.cvtColor(self.frame, cv2.COLOR_BGR2RGB)) 255 | spec_image = CalculateSpectrogramImage(self.frame) 256 | self.__show_frame__(hist_image=hist_image, spec_image=spec_image) 257 | else: break 258 | self.root.update_idletasks() 259 | 260 | def __show_frame__(self, hist_image=None, spec_image=None): 261 | """固定在self.video_label上显示视频""" 262 | img_fromarray = Image.fromarray(self.frame) 263 | imgtk = ImageTk.PhotoImage(image=img_fromarray) 264 | self.video_label.config(image=imgtk) 265 | self.video_label.image = imgtk 266 | # 绘制直方图和频谱图 267 | if hist_image is not None: 268 | hist_image = ImageTk.PhotoImage(image=Image.fromarray(hist_image)) 269 | self.histogram_label.config(image=hist_image) 270 | self.histogram_label.image = hist_image 271 | if spec_image is not None: 272 | spec_image = ImageTk.PhotoImage(image=Image.fromarray(spec_image)) 273 | self.spectrogram_label.config(image=spec_image) 274 | self.spectrogram_label.image = spec_image 275 | self.root.after(20) 276 | 277 | 278 | 279 | ################################ 主控测试函数 ############################# 280 | if __name__=='__main__': 281 | root = tk.Tk() 282 | app = OSTrackGUI(root=root) 283 | app.mainloop() -------------------------------------------------------------------------------- /model/ostrack/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Misc functions, including distributed helpers. 4 | 5 | Mostly copy-paste from torchvision references. 6 | """ 7 | import os 8 | import subprocess 9 | import time 10 | from collections import defaultdict, deque 11 | import datetime 12 | import pickle 13 | from typing import Optional, List 14 | 15 | import torch 16 | import torch.distributed as dist 17 | from torch import Tensor 18 | 19 | # needed due to empty tensor bug in pytorch and torchvision 0.5 20 | import torchvision 21 | vers = torchvision.__version__.split('.') 22 | if int(vers[0]) <= 0 and int(vers[1]) < 7: 23 | from torchvision.ops import _new_empty_tensor 24 | from torchvision.ops.misc import _output_size 25 | 26 | 27 | class SmoothedValue(object): 28 | """Track a series of values and provide access to smoothed values over a 29 | window or the global series average. 30 | """ 31 | 32 | def __init__(self, window_size=20, fmt=None): 33 | if fmt is None: 34 | fmt = "{median:.4f} ({global_avg:.4f})" 35 | self.deque = deque(maxlen=window_size) 36 | self.total = 0.0 37 | self.count = 0 38 | self.fmt = fmt 39 | 40 | def update(self, value, n=1): 41 | self.deque.append(value) 42 | self.count += n 43 | self.total += value * n 44 | 45 | def synchronize_between_processes(self): 46 | """ 47 | Warning: does not synchronize the deque! 48 | """ 49 | if not is_dist_avail_and_initialized(): 50 | return 51 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 52 | dist.barrier() 53 | dist.all_reduce(t) 54 | t = t.tolist() 55 | self.count = int(t[0]) 56 | self.total = t[1] 57 | 58 | @property 59 | def median(self): 60 | d = torch.tensor(list(self.deque)) 61 | return d.median().item() 62 | 63 | @property 64 | def avg(self): 65 | d = torch.tensor(list(self.deque), dtype=torch.float32) 66 | return d.mean().item() 67 | 68 | @property 69 | def global_avg(self): 70 | return self.total / self.count 71 | 72 | @property 73 | def max(self): 74 | return max(self.deque) 75 | 76 | @property 77 | def value(self): 78 | return self.deque[-1] 79 | 80 | def __str__(self): 81 | return self.fmt.format( 82 | median=self.median, 83 | avg=self.avg, 84 | global_avg=self.global_avg, 85 | max=self.max, 86 | value=self.value) 87 | 88 | 89 | def all_gather(data): 90 | """ 91 | Run all_gather on arbitrary picklable data (not necessarily tensors) 92 | Args: 93 | data: any picklable object 94 | Returns: 95 | list[data]: list of data gathered from each rank 96 | """ 97 | world_size = get_world_size() 98 | if world_size == 1: 99 | return [data] 100 | 101 | # serialized to a Tensor 102 | buffer = pickle.dumps(data) 103 | storage = torch.ByteStorage.from_buffer(buffer) 104 | tensor = torch.ByteTensor(storage).to("cuda") 105 | 106 | # obtain Tensor size of each rank 107 | local_size = torch.tensor([tensor.numel()], device="cuda") 108 | size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] 109 | dist.all_gather(size_list, local_size) 110 | size_list = [int(size.item()) for size in size_list] 111 | max_size = max(size_list) 112 | 113 | # receiving Tensor from all ranks 114 | # we pad the tensor because torch all_gather does not support 115 | # gathering tensors of different shapes 116 | tensor_list = [] 117 | for _ in size_list: 118 | tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) 119 | if local_size != max_size: 120 | padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") 121 | tensor = torch.cat((tensor, padding), dim=0) 122 | dist.all_gather(tensor_list, tensor) 123 | 124 | data_list = [] 125 | for size, tensor in zip(size_list, tensor_list): 126 | buffer = tensor.cpu().numpy().tobytes()[:size] 127 | data_list.append(pickle.loads(buffer)) 128 | 129 | return data_list 130 | 131 | 132 | def reduce_dict(input_dict, average=True): 133 | """ 134 | Args: 135 | input_dict (dict): all the values will be reduced 136 | average (bool): whether to do average or sum 137 | Reduce the values in the dictionary from all processes so that all processes 138 | have the averaged results. Returns a dict with the same fields as 139 | input_dict, after reduction. 140 | """ 141 | world_size = get_world_size() 142 | if world_size < 2: 143 | return input_dict 144 | with torch.no_grad(): 145 | names = [] 146 | values = [] 147 | # sort the keys so that they are consistent across processes 148 | for k in sorted(input_dict.keys()): 149 | names.append(k) 150 | values.append(input_dict[k]) 151 | values = torch.stack(values, dim=0) 152 | dist.all_reduce(values) 153 | if average: 154 | values /= world_size 155 | reduced_dict = {k: v for k, v in zip(names, values)} 156 | return reduced_dict 157 | 158 | 159 | class MetricLogger(object): 160 | def __init__(self, delimiter="\t"): 161 | self.meters = defaultdict(SmoothedValue) 162 | self.delimiter = delimiter 163 | 164 | def update(self, **kwargs): 165 | for k, v in kwargs.items(): 166 | if isinstance(v, torch.Tensor): 167 | v = v.item() 168 | assert isinstance(v, (float, int)) 169 | self.meters[k].update(v) 170 | 171 | def __getattr__(self, attr): 172 | if attr in self.meters: 173 | return self.meters[attr] 174 | if attr in self.__dict__: 175 | return self.__dict__[attr] 176 | raise AttributeError("'{}' object has no attribute '{}'".format( 177 | type(self).__name__, attr)) 178 | 179 | def __str__(self): 180 | loss_str = [] 181 | for name, meter in self.meters.items(): 182 | loss_str.append( 183 | "{}: {}".format(name, str(meter)) 184 | ) 185 | return self.delimiter.join(loss_str) 186 | 187 | def synchronize_between_processes(self): 188 | for meter in self.meters.values(): 189 | meter.synchronize_between_processes() 190 | 191 | def add_meter(self, name, meter): 192 | self.meters[name] = meter 193 | 194 | def log_every(self, iterable, print_freq, header=None): 195 | i = 0 196 | if not header: 197 | header = '' 198 | start_time = time.time() 199 | end = time.time() 200 | iter_time = SmoothedValue(fmt='{avg:.4f}') 201 | data_time = SmoothedValue(fmt='{avg:.4f}') 202 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 203 | if torch.cuda.is_available(): 204 | log_msg = self.delimiter.join([ 205 | header, 206 | '[{0' + space_fmt + '}/{1}]', 207 | 'eta: {eta}', 208 | '{meters}', 209 | 'time: {time}', 210 | 'data: {data}', 211 | 'max mem: {memory:.0f}' 212 | ]) 213 | else: 214 | log_msg = self.delimiter.join([ 215 | header, 216 | '[{0' + space_fmt + '}/{1}]', 217 | 'eta: {eta}', 218 | '{meters}', 219 | 'time: {time}', 220 | 'data: {data}' 221 | ]) 222 | MB = 1024.0 * 1024.0 223 | for obj in iterable: 224 | data_time.update(time.time() - end) 225 | yield obj 226 | iter_time.update(time.time() - end) 227 | if i % print_freq == 0 or i == len(iterable) - 1: 228 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 229 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 230 | if torch.cuda.is_available(): 231 | print(log_msg.format( 232 | i, len(iterable), eta=eta_string, 233 | meters=str(self), 234 | time=str(iter_time), data=str(data_time), 235 | memory=torch.cuda.max_memory_allocated() / MB)) 236 | else: 237 | print(log_msg.format( 238 | i, len(iterable), eta=eta_string, 239 | meters=str(self), 240 | time=str(iter_time), data=str(data_time))) 241 | i += 1 242 | end = time.time() 243 | total_time = time.time() - start_time 244 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 245 | print('{} Total time: {} ({:.4f} s / it)'.format( 246 | header, total_time_str, total_time / len(iterable))) 247 | 248 | 249 | def get_sha(): 250 | cwd = os.path.dirname(os.path.abspath(__file__)) 251 | 252 | def _run(command): 253 | return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() 254 | sha = 'N/A' 255 | diff = "clean" 256 | branch = 'N/A' 257 | try: 258 | sha = _run(['git', 'rev-parse', 'HEAD']) 259 | subprocess.check_output(['git', 'diff'], cwd=cwd) 260 | diff = _run(['git', 'diff-index', 'HEAD']) 261 | diff = "has uncommited changes" if diff else "clean" 262 | branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) 263 | except Exception: 264 | pass 265 | message = f"sha: {sha}, status: {diff}, branch: {branch}" 266 | return message 267 | 268 | 269 | def collate_fn(batch): 270 | batch = list(zip(*batch)) 271 | batch[0] = nested_tensor_from_tensor_list(batch[0]) 272 | return tuple(batch) 273 | 274 | 275 | def _max_by_axis(the_list): 276 | # type: (List[List[int]]) -> List[int] 277 | maxes = the_list[0] # get the first one 278 | for sublist in the_list[1:]: # [h,w,3] 279 | for index, item in enumerate(sublist): # index: 0,1,2 280 | maxes[index] = max(maxes[index], item) # compare current max with the other elements in the whole 281 | return maxes 282 | 283 | 284 | class NestedTensor(object): 285 | def __init__(self, tensors, mask: Optional[Tensor]): 286 | self.tensors = tensors 287 | self.mask = mask 288 | 289 | def to(self, device): 290 | cast_tensor = self.tensors.to(device) 291 | mask = self.mask 292 | if mask is not None: 293 | assert mask is not None 294 | cast_mask = mask.to(device) 295 | else: 296 | cast_mask = None 297 | return NestedTensor(cast_tensor, cast_mask) 298 | 299 | def decompose(self): 300 | return self.tensors, self.mask 301 | 302 | def __repr__(self): 303 | return str(self.tensors) 304 | 305 | 306 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): 307 | # TODO make this more general 308 | if tensor_list[0].ndim == 3: 309 | if torchvision._is_tracing(): 310 | # nested_tensor_from_tensor_list() does not export well to ONNX 311 | # call _onnx_nested_tensor_from_tensor_list() instead 312 | return _onnx_nested_tensor_from_tensor_list(tensor_list) 313 | 314 | # TODO make it support different-sized images 315 | max_size = _max_by_axis([list(img.shape) for img in tensor_list]) # [[3,h1,w1], [3,h2,w2], [3,h3,w3], ...] 316 | # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) 317 | batch_shape = [len(tensor_list)] + max_size # () 318 | b, c, h, w = batch_shape 319 | dtype = tensor_list[0].dtype 320 | device = tensor_list[0].device 321 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device) 322 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device) 323 | for img, pad_img, m in zip(tensor_list, tensor, mask): 324 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) # copy valid regions of the images to the largest padded base. 325 | m[: img.shape[1], :img.shape[2]] = False 326 | else: 327 | raise ValueError('not supported') 328 | return NestedTensor(tensor, mask) 329 | 330 | 331 | # _onnx_nested_tensor_from_tensor_list() is an implementation of 332 | # nested_tensor_from_tensor_list() that is supported by ONNX tracing. 333 | @torch.jit.unused 334 | def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: 335 | max_size = [] 336 | for i in range(tensor_list[0].dim()): 337 | max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64) 338 | max_size.append(max_size_i) 339 | max_size = tuple(max_size) 340 | 341 | # work around for 342 | # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 343 | # m[: img.shape[1], :img.shape[2]] = False 344 | # which is not yet supported in onnx 345 | padded_imgs = [] 346 | padded_masks = [] 347 | for img in tensor_list: 348 | padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] 349 | padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) 350 | padded_imgs.append(padded_img) 351 | 352 | m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) 353 | padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) 354 | padded_masks.append(padded_mask.to(torch.bool)) 355 | 356 | tensor = torch.stack(padded_imgs) 357 | mask = torch.stack(padded_masks) 358 | 359 | return NestedTensor(tensor, mask=mask) 360 | 361 | 362 | def setup_for_distributed(is_master): 363 | """ 364 | This function disables printing when not in master process 365 | """ 366 | import builtins as __builtin__ 367 | builtin_print = __builtin__.print 368 | 369 | def print(*args, **kwargs): 370 | force = kwargs.pop('force', False) 371 | if is_master or force: 372 | builtin_print(*args, **kwargs) 373 | 374 | __builtin__.print = print 375 | 376 | 377 | def is_dist_avail_and_initialized(): 378 | if not dist.is_available(): 379 | return False 380 | if not dist.is_initialized(): 381 | return False 382 | return True 383 | 384 | 385 | def get_world_size(): 386 | if not is_dist_avail_and_initialized(): 387 | return 1 388 | return dist.get_world_size() 389 | 390 | 391 | def get_rank(): 392 | if not is_dist_avail_and_initialized(): 393 | return 0 394 | return dist.get_rank() 395 | 396 | 397 | def is_main_process(): 398 | return get_rank() == 0 399 | 400 | 401 | def save_on_master(*args, **kwargs): 402 | if is_main_process(): 403 | torch.save(*args, **kwargs) 404 | 405 | 406 | def init_distributed_mode(args): 407 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 408 | args.rank = int(os.environ["RANK"]) 409 | args.world_size = int(os.environ['WORLD_SIZE']) 410 | args.gpu = int(os.environ['LOCAL_RANK']) 411 | elif 'SLURM_PROCID' in os.environ: 412 | args.rank = int(os.environ['SLURM_PROCID']) 413 | args.gpu = args.rank % torch.cuda.device_count() 414 | else: 415 | print('Not using distributed mode') 416 | args.distributed = False 417 | return 418 | 419 | args.distributed = True 420 | 421 | torch.cuda.set_device(args.gpu) 422 | args.dist_backend = 'nccl' 423 | print('| distributed init (rank {}): {}'.format( 424 | args.rank, args.dist_url), flush=True) 425 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 426 | world_size=args.world_size, rank=args.rank) 427 | torch.distributed.barrier() 428 | setup_for_distributed(args.rank == 0) 429 | 430 | 431 | @torch.no_grad() 432 | def accuracy(output, target, topk=(1,)): 433 | """Computes the precision@k for the specified values of k""" 434 | if target.numel() == 0: 435 | return [torch.zeros([], device=output.device)] 436 | maxk = max(topk) 437 | batch_size = target.size(0) 438 | 439 | _, pred = output.topk(maxk, 1, True, True) 440 | pred = pred.t() 441 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 442 | 443 | res = [] 444 | for k in topk: 445 | correct_k = correct[:k].view(-1).float().sum(0) 446 | res.append(correct_k.mul_(100.0 / batch_size)) 447 | return res 448 | 449 | 450 | def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): 451 | # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor 452 | """ 453 | Equivalent to nn.functional.interpolate, but with support for empty batch sizes. 454 | This will eventually be supported natively by PyTorch, and this 455 | class can go away. 456 | """ 457 | if float(torchvision.__version__[:3]) < 0.7: 458 | if input.numel() > 0: 459 | return torch.nn.functional.interpolate( 460 | input, size, scale_factor, mode, align_corners 461 | ) 462 | 463 | output_shape = _output_size(2, input, size, scale_factor) 464 | output_shape = list(input.shape[:-2]) + list(output_shape) 465 | return _new_empty_tensor(input, output_shape) 466 | else: 467 | return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) 468 | -------------------------------------------------------------------------------- /model/ostrack/vit.py: -------------------------------------------------------------------------------- 1 | import math 2 | import logging 3 | from functools import partial 4 | from collections import OrderedDict 5 | from copy import deepcopy 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 12 | from timm.models.helpers import build_model_with_cfg, named_apply, adapt_input_conv 13 | from timm.models.layers import Mlp, DropPath, trunc_normal_, lecun_normal_ 14 | from timm.models.registry import register_model 15 | 16 | from .layers.patch_embed import PatchEmbed 17 | from .base_backbone import BaseBackbone 18 | 19 | 20 | class Attention(nn.Module): 21 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): 22 | super().__init__() 23 | self.num_heads = num_heads 24 | head_dim = dim // num_heads 25 | self.scale = head_dim ** -0.5 26 | 27 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 28 | self.attn_drop = nn.Dropout(attn_drop) 29 | self.proj = nn.Linear(dim, dim) 30 | self.proj_drop = nn.Dropout(proj_drop) 31 | 32 | def forward(self, x, return_attention=False): 33 | B, N, C = x.shape 34 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 35 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 36 | 37 | attn = (q @ k.transpose(-2, -1)) * self.scale 38 | attn = attn.softmax(dim=-1) 39 | attn = self.attn_drop(attn) 40 | 41 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 42 | x = self.proj(x) 43 | x = self.proj_drop(x) 44 | 45 | if return_attention: 46 | return x, attn 47 | return x 48 | 49 | 50 | class Block(nn.Module): 51 | 52 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 53 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 54 | super().__init__() 55 | self.norm1 = norm_layer(dim) 56 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 57 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 58 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 59 | self.norm2 = norm_layer(dim) 60 | mlp_hidden_dim = int(dim * mlp_ratio) 61 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 62 | 63 | def forward(self, x, return_attention=False): 64 | if return_attention: 65 | feat, attn = self.attn(self.norm1(x), True) 66 | x = x + self.drop_path(feat) 67 | x = x + self.drop_path(self.mlp(self.norm2(x))) 68 | return x, attn 69 | else: 70 | x = x + self.drop_path(self.attn(self.norm1(x))) 71 | x = x + self.drop_path(self.mlp(self.norm2(x))) 72 | return x 73 | 74 | 75 | class VisionTransformer(BaseBackbone): 76 | """ Vision Transformer 77 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` 78 | - https://arxiv.org/abs/2010.11929 79 | Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` 80 | - https://arxiv.org/abs/2012.12877 81 | """ 82 | 83 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 84 | num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, 85 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, 86 | act_layer=None, weight_init=''): 87 | """ 88 | Args: 89 | img_size (int, tuple): input image size 90 | patch_size (int, tuple): patch size 91 | in_chans (int): number of input channels 92 | num_classes (int): number of classes for classification head 93 | embed_dim (int): embedding dimension 94 | depth (int): depth of transformer 95 | num_heads (int): number of attention heads 96 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 97 | qkv_bias (bool): enable bias for qkv if True 98 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 99 | distilled (bool): model includes a distillation token and head as in DeiT models 100 | drop_rate (float): dropout rate 101 | attn_drop_rate (float): attention dropout rate 102 | drop_path_rate (float): stochastic depth rate 103 | embed_layer (nn.Module): patch embedding layer 104 | norm_layer: (nn.Module): normalization layer 105 | weight_init: (str): weight init scheme 106 | """ 107 | super().__init__() 108 | self.num_classes = num_classes 109 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 110 | self.num_tokens = 2 if distilled else 1 111 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 112 | act_layer = act_layer or nn.GELU 113 | 114 | self.patch_embed = embed_layer( 115 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 116 | num_patches = self.patch_embed.num_patches 117 | 118 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 119 | self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None 120 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) 121 | self.pos_drop = nn.Dropout(p=drop_rate) 122 | 123 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 124 | self.blocks = nn.Sequential(*[ 125 | Block( 126 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, 127 | attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) 128 | for i in range(depth)]) 129 | self.norm = norm_layer(embed_dim) 130 | 131 | # # Representation layer 132 | # if representation_size and not distilled: 133 | # self.num_features = representation_size 134 | # self.pre_logits = nn.Sequential(OrderedDict([ 135 | # ('fc', nn.Linear(embed_dim, representation_size)), 136 | # ('act', nn.Tanh()) 137 | # ])) 138 | # else: 139 | # self.pre_logits = nn.Identity() 140 | # 141 | # # Classifier head(s) 142 | # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 143 | # self.head_dist = None 144 | # if distilled: 145 | # self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 146 | 147 | self.init_weights(weight_init) 148 | 149 | def init_weights(self, mode=''): 150 | assert mode in ('jax', 'jax_nlhb', 'nlhb', '') 151 | head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. 152 | trunc_normal_(self.pos_embed, std=.02) 153 | if self.dist_token is not None: 154 | trunc_normal_(self.dist_token, std=.02) 155 | if mode.startswith('jax'): 156 | # leave cls token as zeros to match jax impl 157 | named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self) 158 | else: 159 | trunc_normal_(self.cls_token, std=.02) 160 | self.apply(_init_vit_weights) 161 | 162 | def _init_weights(self, m): 163 | # this fn left here for compat with downstream users 164 | _init_vit_weights(m) 165 | 166 | @torch.jit.ignore() 167 | def load_pretrained(self, checkpoint_path, prefix=''): 168 | _load_weights(self, checkpoint_path, prefix) 169 | 170 | @torch.jit.ignore 171 | def no_weight_decay(self): 172 | return {'pos_embed', 'cls_token', 'dist_token'} 173 | 174 | def get_classifier(self): 175 | if self.dist_token is None: 176 | return self.head 177 | else: 178 | return self.head, self.head_dist 179 | 180 | def reset_classifier(self, num_classes, global_pool=''): 181 | self.num_classes = num_classes 182 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 183 | if self.num_tokens == 2: 184 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 185 | 186 | 187 | def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False): 188 | """ ViT weight initialization 189 | * When called without n, head_bias, jax_impl args it will behave exactly the same 190 | as my original init for compatibility with prev hparam / downstream use cases (ie DeiT). 191 | * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl 192 | """ 193 | if isinstance(module, nn.Linear): 194 | if name.startswith('head'): 195 | nn.init.zeros_(module.weight) 196 | nn.init.constant_(module.bias, head_bias) 197 | elif name.startswith('pre_logits'): 198 | lecun_normal_(module.weight) 199 | nn.init.zeros_(module.bias) 200 | else: 201 | if jax_impl: 202 | nn.init.xavier_uniform_(module.weight) 203 | if module.bias is not None: 204 | if 'mlp' in name: 205 | nn.init.normal_(module.bias, std=1e-6) 206 | else: 207 | nn.init.zeros_(module.bias) 208 | else: 209 | trunc_normal_(module.weight, std=.02) 210 | if module.bias is not None: 211 | nn.init.zeros_(module.bias) 212 | elif jax_impl and isinstance(module, nn.Conv2d): 213 | # NOTE conv was left to pytorch default in my original init 214 | lecun_normal_(module.weight) 215 | if module.bias is not None: 216 | nn.init.zeros_(module.bias) 217 | elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): 218 | nn.init.zeros_(module.bias) 219 | nn.init.ones_(module.weight) 220 | 221 | 222 | @torch.no_grad() 223 | def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): 224 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation 225 | """ 226 | import numpy as np 227 | 228 | def _n2p(w, t=True): 229 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: 230 | w = w.flatten() 231 | if t: 232 | if w.ndim == 4: 233 | w = w.transpose([3, 2, 0, 1]) 234 | elif w.ndim == 3: 235 | w = w.transpose([2, 0, 1]) 236 | elif w.ndim == 2: 237 | w = w.transpose([1, 0]) 238 | return torch.from_numpy(w) 239 | 240 | w = np.load(checkpoint_path) 241 | if not prefix and 'opt/target/embedding/kernel' in w: 242 | prefix = 'opt/target/' 243 | 244 | if hasattr(model.patch_embed, 'backbone'): 245 | # hybrid 246 | backbone = model.patch_embed.backbone 247 | stem_only = not hasattr(backbone, 'stem') 248 | stem = backbone if stem_only else backbone.stem 249 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) 250 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) 251 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) 252 | if not stem_only: 253 | for i, stage in enumerate(backbone.stages): 254 | for j, block in enumerate(stage.blocks): 255 | bp = f'{prefix}block{i + 1}/unit{j + 1}/' 256 | for r in range(3): 257 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) 258 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) 259 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) 260 | if block.downsample is not None: 261 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) 262 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) 263 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) 264 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) 265 | else: 266 | embed_conv_w = adapt_input_conv( 267 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) 268 | model.patch_embed.proj.weight.copy_(embed_conv_w) 269 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) 270 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) 271 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) 272 | if pos_embed_w.shape != model.pos_embed.shape: 273 | pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights 274 | pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) 275 | model.pos_embed.copy_(pos_embed_w) 276 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) 277 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) 278 | if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: 279 | model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) 280 | model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) 281 | if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: 282 | model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) 283 | model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) 284 | for i, block in enumerate(model.blocks.children()): 285 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/' 286 | mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' 287 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 288 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 289 | block.attn.qkv.weight.copy_(torch.cat([ 290 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 291 | block.attn.qkv.bias.copy_(torch.cat([ 292 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 293 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 294 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 295 | for r in range(2): 296 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) 297 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) 298 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) 299 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) 300 | 301 | 302 | def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): 303 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from 304 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 305 | print('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) 306 | ntok_new = posemb_new.shape[1] 307 | if num_tokens: 308 | posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] 309 | ntok_new -= num_tokens 310 | else: 311 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0] 312 | gs_old = int(math.sqrt(len(posemb_grid))) 313 | if not len(gs_new): # backwards compatibility 314 | gs_new = [int(math.sqrt(ntok_new))] * 2 315 | assert len(gs_new) >= 2 316 | print('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new) 317 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 318 | posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bilinear') 319 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) 320 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 321 | return posemb 322 | 323 | 324 | def checkpoint_filter_fn(state_dict, model): 325 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 326 | out_dict = {} 327 | if 'model' in state_dict: 328 | # For deit models 329 | state_dict = state_dict['model'] 330 | for k, v in state_dict.items(): 331 | if 'patch_embed.proj.weight' in k and len(v.shape) < 4: 332 | # For old models that I trained prior to conv based patchification 333 | O, I, H, W = model.patch_embed.proj.weight.shape 334 | v = v.reshape(O, -1, H, W) 335 | elif k == 'pos_embed' and v.shape != model.pos_embed.shape: 336 | # To resize pos embedding when using model at different size from pretrained weights 337 | v = resize_pos_embed( 338 | v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) 339 | out_dict[k] = v 340 | return out_dict 341 | 342 | 343 | def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs): 344 | if kwargs.get('features_only', None): 345 | raise RuntimeError('features_only not implemented for Vision Transformer models.') 346 | 347 | model = VisionTransformer(**kwargs) 348 | 349 | if pretrained: 350 | if 'npz' in pretrained: 351 | model.load_pretrained(pretrained, prefix='') 352 | else: 353 | checkpoint = torch.load(pretrained, map_location="cpu") 354 | missing_keys, unexpected_keys = model.load_state_dict(checkpoint["model"], strict=False) 355 | print('Load pretrained model from: ' + pretrained) 356 | 357 | return model 358 | 359 | 360 | def vit_base_patch16_224(pretrained=False, **kwargs): 361 | """ 362 | ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 363 | """ 364 | model_kwargs = dict( 365 | patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 366 | model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) 367 | return model 368 | --------------------------------------------------------------------------------