├── PmNet.yml
├── README.md
├── assets
├── Pipeline.png
└── photo
├── datasets
├── __pycache__
│ ├── args.cpython-311.pyc
│ ├── functional.cpython-310.pyc
│ └── functional.cpython-311.pyc
├── args.py
├── convert_results
│ ├── convert_autolaparo.py
│ └── convert_cholec80.py
├── data_preprosses
│ ├── generate_labels_autolaparo.py
│ ├── generate_labels_ch80.py
│ ├── generate_labels_lungseg.py
│ └── generate_labels_pmlr.py
├── extract_frames
│ ├── extract_frames_autolaparo.py
│ └── extract_frames_ch80.py
├── functional.py
├── phase
│ ├── AutoLaparo_phase.py
│ ├── Cholec80_phase.py
│ ├── PmLR50_phase.py
│ └── __pycache__
│ │ ├── AutoLaparo_phase.cpython-310.pyc
│ │ ├── AutoLaparo_phase.cpython-311.pyc
│ │ ├── Cholec80_phase.cpython-310.pyc
│ │ ├── Cholec80_phase.cpython-311.pyc
│ │ └── PmLR50_phase.cpython-311.pyc
├── tools
│ ├── frame_cutmargin.py
│ ├── resize_frame.py
│ ├── transfer_csv_txt.py
│ └── transfer_json_txt.py
└── transforms
│ ├── __pycache__
│ ├── mixup.cpython-310.pyc
│ ├── mixup.cpython-311.pyc
│ ├── optim_factory.cpython-310.pyc
│ ├── optim_factory.cpython-311.pyc
│ ├── rand_augment.cpython-310.pyc
│ ├── rand_augment.cpython-311.pyc
│ ├── random_erasing.cpython-310.pyc
│ ├── random_erasing.cpython-311.pyc
│ ├── surg_transforms.cpython-310.pyc
│ ├── surg_transforms.cpython-311.pyc
│ ├── video_transforms.cpython-310.pyc
│ ├── video_transforms.cpython-311.pyc
│ ├── volume_transforms.cpython-310.pyc
│ └── volume_transforms.cpython-311.pyc
│ ├── image_transforms.py
│ ├── mixup.py
│ ├── optim_factory.py
│ ├── rand_augment.py
│ ├── random_erasing.py
│ ├── surg_transforms.py
│ ├── transforms.py
│ ├── video_transforms.py
│ └── volume_transforms.py
├── downstream_phase
├── __pycache__
│ ├── datasets_phase.cpython-310.pyc
│ ├── datasets_phase.cpython-311.pyc
│ ├── engine_for_phase.cpython-310.pyc
│ └── engine_for_phase.cpython-311.pyc
├── datasets_phase.py
└── engine_for_phase.py
├── evaluation_matlab
├── Evaluate.m
├── Evaluate_Cataract101.m
├── Evaluate_m2cai.m
├── Main.m
├── Main_AutoLaparo.m
├── Main_Cataract101.m
├── Main_m2cai.m
├── README.md
├── ReadPhaseLabel.m
├── matlab.mat
├── octave-workspace
└── test.m
├── model
├── __pycache__
│ ├── mambapy.cpython-311.pyc
│ ├── pmnet.cpython-311.pyc
│ └── pscan.cpython-311.pyc
├── mambapy.py
├── pmnet.py
└── pscan.py
├── run_phase_training.py
├── scripts
├── test_phase.sh
└── train_phase.sh
├── test.sh
├── train.sh
└── utils.py
/PmNet.yml:
--------------------------------------------------------------------------------
1 | name: Pmnet
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1=main
8 | - _openmp_mutex=5.1=1_gnu
9 | - beautifulsoup4=4.12.2=py38h06a4308_0
10 | - blas=1.0=mkl
11 | - boltons=23.0.0=py38h06a4308_0
12 | - brotlipy=0.7.0=py38h27cfd23_1003
13 | - bzip2=1.0.8=h7f98852_4
14 | - ca-certificates=2023.08.22=h06a4308_0
15 | - certifi=2023.7.22=py38h06a4308_0
16 | - cffi=1.15.1=py38h5eee18b_3
17 | - chardet=4.0.0=py38h06a4308_1003
18 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
19 | - click=8.0.4=py38h06a4308_0
20 | - conda=23.7.4=py38h06a4308_0
21 | - conda-build=3.25.0=py38h06a4308_0
22 | - conda-index=0.2.3=py38h06a4308_0
23 | - conda-package-handling=2.2.0=py38h06a4308_0
24 | - conda-package-streaming=0.9.0=py38h06a4308_0
25 | - cryptography=41.0.3=py38h130f0dd_0
26 | - cudatoolkit=11.1.1=ha002fc5_10
27 | - ffmpeg=4.3=hf484d3e_0
28 | - filelock=3.9.0=py38h06a4308_0
29 | - freetype=2.10.4=h0708190_1
30 | - glob2=0.7=pyhd3eb1b0_0
31 | - gmp=6.2.1=h58526e2_0
32 | - gnutls=3.6.13=h85f3911_1
33 | - icu=58.2=he6710b0_3
34 | - idna=3.4=py38h06a4308_0
35 | - intel-openmp=2021.4.0=h06a4308_3561
36 | - jinja2=3.1.2=py38h06a4308_0
37 | - jpeg=9b=h024ee3a_2
38 | - jsonpatch=1.32=pyhd3eb1b0_0
39 | - jsonpointer=2.1=pyhd3eb1b0_0
40 | - lame=3.100=h7f98852_1001
41 | - ld_impl_linux-64=2.38=h1181459_1
42 | - libarchive=3.4.2=h62408e4_0
43 | - libffi=3.4.4=h6a678d5_0
44 | - libgcc-ng=11.2.0=h1234567_1
45 | - libgomp=11.2.0=h1234567_1
46 | - libiconv=1.17=h166bdaf_0
47 | - liblief=0.12.3=h6a678d5_0
48 | - libpng=1.6.37=h21135ba_2
49 | - libstdcxx-ng=11.2.0=h1234567_1
50 | - libtiff=4.1.0=h2733197_1
51 | - libuv=1.43.0=h7f98852_0
52 | - libxml2=2.9.14=h74e7548_0
53 | - lz4-c=1.9.3=h9c3ff4c_1
54 | - markupsafe=2.1.1=py38h7f8727e_0
55 | - mkl=2021.4.0=h06a4308_640
56 | - mkl-service=2.4.0=py38h95df7f1_0
57 | - mkl_fft=1.3.1=py38h8666266_1
58 | - mkl_random=1.2.2=py38h1abd341_0
59 | - more-itertools=8.12.0=pyhd3eb1b0_0
60 | - ncurses=6.4=h6a678d5_0
61 | - nettle=3.6=he412f7d_0
62 | - numpy=1.24.3=py38h14f4228_0
63 | - numpy-base=1.24.3=py38h31eccc5_0
64 | - olefile=0.46=pyh9f0ad1d_1
65 | - openh264=2.1.1=h780b84a_0
66 | - openssl=1.1.1w=h7f8727e_0
67 | - packaging=23.1=py38h06a4308_0
68 | - patch=2.7.6=h7b6447c_1001
69 | - patchelf=0.17.2=h6a678d5_0
70 | - pip=23.2.1=py38h06a4308_0
71 | - pkginfo=1.9.6=py38h06a4308_0
72 | - pluggy=1.0.0=py38h06a4308_1
73 | - psutil=5.9.0=py38h5eee18b_0
74 | - py-lief=0.12.3=py38h6a678d5_0
75 | - pycosat=0.6.4=py38h5eee18b_0
76 | - pycparser=2.21=pyhd3eb1b0_0
77 | - pyopenssl=23.2.0=py38h06a4308_0
78 | - pysocks=1.7.1=py38h06a4308_0
79 | - python=3.8.18=h7a1cb2a_0
80 | - python-libarchive-c=2.9=pyhd3eb1b0_1
81 | - python_abi=3.8=2_cp38
82 | - pytz=2022.7=py38h06a4308_0
83 | - pyyaml=6.0=py38h5eee18b_1
84 | - readline=8.2=h5eee18b_0
85 | - requests=2.31.0=py38h06a4308_0
86 | - ruamel.yaml=0.17.21=py38h5eee18b_0
87 | - ruamel.yaml.clib=0.2.6=py38h5eee18b_1
88 | - setuptools=68.0.0=py38h06a4308_0
89 | - six=1.16.0=pyh6c4a22f_0
90 | - soupsieve=2.4=py38h06a4308_0
91 | - sqlite=3.41.2=h5eee18b_0
92 | - tk=8.6.12=h1ccaba5_0
93 | - tomli=2.0.1=py38h06a4308_0
94 | - toolz=0.12.0=py38h06a4308_0
95 | - tqdm=4.65.0=py38hb070fc8_0
96 | - urllib3=1.26.16=py38h06a4308_0
97 | - wheel=0.38.4=py38h06a4308_0
98 | - xz=5.4.2=h5eee18b_0
99 | - yaml=0.2.5=h7b6447c_0
100 | - zlib=1.2.13=h5eee18b_0
101 | - zstandard=0.19.0=py38h5eee18b_0
102 | - zstd=1.4.9=ha95c52a_0
103 | - pip:
104 | - absl-py==2.0.0
105 | - appdirs==1.4.4
106 | - astunparse==1.6.3
107 | - cachetools==5.3.1
108 | - contourpy==1.1.1
109 | - cycler==0.11.0
110 | - decord==0.6.0
111 | - defusedxml==0.7.1
112 | - docker-pycreds==0.4.0
113 | - einops==0.6.1
114 | - flatbuffers==23.5.26
115 | - fonttools==4.42.1
116 | - fsspec==2023.9.1
117 | - ftfy==6.1.1
118 | - gast==0.4.0
119 | - gitdb==4.0.11
120 | - gitpython==3.1.40
121 | - glances==3.4.0.3
122 | - google-auth==2.23.2
123 | - google-auth-oauthlib==1.0.0
124 | - google-pasta==0.2.0
125 | - grpcio==1.59.0
126 | - h5py==3.10.0
127 | - hjson==3.1.0
128 | - huggingface-hub==0.17.1
129 | - imageio==2.31.3
130 | - imgaug==0.4.0
131 | - imgcat==0.5.0
132 | - importlib-metadata==6.8.0
133 | - importlib-resources==6.0.1
134 | - joblib==1.3.2
135 | - keras==2.13.1
136 | - kiwisolver==1.4.5
137 | - lazy-loader==0.3
138 | - libclang==16.0.6
139 | - markdown==3.4.4
140 | - matplotlib==3.7.3
141 | - networkx==3.1
142 | - ninja==1.11.1
143 | - nltk==3.7
144 | - nvidia-ml-py==12.535.108
145 | - nvitop==1.3.0
146 | - oauthlib==3.2.2
147 | - openai-clip==1.0.1
148 | - opencv-python==4.8.0.76
149 | - opt-einsum==3.3.0
150 | - pandas==2.0.3
151 | - pillow==10.0.1
152 | - protobuf==4.24.3
153 | - pyasn1==0.5.0
154 | - pyasn1-modules==0.3.0
155 | - pyparsing==3.1.1
156 | - python-dateutil==2.8.2
157 | - pywavelets==1.4.1
158 | - regex==2023.10.3
159 | - requests-oauthlib==1.3.1
160 | - rsa==4.9
161 | - sacremoses==0.1.1
162 | - safetensors==0.3.3
163 | - scikit-image==0.21.0
164 | - scikit-learn==1.3.0
165 | - scipy==1.10.1
166 | - sentry-sdk==1.34.0
167 | - seqeval==1.2.2
168 | - setproctitle==1.3.3
169 | - shapely==2.0.1
170 | - smmap==5.0.1
171 | - tensorboard==2.13.0
172 | - tensorboard-data-server==0.7.1
173 | - tensorboardx==2.6.2.2
174 | - threadpoolctl==3.2.0
175 | - tifffile==2023.7.10
176 | - timm==0.4.12
177 | - tokenizers==0.14.1
178 | - transformers==4.16.2
179 | - triton==2.1.0
180 | - typing-extensions==4.6.1
181 | - tzdata==2023.3
182 | - ujson==5.8.0
183 | - wandb==0.16.0
184 | - wcwidth==0.2.8
185 | - werkzeug==3.0.0
186 | - wrapt==1.16.0
187 | - zipp==3.16.2
188 | - deepspeed==0.16.4
189 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [AAAI 2025] Surgical Workflow Recognition and Blocking Effectiveness Detection in Laparoscopic Liver Resections with Pringle Maneuver
2 |
3 |
4 |

