├── .gitignore ├── LICENSE ├── README.md ├── data ├── coco_to_lvis_synset.json └── configs │ ├── Base-RCNN-FPN.yaml │ └── LVIS-InstanceSegmentation │ └── mask_rcnn_R_101_FPN_1x.yaml ├── docs ├── challenge.md ├── detector_train.md ├── download.md ├── download_hacs_alt.md ├── evaluation.md ├── faqs.md ├── manual_download.md └── trackers.md ├── scripts ├── detectors │ ├── detectron2_infer.py │ ├── detectron2_train_net.py │ └── merge_coco_with_lvis.py ├── download │ ├── download_annotations.py │ ├── download_ava.py │ ├── download_cfg.yaml │ ├── download_hacs.py │ ├── download_helper.py │ ├── extract_frames.py │ ├── gen_checksums.py │ ├── meta │ │ ├── ava_file_names_test_v2.1.txt │ │ └── ava_file_names_trainval_v2.1.txt │ └── verify.py ├── evaluation │ ├── configs │ │ └── default.yaml │ └── evaluate.py └── trackers │ ├── single_obj │ ├── pysot_create_json_for_eval.py │ ├── pysot_trackers.py │ └── visualize.py │ └── sort │ ├── LICENSE │ ├── README.md │ ├── __init__.py │ ├── create_json_for_eval.py │ ├── requirements.txt │ ├── sort.py │ ├── sort_with_detection_id.py │ ├── track.py │ └── visualize.py ├── setup.py └── tao ├── __init__.py ├── toolkit ├── __init__.py └── tao │ ├── __init__.py │ ├── eval.py │ ├── results.py │ └── tao.py ├── trackers └── sot │ ├── base.py │ ├── pysot.py │ ├── pytracking.py │ ├── srdcf.py │ └── staple.py └── utils ├── __init__.py ├── colormap.py ├── cv2_util.py ├── detectron2 └── datasets.py ├── download.py ├── evaluation.py ├── evaluation_mota.py ├── fs.py ├── misc.py ├── parallel ├── __init__.py ├── fixed_gpu_pool.py └── pool_context.py ├── s3.py ├── video.py ├── vis.py ├── yacs_util.py └── ytdl.py /.gitignore: -------------------------------------------------------------------------------- 1 | venv* 2 | tao.egg-info 3 | .ipynb_checkpoints 4 | cache 5 | .vscode 6 | tao/data/s3_cache 7 | .mypy_cache 8 | debug/ 9 | _internal_links.yaml 10 | _pull_internal_changes.py 11 | __pycache__ 12 | .venv 13 | data/detectron_datasets 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | NOTE: This license applies to the code in this repository. 2 | 3 | MIT License 4 | 5 | Copyright (c) 2020 TAO Dataset 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TAO: A Large-Scale Benchmark for Tracking Any Object 2 | 3 | [[Paper](https://arxiv.org/abs/2005.10356)] [[Website](http://taodataset.org)] 4 | 5 | [Achal Dave](http://www.achaldave.com/), [Tarasha Khurana](http://www.cs.cmu.edu/~tkhurana/), [Pavel Tokmakov](https://pvtokmakov.github.io/home/), [Cordelia Schmid](https://thoth.inrialpes.fr/~schmid/), [Deva Ramanan](http://www.cs.cmu.edu/~deva/) 6 | 7 | ## Latest updates 8 | 9 | - \[2024.11.18\]: Updated [downloading docs](https://github.com/TAO-Dataset/tao/blob/master/docs/download.md) to point to HuggingFace copy of the data. 10 | - \[2020.10.25\]: Added [docs](https://github.com/TAO-Dataset/tao/blob/master/docs/trackers.md#single-object-trackers) for running single-object trackers. See [here](https://github.com/TAO-Dataset/tao/issues/23#issuecomment-716068985) for how to add your own SOT tracker! 11 | - \[2020.07.10\]: The ECCV challenge is now live at the 12 | [MOTChallenge website](https://motchallenge.net/results/TAO_Challenge/)! 13 | See [here](docs/challenge.md) for more details. 14 | - \[2020.07.02\]: TAO was accepted to ECCV '20 as a spotlight presentation! 15 | - \[2020.02.20\]: We will be hosting a workshop and challenge at ECCV'20. See [here](http://taodataset.org/workshop/) for details. 16 | 17 | ## Setup 18 | 19 | 1. Clone this repo 20 | ``` 21 | git clone https://github.com/TAO-Dataset/tao 22 | ``` 23 | 1. Install TAO toolkit: 24 | ``` 25 | pip install git+https://github.com/TAO-Dataset/tao 26 | ``` 27 | 28 | ## Download dataset 29 | 30 | See [download instructions](./docs/download.md). 31 | 32 | ## Challenge 33 | 34 | We will be hosting a challenge at our 35 | [ECCV '20 workshop](taodataset.org/workshop/). See [here](docs/challenge.md) for details. 36 | 37 | ## Evaluation 38 | 39 | See [evaluation information](./docs/evaluation.md). Contains information on submitting to the challenge server. 40 | 41 | ## Run baseline trackers 42 | 43 | See [tracker instructions](./docs/trackers.md). 44 | 45 | ## Questions? 46 | 47 | Please see the [faqs](./docs/faqs.md) to check if we've anticipated your 48 | question. If not, for questions about TAO usage or the challenge, please use 49 | this Google Group: https://groups.google.com/forum/#!forum/tao-dataset/ 50 | 51 | For bug reports regarding the toolkit, annotations, or image download, please 52 | file an issue in this repository. 53 | 54 | -------------------------------------------------------------------------------- /data/coco_to_lvis_synset.json: -------------------------------------------------------------------------------- 1 | {"bench": {"coco_cat_id": 15, "meaning": "a long seat for more than one person", "synset": "bench.n.01"}, "baseball bat": {"coco_cat_id": 39, "meaning": "an implement used in baseball by the batter", "synset": "baseball_bat.n.01"}, "kite": {"coco_cat_id": 38, "meaning": "plaything consisting of a light frame covered with tissue paper; flown in wind at end of a string", "synset": "kite.n.03"}, "orange": {"coco_cat_id": 55, "meaning": "orange (FRUIT of an orange tree)", "synset": "orange.n.01"}, "boat": {"coco_cat_id": 9, "meaning": "a vessel for travel on water", "synset": "boat.n.01"}, "carrot": {"coco_cat_id": 57, "meaning": "deep orange edible root of the cultivated carrot plant", "synset": "carrot.n.01"}, "bicycle": {"coco_cat_id": 2, "meaning": "a wheeled vehicle that has two wheels and is moved by foot pedals", "synset": "bicycle.n.01"}, "book": {"coco_cat_id": 84, "meaning": "a written work or composition that has been published", "synset": "book.n.01"}, "toothbrush": {"coco_cat_id": 90, "meaning": "small brush; has long handle; used to clean teeth", "synset": "toothbrush.n.01"}, "tie": {"coco_cat_id": 32, "meaning": "neckwear consisting of a long narrow piece of material worn under a collar and tied in knot at the front", "synset": "necktie.n.01"}, "sandwich": {"coco_cat_id": 54, "meaning": "two (or more) slices of bread with a filling between them", "synset": "sandwich.n.01"}, "toilet": {"coco_cat_id": 70, "meaning": "a plumbing fixture for defecation and urination", "synset": "toilet.n.02"}, "stop sign": {"coco_cat_id": 13, "meaning": "a traffic sign to notify drivers that they must come to a complete stop", "synset": "stop_sign.n.01"}, "wine glass": {"coco_cat_id": 46, "meaning": "a glass that has a stem and in which wine is served", "synset": "wineglass.n.01"}, "clock": {"coco_cat_id": 85, "meaning": "a timepiece that shows the time of day", "synset": "clock.n.01"}, "bear": {"coco_cat_id": 23, "meaning": "large carnivorous or omnivorous mammals with shaggy coats and claws", "synset": "bear.n.01"}, "vase": {"coco_cat_id": 86, "meaning": "an open jar of glass or porcelain used as an ornament or to hold flowers", "synset": "vase.n.01"}, "microwave": {"coco_cat_id": 78, "meaning": "kitchen appliance that cooks food by passing an electromagnetic wave through it", "synset": "microwave.n.02"}, "oven": {"coco_cat_id": 79, "meaning": "kitchen appliance used for baking or roasting", "synset": "oven.n.01"}, "cake": {"coco_cat_id": 61, "meaning": "baked goods made from or based on a mixture of flour, sugar, eggs, and fat", "synset": "cake.n.03"}, "apple": {"coco_cat_id": 53, "meaning": "fruit with red or yellow or green skin and sweet to tart crisp whitish flesh", "synset": "apple.n.01"}, "bed": {"coco_cat_id": 65, "meaning": "a piece of furniture that provides a place to sleep", "synset": "bed.n.01"}, "skis": {"coco_cat_id": 35, "meaning": "sports equipment for skiing on snow", "synset": "ski.n.01"}, "dining table": {"coco_cat_id": 67, "meaning": "a table at which meals are served", "synset": "dining_table.n.01"}, "remote": {"coco_cat_id": 75, "meaning": "a device that can be used to control a machine or apparatus from a distance", "synset": "remote_control.n.01"}, "bird": {"coco_cat_id": 16, "meaning": "animal characterized by feathers and wings", "synset": "bird.n.01"}, "laptop": {"coco_cat_id": 73, "meaning": "a portable computer small enough to use in your lap", "synset": "laptop.n.01"}, "train": {"coco_cat_id": 7, "meaning": "public or private transport provided by a line of railway cars coupled together and drawn by a locomotive", "synset": "train.n.01"}, "mouse": {"coco_cat_id": 74, "meaning": "a computer input device that controls an on-screen pointer", "synset": "mouse.n.04"}, "pizza": {"coco_cat_id": 59, "meaning": "Italian open pie made of thin bread dough spread with a spiced mixture of e.g. tomato sauce and cheese", "synset": "pizza.n.01"}, "toaster": {"coco_cat_id": 80, "meaning": "a kitchen appliance (usually electric) for toasting bread", "synset": "toaster.n.02"}, "cell phone": {"coco_cat_id": 77, "meaning": "a hand-held mobile telephone", "synset": "cellular_telephone.n.01"}, "person": {"coco_cat_id": 1, "meaning": "a human being", "synset": "person.n.01"}, "sports ball": {"coco_cat_id": 37, "meaning": "a spherical object used as a plaything", "synset": "ball.n.06"}, "fire hydrant": {"coco_cat_id": 11, "meaning": "an upright hydrant for drawing water to use in fighting a fire", "synset": "fireplug.n.01"}, "umbrella": {"coco_cat_id": 28, "meaning": "a lightweight handheld collapsible canopy", "synset": "umbrella.n.01"}, "truck": {"coco_cat_id": 8, "meaning": "an automotive vehicle suitable for hauling", "synset": "truck.n.01"}, "knife": {"coco_cat_id": 49, "meaning": "tool with a blade and point used as a cutting instrument", "synset": "knife.n.01"}, "baseball glove": {"coco_cat_id": 40, "meaning": "the handwear used by fielders in playing baseball", "synset": "baseball_glove.n.01"}, "giraffe": {"coco_cat_id": 25, "meaning": "tall animal having a spotted coat and small horns and very long neck and legs", "synset": "giraffe.n.01"}, "airplane": {"coco_cat_id": 5, "meaning": "an aircraft that has a fixed wing and is powered by propellers or jets", "synset": "airplane.n.01"}, "parking meter": {"coco_cat_id": 14, "meaning": "a coin-operated timer located next to a parking space", "synset": "parking_meter.n.01"}, "couch": {"coco_cat_id": 63, "meaning": "an upholstered seat for more than one person", "synset": "sofa.n.01"}, "tennis racket": {"coco_cat_id": 43, "meaning": "a racket used to play tennis", "synset": "tennis_racket.n.01"}, "backpack": {"coco_cat_id": 27, "meaning": "a bag carried by a strap on your back or shoulder", "synset": "backpack.n.01"}, "hot dog": {"coco_cat_id": 58, "meaning": "a smooth-textured sausage, usually smoked, often served on a bread roll", "synset": "frank.n.02"}, "banana": {"coco_cat_id": 52, "meaning": "elongated crescent-shaped yellow fruit with soft sweet flesh", "synset": "banana.n.02"}, "bowl": {"coco_cat_id": 51, "meaning": "a dish that is round and open at the top for serving foods", "synset": "bowl.n.03"}, "skateboard": {"coco_cat_id": 41, "meaning": "a board with wheels that is ridden in a standing or crouching position and propelled by foot", "synset": "skateboard.n.01"}, "bottle": {"coco_cat_id": 44, "meaning": "a glass or plastic vessel used for storing drinks or other liquids", "synset": "bottle.n.01"}, "dog": {"coco_cat_id": 18, "meaning": "a common domesticated dog", "synset": "dog.n.01"}, "frisbee": {"coco_cat_id": 34, "meaning": "a light, plastic disk propelled with a flip of the wrist for recreation or competition", "synset": "frisbee.n.01"}, "broccoli": {"coco_cat_id": 56, "meaning": "plant with dense clusters of tight green flower buds", "synset": "broccoli.n.01"}, "elephant": {"coco_cat_id": 22, "meaning": "a common elephant", "synset": "elephant.n.01"}, "car": {"coco_cat_id": 3, "meaning": "a motor vehicle with four wheels", "synset": "car.n.01"}, "donut": {"coco_cat_id": 60, "meaning": "a small ring-shaped friedcake", "synset": "doughnut.n.02"}, "suitcase": {"coco_cat_id": 33, "meaning": "cases used to carry belongings when traveling", "synset": "bag.n.06"}, "cup": {"coco_cat_id": 47, "meaning": "a small open container usually used for drinking; usually has a handle", "synset": "cup.n.01"}, "hair drier": {"coco_cat_id": 89, "meaning": "a hand-held electric blower that can blow warm air onto the hair", "synset": "hand_blower.n.01"}, "surfboard": {"coco_cat_id": 42, "meaning": "a narrow buoyant board for riding surf", "synset": "surfboard.n.01"}, "traffic light": {"coco_cat_id": 10, "meaning": "a device to control vehicle traffic often consisting of three or more lights", "synset": "traffic_light.n.01"}, "tv": {"coco_cat_id": 72, "meaning": "an electronic device that receives television signals and displays them on a screen", "synset": "television_receiver.n.01"}, "spoon": {"coco_cat_id": 50, "meaning": "a piece of cutlery with a shallow bowl-shaped container and a handle", "synset": "spoon.n.01"}, "horse": {"coco_cat_id": 19, "meaning": "a common horse", "synset": "horse.n.01"}, "motorcycle": {"coco_cat_id": 4, "meaning": "a motor vehicle with two wheels and a strong frame", "synset": "motorcycle.n.01"}, "zebra": {"coco_cat_id": 24, "meaning": "any of several fleet black-and-white striped African equines", "synset": "zebra.n.01"}, "cat": {"coco_cat_id": 17, "meaning": "a domestic house cat", "synset": "cat.n.01"}, "teddy bear": {"coco_cat_id": 88, "meaning": "plaything consisting of a child's toy bear (usually plush and stuffed with soft materials)", "synset": "teddy.n.01"}, "handbag": {"coco_cat_id": 31, "meaning": "a container used for carrying money and small personal items or accessories", "synset": "bag.n.04"}, "sink": {"coco_cat_id": 81, "meaning": "plumbing fixture consisting of a water basin fixed to a wall or floor and having a drainpipe", "synset": "sink.n.01"}, "keyboard": {"coco_cat_id": 76, "meaning": "a keyboard that is a data input device for computers", "synset": "computer_keyboard.n.01"}, "bus": {"coco_cat_id": 6, "meaning": "a vehicle carrying many passengers; used for public transport", "synset": "bus.n.01"}, "fork": {"coco_cat_id": 48, "meaning": "cutlery used for serving and eating food", "synset": "fork.n.01"}, "chair": {"coco_cat_id": 62, "meaning": "a seat for one person, with a support for the back", "synset": "chair.n.01"}, "refrigerator": {"coco_cat_id": 82, "meaning": "a refrigerator in which the coolant is pumped around by an electric motor", "synset": "electric_refrigerator.n.01"}, "scissors": {"coco_cat_id": 87, "meaning": "a tool having two crossed pivoting blades with looped handles", "synset": "scissors.n.01"}, "sheep": {"coco_cat_id": 20, "meaning": "woolly usually horned ruminant mammal related to the goat", "synset": "sheep.n.01"}, "potted plant": {"coco_cat_id": 64, "meaning": "a container in which plants are cultivated", "synset": "pot.n.04"}, "snowboard": {"coco_cat_id": 36, "meaning": "a board that resembles a broad ski or a small surfboard; used in a standing position to slide down snow-covered slopes", "synset": "snowboard.n.01"}, "cow": {"coco_cat_id": 21, "meaning": "cattle that are reared for their meat", "synset": "beef.n.01"}} -------------------------------------------------------------------------------- /data/configs/Base-RCNN-FPN.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "GeneralizedRCNN" 3 | BACKBONE: 4 | NAME: "build_resnet_fpn_backbone" 5 | RESNETS: 6 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 7 | FPN: 8 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 9 | ANCHOR_GENERATOR: 10 | SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map 11 | ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps) 12 | RPN: 13 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] 14 | PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level 15 | PRE_NMS_TOPK_TEST: 1000 # Per FPN level 16 | # Detectron1 uses 2000 proposals per-batch, 17 | # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue) 18 | # which is approximately 1000 proposals per-image since the default batch size for FPN is 2. 19 | POST_NMS_TOPK_TRAIN: 1000 20 | POST_NMS_TOPK_TEST: 1000 21 | ROI_HEADS: 22 | NAME: "StandardROIHeads" 23 | IN_FEATURES: ["p2", "p3", "p4", "p5"] 24 | ROI_BOX_HEAD: 25 | NAME: "FastRCNNConvFCHead" 26 | NUM_FC: 2 27 | POOLER_RESOLUTION: 7 28 | ROI_MASK_HEAD: 29 | NAME: "MaskRCNNConvUpsampleHead" 30 | NUM_CONV: 4 31 | POOLER_RESOLUTION: 14 32 | DATASETS: 33 | TRAIN: ("coco_2017_train",) 34 | TEST: ("coco_2017_val",) 35 | SOLVER: 36 | IMS_PER_BATCH: 16 37 | BASE_LR: 0.02 38 | STEPS: (60000, 80000) 39 | MAX_ITER: 90000 40 | INPUT: 41 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 42 | -------------------------------------------------------------------------------- /data/configs/LVIS-InstanceSegmentation/mask_rcnn_R_101_FPN_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" 4 | MASK_ON: True 5 | RESNETS: 6 | DEPTH: 101 7 | ROI_HEADS: 8 | NUM_CLASSES: 1230 9 | SCORE_THRESH_TEST: 0.0001 10 | INPUT: 11 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 12 | DATASETS: 13 | TRAIN: ("lvis_v0.5_train",) 14 | TEST: ("lvis_v0.5_val",) 15 | TEST: 16 | DETECTIONS_PER_IMAGE: 300 # LVIS allows up to 300 17 | DATALOADER: 18 | SAMPLER_TRAIN: "RepeatFactorTrainingSampler" 19 | REPEAT_THRESHOLD: 0.001 20 | -------------------------------------------------------------------------------- /docs/challenge.md: -------------------------------------------------------------------------------- 1 | # TAO ECCV'20 Multi-Object Tracking Challenge 2 | 3 | We are excited to host a challenge on TAO as part of our 4 | [ECCV workshop](http://taodataset.org/workshop/). 5 | The challenge is hosted on the [motchallenge.net](https://motchallenge.net/) website: 6 | [link](https://motchallenge.net/results/ECCV_2020_TAO_Challenge/). 7 | 8 | ## Important Dates 9 | 10 | - July 10: Challenge released! 11 | - August 16: Challenge closes, winners contacted to prepare presentation for ECCV workshop. 12 | - August 23: ECCV workshop date. Challenge results announced, along with 13 | presentations by challenge submission authors. 14 | 15 | ## Prizes 16 | 17 | We will have the following prizes for the winning entries! 18 | 19 | - First place: $1,500 cash prize, presentation at ECCV workshop. 20 | - Second place: $500 cash prize, presentation at ECCV workshop. 21 | - Honorable mention(s): $250 cash prize, presentation at ECCV workshop. 22 | 23 | ## Protocol 24 | 25 | - **Evaluation data**: The ECCV '20 challenge evaluates multi-object tracking 26 | on the TAO test set. 27 | 28 | - **Training data**: We do not impose any restrictions on the training data used for 29 | submissions, except that the TAO test videos may not be used for training in any way. 30 | This explicitly precludes, for example, unsupervised training on the TAO test set. 31 | However, the TAO validation videos may be used for training in a supervised or 32 | unsupervised manner. 33 | We encourage training on the LVIS v0.5 dataset, which provides 34 | ample detection training data for categories evaluated in TAO. 35 | 36 | - **WARNING**: The TAO test set contains sequences from existing datasets, which 37 | must be excluded from training. These sequences can be seen from the test 38 | json. In particular, a number of LaSOT training sequences are present in the TAO 39 | test set. 40 | 41 | - For submission instructions, see [evaluation.md](evaluation.md). 42 | 43 | 44 | ## FAQs 45 | 46 | Please see [faqs.md](./faqs.md). 47 | -------------------------------------------------------------------------------- /docs/detector_train.md: -------------------------------------------------------------------------------- 1 | # Training your own detectors 2 | 3 | To train your own detectors, follow the steps below: 4 | 5 | 1. Download the LVIS v0.5 annotations and (LVIS v0.5 + COCO) training 6 | annotations from 7 | [here](https://drive.google.com/file/d/1rPSSIVSer7pweyJS-uqAfIF59uZVJ0Nx/view), 8 | and extract them to `./data/detectron_datasets/lvis-coco`. 9 | 10 | 1. Setup [detectron2](https://github.com/facebookresearch/detectron2). 11 | 12 | 1. Download the COCO `train2017` and `val2017` datasets, and link them to: 13 | 14 | ``` 15 | ./data/detectron_datasets/lvis-coco/train2017 16 | ./data/detectron_datasets/lvis-coco/val2017 17 | ``` 18 | 19 | 1. Use the provided `./scripts/detectors/detectron2_train_net.py` script to 20 | train your detector. 21 | 22 | ``` 23 | python scripts/detectors/detectron2_train_net.py \ 24 | --num-gpus 8 \ 25 | --config-file ./data/configs/LVIS-InstanceSegmentation/mask_rcnn_R_101_FPN_1x.yaml \ 26 | DATASETS.TRAIN "('lvis_v0.5_coco_2017_train', )" \ 27 | OUTPUT_DIR /path/to/output-dir 28 | ``` 29 | 30 | This script was tested with detectron2 commit id 31 | fd87af71eebc660dde2f50e4693869bb04f66015. 32 | 33 | -------------------------------------------------------------------------------- /docs/download.md: -------------------------------------------------------------------------------- 1 | # Download TAO 2 | Follow the instructions at our [HuggingFace repo](https://huggingface.co/datasets/chengyenhsieh/TAO-Amodal) to download the TAO videos. 3 | You can download the annotations for the original TAO dataset from [this link](https://motchallenge.net/data/TAOLabels.zip). 4 | 5 | ## Request video deletion 6 | 7 | If you would like to request a video be deleted from TAO (e.g., because you are 8 | featured in the video or you own the rights), please email me at 9 | achalddave@gmail.com. 10 | 11 | -------------------------------------------------------------------------------- /docs/download_hacs_alt.md: -------------------------------------------------------------------------------- 1 | Download and extract from YouTube. 2 | 3 | ``` 4 | python scripts/download/download_hacs.py $TAO_ROOT --split train 5 | ``` 6 | 7 | You can ignore YoutubeDL errors that are printed by this script (e.g., Video not 8 | available). Videos that could not be downloaded will be collected in 9 | `$TAO_ROOT/hacs_missing/missing.txt`. You can request the original HACS videos 10 | by filling out these forms: https://forms.gle/hZD612H5TXDQDozv9 11 | -------------------------------------------------------------------------------- /docs/evaluation.md: -------------------------------------------------------------------------------- 1 | # Evaluating Trackers 2 | 3 | ## Results format 4 | 5 | The TAO toolkit expects results in the same format as COCO, but with additional 6 | `track_id` and `video_id` fields. Specifically, `results.json` should have the 7 | following format: 8 | 9 | ``` 10 | [{ 11 | "image_id" : int, 12 | "category_id" : int, 13 | "bbox" : [x,y,width,height], 14 | "score" : float, 15 | "track_id": int, 16 | "video_id": int 17 | }] 18 | ``` 19 | 20 | 21 | ## Evaluation (toolkit) 22 | 23 | The TAO toolkit provides code for evaluating tracker results. 24 | 25 | ```python 26 | import logging 27 | from tao.toolkit.tao import TaoEval 28 | 29 | # TAO uses logging to print results. Make sure logging is set to show INFO 30 | # messages, or you won't see any evaluation results. 31 | logging.setLevel(logging.INFO) 32 | tao_eval = TaoEval('/path/to/annotations.json', '/path/to/results.json') 33 | tao_eval.run() 34 | tao_eval.print_results() 35 | ``` 36 | 37 | ## Evaluation (command-line) 38 | 39 | TAO also comes with a higher-level `evaluate.py` script which incorporates 40 | various additional features for evaluation. 41 | 42 | In all the examples below, let - 43 | - `$ANNOTATIONS` be the `/path/to/annotations.json` 44 | - `$RESULTS` be the `/path/to/results.json` 45 | - `$OUTPUT_DIR` be the `/path/to/output/logdir`. 46 | 47 | We demonstrate some features below; for more, take a look at the config 48 | description in [`./tao/utils/evaluation.py`](/tao/utils/evaluation.py). 49 | 50 | - Simple evaluation, with logging to an output directory 51 | 52 | ```bash 53 | python scripts/evaluation/evaluate.py \ 54 | $ANNOTATIONS $RESULTS --output-dir $OUTPUT_DIR \ 55 | ``` 56 | 57 | -
Classification oracle

58 | 59 | ```bash 60 | python scripts/evaluation/evaluate.py \ 61 | $ANNOTATIONS $RESULTS --output-dir $OUTPUT_DIR \ 62 | --config-updates ORACLE.TYPE class 63 | ``` 64 |

65 | 66 | -
Track oracle (for linking detections)

67 | 68 | ```bash 69 | python scripts/evaluation/evaluate.py \ 70 | $ANNOTATIONS $RESULTS --output-dir $OUTPUT_DIR \ 71 | --config-updates ORACLE.TYPE track 72 | ``` 73 |

74 | 75 | -
Evaluate MOTA

76 | 77 | ```bash 78 | python scripts/evaluation/evaluate.py \ 79 | $ANNOTATIONS $RESULTS --output-dir $OUTPUT_DIR \ 80 | --config-updates MOTA.ENABLED True 81 | ``` 82 |

83 | 84 | -
Evaluate at (3D) IoU threshold of 0.9

85 | 86 | ```bash 87 | python scripts/evaluation/evaluate.py \ 88 | $ANNOTATIONS $RESULTS --output-dir $OUTPUT_DIR \ 89 | --config-updates EVAL_IOUS "[0.9]" 90 | ``` 91 |

92 | 93 | -
Evaluate at multiple (3D) IoU thresholds

94 | 95 | ```bash 96 | python scripts/evaluation/evaluate.py \ 97 | $ANNOTATIONS $RESULTS --output-dir $OUTPUT_DIR \ 98 | --config-updates \ 99 | EVAL_IOUS "[0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]" 100 | ``` 101 |

102 | 103 | -
Category agnostic evaluation

104 | 105 | ```bash 106 | python scripts/evaluation/evaluate.py \ 107 | $ANNOTATIONS $RESULTS --output-dir $OUTPUT_DIR \ 108 | --config-updates CATEGORY_AGNOSTIC True 109 | ``` 110 |

111 | 112 | -
Report evaluation by source dataset

113 | 114 | ```bash 115 | python scripts/evaluation/evaluate.py \ 116 | $ANNOTATIONS $RESULTS --output-dir $OUTPUT_DIR \ 117 | --config-updates EVAL_BY_DATASET True 118 | ``` 119 |

120 | 121 | ## Evaluation (challenge server) 122 | 123 | For local evaluation, evaluate with steps above on the released validation 124 | set. When submitting test set results to the 125 | [challenge server](https://motchallenge.net/login/), follow same format for 126 | json files as mentioned above. 127 | 128 | The server requires you to submit train, validation and test set results. 129 | We request you to submit these three json files for facilitating progress in 130 | the tracking community. However, if absolutely necessary, submit empty json 131 | files for train and validation. Create a .zip archive that deflates into the 132 | following files 133 | 134 | ```bash 135 | ./TAO_test.json 136 | ./TAO_train.json 137 | ./TAO_val.json 138 | ``` 139 | 140 | ## More details 141 | 142 | ### Merged classes 143 | A few classes from LVIS (v0.5) are merged in TAO, as they are nearly-synonymous. 144 | As such, methods are given credit for predicting any one of the merged classes. 145 | These classes are marked in the annotations, with a `merged` key in the `category` dictionary. 146 | The following code constructs the list of merged classes from the annotations file: 147 | https://github.com/TAO-Dataset/tao/blob/63d04f3b62bd0656614902206bd5e0d1e801dc26/tao/toolkit/tao/tao.py#L96-L104 148 | -------------------------------------------------------------------------------- /docs/faqs.md: -------------------------------------------------------------------------------- 1 | # Frequently asked questions 2 | 3 | 1. Why does the training set only contain 216 LVIS categories? 4 | 5 | TAO contains a total of 482 LVIS categories. However, not all categories 6 | are present in the train, val, and test sets. Instead, we encourage researchers to 7 | train detectors on the LVIS v0.5 dataset, which contains a superset of 8 | the 482 categories, and trackers on existing single-object tracking datasets. 9 | TAO is primarily a benchmark dataset, but we provide a small set of training videos 10 | for tuning trackers. 11 | 12 | 1. Why do the LVIS v1 dataset categories not match with the TAO categories? 13 | 14 | Tao was constructed to be aligned with the LVIS v0.5 dataset. The LVIS v1 update 15 | changes the category names and ids in the LVIS dataset. We are looking into updating 16 | TAO to use the LVIS v1 categories. For now, you may either train on the LVIS v0.5 17 | dataset, or construct your own mapping from LVIS v1 categories to TAO categories 18 | using the 'synset' field. 19 | 20 | 1. Is there any restriction on which data I can train on? 21 | 22 | The only restriction is that you may not train on videos in the TAO test set. 23 | You can see a list of videos in the TAO test set from the test set json file 24 | shared with the annotations. In particular, a number of LaSOT training videos 25 | are in the TAO test set, and must not be used for training. 26 | 27 | Apart from this, there are currently no restrictions on training datasets. 28 | 29 | 1. Are only LVIS categories evaluated in TAO? 30 | 31 | Currently (as of July 2020), we are focusing on the LVIS categories within TAO. 32 | The ECCV challenge will only evaluate on these categories. We intend to formalize 33 | a protocol for evaluation on the non-LVIS categories later this year. 34 | 35 | 1. Is there a single-object tracking track in the ECCV '20 challenge? 36 | 37 | Currently, there is no single-object / user-initialized tracking track in 38 | the challenge. We are looking into ways to host a challenge for user-initialized 39 | tracking on held out data (e.g., by asking researchers to submit code which we run 40 | locally on the held out test set). If you have any suggestions or 41 | feedback, please contact us! 42 | -------------------------------------------------------------------------------- /docs/manual_download.md: -------------------------------------------------------------------------------- 1 | These are alternative instructions that mimic the helper script in 2 | [scripts/download/download_helper.py](/scripts/download/download_helper.py), 3 | in case the helper script causes issues. Please read 4 | [./download.md](./download.md) first. 5 | 6 | 1. Download TAO annotations to $TAO_DIR 7 | 8 | ``` 9 | wget 'https://github.com/TAO-Dataset/annotations/archive/v1.0.tar.gz' 10 | tar xzvf v1.0.tar.gz 11 | mv annotations-v1.0 annotations 12 | ``` 13 | 14 | 1. Extract frames from BDD, Charades, HACS and YFCC-100M. 15 | 16 | ``` 17 | python scripts/download/extract_frames.py $TAO_ROOT --split train 18 | ``` 19 |
After this, your directory should have the following structure:

20 | 21 | ``` 22 | ├── frames 23 | │ └── train 24 | │ ├── ArgoVerse 25 | │ ├── BDD 26 | │ ├── Charades 27 | │ ├── HACS 28 | │ ├── LaSOT 29 | │ └── YFCC100M 30 | └── videos 31 | └── train 32 | ├── BDD 33 | ├── Charades 34 | ├── HACS 35 | └── YFCC100M 36 | ``` 37 |

38 | 39 | 1. Download and extract frames from AVA: 40 | 41 | ``` 42 | python scripts/download/download_ava.py $TAO_ROOT --split train 43 | ``` 44 | 45 | 1. Finally, you can verify that you have downloaded TAO. 46 | 47 |
Expected directory structure

48 | 49 | ``` 50 | ├── frames 51 | │ └── train 52 | │ ├── ArgoVerse 53 | │ ├── AVA 54 | │ ├── BDD 55 | │ ├── Charades 56 | │ ├── HACS 57 | │ ├── LaSOT 58 | │ └── YFCC100M 59 | └── videos 60 | └── train 61 | ├── BDD 62 | ├── Charades 63 | └── YFCC100M 64 | ``` 65 |

66 | 67 | You can run the following command to check that TAO was properly extracted: 68 | 69 | ``` 70 | python scripts/download/verify.py $TAO_ROOT --split train 71 | ``` 72 | -------------------------------------------------------------------------------- /docs/trackers.md: -------------------------------------------------------------------------------- 1 | # Running trackers on TAO 2 | 3 | ## SORT 4 | 5 | Here, we will reproduce a simpler variant of the SORT result presented in TAO. 6 | Specifically, we will reproduce the following row from Table 13 in our 7 | supplementary material. 8 | 9 | | NMS Thresh | Det / image | Det score | `max_age` | `min_hits` | `min_iou` | Track mAP | 10 | | ---------- | ----------- | --------- | --------- | ---------- | --------- | --------- | 11 | | 0.5 | 300 | 0.0005 | 100 | 1 | 0.1 | 11.3 | 12 | 13 | ### Run detectors 14 | 15 | 1. Download and decompress the detection model and config from [here](https://drive.google.com/file/d/13BdXSQDqK0t-LrF2CrwJtT9lFc48u83H/view?usp=sharing) or [here](https://cdn3.vision.in.tum.de/~tao/baselines/detector-r101-fpn-1x-lvis-coco.zip) to 16 | `$DETECTRON_MODEL`. 17 | 18 | If you would like to re-train the detector, please see [this doc](./detector_train.md). 19 | 20 | 1. Setup and install 21 | [detectron2](https://github.com/facebookresearch/detectron2) 22 | 1. Run the detector on TAO: 23 | 24 | ``` 25 | python scripts/detectors/detectron2_infer.py \ 26 | --gpus 0 1 2 3 \ 27 | --root $TAO_ROOT/train \ 28 | --output /path/to/detectron2/output/train \ 29 | --config $DETECTRON_MODEL/config.yaml \ 30 | --opts MODEL.WEIGHTS $DETECTRON_MODEL/model_final.pth 31 | ``` 32 | 33 | On a machine with 4 2080TIs, the above took about 8 hours to run on the 34 | train set. 35 | 36 | ### Run [SORT](https://github.com/abewley/sort) 37 | 38 | ``` 39 | python scripts/trackers/sort/track.py \ 40 | --detections-dir /path/to/detectron2/output/train \ 41 | --annotations $TAO_ROOT/annotations/train.json \ 42 | --output-dir /path/to/sort/output/train \ 43 | --workers 8 44 | ``` 45 | 46 | On our machine, the above took about 11 hours to run on the train set. 47 | 48 | ### Evaluate 49 | 50 | ``` 51 | python scripts/evaluation/evaluate.py \ 52 | $TAO_ROOT/annotations/train.json \ 53 | /path/to/sort/output/train/results.json 54 | ``` 55 | 56 | This should report an AP of 11.3. 57 | 58 | ## Single-object trackers 59 | 60 | Here we show how to run single-object trackers from the excellent PySOT tracking 61 | repository. 62 | 63 | ### Setup 64 | 65 | 1. Download and setup the PySOT repository. This code was tested with PySOT at 66 | commit 67 | [052b96](https://github.com/STVIR/pysot/tree/052b9678a7ed336752f74dc6af31cc00eb004551). 68 | Please follow instructions from the PySOT repository for installation. 69 | 2. Ensure `pysot` to your `PYTHONPATH`. You can check that the following import 70 | works: 71 | 72 | ```bash 73 | python -c 'from pysot.core.config import cfg' 74 | ``` 75 | 76 | ### Download model 77 | 78 | Download configs and models from the PySOT [model 79 | zoo](https://github.com/STVIR/pysot/blob/052b9678a7ed336752f74dc6af31cc00eb004551/MODEL_ZOO.md). 80 | 81 | ### Run tracker 82 | 83 | 1. Run single-object tracker using the first frame of a track as the init: 84 | 85 | ``` 86 | python scripts/trackers/single_obj/pysot_trackers.py \ 87 | --annotations ${TAO_ROOT}/annotations/train.json \ 88 | --frames-dir ${TAO_ROOT}/train/ \ 89 | --output-dir /path/to/pysot/output \ 90 | --config-file /path/to/pysot/repo/experiments/siamrpn_r50_l234_dwxcorr/config.yaml \ 91 | --model-path /path/to/pysot/model/siamrpn_r50_l234_dwxcorr_model.pth \ 92 | --gpus 0 1 2 3 \ 93 | --tasks-per-gpu 2 94 | ``` 95 | 96 | 2. Run tracker with "biggest" init strategy, as in Table 5 of [our 97 | paper](https://arxiv.org/pdf/2005.10356.pdf). To do this, you can add the 98 | `--init biggest` flag, as shown below: 99 | 100 | ``` 101 | python scripts/trackers/single_obj/pysot_trackers.py \ 102 | --annotations ${TAO_ROOT}/annotations/train.json \ 103 | --frames-dir ${TAO_ROOT}/ \ 104 | --output-dir /path/to/pysot/output \ 105 | --config-file /path/to/pysot/repo/experiments/siamrpn_r50_l234_dwxcorr/config.yaml \ 106 | --model-path /path/to/pysot/model/siamrpn_r50_l234_dwxcorr_model.pth \ 107 | --gpus 0 1 2 3 \ 108 | --tasks-per-gpu 2 \ 109 | --init biggest 110 | ``` 111 | 112 | ### Evaluate 113 | 114 | ``` 115 | python scripts/evaluation/evaluate.py \ 116 | $TAO_ROOT/annotations/train.json \ 117 | /path/to/pysot/output/train/results.json \ 118 | SINGLE_OBJECT.ENABLED True \ 119 | THRESHOLD 0.7 120 | ``` 121 | 122 | Note that 0.7 is the tuned threshold for the `siamrpn_r50_l234_dwxcorr` model. 123 | These thresholds are tuned on the training set, as described in Appendix C.2 of our 124 | paper, with results shown in Table 16. 125 | Below are the thresholds for a few PySOT models. 126 | 127 | | Model | Threshold | 128 | | ---- | ---- | 129 | | siamrpn_r50_l234_dwxcorr | 0.7 | 130 | | siamrpn_r50_l234_dwcorr_lt | 0.9 | 131 | | siammask_r50_l3 | 0.8 | 132 | 133 | 134 | The above command, for `siamrpn_r50_l234_dwxcorr`, should produce an AP of 31.5. 135 | 136 | -------------------------------------------------------------------------------- /scripts/detectors/detectron2_infer.py: -------------------------------------------------------------------------------- 1 | # Modified from detectron2/demo/demo.py 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import argparse 5 | import logging 6 | import os 7 | import pickle 8 | from pathlib import Path 9 | 10 | import numpy as np 11 | import torch 12 | from detectron2.config import get_cfg 13 | from detectron2.data.detection_utils import read_image 14 | from detectron2.engine.defaults import DefaultPredictor 15 | from pycocotools import mask 16 | from script_utils.common import common_setup 17 | from tqdm import tqdm 18 | 19 | from tao.utils.parallel.fixed_gpu_pool import FixedGpuPool 20 | 21 | 22 | def init_model(init_args, context): 23 | os.environ['CUDA_VISIBLE_DEVICES'] = str(context['gpu']) 24 | context['predictor'] = DefaultPredictor(init_args['config']) 25 | 26 | 27 | def infer(kwargs, context): 28 | predictor = context['predictor'] 29 | image_path = kwargs['image_path'] 30 | output_path = kwargs['output_path'] 31 | img = read_image(str(image_path), format="BGR") 32 | 33 | predictions = predictor(img) 34 | predictions = predictions["instances"].get_fields() 35 | boxes_decoded = predictions["pred_boxes"].tensor.cpu().numpy().tolist() 36 | scores_decoded = predictions["scores"].cpu().numpy().tolist() 37 | classes_decoded = predictions["pred_classes"].cpu().numpy().tolist() 38 | masks_decoded = None 39 | if args.save_masks: 40 | masks_decoded = predictions["pred_masks"].cpu().numpy().astype(np.bool) 41 | save(boxes_decoded, scores_decoded, classes_decoded, masks_decoded, 42 | output_path) 43 | 44 | 45 | def save(boxes_decoded, scores_decoded, classes_decoded, masks_decoded, 46 | results_path): 47 | predictions_decoded = {} 48 | predictions_decoded["instances"] = { 49 | "pred_boxes": boxes_decoded, 50 | "scores": scores_decoded, 51 | "pred_classes": classes_decoded, 52 | } 53 | if masks_decoded is not None: 54 | rles = mask.encode( 55 | np.array(masks_decoded.transpose((1, 2, 0)), 56 | order='F', 57 | dtype=np.uint8)) 58 | for rle in rles: 59 | rle["counts"] = rle["counts"].decode("utf-8") 60 | predictions_decoded['instances']['pred_masks'] = rles 61 | with open(results_path, 'wb') as f: 62 | pickle.dump(predictions_decoded, f) 63 | 64 | 65 | def setup_cfg(args): 66 | # load config from file and command-line arguments 67 | cfg = get_cfg() 68 | cfg.merge_from_file(args.config_file) 69 | cfg.merge_from_list(args.opts) 70 | if not args.save_masks: 71 | cfg.MODEL.MASK_ON = False 72 | cfg.freeze() 73 | return cfg 74 | 75 | 76 | def get_parser(): 77 | parser = argparse.ArgumentParser(description="Detectron2 Demo") 78 | parser.add_argument("--root", required=True, type=Path) 79 | parser.add_argument("--output", 80 | required=True, 81 | type=Path, 82 | help="Directory to save output pickles.") 83 | parser.add_argument("--config-file", 84 | required=True, 85 | type=Path, 86 | help="path to config file") 87 | parser.add_argument('--gpus', default=[0], nargs='+', type=int) 88 | parser.add_argument( 89 | "--opts", 90 | help="Modify model config options using the command-line", 91 | default=[], 92 | nargs=argparse.REMAINDER) 93 | parser.add_argument( 94 | '--save-masks', default=False, action='store_true') 95 | return parser 96 | 97 | 98 | if __name__ == "__main__": 99 | args = get_parser().parse_args() 100 | Path(args.output).mkdir(exist_ok=True, parents=True) 101 | common_setup(__file__, args.output, args) 102 | # Prevent detectron from flooding terminal with messages. 103 | logging.getLogger('detectron2.checkpoint.c2_model_loading').setLevel( 104 | logging.WARNING) 105 | logging.getLogger('fvcore.common.checkpoint').setLevel( 106 | logging.WARNING) 107 | logger = logging.root 108 | 109 | cfg = setup_cfg(args) 110 | 111 | threads_per_worker = 4 112 | torch.set_num_threads(threads_per_worker) 113 | os.environ['OMP_NUM_THREADS'] = str(threads_per_worker) 114 | 115 | all_files = args.root.rglob('*.jpg') 116 | 117 | # Arguments to init_model() 118 | init_args = {'config': cfg} 119 | 120 | # Tasks to pass to infer() 121 | infer_tasks = [] 122 | for path in tqdm(all_files, 123 | mininterval=1, 124 | dynamic_ncols=True, 125 | desc='Collecting frames'): 126 | relative = path.relative_to(args.root) 127 | output_pkl = (args.output / relative).with_suffix('.pkl') 128 | if output_pkl.exists(): 129 | continue 130 | output_pkl.parent.mkdir(exist_ok=True, parents=True) 131 | infer_tasks.append({'image_path': path, 'output_path': output_pkl}) 132 | 133 | if len(args.gpus) == 1: 134 | context = {'gpu': args.gpus[0]} 135 | init_model(init_args, context) 136 | for task in tqdm(infer_tasks, 137 | mininterval=1, 138 | desc='Running detector', 139 | dynamic_ncols=True): 140 | infer(task, context) 141 | else: 142 | pool = FixedGpuPool( 143 | args.gpus, initializer=init_model, initargs=init_args) 144 | list( 145 | tqdm(pool.imap_unordered(infer, infer_tasks), 146 | total=len(infer_tasks), 147 | mininterval=10, 148 | desc='Running detector', 149 | dynamic_ncols=True)) 150 | -------------------------------------------------------------------------------- /scripts/detectors/detectron2_train_net.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Detection Training Script. 4 | 5 | This scripts reads a given config file and runs the training or evaluation. 6 | It is an entry point that is made to train standard models in detectron2. 7 | 8 | In order to let one script support training of many models, 9 | this script contains logic that are specific to these built-in models and therefore 10 | may not be suitable for your own project. 11 | For example, your research project perhaps only needs a single "evaluator". 12 | 13 | Therefore, we recommend you to use detectron2 as an library and take 14 | this file as an example of how to use the library. 15 | You may want to write your own script with your datasets and other customizations. 16 | """ 17 | 18 | import logging 19 | import os 20 | from collections import OrderedDict 21 | import torch 22 | 23 | import detectron2.utils.comm as comm 24 | from detectron2.checkpoint import DetectionCheckpointer 25 | from detectron2.config import get_cfg 26 | from detectron2.data import MetadataCatalog 27 | from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, hooks, launch 28 | from detectron2.evaluation import ( 29 | CityscapesEvaluator, 30 | COCOEvaluator, 31 | COCOPanopticEvaluator, 32 | DatasetEvaluators, 33 | LVISEvaluator, 34 | PascalVOCDetectionEvaluator, 35 | SemSegEvaluator, 36 | verify_results, 37 | ) 38 | from detectron2.modeling import GeneralizedRCNNWithTTA 39 | 40 | import tao.utils.detectron2.datasets 41 | 42 | 43 | class Trainer(DefaultTrainer): 44 | """ 45 | We use the "DefaultTrainer" which contains a number pre-defined logic for 46 | standard training workflow. They may not work for you, especially if you 47 | are working on a new research project. In that case you can use the cleaner 48 | "SimpleTrainer", or write your own training loop. 49 | """ 50 | 51 | @classmethod 52 | def build_evaluator(cls, cfg, dataset_name, output_folder=None): 53 | """ 54 | Create evaluator(s) for a given dataset. 55 | This uses the special metadata "evaluator_type" associated with each builtin dataset. 56 | For your own dataset, you can simply create an evaluator manually in your 57 | script and do not have to worry about the hacky if-else logic here. 58 | """ 59 | if output_folder is None: 60 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") 61 | evaluator_list = [] 62 | evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type 63 | if evaluator_type in ["sem_seg", "coco_panoptic_seg"]: 64 | evaluator_list.append( 65 | SemSegEvaluator( 66 | dataset_name, 67 | distributed=True, 68 | num_classes=cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, 69 | ignore_label=cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, 70 | output_dir=output_folder, 71 | ) 72 | ) 73 | if evaluator_type in ["coco", "coco_panoptic_seg"]: 74 | evaluator_list.append(COCOEvaluator(dataset_name, cfg, True, output_folder)) 75 | if evaluator_type == "coco_panoptic_seg": 76 | evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder)) 77 | if evaluator_type == "cityscapes": 78 | assert ( 79 | torch.cuda.device_count() >= comm.get_rank() 80 | ), "CityscapesEvaluator currently do not work with multiple machines." 81 | return CityscapesEvaluator(dataset_name) 82 | if evaluator_type == "pascal_voc": 83 | return PascalVOCDetectionEvaluator(dataset_name) 84 | if evaluator_type == "lvis": 85 | return LVISEvaluator(dataset_name, cfg, True, output_folder) 86 | if len(evaluator_list) == 0: 87 | raise NotImplementedError( 88 | "no Evaluator for the dataset {} with the type {}".format( 89 | dataset_name, evaluator_type 90 | ) 91 | ) 92 | if len(evaluator_list) == 1: 93 | return evaluator_list[0] 94 | return DatasetEvaluators(evaluator_list) 95 | 96 | @classmethod 97 | def test_with_TTA(cls, cfg, model): 98 | logger = logging.getLogger("detectron2.trainer") 99 | # In the end of training, run an evaluation with TTA 100 | # Only support some R-CNN models. 101 | logger.info("Running inference with test-time augmentation ...") 102 | model = GeneralizedRCNNWithTTA(cfg, model) 103 | evaluators = [ 104 | cls.build_evaluator( 105 | cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA") 106 | ) 107 | for name in cfg.DATASETS.TEST 108 | ] 109 | res = cls.test(cfg, model, evaluators) 110 | res = OrderedDict({k + "_TTA": v for k, v in res.items()}) 111 | return res 112 | 113 | 114 | def setup(args): 115 | """ 116 | Create configs and perform basic setups. 117 | """ 118 | cfg = get_cfg() 119 | cfg.merge_from_file(args.config_file) 120 | cfg.merge_from_list(args.opts) 121 | cfg.freeze() 122 | default_setup(cfg, args) 123 | return cfg 124 | 125 | 126 | def main(args): 127 | cfg = setup(args) 128 | tao.utils.detectron2.datasets.register_datasets() 129 | 130 | if args.eval_only: 131 | model = Trainer.build_model(cfg) 132 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( 133 | cfg.MODEL.WEIGHTS, resume=args.resume 134 | ) 135 | res = Trainer.test(cfg, model) 136 | if comm.is_main_process(): 137 | verify_results(cfg, res) 138 | if cfg.TEST.AUG.ENABLED: 139 | res.update(Trainer.test_with_TTA(cfg, model)) 140 | return res 141 | 142 | """ 143 | If you'd like to do anything fancier than the standard training logic, 144 | consider writing your own training loop or subclassing the trainer. 145 | """ 146 | trainer = Trainer(cfg) 147 | trainer.resume_or_load(resume=args.resume) 148 | if cfg.TEST.AUG.ENABLED: 149 | trainer.register_hooks( 150 | [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))] 151 | ) 152 | return trainer.train() 153 | 154 | 155 | if __name__ == "__main__": 156 | parser = default_argument_parser() 157 | args = parser.parse_args() 158 | print("Command Line Args:", args) 159 | launch( 160 | main, 161 | args.num_gpus, 162 | num_machines=args.num_machines, 163 | machine_rank=args.machine_rank, 164 | dist_url=args.dist_url, 165 | args=(args,), 166 | ) 167 | -------------------------------------------------------------------------------- /scripts/detectors/merge_coco_with_lvis.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import json 4 | import logging 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | from pycocotools.coco import COCO 9 | import pycocotools.mask as mask_util 10 | from script_utils.common import common_setup 11 | from tqdm import tqdm 12 | 13 | 14 | ROOT = Path(__file__).resolve().parent.parent.parent 15 | 16 | 17 | def main(): 18 | # Use first line of file docstring as description if it exists. 19 | parser = argparse.ArgumentParser( 20 | description=__doc__.split('\n')[0] if __doc__ else '', 21 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 22 | parser.add_argument('--lvis', type=Path, required=True) 23 | parser.add_argument('--coco', type=Path, required=True) 24 | parser.add_argument('--mapping', 25 | type=Path, 26 | default=ROOT / 'data/lvis_coco_to_synset.json') 27 | parser.add_argument('--output-json', 28 | type=Path, 29 | required=True) 30 | parser.add_argument( 31 | '--iou-thresh', 32 | default=0.7, 33 | type=float, 34 | help=('If a COCO annotation overlaps with an LVIS annotations with ' 35 | 'IoU over this threshold, we use only the LVIS annotation.')) 36 | 37 | args = parser.parse_args() 38 | args.output_json.parent.mkdir(exist_ok=True, parents=True) 39 | common_setup(args.output_json.name + '.log', args.output_json.parent, args) 40 | 41 | coco = COCO(args.coco) 42 | lvis = COCO(args.lvis) 43 | 44 | synset_to_lvis_id = {x['synset']: x['id'] for x in lvis.cats.values()} 45 | coco_to_lvis_category = {} 46 | with open(args.mapping, 'r') as f: 47 | name_mapping = json.load(f) 48 | for category in coco.cats.values(): 49 | mapped = name_mapping[category['name']] 50 | assert mapped['coco_cat_id'] == category['id'] 51 | synset = mapped['synset'] 52 | if synset not in synset_to_lvis_id: 53 | logging.debug( 54 | f'Found no LVIS category for "{category["name"]}" from COCO') 55 | continue 56 | coco_to_lvis_category[category['id']] = synset_to_lvis_id[synset] 57 | 58 | for image_id, image in coco.imgs.items(): 59 | if image_id in lvis.imgs: 60 | coco_name = coco.imgs[image_id]['file_name'] 61 | lvis_name = lvis.imgs[image_id]['file_name'] 62 | assert coco_name in lvis_name 63 | else: 64 | logging.info( 65 | f'Image {image_id} in COCO, but not annotated in LVIS') 66 | 67 | lvis_highest_id = max(x['id'] for x in lvis.anns.values()) 68 | ann_id_generator = itertools.count(lvis_highest_id + 1) 69 | new_annotations = [] 70 | for image_id, lvis_anns in tqdm(lvis.imgToAnns.items()): 71 | if image_id not in coco.imgToAnns: 72 | logging.info( 73 | f'Image {image_id} in LVIS, but not annotated in COCO') 74 | continue 75 | 76 | coco_anns = coco.imgToAnns[image_id] 77 | # Compute IoU between coco_anns and lvis_anns 78 | # Shape (num_coco_anns, num_lvis_anns) 79 | mask_iou = mask_util.iou([coco.annToRLE(x) for x in coco_anns], 80 | [lvis.annToRLE(x) for x in lvis_anns], 81 | pyiscrowd=np.zeros(len(lvis_anns))) 82 | does_overlap = mask_iou.max(axis=1) > args.iou_thresh 83 | to_add = [] 84 | for i, ann in enumerate(coco_anns): 85 | if does_overlap[i]: 86 | continue 87 | if ann['category_id'] not in coco_to_lvis_category: 88 | continue 89 | ann['category_id'] = coco_to_lvis_category[ann['category_id']] 90 | ann['id'] = next(ann_id_generator) 91 | to_add.append(ann) 92 | new_annotations.extend(to_add) 93 | 94 | with open(args.lvis, 'r') as f: 95 | merged = json.load(f) 96 | merged['annotations'].extend(new_annotations) 97 | with open(args.output_json, 'w') as f: 98 | json.dump(merged, f) 99 | 100 | 101 | if __name__ == "__main__": 102 | main() 103 | -------------------------------------------------------------------------------- /scripts/download/download_annotations.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import urllib.error 3 | import urllib.request 4 | from pathlib import Path 5 | 6 | import subprocess 7 | 8 | ANNOTATIONS_TAR_GZ = 'https://github.com/TAO-Dataset/annotations/archive/v1.2.tar.gz' 9 | 10 | 11 | def banner_log(msg): 12 | banner = '#' * len(msg) 13 | print(f'\n{banner}\n{msg}\n{banner}') 14 | 15 | 16 | def log_and_run(cmd, *args, **kwargs): 17 | print(f'Running command:\n{" ".join(cmd)}') 18 | subprocess.run(cmd, *args, **kwargs) 19 | 20 | 21 | def main(): 22 | # Use first line of file docstring as description if it exists. 23 | parser = argparse.ArgumentParser( 24 | description=__doc__.split('\n')[0] if __doc__ else '', 25 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 26 | parser.add_argument('tao_root', type=Path) 27 | parser.add_argument('--split', 28 | required=True, 29 | choices=['train', 'val', 'test']) 30 | 31 | args = parser.parse_args() 32 | 33 | assert args.tao_root.exists(), ( 34 | f'TAO_ROOT does not exist at {args.tao_root}') 35 | 36 | annotations_dir = args.tao_root / 'annotations' 37 | if annotations_dir.exists(): 38 | print(f'Annotations directory already exists; skipping.') 39 | else: 40 | annotations_compressed = args.tao_root / 'annotations.tar.gz' 41 | if not annotations_compressed.exists(): 42 | banner_log('Downloading annotations') 43 | try: 44 | urllib.request.urlretrieve(ANNOTATIONS_TAR_GZ, 45 | annotations_compressed) 46 | except urllib.error.HTTPError as e: 47 | if e.code == 404: 48 | print(f'Unable to download annotations.tar.gz. Please ' 49 | f'download it manually from\n' 50 | f'{ANNOTATIONS_TAR_GZ}\n' 51 | f'and save it to {args.tao_root}.') 52 | return 53 | raise 54 | banner_log('Extracting annotations') 55 | log_and_run([ 56 | 'tar', 'xzvf', 57 | str(annotations_compressed), '-C', 58 | str(args.tao_root) 59 | ]) 60 | (args.tao_root / 'annotations-1.2').rename(annotations_dir) 61 | 62 | 63 | if __name__ == "__main__": 64 | main() 65 | -------------------------------------------------------------------------------- /scripts/download/download_ava.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import urllib.request 5 | from collections import defaultdict 6 | from pathlib import Path 7 | 8 | from moviepy.video.io.VideoFileClip import VideoFileClip 9 | from script_utils.common import common_setup 10 | from tqdm import tqdm 11 | 12 | from tao.utils.download import ( 13 | are_tao_frames_dumped, dump_tao_frames, remove_non_tao_frames) 14 | from tao.utils.fs import dir_path, file_path 15 | 16 | META_DIR = Path(__file__).resolve().parent / 'meta' 17 | 18 | AVA_URL = 'https://s3.amazonaws.com/ava-dataset' 19 | 20 | 21 | def close_clip(video): 22 | video.reader.close() 23 | if video.audio and video.audio.reader: 24 | video.audio.reader.close_proc() 25 | 26 | 27 | def ava_load_meta(): 28 | info = {} 29 | for split in ('trainval', 'test'): 30 | with open(META_DIR / f'ava_file_names_{split}_v2.1.txt', 'r') as f: 31 | for line in f: 32 | stem, ext = line.strip().rsplit('.', 1) 33 | info[stem] = {'ext': ext, 'split': split} 34 | return info 35 | 36 | 37 | def download_ava(root, 38 | annotations, 39 | checksums, 40 | workers=8, 41 | movies_dir=None): 42 | if movies_dir is None: 43 | movies_dir = root / 'cache' / 'ava_movies' 44 | movies_dir.mkdir(exist_ok=True, parents=True) 45 | 46 | logging.info(f'Downloading AVA videos.') 47 | videos = [ 48 | v for v in annotations['videos'] if v['metadata']['dataset'] == 'AVA' 49 | ] 50 | 51 | movie_clips = defaultdict(list) 52 | for v in videos: 53 | movie_clips[v['metadata']['movie']].append(v) 54 | 55 | movie_info = ava_load_meta() 56 | 57 | videos_dir = root / 'videos' 58 | frames_root = root / 'frames' 59 | for movie_stem, clips in tqdm(movie_clips.items(), 60 | desc='Processing AVA movies'): 61 | movie = f"{movie_stem}.{movie_info[movie_stem]['ext']}" 62 | 63 | # List of (clip, output clip path, output frames directory) for clips 64 | # whose frames have not already been extracted. 65 | to_process = [] 66 | for clip in clips: 67 | name = clip['name'] 68 | output_clip = file_path(videos_dir / f"{name}.mp4") 69 | output_frames = dir_path(frames_root / name) 70 | if are_tao_frames_dumped(output_frames, 71 | checksums[name], 72 | warn=False): 73 | logging.debug(f'Skipping extracted clip: {name}') 74 | continue 75 | to_process.append((clip, output_clip, output_frames)) 76 | 77 | # Download movie if necessary. 78 | if all(x[1].exists() for x in to_process): 79 | movie_vfc = None 80 | else: 81 | if movies_dir and (movies_dir / movie).exists(): 82 | downloaded_movie_this_run = False 83 | movie_path = movies_dir / movie 84 | logging.debug(f'Found AVA movie {movie} at {movie_path}') 85 | else: 86 | downloaded_movie_this_run = True 87 | movie_path = movies_dir / movie 88 | if not movie_path.exists(): 89 | logging.debug(f'Downloading AVA movie: {movie}.') 90 | url = ( 91 | f"{AVA_URL}/{movie_info[movie_stem]['split']}/{movie}") 92 | urllib.request.urlretrieve(url, movie_path) 93 | movie_vfc = VideoFileClip(str(movie_path)) 94 | 95 | for clip_info, clip_path, frames_dir in tqdm(to_process, 96 | desc='Extracting shots', 97 | leave=False): 98 | if clip_path.exists(): 99 | continue 100 | shot_endpoints = clip_info['metadata']['scene'].rsplit('_', 1)[1] 101 | start, end = shot_endpoints.split('-') 102 | subclip = movie_vfc.subclip( 103 | int(start) / movie_vfc.fps, 104 | int(end) / movie_vfc.fps) 105 | subclip.write_videofile(str(clip_path), 106 | audio=False, 107 | verbose=False, 108 | progress_bar=False) 109 | close_clip(subclip) 110 | 111 | if movie_vfc: 112 | close_clip(movie_vfc) 113 | if downloaded_movie_this_run: 114 | movie_path.unlink() 115 | 116 | logging.debug( 117 | f'AVA: Dumping TAO frames:\n{[x[1:] for x in to_process]}') 118 | dump_tao_frames([x[1] for x in to_process], [x[2] for x in to_process], 119 | workers) 120 | for clip, clip_path, frame_dir in to_process: 121 | if not are_tao_frames_dumped(frame_dir, checksums[clip['name']]): 122 | raise ValueError( 123 | f'Not all TAO frames for {clip["name"]} were extracted. ' 124 | f'Try deleting the clip at {clip_path} and running this ' 125 | f'script again.') 126 | remove_non_tao_frames(frame_dir, 127 | set(checksums[clip['name']].keys())) 128 | assert are_tao_frames_dumped(frame_dir, checksums[clip['name']]), ( 129 | f'ERROR: TAO frames were dumped properly for {clip["name"]}, ' 130 | f'but were deleted by `remove_non_tao_frames`! This is a bug, ' 131 | f'please report it.') 132 | 133 | 134 | def main(): 135 | # Use first line of file docstring as description if it exists. 136 | parser = argparse.ArgumentParser( 137 | description=__doc__.split('\n')[0] if __doc__ else '', 138 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 139 | parser.add_argument('root', type=Path) 140 | parser.add_argument('--split', 141 | required=True, 142 | choices=['train', 'val', 'test']) 143 | 144 | optional = parser.add_argument_group('Optional') 145 | optional.add_argument('--workers', default=8, type=int) 146 | optional.add_argument( 147 | '--movies-dir', 148 | type=Path, 149 | help=('Directory to save AVA movies to. If you have a copy ' 150 | 'AVA locally, you can point to that directory to skip ' 151 | 'downloading. NOTE: Any movies downloaded by this script will ' 152 | 'be deleted after the script completes. Any movies that already ' 153 | 'existed on disk will not be deleted.')) 154 | 155 | args = parser.parse_args() 156 | log_dir = args.root / 'logs' 157 | log_dir.mkdir(exist_ok=True, parents=True) 158 | common_setup(__file__, log_dir, args) 159 | 160 | ann_path = args.root / f'annotations/{args.split}.json' 161 | with open(ann_path, 'r') as f: 162 | tao = json.load(f) 163 | 164 | checksums_path = ( 165 | args.root / f'annotations/checksums/{args.split}_checksums.json') 166 | with open(checksums_path, 'r') as f: 167 | checksums = json.load(f) 168 | # checksums = {} 169 | # for image in tao['images']: 170 | # video = image['video'] 171 | # if video not in checksums: 172 | # checksums[video] = {} 173 | # name = image['file_name'].split('/')[-1].replace('.jpeg', '.jpg') 174 | # checksums[video][name] = '' 175 | 176 | download_ava(args.root, 177 | tao, 178 | checksums, 179 | workers=args.workers, 180 | movies_dir=args.movies_dir) 181 | 182 | 183 | if __name__ == "__main__": 184 | main() 185 | -------------------------------------------------------------------------------- /scripts/download/download_cfg.yaml: -------------------------------------------------------------------------------- 1 | TAO_ANNOTATIONS: 2 | TRAIN: /data/achald/track_dataset/annotations/scale/4-18/tao-format/train_federated_lvis.json 3 | VAL: /data/achald/track_dataset/annotations/scale/4-18/tao-format/validation_federated_lvis.json 4 | CHECKSUMS: 5 | VERIFY: True 6 | PATH: /data/achald/track_dataset/annotations/scale/4-18/tao-format/with_test_unfederated/checksums.json 7 | AVA: 8 | MOVIES: 9 | # Contains symlinks to /data/all/AVA/data 10 | DIR: /scratch/achald/tao/release/ava/ 11 | LASOT: 12 | DATASET_ROOT: /ssd1/achald/lasot 13 | CREATE_SYMLINKS: True 14 | CHARADES: 15 | VIDEOS_DIR: /data/all/Charades/Charades_v1/videos 16 | BDD: 17 | VIDEOS_DIR: /data/achald/track_dataset/bdd/val/videos/val_00/ -------------------------------------------------------------------------------- /scripts/download/download_hacs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import json 4 | import logging 5 | import shutil 6 | from pathlib import Path 7 | from textwrap import fill 8 | 9 | from moviepy.video.io.VideoFileClip import VideoFileClip 10 | from script_utils.common import common_setup 11 | from tqdm import tqdm 12 | 13 | from tao.utils.download import ( 14 | are_tao_frames_dumped, dump_tao_frames, remove_non_tao_frames) 15 | from tao.utils.fs import dir_path, file_path 16 | from tao.utils.ytdl import download_to_bytes 17 | 18 | META_DIR = Path(__file__).resolve().parent / 'meta' 19 | 20 | 21 | def close_clip(video): 22 | video.reader.close() 23 | if video.audio and video.audio.reader: 24 | video.audio.reader.close_proc() 25 | 26 | 27 | def download_hacs(root, annotations, checksums, workers=8, debug=False): 28 | logging.info(f'Downloading HACS videos.') 29 | videos = [ 30 | v for v in annotations['videos'] if v['metadata']['dataset'] == 'HACS' 31 | ] 32 | 33 | if debug: 34 | # Take 5 of each type of video. 35 | _scene_videos = [ 36 | v for v in videos if v['metadata']['scene'] is not None 37 | ] 38 | _noscene_videos = [v for v in videos if v['metadata']['scene'] is None] 39 | videos = _scene_videos[:5] + _noscene_videos[:5] 40 | 41 | videos_dir = root / 'videos' 42 | frames_dir = root / 'frames' 43 | tmp_dir = dir_path(root / 'cache' / 'hacs_videos') 44 | missing_dir = Path(root / 'hacs_missing') 45 | 46 | # List of (video, video_path, frame_path) 47 | videos_to_dump = [] 48 | unavailable_videos = [] 49 | for video in tqdm(videos, desc='Downloading HACS'): 50 | video_path = file_path(videos_dir / f"{video['name']}.mp4") 51 | frame_output = dir_path(frames_dir / video['name']) 52 | if are_tao_frames_dumped(frame_output, 53 | checksums[video['name']], 54 | warn=False): 55 | continue 56 | if not video_path.exists(): 57 | ytid = video['metadata']['youtube_id'] 58 | full_video = tmp_dir / f"v_{ytid}.mp4" 59 | missing_downloaded = missing_dir / f"{ytid}.mp4" 60 | if missing_downloaded.exists(): 61 | logging.info( 62 | f'Found video downloaded by user at {missing_downloaded}.') 63 | shutil.copy2(missing_downloaded, full_video) 64 | if not full_video.exists(): 65 | url = 'http://youtu.be/' + ytid 66 | try: 67 | vid_bytes = download_to_bytes(url) 68 | except BaseException: 69 | vid_bytes = None 70 | if isinstance(vid_bytes, int) or vid_bytes is None: 71 | unavailable_videos.append( 72 | (ytid, video['metadata']['action'])) 73 | continue 74 | else: 75 | vid_bytes = vid_bytes.getvalue() 76 | if len(vid_bytes) == 0: 77 | unavailable_videos.append( 78 | (ytid, video['metadata']['action'])) 79 | continue 80 | with open(full_video, 'wb') as f: 81 | f.write(vid_bytes) 82 | 83 | if video['metadata']['scene'] is not None: 84 | shot_endpoints = video['metadata']['scene'].rsplit('_', 1)[1] 85 | start, end = shot_endpoints.split('-') 86 | clip = VideoFileClip(str(full_video)) 87 | subclip = clip.subclip( 88 | int(start) / clip.fps, 89 | int(end) / clip.fps) 90 | subclip.write_videofile(str(video_path), 91 | audio=False, 92 | verbose=False, 93 | progress_bar=False) 94 | else: 95 | shutil.copy2(full_video, video_path) 96 | videos_to_dump.append((video['name'], video_path, frame_output)) 97 | 98 | dump_tao_frames([x[1] for x in videos_to_dump], 99 | [x[2] for x in videos_to_dump], workers) 100 | for video, video_path, frame_dir in videos_to_dump: 101 | remove_non_tao_frames(frame_dir, set(checksums[video].keys())) 102 | assert are_tao_frames_dumped(frame_dir, checksums[video]), ( 103 | f'Not all TAO frames for {video} were extracted.') 104 | 105 | if unavailable_videos: 106 | missing_path = file_path(missing_dir / 'missing.txt') 107 | logging.error('\n'.join([ 108 | '', 109 | f'{len(unavailable_videos)} video(s) could not be downloaded; ' 110 | 'please request them from the HACS website by uploading ', 111 | f'\t{missing_path}', 112 | 'to the following form', 113 | '\thttps://goo.gl/forms/0STStcLndI32oke22', 114 | 'See the following README for details:', 115 | '\thttps://github.com/hangzhaomit/HACS-dataset#request-testing-videos-and-missing-videos-new', 116 | ])) 117 | 118 | with open(missing_path, 'w') as f: 119 | csv.writer(f).writerows(unavailable_videos) 120 | 121 | if len(unavailable_videos) > 20: 122 | logging.error( 123 | fill('NOTE: Over 20 HACS videos were unavailable. This may mean ' 124 | 'that YouTube is rate-limiting your download; please try ' 125 | 'running this script again after a few hours, or on a ' 126 | 'different machine.')) 127 | 128 | 129 | def main(): 130 | # Use first line of file docstring as description if it exists. 131 | parser = argparse.ArgumentParser( 132 | description=__doc__.split('\n')[0] if __doc__ else '', 133 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 134 | parser.add_argument('root', type=Path) 135 | parser.add_argument('--split', 136 | required=True, 137 | choices=['train', 'val', 'test']) 138 | 139 | optional = parser.add_argument_group('Optional') 140 | optional.add_argument('--workers', default=8, type=int) 141 | optional.add_argument( 142 | '--movies-dir', 143 | type=Path, 144 | help=('Directory to save AVA movies to. If you have a copy ' 145 | 'AVA locally, you can point to that directory to skip ' 146 | 'downloading. NOTE: Any movies downloaded by this script will ' 147 | 'be deleted after the script completes. Any movies that already ' 148 | 'existed on disk will not be deleted.')) 149 | 150 | args = parser.parse_args() 151 | log_dir = args.root / 'logs' 152 | log_dir.mkdir(exist_ok=True, parents=True) 153 | common_setup(__file__, log_dir, args) 154 | 155 | ann_path = args.root / f'annotations/{args.split}.json' 156 | with open(ann_path, 'r') as f: 157 | tao = json.load(f) 158 | 159 | checksums_path = ( 160 | args.root / f'annotations/checksums/{args.split}_checksums.json') 161 | with open(checksums_path, 'r') as f: 162 | checksums = json.load(f) 163 | 164 | download_hacs(args.root, tao, checksums, workers=args.workers) 165 | 166 | 167 | if __name__ == "__main__": 168 | main() 169 | -------------------------------------------------------------------------------- /scripts/download/download_helper.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import urllib.error 3 | import urllib.request 4 | from pathlib import Path 5 | 6 | import subprocess 7 | 8 | # ANNOTATIONS_TAR_GZ = 'https://github.com/TAO-Dataset/annotations/archive/v1.0.tar.gz' 9 | # Temporary URL while in beta. 10 | ANNOTATIONS_TAR_GZ = 'https://achal-public.s3.amazonaws.com/release-beta/annotations/annotations.tar.gz' 11 | 12 | 13 | def banner_log(msg): 14 | banner = '#' * len(msg) 15 | print(f'\n{banner}\n{msg}\n{banner}') 16 | 17 | 18 | def log_and_run(cmd, *args, **kwargs): 19 | print(f'Running command:\n{" ".join(cmd)}') 20 | subprocess.run(cmd, *args, **kwargs) 21 | 22 | 23 | def main(): 24 | # Use first line of file docstring as description if it exists. 25 | parser = argparse.ArgumentParser( 26 | description=__doc__.split('\n')[0] if __doc__ else '', 27 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 28 | parser.add_argument('tao_root', type=Path) 29 | parser.add_argument('--split', 30 | required=True, 31 | choices=['train', 'val', 'test']) 32 | 33 | args = parser.parse_args() 34 | 35 | assert args.tao_root.exists(), ( 36 | f'TAO_ROOT does not exist at {args.tao_root}') 37 | 38 | annotations_dir = args.tao_root / 'annotations' 39 | if annotations_dir.exists(): 40 | print(f'Annotations directory already exists; skipping.') 41 | else: 42 | annotations_compressed = args.tao_root / 'annotations.tar.gz' 43 | if not annotations_compressed.exists(): 44 | banner_log('Downloading annotations') 45 | try: 46 | urllib.request.urlretrieve(ANNOTATIONS_TAR_GZ, 47 | annotations_compressed) 48 | except urllib.error.HTTPError as e: 49 | if e.code == 404: 50 | print(f'Unable to download annotations.tar.gz. Please ' 51 | f'download it manually from\n' 52 | f'{ANNOTATIONS_TAR_GZ}\n' 53 | f'and save it to {args.tao_root}.') 54 | return 55 | raise 56 | banner_log('Extracting annotations') 57 | log_and_run([ 58 | 'tar', 'xzvf', 59 | str(annotations_compressed), '-C', 60 | str(args.tao_root) 61 | ]) 62 | (args.tao_root / 'annotations-1.0').rename(annotations_dir) 63 | 64 | banner_log("Extracting BDD, Charades, HACS, and YFCC frames") 65 | log_and_run([ 66 | 'python', 'scripts/download/extract_frames.py', 67 | str(args.tao_root), '--split', args.split 68 | ]) 69 | 70 | banner_log("Downloading AVA videos") 71 | log_and_run([ 72 | 'python', 'scripts/download/download_ava.py', 73 | str(args.tao_root), '--split', args.split 74 | ]) 75 | 76 | banner_log("Verifying TAO frames") 77 | log_and_run([ 78 | 'python', 'scripts/download/verify.py', 79 | str(args.tao_root), '--split', args.split 80 | ]) 81 | 82 | 83 | if __name__ == "__main__": 84 | main() 85 | -------------------------------------------------------------------------------- /scripts/download/extract_frames.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | from collections import defaultdict 5 | from pathlib import Path 6 | 7 | from script_utils.common import common_setup 8 | 9 | from tao.utils.download import ( 10 | are_tao_frames_dumped, dump_tao_frames, remove_non_tao_frames) 11 | 12 | 13 | def main(): 14 | # Use first line of file docstring as description if it exists. 15 | parser = argparse.ArgumentParser( 16 | description=__doc__.split('\n')[0] if __doc__ else '', 17 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 18 | parser.add_argument('root', type=Path) 19 | parser.add_argument('--split', 20 | required=True, 21 | choices=['train', 'val', 'test']) 22 | parser.add_argument('--sources', 23 | default=['BDD', 'HACS', 'Charades', 'YFCC100M'], 24 | choices=['BDD', 'HACS', 'Charades', 'YFCC100M']) 25 | parser.add_argument('--workers', default=8, type=int) 26 | 27 | args = parser.parse_args() 28 | log_dir = args.root / 'logs' 29 | log_dir.mkdir(exist_ok=True, parents=True) 30 | common_setup(__file__, log_dir, args) 31 | 32 | ann_path = args.root / f'annotations/{args.split}.json' 33 | with open(ann_path, 'r') as f: 34 | tao = json.load(f) 35 | 36 | checksums_path = ( 37 | args.root / f'annotations/checksums/{args.split}_checksums.json') 38 | with open(checksums_path, 'r') as f: 39 | checksums = json.load(f) 40 | 41 | videos_by_dataset = defaultdict(list) 42 | for video in tao['videos']: 43 | videos_by_dataset[video['metadata']['dataset']].append(video) 44 | 45 | videos_dir = args.root / 'videos' 46 | frames_dir = args.root / 'frames' 47 | for dataset in args.sources: 48 | # Collect list of videos 49 | ext = '.mov' if dataset == 'BDD' else '.mp4' 50 | videos = videos_by_dataset[dataset] 51 | video_paths = [ 52 | videos_dir / f"{video['name']}{ext}" for video in videos 53 | ] 54 | output_frame_dirs = [frames_dir / video['name'] for video in videos] 55 | 56 | # List of (video, video path, frame directory) tuples 57 | to_dump = [] 58 | for video, video_path, frame_dir in zip(videos, video_paths, 59 | output_frame_dirs): 60 | if not video_path.exists(): 61 | raise ValueError(f'Could not find video at {video_path}') 62 | video_checksums = checksums[video['name']] 63 | if frame_dir.exists() and are_tao_frames_dumped( 64 | frame_dir, video_checksums, warn=False): 65 | continue 66 | to_dump.append((video, video_path, frame_dir)) 67 | 68 | # Dump frames from each video 69 | logging.info(f'{dataset}: Extracting frames') 70 | dump_tao_frames([x[1] for x in to_dump], [x[2] for x in to_dump], 71 | workers=args.workers) 72 | 73 | to_dump = [] 74 | for video, video_path, frame_dir in zip(videos, video_paths, 75 | output_frame_dirs): 76 | video_checksums = checksums[video['name']] 77 | # Remove frames not used for TAO. 78 | remove_non_tao_frames(frame_dir, set(video_checksums.keys())) 79 | # Compare checksums for frames 80 | assert are_tao_frames_dumped(frame_dir, video_checksums), ( 81 | f'Not all TAO frames for {video["name"]} were extracted.') 82 | 83 | logging.info( 84 | f'{dataset}: Removing non-TAO frames, verifying extraction') 85 | logging.info(f'{dataset}: Successfully extracted!') 86 | 87 | 88 | if __name__ == "__main__": 89 | main() 90 | -------------------------------------------------------------------------------- /scripts/download/gen_checksums.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from collections import defaultdict 4 | from hashlib import md5 5 | from pathlib import Path 6 | 7 | from tqdm import tqdm 8 | from script_utils.common import common_setup 9 | 10 | from tao.utils import fs 11 | 12 | 13 | def main(): 14 | # Use first line of file docstring as description if it exists. 15 | parser = argparse.ArgumentParser( 16 | description=__doc__.split('\n')[0] if __doc__ else '', 17 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 18 | parser.add_argument('--frames-dir', type=Path, required=True) 19 | parser.add_argument('--output-json', type=Path, required=True) 20 | parser.add_argument('--tao-annotations', type=Path, required=True) 21 | 22 | args = parser.parse_args() 23 | output_dir = args.output_json.parent 24 | output_dir.mkdir(exist_ok=True, parents=True) 25 | common_setup(args.output_json.name, output_dir, args) 26 | 27 | with open(args.tao_annotations, 'r') as f: 28 | tao = json.load(f) 29 | videos = [x['name'] for x in tao['videos']] 30 | 31 | labeled_frames = defaultdict(set) 32 | for frame in tao['images']: 33 | video, frame_name = frame['file_name'].rsplit('/', 1) 34 | labeled_frames[video].add(frame_name) 35 | 36 | # videos = videos[:10] 37 | hashes = {} 38 | for video in tqdm(videos): 39 | frames = fs.glob_ext(args.frames_dir / video, ('.jpg', '.jpeg')) 40 | hashes[video] = {} 41 | for i, frame in tqdm(enumerate(frames)): 42 | if frame.name in labeled_frames[video]: 43 | with open(frame, 'rb') as f: 44 | hashes[video][frame.name] = md5(f.read()).hexdigest() 45 | else: 46 | hashes[video][frame.name] = '' 47 | if all(x == '' for x in hashes[video].values()): 48 | raise ValueError(f'Did not find any labeled frames for {video}') 49 | 50 | with open(args.output_json, 'w') as f: 51 | json.dump(hashes, f) 52 | 53 | 54 | if __name__ == "__main__": 55 | main() 56 | -------------------------------------------------------------------------------- /scripts/download/meta/ava_file_names_test_v2.1.txt: -------------------------------------------------------------------------------- 1 | --205wugM18.mkv 2 | -APF0-L14kw.mkv 3 | -FLn0aeA6EU.mkv 4 | 0OLtK6SeTwo.mp4 5 | 1R7n8B8KkZE.mkv 6 | 1XZZnWMP4CU.mkv 7 | 2eTGj8zPykM.mkv 8 | 3-ivkPTSTSw.mp4 9 | 30Qkf0pq-PY.mkv 10 | 55R6Ng9w65o.mkv 11 | 6IebItD0ETQ.mkv 12 | 72MzYjWz_7g.mkv 13 | 7QstV153hbA.mkv 14 | 7SGCpWCNN84.mp4 15 | 7oY-kE-goOA.mkv 16 | 8FYx0LtfPTE.mkv 17 | 8oL0i5WorkE.mp4 18 | 9f8r96-it6c.mkv 19 | A8SUe2Yqn60.mkv 20 | A9WSiEDeu0I.mkv 21 | AwlY-zteegM.mkv 22 | BD3zaLKhkV4.mkv 23 | BLDTynQwGRI.mkv 24 | BU98nWUtT5E.mkv 25 | BV1VreCWZ64.mkv 26 | BnIFkfDhJ2w.mkv 27 | DaUzhc9_6io.mp4 28 | E-6ruyZFfZs.mkv 29 | E-fqjlYMFhE.mp4 30 | EHMP5-9KUdI.mp4 31 | EO1gLAoEZRA.mp4 32 | FONjBIXaM-0.mp4 33 | G0gDuIVKiXg.mkv 34 | GElolK2jG50.mkv 35 | GQxKfbvL3mg.mkv 36 | Gsm_ZBStr0s.mp4 37 | HPd4eMvs1Kg.mp4 38 | HeKz7BELAQc.mkv 39 | HtXWX0LnifY.mp4 40 | IC5M1EhJNfI.webm 41 | IIyYHprTP58.webm 42 | Ic0LMbDyc9Y.mkv 43 | JiBiCiK9HjY.mp4 44 | K-tICG1ek-E.mp4 45 | Ke8b1_yiUVQ.mkv 46 | KkAf75yOKqs.mkv 47 | KrMSZUQJlNM.mkv 48 | LO964EmiVfo.mkv 49 | Mz0FKktvMLY.mkv 50 | NUwem2aZa0Y.mp4 51 | O5y8zKl9X2E.mp4 52 | O8xkUcUJPNo.mkv 53 | OEUMcSba9t0.mp4 54 | OL_Wwo5W1Zs.mp4 55 | OQxN4ksema0.mkv 56 | P5EhajqkqPw.mkv 57 | QTf_v67C5KI.mp4 58 | Qes4a8HuyEc.mkv 59 | RCNuAys0Hsg.mkv 60 | RW-H3fN_79I.mp4 61 | Scg5LeZszCc.mkv 62 | Sntyb4omSfU.mkv 63 | SoNhz0WJZsI.mkv 64 | Uw7387tc9PU.mp4 65 | V6RX59GT-3k.mkv 66 | VNZ8JDb8sks.mkv 67 | ViY7CR2TSO8.mkv 68 | W8TFzEy0gp0.mkv 69 | WMFTBgYWJS8.mkv 70 | Wgytpy6TeUA.mp4 71 | WhkON_S-pQc.mp4 72 | XOe9GeojzCs.mp4 73 | YAAUPjq-L-Q.mp4 74 | Z0FEElATNjk.mkv 75 | Z42lnoj2n08.mkv 76 | ZS2C28fDC9U.mp4 77 | ZbeMNLwASVo.mkv 78 | ZsgPK0XGYoM.mp4 79 | Zu4iQJrlpo0.mkv 80 | _kbrVsCaaPo.mp4 81 | _vy57h5Oeys.mkv 82 | aDfOtlsdoWw.mkv 83 | bNP8Q_8u89A.webm 84 | bUVls-bf0jM.mkv 85 | bzGQK5lH-RA.mkv 86 | c5mlhcFYYZs.mp4 87 | cYt6NaQgcEk.mp4 88 | cqkChR44vkA.mkv 89 | fT_WjgJ_-r0.mkv 90 | gEI9qBdVt5I.mp4 91 | h7Atb503JwY.webm 92 | hgmK4Epb02E.mkv 93 | i9cuy3teV0w.mkv 94 | ipBRBABLSAk.mkv 95 | jKKXDh4lYd0.mkv 96 | kW5WyJ1QNpM.mkv 97 | keUOiCcHtoQ.mkv 98 | kvFlbTK812w.mkv 99 | l8_Mk3-sZsQ.mkv 100 | nAg_NVzLoAY.mkv 101 | nRzhjXMIXt4.mkv 102 | o-ZcbjLBtls.mkv 103 | ohn_RxyaCy4.mp4 104 | pSE4Dlork1Y.mp4 105 | pSdPmmJ3-ng.mp4 106 | rJibAAUEMDY.mkv 107 | rRL0Ce8e-RY.mkv 108 | rTCch_5JlkA.mp4 109 | s2z5UASlrP8.mkv 110 | sV3zZROy0uc.mkv 111 | tDF-BqFfF78.mkv 112 | tj-VmrMYtUI.mp4 113 | u97DLHpcw7c.mkv 114 | vL7N_xRJKJU.mp4 115 | vsMgg4snZzM.mkv 116 | w-jIrlwuv2Y.mkv 117 | wamBSoyRtbs.mkv 118 | woC9Vfbn74I.mkv 119 | xH1WLtZ8csM.mp4 120 | xJpDPrwLJh4.mkv 121 | xT2ogY6xEsI.mp4 122 | xYUx0drhUNk.mkv 123 | xauSNGP5yA0.mkv 124 | xdDTWBRWPLQ.mkv 125 | y4lBI_gFnqI.mkv 126 | y5o8w0FRj98.mkv 127 | yQdi5Ke4dNY.mkv 128 | yRRZkwtJCwU.mkv 129 | z5lg_3abT-s.mkv 130 | zm78XnWN7MU.mkv 131 | zvxnOrzTg0M.mp4 132 | -------------------------------------------------------------------------------- /scripts/download/meta/ava_file_names_trainval_v2.1.txt: -------------------------------------------------------------------------------- 1 | _-Z6wFjXtGQ.mkv 2 | _145Aa_xkuE.mp4 3 | _7oWZq_s_Sk.mkv 4 | _a9SWtcaNj8.mkv 5 | _Ca3gOdOHxU.mp4 6 | _dBTTYDRdRQ.webm 7 | _eBah6c5kyA.mkv 8 | _ithRWANKB0.mp4 9 | _mAfwH6i90E.mkv 10 | -5KQ66BBWC4.mkv 11 | -FaXLcSFjUI.mp4 12 | -IELREHX_js.mp4 13 | -OyDO1g74vc.mp4 14 | -XpUuIgyUHE.mp4 15 | -ZFgsrolSxo.mkv 16 | 053oq2xB3oU.mkv 17 | 0f39OWEqJ24.mp4 18 | 0wBYFahr3uI.mp4 19 | 1j20qq1JyX4.mp4 20 | 1ReZIMmD_8E.mp4 21 | 26V9UzqSguo.mp4 22 | 2bxKkUgcqpk.mp4 23 | 2DUITARAsWQ.mp4 24 | 2E_e8JlvTlg.mkv 25 | 2FIHxnZKg6A.webm 26 | 2fwni_Kjf2M.mkv 27 | 2KpThOF_QmE.mkv 28 | 2PpxiG0WU18.mkv 29 | 2qQs3Y9OJX0.mkv 30 | 3_VjIRdXVdM.mkv 31 | 32HR3MnDZ8g.mp4 32 | 3IOE-Q3UWdA.mp4 33 | 4gVsDd8PV9U.mp4 34 | 4k-rTF3oZKw.mp4 35 | 4Y5qi1gD2Sw.mkv 36 | 4ZpjKfu6Cl8.mkv 37 | 55Ihr6uVIDA.mkv 38 | 5BDj0ow5hnA.mp4 39 | 5LrOQEt_XVM.mp4 40 | 5milLu-6bWI.mp4 41 | 5MxjqHfkWFI.mkv 42 | 5YPjcdLbs5g.mkv 43 | 6d5u6FHvz7Q.mkv 44 | 7g37N3eoQ9s.mkv 45 | 7nHkh4sP5Ks.mkv 46 | 7T5G0CmwTPo.mkv 47 | 7YpF6DntOYw.mkv 48 | 8aMv-ZGD4ic.mkv 49 | 8JSxLhDMGtE.mkv 50 | 8nO5FFbIAog.webm 51 | 8VZEwOCQ8bc.mkv 52 | 914yZXz-iRs.mkv 53 | 9bK05eBt1GM.mp4 54 | 9eAOr_ttXp0.mkv 55 | 9F2voT6QWvQ.mkv 56 | 9HOMUW7QNFc.mkv 57 | 9IF8uTRrWAM.mkv 58 | 9mLYmkonWZQ.mkv 59 | 9QbzS8bZXFE.mkv 60 | 9Rcxr3IEX4E.mkv 61 | 9tyiDEYiWiA.mkv 62 | 9Y_l9NsnYE0.mp4 63 | aDEYi1OG0vU.mkv 64 | Ag-pXiLrd48.mp4 65 | aMYcLyh9OhU.mkv 66 | AN07xQokfiE.mp4 67 | aRbLw-dU2XY.mp4 68 | ax3q-RkVIt4.mp4 69 | ayAMdYfJJLk.mkv 70 | AYebXQ8eUkM.mkv 71 | b-YoBU0XT90.mp4 72 | B1MAUxpKaV8.mkv 73 | b50s4AlOOKY.mkv 74 | b5pRYl_djbs.mp4 75 | bAVXp1oGjHA.mkv 76 | BCiuXAuCKAU.mp4 77 | bePts02nIY8.mkv 78 | bhlFavrh7WU.mkv 79 | bSZiZ4rOC7c.mkv 80 | BXCh3r-pPAM.mkv 81 | BY3sZmvUp-0.mp4 82 | C25wkwAMB-w.mkv 83 | C3qk4yAMANk.mkv 84 | c9pEMjPT16M.webm 85 | cc4y-yYm5Ao.mkv 86 | CG98XdYsgrA.mkv 87 | cKA-qeZuH_w.mkv 88 | cLiJgvrDlWw.mp4 89 | CMCPhm2L400.mkv 90 | covMYDBa5dk.mp4 91 | CrlfWnsS7ac.mkv 92 | cWYJHb25EVs.mp4 93 | CZ2NP8UsPuE.mkv 94 | D-BJTU6NxZ8.mkv 95 | D8Vhxbho1fY.mp4 96 | Db19rWN5BGo.mkv 97 | dgLApPvmfBE.mkv 98 | Di1MG6auDYo.mkv 99 | dMH8L7mqCNI.mkv 100 | E2jecoyAx1M.mkv 101 | E7JcKooKVsM.mp4 102 | eA55_shhKko.mkv 103 | Ecivp8t3MdY.mkv 104 | Ekwy7wzLfjc.mkv 105 | er7eeiJB6dI.mkv 106 | F3dPH6Xqf5M.mp4 107 | fD6VkIRlIRI.mkv 108 | Feu1_8NazPE.mp4 109 | fGgnNCbXZ20.mp4 110 | fNcxxBjEOgw.mkv 111 | fpprSy6AzKk.mkv 112 | fZs-yXm-uUs.mp4 113 | g1wyIcLPbq0.mp4 114 | G4qq1MRXCiY.mkv 115 | G5Yr20A5z_Q.mkv 116 | GBXK_SyfisM.mkv 117 | Gfdg_GcaNe8.mkv 118 | gjasEUDkbuc.mkv 119 | gjdgj04FzR0.mp4 120 | GozLjpMNADg.mkv 121 | gqmmpoO1JrY.mkv 122 | Gt61_Yekkgc.mp4 123 | Gvp-cj3bmIY.webm 124 | hbYvDvJrpNk.mp4 125 | hHgg9WI8dTk.mkv 126 | Hi8QeP_VPu0.mkv 127 | HJzgJ9ZjvJk.mkv 128 | HKjR70GCRPE.mp4 129 | Hscyg0vLKc8.mp4 130 | HTYT2vF-j_w.mkv 131 | HV0H6oc4Kvs.mkv 132 | HVAmkvLrthQ.mkv 133 | HymKCzQJbB8.mkv 134 | I8j6Xq2B5ys.mp4 135 | Ie35yEssHko.mkv 136 | IKdBLciu_-A.mp4 137 | iSlDMboCSao.mkv 138 | IuPC-z-M9u8.mkv 139 | IzvOYVMltkI.mp4 140 | J1jDc2rTJlg.mkv 141 | j35JnR0Q7Es.mp4 142 | J4bt4y9ShTA.mkv 143 | j5jmjhGBW44.mkv 144 | jBs_XYHI7gM.mkv 145 | jE0S8gYWftE.webm 146 | jgAwJ0RqmYg.mp4 147 | jI0HIlSsa3s.mkv 148 | JNb4nWexD0I.mkv 149 | jqZpiHlJUig.mkv 150 | K_SpqDJnlps.mkv 151 | kAsz-76DTDE.mkv 152 | Kb1fduj-jdY.mp4 153 | KHHgQ_Pe4cI.mkv 154 | KIy2a-nejxg.mp4 155 | kLDpP9QEVBs.mp4 156 | kMy-6RtoOVU.mkv 157 | kplbKz3_fZk.mkv 158 | Ksd1JQFHYWA.mp4 159 | KVq6If6ozMY.mkv 160 | KWoSGtglCms.mkv 161 | l-jxh8gpxuY.mkv 162 | l2XO3tQk8lI.mkv 163 | lDmLcWWBp1E.mkv 164 | Lg1jOu8cUBM.mkv 165 | LIavUJVrXaI.mkv 166 | LrDT25hmApw.mkv 167 | lT1zdTL-3SM.mkv 168 | lWXhqIAvarw.mkv 169 | M6cgEs9JgDo.mkv 170 | Ma2hgTmveKQ.mkv 171 | mfsbYdLx9wE.mkv 172 | miB-wo2PfLI.mkv 173 | mkcDANJjDcM.mkv 174 | N0Dt9i9IUNg.mkv 175 | N1K2bEZLL_A.mkv 176 | N5UD8FGzDek.mkv 177 | N7baJsMszJ0.mkv 178 | NEQ7Wpf-EtI.mkv 179 | nlinqZPgvVk.mkv 180 | NO2esmws190.mkv 181 | O_NYCUhZ9zw.mp4 182 | o4xQ-BEa3Ss.mkv 183 | O5m_0Yay4EU.mkv 184 | oD_wxyTHJ2I.mp4 185 | OfMdakd4bHI.mkv 186 | OGNnUvJq9RI.mkv 187 | oifTDWZvOhY.mkv 188 | oITFHwzfw_k.mkv 189 | om_83F5VwTQ.mp4 190 | oq_bufAhyl8.mkv 191 | Ov0za6Xb1LM.mkv 192 | oWhvucAskhk.mkv 193 | P60OxWahxBQ.mkv 194 | P90hF2S1JzA.mkv 195 | PcFEhUKhN6g.mkv 196 | pGP_oIdKmRY.mkv 197 | phrYEKv0rmw.mkv 198 | phVLLTMzmKk.mkv 199 | pieVIsGmLsc.mkv 200 | piYxcrMxVPw.mkv 201 | plkJ45_-pMk.mp4 202 | PmElx9ZVByw.mp4 203 | PNZQ2UJfyQE.mp4 204 | QaIMUi-elFo.mkv 205 | qBUu7cy-5Iw.mp4 206 | QCLQYnt3aMo.webm 207 | QD3L10bUnBo.mkv 208 | QJzocCGLdHU.mp4 209 | QMwT7DFA5O4.mkv 210 | QotkBTEePI8.mkv 211 | qpoWHELxL-4.mp4 212 | qrkff49p4E4.mp4 213 | qsTqtWVVSLM.mkv 214 | QTmwhrVal1g.mkv 215 | qx2vAO5ofmo.mp4 216 | r2llOyS-BmE.mkv 217 | rCb9-U4TArw.mp4 218 | rFgb2ECMcrY.mkv 219 | ri4P2enZT9o.mkv 220 | Riu4ZKk4YdQ.webm 221 | rJKeqfTlAeY.mkv 222 | rk8Xm0EAOWs.mkv 223 | Rm518TUhbRY.mkv 224 | rUYsoIIE37A.mp4 225 | rXFlJbXyZyc.mkv 226 | S0tkhGJjwLA.mkv 227 | sADELCyj10I.mkv 228 | SCh-ZImnyyk.mp4 229 | SHBMiL5f_3Q.mkv 230 | skiZueh4lfY.mkv 231 | sNQJfYvhcPk.mp4 232 | sUVhd0YTKgw.mkv 233 | T-Fc9ctuNVI.mkv 234 | t0V4drbYDnc.mkv 235 | t1LXrJOvPDg.mkv 236 | T26G6_AjJZ4.mkv 237 | TcB0IFBwk-k.mkv 238 | TCmNvNLRWrc.mkv 239 | tEoJW9ycmSY.mkv 240 | TEQ9sAj-DPo.mp4 241 | tghXjom3120.mkv 242 | tjqCzVjojCo.mkv 243 | TM5MPJIq1Is.mkv 244 | tNpZtigMc4g.mkv 245 | tt0t_a1EDCE.mkv 246 | TzaVHtLXOzY.mkv 247 | U_WzY2k8IBM.mkv 248 | u1ltv6r14KQ.mkv 249 | UgZFdrNT6W0.mkv 250 | uNT6HrrnqPU.webm 251 | UOfuzrwkclM.mkv 252 | UOyyTUX5Vo4.mkv 253 | uq_HBsvP548.mkv 254 | UrsCy6qIGoo.mkv 255 | UsLnxI_zGpY.mkv 256 | uwW0ejeosmk.mkv 257 | uzPI7FcF79U.mkv 258 | v0L-WkMO3s4.mp4 259 | vBbjA4tWCPg.mp4 260 | vfjywN5CN0Y.mkv 261 | Vmef_8MY46w.mkv 262 | VRlpH1MbWUw.mp4 263 | VsYPP2I0aUQ.mkv 264 | wEAeql4z1O0.mp4 265 | wfEOx36N4jA.mp4 266 | WKqbLbU68wU.mkv 267 | WlgxRNCHQzw.mkv 268 | wogRuPNBUi8.mp4 269 | wONG7Vh87B4.mkv 270 | WSPvfxtqisg.mkv 271 | WVde9pyaHg4.mkv 272 | x-6CtPWVi6E.mkv 273 | X5wWhZ2r9kc.mp4 274 | xeGWXqSvC-8.webm 275 | XF87VL5T0aA.mkv 276 | XglAvHaEtHY.mp4 277 | xJmRNZVDDCY.mkv 278 | xmqSaQPzL1E.mkv 279 | xO4ABy2iOQA.mp4 280 | xp67EC-Hvwk.mkv 281 | XpGRS72ghag.mkv 282 | XV_FF3WC7kA.mkv 283 | y7ncweROe9U.mkv 284 | yMtGmGa8KZ0.mkv 285 | yn9WN9lsHRE.mkv 286 | yo-Kg2YxlZs.mkv 287 | yqImJuC5UzI.mp4 288 | Ytga8ciKWJc.mkv 289 | yvgCGJ6vfkY.mkv 290 | YYWdB7h1INo.mkv 291 | z-fsLpGHq6o.mkv 292 | Z1YV6wB037M.mkv 293 | z3kgrh0L_80.mkv 294 | zC5Fh2tTS1U.mp4 295 | zG7mx8KiavA.mp4 296 | zlVkeKC6Ha8.mp4 297 | ZosVdkY76FU.mkv 298 | zR725veL-DI.mkv 299 | ZxQn8HVmXsY.mkv 300 | -------------------------------------------------------------------------------- /scripts/download/verify.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | from collections import defaultdict 5 | from pathlib import Path 6 | 7 | from script_utils.common import common_setup 8 | from tqdm import tqdm 9 | 10 | from tao.utils.download import are_tao_frames_dumped 11 | 12 | 13 | def main(): 14 | # Use first line of file docstring as description if it exists. 15 | parser = argparse.ArgumentParser( 16 | description=__doc__.split('\n')[0] if __doc__ else '', 17 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 18 | parser.add_argument('root', type=Path) 19 | parser.add_argument('--split', 20 | required=True, 21 | choices=['train', 'validation']) 22 | 23 | args = parser.parse_args() 24 | log_dir = args.root / 'logs' 25 | log_dir.mkdir(exist_ok=True, parents=True) 26 | common_setup(__file__, log_dir, args) 27 | 28 | ann_path = args.root / f'annotations/{args.split}.json' 29 | with open(ann_path, 'r') as f: 30 | tao = json.load(f) 31 | 32 | checksums_path = ( 33 | args.root / f'annotations/checksums/{args.split}_checksums.json') 34 | with open(checksums_path, 'r') as f: 35 | checksums = json.load(f) 36 | 37 | videos_by_dataset = defaultdict(list) 38 | for video in tao['videos']: 39 | videos_by_dataset[video['metadata']['dataset']].append(video) 40 | 41 | status = {} 42 | for dataset, videos in sorted(videos_by_dataset.items()): 43 | status[dataset] = True 44 | for video in tqdm(videos, desc=f'Verifying {dataset}'): 45 | name = video['name'] 46 | frame_dir = args.root / 'frames' / name 47 | if not are_tao_frames_dumped( 48 | frame_dir, checksums[name], warn=True, allow_extra=False): 49 | logging.warning( 50 | f'Frames for {name} are not extracted properly. ' 51 | f'Skipping rest of dataset.') 52 | status[dataset] = False 53 | break 54 | 55 | success = [] 56 | for dataset in sorted([d for d, v in status.items() if v]): 57 | success.append(f'{dataset: <12}: Verified ✓✓✓') 58 | 59 | failure = [] 60 | for dataset in sorted([d for d, v in status.items() if not v]): 61 | failure.append(f'{dataset: <12}: FAILED 𐄂𐄂𐄂') 62 | 63 | if success: 64 | logging.info('Success!\n' + ('\n'.join(success))) 65 | if failure: 66 | logging.warning('Some datasets were not properly extracted!\n' + 67 | ('\n'.join(failure))) 68 | 69 | 70 | if __name__ == "__main__": 71 | main() 72 | -------------------------------------------------------------------------------- /scripts/evaluation/configs/default.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TAO-Dataset/tao/a3a713c51e2fdeb3a106c34b06d889ea581150a7/scripts/evaluation/configs/default.yaml -------------------------------------------------------------------------------- /scripts/evaluation/evaluate.py: -------------------------------------------------------------------------------- 1 | """Evaluate tao results (helper script).""" 2 | 3 | import argparse 4 | import logging 5 | from pathlib import Path 6 | 7 | from script_utils.common import common_setup 8 | from tao.utils.evaluation import get_cfg_defaults, evaluate, log_eval 9 | from tao.utils.yacs_util import merge_from_file_with_base 10 | 11 | 12 | CONFIG_DIR = Path(__file__).resolve().parent / 'configs' 13 | 14 | 15 | def main(): 16 | # Use first line of file docstring as description if it exists. 17 | parser = argparse.ArgumentParser( 18 | description=__doc__.split('\n')[0] if __doc__ else '', 19 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 20 | parser.add_argument('annotations', type=Path) 21 | parser.add_argument('predictions', type=Path) 22 | parser.add_argument('--output-dir', type=Path) 23 | parser.add_argument('--config', 24 | type=Path, 25 | default=CONFIG_DIR / 'default.yaml') 26 | parser.add_argument('--config-updates', nargs='*') 27 | 28 | args = parser.parse_args() 29 | 30 | if args.output_dir: 31 | tensorboard_dir = args.output_dir / 'tensorboard' 32 | if tensorboard_dir.exists(): 33 | raise ValueError( 34 | f'Tensorboard dir already exists, not evaluating.') 35 | args.output_dir.mkdir(exist_ok=True, parents=True) 36 | log_path = common_setup(__file__, args.output_dir, args).name 37 | else: 38 | logging.getLogger().setLevel(logging.INFO) 39 | logging.basicConfig(format='%(asctime)s.%(msecs).03d: %(message)s', 40 | datefmt='%H:%M:%S') 41 | logging.info('Args:\n%s', vars(args)) 42 | log_path = None 43 | 44 | cfg = get_cfg_defaults() 45 | merge_from_file_with_base(cfg, args.config) 46 | if args.config_updates: 47 | cfg.merge_from_list(args.config_updates) 48 | cfg.freeze() 49 | 50 | if args.output_dir: 51 | with open(args.output_dir / 'config.yaml', 'w') as f: 52 | f.write(cfg.dump()) 53 | 54 | tao_eval = evaluate(args.annotations, args.predictions, cfg) 55 | log_eval(tao_eval, cfg, output_dir=args.output_dir, log_path=log_path) 56 | 57 | 58 | if __name__ == "__main__": 59 | main() 60 | -------------------------------------------------------------------------------- /scripts/trackers/single_obj/pysot_create_json_for_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import pickle 5 | import numpy as np 6 | from collections import defaultdict 7 | from pathlib import Path 8 | from natsort import natsorted 9 | 10 | from script_utils.common import common_setup 11 | from tqdm import tqdm 12 | 13 | from tao.toolkit.tao import Tao 14 | from tao.utils import fs 15 | 16 | 17 | def create_json(pickle_dir, tao, frames_dir, output_dir, oracle_category): 18 | if not isinstance(pickle_dir, (list, tuple)): 19 | pickle_dir = [pickle_dir] 20 | image_to_id = {x['file_name']: x['id'] for x in tao.imgs.values()} 21 | 22 | paths = [ 23 | (root, root / f'{video["name"]}.pkl') 24 | for video in tao.vids.values() 25 | for root in pickle_dir 26 | ] 27 | 28 | annotations = [] 29 | # Map video to list of track ids seen so far. Used to check for duplicates 30 | # when multiple pickle directories are specified. 31 | seen_track_ids = defaultdict(set) 32 | for root, p in tqdm(paths): 33 | if not p.exists(): 34 | logging.warn(f'Could not find tracks for video {p}') 35 | continue 36 | video_name = str(p.relative_to(root)).split('.pkl')[0] 37 | video_frames_dir = frames_dir / video_name 38 | frames = [ 39 | str(x.relative_to(frames_dir)) 40 | for x in natsorted(fs.glob_ext(video_frames_dir, fs.IMG_EXTENSIONS)) 41 | ] 42 | if not frames: 43 | raise ValueError(f'Found no frames at {video_frames_dir}') 44 | frame_indices = {x: i for i, x in enumerate(frames)} 45 | with open(p, 'rb') as f: 46 | # Map object_id to {'boxes': np.array} 47 | tracks = pickle.load(f) 48 | 49 | for object_id, outputs in tracks.items(): 50 | if object_id in seen_track_ids[video_name]: 51 | raise ValueError( 52 | f'Object id {object_id} in video {video_name} seen ' 53 | f'multiple times!') 54 | if object_id not in tao.tracks: 55 | logging.warn( 56 | f'Object id {object_id} for video {video_name} not found ' 57 | f'in annotations, skipping.') 58 | continue 59 | seen_track_ids[video_name].add(object_id) 60 | init = tao.get_kth_annotation(object_id, 0) 61 | init_frame = frame_indices[tao.imgs[init['image_id']]['file_name']] 62 | boxes = outputs['boxes'] 63 | for i, frame in enumerate(frames[init_frame:]): 64 | if frame not in image_to_id: 65 | continue 66 | x0, y0, x1, y1, score = boxes[i] 67 | w, h = x1 - x0 + 1, y1 - y0 + 1 68 | is_init = np.isinf(score) 69 | annotations.append({ 70 | 'id': len(annotations), 71 | 'image_id': image_to_id[frame], 72 | 'track_id': object_id, 73 | 'bbox': [x0, y0, w, h], 74 | 'video_id': tao.imgs[image_to_id[frame]]['video_id'], 75 | 'category_id': (tao.tracks[object_id]['category_id'] 76 | if oracle_category else 1), 77 | 'score': score, 78 | # Numpy -> python boolean for serialization 79 | '_single_object_init': bool(is_init) 80 | }) 81 | 82 | with open(output_dir / 'results.json', 'w') as f: 83 | json.dump(annotations, f) 84 | 85 | 86 | def main(): 87 | # Use first line of file docstring as description if it exists. 88 | parser = argparse.ArgumentParser( 89 | description=__doc__.split('\n')[0] if __doc__ else '', 90 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 91 | parser.add_argument('--annotations', type=Path, required=True) 92 | # We need the frames dir because the pickles contain boxes into the ordered 93 | # list frames. 94 | parser.add_argument('--frames-dir', type=Path, required=True) 95 | parser.add_argument('--pickle-dir', type=Path, nargs='+', required=True) 96 | parser.add_argument('--oracle-category', action='store_true') 97 | parser.add_argument('--output-dir', 98 | type=Path, 99 | required=True) 100 | 101 | args = parser.parse_args() 102 | args.output_dir.mkdir(exist_ok=True, parents=True) 103 | common_setup(__file__, args.output_dir, args) 104 | 105 | tao = Tao(args.annotations) 106 | create_json(args.pickle_dir, tao, args.frames_dir, args.output_dir, 107 | args.oracle_category) 108 | 109 | 110 | if __name__ == "__main__": 111 | main() 112 | -------------------------------------------------------------------------------- /scripts/trackers/single_obj/visualize.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import logging 4 | import pickle 5 | from collections import defaultdict 6 | from multiprocessing import Pool 7 | from pathlib import Path 8 | 9 | import numpy as np 10 | from natsort import natsorted 11 | from PIL import Image 12 | from script_utils.common import common_setup 13 | from tqdm import tqdm 14 | 15 | from tao.toolkit.tao import Tao 16 | from tao.utils import fs 17 | from tao.utils import video as video_utils 18 | from tao.utils import vis 19 | from tao.utils.colormap import colormap 20 | 21 | 22 | def visualize(pickle_path, video_name, frames_root, cats, vis_cats, 23 | annotations_json, threshold, output_video): 24 | logging.getLogger('tao.toolkit.tao.tao').setLevel(logging.WARN) 25 | tao = Tao(annotations_json) 26 | frames_dir = frames_root / video_name 27 | frame_paths = natsorted(fs.glob_ext(frames_dir, fs.IMG_EXTENSIONS)) 28 | frames = [str(x.relative_to(frames_root)) for x in frame_paths] 29 | frame_indices = {x: i for i, x in enumerate(frames)} 30 | with open(pickle_path, 'rb') as f: 31 | # Map object_id to {'boxes': np.array} 32 | tracks = pickle.load(f) 33 | init_type = tracks.pop('_init_type', 'first') 34 | if init_type != 'first': 35 | raise NotImplementedError( 36 | 'init type "{init_type}" not yet implemented.') 37 | 38 | frame_annotations = defaultdict(list) 39 | init_frames = {} 40 | annotation_id_generator = itertools.count() 41 | for object_id, outputs in tracks.items(): 42 | init = tao.get_kth_annotation(object_id, k=0) 43 | init_frame = frame_indices[tao.imgs[init['image_id']]['file_name']] 44 | init_frames[object_id] = init_frame 45 | boxes = outputs['boxes'] 46 | for i, frame in enumerate(frames[init_frame:]): 47 | if len(boxes) <= i: 48 | logging.warn( 49 | f'Could not find box for object {object_id} for ' 50 | f'frame (index: {i}, {frame})') 51 | continue 52 | box = boxes[i].tolist() 53 | if len(box) == 4: 54 | box.append(1) 55 | x0, y0, x1, y1, score = box 56 | if score < threshold: 57 | continue 58 | w, h = x1 - x0 + 1, y1 - y0 + 1 59 | category = tao.tracks[object_id]['category_id'] 60 | if (vis_cats is not None 61 | and tao.cats[category]['name'] not in vis_cats): 62 | continue 63 | frame_annotations[frame].append({ 64 | 'id': next(annotation_id_generator), 65 | 'track_id': object_id, 66 | 'bbox': [x0, y0, w, h], 67 | 'category_id': category, 68 | 'score': score 69 | }) 70 | size = Image.open(frame_paths[0]).size 71 | output_video.parent.mkdir(exist_ok=True, parents=True) 72 | with video_utils.video_writer(output_video, size=size) as writer: 73 | color_generator = itertools.cycle(colormap(as_int=True).tolist()) 74 | colors = defaultdict(lambda: next(color_generator)) 75 | for frame in frame_paths: 76 | image = np.array(Image.open(frame)) 77 | frame_key = str(frame.relative_to(frames_root)) 78 | tracks = frame_annotations[frame_key] 79 | image = vis.overlay_boxes_coco( 80 | image, 81 | tracks, 82 | colors=[colors[x['track_id']] for x in tracks]) 83 | image = vis.overlay_class_coco( 84 | image, 85 | tracks, 86 | categories=cats, 87 | font_scale=1, 88 | font_thickness=2) 89 | writer.write_frame(image) 90 | 91 | 92 | def visualize_star(kwargs): 93 | return visualize(**kwargs) 94 | 95 | 96 | def main(): 97 | # Use first line of file docstring as description if it exists. 98 | parser = argparse.ArgumentParser( 99 | description=__doc__.split('\n')[0] if __doc__ else '', 100 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 101 | parser.add_argument('--annotations', type=Path, required=True) 102 | # We need the frames dir because the pickles contain boxes into the ordered 103 | # list frames. 104 | parser.add_argument('--frames-dir', type=Path, required=True) 105 | parser.add_argument('--pickle-dir', type=Path, required=True) 106 | parser.add_argument('--oracle-category', action='store_true') 107 | parser.add_argument('--workers', type=int, default=8) 108 | parser.add_argument('--threshold', default=0.5, type=float) 109 | parser.add_argument('--vis-cats', nargs='*', type=str) 110 | parser.add_argument('--videos', nargs='*') 111 | parser.add_argument('--output-dir', 112 | type=Path, 113 | required=True) 114 | 115 | args = parser.parse_args() 116 | args.output_dir.mkdir(exist_ok=True, parents=True) 117 | common_setup(__file__, args.output_dir, args) 118 | 119 | paths = list(args.pickle_dir.rglob('*.pkl')) 120 | 121 | tao = Tao(args.annotations) 122 | cats = tao.cats.copy() 123 | for cat in cats.values(): 124 | if cat['name'] == 'baby': 125 | cat['name'] = 'person' 126 | 127 | tasks = [] 128 | for p in paths: 129 | video_name = str(p.relative_to(args.pickle_dir)).split('.pkl')[0] 130 | if args.videos is not None and video_name not in args.videos: 131 | continue 132 | output_video = args.output_dir / f'{video_name}.mp4' 133 | if output_video.exists(): 134 | continue 135 | tasks.append({ 136 | 'pickle_path': p, 137 | 'video_name': video_name, 138 | 'frames_root': args.frames_dir, 139 | 'cats': cats, 140 | 'vis_cats': args.vis_cats, 141 | 'annotations_json': args.annotations, 142 | 'threshold': args.threshold, 143 | 'output_video': output_video 144 | }) 145 | 146 | if args.workers == 0: 147 | for task in tqdm(tasks): 148 | visualize(**task) 149 | else: 150 | pool = Pool(args.workers) 151 | list(tqdm(pool.imap_unordered(visualize_star, tasks), 152 | total=len(tasks))) 153 | 154 | 155 | if __name__ == "__main__": 156 | main() 157 | -------------------------------------------------------------------------------- /scripts/trackers/sort/README.md: -------------------------------------------------------------------------------- 1 | SORT 2 | ===== 3 | 4 | A simple online and realtime tracking algorithm for 2D multiple object tracking in video sequences. 5 | See an example [video here](https://motchallenge.net/movies/ETH-Linthescher-SORT.mp4). 6 | 7 | By Alex Bewley 8 | 9 | ### Introduction 10 | 11 | SORT is a barebones implementation of a visual multiple object tracking framework based on rudimentary data association and state estimation techniques. It is designed for online tracking applications where only past and current frames are available and the method produces object identities on the fly. While this minimalistic tracker doesn't handle occlusion or re-entering objects its purpose is to serve as a baseline and testbed for the development of future trackers. 12 | 13 | SORT was initially described in an [arXiv tech report](http://arxiv.org/abs/1602.00763). At the time of the initial publication, SORT was ranked the best *open source* multiple object tracker on the [MOT benchmark](https://motchallenge.net/results/2D_MOT_2015/). 14 | 15 | This code has been tested on Mac OSX 10.10, and Ubuntu 14.04, with Python 2.7 (anaconda). 16 | 17 | **Note:** A significant proportion of SORT's accuracy is attributed to the detections. 18 | For your convenience, this repo also contains *Faster* RCNN detections for the MOT benchmark sequences in the [benchmark format](https://motchallenge.net/instructions/). To run the detector yourself please see the original [*Faster* RCNN project](https://github.com/ShaoqingRen/faster_rcnn) or the python reimplementation of [py-faster-rcnn](https://github.com/rbgirshick/py-faster-rcnn) by Ross Girshick. 19 | 20 | **Also see:** 21 | A new and improved version of SORT with a Deep Association Metric implemented in tensorflow is available at [https://github.com/nwojke/deep_sort](https://github.com/nwojke/deep_sort) . 22 | 23 | ### License 24 | 25 | SORT is released under the GPL License (refer to the LICENSE file for details) to promote the open use of the tracker and future improvements. If you require a permissive license contact Alex (alex@bewley.ai). 26 | 27 | ### Citing SORT 28 | 29 | If you find this repo useful in your research, please consider citing: 30 | 31 | @inproceedings{Bewley2016_sort, 32 | author={Bewley, Alex and Ge, Zongyuan and Ott, Lionel and Ramos, Fabio and Upcroft, Ben}, 33 | booktitle={2016 IEEE International Conference on Image Processing (ICIP)}, 34 | title={Simple online and realtime tracking}, 35 | year={2016}, 36 | pages={3464-3468}, 37 | keywords={Benchmark testing;Complexity theory;Detectors;Kalman filters;Target tracking;Visualization;Computer Vision;Data Association;Detection;Multiple Object Tracking}, 38 | doi={10.1109/ICIP.2016.7533003} 39 | } 40 | 41 | 42 | ### Dependencies: 43 | 44 | This code makes use of the following packages: 45 | 1. [`scikit-learn`](http://scikit-learn.org/stable/) 46 | 0. [`scikit-image`](http://scikit-image.org/download) 47 | 0. [`FilterPy`](https://github.com/rlabbe/filterpy) 48 | 49 | To install required dependencies run: 50 | ``` 51 | $ pip install -r requirements.txt 52 | ``` 53 | 54 | 55 | ### Demo: 56 | 57 | To run the tracker with the provided detections: 58 | 59 | ``` 60 | $ cd path/to/sort 61 | $ python sort.py 62 | ``` 63 | 64 | To display the results you need to: 65 | 66 | 0. Download the [2D MOT 2015 benchmark dataset](https://motchallenge.net/data/2D_MOT_2015/#download) 67 | 0. Create a symbolic link to the dataset 68 | ``` 69 | $ ln -s /path/to/MOT2015_challenge/data/2DMOT2015 mot_benchmark 70 | ``` 71 | 0. Run the demo with the ```--display``` flag 72 | ``` 73 | $ python sort.py --display 74 | ``` 75 | 76 | 77 | ### Main Results 78 | 79 | Using the [MOT challenge devkit](https://motchallenge.net/devkit/) the method produces the following results (as described in the paper). 80 | 81 | Sequence | Rcll | Prcn | FAR | GT MT PT ML| FP FN IDs FM| MOTA MOTP MOTAL 82 | --------------- |:----:|:----:|:----:|:-------------:|:-------------------:|:------------------: 83 | TUD-Campus | 68.5 | 94.3 | 0.21 | 8 6 2 0| 15 113 6 9| 62.7 73.7 64.1 84 | ETH-Sunnyday | 77.5 | 81.9 | 0.90 | 30 11 16 3| 319 418 22 54| 59.1 74.4 60.3 85 | ETH-Pedcross2 | 51.9 | 90.8 | 0.39 | 133 17 60 56| 330 3014 77 103| 45.4 74.8 46.6 86 | ADL-Rundle-8 | 44.3 | 75.8 | 1.47 | 28 6 16 6| 959 3781 103 211| 28.6 71.1 30.1 87 | Venice-2 | 42.5 | 64.8 | 2.75 | 26 7 9 10| 1650 4109 57 106| 18.6 73.4 19.3 88 | KITTI-17 | 67.1 | 92.3 | 0.26 | 9 1 8 0| 38 225 9 16| 60.2 72.3 61.3 89 | *Overall* | 49.5 | 77.5 | 1.24 | 234 48 111 75| 3311 11660 274 499| 34.0 73.3 35.1 90 | 91 | 92 | ### Using SORT in your own project 93 | 94 | Below is the gist of how to instantiate and update SORT. See the ['__main__'](https://github.com/abewley/sort/blob/master/sort.py#L239) section of [sort.py](https://github.com/abewley/sort/blob/master/sort.py#L239) for a complete example. 95 | 96 | from sort import * 97 | 98 | #create instance of SORT 99 | mot_tracker = Sort() 100 | 101 | # get detections 102 | ... 103 | 104 | # update SORT 105 | track_bbs_ids = mot_tracker.update(detections) 106 | 107 | # track_bbs_ids is a np array where each row contains a valid bounding box and track_id (last column) 108 | ... 109 | 110 | 111 | -------------------------------------------------------------------------------- /scripts/trackers/sort/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TAO-Dataset/tao/a3a713c51e2fdeb3a106c34b06d889ea581150a7/scripts/trackers/sort/__init__.py -------------------------------------------------------------------------------- /scripts/trackers/sort/create_json_for_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import random 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | from natsort import natsorted 9 | from script_utils.common import common_setup 10 | from tqdm import tqdm 11 | 12 | 13 | def create_json(track_result, groundtruth, output_dir): 14 | # Image without extension -> image id 15 | image_stem_to_info = { 16 | x['file_name'].rsplit('.', 1)[0]: x for x in groundtruth['images'] 17 | } 18 | valid_videos = {x['name'] for x in groundtruth['videos']} 19 | 20 | all_annotations = [] 21 | found_predictions = {} 22 | for video in tqdm(valid_videos): 23 | video_npz = track_result / f'{video}.npz' 24 | if not video_npz.exists(): 25 | logging.error(f'Could not find video {video} at {video_npz}') 26 | continue 27 | video_result = np.load(video_npz) 28 | frame_names = [x for x in video_result.keys() if x != 'field_order'] 29 | video_found = {} 30 | for frame in natsorted(frame_names): 31 | # (x0, y0, x1, y1, class, score, box_index, track_id) 32 | frame_name = f'{video}/{frame}' 33 | if frame_name not in image_stem_to_info: 34 | continue 35 | video_found[frame_name] = True 36 | image_info = image_stem_to_info[frame_name] 37 | all_annotations.extend([{ 38 | # (x1, y1) -> (w, h) 39 | 'image_id': image_info['id'], 40 | 'video_id': image_info['video_id'], 41 | 'track_id': int(x[7]), 42 | 'bbox': [x[0], x[1], x[2] - x[0], x[3] - x[1]], 43 | 'category_id': x[4], 44 | 'score': x[5], 45 | } for x in video_result[frame]]) 46 | if not video_found: 47 | raise ValueError(f'Found no valid predictions for video {video}') 48 | found_predictions.update(video_found) 49 | if not found_predictions: 50 | raise ValueError('Found no valid predictions!') 51 | 52 | with_predictions = set(found_predictions.keys()) 53 | with_labels = set(image_stem_to_info.keys()) 54 | if with_predictions != with_labels: 55 | missing_videos = { 56 | x.rsplit('/', 1)[0] 57 | for x in with_labels - with_predictions 58 | } 59 | logging.warn( 60 | f'{len(with_labels - with_predictions)} images from ' 61 | f'{len(missing_videos)} videos did not have predictions!') 62 | 63 | with open(output_dir / 'results.json', 'w') as f: 64 | json.dump(all_annotations, f) 65 | 66 | 67 | def main(): 68 | # Use first line of file docstring as description if it exists. 69 | parser = argparse.ArgumentParser( 70 | description=__doc__.split('\n')[0] if __doc__ else '', 71 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 72 | parser.add_argument('--track-result', required=True, type=Path) 73 | parser.add_argument('--annotations-json', 74 | type=Path, 75 | help='Annotations json') 76 | parser.add_argument('--output-dir', required=True, type=Path) 77 | 78 | args = parser.parse_args() 79 | args.output_dir.mkdir(exist_ok=True, parents=True) 80 | common_setup(__file__, args.output_dir, args) 81 | 82 | with open(args.annotations_json, 'r') as f: 83 | groundtruth = json.load(f) 84 | 85 | create_json(args.track_result, groundtruth, args.output_dir) 86 | 87 | 88 | if __name__ == "__main__": 89 | main() 90 | -------------------------------------------------------------------------------- /scripts/trackers/sort/requirements.txt: -------------------------------------------------------------------------------- 1 | scipy 2 | filterpy~=1.4.1 3 | numba~=0.38.1 4 | scikit-image~=0.14.0 5 | scikit-learn~=0.19.1 6 | lap~=0.4.0 7 | -------------------------------------------------------------------------------- /scripts/trackers/sort/sort_with_detection_id.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from sort import associate_detections_to_trackers, KalmanBoxTracker 4 | 5 | 6 | class SortWithDetectionId(object): 7 | def __init__(self, max_age=1, min_hits=3, iou_threshold=0.3): 8 | """ 9 | Sets key parameters for SORT 10 | """ 11 | self.max_age = max_age 12 | self.min_hits = min_hits 13 | self.trackers = [] 14 | self.frame_count = 0 15 | self.iou_threshold = iou_threshold 16 | 17 | def update(self, dets): 18 | """ 19 | Args: 20 | dets (np.array): Shape (num_boxes, 5), where each row contains 21 | [x1, y1, x2, y2, score] 22 | 23 | Retruns: 24 | tracks (np.array): Shape (num_boxes, 6), where each row contains 25 | [x1, y1, x2, y2, detection_index, track_id] 26 | """ 27 | self.frame_count += 1 28 | # get predicted locations from existing trackers. 29 | trks = np.zeros((len(self.trackers), 5)) 30 | to_del = [] 31 | ret = [] 32 | for t, trk in enumerate(trks): 33 | pos = self.trackers[t].predict()[0] 34 | trk[:] = [pos[0], pos[1], pos[2], pos[3], 0] 35 | if (np.any(np.isnan(pos))): 36 | to_del.append(t) 37 | 38 | trks = np.ma.compress_rows(np.ma.masked_invalid(trks)) 39 | for t in reversed(to_del): 40 | self.trackers.pop(t) 41 | matched, unmatched_dets, unmatched_trks = ( 42 | associate_detections_to_trackers(dets, trks, self.iou_threshold)) 43 | 44 | # update matched trackers with assigned detections 45 | track_to_det_index = {t: d for d, t in matched} 46 | # matched[i, 0] is matched to matched[i, 1] 47 | for t, trk in enumerate(self.trackers): 48 | if (t not in unmatched_trks): 49 | d = track_to_det_index[t] 50 | trk.update(dets[d, :]) 51 | 52 | # create and initialise new trackers for unmatched detections 53 | for i in unmatched_dets: 54 | trk = KalmanBoxTracker(dets[i, :]) 55 | self.trackers.append(trk) 56 | track_to_det_index[len(self.trackers) - 1] = i 57 | i = len(self.trackers) 58 | for t, trk in reversed(list(enumerate(self.trackers))): 59 | d = trk.get_state()[0] 60 | det_id = track_to_det_index.get(t, -1) 61 | if ((trk.time_since_update < 1) 62 | and (trk.hit_streak >= self.min_hits 63 | or self.frame_count <= self.min_hits)): 64 | ret.append( 65 | np.concatenate((d, [det_id, trk.id + 1])).reshape( 66 | 1, -1)) # +1 as MOT benchmark requires positive 67 | i -= 1 68 | # remove dead tracklet 69 | if (trk.time_since_update > self.max_age): 70 | self.trackers.pop(i) 71 | if (len(ret) > 0): 72 | return np.concatenate(ret) 73 | return np.empty((0, 6)) 74 | -------------------------------------------------------------------------------- /scripts/trackers/sort/track.py: -------------------------------------------------------------------------------- 1 | """Link detections using the SORT tracker.""" 2 | 3 | import argparse 4 | import itertools 5 | import json 6 | import logging 7 | import pickle 8 | import sys 9 | from collections import defaultdict 10 | from multiprocessing import Pool 11 | from pathlib import Path 12 | 13 | import numpy as np 14 | import torch 15 | from natsort import natsorted 16 | from script_utils.common import common_setup 17 | from torchvision.ops import nms 18 | from tqdm import tqdm 19 | 20 | from tao.utils.fs import dir_path 21 | 22 | # Add current directory to path 23 | sys.path.insert(0, str(Path(__file__).resolve().parent)) 24 | from sort_with_detection_id import SortWithDetectionId 25 | from create_json_for_eval import create_json 26 | 27 | 28 | def sort_track_vid(video_predictions, sort_kwargs): 29 | """ 30 | Args: 31 | video_predictions (List): video_predictions[i] is a numpy array of 32 | shape (n, 4), containing [x0, y0, x1, y1]. 33 | 34 | Returns: 35 | tracked_predictions (List): tracked_predictions[i] is a numpy array 36 | of shape (n', 6), containing 37 | [x0, y0, x1, y1, box_index, track_id], 38 | where box_index is -1 or indexes into video_predictions[i]. 39 | """ 40 | tracker = SortWithDetectionId(**sort_kwargs) 41 | return [ 42 | tracker.update(boxes) for boxes in tqdm( 43 | video_predictions, desc='Running tracker', disable=True) 44 | ] 45 | 46 | 47 | def track_and_save(pickle_paths, output, score_threshold, 48 | nms_thresh, sort_kwargs): 49 | paths = natsorted(pickle_paths) 50 | all_instances = [] 51 | for path in paths: 52 | with open(path, 'rb') as f: 53 | data = pickle.load(f)['instances'] 54 | all_instances.append({ 55 | 'scores': np.array(data['scores']), 56 | 'pred_classes': np.array(data['pred_classes']), 57 | 'pred_boxes': np.array(data['pred_boxes']) 58 | }) 59 | 60 | if score_threshold > -float('inf'): 61 | for i, data in enumerate(all_instances): 62 | valid = data['scores'] > score_threshold 63 | for x in ('scores', 'pred_boxes', 'pred_classes'): 64 | data[x] = data[x][valid] 65 | 66 | categories = sorted({ 67 | x 68 | for data in all_instances for x in data['pred_classes'] 69 | }) 70 | 71 | frame_infos = defaultdict(list) 72 | id_gen = itertools.count(1) 73 | unique_track_ids = defaultdict(lambda: next(id_gen)) 74 | for category in categories: 75 | class_instances = [] 76 | for data in all_instances: 77 | in_class = data['pred_classes'] == category 78 | class_instances.append({ 79 | k: data[k][in_class] 80 | for k in ('scores', 'pred_boxes', 'pred_classes') 81 | }) 82 | 83 | if nms_thresh >= 0: 84 | for i, instances in enumerate(class_instances): 85 | nms_keep = nms(torch.from_numpy(instances['pred_boxes']), 86 | torch.from_numpy(instances['scores']), 87 | iou_threshold=nms_thresh).numpy() 88 | class_instances[i] = { 89 | 'scores': instances['scores'][nms_keep], 90 | 'pred_boxes': instances['pred_boxes'][nms_keep], 91 | 'pred_classes': instances['pred_classes'][nms_keep] 92 | } 93 | 94 | tracked_boxes = sort_track_vid( 95 | [x['pred_boxes'] for x in class_instances], sort_kwargs) 96 | 97 | for frame, frame_tracks in enumerate(tracked_boxes): 98 | # Each row is of the form (x0, y0, x1, y1, box_index, track_id) 99 | frame_boxes = frame_tracks[:, :4] 100 | box_indices = frame_tracks[:, 4].astype(int) 101 | track_ids = np.array([ 102 | unique_track_ids[(x, category)] for x in frame_tracks[:, 5] 103 | ]) 104 | 105 | frame_instances = class_instances[frame] 106 | frame_scores = np.zeros((len(box_indices), 1)) 107 | frame_classes = np.zeros((len(box_indices), 1)) 108 | for i, idx in enumerate(box_indices.astype(int)): 109 | if idx == -1: 110 | frame_classes[i] = -1 111 | frame_scores[i] = -1 112 | else: 113 | frame_classes[i] = frame_instances['pred_classes'][idx] 114 | frame_scores[i] = frame_instances['scores'][idx] 115 | frame_infos[paths[frame].stem].append( 116 | np.hstack( 117 | (frame_boxes, frame_classes + 1, frame_scores, 118 | box_indices[:, np.newaxis], track_ids[:, np.newaxis]))) 119 | 120 | frame_infos = {k: np.vstack(lst) for k, lst in frame_infos.items()} 121 | output.parent.mkdir(exist_ok=True, parents=True) 122 | np.savez_compressed(output, 123 | **frame_infos, 124 | field_order=[ 125 | 'x0', 'y0', 'x1', 'y1', 'class', 'score', 126 | 'box_index', 'track_id' 127 | ]) 128 | 129 | 130 | def track_and_save_star(args): 131 | track_and_save(*args) 132 | 133 | 134 | def main(): 135 | # Use first line of file docstring as description if it exists. 136 | parser = argparse.ArgumentParser( 137 | description=__doc__.split('\n')[0] if __doc__ else '', 138 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 139 | parser.add_argument('--detections-dir', 140 | type=Path, 141 | required=True, 142 | help='Results directory with pickle or mat files') 143 | parser.add_argument('--annotations', 144 | type=Path, 145 | required=True, 146 | help='Annotations json') 147 | parser.add_argument( 148 | '--output-dir', 149 | type=Path, 150 | required=True, 151 | help=('Output directory, where a results.json will be output, as well ' 152 | 'as a .npz file for each video, containing a boxes array of ' 153 | 'size (num_boxes, 6), of the format [x0, y0, x1, y1, class, ' 154 | 'score, box_index, track_id], where box_index maps into the ' 155 | 'pickle files')) 156 | parser.add_argument('--max-age', default=100, type=int) 157 | parser.add_argument('--min-hits', default=1, type=float) 158 | parser.add_argument('--min-iou', default=0.1, type=float) 159 | parser.add_argument('--score-threshold', 160 | default=0.0005, 161 | help='Float or "none".') 162 | parser.add_argument('--nms-thresh', type=float, default=-1) 163 | parser.add_argument('--workers', default=8, type=int) 164 | 165 | args = parser.parse_args() 166 | args.score_threshold = (-float('inf') if args.score_threshold == 'none' 167 | else float(args.score_threshold)) 168 | 169 | args.output_dir.mkdir(exist_ok=True, parents=True) 170 | common_setup(__file__, args.output_dir, args) 171 | 172 | npz_dir = dir_path(args.output_dir / 'npz_files') 173 | 174 | def get_output_path(video): 175 | return npz_dir / (video + '.npz') 176 | 177 | with open(args.annotations, 'r') as f: 178 | groundtruth = json.load(f) 179 | videos = [x['name'] for x in groundtruth['videos']] 180 | video_paths = {} 181 | for video in tqdm(videos, desc='Collecting paths'): 182 | output = get_output_path(video) 183 | if output.exists(): 184 | logging.debug(f'{output} already exists, skipping...') 185 | continue 186 | vid_detections = args.detections_dir / video 187 | assert vid_detections.exists(), ( 188 | f'No detections dir at {vid_detections}!') 189 | detection_paths = natsorted( 190 | (args.detections_dir / video).rglob(f'*.pkl')) 191 | assert detection_paths, ( 192 | f'No detections pickles at {vid_detections}!') 193 | video_paths[video] = detection_paths 194 | 195 | if not video_paths: 196 | logging.info(f'Nothing to do! Exiting.') 197 | return 198 | logging.info(f'Found {len(video_paths)} videos to track.') 199 | 200 | tasks = [] 201 | for video, paths in tqdm(video_paths.items()): 202 | output = get_output_path(video) 203 | tasks.append((paths, output, args.score_threshold, 204 | args.nms_thresh, { 205 | 'iou_threshold': args.min_iou, 206 | 'min_hits': args.min_hits, 207 | 'max_age': args.max_age 208 | })) 209 | 210 | if args.workers > 0: 211 | pool = Pool(args.workers) 212 | list( 213 | tqdm(pool.imap_unordered(track_and_save_star, tasks), 214 | total=len(tasks), 215 | desc='Tracking')) 216 | else: 217 | for task in tqdm(tasks): 218 | track_and_save(*task) 219 | logging.info(f'Finished') 220 | 221 | create_json(npz_dir, groundtruth, args.output_dir) 222 | 223 | 224 | if __name__ == "__main__": 225 | main() 226 | -------------------------------------------------------------------------------- /scripts/trackers/sort/visualize.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import json 4 | import random 5 | from collections import defaultdict 6 | from pathlib import Path 7 | 8 | import numpy as np 9 | from multiprocessing import Pool 10 | from natsort import natsorted 11 | from PIL import Image 12 | from script_utils.common import common_setup 13 | from tqdm import tqdm 14 | 15 | from tao.utils import fs 16 | from tao.utils import vis 17 | from tao.utils.colormap import colormap 18 | from tao.utils.video import video_writer 19 | 20 | 21 | def visualize_star(kwargs): 22 | return visualize(**kwargs) 23 | 24 | 25 | def visualize(video_npz, 26 | frames_dir, 27 | categories, 28 | threshold, 29 | track_threshold, 30 | output, 31 | vis_categories=None, 32 | progress=False): 33 | try: 34 | video_result = np.load(video_npz) 35 | except ValueError: 36 | print(video_npz) 37 | raise 38 | frame_names = [ 39 | x for x in video_result.keys() 40 | if x != 'field_order' and not x.startswith('__') 41 | ] 42 | output.parent.mkdir(exist_ok=True, parents=True) 43 | first_frame = fs.find_file_extensions(frames_dir, frame_names[0], 44 | fs.IMG_EXTENSIONS) 45 | if first_frame is None: 46 | raise ValueError(f'Could not find frame with name {frame_names[0]} in ' 47 | f'{frames_dir}') 48 | ext = first_frame.suffix 49 | w, h = Image.open(first_frame).size 50 | color_generator = itertools.cycle(colormap(rgb=True).tolist()) 51 | colors = defaultdict(lambda: next(color_generator)) 52 | 53 | track_scores = defaultdict(list) 54 | for frame in frame_names: 55 | for x in video_result[frame]: 56 | track_scores[x[-1]].append(x[5]) 57 | track_scores = {t: np.mean(vs) for t, vs in track_scores.items()} 58 | 59 | with video_writer(str(output), (w, h)) as writer: 60 | for frame in tqdm(natsorted(frame_names), disable=not progress): 61 | # Format is one of 62 | # (x0, y0, x1, y1, class, score, box_index, track_id) 63 | # (x0, y0, x1, y1, class, score, track_id) 64 | frame_result = video_result[frame] 65 | annotations = [ 66 | { 67 | # (x1, y1) -> (w, h) 68 | 'bbox': [x[0], x[1], x[2] - x[0], x[3] - x[1]], 69 | 'category_id': x[4], 70 | 'score': x[5], 71 | 'track_id': int(x[-1]) 72 | } for x in frame_result 73 | if x[5] > threshold and track_scores[x[-1]] > track_threshold 74 | ] 75 | if vis_categories is not None: 76 | annotations = [ 77 | x for x in annotations 78 | if categories[x['category_id']]['name'] in vis_categories 79 | ] 80 | annotations = sorted(annotations, key=lambda x: x['score']) 81 | box_colors = [colors[x['track_id']] for x in annotations] 82 | image = np.array(Image.open(frames_dir / (frame + ext))) 83 | image = vis.overlay_boxes_coco(image, 84 | annotations, 85 | colors=box_colors) 86 | image = vis.overlay_class_coco(image, 87 | annotations, 88 | categories, 89 | font_scale=1, 90 | font_thickness=2, 91 | show_track_id=True) 92 | writer.write_frame(image) 93 | 94 | 95 | def main(): 96 | # Use first line of file docstring as description if it exists. 97 | parser = argparse.ArgumentParser( 98 | description=__doc__.split('\n')[0] if __doc__ else '', 99 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 100 | parser.add_argument('--track-result', required=True, type=Path) 101 | parser.add_argument('--threshold', default=0.3, type=float) 102 | parser.add_argument('--track-threshold', default=0, type=float) 103 | parser.add_argument( 104 | '--annotations-json', 105 | type=Path, 106 | help='Annotations json; we only care about the "categories" field.') 107 | parser.add_argument('--frames-dir', required=True, type=Path) 108 | parser.add_argument('--output-dir', required=True, type=Path) 109 | parser.add_argument('--videos', nargs='*') 110 | parser.add_argument('--vis-cats', nargs='*', type=str) 111 | parser.add_argument('--num-videos', default=-1, type=int) 112 | parser.add_argument('--seed', default=0, type=int) 113 | parser.add_argument('--workers', default=8, type=int) 114 | 115 | args = parser.parse_args() 116 | args.output_dir.mkdir(exist_ok=True, parents=True) 117 | common_setup(__file__, args.output_dir, args) 118 | 119 | with open(args.annotations_json, 'r') as f: 120 | categories = {x['id']: x for x in json.load(f)['categories']} 121 | 122 | all_npz = args.track_result.rglob('*.npz') 123 | if args.num_videos > 0: 124 | all_npz = list(all_npz) 125 | random.seed(args.seed) 126 | random.shuffle(all_npz) 127 | all_npz = all_npz[:args.num_videos] 128 | 129 | tasks = [] 130 | for video_npz in all_npz: 131 | video = video_npz.relative_to(args.track_result).with_suffix('') 132 | if args.videos and str(video) not in args.videos: 133 | continue 134 | if args.videos: 135 | print(video) 136 | frames_dir = args.frames_dir / video 137 | output = args.output_dir / (str(video) + '.mp4') 138 | if output.exists(): 139 | continue 140 | tasks.append({ 141 | 'video_npz': video_npz, 142 | 'frames_dir': frames_dir, 143 | 'categories': categories, 144 | 'vis_categories': args.vis_cats, 145 | 'threshold': args.threshold, 146 | 'track_threshold': args.track_threshold, 147 | 'output': output 148 | }) 149 | 150 | if args.workers > 0: 151 | pool = Pool(args.workers) 152 | list(tqdm(pool.imap_unordered(visualize_star, tasks), 153 | total=len(tasks))) 154 | else: 155 | for task in tqdm(tasks): 156 | task['progress'] = True 157 | visualize(**task) 158 | 159 | 160 | if __name__ == "__main__": 161 | main() 162 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Note: To use the 'upload' functionality of this file, you must: 5 | # $ pipenv install twine --dev 6 | 7 | import io 8 | import os 9 | import sys 10 | from shutil import rmtree 11 | 12 | from setuptools import find_packages, setup, Command 13 | 14 | # Package meta-data. 15 | NAME = 'tao' 16 | DESCRIPTION = 'Track Any Object' 17 | URL = 'http://taodataset.org' 18 | EMAIL = 'achald@cs.cmu.edu' 19 | AUTHOR = 'Achal Dave' 20 | REQUIRES_PYTHON = '>=3.6.0' 21 | VERSION = '0.1.0' 22 | 23 | # What packages are required for this module to be executed? 24 | REQUIRED = [ 25 | 'script_utils @ git+https://github.com/achalddave/python-script-utils.git@v0.0.2#egg=script_utils', 26 | 'moviepy~=0.2', 'scipy', 'natsort', 'tqdm', 'yacs', 'boto3', 'youtube_dl', 27 | 'numba', 'motmetrics' 28 | # 'requests', 'maya', 'records', 29 | ] 30 | 31 | # What packages are optional? 32 | EXTRAS = { 33 | # 'fancy feature': ['django'], 34 | } 35 | 36 | # The rest you shouldn't have to touch too much :) 37 | # ------------------------------------------------ 38 | # Except, perhaps the License and Trove Classifiers! 39 | # If you do change the License, remember to change the Trove Classifier for that! 40 | 41 | here = os.path.abspath(os.path.dirname(__file__)) 42 | 43 | # Import the README and use it as the long-description. 44 | # Note: this will only work if 'README.md' is present in your MANIFEST.in file! 45 | try: 46 | with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f: 47 | long_description = '\n' + f.read() 48 | except FileNotFoundError: 49 | long_description = DESCRIPTION 50 | 51 | # Load the package's __version__.py module as a dictionary. 52 | about = {} 53 | if not VERSION: 54 | project_slug = NAME.lower().replace("-", "_").replace(" ", "_") 55 | with open(os.path.join(here, project_slug, '__version__.py')) as f: 56 | exec(f.read(), about) 57 | else: 58 | about['__version__'] = VERSION 59 | 60 | 61 | class UploadCommand(Command): 62 | """Support setup.py upload.""" 63 | 64 | description = 'Build and publish the package.' 65 | user_options = [] 66 | 67 | @staticmethod 68 | def status(s): 69 | """Prints things in bold.""" 70 | print('\033[1m{0}\033[0m'.format(s)) 71 | 72 | def initialize_options(self): 73 | pass 74 | 75 | def finalize_options(self): 76 | pass 77 | 78 | def run(self): 79 | try: 80 | self.status('Removing previous builds…') 81 | rmtree(os.path.join(here, 'dist')) 82 | except OSError: 83 | pass 84 | 85 | self.status('Building Source and Wheel (universal) distribution…') 86 | os.system('{0} setup.py sdist bdist_wheel --universal'.format(sys.executable)) 87 | 88 | self.status('Uploading the package to PyPI via Twine…') 89 | os.system('twine upload dist/*') 90 | 91 | self.status('Pushing git tags…') 92 | os.system('git tag v{0}'.format(about['__version__'])) 93 | os.system('git push --tags') 94 | 95 | sys.exit() 96 | 97 | 98 | # Where the magic happens: 99 | setup( 100 | name=NAME, 101 | version=about['__version__'], 102 | description=DESCRIPTION, 103 | long_description=long_description, 104 | long_description_content_type='text/markdown', 105 | author=AUTHOR, 106 | author_email=EMAIL, 107 | python_requires=REQUIRES_PYTHON, 108 | url=URL, 109 | packages=find_packages(exclude=["tests", "*.tests", "*.tests.*", "tests.*"]), 110 | # If your package is a single module, use this instead of 'packages': 111 | # py_modules=['tao'], 112 | 113 | # entry_points={ 114 | # 'console_scripts': ['mycli=mymodule:cli'], 115 | # }, 116 | install_requires=REQUIRED, 117 | extras_require=EXTRAS, 118 | include_package_data=True, 119 | license='MIT', 120 | classifiers=[ 121 | # Trove classifiers 122 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 123 | 'License :: OSI Approved :: MIT License', 124 | 'Programming Language :: Python', 125 | 'Programming Language :: Python :: 3', 126 | 'Programming Language :: Python :: 3.6', 127 | 'Programming Language :: Python :: Implementation :: CPython', 128 | 'Programming Language :: Python :: Implementation :: PyPy' 129 | ], 130 | # $ setup.py publish support. 131 | cmdclass={ 132 | 'upload': UploadCommand, 133 | }, 134 | ) 135 | -------------------------------------------------------------------------------- /tao/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TAO-Dataset/tao/a3a713c51e2fdeb3a106c34b06d889ea581150a7/tao/__init__.py -------------------------------------------------------------------------------- /tao/toolkit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TAO-Dataset/tao/a3a713c51e2fdeb3a106c34b06d889ea581150a7/tao/toolkit/__init__.py -------------------------------------------------------------------------------- /tao/toolkit/tao/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from .tao import Tao 3 | from .results import TaoResults 4 | from .eval import TaoEval 5 | 6 | logging.basicConfig( 7 | format="[%(asctime)s] %(name)s %(levelname)s: %(message)s", 8 | datefmt="%m/%d %H:%M:%S", 9 | level=logging.WARN, 10 | ) 11 | 12 | __all__ = ["Tao", "TaoResults", "TaoEval"] 13 | -------------------------------------------------------------------------------- /tao/toolkit/tao/results.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import logging 3 | from collections import defaultdict 4 | from copy import deepcopy 5 | 6 | import numpy as np 7 | 8 | from .tao import Tao 9 | 10 | 11 | class TaoResults(Tao): 12 | def __init__(self, tao_gt, results, max_dets=300): 13 | """Constructor for Tao results. 14 | Args: 15 | tao_gt (Tao or str): Tao class instance, or str containing path 16 | of annotation file) 17 | results (str or List[dict]): Contains path of result file or a 18 | list of dicts. Each dict should be an annotation containing 19 | a 'track_id.' The score for a track will be set as the average 20 | of the score of all annotations in the track. To use an 21 | alternative mechanism, the caller should pre-compute the track 22 | score and ensure all annotations in the track have the same 23 | score. 24 | max_dets (int): max number of detections per image. The official 25 | value of max_dets for Tao is 300. 26 | """ 27 | if isinstance(tao_gt, Tao): 28 | self.dataset = deepcopy(tao_gt.dataset) 29 | elif isinstance(tao_gt, str): 30 | self.dataset = self._load_json(tao_gt) 31 | else: 32 | raise TypeError("Unsupported type {} of tao_gt.".format( 33 | type(tao_gt))) 34 | 35 | self.logger = logging.getLogger('tao.results') 36 | self.logger.info("Loading and preparing results.") 37 | 38 | if isinstance(results, str): 39 | result_anns = self._load_json(results) 40 | else: 41 | # This path way is provided for efficiency, in case the JSON was 42 | # already loaded by the caller. 43 | self.logger.warn( 44 | "Assuming user provided the results in correct format.") 45 | result_anns = results 46 | 47 | merge_map = Tao._construct_merge_map(self.dataset) 48 | for x in result_anns: 49 | if x['category_id'] in merge_map: 50 | x['category_id'] = merge_map[x['category_id']] 51 | 52 | assert isinstance(result_anns, list), "results is not a list." 53 | 54 | self.ensure_unique_track_ids(result_anns) 55 | 56 | if max_dets >= 0: 57 | # NOTE: We limit detections per _frame_, not per video. 58 | result_anns = self.limit_dets_per_image(result_anns, max_dets) 59 | 60 | tracks = {} # Map track_id to track object 61 | if "bbox" in result_anns[0]: 62 | for id, ann in enumerate(result_anns): 63 | x1, y1, w, h = ann["bbox"] 64 | x2 = x1 + w 65 | y2 = y1 + h 66 | 67 | if "segmentation" not in ann: 68 | ann["segmentation"] = [[x1, y1, x1, y2, x2, y2, x2, y1]] 69 | 70 | track_id = ann['track_id'] 71 | if track_id not in tracks: 72 | tracks[track_id] = { 73 | 'id': track_id, 74 | 'video_id': ann['video_id'], 75 | 'category_id': ann['category_id'] 76 | } 77 | assert tracks[track_id]['category_id'] == ann['category_id'], ( 78 | f'Annotations for track {track_id} have multiple ' 79 | f'categories') 80 | ann["area"] = w * h 81 | ann["id"] = id + 1 82 | 83 | self.dataset["annotations"] = result_anns 84 | self.dataset["tracks"] = list(tracks.values()) 85 | self._create_index() 86 | 87 | _required_average = False 88 | for track_id, track_anns in self.track_ann_map.items(): 89 | scores = [float(x['score']) for x in track_anns] 90 | unique_scores = set(scores) 91 | if len(unique_scores) > 1: 92 | _required_average = True 93 | avg = np.mean(scores) 94 | self.tracks[track_id]['score'] = avg 95 | for x in track_anns: 96 | x['score'] = avg 97 | elif len(unique_scores) == 1: 98 | self.tracks[track_id]['score'] = unique_scores.pop() 99 | if _required_average: 100 | self.logger.warn( 101 | 'At least one track had annotations with different scores; ' 102 | 'using average of individual annotation scores as track ' 103 | 'scores.') 104 | 105 | img_ids_in_result = [ann["image_id"] for ann in result_anns] 106 | 107 | assert set(img_ids_in_result) == ( 108 | set(img_ids_in_result) & set(self.get_img_ids()) 109 | ), "Results do not correspond to current Tao set." 110 | 111 | def ensure_unique_track_ids(self, result_anns): 112 | track_id_videos = {} 113 | for ann in result_anns: 114 | t = ann['track_id'] 115 | if t not in track_id_videos: 116 | track_id_videos[t] = ann['video_id'] 117 | assert ann['video_id'] == track_id_videos[t], ( 118 | f'Track id {t} appears in more than one video: ' 119 | f'{track_id_videos[t]} and {ann["video_id"]}') 120 | 121 | def limit_dets_per_image(self, anns, max_dets): 122 | img_ann = defaultdict(list) 123 | for ann in anns: 124 | img_ann[ann["image_id"]].append(ann) 125 | 126 | for img_id, _anns in img_ann.items(): 127 | if len(_anns) <= max_dets: 128 | continue 129 | _anns = sorted(_anns, key=lambda ann: ann["score"], reverse=True) 130 | img_ann[img_id] = _anns[:max_dets] 131 | 132 | return [ann for anns in img_ann.values() for ann in anns] 133 | 134 | def get_top_results(self, img_id, score_thrs): 135 | raise NotImplementedError( 136 | 'Unclear if this should be per image or per video') 137 | # LVIS implementation below: 138 | # ann_ids = self.get_ann_ids(img_ids=[img_id]) 139 | # anns = self.load_anns(ann_ids) 140 | # return list(filter(lambda ann: ann["score"] > score_thrs, anns)) 141 | -------------------------------------------------------------------------------- /tao/trackers/sot/base.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | from abc import ABC, abstractmethod 4 | from contextlib import contextmanager 5 | from functools import partial 6 | from pathlib import Path 7 | 8 | import numpy as np 9 | from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter 10 | from tqdm import tqdm 11 | from PIL import Image 12 | 13 | from tao.utils import vis 14 | 15 | _GREEN = (18, 127, 15) 16 | _GRAY = (218, 227, 218) 17 | _BLACK = (0, 0, 0) 18 | COLOR_BOX = COLOR_MASK = [255*x for x in (0.000, 0.447, 0.741)] 19 | COLOR_TEXT = _GRAY 20 | COLOR_TEXT_INACTIVE = _BLACK 21 | COLOR_MASK_INACTIVE = COLOR_BOX_INACTIVE = _GRAY 22 | 23 | WIDTH_BOX = 10 24 | WIDTH_BOX_INACTIVE = 1 25 | 26 | WIDTH_MASK = 2 27 | BORDER_ALPHA_MASK = 0.9 28 | WIDTH_MASK_INACTIVE = 1 29 | 30 | 31 | class Tracker(ABC): 32 | @property 33 | def stateless(self): 34 | return False 35 | 36 | @abstractmethod 37 | def init(self, image, box): 38 | """ 39 | Args: 40 | image (np.array): Shape (height, width, num_channels). RGB image. 41 | box (list of int): (x0, y0, x1, y1). 0-indexed coordinates from 42 | top-left. 43 | """ 44 | pass 45 | 46 | @abstractmethod 47 | def update(self, image): 48 | """ 49 | Args: 50 | image (np.array): Shape (height, width, num_channels). RGB image. 51 | 52 | Returns: 53 | box (list of int): (x0, y0, x1, y1). 0-indexed coordinates from 54 | top-left. 55 | score (float) 56 | """ 57 | pass 58 | 59 | def track_yield(self, 60 | img_files, 61 | box, 62 | yield_image=False, 63 | **unused_extra_args): 64 | """ 65 | Args: 66 | img_files (list of str/Path): Ordered list of image paths 67 | box (list of int): (x0, y0, x1, y1). 0-indexed coordinates from 68 | top-left. 69 | yield_image (bool): Whether to yield the original image. Useful 70 | if the caller wants to operate on images without having to 71 | re-read them from disk. 72 | 73 | Yields: 74 | box (np.array): Shape (5, ), containing (x0, y0, x1, y1, score). 75 | 0-indexed coordinates from top-left. 76 | tracker_time (float): Time elapsed in tracker. 77 | image (optional, np.array): Image loaded from img_files; see 78 | yield_image. 79 | """ 80 | for f, img_file in enumerate(img_files): 81 | image = Image.open(img_file) 82 | if not image.mode == 'RGB': 83 | image = image.convert('RGB') 84 | image = np.array(image) 85 | 86 | start_time = time.time() 87 | if f == 0: 88 | self.init(image, box) 89 | elapsed_time = time.time() - start_time 90 | box = np.array([box[0], box[1], box[2], box[3], float('inf')]) 91 | extra_output = {} 92 | else: 93 | output = self.update(image) 94 | assert len(output) in (2, 3) 95 | box, score = output[:2] 96 | extra_output = output[2] if len(output) == 3 else {} 97 | elapsed_time = time.time() - start_time 98 | box = np.array([box[0], box[1], box[2], box[3], score]) 99 | if yield_image: 100 | yield box, elapsed_time, extra_output, image 101 | else: 102 | yield box, elapsed_time, extra_output 103 | 104 | @contextmanager 105 | def videowriter(self, 106 | output_video, 107 | width, 108 | height, 109 | fps=30, 110 | ffmpeg_params=None): 111 | if isinstance(output_video, Path): 112 | output_video = str(output_video) 113 | if ffmpeg_params is None: 114 | ffmpeg_params = [ 115 | '-vf', "scale=trunc(iw/2)*2:trunc(ih/2)*2", '-pix_fmt', 116 | 'yuv420p' 117 | ] 118 | with FFMPEG_VideoWriter( 119 | output_video, 120 | size=(width, height), 121 | fps=fps, 122 | ffmpeg_params=ffmpeg_params) as writer: 123 | yield writer 124 | 125 | def vis_single_prediction(self, 126 | image, 127 | box, 128 | mask=None, 129 | label=None, 130 | mask_border_width=WIDTH_MASK, 131 | mask_border_alpha=BORDER_ALPHA_MASK, 132 | box_color=COLOR_BOX, 133 | text_color=COLOR_TEXT, 134 | mask_color=COLOR_MASK): 135 | """ 136 | Args: 137 | image (np.array) 138 | box (list-like): x0, y0, x1, y1, score 139 | mask (np.array): Shape (height, width) 140 | """ 141 | if mask is None: 142 | image = vis.vis_bbox( 143 | image, (box[0], box[1], box[2] - box[0], box[3] - box[1]), 144 | fill_color=box_color) 145 | if label is None: 146 | text = f'Object: {box[4]:.02f}' 147 | else: 148 | # text = f'{label}: {box[4]:.02f}' 149 | text = f'{label}' 150 | image = vis.vis_class(image, (box[0], box[1] - 2), 151 | text, 152 | font_scale=0.75, 153 | text_color=text_color) 154 | # if box[4] < 0.8: # Draw gray masks when below threshold. 155 | # mask_color = [100, 100, 100] 156 | if mask is not None: 157 | image = vis.vis_mask( 158 | image, 159 | mask, 160 | mask_color, 161 | border_thick=mask_border_width, 162 | border_alpha=mask_border_alpha) 163 | return image 164 | 165 | def vis_image(self, 166 | image, 167 | box, 168 | mask=None, 169 | label=None, 170 | other_boxes=[], 171 | other_masks=[], 172 | other_labels=[], 173 | vis_threshold=0.1): 174 | """ 175 | Args: 176 | image (np.array) 177 | box (list-like): x0, y0, x1, y1, score 178 | mask (np.array): Shape (height, width) 179 | other_boxes (list[list-like]): Contains alternative boxes that 180 | were not selected. 181 | other_masks (list[list-like]): Contains masks for alternative 182 | boxes that were not selected. 183 | """ 184 | return self.vis_single_prediction(image, box, mask, label=label) 185 | 186 | def track(self, 187 | img_files, 188 | box, 189 | show_progress=False, 190 | output_video=None, 191 | output_video_fps=30, 192 | visualize_subsample=1, 193 | visualize_threshold=0.1, 194 | return_masks=False, 195 | **tracker_args): 196 | """ 197 | Like self.track, but collect all tracking results in numpy arrays. 198 | 199 | Args: 200 | img_files (list of str/Path): Ordered list of image paths 201 | box (list of int): (x0, y0, x1, y1). 0-indexed coordinates from 202 | top-left. 203 | output_vis 204 | return_masks (bool): If false, don't return masks. This is helpful 205 | for OxUvA, where collecting all the masks may use too much 206 | memory. 207 | 208 | Returns: 209 | boxes (np.array): Shape (num_frames, 5), contains 210 | (x0, y0, x1, y1, score) for each frame. 0-indexed coordinates 211 | from top-left. 212 | times (np.array): Shape (num_frames,), contains timings for each 213 | frame. 214 | """ 215 | frame_num = len(img_files) 216 | boxes = np.zeros((frame_num, 5)) 217 | if return_masks: 218 | masks = [None] * frame_num 219 | times = np.zeros(frame_num) 220 | 221 | pbar = partial(tqdm, total=len(img_files), disable=not show_progress) 222 | if output_video is None: 223 | for f, (box, elapsed_time, extra) in enumerate( 224 | pbar(self.track_yield(img_files, box, **tracker_args))): 225 | boxes[f] = box 226 | times[f] = elapsed_time 227 | if return_masks: 228 | masks[f] = extra.get('mask', None) 229 | else: 230 | output_video = Path(output_video) 231 | output_video.parent.mkdir(exist_ok=True, parents=True) 232 | # Some videos don't play in Firefox and QuickTime if '-pix_fmt 233 | # yuv420p' is not specified, and '-pix_fmt yuv420p' requires that 234 | # the dimensions be even, so we need the '-vf scale=...' filter. 235 | width, height = Image.open(img_files[0]).size 236 | with self.videowriter( 237 | output_video, width=width, height=height, 238 | fps=output_video_fps) as writer: 239 | track_outputs = self.track_yield( 240 | img_files, box, yield_image=True, **tracker_args) 241 | for f, (box, elapsed_time, extra, image) in enumerate( 242 | pbar(track_outputs)): 243 | mask = extra.get('mask', None) 244 | if mask is not None and mask.shape != image.shape[:2]: 245 | logging.warn( 246 | f'Resizing mask (shape {mask.shape}) to match ' 247 | f'image (shape {image.shape[:2]})') 248 | new_h, new_w = image.shape[:2] 249 | mask = np.asarray( 250 | Image.fromarray(mask).resize( 251 | (new_w, new_h), resample=Image.NEAREST)) 252 | other_boxes = extra.get('other_boxes', []) 253 | other_masks = extra.get('other_masks', []) 254 | label = extra.get('label', None) 255 | other_labels = extra.get('other_labels', []) 256 | if (f % visualize_subsample) == 0: 257 | writer.write_frame( 258 | self.vis_image(image, 259 | box, 260 | mask, 261 | label=label, 262 | other_boxes=other_boxes, 263 | other_masks=other_masks, 264 | other_labels=other_labels, 265 | vis_threshold=visualize_threshold)) 266 | boxes[f] = box 267 | times[f] = elapsed_time 268 | if return_masks: 269 | masks[f] = mask 270 | if return_masks: 271 | return boxes, masks, times 272 | else: 273 | return boxes, None, times 274 | -------------------------------------------------------------------------------- /tao/trackers/sot/pysot.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from pysot.core.config import cfg 5 | from pysot.models.model_builder import ModelBuilder 6 | from pysot.tracker.tracker_builder import build_tracker 7 | from .base import Tracker 8 | 9 | 10 | class PysotTracker(Tracker): 11 | def __init__(self, config_file, model_path): 12 | super().__init__() 13 | cfg.merge_from_file(config_file) 14 | model = ModelBuilder() 15 | model.load_state_dict( 16 | torch.load(model_path, 17 | map_location=lambda storage, loc: storage.cpu())) 18 | model.eval().cuda() 19 | self.tracker = build_tracker(model) 20 | 21 | def init(self, image, box): 22 | x0, y0, x1, y1 = box 23 | w = x1 - x0 24 | h = y1 - y0 25 | image = np.array(image)[:, :, [2, 1, 0]] # RGB -> BGR 26 | self.tracker.init(image, (x0, y0, w, h)) 27 | 28 | def update(self, image): 29 | image = np.array(image)[:, :, [2, 1, 0]] # RGB -> BGR 30 | output = self.tracker.track(image) 31 | x0, y0, w, h = output['bbox'] 32 | box = (x0, y0, x0 + w, y0 + h) 33 | return box, output['best_score'], {} 34 | -------------------------------------------------------------------------------- /tao/trackers/sot/pytracking.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import numpy as np 4 | 5 | from pytracking.features.net_wrappers import NetWithBackbone 6 | 7 | from .base import Tracker 8 | 9 | 10 | class PytrackingTracker(Tracker): 11 | """Wrapper around PyTracking Trackers. 12 | 13 | This re-implements some code from pytracking.evaluation.Tracker, removing 14 | parts of the code from that class that write to disk.""" 15 | def __init__(self, tracker_name, tracker_param, model_path=None): 16 | super().__init__() 17 | 18 | tracker_module = importlib.import_module( 19 | 'pytracking.tracker.{}'.format(tracker_name)) 20 | self.tracker_class = tracker_module.get_tracker_class() 21 | self.params = self.get_parameters(tracker_name, 22 | tracker_param, 23 | model_path=model_path) 24 | 25 | def get_parameters(self, tracker_name, tracker_param, model_path=None): 26 | """Get parameters.""" 27 | param_module = importlib.import_module( 28 | 'pytracking.parameter.{}.{}'.format(tracker_name, 29 | tracker_param)) 30 | params = param_module.parameters() 31 | if model_path is not None: 32 | params.net = NetWithBackbone(net_path=model_path, 33 | use_gpu=params.use_gpu) 34 | return params 35 | 36 | def init(self, image, box): 37 | self.tracker = self.tracker_class(self.params) 38 | x0, y0, x1, y1 = box 39 | w = x1 - x0 40 | h = y1 - y0 41 | image = np.array(image)[:, :, [2, 1, 0]] # RGB -> BGR 42 | self.tracker.initialize(image, {'init_bbox': [x0, y0, w, h]}) 43 | 44 | def update(self, image): 45 | image = np.array(image)[:, :, [2, 1, 0]] # RGB -> BGR 46 | output = self.tracker.track(image) 47 | x0, y0, w, h = output['target_bbox'] 48 | box = (x0, y0, x0 + w, y0 + h) 49 | return box, self.tracker.debug_info['max_score'], {} 50 | -------------------------------------------------------------------------------- /tao/trackers/sot/srdcf.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import logging 3 | import os 4 | from tempfile import NamedTemporaryFile 5 | 6 | import numpy as np 7 | from scipy.io import loadmat 8 | 9 | from tao.utils.paths import ROOT_DIR 10 | from .base import Tracker 11 | 12 | 13 | SRDCF_ROOT = ROOT_DIR / 'third_party/srdcf' 14 | 15 | 16 | class SrdcfTracker(Tracker): 17 | def init(self, *args, **kwargs): 18 | raise NotImplementedError 19 | 20 | def update(self, *args, **kwargs): 21 | raise NotImplementedError 22 | 23 | def track(self, 24 | img_files, 25 | box, 26 | show_progress=False, 27 | output_video=None, 28 | output_video_fps=30, 29 | visualize_subsample=1, 30 | visualize_threshold=0.1, 31 | return_masks=False, 32 | **tracker_args): 33 | x0, y0, x1, y1 = box 34 | w = x1 - x0 35 | h = y1 - y0 36 | region = [x0, y0, w, h] 37 | images_list = NamedTemporaryFile('w') 38 | images_list.writelines([f'{x}\n' for x in img_files]) 39 | images_list.seek(0) 40 | # print(images_list.name) 41 | # print('hi') 42 | # subprocess.run(['cat', images_list.name], stderr=subprocess.STDOUT) 43 | # print('hi') 44 | # print([f'{x}\n' for x in img_files][:5]) 45 | # print('hi') 46 | # print(img_files) 47 | # print('hi') 48 | 49 | output = NamedTemporaryFile('w', suffix='.mat') 50 | command = [ 51 | 'matlab', '-r', 52 | f"run_SRDCF_TAO('{images_list.name}', {region}, '{output.name}'); " 53 | f"quit" 54 | ] 55 | # Conda is clashing with MATLAB here, causing an error in C++ ABIs. 56 | # Unsetting LD_LIBRARY_PATH fixes this. 57 | env = os.environ.copy() 58 | env['LD_LIBRARY_PATH'] = '' 59 | try: 60 | subprocess.check_output(command, 61 | stderr=subprocess.STDOUT, 62 | cwd=str(SRDCF_ROOT), 63 | env=env) 64 | except subprocess.CalledProcessError as e: 65 | logging.fatal('Failed command.\nException: %s\nOutput %s', 66 | e.returncode, e.output.decode('utf-8')) 67 | raise 68 | 69 | result = loadmat(output.name)['results'].squeeze() 70 | images_list.close() 71 | output.close() 72 | 73 | boxes = result['res'].item() 74 | # width, height -> x1, y1 75 | boxes[:, 2] += boxes[:, 0] 76 | boxes[:, 3] += boxes[:, 1] 77 | # scores = result['scores'].item() 78 | scores = np.ones((boxes.shape[0], 1)) 79 | scores[0] = float('inf') 80 | boxes = np.hstack((boxes, scores)) 81 | return boxes, None, None 82 | -------------------------------------------------------------------------------- /tao/trackers/sot/staple.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import logging 3 | import os 4 | from tempfile import NamedTemporaryFile 5 | 6 | import numpy as np 7 | from scipy.io import loadmat 8 | 9 | from tao.utils.paths import ROOT_DIR 10 | from .base import Tracker 11 | 12 | 13 | STAPLE_ROOT = ROOT_DIR / 'third_party/staple' 14 | 15 | 16 | class StapleTracker(Tracker): 17 | def init(self, *args, **kwargs): 18 | raise NotImplementedError 19 | 20 | def update(self, *args, **kwargs): 21 | raise NotImplementedError 22 | 23 | def track(self, 24 | img_files, 25 | box, 26 | show_progress=False, 27 | output_video=None, 28 | output_video_fps=30, 29 | visualize_subsample=1, 30 | visualize_threshold=0.1, 31 | return_masks=False, 32 | **tracker_args): 33 | x0, y0, x1, y1 = box 34 | w = x1 - x0 35 | h = y1 - y0 36 | region = [x0, y0, w, h] 37 | images_list = NamedTemporaryFile('w') 38 | images_list.writelines([f'{x}\n' for x in img_files]) 39 | images_list.seek(0) 40 | # print(images_list.name) 41 | # print('hi') 42 | # subprocess.run(['cat', images_list.name], stderr=subprocess.STDOUT) 43 | # print('hi') 44 | # print([f'{x}\n' for x in img_files][:5]) 45 | # print('hi') 46 | # print(img_files) 47 | # print('hi') 48 | 49 | output = NamedTemporaryFile('w', suffix='.mat') 50 | command = [ 51 | 'matlab', '-r', 52 | f"runTrackerTao('{images_list.name}', {region}, '{output.name}'); " 53 | f"quit" 54 | ] 55 | # Conda is clashing with MATLAB here, causing an error in C++ ABIs. 56 | # Unsetting LD_LIBRARY_PATH fixes this. 57 | env = os.environ.copy() 58 | env['LD_LIBRARY_PATH'] = '' 59 | try: 60 | subprocess.check_output(command, 61 | stderr=subprocess.STDOUT, 62 | cwd=str(STAPLE_ROOT), 63 | env=env) 64 | except subprocess.CalledProcessError as e: 65 | logging.fatal('Failed command.\nException: %s\nOutput %s', 66 | e.returncode, e.output.decode('utf-8')) 67 | raise 68 | 69 | result = loadmat(output.name)['results'].squeeze() 70 | images_list.close() 71 | output.close() 72 | 73 | boxes = result['res'].item() 74 | # width, height -> x1, y1 75 | boxes[:, 2] += boxes[:, 0] 76 | boxes[:, 3] += boxes[:, 1] 77 | scores = result['scores'].item() 78 | boxes = np.hstack((boxes, scores)) 79 | return boxes, None, None 80 | -------------------------------------------------------------------------------- /tao/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TAO-Dataset/tao/a3a713c51e2fdeb3a106c34b06d889ea581150a7/tao/utils/__init__.py -------------------------------------------------------------------------------- /tao/utils/colormap.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | ############################################################################## 15 | 16 | """An awesome colormap for really neat visualizations.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | from __future__ import unicode_literals 22 | 23 | import numpy as np 24 | 25 | 26 | def colormap(rgb=False, as_int=False): 27 | color_list = np.array( 28 | [ 29 | 0.000, 0.447, 0.741, 30 | 0.850, 0.325, 0.098, 31 | 0.929, 0.694, 0.125, 32 | 0.494, 0.184, 0.556, 33 | 0.466, 0.674, 0.188, 34 | 0.301, 0.745, 0.933, 35 | 0.635, 0.078, 0.184, 36 | 0.300, 0.300, 0.300, 37 | 0.600, 0.600, 0.600, 38 | 1.000, 0.000, 0.000, 39 | 1.000, 0.500, 0.000, 40 | 0.749, 0.749, 0.000, 41 | 0.000, 1.000, 0.000, 42 | 0.000, 0.000, 1.000, 43 | 0.667, 0.000, 1.000, 44 | 0.333, 0.333, 0.000, 45 | 0.333, 0.667, 0.000, 46 | 0.333, 1.000, 0.000, 47 | 0.667, 0.333, 0.000, 48 | 0.667, 0.667, 0.000, 49 | 0.667, 1.000, 0.000, 50 | 1.000, 0.333, 0.000, 51 | 1.000, 0.667, 0.000, 52 | 1.000, 1.000, 0.000, 53 | 0.000, 0.333, 0.500, 54 | 0.000, 0.667, 0.500, 55 | 0.000, 1.000, 0.500, 56 | 0.333, 0.000, 0.500, 57 | 0.333, 0.333, 0.500, 58 | 0.333, 0.667, 0.500, 59 | 0.333, 1.000, 0.500, 60 | 0.667, 0.000, 0.500, 61 | 0.667, 0.333, 0.500, 62 | 0.667, 0.667, 0.500, 63 | 0.667, 1.000, 0.500, 64 | 1.000, 0.000, 0.500, 65 | 1.000, 0.333, 0.500, 66 | 1.000, 0.667, 0.500, 67 | 1.000, 1.000, 0.500, 68 | 0.000, 0.333, 1.000, 69 | 0.000, 0.667, 1.000, 70 | 0.000, 1.000, 1.000, 71 | 0.333, 0.000, 1.000, 72 | 0.333, 0.333, 1.000, 73 | 0.333, 0.667, 1.000, 74 | 0.333, 1.000, 1.000, 75 | 0.667, 0.000, 1.000, 76 | 0.667, 0.333, 1.000, 77 | 0.667, 0.667, 1.000, 78 | 0.667, 1.000, 1.000, 79 | 1.000, 0.000, 1.000, 80 | 1.000, 0.333, 1.000, 81 | 1.000, 0.667, 1.000, 82 | 0.167, 0.000, 0.000, 83 | 0.333, 0.000, 0.000, 84 | 0.500, 0.000, 0.000, 85 | 0.667, 0.000, 0.000, 86 | 0.833, 0.000, 0.000, 87 | 1.000, 0.000, 0.000, 88 | 0.000, 0.167, 0.000, 89 | 0.000, 0.333, 0.000, 90 | 0.000, 0.500, 0.000, 91 | 0.000, 0.667, 0.000, 92 | 0.000, 0.833, 0.000, 93 | 0.000, 1.000, 0.000, 94 | 0.000, 0.000, 0.167, 95 | 0.000, 0.000, 0.333, 96 | 0.000, 0.000, 0.500, 97 | 0.000, 0.000, 0.667, 98 | 0.000, 0.000, 0.833, 99 | 0.000, 0.000, 1.000, 100 | 0.000, 0.000, 0.000, 101 | 0.143, 0.143, 0.143, 102 | 0.286, 0.286, 0.286, 103 | 0.429, 0.429, 0.429, 104 | 0.571, 0.571, 0.571, 105 | 0.714, 0.714, 0.714, 106 | 0.857, 0.857, 0.857, 107 | 1.000, 1.000, 1.000 108 | ] 109 | ).astype(np.float32) 110 | color_list = color_list.reshape((-1, 3)) * 255 111 | if not rgb: 112 | color_list = color_list[:, ::-1] 113 | if as_int: 114 | color_list = color_list.astype(np.uint8) 115 | return color_list 116 | -------------------------------------------------------------------------------- /tao/utils/cv2_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module for cv2 utility functions and maintaining version compatibility 3 | between 3.x and 4.x 4 | """ 5 | import cv2 6 | 7 | 8 | def findContours(*args, **kwargs): 9 | """ 10 | Wraps cv2.findContours to maintain compatiblity between versions 11 | 3 and 4 12 | 13 | Returns: 14 | contours, hierarchy 15 | """ 16 | if cv2.__version__.startswith('4'): 17 | contours, hierarchy = cv2.findContours(*args, **kwargs) 18 | elif cv2.__version__.startswith('3'): 19 | _, contours, hierarchy = cv2.findContours(*args, **kwargs) 20 | else: 21 | raise AssertionError( 22 | 'cv2 must be either version 3 or 4 to call this method') 23 | 24 | return contours, hierarchy 25 | -------------------------------------------------------------------------------- /tao/utils/detectron2/datasets.py: -------------------------------------------------------------------------------- 1 | from detectron2.data.datasets import register_coco_instances 2 | 3 | 4 | def register_datasets(): 5 | register_coco_instances( 6 | "lvis_v0.5_coco_2017_train", {}, 7 | "data/detectron_datasets/lvis-coco/lvis-0.5_coco2017_train.json", 8 | "data/detectron_datasets/lvis-coco/train2017") 9 | -------------------------------------------------------------------------------- /tao/utils/download.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from hashlib import md5 3 | from multiprocessing import Pool 4 | from pathlib import Path 5 | 6 | from tqdm import tqdm 7 | 8 | from tao.utils.video import dump_frames 9 | 10 | 11 | def dump_frames_star(task): 12 | return dump_frames(*task) 13 | 14 | 15 | def dump_tao_frames(videos, 16 | output_dirs, 17 | workers, 18 | tqdm_desc='Converting to frames'): 19 | fps = None 20 | extension = '.jpg' 21 | jpeg_qscale = 2 22 | 23 | for output_dir in output_dirs: 24 | Path(output_dir).mkdir(exist_ok=True, parents=True) 25 | 26 | dump_frames_tasks = [] 27 | for video_path, output_dir in zip(videos, output_dirs): 28 | dump_frames_tasks.append( 29 | (video_path, output_dir, fps, extension, jpeg_qscale)) 30 | 31 | # dump_frames code logs when, e.g., the expected number of frames does not 32 | # match the number of dumped frames. But these logs can have false 33 | # positives that are confusing, so we check that frames are correctly 34 | # dumped ourselves separately based on frames in TAO annotations. 35 | _log_level = logging.root.level 36 | logging.root.setLevel(logging.ERROR) 37 | if workers > 1: 38 | pool = Pool(workers) 39 | try: 40 | list( 41 | tqdm(pool.imap_unordered(dump_frames_star, dump_frames_tasks), 42 | total=len(dump_frames_tasks), 43 | leave=False, 44 | desc=tqdm_desc)) 45 | except KeyboardInterrupt: 46 | print('Parent received control-c, exiting.') 47 | pool.terminate() 48 | else: 49 | for task in tqdm(dump_frames_tasks): 50 | dump_frames_star(task) 51 | logging.root.setLevel(_log_level) 52 | 53 | 54 | def frame_checksums_diff(frames_dir, checksums, early_exit=False): 55 | missing = [] 56 | mismatch = [] 57 | 58 | checksums = {k.replace('.jpeg', '.jpg'): v for k, v in checksums.items()} 59 | extra = [x for x in frames_dir.rglob('.jpg') if x.name not in checksums] 60 | 61 | for frame, cksum in checksums.items(): 62 | path = frames_dir / frame 63 | if not path.exists(): 64 | missing.append(path) 65 | if early_exit: 66 | break 67 | if cksum: 68 | with open(path, 'rb') as f: 69 | md5_digest = md5(f.read()).hexdigest() 70 | if md5_digest != cksum: 71 | # path, seen, expected 72 | mismatch.append((path, md5_digest, cksum)) 73 | if early_exit: 74 | break 75 | return missing, mismatch, extra 76 | 77 | 78 | def are_tao_frames_dumped(frames_dir, checksums, warn=True, allow_extra=True): 79 | missing, mismatch, extra = frame_checksums_diff(frames_dir, 80 | checksums, 81 | early_exit=True) 82 | if allow_extra: 83 | extra = [] 84 | if warn and extra: 85 | logging.warning(f'Unexpected frame at {extra[0]}!') 86 | if warn and missing: 87 | logging.warning(f'Could not find frame at {missing[0]}!') 88 | if warn and mismatch: 89 | path, seen, expected = mismatch[0] 90 | logging.warning( 91 | f'Checksum for {path} did not match! ' 92 | f'Expected: {expected}, saw: {seen}') 93 | return not mismatch and not missing and not extra 94 | 95 | 96 | def remove_non_tao_frames(frames_dir, keep_frames): 97 | frames = {x.split('.')[0] for x in keep_frames} 98 | extracted_frames = list(frames_dir.glob('*.jpg')) 99 | to_remove = [x for x in extracted_frames if x.stem not in frames] 100 | assert len(to_remove) != len(extracted_frames) 101 | for frame in to_remove: 102 | frame.unlink() 103 | -------------------------------------------------------------------------------- /tao/utils/evaluation_mota.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import logging 3 | import math 4 | from collections import defaultdict 5 | 6 | import motmetrics as mm 7 | from tqdm import tqdm 8 | 9 | MOTA_COUNT_FIELDS = { 10 | 'num_unique_objects', 'mostly_tracked', 'partially_tracked', 'mostly_lost', 11 | 'num_false_positives', 'num_misses', 'num_switches', 'num_fragmentations', 12 | 'num_transfer', 'num_ascend', 'num_migrate' 13 | } 14 | MOTA_PERCENT_FIELDS = { 15 | 'idf1', 'idp', 'idr', 'recall', 'precision', 'mota', 'motp' 16 | } 17 | 18 | 19 | def merge_values(mota_key, mota_values): 20 | avg = sum(mota_values) / max(len(mota_values), 1e-9) 21 | if mota_key in MOTA_PERCENT_FIELDS: 22 | return (avg, -1) 23 | elif mota_key in MOTA_COUNT_FIELDS: 24 | return (avg, sum(mota_values)) 25 | else: 26 | raise ValueError("Unknown key: {key}") 27 | 28 | 29 | def merged_mota(mota_metrics): 30 | output = {} 31 | for key, values in mota_metrics.items(): 32 | values = [x for x in values if not math.isnan(x)] 33 | if key in MOTA_PERCENT_FIELDS: 34 | output[key] = sum(values) / max(len(values), 1e-9) 35 | elif key in MOTA_COUNT_FIELDS: 36 | output[f'{key}'] = sum(values) 37 | else: 38 | raise ValueError(f"Unknown key: {key}") 39 | return output 40 | 41 | 42 | def summarize_mota(group_videos, group_accumulators): 43 | summaries = {} 44 | for group, videos in tqdm(group_videos.items(), desc='Summarizing'): 45 | if not videos: 46 | continue 47 | metrics_host = mm.metrics.create() 48 | # MOTA code makes a bunch of unnecessary logs; disable them for now. 49 | old_level = logging.root.level 50 | logging.root.setLevel(logging.WARN) 51 | summaries[group] = metrics_host.compute_many( 52 | group_accumulators[group], 53 | metrics=mm.metrics.motchallenge_metrics, 54 | names=videos, 55 | generate_overall=True) 56 | logging.root.setLevel(old_level) 57 | return summaries 58 | 59 | 60 | def evaluate_mota(tao_eval, cfg, logger=logging.root): 61 | track_threshold = cfg.MOTA.TRACK_THRESHOLD 62 | tao = tao_eval.tao_gt 63 | results = tao_eval.tao_dt 64 | 65 | seen_categories = {x['category_id'] for x in tao.anns.values()} 66 | if not cfg.CATEGORIES: 67 | categories = [ 68 | x['id'] for x in tao.cats.values() if x['id'] in seen_categories 69 | ] 70 | else: 71 | categories = [ 72 | x['id'] for x in tao.cats.values() if x['synset'] in cfg.CATEGORIES 73 | ] 74 | # Map category to list of accumulators 75 | mota_accumulators = defaultdict(list) 76 | video_ids = sorted(tao.vids.keys()) 77 | valid_videos = defaultdict(list) 78 | for vid_id in tqdm(video_ids): 79 | video = tao.vids[vid_id] 80 | for category in categories: 81 | acc = mm.MOTAccumulator(auto_id=True) 82 | has_groundtruth = False 83 | has_predictions = False 84 | for image in tao.vid_img_map[video['id']]: 85 | groundtruth = [ 86 | x for x in tao.img_ann_map[image['id']] 87 | if x['category_id'] == category 88 | ] 89 | predictions = [ 90 | x for x in results.img_ann_map[image['id']] 91 | if x['category_id'] == category 92 | and float(x['score']) > track_threshold 93 | ] 94 | if not groundtruth and not predictions: 95 | continue 96 | if groundtruth: 97 | has_groundtruth = True 98 | if predictions: 99 | has_predictions = True 100 | # IoU is 1 - IoU here. MOT threshold here is IoU 0.5 101 | distances = mm.distances.iou_matrix( 102 | [x['bbox'] for x in groundtruth], 103 | [x['bbox'] for x in predictions], 104 | max_iou=0.5) 105 | acc.update([x['track_id'] for x in groundtruth], 106 | [x['track_id'] for x in predictions], distances) 107 | if not has_groundtruth: 108 | # MOTA is not defined for sequences without a groundtruth. 109 | if not cfg.MOTA.INCLUDE_NEGATIVE_VIDEOS: 110 | continue 111 | elif not (has_predictions 112 | and category in video['neg_category_ids']): 113 | continue 114 | if category in video['not_exhaustive_category_ids']: 115 | # Remove false positives. 116 | if isinstance(acc._indices, list): # motmetrics <=v1.1.3 117 | inds = [ 118 | i for i, event in enumerate(acc._events) 119 | if event[0] != 'FP' 120 | ] 121 | acc._indices = [acc._indices[i] for i in inds] 122 | acc._events = [acc._events[i] for i in inds] 123 | elif isinstance(acc._indices, dict): # motmetrics v1.2.0 124 | inds = [ 125 | i for i, event in enumerate(acc._events['Type']) 126 | if event != 'FP' 127 | ] 128 | acc._indices = { 129 | k: [v[i] for i in inds] 130 | for k, v in acc._indices.items() 131 | } 132 | acc._events = { 133 | k: [v[i] for i in inds] 134 | for k, v in acc._events.items() 135 | } 136 | else: 137 | raise ValueError( 138 | "Unknown _indices format in motmetrics. Please file " 139 | "an issue on the TAO repository, with your " 140 | "motmetrics version.") 141 | acc.cached_events_df = ( 142 | mm.MOTAccumulator.new_event_dataframe_with_data( 143 | acc._indices, acc._events)) 144 | valid_videos[category].append(video['name']) 145 | mota_accumulators[category].append(acc) 146 | 147 | summaries = [] 148 | raw_summaries = {} 149 | headers = None 150 | category_summaries = summarize_mota(valid_videos, mota_accumulators) 151 | for category, summary in category_summaries.items(): 152 | if headers is None: 153 | headers = summary.columns.values.tolist() 154 | summaries.append([tao.cats[category]['synset']] + 155 | summary.loc['OVERALL'].values.tolist()) 156 | raw_summaries[category] = summary 157 | 158 | merged = merged_mota({ 159 | key: [x[i+1] for x in summaries] 160 | for i, key in enumerate(headers) 161 | }) 162 | 163 | videos_by_dataset = defaultdict(list) 164 | for video in tao.vids.values(): 165 | videos_by_dataset[video['metadata']['dataset']].append(video) 166 | 167 | if cfg.MOTA.EVAL_BY_DATASET: 168 | dataset_overall = {} 169 | for dataset, videos in tqdm(videos_by_dataset.items(), 170 | desc='Summarizing by dataset'): 171 | video_names = {v['name'] for v in videos} 172 | dataset_videos = defaultdict(list) 173 | dataset_accums = defaultdict(list) 174 | for c in valid_videos: 175 | for v, accum in zip(valid_videos[c], mota_accumulators[c]): 176 | if v in video_names: 177 | dataset_videos[c].append(v) 178 | dataset_accums[c].append(accum) 179 | dataset_summaries = summarize_mota(dataset_videos, dataset_accums) 180 | mota_metrics_raw = { 181 | key: x.loc['OVERALL'].values[i+1] 182 | for x in dataset_summaries.values() 183 | for i, key in enumerate(headers) 184 | } 185 | dataset_overall[dataset] = merged_mota(mota_metrics_raw) 186 | else: 187 | dataset_overall = {} 188 | 189 | raw_summaries = { 190 | tao.cats[c]['synset']: v 191 | for c, v in raw_summaries.items() 192 | } 193 | metrics_headers = [] 194 | for x in headers: 195 | metrics_headers.append(x) 196 | return { 197 | 'summary_headers': headers, 198 | 'summaries': summaries, 199 | 'mota_headers': metrics_headers, 200 | 'overall': merged, 201 | 'track_threshold': track_threshold, 202 | 'raw_summaries': summaries, 203 | 'dataset_overall': dataset_overall 204 | } 205 | 206 | 207 | def log_mota(eval_info, logger=logging.root, output_dir=None, log_path=None): 208 | track_threshold = eval_info['mota_eval']['track_threshold'] 209 | headers = eval_info['mota_eval']['summary_headers'] 210 | mota_headers = eval_info['mota_eval']['mota_headers'] 211 | summaries = eval_info['mota_eval']['summaries'] 212 | dataset_overall = eval_info['mota_eval']['dataset_overall'] 213 | # Overall metrics 214 | overall = eval_info['mota_eval']['overall'] 215 | 216 | if output_dir: 217 | category_headers = ['category'] + mota_headers 218 | with open(output_dir / 'summary.csv', 'w') as f: 219 | writer = csv.DictWriter(f, fieldnames=category_headers, restval=-1) 220 | writer.writeheader() 221 | for summary in summaries: 222 | writer.writerow(dict(zip(category_headers, summary))) 223 | overall_row = {'category': 'OVERALL'} 224 | overall_row.update(overall) 225 | writer.writerow(overall_row) 226 | 227 | if dataset_overall: 228 | dataset_headers = ['dataset'] + mota_headers 229 | with open(output_dir / 'dataset_summaries.csv', 'w') as f: 230 | writer = csv.DictWriter(f, 231 | fieldnames=dataset_headers, 232 | restval=-1) 233 | writer.writeheader() 234 | for dataset, overall in dataset_overall.items(): 235 | overall_row = {'dataset': dataset} 236 | overall_row.update(overall) 237 | writer.writerow(overall_row) 238 | 239 | logger.info('Overall MOTA: %s', overall['mota']) 240 | first_keys = ['mota', 'idf1'] 241 | ordered_keys = first_keys + [ 242 | x for x in mota_headers[1:] if x not in first_keys 243 | ] 244 | log_keys = ['threshold'] + ordered_keys 245 | str_values = [str(track_threshold)] 246 | for k in ordered_keys: 247 | v = overall[k] 248 | if k in MOTA_PERCENT_FIELDS: 249 | v_str = f'{100*v:.2f}' 250 | elif k in MOTA_COUNT_FIELDS: 251 | v_str = str(int(v)) 252 | str_values.append(v_str) 253 | if output_dir: 254 | log_keys += ['path'] 255 | str_values += [log_path if log_path else output_dir] 256 | logger.info('Copy paste:\n%s\n%s', ','.join(log_keys), 257 | ','.join(str_values)) 258 | -------------------------------------------------------------------------------- /tao/utils/fs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm'] 5 | VIDEO_EXTENSIONS = ['.mp4', '.avi', '.mkv', '.mov'] 6 | 7 | 8 | def dir_path(path): 9 | """Wrapper around Path that ensures this directory is created.""" 10 | if not isinstance(path, Path): 11 | path = Path(path) 12 | path.mkdir(exist_ok=True, parents=True) 13 | return path 14 | 15 | 16 | def file_path(path): 17 | """Wrapper around Path that ensures parent directories are created. 18 | 19 | x = mkdir_parents(dir / video_with_dir_prefix) 20 | is short-hand for 21 | x = Path(dir / video_with_dir_prefix) 22 | x.parent.mkdir(exist_ok=True, parents=True) 23 | """ 24 | if not isinstance(path, Path): 25 | path = Path(path) 26 | path.resolve().parent.mkdir(exist_ok=True, parents=True) 27 | return path 28 | 29 | 30 | def glob_ext(path, extensions, recursive=False): 31 | if not isinstance(path, Path): 32 | path = Path(path) 33 | if recursive: 34 | # Handle one level of symlinks. 35 | path_children = list(path.glob('*')) 36 | all_files = list(path_children) 37 | for x in path_children: 38 | if x.is_dir(): 39 | all_files += x.rglob('*') 40 | else: 41 | all_files = path.glob('*') 42 | return [ 43 | x for x in all_files if any(x.name.endswith(y) for y in extensions) 44 | ] 45 | 46 | 47 | def find_file_extensions(folder, stem, possible_extensions): 48 | if not isinstance(folder, Path): 49 | folder = Path(folder) 50 | for ext in possible_extensions: 51 | if ext[0] != '.': 52 | ext = f'.{ext}' 53 | path = folder / f'{stem}{ext}' 54 | if path.exists(): 55 | return path 56 | return None 57 | 58 | 59 | def is_image_file(filename): 60 | """Checks if a file is an image. 61 | 62 | Args: 63 | filename (string): path to a file 64 | Returns: 65 | bool: True if the filename ends with a known image extension 66 | """ 67 | filename_lower = filename.lower() 68 | return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS) 69 | 70 | 71 | def simple_table(rows): 72 | lengths = [ 73 | max(len(row[i]) for row in rows) + 1 for i in range(len(rows[0])) 74 | ] 75 | row_format = ' '.join(('{:<%s}' % length) for length in lengths[:-1]) 76 | row_format += ' {}' # The last column can maintain its length. 77 | 78 | output = '' 79 | for i, row in enumerate(rows): 80 | if i > 0: 81 | output += '\n' 82 | output += row_format.format(*row) 83 | return output 84 | 85 | 86 | def parse_bool(arg): 87 | """Parse string to boolean. 88 | Using type=bool in argparse does not do the right thing. E.g. 89 | '--bool_flag False' will parse as True. See 90 | 91 | 92 | Usage: 93 | parser.add_argument( '--choice', type=parse_bool) 94 | """ 95 | if arg == 'True': 96 | return True 97 | elif arg == 'False': 98 | return False 99 | else: 100 | raise argparse.ArgumentTypeError("Expected 'True' or 'False'.") 101 | -------------------------------------------------------------------------------- /tao/utils/misc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import pickle 4 | 5 | from pathlib import Path 6 | from scipy.io import loadmat 7 | from tqdm import tqdm 8 | 9 | from tao.utils import misc 10 | 11 | 12 | def parse_bool(arg): 13 | """Parse string to boolean. 14 | 15 | Using type=bool in argparse does not do the right thing. E.g. 16 | '--bool_flag False' will parse as True. See 17 | 18 | """ 19 | if arg == 'True': 20 | return True 21 | elif arg == 'False': 22 | return False 23 | else: 24 | raise argparse.ArgumentTypeError("Expected 'True' or 'False'.") 25 | 26 | 27 | def load_detection_mat(mat): 28 | dictionary = {} 29 | f = loadmat(mat)['x'] 30 | result = {} 31 | # Assume mat files are of the format (x0, y0, x1, y1, label, score) 32 | if f.shape[1] == 6: 33 | result['pred_boxes'] = [[x[0], x[1], x[2], x[3]] for x in f[:, :4]] 34 | result['scores'] = [x for x in f[:, 5]] 35 | result['pred_classes'] = [x for x in f[:, 4]] 36 | elif f.shape[1] > 6: 37 | # Assume mat files are of the format 38 | # (x0, y0, x1, y1, label1_score, label2_score, ..., labeln_score) 39 | result['pred_boxes'] = [[x[0], x[1], x[2], x[3]] for x in f[:, :4]] 40 | result['scores'] = [] 41 | result['pred_classes'] = [] 42 | for box in f: 43 | label = box[4:].argmax() 44 | result['pred_classes'].append(label) 45 | result['scores'].append(box[label+4]) 46 | dictionary['instances'] = result 47 | return dictionary 48 | 49 | 50 | def load_detection_dir_as_results(root, 51 | annotations, 52 | detections_format='pickle', 53 | include_masks=False, 54 | score_threshold=None, 55 | max_dets_per_image=None, 56 | show_progress=False): 57 | """Load detections from dir as a results.json dict.""" 58 | if not isinstance(root, Path): 59 | root = Path(root) 60 | ext = { 61 | 'pickle': '.pickle', 62 | 'pkl': '.pkl', 63 | 'mat': '.mat' 64 | }[detections_format] 65 | bbox_annotations = [] 66 | if include_masks: 67 | segmentation_annotations = [] 68 | 69 | for image in tqdm(annotations['images'], 70 | desc='Collecting annotations', 71 | disable=not show_progress): 72 | path = (root / f'{image["file_name"]}').with_suffix(ext) 73 | if not path.exists(): 74 | logging.warn(f'Could not find detections for image ' 75 | f'{image["file_name"]} at {path}; skipping...') 76 | continue 77 | if detections_format in ('pickle', 'pkl'): 78 | with open(path, 'rb') as f: 79 | detections = pickle.load(f) 80 | else: 81 | detections = misc.load_detection_mat(path) 82 | 83 | num_detections = len(detections['instances']['scores']) 84 | indices = sorted(range(num_detections), 85 | key=lambda i: detections['instances']['scores'][i], 86 | reverse=True) 87 | 88 | if max_dets_per_image is not None: 89 | indices = indices[:max_dets_per_image] 90 | 91 | for idx in indices: 92 | entry = detections['instances']['pred_boxes'][idx] 93 | x1 = entry[0] 94 | y1 = entry[1] 95 | x2 = entry[2] 96 | y2 = entry[3] 97 | bbox = [int(x1), int(y1), int(x2-x1), int(y2-y1)] 98 | 99 | category = int(detections['instances']['pred_classes'][idx] + 1) 100 | score = detections['instances']['scores'][idx] 101 | if score_threshold is not None and score < score_threshold: 102 | continue 103 | 104 | try: 105 | score = score.item() 106 | except AttributeError: 107 | pass 108 | 109 | bbox_annotations.append({ 110 | 'image_id': image['id'], 111 | 'category_id': category, 112 | 'bbox': bbox, 113 | 'score': score, 114 | }) 115 | if include_masks: 116 | segmentation_annotations.append({ 117 | 'image_id': image['id'], 118 | 'category_id': category, 119 | 'segmentation': detections['instances']['pred_masks'][idx], 120 | 'score': score 121 | }) 122 | if include_masks: 123 | return bbox_annotations, segmentation_annotations 124 | else: 125 | return bbox_annotations 126 | -------------------------------------------------------------------------------- /tao/utils/parallel/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TAO-Dataset/tao/a3a713c51e2fdeb3a106c34b06d889ea581150a7/tao/utils/parallel/__init__.py -------------------------------------------------------------------------------- /tao/utils/parallel/fixed_gpu_pool.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | from tao.utils.parallel.pool_context import PoolWithContext 3 | 4 | 5 | class FixedGpuPool: 6 | """Pool where each process is attached to a specific GPU. 7 | 8 | Usage: 9 | def init(args, context): 10 | context['init_return'] = 'init' 11 | def run(args, context): 12 | return (context['gpu'], context['init_return'], args) 13 | p = FixedGpuPool([0, 1, 2, 3], init, None) 14 | print(p.map(run, ['task1', 'task2', 'task3'])) 15 | # [(0, 'init', 'task1'), (1, 'init', 'task2'), (2, 'hi', 'task3')] 16 | # NOTE: GPUs may be in different order 17 | """ 18 | 19 | def __init__(self, gpus, initializer=None, initargs=None): 20 | gpu_queue = mp.Manager().Queue() 21 | for gpu in gpus: 22 | gpu_queue.put(gpu) 23 | self.pool = PoolWithContext( 24 | len(gpus), _FixedGpuPool_init, (gpu_queue, initializer, initargs)) 25 | 26 | def map(self, task_fn, tasks): 27 | return self.pool.map(_FixedGpuPool_run, 28 | ((task_fn, task) for task in tasks)) 29 | 30 | def imap_unordered(self, task_fn, tasks): 31 | return self.pool.imap_unordered(_FixedGpuPool_run, 32 | ((task_fn, task) for task in tasks)) 33 | 34 | def close(self): 35 | self.pool.close() 36 | 37 | 38 | def _FixedGpuPool_init(args, context): 39 | gpu_queue, initializer, initargs = args 40 | context['gpu'] = gpu_queue.get() 41 | initializer(initargs, context=context) 42 | 43 | 44 | def _FixedGpuPool_run(args, context): 45 | task_fn, task_args = args 46 | return task_fn(task_args, context=context) 47 | 48 | 49 | if __name__ == "__main__": 50 | def _test_gpu_init(args, context): 51 | context['init_return'] = 'init' 52 | 53 | def _test_gpu_run(args, context): 54 | return (context['gpu'], context['init_return'], args) 55 | 56 | p = FixedGpuPool([0, 1, 2, 3], _test_gpu_init, 'init arg') 57 | print(p.map(_test_gpu_run, ['task1', 'task2', 'task3'])) 58 | -------------------------------------------------------------------------------- /tao/utils/parallel/pool_context.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | from collections.abc import Iterable 3 | 4 | 5 | _PoolWithContext_context = None 6 | 7 | 8 | def _PoolWithContext_init(initializer, init_args): 9 | global _PoolWithContext_context 10 | _PoolWithContext_context = {} 11 | if init_args is None: 12 | initializer(context=_PoolWithContext_context) 13 | else: 14 | initializer(init_args, context=_PoolWithContext_context) 15 | 16 | 17 | def _PoolWithContext_run(args): 18 | task_fn, task_args = args 19 | return task_fn(task_args, context=_PoolWithContext_context) 20 | 21 | 22 | class PoolWithContext: 23 | """Like multiprocessing.Pool, but pass output of initializer to map fn. 24 | 25 | Usage: 26 | def init(context): 27 | context['init_return'] = 'init' 28 | def run(args, context): 29 | return (context['init_return'], args) 30 | p = PoolWithContext(4, init) 31 | print(p.map(run, ['task1', 'task2', 'task3'])) 32 | # [('init', 'task1'), ('init', 'task2'), ('init', 'task3')] 33 | # NOTE: GPUs may be in different order 34 | """ 35 | def __init__(self, num_workers, initializer, initargs=None): 36 | self.pool = mp.Pool( 37 | num_workers, 38 | initializer=_PoolWithContext_init, 39 | initargs=(initializer, initargs)) 40 | 41 | def map(self, task_fn, tasks): 42 | return self.pool.map(_PoolWithContext_run, 43 | ((task_fn, task) for task in tasks)) 44 | 45 | def close(self): 46 | self.pool.close() 47 | 48 | def imap_unordered(self, task_fn, tasks): 49 | return self.pool.imap_unordered(_PoolWithContext_run, 50 | ((task_fn, task) for task in tasks)) 51 | 52 | 53 | if __name__ == "__main__": 54 | def _test_init(context): 55 | context['init_return'] = 'hi' 56 | 57 | def _test_init_2(context): 58 | context['hello'] = 2 59 | 60 | def _test_run(args, context): 61 | return (args, context['init_return']) 62 | 63 | def _test_run_2(args, context): 64 | return (args, context) 65 | 66 | p = PoolWithContext(4, _test_init) 67 | p2 = PoolWithContext(4, _test_init_2) 68 | print(p.map(_test_run, ['task1', 'task2', 'task3'])) 69 | print(p2.map(_test_run_2, ['task1', 'task2', 'task3'])) 70 | -------------------------------------------------------------------------------- /tao/utils/video.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import json 3 | import logging 4 | import os 5 | import subprocess 6 | from contextlib import contextmanager 7 | 8 | from pathlib import Path 9 | 10 | from moviepy.video.io.ffmpeg_reader import ffmpeg_parse_infos 11 | from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter 12 | from moviepy.tools import extensions_dict 13 | 14 | 15 | @contextmanager 16 | def video_writer(output, 17 | size, 18 | fps=30, 19 | codec=None, 20 | ffmpeg_params=None): 21 | """ 22 | Args: 23 | size (tuple): (width, height) tuple 24 | """ 25 | if isinstance(output, Path): 26 | output = str(output) 27 | 28 | if codec is None: 29 | extension = Path(output).suffix[1:] 30 | try: 31 | codec = extensions_dict[extension]['codec'][0] 32 | except KeyError: 33 | raise ValueError(f"Couldn't find the codec associated with the " 34 | f"filename ({output}). Please specify codec") 35 | 36 | if ffmpeg_params is None: 37 | ffmpeg_params = [ 38 | '-vf', "scale=trunc(iw/2)*2:trunc(ih/2)*2", '-pix_fmt', 'yuv420p' 39 | ] 40 | with FFMPEG_VideoWriter(output, 41 | size=size, 42 | fps=fps, 43 | codec=codec, 44 | ffmpeg_params=ffmpeg_params) as writer: 45 | yield writer 46 | 47 | 48 | def video_info(video): 49 | from moviepy.video.io.ffmpeg_reader import ffmpeg_parse_infos 50 | if isinstance(video, Path): 51 | video = str(video) 52 | 53 | info = ffmpeg_parse_infos(video) 54 | return { 55 | 'duration': info['duration'], 56 | 'fps': info['video_fps'], 57 | 'size': info['video_size'] # (width, height) 58 | } 59 | 60 | 61 | def are_frames_dumped(video_path, 62 | output_dir, 63 | expected_fps, 64 | expected_info_path, 65 | expected_name_format, 66 | log_reason=False): 67 | """Check if the output directory exists and has already been processed. 68 | 69 | 1) Check the info.json file to see if the parameters match. 70 | 2) Ensure that all the frames exist. 71 | 72 | Params: 73 | video_path (str) 74 | output_dir (str) 75 | expected_fps (num) 76 | expected_info_path (str) 77 | expected_name_format (str) 78 | """ 79 | # Ensure that info file exists. 80 | if not os.path.isfile(expected_info_path): 81 | if log_reason: 82 | logging.info("Info path doesn't exist at %s" % expected_info_path) 83 | return False 84 | 85 | # Ensure that info file is valid. 86 | with open(expected_info_path, 'r') as info_file: 87 | info = json.load(info_file) 88 | info_valid = info['frames_per_second'] == expected_fps \ 89 | and info['input_video_path'] == os.path.abspath(video_path) 90 | if not info_valid: 91 | if log_reason: 92 | logging.info("Info file (%s) is invalid" % expected_info_path) 93 | return False 94 | 95 | # Check that all frame paths exist. 96 | offset_if_one_indexed = 0 97 | if not os.path.exists(expected_name_format % 0): 98 | # If the 0th frame doesn't exist, either we haven't dumped the frames, 99 | # or the frames start with index 1 (this changed between versions of 100 | # moviepy, so we have to explicitly check). We can assume they start 101 | # with index 1, and continue. 102 | offset_if_one_indexed = 1 103 | 104 | # https://stackoverflow.com/a/28376817/1291812 105 | num_frames_cmd = [ 106 | 'ffprobe', '-v', 'error', '-select_streams', 'v:0', '-show_entries', 107 | 'stream=nb_frames', '-of', 'default=nokey=1:noprint_wrappers=1', 108 | video_path 109 | ] 110 | expected_num_frames = subprocess.check_output(num_frames_cmd, 111 | stderr=subprocess.STDOUT) 112 | expected_num_frames = int(expected_num_frames.decode().strip()) 113 | expected_frame_paths = [ 114 | expected_name_format % (i + offset_if_one_indexed) 115 | for i in range(expected_num_frames) 116 | ] 117 | missing_frames = [x for x in expected_frame_paths if not os.path.exists(x)] 118 | if missing_frames: 119 | if log_reason: 120 | logging.info("Missing frames:\n%s" % ('\n'.join(missing_frames))) 121 | return False 122 | 123 | # All checks passed 124 | return True 125 | 126 | 127 | def dump_frames(video_path, 128 | output_dir, 129 | fps, 130 | extension='.jpg', 131 | jpeg_qscale=2): 132 | """Dump frames at frames_per_second from a video to output_dir. 133 | 134 | If frames_per_second is None, the clip's fps attribute is used instead.""" 135 | output_dir.mkdir(exist_ok=True, parents=True) 136 | 137 | if extension[0] != '.': 138 | extension = f'.{extension}' 139 | 140 | try: 141 | video_info = ffmpeg_parse_infos(str(video_path)) 142 | video_fps = video_info['video_fps'] 143 | except OSError: 144 | logging.exception('Unable to open video (%s), skipping.' % video_path) 145 | raise 146 | except KeyError: 147 | logging.error('Unable to extract metadata about video (%s), skipping.' 148 | % video_path) 149 | logging.exception('Exception:') 150 | return 151 | info_path = '{}/info.json'.format(output_dir) 152 | name_format = '{}/frame%04d{}'.format(output_dir, extension) 153 | 154 | if fps is None or fps == 0: 155 | fps = video_fps # Extract all frames 156 | 157 | are_frames_dumped_wrapper = functools.partial( 158 | are_frames_dumped, 159 | video_path=video_path, 160 | output_dir=output_dir, 161 | expected_fps=fps, 162 | expected_info_path=info_path, 163 | expected_name_format=name_format) 164 | 165 | if extension.lower() in ('.jpg', '.jpeg'): 166 | qscale = ['-qscale:v', str(jpeg_qscale)] 167 | else: 168 | qscale = [] 169 | 170 | if are_frames_dumped_wrapper(log_reason=False): 171 | return 172 | 173 | successfully_wrote_images = False 174 | try: 175 | if fps == video_fps: 176 | cmd = ['ffmpeg', '-i', str(video_path)] + qscale + [name_format] 177 | else: 178 | cmd = ['ffmpeg', '-i', str(video_path) 179 | ] + qscale + ['-vf', 'fps={}'.format(fps), name_format] 180 | subprocess.check_output(cmd, stderr=subprocess.STDOUT) 181 | successfully_wrote_images = True 182 | except subprocess.CalledProcessError as e: 183 | logging.exception("Failed to dump images for %s", video_path) 184 | logging.error(e.output.decode('utf-8')) 185 | raise 186 | 187 | if successfully_wrote_images: 188 | info = {'frames_per_second': fps, 189 | 'input_video_path': os.path.abspath(video_path)} 190 | with open(info_path, 'w') as info_file: 191 | json.dump(info, info_file) 192 | 193 | if not are_frames_dumped_wrapper(log_reason=True): 194 | logging.warning( 195 | "Images for {} don't seem to be dumped properly!".format( 196 | video_path)) 197 | -------------------------------------------------------------------------------- /tao/utils/vis.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | import numpy as np 4 | import pycocotools.mask as mask_util 5 | 6 | from tao.utils import cv2_util 7 | from tao.utils.colormap import colormap 8 | 9 | 10 | _BLACK = (0, 0, 0) 11 | _GRAY = (218, 227, 218) 12 | _GREEN = (18, 127, 15) 13 | _WHITE = (255, 255, 255) 14 | 15 | _COLOR1 = tuple(255*x for x in (0.000, 0.447, 0.741)) 16 | 17 | 18 | def rect_with_opacity(image, top_left, bottom_right, fill_color, fill_opacity): 19 | with_fill = image.copy() 20 | with_fill = cv2.rectangle(with_fill, top_left, bottom_right, fill_color, 21 | cv2.FILLED) 22 | return cv2.addWeighted(with_fill, fill_opacity, image, 1 - fill_opacity, 0, 23 | image) 24 | 25 | 26 | def get_annotation_colors(annotations): 27 | # Sort boxes by area, this will ensure, e.g., that the largest box 28 | # has the same color in all frames of a video. 29 | areas = [x['bbox'][2] * x['bbox'][3] for x in annotations] 30 | box_order = sorted(range(len(areas)), key=lambda i: areas[i]) 31 | colors = colormap(rgb=True)[:len(annotations)].tolist() 32 | return [colors[i % len(colors)] for i in box_order] 33 | 34 | 35 | def vis_class(image, 36 | pos, 37 | class_str, 38 | font_scale=0.35, 39 | bg_color=_BLACK, 40 | bg_opacity=1.0, 41 | text_color=_GRAY, 42 | thickness=1): 43 | """Visualizes the class.""" 44 | x, y = int(pos[0]), int(pos[1]) 45 | # Compute text size. 46 | txt = class_str 47 | font = cv2.FONT_HERSHEY_SIMPLEX 48 | ((txt_w, txt_h), _) = cv2.getTextSize(txt, font, font_scale, 1) 49 | # Place text background. 50 | back_tl = x, y 51 | back_br = x + txt_w, y + int(1.3 * txt_h) 52 | # Show text. 53 | txt_tl = x, y + int(1 * txt_h) 54 | image = rect_with_opacity(image, back_tl, back_br, bg_color, bg_opacity) 55 | cv2.putText(image, 56 | txt, 57 | txt_tl, 58 | font, 59 | font_scale, 60 | text_color, 61 | thickness=thickness, 62 | lineType=cv2.LINE_AA) 63 | return image 64 | 65 | 66 | def overlay_class_coco(image, 67 | annotations, 68 | categories, 69 | background_colors=None, 70 | font_scale=0.5, 71 | font_thickness=1, 72 | bg_opacity=1.0, 73 | text_color=_GRAY, 74 | show_track_id=False): 75 | """ 76 | Adds class names in the positions defined by the top-left corner of the 77 | COCO annotation bounding box 78 | 79 | Arguments: 80 | image (np.ndarray): an image as returned by OpenCV 81 | annotations (List[dict]): List of COCO annotations. 82 | categories (dict): coco.cats 83 | """ 84 | assert not isinstance(categories, list), ( 85 | 'categories should be a dict with category ids as keys.') 86 | labels = [] 87 | for a in annotations: 88 | label = categories[a['category_id']]['name'] 89 | if label == 'baby': 90 | label = 'person' 91 | if show_track_id and 'track_id' in a: 92 | label = f'{label} ({a["track_id"]})' 93 | labels.append(label) 94 | # labels = [categories[i['category_id']]['name'] for i in annotations] 95 | # labels = predictions.get_field("labels").tolist() 96 | # labels = [categories[i] for i in labels] 97 | boxes = [[int(round(y)) for y in x['bbox']] for x in annotations] 98 | if background_colors is None: 99 | # colors = get_annotation_colors(annotations) 100 | colors = [_BLACK for _ in annotations] 101 | else: 102 | colors = background_colors 103 | 104 | for box, label, color in zip(boxes, labels, colors): 105 | vis_class(image, 106 | box, 107 | label, 108 | font_scale=font_scale, 109 | bg_color=color, 110 | bg_opacity=bg_opacity, 111 | text_color=text_color, 112 | thickness=font_thickness) 113 | 114 | return image 115 | 116 | 117 | def vis_bbox(image, 118 | box, 119 | border_color=_BLACK, 120 | fill_color=_COLOR1, 121 | fill_opacity=0.65, 122 | thickness=1): 123 | """Visualizes a bounding box.""" 124 | x0, y0, w, h = box 125 | x1, y1 = int(x0 + w), int(y0 + h) 126 | x0, y0 = int(x0), int(y0) 127 | # Draw border 128 | if fill_opacity > 0 and fill_color is not None: 129 | image = rect_with_opacity(image, (x0, y0), (x1, y1), tuple(fill_color), 130 | fill_opacity) 131 | image = cv2.rectangle(image, (x0, y0), (x1, y1), tuple(border_color), 132 | thickness) 133 | return image 134 | 135 | 136 | def overlay_boxes_coco(image, 137 | annotations, 138 | colors=None, 139 | border_color=None, 140 | fill_opacity=None, 141 | thickness=1): 142 | """ 143 | Adds the predicted boxes on top of the image 144 | 145 | Arguments: 146 | image (np.ndarray): an image as returned by OpenCV 147 | annotations (List[dict]): List of COCO annotations. 148 | """ 149 | boxes = [[int(round(y)) for y in x['bbox']] for x in annotations] 150 | 151 | sorted_inds = sorted(range(len(boxes)), 152 | key=lambda i: boxes[i][2] * boxes[i][3], 153 | reverse=True) 154 | 155 | if colors is None: 156 | colors = get_annotation_colors(annotations) 157 | 158 | for i in sorted_inds: 159 | box = boxes[i] 160 | color = colors[i] 161 | kwargs = {} 162 | if fill_opacity: 163 | kwargs['fill_opacity'] = fill_opacity 164 | if border_color is not None: 165 | kwargs['border_color'] = border_color 166 | image = vis_bbox(image, 167 | box, 168 | fill_color=color, 169 | thickness=thickness, 170 | **kwargs) 171 | return image 172 | 173 | 174 | def vis_mask(image, 175 | mask, 176 | color, 177 | alpha=0.4, 178 | show_border=True, 179 | border_alpha=0.5, 180 | border_thick=1, 181 | border_color=None): 182 | """Visualizes a single binary mask.""" 183 | image = image.astype(np.float32) 184 | mask = mask[0, :, :, None] 185 | idx = np.nonzero(mask) 186 | 187 | image[idx[0], idx[1], :] *= 1.0 - alpha 188 | image[idx[0], idx[1], :] += [alpha * x for x in color] 189 | 190 | if border_alpha == 0: 191 | return 192 | 193 | if border_color is None: 194 | border_color = [x * 0.5 for x in color] 195 | if isinstance(border_color, np.ndarray): 196 | border_color = border_color.tolist() 197 | contours, _ = cv2_util.findContours(mask, cv2.RETR_TREE, 198 | cv2.CHAIN_APPROX_SIMPLE) 199 | if border_alpha < 1: 200 | with_border = image.copy() 201 | cv2.drawContours(with_border, contours, -1, border_color, border_thick, 202 | cv2.LINE_AA) 203 | image = ((1 - border_alpha) * image + border_alpha * with_border) 204 | else: 205 | cv2.drawContours(image, contours, -1, border_color, border_thick, 206 | cv2.LINE_AA) 207 | return image.astype(np.uint8) 208 | 209 | 210 | def overlay_mask_coco(image, 211 | annotations, 212 | alpha=0.3, 213 | border_alpha=1.0, 214 | border_thick=2): 215 | """ 216 | Adds the instances contours for each predicted object. 217 | Each label has a different color. 218 | 219 | Arguments: 220 | image (np.ndarray): an image as returned by OpenCV 221 | annotations (List[dict]): List of COCO annotations. 222 | """ 223 | colors = colormap(rgb=True)[:len(annotations)] 224 | 225 | for annotation, color in zip(annotations, colors): 226 | mask = mask_util.decode(annotation['segmentation']) 227 | image = vis_mask(image, 228 | mask, 229 | color, 230 | alpha=alpha, 231 | border_alpha=border_alpha, 232 | border_thick=border_thick) 233 | return image 234 | -------------------------------------------------------------------------------- /tao/utils/yacs_util.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import yaml 4 | from typing import Any, Dict 5 | 6 | from yacs.config import CfgNode 7 | from yacs.config import _valid_type, _VALID_TYPES 8 | 9 | 10 | BASE_KEY = "_BASE_" 11 | 12 | 13 | def _load_yaml_with_base(filename: str, allow_unsafe: bool = False) -> CfgNode: 14 | """ 15 | Just like `yaml.load(open(filename))`, but inherit attributes from its 16 | `_BASE_`. 17 | 18 | Modified from 19 | https://github.com/facebookresearch/fvcore/blob/99cb965c67e675dc3259cd490c1dd78ab03a55ff/fvcore/common/config.py 20 | 21 | Args: 22 | filename (str): the file name of the current config. Will be used to 23 | find the base config file. 24 | allow_unsafe (bool): whether to allow loading the config file with 25 | `yaml.unsafe_load`. 26 | Returns: 27 | (dict): the loaded yaml 28 | """ 29 | with open(filename, "r") as f: 30 | try: 31 | cfg = yaml.safe_load(f) 32 | except yaml.constructor.ConstructorError: 33 | if not allow_unsafe: 34 | raise 35 | logger = logging.getLogger(__name__) 36 | logger.warning( 37 | "Loading config {} with yaml.unsafe_load. Your machine may " 38 | "be at risk if the file contains malicious content.".format( 39 | filename 40 | ) 41 | ) 42 | f.close() 43 | with open(filename, "r") as f: 44 | cfg = yaml.unsafe_load(f) # pyre-ignore 45 | 46 | if cfg is None: 47 | return cfg 48 | 49 | # pyre-ignore 50 | def merge_a_into_b(a: Dict[Any, Any], b: Dict[Any, Any]) -> None: 51 | # merge dict a into dict b. values in a will overwrite b. 52 | for k, v in a.items(): 53 | if isinstance(v, dict) and k in b: 54 | assert isinstance( 55 | b[k], dict), "Cannot inherit key '{}' from base!".format(k) 56 | merge_a_into_b(v, b[k]) 57 | else: 58 | b[k] = v 59 | 60 | if BASE_KEY in cfg: 61 | base_cfg_file = cfg[BASE_KEY] 62 | if base_cfg_file.startswith("~"): 63 | base_cfg_file = os.path.expanduser(base_cfg_file) 64 | if not any(map(base_cfg_file.startswith, 65 | ["/", "https://", "http://"])): 66 | # the path to base cfg is relative to the config file itself. 67 | base_cfg_file = os.path.join(os.path.dirname(filename), 68 | base_cfg_file) 69 | base_cfg = _load_yaml_with_base(base_cfg_file, 70 | allow_unsafe=allow_unsafe) 71 | del cfg[BASE_KEY] 72 | if base_cfg is None: 73 | return cfg 74 | 75 | merge_a_into_b(cfg, base_cfg) # pyre-ignore 76 | return base_cfg 77 | return cfg 78 | 79 | 80 | def merge_from_file_with_base(cfg, 81 | cfg_filename: str, 82 | allow_unsafe: bool = False) -> None: 83 | """ 84 | Merge configs from a given yaml file. 85 | 86 | Modified from 87 | https://github.com/facebookresearch/fvcore/blob/99cb965c67e675dc3259cd490c1dd78ab03a55ff/fvcore/common/config.py 88 | 89 | Args: 90 | cfg_filename: the file name of the yaml config. 91 | allow_unsafe: whether to allow loading the config file with 92 | `yaml.unsafe_load`. 93 | """ 94 | loaded_cfg = _load_yaml_with_base(cfg_filename, allow_unsafe=allow_unsafe) 95 | loaded_cfg = type(cfg)(loaded_cfg) 96 | cfg.merge_from_other_cfg(loaded_cfg) 97 | 98 | 99 | def cfg_to_dict(cfg_node, key_list=[]): 100 | if not isinstance(cfg_node, CfgNode): 101 | assert _valid_type(cfg_node), ( 102 | "Key {} with value {} is not a valid type; valid types: {}".format( 103 | ".".join(key_list), type(cfg_node), _VALID_TYPES)) 104 | return cfg_node 105 | else: 106 | cfg_dict = dict(cfg_node) 107 | for k, v in cfg_dict.items(): 108 | cfg_dict[k] = cfg_to_dict(v, key_list + [k]) 109 | return cfg_dict 110 | 111 | 112 | def cfg_to_flat_dict(cfg_node, key_list=[]): 113 | if not isinstance(cfg_node, CfgNode): 114 | assert _valid_type(cfg_node), ( 115 | "Key {} with value {} is not a valid type; valid types: {}".format( 116 | ".".join(key_list), type(cfg_node), _VALID_TYPES)) 117 | return cfg_node 118 | else: 119 | cfg_dict_flat = {} 120 | for k, v in dict(cfg_node).items(): 121 | updated = cfg_to_dict(v, key_list + [k]) 122 | if isinstance(updated, dict): 123 | for k1, v1 in updated.items(): 124 | cfg_dict_flat['.'.join(key_list + [k, k1])] = v1 125 | else: 126 | cfg_dict_flat['.'.join(key_list + [k])] = updated 127 | return cfg_dict_flat 128 | -------------------------------------------------------------------------------- /tao/utils/ytdl.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import sys 4 | from contextlib import redirect_stdout 5 | from io import BytesIO, StringIO 6 | from pathlib import Path 7 | 8 | import boto3 9 | import youtube_dl 10 | from tqdm import tqdm 11 | 12 | # In case we use pywren, we can't import from the tao module directly here 13 | # for some reason. Just import what we need. 14 | sys.path.insert(0, str(Path(__file__).parent)) 15 | import s3 as s3_utils 16 | 17 | 18 | class VideoUnavailableError(youtube_dl.DownloadError): 19 | pass 20 | 21 | 22 | def get_metadata(url): 23 | ydl_opts = { 24 | 'outtmpl': '-', 25 | 'skip_download': True, 26 | 'forcejson': True, 27 | 'quiet': True, 28 | 'nocheckcertificate': True, 29 | 'cachedir': False 30 | } 31 | info = StringIO() 32 | try: 33 | with redirect_stdout(info): 34 | with youtube_dl.YoutubeDL(ydl_opts) as ydl: 35 | ydl.download([url]) 36 | except youtube_dl.DownloadError as e: 37 | message = str(e) 38 | if 'This video is no longer available' in message: 39 | raise VideoUnavailableError(message) 40 | else: 41 | raise e 42 | except BaseException as e: 43 | print('Exception', e) 44 | import traceback 45 | traceback.print_exc() 46 | raise e 47 | return json.loads(info.getvalue()) 48 | 49 | 50 | def download_to_bytes(url, extra_opts={}): 51 | ydl_opts = { 52 | 'format': 'best[ext=mp4]', 53 | 'outtmpl': '-', 54 | 'logger': logging.getLogger('youtube-dl'), 55 | 'nocheckcertificate': True, 56 | 'cachedir': False 57 | } 58 | ydl_opts.update(extra_opts) 59 | 60 | video = BytesIO() 61 | try: 62 | with redirect_stdout(video): 63 | with youtube_dl.YoutubeDL(ydl_opts) as ydl: 64 | ydl.download([url]) 65 | except youtube_dl.DownloadError as e: 66 | message = str(e) 67 | if 'This video is no longer available' in message: 68 | raise VideoUnavailableError(message) 69 | else: 70 | raise e 71 | except BaseException as e: 72 | print('Exception', e) 73 | import traceback 74 | traceback.print_exc() 75 | raise e 76 | return video 77 | 78 | 79 | def chunks(l, n): 80 | """Yield successive n-sized chunks from l.""" 81 | for i in range(0, len(l), n): 82 | yield l[i:i + n] 83 | 84 | 85 | def pytube_download_bytes(url, extra_opts={}): 86 | buffer_obj = None 87 | try: 88 | buffer_obj = download_to_bytes(url, extra_opts=extra_opts) 89 | except VideoUnavailableError: 90 | return -1 91 | except: 92 | return -2 93 | return buffer_obj.getvalue() 94 | 95 | 96 | def store_bytes_to_s3(object_bytes, key, bucket): 97 | client = boto3.client('s3') 98 | client.put_object(Body=object_bytes, 99 | Key=key, 100 | Bucket=bucket) 101 | return len(object_bytes) 102 | 103 | 104 | def vid_id_to_name(vid_id): 105 | return f'v_{vid_id}' 106 | 107 | 108 | def download_and_store_vids(urls, 109 | ids, 110 | keys, 111 | bucket, 112 | ytdl_params={}, 113 | progress=False): 114 | logging.getLogger().setLevel(logging.INFO) 115 | logging.basicConfig(format='%(asctime)s.%(msecs).03d: %(message)s', 116 | datefmt='%H:%M:%S') 117 | logging.info('Downloading videos: %s', urls) 118 | downloaded = 0 119 | unavail = 0 120 | other_error = 0 121 | empty_bytes = 0 122 | blacklist = [] 123 | for url, vid_id, key in zip(tqdm(urls, disable=not progress), ids, keys): 124 | logging.info(f"Downloading Video {vid_id}") 125 | vid_bytes = pytube_download_bytes(url, ytdl_params) 126 | 127 | problem = None 128 | if isinstance(vid_bytes, int): 129 | if vid_bytes == -1: 130 | logging.info(f'unavailable: {vid_id}') 131 | problem = 'unavailable' 132 | unavail += 1 133 | blacklist.append(vid_id) 134 | elif vid_bytes == -2: 135 | problem = 'error' 136 | other_error += 2 137 | elif vid_bytes is None: 138 | problem = 'other_download_error' 139 | elif len(vid_bytes) == 0: 140 | logging.info(f'empty bytes: {vid_id}') 141 | problem = 'empty_bytes' 142 | empty_bytes += 1 143 | 144 | if problem is not None: 145 | problem_key = f'{key}.{problem}' 146 | store_bytes_to_s3('', problem_key, bucket) 147 | continue 148 | 149 | logging.debug(f"Uploading Video {vid_id} to S3") 150 | store_bytes_to_s3(vid_bytes, key, bucket) 151 | downloaded += 1 152 | import time 153 | import random 154 | time.sleep(random.random() * 2) 155 | return downloaded, unavail, other_error, blacklist 156 | 157 | 158 | def download_vids(videos, 159 | s3_bucket, 160 | s3_prefix, 161 | cache_dir=None, 162 | parallel=False, 163 | video_keys=None, 164 | ytdl_params={}, 165 | skip_exists=True, 166 | chunk_size=2, 167 | subset=None, 168 | verbose=False): 169 | urls = ['http://youtu.be/'+vid for vid in videos] 170 | ids = videos 171 | client = boto3.client('s3') 172 | exist_ids = set(s3_utils.list_all_keys(client, s3_bucket, s3_prefix)) 173 | 174 | if verbose: 175 | log = logging.info 176 | else: 177 | def log(*args, **kwargs): 178 | return 179 | 180 | if video_keys is None: 181 | video_keys = [ 182 | f"{vid_id_to_name(vid_id)}.mp4" for vid_id in videos 183 | ] 184 | if s3_prefix[-1] != '/': 185 | s3_prefix = s3_prefix + '/' 186 | video_keys = [f'{s3_prefix}{key}' for key in video_keys] 187 | 188 | # if processed_prefix is not None: 189 | # processed_key = f'{s3_prefix}{processed_prefix}' 190 | # if s3_utils.key_exists(s3_bucket, processed_key): 191 | # obj = client.Object(s3_bucket, processed_key) 192 | # data = obj.get()['Body'].read().decode('utf-8') 193 | # processed = set(data.split('\n')) 194 | # valid_videos = [(url, vid, k) 195 | # for (url, vid, k) in zip(urls, ids, video_keys) 196 | # if k not in processed] 197 | # if not valid_videos: 198 | # return {'num_downloaded': 0, 'unavailable': 0, 'num_errors': 0} 199 | # urls, ids, video_keys = zip(*valid_videos) 200 | 201 | if skip_exists: 202 | if len(exist_ids) > 0: 203 | valid_videos = [] 204 | for url, vid, key in zip(urls, ids, video_keys): 205 | error_keys = [ 206 | f'{key}.{x}' for x in ('unavailable', 'empty_bytes') 207 | ] 208 | check_keys = error_keys + [key] 209 | if not any(x in exist_ids for x in check_keys): 210 | valid_videos.append((url, vid, key)) 211 | if not valid_videos: 212 | return {'num_downloaded': 0, 'unavailable': 0, 'num_errors': 0} 213 | urls, ids, video_keys = zip(*valid_videos) 214 | log(f'{len(urls)}/{len(videos)} to download.') 215 | 216 | if subset is not None: 217 | videos = videos[:subset] 218 | urls = urls[:subset] 219 | if not parallel: 220 | results = [ 221 | download_and_store_vids(urls, ids, video_keys, s3_bucket, 222 | ytdl_params, progress=True) 223 | ] 224 | else: 225 | import pywren 226 | chunked_lst = list(chunks(list(zip(urls, ids, video_keys)), chunk_size)) 227 | log(f"{len(chunked_lst)} Pywren jobs total") 228 | pwex = pywren.default_executor() 229 | 230 | def pywren_f(elem): 231 | urls, ids, keys = zip(*elem) 232 | return download_and_store_vids(urls, ids, keys, s3_bucket, 233 | ytdl_params) 234 | 235 | try: 236 | log("Mapping...") 237 | futures = pwex.map(pywren_f, chunked_lst) 238 | pywren.wait(futures) 239 | results = [f.result() for f in futures] 240 | print('len(results)', len(results)) 241 | finally: 242 | for f in futures: 243 | f.cancel() 244 | 245 | unavailable = [y for x in results for y in x[3]] 246 | return { 247 | 'num_downloaded': sum([x[0] for x in results]), 248 | 'unavailable': unavailable, 249 | 'num_errors': sum([x[2] for x in results]), 250 | } 251 | --------------------------------------------------------------------------------