├── model └── ostrack │ ├── layers │ ├── __init__.py │ ├── __pycache__ │ │ ├── attn.cpython-311.pyc │ │ ├── head.cpython-311.pyc │ │ ├── rpe.cpython-311.pyc │ │ ├── __init__.cpython-311.pyc │ │ ├── frozen_bn.cpython-311.pyc │ │ ├── attn_blocks.cpython-311.pyc │ │ └── patch_embed.cpython-311.pyc │ ├── patch_embed.py │ ├── frozen_bn.py │ ├── rpe.py │ ├── attn.py │ ├── attn_blocks.py │ └── head.py │ ├── utils │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-311.pyc │ │ ├── box_ops.cpython-311.pyc │ │ └── tensor.cpython-311.pyc │ ├── merge.py │ ├── variable_hook.py │ ├── lmdb_utils.py │ ├── focal_loss.py │ ├── box_ops.py │ ├── ce_utils.py │ ├── heapmap_utils.py │ ├── tensor.py │ └── misc.py │ ├── __pycache__ │ ├── util.cpython-311.pyc │ ├── vit.cpython-311.pyc │ ├── vit_ce.cpython-311.pyc │ ├── ostrack.cpython-311.pyc │ └── base_backbone.cpython-311.pyc │ ├── util.py │ ├── ostrack.py │ ├── base_backbone.py │ ├── vit_ce.py │ └── vit.py ├── __init__.py ├── assets ├── uav_1.jpg ├── uav_2.jpg ├── uav_3.jpg ├── OSTrack.jpg ├── bandicam.gif ├── infrared_5.gif ├── cover_image.jpg ├── uav_1_result.jpg ├── results │ ├── P_curve.jpg │ ├── R_curve.jpg │ ├── labels.jpg │ ├── paths.jpg │ ├── results.jpg │ ├── F1_curve.jpg │ ├── PR_curve.jpg │ ├── train_batch0.jpg │ ├── train_batch1.jpg │ ├── train_batch2.jpg │ ├── confusion_matrix.jpg │ ├── val_batch0_labels.jpg │ ├── val_batch0_pred.jpg │ ├── val_batch1_labels.jpg │ ├── val_batch1_pred.jpg │ ├── val_batch2_labels.jpg │ ├── val_batch2_pred.jpg │ └── labels_correlogram.jpg ├── ostrack_test0320.jpg ├── exception │ ├── exception1.gif │ └── exception2.gif ├── processed_infrared_5.gif └── architecture │ ├── ostrack_1.jpg │ ├── ostrack_2.jpg │ ├── ostrack_3.jpg │ ├── ostrack_4.jpg │ ├── ostrack_5.jpg │ └── ostrack_6.jpg ├── frames ├── shot_0.jpg ├── shot_1.jpg ├── shot_2.jpg ├── shot_3.jpg ├── shot_4.jpg └── shot_5.jpg ├── images └── main_window_img.png ├── .gitmodules ├── __pycache__ ├── bbox.cpython-311.pyc ├── util.cpython-311.pyc ├── utils.cpython-311.pyc ├── yolo.cpython-311.pyc ├── yolov5.cpython-311.pyc └── yolo_model.cpython-311.pyc ├── config ├── got_10k.yaml ├── vitb_384_mae_32x4_ep300.yaml ├── vitb_256_mae_32x4_ep300.yaml ├── vitb_256_mae_ce_32x4_got10k_ep100.yaml ├── vitb_384_mae_ce_32x4_ep300.yaml ├── vitb_384_mae_ce_32x4_got10k_ep100.yaml └── vitb_256_mae_ce_32x4_ep300.yaml ├── .gitignore ├── LICENSE ├── bbox.py ├── yolo_model.py ├── util.py ├── README.md └── app_osyo.py /model/ostrack/layers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tkinter as tk 3 | import numpy as np 4 | -------------------------------------------------------------------------------- /model/ostrack/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .tensor import TensorDict, TensorList 2 | -------------------------------------------------------------------------------- /assets/uav_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/uav_1.jpg -------------------------------------------------------------------------------- /assets/uav_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/uav_2.jpg -------------------------------------------------------------------------------- /assets/uav_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/uav_3.jpg -------------------------------------------------------------------------------- /assets/OSTrack.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/OSTrack.jpg -------------------------------------------------------------------------------- /frames/shot_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/frames/shot_0.jpg -------------------------------------------------------------------------------- /frames/shot_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/frames/shot_1.jpg -------------------------------------------------------------------------------- /frames/shot_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/frames/shot_2.jpg -------------------------------------------------------------------------------- /frames/shot_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/frames/shot_3.jpg -------------------------------------------------------------------------------- /frames/shot_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/frames/shot_4.jpg -------------------------------------------------------------------------------- /frames/shot_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/frames/shot_5.jpg -------------------------------------------------------------------------------- /assets/bandicam.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/bandicam.gif -------------------------------------------------------------------------------- /assets/infrared_5.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/infrared_5.gif -------------------------------------------------------------------------------- /assets/cover_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/cover_image.jpg -------------------------------------------------------------------------------- /assets/uav_1_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/uav_1_result.jpg -------------------------------------------------------------------------------- /assets/results/P_curve.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/results/P_curve.jpg -------------------------------------------------------------------------------- /assets/results/R_curve.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/results/R_curve.jpg -------------------------------------------------------------------------------- /assets/results/labels.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/results/labels.jpg -------------------------------------------------------------------------------- /assets/results/paths.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/results/paths.jpg -------------------------------------------------------------------------------- /assets/results/results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/results/results.jpg -------------------------------------------------------------------------------- /images/main_window_img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/images/main_window_img.png -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "yolov5"] 2 | path = yolov5 3 | url = https://github.com/ultralytics/yolov5.git 4 | -------------------------------------------------------------------------------- /assets/ostrack_test0320.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/ostrack_test0320.jpg -------------------------------------------------------------------------------- /assets/results/F1_curve.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/results/F1_curve.jpg -------------------------------------------------------------------------------- /assets/results/PR_curve.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/results/PR_curve.jpg -------------------------------------------------------------------------------- /assets/exception/exception1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/exception/exception1.gif -------------------------------------------------------------------------------- /assets/exception/exception2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/exception/exception2.gif -------------------------------------------------------------------------------- /assets/processed_infrared_5.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/processed_infrared_5.gif -------------------------------------------------------------------------------- /assets/results/train_batch0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/results/train_batch0.jpg -------------------------------------------------------------------------------- /assets/results/train_batch1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/results/train_batch1.jpg -------------------------------------------------------------------------------- /assets/results/train_batch2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/results/train_batch2.jpg -------------------------------------------------------------------------------- /__pycache__/bbox.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/__pycache__/bbox.cpython-311.pyc -------------------------------------------------------------------------------- /__pycache__/util.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/__pycache__/util.cpython-311.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/__pycache__/utils.cpython-311.pyc -------------------------------------------------------------------------------- /__pycache__/yolo.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/__pycache__/yolo.cpython-311.pyc -------------------------------------------------------------------------------- /assets/architecture/ostrack_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/architecture/ostrack_1.jpg -------------------------------------------------------------------------------- /assets/architecture/ostrack_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/architecture/ostrack_2.jpg -------------------------------------------------------------------------------- /assets/architecture/ostrack_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/architecture/ostrack_3.jpg -------------------------------------------------------------------------------- /assets/architecture/ostrack_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/architecture/ostrack_4.jpg -------------------------------------------------------------------------------- /assets/architecture/ostrack_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/architecture/ostrack_5.jpg -------------------------------------------------------------------------------- /assets/architecture/ostrack_6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/architecture/ostrack_6.jpg -------------------------------------------------------------------------------- /config/got_10k.yaml: -------------------------------------------------------------------------------- 1 | train: ./datasets/images/train 2 | val: ./datasets/images/val 3 | 4 | nc: 1 5 | names: ['drone'] -------------------------------------------------------------------------------- /__pycache__/yolov5.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/__pycache__/yolov5.cpython-311.pyc -------------------------------------------------------------------------------- /assets/results/confusion_matrix.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/results/confusion_matrix.jpg -------------------------------------------------------------------------------- /assets/results/val_batch0_labels.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/results/val_batch0_labels.jpg -------------------------------------------------------------------------------- /assets/results/val_batch0_pred.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/results/val_batch0_pred.jpg -------------------------------------------------------------------------------- /assets/results/val_batch1_labels.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/results/val_batch1_labels.jpg -------------------------------------------------------------------------------- /assets/results/val_batch1_pred.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/results/val_batch1_pred.jpg -------------------------------------------------------------------------------- /assets/results/val_batch2_labels.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/results/val_batch2_labels.jpg -------------------------------------------------------------------------------- /assets/results/val_batch2_pred.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/results/val_batch2_pred.jpg -------------------------------------------------------------------------------- /__pycache__/yolo_model.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/__pycache__/yolo_model.cpython-311.pyc -------------------------------------------------------------------------------- /assets/results/labels_correlogram.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/assets/results/labels_correlogram.jpg -------------------------------------------------------------------------------- /model/ostrack/__pycache__/util.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/model/ostrack/__pycache__/util.cpython-311.pyc -------------------------------------------------------------------------------- /model/ostrack/__pycache__/vit.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/model/ostrack/__pycache__/vit.cpython-311.pyc -------------------------------------------------------------------------------- /model/ostrack/__pycache__/vit_ce.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/model/ostrack/__pycache__/vit_ce.cpython-311.pyc -------------------------------------------------------------------------------- /model/ostrack/__pycache__/ostrack.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/model/ostrack/__pycache__/ostrack.cpython-311.pyc -------------------------------------------------------------------------------- /model/ostrack/layers/__pycache__/attn.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/model/ostrack/layers/__pycache__/attn.cpython-311.pyc -------------------------------------------------------------------------------- /model/ostrack/layers/__pycache__/head.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/model/ostrack/layers/__pycache__/head.cpython-311.pyc -------------------------------------------------------------------------------- /model/ostrack/layers/__pycache__/rpe.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/model/ostrack/layers/__pycache__/rpe.cpython-311.pyc -------------------------------------------------------------------------------- /model/ostrack/__pycache__/base_backbone.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/model/ostrack/__pycache__/base_backbone.cpython-311.pyc -------------------------------------------------------------------------------- /model/ostrack/utils/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/model/ostrack/utils/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /model/ostrack/utils/__pycache__/box_ops.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/model/ostrack/utils/__pycache__/box_ops.cpython-311.pyc -------------------------------------------------------------------------------- /model/ostrack/utils/__pycache__/tensor.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/model/ostrack/utils/__pycache__/tensor.cpython-311.pyc -------------------------------------------------------------------------------- /model/ostrack/layers/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/model/ostrack/layers/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /model/ostrack/layers/__pycache__/frozen_bn.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/model/ostrack/layers/__pycache__/frozen_bn.cpython-311.pyc -------------------------------------------------------------------------------- /model/ostrack/layers/__pycache__/attn_blocks.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/model/ostrack/layers/__pycache__/attn_blocks.cpython-311.pyc -------------------------------------------------------------------------------- /model/ostrack/layers/__pycache__/patch_embed.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rtwotwo/OSTrack/HEAD/model/ostrack/layers/__pycache__/patch_embed.cpython-311.pyc -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # 使用git reset更新.gitignore 2 | 3 | # 原始项目文件不上传 4 | MCJT/ 5 | 6 | # 忽略权重文件 7 | weights/ 8 | 9 | # 数据集不上传 10 | datasets/ 11 | 12 | # 忽略日志文件 13 | logs/ 14 | 15 | # 忽略测试视频 16 | video/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | Copyright (c) 2025 Yolo Redal 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /model/ostrack/utils/merge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def merge_template_search(inp_list, return_search=False, return_template=False): 5 | """NOTICE: search region related features must be in the last place""" 6 | seq_dict = {"feat": torch.cat([x["feat"] for x in inp_list], dim=0), 7 | "mask": torch.cat([x["mask"] for x in inp_list], dim=1), 8 | "pos": torch.cat([x["pos"] for x in inp_list], dim=0)} 9 | if return_search: 10 | x = inp_list[-1] 11 | seq_dict.update({"feat_x": x["feat"], "mask_x": x["mask"], "pos_x": x["pos"]}) 12 | if return_template: 13 | z = inp_list[0] 14 | seq_dict.update({"feat_z": z["feat"], "mask_z": z["mask"], "pos_z": z["pos"]}) 15 | return seq_dict 16 | 17 | 18 | def get_qkv(inp_list): 19 | """The 1st element of the inp_list is about the template, 20 | the 2nd (the last) element is about the search region""" 21 | dict_x = inp_list[-1] 22 | dict_c = {"feat": torch.cat([x["feat"] for x in inp_list], dim=0), 23 | "mask": torch.cat([x["mask"] for x in inp_list], dim=1), 24 | "pos": torch.cat([x["pos"] for x in inp_list], dim=0)} # concatenated dict 25 | q = dict_x["feat"] + dict_x["pos"] 26 | k = dict_c["feat"] + dict_c["pos"] 27 | v = dict_c["feat"] 28 | key_padding_mask = dict_c["mask"] 29 | return q, k, v, key_padding_mask 30 | -------------------------------------------------------------------------------- /model/ostrack/layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from timm.models.layers import to_2tuple 4 | 5 | 6 | class PatchEmbed(nn.Module): 7 | """ 2D Image to Patch Embedding 8 | """ 9 | 10 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): 11 | super().__init__() 12 | img_size = to_2tuple(img_size) 13 | patch_size = to_2tuple(patch_size) 14 | self.img_size = img_size 15 | self.patch_size = patch_size 16 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 17 | self.num_patches = self.grid_size[0] * self.grid_size[1] 18 | self.flatten = flatten 19 | 20 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 21 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 22 | 23 | def forward(self, x): 24 | # allow different input size 25 | # B, C, H, W = x.shape 26 | # _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") 27 | # _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") 28 | x = self.proj(x) 29 | if self.flatten: 30 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 31 | x = self.norm(x) 32 | return x 33 | -------------------------------------------------------------------------------- /config/vitb_384_mae_32x4_ep300.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | MAX_SAMPLE_INTERVAL: 200 3 | MEAN: 4 | - 0.485 5 | - 0.456 6 | - 0.406 7 | SEARCH: 8 | CENTER_JITTER: 4.5 9 | FACTOR: 5.0 10 | SCALE_JITTER: 0.5 11 | SIZE: 384 12 | STD: 13 | - 0.229 14 | - 0.224 15 | - 0.225 16 | TEMPLATE: 17 | CENTER_JITTER: 0 18 | FACTOR: 2.0 19 | SCALE_JITTER: 0 20 | SIZE: 192 21 | # TRAIN: 22 | # DATASETS_NAME: 23 | # - GOT10K_train_full 24 | # DATASETS_RATIO: 25 | # - 1 26 | # SAMPLE_PER_EPOCH: 60000 27 | 28 | TRAIN: 29 | DATASETS_NAME: 30 | - LASOT 31 | - GOT10K_vottrain 32 | - COCO17 33 | - TRACKINGNET 34 | DATASETS_RATIO: 35 | - 1 36 | - 1 37 | - 1 38 | - 1 39 | SAMPLE_PER_EPOCH: 60000 40 | VAL: 41 | DATASETS_NAME: 42 | - GOT10K_votval 43 | DATASETS_RATIO: 44 | - 1 45 | SAMPLE_PER_EPOCH: 10000 46 | MODEL: 47 | PRETRAIN_FILE: "mae_pretrain_vit_base.pth" 48 | EXTRA_MERGER: False 49 | RETURN_INTER: False 50 | BACKBONE: 51 | TYPE: vit_base_patch16_224 52 | STRIDE: 16 53 | HEAD: 54 | TYPE: CENTER 55 | NUM_CHANNELS: 256 56 | TRAIN: 57 | BACKBONE_MULTIPLIER: 0.1 58 | DROP_PATH_RATE: 0.1 59 | BATCH_SIZE: 32 60 | EPOCH: 300 61 | GIOU_WEIGHT: 2.0 62 | L1_WEIGHT: 5.0 63 | GRAD_CLIP_NORM: 0.1 64 | LR: 0.0004 65 | LR_DROP_EPOCH: 240 66 | NUM_WORKER: 10 67 | OPTIMIZER: ADAMW 68 | PRINT_INTERVAL: 50 69 | SCHEDULER: 70 | TYPE: step 71 | DECAY_RATE: 0.1 72 | VAL_EPOCH_INTERVAL: 20 73 | WEIGHT_DECAY: 0.0001 74 | AMP: False 75 | TEST: 76 | EPOCH: 300 77 | SEARCH_FACTOR: 5.0 78 | SEARCH_SIZE: 384 79 | TEMPLATE_FACTOR: 2.0 80 | TEMPLATE_SIZE: 192 -------------------------------------------------------------------------------- /config/vitb_256_mae_32x4_ep300.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | MAX_SAMPLE_INTERVAL: 200 3 | MEAN: 4 | - 0.485 5 | - 0.456 6 | - 0.406 7 | SEARCH: 8 | CENTER_JITTER: 3 9 | FACTOR: 4.0 10 | SCALE_JITTER: 0.25 11 | SIZE: 256 12 | NUMBER: 1 13 | STD: 14 | - 0.229 15 | - 0.224 16 | - 0.225 17 | TEMPLATE: 18 | CENTER_JITTER: 0 19 | FACTOR: 2.0 20 | SCALE_JITTER: 0 21 | SIZE: 128 22 | # TRAIN: 23 | # DATASETS_NAME: 24 | # - GOT10K_train_full 25 | # DATASETS_RATIO: 26 | # - 1 27 | # SAMPLE_PER_EPOCH: 60000 28 | 29 | TRAIN: 30 | DATASETS_NAME: 31 | - LASOT 32 | - GOT10K_vottrain 33 | - COCO17 34 | - TRACKINGNET 35 | DATASETS_RATIO: 36 | - 1 37 | - 1 38 | - 1 39 | - 1 40 | SAMPLE_PER_EPOCH: 60000 41 | VAL: 42 | DATASETS_NAME: 43 | - GOT10K_votval 44 | DATASETS_RATIO: 45 | - 1 46 | SAMPLE_PER_EPOCH: 10000 47 | MODEL: 48 | PRETRAIN_FILE: "mae_pretrain_vit_base.pth" 49 | EXTRA_MERGER: False 50 | RETURN_INTER: False 51 | BACKBONE: 52 | TYPE: vit_base_patch16_224 53 | STRIDE: 16 54 | HEAD: 55 | TYPE: CENTER 56 | NUM_CHANNELS: 256 57 | TRAIN: 58 | BACKBONE_MULTIPLIER: 0.1 59 | DROP_PATH_RATE: 0.1 60 | BATCH_SIZE: 32 61 | EPOCH: 300 62 | GIOU_WEIGHT: 2.0 63 | L1_WEIGHT: 5.0 64 | GRAD_CLIP_NORM: 0.1 65 | LR: 0.0004 66 | LR_DROP_EPOCH: 240 67 | NUM_WORKER: 10 68 | OPTIMIZER: ADAMW 69 | PRINT_INTERVAL: 50 70 | SCHEDULER: 71 | TYPE: step 72 | DECAY_RATE: 0.1 73 | VAL_EPOCH_INTERVAL: 20 74 | WEIGHT_DECAY: 0.0001 75 | AMP: False 76 | TEST: 77 | EPOCH: 300 78 | SEARCH_FACTOR: 4.0 79 | SEARCH_SIZE: 256 80 | TEMPLATE_FACTOR: 2.0 81 | TEMPLATE_SIZE: 128 -------------------------------------------------------------------------------- /model/ostrack/utils/variable_hook.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from bytecode import Bytecode, Instr 3 | 4 | 5 | class get_local(object): 6 | cache = {} 7 | is_activate = False 8 | 9 | def __init__(self, varname): 10 | self.varname = varname 11 | 12 | def __call__(self, func): 13 | if not type(self).is_activate: 14 | return func 15 | 16 | type(self).cache[func.__qualname__] = [] 17 | c = Bytecode.from_code(func.__code__) 18 | extra_code = [ 19 | Instr('STORE_FAST', '_res'), 20 | Instr('LOAD_FAST', self.varname), 21 | Instr('STORE_FAST', '_value'), 22 | Instr('LOAD_FAST', '_res'), 23 | Instr('LOAD_FAST', '_value'), 24 | Instr('BUILD_TUPLE', 2), 25 | Instr('STORE_FAST', '_result_tuple'), 26 | Instr('LOAD_FAST', '_result_tuple'), 27 | ] 28 | c[-1:-1] = extra_code 29 | func.__code__ = c.to_code() 30 | 31 | def wrapper(*args, **kwargs): 32 | res, values = func(*args, **kwargs) 33 | if isinstance(values, torch.Tensor): 34 | type(self).cache[func.__qualname__].append(values.detach().cpu().numpy()) 35 | elif isinstance(values, list): # list of Tensor 36 | type(self).cache[func.__qualname__].append([value.detach().cpu().numpy() for value in values]) 37 | else: 38 | raise NotImplementedError 39 | return res 40 | 41 | return wrapper 42 | 43 | @classmethod 44 | def clear(cls): 45 | for key in cls.cache.keys(): 46 | cls.cache[key] = [] 47 | 48 | @classmethod 49 | def activate(cls): 50 | cls.is_activate = True 51 | -------------------------------------------------------------------------------- /model/ostrack/layers/frozen_bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class FrozenBatchNorm2d(torch.nn.Module): 5 | """ 6 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 7 | 8 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 9 | without which any other models than torchvision.models.resnet[18,34,50,101] 10 | produce nans. 11 | """ 12 | 13 | def __init__(self, n): 14 | super(FrozenBatchNorm2d, self).__init__() 15 | self.register_buffer("weight", torch.ones(n)) 16 | self.register_buffer("bias", torch.zeros(n)) 17 | self.register_buffer("running_mean", torch.zeros(n)) 18 | self.register_buffer("running_var", torch.ones(n)) 19 | 20 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 21 | missing_keys, unexpected_keys, error_msgs): 22 | num_batches_tracked_key = prefix + 'num_batches_tracked' 23 | if num_batches_tracked_key in state_dict: 24 | del state_dict[num_batches_tracked_key] 25 | 26 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 27 | state_dict, prefix, local_metadata, strict, 28 | missing_keys, unexpected_keys, error_msgs) 29 | 30 | def forward(self, x): 31 | # move reshapes to the beginning 32 | # to make it fuser-friendly 33 | w = self.weight.reshape(1, -1, 1, 1) 34 | b = self.bias.reshape(1, -1, 1, 1) 35 | rv = self.running_var.reshape(1, -1, 1, 1) 36 | rm = self.running_mean.reshape(1, -1, 1, 1) 37 | eps = 1e-5 38 | scale = w * (rv + eps).rsqrt() # rsqrt(x): 1/sqrt(x), r: reciprocal 39 | bias = b - rm * scale 40 | return x * scale + bias 41 | -------------------------------------------------------------------------------- /config/vitb_256_mae_ce_32x4_got10k_ep100.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | MAX_SAMPLE_INTERVAL: 200 3 | MEAN: 4 | - 0.485 5 | - 0.456 6 | - 0.406 7 | SEARCH: 8 | CENTER_JITTER: 3 9 | FACTOR: 4.0 10 | SCALE_JITTER: 0.25 11 | SIZE: 256 12 | NUMBER: 1 13 | STD: 14 | - 0.229 15 | - 0.224 16 | - 0.225 17 | TEMPLATE: 18 | CENTER_JITTER: 0 19 | FACTOR: 2.0 20 | SCALE_JITTER: 0 21 | SIZE: 128 22 | TRAIN: 23 | DATASETS_NAME: 24 | - GOT10K_train_full 25 | DATASETS_RATIO: 26 | - 1 27 | SAMPLE_PER_EPOCH: 60000 28 | VAL: 29 | DATASETS_NAME: 30 | - GOT10K_official_val 31 | DATASETS_RATIO: 32 | - 1 33 | SAMPLE_PER_EPOCH: 10000 34 | MODEL: 35 | PRETRAIN_FILE: "mae_pretrain_vit_base.pth" 36 | EXTRA_MERGER: False 37 | RETURN_INTER: False 38 | BACKBONE: 39 | TYPE: vit_base_patch16_224_ce 40 | STRIDE: 16 41 | CE_LOC: [3, 6, 9] 42 | CE_KEEP_RATIO: [0.7, 0.7, 0.7] 43 | CE_TEMPLATE_RANGE: 'CTR_POINT' # choose between ALL, CTR_POINT, CTR_REC, GT_BOX 44 | HEAD: 45 | TYPE: CENTER 46 | NUM_CHANNELS: 256 47 | TRAIN: 48 | BACKBONE_MULTIPLIER: 0.1 49 | DROP_PATH_RATE: 0.1 50 | CE_START_EPOCH: 20 # candidate elimination start epoch 51 | CE_WARM_EPOCH: 50 # candidate elimination warm up epoch 52 | BATCH_SIZE: 16 53 | EPOCH: 100 54 | GIOU_WEIGHT: 2.0 55 | L1_WEIGHT: 5.0 56 | GRAD_CLIP_NORM: 0.1 57 | LR: 0.0004 58 | LR_DROP_EPOCH: 80 59 | NUM_WORKER: 10 60 | OPTIMIZER: ADAMW 61 | PRINT_INTERVAL: 50 62 | SCHEDULER: 63 | TYPE: step 64 | DECAY_RATE: 0.1 65 | VAL_EPOCH_INTERVAL: 20 66 | WEIGHT_DECAY: 0.0001 67 | AMP: False 68 | TEST: 69 | EPOCH: 100 70 | SEARCH_FACTOR: 4.0 71 | SEARCH_SIZE: 256 72 | TEMPLATE_FACTOR: 2.0 73 | TEMPLATE_SIZE: 128 -------------------------------------------------------------------------------- /model/ostrack/utils/lmdb_utils.py: -------------------------------------------------------------------------------- 1 | import lmdb 2 | import numpy as np 3 | import cv2 4 | import json 5 | 6 | LMDB_ENVS = dict() 7 | LMDB_HANDLES = dict() 8 | LMDB_FILELISTS = dict() 9 | 10 | 11 | def get_lmdb_handle(name): 12 | global LMDB_HANDLES, LMDB_FILELISTS 13 | item = LMDB_HANDLES.get(name, None) 14 | if item is None: 15 | env = lmdb.open(name, readonly=True, lock=False, readahead=False, meminit=False) 16 | LMDB_ENVS[name] = env 17 | item = env.begin(write=False) 18 | LMDB_HANDLES[name] = item 19 | 20 | return item 21 | 22 | 23 | def decode_img(lmdb_fname, key_name): 24 | handle = get_lmdb_handle(lmdb_fname) 25 | binfile = handle.get(key_name.encode()) 26 | if binfile is None: 27 | print("Illegal data detected. %s %s" % (lmdb_fname, key_name)) 28 | s = np.frombuffer(binfile, np.uint8) 29 | x = cv2.cvtColor(cv2.imdecode(s, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) 30 | return x 31 | 32 | 33 | def decode_str(lmdb_fname, key_name): 34 | handle = get_lmdb_handle(lmdb_fname) 35 | binfile = handle.get(key_name.encode()) 36 | string = binfile.decode() 37 | return string 38 | 39 | 40 | def decode_json(lmdb_fname, key_name): 41 | return json.loads(decode_str(lmdb_fname, key_name)) 42 | 43 | 44 | if __name__ == "__main__": 45 | lmdb_fname = "/data/sda/v-yanbi/iccv21/LittleBoy_clean/data/got10k_lmdb" 46 | '''Decode image''' 47 | # key_name = "test/GOT-10k_Test_000001/00000001.jpg" 48 | # img = decode_img(lmdb_fname, key_name) 49 | # cv2.imwrite("001.jpg", img) 50 | '''Decode str''' 51 | # key_name = "test/list.txt" 52 | # key_name = "train/GOT-10k_Train_000001/groundtruth.txt" 53 | key_name = "train/GOT-10k_Train_000001/absence.label" 54 | str_ = decode_str(lmdb_fname, key_name) 55 | print(str_) 56 | -------------------------------------------------------------------------------- /config/vitb_384_mae_ce_32x4_ep300.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | MAX_SAMPLE_INTERVAL: 200 3 | MEAN: 4 | - 0.485 5 | - 0.456 6 | - 0.406 7 | SEARCH: 8 | CENTER_JITTER: 4.5 9 | FACTOR: 5.0 10 | SCALE_JITTER: 0.5 11 | SIZE: 384 12 | STD: 13 | - 0.229 14 | - 0.224 15 | - 0.225 16 | TEMPLATE: 17 | CENTER_JITTER: 0 18 | FACTOR: 2.0 19 | SCALE_JITTER: 0 20 | SIZE: 192 21 | # TRAIN: 22 | # DATASETS_NAME: 23 | # - GOT10K_train_full 24 | # DATASETS_RATIO: 25 | # - 1 26 | # SAMPLE_PER_EPOCH: 60000 27 | 28 | TRAIN: 29 | DATASETS_NAME: 30 | - LASOT 31 | - GOT10K_vottrain 32 | - COCO17 33 | - TRACKINGNET 34 | DATASETS_RATIO: 35 | - 1 36 | - 1 37 | - 1 38 | - 1 39 | SAMPLE_PER_EPOCH: 60000 40 | VAL: 41 | DATASETS_NAME: 42 | - GOT10K_votval 43 | DATASETS_RATIO: 44 | - 1 45 | SAMPLE_PER_EPOCH: 10000 46 | MODEL: 47 | PRETRAIN_FILE: "mae_pretrain_vit_base.pth" 48 | EXTRA_MERGER: False 49 | RETURN_INTER: False 50 | BACKBONE: 51 | TYPE: vit_base_patch16_224_ce 52 | STRIDE: 16 53 | CE_LOC: [3, 6, 9] 54 | CE_KEEP_RATIO: [0.7, 0.7, 0.7] 55 | CE_TEMPLATE_RANGE: 'CTR_POINT' # choose between ALL, CTR_POINT, CTR_REC, GT_BOX 56 | HEAD: 57 | TYPE: CENTER 58 | NUM_CHANNELS: 256 59 | TRAIN: 60 | BACKBONE_MULTIPLIER: 0.1 61 | DROP_PATH_RATE: 0.1 62 | BATCH_SIZE: 32 63 | EPOCH: 300 64 | GIOU_WEIGHT: 2.0 65 | L1_WEIGHT: 5.0 66 | GRAD_CLIP_NORM: 0.1 67 | LR: 0.0004 68 | LR_DROP_EPOCH: 240 69 | NUM_WORKER: 10 70 | OPTIMIZER: ADAMW 71 | PRINT_INTERVAL: 50 72 | SCHEDULER: 73 | TYPE: step 74 | DECAY_RATE: 0.1 75 | VAL_EPOCH_INTERVAL: 20 76 | WEIGHT_DECAY: 0.0001 77 | AMP: False 78 | TEST: 79 | EPOCH: 300 80 | SEARCH_FACTOR: 5.0 81 | SEARCH_SIZE: 384 82 | TEMPLATE_FACTOR: 2.0 83 | TEMPLATE_SIZE: 192 -------------------------------------------------------------------------------- /config/vitb_384_mae_ce_32x4_got10k_ep100.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | MAX_SAMPLE_INTERVAL: 200 3 | MEAN: 4 | - 0.485 5 | - 0.456 6 | - 0.406 7 | SEARCH: 8 | CENTER_JITTER: 4.5 9 | FACTOR: 5.0 10 | SCALE_JITTER: 0.5 11 | SIZE: 384 12 | STD: 13 | - 0.229 14 | - 0.224 15 | - 0.225 16 | TEMPLATE: 17 | CENTER_JITTER: 0 18 | FACTOR: 2.0 19 | SCALE_JITTER: 0 20 | SIZE: 192 21 | TRAIN: 22 | DATASETS_NAME: 23 | - GOT10K_train_full 24 | DATASETS_RATIO: 25 | - 1 26 | SAMPLE_PER_EPOCH: 60000 27 | VAL: 28 | DATASETS_NAME: 29 | - GOT10K_official_val 30 | DATASETS_RATIO: 31 | - 1 32 | SAMPLE_PER_EPOCH: 10000 33 | MODEL: 34 | PRETRAIN_FILE: "mae_pretrain_vit_base.pth" 35 | EXTRA_MERGER: False 36 | RETURN_INTER: False 37 | BACKBONE: 38 | CAT_MODE: 'direct' # direct or concat 39 | SEP_SEG: True 40 | TYPE: vit_base_patch16_224_ce 41 | STRIDE: 16 42 | CE_LOC: [3, 6, 9] 43 | CE_KEEP_RATIO: [0.7, 0.7, 0.7] 44 | # CE_KEEP_RATIO: [0.5, 0.5, 0.5] 45 | CE_TEMPLATE_RANGE: 'CTR_POINT' # choose between ALL, CTR_POINT, CTR_REC, GT_BOX 46 | HEAD: 47 | TYPE: CENTER 48 | NUM_CHANNELS: 256 49 | RETURN_STAGES: [2, 5, 8, 11] 50 | TRAIN: 51 | BACKBONE_MULTIPLIER: 0.1 52 | DROP_PATH_RATE: 0.1 53 | CE_START_EPOCH: 20 # candidate elimination start epoch 54 | CE_WARM_EPOCH: 50 # candidate elimination warm up epoch 55 | BATCH_SIZE: 16 # 2022.12.15 32->16 56 | EPOCH: 100 57 | GIOU_WEIGHT: 2.0 58 | L1_WEIGHT: 5.0 59 | GRAD_CLIP_NORM: 0.1 60 | LR: 0.0002 # 2022.12.20 0.0004 -> 0.0002 61 | LR_DROP_EPOCH: 80 62 | NUM_WORKER: 10 63 | OPTIMIZER: ADAMW 64 | PRINT_INTERVAL: 50 65 | SCHEDULER: 66 | TYPE: step 67 | DECAY_RATE: 0.1 68 | VAL_EPOCH_INTERVAL: 20 69 | WEIGHT_DECAY: 0.0001 70 | AMP: False 71 | TEST: 72 | EPOCH: 100 73 | SEARCH_FACTOR: 5.0 74 | SEARCH_SIZE: 384 75 | TEMPLATE_FACTOR: 2.0 76 | TEMPLATE_SIZE: 192 -------------------------------------------------------------------------------- /config/vitb_256_mae_ce_32x4_ep300.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | MAX_SAMPLE_INTERVAL: 200 3 | MEAN: 4 | - 0.485 5 | - 0.456 6 | - 0.406 7 | SEARCH: 8 | CENTER_JITTER: 3 9 | FACTOR: 4.0 10 | SCALE_JITTER: 0.25 11 | SIZE: 256 12 | NUMBER: 1 13 | STD: 14 | - 0.229 15 | - 0.224 16 | - 0.225 17 | TEMPLATE: 18 | CENTER_JITTER: 0 19 | FACTOR: 2.0 20 | SCALE_JITTER: 0 21 | SIZE: 128 22 | # TRAIN: 23 | # DATASETS_NAME: 24 | # - GOT10K_train_full 25 | # DATASETS_RATIO: 26 | # - 1 27 | # SAMPLE_PER_EPOCH: 60000 28 | 29 | TRAIN: 30 | DATASETS_NAME: 31 | - LASOT 32 | - GOT10K_vottrain 33 | - COCO17 34 | - TRACKINGNET 35 | DATASETS_RATIO: 36 | - 1 37 | - 1 38 | - 1 39 | - 1 40 | SAMPLE_PER_EPOCH: 60000 41 | VAL: 42 | DATASETS_NAME: 43 | - GOT10K_votval 44 | DATASETS_RATIO: 45 | - 1 46 | SAMPLE_PER_EPOCH: 10000 47 | MODEL: 48 | PRETRAIN_FILE: "mae_pretrain_vit_base.pth" 49 | EXTRA_MERGER: False 50 | RETURN_INTER: False 51 | BACKBONE: 52 | TYPE: vit_base_patch16_224_ce 53 | STRIDE: 16 54 | CE_LOC: [3, 6, 9] 55 | CE_KEEP_RATIO: [0.7, 0.7, 0.7] 56 | CE_TEMPLATE_RANGE: 'CTR_POINT' # choose between ALL, CTR_POINT, CTR_REC, GT_BOX 57 | HEAD: 58 | TYPE: CENTER 59 | NUM_CHANNELS: 256 60 | TRAIN: 61 | BACKBONE_MULTIPLIER: 0.1 62 | DROP_PATH_RATE: 0.1 63 | CE_START_EPOCH: 20 # candidate elimination start epoch 64 | CE_WARM_EPOCH: 80 # candidate elimination warm up epoch 65 | BATCH_SIZE: 32 66 | EPOCH: 300 67 | GIOU_WEIGHT: 2.0 68 | L1_WEIGHT: 5.0 69 | GRAD_CLIP_NORM: 0.1 70 | LR: 0.0004 71 | LR_DROP_EPOCH: 240 72 | NUM_WORKER: 10 73 | OPTIMIZER: ADAMW 74 | PRINT_INTERVAL: 50 75 | SCHEDULER: 76 | TYPE: step 77 | DECAY_RATE: 0.1 78 | VAL_EPOCH_INTERVAL: 20 79 | WEIGHT_DECAY: 0.0001 80 | AMP: False 81 | TEST: 82 | EPOCH: 300 83 | SEARCH_FACTOR: 4.0 84 | SEARCH_SIZE: 256 85 | TEMPLATE_FACTOR: 2.0 86 | TEMPLATE_SIZE: 128 -------------------------------------------------------------------------------- /model/ostrack/utils/focal_loss.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class FocalLoss(nn.Module, ABC): 9 | def __init__(self, alpha=2, beta=4): 10 | super(FocalLoss, self).__init__() 11 | self.alpha = alpha 12 | self.beta = beta 13 | 14 | def forward(self, prediction, target): 15 | positive_index = target.eq(1).float() 16 | negative_index = target.lt(1).float() 17 | 18 | negative_weights = torch.pow(1 - target, self.beta) 19 | # clamp min value is set to 1e-12 to maintain the numerical stability 20 | prediction = torch.clamp(prediction, 1e-12) 21 | 22 | positive_loss = torch.log(prediction) * torch.pow(1 - prediction, self.alpha) * positive_index 23 | negative_loss = torch.log(1 - prediction) * torch.pow(prediction, 24 | self.alpha) * negative_weights * negative_index 25 | 26 | num_positive = positive_index.float().sum() 27 | positive_loss = positive_loss.sum() 28 | negative_loss = negative_loss.sum() 29 | 30 | if num_positive == 0: 31 | loss = -negative_loss 32 | else: 33 | loss = -(positive_loss + negative_loss) / num_positive 34 | 35 | return loss 36 | 37 | 38 | class LBHinge(nn.Module): 39 | """Loss that uses a 'hinge' on the lower bound. 40 | This means that for samples with a label value smaller than the threshold, the loss is zero if the prediction is 41 | also smaller than that threshold. 42 | args: 43 | error_matric: What base loss to use (MSE by default). 44 | threshold: Threshold to use for the hinge. 45 | clip: Clip the loss if it is above this value. 46 | """ 47 | def __init__(self, error_metric=nn.MSELoss(), threshold=None, clip=None): 48 | super().__init__() 49 | self.error_metric = error_metric 50 | self.threshold = threshold if threshold is not None else -100 51 | self.clip = clip 52 | 53 | def forward(self, prediction, label, target_bb=None): 54 | negative_mask = (label < self.threshold).float() 55 | positive_mask = (1.0 - negative_mask) 56 | 57 | prediction = negative_mask * F.relu(prediction) + positive_mask * prediction 58 | 59 | loss = self.error_metric(prediction, positive_mask * label) 60 | 61 | if self.clip is not None: 62 | loss = torch.min(loss, torch.tensor([self.clip], device=loss.device)) 63 | return loss -------------------------------------------------------------------------------- /model/ostrack/utils/box_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.ops.boxes import box_area 3 | import numpy as np 4 | 5 | 6 | def box_cxcywh_to_xyxy(x): 7 | x_c, y_c, w, h = x.unbind(-1) 8 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 9 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 10 | return torch.stack(b, dim=-1) 11 | 12 | 13 | def box_xywh_to_xyxy(x): 14 | x1, y1, w, h = x.unbind(-1) 15 | b = [x1, y1, x1 + w, y1 + h] 16 | return torch.stack(b, dim=-1) 17 | 18 | 19 | def box_xyxy_to_xywh(x): 20 | x1, y1, x2, y2 = x.unbind(-1) 21 | b = [x1, y1, x2 - x1, y2 - y1] 22 | return torch.stack(b, dim=-1) 23 | 24 | 25 | def box_xyxy_to_cxcywh(x): 26 | x0, y0, x1, y1 = x.unbind(-1) 27 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 28 | (x1 - x0), (y1 - y0)] 29 | return torch.stack(b, dim=-1) 30 | 31 | 32 | # modified from torchvision to also return the union 33 | '''Note that this function only supports shape (N,4)''' 34 | 35 | 36 | def box_iou(boxes1, boxes2): 37 | """ 38 | 39 | :param boxes1: (N, 4) (x1,y1,x2,y2) 40 | :param boxes2: (N, 4) (x1,y1,x2,y2) 41 | :return: 42 | """ 43 | area1 = box_area(boxes1) # (N,) 44 | area2 = box_area(boxes2) # (N,) 45 | 46 | lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # (N,2) 47 | rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # (N,2) 48 | 49 | wh = (rb - lt).clamp(min=0) # (N,2) 50 | inter = wh[:, 0] * wh[:, 1] # (N,) 51 | 52 | union = area1 + area2 - inter 53 | 54 | iou = inter / union 55 | return iou, union 56 | 57 | 58 | '''Note that this implementation is different from DETR's''' 59 | 60 | 61 | def generalized_box_iou(boxes1, boxes2): 62 | """ 63 | Generalized IoU from https://giou.stanford.edu/ 64 | 65 | The boxes should be in [x0, y0, x1, y1] format 66 | 67 | boxes1: (N, 4) 68 | boxes2: (N, 4) 69 | """ 70 | # degenerate boxes gives inf / nan results 71 | # so do an early check 72 | # try: 73 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 74 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 75 | iou, union = box_iou(boxes1, boxes2) # (N,) 76 | 77 | lt = torch.min(boxes1[:, :2], boxes2[:, :2]) 78 | rb = torch.max(boxes1[:, 2:], boxes2[:, 2:]) 79 | 80 | wh = (rb - lt).clamp(min=0) # (N,2) 81 | area = wh[:, 0] * wh[:, 1] # (N,) 82 | 83 | return iou - (area - union) / area, iou 84 | 85 | 86 | def giou_loss(boxes1, boxes2): 87 | """ 88 | 89 | :param boxes1: (N, 4) (x1,y1,x2,y2) 90 | :param boxes2: (N, 4) (x1,y1,x2,y2) 91 | :return: 92 | """ 93 | giou, iou = generalized_box_iou(boxes1, boxes2) 94 | return (1 - giou).mean(), iou 95 | 96 | 97 | def clip_box(box: list, H, W, margin=0): 98 | x1, y1, w, h = box 99 | x2, y2 = x1 + w, y1 + h 100 | x1 = min(max(0, x1), W-margin) 101 | x2 = min(max(margin, x2), W) 102 | y1 = min(max(0, y1), H-margin) 103 | y2 = min(max(margin, y2), H) 104 | w = max(margin, x2-x1) 105 | h = max(margin, y2-y1) 106 | return [x1, y1, w, h] 107 | -------------------------------------------------------------------------------- /model/ostrack/utils/ce_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def generate_bbox_mask(bbox_mask, bbox): 8 | b, h, w = bbox_mask.shape 9 | for i in range(b): 10 | bbox_i = bbox[i].cpu().tolist() 11 | bbox_mask[i, int(bbox_i[1]):int(bbox_i[1] + bbox_i[3] - 1), int(bbox_i[0]):int(bbox_i[0] + bbox_i[2] - 1)] = 1 12 | return bbox_mask 13 | 14 | 15 | def generate_mask_cond(cfg, bs, device, gt_bbox): 16 | template_size = cfg.DATA.TEMPLATE.SIZE 17 | stride = cfg.MODEL.BACKBONE.STRIDE 18 | template_feat_size = template_size // stride 19 | 20 | if cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'ALL': 21 | box_mask_z = None 22 | elif cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'CTR_POINT': 23 | if template_feat_size == 8: 24 | index = slice(3, 4) 25 | elif template_feat_size == 12: 26 | index = slice(5, 6) 27 | elif template_feat_size == 7: 28 | index = slice(3, 4) 29 | elif template_feat_size == 14: 30 | index = slice(6, 7) 31 | else: 32 | raise NotImplementedError 33 | box_mask_z = torch.zeros([bs, template_feat_size, template_feat_size], device=device) 34 | box_mask_z[:, index, index] = 1 35 | box_mask_z = box_mask_z.flatten(1).to(torch.bool) 36 | elif cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'CTR_REC': 37 | # use fixed 4x4 region, 3:5 for 8x8 38 | # use fixed 4x4 region 5:6 for 12x12 39 | if template_feat_size == 8: 40 | index = slice(3, 5) 41 | elif template_feat_size == 12: 42 | index = slice(5, 7) 43 | elif template_feat_size == 7: 44 | index = slice(3, 4) 45 | else: 46 | raise NotImplementedError 47 | box_mask_z = torch.zeros([bs, template_feat_size, template_feat_size], device=device) 48 | box_mask_z[:, index, index] = 1 49 | box_mask_z = box_mask_z.flatten(1).to(torch.bool) 50 | 51 | elif cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'GT_BOX': 52 | box_mask_z = torch.zeros([bs, template_size, template_size], device=device) 53 | # box_mask_z_ori = data['template_seg'][0].view(-1, 1, *data['template_seg'].shape[2:]) # (batch, 1, 128, 128) 54 | box_mask_z = generate_bbox_mask(box_mask_z, gt_bbox * template_size).unsqueeze(1).to( 55 | torch.float) # (batch, 1, 128, 128) 56 | # box_mask_z_vis = box_mask_z.cpu().numpy() 57 | box_mask_z = F.interpolate(box_mask_z, scale_factor=1. / cfg.MODEL.BACKBONE.STRIDE, mode='bilinear', 58 | align_corners=False) 59 | box_mask_z = box_mask_z.flatten(1).to(torch.bool) 60 | # box_mask_z_vis = box_mask_z[:, 0, ...].cpu().numpy() 61 | # gaussian_maps_vis = generate_heatmap(data['template_anno'], self.cfg.DATA.TEMPLATE.SIZE, self.cfg.MODEL.STRIDE)[0].cpu().numpy() 62 | else: 63 | raise NotImplementedError 64 | 65 | return box_mask_z 66 | 67 | 68 | def adjust_keep_rate(epoch, warmup_epochs, total_epochs, ITERS_PER_EPOCH, base_keep_rate=0.5, max_keep_rate=1, iters=-1): 69 | if epoch < warmup_epochs: 70 | return 1 71 | if epoch >= total_epochs: 72 | return base_keep_rate 73 | if iters == -1: 74 | iters = epoch * ITERS_PER_EPOCH 75 | total_iters = ITERS_PER_EPOCH * (total_epochs - warmup_epochs) 76 | iters = iters - ITERS_PER_EPOCH * warmup_epochs 77 | keep_rate = base_keep_rate + (max_keep_rate - base_keep_rate) \ 78 | * (math.cos(iters / total_iters * math.pi) + 1) * 0.5 79 | 80 | return keep_rate 81 | -------------------------------------------------------------------------------- /model/ostrack/util.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def combine_tokens(template_tokens, search_tokens, mode='direct', return_res=False): 8 | # [B, HW, C] 9 | len_t = template_tokens.shape[1] 10 | len_s = search_tokens.shape[1] 11 | 12 | if mode == 'direct': 13 | merged_feature = torch.cat((template_tokens, search_tokens), dim=1) 14 | elif mode == 'template_central': 15 | central_pivot = len_s // 2 16 | first_half = search_tokens[:, :central_pivot, :] 17 | second_half = search_tokens[:, central_pivot:, :] 18 | merged_feature = torch.cat((first_half, template_tokens, second_half), dim=1) 19 | elif mode == 'partition': 20 | feat_size_s = int(math.sqrt(len_s)) 21 | feat_size_t = int(math.sqrt(len_t)) 22 | window_size = math.ceil(feat_size_t / 2.) 23 | # pad feature maps to multiples of window size 24 | B, _, C = template_tokens.shape 25 | H = W = feat_size_t 26 | template_tokens = template_tokens.view(B, H, W, C) 27 | pad_l = pad_b = pad_r = 0 28 | # pad_r = (window_size - W % window_size) % window_size 29 | pad_t = (window_size - H % window_size) % window_size 30 | template_tokens = F.pad(template_tokens, (0, 0, pad_l, pad_r, pad_t, pad_b)) 31 | _, Hp, Wp, _ = template_tokens.shape 32 | template_tokens = template_tokens.view(B, Hp // window_size, window_size, W, C) 33 | template_tokens = torch.cat([template_tokens[:, 0, ...], template_tokens[:, 1, ...]], dim=2) 34 | _, Hc, Wc, _ = template_tokens.shape 35 | template_tokens = template_tokens.view(B, -1, C) 36 | merged_feature = torch.cat([template_tokens, search_tokens], dim=1) 37 | 38 | # calculate new h and w, which may be useful for SwinT or others 39 | merged_h, merged_w = feat_size_s + Hc, feat_size_s 40 | if return_res: 41 | return merged_feature, merged_h, merged_w 42 | 43 | else: 44 | raise NotImplementedError 45 | 46 | return merged_feature 47 | 48 | 49 | def recover_tokens(merged_tokens, len_template_token, len_search_token, mode='direct'): 50 | if mode == 'direct': 51 | recovered_tokens = merged_tokens 52 | elif mode == 'template_central': 53 | central_pivot = len_search_token // 2 54 | len_remain = len_search_token - central_pivot 55 | len_half_and_t = central_pivot + len_template_token 56 | 57 | first_half = merged_tokens[:, :central_pivot, :] 58 | second_half = merged_tokens[:, -len_remain:, :] 59 | template_tokens = merged_tokens[:, central_pivot:len_half_and_t, :] 60 | 61 | recovered_tokens = torch.cat((template_tokens, first_half, second_half), dim=1) 62 | elif mode == 'partition': 63 | recovered_tokens = merged_tokens 64 | else: 65 | raise NotImplementedError 66 | 67 | return recovered_tokens 68 | 69 | 70 | def window_partition(x, window_size: int): 71 | """ 72 | Args: 73 | x: (B, H, W, C) 74 | window_size (int): window size 75 | 76 | Returns: 77 | windows: (num_windows*B, window_size, window_size, C) 78 | """ 79 | B, H, W, C = x.shape 80 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 81 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 82 | return windows 83 | 84 | 85 | def window_reverse(windows, window_size: int, H: int, W: int): 86 | """ 87 | Args: 88 | windows: (num_windows*B, window_size, window_size, C) 89 | window_size (int): Window size 90 | H (int): Height of image 91 | W (int): Width of image 92 | 93 | Returns: 94 | x: (B, H, W, C) 95 | """ 96 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 97 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 98 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 99 | return x 100 | -------------------------------------------------------------------------------- /model/ostrack/layers/rpe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from timm.models.layers import trunc_normal_ 4 | 5 | 6 | def generate_2d_relative_positional_encoding_index(z_shape, x_shape): 7 | ''' 8 | z_shape: (z_h, z_w) 9 | x_shape: (x_h, x_w) 10 | ''' 11 | z_2d_index_h, z_2d_index_w = torch.meshgrid(torch.arange(z_shape[0]), torch.arange(z_shape[1])) 12 | x_2d_index_h, x_2d_index_w = torch.meshgrid(torch.arange(x_shape[0]), torch.arange(x_shape[1])) 13 | 14 | z_2d_index_h = z_2d_index_h.flatten(0) 15 | z_2d_index_w = z_2d_index_w.flatten(0) 16 | x_2d_index_h = x_2d_index_h.flatten(0) 17 | x_2d_index_w = x_2d_index_w.flatten(0) 18 | 19 | diff_h = z_2d_index_h[:, None] - x_2d_index_h[None, :] 20 | diff_w = z_2d_index_w[:, None] - x_2d_index_w[None, :] 21 | 22 | diff = torch.stack((diff_h, diff_w), dim=-1) 23 | _, indices = torch.unique(diff.view(-1, 2), return_inverse=True, dim=0) 24 | return indices.view(z_shape[0] * z_shape[1], x_shape[0] * x_shape[1]) 25 | 26 | 27 | def generate_2d_concatenated_self_attention_relative_positional_encoding_index(z_shape, x_shape): 28 | ''' 29 | z_shape: (z_h, z_w) 30 | x_shape: (x_h, x_w) 31 | ''' 32 | z_2d_index_h, z_2d_index_w = torch.meshgrid(torch.arange(z_shape[0]), torch.arange(z_shape[1])) 33 | x_2d_index_h, x_2d_index_w = torch.meshgrid(torch.arange(x_shape[0]), torch.arange(x_shape[1])) 34 | 35 | z_2d_index_h = z_2d_index_h.flatten(0) 36 | z_2d_index_w = z_2d_index_w.flatten(0) 37 | x_2d_index_h = x_2d_index_h.flatten(0) 38 | x_2d_index_w = x_2d_index_w.flatten(0) 39 | 40 | concatenated_2d_index_h = torch.cat((z_2d_index_h, x_2d_index_h)) 41 | concatenated_2d_index_w = torch.cat((z_2d_index_w, x_2d_index_w)) 42 | 43 | diff_h = concatenated_2d_index_h[:, None] - concatenated_2d_index_h[None, :] 44 | diff_w = concatenated_2d_index_w[:, None] - concatenated_2d_index_w[None, :] 45 | 46 | z_len = z_shape[0] * z_shape[1] 47 | x_len = x_shape[0] * x_shape[1] 48 | a = torch.empty((z_len + x_len), dtype=torch.int64) 49 | a[:z_len] = 0 50 | a[z_len:] = 1 51 | b=a[:, None].repeat(1, z_len + x_len) 52 | c=a[None, :].repeat(z_len + x_len, 1) 53 | 54 | diff = torch.stack((diff_h, diff_w, b, c), dim=-1) 55 | _, indices = torch.unique(diff.view((z_len + x_len) * (z_len + x_len), 4), return_inverse=True, dim=0) 56 | return indices.view((z_len + x_len), (z_len + x_len)) 57 | 58 | 59 | def generate_2d_concatenated_cross_attention_relative_positional_encoding_index(z_shape, x_shape): 60 | ''' 61 | z_shape: (z_h, z_w) 62 | x_shape: (x_h, x_w) 63 | ''' 64 | z_2d_index_h, z_2d_index_w = torch.meshgrid(torch.arange(z_shape[0]), torch.arange(z_shape[1])) 65 | x_2d_index_h, x_2d_index_w = torch.meshgrid(torch.arange(x_shape[0]), torch.arange(x_shape[1])) 66 | 67 | z_2d_index_h = z_2d_index_h.flatten(0) 68 | z_2d_index_w = z_2d_index_w.flatten(0) 69 | x_2d_index_h = x_2d_index_h.flatten(0) 70 | x_2d_index_w = x_2d_index_w.flatten(0) 71 | 72 | concatenated_2d_index_h = torch.cat((z_2d_index_h, x_2d_index_h)) 73 | concatenated_2d_index_w = torch.cat((z_2d_index_w, x_2d_index_w)) 74 | 75 | diff_h = x_2d_index_h[:, None] - concatenated_2d_index_h[None, :] 76 | diff_w = x_2d_index_w[:, None] - concatenated_2d_index_w[None, :] 77 | 78 | z_len = z_shape[0] * z_shape[1] 79 | x_len = x_shape[0] * x_shape[1] 80 | 81 | a = torch.empty(z_len + x_len, dtype=torch.int64) 82 | a[: z_len] = 0 83 | a[z_len:] = 1 84 | c = a[None, :].repeat(x_len, 1) 85 | 86 | diff = torch.stack((diff_h, diff_w, c), dim=-1) 87 | _, indices = torch.unique(diff.view(x_len * (z_len + x_len), 3), return_inverse=True, dim=0) 88 | return indices.view(x_len, (z_len + x_len)) 89 | 90 | 91 | class RelativePosition2DEncoder(nn.Module): 92 | def __init__(self, num_heads, embed_size): 93 | super(RelativePosition2DEncoder, self).__init__() 94 | self.relative_position_bias_table = nn.Parameter(torch.empty((num_heads, embed_size))) 95 | trunc_normal_(self.relative_position_bias_table, std=0.02) 96 | 97 | def forward(self, attn_rpe_index): 98 | ''' 99 | Args: 100 | attn_rpe_index (torch.Tensor): (*), any shape containing indices, max(attn_rpe_index) < embed_size 101 | Returns: 102 | torch.Tensor: (1, num_heads, *) 103 | ''' 104 | return self.relative_position_bias_table[:, attn_rpe_index].unsqueeze(0) 105 | -------------------------------------------------------------------------------- /bbox.py: -------------------------------------------------------------------------------- 1 | """ 2 | TODO: 构建yolov5模型的解码输出,包括bbox的解码和置信度的解码 3 | 以及无人机的位置pixels信息 4 | 时间: 2025/03/11-Redal 5 | """ 6 | import os 7 | import cv2 8 | import random 9 | import torch 10 | import numpy as np 11 | from yolov5.utils.general import non_max_suppression, scale_boxes 12 | from yolov5.utils.augmentations import letterbox 13 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 14 | 15 | 16 | 17 | ########################### 用于YOLOv5模型解析输出 ############################################# 18 | def plot_one_box(x, img, color=None, label=None, line_thickness=None): 19 | """Plots one bounding box on image img""" 20 | tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 21 | color = color or [random.randint(0, 255) for _ in range(3)] 22 | c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3])) 23 | cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA) 24 | if label: 25 | tf = max(tl - 1, 1) 26 | t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0] 27 | c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3 28 | cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) 29 | cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA) 30 | return img 31 | 32 | 33 | def decoder(model, img0): 34 | """decode the yolo model output and plot the bounding box 35 | :param model: the trained yolo model consisting of s/m/l/x version 36 | :param img0: the uav frame image got from computer camera""" 37 | img = letterbox(img0, new_shape=640)[0] 38 | img = img[:, :, ::-1].transpose(2, 0, 1) 39 | img = np.ascontiguousarray(img) 40 | # yolo model inference and postprocess 41 | img = torch.from_numpy(img).to(device) 42 | img = img.float() / 255.0 43 | if img.ndimension() == 3: 44 | img = img.unsqueeze(0) 45 | with torch.no_grad(): 46 | pred = model(img)[0] 47 | # use NMS to remove redundant boxes 48 | conf_thres = 0.25 49 | iou_thres = 0.45 50 | pred = non_max_suppression(pred, conf_thres, iou_thres) 51 | for det in pred: 52 | if len(det): 53 | det[:, :4] = scale_boxes(img.shape[2:], det[:, :4], img0.shape).round() 54 | for *xyxy, conf, cls in reversed(det): 55 | # Xyxy contains the coordinates of the upper left and lower right corners 56 | # of the bounding box, conf is the confidence level, and cls is the category number. 57 | label = f'{model.names[int(cls)]} {conf:.2f}' 58 | print(f"Detected object: {label} at {xyxy}") 59 | plot_one_box(xyxy, img0, label=label, color=(0, 255, 0), line_thickness=3) 60 | try: return img0, xyxy 61 | except: return img0, None 62 | 63 | 64 | 65 | ########################### 用于OSTrack模型规范输入 ############################################# 66 | def ScaleClip(img, xyxy, mode=None): 67 | """ScaleClip is used to clip frame for template and search area 68 | :param img: the frame image must consists of UAV pixels 69 | :param xyxy: the up-left and down-right coordinates of the UAV bounding box""" 70 | img_array = np.array(img) 71 | width, height = xyxy[2] - xyxy[0], xyxy[3] - xyxy[1] 72 | center = np.array([xyxy[0] + width / 2, xyxy[1] + height / 2]) 73 | scale_factor = {'template': 2, 'search': 5}.get(mode, 0) 74 | scaled_width = int(scale_factor * width) 75 | scaled_height = int(scale_factor * height) 76 | # Calculate the cropping rectangle ensuring it does not exceed image boundaries. 77 | top_left_x = max(int(center[0] - scaled_width / 2), 0) 78 | top_left_y = max(int(center[1] - scaled_height / 2), 0) 79 | bottom_right_x = min(int(center[0] + scaled_width / 2), img_array.shape[1]) 80 | bottom_right_y = min(int(center[1] + scaled_height / 2), img_array.shape[0]) 81 | # Clip the image 82 | img_clipped = img_array[top_left_y:bottom_right_y, top_left_x:bottom_right_x, :] 83 | return img_clipped 84 | 85 | 86 | 87 | ############################# 主函数测试分析 ############################################# 88 | if __name__ == '__main__': 89 | img0 = cv2.imread('assets/uav_2.jpg') 90 | # img_templete = ScaleClip(img0, [150, 150, 250, 250], mode='template') 91 | # img_search = ScaleClip(img0, [150, 150, 250, 250], mode='search') 92 | # cv2.imshow('template', img_templete) 93 | # cv2.imshow('search', img_search) 94 | cv2.waitKey(0) 95 | cv2.destroyAllWindows() 96 | -------------------------------------------------------------------------------- /model/ostrack/layers/attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from timm.models.layers import trunc_normal_ 5 | 6 | from .rpe import generate_2d_concatenated_self_attention_relative_positional_encoding_index 7 | 8 | 9 | class Attention(nn.Module): 10 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., 11 | rpe=False, z_size=7, x_size=14): 12 | super().__init__() 13 | self.num_heads = num_heads 14 | head_dim = dim // num_heads 15 | self.scale = head_dim ** -0.5 16 | 17 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 18 | self.attn_drop = nn.Dropout(attn_drop) 19 | self.proj = nn.Linear(dim, dim) 20 | self.proj_drop = nn.Dropout(proj_drop) 21 | 22 | self.rpe =rpe 23 | if self.rpe: 24 | relative_position_index = \ 25 | generate_2d_concatenated_self_attention_relative_positional_encoding_index([z_size, z_size], 26 | [x_size, x_size]) 27 | self.register_buffer("relative_position_index", relative_position_index) 28 | # define a parameter table of relative position bias 29 | self.relative_position_bias_table = nn.Parameter(torch.empty((num_heads, 30 | relative_position_index.max() + 1))) 31 | trunc_normal_(self.relative_position_bias_table, std=0.02) 32 | 33 | def forward(self, x, mask=None, return_attention=False): 34 | # x: B, N, C 35 | # mask: [B, N, ] torch.bool 36 | B, N, C = x.shape 37 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 38 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 39 | 40 | attn = (q @ k.transpose(-2, -1)) * self.scale 41 | 42 | if self.rpe: 43 | relative_position_bias = self.relative_position_bias_table[:, self.relative_position_index].unsqueeze(0) 44 | attn += relative_position_bias 45 | 46 | if mask is not None: 47 | attn = attn.masked_fill(mask.unsqueeze(1).unsqueeze(2), float('-inf'),) 48 | 49 | attn = attn.softmax(dim=-1) 50 | attn = self.attn_drop(attn) 51 | 52 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 53 | x = self.proj(x) 54 | x = self.proj_drop(x) 55 | 56 | if return_attention: 57 | return x, attn 58 | else: 59 | return x 60 | 61 | 62 | class Attention_talking_head(nn.Module): 63 | # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 64 | # with slight modifications to add Talking Heads Attention (https://arxiv.org/pdf/2003.02436v1.pdf) 65 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., 66 | rpe=True, z_size=7, x_size=14): 67 | super().__init__() 68 | 69 | self.num_heads = num_heads 70 | 71 | head_dim = dim // num_heads 72 | 73 | self.scale = qk_scale or head_dim ** -0.5 74 | 75 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 76 | self.attn_drop = nn.Dropout(attn_drop) 77 | 78 | self.proj = nn.Linear(dim, dim) 79 | 80 | self.proj_l = nn.Linear(num_heads, num_heads) 81 | self.proj_w = nn.Linear(num_heads, num_heads) 82 | 83 | self.proj_drop = nn.Dropout(proj_drop) 84 | 85 | self.rpe = rpe 86 | if self.rpe: 87 | relative_position_index = \ 88 | generate_2d_concatenated_self_attention_relative_positional_encoding_index([z_size, z_size], 89 | [x_size, x_size]) 90 | self.register_buffer("relative_position_index", relative_position_index) 91 | # define a parameter table of relative position bias 92 | self.relative_position_bias_table = nn.Parameter(torch.empty((num_heads, 93 | relative_position_index.max() + 1))) 94 | trunc_normal_(self.relative_position_bias_table, std=0.02) 95 | 96 | def forward(self, x, mask=None): 97 | B, N, C = x.shape 98 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 99 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 100 | 101 | attn = (q @ k.transpose(-2, -1)) 102 | 103 | if self.rpe: 104 | relative_position_bias = self.relative_position_bias_table[:, self.relative_position_index].unsqueeze(0) 105 | attn += relative_position_bias 106 | 107 | if mask is not None: 108 | attn = attn.masked_fill(mask.unsqueeze(1).unsqueeze(2), 109 | float('-inf'),) 110 | 111 | attn = self.proj_l(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) 112 | 113 | attn = attn.softmax(dim=-1) 114 | 115 | attn = self.proj_w(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) 116 | attn = self.attn_drop(attn) 117 | 118 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 119 | x = self.proj(x) 120 | x = self.proj_drop(x) 121 | return x -------------------------------------------------------------------------------- /model/ostrack/ostrack.py: -------------------------------------------------------------------------------- 1 | """ 2 | Basic OSTrack model. 3 | """ 4 | import math 5 | import os 6 | from typing import List 7 | 8 | import torch 9 | from torch import nn 10 | from torch.nn.modules.transformer import _get_clones 11 | 12 | from .layers.head import build_box_head 13 | from .vit import vit_base_patch16_224 14 | from .vit_ce import vit_large_patch16_224_ce, vit_base_patch16_224_ce 15 | from .utils.box_ops import box_xyxy_to_cxcywh 16 | 17 | 18 | class OSTrack(nn.Module): 19 | """ This is the base class for OSTrack """ 20 | 21 | def __init__(self, transformer, box_head, aux_loss=False, head_type="CORNER"): 22 | """ Initializes the model. 23 | Parameters: 24 | transformer: torch module of the transformer architecture. 25 | aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. 26 | """ 27 | super().__init__() 28 | self.backbone = transformer 29 | self.box_head = box_head 30 | 31 | self.aux_loss = aux_loss 32 | self.head_type = head_type 33 | if head_type == "CORNER" or head_type == "CENTER": 34 | self.feat_sz_s = int(box_head.feat_sz) 35 | self.feat_len_s = int(box_head.feat_sz ** 2) 36 | 37 | if self.aux_loss: 38 | self.box_head = _get_clones(self.box_head, 6) 39 | 40 | def forward(self, template: torch.Tensor, 41 | search: torch.Tensor, 42 | ce_template_mask=None, 43 | ce_keep_rate=None, 44 | return_last_attn=False, 45 | ): 46 | x, aux_dict = self.backbone(z=template, x=search, 47 | ce_template_mask=ce_template_mask, 48 | ce_keep_rate=ce_keep_rate, 49 | return_last_attn=return_last_attn, ) 50 | 51 | # Forward head 52 | feat_last = x 53 | if isinstance(x, list): 54 | feat_last = x[-1] 55 | out = self.forward_head(feat_last, None) 56 | 57 | out.update(aux_dict) 58 | out['backbone_feat'] = x 59 | return out 60 | 61 | def forward_head(self, cat_feature, gt_score_map=None): 62 | """ 63 | cat_feature: output embeddings of the backbone, it can be (HW1+HW2, B, C) or (HW2, B, C) 64 | """ 65 | enc_opt = cat_feature[:, -self.feat_len_s:] # encoder output for the search region (B, HW, C) 66 | opt = (enc_opt.unsqueeze(-1)).permute((0, 3, 2, 1)).contiguous() 67 | bs, Nq, C, HW = opt.size() 68 | opt_feat = opt.view(-1, C, self.feat_sz_s, self.feat_sz_s) 69 | 70 | if self.head_type == "CORNER": 71 | # run the corner head 72 | pred_box, score_map = self.box_head(opt_feat, True) 73 | outputs_coord = box_xyxy_to_cxcywh(pred_box) 74 | outputs_coord_new = outputs_coord.view(bs, Nq, 4) 75 | out = {'pred_boxes': outputs_coord_new, 76 | 'score_map': score_map, 77 | } 78 | return out 79 | 80 | elif self.head_type == "CENTER": 81 | # run the center head 82 | score_map_ctr, bbox, size_map, offset_map = self.box_head(opt_feat, gt_score_map) 83 | # outputs_coord = box_xyxy_to_cxcywh(bbox) 84 | outputs_coord = bbox 85 | outputs_coord_new = outputs_coord.view(bs, Nq, 4) 86 | out = {'pred_boxes': outputs_coord_new, 87 | 'score_map': score_map_ctr, 88 | 'size_map': size_map, 89 | 'offset_map': offset_map} 90 | return out 91 | else: 92 | raise NotImplementedError 93 | 94 | 95 | def build_ostrack(cfg, training=True): 96 | current_dir = os.path.dirname(os.path.abspath(__file__)) # This is your Project Root 97 | pretrained_path = os.path.join(current_dir, '../../../pretrained_models') 98 | if cfg['MODEL']['PRETRAIN_FILE'] and ('OSTrack' not in cfg['MODEL']['PRETRAIN_FILE']) and training: 99 | pretrained = os.path.join(pretrained_path, cfg['MODEL']['PRETRAIN_FILE']) 100 | else: 101 | pretrained = '' 102 | if cfg['MODEL']['BACKBONE']['TYPE'] == 'vit_base_patch16_224': 103 | backbone = vit_base_patch16_224(pretrained, drop_path_rate=cfg['TRAIN']['DROP_PATH_RATE']) 104 | hidden_dim = backbone.embed_dim 105 | patch_start_index = 1 106 | 107 | elif cfg['MODEL']['BACKBONE']['TYPE'] == 'vit_base_patch16_224_ce': 108 | backbone = vit_base_patch16_224_ce(pretrained, drop_path_rate=cfg['TRAIN']['DROP_PATH_RATE'], 109 | ce_loc=cfg['MODEL']['BACKBONE']['CE_LOC'], 110 | ce_keep_ratio=cfg['MODEL']['BACKBONE']['CE_KEEP_RATIO'], 111 | ) 112 | hidden_dim = backbone.embed_dim 113 | patch_start_index = 1 114 | 115 | elif cfg['MODEL']['BACKBONE']['TYPE'] == 'vit_large_patch16_224_ce': 116 | backbone = vit_large_patch16_224_ce(pretrained, drop_path_rate=cfg['TRAIN']['DROP_PATH_RATE'], 117 | ce_loc=cfg['MODEL']['BACKBONE']['CE_LOC'], 118 | ce_keep_ratio=cfg['MODEL']['BACKBONE']['CE_KEEP_RATIO'], 119 | ) 120 | 121 | hidden_dim = backbone.embed_dim 122 | patch_start_index = 1 123 | 124 | else: 125 | raise NotImplementedError 126 | 127 | backbone.finetune_track(cfg=cfg, patch_start_index=patch_start_index) 128 | 129 | box_head = build_box_head(cfg, hidden_dim) 130 | 131 | model = OSTrack( 132 | backbone, 133 | box_head, 134 | aux_loss=False, 135 | head_type=cfg['MODEL']['HEAD']['TYPE'], 136 | ) 137 | 138 | if 'OSTrack' in cfg['MODEL']['PRETRAIN_FILE'] and training: 139 | checkpoint = torch.load(cfg['MODEL']['PRETRAIN_FILE'], map_location="cpu") 140 | missing_keys, unexpected_keys = model.load_state_dict(checkpoint["net"], strict=False) 141 | print('Load pretrained model from: ' + cfg['MODEL']['PRETRAIN_FILE']) 142 | 143 | return model 144 | -------------------------------------------------------------------------------- /model/ostrack/layers/attn_blocks.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from timm.models.layers import Mlp, DropPath, trunc_normal_, lecun_normal_ 5 | 6 | from .attn import Attention 7 | 8 | 9 | def candidate_elimination(attn: torch.Tensor, tokens: torch.Tensor, lens_t: int, keep_ratio: float, global_index: torch.Tensor, box_mask_z: torch.Tensor): 10 | """ 11 | Eliminate potential background candidates for computation reduction and noise cancellation. 12 | Args: 13 | attn (torch.Tensor): [B, num_heads, L_t + L_s, L_t + L_s], attention weights 14 | tokens (torch.Tensor): [B, L_t + L_s, C], template and search region tokens 15 | lens_t (int): length of template 16 | keep_ratio (float): keep ratio of search region tokens (candidates) 17 | global_index (torch.Tensor): global index of search region tokens 18 | box_mask_z (torch.Tensor): template mask used to accumulate attention weights 19 | 20 | Returns: 21 | tokens_new (torch.Tensor): tokens after candidate elimination 22 | keep_index (torch.Tensor): indices of kept search region tokens 23 | removed_index (torch.Tensor): indices of removed search region tokens 24 | """ 25 | lens_s = attn.shape[-1] - lens_t 26 | bs, hn, _, _ = attn.shape 27 | 28 | lens_keep = math.ceil(keep_ratio * lens_s) 29 | if lens_keep == lens_s: 30 | return tokens, global_index, None 31 | 32 | attn_t = attn[:, :, :lens_t, lens_t:] 33 | 34 | if box_mask_z is not None: 35 | box_mask_z = box_mask_z.unsqueeze(1).unsqueeze(-1).expand(-1, attn_t.shape[1], -1, attn_t.shape[-1]) 36 | # attn_t = attn_t[:, :, box_mask_z, :] 37 | attn_t = attn_t[box_mask_z] 38 | attn_t = attn_t.view(bs, hn, -1, lens_s) 39 | attn_t = attn_t.mean(dim=2).mean(dim=1) # B, H, L-T, L_s --> B, L_s 40 | 41 | # attn_t = [attn_t[i, :, box_mask_z[i, :], :] for i in range(attn_t.size(0))] 42 | # attn_t = [attn_t[i].mean(dim=1).mean(dim=0) for i in range(len(attn_t))] 43 | # attn_t = torch.stack(attn_t, dim=0) 44 | else: 45 | attn_t = attn_t.mean(dim=2).mean(dim=1) # B, H, L-T, L_s --> B, L_s 46 | 47 | # use sort instead of topk, due to the speed issue 48 | # https://github.com/pytorch/pytorch/issues/22812 49 | sorted_attn, indices = torch.sort(attn_t, dim=1, descending=True) 50 | 51 | topk_attn, topk_idx = sorted_attn[:, :lens_keep], indices[:, :lens_keep] 52 | non_topk_attn, non_topk_idx = sorted_attn[:, lens_keep:], indices[:, lens_keep:] 53 | 54 | keep_index = global_index.gather(dim=1, index=topk_idx) 55 | removed_index = global_index.gather(dim=1, index=non_topk_idx) 56 | 57 | # separate template and search tokens 58 | tokens_t = tokens[:, :lens_t] 59 | tokens_s = tokens[:, lens_t:] 60 | 61 | # obtain the attentive and inattentive tokens 62 | B, L, C = tokens_s.shape 63 | # topk_idx_ = topk_idx.unsqueeze(-1).expand(B, lens_keep, C) 64 | attentive_tokens = tokens_s.gather(dim=1, index=topk_idx.unsqueeze(-1).expand(B, -1, C)) 65 | # inattentive_tokens = tokens_s.gather(dim=1, index=non_topk_idx.unsqueeze(-1).expand(B, -1, C)) 66 | 67 | # compute the weighted combination of inattentive tokens 68 | # fused_token = non_topk_attn @ inattentive_tokens 69 | 70 | # concatenate these tokens 71 | # tokens_new = torch.cat([tokens_t, attentive_tokens, fused_token], dim=0) 72 | tokens_new = torch.cat([tokens_t, attentive_tokens], dim=1) 73 | 74 | return tokens_new, keep_index, removed_index 75 | 76 | 77 | class CEBlock(nn.Module): 78 | 79 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 80 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, keep_ratio_search=1.0,): 81 | super().__init__() 82 | self.norm1 = norm_layer(dim) 83 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 84 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 85 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 86 | self.norm2 = norm_layer(dim) 87 | mlp_hidden_dim = int(dim * mlp_ratio) 88 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 89 | 90 | self.keep_ratio_search = keep_ratio_search 91 | 92 | def forward(self, x, global_index_template, global_index_search, mask=None, ce_template_mask=None, keep_ratio_search=None): 93 | x_attn, attn = self.attn(self.norm1(x), mask, True) 94 | x = x + self.drop_path(x_attn) 95 | lens_t = global_index_template.shape[1] 96 | 97 | removed_index_search = None 98 | if self.keep_ratio_search < 1 and (keep_ratio_search is None or keep_ratio_search < 1): 99 | keep_ratio_search = self.keep_ratio_search if keep_ratio_search is None else keep_ratio_search 100 | x, global_index_search, removed_index_search = candidate_elimination(attn, x, lens_t, keep_ratio_search, global_index_search, ce_template_mask) 101 | 102 | x = x + self.drop_path(self.mlp(self.norm2(x))) 103 | return x, global_index_template, global_index_search, removed_index_search, attn 104 | 105 | 106 | class Block(nn.Module): 107 | 108 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 109 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 110 | super().__init__() 111 | self.norm1 = norm_layer(dim) 112 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 113 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 114 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 115 | self.norm2 = norm_layer(dim) 116 | mlp_hidden_dim = int(dim * mlp_ratio) 117 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 118 | 119 | def forward(self, x, mask=None): 120 | x = x + self.drop_path(self.attn(self.norm1(x), mask)) 121 | x = x + self.drop_path(self.mlp(self.norm2(x))) 122 | return x 123 | -------------------------------------------------------------------------------- /yolo_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | TODO: 使用YOLOv5进行无人机图像检测定位测试, 3 | 数据集采用红外影像数据./datasets/test1 4 | 进行分割重构数据集 5 | Time: 2025/03/07-Redal 6 | """ 7 | import pathlib 8 | temp = pathlib.PosixPath 9 | pathlib.PosixPath = pathlib.WindowsPath 10 | 11 | import os 12 | import cv2 13 | import argparse 14 | import sys 15 | import torch 16 | import numpy as np 17 | import matplotlib.pyplot as plt 18 | from yolov5.models.experimental import attempt_load 19 | from bbox import decoder 20 | 21 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 22 | proj_path = os.getcwd() 23 | sys.path.append(os.path.join(proj_path, "yolov5")) 24 | sys.path.append(os.path.abspath(os.path.join(os.getcwd(), ".."))) 25 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 26 | 27 | 28 | 29 | ############################## 配置解析文件变量 ############################### 30 | def parser(): 31 | parser = argparse.ArgumentParser(description='YOLOv5 inference config', 32 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 33 | parser.add_argument_group('YOLOv5 Dataset Reconstruction') 34 | parser.add_argument('--gt_dir', type=str, default=r'datasets\images', help='ground truth file directory path(s)') 35 | parser.add_argument('--gt_mode', type=str, default=r'train', help='ground truth file mode [train, val]') 36 | parser.add_argument('--gt_file', type=str, default=r'groundtruth.txt', help='ground truth file name') 37 | parser.add_argument('--label_dir', type=str, default=r'datasets\labels', help='labels file directory path(s)') 38 | parser.add_argument('--label_mode', type=str, default=r'train', help='labels file mode [train, val]') 39 | parser.add_argument('--origin_dir', type=str, default=r'datasets\test', help='images file directory path(s)') 40 | # related to YOLOv5 model and weights 41 | parser.add_argument_group('YOLOv5 Model Configuration') 42 | parser.add_argument('--weights_dir', type=str, default=r'weights', help='weights file directory path(s)') 43 | parser.add_argument('--weights_file', type=str, default=r'yolov5l.pt', help='weights file name yolov5s/m/l/x.pt') 44 | parser.add_argument('--yolo_dir', type=str, default=r'yolov5', help='yolov5 file directory path(s)') 45 | args = parser.parse_args() 46 | return args 47 | 48 | 49 | 50 | ############################## 重构Got_10k数据集 ############################### 51 | def split_groundtruth_to_individual_labels(args): 52 | """Split the label information in the groundtruth.txt file into the txt file for each image 53 | :param input image is 640x512 pixels width x height""" 54 | gt_filepath = os.path.join(args.gt_dir, args.gt_mode, args.gt_file) 55 | output_folder = os.path.join(args.label_dir, args.label_mode) 56 | img_w = 640; img_h = 512 57 | with open(gt_filepath, 'r') as f: 58 | lines = f.readlines() 59 | for idx,line in enumerate(lines): 60 | parts = line.strip().split(',') 61 | lr_w, lr_h = float(parts[0]), float(parts[1]) 62 | box_w, box_h = float(parts[2]), float(parts[3]) 63 | # normalization 64 | cneter_w, center_h = (lr_w + box_w / 2) / img_w , (lr_h + box_h / 2) / img_h 65 | bounding_w, bounding_h = box_w / img_w, box_h / img_h 66 | txt_file_path = os.path.join(output_folder, f'{idx:08d}.txt') 67 | with open(txt_file_path, 'w') as txt_file: 68 | txt_file.write(f'{0} {cneter_w} {center_h} {bounding_w} {bounding_h}\n') 69 | print(f'{txt_file_path} has been written', end='\r', flush=True) 70 | 71 | 72 | def rebuild_data(args): 73 | """Rebuild the images/labels dataset form test directory""" 74 | img_w = 640; img_h = 512 75 | origin_dir = args.origin_dir 76 | original_dirspath = [os.path.join(origin_dir, dirname) for dirname in os.listdir(origin_dir)] 77 | for idx, dirpath in enumerate(original_dirspath): 78 | with open(os.path.join(dirpath, 'groundtruth.txt')) as gt: 79 | gt_lines = gt.readlines() 80 | if idx % 2 == 0: 81 | args.gt_mode = 'train' 82 | args.label_mode = 'train' 83 | else: 84 | args.gt_mode = 'val' 85 | args.label_mode = 'val' 86 | for id,line in enumerate(gt_lines): 87 | num_files = len(os.listdir(os.path.join(args.gt_dir, args.gt_mode))) 88 | # read image and save it to train images folder 89 | img_path = os.path.join(dirpath, f'{id:08d}.jpg') 90 | img = cv2.imread(img_path) 91 | img_save_path = os.path.join(args.gt_dir, args.gt_mode, f'{num_files:08d}.jpg') 92 | cv2.imwrite(img_save_path, img) 93 | # read the groundtruth.txt file into train labels folder 94 | parts = line.strip().split(',') 95 | lr_w, lr_h = float(parts[0]), float(parts[1]) 96 | box_w, box_h = float(parts[2]), float(parts[3]) 97 | cneter_w, center_h = (lr_w + box_w / 2) / img_w , (lr_h + box_h / 2) / img_h 98 | bounding_w, bounding_h = box_w / img_w, box_h / img_h 99 | txt_filepath = os.path.join(args.label_dir, args.label_mode, f'{num_files:08d}.txt') 100 | with open(txt_filepath, 'w') as txtf: 101 | txtf.write(f'{0} {cneter_w} {center_h} {bounding_w} {bounding_h}\n') 102 | print(f'{dirpath}: {txt_filepath} and {img_save_path} has been written', end='\r', flush=True) 103 | 104 | 105 | 106 | ############################## 部署Yolov5模型 ############################### 107 | def load_yolo(args): 108 | """Load the yolov5 model weights and return the model 109 | :param args.weights_dir: weights file directory path(s)""" 110 | weights_fp = os.path.join(args.weights_dir, args.weights_file) 111 | print(f'Loading weights from {weights_fp} and yolov5 is {args.yolo_dir}') 112 | yolo_model = attempt_load(weights=weights_fp, device=device) 113 | return yolo_model.eval() 114 | 115 | 116 | 117 | ############################## 主函数测试分析 ############################### 118 | if __name__ == '__main__': 119 | args = parser() 120 | model = load_yolo(args).to(device) 121 | # print(model) 122 | 123 | img0 = cv2.imread(r'assets\uav_1.jpg') 124 | img_de = decoder(model, img0) 125 | cv2.imshow('img', img_de) 126 | cv2.waitKey(0) 127 | cv2.destroyAllWindows() -------------------------------------------------------------------------------- /model/ostrack/utils/heapmap_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def generate_heatmap(bboxes, patch_size=320, stride=16): 6 | """ 7 | Generate ground truth heatmap same as CenterNet 8 | Args: 9 | bboxes (torch.Tensor): shape of [num_search, bs, 4] 10 | 11 | Returns: 12 | gaussian_maps: list of generated heatmap 13 | 14 | """ 15 | gaussian_maps = [] 16 | heatmap_size = patch_size // stride 17 | for single_patch_bboxes in bboxes: 18 | bs = single_patch_bboxes.shape[0] 19 | gt_scoremap = torch.zeros(bs, heatmap_size, heatmap_size) 20 | classes = torch.arange(bs).to(torch.long) 21 | bbox = single_patch_bboxes * heatmap_size 22 | wh = bbox[:, 2:] 23 | centers_int = (bbox[:, :2] + wh / 2).round() 24 | CenterNetHeatMap.generate_score_map(gt_scoremap, classes, wh, centers_int, 0.7) # 生成的高斯热图放到gt_scoremap中 25 | gaussian_maps.append(gt_scoremap.to(bbox.device)) 26 | return gaussian_maps 27 | 28 | 29 | class CenterNetHeatMap(object): 30 | @staticmethod 31 | def generate_score_map(fmap, gt_class, gt_wh, centers_int, min_overlap): 32 | radius = CenterNetHeatMap.get_gaussian_radius(gt_wh, min_overlap) # 生成高斯核的半径 33 | radius = torch.clamp_min(radius, 0) 34 | radius = radius.type(torch.int).cpu().numpy() 35 | for i in range(gt_class.shape[0]): 36 | channel_index = gt_class[i] 37 | CenterNetHeatMap.draw_gaussian(fmap[channel_index], centers_int[i], radius[i]) # 根据中心点坐标和半径画出高斯热图,并放到fmap中 38 | 39 | @staticmethod 40 | def get_gaussian_radius(box_size, min_overlap): 41 | """ 42 | copyed from CornerNet 43 | box_size (w, h), it could be a torch.Tensor, numpy.ndarray, list or tuple 44 | notice: we are using a bug-version, please refer to fix bug version in CornerNet 45 | """ 46 | # box_tensor = torch.Tensor(box_size) 47 | box_tensor = box_size 48 | width, height = box_tensor[..., 0], box_tensor[..., 1] 49 | 50 | a1 = 1 51 | b1 = height + width 52 | c1 = width * height * (1 - min_overlap) / (1 + min_overlap) 53 | sq1 = torch.sqrt(b1 ** 2 - 4 * a1 * c1) 54 | r1 = (b1 + sq1) / 2 55 | 56 | a2 = 4 57 | b2 = 2 * (height + width) 58 | c2 = (1 - min_overlap) * width * height 59 | sq2 = torch.sqrt(b2 ** 2 - 4 * a2 * c2) 60 | r2 = (b2 + sq2) / 2 61 | 62 | a3 = 4 * min_overlap 63 | b3 = -2 * min_overlap * (height + width) 64 | c3 = (min_overlap - 1) * width * height 65 | sq3 = torch.sqrt(b3 ** 2 - 4 * a3 * c3) 66 | r3 = (b3 + sq3) / 2 67 | 68 | return torch.min(r1, torch.min(r2, r3)) 69 | 70 | @staticmethod 71 | def gaussian2D(radius, sigma=1): 72 | # m, n = [(s - 1.) / 2. for s in shape] 73 | m, n = radius 74 | y, x = np.ogrid[-m: m + 1, -n: n + 1] # 生成网格,y和x代表两个方向 75 | 76 | gauss = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) # 对应原点 77 | gauss[gauss < np.finfo(gauss.dtype).eps * gauss.max()] = 0 # 将高斯核中值非常接近于零的元素设置为零 78 | return gauss 79 | 80 | @staticmethod 81 | def draw_gaussian(fmap, center, radius, k=1): 82 | diameter = 2 * radius + 1 83 | gaussian = CenterNetHeatMap.gaussian2D((radius, radius), sigma=diameter / 6) # 生成二维高斯核,没有做归一化处理,只有中间的那个数为1 84 | gaussian = torch.Tensor(gaussian) 85 | x, y = int(center[0]), int(center[1]) 86 | height, width = fmap.shape[:2] 87 | 88 | left, right = min(x, radius), min(width - x, radius + 1) 89 | top, bottom = min(y, radius), min(height - y, radius + 1) 90 | 91 | masked_fmap = fmap[y - top: y + bottom, x - left: x + right] 92 | masked_gaussian = gaussian[radius - top: radius + bottom, radius - left: radius + right] 93 | if min(masked_gaussian.shape) > 0 and min(masked_fmap.shape) > 0: 94 | masked_fmap = torch.max(masked_fmap, masked_gaussian * k) 95 | fmap[y - top: y + bottom, x - left: x + right] = masked_fmap 96 | # return fmap 97 | 98 | 99 | def compute_grids(features, strides): 100 | """ 101 | grids regret to the input image size 102 | """ 103 | grids = [] 104 | for level, feature in enumerate(features): 105 | h, w = feature.size()[-2:] 106 | shifts_x = torch.arange( 107 | 0, w * strides[level], 108 | step=strides[level], 109 | dtype=torch.float32, device=feature.device) 110 | shifts_y = torch.arange( 111 | 0, h * strides[level], 112 | step=strides[level], 113 | dtype=torch.float32, device=feature.device) 114 | shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) 115 | shift_x = shift_x.reshape(-1) 116 | shift_y = shift_y.reshape(-1) 117 | grids_per_level = torch.stack((shift_x, shift_y), dim=1) + \ 118 | strides[level] // 2 119 | grids.append(grids_per_level) 120 | return grids 121 | 122 | 123 | def get_center3x3(locations, centers, strides, range=3): 124 | ''' 125 | Inputs: 126 | locations: M x 2 127 | centers: N x 2 128 | strides: M 129 | ''' 130 | range = (range - 1) / 2 131 | M, N = locations.shape[0], centers.shape[0] 132 | locations_expanded = locations.view(M, 1, 2).expand(M, N, 2) # M x N x 2 133 | centers_expanded = centers.view(1, N, 2).expand(M, N, 2) # M x N x 2 134 | strides_expanded = strides.view(M, 1, 1).expand(M, N, 2) # M x N 135 | centers_discret = ((centers_expanded / strides_expanded).int() * strides_expanded).float() + \ 136 | strides_expanded / 2 # M x N x 2 137 | dist_x = (locations_expanded[:, :, 0] - centers_discret[:, :, 0]).abs() 138 | dist_y = (locations_expanded[:, :, 1] - centers_discret[:, :, 1]).abs() 139 | return (dist_x <= strides_expanded[:, :, 0] * range) & \ 140 | (dist_y <= strides_expanded[:, :, 0] * range) 141 | 142 | 143 | def get_pred(score_map_ctr, size_map, offset_map, feat_size): 144 | max_score, idx = torch.max(score_map_ctr.flatten(1), dim=1, keepdim=True) 145 | 146 | idx = idx.unsqueeze(1).expand(idx.shape[0], 2, 1) 147 | size = size_map.flatten(2).gather(dim=2, index=idx).squeeze(-1) 148 | offset = offset_map.flatten(2).gather(dim=2, index=idx).squeeze(-1) 149 | 150 | return size * feat_size, offset 151 | -------------------------------------------------------------------------------- /model/ostrack/base_backbone.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from timm.models.vision_transformer import resize_pos_embed 6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 7 | 8 | from .layers.patch_embed import PatchEmbed 9 | from .util import combine_tokens, recover_tokens 10 | 11 | 12 | class BaseBackbone(nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | # for original ViT 17 | self.pos_embed = None 18 | self.img_size = [224, 224] 19 | self.patch_size = 16 20 | self.embed_dim = 384 21 | 22 | self.cat_mode = 'direct' 23 | 24 | self.pos_embed_z = None 25 | self.pos_embed_x = None 26 | 27 | self.template_segment_pos_embed = None 28 | self.search_segment_pos_embed = None 29 | 30 | self.return_inter = False 31 | self.return_stage = [2, 5, 8, 11] 32 | 33 | self.add_cls_token = False 34 | self.add_sep_seg = False 35 | 36 | def finetune_track(self, cfg, patch_start_index=1): 37 | 38 | search_size = to_2tuple(cfg['DATA']['SEARCH']['SIZE']) 39 | template_size = to_2tuple(cfg['DATA']['TEMPLATE']['SIZE']) 40 | new_patch_size = cfg['MODEL']['BACKBONE']['STRIDE'] 41 | 42 | self.cat_mode = cfg['MODEL']['BACKBONE']['CAT_MODE'] 43 | self.return_inter = cfg['MODEL']['RETURN_INTER'] 44 | self.return_stage = cfg['MODEL']['RETURN_STAGES'] 45 | self.add_sep_seg = cfg['MODEL']['BACKBONE']['SEP_SEG'] 46 | 47 | # resize patch embedding 48 | if new_patch_size != self.patch_size: 49 | print('Inconsistent Patch Size With The Pretrained Weights, Interpolate The Weight!') 50 | old_patch_embed = {} 51 | for name, param in self.patch_embed.named_parameters(): 52 | if 'weight' in name: 53 | param = nn.functional.interpolate(param, size=(new_patch_size, new_patch_size), 54 | mode='bicubic', align_corners=False) 55 | param = nn.Parameter(param) 56 | old_patch_embed[name] = param 57 | self.patch_embed = PatchEmbed(img_size=self.img_size, patch_size=new_patch_size, in_chans=3, 58 | embed_dim=self.embed_dim) 59 | self.patch_embed.proj.bias = old_patch_embed['proj.bias'] 60 | self.patch_embed.proj.weight = old_patch_embed['proj.weight'] 61 | 62 | # for patch embedding 63 | patch_pos_embed = self.pos_embed[:, patch_start_index:, :] 64 | patch_pos_embed = patch_pos_embed.transpose(1, 2) 65 | B, E, Q = patch_pos_embed.shape 66 | P_H, P_W = self.img_size[0] // self.patch_size, self.img_size[1] // self.patch_size 67 | patch_pos_embed = patch_pos_embed.view(B, E, P_H, P_W) 68 | 69 | # for search region 70 | H, W = search_size 71 | new_P_H, new_P_W = H // new_patch_size, W // new_patch_size 72 | search_patch_pos_embed = nn.functional.interpolate(patch_pos_embed, size=(new_P_H, new_P_W), mode='bicubic', 73 | align_corners=False) 74 | search_patch_pos_embed = search_patch_pos_embed.flatten(2).transpose(1, 2) 75 | 76 | # for template region 77 | H, W = template_size 78 | new_P_H, new_P_W = H // new_patch_size, W // new_patch_size 79 | template_patch_pos_embed = nn.functional.interpolate(patch_pos_embed, size=(new_P_H, new_P_W), mode='bicubic', 80 | align_corners=False) 81 | template_patch_pos_embed = template_patch_pos_embed.flatten(2).transpose(1, 2) 82 | 83 | self.pos_embed_z = nn.Parameter(template_patch_pos_embed) 84 | self.pos_embed_x = nn.Parameter(search_patch_pos_embed) 85 | 86 | # for cls token (keep it but not used) 87 | if self.add_cls_token and patch_start_index > 0: 88 | cls_pos_embed = self.pos_embed[:, 0:1, :] 89 | self.cls_pos_embed = nn.Parameter(cls_pos_embed) 90 | 91 | # separate token and segment token 92 | if self.add_sep_seg: 93 | self.template_segment_pos_embed = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) 94 | self.template_segment_pos_embed = trunc_normal_(self.template_segment_pos_embed, std=.02) 95 | self.search_segment_pos_embed = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) 96 | self.search_segment_pos_embed = trunc_normal_(self.search_segment_pos_embed, std=.02) 97 | 98 | # self.cls_token = None 99 | # self.pos_embed = None 100 | 101 | if self.return_inter: 102 | for i_layer in self.return_stage: 103 | if i_layer != 11: 104 | norm_layer = partial(nn.LayerNorm, eps=1e-6) 105 | layer = norm_layer(self.embed_dim) 106 | layer_name = f'norm{i_layer}' 107 | self.add_module(layer_name, layer) 108 | 109 | def forward_features(self, z, x): 110 | B, H, W = x.shape[0], x.shape[2], x.shape[3] 111 | 112 | x = self.patch_embed(x) 113 | z = self.patch_embed(z) 114 | 115 | if self.add_cls_token: 116 | cls_tokens = self.cls_token.expand(B, -1, -1) 117 | cls_tokens = cls_tokens + self.cls_pos_embed 118 | 119 | z += self.pos_embed_z 120 | x += self.pos_embed_x 121 | 122 | if self.add_sep_seg: 123 | x += self.search_segment_pos_embed 124 | z += self.template_segment_pos_embed 125 | 126 | x = combine_tokens(z, x, mode=self.cat_mode) 127 | if self.add_cls_token: 128 | x = torch.cat([cls_tokens, x], dim=1) 129 | 130 | x = self.pos_drop(x) 131 | 132 | for i, blk in enumerate(self.blocks): 133 | x = blk(x) 134 | 135 | lens_z = self.pos_embed_z.shape[1] 136 | lens_x = self.pos_embed_x.shape[1] 137 | x = recover_tokens(x, lens_z, lens_x, mode=self.cat_mode) 138 | 139 | aux_dict = {"attn": None} 140 | return self.norm(x), aux_dict 141 | 142 | def forward(self, z, x, **kwargs): 143 | """ 144 | Joint feature extraction and relation modeling for the basic ViT backbone. 145 | Args: 146 | z (torch.Tensor): template feature, [B, C, H_z, W_z] 147 | x (torch.Tensor): search region feature, [B, C, H_x, W_x] 148 | 149 | Returns: 150 | x (torch.Tensor): merged template and search region feature, [B, L_z+L_x, C] 151 | attn : None 152 | """ 153 | x, aux_dict = self.forward_features(z, x,) 154 | 155 | return x, aux_dict 156 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | """ 2 | 任务: 创建基本的OSTrack模型所需的实例已经辅助函数, 3 | 同时包括相关的机器视觉处理函数 4 | 时间: 2025/01/13-Redal 5 | """ 6 | import os 7 | import sys 8 | import cv2 9 | import yaml 10 | import torch 11 | import argparse 12 | import numpy as np 13 | from PIL import Image, ImageDraw 14 | from torchvision.transforms import transforms 15 | from model.ostrack.ostrack import build_ostrack 16 | import matplotlib.pyplot as plt 17 | 18 | current_path = os.path.abspath(os.path.dirname(__file__)) 19 | sys.path.append(current_path) 20 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 21 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 22 | template_transform = transforms.Compose([ 23 | transforms.ToTensor(), 24 | transforms.Resize((192, 192)), 25 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),]) 26 | search_transform = transforms.Compose([ 27 | transforms.ToTensor(), 28 | transforms.Resize((384, 384)), 29 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),]) 30 | 31 | 32 | 33 | ############################### tkinter GUI相关绘图函数 ############################ 34 | def ComputeHistogramImage(frame, hist_height=200, hist_width=300): 35 | """使用 OpenCV 绘制 RGB 直方图 36 | :param frame: 输入帧(BGR 格式) 37 | :param hist_height: 直方图图像的高度 38 | :param hist_width: 直方图图像的宽度 39 | :return: 直方图图像 40 | """ 41 | # 创建灰色背景图像 42 | hist_image = np.full((hist_height, hist_width, 3), fill_value=0, dtype=np.uint8) 43 | colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255)] 44 | 45 | # 计算每个通道的直方图 46 | for i, color in enumerate(colors): 47 | hist = cv2.calcHist([frame], [i], None, [256], [0, 256]) 48 | cv2.normalize(hist, hist, 0, hist_height - 20, cv2.NORM_MINMAX) # 留出空间给坐标轴 49 | for j in range(1, 256): 50 | # 约束直方图在图像范围内 51 | x1 = (j - 1) * (hist_width // 256) 52 | y1 = hist_height - 10 - int(hist[j - 1]) # 留出空间给横轴 53 | x2 = j * (hist_width // 256) 54 | y2 = hist_height - 10 - int(hist[j]) # 留出空间给横轴 55 | cv2.line(hist_image, (x1, y1), (x2, y2), color, thickness=2) 56 | cv2.line(hist_image, (0, hist_height - 10), (hist_width, hist_height - 10), (0, 0, 0), thickness=2) # 横轴 57 | cv2.line(hist_image, (10, 0), (10, hist_height - 10), (0, 0, 0), thickness=2) 58 | 59 | # 添加横轴刻度线和标签 60 | for i in range(0, 256, 32): 61 | x = i * (hist_width // 256) 62 | cv2.line(hist_image, (x, hist_height - 10), (x, hist_height - 5), (255, 255, 255), thickness=1) 63 | cv2.putText(hist_image, str(i), (x - 10, hist_height - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1) 64 | # 添加纵轴刻度线和标签 65 | for i in range(0, hist_height - 10, 50): 66 | y = hist_height - 10 - i 67 | cv2.line(hist_image, (10, y), (15, y), (255, 255, 255), thickness=1) 68 | cv2.putText(hist_image, str(i), (20, y + 5), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1) 69 | return hist_image 70 | 71 | 72 | def CalculateSpectrogramImage(frame, spec_height=200, spec_width=200): 73 | """使用 OpenCV 绘制频谱图 74 | :param frame: 输入帧(BGR 格式) 75 | :param spec_height: 频谱图图像的高度 76 | :param spec_width: 频谱图图像的宽度 77 | :return: 频谱图图像 78 | """ 79 | gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) 80 | fft = np.fft.fft2(gray_frame) 81 | fft_shift = np.fft.fftshift(fft) # 将低频部分移到中心 82 | magnitude_spectrum = np.log(np.abs(fft_shift) + 1) # 计算幅度谱并取对数 83 | 84 | magnitude_spectrum = cv2.normalize(magnitude_spectrum, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) 85 | spectrogram_image = cv2.resize(magnitude_spectrum, (spec_width, spec_height)) 86 | 87 | # 频谱图美化:应用伪彩色映射 88 | spectrogram_color = cv2.applyColorMap(spectrogram_image, cv2.COLORMAP_JET) 89 | grid_color = (255, 255, 255) 90 | grid_spacing = 50 91 | for x in range(0, spec_width, grid_spacing): 92 | cv2.line(spectrogram_color, (x, 0), (x, spec_height), grid_color, 1) 93 | for y in range(0, spec_height, grid_spacing): 94 | cv2.line(spectrogram_color, (0, y), (spec_width, y), grid_color, 1) 95 | # 添加标题和坐标轴标签 96 | title = "Spectrogram" 97 | cv2.putText(spectrogram_color, title, (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1) 98 | cv2.putText(spectrogram_color, "Frequency", (10, spec_height - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 0), 1) 99 | cv2.putText(spectrogram_color, "Time", (spec_width - 50, spec_height - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 0), 1) 100 | # 调整对比度和亮度 101 | alpha = 1.2 # 对比度系数 102 | beta = 30 # 亮度系数 103 | spectrogram_color = cv2.convertScaleAbs(spectrogram_color, alpha=alpha, beta=beta) 104 | return spectrogram_color 105 | 106 | 107 | ############################### 配置ostrack模型解析文件 ############################ 108 | def load_config(args): 109 | """read the configuration file""" 110 | config_path = os.path.join(args.config_dir, args.config_file) 111 | with open(config_path, 'r') as file: 112 | config = yaml.safe_load(file) 113 | return config 114 | 115 | class Params: 116 | """Load retrained model parameters""" 117 | def __init__(self, args): 118 | self.checkpoint = os.path.join(args.weight_dir, args.weight_file) 119 | self.debug = False 120 | self.save_all_boxes = False 121 | 122 | def config(): 123 | """make configurations about the vit model""" 124 | parser = argparse.ArgumentParser(description='OSTrack model configuration', 125 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 126 | parser.add_argument('--config_dir', default='./config', type=str, 127 | help='The directory of the configuration file') 128 | parser.add_argument('--config_file', default='vitb_384_mae_ce_32x4_got10k_ep100.yaml', type=str, 129 | help='The name of the configuration file') 130 | parser.add_argument('--weight_dir', default='./weights', type=str, 131 | help='the directory of the weight file') 132 | parser.add_argument('--weight_file', default='vit_384_mae_ce.pth', type=str, 133 | help='the name of the weight file OSTrack_ep0061.pth / vit_384_mae_ce.pth') 134 | args = parser.parse_args() 135 | # initialize the config and model weight 136 | cfg = load_config(args) 137 | ostrack_model = build_ostrack(cfg, training=False) 138 | weight_path = os.path.join(args.weight_dir, args.weight_file) 139 | ostrack_model.load_state_dict(torch.load(weight_path, map_location=device)) 140 | 141 | # params = Params(args) 142 | # ostrack_model = build_ostrack(cfg, training=False) 143 | # ostrack_model.load_state_dict(torch.load(params.checkpoint, map_location='cpu')['net'], strict=True) 144 | return ostrack_model 145 | 146 | 147 | 148 | ############################### 主函数测试分析 ################################ 149 | if __name__ == '__main__': 150 | # test the model function, and the model is processing the neighbor frames 151 | # called the template and search image with boundding box 2~5 times 152 | ostrack_model = config().eval().to(device) 153 | print(ostrack_model) 154 | 155 | 156 | template_img = Image.open('assets/uav_1.jpg') 157 | search_img = Image.open('assets/uav_3.jpg') 158 | template_img = template_transform(template_img).unsqueeze(0).to(device) 159 | search_img = search_transform(search_img).unsqueeze(0).to(device) 160 | results = ostrack_model(template_img, search_img) 161 | answer = results['pred_boxes'][0] 162 | bbox = answer.detach().cpu().numpy()[0] 163 | # depict the result 164 | search_img = cv2.imread('assets/uav_3.jpg') 165 | print(results, bbox) 166 | height, width, _ = search_img.shape 167 | x_min = int(bbox[0] * width) 168 | y_min = int(bbox[1] * height) 169 | x_max = int(bbox[2] * width) 170 | y_max = int(bbox[3] * height) 171 | 172 | color = (255, 0, 0) # 绿色 173 | thickness = 2 174 | cv2.rectangle(search_img, (x_min, y_min), (x_max, y_max), color, thickness) 175 | # 使用 matplotlib 显示图像 176 | plt.imshow(cv2.cvtColor(search_img, cv2.COLOR_BGR2RGB)) 177 | plt.axis('off') # 关闭坐标轴 178 | plt.show() 179 | cv2.imwrite('assets/uav_1_result.jpg', search_img) -------------------------------------------------------------------------------- /model/ostrack/utils/tensor.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | import copy 4 | from collections import OrderedDict 5 | 6 | 7 | class TensorDict(OrderedDict): 8 | """Container mainly used for dicts of torch tensors. Extends OrderedDict with pytorch functionality.""" 9 | 10 | def concat(self, other): 11 | """Concatenates two dicts without copying internal data.""" 12 | return TensorDict(self, **other) 13 | 14 | def copy(self): 15 | return TensorDict(super(TensorDict, self).copy()) 16 | 17 | def __deepcopy__(self, memodict={}): 18 | return TensorDict(copy.deepcopy(list(self), memodict)) 19 | 20 | def __getattr__(self, name): 21 | if not hasattr(torch.Tensor, name): 22 | raise AttributeError('\'TensorDict\' object has not attribute \'{}\''.format(name)) 23 | 24 | def apply_attr(*args, **kwargs): 25 | return TensorDict({n: getattr(e, name)(*args, **kwargs) if hasattr(e, name) else e for n, e in self.items()}) 26 | return apply_attr 27 | 28 | def attribute(self, attr: str, *args): 29 | return TensorDict({n: getattr(e, attr, *args) for n, e in self.items()}) 30 | 31 | def apply(self, fn, *args, **kwargs): 32 | return TensorDict({n: fn(e, *args, **kwargs) for n, e in self.items()}) 33 | 34 | @staticmethod 35 | def _iterable(a): 36 | return isinstance(a, (TensorDict, list)) 37 | 38 | 39 | class TensorList(list): 40 | """Container mainly used for lists of torch tensors. Extends lists with pytorch functionality.""" 41 | 42 | def __init__(self, list_of_tensors = None): 43 | if list_of_tensors is None: 44 | list_of_tensors = list() 45 | super(TensorList, self).__init__(list_of_tensors) 46 | 47 | def __deepcopy__(self, memodict={}): 48 | return TensorList(copy.deepcopy(list(self), memodict)) 49 | 50 | def __getitem__(self, item): 51 | if isinstance(item, int): 52 | return super(TensorList, self).__getitem__(item) 53 | elif isinstance(item, (tuple, list)): 54 | return TensorList([super(TensorList, self).__getitem__(i) for i in item]) 55 | else: 56 | return TensorList(super(TensorList, self).__getitem__(item)) 57 | 58 | def __add__(self, other): 59 | if TensorList._iterable(other): 60 | return TensorList([e1 + e2 for e1, e2 in zip(self, other)]) 61 | return TensorList([e + other for e in self]) 62 | 63 | def __radd__(self, other): 64 | if TensorList._iterable(other): 65 | return TensorList([e2 + e1 for e1, e2 in zip(self, other)]) 66 | return TensorList([other + e for e in self]) 67 | 68 | def __iadd__(self, other): 69 | if TensorList._iterable(other): 70 | for i, e2 in enumerate(other): 71 | self[i] += e2 72 | else: 73 | for i in range(len(self)): 74 | self[i] += other 75 | return self 76 | 77 | def __sub__(self, other): 78 | if TensorList._iterable(other): 79 | return TensorList([e1 - e2 for e1, e2 in zip(self, other)]) 80 | return TensorList([e - other for e in self]) 81 | 82 | def __rsub__(self, other): 83 | if TensorList._iterable(other): 84 | return TensorList([e2 - e1 for e1, e2 in zip(self, other)]) 85 | return TensorList([other - e for e in self]) 86 | 87 | def __isub__(self, other): 88 | if TensorList._iterable(other): 89 | for i, e2 in enumerate(other): 90 | self[i] -= e2 91 | else: 92 | for i in range(len(self)): 93 | self[i] -= other 94 | return self 95 | 96 | def __mul__(self, other): 97 | if TensorList._iterable(other): 98 | return TensorList([e1 * e2 for e1, e2 in zip(self, other)]) 99 | return TensorList([e * other for e in self]) 100 | 101 | def __rmul__(self, other): 102 | if TensorList._iterable(other): 103 | return TensorList([e2 * e1 for e1, e2 in zip(self, other)]) 104 | return TensorList([other * e for e in self]) 105 | 106 | def __imul__(self, other): 107 | if TensorList._iterable(other): 108 | for i, e2 in enumerate(other): 109 | self[i] *= e2 110 | else: 111 | for i in range(len(self)): 112 | self[i] *= other 113 | return self 114 | 115 | def __truediv__(self, other): 116 | if TensorList._iterable(other): 117 | return TensorList([e1 / e2 for e1, e2 in zip(self, other)]) 118 | return TensorList([e / other for e in self]) 119 | 120 | def __rtruediv__(self, other): 121 | if TensorList._iterable(other): 122 | return TensorList([e2 / e1 for e1, e2 in zip(self, other)]) 123 | return TensorList([other / e for e in self]) 124 | 125 | def __itruediv__(self, other): 126 | if TensorList._iterable(other): 127 | for i, e2 in enumerate(other): 128 | self[i] /= e2 129 | else: 130 | for i in range(len(self)): 131 | self[i] /= other 132 | return self 133 | 134 | def __matmul__(self, other): 135 | if TensorList._iterable(other): 136 | return TensorList([e1 @ e2 for e1, e2 in zip(self, other)]) 137 | return TensorList([e @ other for e in self]) 138 | 139 | def __rmatmul__(self, other): 140 | if TensorList._iterable(other): 141 | return TensorList([e2 @ e1 for e1, e2 in zip(self, other)]) 142 | return TensorList([other @ e for e in self]) 143 | 144 | def __imatmul__(self, other): 145 | if TensorList._iterable(other): 146 | for i, e2 in enumerate(other): 147 | self[i] @= e2 148 | else: 149 | for i in range(len(self)): 150 | self[i] @= other 151 | return self 152 | 153 | def __mod__(self, other): 154 | if TensorList._iterable(other): 155 | return TensorList([e1 % e2 for e1, e2 in zip(self, other)]) 156 | return TensorList([e % other for e in self]) 157 | 158 | def __rmod__(self, other): 159 | if TensorList._iterable(other): 160 | return TensorList([e2 % e1 for e1, e2 in zip(self, other)]) 161 | return TensorList([other % e for e in self]) 162 | 163 | def __pos__(self): 164 | return TensorList([+e for e in self]) 165 | 166 | def __neg__(self): 167 | return TensorList([-e for e in self]) 168 | 169 | def __le__(self, other): 170 | if TensorList._iterable(other): 171 | return TensorList([e1 <= e2 for e1, e2 in zip(self, other)]) 172 | return TensorList([e <= other for e in self]) 173 | 174 | def __ge__(self, other): 175 | if TensorList._iterable(other): 176 | return TensorList([e1 >= e2 for e1, e2 in zip(self, other)]) 177 | return TensorList([e >= other for e in self]) 178 | 179 | def concat(self, other): 180 | return TensorList(super(TensorList, self).__add__(other)) 181 | 182 | def copy(self): 183 | return TensorList(super(TensorList, self).copy()) 184 | 185 | def unroll(self): 186 | if not any(isinstance(t, TensorList) for t in self): 187 | return self 188 | 189 | new_list = TensorList() 190 | for t in self: 191 | if isinstance(t, TensorList): 192 | new_list.extend(t.unroll()) 193 | else: 194 | new_list.append(t) 195 | return new_list 196 | 197 | def list(self): 198 | return list(self) 199 | 200 | def attribute(self, attr: str, *args): 201 | return TensorList([getattr(e, attr, *args) for e in self]) 202 | 203 | def apply(self, fn): 204 | return TensorList([fn(e) for e in self]) 205 | 206 | def __getattr__(self, name): 207 | if not hasattr(torch.Tensor, name): 208 | raise AttributeError('\'TensorList\' object has not attribute \'{}\''.format(name)) 209 | 210 | def apply_attr(*args, **kwargs): 211 | return TensorList([getattr(e, name)(*args, **kwargs) for e in self]) 212 | 213 | return apply_attr 214 | 215 | @staticmethod 216 | def _iterable(a): 217 | return isinstance(a, (TensorList, list)) 218 | 219 | 220 | def tensor_operation(op): 221 | def islist(a): 222 | return isinstance(a, TensorList) 223 | 224 | @functools.wraps(op) 225 | def oplist(*args, **kwargs): 226 | if len(args) == 0: 227 | raise ValueError('Must be at least one argument without keyword (i.e. operand).') 228 | 229 | if len(args) == 1: 230 | if islist(args[0]): 231 | return TensorList([op(a, **kwargs) for a in args[0]]) 232 | else: 233 | # Multiple operands, assume max two 234 | if islist(args[0]) and islist(args[1]): 235 | return TensorList([op(a, b, *args[2:], **kwargs) for a, b in zip(*args[:2])]) 236 | if islist(args[0]): 237 | return TensorList([op(a, *args[1:], **kwargs) for a in args[0]]) 238 | if islist(args[1]): 239 | return TensorList([op(args[0], b, *args[2:], **kwargs) for b in args[1]]) 240 | 241 | # None of the operands are lists 242 | return op(*args, **kwargs) 243 | 244 | return oplist 245 | -------------------------------------------------------------------------------- /model/ostrack/vit_ce.py: -------------------------------------------------------------------------------- 1 | import math 2 | import logging 3 | from functools import partial 4 | from collections import OrderedDict 5 | from copy import deepcopy 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from timm.models.layers import to_2tuple 12 | 13 | from .layers.patch_embed import PatchEmbed 14 | from .util import combine_tokens, recover_tokens 15 | from .vit import VisionTransformer 16 | from .layers.attn_blocks import CEBlock 17 | 18 | _logger = logging.getLogger(__name__) 19 | 20 | 21 | class VisionTransformerCE(VisionTransformer): 22 | """ Vision Transformer with candidate elimination (CE) module 23 | 24 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` 25 | - https://arxiv.org/abs/2010.11929 26 | 27 | Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` 28 | - https://arxiv.org/abs/2012.12877 29 | """ 30 | 31 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 32 | num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, 33 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, 34 | act_layer=None, weight_init='', 35 | ce_loc=None, ce_keep_ratio=None): 36 | """ 37 | Args: 38 | img_size (int, tuple): input image size 39 | patch_size (int, tuple): patch size 40 | in_chans (int): number of input channels 41 | num_classes (int): number of classes for classification head 42 | embed_dim (int): embedding dimension 43 | depth (int): depth of transformer 44 | num_heads (int): number of attention heads 45 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 46 | qkv_bias (bool): enable bias for qkv if True 47 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 48 | distilled (bool): model includes a distillation token and head as in DeiT models 49 | drop_rate (float): dropout rate 50 | attn_drop_rate (float): attention dropout rate 51 | drop_path_rate (float): stochastic depth rate 52 | embed_layer (nn.Module): patch embedding layer 53 | norm_layer: (nn.Module): normalization layer 54 | weight_init: (str): weight init scheme 55 | """ 56 | # super().__init__() 57 | super().__init__() 58 | if isinstance(img_size, tuple): 59 | self.img_size = img_size 60 | else: 61 | self.img_size = to_2tuple(img_size) 62 | self.patch_size = patch_size 63 | self.in_chans = in_chans 64 | 65 | self.num_classes = num_classes 66 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 67 | self.num_tokens = 2 if distilled else 1 68 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 69 | act_layer = act_layer or nn.GELU 70 | 71 | self.patch_embed = embed_layer( 72 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 73 | num_patches = self.patch_embed.num_patches 74 | 75 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 76 | self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None 77 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) 78 | self.pos_drop = nn.Dropout(p=drop_rate) 79 | 80 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 81 | blocks = [] 82 | ce_index = 0 83 | self.ce_loc = ce_loc 84 | for i in range(depth): 85 | ce_keep_ratio_i = 1.0 86 | if ce_loc is not None and i in ce_loc: 87 | ce_keep_ratio_i = ce_keep_ratio[ce_index] 88 | ce_index += 1 89 | 90 | blocks.append( 91 | CEBlock( 92 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, 93 | attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, 94 | keep_ratio_search=ce_keep_ratio_i) 95 | ) 96 | 97 | self.blocks = nn.Sequential(*blocks) 98 | self.norm = norm_layer(embed_dim) 99 | 100 | self.init_weights(weight_init) 101 | 102 | def forward_features(self, z, x, mask_z=None, mask_x=None, 103 | ce_template_mask=None, ce_keep_rate=None, 104 | return_last_attn=False 105 | ): 106 | B, H, W = x.shape[0], x.shape[2], x.shape[3] 107 | 108 | x = self.patch_embed(x) 109 | z = self.patch_embed(z) 110 | self.cat_mode = 'direct' 111 | 112 | # attention mask handling 113 | # B, H, W 114 | if mask_z is not None and mask_x is not None: 115 | mask_z = F.interpolate(mask_z[None].float(), scale_factor=1. / self.patch_size).to(torch.bool)[0] 116 | mask_z = mask_z.flatten(1).unsqueeze(-1) 117 | 118 | mask_x = F.interpolate(mask_x[None].float(), scale_factor=1. / self.patch_size).to(torch.bool)[0] 119 | mask_x = mask_x.flatten(1).unsqueeze(-1) 120 | 121 | mask_x = combine_tokens(mask_z, mask_x, mode=self.cat_mode) 122 | mask_x = mask_x.squeeze(-1) 123 | 124 | if self.add_cls_token: 125 | cls_tokens = self.cls_token.expand(B, -1, -1) 126 | cls_tokens = cls_tokens + self.cls_pos_embed 127 | 128 | z += self.pos_embed_z 129 | x += self.pos_embed_x 130 | 131 | if self.add_sep_seg: 132 | x += self.search_segment_pos_embed 133 | z += self.template_segment_pos_embed 134 | 135 | x = combine_tokens(z, x, mode=self.cat_mode) 136 | if self.add_cls_token: 137 | x = torch.cat([cls_tokens, x], dim=1) 138 | 139 | x = self.pos_drop(x) 140 | 141 | lens_z = self.pos_embed_z.shape[1] 142 | lens_x = self.pos_embed_x.shape[1] 143 | 144 | global_index_t = torch.linspace(0, lens_z - 1, lens_z).to(x.device) 145 | global_index_t = global_index_t.repeat(B, 1) 146 | 147 | global_index_s = torch.linspace(0, lens_x - 1, lens_x).to(x.device) 148 | global_index_s = global_index_s.repeat(B, 1) 149 | removed_indexes_s = [] 150 | for i, blk in enumerate(self.blocks): 151 | x, global_index_t, global_index_s, removed_index_s, attn = \ 152 | blk(x, global_index_t, global_index_s, mask_x, ce_template_mask, ce_keep_rate) 153 | 154 | if self.ce_loc is not None and i in self.ce_loc: 155 | removed_indexes_s.append(removed_index_s) 156 | 157 | x = self.norm(x) 158 | lens_x_new = global_index_s.shape[1] 159 | lens_z_new = global_index_t.shape[1] 160 | 161 | z = x[:, :lens_z_new] 162 | x = x[:, lens_z_new:] 163 | 164 | if removed_indexes_s and removed_indexes_s[0] is not None: 165 | removed_indexes_cat = torch.cat(removed_indexes_s, dim=1) 166 | 167 | pruned_lens_x = lens_x - lens_x_new 168 | pad_x = torch.zeros([B, pruned_lens_x, x.shape[2]], device=x.device) 169 | x = torch.cat([x, pad_x], dim=1) 170 | index_all = torch.cat([global_index_s, removed_indexes_cat], dim=1) 171 | # recover original token order 172 | C = x.shape[-1] 173 | # x = x.gather(1, index_all.unsqueeze(-1).expand(B, -1, C).argsort(1)) 174 | x = torch.zeros_like(x).scatter_(dim=1, index=index_all.unsqueeze(-1).expand(B, -1, C).to(torch.int64), src=x) 175 | 176 | x = recover_tokens(x, lens_z_new, lens_x, mode=self.cat_mode) 177 | 178 | # re-concatenate with the template, which may be further used by other modules 179 | x = torch.cat([z, x], dim=1) 180 | 181 | aux_dict = { 182 | "attn": attn, 183 | "removed_indexes_s": removed_indexes_s, # used for visualization 184 | } 185 | 186 | return x, aux_dict 187 | 188 | def forward(self, z, x, ce_template_mask=None, ce_keep_rate=None, 189 | tnc_keep_rate=None, 190 | return_last_attn=False): 191 | 192 | x, aux_dict = self.forward_features(z, x, ce_template_mask=ce_template_mask, ce_keep_rate=ce_keep_rate,) 193 | 194 | return x, aux_dict 195 | 196 | 197 | def _create_vision_transformer(pretrained=False, **kwargs): 198 | model = VisionTransformerCE(**kwargs) 199 | 200 | if pretrained: 201 | if 'npz' in pretrained: 202 | model.load_pretrained(pretrained, prefix='') 203 | else: 204 | checkpoint = torch.load(pretrained, map_location="cpu") 205 | missing_keys, unexpected_keys = model.load_state_dict(checkpoint["model"], strict=False) 206 | print('Load pretrained model from: ' + pretrained) 207 | 208 | return model 209 | 210 | 211 | def vit_base_patch16_224_ce(pretrained=False, **kwargs): 212 | """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 213 | """ 214 | model_kwargs = dict( 215 | patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 216 | model = _create_vision_transformer(pretrained=pretrained, **model_kwargs) 217 | return model 218 | 219 | 220 | def vit_large_patch16_224_ce(pretrained=False, **kwargs): 221 | """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). 222 | """ 223 | model_kwargs = dict( 224 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) 225 | model = _create_vision_transformer(pretrained=pretrained, **model_kwargs) 226 | return model 227 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |
3 |