5 |
6 |
7 | [](https://opensource.org/licenses/MIT)
8 |
9 | ## Abstract
10 |
11 | > Pringle maneuver (PM) in laparoscopic liver resection aims to reduce blood loss and provide a clear surgical view by intermittently blocking blood inflow of the liver, whereas prolonged PM may cause ischemic injury. To comprehensively monitor this surgical procedure and provide timely warnings of ineffective and prolonged blocking, we suggest two complementary AI-assisted surgical monitoring tasks: workflow recognition and blocking effectiveness detection in liver resections. The former presents challenges in real-time capturing of short-term PM, while the latter involves the intraoperative discrimination of long-term liver ischemia states. To address these challenges, we meticulously collect a novel dataset, called PmLR50, consisting of 25,037 video frames covering various surgical phases from 50 laparoscopic liver resection procedures. Additionally, we develop an online baseline for PmLR50, termed PmNet. This model embraces Masked Temporal Encoding (MTE) and Compressed Sequence Modeling (CSM) for efficient short-term and long-term temporal information modeling, and embeds Contrastive Prototype Separation (CPS) to enhance action discrimination between similar intraoperative operations. Experimental results demonstrate that PmNet outperforms existing state-of-the-art surgical workflow recognition methods on the PmLR50 benchmark. Our research offers potential clinical applications for the laparoscopic liver surgery community.
12 |
13 |
14 | ## 🔥🔥🔥 News!!
15 | * Dec 12, 2024: 🤗 Our work has been accepted by AAAI 2025! Congratulations!
16 | * Feb 21, 2025: 🚀 Code and dataset have been released! [Dataset Link](https://docs.google.com/forms/d/e/1FAIpQLSf33G5mdwXeqwabfbXnEboMpj48iCNlQBAY_up4kLuZiqCPUQ/viewform?usp=dialog)
17 |
18 |
19 | ## PmLR50 Dataset and PmNet
20 | ### Installation
21 | * Environment: CUDA 12.5 / Python 3.8
22 | * Device: Two NVIDIA GeForce RTX 4090 GPUs
23 | * Create a virtual environment
24 | ```shell
25 | git clone https://github.com/RascalGdd/PmNet.git
26 | cd PmNet
27 | conda env create -f PmNet.yml
28 | conda activate Pmnet
29 | ```
30 | ### Prepare your data
31 | Download processed data from [PmLR50](https://docs.google.com/forms/d/e/1FAIpQLSf33G5mdwXeqwabfbXnEboMpj48iCNlQBAY_up4kLuZiqCPUQ/viewform?usp=dialog);
32 | The final structure of datasets should be as following:
33 |
34 | ```bash
35 | data/
36 | └──PmLR50/
37 | └──frames/
38 | └──01
39 | ├──00000000.jpg
40 | ├──00000001.jpg
41 | └──...
42 | ├──...
43 | └──50
44 | └──phase_annotations/
45 | └──01.txt
46 | ├──02.txt
47 | ├──...
48 | └──50.txt
49 | └──blocking_annotations/
50 | └──01.txt
51 | ├──02.txt
52 | ├──...
53 | └──50.txt
54 | └──bbox_annotations/
55 | └──01.json
56 | ├──02.json
57 | ├──...
58 | └──50.json
59 | ```
60 | Then, process the data with [generate_labels_pmlr.py](https://github.com/RascalGdd/PmNet/blob/main/datasets/data_preprosses/generate_labels_pmlr.py) to generate labels for training and testing.
61 |
62 | ### Training
63 | We provide the script for training [train.sh](https://github.com/RascalGdd/PmNet/blob/main/train.sh) and testing [test.sh](https://github.com/RascalGdd/PmNet/blob/main/test.sh).
64 |
65 | run the following code for training
66 |
67 | ```shell
68 | sh train.sh
69 | ```
70 | and run the following code for testing
71 |
72 | ```shell
73 | sh test.sh
74 | ```
75 | The checkpoint of our model is provided [here](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155229775_link_cuhk_edu_hk/EZVHcTmQBY1Mv1zTSLEtu0cBKTA7zTNURaG65gWWloqFmg?e=Zudo2X).
76 |
77 | ### More Configurations
78 |
79 | We list some more useful configurations for easy usage:
80 |
81 | | Argument | Default | Description |
82 | |:----------------------:|:---------:|:-----------------------------------------:|
83 | | `--nproc_per_node` | 2 | Number of nodes used for training and testing |
84 | | `--batch_size` | 8 | The batch size for training and inference |
85 | | `--epochs` | 50 | The max epoch for training |
86 | | `--save_ckpt_freq` | 10 | The frequency for saving checkpoints during training |
87 | | `--nb_classes` | 5 | The number of classes for surgical workflows |
88 | | `--data_strategy` | online | Online/offline mode |
89 | | `--num_frames` | 20 | The number of consecutive frames used |
90 | | `--sampling_rate` | 8 | The sampling interval for comsecutive frames |
91 | | `--enable_deepspeed` | True | Use deepspeed to accelerate |
92 | | `--dist_eval` | False | Use distributed evaluation to accelerate |
93 | | `--load_ckpt` | -- | Load a given checkpoint for testing |
94 |
95 | ## Acknowledgements
96 | Huge thanks to the authors of following open-source projects:
97 | - [TMRNet](https://github.com/YuemingJin/TMRNet)
98 | - [Surgformer](https://github.com/isyangshu/Surgformer/)
99 | - [TimeSformer](https://github.com/facebookresearch/TimeSformer)
100 |
101 | ## Citation
102 | If you find our work useful in your research, please consider citing our paper:
103 |
104 | @inproceedings{guo2025surgical,
105 | title={Surgical Workflow Recognition and Blocking Effectiveness Detection in Laparoscopic Liver Resection with Pringle Maneuver},
106 | author={Guo, Diandian and Si, Weixin and Li, Zhixi and Pei, Jialun and Heng, Pheng-Ann},
107 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
108 | volume={39},
109 | number={3},
110 | pages={3220--3228},
111 | year={2025}
112 | }
113 |
--------------------------------------------------------------------------------
/assets/Pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/assets/Pipeline.png
--------------------------------------------------------------------------------
/assets/photo:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/datasets/__pycache__/args.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/__pycache__/args.cpython-311.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/functional.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/__pycache__/functional.cpython-310.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/functional.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/__pycache__/functional.cpython-311.pyc
--------------------------------------------------------------------------------
/datasets/args.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def get_args():
4 | parser = argparse.ArgumentParser('SurgVideoMAE pre-training script', add_help=False)
5 | parser.add_argument('--batch_size', default=64, type=int)
6 | parser.add_argument('--epochs', default=801, type=int)
7 | parser.add_argument('--save_ckpt_freq', default=20, type=int)
8 |
9 | # Model parameters
10 | parser.add_argument('--model', default='pretrain_videomae_base_patch16_224', type=str, metavar='MODEL',
11 | help='Name of model to train')
12 |
13 | parser.add_argument('--decoder_depth', default=4, type=int,
14 | help='depth of decoder')
15 |
16 | parser.add_argument('--mask_type', default='tube', choices=['random', 'tube'],
17 | type=str, help='masked strategy of video tokens/patches')
18 |
19 | parser.add_argument('--mask_ratio', default=0.9, type=float,
20 | help='ratio of the visual tokens/patches need be masked')
21 |
22 | parser.add_argument('--input_size', default=224, type=int,
23 | help='videos input size for backbone')
24 |
25 | parser.add_argument('--drop_path', type=float, default=0.0, metavar='PCT',
26 | help='Drop path rate (default: 0.1)')
27 |
28 | parser.add_argument('--normlize_target', default=True, type=bool,
29 | help='normalized the target patch pixels')
30 |
31 | # Optimizer parameters
32 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
33 | help='Optimizer (default: "adamw"')
34 | parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
35 | help='Optimizer Epsilon (default: 1e-8)')
36 | parser.add_argument('--opt_betas', default=(0.9, 0.95), type=float, nargs='+', metavar='BETA',
37 | help='Optimizer Betas (default: None, use opt default)')
38 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
39 | help='Clip gradient norm (default: None, no clipping)')
40 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
41 | help='SGD momentum (default: 0.9)')
42 | parser.add_argument('--weight_decay', type=float, default=0.05,
43 | help='weight decay (default: 0.05)')
44 | parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the
45 | weight decay. We use a cosine schedule for WD.
46 | (Set the same value with args.weight_decay to keep weight decay no change)""")
47 |
48 | parser.add_argument('--lr', type=float, default=1.5e-4, metavar='LR',
49 | help='learning rate (default: 1.5e-4)')
50 | parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR',
51 | help='warmup learning rate (default: 1e-6)')
52 | parser.add_argument('--min_lr', type=float, default=1e-5, metavar='LR',
53 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
54 |
55 | parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N',
56 | help='epochs to warmup LR, if scheduler supports')
57 | parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N',
58 | help='epochs to warmup LR, if scheduler supports')
59 | parser.add_argument('--use_checkpoint', action='store_true')
60 | parser.set_defaults(use_checkpoint=False)
61 |
62 | # Augmentation parameters
63 | parser.add_argument('--color_jitter', type=float, default=0.0, metavar='PCT',
64 | help='Color jitter factor (default: 0.4)')
65 | parser.add_argument('--train_interpolation', type=str, default='bicubic',
66 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
67 |
68 | # Dataset parameters
69 | parser.add_argument('--data_path', default='/path/to/list_kinetics-400', type=str,
70 | help='dataset path')
71 | parser.add_argument('--imagenet_default_mean_and_std', default=True, action='store_true')
72 | parser.add_argument('--num_frames', type=int, default= 16)
73 | parser.add_argument('--sampling_rate', type=int, default=2)
74 | parser.add_argument('--output_dir', default='Cholec80',
75 | help='path where to save, empty for no saving')
76 | parser.add_argument('--log_dir', default=None,
77 | help='path where to tensorboard log')
78 | parser.add_argument('--device', default='cuda',
79 | help='device to use for training / testing')
80 | parser.add_argument('--seed', default=0, type=int)
81 | parser.add_argument('--resume', default='', help='resume from checkpoint')
82 | parser.add_argument('--auto_resume', action='store_true')
83 | parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume')
84 | parser.set_defaults(auto_resume=True)
85 |
86 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
87 | help='start epoch')
88 | parser.add_argument('--num_workers', default=10, type=int)
89 | parser.add_argument('--pin_mem', action='store_true',
90 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
91 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem',
92 | help='')
93 | parser.set_defaults(pin_mem=True)
94 |
95 | # distributed training parameters
96 | parser.add_argument('--world_size', default=1, type=int,
97 | help='number of distributed processes')
98 | parser.add_argument('--local_rank', default=-1, type=int)
99 | parser.add_argument('--dist_on_itp', action='store_true')
100 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
101 |
102 | parser.add_argument('--window_size', default=(8,14,14), type=tuple,
103 | help='number of distributed processes')
104 |
105 | return parser.parse_args()
106 |
107 |
108 | def get_args_finetuning():
109 | parser = argparse.ArgumentParser(
110 | "SurgVideoMAE fine-tuning and evaluation script for video phase recognition",
111 | add_help=False,
112 | )
113 | parser.add_argument("--batch_size", default=10, type=int)
114 | parser.add_argument("--epochs", default=100, type=int)
115 | parser.add_argument("--update_freq", default=1, type=int)
116 | parser.add_argument("--save_ckpt_freq", default=10, type=int)
117 |
118 |
119 | # Model parameters
120 | parser.add_argument(
121 | "--model",
122 | default="vit_base_patch16_224",
123 | type=str,
124 | metavar="MODEL",
125 | help="Name of model to train",
126 | )
127 | parser.add_argument("--tubelet_size", type=int, default=2)
128 | parser.add_argument("--full_finetune", action="store_true", default=False)
129 | parser.add_argument("--input_size", default=224, type=int, help="videos input size")
130 |
131 | parser.add_argument(
132 | "--fc_drop_rate",
133 | type=float,
134 | default=0.0,
135 | metavar="PCT",
136 | help="Dropout rate (default: 0.)",
137 | )
138 | parser.add_argument(
139 | "--drop",
140 | type=float,
141 | default=0.0,
142 | metavar="PCT",
143 | help="Dropout rate (default: 0.)",
144 | )
145 | parser.add_argument(
146 | "--attn_drop_rate",
147 | type=float,
148 | default=0.0,
149 | metavar="PCT",
150 | help="Attention dropout rate (default: 0.)",
151 | )
152 | parser.add_argument(
153 | "--drop_path",
154 | type=float,
155 | default=0,
156 | metavar="PCT",
157 | help="Drop path rate (default: 0.1)",
158 | )
159 |
160 | parser.add_argument(
161 | "--disable_eval_during_finetuning", action="store_true", default=False
162 | )
163 | parser.add_argument("--model_ema", action="store_true", default=False)
164 | parser.add_argument("--model_ema_decay", type=float, default=0.9999, help="")
165 | parser.add_argument(
166 | "--model_ema_force_cpu", action="store_true", default=False, help=""
167 | )
168 |
169 | # Optimizer parameters
170 | parser.add_argument(
171 | "--opt",
172 | default="adamw",
173 | type=str,
174 | metavar="OPTIMIZER",
175 | help='Optimizer (default: "adamw"',
176 | )
177 | parser.add_argument(
178 | "--opt_eps",
179 | default=1e-8,
180 | type=float,
181 | metavar="EPSILON",
182 | help="Optimizer Epsilon (default: 1e-8)",
183 | )
184 | parser.add_argument(
185 | "--opt_betas",
186 | default=(0.9, 0.999),
187 | type=float,
188 | nargs="+",
189 | metavar="BETA",
190 | help="Optimizer Betas (default: None, use opt default)",
191 | )
192 | parser.add_argument(
193 | "--clip_grad",
194 | type=float,
195 | default=None,
196 | metavar="NORM",
197 | help="Clip gradient norm (default: None, no clipping)",
198 | )
199 | parser.add_argument(
200 | "--momentum",
201 | type=float,
202 | default=0.9,
203 | metavar="M",
204 | help="SGD momentum (default: 0.9)",
205 | )
206 | parser.add_argument(
207 | "--weight_decay", type=float, default=0.05, help="weight decay (default: 0.05)"
208 | )
209 | parser.add_argument(
210 | "--weight_decay_end",
211 | type=float,
212 | default=None,
213 | help="""Final value of the
214 | weight decay. We use a cosine schedule for WD and using a larger decay by
215 | the end of training improves performance for ViTs.""",
216 | )
217 |
218 | parser.add_argument(
219 | "--lr",
220 | type=float,
221 | default=5e-4,
222 | metavar="LR",
223 | help="learning rate (default: 1e-3)",
224 | )
225 | parser.add_argument("--layer_decay", type=float, default=0.75)
226 |
227 | parser.add_argument(
228 | "--warmup_lr",
229 | type=float,
230 | default=1e-6,
231 | metavar="LR",
232 | help="warmup learning rate (default: 1e-6)",
233 | )
234 | parser.add_argument(
235 | "--min_lr",
236 | type=float,
237 | default=1e-6,
238 | metavar="LR",
239 | help="lower lr bound for cyclic schedulers that hit 0 (1e-5)",
240 | )
241 |
242 | parser.add_argument(
243 | "--warmup_epochs",
244 | type=int,
245 | default=5,
246 | metavar="N",
247 | help="epochs to warmup LR, if scheduler supports",
248 | )
249 | parser.add_argument(
250 | "--warmup_steps",
251 | type=int,
252 | default=-1,
253 | metavar="N",
254 | help="num of steps to warmup LR, will overload warmup_epochs if set > 0",
255 | )
256 |
257 | # Augmentation parameters
258 | parser.add_argument(
259 | "--color_jitter",
260 | type=float,
261 | default=0.4,
262 | metavar="PCT",
263 | help="Color jitter factor (default: 0.4)",
264 | )
265 | parser.add_argument(
266 | "--num_sample", type=int, default=1, help="Repeated_aug (default: 2)"
267 | )
268 | parser.add_argument(
269 | "--aa",
270 | type=str,
271 | default="rand-m7-n4-mstd0.5-inc1",
272 | # default="rand-m8-n2-mstd0.5-inc1",
273 | metavar="NAME",
274 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m7-n4-mstd0.5-inc1)',
275 | ),
276 | parser.add_argument(
277 | "--smoothing", type=float, default=0.1, help="Label smoothing (default: 0.1)"
278 | )
279 | parser.add_argument(
280 | "--train_interpolation",
281 | type=str,
282 | default="bicubic",
283 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")',
284 | )
285 |
286 | # Evaluation parameters
287 | parser.add_argument("--crop_pct", type=float, default=None)
288 | parser.add_argument("--short_side_size", type=int, default=224)
289 | parser.add_argument("--test_num_segment", type=int, default=5)
290 | parser.add_argument("--test_num_crop", type=int, default=3)
291 |
292 | # Random Erase params
293 | parser.add_argument(
294 | "--reprob",
295 | type=float,
296 | default=0.25,
297 | metavar="PCT",
298 | help="Random erase prob (default: 0.25)",
299 | )
300 | parser.add_argument(
301 | "--remode",
302 | type=str,
303 | default="pixel",
304 | help='Random erase mode (default: "pixel")',
305 | )
306 | parser.add_argument(
307 | "--recount", type=int, default=1, help="Random erase count (default: 1)"
308 | )
309 | parser.add_argument(
310 | "--resplit",
311 | action="store_true",
312 | default=False,
313 | help="Do not random erase first (clean) augmentation split",
314 | )
315 |
316 | # Mixup params
317 | parser.add_argument(
318 | "--mixup",
319 | type=float,
320 | default=0,
321 | help="mixup alpha, mixup enabled if > 0, default 0.8.",
322 | )
323 | parser.add_argument(
324 | "--cutmix",
325 | type=float,
326 | default=0,
327 | help="cutmix alpha, cutmix enabled if > 0, default 1.0.",
328 | )
329 | parser.add_argument(
330 | "--cutmix_minmax",
331 | type=float,
332 | nargs="+",
333 | default=None,
334 | help="cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)",
335 | )
336 | parser.add_argument(
337 | "--mixup_prob",
338 | type=float,
339 | default=1.0,
340 | help="Probability of performing mixup or cutmix when either/both is enabled",
341 | )
342 | parser.add_argument(
343 | "--mixup_switch_prob",
344 | type=float,
345 | default=0.5,
346 | help="Probability of switching to cutmix when both mixup and cutmix enabled",
347 | )
348 | parser.add_argument(
349 | "--mixup_mode",
350 | type=str,
351 | default="batch",
352 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"',
353 | )
354 |
355 | # Finetuning params
356 | parser.add_argument("--finetune", default="", help="finetune from checkpoint")
357 | parser.add_argument("--model_key", default="model|module", type=str)
358 | parser.add_argument("--model_prefix", default="", type=str)
359 | parser.add_argument("--init_scale", default=0.001, type=float)
360 | parser.add_argument("--use_checkpoint", action="store_true")
361 | parser.set_defaults(use_checkpoint=False)
362 | parser.add_argument("--use_mean_pooling", action="store_true")
363 | parser.set_defaults(use_mean_pooling=True)
364 | parser.add_argument("--use_cls", action="store_false", dest="use_mean_pooling")
365 |
366 | # Dataset parameters
367 | parser.add_argument(
368 | "--data_path",
369 | default="/Users/yangshu/Downloads/Cataract101",
370 | # default="/Users/yangshu/Documents/SurgVideoMAE/data/cholec80",
371 | type=str,
372 | help="dataset path",
373 | )
374 | parser.add_argument(
375 | "--eval_data_path",
376 | default="/Users/yangshu/Downloads/Cataract101",
377 | type=str,
378 | help="dataset path for evaluation",
379 | )
380 | parser.add_argument(
381 | "--nb_classes", default=10, type=int, help="number of the classification types"
382 | )
383 | parser.add_argument(
384 | "--imagenet_default_mean_and_std", default=True, action="store_true"
385 | )
386 | parser.add_argument("--num_segments", type=int, default=1)
387 | parser.add_argument("--num_frames", type=int, default=8)
388 | parser.add_argument("--sampling_rate", type=int, default=2)
389 | parser.add_argument(
390 | "--data_set",
391 | default="Cataract101",
392 | choices=["Cholec80", "AutoLaparo", "Cataract101"],
393 | type=str,
394 | help="dataset",
395 | )
396 | parser.add_argument(
397 | "--data_fps",
398 | default="1fps",
399 | choices=["", "5fps", "1fps"],
400 | type=str,
401 | help="dataset",
402 | )
403 | parser.add_argument(
404 | "--output_dir",
405 | default="/home/yangshu/SurgVideoMAE/Cholec80/ImageNet/phase/1fps_loss",
406 | help="path where to save, empty for no saving",
407 | )
408 | parser.add_argument(
409 | "--log_dir",
410 | default="/home/yangshu/SurgVideoMAE/Cholec80/ImageNet/phase/1fps_loss/log",
411 | help="path where to tensorboard log",
412 | )
413 | parser.add_argument(
414 | "--device", default="cuda", help="device to use for training / testing"
415 | )
416 | parser.add_argument("--seed", default=0, type=int)
417 | parser.add_argument("--resume", default="", help="resume from checkpoint")
418 | parser.add_argument("--auto_resume", action="store_true")
419 | parser.add_argument("--no_auto_resume", action="store_false", dest="auto_resume")
420 | parser.set_defaults(auto_resume=True)
421 |
422 | parser.add_argument("--save_ckpt", action="store_true")
423 | parser.add_argument("--no_save_ckpt", action="store_false", dest="save_ckpt")
424 | parser.set_defaults(save_ckpt=True)
425 |
426 | parser.add_argument(
427 | "--start_epoch", default=0, type=int, metavar="N", help="start epoch"
428 | )
429 | parser.add_argument(
430 | "--eval", action="store_true", default=False, help="Perform evaluation only"
431 | )
432 | parser.add_argument(
433 | "--dist_eval",
434 | action="store_true",
435 | default=False,
436 | help="Enabling distributed evaluation",
437 | )
438 | parser.add_argument("--num_workers", default=10, type=int)
439 | parser.add_argument(
440 | "--pin_mem",
441 | action="store_true",
442 | help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.",
443 | )
444 | parser.add_argument("--no_pin_mem", action="store_false", dest="pin_mem")
445 | parser.set_defaults(pin_mem=True)
446 |
447 | # distributed training parameters
448 | parser.add_argument(
449 | "--world_size", default=1, type=int, help="number of distributed processes"
450 | )
451 | parser.add_argument("--local_rank", default=-1, type=int)
452 | parser.add_argument("--dist_on_itp", action="store_true")
453 | parser.add_argument(
454 | "--dist_url", default="env://", help="url used to set up distributed training"
455 | )
456 |
457 | parser.add_argument("--enable_deepspeed", action="store_true", default=False)
458 |
459 | known_args, _ = parser.parse_known_args()
460 |
461 | if known_args.enable_deepspeed:
462 | try:
463 | import deepspeed
464 | from deepspeed import DeepSpeedConfig
465 |
466 | parser = deepspeed.add_config_arguments(parser)
467 | ds_init = deepspeed.initialize
468 | except:
469 | print("Please 'pip install deepspeed'")
470 | exit(0)
471 | else:
472 | ds_init = None
473 |
474 | return parser.parse_args(), ds_init
--------------------------------------------------------------------------------
/datasets/convert_results/convert_autolaparo.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | phases = [
4 | "Preparation",
5 | "CalotTriangleDissection",
6 | "ClippingCutting",
7 | "GallbladderDissection",
8 | "GallbladderPackaging",
9 | "CleaningCoagulation",
10 | "GallbladderRetraction",
11 | ]
12 |
13 | def create_folder_if_not_exists(folder_path):
14 | if not os.path.exists(folder_path):
15 | os.makedirs(folder_path)
16 | print("文件夹已创建:", folder_path)
17 | else:
18 | print("文件夹已存在:", folder_path)
19 |
20 | main_path = "/Users/yangshu/Documents/PETL4SurgVideo/result_save/Ours/AutoLaparo/16-4/LGA/WIT400M/"
21 | file_path_0 = os.path.join(main_path, "0.txt")
22 | file_path_1 = os.path.join(main_path, "1.txt")
23 | anns_path = "/Users/yangshu/Documents/PETL4SurgVideo/result_save/Ours/AutoLaparo/16-4/LGA/WIT400M" + "/phase_annotations"
24 | pred_path = "/Users/yangshu/Documents/PETL4SurgVideo/result_save/Ours/AutoLaparo/16-4/LGA/WIT400M" + "/prediction"
25 |
26 |
27 | create_folder_if_not_exists(anns_path)
28 | create_folder_if_not_exists(pred_path)
29 |
30 | with open(file_path_0) as f:
31 | lines0 = f.readlines()
32 |
33 | with open(file_path_1) as f:
34 | lines1 = f.readlines()
35 |
36 | for i in range(15, 22):
37 | with open(
38 | anns_path + "/video-{}.txt".format(str(i)), "w"
39 | ) as f:
40 | f.write("Frame")
41 | f.write("\t")
42 | f.write("Phase")
43 | f.write("\n")
44 | assert len(lines0) == len(lines1)
45 | for j in range(1, len(lines0)):
46 | temp0 = lines0[j].split()
47 | temp1 = lines1[j].split()
48 | if temp0[1] == "{}".format(str(i)):
49 | f.write(str(temp0[2])) # phase_annotations
50 | f.write("\t") # phase_annotations
51 | f.write(str(temp0[-1])) # phase_annotations
52 | f.write("\n") # phase_annotations
53 | if temp1[1] == "{}".format(str(i)):
54 | f.write(str(temp1[2])) # phase_annotations
55 | f.write("\t") # phase_annotations
56 | f.write(str(temp1[-1])) # phase_annotations
57 | f.write("\n") # phase_annotations
58 |
59 | with open(file_path_0) as f:
60 | lines0 = f.readlines()
61 |
62 | with open(file_path_1) as f:
63 | lines1 = f.readlines()
64 | for i in range(15, 22):
65 | print(i)
66 | with open(
67 | pred_path + "/video-{}.txt".format(str(i)), "w"
68 | ) as f: # phase_annotations
69 | f.write("Frame")
70 | f.write("\t")
71 | f.write("Phase")
72 | f.write("\n")
73 | assert len(lines0) == len(lines1)
74 | for j in range(1, len(lines0)):
75 | temp0 = lines0[j].strip() # prediction
76 | temp1 = lines1[j].strip() # prediction
77 | data0 = np.fromstring(
78 | temp0.split("[")[1].split("]")[0], dtype=np.float32, sep=","
79 | ) # prediction
80 | data1 = np.fromstring(
81 | temp1.split("[")[1].split("]")[0], dtype=np.float32, sep=","
82 | ) # prediction
83 | data0 = data0.argmax() # prediction
84 | data1 = data1.argmax() # prediction
85 | temp0 = lines0[j].split()
86 | temp1 = lines1[j].split()
87 | if temp0[1] == "{}".format(str(i)):
88 | f.write(str(temp0[2])) # prediction
89 | f.write('\t') # prediction
90 | f.write(str(data0)) # prediction
91 | f.write('\n') # prediction
92 | if temp1[1] == "{}".format(str(i)):
93 | f.write(str(temp1[2])) # prediction
94 | f.write('\t') # prediction
95 | f.write(str(data1)) # prediction
96 | f.write('\n') # prediction
97 |
--------------------------------------------------------------------------------
/datasets/convert_results/convert_cholec80.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | phases = [
4 | "Preparation",
5 | "CalotTriangleDissection",
6 | "ClippingCutting",
7 | "GallbladderDissection",
8 | "GallbladderPackaging",
9 | "CleaningCoagulation",
10 | "GallbladderRetraction",
11 | ]
12 |
13 | def create_folder_if_not_exists(folder_path):
14 | if not os.path.exists(folder_path):
15 | os.makedirs(folder_path)
16 | print("文件夹已创建:", folder_path)
17 | else:
18 | print("文件夹已存在:", folder_path)
19 |
20 |
21 | main_path = "/Users/yangshu/Documents/PETL4SurgVideo/result_save/Dualpath/Large/"
22 | file_path_0 = os.path.join(main_path, "0.txt")
23 | file_path_1 = os.path.join(main_path, "1.txt")
24 | anns_path = "/Users/yangshu/Documents/PETL4SurgVideo/result_save/Dualpath/Large" + "/phase_annotations"
25 | pred_path = "/Users/yangshu/Documents/PETL4SurgVideo/result_save/Dualpath/Large" + "/prediction"
26 |
27 |
28 | create_folder_if_not_exists(anns_path)
29 | create_folder_if_not_exists(pred_path)
30 |
31 | with open(file_path_0) as f:
32 | lines0 = f.readlines()
33 |
34 | with open(file_path_1) as f:
35 | lines1 = f.readlines()
36 |
37 | for i in range(41, 81):
38 | with open(
39 | anns_path + "/video-{}.txt".format(str(i)), "w"
40 | ) as f:
41 | f.write("Frame")
42 | f.write("\t")
43 | f.write("Phase")
44 | f.write("\n")
45 | assert len(lines0) == len(lines1)
46 | for j in range(1, len(lines0)):
47 | temp0 = lines0[j].split()
48 | temp1 = lines1[j].split()
49 | if temp0[1] == "video{}".format(str(i)):
50 | f.write(str(temp0[2])) # phase_annotations
51 | f.write("\t") # phase_annotations
52 | f.write(str(temp0[-1])) # phase_annotations
53 | f.write("\n") # phase_annotations
54 | if temp1[1] == "video{}".format(str(i)):
55 | f.write(str(temp1[2])) # phase_annotations
56 | f.write("\t") # phase_annotations
57 | f.write(str(temp1[-1])) # phase_annotations
58 | f.write("\n") # phase_annotations
59 |
60 | with open(file_path_0) as f:
61 | lines0 = f.readlines()
62 |
63 | with open(file_path_1) as f:
64 | lines1 = f.readlines()
65 | for i in range(41, 81):
66 | print(i)
67 | with open(
68 | pred_path + "/video-{}.txt".format(str(i)), "w"
69 | ) as f: # phase_annotations
70 | f.write("Frame")
71 | f.write("\t")
72 | f.write("Phase")
73 | f.write("\n")
74 | assert len(lines0) == len(lines1)
75 | for j in range(1, len(lines0)):
76 | temp0 = lines0[j].strip() # prediction
77 | temp1 = lines1[j].strip() # prediction
78 | data0 = np.fromstring(
79 | temp0.split("[")[1].split("]")[0], dtype=np.float32, sep=","
80 | ) # prediction
81 | data1 = np.fromstring(
82 | temp1.split("[")[1].split("]")[0], dtype=np.float32, sep=","
83 | ) # prediction
84 | data0 = data0.argmax() # prediction
85 | data1 = data1.argmax() # prediction
86 | temp0 = lines0[j].split()
87 | temp1 = lines1[j].split()
88 | if temp0[1] == "video{}".format(str(i)):
89 | f.write(str(temp0[2])) # prediction
90 | f.write('\t') # prediction
91 | f.write(str(data0)) # prediction
92 | f.write('\n') # prediction
93 | if temp1[1] == "video{}".format(str(i)):
94 | f.write(str(temp1[2])) # prediction
95 | f.write('\t') # prediction
96 | f.write(str(data1)) # prediction
97 | f.write('\n') # prediction
98 |
--------------------------------------------------------------------------------
/datasets/data_preprosses/generate_labels_autolaparo.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import cv2
4 | import pickle
5 | from tqdm import tqdm
6 |
7 | def main():
8 | ROOT_DIR = "/xxxxxxx/AutoLaparo/"
9 | VIDEO_NAMES = os.listdir(os.path.join(ROOT_DIR, 'frames'))
10 | VIDEO_NAMES = sorted([x for x in VIDEO_NAMES if "DS" not in x])
11 |
12 | TRAIN_NUMBERS = np.arange(1,11).tolist()
13 | VAL_NUMBERS = np.arange(11,15).tolist()
14 | TEST_NUMBERS = np.arange(15,22).tolist()
15 |
16 | TRAIN_FRAME_NUMBERS = 0
17 | VAL_FRAME_NUMBERS = 0
18 | TEST_FRAME_NUMBERS = 0
19 |
20 | train_pkl = dict()
21 | val_pkl = dict()
22 | test_pkl = dict()
23 |
24 | unique_id = 0
25 | unique_id_train = 0
26 | unique_id_val = 0
27 | unique_id_test = 0
28 |
29 | id2phase = {1: "Preparation", 2: "Dividing Ligament and Peritoneum", 3: "Dividing Uterine Vessels and Ligament",
30 | 4: "Transecting the Vagina", 5: "Specimen Removal", 6: "Suturing", 7: "Washing"}
31 |
32 | for video_id in VIDEO_NAMES:
33 | vid_id = int(video_id)
34 | if vid_id in TRAIN_NUMBERS:
35 | unique_id = unique_id_train
36 | elif vid_id in VAL_NUMBERS:
37 | unique_id = unique_id_val
38 | elif vid_id in TEST_NUMBERS:
39 | unique_id = unique_id_test
40 |
41 | # 总帧数(frames)
42 | video_path = os.path.join(ROOT_DIR, "frames", video_id)
43 | frames_list = os.listdir(video_path)
44 |
45 | # 打开Label文件
46 | phase_path = os.path.join(ROOT_DIR, 'labels', "label_" + video_id + '.txt')
47 | phase_file = open(phase_path, 'r')
48 | phase_results = phase_file.readlines()[1:]
49 |
50 | frame_infos = list()
51 | frame_id_ = 0
52 | for frame_id in tqdm(range(0, len(frames_list))):
53 | info = dict()
54 | info['unique_id'] = unique_id
55 | info['frame_id'] = frame_id_
56 | info['original_frame_id'] = frame_id
57 | info['video_id'] = video_id
58 | info['tool_gt'] = None
59 | info['frames'] = len(frames_list)
60 | phase = phase_results[frame_id].strip().split()
61 | assert int(phase[0]) == frame_id + 1
62 | phase_id = int(phase[1])
63 | info['phase_gt'] = phase_id - 1
64 | info['phase_name'] = id2phase[int(phase[1])]
65 | info['fps'] = 1
66 | frame_infos.append(info)
67 | unique_id += 1
68 | frame_id_ += 1
69 |
70 | if vid_id in TRAIN_NUMBERS:
71 | train_pkl[video_id] = frame_infos
72 | TRAIN_FRAME_NUMBERS += len(frames_list)
73 | unique_id_train = unique_id
74 | elif vid_id in VAL_NUMBERS:
75 | val_pkl[video_id] = frame_infos
76 | VAL_FRAME_NUMBERS += len(frames_list)
77 | unique_id_val = unique_id
78 | elif vid_id in TEST_NUMBERS:
79 | test_pkl[video_id] = frame_infos
80 | TEST_FRAME_NUMBERS += len(frames_list)
81 | unique_id_test = unique_id
82 |
83 | train_save_dir = os.path.join(ROOT_DIR, 'labels_pkl', 'train')
84 | os.makedirs(train_save_dir, exist_ok=True)
85 | with open(os.path.join(train_save_dir, '1fpstrain.pickle'), 'wb') as file:
86 | pickle.dump(train_pkl, file)
87 |
88 | val_save_dir = os.path.join(ROOT_DIR, 'labels_pkl', 'val')
89 | os.makedirs(val_save_dir, exist_ok=True)
90 | with open(os.path.join(val_save_dir, '1fpsval.pickle'), 'wb') as file:
91 | pickle.dump(val_pkl, file)
92 |
93 | test_save_dir = os.path.join(ROOT_DIR, 'labels_pkl', 'test')
94 | os.makedirs(test_save_dir, exist_ok=True)
95 | with open(os.path.join(test_save_dir, '1fpstest.pickle'), 'wb') as file:
96 | pickle.dump(test_pkl, file)
97 |
98 |
99 | print('TRAIN Frams', TRAIN_FRAME_NUMBERS, unique_id_train)
100 | print('VAL Frams', VAL_FRAME_NUMBERS, unique_id_val)
101 | print('TEST Frams', TEST_FRAME_NUMBERS, unique_id_test)
102 |
103 | if __name__ == '__main__':
104 | main()
--------------------------------------------------------------------------------
/datasets/data_preprosses/generate_labels_ch80.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import cv2
4 | import pickle
5 | from tqdm import tqdm
6 |
7 | def main():
8 | ROOT_DIR = "/xxxxxxx/cholec80"
9 | VIDEO_NAMES = os.listdir(os.path.join(ROOT_DIR, 'videos'))
10 | VIDEO_NAMES = sorted([x for x in VIDEO_NAMES if 'mp4' in x])
11 | TRAIN_NUMBERS = np.arange(1,41).tolist()
12 | VAL_NUMBERS = np.arange(41,49).tolist()
13 | TEST_NUMBERS = np.arange(49,81).tolist()
14 |
15 | TRAIN_FRAME_NUMBERS = 0
16 | VAL_FRAME_NUMBERS = 0
17 | TEST_FRAME_NUMBERS = 0
18 |
19 | train_pkl = dict()
20 | val_pkl = dict()
21 | test_pkl = dict()
22 | val_test_pkl = dict()
23 |
24 | unique_id = 0
25 | unique_id_train = 0
26 | unique_id_val = 0
27 | unique_id_test = 0
28 |
29 | phase2id = {'Preparation': 0, 'CalotTriangleDissection': 1, 'ClippingCutting': 2, 'GallbladderDissection': 3,
30 | 'GallbladderPackaging': 4, 'CleaningCoagulation': 5, 'GallbladderRetraction': 6}
31 |
32 | for video_name in VIDEO_NAMES:
33 | video_id = video_name.replace('.mp4', '')
34 | vid_id = int(video_name.replace('.mp4', '').replace("video", ""))
35 | if vid_id in TRAIN_NUMBERS:
36 | unique_id = unique_id_train
37 | elif vid_id in VAL_NUMBERS:
38 | unique_id = unique_id_val
39 | elif vid_id in TEST_NUMBERS:
40 | unique_id = unique_id_test
41 |
42 | # 打开视频文件
43 | vidcap = cv2.VideoCapture(os.path.join(ROOT_DIR, './videos/' + video_name))
44 | # 帧率(frames per second)
45 | fps = vidcap.get(cv2.CAP_PROP_FPS)
46 | if fps != 25:
47 | print(video_name, 'not at 25fps', fps)
48 | # 总帧数(frames)
49 | frames = vidcap.get(cv2.CAP_PROP_FRAME_COUNT)
50 |
51 | # 打开Label文件
52 | tool_path = os.path.join(ROOT_DIR, 'tool_annotations', video_name.replace('.mp4', '-tool.txt'))
53 | tool_file = open(tool_path, 'r')
54 | tool = tool_file.readline().strip().split()
55 | tool_name = tool[1:]
56 | tool_dict = dict()
57 | while tool:
58 | tool = tool_file.readline().strip().split()
59 | if len(tool) > 0:
60 | tool = list(map(int, tool))
61 | tool_dict[str(tool[0])] = tool[1:]
62 |
63 | phase_path = os.path.join(ROOT_DIR, 'phase_annotations', video_name.replace('.mp4', '-phase.txt'))
64 | phase_file = open(phase_path, 'r')
65 | phase_results = phase_file.readlines()[1:]
66 |
67 | frame_infos = list()
68 | frame_id_ = 0
69 | for frame_id in tqdm(range(0, int(frames), 25)):
70 | info = dict()
71 | info['unique_id'] = unique_id
72 | info['frame_id'] = frame_id_
73 | info['original_frame_id'] = frame_id
74 | info['video_id'] = video_id
75 |
76 | if str(frame_id) in tool_dict:
77 | info['tool_gt'] = tool_dict[str(frame_id)]
78 | else:
79 | info['tool_gt'] = None
80 |
81 | phase = phase_results[frame_id].strip().split()
82 | assert int(phase[0]) == frame_id
83 | phase_id = phase2id[phase[1]]
84 | info['phase_gt'] = phase_id
85 | info['phase_name'] = phase[1]
86 | info['fps'] = 1
87 | info['original_frames'] = int(frames)
88 | info['frames'] = int(frames) // 25
89 | # info['tool_names'] = tool_name
90 | info['phase_name'] = phase[1]
91 | frame_infos.append(info)
92 | unique_id += 1
93 | frame_id_ += 1
94 |
95 | vid_id = int(video_name.replace('.mp4', '').replace("video", ""))
96 | if vid_id in TRAIN_NUMBERS:
97 | train_pkl[video_id] = frame_infos
98 | TRAIN_FRAME_NUMBERS += frames
99 | unique_id_train = unique_id
100 | elif vid_id in VAL_NUMBERS:
101 | val_pkl[video_id] = frame_infos
102 | VAL_FRAME_NUMBERS += frames
103 | unique_id_val = unique_id
104 | elif vid_id in TEST_NUMBERS:
105 | test_pkl[video_id] = frame_infos
106 | TEST_FRAME_NUMBERS += frames
107 | unique_id_test = unique_id
108 |
109 | val_test_pkl = {**val_pkl, **test_pkl}
110 |
111 | train_save_dir = os.path.join(ROOT_DIR, 'labels', 'train')
112 | os.makedirs(train_save_dir, exist_ok=True)
113 | with open(os.path.join(train_save_dir, '1fpstrain.pickle'), 'wb') as file:
114 | pickle.dump(train_pkl, file)
115 |
116 | val_save_dir = os.path.join(ROOT_DIR, 'labels', 'val')
117 | os.makedirs(val_save_dir, exist_ok=True)
118 | with open(os.path.join(val_save_dir, '1fpsval.pickle'), 'wb') as file:
119 | pickle.dump(val_pkl, file)
120 |
121 | test_save_dir = os.path.join(ROOT_DIR, 'labels', 'test')
122 | os.makedirs(test_save_dir, exist_ok=True)
123 | with open(os.path.join(test_save_dir, '1fpstest.pickle'), 'wb') as file:
124 | pickle.dump(test_pkl, file)
125 | with open(os.path.join(test_save_dir, '1fpsval_test.pickle'), 'wb') as file:
126 | pickle.dump(val_test_pkl, file)
127 |
128 |
129 | print('TRAIN Frams', TRAIN_FRAME_NUMBERS, unique_id_train)
130 | print('VAL Frams', VAL_FRAME_NUMBERS, unique_id_val)
131 | print('TEST Frams', TEST_FRAME_NUMBERS, unique_id_test)
132 | print('VAL TEST Frames', len(val_test_pkl))
133 |
134 | if __name__ == '__main__':
135 | main()
136 |
--------------------------------------------------------------------------------
/datasets/data_preprosses/generate_labels_lungseg.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import pickle
4 | from tqdm import tqdm
5 | import random
6 |
7 | def main():
8 | ROOT_DIR = "/home/diandian/Diandian/DD/PmNet/data/PmLR50"
9 | VIDEO_NAMES = os.listdir(os.path.join(ROOT_DIR, "frames"))
10 | VIDEO_NAMES = sorted([x for x in VIDEO_NAMES if "DS" not in x])
11 | VIDEO_NUMBERS = len(VIDEO_NAMES)
12 | TRAIN_NUMBERS = list(range(1, 34))
13 | TRAIN_NUMBERS.append(49)
14 | TRAIN_NUMBERS.append(50)
15 | # TRAIN_NUMBERS = sorted(random.sample(range(VIDEO_NUMBERS), 30))
16 | print(TRAIN_NUMBERS)
17 | # TEST_NUMBERS = [x for x in range(VIDEO_NUMBERS) if x not in TRAIN_NUMBERS]
18 | TEST_NUMBERS = list(range(34, 38))
19 | TEST_NUMBERS.append(48)
20 | print(TEST_NUMBERS)
21 | INFER_NUMBERS = list(range(38, 48))
22 |
23 | TRAIN_FRAME_NUMBERS = 0
24 | TEST_FRAME_NUMBERS = 0
25 | INFER_FRAME_NUMBERS = 0
26 |
27 | train_pkl = dict()
28 | test_pkl = dict()
29 | infer_pkl = dict()
30 |
31 | unique_id = 0
32 | unique_id_train = 0
33 | unique_id_test = 0
34 | unique_id_infer = 0
35 |
36 | for video_id in VIDEO_NAMES:
37 | vid_id = int(video_id)
38 |
39 | if vid_id in TRAIN_NUMBERS:
40 | unique_id = unique_id_train
41 | elif vid_id in TEST_NUMBERS:
42 | unique_id = unique_id_test
43 | elif vid_id in INFER_NUMBERS:
44 | unique_id = unique_id_infer
45 |
46 | # 总帧数(frames)
47 | video_path = os.path.join(ROOT_DIR, "frames", video_id)
48 | frames_list = os.listdir(video_path)
49 | frames_list = sorted([x for x in frames_list if "jpg" in x])
50 |
51 | # 打开Label文件
52 | anno_path = os.path.join(ROOT_DIR, 'phase_annotations', video_id + '.txt')
53 | anno_file = open(anno_path, 'r')
54 | anno_results = anno_file.readlines()[1:]
55 | print(len(frames_list))
56 | print(len(anno_results))
57 | assert len(frames_list) == len(anno_results)
58 | frame_infos = list()
59 | for frame_id in tqdm(range(0, len(frames_list))):
60 | info = dict()
61 | info['unique_id'] = unique_id
62 | info['frame_id'] = frame_id
63 | info['video_id'] = video_id
64 | info['frames'] = len(frames_list)
65 | anno_info = anno_results[frame_id]
66 | anno_frame = anno_info.split()[0]
67 | assert int(anno_frame) == frame_id
68 | anno_id = anno_info.split()[1]
69 | info['phase_gt'] = int(anno_id)
70 | info['fps'] = 1
71 | frame_infos.append(info)
72 | unique_id += 1
73 |
74 | if vid_id in TRAIN_NUMBERS:
75 | train_pkl[video_id] = frame_infos
76 | TRAIN_FRAME_NUMBERS += len(frames_list)
77 | unique_id_train = unique_id
78 | elif vid_id in TEST_NUMBERS:
79 | test_pkl[video_id] = frame_infos
80 | TEST_FRAME_NUMBERS += len(frames_list)
81 | unique_id_test = unique_id
82 | elif vid_id in INFER_NUMBERS:
83 | infer_pkl[video_id] = frame_infos
84 | INFER_FRAME_NUMBERS += len(frames_list)
85 | unique_id_infer = unique_id
86 |
87 | train_save_dir = os.path.join(ROOT_DIR, 'labels', 'train')
88 | os.makedirs(train_save_dir, exist_ok=True)
89 | with open(os.path.join(train_save_dir, '1fpstrain.pickle'), 'wb') as file:
90 | pickle.dump(train_pkl, file)
91 |
92 | test_save_dir = os.path.join(ROOT_DIR, 'labels', 'test')
93 | os.makedirs(test_save_dir, exist_ok=True)
94 | with open(os.path.join(test_save_dir, '1fpstest.pickle'), 'wb') as file:
95 | pickle.dump(test_pkl, file)
96 |
97 | infer_save_dir = os.path.join(ROOT_DIR, 'labels', 'infer')
98 | os.makedirs(infer_save_dir, exist_ok=True)
99 | with open(os.path.join(infer_save_dir, '1fpsinfer.pickle'), 'wb') as file:
100 | pickle.dump(infer_pkl, file)
101 |
102 | print('TRAIN Frams', TRAIN_FRAME_NUMBERS, unique_id_train)
103 | print('TEST Frams', TEST_FRAME_NUMBERS, unique_id_test)
104 | print('INFER Frams', INFER_FRAME_NUMBERS, unique_id_infer)
105 |
106 | if __name__ == "__main__":
107 | main()
108 |
--------------------------------------------------------------------------------
/datasets/data_preprosses/generate_labels_pmlr.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import pickle
4 | from tqdm import tqdm
5 | import random
6 | import json
7 |
8 | def main():
9 | ROOT_DIR = "/home/diandian/Diandian/DD/PmNet/data/PmLR50"
10 | VIDEO_NAMES = os.listdir(os.path.join(ROOT_DIR, "frames"))
11 | VIDEO_NAMES = sorted([x for x in VIDEO_NAMES if "DS" not in x])
12 | VIDEO_NUMBERS = len(VIDEO_NAMES)
13 | TRAIN_NUMBERS = list(range(1, 34))
14 | TRAIN_NUMBERS.append(49)
15 | TRAIN_NUMBERS.append(50)
16 | print('For training:', TRAIN_NUMBERS)
17 | TEST_NUMBERS = list(range(34, 38))
18 | TEST_NUMBERS.append(48)
19 | print('For validation:', TEST_NUMBERS)
20 | INFER_NUMBERS = list(range(38, 48))
21 | print('For testing:', INFER_NUMBERS)
22 |
23 | total_frame_count = 0
24 | for video_id in VIDEO_NAMES:
25 | if int(video_id )in TRAIN_NUMBERS:
26 | video_path = os.path.join(ROOT_DIR, "frames", video_id)
27 | total_frame_count += len(os.listdir(video_path))
28 | average_frame_number = total_frame_count/len(TRAIN_NUMBERS)
29 |
30 | TRAIN_FRAME_NUMBERS = 0
31 | TEST_FRAME_NUMBERS = 0
32 | INFER_FRAME_NUMBERS = 0
33 |
34 | train_pkl = dict()
35 | test_pkl = dict()
36 | infer_pkl = dict()
37 |
38 | unique_id = 0
39 | unique_id_train = 0
40 | unique_id_test = 0
41 | unique_id_infer = 0
42 |
43 | for video_id in VIDEO_NAMES:
44 | vid_id = int(video_id)
45 |
46 | if vid_id in TRAIN_NUMBERS:
47 | unique_id = unique_id_train
48 | elif vid_id in TEST_NUMBERS:
49 | unique_id = unique_id_test
50 | elif vid_id in INFER_NUMBERS:
51 | unique_id = unique_id_infer
52 |
53 | # 总帧数(frames)
54 | video_path = os.path.join(ROOT_DIR, "frames", video_id)
55 | frames_list = os.listdir(video_path)
56 | frames_list = sorted([x for x in frames_list if "jpg" in x])
57 |
58 | # 打开Label文件
59 | anno_path = os.path.join(ROOT_DIR, 'phase_annotations', video_id + '.txt')
60 | anno_file = open(anno_path, 'r')
61 | anno_results = anno_file.readlines()[1:]
62 | blocking_path = os.path.join(ROOT_DIR, 'blocking_annotations', video_id + '.txt')
63 | blocking_file = open(blocking_path, 'r')
64 | blocking_results = blocking_file.readlines()[1:]
65 | bbox_path = os.path.join(ROOT_DIR, 'bbox_annotations', video_id + '.json')
66 | with open(bbox_path, 'r', encoding='utf-8') as f:
67 | bbox_data = json.load(f)
68 | assert len(frames_list) == len(anno_results)
69 | frame_infos = list()
70 | for frame_id in tqdm(range(0, len(frames_list))):
71 | info = dict()
72 | info['unique_id'] = unique_id
73 | info['frame_id'] = frame_id
74 | info['video_id'] = video_id
75 | info['frames'] = len(frames_list)
76 | anno_info = anno_results[frame_id]
77 | blocking_info = blocking_results[frame_id]
78 | anno_frame = anno_info.split()[0]
79 | assert int(anno_frame) == frame_id
80 | anno_id = anno_info.split()[1]
81 | blocking_id = blocking_info.split()[1]
82 | info['phase_gt'] = int(anno_id)
83 | info['blocking_gt'] = int(blocking_id)
84 | origin_bbox = bbox_data[frames_list[frame_id].split('.')[0]]
85 | x1, x2, y1, y2 = origin_bbox[0][0], origin_bbox[1][0], origin_bbox[0][1], origin_bbox[1][1]
86 | if vid_id not in [16, 18, 46, 47, 48, 49]:
87 | x1, x2 = x1 * 224 / 1280, x2 * 224 / 1280
88 | y1, y2 = y1 * 224 / 720, y2 * 224 / 720
89 | else:
90 | x1, x2 = x1 * 224 / 1920, x2 * 224 / 1920
91 | y1, y2 = y1 * 224 / 1080, y2 * 224 / 1080
92 |
93 | x1, x2 = min(x1, x2), max(x1, x2)
94 | y1, y2 = min(y1, y2), max(y1, y2)
95 |
96 | x1, y1 = max(0, int(x1)), max(0, int(y1))
97 | x2, y2 = min(224, int(x2)), min(224, int(y2))
98 | info['bbox'] = [[x1, y1], [x2, y2]]
99 | info['fps'] = 1
100 | info['avg_length'] = average_frame_number
101 | frame_infos.append(info)
102 | unique_id += 1
103 |
104 | if vid_id in TRAIN_NUMBERS:
105 | train_pkl[video_id] = frame_infos
106 | TRAIN_FRAME_NUMBERS += len(frames_list)
107 | unique_id_train = unique_id
108 | elif vid_id in TEST_NUMBERS:
109 | test_pkl[video_id] = frame_infos
110 | TEST_FRAME_NUMBERS += len(frames_list)
111 | unique_id_test = unique_id
112 | elif vid_id in INFER_NUMBERS:
113 | infer_pkl[video_id] = frame_infos
114 | INFER_FRAME_NUMBERS += len(frames_list)
115 | unique_id_infer = unique_id
116 |
117 | train_save_dir = os.path.join(ROOT_DIR, 'labels', 'train')
118 | os.makedirs(train_save_dir, exist_ok=True)
119 | with open(os.path.join(train_save_dir, '1fpstrain.pickle'), 'wb') as file:
120 | pickle.dump(train_pkl, file)
121 |
122 | test_save_dir = os.path.join(ROOT_DIR, 'labels', 'test')
123 | os.makedirs(test_save_dir, exist_ok=True)
124 | with open(os.path.join(test_save_dir, '1fpstest.pickle'), 'wb') as file:
125 | pickle.dump(test_pkl, file)
126 |
127 | infer_save_dir = os.path.join(ROOT_DIR, 'labels', 'infer')
128 | os.makedirs(infer_save_dir, exist_ok=True)
129 | with open(os.path.join(infer_save_dir, '1fpsinfer.pickle'), 'wb') as file:
130 | pickle.dump(infer_pkl, file)
131 |
132 | print('TRAIN Frames', TRAIN_FRAME_NUMBERS, unique_id_train)
133 | print('TEST Frames', TEST_FRAME_NUMBERS, unique_id_test)
134 | print('INFER Frames', INFER_FRAME_NUMBERS, unique_id_infer)
135 | print('Average Length', (TRAIN_FRAME_NUMBERS)/len(TRAIN_NUMBERS))
136 |
137 | if __name__ == "__main__":
138 | main()
139 |
--------------------------------------------------------------------------------
/datasets/extract_frames/extract_frames_autolaparo.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import cv2
4 | from tqdm import tqdm
5 |
6 | ROOT_DIR = "/xxxxxx/AutoLaparo"
7 | VIDEO_NAMES = os.listdir(os.path.join(ROOT_DIR, "videos"))
8 | VIDEO_NAMES = sorted([x for x in VIDEO_NAMES if 'mp4' in x])
9 |
10 | FRAME_NUMBERS = 0
11 |
12 | for video_name in VIDEO_NAMES:
13 | print(video_name)
14 | vidcap = cv2.VideoCapture(os.path.join(ROOT_DIR, "videos", video_name))
15 | fps = vidcap.get(cv2.CAP_PROP_FPS)
16 | print("fps", fps)
17 | success=True
18 | count=0
19 | save_dir = './frames/' + video_name.replace('.mp4', '') +'/'
20 | save_dir = os.path.join(ROOT_DIR, save_dir)
21 | os.makedirs(save_dir, exist_ok=True)
22 | while success is True:
23 | success,image = vidcap.read()
24 | if success:
25 | if count % fps == 0:
26 | cv2.imwrite(save_dir + str(int(count//fps)).zfill(5) + '.png', image)
27 | count+=1
28 | vidcap.release()
29 | cv2.destroyAllWindows()
30 | print(count)
31 | FRAME_NUMBERS += count
32 |
33 | print('Total Frams', FRAME_NUMBERS)
34 |
--------------------------------------------------------------------------------
/datasets/extract_frames/extract_frames_ch80.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import cv2
4 | from tqdm import tqdm
5 |
6 | ROOT_DIR = "/xxxxxxx/cholec80"
7 | VIDEO_NAMES = os.listdir(os.path.join(ROOT_DIR, 'videos'))
8 | VIDEO_NAMES = [x for x in VIDEO_NAMES if 'mp4' in x]
9 | TRAIN_NUMBERS = np.arange(1,41).tolist()
10 | VAL_NUMBERS = np.arange(41,49).tolist()
11 | TEST_NUMBERS = np.arange(49,81).tolist()
12 |
13 | TRAIN_FRAME_NUMBERS = 0
14 | VAL_FRAME_NUMBERS = 0
15 | TEST_FRAME_NUMBERS = 0
16 |
17 | for video_name in VIDEO_NAMES:
18 | print(video_name)
19 | vidcap = cv2.VideoCapture(os.path.join(ROOT_DIR, 'videos', video_name))
20 | fps = vidcap.get(cv2.CAP_PROP_FPS)
21 | print("fps", fps)
22 | if fps != 25:
23 | print(video_name, 'not at 25fps', fps)
24 | success=True
25 | count=0
26 | vid_id = int(video_name.replace('.mp4', '').replace("video", ""))
27 | if vid_id in TRAIN_NUMBERS:
28 | save_dir = './frames/train/' + video_name.replace('.mp4', '') +'/'
29 | elif vid_id in VAL_NUMBERS:
30 | save_dir = './frames/val/' + video_name.replace('.mp4', '') +'/'
31 | elif vid_id in TEST_NUMBERS:
32 | save_dir = './frames/test/' + video_name.replace('.mp4', '') +'/'
33 | save_dir = os.path.join(ROOT_DIR, save_dir)
34 | os.makedirs(save_dir, exist_ok=True)
35 | while success is True:
36 | success,image = vidcap.read()
37 | if success:
38 | cv2.imwrite(save_dir + str(count) + '.jpg', image)
39 | count+=1
40 | vidcap.release()
41 | cv2.destroyAllWindows()
42 | if vid_id in TRAIN_NUMBERS:
43 | TRAIN_FRAME_NUMBERS += count
44 | elif vid_id in VAL_NUMBERS:
45 | VAL_FRAME_NUMBERS += count
46 | elif vid_id in TEST_NUMBERS:
47 | TEST_FRAME_NUMBERS += count
48 |
49 | print('TRAIN Frams', TRAIN_FRAME_NUMBERS)
50 | print('VAL Frams', VAL_FRAME_NUMBERS)
51 | print('TEST Frams', TEST_FRAME_NUMBERS)
--------------------------------------------------------------------------------
/datasets/functional.py:
--------------------------------------------------------------------------------
1 | import numbers
2 | import cv2
3 | import numpy as np
4 | import PIL
5 | import torch
6 |
7 |
8 | def _is_tensor_clip(clip):
9 | return torch.is_tensor(clip) and clip.ndimension() == 4
10 |
11 |
12 | def crop_clip(clip, min_h, min_w, h, w):
13 | if isinstance(clip[0], np.ndarray):
14 | cropped = [img[min_h : min_h + h, min_w : min_w + w, :] for img in clip]
15 |
16 | elif isinstance(clip[0], PIL.Image.Image):
17 | cropped = [img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip]
18 | else:
19 | raise TypeError(
20 | "Expected numpy.ndarray or PIL.Image"
21 | + "but got list of {0}".format(type(clip[0]))
22 | )
23 | return cropped
24 |
25 |
26 | def resize_clip(clip, size, interpolation="bilinear"):
27 | if isinstance(clip[0], np.ndarray):
28 | if isinstance(size, numbers.Number):
29 | im_h, im_w, im_c = clip[0].shape
30 | # Min spatial dim already matches minimal size, 按最小边resize,如果最短边等于目标长度,则直接返回输入
31 | if (im_w <= im_h and im_w == size) or (im_h <= im_w and im_h == size):
32 | return clip
33 | new_h, new_w = get_resize_sizes(im_h, im_w, size)
34 | size = (new_w, new_h)
35 | else:
36 | size = size[0], size[1]
37 | if interpolation == "bilinear":
38 | np_inter = cv2.INTER_LINEAR
39 | else:
40 | np_inter = cv2.INTER_NEAREST
41 | scaled = [cv2.resize(img, size, interpolation=np_inter) for img in clip]
42 | elif isinstance(clip[0], PIL.Image.Image):
43 | if isinstance(size, numbers.Number):
44 | im_w, im_h = clip[0].size
45 | # Min spatial dim already matches minimal size
46 | if (im_w <= im_h and im_w == size) or (im_h <= im_w and im_h == size):
47 | return clip
48 | new_h, new_w = get_resize_sizes(im_h, im_w, size)
49 | size = (new_w, new_h)
50 | else:
51 | size = size[1], size[0]
52 | if interpolation == "bilinear":
53 | pil_inter = PIL.Image.BILINEAR
54 | else:
55 | pil_inter = PIL.Image.NEAREST
56 | scaled = [img.resize(size, pil_inter) for img in clip]
57 | else:
58 | raise TypeError(
59 | "Expected numpy.ndarray or PIL.Image"
60 | + "but got list of {0}".format(type(clip[0]))
61 | )
62 | return scaled
63 |
64 |
65 | def get_resize_sizes(im_h, im_w, size):
66 | if im_w < im_h:
67 | ow = size
68 | oh = int(size * im_h / im_w)
69 | else:
70 | oh = size
71 | ow = int(size * im_w / im_h)
72 | return oh, ow
73 |
74 |
75 | def normalize(clip, mean, std, inplace=False):
76 | if not _is_tensor_clip(clip):
77 | raise TypeError("tensor is not a torch clip.")
78 |
79 | if not inplace:
80 | clip = clip.clone()
81 |
82 | dtype = clip.dtype
83 | mean = torch.as_tensor(mean, dtype=dtype, device=clip.device)
84 | std = torch.as_tensor(std, dtype=dtype, device=clip.device)
85 | clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
86 |
87 | return clip
88 |
--------------------------------------------------------------------------------
/datasets/phase/__pycache__/AutoLaparo_phase.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/phase/__pycache__/AutoLaparo_phase.cpython-310.pyc
--------------------------------------------------------------------------------
/datasets/phase/__pycache__/AutoLaparo_phase.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/phase/__pycache__/AutoLaparo_phase.cpython-311.pyc
--------------------------------------------------------------------------------
/datasets/phase/__pycache__/Cholec80_phase.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/phase/__pycache__/Cholec80_phase.cpython-310.pyc
--------------------------------------------------------------------------------
/datasets/phase/__pycache__/Cholec80_phase.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/phase/__pycache__/Cholec80_phase.cpython-311.pyc
--------------------------------------------------------------------------------
/datasets/phase/__pycache__/PmLR50_phase.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/phase/__pycache__/PmLR50_phase.cpython-311.pyc
--------------------------------------------------------------------------------
/datasets/tools/frame_cutmargin.py:
--------------------------------------------------------------------------------
1 | # -----------------------------
2 | # Cut black margin for surgical video
3 | # Copyright (c) CUHK 2021.
4 | # IEEE TMI 'Temporal Relation Network for Workflow Recognition from Surgical Video'
5 | # -----------------------------
6 |
7 | import cv2
8 | import os
9 | import numpy as np
10 | import multiprocessing
11 | from tqdm import tqdm
12 |
13 |
14 | def create_directory_if_not_exists(path):
15 | if not os.path.exists(path):
16 | os.makedirs(path)
17 |
18 |
19 | def filter_black(image):
20 | binary_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
21 | _, binary_image2 = cv2.threshold(binary_image, 15, 255, cv2.THRESH_BINARY)
22 | binary_image2 = cv2.medianBlur(
23 | binary_image2, 19
24 | ) # filter the noise, need to adjust the parameter based on the dataset
25 | x = binary_image2.shape[0]
26 | y = binary_image2.shape[1]
27 |
28 | edges_x = []
29 | edges_y = []
30 | for i in range(x):
31 | for j in range(10, y - 10):
32 | if binary_image2.item(i, j) != 0:
33 | edges_x.append(i)
34 | edges_y.append(j)
35 |
36 | if not edges_x:
37 | return image
38 |
39 | left = min(edges_x) # left border
40 | right = max(edges_x) # right
41 | width = right - left
42 | bottom = min(edges_y) # bottom
43 | top = max(edges_y) # top
44 | height = top - bottom
45 |
46 | pre1_picture = image[left : left + width, bottom : bottom + height]
47 |
48 | return pre1_picture
49 |
50 |
51 | def process_image(image_source, image_save):
52 | frame = cv2.imread(image_source)
53 | dim = (int(frame.shape[1] / frame.shape[0] * 300), 300)
54 | frame = cv2.resize(frame, dim)
55 | frame = filter_black(frame)
56 |
57 | img_result = cv2.resize(frame, (250, 250))
58 | cv2.imwrite(image_save, img_result)
59 |
60 |
61 | def process_video(video_id, video_source, video_save):
62 | create_directory_if_not_exists(video_save)
63 |
64 | for image_id in sorted(os.listdir(video_source)):
65 | if image_id == ".DS_Store":
66 | continue
67 | image_source = os.path.join(video_source, image_id)
68 | image_save = os.path.join(video_save, image_id)
69 |
70 | process_image(image_source, image_save)
71 |
72 |
73 | if __name__ == "__main__":
74 | # source_path = "/jhcnas1/yangshu/data/cholec80/frames" # original path
75 | # save_path = "/jhcnas1/yangshu/data/cholec80/frames_cutmargin" # save path
76 |
77 | source_path = "data/cholec80/frames" # original path
78 | save_path = "data/cholec80/frames_cutmargin" # save path
79 |
80 | create_directory_if_not_exists(save_path)
81 |
82 | processes = []
83 |
84 | for data_split in os.listdir(source_path):
85 | if data_split == ".DS_Store":
86 | continue
87 | data_source = os.path.join(source_path, data_split)
88 | data_save = os.path.join(save_path, data_split)
89 |
90 | for video_id in tqdm(os.listdir(data_source)):
91 | if video_id == ".DS_Store":
92 | continue
93 | video_source = os.path.join(data_source, video_id)
94 | video_save = os.path.join(data_save, video_id)
95 |
96 | process = multiprocessing.Process(
97 | target=process_video, args=(video_id, video_source, video_save)
98 | )
99 | process.start()
100 | processes.append(process)
101 |
102 | for process in processes:
103 | process.join()
104 |
105 | print("Cut Done")
106 |
--------------------------------------------------------------------------------
/datasets/tools/resize_frame.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | from tqdm import tqdm
4 | def resize_dataset(input_folder, output_folder, target_size=320):
5 | # 确保输出文件夹存在
6 | if not os.path.exists(output_folder):
7 | os.makedirs(output_folder)
8 |
9 | # 遍历输入文件夹下的所有子文件夹
10 | for subdir, dirs, files in os.walk(input_folder):
11 | # 获取相对路径
12 | rel_path = os.path.relpath(subdir, input_folder)
13 | # 创建输出子文件夹
14 | output_subdir = os.path.join(output_folder, rel_path)
15 | if not os.path.exists(output_subdir):
16 | os.makedirs(output_subdir)
17 | # 遍历当前子文件夹下的所有图像文件
18 | for file in tqdm(files):
19 | # 判断文件是否为图像文件
20 | if file.endswith('.jpg') or file.endswith('.png'):
21 | # 构建输入和输出文件路径
22 | input_path = os.path.join(subdir, file)
23 | output_path = os.path.join(output_subdir, file)
24 | # 打开图像并进行 resize
25 | img = Image.open(input_path)
26 | width, height = img.size
27 | if width < height:
28 | new_width = target_size
29 | new_height = int(height * target_size / width)
30 | else:
31 | new_height = target_size
32 | new_width = int(width * target_size / height)
33 | size = (new_width, new_height)
34 | img = img.resize(size, resample=Image.BILINEAR)
35 |
36 | # 保存图像到输出文件夹
37 | img.save(output_path)
38 | print(f'Processed: {subdir} -> {output_subdir}')
39 | print(len(os.listdir(subdir)), len(os.listdir(output_subdir)))
40 |
41 | if __name__ == "__main__":
42 | # 定义输入和输出文件夹
43 | input_folder = '/jhcnas4/syangcw/AutoLaparo/frames'
44 | output_folder = '/jhcnas4/syangcw/AutoLaparo/frames_resized'
45 | resize_dataset(input_folder, output_folder, 250)
--------------------------------------------------------------------------------
/datasets/tools/transfer_csv_txt.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import os
3 | from collections import defaultdict
4 | import cv2
5 | csv_phase = "/Users/yangshu/Downloads/cataract-101/annotations.csv"
6 | save_txt_folder = "/Users/yangshu/Downloads/cataract-101/phase_annotations"
7 | # 读取CSV文件
8 | videos = defaultdict(list)
9 | with open(csv_phase, "r") as csvfile:
10 | reader = csv.reader(csvfile)
11 | next(reader) # 跳过表头行
12 | for row in reader:
13 | video_id, frame_number, phase_id = row[0].split(";")
14 | # 生成逐帧标注的txt文件名
15 | videos[video_id].append([video_id, frame_number, phase_id])
16 |
17 | count_1 = 0
18 | count_2 = 0
19 | for video_id in videos.keys():
20 | video_detail = videos[video_id]
21 |
22 | frames_position = [int(k[1]) for k in video_detail]
23 | phase_id = [int(k[2]) for k in video_detail]
24 | if phase_id[-1] == phase_id[-2]:
25 | count_1 += 1
26 | txt_filename = os.path.join(save_txt_folder, f"{video_id}.txt")
27 |
28 | # 打开txt文件并写入逐帧标注
29 | with open(txt_filename, "a") as txtfile:
30 | txtfile.write("Frame\tPhase\n") # 写入标题行
31 | phase = phase_id[0]
32 | for frame in range(frames_position[0], frames_position[-1]+1):
33 | if frame in frames_position:
34 | position = frames_position.index(frame)
35 | phase = phase_id[position]
36 | txtfile.write(f"{frame}\t{phase}\n")
37 | else:
38 | count_2 += 1
39 | txt_filename = os.path.join(save_txt_folder, f"{video_id}.txt")
40 |
41 | frame_folder = os.path.join("/Users/yangshu/Downloads/cataract-101/videos/", "case_" + str(video_id)+".mp4")
42 | vidcap = cv2.VideoCapture(frame_folder)
43 | length = int(vidcap.get(7))
44 | # 打开txt文件并写入逐帧标注
45 | with open(txt_filename, "a") as txtfile:
46 | txtfile.write("Frame\tPhase\n") # 写入标题行
47 | phase = phase_id[0]
48 | for frame in range(frames_position[0], length):
49 | if frame in frames_position:
50 | position = frames_position.index(frame)
51 | phase = phase_id[position]
52 | txtfile.write(f"{frame}\t{phase}\n")
53 | print(count_1)
54 | print(count_2)
55 |
--------------------------------------------------------------------------------
/datasets/tools/transfer_json_txt.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 |
4 | def read_json_file(file_path):
5 | with open(file_path, 'r') as file:
6 | data = json.load(file)
7 | return data
8 |
9 | dataset_path = "/Users/yangshu/Downloads/CholecT50"
10 | labels_path = "/Users/yangshu/Downloads/CholecT50/labels"
11 | save_path = "/Users/yangshu/Downloads/CholecT50/verb"
12 |
13 | if not os.path.exists(save_path):
14 | os.makedirs(save_path)
15 |
16 | video_labels = sorted(os.listdir(labels_path))
17 |
18 | total_num_triplet = 0
19 |
20 | for video_label in video_labels:
21 | if "DS" in video_label:
22 | continue
23 | data_labels = dict()
24 | video_name = video_label.split('.')[0]
25 | txt_filename = os.path.join(save_path, video_name + ".txt")
26 | # 读取JSON文件
27 | label_path = os.path.join(labels_path, video_label)
28 | file_path = os.path.join(save_path, video_label)
29 | data = read_json_file(label_path)
30 | categpries = data['categories']
31 | instrument = categpries['instrument']
32 | verb = categpries['verb']
33 | target = categpries['target']
34 | triplet = categpries['triplet']
35 | phase = categpries['phase']
36 |
37 | rs = dict()
38 | anns = data['annotations']
39 | sorted_anns = dict(sorted(anns.items(), key=lambda x: int(x[0])))
40 | for k, v in sorted_anns.items():
41 | r = list()
42 | for instance in v:
43 | if instance[0] == -1:
44 | continue
45 | # triplet = instance[0]
46 | # r.append(str(triplet))
47 | # instrument = instance[1]
48 | # r.append(str(instrument))
49 | verb = instance[7]
50 | r.append(str(verb))
51 | # target = instance[8]
52 | # r.append(str(target))
53 | # phase = instance[14]
54 | # r.append(str(phase))
55 |
56 | # if len(r) >1:
57 | # print(k,video_name)
58 | # if len(set(r)) != len(r):
59 | # print(r)
60 | # print('======')
61 | # rs[k] = set(r)
62 | rs[k] = r
63 | with open(txt_filename, "a") as txtfile:
64 | txtfile.write("Frame\tVerb\n") # 写入标题行
65 | for frame, triplet in rs.items():
66 | if len(triplet) == 0:
67 | target = -1
68 | else:
69 | target = ",".join(triplet)
70 | txtfile.write(f"{frame}\t{target}\n")
--------------------------------------------------------------------------------
/datasets/transforms/__pycache__/mixup.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/transforms/__pycache__/mixup.cpython-310.pyc
--------------------------------------------------------------------------------
/datasets/transforms/__pycache__/mixup.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/transforms/__pycache__/mixup.cpython-311.pyc
--------------------------------------------------------------------------------
/datasets/transforms/__pycache__/optim_factory.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/transforms/__pycache__/optim_factory.cpython-310.pyc
--------------------------------------------------------------------------------
/datasets/transforms/__pycache__/optim_factory.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/transforms/__pycache__/optim_factory.cpython-311.pyc
--------------------------------------------------------------------------------
/datasets/transforms/__pycache__/rand_augment.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/transforms/__pycache__/rand_augment.cpython-310.pyc
--------------------------------------------------------------------------------
/datasets/transforms/__pycache__/rand_augment.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/transforms/__pycache__/rand_augment.cpython-311.pyc
--------------------------------------------------------------------------------
/datasets/transforms/__pycache__/random_erasing.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/transforms/__pycache__/random_erasing.cpython-310.pyc
--------------------------------------------------------------------------------
/datasets/transforms/__pycache__/random_erasing.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/transforms/__pycache__/random_erasing.cpython-311.pyc
--------------------------------------------------------------------------------
/datasets/transforms/__pycache__/surg_transforms.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/transforms/__pycache__/surg_transforms.cpython-310.pyc
--------------------------------------------------------------------------------
/datasets/transforms/__pycache__/surg_transforms.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/transforms/__pycache__/surg_transforms.cpython-311.pyc
--------------------------------------------------------------------------------
/datasets/transforms/__pycache__/video_transforms.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/transforms/__pycache__/video_transforms.cpython-310.pyc
--------------------------------------------------------------------------------
/datasets/transforms/__pycache__/video_transforms.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/transforms/__pycache__/video_transforms.cpython-311.pyc
--------------------------------------------------------------------------------
/datasets/transforms/__pycache__/volume_transforms.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/transforms/__pycache__/volume_transforms.cpython-310.pyc
--------------------------------------------------------------------------------
/datasets/transforms/__pycache__/volume_transforms.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/datasets/transforms/__pycache__/volume_transforms.cpython-311.pyc
--------------------------------------------------------------------------------
/datasets/transforms/image_transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image, ImageOps
3 | import torch
4 | import numbers
5 | import random
6 | import torchvision.transforms.functional as TF
7 |
8 |
9 | class RandomCrop(object):
10 |
11 | def __init__(self, size, padding=0):
12 | if isinstance(size, numbers.Number):
13 | self.size = (int(size), int(size))
14 | else:
15 | self.size = size
16 | self.padding = padding
17 | self.count = 0
18 |
19 | def __call__(self, img):
20 |
21 | if self.padding > 0:
22 | img = ImageOps.expand(img, border=self.padding, fill=0)
23 |
24 | w, h = img.size
25 | th, tw = self.size
26 | if w == tw and h == th:
27 | return img
28 |
29 | random.seed(self.count)
30 | x1 = random.randint(0, w - tw)
31 | y1 = random.randint(0, h - th)
32 | # print(self.count, x1, y1)
33 | self.count += 1
34 | return img.crop((x1, y1, x1 + tw, y1 + th))
35 |
36 |
37 | class RandomHorizontalFlip(object):
38 | def __init__(self):
39 | self.count = 0
40 |
41 | def __call__(self, img):
42 | seed = self.count
43 | random.seed(seed)
44 | prob = random.random()
45 | self.count += 1
46 | # print(self.count, seed, prob)
47 | if prob < 0.5:
48 | return img.transpose(Image.FLIP_LEFT_RIGHT)
49 | return img
50 |
51 |
52 | class RandomRotation(object):
53 | def __init__(self,degrees):
54 | self.degrees = degrees
55 | self.count = 0
56 |
57 | def __call__(self, img):
58 | seed = self.count
59 | random.seed(seed)
60 | self.count += 1
61 | angle = random.randint(-self.degrees,self.degrees)
62 | return TF.rotate(img, angle)
63 |
64 |
65 | class ColorJitter(object):
66 | def __init__(self,brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1):
67 | self.brightness = brightness
68 | self.contrast = contrast
69 | self.saturation = saturation
70 | self.hue = hue
71 | self.count = 0
72 |
73 | def __call__(self, img):
74 | seed = self.count
75 | random.seed(seed)
76 | self.count += 1
77 | brightness_factor = random.uniform(1 - self.brightness, 1 + self.brightness)
78 | contrast_factor = random.uniform(1 - self.contrast, 1 + self.contrast)
79 | saturation_factor = random.uniform(1 - self.saturation, 1 + self.saturation)
80 | hue_factor = random.uniform(- self.hue, self.hue)
81 |
82 | img_ = TF.adjust_brightness(img,brightness_factor)
83 | img_ = TF.adjust_contrast(img_,contrast_factor)
84 | img_ = TF.adjust_saturation(img_,saturation_factor)
85 | img_ = TF.adjust_hue(img_,hue_factor)
86 |
87 | return img_
--------------------------------------------------------------------------------
/datasets/transforms/mixup.py:
--------------------------------------------------------------------------------
1 | """ Mixup and Cutmix
2 |
3 | Papers:
4 | mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
5 |
6 | CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)
7 |
8 | Code Reference:
9 | CutMix: https://github.com/clovaai/CutMix-PyTorch
10 |
11 | Hacked together by / Copyright 2019, Ross Wightman
12 | """
13 | import numpy as np
14 | import torch
15 |
16 |
17 | def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
18 | x = x.long().view(-1, 1)
19 | return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
20 |
21 |
22 | def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
23 | off_value = smoothing / num_classes
24 | on_value = 1. - smoothing + off_value
25 | y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device)
26 | y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device)
27 | return y1 * lam + y2 * (1. - lam)
28 |
29 |
30 | def rand_bbox(img_shape, lam, margin=0., count=None):
31 | """ Standard CutMix bounding-box
32 | Generates a random square bbox based on lambda value. This impl includes
33 | support for enforcing a border margin as percent of bbox dimensions.
34 |
35 | Args:
36 | img_shape (tuple): Image shape as tuple
37 | lam (float): Cutmix lambda value
38 | margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
39 | count (int): Number of bbox to generate
40 | """
41 | ratio = np.sqrt(1 - lam)
42 | img_h, img_w = img_shape[-2:]
43 | cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
44 | margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
45 | cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
46 | cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
47 | yl = np.clip(cy - cut_h // 2, 0, img_h)
48 | yh = np.clip(cy + cut_h // 2, 0, img_h)
49 | xl = np.clip(cx - cut_w // 2, 0, img_w)
50 | xh = np.clip(cx + cut_w // 2, 0, img_w)
51 | return yl, yh, xl, xh
52 |
53 |
54 | def rand_bbox_minmax(img_shape, minmax, count=None):
55 | """ Min-Max CutMix bounding-box
56 | Inspired by Darknet cutmix impl, generates a random rectangular bbox
57 | based on min/max percent values applied to each dimension of the input image.
58 |
59 | Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.
60 |
61 | Args:
62 | img_shape (tuple): Image shape as tuple
63 | minmax (tuple or list): Min and max bbox ratios (as percent of image size)
64 | count (int): Number of bbox to generate
65 | """
66 | assert len(minmax) == 2
67 | img_h, img_w = img_shape[-2:]
68 | cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count)
69 | cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count)
70 | yl = np.random.randint(0, img_h - cut_h, size=count)
71 | xl = np.random.randint(0, img_w - cut_w, size=count)
72 | yu = yl + cut_h
73 | xu = xl + cut_w
74 | return yl, yu, xl, xu
75 |
76 |
77 | def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None):
78 | """ Generate bbox and apply lambda correction.
79 | """
80 | if ratio_minmax is not None:
81 | yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count)
82 | else:
83 | yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count)
84 | if correct_lam or ratio_minmax is not None:
85 | bbox_area = (yu - yl) * (xu - xl)
86 | lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
87 | return (yl, yu, xl, xu), lam
88 |
89 |
90 | class Mixup:
91 | """ Mixup/Cutmix that applies different params to each element or whole batch
92 |
93 | Args:
94 | mixup_alpha (float): mixup alpha value, mixup is active if > 0.
95 | cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
96 | cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
97 | prob (float): probability of applying mixup or cutmix per batch or element
98 | switch_prob (float): probability of switching to cutmix instead of mixup when both are active
99 | mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
100 | correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
101 | label_smoothing (float): apply label smoothing to the mixed target tensor
102 | num_classes (int): number of classes for target
103 | """
104 | def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
105 | mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000):
106 | self.mixup_alpha = mixup_alpha
107 | self.cutmix_alpha = cutmix_alpha
108 | self.cutmix_minmax = cutmix_minmax
109 | if self.cutmix_minmax is not None:
110 | assert len(self.cutmix_minmax) == 2
111 | # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
112 | self.cutmix_alpha = 1.0
113 | self.mix_prob = prob
114 | self.switch_prob = switch_prob
115 | self.label_smoothing = label_smoothing
116 | self.num_classes = num_classes
117 | self.mode = mode
118 | self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
119 | self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
120 |
121 | def _params_per_elem(self, batch_size):
122 | lam = np.ones(batch_size, dtype=np.float32)
123 | use_cutmix = np.zeros(batch_size, dtype=np.bool)
124 | if self.mixup_enabled:
125 | if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
126 | use_cutmix = np.random.rand(batch_size) < self.switch_prob
127 | lam_mix = np.where(
128 | use_cutmix,
129 | np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size),
130 | np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size))
131 | elif self.mixup_alpha > 0.:
132 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)
133 | elif self.cutmix_alpha > 0.:
134 | use_cutmix = np.ones(batch_size, dtype=np.bool)
135 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size)
136 | else:
137 | assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
138 | lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam)
139 | return lam, use_cutmix
140 |
141 | def _params_per_batch(self):
142 | lam = 1.
143 | use_cutmix = False
144 | if self.mixup_enabled and np.random.rand() < self.mix_prob:
145 | if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
146 | use_cutmix = np.random.rand() < self.switch_prob
147 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \
148 | np.random.beta(self.mixup_alpha, self.mixup_alpha)
149 | elif self.mixup_alpha > 0.:
150 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha)
151 | elif self.cutmix_alpha > 0.:
152 | use_cutmix = True
153 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
154 | else:
155 | assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
156 | lam = float(lam_mix)
157 | return lam, use_cutmix
158 |
159 | def _mix_elem(self, x):
160 | batch_size = len(x)
161 | lam_batch, use_cutmix = self._params_per_elem(batch_size)
162 | x_orig = x.clone() # need to keep an unmodified original for mixing source
163 | for i in range(batch_size):
164 | j = batch_size - i - 1
165 | lam = lam_batch[i]
166 | if lam != 1.:
167 | if use_cutmix[i]:
168 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
169 | x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
170 | x[i][..., yl:yh, xl:xh] = x_orig[j][..., yl:yh, xl:xh]
171 | lam_batch[i] = lam
172 | else:
173 | x[i] = x[i] * lam + x_orig[j] * (1 - lam)
174 | return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
175 |
176 | def _mix_pair(self, x):
177 | batch_size = len(x)
178 | lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
179 | x_orig = x.clone() # need to keep an unmodified original for mixing source
180 | for i in range(batch_size // 2):
181 | j = batch_size - i - 1
182 | lam = lam_batch[i]
183 | if lam != 1.:
184 | if use_cutmix[i]:
185 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
186 | x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
187 | x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
188 | x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh]
189 | lam_batch[i] = lam
190 | else:
191 | x[i] = x[i] * lam + x_orig[j] * (1 - lam)
192 | x[j] = x[j] * lam + x_orig[i] * (1 - lam)
193 | lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
194 | return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
195 |
196 | def _mix_batch(self, x):
197 | lam, use_cutmix = self._params_per_batch()
198 | if lam == 1.:
199 | return 1.
200 | if use_cutmix:
201 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
202 | x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
203 | x[..., yl:yh, xl:xh] = x.flip(0)[..., yl:yh, xl:xh]
204 | else:
205 | x_flipped = x.flip(0).mul_(1. - lam)
206 | x.mul_(lam).add_(x_flipped)
207 | return lam
208 |
209 | def __call__(self, x, target):
210 | assert len(x) % 2 == 0, 'Batch size should be even when using this'
211 | if self.mode == 'elem':
212 | lam = self._mix_elem(x)
213 | elif self.mode == 'pair':
214 | lam = self._mix_pair(x)
215 | else:
216 | lam = self._mix_batch(x)
217 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device)
218 | return x, target
219 |
220 |
221 | class FastCollateMixup(Mixup):
222 | """ Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch
223 |
224 | A Mixup impl that's performed while collating the batches.
225 | """
226 |
227 | def _mix_elem_collate(self, output, batch, half=False):
228 | batch_size = len(batch)
229 | num_elem = batch_size // 2 if half else batch_size
230 | assert len(output) == num_elem
231 | lam_batch, use_cutmix = self._params_per_elem(num_elem)
232 | for i in range(num_elem):
233 | j = batch_size - i - 1
234 | lam = lam_batch[i]
235 | mixed = batch[i][0]
236 | if lam != 1.:
237 | if use_cutmix[i]:
238 | if not half:
239 | mixed = mixed.copy()
240 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
241 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
242 | mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
243 | lam_batch[i] = lam
244 | else:
245 | mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
246 | np.rint(mixed, out=mixed)
247 | output[i] += torch.from_numpy(mixed.astype(np.uint8))
248 | if half:
249 | lam_batch = np.concatenate((lam_batch, np.ones(num_elem)))
250 | return torch.tensor(lam_batch).unsqueeze(1)
251 |
252 | def _mix_pair_collate(self, output, batch):
253 | batch_size = len(batch)
254 | lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
255 | for i in range(batch_size // 2):
256 | j = batch_size - i - 1
257 | lam = lam_batch[i]
258 | mixed_i = batch[i][0]
259 | mixed_j = batch[j][0]
260 | assert 0 <= lam <= 1.0
261 | if lam < 1.:
262 | if use_cutmix[i]:
263 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
264 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
265 | patch_i = mixed_i[:, yl:yh, xl:xh].copy()
266 | mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh]
267 | mixed_j[:, yl:yh, xl:xh] = patch_i
268 | lam_batch[i] = lam
269 | else:
270 | mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam)
271 | mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam)
272 | mixed_i = mixed_temp
273 | np.rint(mixed_j, out=mixed_j)
274 | np.rint(mixed_i, out=mixed_i)
275 | output[i] += torch.from_numpy(mixed_i.astype(np.uint8))
276 | output[j] += torch.from_numpy(mixed_j.astype(np.uint8))
277 | lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
278 | return torch.tensor(lam_batch).unsqueeze(1)
279 |
280 | def _mix_batch_collate(self, output, batch):
281 | batch_size = len(batch)
282 | lam, use_cutmix = self._params_per_batch()
283 | if use_cutmix:
284 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
285 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
286 | for i in range(batch_size):
287 | j = batch_size - i - 1
288 | mixed = batch[i][0]
289 | if lam != 1.:
290 | if use_cutmix:
291 | mixed = mixed.copy() # don't want to modify the original while iterating
292 | mixed[..., yl:yh, xl:xh] = batch[j][0][..., yl:yh, xl:xh]
293 | else:
294 | mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
295 | np.rint(mixed, out=mixed)
296 | output[i] += torch.from_numpy(mixed.astype(np.uint8))
297 | return lam
298 |
299 | def __call__(self, batch, _=None):
300 | batch_size = len(batch)
301 | assert batch_size % 2 == 0, 'Batch size should be even when using this'
302 | half = 'half' in self.mode
303 | if half:
304 | batch_size //= 2
305 | output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
306 | if self.mode == 'elem' or self.mode == 'half':
307 | lam = self._mix_elem_collate(output, batch, half=half)
308 | elif self.mode == 'pair':
309 | lam = self._mix_pair_collate(output, batch)
310 | else:
311 | lam = self._mix_batch_collate(output, batch)
312 | target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
313 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
314 | target = target[:batch_size]
315 | return output, target
316 |
317 |
--------------------------------------------------------------------------------
/datasets/transforms/optim_factory.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import optim as optim
3 |
4 | from timm.optim.adafactor import Adafactor
5 | from timm.optim.adahessian import Adahessian
6 | from timm.optim.adamp import AdamP
7 | from timm.optim.lookahead import Lookahead
8 | from timm.optim.nadam import Nadam
9 | from timm.optim.novograd import NovoGrad
10 | from timm.optim.nvnovograd import NvNovoGrad
11 | from timm.optim.radam import RAdam
12 | from timm.optim.rmsprop_tf import RMSpropTF
13 | from timm.optim.sgdp import SGDP
14 |
15 | import json
16 |
17 | try:
18 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
19 | has_apex = True
20 | except ImportError:
21 | has_apex = False
22 |
23 |
24 | def get_num_layer_for_vit(var_name, num_max_layer):
25 | if var_name in ("cls_token", "mask_token", "pos_embed"):
26 | return 0
27 | elif var_name.startswith("patch_embed"):
28 | return 0
29 | elif var_name.startswith("temporal_embedding"):
30 | return 0
31 | elif var_name.startswith("time_embed"):
32 | return 0
33 | elif var_name in ("class_embedding", "positional_embedding", "temporal_positional_embedding"):
34 | return 0
35 | elif var_name.startswith("conv1"):
36 | return 0
37 | elif var_name.startswith("rel_pos_bias"):
38 | return num_max_layer - 1
39 | elif var_name.startswith("blocks"):
40 | layer_id = int(var_name.split('.')[1])
41 | return layer_id + 1
42 | elif var_name.startswith("transformer.resblocks"):
43 | layer_id = int(var_name.split('.')[2])
44 | return layer_id + 1
45 | else:
46 | return num_max_layer - 1
47 |
48 |
49 | class LayerDecayValueAssigner(object):
50 | def __init__(self, values):
51 | self.values = values
52 |
53 | def get_scale(self, layer_id):
54 | return self.values[layer_id]
55 |
56 | def get_layer_id(self, var_name):
57 | return get_num_layer_for_vit(var_name, len(self.values))
58 |
59 |
60 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None):
61 | parameter_group_names = {}
62 | parameter_group_vars = {}
63 |
64 | for name, param in model.named_parameters():
65 | if not param.requires_grad:
66 | continue # frozen weights
67 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
68 | group_name = "no_decay"
69 | this_weight_decay = 0.
70 | else:
71 | group_name = "decay"
72 | this_weight_decay = weight_decay
73 | if get_num_layer is not None:
74 | layer_id = get_num_layer(name)
75 | group_name = "layer_%d_%s" % (layer_id, group_name)
76 | else:
77 | layer_id = None
78 |
79 | if group_name not in parameter_group_names:
80 | if get_layer_scale is not None:
81 | scale = get_layer_scale(layer_id)
82 | else:
83 | scale = 1.
84 |
85 | parameter_group_names[group_name] = {
86 | "weight_decay": this_weight_decay,
87 | "params": [],
88 | "lr_scale": scale
89 | }
90 | parameter_group_vars[group_name] = {
91 | "weight_decay": this_weight_decay,
92 | "params": [],
93 | "lr_scale": scale
94 | }
95 |
96 | parameter_group_vars[group_name]["params"].append(param)
97 | parameter_group_names[group_name]["params"].append(name)
98 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
99 | return list(parameter_group_vars.values())
100 |
101 |
102 | def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None):
103 | opt_lower = args.opt.lower()
104 | weight_decay = args.weight_decay
105 | if weight_decay and filter_bias_and_bn:
106 | skip = {}
107 | if skip_list is not None:
108 | skip = skip_list
109 | elif hasattr(model, 'no_weight_decay'):
110 | skip = model.no_weight_decay()
111 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale)
112 | weight_decay = 0.
113 | else:
114 | parameters = model.parameters()
115 |
116 | if 'fused' in opt_lower:
117 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
118 |
119 | opt_args = dict(lr=args.lr, weight_decay=weight_decay)
120 | if hasattr(args, 'opt_eps') and args.opt_eps is not None:
121 | opt_args['eps'] = args.opt_eps
122 | if hasattr(args, 'opt_betas') and args.opt_betas is not None:
123 | opt_args['betas'] = args.opt_betas
124 |
125 | print("optimizer settings:", opt_args)
126 |
127 | opt_split = opt_lower.split('_')
128 | opt_lower = opt_split[-1]
129 | if opt_lower == 'sgd' or opt_lower == 'nesterov':
130 | opt_args.pop('eps', None)
131 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
132 | elif opt_lower == 'momentum':
133 | opt_args.pop('eps', None)
134 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
135 | elif opt_lower == 'adam':
136 | optimizer = optim.Adam(parameters, **opt_args)
137 | elif opt_lower == 'adamw':
138 | optimizer = optim.AdamW(parameters, **opt_args)
139 | elif opt_lower == 'nadam':
140 | optimizer = Nadam(parameters, **opt_args)
141 | elif opt_lower == 'radam':
142 | optimizer = RAdam(parameters, **opt_args)
143 | elif opt_lower == 'adamp':
144 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
145 | elif opt_lower == 'sgdp':
146 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args)
147 | elif opt_lower == 'adadelta':
148 | optimizer = optim.Adadelta(parameters, **opt_args)
149 | elif opt_lower == 'adafactor':
150 | if not args.lr:
151 | opt_args['lr'] = None
152 | optimizer = Adafactor(parameters, **opt_args)
153 | elif opt_lower == 'adahessian':
154 | optimizer = Adahessian(parameters, **opt_args)
155 | elif opt_lower == 'rmsprop':
156 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
157 | elif opt_lower == 'rmsproptf':
158 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
159 | elif opt_lower == 'novograd':
160 | optimizer = NovoGrad(parameters, **opt_args)
161 | elif opt_lower == 'nvnovograd':
162 | optimizer = NvNovoGrad(parameters, **opt_args)
163 | elif opt_lower == 'fusedsgd':
164 | opt_args.pop('eps', None)
165 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
166 | elif opt_lower == 'fusedmomentum':
167 | opt_args.pop('eps', None)
168 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
169 | elif opt_lower == 'fusedadam':
170 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
171 | elif opt_lower == 'fusedadamw':
172 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
173 | elif opt_lower == 'fusedlamb':
174 | optimizer = FusedLAMB(parameters, **opt_args)
175 | elif opt_lower == 'fusednovograd':
176 | opt_args.setdefault('betas', (0.95, 0.98))
177 | optimizer = FusedNovoGrad(parameters, **opt_args)
178 | else:
179 | assert False and "Invalid optimizer"
180 | raise ValueError
181 |
182 | if len(opt_split) > 1:
183 | if opt_split[0] == 'lookahead':
184 | optimizer = Lookahead(optimizer)
185 |
186 | return optimizer
187 |
--------------------------------------------------------------------------------
/datasets/transforms/rand_augment.py:
--------------------------------------------------------------------------------
1 | """
2 | This implementation is based on
3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py
4 | pulished under an Apache License 2.0.
5 |
6 | COMMENT FROM ORIGINAL:
7 | AutoAugment, RandAugment, and AugMix for PyTorch
8 | This code implements the searched ImageNet policies with various tweaks and
9 | improvements and does not include any of the search code. AA and RA
10 | Implementation adapted from:
11 | https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
12 | AugMix adapted from:
13 | https://github.com/google-research/augmix
14 | Papers:
15 | AutoAugment: Learning Augmentation Policies from Data
16 | https://arxiv.org/abs/1805.09501
17 | Learning Data Augmentation Strategies for Object Detection
18 | https://arxiv.org/abs/1906.11172
19 | RandAugment: Practical automated data augmentation...
20 | https://arxiv.org/abs/1909.13719
21 | AugMix: A Simple Data Processing Method to Improve Robustness and
22 | Uncertainty https://arxiv.org/abs/1912.02781
23 |
24 | Hacked together by / Copyright 2020 Ross Wightman
25 | """
26 |
27 | import math
28 | import numpy as np
29 | import random
30 | import re
31 | import PIL
32 | from PIL import Image, ImageEnhance, ImageOps
33 |
34 | _PIL_VER = tuple([int(x) for x in PIL.__version__.split(".")[:2]])
35 |
36 | _FILL = (128, 128, 128) # 灰色补全
37 | # _FILL = (0, 0, 0) # 黑色补全,与正常内窥镜截断一致
38 |
39 | # This signifies the max integer that the controller RNN could predict for the
40 | # augmentation scheme.
41 | _MAX_LEVEL = 10.0
42 |
43 | _HPARAMS_DEFAULT = {
44 | "translate_const": 250,
45 | "img_mean": _FILL,
46 | }
47 |
48 | _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
49 |
50 |
51 | def _interpolation(kwargs):
52 | interpolation = kwargs.pop("resample", Image.BILINEAR)
53 | if isinstance(interpolation, (list, tuple)):
54 | return random.choice(interpolation)
55 | else:
56 | return interpolation
57 |
58 |
59 | def _check_args_tf(kwargs):
60 | if "fillcolor" in kwargs and _PIL_VER < (5, 0):
61 | kwargs.pop("fillcolor")
62 | kwargs["resample"] = _interpolation(kwargs)
63 |
64 |
65 | def shear_x(img, factor, **kwargs):
66 | _check_args_tf(kwargs)
67 | return img.transform(
68 | img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs
69 | )
70 |
71 |
72 | def shear_y(img, factor, **kwargs):
73 | _check_args_tf(kwargs)
74 | return img.transform(
75 | img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs
76 | )
77 |
78 |
79 | def translate_x_rel(img, pct, **kwargs):
80 | pixels = pct * img.size[0]
81 | _check_args_tf(kwargs)
82 | return img.transform(
83 | img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs
84 | )
85 |
86 |
87 | def translate_y_rel(img, pct, **kwargs):
88 | pixels = pct * img.size[1]
89 | _check_args_tf(kwargs)
90 | return img.transform(
91 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs
92 | )
93 |
94 |
95 | def translate_x_abs(img, pixels, **kwargs):
96 | _check_args_tf(kwargs)
97 | return img.transform(
98 | img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs
99 | )
100 |
101 |
102 | def translate_y_abs(img, pixels, **kwargs):
103 | _check_args_tf(kwargs)
104 | return img.transform(
105 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs
106 | )
107 |
108 |
109 | def rotate(img, degrees, **kwargs):
110 | _check_args_tf(kwargs)
111 | if _PIL_VER >= (5, 2):
112 | return img.rotate(degrees, **kwargs)
113 | elif _PIL_VER >= (5, 0):
114 | w, h = img.size
115 | post_trans = (0, 0)
116 | rotn_center = (w / 2.0, h / 2.0)
117 | angle = -math.radians(degrees)
118 | matrix = [
119 | round(math.cos(angle), 15),
120 | round(math.sin(angle), 15),
121 | 0.0,
122 | round(-math.sin(angle), 15),
123 | round(math.cos(angle), 15),
124 | 0.0,
125 | ]
126 |
127 | def transform(x, y, matrix):
128 | (a, b, c, d, e, f) = matrix
129 | return a * x + b * y + c, d * x + e * y + f
130 |
131 | matrix[2], matrix[5] = transform(
132 | -rotn_center[0] - post_trans[0],
133 | -rotn_center[1] - post_trans[1],
134 | matrix,
135 | )
136 | matrix[2] += rotn_center[0]
137 | matrix[5] += rotn_center[1]
138 | return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
139 | else:
140 | return img.rotate(degrees, resample=kwargs["resample"])
141 |
142 |
143 | def auto_contrast(img, **__):
144 | return ImageOps.autocontrast(img)
145 |
146 |
147 | def invert(img, **__):
148 | return ImageOps.invert(img)
149 |
150 |
151 | def equalize(img, **__):
152 | return ImageOps.equalize(img)
153 |
154 |
155 | def solarize(img, thresh, **__):
156 | return ImageOps.solarize(img, thresh)
157 |
158 |
159 | def solarize_add(img, add, thresh=128, **__):
160 | lut = []
161 | for i in range(256):
162 | if i < thresh:
163 | lut.append(min(255, i + add))
164 | else:
165 | lut.append(i)
166 | if img.mode in ("L", "RGB"):
167 | if img.mode == "RGB" and len(lut) == 256:
168 | lut = lut + lut + lut
169 | return img.point(lut)
170 | else:
171 | return img
172 |
173 |
174 | def posterize(img, bits_to_keep, **__):
175 | if bits_to_keep >= 8:
176 | return img
177 | return ImageOps.posterize(img, bits_to_keep)
178 |
179 |
180 | def contrast(img, factor, **__):
181 | return ImageEnhance.Contrast(img).enhance(factor)
182 |
183 |
184 | def color(img, factor, **__):
185 | return ImageEnhance.Color(img).enhance(factor)
186 |
187 |
188 | def brightness(img, factor, **__):
189 | return ImageEnhance.Brightness(img).enhance(factor)
190 |
191 |
192 | def sharpness(img, factor, **__):
193 | return ImageEnhance.Sharpness(img).enhance(factor)
194 |
195 |
196 | def _randomly_negate(v):
197 | """With 50% prob, negate the value"""
198 | return -v if random.random() > 0.5 else v
199 |
200 |
201 | def _rotate_level_to_arg(level, _hparams):
202 | # range [-30, 30]
203 | level = (level / _MAX_LEVEL) * 30.0
204 | level = _randomly_negate(level)
205 | return (level,)
206 |
207 |
208 | def _enhance_level_to_arg(level, _hparams):
209 | # range [0.1, 1.9]
210 | return ((level / _MAX_LEVEL) * 1.8 + 0.1,)
211 |
212 |
213 | def _enhance_increasing_level_to_arg(level, _hparams):
214 | # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
215 | # range [0.1, 1.9]
216 | level = (level / _MAX_LEVEL) * 0.9
217 | level = 1.0 + _randomly_negate(level)
218 | return (level,)
219 |
220 |
221 | def _shear_level_to_arg(level, _hparams):
222 | # range [-0.3, 0.3]
223 | level = (level / _MAX_LEVEL) * 0.3
224 | level = _randomly_negate(level)
225 | return (level,)
226 |
227 |
228 | def _translate_abs_level_to_arg(level, hparams):
229 | translate_const = hparams["translate_const"]
230 | level = (level / _MAX_LEVEL) * float(translate_const)
231 | level = _randomly_negate(level)
232 | return (level,)
233 |
234 |
235 | def _translate_rel_level_to_arg(level, hparams):
236 | # default range [-0.45, 0.45]
237 | translate_pct = hparams.get("translate_pct", 0.45)
238 | level = (level / _MAX_LEVEL) * translate_pct
239 | level = _randomly_negate(level)
240 | return (level,)
241 |
242 |
243 | def _posterize_level_to_arg(level, _hparams):
244 | # As per Tensorflow TPU EfficientNet impl
245 | # range [0, 4], 'keep 0 up to 4 MSB of original image'
246 | # intensity/severity of augmentation decreases with level
247 | return (int((level / _MAX_LEVEL) * 4),)
248 |
249 |
250 | def _posterize_increasing_level_to_arg(level, hparams):
251 | # As per Tensorflow models research and UDA impl
252 | # range [4, 0], 'keep 4 down to 0 MSB of original image',
253 | # intensity/severity of augmentation increases with level
254 | return (4 - _posterize_level_to_arg(level, hparams)[0],)
255 |
256 |
257 | def _posterize_original_level_to_arg(level, _hparams):
258 | # As per original AutoAugment paper description
259 | # range [4, 8], 'keep 4 up to 8 MSB of image'
260 | # intensity/severity of augmentation decreases with level
261 | return (int((level / _MAX_LEVEL) * 4) + 4,)
262 |
263 |
264 | def _solarize_level_to_arg(level, _hparams):
265 | # range [0, 256]
266 | # intensity/severity of augmentation decreases with level
267 | return (int((level / _MAX_LEVEL) * 256),)
268 |
269 |
270 | def _solarize_increasing_level_to_arg(level, _hparams):
271 | # range [0, 256]
272 | # intensity/severity of augmentation increases with level
273 | return (256 - _solarize_level_to_arg(level, _hparams)[0],)
274 |
275 |
276 | def _solarize_add_level_to_arg(level, _hparams):
277 | # range [0, 110]
278 | return (int((level / _MAX_LEVEL) * 110),)
279 |
280 |
281 | LEVEL_TO_ARG = {
282 | "AutoContrast": None,
283 | "Equalize": None,
284 | "Invert": None,
285 | "Rotate": _rotate_level_to_arg,
286 | # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
287 | "Posterize": _posterize_level_to_arg,
288 | "PosterizeIncreasing": _posterize_increasing_level_to_arg,
289 | "PosterizeOriginal": _posterize_original_level_to_arg,
290 | "Solarize": _solarize_level_to_arg,
291 | "SolarizeIncreasing": _solarize_increasing_level_to_arg,
292 | "SolarizeAdd": _solarize_add_level_to_arg,
293 | "Color": _enhance_level_to_arg,
294 | "ColorIncreasing": _enhance_increasing_level_to_arg,
295 | "Contrast": _enhance_level_to_arg,
296 | "ContrastIncreasing": _enhance_increasing_level_to_arg,
297 | "Brightness": _enhance_level_to_arg,
298 | "BrightnessIncreasing": _enhance_increasing_level_to_arg,
299 | "Sharpness": _enhance_level_to_arg,
300 | "SharpnessIncreasing": _enhance_increasing_level_to_arg,
301 | "ShearX": _shear_level_to_arg,
302 | "ShearY": _shear_level_to_arg,
303 | "TranslateX": _translate_abs_level_to_arg,
304 | "TranslateY": _translate_abs_level_to_arg,
305 | "TranslateXRel": _translate_rel_level_to_arg,
306 | "TranslateYRel": _translate_rel_level_to_arg,
307 | }
308 |
309 |
310 | NAME_TO_OP = {
311 | "AutoContrast": auto_contrast,
312 | "Equalize": equalize,
313 | "Invert": invert,
314 | "Rotate": rotate,
315 | "Posterize": posterize,
316 | "PosterizeIncreasing": posterize,
317 | "PosterizeOriginal": posterize,
318 | "Solarize": solarize,
319 | "SolarizeIncreasing": solarize,
320 | "SolarizeAdd": solarize_add,
321 | "Color": color,
322 | "ColorIncreasing": color,
323 | "Contrast": contrast,
324 | "ContrastIncreasing": contrast,
325 | "Brightness": brightness,
326 | "BrightnessIncreasing": brightness,
327 | "Sharpness": sharpness,
328 | "SharpnessIncreasing": sharpness,
329 | "ShearX": shear_x,
330 | "ShearY": shear_y,
331 | "TranslateX": translate_x_abs,
332 | "TranslateY": translate_y_abs,
333 | "TranslateXRel": translate_x_rel,
334 | "TranslateYRel": translate_y_rel,
335 | }
336 |
337 |
338 | class AugmentOp:
339 | """
340 | Apply for video.
341 | """
342 |
343 | def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
344 | hparams = hparams or _HPARAMS_DEFAULT
345 | self.aug_fn = NAME_TO_OP[name]
346 | self.level_fn = LEVEL_TO_ARG[name]
347 | self.prob = prob
348 | self.magnitude = magnitude
349 | self.hparams = hparams.copy()
350 | self.kwargs = {
351 | "fillcolor": hparams["img_mean"]
352 | if "img_mean" in hparams
353 | else _FILL,
354 | "resample": hparams["interpolation"]
355 | if "interpolation" in hparams
356 | else _RANDOM_INTERPOLATION,
357 | }
358 |
359 | # If magnitude_std is > 0, we introduce some randomness
360 | # in the usually fixed policy and sample magnitude from a normal distribution
361 | # with mean `magnitude` and std-dev of `magnitude_std`.
362 | # NOTE This is my own hack, being tested, not in papers or reference impls.
363 | self.magnitude_std = self.hparams.get("magnitude_std", 0)
364 |
365 | def __call__(self, img_list):
366 | if self.prob < 1.0 and random.random() > self.prob:
367 | return img_list
368 | magnitude = self.magnitude
369 | if self.magnitude_std and self.magnitude_std > 0:
370 | magnitude = random.gauss(magnitude, self.magnitude_std)
371 | magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
372 | level_args = (
373 | self.level_fn(magnitude, self.hparams)
374 | if self.level_fn is not None
375 | else ()
376 | )
377 |
378 | if isinstance(img_list, list):
379 | return [
380 | self.aug_fn(img, *level_args, **self.kwargs) for img in img_list
381 | ]
382 | else:
383 | return self.aug_fn(img_list, *level_args, **self.kwargs)
384 |
385 |
386 | _RAND_TRANSFORMS = [
387 | "AutoContrast",
388 | "Equalize",
389 | "Invert",
390 | "Rotate",
391 | "Posterize",
392 | "Solarize",
393 | "SolarizeAdd",
394 | "Color",
395 | "Contrast",
396 | "Brightness",
397 | "Sharpness",
398 | "ShearX",
399 | "ShearY",
400 | "TranslateXRel",
401 | "TranslateYRel",
402 | ]
403 |
404 |
405 | _RAND_INCREASING_TRANSFORMS = [
406 | "AutoContrast",
407 | "Equalize",
408 | "Invert",
409 | "Rotate",
410 | "PosterizeIncreasing",
411 | "SolarizeIncreasing",
412 | "SolarizeAdd",
413 | "ColorIncreasing",
414 | "ContrastIncreasing",
415 | "BrightnessIncreasing",
416 | "SharpnessIncreasing",
417 | "ShearX",
418 | "ShearY",
419 | "TranslateXRel",
420 | "TranslateYRel",
421 | ]
422 |
423 |
424 | # These experimental weights are based loosely on the relative improvements mentioned in paper.
425 | # They may not result in increased performance, but could likely be tuned to so.
426 | _RAND_CHOICE_WEIGHTS_0 = {
427 | "Rotate": 0.3,
428 | "ShearX": 0.2,
429 | "ShearY": 0.2,
430 | "TranslateXRel": 0.1,
431 | "TranslateYRel": 0.1,
432 | "Color": 0.025,
433 | "Sharpness": 0.025,
434 | "AutoContrast": 0.025,
435 | "Solarize": 0.005,
436 | "SolarizeAdd": 0.005,
437 | "Contrast": 0.005,
438 | "Brightness": 0.005,
439 | "Equalize": 0.005,
440 | "Posterize": 0,
441 | "Invert": 0,
442 | }
443 |
444 |
445 | def _select_rand_weights(weight_idx=0, transforms=None):
446 | transforms = transforms or _RAND_TRANSFORMS
447 | assert weight_idx == 0 # only one set of weights currently
448 | rand_weights = _RAND_CHOICE_WEIGHTS_0
449 | probs = [rand_weights[k] for k in transforms]
450 | probs /= np.sum(probs)
451 | return probs
452 |
453 |
454 | def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
455 | hparams = hparams or _HPARAMS_DEFAULT
456 | transforms = transforms or _RAND_TRANSFORMS
457 | return [
458 | AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams)
459 | for name in transforms
460 | ]
461 |
462 |
463 | class RandAugment:
464 | def __init__(self, ops, num_layers=2, choice_weights=None):
465 | self.ops = ops
466 | self.num_layers = num_layers
467 | self.choice_weights = choice_weights
468 |
469 | def __call__(self, img):
470 | # no replacement when using weighted choice
471 | ops = np.random.choice(
472 | self.ops,
473 | self.num_layers,
474 | replace=self.choice_weights is None,
475 | p=self.choice_weights,
476 | )
477 | for op in ops:
478 | img = op(img)
479 | return img
480 |
481 |
482 | def rand_augment_transform(config_str, hparams):
483 | """
484 | RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
485 |
486 | Create a RandAugment transform
487 | :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
488 | dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
489 | sections, not order sepecific determine
490 | 'm' - integer magnitude of rand augment
491 | 'n' - integer num layers (number of transform ops selected per image)
492 | 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
493 | 'mstd' - float std deviation of magnitude noise applied
494 | 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
495 | Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
496 | 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
497 | :param hparams: Other hparams (kwargs) for the RandAugmentation scheme
498 | :return: A PyTorch compatible Transform
499 | """
500 | # rand-m7-n4-mstd0.5-inc1
501 | # rand-m8-n2-mstd0.5-inc1
502 | # {'translate_const': 100, 'interpolation': 3}
503 | magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10)
504 | num_layers = 2 # default to 2 ops per image
505 | weight_idx = None # default to no probability weights for op choice
506 | transforms = _RAND_TRANSFORMS
507 | config = config_str.split("-")
508 | assert config[0] == "rand"
509 | config = config[1:]
510 | for c in config:
511 | cs = re.split(r"(\d.*)", c)
512 | if len(cs) < 2:
513 | continue
514 | key, val = cs[:2]
515 | if key == "mstd":
516 | # noise param injected via hparams for now
517 | hparams.setdefault("magnitude_std", float(val))
518 | elif key == "inc":
519 | if bool(val):
520 | transforms = _RAND_INCREASING_TRANSFORMS
521 | elif key == "m":
522 | magnitude = int(val)
523 | elif key == "n":
524 | num_layers = int(val)
525 | elif key == "w":
526 | weight_idx = int(val)
527 | else:
528 | assert NotImplementedError
529 | ra_ops = rand_augment_ops(
530 | magnitude=magnitude, hparams=hparams, transforms=transforms
531 | )
532 | choice_weights = (
533 | None if weight_idx is None else _select_rand_weights(weight_idx)
534 | )
535 | return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
536 |
--------------------------------------------------------------------------------
/datasets/transforms/random_erasing.py:
--------------------------------------------------------------------------------
1 | """
2 | This implementation is based on
3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py
4 | pulished under an Apache License 2.0.
5 | """
6 | import math
7 | import random
8 | import torch
9 |
10 |
11 | def _get_pixels(
12 | per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda"
13 | ):
14 | # NOTE I've seen CUDA illegal memory access errors being caused by the normal_()
15 | # paths, flip the order so normal is run on CPU if this becomes a problem
16 | # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508
17 | if per_pixel:
18 | return torch.empty(patch_size, dtype=dtype, device=device).normal_()
19 | elif rand_color:
20 | return torch.empty(
21 | (patch_size[0], 1, 1), dtype=dtype, device=device
22 | ).normal_()
23 | else:
24 | return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device)
25 |
26 |
27 | class RandomErasing:
28 | """Randomly selects a rectangle region in an image and erases its pixels.
29 | 'Random Erasing Data Augmentation' by Zhong et al.
30 | See https://arxiv.org/pdf/1708.04896.pdf
31 | This variant of RandomErasing is intended to be applied to either a batch
32 | or single image tensor after it has been normalized by dataset mean and std.
33 | Args:
34 | probability: Probability that the Random Erasing operation will be performed.
35 | min_area: Minimum percentage of erased area wrt input image area.
36 | max_area: Maximum percentage of erased area wrt input image area.
37 | min_aspect: Minimum aspect ratio of erased area.
38 | mode: pixel color mode, one of 'const', 'rand', or 'pixel'
39 | 'const' - erase block is constant color of 0 for all channels
40 | 'rand' - erase block is same per-channel random (normal) color
41 | 'pixel' - erase block is per-pixel random (normal) color
42 | max_count: maximum number of erasing blocks per image, area per box is scaled by count.
43 | per-image count is randomly chosen between 1 and this value.
44 | """
45 |
46 | def __init__(
47 | self,
48 | probability=0.5,
49 | min_area=0.02,
50 | max_area=1 / 3,
51 | min_aspect=0.3,
52 | max_aspect=None,
53 | mode="const",
54 | min_count=1,
55 | max_count=None,
56 | num_splits=0,
57 | device="cuda",
58 | cube=True,
59 | ):
60 | self.probability = probability
61 | self.min_area = min_area
62 | self.max_area = max_area
63 | max_aspect = max_aspect or 1 / min_aspect
64 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
65 | self.min_count = min_count
66 | self.max_count = max_count or min_count
67 | self.num_splits = num_splits
68 | mode = mode.lower()
69 | self.rand_color = False
70 | self.per_pixel = False
71 | self.cube = cube
72 | if mode == "rand":
73 | self.rand_color = True # per block random normal
74 | elif mode == "pixel":
75 | self.per_pixel = True # per pixel random normal
76 | else:
77 | assert not mode or mode == "const"
78 | self.device = device
79 |
80 | def _erase(self, img, chan, img_h, img_w, dtype):
81 | if random.random() > self.probability:
82 | return
83 | area = img_h * img_w
84 | count = (
85 | self.min_count
86 | if self.min_count == self.max_count
87 | else random.randint(self.min_count, self.max_count)
88 | )
89 | for _ in range(count):
90 | for _ in range(10):
91 | target_area = (
92 | random.uniform(self.min_area, self.max_area) * area / count
93 | )
94 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
95 | h = int(round(math.sqrt(target_area * aspect_ratio)))
96 | w = int(round(math.sqrt(target_area / aspect_ratio)))
97 | if w < img_w and h < img_h:
98 | top = random.randint(0, img_h - h)
99 | left = random.randint(0, img_w - w)
100 | img[:, top : top + h, left : left + w] = _get_pixels(
101 | self.per_pixel,
102 | self.rand_color,
103 | (chan, h, w),
104 | dtype=dtype,
105 | device=self.device,
106 | )
107 | break
108 |
109 | def _erase_cube(
110 | self,
111 | img,
112 | batch_start,
113 | batch_size,
114 | chan,
115 | img_h,
116 | img_w,
117 | dtype,
118 | ):
119 | if random.random() > self.probability:
120 | return
121 | area = img_h * img_w
122 | count = (
123 | self.min_count
124 | if self.min_count == self.max_count
125 | else random.randint(self.min_count, self.max_count)
126 | )
127 | for _ in range(count):
128 | for _ in range(100):
129 | target_area = (
130 | random.uniform(self.min_area, self.max_area) * area / count
131 | )
132 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
133 | h = int(round(math.sqrt(target_area * aspect_ratio)))
134 | w = int(round(math.sqrt(target_area / aspect_ratio)))
135 | if w < img_w and h < img_h:
136 | top = random.randint(0, img_h - h)
137 | left = random.randint(0, img_w - w)
138 | for i in range(batch_start, batch_size):
139 | img_instance = img[i]
140 | img_instance[
141 | :, top : top + h, left : left + w
142 | ] = _get_pixels(
143 | self.per_pixel,
144 | self.rand_color,
145 | (chan, h, w),
146 | dtype=dtype,
147 | device=self.device,
148 | )
149 | break
150 |
151 | def __call__(self, input):
152 | if len(input.size()) == 3:
153 | self._erase(input, *input.size(), input.dtype)
154 | else:
155 | batch_size, chan, img_h, img_w = input.size()
156 | # skip first slice of batch if num_splits is set (for clean portion of samples)
157 | batch_start = (
158 | batch_size // self.num_splits if self.num_splits > 1 else 0
159 | )
160 | if self.cube:
161 | self._erase_cube(
162 | input,
163 | batch_start,
164 | batch_size,
165 | chan,
166 | img_h,
167 | img_w,
168 | input.dtype,
169 | )
170 | else:
171 | for i in range(batch_start, batch_size):
172 | self._erase(input[i], chan, img_h, img_w, input.dtype)
173 | return input
174 |
--------------------------------------------------------------------------------
/datasets/transforms/surg_transforms.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import torch
3 | import torchvision.transforms.functional as F
4 | import warnings
5 | import random
6 | import numpy as np
7 | import torchvision
8 | from PIL import Image, ImageOps
9 | import numbers
10 | from imgaug import augmenters as iaa
11 |
12 | class SurgTransforms(object):
13 |
14 | def __init__(self, input_size=224, scales=(0.0, 0.3)):
15 | self.scales = scales
16 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size]
17 | self.aug = iaa.Sequential([
18 | # Resize to (252, 448)
19 | iaa.Resize({"height": 252, "width": 448}),
20 | # Crop with Scale [0.8 - 1.0]
21 | iaa.Crop(percent=scales, keep_size=False),
22 | # Resize to (224, 224)
23 | iaa.Resize({"height": input_size, "width": input_size}),
24 | # Random Augment Surgery
25 | iaa.SomeOf((0, 2), [
26 | iaa.pillike.EnhanceSharpness(),
27 | iaa.pillike.Autocontrast(),
28 | iaa.pillike.Equalize(),
29 | iaa.pillike.EnhanceContrast(),
30 | iaa.pillike.EnhanceColor(),
31 | iaa.pillike.EnhanceBrightness(),
32 | iaa.Rotate((-30, 30)),
33 | iaa.ShearX((-20, 20)),
34 | iaa.ShearY((-20, 20)),
35 | iaa.TranslateX(percent=(-0.1, 0.1)),
36 | iaa.TranslateY(percent=(-0.1, 0.1))
37 | ]),
38 | iaa.Sometimes(0.3, iaa.AddToHueAndSaturation((-50, 50), per_channel=True)),
39 | # Horizontally flip 50% of all images
40 | iaa.Fliplr(0.5)])
41 |
42 |
43 | def __call__(self, img_tuple):
44 | images, label = img_tuple
45 |
46 | # 给定裁剪起始及裁剪尺寸进行裁剪及Resize,处理过程中维持视频序列相同处理方案
47 | augDet = self.aug.to_deterministic()
48 | augment_images = []
49 | for _, img in enumerate(images):
50 | img_aug = augDet.augment_image(np.array(img))
51 | augment_images.append(img_aug)
52 |
53 | # for index, img in enumerate(augment_images):
54 | # img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
55 | # cv2.imshow(str(index), img)
56 | # cv2.waitKey()
57 | return (augment_images, label)
58 |
59 |
60 | class SurgStack(object):
61 |
62 | def __init__(self, roll=False):
63 | self.roll = roll
64 |
65 | def __call__(self, img_tuple):
66 | img_group, label = img_tuple
67 | if img_group[0].shape[2] == 1:
68 | return (np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2), label)
69 | elif img_group[0].shape[2] == 3:
70 | if self.roll:
71 | return (np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2), label)
72 | else:
73 | return (np.concatenate(img_group, axis=2), label)
74 |
75 |
76 | if __name__ == '__main__':
77 |
78 | class SurgTransforms(object):
79 |
80 | def __init__(self, input_size=224, scales=(0.0, 0.3)):
81 | self.scales = scales
82 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size]
83 | self.aug = iaa.Sequential([
84 | # Resize to (252, 448)
85 | iaa.Resize({"height": 252, "width": 448}),
86 | # Crop with Scale [0.8 - 1.0]
87 | iaa.Crop(percent=scales, keep_size=False),
88 | # Resize to (224, 224)
89 | iaa.Resize({"height": input_size, "width": input_size}),
90 | # Random Augment Surgery
91 | iaa.SomeOf((0, 2), [
92 | iaa.pillike.EnhanceSharpness(),
93 | iaa.pillike.Autocontrast(),
94 | iaa.pillike.Equalize(),
95 | iaa.pillike.EnhanceContrast(),
96 | iaa.pillike.EnhanceColor(),
97 | iaa.pillike.EnhanceBrightness(),
98 | iaa.Rotate((-30, 30)),
99 | iaa.ShearX((-20, 20)),
100 | iaa.ShearY((-20, 20)),
101 | iaa.TranslateX(percent=(-0.1, 0.1)),
102 | iaa.TranslateY(percent=(-0.1, 0.1))
103 | ]),
104 | iaa.Sometimes(0.3, iaa.AddToHueAndSaturation((-50, 50), per_channel=True)),
105 | # Horizontally flip 50% of all images
106 | iaa.Fliplr(0.5)
107 | ])
108 |
109 |
110 | def __call__(self, images):
111 |
112 | # 给定裁剪起始及裁剪尺寸进行裁剪及Resize,处理过程中维持视频序列相同处理方案
113 | augDet = self.aug.to_deterministic()
114 | augment_images = []
115 | for _, img in enumerate(images):
116 | img_aug = augDet.augment_image(img)
117 | augment_images.append(img_aug)
118 |
119 | return augment_images
120 |
121 | A = SurgTransforms()
122 | origin_images = cv2.imread("data/cholec80/frames/train/video01/0.jpg")
123 | origin_images = cv2.cvtColor(origin_images, cv2.COLOR_BGR2RGB)
124 |
125 | images = np.array([origin_images for _ in range(4)], dtype=np.uint8)
126 | img_1 = A(images)
127 | for index, img in enumerate(img_1):
128 | img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
129 | print(img.shape)
130 | cv2.imshow(str(index), img)
131 | cv2.waitKey()
132 | img_2 = A(images)
133 | for index, img in enumerate(img_2):
134 | img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
135 | print(img.shape)
136 | cv2.imshow(str(index)+'2', img)
137 | cv2.waitKey()
138 | img_3 = A(images)
139 | for index, img in enumerate(img_3):
140 | img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
141 | print(img.shape)
142 | cv2.imshow(str(index)+'3', img)
143 | cv2.waitKey()
144 |
--------------------------------------------------------------------------------
/datasets/transforms/transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.transforms.functional as F
3 | import warnings
4 | import random
5 | import numpy as np
6 | import torchvision
7 | from PIL import Image, ImageOps
8 | import numbers
9 |
10 |
11 | class GroupRandomCrop(object):
12 | def __init__(self, size):
13 | if isinstance(size, numbers.Number):
14 | self.size = (int(size), int(size))
15 | else:
16 | self.size = size
17 |
18 | def __call__(self, img_tuple):
19 | img_group, label = img_tuple
20 |
21 | w, h = img_group[0].size
22 | th, tw = self.size
23 |
24 | out_images = list()
25 |
26 | x1 = random.randint(0, w - tw)
27 | y1 = random.randint(0, h - th)
28 |
29 | for img in img_group:
30 | assert(img.size[0] == w and img.size[1] == h)
31 | if w == tw and h == th:
32 | out_images.append(img)
33 | else:
34 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
35 |
36 | return (out_images, label)
37 |
38 |
39 | class GroupCenterCrop(object):
40 | def __init__(self, size):
41 | self.worker = torchvision.transforms.CenterCrop(size)
42 |
43 | def __call__(self, img_tuple):
44 | img_group, label = img_tuple
45 | return ([self.worker(img) for img in img_group], label)
46 |
47 |
48 | class GroupNormalize(object):
49 | def __init__(self, mean, std):
50 | self.mean = mean
51 | self.std = std
52 |
53 | def __call__(self, tensor_tuple):
54 | tensor, label = tensor_tuple
55 | rep_mean = self.mean * (tensor.size()[0]//len(self.mean))
56 | rep_std = self.std * (tensor.size()[0]//len(self.std))
57 |
58 | # TODO: make efficient
59 | for t, m, s in zip(tensor, rep_mean, rep_std):
60 | t.sub_(m).div_(s)
61 |
62 | return (tensor,label)
63 |
64 |
65 | class GroupGrayScale(object):
66 | def __init__(self, size):
67 | self.worker = torchvision.transforms.Grayscale(size)
68 |
69 | def __call__(self, img_tuple):
70 | img_group, label = img_tuple
71 | return ([self.worker(img) for img in img_group], label)
72 |
73 |
74 | class GroupScale(object):
75 | """ Rescales the input PIL.Image to the given 'size'.
76 | 'size' will be the size of the smaller edge.
77 | For example, if height > width, then image will be
78 | rescaled to (size * height / width, size)
79 | size: size of the smaller edge
80 | interpolation: Default: PIL.Image.BILINEAR
81 | """
82 |
83 | def __init__(self, size, interpolation=Image.BILINEAR):
84 | self.worker = torchvision.transforms.Resize(size, interpolation)
85 |
86 | def __call__(self, img_tuple):
87 | img_group, label = img_tuple
88 | return ([self.worker(img) for img in img_group], label)
89 |
90 |
91 | # 成组多尺度裁剪
92 | class GroupMultiScaleCrop(object):
93 |
94 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True):
95 | self.scales = scales if scales is not None else [1, .875, .75, .66]
96 | self.max_distort = max_distort
97 | self.fix_crop = fix_crop
98 | self.more_fix_crop = more_fix_crop
99 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size]
100 | self.interpolation = Image.BILINEAR
101 |
102 | def __call__(self, img_tuple):
103 | img_group, label = img_tuple
104 |
105 | im_size = img_group[0].size # [854, 480]
106 |
107 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
108 | # 给定裁剪起始及裁剪尺寸进行裁剪及Resize,处理过程中维持视频序列相同处理方案
109 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group]
110 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) for img in crop_img_group]
111 |
112 | # import cv2
113 | # for index, img in enumerate(ret_img_group):
114 | # img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
115 | # cv2.imshow(str(index), img)
116 | # cv2.waitKey()
117 | return (ret_img_group, label)
118 |
119 | # 取较短边计算裁剪尺寸,并将小于input_size的设置为input_size,然后混合裁剪尺寸,生成多尺度长宽比
120 | def _sample_crop_size(self, im_size):
121 | image_w, image_h = im_size[0], im_size[1]
122 | # find a crop size
123 | base_size = min(image_w, image_h)
124 | crop_sizes = [int(base_size * x) for x in self.scales]
125 |
126 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes]
127 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes]
128 |
129 | pairs = []
130 | for i, h in enumerate(crop_h):
131 | for j, w in enumerate(crop_w):
132 | if abs(i - j) <= self.max_distort:
133 | pairs.append((w, h))
134 |
135 | crop_pair = random.choice(pairs)
136 | if not self.fix_crop:
137 | w_offset = random.randint(0, image_w - crop_pair[0])
138 | h_offset = random.randint(0, image_h - crop_pair[1])
139 | else:
140 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1])
141 |
142 | return crop_pair[0], crop_pair[1], w_offset, h_offset
143 |
144 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
145 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h)
146 | return random.choice(offsets)
147 |
148 | @staticmethod
149 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
150 | w_step = (image_w - crop_w) // 4
151 | h_step = (image_h - crop_h) // 4
152 |
153 | ret = list()
154 | ret.append((0, 0)) # upper left
155 | ret.append((4 * w_step, 0)) # upper right
156 | ret.append((0, 4 * h_step)) # lower left
157 | ret.append((4 * w_step, 4 * h_step)) # lower right
158 | ret.append((2 * w_step, 2 * h_step)) # center
159 |
160 | if more_fix_crop:
161 | ret.append((0, 2 * h_step)) # center left
162 | ret.append((4 * w_step, 2 * h_step)) # center right
163 | ret.append((2 * w_step, 4 * h_step)) # lower center
164 | ret.append((2 * w_step, 0 * h_step)) # upper center
165 |
166 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter
167 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter
168 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter
169 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
170 | return ret
171 |
172 |
173 | class Stack(object):
174 |
175 | def __init__(self, roll=False):
176 | self.roll = roll
177 |
178 | def __call__(self, img_tuple):
179 | img_group, label = img_tuple
180 | if img_group[0].mode == 'L':
181 | return (np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2), label)
182 | elif img_group[0].mode == 'RGB':
183 | if self.roll:
184 | return (np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2), label)
185 | else:
186 | return (np.concatenate(img_group, axis=2), label)
187 |
188 |
189 | class ToTorchFormatTensor(object):
190 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
191 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
192 | def __init__(self, div=True):
193 | self.div = div
194 |
195 | def __call__(self, pic_tuple):
196 | pic, label = pic_tuple
197 | if isinstance(pic, np.ndarray):
198 | # handle numpy array
199 | img = torch.from_numpy(pic).permute(2, 0, 1).contiguous()
200 | else:
201 | # handle PIL Image
202 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
203 | img = img.view(pic.size[1], pic.size[0], len(pic.mode))
204 | # put it from HWC to CHW format
205 | # yikes, this transpose takes 80% of the loading time/CPU
206 | img = img.transpose(0, 1).transpose(0, 2).contiguous()
207 | return (img.float().div(255.) if self.div else img.float(), label)
208 |
209 |
210 | class IdentityTransform(object):
211 |
212 | def __call__(self, data):
213 | return data
--------------------------------------------------------------------------------
/datasets/transforms/volume_transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image, ImageOps
3 | import torch
4 | import numbers
5 | import random
6 | import torchvision.transforms.functional as TF
7 |
8 | def convert_img(img):
9 | """Converts (H, W, C) numpy.ndarray to (C, W, H) format"""
10 | if len(img.shape) == 3:
11 | img = img.transpose(2, 0, 1)
12 | if len(img.shape) == 2:
13 | img = np.expand_dims(img, 0)
14 | return img
15 |
16 |
17 | class ClipToTensor(object):
18 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255]
19 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0]
20 | """
21 |
22 | def __init__(self, channel_nb=3, div_255=True, numpy=False):
23 | self.channel_nb = channel_nb
24 | self.div_255 = div_255
25 | self.numpy = numpy
26 |
27 | def __call__(self, clip):
28 | """
29 | Args: clip (list of numpy.ndarray): clip (list of images)
30 | to be converted to tensor.
31 | """
32 | # Retrieve shape
33 | if isinstance(clip[0], np.ndarray):
34 | h, w, ch = clip[0].shape
35 | assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch)
36 | elif isinstance(clip[0], Image.Image):
37 | w, h = clip[0].size
38 | else:
39 | raise TypeError(
40 | "Expected numpy.ndarray or PIL.Image\
41 | but got list of {0}".format(
42 | type(clip[0])
43 | )
44 | )
45 |
46 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)])
47 |
48 | # Convert
49 | for img_idx, img in enumerate(clip):
50 | if isinstance(img, np.ndarray):
51 | pass
52 | elif isinstance(img, Image.Image):
53 | img = np.array(img, copy=False)
54 | else:
55 | raise TypeError(
56 | "Expected numpy.ndarray or PIL.Image\
57 | but got list of {0}".format(
58 | type(clip[0])
59 | )
60 | )
61 | img = convert_img(img)
62 | np_clip[:, img_idx, :, :] = img
63 | if self.numpy:
64 | if self.div_255:
65 | np_clip = np_clip / 255.0
66 | return np_clip
67 |
68 | else:
69 | tensor_clip = torch.from_numpy(np_clip)
70 |
71 | if not isinstance(tensor_clip, torch.FloatTensor):
72 | tensor_clip = tensor_clip.float()
73 | if self.div_255:
74 | tensor_clip = torch.div(tensor_clip, 255)
75 | return tensor_clip
76 |
77 |
78 | # Note this norms data to -1/1
79 | class ClipToTensor_K(object):
80 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255]
81 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0]
82 | """
83 |
84 | def __init__(self, channel_nb=3, div_255=True, numpy=False):
85 | self.channel_nb = channel_nb
86 | self.div_255 = div_255
87 | self.numpy = numpy
88 |
89 | def __call__(self, clip):
90 | """
91 | Args: clip (list of numpy.ndarray): clip (list of images)
92 | to be converted to tensor.
93 | """
94 | # Retrieve shape
95 | if isinstance(clip[0], np.ndarray):
96 | h, w, ch = clip[0].shape
97 | assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch)
98 | elif isinstance(clip[0], Image.Image):
99 | w, h = clip[0].size
100 | else:
101 | raise TypeError(
102 | "Expected numpy.ndarray or PIL.Image\
103 | but got list of {0}".format(
104 | type(clip[0])
105 | )
106 | )
107 |
108 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)])
109 |
110 | # Convert
111 | for img_idx, img in enumerate(clip):
112 | if isinstance(img, np.ndarray):
113 | pass
114 | elif isinstance(img, Image.Image):
115 | img = np.array(img, copy=False)
116 | else:
117 | raise TypeError(
118 | "Expected numpy.ndarray or PIL.Image\
119 | but got list of {0}".format(
120 | type(clip[0])
121 | )
122 | )
123 | img = convert_img(img)
124 | np_clip[:, img_idx, :, :] = img
125 | if self.numpy:
126 | if self.div_255:
127 | np_clip = (np_clip - 127.5) / 127.5
128 | return np_clip
129 |
130 | else:
131 | tensor_clip = torch.from_numpy(np_clip)
132 |
133 | if not isinstance(tensor_clip, torch.FloatTensor):
134 | tensor_clip = tensor_clip.float()
135 | if self.div_255:
136 | tensor_clip = torch.div(torch.sub(tensor_clip, 127.5), 127.5)
137 | return tensor_clip
138 |
139 |
140 | class ToTensor(object):
141 | """Converts numpy array to tensor"""
142 |
143 | def __call__(self, array):
144 | tensor = torch.from_numpy(array)
145 | return tensor
146 |
147 |
148 | class RandomCrop(object):
149 |
150 | def __init__(self, size, padding=0, sequence_length=16):
151 | if isinstance(size, numbers.Number):
152 | self.size = (int(size), int(size))
153 | else:
154 | self.size = size
155 | self.sequence_length = sequence_length
156 | self.padding = padding
157 | self.count = 0
158 |
159 | def __call__(self, img):
160 |
161 | if self.padding > 0:
162 | img = ImageOps.expand(img, border=self.padding, fill=0)
163 |
164 | w, h = img.size
165 | th, tw = self.size
166 | if w == tw and h == th:
167 | return img
168 |
169 | random.seed(self.count // self.sequence_length)
170 | x1 = random.randint(0, w - tw)
171 | y1 = random.randint(0, h - th)
172 | # print(self.count, x1, y1)
173 | self.count += 1
174 | return img.crop((x1, y1, x1 + tw, y1 + th))
175 |
176 |
177 | class RandomHorizontalFlip(object):
178 | def __init__(self, sequence_length=16):
179 | self.count = 0
180 | self.sequence_length = sequence_length
181 |
182 | def __call__(self, img):
183 | seed = self.count // self.sequence_length
184 | random.seed(seed)
185 | prob = random.random()
186 | self.count += 1
187 | # print(self.count, seed, prob)
188 | if prob < 0.5:
189 | return img.transpose(Image.FLIP_LEFT_RIGHT)
190 | return img
191 |
192 |
193 | class RandomRotation(object):
194 | def __init__(self,degrees, sequence_length=16):
195 | self.degrees = degrees
196 | self.count = 0
197 | self.sequence_length = sequence_length
198 |
199 | def __call__(self, img):
200 | seed = self.count // self.sequence_length
201 | random.seed(seed)
202 | self.count += 1
203 | angle = random.randint(-self.degrees,self.degrees)
204 | return TF.rotate(img, angle)
205 |
206 |
207 |
208 | class ColorJitter(object):
209 | def __init__(self,brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1, sequence_length=16):
210 | self.brightness = brightness
211 | self.contrast = contrast
212 | self.saturation = saturation
213 | self.hue = hue
214 | self.count = 0
215 | self.sequence_length = sequence_length
216 |
217 | def __call__(self, img):
218 | seed = self.count // self.sequence_length
219 | random.seed(seed)
220 | self.count += 1
221 | brightness_factor = random.uniform(1 - self.brightness, 1 + self.brightness)
222 | contrast_factor = random.uniform(1 - self.contrast, 1 + self.contrast)
223 | saturation_factor = random.uniform(1 - self.saturation, 1 + self.saturation)
224 | hue_factor = random.uniform(- self.hue, self.hue)
225 |
226 | img_ = TF.adjust_brightness(img,brightness_factor)
227 | img_ = TF.adjust_contrast(img_,contrast_factor)
228 | img_ = TF.adjust_saturation(img_,saturation_factor)
229 | img_ = TF.adjust_hue(img_,hue_factor)
230 |
231 | return img_
--------------------------------------------------------------------------------
/downstream_phase/__pycache__/datasets_phase.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/downstream_phase/__pycache__/datasets_phase.cpython-310.pyc
--------------------------------------------------------------------------------
/downstream_phase/__pycache__/datasets_phase.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/downstream_phase/__pycache__/datasets_phase.cpython-311.pyc
--------------------------------------------------------------------------------
/downstream_phase/__pycache__/engine_for_phase.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/downstream_phase/__pycache__/engine_for_phase.cpython-310.pyc
--------------------------------------------------------------------------------
/downstream_phase/__pycache__/engine_for_phase.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/downstream_phase/__pycache__/engine_for_phase.cpython-311.pyc
--------------------------------------------------------------------------------
/downstream_phase/datasets_phase.py:
--------------------------------------------------------------------------------
1 | import os
2 | from datasets.transforms import *
3 | from datasets.transforms.surg_transforms import *
4 |
5 | from datasets.phase.Cholec80_phase import PhaseDataset_Cholec80
6 | from datasets.phase.AutoLaparo_phase import PhaseDataset_AutoLaparo
7 | from datasets.phase.PmLR50_phase import PhaseDataset_PmLR50
8 |
9 | def build_dataset(is_train, test_mode, fps, args):
10 | """Load video phase recognition dataset."""
11 |
12 | if args.data_set == "Cholec80":
13 | mode = None
14 | anno_path = None
15 | if is_train is True:
16 | mode = "train"
17 | anno_path = os.path.join(
18 | args.data_path, "labels", mode, fps + "train.pickle"
19 | )
20 | elif test_mode is True:
21 | mode = "test"
22 | anno_path = os.path.join(
23 | args.data_path, "labels", mode, fps + "val_test.pickle"
24 | )
25 | else:
26 | mode = "test"
27 | anno_path = os.path.join(args.data_path, "labels", mode, fps + "val_test.pickle")
28 |
29 | dataset = PhaseDataset_Cholec80(
30 | anno_path=anno_path,
31 | data_path=args.data_path,
32 | mode=mode,
33 | data_strategy=args.data_strategy,
34 | output_mode=args.output_mode,
35 | cut_black=args.cut_black,
36 | clip_len=args.num_frames,
37 | frame_sample_rate=args.sampling_rate,
38 | keep_aspect_ratio=True,
39 | crop_size=args.input_size,
40 | short_side_size=args.short_side_size,
41 | new_height=256,
42 | new_width=320,
43 | args=args,
44 | )
45 | nb_classes = 7
46 |
47 | elif args.data_set == "AutoLaparo":
48 | mode = None
49 | anno_path = None
50 | if is_train is True:
51 | mode = "train"
52 | anno_path = os.path.join(
53 | args.data_path, "labels_pkl", mode, fps + "train.pickle"
54 | )
55 | elif test_mode is True:
56 | mode = "test"
57 | anno_path = os.path.join(
58 | args.data_path, "labels_pkl", mode, fps + "test.pickle"
59 | )
60 | else:
61 | mode = "val"
62 | anno_path = os.path.join(args.data_path, "labels_pkl", mode, fps + "val.pickle")
63 |
64 | dataset = PhaseDataset_AutoLaparo(
65 | anno_path=anno_path,
66 | data_path=args.data_path,
67 | mode=mode,
68 | data_strategy=args.data_strategy,
69 | output_mode=args.output_mode,
70 | cut_black=args.cut_black,
71 | clip_len=args.num_frames,
72 | frame_sample_rate=args.sampling_rate,
73 | keep_aspect_ratio=True,
74 | crop_size=args.input_size,
75 | short_side_size=args.short_side_size,
76 | new_height=256,
77 | new_width=320,
78 | args=args,
79 | )
80 | nb_classes = 7
81 |
82 | elif args.data_set == "PmLR50":
83 | mode = None
84 | anno_path = None
85 | if is_train is True:
86 | mode = "train"
87 | anno_path = os.path.join(
88 | args.data_path, "labels", mode, fps + "train.pickle"
89 | )
90 | elif test_mode is True:
91 | mode = "test"
92 | anno_path = os.path.join(
93 | args.data_path, "labels", 'infer', fps + "infer.pickle"
94 | )
95 | else:
96 | mode = "val"
97 | anno_path = os.path.join(args.data_path, "labels", "test", fps + "test.pickle")
98 |
99 | dataset = PhaseDataset_PmLR50(
100 | anno_path=anno_path,
101 | data_path=args.data_path,
102 | mode=mode,
103 | data_strategy=args.data_strategy,
104 | output_mode=args.output_mode,
105 | cut_black=args.cut_black,
106 | clip_len=args.num_frames,
107 | frame_sample_rate=args.sampling_rate,
108 | keep_aspect_ratio=True,
109 | crop_size=args.input_size,
110 | short_side_size=args.short_side_size,
111 | new_height=256,
112 | new_width=320,
113 | args=args,
114 | )
115 | nb_classes = 5
116 | else:
117 | print("Error")
118 |
119 | assert nb_classes == args.nb_classes
120 | print("%s - %s : Number of the class = %d" % (mode, fps, args.nb_classes))
121 | print("Data Strategy: %s" % args.data_strategy)
122 | print("Output Mode: %s" % args.output_mode)
123 | print("Cut Black: %s" % args.cut_black)
124 | if args.sampling_rate == 0:
125 | print(
126 | "%s Frames with Temporal sample Rate %s (%s)"
127 | % (str(args.num_frames), str(args.sampling_rate), "Exponential Stride")
128 | )
129 | elif args.sampling_rate == -1:
130 | print(
131 | "%s Frames with Temporal sample Rate %s (%s)"
132 | % (str(args.num_frames), str(args.sampling_rate), "Random Stride (1-5)")
133 | )
134 | elif args.sampling_rate == -2:
135 | print(
136 | "%s Frames with Temporal sample Rate %s (%s)"
137 | % (str(args.num_frames), str(args.sampling_rate), "Incremental Stride")
138 | )
139 | else:
140 | print(
141 | "%s Frames with Temporal sample Rate %s"
142 | % (str(args.num_frames), str(args.sampling_rate))
143 | )
144 |
145 | return dataset, nb_classes
146 |
--------------------------------------------------------------------------------
/evaluation_matlab/Evaluate.m:
--------------------------------------------------------------------------------
1 | function [ res, prec, rec, acc, f1 ] = Evaluate( gtLabelID, predLabelID, fps )
2 | %EVALUATE
3 | % A function to evaluate the performance of the phase recognition method
4 | % providing jaccard index, precision, and recall for each phase
5 | % and accuracy over the surgery. All metrics are computed in a relaxed
6 | % boundary mode.
7 | % OUTPUT:
8 | % res: the jaccard index per phase (relaxed) - NaN for non existing phase in GT
9 | % prec: precision per phase (relaxed) - NaN for non existing phase in GT
10 | % rec: recall per phase (relaxed) - NaN for non existing phase in GT
11 | % acc: the accuracy over the video (relaxed)
12 |
13 | oriT = 10 * fps; % 10 seconds relaxed boundary
14 |
15 | res = []; prec = []; rec = [];
16 | diff = predLabelID - gtLabelID;
17 | updatedDiff = [];
18 |
19 | % obtain the true positive with relaxed boundary
20 | for iPhase = 1:7 % nPhases
21 | gtConn = bwconncomp(gtLabelID == iPhase);
22 |
23 | for iConn = 1:gtConn.NumObjects
24 | startIdx = min(gtConn.PixelIdxList{iConn});
25 | endIdx = max(gtConn.PixelIdxList{iConn});
26 |
27 | curDiff = diff(startIdx:endIdx);
28 |
29 | % in the case where the phase is shorter than the relaxed boundary
30 | t = oriT;
31 | if(t > length(curDiff))
32 | t = length(curDiff);
33 | disp(['Very short phase ' num2str(iPhase)]);
34 | end
35 |
36 | % relaxed boundary
37 | % revised for cholec80 dataset !!!!!!!!!!!
38 | if(iPhase == 4 || iPhase == 5) % Gallbladder dissection and packaging might jump between two phases
39 | curDiff(curDiff(1:t)==-1) = 0; % late transition
40 | curDiff(curDiff(end-t+1:end)==1 | curDiff(end-t+1:end)==2) = 0; % early transition % 5 can be predicted as 6/7 at the end > 5 followed by 6/7
41 | elseif(iPhase == 6 || iPhase == 7) % Gallbladder dissection might jump between two phases
42 | curDiff(curDiff(1:t)==-1 | curDiff(1:t)==-2) = 0; % late transition
43 | curDiff(curDiff(end-t+1:end)==1 | curDiff(end-t+1:end)==2) = 0; % early transition
44 | else
45 | % general situation
46 | curDiff(curDiff(1:t)==-1) = 0; % late transition
47 | curDiff(curDiff(end-t+1:end)==1) = 0; % early transition
48 | end
49 |
50 | updatedDiff(startIdx:endIdx) = curDiff;
51 | end
52 | end
53 |
54 | % compute jaccard index, prec, and rec per phase
55 | for iPhase = 1:7
56 | gtConn = bwconncomp(gtLabelID == iPhase);
57 | predConn = bwconncomp(predLabelID == iPhase);
58 |
59 | if(gtConn.NumObjects == 0)
60 | % no iPhase in current ground truth, assigned NaN values
61 | % SHOULD be excluded in the computation of mean (use nanmean)
62 | res(end+1) = NaN;
63 | prec(end+1) = NaN;
64 | rec(end+1) = NaN;
65 | continue;
66 | end
67 |
68 | iPUnion = union(vertcat(predConn.PixelIdxList{:}), vertcat(gtConn.PixelIdxList{:}));
69 | tp = sum(updatedDiff(iPUnion) == 0);
70 | jaccard = tp/length(iPUnion);
71 | jaccard = jaccard * 100;
72 |
73 | % res(end+1, :) = [iPhase jaccard];
74 | res(end+1) = jaccard;
75 |
76 | % Compute prec and rec
77 | indx = (gtLabelID == iPhase);
78 |
79 | sumTP = tp; % sum(predLabelID(indx) == iPhase);
80 | sumPred = sum(predLabelID == iPhase);
81 | sumGT = sum(indx);
82 |
83 | prec(end+1) = sumTP * 100 / sumPred;
84 | rec(end+1) = sumTP * 100 / sumGT;
85 | end
86 |
87 | % compute accuracy
88 | acc = sum(updatedDiff==0) / length(gtLabelID);
89 | acc = acc * 100;
90 | f1 = zeros(1,7);
91 | for i=1:7
92 | % f1(i) = 2*(prec(i)*rec(i))/((prec(i)+rec(i))+0.0000001)
93 | f1(i) = 2*(prec(i)*rec(i))/((prec(i)+rec(i)));
94 | end
95 | end
96 |
97 |
--------------------------------------------------------------------------------
/evaluation_matlab/Evaluate_Cataract101.m:
--------------------------------------------------------------------------------
1 | function [ res, prec, rec, acc, f1 ] = Evaluate( gtLabelID, predLabelID, fps )
2 | %EVALUATE
3 | % A function to evaluate the performance of the phase recognition method
4 | % providing jaccard index, precision, and recall for each phase
5 | % and accuracy over the surgery. All metrics are computed in a relaxed
6 | % boundary mode.
7 | % OUTPUT:
8 | % res: the jaccard index per phase (relaxed) - NaN for non existing phase in GT
9 | % prec: precision per phase (relaxed) - NaN for non existing phase in GT
10 | % rec: recall per phase (relaxed) - NaN for non existing phase in GT
11 | % acc: the accuracy over the video (relaxed)
12 |
13 | oriT = 10 * fps; % 10 seconds relaxed boundary
14 |
15 | res = []; prec = []; rec = [];
16 | diff = predLabelID - gtLabelID;
17 | updatedDiff = [];
18 |
19 | % obtain the true positive with relaxed boundary
20 | for iPhase = 1:10 % nPhases
21 | gtConn = bwconncomp(gtLabelID == iPhase);
22 |
23 | for iConn = 1:gtConn.NumObjects
24 | startIdx = min(gtConn.PixelIdxList{iConn});
25 | endIdx = max(gtConn.PixelIdxList{iConn});
26 |
27 | curDiff = diff(startIdx:endIdx);
28 |
29 | updatedDiff(startIdx:endIdx) = curDiff;
30 | end
31 | end
32 |
33 | % compute jaccard index, prec, and rec per phase
34 | for iPhase = 1:10
35 | gtConn = bwconncomp(gtLabelID == iPhase);
36 | predConn = bwconncomp(predLabelID == iPhase);
37 |
38 | if(gtConn.NumObjects == 0)
39 | % no iPhase in current ground truth, assigned NaN values
40 | % SHOULD be excluded in the computation of mean (use nanmean)
41 | res(end+1) = NaN;
42 | prec(end+1) = NaN;
43 | rec(end+1) = NaN;
44 | continue;
45 | end
46 |
47 | iPUnion = union(vertcat(predConn.PixelIdxList{:}), vertcat(gtConn.PixelIdxList{:}));
48 | tp = sum(updatedDiff(iPUnion) == 0);
49 | jaccard = tp/length(iPUnion);
50 | jaccard = jaccard * 100;
51 |
52 | % res(end+1, :) = [iPhase jaccard];
53 | res(end+1) = jaccard;
54 |
55 | % Compute prec and rec
56 | indx = (gtLabelID == iPhase);
57 |
58 | sumTP = tp; % sum(predLabelID(indx) == iPhase);
59 | sumPred = sum(predLabelID == iPhase);
60 | sumGT = sum(indx);
61 |
62 | prec(end+1) = sumTP * 100 / sumPred;
63 | rec(end+1) = sumTP * 100 / sumGT;
64 | end
65 |
66 | % compute accuracy
67 | acc = sum(updatedDiff==0) / length(gtLabelID);
68 | acc = acc * 100;
69 | f1 = zeros(1,10);
70 | for i=1:10
71 | % f1(i) = 2*(prec(i)*rec(i))/((prec(i)+rec(i))+0.0000001)
72 | f1(i) = 2*(prec(i)*rec(i))/((prec(i)+rec(i)));
73 | end
74 | end
75 |
76 |
--------------------------------------------------------------------------------
/evaluation_matlab/Evaluate_m2cai.m:
--------------------------------------------------------------------------------
1 | function [ res, prec, rec, acc ] = Evaluate( gtLabelID, predLabelID, fps )
2 | %EVALUATE
3 | % A function to evaluate the performance of the phase recognition method
4 | % providing jaccard index, precision, and recall for each phase
5 | % and accuracy over the surgery. All metrics are computed in a relaxed
6 | % boundary mode.
7 | % OUTPUT:
8 | % res: the jaccard index per phase (relaxed) - NaN for non existing phase in GT
9 | % prec: precision per phase (relaxed) - NaN for non existing phase in GT
10 | % rec: recall per phase (relaxed) - NaN for non existing phase in GT
11 | % acc: the accuracy over the video (relaxed)
12 |
13 | oriT = 10 * fps; % 10 seconds relaxed boundary
14 |
15 | res = []; prec = []; rec = [];
16 | diff = predLabelID - gtLabelID;
17 | updatedDiff = [];
18 |
19 | % obtain the true positive with relaxed boundary
20 | for iPhase = 1:8 % nPhases
21 | gtConn = bwconncomp(gtLabelID == iPhase);
22 |
23 | for iConn = 1:gtConn.NumObjects
24 | startIdx = min(gtConn.PixelIdxList{iConn});
25 | endIdx = max(gtConn.PixelIdxList{iConn});
26 |
27 | curDiff = diff(startIdx:endIdx);
28 |
29 | % in the case where the phase is shorter than the relaxed boundary
30 | t = oriT;
31 | if(t > length(curDiff))
32 | t = length(curDiff);
33 | disp(['Very short phase ' num2str(iPhase)]);
34 | end
35 |
36 | % relaxed boundary
37 | if(iPhase == 5 || iPhase == 6) % Gallbladder dissection and packaging might jump between two phases
38 | curDiff(curDiff(1:t)==-1) = 0; % late transition
39 | curDiff(curDiff(end-t+1:end)==1 | curDiff(end-t+1:end)==2) = 0; % early transition % 5 can be predicted as 6/7 at the end > 5 followed by 6/7
40 | elseif(iPhase == 7 || iPhase == 8) % Gallbladder dissection might jump between two phases
41 | curDiff(curDiff(1:t)==-1 | curDiff(1:t)==-2) = 0; % late transition
42 | curDiff(curDiff(end-t+1:end)==1 | curDiff(end-t+1:end)==2) = 0; % early transition
43 | else
44 | % general situation
45 | curDiff(curDiff(1:t)==-1) = 0; % late transition
46 | curDiff(curDiff(end-t+1:end)==1) = 0; % early transition
47 | end
48 |
49 | updatedDiff(startIdx:endIdx) = curDiff;
50 | end
51 | end
52 |
53 | % compute jaccard index, prec, and rec per phase
54 | for iPhase = 1:8
55 | gtConn = bwconncomp(gtLabelID == iPhase);
56 | predConn = bwconncomp(predLabelID == iPhase);
57 |
58 | if(gtConn.NumObjects == 0)
59 | % no iPhase in current ground truth, assigned NaN values
60 | % SHOULD be excluded in the computation of mean (use nanmean)
61 | res(end+1) = NaN;
62 | prec(end+1) = NaN;
63 | rec(end+1) = NaN;
64 | continue;
65 | end
66 |
67 | iPUnion = union(vertcat(predConn.PixelIdxList{:}), vertcat(gtConn.PixelIdxList{:}));
68 | tp = sum(updatedDiff(iPUnion) == 0);
69 | jaccard = tp/length(iPUnion);
70 | jaccard = jaccard * 100;
71 |
72 | % res(end+1, :) = [iPhase jaccard];
73 | res(end+1) = jaccard;
74 |
75 | % Compute prec and rec
76 | indx = (gtLabelID == iPhase);
77 |
78 | sumTP = tp; % sum(predLabelID(indx) == iPhase);
79 | sumPred = sum(predLabelID == iPhase);
80 | sumGT = sum(indx);
81 |
82 | prec(end+1) = sumTP * 100 / sumPred;
83 | rec(end+1) = sumTP * 100 / sumGT;
84 | end
85 |
86 | % compute accuracy
87 | acc = sum(updatedDiff==0) / length(gtLabelID);
88 | acc = acc * 100;
89 |
90 | end
91 |
92 |
--------------------------------------------------------------------------------
/evaluation_matlab/Main.m:
--------------------------------------------------------------------------------
1 | close all; clear all;
2 |
3 | phaseGroundTruths = {};
4 | gt_root_folder = 'phase_annotations/';
5 | for k = 41:80
6 | num = num2str(k);
7 | to_add = ['video-' num];
8 | video_name = [gt_root_folder to_add '.txt'];
9 | phaseGroundTruths = [phaseGroundTruths video_name];
10 | end
11 | % phaseGroundTruths = {'video41-phase.txt', ...
12 | % 'video42-phase.txt'};
13 | % phaseGroundTruths
14 |
15 | phases = {'Preparation', 'CalotTriangleDissection', ...
16 | 'ClippingCutting', 'GallbladderDissection', 'GallbladderPackaging', 'CleaningCoagulation', ...
17 | 'GallbladderRetraction'};
18 |
19 | fps = 1;
20 | jaccard1 = zeros(7, 40);
21 | prec1 = zeros(7, 40);
22 | rec1 = zeros(7, 40);
23 | acc1 = zeros(1, 40);
24 | f11 = zeros(7, 40);
25 |
26 |
27 | for i = 1:length(phaseGroundTruths)
28 | predroot = 'prediction/';
29 | phaseGroundTruth = phaseGroundTruths{i};
30 | predFile = [predroot 'video-' phaseGroundTruth(end-5:end-4) '.txt'];
31 | [gt] = ReadPhaseLabel(phaseGroundTruth);
32 | [pred] = ReadPhaseLabel(predFile);
33 |
34 | if(size(gt{1}, 1) ~= size(pred{1},1) || size(gt{2}, 1) ~= size(pred{2},1))
35 | error(['ERROR:' ground_truth_file '\nGround truth and prediction have different sizes']);
36 | end
37 |
38 | if(~isempty(find(gt{1} ~= pred{1})))
39 | error(['ERROR: ' ground_truth_file '\nThe frame index in ground truth and prediction is not equal']);
40 | end
41 |
42 | t = length(gt{2});
43 | for z=1:t
44 | gt{2}{z} = gt{2}{z}(1);
45 | pred{2}{z} = pred{2}{z}(1);
46 | end
47 | % reassigning the phase labels to numbers
48 | gtLabelID = [];
49 | predLabelID = [];
50 | for j = 1:7
51 | gtLabelID(find(strcmp(num2str(j-1), gt{2}))) = j;
52 | predLabelID(find(strcmp(num2str(j-1), pred{2}))) = j;
53 | end
54 |
55 | % compute jaccard index, precision, recall, and the accuracy
56 | [jaccard, prec, rec, acc, f1] = Evaluate(gtLabelID, predLabelID, fps);
57 | jaccard1(:, i) = jaccard;
58 | prec1(:, i) = prec;
59 | rec1(:, i) = rec;
60 | acc1(i) = acc;
61 | f11(:, i) = f1;
62 | end
63 |
64 | acc = acc1;
65 | rec = rec1;
66 | prec = prec1;
67 | jaccard = jaccard1;
68 | f1 = f11;
69 |
70 | accPerVideo= acc;
71 |
72 | % Compute means and stds
73 | index = find(jaccard>100);
74 | jaccard(index)=100;
75 | meanJaccPerPhase = nanmean(jaccard, 2);
76 | meanJaccPerVideo = nanmean(jaccard, 1);
77 | meanJacc = mean(meanJaccPerPhase);
78 | stdJacc = std(meanJaccPerPhase);
79 | for h = 1:7
80 | jaccphase = jaccard(h,:);
81 | meanjaccphase(h) = nanmean(jaccphase);
82 | stdjaccphase(h) = nanstd(jaccphase);
83 | end
84 |
85 | index = find(f1>100);
86 | f1(index)=100;
87 | meanF1PerPhase = nanmean(f1, 2);
88 | meanF1PerVideo = nanmean(f1, 1);
89 | meanF1 = mean(meanF1PerPhase);
90 | stdF1 = std(meanF1PerVideo);
91 | for h = 1:7
92 | f1phase = f1(h,:);
93 | meanf1phase(h) = nanmean(f1phase);
94 | stdf1phase(h) = nanstd(f1phase);
95 | end
96 |
97 | index = find(prec>100);
98 | prec(index)=100;
99 | meanPrecPerPhase = nanmean(prec, 2);
100 | meanPrecPerVideo = nanmean(prec, 1);
101 | meanPrec = nanmean(meanPrecPerPhase);
102 | stdPrec = nanstd(meanPrecPerPhase);
103 | for h = 1:7
104 | precphase = prec(h,:);
105 | meanprecphase(h) = nanmean(precphase);
106 | stdprecphase(h) = nanstd(precphase);
107 | end
108 |
109 | index = find(rec>100);
110 | rec(index)=100;
111 | meanRecPerPhase = nanmean(rec, 2);
112 | meanRecPerVideo = nanmean(rec, 1);
113 | meanRec = mean(meanRecPerPhase);
114 | stdRec = std(meanRecPerPhase);
115 | for h = 1:7
116 | recphase = rec(h,:);
117 | meanrecphase(h) = nanmean(recphase);
118 | stdrecphase(h) = nanstd(recphase);
119 | end
120 |
121 |
122 | meanAcc = mean(acc);
123 | stdAcc = std(acc);
124 |
125 | % Display results
126 | % fprintf('model is :%s\n', model_rootfolder);
127 | disp('================================================');
128 | disp([sprintf('%25s', 'Phase') '|' sprintf('%6s', 'Jacc') '|'...
129 | sprintf('%6s', 'Prec') '|' sprintf('%6s', 'Rec') '|']);
130 | disp('================================================');
131 | for iPhase = 1:length(phases)
132 | disp([sprintf('%25s', phases{iPhase}) '|' sprintf('%6.2f', meanJaccPerPhase(iPhase)) '|' ...
133 | sprintf('%6.2f', meanPrecPerPhase(iPhase)) '|' sprintf('%6.2f', meanRecPerPhase(iPhase)) '|']);
134 | disp('---------------------------------------------');
135 | end
136 | disp('================================================');
137 |
138 | disp(['Mean jaccard: ' sprintf('%5.2f', meanJacc) '+-' sprintf('%5.2f', stdJacc)]);
139 | disp(['Mean f1-score: ' sprintf('%5.2f', meanF1) '+-' sprintf('%5.2f', stdF1)]);
140 | disp(['Mean accuracy: ' sprintf('%5.2f', meanAcc) '+-' sprintf('%5.2f', stdAcc)]);
141 | disp(['Mean precision: ' sprintf('%5.2f', meanPrec) '+-' sprintf('%5.2f', stdPrec)]);
142 | disp(['Mean recall: ' sprintf('%5.2f', meanRec) '+-' sprintf('%5.2f', stdRec)]);
143 |
--------------------------------------------------------------------------------
/evaluation_matlab/Main_AutoLaparo.m:
--------------------------------------------------------------------------------
1 | close all; clear all;
2 |
3 | phaseGroundTruths = {};
4 | gt_root_folder = 'phase_annotations/';
5 | for k = 15:21
6 | num = num2str(k);
7 | to_add = ['video-' num];
8 | video_name = [gt_root_folder to_add '.txt'];
9 | phaseGroundTruths = [phaseGroundTruths video_name];
10 | end
11 | % phaseGroundTruths = {'video41-phase.txt', ...
12 | % 'video42-phase.txt'};
13 | % phaseGroundTruths
14 |
15 | phases = {'Preparation', 'CalotTriangleDissection', ...
16 | 'ClippingCutting', 'GallbladderDissection', 'GallbladderPackaging', 'CleaningCoagulation', ...
17 | 'GallbladderRetraction'};
18 |
19 | fps = 1;
20 | jaccard1 = zeros(7, 7);
21 | prec1 = zeros(7, 7);
22 | rec1 = zeros(7, 7);
23 | acc1 = zeros(1, 7);
24 | f11 = zeros(7, 7);
25 |
26 |
27 | for i = 1:length(phaseGroundTruths)
28 | predroot = 'prediction/';
29 | %predroot = '../../Results/multi/phase';
30 | %predroot = '../../Results/multi_kl_best_890_882/phase_post';
31 | phaseGroundTruth = phaseGroundTruths{i};
32 | predFile = [predroot 'video-' phaseGroundTruth(end-5:end-4) '.txt'];
33 | [gt] = ReadPhaseLabel(phaseGroundTruth);
34 | [pred] = ReadPhaseLabel(predFile);
35 |
36 | if(size(gt{1}, 1) ~= size(pred{1},1) || size(gt{2}, 1) ~= size(pred{2},1))
37 | error(['ERROR:' ground_truth_file '\nGround truth and prediction have different sizes']);
38 | end
39 |
40 | if(~isempty(find(gt{1} ~= pred{1})))
41 | error(['ERROR: ' ground_truth_file '\nThe frame index in ground truth and prediction is not equal']);
42 | end
43 |
44 | t = length(gt{2});
45 | for z=1:t
46 | gt{2}{z} = gt{2}{z}(1);
47 | pred{2}{z} = pred{2}{z}(1);
48 | end
49 | % reassigning the phase labels to numbers
50 | gtLabelID = [];
51 | predLabelID = [];
52 | for j = 1:7
53 | gtLabelID(find(strcmp(num2str(j-1), gt{2}))) = j;
54 | predLabelID(find(strcmp(num2str(j-1), pred{2}))) = j;
55 | end
56 |
57 | % compute jaccard index, precision, recall, and the accuracy
58 | [jaccard, prec, rec, acc, f1] = Evaluate(gtLabelID, predLabelID, fps);
59 | jaccard1(:, i) = jaccard;
60 | prec1(:, i) = prec;
61 | rec1(:, i) = rec;
62 | acc1(i) = acc;
63 | f11(:, i) = f1;
64 | end
65 |
66 | acc = acc1;
67 | rec = rec1;
68 | prec = prec1;
69 | jaccard = jaccard1;
70 | f1 = f11;
71 |
72 | accPerVideo= acc;
73 |
74 | % Compute means and stds
75 | index = find(jaccard>100);
76 | jaccard(index)=100;
77 | meanJaccPerPhase = nanmean(jaccard, 2);
78 | meanJaccPerVideo = nanmean(jaccard, 1);
79 | meanJacc = mean(meanJaccPerPhase);
80 | stdJacc = std(meanJaccPerPhase);
81 | for h = 1:7
82 | jaccphase = jaccard(h,:);
83 | meanjaccphase(h) = nanmean(jaccphase);
84 | stdjaccphase(h) = nanstd(jaccphase);
85 | end
86 |
87 | index = find(f1>100);
88 | f1(index)=100;
89 | meanF1PerPhase = nanmean(f1, 2);
90 | meanF1PerVideo = nanmean(f1, 1);
91 | meanF1 = mean(meanF1PerPhase);
92 | stdF1 = std(meanF1PerVideo);
93 | for h = 1:7
94 | f1phase = f1(h,:);
95 | meanf1phase(h) = nanmean(f1phase);
96 | stdf1phase(h) = nanstd(f1phase);
97 | end
98 |
99 | index = find(prec>100);
100 | prec(index)=100;
101 | meanPrecPerPhase = nanmean(prec, 2);
102 | meanPrecPerVideo = nanmean(prec, 1);
103 | meanPrec = nanmean(meanPrecPerPhase);
104 | stdPrec = nanstd(meanPrecPerPhase);
105 | for h = 1:7
106 | precphase = prec(h,:);
107 | meanprecphase(h) = nanmean(precphase);
108 | stdprecphase(h) = nanstd(precphase);
109 | end
110 |
111 | index = find(rec>100);
112 | rec(index)=100;
113 | meanRecPerPhase = nanmean(rec, 2);
114 | meanRecPerVideo = nanmean(rec, 1);
115 | meanRec = mean(meanRecPerPhase);
116 | stdRec = std(meanRecPerPhase);
117 | for h = 1:7
118 | recphase = rec(h,:);
119 | meanrecphase(h) = nanmean(recphase);
120 | stdrecphase(h) = nanstd(recphase);
121 | end
122 |
123 |
124 | meanAcc = mean(acc);
125 | stdAcc = std(acc);
126 |
127 | % Display results
128 | % fprintf('model is :%s\n', model_rootfolder);
129 | disp('================================================');
130 | disp([sprintf('%25s', 'Phase') '|' sprintf('%6s', 'Jacc') '|'...
131 | sprintf('%6s', 'Prec') '|' sprintf('%6s', 'Rec') '|']);
132 | disp('================================================');
133 | for iPhase = 1:length(phases)
134 | disp([sprintf('%25s', phases{iPhase}) '|' sprintf('%6.2f', meanJaccPerPhase(iPhase)) '|' ...
135 | sprintf('%6.2f', meanPrecPerPhase(iPhase)) '|' sprintf('%6.2f', meanRecPerPhase(iPhase)) '|']);
136 | disp('---------------------------------------------');
137 | end
138 | disp('================================================');
139 |
140 | disp(['Mean jaccard: ' sprintf('%5.2f', meanJacc) '+-' sprintf('%5.2f', stdJacc)]);
141 | disp(['Mean f1-score: ' sprintf('%5.2f', meanF1) '+-' sprintf('%5.2f', stdF1)]);
142 | disp(['Mean accuracy: ' sprintf('%5.2f', meanAcc) '+-' sprintf('%5.2f', stdAcc)]);
143 | disp(['Mean precision: ' sprintf('%5.2f', meanPrec) '+-' sprintf('%5.2f', stdPrec)]);
144 | disp(['Mean recall: ' sprintf('%5.2f', meanRec) '+-' sprintf('%5.2f', stdRec)]);
--------------------------------------------------------------------------------
/evaluation_matlab/Main_Cataract101.m:
--------------------------------------------------------------------------------
1 | close all; clear all;
2 |
3 | phaseGroundTruths = {};
4 | gt_root_folder = 'phase_annotations/';
5 | for k = 110:149
6 | num = num2str(k);
7 | to_add = ['video-' num];
8 | video_name = [gt_root_folder to_add '.txt'];
9 | phaseGroundTruths = [phaseGroundTruths video_name];
10 | end
11 | % phaseGroundTruths = {'video41-phase.txt', ...
12 | % 'video42-phase.txt'};
13 | % phaseGroundTruths
14 |
15 | phases = {"Incision",
16 | "Viscous agent injection",
17 | "Rhexis",
18 | "Hydrodissection",
19 | "Phacoemulsificiation",
20 | "Irrigation and aspiration",
21 | "Capsule polishing",
22 | "Lens implant setting-up",
23 | "Viscous agent removal",
24 | "Tonifying and antibiotics"};
25 |
26 | fps = 1;
27 | jaccard1 = zeros(10, 40);
28 | prec1 = zeros(10, 40);
29 | rec1 = zeros(10, 40);
30 | acc1 = zeros(1, 40);
31 | f11 = zeros(10, 40);
32 |
33 |
34 | for i = 1:length(phaseGroundTruths)
35 | predroot = 'prediction/';
36 | phaseGroundTruth = phaseGroundTruths{i};
37 | predFile = [predroot 'video-' phaseGroundTruth(end-6:end-4) '.txt'];
38 | [gt] = ReadPhaseLabel(phaseGroundTruth);
39 | [pred] = ReadPhaseLabel(predFile);
40 |
41 | if(size(gt{1}, 1) ~= size(pred{1},1) || size(gt{2}, 1) ~= size(pred{2},1))
42 | error(['ERROR:' ground_truth_file '\nGround truth and prediction have different sizes']);
43 | end
44 |
45 | if(~isempty(find(gt{1} ~= pred{1})))
46 | error(['ERROR: ' ground_truth_file '\nThe frame index in ground truth and prediction is not equal']);
47 | end
48 |
49 | t = length(gt{2});
50 | for z=1:t
51 | gt{2}{z} = gt{2}{z}(1);
52 | pred{2}{z} = pred{2}{z}(1);
53 | end
54 | % reassigning the phase labels to numbers
55 | gtLabelID = [];
56 | predLabelID = [];
57 | for j = 1:10
58 | gtLabelID(find(strcmp(num2str(j-1), gt{2}))) = j;
59 | predLabelID(find(strcmp(num2str(j-1), pred{2}))) = j;
60 | end
61 |
62 | % compute jaccard index, precision, recall, and the accuracy
63 | [jaccard, prec, rec, acc, f1] = Evaluate_Cataract101(gtLabelID, predLabelID, fps);
64 | jaccard1(:, i) = jaccard;
65 | prec1(:, i) = prec;
66 | rec1(:, i) = rec;
67 | acc1(i) = acc;
68 | f11(:, i) = f1;
69 | end
70 |
71 | acc = acc1;
72 | rec = rec1;
73 | prec = prec1;
74 | jaccard = jaccard1;
75 | f1 = f11;
76 |
77 | accPerVideo= acc;
78 |
79 | % Compute means and stds
80 | index = find(jaccard>100);
81 | jaccard(index)=100;
82 | meanJaccPerPhase = nanmean(jaccard, 2);
83 | meanJaccPerVideo = nanmean(jaccard, 1);
84 | meanJacc = mean(meanJaccPerPhase);
85 | stdJacc = std(meanJaccPerPhase);
86 | for h = 1:10
87 | jaccphase = jaccard(h,:);
88 | meanjaccphase(h) = nanmean(jaccphase);
89 | stdjaccphase(h) = nanstd(jaccphase);
90 | end
91 |
92 | index = find(f1>100);
93 | f1(index)=100;
94 | meanF1PerPhase = nanmean(f1, 2);
95 | meanF1PerVideo = nanmean(f1, 1);
96 | meanF1 = mean(meanF1PerPhase);
97 | stdF1 = std(meanF1PerVideo);
98 | for h = 1:10
99 | f1phase = f1(h,:);
100 | meanf1phase(h) = nanmean(f1phase);
101 | stdf1phase(h) = nanstd(f1phase);
102 | end
103 |
104 | index = find(prec>100);
105 | prec(index)=100;
106 | meanPrecPerPhase = nanmean(prec, 2);
107 | meanPrecPerVideo = nanmean(prec, 1);
108 | meanPrec = nanmean(meanPrecPerPhase);
109 | stdPrec = nanstd(meanPrecPerPhase);
110 | for h = 1:10
111 | precphase = prec(h,:);
112 | meanprecphase(h) = nanmean(precphase);
113 | stdprecphase(h) = nanstd(precphase);
114 | end
115 |
116 | index = find(rec>100);
117 | rec(index)=100;
118 | meanRecPerPhase = nanmean(rec, 2);
119 | meanRecPerVideo = nanmean(rec, 1);
120 | meanRec = mean(meanRecPerPhase);
121 | stdRec = std(meanRecPerPhase);
122 | for h = 1:10
123 | recphase = rec(h,:);
124 | meanrecphase(h) = nanmean(recphase);
125 | stdrecphase(h) = nanstd(recphase);
126 | end
127 |
128 |
129 | meanAcc = mean(acc);
130 | stdAcc = std(acc);
131 |
132 | % Display results
133 | % fprintf('model is :%s\n', model_rootfolder);
134 | disp('================================================');
135 | disp([sprintf('%25s', 'Phase') '|' sprintf('%6s', 'Jacc') '|'...
136 | sprintf('%6s', 'Prec') '|' sprintf('%6s', 'Rec') '|']);
137 | disp('================================================');
138 | for iPhase = 1:length(phases)
139 | disp([sprintf('%25s', phases{iPhase}) '|' sprintf('%6.2f', meanJaccPerPhase(iPhase)) '|' ...
140 | sprintf('%6.2f', meanPrecPerPhase(iPhase)) '|' sprintf('%6.2f', meanRecPerPhase(iPhase)) '|']);
141 | disp('---------------------------------------------');
142 | end
143 | disp('================================================');
144 |
145 | disp(['Mean jaccard: ' sprintf('%5.2f', meanJacc) '+-' sprintf('%5.2f', stdJacc)]);
146 | disp(['Mean f1-score: ' sprintf('%5.2f', meanF1) '+-' sprintf('%5.2f', stdF1)]);
147 | disp(['Mean accuracy: ' sprintf('%5.2f', meanAcc) '+-' sprintf('%5.2f', stdAcc)]);
148 | disp(['Mean precision: ' sprintf('%5.2f', meanPrec) '+-' sprintf('%5.2f', stdPrec)]);
149 | disp(['Mean recall: ' sprintf('%5.2f', meanRec) '+-' sprintf('%5.2f', stdRec)]);
150 |
--------------------------------------------------------------------------------
/evaluation_matlab/Main_m2cai.m:
--------------------------------------------------------------------------------
1 | close all; clear all;
2 |
3 | phaseGroundTruths = {};
4 | gt_root_folder = '../gt-phase/';
5 | for k = 1:14
6 | num = num2str(k);
7 | to_add = ['video' num];
8 | video_name = [gt_root_folder to_add '-phase.txt'];
9 | phaseGroundTruths = [phaseGroundTruths video_name];
10 | end
11 | % phaseGroundTruths = {'video41-phase.txt', ...
12 | % 'video42-phase.txt'};
13 | % phaseGroundTruths
14 |
15 | phases = {'TrocarPlacement', 'Preparation', 'CalotTriangleDissection', ...
16 | 'ClippingCutting', 'GallbladderDissection', 'GallbladderPackaging', 'CleaningCoagulation', ...
17 | 'GallbladderRetraction'};
18 |
19 | fps = 25;
20 |
21 | for i = 1:length(phaseGroundTruths)
22 | predroot = '../phase/';
23 | %predroot = '../../Results/multi/phase';
24 | %predroot = '../../Results/multi_kl_best_890_882/phase_post';
25 | phaseGroundTruth = phaseGroundTruths{i};
26 | predFile = [predroot phaseGroundTruth(13:end-10) '-phase.txt'];
27 |
28 | [gt] = ReadPhaseLabel(phaseGroundTruth);
29 | [pred] = ReadPhaseLabel(predFile);
30 |
31 | if(size(gt{1}, 1) ~= size(pred{1},1) || size(gt{2}, 1) ~= size(pred{2},1))
32 | error(['ERROR:' ground_truth_file '\nGround truth and prediction have different sizes']);
33 | end
34 |
35 | if(~isempty(find(gt{1} ~= pred{1})))
36 | error(['ERROR: ' ground_truth_file '\nThe frame index in ground truth and prediction is not equal']);
37 | end
38 |
39 | % reassigning the phase labels to numbers
40 | gtLabelID = [];
41 | predLabelID = [];
42 | for j = 1:8
43 | gtLabelID(find(strcmp(num2str(j-1), gt{2}))) = j;
44 | predLabelID(find(strcmp(num2str(j-1), pred{2}))) = j;
45 | end
46 |
47 | % compute jaccard index, precision, recall, and the accuracy
48 | [jaccard(:,i), prec(:,i), rec(:,i), acc(i)] = Evaluate_m2cai(gtLabelID, predLabelID, fps);
49 |
50 | end
51 |
52 | % Compute means and stds
53 | index = find(jaccard>100);
54 | jaccard(index)=100;
55 | meanJaccPerPhase = nanmean(jaccard, 2);
56 | meanJacc = mean(meanJaccPerPhase);
57 | stdJacc = std(meanJaccPerPhase);
58 | for h = 1:8
59 | jaccphase = jaccard(h,:);
60 | meanjaccphase(h) = nanmean(jaccphase);
61 | stdjaccphase(h) = nanstd(jaccphase);
62 | end
63 |
64 | index = find(prec>100);
65 | prec(index)=100;
66 | meanPrecPerPhase = nanmean(prec, 2);
67 | meanPrec = nanmean(meanPrecPerPhase);
68 | stdPrec = nanstd(meanPrecPerPhase);
69 | for h = 1:8
70 | precphase = prec(h,:);
71 | meanprecphase(h) = nanmean(precphase);
72 | stdprecphase(h) = nanstd(precphase);
73 | end
74 |
75 | index = find(rec>100);
76 | rec(index)=100;
77 | meanRecPerPhase = nanmean(rec, 2);
78 | meanRec = mean(meanRecPerPhase);
79 | stdRec = std(meanRecPerPhase);
80 | for h = 1:8
81 | recphase = rec(h,:);
82 | meanrecphase(h) = nanmean(recphase);
83 | stdrecphase(h) = nanstd(recphase);
84 | end
85 |
86 |
87 | meanAcc = mean(acc);
88 | stdAcc = std(acc);
89 |
90 | % Display results
91 | % fprintf('model is :%s\n', model_rootfolder);
92 | disp('================================================');
93 | disp([sprintf('%25s', 'Phase') '|' sprintf('%6s', 'Jacc') '|'...
94 | sprintf('%6s', 'Prec') '|' sprintf('%6s', 'Rec') '|']);
95 | disp('================================================');
96 | for iPhase = 1:length(phases)
97 | disp([sprintf('%25s', phases{iPhase}) '|' sprintf('%6.2f', meanJaccPerPhase(iPhase)) '|' ...
98 | sprintf('%6.2f', meanPrecPerPhase(iPhase)) '|' sprintf('%6.2f', meanRecPerPhase(iPhase)) '|']);
99 | disp('---------------------------------------------');
100 | end
101 | disp('================================================');
102 |
103 | disp(['Mean jaccard: ' sprintf('%5.2f', meanJacc) ' +- ' sprintf('%5.2f', stdJacc)]);
104 | disp(['Mean accuracy: ' sprintf('%5.2f', meanAcc) ' +- ' sprintf('%5.2f', stdAcc)]);
105 | disp(['Mean precision: ' sprintf('%5.2f', meanPrec) ' +- ' sprintf('%5.2f', stdPrec)]);
106 | disp(['Mean recall: ' sprintf('%5.2f', meanRec) ' +- ' sprintf('%5.2f', stdRec)]);
107 |
--------------------------------------------------------------------------------
/evaluation_matlab/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ### MATLAB scripts to perform the evaluation
4 |
5 | ### Acknowledgement:
6 | MICCAI M2CAI challenge; the official webpage of the challenge can be found here: http://camma.u-strasbg.fr/m2cai2016
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/evaluation_matlab/ReadPhaseLabel.m:
--------------------------------------------------------------------------------
1 | function [ outp ] = ReadPhaseLabel( file )
2 | %READPHASELABEL
3 | % Read the phase label (annotation and prediction)
4 |
5 | fid = fopen(file, 'r');
6 |
7 | % read the header first
8 | tline = fgets(fid);
9 |
10 | % read the labels
11 | [outp] = textscan(fid, '%d\t%s\n', 'EndOfLine','\n' );
12 | % [outp] = textscan(fid, '%d\t%s\n');
13 | end
--------------------------------------------------------------------------------
/evaluation_matlab/matlab.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/evaluation_matlab/matlab.mat
--------------------------------------------------------------------------------
/evaluation_matlab/octave-workspace:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/evaluation_matlab/octave-workspace
--------------------------------------------------------------------------------
/evaluation_matlab/test.m:
--------------------------------------------------------------------------------
1 | fid = fopen(file, 'r');
2 | tline = fgets(fid);
3 | [outp] = textscan(fid, '%d %s\n');
4 |
--------------------------------------------------------------------------------
/model/__pycache__/mambapy.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/model/__pycache__/mambapy.cpython-311.pyc
--------------------------------------------------------------------------------
/model/__pycache__/pmnet.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/model/__pycache__/pmnet.cpython-311.pyc
--------------------------------------------------------------------------------
/model/__pycache__/pscan.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RascalGdd/PmNet/4e4b71c7b521cb2fa1f078014f68ac4909f299dd/model/__pycache__/pscan.cpython-311.pyc
--------------------------------------------------------------------------------
/model/pmnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from functools import partial
4 | import utils
5 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_
6 | from timm.models.registry import register_model
7 | from einops import rearrange
8 | from collections import OrderedDict
9 | import torch.nn.functional as F
10 | import torchvision
11 | from model.mambapy import Mamba_CSM, MambaConfig
12 |
13 | def _cfg(url="", **kwargs):
14 | return {
15 | "url": url,
16 | "num_classes": 7,
17 | "input_size": (3, 224, 224),
18 | "pool_size": None,
19 | "crop_pct": 0.9,
20 | "interpolation": "bicubic",
21 | "mean": (0.5, 0.5, 0.5),
22 | "std": (0.5, 0.5, 0.5),
23 | **kwargs,
24 | }
25 |
26 |
27 | def crop_and_pool(images, bboxes):
28 | """
29 | 使用 bbox 裁剪图片并进行平均池化。
30 |
31 | 参数:
32 | images: 形状为 (B, 3, T, 224, 224) 的张量。
33 | bboxes: 形状为 (B, T, 2, 2) 的张量,表示每个时间步的 bbox。
34 |
35 | 返回:
36 | 形状为 (B, 3, T, 1, 1) 的张量。
37 | """
38 | B, C, T, H, W = images.shape
39 | output = torch.zeros(B, C, T, 1, 1, device=images.device) # 初始化输出张量
40 |
41 | for b in range(B): # 遍历 batch
42 | for t in range(T): # 遍历时间步
43 | # 获取当前时间步的 bbox
44 | x1, y1 = bboxes[b, t, 0, 0], bboxes[b, t, 0, 1]
45 | x2, y2 = bboxes[b, t, 1, 0], bboxes[b, t, 1, 1]
46 |
47 | # 裁剪图片
48 | cropped = images[b, :, t, y1:y2, x1:x2] # 形状为 (3, h, w)
49 | # 对裁剪后的图片进行全局平均池化
50 | pooled = F.avg_pool2d(cropped.unsqueeze(0), (cropped.shape[1], cropped.shape[2])) # 形状为 (1, 3, 1, 1)
51 |
52 | # 将结果存入输出张量
53 | output[b, :, t] = pooled.squeeze(0)
54 |
55 | return output
56 |
57 | class DropPath(nn.Module):
58 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
59 |
60 | def __init__(self, drop_prob=None):
61 | super(DropPath, self).__init__()
62 | self.drop_prob = drop_prob
63 |
64 | def forward(self, x):
65 | return drop_path(x, self.drop_prob, self.training)
66 |
67 | def extra_repr(self) -> str:
68 | return "p={}".format(self.drop_prob)
69 |
70 |
71 | class Mlp(nn.Module):
72 | def __init__(
73 | self,
74 | in_features,
75 | hidden_features=None,
76 | out_features=None,
77 | act_layer=nn.GELU,
78 | drop=0.0,
79 | ):
80 | super().__init__()
81 | out_features = out_features or in_features
82 | hidden_features = hidden_features or in_features
83 | self.fc1 = nn.Linear(in_features, hidden_features)
84 | self.act = act_layer()
85 | self.fc2 = nn.Linear(hidden_features, out_features)
86 | self.drop = nn.Dropout(drop)
87 |
88 | def forward(self, x):
89 | x = self.fc1(x)
90 | x = self.act(x)
91 | x = self.drop(x)
92 | x = self.fc2(x)
93 | x = self.drop(x)
94 | return x
95 |
96 | class CrossAttention(nn.Module):
97 | def __init__(
98 | self,
99 | dim,
100 | num_heads=8,
101 | qkv_bias=False,
102 | qk_scale=None,
103 | attn_drop=0.0,
104 | proj_drop=0.0,
105 | ):
106 | super().__init__()
107 | self.num_heads = num_heads
108 | head_dim = dim // num_heads
109 | self.scale = qk_scale or head_dim**-0.5
110 |
111 | # Separate Linear layers for Query, Key, and Value
112 | self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
113 | self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
114 | self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
115 |
116 | self.proj = nn.Linear(dim, dim)
117 | self.proj_drop = nn.Dropout(proj_drop)
118 | self.attn_drop = nn.Dropout(attn_drop)
119 |
120 | def forward(self, query, key_value, B):
121 | """
122 | Args:
123 | query: Tensor of shape (B * K, T_query, C)
124 | key_value: Tensor of shape (B * K, T_key, C)
125 | B: Batch size
126 | """
127 | BK_query, T_query, C = query.shape
128 | BK_key, T_key, _ = key_value.shape
129 | K_query = BK_query // B
130 | K_key = BK_key // B
131 |
132 | # Generate Query, Key, Value
133 | q = self.q_proj(query) # (B * K_query, T_query, C)
134 | k = self.k_proj(key_value) # (B * K_key, T_key, C)
135 | v = self.v_proj(key_value) # (B * K_key, T_key, C)
136 |
137 | # Reshape for multi-head attention
138 | q = rearrange(q, "(b k) t (num_heads c) -> (b k) num_heads t c", k=K_query, num_heads=self.num_heads)
139 | k = rearrange(k, "(b k) t (num_heads c) -> (b k) num_heads t c", k=K_key, num_heads=self.num_heads)
140 | v = rearrange(v, "(b k) t (num_heads c) -> (b k) num_heads t c", k=K_key, num_heads=self.num_heads)
141 |
142 | # Compute attention scores
143 | attn = (q @ k.transpose(-2, -1)) * self.scale # (B * K_query, num_heads, T_query, T_key)
144 | attn = attn.softmax(dim=-1)
145 | attn = self.attn_drop(attn)
146 |
147 | # Compute attention output
148 | x = attn @ v # (B * K_query, num_heads, T_query, C_head)
149 | x = rearrange(x, "(b k) num_heads t c -> (b k) t (num_heads c)", b=B)
150 |
151 | # Project back to original dimension
152 | x = self.proj(x)
153 | return self.proj_drop(x)
154 |
155 | class Attention_Spatial(nn.Module):
156 | def __init__(
157 | self,
158 | dim,
159 | num_heads=8,
160 | qkv_bias=False,
161 | qk_scale=None,
162 | attn_drop=0.0,
163 | proj_drop=0.0,
164 | with_qkv=True,
165 | ):
166 | super().__init__()
167 | self.num_heads = num_heads
168 | head_dim = dim // num_heads
169 | self.scale = qk_scale or head_dim**-0.5
170 | self.with_qkv = with_qkv
171 | if self.with_qkv:
172 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
173 | self.proj = nn.Linear(dim, dim)
174 | self.proj_drop = nn.Dropout(proj_drop)
175 | self.attn_drop = nn.Dropout(attn_drop)
176 |
177 | def forward(self, x, B):
178 | BT, K, C = x.shape
179 | T = BT // B
180 | qkv = self.qkv(x)
181 | # For Intra-Spatial: (BT, heads, K, C)
182 | # Atten: K*K, Values: K*C
183 | qkv = rearrange(
184 | qkv,
185 | "(b t) k (qkv num_heads c) -> qkv (b t) num_heads k c",
186 | t=T,
187 | qkv=3,
188 | num_heads=self.num_heads,
189 | )
190 | q, k, v = (
191 | qkv[0],
192 | qkv[1],
193 | qkv[2],
194 | ) # make torchscript happy (cannot use tensor as tuple)
195 |
196 | attn = (q @ k.transpose(-2, -1)) * self.scale
197 | attn = attn.softmax(dim=-1)
198 | attn = self.attn_drop(attn)
199 |
200 | x = attn @ v
201 | x = rearrange(
202 | x,
203 | "(b t) num_heads k c -> (b t) k (num_heads c)",
204 | b=B,
205 | )
206 | x = self.proj(x)
207 | return self.proj_drop(x)
208 |
209 |
210 | class Attention_Temporal(nn.Module):
211 | def __init__(
212 | self,
213 | dim,
214 | num_heads=8,
215 | qkv_bias=False,
216 | qk_scale=None,
217 | attn_drop=0.0,
218 | proj_drop=0.0,
219 | with_qkv=True,
220 | ):
221 | super().__init__()
222 | self.num_heads = num_heads
223 | head_dim = dim // num_heads
224 | self.scale = qk_scale or head_dim**-0.5
225 | self.with_qkv = with_qkv
226 | if self.with_qkv:
227 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
228 | self.proj = nn.Linear(dim, dim)
229 | self.proj_drop = nn.Dropout(proj_drop)
230 | self.attn_drop = nn.Dropout(attn_drop)
231 |
232 | def forward(self, x, B):
233 | BK, T, C = x.shape
234 | K = BK // B
235 | qkv = self.qkv(x)
236 |
237 | # For Intra-Spatial: (BK, heads, T, C)
238 | # Atten: T*T, Values: T*C
239 | qkv = rearrange(
240 | qkv,
241 | "(b k) t (qkv num_heads c) -> qkv (b k) num_heads t c",
242 | k=K,
243 | qkv=3,
244 | num_heads=self.num_heads,
245 | )
246 | q, k, v = (
247 | qkv[0],
248 | qkv[1],
249 | qkv[2],
250 | ) # make torchscript happy (cannot use tensor as tuple)
251 |
252 | attn = (q @ k.transpose(-2, -1)) * self.scale
253 | attn = attn.softmax(dim=-1)
254 | attn = self.attn_drop(attn)
255 |
256 | x = attn @ v
257 | x = rearrange(
258 | x,
259 | "(b k) num_heads t c -> (b k) t (num_heads c)",
260 | b=B,
261 | )
262 |
263 | x = self.proj(x)
264 | return self.proj_drop(x)
265 |
266 | class VisionTransformer(nn.Module):
267 | """Vision Transformer"""
268 |
269 | def __init__(
270 | self,
271 | img_size=224,
272 | patch_size=16,
273 | in_chans=3,
274 | num_classes=7,
275 | num_heads=12,
276 | qkv_bias=False,
277 | qk_scale=None,
278 | fc_drop_rate=0.0,
279 | drop_rate=0.0,
280 | attn_drop_rate=0.0,
281 | drop_path_rate=0.0,
282 | norm_layer=nn.LayerNorm,
283 | all_frames=16,
284 | ):
285 | super().__init__()
286 | embed_dim = 1536
287 | self.num_classes = num_classes
288 | self.num_features = (
289 | self.embed_dim
290 | ) = embed_dim # num_features for consistency with other models
291 | backbone = torchvision.models.efficientnet_b3(weights=torchvision.models.EfficientNet_B3_Weights)
292 | self.backbone = nn.Sequential(*list(backbone.children())[:-1])
293 | self.reduce_dim = nn.Linear(1536, 1536)
294 | self.NumberToVector = nn.Linear(1, 1536)
295 | self.norm1 = norm_layer(embed_dim)
296 |
297 | # knot and release feats
298 | self.act_feats = torch.zeros(1536).cuda()
299 | self.release_feats = torch.zeros(1536).cuda()
300 | self.knot_feats = torch.zeros(1536).cuda()
301 | self.act_cnt = torch.tensor(0, device='cuda')
302 | self.release_cnt = torch.tensor(0, device='cuda')
303 | self.knot_cnt = torch.tensor(0, device='cuda')
304 | self.alpha = 0.9
305 |
306 | # Temporal Attention Parameters
307 | self.temporal_norm1 = norm_layer(embed_dim)
308 | self.temporal_attn = Attention_Temporal(
309 | embed_dim,
310 | num_heads=num_heads,
311 | qkv_bias=qkv_bias,
312 | qk_scale=qk_scale,
313 | attn_drop=0.0,
314 | proj_drop=0.0,
315 | )
316 | self.temporal_fc = nn.Linear(embed_dim, embed_dim)
317 |
318 | # SSM
319 | self.ca_norm = norm_layer(embed_dim)
320 | self.ca_attn = CrossAttention(
321 | embed_dim,
322 | num_heads=num_heads,
323 | qkv_bias=qkv_bias,
324 | qk_scale=qk_scale,
325 | attn_drop=0.0,
326 | proj_drop=0.0,
327 | )
328 | config = MambaConfig(d_model=self.embed_dim, n_layers=2)
329 | self.ssm = Mamba_CSM(config)
330 | self.ca_fc = nn.Linear(embed_dim, embed_dim)
331 |
332 | ## Positional Embeddings
333 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
334 | self.cls_token_swap = nn.Parameter(torch.zeros(1, 5, 1, embed_dim))
335 | self.time_embed = nn.Parameter(torch.zeros(1, all_frames, embed_dim))
336 | self.time_drop = nn.Dropout(p=drop_rate)
337 | self.mask = nn.Parameter(torch.zeros(embed_dim))
338 |
339 |
340 | self.norm = norm_layer(embed_dim)
341 |
342 | # Classifier head
343 | self.fc_dropout = (
344 | nn.Dropout(p=fc_drop_rate) if fc_drop_rate > 0 else nn.Identity()
345 | )
346 | self.head = (
347 | nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
348 | )
349 | self.fc_dropout_blocking = (
350 | nn.Dropout(p=fc_drop_rate) if fc_drop_rate > 0 else nn.Identity()
351 | )
352 | self.head_blocking = (
353 | nn.Linear(embed_dim, 2) if num_classes > 0 else nn.Identity()
354 | )
355 | trunc_normal_(self.cls_token, std=0.02)
356 | self.apply(self._init_weights)
357 |
358 | def _init_weights(self, m):
359 | if isinstance(m, nn.Linear):
360 | trunc_normal_(m.weight, std=0.02)
361 | if isinstance(m, nn.Linear) and m.bias is not None:
362 | nn.init.constant_(m.bias, 0)
363 | elif isinstance(m, nn.LayerNorm):
364 | nn.init.constant_(m.bias, 0)
365 | nn.init.constant_(m.weight, 1.0)
366 |
367 | @torch.jit.ignore
368 | def no_weight_decay(self):
369 | return {"pos_embed", "cls_token", "time_embed"}
370 |
371 | def get_classifier(self):
372 | return self.head
373 |
374 | def reset_classifier(self, num_classes, global_pool=""):
375 | self.num_classes = num_classes
376 | self.head = (
377 | nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
378 | )
379 |
380 | def forward(self, x, timestamp, bboxes, target=None):
381 | B, _, T, H, W = x.size()
382 |
383 | # pooling for RGB img/crop
384 | pooled_q = F.adaptive_avg_pool2d(x, (1, 1))
385 | pooled_q += crop_and_pool(x, bboxes)
386 | ssm_q = pooled_q.squeeze(-1).squeeze(-1).permute(0, 2, 1)
387 |
388 | x = rearrange(x, "b c t h w -> (b t) c h w")
389 | x = self.backbone(x)
390 | x = torch.squeeze(x)
391 | x = self.reduce_dim(x)
392 | x = rearrange(x, "(b t) c -> b t c", b=B)
393 | xt = x + self.time_embed # B, T, C
394 | timestamp_embed = self.NumberToVector(timestamp.unsqueeze(1))
395 | xt = xt + timestamp_embed.unsqueeze(1)
396 | xt = self.time_drop(xt)
397 |
398 | # temporal pooling
399 | kernel = torch.tensor([1 / 3, 1 / 3, 1 / 3], dtype=torch.float16).view(1, 1, -1)
400 | kernel = kernel.repeat(1536, 1, 1).to(xt.device)
401 | output = F.conv1d(xt.permute(0, 2, 1), kernel, padding=1, groups=1536)
402 | pooled = output.permute(0, 2, 1) # (B, N, C)
403 | pooled_last_group_0 = pooled[:, -1, :]
404 |
405 | # add masks
406 | if self.act_cnt != 0:
407 | dot_product_b = torch.matmul(pooled_last_group_0, self.act_feats.to(torch.float16))
408 | norm_a = torch.norm(pooled_last_group_0, p=2, dim=1)
409 | norm_act = torch.norm(self.act_feats, p=2)
410 |
411 | similarity_act = dot_product_b / (norm_a * norm_act + 1e-8)
412 | similarity_idx = (similarity_act < 0)
413 | time_idx = (timestamp > 0.5)
414 | mask_idx = similarity_idx * time_idx
415 | if True in mask_idx:
416 | xt[mask_idx, -1, :] += self.mask
417 |
418 | groups = rearrange(xt, "b (n k) c -> b n k c", k=4)
419 | cls_token = self.cls_token_swap.expand(xt.size(0), -1, -1, -1)
420 | groups_with_cls = torch.cat([cls_token, groups], dim=2)
421 |
422 | num_interactions = 4 # 4 interactions for temporal integration
423 | for _ in range(num_interactions):
424 | groups_reshaped = rearrange(groups_with_cls, "b n k c -> (b n) k c")
425 | updated_groups = self.temporal_attn(self.temporal_norm1(groups_reshaped), B)
426 | updated_groups = self.temporal_fc(updated_groups) + groups_reshaped
427 | groups_with_cls = rearrange(updated_groups, "(b n) k c -> b n k c", b=B)
428 |
429 | # CLS token exchange
430 | cls_tokens = groups_with_cls[:, :, 0, :]
431 | cls_tokens = cls_tokens.roll(shifts=1, dims=1)
432 | groups_with_cls[:, :, 0, :] = cls_tokens
433 |
434 | groups_without_cls = groups_with_cls[:, :, 1:, :]
435 |
436 | last_group_feats = groups_without_cls[:, -1, :, :]
437 | pooled_last_group = last_group_feats[:, -1, :]
438 |
439 | # contrastive learning update prototype
440 | if self.training:
441 | act_idx = ((target == 1) + (target == 3))
442 | knot_idx = (target == 1)
443 | release_idx = (target == 3)
444 | if True in act_idx:
445 | act_cnt = sum(act_idx)
446 | act_feat = pooled_last_group[act_idx]
447 | summed_act = act_feat.sum(dim=0)
448 | # self.act_feats = self.act_feats * (self.act_cnt/(self.act_cnt+act_cnt)) + summed_act/(self.act_cnt+act_cnt)
449 | self.act_feats = self.act_feats * self.alpha + (1 - self.alpha) * (summed_act / act_cnt)
450 | self.act_cnt += act_cnt
451 |
452 | if True in knot_idx:
453 | knot_cnt = sum(knot_idx)
454 | knot_feat = pooled_last_group[knot_idx]
455 | summed_knot = knot_feat.sum(dim=0)
456 | self.knot_feats = self.knot_feats * self.alpha + (1 - self.alpha) * (summed_knot / knot_cnt)
457 | self.knot_cnt += knot_cnt
458 |
459 | if True in release_idx:
460 | release_cnt = sum(release_idx)
461 | release_feat = pooled_last_group[release_idx]
462 | summed_release = release_feat.sum(dim=0)
463 | self.release_feats = self.release_feats * self.alpha + (1 - self.alpha) * (summed_release / release_cnt)
464 | self.release_cnt += release_cnt
465 |
466 | xt = rearrange(groups_without_cls, "b n k c -> b (n k) c") # (B, 20, C),将结果展平回去
467 |
468 | ssm_xt = xt
469 | ssm_out = self.ssm(ssm_xt, ssm_q)
470 |
471 | xt = xt[:, -1, :] + x[:, -1, :]
472 |
473 | fused_xt = self.ca_attn(self.ca_norm(xt.unsqueeze(1)), self.ca_norm(ssm_out), B) # 输出 (B*4, 5, C)
474 | fused_xt = self.ca_fc(fused_xt)[:, -1, :] + xt
475 |
476 | output = self.head(self.fc_dropout(fused_xt))
477 | output_blocking = self.head_blocking(self.fc_dropout_blocking(fused_xt))
478 |
479 | output_idx = output.argmax(-1)
480 | knot_FP = ((target == 3) * (output_idx == 1))
481 | release_FP = ((target == 1) * (output_idx == 3))
482 | contrastive_dict = {}
483 | if True in knot_FP:
484 | knot_FP_feats = pooled_last_group[knot_FP]
485 | contrastive_dict['knot_FP_feats'] = knot_FP_feats
486 | if True in release_FP:
487 | release_FP_feats = pooled_last_group[release_FP]
488 | contrastive_dict['release_FP_feats'] = release_FP_feats
489 |
490 | return output, output_blocking, contrastive_dict
491 |
492 | @register_model
493 | def pmnet(pretrained=False, pretrain_path=None, **kwargs):
494 | model = VisionTransformer(
495 | img_size=224,
496 | patch_size=16,
497 | num_heads=12,
498 | qkv_bias=True,
499 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
500 | **kwargs,
501 | )
502 | model.default_cfg = _cfg()
503 |
504 | return model
--------------------------------------------------------------------------------
/model/pscan.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | """
7 |
8 | An implementation of the parallel scan operation in PyTorch (Blelloch version).
9 | Please see docs/pscan.ipynb for a detailed explanation of what happens here.
10 |
11 | """
12 |
13 |
14 | def npo2(len):
15 | """
16 | Returns the next power of 2 above len
17 | """
18 |
19 | return 2 ** math.ceil(math.log2(len))
20 |
21 |
22 | def pad_npo2(X):
23 | """
24 | Pads input length dim to the next power of 2
25 |
26 | Args:
27 | X : (B, L, D, N)
28 |
29 | Returns:
30 | Y : (B, npo2(L), D, N)
31 | """
32 |
33 | len_npo2 = npo2(X.size(1))
34 | pad_tuple = (0, 0, 0, 0, 0, len_npo2 - X.size(1))
35 | return F.pad(X, pad_tuple, "constant", 0)
36 |
37 |
38 | class PScan(torch.autograd.Function):
39 | @staticmethod
40 | def pscan(A, X):
41 | # A : (B, D, L, N)
42 | # X : (B, D, L, N)
43 |
44 | # modifies X in place by doing a parallel scan.
45 | # more formally, X will be populated by these values :
46 | # H[t] = A[t] * H[t-1] + X[t] with H[0] = 0
47 | # which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps)
48 |
49 | # only supports L that is a power of two (mainly for a clearer code)
50 |
51 | B, D, L, _ = A.size()
52 | num_steps = int(math.log2(L))
53 |
54 | # up sweep (last 2 steps unfolded)
55 | Aa = A
56 | Xa = X
57 | for _ in range(num_steps - 2):
58 | T = Xa.size(2)
59 | Aa = Aa.view(B, D, T // 2, 2, -1)
60 | Xa = Xa.view(B, D, T // 2, 2, -1)
61 |
62 | Xa[:, :, :, 1].add_(Aa[:, :, :, 1].mul(Xa[:, :, :, 0]))
63 | Aa[:, :, :, 1].mul_(Aa[:, :, :, 0])
64 |
65 | Aa = Aa[:, :, :, 1]
66 | Xa = Xa[:, :, :, 1]
67 |
68 | # we have only 4, 2 or 1 nodes left
69 | if Xa.size(2) == 4:
70 | Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
71 | Aa[:, :, 1].mul_(Aa[:, :, 0])
72 |
73 | Xa[:, :, 3].add_(Aa[:, :, 3].mul(Xa[:, :, 2] + Aa[:, :, 2].mul(Xa[:, :, 1])))
74 | elif Xa.size(2) == 2:
75 | Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
76 | return
77 | else:
78 | return
79 |
80 | # down sweep (first 2 steps unfolded)
81 | Aa = A[:, :, 2 ** (num_steps - 2) - 1:L:2 ** (num_steps - 2)]
82 | Xa = X[:, :, 2 ** (num_steps - 2) - 1:L:2 ** (num_steps - 2)]
83 | Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 1]))
84 | Aa[:, :, 2].mul_(Aa[:, :, 1])
85 |
86 | for k in range(num_steps - 3, -1, -1):
87 | Aa = A[:, :, 2 ** k - 1:L:2 ** k]
88 | Xa = X[:, :, 2 ** k - 1:L:2 ** k]
89 |
90 | T = Xa.size(2)
91 | Aa = Aa.view(B, D, T // 2, 2, -1)
92 | Xa = Xa.view(B, D, T // 2, 2, -1)
93 |
94 | Xa[:, :, 1:, 0].add_(Aa[:, :, 1:, 0].mul(Xa[:, :, :-1, 1]))
95 | Aa[:, :, 1:, 0].mul_(Aa[:, :, :-1, 1])
96 |
97 | @staticmethod
98 | def pscan_rev(A, X):
99 | # A : (B, D, L, N)
100 | # X : (B, D, L, N)
101 |
102 | # the same function as above, but in reverse
103 | # (if you flip the input, call pscan, then flip the output, you get what this function outputs)
104 | # it is used in the backward pass
105 |
106 | # only supports L that is a power of two (mainly for a clearer code)
107 |
108 | B, D, L, _ = A.size()
109 | num_steps = int(math.log2(L))
110 |
111 | # up sweep (last 2 steps unfolded)
112 | Aa = A
113 | Xa = X
114 | for _ in range(num_steps - 2):
115 | T = Xa.size(2)
116 | Aa = Aa.view(B, D, T // 2, 2, -1)
117 | Xa = Xa.view(B, D, T // 2, 2, -1)
118 |
119 | Xa[:, :, :, 0].add_(Aa[:, :, :, 0].mul(Xa[:, :, :, 1]))
120 | Aa[:, :, :, 0].mul_(Aa[:, :, :, 1])
121 |
122 | Aa = Aa[:, :, :, 0]
123 | Xa = Xa[:, :, :, 0]
124 |
125 | # we have only 4, 2 or 1 nodes left
126 | if Xa.size(2) == 4:
127 | Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 3]))
128 | Aa[:, :, 2].mul_(Aa[:, :, 3])
129 |
130 | Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1].add(Aa[:, :, 1].mul(Xa[:, :, 2]))))
131 | elif Xa.size(2) == 2:
132 | Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1]))
133 | return
134 | else:
135 | return
136 |
137 | # down sweep (first 2 steps unfolded)
138 | Aa = A[:, :, 0:L:2 ** (num_steps - 2)]
139 | Xa = X[:, :, 0:L:2 ** (num_steps - 2)]
140 | Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 2]))
141 | Aa[:, :, 1].mul_(Aa[:, :, 2])
142 |
143 | for k in range(num_steps - 3, -1, -1):
144 | Aa = A[:, :, 0:L:2 ** k]
145 | Xa = X[:, :, 0:L:2 ** k]
146 |
147 | T = Xa.size(2)
148 | Aa = Aa.view(B, D, T // 2, 2, -1)
149 | Xa = Xa.view(B, D, T // 2, 2, -1)
150 |
151 | Xa[:, :, :-1, 1].add_(Aa[:, :, :-1, 1].mul(Xa[:, :, 1:, 0]))
152 | Aa[:, :, :-1, 1].mul_(Aa[:, :, 1:, 0])
153 |
154 | @staticmethod
155 | def forward(ctx, A_in, X_in):
156 | """
157 | Applies the parallel scan operation, as defined above. Returns a new tensor.
158 | If you can, privilege sequence lengths that are powers of two.
159 |
160 | Args:
161 | A_in : (B, L, D, N)
162 | X_in : (B, L, D, N)
163 |
164 | Returns:
165 | H : (B, L, D, N)
166 | """
167 |
168 | L = X_in.size(1)
169 |
170 | # cloning is requiered because of the in-place ops
171 | if L == npo2(L):
172 | A = A_in.clone()
173 | X = X_in.clone()
174 | else:
175 | # pad tensors (and clone btw)
176 | A = pad_npo2(A_in) # (B, npo2(L), D, N)
177 | X = pad_npo2(X_in) # (B, npo2(L), D, N)
178 |
179 | # prepare tensors
180 | A = A.transpose(2, 1) # (B, D, npo2(L), N)
181 | X = X.transpose(2, 1) # (B, D, npo2(L), N)
182 |
183 | # parallel scan (modifies X in-place)
184 | PScan.pscan(A, X)
185 |
186 | ctx.save_for_backward(A_in, X)
187 |
188 | # slice [:, :L] (cut if there was padding)
189 | return X.transpose(2, 1)[:, :L]
190 |
191 | @staticmethod
192 | def backward(ctx, grad_output_in):
193 | """
194 | Flows the gradient from the output to the input. Returns two new tensors.
195 |
196 | Args:
197 | ctx : A_in : (B, L, D, N), X : (B, D, L, N)
198 | grad_output_in : (B, L, D, N)
199 |
200 | Returns:
201 | gradA : (B, L, D, N), gradX : (B, L, D, N)
202 | """
203 |
204 | A_in, X = ctx.saved_tensors
205 |
206 | L = grad_output_in.size(1)
207 |
208 | # cloning is requiered because of the in-place ops
209 | if L == npo2(L):
210 | grad_output = grad_output_in.clone()
211 | # the next padding will clone A_in
212 | else:
213 | grad_output = pad_npo2(grad_output_in) # (B, npo2(L), D, N)
214 | A_in = pad_npo2(A_in) # (B, npo2(L), D, N)
215 |
216 | # prepare tensors
217 | grad_output = grad_output.transpose(2, 1)
218 | A_in = A_in.transpose(2, 1) # (B, D, npo2(L), N)
219 | A = torch.nn.functional.pad(A_in[:, :, 1:],
220 | (0, 0, 0, 1)) # (B, D, npo2(L), N) shift 1 to the left (see hand derivation)
221 |
222 | # reverse parallel scan (modifies grad_output in-place)
223 | PScan.pscan_rev(A, grad_output)
224 |
225 | Q = torch.zeros_like(X)
226 | Q[:, :, 1:].add_(X[:, :, :-1] * grad_output[:, :, 1:])
227 |
228 | return Q.transpose(2, 1)[:, :L], grad_output.transpose(2, 1)[:, :L]
229 |
230 |
231 | pscan = PScan.apply
232 |
--------------------------------------------------------------------------------
/scripts/test_phase.sh:
--------------------------------------------------------------------------------
1 | export CC=gcc-11
2 | export CXX=g++-11
3 |
4 | CUDA_VISIBLE_DEVICES=6,7 python -m torch.distributed.launch \
5 | --nproc_per_node=2 \
6 | --master_port 12324 \
7 | downstream_phase/run_phase_training.py \
8 | --batch_size 8 \
9 | --epochs 50 \
10 | --save_ckpt_freq 10 \
11 | --model surgformer_HTA \
12 | --pretrained_path pretrain_params/timesformer_base_patch16_224_K400.pyth \
13 | --mixup 0.8 \
14 | --cutmix 1.0 \
15 | --smoothing 0.1 \
16 | --lr 5e-4 \
17 | --layer_decay 0.75 \
18 | --warmup_epochs 5 \
19 | --data_path /jhcnas4/syangcw/M2CAI16-workflow \
20 | --eval_data_path /jhcnas4/syangcw/M2CAI16-workflow \
21 | --nb_classes 8 \
22 | --data_strategy online \
23 | --output_mode key_frame \
24 | --num_frames 16 \
25 | --sampling_rate 4 \
26 | --eval \
27 | --finetune /jhcnas4/syangcw/Surgformerv2/M2CAI16/surgformer_HTA_M2CAI16_0.0005_0.75_online_key_frame_frame16_Fixed_Stride_4/checkpoint-best/mp_rank_00_model_states.pt \
28 | --data_set M2CAI16 \
29 | --data_fps 1fps \
30 | --output_dir /jhcnas4/syangcw/Surgformerv2/M2CAI16 \
31 | --log_dir /jhcnas4/syangcw/Surgformerv2/M2CAI16 \
32 | --num_workers 10 \
33 | --dist_eval \
34 | --enable_deepspeed \
35 | --no_auto_resume
--------------------------------------------------------------------------------
/scripts/train_phase.sh:
--------------------------------------------------------------------------------
1 | export CC=gcc-11
2 | export CXX=g++-11
3 |
4 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch \
5 | --nproc_per_node=4 \
6 | --master_port 12326 \
7 | downstream_phase/run_phase_training.py \
8 | --batch_size 8 \
9 | --epochs 50 \
10 | --save_ckpt_freq 10 \
11 | --model timesformer \
12 | --pretrained_path pretrain_params/timesformer_base_patch16_224_K400.pyth \
13 | --mixup 0.8 \
14 | --cutmix 1.0 \
15 | --smoothing 0.1 \
16 | --lr 5e-4 \
17 | --layer_decay 0.75 \
18 | --warmup_epochs 5 \
19 | --data_path /home/syangcw/LungSeg/LungSeg \
20 | --eval_data_path /home/syangcw/LungSeg/LungSeg \
21 | --nb_classes 7 \
22 | --data_strategy online \
23 | --output_mode key_frame \
24 | --num_frames 16 \
25 | --sampling_rate 4 \
26 | --data_set LungSeg \
27 | --data_fps 1fps \
28 | --output_dir /results/LungSeg \
29 | --log_dir /results/LungSeg \
30 | --num_workers 10 \
31 | --dist_eval \
32 | --enable_deepspeed \
33 | --no_auto_resume
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 | python -m torch.distributed.launch \
2 | --nproc_per_node=2 \
3 | run_phase_training.py \
4 | --batch_size 8 \
5 | --epochs 50 \
6 | --save_ckpt_freq 10 \
7 | --model pmnet \
8 | --lr 3e-5 \
9 | --layer_decay 0.75 \
10 | --warmup_epochs 5 \
11 | --data_path data_path \
12 | --eval_data_path data_path \
13 | --nb_classes 5 \
14 | --data_strategy online \
15 | --output_mode key_frame \
16 | --num_frames 20 \
17 | --sampling_rate 8 \
18 | --data_set PmLR50 \
19 | --data_fps 1fps \
20 | --output_dir output_path \
21 | --log_dir output_path \
22 | --num_workers 10 \
23 | --enable_deepspeed \
24 | --no_auto_resume \
25 | --eval \
26 | --load_ckpt checkpoint_path
27 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | python -m torch.distributed.launch \
2 | --nproc_per_node=2 \
3 | run_phase_training.py \
4 | --batch_size 8 \
5 | --epochs 50 \
6 | --save_ckpt_freq 10 \
7 | --model pmnet \
8 | --lr 3e-5 \
9 | --layer_decay 0.75 \
10 | --warmup_epochs 5 \
11 | --data_path data_path \
12 | --eval_data_path data_path \
13 | --nb_classes 5 \
14 | --data_strategy online \
15 | --output_mode key_frame \
16 | --num_frames 20 \
17 | --sampling_rate 8 \
18 | --data_set PmLR50 \
19 | --data_fps 1fps \
20 | --output_dir output_path \
21 | --log_dir output_path \
22 | --num_workers 10 \
23 | --enable_deepspeed \
24 | --no_auto_resume
25 |
--------------------------------------------------------------------------------