├── imgs └── teaser.png ├── checkpoint └── .gitignore ├── assets └── worldcup_field_model.png ├── dataset ├── WorldCup_2014_2018 │ └── .gitignore ├── soccer_worldcup_2014 │ └── .gitignore └── DATASET.md ├── LICENSE ├── opt_ours.txt ├── exp_ours.txt ├── robust ├── opt_robust.txt ├── exp_robust.txt ├── opt_robust_finetuned.txt ├── exp_robust_finetuned.txt ├── README.md ├── worldcup_loader.py ├── ts_worldcup_loader.py ├── models │ └── model.py └── test.py ├── opt_ours_finetuned.txt ├── exp_ours_finetuned.txt ├── exp_dice_bce.txt ├── exp_dice_wce.txt ├── inference.txt ├── loss.py ├── models ├── eval_network.py ├── inference_core.py ├── modules.py ├── non_local.py ├── network.py └── mod_resnet.py ├── .gitignore ├── options.py ├── README.md ├── environment.yml ├── worldcup_train_loader.py ├── ts_worldcup_train_loader.py ├── metrics.py ├── worldcup_test_loader.py ├── ts_worldcup_test_loader.py └── utils.py /imgs/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ericsujw/KpSFR/HEAD/imgs/teaser.png -------------------------------------------------------------------------------- /checkpoint/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /assets/worldcup_field_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ericsujw/KpSFR/HEAD/assets/worldcup_field_model.png -------------------------------------------------------------------------------- /dataset/WorldCup_2014_2018/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /dataset/soccer_worldcup_2014/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /dataset/DATASET.md: -------------------------------------------------------------------------------- 1 | Public Soccer WorldCup ([Download](https://nhoma.github.io/)) 2 | - Train Set Path: dataset/soccer_worldcup_2014/soccer_data/train_val 3 | - Test Set Path: dataset/soccer_worldcup_2014/soccer_data/test 4 | 5 | Details refer to [source](dataset/soccer_worldcup_2014/soccer_data/source.txt). 6 | 7 | TS-WorldCup ([Download](https://cgv.cs.nthu.edu.tw/KpSFR_data/TS-WorldCup.zip)) 8 | - Train Set Path: 9 | - RGB: dataset/WorldCup_2014_2018/Dataset/80_95/ + train.txt 10 | - GT: dataset/WorldCup_2014_2018/Annotations/80_95/ + train.txt 11 | 12 | - Test Set Path: 13 | - RGB: dataset/WorldCup_2014_2018/Dataset/80_95/ + test.txt 14 | - GT: dataset/WorldCup_2014_2018/Annotations/80_95/ + test.txt 15 | 16 | - Misc 17 | - SingleFramePredict_with_normalized/SingleFramePredict_finetuned_with_normalized: the prediction of Nie et al. 18 | 19 | We would predict keypoints label based on the heatmap results of Nie et al. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Yen-Jui Chu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /opt_ours.txt: -------------------------------------------------------------------------------- 1 | --batch_size 4 2 | --checkpoints_dir checkpoint 3 | --ckpt_path 4 | --custom_worldcup_root dataset/WorldCup_2014_2018 5 | --sfp_finetuned False 6 | --gpu_ids 0 7 | --isTrain True 8 | --loss_mode all 9 | --model_archi KC 10 | --name example/opt_ours 11 | --nms_thres 0.995 12 | --noise_rotate 0.0084 13 | --noise_trans 5.0 14 | --num_objects 4 15 | --pr_thres 5.0 16 | --public_worldcup_root dataset/soccer_worldcup_2014/soccer_data 17 | --resume False 18 | --step_size 300 19 | --template_path assets 20 | --testset test 21 | --train_epochs 1500 22 | --train_lr 0.0001 23 | --train_stage 0 24 | --trainset train_val 25 | --use_non_local 1 26 | --weight_decay 0.0 27 | --target_video 28 | --target_image -------------------------------------------------------------------------------- /exp_ours.txt: -------------------------------------------------------------------------------- 1 | --batch_size 4 2 | --checkpoints_dir checkpoint 3 | --ckpt_path checkpoint/kpsfr.pth 4 | --custom_worldcup_root dataset/WorldCup_2014_2018 5 | --sfp_finetuned False 6 | --gpu_ids 0 7 | --isTrain False 8 | --loss_mode all 9 | --model_archi KC 10 | --name example/exp_ours 11 | --nms_thres 0.995 12 | --noise_rotate 0.0084 13 | --noise_trans 5.0 14 | --num_objects 4 15 | --pr_thres 5.0 16 | --public_worldcup_root dataset/soccer_worldcup_2014/soccer_data 17 | --resume False 18 | --step_size 200 19 | --template_path assets 20 | --testset test 21 | --train_epochs 1500 22 | --train_lr 0.0001 23 | --train_stage 1 24 | --trainset train 25 | --use_non_local 1 26 | --weight_decay 0.0 27 | --target_video 28 | --target_image -------------------------------------------------------------------------------- /robust/opt_robust.txt: -------------------------------------------------------------------------------- 1 | --batch_size 4 2 | --checkpoints_dir ../checkpoint 3 | --ckpt_path 4 | --custom_worldcup_root ../dataset/WorldCup_2014_2018 5 | --sfp_finetuned False 6 | --gpu_ids 0 7 | --isTrain True 8 | --loss_mode 9 | --model_archi KC 10 | --name example/opt_robust 11 | --nms_thres 0.995 12 | --noise_rotate 0.0084 13 | --noise_trans 5.0 14 | --num_objects 4 15 | --pr_thres 5.0 16 | --public_worldcup_root ../dataset/soccer_worldcup_2014/soccer_data 17 | --resume False 18 | --step_size 200 19 | --template_path ../assets 20 | --testset test 21 | --train_epochs 600 22 | --train_lr 0.0001 23 | --train_stage 0 24 | --trainset train_val 25 | --use_non_local 1 26 | --weight_decay 0.0 27 | --target_video 28 | --target_image -------------------------------------------------------------------------------- /opt_ours_finetuned.txt: -------------------------------------------------------------------------------- 1 | --batch_size 4 2 | --checkpoints_dir checkpoint 3 | --ckpt_path checkpoint/kpsfr.pth 4 | --custom_worldcup_root dataset/WorldCup_2014_2018 5 | --sfp_finetuned False 6 | --gpu_ids 0 7 | --isTrain True 8 | --loss_mode all 9 | --model_archi KC 10 | --name example/opt_ours_finetuned 11 | --nms_thres 0.995 12 | --noise_rotate 0.0084 13 | --noise_trans 5.0 14 | --num_objects 4 15 | --pr_thres 5.0 16 | --public_worldcup_root dataset/soccer_worldcup_2014/soccer_data 17 | --resume False 18 | --step_size 30 19 | --template_path assets 20 | --testset test 21 | --train_epochs 100 22 | --train_lr 0.0001 23 | --train_stage 1 24 | --trainset train 25 | --use_non_local 1 26 | --weight_decay 0.0 27 | --target_video 28 | --target_image -------------------------------------------------------------------------------- /exp_ours_finetuned.txt: -------------------------------------------------------------------------------- 1 | --batch_size 4 2 | --checkpoints_dir checkpoint 3 | --ckpt_path checkpoint/kpsfr_finetuned.pth 4 | --custom_worldcup_root dataset/WorldCup_2014_2018 5 | --sfp_finetuned True 6 | --gpu_ids 0 7 | --isTrain False 8 | --loss_mode all 9 | --model_archi KC 10 | --name example/exp_ours_finetuned 11 | --nms_thres 0.995 12 | --noise_rotate 0.0084 13 | --noise_trans 5.0 14 | --num_objects 4 15 | --pr_thres 5.0 16 | --public_worldcup_root dataset/soccer_worldcup_2014/soccer_data 17 | --resume False 18 | --step_size 200 19 | --template_path assets 20 | --testset test 21 | --train_epochs 1500 22 | --train_lr 0.0001 23 | --train_stage 1 24 | --trainset train 25 | --use_non_local 1 26 | --weight_decay 0.0 27 | --target_video 28 | --target_image -------------------------------------------------------------------------------- /robust/exp_robust.txt: -------------------------------------------------------------------------------- 1 | --batch_size 4 2 | --checkpoints_dir ../checkpoint 3 | --ckpt_path ../checkpoint/robust.pth 4 | --custom_worldcup_root ../dataset/WorldCup_2014_2018 5 | --sfp_finetuned False 6 | --gpu_ids 0 7 | --isTrain False 8 | --loss_mode 9 | --model_archi KC 10 | --name example/exp_robust 11 | --nms_thres 0.995 12 | --noise_rotate 0.0084 13 | --noise_trans 5.0 14 | --num_objects 4 15 | --pr_thres 5.0 16 | --public_worldcup_root ../dataset/soccer_worldcup_2014/soccer_data 17 | --resume False 18 | --step_size 200 19 | --template_path ../assets 20 | --testset test 21 | --train_epochs 600 22 | --train_lr 0.0001 23 | --train_stage 1 24 | --trainset train 25 | --use_non_local 1 26 | --weight_decay 0.0 27 | --target_video 28 | --target_image -------------------------------------------------------------------------------- /exp_dice_bce.txt: -------------------------------------------------------------------------------- 1 | --batch_size 4 2 | --checkpoints_dir checkpoint 3 | --ckpt_path checkpoint/kpsfr_dicebceloss.pth 4 | --custom_worldcup_root dataset/WorldCup_2014_2018 5 | --sfp_finetuned False 6 | --gpu_ids 0 7 | --isTrain False 8 | --loss_mode all 9 | --model_archi KC 10 | --name example/exp_dice_bce 11 | --nms_thres 0.995 12 | --noise_rotate 0.0084 13 | --noise_trans 5.0 14 | --num_objects 4 15 | --pr_thres 5.0 16 | --public_worldcup_root dataset/soccer_worldcup_2014/soccer_data 17 | --resume False 18 | --step_size 200 19 | --template_path assets 20 | --testset test 21 | --train_epochs 1500 22 | --train_lr 0.0001 23 | --train_stage 1 24 | --trainset train 25 | --use_non_local 1 26 | --weight_decay 0.0 27 | --target_video 28 | --target_image -------------------------------------------------------------------------------- /exp_dice_wce.txt: -------------------------------------------------------------------------------- 1 | --batch_size 4 2 | --checkpoints_dir checkpoint 3 | --ckpt_path checkpoint/kpsfr_dicewceloss.pth 4 | --custom_worldcup_root dataset/WorldCup_2014_2018 5 | --sfp_finetuned False 6 | --gpu_ids 0 7 | --isTrain False 8 | --loss_mode all 9 | --model_archi KC 10 | --name example/exp_dice_wce 11 | --nms_thres 0.995 12 | --noise_rotate 0.0084 13 | --noise_trans 5.0 14 | --num_objects 4 15 | --pr_thres 5.0 16 | --public_worldcup_root dataset/soccer_worldcup_2014/soccer_data 17 | --resume False 18 | --step_size 200 19 | --template_path assets 20 | --testset test 21 | --train_epochs 1500 22 | --train_lr 0.0001 23 | --train_stage 1 24 | --trainset train 25 | --use_non_local 1 26 | --weight_decay 0.0 27 | --target_video 28 | --target_image -------------------------------------------------------------------------------- /robust/opt_robust_finetuned.txt: -------------------------------------------------------------------------------- 1 | --batch_size 4 2 | --checkpoints_dir ../checkpoint 3 | --ckpt_path ../checkpoint/robust.pth 4 | --custom_worldcup_root ../dataset/WorldCup_2014_2018 5 | --sfp_finetuned False 6 | --gpu_ids 0 7 | --isTrain True 8 | --loss_mode 9 | --model_archi KC 10 | --name example/opt_robust_finetuned 11 | --nms_thres 0.995 12 | --noise_rotate 0.0084 13 | --noise_trans 5.0 14 | --num_objects 4 15 | --pr_thres 5.0 16 | --public_worldcup_root ../dataset/soccer_worldcup_2014/soccer_data 17 | --resume False 18 | --step_size 200 19 | --template_path ../assets 20 | --testset test 21 | --train_epochs 600 22 | --train_lr 0.0001 23 | --train_stage 1 24 | --trainset train 25 | --use_non_local 1 26 | --weight_decay 0.0 27 | --target_video 28 | --target_image -------------------------------------------------------------------------------- /robust/exp_robust_finetuned.txt: -------------------------------------------------------------------------------- 1 | --batch_size 4 2 | --checkpoints_dir ../checkpoint 3 | --ckpt_path ../checkpoint/robust_finetuned.pth 4 | --custom_worldcup_root ../dataset/WorldCup_2014_2018 5 | --sfp_finetuned True 6 | --gpu_ids 0 7 | --isTrain False 8 | --loss_mode 9 | --model_archi KC 10 | --name example/exp_robust_finetuned 11 | --nms_thres 0.995 12 | --noise_rotate 0.0084 13 | --noise_trans 5.0 14 | --num_objects 4 15 | --pr_thres 5.0 16 | --public_worldcup_root ../dataset/soccer_worldcup_2014/soccer_data 17 | --resume False 18 | --step_size 200 19 | --template_path ../assets 20 | --testset test 21 | --train_epochs 600 22 | --train_lr 0.0001 23 | --train_stage 1 24 | --trainset train 25 | --use_non_local 1 26 | --weight_decay 0.0 27 | --target_video 28 | --target_image -------------------------------------------------------------------------------- /inference.txt: -------------------------------------------------------------------------------- 1 | --batch_size 4 2 | --checkpoints_dir checkpoint 3 | --ckpt_path checkpoint/kpsfr_finetuned.pth 4 | --custom_worldcup_root dataset/WorldCup_2014_2018 5 | --sfp_finetuned True 6 | --gpu_ids 0 7 | --isTrain False 8 | --loss_mode all 9 | --model_archi KC 10 | --name example/inference 11 | --nms_thres 0.995 12 | --noise_rotate 0.0084 13 | --noise_trans 5.0 14 | --num_objects 4 15 | --pr_thres 5.0 16 | --public_worldcup_root dataset/soccer_worldcup_2014/soccer_data 17 | --resume False 18 | --step_size 200 19 | --template_path assets 20 | --testset test 21 | --train_epochs 1500 22 | --train_lr 0.0001 23 | --train_stage 1 24 | --trainset train 25 | --use_non_local 1 26 | --weight_decay 0.0 27 | --target_video right/2018_Match_Highlights5_clip_00016-1 left/2014_Match_Highlights1_clip_00007-1 28 | --target_image 128 113 -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class BinaryDiceLoss(nn.Module): 7 | """Dice loss of binary class 8 | Args: 9 | smooth: A float number to smooth loss, and avoid NaN error, default: 1 10 | p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2 11 | predict: A tensor of shape [N, *] 12 | target: A tensor of shape same with predict 13 | reduction: Reduction method to apply, return mean over batch if 'mean', 14 | return sum if 'sum', return a tensor of shape [N,] if 'none' 15 | Returns: 16 | Loss tensor according to arg reduction 17 | Raise: 18 | Exception if unexpected reduction 19 | """ 20 | 21 | def __init__(self, smooth=1, p=2, reduction='mean'): 22 | super(BinaryDiceLoss, self).__init__() 23 | self.smooth = smooth 24 | self.p = p 25 | self.reduction = reduction 26 | 27 | def forward(self, predict, target): 28 | assert predict.shape[0] == target.shape[0], "predict & target batch size don't match" 29 | predict = predict.contiguous().view(predict.shape[0], -1) 30 | target = target.contiguous().view(target.shape[0], -1) 31 | 32 | num = torch.sum(torch.mul(predict, target), dim=1) + self.smooth 33 | den = torch.sum(predict.pow(self.p) + 34 | target.pow(self.p), dim=1) + self.smooth 35 | 36 | loss = 1 - num / den 37 | 38 | if self.reduction == 'mean': 39 | return loss.mean() 40 | elif self.reduction == 'sum': 41 | return loss.sum() 42 | elif self.reduction == 'none': 43 | return loss 44 | else: 45 | raise Exception('Unexpected reduction {}'.format(self.reduction)) 46 | -------------------------------------------------------------------------------- /models/eval_network.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modifed from STCN: https://github.com/hkchengrex/STCN 3 | eval_network.py - Evaluation version of the network 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from models.modules import * 11 | from models.network import Decoder 12 | 13 | import numpy as np 14 | import matplotlib.pyplot as plt 15 | import os.path as osp 16 | import utils 17 | 18 | 19 | class EvalKpSFR(nn.Module): 20 | 21 | def __init__(self, model_archi, num_objects, non_local): 22 | super().__init__() 23 | 24 | self.k = num_objects 25 | self.model_archi = model_archi 26 | 27 | self.key_encoder = KeyEncoder( 28 | num_objects=0, non_local=non_local) 29 | 30 | # TODO: random pick 4 31 | self.value_encoder = ValueEncoder( 32 | num_objects=num_objects, non_local=non_local) 33 | 34 | # Projection from f32 feature space to key space 35 | self.key_proj = KeyProjection(512, keydim=64) 36 | 37 | # Compress f32 a bit to use in decoding later on 38 | self.key_comp = nn.Conv2d(512, 256, kernel_size=3, padding=1) 39 | 40 | self.decoder = Decoder(model_archi=self.model_archi) 41 | 42 | def encode_key(self, frame): 43 | # frame: b*c*h*w 44 | 45 | f32, f16, f8, f4 = self.key_encoder(frame, None) 46 | 47 | return f32, f16, f8, f4 48 | 49 | def segment_with_query(self, k, qf32, qf16, qf8, qf4, lookup): 50 | 51 | print('model archi: ', self.model_archi) 52 | qf32 = qf32.expand(k, -1, -1, -1) 53 | qf16 = qf16.expand(k, -1, -1, -1) 54 | qf8 = qf8.expand(k, -1, -1, -1) 55 | qf4 = qf4.expand(k, -1, -1, -1) 56 | 57 | x_origin = self.decoder( 58 | qf32[0:1], qf16[0:1], qf8[0:1], qf4[0:1], lookup[0:1]) 59 | for idx in range(1, k): 60 | x1_origin = self.decoder( 61 | qf32[idx:idx+1], qf16[idx:idx+1], qf8[idx:idx+1], qf4[idx:idx+1], lookup[idx:idx+1]) 62 | x_origin = torch.cat([x_origin, x1_origin], dim=0) 63 | 64 | return torch.sigmoid(x_origin) 65 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /models/inference_core.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modifed from STCN: https://github.com/hkchengrex/STCN 3 | """ 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | from models.eval_network import EvalKpSFR 9 | 10 | import matplotlib.pyplot as plt 11 | import os.path as osp 12 | import numpy as np 13 | import utils 14 | 15 | 16 | class InferenceCore: 17 | 18 | def __init__(self, prop_net: EvalKpSFR, images, device, num_objects, lookup=None): 19 | 20 | self.prop_net = prop_net 21 | 22 | # True dimensions 23 | t = images.shape[1] 24 | h, w = images.shape[-2:] 25 | 26 | nh = images.shape[-2] // 4 27 | nw = images.shape[-1] // 4 28 | 29 | self.images = images 30 | self.device = device 31 | 32 | self.k = num_objects 33 | self.lookup = lookup 34 | 35 | # Background included, not always consistent (i.e. sum up to 1) 36 | self.prob = torch.zeros( 37 | (self.k + 1, t, 1, nh, nw), dtype=torch.float32, device=self.device) 38 | self.prob[0] = 1e-7 39 | 40 | self.t, self.h, self.w = t, h, w 41 | self.nh, self.nw = nh, nw 42 | 43 | def aggregate(self, prob, keep_bg=False): 44 | # values of prob is from (0, 1) 45 | assert prob.max() <= 1, print('out of value') 46 | assert prob.min() >= 0, print('out of value') 47 | k = prob.shape 48 | new_prob = torch.cat([ 49 | torch.prod(1 - prob, dim=0, keepdim=True), 50 | prob 51 | ], dim=0).clamp(1e-7, 1 - 1e-7) 52 | logits = torch.log((new_prob / (1 - new_prob))) 53 | 54 | if keep_bg: 55 | return logits, F.softmax(logits, dim=0) 56 | else: 57 | return logits, F.softmax(logits, dim=0)[1:] 58 | 59 | def encode_key(self, idx): 60 | 61 | result = self.prop_net.encode_key(self.images[:, idx]) 62 | return result 63 | 64 | def do_pass(self, idx, end_idx, selector): 65 | 66 | print(f'Current frame is {idx}') 67 | 68 | closest_ti = end_idx 69 | 70 | # Note that we never reach closest_ti, just the frame before it 71 | print('Start frame: ', idx) 72 | this_range = range(idx, closest_ti) 73 | end = closest_ti - 1 74 | 75 | for ti in this_range: 76 | print(f'Current frame is {ti}') 77 | qf32, qf16, qf8, qf4 = self.encode_key(ti) 78 | out_mask_origin = self.prop_net.segment_with_query( 79 | self.k, qf32, qf16, qf8, qf4, self.lookup[ti]) 80 | 81 | _, out_prob_origin = self.aggregate( 82 | out_mask_origin, keep_bg=True) 83 | out_prob = F.interpolate( 84 | out_prob_origin, (self.nh, self.nw), mode='bilinear', align_corners=False) 85 | self.prob[:, ti] = out_prob 86 | 87 | return closest_ti 88 | 89 | def interact(self, frame_idx, end_idx, selector=None): 90 | 91 | # Propagate 92 | self.do_pass(frame_idx, end_idx, selector) 93 | -------------------------------------------------------------------------------- /robust/README.md: -------------------------------------------------------------------------------- 1 | # A Robust and Efficient Framework for Sports-Field Registration 2 | 3 | This is an re-implementation of [A Robust and Efficient Framework for Sports-Field Registration](https://openaccess.thecvf.com/content/WACV2021/papers/Nie_A_Robust_and_Efficient_Framework_for_Sports-Field_Registration_WACV_2021_paper.pdf). 4 | 5 | 6 | ## Pretrained Models 7 | 1. Download [pretrained weight on WorldCup dataset](https://cgv.cs.nthu.edu.tw/KpSFR_data/model/robust.pth). 8 | 2. Now the pretrained models would place in [checkpoints](../checkpoints). 9 | 3. Download public [WorldCup](https://nhoma.github.io/) dataset. 10 | 4. Now the WorldCup dataset would place in [dataset/soccer_worldcup_2014](../dataset/soccer_worldcup_2014). 11 | 12 | 13 | ## Evaluation 14 | ### Evaluation command 15 | ```python 16 | python test.py 17 | ``` 18 | param_text_file as follows, 19 | - `exp_robust.txt`: download [pretrained weight](https://cgv.cs.nthu.edu.tw/KpSFR_data/model/robust.pth) first and place in [checkpoints](../checkpoints). Set `--train_stage` to **0** for testing on WorldCup test set or set `--train_stage` to **1** on TS-WorldCup test set and set `--sfp_finetuned` to **False**. 20 | - `exp_robust_finetuned.txt`: download [pretrained weight](https://cgv.cs.nthu.edu.tw/KpSFR_data/model/robust_finetuned.pth) first and place in [checkpoints](../checkpoints). Set `--train_stage` to **1** and `--sfp_finetuned` to **True** for testing finetuned results on TS-WorldCup test set. 21 | 22 | We will save heatmap results and corresponding homography matrix into **/checkpoints/path of experimental name**, which set `--name` in param_text_file. 23 | 24 | Note: 25 | - `robust_worldcup_testset_dilated`: the preprocess results for predicting on WorldCup dataset and would place in [dataset/soccer_worldcup_2014/soccer_data](../dataset/soccer_worldcup_2014/soccer_data). 26 | - `SingleFramePredict_with_normalized`: the preprocess results for predicting on TS-WorldCup dataset and would place in [dataset/WorldCup_2014_2018](../dataset/WorldCup_2014_2018). 27 | - `SingleFramePredict_finetuned_with_normalized`: the preprocess results for finetuning on TS-WorldCup dataset and would place in [dataset/WorldCup_2014_2018](../dataset/WorldCup_2014_2018). 28 | 29 | 30 | ## Train model 31 | ### Train command 32 | ```python 33 | python train.py 34 | ``` 35 | param_text_file as follows, 36 | - `opt_robust.txt`: download [pretrained weight](https://cgv.cs.nthu.edu.tw/KpSFR_data/model/robust.pth) first and place in [checkpoints](../checkpoints). Set `--train_stage` to **0** and `--trainset` to **train_val** for training on WorldCup train set. 37 | - `opt_robust_finetuned.txt`: download [pretrained weight](https://cgv.cs.nthu.edu.tw/KpSFR_data/model/robust_finetuned.pth) first and place in [checkpoints](../checkpoints). Set `--train_stage` to **1** and `--trainset` to **train** for finetuning on WorldCup train set. 38 | 39 | We will save visualize results and weights into **/checkpoints/path of experimental name**, which set `--name` in param_text_file. 40 | 41 | Note: Please check the following arguments to set correct before training every time. 42 | - `--gpu_ids` 43 | - `--name` 44 | - `--train_stage` and `--trainset` 45 | - `--ckpt_path` 46 | - `--train_epochs` and `--step_size` 47 | 48 | Details refer to [options.py](../options.py). -------------------------------------------------------------------------------- /models/modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modifed from STCN: https://github.com/hkchengrex/STCN 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torchvision import models 9 | 10 | from models import mod_resnet 11 | 12 | 13 | class ResBlock(nn.Module): 14 | 15 | def __init__(self, indim, outdim=None): 16 | super(ResBlock, self).__init__() 17 | 18 | if outdim == None: 19 | outdim = indim 20 | if indim == outdim: 21 | self.downsample = None 22 | else: 23 | self.downsample = nn.Conv2d( 24 | indim, outdim, kernel_size=3, padding=1) 25 | 26 | self.conv1 = nn.Conv2d(indim, outdim, kernel_size=3, padding=1) 27 | self.conv2 = nn.Conv2d(outdim, outdim, kernel_size=3, padding=1) 28 | 29 | def forward(self, x): 30 | r = self.conv1(F.relu(x)) 31 | r = self.conv2(F.relu(r)) 32 | 33 | if self.downsample is not None: 34 | x = self.downsample(x) 35 | 36 | return x + r 37 | 38 | 39 | class FeatureFusionBlock(nn.Module): 40 | 41 | def __init__(self, indim, outdim): 42 | super().__init__() 43 | 44 | self.block1 = ResBlock(indim, outdim) 45 | self.block2 = ResBlock(outdim, outdim) 46 | 47 | def forward(self, x, f32): 48 | 49 | x = torch.cat([x, f32], dim=1) 50 | x = self.block1(x) 51 | x = self.block2(x) 52 | 53 | return x 54 | 55 | 56 | # Multiple objects version, used in other times 57 | class ValueEncoder(nn.Module): 58 | 59 | def __init__(self, num_objects, non_local): 60 | super().__init__() 61 | 62 | resnet = mod_resnet.resnet18( 63 | pretrained=True, extra_chan=1, non_local=non_local) 64 | 65 | self.conv1 = resnet.conv1 66 | self.bn1 = resnet.bn1 67 | self.relu = resnet.relu # 1/2, 64 68 | self.maxpool = resnet.maxpool 69 | 70 | self.layer1 = resnet.layer1 # 1/4, 64 71 | self.layer2 = resnet.layer2 # 1/8, 128 72 | self.layer3 = resnet.layer3 # 1/16, 256 73 | self.layer4 = resnet.layer4 # 1/32, 512 74 | 75 | self.fuser = FeatureFusionBlock(512 + 512, 256) 76 | 77 | def forward(self, image, key_f32, mask, other_masks, cls, isFirst): 78 | # input mask: b*1*h*w 79 | # key_f32 is the feature from the key encoder 80 | assert mask.shape[1] == 1, 'channel inconsistent' 81 | if isFirst: 82 | mask = F.interpolate(mask, scale_factor=4, mode='nearest') 83 | 84 | assert image.shape[-2:] == mask.shape[-2:], 'shape inconsistent' 85 | f = torch.cat([image, mask], dim=1) 86 | 87 | x = self.conv1(f) 88 | x = self.bn1(x) 89 | x = self.relu(x) # 1/2, 64 90 | x = self.maxpool(x) # 1/4, 64 91 | 92 | x = self.layer1(x) # 1/4, 64 93 | x = self.layer2(x) # 1/8, 128 94 | x = self.layer3(x) # 1/16, 256 95 | x = self.layer4(x) # 1/32, 512 96 | 97 | x = self.fuser(x, key_f32) 98 | 99 | return x 100 | 101 | 102 | class KeyEncoder(nn.Module): 103 | 104 | def __init__(self, num_objects, non_local): 105 | super().__init__() 106 | 107 | resnet = mod_resnet.resnet34(pretrained=True, non_local=non_local) 108 | 109 | self.conv1 = resnet.conv1 110 | self.bn1 = resnet.bn1 111 | self.relu = resnet.relu # 1/2, 64 112 | self.maxpool = resnet.maxpool 113 | 114 | self.layer1 = resnet.layer1 # 1/4, 64 115 | self.layer2 = resnet.layer2 # 1/8, 128 116 | self.layer3 = resnet.layer3 # 1/16, 256 117 | self.layer4 = resnet.layer4 # 1/32, 512 118 | 119 | def forward(self, f, cls): 120 | 121 | x = self.conv1(f) 122 | x = self.bn1(x) 123 | x = self.relu(x) # 1/2, 64 124 | x = self.maxpool(x) # 1/4, 64 125 | 126 | f4 = self.layer1(x) # 1/4, 64 127 | f8 = self.layer2(f4) # 1/8, 128 128 | f16 = self.layer3(f8) # 1/16, 256 129 | f32 = self.layer4(f16) # 1/32, 512 130 | 131 | return f32, f16, f8, f4 132 | 133 | 134 | class KeyProjection(nn.Module): 135 | 136 | def __init__(self, indim, keydim): 137 | super().__init__() 138 | 139 | self.key_proj = nn.Conv2d( 140 | indim, keydim, kernel_size=3, padding=1) 141 | 142 | nn.init.orthogonal_(self.key_proj.weight.data) 143 | if self.key_proj.bias is not None: 144 | nn.init.zeros_(self.key_proj.bias.data) 145 | 146 | def forward(self, x): 147 | 148 | return self.key_proj(x) 149 | -------------------------------------------------------------------------------- /robust/worldcup_loader.py: -------------------------------------------------------------------------------- 1 | ''' 2 | train and test set on WorldCup for Nie et al. (A robust and efficient framework for sports-field registration) 3 | ''' 4 | import os.path as osp 5 | import numpy as np 6 | from PIL import Image 7 | from torch.utils import data 8 | from torchvision import transforms 9 | import utils 10 | import glob 11 | import random 12 | import skimage.segmentation as ss 13 | from typing import Optional 14 | 15 | 16 | class PublicWorldCupDataset(data.Dataset): 17 | 18 | def __init__(self, root, data_type, mode, noise_trans: Optional[float] = None, noise_rotate: Optional[float] = None): 19 | 20 | self.frame_h = 720 21 | self.frame_w = 1280 22 | self.root = root 23 | self.data_type = data_type 24 | self.mode = mode 25 | self.noise_trans = noise_trans 26 | self.noise_rotate = noise_rotate 27 | 28 | frame_list = [osp.basename(name) for name in glob.glob( 29 | osp.join(self.root, self.data_type, '*.jpg'))] 30 | self.frames = [img for img in sorted( 31 | frame_list, key=lambda x: int(x[:-4]))] 32 | 33 | homographies_list = [osp.basename(name) for name in glob.glob( 34 | osp.join(self.root, self.data_type, '*.homographyMatrix'))] 35 | self.homographies = [mat for mat in sorted( 36 | homographies_list, key=lambda x: int(x[:-17]))] 37 | 38 | self.preprocess = transforms.Compose([ 39 | transforms.ToTensor(), 40 | transforms.Normalize([0.485, 0.456, 0.406], 41 | [0.229, 0.224, 0.225]) # ImageNet 42 | ]) 43 | 44 | def __len__(self): 45 | 46 | return len(self.frames) 47 | 48 | def __getitem__(self, index): 49 | 50 | image = np.array(Image.open( 51 | osp.join(self.root, self.data_type, self.frames[index]))) 52 | 53 | gt_h = np.loadtxt( 54 | osp.join(self.root, self.data_type, self.homographies[index])) 55 | 56 | template_grid = utils.gen_template_grid() # template grid shape (91, 3) 57 | 58 | # Warp grid shape (N, 3), N is based on number of points in image view 59 | warp_image, warp_grid, homo_mat = utils.gen_im_partial_grid( 60 | self.mode, image, gt_h, template_grid, self.noise_trans, self.noise_rotate, index) 61 | num_pts = warp_grid.shape[0] 62 | pil_image = Image.fromarray(warp_image) 63 | 64 | # TODO apply random horizontal flip to the image and grid points 65 | if self.mode == 'train' and random.random() < 0.5: 66 | pil_image, warp_grid = utils.put_lrflip_augmentation( 67 | pil_image, warp_grid) 68 | 69 | # Default all keypoints belong to background 70 | heatmaps = np.zeros( 71 | (self.frame_h // 4, self.frame_w // 4), dtype=np.float32) 72 | 73 | # Dilate pixels 74 | for keypts_label in range(num_pts): 75 | px = np.rint(warp_grid[keypts_label, 0] / 4).astype(np.int32) 76 | py = np.rint(warp_grid[keypts_label, 1] / 4).astype(np.int32) 77 | if 0 <= px < (self.frame_w // 4) and 0 <= py < (self.frame_h // 4): 78 | heatmaps[py, px] = warp_grid[keypts_label, 2] 79 | 80 | dilated_heatmaps = ss.expand_labels(heatmaps, distance=5) 81 | image_tensor = self.preprocess(pil_image) 82 | 83 | return image_tensor, utils.to_torch(dilated_heatmaps), utils.to_torch(heatmaps), homo_mat 84 | 85 | 86 | if __name__ == "__main__": 87 | 88 | worldcup_loader = PublicWorldCupDataset( 89 | root='./dataset/soccer_worldcup_2014/soccer_data', data_type='train_val', mode='train', noise_trans=5.0, noise_rotate=0.0084) 90 | # worldcup_loader = PublicWorldCupDataset( 91 | # root='./dataset/soccer_worldcup_2014/soccer_data', data_type='test', mode='test') 92 | 93 | denorm = utils.UnNormalize( 94 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 95 | cnt = 0 96 | for image, gt_heatmap, target, gt_homo in worldcup_loader: 97 | pass 98 | # print(gt_homo.dtype) 99 | # plt.imsave('./target.png', target, vmin=0, vmax=91) 100 | # print(image.shape, gt_heatmap.shape) 101 | # print(image.dtype, gt_heatmap.dtype) 102 | # gt_heatmap = gt_heatmap.squeeze(0) 103 | # image = np.transpose(image.cpu().numpy(), (1, 2, 0)) 104 | # plt.imsave(osp.join('./0723', '%03d_out_warp_rgb.png' % cnt), image) 105 | # plt.imsave('./out_heatmaps.png', heatmaps, vmin=0, vmax=91, cmap='tab20b') 106 | # plt.imsave(osp.join('./0723', '%03d_out_heatmaps.png' % 107 | # cnt), gt_heatmap, vmin=0, vmax=91) 108 | # print(gt_heatmap.shape, gt_heatmap.dtype) 109 | # plt.imsave('./next_dilated_circle.png', heatmaps, vmin=0, vmax=91, cmap='tab20b') 110 | cnt += 1 111 | # assert False 112 | # utils.visualize_heatmaps(image, gt_heatmap) 113 | -------------------------------------------------------------------------------- /robust/ts_worldcup_loader.py: -------------------------------------------------------------------------------- 1 | ''' 2 | train and test set on TS-WorldCup for Nie et al. (A robust and efficient framework for sports-field registration) 3 | ''' 4 | import random 5 | import glob 6 | import os 7 | import os.path as osp 8 | import numpy as np 9 | from PIL import Image 10 | from torch.utils import data 11 | from torchvision import transforms 12 | import matplotlib.pyplot as plt 13 | import skimage.segmentation as ss 14 | from typing import Optional 15 | import utils 16 | 17 | 18 | class MainTestSVDataset(data.Dataset): 19 | 20 | def __init__(self, root, data_type, mode, noise_trans: Optional[float] = None, noise_rotate: Optional[float] = None): 21 | 22 | self.frame_h = 720 23 | self.frame_w = 1280 24 | self.root = root 25 | self.data_type = data_type 26 | self.mode = mode 27 | self.noise_trans = noise_trans 28 | self.noise_rotate = noise_rotate 29 | 30 | sequence_interval = '80_95' 31 | self.image_path = osp.join( 32 | self.root, 'Dataset', sequence_interval) 33 | self.anno_path = osp.join( 34 | self.root, 'Annotations', sequence_interval) 35 | imgset_path = osp.join(self.root, self.data_type) 36 | 37 | self.videos = [] 38 | self.num_frames = {} 39 | self.num_homographies = {} 40 | self.frames = [] 41 | self.homographies = [] 42 | 43 | with open(imgset_path + '.txt', 'r') as lines: 44 | for line in lines: 45 | _video = line.rstrip('\n') 46 | self.videos.append(_video) 47 | self.num_frames[_video] = len( 48 | glob.glob(osp.join(self.image_path, _video, '*.jpg'))) 49 | self.num_homographies[_video] = len( 50 | glob.glob(osp.join(self.anno_path, _video, '*_homography.npy'))) 51 | 52 | frames = sorted(os.listdir( 53 | osp.join(self.image_path, _video))) 54 | for img in frames: 55 | self.frames.append( 56 | osp.join(self.image_path, _video, img)) 57 | 58 | homographies = sorted(os.listdir( 59 | osp.join(self.anno_path, _video))) 60 | for mat in homographies: 61 | self.homographies.append( 62 | osp.join(self.anno_path, _video, mat)) 63 | 64 | self.preprocess = transforms.Compose([ 65 | transforms.ToTensor(), 66 | transforms.Normalize([0.485, 0.456, 0.406], 67 | [0.229, 0.224, 0.225]) # ImageNet 68 | ]) 69 | 70 | def __len__(self): 71 | 72 | return len(self.frames) 73 | 74 | def __getitem__(self, index): 75 | 76 | image = np.array(Image.open(self.frames[index])) 77 | 78 | gt_h = np.load(self.homographies[index]) 79 | 80 | template_grid = utils.gen_template_grid() # template grid shape (91, 3) 81 | 82 | # Warp grid shape (N, 3), N is based on number of points in image view 83 | warp_image, warp_grid, homo_mat = utils.gen_im_partial_grid( 84 | self.mode, image, gt_h, template_grid, self.noise_trans, self.noise_rotate, index) 85 | num_pts = warp_grid.shape[0] 86 | pil_image = Image.fromarray(warp_image) 87 | 88 | # TODO apply random horizontal flip to the image and grid points 89 | if self.mode == 'train' and random.random() < 0.5: 90 | pil_image, warp_grid = utils.put_lrflip_augmentation( 91 | pil_image, warp_grid) 92 | 93 | # Default all keypoints belong to background 94 | heatmaps = np.zeros( 95 | (self.frame_h // 4, self.frame_w // 4), dtype=np.float32) 96 | 97 | # Dilate pixels 98 | for keypts_label in range(num_pts): 99 | px = np.rint(warp_grid[keypts_label, 0] / 4).astype(np.int32) 100 | py = np.rint(warp_grid[keypts_label, 1] / 4).astype(np.int32) 101 | if 0 <= px < (self.frame_w // 4) and 0 <= py < (self.frame_h // 4): 102 | heatmaps[py, px] = warp_grid[keypts_label, 2] 103 | 104 | dilated_heatmaps = ss.expand_labels(heatmaps, distance=5) 105 | image_tensor = self.preprocess(pil_image) 106 | 107 | return image_tensor, utils.to_torch(dilated_heatmaps), utils.to_torch(heatmaps), homo_mat 108 | 109 | 110 | if __name__ == "__main__": 111 | 112 | # worldcup_loader = MainTestSVDataset( 113 | # root='dataset/WorldCup_2014_2018', data_type='train', mode='train', noise_trans=5.0, noise_rotate=0.0084) 114 | worldcup_loader = MainTestSVDataset( 115 | root='dataset/WorldCup_2014_2018', data_type='test', mode='test') 116 | 117 | import shutil 118 | cnt = 1 119 | visual_dir = osp.join('visual', 'main_test_sv') 120 | if osp.exists(visual_dir): 121 | print(f'Remove directory: {visual_dir}') 122 | shutil.rmtree(visual_dir) 123 | print(f'Create directory: {visual_dir}') 124 | os.makedirs(visual_dir, exist_ok=True) 125 | 126 | for image, gt_heatmap, target, gt_homo in worldcup_loader: 127 | plt.imsave(osp.join(visual_dir, 'Rgb%03d.jpg' % ( 128 | cnt)), utils.im_to_numpy(image)) 129 | plt.imsave(osp.join(visual_dir, 'Seg%03d.jpg' % ( 130 | cnt)), utils.to_numpy(gt_heatmap), vmin=0, vmax=91) 131 | cnt += 1 132 | if cnt == 10: 133 | break 134 | pass 135 | -------------------------------------------------------------------------------- /models/non_local.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modifed from: https://github.com/tea1528/Non-Local-NN-Pytorch 3 | We modified non-local neural networks from the above example. 4 | """ 5 | 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | 10 | 11 | class NLBlockND(nn.Module): 12 | def __init__(self, in_channels, inter_channels=None, mode='embedded', 13 | dimension=3, sub_sample=True, bn_layer=True): 14 | """Implementation of Non-Local Block with 4 different pairwise functions and include subsampling trick 15 | args: 16 | in_channels: original channel size (1024 in the paper) 17 | inter_channels: channel size inside the block if not specifed reduced to half (512 in the paper) 18 | mode: supports Gaussian, Embedded Gaussian, Dot Product, and Concatenation 19 | dimension: can be 1 (temporal), 2 (spatial), 3 (spatiotemporal) 20 | bn_layer: whether to add batch norm 21 | """ 22 | super(NLBlockND, self).__init__() 23 | 24 | assert dimension in [1, 2, 3] 25 | 26 | if mode not in ['gaussian', 'embedded', 'dot', 'concatenate']: 27 | raise ValueError( 28 | '`mode` must be one of `gaussian`, `embedded`, `dot` or `concatenate`') 29 | 30 | self.mode = mode 31 | self.dimension = dimension 32 | self.sub_sample = sub_sample 33 | 34 | self.in_channels = in_channels 35 | self.inter_channels = inter_channels 36 | 37 | # the channel size is reduced to half inside the block 38 | if self.inter_channels is None: 39 | self.inter_channels = in_channels // 2 40 | if self.inter_channels == 0: 41 | self.inter_channels = 1 42 | 43 | # assign appropriate convolutional, max pool, and batch norm layers for different dimensions 44 | if dimension == 3: 45 | conv_nd = nn.Conv3d 46 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 47 | bn = nn.BatchNorm3d 48 | elif dimension == 2: 49 | conv_nd = nn.Conv2d 50 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 51 | bn = nn.BatchNorm2d 52 | else: 53 | conv_nd = nn.Conv1d 54 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 55 | bn = nn.BatchNorm1d 56 | 57 | # function g in the paper which goes through conv. with kernel size 1 58 | self.g = conv_nd(in_channels=self.in_channels, 59 | out_channels=self.inter_channels, kernel_size=1) 60 | 61 | # add BatchNorm layer after the last conv layer 62 | if bn_layer: 63 | self.W_z = nn.Sequential( 64 | conv_nd(in_channels=self.inter_channels, 65 | out_channels=self.in_channels, kernel_size=1), 66 | bn(self.in_channels) 67 | ) 68 | # from section 4.1 of the paper, initializing params of BN ensures that the initial state of non-local block is identity mapping 69 | nn.init.constant_(self.W_z[1].weight, 0) 70 | nn.init.constant_(self.W_z[1].bias, 0) 71 | else: 72 | self.W_z = conv_nd(in_channels=self.inter_channels, 73 | out_channels=self.in_channels, kernel_size=1) 74 | 75 | # from section 3.3 of the paper by initializing Wz to 0, this block can be inserted to any existing architecture 76 | nn.init.constant_(self.W_z.weight, 0) 77 | nn.init.constant_(self.W_z.bias, 0) 78 | 79 | # define theta and phi for all operations except gaussian 80 | if self.mode == "embedded" or self.mode == "dot" or self.mode == "concatenate": 81 | self.theta = conv_nd(in_channels=self.in_channels, 82 | out_channels=self.inter_channels, kernel_size=1) 83 | self.phi = conv_nd(in_channels=self.in_channels, 84 | out_channels=self.inter_channels, kernel_size=1) 85 | 86 | if self.mode == "concatenate": 87 | self.W_f = nn.Sequential( 88 | nn.Conv2d(in_channels=self.inter_channels * 89 | 2, out_channels=1, kernel_size=1), 90 | nn.ReLU() 91 | ) 92 | 93 | # from section 3.3 of the paper by reducing computation using subsampling trick, added max pooling layer after phi and g 94 | if sub_sample: 95 | self.g = nn.Sequential(self.g, max_pool_layer) 96 | self.phi = nn.Sequential(self.phi, max_pool_layer) 97 | 98 | def forward(self, x): 99 | """ 100 | args 101 | x: (N, C, T, H, W) for dimension=3; (N, C, H, W) for dimension 2; (N, C, T) for dimension 1 102 | """ 103 | 104 | batch_size = x.size(0) 105 | 106 | # (N, C, THW) 107 | # this reshaping and permutation is from the spacetime_nonlocal function in the original Caffe2 implementation 108 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 109 | g_x = g_x.permute(0, 2, 1) 110 | 111 | if self.mode == "gaussian": 112 | theta_x = x.view(batch_size, self.in_channels, -1) 113 | phi_x = x.view(batch_size, self.in_channels, -1) 114 | theta_x = theta_x.permute(0, 2, 1) 115 | f = torch.matmul(theta_x, phi_x) 116 | 117 | elif self.mode == "embedded" or self.mode == "dot": 118 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 119 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 120 | theta_x = theta_x.permute(0, 2, 1) 121 | f = torch.matmul(theta_x, phi_x) 122 | 123 | elif self.mode == "concatenate": 124 | theta_x = self.theta(x).view( 125 | batch_size, self.inter_channels, -1, 1) 126 | phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1) 127 | 128 | h = theta_x.size(2) 129 | w = phi_x.size(3) 130 | theta_x = theta_x.repeat(1, 1, 1, w) 131 | phi_x = phi_x.repeat(1, 1, h, 1) 132 | 133 | concat = torch.cat([theta_x, phi_x], dim=1) 134 | f = self.W_f(concat) 135 | f = f.view(f.size(0), f.size(2), f.size(3)) 136 | 137 | if self.mode == "gaussian" or self.mode == "embedded": 138 | f_div_C = F.softmax(f, dim=-1) # performed on each row 139 | elif self.mode == "dot" or self.mode == "concatenate": 140 | N = f.size(-1) # number of position in x 141 | f_div_C = f / N 142 | 143 | y = torch.matmul(f_div_C, g_x) 144 | 145 | # contiguous here just allocates contiguous chunk of memory 146 | y = y.permute(0, 2, 1).contiguous() 147 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 148 | 149 | W_y = self.W_z(y) 150 | # residual connection 151 | z = W_y + x 152 | 153 | return z 154 | 155 | 156 | if __name__ == '__main__': 157 | 158 | for bn_layer in [True, False]: 159 | img = torch.zeros(2, 3, 20) 160 | net = NLBlockND(in_channels=3, mode='concatenate', 161 | dimension=1, bn_layer=bn_layer) 162 | out = net(img) 163 | print(out.size()) 164 | 165 | img = torch.zeros(2, 3, 20, 20) 166 | net = NLBlockND(in_channels=3, mode='concatenate', 167 | dimension=2, bn_layer=bn_layer) 168 | out = net(img) 169 | print(out.size()) 170 | 171 | img = torch.randn(2, 3, 8, 20, 20) 172 | net = NLBlockND(in_channels=3, mode='concatenate', 173 | dimension=3, bn_layer=bn_layer) 174 | out = net(img) 175 | print(out.size()) 176 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import sys 7 | import argparse 8 | import os 9 | 10 | 11 | def convert_arg_line_to_args(arg_line): 12 | line_all = arg_line.split() 13 | for idx, arg in enumerate(line_all): 14 | if len(line_all) > 1 and line_all[1] == 'False': 15 | return 16 | # ignore only resume_weight 17 | if len(line_all) == 1: 18 | return 19 | if arg[0] == '[' or arg == 'False' or arg == 'True': 20 | continue 21 | if arg == '--isTrain': 22 | continue 23 | if not arg.strip(): 24 | continue 25 | 26 | yield arg 27 | 28 | 29 | class CustomOptions(): 30 | 31 | def __init__(self, train): 32 | self.initialized = False 33 | self.isTrain = train 34 | 35 | def initialize(self, parser): 36 | # Experiment specifics 37 | parser.add_argument('--checkpoints_dir', type=str, 38 | default='checkpoint') 39 | parser.add_argument('--name', type=str, default='debug/check_full_model', 40 | help='name of the experiment. It decides where to store samples and models') 41 | 42 | # Data related 43 | parser.add_argument('--public_worldcup_root', type=str, 44 | default='./dataset/soccer_worldcup_2014/soccer_data', help='data root of public worldcup dataset') 45 | parser.add_argument('--custom_worldcup_root', type=str, 46 | default='./dataset/WorldCup_2014_2018', help='data root of custom worldcup dataset') 47 | parser.add_argument('--template_path', type=str, 48 | default='./assets', help='path of worldcup template') 49 | parser.add_argument('--trainset', type=str, 50 | default='train_val', help='path of training set') 51 | parser.add_argument('--testset', type=str, 52 | default='test', help='path of testing set') 53 | 54 | # Training related 55 | parser.add_argument('--gpu_ids', type=str, default='1', 56 | help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 57 | parser.add_argument('--model_archi', type=str, 58 | default='KC_STCN', help='KC_STCN or KC or STCN') 59 | parser.add_argument('--loss_mode', type=str, 60 | default='', help='all or dice_bce or dice_wce') 61 | parser.add_argument('--use_non_local', type=int, 62 | default=1, help='if use non local block layer') 63 | parser.add_argument('--train_epochs', type=int, default=300) 64 | parser.add_argument('--train_stage', type=int, default=0, 65 | help='training stage (0-public, 1-custom)') 66 | parser.add_argument('--num_objects', type=int, 67 | default=4, help='the number of objects to be segmented') 68 | parser.add_argument('--resume', action='store_true', 69 | default=False, help='if resume training') 70 | parser.add_argument('--ckpt_path', type=str, 71 | default='', help='path of pretrained or resumed weight') 72 | parser.add_argument('--sfp_finetuned', action='store_true', 73 | default=False, help='if use finetuned results of single frame prediction on testing') 74 | 75 | # Hyperparameters 76 | parser.add_argument('--batch_size', type=int, 77 | default=4, help='input batch size') 78 | parser.add_argument('--train_lr', type=float, 79 | default=1e-4, help='base learning rate') 80 | parser.add_argument('--step_size', type=int, 81 | default=200, help='learning rate scheduling') 82 | parser.add_argument('--weight_decay', type=float, 83 | default=0.0, help='if the need for regularization') 84 | parser.add_argument('--nms_thres', type=float, 85 | default=0.25, help='threshold when calculating nms') # 0.995 86 | parser.add_argument('--pr_thres', type=float, 87 | default=5.0, help='threshold when calculating precision and recall of keypoints') 88 | parser.add_argument('--noise_trans', type=float, 89 | default=5.0, help='noise parameter translation') 90 | parser.add_argument('--noise_rotate', type=float, 91 | default=0.014, help='noise parameter rotation') # 0.0084 92 | 93 | # Inference settings 94 | parser.add_argument('--target_video', type=str, nargs='+', 95 | default=[], help='Inference single video only, default is None') 96 | parser.add_argument('--target_image', type=str, nargs='+', 97 | default=[], help='Inference single image only, default is None') 98 | 99 | self.initialized = True 100 | return parser 101 | 102 | def gather_options(self): 103 | 104 | # initialize parser with basic options 105 | if not self.initialized: 106 | parser = argparse.ArgumentParser( 107 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, fromfile_prefix_chars='@') 108 | parser.convert_arg_line_to_args = convert_arg_line_to_args 109 | parser = self.initialize(parser) 110 | 111 | if sys.argv.__len__() == 2: 112 | arg_filename_with_prefix = '@' + sys.argv[1] 113 | opt, unknown = parser.parse_known_args([arg_filename_with_prefix]) 114 | opt = parser.parse_args([arg_filename_with_prefix]) 115 | else: 116 | # get the basic options 117 | opt, unknown = parser.parse_known_args() 118 | opt = parser.parse_args() 119 | 120 | self.parser = parser 121 | return opt 122 | 123 | def print_options(self, opt): 124 | message = '' 125 | message += '----------------- Options ---------------\n' 126 | for k, v in sorted(vars(opt).items()): 127 | comment = '' 128 | default = self.parser.get_default(k) 129 | if v != default: 130 | comment = '\t[default: %s]' % str(default) 131 | message += '{:<25}: {:<30}{}\n'.format(str(k), str(v), comment) 132 | message += '----------------- End -------------------' 133 | print(message) 134 | 135 | def option_file_path(self, opt, makedir=False): 136 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 137 | if makedir: 138 | os.makedirs(expr_dir, exist_ok=True) 139 | file_name = os.path.join(expr_dir, 'opt') 140 | 141 | return file_name 142 | 143 | def save_options(self, opt): 144 | file_name = self.option_file_path(opt, makedir=True) 145 | with open(file_name + '.txt', 'wt') as opt_file: 146 | for k, v in sorted(vars(opt).items()): 147 | comment = '' 148 | default = self.parser.get_default(k) 149 | if v != default: 150 | comment = '\t[default:%s]' % str(default) 151 | opt_file.write( 152 | '--{:<25} {:<30} {}\n'.format(str(k), str(v), comment)) 153 | 154 | def parse(self, save=False): 155 | 156 | opt = self.gather_options() 157 | opt.isTrain = self.isTrain # train or test 158 | 159 | self.print_options(opt) 160 | if opt.isTrain: 161 | self.save_options(opt) 162 | 163 | # os._exit(0) 164 | self.opt = opt 165 | return self.opt 166 | -------------------------------------------------------------------------------- /robust/models/model.py: -------------------------------------------------------------------------------- 1 | ''' 2 | network for Nie et al. (A robust and efficient framework for sports-field registration) 3 | ''' 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torchvision.models 8 | import math 9 | 10 | from models.non_local import NLBlockND 11 | 12 | 13 | def weights_init(m): 14 | # Initialize filters with Gaussian random weights 15 | if isinstance(m, nn.Conv2d): 16 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 17 | m.weight.data.normal_(0, math.sqrt(2. / n)) 18 | if m.bias is not None: 19 | m.bias.data.zero_() 20 | elif isinstance(m, nn.BatchNorm2d): 21 | m.weight.data.fill_(1) 22 | m.bias.data.zero_() 23 | 24 | 25 | def weights_init2(m): 26 | 27 | if isinstance(m, nn.Conv2d): 28 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 29 | m.weight.data.normal_(0, math.sqrt(2. / n)) 30 | if m.bias is not None: 31 | m.bias.data.zero_() 32 | 33 | 34 | class BasicBlock(nn.Module): 35 | 36 | expansion = 1 37 | 38 | def __init__(self, in_planes, out_planes, stride=1, dilation=1): 39 | super(BasicBlock, self).__init__() 40 | 41 | self.conv1 = nn.Conv2d( 42 | in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=False) 43 | self.bn1 = nn.BatchNorm2d(out_planes) 44 | self.relu = nn.ReLU(inplace=True) 45 | self.conv2 = nn.Conv2d( 46 | out_planes, out_planes, kernel_size=3, stride=1, padding=dilation, dilation=dilation, bias=False) 47 | self.bn2 = nn.BatchNorm2d(out_planes) 48 | self.downsample = None 49 | if stride > 1: 50 | self.downsample = nn.Sequential( 51 | nn.Conv2d(in_planes, out_planes * self.expansion, kernel_size=1, 52 | stride=stride, bias=False), 53 | nn.BatchNorm2d(out_planes * self.expansion), 54 | ) 55 | 56 | def forward(self, x): 57 | residual = x 58 | 59 | out = self.conv1(x) 60 | out = self.bn1(out) 61 | out = self.relu(out) 62 | 63 | out = self.conv2(out) 64 | out = self.bn2(out) 65 | 66 | if self.downsample is not None: 67 | residual = self.downsample(x) 68 | 69 | out += residual 70 | out = self.relu(out) 71 | 72 | return out 73 | 74 | 75 | class EncDec(nn.Module): 76 | 77 | def up_conv(self, in_channels, out_channels, size): 78 | return nn.Sequential( 79 | nn.Conv2d(in_channels, in_channels // 2, 80 | kernel_size=3, padding=1, bias=False), 81 | nn.BatchNorm2d(in_channels // 2), 82 | nn.ReLU(inplace=True), 83 | nn.Conv2d(in_channels // 2, in_channels // 84 | 2, kernel_size=3, padding=1, bias=False), 85 | nn.BatchNorm2d(in_channels//2), 86 | nn.ReLU(inplace=True), 87 | 88 | nn.Upsample(size=size, mode='bilinear', align_corners=True), 89 | 90 | nn.Conv2d(in_channels // 2, out_channels, 91 | kernel_size=3, padding=1, bias=False), 92 | nn.BatchNorm2d(out_channels), 93 | nn.ReLU(inplace=True), 94 | ) 95 | 96 | def __init__(self, layers, n_classes, non_local, pretrained=True): 97 | 98 | if layers not in [18, 34, 50, 101, 152]: 99 | raise RuntimeError( 100 | 'Only 18, 34, 50, 101, and 152 layer model are defined for ResNet. Got {}'.format(layers)) 101 | 102 | if non_local: 103 | sub_sample = bn_layer = True 104 | 105 | super(EncDec, self).__init__() 106 | pretrained_model = torchvision.models.__dict__[ 107 | 'resnet{}'.format(layers)](pretrained=pretrained) 108 | 109 | self.in_planes = 256 110 | self.conv1 = pretrained_model._modules['conv1'] 111 | self.bn1 = pretrained_model._modules['bn1'] 112 | self.relu = pretrained_model._modules['relu'] 113 | self.maxpool = pretrained_model._modules['maxpool'] 114 | 115 | self.layer1 = pretrained_model._modules['layer1'] 116 | self.layer2 = pretrained_model._modules['layer2'] 117 | self.layer3 = pretrained_model._modules['layer3'] 118 | # self.layer4 = pretrained_model._modules['layer4'] 119 | 120 | # TODO add a few dilated convolution and non-local block layers 121 | self.layer4 = self._make_layer( 122 | BasicBlock, 512, 2, stride=2, dilation=2, non_local=non_local, sub_sample=sub_sample, bn_layer=bn_layer) 123 | self.layer4[0].apply(weights_init) 124 | self.layer4[2].apply(weights_init) 125 | self.layer4[1].apply(weights_init2) 126 | self.layer4[3].apply(weights_init2) 127 | 128 | # Clear memory 129 | del pretrained_model 130 | 131 | self.midconv = nn.Sequential( 132 | nn.Conv2d(512, 1024, kernel_size=3, 133 | stride=2, padding=1, bias=False), 134 | nn.BatchNorm2d(1024), 135 | nn.ReLU(inplace=True), 136 | nn.Conv2d(1024, 1024, kernel_size=3, 137 | stride=1, padding=1, bias=False), 138 | nn.BatchNorm2d(1024), 139 | nn.ReLU(inplace=True), 140 | 141 | nn.Upsample(size=(23, 40), mode='bilinear', align_corners=True), 142 | 143 | nn.Conv2d(1024, 512, kernel_size=3, 144 | stride=1, padding=1, bias=False), 145 | nn.BatchNorm2d(512), 146 | nn.ReLU(inplace=True), 147 | ) 148 | 149 | self.up3 = self.up_conv(1024, 256, (45, 80)) 150 | self.up2 = self.up_conv(512, 128, (90, 160)) 151 | self.up1 = self.up_conv(256, 64, (180, 320)) 152 | 153 | self.conv_last = nn.Conv2d(128, n_classes, kernel_size=1) 154 | 155 | # Weights initialize 156 | self.midconv.apply(weights_init) 157 | self.up3.apply(weights_init) 158 | self.up2.apply(weights_init) 159 | self.up1.apply(weights_init) 160 | self.conv_last.apply(weights_init) 161 | 162 | def _make_layer(self, block, out_planes, num_blocks, stride, dilation=1, non_local=False, sub_sample=False, bn_layer=False): 163 | strides = [stride] + [1] * (num_blocks - 1) 164 | layers = [] 165 | last_idx = len(strides) 166 | 167 | if non_local: 168 | last_idx = len(strides) - 1 169 | 170 | for i in range(last_idx): 171 | layers.append( 172 | block(self.in_planes, out_planes, strides[i], dilation)) 173 | self.in_planes = out_planes * block.expansion 174 | 175 | if non_local: 176 | layers.append(NLBlockND(in_channels=out_planes, mode='embedded', 177 | dimension=2, sub_sample=sub_sample, bn_layer=bn_layer)) 178 | layers.append( 179 | block(self.in_planes, out_planes, strides[-1], dilation)) 180 | layers.append(NLBlockND(in_channels=out_planes, mode='embedded', 181 | dimension=2, sub_sample=sub_sample, bn_layer=bn_layer)) 182 | 183 | return nn.Sequential(*layers) 184 | 185 | def forward(self, x): 186 | # === resnet === 187 | [bs, c, h, w] = x.size() 188 | 189 | x_original = self.conv1(x) 190 | x_original = self.bn1(x_original) 191 | x_original = self.relu(x_original) 192 | 193 | conv0 = self.maxpool(x_original) 194 | conv1 = self.layer1(conv0) # size=(N, 64, x.H/4, x.W/4) 195 | conv2 = self.layer2(conv1) # size=(N, 128, x.H/8, x.W/8) 196 | conv3 = self.layer3(conv2) # size=(N, 256, x.H/16, x.W/16) 197 | conv4 = self.layer4(conv3) # size=(N, 512, x.H/32, x.W/32) 198 | 199 | # === decoder === 200 | x = self.midconv(conv4) # 512 201 | x = torch.cat([x, conv4], dim=1) 202 | x = self.up3(x) 203 | x = torch.cat([x, conv3], dim=1) 204 | x = self.up2(x) 205 | x = torch.cat([x, conv2], dim=1) 206 | x = self.up1(x) 207 | x = torch.cat([x, conv1], dim=1) 208 | 209 | out = self.conv_last(x) 210 | 211 | return out 212 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sports Field Registration via Keypoints-aware Label Condition 2 | 3 | This is an official implementation of Sports Field Registration via Keypoints-aware Label Condition. 4 | 5 | Yen-Jui Chu, Jheng-Wei Su, Kai-Wen Hsiao, Chi-Yu Lien, Shu-Ho Fan, 6 | Min-Chun Hu, Ruen-Rone Lee, Chih-Yuan Yao, Hung-Kuo Chu 7 | 8 | 8th International Workshop on Computer Vision in Sports (CVsports) at CVPR 2022 9 | 10 | ### [[Paper](https://cgv.cs.nthu.edu.tw/KpSFR_data/KpSFR_paper.pdf)] [[Webpage](https://ericsujw.github.io/KpSFR/)] 11 | 12 |

13 | 14 |

15 | 16 | We propose a novel deep learning framework for sports field registration. The typical algorithmic flow for sports field registration involves extracting field-specific features (e.g., corners, lines, etc.) from field image and estimating the homography matrix between a 2D field template and the field image using the extracted features. 17 | Unlike previous methods that strive to extract sparse field features from field images with uniform appearance, we tackle the problem differently. 18 | First, we use a grid of uniformly distributed keypoints as our field-specific features to increase the likelihood of having sufficient field features under various camera poses. Then we formulate the keypoints detection problem as an instance segmentation with dynamic filter learning. 19 | In our model, the convolution filters are generated dynamically, conditioned on the field image and associated keypoint identity, thus improving the robustness of prediction results. 20 | To extensively evaluate our method, we introduce a new soccer dataset, called TS-WorldCup, with detailed field markings on 3812 time-sequence images from 43 videos of Soccer World Cup 2014 and 2018. The experimental results demonstrate that our method outperforms state-of-the-arts on the TS-WorldCup dataset in both quantitative and qualitative evaluations. 21 | 22 | 23 | ## Requirements 24 | - CUDA 11 25 | - Python >= 3.8 26 | - Pytorch == 1.9.0 27 | - torchvision == 0.9.0 28 | - Numpy 29 | - OpenCV-Python == 4.5.1.48 30 | - Matplotlib 31 | - Pillow/scikit-image 32 | - Shapely == 1.7.1 33 | - tqdm 34 | 35 | ## Getting Started 36 | 1. Clone this repo: 37 | ```sh 38 | git clone https://github.com/ericsujw/KpSFR 39 | cd KpSFR/ 40 | ``` 41 | 2. Install [miniconda](https://docs.conda.io/en/latest/miniconda.html) 42 | 3. Install all the dependencies 43 | ```sh 44 | conda env create -f environment.yml 45 | ``` 46 | 4. Switch to the conda environment 47 | ```sh 48 | conda activate kpsfr 49 | ``` 50 | 51 | ## Pretrained Models 52 | 1. Download [pretrained weight on WorldCup dataset](https://cgv.cs.nthu.edu.tw/KpSFR_data/model/kpsfr.pth). 53 | 2. Now the pretrained models would place in [checkpoints](checkpoints). 54 | 3. Download public [WorldCup](https://nhoma.github.io/) dataset. 55 | 4. Download [TS-WorldCup](https://cgv.cs.nthu.edu.tw/KpSFR_data/TS-WorldCup.zip) dataset. 56 | 5. Now the WorldCup dataset would place in [dataset/soccer_worldcup_2014](dataset/soccer_worldcup_2014) and TS-WorldCup in [dataset/WorldCup_2014_2018](dataset/WorldCup_2014_2018). 57 | 58 | ## Inference 59 | 60 | Please use [robust](robust) model first to get the preprocess results before running the inference command below. 61 | 62 | ### Inference command 63 | ```python 64 | python inference.py 65 | ``` 66 | param_text_file as follows, 67 | - `inference.txt`: download [pretrained weight on WorldCup dataset](https://cgv.cs.nthu.edu.tw/KpSFR_data/model/kpsfr.pth) or [pretrained weight on TS-WorldCup dataset](https://cgv.cs.nthu.edu.tw/KpSFR_data/model/kpsfr_finetuned.pth) first and place in [checkpoints](checkpoints). For inference on your own data. 68 | - `target_video`: specify the target video, use space to split multiple video, and set `--train_stage` to **1**, `--sfp_finetuned` to **True** and `--ckpt_path` to the corresponding model finetuned weight. 69 | - `target_image`: specify the target image, use space to split multiple image, and set `--train_stage` to **0**, `--sfp_finetuned` to **False** and `--ckpt_path` to the corresponding model weight. 70 | 71 | Note: 72 | - In the current implementation, we could only infer on WorldCup or TS-WorldCup test set. 73 | - Input index of the corresponding image when specify `target_image`. 74 | - Input format for specifying `target_video` to refer to the [text file](dataset/WorldCup_2014_2018/test.txt). 75 | - Execute and output all testing data if not specify target image or video. 76 | 77 | ## Evaluation 78 | 79 | Please use [robust](robust) model first to get the preprocess results before running the evaluation command below. 80 | 81 | ### Evaluation command 82 | ```python 83 | python eval_testset.py 84 | ``` 85 | param_text_file as follows, 86 | - `exp_ours.txt`: download [pretrained weight on WorldCup dataset](https://cgv.cs.nthu.edu.tw/KpSFR_data/model/kpsfr.pth) first and place in [checkpoints](checkpoints). Set `--train_stage` to **0** for testing on WorldCup test set or set `--train_stage` to **1** on TS-WorldCup test set and set `--sfp_finetuned` to **False**. 87 | - `exp_our_finetuned.txt`: download [pretrained weight on TS-WorldCup dataset](https://cgv.cs.nthu.edu.tw/KpSFR_data/model/kpsfr_finetuned.pth) first and place in [checkpoints](checkpoints). Set `--train_stage` to **1** and `--sfp_finetuned` to **True** for testing finetuned results on TS-WorldCup test set. 88 | - `exp_our_dice_bce.txt`: download [pretrained weight](https://cgv.cs.nthu.edu.tw/KpSFR_data/model/kpsfr_dicebceloss.pth) first and place in [checkpoints](checkpoints). Set `--train_stage` to **1** for ablation study of loss function (binary dice loss with binary cross entropy loss) on TS-WorldCup test set. 89 | - `exp_our_dice_wce.txt`: download [pretrained weight](https://cgv.cs.nthu.edu.tw/KpSFR_data/model/kpsfr_dicewceloss.pth) first and place in [checkpoints](checkpoints). Set `--train_stage` to **1** for ablation study of loss function (binary dice loss with weighted cross entropy loss) on TS-WorldCup test set. 90 | 91 | We will save heatmap results and corresponding homography matrix into **/checkpoints/path of experimental name**, which set `--name` in param_text_file. 92 | 93 | ## Train model 94 | ### Train command 95 | ```python 96 | python train_nn.py 97 | ``` 98 | param_text_file as follows, 99 | - `opt_ours.txt`: download [pretrained weight on WorldCup dataset](https://cgv.cs.nthu.edu.tw/KpSFR_data/model/kpsfr.pth) first and place in [checkpoints](checkpoints). Set `--train_stage` to **0**, `--trainset` to **train_val** and `--loss_mode` to **all** for training on WorldCup train set. For ablation study, download [pretrained weight on TS-WorldCup dataset](https://cgv.cs.nthu.edu.tw/KpSFR_data/model/kpsfr_finetuned.pth) first and place in [checkpoints](checkpoints). Set `--loss_mode` to **dice_bce** or **dice_wce**. 100 | - `opt_our_finetuned.txt`: download [pretrained weight on TS-WorldCup dataset](https://cgv.cs.nthu.edu.tw/KpSFR_data/model/kpsfr_finetuned.pth) first and place in [checkpoints](checkpoints). Set `--train_stage` to **1**, `--trainset` to **train** and `--loss_mode` to **all** for finetuning on TS-WorldCup train set. 101 | 102 | We will save visualize results and weights into **/checkpoints/path of experimental name**, which set `--name` in param_text_file. 103 | 104 | Note: Please check the following arguments to set correct before training every time. 105 | - `--gpu_ids` 106 | - `--name` 107 | - `--train_stage` and `--trainset` 108 | - `--ckpt_path` 109 | - `--loss_mode` 110 | - `--train_epochs` and `--step_size` 111 | 112 | Details refer to [options.py](options.py). 113 | 114 | ## License 115 | This work is licensed under MIT License. See [LICENSE](LICENSE) for details. 116 | 117 | ## Citation 118 | If you find our code/models useful, please consider citing our paper: 119 | ``` 120 | @InProceedings{Chu_2022_CVPR, 121 | author = {Chu, Yen-Jui and Su, Jheng-Wei and Hsiao, Kai-Wen and Lien, Chi-Yu and Fan, Shu-Ho and Hu, Min-Chun and Lee, Ruen-Rone and Yao, Chih-Yuan and Chu, Hung-Kuo}, 122 | title = {Sports Field Registration via Keypoints-Aware Label Condition}, 123 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops}, 124 | year = {2022} 125 | } 126 | ``` 127 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: base 2 | channels: 3 | - local 4 | - conda-forge 5 | - anaconda 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _openmp_mutex=4.5=1_gnu 10 | - attrs=20.3.0=pyhd3eb1b0_0 11 | - backcall=0.2.0=pyhd3eb1b0_0 12 | - beautifulsoup4=4.9.3=pyha847dfd_0 13 | - blas=1.0=openblas 14 | - blosc=1.20.1=hd408876_0 15 | - brotlipy=0.7.0=py38h27cfd23_1003 16 | - bzip2=1.0.8=h7b6447c_0 17 | - ca-certificates=2020.10.14=0 18 | - catalogue=1.0.0=py38_1 19 | - certifi=2020.6.20=py38_0 20 | - cffi=1.14.3=py38h261ae71_2 21 | - chardet=3.0.4=py38h06a4308_1003 22 | - cloudpickle=1.6.0=py_0 23 | - cmake=3.19.6=h973ab73_0 24 | - conda=4.11.0=py38h06a4308_0 25 | - conda-build=3.21.4=py38h06a4308_0 26 | - conda-package-handling=1.7.2=py38h03888b9_0 27 | - cryptography=3.2.1=py38h3c74f83_1 28 | - cymem=2.0.5=py38h2531618_0 29 | - cython-blis=0.7.4=py38h27cfd23_1 30 | - cytoolz=0.11.0=py38h7b6447c_0 31 | - dask-core=2021.7.2=pyhd3eb1b0_0 32 | - dbus=1.13.18=hb2f20db_0 33 | - decorator=5.0.6=pyhd3eb1b0_0 34 | - expat=2.3.0=h2531618_2 35 | - filelock=3.0.12=pyhd3eb1b0_1 36 | - fontconfig=2.13.1=hba837de_1005 37 | - freetype=2.10.4=h5ab3b9f_0 38 | - fsspec=2021.7.0=pyhd3eb1b0_0 39 | - gettext=0.19.8.1=h0b5b191_1005 40 | - glib=2.68.3=h9c3ff4c_0 41 | - glib-tools=2.68.3=h9c3ff4c_0 42 | - glob2=0.7=pyhd3eb1b0_0 43 | - gst-plugins-base=1.14.0=hbbd80ab_1 44 | - gstreamer=1.14.0=h28cd5cc_2 45 | - hdf5=1.10.4=hb1b8bf9_0 46 | - icu=58.2=he6710b0_3 47 | - idna=2.10=py_0 48 | - imageio=2.9.0=pyhd3eb1b0_0 49 | - importlib-metadata=3.10.0=py38h06a4308_0 50 | - importlib_metadata=3.10.0=hd3eb1b0_0 51 | - intel-openmp=2021.2.0=h06a4308_610 52 | - ipython=7.22.0=py38hb070fc8_0 53 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 54 | - jedi=0.17.0=py38_0 55 | - jinja2=2.11.3=pyhd3eb1b0_0 56 | - jpeg=9b=h024ee3a_2 57 | - jsonschema=3.0.2=py38_0 58 | - kiwisolver=1.3.1=py38h2531618_0 59 | - krb5=1.18.2=h173b8e3_0 60 | - lcms2=2.12=h3be6417_0 61 | - ld_impl_linux-64=2.33.1=h53a641e_7 62 | - libarchive=3.4.2=h62408e4_0 63 | - libcurl=7.71.1=h20c2e04_1 64 | - libedit=3.1.20191231=h14c3975_1 65 | - libffi=3.3=he6710b0_2 66 | - libgcc-ng=9.3.0=h5101ec6_17 67 | - libgfortran-ng=7.3.0=hdf63c60_0 68 | - libglib=2.68.3=h3e27bee_0 69 | - libgomp=9.3.0=h5101ec6_17 70 | - libiconv=1.16=h516909a_0 71 | - liblief=0.10.1=he6710b0_0 72 | - libllvm10=10.0.1=hbcb73fb_5 73 | - libopenblas=0.3.10=h5a2b251_0 74 | - libpng=1.6.37=hbc83047_0 75 | - libssh2=1.9.0=h1ba5d50_1 76 | - libstdcxx-ng=11.2.0=he4da1e4_11 77 | - libtiff=4.2.0=h85742a9_0 78 | - libuuid=2.32.1=h7f98852_1000 79 | - libuv=1.40.0=h7b6447c_0 80 | - libwebp-base=1.2.0=h27cfd23_0 81 | - libxcb=1.13=h7f98852_1003 82 | - libxml2=2.9.10=hb55368b_3 83 | - llvmlite=0.35.0=py38h612dafd_4 84 | - locket=0.2.1=py38h06a4308_1 85 | - lz4-c=1.9.3=h2531618_0 86 | - lzo=2.10=h7b6447c_2 87 | - magma-cuda110=2.5.2=5 88 | - markupsafe=1.1.1=py38h7b6447c_0 89 | - matplotlib=3.4.1=py38h578d9bd_0 90 | - matplotlib-base=3.4.1=py38hcc49a3a_0 91 | - mkl=2019.4=243 92 | - mkl-include=2019.4=243 93 | - mock=4.0.3=pyhd3eb1b0_0 94 | - murmurhash=1.0.5=py38h2531618_0 95 | - ncurses=6.2=he6710b0_1 96 | - ninja=1.10.2=hff7bd54_1 97 | - nomkl=3.0=0 98 | - numba=0.52.0=py38ha9443f7_0 99 | - numexpr=2.7.1=py38h7ea95a0_0 100 | - numpy=1.20.1=py38h5a90a98_0 101 | - numpy-base=1.20.1=py38h34387ca_0 102 | - olefile=0.46=py_0 103 | - openblas=0.3.10=0 104 | - openblas-devel=0.3.10=0 105 | - openjpeg=2.3.0=h05c96fa_1 106 | - openssl=1.1.1m=h7f8727e_0 107 | - parso=0.8.2=pyhd3eb1b0_0 108 | - partd=1.2.0=pyhd3eb1b0_0 109 | - patchelf=0.12=h2531618_1 110 | - pcre=8.45=h9c3ff4c_0 111 | - pexpect=4.8.0=pyhd3eb1b0_3 112 | - pickleshare=0.7.5=pyhd3eb1b0_1003 113 | - pillow=8.3.1=py38h2c7a002_0 114 | - pip=20.2.4=py38h06a4308_0 115 | - pkginfo=1.7.0=py38h06a4308_0 116 | - plac=1.1.0=py38_1 117 | - preshed=3.0.2=py38he6710b0_1 118 | - prompt-toolkit=3.0.17=pyh06a4308_0 119 | - psutil=5.8.0=py38h27cfd23_1 120 | - pthread-stubs=0.4=h36c2ea0_1001 121 | - ptyprocess=0.7.0=pyhd3eb1b0_2 122 | - py-lief=0.10.1=py38h403a769_0 123 | - pycosat=0.6.3=py38h7b6447c_1 124 | - pycparser=2.20=py_2 125 | - pygments=2.8.1=pyhd3eb1b0_0 126 | - pyopenssl=19.1.0=pyhd3eb1b0_1 127 | - pyparsing=2.4.7=pyhd3eb1b0_0 128 | - pyqt=5.9.2=py38h05f1152_4 129 | - pyrsistent=0.17.3=py38h7b6447c_0 130 | - pysocks=1.7.1=py38h06a4308_0 131 | - pytables=3.6.1=py38h9fd0a39_0 132 | - python=3.8.8=hdb3f193_4 133 | - python-libarchive-c=2.9=pyhd3eb1b0_1 134 | - python_abi=3.8=2_cp38 135 | - pytz=2021.1=pyhd3eb1b0_0 136 | - pywavelets=1.1.1=py38h7b6447c_2 137 | - pyyaml=5.4.1=py38h27cfd23_1 138 | - qt=5.9.7=h5867ecd_1 139 | - readline=8.0=h7b6447c_0 140 | - requests=2.24.0=py_0 141 | - rhash=1.4.1=h3c74f83_1 142 | - ripgrep=12.1.1=0 143 | - ruamel_yaml=0.15.87=py38h7b6447c_1 144 | - scipy=1.6.2=py38hf56f3a7_1 145 | - setuptools=50.3.1=py38h06a4308_1 146 | - sip=4.19.13=py38he6710b0_0 147 | - six=1.15.0=py38h06a4308_0 148 | - soupsieve=2.2.1=pyhd3eb1b0_0 149 | - spacy=2.3.5=py38hff7bd54_0 150 | - sqlite=3.33.0=h62c20be_0 151 | - srsly=1.0.5=py38h2531618_0 152 | - tbb=2020.3=hfd86e86_0 153 | - thinc=7.4.5=py38h9a67853_0 154 | - tifffile=2020.10.1=py38hdd07704_2 155 | - tk=8.6.10=hbc83047_0 156 | - toolz=0.11.1=pyhd3eb1b0_0 157 | - tornado=6.1=py38h27cfd23_0 158 | - traitlets=5.0.5=pyhd3eb1b0_0 159 | - urllib3=1.25.11=py_0 160 | - wasabi=0.8.2=pyhd3eb1b0_0 161 | - wcwidth=0.2.5=py_0 162 | - wheel=0.35.1=pyhd3eb1b0_0 163 | - xorg-libxau=1.0.9=h7f98852_0 164 | - xorg-libxdmcp=1.1.3=h7f98852_0 165 | - xz=5.2.5=h7b6447c_0 166 | - yaml=0.2.5=h7b6447c_0 167 | - zipp=3.4.1=pyhd3eb1b0_0 168 | - zlib=1.2.11=h7b6447c_3 169 | - zstd=1.4.5=h9ceee32_0 170 | - pip: 171 | - absl-py==0.12.0 172 | - alabaster==0.7.12 173 | - apex==0.1 174 | - appdirs==1.4.4 175 | - argon2-cffi==20.1.0 176 | - ascii-graph==1.5.1 177 | - async-generator==1.10 178 | - audioread==2.1.9 179 | - autopep8==1.5.7 180 | - babel==2.9.1 181 | - bleach==3.3.0 182 | - boto3==1.17.60 183 | - botocore==1.20.60 184 | - cachetools==4.2.2 185 | - click==7.1.2 186 | - codecov==2.1.11 187 | - coverage==5.5 188 | - cxxfilt==0.2.2 189 | - cycler==0.10.0 190 | - cython==0.28.4 191 | - dataproperty==0.50.1 192 | - defusedxml==0.7.1 193 | - dllogger==0.1.0 194 | - docutils==0.16 195 | - entrypoints==0.3 196 | - flake8==3.7.9 197 | - flask==1.1.2 198 | - future==0.18.2 199 | - google-auth==1.35.0 200 | - google-auth-oauthlib==0.4.5 201 | - graphsurgeon==0.4.5 202 | - grpcio==1.37.0 203 | - h5py==3.2.1 204 | - html2text==2020.1.16 205 | - hypothesis==4.50.8 206 | - imagesize==1.2.0 207 | - inflect==5.3.0 208 | - iniconfig==1.1.1 209 | - ipdb==0.13.7 210 | - ipykernel==5.5.3 211 | - itsdangerous==1.1.0 212 | - jmespath==0.10.0 213 | - joblib==1.0.1 214 | - json5==0.9.5 215 | - jupyter-client==6.1.12 216 | - jupyter-core==4.7.1 217 | - jupyter-tensorboard==0.2.0 218 | - jupyterlab==2.3.1 219 | - jupyterlab-pygments==0.1.2 220 | - jupyterlab-server==1.2.0 221 | - jupytext==1.11.1 222 | - librosa==0.8.0 223 | - lmdb==1.2.1 224 | - mako==1.1.4 225 | - markdown==3.3.4 226 | - markdown-it-py==0.6.2 227 | - maskrcnn-benchmark==0.1 228 | - mbstrdecoder==1.0.1 229 | - mccabe==0.6.1 230 | - mdit-py-plugins==0.2.6 231 | - mistune==0.8.4 232 | - mlperf-compliance==0.0.10 233 | - msgfy==0.1.0 234 | - nbclient==0.5.3 235 | - nbconvert==6.0.7 236 | - nbformat==5.1.3 237 | - nest-asyncio==1.5.1 238 | - networkx==2.0 239 | - nltk==3.6.2 240 | - notebook==6.2.0 241 | - nvidia-dali-cuda110==1.0.0 242 | - nvidia-dlprof-pytorch-nvtx==1.1.0 243 | - nvidia-pyprof==3.10.0 244 | - oauthlib==3.1.1 245 | - onnx==1.8.0 246 | - onnxruntime==1.7.0 247 | - packaging==20.9 248 | - pandas==1.1.4 249 | - pandocfilters==1.4.3 250 | - pathvalidate==2.4.1 251 | - pillow-simd==7.0.0.post3 252 | - pluggy==0.13.1 253 | - polygraphy==0.29.1 254 | - pooch==1.3.0 255 | - prettytable==2.1.0 256 | - progressbar==2.5 257 | - prometheus-client==0.10.1 258 | - protobuf==3.15.8 259 | - py==1.10.0 260 | - pyasn1==0.4.8 261 | - pyasn1-modules==0.2.8 262 | - pybind11==2.6.2 263 | - pycocotools==2.0+nv0.5.1 264 | - pycodestyle==2.7.0 265 | - pycuda==2020.1 266 | - pydot==1.4.2 267 | - pyflakes==2.1.1 268 | - pynvml==8.0.4 269 | - pytablewriter==0.47.0 270 | - pytest==6.2.3 271 | - pytest-cov==2.11.1 272 | - pytest-pythonpath==0.7.3 273 | - python-dateutil==2.8.1 274 | - python-hostlist==1.21 275 | - python-nvd3==0.15.0 276 | - python-slugify==4.0.1 277 | - pytools==2021.2.6 278 | - pytorch-quantization==2.1.0 279 | - pytorch-transformers==1.1.0 280 | - pyzmq==22.0.3 281 | - regex==2021.4.4 282 | - requests-oauthlib==1.3.0 283 | - resampy==0.2.2 284 | - revtok==0.0.3 285 | - rsa==4.7.2 286 | - s3transfer==0.4.2 287 | - sacrebleu==1.2.10 288 | - sacremoses==0.0.35 289 | - scikit-image==0.15.0 290 | - scikit-learn==0.24.2 291 | - send2trash==1.5.0 292 | - sentencepiece==0.1.95 293 | - shapely==1.7.1 294 | - snowballstemmer==2.1.0 295 | - soundfile==0.10.3.post1 296 | - sox==1.4.1 297 | - sphinx==3.5.4 298 | - sphinx-glpi-theme==0.3 299 | - sphinx-rtd-theme==0.5.2 300 | - sphinxcontrib-applehelp==1.0.2 301 | - sphinxcontrib-devhelp==1.0.2 302 | - sphinxcontrib-htmlhelp==1.0.3 303 | - sphinxcontrib-jsmath==1.0.1 304 | - sphinxcontrib-qthelp==1.0.3 305 | - sphinxcontrib-serializinghtml==1.1.4 306 | - subword-nmt==0.3.3 307 | - tabledata==1.1.3 308 | - tabulate==0.8.9 309 | - tensorboard==2.6.0 310 | - tensorboard-data-server==0.6.1 311 | - tensorboard-plugin-wit==1.8.0 312 | - tensorrt==7.2.3.4 313 | - terminado==0.9.4 314 | - testpath==0.4.4 315 | - text-unidecode==1.3 316 | - threadpoolctl==2.1.0 317 | - toml==0.10.2 318 | - torch==1.9.0a0+2ecb2c7 319 | - torchtext==0.10.0a0 320 | - torchvision==0.9.0a0 321 | - tqdm==4.53.0 322 | - typepy==1.1.5 323 | - typing-extensions==3.7.4.3 324 | - uff==0.6.9 325 | - unidecode==1.2.0 326 | - webencodings==0.5.1 327 | - werkzeug==1.0.1 328 | - wrapt==1.10.11 329 | - yacs==0.1.8 330 | -------------------------------------------------------------------------------- /worldcup_train_loader.py: -------------------------------------------------------------------------------- 1 | ''' 2 | random pick 4 from all keypoints 3 | ''' 4 | 5 | import random 6 | import glob 7 | import os 8 | import os.path as osp 9 | import numpy as np 10 | from PIL import Image 11 | import torch 12 | from torch.utils import data 13 | from torchvision import transforms 14 | import matplotlib.pyplot as plt 15 | import skimage.segmentation as ss 16 | from typing import Optional 17 | import utils 18 | 19 | 20 | class StaticTransformDataset(data.Dataset): 21 | 22 | def __init__(self, root, data_type, mode, num_objects, noise_trans: Optional[float] = None, noise_rotate: Optional[float] = None): 23 | 24 | self.frame_h = 720 25 | self.frame_w = 1280 26 | self.root = root 27 | self.data_type = data_type 28 | self.mode = mode 29 | self.num_objects = num_objects 30 | self.noise_trans = noise_trans 31 | self.noise_rotate = noise_rotate 32 | 33 | frame_list = [osp.basename(name) for name in glob.glob( 34 | osp.join(self.root, self.data_type, '*.jpg'))] 35 | self.frames = [img for img in sorted( 36 | frame_list, key=lambda x: int(x[:-4]))] 37 | 38 | homographies_list = [osp.basename(name) for name in glob.glob( 39 | osp.join(self.root, self.data_type, '*.homographyMatrix'))] 40 | self.homographies = [mat for mat in sorted( 41 | homographies_list, key=lambda x: int(x[:-17]))] 42 | 43 | self.preprocess = transforms.Compose([ 44 | transforms.ToTensor(), 45 | transforms.Normalize([0.485, 0.456, 0.406], 46 | [0.229, 0.224, 0.225]), # ImageNet 47 | ]) 48 | 49 | def __len__(self): 50 | 51 | return len(self.frames) 52 | 53 | def __getitem__(self, index): 54 | 55 | image = np.array(Image.open( 56 | osp.join(self.root, self.data_type, self.frames[index]))) 57 | 58 | gt_h = np.loadtxt( 59 | osp.join(self.root, self.data_type, self.homographies[index])) 60 | 61 | template_grid = utils.gen_template_grid() # template grid shape (91, 3) 62 | 63 | image_list = [] 64 | homo_mat_list = [] 65 | pairwise_seed = random.randint(0, 2147483647) 66 | f1_seed = random.randint(0, 2147483647) 67 | f2_seed = random.randint(0, 2147483647) 68 | f3_seed = random.randint(0, 2147483647) 69 | choice1_cls_seed = random.randint(0, 2147483647) 70 | choice2_cls_seed = random.randint(0, 2147483647) 71 | choice3_cls_seed = random.randint(0, 2147483647) 72 | obj_seed = random.randint(0, 2147483647) 73 | dilated_hm_list = [] 74 | 75 | # TODO: augmentation to get warp_image, warp_grid, heatmap, pert_homo of each training sample 76 | for f in range(3): 77 | if f == 0: 78 | random.seed(f1_seed) 79 | elif f == 1: 80 | random.seed(f2_seed) 81 | elif f == 2: 82 | random.seed(f3_seed) 83 | 84 | # warp grid shape (91, 3) 85 | warp_image, warp_grid, homo_mat = utils.gen_im_whole_grid( 86 | self.mode, image, f, gt_h, template_grid, self.noise_trans, self.noise_rotate, index) 87 | 88 | # Each keypoints is considered as an object 89 | num_pts = warp_grid.shape[0] 90 | pil_image = Image.fromarray(warp_image) 91 | 92 | # TODO: apply random horizontal flip to all the image and grid points 93 | random.seed(pairwise_seed) 94 | if self.mode == 'train' and random.random() < 0.5: 95 | pil_image, warp_grid = utils.put_lrflip_augmentation( 96 | pil_image, warp_grid) 97 | 98 | image_tensor = self.preprocess(pil_image) 99 | image_list.append(image_tensor) 100 | homo_mat_list.append(homo_mat) 101 | 102 | # By default, all keypoints belong to background 103 | # C*H*W, C:91, exclude background class 104 | heatmaps = np.zeros( 105 | (num_pts, self.frame_h // 4, self.frame_w // 4), dtype=np.float32) 106 | dilated_heatmaps = np.zeros_like(heatmaps) 107 | 108 | for keypts_label in range(num_pts): 109 | if np.isnan(warp_grid[keypts_label, 0]) and np.isnan(warp_grid[keypts_label, 1]): 110 | continue 111 | px = np.rint(warp_grid[keypts_label, 0] / 4).astype(np.int32) 112 | py = np.rint(warp_grid[keypts_label, 1] / 4).astype(np.int32) 113 | cls = int(warp_grid[keypts_label, 2]) - 1 114 | if 0 <= px < (self.frame_w // 4) and 0 <= py < (self.frame_h // 4): 115 | heatmaps[cls][py, px] = warp_grid[keypts_label, 2] 116 | dilated_heatmaps[cls] = ss.expand_labels( 117 | heatmaps[cls], distance=5) 118 | dilated_hm_list.append(dilated_heatmaps) 119 | 120 | # Those keypoints appears on the first frame 121 | labels = np.unique(dilated_hm_list[0]) 122 | labels = labels[labels != 0] # Remove background class 123 | 124 | dilated_hm_list = np.stack(dilated_hm_list, axis=0) # 3*91*H*W 125 | T, _, H, W = dilated_hm_list.shape 126 | 127 | # TODO: keypoints appear/disappear augmentation 128 | target_dilated_hm_list = torch.zeros((self.num_objects, T, H, W)) 129 | lookup_list = [] 130 | for f in range(3): 131 | labels = np.unique(dilated_hm_list[0]) 132 | labels = labels[labels != 0] # remove background class 133 | lookup = np.ones(self.num_objects, dtype=np.float32) * -1 134 | 135 | # hard level 136 | if f == 0: 137 | random.seed(choice1_cls_seed) 138 | elif f == 1: 139 | random.seed(choice2_cls_seed) 140 | elif f == 2: 141 | random.seed(choice3_cls_seed) 142 | 143 | if len(labels) < 4: 144 | print('b', labels.tolist()) 145 | for idx, obj in enumerate(labels): 146 | lookup[idx] = obj 147 | else: 148 | for idx in range(self.num_objects): 149 | if len(labels) > 0: 150 | target_object = random.choice(labels) 151 | labels = labels[labels != target_object] 152 | lookup[idx] = target_object 153 | else: 154 | print('Less than four classes') 155 | 156 | lookup_list.append(lookup) 157 | 158 | lookup_list = np.stack(lookup_list, axis=0) # T*CK:4 159 | 160 | # Label reorder 161 | new_lookup_list = torch.ones((3, self.num_objects)) * -1 162 | new_selector_list = torch.ones_like(new_lookup_list) 163 | 164 | inter01 = np.intersect1d(lookup_list[0], lookup_list[1]) 165 | non_inter01 = np.setdiff1d(lookup_list[0], lookup_list[1]) 166 | non_inter10 = np.setdiff1d(lookup_list[1], lookup_list[0]) 167 | new0 = np.concatenate((inter01, non_inter01), axis=0) 168 | new1 = np.concatenate((inter01, non_inter10), axis=0) 169 | inter12, inter1_ind, _ = np.intersect1d( 170 | new1, lookup_list[2], return_indices=True) 171 | non_inter21 = np.setdiff1d(lookup_list[2], new1) 172 | new_lookup_list[0, :] = utils.to_torch(new0) 173 | new_lookup_list[1, :] = utils.to_torch(new1) 174 | new_lookup_list[2, inter1_ind] = utils.to_torch(inter12) 175 | remain_ind = torch.where(new_lookup_list[2] == -1)[0] 176 | new_lookup_list[2, remain_ind] = utils.to_torch(non_inter21) 177 | 178 | new_selector_list[new_lookup_list == -1] = 0 179 | 180 | dilated_hm_list = utils.to_torch(dilated_hm_list) 181 | for f in range(3): 182 | for idx, obj in enumerate(new_lookup_list[f]): 183 | if obj != -1: 184 | target_dilated_hm = dilated_hm_list[f, int( 185 | obj)-1].clone() # H*W 186 | target_dilated_hm[target_dilated_hm == obj] = 1 187 | target_dilated_hm_list[idx, f] = target_dilated_hm 188 | 189 | # TODO: union of ground truth segmentation of all objects 190 | cls_gt = torch.zeros((3, H, W)) 191 | for f in range(3): 192 | for idx in range(self.num_objects): 193 | cls_gt[f][target_dilated_hm_list[idx, f] == 1] = idx + 1 194 | 195 | image_list = torch.stack(image_list, dim=0) # (3, 3, 720, 1280) 196 | homo_mat_list = np.stack(homo_mat_list, axis=0) 197 | 198 | data = {} 199 | data['rgb'] = image_list 200 | data['target_dilated_hm'] = target_dilated_hm_list 201 | data['cls_gt'] = cls_gt 202 | data['gt_homo'] = homo_mat_list 203 | data['selector'] = new_selector_list 204 | data['lookup'] = new_lookup_list 205 | 206 | return data 207 | 208 | 209 | if __name__ == "__main__": 210 | 211 | static_loader = StaticTransformDataset( 212 | root='./dataset/soccer_worldcup_2014/soccer_data', data_type='train_val', mode='train', num_objects=4, noise_trans=5.0, noise_rotate=0.0084) 213 | 214 | import shutil 215 | cnt = 1 216 | visual_dir = osp.join('visual', 'static') 217 | if osp.exists(visual_dir): 218 | print(f'Remove directory: {visual_dir}') 219 | shutil.rmtree(visual_dir) 220 | print(f'Create directory: {visual_dir}') 221 | os.makedirs(visual_dir, exist_ok=True) 222 | 223 | denorm = utils.UnNormalize( 224 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 225 | 226 | for idx, data in enumerate(static_loader): 227 | image = data['rgb'] 228 | mask = data['target_dilated_hm'] 229 | cls_gt = data['cls_gt'] 230 | lookup = data['lookup'] 231 | # === debug === 232 | print(f'number of frames: {cls_gt.shape[0]}') 233 | print(image.shape, mask.shape, cls_gt.shape) 234 | print('lookup:\n', lookup) 235 | for j in range(cls_gt.shape[0]): 236 | print(torch.unique(cls_gt[j])) 237 | plt.imsave(osp.join(visual_dir, 'seg%03d.jpg' % 238 | (j + 1)), utils.to_numpy(cls_gt[j]), vmin=0, vmax=4) 239 | plt.imsave(osp.join(visual_dir, 'rgb%03d.jpg' % 240 | (j + 1)), utils.im_to_numpy(denorm(image[j]))) 241 | for i in range(4): 242 | plt.imsave(osp.join(visual_dir, '%d_dilated_mask_obj%d.jpg' % ( 243 | j + 1, i + 1)), utils.to_numpy(mask[i, j])) 244 | cnt += 1 245 | assert False 246 | pass 247 | -------------------------------------------------------------------------------- /models/network.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modifed from STCN: https://github.com/hkchengrex/STCN 3 | network.py - The core of the neural network 4 | """ 5 | 6 | import math 7 | from matplotlib.pyplot import xscale 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from models.modules import * 14 | 15 | 16 | def weights_init(m): 17 | # Initialize filters with Gaussian random weights 18 | if isinstance(m, nn.Conv2d): 19 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 20 | m.weight.data.normal_(0, math.sqrt(2. / n)) 21 | if m.bias is not None: 22 | m.bias.data.zero_() 23 | elif isinstance(m, nn.BatchNorm2d): 24 | m.weight.data.fill_(1) 25 | m.bias.data.zero_() 26 | 27 | 28 | class Decoder(nn.Module): 29 | 30 | def __init__(self, model_archi): 31 | super().__init__() 32 | 33 | self.model_archi = model_archi 34 | 35 | self.n_classes = 1 36 | # self.n_classes = 92 37 | 38 | # part of encoder 39 | self.midconv_1 = nn.Sequential( 40 | nn.Conv2d(512, 1024, kernel_size=3, 41 | stride=2, padding=1, bias=False), 42 | nn.BatchNorm2d(1024), 43 | nn.ReLU(inplace=True), 44 | nn.Conv2d(1024, 1024, kernel_size=3, 45 | stride=1, padding=1, bias=False), 46 | nn.BatchNorm2d(1024), 47 | nn.ReLU(inplace=True) 48 | ) 49 | 50 | self.midconv_2 = nn.Sequential( 51 | nn.Upsample(size=(23, 40), mode='bilinear', 52 | align_corners=True), 53 | 54 | nn.Conv2d(1024, 512, kernel_size=3, 55 | stride=1, padding=1, bias=False), 56 | nn.BatchNorm2d(512), 57 | nn.ReLU(inplace=True) 58 | ) 59 | 60 | self.up_32_16 = self.up_conv(1024, 256, (45, 80)) 61 | self.up_16_8 = self.up_conv(512, 128, (90, 160)) 62 | self.up_8_4 = self.up_conv(256, 64, (180, 320)) 63 | 64 | self.deconv1 = nn.Sequential( 65 | nn.Conv2d(128, 64, kernel_size=3, padding=1, bias=False), 66 | nn.BatchNorm2d(64), 67 | nn.ReLU(inplace=True), 68 | nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False), 69 | nn.BatchNorm2d(64), 70 | nn.ReLU(inplace=True) 71 | ) 72 | self.conv_last = nn.Conv2d(64, 16, kernel_size=1) 73 | 74 | # Weights initialize 75 | self.midconv_1.apply(weights_init) 76 | self.midconv_2.apply(weights_init) 77 | self.up_32_16.apply(weights_init) 78 | self.up_16_8.apply(weights_init) 79 | self.up_8_4.apply(weights_init) 80 | self.deconv1.apply(weights_init) 81 | self.conv_last.apply(weights_init) 82 | 83 | # task-awareness of DoDNet, generate conv filters for classification layer 84 | self.GAP = nn.AdaptiveAvgPool2d((1, 1)) 85 | self.controller = nn.Conv2d( 86 | 1024+91, 561, kernel_size=1, stride=1, padding=0) 87 | 88 | def up_conv(self, in_channels, out_channels, size): 89 | 90 | return nn.Sequential( 91 | nn.Conv2d(in_channels, in_channels // 2, 92 | kernel_size=3, padding=1, bias=False), 93 | nn.BatchNorm2d(in_channels // 2), 94 | nn.ReLU(inplace=True), 95 | nn.Conv2d(in_channels // 2, in_channels // 96 | 2, kernel_size=3, padding=1, bias=False), 97 | nn.BatchNorm2d(in_channels//2), 98 | nn.ReLU(inplace=True), 99 | 100 | nn.Upsample(size=size, mode='bilinear', align_corners=True), 101 | 102 | nn.Conv2d(in_channels // 2, out_channels, 103 | kernel_size=3, padding=1, bias=False), 104 | nn.BatchNorm2d(out_channels), 105 | nn.ReLU(inplace=True) 106 | ) 107 | 108 | def encoding_task(self, task_id): 109 | 110 | N = task_id.shape[0] 111 | task_encoding = torch.zeros(size=(N, 91), device=task_id.device) 112 | for i in range(N): 113 | if task_id[i] != -1: 114 | task_encoding[i, int(task_id[i]) - 1] = 1 115 | return task_encoding 116 | 117 | def parse_dynamic_params(self, params, channels, weight_nums, bias_nums): 118 | 119 | assert params.dim() == 2 120 | assert len(weight_nums) == len(bias_nums) 121 | assert params.size(1) == sum(weight_nums) + sum(bias_nums) 122 | 123 | num_insts = params.size(0) 124 | num_layers = len(weight_nums) 125 | 126 | params_splits = list(torch.split_with_sizes( 127 | params, weight_nums + bias_nums, dim=1 128 | )) 129 | 130 | weight_splits = params_splits[:num_layers] 131 | bias_splits = params_splits[num_layers:] 132 | 133 | for l in range(num_layers): 134 | if l < num_layers - 1: 135 | weight_splits[l] = weight_splits[l].reshape( 136 | num_insts * channels, -1, 1, 1) 137 | bias_splits[l] = bias_splits[l].reshape(num_insts * channels) 138 | else: 139 | weight_splits[l] = weight_splits[l].reshape( 140 | num_insts * 1, -1, 1, 1) 141 | bias_splits[l] = bias_splits[l].reshape(num_insts * 1) 142 | 143 | return weight_splits, bias_splits 144 | 145 | def heads_forward(self, features, weights, biases, num_insts): 146 | 147 | assert features.dim() == 4 148 | n_layers = len(weights) 149 | x = features 150 | for i, (w, b) in enumerate(zip(weights, biases)): 151 | x = F.conv2d( 152 | x, w, bias=b, 153 | stride=1, padding=0, 154 | groups=num_insts 155 | ) 156 | if i < n_layers - 1: 157 | x = F.relu(x) 158 | return x 159 | 160 | def forward(self, f32, f16, f8, f4, cls): 161 | 162 | x = self.midconv_1(f32) 163 | 164 | task_encoding = self.encoding_task(cls) 165 | task_encoding.unsqueeze_(2).unsqueeze_(2) 166 | x_feat = self.GAP(x) 167 | x_cond = torch.cat([x_feat, task_encoding], dim=1) 168 | params = self.controller(x_cond) 169 | params.squeeze_(-1).squeeze_(-1) 170 | 171 | x = self.midconv_2(x) 172 | 173 | x = torch.cat([x, f32], dim=1) 174 | 175 | x = self.up_32_16(x) 176 | x = torch.cat([x, f16], dim=1) 177 | 178 | x = self.up_16_8(x) 179 | x = torch.cat([x, f8], dim=1) 180 | 181 | x = self.up_8_4(x) 182 | x = torch.cat([x, f4], dim=1) 183 | 184 | x = self.deconv1(x) 185 | x = self.conv_last(x) 186 | 187 | N, _, H, W = x.shape 188 | head_inputs = x.reshape(1, -1, H, W) 189 | 190 | weight_nums, bias_nums = [], [] 191 | weight_nums.append(16*16) 192 | weight_nums.append(16*16) 193 | weight_nums.append(16*1) 194 | bias_nums.append(16) 195 | bias_nums.append(16) 196 | bias_nums.append(1) 197 | weights, biases = self.parse_dynamic_params( 198 | params, 16, weight_nums, bias_nums) 199 | 200 | logits = self.heads_forward(head_inputs, weights, biases, N) 201 | logits = logits.reshape(-1, 1, H, W) 202 | logits_origin = F.interpolate( 203 | logits, scale_factor=4, mode='bilinear', align_corners=False) 204 | 205 | return logits_origin 206 | 207 | 208 | class KpSFR(nn.Module): 209 | 210 | def __init__(self, model_archi, num_objects, non_local): 211 | super(KpSFR, self).__init__() 212 | 213 | self.model_archi = model_archi 214 | 215 | self.key_encoder = KeyEncoder( 216 | num_objects=num_objects, non_local=non_local) 217 | 218 | # TODO: random pick 4 219 | self.value_encoder = ValueEncoder( 220 | num_objects=num_objects, non_local=non_local) 221 | 222 | # Projection from f32 feature space to key space 223 | self.key_proj = KeyProjection(512, keydim=64) 224 | 225 | # Compress f32 a bit to use in decoding later on 226 | self.key_comp = nn.Conv2d(512, 256, kernel_size=3, padding=1) 227 | 228 | self.decoder = Decoder(model_archi=self.model_archi) 229 | 230 | def aggregate(self, prob): 231 | # values of prob is from (0, 1) 232 | assert prob.max() <= 1, print('out of value') 233 | assert prob.min() >= 0, print('out of value') 234 | new_prob = torch.cat([ 235 | torch.prod(1 - prob, dim=1, keepdim=True), 236 | prob 237 | ], dim=1).clamp(1e-7, 1 - 1e-7) 238 | logits = torch.log((new_prob / (1 - new_prob))) 239 | return logits # the inverse of sigmoid function 240 | 241 | def encode_key(self, frame, qcls=None): 242 | # input: b*t*c*h*w 243 | b, t = frame.shape[:2] 244 | 245 | f32, f16, f8, f4 = self.key_encoder( 246 | frame.flatten(start_dim=0, end_dim=1), qcls) 247 | assert torch.isnan(f32).sum() == 0, print('qf32: ', f32) 248 | assert torch.isnan(f16).sum() == 0, print('qf16: ', f16) 249 | assert torch.isnan(f8).sum() == 0, print('qf8: ', f8) 250 | assert torch.isnan(f4).sum() == 0, print('qf4: ', f4) 251 | 252 | # B*T*C*H*W 253 | f32 = f32.view(b, t, *f32.shape[-3:]) 254 | f16 = f16.view(b, t, *f16.shape[-3:]) 255 | f8 = f8.view(b, t, *f8.shape[-3:]) 256 | f4 = f4.view(b, t, *f4.shape[-3:]) 257 | 258 | return f32, f16, f8, f4 259 | 260 | def encode_value(self, frame, kf32, mask, other_masks=None, lookup=None, isFirst=False): 261 | # TODO: random pick 4 262 | # Extract memory key/value for a frame 263 | # input mask/other_masks: b*1*h*w, b*(objs-1)*h*w 264 | f32 = self.value_encoder( 265 | frame, kf32, mask, other_masks, lookup, isFirst) 266 | 267 | return f32.unsqueeze(2) # B*256*T*H*W 268 | 269 | def segment(self, qf32, qf16, qf8, qf4, k, qcls, selector=None): 270 | # q - query, m - memory 271 | # qv32 is f32_thin above 272 | 273 | assert torch.isnan(qf32).sum() == 0, print('qf32: ', qf32) 274 | assert torch.isnan(qf16).sum() == 0, print('qf16: ', qf16) 275 | assert torch.isnan(qf8).sum() == 0, print('qf8: ', qf8) 276 | assert torch.isnan(qf4).sum() == 0, print('qf4: ', qf4) 277 | 278 | print('model archi: ', self.model_archi) 279 | x_origin = self.decoder(qf32, qf16, qf8, qf4, qcls[:, 0]) 280 | 281 | for obj in range(1, k): 282 | x1_origin = self.decoder(qf32, qf16, qf8, qf4, qcls[:, obj]) 283 | x_origin = torch.cat([x_origin, x1_origin], dim=1) 284 | 285 | assert torch.isnan(x_origin).sum() == 0, print('x_origin: ', x_origin) 286 | 287 | prob_origin = torch.sigmoid(x_origin) 288 | prob_origin = prob_origin * selector.unsqueeze(2).unsqueeze(2) 289 | logits_origin = self.aggregate(prob_origin) # B*(obj+1)*720*1280 290 | prob_origin = F.softmax(logits_origin, dim=1)[:, 1:] # B*obj*720*1280 291 | 292 | return x_origin, logits_origin, prob_origin 293 | 294 | def forward(self, mode, *args, **kwargs): 295 | if mode == 'encode_key': 296 | return self.encode_key(*args, **kwargs) 297 | elif mode == 'encode_value': 298 | return self.encode_value(*args, **kwargs) 299 | elif mode == 'segment': 300 | return self.segment(*args, **kwargs) 301 | else: 302 | raise NotImplementedError 303 | -------------------------------------------------------------------------------- /ts_worldcup_train_loader.py: -------------------------------------------------------------------------------- 1 | ''' 2 | random pick 4 from all keypoints 3 | ''' 4 | 5 | import random 6 | import glob 7 | import os 8 | import os.path as osp 9 | import numpy as np 10 | from PIL import Image 11 | import torch 12 | from torch.utils import data 13 | from torchvision import transforms 14 | import matplotlib.pyplot as plt 15 | import skimage.segmentation as ss 16 | from typing import Optional 17 | import utils 18 | 19 | 20 | class CustomWorldCupDataset(data.Dataset): 21 | 22 | def __init__(self, root, data_type, mode, num_objects, noise_trans: Optional[float] = None, noise_rotate: Optional[float] = None): 23 | 24 | self.frame_h = 720 25 | self.frame_w = 1280 26 | self.root = root 27 | self.data_type = data_type 28 | self.mode = mode 29 | self.num_objects = num_objects 30 | self.noise_trans = noise_trans 31 | self.noise_rotate = noise_rotate 32 | 33 | sequence_interval = '80_95' 34 | self.image_path = osp.join( 35 | self.root, 'Dataset', sequence_interval) 36 | self.anno_path = osp.join( 37 | self.root, 'Annotations', sequence_interval) 38 | imgset_path = osp.join(self.root, self.data_type) 39 | 40 | self.videos = [] 41 | self.num_frames = {} 42 | self.num_homographies = {} 43 | self.frames = [] 44 | self.homographies = [] 45 | 46 | with open(imgset_path + '.txt', 'r') as lines: 47 | for line in lines: 48 | _video = line.rstrip('\n') 49 | self.videos.append(_video) 50 | self.num_frames[_video] = len( 51 | glob.glob(osp.join(self.image_path, _video, '*.jpg'))) 52 | self.num_homographies[_video] = len( 53 | glob.glob(osp.join(self.anno_path, _video, '*_homography.npy'))) 54 | 55 | frames = sorted(os.listdir( 56 | osp.join(self.image_path, _video))) 57 | for img in frames: 58 | self.frames.append( 59 | osp.join(self.image_path, _video, img)) 60 | 61 | homographies = sorted(os.listdir( 62 | osp.join(self.anno_path, _video))) 63 | for mat in homographies: 64 | self.homographies.append( 65 | osp.join(self.anno_path, _video, mat)) 66 | 67 | self.preprocess = transforms.Compose([ 68 | transforms.ToTensor(), 69 | transforms.Normalize([0.485, 0.456, 0.406], 70 | [0.229, 0.224, 0.225]), # ImageNet 71 | ]) 72 | 73 | def __len__(self): 74 | 75 | return len(self.frames) 76 | 77 | def __getitem__(self, index): 78 | 79 | image = np.array(Image.open(self.frames[index])) 80 | 81 | gt_h = np.load(self.homographies[index]) 82 | 83 | template_grid = utils.gen_template_grid() # template grid shape (91, 3) 84 | 85 | image_list = [] 86 | homo_mat_list = [] 87 | pairwise_seed = random.randint(0, 2147483647) 88 | f1_seed = random.randint(0, 2147483647) 89 | f2_seed = random.randint(0, 2147483647) 90 | f3_seed = random.randint(0, 2147483647) 91 | choice1_cls_seed = random.randint(0, 2147483647) 92 | choice2_cls_seed = random.randint(0, 2147483647) 93 | choice3_cls_seed = random.randint(0, 2147483647) 94 | obj_seed = random.randint(0, 2147483647) 95 | dilated_hm_list = [] 96 | 97 | # TODO: augmentation to get warp_image, warp_grid, heatmap, pert_homo of each training sample 98 | for f in range(3): 99 | if f == 0: 100 | random.seed(f1_seed) 101 | elif f == 1: 102 | random.seed(f2_seed) 103 | elif f == 2: 104 | random.seed(f3_seed) 105 | 106 | # warp grid shape (91, 3) 107 | warp_image, warp_grid, homo_mat = utils.gen_im_whole_grid( 108 | self.mode, image, f, gt_h, template_grid, self.noise_trans, self.noise_rotate, index) 109 | 110 | # Each keypoints is considered as an object 111 | num_pts = warp_grid.shape[0] 112 | pil_image = Image.fromarray(warp_image) 113 | 114 | # TODO: apply random horizontal flip to all the image and grid points 115 | random.seed(pairwise_seed) 116 | if self.mode == 'train' and random.random() < 0.5: 117 | pil_image, warp_grid = utils.put_lrflip_augmentation( 118 | pil_image, warp_grid) 119 | image_tensor = self.preprocess(pil_image) 120 | image_list.append(image_tensor) 121 | homo_mat_list.append(homo_mat) 122 | 123 | # By default, all keypoints belong to background 124 | # C*H*W, C:91, exclude background class 125 | heatmaps = np.zeros( 126 | (num_pts, self.frame_h // 4, self.frame_w // 4), dtype=np.float32) 127 | dilated_heatmaps = np.zeros_like(heatmaps) 128 | 129 | for keypts_label in range(num_pts): 130 | if np.isnan(warp_grid[keypts_label, 0]) and np.isnan(warp_grid[keypts_label, 1]): 131 | continue 132 | px = np.rint(warp_grid[keypts_label, 0] / 4).astype(np.int32) 133 | py = np.rint(warp_grid[keypts_label, 1] / 4).astype(np.int32) 134 | cls = int(warp_grid[keypts_label, 2]) - 1 135 | if 0 <= px < (self.frame_w // 4) and 0 <= py < (self.frame_h // 4): 136 | heatmaps[cls][py, px] = warp_grid[keypts_label, 2] 137 | dilated_heatmaps[cls] = ss.expand_labels( 138 | heatmaps[cls], distance=5) 139 | dilated_hm_list.append(dilated_heatmaps) 140 | 141 | # Those keypoints appears on the first frame 142 | labels = np.unique(dilated_hm_list[0]) 143 | labels = labels[labels != 0] # Remove background class 144 | 145 | dilated_hm_list = np.stack(dilated_hm_list, axis=0) # 3*91*H*W 146 | T, _, H, W = dilated_hm_list.shape 147 | 148 | # TODO: keypoints appear/disappear augmentation 149 | target_dilated_hm_list = torch.zeros((self.num_objects, T, H, W)) 150 | lookup_list = [] 151 | for f in range(3): 152 | labels = np.unique(dilated_hm_list[0]) 153 | labels = labels[labels != 0] # remove background class 154 | lookup = np.ones(self.num_objects, dtype=np.float32) * -1 155 | 156 | # hard level 157 | if f == 0: 158 | random.seed(choice1_cls_seed) 159 | elif f == 1: 160 | random.seed(choice2_cls_seed) 161 | elif f == 2: 162 | random.seed(choice3_cls_seed) 163 | 164 | if len(labels) < 4: 165 | print('b', labels.tolist()) 166 | for idx, obj in enumerate(labels): 167 | lookup[idx] = obj 168 | else: 169 | for idx in range(self.num_objects): 170 | if len(labels) > 0: 171 | target_object = random.choice(labels) 172 | labels = labels[labels != target_object] 173 | lookup[idx] = target_object 174 | else: 175 | print('Less than four classes') 176 | 177 | lookup_list.append(lookup) 178 | 179 | lookup_list = np.stack(lookup_list, axis=0) # T*CK:4 180 | 181 | # Label reorder 182 | new_lookup_list = torch.ones((3, self.num_objects)) * -1 183 | new_selector_list = torch.ones_like(new_lookup_list) 184 | 185 | inter01 = np.intersect1d(lookup_list[0], lookup_list[1]) 186 | non_inter01 = np.setdiff1d(lookup_list[0], lookup_list[1]) 187 | non_inter10 = np.setdiff1d(lookup_list[1], lookup_list[0]) 188 | new0 = np.concatenate((inter01, non_inter01), axis=0) 189 | new1 = np.concatenate((inter01, non_inter10), axis=0) 190 | inter12, inter1_ind, _ = np.intersect1d( 191 | new1, lookup_list[2], return_indices=True) 192 | non_inter21 = np.setdiff1d(lookup_list[2], new1) 193 | new_lookup_list[0, :] = utils.to_torch(new0) 194 | new_lookup_list[1, :] = utils.to_torch(new1) 195 | new_lookup_list[2, inter1_ind] = utils.to_torch(inter12) 196 | remain_ind = torch.where(new_lookup_list[2] == -1)[0] 197 | new_lookup_list[2, remain_ind] = utils.to_torch(non_inter21) 198 | 199 | new_selector_list[new_lookup_list == -1] = 0 200 | 201 | dilated_hm_list = utils.to_torch(dilated_hm_list) 202 | for f in range(3): 203 | for idx, obj in enumerate(new_lookup_list[f]): 204 | if obj != -1: 205 | target_dilated_hm = dilated_hm_list[f, int( 206 | obj)-1].clone() # H*W 207 | target_dilated_hm[target_dilated_hm == obj] = 1 208 | target_dilated_hm_list[idx, f] = target_dilated_hm 209 | 210 | # TODO: union of ground truth segmentation of all objects 211 | cls_gt = torch.zeros((3, H, W)) 212 | for f in range(3): 213 | for idx in range(self.num_objects): 214 | cls_gt[f][target_dilated_hm_list[idx, f] == 1] = idx + 1 215 | 216 | image_list = torch.stack(image_list, dim=0) # (3, 3, 720, 1280) 217 | homo_mat_list = np.stack(homo_mat_list, axis=0) 218 | data = {} 219 | data['rgb'] = image_list 220 | data['target_dilated_hm'] = target_dilated_hm_list 221 | data['cls_gt'] = cls_gt 222 | data['gt_homo'] = homo_mat_list 223 | data['selector'] = new_selector_list 224 | data['lookup'] = new_lookup_list 225 | 226 | return data 227 | 228 | 229 | if __name__ == "__main__": 230 | 231 | custom_worldcup_loader = CustomWorldCupDataset( 232 | root='dataset/WorldCup_2014_2018', data_type='train', mode='train', num_objects=4, noise_trans=5.0, noise_rotate=0.0084) 233 | 234 | import shutil 235 | cnt = 1 236 | visual_dir = osp.join('visual', 'custom') 237 | if osp.exists(visual_dir): 238 | print(f'Remove directory: {visual_dir}') 239 | shutil.rmtree(visual_dir) 240 | print(f'Create directory: {visual_dir}') 241 | os.makedirs(visual_dir, exist_ok=True) 242 | 243 | denorm = utils.UnNormalize( 244 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 245 | 246 | for idx, data in enumerate(custom_worldcup_loader): 247 | image = data['rgb'] 248 | mask = data['target_dilated_hm'] 249 | cls_gt = data['cls_gt'] 250 | lookup = data['lookup'] 251 | # === debug === 252 | print(f'number of frames: {cls_gt.shape[0]}') 253 | print(image.shape, mask.shape, cls_gt.shape) 254 | print('lookup:', lookup) 255 | for j in range(cls_gt.shape[0]): 256 | print(torch.unique(cls_gt[j])) 257 | plt.imsave(osp.join(visual_dir, 'seg%03d.jpg' % 258 | (j + 1)), utils.to_numpy(cls_gt[j]), vmin=0, vmax=4) 259 | plt.imsave(osp.join(visual_dir, 'rgb%03d.jpg' % 260 | (j + 1)), utils.im_to_numpy(denorm(image[j]))) 261 | for i in range(4): 262 | plt.imsave(osp.join(visual_dir, '%d_dilated_mask_obj%d.jpg' % ( 263 | j + 1, i + 1)), utils.to_numpy(mask[i, j])) 264 | cnt += 1 265 | assert False 266 | if cnt >= 11: 267 | break 268 | pass 269 | -------------------------------------------------------------------------------- /models/mod_resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modifed from STCN: https://github.com/hkchengrex/STCN 3 | mod_resnet.py - A modified ResNet structure 4 | We append extra channels to the first conv by some network surgery 5 | """ 6 | 7 | from collections import OrderedDict 8 | import math 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch.utils import model_zoo 13 | 14 | from models.non_local import NLBlockND 15 | 16 | 17 | def weights_init(m): 18 | 19 | if isinstance(m, nn.Conv2d): 20 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 21 | m.weight.data.normal_(0, math.sqrt(2. / n)) 22 | if m.bias is not None: 23 | m.bias.data.zero_() 24 | elif isinstance(m, nn.BatchNorm2d): 25 | m.weight.data.fill_(1) 26 | m.bias.data.zero_() 27 | 28 | 29 | def weights_init2(m): 30 | 31 | if isinstance(m, nn.Conv2d): 32 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 33 | m.weight.data.normal_(0, math.sqrt(2. / n)) 34 | if m.bias is not None: 35 | m.bias.data.zero_() 36 | 37 | 38 | def load_weights_sequential(target, source_state, extra_chan=1): 39 | 40 | new_dict = OrderedDict() 41 | 42 | for k1, v1 in target.state_dict().items(): 43 | if not 'num_batches_tracked' in k1: 44 | if k1 in source_state: 45 | tar_v = source_state[k1] 46 | 47 | if v1.shape != tar_v.shape: 48 | # Init the new segmentation channel with zeros 49 | c, _, w, h = v1.shape 50 | pads = torch.zeros((c, extra_chan, w, h), 51 | device=tar_v.device) 52 | nn.init.orthogonal_(pads) 53 | tar_v = torch.cat([tar_v, pads], 1) 54 | 55 | new_dict[k1] = tar_v 56 | 57 | target.load_state_dict(new_dict, strict=False) 58 | 59 | 60 | model_urls = { 61 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 62 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 63 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 64 | } 65 | 66 | 67 | def conv3x3(in_planes, out_planes, stride=1, dilation=1): 68 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 69 | padding=dilation, dilation=dilation, bias=False) 70 | 71 | 72 | class BasicBlock(nn.Module): 73 | 74 | expansion = 1 75 | 76 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 77 | super(BasicBlock, self).__init__() 78 | 79 | self.conv1 = conv3x3( 80 | inplanes, planes, stride=stride, dilation=dilation) 81 | self.bn1 = nn.BatchNorm2d(planes) 82 | self.relu = nn.ReLU(inplace=True) 83 | self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation) 84 | self.bn2 = nn.BatchNorm2d(planes) 85 | self.downsample = downsample 86 | self.stride = stride 87 | 88 | def forward(self, x): 89 | residual = x 90 | 91 | out = self.conv1(x) 92 | out = self.bn1(out) 93 | out = self.relu(out) 94 | 95 | out = self.conv2(out) 96 | out = self.bn2(out) 97 | 98 | if self.downsample is not None: 99 | residual = self.downsample(x) 100 | 101 | out += residual 102 | out = self.relu(out) 103 | 104 | return out 105 | 106 | 107 | class Bottleneck(nn.Module): 108 | 109 | expansion = 4 110 | 111 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 112 | super(Bottleneck, self).__init__() 113 | 114 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 115 | self.bn1 = nn.BatchNorm2d(planes) 116 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, dilation=dilation, 117 | padding=dilation, bias=False) 118 | self.bn2 = nn.BatchNorm2d(planes) 119 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 120 | self.bn3 = nn.BatchNorm2d(planes * 4) 121 | self.relu = nn.ReLU(inplace=True) 122 | self.downsample = downsample 123 | self.stride = stride 124 | 125 | def forward(self, x): 126 | residual = x 127 | 128 | out = self.conv1(x) 129 | out = self.bn1(out) 130 | out = self.relu(out) 131 | 132 | out = self.conv2(out) 133 | out = self.bn2(out) 134 | out = self.relu(out) 135 | 136 | out = self.conv3(out) 137 | out = self.bn3(out) 138 | 139 | if self.downsample is not None: 140 | residual = self.downsample(x) 141 | 142 | out += residual 143 | out = self.relu(out) 144 | 145 | return out 146 | 147 | 148 | class ResNet(nn.Module): 149 | 150 | def __init__(self, block, layers=(2, 2, 2, 2), extra_chan=1, non_local=False, sub_sample=False, bn_layer=False): 151 | self.inplanes = 64 152 | super(ResNet, self).__init__() 153 | 154 | self.conv1 = nn.Conv2d( 155 | 3 + extra_chan, 64, kernel_size=7, stride=2, padding=3, bias=False) 156 | self.bn1 = nn.BatchNorm2d(64) 157 | self.relu = nn.ReLU(inplace=True) 158 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 159 | self.layer1 = self._make_layer(block, 64, layers[0]) 160 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 161 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 162 | 163 | # TODO: add a few dilated convolution and spatial-only non-local block layers 164 | self.layer4 = self._make_layer( 165 | block, 512, layers[3], stride=2, dilation=2, non_local=non_local, sub_sample=sub_sample, bn_layer=bn_layer) 166 | self.layer4[0].apply(weights_init) 167 | self.layer4[2].apply(weights_init) 168 | self.layer4[1].apply(weights_init2) 169 | self.layer4[3].apply(weights_init2) 170 | 171 | def _make_layer(self, block, planes, num_blocks, stride=1, dilation=1, non_local=False, sub_sample=False, bn_layer=False): 172 | downsample = None 173 | if stride != 1 or self.inplanes != planes * block.expansion: 174 | downsample = nn.Sequential( 175 | nn.Conv2d(self.inplanes, planes * block.expansion, 176 | kernel_size=1, stride=stride, bias=False), 177 | nn.BatchNorm2d(planes * block.expansion), 178 | ) 179 | 180 | strides = [stride] + [1] * (num_blocks - 1) 181 | layers = [] 182 | last_idx = len(strides) 183 | 184 | if non_local: 185 | last_idx = len(strides) - 1 186 | 187 | for i in range(last_idx): 188 | layers.append( 189 | block(self.inplanes, planes, strides[i], downsample, dilation)) 190 | self.inplanes = planes * block.expansion 191 | downsample = None 192 | 193 | if non_local: 194 | layers.append(NLBlockND(in_channels=planes, mode='embedded', 195 | dimension=2, sub_sample=sub_sample, bn_layer=bn_layer)) 196 | layers.append( 197 | block(self.inplanes, planes, strides[-1], dilation=dilation)) 198 | layers.append(NLBlockND(in_channels=planes, mode='embedded', 199 | dimension=2, sub_sample=sub_sample, bn_layer=bn_layer)) 200 | 201 | return nn.Sequential(*layers) 202 | 203 | 204 | class ResNet34(nn.Module): 205 | 206 | def __init__(self, block, layers=(2, 2, 2, 2), non_local=False, sub_sample=False, bn_layer=False): 207 | self.inplanes = 64 208 | super(ResNet34, self).__init__() 209 | 210 | self.conv1 = nn.Conv2d( 211 | 3, 64, kernel_size=7, stride=2, padding=3, bias=False) 212 | self.bn1 = nn.BatchNorm2d(64) 213 | self.relu = nn.ReLU(inplace=True) 214 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 215 | self.layer1 = self._make_layer(block, 64, layers[0]) 216 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 217 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 218 | 219 | # TODO: add a few dilated convolution and spatial-only non-local block layers 220 | self.layer4 = self._make_layer( 221 | block, 512, layers[3], stride=2, dilation=2, non_local=non_local, sub_sample=sub_sample, bn_layer=bn_layer) 222 | self.layer4[0].apply(weights_init) 223 | self.layer4[1].apply(weights_init) 224 | self.layer4[3].apply(weights_init) 225 | self.layer4[2].apply(weights_init2) 226 | self.layer4[4].apply(weights_init2) 227 | 228 | def _make_layer(self, block, planes, num_blocks, stride=1, dilation=1, non_local=False, sub_sample=False, bn_layer=False): 229 | downsample = None 230 | if stride != 1 or self.inplanes != planes * block.expansion: 231 | downsample = nn.Sequential( 232 | nn.Conv2d(self.inplanes, planes * block.expansion, 233 | kernel_size=1, stride=stride, bias=False), 234 | nn.BatchNorm2d(planes * block.expansion), 235 | ) 236 | 237 | strides = [stride] + [1] * (num_blocks - 1) 238 | layers = [] 239 | last_idx = len(strides) 240 | 241 | if non_local: 242 | last_idx = len(strides) - 1 243 | 244 | for i in range(last_idx): 245 | layers.append( 246 | block(self.inplanes, planes, strides[i], downsample, dilation)) 247 | self.inplanes = planes * block.expansion 248 | downsample = None 249 | 250 | if non_local: 251 | layers.append(NLBlockND(in_channels=planes, mode='embedded', 252 | dimension=2, sub_sample=sub_sample, bn_layer=bn_layer)) 253 | layers.append( 254 | block(self.inplanes, planes, strides[-1], dilation=dilation)) 255 | layers.append(NLBlockND(in_channels=planes, mode='embedded', 256 | dimension=2, sub_sample=sub_sample, bn_layer=bn_layer)) 257 | 258 | return nn.Sequential(*layers) 259 | 260 | 261 | def resnet18(pretrained=True, extra_chan=0, non_local=False, sub_sample=False, bn_layer=False): 262 | if non_local: 263 | sub_sample = bn_layer = True 264 | model = ResNet(BasicBlock, [2, 2, 2, 2], 265 | extra_chan, non_local, sub_sample, bn_layer) 266 | if pretrained: 267 | load_weights_sequential(model, model_zoo.load_url( 268 | model_urls['resnet18']), extra_chan) 269 | return model 270 | 271 | 272 | def resnet34(pretrained=True, extra_chan=0, non_local=False, sub_sample=False, bn_layer=False): 273 | if non_local: 274 | sub_sample = bn_layer = True 275 | 276 | model = ResNet34(BasicBlock, [3, 4, 6, 3], non_local, sub_sample, bn_layer) 277 | 278 | if pretrained: 279 | load_weights_sequential(model, model_zoo.load_url( 280 | model_urls['resnet34']), extra_chan) 281 | 282 | return model 283 | 284 | 285 | def resnet50(pretrained=True, extra_chan=0, non_local=False, sub_sample=False, bn_layer=False): 286 | if non_local: 287 | sub_sample = bn_layer = True 288 | model = ResNet(Bottleneck, [3, 4, 6, 3], 289 | extra_chan, non_local, sub_sample, bn_layer) 290 | if pretrained: 291 | load_weights_sequential(model, model_zoo.load_url( 292 | model_urls['resnet50']), extra_chan) 293 | return model 294 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import cv2 4 | import matplotlib.pyplot as plt 5 | from matplotlib.patches import Circle 6 | from PIL import Image 7 | from shapely.geometry import Point, Polygon, MultiPoint 8 | import utils 9 | import torch 10 | import os.path as osp 11 | 12 | 13 | def calc_euclidean_distance(a, b, _norm=np.linalg.norm, axis=None): 14 | return _norm(a - b, axis=axis) 15 | 16 | 17 | def calc_iou_part(pred_h, gt_h, frame, template, frame_w=1280, frame_h=720, template_w=115, template_h=74): 18 | 19 | # TODO: calculate iou part 20 | # === render === 21 | render_w, render_h = template.size # (1050, 680) 22 | dst = np.array(template) 23 | 24 | # Create three channels (680, 1050, 3) 25 | dst = np.stack((dst, ) * 3, axis=-1) 26 | 27 | scaling_mat = np.eye(3) 28 | scaling_mat[0, 0] = render_w / template_w 29 | scaling_mat[1, 1] = render_h / template_h 30 | 31 | frame = np.uint8(frame * 255) # 0-1 map to 0-255 32 | gt_mask_render = cv2.warpPerspective( 33 | frame, scaling_mat @ gt_h, (render_w, render_h), cv2.INTER_CUBIC, borderMode=cv2.BORDER_CONSTANT, borderValue=(0)) 34 | pred_mask_render = cv2.warpPerspective( 35 | frame, scaling_mat @ pred_h, (render_w, render_h), cv2.INTER_CUBIC, borderMode=cv2.BORDER_CONSTANT, borderValue=(0)) 36 | 37 | # === blending === 38 | dstf = dst.astype(float) / 255 39 | gt_mask_renderf = gt_mask_render.astype(float) / 255 40 | gt_resultf = cv2.addWeighted(dstf, 0.3, gt_mask_renderf, 0.7, 0.0) 41 | gt_result = np.uint8(gt_resultf * 255) 42 | pred_mask_renderf = pred_mask_render.astype(float) / 255 43 | pred_resultf = cv2.addWeighted(dstf, 0.3, pred_mask_renderf, 0.7, 0.0) 44 | pred_result = np.uint8(pred_resultf * 255) 45 | 46 | # field template binary mask 47 | field_mask = np.ones((frame_h, frame_w, 3), dtype=np.uint8) * 255 48 | gt_mask = cv2.warpPerspective(field_mask, gt_h, (template_w, template_h), 49 | cv2.INTER_AREA, borderMode=cv2.BORDER_CONSTANT, borderValue=(0)) 50 | pred_mask = cv2.warpPerspective(field_mask, pred_h, (template_w, template_h), 51 | cv2.INTER_AREA, borderMode=cv2.BORDER_CONSTANT, borderValue=(0)) 52 | gt_mask[gt_mask > 0] = 255 53 | pred_mask[pred_mask > 0] = 255 54 | 55 | intersection = ((gt_mask > 0) * (pred_mask > 0)).sum() 56 | union = (gt_mask > 0).sum() + (pred_mask > 0).sum() - intersection 57 | 58 | if union <= 0: 59 | print('part union', union) 60 | # iou = float('nan') 61 | iou = 0. 62 | else: 63 | iou = float(intersection) / float(union) 64 | 65 | # === blending === 66 | gt_white_area = (gt_mask[:, :, 0] == 255) & ( 67 | gt_mask[:, :, 1] == 255) & (gt_mask[:, :, 2] == 255) 68 | gt_fill = gt_mask.copy() 69 | gt_fill[gt_white_area, 0] = 255 70 | gt_fill[gt_white_area, 1] = 0 71 | gt_fill[gt_white_area, 2] = 0 72 | pred_white_area = (pred_mask[:, :, 0] == 255) & ( 73 | pred_mask[:, :, 1] == 255) & (pred_mask[:, :, 2] == 255) 74 | pred_fill = pred_mask.copy() 75 | pred_fill[pred_white_area, 0] = 0 76 | pred_fill[pred_white_area, 1] = 255 77 | pred_fill[pred_white_area, 2] = 0 78 | gt_maskf = gt_fill.astype(float) / 255 79 | pred_maskf = pred_fill.astype(float) / 255 80 | fill_resultf = cv2.addWeighted(gt_maskf, 0.5, 81 | pred_maskf, 0.5, 0.0) 82 | fill_result = np.uint8(fill_resultf * 255) 83 | 84 | return iou, gt_result, pred_result, fill_result 85 | 86 | 87 | def calc_iou_whole_with_poly(pred_h, gt_h, frame, template, frame_w=1280, frame_h=720, template_w=115, template_h=74): 88 | 89 | corners = np.array([[0, 0], 90 | [frame_w - 1, 0], 91 | [frame_w - 1, frame_h - 1], 92 | [0, frame_h - 1]], dtype=np.float64) 93 | 94 | mapping_mat = np.linalg.inv(gt_h) 95 | mapping_mat /= mapping_mat[2, 2] 96 | 97 | gt_corners = cv2.perspectiveTransform( 98 | corners[:, None, :], gt_h) # inv_gt_mat * (gt_mat * [x, y, 1]) 99 | gt_corners = cv2.perspectiveTransform( 100 | gt_corners, np.linalg.inv(gt_h)) 101 | gt_corners = gt_corners[:, 0, :] 102 | 103 | pred_corners = cv2.perspectiveTransform( 104 | corners[:, None, :], gt_h) # inv_pred_mat * (gt_mat * [x, y, 1]) 105 | pred_corners = cv2.perspectiveTransform( 106 | pred_corners, np.linalg.inv(pred_h)) 107 | pred_corners = pred_corners[:, 0, :] 108 | 109 | gt_poly = Polygon(gt_corners.tolist()) 110 | pred_poly = Polygon(pred_corners.tolist()) 111 | 112 | # f, axarr = plt.subplots(1, 2, figsize=(16, 12)) 113 | # axarr[0].plot(*gt_poly.exterior.coords.xy) 114 | # axarr[1].plot(*pred_poly.exterior.coords.xy) 115 | # plt.show() 116 | 117 | if pred_poly.is_valid is False: 118 | return 0., None, None 119 | 120 | if not gt_poly.intersects(pred_poly): 121 | print('not intersects') 122 | iou = 0. 123 | else: 124 | intersection = gt_poly.intersection(pred_poly).area 125 | union = gt_poly.area + pred_poly.area - intersection 126 | if union <= 0.: 127 | print('whole union', union) 128 | iou = 0. 129 | else: 130 | iou = intersection / union 131 | 132 | return iou, None, None 133 | 134 | 135 | def calc_proj_error(pred_h, gt_h, frame, template, frame_w=1280, frame_h=720, template_w=115, template_h=74): 136 | 137 | # TODO get visible field area of the camera image 138 | dst = np.array(template) 139 | 140 | # Create three channels (680, 1050, 3) 141 | dst = np.stack((dst, ) * 3, axis=-1) 142 | 143 | field_mask = np.ones((template_h, template_w, 3), dtype=np.uint8) * 255 144 | gt_mask = cv2.warpPerspective(field_mask, np.linalg.inv( 145 | gt_h), (frame_w, frame_h), borderMode=cv2.BORDER_CONSTANT, borderValue=(0)) 146 | gt_gray = cv2.cvtColor(gt_mask, cv2.COLOR_BGR2GRAY) 147 | contours, hierarchy = cv2.findContours( 148 | gt_gray, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 149 | contour = np.squeeze(contours[0]) 150 | poly = Polygon(contour) 151 | sample_pts = [] 152 | num_pts = 2500 153 | while len(sample_pts) <= num_pts: 154 | x = random.sample(range(0, frame_w), 1) 155 | y = random.sample(range(0, frame_h), 1) 156 | p = Point(x[0], y[0]) 157 | if p.within(poly): 158 | sample_pts.append([x[0], y[0]]) 159 | sample_pts = np.array(sample_pts, dtype=np.float32) 160 | 161 | field_dim_x, field_dim_y = 100, 60 162 | x_scale = field_dim_x / template_w 163 | y_scale = field_dim_y / template_h 164 | scaling_mat = np.eye(3) 165 | scaling_mat[0, 0] = x_scale 166 | scaling_mat[1, 1] = y_scale 167 | gt_temp_grid = cv2.perspectiveTransform( 168 | sample_pts.reshape(-1, 1, 2), scaling_mat @ gt_h) 169 | gt_temp_grid = gt_temp_grid.reshape(-1, 2) 170 | pred_temp_grid = cv2.perspectiveTransform( 171 | sample_pts.reshape(-1, 1, 2), scaling_mat @ pred_h) 172 | pred_temp_grid = pred_temp_grid.reshape(-1, 2) 173 | 174 | # TODO compute distance in top view 175 | gt_grid_list = [] 176 | pred_grid_list = [] 177 | for gt_pts, pred_pts in zip(gt_temp_grid, pred_temp_grid): 178 | if 0 <= gt_pts[0] < field_dim_x and 0 <= gt_pts[1] < field_dim_y and \ 179 | 0 <= pred_pts[0] < field_dim_x and 0 <= pred_pts[1] < field_dim_y: 180 | gt_grid_list.append(gt_pts) 181 | pred_grid_list.append(pred_pts) 182 | gt_grid_list = np.array(gt_grid_list) 183 | pred_grid_list = np.array(pred_grid_list) 184 | 185 | if gt_grid_list.shape != pred_grid_list.shape: 186 | print('proj error:', gt_grid_list.shape, pred_grid_list.shape) 187 | assert gt_grid_list.shape == pred_grid_list.shape, 'shape mismatch' 188 | 189 | if gt_grid_list.size != 0 and pred_grid_list.size != 0: 190 | distance_list = calc_euclidean_distance( 191 | gt_grid_list, pred_grid_list, axis=1) 192 | return distance_list.mean() # average all keypoints 193 | else: 194 | print(gt_grid_list) 195 | print(pred_grid_list) 196 | return float('nan') 197 | 198 | 199 | def calc_reproj_error(pred_h, gt_h, frame, template, frame_w=1280, frame_h=720, template_w=115, template_h=74): 200 | 201 | uniform_grid = utils.gen_template_grid() # grid shape (91, 3), (x, y, label) 202 | template_grid = uniform_grid[:, :2].copy() 203 | template_grid = template_grid.reshape(-1, 1, 2) 204 | 205 | gt_warp_grid = cv2.perspectiveTransform(template_grid, np.linalg.inv(gt_h)) 206 | gt_warp_grid = gt_warp_grid.reshape(-1, 2) 207 | pred_warp_grid = cv2.perspectiveTransform( 208 | template_grid, np.linalg.inv(pred_h)) 209 | pred_warp_grid = pred_warp_grid.reshape(-1, 2) 210 | 211 | # TODO compute distance in camera view 212 | gt_grid_list = [] 213 | pred_grid_list = [] 214 | for gt_pts, pred_pts in zip(gt_warp_grid, pred_warp_grid): 215 | if 0 <= gt_pts[0] < frame_w and 0 <= gt_pts[1] < frame_h and \ 216 | 0 <= pred_pts[0] < frame_w and 0 <= pred_pts[1] < frame_h: 217 | gt_grid_list.append(gt_pts) 218 | pred_grid_list.append(pred_pts) 219 | gt_grid_list = np.array(gt_grid_list) 220 | pred_grid_list = np.array(pred_grid_list) 221 | 222 | if gt_grid_list.shape != pred_grid_list.shape: 223 | print('reproj error:', gt_grid_list.shape, pred_grid_list.shape) 224 | assert gt_grid_list.shape == pred_grid_list.shape, 'shape mismatch' 225 | 226 | if gt_grid_list.size != 0 and pred_grid_list.size != 0: 227 | distance_list = calc_euclidean_distance( 228 | gt_grid_list, pred_grid_list, axis=1) 229 | distance_list /= frame_h # normalize by image height 230 | return distance_list.mean() # average all keypoints 231 | else: 232 | print(gt_grid_list) 233 | print(pred_grid_list) 234 | return float('nan') 235 | 236 | 237 | if __name__ == "__main__": 238 | 239 | # image = np.array(Image.open( 240 | # osp.join('./dataset/soccer_worldcup_2014/soccer_data/test', '1.jpg'))) 241 | image = np.array(Image.open( 242 | osp.join('./assets', 'IMG_001.jpg'))) 243 | # print(image.shape) 244 | # gt_homo = np.loadtxt( 245 | # osp.join('./dataset/soccer_worldcup_2014/soccer_data/test', '1.homographyMatrix')) 246 | gt_homo = np.load( 247 | osp.join('./assets', 'IMG_001_homography.npy')) 248 | pred_homo = np.load( 249 | osp.join('./assets', 'IMG_002_homography.npy')) 250 | # gt_homo = np.load( 251 | # osp.join('./nms/debug/test_visual', 'test_00010_gt_homography.npy')) 252 | # pred_homo = np.load( 253 | # osp.join('./nms/debug/test_visual', 'test_00010_pred_homography.npy')) 254 | field_model = Image.open( 255 | osp.join('./assets', 'worldcup_field_model.png')) 256 | m = np.array([[0.02, 0, 0], 257 | [0, 0.02, 0], 258 | [0, 0, 0]], dtype=np.float32) 259 | print(gt_homo) 260 | # pred_homo = gt_homo 261 | # pred_homo[0, 0] *= -1 262 | # pred_homo = gt_homo - m 263 | # pred_homo = np.eye(3, dtype=np.float32) 264 | print(pred_homo) 265 | 266 | iou_part, gt_part_mask, pred_part_mask, part_merge_result = calc_iou_part( 267 | pred_homo, gt_homo, image, field_model) 268 | print(f'{iou_part * 100.:.1f}') 269 | print(iou_part) 270 | 271 | # iou_whole, whole_line_merge_result, whole_fill_merge_result = calc_iou_whole( 272 | # pred_homo, gt_homo, image, field_model) 273 | # print(f'{iou_whole * 100.:.1f}') 274 | # print(iou_whole) 275 | -------------------------------------------------------------------------------- /worldcup_test_loader.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import os.path as osp 4 | import numpy as np 5 | from PIL import Image 6 | import torch 7 | from torch.utils import data 8 | from torchvision import transforms 9 | import matplotlib.pyplot as plt 10 | import skimage.segmentation as ss 11 | from typing import Optional 12 | import utils 13 | 14 | 15 | class WorldcupTestDataset(data.Dataset): 16 | 17 | def __init__(self, root, data_type, mode, num_objects, noise_trans: Optional[float] = None, noise_rotate: Optional[float] = None, target_image: list = None): 18 | 19 | self.frame_h = 720 20 | self.frame_w = 1280 21 | self.root = root 22 | self.data_type = data_type 23 | self.mode = mode 24 | self.num_objects = num_objects 25 | # self.num_objects = 91 26 | self.noise_trans = noise_trans 27 | self.noise_rotate = noise_rotate 28 | 29 | sfp_out_path = 'robust_worldcup_testset_dilated' 30 | self.sfp_path = osp.join(self.root, sfp_out_path) 31 | 32 | self.imgset_path = osp.join(self.root, self.data_type) 33 | self.videos = [] 34 | self.num_frames = {} 35 | self.num_homographies = {} 36 | self.frames = {} 37 | self.homographies = {} 38 | self.segs = {} 39 | 40 | _video = 'worldcup_2014' 41 | self.videos.append(_video) 42 | self.num_frames[_video] = len( 43 | glob.glob(osp.join(self.imgset_path, '*.jpg'))) 44 | self.num_homographies[_video] = len( 45 | glob.glob(osp.join(self.imgset_path, '*.homographyMatrix'))) 46 | 47 | frame_list = [osp.basename(name) for name in glob.glob( 48 | osp.join(self.imgset_path, '*.jpg'))] 49 | frames = [img for img in sorted( 50 | frame_list, key=lambda x: int(x[:-4]))] 51 | self.frames[_video] = frames 52 | 53 | homographies_list = [osp.basename(name) for name in glob.glob( 54 | osp.join(self.imgset_path, '*.homographyMatrix'))] 55 | homographies = [mat for mat in sorted( 56 | homographies_list, key=lambda x: int(x[:-17]))] 57 | self.homographies[_video] = homographies 58 | 59 | gt_segs = sorted(os.listdir( 60 | osp.join(self.sfp_path, _video))) 61 | self.segs[_video] = gt_segs 62 | 63 | self.preprocess = transforms.Compose([ 64 | transforms.ToTensor(), 65 | transforms.Normalize([0.485, 0.456, 0.406], 66 | [0.229, 0.224, 0.225]), # ImageNet 67 | ]) 68 | 69 | self.target_image = target_image 70 | 71 | def __len__(self): 72 | return len(self.videos) 73 | 74 | def __getitem__(self, index): 75 | 76 | _video_name = self.videos[index] 77 | _frames = self.frames[_video_name] 78 | _homographies = self.homographies[_video_name] 79 | _segs = self.segs[_video_name] 80 | info = {} 81 | info['name'] = _video_name 82 | info['frames'] = [] 83 | info['num_frames'] = self.num_frames[_video_name] 84 | info['single_frame_path'] = self.sfp_path 85 | 86 | template_grid = utils.gen_template_grid() # template grid shape (91, 3) 87 | 88 | image_list = [] 89 | homo_mat_list = [] 90 | dilated_hm_list = [] 91 | hm_list = [] 92 | gt_seg_list = [] 93 | 94 | for f_idx in range(self.num_frames[_video_name]): 95 | 96 | jpg_image = _frames[f_idx] 97 | # NNN.jpg 98 | if self.target_image and jpg_image.split('.')[0] not in self.target_image: 99 | continue 100 | 101 | npy_matrix = _homographies[f_idx] 102 | png_seg = _segs[f_idx] 103 | info['frames'].append( 104 | osp.join(self.imgset_path, jpg_image)) 105 | image = np.array(Image.open( 106 | osp.join(self.imgset_path, jpg_image))) 107 | gt_h = np.loadtxt(osp.join(self.imgset_path, npy_matrix)) 108 | 109 | sfp_seg = np.array(Image.open( 110 | osp.join(self.sfp_path, _video_name, png_seg)).convert('P')) 111 | gt_seg_list.append(sfp_seg) 112 | 113 | # warp grid shape (91, 3) 114 | warp_image, warp_grid, homo_mat = utils.gen_im_whole_grid( 115 | self.mode, image, f_idx, gt_h, template_grid, self.noise_trans, self.noise_rotate, index) 116 | 117 | # Each keypoints is considered as an object 118 | num_pts = warp_grid.shape[0] 119 | pil_image = Image.fromarray(warp_image) 120 | 121 | image_tensor = self.preprocess(pil_image) 122 | image_list.append(image_tensor) 123 | homo_mat_list.append(homo_mat) 124 | 125 | # By default, all keypoints belong to background 126 | # C*H*W, C:91, exclude background class 127 | heatmaps = np.zeros( 128 | (num_pts, self.frame_h // 4, self.frame_w // 4), dtype=np.float32) 129 | dilated_heatmaps = np.zeros_like(heatmaps) 130 | for keypts_label in range(num_pts): 131 | if np.isnan(warp_grid[keypts_label, 0]) and np.isnan(warp_grid[keypts_label, 1]): 132 | continue 133 | px = np.rint(warp_grid[keypts_label, 0] / 4).astype(np.int32) 134 | py = np.rint(warp_grid[keypts_label, 1] / 4).astype(np.int32) 135 | cls = int(warp_grid[keypts_label, 2]) - 1 136 | if 0 <= px < (self.frame_w // 4) and 0 <= py < (self.frame_h // 4): 137 | heatmaps[cls][py, px] = warp_grid[keypts_label, 2] 138 | dilated_heatmaps[cls] = ss.expand_labels( 139 | heatmaps[cls], distance=5) 140 | 141 | dilated_hm_list.append(dilated_heatmaps) 142 | hm_list.append(heatmaps) 143 | 144 | # TODO: use full gt segmentatino info, only previous for memory management 145 | info['num_objects'] = self.num_objects 146 | 147 | dilated_hm_list = np.stack( 148 | dilated_hm_list, axis=0) # num_frames*91*H*W 149 | T, CK, H, W = dilated_hm_list.shape 150 | hm_list = np.stack(hm_list, axis=0) 151 | gt_seg_list = np.stack(gt_seg_list, axis=0) # num_frames*H*W 152 | # (CK:num_objects, T:num_frames, H:180, W:320) 153 | target_dilated_hm_list = torch.zeros((CK, T, H, W)) 154 | target_hm_list = torch.zeros_like(target_dilated_hm_list) 155 | 156 | if self.target_image: 157 | cls_gt = torch.zeros((len(self.target_image), H, W)) 158 | else: 159 | cls_gt = torch.zeros((self.num_frames[_video_name], H, W)) 160 | 161 | lookup_list = [] 162 | 163 | if self.target_image: 164 | hm_range = range(len(self.target_image)) 165 | else: 166 | hm_range = range(self.num_frames[_video_name]) 167 | 168 | for f in hm_range: 169 | class_lables = np.ones(num_pts, dtype=np.float32) * -1 170 | # Those keypoints appears on the each frame 171 | labels = np.unique(dilated_hm_list[f]) 172 | labels = labels[labels != 0] # Remove background class 173 | for obj in labels: 174 | class_lables[int(obj) - 1] = obj 175 | 176 | for idx, obj in enumerate(class_lables): 177 | if obj != -1: 178 | target_dilated_hm = dilated_hm_list[f, int(obj) - 1].copy() 179 | target_dilated_hm[target_dilated_hm == obj] = 1 180 | target_dilated_hm_tensor = utils.to_torch( 181 | target_dilated_hm) 182 | target_dilated_hm_list[int( 183 | obj) - 1, f] = target_dilated_hm_tensor 184 | 185 | target_hm = hm_list[f, int(obj) - 1].copy() 186 | target_hm[target_hm == obj] = 1 187 | target_hm_tensor = utils.to_torch(target_hm) 188 | target_hm_list[int(obj) - 1, f] = target_hm_tensor 189 | 190 | # TODO: union of all target objects of ground truth segmentation 191 | for idx, obj in enumerate(class_lables): 192 | if obj != -1: 193 | cls_gt[target_hm_list[idx] == 194 | 1] = torch.tensor(obj).float() 195 | 196 | # TODO: use full single frame predict segmentatino info, only previous for memory management 197 | for f in hm_range: 198 | class_lables = np.ones(num_pts, dtype=np.float32) * -1 199 | # Those keypoints appears on the each single frame prediction 200 | labels = np.unique(gt_seg_list[f]) 201 | labels = labels[labels != 0] # Remove background class 202 | for obj in labels: 203 | class_lables[int(obj) - 1] = obj 204 | 205 | sfp_lookup = utils.to_torch(class_lables) 206 | 207 | # TODO: choose the range of classes for class conditioning 208 | sfp_interval = torch.ones_like(sfp_lookup) * -1 209 | cls_id = torch.unique(sfp_lookup) 210 | cls_id = cls_id[cls_id != -1] 211 | cls_list = torch.arange(cls_id.min(), cls_id.max() + 1) 212 | if cls_list.min() > 10: 213 | min_cls = cls_list.min() 214 | l1 = torch.arange(min_cls - 10, min_cls) 215 | cls_list = torch.cat([l1, cls_list], dim=0) 216 | 217 | if cls_list.max() < 81: 218 | max_cls = cls_list.max() + 1 219 | l2 = torch.arange(max_cls, max_cls + 10) 220 | cls_list = torch.cat([cls_list, l2], dim=0) 221 | 222 | for obj in cls_list: 223 | sfp_interval[int(obj) - 1] = obj 224 | 225 | lookup_list.append(sfp_interval) 226 | 227 | lookup_list = torch.stack(lookup_list, dim=0) # T*CK:91 228 | selector_list = torch.ones_like(lookup_list) # T*CK:91 229 | selector_list[lookup_list == -1] = 0 230 | 231 | # (num_frames, 3, 720, 1280) 232 | image_list = torch.stack(image_list, dim=0) 233 | homo_mat_list = np.stack(homo_mat_list, axis=0) 234 | # (K:num_objects, T:num_frames, C:1, H:180, W:320) 235 | target_dilated_hm_list = target_dilated_hm_list.unsqueeze(2) 236 | 237 | data = {} 238 | data['rgb'] = image_list 239 | data['target_dilated_hm'] = target_dilated_hm_list 240 | data['cls_gt'] = cls_gt 241 | data['gt_homo'] = homo_mat_list 242 | data['selector'] = selector_list 243 | data['lookup'] = lookup_list 244 | data['info'] = info 245 | 246 | return data 247 | 248 | 249 | if __name__ == "__main__": 250 | 251 | worldcup_test_loader = WorldcupTestDataset( 252 | root='dataset/soccer_worldcup_2014/soccer_data', data_type='test', mode='test', num_objects=4) 253 | 254 | import shutil 255 | cnt = 1 256 | visual_dir = osp.join('visual', 'worldcup_test') 257 | if osp.exists(visual_dir): 258 | print(f'Remove directory: {visual_dir}') 259 | shutil.rmtree(visual_dir) 260 | print(f'Create directory: {visual_dir}') 261 | os.makedirs(visual_dir, exist_ok=True) 262 | 263 | denorm = utils.UnNormalize( 264 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 265 | 266 | for data in worldcup_test_loader: 267 | image = data['rgb'] 268 | mask = data['target_dilated_hm'] 269 | cls_gt = data['cls_gt'] 270 | # === debug === 271 | print(f'number of frames: {cls_gt.shape[0]}') 272 | for j in range(cls_gt.shape[0]): 273 | print(torch.unique(cls_gt[j])) 274 | plt.imsave(osp.join(visual_dir, 'Video%d_Seg%03d.jpg' % 275 | (cnt, j + 1)), utils.to_numpy(cls_gt[j]), vmin=0, vmax=91) 276 | plt.imsave(osp.join(visual_dir, 'Frame_%03d.jpg' % 277 | (j + 1)), utils.im_to_numpy(denorm(image[j]))) 278 | for i in range(91): 279 | if np.any(utils.to_numpy(mask[i, j, 0])): 280 | plt.imsave(osp.join(visual_dir, '%d_dilated_mask_obj%d.jpg' % ( 281 | j + 1, i + 1)), utils.to_numpy(mask[i, j, 0])) 282 | if j == 3: 283 | assert False 284 | cnt += 1 285 | pass 286 | -------------------------------------------------------------------------------- /ts_worldcup_test_loader.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import os.path as osp 4 | import numpy as np 5 | from PIL import Image 6 | import torch 7 | from torch.utils import data 8 | from torchvision import transforms 9 | import matplotlib.pyplot as plt 10 | import skimage.segmentation as ss 11 | from typing import Optional 12 | import utils 13 | 14 | 15 | class MainTestDataset(data.Dataset): 16 | 17 | def __init__(self, root, data_type, mode, num_objects, noise_trans: Optional[float] = None, noise_rotate: Optional[float] = None, target_video: list = [], sfp_finetuned: Optional[bool] = False): 18 | 19 | self.frame_h = 720 20 | self.frame_w = 1280 21 | self.root = root 22 | self.data_type = data_type 23 | self.mode = mode 24 | self.num_objects = num_objects 25 | # self.num_objects = 91 26 | self.noise_trans = noise_trans 27 | self.noise_rotate = noise_rotate 28 | self.sfp_finetuned = sfp_finetuned 29 | 30 | sequence_interval = '80_95' 31 | 32 | if self.sfp_finetuned: 33 | sfp_out_path = 'SingleFramePredict_finetuned_with_normalized' 34 | else: 35 | sfp_out_path = 'SingleFramePredict_with_normalized' 36 | 37 | self.image_path = osp.join( 38 | self.root, 'Dataset', sequence_interval) 39 | self.anno_path = osp.join( 40 | self.root, 'Annotations', sequence_interval) 41 | self.sfp_path = osp.join( 42 | self.root, sfp_out_path, sequence_interval) 43 | 44 | imgset_path = osp.join(self.root, self.data_type) 45 | self.videos = [] 46 | self.num_frames = {} 47 | self.num_homographies = {} 48 | self.frames = {} 49 | self.homographies = {} 50 | self.segs = {} 51 | 52 | with open(imgset_path + '.txt', 'r') as lines: 53 | for line in lines: 54 | _video = line.rstrip('\n') 55 | 56 | if target_video and not _video in target_video: 57 | continue 58 | 59 | self.videos.append(_video) 60 | self.num_frames[_video] = len( 61 | glob.glob(osp.join(self.image_path, _video, '*.jpg'))) 62 | self.num_homographies[_video] = len( 63 | glob.glob(osp.join(self.anno_path, _video, '*_homography.npy'))) 64 | 65 | frames = sorted(os.listdir( 66 | osp.join(self.image_path, _video))) 67 | self.frames[_video] = frames 68 | 69 | homographies = sorted(os.listdir( 70 | osp.join(self.anno_path, _video))) 71 | self.homographies[_video] = homographies 72 | 73 | gt_segs = sorted(os.listdir( 74 | osp.join(self.sfp_path, _video))) 75 | self.segs[_video] = gt_segs 76 | 77 | self.preprocess = transforms.Compose([ 78 | transforms.ToTensor(), 79 | transforms.Normalize([0.485, 0.456, 0.406], 80 | [0.229, 0.224, 0.225]), # ImageNet 81 | ]) 82 | 83 | def __len__(self): 84 | 85 | return len(self.videos) 86 | 87 | def __getitem__(self, index): 88 | 89 | _video_name = self.videos[index] 90 | _frames = self.frames[_video_name] 91 | _homographies = self.homographies[_video_name] 92 | _segs = self.segs[_video_name] 93 | info = {} 94 | info['num_objects'] = self.num_objects 95 | info['name'] = _video_name 96 | info['frames'] = [] 97 | info['num_frames'] = self.num_frames[_video_name] 98 | info['single_frame_path'] = self.sfp_path 99 | 100 | template_grid = utils.gen_template_grid() # template grid shape (91, 3) 101 | 102 | image_list = [] 103 | homo_mat_list = [] 104 | dilated_hm_list = [] 105 | hm_list = [] 106 | gt_seg_list = [] 107 | 108 | for f_idx in range(self.num_frames[_video_name]): 109 | jpg_image = _frames[f_idx] 110 | npy_matrix = _homographies[f_idx] 111 | png_seg = _segs[f_idx] 112 | info['frames'].append( 113 | osp.join(self.image_path, _video_name, jpg_image)) 114 | image = np.array(Image.open( 115 | osp.join(self.image_path, _video_name, jpg_image))) 116 | gt_h = np.load(osp.join(self.anno_path, _video_name, npy_matrix)) 117 | 118 | sfp_seg = np.array(Image.open( 119 | osp.join(self.sfp_path, _video_name, png_seg)).convert('P')) 120 | gt_seg_list.append(sfp_seg) 121 | 122 | # warp grid shape (91, 3) 123 | warp_image, warp_grid, homo_mat = utils.gen_im_whole_grid( 124 | self.mode, image, f_idx, gt_h, template_grid, self.noise_trans, self.noise_rotate, index) 125 | 126 | # Each keypoints is considered as an object 127 | num_pts = warp_grid.shape[0] 128 | pil_image = Image.fromarray(warp_image) 129 | 130 | image_tensor = self.preprocess(pil_image) 131 | image_list.append(image_tensor) 132 | homo_mat_list.append(homo_mat) 133 | 134 | # By default, all keypoints belong to background 135 | # C*H*W, C:91, exclude background class 136 | heatmaps = np.zeros( 137 | (num_pts, self.frame_h // 4, self.frame_w // 4), dtype=np.float32) 138 | dilated_heatmaps = np.zeros_like(heatmaps) 139 | for keypts_label in range(num_pts): 140 | if np.isnan(warp_grid[keypts_label, 0]) and np.isnan(warp_grid[keypts_label, 1]): 141 | continue 142 | px = np.rint(warp_grid[keypts_label, 0] / 4).astype(np.int32) 143 | py = np.rint(warp_grid[keypts_label, 1] / 4).astype(np.int32) 144 | cls = int(warp_grid[keypts_label, 2]) - 1 145 | if 0 <= px < (self.frame_w // 4) and 0 <= py < (self.frame_h // 4): 146 | heatmaps[cls][py, px] = warp_grid[keypts_label, 2] 147 | dilated_heatmaps[cls] = ss.expand_labels( 148 | heatmaps[cls], distance=5) 149 | 150 | dilated_hm_list.append(dilated_heatmaps) 151 | hm_list.append(heatmaps) 152 | 153 | # TODO: use full gt segmentatino info, only previous for memory management 154 | dilated_hm_list = np.stack( 155 | dilated_hm_list, axis=0) # num_frames*91*H*W 156 | T, CK, H, W = dilated_hm_list.shape 157 | hm_list = np.stack(hm_list, axis=0) 158 | 159 | # (CK:num_objects, T:num_frames, H:180, W:320) 160 | target_dilated_hm_list = torch.zeros((CK, T, H, W)) 161 | target_hm_list = torch.zeros_like(target_dilated_hm_list) 162 | cls_gt = torch.zeros((self.num_frames[_video_name], H, W)) 163 | lookup_list = [] 164 | for f in range(self.num_frames[_video_name]): 165 | class_lables = np.ones(num_pts, dtype=np.float32) * -1 166 | # Those keypoints appears on the each frame 167 | labels = np.unique(dilated_hm_list[f]) 168 | labels = labels[labels != 0] # Remove background class 169 | for obj in labels: 170 | class_lables[int(obj) - 1] = obj 171 | 172 | for idx, obj in enumerate(class_lables): 173 | if obj != -1: 174 | target_dilated_hm = dilated_hm_list[f, int(obj) - 1].copy() 175 | target_dilated_hm[target_dilated_hm == obj] = 1 176 | target_dilated_hm_tensor = utils.to_torch( 177 | target_dilated_hm) 178 | target_dilated_hm_list[int( 179 | obj) - 1, f] = target_dilated_hm_tensor 180 | 181 | target_hm = hm_list[f, int(obj) - 1].copy() 182 | target_hm[target_hm == obj] = 1 183 | target_hm_tensor = utils.to_torch(target_hm) 184 | target_hm_list[int(obj) - 1, f] = target_hm_tensor 185 | 186 | # TODO: union of all target objects of ground truth segmentation 187 | for idx, obj in enumerate(class_lables): 188 | if obj != -1: 189 | cls_gt[target_hm_list[idx] == 190 | 1] = torch.tensor(obj).float() 191 | 192 | # TODO: use full single frame predict segmentatino info, only previous for memory management 193 | gt_seg_list = np.stack(gt_seg_list, axis=0) # num_frames*H*W 194 | for f in range(self.num_frames[_video_name]): 195 | class_lables = np.ones(num_pts, dtype=np.float32) * -1 196 | # Those keypoints appears on the each single frame prediction 197 | labels = np.unique(gt_seg_list[f]) 198 | labels = labels[labels != 0] # Remove background class 199 | for obj in labels: 200 | class_lables[int(obj) - 1] = obj 201 | 202 | sfp_lookup = utils.to_torch(class_lables) 203 | 204 | # TODO: choose the range of classes for class conditioning 205 | sfp_interval = torch.ones_like(sfp_lookup) * -1 206 | cls_id = torch.unique(sfp_lookup) 207 | cls_id = cls_id[cls_id != -1] 208 | cls_list = torch.arange(cls_id.min(), cls_id.max() + 1) 209 | 210 | if cls_list.min() > 10: 211 | min_cls = cls_list.min() 212 | l1 = torch.arange(min_cls - 10, min_cls) 213 | cls_list = torch.cat([l1, cls_list], dim=0) 214 | 215 | if cls_list.max() < 81: 216 | max_cls = cls_list.max() + 1 217 | l2 = torch.arange(max_cls, max_cls + 10) 218 | cls_list = torch.cat([cls_list, l2], dim=0) 219 | 220 | for obj in cls_list: 221 | sfp_interval[int(obj) - 1] = obj 222 | 223 | lookup_list.append(sfp_interval) 224 | 225 | lookup_list = torch.stack(lookup_list, dim=0) # T*CK:91 226 | selector_list = torch.ones_like(lookup_list) # T*CK:91 227 | selector_list[lookup_list == -1] = 0 228 | 229 | # (num_frames, 3, 720, 1280) 230 | image_list = torch.stack(image_list, dim=0) 231 | homo_mat_list = np.stack(homo_mat_list, axis=0) 232 | # (K:num_objects, T:num_frames, C:1, H:180, W:320) 233 | target_dilated_hm_list = target_dilated_hm_list.unsqueeze(2) 234 | 235 | data = {} 236 | data['rgb'] = image_list 237 | data['target_dilated_hm'] = target_dilated_hm_list 238 | data['cls_gt'] = cls_gt 239 | data['gt_homo'] = homo_mat_list 240 | data['selector'] = selector_list 241 | data['lookup'] = lookup_list 242 | data['info'] = info 243 | 244 | return data 245 | 246 | 247 | if __name__ == "__main__": 248 | 249 | main_test_loader = MainTestDataset( 250 | root='dataset/WorldCup_2014_2018', data_type='test', mode='test', num_objects=4) 251 | 252 | import shutil 253 | cnt = 1 254 | visual_dir = osp.join('visual', 'main_test') 255 | if osp.exists(visual_dir): 256 | print(f'Remove directory: {visual_dir}') 257 | shutil.rmtree(visual_dir) 258 | print(f'Create directory: {visual_dir}') 259 | os.makedirs(visual_dir, exist_ok=True) 260 | 261 | denorm = utils.UnNormalize( 262 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 263 | 264 | for data in main_test_loader: 265 | image = data['rgb'] 266 | mask = data['target_dilated_hm'] 267 | cls_gt = data['cls_gt'] 268 | # === debug === 269 | print(f'number of frames: {cls_gt.shape[0]}') 270 | for j in range(cls_gt.shape[0]): 271 | print(torch.unique(cls_gt[j])) 272 | plt.imsave(osp.join(visual_dir, 'Video%d_Seg%03d.jpg' % 273 | (cnt, j + 1)), utils.to_numpy(cls_gt[j]), vmin=0, vmax=91) 274 | plt.imsave(osp.join(visual_dir, 'Frame_%03d.jpg' % 275 | (j + 1)), utils.im_to_numpy(denorm(image[j]))) 276 | for i in range(91): 277 | if np.any(utils.to_numpy(mask[i, j, 0])): 278 | plt.imsave(osp.join(visual_dir, '%d_dilated_mask_obj%d.jpg' % ( 279 | j + 1, i + 1)), utils.to_numpy(mask[i, j, 0])) 280 | cnt += 1 281 | assert False 282 | pass 283 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import cv2 4 | import math 5 | import matplotlib.pyplot as plt 6 | from matplotlib.patches import Circle 7 | import random 8 | from PIL import Image 9 | import os.path as osp 10 | 11 | 12 | def get_mean_std(loader): 13 | # Calc the new mean and standard deviation if want to train from scratch on your own dataset 14 | # Var[x] = E[x**2] - E[x]**2 15 | channels_sum, channels_squared_sum, num_batches = 0, 0, 0 16 | 17 | for data, _ in loader: 18 | channels_sum += torch.mean(data, dim=[0, 2, 3]) 19 | channels_squared_sum += torch.mean(data ** 2, dim=[0, 2, 3]) 20 | num_batches += 1 21 | 22 | mean = channels_sum / num_batches 23 | std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5 24 | 25 | return mean, std 26 | 27 | 28 | class UnNormalize(object): 29 | def __init__(self, mean, std): 30 | self.mean = mean 31 | self.std = std 32 | 33 | def __call__(self, tensor): 34 | """ 35 | Args: 36 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 37 | Returns: 38 | Tensor: Normalized image. 39 | """ 40 | for t, m, s in zip(tensor, self.mean, self.std): 41 | t.mul_(s).add_(m) 42 | # The normalize code -> t.sub_(m).div_(s) 43 | return tensor 44 | 45 | 46 | def isnan(x): 47 | return x != x 48 | 49 | 50 | def hasnan(x): 51 | return isnan(x).any() 52 | 53 | 54 | def to_numpy(tensor): 55 | if torch.is_tensor(tensor): 56 | return tensor.detach().cpu().numpy() 57 | elif type(tensor).__module__ != 'numpy': 58 | raise ValueError("Cannot convert {} to numpy array" 59 | .format(type(tensor))) 60 | return tensor 61 | 62 | 63 | def to_torch(ndarray): 64 | if type(ndarray).__module__ == 'numpy': 65 | return torch.from_numpy(ndarray.copy()) 66 | elif not torch.is_tensor(ndarray): 67 | raise ValueError("Cannot convert {} to torch tensor" 68 | .format(type(ndarray))) 69 | return ndarray 70 | 71 | 72 | def im_to_numpy(img): 73 | img = to_numpy(img) 74 | img = np.transpose(img, (1, 2, 0)) # H*W*C 75 | return img 76 | 77 | 78 | def im_to_torch(img): 79 | img = np.transpose(img, (2, 0, 1)) # C*H*W 80 | img = to_torch(img).float() 81 | if img.max() > 1: 82 | img /= 255 83 | return img 84 | 85 | 86 | def reseed(seed): 87 | 88 | # TODO: Set seed to ensure the same initialization 89 | torch.manual_seed(seed) 90 | np.random.seed(seed) 91 | random.seed(seed) 92 | 93 | # if using cuda 94 | # torch.cuda.manual_seed_all(seed) 95 | # torch.backends.cudnn.enabled = False 96 | # torch.backends.cudnn.deterministic = True 97 | # torch.backends.cudnn.benchmark = False 98 | 99 | 100 | def gen_template_grid(): 101 | # === set uniform grid === 102 | # field_dim_x, field_dim_y = 105.000552, 68.003928 # in meter 103 | field_dim_x, field_dim_y = 114.83, 74.37 # in yard 104 | # field_dim_x, field_dim_y = 115, 74 # in yard 105 | nx, ny = (13, 7) 106 | x = np.linspace(0, field_dim_x, nx) 107 | y = np.linspace(0, field_dim_y, ny) 108 | xv, yv = np.meshgrid(x, y, indexing='ij') 109 | uniform_grid = np.stack((xv, yv), axis=2).reshape(-1, 2) 110 | uniform_grid = np.concatenate((uniform_grid, np.ones( 111 | (uniform_grid.shape[0], 1))), axis=1) # top2bottom, left2right 112 | # TODO: class label in template, each keypoints is (x, y, c), c is label that starts from 1 113 | for idx, pts in enumerate(uniform_grid): 114 | pts[2] = idx + 1 # keypoints label 115 | return uniform_grid 116 | 117 | 118 | def gen_im_partial_grid(mode, frame, gt_homo, template, noise_trans, noise_rotate, index): 119 | # === Warping image and grid for single-frame method === 120 | frame_w, frame_h = frame.shape[1], frame.shape[0] 121 | 122 | unigrid_copy = template.copy() # (91, 3) 123 | unigrid_copy[:, 2] = 1 124 | 125 | gt_warp_grid = unigrid_copy @ np.linalg.inv(gt_homo.T) 126 | gt_warp_grid /= gt_warp_grid[:, 2, np.newaxis] 127 | 128 | # assign pixels class label, 1-91 129 | for idx, pts in enumerate(gt_warp_grid): 130 | pts[2] = idx + 1 # keypoints label 131 | 132 | # TODO: apply random small noise to the gt homography and the image is warp accordingly 133 | if mode == 'train' and random.random() < 0.5: 134 | # if False: 135 | # if True: 136 | # only store those points in image view 137 | l1, l2, label = [], [], [] 138 | for pts, t_pts, sub_pts in zip(gt_warp_grid, unigrid_copy, template): 139 | if 0 <= pts[0] < frame_w and 0 <= pts[1] < frame_h: 140 | l1.append(pts) 141 | l2.append(t_pts) 142 | label.append(sub_pts[2]) # has labels 143 | src_grid = np.array(l1) 144 | tmp_grid = np.array(l2) 145 | class_labels = np.array(label) 146 | 147 | # TODO: do homography augmentation, around center?????? 148 | center_x, center_y = frame_w / 2, frame_h / 2 149 | noise_scale = random.uniform(0.8, 1.05) 150 | scaling_mat = np.eye(3).astype(np.float32) 151 | scaling_mat[0, 0] = noise_scale 152 | scaling_mat[1, 1] = noise_scale 153 | if random.random() < 0.5: 154 | if random.random() < 0.5: 155 | scaling_mat[0, 2] = frame_h // 10 156 | scaling_mat[1, 2] = frame_h // 10 157 | else: 158 | scaling_mat[0, 2] = frame_h // 6 159 | scaling_mat[1, 2] = frame_h // 6 160 | tx = random.uniform(-noise_trans, noise_trans) 161 | ty = random.uniform(-noise_trans, noise_trans) 162 | translate_mat = np.array([[1, 0, tx], 163 | [0, 1, ty], 164 | [0, 0, 1]], dtype=np.float32) 165 | 166 | theta = random.uniform(-noise_rotate, noise_rotate) 167 | deflection = 1.0 168 | theta = theta * 2.0 * deflection * np.pi # in radians 169 | c, s = np.cos(theta), np.sin(theta) 170 | rotate_mat = np.array([[c, -s, 0], 171 | [s, c, 0], 172 | [0, 0, 1]], dtype=np.float32) 173 | 174 | pert_homo = rotate_mat @ gt_homo @ scaling_mat @ translate_mat @ rotate_mat.T 175 | pert_homo /= pert_homo[2, 2] 176 | 177 | # shape is (?, 3) 178 | pert_src_grid = tmp_grid @ np.linalg.inv(pert_homo.T) 179 | pert_src_grid /= pert_src_grid[:, 2, np.newaxis] 180 | 181 | for pts, cls in zip(pert_src_grid, class_labels): 182 | pts[2] = cls # assign keypoints label 183 | 184 | src_list, dst_list = [], [] 185 | for _src, _dst in zip(src_grid, pert_src_grid): 186 | # warp points maybe out of image resolution after perturbation 187 | if 0 <= _dst[0] < frame_w and 0 <= _dst[1] < frame_h: 188 | src_list.append(_src) 189 | dst_list.append(_dst) 190 | src_pts = np.array(src_list) 191 | dst_pts = np.array(dst_list) 192 | if src_pts.shape[0] >= 4 and dst_pts.shape[0] >= 4: 193 | new_homo_mat, mask = cv2.findHomography( 194 | src_pts[:, :2].reshape(-1, 1, 2), dst_pts[:, :2].reshape(-1, 1, 2), cv2.RANSAC, 5) 195 | if new_homo_mat is not None: 196 | warp_image = cv2.warpPerspective( 197 | frame, new_homo_mat, (frame_w, frame_h), cv2.INTER_CUBIC, borderMode=cv2.BORDER_CONSTANT, borderValue=(0)) 198 | warp_grid = dst_pts.copy() 199 | homo_mat = pert_homo 200 | else: 201 | warp_image = None 202 | warp_grid = None 203 | else: 204 | warp_image = None 205 | warp_grid = None 206 | else: 207 | warp_image = frame.copy() 208 | grid_list = [] 209 | for ind, pts in enumerate(gt_warp_grid): 210 | if 0 <= pts[0] < frame_w and 0 <= pts[1] < frame_h: 211 | grid_list.append(pts) 212 | warp_grid = np.array(grid_list) 213 | homo_mat = gt_homo 214 | 215 | if warp_image is None and warp_grid is None: 216 | warp_image = frame.copy() 217 | grid_list = [] 218 | for ind, pts in enumerate(gt_warp_grid): 219 | if 0 <= pts[0] < frame_w and 0 <= pts[1] < frame_h: 220 | grid_list.append(pts) 221 | warp_grid = np.array(grid_list) 222 | homo_mat = gt_homo 223 | 224 | return warp_image, warp_grid, homo_mat 225 | 226 | 227 | def gen_im_whole_grid(mode, frame, f_idx, gt_homo, template, noise_trans, noise_rotate, index, vid_name=None): 228 | # === Warping image and grid for multi-frame method or cooredinate regression === 229 | frame_w, frame_h = frame.shape[1], frame.shape[0] 230 | 231 | unigrid_copy = template.copy() # (91, 3) 232 | unigrid_copy[:, 2] = 1 233 | 234 | gt_warp_grid = unigrid_copy @ np.linalg.inv(gt_homo.T) 235 | gt_warp_grid /= gt_warp_grid[:, 2, np.newaxis] 236 | 237 | # assign pixels class label, 1-91 238 | for idx, pts in enumerate(gt_warp_grid): 239 | pts[2] = idx + 1 # keypoints label 240 | 241 | # TODO: apply random small noise to the gt homography and the image is warp accordingly 242 | if mode == 'train' and (f_idx == 1 or f_idx == 2): # hard level 243 | # only store those points in image view 244 | l1, l2, label = [], [], [] 245 | for pts, t_pts, sub_pts in zip(gt_warp_grid, unigrid_copy, template): 246 | if 0 <= pts[0] < frame_w and 0 <= pts[1] < frame_h: 247 | l1.append(pts) 248 | l2.append(t_pts) 249 | label.append(sub_pts[2]) # has labels 250 | else: 251 | l1.append([float('nan'), float('nan'), -1.]) 252 | l2.append([float('nan'), float('nan'), 1.]) 253 | label.append(-1.) 254 | 255 | src_grid = np.array(l1) 256 | tmp_grid = np.array(l2) 257 | class_labels = np.array(label) 258 | 259 | # TODO: do homography augmentation, around center?????? 260 | center_x, center_y = frame_w / 2, frame_h / 2 261 | noise_scale = random.uniform(0.8, 1.05) 262 | scaling_mat = np.eye(3).astype(np.float32) 263 | scaling_mat[0, 0] = noise_scale 264 | scaling_mat[1, 1] = noise_scale 265 | if random.random() < 0.5: 266 | if random.random() < 0.5: 267 | scaling_mat[0, 2] = frame_h // 10 268 | scaling_mat[1, 2] = frame_h // 10 269 | else: 270 | scaling_mat[0, 2] = frame_h // 6 271 | scaling_mat[1, 2] = frame_h // 6 272 | tx = random.uniform(-noise_trans, noise_trans) 273 | ty = random.uniform(-noise_trans, noise_trans) 274 | translate_mat = np.array([[1, 0, tx], 275 | [0, 1, ty], 276 | [0, 0, 1]], dtype=np.float32) 277 | 278 | theta = random.uniform(-noise_rotate, noise_rotate) 279 | deflection = 1.0 280 | theta = theta * 2.0 * deflection * np.pi # in radians 281 | c, s = np.cos(theta), np.sin(theta) 282 | rotate_mat = np.array([[c, -s, 0], 283 | [s, c, 0], 284 | [0, 0, 1]], dtype=np.float32) 285 | 286 | pert_homo = rotate_mat @ gt_homo @ scaling_mat @ translate_mat @ rotate_mat.T 287 | pert_homo /= pert_homo[2, 2] 288 | 289 | # shape is (?, 3) 290 | pert_src_grid = tmp_grid @ np.linalg.inv(pert_homo.T) 291 | pert_src_grid /= pert_src_grid[:, 2, np.newaxis] 292 | 293 | for pts, cls in zip(pert_src_grid, class_labels): 294 | pts[2] = cls # assign keypoints label 295 | 296 | src_list, dst_list = [], [] 297 | for _src, _dst in zip(src_grid, pert_src_grid): 298 | if np.isnan(_dst).any(): 299 | continue 300 | # warp points maybe out of image resolution after perturbation 301 | if 0 <= _dst[0] < frame_w and 0 <= _dst[1] < frame_h: 302 | src_list.append(_src) 303 | dst_list.append(_dst) 304 | else: 305 | _dst[0] = float('nan') 306 | _dst[1] = float('nan') 307 | _dst[2] = -1. 308 | src_pts = np.array(src_list) 309 | dst_pts = np.array(dst_list) 310 | if src_pts.shape[0] >= 4 and dst_pts.shape[0] >= 4: 311 | new_homo_mat, mask = cv2.findHomography( 312 | src_pts[:, :2].reshape(-1, 1, 2), dst_pts[:, :2].reshape(-1, 1, 2), cv2.RANSAC, 5) 313 | if new_homo_mat is not None: 314 | warp_image = cv2.warpPerspective( 315 | frame, new_homo_mat, (frame_w, frame_h), cv2.INTER_CUBIC, borderMode=cv2.BORDER_CONSTANT, borderValue=(0)) 316 | warp_grid = pert_src_grid.copy() 317 | homo_mat = pert_homo 318 | else: 319 | warp_image = None 320 | warp_grid = None 321 | else: 322 | warp_image = None 323 | warp_grid = None 324 | else: 325 | warp_image = frame.copy() 326 | grid_list = [] 327 | for ind, pts in enumerate(gt_warp_grid): 328 | if 0 <= pts[0] < frame_w and 0 <= pts[1] < frame_h: 329 | grid_list.append(pts) 330 | else: 331 | grid_list.append([float('nan'), float('nan'), -1.]) 332 | warp_grid = np.array(grid_list) 333 | homo_mat = gt_homo 334 | 335 | if warp_image is None and warp_grid is None: 336 | warp_image = frame.copy() 337 | grid_list = [] 338 | for ind, pts in enumerate(gt_warp_grid): 339 | if 0 <= pts[0] < frame_w and 0 <= pts[1] < frame_h: 340 | grid_list.append(pts) 341 | else: 342 | grid_list.append([float('nan'), float('nan'), -1.]) 343 | warp_grid = np.array(grid_list) 344 | homo_mat = gt_homo 345 | 346 | return warp_image, warp_grid, homo_mat 347 | 348 | 349 | def put_lrflip_augmentation(frame, unigrid): 350 | 351 | frame_w, frame_h = frame.size 352 | npy_image = np.array(frame) 353 | 354 | flipped_img = np.fliplr(npy_image) 355 | 356 | # TODO: grid flipping and re-assign pixels class label, 1-91 357 | for ind, pts in enumerate(unigrid): 358 | pts[0] = frame_w - pts[0] 359 | col = (pts[2] - 1) // 7 # get each column of uniform grid 360 | pts[2] = pts[2] - (col - 6) * 2 * 7 # keypoints label 361 | 362 | return Image.fromarray(flipped_img), unigrid 363 | -------------------------------------------------------------------------------- /robust/test.py: -------------------------------------------------------------------------------- 1 | ''' 2 | testing file for Nie et al. (A robust and efficient framework for sports-field registration) 3 | ''' 4 | import sys 5 | sys.path.append('..') 6 | from options import CustomOptions 7 | from models.model import EncDec 8 | from worldcup_loader import PublicWorldCupDataset 9 | from ts_worldcup_loader import MainTestSVDataset 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.optim as optim 14 | from torch.utils.data import DataLoader 15 | import numpy as np 16 | import os 17 | import os.path as osp 18 | import time 19 | from PIL import Image 20 | from tqdm import tqdm 21 | import cv2 22 | import matplotlib.pyplot as plt 23 | import utils 24 | import metrics 25 | import skimage.segmentation as ss 26 | 27 | 28 | # Get input arguments 29 | opt = CustomOptions(train=False) 30 | opt = opt.parse() 31 | 32 | # Setup GPU 33 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_ids 34 | print('CUDA_VISIBLE_DEVICES: %s' % opt.gpu_ids) 35 | device = torch.device('cuda:0') 36 | print('device: %s' % device) 37 | 38 | 39 | def calc_euclidean_distance(a, b, _norm=np.linalg.norm, axis=None): 40 | return _norm(a - b, axis=axis) 41 | 42 | 43 | def my_mseloss(gt, pred): 44 | return torch.mean(torch.square(pred - gt)) 45 | 46 | 47 | def postprocessing(scores, pred, target, num_classes, nms_thres): 48 | 49 | # TODO: decode the heatmaps into keypoint sets using non-maximum suppression 50 | pred_cls_dict = {k: [] for k in range(1, num_classes)} 51 | 52 | for cls in range(1, num_classes): 53 | pred_inds = pred == cls 54 | 55 | # implies the current class does not appear in this heatmaps 56 | if not np.any(pred_inds): 57 | continue 58 | 59 | values = scores[pred_inds] 60 | max_score = values.max() 61 | max_index = values.argmax() 62 | 63 | indices = np.where(pred_inds) 64 | coords = list(zip(indices[0], indices[1])) 65 | 66 | # the only keypoint with max confidence is greater than threshold or not 67 | if max_score >= nms_thres: 68 | pred_cls_dict[cls].append(max_score) 69 | pred_cls_dict[cls].append(coords[max_index]) 70 | 71 | gt_cls_dict = {k: [] for k in range(1, num_classes)} 72 | for cls in range(1, num_classes): 73 | gt_inds = target == cls 74 | 75 | # implies the current class does not appear in this heatmaps 76 | if not np.any(gt_inds): 77 | continue 78 | coords = np.argwhere(gt_inds)[0] 79 | 80 | # coordinate order is (y, x) 81 | gt_cls_dict[cls].append((coords[0], coords[1])) 82 | 83 | return gt_cls_dict, pred_cls_dict 84 | 85 | 86 | def calc_keypts_metrics(gt_cls_dict, pred_cls_dict, pr_thres): 87 | 88 | num_gt_pos = 0 89 | num_pred_pos = 0 90 | num_both_keypts_appear = 0 91 | tp = 0 92 | mse_loss = 0.0 93 | 94 | for (gk, gv), (pk, pv) in zip(gt_cls_dict.items(), pred_cls_dict.items()): 95 | if gv: 96 | num_gt_pos += 1 97 | 98 | if pv: 99 | num_pred_pos += 1 100 | 101 | if gv and pv: 102 | num_both_keypts_appear += 1 103 | mse_loss += my_mseloss(torch.FloatTensor(gv[0]), 104 | torch.FloatTensor(pv[1])) 105 | 106 | if calc_euclidean_distance(np.array(gv[0]), np.array(pv[1])) <= pr_thres: 107 | tp += 1 108 | 109 | if num_both_keypts_appear == 0: 110 | return 0.0, 0.0, 0.0 111 | return tp / num_pred_pos, tp / num_gt_pos, mse_loss / num_both_keypts_appear 112 | 113 | 114 | def class_mapping(rgb): 115 | 116 | # TODO: class mapping 117 | template = utils.gen_template_grid() # grid shape (91, 3), (x, y, label) 118 | src_pts = rgb.copy() 119 | cls_map_pts = [] 120 | 121 | for ind, elem in enumerate(src_pts): 122 | coords = np.where(elem[2] == template[:, 2])[0] # find correspondence 123 | cls_map_pts.append(template[coords[0]]) 124 | dst_pts = np.array(cls_map_pts, dtype=np.float32) 125 | 126 | return src_pts[:, :2], dst_pts[:, :2] 127 | 128 | 129 | def test(): 130 | 131 | num_classes = 92 132 | non_local = bool(opt.use_non_local) 133 | layers = 18 134 | 135 | # Initialize models 136 | model = EncDec(layers, num_classes, non_local).to(device) 137 | 138 | # Setup dataset 139 | if opt.train_stage == 0: 140 | # Load testing data 141 | print('Loading data from public world cup dataset...') 142 | test_dataset = PublicWorldCupDataset( 143 | root=opt.public_worldcup_root, 144 | data_type=opt.testset, 145 | mode='test' 146 | ) 147 | test_loader = DataLoader( 148 | dataset=test_dataset, 149 | batch_size=1, 150 | shuffle=False, 151 | num_workers=4, 152 | ) 153 | elif opt.train_stage == 1: 154 | # Load testing data 155 | print('Loading data from time sequence world cup dataset...') 156 | test_dataset = MainTestSVDataset( 157 | root=opt.custom_worldcup_root, 158 | data_type=opt.testset, 159 | mode='test' 160 | ) 161 | test_loader = DataLoader( 162 | dataset=test_dataset, 163 | batch_size=1, 164 | shuffle=False, 165 | num_workers=4, 166 | ) 167 | 168 | # Loss function 169 | class_weights = torch.ones(num_classes) * 100 170 | class_weights[0] = 1 171 | class_weights = class_weights.to(device) 172 | criterion = nn.CrossEntropyLoss( 173 | weight=class_weights) # TODO: put class weight 174 | 175 | # Set data path 176 | denorm = utils.UnNormalize( 177 | mean=[0.485, 0.456, 0.406], 178 | std=[0.229, 0.224, 0.225] 179 | ) 180 | exp_name_path = osp.join(opt.checkpoints_dir, opt.name) 181 | test_visual_dir = osp.join(exp_name_path, 'imgs', 'test_visual') 182 | os.makedirs(test_visual_dir, exist_ok=True) 183 | 184 | iou_visual_dir = osp.join(test_visual_dir, 'iou') 185 | os.makedirs(iou_visual_dir, exist_ok=True) 186 | 187 | homo_visual_dir = osp.join(test_visual_dir, 'homography') 188 | os.makedirs(homo_visual_dir, exist_ok=True) 189 | 190 | field_model = Image.open( 191 | osp.join(opt.template_path, 'worldcup_field_model.png')) 192 | 193 | # TODO:: Load pretrained model or resume training 194 | if len(opt.ckpt_path) > 0: 195 | load_weights_path = opt.ckpt_path 196 | print('Loading weights: ', load_weights_path) 197 | assert osp.isfile(load_weights_path), 'Error: no checkpoints found' 198 | checkpoint = torch.load(load_weights_path) 199 | model.load_state_dict(checkpoint['model_state_dict']) 200 | epoch = checkpoint['epoch'] 201 | print('Checkpoint Epoch: ', epoch) 202 | 203 | if opt.sfp_finetuned and opt.train_stage == 1: 204 | sfp_out_path = 'SingleFramePredict_finetuned_with_normalized' 205 | 206 | elif not opt.sfp_finetuned and opt.train_stage == 1: 207 | sfp_out_path = 'SingleFramePredict_with_normalized' 208 | 209 | elif not opt.sfp_finetuned and opt.train_stage == 0: 210 | # for test on worldcup dataset 211 | sfp_out_path = 'robust_worldcup_testset_dilated' 212 | 213 | print("Testing...") 214 | model.eval() 215 | batch_celoss = 0.0 216 | batch_l2loss = 0.0 217 | precision_list = [] 218 | recall_list = [] 219 | iou_part_list = [] 220 | iou_whole_list = [] 221 | proj_error_list = [] 222 | reproj_error_list = [] 223 | test_progress_bar = tqdm( 224 | enumerate(test_loader), total=len(test_loader), leave=False) 225 | test_progress_bar.set_description( 226 | f'Epoch: {epoch}/{opt.train_epochs}') 227 | 228 | cnt1 = 0 229 | cnt2 = 0 230 | cnt3 = 0 231 | cnt4 = 0 232 | cnt5 = 0 233 | cnt6 = 0 234 | cnt7 = 0 235 | cnt8 = 0 236 | cnt9 = 0 237 | cnt10 = 0 238 | 239 | with torch.no_grad(): 240 | for step, (image, gt_heatmap, target, gt_homo) in test_progress_bar: 241 | image = image.to(device) 242 | gt_heatmap = gt_heatmap.to(device).long() 243 | 244 | pred_heatmap = model(image) 245 | 246 | # (B, 92, 180, 320), (B, 180, 320) 247 | loss = criterion(pred_heatmap, gt_heatmap) 248 | 249 | pred_heatmap = torch.softmax(pred_heatmap, dim=1) 250 | scores, pred_heatmap = torch.max(pred_heatmap, dim=1) 251 | scores = scores[0].detach().cpu().numpy() 252 | pred_heatmap = pred_heatmap[0].detach().cpu().numpy() 253 | gt_heatmap = gt_heatmap[0].cpu().numpy() 254 | target = target[0].cpu().numpy() 255 | gt_homo = gt_homo[0].cpu().numpy() 256 | 257 | gt_cls_dict, pred_cls_dict = postprocessing( 258 | scores, pred_heatmap, target, num_classes, opt.nms_thres) 259 | 260 | p, r, loss2 = calc_keypts_metrics( 261 | gt_cls_dict, pred_cls_dict, opt.pr_thres) 262 | 263 | precision_list.append(p) 264 | recall_list.append(r) 265 | 266 | batch_celoss += loss.detach() 267 | batch_l2loss += loss2.detach() 268 | 269 | # TODO: log loss 270 | if step % 10 == 9: 271 | batch_celoss /= 10 272 | print('Step: {}/{}\tTesting CE Loss: {:.4f}'.format(step, 273 | len(test_loader), batch_celoss)) 274 | batch_celoss = 0.0 275 | 276 | batch_l2loss /= 10 277 | print('Step: {}/{}\tTesting MSE Loss: {:.4f}'.format(step, 278 | len(test_loader), batch_l2loss)) 279 | batch_l2loss = 0.0 280 | 281 | image = utils.im_to_numpy(denorm(image[0])) 282 | 283 | # TODO: show keypoints visual result after postprocessing 284 | pred_keypoints = np.zeros_like( 285 | pred_heatmap, dtype=np.uint8) 286 | pred_rgb = [] 287 | for ind, (pk, pv) in enumerate(pred_cls_dict.items()): 288 | if pv: 289 | pred_keypoints[pv[1][0], pv[1][1]] = pk # (H, W) 290 | # camera view point sets (x, y, label) in rgb domain not heatmap domain 291 | pred_rgb.append([pv[1][1] * 4, pv[1][0] * 4, pk]) 292 | pred_rgb = np.asarray(pred_rgb, dtype=np.float32) # (?, 3) 293 | pred_homo = None 294 | if pred_rgb.shape[0] >= 4: # at least four points 295 | src_pts, dst_pts = class_mapping(pred_rgb) 296 | pred_homo, _ = cv2.findHomography( 297 | src_pts.reshape(-1, 1, 2), dst_pts.reshape(-1, 1, 2), cv2.RANSAC, 10) 298 | if pred_homo is not None: 299 | iou_part, gt_part_mask, pred_part_mask, part_merge_result = metrics.calc_iou_part( 300 | pred_homo, gt_homo, image, field_model) 301 | iou_part_list.append(iou_part) 302 | 303 | iou_whole, whole_line_merge_result, whole_fill_merge_result = metrics.calc_iou_whole_with_poly( 304 | pred_homo, gt_homo, image, field_model) 305 | iou_whole_list.append(iou_whole) 306 | 307 | proj_error = metrics.calc_proj_error( 308 | pred_homo, gt_homo, image, field_model) 309 | proj_error_list.append(proj_error) 310 | 311 | reproj_error = metrics.calc_reproj_error( 312 | pred_homo, gt_homo, image, field_model) 313 | reproj_error_list.append(reproj_error) 314 | else: 315 | print(f'pred homo is None at {step + 1}') 316 | iou_part_list.append(float('nan')) 317 | iou_whole_list.append(float('nan')) 318 | proj_error_list.append(float('nan')) 319 | reproj_error_list.append(float('nan')) 320 | else: 321 | print(f'less than four points at {step + 1}') 322 | iou_part_list.append(float('nan')) 323 | iou_whole_list.append(float('nan')) 324 | proj_error_list.append(float('nan')) 325 | reproj_error_list.append(float('nan')) 326 | 327 | # # TODO: save pic 328 | if True: 329 | # if False: 330 | pred_keypoints = ss.expand_labels( 331 | pred_keypoints, distance=5) 332 | 333 | # TODO: save undilated heatmap for each testing video 334 | if step < 89: 335 | if opt.train_stage == 1: 336 | vid1_path = osp.join( 337 | exp_name_path, sfp_out_path, '80_95/left/2014_Match_Highlights1_clip_00007-1') 338 | elif opt.train_stage == 0: # for test on worldcup dataset 339 | vid1_path = osp.join( 340 | exp_name_path, sfp_out_path, 'worldcup_2014') 341 | os.makedirs(vid1_path, exist_ok=True) 342 | cv2.imwrite(osp.join(vid1_path, '%05d.png' % 343 | (cnt1)), pred_keypoints) 344 | cnt1 += 1 345 | 346 | elif step >= 89 and step < 172: 347 | if opt.train_stage == 1: 348 | vid2_path = osp.join( 349 | exp_name_path, sfp_out_path, '80_95/left/2014_Match_Highlights2_clip_00006-1') 350 | elif opt.train_stage == 0: # for test on worldcup dataset 351 | vid2_path = osp.join( 352 | exp_name_path, sfp_out_path, 'worldcup_2014') 353 | os.makedirs(vid2_path, exist_ok=True) 354 | if opt.train_stage == 0: 355 | cv2.imwrite(osp.join(vid2_path, '%05d.png' % 356 | (cnt1)), pred_keypoints) 357 | cnt1 += 1 358 | elif opt.train_stage == 1: 359 | cv2.imwrite(osp.join(vid2_path, '%05d.png' % 360 | (cnt2)), pred_keypoints) 361 | cnt2 += 1 362 | 363 | elif step >= 172 and step < 253: 364 | if opt.train_stage == 1: 365 | vid3_path = osp.join( 366 | exp_name_path, sfp_out_path, '80_95/left/2014_Match_Highlights3_clip_00004-2') 367 | elif opt.train_stage == 0: 368 | vid3_path = osp.join( 369 | exp_name_path, sfp_out_path, 'worldcup_2014') 370 | os.makedirs(vid3_path, exist_ok=True) 371 | if opt.train_stage == 0: 372 | cv2.imwrite(osp.join(vid3_path, '%05d.png' % 373 | (cnt1)), pred_keypoints) 374 | cnt1 += 1 375 | elif opt.train_stage == 1: 376 | cv2.imwrite(osp.join(vid3_path, '%05d.png' % 377 | (cnt3)), pred_keypoints) 378 | cnt3 += 1 379 | 380 | elif step >= 253 and step < 340: 381 | vid4_path = osp.join( 382 | exp_name_path, sfp_out_path, '80_95/left/2014_Match_Highlights3_clip_00009-1') 383 | os.makedirs(vid4_path, exist_ok=True) 384 | cv2.imwrite(osp.join(vid4_path, '%05d.png' % 385 | (cnt4)), pred_keypoints) 386 | cnt4 += 1 387 | 388 | elif step >= 340 and step < 426: 389 | vid5_path = osp.join( 390 | exp_name_path, sfp_out_path, '80_95/left/2014_Match_Highlights6_clip_00013-1') 391 | os.makedirs(vid5_path, exist_ok=True) 392 | cv2.imwrite(osp.join(vid5_path, '%05d.png' % 393 | (cnt5)), pred_keypoints) 394 | cnt5 += 1 395 | 396 | elif step >= 426 and step < 514: 397 | vid6_path = osp.join( 398 | exp_name_path, sfp_out_path, '80_95/right/2014_Match_Highlights3_clip_00013-1') 399 | os.makedirs(vid6_path, exist_ok=True) 400 | cv2.imwrite(osp.join(vid6_path, '%05d.png' % 401 | (cnt6)), pred_keypoints) 402 | cnt6 += 1 403 | 404 | elif step >= 514 and step < 609: 405 | vid7_path = osp.join( 406 | exp_name_path, sfp_out_path, '80_95/right/2014_Match_Highlights5_clip_00010-2') 407 | os.makedirs(vid7_path, exist_ok=True) 408 | cv2.imwrite(osp.join(vid7_path, '%05d.png' % 409 | (cnt7)), pred_keypoints) 410 | cnt7 += 1 411 | 412 | elif step >= 609 and step < 702: 413 | vid8_path = osp.join( 414 | exp_name_path, sfp_out_path, '80_95/right/2018_Match_Highlights5_clip_00016-1') 415 | os.makedirs(vid8_path, exist_ok=True) 416 | cv2.imwrite(osp.join(vid8_path, '%05d.png' % 417 | (cnt8)), pred_keypoints) 418 | cnt8 += 1 419 | 420 | elif step >= 702 and step < 793: 421 | vid9_path = osp.join( 422 | exp_name_path, sfp_out_path, '80_95/right/2018_Match_Highlights6_clip_00015-2') 423 | os.makedirs(vid9_path, exist_ok=True) 424 | cv2.imwrite(osp.join(vid9_path, '%05d.png' % 425 | (cnt9)), pred_keypoints) 426 | cnt9 += 1 427 | 428 | elif step >= 793 and step < 887: 429 | vid10_path = osp.join( 430 | exp_name_path, sfp_out_path, '80_95/right/2018_Match_Highlights6_clip_00023-3') 431 | os.makedirs(vid10_path, exist_ok=True) 432 | cv2.imwrite(osp.join(vid10_path, '%05d.png' % 433 | (cnt10)), pred_keypoints) 434 | cnt10 += 1 435 | 436 | # plt.imsave(osp.join(test_visual_dir, 437 | # 'test_%05d_%05d_rgb.jpg' % (epoch, step)), image) 438 | plt.imsave(osp.join(test_visual_dir, 'test_%05d_%05d_pred.png' % ( 439 | epoch, step)), pred_heatmap, vmin=0, vmax=91) 440 | plt.imsave(osp.join(test_visual_dir, 'test_%05d_%05d_gt.png' % ( 441 | epoch, step)), gt_heatmap, vmin=0, vmax=91) 442 | # pred_keypoints = ss.expand_labels( 443 | # pred_keypoints, distance=5) 444 | plt.imsave(osp.join(test_visual_dir, 'test_%05d_%05d_pred_keypts.png' % ( 445 | epoch, step)), pred_keypoints, vmin=0, vmax=91) 446 | 447 | if True: 448 | # if False: 449 | if pred_rgb.shape[0] >= 4 and pred_homo is not None: 450 | # plt.imsave(osp.join(iou_visual_dir, 'test_%05d_%05d_gt_iou_part.png' % ( 451 | # epoch, step)), gt_part_mask) 452 | # plt.imsave(osp.join(iou_visual_dir, 'test_%05d_%05d_pred_iou_part.png' % ( 453 | # epoch, step)), pred_part_mask) 454 | # plt.imsave(osp.join(iou_visual_dir, 'test_%05d_%05d_merge_iou_part.png' % ( 455 | # epoch, step)), part_merge_result) 456 | # plt.imsave(osp.join(iou_visual_dir, 'test_%05d_%05d_line_iou_whole.png' % ( 457 | # epoch, step)), whole_line_merge_result) 458 | # plt.imsave(osp.join(iou_visual_dir, 'test_%05d_%05d_fill_iou_whole.png' % ( 459 | # epoch, step)), whole_fill_merge_result) 460 | np.save(osp.join(homo_visual_dir, 'test_%05d_%05d_gt_homography.npy' % ( 461 | epoch, step)), gt_homo) 462 | np.save(osp.join(homo_visual_dir, 'test_%05d_%05d_pred_homography.npy' % ( 463 | epoch, step)), pred_homo) 464 | 465 | average_precision = np.array(precision_list).mean() 466 | average_recall = np.array(recall_list).mean() 467 | print( 468 | f'Average Precision: {average_precision:.2f}, Recall: {average_recall:.2f}') 469 | 470 | iou_part_list = np.array(iou_part_list) 471 | iou_whole_list = np.array(iou_whole_list) 472 | mean_iou_part = np.nanmean(iou_part_list) 473 | mean_iou_whole = np.nanmean(iou_whole_list) 474 | print( 475 | f'Mean IOU part: {mean_iou_part * 100.:.1f}, IOU whole: {mean_iou_whole * 100.:.1f}') 476 | 477 | median_iou_part = np.nanmedian(iou_part_list) 478 | median_iou_whole = np.nanmedian(iou_whole_list) 479 | print( 480 | f'Median IOU part: {median_iou_part * 100.:.1f}, IOU whole: {median_iou_whole * 100.:.1f}') 481 | 482 | proj_error_list = np.array(proj_error_list) 483 | reproj_error_list = np.array(reproj_error_list) 484 | mean_proj_error = np.nanmean(proj_error_list) 485 | mean_reproj_error = np.nanmean(reproj_error_list) 486 | print( 487 | f'Mean Projection Error: {mean_proj_error:.2f}, Reprojection Error: {mean_reproj_error:.3f}') 488 | 489 | median_proj_error = np.nanmedian(proj_error_list) 490 | median_reproj_error = np.nanmedian(reproj_error_list) 491 | print( 492 | f'Median Projection Error: {median_proj_error:.2f}, Reprojection Error: {median_reproj_error:.3f}') 493 | 494 | with open(osp.join(exp_name_path, 'metrics_%03d.txt' % epoch), 'w') as out_file: 495 | out_file.write( 496 | f'Loading weights: {load_weights_path}') 497 | out_file.write('\n') 498 | out_file.write( 499 | f'Average Precision: {average_precision:.2f}, Recall: {average_recall:.2f}') 500 | out_file.write('\n') 501 | out_file.write( 502 | f'Mean IOU part: {mean_iou_part * 100.:.1f}, IOU whole: {mean_iou_whole * 100.:.1f}') 503 | out_file.write('\n') 504 | out_file.write( 505 | f'Median IOU part: {median_iou_part * 100.:.1f}, IOU whole: {median_iou_whole * 100.:.1f}') 506 | out_file.write('\n') 507 | out_file.write( 508 | f'Mean Projection Error: {mean_proj_error:.2f}, Reprojection Error: {mean_reproj_error:.3f}') 509 | out_file.write('\n') 510 | out_file.write( 511 | f'Median Projection Error: {median_proj_error:.2f}, Reprojection Error: {median_reproj_error:.3f}') 512 | out_file.write('\n') 513 | 514 | 515 | def main(): 516 | 517 | test() 518 | 519 | 520 | if __name__ == '__main__': 521 | 522 | start_time = time.time() 523 | main() 524 | print(f'Done...Take {(time.time() - start_time):.4f} (sec)') 525 | --------------------------------------------------------------------------------