├── PmNet.yml ├── README.md ├── assets ├── Pipeline.png └── photo ├── datasets ├── __pycache__ │ ├── args.cpython-311.pyc │ ├── functional.cpython-310.pyc │ └── functional.cpython-311.pyc ├── args.py ├── convert_results │ ├── convert_autolaparo.py │ └── convert_cholec80.py ├── data_preprosses │ ├── generate_labels_autolaparo.py │ ├── generate_labels_ch80.py │ ├── generate_labels_lungseg.py │ └── generate_labels_pmlr.py ├── extract_frames │ ├── extract_frames_autolaparo.py │ └── extract_frames_ch80.py ├── functional.py ├── phase │ ├── AutoLaparo_phase.py │ ├── Cholec80_phase.py │ ├── PmLR50_phase.py │ └── __pycache__ │ │ ├── AutoLaparo_phase.cpython-310.pyc │ │ ├── AutoLaparo_phase.cpython-311.pyc │ │ ├── Cholec80_phase.cpython-310.pyc │ │ ├── Cholec80_phase.cpython-311.pyc │ │ └── PmLR50_phase.cpython-311.pyc ├── tools │ ├── frame_cutmargin.py │ ├── resize_frame.py │ ├── transfer_csv_txt.py │ └── transfer_json_txt.py └── transforms │ ├── __pycache__ │ ├── mixup.cpython-310.pyc │ ├── mixup.cpython-311.pyc │ ├── optim_factory.cpython-310.pyc │ ├── optim_factory.cpython-311.pyc │ ├── rand_augment.cpython-310.pyc │ ├── rand_augment.cpython-311.pyc │ ├── random_erasing.cpython-310.pyc │ ├── random_erasing.cpython-311.pyc │ ├── surg_transforms.cpython-310.pyc │ ├── surg_transforms.cpython-311.pyc │ ├── video_transforms.cpython-310.pyc │ ├── video_transforms.cpython-311.pyc │ ├── volume_transforms.cpython-310.pyc │ └── volume_transforms.cpython-311.pyc │ ├── image_transforms.py │ ├── mixup.py │ ├── optim_factory.py │ ├── rand_augment.py │ ├── random_erasing.py │ ├── surg_transforms.py │ ├── transforms.py │ ├── video_transforms.py │ └── volume_transforms.py ├── downstream_phase ├── __pycache__ │ ├── datasets_phase.cpython-310.pyc │ ├── datasets_phase.cpython-311.pyc │ ├── engine_for_phase.cpython-310.pyc │ └── engine_for_phase.cpython-311.pyc ├── datasets_phase.py └── engine_for_phase.py ├── evaluation_matlab ├── Evaluate.m ├── Evaluate_Cataract101.m ├── Evaluate_m2cai.m ├── Main.m ├── Main_AutoLaparo.m ├── Main_Cataract101.m ├── Main_m2cai.m ├── README.md ├── ReadPhaseLabel.m ├── matlab.mat ├── octave-workspace └── test.m ├── model ├── __pycache__ │ ├── mambapy.cpython-311.pyc │ ├── pmnet.cpython-311.pyc │ └── pscan.cpython-311.pyc ├── mambapy.py ├── pmnet.py └── pscan.py ├── run_phase_training.py ├── scripts ├── test_phase.sh └── train_phase.sh ├── test.sh ├── train.sh └── utils.py /PmNet.yml: -------------------------------------------------------------------------------- 1 | name: Pmnet 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=5.1=1_gnu 9 | - beautifulsoup4=4.12.2=py38h06a4308_0 10 | - blas=1.0=mkl 11 | - boltons=23.0.0=py38h06a4308_0 12 | - brotlipy=0.7.0=py38h27cfd23_1003 13 | - bzip2=1.0.8=h7f98852_4 14 | - ca-certificates=2023.08.22=h06a4308_0 15 | - certifi=2023.7.22=py38h06a4308_0 16 | - cffi=1.15.1=py38h5eee18b_3 17 | - chardet=4.0.0=py38h06a4308_1003 18 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 19 | - click=8.0.4=py38h06a4308_0 20 | - conda=23.7.4=py38h06a4308_0 21 | - conda-build=3.25.0=py38h06a4308_0 22 | - conda-index=0.2.3=py38h06a4308_0 23 | - conda-package-handling=2.2.0=py38h06a4308_0 24 | - conda-package-streaming=0.9.0=py38h06a4308_0 25 | - cryptography=41.0.3=py38h130f0dd_0 26 | - cudatoolkit=11.1.1=ha002fc5_10 27 | - ffmpeg=4.3=hf484d3e_0 28 | - filelock=3.9.0=py38h06a4308_0 29 | - freetype=2.10.4=h0708190_1 30 | - glob2=0.7=pyhd3eb1b0_0 31 | - gmp=6.2.1=h58526e2_0 32 | - gnutls=3.6.13=h85f3911_1 33 | - icu=58.2=he6710b0_3 34 | - idna=3.4=py38h06a4308_0 35 | - intel-openmp=2021.4.0=h06a4308_3561 36 | - jinja2=3.1.2=py38h06a4308_0 37 | - jpeg=9b=h024ee3a_2 38 | - jsonpatch=1.32=pyhd3eb1b0_0 39 | - jsonpointer=2.1=pyhd3eb1b0_0 40 | - lame=3.100=h7f98852_1001 41 | - ld_impl_linux-64=2.38=h1181459_1 42 | - libarchive=3.4.2=h62408e4_0 43 | - libffi=3.4.4=h6a678d5_0 44 | - libgcc-ng=11.2.0=h1234567_1 45 | - libgomp=11.2.0=h1234567_1 46 | - libiconv=1.17=h166bdaf_0 47 | - liblief=0.12.3=h6a678d5_0 48 | - libpng=1.6.37=h21135ba_2 49 | - libstdcxx-ng=11.2.0=h1234567_1 50 | - libtiff=4.1.0=h2733197_1 51 | - libuv=1.43.0=h7f98852_0 52 | - libxml2=2.9.14=h74e7548_0 53 | - lz4-c=1.9.3=h9c3ff4c_1 54 | - markupsafe=2.1.1=py38h7f8727e_0 55 | - mkl=2021.4.0=h06a4308_640 56 | - mkl-service=2.4.0=py38h95df7f1_0 57 | - mkl_fft=1.3.1=py38h8666266_1 58 | - mkl_random=1.2.2=py38h1abd341_0 59 | - more-itertools=8.12.0=pyhd3eb1b0_0 60 | - ncurses=6.4=h6a678d5_0 61 | - nettle=3.6=he412f7d_0 62 | - numpy=1.24.3=py38h14f4228_0 63 | - numpy-base=1.24.3=py38h31eccc5_0 64 | - olefile=0.46=pyh9f0ad1d_1 65 | - openh264=2.1.1=h780b84a_0 66 | - openssl=1.1.1w=h7f8727e_0 67 | - packaging=23.1=py38h06a4308_0 68 | - patch=2.7.6=h7b6447c_1001 69 | - patchelf=0.17.2=h6a678d5_0 70 | - pip=23.2.1=py38h06a4308_0 71 | - pkginfo=1.9.6=py38h06a4308_0 72 | - pluggy=1.0.0=py38h06a4308_1 73 | - psutil=5.9.0=py38h5eee18b_0 74 | - py-lief=0.12.3=py38h6a678d5_0 75 | - pycosat=0.6.4=py38h5eee18b_0 76 | - pycparser=2.21=pyhd3eb1b0_0 77 | - pyopenssl=23.2.0=py38h06a4308_0 78 | - pysocks=1.7.1=py38h06a4308_0 79 | - python=3.8.18=h7a1cb2a_0 80 | - python-libarchive-c=2.9=pyhd3eb1b0_1 81 | - python_abi=3.8=2_cp38 82 | - pytz=2022.7=py38h06a4308_0 83 | - pyyaml=6.0=py38h5eee18b_1 84 | - readline=8.2=h5eee18b_0 85 | - requests=2.31.0=py38h06a4308_0 86 | - ruamel.yaml=0.17.21=py38h5eee18b_0 87 | - ruamel.yaml.clib=0.2.6=py38h5eee18b_1 88 | - setuptools=68.0.0=py38h06a4308_0 89 | - six=1.16.0=pyh6c4a22f_0 90 | - soupsieve=2.4=py38h06a4308_0 91 | - sqlite=3.41.2=h5eee18b_0 92 | - tk=8.6.12=h1ccaba5_0 93 | - tomli=2.0.1=py38h06a4308_0 94 | - toolz=0.12.0=py38h06a4308_0 95 | - tqdm=4.65.0=py38hb070fc8_0 96 | - urllib3=1.26.16=py38h06a4308_0 97 | - wheel=0.38.4=py38h06a4308_0 98 | - xz=5.4.2=h5eee18b_0 99 | - yaml=0.2.5=h7b6447c_0 100 | - zlib=1.2.13=h5eee18b_0 101 | - zstandard=0.19.0=py38h5eee18b_0 102 | - zstd=1.4.9=ha95c52a_0 103 | - pip: 104 | - absl-py==2.0.0 105 | - appdirs==1.4.4 106 | - astunparse==1.6.3 107 | - cachetools==5.3.1 108 | - contourpy==1.1.1 109 | - cycler==0.11.0 110 | - decord==0.6.0 111 | - defusedxml==0.7.1 112 | - docker-pycreds==0.4.0 113 | - einops==0.6.1 114 | - flatbuffers==23.5.26 115 | - fonttools==4.42.1 116 | - fsspec==2023.9.1 117 | - ftfy==6.1.1 118 | - gast==0.4.0 119 | - gitdb==4.0.11 120 | - gitpython==3.1.40 121 | - glances==3.4.0.3 122 | - google-auth==2.23.2 123 | - google-auth-oauthlib==1.0.0 124 | - google-pasta==0.2.0 125 | - grpcio==1.59.0 126 | - h5py==3.10.0 127 | - hjson==3.1.0 128 | - huggingface-hub==0.17.1 129 | - imageio==2.31.3 130 | - imgaug==0.4.0 131 | - imgcat==0.5.0 132 | - importlib-metadata==6.8.0 133 | - importlib-resources==6.0.1 134 | - joblib==1.3.2 135 | - keras==2.13.1 136 | - kiwisolver==1.4.5 137 | - lazy-loader==0.3 138 | - libclang==16.0.6 139 | - markdown==3.4.4 140 | - matplotlib==3.7.3 141 | - networkx==3.1 142 | - ninja==1.11.1 143 | - nltk==3.7 144 | - nvidia-ml-py==12.535.108 145 | - nvitop==1.3.0 146 | - oauthlib==3.2.2 147 | - openai-clip==1.0.1 148 | - opencv-python==4.8.0.76 149 | - opt-einsum==3.3.0 150 | - pandas==2.0.3 151 | - pillow==10.0.1 152 | - protobuf==4.24.3 153 | - pyasn1==0.5.0 154 | - pyasn1-modules==0.3.0 155 | - pyparsing==3.1.1 156 | - python-dateutil==2.8.2 157 | - pywavelets==1.4.1 158 | - regex==2023.10.3 159 | - requests-oauthlib==1.3.1 160 | - rsa==4.9 161 | - sacremoses==0.1.1 162 | - safetensors==0.3.3 163 | - scikit-image==0.21.0 164 | - scikit-learn==1.3.0 165 | - scipy==1.10.1 166 | - sentry-sdk==1.34.0 167 | - seqeval==1.2.2 168 | - setproctitle==1.3.3 169 | - shapely==2.0.1 170 | - smmap==5.0.1 171 | - tensorboard==2.13.0 172 | - tensorboard-data-server==0.7.1 173 | - tensorboardx==2.6.2.2 174 | - threadpoolctl==3.2.0 175 | - tifffile==2023.7.10 176 | - timm==0.4.12 177 | - tokenizers==0.14.1 178 | - transformers==4.16.2 179 | - triton==2.1.0 180 | - typing-extensions==4.6.1 181 | - tzdata==2023.3 182 | - ujson==5.8.0 183 | - wandb==0.16.0 184 | - wcwidth==0.2.8 185 | - werkzeug==3.0.0 186 | - wrapt==1.16.0 187 | - zipp==3.16.2 188 | - deepspeed==0.16.4 189 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [AAAI 2025] Surgical Workflow Recognition and Blocking Effectiveness Detection in Laparoscopic Liver Resections with Pringle Maneuver 2 | 3 |
4 | 5 |
6 | 7 | [![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT) 8 | 9 | ## Abstract 10 | 11 | > Pringle maneuver (PM) in laparoscopic liver resection aims to reduce blood loss and provide a clear surgical view by intermittently blocking blood inflow of the liver, whereas prolonged PM may cause ischemic injury. To comprehensively monitor this surgical procedure and provide timely warnings of ineffective and prolonged blocking, we suggest two complementary AI-assisted surgical monitoring tasks: workflow recognition and blocking effectiveness detection in liver resections. The former presents challenges in real-time capturing of short-term PM, while the latter involves the intraoperative discrimination of long-term liver ischemia states. To address these challenges, we meticulously collect a novel dataset, called PmLR50, consisting of 25,037 video frames covering various surgical phases from 50 laparoscopic liver resection procedures. Additionally, we develop an online baseline for PmLR50, termed PmNet. This model embraces Masked Temporal Encoding (MTE) and Compressed Sequence Modeling (CSM) for efficient short-term and long-term temporal information modeling, and embeds Contrastive Prototype Separation (CPS) to enhance action discrimination between similar intraoperative operations. Experimental results demonstrate that PmNet outperforms existing state-of-the-art surgical workflow recognition methods on the PmLR50 benchmark. Our research offers potential clinical applications for the laparoscopic liver surgery community. 12 | 13 | 14 | ## 🔥🔥🔥 News!! 15 | * Dec 12, 2024: 🤗 Our work has been accepted by AAAI 2025! Congratulations! 16 | * Feb 21, 2025: 🚀 Code and dataset have been released! [Dataset Link](https://docs.google.com/forms/d/e/1FAIpQLSf33G5mdwXeqwabfbXnEboMpj48iCNlQBAY_up4kLuZiqCPUQ/viewform?usp=dialog) 17 | 18 | 19 | ## PmLR50 Dataset and PmNet 20 | ### Installation 21 | * Environment: CUDA 12.5 / Python 3.8 22 | * Device: Two NVIDIA GeForce RTX 4090 GPUs 23 | * Create a virtual environment 24 | ```shell 25 | git clone https://github.com/RascalGdd/PmNet.git 26 | cd PmNet 27 | conda env create -f PmNet.yml 28 | conda activate Pmnet 29 | ``` 30 | ### Prepare your data 31 | Download processed data from [PmLR50](https://docs.google.com/forms/d/e/1FAIpQLSf33G5mdwXeqwabfbXnEboMpj48iCNlQBAY_up4kLuZiqCPUQ/viewform?usp=dialog); 32 | The final structure of datasets should be as following: 33 | 34 | ```bash 35 | data/ 36 | └──PmLR50/ 37 | └──frames/ 38 | └──01 39 | ├──00000000.jpg 40 | ├──00000001.jpg 41 | └──... 42 | ├──... 43 | └──50 44 | └──phase_annotations/ 45 | └──01.txt 46 | ├──02.txt 47 | ├──... 48 | └──50.txt 49 | └──blocking_annotations/ 50 | └──01.txt 51 | ├──02.txt 52 | ├──... 53 | └──50.txt 54 | └──bbox_annotations/ 55 | └──01.json 56 | ├──02.json 57 | ├──... 58 | └──50.json 59 | ``` 60 | Then, process the data with [generate_labels_pmlr.py](https://github.com/RascalGdd/PmNet/blob/main/datasets/data_preprosses/generate_labels_pmlr.py) to generate labels for training and testing. 61 | 62 | ### Training 63 | We provide the script for training [train.sh](https://github.com/RascalGdd/PmNet/blob/main/train.sh) and testing [test.sh](https://github.com/RascalGdd/PmNet/blob/main/test.sh). 64 | 65 | run the following code for training 66 | 67 | ```shell 68 | sh train.sh 69 | ``` 70 | and run the following code for testing 71 | 72 | ```shell 73 | sh test.sh 74 | ``` 75 | The checkpoint of our model is provided [here](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155229775_link_cuhk_edu_hk/EZVHcTmQBY1Mv1zTSLEtu0cBKTA7zTNURaG65gWWloqFmg?e=Zudo2X). 76 | 77 | ### More Configurations 78 | 79 | We list some more useful configurations for easy usage: 80 | 81 | | Argument | Default | Description | 82 | |:----------------------:|:---------:|:-----------------------------------------:| 83 | | `--nproc_per_node` | 2 | Number of nodes used for training and testing | 84 | | `--batch_size` | 8 | The batch size for training and inference | 85 | | `--epochs` | 50 | The max epoch for training | 86 | | `--save_ckpt_freq` | 10 | The frequency for saving checkpoints during training | 87 | | `--nb_classes` | 5 | The number of classes for surgical workflows | 88 | | `--data_strategy` | online | Online/offline mode | 89 | | `--num_frames` | 20 | The number of consecutive frames used | 90 | | `--sampling_rate` | 8 | The sampling interval for comsecutive frames | 91 | | `--enable_deepspeed` | True | Use deepspeed to accelerate | 92 | | `--dist_eval` | False | Use distributed evaluation to accelerate | 93 | | `--load_ckpt` | -- | Load a given checkpoint for testing | 94 | 95 | ## Acknowledgements 96 | Huge thanks to the authors of following open-source projects: 97 | - [TMRNet](https://github.com/YuemingJin/TMRNet) 98 | - [Surgformer](https://github.com/isyangshu/Surgformer/) 99 | - [TimeSformer](https://github.com/facebookresearch/TimeSformer) 100 | 101 | ## Citation 102 | If you find our work useful in your research, please consider citing our paper: 103 | 104 | @inproceedings{guo2025surgical, 105 | title={Surgical Workflow Recognition and Blocking Effectiveness Detection in Laparoscopic Liver Resection with Pringle Maneuver}, 106 | author={Guo, Diandian and Si, Weixin and Li, Zhixi and Pei, Jialun and Heng, Pheng-Ann}, 107 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 108 | volume={39}, 109 | number={3}, 110 | pages={3220--3228}, 111 | year={2025} 112 | } 113 | -------------------------------------------------------------------------------- /assets/Pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/assets/Pipeline.png -------------------------------------------------------------------------------- /assets/photo: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /datasets/__pycache__/args.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/__pycache__/args.cpython-311.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/functional.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/__pycache__/functional.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/functional.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/__pycache__/functional.cpython-311.pyc -------------------------------------------------------------------------------- /datasets/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_args(): 4 | parser = argparse.ArgumentParser('SurgVideoMAE pre-training script', add_help=False) 5 | parser.add_argument('--batch_size', default=64, type=int) 6 | parser.add_argument('--epochs', default=801, type=int) 7 | parser.add_argument('--save_ckpt_freq', default=20, type=int) 8 | 9 | # Model parameters 10 | parser.add_argument('--model', default='pretrain_videomae_base_patch16_224', type=str, metavar='MODEL', 11 | help='Name of model to train') 12 | 13 | parser.add_argument('--decoder_depth', default=4, type=int, 14 | help='depth of decoder') 15 | 16 | parser.add_argument('--mask_type', default='tube', choices=['random', 'tube'], 17 | type=str, help='masked strategy of video tokens/patches') 18 | 19 | parser.add_argument('--mask_ratio', default=0.9, type=float, 20 | help='ratio of the visual tokens/patches need be masked') 21 | 22 | parser.add_argument('--input_size', default=224, type=int, 23 | help='videos input size for backbone') 24 | 25 | parser.add_argument('--drop_path', type=float, default=0.0, metavar='PCT', 26 | help='Drop path rate (default: 0.1)') 27 | 28 | parser.add_argument('--normlize_target', default=True, type=bool, 29 | help='normalized the target patch pixels') 30 | 31 | # Optimizer parameters 32 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 33 | help='Optimizer (default: "adamw"') 34 | parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON', 35 | help='Optimizer Epsilon (default: 1e-8)') 36 | parser.add_argument('--opt_betas', default=(0.9, 0.95), type=float, nargs='+', metavar='BETA', 37 | help='Optimizer Betas (default: None, use opt default)') 38 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', 39 | help='Clip gradient norm (default: None, no clipping)') 40 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 41 | help='SGD momentum (default: 0.9)') 42 | parser.add_argument('--weight_decay', type=float, default=0.05, 43 | help='weight decay (default: 0.05)') 44 | parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the 45 | weight decay. We use a cosine schedule for WD. 46 | (Set the same value with args.weight_decay to keep weight decay no change)""") 47 | 48 | parser.add_argument('--lr', type=float, default=1.5e-4, metavar='LR', 49 | help='learning rate (default: 1.5e-4)') 50 | parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR', 51 | help='warmup learning rate (default: 1e-6)') 52 | parser.add_argument('--min_lr', type=float, default=1e-5, metavar='LR', 53 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 54 | 55 | parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', 56 | help='epochs to warmup LR, if scheduler supports') 57 | parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N', 58 | help='epochs to warmup LR, if scheduler supports') 59 | parser.add_argument('--use_checkpoint', action='store_true') 60 | parser.set_defaults(use_checkpoint=False) 61 | 62 | # Augmentation parameters 63 | parser.add_argument('--color_jitter', type=float, default=0.0, metavar='PCT', 64 | help='Color jitter factor (default: 0.4)') 65 | parser.add_argument('--train_interpolation', type=str, default='bicubic', 66 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 67 | 68 | # Dataset parameters 69 | parser.add_argument('--data_path', default='/path/to/list_kinetics-400', type=str, 70 | help='dataset path') 71 | parser.add_argument('--imagenet_default_mean_and_std', default=True, action='store_true') 72 | parser.add_argument('--num_frames', type=int, default= 16) 73 | parser.add_argument('--sampling_rate', type=int, default=2) 74 | parser.add_argument('--output_dir', default='Cholec80', 75 | help='path where to save, empty for no saving') 76 | parser.add_argument('--log_dir', default=None, 77 | help='path where to tensorboard log') 78 | parser.add_argument('--device', default='cuda', 79 | help='device to use for training / testing') 80 | parser.add_argument('--seed', default=0, type=int) 81 | parser.add_argument('--resume', default='', help='resume from checkpoint') 82 | parser.add_argument('--auto_resume', action='store_true') 83 | parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume') 84 | parser.set_defaults(auto_resume=True) 85 | 86 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 87 | help='start epoch') 88 | parser.add_argument('--num_workers', default=10, type=int) 89 | parser.add_argument('--pin_mem', action='store_true', 90 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 91 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem', 92 | help='') 93 | parser.set_defaults(pin_mem=True) 94 | 95 | # distributed training parameters 96 | parser.add_argument('--world_size', default=1, type=int, 97 | help='number of distributed processes') 98 | parser.add_argument('--local_rank', default=-1, type=int) 99 | parser.add_argument('--dist_on_itp', action='store_true') 100 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 101 | 102 | parser.add_argument('--window_size', default=(8,14,14), type=tuple, 103 | help='number of distributed processes') 104 | 105 | return parser.parse_args() 106 | 107 | 108 | def get_args_finetuning(): 109 | parser = argparse.ArgumentParser( 110 | "SurgVideoMAE fine-tuning and evaluation script for video phase recognition", 111 | add_help=False, 112 | ) 113 | parser.add_argument("--batch_size", default=10, type=int) 114 | parser.add_argument("--epochs", default=100, type=int) 115 | parser.add_argument("--update_freq", default=1, type=int) 116 | parser.add_argument("--save_ckpt_freq", default=10, type=int) 117 | 118 | 119 | # Model parameters 120 | parser.add_argument( 121 | "--model", 122 | default="vit_base_patch16_224", 123 | type=str, 124 | metavar="MODEL", 125 | help="Name of model to train", 126 | ) 127 | parser.add_argument("--tubelet_size", type=int, default=2) 128 | parser.add_argument("--full_finetune", action="store_true", default=False) 129 | parser.add_argument("--input_size", default=224, type=int, help="videos input size") 130 | 131 | parser.add_argument( 132 | "--fc_drop_rate", 133 | type=float, 134 | default=0.0, 135 | metavar="PCT", 136 | help="Dropout rate (default: 0.)", 137 | ) 138 | parser.add_argument( 139 | "--drop", 140 | type=float, 141 | default=0.0, 142 | metavar="PCT", 143 | help="Dropout rate (default: 0.)", 144 | ) 145 | parser.add_argument( 146 | "--attn_drop_rate", 147 | type=float, 148 | default=0.0, 149 | metavar="PCT", 150 | help="Attention dropout rate (default: 0.)", 151 | ) 152 | parser.add_argument( 153 | "--drop_path", 154 | type=float, 155 | default=0, 156 | metavar="PCT", 157 | help="Drop path rate (default: 0.1)", 158 | ) 159 | 160 | parser.add_argument( 161 | "--disable_eval_during_finetuning", action="store_true", default=False 162 | ) 163 | parser.add_argument("--model_ema", action="store_true", default=False) 164 | parser.add_argument("--model_ema_decay", type=float, default=0.9999, help="") 165 | parser.add_argument( 166 | "--model_ema_force_cpu", action="store_true", default=False, help="" 167 | ) 168 | 169 | # Optimizer parameters 170 | parser.add_argument( 171 | "--opt", 172 | default="adamw", 173 | type=str, 174 | metavar="OPTIMIZER", 175 | help='Optimizer (default: "adamw"', 176 | ) 177 | parser.add_argument( 178 | "--opt_eps", 179 | default=1e-8, 180 | type=float, 181 | metavar="EPSILON", 182 | help="Optimizer Epsilon (default: 1e-8)", 183 | ) 184 | parser.add_argument( 185 | "--opt_betas", 186 | default=(0.9, 0.999), 187 | type=float, 188 | nargs="+", 189 | metavar="BETA", 190 | help="Optimizer Betas (default: None, use opt default)", 191 | ) 192 | parser.add_argument( 193 | "--clip_grad", 194 | type=float, 195 | default=None, 196 | metavar="NORM", 197 | help="Clip gradient norm (default: None, no clipping)", 198 | ) 199 | parser.add_argument( 200 | "--momentum", 201 | type=float, 202 | default=0.9, 203 | metavar="M", 204 | help="SGD momentum (default: 0.9)", 205 | ) 206 | parser.add_argument( 207 | "--weight_decay", type=float, default=0.05, help="weight decay (default: 0.05)" 208 | ) 209 | parser.add_argument( 210 | "--weight_decay_end", 211 | type=float, 212 | default=None, 213 | help="""Final value of the 214 | weight decay. We use a cosine schedule for WD and using a larger decay by 215 | the end of training improves performance for ViTs.""", 216 | ) 217 | 218 | parser.add_argument( 219 | "--lr", 220 | type=float, 221 | default=5e-4, 222 | metavar="LR", 223 | help="learning rate (default: 1e-3)", 224 | ) 225 | parser.add_argument("--layer_decay", type=float, default=0.75) 226 | 227 | parser.add_argument( 228 | "--warmup_lr", 229 | type=float, 230 | default=1e-6, 231 | metavar="LR", 232 | help="warmup learning rate (default: 1e-6)", 233 | ) 234 | parser.add_argument( 235 | "--min_lr", 236 | type=float, 237 | default=1e-6, 238 | metavar="LR", 239 | help="lower lr bound for cyclic schedulers that hit 0 (1e-5)", 240 | ) 241 | 242 | parser.add_argument( 243 | "--warmup_epochs", 244 | type=int, 245 | default=5, 246 | metavar="N", 247 | help="epochs to warmup LR, if scheduler supports", 248 | ) 249 | parser.add_argument( 250 | "--warmup_steps", 251 | type=int, 252 | default=-1, 253 | metavar="N", 254 | help="num of steps to warmup LR, will overload warmup_epochs if set > 0", 255 | ) 256 | 257 | # Augmentation parameters 258 | parser.add_argument( 259 | "--color_jitter", 260 | type=float, 261 | default=0.4, 262 | metavar="PCT", 263 | help="Color jitter factor (default: 0.4)", 264 | ) 265 | parser.add_argument( 266 | "--num_sample", type=int, default=1, help="Repeated_aug (default: 2)" 267 | ) 268 | parser.add_argument( 269 | "--aa", 270 | type=str, 271 | default="rand-m7-n4-mstd0.5-inc1", 272 | # default="rand-m8-n2-mstd0.5-inc1", 273 | metavar="NAME", 274 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m7-n4-mstd0.5-inc1)', 275 | ), 276 | parser.add_argument( 277 | "--smoothing", type=float, default=0.1, help="Label smoothing (default: 0.1)" 278 | ) 279 | parser.add_argument( 280 | "--train_interpolation", 281 | type=str, 282 | default="bicubic", 283 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")', 284 | ) 285 | 286 | # Evaluation parameters 287 | parser.add_argument("--crop_pct", type=float, default=None) 288 | parser.add_argument("--short_side_size", type=int, default=224) 289 | parser.add_argument("--test_num_segment", type=int, default=5) 290 | parser.add_argument("--test_num_crop", type=int, default=3) 291 | 292 | # Random Erase params 293 | parser.add_argument( 294 | "--reprob", 295 | type=float, 296 | default=0.25, 297 | metavar="PCT", 298 | help="Random erase prob (default: 0.25)", 299 | ) 300 | parser.add_argument( 301 | "--remode", 302 | type=str, 303 | default="pixel", 304 | help='Random erase mode (default: "pixel")', 305 | ) 306 | parser.add_argument( 307 | "--recount", type=int, default=1, help="Random erase count (default: 1)" 308 | ) 309 | parser.add_argument( 310 | "--resplit", 311 | action="store_true", 312 | default=False, 313 | help="Do not random erase first (clean) augmentation split", 314 | ) 315 | 316 | # Mixup params 317 | parser.add_argument( 318 | "--mixup", 319 | type=float, 320 | default=0, 321 | help="mixup alpha, mixup enabled if > 0, default 0.8.", 322 | ) 323 | parser.add_argument( 324 | "--cutmix", 325 | type=float, 326 | default=0, 327 | help="cutmix alpha, cutmix enabled if > 0, default 1.0.", 328 | ) 329 | parser.add_argument( 330 | "--cutmix_minmax", 331 | type=float, 332 | nargs="+", 333 | default=None, 334 | help="cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)", 335 | ) 336 | parser.add_argument( 337 | "--mixup_prob", 338 | type=float, 339 | default=1.0, 340 | help="Probability of performing mixup or cutmix when either/both is enabled", 341 | ) 342 | parser.add_argument( 343 | "--mixup_switch_prob", 344 | type=float, 345 | default=0.5, 346 | help="Probability of switching to cutmix when both mixup and cutmix enabled", 347 | ) 348 | parser.add_argument( 349 | "--mixup_mode", 350 | type=str, 351 | default="batch", 352 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"', 353 | ) 354 | 355 | # Finetuning params 356 | parser.add_argument("--finetune", default="", help="finetune from checkpoint") 357 | parser.add_argument("--model_key", default="model|module", type=str) 358 | parser.add_argument("--model_prefix", default="", type=str) 359 | parser.add_argument("--init_scale", default=0.001, type=float) 360 | parser.add_argument("--use_checkpoint", action="store_true") 361 | parser.set_defaults(use_checkpoint=False) 362 | parser.add_argument("--use_mean_pooling", action="store_true") 363 | parser.set_defaults(use_mean_pooling=True) 364 | parser.add_argument("--use_cls", action="store_false", dest="use_mean_pooling") 365 | 366 | # Dataset parameters 367 | parser.add_argument( 368 | "--data_path", 369 | default="/Users/yangshu/Downloads/Cataract101", 370 | # default="/Users/yangshu/Documents/SurgVideoMAE/data/cholec80", 371 | type=str, 372 | help="dataset path", 373 | ) 374 | parser.add_argument( 375 | "--eval_data_path", 376 | default="/Users/yangshu/Downloads/Cataract101", 377 | type=str, 378 | help="dataset path for evaluation", 379 | ) 380 | parser.add_argument( 381 | "--nb_classes", default=10, type=int, help="number of the classification types" 382 | ) 383 | parser.add_argument( 384 | "--imagenet_default_mean_and_std", default=True, action="store_true" 385 | ) 386 | parser.add_argument("--num_segments", type=int, default=1) 387 | parser.add_argument("--num_frames", type=int, default=8) 388 | parser.add_argument("--sampling_rate", type=int, default=2) 389 | parser.add_argument( 390 | "--data_set", 391 | default="Cataract101", 392 | choices=["Cholec80", "AutoLaparo", "Cataract101"], 393 | type=str, 394 | help="dataset", 395 | ) 396 | parser.add_argument( 397 | "--data_fps", 398 | default="1fps", 399 | choices=["", "5fps", "1fps"], 400 | type=str, 401 | help="dataset", 402 | ) 403 | parser.add_argument( 404 | "--output_dir", 405 | default="/home/yangshu/SurgVideoMAE/Cholec80/ImageNet/phase/1fps_loss", 406 | help="path where to save, empty for no saving", 407 | ) 408 | parser.add_argument( 409 | "--log_dir", 410 | default="/home/yangshu/SurgVideoMAE/Cholec80/ImageNet/phase/1fps_loss/log", 411 | help="path where to tensorboard log", 412 | ) 413 | parser.add_argument( 414 | "--device", default="cuda", help="device to use for training / testing" 415 | ) 416 | parser.add_argument("--seed", default=0, type=int) 417 | parser.add_argument("--resume", default="", help="resume from checkpoint") 418 | parser.add_argument("--auto_resume", action="store_true") 419 | parser.add_argument("--no_auto_resume", action="store_false", dest="auto_resume") 420 | parser.set_defaults(auto_resume=True) 421 | 422 | parser.add_argument("--save_ckpt", action="store_true") 423 | parser.add_argument("--no_save_ckpt", action="store_false", dest="save_ckpt") 424 | parser.set_defaults(save_ckpt=True) 425 | 426 | parser.add_argument( 427 | "--start_epoch", default=0, type=int, metavar="N", help="start epoch" 428 | ) 429 | parser.add_argument( 430 | "--eval", action="store_true", default=False, help="Perform evaluation only" 431 | ) 432 | parser.add_argument( 433 | "--dist_eval", 434 | action="store_true", 435 | default=False, 436 | help="Enabling distributed evaluation", 437 | ) 438 | parser.add_argument("--num_workers", default=10, type=int) 439 | parser.add_argument( 440 | "--pin_mem", 441 | action="store_true", 442 | help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.", 443 | ) 444 | parser.add_argument("--no_pin_mem", action="store_false", dest="pin_mem") 445 | parser.set_defaults(pin_mem=True) 446 | 447 | # distributed training parameters 448 | parser.add_argument( 449 | "--world_size", default=1, type=int, help="number of distributed processes" 450 | ) 451 | parser.add_argument("--local_rank", default=-1, type=int) 452 | parser.add_argument("--dist_on_itp", action="store_true") 453 | parser.add_argument( 454 | "--dist_url", default="env://", help="url used to set up distributed training" 455 | ) 456 | 457 | parser.add_argument("--enable_deepspeed", action="store_true", default=False) 458 | 459 | known_args, _ = parser.parse_known_args() 460 | 461 | if known_args.enable_deepspeed: 462 | try: 463 | import deepspeed 464 | from deepspeed import DeepSpeedConfig 465 | 466 | parser = deepspeed.add_config_arguments(parser) 467 | ds_init = deepspeed.initialize 468 | except: 469 | print("Please 'pip install deepspeed'") 470 | exit(0) 471 | else: 472 | ds_init = None 473 | 474 | return parser.parse_args(), ds_init -------------------------------------------------------------------------------- /datasets/convert_results/convert_autolaparo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | phases = [ 4 | "Preparation", 5 | "CalotTriangleDissection", 6 | "ClippingCutting", 7 | "GallbladderDissection", 8 | "GallbladderPackaging", 9 | "CleaningCoagulation", 10 | "GallbladderRetraction", 11 | ] 12 | 13 | def create_folder_if_not_exists(folder_path): 14 | if not os.path.exists(folder_path): 15 | os.makedirs(folder_path) 16 | print("文件夹已创建:", folder_path) 17 | else: 18 | print("文件夹已存在:", folder_path) 19 | 20 | main_path = "/Users/yangshu/Documents/PETL4SurgVideo/result_save/Ours/AutoLaparo/16-4/LGA/WIT400M/" 21 | file_path_0 = os.path.join(main_path, "0.txt") 22 | file_path_1 = os.path.join(main_path, "1.txt") 23 | anns_path = "/Users/yangshu/Documents/PETL4SurgVideo/result_save/Ours/AutoLaparo/16-4/LGA/WIT400M" + "/phase_annotations" 24 | pred_path = "/Users/yangshu/Documents/PETL4SurgVideo/result_save/Ours/AutoLaparo/16-4/LGA/WIT400M" + "/prediction" 25 | 26 | 27 | create_folder_if_not_exists(anns_path) 28 | create_folder_if_not_exists(pred_path) 29 | 30 | with open(file_path_0) as f: 31 | lines0 = f.readlines() 32 | 33 | with open(file_path_1) as f: 34 | lines1 = f.readlines() 35 | 36 | for i in range(15, 22): 37 | with open( 38 | anns_path + "/video-{}.txt".format(str(i)), "w" 39 | ) as f: 40 | f.write("Frame") 41 | f.write("\t") 42 | f.write("Phase") 43 | f.write("\n") 44 | assert len(lines0) == len(lines1) 45 | for j in range(1, len(lines0)): 46 | temp0 = lines0[j].split() 47 | temp1 = lines1[j].split() 48 | if temp0[1] == "{}".format(str(i)): 49 | f.write(str(temp0[2])) # phase_annotations 50 | f.write("\t") # phase_annotations 51 | f.write(str(temp0[-1])) # phase_annotations 52 | f.write("\n") # phase_annotations 53 | if temp1[1] == "{}".format(str(i)): 54 | f.write(str(temp1[2])) # phase_annotations 55 | f.write("\t") # phase_annotations 56 | f.write(str(temp1[-1])) # phase_annotations 57 | f.write("\n") # phase_annotations 58 | 59 | with open(file_path_0) as f: 60 | lines0 = f.readlines() 61 | 62 | with open(file_path_1) as f: 63 | lines1 = f.readlines() 64 | for i in range(15, 22): 65 | print(i) 66 | with open( 67 | pred_path + "/video-{}.txt".format(str(i)), "w" 68 | ) as f: # phase_annotations 69 | f.write("Frame") 70 | f.write("\t") 71 | f.write("Phase") 72 | f.write("\n") 73 | assert len(lines0) == len(lines1) 74 | for j in range(1, len(lines0)): 75 | temp0 = lines0[j].strip() # prediction 76 | temp1 = lines1[j].strip() # prediction 77 | data0 = np.fromstring( 78 | temp0.split("[")[1].split("]")[0], dtype=np.float32, sep="," 79 | ) # prediction 80 | data1 = np.fromstring( 81 | temp1.split("[")[1].split("]")[0], dtype=np.float32, sep="," 82 | ) # prediction 83 | data0 = data0.argmax() # prediction 84 | data1 = data1.argmax() # prediction 85 | temp0 = lines0[j].split() 86 | temp1 = lines1[j].split() 87 | if temp0[1] == "{}".format(str(i)): 88 | f.write(str(temp0[2])) # prediction 89 | f.write('\t') # prediction 90 | f.write(str(data0)) # prediction 91 | f.write('\n') # prediction 92 | if temp1[1] == "{}".format(str(i)): 93 | f.write(str(temp1[2])) # prediction 94 | f.write('\t') # prediction 95 | f.write(str(data1)) # prediction 96 | f.write('\n') # prediction 97 | -------------------------------------------------------------------------------- /datasets/convert_results/convert_cholec80.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | phases = [ 4 | "Preparation", 5 | "CalotTriangleDissection", 6 | "ClippingCutting", 7 | "GallbladderDissection", 8 | "GallbladderPackaging", 9 | "CleaningCoagulation", 10 | "GallbladderRetraction", 11 | ] 12 | 13 | def create_folder_if_not_exists(folder_path): 14 | if not os.path.exists(folder_path): 15 | os.makedirs(folder_path) 16 | print("文件夹已创建:", folder_path) 17 | else: 18 | print("文件夹已存在:", folder_path) 19 | 20 | 21 | main_path = "/Users/yangshu/Documents/PETL4SurgVideo/result_save/Dualpath/Large/" 22 | file_path_0 = os.path.join(main_path, "0.txt") 23 | file_path_1 = os.path.join(main_path, "1.txt") 24 | anns_path = "/Users/yangshu/Documents/PETL4SurgVideo/result_save/Dualpath/Large" + "/phase_annotations" 25 | pred_path = "/Users/yangshu/Documents/PETL4SurgVideo/result_save/Dualpath/Large" + "/prediction" 26 | 27 | 28 | create_folder_if_not_exists(anns_path) 29 | create_folder_if_not_exists(pred_path) 30 | 31 | with open(file_path_0) as f: 32 | lines0 = f.readlines() 33 | 34 | with open(file_path_1) as f: 35 | lines1 = f.readlines() 36 | 37 | for i in range(41, 81): 38 | with open( 39 | anns_path + "/video-{}.txt".format(str(i)), "w" 40 | ) as f: 41 | f.write("Frame") 42 | f.write("\t") 43 | f.write("Phase") 44 | f.write("\n") 45 | assert len(lines0) == len(lines1) 46 | for j in range(1, len(lines0)): 47 | temp0 = lines0[j].split() 48 | temp1 = lines1[j].split() 49 | if temp0[1] == "video{}".format(str(i)): 50 | f.write(str(temp0[2])) # phase_annotations 51 | f.write("\t") # phase_annotations 52 | f.write(str(temp0[-1])) # phase_annotations 53 | f.write("\n") # phase_annotations 54 | if temp1[1] == "video{}".format(str(i)): 55 | f.write(str(temp1[2])) # phase_annotations 56 | f.write("\t") # phase_annotations 57 | f.write(str(temp1[-1])) # phase_annotations 58 | f.write("\n") # phase_annotations 59 | 60 | with open(file_path_0) as f: 61 | lines0 = f.readlines() 62 | 63 | with open(file_path_1) as f: 64 | lines1 = f.readlines() 65 | for i in range(41, 81): 66 | print(i) 67 | with open( 68 | pred_path + "/video-{}.txt".format(str(i)), "w" 69 | ) as f: # phase_annotations 70 | f.write("Frame") 71 | f.write("\t") 72 | f.write("Phase") 73 | f.write("\n") 74 | assert len(lines0) == len(lines1) 75 | for j in range(1, len(lines0)): 76 | temp0 = lines0[j].strip() # prediction 77 | temp1 = lines1[j].strip() # prediction 78 | data0 = np.fromstring( 79 | temp0.split("[")[1].split("]")[0], dtype=np.float32, sep="," 80 | ) # prediction 81 | data1 = np.fromstring( 82 | temp1.split("[")[1].split("]")[0], dtype=np.float32, sep="," 83 | ) # prediction 84 | data0 = data0.argmax() # prediction 85 | data1 = data1.argmax() # prediction 86 | temp0 = lines0[j].split() 87 | temp1 = lines1[j].split() 88 | if temp0[1] == "video{}".format(str(i)): 89 | f.write(str(temp0[2])) # prediction 90 | f.write('\t') # prediction 91 | f.write(str(data0)) # prediction 92 | f.write('\n') # prediction 93 | if temp1[1] == "video{}".format(str(i)): 94 | f.write(str(temp1[2])) # prediction 95 | f.write('\t') # prediction 96 | f.write(str(data1)) # prediction 97 | f.write('\n') # prediction 98 | -------------------------------------------------------------------------------- /datasets/data_preprosses/generate_labels_autolaparo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import cv2 4 | import pickle 5 | from tqdm import tqdm 6 | 7 | def main(): 8 | ROOT_DIR = "/xxxxxxx/AutoLaparo/" 9 | VIDEO_NAMES = os.listdir(os.path.join(ROOT_DIR, 'frames')) 10 | VIDEO_NAMES = sorted([x for x in VIDEO_NAMES if "DS" not in x]) 11 | 12 | TRAIN_NUMBERS = np.arange(1,11).tolist() 13 | VAL_NUMBERS = np.arange(11,15).tolist() 14 | TEST_NUMBERS = np.arange(15,22).tolist() 15 | 16 | TRAIN_FRAME_NUMBERS = 0 17 | VAL_FRAME_NUMBERS = 0 18 | TEST_FRAME_NUMBERS = 0 19 | 20 | train_pkl = dict() 21 | val_pkl = dict() 22 | test_pkl = dict() 23 | 24 | unique_id = 0 25 | unique_id_train = 0 26 | unique_id_val = 0 27 | unique_id_test = 0 28 | 29 | id2phase = {1: "Preparation", 2: "Dividing Ligament and Peritoneum", 3: "Dividing Uterine Vessels and Ligament", 30 | 4: "Transecting the Vagina", 5: "Specimen Removal", 6: "Suturing", 7: "Washing"} 31 | 32 | for video_id in VIDEO_NAMES: 33 | vid_id = int(video_id) 34 | if vid_id in TRAIN_NUMBERS: 35 | unique_id = unique_id_train 36 | elif vid_id in VAL_NUMBERS: 37 | unique_id = unique_id_val 38 | elif vid_id in TEST_NUMBERS: 39 | unique_id = unique_id_test 40 | 41 | # 总帧数(frames) 42 | video_path = os.path.join(ROOT_DIR, "frames", video_id) 43 | frames_list = os.listdir(video_path) 44 | 45 | # 打开Label文件 46 | phase_path = os.path.join(ROOT_DIR, 'labels', "label_" + video_id + '.txt') 47 | phase_file = open(phase_path, 'r') 48 | phase_results = phase_file.readlines()[1:] 49 | 50 | frame_infos = list() 51 | frame_id_ = 0 52 | for frame_id in tqdm(range(0, len(frames_list))): 53 | info = dict() 54 | info['unique_id'] = unique_id 55 | info['frame_id'] = frame_id_ 56 | info['original_frame_id'] = frame_id 57 | info['video_id'] = video_id 58 | info['tool_gt'] = None 59 | info['frames'] = len(frames_list) 60 | phase = phase_results[frame_id].strip().split() 61 | assert int(phase[0]) == frame_id + 1 62 | phase_id = int(phase[1]) 63 | info['phase_gt'] = phase_id - 1 64 | info['phase_name'] = id2phase[int(phase[1])] 65 | info['fps'] = 1 66 | frame_infos.append(info) 67 | unique_id += 1 68 | frame_id_ += 1 69 | 70 | if vid_id in TRAIN_NUMBERS: 71 | train_pkl[video_id] = frame_infos 72 | TRAIN_FRAME_NUMBERS += len(frames_list) 73 | unique_id_train = unique_id 74 | elif vid_id in VAL_NUMBERS: 75 | val_pkl[video_id] = frame_infos 76 | VAL_FRAME_NUMBERS += len(frames_list) 77 | unique_id_val = unique_id 78 | elif vid_id in TEST_NUMBERS: 79 | test_pkl[video_id] = frame_infos 80 | TEST_FRAME_NUMBERS += len(frames_list) 81 | unique_id_test = unique_id 82 | 83 | train_save_dir = os.path.join(ROOT_DIR, 'labels_pkl', 'train') 84 | os.makedirs(train_save_dir, exist_ok=True) 85 | with open(os.path.join(train_save_dir, '1fpstrain.pickle'), 'wb') as file: 86 | pickle.dump(train_pkl, file) 87 | 88 | val_save_dir = os.path.join(ROOT_DIR, 'labels_pkl', 'val') 89 | os.makedirs(val_save_dir, exist_ok=True) 90 | with open(os.path.join(val_save_dir, '1fpsval.pickle'), 'wb') as file: 91 | pickle.dump(val_pkl, file) 92 | 93 | test_save_dir = os.path.join(ROOT_DIR, 'labels_pkl', 'test') 94 | os.makedirs(test_save_dir, exist_ok=True) 95 | with open(os.path.join(test_save_dir, '1fpstest.pickle'), 'wb') as file: 96 | pickle.dump(test_pkl, file) 97 | 98 | 99 | print('TRAIN Frams', TRAIN_FRAME_NUMBERS, unique_id_train) 100 | print('VAL Frams', VAL_FRAME_NUMBERS, unique_id_val) 101 | print('TEST Frams', TEST_FRAME_NUMBERS, unique_id_test) 102 | 103 | if __name__ == '__main__': 104 | main() -------------------------------------------------------------------------------- /datasets/data_preprosses/generate_labels_ch80.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import cv2 4 | import pickle 5 | from tqdm import tqdm 6 | 7 | def main(): 8 | ROOT_DIR = "/xxxxxxx/cholec80" 9 | VIDEO_NAMES = os.listdir(os.path.join(ROOT_DIR, 'videos')) 10 | VIDEO_NAMES = sorted([x for x in VIDEO_NAMES if 'mp4' in x]) 11 | TRAIN_NUMBERS = np.arange(1,41).tolist() 12 | VAL_NUMBERS = np.arange(41,49).tolist() 13 | TEST_NUMBERS = np.arange(49,81).tolist() 14 | 15 | TRAIN_FRAME_NUMBERS = 0 16 | VAL_FRAME_NUMBERS = 0 17 | TEST_FRAME_NUMBERS = 0 18 | 19 | train_pkl = dict() 20 | val_pkl = dict() 21 | test_pkl = dict() 22 | val_test_pkl = dict() 23 | 24 | unique_id = 0 25 | unique_id_train = 0 26 | unique_id_val = 0 27 | unique_id_test = 0 28 | 29 | phase2id = {'Preparation': 0, 'CalotTriangleDissection': 1, 'ClippingCutting': 2, 'GallbladderDissection': 3, 30 | 'GallbladderPackaging': 4, 'CleaningCoagulation': 5, 'GallbladderRetraction': 6} 31 | 32 | for video_name in VIDEO_NAMES: 33 | video_id = video_name.replace('.mp4', '') 34 | vid_id = int(video_name.replace('.mp4', '').replace("video", "")) 35 | if vid_id in TRAIN_NUMBERS: 36 | unique_id = unique_id_train 37 | elif vid_id in VAL_NUMBERS: 38 | unique_id = unique_id_val 39 | elif vid_id in TEST_NUMBERS: 40 | unique_id = unique_id_test 41 | 42 | # 打开视频文件 43 | vidcap = cv2.VideoCapture(os.path.join(ROOT_DIR, './videos/' + video_name)) 44 | # 帧率(frames per second) 45 | fps = vidcap.get(cv2.CAP_PROP_FPS) 46 | if fps != 25: 47 | print(video_name, 'not at 25fps', fps) 48 | # 总帧数(frames) 49 | frames = vidcap.get(cv2.CAP_PROP_FRAME_COUNT) 50 | 51 | # 打开Label文件 52 | tool_path = os.path.join(ROOT_DIR, 'tool_annotations', video_name.replace('.mp4', '-tool.txt')) 53 | tool_file = open(tool_path, 'r') 54 | tool = tool_file.readline().strip().split() 55 | tool_name = tool[1:] 56 | tool_dict = dict() 57 | while tool: 58 | tool = tool_file.readline().strip().split() 59 | if len(tool) > 0: 60 | tool = list(map(int, tool)) 61 | tool_dict[str(tool[0])] = tool[1:] 62 | 63 | phase_path = os.path.join(ROOT_DIR, 'phase_annotations', video_name.replace('.mp4', '-phase.txt')) 64 | phase_file = open(phase_path, 'r') 65 | phase_results = phase_file.readlines()[1:] 66 | 67 | frame_infos = list() 68 | frame_id_ = 0 69 | for frame_id in tqdm(range(0, int(frames), 25)): 70 | info = dict() 71 | info['unique_id'] = unique_id 72 | info['frame_id'] = frame_id_ 73 | info['original_frame_id'] = frame_id 74 | info['video_id'] = video_id 75 | 76 | if str(frame_id) in tool_dict: 77 | info['tool_gt'] = tool_dict[str(frame_id)] 78 | else: 79 | info['tool_gt'] = None 80 | 81 | phase = phase_results[frame_id].strip().split() 82 | assert int(phase[0]) == frame_id 83 | phase_id = phase2id[phase[1]] 84 | info['phase_gt'] = phase_id 85 | info['phase_name'] = phase[1] 86 | info['fps'] = 1 87 | info['original_frames'] = int(frames) 88 | info['frames'] = int(frames) // 25 89 | # info['tool_names'] = tool_name 90 | info['phase_name'] = phase[1] 91 | frame_infos.append(info) 92 | unique_id += 1 93 | frame_id_ += 1 94 | 95 | vid_id = int(video_name.replace('.mp4', '').replace("video", "")) 96 | if vid_id in TRAIN_NUMBERS: 97 | train_pkl[video_id] = frame_infos 98 | TRAIN_FRAME_NUMBERS += frames 99 | unique_id_train = unique_id 100 | elif vid_id in VAL_NUMBERS: 101 | val_pkl[video_id] = frame_infos 102 | VAL_FRAME_NUMBERS += frames 103 | unique_id_val = unique_id 104 | elif vid_id in TEST_NUMBERS: 105 | test_pkl[video_id] = frame_infos 106 | TEST_FRAME_NUMBERS += frames 107 | unique_id_test = unique_id 108 | 109 | val_test_pkl = {**val_pkl, **test_pkl} 110 | 111 | train_save_dir = os.path.join(ROOT_DIR, 'labels', 'train') 112 | os.makedirs(train_save_dir, exist_ok=True) 113 | with open(os.path.join(train_save_dir, '1fpstrain.pickle'), 'wb') as file: 114 | pickle.dump(train_pkl, file) 115 | 116 | val_save_dir = os.path.join(ROOT_DIR, 'labels', 'val') 117 | os.makedirs(val_save_dir, exist_ok=True) 118 | with open(os.path.join(val_save_dir, '1fpsval.pickle'), 'wb') as file: 119 | pickle.dump(val_pkl, file) 120 | 121 | test_save_dir = os.path.join(ROOT_DIR, 'labels', 'test') 122 | os.makedirs(test_save_dir, exist_ok=True) 123 | with open(os.path.join(test_save_dir, '1fpstest.pickle'), 'wb') as file: 124 | pickle.dump(test_pkl, file) 125 | with open(os.path.join(test_save_dir, '1fpsval_test.pickle'), 'wb') as file: 126 | pickle.dump(val_test_pkl, file) 127 | 128 | 129 | print('TRAIN Frams', TRAIN_FRAME_NUMBERS, unique_id_train) 130 | print('VAL Frams', VAL_FRAME_NUMBERS, unique_id_val) 131 | print('TEST Frams', TEST_FRAME_NUMBERS, unique_id_test) 132 | print('VAL TEST Frames', len(val_test_pkl)) 133 | 134 | if __name__ == '__main__': 135 | main() 136 | -------------------------------------------------------------------------------- /datasets/data_preprosses/generate_labels_lungseg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import pickle 4 | from tqdm import tqdm 5 | import random 6 | 7 | def main(): 8 | ROOT_DIR = "/home/diandian/Diandian/DD/PmNet/data/PmLR50" 9 | VIDEO_NAMES = os.listdir(os.path.join(ROOT_DIR, "frames")) 10 | VIDEO_NAMES = sorted([x for x in VIDEO_NAMES if "DS" not in x]) 11 | VIDEO_NUMBERS = len(VIDEO_NAMES) 12 | TRAIN_NUMBERS = list(range(1, 34)) 13 | TRAIN_NUMBERS.append(49) 14 | TRAIN_NUMBERS.append(50) 15 | # TRAIN_NUMBERS = sorted(random.sample(range(VIDEO_NUMBERS), 30)) 16 | print(TRAIN_NUMBERS) 17 | # TEST_NUMBERS = [x for x in range(VIDEO_NUMBERS) if x not in TRAIN_NUMBERS] 18 | TEST_NUMBERS = list(range(34, 38)) 19 | TEST_NUMBERS.append(48) 20 | print(TEST_NUMBERS) 21 | INFER_NUMBERS = list(range(38, 48)) 22 | 23 | TRAIN_FRAME_NUMBERS = 0 24 | TEST_FRAME_NUMBERS = 0 25 | INFER_FRAME_NUMBERS = 0 26 | 27 | train_pkl = dict() 28 | test_pkl = dict() 29 | infer_pkl = dict() 30 | 31 | unique_id = 0 32 | unique_id_train = 0 33 | unique_id_test = 0 34 | unique_id_infer = 0 35 | 36 | for video_id in VIDEO_NAMES: 37 | vid_id = int(video_id) 38 | 39 | if vid_id in TRAIN_NUMBERS: 40 | unique_id = unique_id_train 41 | elif vid_id in TEST_NUMBERS: 42 | unique_id = unique_id_test 43 | elif vid_id in INFER_NUMBERS: 44 | unique_id = unique_id_infer 45 | 46 | # 总帧数(frames) 47 | video_path = os.path.join(ROOT_DIR, "frames", video_id) 48 | frames_list = os.listdir(video_path) 49 | frames_list = sorted([x for x in frames_list if "jpg" in x]) 50 | 51 | # 打开Label文件 52 | anno_path = os.path.join(ROOT_DIR, 'phase_annotations', video_id + '.txt') 53 | anno_file = open(anno_path, 'r') 54 | anno_results = anno_file.readlines()[1:] 55 | print(len(frames_list)) 56 | print(len(anno_results)) 57 | assert len(frames_list) == len(anno_results) 58 | frame_infos = list() 59 | for frame_id in tqdm(range(0, len(frames_list))): 60 | info = dict() 61 | info['unique_id'] = unique_id 62 | info['frame_id'] = frame_id 63 | info['video_id'] = video_id 64 | info['frames'] = len(frames_list) 65 | anno_info = anno_results[frame_id] 66 | anno_frame = anno_info.split()[0] 67 | assert int(anno_frame) == frame_id 68 | anno_id = anno_info.split()[1] 69 | info['phase_gt'] = int(anno_id) 70 | info['fps'] = 1 71 | frame_infos.append(info) 72 | unique_id += 1 73 | 74 | if vid_id in TRAIN_NUMBERS: 75 | train_pkl[video_id] = frame_infos 76 | TRAIN_FRAME_NUMBERS += len(frames_list) 77 | unique_id_train = unique_id 78 | elif vid_id in TEST_NUMBERS: 79 | test_pkl[video_id] = frame_infos 80 | TEST_FRAME_NUMBERS += len(frames_list) 81 | unique_id_test = unique_id 82 | elif vid_id in INFER_NUMBERS: 83 | infer_pkl[video_id] = frame_infos 84 | INFER_FRAME_NUMBERS += len(frames_list) 85 | unique_id_infer = unique_id 86 | 87 | train_save_dir = os.path.join(ROOT_DIR, 'labels', 'train') 88 | os.makedirs(train_save_dir, exist_ok=True) 89 | with open(os.path.join(train_save_dir, '1fpstrain.pickle'), 'wb') as file: 90 | pickle.dump(train_pkl, file) 91 | 92 | test_save_dir = os.path.join(ROOT_DIR, 'labels', 'test') 93 | os.makedirs(test_save_dir, exist_ok=True) 94 | with open(os.path.join(test_save_dir, '1fpstest.pickle'), 'wb') as file: 95 | pickle.dump(test_pkl, file) 96 | 97 | infer_save_dir = os.path.join(ROOT_DIR, 'labels', 'infer') 98 | os.makedirs(infer_save_dir, exist_ok=True) 99 | with open(os.path.join(infer_save_dir, '1fpsinfer.pickle'), 'wb') as file: 100 | pickle.dump(infer_pkl, file) 101 | 102 | print('TRAIN Frams', TRAIN_FRAME_NUMBERS, unique_id_train) 103 | print('TEST Frams', TEST_FRAME_NUMBERS, unique_id_test) 104 | print('INFER Frams', INFER_FRAME_NUMBERS, unique_id_infer) 105 | 106 | if __name__ == "__main__": 107 | main() 108 | -------------------------------------------------------------------------------- /datasets/data_preprosses/generate_labels_pmlr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import pickle 4 | from tqdm import tqdm 5 | import random 6 | import json 7 | 8 | def main(): 9 | ROOT_DIR = "/home/diandian/Diandian/DD/PmNet/data/PmLR50" 10 | VIDEO_NAMES = os.listdir(os.path.join(ROOT_DIR, "frames")) 11 | VIDEO_NAMES = sorted([x for x in VIDEO_NAMES if "DS" not in x]) 12 | VIDEO_NUMBERS = len(VIDEO_NAMES) 13 | TRAIN_NUMBERS = list(range(1, 34)) 14 | TRAIN_NUMBERS.append(49) 15 | TRAIN_NUMBERS.append(50) 16 | print('For training:', TRAIN_NUMBERS) 17 | TEST_NUMBERS = list(range(34, 38)) 18 | TEST_NUMBERS.append(48) 19 | print('For validation:', TEST_NUMBERS) 20 | INFER_NUMBERS = list(range(38, 48)) 21 | print('For testing:', INFER_NUMBERS) 22 | 23 | total_frame_count = 0 24 | for video_id in VIDEO_NAMES: 25 | if int(video_id )in TRAIN_NUMBERS: 26 | video_path = os.path.join(ROOT_DIR, "frames", video_id) 27 | total_frame_count += len(os.listdir(video_path)) 28 | average_frame_number = total_frame_count/len(TRAIN_NUMBERS) 29 | 30 | TRAIN_FRAME_NUMBERS = 0 31 | TEST_FRAME_NUMBERS = 0 32 | INFER_FRAME_NUMBERS = 0 33 | 34 | train_pkl = dict() 35 | test_pkl = dict() 36 | infer_pkl = dict() 37 | 38 | unique_id = 0 39 | unique_id_train = 0 40 | unique_id_test = 0 41 | unique_id_infer = 0 42 | 43 | for video_id in VIDEO_NAMES: 44 | vid_id = int(video_id) 45 | 46 | if vid_id in TRAIN_NUMBERS: 47 | unique_id = unique_id_train 48 | elif vid_id in TEST_NUMBERS: 49 | unique_id = unique_id_test 50 | elif vid_id in INFER_NUMBERS: 51 | unique_id = unique_id_infer 52 | 53 | # 总帧数(frames) 54 | video_path = os.path.join(ROOT_DIR, "frames", video_id) 55 | frames_list = os.listdir(video_path) 56 | frames_list = sorted([x for x in frames_list if "jpg" in x]) 57 | 58 | # 打开Label文件 59 | anno_path = os.path.join(ROOT_DIR, 'phase_annotations', video_id + '.txt') 60 | anno_file = open(anno_path, 'r') 61 | anno_results = anno_file.readlines()[1:] 62 | blocking_path = os.path.join(ROOT_DIR, 'blocking_annotations', video_id + '.txt') 63 | blocking_file = open(blocking_path, 'r') 64 | blocking_results = blocking_file.readlines()[1:] 65 | bbox_path = os.path.join(ROOT_DIR, 'bbox_annotations', video_id + '.json') 66 | with open(bbox_path, 'r', encoding='utf-8') as f: 67 | bbox_data = json.load(f) 68 | assert len(frames_list) == len(anno_results) 69 | frame_infos = list() 70 | for frame_id in tqdm(range(0, len(frames_list))): 71 | info = dict() 72 | info['unique_id'] = unique_id 73 | info['frame_id'] = frame_id 74 | info['video_id'] = video_id 75 | info['frames'] = len(frames_list) 76 | anno_info = anno_results[frame_id] 77 | blocking_info = blocking_results[frame_id] 78 | anno_frame = anno_info.split()[0] 79 | assert int(anno_frame) == frame_id 80 | anno_id = anno_info.split()[1] 81 | blocking_id = blocking_info.split()[1] 82 | info['phase_gt'] = int(anno_id) 83 | info['blocking_gt'] = int(blocking_id) 84 | origin_bbox = bbox_data[frames_list[frame_id].split('.')[0]] 85 | x1, x2, y1, y2 = origin_bbox[0][0], origin_bbox[1][0], origin_bbox[0][1], origin_bbox[1][1] 86 | if vid_id not in [16, 18, 46, 47, 48, 49]: 87 | x1, x2 = x1 * 224 / 1280, x2 * 224 / 1280 88 | y1, y2 = y1 * 224 / 720, y2 * 224 / 720 89 | else: 90 | x1, x2 = x1 * 224 / 1920, x2 * 224 / 1920 91 | y1, y2 = y1 * 224 / 1080, y2 * 224 / 1080 92 | 93 | x1, x2 = min(x1, x2), max(x1, x2) 94 | y1, y2 = min(y1, y2), max(y1, y2) 95 | 96 | x1, y1 = max(0, int(x1)), max(0, int(y1)) 97 | x2, y2 = min(224, int(x2)), min(224, int(y2)) 98 | info['bbox'] = [[x1, y1], [x2, y2]] 99 | info['fps'] = 1 100 | info['avg_length'] = average_frame_number 101 | frame_infos.append(info) 102 | unique_id += 1 103 | 104 | if vid_id in TRAIN_NUMBERS: 105 | train_pkl[video_id] = frame_infos 106 | TRAIN_FRAME_NUMBERS += len(frames_list) 107 | unique_id_train = unique_id 108 | elif vid_id in TEST_NUMBERS: 109 | test_pkl[video_id] = frame_infos 110 | TEST_FRAME_NUMBERS += len(frames_list) 111 | unique_id_test = unique_id 112 | elif vid_id in INFER_NUMBERS: 113 | infer_pkl[video_id] = frame_infos 114 | INFER_FRAME_NUMBERS += len(frames_list) 115 | unique_id_infer = unique_id 116 | 117 | train_save_dir = os.path.join(ROOT_DIR, 'labels', 'train') 118 | os.makedirs(train_save_dir, exist_ok=True) 119 | with open(os.path.join(train_save_dir, '1fpstrain.pickle'), 'wb') as file: 120 | pickle.dump(train_pkl, file) 121 | 122 | test_save_dir = os.path.join(ROOT_DIR, 'labels', 'test') 123 | os.makedirs(test_save_dir, exist_ok=True) 124 | with open(os.path.join(test_save_dir, '1fpstest.pickle'), 'wb') as file: 125 | pickle.dump(test_pkl, file) 126 | 127 | infer_save_dir = os.path.join(ROOT_DIR, 'labels', 'infer') 128 | os.makedirs(infer_save_dir, exist_ok=True) 129 | with open(os.path.join(infer_save_dir, '1fpsinfer.pickle'), 'wb') as file: 130 | pickle.dump(infer_pkl, file) 131 | 132 | print('TRAIN Frames', TRAIN_FRAME_NUMBERS, unique_id_train) 133 | print('TEST Frames', TEST_FRAME_NUMBERS, unique_id_test) 134 | print('INFER Frames', INFER_FRAME_NUMBERS, unique_id_infer) 135 | print('Average Length', (TRAIN_FRAME_NUMBERS)/len(TRAIN_NUMBERS)) 136 | 137 | if __name__ == "__main__": 138 | main() 139 | -------------------------------------------------------------------------------- /datasets/extract_frames/extract_frames_autolaparo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import cv2 4 | from tqdm import tqdm 5 | 6 | ROOT_DIR = "/xxxxxx/AutoLaparo" 7 | VIDEO_NAMES = os.listdir(os.path.join(ROOT_DIR, "videos")) 8 | VIDEO_NAMES = sorted([x for x in VIDEO_NAMES if 'mp4' in x]) 9 | 10 | FRAME_NUMBERS = 0 11 | 12 | for video_name in VIDEO_NAMES: 13 | print(video_name) 14 | vidcap = cv2.VideoCapture(os.path.join(ROOT_DIR, "videos", video_name)) 15 | fps = vidcap.get(cv2.CAP_PROP_FPS) 16 | print("fps", fps) 17 | success=True 18 | count=0 19 | save_dir = './frames/' + video_name.replace('.mp4', '') +'/' 20 | save_dir = os.path.join(ROOT_DIR, save_dir) 21 | os.makedirs(save_dir, exist_ok=True) 22 | while success is True: 23 | success,image = vidcap.read() 24 | if success: 25 | if count % fps == 0: 26 | cv2.imwrite(save_dir + str(int(count//fps)).zfill(5) + '.png', image) 27 | count+=1 28 | vidcap.release() 29 | cv2.destroyAllWindows() 30 | print(count) 31 | FRAME_NUMBERS += count 32 | 33 | print('Total Frams', FRAME_NUMBERS) 34 | -------------------------------------------------------------------------------- /datasets/extract_frames/extract_frames_ch80.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import cv2 4 | from tqdm import tqdm 5 | 6 | ROOT_DIR = "/xxxxxxx/cholec80" 7 | VIDEO_NAMES = os.listdir(os.path.join(ROOT_DIR, 'videos')) 8 | VIDEO_NAMES = [x for x in VIDEO_NAMES if 'mp4' in x] 9 | TRAIN_NUMBERS = np.arange(1,41).tolist() 10 | VAL_NUMBERS = np.arange(41,49).tolist() 11 | TEST_NUMBERS = np.arange(49,81).tolist() 12 | 13 | TRAIN_FRAME_NUMBERS = 0 14 | VAL_FRAME_NUMBERS = 0 15 | TEST_FRAME_NUMBERS = 0 16 | 17 | for video_name in VIDEO_NAMES: 18 | print(video_name) 19 | vidcap = cv2.VideoCapture(os.path.join(ROOT_DIR, 'videos', video_name)) 20 | fps = vidcap.get(cv2.CAP_PROP_FPS) 21 | print("fps", fps) 22 | if fps != 25: 23 | print(video_name, 'not at 25fps', fps) 24 | success=True 25 | count=0 26 | vid_id = int(video_name.replace('.mp4', '').replace("video", "")) 27 | if vid_id in TRAIN_NUMBERS: 28 | save_dir = './frames/train/' + video_name.replace('.mp4', '') +'/' 29 | elif vid_id in VAL_NUMBERS: 30 | save_dir = './frames/val/' + video_name.replace('.mp4', '') +'/' 31 | elif vid_id in TEST_NUMBERS: 32 | save_dir = './frames/test/' + video_name.replace('.mp4', '') +'/' 33 | save_dir = os.path.join(ROOT_DIR, save_dir) 34 | os.makedirs(save_dir, exist_ok=True) 35 | while success is True: 36 | success,image = vidcap.read() 37 | if success: 38 | cv2.imwrite(save_dir + str(count) + '.jpg', image) 39 | count+=1 40 | vidcap.release() 41 | cv2.destroyAllWindows() 42 | if vid_id in TRAIN_NUMBERS: 43 | TRAIN_FRAME_NUMBERS += count 44 | elif vid_id in VAL_NUMBERS: 45 | VAL_FRAME_NUMBERS += count 46 | elif vid_id in TEST_NUMBERS: 47 | TEST_FRAME_NUMBERS += count 48 | 49 | print('TRAIN Frams', TRAIN_FRAME_NUMBERS) 50 | print('VAL Frams', VAL_FRAME_NUMBERS) 51 | print('TEST Frams', TEST_FRAME_NUMBERS) -------------------------------------------------------------------------------- /datasets/functional.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import cv2 3 | import numpy as np 4 | import PIL 5 | import torch 6 | 7 | 8 | def _is_tensor_clip(clip): 9 | return torch.is_tensor(clip) and clip.ndimension() == 4 10 | 11 | 12 | def crop_clip(clip, min_h, min_w, h, w): 13 | if isinstance(clip[0], np.ndarray): 14 | cropped = [img[min_h : min_h + h, min_w : min_w + w, :] for img in clip] 15 | 16 | elif isinstance(clip[0], PIL.Image.Image): 17 | cropped = [img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip] 18 | else: 19 | raise TypeError( 20 | "Expected numpy.ndarray or PIL.Image" 21 | + "but got list of {0}".format(type(clip[0])) 22 | ) 23 | return cropped 24 | 25 | 26 | def resize_clip(clip, size, interpolation="bilinear"): 27 | if isinstance(clip[0], np.ndarray): 28 | if isinstance(size, numbers.Number): 29 | im_h, im_w, im_c = clip[0].shape 30 | # Min spatial dim already matches minimal size, 按最小边resize,如果最短边等于目标长度,则直接返回输入 31 | if (im_w <= im_h and im_w == size) or (im_h <= im_w and im_h == size): 32 | return clip 33 | new_h, new_w = get_resize_sizes(im_h, im_w, size) 34 | size = (new_w, new_h) 35 | else: 36 | size = size[0], size[1] 37 | if interpolation == "bilinear": 38 | np_inter = cv2.INTER_LINEAR 39 | else: 40 | np_inter = cv2.INTER_NEAREST 41 | scaled = [cv2.resize(img, size, interpolation=np_inter) for img in clip] 42 | elif isinstance(clip[0], PIL.Image.Image): 43 | if isinstance(size, numbers.Number): 44 | im_w, im_h = clip[0].size 45 | # Min spatial dim already matches minimal size 46 | if (im_w <= im_h and im_w == size) or (im_h <= im_w and im_h == size): 47 | return clip 48 | new_h, new_w = get_resize_sizes(im_h, im_w, size) 49 | size = (new_w, new_h) 50 | else: 51 | size = size[1], size[0] 52 | if interpolation == "bilinear": 53 | pil_inter = PIL.Image.BILINEAR 54 | else: 55 | pil_inter = PIL.Image.NEAREST 56 | scaled = [img.resize(size, pil_inter) for img in clip] 57 | else: 58 | raise TypeError( 59 | "Expected numpy.ndarray or PIL.Image" 60 | + "but got list of {0}".format(type(clip[0])) 61 | ) 62 | return scaled 63 | 64 | 65 | def get_resize_sizes(im_h, im_w, size): 66 | if im_w < im_h: 67 | ow = size 68 | oh = int(size * im_h / im_w) 69 | else: 70 | oh = size 71 | ow = int(size * im_w / im_h) 72 | return oh, ow 73 | 74 | 75 | def normalize(clip, mean, std, inplace=False): 76 | if not _is_tensor_clip(clip): 77 | raise TypeError("tensor is not a torch clip.") 78 | 79 | if not inplace: 80 | clip = clip.clone() 81 | 82 | dtype = clip.dtype 83 | mean = torch.as_tensor(mean, dtype=dtype, device=clip.device) 84 | std = torch.as_tensor(std, dtype=dtype, device=clip.device) 85 | clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) 86 | 87 | return clip 88 | -------------------------------------------------------------------------------- /datasets/phase/__pycache__/AutoLaparo_phase.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/phase/__pycache__/AutoLaparo_phase.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/phase/__pycache__/AutoLaparo_phase.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/phase/__pycache__/AutoLaparo_phase.cpython-311.pyc -------------------------------------------------------------------------------- /datasets/phase/__pycache__/Cholec80_phase.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/phase/__pycache__/Cholec80_phase.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/phase/__pycache__/Cholec80_phase.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/phase/__pycache__/Cholec80_phase.cpython-311.pyc -------------------------------------------------------------------------------- /datasets/phase/__pycache__/PmLR50_phase.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/phase/__pycache__/PmLR50_phase.cpython-311.pyc -------------------------------------------------------------------------------- /datasets/tools/frame_cutmargin.py: -------------------------------------------------------------------------------- 1 | # ----------------------------- 2 | # Cut black margin for surgical video 3 | # Copyright (c) CUHK 2021. 4 | # IEEE TMI 'Temporal Relation Network for Workflow Recognition from Surgical Video' 5 | # ----------------------------- 6 | 7 | import cv2 8 | import os 9 | import numpy as np 10 | import multiprocessing 11 | from tqdm import tqdm 12 | 13 | 14 | def create_directory_if_not_exists(path): 15 | if not os.path.exists(path): 16 | os.makedirs(path) 17 | 18 | 19 | def filter_black(image): 20 | binary_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 21 | _, binary_image2 = cv2.threshold(binary_image, 15, 255, cv2.THRESH_BINARY) 22 | binary_image2 = cv2.medianBlur( 23 | binary_image2, 19 24 | ) # filter the noise, need to adjust the parameter based on the dataset 25 | x = binary_image2.shape[0] 26 | y = binary_image2.shape[1] 27 | 28 | edges_x = [] 29 | edges_y = [] 30 | for i in range(x): 31 | for j in range(10, y - 10): 32 | if binary_image2.item(i, j) != 0: 33 | edges_x.append(i) 34 | edges_y.append(j) 35 | 36 | if not edges_x: 37 | return image 38 | 39 | left = min(edges_x) # left border 40 | right = max(edges_x) # right 41 | width = right - left 42 | bottom = min(edges_y) # bottom 43 | top = max(edges_y) # top 44 | height = top - bottom 45 | 46 | pre1_picture = image[left : left + width, bottom : bottom + height] 47 | 48 | return pre1_picture 49 | 50 | 51 | def process_image(image_source, image_save): 52 | frame = cv2.imread(image_source) 53 | dim = (int(frame.shape[1] / frame.shape[0] * 300), 300) 54 | frame = cv2.resize(frame, dim) 55 | frame = filter_black(frame) 56 | 57 | img_result = cv2.resize(frame, (250, 250)) 58 | cv2.imwrite(image_save, img_result) 59 | 60 | 61 | def process_video(video_id, video_source, video_save): 62 | create_directory_if_not_exists(video_save) 63 | 64 | for image_id in sorted(os.listdir(video_source)): 65 | if image_id == ".DS_Store": 66 | continue 67 | image_source = os.path.join(video_source, image_id) 68 | image_save = os.path.join(video_save, image_id) 69 | 70 | process_image(image_source, image_save) 71 | 72 | 73 | if __name__ == "__main__": 74 | # source_path = "/jhcnas1/yangshu/data/cholec80/frames" # original path 75 | # save_path = "/jhcnas1/yangshu/data/cholec80/frames_cutmargin" # save path 76 | 77 | source_path = "data/cholec80/frames" # original path 78 | save_path = "data/cholec80/frames_cutmargin" # save path 79 | 80 | create_directory_if_not_exists(save_path) 81 | 82 | processes = [] 83 | 84 | for data_split in os.listdir(source_path): 85 | if data_split == ".DS_Store": 86 | continue 87 | data_source = os.path.join(source_path, data_split) 88 | data_save = os.path.join(save_path, data_split) 89 | 90 | for video_id in tqdm(os.listdir(data_source)): 91 | if video_id == ".DS_Store": 92 | continue 93 | video_source = os.path.join(data_source, video_id) 94 | video_save = os.path.join(data_save, video_id) 95 | 96 | process = multiprocessing.Process( 97 | target=process_video, args=(video_id, video_source, video_save) 98 | ) 99 | process.start() 100 | processes.append(process) 101 | 102 | for process in processes: 103 | process.join() 104 | 105 | print("Cut Done") 106 | -------------------------------------------------------------------------------- /datasets/tools/resize_frame.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from tqdm import tqdm 4 | def resize_dataset(input_folder, output_folder, target_size=320): 5 | # 确保输出文件夹存在 6 | if not os.path.exists(output_folder): 7 | os.makedirs(output_folder) 8 | 9 | # 遍历输入文件夹下的所有子文件夹 10 | for subdir, dirs, files in os.walk(input_folder): 11 | # 获取相对路径 12 | rel_path = os.path.relpath(subdir, input_folder) 13 | # 创建输出子文件夹 14 | output_subdir = os.path.join(output_folder, rel_path) 15 | if not os.path.exists(output_subdir): 16 | os.makedirs(output_subdir) 17 | # 遍历当前子文件夹下的所有图像文件 18 | for file in tqdm(files): 19 | # 判断文件是否为图像文件 20 | if file.endswith('.jpg') or file.endswith('.png'): 21 | # 构建输入和输出文件路径 22 | input_path = os.path.join(subdir, file) 23 | output_path = os.path.join(output_subdir, file) 24 | # 打开图像并进行 resize 25 | img = Image.open(input_path) 26 | width, height = img.size 27 | if width < height: 28 | new_width = target_size 29 | new_height = int(height * target_size / width) 30 | else: 31 | new_height = target_size 32 | new_width = int(width * target_size / height) 33 | size = (new_width, new_height) 34 | img = img.resize(size, resample=Image.BILINEAR) 35 | 36 | # 保存图像到输出文件夹 37 | img.save(output_path) 38 | print(f'Processed: {subdir} -> {output_subdir}') 39 | print(len(os.listdir(subdir)), len(os.listdir(output_subdir))) 40 | 41 | if __name__ == "__main__": 42 | # 定义输入和输出文件夹 43 | input_folder = '/jhcnas4/syangcw/AutoLaparo/frames' 44 | output_folder = '/jhcnas4/syangcw/AutoLaparo/frames_resized' 45 | resize_dataset(input_folder, output_folder, 250) -------------------------------------------------------------------------------- /datasets/tools/transfer_csv_txt.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | from collections import defaultdict 4 | import cv2 5 | csv_phase = "/Users/yangshu/Downloads/cataract-101/annotations.csv" 6 | save_txt_folder = "/Users/yangshu/Downloads/cataract-101/phase_annotations" 7 | # 读取CSV文件 8 | videos = defaultdict(list) 9 | with open(csv_phase, "r") as csvfile: 10 | reader = csv.reader(csvfile) 11 | next(reader) # 跳过表头行 12 | for row in reader: 13 | video_id, frame_number, phase_id = row[0].split(";") 14 | # 生成逐帧标注的txt文件名 15 | videos[video_id].append([video_id, frame_number, phase_id]) 16 | 17 | count_1 = 0 18 | count_2 = 0 19 | for video_id in videos.keys(): 20 | video_detail = videos[video_id] 21 | 22 | frames_position = [int(k[1]) for k in video_detail] 23 | phase_id = [int(k[2]) for k in video_detail] 24 | if phase_id[-1] == phase_id[-2]: 25 | count_1 += 1 26 | txt_filename = os.path.join(save_txt_folder, f"{video_id}.txt") 27 | 28 | # 打开txt文件并写入逐帧标注 29 | with open(txt_filename, "a") as txtfile: 30 | txtfile.write("Frame\tPhase\n") # 写入标题行 31 | phase = phase_id[0] 32 | for frame in range(frames_position[0], frames_position[-1]+1): 33 | if frame in frames_position: 34 | position = frames_position.index(frame) 35 | phase = phase_id[position] 36 | txtfile.write(f"{frame}\t{phase}\n") 37 | else: 38 | count_2 += 1 39 | txt_filename = os.path.join(save_txt_folder, f"{video_id}.txt") 40 | 41 | frame_folder = os.path.join("/Users/yangshu/Downloads/cataract-101/videos/", "case_" + str(video_id)+".mp4") 42 | vidcap = cv2.VideoCapture(frame_folder) 43 | length = int(vidcap.get(7)) 44 | # 打开txt文件并写入逐帧标注 45 | with open(txt_filename, "a") as txtfile: 46 | txtfile.write("Frame\tPhase\n") # 写入标题行 47 | phase = phase_id[0] 48 | for frame in range(frames_position[0], length): 49 | if frame in frames_position: 50 | position = frames_position.index(frame) 51 | phase = phase_id[position] 52 | txtfile.write(f"{frame}\t{phase}\n") 53 | print(count_1) 54 | print(count_2) 55 | -------------------------------------------------------------------------------- /datasets/tools/transfer_json_txt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | def read_json_file(file_path): 5 | with open(file_path, 'r') as file: 6 | data = json.load(file) 7 | return data 8 | 9 | dataset_path = "/Users/yangshu/Downloads/CholecT50" 10 | labels_path = "/Users/yangshu/Downloads/CholecT50/labels" 11 | save_path = "/Users/yangshu/Downloads/CholecT50/verb" 12 | 13 | if not os.path.exists(save_path): 14 | os.makedirs(save_path) 15 | 16 | video_labels = sorted(os.listdir(labels_path)) 17 | 18 | total_num_triplet = 0 19 | 20 | for video_label in video_labels: 21 | if "DS" in video_label: 22 | continue 23 | data_labels = dict() 24 | video_name = video_label.split('.')[0] 25 | txt_filename = os.path.join(save_path, video_name + ".txt") 26 | # 读取JSON文件 27 | label_path = os.path.join(labels_path, video_label) 28 | file_path = os.path.join(save_path, video_label) 29 | data = read_json_file(label_path) 30 | categpries = data['categories'] 31 | instrument = categpries['instrument'] 32 | verb = categpries['verb'] 33 | target = categpries['target'] 34 | triplet = categpries['triplet'] 35 | phase = categpries['phase'] 36 | 37 | rs = dict() 38 | anns = data['annotations'] 39 | sorted_anns = dict(sorted(anns.items(), key=lambda x: int(x[0]))) 40 | for k, v in sorted_anns.items(): 41 | r = list() 42 | for instance in v: 43 | if instance[0] == -1: 44 | continue 45 | # triplet = instance[0] 46 | # r.append(str(triplet)) 47 | # instrument = instance[1] 48 | # r.append(str(instrument)) 49 | verb = instance[7] 50 | r.append(str(verb)) 51 | # target = instance[8] 52 | # r.append(str(target)) 53 | # phase = instance[14] 54 | # r.append(str(phase)) 55 | 56 | # if len(r) >1: 57 | # print(k,video_name) 58 | # if len(set(r)) != len(r): 59 | # print(r) 60 | # print('======') 61 | # rs[k] = set(r) 62 | rs[k] = r 63 | with open(txt_filename, "a") as txtfile: 64 | txtfile.write("Frame\tVerb\n") # 写入标题行 65 | for frame, triplet in rs.items(): 66 | if len(triplet) == 0: 67 | target = -1 68 | else: 69 | target = ",".join(triplet) 70 | txtfile.write(f"{frame}\t{target}\n") -------------------------------------------------------------------------------- /datasets/transforms/__pycache__/mixup.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/transforms/__pycache__/mixup.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/transforms/__pycache__/mixup.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/transforms/__pycache__/mixup.cpython-311.pyc -------------------------------------------------------------------------------- /datasets/transforms/__pycache__/optim_factory.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/transforms/__pycache__/optim_factory.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/transforms/__pycache__/optim_factory.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/transforms/__pycache__/optim_factory.cpython-311.pyc -------------------------------------------------------------------------------- /datasets/transforms/__pycache__/rand_augment.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/transforms/__pycache__/rand_augment.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/transforms/__pycache__/rand_augment.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/transforms/__pycache__/rand_augment.cpython-311.pyc -------------------------------------------------------------------------------- /datasets/transforms/__pycache__/random_erasing.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/transforms/__pycache__/random_erasing.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/transforms/__pycache__/random_erasing.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/transforms/__pycache__/random_erasing.cpython-311.pyc -------------------------------------------------------------------------------- /datasets/transforms/__pycache__/surg_transforms.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/transforms/__pycache__/surg_transforms.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/transforms/__pycache__/surg_transforms.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/transforms/__pycache__/surg_transforms.cpython-311.pyc -------------------------------------------------------------------------------- /datasets/transforms/__pycache__/video_transforms.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/transforms/__pycache__/video_transforms.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/transforms/__pycache__/video_transforms.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/transforms/__pycache__/video_transforms.cpython-311.pyc -------------------------------------------------------------------------------- /datasets/transforms/__pycache__/volume_transforms.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/transforms/__pycache__/volume_transforms.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/transforms/__pycache__/volume_transforms.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/transforms/__pycache__/volume_transforms.cpython-311.pyc -------------------------------------------------------------------------------- /datasets/transforms/image_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image, ImageOps 3 | import torch 4 | import numbers 5 | import random 6 | import torchvision.transforms.functional as TF 7 | 8 | 9 | class RandomCrop(object): 10 | 11 | def __init__(self, size, padding=0): 12 | if isinstance(size, numbers.Number): 13 | self.size = (int(size), int(size)) 14 | else: 15 | self.size = size 16 | self.padding = padding 17 | self.count = 0 18 | 19 | def __call__(self, img): 20 | 21 | if self.padding > 0: 22 | img = ImageOps.expand(img, border=self.padding, fill=0) 23 | 24 | w, h = img.size 25 | th, tw = self.size 26 | if w == tw and h == th: 27 | return img 28 | 29 | random.seed(self.count) 30 | x1 = random.randint(0, w - tw) 31 | y1 = random.randint(0, h - th) 32 | # print(self.count, x1, y1) 33 | self.count += 1 34 | return img.crop((x1, y1, x1 + tw, y1 + th)) 35 | 36 | 37 | class RandomHorizontalFlip(object): 38 | def __init__(self): 39 | self.count = 0 40 | 41 | def __call__(self, img): 42 | seed = self.count 43 | random.seed(seed) 44 | prob = random.random() 45 | self.count += 1 46 | # print(self.count, seed, prob) 47 | if prob < 0.5: 48 | return img.transpose(Image.FLIP_LEFT_RIGHT) 49 | return img 50 | 51 | 52 | class RandomRotation(object): 53 | def __init__(self,degrees): 54 | self.degrees = degrees 55 | self.count = 0 56 | 57 | def __call__(self, img): 58 | seed = self.count 59 | random.seed(seed) 60 | self.count += 1 61 | angle = random.randint(-self.degrees,self.degrees) 62 | return TF.rotate(img, angle) 63 | 64 | 65 | class ColorJitter(object): 66 | def __init__(self,brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1): 67 | self.brightness = brightness 68 | self.contrast = contrast 69 | self.saturation = saturation 70 | self.hue = hue 71 | self.count = 0 72 | 73 | def __call__(self, img): 74 | seed = self.count 75 | random.seed(seed) 76 | self.count += 1 77 | brightness_factor = random.uniform(1 - self.brightness, 1 + self.brightness) 78 | contrast_factor = random.uniform(1 - self.contrast, 1 + self.contrast) 79 | saturation_factor = random.uniform(1 - self.saturation, 1 + self.saturation) 80 | hue_factor = random.uniform(- self.hue, self.hue) 81 | 82 | img_ = TF.adjust_brightness(img,brightness_factor) 83 | img_ = TF.adjust_contrast(img_,contrast_factor) 84 | img_ = TF.adjust_saturation(img_,saturation_factor) 85 | img_ = TF.adjust_hue(img_,hue_factor) 86 | 87 | return img_ -------------------------------------------------------------------------------- /datasets/transforms/mixup.py: -------------------------------------------------------------------------------- 1 | """ Mixup and Cutmix 2 | 3 | Papers: 4 | mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412) 5 | 6 | CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899) 7 | 8 | Code Reference: 9 | CutMix: https://github.com/clovaai/CutMix-PyTorch 10 | 11 | Hacked together by / Copyright 2019, Ross Wightman 12 | """ 13 | import numpy as np 14 | import torch 15 | 16 | 17 | def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'): 18 | x = x.long().view(-1, 1) 19 | return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value) 20 | 21 | 22 | def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'): 23 | off_value = smoothing / num_classes 24 | on_value = 1. - smoothing + off_value 25 | y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device) 26 | y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device) 27 | return y1 * lam + y2 * (1. - lam) 28 | 29 | 30 | def rand_bbox(img_shape, lam, margin=0., count=None): 31 | """ Standard CutMix bounding-box 32 | Generates a random square bbox based on lambda value. This impl includes 33 | support for enforcing a border margin as percent of bbox dimensions. 34 | 35 | Args: 36 | img_shape (tuple): Image shape as tuple 37 | lam (float): Cutmix lambda value 38 | margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image) 39 | count (int): Number of bbox to generate 40 | """ 41 | ratio = np.sqrt(1 - lam) 42 | img_h, img_w = img_shape[-2:] 43 | cut_h, cut_w = int(img_h * ratio), int(img_w * ratio) 44 | margin_y, margin_x = int(margin * cut_h), int(margin * cut_w) 45 | cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count) 46 | cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count) 47 | yl = np.clip(cy - cut_h // 2, 0, img_h) 48 | yh = np.clip(cy + cut_h // 2, 0, img_h) 49 | xl = np.clip(cx - cut_w // 2, 0, img_w) 50 | xh = np.clip(cx + cut_w // 2, 0, img_w) 51 | return yl, yh, xl, xh 52 | 53 | 54 | def rand_bbox_minmax(img_shape, minmax, count=None): 55 | """ Min-Max CutMix bounding-box 56 | Inspired by Darknet cutmix impl, generates a random rectangular bbox 57 | based on min/max percent values applied to each dimension of the input image. 58 | 59 | Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max. 60 | 61 | Args: 62 | img_shape (tuple): Image shape as tuple 63 | minmax (tuple or list): Min and max bbox ratios (as percent of image size) 64 | count (int): Number of bbox to generate 65 | """ 66 | assert len(minmax) == 2 67 | img_h, img_w = img_shape[-2:] 68 | cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count) 69 | cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count) 70 | yl = np.random.randint(0, img_h - cut_h, size=count) 71 | xl = np.random.randint(0, img_w - cut_w, size=count) 72 | yu = yl + cut_h 73 | xu = xl + cut_w 74 | return yl, yu, xl, xu 75 | 76 | 77 | def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None): 78 | """ Generate bbox and apply lambda correction. 79 | """ 80 | if ratio_minmax is not None: 81 | yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count) 82 | else: 83 | yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count) 84 | if correct_lam or ratio_minmax is not None: 85 | bbox_area = (yu - yl) * (xu - xl) 86 | lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1]) 87 | return (yl, yu, xl, xu), lam 88 | 89 | 90 | class Mixup: 91 | """ Mixup/Cutmix that applies different params to each element or whole batch 92 | 93 | Args: 94 | mixup_alpha (float): mixup alpha value, mixup is active if > 0. 95 | cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0. 96 | cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None. 97 | prob (float): probability of applying mixup or cutmix per batch or element 98 | switch_prob (float): probability of switching to cutmix instead of mixup when both are active 99 | mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element) 100 | correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders 101 | label_smoothing (float): apply label smoothing to the mixed target tensor 102 | num_classes (int): number of classes for target 103 | """ 104 | def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5, 105 | mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000): 106 | self.mixup_alpha = mixup_alpha 107 | self.cutmix_alpha = cutmix_alpha 108 | self.cutmix_minmax = cutmix_minmax 109 | if self.cutmix_minmax is not None: 110 | assert len(self.cutmix_minmax) == 2 111 | # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe 112 | self.cutmix_alpha = 1.0 113 | self.mix_prob = prob 114 | self.switch_prob = switch_prob 115 | self.label_smoothing = label_smoothing 116 | self.num_classes = num_classes 117 | self.mode = mode 118 | self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix 119 | self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop) 120 | 121 | def _params_per_elem(self, batch_size): 122 | lam = np.ones(batch_size, dtype=np.float32) 123 | use_cutmix = np.zeros(batch_size, dtype=np.bool) 124 | if self.mixup_enabled: 125 | if self.mixup_alpha > 0. and self.cutmix_alpha > 0.: 126 | use_cutmix = np.random.rand(batch_size) < self.switch_prob 127 | lam_mix = np.where( 128 | use_cutmix, 129 | np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size), 130 | np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)) 131 | elif self.mixup_alpha > 0.: 132 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size) 133 | elif self.cutmix_alpha > 0.: 134 | use_cutmix = np.ones(batch_size, dtype=np.bool) 135 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size) 136 | else: 137 | assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." 138 | lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam) 139 | return lam, use_cutmix 140 | 141 | def _params_per_batch(self): 142 | lam = 1. 143 | use_cutmix = False 144 | if self.mixup_enabled and np.random.rand() < self.mix_prob: 145 | if self.mixup_alpha > 0. and self.cutmix_alpha > 0.: 146 | use_cutmix = np.random.rand() < self.switch_prob 147 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \ 148 | np.random.beta(self.mixup_alpha, self.mixup_alpha) 149 | elif self.mixup_alpha > 0.: 150 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha) 151 | elif self.cutmix_alpha > 0.: 152 | use_cutmix = True 153 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) 154 | else: 155 | assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." 156 | lam = float(lam_mix) 157 | return lam, use_cutmix 158 | 159 | def _mix_elem(self, x): 160 | batch_size = len(x) 161 | lam_batch, use_cutmix = self._params_per_elem(batch_size) 162 | x_orig = x.clone() # need to keep an unmodified original for mixing source 163 | for i in range(batch_size): 164 | j = batch_size - i - 1 165 | lam = lam_batch[i] 166 | if lam != 1.: 167 | if use_cutmix[i]: 168 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 169 | x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 170 | x[i][..., yl:yh, xl:xh] = x_orig[j][..., yl:yh, xl:xh] 171 | lam_batch[i] = lam 172 | else: 173 | x[i] = x[i] * lam + x_orig[j] * (1 - lam) 174 | return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1) 175 | 176 | def _mix_pair(self, x): 177 | batch_size = len(x) 178 | lam_batch, use_cutmix = self._params_per_elem(batch_size // 2) 179 | x_orig = x.clone() # need to keep an unmodified original for mixing source 180 | for i in range(batch_size // 2): 181 | j = batch_size - i - 1 182 | lam = lam_batch[i] 183 | if lam != 1.: 184 | if use_cutmix[i]: 185 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 186 | x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 187 | x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh] 188 | x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh] 189 | lam_batch[i] = lam 190 | else: 191 | x[i] = x[i] * lam + x_orig[j] * (1 - lam) 192 | x[j] = x[j] * lam + x_orig[i] * (1 - lam) 193 | lam_batch = np.concatenate((lam_batch, lam_batch[::-1])) 194 | return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1) 195 | 196 | def _mix_batch(self, x): 197 | lam, use_cutmix = self._params_per_batch() 198 | if lam == 1.: 199 | return 1. 200 | if use_cutmix: 201 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 202 | x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 203 | x[..., yl:yh, xl:xh] = x.flip(0)[..., yl:yh, xl:xh] 204 | else: 205 | x_flipped = x.flip(0).mul_(1. - lam) 206 | x.mul_(lam).add_(x_flipped) 207 | return lam 208 | 209 | def __call__(self, x, target): 210 | assert len(x) % 2 == 0, 'Batch size should be even when using this' 211 | if self.mode == 'elem': 212 | lam = self._mix_elem(x) 213 | elif self.mode == 'pair': 214 | lam = self._mix_pair(x) 215 | else: 216 | lam = self._mix_batch(x) 217 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device) 218 | return x, target 219 | 220 | 221 | class FastCollateMixup(Mixup): 222 | """ Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch 223 | 224 | A Mixup impl that's performed while collating the batches. 225 | """ 226 | 227 | def _mix_elem_collate(self, output, batch, half=False): 228 | batch_size = len(batch) 229 | num_elem = batch_size // 2 if half else batch_size 230 | assert len(output) == num_elem 231 | lam_batch, use_cutmix = self._params_per_elem(num_elem) 232 | for i in range(num_elem): 233 | j = batch_size - i - 1 234 | lam = lam_batch[i] 235 | mixed = batch[i][0] 236 | if lam != 1.: 237 | if use_cutmix[i]: 238 | if not half: 239 | mixed = mixed.copy() 240 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 241 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 242 | mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh] 243 | lam_batch[i] = lam 244 | else: 245 | mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) 246 | np.rint(mixed, out=mixed) 247 | output[i] += torch.from_numpy(mixed.astype(np.uint8)) 248 | if half: 249 | lam_batch = np.concatenate((lam_batch, np.ones(num_elem))) 250 | return torch.tensor(lam_batch).unsqueeze(1) 251 | 252 | def _mix_pair_collate(self, output, batch): 253 | batch_size = len(batch) 254 | lam_batch, use_cutmix = self._params_per_elem(batch_size // 2) 255 | for i in range(batch_size // 2): 256 | j = batch_size - i - 1 257 | lam = lam_batch[i] 258 | mixed_i = batch[i][0] 259 | mixed_j = batch[j][0] 260 | assert 0 <= lam <= 1.0 261 | if lam < 1.: 262 | if use_cutmix[i]: 263 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 264 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 265 | patch_i = mixed_i[:, yl:yh, xl:xh].copy() 266 | mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh] 267 | mixed_j[:, yl:yh, xl:xh] = patch_i 268 | lam_batch[i] = lam 269 | else: 270 | mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam) 271 | mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam) 272 | mixed_i = mixed_temp 273 | np.rint(mixed_j, out=mixed_j) 274 | np.rint(mixed_i, out=mixed_i) 275 | output[i] += torch.from_numpy(mixed_i.astype(np.uint8)) 276 | output[j] += torch.from_numpy(mixed_j.astype(np.uint8)) 277 | lam_batch = np.concatenate((lam_batch, lam_batch[::-1])) 278 | return torch.tensor(lam_batch).unsqueeze(1) 279 | 280 | def _mix_batch_collate(self, output, batch): 281 | batch_size = len(batch) 282 | lam, use_cutmix = self._params_per_batch() 283 | if use_cutmix: 284 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 285 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 286 | for i in range(batch_size): 287 | j = batch_size - i - 1 288 | mixed = batch[i][0] 289 | if lam != 1.: 290 | if use_cutmix: 291 | mixed = mixed.copy() # don't want to modify the original while iterating 292 | mixed[..., yl:yh, xl:xh] = batch[j][0][..., yl:yh, xl:xh] 293 | else: 294 | mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) 295 | np.rint(mixed, out=mixed) 296 | output[i] += torch.from_numpy(mixed.astype(np.uint8)) 297 | return lam 298 | 299 | def __call__(self, batch, _=None): 300 | batch_size = len(batch) 301 | assert batch_size % 2 == 0, 'Batch size should be even when using this' 302 | half = 'half' in self.mode 303 | if half: 304 | batch_size //= 2 305 | output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) 306 | if self.mode == 'elem' or self.mode == 'half': 307 | lam = self._mix_elem_collate(output, batch, half=half) 308 | elif self.mode == 'pair': 309 | lam = self._mix_pair_collate(output, batch) 310 | else: 311 | lam = self._mix_batch_collate(output, batch) 312 | target = torch.tensor([b[1] for b in batch], dtype=torch.int64) 313 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu') 314 | target = target[:batch_size] 315 | return output, target 316 | 317 | -------------------------------------------------------------------------------- /datasets/transforms/optim_factory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim as optim 3 | 4 | from timm.optim.adafactor import Adafactor 5 | from timm.optim.adahessian import Adahessian 6 | from timm.optim.adamp import AdamP 7 | from timm.optim.lookahead import Lookahead 8 | from timm.optim.nadam import Nadam 9 | from timm.optim.novograd import NovoGrad 10 | from timm.optim.nvnovograd import NvNovoGrad 11 | from timm.optim.radam import RAdam 12 | from timm.optim.rmsprop_tf import RMSpropTF 13 | from timm.optim.sgdp import SGDP 14 | 15 | import json 16 | 17 | try: 18 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD 19 | has_apex = True 20 | except ImportError: 21 | has_apex = False 22 | 23 | 24 | def get_num_layer_for_vit(var_name, num_max_layer): 25 | if var_name in ("cls_token", "mask_token", "pos_embed"): 26 | return 0 27 | elif var_name.startswith("patch_embed"): 28 | return 0 29 | elif var_name.startswith("temporal_embedding"): 30 | return 0 31 | elif var_name.startswith("time_embed"): 32 | return 0 33 | elif var_name in ("class_embedding", "positional_embedding", "temporal_positional_embedding"): 34 | return 0 35 | elif var_name.startswith("conv1"): 36 | return 0 37 | elif var_name.startswith("rel_pos_bias"): 38 | return num_max_layer - 1 39 | elif var_name.startswith("blocks"): 40 | layer_id = int(var_name.split('.')[1]) 41 | return layer_id + 1 42 | elif var_name.startswith("transformer.resblocks"): 43 | layer_id = int(var_name.split('.')[2]) 44 | return layer_id + 1 45 | else: 46 | return num_max_layer - 1 47 | 48 | 49 | class LayerDecayValueAssigner(object): 50 | def __init__(self, values): 51 | self.values = values 52 | 53 | def get_scale(self, layer_id): 54 | return self.values[layer_id] 55 | 56 | def get_layer_id(self, var_name): 57 | return get_num_layer_for_vit(var_name, len(self.values)) 58 | 59 | 60 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None): 61 | parameter_group_names = {} 62 | parameter_group_vars = {} 63 | 64 | for name, param in model.named_parameters(): 65 | if not param.requires_grad: 66 | continue # frozen weights 67 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: 68 | group_name = "no_decay" 69 | this_weight_decay = 0. 70 | else: 71 | group_name = "decay" 72 | this_weight_decay = weight_decay 73 | if get_num_layer is not None: 74 | layer_id = get_num_layer(name) 75 | group_name = "layer_%d_%s" % (layer_id, group_name) 76 | else: 77 | layer_id = None 78 | 79 | if group_name not in parameter_group_names: 80 | if get_layer_scale is not None: 81 | scale = get_layer_scale(layer_id) 82 | else: 83 | scale = 1. 84 | 85 | parameter_group_names[group_name] = { 86 | "weight_decay": this_weight_decay, 87 | "params": [], 88 | "lr_scale": scale 89 | } 90 | parameter_group_vars[group_name] = { 91 | "weight_decay": this_weight_decay, 92 | "params": [], 93 | "lr_scale": scale 94 | } 95 | 96 | parameter_group_vars[group_name]["params"].append(param) 97 | parameter_group_names[group_name]["params"].append(name) 98 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) 99 | return list(parameter_group_vars.values()) 100 | 101 | 102 | def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None): 103 | opt_lower = args.opt.lower() 104 | weight_decay = args.weight_decay 105 | if weight_decay and filter_bias_and_bn: 106 | skip = {} 107 | if skip_list is not None: 108 | skip = skip_list 109 | elif hasattr(model, 'no_weight_decay'): 110 | skip = model.no_weight_decay() 111 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale) 112 | weight_decay = 0. 113 | else: 114 | parameters = model.parameters() 115 | 116 | if 'fused' in opt_lower: 117 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' 118 | 119 | opt_args = dict(lr=args.lr, weight_decay=weight_decay) 120 | if hasattr(args, 'opt_eps') and args.opt_eps is not None: 121 | opt_args['eps'] = args.opt_eps 122 | if hasattr(args, 'opt_betas') and args.opt_betas is not None: 123 | opt_args['betas'] = args.opt_betas 124 | 125 | print("optimizer settings:", opt_args) 126 | 127 | opt_split = opt_lower.split('_') 128 | opt_lower = opt_split[-1] 129 | if opt_lower == 'sgd' or opt_lower == 'nesterov': 130 | opt_args.pop('eps', None) 131 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 132 | elif opt_lower == 'momentum': 133 | opt_args.pop('eps', None) 134 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 135 | elif opt_lower == 'adam': 136 | optimizer = optim.Adam(parameters, **opt_args) 137 | elif opt_lower == 'adamw': 138 | optimizer = optim.AdamW(parameters, **opt_args) 139 | elif opt_lower == 'nadam': 140 | optimizer = Nadam(parameters, **opt_args) 141 | elif opt_lower == 'radam': 142 | optimizer = RAdam(parameters, **opt_args) 143 | elif opt_lower == 'adamp': 144 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) 145 | elif opt_lower == 'sgdp': 146 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args) 147 | elif opt_lower == 'adadelta': 148 | optimizer = optim.Adadelta(parameters, **opt_args) 149 | elif opt_lower == 'adafactor': 150 | if not args.lr: 151 | opt_args['lr'] = None 152 | optimizer = Adafactor(parameters, **opt_args) 153 | elif opt_lower == 'adahessian': 154 | optimizer = Adahessian(parameters, **opt_args) 155 | elif opt_lower == 'rmsprop': 156 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 157 | elif opt_lower == 'rmsproptf': 158 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 159 | elif opt_lower == 'novograd': 160 | optimizer = NovoGrad(parameters, **opt_args) 161 | elif opt_lower == 'nvnovograd': 162 | optimizer = NvNovoGrad(parameters, **opt_args) 163 | elif opt_lower == 'fusedsgd': 164 | opt_args.pop('eps', None) 165 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 166 | elif opt_lower == 'fusedmomentum': 167 | opt_args.pop('eps', None) 168 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 169 | elif opt_lower == 'fusedadam': 170 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) 171 | elif opt_lower == 'fusedadamw': 172 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) 173 | elif opt_lower == 'fusedlamb': 174 | optimizer = FusedLAMB(parameters, **opt_args) 175 | elif opt_lower == 'fusednovograd': 176 | opt_args.setdefault('betas', (0.95, 0.98)) 177 | optimizer = FusedNovoGrad(parameters, **opt_args) 178 | else: 179 | assert False and "Invalid optimizer" 180 | raise ValueError 181 | 182 | if len(opt_split) > 1: 183 | if opt_split[0] == 'lookahead': 184 | optimizer = Lookahead(optimizer) 185 | 186 | return optimizer 187 | -------------------------------------------------------------------------------- /datasets/transforms/rand_augment.py: -------------------------------------------------------------------------------- 1 | """ 2 | This implementation is based on 3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py 4 | pulished under an Apache License 2.0. 5 | 6 | COMMENT FROM ORIGINAL: 7 | AutoAugment, RandAugment, and AugMix for PyTorch 8 | This code implements the searched ImageNet policies with various tweaks and 9 | improvements and does not include any of the search code. AA and RA 10 | Implementation adapted from: 11 | https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py 12 | AugMix adapted from: 13 | https://github.com/google-research/augmix 14 | Papers: 15 | AutoAugment: Learning Augmentation Policies from Data 16 | https://arxiv.org/abs/1805.09501 17 | Learning Data Augmentation Strategies for Object Detection 18 | https://arxiv.org/abs/1906.11172 19 | RandAugment: Practical automated data augmentation... 20 | https://arxiv.org/abs/1909.13719 21 | AugMix: A Simple Data Processing Method to Improve Robustness and 22 | Uncertainty https://arxiv.org/abs/1912.02781 23 | 24 | Hacked together by / Copyright 2020 Ross Wightman 25 | """ 26 | 27 | import math 28 | import numpy as np 29 | import random 30 | import re 31 | import PIL 32 | from PIL import Image, ImageEnhance, ImageOps 33 | 34 | _PIL_VER = tuple([int(x) for x in PIL.__version__.split(".")[:2]]) 35 | 36 | _FILL = (128, 128, 128) # 灰色补全 37 | # _FILL = (0, 0, 0) # 黑色补全,与正常内窥镜截断一致 38 | 39 | # This signifies the max integer that the controller RNN could predict for the 40 | # augmentation scheme. 41 | _MAX_LEVEL = 10.0 42 | 43 | _HPARAMS_DEFAULT = { 44 | "translate_const": 250, 45 | "img_mean": _FILL, 46 | } 47 | 48 | _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) 49 | 50 | 51 | def _interpolation(kwargs): 52 | interpolation = kwargs.pop("resample", Image.BILINEAR) 53 | if isinstance(interpolation, (list, tuple)): 54 | return random.choice(interpolation) 55 | else: 56 | return interpolation 57 | 58 | 59 | def _check_args_tf(kwargs): 60 | if "fillcolor" in kwargs and _PIL_VER < (5, 0): 61 | kwargs.pop("fillcolor") 62 | kwargs["resample"] = _interpolation(kwargs) 63 | 64 | 65 | def shear_x(img, factor, **kwargs): 66 | _check_args_tf(kwargs) 67 | return img.transform( 68 | img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs 69 | ) 70 | 71 | 72 | def shear_y(img, factor, **kwargs): 73 | _check_args_tf(kwargs) 74 | return img.transform( 75 | img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs 76 | ) 77 | 78 | 79 | def translate_x_rel(img, pct, **kwargs): 80 | pixels = pct * img.size[0] 81 | _check_args_tf(kwargs) 82 | return img.transform( 83 | img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs 84 | ) 85 | 86 | 87 | def translate_y_rel(img, pct, **kwargs): 88 | pixels = pct * img.size[1] 89 | _check_args_tf(kwargs) 90 | return img.transform( 91 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs 92 | ) 93 | 94 | 95 | def translate_x_abs(img, pixels, **kwargs): 96 | _check_args_tf(kwargs) 97 | return img.transform( 98 | img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs 99 | ) 100 | 101 | 102 | def translate_y_abs(img, pixels, **kwargs): 103 | _check_args_tf(kwargs) 104 | return img.transform( 105 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs 106 | ) 107 | 108 | 109 | def rotate(img, degrees, **kwargs): 110 | _check_args_tf(kwargs) 111 | if _PIL_VER >= (5, 2): 112 | return img.rotate(degrees, **kwargs) 113 | elif _PIL_VER >= (5, 0): 114 | w, h = img.size 115 | post_trans = (0, 0) 116 | rotn_center = (w / 2.0, h / 2.0) 117 | angle = -math.radians(degrees) 118 | matrix = [ 119 | round(math.cos(angle), 15), 120 | round(math.sin(angle), 15), 121 | 0.0, 122 | round(-math.sin(angle), 15), 123 | round(math.cos(angle), 15), 124 | 0.0, 125 | ] 126 | 127 | def transform(x, y, matrix): 128 | (a, b, c, d, e, f) = matrix 129 | return a * x + b * y + c, d * x + e * y + f 130 | 131 | matrix[2], matrix[5] = transform( 132 | -rotn_center[0] - post_trans[0], 133 | -rotn_center[1] - post_trans[1], 134 | matrix, 135 | ) 136 | matrix[2] += rotn_center[0] 137 | matrix[5] += rotn_center[1] 138 | return img.transform(img.size, Image.AFFINE, matrix, **kwargs) 139 | else: 140 | return img.rotate(degrees, resample=kwargs["resample"]) 141 | 142 | 143 | def auto_contrast(img, **__): 144 | return ImageOps.autocontrast(img) 145 | 146 | 147 | def invert(img, **__): 148 | return ImageOps.invert(img) 149 | 150 | 151 | def equalize(img, **__): 152 | return ImageOps.equalize(img) 153 | 154 | 155 | def solarize(img, thresh, **__): 156 | return ImageOps.solarize(img, thresh) 157 | 158 | 159 | def solarize_add(img, add, thresh=128, **__): 160 | lut = [] 161 | for i in range(256): 162 | if i < thresh: 163 | lut.append(min(255, i + add)) 164 | else: 165 | lut.append(i) 166 | if img.mode in ("L", "RGB"): 167 | if img.mode == "RGB" and len(lut) == 256: 168 | lut = lut + lut + lut 169 | return img.point(lut) 170 | else: 171 | return img 172 | 173 | 174 | def posterize(img, bits_to_keep, **__): 175 | if bits_to_keep >= 8: 176 | return img 177 | return ImageOps.posterize(img, bits_to_keep) 178 | 179 | 180 | def contrast(img, factor, **__): 181 | return ImageEnhance.Contrast(img).enhance(factor) 182 | 183 | 184 | def color(img, factor, **__): 185 | return ImageEnhance.Color(img).enhance(factor) 186 | 187 | 188 | def brightness(img, factor, **__): 189 | return ImageEnhance.Brightness(img).enhance(factor) 190 | 191 | 192 | def sharpness(img, factor, **__): 193 | return ImageEnhance.Sharpness(img).enhance(factor) 194 | 195 | 196 | def _randomly_negate(v): 197 | """With 50% prob, negate the value""" 198 | return -v if random.random() > 0.5 else v 199 | 200 | 201 | def _rotate_level_to_arg(level, _hparams): 202 | # range [-30, 30] 203 | level = (level / _MAX_LEVEL) * 30.0 204 | level = _randomly_negate(level) 205 | return (level,) 206 | 207 | 208 | def _enhance_level_to_arg(level, _hparams): 209 | # range [0.1, 1.9] 210 | return ((level / _MAX_LEVEL) * 1.8 + 0.1,) 211 | 212 | 213 | def _enhance_increasing_level_to_arg(level, _hparams): 214 | # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend 215 | # range [0.1, 1.9] 216 | level = (level / _MAX_LEVEL) * 0.9 217 | level = 1.0 + _randomly_negate(level) 218 | return (level,) 219 | 220 | 221 | def _shear_level_to_arg(level, _hparams): 222 | # range [-0.3, 0.3] 223 | level = (level / _MAX_LEVEL) * 0.3 224 | level = _randomly_negate(level) 225 | return (level,) 226 | 227 | 228 | def _translate_abs_level_to_arg(level, hparams): 229 | translate_const = hparams["translate_const"] 230 | level = (level / _MAX_LEVEL) * float(translate_const) 231 | level = _randomly_negate(level) 232 | return (level,) 233 | 234 | 235 | def _translate_rel_level_to_arg(level, hparams): 236 | # default range [-0.45, 0.45] 237 | translate_pct = hparams.get("translate_pct", 0.45) 238 | level = (level / _MAX_LEVEL) * translate_pct 239 | level = _randomly_negate(level) 240 | return (level,) 241 | 242 | 243 | def _posterize_level_to_arg(level, _hparams): 244 | # As per Tensorflow TPU EfficientNet impl 245 | # range [0, 4], 'keep 0 up to 4 MSB of original image' 246 | # intensity/severity of augmentation decreases with level 247 | return (int((level / _MAX_LEVEL) * 4),) 248 | 249 | 250 | def _posterize_increasing_level_to_arg(level, hparams): 251 | # As per Tensorflow models research and UDA impl 252 | # range [4, 0], 'keep 4 down to 0 MSB of original image', 253 | # intensity/severity of augmentation increases with level 254 | return (4 - _posterize_level_to_arg(level, hparams)[0],) 255 | 256 | 257 | def _posterize_original_level_to_arg(level, _hparams): 258 | # As per original AutoAugment paper description 259 | # range [4, 8], 'keep 4 up to 8 MSB of image' 260 | # intensity/severity of augmentation decreases with level 261 | return (int((level / _MAX_LEVEL) * 4) + 4,) 262 | 263 | 264 | def _solarize_level_to_arg(level, _hparams): 265 | # range [0, 256] 266 | # intensity/severity of augmentation decreases with level 267 | return (int((level / _MAX_LEVEL) * 256),) 268 | 269 | 270 | def _solarize_increasing_level_to_arg(level, _hparams): 271 | # range [0, 256] 272 | # intensity/severity of augmentation increases with level 273 | return (256 - _solarize_level_to_arg(level, _hparams)[0],) 274 | 275 | 276 | def _solarize_add_level_to_arg(level, _hparams): 277 | # range [0, 110] 278 | return (int((level / _MAX_LEVEL) * 110),) 279 | 280 | 281 | LEVEL_TO_ARG = { 282 | "AutoContrast": None, 283 | "Equalize": None, 284 | "Invert": None, 285 | "Rotate": _rotate_level_to_arg, 286 | # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers 287 | "Posterize": _posterize_level_to_arg, 288 | "PosterizeIncreasing": _posterize_increasing_level_to_arg, 289 | "PosterizeOriginal": _posterize_original_level_to_arg, 290 | "Solarize": _solarize_level_to_arg, 291 | "SolarizeIncreasing": _solarize_increasing_level_to_arg, 292 | "SolarizeAdd": _solarize_add_level_to_arg, 293 | "Color": _enhance_level_to_arg, 294 | "ColorIncreasing": _enhance_increasing_level_to_arg, 295 | "Contrast": _enhance_level_to_arg, 296 | "ContrastIncreasing": _enhance_increasing_level_to_arg, 297 | "Brightness": _enhance_level_to_arg, 298 | "BrightnessIncreasing": _enhance_increasing_level_to_arg, 299 | "Sharpness": _enhance_level_to_arg, 300 | "SharpnessIncreasing": _enhance_increasing_level_to_arg, 301 | "ShearX": _shear_level_to_arg, 302 | "ShearY": _shear_level_to_arg, 303 | "TranslateX": _translate_abs_level_to_arg, 304 | "TranslateY": _translate_abs_level_to_arg, 305 | "TranslateXRel": _translate_rel_level_to_arg, 306 | "TranslateYRel": _translate_rel_level_to_arg, 307 | } 308 | 309 | 310 | NAME_TO_OP = { 311 | "AutoContrast": auto_contrast, 312 | "Equalize": equalize, 313 | "Invert": invert, 314 | "Rotate": rotate, 315 | "Posterize": posterize, 316 | "PosterizeIncreasing": posterize, 317 | "PosterizeOriginal": posterize, 318 | "Solarize": solarize, 319 | "SolarizeIncreasing": solarize, 320 | "SolarizeAdd": solarize_add, 321 | "Color": color, 322 | "ColorIncreasing": color, 323 | "Contrast": contrast, 324 | "ContrastIncreasing": contrast, 325 | "Brightness": brightness, 326 | "BrightnessIncreasing": brightness, 327 | "Sharpness": sharpness, 328 | "SharpnessIncreasing": sharpness, 329 | "ShearX": shear_x, 330 | "ShearY": shear_y, 331 | "TranslateX": translate_x_abs, 332 | "TranslateY": translate_y_abs, 333 | "TranslateXRel": translate_x_rel, 334 | "TranslateYRel": translate_y_rel, 335 | } 336 | 337 | 338 | class AugmentOp: 339 | """ 340 | Apply for video. 341 | """ 342 | 343 | def __init__(self, name, prob=0.5, magnitude=10, hparams=None): 344 | hparams = hparams or _HPARAMS_DEFAULT 345 | self.aug_fn = NAME_TO_OP[name] 346 | self.level_fn = LEVEL_TO_ARG[name] 347 | self.prob = prob 348 | self.magnitude = magnitude 349 | self.hparams = hparams.copy() 350 | self.kwargs = { 351 | "fillcolor": hparams["img_mean"] 352 | if "img_mean" in hparams 353 | else _FILL, 354 | "resample": hparams["interpolation"] 355 | if "interpolation" in hparams 356 | else _RANDOM_INTERPOLATION, 357 | } 358 | 359 | # If magnitude_std is > 0, we introduce some randomness 360 | # in the usually fixed policy and sample magnitude from a normal distribution 361 | # with mean `magnitude` and std-dev of `magnitude_std`. 362 | # NOTE This is my own hack, being tested, not in papers or reference impls. 363 | self.magnitude_std = self.hparams.get("magnitude_std", 0) 364 | 365 | def __call__(self, img_list): 366 | if self.prob < 1.0 and random.random() > self.prob: 367 | return img_list 368 | magnitude = self.magnitude 369 | if self.magnitude_std and self.magnitude_std > 0: 370 | magnitude = random.gauss(magnitude, self.magnitude_std) 371 | magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range 372 | level_args = ( 373 | self.level_fn(magnitude, self.hparams) 374 | if self.level_fn is not None 375 | else () 376 | ) 377 | 378 | if isinstance(img_list, list): 379 | return [ 380 | self.aug_fn(img, *level_args, **self.kwargs) for img in img_list 381 | ] 382 | else: 383 | return self.aug_fn(img_list, *level_args, **self.kwargs) 384 | 385 | 386 | _RAND_TRANSFORMS = [ 387 | "AutoContrast", 388 | "Equalize", 389 | "Invert", 390 | "Rotate", 391 | "Posterize", 392 | "Solarize", 393 | "SolarizeAdd", 394 | "Color", 395 | "Contrast", 396 | "Brightness", 397 | "Sharpness", 398 | "ShearX", 399 | "ShearY", 400 | "TranslateXRel", 401 | "TranslateYRel", 402 | ] 403 | 404 | 405 | _RAND_INCREASING_TRANSFORMS = [ 406 | "AutoContrast", 407 | "Equalize", 408 | "Invert", 409 | "Rotate", 410 | "PosterizeIncreasing", 411 | "SolarizeIncreasing", 412 | "SolarizeAdd", 413 | "ColorIncreasing", 414 | "ContrastIncreasing", 415 | "BrightnessIncreasing", 416 | "SharpnessIncreasing", 417 | "ShearX", 418 | "ShearY", 419 | "TranslateXRel", 420 | "TranslateYRel", 421 | ] 422 | 423 | 424 | # These experimental weights are based loosely on the relative improvements mentioned in paper. 425 | # They may not result in increased performance, but could likely be tuned to so. 426 | _RAND_CHOICE_WEIGHTS_0 = { 427 | "Rotate": 0.3, 428 | "ShearX": 0.2, 429 | "ShearY": 0.2, 430 | "TranslateXRel": 0.1, 431 | "TranslateYRel": 0.1, 432 | "Color": 0.025, 433 | "Sharpness": 0.025, 434 | "AutoContrast": 0.025, 435 | "Solarize": 0.005, 436 | "SolarizeAdd": 0.005, 437 | "Contrast": 0.005, 438 | "Brightness": 0.005, 439 | "Equalize": 0.005, 440 | "Posterize": 0, 441 | "Invert": 0, 442 | } 443 | 444 | 445 | def _select_rand_weights(weight_idx=0, transforms=None): 446 | transforms = transforms or _RAND_TRANSFORMS 447 | assert weight_idx == 0 # only one set of weights currently 448 | rand_weights = _RAND_CHOICE_WEIGHTS_0 449 | probs = [rand_weights[k] for k in transforms] 450 | probs /= np.sum(probs) 451 | return probs 452 | 453 | 454 | def rand_augment_ops(magnitude=10, hparams=None, transforms=None): 455 | hparams = hparams or _HPARAMS_DEFAULT 456 | transforms = transforms or _RAND_TRANSFORMS 457 | return [ 458 | AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams) 459 | for name in transforms 460 | ] 461 | 462 | 463 | class RandAugment: 464 | def __init__(self, ops, num_layers=2, choice_weights=None): 465 | self.ops = ops 466 | self.num_layers = num_layers 467 | self.choice_weights = choice_weights 468 | 469 | def __call__(self, img): 470 | # no replacement when using weighted choice 471 | ops = np.random.choice( 472 | self.ops, 473 | self.num_layers, 474 | replace=self.choice_weights is None, 475 | p=self.choice_weights, 476 | ) 477 | for op in ops: 478 | img = op(img) 479 | return img 480 | 481 | 482 | def rand_augment_transform(config_str, hparams): 483 | """ 484 | RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719 485 | 486 | Create a RandAugment transform 487 | :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by 488 | dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining 489 | sections, not order sepecific determine 490 | 'm' - integer magnitude of rand augment 491 | 'n' - integer num layers (number of transform ops selected per image) 492 | 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) 493 | 'mstd' - float std deviation of magnitude noise applied 494 | 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) 495 | Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 496 | 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 497 | :param hparams: Other hparams (kwargs) for the RandAugmentation scheme 498 | :return: A PyTorch compatible Transform 499 | """ 500 | # rand-m7-n4-mstd0.5-inc1 501 | # rand-m8-n2-mstd0.5-inc1 502 | # {'translate_const': 100, 'interpolation': 3} 503 | magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10) 504 | num_layers = 2 # default to 2 ops per image 505 | weight_idx = None # default to no probability weights for op choice 506 | transforms = _RAND_TRANSFORMS 507 | config = config_str.split("-") 508 | assert config[0] == "rand" 509 | config = config[1:] 510 | for c in config: 511 | cs = re.split(r"(\d.*)", c) 512 | if len(cs) < 2: 513 | continue 514 | key, val = cs[:2] 515 | if key == "mstd": 516 | # noise param injected via hparams for now 517 | hparams.setdefault("magnitude_std", float(val)) 518 | elif key == "inc": 519 | if bool(val): 520 | transforms = _RAND_INCREASING_TRANSFORMS 521 | elif key == "m": 522 | magnitude = int(val) 523 | elif key == "n": 524 | num_layers = int(val) 525 | elif key == "w": 526 | weight_idx = int(val) 527 | else: 528 | assert NotImplementedError 529 | ra_ops = rand_augment_ops( 530 | magnitude=magnitude, hparams=hparams, transforms=transforms 531 | ) 532 | choice_weights = ( 533 | None if weight_idx is None else _select_rand_weights(weight_idx) 534 | ) 535 | return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) 536 | -------------------------------------------------------------------------------- /datasets/transforms/random_erasing.py: -------------------------------------------------------------------------------- 1 | """ 2 | This implementation is based on 3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py 4 | pulished under an Apache License 2.0. 5 | """ 6 | import math 7 | import random 8 | import torch 9 | 10 | 11 | def _get_pixels( 12 | per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda" 13 | ): 14 | # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() 15 | # paths, flip the order so normal is run on CPU if this becomes a problem 16 | # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 17 | if per_pixel: 18 | return torch.empty(patch_size, dtype=dtype, device=device).normal_() 19 | elif rand_color: 20 | return torch.empty( 21 | (patch_size[0], 1, 1), dtype=dtype, device=device 22 | ).normal_() 23 | else: 24 | return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) 25 | 26 | 27 | class RandomErasing: 28 | """Randomly selects a rectangle region in an image and erases its pixels. 29 | 'Random Erasing Data Augmentation' by Zhong et al. 30 | See https://arxiv.org/pdf/1708.04896.pdf 31 | This variant of RandomErasing is intended to be applied to either a batch 32 | or single image tensor after it has been normalized by dataset mean and std. 33 | Args: 34 | probability: Probability that the Random Erasing operation will be performed. 35 | min_area: Minimum percentage of erased area wrt input image area. 36 | max_area: Maximum percentage of erased area wrt input image area. 37 | min_aspect: Minimum aspect ratio of erased area. 38 | mode: pixel color mode, one of 'const', 'rand', or 'pixel' 39 | 'const' - erase block is constant color of 0 for all channels 40 | 'rand' - erase block is same per-channel random (normal) color 41 | 'pixel' - erase block is per-pixel random (normal) color 42 | max_count: maximum number of erasing blocks per image, area per box is scaled by count. 43 | per-image count is randomly chosen between 1 and this value. 44 | """ 45 | 46 | def __init__( 47 | self, 48 | probability=0.5, 49 | min_area=0.02, 50 | max_area=1 / 3, 51 | min_aspect=0.3, 52 | max_aspect=None, 53 | mode="const", 54 | min_count=1, 55 | max_count=None, 56 | num_splits=0, 57 | device="cuda", 58 | cube=True, 59 | ): 60 | self.probability = probability 61 | self.min_area = min_area 62 | self.max_area = max_area 63 | max_aspect = max_aspect or 1 / min_aspect 64 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) 65 | self.min_count = min_count 66 | self.max_count = max_count or min_count 67 | self.num_splits = num_splits 68 | mode = mode.lower() 69 | self.rand_color = False 70 | self.per_pixel = False 71 | self.cube = cube 72 | if mode == "rand": 73 | self.rand_color = True # per block random normal 74 | elif mode == "pixel": 75 | self.per_pixel = True # per pixel random normal 76 | else: 77 | assert not mode or mode == "const" 78 | self.device = device 79 | 80 | def _erase(self, img, chan, img_h, img_w, dtype): 81 | if random.random() > self.probability: 82 | return 83 | area = img_h * img_w 84 | count = ( 85 | self.min_count 86 | if self.min_count == self.max_count 87 | else random.randint(self.min_count, self.max_count) 88 | ) 89 | for _ in range(count): 90 | for _ in range(10): 91 | target_area = ( 92 | random.uniform(self.min_area, self.max_area) * area / count 93 | ) 94 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 95 | h = int(round(math.sqrt(target_area * aspect_ratio))) 96 | w = int(round(math.sqrt(target_area / aspect_ratio))) 97 | if w < img_w and h < img_h: 98 | top = random.randint(0, img_h - h) 99 | left = random.randint(0, img_w - w) 100 | img[:, top : top + h, left : left + w] = _get_pixels( 101 | self.per_pixel, 102 | self.rand_color, 103 | (chan, h, w), 104 | dtype=dtype, 105 | device=self.device, 106 | ) 107 | break 108 | 109 | def _erase_cube( 110 | self, 111 | img, 112 | batch_start, 113 | batch_size, 114 | chan, 115 | img_h, 116 | img_w, 117 | dtype, 118 | ): 119 | if random.random() > self.probability: 120 | return 121 | area = img_h * img_w 122 | count = ( 123 | self.min_count 124 | if self.min_count == self.max_count 125 | else random.randint(self.min_count, self.max_count) 126 | ) 127 | for _ in range(count): 128 | for _ in range(100): 129 | target_area = ( 130 | random.uniform(self.min_area, self.max_area) * area / count 131 | ) 132 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 133 | h = int(round(math.sqrt(target_area * aspect_ratio))) 134 | w = int(round(math.sqrt(target_area / aspect_ratio))) 135 | if w < img_w and h < img_h: 136 | top = random.randint(0, img_h - h) 137 | left = random.randint(0, img_w - w) 138 | for i in range(batch_start, batch_size): 139 | img_instance = img[i] 140 | img_instance[ 141 | :, top : top + h, left : left + w 142 | ] = _get_pixels( 143 | self.per_pixel, 144 | self.rand_color, 145 | (chan, h, w), 146 | dtype=dtype, 147 | device=self.device, 148 | ) 149 | break 150 | 151 | def __call__(self, input): 152 | if len(input.size()) == 3: 153 | self._erase(input, *input.size(), input.dtype) 154 | else: 155 | batch_size, chan, img_h, img_w = input.size() 156 | # skip first slice of batch if num_splits is set (for clean portion of samples) 157 | batch_start = ( 158 | batch_size // self.num_splits if self.num_splits > 1 else 0 159 | ) 160 | if self.cube: 161 | self._erase_cube( 162 | input, 163 | batch_start, 164 | batch_size, 165 | chan, 166 | img_h, 167 | img_w, 168 | input.dtype, 169 | ) 170 | else: 171 | for i in range(batch_start, batch_size): 172 | self._erase(input[i], chan, img_h, img_w, input.dtype) 173 | return input 174 | -------------------------------------------------------------------------------- /datasets/transforms/surg_transforms.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import torchvision.transforms.functional as F 4 | import warnings 5 | import random 6 | import numpy as np 7 | import torchvision 8 | from PIL import Image, ImageOps 9 | import numbers 10 | from imgaug import augmenters as iaa 11 | 12 | class SurgTransforms(object): 13 | 14 | def __init__(self, input_size=224, scales=(0.0, 0.3)): 15 | self.scales = scales 16 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size] 17 | self.aug = iaa.Sequential([ 18 | # Resize to (252, 448) 19 | iaa.Resize({"height": 252, "width": 448}), 20 | # Crop with Scale [0.8 - 1.0] 21 | iaa.Crop(percent=scales, keep_size=False), 22 | # Resize to (224, 224) 23 | iaa.Resize({"height": input_size, "width": input_size}), 24 | # Random Augment Surgery 25 | iaa.SomeOf((0, 2), [ 26 | iaa.pillike.EnhanceSharpness(), 27 | iaa.pillike.Autocontrast(), 28 | iaa.pillike.Equalize(), 29 | iaa.pillike.EnhanceContrast(), 30 | iaa.pillike.EnhanceColor(), 31 | iaa.pillike.EnhanceBrightness(), 32 | iaa.Rotate((-30, 30)), 33 | iaa.ShearX((-20, 20)), 34 | iaa.ShearY((-20, 20)), 35 | iaa.TranslateX(percent=(-0.1, 0.1)), 36 | iaa.TranslateY(percent=(-0.1, 0.1)) 37 | ]), 38 | iaa.Sometimes(0.3, iaa.AddToHueAndSaturation((-50, 50), per_channel=True)), 39 | # Horizontally flip 50% of all images 40 | iaa.Fliplr(0.5)]) 41 | 42 | 43 | def __call__(self, img_tuple): 44 | images, label = img_tuple 45 | 46 | # 给定裁剪起始及裁剪尺寸进行裁剪及Resize,处理过程中维持视频序列相同处理方案 47 | augDet = self.aug.to_deterministic() 48 | augment_images = [] 49 | for _, img in enumerate(images): 50 | img_aug = augDet.augment_image(np.array(img)) 51 | augment_images.append(img_aug) 52 | 53 | # for index, img in enumerate(augment_images): 54 | # img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 55 | # cv2.imshow(str(index), img) 56 | # cv2.waitKey() 57 | return (augment_images, label) 58 | 59 | 60 | class SurgStack(object): 61 | 62 | def __init__(self, roll=False): 63 | self.roll = roll 64 | 65 | def __call__(self, img_tuple): 66 | img_group, label = img_tuple 67 | if img_group[0].shape[2] == 1: 68 | return (np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2), label) 69 | elif img_group[0].shape[2] == 3: 70 | if self.roll: 71 | return (np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2), label) 72 | else: 73 | return (np.concatenate(img_group, axis=2), label) 74 | 75 | 76 | if __name__ == '__main__': 77 | 78 | class SurgTransforms(object): 79 | 80 | def __init__(self, input_size=224, scales=(0.0, 0.3)): 81 | self.scales = scales 82 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size] 83 | self.aug = iaa.Sequential([ 84 | # Resize to (252, 448) 85 | iaa.Resize({"height": 252, "width": 448}), 86 | # Crop with Scale [0.8 - 1.0] 87 | iaa.Crop(percent=scales, keep_size=False), 88 | # Resize to (224, 224) 89 | iaa.Resize({"height": input_size, "width": input_size}), 90 | # Random Augment Surgery 91 | iaa.SomeOf((0, 2), [ 92 | iaa.pillike.EnhanceSharpness(), 93 | iaa.pillike.Autocontrast(), 94 | iaa.pillike.Equalize(), 95 | iaa.pillike.EnhanceContrast(), 96 | iaa.pillike.EnhanceColor(), 97 | iaa.pillike.EnhanceBrightness(), 98 | iaa.Rotate((-30, 30)), 99 | iaa.ShearX((-20, 20)), 100 | iaa.ShearY((-20, 20)), 101 | iaa.TranslateX(percent=(-0.1, 0.1)), 102 | iaa.TranslateY(percent=(-0.1, 0.1)) 103 | ]), 104 | iaa.Sometimes(0.3, iaa.AddToHueAndSaturation((-50, 50), per_channel=True)), 105 | # Horizontally flip 50% of all images 106 | iaa.Fliplr(0.5) 107 | ]) 108 | 109 | 110 | def __call__(self, images): 111 | 112 | # 给定裁剪起始及裁剪尺寸进行裁剪及Resize,处理过程中维持视频序列相同处理方案 113 | augDet = self.aug.to_deterministic() 114 | augment_images = [] 115 | for _, img in enumerate(images): 116 | img_aug = augDet.augment_image(img) 117 | augment_images.append(img_aug) 118 | 119 | return augment_images 120 | 121 | A = SurgTransforms() 122 | origin_images = cv2.imread("data/cholec80/frames/train/video01/0.jpg") 123 | origin_images = cv2.cvtColor(origin_images, cv2.COLOR_BGR2RGB) 124 | 125 | images = np.array([origin_images for _ in range(4)], dtype=np.uint8) 126 | img_1 = A(images) 127 | for index, img in enumerate(img_1): 128 | img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR) 129 | print(img.shape) 130 | cv2.imshow(str(index), img) 131 | cv2.waitKey() 132 | img_2 = A(images) 133 | for index, img in enumerate(img_2): 134 | img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR) 135 | print(img.shape) 136 | cv2.imshow(str(index)+'2', img) 137 | cv2.waitKey() 138 | img_3 = A(images) 139 | for index, img in enumerate(img_3): 140 | img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR) 141 | print(img.shape) 142 | cv2.imshow(str(index)+'3', img) 143 | cv2.waitKey() 144 | -------------------------------------------------------------------------------- /datasets/transforms/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms.functional as F 3 | import warnings 4 | import random 5 | import numpy as np 6 | import torchvision 7 | from PIL import Image, ImageOps 8 | import numbers 9 | 10 | 11 | class GroupRandomCrop(object): 12 | def __init__(self, size): 13 | if isinstance(size, numbers.Number): 14 | self.size = (int(size), int(size)) 15 | else: 16 | self.size = size 17 | 18 | def __call__(self, img_tuple): 19 | img_group, label = img_tuple 20 | 21 | w, h = img_group[0].size 22 | th, tw = self.size 23 | 24 | out_images = list() 25 | 26 | x1 = random.randint(0, w - tw) 27 | y1 = random.randint(0, h - th) 28 | 29 | for img in img_group: 30 | assert(img.size[0] == w and img.size[1] == h) 31 | if w == tw and h == th: 32 | out_images.append(img) 33 | else: 34 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) 35 | 36 | return (out_images, label) 37 | 38 | 39 | class GroupCenterCrop(object): 40 | def __init__(self, size): 41 | self.worker = torchvision.transforms.CenterCrop(size) 42 | 43 | def __call__(self, img_tuple): 44 | img_group, label = img_tuple 45 | return ([self.worker(img) for img in img_group], label) 46 | 47 | 48 | class GroupNormalize(object): 49 | def __init__(self, mean, std): 50 | self.mean = mean 51 | self.std = std 52 | 53 | def __call__(self, tensor_tuple): 54 | tensor, label = tensor_tuple 55 | rep_mean = self.mean * (tensor.size()[0]//len(self.mean)) 56 | rep_std = self.std * (tensor.size()[0]//len(self.std)) 57 | 58 | # TODO: make efficient 59 | for t, m, s in zip(tensor, rep_mean, rep_std): 60 | t.sub_(m).div_(s) 61 | 62 | return (tensor,label) 63 | 64 | 65 | class GroupGrayScale(object): 66 | def __init__(self, size): 67 | self.worker = torchvision.transforms.Grayscale(size) 68 | 69 | def __call__(self, img_tuple): 70 | img_group, label = img_tuple 71 | return ([self.worker(img) for img in img_group], label) 72 | 73 | 74 | class GroupScale(object): 75 | """ Rescales the input PIL.Image to the given 'size'. 76 | 'size' will be the size of the smaller edge. 77 | For example, if height > width, then image will be 78 | rescaled to (size * height / width, size) 79 | size: size of the smaller edge 80 | interpolation: Default: PIL.Image.BILINEAR 81 | """ 82 | 83 | def __init__(self, size, interpolation=Image.BILINEAR): 84 | self.worker = torchvision.transforms.Resize(size, interpolation) 85 | 86 | def __call__(self, img_tuple): 87 | img_group, label = img_tuple 88 | return ([self.worker(img) for img in img_group], label) 89 | 90 | 91 | # 成组多尺度裁剪 92 | class GroupMultiScaleCrop(object): 93 | 94 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True): 95 | self.scales = scales if scales is not None else [1, .875, .75, .66] 96 | self.max_distort = max_distort 97 | self.fix_crop = fix_crop 98 | self.more_fix_crop = more_fix_crop 99 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size] 100 | self.interpolation = Image.BILINEAR 101 | 102 | def __call__(self, img_tuple): 103 | img_group, label = img_tuple 104 | 105 | im_size = img_group[0].size # [854, 480] 106 | 107 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) 108 | # 给定裁剪起始及裁剪尺寸进行裁剪及Resize,处理过程中维持视频序列相同处理方案 109 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group] 110 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) for img in crop_img_group] 111 | 112 | # import cv2 113 | # for index, img in enumerate(ret_img_group): 114 | # img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR) 115 | # cv2.imshow(str(index), img) 116 | # cv2.waitKey() 117 | return (ret_img_group, label) 118 | 119 | # 取较短边计算裁剪尺寸,并将小于input_size的设置为input_size,然后混合裁剪尺寸,生成多尺度长宽比 120 | def _sample_crop_size(self, im_size): 121 | image_w, image_h = im_size[0], im_size[1] 122 | # find a crop size 123 | base_size = min(image_w, image_h) 124 | crop_sizes = [int(base_size * x) for x in self.scales] 125 | 126 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes] 127 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes] 128 | 129 | pairs = [] 130 | for i, h in enumerate(crop_h): 131 | for j, w in enumerate(crop_w): 132 | if abs(i - j) <= self.max_distort: 133 | pairs.append((w, h)) 134 | 135 | crop_pair = random.choice(pairs) 136 | if not self.fix_crop: 137 | w_offset = random.randint(0, image_w - crop_pair[0]) 138 | h_offset = random.randint(0, image_h - crop_pair[1]) 139 | else: 140 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1]) 141 | 142 | return crop_pair[0], crop_pair[1], w_offset, h_offset 143 | 144 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): 145 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h) 146 | return random.choice(offsets) 147 | 148 | @staticmethod 149 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): 150 | w_step = (image_w - crop_w) // 4 151 | h_step = (image_h - crop_h) // 4 152 | 153 | ret = list() 154 | ret.append((0, 0)) # upper left 155 | ret.append((4 * w_step, 0)) # upper right 156 | ret.append((0, 4 * h_step)) # lower left 157 | ret.append((4 * w_step, 4 * h_step)) # lower right 158 | ret.append((2 * w_step, 2 * h_step)) # center 159 | 160 | if more_fix_crop: 161 | ret.append((0, 2 * h_step)) # center left 162 | ret.append((4 * w_step, 2 * h_step)) # center right 163 | ret.append((2 * w_step, 4 * h_step)) # lower center 164 | ret.append((2 * w_step, 0 * h_step)) # upper center 165 | 166 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter 167 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter 168 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter 169 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter 170 | return ret 171 | 172 | 173 | class Stack(object): 174 | 175 | def __init__(self, roll=False): 176 | self.roll = roll 177 | 178 | def __call__(self, img_tuple): 179 | img_group, label = img_tuple 180 | if img_group[0].mode == 'L': 181 | return (np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2), label) 182 | elif img_group[0].mode == 'RGB': 183 | if self.roll: 184 | return (np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2), label) 185 | else: 186 | return (np.concatenate(img_group, axis=2), label) 187 | 188 | 189 | class ToTorchFormatTensor(object): 190 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] 191 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ 192 | def __init__(self, div=True): 193 | self.div = div 194 | 195 | def __call__(self, pic_tuple): 196 | pic, label = pic_tuple 197 | if isinstance(pic, np.ndarray): 198 | # handle numpy array 199 | img = torch.from_numpy(pic).permute(2, 0, 1).contiguous() 200 | else: 201 | # handle PIL Image 202 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 203 | img = img.view(pic.size[1], pic.size[0], len(pic.mode)) 204 | # put it from HWC to CHW format 205 | # yikes, this transpose takes 80% of the loading time/CPU 206 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 207 | return (img.float().div(255.) if self.div else img.float(), label) 208 | 209 | 210 | class IdentityTransform(object): 211 | 212 | def __call__(self, data): 213 | return data -------------------------------------------------------------------------------- /datasets/transforms/volume_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image, ImageOps 3 | import torch 4 | import numbers 5 | import random 6 | import torchvision.transforms.functional as TF 7 | 8 | def convert_img(img): 9 | """Converts (H, W, C) numpy.ndarray to (C, W, H) format""" 10 | if len(img.shape) == 3: 11 | img = img.transpose(2, 0, 1) 12 | if len(img.shape) == 2: 13 | img = np.expand_dims(img, 0) 14 | return img 15 | 16 | 17 | class ClipToTensor(object): 18 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] 19 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] 20 | """ 21 | 22 | def __init__(self, channel_nb=3, div_255=True, numpy=False): 23 | self.channel_nb = channel_nb 24 | self.div_255 = div_255 25 | self.numpy = numpy 26 | 27 | def __call__(self, clip): 28 | """ 29 | Args: clip (list of numpy.ndarray): clip (list of images) 30 | to be converted to tensor. 31 | """ 32 | # Retrieve shape 33 | if isinstance(clip[0], np.ndarray): 34 | h, w, ch = clip[0].shape 35 | assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch) 36 | elif isinstance(clip[0], Image.Image): 37 | w, h = clip[0].size 38 | else: 39 | raise TypeError( 40 | "Expected numpy.ndarray or PIL.Image\ 41 | but got list of {0}".format( 42 | type(clip[0]) 43 | ) 44 | ) 45 | 46 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) 47 | 48 | # Convert 49 | for img_idx, img in enumerate(clip): 50 | if isinstance(img, np.ndarray): 51 | pass 52 | elif isinstance(img, Image.Image): 53 | img = np.array(img, copy=False) 54 | else: 55 | raise TypeError( 56 | "Expected numpy.ndarray or PIL.Image\ 57 | but got list of {0}".format( 58 | type(clip[0]) 59 | ) 60 | ) 61 | img = convert_img(img) 62 | np_clip[:, img_idx, :, :] = img 63 | if self.numpy: 64 | if self.div_255: 65 | np_clip = np_clip / 255.0 66 | return np_clip 67 | 68 | else: 69 | tensor_clip = torch.from_numpy(np_clip) 70 | 71 | if not isinstance(tensor_clip, torch.FloatTensor): 72 | tensor_clip = tensor_clip.float() 73 | if self.div_255: 74 | tensor_clip = torch.div(tensor_clip, 255) 75 | return tensor_clip 76 | 77 | 78 | # Note this norms data to -1/1 79 | class ClipToTensor_K(object): 80 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] 81 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] 82 | """ 83 | 84 | def __init__(self, channel_nb=3, div_255=True, numpy=False): 85 | self.channel_nb = channel_nb 86 | self.div_255 = div_255 87 | self.numpy = numpy 88 | 89 | def __call__(self, clip): 90 | """ 91 | Args: clip (list of numpy.ndarray): clip (list of images) 92 | to be converted to tensor. 93 | """ 94 | # Retrieve shape 95 | if isinstance(clip[0], np.ndarray): 96 | h, w, ch = clip[0].shape 97 | assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch) 98 | elif isinstance(clip[0], Image.Image): 99 | w, h = clip[0].size 100 | else: 101 | raise TypeError( 102 | "Expected numpy.ndarray or PIL.Image\ 103 | but got list of {0}".format( 104 | type(clip[0]) 105 | ) 106 | ) 107 | 108 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) 109 | 110 | # Convert 111 | for img_idx, img in enumerate(clip): 112 | if isinstance(img, np.ndarray): 113 | pass 114 | elif isinstance(img, Image.Image): 115 | img = np.array(img, copy=False) 116 | else: 117 | raise TypeError( 118 | "Expected numpy.ndarray or PIL.Image\ 119 | but got list of {0}".format( 120 | type(clip[0]) 121 | ) 122 | ) 123 | img = convert_img(img) 124 | np_clip[:, img_idx, :, :] = img 125 | if self.numpy: 126 | if self.div_255: 127 | np_clip = (np_clip - 127.5) / 127.5 128 | return np_clip 129 | 130 | else: 131 | tensor_clip = torch.from_numpy(np_clip) 132 | 133 | if not isinstance(tensor_clip, torch.FloatTensor): 134 | tensor_clip = tensor_clip.float() 135 | if self.div_255: 136 | tensor_clip = torch.div(torch.sub(tensor_clip, 127.5), 127.5) 137 | return tensor_clip 138 | 139 | 140 | class ToTensor(object): 141 | """Converts numpy array to tensor""" 142 | 143 | def __call__(self, array): 144 | tensor = torch.from_numpy(array) 145 | return tensor 146 | 147 | 148 | class RandomCrop(object): 149 | 150 | def __init__(self, size, padding=0, sequence_length=16): 151 | if isinstance(size, numbers.Number): 152 | self.size = (int(size), int(size)) 153 | else: 154 | self.size = size 155 | self.sequence_length = sequence_length 156 | self.padding = padding 157 | self.count = 0 158 | 159 | def __call__(self, img): 160 | 161 | if self.padding > 0: 162 | img = ImageOps.expand(img, border=self.padding, fill=0) 163 | 164 | w, h = img.size 165 | th, tw = self.size 166 | if w == tw and h == th: 167 | return img 168 | 169 | random.seed(self.count // self.sequence_length) 170 | x1 = random.randint(0, w - tw) 171 | y1 = random.randint(0, h - th) 172 | # print(self.count, x1, y1) 173 | self.count += 1 174 | return img.crop((x1, y1, x1 + tw, y1 + th)) 175 | 176 | 177 | class RandomHorizontalFlip(object): 178 | def __init__(self, sequence_length=16): 179 | self.count = 0 180 | self.sequence_length = sequence_length 181 | 182 | def __call__(self, img): 183 | seed = self.count // self.sequence_length 184 | random.seed(seed) 185 | prob = random.random() 186 | self.count += 1 187 | # print(self.count, seed, prob) 188 | if prob < 0.5: 189 | return img.transpose(Image.FLIP_LEFT_RIGHT) 190 | return img 191 | 192 | 193 | class RandomRotation(object): 194 | def __init__(self,degrees, sequence_length=16): 195 | self.degrees = degrees 196 | self.count = 0 197 | self.sequence_length = sequence_length 198 | 199 | def __call__(self, img): 200 | seed = self.count // self.sequence_length 201 | random.seed(seed) 202 | self.count += 1 203 | angle = random.randint(-self.degrees,self.degrees) 204 | return TF.rotate(img, angle) 205 | 206 | 207 | 208 | class ColorJitter(object): 209 | def __init__(self,brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1, sequence_length=16): 210 | self.brightness = brightness 211 | self.contrast = contrast 212 | self.saturation = saturation 213 | self.hue = hue 214 | self.count = 0 215 | self.sequence_length = sequence_length 216 | 217 | def __call__(self, img): 218 | seed = self.count // self.sequence_length 219 | random.seed(seed) 220 | self.count += 1 221 | brightness_factor = random.uniform(1 - self.brightness, 1 + self.brightness) 222 | contrast_factor = random.uniform(1 - self.contrast, 1 + self.contrast) 223 | saturation_factor = random.uniform(1 - self.saturation, 1 + self.saturation) 224 | hue_factor = random.uniform(- self.hue, self.hue) 225 | 226 | img_ = TF.adjust_brightness(img,brightness_factor) 227 | img_ = TF.adjust_contrast(img_,contrast_factor) 228 | img_ = TF.adjust_saturation(img_,saturation_factor) 229 | img_ = TF.adjust_hue(img_,hue_factor) 230 | 231 | return img_ -------------------------------------------------------------------------------- /downstream_phase/__pycache__/datasets_phase.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/downstream_phase/__pycache__/datasets_phase.cpython-310.pyc -------------------------------------------------------------------------------- /downstream_phase/__pycache__/datasets_phase.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/downstream_phase/__pycache__/datasets_phase.cpython-311.pyc -------------------------------------------------------------------------------- /downstream_phase/__pycache__/engine_for_phase.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/downstream_phase/__pycache__/engine_for_phase.cpython-310.pyc -------------------------------------------------------------------------------- /downstream_phase/__pycache__/engine_for_phase.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/downstream_phase/__pycache__/engine_for_phase.cpython-311.pyc -------------------------------------------------------------------------------- /downstream_phase/datasets_phase.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datasets.transforms import * 3 | from datasets.transforms.surg_transforms import * 4 | 5 | from datasets.phase.Cholec80_phase import PhaseDataset_Cholec80 6 | from datasets.phase.AutoLaparo_phase import PhaseDataset_AutoLaparo 7 | from datasets.phase.PmLR50_phase import PhaseDataset_PmLR50 8 | 9 | def build_dataset(is_train, test_mode, fps, args): 10 | """Load video phase recognition dataset.""" 11 | 12 | if args.data_set == "Cholec80": 13 | mode = None 14 | anno_path = None 15 | if is_train is True: 16 | mode = "train" 17 | anno_path = os.path.join( 18 | args.data_path, "labels", mode, fps + "train.pickle" 19 | ) 20 | elif test_mode is True: 21 | mode = "test" 22 | anno_path = os.path.join( 23 | args.data_path, "labels", mode, fps + "val_test.pickle" 24 | ) 25 | else: 26 | mode = "test" 27 | anno_path = os.path.join(args.data_path, "labels", mode, fps + "val_test.pickle") 28 | 29 | dataset = PhaseDataset_Cholec80( 30 | anno_path=anno_path, 31 | data_path=args.data_path, 32 | mode=mode, 33 | data_strategy=args.data_strategy, 34 | output_mode=args.output_mode, 35 | cut_black=args.cut_black, 36 | clip_len=args.num_frames, 37 | frame_sample_rate=args.sampling_rate, 38 | keep_aspect_ratio=True, 39 | crop_size=args.input_size, 40 | short_side_size=args.short_side_size, 41 | new_height=256, 42 | new_width=320, 43 | args=args, 44 | ) 45 | nb_classes = 7 46 | 47 | elif args.data_set == "AutoLaparo": 48 | mode = None 49 | anno_path = None 50 | if is_train is True: 51 | mode = "train" 52 | anno_path = os.path.join( 53 | args.data_path, "labels_pkl", mode, fps + "train.pickle" 54 | ) 55 | elif test_mode is True: 56 | mode = "test" 57 | anno_path = os.path.join( 58 | args.data_path, "labels_pkl", mode, fps + "test.pickle" 59 | ) 60 | else: 61 | mode = "val" 62 | anno_path = os.path.join(args.data_path, "labels_pkl", mode, fps + "val.pickle") 63 | 64 | dataset = PhaseDataset_AutoLaparo( 65 | anno_path=anno_path, 66 | data_path=args.data_path, 67 | mode=mode, 68 | data_strategy=args.data_strategy, 69 | output_mode=args.output_mode, 70 | cut_black=args.cut_black, 71 | clip_len=args.num_frames, 72 | frame_sample_rate=args.sampling_rate, 73 | keep_aspect_ratio=True, 74 | crop_size=args.input_size, 75 | short_side_size=args.short_side_size, 76 | new_height=256, 77 | new_width=320, 78 | args=args, 79 | ) 80 | nb_classes = 7 81 | 82 | elif args.data_set == "PmLR50": 83 | mode = None 84 | anno_path = None 85 | if is_train is True: 86 | mode = "train" 87 | anno_path = os.path.join( 88 | args.data_path, "labels", mode, fps + "train.pickle" 89 | ) 90 | elif test_mode is True: 91 | mode = "test" 92 | anno_path = os.path.join( 93 | args.data_path, "labels", 'infer', fps + "infer.pickle" 94 | ) 95 | else: 96 | mode = "val" 97 | anno_path = os.path.join(args.data_path, "labels", "test", fps + "test.pickle") 98 | 99 | dataset = PhaseDataset_PmLR50( 100 | anno_path=anno_path, 101 | data_path=args.data_path, 102 | mode=mode, 103 | data_strategy=args.data_strategy, 104 | output_mode=args.output_mode, 105 | cut_black=args.cut_black, 106 | clip_len=args.num_frames, 107 | frame_sample_rate=args.sampling_rate, 108 | keep_aspect_ratio=True, 109 | crop_size=args.input_size, 110 | short_side_size=args.short_side_size, 111 | new_height=256, 112 | new_width=320, 113 | args=args, 114 | ) 115 | nb_classes = 5 116 | else: 117 | print("Error") 118 | 119 | assert nb_classes == args.nb_classes 120 | print("%s - %s : Number of the class = %d" % (mode, fps, args.nb_classes)) 121 | print("Data Strategy: %s" % args.data_strategy) 122 | print("Output Mode: %s" % args.output_mode) 123 | print("Cut Black: %s" % args.cut_black) 124 | if args.sampling_rate == 0: 125 | print( 126 | "%s Frames with Temporal sample Rate %s (%s)" 127 | % (str(args.num_frames), str(args.sampling_rate), "Exponential Stride") 128 | ) 129 | elif args.sampling_rate == -1: 130 | print( 131 | "%s Frames with Temporal sample Rate %s (%s)" 132 | % (str(args.num_frames), str(args.sampling_rate), "Random Stride (1-5)") 133 | ) 134 | elif args.sampling_rate == -2: 135 | print( 136 | "%s Frames with Temporal sample Rate %s (%s)" 137 | % (str(args.num_frames), str(args.sampling_rate), "Incremental Stride") 138 | ) 139 | else: 140 | print( 141 | "%s Frames with Temporal sample Rate %s" 142 | % (str(args.num_frames), str(args.sampling_rate)) 143 | ) 144 | 145 | return dataset, nb_classes 146 | -------------------------------------------------------------------------------- /evaluation_matlab/Evaluate.m: -------------------------------------------------------------------------------- 1 | function [ res, prec, rec, acc, f1 ] = Evaluate( gtLabelID, predLabelID, fps ) 2 | %EVALUATE 3 | % A function to evaluate the performance of the phase recognition method 4 | % providing jaccard index, precision, and recall for each phase 5 | % and accuracy over the surgery. All metrics are computed in a relaxed 6 | % boundary mode. 7 | % OUTPUT: 8 | % res: the jaccard index per phase (relaxed) - NaN for non existing phase in GT 9 | % prec: precision per phase (relaxed) - NaN for non existing phase in GT 10 | % rec: recall per phase (relaxed) - NaN for non existing phase in GT 11 | % acc: the accuracy over the video (relaxed) 12 | 13 | oriT = 10 * fps; % 10 seconds relaxed boundary 14 | 15 | res = []; prec = []; rec = []; 16 | diff = predLabelID - gtLabelID; 17 | updatedDiff = []; 18 | 19 | % obtain the true positive with relaxed boundary 20 | for iPhase = 1:7 % nPhases 21 | gtConn = bwconncomp(gtLabelID == iPhase); 22 | 23 | for iConn = 1:gtConn.NumObjects 24 | startIdx = min(gtConn.PixelIdxList{iConn}); 25 | endIdx = max(gtConn.PixelIdxList{iConn}); 26 | 27 | curDiff = diff(startIdx:endIdx); 28 | 29 | % in the case where the phase is shorter than the relaxed boundary 30 | t = oriT; 31 | if(t > length(curDiff)) 32 | t = length(curDiff); 33 | disp(['Very short phase ' num2str(iPhase)]); 34 | end 35 | 36 | % relaxed boundary 37 | % revised for cholec80 dataset !!!!!!!!!!! 38 | if(iPhase == 4 || iPhase == 5) % Gallbladder dissection and packaging might jump between two phases 39 | curDiff(curDiff(1:t)==-1) = 0; % late transition 40 | curDiff(curDiff(end-t+1:end)==1 | curDiff(end-t+1:end)==2) = 0; % early transition % 5 can be predicted as 6/7 at the end > 5 followed by 6/7 41 | elseif(iPhase == 6 || iPhase == 7) % Gallbladder dissection might jump between two phases 42 | curDiff(curDiff(1:t)==-1 | curDiff(1:t)==-2) = 0; % late transition 43 | curDiff(curDiff(end-t+1:end)==1 | curDiff(end-t+1:end)==2) = 0; % early transition 44 | else 45 | % general situation 46 | curDiff(curDiff(1:t)==-1) = 0; % late transition 47 | curDiff(curDiff(end-t+1:end)==1) = 0; % early transition 48 | end 49 | 50 | updatedDiff(startIdx:endIdx) = curDiff; 51 | end 52 | end 53 | 54 | % compute jaccard index, prec, and rec per phase 55 | for iPhase = 1:7 56 | gtConn = bwconncomp(gtLabelID == iPhase); 57 | predConn = bwconncomp(predLabelID == iPhase); 58 | 59 | if(gtConn.NumObjects == 0) 60 | % no iPhase in current ground truth, assigned NaN values 61 | % SHOULD be excluded in the computation of mean (use nanmean) 62 | res(end+1) = NaN; 63 | prec(end+1) = NaN; 64 | rec(end+1) = NaN; 65 | continue; 66 | end 67 | 68 | iPUnion = union(vertcat(predConn.PixelIdxList{:}), vertcat(gtConn.PixelIdxList{:})); 69 | tp = sum(updatedDiff(iPUnion) == 0); 70 | jaccard = tp/length(iPUnion); 71 | jaccard = jaccard * 100; 72 | 73 | % res(end+1, :) = [iPhase jaccard]; 74 | res(end+1) = jaccard; 75 | 76 | % Compute prec and rec 77 | indx = (gtLabelID == iPhase); 78 | 79 | sumTP = tp; % sum(predLabelID(indx) == iPhase); 80 | sumPred = sum(predLabelID == iPhase); 81 | sumGT = sum(indx); 82 | 83 | prec(end+1) = sumTP * 100 / sumPred; 84 | rec(end+1) = sumTP * 100 / sumGT; 85 | end 86 | 87 | % compute accuracy 88 | acc = sum(updatedDiff==0) / length(gtLabelID); 89 | acc = acc * 100; 90 | f1 = zeros(1,7); 91 | for i=1:7 92 | % f1(i) = 2*(prec(i)*rec(i))/((prec(i)+rec(i))+0.0000001) 93 | f1(i) = 2*(prec(i)*rec(i))/((prec(i)+rec(i))); 94 | end 95 | end 96 | 97 | -------------------------------------------------------------------------------- /evaluation_matlab/Evaluate_Cataract101.m: -------------------------------------------------------------------------------- 1 | function [ res, prec, rec, acc, f1 ] = Evaluate( gtLabelID, predLabelID, fps ) 2 | %EVALUATE 3 | % A function to evaluate the performance of the phase recognition method 4 | % providing jaccard index, precision, and recall for each phase 5 | % and accuracy over the surgery. All metrics are computed in a relaxed 6 | % boundary mode. 7 | % OUTPUT: 8 | % res: the jaccard index per phase (relaxed) - NaN for non existing phase in GT 9 | % prec: precision per phase (relaxed) - NaN for non existing phase in GT 10 | % rec: recall per phase (relaxed) - NaN for non existing phase in GT 11 | % acc: the accuracy over the video (relaxed) 12 | 13 | oriT = 10 * fps; % 10 seconds relaxed boundary 14 | 15 | res = []; prec = []; rec = []; 16 | diff = predLabelID - gtLabelID; 17 | updatedDiff = []; 18 | 19 | % obtain the true positive with relaxed boundary 20 | for iPhase = 1:10 % nPhases 21 | gtConn = bwconncomp(gtLabelID == iPhase); 22 | 23 | for iConn = 1:gtConn.NumObjects 24 | startIdx = min(gtConn.PixelIdxList{iConn}); 25 | endIdx = max(gtConn.PixelIdxList{iConn}); 26 | 27 | curDiff = diff(startIdx:endIdx); 28 | 29 | updatedDiff(startIdx:endIdx) = curDiff; 30 | end 31 | end 32 | 33 | % compute jaccard index, prec, and rec per phase 34 | for iPhase = 1:10 35 | gtConn = bwconncomp(gtLabelID == iPhase); 36 | predConn = bwconncomp(predLabelID == iPhase); 37 | 38 | if(gtConn.NumObjects == 0) 39 | % no iPhase in current ground truth, assigned NaN values 40 | % SHOULD be excluded in the computation of mean (use nanmean) 41 | res(end+1) = NaN; 42 | prec(end+1) = NaN; 43 | rec(end+1) = NaN; 44 | continue; 45 | end 46 | 47 | iPUnion = union(vertcat(predConn.PixelIdxList{:}), vertcat(gtConn.PixelIdxList{:})); 48 | tp = sum(updatedDiff(iPUnion) == 0); 49 | jaccard = tp/length(iPUnion); 50 | jaccard = jaccard * 100; 51 | 52 | % res(end+1, :) = [iPhase jaccard]; 53 | res(end+1) = jaccard; 54 | 55 | % Compute prec and rec 56 | indx = (gtLabelID == iPhase); 57 | 58 | sumTP = tp; % sum(predLabelID(indx) == iPhase); 59 | sumPred = sum(predLabelID == iPhase); 60 | sumGT = sum(indx); 61 | 62 | prec(end+1) = sumTP * 100 / sumPred; 63 | rec(end+1) = sumTP * 100 / sumGT; 64 | end 65 | 66 | % compute accuracy 67 | acc = sum(updatedDiff==0) / length(gtLabelID); 68 | acc = acc * 100; 69 | f1 = zeros(1,10); 70 | for i=1:10 71 | % f1(i) = 2*(prec(i)*rec(i))/((prec(i)+rec(i))+0.0000001) 72 | f1(i) = 2*(prec(i)*rec(i))/((prec(i)+rec(i))); 73 | end 74 | end 75 | 76 | -------------------------------------------------------------------------------- /evaluation_matlab/Evaluate_m2cai.m: -------------------------------------------------------------------------------- 1 | function [ res, prec, rec, acc ] = Evaluate( gtLabelID, predLabelID, fps ) 2 | %EVALUATE 3 | % A function to evaluate the performance of the phase recognition method 4 | % providing jaccard index, precision, and recall for each phase 5 | % and accuracy over the surgery. All metrics are computed in a relaxed 6 | % boundary mode. 7 | % OUTPUT: 8 | % res: the jaccard index per phase (relaxed) - NaN for non existing phase in GT 9 | % prec: precision per phase (relaxed) - NaN for non existing phase in GT 10 | % rec: recall per phase (relaxed) - NaN for non existing phase in GT 11 | % acc: the accuracy over the video (relaxed) 12 | 13 | oriT = 10 * fps; % 10 seconds relaxed boundary 14 | 15 | res = []; prec = []; rec = []; 16 | diff = predLabelID - gtLabelID; 17 | updatedDiff = []; 18 | 19 | % obtain the true positive with relaxed boundary 20 | for iPhase = 1:8 % nPhases 21 | gtConn = bwconncomp(gtLabelID == iPhase); 22 | 23 | for iConn = 1:gtConn.NumObjects 24 | startIdx = min(gtConn.PixelIdxList{iConn}); 25 | endIdx = max(gtConn.PixelIdxList{iConn}); 26 | 27 | curDiff = diff(startIdx:endIdx); 28 | 29 | % in the case where the phase is shorter than the relaxed boundary 30 | t = oriT; 31 | if(t > length(curDiff)) 32 | t = length(curDiff); 33 | disp(['Very short phase ' num2str(iPhase)]); 34 | end 35 | 36 | % relaxed boundary 37 | if(iPhase == 5 || iPhase == 6) % Gallbladder dissection and packaging might jump between two phases 38 | curDiff(curDiff(1:t)==-1) = 0; % late transition 39 | curDiff(curDiff(end-t+1:end)==1 | curDiff(end-t+1:end)==2) = 0; % early transition % 5 can be predicted as 6/7 at the end > 5 followed by 6/7 40 | elseif(iPhase == 7 || iPhase == 8) % Gallbladder dissection might jump between two phases 41 | curDiff(curDiff(1:t)==-1 | curDiff(1:t)==-2) = 0; % late transition 42 | curDiff(curDiff(end-t+1:end)==1 | curDiff(end-t+1:end)==2) = 0; % early transition 43 | else 44 | % general situation 45 | curDiff(curDiff(1:t)==-1) = 0; % late transition 46 | curDiff(curDiff(end-t+1:end)==1) = 0; % early transition 47 | end 48 | 49 | updatedDiff(startIdx:endIdx) = curDiff; 50 | end 51 | end 52 | 53 | % compute jaccard index, prec, and rec per phase 54 | for iPhase = 1:8 55 | gtConn = bwconncomp(gtLabelID == iPhase); 56 | predConn = bwconncomp(predLabelID == iPhase); 57 | 58 | if(gtConn.NumObjects == 0) 59 | % no iPhase in current ground truth, assigned NaN values 60 | % SHOULD be excluded in the computation of mean (use nanmean) 61 | res(end+1) = NaN; 62 | prec(end+1) = NaN; 63 | rec(end+1) = NaN; 64 | continue; 65 | end 66 | 67 | iPUnion = union(vertcat(predConn.PixelIdxList{:}), vertcat(gtConn.PixelIdxList{:})); 68 | tp = sum(updatedDiff(iPUnion) == 0); 69 | jaccard = tp/length(iPUnion); 70 | jaccard = jaccard * 100; 71 | 72 | % res(end+1, :) = [iPhase jaccard]; 73 | res(end+1) = jaccard; 74 | 75 | % Compute prec and rec 76 | indx = (gtLabelID == iPhase); 77 | 78 | sumTP = tp; % sum(predLabelID(indx) == iPhase); 79 | sumPred = sum(predLabelID == iPhase); 80 | sumGT = sum(indx); 81 | 82 | prec(end+1) = sumTP * 100 / sumPred; 83 | rec(end+1) = sumTP * 100 / sumGT; 84 | end 85 | 86 | % compute accuracy 87 | acc = sum(updatedDiff==0) / length(gtLabelID); 88 | acc = acc * 100; 89 | 90 | end 91 | 92 | -------------------------------------------------------------------------------- /evaluation_matlab/Main.m: -------------------------------------------------------------------------------- 1 | close all; clear all; 2 | 3 | phaseGroundTruths = {}; 4 | gt_root_folder = 'phase_annotations/'; 5 | for k = 41:80 6 | num = num2str(k); 7 | to_add = ['video-' num]; 8 | video_name = [gt_root_folder to_add '.txt']; 9 | phaseGroundTruths = [phaseGroundTruths video_name]; 10 | end 11 | % phaseGroundTruths = {'video41-phase.txt', ... 12 | % 'video42-phase.txt'}; 13 | % phaseGroundTruths 14 | 15 | phases = {'Preparation', 'CalotTriangleDissection', ... 16 | 'ClippingCutting', 'GallbladderDissection', 'GallbladderPackaging', 'CleaningCoagulation', ... 17 | 'GallbladderRetraction'}; 18 | 19 | fps = 1; 20 | jaccard1 = zeros(7, 40); 21 | prec1 = zeros(7, 40); 22 | rec1 = zeros(7, 40); 23 | acc1 = zeros(1, 40); 24 | f11 = zeros(7, 40); 25 | 26 | 27 | for i = 1:length(phaseGroundTruths) 28 | predroot = 'prediction/'; 29 | phaseGroundTruth = phaseGroundTruths{i}; 30 | predFile = [predroot 'video-' phaseGroundTruth(end-5:end-4) '.txt']; 31 | [gt] = ReadPhaseLabel(phaseGroundTruth); 32 | [pred] = ReadPhaseLabel(predFile); 33 | 34 | if(size(gt{1}, 1) ~= size(pred{1},1) || size(gt{2}, 1) ~= size(pred{2},1)) 35 | error(['ERROR:' ground_truth_file '\nGround truth and prediction have different sizes']); 36 | end 37 | 38 | if(~isempty(find(gt{1} ~= pred{1}))) 39 | error(['ERROR: ' ground_truth_file '\nThe frame index in ground truth and prediction is not equal']); 40 | end 41 | 42 | t = length(gt{2}); 43 | for z=1:t 44 | gt{2}{z} = gt{2}{z}(1); 45 | pred{2}{z} = pred{2}{z}(1); 46 | end 47 | % reassigning the phase labels to numbers 48 | gtLabelID = []; 49 | predLabelID = []; 50 | for j = 1:7 51 | gtLabelID(find(strcmp(num2str(j-1), gt{2}))) = j; 52 | predLabelID(find(strcmp(num2str(j-1), pred{2}))) = j; 53 | end 54 | 55 | % compute jaccard index, precision, recall, and the accuracy 56 | [jaccard, prec, rec, acc, f1] = Evaluate(gtLabelID, predLabelID, fps); 57 | jaccard1(:, i) = jaccard; 58 | prec1(:, i) = prec; 59 | rec1(:, i) = rec; 60 | acc1(i) = acc; 61 | f11(:, i) = f1; 62 | end 63 | 64 | acc = acc1; 65 | rec = rec1; 66 | prec = prec1; 67 | jaccard = jaccard1; 68 | f1 = f11; 69 | 70 | accPerVideo= acc; 71 | 72 | % Compute means and stds 73 | index = find(jaccard>100); 74 | jaccard(index)=100; 75 | meanJaccPerPhase = nanmean(jaccard, 2); 76 | meanJaccPerVideo = nanmean(jaccard, 1); 77 | meanJacc = mean(meanJaccPerPhase); 78 | stdJacc = std(meanJaccPerPhase); 79 | for h = 1:7 80 | jaccphase = jaccard(h,:); 81 | meanjaccphase(h) = nanmean(jaccphase); 82 | stdjaccphase(h) = nanstd(jaccphase); 83 | end 84 | 85 | index = find(f1>100); 86 | f1(index)=100; 87 | meanF1PerPhase = nanmean(f1, 2); 88 | meanF1PerVideo = nanmean(f1, 1); 89 | meanF1 = mean(meanF1PerPhase); 90 | stdF1 = std(meanF1PerVideo); 91 | for h = 1:7 92 | f1phase = f1(h,:); 93 | meanf1phase(h) = nanmean(f1phase); 94 | stdf1phase(h) = nanstd(f1phase); 95 | end 96 | 97 | index = find(prec>100); 98 | prec(index)=100; 99 | meanPrecPerPhase = nanmean(prec, 2); 100 | meanPrecPerVideo = nanmean(prec, 1); 101 | meanPrec = nanmean(meanPrecPerPhase); 102 | stdPrec = nanstd(meanPrecPerPhase); 103 | for h = 1:7 104 | precphase = prec(h,:); 105 | meanprecphase(h) = nanmean(precphase); 106 | stdprecphase(h) = nanstd(precphase); 107 | end 108 | 109 | index = find(rec>100); 110 | rec(index)=100; 111 | meanRecPerPhase = nanmean(rec, 2); 112 | meanRecPerVideo = nanmean(rec, 1); 113 | meanRec = mean(meanRecPerPhase); 114 | stdRec = std(meanRecPerPhase); 115 | for h = 1:7 116 | recphase = rec(h,:); 117 | meanrecphase(h) = nanmean(recphase); 118 | stdrecphase(h) = nanstd(recphase); 119 | end 120 | 121 | 122 | meanAcc = mean(acc); 123 | stdAcc = std(acc); 124 | 125 | % Display results 126 | % fprintf('model is :%s\n', model_rootfolder); 127 | disp('================================================'); 128 | disp([sprintf('%25s', 'Phase') '|' sprintf('%6s', 'Jacc') '|'... 129 | sprintf('%6s', 'Prec') '|' sprintf('%6s', 'Rec') '|']); 130 | disp('================================================'); 131 | for iPhase = 1:length(phases) 132 | disp([sprintf('%25s', phases{iPhase}) '|' sprintf('%6.2f', meanJaccPerPhase(iPhase)) '|' ... 133 | sprintf('%6.2f', meanPrecPerPhase(iPhase)) '|' sprintf('%6.2f', meanRecPerPhase(iPhase)) '|']); 134 | disp('---------------------------------------------'); 135 | end 136 | disp('================================================'); 137 | 138 | disp(['Mean jaccard: ' sprintf('%5.2f', meanJacc) '+-' sprintf('%5.2f', stdJacc)]); 139 | disp(['Mean f1-score: ' sprintf('%5.2f', meanF1) '+-' sprintf('%5.2f', stdF1)]); 140 | disp(['Mean accuracy: ' sprintf('%5.2f', meanAcc) '+-' sprintf('%5.2f', stdAcc)]); 141 | disp(['Mean precision: ' sprintf('%5.2f', meanPrec) '+-' sprintf('%5.2f', stdPrec)]); 142 | disp(['Mean recall: ' sprintf('%5.2f', meanRec) '+-' sprintf('%5.2f', stdRec)]); 143 | -------------------------------------------------------------------------------- /evaluation_matlab/Main_AutoLaparo.m: -------------------------------------------------------------------------------- 1 | close all; clear all; 2 | 3 | phaseGroundTruths = {}; 4 | gt_root_folder = 'phase_annotations/'; 5 | for k = 15:21 6 | num = num2str(k); 7 | to_add = ['video-' num]; 8 | video_name = [gt_root_folder to_add '.txt']; 9 | phaseGroundTruths = [phaseGroundTruths video_name]; 10 | end 11 | % phaseGroundTruths = {'video41-phase.txt', ... 12 | % 'video42-phase.txt'}; 13 | % phaseGroundTruths 14 | 15 | phases = {'Preparation', 'CalotTriangleDissection', ... 16 | 'ClippingCutting', 'GallbladderDissection', 'GallbladderPackaging', 'CleaningCoagulation', ... 17 | 'GallbladderRetraction'}; 18 | 19 | fps = 1; 20 | jaccard1 = zeros(7, 7); 21 | prec1 = zeros(7, 7); 22 | rec1 = zeros(7, 7); 23 | acc1 = zeros(1, 7); 24 | f11 = zeros(7, 7); 25 | 26 | 27 | for i = 1:length(phaseGroundTruths) 28 | predroot = 'prediction/'; 29 | %predroot = '../../Results/multi/phase'; 30 | %predroot = '../../Results/multi_kl_best_890_882/phase_post'; 31 | phaseGroundTruth = phaseGroundTruths{i}; 32 | predFile = [predroot 'video-' phaseGroundTruth(end-5:end-4) '.txt']; 33 | [gt] = ReadPhaseLabel(phaseGroundTruth); 34 | [pred] = ReadPhaseLabel(predFile); 35 | 36 | if(size(gt{1}, 1) ~= size(pred{1},1) || size(gt{2}, 1) ~= size(pred{2},1)) 37 | error(['ERROR:' ground_truth_file '\nGround truth and prediction have different sizes']); 38 | end 39 | 40 | if(~isempty(find(gt{1} ~= pred{1}))) 41 | error(['ERROR: ' ground_truth_file '\nThe frame index in ground truth and prediction is not equal']); 42 | end 43 | 44 | t = length(gt{2}); 45 | for z=1:t 46 | gt{2}{z} = gt{2}{z}(1); 47 | pred{2}{z} = pred{2}{z}(1); 48 | end 49 | % reassigning the phase labels to numbers 50 | gtLabelID = []; 51 | predLabelID = []; 52 | for j = 1:7 53 | gtLabelID(find(strcmp(num2str(j-1), gt{2}))) = j; 54 | predLabelID(find(strcmp(num2str(j-1), pred{2}))) = j; 55 | end 56 | 57 | % compute jaccard index, precision, recall, and the accuracy 58 | [jaccard, prec, rec, acc, f1] = Evaluate(gtLabelID, predLabelID, fps); 59 | jaccard1(:, i) = jaccard; 60 | prec1(:, i) = prec; 61 | rec1(:, i) = rec; 62 | acc1(i) = acc; 63 | f11(:, i) = f1; 64 | end 65 | 66 | acc = acc1; 67 | rec = rec1; 68 | prec = prec1; 69 | jaccard = jaccard1; 70 | f1 = f11; 71 | 72 | accPerVideo= acc; 73 | 74 | % Compute means and stds 75 | index = find(jaccard>100); 76 | jaccard(index)=100; 77 | meanJaccPerPhase = nanmean(jaccard, 2); 78 | meanJaccPerVideo = nanmean(jaccard, 1); 79 | meanJacc = mean(meanJaccPerPhase); 80 | stdJacc = std(meanJaccPerPhase); 81 | for h = 1:7 82 | jaccphase = jaccard(h,:); 83 | meanjaccphase(h) = nanmean(jaccphase); 84 | stdjaccphase(h) = nanstd(jaccphase); 85 | end 86 | 87 | index = find(f1>100); 88 | f1(index)=100; 89 | meanF1PerPhase = nanmean(f1, 2); 90 | meanF1PerVideo = nanmean(f1, 1); 91 | meanF1 = mean(meanF1PerPhase); 92 | stdF1 = std(meanF1PerVideo); 93 | for h = 1:7 94 | f1phase = f1(h,:); 95 | meanf1phase(h) = nanmean(f1phase); 96 | stdf1phase(h) = nanstd(f1phase); 97 | end 98 | 99 | index = find(prec>100); 100 | prec(index)=100; 101 | meanPrecPerPhase = nanmean(prec, 2); 102 | meanPrecPerVideo = nanmean(prec, 1); 103 | meanPrec = nanmean(meanPrecPerPhase); 104 | stdPrec = nanstd(meanPrecPerPhase); 105 | for h = 1:7 106 | precphase = prec(h,:); 107 | meanprecphase(h) = nanmean(precphase); 108 | stdprecphase(h) = nanstd(precphase); 109 | end 110 | 111 | index = find(rec>100); 112 | rec(index)=100; 113 | meanRecPerPhase = nanmean(rec, 2); 114 | meanRecPerVideo = nanmean(rec, 1); 115 | meanRec = mean(meanRecPerPhase); 116 | stdRec = std(meanRecPerPhase); 117 | for h = 1:7 118 | recphase = rec(h,:); 119 | meanrecphase(h) = nanmean(recphase); 120 | stdrecphase(h) = nanstd(recphase); 121 | end 122 | 123 | 124 | meanAcc = mean(acc); 125 | stdAcc = std(acc); 126 | 127 | % Display results 128 | % fprintf('model is :%s\n', model_rootfolder); 129 | disp('================================================'); 130 | disp([sprintf('%25s', 'Phase') '|' sprintf('%6s', 'Jacc') '|'... 131 | sprintf('%6s', 'Prec') '|' sprintf('%6s', 'Rec') '|']); 132 | disp('================================================'); 133 | for iPhase = 1:length(phases) 134 | disp([sprintf('%25s', phases{iPhase}) '|' sprintf('%6.2f', meanJaccPerPhase(iPhase)) '|' ... 135 | sprintf('%6.2f', meanPrecPerPhase(iPhase)) '|' sprintf('%6.2f', meanRecPerPhase(iPhase)) '|']); 136 | disp('---------------------------------------------'); 137 | end 138 | disp('================================================'); 139 | 140 | disp(['Mean jaccard: ' sprintf('%5.2f', meanJacc) '+-' sprintf('%5.2f', stdJacc)]); 141 | disp(['Mean f1-score: ' sprintf('%5.2f', meanF1) '+-' sprintf('%5.2f', stdF1)]); 142 | disp(['Mean accuracy: ' sprintf('%5.2f', meanAcc) '+-' sprintf('%5.2f', stdAcc)]); 143 | disp(['Mean precision: ' sprintf('%5.2f', meanPrec) '+-' sprintf('%5.2f', stdPrec)]); 144 | disp(['Mean recall: ' sprintf('%5.2f', meanRec) '+-' sprintf('%5.2f', stdRec)]); -------------------------------------------------------------------------------- /evaluation_matlab/Main_Cataract101.m: -------------------------------------------------------------------------------- 1 | close all; clear all; 2 | 3 | phaseGroundTruths = {}; 4 | gt_root_folder = 'phase_annotations/'; 5 | for k = 110:149 6 | num = num2str(k); 7 | to_add = ['video-' num]; 8 | video_name = [gt_root_folder to_add '.txt']; 9 | phaseGroundTruths = [phaseGroundTruths video_name]; 10 | end 11 | % phaseGroundTruths = {'video41-phase.txt', ... 12 | % 'video42-phase.txt'}; 13 | % phaseGroundTruths 14 | 15 | phases = {"Incision", 16 | "Viscous agent injection", 17 | "Rhexis", 18 | "Hydrodissection", 19 | "Phacoemulsificiation", 20 | "Irrigation and aspiration", 21 | "Capsule polishing", 22 | "Lens implant setting-up", 23 | "Viscous agent removal", 24 | "Tonifying and antibiotics"}; 25 | 26 | fps = 1; 27 | jaccard1 = zeros(10, 40); 28 | prec1 = zeros(10, 40); 29 | rec1 = zeros(10, 40); 30 | acc1 = zeros(1, 40); 31 | f11 = zeros(10, 40); 32 | 33 | 34 | for i = 1:length(phaseGroundTruths) 35 | predroot = 'prediction/'; 36 | phaseGroundTruth = phaseGroundTruths{i}; 37 | predFile = [predroot 'video-' phaseGroundTruth(end-6:end-4) '.txt']; 38 | [gt] = ReadPhaseLabel(phaseGroundTruth); 39 | [pred] = ReadPhaseLabel(predFile); 40 | 41 | if(size(gt{1}, 1) ~= size(pred{1},1) || size(gt{2}, 1) ~= size(pred{2},1)) 42 | error(['ERROR:' ground_truth_file '\nGround truth and prediction have different sizes']); 43 | end 44 | 45 | if(~isempty(find(gt{1} ~= pred{1}))) 46 | error(['ERROR: ' ground_truth_file '\nThe frame index in ground truth and prediction is not equal']); 47 | end 48 | 49 | t = length(gt{2}); 50 | for z=1:t 51 | gt{2}{z} = gt{2}{z}(1); 52 | pred{2}{z} = pred{2}{z}(1); 53 | end 54 | % reassigning the phase labels to numbers 55 | gtLabelID = []; 56 | predLabelID = []; 57 | for j = 1:10 58 | gtLabelID(find(strcmp(num2str(j-1), gt{2}))) = j; 59 | predLabelID(find(strcmp(num2str(j-1), pred{2}))) = j; 60 | end 61 | 62 | % compute jaccard index, precision, recall, and the accuracy 63 | [jaccard, prec, rec, acc, f1] = Evaluate_Cataract101(gtLabelID, predLabelID, fps); 64 | jaccard1(:, i) = jaccard; 65 | prec1(:, i) = prec; 66 | rec1(:, i) = rec; 67 | acc1(i) = acc; 68 | f11(:, i) = f1; 69 | end 70 | 71 | acc = acc1; 72 | rec = rec1; 73 | prec = prec1; 74 | jaccard = jaccard1; 75 | f1 = f11; 76 | 77 | accPerVideo= acc; 78 | 79 | % Compute means and stds 80 | index = find(jaccard>100); 81 | jaccard(index)=100; 82 | meanJaccPerPhase = nanmean(jaccard, 2); 83 | meanJaccPerVideo = nanmean(jaccard, 1); 84 | meanJacc = mean(meanJaccPerPhase); 85 | stdJacc = std(meanJaccPerPhase); 86 | for h = 1:10 87 | jaccphase = jaccard(h,:); 88 | meanjaccphase(h) = nanmean(jaccphase); 89 | stdjaccphase(h) = nanstd(jaccphase); 90 | end 91 | 92 | index = find(f1>100); 93 | f1(index)=100; 94 | meanF1PerPhase = nanmean(f1, 2); 95 | meanF1PerVideo = nanmean(f1, 1); 96 | meanF1 = mean(meanF1PerPhase); 97 | stdF1 = std(meanF1PerVideo); 98 | for h = 1:10 99 | f1phase = f1(h,:); 100 | meanf1phase(h) = nanmean(f1phase); 101 | stdf1phase(h) = nanstd(f1phase); 102 | end 103 | 104 | index = find(prec>100); 105 | prec(index)=100; 106 | meanPrecPerPhase = nanmean(prec, 2); 107 | meanPrecPerVideo = nanmean(prec, 1); 108 | meanPrec = nanmean(meanPrecPerPhase); 109 | stdPrec = nanstd(meanPrecPerPhase); 110 | for h = 1:10 111 | precphase = prec(h,:); 112 | meanprecphase(h) = nanmean(precphase); 113 | stdprecphase(h) = nanstd(precphase); 114 | end 115 | 116 | index = find(rec>100); 117 | rec(index)=100; 118 | meanRecPerPhase = nanmean(rec, 2); 119 | meanRecPerVideo = nanmean(rec, 1); 120 | meanRec = mean(meanRecPerPhase); 121 | stdRec = std(meanRecPerPhase); 122 | for h = 1:10 123 | recphase = rec(h,:); 124 | meanrecphase(h) = nanmean(recphase); 125 | stdrecphase(h) = nanstd(recphase); 126 | end 127 | 128 | 129 | meanAcc = mean(acc); 130 | stdAcc = std(acc); 131 | 132 | % Display results 133 | % fprintf('model is :%s\n', model_rootfolder); 134 | disp('================================================'); 135 | disp([sprintf('%25s', 'Phase') '|' sprintf('%6s', 'Jacc') '|'... 136 | sprintf('%6s', 'Prec') '|' sprintf('%6s', 'Rec') '|']); 137 | disp('================================================'); 138 | for iPhase = 1:length(phases) 139 | disp([sprintf('%25s', phases{iPhase}) '|' sprintf('%6.2f', meanJaccPerPhase(iPhase)) '|' ... 140 | sprintf('%6.2f', meanPrecPerPhase(iPhase)) '|' sprintf('%6.2f', meanRecPerPhase(iPhase)) '|']); 141 | disp('---------------------------------------------'); 142 | end 143 | disp('================================================'); 144 | 145 | disp(['Mean jaccard: ' sprintf('%5.2f', meanJacc) '+-' sprintf('%5.2f', stdJacc)]); 146 | disp(['Mean f1-score: ' sprintf('%5.2f', meanF1) '+-' sprintf('%5.2f', stdF1)]); 147 | disp(['Mean accuracy: ' sprintf('%5.2f', meanAcc) '+-' sprintf('%5.2f', stdAcc)]); 148 | disp(['Mean precision: ' sprintf('%5.2f', meanPrec) '+-' sprintf('%5.2f', stdPrec)]); 149 | disp(['Mean recall: ' sprintf('%5.2f', meanRec) '+-' sprintf('%5.2f', stdRec)]); 150 | -------------------------------------------------------------------------------- /evaluation_matlab/Main_m2cai.m: -------------------------------------------------------------------------------- 1 | close all; clear all; 2 | 3 | phaseGroundTruths = {}; 4 | gt_root_folder = '../gt-phase/'; 5 | for k = 1:14 6 | num = num2str(k); 7 | to_add = ['video' num]; 8 | video_name = [gt_root_folder to_add '-phase.txt']; 9 | phaseGroundTruths = [phaseGroundTruths video_name]; 10 | end 11 | % phaseGroundTruths = {'video41-phase.txt', ... 12 | % 'video42-phase.txt'}; 13 | % phaseGroundTruths 14 | 15 | phases = {'TrocarPlacement', 'Preparation', 'CalotTriangleDissection', ... 16 | 'ClippingCutting', 'GallbladderDissection', 'GallbladderPackaging', 'CleaningCoagulation', ... 17 | 'GallbladderRetraction'}; 18 | 19 | fps = 25; 20 | 21 | for i = 1:length(phaseGroundTruths) 22 | predroot = '../phase/'; 23 | %predroot = '../../Results/multi/phase'; 24 | %predroot = '../../Results/multi_kl_best_890_882/phase_post'; 25 | phaseGroundTruth = phaseGroundTruths{i}; 26 | predFile = [predroot phaseGroundTruth(13:end-10) '-phase.txt']; 27 | 28 | [gt] = ReadPhaseLabel(phaseGroundTruth); 29 | [pred] = ReadPhaseLabel(predFile); 30 | 31 | if(size(gt{1}, 1) ~= size(pred{1},1) || size(gt{2}, 1) ~= size(pred{2},1)) 32 | error(['ERROR:' ground_truth_file '\nGround truth and prediction have different sizes']); 33 | end 34 | 35 | if(~isempty(find(gt{1} ~= pred{1}))) 36 | error(['ERROR: ' ground_truth_file '\nThe frame index in ground truth and prediction is not equal']); 37 | end 38 | 39 | % reassigning the phase labels to numbers 40 | gtLabelID = []; 41 | predLabelID = []; 42 | for j = 1:8 43 | gtLabelID(find(strcmp(num2str(j-1), gt{2}))) = j; 44 | predLabelID(find(strcmp(num2str(j-1), pred{2}))) = j; 45 | end 46 | 47 | % compute jaccard index, precision, recall, and the accuracy 48 | [jaccard(:,i), prec(:,i), rec(:,i), acc(i)] = Evaluate_m2cai(gtLabelID, predLabelID, fps); 49 | 50 | end 51 | 52 | % Compute means and stds 53 | index = find(jaccard>100); 54 | jaccard(index)=100; 55 | meanJaccPerPhase = nanmean(jaccard, 2); 56 | meanJacc = mean(meanJaccPerPhase); 57 | stdJacc = std(meanJaccPerPhase); 58 | for h = 1:8 59 | jaccphase = jaccard(h,:); 60 | meanjaccphase(h) = nanmean(jaccphase); 61 | stdjaccphase(h) = nanstd(jaccphase); 62 | end 63 | 64 | index = find(prec>100); 65 | prec(index)=100; 66 | meanPrecPerPhase = nanmean(prec, 2); 67 | meanPrec = nanmean(meanPrecPerPhase); 68 | stdPrec = nanstd(meanPrecPerPhase); 69 | for h = 1:8 70 | precphase = prec(h,:); 71 | meanprecphase(h) = nanmean(precphase); 72 | stdprecphase(h) = nanstd(precphase); 73 | end 74 | 75 | index = find(rec>100); 76 | rec(index)=100; 77 | meanRecPerPhase = nanmean(rec, 2); 78 | meanRec = mean(meanRecPerPhase); 79 | stdRec = std(meanRecPerPhase); 80 | for h = 1:8 81 | recphase = rec(h,:); 82 | meanrecphase(h) = nanmean(recphase); 83 | stdrecphase(h) = nanstd(recphase); 84 | end 85 | 86 | 87 | meanAcc = mean(acc); 88 | stdAcc = std(acc); 89 | 90 | % Display results 91 | % fprintf('model is :%s\n', model_rootfolder); 92 | disp('================================================'); 93 | disp([sprintf('%25s', 'Phase') '|' sprintf('%6s', 'Jacc') '|'... 94 | sprintf('%6s', 'Prec') '|' sprintf('%6s', 'Rec') '|']); 95 | disp('================================================'); 96 | for iPhase = 1:length(phases) 97 | disp([sprintf('%25s', phases{iPhase}) '|' sprintf('%6.2f', meanJaccPerPhase(iPhase)) '|' ... 98 | sprintf('%6.2f', meanPrecPerPhase(iPhase)) '|' sprintf('%6.2f', meanRecPerPhase(iPhase)) '|']); 99 | disp('---------------------------------------------'); 100 | end 101 | disp('================================================'); 102 | 103 | disp(['Mean jaccard: ' sprintf('%5.2f', meanJacc) ' +- ' sprintf('%5.2f', stdJacc)]); 104 | disp(['Mean accuracy: ' sprintf('%5.2f', meanAcc) ' +- ' sprintf('%5.2f', stdAcc)]); 105 | disp(['Mean precision: ' sprintf('%5.2f', meanPrec) ' +- ' sprintf('%5.2f', stdPrec)]); 106 | disp(['Mean recall: ' sprintf('%5.2f', meanRec) ' +- ' sprintf('%5.2f', stdRec)]); 107 | -------------------------------------------------------------------------------- /evaluation_matlab/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ### MATLAB scripts to perform the evaluation 4 | 5 | ### Acknowledgement: 6 | MICCAI M2CAI challenge; the official webpage of the challenge can be found here: http://camma.u-strasbg.fr/m2cai2016 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /evaluation_matlab/ReadPhaseLabel.m: -------------------------------------------------------------------------------- 1 | function [ outp ] = ReadPhaseLabel( file ) 2 | %READPHASELABEL 3 | % Read the phase label (annotation and prediction) 4 | 5 | fid = fopen(file, 'r'); 6 | 7 | % read the header first 8 | tline = fgets(fid); 9 | 10 | % read the labels 11 | [outp] = textscan(fid, '%d\t%s\n', 'EndOfLine','\n' ); 12 | % [outp] = textscan(fid, '%d\t%s\n'); 13 | end -------------------------------------------------------------------------------- /evaluation_matlab/matlab.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/evaluation_matlab/matlab.mat -------------------------------------------------------------------------------- /evaluation_matlab/octave-workspace: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/evaluation_matlab/octave-workspace -------------------------------------------------------------------------------- /evaluation_matlab/test.m: -------------------------------------------------------------------------------- 1 | fid = fopen(file, 'r'); 2 | tline = fgets(fid); 3 | [outp] = textscan(fid, '%d %s\n'); 4 | -------------------------------------------------------------------------------- /model/__pycache__/mambapy.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/model/__pycache__/mambapy.cpython-311.pyc -------------------------------------------------------------------------------- /model/__pycache__/pmnet.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/model/__pycache__/pmnet.cpython-311.pyc -------------------------------------------------------------------------------- /model/__pycache__/pscan.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/model/__pycache__/pscan.cpython-311.pyc -------------------------------------------------------------------------------- /model/pmnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | import utils 5 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_ 6 | from timm.models.registry import register_model 7 | from einops import rearrange 8 | from collections import OrderedDict 9 | import torch.nn.functional as F 10 | import torchvision 11 | from model.mambapy import Mamba_CSM, MambaConfig 12 | 13 | def _cfg(url="", **kwargs): 14 | return { 15 | "url": url, 16 | "num_classes": 7, 17 | "input_size": (3, 224, 224), 18 | "pool_size": None, 19 | "crop_pct": 0.9, 20 | "interpolation": "bicubic", 21 | "mean": (0.5, 0.5, 0.5), 22 | "std": (0.5, 0.5, 0.5), 23 | **kwargs, 24 | } 25 | 26 | 27 | def crop_and_pool(images, bboxes): 28 | """ 29 | 使用 bbox 裁剪图片并进行平均池化。 30 | 31 | 参数: 32 | images: 形状为 (B, 3, T, 224, 224) 的张量。 33 | bboxes: 形状为 (B, T, 2, 2) 的张量,表示每个时间步的 bbox。 34 | 35 | 返回: 36 | 形状为 (B, 3, T, 1, 1) 的张量。 37 | """ 38 | B, C, T, H, W = images.shape 39 | output = torch.zeros(B, C, T, 1, 1, device=images.device) # 初始化输出张量 40 | 41 | for b in range(B): # 遍历 batch 42 | for t in range(T): # 遍历时间步 43 | # 获取当前时间步的 bbox 44 | x1, y1 = bboxes[b, t, 0, 0], bboxes[b, t, 0, 1] 45 | x2, y2 = bboxes[b, t, 1, 0], bboxes[b, t, 1, 1] 46 | 47 | # 裁剪图片 48 | cropped = images[b, :, t, y1:y2, x1:x2] # 形状为 (3, h, w) 49 | # 对裁剪后的图片进行全局平均池化 50 | pooled = F.avg_pool2d(cropped.unsqueeze(0), (cropped.shape[1], cropped.shape[2])) # 形状为 (1, 3, 1, 1) 51 | 52 | # 将结果存入输出张量 53 | output[b, :, t] = pooled.squeeze(0) 54 | 55 | return output 56 | 57 | class DropPath(nn.Module): 58 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 59 | 60 | def __init__(self, drop_prob=None): 61 | super(DropPath, self).__init__() 62 | self.drop_prob = drop_prob 63 | 64 | def forward(self, x): 65 | return drop_path(x, self.drop_prob, self.training) 66 | 67 | def extra_repr(self) -> str: 68 | return "p={}".format(self.drop_prob) 69 | 70 | 71 | class Mlp(nn.Module): 72 | def __init__( 73 | self, 74 | in_features, 75 | hidden_features=None, 76 | out_features=None, 77 | act_layer=nn.GELU, 78 | drop=0.0, 79 | ): 80 | super().__init__() 81 | out_features = out_features or in_features 82 | hidden_features = hidden_features or in_features 83 | self.fc1 = nn.Linear(in_features, hidden_features) 84 | self.act = act_layer() 85 | self.fc2 = nn.Linear(hidden_features, out_features) 86 | self.drop = nn.Dropout(drop) 87 | 88 | def forward(self, x): 89 | x = self.fc1(x) 90 | x = self.act(x) 91 | x = self.drop(x) 92 | x = self.fc2(x) 93 | x = self.drop(x) 94 | return x 95 | 96 | class CrossAttention(nn.Module): 97 | def __init__( 98 | self, 99 | dim, 100 | num_heads=8, 101 | qkv_bias=False, 102 | qk_scale=None, 103 | attn_drop=0.0, 104 | proj_drop=0.0, 105 | ): 106 | super().__init__() 107 | self.num_heads = num_heads 108 | head_dim = dim // num_heads 109 | self.scale = qk_scale or head_dim**-0.5 110 | 111 | # Separate Linear layers for Query, Key, and Value 112 | self.q_proj = nn.Linear(dim, dim, bias=qkv_bias) 113 | self.k_proj = nn.Linear(dim, dim, bias=qkv_bias) 114 | self.v_proj = nn.Linear(dim, dim, bias=qkv_bias) 115 | 116 | self.proj = nn.Linear(dim, dim) 117 | self.proj_drop = nn.Dropout(proj_drop) 118 | self.attn_drop = nn.Dropout(attn_drop) 119 | 120 | def forward(self, query, key_value, B): 121 | """ 122 | Args: 123 | query: Tensor of shape (B * K, T_query, C) 124 | key_value: Tensor of shape (B * K, T_key, C) 125 | B: Batch size 126 | """ 127 | BK_query, T_query, C = query.shape 128 | BK_key, T_key, _ = key_value.shape 129 | K_query = BK_query // B 130 | K_key = BK_key // B 131 | 132 | # Generate Query, Key, Value 133 | q = self.q_proj(query) # (B * K_query, T_query, C) 134 | k = self.k_proj(key_value) # (B * K_key, T_key, C) 135 | v = self.v_proj(key_value) # (B * K_key, T_key, C) 136 | 137 | # Reshape for multi-head attention 138 | q = rearrange(q, "(b k) t (num_heads c) -> (b k) num_heads t c", k=K_query, num_heads=self.num_heads) 139 | k = rearrange(k, "(b k) t (num_heads c) -> (b k) num_heads t c", k=K_key, num_heads=self.num_heads) 140 | v = rearrange(v, "(b k) t (num_heads c) -> (b k) num_heads t c", k=K_key, num_heads=self.num_heads) 141 | 142 | # Compute attention scores 143 | attn = (q @ k.transpose(-2, -1)) * self.scale # (B * K_query, num_heads, T_query, T_key) 144 | attn = attn.softmax(dim=-1) 145 | attn = self.attn_drop(attn) 146 | 147 | # Compute attention output 148 | x = attn @ v # (B * K_query, num_heads, T_query, C_head) 149 | x = rearrange(x, "(b k) num_heads t c -> (b k) t (num_heads c)", b=B) 150 | 151 | # Project back to original dimension 152 | x = self.proj(x) 153 | return self.proj_drop(x) 154 | 155 | class Attention_Spatial(nn.Module): 156 | def __init__( 157 | self, 158 | dim, 159 | num_heads=8, 160 | qkv_bias=False, 161 | qk_scale=None, 162 | attn_drop=0.0, 163 | proj_drop=0.0, 164 | with_qkv=True, 165 | ): 166 | super().__init__() 167 | self.num_heads = num_heads 168 | head_dim = dim // num_heads 169 | self.scale = qk_scale or head_dim**-0.5 170 | self.with_qkv = with_qkv 171 | if self.with_qkv: 172 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 173 | self.proj = nn.Linear(dim, dim) 174 | self.proj_drop = nn.Dropout(proj_drop) 175 | self.attn_drop = nn.Dropout(attn_drop) 176 | 177 | def forward(self, x, B): 178 | BT, K, C = x.shape 179 | T = BT // B 180 | qkv = self.qkv(x) 181 | # For Intra-Spatial: (BT, heads, K, C) 182 | # Atten: K*K, Values: K*C 183 | qkv = rearrange( 184 | qkv, 185 | "(b t) k (qkv num_heads c) -> qkv (b t) num_heads k c", 186 | t=T, 187 | qkv=3, 188 | num_heads=self.num_heads, 189 | ) 190 | q, k, v = ( 191 | qkv[0], 192 | qkv[1], 193 | qkv[2], 194 | ) # make torchscript happy (cannot use tensor as tuple) 195 | 196 | attn = (q @ k.transpose(-2, -1)) * self.scale 197 | attn = attn.softmax(dim=-1) 198 | attn = self.attn_drop(attn) 199 | 200 | x = attn @ v 201 | x = rearrange( 202 | x, 203 | "(b t) num_heads k c -> (b t) k (num_heads c)", 204 | b=B, 205 | ) 206 | x = self.proj(x) 207 | return self.proj_drop(x) 208 | 209 | 210 | class Attention_Temporal(nn.Module): 211 | def __init__( 212 | self, 213 | dim, 214 | num_heads=8, 215 | qkv_bias=False, 216 | qk_scale=None, 217 | attn_drop=0.0, 218 | proj_drop=0.0, 219 | with_qkv=True, 220 | ): 221 | super().__init__() 222 | self.num_heads = num_heads 223 | head_dim = dim // num_heads 224 | self.scale = qk_scale or head_dim**-0.5 225 | self.with_qkv = with_qkv 226 | if self.with_qkv: 227 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 228 | self.proj = nn.Linear(dim, dim) 229 | self.proj_drop = nn.Dropout(proj_drop) 230 | self.attn_drop = nn.Dropout(attn_drop) 231 | 232 | def forward(self, x, B): 233 | BK, T, C = x.shape 234 | K = BK // B 235 | qkv = self.qkv(x) 236 | 237 | # For Intra-Spatial: (BK, heads, T, C) 238 | # Atten: T*T, Values: T*C 239 | qkv = rearrange( 240 | qkv, 241 | "(b k) t (qkv num_heads c) -> qkv (b k) num_heads t c", 242 | k=K, 243 | qkv=3, 244 | num_heads=self.num_heads, 245 | ) 246 | q, k, v = ( 247 | qkv[0], 248 | qkv[1], 249 | qkv[2], 250 | ) # make torchscript happy (cannot use tensor as tuple) 251 | 252 | attn = (q @ k.transpose(-2, -1)) * self.scale 253 | attn = attn.softmax(dim=-1) 254 | attn = self.attn_drop(attn) 255 | 256 | x = attn @ v 257 | x = rearrange( 258 | x, 259 | "(b k) num_heads t c -> (b k) t (num_heads c)", 260 | b=B, 261 | ) 262 | 263 | x = self.proj(x) 264 | return self.proj_drop(x) 265 | 266 | class VisionTransformer(nn.Module): 267 | """Vision Transformer""" 268 | 269 | def __init__( 270 | self, 271 | img_size=224, 272 | patch_size=16, 273 | in_chans=3, 274 | num_classes=7, 275 | num_heads=12, 276 | qkv_bias=False, 277 | qk_scale=None, 278 | fc_drop_rate=0.0, 279 | drop_rate=0.0, 280 | attn_drop_rate=0.0, 281 | drop_path_rate=0.0, 282 | norm_layer=nn.LayerNorm, 283 | all_frames=16, 284 | ): 285 | super().__init__() 286 | embed_dim = 1536 287 | self.num_classes = num_classes 288 | self.num_features = ( 289 | self.embed_dim 290 | ) = embed_dim # num_features for consistency with other models 291 | backbone = torchvision.models.efficientnet_b3(weights=torchvision.models.EfficientNet_B3_Weights) 292 | self.backbone = nn.Sequential(*list(backbone.children())[:-1]) 293 | self.reduce_dim = nn.Linear(1536, 1536) 294 | self.NumberToVector = nn.Linear(1, 1536) 295 | self.norm1 = norm_layer(embed_dim) 296 | 297 | # knot and release feats 298 | self.act_feats = torch.zeros(1536).cuda() 299 | self.release_feats = torch.zeros(1536).cuda() 300 | self.knot_feats = torch.zeros(1536).cuda() 301 | self.act_cnt = torch.tensor(0, device='cuda') 302 | self.release_cnt = torch.tensor(0, device='cuda') 303 | self.knot_cnt = torch.tensor(0, device='cuda') 304 | self.alpha = 0.9 305 | 306 | # Temporal Attention Parameters 307 | self.temporal_norm1 = norm_layer(embed_dim) 308 | self.temporal_attn = Attention_Temporal( 309 | embed_dim, 310 | num_heads=num_heads, 311 | qkv_bias=qkv_bias, 312 | qk_scale=qk_scale, 313 | attn_drop=0.0, 314 | proj_drop=0.0, 315 | ) 316 | self.temporal_fc = nn.Linear(embed_dim, embed_dim) 317 | 318 | # SSM 319 | self.ca_norm = norm_layer(embed_dim) 320 | self.ca_attn = CrossAttention( 321 | embed_dim, 322 | num_heads=num_heads, 323 | qkv_bias=qkv_bias, 324 | qk_scale=qk_scale, 325 | attn_drop=0.0, 326 | proj_drop=0.0, 327 | ) 328 | config = MambaConfig(d_model=self.embed_dim, n_layers=2) 329 | self.ssm = Mamba_CSM(config) 330 | self.ca_fc = nn.Linear(embed_dim, embed_dim) 331 | 332 | ## Positional Embeddings 333 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 334 | self.cls_token_swap = nn.Parameter(torch.zeros(1, 5, 1, embed_dim)) 335 | self.time_embed = nn.Parameter(torch.zeros(1, all_frames, embed_dim)) 336 | self.time_drop = nn.Dropout(p=drop_rate) 337 | self.mask = nn.Parameter(torch.zeros(embed_dim)) 338 | 339 | 340 | self.norm = norm_layer(embed_dim) 341 | 342 | # Classifier head 343 | self.fc_dropout = ( 344 | nn.Dropout(p=fc_drop_rate) if fc_drop_rate > 0 else nn.Identity() 345 | ) 346 | self.head = ( 347 | nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 348 | ) 349 | self.fc_dropout_blocking = ( 350 | nn.Dropout(p=fc_drop_rate) if fc_drop_rate > 0 else nn.Identity() 351 | ) 352 | self.head_blocking = ( 353 | nn.Linear(embed_dim, 2) if num_classes > 0 else nn.Identity() 354 | ) 355 | trunc_normal_(self.cls_token, std=0.02) 356 | self.apply(self._init_weights) 357 | 358 | def _init_weights(self, m): 359 | if isinstance(m, nn.Linear): 360 | trunc_normal_(m.weight, std=0.02) 361 | if isinstance(m, nn.Linear) and m.bias is not None: 362 | nn.init.constant_(m.bias, 0) 363 | elif isinstance(m, nn.LayerNorm): 364 | nn.init.constant_(m.bias, 0) 365 | nn.init.constant_(m.weight, 1.0) 366 | 367 | @torch.jit.ignore 368 | def no_weight_decay(self): 369 | return {"pos_embed", "cls_token", "time_embed"} 370 | 371 | def get_classifier(self): 372 | return self.head 373 | 374 | def reset_classifier(self, num_classes, global_pool=""): 375 | self.num_classes = num_classes 376 | self.head = ( 377 | nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 378 | ) 379 | 380 | def forward(self, x, timestamp, bboxes, target=None): 381 | B, _, T, H, W = x.size() 382 | 383 | # pooling for RGB img/crop 384 | pooled_q = F.adaptive_avg_pool2d(x, (1, 1)) 385 | pooled_q += crop_and_pool(x, bboxes) 386 | ssm_q = pooled_q.squeeze(-1).squeeze(-1).permute(0, 2, 1) 387 | 388 | x = rearrange(x, "b c t h w -> (b t) c h w") 389 | x = self.backbone(x) 390 | x = torch.squeeze(x) 391 | x = self.reduce_dim(x) 392 | x = rearrange(x, "(b t) c -> b t c", b=B) 393 | xt = x + self.time_embed # B, T, C 394 | timestamp_embed = self.NumberToVector(timestamp.unsqueeze(1)) 395 | xt = xt + timestamp_embed.unsqueeze(1) 396 | xt = self.time_drop(xt) 397 | 398 | # temporal pooling 399 | kernel = torch.tensor([1 / 3, 1 / 3, 1 / 3], dtype=torch.float16).view(1, 1, -1) 400 | kernel = kernel.repeat(1536, 1, 1).to(xt.device) 401 | output = F.conv1d(xt.permute(0, 2, 1), kernel, padding=1, groups=1536) 402 | pooled = output.permute(0, 2, 1) # (B, N, C) 403 | pooled_last_group_0 = pooled[:, -1, :] 404 | 405 | # add masks 406 | if self.act_cnt != 0: 407 | dot_product_b = torch.matmul(pooled_last_group_0, self.act_feats.to(torch.float16)) 408 | norm_a = torch.norm(pooled_last_group_0, p=2, dim=1) 409 | norm_act = torch.norm(self.act_feats, p=2) 410 | 411 | similarity_act = dot_product_b / (norm_a * norm_act + 1e-8) 412 | similarity_idx = (similarity_act < 0) 413 | time_idx = (timestamp > 0.5) 414 | mask_idx = similarity_idx * time_idx 415 | if True in mask_idx: 416 | xt[mask_idx, -1, :] += self.mask 417 | 418 | groups = rearrange(xt, "b (n k) c -> b n k c", k=4) 419 | cls_token = self.cls_token_swap.expand(xt.size(0), -1, -1, -1) 420 | groups_with_cls = torch.cat([cls_token, groups], dim=2) 421 | 422 | num_interactions = 4 # 4 interactions for temporal integration 423 | for _ in range(num_interactions): 424 | groups_reshaped = rearrange(groups_with_cls, "b n k c -> (b n) k c") 425 | updated_groups = self.temporal_attn(self.temporal_norm1(groups_reshaped), B) 426 | updated_groups = self.temporal_fc(updated_groups) + groups_reshaped 427 | groups_with_cls = rearrange(updated_groups, "(b n) k c -> b n k c", b=B) 428 | 429 | # CLS token exchange 430 | cls_tokens = groups_with_cls[:, :, 0, :] 431 | cls_tokens = cls_tokens.roll(shifts=1, dims=1) 432 | groups_with_cls[:, :, 0, :] = cls_tokens 433 | 434 | groups_without_cls = groups_with_cls[:, :, 1:, :] 435 | 436 | last_group_feats = groups_without_cls[:, -1, :, :] 437 | pooled_last_group = last_group_feats[:, -1, :] 438 | 439 | # contrastive learning update prototype 440 | if self.training: 441 | act_idx = ((target == 1) + (target == 3)) 442 | knot_idx = (target == 1) 443 | release_idx = (target == 3) 444 | if True in act_idx: 445 | act_cnt = sum(act_idx) 446 | act_feat = pooled_last_group[act_idx] 447 | summed_act = act_feat.sum(dim=0) 448 | # self.act_feats = self.act_feats * (self.act_cnt/(self.act_cnt+act_cnt)) + summed_act/(self.act_cnt+act_cnt) 449 | self.act_feats = self.act_feats * self.alpha + (1 - self.alpha) * (summed_act / act_cnt) 450 | self.act_cnt += act_cnt 451 | 452 | if True in knot_idx: 453 | knot_cnt = sum(knot_idx) 454 | knot_feat = pooled_last_group[knot_idx] 455 | summed_knot = knot_feat.sum(dim=0) 456 | self.knot_feats = self.knot_feats * self.alpha + (1 - self.alpha) * (summed_knot / knot_cnt) 457 | self.knot_cnt += knot_cnt 458 | 459 | if True in release_idx: 460 | release_cnt = sum(release_idx) 461 | release_feat = pooled_last_group[release_idx] 462 | summed_release = release_feat.sum(dim=0) 463 | self.release_feats = self.release_feats * self.alpha + (1 - self.alpha) * (summed_release / release_cnt) 464 | self.release_cnt += release_cnt 465 | 466 | xt = rearrange(groups_without_cls, "b n k c -> b (n k) c") # (B, 20, C),将结果展平回去 467 | 468 | ssm_xt = xt 469 | ssm_out = self.ssm(ssm_xt, ssm_q) 470 | 471 | xt = xt[:, -1, :] + x[:, -1, :] 472 | 473 | fused_xt = self.ca_attn(self.ca_norm(xt.unsqueeze(1)), self.ca_norm(ssm_out), B) # 输出 (B*4, 5, C) 474 | fused_xt = self.ca_fc(fused_xt)[:, -1, :] + xt 475 | 476 | output = self.head(self.fc_dropout(fused_xt)) 477 | output_blocking = self.head_blocking(self.fc_dropout_blocking(fused_xt)) 478 | 479 | output_idx = output.argmax(-1) 480 | knot_FP = ((target == 3) * (output_idx == 1)) 481 | release_FP = ((target == 1) * (output_idx == 3)) 482 | contrastive_dict = {} 483 | if True in knot_FP: 484 | knot_FP_feats = pooled_last_group[knot_FP] 485 | contrastive_dict['knot_FP_feats'] = knot_FP_feats 486 | if True in release_FP: 487 | release_FP_feats = pooled_last_group[release_FP] 488 | contrastive_dict['release_FP_feats'] = release_FP_feats 489 | 490 | return output, output_blocking, contrastive_dict 491 | 492 | @register_model 493 | def pmnet(pretrained=False, pretrain_path=None, **kwargs): 494 | model = VisionTransformer( 495 | img_size=224, 496 | patch_size=16, 497 | num_heads=12, 498 | qkv_bias=True, 499 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 500 | **kwargs, 501 | ) 502 | model.default_cfg = _cfg() 503 | 504 | return model -------------------------------------------------------------------------------- /model/pscan.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | """ 7 | 8 | An implementation of the parallel scan operation in PyTorch (Blelloch version). 9 | Please see docs/pscan.ipynb for a detailed explanation of what happens here. 10 | 11 | """ 12 | 13 | 14 | def npo2(len): 15 | """ 16 | Returns the next power of 2 above len 17 | """ 18 | 19 | return 2 ** math.ceil(math.log2(len)) 20 | 21 | 22 | def pad_npo2(X): 23 | """ 24 | Pads input length dim to the next power of 2 25 | 26 | Args: 27 | X : (B, L, D, N) 28 | 29 | Returns: 30 | Y : (B, npo2(L), D, N) 31 | """ 32 | 33 | len_npo2 = npo2(X.size(1)) 34 | pad_tuple = (0, 0, 0, 0, 0, len_npo2 - X.size(1)) 35 | return F.pad(X, pad_tuple, "constant", 0) 36 | 37 | 38 | class PScan(torch.autograd.Function): 39 | @staticmethod 40 | def pscan(A, X): 41 | #  A : (B, D, L, N) 42 | #  X : (B, D, L, N) 43 | 44 | #  modifies X in place by doing a parallel scan. 45 | #  more formally, X will be populated by these values : 46 | #  H[t] = A[t] * H[t-1] + X[t] with H[0] = 0 47 | #  which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps) 48 | 49 | #  only supports L that is a power of two (mainly for a clearer code) 50 | 51 | B, D, L, _ = A.size() 52 | num_steps = int(math.log2(L)) 53 | 54 | #  up sweep (last 2 steps unfolded) 55 | Aa = A 56 | Xa = X 57 | for _ in range(num_steps - 2): 58 | T = Xa.size(2) 59 | Aa = Aa.view(B, D, T // 2, 2, -1) 60 | Xa = Xa.view(B, D, T // 2, 2, -1) 61 | 62 | Xa[:, :, :, 1].add_(Aa[:, :, :, 1].mul(Xa[:, :, :, 0])) 63 | Aa[:, :, :, 1].mul_(Aa[:, :, :, 0]) 64 | 65 | Aa = Aa[:, :, :, 1] 66 | Xa = Xa[:, :, :, 1] 67 | 68 | #  we have only 4, 2 or 1 nodes left 69 | if Xa.size(2) == 4: 70 | Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0])) 71 | Aa[:, :, 1].mul_(Aa[:, :, 0]) 72 | 73 | Xa[:, :, 3].add_(Aa[:, :, 3].mul(Xa[:, :, 2] + Aa[:, :, 2].mul(Xa[:, :, 1]))) 74 | elif Xa.size(2) == 2: 75 | Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0])) 76 | return 77 | else: 78 | return 79 | 80 | #  down sweep (first 2 steps unfolded) 81 | Aa = A[:, :, 2 ** (num_steps - 2) - 1:L:2 ** (num_steps - 2)] 82 | Xa = X[:, :, 2 ** (num_steps - 2) - 1:L:2 ** (num_steps - 2)] 83 | Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 1])) 84 | Aa[:, :, 2].mul_(Aa[:, :, 1]) 85 | 86 | for k in range(num_steps - 3, -1, -1): 87 | Aa = A[:, :, 2 ** k - 1:L:2 ** k] 88 | Xa = X[:, :, 2 ** k - 1:L:2 ** k] 89 | 90 | T = Xa.size(2) 91 | Aa = Aa.view(B, D, T // 2, 2, -1) 92 | Xa = Xa.view(B, D, T // 2, 2, -1) 93 | 94 | Xa[:, :, 1:, 0].add_(Aa[:, :, 1:, 0].mul(Xa[:, :, :-1, 1])) 95 | Aa[:, :, 1:, 0].mul_(Aa[:, :, :-1, 1]) 96 | 97 | @staticmethod 98 | def pscan_rev(A, X): 99 | #  A : (B, D, L, N) 100 | #  X : (B, D, L, N) 101 | 102 | #  the same function as above, but in reverse 103 | # (if you flip the input, call pscan, then flip the output, you get what this function outputs) 104 | #  it is used in the backward pass 105 | 106 | #  only supports L that is a power of two (mainly for a clearer code) 107 | 108 | B, D, L, _ = A.size() 109 | num_steps = int(math.log2(L)) 110 | 111 | #  up sweep (last 2 steps unfolded) 112 | Aa = A 113 | Xa = X 114 | for _ in range(num_steps - 2): 115 | T = Xa.size(2) 116 | Aa = Aa.view(B, D, T // 2, 2, -1) 117 | Xa = Xa.view(B, D, T // 2, 2, -1) 118 | 119 | Xa[:, :, :, 0].add_(Aa[:, :, :, 0].mul(Xa[:, :, :, 1])) 120 | Aa[:, :, :, 0].mul_(Aa[:, :, :, 1]) 121 | 122 | Aa = Aa[:, :, :, 0] 123 | Xa = Xa[:, :, :, 0] 124 | 125 | #  we have only 4, 2 or 1 nodes left 126 | if Xa.size(2) == 4: 127 | Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 3])) 128 | Aa[:, :, 2].mul_(Aa[:, :, 3]) 129 | 130 | Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1].add(Aa[:, :, 1].mul(Xa[:, :, 2])))) 131 | elif Xa.size(2) == 2: 132 | Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1])) 133 | return 134 | else: 135 | return 136 | 137 | #  down sweep (first 2 steps unfolded) 138 | Aa = A[:, :, 0:L:2 ** (num_steps - 2)] 139 | Xa = X[:, :, 0:L:2 ** (num_steps - 2)] 140 | Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 2])) 141 | Aa[:, :, 1].mul_(Aa[:, :, 2]) 142 | 143 | for k in range(num_steps - 3, -1, -1): 144 | Aa = A[:, :, 0:L:2 ** k] 145 | Xa = X[:, :, 0:L:2 ** k] 146 | 147 | T = Xa.size(2) 148 | Aa = Aa.view(B, D, T // 2, 2, -1) 149 | Xa = Xa.view(B, D, T // 2, 2, -1) 150 | 151 | Xa[:, :, :-1, 1].add_(Aa[:, :, :-1, 1].mul(Xa[:, :, 1:, 0])) 152 | Aa[:, :, :-1, 1].mul_(Aa[:, :, 1:, 0]) 153 | 154 | @staticmethod 155 | def forward(ctx, A_in, X_in): 156 | """ 157 | Applies the parallel scan operation, as defined above. Returns a new tensor. 158 | If you can, privilege sequence lengths that are powers of two. 159 | 160 | Args: 161 | A_in : (B, L, D, N) 162 | X_in : (B, L, D, N) 163 | 164 | Returns: 165 | H : (B, L, D, N) 166 | """ 167 | 168 | L = X_in.size(1) 169 | 170 | #  cloning is requiered because of the in-place ops 171 | if L == npo2(L): 172 | A = A_in.clone() 173 | X = X_in.clone() 174 | else: 175 | #  pad tensors (and clone btw) 176 | A = pad_npo2(A_in) #  (B, npo2(L), D, N) 177 | X = pad_npo2(X_in) #  (B, npo2(L), D, N) 178 | 179 | # prepare tensors 180 | A = A.transpose(2, 1) #  (B, D, npo2(L), N) 181 | X = X.transpose(2, 1) #  (B, D, npo2(L), N) 182 | 183 | #  parallel scan (modifies X in-place) 184 | PScan.pscan(A, X) 185 | 186 | ctx.save_for_backward(A_in, X) 187 | 188 | #  slice [:, :L] (cut if there was padding) 189 | return X.transpose(2, 1)[:, :L] 190 | 191 | @staticmethod 192 | def backward(ctx, grad_output_in): 193 | """ 194 | Flows the gradient from the output to the input. Returns two new tensors. 195 | 196 | Args: 197 | ctx : A_in : (B, L, D, N), X : (B, D, L, N) 198 | grad_output_in : (B, L, D, N) 199 | 200 | Returns: 201 | gradA : (B, L, D, N), gradX : (B, L, D, N) 202 | """ 203 | 204 | A_in, X = ctx.saved_tensors 205 | 206 | L = grad_output_in.size(1) 207 | 208 | # cloning is requiered because of the in-place ops 209 | if L == npo2(L): 210 | grad_output = grad_output_in.clone() 211 | #  the next padding will clone A_in 212 | else: 213 | grad_output = pad_npo2(grad_output_in) #  (B, npo2(L), D, N) 214 | A_in = pad_npo2(A_in) #  (B, npo2(L), D, N) 215 | 216 | # prepare tensors 217 | grad_output = grad_output.transpose(2, 1) 218 | A_in = A_in.transpose(2, 1) #  (B, D, npo2(L), N) 219 | A = torch.nn.functional.pad(A_in[:, :, 1:], 220 | (0, 0, 0, 1)) #  (B, D, npo2(L), N) shift 1 to the left (see hand derivation) 221 | 222 | #  reverse parallel scan (modifies grad_output in-place) 223 | PScan.pscan_rev(A, grad_output) 224 | 225 | Q = torch.zeros_like(X) 226 | Q[:, :, 1:].add_(X[:, :, :-1] * grad_output[:, :, 1:]) 227 | 228 | return Q.transpose(2, 1)[:, :L], grad_output.transpose(2, 1)[:, :L] 229 | 230 | 231 | pscan = PScan.apply 232 | -------------------------------------------------------------------------------- /scripts/test_phase.sh: -------------------------------------------------------------------------------- 1 | export CC=gcc-11 2 | export CXX=g++-11 3 | 4 | CUDA_VISIBLE_DEVICES=6,7 python -m torch.distributed.launch \ 5 | --nproc_per_node=2 \ 6 | --master_port 12324 \ 7 | downstream_phase/run_phase_training.py \ 8 | --batch_size 8 \ 9 | --epochs 50 \ 10 | --save_ckpt_freq 10 \ 11 | --model surgformer_HTA \ 12 | --pretrained_path pretrain_params/timesformer_base_patch16_224_K400.pyth \ 13 | --mixup 0.8 \ 14 | --cutmix 1.0 \ 15 | --smoothing 0.1 \ 16 | --lr 5e-4 \ 17 | --layer_decay 0.75 \ 18 | --warmup_epochs 5 \ 19 | --data_path /jhcnas4/syangcw/M2CAI16-workflow \ 20 | --eval_data_path /jhcnas4/syangcw/M2CAI16-workflow \ 21 | --nb_classes 8 \ 22 | --data_strategy online \ 23 | --output_mode key_frame \ 24 | --num_frames 16 \ 25 | --sampling_rate 4 \ 26 | --eval \ 27 | --finetune /jhcnas4/syangcw/Surgformerv2/M2CAI16/surgformer_HTA_M2CAI16_0.0005_0.75_online_key_frame_frame16_Fixed_Stride_4/checkpoint-best/mp_rank_00_model_states.pt \ 28 | --data_set M2CAI16 \ 29 | --data_fps 1fps \ 30 | --output_dir /jhcnas4/syangcw/Surgformerv2/M2CAI16 \ 31 | --log_dir /jhcnas4/syangcw/Surgformerv2/M2CAI16 \ 32 | --num_workers 10 \ 33 | --dist_eval \ 34 | --enable_deepspeed \ 35 | --no_auto_resume -------------------------------------------------------------------------------- /scripts/train_phase.sh: -------------------------------------------------------------------------------- 1 | export CC=gcc-11 2 | export CXX=g++-11 3 | 4 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch \ 5 | --nproc_per_node=4 \ 6 | --master_port 12326 \ 7 | downstream_phase/run_phase_training.py \ 8 | --batch_size 8 \ 9 | --epochs 50 \ 10 | --save_ckpt_freq 10 \ 11 | --model timesformer \ 12 | --pretrained_path pretrain_params/timesformer_base_patch16_224_K400.pyth \ 13 | --mixup 0.8 \ 14 | --cutmix 1.0 \ 15 | --smoothing 0.1 \ 16 | --lr 5e-4 \ 17 | --layer_decay 0.75 \ 18 | --warmup_epochs 5 \ 19 | --data_path /home/syangcw/LungSeg/LungSeg \ 20 | --eval_data_path /home/syangcw/LungSeg/LungSeg \ 21 | --nb_classes 7 \ 22 | --data_strategy online \ 23 | --output_mode key_frame \ 24 | --num_frames 16 \ 25 | --sampling_rate 4 \ 26 | --data_set LungSeg \ 27 | --data_fps 1fps \ 28 | --output_dir /results/LungSeg \ 29 | --log_dir /results/LungSeg \ 30 | --num_workers 10 \ 31 | --dist_eval \ 32 | --enable_deepspeed \ 33 | --no_auto_resume -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch \ 2 | --nproc_per_node=2 \ 3 | run_phase_training.py \ 4 | --batch_size 8 \ 5 | --epochs 50 \ 6 | --save_ckpt_freq 10 \ 7 | --model pmnet \ 8 | --lr 3e-5 \ 9 | --layer_decay 0.75 \ 10 | --warmup_epochs 5 \ 11 | --data_path data_path \ 12 | --eval_data_path data_path \ 13 | --nb_classes 5 \ 14 | --data_strategy online \ 15 | --output_mode key_frame \ 16 | --num_frames 20 \ 17 | --sampling_rate 8 \ 18 | --data_set PmLR50 \ 19 | --data_fps 1fps \ 20 | --output_dir output_path \ 21 | --log_dir output_path \ 22 | --num_workers 10 \ 23 | --enable_deepspeed \ 24 | --no_auto_resume \ 25 | --eval \ 26 | --load_ckpt checkpoint_path 27 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch \ 2 | --nproc_per_node=2 \ 3 | run_phase_training.py \ 4 | --batch_size 8 \ 5 | --epochs 50 \ 6 | --save_ckpt_freq 10 \ 7 | --model pmnet \ 8 | --lr 3e-5 \ 9 | --layer_decay 0.75 \ 10 | --warmup_epochs 5 \ 11 | --data_path data_path \ 12 | --eval_data_path data_path \ 13 | --nb_classes 5 \ 14 | --data_strategy online \ 15 | --output_mode key_frame \ 16 | --num_frames 20 \ 17 | --sampling_rate 8 \ 18 | --data_set PmLR50 \ 19 | --data_fps 1fps \ 20 | --output_dir output_path \ 21 | --log_dir output_path \ 22 | --num_workers 10 \ 23 | --enable_deepspeed \ 24 | --no_auto_resume 25 | --------------------------------------------------------------------------------