├── LICENSE.txt ├── NAT.py ├── README.md ├── assets ├── ai_dist.png ├── confidence_map.png ├── framework.png ├── one-to-many_generation.png ├── suc_data.png └── visualization.png ├── install.sh ├── label.py ├── list.py ├── par_crop.py ├── requirement.txt ├── sam ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── assets │ ├── masks1.png │ ├── masks2.jpg │ ├── minidemo.gif │ ├── model_diagram.png │ ├── notebook1.png │ └── notebook2.png ├── demo │ ├── README.md │ ├── configs │ │ └── webpack │ │ │ ├── common.js │ │ │ ├── dev.js │ │ │ └── prod.js │ ├── package.json │ ├── postcss.config.js │ ├── src │ │ ├── App.tsx │ │ ├── assets │ │ │ ├── data │ │ │ │ └── dogs.jpg │ │ │ ├── index.html │ │ │ └── scss │ │ │ │ └── App.scss │ │ ├── components │ │ │ ├── Stage.tsx │ │ │ ├── Tool.tsx │ │ │ ├── helpers │ │ │ │ ├── Interfaces.tsx │ │ │ │ ├── maskUtils.tsx │ │ │ │ ├── onnxModelAPI.tsx │ │ │ │ └── scaleHelper.tsx │ │ │ └── hooks │ │ │ │ ├── context.tsx │ │ │ │ └── createContext.tsx │ │ └── index.tsx │ ├── tailwind.config.js │ └── tsconfig.json ├── linter.sh ├── notebooks │ ├── automatic_mask_generator_example.ipynb │ ├── images │ │ ├── dog.jpg │ │ ├── groceries.jpg │ │ └── truck.jpg │ ├── onnx_model_example.ipynb │ └── predictor_example.ipynb ├── scripts │ ├── amg.py │ └── export_onnx_model.py ├── segment_anything │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── automatic_mask_generator.cpython-38.pyc │ │ ├── build_sam.cpython-38.pyc │ │ └── predictor.cpython-38.pyc │ ├── automatic_mask_generator.py │ ├── build_sam.py │ ├── modeling │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── common.cpython-38.pyc │ │ │ ├── image_encoder.cpython-38.pyc │ │ │ ├── mask_decoder.cpython-38.pyc │ │ │ ├── prompt_encoder.cpython-38.pyc │ │ │ ├── sam.cpython-38.pyc │ │ │ └── transformer.cpython-38.pyc │ │ ├── common.py │ │ ├── image_encoder.py │ │ ├── mask_decoder.py │ │ ├── prompt_encoder.py │ │ ├── sam.py │ │ └── transformer.py │ ├── predictor.py │ └── utils │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── amg.cpython-38.pyc │ │ └── transforms.cpython-38.pyc │ │ ├── amg.py │ │ ├── onnx.py │ │ └── transforms.py ├── setup.cfg └── setup.py ├── swell.sh └── tracker └── BAN ├── experiments └── udatban_r50_l234 │ └── config.yaml ├── siamban ├── core │ ├── __init__.py │ ├── config.py │ └── xcorr.py ├── datasets │ ├── __init__.py │ ├── augmentation.py │ ├── dataset.py │ └── point_target.py ├── models │ ├── GRL.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── GRL.cpython-38.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── iou_loss.cpython-38.pyc │ │ ├── loss.cpython-38.pyc │ │ ├── model_builder.cpython-38.pyc │ │ ├── model_builder_tsne.cpython-38.pyc │ │ ├── model_builder_v.cpython-38.pyc │ │ └── trans_discriminator.cpython-38.pyc │ ├── backbone │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── alexnet.cpython-36.pyc │ │ │ ├── alexnet.cpython-38.pyc │ │ │ ├── mobile_v2.cpython-36.pyc │ │ │ ├── mobile_v2.cpython-38.pyc │ │ │ ├── resnet_atrous.cpython-36.pyc │ │ │ └── resnet_atrous.cpython-38.pyc │ │ ├── alexnet.py │ │ ├── mobile_v2.py │ │ └── resnet_atrous.py │ ├── head │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ └── ban.cpython-38.pyc │ │ └── ban.py │ ├── init_weight.py │ ├── iou_loss.py │ ├── loss.py │ ├── model_builder.py │ ├── model_builder_tsne.py │ ├── model_builder_v.py │ ├── neck │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── neck.cpython-38.pyc │ │ │ └── trans.cpython-38.pyc │ │ ├── neck.py │ │ └── trans.py │ └── trans_discriminator.py ├── tracker │ ├── __init__.py │ ├── base_tracker.py │ ├── siamban_tracker.py │ └── tracker_builder.py └── utils │ ├── __init__.py │ ├── average_meter.py │ ├── bbox.py │ ├── distributed.py │ ├── log_helper.py │ ├── lr_scheduler.py │ ├── misc.py │ ├── model_load.py │ └── point.py ├── snapshot └── README.md ├── toolkit ├── __init__.py ├── datasets │ ├── DarkTrack2021.py │ ├── UAVDark135.py │ ├── UAVDark70.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── DarkTrack2021.cpython-38.pyc │ │ ├── UAVDark135.cpython-38.pyc │ │ ├── UAVDark70.cpython-38.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── dataset.cpython-38.pyc │ │ ├── nat.cpython-38.pyc │ │ ├── nat_l.cpython-38.pyc │ │ ├── nut.cpython-38.pyc │ │ ├── nut_l.cpython-38.pyc │ │ ├── nut_l_s.cpython-38.pyc │ │ ├── nut_l_t.cpython-38.pyc │ │ ├── uav.cpython-38.pyc │ │ └── video.cpython-38.pyc │ ├── dataset.py │ ├── nat.py │ ├── nat_l.py │ ├── nut.py │ ├── nut_l.py │ ├── nut_l_s.py │ ├── nut_l_t.py │ ├── uav.py │ └── video.py ├── evaluation │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── ar_benchmark.cpython-38.pyc │ │ ├── eao_benchmark.cpython-38.pyc │ │ ├── f1_benchmark.cpython-38.pyc │ │ └── ope_benchmark.cpython-38.pyc │ ├── ar_benchmark.py │ ├── eao_benchmark.py │ ├── f1_benchmark.py │ └── ope_benchmark.py ├── utils │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ └── statistics.cpython-38.pyc │ ├── c_region.pxd │ ├── misc.py │ ├── region.c │ ├── region.cpython-36m-x86_64-linux-gnu.so │ ├── region.cpython-38-x86_64-linux-gnu.so │ ├── region.pyx │ ├── src │ │ ├── buffer.h │ │ ├── region.c │ │ └── region.h │ └── statistics.py └── visualization │ ├── __init__.py │ ├── draw_eao.py │ ├── draw_f1.py │ ├── draw_success_precision.py │ └── draw_utils.py ├── tools ├── demo.py ├── eval.py ├── test.py └── train.py └── train_dataset ├── got10k ├── gen_json.py ├── par_crop.py └── readme.md └── vid ├── gen_json.py ├── par_crop.py ├── parse_vid.py ├── readme.md └── visual.py /NAT.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | from collections import defaultdict 4 | import itertools 5 | def _isArrayLike(obj): 6 | return hasattr(obj, '__iter__') and hasattr(obj, '__len__') 7 | 8 | class NAT2021: 9 | def __init__(self, annotation_file=None): 10 | """ 11 | Thanks coco 12 | :param annotation_file (str): location of annotation file 13 | :param image_folder (str): location to the folder that hosts images. 14 | :return: 15 | """ 16 | # load dataset 17 | self.dataset,self.anns,self.imgs = dict(),dict(),dict() 18 | # self.imgToAnns = defaultdict(list) 19 | if not annotation_file == None: 20 | print('loading annotations into memory...') 21 | tic = time.time() 22 | dataset = json.load(open(annotation_file, 'r')) 23 | assert type(dataset)==dict, 'annotation file format {} not supported'.format(type(dataset)) 24 | print('Done (t={:0.2f}s)'.format(time.time()- tic)) 25 | self.dataset = dataset 26 | self.createIndex() 27 | 28 | def createIndex(self): 29 | # create index 30 | print('creating index...') 31 | anns, imgs = {}, {} 32 | imgToAnns = defaultdict(list) 33 | 34 | 35 | dataType = list(self.dataset.keys())[0] 36 | json_dict = self.dataset[dataType] 37 | for ann in json_dict: 38 | id = list(ann.keys())[0] 39 | masks = ann[id] 40 | imgToAnns[id] = masks 41 | 42 | print('index created!') 43 | self.imgToAnns = imgToAnns 44 | 45 | 46 | if __name__ == '__main__': 47 | jsonFile = './seg_result/annotations/0421truck2_10.json' 48 | coco = NAT2021(jsonFile) 49 | print(len(list(coco.imgToAnns.keys()))) 50 | for id in list(coco.imgToAnns.keys()): 51 | print(coco.loadImgs(id)[0]) -------------------------------------------------------------------------------- /assets/ai_dist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/assets/ai_dist.png -------------------------------------------------------------------------------- /assets/confidence_map.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/assets/confidence_map.png -------------------------------------------------------------------------------- /assets/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/assets/framework.png -------------------------------------------------------------------------------- /assets/one-to-many_generation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/assets/one-to-many_generation.png -------------------------------------------------------------------------------- /assets/suc_data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/assets/suc_data.png -------------------------------------------------------------------------------- /assets/visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/assets/visualization.png -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | # Install SAM 2 | cd sam; pip install -e . 3 | cd - 4 | # Install BAN 5 | -------------------------------------------------------------------------------- /list.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | 4 | char_at_position = [] 5 | for i in range(4): 6 | with open('./result/'+str(i)+'.json') as file: 7 | # read the json file 8 | content = file.read() 9 | pattern = r'\}(\r?\n)\{' 10 | content = re.sub(pattern,'};\n{',content) 11 | dictionaries = content.split(';') 12 | for dictionary in dictionaries: 13 | data = json.loads(dictionary) 14 | with open('./result/list.json', 'a') as file: 15 | json.dump(data, file, indent=4, sort_keys=True) 16 | file.write('\n') 17 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | opencv-python 3 | pyyaml 4 | yacs 5 | tqdm 6 | colorama 7 | matplotlib 8 | cython 9 | tensorboardX 10 | -------------------------------------------------------------------------------- /sam/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /sam/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to segment-anything 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints, using the `linter.sh` script in the project's root directory. Linting requires `black==23.*`, `isort==5.12.0`, `flake8`, and `mypy`. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to segment-anything, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /sam/assets/masks1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/sam/assets/masks1.png -------------------------------------------------------------------------------- /sam/assets/masks2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/sam/assets/masks2.jpg -------------------------------------------------------------------------------- /sam/assets/minidemo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/sam/assets/minidemo.gif -------------------------------------------------------------------------------- /sam/assets/model_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/sam/assets/model_diagram.png -------------------------------------------------------------------------------- /sam/assets/notebook1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/sam/assets/notebook1.png -------------------------------------------------------------------------------- /sam/assets/notebook2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/sam/assets/notebook2.png -------------------------------------------------------------------------------- /sam/demo/README.md: -------------------------------------------------------------------------------- 1 | ## Segment Anything Simple Web demo 2 | 3 | This **front-end only** React based web demo shows how to load a fixed image and corresponding `.npy` file of the SAM image embedding, and run the SAM ONNX model in the browser using Web Assembly with mulithreading enabled by `SharedArrayBuffer`, Web Worker, and SIMD128. 4 | 5 | 6 | 7 | ## Run the app 8 | 9 | Install Yarn 10 | 11 | ``` 12 | npm install --g yarn 13 | ``` 14 | 15 | Build and run: 16 | 17 | ``` 18 | yarn && yarn start 19 | ``` 20 | 21 | Navigate to [`http://localhost:8081/`](http://localhost:8081/) 22 | 23 | Move your cursor around to see the mask prediction update in real time. 24 | 25 | ## Export the image embedding 26 | 27 | In the [ONNX Model Example notebook](https://github.com/facebookresearch/segment-anything/blob/main/notebooks/onnx_model_example.ipynb) upload the image of your choice and generate and save corresponding embedding. 28 | 29 | Initialize the predictor: 30 | 31 | ```python 32 | checkpoint = "sam_vit_h_4b8939.pth" 33 | model_type = "vit_h" 34 | sam = sam_model_registry[model_type](checkpoint=checkpoint) 35 | sam.to(device='cuda') 36 | predictor = SamPredictor(sam) 37 | ``` 38 | 39 | Set the new image and export the embedding: 40 | 41 | ``` 42 | image = cv2.imread('src/assets/dogs.jpg') 43 | predictor.set_image(image) 44 | image_embedding = predictor.get_image_embedding().cpu().numpy() 45 | np.save("dogs_embedding.npy", image_embedding) 46 | ``` 47 | 48 | Save the new image and embedding in `src/assets/data`. 49 | 50 | ## Export the ONNX model 51 | 52 | You also need to export the quantized ONNX model from the [ONNX Model Example notebook](https://github.com/facebookresearch/segment-anything/blob/main/notebooks/onnx_model_example.ipynb). 53 | 54 | Run the cell in the notebook which saves the `sam_onnx_quantized_example.onnx` file, download it and copy it to the path `/model/sam_onnx_quantized_example.onnx`. 55 | 56 | Here is a snippet of the export/quantization code: 57 | 58 | ``` 59 | onnx_model_path = "sam_onnx_example.onnx" 60 | onnx_model_quantized_path = "sam_onnx_quantized_example.onnx" 61 | quantize_dynamic( 62 | model_input=onnx_model_path, 63 | model_output=onnx_model_quantized_path, 64 | optimize_model=True, 65 | per_channel=False, 66 | reduce_range=False, 67 | weight_type=QuantType.QUInt8, 68 | ) 69 | ``` 70 | 71 | **NOTE: if you change the ONNX model by using a new checkpoint you need to also re-export the embedding.** 72 | 73 | ## Update the image, embedding, model in the app 74 | 75 | Update the following file paths at the top of`App.tsx`: 76 | 77 | ```py 78 | const IMAGE_PATH = "/assets/data/dogs.jpg"; 79 | const IMAGE_EMBEDDING = "/assets/data/dogs_embedding.npy"; 80 | const MODEL_DIR = "/model/sam_onnx_quantized_example.onnx"; 81 | ``` 82 | 83 | ## ONNX multithreading with SharedArrayBuffer 84 | 85 | To use multithreading, the appropriate headers need to be set to create a cross origin isolation state which will enable use of `SharedArrayBuffer` (see this [blog post](https://cloudblogs.microsoft.com/opensource/2021/09/02/onnx-runtime-web-running-your-machine-learning-model-in-browser/) for more details) 86 | 87 | The headers below are set in `configs/webpack/dev.js`: 88 | 89 | ```js 90 | headers: { 91 | "Cross-Origin-Opener-Policy": "same-origin", 92 | "Cross-Origin-Embedder-Policy": "credentialless", 93 | } 94 | ``` 95 | 96 | ## Structure of the app 97 | 98 | **`App.tsx`** 99 | 100 | - Initializes ONNX model 101 | - Loads image embedding and image 102 | - Runs the ONNX model based on input prompts 103 | 104 | **`Stage.tsx`** 105 | 106 | - Handles mouse move interaction to update the ONNX model prompt 107 | 108 | **`Tool.tsx`** 109 | 110 | - Renders the image and the mask prediction 111 | 112 | **`helpers/maskUtils.tsx`** 113 | 114 | - Conversion of ONNX model output from array to an HTMLImageElement 115 | 116 | **`helpers/onnxModelAPI.tsx`** 117 | 118 | - Formats the inputs for the ONNX model 119 | 120 | **`helpers/scaleHelper.tsx`** 121 | 122 | - Handles image scaling logic for SAM (longest size 1024) 123 | 124 | **`hooks/`** 125 | 126 | - Handle shared state for the app 127 | -------------------------------------------------------------------------------- /sam/demo/configs/webpack/common.js: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | const { resolve } = require("path"); 8 | const HtmlWebpackPlugin = require("html-webpack-plugin"); 9 | const FriendlyErrorsWebpackPlugin = require("friendly-errors-webpack-plugin"); 10 | const CopyPlugin = require("copy-webpack-plugin"); 11 | const webpack = require("webpack"); 12 | 13 | module.exports = { 14 | entry: "./src/index.tsx", 15 | resolve: { 16 | extensions: [".js", ".jsx", ".ts", ".tsx"], 17 | }, 18 | output: { 19 | path: resolve(__dirname, "dist"), 20 | }, 21 | module: { 22 | rules: [ 23 | { 24 | test: /\.mjs$/, 25 | include: /node_modules/, 26 | type: "javascript/auto", 27 | resolve: { 28 | fullySpecified: false, 29 | }, 30 | }, 31 | { 32 | test: [/\.jsx?$/, /\.tsx?$/], 33 | use: ["ts-loader"], 34 | exclude: /node_modules/, 35 | }, 36 | { 37 | test: /\.css$/, 38 | use: ["style-loader", "css-loader"], 39 | }, 40 | { 41 | test: /\.(scss|sass)$/, 42 | use: ["style-loader", "css-loader", "postcss-loader"], 43 | }, 44 | { 45 | test: /\.(jpe?g|png|gif|svg)$/i, 46 | use: [ 47 | "file-loader?hash=sha512&digest=hex&name=img/[contenthash].[ext]", 48 | "image-webpack-loader?bypassOnDebug&optipng.optimizationLevel=7&gifsicle.interlaced=false", 49 | ], 50 | }, 51 | { 52 | test: /\.(woff|woff2|ttf)$/, 53 | use: { 54 | loader: "url-loader", 55 | }, 56 | }, 57 | ], 58 | }, 59 | plugins: [ 60 | new CopyPlugin({ 61 | patterns: [ 62 | { 63 | from: "node_modules/onnxruntime-web/dist/*.wasm", 64 | to: "[name][ext]", 65 | }, 66 | { 67 | from: "model", 68 | to: "model", 69 | }, 70 | { 71 | from: "src/assets", 72 | to: "assets", 73 | }, 74 | ], 75 | }), 76 | new HtmlWebpackPlugin({ 77 | template: "./src/assets/index.html", 78 | }), 79 | new FriendlyErrorsWebpackPlugin(), 80 | new webpack.ProvidePlugin({ 81 | process: "process/browser", 82 | }), 83 | ], 84 | }; 85 | -------------------------------------------------------------------------------- /sam/demo/configs/webpack/dev.js: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | // development config 8 | const { merge } = require("webpack-merge"); 9 | const commonConfig = require("./common"); 10 | 11 | module.exports = merge(commonConfig, { 12 | mode: "development", 13 | devServer: { 14 | hot: true, // enable HMR on the server 15 | open: true, 16 | // These headers enable the cross origin isolation state 17 | // needed to enable use of SharedArrayBuffer for ONNX 18 | // multithreading. 19 | headers: { 20 | "Cross-Origin-Opener-Policy": "same-origin", 21 | "Cross-Origin-Embedder-Policy": "credentialless", 22 | }, 23 | }, 24 | devtool: "cheap-module-source-map", 25 | }); 26 | -------------------------------------------------------------------------------- /sam/demo/configs/webpack/prod.js: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | // production config 8 | const { merge } = require("webpack-merge"); 9 | const { resolve } = require("path"); 10 | const Dotenv = require("dotenv-webpack"); 11 | const commonConfig = require("./common"); 12 | 13 | module.exports = merge(commonConfig, { 14 | mode: "production", 15 | output: { 16 | filename: "js/bundle.[contenthash].min.js", 17 | path: resolve(__dirname, "../../dist"), 18 | publicPath: "/", 19 | }, 20 | devtool: "source-map", 21 | plugins: [new Dotenv()], 22 | }); 23 | -------------------------------------------------------------------------------- /sam/demo/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "segment-anything-mini-demo", 3 | "version": "0.1.0", 4 | "license": "MIT", 5 | "scripts": { 6 | "build": "yarn run clean-dist && webpack --config=configs/webpack/prod.js && mv dist/*.wasm dist/js", 7 | "clean-dist": "rimraf dist/*", 8 | "lint": "eslint './src/**/*.{js,ts,tsx}' --quiet", 9 | "start": "yarn run start-dev", 10 | "test": "yarn run start-model-test", 11 | "start-dev": "webpack serve --config=configs/webpack/dev.js" 12 | }, 13 | "devDependencies": { 14 | "@babel/core": "^7.18.13", 15 | "@babel/preset-env": "^7.18.10", 16 | "@babel/preset-react": "^7.18.6", 17 | "@babel/preset-typescript": "^7.18.6", 18 | "@pmmmwh/react-refresh-webpack-plugin": "^0.5.7", 19 | "@testing-library/react": "^13.3.0", 20 | "@types/node": "^18.7.13", 21 | "@types/react": "^18.0.17", 22 | "@types/react-dom": "^18.0.6", 23 | "@types/underscore": "^1.11.4", 24 | "@typescript-eslint/eslint-plugin": "^5.35.1", 25 | "@typescript-eslint/parser": "^5.35.1", 26 | "babel-loader": "^8.2.5", 27 | "copy-webpack-plugin": "^11.0.0", 28 | "css-loader": "^6.7.1", 29 | "dotenv": "^16.0.2", 30 | "dotenv-webpack": "^8.0.1", 31 | "eslint": "^8.22.0", 32 | "eslint-plugin-react": "^7.31.0", 33 | "file-loader": "^6.2.0", 34 | "fork-ts-checker-webpack-plugin": "^7.2.13", 35 | "friendly-errors-webpack-plugin": "^1.7.0", 36 | "html-webpack-plugin": "^5.5.0", 37 | "image-webpack-loader": "^8.1.0", 38 | "postcss-loader": "^7.0.1", 39 | "postcss-preset-env": "^7.8.0", 40 | "process": "^0.11.10", 41 | "rimraf": "^3.0.2", 42 | "sass": "^1.54.5", 43 | "sass-loader": "^13.0.2", 44 | "style-loader": "^3.3.1", 45 | "tailwindcss": "^3.1.8", 46 | "ts-loader": "^9.3.1", 47 | "typescript": "^4.8.2", 48 | "webpack": "^5.74.0", 49 | "webpack-cli": "^4.10.0", 50 | "webpack-dev-server": "^4.10.0", 51 | "webpack-dotenv-plugin": "^2.1.0", 52 | "webpack-merge": "^5.8.0" 53 | }, 54 | "dependencies": { 55 | "npyjs": "^0.4.0", 56 | "onnxruntime-web": "^1.14.0", 57 | "react": "^18.2.0", 58 | "react-dom": "^18.2.0", 59 | "underscore": "^1.13.6", 60 | "react-refresh": "^0.14.0" 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /sam/demo/postcss.config.js: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | const tailwindcss = require("tailwindcss"); 8 | module.exports = { 9 | plugins: ["postcss-preset-env", 'tailwindcss/nesting', tailwindcss], 10 | }; 11 | -------------------------------------------------------------------------------- /sam/demo/src/App.tsx: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | import { InferenceSession, Tensor } from "onnxruntime-web"; 8 | import React, { useContext, useEffect, useState } from "react"; 9 | import "./assets/scss/App.scss"; 10 | import { handleImageScale } from "./components/helpers/scaleHelper"; 11 | import { modelScaleProps } from "./components/helpers/Interfaces"; 12 | import { onnxMaskToImage } from "./components/helpers/maskUtils"; 13 | import { modelData } from "./components/helpers/onnxModelAPI"; 14 | import Stage from "./components/Stage"; 15 | import AppContext from "./components/hooks/createContext"; 16 | const ort = require("onnxruntime-web"); 17 | /* @ts-ignore */ 18 | import npyjs from "npyjs"; 19 | 20 | // Define image, embedding and model paths 21 | const IMAGE_PATH = "/assets/data/dogs.jpg"; 22 | const IMAGE_EMBEDDING = "/assets/data/dogs_embedding.npy"; 23 | const MODEL_DIR = "/model/sam_onnx_quantized_example.onnx"; 24 | 25 | const App = () => { 26 | const { 27 | clicks: [clicks], 28 | image: [, setImage], 29 | maskImg: [, setMaskImg], 30 | } = useContext(AppContext)!; 31 | const [model, setModel] = useState(null); // ONNX model 32 | const [tensor, setTensor] = useState(null); // Image embedding tensor 33 | 34 | // The ONNX model expects the input to be rescaled to 1024. 35 | // The modelScale state variable keeps track of the scale values. 36 | const [modelScale, setModelScale] = useState(null); 37 | 38 | // Initialize the ONNX model. load the image, and load the SAM 39 | // pre-computed image embedding 40 | useEffect(() => { 41 | // Initialize the ONNX model 42 | const initModel = async () => { 43 | try { 44 | if (MODEL_DIR === undefined) return; 45 | const URL: string = MODEL_DIR; 46 | const model = await InferenceSession.create(URL); 47 | setModel(model); 48 | } catch (e) { 49 | console.log(e); 50 | } 51 | }; 52 | initModel(); 53 | 54 | // Load the image 55 | const url = new URL(IMAGE_PATH, location.origin); 56 | loadImage(url); 57 | 58 | // Load the Segment Anything pre-computed embedding 59 | Promise.resolve(loadNpyTensor(IMAGE_EMBEDDING, "float32")).then( 60 | (embedding) => setTensor(embedding) 61 | ); 62 | }, []); 63 | 64 | const loadImage = async (url: URL) => { 65 | try { 66 | const img = new Image(); 67 | img.src = url.href; 68 | img.onload = () => { 69 | const { height, width, samScale } = handleImageScale(img); 70 | setModelScale({ 71 | height: height, // original image height 72 | width: width, // original image width 73 | samScale: samScale, // scaling factor for image which has been resized to longest side 1024 74 | }); 75 | img.width = width; 76 | img.height = height; 77 | setImage(img); 78 | }; 79 | } catch (error) { 80 | console.log(error); 81 | } 82 | }; 83 | 84 | // Decode a Numpy file into a tensor. 85 | const loadNpyTensor = async (tensorFile: string, dType: string) => { 86 | let npLoader = new npyjs(); 87 | const npArray = await npLoader.load(tensorFile); 88 | const tensor = new ort.Tensor(dType, npArray.data, npArray.shape); 89 | return tensor; 90 | }; 91 | 92 | // Run the ONNX model every time clicks has changed 93 | useEffect(() => { 94 | runONNX(); 95 | }, [clicks]); 96 | 97 | const runONNX = async () => { 98 | try { 99 | if ( 100 | model === null || 101 | clicks === null || 102 | tensor === null || 103 | modelScale === null 104 | ) 105 | return; 106 | else { 107 | // Preapre the model input in the correct format for SAM. 108 | // The modelData function is from onnxModelAPI.tsx. 109 | const feeds = modelData({ 110 | clicks, 111 | tensor, 112 | modelScale, 113 | }); 114 | if (feeds === undefined) return; 115 | // Run the SAM ONNX model with the feeds returned from modelData() 116 | const results = await model.run(feeds); 117 | const output = results[model.outputNames[0]]; 118 | // The predicted mask returned from the ONNX model is an array which is 119 | // rendered as an HTML image using onnxMaskToImage() from maskUtils.tsx. 120 | setMaskImg(onnxMaskToImage(output.data, output.dims[2], output.dims[3])); 121 | } 122 | } catch (e) { 123 | console.log(e); 124 | } 125 | }; 126 | 127 | return ; 128 | }; 129 | 130 | export default App; 131 | -------------------------------------------------------------------------------- /sam/demo/src/assets/data/dogs.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/sam/demo/src/assets/data/dogs.jpg -------------------------------------------------------------------------------- /sam/demo/src/assets/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 9 | Segment Anything Demo 10 | 11 | 12 | 13 | 14 | 15 | 16 |
17 | 18 | 19 | -------------------------------------------------------------------------------- /sam/demo/src/assets/scss/App.scss: -------------------------------------------------------------------------------- 1 | @tailwind base; 2 | @tailwind components; 3 | @tailwind utilities; 4 | -------------------------------------------------------------------------------- /sam/demo/src/components/Stage.tsx: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | import React, { useContext } from "react"; 8 | import * as _ from "underscore"; 9 | import Tool from "./Tool"; 10 | import { modelInputProps } from "./helpers/Interfaces"; 11 | import AppContext from "./hooks/createContext"; 12 | 13 | const Stage = () => { 14 | const { 15 | clicks: [, setClicks], 16 | image: [image], 17 | } = useContext(AppContext)!; 18 | 19 | const getClick = (x: number, y: number): modelInputProps => { 20 | const clickType = 1; 21 | return { x, y, clickType }; 22 | }; 23 | 24 | // Get mouse position and scale the (x, y) coordinates back to the natural 25 | // scale of the image. Update the state of clicks with setClicks to trigger 26 | // the ONNX model to run and generate a new mask via a useEffect in App.tsx 27 | const handleMouseMove = _.throttle((e: any) => { 28 | let el = e.nativeEvent.target; 29 | const rect = el.getBoundingClientRect(); 30 | let x = e.clientX - rect.left; 31 | let y = e.clientY - rect.top; 32 | const imageScale = image ? image.width / el.offsetWidth : 1; 33 | x *= imageScale; 34 | y *= imageScale; 35 | const click = getClick(x, y); 36 | if (click) setClicks([click]); 37 | }, 15); 38 | 39 | const flexCenterClasses = "flex items-center justify-center"; 40 | return ( 41 |
42 |
43 | 44 |
45 |
46 | ); 47 | }; 48 | 49 | export default Stage; 50 | -------------------------------------------------------------------------------- /sam/demo/src/components/Tool.tsx: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | import React, { useContext, useEffect, useState } from "react"; 8 | import AppContext from "./hooks/createContext"; 9 | import { ToolProps } from "./helpers/Interfaces"; 10 | import * as _ from "underscore"; 11 | 12 | const Tool = ({ handleMouseMove }: ToolProps) => { 13 | const { 14 | image: [image], 15 | maskImg: [maskImg, setMaskImg], 16 | } = useContext(AppContext)!; 17 | 18 | // Determine if we should shrink or grow the images to match the 19 | // width or the height of the page and setup a ResizeObserver to 20 | // monitor changes in the size of the page 21 | const [shouldFitToWidth, setShouldFitToWidth] = useState(true); 22 | const bodyEl = document.body; 23 | const fitToPage = () => { 24 | if (!image) return; 25 | const imageAspectRatio = image.width / image.height; 26 | const screenAspectRatio = window.innerWidth / window.innerHeight; 27 | setShouldFitToWidth(imageAspectRatio > screenAspectRatio); 28 | }; 29 | const resizeObserver = new ResizeObserver((entries) => { 30 | for (const entry of entries) { 31 | if (entry.target === bodyEl) { 32 | fitToPage(); 33 | } 34 | } 35 | }); 36 | useEffect(() => { 37 | fitToPage(); 38 | resizeObserver.observe(bodyEl); 39 | return () => { 40 | resizeObserver.unobserve(bodyEl); 41 | }; 42 | }, [image]); 43 | 44 | const imageClasses = ""; 45 | const maskImageClasses = `absolute opacity-40 pointer-events-none`; 46 | 47 | // Render the image and the predicted mask image on top 48 | return ( 49 | <> 50 | {image && ( 51 | _.defer(() => setMaskImg(null))} 54 | onTouchStart={handleMouseMove} 55 | src={image.src} 56 | className={`${ 57 | shouldFitToWidth ? "w-full" : "h-full" 58 | } ${imageClasses}`} 59 | > 60 | )} 61 | {maskImg && ( 62 | 68 | )} 69 | 70 | ); 71 | }; 72 | 73 | export default Tool; 74 | -------------------------------------------------------------------------------- /sam/demo/src/components/helpers/Interfaces.tsx: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | import { Tensor } from "onnxruntime-web"; 8 | 9 | export interface modelScaleProps { 10 | samScale: number; 11 | height: number; 12 | width: number; 13 | } 14 | 15 | export interface modelInputProps { 16 | x: number; 17 | y: number; 18 | clickType: number; 19 | } 20 | 21 | export interface modeDataProps { 22 | clicks?: Array; 23 | tensor: Tensor; 24 | modelScale: modelScaleProps; 25 | } 26 | 27 | export interface ToolProps { 28 | handleMouseMove: (e: any) => void; 29 | } 30 | -------------------------------------------------------------------------------- /sam/demo/src/components/helpers/maskUtils.tsx: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | // Convert the onnx model mask prediction to ImageData 8 | function arrayToImageData(input: any, width: number, height: number) { 9 | const [r, g, b, a] = [0, 114, 189, 255]; // the masks's blue color 10 | const arr = new Uint8ClampedArray(4 * width * height).fill(0); 11 | for (let i = 0; i < input.length; i++) { 12 | 13 | // Threshold the onnx model mask prediction at 0.0 14 | // This is equivalent to thresholding the mask using predictor.model.mask_threshold 15 | // in python 16 | if (input[i] > 0.0) { 17 | arr[4 * i + 0] = r; 18 | arr[4 * i + 1] = g; 19 | arr[4 * i + 2] = b; 20 | arr[4 * i + 3] = a; 21 | } 22 | } 23 | return new ImageData(arr, height, width); 24 | } 25 | 26 | // Use a Canvas element to produce an image from ImageData 27 | function imageDataToImage(imageData: ImageData) { 28 | const canvas = imageDataToCanvas(imageData); 29 | const image = new Image(); 30 | image.src = canvas.toDataURL(); 31 | return image; 32 | } 33 | 34 | // Canvas elements can be created from ImageData 35 | function imageDataToCanvas(imageData: ImageData) { 36 | const canvas = document.createElement("canvas"); 37 | const ctx = canvas.getContext("2d"); 38 | canvas.width = imageData.width; 39 | canvas.height = imageData.height; 40 | ctx?.putImageData(imageData, 0, 0); 41 | return canvas; 42 | } 43 | 44 | // Convert the onnx model mask output to an HTMLImageElement 45 | export function onnxMaskToImage(input: any, width: number, height: number) { 46 | return imageDataToImage(arrayToImageData(input, width, height)); 47 | } 48 | -------------------------------------------------------------------------------- /sam/demo/src/components/helpers/onnxModelAPI.tsx: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | import { Tensor } from "onnxruntime-web"; 8 | import { modeDataProps } from "./Interfaces"; 9 | 10 | const modelData = ({ clicks, tensor, modelScale }: modeDataProps) => { 11 | const imageEmbedding = tensor; 12 | let pointCoords; 13 | let pointLabels; 14 | let pointCoordsTensor; 15 | let pointLabelsTensor; 16 | 17 | // Check there are input click prompts 18 | if (clicks) { 19 | let n = clicks.length; 20 | 21 | // If there is no box input, a single padding point with 22 | // label -1 and coordinates (0.0, 0.0) should be concatenated 23 | // so initialize the array to support (n + 1) points. 24 | pointCoords = new Float32Array(2 * (n + 1)); 25 | pointLabels = new Float32Array(n + 1); 26 | 27 | // Add clicks and scale to what SAM expects 28 | for (let i = 0; i < n; i++) { 29 | pointCoords[2 * i] = clicks[i].x * modelScale.samScale; 30 | pointCoords[2 * i + 1] = clicks[i].y * modelScale.samScale; 31 | pointLabels[i] = clicks[i].clickType; 32 | } 33 | 34 | // Add in the extra point/label when only clicks and no box 35 | // The extra point is at (0, 0) with label -1 36 | pointCoords[2 * n] = 0.0; 37 | pointCoords[2 * n + 1] = 0.0; 38 | pointLabels[n] = -1.0; 39 | 40 | // Create the tensor 41 | pointCoordsTensor = new Tensor("float32", pointCoords, [1, n + 1, 2]); 42 | pointLabelsTensor = new Tensor("float32", pointLabels, [1, n + 1]); 43 | } 44 | const imageSizeTensor = new Tensor("float32", [ 45 | modelScale.height, 46 | modelScale.width, 47 | ]); 48 | 49 | if (pointCoordsTensor === undefined || pointLabelsTensor === undefined) 50 | return; 51 | 52 | // There is no previous mask, so default to an empty tensor 53 | const maskInput = new Tensor( 54 | "float32", 55 | new Float32Array(256 * 256), 56 | [1, 1, 256, 256] 57 | ); 58 | // There is no previous mask, so default to 0 59 | const hasMaskInput = new Tensor("float32", [0]); 60 | 61 | return { 62 | image_embeddings: imageEmbedding, 63 | point_coords: pointCoordsTensor, 64 | point_labels: pointLabelsTensor, 65 | orig_im_size: imageSizeTensor, 66 | mask_input: maskInput, 67 | has_mask_input: hasMaskInput, 68 | }; 69 | }; 70 | 71 | export { modelData }; 72 | -------------------------------------------------------------------------------- /sam/demo/src/components/helpers/scaleHelper.tsx: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | 8 | // Helper function for handling image scaling needed for SAM 9 | const handleImageScale = (image: HTMLImageElement) => { 10 | // Input images to SAM must be resized so the longest side is 1024 11 | const LONG_SIDE_LENGTH = 1024; 12 | let w = image.naturalWidth; 13 | let h = image.naturalHeight; 14 | const samScale = LONG_SIDE_LENGTH / Math.max(h, w); 15 | return { height: h, width: w, samScale }; 16 | }; 17 | 18 | export { handleImageScale }; 19 | -------------------------------------------------------------------------------- /sam/demo/src/components/hooks/context.tsx: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | import React, { useState } from "react"; 8 | import { modelInputProps } from "../helpers/Interfaces"; 9 | import AppContext from "./createContext"; 10 | 11 | const AppContextProvider = (props: { 12 | children: React.ReactElement>; 13 | }) => { 14 | const [clicks, setClicks] = useState | null>(null); 15 | const [image, setImage] = useState(null); 16 | const [maskImg, setMaskImg] = useState(null); 17 | 18 | return ( 19 | 26 | {props.children} 27 | 28 | ); 29 | }; 30 | 31 | export default AppContextProvider; 32 | -------------------------------------------------------------------------------- /sam/demo/src/components/hooks/createContext.tsx: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | import { createContext } from "react"; 8 | import { modelInputProps } from "../helpers/Interfaces"; 9 | 10 | interface contextProps { 11 | clicks: [ 12 | clicks: modelInputProps[] | null, 13 | setClicks: (e: modelInputProps[] | null) => void 14 | ]; 15 | image: [ 16 | image: HTMLImageElement | null, 17 | setImage: (e: HTMLImageElement | null) => void 18 | ]; 19 | maskImg: [ 20 | maskImg: HTMLImageElement | null, 21 | setMaskImg: (e: HTMLImageElement | null) => void 22 | ]; 23 | } 24 | 25 | const AppContext = createContext(null); 26 | 27 | export default AppContext; 28 | -------------------------------------------------------------------------------- /sam/demo/src/index.tsx: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | import * as React from "react"; 8 | import { createRoot } from "react-dom/client"; 9 | import AppContextProvider from "./components/hooks/context"; 10 | import App from "./App"; 11 | const container = document.getElementById("root"); 12 | const root = createRoot(container!); 13 | root.render( 14 | 15 | 16 | 17 | ); 18 | -------------------------------------------------------------------------------- /sam/demo/tailwind.config.js: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | /** @type {import('tailwindcss').Config} */ 8 | module.exports = { 9 | content: ["./src/**/*.{html,js,tsx}"], 10 | theme: {}, 11 | plugins: [], 12 | }; 13 | -------------------------------------------------------------------------------- /sam/demo/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "lib": ["dom", "dom.iterable", "esnext"], 4 | "allowJs": true, 5 | "skipLibCheck": true, 6 | "strict": true, 7 | "forceConsistentCasingInFileNames": true, 8 | "noEmit": false, 9 | "esModuleInterop": true, 10 | "module": "esnext", 11 | "moduleResolution": "node", 12 | "resolveJsonModule": true, 13 | "isolatedModules": true, 14 | "jsx": "react", 15 | "incremental": true, 16 | "target": "ESNext", 17 | "useDefineForClassFields": true, 18 | "allowSyntheticDefaultImports": true, 19 | "outDir": "./dist/", 20 | "sourceMap": true 21 | }, 22 | "include": ["next-env.d.ts", "**/*.ts", "**/*.tsx", "src"], 23 | "exclude": ["node_modules"] 24 | } 25 | -------------------------------------------------------------------------------- /sam/linter.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -e 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | { 5 | black --version | grep -E "23\." > /dev/null 6 | } || { 7 | echo "Linter requires 'black==23.*' !" 8 | exit 1 9 | } 10 | 11 | ISORT_VERSION=$(isort --version-number) 12 | if [[ "$ISORT_VERSION" != 5.12* ]]; then 13 | echo "Linter requires isort==5.12.0 !" 14 | exit 1 15 | fi 16 | 17 | echo "Running isort ..." 18 | isort . --atomic 19 | 20 | echo "Running black ..." 21 | black -l 100 . 22 | 23 | echo "Running flake8 ..." 24 | if [ -x "$(command -v flake8)" ]; then 25 | flake8 . 26 | else 27 | python3 -m flake8 . 28 | fi 29 | 30 | echo "Running mypy..." 31 | 32 | mypy --exclude 'setup.py|notebooks' . 33 | -------------------------------------------------------------------------------- /sam/notebooks/images/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/sam/notebooks/images/dog.jpg -------------------------------------------------------------------------------- /sam/notebooks/images/groceries.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/sam/notebooks/images/groceries.jpg -------------------------------------------------------------------------------- /sam/notebooks/images/truck.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/sam/notebooks/images/truck.jpg -------------------------------------------------------------------------------- /sam/segment_anything/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | from .predictor import SamPredictor 15 | from .automatic_mask_generator import SamAutomaticMaskGenerator 16 | -------------------------------------------------------------------------------- /sam/segment_anything/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/sam/segment_anything/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /sam/segment_anything/__pycache__/automatic_mask_generator.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/sam/segment_anything/__pycache__/automatic_mask_generator.cpython-38.pyc -------------------------------------------------------------------------------- /sam/segment_anything/__pycache__/build_sam.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/sam/segment_anything/__pycache__/build_sam.cpython-38.pyc -------------------------------------------------------------------------------- /sam/segment_anything/__pycache__/predictor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/sam/segment_anything/__pycache__/predictor.cpython-38.pyc -------------------------------------------------------------------------------- /sam/segment_anything/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer 12 | 13 | 14 | def build_sam_vit_h(checkpoint=None): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | ) 22 | 23 | 24 | build_sam = build_sam_vit_h 25 | 26 | 27 | def build_sam_vit_l(checkpoint=None): 28 | return _build_sam( 29 | encoder_embed_dim=1024, 30 | encoder_depth=24, 31 | encoder_num_heads=16, 32 | encoder_global_attn_indexes=[5, 11, 17, 23], 33 | checkpoint=checkpoint, 34 | ) 35 | 36 | 37 | def build_sam_vit_b(checkpoint=None): 38 | return _build_sam( 39 | encoder_embed_dim=768, 40 | encoder_depth=12, 41 | encoder_num_heads=12, 42 | encoder_global_attn_indexes=[2, 5, 8, 11], 43 | checkpoint=checkpoint, 44 | ) 45 | 46 | 47 | sam_model_registry = { 48 | "default": build_sam_vit_h, 49 | "vit_h": build_sam_vit_h, 50 | "vit_l": build_sam_vit_l, 51 | "vit_b": build_sam_vit_b, 52 | } 53 | 54 | 55 | def _build_sam( 56 | encoder_embed_dim, 57 | encoder_depth, 58 | encoder_num_heads, 59 | encoder_global_attn_indexes, 60 | checkpoint=None, 61 | ): 62 | prompt_embed_dim = 256 63 | image_size = 1024 64 | vit_patch_size = 16 65 | image_embedding_size = image_size // vit_patch_size 66 | sam = Sam( 67 | image_encoder=ImageEncoderViT( 68 | depth=encoder_depth, 69 | embed_dim=encoder_embed_dim, 70 | img_size=image_size, 71 | mlp_ratio=4, 72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 73 | num_heads=encoder_num_heads, 74 | patch_size=vit_patch_size, 75 | qkv_bias=True, 76 | use_rel_pos=True, 77 | global_attn_indexes=encoder_global_attn_indexes, 78 | window_size=14, 79 | out_chans=prompt_embed_dim, 80 | ), 81 | prompt_encoder=PromptEncoder( 82 | embed_dim=prompt_embed_dim, 83 | image_embedding_size=(image_embedding_size, image_embedding_size), 84 | input_image_size=(image_size, image_size), 85 | mask_in_chans=16, 86 | ), 87 | mask_decoder=MaskDecoder( 88 | num_multimask_outputs=3, 89 | transformer=TwoWayTransformer( 90 | depth=2, 91 | embedding_dim=prompt_embed_dim, 92 | mlp_dim=2048, 93 | num_heads=8, 94 | ), 95 | transformer_dim=prompt_embed_dim, 96 | iou_head_depth=3, 97 | iou_head_hidden_dim=256, 98 | ), 99 | pixel_mean=[123.675, 116.28, 103.53], 100 | pixel_std=[58.395, 57.12, 57.375], 101 | ) 102 | sam.eval() 103 | if checkpoint is not None: 104 | with open(checkpoint, "rb") as f: 105 | state_dict = torch.load(f) 106 | sam.load_state_dict(state_dict) 107 | return sam 108 | -------------------------------------------------------------------------------- /sam/segment_anything/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | -------------------------------------------------------------------------------- /sam/segment_anything/modeling/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/sam/segment_anything/modeling/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /sam/segment_anything/modeling/__pycache__/common.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/sam/segment_anything/modeling/__pycache__/common.cpython-38.pyc -------------------------------------------------------------------------------- /sam/segment_anything/modeling/__pycache__/image_encoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/sam/segment_anything/modeling/__pycache__/image_encoder.cpython-38.pyc -------------------------------------------------------------------------------- /sam/segment_anything/modeling/__pycache__/mask_decoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/sam/segment_anything/modeling/__pycache__/mask_decoder.cpython-38.pyc -------------------------------------------------------------------------------- /sam/segment_anything/modeling/__pycache__/prompt_encoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/sam/segment_anything/modeling/__pycache__/prompt_encoder.cpython-38.pyc -------------------------------------------------------------------------------- /sam/segment_anything/modeling/__pycache__/sam.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/sam/segment_anything/modeling/__pycache__/sam.cpython-38.pyc -------------------------------------------------------------------------------- /sam/segment_anything/modeling/__pycache__/transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/sam/segment_anything/modeling/__pycache__/transformer.cpython-38.pyc -------------------------------------------------------------------------------- /sam/segment_anything/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /sam/segment_anything/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam/segment_anything/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/sam/segment_anything/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /sam/segment_anything/utils/__pycache__/amg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/sam/segment_anything/utils/__pycache__/amg.cpython-38.pyc -------------------------------------------------------------------------------- /sam/segment_anything/utils/__pycache__/transforms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/sam/segment_anything/utils/__pycache__/transforms.cpython-38.pyc -------------------------------------------------------------------------------- /sam/segment_anything/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to the longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /sam/setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | line_length=100 3 | multi_line_output=3 4 | include_trailing_comma=True 5 | known_standard_library=numpy,setuptools 6 | skip_glob=*/__init__.py 7 | known_myself=segment_anything 8 | known_third_party=matplotlib,cv2,torch,torchvision,pycocotools,onnx,black,isort 9 | no_lines_before=STDLIB,THIRDPARTY 10 | sections=FUTURE,STDLIB,THIRDPARTY,MYSELF,FIRSTPARTY,LOCALFOLDER 11 | default_section=FIRSTPARTY 12 | -------------------------------------------------------------------------------- /sam/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from setuptools import find_packages, setup 8 | 9 | setup( 10 | name="segment_anything", 11 | version="1.0", 12 | install_requires=[], 13 | packages=find_packages(exclude="notebooks"), 14 | extras_require={ 15 | "all": ["matplotlib", "pycocotools", "opencv-python", "onnx", "onnxruntime"], 16 | "dev": ["flake8", "isort", "black", "mypy"], 17 | }, 18 | ) 19 | -------------------------------------------------------------------------------- /swell.sh: -------------------------------------------------------------------------------- 1 | python label.py # Acquire the json file 2 | python list.py # Merge the json file 3 | python par_crop.py # crop the patches 4 | -------------------------------------------------------------------------------- /tracker/BAN/experiments/udatban_r50_l234/config.yaml: -------------------------------------------------------------------------------- 1 | META_ARC: "udatban_r50_l234" 2 | 3 | BACKBONE: 4 | TYPE: "resnet50" 5 | KWARGS: 6 | used_layers: [2, 3, 4] 7 | PRETRAINED: '' 8 | TRAIN_LAYERS: ['layer2', 'layer3', 'layer4'] 9 | TRAIN_EPOCH: 0 10 | LAYERS_LR: 0.1 11 | 12 | ADJUST: 13 | ADJUST: True 14 | TYPE: "AdjustAllLayer" 15 | KWARGS: 16 | in_channels: [512, 1024, 2048] 17 | out_channels: [256, 256, 256] 18 | 19 | ALIGN: 20 | ALIGN: False 21 | TYPE: "Adjust_Transformer" 22 | KWARGS: 23 | channels: 256 24 | 25 | BAN: 26 | BAN: True 27 | TYPE: 'MultiBAN' 28 | KWARGS: 29 | in_channels: [256, 256, 256] 30 | cls_out_channels: 2 # if use sigmoid cls, cls_out_channel = 1 else 2 31 | weighted: True 32 | 33 | POINT: 34 | STRIDE: 8 35 | 36 | TRACK: 37 | TYPE: 'SiamBANTracker' 38 | WINDOW_INFLUENCE: 0.385 39 | PENALTY_K: 0.02 40 | LR: 0.473 41 | EXEMPLAR_SIZE: 127 42 | INSTANCE_SIZE: 255 43 | BASE_SIZE: 8 44 | CONTEXT_AMOUNT: 0.5 45 | 46 | TRAIN: 47 | EPOCH: 20 48 | START_EPOCH: 0 49 | BATCH_SIZE: 28 50 | BASE_LR: 0.0015 51 | BASE_LR_d: 0.005 52 | CLS_WEIGHT: 1.0 53 | LOC_WEIGHT: 1.0 54 | PRETRAINED: './checkpoint/siamban.pth' 55 | RESUME: '' 56 | 57 | LR: 58 | TYPE: 'log' 59 | KWARGS: 60 | start_lr: 0.0015 61 | end_lr: 0.000015 62 | LR_WARMUP: 63 | TYPE: 'step' 64 | EPOCH: 5 65 | KWARGS: 66 | start_lr: 0.0003 67 | end_lr: 0.0015 68 | step: 1 69 | 70 | DATASET: 71 | NAMES: 72 | - 'VID' 73 | - 'YOUTUBEBB' 74 | - 'COCO' 75 | - 'DET' 76 | - 'GOT10K' 77 | - 'LASOT' 78 | 79 | VIDEOS_PER_EPOCH: 20000 80 | 81 | TEMPLATE: 82 | SHIFT: 4 83 | SCALE: 0.05 84 | BLUR: 0.0 85 | FLIP: 0.0 86 | COLOR: 1.0 87 | 88 | SEARCH: 89 | SHIFT: 64 90 | SCALE: 0.18 91 | BLUR: 0.2 92 | FLIP: 0.0 93 | COLOR: 1.0 94 | 95 | NEG: 0.2 96 | GRAY: 0.2 97 | -------------------------------------------------------------------------------- /tracker/BAN/siamban/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/siamban/core/__init__.py -------------------------------------------------------------------------------- /tracker/BAN/siamban/core/xcorr.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | 12 | def xcorr_slow(x, kernel): 13 | """for loop to calculate cross correlation, slow version 14 | """ 15 | batch = x.size()[0] 16 | out = [] 17 | for i in range(batch): 18 | px = x[i] 19 | pk = kernel[i] 20 | px = px.view(1, -1, px.size()[1], px.size()[2]) 21 | pk = pk.view(1, -1, pk.size()[1], pk.size()[2]) 22 | po = F.conv2d(px, pk) 23 | out.append(po) 24 | out = torch.cat(out, 0) 25 | return out 26 | 27 | 28 | def xcorr_fast(x, kernel): 29 | """group conv2d to calculate cross correlation, fast version 30 | """ 31 | batch = kernel.size()[0] 32 | pk = kernel.view(-1, x.size()[1], kernel.size()[2], kernel.size()[3]) 33 | px = x.view(1, -1, x.size()[2], x.size()[3]) 34 | po = F.conv2d(px, pk, groups=batch) 35 | po = po.view(batch, -1, po.size()[2], po.size()[3]) 36 | return po 37 | 38 | 39 | def xcorr_depthwise(x, kernel): 40 | """depthwise cross correlation 41 | """ 42 | batch = kernel.size(0) 43 | channel = kernel.size(1) 44 | x = x.view(1, batch*channel, x.size(2), x.size(3)) 45 | kernel = kernel.view(batch*channel, 1, kernel.size(2), kernel.size(3)) 46 | out = F.conv2d(x, kernel, groups=batch*channel) 47 | out = out.view(batch, channel, out.size(2), out.size(3)) 48 | return out 49 | -------------------------------------------------------------------------------- /tracker/BAN/siamban/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/siamban/datasets/__init__.py -------------------------------------------------------------------------------- /tracker/BAN/siamban/datasets/point_target.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import numpy as np 7 | 8 | from siamban.core.config import cfg 9 | from siamban.utils.bbox import corner2center 10 | from siamban.utils.point import Point 11 | 12 | 13 | class PointTarget: 14 | def __init__(self,): 15 | self.points = Point(cfg.POINT.STRIDE, cfg.TRAIN.OUTPUT_SIZE, cfg.TRAIN.SEARCH_SIZE//2) 16 | 17 | def __call__(self, target, size, neg=False): 18 | 19 | # -1 ignore 0 negative 1 positive 20 | cls = -1 * np.ones((size, size), dtype=np.int64) 21 | delta = np.zeros((4, size, size), dtype=np.float32) 22 | 23 | def select(position, keep_num=16): 24 | num = position[0].shape[0] 25 | if num <= keep_num: 26 | return position, num 27 | slt = np.arange(num) 28 | np.random.shuffle(slt) 29 | slt = slt[:keep_num] 30 | return tuple(p[slt] for p in position), keep_num 31 | 32 | tcx, tcy, tw, th = corner2center(target) 33 | points = self.points.points 34 | 35 | if neg: 36 | neg = np.where(np.square(tcx - points[0]) / np.square(tw / 4) + 37 | np.square(tcy - points[1]) / np.square(th / 4) < 1) 38 | neg, neg_num = select(neg, cfg.TRAIN.NEG_NUM) 39 | cls[neg] = 0 40 | 41 | return cls, delta 42 | 43 | delta[0] = points[0] - target[0] 44 | delta[1] = points[1] - target[1] 45 | delta[2] = target[2] - points[0] 46 | delta[3] = target[3] - points[1] 47 | 48 | # ellipse label 49 | pos = np.where(np.square(tcx - points[0]) / np.square(tw / 4) + 50 | np.square(tcy - points[1]) / np.square(th / 4) < 1) 51 | neg = np.where(np.square(tcx - points[0]) / np.square(tw / 2) + 52 | np.square(tcy - points[1]) / np.square(th / 2) > 1) 53 | 54 | # sampling 55 | pos, pos_num = select(pos, cfg.TRAIN.POS_NUM) 56 | neg, neg_num = select(neg, cfg.TRAIN.TOTAL_NUM - cfg.TRAIN.POS_NUM) 57 | 58 | cls[pos] = 1 59 | cls[neg] = 0 60 | 61 | return cls, delta 62 | -------------------------------------------------------------------------------- /tracker/BAN/siamban/models/GRL.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class _GradientScalarLayer(torch.autograd.Function): 5 | @staticmethod 6 | def forward(ctx, input, weight): 7 | ctx.weight = weight 8 | return input.view_as(input) 9 | 10 | @staticmethod 11 | def backward(ctx, grad_output): 12 | grad_input = grad_output.clone() 13 | return ctx.weight*grad_input, None 14 | 15 | gradient_scalar = _GradientScalarLayer.apply 16 | 17 | 18 | class GradientScalarLayer(torch.nn.Module): 19 | def __init__(self, weight): 20 | super(GradientScalarLayer, self).__init__() 21 | self.weight = weight 22 | 23 | def forward(self, input): 24 | return gradient_scalar(input, self.weight) 25 | 26 | def __repr__(self): 27 | tmpstr = self.__class__.__name__ + "(" 28 | tmpstr += "weight=" + str(self.weight) 29 | tmpstr += ")" 30 | return tmpstr -------------------------------------------------------------------------------- /tracker/BAN/siamban/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/siamban/models/__init__.py -------------------------------------------------------------------------------- /tracker/BAN/siamban/models/__pycache__/GRL.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/siamban/models/__pycache__/GRL.cpython-38.pyc -------------------------------------------------------------------------------- /tracker/BAN/siamban/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/siamban/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /tracker/BAN/siamban/models/__pycache__/iou_loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/siamban/models/__pycache__/iou_loss.cpython-38.pyc -------------------------------------------------------------------------------- /tracker/BAN/siamban/models/__pycache__/loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/siamban/models/__pycache__/loss.cpython-38.pyc -------------------------------------------------------------------------------- /tracker/BAN/siamban/models/__pycache__/model_builder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/siamban/models/__pycache__/model_builder.cpython-38.pyc -------------------------------------------------------------------------------- /tracker/BAN/siamban/models/__pycache__/model_builder_tsne.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/siamban/models/__pycache__/model_builder_tsne.cpython-38.pyc -------------------------------------------------------------------------------- /tracker/BAN/siamban/models/__pycache__/model_builder_v.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/siamban/models/__pycache__/model_builder_v.cpython-38.pyc -------------------------------------------------------------------------------- /tracker/BAN/siamban/models/__pycache__/trans_discriminator.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/siamban/models/__pycache__/trans_discriminator.cpython-38.pyc -------------------------------------------------------------------------------- /tracker/BAN/siamban/models/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | from siamban.models.backbone.alexnet import alexnetlegacy, alexnet 9 | from siamban.models.backbone.mobile_v2 import mobilenetv2 10 | from siamban.models.backbone.resnet_atrous import resnet18, resnet34, resnet50 11 | 12 | BACKBONES = { 13 | 'alexnetlegacy': alexnetlegacy, 14 | 'mobilenetv2': mobilenetv2, 15 | 'resnet18': resnet18, 16 | 'resnet34': resnet34, 17 | 'resnet50': resnet50, 18 | 'alexnet': alexnet, 19 | } 20 | 21 | 22 | def get_backbone(name, **kwargs): 23 | return BACKBONES[name](**kwargs) 24 | -------------------------------------------------------------------------------- /tracker/BAN/siamban/models/backbone/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/siamban/models/backbone/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /tracker/BAN/siamban/models/backbone/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/siamban/models/backbone/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /tracker/BAN/siamban/models/backbone/__pycache__/alexnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/siamban/models/backbone/__pycache__/alexnet.cpython-36.pyc -------------------------------------------------------------------------------- /tracker/BAN/siamban/models/backbone/__pycache__/alexnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/siamban/models/backbone/__pycache__/alexnet.cpython-38.pyc -------------------------------------------------------------------------------- /tracker/BAN/siamban/models/backbone/__pycache__/mobile_v2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/siamban/models/backbone/__pycache__/mobile_v2.cpython-36.pyc -------------------------------------------------------------------------------- /tracker/BAN/siamban/models/backbone/__pycache__/mobile_v2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/siamban/models/backbone/__pycache__/mobile_v2.cpython-38.pyc -------------------------------------------------------------------------------- /tracker/BAN/siamban/models/backbone/__pycache__/resnet_atrous.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/siamban/models/backbone/__pycache__/resnet_atrous.cpython-36.pyc -------------------------------------------------------------------------------- /tracker/BAN/siamban/models/backbone/__pycache__/resnet_atrous.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/siamban/models/backbone/__pycache__/resnet_atrous.cpython-38.pyc -------------------------------------------------------------------------------- /tracker/BAN/siamban/models/backbone/alexnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import torch.nn as nn 7 | 8 | 9 | class AlexNetLegacy(nn.Module): 10 | configs = [3, 96, 256, 384, 384, 256] 11 | 12 | def __init__(self, width_mult=1): 13 | configs = list(map(lambda x: 3 if x == 3 else 14 | int(x*width_mult), AlexNet.configs)) 15 | super(AlexNetLegacy, self).__init__() 16 | self.features = nn.Sequential( 17 | nn.Conv2d(configs[0], configs[1], kernel_size=11, stride=2), 18 | nn.BatchNorm2d(configs[1]), 19 | nn.MaxPool2d(kernel_size=3, stride=2), 20 | nn.ReLU(inplace=True), 21 | nn.Conv2d(configs[1], configs[2], kernel_size=5), 22 | nn.BatchNorm2d(configs[2]), 23 | nn.MaxPool2d(kernel_size=3, stride=2), 24 | nn.ReLU(inplace=True), 25 | nn.Conv2d(configs[2], configs[3], kernel_size=3), 26 | nn.BatchNorm2d(configs[3]), 27 | nn.ReLU(inplace=True), 28 | nn.Conv2d(configs[3], configs[4], kernel_size=3), 29 | nn.BatchNorm2d(configs[4]), 30 | nn.ReLU(inplace=True), 31 | nn.Conv2d(configs[4], configs[5], kernel_size=3), 32 | nn.BatchNorm2d(configs[5]), 33 | ) 34 | self.feature_size = configs[5] 35 | 36 | def forward(self, x): 37 | x = self.features(x) 38 | return x 39 | 40 | 41 | class AlexNet(nn.Module): 42 | configs = [3, 96, 256, 384, 384, 256] 43 | 44 | def __init__(self, width_mult=1): 45 | configs = list(map(lambda x: 3 if x == 3 else 46 | int(x*width_mult), AlexNet.configs)) 47 | super(AlexNet, self).__init__() 48 | self.layer1 = nn.Sequential( 49 | nn.Conv2d(configs[0], configs[1], kernel_size=11, stride=2), 50 | nn.BatchNorm2d(configs[1]), 51 | nn.MaxPool2d(kernel_size=3, stride=2), 52 | nn.ReLU(inplace=True), 53 | ) 54 | self.layer2 = nn.Sequential( 55 | nn.Conv2d(configs[1], configs[2], kernel_size=5), 56 | nn.BatchNorm2d(configs[2]), 57 | nn.MaxPool2d(kernel_size=3, stride=2), 58 | nn.ReLU(inplace=True), 59 | ) 60 | self.layer3 = nn.Sequential( 61 | nn.Conv2d(configs[2], configs[3], kernel_size=3), 62 | nn.BatchNorm2d(configs[3]), 63 | nn.ReLU(inplace=True), 64 | ) 65 | self.layer4 = nn.Sequential( 66 | nn.Conv2d(configs[3], configs[4], kernel_size=3), 67 | nn.BatchNorm2d(configs[4]), 68 | nn.ReLU(inplace=True), 69 | ) 70 | 71 | self.layer5 = nn.Sequential( 72 | nn.Conv2d(configs[4], configs[5], kernel_size=3), 73 | nn.BatchNorm2d(configs[5]), 74 | ) 75 | self.feature_size = configs[5] 76 | 77 | def forward(self, x): 78 | x = self.layer1(x) 79 | x = self.layer2(x) 80 | x = self.layer3(x) 81 | x = self.layer4(x) 82 | x = self.layer5(x) 83 | return x 84 | 85 | 86 | def alexnetlegacy(**kwargs): 87 | return AlexNetLegacy(**kwargs) 88 | 89 | 90 | def alexnet(**kwargs): 91 | return AlexNet(**kwargs) 92 | -------------------------------------------------------------------------------- /tracker/BAN/siamban/models/backbone/mobile_v2.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | def conv_bn(inp, oup, stride, padding=1): 11 | return nn.Sequential( 12 | nn.Conv2d(inp, oup, 3, stride, padding, bias=False), 13 | nn.BatchNorm2d(oup), 14 | nn.ReLU6(inplace=True) 15 | ) 16 | 17 | 18 | def conv_1x1_bn(inp, oup): 19 | return nn.Sequential( 20 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 21 | nn.BatchNorm2d(oup), 22 | nn.ReLU6(inplace=True) 23 | ) 24 | 25 | 26 | class InvertedResidual(nn.Module): 27 | def __init__(self, inp, oup, stride, expand_ratio, dilation=1): 28 | super(InvertedResidual, self).__init__() 29 | self.stride = stride 30 | 31 | self.use_res_connect = self.stride == 1 and inp == oup 32 | 33 | padding = 2 - stride 34 | if dilation > 1: 35 | padding = dilation 36 | 37 | self.conv = nn.Sequential( 38 | # pw 39 | nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False), 40 | nn.BatchNorm2d(inp * expand_ratio), 41 | nn.ReLU6(inplace=True), 42 | # dw 43 | nn.Conv2d(inp * expand_ratio, inp * expand_ratio, 3, 44 | stride, padding, dilation=dilation, 45 | groups=inp * expand_ratio, bias=False), 46 | nn.BatchNorm2d(inp * expand_ratio), 47 | nn.ReLU6(inplace=True), 48 | # pw-linear 49 | nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False), 50 | nn.BatchNorm2d(oup), 51 | ) 52 | 53 | def forward(self, x): 54 | if self.use_res_connect: 55 | return x + self.conv(x) 56 | else: 57 | return self.conv(x) 58 | 59 | 60 | class MobileNetV2(nn.Sequential): 61 | def __init__(self, width_mult=1.0, used_layers=[3, 5, 7]): 62 | super(MobileNetV2, self).__init__() 63 | 64 | self.interverted_residual_setting = [ 65 | # t, c, n, s 66 | [1, 16, 1, 1, 1], 67 | [6, 24, 2, 2, 1], 68 | [6, 32, 3, 2, 1], 69 | [6, 64, 4, 2, 1], 70 | [6, 96, 3, 1, 1], 71 | [6, 160, 3, 2, 1], 72 | [6, 320, 1, 1, 1], 73 | ] 74 | # 0,2,3,4,6 75 | 76 | self.interverted_residual_setting = [ 77 | # t, c, n, s 78 | [1, 16, 1, 1, 1], 79 | [6, 24, 2, 2, 1], 80 | [6, 32, 3, 2, 1], 81 | [6, 64, 4, 1, 2], 82 | [6, 96, 3, 1, 2], 83 | [6, 160, 3, 1, 4], 84 | [6, 320, 1, 1, 4], 85 | ] 86 | 87 | self.channels = [24, 32, 96, 320] 88 | self.channels = [int(c * width_mult) for c in self.channels] 89 | 90 | input_channel = int(32 * width_mult) 91 | self.last_channel = int(1280 * width_mult) \ 92 | if width_mult > 1.0 else 1280 93 | 94 | self.add_module('layer0', conv_bn(3, input_channel, 2, 0)) 95 | 96 | last_dilation = 1 97 | 98 | self.used_layers = used_layers 99 | 100 | for idx, (t, c, n, s, d) in \ 101 | enumerate(self.interverted_residual_setting, start=1): 102 | output_channel = int(c * width_mult) 103 | 104 | layers = [] 105 | 106 | for i in range(n): 107 | if i == 0: 108 | if d == last_dilation: 109 | dd = d 110 | else: 111 | dd = max(d // 2, 1) 112 | layers.append(InvertedResidual(input_channel, 113 | output_channel, s, t, dd)) 114 | else: 115 | layers.append(InvertedResidual(input_channel, 116 | output_channel, 1, t, d)) 117 | input_channel = output_channel 118 | 119 | last_dilation = d 120 | 121 | self.add_module('layer%d' % (idx), nn.Sequential(*layers)) 122 | 123 | def forward(self, x): 124 | outputs = [] 125 | for idx in range(8): 126 | name = "layer%d" % idx 127 | x = getattr(self, name)(x) 128 | outputs.append(x) 129 | p0, p1, p2, p3, p4 = [outputs[i] for i in [1, 2, 3, 5, 7]] 130 | out = [outputs[i] for i in self.used_layers] 131 | return out 132 | 133 | 134 | def mobilenetv2(**kwargs): 135 | model = MobileNetV2(**kwargs) 136 | return model 137 | 138 | 139 | if __name__ == '__main__': 140 | net = mobilenetv2() 141 | 142 | print(net) 143 | 144 | from torch.autograd import Variable 145 | tensor = Variable(torch.Tensor(1, 3, 255, 255)).cuda() 146 | 147 | net = net.cuda() 148 | 149 | out = net(tensor) 150 | 151 | for i, p in enumerate(out): 152 | print(i, p.size()) 153 | -------------------------------------------------------------------------------- /tracker/BAN/siamban/models/head/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | from siamban.models.head.ban import UPChannelBAN, DepthwiseBAN, MultiBAN 7 | 8 | 9 | BANS = { 10 | 'UPChannelBAN': UPChannelBAN, 11 | 'DepthwiseBAN': DepthwiseBAN, 12 | 'MultiBAN': MultiBAN 13 | } 14 | 15 | 16 | def get_ban_head(name, **kwargs): 17 | return BANS[name](**kwargs) 18 | 19 | -------------------------------------------------------------------------------- /tracker/BAN/siamban/models/head/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/siamban/models/head/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /tracker/BAN/siamban/models/head/__pycache__/ban.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/siamban/models/head/__pycache__/ban.cpython-38.pyc -------------------------------------------------------------------------------- /tracker/BAN/siamban/models/head/ban.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from siamban.core.xcorr import xcorr_fast, xcorr_depthwise 11 | 12 | class BAN(nn.Module): 13 | def __init__(self): 14 | super(BAN, self).__init__() 15 | 16 | def forward(self, z_f, x_f): 17 | raise NotImplementedError 18 | 19 | class UPChannelBAN(BAN): 20 | def __init__(self, feature_in=256, cls_out_channels=2): 21 | super(UPChannelBAN, self).__init__() 22 | 23 | cls_output = cls_out_channels 24 | loc_output = 4 25 | 26 | self.template_cls_conv = nn.Conv2d(feature_in, 27 | feature_in * cls_output, kernel_size=3) 28 | self.template_loc_conv = nn.Conv2d(feature_in, 29 | feature_in * loc_output, kernel_size=3) 30 | 31 | self.search_cls_conv = nn.Conv2d(feature_in, 32 | feature_in, kernel_size=3) 33 | self.search_loc_conv = nn.Conv2d(feature_in, 34 | feature_in, kernel_size=3) 35 | 36 | self.loc_adjust = nn.Conv2d(loc_output, loc_output, kernel_size=1) 37 | 38 | 39 | def forward(self, z_f, x_f): 40 | cls_kernel = self.template_cls_conv(z_f) 41 | loc_kernel = self.template_loc_conv(z_f) 42 | 43 | cls_feature = self.search_cls_conv(x_f) 44 | loc_feature = self.search_loc_conv(x_f) 45 | 46 | cls = xcorr_fast(cls_feature, cls_kernel) 47 | loc = self.loc_adjust(xcorr_fast(loc_feature, loc_kernel)) 48 | return cls, loc 49 | 50 | 51 | class DepthwiseXCorr(nn.Module): 52 | def __init__(self, in_channels, hidden, out_channels, kernel_size=3): 53 | super(DepthwiseXCorr, self).__init__() 54 | self.conv_kernel = nn.Sequential( 55 | nn.Conv2d(in_channels, hidden, kernel_size=kernel_size, bias=False), 56 | nn.BatchNorm2d(hidden), 57 | nn.ReLU(inplace=True), 58 | ) 59 | self.conv_search = nn.Sequential( 60 | nn.Conv2d(in_channels, hidden, kernel_size=kernel_size, bias=False), 61 | nn.BatchNorm2d(hidden), 62 | nn.ReLU(inplace=True), 63 | ) 64 | self.head = nn.Sequential( 65 | nn.Conv2d(hidden, hidden, kernel_size=1, bias=False), 66 | nn.BatchNorm2d(hidden), 67 | nn.ReLU(inplace=True), 68 | nn.Conv2d(hidden, out_channels, kernel_size=1) 69 | ) 70 | 71 | 72 | def forward(self, kernel, search): 73 | kernel = self.conv_kernel(kernel) 74 | search = self.conv_search(search) 75 | feature = xcorr_depthwise(search, kernel) 76 | out = self.head(feature) 77 | return out 78 | 79 | 80 | class DepthwiseBAN(BAN): 81 | def __init__(self, in_channels=256, out_channels=256, cls_out_channels=2, weighted=False): 82 | super(DepthwiseBAN, self).__init__() 83 | self.cls = DepthwiseXCorr(in_channels, out_channels, cls_out_channels) 84 | self.loc = DepthwiseXCorr(in_channels, out_channels, 4) 85 | 86 | def forward(self, z_f, x_f): 87 | cls = self.cls(z_f, x_f) 88 | loc = self.loc(z_f, x_f) 89 | return cls, loc 90 | 91 | 92 | class MultiBAN(BAN): 93 | def __init__(self, in_channels, cls_out_channels, weighted=False): 94 | super(MultiBAN, self).__init__() 95 | self.weighted = weighted 96 | for i in range(len(in_channels)): 97 | self.add_module('box'+str(i+2), DepthwiseBAN(in_channels[i], in_channels[i], cls_out_channels)) 98 | if self.weighted: 99 | self.cls_weight = nn.Parameter(torch.ones(len(in_channels))) 100 | self.loc_weight = nn.Parameter(torch.ones(len(in_channels))) 101 | self.loc_scale = nn.Parameter(torch.ones(len(in_channels))) 102 | 103 | def forward(self, z_fs, x_fs): 104 | cls = [] 105 | loc = [] 106 | for idx, (z_f, x_f) in enumerate(zip(z_fs, x_fs), start=2): 107 | box = getattr(self, 'box'+str(idx)) 108 | c, l = box(z_f, x_f) 109 | cls.append(c) 110 | loc.append(torch.exp(l*self.loc_scale[idx-2])) 111 | 112 | if self.weighted: 113 | cls_weight = F.softmax(self.cls_weight, 0) 114 | loc_weight = F.softmax(self.loc_weight, 0) 115 | 116 | def avg(lst): 117 | return sum(lst) / len(lst) 118 | 119 | def weighted_avg(lst, weight): 120 | s = 0 121 | for i in range(len(weight)): 122 | s += lst[i] * weight[i] 123 | return s 124 | 125 | if self.weighted: 126 | return weighted_avg(cls, cls_weight), weighted_avg(loc, loc_weight) 127 | else: 128 | return avg(cls), avg(loc) 129 | -------------------------------------------------------------------------------- /tracker/BAN/siamban/models/init_weight.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def init_weights(model): 5 | for m in model.modules(): 6 | if isinstance(m, nn.Conv2d): 7 | nn.init.kaiming_normal_(m.weight.data, 8 | mode='fan_out', 9 | nonlinearity='relu') 10 | elif isinstance(m, nn.BatchNorm2d): 11 | m.weight.data.fill_(1) 12 | m.bias.data.zero_() 13 | -------------------------------------------------------------------------------- /tracker/BAN/siamban/models/iou_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class IOULoss(nn.Module): 6 | def __init__(self, loc_loss_type): 7 | super(IOULoss, self).__init__() 8 | self.loc_loss_type = loc_loss_type 9 | 10 | def forward(self, pred, target, weight=None): 11 | pred_left = pred[:, 0] 12 | pred_top = pred[:, 1] 13 | pred_right = pred[:, 2] 14 | pred_bottom = pred[:, 3] 15 | 16 | target_left = target[:, 0] 17 | target_top = target[:, 1] 18 | target_right = target[:, 2] 19 | target_bottom = target[:, 3] 20 | 21 | pred_area = (pred_left + pred_right) * (pred_top + pred_bottom) 22 | target_area = (target_left + target_right) * (target_top + target_bottom) 23 | 24 | w_intersect = torch.min(pred_left, target_left) + torch.min(pred_right, target_right) 25 | g_w_intersect = torch.max(pred_left, target_left) + torch.max(pred_right, target_right) 26 | h_intersect = torch.min(pred_bottom, target_bottom) + torch.min(pred_top, target_top) 27 | g_h_intersect = torch.max(pred_bottom, target_bottom) + torch.max(pred_top, target_top) 28 | ac_uion = g_w_intersect * g_h_intersect + 1e-7 29 | area_intersect = w_intersect * h_intersect 30 | area_union = target_area + pred_area - area_intersect 31 | ious = (area_intersect + 1.0) / (area_union + 1.0) 32 | gious = ious - (ac_uion - area_union) / ac_uion 33 | 34 | if self.loc_loss_type == 'iou': 35 | losses = -torch.log(ious) 36 | elif self.loc_loss_type == 'linear_iou': 37 | losses = 1 - ious 38 | elif self.loc_loss_type == 'giou': 39 | losses = 1 - gious 40 | else: 41 | raise NotImplementedError 42 | 43 | if weight is not None and weight.sum() > 0: 44 | return (losses * weight).sum() / weight.sum() 45 | else: 46 | assert losses.numel() != 0 47 | return losses.mean() 48 | 49 | 50 | linear_iou = IOULoss(loc_loss_type='linear_iou') 51 | -------------------------------------------------------------------------------- /tracker/BAN/siamban/models/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import numpy as np 12 | 13 | from siamban.core.config import cfg 14 | from siamban.models.iou_loss import linear_iou 15 | 16 | 17 | def get_cls_loss(pred, label, select): 18 | if len(select.size()) == 0 or \ 19 | select.size() == torch.Size([0]): 20 | return 0 21 | pred = torch.index_select(pred, 0, select) 22 | label = torch.index_select(label, 0, select) 23 | return F.nll_loss(pred, label) 24 | 25 | 26 | def select_cross_entropy_loss(pred, label): 27 | pred = pred.view(-1, 2) 28 | label = label.view(-1) 29 | pos = label.data.eq(1).nonzero().squeeze().cuda() 30 | neg = label.data.eq(0).nonzero().squeeze().cuda() 31 | loss_pos = get_cls_loss(pred, label, pos) 32 | loss_neg = get_cls_loss(pred, label, neg) 33 | return loss_pos * 0.5 + loss_neg * 0.5 34 | 35 | 36 | def weight_l1_loss(pred_loc, label_loc, loss_weight): 37 | if cfg.BAN.BAN: 38 | diff = (pred_loc - label_loc).abs() 39 | diff = diff.sum(dim=1) 40 | else: 41 | diff = None 42 | loss = diff * loss_weight 43 | return loss.sum().div(pred_loc.size()[0]) 44 | 45 | 46 | def select_iou_loss(pred_loc, label_loc, label_cls): 47 | label_cls = label_cls.reshape(-1) 48 | pos = label_cls.data.eq(1).nonzero().squeeze().cuda() 49 | 50 | pred_loc = pred_loc.permute(0, 2, 3, 1).reshape(-1, 4) 51 | pred_loc = torch.index_select(pred_loc, 0, pos) 52 | 53 | label_loc = label_loc.permute(0, 2, 3, 1).reshape(-1, 4) 54 | label_loc = torch.index_select(label_loc, 0, pos) 55 | 56 | return linear_iou(pred_loc, label_loc) 57 | -------------------------------------------------------------------------------- /tracker/BAN/siamban/models/model_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from siamban.core.config import cfg 12 | from siamban.models.loss import select_cross_entropy_loss, select_iou_loss 13 | from siamban.models.backbone import get_backbone 14 | from siamban.models.head import get_ban_head 15 | from siamban.models.neck import get_neck 16 | 17 | 18 | class ModelBuilder(nn.Module): 19 | def __init__(self): 20 | super(ModelBuilder, self).__init__() 21 | 22 | # build backbone 23 | self.backbone = get_backbone(cfg.BACKBONE.TYPE, 24 | **cfg.BACKBONE.KWARGS) 25 | 26 | # build adjust layer 27 | if cfg.ADJUST.ADJUST: 28 | self.neck = get_neck(cfg.ADJUST.TYPE, 29 | **cfg.ADJUST.KWARGS) 30 | 31 | if cfg.ALIGN.ALIGN: 32 | self.align = get_neck(cfg.ALIGN.TYPE, 33 | **cfg.ALIGN.KWARGS) 34 | 35 | # build ban head 36 | if cfg.BAN.BAN: 37 | self.head = get_ban_head(cfg.BAN.TYPE, 38 | **cfg.BAN.KWARGS) 39 | 40 | def template(self, z): 41 | zf = self.backbone(z) 42 | if cfg.ADJUST.ADJUST: 43 | zf = self.neck(zf) 44 | if cfg.ALIGN.ALIGN: 45 | zf = [self.align(zf[i]) for i in range(len(zf))] 46 | 47 | self.zf = zf 48 | 49 | def track(self, x): 50 | xf = self.backbone(x) 51 | if cfg.ADJUST.ADJUST: 52 | xf = self.neck(xf) 53 | if cfg.ALIGN.ALIGN: 54 | xf = [self.align(xf[i]) for i in range(len(xf))] 55 | 56 | cls, loc = self.head(self.zf, xf) 57 | return { 58 | 'cls': cls, 59 | 'loc': loc 60 | } 61 | 62 | 63 | def log_softmax(self, cls): 64 | if cfg.BAN.BAN: 65 | cls = cls.permute(0, 2, 3, 1).contiguous() 66 | cls = F.log_softmax(cls, dim=3) 67 | return cls 68 | 69 | def forward(self, data): 70 | """ only used in training 71 | """ 72 | template = data['template'].cuda() 73 | search = data['search'].cuda() 74 | label_cls = data['label_cls'].cuda() 75 | label_loc = data['label_loc'].cuda() 76 | 77 | # get feature 78 | zf = self.backbone(template) 79 | xf = self.backbone(search) 80 | if cfg.ADJUST.ADJUST: 81 | zf = self.neck(zf) 82 | xf = self.neck(xf) 83 | if cfg.ALIGN.ALIGN: 84 | zf = [self.align(_zf) for _zf in zf] 85 | xf = [self.align(_xf) for _xf in xf] 86 | 87 | cls, loc = self.head(zf, xf) 88 | 89 | # get loss 90 | 91 | # cls loss with cross entropy loss 92 | cls = self.log_softmax(cls) 93 | cls_loss = select_cross_entropy_loss(cls, label_cls) 94 | 95 | # loc loss with iou loss 96 | loc_loss = select_iou_loss(loc, label_loc, label_cls) 97 | 98 | outputs = {} 99 | outputs['total_loss'] = cfg.TRAIN.CLS_WEIGHT * cls_loss + \ 100 | cfg.TRAIN.LOC_WEIGHT * loc_loss 101 | outputs['cls_loss'] = cls_loss 102 | outputs['loc_loss'] = loc_loss 103 | 104 | return outputs, zf, xf 105 | -------------------------------------------------------------------------------- /tracker/BAN/siamban/models/model_builder_tsne.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from siamban.core.config import cfg 12 | from siamban.models.loss import select_cross_entropy_loss, select_iou_loss 13 | from siamban.models.backbone import get_backbone 14 | from siamban.models.head import get_ban_head 15 | from siamban.models.neck import get_neck 16 | import torch 17 | import numpy 18 | 19 | 20 | class ModelBuilder(nn.Module): 21 | def __init__(self): 22 | super(ModelBuilder, self).__init__() 23 | 24 | # build backbone 25 | self.backbone = get_backbone(cfg.BACKBONE.TYPE, 26 | **cfg.BACKBONE.KWARGS) 27 | 28 | # build adjust layer 29 | if cfg.ADJUST.ADJUST: 30 | self.neck = get_neck(cfg.ADJUST.TYPE, 31 | **cfg.ADJUST.KWARGS) 32 | 33 | if cfg.ALIGN.ALIGN: 34 | self.align = get_neck(cfg.ALIGN.TYPE, 35 | **cfg.ALIGN.KWARGS) 36 | 37 | # build ban head 38 | if cfg.BAN.BAN: 39 | self.head = get_ban_head(cfg.BAN.TYPE, 40 | **cfg.BAN.KWARGS) 41 | 42 | def template(self, z): 43 | zf = self.backbone(z) 44 | if cfg.ADJUST.ADJUST: 45 | zf = self.neck(zf) 46 | if cfg.ALIGN.ALIGN: 47 | zf = [self.align(zf[i]) for i in range(len(zf))] 48 | 49 | self.zf = zf 50 | 51 | def track(self, x): 52 | xf = self.backbone(x) 53 | xf_b = xf 54 | if cfg.ADJUST.ADJUST: 55 | xf = self.neck(xf) 56 | if cfg.ALIGN.ALIGN: 57 | xf = [self.align(xf[i]) for i in range(len(xf))] 58 | 59 | cls, loc = self.head(self.zf, xf) 60 | return { 61 | 'cls': cls, 62 | 'loc': loc 63 | }, xf_b[1] 64 | 65 | def setToArray(self,setInput, dtype=' self.num: 61 | pop_num = len(self.history[k]) - self.num 62 | for _ in range(pop_num): 63 | self.sum[k] -= self.history[k][0] 64 | del self.history[k][0] 65 | self.count[k] -= 1 66 | 67 | def __repr__(self): 68 | s = '' 69 | for k in self.sum: 70 | s += self.format_str(k) 71 | return s 72 | 73 | def format_str(self, attr): 74 | return "{name}: {val:.6f} ({avg:.6f}) ".format( 75 | name=attr, 76 | val=float(self.val[attr]), 77 | avg=float(self.sum[attr]) / self.count[attr]) 78 | 79 | def __getattr__(self, attr): 80 | if attr in self.__dict__: 81 | return super(AverageMeter, self).__getattr__(attr) 82 | if attr not in self.sum: 83 | print("invalid key '{}'".format(attr)) 84 | return Meter(attr, 0, 0) 85 | return Meter(attr, self.val[attr], self.avg(attr)) 86 | 87 | def avg(self, attr): 88 | return float(self.sum[attr]) / self.count[attr] 89 | 90 | 91 | if __name__ == '__main__': 92 | avg1 = AverageMeter(10) 93 | avg2 = AverageMeter(0) 94 | avg3 = AverageMeter(-1) 95 | 96 | for i in range(20): 97 | avg1.update(s=i) 98 | avg2.update(s=i) 99 | avg3.update(s=i) 100 | 101 | print('iter {}'.format(i)) 102 | print(avg1.s) 103 | print(avg2.s) 104 | print(avg3.s) 105 | -------------------------------------------------------------------------------- /tracker/BAN/siamban/utils/bbox.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | from collections import namedtuple 9 | 10 | import numpy as np 11 | 12 | 13 | Corner = namedtuple('Corner', 'x1 y1 x2 y2') 14 | # alias 15 | BBox = Corner 16 | Center = namedtuple('Center', 'x y w h') 17 | 18 | 19 | def corner2center(corner): 20 | """ convert (x1, y1, x2, y2) to (cx, cy, w, h) 21 | Args: 22 | conrner: Corner or np.array (4*N) 23 | Return: 24 | Center or np.array (4 * N) 25 | """ 26 | if isinstance(corner, Corner): 27 | x1, y1, x2, y2 = corner 28 | return Center((x1 + x2) * 0.5, (y1 + y2) * 0.5, (x2 - x1), (y2 - y1)) 29 | else: 30 | x1, y1, x2, y2 = corner[0], corner[1], corner[2], corner[3] 31 | x = (x1 + x2) * 0.5 32 | y = (y1 + y2) * 0.5 33 | w = x2 - x1 34 | h = y2 - y1 35 | return x, y, w, h 36 | 37 | 38 | def center2corner(center): 39 | """ convert (cx, cy, w, h) to (x1, y1, x2, y2) 40 | Args: 41 | center: Center or np.array (4 * N) 42 | Return: 43 | center or np.array (4 * N) 44 | """ 45 | if isinstance(center, Center): 46 | x, y, w, h = center 47 | return Corner(x - w * 0.5, y - h * 0.5, x + w * 0.5, y + h * 0.5) 48 | else: 49 | x, y, w, h = center[0], center[1], center[2], center[3] 50 | x1 = x - w * 0.5 51 | y1 = y - h * 0.5 52 | x2 = x + w * 0.5 53 | y2 = y + h * 0.5 54 | return x1, y1, x2, y2 55 | 56 | 57 | def IoU(rect1, rect2): 58 | """ caculate interection over union 59 | Args: 60 | rect1: (x1, y1, x2, y2) 61 | rect2: (x1, y1, x2, y2) 62 | Returns: 63 | iou 64 | """ 65 | # overlap 66 | x1, y1, x2, y2 = rect1[0], rect1[1], rect1[2], rect1[3] 67 | tx1, ty1, tx2, ty2 = rect2[0], rect2[1], rect2[2], rect2[3] 68 | 69 | xx1 = np.maximum(tx1, x1) 70 | yy1 = np.maximum(ty1, y1) 71 | xx2 = np.minimum(tx2, x2) 72 | yy2 = np.minimum(ty2, y2) 73 | 74 | ww = np.maximum(0, xx2 - xx1) 75 | hh = np.maximum(0, yy2 - yy1) 76 | 77 | area = (x2 - x1) * (y2 - y1) 78 | target_a = (tx2 - tx1) * (ty2 - ty1) 79 | inter = ww * hh 80 | iou = inter / (area + target_a - inter) 81 | return iou 82 | 83 | 84 | def cxy_wh_2_rect(pos, sz): 85 | """ convert (cx, cy, w, h) to (x1, y1, w, h), 0-index 86 | """ 87 | return np.array([pos[0] - sz[0] / 2, pos[1] - sz[1] / 2, sz[0], sz[1]]) 88 | 89 | 90 | def rect_2_cxy_wh(rect): 91 | """ convert (x1, y1, w, h) to (cx, cy, w, h), 0-index 92 | """ 93 | return np.array([rect[0] + rect[2] / 2, rect[1] + rect[3] / 2]), \ 94 | np.array([rect[2], rect[3]]) 95 | 96 | 97 | def cxy_wh_2_rect1(pos, sz): 98 | """ convert (cx, cy, w, h) to (x1, y1, w, h), 1-index 99 | """ 100 | return np.array([pos[0] - sz[0] / 2 + 1, pos[1] - sz[1] / 2 + 1, sz[0], sz[1]]) 101 | 102 | 103 | def rect1_2_cxy_wh(rect): 104 | """ convert (x1, y1, w, h) to (cx, cy, w, h), 1-index 105 | """ 106 | return np.array([rect[0] + rect[2] / 2 - 1, rect[1] + rect[3] / 2 - 1]), \ 107 | np.array([rect[2], rect[3]]) 108 | 109 | 110 | def get_axis_aligned_bbox(region): 111 | """ convert region to (cx, cy, w, h) that represent by axis aligned box 112 | """ 113 | nv = region.size 114 | if nv == 8: 115 | cx = np.mean(region[0::2]) 116 | cy = np.mean(region[1::2]) 117 | x1 = min(region[0::2]) 118 | x2 = max(region[0::2]) 119 | y1 = min(region[1::2]) 120 | y2 = max(region[1::2]) 121 | A1 = np.linalg.norm(region[0:2] - region[2:4]) * \ 122 | np.linalg.norm(region[2:4] - region[4:6]) 123 | A2 = (x2 - x1) * (y2 - y1) 124 | s = np.sqrt(A1 / A2) 125 | w = s * (x2 - x1) + 1 126 | h = s * (y2 - y1) + 1 127 | else: 128 | x = region[0] 129 | y = region[1] 130 | w = region[2] 131 | h = region[3] 132 | cx = x + w / 2 133 | cy = y + h / 2 134 | return cx, cy, w, h 135 | 136 | 137 | 138 | def get_min_max_bbox(region): 139 | """ convert region to (cx, cy, w, h) that represent by mim-max box 140 | """ 141 | nv = region.size 142 | if nv == 8: 143 | cx = np.mean(region[0::2]) 144 | cy = np.mean(region[1::2]) 145 | x1 = min(region[0::2]) 146 | x2 = max(region[0::2]) 147 | y1 = min(region[1::2]) 148 | y2 = max(region[1::2]) 149 | w = x2 - x1 150 | h = y2 - y1 151 | else: 152 | x = region[0] 153 | y = region[1] 154 | w = region[2] 155 | h = region[3] 156 | cx = x + w / 2 157 | cy = y + h / 2 158 | return cx, cy, w, h 159 | -------------------------------------------------------------------------------- /tracker/BAN/siamban/utils/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import os 9 | import socket 10 | import logging 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.distributed as dist 15 | 16 | from siamban.utils.log_helper import log_once 17 | 18 | logger = logging.getLogger('global') 19 | 20 | 21 | def average_reduce(v): 22 | if get_world_size() == 1: 23 | return v 24 | tensor = torch.cuda.FloatTensor(1) 25 | tensor[0] = v 26 | dist.all_reduce(tensor) 27 | v = tensor[0] / get_world_size() 28 | return v 29 | 30 | 31 | class DistModule(nn.Module): 32 | def __init__(self, module, bn_method=0): 33 | super(DistModule, self).__init__() 34 | self.module = module 35 | self.bn_method = bn_method 36 | if get_world_size() > 1: 37 | broadcast_params(self.module) 38 | else: 39 | self.bn_method = 0 # single proccess 40 | 41 | def forward(self, *args, **kwargs): 42 | broadcast_buffers(self.module, self.bn_method) 43 | return self.module(*args, **kwargs) 44 | 45 | def train(self, mode=True): 46 | super(DistModule, self).train(mode) 47 | self.module.train(mode) 48 | return self 49 | 50 | 51 | def broadcast_params(model): 52 | """ broadcast model parameters """ 53 | for p in model.state_dict().values(): 54 | dist.broadcast(p, 0) 55 | 56 | 57 | def broadcast_buffers(model, method=0): 58 | """ broadcast model buffers """ 59 | if method == 0: 60 | return 61 | 62 | world_size = get_world_size() 63 | 64 | for b in model._all_buffers(): 65 | if method == 1: # broadcast from main proccess 66 | dist.broadcast(b, 0) 67 | elif method == 2: # average 68 | dist.all_reduce(b) 69 | b /= world_size 70 | else: 71 | raise Exception('Invalid buffer broadcast code {}'.format(method)) 72 | 73 | 74 | inited = False 75 | 76 | 77 | def _dist_init(): 78 | ''' 79 | if guess right: 80 | ntasks: world_size (process num) 81 | proc_id: rank 82 | ''' 83 | # rank = int(os.environ['RANK']) 84 | rank = 0 85 | num_gpus = torch.cuda.device_count() 86 | torch.cuda.set_device(rank % num_gpus) 87 | dist.init_process_group(backend='nccl') 88 | world_size = dist.get_world_size() 89 | return rank, world_size 90 | 91 | 92 | def _get_local_ip(): 93 | try: 94 | s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 95 | s.connect(('8.8.8.8', 80)) 96 | ip = s.getsockname()[0] 97 | finally: 98 | s.close() 99 | return ip 100 | 101 | 102 | def dist_init(): 103 | global rank, world_size, inited 104 | # try: 105 | # rank, world_size = _dist_init() 106 | # except RuntimeError as e: 107 | # if 'public' in e.args[0]: 108 | # logger.info(e) 109 | # logger.info('Warning: use single process') 110 | # rank, world_size = 0, 1 111 | # else: 112 | # raise RuntimeError(*e.args) 113 | rank, world_size = 0, 1 114 | inited = True 115 | return rank, world_size 116 | 117 | 118 | def get_rank(): 119 | if not inited: 120 | raise(Exception('dist not inited')) 121 | return rank 122 | 123 | 124 | def get_world_size(): 125 | if not inited: 126 | raise(Exception('dist not inited')) 127 | return world_size 128 | 129 | 130 | def reduce_gradients(model, _type='sum'): 131 | types = ['sum', 'avg'] 132 | assert _type in types, 'gradients method must be in "{}"'.format(types) 133 | log_once("gradients method is {}".format(_type)) 134 | if get_world_size() > 1: 135 | for param in model.parameters(): 136 | if param.requires_grad: 137 | dist.all_reduce(param.grad.data) 138 | if _type == 'avg': 139 | param.grad.data /= get_world_size() 140 | else: 141 | return None 142 | -------------------------------------------------------------------------------- /tracker/BAN/siamban/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import os 9 | 10 | from colorama import Fore, Style 11 | 12 | 13 | __all__ = ['commit', 'describe'] 14 | 15 | 16 | def _exec(cmd): 17 | f = os.popen(cmd, 'r', 1) 18 | return f.read().strip() 19 | 20 | 21 | def _bold(s): 22 | return "\033[1m%s\033[0m" % s 23 | 24 | 25 | def _color(s): 26 | return f'{Fore.RED}{s}{Style.RESET_ALL}' 27 | 28 | 29 | def _describe(model, lines=None, spaces=0): 30 | head = " " * spaces 31 | for name, p in model.named_parameters(): 32 | if '.' in name: 33 | continue 34 | if p.requires_grad: 35 | name = _color(name) 36 | line = "{head}- {name}".format(head=head, name=name) 37 | lines.append(line) 38 | 39 | for name, m in model.named_children(): 40 | space_num = len(name) + spaces + 1 41 | if m.training: 42 | name = _color(name) 43 | line = "{head}.{name} ({type})".format( 44 | head=head, 45 | name=name, 46 | type=m.__class__.__name__) 47 | lines.append(line) 48 | _describe(m, lines, space_num) 49 | 50 | 51 | def commit(): 52 | root = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')) 53 | cmd = "cd {}; git log | head -n1 | awk '{{print $2}}'".format(root) 54 | commit = _exec(cmd) 55 | cmd = "cd {}; git log --oneline | head -n1".format(root) 56 | commit_log = _exec(cmd) 57 | return "commit : {}\n log : {}".format(commit, commit_log) 58 | 59 | 60 | def describe(net, name=None): 61 | num = 0 62 | lines = [] 63 | if name is not None: 64 | lines.append(name) 65 | num = len(name) 66 | _describe(net, lines, num) 67 | return "\n".join(lines) 68 | -------------------------------------------------------------------------------- /tracker/BAN/siamban/utils/model_load.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import logging 9 | 10 | import torch 11 | 12 | 13 | logger = logging.getLogger('global') 14 | 15 | 16 | def check_keys(model, pretrained_state_dict): 17 | ckpt_keys = set(pretrained_state_dict.keys()) 18 | model_keys = set(model.state_dict().keys()) 19 | used_pretrained_keys = model_keys & ckpt_keys 20 | unused_pretrained_keys = ckpt_keys - model_keys 21 | missing_keys = model_keys - ckpt_keys 22 | # filter 'num_batches_tracked' 23 | missing_keys = [x for x in missing_keys 24 | if not x.endswith('num_batches_tracked')] 25 | if len(missing_keys) > 0: 26 | logger.info('[Warning] missing keys: {}'.format(missing_keys)) 27 | logger.info('missing keys:{}'.format(len(missing_keys))) 28 | if len(unused_pretrained_keys) > 0: 29 | logger.info('[Warning] unused_pretrained_keys: {}'.format( 30 | unused_pretrained_keys)) 31 | logger.info('unused checkpoint keys:{}'.format( 32 | len(unused_pretrained_keys))) 33 | logger.info('used keys:{}'.format(len(used_pretrained_keys))) 34 | assert len(used_pretrained_keys) > 0, \ 35 | 'load NONE from pretrained checkpoint' 36 | return True 37 | 38 | 39 | def remove_prefix(state_dict, prefix): 40 | ''' Old style model is stored with all names of parameters 41 | share common prefix 'module.' ''' 42 | logger.info('remove prefix \'{}\''.format(prefix)) 43 | f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x 44 | return {f(key): value for key, value in state_dict.items()} 45 | 46 | 47 | def load_pretrain(model, pretrained_path): 48 | logger.info('load pretrained model from {}'.format(pretrained_path)) 49 | device = torch.cuda.current_device() 50 | pretrained_dict = torch.load(pretrained_path, 51 | map_location=lambda storage, loc: storage.cuda(device)) 52 | if "state_dict" in pretrained_dict.keys(): 53 | pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 54 | 'module.') 55 | else: 56 | pretrained_dict = remove_prefix(pretrained_dict, 'module.') 57 | 58 | try: 59 | check_keys(model, pretrained_dict) 60 | except: 61 | logger.info('[Warning]: using pretrain as features.\ 62 | Adding "features." as prefix') 63 | new_dict = {} 64 | for k, v in pretrained_dict.items(): 65 | k = 'features.' + k 66 | new_dict[k] = v 67 | pretrained_dict = new_dict 68 | check_keys(model, pretrained_dict) 69 | model.load_state_dict(pretrained_dict, strict=False) 70 | return model 71 | 72 | 73 | def restore_from(model, optimizer, ckpt_path): 74 | device = torch.cuda.current_device() 75 | ckpt = torch.load(ckpt_path, 76 | map_location=lambda storage, loc: storage.cuda(device)) 77 | epoch = ckpt['epoch'] 78 | 79 | ckpt_model_dict = remove_prefix(ckpt['state_dict'], 'module.') 80 | check_keys(model, ckpt_model_dict) 81 | model.load_state_dict(ckpt_model_dict, strict=False) 82 | 83 | check_keys(optimizer, ckpt['optimizer']) 84 | optimizer.load_state_dict(ckpt['optimizer']) 85 | return model, optimizer, epoch 86 | -------------------------------------------------------------------------------- /tracker/BAN/siamban/utils/point.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import numpy as np 7 | 8 | 9 | class Point: 10 | """ 11 | This class generate points. 12 | """ 13 | def __init__(self, stride, size, image_center): 14 | self.stride = stride 15 | self.size = size 16 | self.image_center = image_center 17 | 18 | self.points = self.generate_points(self.stride, self.size, self.image_center) 19 | 20 | def generate_points(self, stride, size, im_c): 21 | ori = im_c - size // 2 * stride 22 | x, y = np.meshgrid([ori + stride * dx for dx in np.arange(0, size)], 23 | [ori + stride * dy for dy in np.arange(0, size)]) 24 | points = np.zeros((2, size, size), dtype=np.float32) 25 | points[0, :, :], points[1, :, :] = x.astype(np.float32), y.astype(np.float32) 26 | 27 | return points 28 | -------------------------------------------------------------------------------- /tracker/BAN/snapshot/README.md: -------------------------------------------------------------------------------- 1 | ### Put snapshot here 2 | 3 | | Training data | Model | Source 1 | Source 2 | Source 3 | 4 | | ---- | ---- | ---- | ---- | ---- | 5 | | SAM-NAT-B (base) | `sam-da-track-b` | [Baidu](https://pan.baidu.com/s/1c_hlOxnyv-4bGyHzymlpRA?pwd=6prk) | [Google](https://drive.google.com/file/d/1yiUTYQty52cAacmtGuqdgb53CnIe2l1W/view?usp=sharing) | [Hugging face](https://huggingface.co/George-Zhuang/SAM-DA/resolve/main/sam-da-track-b.pth) | 6 | | SAM-NAT-S (small) | `sam-da-track-s` | [Baidu](https://pan.baidu.com/s/1kUCZMXgRZs1HgD6gtx9hrQ?pwd=a48s) | [Google](https://drive.google.com/file/d/1fxShaJ67XB1nMnE9ioQg7_LXYQBd6snI/view?usp=sharing) | [Hugging face](https://huggingface.co/George-Zhuang/SAM-DA/resolve/main/sam-da-track-s.pth) | 7 | | SAM-NAT-T (tiny) | `sam-da-track-t` | [Baidu](https://pan.baidu.com/s/11LrJwoz--AO3UzXavwa_GA?pwd=5qkj) | [Google](https://drive.google.com/file/d/10Y9td4CJt4DqbcvCCLVUkCEx67MyilYC/view?usp=sharing) | [Hugging face](https://huggingface.co/George-Zhuang/SAM-DA/resolve/main/sam-da-track-t.pth) | 8 | | SAM-NAT-N (nano) | `sam-da-track-n` | [Baidu](https://pan.baidu.com/s/1h1OROv17qINJmGU7zR4LTA?pwd=ujag) | [Google](https://drive.google.com/file/d/1xR5i2XqHoDRoBEXH7O4ko5JZok0EPHTF/view?usp=sharing) | [Hugging face](https://huggingface.co/George-Zhuang/SAM-DA/resolve/main/sam-da-track-n.pth) | -------------------------------------------------------------------------------- /tracker/BAN/toolkit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/toolkit/__init__.py -------------------------------------------------------------------------------- /tracker/BAN/toolkit/datasets/DarkTrack2021.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | 5 | from PIL import Image 6 | from tqdm import tqdm 7 | from glob import glob 8 | 9 | from .dataset import Dataset 10 | from .video import Video 11 | 12 | 13 | def ca(dataset_root): 14 | 15 | path=dataset_root 16 | 17 | name_list=os.listdir(path+'/data_seq') 18 | name_list.sort() 19 | 20 | b=[] 21 | for i in range(len(name_list)): 22 | b.append(name_list[i]) 23 | c=[] 24 | 25 | for jj in range(len(name_list)): 26 | imgs=path+'/data_seq/'+str(name_list[jj]) 27 | txt=path+'/anno/'+str(name_list[jj])+'.txt' 28 | bbox=[] 29 | f = open(txt) # 返回一个文件对象 30 | file= f.readlines() 31 | li=os.listdir(imgs) 32 | li.sort() 33 | for ii in range(len(file)): 34 | try: 35 | li[ii]=name_list[jj]+'/'+li[ii] 36 | except: 37 | a=1 38 | 39 | line = file[ii].strip('\n').split(',') 40 | 41 | try: 42 | line[0]=int(line[0]) 43 | except: 44 | line[0]=float(line[0]) 45 | try: 46 | line[1]=int(line[1]) 47 | except: 48 | line[1]=float(line[1]) 49 | try: 50 | line[2]=int(line[2]) 51 | except: 52 | line[2]=float(line[2]) 53 | try: 54 | line[3]=int(line[3]) 55 | except: 56 | line[3]=float(line[3]) 57 | bbox.append(line) 58 | 59 | if len(bbox)!=len(li): 60 | print (jj) 61 | f.close() 62 | c.append({'attr':[],'gt_rect':bbox,'img_names':li,'init_rect':bbox[0],'video_dir':name_list[jj]}) 63 | 64 | d=dict(zip(b,c)) 65 | 66 | return d 67 | 68 | class UAVVideo(Video): 69 | """ 70 | Args: 71 | name: video name 72 | root: dataset root 73 | video_dir: video directory 74 | init_rect: init rectangle 75 | img_names: image names 76 | gt_rect: groundtruth rectangle 77 | attr: attribute of video 78 | """ 79 | def __init__(self, name, root, video_dir, init_rect, img_names, 80 | gt_rect, attr, load_img=False): 81 | super(UAVVideo, self).__init__(name, root, video_dir, 82 | init_rect, img_names, gt_rect, attr, load_img) 83 | 84 | 85 | class DarkTrack2021Dataset(Dataset): 86 | """ 87 | Args: 88 | name: dataset name 89 | dataset_root: dataset root 90 | load_img: wether to load all imgs 91 | """ 92 | def __init__(self, name, dataset_root, load_img=False): 93 | super(DarkTrack2021Dataset, self).__init__(name, dataset_root) 94 | dataset_root = dataset_root + '/DarkTrack2021' 95 | meta_data = ca(dataset_root) 96 | dataset_root = dataset_root + '/data_seq' 97 | 98 | # load videos 99 | pbar = tqdm(meta_data.keys(), desc='loading '+name, ncols=100) 100 | self.videos = {} 101 | for video in pbar: 102 | pbar.set_postfix_str(video) 103 | self.videos[video] = UAVVideo(video, 104 | dataset_root, 105 | meta_data[video]['video_dir'], 106 | meta_data[video]['init_rect'], 107 | meta_data[video]['img_names'], 108 | meta_data[video]['gt_rect'], 109 | meta_data[video]['attr']) 110 | 111 | # set attr 112 | attr = [] 113 | for x in self.videos.values(): 114 | attr += x.attr 115 | attr = set(attr) 116 | self.attr = {} 117 | self.attr['ALL'] = list(self.videos.keys()) 118 | for x in attr: 119 | self.attr[x] = [] 120 | for k, v in self.videos.items(): 121 | for attr_ in v.attr: 122 | self.attr[attr_].append(k) 123 | 124 | -------------------------------------------------------------------------------- /tracker/BAN/toolkit/datasets/UAVDark135.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | 5 | from PIL import Image 6 | from tqdm import tqdm 7 | from glob import glob 8 | 9 | from .dataset import Dataset 10 | from .video import Video 11 | 12 | 13 | def ca(dataset_root): 14 | 15 | path=dataset_root 16 | 17 | name_list=os.listdir(path+'/data_seq') 18 | name_list.sort() 19 | 20 | b=[] 21 | for i in range(len(name_list)): 22 | b.append(name_list[i]) 23 | c=[] 24 | 25 | for jj in range(len(name_list)): 26 | imgs=path+'/data_seq/'+str(name_list[jj]) 27 | txt=path+'/anno/'+str(name_list[jj])+'.txt' 28 | bbox=[] 29 | f = open(txt) # 返回一个文件对象 30 | file= f.readlines() 31 | li=os.listdir(imgs) 32 | li.sort() 33 | for ii in range(len(file)): 34 | try: 35 | li[ii]=name_list[jj]+'/'+li[ii] 36 | except: 37 | a=1 38 | 39 | line = file[ii].strip('\n').split(',') 40 | 41 | try: 42 | line[0]=int(line[0]) 43 | except: 44 | line[0]=float(line[0]) 45 | try: 46 | line[1]=int(line[1]) 47 | except: 48 | line[1]=float(line[1]) 49 | try: 50 | line[2]=int(line[2]) 51 | except: 52 | line[2]=float(line[2]) 53 | try: 54 | line[3]=int(line[3]) 55 | except: 56 | line[3]=float(line[3]) 57 | bbox.append(line) 58 | 59 | if len(bbox)!=len(li): 60 | print (jj) 61 | f.close() 62 | c.append({'attr':[],'gt_rect':bbox,'img_names':li,'init_rect':bbox[0],'video_dir':name_list[jj]}) 63 | 64 | d=dict(zip(b,c)) 65 | 66 | return d 67 | 68 | class UAVVideo(Video): 69 | """ 70 | Args: 71 | name: video name 72 | root: dataset root 73 | video_dir: video directory 74 | init_rect: init rectangle 75 | img_names: image names 76 | gt_rect: groundtruth rectangle 77 | attr: attribute of video 78 | """ 79 | def __init__(self, name, root, video_dir, init_rect, img_names, 80 | gt_rect, attr, load_img=False): 81 | super(UAVVideo, self).__init__(name, root, video_dir, 82 | init_rect, img_names, gt_rect, attr, load_img) 83 | 84 | 85 | class UAVDark135Dataset(Dataset): 86 | """ 87 | Args: 88 | name: dataset name 89 | dataset_root: dataset root 90 | load_img: wether to load all imgs 91 | """ 92 | def __init__(self, name, dataset_root, load_img=False): 93 | super(UAVDark135Dataset, self).__init__(name, dataset_root) 94 | dataset_root = dataset_root + '/UAVDark135' 95 | meta_data = ca(dataset_root) 96 | dataset_root = dataset_root + '/data_seq' 97 | 98 | # load videos 99 | pbar = tqdm(meta_data.keys(), desc='loading '+name, ncols=100) 100 | self.videos = {} 101 | for video in pbar: 102 | pbar.set_postfix_str(video) 103 | self.videos[video] = UAVVideo(video, 104 | dataset_root, 105 | meta_data[video]['video_dir'], 106 | meta_data[video]['init_rect'], 107 | meta_data[video]['img_names'], 108 | meta_data[video]['gt_rect'], 109 | meta_data[video]['attr']) 110 | 111 | # set attr 112 | attr = [] 113 | for x in self.videos.values(): 114 | attr += x.attr 115 | attr = set(attr) 116 | self.attr = {} 117 | self.attr['ALL'] = list(self.videos.keys()) 118 | for x in attr: 119 | self.attr[x] = [] 120 | for k, v in self.videos.items(): 121 | for attr_ in v.attr: 122 | self.attr[attr_].append(k) 123 | 124 | -------------------------------------------------------------------------------- /tracker/BAN/toolkit/datasets/UAVDark70.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | 5 | from PIL import Image 6 | from tqdm import tqdm 7 | from glob import glob 8 | 9 | from .dataset import Dataset 10 | from .video import Video 11 | 12 | 13 | def ca(dataset_root): 14 | 15 | path=dataset_root 16 | 17 | name_list=os.listdir(path+'/data_seq') 18 | name_list.sort() 19 | 20 | b=[] 21 | for i in range(len(name_list)): 22 | b.append(name_list[i]) 23 | c=[] 24 | 25 | for jj in range(len(name_list)): 26 | imgs=path+'/data_seq/'+str(name_list[jj]) 27 | txt=path+'/anno/'+str(name_list[jj])+'.txt' 28 | bbox=[] 29 | f = open(txt) # 返回一个文件对象 30 | file= f.readlines() 31 | li=os.listdir(imgs) 32 | li.sort() 33 | for ii in range(len(file)): 34 | try: 35 | li[ii]=name_list[jj]+'/'+li[ii] 36 | except: 37 | a=1 38 | 39 | line = file[ii].strip('\n').split(',') 40 | 41 | try: 42 | line[0]=int(line[0]) 43 | except: 44 | line[0]=float(line[0]) 45 | try: 46 | line[1]=int(line[1]) 47 | except: 48 | line[1]=float(line[1]) 49 | try: 50 | line[2]=int(line[2]) 51 | except: 52 | line[2]=float(line[2]) 53 | try: 54 | line[3]=int(line[3]) 55 | except: 56 | line[3]=float(line[3]) 57 | bbox.append(line) 58 | 59 | if len(bbox)!=len(li): 60 | print (jj) 61 | f.close() 62 | c.append({'attr':[],'gt_rect':bbox,'img_names':li,'init_rect':bbox[0],'video_dir':name_list[jj]}) 63 | 64 | d=dict(zip(b,c)) 65 | 66 | return d 67 | 68 | class UAVVideo(Video): 69 | """ 70 | Args: 71 | name: video name 72 | root: dataset root 73 | video_dir: video directory 74 | init_rect: init rectangle 75 | img_names: image names 76 | gt_rect: groundtruth rectangle 77 | attr: attribute of video 78 | """ 79 | def __init__(self, name, root, video_dir, init_rect, img_names, 80 | gt_rect, attr, load_img=False): 81 | super(UAVVideo, self).__init__(name, root, video_dir, 82 | init_rect, img_names, gt_rect, attr, load_img) 83 | 84 | 85 | class UAVDark70Dataset(Dataset): 86 | """ 87 | Args: 88 | name: dataset name 89 | dataset_root: dataset root 90 | load_img: wether to load all imgs 91 | """ 92 | def __init__(self, name, dataset_root, load_img=False): 93 | super(UAVDark70Dataset, self).__init__(name, dataset_root) 94 | dataset_root = dataset_root + '/UAVDark70' 95 | meta_data = ca(dataset_root) 96 | dataset_root = dataset_root + '/data_seq' 97 | 98 | # load videos 99 | pbar = tqdm(meta_data.keys(), desc='loading '+name, ncols=100) 100 | self.videos = {} 101 | for video in pbar: 102 | pbar.set_postfix_str(video) 103 | self.videos[video] = UAVVideo(video, 104 | dataset_root, 105 | meta_data[video]['video_dir'], 106 | meta_data[video]['init_rect'], 107 | meta_data[video]['img_names'], 108 | meta_data[video]['gt_rect'], 109 | meta_data[video]['attr']) 110 | 111 | # set attr 112 | attr = [] 113 | for x in self.videos.values(): 114 | attr += x.attr 115 | attr = set(attr) 116 | self.attr = {} 117 | self.attr['ALL'] = list(self.videos.keys()) 118 | for x in attr: 119 | self.attr[x] = [] 120 | for k, v in self.videos.items(): 121 | for attr_ in v.attr: 122 | self.attr[attr_].append(k) 123 | 124 | -------------------------------------------------------------------------------- /tracker/BAN/toolkit/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .uav import UAVDataset 2 | from .UAVDark70 import UAVDark70Dataset 3 | from .UAVDark135 import UAVDark135Dataset 4 | from .DarkTrack2021 import DarkTrack2021Dataset 5 | from .nat import NATDataset 6 | from .nat_l import NAT_LDataset 7 | from .nut import NUTDataset 8 | from .nut_l import NUT_LDataset 9 | from .nut_l_t import NUT_L_tDataset 10 | from .nut_l_s import NUT_L_sDataset 11 | class DatasetFactory(object): 12 | @staticmethod 13 | def create_dataset(**kwargs): 14 | """ 15 | Args: 16 | name: dataset name 'UAVDark70', 'UAV', 'NAT', 'NAT' 17 | dataset_root: dataset root 18 | load_img: wether to load image 19 | Return: 20 | dataset 21 | """ 22 | assert 'name' in kwargs, "should provide dataset name" 23 | name = kwargs['name'] 24 | if 'UAVDark70' == name: 25 | dataset = UAVDark70Dataset(**kwargs) 26 | elif 'UAVDark135' == name: 27 | dataset = UAVDark135Dataset(**kwargs) 28 | elif 'DarkTrack' in name: 29 | dataset = DarkTrack2021Dataset(**kwargs) 30 | elif 'UAV' in name: 31 | dataset = UAVDataset(**kwargs) 32 | elif 'NAT' == name: 33 | dataset = NATDataset(**kwargs) 34 | elif 'NAT_L' == name: 35 | dataset = NAT_LDataset(**kwargs) 36 | elif 'NUT' == name: 37 | dataset = NUTDataset(**kwargs) 38 | elif 'NUT_L' == name: 39 | dataset = NUT_LDataset(**kwargs) 40 | elif 'NUT_L_target' == name: 41 | dataset = NUT_L_tDataset(**kwargs) 42 | elif 'NUT_L_source' == name: 43 | dataset = NUT_L_sDataset(**kwargs) 44 | else: 45 | raise Exception("unknow dataset {}".format(kwargs['name'])) 46 | return dataset 47 | 48 | -------------------------------------------------------------------------------- /tracker/BAN/toolkit/datasets/__pycache__/DarkTrack2021.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/toolkit/datasets/__pycache__/DarkTrack2021.cpython-38.pyc -------------------------------------------------------------------------------- /tracker/BAN/toolkit/datasets/__pycache__/UAVDark135.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/toolkit/datasets/__pycache__/UAVDark135.cpython-38.pyc -------------------------------------------------------------------------------- /tracker/BAN/toolkit/datasets/__pycache__/UAVDark70.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/toolkit/datasets/__pycache__/UAVDark70.cpython-38.pyc -------------------------------------------------------------------------------- /tracker/BAN/toolkit/datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/toolkit/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /tracker/BAN/toolkit/datasets/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/toolkit/datasets/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /tracker/BAN/toolkit/datasets/__pycache__/nat.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/toolkit/datasets/__pycache__/nat.cpython-38.pyc -------------------------------------------------------------------------------- /tracker/BAN/toolkit/datasets/__pycache__/nat_l.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/toolkit/datasets/__pycache__/nat_l.cpython-38.pyc -------------------------------------------------------------------------------- /tracker/BAN/toolkit/datasets/__pycache__/nut.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/toolkit/datasets/__pycache__/nut.cpython-38.pyc -------------------------------------------------------------------------------- /tracker/BAN/toolkit/datasets/__pycache__/nut_l.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/toolkit/datasets/__pycache__/nut_l.cpython-38.pyc -------------------------------------------------------------------------------- /tracker/BAN/toolkit/datasets/__pycache__/nut_l_s.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/toolkit/datasets/__pycache__/nut_l_s.cpython-38.pyc -------------------------------------------------------------------------------- /tracker/BAN/toolkit/datasets/__pycache__/nut_l_t.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/toolkit/datasets/__pycache__/nut_l_t.cpython-38.pyc -------------------------------------------------------------------------------- /tracker/BAN/toolkit/datasets/__pycache__/uav.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/toolkit/datasets/__pycache__/uav.cpython-38.pyc -------------------------------------------------------------------------------- /tracker/BAN/toolkit/datasets/__pycache__/video.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/toolkit/datasets/__pycache__/video.cpython-38.pyc -------------------------------------------------------------------------------- /tracker/BAN/toolkit/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | 3 | class Dataset(object): 4 | def __init__(self, name, dataset_root): 5 | self.name = name 6 | self.dataset_root = dataset_root 7 | self.videos = None 8 | 9 | def __getitem__(self, idx): 10 | if isinstance(idx, str): 11 | return self.videos[idx] 12 | elif isinstance(idx, int): 13 | return self.videos[sorted(list(self.videos.keys()))[idx]] 14 | 15 | def __len__(self): 16 | return len(self.videos) 17 | 18 | def __iter__(self): 19 | keys = sorted(list(self.videos.keys())) 20 | for key in keys: 21 | yield self.videos[key] 22 | 23 | def set_tracker(self, path, tracker_names): 24 | """ 25 | Args: 26 | path: path to tracker results, 27 | tracker_names: list of tracker name 28 | """ 29 | self.tracker_path = path 30 | self.tracker_names = tracker_names 31 | # for video in tqdm(self.videos.values(), 32 | # desc='loading tacker result', ncols=100): 33 | # video.load_tracker(path, tracker_names) 34 | -------------------------------------------------------------------------------- /tracker/BAN/toolkit/datasets/nat.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | 5 | from PIL import Image 6 | from tqdm import tqdm 7 | from glob import glob 8 | 9 | from .dataset import Dataset 10 | from .video import Video 11 | 12 | 13 | def ca(dataset_root): 14 | 15 | path=dataset_root 16 | 17 | name_list=os.listdir(path+'/data_seq') 18 | name_list.sort() 19 | 20 | b=[] 21 | for i in range(len(name_list)): 22 | b.append(name_list[i]) 23 | c=[] 24 | 25 | for jj in range(len(name_list)): 26 | imgs=path+'/data_seq/'+str(name_list[jj]) 27 | txt=path+'/anno/'+str(name_list[jj])+'.txt' 28 | bbox=[] 29 | f = open(txt) # 返回一个文件对象 30 | file= f.readlines() 31 | li=os.listdir(imgs) 32 | li.sort() 33 | for ii in range(len(file)): 34 | try: 35 | li[ii]=name_list[jj]+'/'+li[ii] 36 | except: 37 | a=1 38 | 39 | line = file[ii].strip('\n').split(',') 40 | 41 | try: 42 | line[0]=int(line[0]) 43 | except: 44 | line[0]=float(line[0]) 45 | try: 46 | line[1]=int(line[1]) 47 | except: 48 | line[1]=float(line[1]) 49 | try: 50 | line[2]=int(line[2]) 51 | except: 52 | line[2]=float(line[2]) 53 | try: 54 | line[3]=int(line[3]) 55 | except: 56 | line[3]=float(line[3]) 57 | bbox.append(line) 58 | 59 | if len(bbox)!=len(li): 60 | print (jj) 61 | f.close() 62 | c.append({'attr':[],'gt_rect':bbox,'img_names':li,'init_rect':bbox[0],'video_dir':name_list[jj]}) 63 | 64 | d=dict(zip(b,c)) 65 | 66 | return d 67 | 68 | class UAVVideo(Video): 69 | """ 70 | Args: 71 | name: video name 72 | root: dataset root 73 | video_dir: video directory 74 | init_rect: init rectangle 75 | img_names: image names 76 | gt_rect: groundtruth rectangle 77 | attr: attribute of video 78 | """ 79 | def __init__(self, name, root, video_dir, init_rect, img_names, 80 | gt_rect, attr, load_img=False): 81 | super(UAVVideo, self).__init__(name, root, video_dir, 82 | init_rect, img_names, gt_rect, attr, load_img) 83 | 84 | 85 | class NATDataset(Dataset): 86 | """ 87 | Args: 88 | name: dataset name 89 | dataset_root: dataset root 90 | load_img: wether to load all imgs 91 | """ 92 | def __init__(self, name, dataset_root, load_img=False): 93 | super(NATDataset, self).__init__(name, dataset_root) 94 | dataset_root = dataset_root + '/NAT2021' 95 | meta_data = ca(dataset_root) 96 | dataset_root = dataset_root + '/data_seq' 97 | # load videos 98 | pbar = tqdm(meta_data.keys(), desc='loading '+name, ncols=100) 99 | self.videos = {} 100 | for video in pbar: 101 | pbar.set_postfix_str(video) 102 | self.videos[video] = UAVVideo(video, 103 | dataset_root, 104 | meta_data[video]['video_dir'], 105 | meta_data[video]['init_rect'], 106 | meta_data[video]['img_names'], 107 | meta_data[video]['gt_rect'], 108 | meta_data[video]['attr']) 109 | 110 | # set attr 111 | attr = [] 112 | for x in self.videos.values(): 113 | attr += x.attr 114 | attr = set(attr) 115 | self.attr = {} 116 | self.attr['ALL'] = list(self.videos.keys()) 117 | for x in attr: 118 | self.attr[x] = [] 119 | for k, v in self.videos.items(): 120 | for attr_ in v.attr: 121 | self.attr[attr_].append(k) 122 | 123 | -------------------------------------------------------------------------------- /tracker/BAN/toolkit/datasets/nat_l.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | 5 | from PIL import Image 6 | from tqdm import tqdm 7 | from glob import glob 8 | 9 | from .dataset import Dataset 10 | from .video import Video 11 | 12 | 13 | def ca(dataset_root): 14 | 15 | path=dataset_root 16 | 17 | name_list=os.listdir(path+'/data_seq') 18 | name_list.sort() 19 | 20 | b=[] 21 | for i in range(len(name_list)): 22 | b.append(name_list[i]) 23 | c=[] 24 | 25 | for jj in range(len(name_list)): 26 | imgs=path+'/data_seq/'+str(name_list[jj]) 27 | txt=path+'/anno/'+str(name_list[jj])+'.txt' 28 | bbox=[] 29 | f = open(txt) # 返回一个文件对象 30 | file= f.readlines() 31 | li=os.listdir(imgs) 32 | li.sort() 33 | for ii in range(len(file)): 34 | try: 35 | li[ii]=name_list[jj]+'/'+li[ii] 36 | except: 37 | a=1 38 | 39 | line = file[ii].strip('\n').split(',') 40 | 41 | try: 42 | line[0]=int(line[0]) 43 | except: 44 | line[0]=float(line[0]) 45 | try: 46 | line[1]=int(line[1]) 47 | except: 48 | line[1]=float(line[1]) 49 | try: 50 | line[2]=int(line[2]) 51 | except: 52 | line[2]=float(line[2]) 53 | try: 54 | line[3]=int(line[3]) 55 | except: 56 | line[3]=float(line[3]) 57 | bbox.append(line) 58 | 59 | if len(bbox)!=len(li): 60 | print (jj) 61 | f.close() 62 | c.append({'attr':[],'gt_rect':bbox,'img_names':li,'init_rect':bbox[0],'video_dir':name_list[jj]}) 63 | 64 | d=dict(zip(b,c)) 65 | 66 | return d 67 | 68 | class UAVVideo(Video): 69 | """ 70 | Args: 71 | name: video name 72 | root: dataset root 73 | video_dir: video directory 74 | init_rect: init rectangle 75 | img_names: image names 76 | gt_rect: groundtruth rectangle 77 | attr: attribute of video 78 | """ 79 | def __init__(self, name, root, video_dir, init_rect, img_names, 80 | gt_rect, attr, load_img=False): 81 | super(UAVVideo, self).__init__(name, root, video_dir, 82 | init_rect, img_names, gt_rect, attr, load_img) 83 | 84 | 85 | class NAT_LDataset(Dataset): 86 | """ 87 | Args: 88 | name: dataset name 89 | dataset_root: dataset root 90 | load_img: wether to load all imgs 91 | """ 92 | def __init__(self, name, dataset_root, load_img=False): 93 | super(NAT_LDataset, self).__init__(name, dataset_root) 94 | dataset_root = dataset_root + '/NAT2021L' 95 | meta_data = ca(dataset_root) 96 | dataset_root = dataset_root + '/data_seq' 97 | # load videos 98 | pbar = tqdm(meta_data.keys(), desc='loading '+name, ncols=100) 99 | self.videos = {} 100 | for video in pbar: 101 | pbar.set_postfix_str(video) 102 | self.videos[video] = UAVVideo(video, 103 | dataset_root, 104 | meta_data[video]['video_dir'], 105 | meta_data[video]['init_rect'], 106 | meta_data[video]['img_names'], 107 | meta_data[video]['gt_rect'], 108 | meta_data[video]['attr']) 109 | 110 | # set attr 111 | attr = [] 112 | for x in self.videos.values(): 113 | attr += x.attr 114 | attr = set(attr) 115 | self.attr = {} 116 | self.attr['ALL'] = list(self.videos.keys()) 117 | for x in attr: 118 | self.attr[x] = [] 119 | for k, v in self.videos.items(): 120 | for attr_ in v.attr: 121 | self.attr[attr_].append(k) 122 | 123 | -------------------------------------------------------------------------------- /tracker/BAN/toolkit/datasets/nut.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | 5 | from PIL import Image 6 | from tqdm import tqdm 7 | from glob import glob 8 | 9 | from .dataset import Dataset 10 | from .video import Video 11 | 12 | 13 | def ca(dataset_root): 14 | 15 | path=dataset_root 16 | 17 | name_list=os.listdir(path+'/data_seq') 18 | name_list.sort() 19 | 20 | b=[] 21 | for i in range(len(name_list)): 22 | b.append(name_list[i]) 23 | c=[] 24 | 25 | for jj in range(len(name_list)): 26 | imgs=path+'/data_seq/'+str(name_list[jj]) 27 | txt=path+'/anno/'+str(name_list[jj])+'.txt' 28 | bbox=[] 29 | f = open(txt) # 返回一个文件对象 30 | file= f.readlines() 31 | li=os.listdir(imgs) 32 | li.sort() 33 | for ii in range(len(file)): 34 | try: 35 | li[ii]=name_list[jj]+'/'+li[ii] 36 | except: 37 | a=1 38 | 39 | line = file[ii].strip('\n').split(',') 40 | 41 | try: 42 | line[0]=int(line[0]) 43 | except: 44 | line[0]=float(line[0]) 45 | try: 46 | line[1]=int(line[1]) 47 | except: 48 | line[1]=float(line[1]) 49 | try: 50 | line[2]=int(line[2]) 51 | except: 52 | line[2]=float(line[2]) 53 | try: 54 | line[3]=int(line[3]) 55 | except: 56 | line[3]=float(line[3]) 57 | bbox.append(line) 58 | 59 | if len(bbox)!=len(li): 60 | print (jj) 61 | f.close() 62 | c.append({'attr':[],'gt_rect':bbox,'img_names':li,'init_rect':bbox[0],'video_dir':name_list[jj]}) 63 | 64 | d=dict(zip(b,c)) 65 | 66 | return d 67 | 68 | class UAVVideo(Video): 69 | """ 70 | Args: 71 | name: video name 72 | root: dataset root 73 | video_dir: video directory 74 | init_rect: init rectangle 75 | img_names: image names 76 | gt_rect: groundtruth rectangle 77 | attr: attribute of video 78 | """ 79 | def __init__(self, name, root, video_dir, init_rect, img_names, 80 | gt_rect, attr, load_img=False): 81 | super(UAVVideo, self).__init__(name, root, video_dir, 82 | init_rect, img_names, gt_rect, attr, load_img) 83 | 84 | 85 | class NUTDataset(Dataset): 86 | """ 87 | Args: 88 | name: dataset name 89 | dataset_root: dataset root 90 | load_img: wether to load all imgs 91 | """ 92 | def __init__(self, name, dataset_root, load_img=False): 93 | super(NUTDataset, self).__init__(name, dataset_root) 94 | dataset_root = dataset_root + '/NUT' 95 | meta_data = ca(dataset_root) 96 | dataset_root = dataset_root + '/data_seq' 97 | # load videos 98 | pbar = tqdm(meta_data.keys(), desc='loading '+name, ncols=100) 99 | self.videos = {} 100 | for video in pbar: 101 | pbar.set_postfix_str(video) 102 | self.videos[video] = UAVVideo(video, 103 | dataset_root, 104 | meta_data[video]['video_dir'], 105 | meta_data[video]['init_rect'], 106 | meta_data[video]['img_names'], 107 | meta_data[video]['gt_rect'], 108 | meta_data[video]['attr']) 109 | 110 | # set attr 111 | attr = [] 112 | for x in self.videos.values(): 113 | attr += x.attr 114 | attr = set(attr) 115 | self.attr = {} 116 | self.attr['ALL'] = list(self.videos.keys()) 117 | for x in attr: 118 | self.attr[x] = [] 119 | for k, v in self.videos.items(): 120 | for attr_ in v.attr: 121 | self.attr[attr_].append(k) 122 | 123 | -------------------------------------------------------------------------------- /tracker/BAN/toolkit/datasets/nut_l.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | 5 | from PIL import Image 6 | from tqdm import tqdm 7 | from glob import glob 8 | 9 | from .dataset import Dataset 10 | from .video import Video 11 | 12 | 13 | def ca(dataset_root): 14 | 15 | path=dataset_root 16 | 17 | name_list=os.listdir(path+'/data_seq') 18 | name_list.sort() 19 | 20 | b=[] 21 | for i in range(len(name_list)): 22 | b.append(name_list[i]) 23 | c=[] 24 | 25 | for jj in range(len(name_list)): 26 | imgs=path+'/data_seq/'+str(name_list[jj]) 27 | txt=path+'/anno/'+str(name_list[jj])+'.txt' 28 | bbox=[] 29 | f = open(txt) # 返回一个文件对象 30 | file= f.readlines() 31 | li=os.listdir(imgs) 32 | li.sort() 33 | for ii in range(len(file)): 34 | try: 35 | li[ii]=name_list[jj]+'/'+li[ii] 36 | except: 37 | a=1 38 | 39 | line = file[ii].strip('\n').split(',') 40 | 41 | try: 42 | line[0]=int(line[0]) 43 | except: 44 | line[0]=float(line[0]) 45 | try: 46 | line[1]=int(line[1]) 47 | except: 48 | line[1]=float(line[1]) 49 | try: 50 | line[2]=int(line[2]) 51 | except: 52 | line[2]=float(line[2]) 53 | try: 54 | line[3]=int(line[3]) 55 | except: 56 | line[3]=float(line[3]) 57 | bbox.append(line) 58 | 59 | if len(bbox)!=len(li): 60 | print (jj) 61 | f.close() 62 | c.append({'attr':[],'gt_rect':bbox,'img_names':li,'init_rect':bbox[0],'video_dir':name_list[jj]}) 63 | 64 | d=dict(zip(b,c)) 65 | 66 | return d 67 | 68 | class UAVVideo(Video): 69 | """ 70 | Args: 71 | name: video name 72 | root: dataset root 73 | video_dir: video directory 74 | init_rect: init rectangle 75 | img_names: image names 76 | gt_rect: groundtruth rectangle 77 | attr: attribute of video 78 | """ 79 | def __init__(self, name, root, video_dir, init_rect, img_names, 80 | gt_rect, attr, load_img=False): 81 | super(UAVVideo, self).__init__(name, root, video_dir, 82 | init_rect, img_names, gt_rect, attr, load_img) 83 | 84 | 85 | class NUT_LDataset(Dataset): 86 | """ 87 | Args: 88 | name: dataset name 89 | dataset_root: dataset root 90 | load_img: wether to load all imgs 91 | """ 92 | def __init__(self, name, dataset_root, load_img=False): 93 | super(NUT_LDataset, self).__init__(name, dataset_root) 94 | dataset_root = dataset_root + '/NUT-L' 95 | meta_data = ca(dataset_root) 96 | dataset_root = dataset_root + '/data_seq' 97 | # load videos 98 | pbar = tqdm(meta_data.keys(), desc='loading '+name, ncols=100) 99 | self.videos = {} 100 | for video in pbar: 101 | pbar.set_postfix_str(video) 102 | self.videos[video] = UAVVideo(video, 103 | dataset_root, 104 | meta_data[video]['video_dir'], 105 | meta_data[video]['init_rect'], 106 | meta_data[video]['img_names'], 107 | meta_data[video]['gt_rect'], 108 | meta_data[video]['attr']) 109 | 110 | # set attr 111 | attr = [] 112 | for x in self.videos.values(): 113 | attr += x.attr 114 | attr = set(attr) 115 | self.attr = {} 116 | self.attr['ALL'] = list(self.videos.keys()) 117 | for x in attr: 118 | self.attr[x] = [] 119 | for k, v in self.videos.items(): 120 | for attr_ in v.attr: 121 | self.attr[attr_].append(k) 122 | 123 | -------------------------------------------------------------------------------- /tracker/BAN/toolkit/datasets/nut_l_s.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | 5 | from PIL import Image 6 | from tqdm import tqdm 7 | from glob import glob 8 | 9 | from .dataset import Dataset 10 | from .video import Video 11 | 12 | 13 | def ca(dataset_root): 14 | 15 | path=dataset_root 16 | 17 | name_list=os.listdir(path+'/data_seq') 18 | name_list.sort() 19 | 20 | b=[] 21 | for i in range(len(name_list)): 22 | b.append(name_list[i]) 23 | c=[] 24 | 25 | for jj in range(len(name_list)): 26 | imgs=path+'/data_seq/'+str(name_list[jj]) 27 | txt=path+'/anno/'+str(name_list[jj])+'.txt' 28 | bbox=[] 29 | f = open(txt) # 返回一个文件对象 30 | file= f.readlines() 31 | li=os.listdir(imgs) 32 | li.sort() 33 | for ii in range(len(file)): 34 | try: 35 | li[ii]=name_list[jj]+'/'+li[ii] 36 | except: 37 | a=1 38 | 39 | line = file[ii].strip('\n').split(',') 40 | 41 | try: 42 | line[0]=int(line[0]) 43 | except: 44 | line[0]=float(line[0]) 45 | try: 46 | line[1]=int(line[1]) 47 | except: 48 | line[1]=float(line[1]) 49 | try: 50 | line[2]=int(line[2]) 51 | except: 52 | line[2]=float(line[2]) 53 | try: 54 | line[3]=int(line[3]) 55 | except: 56 | line[3]=float(line[3]) 57 | bbox.append(line) 58 | 59 | if len(bbox)!=len(li): 60 | print (jj) 61 | f.close() 62 | c.append({'attr':[],'gt_rect':bbox,'img_names':li,'init_rect':bbox[0],'video_dir':name_list[jj]}) 63 | 64 | d=dict(zip(b,c)) 65 | 66 | return d 67 | 68 | class UAVVideo(Video): 69 | """ 70 | Args: 71 | name: video name 72 | root: dataset root 73 | video_dir: video directory 74 | init_rect: init rectangle 75 | img_names: image names 76 | gt_rect: groundtruth rectangle 77 | attr: attribute of video 78 | """ 79 | def __init__(self, name, root, video_dir, init_rect, img_names, 80 | gt_rect, attr, load_img=False): 81 | super(UAVVideo, self).__init__(name, root, video_dir, 82 | init_rect, img_names, gt_rect, attr, load_img) 83 | 84 | 85 | class NUT_L_sDataset(Dataset): 86 | """ 87 | Args: 88 | name: dataset name 89 | dataset_root: dataset root 90 | load_img: wether to load all imgs 91 | """ 92 | def __init__(self, name, dataset_root, load_img=False): 93 | super(NUT_L_sDataset, self).__init__(name, dataset_root) 94 | dataset_root = dataset_root + '/NUT_L_source' 95 | meta_data = ca(dataset_root) 96 | dataset_root = dataset_root + '/data_seq' 97 | # load videos 98 | pbar = tqdm(meta_data.keys(), desc='loading '+name, ncols=100) 99 | self.videos = {} 100 | for video in pbar: 101 | pbar.set_postfix_str(video) 102 | self.videos[video] = UAVVideo(video, 103 | dataset_root, 104 | meta_data[video]['video_dir'], 105 | meta_data[video]['init_rect'], 106 | meta_data[video]['img_names'], 107 | meta_data[video]['gt_rect'], 108 | meta_data[video]['attr']) 109 | 110 | # set attr 111 | attr = [] 112 | for x in self.videos.values(): 113 | attr += x.attr 114 | attr = set(attr) 115 | self.attr = {} 116 | self.attr['ALL'] = list(self.videos.keys()) 117 | for x in attr: 118 | self.attr[x] = [] 119 | for k, v in self.videos.items(): 120 | for attr_ in v.attr: 121 | self.attr[attr_].append(k) 122 | 123 | -------------------------------------------------------------------------------- /tracker/BAN/toolkit/datasets/nut_l_t.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | 5 | from PIL import Image 6 | from tqdm import tqdm 7 | from glob import glob 8 | 9 | from .dataset import Dataset 10 | from .video import Video 11 | 12 | 13 | def ca(dataset_root): 14 | 15 | path=dataset_root 16 | 17 | name_list=os.listdir(path+'/data_seq') 18 | name_list.sort() 19 | 20 | b=[] 21 | for i in range(len(name_list)): 22 | b.append(name_list[i]) 23 | c=[] 24 | 25 | for jj in range(len(name_list)): 26 | imgs=path+'/data_seq/'+str(name_list[jj]) 27 | txt=path+'/anno/'+str(name_list[jj])+'.txt' 28 | bbox=[] 29 | f = open(txt) # 返回一个文件对象 30 | file= f.readlines() 31 | li=os.listdir(imgs) 32 | li.sort() 33 | for ii in range(len(file)): 34 | try: 35 | li[ii]=name_list[jj]+'/'+li[ii] 36 | except: 37 | a=1 38 | 39 | line = file[ii].strip('\n').split(',') 40 | 41 | try: 42 | line[0]=int(line[0]) 43 | except: 44 | line[0]=float(line[0]) 45 | try: 46 | line[1]=int(line[1]) 47 | except: 48 | line[1]=float(line[1]) 49 | try: 50 | line[2]=int(line[2]) 51 | except: 52 | line[2]=float(line[2]) 53 | try: 54 | line[3]=int(line[3]) 55 | except: 56 | line[3]=float(line[3]) 57 | bbox.append(line) 58 | 59 | if len(bbox)!=len(li): 60 | print (jj) 61 | f.close() 62 | c.append({'attr':[],'gt_rect':bbox,'img_names':li,'init_rect':bbox[0],'video_dir':name_list[jj]}) 63 | 64 | d=dict(zip(b,c)) 65 | 66 | return d 67 | 68 | class UAVVideo(Video): 69 | """ 70 | Args: 71 | name: video name 72 | root: dataset root 73 | video_dir: video directory 74 | init_rect: init rectangle 75 | img_names: image names 76 | gt_rect: groundtruth rectangle 77 | attr: attribute of video 78 | """ 79 | def __init__(self, name, root, video_dir, init_rect, img_names, 80 | gt_rect, attr, load_img=False): 81 | super(UAVVideo, self).__init__(name, root, video_dir, 82 | init_rect, img_names, gt_rect, attr, load_img) 83 | 84 | 85 | class NUT_L_tDataset(Dataset): 86 | """ 87 | Args: 88 | name: dataset name 89 | dataset_root: dataset root 90 | load_img: wether to load all imgs 91 | """ 92 | def __init__(self, name, dataset_root, load_img=False): 93 | super(NUT_L_tDataset, self).__init__(name, dataset_root) 94 | dataset_root = dataset_root + '/NUT_L_target' 95 | meta_data = ca(dataset_root) 96 | dataset_root = dataset_root + '/data_seq' 97 | # load videos 98 | pbar = tqdm(meta_data.keys(), desc='loading '+name, ncols=100) 99 | self.videos = {} 100 | for video in pbar: 101 | pbar.set_postfix_str(video) 102 | self.videos[video] = UAVVideo(video, 103 | dataset_root, 104 | meta_data[video]['video_dir'], 105 | meta_data[video]['init_rect'], 106 | meta_data[video]['img_names'], 107 | meta_data[video]['gt_rect'], 108 | meta_data[video]['attr']) 109 | 110 | # set attr 111 | attr = [] 112 | for x in self.videos.values(): 113 | attr += x.attr 114 | attr = set(attr) 115 | self.attr = {} 116 | self.attr['ALL'] = list(self.videos.keys()) 117 | for x in attr: 118 | self.attr[x] = [] 119 | for k, v in self.videos.items(): 120 | for attr_ in v.attr: 121 | self.attr[attr_].append(k) 122 | 123 | -------------------------------------------------------------------------------- /tracker/BAN/toolkit/datasets/uav.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from tqdm import tqdm 5 | from glob import glob 6 | 7 | from .dataset import Dataset 8 | from .video import Video 9 | 10 | class UAVVideo(Video): 11 | """ 12 | Args: 13 | name: video name 14 | root: dataset root 15 | video_dir: video directory 16 | init_rect: init rectangle 17 | img_names: image names 18 | gt_rect: groundtruth rectangle 19 | attr: attribute of video 20 | """ 21 | def __init__(self, name, root, video_dir, init_rect, img_names, 22 | gt_rect, attr, load_img=False): 23 | super(UAVVideo, self).__init__(name, root, video_dir, 24 | init_rect, img_names, gt_rect, attr, load_img) 25 | 26 | 27 | class UAVDataset(Dataset): 28 | """ 29 | Args: 30 | name: dataset name, should be 'UAV123', 'UAV20L' 31 | dataset_root: dataset root 32 | load_img: wether to load all imgs 33 | """ 34 | def __init__(self, name, dataset_root, load_img=False): 35 | super(UAVDataset, self).__init__(name, dataset_root) 36 | dataset_root = dataset_root + '/UAV123/data_seq/UAV123/' 37 | with open(os.path.join(dataset_root, '{}.json'.format(name)), 'r') as f: 38 | meta_data = json.load(f) 39 | 40 | # load videos 41 | pbar = tqdm(meta_data.keys(), desc='loading '+name, ncols=100) 42 | self.videos = {} 43 | for video in pbar: 44 | pbar.set_postfix_str(video) 45 | self.videos[video] = UAVVideo(video, 46 | dataset_root, 47 | meta_data[video]['video_dir'], 48 | meta_data[video]['init_rect'], 49 | meta_data[video]['img_names'], 50 | meta_data[video]['gt_rect'], 51 | meta_data[video]['attr']) 52 | 53 | # set attr 54 | attr = [] 55 | for x in self.videos.values(): 56 | attr += x.attr 57 | attr = set(attr) 58 | self.attr = {} 59 | self.attr['ALL'] = list(self.videos.keys()) 60 | for x in attr: 61 | self.attr[x] = [] 62 | for k, v in self.videos.items(): 63 | for attr_ in v.attr: 64 | self.attr[attr_].append(k) 65 | 66 | -------------------------------------------------------------------------------- /tracker/BAN/toolkit/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .ar_benchmark import AccuracyRobustnessBenchmark 2 | from .eao_benchmark import EAOBenchmark 3 | from .ope_benchmark import OPEBenchmark 4 | from .f1_benchmark import F1Benchmark 5 | -------------------------------------------------------------------------------- /tracker/BAN/toolkit/evaluation/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/toolkit/evaluation/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /tracker/BAN/toolkit/evaluation/__pycache__/ar_benchmark.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/toolkit/evaluation/__pycache__/ar_benchmark.cpython-38.pyc -------------------------------------------------------------------------------- /tracker/BAN/toolkit/evaluation/__pycache__/eao_benchmark.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/toolkit/evaluation/__pycache__/eao_benchmark.cpython-38.pyc -------------------------------------------------------------------------------- /tracker/BAN/toolkit/evaluation/__pycache__/f1_benchmark.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/toolkit/evaluation/__pycache__/f1_benchmark.cpython-38.pyc -------------------------------------------------------------------------------- /tracker/BAN/toolkit/evaluation/__pycache__/ope_benchmark.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/toolkit/evaluation/__pycache__/ope_benchmark.cpython-38.pyc -------------------------------------------------------------------------------- /tracker/BAN/toolkit/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import region 2 | from .statistics import * 3 | -------------------------------------------------------------------------------- /tracker/BAN/toolkit/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/toolkit/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /tracker/BAN/toolkit/utils/__pycache__/statistics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/toolkit/utils/__pycache__/statistics.cpython-38.pyc -------------------------------------------------------------------------------- /tracker/BAN/toolkit/utils/c_region.pxd: -------------------------------------------------------------------------------- 1 | cdef extern from "src/region.h": 2 | ctypedef enum region_type "RegionType": 3 | EMTPY 4 | SPECIAL 5 | RECTANGEL 6 | POLYGON 7 | MASK 8 | 9 | ctypedef struct region_bounds: 10 | float top 11 | float bottom 12 | float left 13 | float right 14 | 15 | ctypedef struct region_rectangle: 16 | float x 17 | float y 18 | float width 19 | float height 20 | 21 | # ctypedef struct region_mask: 22 | # int x 23 | # int y 24 | # int width 25 | # int height 26 | # char *data 27 | 28 | ctypedef struct region_polygon: 29 | int count 30 | float *x 31 | float *y 32 | 33 | ctypedef union region_container_data: 34 | region_rectangle rectangle 35 | region_polygon polygon 36 | # region_mask mask 37 | int special 38 | 39 | ctypedef struct region_container: 40 | region_type type 41 | region_container_data data 42 | 43 | # ctypedef struct region_overlap: 44 | # float overlap 45 | # float only1 46 | # float only2 47 | 48 | # region_overlap region_compute_overlap(const region_container* ra, const region_container* rb, region_bounds bounds) 49 | 50 | float compute_polygon_overlap(const region_polygon* p1, const region_polygon* p2, float *only1, float *only2, region_bounds bounds) 51 | -------------------------------------------------------------------------------- /tracker/BAN/toolkit/utils/misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author fangyi.zhang@vipl.ict.ac.cn 3 | """ 4 | import numpy as np 5 | 6 | def determine_thresholds(confidence, resolution=100): 7 | """choose threshold according to confidence 8 | 9 | Args: 10 | confidence: list or numpy array or numpy array 11 | reolution: number of threshold to choose 12 | 13 | Restures: 14 | threshold: numpy array 15 | """ 16 | if isinstance(confidence, list): 17 | confidence = np.array(confidence) 18 | confidence = confidence.flatten() 19 | confidence = confidence[~np.isnan(confidence)] 20 | confidence.sort() 21 | 22 | assert len(confidence) > resolution and resolution > 2 23 | 24 | thresholds = np.ones((resolution)) 25 | thresholds[0] = - np.inf 26 | thresholds[-1] = np.inf 27 | delta = np.floor(len(confidence) / (resolution - 2)) 28 | idxs = np.linspace(delta, len(confidence)-delta, resolution-2, dtype=np.int32) 29 | thresholds[1:-1] = confidence[idxs] 30 | return thresholds 31 | -------------------------------------------------------------------------------- /tracker/BAN/toolkit/utils/region.cpython-36m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/toolkit/utils/region.cpython-36m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /tracker/BAN/toolkit/utils/region.cpython-38-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/SAM-DA/0b4ee43ccd32cd6fba8ca93c5bfbee90a527fa8e/tracker/BAN/toolkit/utils/region.cpython-38-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /tracker/BAN/toolkit/utils/src/region.h: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C; indent-tabs-mode: nil; c-basic-offset: 4; tab-width: 4 -*- */ 2 | 3 | #ifndef _REGION_H_ 4 | #define _REGION_H_ 5 | 6 | #ifdef TRAX_STATIC_DEFINE 7 | # define __TRAX_EXPORT 8 | #else 9 | # ifndef __TRAX_EXPORT 10 | # if defined(_MSC_VER) 11 | # ifdef trax_EXPORTS 12 | /* We are building this library */ 13 | # define __TRAX_EXPORT __declspec(dllexport) 14 | # else 15 | /* We are using this library */ 16 | # define __TRAX_EXPORT __declspec(dllimport) 17 | # endif 18 | # elif defined(__GNUC__) 19 | # ifdef trax_EXPORTS 20 | /* We are building this library */ 21 | # define __TRAX_EXPORT __attribute__((visibility("default"))) 22 | # else 23 | /* We are using this library */ 24 | # define __TRAX_EXPORT __attribute__((visibility("default"))) 25 | # endif 26 | # endif 27 | # endif 28 | #endif 29 | 30 | #ifndef MAX 31 | #define MAX(a,b) (((a) > (b)) ? (a) : (b)) 32 | #endif 33 | 34 | #ifndef MIN 35 | #define MIN(a,b) (((a) < (b)) ? (a) : (b)) 36 | #endif 37 | 38 | #define TRAX_DEFAULT_CODE 0 39 | 40 | #define REGION_LEGACY_RASTERIZATION 1 41 | 42 | #ifdef __cplusplus 43 | extern "C" { 44 | #endif 45 | 46 | typedef enum region_type {EMPTY, SPECIAL, RECTANGLE, POLYGON, MASK} region_type; 47 | 48 | typedef struct region_bounds { 49 | 50 | float top; 51 | float bottom; 52 | float left; 53 | float right; 54 | 55 | } region_bounds; 56 | 57 | typedef struct region_polygon { 58 | 59 | int count; 60 | 61 | float* x; 62 | float* y; 63 | 64 | } region_polygon; 65 | 66 | typedef struct region_mask { 67 | 68 | int x; 69 | int y; 70 | 71 | int width; 72 | int height; 73 | 74 | char* data; 75 | 76 | } region_mask; 77 | 78 | typedef struct region_rectangle { 79 | 80 | float x; 81 | float y; 82 | float width; 83 | float height; 84 | 85 | } region_rectangle; 86 | 87 | typedef struct region_container { 88 | enum region_type type; 89 | union { 90 | region_rectangle rectangle; 91 | region_polygon polygon; 92 | region_mask mask; 93 | int special; 94 | } data; 95 | } region_container; 96 | 97 | typedef struct region_overlap { 98 | 99 | float overlap; 100 | float only1; 101 | float only2; 102 | 103 | } region_overlap; 104 | 105 | extern const region_bounds region_no_bounds; 106 | 107 | __TRAX_EXPORT int region_set_flags(int mask); 108 | 109 | __TRAX_EXPORT int region_clear_flags(int mask); 110 | 111 | __TRAX_EXPORT region_overlap region_compute_overlap(const region_container* ra, const region_container* rb, region_bounds bounds); 112 | 113 | __TRAX_EXPORT float compute_polygon_overlap(const region_polygon* p1, const region_polygon* p2, float *only1, float *only2, region_bounds bounds); 114 | 115 | __TRAX_EXPORT region_bounds region_create_bounds(float left, float top, float right, float bottom); 116 | 117 | __TRAX_EXPORT region_bounds region_compute_bounds(const region_container* region); 118 | 119 | __TRAX_EXPORT int region_parse(const char* buffer, region_container** region); 120 | 121 | __TRAX_EXPORT char* region_string(region_container* region); 122 | 123 | __TRAX_EXPORT void region_print(FILE* out, region_container* region); 124 | 125 | __TRAX_EXPORT region_container* region_convert(const region_container* region, region_type type); 126 | 127 | __TRAX_EXPORT void region_release(region_container** region); 128 | 129 | __TRAX_EXPORT region_container* region_create_special(int code); 130 | 131 | __TRAX_EXPORT region_container* region_create_rectangle(float x, float y, float width, float height); 132 | 133 | __TRAX_EXPORT region_container* region_create_polygon(int count); 134 | 135 | __TRAX_EXPORT int region_contains_point(region_container* r, float x, float y); 136 | 137 | __TRAX_EXPORT void region_get_mask(region_container* r, char* mask, int width, int height); 138 | 139 | __TRAX_EXPORT void region_get_mask_offset(region_container* r, char* mask, int x, int y, int width, int height); 140 | 141 | #ifdef __cplusplus 142 | } 143 | #endif 144 | 145 | #endif 146 | -------------------------------------------------------------------------------- /tracker/BAN/toolkit/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | from .draw_f1 import draw_f1 2 | from .draw_success_precision import draw_success_precision 3 | from .draw_eao import draw_eao 4 | -------------------------------------------------------------------------------- /tracker/BAN/toolkit/visualization/draw_eao.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import pickle 4 | 5 | from matplotlib import rc 6 | from .draw_utils import COLOR, MARKER_STYLE 7 | 8 | rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']}) 9 | rc('text', usetex=True) 10 | 11 | def draw_eao(result): 12 | fig = plt.figure() 13 | ax = fig.add_subplot(111, projection='polar') 14 | angles = np.linspace(0, 2*np.pi, 8, endpoint=True) 15 | 16 | attr2value = [] 17 | for i, (tracker_name, ret) in enumerate(result.items()): 18 | value = list(ret.values()) 19 | attr2value.append(value) 20 | value.append(value[0]) 21 | attr2value = np.array(attr2value) 22 | max_value = np.max(attr2value, axis=0) 23 | min_value = np.min(attr2value, axis=0) 24 | for i, (tracker_name, ret) in enumerate(result.items()): 25 | value = list(ret.values()) 26 | value.append(value[0]) 27 | value = np.array(value) 28 | value *= (1 / max_value) 29 | plt.plot(angles, value, linestyle='-', color=COLOR[i], marker=MARKER_STYLE[i], 30 | label=tracker_name, linewidth=1.5, markersize=6) 31 | 32 | attrs = ["Overall", "Camera motion", 33 | "Illumination change","Motion Change", 34 | "Size change","Occlusion", 35 | "Unassigned"] 36 | attr_value = [] 37 | for attr, maxv, minv in zip(attrs, max_value, min_value): 38 | attr_value.append(attr + "\n({:.3f},{:.3f})".format(minv, maxv)) 39 | ax.set_thetagrids(angles[:-1] * 180/np.pi, attr_value) 40 | ax.spines['polar'].set_visible(False) 41 | ax.legend(loc='upper center', bbox_to_anchor=(0.5,-0.07), frameon=False, ncol=5) 42 | ax.grid(b=False) 43 | ax.set_ylim(0, 1.18) 44 | ax.set_yticks([]) 45 | plt.show() 46 | 47 | if __name__ == '__main__': 48 | result = pickle.load(open("../../result.pkl", 'rb')) 49 | draw_eao(result) 50 | -------------------------------------------------------------------------------- /tracker/BAN/toolkit/visualization/draw_f1.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | from matplotlib import rc 5 | from .draw_utils import COLOR, LINE_STYLE 6 | 7 | rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']}) 8 | rc('text', usetex=True) 9 | 10 | def draw_f1(result, bold_name=None): 11 | # drawing f1 contour 12 | fig, ax = plt.subplots() 13 | for f1 in np.arange(0.1, 1, 0.1): 14 | recall = np.arange(f1, 1+0.01, 0.01) 15 | precision = f1 * recall / (2 * recall - f1) 16 | ax.plot(recall, precision, color=[0,1,0], linestyle='-', linewidth=0.5) 17 | ax.plot(precision, recall, color=[0,1,0], linestyle='-', linewidth=0.5) 18 | ax.grid(b=True) 19 | ax.set_aspect(1) 20 | plt.xlabel('Recall') 21 | plt.ylabel('Precision') 22 | plt.axis([0, 1, 0, 1]) 23 | plt.title(r'\textbf{VOT2018-LT Precision vs Recall}') 24 | 25 | # draw result line 26 | all_precision = {} 27 | all_recall = {} 28 | best_f1 = {} 29 | best_idx = {} 30 | for tracker_name, ret in result.items(): 31 | precision = np.mean(list(ret['precision'].values()), axis=0) 32 | recall = np.mean(list(ret['recall'].values()), axis=0) 33 | f1 = 2 * precision * recall / (precision + recall) 34 | max_idx = np.argmax(f1) 35 | all_precision[tracker_name] = precision 36 | all_recall[tracker_name] = recall 37 | best_f1[tracker_name] = f1[max_idx] 38 | best_idx[tracker_name] = max_idx 39 | 40 | for idx, (tracker_name, best_f1) in \ 41 | enumerate(sorted(best_f1.items(), key=lambda x:x[1], reverse=True)): 42 | if tracker_name == bold_name: 43 | label = r"\textbf{[%.3f] Ours}" % (best_f1) 44 | else: 45 | label = "[%.3f] " % (best_f1) + tracker_name 46 | recall = all_recall[tracker_name][:-1] 47 | precision = all_precision[tracker_name][:-1] 48 | ax.plot(recall, precision, color=COLOR[idx], linestyle='-', 49 | label=label) 50 | f1_idx = best_idx[tracker_name] 51 | ax.plot(recall[f1_idx], precision[f1_idx], color=[0,0,0], marker='o', 52 | markerfacecolor=COLOR[idx], markersize=5) 53 | ax.legend(loc='lower right', labelspacing=0.2) 54 | plt.xticks(np.arange(0, 1+0.1, 0.1)) 55 | plt.yticks(np.arange(0, 1+0.1, 0.1)) 56 | plt.show() 57 | 58 | if __name__ == '__main__': 59 | draw_f1(None) 60 | -------------------------------------------------------------------------------- /tracker/BAN/toolkit/visualization/draw_utils.py: -------------------------------------------------------------------------------- 1 | 2 | COLOR = ((1, 0, 0), 3 | (0, 1, 0), 4 | (1, 0, 1), 5 | (1, 1, 0), 6 | (0 , 162/255, 232/255), 7 | (0.5, 0.5, 0.5), 8 | (0, 0, 1), 9 | (0, 1, 1), 10 | (136/255, 0 , 21/255), 11 | (255/255, 127/255, 39/255), 12 | (0, 0, 0)) 13 | 14 | LINE_STYLE = ['-', '--', ':', '-', '--', ':', '-', '--', ':', '-'] 15 | 16 | MARKER_STYLE = ['o', 'v', '<', '*', 'D', 'x', '.', 'x', '<', '.'] 17 | -------------------------------------------------------------------------------- /tracker/BAN/train_dataset/got10k/gen_json.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from __future__ import unicode_literals 4 | import json 5 | from os.path import join, exists 6 | import os 7 | import pandas as pd 8 | 9 | dataset_path = 'data' 10 | train_sets = ['GOT-10k_Train_split_01','GOT-10k_Train_split_02','GOT-10k_Train_split_03','GOT-10k_Train_split_04', 11 | 'GOT-10k_Train_split_05','GOT-10k_Train_split_06','GOT-10k_Train_split_07','GOT-10k_Train_split_08', 12 | 'GOT-10k_Train_split_09','GOT-10k_Train_split_10','GOT-10k_Train_split_11','GOT-10k_Train_split_12', 13 | 'GOT-10k_Train_split_13','GOT-10k_Train_split_14','GOT-10k_Train_split_15','GOT-10k_Train_split_16', 14 | 'GOT-10k_Train_split_17','GOT-10k_Train_split_18','GOT-10k_Train_split_19'] 15 | val_set = ['val'] 16 | d_sets = {'videos_val':val_set,'videos_train':train_sets} 17 | 18 | 19 | def parse_and_sched(dl_dir='.'): 20 | js = {} 21 | for d_set in d_sets: 22 | for dataset in d_sets[d_set]: 23 | videos = os.listdir(os.path.join(dataset_path,dataset)) 24 | for video in videos: 25 | if video == 'list.txt': 26 | continue 27 | video = dataset+'/'+video 28 | gt_path = join(dataset_path, video, 'groundtruth.txt') 29 | f = open(gt_path, 'r') 30 | groundtruth = f.readlines() 31 | f.close() 32 | for idx, gt_line in enumerate(groundtruth): 33 | gt_image = gt_line.strip().split(',') 34 | frame = '%06d' % (int(idx)) 35 | obj = '%02d' % (int(0)) 36 | bbox = [int(float(gt_image[0])), int(float(gt_image[1])), 37 | int(float(gt_image[0])) + int(float(gt_image[2])), 38 | int(float(gt_image[1])) + int(float(gt_image[3]))] # xmin,ymin,xmax,ymax 39 | 40 | if video not in js: 41 | js[video] = {} 42 | if obj not in js[video]: 43 | js[video][obj] = {} 44 | js[video][obj][frame] = bbox 45 | if 'videos_val' == d_set: 46 | json.dump(js, open('val.json', 'w'), indent=4, sort_keys=True) 47 | else: 48 | json.dump(js, open('train.json', 'w'), indent=4, sort_keys=True) 49 | js = {} 50 | 51 | print(d_set+': All videos downloaded' ) 52 | 53 | 54 | if __name__ == '__main__': 55 | parse_and_sched() 56 | -------------------------------------------------------------------------------- /tracker/BAN/train_dataset/got10k/par_crop.py: -------------------------------------------------------------------------------- 1 | from os.path import join, isdir 2 | from os import listdir, mkdir, makedirs 3 | import cv2 4 | import numpy as np 5 | import glob 6 | import xml.etree.ElementTree as ET 7 | from concurrent import futures 8 | import sys 9 | import time 10 | 11 | dataset_path = './data' 12 | sub_sets = ['GOT-10k_Train_split_01','GOT-10k_Train_split_02','GOT-10k_Train_split_03','GOT-10k_Train_split_04', 13 | 'GOT-10k_Train_split_05','GOT-10k_Train_split_06','GOT-10k_Train_split_07','GOT-10k_Train_split_08', 14 | 'GOT-10k_Train_split_09','GOT-10k_Train_split_10','GOT-10k_Train_split_11','GOT-10k_Train_split_12', 15 | 'GOT-10k_Train_split_13','GOT-10k_Train_split_14','GOT-10k_Train_split_15','GOT-10k_Train_split_16', 16 | 'GOT-10k_Train_split_17','GOT-10k_Train_split_18','GOT-10k_Train_split_19','val'] 17 | 18 | # Print iterations progress (thanks StackOverflow) 19 | def printProgress(iteration, total, prefix='', suffix='', decimals=1, barLength=100): 20 | 21 | formatStr = "{0:." + str(decimals) + "f}" 22 | percents = formatStr.format(100 * (iteration / float(total))) 23 | filledLength = int(round(barLength * iteration / float(total))) 24 | bar = '' * filledLength + '-' * (barLength - filledLength) 25 | sys.stdout.write('\r%s |%s| %s%s %s' % (prefix, bar, percents, '%', suffix)), 26 | if iteration == total: 27 | sys.stdout.write('\x1b[2K\r') 28 | sys.stdout.flush() 29 | 30 | 31 | def crop_hwc(image, bbox, out_sz, padding=(0, 0, 0)): 32 | a = (out_sz-1) / (bbox[2]-bbox[0]) 33 | b = (out_sz-1) / (bbox[3]-bbox[1]) 34 | c = -a * bbox[0] 35 | d = -b * bbox[1] 36 | mapping = np.array([[a, 0, c], 37 | [0, b, d]]).astype(float) 38 | crop = cv2.warpAffine(image, mapping, (out_sz, out_sz), borderMode=cv2.BORDER_CONSTANT, borderValue=padding) 39 | return crop 40 | 41 | 42 | def pos_s_2_bbox(pos, s): 43 | return [pos[0]-s/2, pos[1]-s/2, pos[0]+s/2, pos[1]+s/2] 44 | 45 | 46 | def crop_like_SiamFC(image, bbox, context_amount=0.5, exemplar_size=127, instanc_size=255, padding=(0, 0, 0)): 47 | target_pos = [(bbox[2]+bbox[0])/2., (bbox[3]+bbox[1])/2.] 48 | target_size = [bbox[2]-bbox[0], bbox[3]-bbox[1]] 49 | wc_z = target_size[1] + context_amount * sum(target_size) 50 | hc_z = target_size[0] + context_amount * sum(target_size) 51 | s_z = np.sqrt(wc_z * hc_z) 52 | scale_z = exemplar_size / s_z 53 | d_search = (instanc_size - exemplar_size) / 2 54 | pad = d_search / scale_z 55 | s_x = s_z + 2 * pad 56 | 57 | z = crop_hwc(image, pos_s_2_bbox(target_pos, s_z), exemplar_size, padding) 58 | x = crop_hwc(image, pos_s_2_bbox(target_pos, s_x), instanc_size, padding) 59 | return z, x 60 | 61 | 62 | def crop_video(video, d_set, crop_path, instanc_size): 63 | if video != 'list.txt': 64 | video_crop_base_path = join(crop_path, video) 65 | if not isdir(video_crop_base_path): makedirs(video_crop_base_path) 66 | gt_path = join(dataset_path, d_set, video, 'groundtruth.txt') 67 | images_path = join(dataset_path, d_set, video) 68 | f = open(gt_path, 'r') 69 | groundtruth = f.readlines() 70 | f.close() 71 | for idx, gt_line in enumerate(groundtruth): 72 | gt_image = gt_line.strip().split(',') 73 | bbox = [int(float(gt_image[0])),int(float(gt_image[1])),int(float(gt_image[0]))+int(float(gt_image[2])),int(float(gt_image[1]))+int(float(gt_image[3]))]#xmin,ymin,xmax,ymax 74 | 75 | im = cv2.imread(join(images_path,str(idx+1).zfill(8)+'.jpg')) 76 | avg_chans = np.mean(im, axis=(0, 1)) 77 | 78 | z, x = crop_like_SiamFC(im, bbox, instanc_size=instanc_size, padding=avg_chans) 79 | cv2.imwrite(join(video_crop_base_path, '{:06d}.{:02d}.z.jpg'.format(int(idx), int(0))), z) 80 | cv2.imwrite(join(video_crop_base_path, '{:06d}.{:02d}.x.jpg'.format(int(idx), int(0))), x) 81 | 82 | 83 | def main(instanc_size=511, num_threads=24): 84 | crop_path = './crop{:d}'.format(instanc_size) 85 | 86 | if not isdir(crop_path): mkdir(crop_path) 87 | for d_set in sub_sets: 88 | save_path = join(crop_path, d_set) 89 | videos = listdir(join(dataset_path,d_set)) 90 | if not isdir(save_path): mkdir(save_path) 91 | 92 | 93 | n_videos = len(videos) 94 | with futures.ProcessPoolExecutor(max_workers=num_threads) as executor: 95 | fs = [executor.submit(crop_video, video, d_set, save_path, instanc_size) for video in videos] 96 | for i, f in enumerate(futures.as_completed(fs)): 97 | # Write progress to error so that it can be seen 98 | printProgress(i, n_videos, prefix='train', suffix='Done ', barLength=40) 99 | 100 | 101 | if __name__ == '__main__': 102 | since = time.time() 103 | main() 104 | time_elapsed = time.time() - since 105 | print('Total complete in {:.0f}m {:.0f}s'.format( 106 | time_elapsed // 60, time_elapsed % 60)) 107 | -------------------------------------------------------------------------------- /tracker/BAN/train_dataset/got10k/readme.md: -------------------------------------------------------------------------------- 1 | # Preprocessing GOT-10K 2 | A Large High-Diversity Benchmark for Generic Object Tracking in the Wild 3 | 4 | ### Prepare dataset 5 | 6 | After download the dataset, please unzip the dataset at *train_dataset/got10k* directory 7 | mkdir data 8 | unzip full_data/train_data/*.zip -d ./data 9 | ```` 10 | 11 | ### Crop & Generate data info 12 | 13 | ````shell 14 | #python par_crop.py [crop_size] [num_threads] 15 | python par_crop.py 511 12 16 | python gen_json.py 17 | ```` 18 | -------------------------------------------------------------------------------- /tracker/BAN/train_dataset/vid/gen_json.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | from os import listdir 3 | import json 4 | import numpy as np 5 | 6 | print('load json (raw vid info), please wait 20 seconds~') 7 | vid = json.load(open('vid.json', 'r')) 8 | 9 | 10 | def check_size(frame_sz, bbox): 11 | min_ratio = 0.1 12 | max_ratio = 0.75 13 | area_ratio = np.sqrt((bbox[2]-bbox[0])*(bbox[3]-bbox[1])/float(np.prod(frame_sz))) 14 | ok = (area_ratio > min_ratio) and (area_ratio < max_ratio) 15 | return ok 16 | 17 | 18 | def check_borders(frame_sz, bbox): 19 | dist_from_border = 0.05 * (bbox[2] - bbox[0] + bbox[3] - bbox[1])/2 20 | ok = (bbox[0] > dist_from_border) and (bbox[1] > dist_from_border) and \ 21 | ((frame_sz[0] - bbox[2]) > dist_from_border) and \ 22 | ((frame_sz[1] - bbox[3]) > dist_from_border) 23 | return ok 24 | 25 | 26 | snippets = dict() 27 | n_snippets = 0 28 | n_videos = 0 29 | for subset in vid: 30 | for video in subset: 31 | n_videos += 1 32 | frames = video['frame'] 33 | id_set = [] 34 | id_frames = [[]] * 60 # at most 60 objects 35 | for f, frame in enumerate(frames): 36 | objs = frame['objs'] 37 | frame_sz = frame['frame_sz'] 38 | for obj in objs: 39 | trackid = obj['trackid'] 40 | occluded = obj['occ'] 41 | bbox = obj['bbox'] 42 | 43 | if trackid not in id_set: 44 | id_set.append(trackid) 45 | id_frames[trackid] = [] 46 | id_frames[trackid].append(f) 47 | if len(id_set) > 0: 48 | snippets[video['base_path']] = dict() 49 | for selected in id_set: 50 | frame_ids = sorted(id_frames[selected]) 51 | sequences = np.split(frame_ids, np.array(np.where(np.diff(frame_ids) > 1)[0]) + 1) 52 | sequences = [s for s in sequences if len(s) > 1] # remove isolated frame. 53 | for seq in sequences: 54 | snippet = dict() 55 | for frame_id in seq: 56 | frame = frames[frame_id] 57 | for obj in frame['objs']: 58 | if obj['trackid'] == selected: 59 | o = obj 60 | continue 61 | snippet[frame['img_path'].split('.')[0]] = o['bbox'] 62 | snippets[video['base_path']]['{:02d}'.format(selected)] = snippet 63 | n_snippets += 1 64 | print('video: {:d} snippets_num: {:d}'.format(n_videos, n_snippets)) 65 | 66 | train = {k:v for (k,v) in snippets.items() if 'train' in k} 67 | val = {k:v for (k,v) in snippets.items() if 'val' in k} 68 | 69 | json.dump(train, open('train.json', 'w'), indent=4, sort_keys=True) 70 | json.dump(val, open('val.json', 'w'), indent=4, sort_keys=True) 71 | print('done!') 72 | -------------------------------------------------------------------------------- /tracker/BAN/train_dataset/vid/par_crop.py: -------------------------------------------------------------------------------- 1 | from os.path import join, isdir 2 | from os import listdir, mkdir, makedirs 3 | import cv2 4 | import numpy as np 5 | import glob 6 | import xml.etree.ElementTree as ET 7 | from concurrent import futures 8 | import sys 9 | import time 10 | 11 | VID_base_path = './ILSVRC2015' 12 | ann_base_path = join(VID_base_path, 'Annotations/VID/train/') 13 | sub_sets = sorted({'a', 'b', 'c', 'd', 'e'}) 14 | 15 | 16 | # Print iterations progress (thanks StackOverflow) 17 | def printProgress(iteration, total, prefix='', suffix='', decimals=1, barLength=100): 18 | 19 | formatStr = "{0:." + str(decimals) + "f}" 20 | percents = formatStr.format(100 * (iteration / float(total))) 21 | filledLength = int(round(barLength * iteration / float(total))) 22 | bar = '' * filledLength + '-' * (barLength - filledLength) 23 | sys.stdout.write('\r%s |%s| %s%s %s' % (prefix, bar, percents, '%', suffix)), 24 | if iteration == total: 25 | sys.stdout.write('\x1b[2K\r') 26 | sys.stdout.flush() 27 | 28 | 29 | def crop_hwc(image, bbox, out_sz, padding=(0, 0, 0)): 30 | a = (out_sz-1) / (bbox[2]-bbox[0]) 31 | b = (out_sz-1) / (bbox[3]-bbox[1]) 32 | c = -a * bbox[0] 33 | d = -b * bbox[1] 34 | mapping = np.array([[a, 0, c], 35 | [0, b, d]]).astype(float) 36 | crop = cv2.warpAffine(image, mapping, (out_sz, out_sz), borderMode=cv2.BORDER_CONSTANT, borderValue=padding) 37 | return crop 38 | 39 | 40 | def pos_s_2_bbox(pos, s): 41 | return [pos[0]-s/2, pos[1]-s/2, pos[0]+s/2, pos[1]+s/2] 42 | 43 | 44 | def crop_like_SiamFC(image, bbox, context_amount=0.5, exemplar_size=127, instanc_size=255, padding=(0, 0, 0)): 45 | target_pos = [(bbox[2]+bbox[0])/2., (bbox[3]+bbox[1])/2.] 46 | target_size = [bbox[2]-bbox[0], bbox[3]-bbox[1]] 47 | wc_z = target_size[1] + context_amount * sum(target_size) 48 | hc_z = target_size[0] + context_amount * sum(target_size) 49 | s_z = np.sqrt(wc_z * hc_z) 50 | scale_z = exemplar_size / s_z 51 | d_search = (instanc_size - exemplar_size) / 2 52 | pad = d_search / scale_z 53 | s_x = s_z + 2 * pad 54 | 55 | z = crop_hwc(image, pos_s_2_bbox(target_pos, s_z), exemplar_size, padding) 56 | x = crop_hwc(image, pos_s_2_bbox(target_pos, s_x), instanc_size, padding) 57 | return z, x 58 | 59 | 60 | def crop_video(sub_set, video, crop_path, instanc_size): 61 | video_crop_base_path = join(crop_path, sub_set, video) 62 | if not isdir(video_crop_base_path): makedirs(video_crop_base_path) 63 | 64 | sub_set_base_path = join(ann_base_path, sub_set) 65 | xmls = sorted(glob.glob(join(sub_set_base_path, video, '*.xml'))) 66 | for xml in xmls: 67 | xmltree = ET.parse(xml) 68 | # size = xmltree.findall('size')[0] 69 | # frame_sz = [int(it.text) for it in size] 70 | objects = xmltree.findall('object') 71 | objs = [] 72 | filename = xmltree.findall('filename')[0].text 73 | 74 | im = cv2.imread(xml.replace('xml', 'JPEG').replace('Annotations', 'Data')) 75 | avg_chans = np.mean(im, axis=(0, 1)) 76 | for object_iter in objects: 77 | trackid = int(object_iter.find('trackid').text) 78 | # name = (object_iter.find('name')).text 79 | bndbox = object_iter.find('bndbox') 80 | # occluded = int(object_iter.find('occluded').text) 81 | 82 | bbox = [int(bndbox.find('xmin').text), int(bndbox.find('ymin').text), 83 | int(bndbox.find('xmax').text), int(bndbox.find('ymax').text)] 84 | z, x = crop_like_SiamFC(im, bbox, instanc_size=instanc_size, padding=avg_chans) 85 | cv2.imwrite(join(video_crop_base_path, '{:06d}.{:02d}.z.jpg'.format(int(filename), trackid)), z) 86 | cv2.imwrite(join(video_crop_base_path, '{:06d}.{:02d}.x.jpg'.format(int(filename), trackid)), x) 87 | 88 | 89 | def main(instanc_size=511, num_threads=24): 90 | crop_path = './crop{:d}'.format(instanc_size) 91 | if not isdir(crop_path): mkdir(crop_path) 92 | 93 | for sub_set in sub_sets: 94 | sub_set_base_path = join(ann_base_path, sub_set) 95 | videos = sorted(listdir(sub_set_base_path)) 96 | n_videos = len(videos) 97 | with futures.ProcessPoolExecutor(max_workers=num_threads) as executor: 98 | fs = [executor.submit(crop_video, sub_set, video, crop_path, instanc_size) for video in videos] 99 | for i, f in enumerate(futures.as_completed(fs)): 100 | # Write progress to error so that it can be seen 101 | printProgress(i, n_videos, prefix=sub_set, suffix='Done ', barLength=40) 102 | 103 | 104 | if __name__ == '__main__': 105 | since = time.time() 106 | main(int(sys.argv[1]), int(sys.argv[2])) 107 | time_elapsed = time.time() - since 108 | print('Total complete in {:.0f}m {:.0f}s'.format( 109 | time_elapsed // 60, time_elapsed % 60)) 110 | -------------------------------------------------------------------------------- /tracker/BAN/train_dataset/vid/parse_vid.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | from os import listdir 3 | import json 4 | import glob 5 | import xml.etree.ElementTree as ET 6 | 7 | VID_base_path = './ILSVRC2015' 8 | ann_base_path = join(VID_base_path, 'Annotations/VID/train/') 9 | img_base_path = join(VID_base_path, 'Data/VID/train/') 10 | sub_sets = sorted({'a', 'b', 'c', 'd', 'e'}) 11 | 12 | vid = [] 13 | for sub_set in sub_sets: 14 | sub_set_base_path = join(ann_base_path, sub_set) 15 | videos = sorted(listdir(sub_set_base_path)) 16 | s = [] 17 | for vi, video in enumerate(videos): 18 | print('subset: {} video id: {:04d} / {:04d}'.format(sub_set, vi, len(videos))) 19 | v = dict() 20 | v['base_path'] = join(sub_set, video) 21 | v['frame'] = [] 22 | video_base_path = join(sub_set_base_path, video) 23 | xmls = sorted(glob.glob(join(video_base_path, '*.xml'))) 24 | for xml in xmls: 25 | f = dict() 26 | xmltree = ET.parse(xml) 27 | size = xmltree.findall('size')[0] 28 | frame_sz = [int(it.text) for it in size] 29 | objects = xmltree.findall('object') 30 | objs = [] 31 | for object_iter in objects: 32 | trackid = int(object_iter.find('trackid').text) 33 | name = (object_iter.find('name')).text 34 | bndbox = object_iter.find('bndbox') 35 | occluded = int(object_iter.find('occluded').text) 36 | o = dict() 37 | o['c'] = name 38 | o['bbox'] = [int(bndbox.find('xmin').text), int(bndbox.find('ymin').text), 39 | int(bndbox.find('xmax').text), int(bndbox.find('ymax').text)] 40 | o['trackid'] = trackid 41 | o['occ'] = occluded 42 | objs.append(o) 43 | f['frame_sz'] = frame_sz 44 | f['img_path'] = xml.split('/')[-1].replace('xml', 'JPEG') 45 | f['objs'] = objs 46 | v['frame'].append(f) 47 | s.append(v) 48 | vid.append(s) 49 | print('save json (raw vid info), please wait 1 min~') 50 | json.dump(vid, open('vid.json', 'w'), indent=4, sort_keys=True) 51 | print('done!') 52 | -------------------------------------------------------------------------------- /tracker/BAN/train_dataset/vid/readme.md: -------------------------------------------------------------------------------- 1 | # Preprocessing VID(Object detection from video) 2 | Large Scale Visual Recognition Challenge 2015 (ILSVRC2015) 3 | 4 | ### Download dataset (86GB) 5 | 6 | ````shell 7 | wget http://bvisionweb1.cs.unc.edu/ilsvrc2015/ILSVRC2015_VID.tar.gz 8 | tar -xzvf ./ILSVRC2015_VID.tar.gz 9 | ln -sfb $PWD/ILSVRC2015/Annotations/VID/train/ILSVRC2015_VID_train_0000 ILSVRC2015/Annotations/VID/train/a 10 | ln -sfb $PWD/ILSVRC2015/Annotations/VID/train/ILSVRC2015_VID_train_0001 ILSVRC2015/Annotations/VID/train/b 11 | ln -sfb $PWD/ILSVRC2015/Annotations/VID/train/ILSVRC2015_VID_train_0002 ILSVRC2015/Annotations/VID/train/c 12 | ln -sfb $PWD/ILSVRC2015/Annotations/VID/train/ILSVRC2015_VID_train_0003 ILSVRC2015/Annotations/VID/train/d 13 | ln -sfb $PWD/ILSVRC2015/Annotations/VID/val ILSVRC2015/Annotations/VID/train/e 14 | 15 | ln -sfb $PWD/ILSVRC2015/Data/VID/train/ILSVRC2015_VID_train_0000 ILSVRC2015/Data/VID/train/a 16 | ln -sfb $PWD/ILSVRC2015/Data/VID/train/ILSVRC2015_VID_train_0001 ILSVRC2015/Data/VID/train/b 17 | ln -sfb $PWD/ILSVRC2015/Data/VID/train/ILSVRC2015_VID_train_0002 ILSVRC2015/Data/VID/train/c 18 | ln -sfb $PWD/ILSVRC2015/Data/VID/train/ILSVRC2015_VID_train_0003 ILSVRC2015/Data/VID/train/d 19 | ln -sfb $PWD/ILSVRC2015/Data/VID/val ILSVRC2015/Data/VID/train/e 20 | ```` 21 | 22 | ### Crop & Generate data info (20 min) 23 | 24 | ````shell 25 | python parse_vid.py 26 | 27 | #python par_crop.py [crop_size] [num_threads] 28 | python par_crop.py 511 12 29 | python gen_json.py 30 | ```` 31 | -------------------------------------------------------------------------------- /tracker/BAN/train_dataset/vid/visual.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | from os import listdir 3 | import cv2 4 | import numpy as np 5 | import glob 6 | import xml.etree.ElementTree as ET 7 | 8 | visual = False 9 | color_bar = np.random.randint(0, 255, (90, 3)) 10 | 11 | VID_base_path = './ILSVRC2015' 12 | ann_base_path = join(VID_base_path, 'Annotations/VID/train/') 13 | img_base_path = join(VID_base_path, 'Data/VID/train/') 14 | sub_sets = sorted({'a', 'b', 'c', 'd', 'e'}) 15 | for sub_set in sub_sets: 16 | sub_set_base_path = join(ann_base_path, sub_set) 17 | videos = sorted(listdir(sub_set_base_path)) 18 | for vi, video in enumerate(videos): 19 | print('subset: {} video id: {:04d} / {:04d}'.format(sub_set, vi, len(videos))) 20 | 21 | video_base_path = join(sub_set_base_path, video) 22 | xmls = sorted(glob.glob(join(video_base_path, '*.xml'))) 23 | for xml in xmls: 24 | f = dict() 25 | xmltree = ET.parse(xml) 26 | size = xmltree.findall('size')[0] 27 | frame_sz = [int(it.text) for it in size] 28 | objects = xmltree.findall('object') 29 | if visual: 30 | im = cv2.imread(xml.replace('xml', 'JPEG').replace('Annotations', 'Data')) 31 | for object_iter in objects: 32 | trackid = int(object_iter.find('trackid').text) 33 | bndbox = object_iter.find('bndbox') 34 | bbox = [int(bndbox.find('xmin').text), int(bndbox.find('ymin').text), 35 | int(bndbox.find('xmax').text), int(bndbox.find('ymax').text)] 36 | if visual: 37 | pt1 = (int(bbox[0]), int(bbox[1])) 38 | pt2 = (int(bbox[2]), int(bbox[3])) 39 | cv2.rectangle(im, pt1, pt2, color_bar[trackid], 3) 40 | if visual: 41 | cv2.imshow('img', im) 42 | cv2.waitKey(1) 43 | 44 | print('done!') 45 | --------------------------------------------------------------------------------