├── intro.png ├── dataset.png ├── src ├── datasets │ ├── LockableSeedRandomAccess.py │ ├── CenterDirGroundtruthDataset.py │ ├── __init__.py │ └── MuJoCoDataset.py ├── criterions │ ├── __init__.py │ ├── weightings │ │ ├── unbalanced_weight.py │ │ └── instance_weight.py │ ├── per_pixel_losses.py │ ├── loss_weighting │ │ └── weight_methods.py │ ├── center_localization_loss.py │ └── orientation_loss.py ├── utils │ ├── visualize │ │ └── __init__.py │ ├── evaluation │ │ ├── __init__.py │ │ ├── multivariate.py │ │ ├── center.py │ │ └── orientation.py │ ├── utils_depth.py │ ├── overlaps.py │ └── utils.py ├── models │ ├── multitask_model.py │ ├── __init__.py │ ├── center_estimator_fast.py │ ├── center_estimator_with_orientation.py │ └── center_estimator.py ├── config │ ├── __init__.py │ ├── mujoco │ │ └── train.py │ └── vicos_towel │ │ ├── novel_object=bg+cloth │ │ └── test.py │ │ ├── novel_object=bg │ │ └── test.py │ │ ├── novel_object=cloth │ │ └── test.py │ │ ├── test.py │ │ ├── test_on_train │ │ └── test.py │ │ └── train.py ├── inference │ └── processing.py └── infer.py ├── scripts ├── config_user.sh.example ├── utils.sh ├── config.sh ├── run_distributed.sh ├── EXPERIMENTS_MAIN.sh └── EXPERIMENTS_ABLATION.sh ├── datasets └── download.sh ├── models ├── download_localization_model.sh └── download_models.sh ├── requirements.txt ├── .gitignore ├── environment.yml └── tools ├── export_mujoco_to_coco.py ├── export_mujoco_to_coco_keypoints.py ├── export_vicos_towels_to_coco.py ├── export_vicos_towels_to_coco_keypoints.py └── calculate_cloth_angles.py /intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vicoslab/CeDiRNet-3DoF/HEAD/intro.png -------------------------------------------------------------------------------- /dataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vicoslab/CeDiRNet-3DoF/HEAD/dataset.png -------------------------------------------------------------------------------- /src/datasets/LockableSeedRandomAccess.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | import numpy as np 4 | 5 | class LockableSeedRandomAccess(): 6 | 7 | def lock_samples_seed(self, index_list): 8 | raise NotImplementedError() -------------------------------------------------------------------------------- /scripts/config_user.sh.example: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # DO NOT COMMIT THIS FILE !!!! 3 | 4 | # you may override config.sh variables here - for example 5 | # USE_CONDA_HOME=${USE_CONDA_HOME:-/home/USER/conda} 6 | # USE_CONDA_ENV=${USE_CONDA_ENV:-CeDiRNet} 7 | 8 | # OUTPUT_DIR=${OUTPUT_DIR:-/home/USER/Projects/CeDiRNet/exp} 9 | -------------------------------------------------------------------------------- /datasets/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CURDIR=$(dirname "$(realpath "$0")") 4 | 5 | # download and extract ViCoS Towel Dataset 6 | wget https://go.vicos.si/toweldataset -O - | unzip -d $CURDIR 7 | 8 | # download and extract MuJoCo Dataset 9 | wget https://go.vicos.si/towelmujocodataset -O - | unzip -d $CURDIR 10 | -------------------------------------------------------------------------------- /models/download_localization_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Define the URLs for the models 4 | LOCALOZATION_MODEL=( 5 | "https://box.vicos.si/skokec/rtfm/CeDiRNet-3DoF/localization_checkpoint.pth" 6 | ) 7 | 8 | # Get the directory where the script is located 9 | script_dir=$(dirname "$(realpath "$0")") 10 | 11 | # Download each model 12 | for url in "${LOCALOZATION_MODEL[@]}"; do 13 | wget -P "$script_dir" "$url" 14 | done 15 | 16 | -------------------------------------------------------------------------------- /models/download_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Define the URLs for the models 4 | MODELS=( 5 | "https://box.vicos.si/skokec/rtfm/CeDiRNet-3DoF/ConvNext-L-RGB.pth" 6 | "https://box.vicos.si/skokec/rtfm/CeDiRNet-3DoF/ConvNext-L-RGB-D.pth" 7 | "https://box.vicos.si/skokec/rtfm/CeDiRNet-3DoF/ConvNext-B-RGB.pth" 8 | "https://box.vicos.si/skokec/rtfm/CeDiRNet-3DoF/ConvNext-B-RGB-D.pth" 9 | ) 10 | 11 | # Get the directory where the script is located 12 | script_dir=$(dirname "$(realpath "$0")") 13 | 14 | # Download each model 15 | for url in "${MODELS[@]}"; do 16 | wget -P "$script_dir" "$url" 17 | done 18 | 19 | -------------------------------------------------------------------------------- /src/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | from criterions.center_direction_loss import CenterDirectionLoss 2 | from criterions.orientation_loss import OrientationLoss 3 | 4 | def get_criterion(type, loss_opts, model, center_model): 5 | 6 | if type in ['CenterDirectionLoss','PolarCenterLossV2']: 7 | criterion = CenterDirectionLoss(center_model, **loss_opts) 8 | elif type in ['CenterDirectionLossOrientation','OrientationLoss']: 9 | criterion = OrientationLoss(center_model, **loss_opts) 10 | else: 11 | raise Exception("Unknown 'loss_type' in config: only allowed 'CenterDirectionLoss' or 'CenterDirectionLossOrientation'") 12 | 13 | return criterion -------------------------------------------------------------------------------- /src/utils/visualize/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .center import CentersVisualizeTest, CentersVisualizeTrain 3 | from .orientation import OrientationVisualizeTest, OrientationVisualizeTrain 4 | 5 | def get_visualizer(name, opts): 6 | if name == 'CentersVisualizeTest': 7 | return CentersVisualizeTest(**opts) 8 | elif name == 'CentersVisualizeTrain': 9 | return CentersVisualizeTrain(**opts) 10 | if name == 'OrientationVisualizeTest': 11 | return OrientationVisualizeTest(**opts) 12 | elif name == 'OrientationVisualizeTrain': 13 | return OrientationVisualizeTrain(**opts) 14 | else: 15 | raise Exception("Unknown visualizer: '%s'" % name) -------------------------------------------------------------------------------- /src/models/multitask_model.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator 2 | 3 | import torch.nn as nn 4 | 5 | class MultiTaskModel: 6 | def shared_parameters(self) -> Iterator[nn.parameter.Parameter]: 7 | """Parameters shared by all tasks. 8 | Returns 9 | ------- 10 | """ 11 | return NotImplemented 12 | 13 | def task_specific_parameters(self) -> Iterator[nn.parameter.Parameter]: 14 | """Parameters specific to each task. 15 | Returns 16 | ------- 17 | """ 18 | return NotImplemented 19 | 20 | def last_shared_parameters(self) -> Iterator[nn.parameter.Parameter]: 21 | """Parameters of the last shared layer. 22 | Returns 23 | ------- 24 | """ 25 | return NotImplemented -------------------------------------------------------------------------------- /scripts/utils.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | function cleanup() { 4 | echo "Terminated. Cleaning up .." 5 | child_ids=$(pgrep -P $$ | xargs echo | tr " " ,) 6 | # kill all child processes 7 | pkill -P $child_ids 8 | pkill $$ 9 | exit 0 10 | } 11 | 12 | function wait_or_interrupt() { 13 | # set to kill any child processes if parent is interupted 14 | #trap "pkill -P $child_ids && pkill $$ && echo exit && exit 0" SIGINT 15 | trap cleanup SIGINT 16 | # now wait 17 | if [ -z "$1" ] ; then 18 | wait 19 | elif [ -n "$1" ] && [ -n "$2" ] ; then 20 | MAX_CAPACITY=$1 21 | INDEX=$2 22 | # wait only if INDEX mod MAX_CAPACITY == 0 23 | if [ $((INDEX % MAX_CAPACITY)) -eq 0 ] ; then 24 | wait 25 | fi 26 | else 27 | # wait if more child processes exist than allowed ($1 is the number of allowed children) 28 | while test $(jobs -p | wc -w) -ge "$1"; do wait -n; done 29 | fi 30 | } -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.FPN import FPN 2 | from models.center_estimator import CenterEstimator 3 | from models.center_estimator_fast import CenterEstimatorFast 4 | from models.center_estimator_with_orientation import CenterOrientationEstimator, CenterOrientationEstimatorFast 5 | from models.center_augmentator import CenterAugmentator 6 | 7 | def get_model(name, model_opts): 8 | if name == "fpn": 9 | model = FPN(**model_opts) 10 | else: 11 | raise RuntimeError("model \"{}\" not available".format(name)) 12 | 13 | return model 14 | 15 | def get_center_model(name, model_opts, is_learnable, use_fast_estimator=False): 16 | if name in ['CenterEstimatorOrientation','CenterOrientationEstimator']: 17 | if use_fast_estimator: 18 | return CenterOrientationEstimatorFast(model_opts, is_learnable=is_learnable) 19 | else: 20 | return CenterOrientationEstimator(model_opts, is_learnable=is_learnable) 21 | elif name == 'CenterEstimatorFast': 22 | return CenterEstimatorFast(model_opts, is_learnable=is_learnable) 23 | else: # PolarVotingCentersMultiscale or CenterEstimator 24 | if use_fast_estimator: 25 | return CenterEstimatorFast(model_opts, is_learnable=is_learnable) 26 | else: 27 | return CenterEstimator(model_opts, is_learnable=is_learnable) 28 | 29 | def get_center_augmentator(name, model_opts): 30 | return CenterAugmentator(model_opts) 31 | -------------------------------------------------------------------------------- /src/datasets/CenterDirGroundtruthDataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import numpy as np 3 | 4 | import torch 5 | from torch.utils.data import Dataset 6 | from datasets.LockableSeedRandomAccess import LockableSeedRandomAccess 7 | 8 | class CenterDirGroundtruthDataset(Dataset, LockableSeedRandomAccess): 9 | 10 | def __init__(self, dataset, centerdir_groundtruth_op): 11 | 12 | self.dataset = dataset 13 | self.centerdir_groundtruth_op = centerdir_groundtruth_op 14 | 15 | def get_coco_api(self): 16 | return self.dataset.get_coco_api() 17 | 18 | def lock_samples_seed(self, index_list): 19 | if isinstance(self.dataset,LockableSeedRandomAccess): 20 | self.dataset.lock_samples_seed(index_list) 21 | #else: 22 | # print("Warning: underlying dataset not instance of LockableSeedRandomAccess .. skipping") 23 | 24 | def __len__(self): 25 | return len(self.dataset) 26 | 27 | def __getitem__(self, index): 28 | sample = self.dataset[index] 29 | 30 | instance = sample['instance'] 31 | 32 | # init data strucutre for groundtruth 33 | centerdir_groundtruth, output = self.centerdir_groundtruth_op._init_datastructure(instance.shape[-2], instance.shape[-1]) 34 | 35 | # store cached data into sample 36 | if output is not None: 37 | sample['output'] = output 38 | sample['centerdir_groundtruth'] = centerdir_groundtruth 39 | 40 | return sample -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .CenterDirGroundtruthDataset import CenterDirGroundtruthDataset 2 | from .LockableSeedRandomAccess import LockableSeedRandomAccess 3 | from .MuJoCoDataset import MuJoCoDataset 4 | from .ViCoSTowelDataset import ViCoSTowelDataset 5 | from models.center_groundtruth import CenterDirGroundtruth 6 | 7 | def get_dataset(name, dataset_opts): 8 | if name.lower() == "mujoco": 9 | dataset = MuJoCoDataset(**dataset_opts) 10 | elif name.lower() == "vicos_towel": 11 | dataset = ViCoSTowelDataset(**dataset_opts) 12 | else: 13 | raise RuntimeError("Dataset {} not available".format(name)) 14 | 15 | return dataset 16 | 17 | def get_centerdir_dataset(name, dataset_opts, centerdir_gt_opts=None, centerdir_groundtruth_op=None, no_groundtruth=False): 18 | dataset = get_dataset(name, dataset_opts) 19 | 20 | if no_groundtruth: 21 | dataset.return_gt_heatmaps = False 22 | dataset.return_gt_box_polygon = False 23 | dataset.return_gt_polygon = False 24 | dataset.return_image = True 25 | return dataset, None 26 | 27 | if centerdir_gt_opts is not None and len(centerdir_gt_opts) > 0: 28 | if centerdir_groundtruth_op is None: 29 | centerdir_groundtruth_op = CenterDirGroundtruth(**centerdir_gt_opts) 30 | 31 | dataset = CenterDirGroundtruthDataset(dataset, centerdir_groundtruth_op) 32 | return dataset, centerdir_groundtruth_op 33 | else: 34 | return dataset, None 35 | 36 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | addict==2.4.0 3 | cachetools==5.3.0 4 | certifi @ file:///croot/certifi_1671487769961/work/certifi 5 | chardet==5.1.0 6 | charset-normalizer==3.0.1 7 | cvxpy==1.3.1 8 | cycler==0.11.0 9 | ecos==2.0.12 10 | efficientnet-pytorch==0.7.1 11 | filelock==3.9.0 12 | google-auth==2.16.0 13 | google-auth-oauthlib==0.4.6 14 | grpcio==1.51.1 15 | huggingface-hub==0.12.0 16 | idna==3.4 17 | imageio==2.9.0 18 | importlib-metadata==6.0.0 19 | joblib==1.2.0 20 | kiwisolver==1.4.4 21 | Markdown==3.4.1 22 | markdown-it-py==2.2.0 23 | MarkupSafe==2.1.2 24 | matplotlib==3.3.4 25 | mdurl==0.1.2 26 | mmcv==2.0.0 27 | mmengine==0.7.3 28 | munch==2.5.0 29 | networkx==3.0 30 | numpy==1.19.5 31 | oauthlib==3.2.2 32 | opencv-python==4.7.0.68 33 | osqp==0.6.2.post8 34 | packaging==23.0 35 | pandas==1.3.5 36 | Pillow==8.3.1 37 | pretrainedmodels==0.7.4 38 | projector-installer==1.8.0 39 | protobuf==3.20.3 40 | pyasn1==0.4.8 41 | pyasn1-modules==0.2.8 42 | Pygments==2.15.1 43 | pyparsing==3.0.9 44 | pypng==0.20220715.0 45 | python-dateutil==2.8.2 46 | pytz==2022.7.1 47 | PyWavelets==1.4.1 48 | PyYAML==6.0 49 | qdldl==0.1.7 50 | requests==2.28.2 51 | requests-oauthlib==1.3.1 52 | rich==13.3.5 53 | rsa==4.9 54 | safetensors==0.3.1 55 | scikit-image==0.17.2 56 | scikit-learn==0.24.2 57 | scipy==1.5.4 58 | scs==3.2.3 59 | segmentation-models-pytorch==0.3.2 60 | six==1.16.0 61 | tensorboard==2.11.2 62 | tensorboard-data-server==0.6.1 63 | tensorboard-plugin-wit==1.8.1 64 | termcolor==2.3.0 65 | threadpoolctl==3.1.0 66 | tifffile==2023.2.3 67 | timm==0.6.12 68 | tomli==2.0.1 69 | torch==1.9.1+cu111 70 | torchaudio==0.9.1 71 | torchvision==0.10.1+cu111 72 | tqdm==4.62.3 73 | typing_extensions==4.4.0 74 | urllib3==1.26.14 75 | Werkzeug==2.2.2 76 | yapf==0.33.0 77 | zipp==3.12.1 78 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # ignore user-specific config for CCC/HPC 107 | config_user.sh -------------------------------------------------------------------------------- /src/utils/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import json 4 | 5 | from sklearn.metrics import precision_recall_curve, roc_curve, auc, average_precision_score 6 | 7 | 8 | class NumpyEncoder(json.JSONEncoder): 9 | """ Special json encoder for numpy types """ 10 | 11 | def default(self, obj): 12 | if isinstance(obj, np.integer): 13 | return int(obj) 14 | elif isinstance(obj, np.floating): 15 | return float(obj) 16 | elif isinstance(obj, np.ndarray): 17 | return obj.tolist() 18 | return json.JSONEncoder.default(self, obj) 19 | 20 | def get_AP_and_F1(_Y, _P, ignore_mask=None): 21 | _Y = np.squeeze(np.array(_Y)) 22 | _P = np.squeeze(np.array(_P)) 23 | 24 | # replace inf values (missing detection) with something that is smaller than any other valid probability in P 25 | missing_det_mask = _P == -np.inf 26 | if any(missing_det_mask): 27 | _P[missing_det_mask] = np.min(_P[~missing_det_mask]) - 1 28 | 29 | if ignore_mask is not None: 30 | ignore_mask = np.squeeze(np.array(ignore_mask)) 31 | 32 | _Y = _Y[ignore_mask == 0] 33 | _P = _P[ignore_mask == 0] 34 | 35 | roc_curve = precision_recall_curve(_Y, _P) 36 | 37 | precision = roc_curve[0] 38 | recall = roc_curve[1] 39 | thrs = roc_curve[2] 40 | 41 | # remove missed detections from precision-recall scores (if there are any) 42 | # this is needed to prevent counting recall=100% when recall was never 100% 43 | if any(missing_det_mask): 44 | precision = precision[1:] 45 | recall = recall[1:] 46 | thrs = thrs[1:] 47 | 48 | # do not call average_precision_score(Y,P) directly since it will not be able to count 49 | # missed detection as proprelly missed 50 | AP = -np.sum(np.diff(recall) * np.array(precision)[:-1]) 51 | 52 | F1 = 2 * (precision * recall) / (precision + recall) 53 | 54 | return AP, F1, precision, recall, thrs 55 | -------------------------------------------------------------------------------- /scripts/config.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ################################################### 4 | ######## LOAD USER-SPECIFIC CONFIG 5 | ################################################### 6 | 7 | USER_CONFIG_FILE="$(dirname $0)/config_user.sh" 8 | # create config file if it does not exist 9 | if [ ! -f "$USER_CONFIG_FILE" ]; then 10 | cp "$(dirname $0)/config_user.sh.example" "$USER_CONFIG_FILE" 11 | fi 12 | 13 | # include user-specific settings 14 | # shellcheck source=./config_user.sh 15 | source "$USER_CONFIG_FILE" 16 | 17 | ################################################### 18 | ######## ACTIVATE CONDA ENV 19 | ################################################### 20 | echo "Loading conda env ..." 21 | 22 | USE_CONDA_HOME=${USE_CONDA_HOME:-~/conda} 23 | USE_CONDA_ENV=${USE_CONDA_ENV:-CeDiRNet-dev} 24 | 25 | . $USE_CONDA_HOME/etc/profile.d/conda.sh 26 | 27 | conda activate $USE_CONDA_ENV 28 | echo "... done - using $USE_CONDA_ENV" 29 | 30 | ################################################### 31 | ######## INPUT/OUTPUT PATH 32 | ################################################### 33 | 34 | export ROOT_DIR=${SOURCE_DIR:-$(realpath "$(dirname $0)/..")} 35 | export SOURCE_DIR=${SOURCE_DIR:-$(realpath "$(dirname $0)/../src")} 36 | export OUTPUT_DIR=${OUTPUT_DIR:-$(realpath "$(dirname $0)/../exp")} 37 | 38 | ################################################### 39 | ######## DATASET PATHS 40 | ################################################### 41 | 42 | export MUJOCO_DIR=${MUJOCO_DIR:-(realpath "$(dirname $0)/../datasets/MuJoCo/")} 43 | export VICOS_TOWEL_DATASET_DIR=${VICOS_TOWEL_DATASET_DIR:-(realpath "$(dirname $0)/../datasets/ViCoSTowelDataset/")} 44 | 45 | ################################################### 46 | ######## DATA PARALLEL SETTINGS 47 | ################################################### 48 | export NCCL_P2P_DISABLE=1 49 | export NCCL_IB_DISABLE=1 50 | export NCCL_BLOCKING_WAIT=1 51 | export NCCL_SHM_DISABLE=1 52 | #export NCCL_SOCKET_IFNAME=${NCCL_SOCKET_IFNAME:-eth1} 53 | -------------------------------------------------------------------------------- /src/criterions/weightings/unbalanced_weight.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class UnbalancedWeighting: 4 | def __init__(self, border_weight=1.0, border_weight_px=0, add_distance_gauss_weight=0, *args, **kwargs): 5 | self.border_weight = border_weight 6 | self.border_weight_px = border_weight_px 7 | 8 | self.add_distance_gauss_weight = add_distance_gauss_weight 9 | 10 | def __call__(self, gt_instances, gt_ignore=None, gt_R=None, w_fg=1, w_bg=1,*args, **kwargs): 11 | 12 | batch_size, height, width = gt_instances.shape 13 | 14 | bg_mask = (gt_instances == 0).unsqueeze(1) 15 | fg_mask = bg_mask == False 16 | 17 | mask_weights = torch.ones_like(bg_mask, dtype=torch.float32, requires_grad=False, device=gt_instances.device) 18 | 19 | mask_weights[fg_mask] = w_fg 20 | mask_weights[bg_mask] = w_bg 21 | 22 | # apply additional weights around borders 23 | if self.border_weight_px > 0: 24 | mask_weights = self._apply_border_weights(mask_weights) 25 | 26 | # treat each pixel equally but do not count ignored pixels 27 | if gt_ignore is not None: 28 | mask_weights *= 1 - gt_ignore.type(mask_weights.type()) 29 | mask_weights /= (~gt_ignore).sum() 30 | else: 31 | mask_weights /= (height * width * batch_size) 32 | 33 | # apply additional weight based on distance to center 34 | if self.add_distance_gauss_weight > 0: 35 | mask_weights = self._apply_gauss_distance_weights(mask_weights, gt_R) 36 | 37 | return mask_weights 38 | 39 | def _apply_border_weights(self, W): 40 | B = self.border_weight_px 41 | 42 | mask_border = torch.ones_like(W, dtype=torch.bool, requires_grad=False, device=W.device) 43 | 44 | mask_border[:, :, B:W.shape[2] - B, B:W.shape[3] - B] = 0 45 | W[mask_border] *= self.border_weight 46 | 47 | return W 48 | 49 | def _apply_gauss_distance_weights(self, W, R): 50 | return W * torch.exp(-R / (2 * self.add_distance_gauss_weight ** 2)) 51 | -------------------------------------------------------------------------------- /src/utils/evaluation/multivariate.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from functools import partial 3 | class MultivariateEval: 4 | def __init__(self, eval_obj, image2tags_fn): 5 | 6 | # save a reference to the original eval object 7 | self.ref_eval_obj = copy.deepcopy(eval_obj) 8 | 9 | self.eval_obj = eval_obj 10 | self.image2tags_fn = image2tags_fn 11 | 12 | self.eval_obj_per_tag = dict() 13 | 14 | def add_image_prediction(self, im_name, *args, **kwargs): 15 | # call original evaluation class 16 | result = self.eval_obj.add_image_prediction(im_name, *args, **kwargs) 17 | 18 | # parse image name to get the object tags 19 | img_tags = self.image2tags_fn(im_name) 20 | 21 | # add metrics to the corresponding tag (call the sam function on the corresponding eval object) 22 | for t in img_tags: 23 | if t not in self.eval_obj_per_tag: 24 | # make a copy of the original eval object 25 | self.eval_obj_per_tag[t] = copy.deepcopy(self.ref_eval_obj) 26 | # add modify its save_str() function to add tag information 27 | self.eval_obj_per_tag[t].save_str = partial(lambda x: self.eval_obj.save_str() + '_' + x, t) 28 | 29 | # this is less efficient since we call the same function multiple times but is independent of the eval class 30 | self.eval_obj_per_tag[t].add_image_prediction(im_name, *args, **kwargs) 31 | 32 | return result 33 | 34 | def save_str(self): 35 | return self.eval_obj.save_str() 36 | 37 | def get_attributes(self): 38 | return self.eval_obj.get_attributes() 39 | 40 | def get_results_timestamp(self, *args, **kwargs): 41 | return self.eval_obj.get_results_timestamp(*args, **kwargs) 42 | 43 | def calc_and_display_final_metrics(self, *args, **kwargs): 44 | # save original results 45 | metrics = self.eval_obj.calc_and_display_final_metrics(*args, **kwargs) 46 | 47 | # save per tag results 48 | for t, eval_obj in self.eval_obj_per_tag.items(): 49 | self.eval_obj_per_tag[t].calc_and_display_final_metrics(*args, **kwargs) 50 | 51 | return metrics -------------------------------------------------------------------------------- /src/models/center_estimator_fast.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import numpy as np 5 | 6 | from models.localization.centers import Conv1dMultiscaleLocalization, Conv2dDilatedLocalization 7 | 8 | class CenterEstimatorFast(nn.Module): 9 | def __init__(self, args=dict(), is_learnable=True): 10 | super().__init__() 11 | 12 | instance_center_estimator_op = Conv1dMultiscaleLocalization 13 | if args.get('use_dilated_nn'): 14 | from functools import partial 15 | instance_center_estimator_op = partial(Conv2dDilatedLocalization, 16 | **args.get('dilated_nn_args',{})) 17 | 18 | self.instance_center_estimator = instance_center_estimator_op( 19 | local_max_thr=args.get('local_max_thr', 0.1), 20 | mask_thr=args.get('mask_thr', 0.01), 21 | exclude_border_px=args.get('exclude_border_px', 5), 22 | learnable=is_learnable, 23 | allow_input_backprop=args.get('allow_input_backprop', True), 24 | backprop_only_positive=args.get('backprop_only_positive', True), 25 | apply_input_smoothing_for_local_max=0, 26 | use_findcontours_for_local_max=args.get('use_findcontours_for_local_max', False), 27 | local_max_min_dist=1, 28 | return_time=True 29 | ) 30 | 31 | def set_return_backbone_only(self, val): 32 | pass 33 | 34 | def is_return_backbone_only(self): 35 | return False 36 | 37 | def init_output(self, num_vector_fields=1): 38 | self.num_vector_fields = num_vector_fields 39 | 40 | assert self.num_vector_fields >= 3 41 | self.instance_center_estimator.init_output() 42 | 43 | return input 44 | 45 | def forward(self, input): 46 | 47 | assert input.shape[1] >= self.num_vector_fields 48 | 49 | predictions = input[:, 0:self.num_vector_fields] 50 | 51 | S = predictions[:, 0].unsqueeze(1) 52 | C = predictions[:, 1].unsqueeze(1) 53 | 54 | res, _, times = self.instance_center_estimator(C, S, None, None, None, ignore_region=None) 55 | center_pred = res 56 | 57 | return center_pred, times -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: CeDiRNet-py3.8 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - ca-certificates=2023.01.10=h06a4308_0 8 | - certifi=2022.12.7=py38h06a4308_0 9 | - ld_impl_linux-64=2.38=h1181459_1 10 | - libffi=3.4.2=h6a678d5_6 11 | - libgcc-ng=11.2.0=h1234567_1 12 | - libgomp=11.2.0=h1234567_1 13 | - libstdcxx-ng=11.2.0=h1234567_1 14 | - ncurses=6.4=h6a678d5_0 15 | - openssl=1.1.1s=h7f8727e_0 16 | - pip=22.3.1=py38h06a4308_0 17 | - python=3.8.16=h7a1cb2a_2 18 | - readline=8.2=h5eee18b_0 19 | - setuptools=65.6.3=py38h06a4308_0 20 | - sqlite=3.40.1=h5082296_0 21 | - tk=8.6.12=h1ccaba5_0 22 | - wheel=0.37.1=pyhd3eb1b0_0 23 | - xz=5.2.10=h5eee18b_1 24 | - zlib=1.2.13=h5eee18b_0 25 | - pip: 26 | - absl-py==1.4.0 27 | - addict==2.4.0 28 | - cachetools==5.3.0 29 | - chardet==5.1.0 30 | - charset-normalizer==3.0.1 31 | - cvxpy==1.3.1 32 | - cycler==0.11.0 33 | - ecos==2.0.12 34 | - efficientnet-pytorch==0.7.1 35 | - filelock==3.9.0 36 | - google-auth==2.16.0 37 | - google-auth-oauthlib==0.4.6 38 | - grpcio==1.51.1 39 | - huggingface-hub==0.12.0 40 | - idna==3.4 41 | - imageio==2.9.0 42 | - importlib-metadata==6.0.0 43 | - joblib==1.2.0 44 | - kiwisolver==1.4.4 45 | - markdown==3.4.1 46 | - markdown-it-py==2.2.0 47 | - markupsafe==2.1.2 48 | - matplotlib==3.3.4 49 | - mdurl==0.1.2 50 | - mmcv==2.0.0 51 | - mmengine==0.7.3 52 | - munch==2.5.0 53 | - networkx==3.0 54 | - numpy==1.19.5 55 | - oauthlib==3.2.2 56 | - opencv-python==4.7.0.68 57 | - osqp==0.6.2.post8 58 | - packaging==23.0 59 | - pandas==1.3.5 60 | - pillow==8.3.1 61 | - pretrainedmodels==0.7.4 62 | - protobuf==3.20.3 63 | - pyasn1==0.4.8 64 | - pyasn1-modules==0.2.8 65 | - pygments==2.15.1 66 | - pyparsing==3.0.9 67 | - pypng==0.20220715.0 68 | - python-dateutil==2.8.2 69 | - pytz==2022.7.1 70 | - pywavelets==1.4.1 71 | - pyyaml==6.0 72 | - qdldl==0.1.7 73 | - requests==2.28.2 74 | - requests-oauthlib==1.3.1 75 | - rich==13.3.5 76 | - rsa==4.9 77 | - safetensors==0.3.1 78 | - scikit-image==0.17.2 79 | - scikit-learn==0.24.2 80 | - scipy==1.5.4 81 | - scs==3.2.3 82 | - segmentation-models-pytorch==0.3.2 83 | - six==1.16.0 84 | - tensorboard==2.11.2 85 | - tensorboard-data-server==0.6.1 86 | - tensorboard-plugin-wit==1.8.1 87 | - termcolor==2.3.0 88 | - threadpoolctl==3.1.0 89 | - tifffile==2023.2.3 90 | - timm==0.6.12 91 | - tomli==2.0.1 92 | - torch==1.9.1+cu111 93 | - torchaudio==0.9.1 94 | - torchvision==0.10.1+cu111 95 | - tqdm==4.62.3 96 | - typing-extensions==4.4.0 97 | - urllib3==1.26.14 98 | - werkzeug==2.2.2 99 | - yapf==0.33.0 100 | - zipp==3.12.1 101 | -------------------------------------------------------------------------------- /src/criterions/per_pixel_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from functools import partial 5 | 6 | def get_per_pixel_loss_func(loss_type): 7 | 8 | def abs_jit(X, Y): 9 | return torch.abs(X - Y) 10 | def mse_jit(X, Y): 11 | return torch.pow(X - Y, 2) 12 | def m4e_jit(X, Y): 13 | return torch.pow(X - Y, 4) 14 | 15 | loss_abs_fn = abs_jit 16 | loss_mse_fn = mse_jit 17 | loss_m4e_fn = m4e_jit 18 | 19 | loss_hinge_fn = lambda X, Y, sign_fn, eps=0: (torch.clamp_min(sign_fn(Y - X), eps) - eps) 20 | loss_smoothL1_fn = lambda X, Y, beta, pow: torch.where((X - Y).abs() < beta, 21 | torch.pow(X - Y, pow) / (pow * beta), 22 | (X - Y).abs() - 1 / float(pow) * beta) 23 | loss_inverted_smoothL1_fn = lambda X, Y, beta, pow: torch.where((X - Y).abs() > beta, 24 | torch.pow(X - Y, pow) / (pow * beta), 25 | (X - Y).abs() - 1 / float(pow) * beta) 26 | loss_bce_logits = torch.nn.BCEWithLogitsLoss(reduction='none') 27 | 28 | def binary_hinge_loss(X, Y): 29 | with torch.no_grad(): 30 | valid_neg = (Y <= 0) * (X > 0) 31 | valid_pos = (Y >= 1) * (X < 1) 32 | valid = (valid_neg + valid_pos) > 0 33 | 34 | return torch.abs(X - Y) * valid.float() 35 | 36 | args = {} 37 | if type(loss_type) is dict: 38 | args = loss_type['args'] if 'args' in loss_type else {} 39 | loss_type = loss_type['type'] 40 | 41 | if loss_type.upper() in ['L1', 'MAE']: 42 | return partial(loss_abs_fn, **args) 43 | elif loss_type.upper() in ['L2', 'MSE']: 44 | return partial(loss_mse_fn, **args) 45 | elif loss_type.lower() in ['hinge']: 46 | return partial(loss_hinge_fn, **args) 47 | elif loss_type.lower() in ['smoothl1']: 48 | return partial(loss_smoothL1_fn, **args) 49 | elif loss_type.lower() in ['inverted-smoothl1']: 50 | return partial(loss_inverted_smoothL1_fn, **args) 51 | elif loss_type.lower() in ['cross-entropy', 'bce']: 52 | return partial(loss_bce_logits, **args) 53 | elif loss_type.lower() in ['focal']: 54 | return lambda X, Y: sigmoid_focal_loss(X, Y, reduction="none", **args) 55 | else: 56 | raise Exception('Unsuported loss type: \'%s\'' % loss_type) 57 | 58 | 59 | def sigmoid_focal_loss( 60 | inputs, 61 | targets, 62 | alpha = 0.25, 63 | delta = 1, 64 | gamma = 2, 65 | A = 1, 66 | reduction = "none"): 67 | 68 | p = torch.sigmoid(inputs) 69 | ce_loss = F.binary_cross_entropy_with_logits( 70 | inputs, targets, reduction="none" 71 | ) 72 | 73 | loss = ce_loss * torch.where(targets == 1, 74 | (1-p)**gamma, # foreground 75 | A*(1-targets)**delta * p**gamma) 76 | 77 | if reduction == "mean": 78 | loss = loss.mean() 79 | elif reduction == "sum": 80 | loss = loss.sum() 81 | return loss 82 | -------------------------------------------------------------------------------- /src/config/__init__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import ast 3 | import importlib 4 | import copy 5 | 6 | dataset_to_import = {'mujoco':'mujoco.{subname}.{type}', 7 | 'vicos_towel':'vicos_towel.{subname}.{type}'} 8 | 9 | def get_cmd_args(): 10 | class ParseConfigCMDArgs(argparse.Action): 11 | def __call__(self, parser, namespace, values, option_string=None): 12 | setattr(namespace, self.dest, dict()) 13 | for value in values: 14 | value_eq = value.split('=') 15 | key, value = value_eq[0], value_eq[1:] 16 | getattr(namespace, self.dest)[key] = "=".join(value) 17 | 18 | # get any config values from CMD arg that override the config file ones 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('-s', '--cfg_subname', type=str, default='') 21 | parser.add_argument('-c', '--configs', nargs='*', action=ParseConfigCMDArgs, default=dict()) 22 | 23 | cmd_args = parser.parse_args() 24 | 25 | return cmd_args 26 | 27 | 28 | def get_config_args(dataset, type, merge_from_cmd_args=True): 29 | # parse command line arguments 30 | cmd_args = get_cmd_args() 31 | 32 | # parse dataset and type 33 | dataset = dataset.lower() 34 | 35 | if dataset not in dataset_to_import.keys(): 36 | raise Exception('Unknown or missing dataset value') 37 | 38 | if type not in ['train','test']: 39 | raise Exception('Invalid type of arguments request: supported only train or test') 40 | 41 | config_module = 'config.' + dataset_to_import[dataset].format(type=type,subname=cmd_args.cfg_subname) 42 | 43 | # remove any double dots 44 | config_module = config_module.replace('..','.') 45 | 46 | module = importlib.import_module(config_module) 47 | 48 | print('Loading config for dataset=%s and type=%s' % (dataset, type)) 49 | args = module.get_args() 50 | 51 | ######################################################## 52 | # Merge from CMD args if any 53 | if merge_from_cmd_args: 54 | args = merge_args_from_config(args, cmd_args) 55 | 56 | ######################################################## 57 | # updated any string with format substitution based on other other arg values 58 | return replace_args(args) 59 | 60 | def merge_args_from_config(args, cmd_args): 61 | def set_config_val_recursive(config, k, v): 62 | k0 = k[0] 63 | if isinstance(config, list): 64 | k0 = int(k0) 65 | if isinstance(k, list) and len(k) > 1: 66 | config[k0] = set_config_val_recursive(config[k0], k[1:], v) 67 | else: 68 | config[k0] = v 69 | return config 70 | 71 | for k,v in cmd_args.configs.items(): 72 | try: 73 | v = ast.literal_eval(v) 74 | except: 75 | print('WARNING: cfg %s=%s interpreting as string' % (k,v)) 76 | args = set_config_val_recursive(args, k.split("."), v) 77 | print("Overriding config with cmd %s=%s" % (k,v)) 78 | 79 | return args 80 | 81 | def replace_args(_args, full_args=None): 82 | if full_args is None: 83 | full_args = copy.deepcopy(_args) 84 | 85 | if isinstance(_args, str): 86 | _args = _args.format(args=full_args) 87 | elif isinstance(_args, dict): 88 | _args = {k: replace_args(a, full_args) for k,a in _args.items()} 89 | elif isinstance(_args, list) or isinstance(_args, tuple): 90 | _args = [replace_args(a,full_args) for a in _args] 91 | return _args 92 | -------------------------------------------------------------------------------- /scripts/run_distributed.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CMD_ARGS=$@ 4 | 5 | job_main() { 6 | # initial call that will delegate job into servers 7 | # SERVERS env var should be list of servers ang gpu ids per each server (e.g. "donbot:0,1,2,3 morbo:2,3 calculon:1,0" 8 | # first count number of servers and GPUs to get the world size 9 | num_gpus="${SERVERS//[^,]}" 10 | num_gpus="${#num_gpus}" 11 | 12 | num_servers="${SERVERS//[^ ]}" 13 | num_servers="${#num_servers}" 14 | 15 | WORLD_SIZE=$((num_gpus+num_servers+1)) 16 | RANK_OFFSET=0 17 | MASTER_PORT=$((RANDOM+24000)) 18 | 19 | IFS=' ' read -ra ADDR_LIST <<< "$SERVERS" 20 | for ADDR in "${ADDR_LIST[@]}"; do 21 | # address is in format: : (e.g. donbot:0,1,2,3) 22 | IFS=':' read -ra NAME_ID <<< "$ADDR" 23 | SERVER_NAME=${NAME_ID[0]} 24 | CUDA_VISIBLE_DEVICES=${NAME_ID[1]} 25 | 26 | # set master to first server 27 | if [ -z "$MASTER_ADDR" ]; then 28 | MASTER_ADDR=$SERVER_NAME 29 | fi 30 | 31 | # pass to ssh all needed env vars 32 | ENVS="" 33 | ENVS="$ENVS CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" 34 | ENVS="$ENVS DATASET=$DATASET" 35 | ENVS="$ENVS USE_DEPTH=$USE_DEPTH" 36 | if [ -n "$TRAIN_SIZE" ]; then 37 | ENVS="$ENVS TRAIN_SIZE=$TRAIN_SIZE" 38 | fi 39 | if [ -n "$TRAIN_SIZE_WIDTH" ]; then 40 | ENVS="$ENVS TRAIN_SIZE_WIDTH=$TRAIN_SIZE_WIDTH" 41 | fi 42 | if [ -n "$TRAIN_SIZE_HEIGHT" ]; then 43 | ENVS="$ENVS TRAIN_SIZE_HEIGHT=$TRAIN_SIZE_HEIGHT" 44 | fi 45 | if [ -n "$TEST_SIZE" ]; then 46 | ENVS="$ENVS TEST_SIZE=$TEST_SIZE" 47 | fi 48 | if [ -n "$TEST_SIZE_WIDTH" ]; then 49 | ENVS="$ENVS TEST_SIZE_WIDTH=$TEST_SIZE_WIDTH" 50 | fi 51 | if [ -n "$TEST_SIZE_HEIGHT" ]; then 52 | ENVS="$ENVS TEST_SIZE_HEIGHT=$TEST_SIZE_HEIGHT" 53 | fi 54 | ENVS="$ENVS ENABLE_6DOF=$ENABLE_6DOF" 55 | ENVS="$ENVS ENABLE_EULER=$ENABLE_EULER" 56 | 57 | ENVS="$ENVS MASTER_PORT=$MASTER_PORT" 58 | ENVS="$ENVS MASTER_ADDR=$MASTER_ADDR" 59 | ENVS="$ENVS WORLD_SIZE=$WORLD_SIZE" 60 | ENVS="$ENVS RANK_OFFSET=$RANK_OFFSET" 61 | if [ -n "$USE_CONDA_ENV" ]; then 62 | ENVS="$ENVS USE_CONDA_ENV=$USE_CONDA_ENV" 63 | fi 64 | 65 | export ENVS 66 | export SERVER_NAME 67 | # run ssh connection in child background process 68 | RUN_TASK=2 $(realpath $0) $CMD_ARGS & 69 | 70 | # increase world rank offset by the number of gpus 71 | num_gpus="${CUDA_VISIBLE_DEVICES//[^,]}" 72 | num_gpus="${#num_gpus}" 73 | RANK_OFFSET=$((RANK_OFFSET + num_gpus + 1)) 74 | done 75 | 76 | wait 77 | } 78 | 79 | ssh_main() { 80 | SSH_ARGS="-t -t -o StrictHostKeyChecking=no" 81 | if [ "${DISABLE_X11}" != "1" ]; then 82 | SSH_ARGS="-Y $SSH_ARGS" 83 | fi 84 | # call main task function on server (use -t -t to allow exiting remote process in interuption) 85 | exec ssh $SSH_ARGS $SERVER_NAME RUN_TASK=1 $ENVS $(realpath $0) $(printf "%q " "$CMD_ARGS") 86 | } 87 | 88 | task_main() { 89 | # Set up signal trap to catch Ctrl+C 90 | trap "exit" SIGINT 91 | 92 | # set up env vars 93 | source "$(dirname $0)/config.sh" 94 | 95 | echo "NODE=$HOSTNAME" 96 | echo "WORLD_SIZE=$WORLD_SIZE" 97 | echo "RANK_OFFSET=$RANK_OFFSET" 98 | 99 | echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" 100 | 101 | echo "master=$MASTER_ADDR" 102 | ################################################### 103 | 104 | cd $SOURCE_DIR 105 | $CMD_ARGS 106 | } 107 | 108 | if [ "${RUN_TASK}" = "2" ]; then 109 | ssh_main 110 | elif [ "${RUN_TASK}" = "1" ]; then 111 | task_main 112 | else 113 | job_main 114 | fi -------------------------------------------------------------------------------- /tools/export_mujoco_to_coco.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | from PIL import Image 5 | from tqdm import tqdm 6 | from datetime import datetime 7 | import importlib 8 | 9 | import sys 10 | sys.path.append(os.path.join(os.path.dirname(__file__), '../src')) 11 | 12 | # Define a function to convert ViCoSClothDataset annotations to COCO format 13 | def convert_to_coco_format(subfolders, output_file, R = 25): 14 | CAT_CENTER_ID = 1 15 | CAT_CENTER_NAME = "corner" 16 | 17 | # Create COCO format dictionary 18 | coco_format = { 19 | "info": { 20 | "version": "v1", 21 | "description": "RTFM MUJOCO Dataset", 22 | "contributor": "domen.tabernik@fri.uni-lj.si", 23 | "url": "", 24 | "year": datetime.now().year, 25 | "date_created": datetime.now().strftime("%Y-%m-%d"), 26 | }, 27 | "categories": [{"supercategory": CAT_CENTER_NAME, "id": CAT_CENTER_ID, "name": CAT_CENTER_NAME}], 28 | "licenses": [], 29 | "images": [], 30 | "annotations": [], 31 | } 32 | 33 | # Load annotations 34 | from datasets.MuJoCoDataset import RTFMDataset 35 | db = RTFMDataset(root_dir=ROOT_DIR, subfolder=subfolders, transform=None, use_depth=False, correct_depth_rotation=False, use_normals=False) 36 | 37 | annt_id = 0 38 | # Loop through images in dataset 39 | for img_id,item in enumerate(tqdm(db)): 40 | filename = item['im_name'] 41 | 42 | # change filename relative to output_file 43 | filename = filename.replace(os.path.dirname(output_file),".") 44 | 45 | # Add image information to COCO format dictionary 46 | coco_format["images"].append({ 47 | "id": img_id, 48 | "file_name": filename, 49 | "height": item['im_size'][1], 50 | "width": item['im_size'][0] 51 | }) 52 | 53 | # Loop through centers for current image 54 | gt_centers = item['center'] 55 | gt_centers = gt_centers[(gt_centers[:, 0] > 0) | (gt_centers[:, 1] > 0), :] 56 | 57 | for center in gt_centers: 58 | # calculate bounding box in XYWH format by adding R to center 59 | bbox = [center[0]-R, center[1]-R, 2*R, 2*R] 60 | 61 | # Add annotation information to COCO format dictionary 62 | coco_format["annotations"].append({ 63 | "id": annt_id, 64 | "image_id": img_id, 65 | "category_id": CAT_CENTER_ID, 66 | "bbox": bbox, 67 | "segmentation": [[bbox[0], bbox[1], 68 | bbox[0]+bbox[2], bbox[1], 69 | bbox[0]+bbox[2], bbox[1]+bbox[3], 70 | bbox[0], bbox[1]+bbox[3], 71 | bbox[0], bbox[1]]], 72 | "area": int(np.prod(bbox[2:])), 73 | "iscrowd": 0 74 | }) 75 | annt_id+=1 76 | 77 | # Save COCO format dictionary to output directory 78 | with open(output_file, "w") as f: 79 | json.dump(coco_format, f) 80 | 81 | 82 | if __name__ == "__main__": 83 | ROOT_DIR = '/storage/datasets/ClothDataset/' 84 | 85 | mujoco_subfolders = ['mujoco', 86 | 'mujoco_all_combinations_normal_color_temp', 87 | 'mujoco_all_combinations_rgb_light', 88 | 'mujoco_white_desk_HS_extreme_color_temp', 89 | 'mujoco_white_desk_HS_normal_color_temp'] 90 | 91 | convert_to_coco_format(mujoco_subfolders, output_file=os.path.join(ROOT_DIR, 'coco_train.json')) 92 | -------------------------------------------------------------------------------- /tools/export_mujoco_to_coco_keypoints.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | from PIL import Image 5 | from tqdm import tqdm 6 | from datetime import datetime 7 | import importlib 8 | 9 | import sys 10 | sys.path.append(os.path.join(os.path.dirname(__file__), '../../src')) 11 | 12 | # Define a function to convert ViCoSClothDataset annotations to COCO format 13 | def convert_to_coco_format(subfolders, output_file, R = 25): 14 | CAT_CENTER_ID = 1 15 | CAT_CENTER_NAME = "corner" 16 | 17 | # Create COCO format dictionary 18 | coco_format = { 19 | "info": { 20 | "version": "v1", 21 | "description": "RTFM MUJOCO Dataset", 22 | "contributor": "domen.tabernik@fri.uni-lj.si", 23 | "url": "", 24 | "year": datetime.now().year, 25 | "date_created": datetime.now().strftime("%Y-%m-%d"), 26 | }, 27 | "categories": [{"supercategory": CAT_CENTER_NAME, "id": CAT_CENTER_ID, "name": CAT_CENTER_NAME, "keypoints": ["center"],}], 28 | "licenses": [], 29 | "images": [], 30 | "annotations": [], 31 | } 32 | 33 | # Load annotations 34 | from datasets.MuJoCoDataset import RTFMDataset 35 | db = RTFMDataset(root_dir=ROOT_DIR, subfolder=subfolders, transform=None, use_depth=False, correct_depth_rotation=False, use_normals=False) 36 | 37 | annt_id = 0 38 | # Loop through images in dataset 39 | for img_id,item in enumerate(tqdm(db)): 40 | filename = item['im_name'] 41 | 42 | # change filename relative to output_file 43 | filename = filename.replace(os.path.dirname(output_file),".") 44 | 45 | # Add image information to COCO format dictionary 46 | coco_format["images"].append({ 47 | "id": img_id, 48 | "file_name": filename, 49 | "height": item['im_size'][1], 50 | "width": item['im_size'][0] 51 | }) 52 | 53 | # Loop through centers for current image 54 | gt_centers = item['center'] 55 | gt_centers = gt_centers[(gt_centers[:, 0] > 0) | (gt_centers[:, 1] > 0), :] 56 | 57 | for center in gt_centers: 58 | # calculate bounding box in XYWH format by adding R to center 59 | bbox = [center[0]-R, center[1]-R, 2*R, 2*R] 60 | 61 | # Add annotation information to COCO format dictionary 62 | coco_format["annotations"].append({ 63 | "id": annt_id, 64 | "image_id": img_id, 65 | "category_id": CAT_CENTER_ID, 66 | "bbox": bbox, 67 | "segmentation": [[bbox[0], bbox[1], 68 | bbox[0]+bbox[2], bbox[1], 69 | bbox[0]+bbox[2], bbox[1]+bbox[3], 70 | bbox[0], bbox[1]+bbox[3], 71 | bbox[0], bbox[1]]], 72 | "keypoints": [center[0], center[1], 2], 73 | "num_keypoints": 1, 74 | "area": int(np.prod(bbox[2:])), 75 | "iscrowd": 0 76 | }) 77 | annt_id+=1 78 | 79 | # Save COCO format dictionary to output directory 80 | with open(output_file, "w") as f: 81 | json.dump(coco_format, f) 82 | 83 | 84 | if __name__ == "__main__": 85 | ROOT_DIR = '/storage/datasets/ClothDataset/' 86 | 87 | mujoco_subfolders = ['mujoco', 88 | 'mujoco_all_combinations_normal_color_temp', 89 | 'mujoco_all_combinations_rgb_light', 90 | 'mujoco_white_desk_HS_extreme_color_temp', 91 | 'mujoco_white_desk_HS_normal_color_temp'] 92 | 93 | convert_to_coco_format(mujoco_subfolders, output_file=os.path.join(ROOT_DIR, 'mujoco_all_train_coco_format_with_keypoints.json')) 94 | -------------------------------------------------------------------------------- /src/criterions/weightings/instance_weight.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import numpy as np 4 | 5 | from criterions.weightings.unbalanced_weight import UnbalancedWeighting 6 | 7 | class InstanceGroupWeighting(UnbalancedWeighting): 8 | IGNORE_FLAG=-9999 9 | 10 | def __init__(self, *args, **kwargs): 11 | super().__init__(*args, **kwargs) 12 | 13 | def __call__(self, gt_instances, gt_ignore=None, gt_R=None, w_fg=1, w_bg=1, *args, **kwargs): 14 | 15 | batch_size, height, width = gt_instances.shape 16 | 17 | bg_mask = (gt_instances == 0).unsqueeze(1) 18 | fg_mask = bg_mask == False 19 | 20 | mask_weights = torch.ones_like(bg_mask, dtype=torch.float32, requires_grad=False, device=gt_instances.device) 21 | 22 | mask_weights[fg_mask] = w_fg 23 | mask_weights[bg_mask] = w_bg 24 | 25 | # apply additional weights around borders 26 | if self.border_weight_px > 0: 27 | mask_weights = self._apply_border_weights(mask_weights) 28 | 29 | if gt_ignore is not None: 30 | mask_weights *= 1 - gt_ignore.type(mask_weights.type()) 31 | 32 | # ensure each instance (and background) is weighted equally regardless of pixels size 33 | # count number of pixels per instance 34 | instance_ids, instance_sizes = gt_instances.reshape(gt_instances.shape[0], -1).unique(return_counts=True, dim=-1) 35 | 36 | # count number of instance for each batch element (without background and ignored regions) 37 | num_bg_pixels = instance_sizes.repeat(batch_size, 1)[instance_ids == 0].sum().float() 38 | 39 | mask_weights = self._init_grouped_weights(mask_weights, gt_instances, instance_ids, instance_sizes, num_bg_pixels) 40 | 41 | # apply additional weight based on distance to center 42 | if self.add_distance_gauss_weight > 0: 43 | mask_weights = self._apply_gauss_distance_weights(mask_weights, gt_R) 44 | 45 | return mask_weights 46 | 47 | def _init_grouped_weights(self, W, group_instance, group_instance_ids, group_instance_sizes, num_bg_pixels, 48 | num_hard_negative_pixels=torch.tensor(0.0)): 49 | # ensure each instance (and background) is weighted equally regardless of pixels size 50 | num_instances = sum([len(set(ids.unique().cpu().numpy()) - set([0, self.IGNORE_FLAG])) for ids in group_instance]) 51 | for b in range(len(group_instance)): 52 | for id in group_instance_ids[b].unique(): 53 | mask_id = group_instance[b].eq(id).unsqueeze(0) 54 | if id == 0: 55 | # for BG instance we normalize based on the number of all bg pixels over the whole batch 56 | instance_normalization = num_bg_pixels * 1 57 | instance_normalization = instance_normalization * ( 58 | 3 / 1.0 if num_hard_negative_pixels > 0 else 2) 59 | elif id < 0: 60 | if num_hard_negative_pixels > 0: 61 | # for hard-negative instances we normalized based on number of them (in pixels) 62 | instance_normalization = num_hard_negative_pixels * torch.log(num_hard_negative_pixels + 1) 63 | instance_normalization = instance_normalization * 3 / 1.0 64 | else: 65 | instance_normalization = 1.0 66 | else: 67 | # for FG instances we normalized based on the size of instance (in pixel) and the number of 68 | # instances over the whole batch 69 | instance_pixels = group_instance_sizes[group_instance_ids[b] == id].sum().float() 70 | instance_normalization = instance_pixels * num_instances * 1 71 | instance_normalization = instance_normalization * ( 3 / 1.0 if num_hard_negative_pixels > 0 else 2) 72 | 73 | # BG and FG are treated as equal so add multiplication by 2 (or 3 if we also have hard-negatives) 74 | # instance_normalization = instance_normalization * _N 75 | W[b][mask_id] *= 1.0 / instance_normalization 76 | return W 77 | 78 | -------------------------------------------------------------------------------- /src/criterions/loss_weighting/weight_methods.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import List, Tuple, Union 3 | 4 | import torch 5 | 6 | 7 | class WeightMethod: 8 | def __init__(self, n_tasks: int, device: torch.device): 9 | super().__init__() 10 | self.n_tasks = n_tasks 11 | self.device = device 12 | 13 | @abstractmethod 14 | def get_weighted_loss( 15 | self, 16 | losses: torch.Tensor, 17 | shared_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor], 18 | task_specific_parameters: Union[ 19 | List[torch.nn.parameter.Parameter], torch.Tensor 20 | ], 21 | last_shared_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor], 22 | representation: Union[torch.nn.parameter.Parameter, torch.Tensor], 23 | **kwargs, 24 | ): 25 | pass 26 | 27 | def backward( 28 | self, 29 | losses: torch.Tensor, 30 | shared_parameters: Union[ 31 | List[torch.nn.parameter.Parameter], torch.Tensor 32 | ] = None, 33 | task_specific_parameters: Union[ 34 | List[torch.nn.parameter.Parameter], torch.Tensor 35 | ] = None, 36 | last_shared_parameters: Union[ 37 | List[torch.nn.parameter.Parameter], torch.Tensor 38 | ] = None, 39 | representation: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None, 40 | **kwargs, 41 | ) -> Tuple[Union[torch.Tensor, None], Union[dict, None]]: 42 | """ 43 | Parameters 44 | ---------- 45 | losses : 46 | shared_parameters : 47 | task_specific_parameters : 48 | last_shared_parameters : parameters of last shared layer/block 49 | representation : shared representation 50 | kwargs : 51 | Returns 52 | ------- 53 | Loss, extra outputs 54 | """ 55 | loss, extra_outputs = self.get_weighted_loss( 56 | losses=losses, 57 | shared_parameters=shared_parameters, 58 | task_specific_parameters=task_specific_parameters, 59 | last_shared_parameters=last_shared_parameters, 60 | representation=representation, 61 | **kwargs, 62 | ) 63 | loss.backward() 64 | return loss, extra_outputs 65 | 66 | def __call__( 67 | self, 68 | losses: torch.Tensor, 69 | shared_parameters: Union[ 70 | List[torch.nn.parameter.Parameter], torch.Tensor 71 | ] = None, 72 | task_specific_parameters: Union[ 73 | List[torch.nn.parameter.Parameter], torch.Tensor 74 | ] = None, 75 | **kwargs, 76 | ): 77 | return self.backward( 78 | losses=losses, 79 | shared_parameters=shared_parameters, 80 | task_specific_parameters=task_specific_parameters, 81 | **kwargs, 82 | ) 83 | 84 | def parameters(self) -> List[torch.Tensor]: 85 | """return learnable parameters""" 86 | return [] 87 | 88 | class Uncertainty(WeightMethod): 89 | """Implementation of `Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics` 90 | Source: https://github.com/yaringal/multi-task-learning-example/blob/master/multi-task-learning-example-pytorch.ipynb 91 | """ 92 | 93 | def __init__(self, n_tasks, device: torch.device): 94 | super().__init__(n_tasks, device=device) 95 | self.logsigma = torch.tensor([0.0] * n_tasks, device=device, requires_grad=True) 96 | 97 | def get_weighted_loss(self, losses: torch.Tensor, **kwargs): 98 | loss = sum( 99 | [ 100 | 0.5 * (torch.exp(-logs) * loss + logs) 101 | for loss, logs in zip(losses, self.logsigma) 102 | ] 103 | ) 104 | 105 | return loss, dict( 106 | weights=torch.exp(-self.logsigma) 107 | ) # NOTE: not exactly task weights 108 | 109 | def parameters(self) -> List[torch.Tensor]: 110 | return [self.logsigma] 111 | 112 | 113 | def get_weight_method(method: str, n_tasks: int, device: torch.device, **kwargs): 114 | assert method in list(METHODS.keys()), f"unknown method {method}." 115 | 116 | return METHODS[method](n_tasks=n_tasks, device=device, **kwargs) 117 | 118 | METHODS = dict( 119 | uw=Uncertainty, 120 | ) -------------------------------------------------------------------------------- /src/criterions/center_localization_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from criterions.per_pixel_losses import get_per_pixel_loss_func 7 | 8 | class CenterLocalizationLoss(nn.Module): 9 | 10 | def __init__(self, loss_type='l1', ignore_negative_gradient=False, fp_threshold=0.1, 11 | use_per_instance_normalization=False, positive_area_radius=1, **kargs): 12 | super().__init__() 13 | 14 | self.ignore_negative_gradient = ignore_negative_gradient 15 | self.fp_threshold = fp_threshold # 0.1 16 | self.use_per_instance_normalization = use_per_instance_normalization 17 | self.positive_area_radius = positive_area_radius 18 | 19 | self.loss_fn = get_per_pixel_loss_func(loss_type) 20 | 21 | def forward(self, prediction_centers, prediction_prob, gt_centers, gt_center_mask, ignore_mask=None, 22 | w_fg=1, w_bg=1, reduction_dims=(1,2,3), **kwargs): 23 | 24 | batch_size, height, width = prediction_prob.size(0), prediction_prob.size(2), prediction_prob.size(3) 25 | 26 | loss_centers = torch.zeros(size=[d for i,d in enumerate(prediction_prob.shape) if i not in reduction_dims], 27 | device=prediction_prob.device) 28 | 29 | with torch.no_grad(): 30 | if self.use_per_instance_normalization: 31 | mask_weights = self._calc_per_instance_weight_mask(prediction_centers, prediction_prob, gt_centers, 32 | thr=self.fp_threshold, N=100, R=self.positive_area_radius) 33 | else: 34 | mask_weights = torch.ones_like(prediction_prob, requires_grad=False, device=prediction_prob.device, dtype=torch.float32) 35 | mask_weights *= 1.0 / (height * width * batch_size) 36 | 37 | if ignore_mask is not None: 38 | mask_weights *= 1 - ignore_mask.type(mask_weights.type()) 39 | 40 | if w_fg != 1: 41 | mask_weights[gt_center_mask > 0] *= w_fg 42 | if w_bg != 1: 43 | mask_weights[gt_center_mask <= 0] *= w_bg 44 | 45 | if self.ignore_negative_gradient: 46 | loss_centers += torch.sum(mask_weights * torch.where((prediction_prob <= 0) * (gt_center_mask <= 0), 47 | torch.zeros_like(prediction_prob, requires_grad=False), 48 | self.loss_fn(prediction_prob, gt_center_mask)), 49 | dim=reduction_dims) 50 | else: 51 | loss_centers += torch.sum(mask_weights * self.loss_fn(prediction_prob, gt_center_mask), 52 | dim=reduction_dims) 53 | 54 | return loss_centers 55 | 56 | @staticmethod 57 | def _calc_per_instance_weight_mask(prediction_centers, prediction_prob, gt_centers, thr, N, R): 58 | batch_size, height, width = prediction_prob.size(0), prediction_prob.size(2), prediction_prob.size(3) 59 | 60 | mask_weights = torch.zeros_like(prediction_prob, requires_grad=False, device=prediction_prob.device, dtype=torch.float32) 61 | 62 | centers_pred = np.array(prediction_centers) 63 | for b in range(batch_size): 64 | # mark positive groundtruth areas 65 | for x, y in gt_centers[b, 1:]: 66 | if x == 0 and y == 0: break 67 | x, y = int(x), int(y) 68 | mask_weights[b, :, x - R:x + R, y - R:y + R] = 1 69 | 70 | # find hard-negative centers if exist any 71 | if len(centers_pred) <= 0: 72 | continue 73 | 74 | batch_centers = centers_pred[centers_pred[:, 0] == b, :] 75 | 76 | if len(batch_centers) <= 0: 77 | continue 78 | 79 | ids = np.argsort(batch_centers[:, -1])[::-1] 80 | 81 | hard_neg_centers = batch_centers[ids, :] 82 | hard_neg_centers = hard_neg_centers[hard_neg_centers[:, -1] > thr] 83 | if len(hard_neg_centers) > N: 84 | hard_neg_centers = hard_neg_centers[:N, :] 85 | 86 | # mark hard-negative areas 87 | for x, y in hard_neg_centers[:, 1:3]: 88 | x, y = int(x), int(y) 89 | mask_weights[b, :, y - R:y + R, x - R:x + R] = 1 90 | 91 | # original version 92 | if True: 93 | hard_neg_pixels = max(mask_weights.sum().item(), 1.0) 94 | 95 | mask_weights *= (1.0 / hard_neg_pixels - 1 / ( 96 | height * width * batch_size - hard_neg_pixels)) 97 | mask_weights += 1 / (height * width * batch_size - hard_neg_pixels) 98 | 99 | # using only FP pixels 100 | if False: 101 | hard_neg_pixels = mask_weights.sum().item() 102 | if hard_neg_pixels > 0: 103 | mask_weights *= (1.0 / hard_neg_pixels) 104 | 105 | return mask_weights 106 | -------------------------------------------------------------------------------- /tools/export_vicos_towels_to_coco.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | from PIL import Image 5 | from tqdm import tqdm 6 | from datetime import datetime 7 | import importlib 8 | 9 | import sys 10 | sys.path.append(os.path.join(os.path.dirname(__file__), '../src')) 11 | 12 | # Define a function to convert ViCoSClothDataset annotations to COCO format 13 | def convert_to_coco_format(subfolders, output_file, R = 25): 14 | CAT_CENTER_ID = 1 15 | CAT_CENTER_NAME = "corner" 16 | 17 | # Create COCO format dictionary 18 | coco_format = { 19 | "info": { 20 | "version": "v2", 21 | "description": "Vicos Towels Dataset", 22 | "contributor": "domen.tabernik@fri.uni-lj.si", 23 | "url": "", 24 | "year": datetime.now().year, 25 | "date_created": datetime.now().strftime("%Y-%m-%d"), 26 | }, 27 | "categories": [{"supercategory": CAT_CENTER_NAME, "id": CAT_CENTER_ID, "name": CAT_CENTER_NAME}], 28 | "licenses": [], 29 | "images": [], 30 | "annotations": [], 31 | } 32 | 33 | # Load annotations 34 | db = ViCoSTowelDataset(root_dir=ROOT_DIR, subfolders=subfolders, transform=None, use_depth=False, correct_depth_rotation=False, use_normals=False) 35 | db.return_image = False 36 | db.check_consistency = False 37 | 38 | annt_id = 0 39 | # Loop through images in dataset 40 | for img_id,item in enumerate(tqdm(db)): 41 | filename = item['im_name'] 42 | 43 | # change filename relative to output_file 44 | filename = filename.replace(os.path.dirname(output_file),".") 45 | 46 | # Add image information to COCO format dictionary 47 | coco_format["images"].append({ 48 | "id": img_id, 49 | "file_name": filename, 50 | "height": item['im_size'][1], 51 | "width": item['im_size'][0] 52 | }) 53 | 54 | # Loop through centers for current image 55 | gt_centers = item['center'] 56 | gt_centers = gt_centers[(gt_centers[:, 0] > 0) | (gt_centers[:, 1] > 0), :] 57 | 58 | for center in gt_centers: 59 | # calculate bounding box in XYWH format by adding R to center 60 | bbox = [center[0]-R, center[1]-R, 2*R, 2*R] 61 | 62 | # Add annotation information to COCO format dictionary 63 | coco_format["annotations"].append({ 64 | "id": annt_id, 65 | "image_id": img_id, 66 | "category_id": CAT_CENTER_ID, 67 | "bbox": bbox, 68 | "segmentation": [[bbox[0], bbox[1], 69 | bbox[0]+bbox[2], bbox[1], 70 | bbox[0]+bbox[2], bbox[1]+bbox[3], 71 | bbox[0], bbox[1]+bbox[3], 72 | bbox[0], bbox[1]]], 73 | "area": int(np.prod(bbox[2:])), 74 | "iscrowd": 0 75 | }) 76 | annt_id+=1 77 | 78 | # Save COCO format dictionary to output directory 79 | with open(output_file, "w") as f: 80 | json.dump(coco_format, f) 81 | 82 | 83 | if __name__ == "__main__": 84 | ROOT_DIR = '/storage/datasets/ClothDataset/ClothDatasetVICOS/' 85 | 86 | from datasets.ViCoSTowelDataset import ViCoSTowelDataset 87 | 88 | os.environ['RTFM_CLOTH_DATASET_VICOS_DIR'] = ROOT_DIR 89 | from config.vicos_towel.train import get_args as get_train_args 90 | from config.vicos_towel.test import get_args as get_test_args 91 | 92 | get_novel_bg_test_args = importlib.import_module("config.vicos_towel.novel_object=bg.test").get_args 93 | get_novel_cloth_test_args = importlib.import_module("config.vicos_towel.novel_object=cloth.test").get_args 94 | get_novel_bg_and_cloth_test_args = importlib.import_module("config.vicos_towel.novel_object=bg+cloth.test").get_args 95 | 96 | train_subfolders = get_train_args()['train_dataset']['kwargs']['subfolders'] 97 | test_subfolders = get_test_args()['dataset']['kwargs']['subfolders'] 98 | 99 | test_novel_bg_subfolders = get_novel_bg_test_args()['dataset']['kwargs']['subfolders'] 100 | test_novel_cloth_subfolders = get_novel_cloth_test_args()['dataset']['kwargs']['subfolders'] 101 | test_novel_cloth_and_bg_subfolders = get_novel_bg_and_cloth_test_args()['dataset']['kwargs']['subfolders'] 102 | 103 | convert_to_coco_format(train_subfolders, output_file=os.path.join(ROOT_DIR, 'coco_train.json')) 104 | convert_to_coco_format(test_subfolders, output_file=os.path.join(ROOT_DIR, 'coco_test.json')) 105 | 106 | convert_to_coco_format(test_novel_bg_subfolders, output_file=os.path.join(ROOT_DIR, 'coco_test_novel=bg.json')) 107 | convert_to_coco_format(test_novel_cloth_subfolders, output_file=os.path.join(ROOT_DIR, 'coco_test_novel=cloth.json')) 108 | convert_to_coco_format(test_novel_cloth_and_bg_subfolders, output_file=os.path.join(ROOT_DIR, 'coco_test_novel=bg+cloth.json')) 109 | -------------------------------------------------------------------------------- /tools/export_vicos_towels_to_coco_keypoints.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | from PIL import Image 5 | from tqdm import tqdm 6 | from datetime import datetime 7 | import importlib 8 | 9 | import sys 10 | sys.path.append(os.path.join(os.path.dirname(__file__), '../../src')) 11 | 12 | # Define a function to convert ViCoSClothDataset annotations to COCO format 13 | def convert_to_coco_format(subfolders, output_file, R = 25): 14 | CAT_CENTER_ID = 1 15 | CAT_CENTER_NAME = "corner" 16 | 17 | # Create COCO format dictionary 18 | coco_format = { 19 | "info": { 20 | "version": "v2", 21 | "description": "Vicos Cloth Dataset", 22 | "contributor": "domen.tabernik@fri.uni-lj.si", 23 | "url": "", 24 | "year": datetime.now().year, 25 | "date_created": datetime.now().strftime("%Y-%m-%d"), 26 | }, 27 | "categories": [{"supercategory": CAT_CENTER_NAME, "id": CAT_CENTER_ID, "name": CAT_CENTER_NAME, "keypoints": ["center"],}], 28 | "licenses": [], 29 | "images": [], 30 | "annotations": [], 31 | } 32 | 33 | # Load annotations 34 | db = ViCoSTowelDataset(root_dir=ROOT_DIR, subfolders=subfolders, transform=None, use_depth=False, correct_depth_rotation=False, use_normals=False) 35 | db.return_image = False 36 | db.check_consistency = False 37 | db.remove_out_of_bounds_centers = False 38 | 39 | annt_id = 0 40 | # Loop through images in dataset 41 | for img_id,item in enumerate(tqdm(db)): 42 | filename = item['im_name'] 43 | 44 | # change filename relative to output_file 45 | filename = filename.replace(os.path.dirname(output_file),".") 46 | 47 | # Add image information to COCO format dictionary 48 | coco_format["images"].append({ 49 | "id": img_id, 50 | "file_name": filename, 51 | "height": item['im_size'][1], 52 | "width": item['im_size'][0] 53 | }) 54 | 55 | # Loop through centers for current image 56 | gt_centers = item['center'] 57 | gt_centers = gt_centers[(gt_centers[:, 0] > 0) | (gt_centers[:, 1] > 0), :] 58 | 59 | for center in gt_centers: 60 | # calculate bounding box in XYWH format by adding R to center 61 | bbox = [center[0]-R, center[1]-R, 2*R, 2*R] 62 | 63 | # Add annotation information to COCO format dictionary 64 | coco_format["annotations"].append({ 65 | "id": annt_id, 66 | "image_id": img_id, 67 | "category_id": CAT_CENTER_ID, 68 | "bbox": bbox, 69 | "segmentation": [[bbox[0], bbox[1], 70 | bbox[0]+bbox[2], bbox[1], 71 | bbox[0]+bbox[2], bbox[1]+bbox[3], 72 | bbox[0], bbox[1]+bbox[3], 73 | bbox[0], bbox[1]]], 74 | "keypoints": [center[0], center[1], 2], 75 | "num_keypoints": 1, 76 | "area": int(np.prod(bbox[2:])), 77 | "iscrowd": 0 78 | }) 79 | annt_id+=1 80 | 81 | # Save COCO format dictionary to output directory 82 | with open(output_file, "w") as f: 83 | json.dump(coco_format, f) 84 | 85 | 86 | if __name__ == "__main__": 87 | ROOT_DIR = '/storage/datasets/ClothDataset/ClothDatasetVICOS/' 88 | 89 | from datasets.ViCoSTowelDataset import ViCoSTowelDataset 90 | 91 | os.environ['RTFM_CLOTH_DATASET_VICOS_DIR'] = ROOT_DIR 92 | from config.vicos_towel.train import get_args as get_train_args 93 | from config.vicos_towel.test import get_args as get_test_args 94 | 95 | get_novel_bg_test_args = importlib.import_module("config.vicos_towel.novel_object=bg.test").get_args 96 | get_novel_cloth_test_args = importlib.import_module("config.vicos_towel.novel_object=cloth.test").get_args 97 | get_novel_bg_and_cloth_test_args = importlib.import_module("config.vicos_towel.novel_object=bg+cloth.test").get_args 98 | 99 | train_subfolders = get_train_args()['train_dataset']['kwargs']['subfolders'] 100 | test_subfolders = get_test_args()['dataset']['kwargs']['subfolders'] 101 | 102 | test_novel_bg_subfolders = get_novel_bg_test_args()['dataset']['kwargs']['subfolders'] 103 | test_novel_cloth_subfolders = get_novel_cloth_test_args()['dataset']['kwargs']['subfolders'] 104 | test_novel_cloth_and_bg_subfolders = get_novel_bg_and_cloth_test_args()['dataset']['kwargs']['subfolders'] 105 | 106 | convert_to_coco_format(train_subfolders, output_file=os.path.join(ROOT_DIR, 'coco_train_with_keypoints.json')) 107 | convert_to_coco_format(test_subfolders, output_file=os.path.join(ROOT_DIR, 'coco_test_with_keypoints.json')) 108 | 109 | #convert_to_coco_format(test_novel_bg_subfolders, output_file=os.path.join(ROOT_DIR, 'coco_test_novel=bg_with_keypoints.json')) 110 | #convert_to_coco_format(test_novel_cloth_subfolders, output_file=os.path.join(ROOT_DIR, 'coco_test_novel=cloth_with_keypoints.json')) 111 | #convert_to_coco_format(test_novel_cloth_and_bg_subfolders, output_file=os.path.join(ROOT_DIR, 'coco_test_novel=bg+cloth_with_keypoints.json')) 112 | -------------------------------------------------------------------------------- /scripts/EXPERIMENTS_MAIN.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # include config and some utils 4 | source ./config.sh 5 | source ./utils.sh 6 | 7 | export USE_CONDA_ENV=CeDiRNet-py3.8 8 | export DISABLE_X11=0 9 | 10 | centernet_filename="${ROOT_DIR}/models/localization_checkpoint.pth" 11 | 12 | DO_SYNT_TRAINING=True # step 0: pretraining on syntetic data (MuJoCo) 13 | DO_REAL_TRAINING=True # step 1: training on real-world data (ViCoS Towel Dataset) 14 | DO_EVALUATION=True # step 3: evaluate 15 | 16 | # assuming 4 GPUs available on localhost 17 | GPU_LIST=("localhost:0" "localhost:1" "localhost:2" "localhost:4") 18 | GPU_COUNT=${#GPU_LIST[@]} 19 | 20 | 21 | ######################################## 22 | # PRETRAINING on synthetic data only 23 | ######################################## 24 | 25 | if [[ "$DO_SYNT_TRAINING" == True ]] ; then 26 | s=0 27 | for db in "mujoco"; do 28 | export DATASET=$db 29 | for backbone in "tu-convnext_base" "tu-convnext_large" ; do 30 | for epoch in 10; do 31 | for depth in off on; do 32 | if [[ "$depth" == off ]] ; then 33 | export USE_DEPTH=False 34 | elif [[ "$depth" == on ]] ; then 35 | export USE_DEPTH=True 36 | fi 37 | SERVERS=${GPU_LIST[$((s % GPU_COUNT))]} ./run_distributed.sh python train.py --config \ 38 | model.kwargs.backbone=$backbone \ 39 | n_epochs=$epoch \ 40 | train_dataset.batch_size=4 \ 41 | train_dataset.workers=16 \ 42 | "pretrained_center_model_path=$centernet_filename" \ 43 | display=False save_interval=1 skip_if_exists=True & 44 | s=$((s+1)) 45 | wait_or_interrupt $GPU_COUNT $s 46 | done 47 | done 48 | done 49 | done 50 | fi 51 | wait_or_interrupt 52 | 53 | ######################################## 54 | # Training on real data 55 | ######################################## 56 | 57 | if [[ "$DO_REAL_TRAINING" == True ]] ; then 58 | s=0 59 | 60 | for db in "vicos_towel"; do 61 | export DATASET=$db 62 | export TRAIN_SIZE=768 63 | for backbone in "tu-convnext_base" "tu-convnext_large" ; do 64 | for epoch in 10; do 65 | for depth in off on; do 66 | if [[ "$depth" == off ]] ; then 67 | export USE_DEPTH=False 68 | depth_str=False 69 | elif [[ "$depth" == on ]] ; then 70 | export USE_DEPTH=True 71 | depth_str=True 72 | fi 73 | PRETRAINED_CHECKPOINT="${OUTPUT_DIR}/mujoco/backbone=${backbone}/num_train_epoch=10/depth=$depth_str/multitask_weight=uw/checkpoint.pth" 74 | SERVERS=${GPU_LIST[$((s % GPU_COUNT))]} ./run_distributed.sh python train.py --config \ 75 | model.kwargs.backbone=$backbone \ 76 | n_epochs=$epoch \ 77 | "pretrained_center_model_path=$centernet_filename" \ 78 | "pretrained_model_path=$PRETRAINED_CHECKPOINT" \ 79 | display=True skip_if_exists=True "save_interval=1" & 80 | s=$((s+1)) 81 | wait_or_interrupt $GPU_COUNT $s 82 | done 83 | done 84 | done 85 | done 86 | fi 87 | wait_or_interrupt 88 | 89 | ######################################## 90 | # Evaluating on test data 91 | ######################################## 92 | 93 | if [[ "$DO_EVALUATION" == True ]] ; then 94 | 95 | # FOR FULL EVAL 96 | DISPLAY_ARGS="display=False display_to_file_only=True skip_if_exists=True" 97 | 98 | s=0 99 | export DISABLE_X11=0 100 | for db in "vicos_towel"; do 101 | for cfg_subname in ""; do # for exclusively unseen objects set to "novel_object=bg+cloth" (or to "novel_object=cloth" "novel_object=bg") 102 | export DATASET=$db 103 | export TRAIN_SIZE=768 104 | export TEST_SIZE=768 105 | for backbone in "tu-convnext_base" "tu-convnext_large" ; do 106 | for epoch_train in 10; do 107 | ALL_EPOCH=("") # set to ALL_EPOCH=("" _002 _004 _006 _008) to evaluate every second epoch 108 | for epoch_eval in "${ALL_EPOCH[@]}"; do 109 | for depth in off on; do 110 | if [[ "$depth" == off ]] ; then 111 | export USE_DEPTH=False 112 | elif [[ "$depth" == on ]] ; then 113 | export USE_DEPTH=True 114 | fi 115 | # run center model pre-trained on weakly-supervised 116 | SERVERS=${GPU_LIST[$((s % GPU_COUNT))]} ./run_distributed.sh python test.py --cfg_subname="$cfg_subname" --config \ 117 | eval_epoch=$epoch_eval \ 118 | model.kwargs.backbone=$backbone \ 119 | train_settings.n_epochs=$epoch_train \ 120 | "center_checkpoint_path=$centernet_filename" \ 121 | center_checkpoint_name_list=None \ 122 | $DISPLAY_ARGS & 123 | 124 | s=$((s+1)) 125 | wait_or_interrupt $GPU_COUNT $s 126 | done 127 | done 128 | done 129 | done 130 | done 131 | done 132 | fi 133 | wait -------------------------------------------------------------------------------- /tools/calculate_cloth_angles.py: -------------------------------------------------------------------------------- 1 | import glob, os, json, cv2 2 | from matplotlib import pyplot as plt 3 | import numpy as np 4 | import pyransac3d as pyrsc 5 | import matplotlib.pyplot as plt 6 | 7 | from numpy.linalg import inv 8 | 9 | from utils.utils_depth import * 10 | 11 | fx = 1081.3720703125 12 | cx = 959.5 13 | cy = 539.5 14 | 15 | K = np.array([[fx, 0.0, cx], [0, fx, cy], [0,0,1]]) 16 | 17 | def calculate_pitch(img_fn, depth_fn, downsample_factor=0.1): 18 | global K 19 | pitch = 0 20 | 21 | K_ = K.copy() 22 | 23 | img = cv2.imread(img_fn).astype(float)/255 24 | depth = np.load(depth_fn) 25 | 26 | # preprocess depth 27 | depth[np.isnan(depth)]=0 28 | depth[np.isinf(depth)]=0 29 | depth[depth>1e6]=0 30 | depth*=1e-3 31 | 32 | img = cv2.resize(img, None, fx=downsample_factor, fy=downsample_factor) 33 | depth = cv2.resize(depth, None, fx=downsample_factor, fy=downsample_factor, interpolation=cv2.INTER_NEAREST) 34 | 35 | h, w, _ = img.shape 36 | 37 | # build point cloud 38 | xx, yy = np.meshgrid(range(w), range(h)) 39 | xx = np.ravel(xx) 40 | yy = np.ravel(yy) 41 | dd = np.ravel(depth) 42 | pc = np.vstack((xx,yy,dd)) 43 | 44 | # un-project points 45 | K_*=downsample_factor 46 | K_[-1,-1]=1 47 | pc = inv(K_)@pc 48 | 49 | # estimate plane 50 | plane1 = pyrsc.Plane() 51 | best_eq, best_inliers = plane1.fit(pc.T, 0.01) 52 | 53 | # extract pitch 54 | ca = best_eq[2]/np.sqrt(best_eq[0]**2+best_eq[1]**2+best_eq[2]**2) 55 | pitch = np.degrees(np.arccos(ca)) 56 | pitch = 180-pitch if pitch>90 else pitch 57 | 58 | display = False 59 | if display: 60 | 61 | R = eul2rot((np.radians(pitch),0,0)) 62 | pc = K_@R@pc 63 | d_rotated = pc[-1,:].reshape(h,w) 64 | 65 | plt.clf() 66 | plt.subplot(2,2,1) 67 | plt.imshow(img) 68 | plt.subplot(2,2,2) 69 | plt.imshow(depth) 70 | plt.subplot(2,2,3) 71 | plt.imshow(d_rotated) 72 | plt.subplot(2,2,4) 73 | plt.imshow(np.abs(depth-d_rotated)) 74 | plt.draw(); plt.pause(0.01) 75 | plt.waitforbuttonpress() 76 | 77 | return pitch 78 | 79 | def main(): 80 | 81 | dataset_path = '/storage/datasets/ClothDataset/ClothDatasetVICOS/' 82 | dataset_setups = glob.glob(f'{dataset_path}bg=*') 83 | 84 | for setup in dataset_setups: 85 | 86 | cloths = glob.glob(f'{setup}/cloth=*') 87 | 88 | for subset in cloths: 89 | print(subset) 90 | 91 | images = glob.glob(f'{subset}/rgb/*') 92 | print(len(images)) 93 | 94 | data = {} 95 | 96 | for fn in images: 97 | name = fn.split('/')[-1] 98 | 99 | depth_fn = f'{subset}/depth/{name[:-4]}.npy' 100 | 101 | if not os.path.exists(depth_fn): 102 | depth_fn = depth_fn.replace('camera0', 'camera1') 103 | 104 | pitch = calculate_pitch(fn, depth_fn) 105 | print("pitch", pitch) 106 | 107 | data[name]={'pitch': pitch} 108 | 109 | with open(f'{subset}/plane_angles.json', 'w', encoding='utf-8') as f: 110 | json.dump(data, f, ensure_ascii=False, indent=4) 111 | 112 | def check(): 113 | 114 | dataset_path = '/storage/datasets/ClothDataset/ClothDatasetVICOS/' 115 | dataset_setups = glob.glob(f'{dataset_path}bg=*') 116 | 117 | for setup in dataset_setups: 118 | 119 | cloths = glob.glob(f'{setup}/cloth=*') 120 | 121 | for subset in cloths: 122 | print(subset) 123 | 124 | images = glob.glob(f'{subset}/rgb/*') 125 | 126 | with open(f'{subset}/plane_angles.json') as f: 127 | data = json.load(f) 128 | 129 | for fn in images: 130 | name = fn.split('/')[-1] 131 | 132 | pitch = data[name]['pitch'] 133 | 134 | depth_fn = f'{subset}/depth/{name[:-4]}.npy' 135 | 136 | if not os.path.exists(depth_fn): 137 | depth_fn = depth_fn.replace('camera0', 'camera1') 138 | 139 | # display result 140 | img = cv2.imread(fn).astype(float)/255 141 | depth = np.load(depth_fn) 142 | 143 | # preprocess depth 144 | depth[np.isnan(depth)]=0 145 | depth[np.isinf(depth)]=0 146 | depth[depth>1e6]=0 147 | depth*=1e-3 148 | 149 | R = eul2rot((np.radians(pitch),0,0)) 150 | 151 | depth_rotated = rotate_depth(depth, R, K) 152 | 153 | plt.clf() 154 | plt.subplot(2,2,1) 155 | plt.imshow(img) 156 | plt.subplot(2,2,2) 157 | plt.imshow(depth) 158 | plt.subplot(2,2,3) 159 | plt.imshow(depth_rotated) 160 | plt.subplot(2,2,4) 161 | plt.imshow(np.abs(depth-depth_rotated)) 162 | plt.draw(); plt.pause(0.01) 163 | plt.waitforbuttonpress() 164 | 165 | break 166 | 167 | # data.append((name, {'pitch': 20})) 168 | 169 | # TODO calculate pitch 170 | 171 | # break 172 | 173 | def check_single(): 174 | 175 | pth = '/storage/datasets/ClothDataset/ClothDatasetVICOS/bg=festive_tablecloth/cloth=linen_rag/' 176 | name = 'image_0000_view19_ls2_camera1' 177 | 178 | im_fn = f'{pth}rgb/{name}.jpg' 179 | im_fn = im_fn.replace('camera1', 'camera0') 180 | print(im_fn) 181 | depth_fn = f'{pth}depth/{name}.npy' 182 | print(depth_fn) 183 | 184 | img = cv2.imread(im_fn).astype(float)/255 185 | depth = np.load(depth_fn) 186 | 187 | depth[np.isnan(depth)]=0 188 | depth[np.isinf(depth)]=0 189 | depth[depth>1e6]=0 190 | depth*=1e-3 191 | 192 | pitch = 20 193 | 194 | R = eul2rot((np.radians(pitch),0,0)) 195 | 196 | depth_rotated = rotate_depth(depth, R, K) 197 | 198 | plt.clf() 199 | plt.subplot(2,2,1) 200 | plt.imshow(img) 201 | plt.subplot(2,2,2) 202 | plt.imshow(depth) 203 | plt.subplot(2,2,3) 204 | plt.imshow(depth_rotated) 205 | plt.subplot(2,2,4) 206 | plt.imshow(np.abs(depth-depth_rotated)) 207 | plt.draw(); plt.pause(0.01) 208 | plt.waitforbuttonpress() 209 | 210 | if __name__=='__main__': 211 | main() 212 | # check() 213 | # check_single() -------------------------------------------------------------------------------- /src/utils/utils_depth.py: -------------------------------------------------------------------------------- 1 | import numpy as np, cv2, glob, os 2 | from matplotlib import pyplot as plt 3 | from numpy.linalg import norm, inv 4 | 5 | def get_surface_normal_by_depth(depth, K=None): 6 | """ 7 | depth: (h, w) of float, the unit of depth is meter 8 | K: (3, 3) of float, the depth camere's intrinsic 9 | """ 10 | K = [[1, 0], [0, 1]] if K is None else K 11 | fx, fy = K[0][0], K[1][1] 12 | 13 | dz_dv, dz_du = np.gradient(depth) # u, v mean the pixel coordinate in the image 14 | du_dx = fx / depth # x is xyz of camera coordinate 15 | dv_dy = fy / depth 16 | 17 | dz_dx = dz_du * du_dx 18 | dz_dy = dz_dv * dv_dy 19 | normal_cross = np.dstack((-dz_dx, -dz_dy, np.ones_like(depth))) 20 | normal_unit = normal_cross / np.linalg.norm(normal_cross, axis=2, keepdims=True) 21 | normal_unit[~np.isfinite(normal_unit).all(2)] = [0, 0, 1] 22 | return normal_unit 23 | 24 | def get_angle_from_depth(depth, household=False, k=15): 25 | # https://math.stackexchange.com/questions/3433645/how-can-i-find-the-angle-of-the-surface-3d-plane 26 | 27 | depth = cv2.GaussianBlur(depth,(k,k),0) 28 | 29 | # get normals from depth 30 | if household: # for kinect images 31 | fx = 1081.3720703125 32 | cx = 959.5 33 | cy = 539.5 34 | else: 35 | fx = 256 36 | cx = 256.0 37 | cy = 256.0 38 | 39 | K = np.array([[fx,0,cx],[0,fx,cy],[0,0,1]]) 40 | normals = get_surface_normal_by_depth(depth, K=K) 41 | 42 | return normals 43 | 44 | def eul2rot(theta) : 45 | 46 | R = np.array( 47 | [ 48 | [np.cos(theta[1])*np.cos(theta[2]), np.sin(theta[0])*np.sin(theta[1])*np.cos(theta[2]) - np.sin(theta[2])*np.cos(theta[0]),np.sin(theta[1])*np.cos(theta[0])*np.cos(theta[2]) + np.sin(theta[0])*np.sin(theta[2])], 49 | [np.sin(theta[2])*np.cos(theta[1]), np.sin(theta[0])*np.sin(theta[1])*np.sin(theta[2]) + np.cos(theta[0])*np.cos(theta[2]),np.sin(theta[1])*np.sin(theta[2])*np.cos(theta[0]) - np.sin(theta[0])*np.cos(theta[2])], 50 | [-np.sin(theta[1]),np.sin(theta[0])*np.cos(theta[1]),np.cos(theta[0])*np.cos(theta[1])] 51 | ] 52 | ) 53 | 54 | return R 55 | 56 | def get_normals(depth, normals_mode=1, household=False, k=15): 57 | 58 | normals = get_angle_from_depth(depth, household=household, k=k) 59 | 60 | # 1 is normal vector as 3 channels 61 | # 2 is angle between normal vector and each of the axes 62 | # 3 is Sobel operator x and y 63 | 64 | if normals_mode==1: 65 | normals = get_angle_from_depth(depth, household=household, k=k) 66 | elif normals_mode==2: 67 | dx = cv2.Sobel(depth, cv2.CV_64F, 1, 0, ksize=k) 68 | dy = cv2.Sobel(depth, cv2.CV_64F, 0, 1, ksize=k) 69 | 70 | dx[np.isnan(dx)]=0 71 | dy[np.isnan(dy)]=0 72 | dx[np.isinf(dx)]=0 73 | dy[np.isinf(dy)]=0 74 | 75 | dx/=np.max(dx) 76 | dy/=np.max(dy) 77 | normals = np.dstack((dx,dy)) 78 | 79 | return normals 80 | 81 | normals = get_angle_from_depth(depth, household=household, k=k) 82 | 83 | # 1 is normal vector as 3 channels 84 | # 2 is dot product of normal vector with reference normal vector in 1 channel 85 | # 3 is angle between normal plane and reference vector in 1 channel 86 | # 4 is same as 3, except expressed as sin and cos 87 | 88 | n = np.zeros_like(normals) 89 | n[...,0]=reference_normal[0] 90 | n[...,1]=reference_normal[1] 91 | n[...,2]=reference_normal[2] 92 | 93 | if normals_mode==1: 94 | depth = normals.copy() 95 | elif normals_mode==2: 96 | depth = np.sum(normals*n, axis=-1) 97 | elif normals_mode==3: 98 | depth = np.sum(normals*n, axis=-1) 99 | 100 | # print("dot 0", np.unique((normals*n)[...,0])) 101 | # print("dot 1", np.unique((normals*n)[...,1])) 102 | # print("dot 2", np.unique((normals*n)[...,2])) 103 | 104 | depth = np.arccos(depth) 105 | # TODO normalize 106 | elif normals_mode==4: 107 | # print("dot 0", np.unique((normals*n)[...,0])) 108 | # print("dot 1", np.unique((normals*n)[...,1])) 109 | # print("dot 2", np.unique((normals*n)[...,2])) 110 | depth = np.sum(normals*n, axis=-1) 111 | angle = np.arccos(depth) 112 | depth = np.dstack((np.sin(angle), np.cos(angle))) 113 | # print("depth", depth.shape) 114 | elif normals_mode==5: 115 | # depth = normals.copy() 116 | 117 | # x1 = 220 118 | # x2 = 1051 119 | # y1 = 400 120 | # y2 = 1530 121 | 122 | # mask = np.zeros_like(depth) 123 | # mask[313:1534, 286:1051]=1 124 | # mask[x1:x2, y1:y2]=1 125 | # mask = mask.astype(bool) 126 | # depth[~mask]=0 127 | 128 | 129 | dx = cv2.Sobel(depth, cv2.CV_64F, 1, 0, ksize=k) 130 | dy = cv2.Sobel(depth, cv2.CV_64F, 0, 1, ksize=k) 131 | 132 | dx[np.isnan(dx)]=0 133 | dy[np.isnan(dy)]=0 134 | dx[np.isinf(dx)]=0 135 | dy[np.isinf(dy)]=0 136 | 137 | dx/=np.max(dx) 138 | dy/=np.max(dy) 139 | 140 | # print("dx", dx.shape, dx.dtype) 141 | # print("dy", dy.shape, dy.dtype) 142 | # print("min max dx", np.min(dx), np.max(dx)) 143 | # print("min max dy", np.min(dy), np.max(dy)) 144 | 145 | # dx[x1-k:x1+k, :] = 0 146 | # dx[x2-k:x2+k, :] = 0 147 | # dx[:, y1-k:y1+k] = 0 148 | # dx[:, y2-k:y2+k] = 0 149 | 150 | # dy[x1-k:x1+k, :] = 0 151 | # dy[x2-k:x2+k, :] = 0 152 | # dy[:, y1-k:y1+k] = 0 153 | # dy[:, y2-k:y2+k] = 0 154 | 155 | 156 | depth = np.dstack((dx,dy)) 157 | 158 | return depth 159 | 160 | def rotate_depth(depth, R, K): 161 | 162 | h, w = depth.shape 163 | 164 | xx, yy = np.meshgrid(range(w), range(h)) 165 | xx = np.ravel(xx) 166 | yy = np.ravel(yy) 167 | dd = np.ravel(depth) 168 | 169 | pc = np.vstack((xx,yy,dd)) 170 | pc = inv(K)@pc 171 | pc = K@R@pc 172 | 173 | d_rotated = pc[-1,:].reshape(h,w) 174 | 175 | return d_rotated -------------------------------------------------------------------------------- /src/utils/overlaps.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | import cv2 5 | import skimage.draw 6 | 7 | def mask_to_rotated_bbox(mask, mask_ids, im_shape, center, estimate_type='minmax', device=None): 8 | if mask_ids is not None: 9 | mask_ids = torch.from_numpy(np.array(np.unravel_index(list(mask_ids),im_shape)).T).to(device) 10 | if estimate_type in ['minmax','3sigma-center','3sigma',None]: 11 | bbox_polygon, rect = _mask_to_rotated_bbox_using_pca(mask, mask_ids, center, estimate_type) 12 | elif estimate_type in ['cv','cv2','opencv2','opencv','findContours']: 13 | bbox_polygon, rect = _mask_to_rotated_bbox_using_cv2(mask) 14 | else: 15 | raise Exception('Invalid estimate_type in mask_to_rotated_bbox, allowed only: opencv, minmax, 3sigma and 3sigma-center') 16 | 17 | xx, yy = skimage.draw.polygon(bbox_polygon[:, 0], bbox_polygon[:, 1], shape=im_shape) 18 | 19 | bbox_ids = np.ravel_multi_index((xx, yy), dims=im_shape) 20 | if mask is not None: 21 | bbox_mask = torch.zeros_like(mask) 22 | bbox_mask[xx, yy] = 1 23 | else: 24 | bbox_mask = None 25 | 26 | return bbox_polygon, bbox_mask, bbox_ids 27 | 28 | def _mask_to_rotated_bbox_using_cv2(mask): 29 | _, contours, _ = cv2.findContours(mask.cpu().numpy(), cv2.RETR_EXTERNAL, 1) 30 | rect = cv2.minAreaRect(np.concatenate(contours,axis=0).squeeze()) 31 | (x, y), (w, h), a = rect 32 | 33 | bbox_polygon = cv2.boxPoints(rect) 34 | bbox_polygon = np.concatenate((bbox_polygon,bbox_polygon[:1,:]),axis=0) 35 | 36 | return bbox_polygon[:,::-1], rect 37 | 38 | def _mask_to_rotated_bbox_using_pca(mask_, mask_idx, center, estimate_type='minimax'): 39 | if mask_idx is None: 40 | mask_idx = torch.nonzero(mask_) 41 | mask_idx = mask_idx.float() 42 | mask_idx_center = mask_idx.mean(dim=0) 43 | A = torch.svd(mask_idx - mask_idx_center) 44 | 45 | if estimate_type == 'minmax': 46 | US_min = (A.U * A.S).min(dim=0)[0].cpu().numpy() 47 | US_max = (A.U * A.S).max(dim=0)[0].cpu().numpy() 48 | elif estimate_type == '3sigma-center': 49 | 50 | pred_center = torch.from_numpy(np.array(center)).float().to(mask_idx.device) 51 | pred_center_uv = np.matmul((pred_center - mask_idx_center).cpu().reshape(1, -1), 52 | A.V.t().inverse().cpu().numpy()) 53 | 54 | stdiv_range = 3 * (A.U * A.S - pred_center_uv.to(A.V.device)).abs().std(dim=0) 55 | US_max = stdiv_range + pred_center_uv[0].to(A.V.device) 56 | US_min = -stdiv_range + pred_center_uv[0].to(A.V.device) 57 | elif estimate_type == '3sigma' or estimate_type == None: 58 | US_max = 3 * (A.U * A.S).abs().std(dim=0).cpu().numpy() 59 | US_min = -US_max 60 | else: 61 | raise Exception('Invalid estimate_type in _mask_to_rotated_bbox_using_pca, allowed only: minmax, 3sigma and 3sigma-center') 62 | 63 | bbox = np.array([[US_min[0], US_min[1]], 64 | [US_min[0], US_max[1]], 65 | [US_max[0], US_max[1]], 66 | [US_max[0], US_min[1]], 67 | [US_min[0], US_min[1]]]) 68 | 69 | bbox_polygon = np.matmul(bbox, A.V.cpu().numpy().T) + mask_idx_center.cpu().numpy() 70 | 71 | rect = cv2.minAreaRect(bbox_polygon) 72 | 73 | return bbox_polygon, rect 74 | 75 | def overlap_pixels(instance_mask, gt_mask): 76 | return (torch.sum(instance_mask & gt_mask).type(torch.float32) / torch.sum(instance_mask | gt_mask).type(torch.float32)).item() 77 | 78 | def overlap_pixels_px_missing(instance_mask, gt_mask): 79 | return (torch.sum(instance_mask | gt_mask).type(torch.float32) - torch.sum(instance_mask & gt_mask).type(torch.float32)).item() 80 | 81 | def overlap_pixels_ids(instance_ids, gt_mask_ids): 82 | instance_ids,gt_mask_ids = set(instance_ids), set(gt_mask_ids) 83 | inter = len(instance_ids.intersection(gt_mask_ids)) 84 | return float(inter) / float(len(instance_ids) + len(gt_mask_ids) - inter) 85 | 86 | def overlap_rot_bbox(instances_points, gts_points): 87 | def _to_rot_bbox_array(pts): 88 | rb = [cv2.minAreaRect(p) for p in pts] 89 | return np.array([[cx,cy,h,w,-a] for (cx,cy),(w,h),a in rb]) 90 | 91 | return rbbx_overlaps(_to_rot_bbox_array(instances_points), 92 | _to_rot_bbox_array(gts_points)) 93 | 94 | def rbbx_overlaps(boxes, query_boxes): 95 | ''' 96 | Parameters 97 | ---------------- 98 | boxes: (N, 5) --- x_ctr, y_ctr, height, width, angle 99 | query: (K, 5) --- x_ctr, y_ctr, height, width, angle 100 | ---------------- 101 | Returns 102 | ---------------- 103 | Overlaps (N, K) IoU 104 | ''' 105 | 106 | N = boxes.shape[0] 107 | K = query_boxes.shape[0] 108 | overlaps = np.zeros((N, K), dtype=np.float32) 109 | 110 | for k in range(K): 111 | query_area = query_boxes[k, 2] * query_boxes[k, 3] 112 | for n in range(N): 113 | box_area = boxes[n, 2] * boxes[n, 3] 114 | # IoU of rotated rectangle 115 | # loading data anti to clock-wise 116 | rn = ((boxes[n, 0], boxes[n, 1]), (boxes[n, 3], boxes[n, 2]), -boxes[n, 4]) 117 | rk = ( 118 | (query_boxes[k, 0], query_boxes[k, 1]), (query_boxes[k, 3], query_boxes[k, 2]), -query_boxes[k, 4]) 119 | int_pts = cv2.rotatedRectangleIntersection(rk, rn)[1] 120 | # print type(int_pts) 121 | if None is not int_pts: 122 | order_pts = cv2.convexHull(int_pts, returnPoints=True) 123 | int_area = cv2.contourArea(order_pts) 124 | overlaps[n, k] = int_area * 1.0 / (query_area + box_area - int_area) 125 | return overlaps 126 | -------------------------------------------------------------------------------- /src/utils/evaluation/center.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import numpy as np 5 | import pylab as plt 6 | import torch 7 | 8 | from utils.evaluation import get_AP_and_F1, NumpyEncoder 9 | 10 | class CenterEvaluation: 11 | def __init__(self, exp_name='', exp_attributes=None): 12 | self.Y = [] 13 | self.P = [] 14 | self.is_difficult = [] 15 | self.exp_name = exp_name 16 | self.exp_attributes = exp_attributes 17 | 18 | def save_str(self): 19 | return "" 20 | def save_attributes(self): 21 | return dict(tau=self.tau_thr, d_alpha=self.merge_threshold_px, score_thr=self.score_thr) 22 | 23 | def add_image_prediction(self, im_name, im_index, im_shape, predictions, predictions_score, 24 | gt_instances_ids, gt_centers_dict, gt_difficult, centerdir_gt, **kwargs): 25 | gt_centers = [gt_centers_dict[k] for k in sorted(gt_centers_dict.keys())] 26 | gt_centers = np.array(gt_centers) 27 | 28 | gt_missed = False 29 | pred_missed = False 30 | 31 | pred_gt_match = [] 32 | 33 | if len(predictions) > 0: 34 | 35 | pred_gt_center = [] 36 | # TODO: this should be tested for any errors !!! 37 | # convert predictions from 2d to 1d index and use it to find corresponding GT ids 38 | for pred_id in np.ravel_multi_index(np.array([[pred[1],pred[0]] for pred in predictions])): 39 | gt_cent = np.array([0,0]) # by default 40 | for index,ids in gt_instances_ids.items(): 41 | if np.any(ids == pred_id): 42 | gt_cent = np.mean(np.unravel_index(pred_id,im_shape),dim=0) 43 | break 44 | pred_gt_center.append(gt_cent) 45 | # Original implementation using gt_instances map 46 | #pred_gt_center = [(gt_instances == gt_instances[int(np.round(pred[1])), 47 | # int(np.round(pred[0]))]).nonzero().type(torch.float32).mean(0).cpu().numpy() 48 | # for pred in predictions] 49 | 50 | pred_gt_center = np.array(pred_gt_center) 51 | 52 | pred_list = np.concatenate((pred_gt_center[:, 1:2], 53 | pred_gt_center[:, 0:1], 54 | predictions_score.reshape(-1,1)), axis=1) 55 | 56 | # match predictions with groundtruth 57 | pred_gt_match = np.zeros(shape=(pred_list.shape[0], 1)) 58 | remaining_pred = np.ones(len(predictions), dtype=np.float32) 59 | 60 | for gt_center in gt_centers: 61 | center_distance = [np.sqrt(np.sum(np.power(np.array([pred[1], pred[0]]) - gt_center, 2))) for pred in 62 | pred_list] 63 | # ignore detection that have already been matched to GT 64 | center_distance = center_distance * remaining_pred 65 | if (center_distance <= 5).any(): 66 | best_match_id = np.argmin(center_distance) 67 | 68 | # mark only the detection with best match with the GT as TRUE POSITIVE 69 | pred_gt_match[best_match_id] = 1 70 | remaining_pred[best_match_id] = np.infpred_gt_match_by_center 71 | else: 72 | self.Y.append(np.array([1.0])) 73 | self.P.append(-np.inf) 74 | self.is_difficult.append(gt_difficult[int(gt_center[0]), int(gt_center[1])].item()) 75 | gt_missed = True 76 | print('best distance: %f' % np.min(center_distance)) 77 | 78 | self.Y.extend(pred_gt_match[:]) 79 | self.P.extend(pred_list[:, -1]) 80 | self.is_difficult.extend([gt_difficult[int(gt[1]), int(gt[0])].item() 81 | for p, gt in zip(predictions, pred_list)]) 82 | 83 | if not pred_gt_match.all(): 84 | pred_missed = True 85 | 86 | elif len(gt_centers) > 0: 87 | gt_missed = True 88 | for gt_center in gt_centers: 89 | self.Y.append(np.array([1.0])) 90 | self.P.append(-np.inf) 91 | self.is_difficult.append(gt_difficult[int(gt_center[0]), int(gt_center[1])].item()) 92 | 93 | return gt_missed, pred_missed, pred_gt_match, None 94 | 95 | def calc_and_display_final_metrics(self, dataset, print_result=True, plot_result=True, save_dir=None, **kwargs): 96 | AP, F1, precision, recall, thrs = get_AP_and_F1(self.Y, self.P, self.is_difficult) 97 | 98 | if print_result: 99 | print('AP=%f' % AP) 100 | print('best F-measure=%f at thr=%f' % (np.max(F1),thrs[np.argmax(F1)-1])) 101 | 102 | fig = None 103 | if plot_result: 104 | fig = plt.figure() 105 | plt.plot(recall,precision) 106 | plt.title('Precision-recall (AP=%f, F1=%f)' % (AP, np.max(F1))) 107 | plt.xlabel('Recall') 108 | plt.ylabel('Precision') 109 | plt.xlim(0,1) 110 | plt.ylim(0, 1) 111 | 112 | metrics = dict(AP=AP, F1=F1, precision=precision, recall=recall, thrs=thrs) 113 | 114 | ######################################################################################################## 115 | # SAVE EVAL RESULTS TO JSON FILE 116 | if save_dir is not None: 117 | out_dir = os.path.join(save_dir, self.exp_name, self.save_str()) 118 | os.makedirs(out_dir, exist_ok=True) 119 | 120 | if fig is not None: 121 | fig.savefig(os.path.join(out_dir, 'AP.png')) 122 | 123 | if metrics is not None: 124 | with open(os.path.join(out_dir, 'results.json'), 'w') as file: 125 | file.write(json.dumps(metrics, cls=NumpyEncoder)) 126 | 127 | return metrics 128 | 129 | def get_results_timestamp(self, save_dir): 130 | res_filename = os.path.join(save_dir, self.exp_name, self.save_str(), 'results.json') 131 | 132 | return os.path.getmtime(res_filename) if os.path.exists(res_filename) else 0 -------------------------------------------------------------------------------- /src/models/center_estimator_with_orientation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from models.center_estimator import CenterEstimator 4 | from models.center_estimator_fast import CenterEstimatorFast 5 | 6 | class CenterOrientationEstimator(CenterEstimator): 7 | def __init__(self, args=dict(), is_learnable=True): 8 | super(CenterOrientationEstimator, self).__init__(args, is_learnable=is_learnable) 9 | 10 | self.enable_6dof = args.get('enable_6dof') 11 | self.use_orientation_confidence_score = args.get('use_orientation_confidence_score') 12 | 13 | def init_output(self, num_vector_fields=1): 14 | super(CenterOrientationEstimator, self).init_output(num_vector_fields) 15 | 16 | REQUIRED_VECTOR_FIELDS = 5 17 | if self.enable_6dof: 18 | REQUIRED_VECTOR_FIELDS += 4 19 | if self.use_orientation_confidence_score: 20 | REQUIRED_VECTOR_FIELDS += 1 21 | 22 | assert self.num_vector_fields >= REQUIRED_VECTOR_FIELDS 23 | 24 | def forward(self, input, ignore_gt=False, **gt): 25 | ret = super(CenterOrientationEstimator, self).forward(input, ignore_gt, **gt) 26 | 27 | # use input from parent forward (in case this is modified) 28 | input = ret['output'] 29 | center_pred = ret['center_pred'] 30 | 31 | predictions = input[:, 0:self.num_vector_fields] 32 | 33 | batch_size = center_pred.shape[0] 34 | num_pred = center_pred.shape[1] 35 | 36 | # WARNING: this assumes CenterEstimator is used as parent WITHOUT fourier (!!) 37 | if self.enable_6dof: 38 | sin_orientation = predictions[:, 3:6] 39 | cos_orientation = predictions[:, 6:9] 40 | else: 41 | sin_orientation = predictions[:, 3:4] 42 | cos_orientation = predictions[:, 4:5] 43 | 44 | prediction_angles = torch.zeros((batch_size,num_pred,sin_orientation.shape[1])) 45 | 46 | if self.use_orientation_confidence_score: 47 | prediction_confidence_score = predictions[:, 9:10] if self.enable_6dof else predictions[:, 5:6] 48 | orientation_confidence_score = torch.zeros((batch_size,num_pred,1)).to(center_pred.device) 49 | 50 | for b in range(batch_size): 51 | # for every predicted center point find its center 52 | for i, pred in enumerate(center_pred[b]): 53 | if pred[0] != 0: 54 | x,y = pred[1:3] 55 | s = sin_orientation[b,:,int(y), int(x)] 56 | c = cos_orientation[b,:, int(y), int(x)] 57 | pred_angle = torch.atan2(c, s) 58 | 59 | pred_angle = torch.rad2deg(pred_angle) 60 | pred_angle += 360 * (pred_angle < 0).int() 61 | 62 | prediction_angles[b,i,:] = pred_angle 63 | 64 | if self.use_orientation_confidence_score: 65 | orientation_confidence_score[b,i] = prediction_confidence_score[b,0,int(y), int(x)] 66 | 67 | # add orientations to list of returned values 68 | ret['pred_angle'] = prediction_angles 69 | 70 | # add orientation confidence score to predictions if needed 71 | if self.use_orientation_confidence_score: 72 | ret['center_pred'] = torch.cat((center_pred,orientation_confidence_score),axis=2) 73 | 74 | return ret 75 | 76 | 77 | import time 78 | 79 | class CenterOrientationEstimatorFast(CenterEstimatorFast): 80 | def __init__(self, args=dict(), is_learnable=True): 81 | super(CenterOrientationEstimatorFast, self).__init__(args, is_learnable=is_learnable) 82 | 83 | self.enable_6dof = args.get('enable_6dof') 84 | self.use_orientation_confidence_score = args.get('use_orientation_confidence_score') 85 | 86 | def init_output(self, num_vector_fields=1): 87 | super(CenterOrientationEstimatorFast, self).init_output(num_vector_fields) 88 | 89 | REQUIRED_VECTOR_FIELDS = 5 90 | if self.enable_6dof: 91 | REQUIRED_VECTOR_FIELDS += 4 92 | if self.use_orientation_confidence_score: 93 | REQUIRED_VECTOR_FIELDS += 1 94 | 95 | assert self.num_vector_fields >= REQUIRED_VECTOR_FIELDS 96 | 97 | 98 | def forward(self, input, ignore_gt=False, **gt): 99 | center_pred, times = super(CenterOrientationEstimatorFast, self).forward(input) 100 | 101 | start_orient = time.time() 102 | predictions = input[:, 0:self.num_vector_fields] 103 | 104 | num_pred = center_pred.shape[0] 105 | 106 | if self.enable_6dof: 107 | sin_orientation = predictions[:, 3:6] 108 | cos_orientation = predictions[:, 6:9] 109 | else: 110 | sin_orientation = predictions[:, 3:4] 111 | cos_orientation = predictions[:, 4:5] 112 | 113 | if self.use_orientation_confidence_score: 114 | prediction_confidence_score = predictions[:, 9:10] if self.enable_6dof else predictions[:, 5:6] 115 | orientation_confidence_score = torch.zeros((num_pred,1)).to(center_pred.device) 116 | 117 | prediction_angles = torch.zeros((num_pred,sin_orientation.shape[1]), device=predictions.device) 118 | 119 | # for every predicted center point find its center 120 | for i, pred in enumerate(center_pred): 121 | batch = int(pred[0]) 122 | x,y = pred[1:3] 123 | 124 | score = pred[-1] 125 | 126 | if self.use_orientation_confidence_score: 127 | confidence_score = prediction_confidence_score[batch,0, int(y), int(x)] 128 | orientation_confidence_score[i,:] = confidence_score 129 | 130 | score *= confidence_score 131 | 132 | if score <= 0.9: 133 | continue 134 | 135 | s = sin_orientation[batch,:, int(y), int(x)] 136 | c = cos_orientation[batch,:, int(y), int(x)] 137 | 138 | pred_angle = torch.atan2(c, s) 139 | 140 | pred_angle = torch.rad2deg(pred_angle) 141 | pred_angle += 360 * (pred_angle < 0).int() 142 | 143 | prediction_angles[i,:] = pred_angle 144 | 145 | if self.use_orientation_confidence_score: 146 | center_pred = torch.cat((center_pred, orientation_confidence_score.to(predictions.device)),axis=1) 147 | 148 | # add orientations to list of returned values 149 | center_pred_with_rot = torch.cat((center_pred, prediction_angles.to(predictions.device)),axis=1) 150 | 151 | end_orient = time.time() 152 | times = list(times) 153 | times[-1] = times[-1] + (end_orient-start_orient) 154 | 155 | return center_pred_with_rot, tuple(times) -------------------------------------------------------------------------------- /src/config/mujoco/train.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | 4 | import torch 5 | from utils import transforms as my_transforms 6 | from torchvision.transforms import InterpolationMode 7 | 8 | DATA_DIR=os.environ.get('MUJOCO_DIR') 9 | 10 | OUTPUT_DIR=os.environ.get('OUTPUT_DIR',default='../exp') 11 | 12 | NUM_FIELDS = 5 13 | SIZE = 512 14 | USE_DEPTH = os.environ.get('USE_DEPTH', default='False').lower() == 'true' 15 | 16 | IN_CHANNELS = 3 17 | 18 | if USE_DEPTH: 19 | IN_CHANNELS = 4 20 | 21 | args = dict( 22 | 23 | cuda=True, 24 | display=False, 25 | display_it=5, 26 | 27 | tf_logging=['loss'], 28 | tf_logging_iter=2, 29 | 30 | visualizer=dict(name='OrientationVisualizeTrain'), 31 | 32 | save=True, 33 | save_interval=10, 34 | 35 | # -------- 36 | n_epochs=10, 37 | ablation_str="", 38 | 39 | save_dir=os.path.join(OUTPUT_DIR, 'mujoco','{args[ablation_str]}', 40 | 'backbone={args[model][kwargs][backbone]}', 41 | 'num_train_epoch={args[n_epochs]}', 42 | 'depth={args[model][kwargs][use_depth]}', 43 | 'multitask_weight={args[multitask_weighting][name]}', 44 | ), 45 | 46 | pretrained_model_path = None, 47 | resume_path = None, 48 | 49 | pretrained_center_model_path = None, 50 | 51 | train_dataset = { 52 | 'name': 'mujoco', 53 | 'kwargs': { 54 | 'normalize': False, 55 | 'root_dir': DATA_DIR, 56 | 'subfolder': ['mujoco', 57 | 'mujoco_all_combinations_normal_color_temp', 58 | 'mujoco_all_combinations_rgb_light', 59 | 'mujoco_white_desk_HS_extreme_color_temp', 60 | 'mujoco_white_desk_HS_normal_color_temp'], 61 | 'use_depth': USE_DEPTH, 62 | 'fixed_bbox_size': 15, 63 | 'transform': my_transforms.get_transform([ 64 | # for training without augmentation (same as testing) 65 | { 66 | 'name': 'ToTensor', 67 | 'opts': { 68 | 'keys': ('image', 'instance', 'label', 'ignore', 'orientation', 'mask') + (('depth',) if USE_DEPTH else ()), 69 | 'type': ( 70 | torch.FloatTensor, torch.ShortTensor, torch.ByteTensor, torch.ByteTensor, torch.FloatTensor, torch.ByteTensor) + ((torch.FloatTensor, ) if USE_DEPTH else ()), 71 | } 72 | }, 73 | { 74 | 'name': 'Resize', 75 | 'opts': { 76 | 'keys': ('image', 'instance', 'label', 'ignore', 'orientation', 'mask') + (('depth',) if USE_DEPTH else ()), 77 | 'interpolation': (InterpolationMode.BILINEAR, InterpolationMode.NEAREST, InterpolationMode.NEAREST, InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.NEAREST) + ((InterpolationMode.BILINEAR, ) if USE_DEPTH else ()), 78 | 'keys_bbox': ('center',), 79 | 'size': (SIZE, SIZE), 80 | } 81 | }, 82 | # for training with random augmentation 83 | { 84 | 'name': 'RandomGaussianBlur', 85 | 'opts': { 86 | 'keys': ('image',), 87 | 'rate': 0.5, 'sigma': [0.5, 2] 88 | } 89 | }, 90 | { 91 | 'name': 'ColorJitter', 92 | 'opts': { 93 | 'keys': ('image',), 'p': 0.5, 94 | 'saturation': 0.3, 'hue': 0.3, 'brightness': 0.3, 'contrast':0.3 95 | } 96 | } 97 | 98 | ]), 99 | 'MAX_NUM_CENTERS':16*128, 100 | }, 101 | 102 | 'centerdir_gt_opts': dict( 103 | ignore_instance_mask_and_use_closest_center=True, 104 | center_ignore_px=3, 105 | skip_gt_center_mask_generate=True, # gt_center_mask is not needed since we are not training localization network 106 | 107 | MAX_NUM_CENTERS=16*128, 108 | ), 109 | 110 | 'batch_size': 4, 111 | 112 | # hard example disabled 113 | 'hard_samples_size': 0, 114 | 'hard_samples_selected_min_percent':0.1, 115 | 116 | 'workers': 4, 117 | 'shuffle': True, 118 | }, 119 | 120 | model = dict( 121 | name='fpn', 122 | kwargs= { 123 | 'backbone': 'tu-convnext_base', 124 | 'use_depth': USE_DEPTH, 125 | 'num_classes': [NUM_FIELDS, 1], 126 | 'use_custom_fpn':True, 127 | 'add_output_exp': False, 128 | 'in_channels': IN_CHANNELS, 129 | 'fpn_args': { 130 | 'decoder_segmentation_head_channels':64, 131 | 'upsampling':4, # required for ConvNext architectures 132 | 'classes_grouping': [(0, 1, 2, 5), (3, 4)], 133 | 'depth_mean': 0, 134 | 'depth_std':1, 135 | }, 136 | 'init_decoder_gain': 0.1 137 | }, 138 | optimizer='Adam', 139 | lr=1e-4, 140 | weight_decay=0, 141 | 142 | ), 143 | center_model=dict( 144 | name='CenterEstimator', 145 | kwargs=dict( 146 | # use vector magnitude as mask instead of regressed mask 147 | use_magnitude_as_mask=False, 148 | # thresholds for conv2d processing 149 | local_max_thr=0.1, mask_thr=0.01, exclude_border_px=0, 150 | use_dilated_nn=True, 151 | dilated_nn_args=dict( 152 | # single scale version (nn6) 153 | inner_ch=16, 154 | inner_kernel=3, 155 | dilations=[1, 4, 8, 12], 156 | freeze_learning=True, 157 | gradpass_relu=False, 158 | # version with leaky relu 159 | leaky_relu=False, 160 | # input check 161 | # use_polar_radii=False, 162 | use_centerdir_radii = False, 163 | use_centerdir_magnitude = False, 164 | use_cls_mask = False 165 | ), 166 | allow_input_backprop=False, 167 | backprop_only_positive=False, 168 | augmentation=False, 169 | scale_r=1.0, # 1024 170 | scale_r_gt=1024, # 1 171 | use_log_r=True, 172 | use_log_r_base='10', 173 | enable_6dof=False, 174 | ), 175 | # DISABLE TRAINING 176 | optimizer='Adam', 177 | lr=0, 178 | weight_decay=0, 179 | ), 180 | 181 | 182 | # loss options 183 | loss_type='OrientationLoss', 184 | loss_opts={ 185 | 'num_vector_fields': NUM_FIELDS, 186 | 'foreground_weight': 1, 187 | 188 | 'enable_centerdir_loss': True, 189 | 'no_instance_loss': True, # MUST be True to ignore instance mask 190 | 'centerdir_instance_weighted': True, 191 | 'regression_loss': 'l1', 192 | 193 | 'use_log_r': True, 194 | 'use_log_r_base': '10', 195 | 196 | 'orientation_args': dict( 197 | enable=True, 198 | no_instance_loss=False, 199 | regression_loss='l1', 200 | enable_6dof=False, 201 | symmetries=None, 202 | ) 203 | }, 204 | multitask_weighting=dict( 205 | name='uw', 206 | kwargs=dict( 207 | n_tasks=2 208 | ) 209 | ), 210 | loss_w={ 211 | 'w_r': 1, 212 | 'w_cos': 1, 213 | 'w_sin': 1, 214 | 'w_cent': 0.1, 215 | 'w_orientation': 1, 216 | }, 217 | 218 | ) 219 | 220 | args['lambda_scheduler_fn']=lambda _args: (lambda epoch: pow((1-((epoch)/_args['n_epochs'])), 0.9)) 221 | #args['lambda_scheduler_fn']=lambda _args: (lambda epoch: 1.0) # disabled 222 | 223 | args['model']['lambda_scheduler_fn'] = args['lambda_scheduler_fn'] 224 | args['center_model']['lambda_scheduler_fn'] = lambda _args: (lambda epoch: pow((1-((epoch)/_args['n_epochs'])), 0.9) if epoch > 1 else 0) 225 | 226 | 227 | def get_args(): 228 | return copy.deepcopy(args) 229 | -------------------------------------------------------------------------------- /src/config/vicos_towel/novel_object=bg+cloth/test.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | 4 | import torchvision 5 | if 'InterpolationMode' in dir(torchvision.transforms): 6 | from torchvision.transforms import InterpolationMode 7 | else: 8 | from PIL import Image as InterpolationMode 9 | 10 | import torch 11 | from utils import transforms as my_transforms 12 | 13 | VICOS_TOWEL_DATASET_DIR = os.environ.get('VICOS_TOWEL_DATASET_DIR') 14 | 15 | OUTPUT_DIR=os.environ.get('OUTPUT_DIR',default='../exp') 16 | 17 | NUM_FIELDS = 5 18 | TRAIN_SIZE = os.environ.get('TRAIN_SIZE', default=256*3) 19 | TEST_SIZE_WIDTH = int(os.environ.get('TEST_SIZE', default=256*3)) 20 | TEST_SIZE_HEIGHT = int(os.environ.get('TEST_SIZE', default=256*3)) 21 | USE_DEPTH = os.environ.get('USE_DEPTH', default='False').lower() == 'true' 22 | 23 | IN_CHANNELS = 3 24 | 25 | if USE_DEPTH: 26 | IN_CHANNELS = 4 27 | 28 | def img2tags_fn(x): 29 | # from x=/storage/datasets/ClothDataset/bg=/cloth=/rgb/image_0000_viewXX_lsYY_camera0.jpg 30 | # extract bg, cloth, view (configuration and clutter) and ls (lightning 31 | 32 | # extract bg 33 | bg = x.split('/')[-4] 34 | # extract cloth 35 | cloth = x.split('/')[-3] 36 | # extract ls 37 | lightning = int(x.split('/')[-1].split('_')[3].replace('ls', '')) 38 | # extract view 39 | view = int(x.split('/')[-1].split('_')[2].replace('view', '')) 40 | # configuration and clutter are encoded in view, clutter is on if view is odd 41 | clutter = 'on' if view % 2 == 1 else 'off' 42 | configuration = view // 2 43 | 44 | return [bg, cloth, f'lightning={lightning}', f'configuration={configuration}', f'clutter={clutter}'] 45 | 46 | 47 | 48 | model_dir = os.path.join(OUTPUT_DIR, 'vicos_towel', '{args[ablation_str]}', 49 | 'backbone={args[model][kwargs][backbone]}' + f'_size={TRAIN_SIZE}x{TRAIN_SIZE}', 50 | 'num_train_epoch={args[train_settings][n_epochs]}', 51 | 'depth={args[model][kwargs][use_depth]}', 52 | 'multitask_weight={args[train_settings][multitask_weighting][name]}') 53 | 54 | args = dict( 55 | 56 | cuda=True, 57 | display=True, 58 | autoadjust_figure_size=True, 59 | 60 | groundtruth_loading = True, 61 | 62 | save=True, 63 | 64 | save_dir=os.path.join(model_dir,'{args[dataset][kwargs][type]}_results{args[eval_epoch]}',f'test_size={TEST_SIZE_HEIGHT}x{TEST_SIZE_WIDTH}',), 65 | checkpoint_path=os.path.join(model_dir,'checkpoint{args[eval_epoch]}.pth'), 66 | 67 | eval_epoch='', 68 | ablation_str='', 69 | 70 | eval=dict( 71 | # available score types ['mask', 'center', 'hough_energy', 'edge_to_area_ratio_of_mask', 'avg(mask_pix)', 'avg(hough_pix)', 'avg(projected_dist_pix)'] 72 | score_combination_and_thr=[ 73 | { 74 | 'center': [0.1,0.01,0.05,0.15,0.2,0.25,0.3,0.35,0.40,0.45,0.5,0.55,0.60,0.65,0.7,0.75,0.8,0.85,0.9,0.94,0.99], 75 | }, 76 | ], 77 | score_thr_final=[0.01], 78 | skip_center_eval=True, 79 | orientation=dict( 80 | display_best_threshold=False, 81 | tau_thr=[20], 82 | ), 83 | enable_multivariate_eval=dict( 84 | image2tags_fn=img2tags_fn, 85 | ) 86 | ), 87 | visualizer=dict(name='OrientationVisualizeTest', 88 | opts=dict(show_rot_axis=(True,), 89 | impath2name_fn=lambda x: ".".join(x.split('/')[-4:]).replace('.jpg',''))), 90 | 91 | 92 | dataset={ 93 | 'name': 'vicos_towel', 94 | 'kwargs': { 95 | 'normalize': False, 96 | 'root_dir': os.path.abspath(VICOS_TOWEL_DATASET_DIR), 97 | 98 | 'type': 'test_novel_object=bg+cloth', 99 | 'subfolders': [dict(folder='bg=festive_tablecloth', data_subfolders=['cloth=checkered_rag_small', 'cloth=cotton_napkin']), 100 | ], 101 | 102 | 'fixed_bbox_size': 5, 103 | 'resize_factor': 1, 104 | 'use_depth': USE_DEPTH, 105 | 'use_mean_for_depth_nan': True, 106 | 'transform': my_transforms.get_transform([ 107 | { 108 | 'name': 'ToTensor', 109 | 'opts': { 110 | 'keys': ('image', 'instance', 'label', 'ignore', 'orientation', 'mask') + (('depth',) if USE_DEPTH else ()), 111 | 'type': ( 112 | torch.FloatTensor, torch.ShortTensor, torch.ByteTensor, torch.ByteTensor, torch.FloatTensor, 113 | torch.ByteTensor) + ((torch.FloatTensor, ) if USE_DEPTH else ()), 114 | } 115 | }, 116 | { 117 | 'name': 'Resize', 118 | 'opts': { 119 | 'keys': ('image', 'instance', 'label', 'ignore', 'orientation', 'mask') + (('depth',) if USE_DEPTH else ()), 120 | 'interpolation': (InterpolationMode.BILINEAR, InterpolationMode.NEAREST, InterpolationMode.NEAREST, InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.NEAREST) + ((InterpolationMode.BILINEAR, ) if USE_DEPTH else ()), 121 | 'keys_bbox': ('center',), 122 | 'size': (TEST_SIZE_HEIGHT, TEST_SIZE_WIDTH), 123 | } 124 | }, 125 | 126 | ]), 127 | 'MAX_NUM_CENTERS':16*128, 128 | }, 129 | 'centerdir_gt_opts': dict( 130 | ignore_instance_mask_and_use_closest_center=True, 131 | center_ignore_px=3, 132 | 133 | MAX_NUM_CENTERS=16*128, 134 | ), 135 | 136 | 'batch_size': 1, 137 | 'workers': 0, 138 | }, 139 | 140 | model=dict( 141 | name='fpn', 142 | kwargs={ 143 | 'backbone': 'tu-convnext_base', 144 | 'use_depth': USE_DEPTH, 145 | 'num_classes': [NUM_FIELDS, 1], 146 | 'use_custom_fpn': True, 147 | 'add_output_exp': False, 148 | 'in_channels': IN_CHANNELS, 149 | 'fpn_args': { 150 | 'decoder_segmentation_head_channels': 64, 151 | 'upsampling':4, # required for ConvNext architectures 152 | 'classes_grouping': [(0, 1, 2, 5), (3, 4)], 153 | 'depth_mean': 0.96, 'depth_std': 0.075, 154 | }, 155 | 'init_decoder_gain': 0.1 156 | }, 157 | ), 158 | center_model=dict( 159 | name='CenterOrientationEstimator', 160 | use_learnable_center_estimation=True, 161 | 162 | kwargs=dict( 163 | use_centerdir_radii = False, 164 | 165 | # use vector magnitude as mask instead of regressed mask 166 | use_magnitude_as_mask=True, 167 | # thresholds for conv2d processing 168 | local_max_thr=0.01, local_max_thr_use_abs=True, 169 | 170 | ### dilated neural net as head for center detection 171 | use_dilated_nn=True, 172 | dilated_nn_args=dict( 173 | return_sigmoid=False, 174 | # single scale version (nn6) 175 | inner_ch=16, 176 | inner_kernel=3, 177 | dilations=[1, 4, 8, 12], 178 | use_centerdir_radii=False, 179 | use_centerdir_magnitude=False, 180 | use_cls_mask=False 181 | ), 182 | augmentation=False, 183 | scale_r=1.0, # 1024 184 | scale_r_gt=1, # 1 185 | use_log_r=False, 186 | use_log_r_base='10', 187 | enable_6dof=False, 188 | ), 189 | 190 | ), 191 | num_vector_fields=NUM_FIELDS, 192 | 193 | # settings from train config needed for automated path construction 194 | train_settings=dict( 195 | n_epochs=10, 196 | multitask_weighting=dict(name='uw'), 197 | ) 198 | ) 199 | 200 | def get_args(): 201 | return copy.deepcopy(args) 202 | -------------------------------------------------------------------------------- /src/datasets/MuJoCoDataset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os, cv2, sys 3 | 4 | import numpy as np 5 | from PIL import Image 6 | from matplotlib import pyplot as plt 7 | 8 | import json 9 | import torch 10 | from torch.utils.data import Dataset 11 | 12 | sys.path.insert(1, os.path.join(sys.path[0], '..')) 13 | 14 | from utils.utils_depth import get_normals 15 | 16 | 17 | def angle_to_rad(A): 18 | A = (A + 360) if A < 0 else A 19 | return np.deg2rad(A - 180) 20 | 21 | class MuJoCoDataset(Dataset): 22 | 23 | def __init__(self, root_dir='./', subfolder="", MAX_NUM_CENTERS=1024, transform=None, use_depth=False, segment_cloth=False, use_normals=False, fixed_bbox_size=15, num_cpu_threads=1, normals_mode=1, reference_normal = [0,0,1], **kwargs): 24 | print('MuJoCoDataset created') 25 | 26 | if num_cpu_threads: 27 | torch.set_num_threads(num_cpu_threads) 28 | 29 | self.root_dir = root_dir 30 | self.MAX_NUM_CENTERS = MAX_NUM_CENTERS 31 | self.transform = transform 32 | self.use_depth = use_depth 33 | self.use_normals = use_normals 34 | self.fixed_bbox_size = fixed_bbox_size 35 | self.normals_mode = normals_mode 36 | self.reference_normal = reference_normal 37 | self.segment_cloth = segment_cloth 38 | 39 | if type(subfolder) not in [list, tuple]: 40 | subfolder = [subfolder] 41 | 42 | image_list = [] 43 | for sub in subfolder: 44 | image_list += sorted(glob.glob(f"{self.root_dir}/{sub}/rgb/*")) 45 | 46 | #image_list = sorted(glob.glob(f"{self.root_dir}/rgb/*")) 47 | 48 | self.image_list = image_list 49 | print(f'MuJoCoDataset of size {len(image_list)}') 50 | 51 | self.size = len(self.image_list) 52 | 53 | def __len__(self): 54 | return self.size 55 | 56 | def __getitem__(self, index): 57 | im_fn = self.image_list[index] 58 | 59 | fn = os.path.splitext(os.path.split(im_fn)[-1])[0] 60 | 61 | image = Image.open(im_fn) 62 | im_size = image.size 63 | 64 | root_dir = os.path.abspath(os.path.join(os.path.dirname(im_fn),'..')) 65 | depth_fn = os.path.join(root_dir, 'depth', f'{fn}.npy') 66 | 67 | gt_fn = os.path.join(root_dir, 'gt_points_vectors', f'{fn}.npy') 68 | if os.path.exists(gt_fn): 69 | gt_data = np.load(gt_fn) 70 | else: 71 | print(gt_fn, "not found") 72 | gt_data = [] 73 | 74 | sample = dict( 75 | image=image, 76 | im_name=im_fn, 77 | im_size=im_size, 78 | index=index, 79 | ) 80 | 81 | if self.segment_cloth: 82 | gt_seg_fn = os.path.join(root_dir, "gt_cloth", f"{fn}.png") 83 | 84 | segmentation_mask = Image.open(gt_seg_fn) 85 | 86 | if self.resize_factor is not None and self.resize_factor != 1.0: 87 | segmentation_mask = segmentation_mask.resize(im_size, Image.BILINEAR) 88 | 89 | sample["segmentation_mask"] = segmentation_mask 90 | 91 | 92 | if self.use_depth: 93 | depth = np.load(depth_fn) 94 | 95 | if self.use_normals: 96 | depth = get_normals(depth, normals_mode=self.normals_mode, household=False) 97 | else: 98 | depth/=np.max(depth) 99 | 100 | sample['depth']=depth 101 | 102 | # create instances image 103 | instances = torch.zeros((1, im_size[1], im_size[0]), dtype=torch.int16) 104 | orientation = torch.zeros((1, im_size[1], im_size[0]), dtype=torch.float32) 105 | label = torch.zeros((1, im_size[1], im_size[0]), dtype=torch.uint8) 106 | 107 | centers = [] 108 | 109 | m = self.fixed_bbox_size 110 | for n, (i,j,s,c) in enumerate(gt_data): 111 | i = int(i) 112 | j = int(j) 113 | centers.append((i,j)) 114 | 115 | angle = np.arctan2(s,c) 116 | angle = np.degrees(angle) 117 | 118 | instances[0, j-m:j+m,i-m:i+m] = n+1 119 | label[0, j-m:j+m,i-m:i+m] = 1 120 | orientation[0, j-m:j+m,i-m:i+m]=angle_to_rad(angle) 121 | 122 | 123 | sample['orientation'] = orientation 124 | sample['instance'] = instances 125 | sample['label'] = label 126 | sample['mask'] = (label > 0) 127 | sample['ignore'] = torch.zeros((1, im_size[1], im_size[0]), dtype=torch.uint8) 128 | 129 | centers = np.array(centers) 130 | sample['center'] = np.zeros((self.MAX_NUM_CENTERS, 2)) 131 | try: 132 | sample['center'][:centers.shape[0], :] = centers 133 | except: 134 | print("no objects in image") 135 | 136 | if self.transform is not None: 137 | rng = np.random.default_rng(seed=1234) 138 | sample = self.transform(sample, rng) 139 | 140 | if self.use_depth: 141 | sample['image'] = torch.cat((sample['image'], sample['depth'])) 142 | 143 | 144 | return sample 145 | 146 | 147 | 148 | if __name__ == "__main__": 149 | import pylab as plt 150 | import matplotlib 151 | 152 | matplotlib.use('TkAgg') 153 | from tqdm import tqdm 154 | import torch 155 | 156 | USE_DEPTH = True 157 | from utils import transforms as my_transforms 158 | 159 | transform = my_transforms.get_transform([ 160 | { 161 | 'name': 'ToTensor', 162 | 'opts': { 163 | 'keys': ('image', 'instance', 'label', 'ignore', 'orientation', 'mask') + (('depth',) if USE_DEPTH else ()), 164 | 'type': (torch.FloatTensor, torch.ShortTensor, torch.ByteTensor, torch.ByteTensor, torch.FloatTensor, 165 | torch.ByteTensor)+ ((torch.FloatTensor, ) if USE_DEPTH else ()), 166 | }, 167 | } 168 | ]) 169 | subfolders = ['mujoco', 'mujoco_all_combinations_normal_color_temp', 'mujoco_all_combinations_rgb_light', 'mujoco_white_desk_HS_extreme_color_temp', 'mujoco_white_desk_HS_normal_color_temp'] 170 | 171 | db = RTFMDataset(root_dir='/storage/datasets/ClothDataset/', resize_factor=1, transform_only_valid_centers=1.0, transform=transform, use_depth=USE_DEPTH, correct_depth_rotation=False, subfolder=subfolders) 172 | shapes = [] 173 | for item in tqdm(db): 174 | if item['index'] % 50 == 0: 175 | print('loaded index %d' % item['index']) 176 | shapes.append(item['image'].shape) 177 | # if True or np.array(item['ignore']).sum() > 0: 178 | # if True: 179 | if item['index'] % 1 == 0: 180 | center = item['center'] 181 | gt_centers = center[(center[:, 0] > 0) | (center[:, 1] > 0), :] 182 | # print(gt_centers) 183 | plt.clf() 184 | 185 | im = item['image'].permute([1, 2, 0]).numpy() 186 | # print(im.shape) 187 | 188 | plt.subplot(2, 2, 1) 189 | plt.imshow(im[...,:3]) 190 | plt.plot(gt_centers[:, 0], gt_centers[:, 1], 'r.') 191 | 192 | x = gt_centers[:,0] 193 | y = gt_centers[:,1] 194 | 195 | r = 100 196 | 197 | for i,j in zip(x,y): 198 | i = int(i) 199 | j = int(j) 200 | if i < 0 or i > item['orientation'].shape[2] or \ 201 | j < 0 or j > item['orientation'].shape[1]: 202 | continue 203 | angle = item['orientation'][0][j,i].numpy() 204 | # print(angle) 205 | # s = item['orientation'][1][j,i] 206 | s = -np.sin(angle) 207 | c = -np.cos(angle) 208 | # print(i,j,c,s) 209 | plt.plot([i,i+r*s],[j,j+r*c], 'r-') 210 | 211 | 212 | 213 | plt.draw(); plt.pause(0.01) 214 | plt.waitforbuttonpress() 215 | # plt.show() 216 | 217 | print("end") 218 | -------------------------------------------------------------------------------- /src/config/vicos_towel/novel_object=bg/test.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | 4 | import torchvision 5 | if 'InterpolationMode' in dir(torchvision.transforms): 6 | from torchvision.transforms import InterpolationMode 7 | else: 8 | from PIL import Image as InterpolationMode 9 | 10 | import torch 11 | from utils import transforms as my_transforms 12 | 13 | VICOS_TOWEL_DATASET_DIR = os.environ.get('VICOS_TOWEL_DATASET_DIR') 14 | 15 | OUTPUT_DIR=os.environ.get('OUTPUT_DIR',default='../exp') 16 | 17 | NUM_FIELDS = 5 18 | TRAIN_SIZE = os.environ.get('TRAIN_SIZE', default=256*3) 19 | TEST_SIZE_WIDTH = int(os.environ.get('TEST_SIZE', default=256*3)) 20 | TEST_SIZE_HEIGHT = int(os.environ.get('TEST_SIZE', default=256*3)) 21 | USE_DEPTH = os.environ.get('USE_DEPTH', default='False').lower() == 'true' 22 | 23 | IN_CHANNELS = 3 24 | 25 | if USE_DEPTH: 26 | IN_CHANNELS = 4 27 | 28 | def img2tags_fn(x): 29 | # from x=/storage/datasets/ClothDataset/bg=/cloth=/rgb/image_0000_viewXX_lsYY_camera0.jpg 30 | # extract bg, cloth, view (configuration and clutter) and ls (lightning 31 | 32 | # extract bg 33 | bg = x.split('/')[-4] 34 | # extract cloth 35 | cloth = x.split('/')[-3] 36 | # extract ls 37 | lightning = int(x.split('/')[-1].split('_')[3].replace('ls', '')) 38 | # extract view 39 | view = int(x.split('/')[-1].split('_')[2].replace('view', '')) 40 | # configuration and clutter are encoded in view, clutter is on if view is odd 41 | clutter = 'on' if view % 2 == 1 else 'off' 42 | configuration = view // 2 43 | 44 | return [bg, cloth, f'lightning={lightning}', f'configuration={configuration}', f'clutter={clutter}'] 45 | 46 | 47 | 48 | model_dir = os.path.join(OUTPUT_DIR, 'vicos_towel', '{args[ablation_str]}', 49 | 'backbone={args[model][kwargs][backbone]}' + f'_size={TRAIN_SIZE}x{TRAIN_SIZE}', 50 | 'num_train_epoch={args[train_settings][n_epochs]}', 51 | 'depth={args[model][kwargs][use_depth]}', 52 | 'multitask_weight={args[train_settings][multitask_weighting][name]}') 53 | 54 | args = dict( 55 | 56 | cuda=True, 57 | display=True, 58 | autoadjust_figure_size=True, 59 | 60 | groundtruth_loading = True, 61 | 62 | save=True, 63 | 64 | save_dir=os.path.join(model_dir,'{args[dataset][kwargs][type]}_results{args[eval_epoch]}',f'test_size={TEST_SIZE_HEIGHT}x{TEST_SIZE_WIDTH}',), 65 | checkpoint_path=os.path.join(model_dir,'checkpoint{args[eval_epoch]}.pth'), 66 | 67 | eval_epoch='', 68 | ablation_str='', 69 | 70 | eval=dict( 71 | # available score types ['mask', 'center', 'hough_energy', 'edge_to_area_ratio_of_mask', 'avg(mask_pix)', 'avg(hough_pix)', 'avg(projected_dist_pix)'] 72 | score_combination_and_thr=[ 73 | { 74 | 'center': [0.1,0.01,0.05,0.15,0.2,0.25,0.3,0.35,0.40,0.45,0.5,0.55,0.60,0.65,0.7,0.75,0.8,0.85,0.9,0.94,0.99], 75 | }, 76 | ], 77 | score_thr_final=[0.01], 78 | skip_center_eval=True, 79 | orientation=dict( 80 | display_best_threshold=False, 81 | tau_thr=[20], # [5,10,20] # try with 5px, 10px and 20px limit 82 | ), 83 | enable_multivariate_eval=dict( 84 | image2tags_fn=img2tags_fn, 85 | ) 86 | ), 87 | visualizer=dict(name='OrientationVisualizeTest', 88 | opts=dict(show_rot_axis=(True,), 89 | impath2name_fn=lambda x: ".".join(x.split('/')[-4:]).replace('.jpg',''))), 90 | 91 | 92 | dataset={ 93 | 'name': 'vicos_towel', 94 | 'kwargs': { 95 | 'normalize': False, 96 | 'root_dir': os.path.abspath(VICOS_TOWEL_DATASET_DIR), 97 | 98 | 'type': 'test_novel_object=bg', 99 | 'subfolders': [dict(folder='bg=festive_tablecloth', data_subfolders=['cloth=big_towel','cloth=checkered_rag_big','cloth=checkered_rag_medium', 'cloth=linen_rag','cloth=small_towel','cloth=towel_rag','cloth=waffle_rag','cloth=waffle_rag_stripes']), 100 | ], 101 | 102 | 'fixed_bbox_size': 5, 103 | 'resize_factor': 1, 104 | 'use_depth': USE_DEPTH, 105 | 'use_mean_for_depth_nan': True, 106 | 'transform': my_transforms.get_transform([ 107 | { 108 | 'name': 'ToTensor', 109 | 'opts': { 110 | 'keys': ('image', 'instance', 'label', 'ignore', 'orientation', 'mask') + (('depth',) if USE_DEPTH else ()), 111 | 'type': ( 112 | torch.FloatTensor, torch.ShortTensor, torch.ByteTensor, torch.ByteTensor, torch.FloatTensor, 113 | torch.ByteTensor) + ((torch.FloatTensor, ) if USE_DEPTH else ()), 114 | } 115 | }, 116 | { 117 | 'name': 'Resize', 118 | 'opts': { 119 | 'keys': ('image', 'instance', 'label', 'ignore', 'orientation', 'mask') + (('depth',) if USE_DEPTH else ()), 120 | 'interpolation': (InterpolationMode.BILINEAR, InterpolationMode.NEAREST, InterpolationMode.NEAREST, InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.NEAREST) + ((InterpolationMode.BILINEAR, ) if USE_DEPTH else ()), 121 | 'keys_bbox': ('center',), 122 | 'size': (TEST_SIZE_HEIGHT, TEST_SIZE_WIDTH), 123 | } 124 | }, 125 | 126 | ]), 127 | 'MAX_NUM_CENTERS':16*128, 128 | }, 129 | 'centerdir_gt_opts': dict( 130 | ignore_instance_mask_and_use_closest_center=True, 131 | center_ignore_px=3, 132 | 133 | MAX_NUM_CENTERS=16*128, 134 | ), 135 | 136 | 'batch_size': 1, 137 | 'workers': 0, 138 | }, 139 | 140 | model=dict( 141 | name='fpn', 142 | kwargs={ 143 | 'backbone': 'tu-convnext_base', 144 | 'use_depth': USE_DEPTH, 145 | 'num_classes': [NUM_FIELDS, 1], 146 | 'use_custom_fpn': True, 147 | 'add_output_exp': False, 148 | 'in_channels': IN_CHANNELS, 149 | 'fpn_args': { 150 | 'decoder_segmentation_head_channels': 64, 151 | 'upsampling':4, # required for ConvNext architectures 152 | 'classes_grouping': [(0, 1, 2, 5), (3, 4)], 153 | 'depth_mean': 0.96, 'depth_std': 0.075, 154 | }, 155 | 'init_decoder_gain': 0.1 156 | }, 157 | ), 158 | center_model=dict( 159 | name='CenterOrientationEstimator', 160 | use_learnable_center_estimation=True, 161 | 162 | kwargs=dict( 163 | use_centerdir_radii = False, 164 | 165 | # use vector magnitude as mask instead of regressed mask 166 | use_magnitude_as_mask=True, 167 | # thresholds for conv2d processing 168 | local_max_thr=0.01, local_max_thr_use_abs=True, 169 | 170 | ### dilated neural net as head for center detection 171 | use_dilated_nn=True, 172 | dilated_nn_args=dict( 173 | return_sigmoid=False, 174 | # single scale version (nn6) 175 | inner_ch=16, 176 | inner_kernel=3, 177 | dilations=[1, 4, 8, 12], 178 | use_centerdir_radii=False, 179 | use_centerdir_magnitude=False, 180 | use_cls_mask=False 181 | ), 182 | augmentation=False, 183 | scale_r=1.0, # 1024 184 | scale_r_gt=1, # 1 185 | use_log_r=False, 186 | use_log_r_base='10', 187 | enable_6dof=False, 188 | ), 189 | 190 | ), 191 | num_vector_fields=NUM_FIELDS, 192 | 193 | # settings from train config needed for automated path construction 194 | train_settings=dict( 195 | n_epochs=10, 196 | multitask_weighting=dict(name='uw'), 197 | ) 198 | ) 199 | 200 | def get_args(): 201 | return copy.deepcopy(args) 202 | -------------------------------------------------------------------------------- /src/config/vicos_towel/novel_object=cloth/test.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | 4 | import torchvision 5 | if 'InterpolationMode' in dir(torchvision.transforms): 6 | from torchvision.transforms import InterpolationMode 7 | else: 8 | from PIL import Image as InterpolationMode 9 | 10 | import torch 11 | from utils import transforms as my_transforms 12 | 13 | VICOS_TOWEL_DATASET_DIR = os.environ.get('VICOS_TOWEL_DATASET_DIR') 14 | 15 | OUTPUT_DIR=os.environ.get('OUTPUT_DIR',default='../exp') 16 | 17 | NUM_FIELDS = 5 18 | TRAIN_SIZE = os.environ.get('TRAIN_SIZE', default=256*3) 19 | TEST_SIZE_WIDTH = int(os.environ.get('TEST_SIZE', default=256*3)) 20 | TEST_SIZE_HEIGHT = int(os.environ.get('TEST_SIZE', default=256*3)) 21 | USE_DEPTH = os.environ.get('USE_DEPTH', default='False').lower() == 'true' 22 | 23 | IN_CHANNELS = 3 24 | 25 | if USE_DEPTH: 26 | IN_CHANNELS = 4 27 | 28 | def img2tags_fn(x): 29 | # from x=/storage/datasets/ClothDataset/bg=/cloth=/rgb/image_0000_viewXX_lsYY_camera0.jpg 30 | # extract bg, cloth, view (configuration and clutter) and ls (lightning 31 | 32 | # extract bg 33 | bg = x.split('/')[-4] 34 | # extract cloth 35 | cloth = x.split('/')[-3] 36 | # extract ls 37 | lightning = int(x.split('/')[-1].split('_')[3].replace('ls', '')) 38 | # extract view 39 | view = int(x.split('/')[-1].split('_')[2].replace('view', '')) 40 | # configuration and clutter are encoded in view, clutter is on if view is odd 41 | clutter = 'on' if view % 2 == 1 else 'off' 42 | configuration = view // 2 43 | 44 | return [bg, cloth, f'lightning={lightning}', f'configuration={configuration}', f'clutter={clutter}'] 45 | 46 | 47 | 48 | model_dir = os.path.join(OUTPUT_DIR, 'vicos_towel', '{args[ablation_str]}', 49 | 'backbone={args[model][kwargs][backbone]}' + f'_size={TRAIN_SIZE}x{TRAIN_SIZE}', 50 | 'num_train_epoch={args[train_settings][n_epochs]}', 51 | 'depth={args[model][kwargs][use_depth]}', 52 | 'multitask_weight={args[train_settings][multitask_weighting][name]}') 53 | 54 | args = dict( 55 | 56 | cuda=True, 57 | display=True, 58 | autoadjust_figure_size=True, 59 | 60 | groundtruth_loading = True, 61 | 62 | save=True, 63 | 64 | save_dir=os.path.join(model_dir,'{args[dataset][kwargs][type]}_results{args[eval_epoch]}',f'test_size={TEST_SIZE_HEIGHT}x{TEST_SIZE_WIDTH}',), 65 | checkpoint_path=os.path.join(model_dir,'checkpoint{args[eval_epoch]}.pth'), 66 | 67 | eval_epoch='', 68 | ablation_str='', 69 | 70 | eval=dict( 71 | # available score types ['mask', 'center', 'hough_energy', 'edge_to_area_ratio_of_mask', 'avg(mask_pix)', 'avg(hough_pix)', 'avg(projected_dist_pix)'] 72 | score_combination_and_thr=[ 73 | { 74 | 'center': [0.1,0.01,0.05,0.15,0.2,0.25,0.3,0.35,0.40,0.45,0.5,0.55,0.60,0.65,0.7,0.75,0.8,0.85,0.9,0.94,0.99], 75 | }, 76 | ], 77 | score_thr_final=[0.01], 78 | skip_center_eval=True, 79 | orientation=dict( 80 | display_best_threshold=False, 81 | tau_thr=[20], 82 | ), 83 | enable_multivariate_eval=dict( 84 | image2tags_fn=img2tags_fn, 85 | ) 86 | ), 87 | visualizer=dict(name='OrientationVisualizeTest', 88 | opts=dict(show_rot_axis=(True,), 89 | impath2name_fn=lambda x: ".".join(x.split('/')[-4:]).replace('.jpg',''))), 90 | 91 | 92 | dataset={ 93 | 'name': 'vicos_towel', 94 | 'kwargs': { 95 | 'normalize': False, 96 | 'root_dir': os.path.abspath(VICOS_TOWEL_DATASET_DIR), 97 | 98 | 'type': 'test_novel_object=cloth', 99 | 'subfolders': [dict(folder='bg=red_tablecloth', data_subfolders=['cloth=checkered_rag_small', 'cloth=cotton_napkin']), 100 | dict(folder='bg=white_desk', data_subfolders=['cloth=checkered_rag_small', 'cloth=cotton_napkin']), 101 | dict(folder='bg=green_checkered', data_subfolders=['cloth=checkered_rag_small', 'cloth=cotton_napkin']), 102 | dict(folder='bg=poster', data_subfolders=['cloth=checkered_rag_small', 'cloth=cotton_napkin']), 103 | ], 104 | 105 | 'fixed_bbox_size': 5, 106 | 'resize_factor': 1, 107 | 'use_depth': USE_DEPTH, 108 | 'use_mean_for_depth_nan': True, 109 | 'transform': my_transforms.get_transform([ 110 | { 111 | 'name': 'ToTensor', 112 | 'opts': { 113 | 'keys': ('image', 'instance', 'label', 'ignore', 'orientation', 'mask') + (('depth',) if USE_DEPTH else ()), 114 | 'type': ( 115 | torch.FloatTensor, torch.ShortTensor, torch.ByteTensor, torch.ByteTensor, torch.FloatTensor, 116 | torch.ByteTensor) + ((torch.FloatTensor, ) if USE_DEPTH else ()), 117 | } 118 | }, 119 | { 120 | 'name': 'Resize', 121 | 'opts': { 122 | 'keys': ('image', 'instance', 'label', 'ignore', 'orientation', 'mask') + (('depth',) if USE_DEPTH else ()), 123 | 'interpolation': (InterpolationMode.BILINEAR, InterpolationMode.NEAREST, InterpolationMode.NEAREST, InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.NEAREST) + ((InterpolationMode.BILINEAR, ) if USE_DEPTH else ()), 124 | 'keys_bbox': ('center',), 125 | 'size': (TEST_SIZE_HEIGHT, TEST_SIZE_WIDTH), 126 | } 127 | }, 128 | 129 | ]), 130 | 'MAX_NUM_CENTERS':16*128, 131 | }, 132 | 'centerdir_gt_opts': dict( 133 | ignore_instance_mask_and_use_closest_center=True, 134 | center_ignore_px=3, 135 | 136 | MAX_NUM_CENTERS=16*128, 137 | ), 138 | 139 | 'batch_size': 1, 140 | 'workers': 0, 141 | }, 142 | 143 | model=dict( 144 | name='fpn', 145 | kwargs={ 146 | 'backbone': 'tu-convnext_base', 147 | 'use_depth': USE_DEPTH, 148 | 'num_classes': [NUM_FIELDS, 1], 149 | 'use_custom_fpn': True, 150 | 'add_output_exp': False, 151 | 'in_channels': IN_CHANNELS, 152 | 'fpn_args': { 153 | 'decoder_segmentation_head_channels': 64, 154 | 'upsampling':4, # required for ConvNext architectures 155 | 'classes_grouping': [(0, 1, 2, 5), (3, 4)], 156 | 'depth_mean': 0.96, 'depth_std': 0.075, 157 | }, 158 | 'init_decoder_gain': 0.1 159 | }, 160 | ), 161 | center_model=dict( 162 | name='CenterOrientationEstimator', 163 | use_learnable_center_estimation=True, 164 | 165 | kwargs=dict( 166 | use_centerdir_radii = False, 167 | 168 | # use vector magnitude as mask instead of regressed mask 169 | use_magnitude_as_mask=True, 170 | # thresholds for conv2d processing 171 | local_max_thr=0.01, local_max_thr_use_abs=True, 172 | 173 | ### dilated neural net as head for center detection 174 | use_dilated_nn=True, 175 | dilated_nn_args=dict( 176 | return_sigmoid=False, 177 | # single scale version (nn6) 178 | inner_ch=16, 179 | inner_kernel=3, 180 | dilations=[1, 4, 8, 12], 181 | use_centerdir_radii=False, 182 | use_centerdir_magnitude=False, 183 | use_cls_mask=False 184 | ), 185 | augmentation=False, 186 | scale_r=1.0, # 1024 187 | scale_r_gt=1, # 1 188 | use_log_r=False, 189 | use_log_r_base='10', 190 | enable_6dof=False, 191 | ), 192 | 193 | ), 194 | num_vector_fields=NUM_FIELDS, 195 | 196 | # settings from train config needed for automated path construction 197 | train_settings=dict( 198 | n_epochs=10, 199 | multitask_weighting=dict(name='uw'), 200 | ) 201 | ) 202 | 203 | def get_args(): 204 | return copy.deepcopy(args) 205 | -------------------------------------------------------------------------------- /src/config/vicos_towel/test.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | 4 | import torchvision 5 | if 'InterpolationMode' in dir(torchvision.transforms): 6 | from torchvision.transforms import InterpolationMode 7 | else: 8 | from PIL import Image as InterpolationMode 9 | 10 | import torch 11 | from utils import transforms as my_transforms 12 | 13 | VICOS_TOWEL_DATASET_DIR = os.environ.get('VICOS_TOWEL_DATASET_DIR') 14 | 15 | OUTPUT_DIR=os.environ.get('OUTPUT_DIR',default='../exp') 16 | 17 | NUM_FIELDS = 5 18 | TRAIN_SIZE = os.environ.get('TRAIN_SIZE', default=256*3) 19 | TEST_SIZE_WIDTH = int(os.environ.get('TEST_SIZE', default=256*3)) 20 | TEST_SIZE_HEIGHT = int(os.environ.get('TEST_SIZE', default=256*3)) 21 | USE_DEPTH = os.environ.get('USE_DEPTH', default='False').lower() == 'true' 22 | 23 | IN_CHANNELS = 3 24 | 25 | if USE_DEPTH: 26 | IN_CHANNELS = 4 27 | 28 | def img2tags_fn(x): 29 | # from x=/storage/datasets/ClothDataset/bg=/cloth=/rgb/image_0000_viewXX_lsYY_camera0.jpg 30 | # extract bg, cloth, view (configuration and clutter) and ls (lightning 31 | 32 | # extract bg 33 | bg = x.split('/')[-4] 34 | # extract cloth 35 | cloth = x.split('/')[-3] 36 | # extract ls 37 | lightning = int(x.split('/')[-1].split('_')[3].replace('ls', '')) 38 | # extract view 39 | view = int(x.split('/')[-1].split('_')[2].replace('view', '')) 40 | # configuration and clutter are encoded in view, clutter is on if view is odd 41 | clutter = 'on' if view % 2 == 1 else 'off' 42 | configuration = view // 2 43 | 44 | return [bg, cloth, f'lightning={lightning}', f'configuration={configuration}', f'clutter={clutter}'] 45 | 46 | 47 | 48 | model_dir = os.path.join(OUTPUT_DIR, 'vicos_towel', '{args[ablation_str]}', 49 | 'backbone={args[model][kwargs][backbone]}' + f'_size={TRAIN_SIZE}x{TRAIN_SIZE}', 50 | 'num_train_epoch={args[train_settings][n_epochs]}', 51 | 'depth={args[model][kwargs][use_depth]}', 52 | 'multitask_weight={args[train_settings][multitask_weighting][name]}') 53 | 54 | args = dict( 55 | 56 | cuda=True, 57 | display=True, 58 | autoadjust_figure_size=True, 59 | 60 | groundtruth_loading = True, 61 | 62 | save=True, 63 | 64 | save_dir=os.path.join(model_dir,'{args[dataset][kwargs][type]}_results{args[eval_epoch]}',f'test_size={TEST_SIZE_HEIGHT}x{TEST_SIZE_WIDTH}',), 65 | checkpoint_path=os.path.join(model_dir,'checkpoint{args[eval_epoch]}.pth'), 66 | 67 | eval_epoch='', 68 | ablation_str='', 69 | 70 | eval=dict( 71 | # available score types ['mask', 'center', 'hough_energy', 'edge_to_area_ratio_of_mask', 'avg(mask_pix)', 'avg(hough_pix)', 'avg(projected_dist_pix)'] 72 | score_combination_and_thr=[ 73 | { 74 | 'center': [0.1,0.01,0.05,0.15,0.2,0.25,0.3,0.35,0.40,0.45,0.5,0.55,0.60,0.65,0.7,0.75,0.8,0.85,0.9,0.94,0.99], 75 | }, 76 | ], 77 | score_thr_final=[0.01], 78 | skip_center_eval=True, 79 | orientation=dict( 80 | display_best_threshold=False, 81 | tau_thr=[20], # [5,10,20] # try with 5px, 10px and 20px limit 82 | ), 83 | enable_multivariate_eval=dict( 84 | image2tags_fn=img2tags_fn, 85 | ) 86 | ), 87 | visualizer=dict(name='OrientationVisualizeTest', 88 | opts=dict(show_rot_axis=(True,), 89 | impath2name_fn=lambda x: ".".join(x.split('/')[-4:]).replace('.jpg',''))), 90 | 91 | 92 | dataset={ 93 | 'name': 'vicos_towel', 94 | 'kwargs': { 95 | 'normalize': False, 96 | 'root_dir': os.path.abspath(VICOS_TOWEL_DATASET_DIR), 97 | 98 | 'type': 'test', 99 | 'subfolders': [dict(folder='bg=red_tablecloth', data_subfolders=['cloth=checkered_rag_small', 'cloth=cotton_napkin']), 100 | dict(folder='bg=white_desk', data_subfolders=['cloth=checkered_rag_small', 'cloth=cotton_napkin']), 101 | dict(folder='bg=green_checkered', data_subfolders=['cloth=checkered_rag_small', 'cloth=cotton_napkin']), 102 | dict(folder='bg=poster', data_subfolders=['cloth=checkered_rag_small', 'cloth=cotton_napkin']), 103 | dict(folder='bg=festive_tablecloth', data_subfolders=['cloth=big_towel','cloth=checkered_rag_big','cloth=checkered_rag_medium', 'cloth=checkered_rag_small', 'cloth=cotton_napkin', 'cloth=linen_rag','cloth=small_towel','cloth=towel_rag','cloth=waffle_rag','cloth=waffle_rag_stripes']) 104 | ], 105 | 106 | 'fixed_bbox_size': 5, 107 | 'resize_factor': 1, 108 | 'use_depth': USE_DEPTH, 109 | 'use_mean_for_depth_nan': True, 110 | 'transform': my_transforms.get_transform([ 111 | { 112 | 'name': 'ToTensor', 113 | 'opts': { 114 | 'keys': ('image', 'instance', 'label', 'ignore', 'orientation', 'mask') + (('depth',) if USE_DEPTH else ()), 115 | 'type': ( 116 | torch.FloatTensor, torch.ShortTensor, torch.ByteTensor, torch.ByteTensor, torch.FloatTensor, 117 | torch.ByteTensor) + ((torch.FloatTensor, ) if USE_DEPTH else ()), 118 | } 119 | }, 120 | { 121 | 'name': 'Resize', 122 | 'opts': { 123 | 'keys': ('image', 'instance', 'label', 'ignore', 'orientation', 'mask') + (('depth',) if USE_DEPTH else ()), 124 | 'interpolation': (InterpolationMode.BILINEAR, InterpolationMode.NEAREST, InterpolationMode.NEAREST, InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.NEAREST) + ((InterpolationMode.BILINEAR, ) if USE_DEPTH else ()), 125 | 'keys_bbox': ('center',), 126 | 'size': (TEST_SIZE_HEIGHT, TEST_SIZE_WIDTH), 127 | } 128 | }, 129 | 130 | ]), 131 | 'MAX_NUM_CENTERS':16*128, 132 | }, 133 | 'centerdir_gt_opts': dict( 134 | ignore_instance_mask_and_use_closest_center=True, 135 | center_ignore_px=3, 136 | 137 | MAX_NUM_CENTERS=16*128, 138 | ), 139 | 140 | 'batch_size': 1, 141 | 'workers': 0, 142 | }, 143 | 144 | model=dict( 145 | name='fpn', 146 | kwargs={ 147 | 'backbone': 'tu-convnext_base', 148 | 'use_depth': USE_DEPTH, 149 | 'num_classes': [NUM_FIELDS, 1], 150 | 'use_custom_fpn': True, 151 | 'add_output_exp': False, 152 | 'in_channels': IN_CHANNELS, 153 | 'fpn_args': { 154 | 'decoder_segmentation_head_channels': 64, 155 | 'upsampling':4, # required for ConvNext architectures 156 | 'classes_grouping': [(0, 1, 2, 5), (3, 4)], 157 | 'depth_mean': 0.96, 'depth_std': 0.075, 158 | }, 159 | 'init_decoder_gain': 0.1 160 | }, 161 | ), 162 | center_model=dict( 163 | name='CenterOrientationEstimator', 164 | use_learnable_center_estimation=True, 165 | 166 | kwargs=dict( 167 | use_centerdir_radii = False, 168 | 169 | # use vector magnitude as mask instead of regressed mask 170 | use_magnitude_as_mask=True, 171 | # thresholds for conv2d processing 172 | local_max_thr=0.01, local_max_thr_use_abs=True, 173 | 174 | ### dilated neural net as head for center detection 175 | use_dilated_nn=True, 176 | dilated_nn_args=dict( 177 | return_sigmoid=False, 178 | # single scale version (nn6) 179 | inner_ch=16, 180 | inner_kernel=3, 181 | dilations=[1, 4, 8, 12], 182 | use_centerdir_radii=False, 183 | use_centerdir_magnitude=False, 184 | use_cls_mask=False 185 | ), 186 | augmentation=False, 187 | scale_r=1.0, # 1024 188 | scale_r_gt=1, # 1 189 | use_log_r=False, 190 | use_log_r_base='10', 191 | enable_6dof=False, 192 | ), 193 | 194 | ), 195 | num_vector_fields=NUM_FIELDS, 196 | 197 | # settings from train config needed for automated path construction 198 | train_settings=dict( 199 | n_epochs=10, 200 | multitask_weighting=dict(name='uw'), 201 | ) 202 | ) 203 | 204 | def get_args(): 205 | return copy.deepcopy(args) 206 | -------------------------------------------------------------------------------- /src/utils/evaluation/orientation.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from models.center_groundtruth import CenterDirGroundtruth 8 | 9 | from utils.evaluation import NumpyEncoder 10 | from utils.evaluation.center_global_min import CenterGlobalMinimizationEval 11 | 12 | class OrientationEval(CenterGlobalMinimizationEval): 13 | 14 | def __init__(self, *args, use_gt_centers=False, append_orientation_to_display_name=True, **kwargs): 15 | super(OrientationEval, self).__init__(*args, **kwargs) 16 | 17 | self.use_gt_centers = use_gt_centers 18 | self.metrics.update(dict(rotation=[],translation=[], rot_x=[], rot_y=[], rot_z=[])) 19 | self.append_orientation_to_display_name = append_orientation_to_display_name 20 | 21 | 22 | def add_image_prediction(self, im_name, im_index, im_shape, predictions, predictions_score, pred_angles, 23 | gt_instances_ids, gt_centers_dict, gt_difficult, centerdir_gt, return_matched_gt_idx=False, 24 | **kwargs): 25 | 26 | # use parent class for center prediction matching 27 | ret = super(OrientationEval, self).add_image_prediction( 28 | im_name, im_index, im_shape, predictions, predictions_score, 29 | gt_instances_ids, gt_centers_dict, gt_difficult, centerdir_gt, return_matched_gt_idx=True) 30 | 31 | gt_missed, pred_missed, pred_gt_match_by_center, filename_suffix, pred_gt_match_by_center_idx = ret 32 | 33 | gt_maps_keys = ['gt_orientation_sin', 'gt_orientation_cos'] 34 | gt_sin_orientation, gt_cos_orientation = CenterDirGroundtruth.parse_single_batch_groundtruth_map(centerdir_gt, 35 | keys=gt_maps_keys) 36 | 37 | if filename_suffix is None: 38 | filename_suffix = '' 39 | 40 | if pred_gt_match_by_center_idx.shape[0] != 0: 41 | gt_selected = np.array([gt_centers_dict[np.int16(i)][::-1] for i in pred_gt_match_by_center_idx[:,0] if i >= 0]) 42 | 43 | if len(gt_selected) > 0: 44 | assert len(pred_angles) == len(pred_gt_match_by_center) 45 | assert len(predictions) == len(pred_gt_match_by_center) 46 | assert len(gt_selected) == sum(pred_gt_match_by_center[:,0] != 0) 47 | 48 | trans_err = np.abs(predictions[pred_gt_match_by_center[:,0] != 0, :2] - gt_selected) 49 | 50 | angle_err = [] 51 | 52 | num_orientation_dim = pred_angles.shape[1] 53 | pred_angles = pred_angles[pred_gt_match_by_center[:,0] != 0,:] 54 | 55 | for i, c_gt in enumerate(gt_selected): 56 | if pred_gt_match_by_center[i] == 0: 57 | continue 58 | 59 | s = gt_sin_orientation[0:num_orientation_dim,0,int(c_gt[1]), int(c_gt[0])] 60 | c = gt_cos_orientation[0:num_orientation_dim,0,int(c_gt[1]), int(c_gt[0])] 61 | 62 | gt_angle_i = torch.atan2(c, s) 63 | gt_angle_i = torch.rad2deg(gt_angle_i) 64 | gt_angle_i += 360 * (gt_angle_i < 0).int() 65 | 66 | pred_angle_i = pred_angles[i] 67 | 68 | e = np.abs(gt_angle_i.cpu().numpy() - pred_angle_i) 69 | 70 | is_e_over_180 = (e > 180).astype(np.int32) 71 | e = is_e_over_180 * 360 - (is_e_over_180*2-1) * e # the same as: e = 360 - e if e > 180 else e 72 | 73 | angle_err.append(e) 74 | 75 | if len(angle_err) > 0: 76 | angle_err = np.array(angle_err) 77 | overall_rot_err = np.mean(angle_err,axis=1) 78 | 79 | self.metrics['rotation'].extend(overall_rot_err) 80 | self.metrics['translation'].extend(trans_err) 81 | 82 | if self.append_orientation_to_display_name: 83 | filename_suffix = f're_{np.mean(overall_rot_err):05.2f}_te_{np.mean(trans_err):05.2f}_{filename_suffix}' 84 | 85 | if len(angle_err.shape) > 1 and angle_err.shape[1] == 3: 86 | axis_rot_err = [angle_err[:,i] for i in range(angle_err.shape[1])] 87 | 88 | for rot_err,rot_axis in zip(axis_rot_err,['rot_y', 'rot_z', 'rot_x']): 89 | self.metrics[rot_axis].extend(rot_err) 90 | 91 | if self.append_orientation_to_display_name: 92 | axis_rot_dict = dict(zip(['ry', 'rz', 'rx'], np.mean(angle_err,axis=0))) 93 | filename_suffix = "_".join([f'{a}_{axis_rot_dict[a]:05.2f}' for a in ['rx', 'ry', 'rz']] + [filename_suffix]) 94 | else: 95 | print(f"No matching predictions found for {im_name}") 96 | if return_matched_gt_idx: 97 | return gt_missed, pred_missed, pred_gt_match_by_center, filename_suffix, pred_gt_match_by_center_idx 98 | else: 99 | return gt_missed, pred_missed, pred_gt_match_by_center, filename_suffix 100 | 101 | def calc_and_display_final_metrics(self, dataset, print_result=True, plot_result=True, save_dir=None, **kwargs): 102 | Re = np.array(self.metrics['Re']).mean() 103 | mae = np.array(self.metrics['mae']).mean() 104 | rmse = np.array(self.metrics['rmse']).mean() 105 | ratio = np.array(self.metrics['ratio']).mean() 106 | AP = np.array(self.metrics['precision']).mean() 107 | AR = np.array(self.metrics['recall']).mean() 108 | F1 = np.array(self.metrics['F1']).mean() 109 | TE = np.array(self.metrics['translation']).mean() 110 | RE = np.array(self.metrics['rotation']).mean() 111 | RE_Y = np.array(self.metrics['rot_y']).mean() if len(self.metrics['rot_y']) > 0 else None 112 | RE_Z = np.array(self.metrics['rot_z']).mean() if len(self.metrics['rot_z']) > 0 else None 113 | RE_X = np.array(self.metrics['rot_x']).mean() if len(self.metrics['rot_x']) > 0 else None 114 | 115 | if print_result: 116 | RES = 'Re=%.4f, mae=%.4f, rmse=%.4f, ratio=%.4f, AP=%.4f, AR=%.4f, F1=%.4f, translation=%.4f, rotation=%.4f' % (Re, mae, rmse, ratio, AP, AR, F1, TE, RE) 117 | if RE_X: 118 | RES += ", rot_x=%.4f" % RE_X 119 | if RE_Y: 120 | RES += ", rot_y=%.4f" % RE_Y 121 | if RE_Z: 122 | RES += ", rot_z=%.4f" % RE_Z 123 | print(RES) 124 | 125 | if self.center_ap_eval is not None: 126 | metrics_mAP = self.center_ap_eval.calc_and_display_final_metrics(print_result, plot_result) 127 | else: 128 | metrics_mAP = None, None 129 | 130 | metrics = dict(AP=AP, AR=AR, F1=F1, ratio=ratio, Re=Re, mae=mae, rmse=rmse, all_images=self.metrics, 131 | metrics_mAP=metrics_mAP, translation=TE, rotation=RE, rot_x=RE_X, rot_y=RE_Y, rot_z=RE_Z) 132 | 133 | ######################################################################################################## 134 | # SAVE EVAL RESULTS TO JSON FILE 135 | if metrics is not None: 136 | out_dir = os.path.join(save_dir, self.exp_name, self.save_str()) 137 | os.makedirs(out_dir, exist_ok=True) 138 | 139 | with open(os.path.join(out_dir, 'results.json'), 'w') as file: 140 | file.write(json.dumps(metrics, cls=NumpyEncoder)) 141 | 142 | return metrics -------------------------------------------------------------------------------- /src/config/vicos_towel/test_on_train/test.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | 4 | import torchvision 5 | if 'InterpolationMode' in dir(torchvision.transforms): 6 | from torchvision.transforms import InterpolationMode 7 | else: 8 | from PIL import Image as InterpolationMode 9 | 10 | import torch 11 | from utils import transforms as my_transforms 12 | 13 | VICOS_TOWEL_DATASET_DIR = os.environ.get('VICOS_TOWEL_DATASET_DIR') 14 | 15 | OUTPUT_DIR=os.environ.get('OUTPUT_DIR',default='../exp') 16 | 17 | NUM_FIELDS = 5 18 | TRAIN_SIZE = os.environ.get('TRAIN_SIZE', default=256*3) 19 | TEST_SIZE_WIDTH = int(os.environ.get('TEST_SIZE', default=256*3)) 20 | TEST_SIZE_HEIGHT = int(os.environ.get('TEST_SIZE', default=256*3)) 21 | USE_DEPTH = os.environ.get('USE_DEPTH', default='False').lower() == 'true' 22 | 23 | IN_CHANNELS = 3 24 | 25 | if USE_DEPTH: 26 | IN_CHANNELS = 4 27 | 28 | def img2tags_fn(x): 29 | # from x=/storage/datasets/ClothDataset/bg=/cloth=/rgb/image_0000_viewXX_lsYY_camera0.jpg 30 | # extract bg, cloth, view (configuration and clutter) and ls (lightning 31 | 32 | # extract bg 33 | bg = x.split('/')[-4] 34 | # extract cloth 35 | cloth = x.split('/')[-3] 36 | # extract ls 37 | lightning = int(x.split('/')[-1].split('_')[3].replace('ls', '')) 38 | # extract view 39 | view = int(x.split('/')[-1].split('_')[2].replace('view', '')) 40 | # configuration and clutter are encoded in view, clutter is on if view is odd 41 | clutter = 'on' if view % 2 == 1 else 'off' 42 | configuration = view // 2 43 | 44 | return [bg, cloth, f'lightning={lightning}', f'configuration={configuration}', f'clutter={clutter}'] 45 | 46 | 47 | 48 | model_dir = os.path.join(OUTPUT_DIR, 'vicos_towel', '{args[ablation_str]}', 49 | 'backbone={args[model][kwargs][backbone]}' + f'_size={TRAIN_SIZE}x{TRAIN_SIZE}', 50 | 'num_train_epoch={args[train_settings][n_epochs]}', 51 | 'depth={args[model][kwargs][use_depth]}', 52 | 'multitask_weight={args[train_settings][multitask_weighting][name]}') 53 | 54 | args = dict( 55 | 56 | cuda=True, 57 | display=True, 58 | autoadjust_figure_size=True, 59 | 60 | groundtruth_loading = True, 61 | 62 | save=True, 63 | 64 | save_dir=os.path.join(model_dir,'{args[dataset][kwargs][type]}_results{args[eval_epoch]}',f'test_size={TEST_SIZE_HEIGHT}x{TEST_SIZE_WIDTH}',), 65 | checkpoint_path=os.path.join(model_dir,'checkpoint{args[eval_epoch]}.pth'), 66 | 67 | eval_epoch='', 68 | ablation_str='', 69 | 70 | eval=dict( 71 | # available score types ['mask', 'center', 'hough_energy', 'edge_to_area_ratio_of_mask', 'avg(mask_pix)', 'avg(hough_pix)', 'avg(projected_dist_pix)'] 72 | score_combination_and_thr=[ 73 | { 74 | 'center': [0.1,0.01,0.05,0.15,0.2,0.25,0.3,0.35,0.40,0.45,0.5,0.55,0.60,0.65,0.7,0.75,0.8,0.85,0.9,0.94,0.99], 75 | }, 76 | ], 77 | score_thr_final=[0.01], 78 | skip_center_eval=True, 79 | orientation=dict( 80 | display_best_threshold=False, 81 | tau_thr=[20], # [5,10,20] # try with 5px, 10px and 20px limit 82 | ), 83 | enable_multivariate_eval=dict( 84 | image2tags_fn=img2tags_fn, 85 | ) 86 | ), 87 | visualizer=dict(name='OrientationVisualizeTest', 88 | opts=dict(show_rot_axis=(True,), 89 | impath2name_fn=lambda x: ".".join(x.split('/')[-4:]).replace('.jpg',''))), 90 | 91 | 92 | dataset={ 93 | 'name': 'vicos_towel', 94 | 'kwargs': { 95 | 'normalize': False, 96 | 'root_dir': os.path.abspath(VICOS_TOWEL_DATASET_DIR), 97 | 98 | 'type': 'train', 99 | 'subfolders': [dict(folder='bg=white_desk', data_subfolders=['cloth=big_towel','cloth=checkered_rag_big','cloth=checkered_rag_medium', 'cloth=linen_rag','cloth=small_towel','cloth=towel_rag','cloth=waffle_rag','cloth=waffle_rag_stripes']), 100 | dict(folder='bg=green_checkered', data_subfolders=['cloth=big_towel','cloth=checkered_rag_big','cloth=checkered_rag_medium', 'cloth=linen_rag','cloth=small_towel','cloth=towel_rag','cloth=waffle_rag','cloth=waffle_rag_stripes']), 101 | dict(folder='bg=poster', data_subfolders=['cloth=big_towel','cloth=checkered_rag_big','cloth=checkered_rag_medium', 'cloth=linen_rag','cloth=small_towel','cloth=towel_rag','cloth=waffle_rag','cloth=waffle_rag_stripes']), 102 | dict(folder='bg=red_tablecloth', data_subfolders=['cloth=big_towel','cloth=checkered_rag_big','cloth=checkered_rag_medium', 'cloth=linen_rag','cloth=small_towel','cloth=towel_rag','cloth=waffle_rag','cloth=waffle_rag_stripes']) 103 | ], 104 | 105 | 'fixed_bbox_size': 5, 106 | 'resize_factor': 1, 107 | 'use_depth': USE_DEPTH, 108 | 'use_mean_for_depth_nan': True, 109 | 'transform': my_transforms.get_transform([ 110 | { 111 | 'name': 'ToTensor', 112 | 'opts': { 113 | 'keys': ('image', 'instance', 'label', 'ignore', 'orientation', 'mask') + (('depth',) if USE_DEPTH else ()), 114 | 'type': ( 115 | torch.FloatTensor, torch.ShortTensor, torch.ByteTensor, torch.ByteTensor, torch.FloatTensor, 116 | torch.ByteTensor) + ((torch.FloatTensor, ) if USE_DEPTH else ()), 117 | } 118 | }, 119 | { 120 | 'name': 'Resize', 121 | 'opts': { 122 | 'keys': ('image', 'instance', 'label', 'ignore', 'orientation', 'mask') + (('depth',) if USE_DEPTH else ()), 123 | 'interpolation': (InterpolationMode.BILINEAR, InterpolationMode.NEAREST, InterpolationMode.NEAREST, InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.NEAREST) + ((InterpolationMode.BILINEAR, ) if USE_DEPTH else ()), 124 | 'keys_bbox': ('center',), 125 | 'size': (TEST_SIZE_HEIGHT, TEST_SIZE_WIDTH), 126 | } 127 | }, 128 | 129 | ]), 130 | 'MAX_NUM_CENTERS':16*128, 131 | }, 132 | 'centerdir_gt_opts': dict( 133 | ignore_instance_mask_and_use_closest_center=True, 134 | center_ignore_px=3, 135 | 136 | MAX_NUM_CENTERS=16*128, 137 | ), 138 | 139 | 'batch_size': 1, 140 | 'workers': 0, 141 | }, 142 | 143 | model=dict( 144 | name='fpn', 145 | kwargs={ 146 | 'backbone': 'tu-convnext_base', 147 | 'use_depth': USE_DEPTH, 148 | 'num_classes': [NUM_FIELDS, 1], 149 | 'use_custom_fpn': True, 150 | 'add_output_exp': False, 151 | 'in_channels': IN_CHANNELS, 152 | 'fpn_args': { 153 | 'decoder_segmentation_head_channels': 64, 154 | 'upsampling':4, # required for ConvNext architectures 155 | 'classes_grouping': [(0, 1, 2, 5), (3, 4)], 156 | 'depth_mean': 0.96, 'depth_std': 0.075, 157 | }, 158 | 'init_decoder_gain': 0.1 159 | }, 160 | ), 161 | center_model=dict( 162 | name='CenterOrientationEstimator', 163 | use_learnable_center_estimation=True, 164 | 165 | kwargs=dict( 166 | use_centerdir_radii = False, 167 | 168 | # use vector magnitude as mask instead of regressed mask 169 | use_magnitude_as_mask=True, 170 | # thresholds for conv2d processing 171 | local_max_thr=0.01, local_max_thr_use_abs=True, 172 | 173 | ### dilated neural net as head for center detection 174 | use_dilated_nn=True, 175 | dilated_nn_args=dict( 176 | return_sigmoid=False, 177 | # single scale version (nn6) 178 | inner_ch=16, 179 | inner_kernel=3, 180 | dilations=[1, 4, 8, 12], 181 | use_centerdir_radii=False, 182 | use_centerdir_magnitude=False, 183 | use_cls_mask=False 184 | ), 185 | augmentation=False, 186 | scale_r=1.0, # 1024 187 | scale_r_gt=1, # 1 188 | use_log_r=False, 189 | use_log_r_base='10', 190 | enable_6dof=False, 191 | ), 192 | 193 | ), 194 | num_vector_fields=NUM_FIELDS, 195 | 196 | # settings from train config needed for automated path construction 197 | train_settings=dict( 198 | n_epochs=10, 199 | multitask_weighting=dict(name='uw'), 200 | ) 201 | ) 202 | 203 | def get_args(): 204 | return copy.deepcopy(args) 205 | -------------------------------------------------------------------------------- /src/inference/processing.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from tqdm import tqdm 4 | 5 | import numpy as np 6 | 7 | import torch 8 | 9 | from models.center_groundtruth import CenterDirGroundtruth 10 | 11 | class CenterDirProcesser: 12 | ''' 13 | Main inference processing class for centerdir models (with center detections). 14 | 15 | When main function is call it will iterate over each image from the provided dataset iterator and process them with 16 | provided self.model and self.center_model. 17 | 18 | Class is able to handle the following items in the input data (sample from the dataset): 19 | - required keys: image, im_name, instance, 20 | - optional keys: centerdir_groundtruth, ignore, center 21 | 22 | The following keys are emitted for each processed/merged image: 23 | - as reference input data: + center_dict 24 | - as processed output data: output, predictions,pred_heatmap 25 | ''' 26 | def __init__(self, model, center_model_list, device=None): 27 | self.model = model 28 | 29 | self.center_model_list = center_model_list 30 | 31 | self.device = device 32 | 33 | def get_center_model_list(self): 34 | return self.center_model_list 35 | 36 | def clean_memory(self): 37 | self.model.cpu() 38 | self.model = None 39 | 40 | if self.center_model_list is not None: 41 | for center_model_desc in self.get_center_model_list(): 42 | center_model_desc['model'].cpu() 43 | center_model_desc['model'] = None 44 | 45 | 46 | def __call__(self, dataset_it, centerdir_groundtruth_op=None, tqdm_kwargs={}): 47 | 48 | assert self.model is not None 49 | self.model.eval() 50 | 51 | for center_model_desc in self.get_center_model_list(): 52 | assert center_model_desc['model'] is not None 53 | center_model_desc['model'].eval() 54 | 55 | im_image = 0 56 | 57 | for sample_ in tqdm(dataset_it, **tqdm_kwargs): 58 | 59 | # call centerdir_groundtruth_op first which will create any missing centerdir_groundtruth (using GPU) and add synthetic output 60 | if centerdir_groundtruth_op is not None: 61 | sample_ = centerdir_groundtruth_op(sample_, torch.arange(0, dataset_it.batch_size).int()) 62 | model = self.model 63 | 64 | im_batch = sample_['image'] 65 | 66 | output_batch_ = model(im_batch) 67 | 68 | for center_model_desc in self.get_center_model_list(): 69 | center_model_name = center_model_desc['name'] 70 | center_model = center_model_desc['model'] 71 | 72 | # run center detection model 73 | center_output = center_model(output_batch_, **sample_) 74 | 75 | output_batch, center_pred, center_heatmap = [center_output[k] for k in ['output', 76 | 'center_pred', 77 | 'center_heatmap']] 78 | # optional output 79 | pred_angle = center_output.get('pred_angle') 80 | 81 | # extract centers either from 'centerdir_groundtruth' or from 'center' in sample 82 | gt_centers = None 83 | if 'centerdir_groundtruth' in sample_: 84 | gt_centers = CenterDirGroundtruth.parse_groundtruth_map(sample_['centerdir_groundtruth'],keys=['gt_centers']) 85 | elif 'center' in sample_: 86 | gt_centers = sample_['center'][:,:,[1,0]] 87 | 88 | if gt_centers is not None: 89 | # get gt_centers from centerdir_gt and convert them to dictionary (filter-out non visible and ignored examples) 90 | # if ignore_flags is present then set to remove all groundtruths where ONLY ignore flag (encoded as 1) is present but not others 91 | # (do not remove other types such as truncated, overlap border, difficult) 92 | instances = sample_['instance'].squeeze(dim=1) 93 | center_ignore = sample_['ignore'] == 1 if 'ignore' in sample_ else None 94 | 95 | gt_centers_dict = CenterDirGroundtruth.convert_gt_centers_to_dictionary(gt_centers, 96 | instances=instances, 97 | ignore=center_ignore) 98 | else: 99 | gt_centers_dict = [] 100 | 101 | sample_keys = sample_.keys() 102 | 103 | for batch_i in range(min(dataset_it.batch_size, len(sample_['im_name']))): 104 | 105 | im_image += 1 106 | output = output_batch[batch_i:batch_i + 1] 107 | 108 | sample = {k: sample_[k][batch_i:batch_i + 1] for k in sample_keys} 109 | 110 | im = sample['image'] 111 | im_name = sample['im_name'][0] 112 | base, _ = os.path.splitext(os.path.basename(im_name)) 113 | 114 | instance = sample['instance'].squeeze() 115 | ignore = sample.get('ignore') 116 | 117 | if 'centerdir_groundtruth' in sample_: 118 | sample['centerdir_groundtruth'] = sample_['centerdir_groundtruth'][0][batch_i] 119 | 120 | if len(gt_centers_dict) > 0: 121 | center_dict = gt_centers_dict[batch_i] 122 | 123 | # manually remove instance that have been ignored 124 | if ignore is not None: 125 | for id in instance.unique(): 126 | id = id.item() 127 | if id > 0 and id not in center_dict.keys(): 128 | instance[instance == id] = 0 129 | else: 130 | center_dict = None 131 | 132 | # extract prediction heatmap and sorted prediction list 133 | pred_heatmap = torch.relu(center_heatmap[batch_i].unsqueeze(0)) 134 | predictions = center_pred[batch_i][center_pred[batch_i,:,0] == 1][:,1:].cpu().numpy() 135 | 136 | idx = np.argsort(predictions[:, -1]) 137 | idx = idx[::-1] 138 | predictions = predictions[idx, :] 139 | 140 | # sort predicted angle if present 141 | pred_angle_b = pred_angle[batch_i].cpu().numpy()[idx, :] if pred_angle is not None else None 142 | 143 | assert len(pred_angle_b) == len(predictions) 144 | 145 | # simply return all input and output data 146 | output_dict = dict(output=output, 147 | predictions=predictions, 148 | pred_heatmap=pred_heatmap, 149 | pred_angle=pred_angle_b, 150 | center_model_name=center_model_name) 151 | 152 | # update sample data with center_dict and override im_name and instance 153 | sample.update(dict(im_name=im_name, # override to remove dimension 154 | instance=instance, # override to remove dimension 155 | center_dict=center_dict,)) 156 | 157 | yield sample, output_dict 158 | 159 | -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import threading 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import pandas as pd 6 | 7 | import torch 8 | 9 | from utils.visualize.vis import Visualizer 10 | 11 | import scipy 12 | import torch.nn as nn 13 | 14 | class AverageMeter(object): 15 | 16 | def __init__(self, num_classes=1): 17 | self.num_classes = num_classes 18 | self.reset() 19 | self.lock = threading.Lock() 20 | 21 | def reset(self): 22 | self.sum = [0] * self.num_classes 23 | self.count = [0] * self.num_classes 24 | self.avg_per_class = [0] * self.num_classes 25 | self.avg = 0 26 | 27 | def update(self, val, cl=0): 28 | with self.lock: 29 | self.sum[cl] += val 30 | self.count[cl] += 1 31 | self.avg_per_class = [ 32 | x/y if x > 0 else 0 for x, y in zip(self.sum, self.count)] 33 | self.avg = sum(self.avg_per_class)/len(self.avg_per_class) 34 | 35 | 36 | class Logger: 37 | 38 | def __init__(self, keys, title=""): 39 | 40 | self.data = {k: [] for k in keys} 41 | self.title = title 42 | self.win = None 43 | 44 | print('created logger with keys: {}'.format(keys)) 45 | 46 | def plot(self, save=False, save_dir=""): 47 | 48 | if self.win is None: 49 | self.win = plt.subplots() 50 | fig, ax = self.win 51 | ax.cla() 52 | 53 | keys = [] 54 | for key in self.data: 55 | keys.append(key) 56 | data = self.data[key] 57 | ax.plot(range(len(data)), data, marker='.') 58 | 59 | ax.legend(keys, loc='upper right') 60 | ax.set_title(self.title) 61 | 62 | plt.draw() 63 | Visualizer.mypause(0.001) 64 | 65 | if save: 66 | # save figure 67 | fig.savefig(os.path.join(save_dir, self.title + '.png')) 68 | 69 | # save data as csv 70 | df = pd.DataFrame.from_dict(self.data) 71 | df.to_csv(os.path.join(save_dir, self.title + '.csv')) 72 | 73 | def add(self, key, value): 74 | assert key in self.data, "Key not in data" 75 | self.data[key].append(value) 76 | 77 | class GaussianLayer(nn.Module): 78 | def __init__(self, num_channels=1, sigma=3): 79 | super(GaussianLayer, self).__init__() 80 | 81 | self.sigma = sigma 82 | self.kernel_size = int(2 * np.ceil(3*self.sigma - 0.5) + 1) 83 | 84 | self.conv = nn.Conv2d(num_channels, num_channels, self.kernel_size, stride=1, 85 | padding=self.kernel_size//2, bias=None, groups=num_channels) 86 | 87 | 88 | self.weights_init() 89 | def forward(self, x): 90 | return self.conv(x) 91 | 92 | def weights_init(self): 93 | n = np.zeros((self.kernel_size,self.kernel_size)) 94 | n[self.kernel_size//2,self.kernel_size//2] = 1 95 | k = scipy.ndimage.gaussian_filter(n,sigma=self.sigma) 96 | for name, f in self.named_parameters(): 97 | f.data.copy_(torch.from_numpy(k)) 98 | 99 | 100 | def tensor_mask_to_ids(mask): 101 | ids = {i.item(): (mask == i).nonzero().cpu().numpy() 102 | for i in mask.unique() if i > 0} 103 | ids = {i: set(np.ravel_multi_index((np.array(i_mask)[:, 0], np.array(i_mask)[:, 1]), dims=mask.shape[-2:])) 104 | for i, i_mask in ids.items()} 105 | 106 | return ids 107 | 108 | def ids_to_tensor_maks(ids, out_shape): 109 | out = np.zeros(out_shape) 110 | for i, indices in ids.items(): 111 | indices = np.unravel_index(indices,out_shape) 112 | out[(indices[0], indices[1])] = i 113 | 114 | return out 115 | 116 | def instance_poly_to_variable_array(polygon_list): 117 | import numpy as np 118 | 119 | if polygon_list is not None and len(polygon_list) > 0: 120 | if type(polygon_list) in [list,tuple]: 121 | # convert from list of [Nx2] to [Nx3] where first value in second axis defines ID of instance 122 | polygon_list = [np.concatenate(((i + 1) * np.ones((len(p), 1)), p), axis=1) for i, p in enumerate(polygon_list)] 123 | polygon_list = np.concatenate(polygon_list, axis=0) 124 | elif len(polygon_list.shape) == 3: 125 | idx = np.expand_dims(np.repeat(np.expand_dims(np.arange(len(polygon_list)),1),(polygon_list.shape[1]),axis=1),2) + 1 126 | polygon_list = np.concatenate((idx,polygon_list),axis=2) 127 | elif len(polygon_list.shape) != 2 or polygon_list.shape[1] != 3: 128 | raise Exception("Invalid input polygon_list: should be list of Nx2, array of Nx2 or array of Nx3") 129 | else: 130 | polygon_list = np.zeros((0,3)) 131 | 132 | return polygon_list 133 | 134 | import torch 135 | import re 136 | import collections 137 | #from torch._six import string_classes # torch._six does not exist in latest pytorch !! 138 | string_classes = (str,) 139 | 140 | np_str_obj_array_pattern = re.compile(r'[SaUO]') 141 | 142 | default_collate_err_msg_format = ( 143 | "default_collate: batch must contain tensors, numpy arrays, numbers, " 144 | "dicts or lists; found {}") 145 | 146 | from torch.nn.utils.rnn import pad_sequence 147 | 148 | def variable_len_collate(batch, batch_first=True, padding_value=0): 149 | r"""Puts each data field into a tensor with outer dimension batch size""" 150 | 151 | elem = batch[0] 152 | elem_type = type(elem) 153 | if isinstance(elem, torch.Tensor): 154 | out = None 155 | numel = [x.numel() for x in batch] 156 | if torch.utils.data.get_worker_info() is not None: 157 | # If we're in a background process, concatenate directly into a 158 | # shared memory tensor to avoid an extra copy 159 | storage = elem.storage()._new_shared(sum(numel)) 160 | out = elem.new(storage) 161 | return torch.stack(batch, 0, out=out) if np.all(numel[0] == numel) else pad_sequence(batch, 162 | batch_first=batch_first, 163 | padding_value=padding_value) 164 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 165 | and elem_type.__name__ != 'string_': 166 | if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': 167 | # array of string classes and object 168 | if np_str_obj_array_pattern.search(elem.dtype.str) is not None: 169 | raise TypeError(default_collate_err_msg_format.format(elem.dtype)) 170 | 171 | return variable_len_collate([torch.as_tensor(b) for b in batch]) 172 | elif elem.shape == (): # scalars 173 | return torch.as_tensor(batch) 174 | elif isinstance(elem, float): 175 | return torch.tensor(batch, dtype=torch.float64) 176 | elif isinstance(elem, int): 177 | return torch.tensor(batch) 178 | elif isinstance(elem, string_classes): 179 | return batch 180 | elif isinstance(elem, collections.abc.Mapping): 181 | return {key: variable_len_collate([d[key] for d in batch]) for key in elem} 182 | elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple 183 | return elem_type(*(variable_len_collate(samples) for samples in zip(*batch))) 184 | elif isinstance(elem, collections.abc.Sequence): 185 | # check to make sure that the elements in batch have consistent size 186 | it = iter(batch) 187 | elem_size = len(next(it)) 188 | if not all(len(elem) == elem_size for elem in it): 189 | raise RuntimeError('each element in list of batch should be of equal size') 190 | transposed = zip(*batch) 191 | return [variable_len_collate(samples) for samples in transposed] 192 | 193 | raise TypeError(default_collate_err_msg_format.format(elem_type)) 194 | -------------------------------------------------------------------------------- /src/config/vicos_towel/train.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | 4 | import torch 5 | from utils import transforms as my_transforms 6 | from torchvision.transforms import InterpolationMode 7 | 8 | VICOS_TOWEL_DATASET_DIR = os.environ.get('VICOS_TOWEL_DATASET_DIR') 9 | 10 | OUTPUT_DIR=os.environ.get('OUTPUT_DIR',default='../exp') 11 | 12 | NUM_FIELDS = 5 13 | SIZE = int(os.environ.get('TRAIN_SIZE', default=256*3)) 14 | USE_DEPTH = os.environ.get('USE_DEPTH', default='False').lower() == 'true' 15 | 16 | IN_CHANNELS = 3 17 | if USE_DEPTH: 18 | IN_CHANNELS = 4 19 | 20 | args = dict( 21 | 22 | cuda=True, 23 | display=False, 24 | display_it=20, 25 | 26 | tf_logging=['loss'], 27 | tf_logging_iter=2, 28 | 29 | visualizer=dict(name='OrientationVisualizeTrain'), 30 | 31 | save=True, 32 | save_interval=2, 33 | 34 | # -------- 35 | n_epochs=10, 36 | ablation_str="", 37 | 38 | save_dir=os.path.join(OUTPUT_DIR, 'vicos_towel', '{args[ablation_str]}', 39 | 'backbone={args[model][kwargs][backbone]}' + f'_size={SIZE}x{SIZE}', 40 | 'num_train_epoch={args[n_epochs]}', 41 | 'depth={args[model][kwargs][use_depth]}', 42 | 'multitask_weight={args[multitask_weighting][name]}', 43 | ), 44 | 45 | 46 | pretrained_model_path = None, 47 | resume_path = None, 48 | 49 | pretrained_center_model_path = None, 50 | 51 | 52 | train_dataset = { 53 | 'name': 'vicos_towel', 54 | 'kwargs': { 55 | 'normalize': False, 56 | 'root_dir': os.path.abspath(VICOS_TOWEL_DATASET_DIR), 57 | 'subfolders': [dict(folder='bg=white_desk', data_subfolders=['cloth=big_towel','cloth=checkered_rag_big','cloth=checkered_rag_medium', 'cloth=linen_rag','cloth=small_towel','cloth=towel_rag','cloth=waffle_rag','cloth=waffle_rag_stripes']), 58 | dict(folder='bg=green_checkered', data_subfolders=['cloth=big_towel','cloth=checkered_rag_big','cloth=checkered_rag_medium', 'cloth=linen_rag','cloth=small_towel','cloth=towel_rag','cloth=waffle_rag','cloth=waffle_rag_stripes']), 59 | dict(folder='bg=poster', data_subfolders=['cloth=big_towel','cloth=checkered_rag_big','cloth=checkered_rag_medium', 'cloth=linen_rag','cloth=small_towel','cloth=towel_rag','cloth=waffle_rag','cloth=waffle_rag_stripes']), 60 | dict(folder='bg=red_tablecloth', data_subfolders=['cloth=big_towel','cloth=checkered_rag_big','cloth=checkered_rag_medium', 'cloth=linen_rag','cloth=small_towel','cloth=towel_rag','cloth=waffle_rag','cloth=waffle_rag_stripes']) 61 | ], 62 | 'fixed_bbox_size': 15, 63 | 'resize_factor': 1, 64 | 'use_depth': USE_DEPTH, 65 | 'correct_depth_rotation': False, 66 | 'use_mean_for_depth_nan': True, 67 | 'use_normals': False, 68 | 'transform_per_sample_rng': True, # TRUE == RA-L version with fixed RNG for each sample (which may not be random as intended but gets good results !!) 69 | 'transform': my_transforms.get_transform([ 70 | # for training without augmentation (same as testing) 71 | { 72 | 'name': 'ToTensor', 73 | 'opts': { 74 | 'keys': ('image', 'instance', 'label', 'ignore', 'orientation', 'mask') + (('depth',) if USE_DEPTH else ()), 75 | 'type': ( 76 | torch.FloatTensor, torch.ShortTensor, torch.ByteTensor, torch.ByteTensor, torch.FloatTensor, torch.ByteTensor) + ((torch.FloatTensor, ) if USE_DEPTH else ()), 77 | } 78 | }, 79 | { 80 | 'name': 'Resize', 81 | 'opts': { 82 | 'keys': ('image', 'instance', 'label', 'ignore', 'orientation', 'mask') + (('depth',) if USE_DEPTH else ()), 83 | 'interpolation': (InterpolationMode.BILINEAR, InterpolationMode.NEAREST, InterpolationMode.NEAREST, InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.NEAREST) + ((InterpolationMode.BILINEAR, ) if USE_DEPTH else ()), 84 | 'keys_bbox': ('center',), 85 | 'size': (SIZE, SIZE), 86 | } 87 | }, 88 | # for training with random augmentation 89 | { 90 | 'name': 'RandomGaussianBlur', 91 | 'opts': { 92 | 'keys': ('image',), 93 | 'rate': 0.5, 'sigma': [0.5, 2] 94 | } 95 | }, 96 | 97 | { 98 | 'name': 'ColorJitter', 99 | 'opts': { 100 | 'keys': ('image',), 'p': 0.5, 101 | 'saturation': 0.3, 'hue': 0.3, 'brightness': 0.3, 'contrast':0.3 102 | } 103 | } 104 | 105 | ]), 106 | 'MAX_NUM_CENTERS':16*128, 107 | }, 108 | 109 | 'centerdir_gt_opts': dict( 110 | ignore_instance_mask_and_use_closest_center=True, # by default 111 | center_ignore_px=3, 112 | 113 | skip_gt_center_mask_generate=True, # gt_center_mask is not needed since we are not training localization network 114 | 115 | MAX_NUM_CENTERS=16*128, 116 | ), 117 | 118 | 'batch_size': 4, 119 | 120 | # hard example disabled 121 | 'hard_samples_size': 0, 122 | 'hard_samples_selected_min_percent':0.1, 123 | 'workers': 4, 124 | 'shuffle': True, 125 | }, 126 | 127 | model = dict( 128 | name='fpn', 129 | kwargs= { 130 | 'backbone': 'tu-convnext_base', 131 | 'use_depth': USE_DEPTH, 132 | 'num_classes': [NUM_FIELDS, 1], 133 | 'use_custom_fpn':True, 134 | 'add_output_exp': False, 135 | 'in_channels': IN_CHANNELS, 136 | 137 | 'fpn_args': { 138 | 'decoder_segmentation_head_channels':64, 139 | 'upsampling':4, # required for ConvNext architectures 140 | 'classes_grouping': [(0, 1, 2, 5), (3, 4)], 141 | 'depth_mean': 0.96, 'depth_std':0.075, 142 | }, 143 | 'init_decoder_gain': 0.1 144 | }, 145 | optimizer='Adam', 146 | lr=1e-4, 147 | weight_decay=0, 148 | 149 | ), 150 | center_model=dict( 151 | name='CenterEstimator', 152 | kwargs=dict( 153 | # use vector magnitude as mask instead of regressed mask 154 | use_magnitude_as_mask=False, 155 | # thresholds for conv2d processing 156 | local_max_thr=0.1, mask_thr=0.01, exclude_border_px=0, 157 | use_dilated_nn=True, 158 | dilated_nn_args=dict( 159 | # single scale version (nn6) 160 | inner_ch=16, 161 | inner_kernel=3, 162 | dilations=[1, 4, 8, 12], 163 | freeze_learning=True, 164 | gradpass_relu=False, 165 | # version with leaky relu 166 | leaky_relu=False, 167 | # input check 168 | # use_polar_radii=False, 169 | use_centerdir_radii = False, 170 | use_centerdir_magnitude = False, 171 | use_cls_mask = False 172 | ), 173 | allow_input_backprop=False, 174 | backprop_only_positive=False, 175 | augmentation=False, 176 | scale_r=1.0, # 1024 177 | scale_r_gt=1024, # 1 178 | use_log_r=True, 179 | use_log_r_base='10', 180 | enable_6dof=False, 181 | ), 182 | # DISABLE TRAINING 183 | optimizer='Adam', 184 | lr=0, 185 | weight_decay=0, 186 | ), 187 | 188 | 189 | # loss options 190 | loss_type='OrientationLoss', 191 | loss_opts={ 192 | 'num_vector_fields': NUM_FIELDS, 193 | 'foreground_weight': 1, 194 | 195 | 'enable_centerdir_loss': True, 196 | 'no_instance_loss': True, # MUST be True to ignore instance mask 197 | 'centerdir_instance_weighted': True, 198 | 'regression_loss': 'l1', 199 | 200 | 'use_log_r': True, 201 | 'use_log_r_base': '10', 202 | 203 | 'orientation_args': dict( 204 | enable=True, 205 | no_instance_loss=False, 206 | regression_loss='l1', 207 | enable_6dof=False, 208 | symmetries=None, 209 | ) 210 | }, 211 | multitask_weighting=dict( 212 | name='uw', 213 | kwargs=dict( 214 | n_tasks=2 215 | ) 216 | ), 217 | 218 | loss_w={ 219 | 'w_r': 1, 220 | 'w_cos': 1, 221 | 'w_sin': 1, 222 | 'w_cent': 0.1, 223 | 'w_orientation': 1, 224 | }, 225 | 226 | ) 227 | 228 | args['lambda_scheduler_fn']=lambda _args: (lambda epoch: pow((1-((epoch)/_args['n_epochs'])), 0.9)) 229 | #args['lambda_scheduler_fn']=lambda _args: (lambda epoch: 1.0) # disabled 230 | 231 | args['model']['lambda_scheduler_fn'] = args['lambda_scheduler_fn'] 232 | args['center_model']['lambda_scheduler_fn'] = lambda _args: (lambda epoch: pow((1-((epoch)/_args['n_epochs'])), 0.9) if epoch > 1 else 0) 233 | 234 | 235 | def get_args(): 236 | return copy.deepcopy(args) 237 | -------------------------------------------------------------------------------- /scripts/EXPERIMENTS_ABLATION.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # include config and some utils 4 | source ./config.sh 5 | source ./utils.sh 6 | 7 | export USE_CONDA_ENV=CeDiRNet-py3.8 8 | export DISABLE_X11=0 9 | 10 | centernet_filename="${ROOT_DIR}/models/localization_checkpoint.pth" 11 | 12 | DO_SYNT_TRAINING=True # step 0: pretraining on syntetic data (MuJoCo) 13 | DO_REAL_TRAINING=True # step 1: training on real-world data (ViCoS Towel Dataset) 14 | DO_EVALUATION=True # step 3: evaluate 15 | 16 | # assuming 4 GPUs available on localhost 17 | GPU_LIST=("localhost:0" "localhost:1" "localhost:2" "localhost:4") 18 | GPU_COUNT=${#GPU_LIST[@]} 19 | 20 | ######################################## 21 | # PRETRAINING on synthetic data only 22 | ######################################## 23 | 24 | if [[ "$DO_SYNT_TRAINING" == True ]] ; then 25 | s=0 26 | for db in "mujoco"; do 27 | export DATASET=$db 28 | for cfg_subname in ""; do 29 | for backbone in "tu-convnext_large"; do # "tu-convnext_base" 30 | for epoch in 10; do 31 | for abl in no_orient no_seg_head no_uncert_weight; do 32 | ABLATION_SETTINGS="" 33 | abl_str="" 34 | if [[ "$abl" == no_orient ]] ; then 35 | ABLATION_SETTINGS="multitask_weighting.name=off model.kwargs.fpn_args.classes_grouping=None loss_opts.orientation_args.enable=False" 36 | abl_str="ablation=$abl" 37 | elif [[ "$abl" == no_seg_head ]] ; then 38 | ABLATION_SETTINGS="multitask_weighting.name=off model.kwargs.fpn_args.classes_grouping=None" 39 | abl_str="ablation=$abl" 40 | elif [[ "$abl" == no_uncert_weight ]] ; then 41 | ABLATION_SETTINGS="multitask_weighting.name=off" 42 | abl_str="ablation=$abl" 43 | fi 44 | 45 | for depth in on; do # on off 46 | if [[ "$depth" == off ]] ; then 47 | export USE_DEPTH=False 48 | elif [[ "$depth" == on ]] ; then 49 | export USE_DEPTH=True 50 | fi 51 | SERVERS=${GPU_LIST[$((s % GPU_COUNT))]} ./run_distributed.sh python train.py --cfg_subname=$cfg_subname --config \ 52 | model.kwargs.backbone=$backbone \ 53 | ablation_str=$abl_str \ 54 | n_epochs=$epoch \ 55 | train_dataset.batch_size=8 \ 56 | train_dataset.workers=16 \ 57 | "pretrained_center_model_path=$centernet_filename" \ 58 | display=False save_interval=1 skip_if_exists=True $ABLATION_SETTINGS & 59 | s=$((s+1)) 60 | wait_or_interrupt $GPU_COUNT $s 61 | done 62 | done 63 | done 64 | done 65 | done 66 | done 67 | fi 68 | wait_or_interrupt 69 | 70 | ######################################## 71 | # Training on real data 72 | ######################################## 73 | 74 | if [[ "$DO_REAL_TRAINING" == True ]] ; then 75 | s=0 76 | 77 | for db in "vicos_cloth"; do 78 | export DATASET=$db 79 | for cfg_subname in "" ; do 80 | export TRAIN_SIZE=768 81 | for backbone in "tu-convnext_large"; do # "tu-convnext_base" 82 | for epoch in 10; do 83 | for depth in on; do # on off 84 | if [[ "$depth" == off ]] ; then 85 | export USE_DEPTH=False 86 | depth_str=False 87 | elif [[ "$depth" == on ]] ; then 88 | export USE_DEPTH=True 89 | depth_str=True 90 | fi 91 | for abl in no_orient no_seg_head no_uncert_weight; do # off 92 | multitask_w="uw" 93 | ABLATION_SETTINGS="" 94 | abl_str="" 95 | if [[ "$abl" == no_orient ]] ; then 96 | ABLATION_SETTINGS="multitask_weighting.name=off model.kwargs.fpn_args.classes_grouping=None loss_opts.orientation_args.enable=False" 97 | abl_str="ablation=$abl" 98 | multitask_w="off" 99 | elif [[ "$abl" == no_seg_head ]] ; then 100 | ABLATION_SETTINGS="multitask_weighting.name=off model.kwargs.fpn_args.classes_grouping=None" 101 | abl_str="ablation=$abl" 102 | multitask_w="off" 103 | elif [[ "$abl" == no_uncert_weight ]] ; then 104 | ABLATION_SETTINGS="multitask_weighting.name=off" 105 | abl_str="ablation=$abl" 106 | multitask_w="off" 107 | fi 108 | PRETRAINED_CHECKPOINT="${OUTPUT_DIR}/mujoco/${abl_str}/backbone=${backbone}/num_train_epoch=10/depth=$depth_str/multitask_weight=${multitask_w}/checkpoint.pth" 109 | SERVERS=${GPU_LIST[$((s % GPU_COUNT))]} ./run_distributed.sh python train.py --cfg_subname $cfg_subname --config \ 110 | model.kwargs.backbone=$backbone \ 111 | ablation_str=$abl_str \ 112 | n_epochs=$epoch \ 113 | train_dataset.batch_size=4 \ 114 | multitask_weighting.name=$multitask_w \ 115 | "pretrained_center_model_path=$centernet_filename" \ 116 | "pretrained_model_path=$PRETRAINED_CHECKPOINT" \ 117 | display=True skip_if_exists=True save_interval=2 $ABLATION_SETTINGS & 118 | s=$((s+1)) 119 | wait_or_interrupt $GPU_COUNT $s 120 | done 121 | done 122 | done 123 | done 124 | done 125 | done 126 | fi 127 | wait_or_interrupt 128 | 129 | ######################################## 130 | # Evaluating on test data 131 | ######################################## 132 | 133 | if [[ "$DO_EVALUATION" == True ]] ; then 134 | 135 | # FOR FULL EVAL 136 | DISPLAY_ARGS="display=False display_to_file_only=True skip_if_exists=True" 137 | 138 | s=0 139 | export DISABLE_X11=0 140 | for db in "vicos_towel"; do 141 | for cfg_subname in ""; do # for exclusively unseen objects set to "novel_object=bg+cloth" (or to "novel_object=cloth" "novel_object=bg") 142 | export DATASET=$db 143 | export TRAIN_SIZE=768 144 | export TEST_SIZE=768 145 | for backbone in "tu-convnext_large"; do # "tu-convnext_base" 146 | for epoch_train in 10; do 147 | ALL_EPOCH=("") # set to ALL_EPOCH=("" _002 _004 _006 _008) to evaluate every second epoch 148 | for epoch_eval in "${ALL_EPOCH[@]}"; do 149 | for depth in on; do # off 150 | if [[ "$depth" == off ]] ; then 151 | export USE_DEPTH=False USE_NORMALS=False NORMALS_MODE=1 152 | elif [[ "$depth" == on ]] ; then 153 | export USE_DEPTH=True USE_NORMALS=False NORMALS_MODE=1 154 | fi 155 | 156 | for abl in no_orient no_seg_head no_uncert_weight; do # off 157 | ABLATION_SETTINGS="" 158 | 159 | abl_str="" 160 | if [[ "$abl" == no_orient ]] ; then 161 | ABLATION_SETTINGS="model.kwargs.fpn_args.classes_grouping=None" 162 | multitask_w=off 163 | abl_str="ablation=$abl" 164 | elif [[ "$abl" == no_seg_head ]] ; then 165 | ABLATION_SETTINGS="model.kwargs.fpn_args.classes_grouping=None" 166 | multitask_w=off 167 | abl_str="ablation=$abl" 168 | elif [[ "$abl" == no_uncert_weight ]] ; then 169 | multitask_w=off 170 | abl_str="ablation=$abl" 171 | fi 172 | # run center model pre-trained on weakly-supervised 173 | SERVERS=${GPU_LIST[$((s % GPU_COUNT))]} ./run_distributed.sh python train.py --cfg_subname $cfg_subname --config \ 174 | eval_epoch=$epoch_eval \ 175 | ablation_str=$abl_str \ 176 | model.kwargs.backbone=$backbone \ 177 | train_settings.n_epochs=$epoch_train \ 178 | train_settings.multitask_weighting.name=$multitask_w \ 179 | $DISPLAY_ARGS $ABLATION_SETTINGS & 180 | 181 | s=$((s+1)) 182 | wait_or_interrupt $GPU_COUNT $s 183 | done 184 | done 185 | done 186 | done 187 | done 188 | done 189 | done 190 | done 191 | done 192 | done 193 | fi 194 | wait -------------------------------------------------------------------------------- /src/models/center_estimator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import numpy as np 5 | 6 | from models.center_augmentator import CenterAugmentator 7 | 8 | from models.localization.centers import Conv1dMultiscaleLocalization, Conv2dDilatedLocalization 9 | 10 | import torch.nn.functional as F 11 | 12 | class CenterEstimator(nn.Module): 13 | def __init__(self, args=dict(), is_learnable=True): 14 | super().__init__() 15 | 16 | self.return_backbone_only = False 17 | 18 | self.use_magnitude_as_mask = args.get('use_magnitude_as_mask') 19 | 20 | instance_center_estimator_op = Conv1dMultiscaleLocalization 21 | if args.get('use_dilated_nn'): 22 | from functools import partial 23 | instance_center_estimator_op = partial(Conv2dDilatedLocalization, 24 | **args.get('dilated_nn_args',{})) 25 | 26 | self.instance_center_estimator = instance_center_estimator_op( 27 | local_max_thr=args.get('local_max_thr', 0.1), 28 | mask_thr=args.get('mask_thr', 0.01), 29 | exclude_border_px=args.get('exclude_border_px', 5), 30 | learnable=is_learnable, 31 | allow_input_backprop=args.get('allow_input_backprop', True), 32 | backprop_only_positive=args.get('backprop_only_positive', True), 33 | apply_input_smoothing_for_local_max=args.get('apply_input_smoothing_for_local_max', 1), 34 | use_findcontours_for_local_max=args.get('use_findcontours_for_local_max', False), 35 | ) 36 | 37 | if args.get('augmentation'): 38 | self.center_augmentator = CenterAugmentator( 39 | **args.get('augmentation_kwargs'), 40 | ) 41 | else: 42 | self.center_augmentator = None 43 | 44 | scale_r = 1024 if 'scale_r' not in args else args['scale_r'] 45 | scale_r_gt = 1 if 'scale_r_gt' not in args else args['scale_r_gt'] 46 | 47 | self.scale_r_fn = lambda x: x * scale_r 48 | self.scale_r_gt_fn = lambda x: x * scale_r_gt 49 | 50 | self.inverse_scale_r_fn = lambda x: x / scale_r 51 | 52 | self.use_log_r = args['use_log_r'] if 'use_log_r' in args else True 53 | use_log_r_base = args['use_log_r_base'] if 'use_log_r_base' in args else 'exp' 54 | 55 | if use_log_r_base.lower() in ['exp', 'e']: 56 | self.log_r_fn = lambda x: torch.log(x+1) 57 | self.inverse_log_r_fn = lambda x: torch.exp(x)-1 58 | elif use_log_r_base.lower() in ['decimal', '10']: 59 | self.log_r_fn = lambda x: torch.log10(x+1) 60 | self.inverse_log_r_fn = lambda x: torch.pow(10, x)-1 61 | elif use_log_r_base.lower() in ['pow10']: 62 | self.log_r_fn = lambda x: torch.log10(x+1) 63 | self.inverse_log_r_fn = lambda x: torch.pow(x, 10)-1 64 | else: 65 | raise Exception('Only "exp" and "10" are allowed logarithms for R') 66 | 67 | self.MAX_NUM_CENTERS = 16*128 68 | 69 | def set_return_backbone_only(self, val): 70 | self.return_backbone_only = val 71 | 72 | def is_return_backbone_only(self): 73 | return self.return_backbone_only 74 | 75 | def init_output(self, num_vector_fields=1): 76 | self.num_vector_fields = num_vector_fields 77 | 78 | assert self.num_vector_fields >= 3 79 | self.instance_center_estimator.init_output() 80 | 81 | return input 82 | 83 | def forward(self, input, ignore_gt=False, **gt): 84 | if self.center_augmentator is not None: 85 | input = self.center_augmentator(input, **gt) 86 | 87 | ignore = gt.get('ignore') 88 | 89 | assert input.shape[1] >= self.num_vector_fields 90 | 91 | predictions = input[:, 0:self.num_vector_fields] 92 | 93 | S = predictions[:, 0].unsqueeze(1) 94 | C = predictions[:, 1].unsqueeze(1) 95 | R = predictions[:, 2].unsqueeze(1) 96 | 97 | R = self.inverse_log_r_fn(self.scale_r_fn(R)) 98 | 99 | cls_mask = torch.zeros_like(S, requires_grad=False) 100 | M = torch.zeros_like(S, requires_grad=False) 101 | 102 | pred_mask = None 103 | 104 | if self.training: 105 | # during training only detect centers but do not do any filtering 106 | center_pred, conv_resp = self.instance_center_estimator(C, S, R, M, cls_mask) 107 | 108 | else: 109 | # during inference detect centers and do additional filtering and scoring if requested 110 | 111 | # apply R adjustment for log directly to input and GT data only during testing 112 | if self.use_log_r and 'centerdir_groundtruth' in gt and not ignore_gt: 113 | input[:, 2:3] = R 114 | gt['centerdir_groundtruth'][0][:, 0] = self.scale_r_gt_fn(gt['centerdir_groundtruth'][0][:, 0]) 115 | 116 | mask = M if self.use_magnitude_as_mask else (cls_mask * (cls_mask > 0).type(torch.float32)) 117 | 118 | # use only ignore flag == 1 here 119 | res, conv_resp = self.instance_center_estimator(C, S, R, M, mask, 120 | ignore_region=ignore & 1 if ignore is not None else None) 121 | 122 | # need to apply relu to ignore raw negative values returned by net 123 | conv_resp = torch.relu(conv_resp) 124 | 125 | res = torch.cat((res, torch.ones((len(res),1),device=res.device)),dim=1) 126 | 127 | res = res.cpu().numpy() 128 | if len(res) > 0: 129 | idx = np.lexsort((res[:, 3],res[:, 4])) 130 | res = res[idx[::-1], :] 131 | 132 | # take only 2000 examples if too may 133 | if res.shape[0] > 2000: 134 | res = res[:2000, :] 135 | 136 | if len(res) > 0: 137 | selected_centers = np.ones(len(res),dtype=bool) 138 | for b in range(len(input)): 139 | batch_idx = res[:, 0] == b 140 | centers_b = res[batch_idx][:, [2, 1, 4]] 141 | 142 | if ignore is not None and len(res) > 0: 143 | # consider all ignore flags except DIFFICULT and padding (8==DIFFICULT; 64,128=PADDING) one which will be handled by evaluation 144 | ignored_pred = np.array([ignore.clone().cpu().numpy()[b, 0, int(r[0]), int(r[1])] & (255 - 8 - 64 - 128) == 0 for r in centers_b]).astype(bool) 145 | 146 | if not np.all(ignored_pred): 147 | selected_centers[batch_idx] *= ignored_pred 148 | 149 | 150 | center_pred = res[selected_centers, :] 151 | else: 152 | center_pred = res 153 | 154 | # res = (batch, x,y, mask_score, center_score) 155 | # voted_mask = 2D array of integers that match (1-based) index of centers in res 156 | # conv_resp_out = list of 2D array with various respones (conv2d for center response, voted_mask, M, etc) 157 | 158 | center_pred = torch.tensor(center_pred).to(input.device) 159 | 160 | # convert center prediction list to tensor of fixed size so that it can be merger from parallel GPU processings 161 | center_pred = self._pack_center_predictions(center_pred, batch_size=len(input)) 162 | 163 | return dict(output=input, center_pred=center_pred, center_heatmap=conv_resp) 164 | 165 | @staticmethod 166 | def _get_edges_to_area_score(res, voted_mask): 167 | 168 | # # calculate mask score based on number of edges to area ratio 169 | # #is_edge = cv2.filter2D(voted_mask.astype(np.float32), -1, np.ones((3, 3)) / 9.0) != voted_mask 170 | voted_mask = voted_mask.type(torch.float32) 171 | is_edge = F.conv2d(voted_mask.unsqueeze(0).unsqueeze(0), 172 | torch.ones((3, 3), dtype=torch.float32, device=voted_mask.device).reshape(1, 1, 3, 3) / 9.0, 173 | padding=1) 174 | is_edge = (torch.abs(is_edge - voted_mask) > 1 / 9.0).squeeze().type(torch.float32) 175 | is_edge_score = [(1 - is_edge[voted_mask == i + 1].sum() / (voted_mask == i + 1).sum()).item() for i, _ in 176 | enumerate(res)] 177 | return is_edge_score 178 | #is_edge_score = np.expand_dims(is_edge_score, axis=1) 179 | 180 | #return np.concatenate((res, is_edge_score), axis=1) 181 | 182 | def _pack_center_predictions(self, center_pred, batch_size): 183 | center_pred_all = torch.zeros((batch_size, self.MAX_NUM_CENTERS, center_pred.shape[1] if len(center_pred) > 0 else 5), 184 | dtype=torch.float, device=center_pred.device) 185 | if len(center_pred) > 0: 186 | for b in center_pred[:, 0].unique().long(): 187 | valid_centers_idx = torch.nonzero(center_pred[:, 0] == b.float()).squeeze(dim=1).long() 188 | 189 | if len(valid_centers_idx) > self.MAX_NUM_CENTERS: 190 | valid_centers_idx = valid_centers_idx[:self.MAX_NUM_CENTERS] 191 | print('WARNING: got more centers (%d) than allowed (%d) - removing last centers to meet criteria' % (len(valid_centers_idx), self.MAX_NUM_CENTERS)) 192 | 193 | center_pred_all[b, :len(valid_centers_idx), 0] = 1 194 | center_pred_all[b, :len(valid_centers_idx), 1:] = center_pred[valid_centers_idx, 1:] 195 | 196 | return center_pred_all 197 | -------------------------------------------------------------------------------- /src/infer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | import os, time 3 | 4 | from matplotlib import pyplot as plt 5 | from tqdm import tqdm 6 | 7 | import numpy as np 8 | import scipy 9 | 10 | import torch 11 | 12 | from datasets import get_centerdir_dataset 13 | from models import get_model, get_center_model 14 | from utils.utils import variable_len_collate 15 | 16 | class Inferencce: 17 | def __init__(self, args): 18 | # if args['display'] and not args.get('display_to_file_only'): 19 | if True: 20 | # plt.switch_backend('TkAgg') 21 | plt.ion() 22 | else: 23 | plt.ioff() 24 | plt.switch_backend("agg") 25 | 26 | if args.get('cudnn_benchmark'): 27 | torch.backends.cudnn.benchmark = True 28 | 29 | self.args = args 30 | 31 | # set device 32 | self.device = torch.device("cuda:0" if args['cuda'] else "cpu") 33 | 34 | def initialize(self): 35 | args = self.args 36 | 37 | ################################################################################################### 38 | # set dataset and model 39 | self.dataset_it, self.model, self.center_model = self._construct_dataset_and_processing(args, self.device) 40 | 41 | 42 | #@classmethod 43 | def _construct_dataset_and_processing(self, args, device): 44 | 45 | ################################################################################################### 46 | # dataloader 47 | dataset_workers = args['dataset']['workers'] if 'workers' in args['dataset'] else 0 48 | dataset_batch = args['dataset']['batch_size'] if 'batch_size' in args['dataset'] else 1 49 | 50 | from utils import transforms as my_transforms 51 | args['dataset']['kwargs']['transform'] = my_transforms.get_transform([ 52 | { 'name': 'Padding', 'opts': { 'keys': ('image',), 'pad_to_size_factor': 32 } }, 53 | { 'name': 'ToTensor', 'opts': { 'keys': ('image',), 'type': (torch.FloatTensor) } }, 54 | ]) 55 | 56 | dataset, _ = get_centerdir_dataset(args['dataset']['name'], args['dataset']['kwargs'], no_groundtruth=True) 57 | 58 | dataset_it = torch.utils.data.DataLoader(dataset, batch_size=dataset_batch, shuffle=False, drop_last=False, 59 | num_workers=dataset_workers, pin_memory=True if args['cuda'] else False, 60 | collate_fn=variable_len_collate) 61 | 62 | ################################################################################################### 63 | # load model 64 | model = get_model(args['model']['name'], args['model']['kwargs']) 65 | model.init_output(args['num_vector_fields']) 66 | model = torch.nn.DataParallel(model).to(device) 67 | 68 | # prepare center_model and center_estimator based on number of center_checkpoint_path that will need to be processed 69 | 70 | center_checkpoint_name = args.get('center_checkpoint_name') if 'center_checkpoint_name' in args else '' 71 | center_checkpoint_path = args.get('center_checkpoint_path') 72 | 73 | center_model = get_center_model(args['center_model']['name'], args['center_model']['kwargs'], 74 | is_learnable=args['center_model'].get('use_learnable_center_estimation'), 75 | use_fast_estimator=True) 76 | 77 | center_model.init_output(args['num_vector_fields']) 78 | center_model = torch.nn.DataParallel(center_model).to(device) 79 | 80 | ################################################################################################### 81 | # load snapshot 82 | if os.path.exists(args['checkpoint_path']): 83 | print('Loading from "%s"' % args['checkpoint_path']) 84 | state = torch.load(args['checkpoint_path']) 85 | if 'model_state_dict' in state: model.load_state_dict(state['model_state_dict'], strict=True) 86 | if not args.get('center_checkpoint_path') and 'center_model_state_dict' in state and args['center_model'].get('use_learnable_center_estimation'): 87 | center_model.load_state_dict(state['center_model_state_dict'], strict=False) 88 | else: 89 | raise Exception('checkpoint_path {} does not exist!'.format(args['checkpoint_path'])) 90 | 91 | if args['center_model'].get('use_learnable_center_estimation') and len(center_checkpoint_name) > 0: 92 | if os.path.exists(center_checkpoint_path): 93 | print('Loading center model from "%s"' % center_checkpoint_path) 94 | state = torch.load(center_checkpoint_path) 95 | if 'center_model_state_dict' in state: 96 | if 'module.instance_center_estimator.conv_start.0.weight' in state['center_model_state_dict']: 97 | checkpoint_input_weights = state['center_model_state_dict']['module.instance_center_estimator.conv_start.0.weight'] 98 | center_input_weights = center_model.module.instance_center_estimator.conv_start[0].weight 99 | if checkpoint_input_weights.shape != center_input_weights.shape: 100 | state['center_model_state_dict']['module.instance_center_estimator.conv_start.0.weight'] = checkpoint_input_weights[:,:2,:,:] 101 | 102 | print('WARNING: #####################################################################################################') 103 | print('WARNING: center input shape mismatch - will load weights for only the first two channels, is this correct ?!!!') 104 | print('WARNING: #####################################################################################################') 105 | 106 | center_model.load_state_dict(state['center_model_state_dict'], strict=False) 107 | else: 108 | raise Exception('checkpoint_path {} does not exist!'.format(center_checkpoint_path)) 109 | 110 | return dataset_it, model, center_model 111 | 112 | ######################################################################################################### 113 | ## MAIN RUN FUNCTION 114 | def run(self): 115 | args = self.args 116 | 117 | time_array = dict(model=[],center=[],post=[],total=[]) 118 | with torch.no_grad(): 119 | model = self.model 120 | center_model = self.center_model 121 | dataset_it = self.dataset_it 122 | 123 | assert dataset_it.batch_size == 1 124 | 125 | model.eval() 126 | center_model.eval() 127 | 128 | im_image = 0 129 | while im_image < 1000: 130 | 131 | for sample in self.dataset_it: 132 | im_image += 1 133 | 134 | torch.cuda.synchronize() 135 | start_model = time.time() 136 | # run main model 137 | output_batch_ = model(sample['image']) 138 | 139 | torch.cuda.synchronize() 140 | start_center = time.time() 141 | 142 | # run center detection model 143 | center_pred, times = center_model(output_batch_) 144 | 145 | #predictions = center_pred[0] 146 | 147 | # make sure data is copied 148 | torch.cuda.synchronize() 149 | end = time.time() 150 | 151 | time_model = start_center - start_model 152 | time_center_total = end - start_center 153 | time_center_preprocess = times[0] 154 | time_center_only = times[1] 155 | time_center_postprocess = times[2] 156 | time_total = end - start_model 157 | 158 | time_array['model'].append(time_model) 159 | time_array['center'].append(time_center_only+time_center_preprocess) 160 | time_array['post'].append(time_center_postprocess) 161 | time_array['total'].append(time_total) 162 | 163 | 164 | print('time total: %.1f ms with model=%.1f ms and center=%.1f ms (pre=%.1f ms, cent=%.1f ms, post=%.1f ms)' % 165 | (time_total*1000, time_model*1000, time_center_total*1000, 166 | time_center_preprocess*1000, time_center_only*1000, time_center_postprocess*1000,)) 167 | 168 | if im_image > 1000: 169 | break 170 | 171 | times_model = np.array(time_array['model'])[100::100] 172 | times_center = np.array(time_array['center'])[100::100] 173 | times_post = np.array(time_array['post'])[100::100] 174 | times_total = np.array(time_array['total'])[100::100] 175 | print('-------------------------------------------------------------') 176 | print('TIMES:') 177 | print('model: avg %.1f ms, std %.1f ms' % (times_model.mean()*1000,times_model.std()*1000 )) 178 | print('center: avg %.1f ms, std %.1f ms' % (times_center.mean() * 1000, times_center.std() * 1000)) 179 | print('post: avg %.1f ms, std %.1f ms' % (times_post.mean() * 1000, times_post.std() * 1000)) 180 | print('total: avg %.1f ms, std %.1f ms' % (times_total.mean() * 1000, times_total.std() * 1000)) 181 | 182 | def main(): 183 | from config import get_config_args 184 | 185 | args = get_config_args(dataset=os.environ.get('DATASET'), type='test') 186 | 187 | infer = Inferencce(args) 188 | infer.initialize() 189 | infer.run() 190 | 191 | if __name__ == "__main__": 192 | main() -------------------------------------------------------------------------------- /src/criterions/orientation_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from models.center_groundtruth import CenterDirGroundtruth 5 | 6 | from criterions.weightings.instance_weight import InstanceGroupWeighting 7 | from criterions.center_direction_loss import CenterDirectionLoss 8 | 9 | from criterions.per_pixel_losses import get_per_pixel_loss_func 10 | 11 | class OrientationLoss(nn.Module): 12 | def __init__(self, model, orientation_args=dict(), **center_dir_args): 13 | super().__init__() 14 | 15 | self.enable_orientation_loss = orientation_args.get('enable') 16 | self.regress_confidence_score = orientation_args.get('regress_confidence_score') 17 | self.no_instance_loss = orientation_args.get('no_instance_loss') 18 | self.enable_6dof = orientation_args.get('enable_6dof') 19 | self.loss_weighting = orientation_args.get('loss_weighting') 20 | 21 | self.centerdir_loss_op = CenterDirectionLoss(model, **center_dir_args) 22 | 23 | self.use_log_r = self.centerdir_loss_op.use_log_r 24 | self.log_r_fn = self.centerdir_loss_op.log_r_fn 25 | self.num_vector_fields = self.centerdir_loss_op.num_vector_fields 26 | 27 | REQUIRED_VECTOR_FIELDS = 5 28 | if self.enable_6dof: 29 | REQUIRED_VECTOR_FIELDS += 4 30 | if self.regress_confidence_score: 31 | REQUIRED_VECTOR_FIELDS += 1 32 | 33 | assert self.num_vector_fields >= REQUIRED_VECTOR_FIELDS 34 | 35 | self.regression_loss_fn = get_per_pixel_loss_func(orientation_args.get('regression_loss')) 36 | 37 | self.orientation_weighting = InstanceGroupWeighting(border_weight=center_dir_args.get('border_weight',1.0), 38 | border_weight_px=center_dir_args.get('border_weight_px',0), 39 | add_distance_gauss_weight=False) 40 | self.tmp = nn.Conv2d(8,8,3) 41 | 42 | def forward(self, prediction, sample, centerdir_responses=None, centerdir_gt=None, ignore_mask=None, 43 | difficult_mask=None, w_orientation=1, w_fg_orientation=1, w_bg_orientation=1, w_confidence_score=1, reduction_dims=(1, 2, 3), **kwargs): 44 | 45 | loss_output_shape = [d for i, d in enumerate(prediction.shape) if i not in reduction_dims] 46 | loss_zero_init = lambda: torch.zeros(size=loss_output_shape, device=prediction.device) 47 | 48 | loss_sin_orientation, loss_cos_orientation, loss_confidence_score = map(torch.clone, [loss_zero_init()] * 3) 49 | 50 | instances = sample["instance"] 51 | instances = instances.squeeze(1) 52 | 53 | # batch computation --- 54 | labels = sample["label"] 55 | bg_mask = labels == 0 56 | fg_mask = bg_mask == False 57 | 58 | centerdir_vectors = prediction[:, 0:self.num_vector_fields] 59 | 60 | # WARNING: this assumes CenterDirectionLoss is used as parent (not compatible with other) 61 | if self.enable_6dof: 62 | prediction_sin_orientation = centerdir_vectors[:, 3:6].unsqueeze(2) 63 | prediction_cos_orientation = centerdir_vectors[:, 6:9].unsqueeze(2) 64 | else: 65 | prediction_sin_orientation = centerdir_vectors[:, 3:4].unsqueeze(2) 66 | prediction_cos_orientation = centerdir_vectors[:, 4:5].unsqueeze(2) 67 | 68 | 69 | if instances.dtype != torch.int16: 70 | instances = instances.type(torch.int16) 71 | 72 | # mark ignore regions as -9999 in instances so that size can be correctly calculated in InstanceGroupWeighting 73 | if ignore_mask is not None: 74 | instances = instances.clone() # do not destroy original 75 | instances[ignore_mask.squeeze(dim=1) == 1] = InstanceGroupWeighting.IGNORE_FLAG 76 | 77 | # retrieve groundtruth values 78 | key = ['gt_orientation_sin', 'gt_orientation_cos'] 79 | gt_sin_orientation, gt_cos_orientation = CenterDirGroundtruth.parse_groundtruth_map(centerdir_gt,keys=key) 80 | 81 | assert prediction_sin_orientation.shape[1] == prediction_cos_orientation.shape[1] 82 | 83 | ROT_DIM = prediction_sin_orientation.shape[1] 84 | 85 | assert ROT_DIM <= gt_sin_orientation.shape[1] 86 | assert ROT_DIM <= gt_cos_orientation.shape[1] 87 | 88 | gt_sin_orientation = gt_sin_orientation[:, :ROT_DIM] 89 | gt_cos_orientation = gt_cos_orientation[:, :ROT_DIM] 90 | 91 | assert prediction_sin_orientation.shape[1] == gt_sin_orientation.shape[1] 92 | assert prediction_cos_orientation.shape[1] == gt_cos_orientation.shape[1] 93 | 94 | if self.regress_confidence_score: 95 | gt_confidence_score = CenterDirGroundtruth.parse_groundtruth_map(centerdir_gt,keys=['gt_confidence_score']) 96 | 97 | prediction_confidence_score = centerdir_vectors[:, 9:10] if self.enable_6dof else centerdir_vectors[:, 5:6] 98 | prediction_confidence_score = prediction_confidence_score.unsqueeze(2) 99 | 100 | 101 | # prepare all arguments that are needed for calculating weighting mask 102 | weighting_args = dict(gt_instances=instances, gt_ignore=ignore_mask, gt_difficult=difficult_mask, 103 | w_fg=w_fg_orientation, w_bg=w_bg_orientation) 104 | 105 | ###################################################### 106 | ### centerdir_vectors losses (cos, sin) 107 | loss_sin_orientation, loss_cos_orientation = zip(*[(loss_sin_orientation.clone(), loss_cos_orientation.clone()) for _ in range(ROT_DIM)]) 108 | 109 | if self.enable_orientation_loss: 110 | 111 | with torch.no_grad(): 112 | mask_weights = self.orientation_weighting(**weighting_args) 113 | 114 | # add regression loss for sin(orientation), cos(orientation) 115 | if self.no_instance_loss: 116 | 117 | loss_sin_orientation += torch.sum( 118 | mask_weights * self.regression_loss_fn(prediction_sin_orientation, gt_sin_orientation), 119 | dim=reduction_dims) 120 | loss_cos_orientation += torch.sum( 121 | mask_weights * self.regression_loss_fn(prediction_cos_orientation, gt_cos_orientation), 122 | dim=reduction_dims) 123 | 124 | if self.regress_confidence_score: 125 | loss_confidence_score += torch.sum( 126 | mask_weights * self.regression_loss_fn(prediction_confidence_score, gt_confidence_score), 127 | dim=reduction_dims) 128 | 129 | else: 130 | for b in range(mask_weights.shape[0]): 131 | fg_mask_weights = mask_weights[b][fg_mask[b]] 132 | for rot_dim in range(ROT_DIM): 133 | loss_sin_orientation[rot_dim][b] += torch.sum( 134 | fg_mask_weights * self.regression_loss_fn(prediction_sin_orientation[b][rot_dim,fg_mask[b]], 135 | gt_sin_orientation[b][rot_dim,fg_mask[b]])) 136 | loss_cos_orientation[rot_dim][b] += torch.sum( 137 | fg_mask_weights * self.regression_loss_fn(prediction_cos_orientation[b][rot_dim,fg_mask[b]], 138 | gt_cos_orientation[b][rot_dim,fg_mask[b]])) 139 | if self.regress_confidence_score: 140 | loss_confidence_score[b] += torch.sum( 141 | fg_mask_weights * self.regression_loss_fn(prediction_confidence_score[b][0,fg_mask[b]], 142 | gt_confidence_score[b][0,fg_mask[b]])) 143 | 144 | loss_sin_orientation = [w_orientation*l for l in loss_sin_orientation] 145 | loss_cos_orientation = [w_orientation*l for l in loss_cos_orientation] 146 | 147 | loss_confidence_score = [w_confidence_score*l for l in loss_confidence_score] 148 | 149 | loss_orientation = sum(loss_sin_orientation) + sum(loss_cos_orientation) + sum(loss_confidence_score) 150 | loss_orientation += prediction.sum() * 0 151 | 152 | # call base loss function for center direction 153 | all_centerdir_losses = self.centerdir_loss_op.forward(prediction, sample, centerdir_responses, 154 | centerdir_gt, ignore_mask, difficult_mask, 155 | reduction_dims=reduction_dims, **kwargs) 156 | 157 | losses_main = [all_centerdir_losses[0] + loss_orientation] 158 | 159 | return tuple(losses_main + list(all_centerdir_losses[1:]) + [loss_orientation] + list(loss_sin_orientation) + list(loss_cos_orientation)) 160 | 161 | 162 | def get_loss_dict(self, loss_tensor): 163 | 164 | loss, loss_cls, loss_centerdir_total, loss_centers, loss_sin, \ 165 | loss_cos, loss_r, loss_magnitude_reg, loss_orientation_total = [l.sum() for l in loss_tensor[:9]] 166 | 167 | loss_orientation = [l.sum() for l in loss_tensor[9:]] 168 | loss_orientation_sin = loss_orientation[:len(loss_orientation) // 2] # first half is sin 169 | loss_orientation_cos = loss_orientation[len(loss_orientation) // 2:] # second half is cos 170 | orientation_dims = ['rot_y','rot_z','rot_x'] 171 | 172 | return dict( # main loss for backprop: 173 | loss=loss, 174 | # losses for visualization: 175 | losses_groups=dict(cls=loss_cls, centerdir_total=loss_centerdir_total, centers=loss_centers, orientation_total=loss_orientation_total), 176 | losses_centerdir_total=dict(sin=loss_sin, cos=loss_cos, r=loss_r, magnitude_reg=loss_magnitude_reg), 177 | losses_orientation_total={**{f'{orientation_dims[i]}_sin':l for i,l in enumerate(loss_orientation_sin)}, 178 | **{f'{orientation_dims[i]}_cos':l for i,l in enumerate(loss_orientation_cos)}}, 179 | losses_main=dict(cls=loss_cls, sin=loss_sin, cos=loss_cos, r=loss_r, cent=loss_centers, ), 180 | # losses for task weighting: 181 | losses_tasks=dict(centerdir=loss_centerdir_total, 182 | **{f'orinetation_{orientation_dims[i]}':(l_sin+l_cos) 183 | for i,(l_sin,l_cos) in enumerate(zip(loss_orientation_sin,loss_orientation_cos))}), 184 | ) 185 | --------------------------------------------------------------------------------