├── .gitignore ├── README.md ├── crop_dataset.py ├── figure ├── COCO-Stuff.png ├── Category-Wise.png └── Cityscapes.png ├── fine_tuning_mlp.py ├── fine_tuning_tr.py ├── loader ├── __init__.py ├── dataloader.py └── netloader.py ├── models ├── __init__.py ├── dinomaevit.py ├── dinov2vit.py ├── ibotvit.py └── msnvit.py ├── modules ├── __init__.py ├── segment.py └── segment_module.py ├── requirements.txt ├── run ├── test_mlp.py ├── test_tr.py ├── train_front_door_mlp.py ├── train_front_door_tr.py ├── train_mediator.py └── utils ├── __init__.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | CAUSE/ 3 | checkpoint/ 4 | logs/ 5 | results/ 6 | visualize_concept.py 7 | visualize_distance_matrix.py 8 | scikit_cluster.py 9 | linear_probing_mlp.py 10 | linear_probing_tr.py 11 | test_category_mlp.py 12 | test_category_tr.py 13 | crop.py 14 | */__pycache__ 15 | __pycache__ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## * Under Repairing 2 | # ***Title: [Causal Unsupervised Semantic Segmentation](https://arxiv.org/pdf/2310.07379v1.pdf)*** 3 | 4 | 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/causal-unsupervised-semantic-segmentation/unsupervised-semantic-segmentation-on-coco-7)](https://paperswithcode.com/sota/unsupervised-semantic-segmentation-on-coco-7?p=causal-unsupervised-semantic-segmentation) 6 | 7 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/causal-unsupervised-semantic-segmentation/unsupervised-semantic-segmentation-on)](https://paperswithcode.com/sota/unsupervised-semantic-segmentation-on?p=causal-unsupervised-semantic-segmentation) 8 | 9 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/causal-unsupervised-semantic-segmentation/unsupervised-semantic-segmentation-on-coco-8)](https://paperswithcode.com/sota/unsupervised-semantic-segmentation-on-coco-8?p=causal-unsupervised-semantic-segmentation) 10 | 11 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/causal-unsupervised-semantic-segmentation/unsupervised-semantic-segmentation-on-coco-6)](https://paperswithcode.com/sota/unsupervised-semantic-segmentation-on-coco-6?p=causal-unsupervised-semantic-segmentation) 12 | 13 | 14 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/causal-unsupervised-semantic-segmentation/unsupervised-semantic-segmentation-on-pascal-1)](https://paperswithcode.com/sota/unsupervised-semantic-segmentation-on-pascal-1?p=causal-unsupervised-semantic-segmentation) 15 | 16 | This is pytorch implementation code for realizing the technical part of *CAusal Unsupervised Semantic sEgmentation (CAUSE)* to improve performance of unsupervised semantic segmentation. 17 | This code is further developed by two baseline codes of [HP: Leveraging Hidden Positives for Unsupervised Semantic Segmentation](https://github.com/hynnsk/HP) accepted in [CVPR 2023](https://openaccess.thecvf.com/content/CVPR2023/papers/Seong_Leveraging_Hidden_Positives_for_Unsupervised_Semantic_Segmentation_CVPR_2023_paper.pdf) 18 | and [STEGO: Unsupervised Semantic Segmentation by Distilling Feature Correspondences](https://github.com/mhamilton723/STEGO) accepted in [ICLR 2022](https://iclr.cc/virtual/2022/poster/6068). 19 | 20 | 21 | --- 22 | 23 | You can see the following bundle of images in Appendix. 24 | Further, we explain concrete implementation beyond the description of the main paper. 25 | 26 |
27 |
28 | Figure 1. Visual comparison of USS for COCO-stuff. Note that, in contrast to 29 | true labels, baseline frameworks fail to 30 | achieve targeted level of granularity, while CAUSE successfully clusters person, sports, vehicle, etc. 31 |
32 | 33 | 34 |
35 | Figure 2. Qualitative comparison of unsupervised semantic segmentation for Cityscapes. 36 |
37 | 38 | 39 |
40 | Figure 3. Log scale of mIoU results for each categories in COCO-Stuff (Black: Thing / Gray: Stuff ) 41 |
42 | 43 | 44 |
45 | 46 | 47 | 48 | --- 49 | 50 | ## 🚀 Download Visual Quality, Seg Head Parameter, and Concept ClusterBook of CAUSE 51 | 52 | You can download the checkpoint files including CAUSE-trained parameters based on 53 | [DINO](https://openaccess.thecvf.com/content/ICCV2021/papers/Caron_Emerging_Properties_in_Self-Supervised_Vision_Transformers_ICCV_2021_paper.pdf), [DINOv2](https://arxiv.org/pdf/2304.07193.pdf), [iBOT](https://openreview.net/pdf?id=ydopy-e6Dg), [MSN](https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136910442.pdf), [MAE](https://openaccess.thecvf.com/content/CVPR2022/papers/He_Masked_Autoencoders_Are_Scalable_Vision_Learners_CVPR_2022_paper.pdf) 54 | in self-supervised vision transformer framework. 55 | If you want to download the pretrained models of DINO in various structures the following CAUSE uses, 56 | you can download them in the following links: 57 | 58 | * [DINO](https://github.com/facebookresearch/dino), ICCV 2021 59 | * [DINOv2](https://github.com/facebookresearch/dinov2), ArXiv 2023 60 | * [iBOT](https://github.com/bytedance/ibot), ICLR 2022 61 | * [MSN](https://github.com/facebookresearch/msn), ECCV 2022 62 | * [MAE](https://github.com/facebookresearch/mae), CVPR 2022 63 | 64 | --- 65 | 66 | 67 | | Dataset | Method | Baseline | mIoU(%) | pAcc(%) | Visual Quality | Seg Head Parameter | Concept ClusterBook | 68 | |:------------|---------------|------------|:-------:|:-------:|:---------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:| 69 | | COCO-Stuff | DINO+**CAUSE-MLP** | ViT-S/8 | 27.9 | 66.8 | [[link]](https://drive.google.com/file/d/1Z0Zj9JWJQQk6qeRctcdAk9MfyZQCwkvW/view?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1ByLMYly-lLAa4vBQZ8Sv8nLSWBLPbev-?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/14bq-B4Xj4V3Usl2b2SfobCOaap4lzIXl?usp=drive_link) | 70 | | COCO-Stuff | DINO+**CAUSE-TR** | ViT-S/8 | 32.4 | 69.6 | [[link]](https://drive.google.com/file/d/1x9LNwCiXtZel-fTh8TqtRgHmmmrIFPgg/view?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1ByLMYly-lLAa4vBQZ8Sv8nLSWBLPbev-?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/14bq-B4Xj4V3Usl2b2SfobCOaap4lzIXl?usp=drive_link) | 71 | | COCO-Stuff | DINO+**CAUSE-MLP** | ViT-S/16 | 25.9 | 66.3 | [[link]](https://drive.google.com/file/d/1wcMomwarw5gQ3sSSmQlZICtP4r3kZMN8/view?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1PfOHDxWF_YcPVOApUSK-xHDUSY32domZ?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1GnVOgtOZdt8N7M6cudd5d59FqZAQDDG5?usp=drive_link)| 72 | | COCO-Stuff | DINO+**CAUSE-TR** | ViT-S/16 | 33.1 | 70.4 | [[link]](https://drive.google.com/file/d/198_-3BvN_GCI63_Mx4lEPCHl0L9Fk2p2/view?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1PfOHDxWF_YcPVOApUSK-xHDUSY32domZ?usp=drive_link) |[[link]](https://drive.google.com/drive/folders/1GnVOgtOZdt8N7M6cudd5d59FqZAQDDG5?usp=drive_link) | 73 | | COCO-Stuff | DINO+**CAUSE-MLP** | ViT-B/8 | 34.3 | 72.8 | [[link]](https://drive.google.com/file/d/1fmUs3UOsWVhOXvcbxjG9c-VT2vEaVzte/view?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1Og2U2ihbPBrxpAAeWuped_FH4u_ecb0P?usp=drive_link) |[[link]](https://drive.google.com/drive/folders/10bZecU1EzgOISoi0RkajqSR-ebfAWr_N?usp=drive_link) | 74 | | COCO-Stuff | DINO+**CAUSE-TR** | ViT-B/8 | 41.9 | 74.9 | [[link]](https://drive.google.com/file/d/107jUAW4Y6xCMB7AgtgMIFcBitLMmQHaT/view?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1Og2U2ihbPBrxpAAeWuped_FH4u_ecb0P?usp=drive_link) |[[link]](https://drive.google.com/drive/folders/10bZecU1EzgOISoi0RkajqSR-ebfAWr_N?usp=drive_link) | 75 | | COCO-Stuff | DINOv2+**CAUSE-TR** | ViT-B/14| 45.3 | 78.0 | [[link]](https://drive.google.com/file/d/1e_Mub-u1EJOqzI7umk4BGgFApWukixmb/view?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1CWMussJAvGulg94lUrNn3EooILbtgIOb?usp=drive_link) |[[link]](https://drive.google.com/drive/folders/1nBjBJnucRiBYFiJHeOJUuE5e1YSoYKaf?usp=drive_link) | 76 | | COCO-Stuff | iBOT+**CAUSE-TR** | ViT-B/16 | 39.5 | 73.8 | [[link]](https://drive.google.com/file/d/1px6M068h3TH4wAxhH9sHSKrMZqreL9z2/view?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1BAMopzQNU7cmiaCyFv73SBLj2F6gikuU?usp=drive_link) |[[link]](https://drive.google.com/drive/folders/1mbJdzpOrR-sjmAk0O1hnqzsStXDU3CL9?usp=drive_link) | 77 | | COCO-Stuff | MSN+**CAUSE-TR** | ViT-S/16 | 34.1 | 72.1 | [[link]](https://drive.google.com/file/d/1R9KH3q9SxyitMzDGoKYQK4GkxuFMn7HQ/view?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/15F2aaVMbG40ISSXTL0f_UrVkX3UAFwZw?usp=drive_link) |[[link]](https://drive.google.com/drive/folders/19Mv7_5sM6e48eH80bAZSagXhO9CfEbCS?usp=drive_link) | 78 | | COCO-Stuff | MAE+**CAUSE-TR** | ViT-B/16 | 21.5 | 59.1 | [[link]](https://drive.google.com/file/d/1_vwGG51DN5rJliDKUcc-9DKLbklroJw9/view?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1ubUbmSliqrN19v6Abqsb_djtWDbx9qbV?usp=drive_link) |[[link]](https://drive.google.com/drive/folders/1G91qCJx-Z3IpYFYLFAUyG1zqMMZOhQWx?usp=drive_link) | 79 | 80 | 81 | 82 | --- 83 | 84 | | Dataset | Method | Baseline | mIoU(%) | pAcc(%) | Visual Quality | Seg Head Parameter | Concept ClusterBook | 85 | |:------------|---------------|------------|:-------:|:-------:|:---------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:| 86 | | Cityscapes | DINO+**CAUSE-MLP** | ViT-S/8 | 21.7 | 87.7 | [[link]](https://drive.google.com/file/d/1sC7OltZGfXCCyPhEaHJ596mczBUMGiEr/view?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1MT_HPyZvn09jEsvnlGci9ZLDB2e6h4PI?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1-ZfobyjlUGx5nltnBnjSzKzLQLTqcD_r?usp=drive_link) | 87 | | Cityscapes | DINO+**CAUSE-TR** | ViT-S/8 | 24.6 | 89.4 | [[link]](https://drive.google.com/file/d/1HEk9DSFHV0i-9SNqCDtmKhQUcPhSsu2P/view?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1MT_HPyZvn09jEsvnlGci9ZLDB2e6h4PI?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1-ZfobyjlUGx5nltnBnjSzKzLQLTqcD_r?usp=drive_link) | 88 | | Cityscapes | DINO+**CAUSE-MLP** | ViT-B/8 | 25.7 | 90.3 | [[link]](https://drive.google.com/file/d/1T4urliZtG-mJgjr1k-AczlC7c6EmWovP/view?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1Y7K3v_IUUn82rq5df6cQUagL_sNLZRdT?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1EoidRFHOT1w8LFNt2ws7C1BCkYdI4fw1?usp=drive_link) | 89 | | Cityscapes | DINO+**CAUSE-TR** | ViT-B/8 | 28.0 | 90.8 | [[link]](https://drive.google.com/file/d/1hQUT8jmzj9StBF_3QN87SL2_HO5n9yxp/view?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1Y7K3v_IUUn82rq5df6cQUagL_sNLZRdT?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1EoidRFHOT1w8LFNt2ws7C1BCkYdI4fw1?usp=drive_link) | 90 | | Cityscapes | DINOv2+**CAUSE-TR** | ViT-B/14 | 29.9 | 89.8 | [[link]](https://drive.google.com/file/d/1SUKv38yrayooAVsW2VWLbg6iy64syWnV/view?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1fi_DvMD3CLaZEozEgrGhIh6nH7WFq_Sj?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1t66yv8_otlAMwy-QQyff-6fiwP58kCvV?usp=drive_link) | 91 | | Cityscapes | iBOT+**CAUSE-TR** | ViT-B/16 | 23.0 | 89.1 | [[link]](https://drive.google.com/file/d/1ZDCr0k6WdmjWFw6J-S7Y6HFf88tGfEAO/view?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1T9OqBTc9tw9h3zDzzi137l8ls29_uOrd?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1t4qsGMYlIWoArvkAFr-uUerkwQYPR5u7?usp=drive_link) | 92 | | Cityscapes | MSN+**CAUSE-TR** | ViT-S/16 | 21.2 | 89.1 | [[link]](https://drive.google.com/file/d/1-jSkmwRObBKOHdiMuu3eLaXWgQFMeida/view?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1UQnhVADQvbnQKLjIXzEpY_hW76Yeeuuj?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1TLGaZjljYoVCFp4EjjOghtvczk-SzVR1?usp=drive_link) | 93 | | Cityscapes | MAE+**CAUSE-TR** | ViT-B/16 | 12.5 | 82.0 | [[link]](https://drive.google.com/file/d/1241UvEi0zc5JS88fga2rZCS4wkDuaE3c/view?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1Ng9mVhzAipmY5aPzJkX35flqJIQ8rgIp?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1SFYWWo5Khqoy8fIhvxuL2XoZEaKt1UH-?usp=drive_link) | 94 | 95 | 96 | --- 97 | 98 | | Dataset | Method | Baseline | mIoU(%) | pAcc(%) | Visual Quality | Seg Head Parameter | Concept ClusterBook | 99 | |:------------|---------------|------------|:-------:|:-------:|:---------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:| 100 | | Pascal VOC | DINO+**CAUSE-MLP** | ViT-S/8 | 46.0 | - | [[link]](https://drive.google.com/file/d/1nzZMGCqb7mYdSXN59xzMkQUitzjxJt-9/view?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1o6AkLzqC1J-V4YB_S7BBhfGd2E6qdopO?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1H9dvIDaEW1fsIKsI1HPETD4NC2dj6Z6S?usp=drive_link) | 101 | | Pascal VOC | DINO+**CAUSE-TR** | ViT-S/8 | 50.0 | - | [[link]](https://drive.google.com/file/d/1Q-2ey069mDHnziGlP1olEc-JSHBf7t6N/view?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1o6AkLzqC1J-V4YB_S7BBhfGd2E6qdopO?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1H9dvIDaEW1fsIKsI1HPETD4NC2dj6Z6S?usp=drive_link) | 102 | | Pascal VOC | DINO+**CAUSE-MLP** | ViT-B/8 | 47.9 | - | [[link]](https://drive.google.com/file/d/1EWlKNbcWGSNXBhZpdezv3ghCcxItR-Zj/view?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1sPlG9jQ-DljVguPNPDS1g3xnW_rTtUyw?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1zsTx1NOECcJ7DH1wa654wRH_NLV6FHWP?usp=drive_link) | 103 | | Pascal VOC | DINO+**CAUSE-TR** | ViT-B/8 | 53.3 | - | [[link]](https://drive.google.com/file/d/1pqJNoCpCz3wMMjIMxQJ-WOxtdwvsaJWM/view?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1sPlG9jQ-DljVguPNPDS1g3xnW_rTtUyw?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1zsTx1NOECcJ7DH1wa654wRH_NLV6FHWP?usp=drive_link) | 104 | | Pascal VOC | DINOv2+**CAUSE-TR** | ViT-B/14 | 53.2 | 91.5 | [[link]](https://drive.google.com/file/d/17FBfHfyML6jyeY5NvPJXUDI_vaxC87vk/view?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1qsuKMVvpqsaYcVvZj3rDecmhZBAHOTeK?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1aGs3KSExQrdQytVFWigOGRs2yC3YLO12?usp=drive_link) | 105 | | Pascal VOC | iBOT+**CAUSE-TR** | ViT-B/16 | 53.4 | 89.6 | [[link]](https://drive.google.com/file/d/1UjkvZ0MFxL-P0kaeUQKSGnVrjPsaeHaY/view?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1G9zvKcLNbhAyqKlJXUpt80CR6MBtuSdi?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1zA9d5eo41GerRWuBOnHY6_AjhtjGETCy?usp=drive_link) | 106 | | Pascal VOC | MSN+**CAUSE-TR** | ViT-S/16 | 30.2 | 84.2 | [[link]](https://drive.google.com/file/d/1by4USHNiEzem17s7jWKKUZTfUylQWyIy/view?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1s6nzSmzt_ZTt_tCDvf8vmhU2RmE3Cdy0?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1bwLogZo3vJOJrpSanRpgZ_1WHTeB3e4q?usp=drive_link) | 107 | | Pascal VOC | MAE+**CAUSE-TR** | ViT-B/16 | 25.8 | 83.7 | [[link]](https://drive.google.com/file/d/1odjO5dgTTdmsWGuGG7xFPi3ZLsV6-stl/view?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1Re_f8QgIdXDnrNwP_5g-kFL-6SoE9YPU?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1wOVWzTnfH58My8PT8rXW0sw2oVdc4xkL?usp=drive_link) | 108 | 109 | 110 | --- 111 | 112 | | Dataset | Method | Baseline | mIoU(%) | pAcc(%) | Visual Quality | Seg Head Parameter | Concept ClusterBook | 113 | |:------------|---------------|------------|:-------:|:-------:|:---------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:| 114 | | COCO-81 | DINO+**CAUSE-MLP** | ViT-S/8 | 19.1 | 78.8 | [[link]](https://drive.google.com/file/d/1Glxb7DHHhjxPQjkGygQM2prak7oH8dTt/view?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1SlQ1_3phGBvjaxizjcYD92Nglab7RV6k?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1ON5vDLS_Wc5OGgTxVK_yFopQGHbKaOar?usp=drive_link) | 115 | | COCO-81 | DINO+**CAUSE-TR** | ViT-S/8 | 21.2 | 75.2 | [[link]](https://drive.google.com/file/d/1QJmkV57mhKx6_A0E-yQcrQifX8lMRspO/view?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1SlQ1_3phGBvjaxizjcYD92Nglab7RV6k?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1ON5vDLS_Wc5OGgTxVK_yFopQGHbKaOar?usp=drive_link) | 116 | | COCO-171 | DINO+**CAUSE-MLP** | ViT-S/8 | 10.6 | 44.9 | [[link]](https://drive.google.com/file/d/1EUDqFTHVlr2c8cIR9oTbjpS83Js6RW66/view?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1SlQ1_3phGBvjaxizjcYD92Nglab7RV6k?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1ON5vDLS_Wc5OGgTxVK_yFopQGHbKaOar?usp=drive_link) | 117 | | COCO-171 | DINO+**CAUSE-TR** | ViT-S/8 | 15.2 | 46.6 | [[link]](https://drive.google.com/file/d/1Gv6306XUb-rbWB980O5m5vxQeZKIModT/view?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1SlQ1_3phGBvjaxizjcYD92Nglab7RV6k?usp=drive_link) | [[link]](https://drive.google.com/drive/folders/1ON5vDLS_Wc5OGgTxVK_yFopQGHbKaOar?usp=drive_link) | 118 | 119 | 120 | --- 121 | 122 | ## 🤖 CAUSE Framework (Top-Level File Directory Layout) 123 | . 124 | ├── loader 125 | │ ├── netloader.py # Self-Supervised Pretrained Model Loader & Segmentation Head Loader 126 | │ └── dataloader.py # Dataloader Thanks to STEGO [ICLR 2022] 127 | │ 128 | ├── models # Model Design of Self-Supervised Pretrained: [DINO/DINOv2/iBOT/MAE/MSN] 129 | │ ├── dinomaevit.py # ViT Structure of DINO and MAE 130 | │ ├── dinov2vit.py # ViT Structure of DINOv2 131 | │ ├── ibotvit.py # ViT Structure of iBOT 132 | │ └── msnvit.py # ViT Structure of MSN 133 | │ 134 | ├── modules # Segmentation Head and Its Necessary Function 135 | │ └── segment_module.py # [Including Tools with Generating Concept Book and Contrastive Learning 136 | │ └── segment.py # [MLP & TR] Including Tools with Generating Concept Book and Contrastive Learning 137 | │ 138 | ├── utils 139 | │ └── utils.py # Utility for auxiliary tools 140 | │ 141 | ├── train_modularity.py # (STEP 1) [MLP & TR] Generating Concept Cluster Book as a Mediator 142 | │ 143 | ├── train_front_door_mlp.py # (STEP 2) [MLP] Frontdoor Adjustment through Unsupervised Semantic Segmentation 144 | ├── fine_tuning_mlp.py # (STEP 3) [MLP] Fine-Tuning Cluster Probe 145 | │ 146 | ├── train_front_door_tr.py # (STEP 2) [TR] Frontdoor Adjustment through Unsupervised Semantic Segmentation 147 | ├── fine_tuning_tr.py # (STEP 3) [TR] Fine-Tuning Cluster Probe 148 | │ 149 | ├── test_mlp.py # [MLP] Evaluating Unsupervised Semantic Segmantation Performance (Post-Processing) 150 | ├── test_tr.py # [TR] Evaluating Unsupervised Semantic Segmantation Performance (Post-Processing) 151 | │ 152 | ├── requirements.txt 153 | └── README.md 154 | 155 | 156 | --- 157 | ## 📊 How to Run CAUSE? 158 | 159 | 160 | For the first, we should generate the cropped dataset by following [STEGO](https://github.com/mhamilton723/STEGO) in ICLR 2022. 161 | 162 | 163 | ```shell script 164 | python crop_dataset.py --dataset "cocostuff27" --crop_type "five" 165 | python crop_dataset.py --dataset "cityscapes" --crop_type "five" 166 | python crop_dataset.py --dataset "pascalvoc" --crop_type "super" 167 | python crop_dataset.py --dataset "cooc81" --crop_type "double" 168 | python crop_dataset.py --dataset "cooc171" --crop_type "double" 169 | ``` 170 | 171 | And then, 172 | 173 | ```shell bash 174 | bash run # All of the following three steps integrated 175 | ``` 176 | 177 | In this shell script file, you can see the following code 178 | 179 | ```shell script 180 | #!/bin/bash 181 | ###################################### 182 | # [OPTION] DATASET 183 | # cocostuff27 184 | dataset="cocostuff27" 185 | ############# 186 | 187 | ###################################### 188 | # [OPTION] STRUCTURE 189 | structure="TR" 190 | ###################################### 191 | 192 | ###################################### 193 | # [OPTION] Self-Supervised Method 194 | ckpt="checkpoint/dino_vit_base_8.pth" 195 | ###################################### 196 | 197 | ###################################### 198 | # GPU and PORT 199 | if [ "$structure" = "MLP" ] 200 | then 201 | train_gpu="0,1,2,3" 202 | elif [ "$structure" = "TR" ] 203 | then 204 | train_gpu="4,5,6,7" 205 | fi 206 | 207 | # Non-Changeable Variable 208 | test_gpu="${train_gpu:0}" 209 | port=$(($RANDOM%800+1200)) 210 | ###################################### 211 | 212 | ###################################### 213 | # [STEP1] MEDIATOR 214 | python train_mediator.py --dataset $dataset --ckpt $ckpt --gpu $train_gpu --port $port 215 | ###################################### 216 | 217 | ###################################### 218 | # [STEP2] CAUSE 219 | if [ "$structure" = "MLP" ] 220 | then 221 | python train_front_door_mlp.py --dataset $dataset --ckpt $ckpt --gpu $train_gpu --port $port 222 | python fine_tuning_mlp.py --dataset $dataset --ckpt $ckpt --gpu $train_gpu --port $port 223 | elif [ "$structure" = "TR" ] 224 | then 225 | python train_front_door_tr.py --dataset $dataset --ckpt $ckpt --gpu $train_gpu --port $port 226 | python fine_tuning_tr.py --dataset $dataset --ckpt $ckpt --gpu $train_gpu --port $port 227 | fi 228 | ###################################### 229 | 230 | ###################################### 231 | # TEST 232 | if [ "$structure" = "MLP" ] 233 | then 234 | python test_mlp.py --dataset $dataset --ckpt $ckpt --gpu $test_gpu 235 | elif [ "$structure" = "TR" ] 236 | then 237 | python test_tr.py --dataset $dataset --ckpt $ckpt --gpu $test_gpu 238 | fi 239 | ###################################### 240 | ``` 241 | 242 | 243 | ### 1. Training CAUSE 244 | 245 | ### (STEP 1): Generating Mediator based on Modularity 246 | 247 | ```shell script 248 | python train_mediator.py # DINO/DINOv2/iBOT/MSN/MAE 249 | ``` 250 | 251 | ### (STEP 2): Frontdoor Adjustment through Contrastive Learning 252 | 253 | ```shell script 254 | python train_front_door_mlp.py # CAUSE-MLP 255 | 256 | # or 257 | 258 | python train_front_door_tr.py # CAUSE-TR 259 | ``` 260 | 261 | 262 | ### (STEP 3): *Technical STEP: Fine-Tuning Cluster Probe* 263 | 264 | ```shell script 265 | python fine_tuning_mlp.py # CAUSE-MLP 266 | 267 | # or 268 | 269 | python fine_tuning_tr.py # CAUSE-TR 270 | ``` 271 | 272 | --- 273 | 274 | ### 2. Testing CAUSE 275 | 276 | ```shell script 277 | python test_mlp.py # CAUSE-MLP 278 | 279 | # or 280 | 281 | python test_tr.py # CAUSE-TR 282 | ``` 283 | 284 | --- 285 | 286 | 287 | ## 💡 Environment Settings 288 | 289 | * Creating Virtual Environment by Anaconda 290 | > conda create -y -n neurips python=3.9 291 | 292 | * Installing [PyTorch]((https://pytorch.org/)) Package in Virtual Envrionment 293 | > pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 294 | 295 | * Installing Pip Package 296 | > pip install -r requirements.txt 297 | 298 | * [Optional] Removing Conda and PIP Cache if Conda and PIP have been locked by unknown reasons 299 | > conda clean -a && pip cache purge 300 | 301 | --- 302 | 303 | ## 🍅 Download Datasets 304 | ### Available Datasets 305 | * [COCO-Stuff](https://paperswithcode.com/dataset/coco-stuff) 306 | * [Cityscapes](https://paperswithcode.com/dataset/cityscapes) 307 | * [Pascal VOC](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html) 308 | 309 | *Note: Pascal VOC is not necessary to download because dataloader will automatically download in your own dataset path* 310 | 311 | ### Try the following scripts 312 | > * wget https://marhamilresearch4.blob.core.windows.net/stego-public/pytorch_data/cityscapes.zip 313 | > * wget https://marhamilresearch4.blob.core.windows.net/stego-public/pytorch_data/cocostuff.zip 314 | 315 | ### If the above do not work, then download [azcopy](https://learn.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-v10?toc=%2Fazure%2Fstorage%2Fblobs%2Ftoc.json&bc=%2Fazure%2Fstorage%2Fblobs%2Fbreadcrumb%2Ftoc.json) and follow the below scripts 316 | > * azcopy copy "https://marhamilresearch4.blob.core.windows.net/stego-public/pytorch_data/cityscapes.zip" "custom_path" --recursive 317 | > * azcopy copy "https://marhamilresearch4.blob.core.windows.net/stego-public/pytorch_data/cocostuff.zip" "custom_path" --recursive 318 | 319 | 320 | ### Unzip Datasets 321 | 322 | ```shell script 323 | unzip cocostuff.zip && unzip cityscapes.zip 324 | ``` 325 | 326 | --- 327 | -------------------------------------------------------------------------------- /crop_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | from PIL import Image 5 | from os.path import join 6 | from utils.utils import * 7 | from torch.utils.data import DataLoader 8 | from loader.dataloader import ContrastiveSegDataset 9 | from torchvision.transforms.functional import five_crop, ten_crop 10 | from tqdm import tqdm 11 | from torch.utils.data import Dataset 12 | from torchvision import transforms as T 13 | 14 | class RandomCropComputer(Dataset): 15 | 16 | @staticmethod 17 | def _get_size(img, crop_ratio): 18 | if len(img.shape) == 3: 19 | return [int(img.shape[1] * crop_ratio), int(img.shape[2] * crop_ratio)] 20 | elif len(img.shape) == 2: 21 | return [int(img.shape[0] * crop_ratio), int(img.shape[1] * crop_ratio)] 22 | else: 23 | raise ValueError("Bad image shape {}".format(img.shape)) 24 | 25 | def __init__(self, args, dataset_name, img_set, crop_type, crop_ratio): 26 | self.pytorch_data_dir = args.data_dir 27 | self.crop_ratio = crop_ratio 28 | 29 | if crop_type == 'five': 30 | crop_func = lambda x: five_crop(x, self._get_size(x, crop_ratio)) 31 | elif crop_type == 'double': 32 | crop_ratio = 0 33 | crop_func = lambda x: ten_crop(x, self._get_size(x, 0.5))\ 34 | + ten_crop(x, self._get_size(x, 0.8)) 35 | elif crop_type == 'super': 36 | crop_ratio = 0 37 | crop_func = lambda x: ten_crop(x, self._get_size(x, 0.3))\ 38 | + ten_crop(x, self._get_size(x, 0.4))\ 39 | + ten_crop(x, self._get_size(x, 0.5))\ 40 | + ten_crop(x, self._get_size(x, 0.6))\ 41 | + ten_crop(x, self._get_size(x, 0.7)) 42 | 43 | if args.dataset=='coco171': 44 | self.save_dir = join( 45 | args.data_dir, 'cocostuff', "cropped", "coco171_{}_crop_{}".format(crop_type, crop_ratio)) 46 | elif args.dataset=='coco81': 47 | self.save_dir = join( 48 | args.data_dir, 'cocostuff', "cropped", "coco81_{}_crop_{}".format(crop_type, crop_ratio)) 49 | else: 50 | self.save_dir = join( 51 | args.data_dir, dataset_name, "cropped", "{}_{}_crop_{}".format(dataset_name, crop_type, crop_ratio)) 52 | self.args = args 53 | 54 | self.img_dir = join(self.save_dir, "img", img_set) 55 | self.label_dir = join(self.save_dir, "label", img_set) 56 | os.makedirs(self.img_dir, exist_ok=True) 57 | os.makedirs(self.label_dir, exist_ok=True) 58 | 59 | # train dataset 60 | self.dataset = ContrastiveSegDataset( 61 | pytorch_data_dir=args.data_dir, 62 | dataset_name=args.dataset, 63 | crop_type=None, 64 | image_set=img_set, 65 | transform=T.ToTensor(), 66 | target_transform=ToTargetTensor(), 67 | extra_transform=crop_func 68 | ) 69 | 70 | def __getitem__(self, item): 71 | return self.dataset[item] 72 | 73 | def __len__(self): 74 | return len(self.dataset) 75 | 76 | 77 | def my_app(): 78 | 79 | # fetch args 80 | parser = argparse.ArgumentParser() 81 | 82 | # fixed parameter 83 | parser.add_argument('--num_workers', default=int(os.cpu_count() / 8), type=int) 84 | 85 | # dataset and baseline 86 | parser.add_argument('--data_dir', default='/mnt/hard2/lbk-iccv/datasets', type=str) 87 | parser.add_argument('--dataset', default='cocostuff27', type=str) 88 | parser.add_argument('--gpu', default=1, type=int) 89 | parser.add_argument('--distributed', default='false', type=str2bool) 90 | parser.add_argument('--crop_type', default='five', type=str) 91 | parser.add_argument('--crop_ratio', default=0.5, type=float) 92 | 93 | args = parser.parse_args() 94 | 95 | # setting gpu id of this process 96 | torch.cuda.set_device(args.gpu) 97 | 98 | counter = 0 99 | dataset = RandomCropComputer(args, args.dataset, "train", args.crop_type, args.crop_ratio) 100 | loader = DataLoader(dataset, 1, shuffle=False, num_workers=args.num_workers, collate_fn=lambda l: l) 101 | for batch in tqdm(loader): 102 | imgs = batch[0]['img'] 103 | labels = batch[0]['label'] 104 | for img, label in zip(imgs, labels): 105 | img_arr = img.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() 106 | label_arr = (label + 1).unsqueeze(0).permute(1, 2, 0).to('cpu', torch.uint8).numpy().squeeze(-1) 107 | Image.fromarray(img_arr).save(join(dataset.img_dir, "{}.jpg".format(counter)), 'JPEG') 108 | Image.fromarray(label_arr).save(join(dataset.label_dir, "{}.png".format(counter)), 'PNG') 109 | counter+=1 110 | 111 | if __name__ == "__main__": 112 | my_app() -------------------------------------------------------------------------------- /figure/COCO-Stuff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ByungKwanLee/Causal-Unsupervised-Segmentation/ee0aa8478a6b6704f4db44ecc70e44a14fe5067f/figure/COCO-Stuff.png -------------------------------------------------------------------------------- /figure/Category-Wise.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ByungKwanLee/Causal-Unsupervised-Segmentation/ee0aa8478a6b6704f4db44ecc70e44a14fe5067f/figure/Category-Wise.png -------------------------------------------------------------------------------- /figure/Cityscapes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ByungKwanLee/Causal-Unsupervised-Segmentation/ee0aa8478a6b6704f4db44ecc70e44a14fe5067f/figure/Cityscapes.png -------------------------------------------------------------------------------- /fine_tuning_mlp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch.nn.init 4 | from tqdm import tqdm 5 | from utils.utils import * 6 | from modules.segment_module import transform, untransform, compute_modularity_based_codebook 7 | import torch.distributed as dist 8 | import torch.multiprocessing as mp 9 | import torch.backends.cudnn as cudnn 10 | from loader.dataloader import dataloader 11 | from torch.cuda.amp import autocast, GradScaler 12 | from loader.netloader import network_loader, segment_mlp_loader, cluster_mlp_loader 13 | 14 | cudnn.benchmark = True 15 | scaler = GradScaler() 16 | 17 | def ddp_setup(args, rank, world_size): 18 | os.environ['MASTER_ADDR'] = 'localhost' 19 | os.environ['MASTER_PORT'] = args.port 20 | 21 | # initialize 22 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 23 | 24 | 25 | def ddp_clean(): 26 | dist.destroy_process_group() 27 | 28 | 29 | @Wrapper.EpochPrint 30 | def train(args, net, segment, cluster, train_loader, optimizer_segment, optimizer_cluster): 31 | global counter 32 | segment.train() 33 | 34 | total_acc = 0 35 | total_loss = 0 36 | total_loss_linear = 0 37 | total_loss_mod = 0 38 | 39 | 40 | prog_bar = tqdm(enumerate(train_loader), total=len(train_loader), leave=True) 41 | for idx, batch in prog_bar: 42 | 43 | # optimizer 44 | with autocast(): 45 | 46 | # image and label and self supervised feature 47 | img = batch["img"].cuda() 48 | label = batch["label"].cuda() 49 | 50 | # intermediate features 51 | feat = net(img)[:, 1:, :] 52 | seg_feat_ema = segment.head_ema(feat, segment.dropout) 53 | 54 | # computing modularity based codebook 55 | loss_mod = compute_modularity_based_codebook(cluster.cluster_probe, seg_feat_ema, grid=args.grid) 56 | 57 | # linear probe loss 58 | linear_logits = segment.linear(seg_feat_ema) 59 | linear_logits = F.interpolate(linear_logits, label.shape[-2:], mode='bilinear', align_corners=False) 60 | flat_linear_logits = linear_logits.permute(0, 2, 3, 1).reshape(-1, args.n_classes) 61 | flat_label = label.reshape(-1) 62 | flat_label_mask = (flat_label >= 0) & (flat_label < args.n_classes) 63 | loss_linear = F.cross_entropy(flat_linear_logits[flat_label_mask], flat_label[flat_label_mask]) 64 | 65 | # loss 66 | loss = loss_linear + loss_mod 67 | 68 | # optimizer 69 | optimizer_segment.zero_grad() 70 | optimizer_cluster.zero_grad() 71 | scaler.scale(loss).backward() 72 | if args.dataset=='cityscapes': 73 | scaler.unscale_(optimizer_segment) 74 | torch.nn.utils.clip_grad_norm_(segment.parameters(), 1) 75 | elif args.dataset=='cocostuff27': 76 | scaler.unscale_(optimizer_segment) 77 | torch.nn.utils.clip_grad_norm_(segment.parameters(), 0.1) 78 | scaler.step(optimizer_segment) 79 | scaler.step(optimizer_cluster) 80 | scaler.update() 81 | 82 | # linear probe acc check 83 | pred_label = linear_logits.argmax(dim=1) 84 | flat_pred_label = pred_label.reshape(-1) 85 | acc = (flat_pred_label[flat_label_mask] == flat_label[flat_label_mask]).sum() / flat_label[ 86 | flat_label_mask].numel() 87 | total_acc += acc.item() 88 | 89 | # loss check 90 | total_loss += loss.item() 91 | total_loss_linear += loss_linear.item() 92 | total_loss_mod += loss_mod.item() 93 | 94 | # real-time print 95 | desc = f'[Train] Loss: {total_loss / (idx + 1):.2f}={total_loss_linear / (idx + 1):.2f}{total_loss_mod / (idx + 1):.2f}' 96 | desc += f' ACC: {100. * total_acc / (idx + 1):.1f}%' 97 | prog_bar.set_description(desc, refresh=True) 98 | 99 | # Interrupt for sync GPU Process 100 | if args.distributed: dist.barrier() 101 | 102 | 103 | @Wrapper.TestPrint 104 | def test(args, net, segment, cluster, nice, test_loader): 105 | global counter_test 106 | segment.eval() 107 | 108 | total_acc = 0 109 | prog_bar = tqdm(enumerate(test_loader), total=len(test_loader), leave=True) 110 | for idx, batch in prog_bar: 111 | # image and label and self supervised feature 112 | img = batch["img"].cuda() 113 | label = batch["label"].cuda() 114 | 115 | # intermediate feature 116 | with autocast(): 117 | 118 | feat = net(img)[:, 1:, :] 119 | seg_feat_ema = segment.head_ema(feat) 120 | 121 | # linear probe loss 122 | linear_logits = segment.linear(seg_feat_ema) 123 | linear_logits = F.interpolate(linear_logits, label.shape[-2:], mode='bilinear', align_corners=False) 124 | flat_label = label.reshape(-1) 125 | flat_label_mask = (flat_label >= 0) & (flat_label < args.n_classes) 126 | 127 | # interp feat 128 | interp_seg_feat = F.interpolate(transform(seg_feat_ema), label.shape[-2:], mode='bilinear', align_corners=False) 129 | 130 | # cluster 131 | cluster_preds = cluster.forward_centroid(untransform(interp_seg_feat), inference=True) 132 | 133 | # linear probe acc check 134 | pred_label = linear_logits.argmax(dim=1) 135 | flat_pred_label = pred_label.reshape(-1) 136 | acc = (flat_pred_label[flat_label_mask] == flat_label[flat_label_mask]).sum() / flat_label[ 137 | flat_label_mask].numel() 138 | total_acc += acc.item() 139 | 140 | # nice evaluation 141 | _, desc_nice = nice.eval(cluster_preds, label) 142 | 143 | # real-time print 144 | desc = f'[TEST] Acc (Linear): {100. * total_acc / (idx + 1):.1f}% | {desc_nice}' 145 | prog_bar.set_description(desc, refresh=True) 146 | 147 | # evaludation metric reset 148 | nice.reset() 149 | 150 | # Interrupt for sync GPU Process 151 | if args.distributed: dist.barrier() 152 | 153 | 154 | def main(rank, args, ngpus_per_node): 155 | 156 | # setup ddp process 157 | if args.distributed: ddp_setup(args, rank, ngpus_per_node) 158 | 159 | # setting gpu id of this process 160 | torch.cuda.set_device(rank) 161 | 162 | # print argparse 163 | print_argparse(args, rank) 164 | 165 | # dataset loader 166 | train_loader, test_loader, sampler = dataloader(args) 167 | 168 | # network loader 169 | net = network_loader(args, rank) 170 | segment = segment_mlp_loader(args, rank) 171 | cluster = cluster_mlp_loader(args, rank) 172 | 173 | # distributed parsing 174 | if args.distributed: net = net.module; segment = segment.module; cluster = cluster.module 175 | 176 | # optimizer 177 | if args.dataset=='cityscapes': 178 | optimizer_segment = torch.optim.Adam(segment.parameters(), lr=1e-3 * ngpus_per_node) 179 | optimizer_cluster = torch.optim.Adam(cluster.parameters(), lr=1e-3 * ngpus_per_node) 180 | else: 181 | optimizer_segment = torch.optim.Adam(segment.parameters(), lr=1e-3 * ngpus_per_node, weight_decay=1e-4) 182 | optimizer_cluster = torch.optim.Adam(cluster.parameters(), lr=1e-3 * ngpus_per_node) 183 | 184 | # scheduler 185 | scheduler_segment = torch.optim.lr_scheduler.StepLR(optimizer_segment, step_size=2, gamma=0.5) 186 | scheduler_cluster = torch.optim.lr_scheduler.StepLR(optimizer_cluster, step_size=2, gamma=0.5) 187 | 188 | # evaluation 189 | nice = NiceTool(args.n_classes) 190 | 191 | ################################################################################### 192 | # First, run train_mediator.py 193 | path, is_exist = pickle_path_and_exist(args) 194 | 195 | # early save for time 196 | if is_exist: 197 | # load 198 | codebook = np.load(path) 199 | cb = torch.from_numpy(codebook).cuda() 200 | cluster.codebook.data = cb 201 | cluster.codebook.requires_grad = False 202 | 203 | # print successful loading modularity 204 | rprint(f'Modularity {path} loaded', rank) 205 | 206 | # Interrupt for sync GPU Process 207 | if args.distributed: dist.barrier() 208 | 209 | else: 210 | rprint('Train Modularity-based Codebook First', rank) 211 | return 212 | ################################################################################### 213 | 214 | 215 | # train 216 | for epoch in range(args.epoch): 217 | 218 | # for shuffle 219 | if args.distributed: sampler.set_epoch(epoch) 220 | 221 | 222 | # train 223 | train( 224 | epoch, # for decorator 225 | rank, # for decorator 226 | args, 227 | net, 228 | segment, 229 | cluster, 230 | train_loader, 231 | optimizer_segment, 232 | optimizer_cluster) 233 | 234 | test( 235 | epoch, # for decorator 236 | rank, # for decorator 237 | args, 238 | net, 239 | segment, 240 | cluster, 241 | nice, 242 | test_loader) 243 | 244 | scheduler_segment.step() 245 | scheduler_cluster.step() 246 | 247 | if (rank == 0): 248 | x = segment.state_dict() 249 | baseline = args.ckpt.split('/')[-1].split('.')[0] 250 | 251 | # filepath hierarchy 252 | check_dir(f'CAUSE/{args.dataset}/{baseline}/{args.num_codebook}') 253 | 254 | # save path 255 | y = f'CAUSE/{args.dataset}/{baseline}/{args.num_codebook}/segment_mlp.pth' 256 | torch.save(segment.state_dict(), y) 257 | 258 | y = f'CAUSE/{args.dataset}/{baseline}/{args.num_codebook}/cluster_mlp.pth' 259 | torch.save(cluster.state_dict(), y) 260 | print(f'-----------------TEST Epoch {epoch}: SAVING CHECKPOINT IN {y}-----------------') 261 | 262 | # Interrupt for sync GPU Process 263 | if args.distributed: dist.barrier() 264 | 265 | # Closing DDP 266 | if args.distributed: dist.barrier(); dist.destroy_process_group() 267 | 268 | if __name__ == "__main__": 269 | 270 | # fetch args 271 | parser = argparse.ArgumentParser() 272 | # model parameter 273 | parser.add_argument('--NAME-TAG', default='CAUSE-MLP', type=str) 274 | parser.add_argument('--data_dir', default='/mnt/hard2/lbk-iccv/datasets/', type=str) 275 | parser.add_argument('--dataset', default='cocostuff27', type=str) 276 | parser.add_argument('--ckpt', default='checkpoint/dino_vit_base_8.pth', type=str) 277 | parser.add_argument('--epoch', default=1, type=int) 278 | parser.add_argument('--distributed', default=True, type=str2bool) 279 | parser.add_argument('--load_segment', default=True, type=str2bool) 280 | parser.add_argument('--load_cluster', default=False, type=str2bool) 281 | parser.add_argument('--train_resolution', default=224, type=int) 282 | parser.add_argument('--test_resolution', default=320, type=int) 283 | parser.add_argument('--batch_size', default=16, type=int) 284 | parser.add_argument('--num_workers', default=int(os.cpu_count() / 8), type=int) 285 | 286 | # DDP 287 | parser.add_argument('--gpu', default='0,1,2,3', type=str) 288 | parser.add_argument('--port', default='12355', type=str) 289 | 290 | # codebook parameter 291 | parser.add_argument('--grid', default='yes', type=str2bool) 292 | parser.add_argument('--num_codebook', default=2048, type=int) 293 | 294 | # model parameter 295 | parser.add_argument('--reduced_dim', default=90, type=int) 296 | parser.add_argument('--projection_dim', default=2048, type=int) 297 | 298 | args = parser.parse_args() 299 | 300 | if 'dinov2' in args.ckpt: 301 | args.test_resolution=322 302 | if 'small' in args.ckpt: 303 | args.dim=384 304 | elif 'base' in args.ckpt: 305 | args.dim=768 306 | 307 | # the number of gpus for multi-process 308 | gpu_list = list(map(int, args.gpu.split(','))) 309 | ngpus_per_node = len(gpu_list) 310 | 311 | if args.distributed: 312 | # cuda visible devices 313 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 314 | # multiprocess spawn 315 | mp.spawn(main, args=(args, ngpus_per_node), nprocs=ngpus_per_node, join=True) 316 | else: 317 | # first gpu index is activated once there are several gpu in args.gpu 318 | main(rank=gpu_list[0], args=args, ngpus_per_node=1) 319 | -------------------------------------------------------------------------------- /fine_tuning_tr.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch.nn.init 4 | from tqdm import tqdm 5 | from utils.utils import * 6 | from modules.segment_module import transform, untransform, compute_modularity_based_codebook 7 | import torch.distributed as dist 8 | import torch.multiprocessing as mp 9 | import torch.backends.cudnn as cudnn 10 | from loader.dataloader import dataloader 11 | from torch.cuda.amp import autocast, GradScaler 12 | from loader.netloader import network_loader, segment_tr_loader, cluster_tr_loader 13 | 14 | cudnn.benchmark = True 15 | scaler = GradScaler() 16 | 17 | cmap = create_pascal_label_colormap() 18 | 19 | def ddp_setup(args, rank, world_size): 20 | os.environ['MASTER_ADDR'] = 'localhost' 21 | os.environ['MASTER_PORT'] = args.port 22 | 23 | # initialize 24 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 25 | 26 | 27 | def ddp_clean(): 28 | dist.destroy_process_group() 29 | 30 | 31 | @Wrapper.EpochPrint 32 | def train(args, net, segment, cluster, train_loader, optimizer_segment, optimizer_cluster): 33 | global counter 34 | segment.train() 35 | 36 | total_acc = 0 37 | total_loss = 0 38 | total_loss_linear = 0 39 | total_loss_mod = 0 40 | 41 | prog_bar = tqdm(enumerate(train_loader), total=len(train_loader), leave=True) 42 | for idx, batch in prog_bar: 43 | 44 | # optimizer 45 | with autocast(): 46 | 47 | # image and label and self supervised feature 48 | img = batch["img"].cuda() 49 | label = batch["label"].cuda() 50 | 51 | # intermediate features 52 | feat = net(img)[:, 1:, :] 53 | seg_feat_ema = segment.head_ema(feat, segment.dropout) 54 | 55 | # computing modularity based codebook 56 | loss_mod = compute_modularity_based_codebook(cluster.cluster_probe, seg_feat_ema, grid=args.grid) 57 | 58 | # linear probe loss 59 | linear_logits = segment.linear(seg_feat_ema) 60 | linear_logits = F.interpolate(linear_logits, label.shape[-2:], mode='bilinear', align_corners=False) 61 | flat_linear_logits = linear_logits.permute(0, 2, 3, 1).reshape(-1, args.n_classes) 62 | flat_label = label.reshape(-1) 63 | flat_label_mask = (flat_label >= 0) & (flat_label < args.n_classes) 64 | loss_linear = F.cross_entropy(flat_linear_logits[flat_label_mask], flat_label[flat_label_mask]) 65 | 66 | # loss 67 | loss = loss_linear + loss_mod 68 | 69 | # optimizer 70 | optimizer_segment.zero_grad() 71 | optimizer_cluster.zero_grad() 72 | scaler.scale(loss).backward() 73 | if args.dataset=='cityscapes': 74 | scaler.unscale_(optimizer_segment) 75 | torch.nn.utils.clip_grad_norm_(segment.parameters(), 1) 76 | elif args.dataset=='cocostuff27': 77 | scaler.unscale_(optimizer_segment) 78 | torch.nn.utils.clip_grad_norm_(segment.parameters(), 2) 79 | scaler.step(optimizer_segment) 80 | scaler.step(optimizer_cluster) 81 | scaler.update() 82 | 83 | # linear probe acc check 84 | pred_label = linear_logits.argmax(dim=1) 85 | flat_pred_label = pred_label.reshape(-1) 86 | acc = (flat_pred_label[flat_label_mask] == flat_label[flat_label_mask]).sum() / flat_label[ 87 | flat_label_mask].numel() 88 | total_acc += acc.item() 89 | 90 | # loss check 91 | total_loss += loss.item() 92 | total_loss_linear += loss_linear.item() 93 | total_loss_mod += loss_mod.item() 94 | 95 | # real-time print 96 | desc = f'[Train] Loss: {total_loss / (idx + 1):.2f}={total_loss_linear / (idx + 1):.2f}{total_loss_mod / (idx + 1):.2f}' 97 | desc += f' ACC: {100. * total_acc / (idx + 1):.1f}%' 98 | prog_bar.set_description(desc, refresh=True) 99 | 100 | # Interrupt for sync GPU Process 101 | if args.distributed: dist.barrier() 102 | 103 | 104 | @Wrapper.TestPrint 105 | def test(args, net, segment, cluster, nice, test_loader): 106 | global counter_test 107 | segment.eval() 108 | 109 | total_acc = 0 110 | prog_bar = tqdm(enumerate(test_loader), total=len(test_loader), leave=True) 111 | for idx, batch in prog_bar: 112 | # image and label and self supervised feature 113 | img = batch["img"].cuda() 114 | label = batch["label"].cuda() 115 | 116 | # intermediate feature 117 | with autocast(): 118 | 119 | feat = net(img)[:, 1:, :] 120 | seg_feat_ema = segment.head_ema(feat) 121 | 122 | # linear probe loss 123 | linear_logits = segment.linear(seg_feat_ema) 124 | linear_logits = F.interpolate(linear_logits, label.shape[-2:], mode='bilinear', align_corners=False) 125 | flat_label = label.reshape(-1) 126 | flat_label_mask = (flat_label >= 0) & (flat_label < args.n_classes) 127 | 128 | # interp feat 129 | interp_seg_feat = F.interpolate(transform(seg_feat_ema), label.shape[-2:], mode='bilinear', align_corners=False) 130 | 131 | # cluster 132 | cluster_preds = cluster.forward_centroid(untransform(interp_seg_feat), inference=True) 133 | 134 | # linear probe acc check 135 | pred_label = linear_logits.argmax(dim=1) 136 | flat_pred_label = pred_label.reshape(-1) 137 | acc = (flat_pred_label[flat_label_mask] == flat_label[flat_label_mask]).sum() / flat_label[ 138 | flat_label_mask].numel() 139 | total_acc += acc.item() 140 | 141 | # nice evaluation 142 | _, desc_nice = nice.eval(cluster_preds, label) 143 | 144 | # real-time print 145 | desc = f'[TEST] Acc (Linear): {100. * total_acc / (idx + 1):.1f}% | {desc_nice}' 146 | prog_bar.set_description(desc, refresh=True) 147 | 148 | # evaludation metric reset 149 | nice.reset() 150 | 151 | # Interrupt for sync GPU Process 152 | if args.distributed: dist.barrier() 153 | 154 | 155 | 156 | 157 | 158 | def main(rank, args, ngpus_per_node): 159 | 160 | # setup ddp process 161 | if args.distributed: ddp_setup(args, rank, ngpus_per_node) 162 | 163 | # setting gpu id of this process 164 | torch.cuda.set_device(rank) 165 | 166 | # print argparse 167 | print_argparse(args, rank) 168 | 169 | # dataset loader 170 | train_loader, test_loader, sampler = dataloader(args) 171 | 172 | # network loader 173 | net = network_loader(args, rank) 174 | segment = segment_tr_loader(args, rank) 175 | cluster = cluster_tr_loader(args, rank) 176 | 177 | # distributed parsing 178 | if args.distributed: net = net.module; segment = segment.module; cluster = cluster.module 179 | 180 | # optimizer 181 | if args.dataset=='cityscapes': 182 | optimizer_segment = torch.optim.Adam(segment.parameters(), lr=1e-3 * ngpus_per_node) 183 | optimizer_cluster = torch.optim.Adam(cluster.parameters(), lr=1e-3 * ngpus_per_node) 184 | else: 185 | optimizer_segment = torch.optim.Adam(segment.parameters(), lr=1e-3 * ngpus_per_node, weight_decay=1e-4) 186 | optimizer_cluster = torch.optim.Adam(cluster.parameters(), lr=1e-3 * ngpus_per_node) 187 | 188 | # scheduler 189 | scheduler_segment = torch.optim.lr_scheduler.StepLR(optimizer_segment, step_size=2, gamma=0.5) 190 | scheduler_cluster = torch.optim.lr_scheduler.StepLR(optimizer_cluster, step_size=2, gamma=0.5) 191 | 192 | # evaluation 193 | nice = NiceTool(args.n_classes) 194 | 195 | ################################################################################### 196 | # First, run train_mediator.py 197 | path, is_exist = pickle_path_and_exist(args) 198 | 199 | # early save for time 200 | if is_exist: 201 | # load 202 | codebook = np.load(path) 203 | cb = torch.from_numpy(codebook).cuda() 204 | cluster.codebook.data = cb 205 | cluster.codebook.requires_grad = False 206 | segment.head.codebook = cb 207 | segment.head_ema.codebook = cb 208 | 209 | # print successful loading modularity 210 | rprint(f'Modularity {path} loaded', rank) 211 | 212 | # Interrupt for sync GPU Process 213 | if args.distributed: dist.barrier() 214 | 215 | else: 216 | rprint('Train Modularity-based Codebook First', rank) 217 | return 218 | ################################################################################### 219 | 220 | 221 | # train 222 | for epoch in range(args.epoch): 223 | 224 | # for shuffle 225 | if args.distributed: sampler.set_epoch(epoch) 226 | 227 | 228 | # train 229 | train( 230 | epoch, # for decorator 231 | rank, # for decorator 232 | args, 233 | net, 234 | segment, 235 | cluster, 236 | train_loader, 237 | optimizer_segment, 238 | optimizer_cluster) 239 | 240 | test( 241 | epoch, # for decorator 242 | rank, # for decorator 243 | args, 244 | net, 245 | segment, 246 | cluster, 247 | nice, 248 | test_loader) 249 | 250 | scheduler_segment.step() 251 | scheduler_cluster.step() 252 | 253 | if (rank == 0): 254 | baseline = args.ckpt.split('/')[-1].split('.')[0] 255 | 256 | # filepath hierarchy 257 | check_dir(f'CAUSE/{args.dataset}/{baseline}/{args.num_codebook}') 258 | 259 | # save path 260 | y = f'CAUSE/{args.dataset}/{baseline}/{args.num_codebook}/segment_tr.pth' 261 | torch.save(segment.state_dict(), y) 262 | 263 | y = f'CAUSE/{args.dataset}/{baseline}/{args.num_codebook}/cluster_tr.pth' 264 | torch.save(cluster.state_dict(), y) 265 | print(f'-----------------TEST Epoch {epoch}: SAVING CHECKPOINT IN {y}-----------------') 266 | 267 | # Interrupt for sync GPU Process 268 | if args.distributed: dist.barrier() 269 | 270 | # Closing DDP 271 | if args.distributed: dist.barrier(); dist.destroy_process_group() 272 | 273 | if __name__ == "__main__": 274 | 275 | # fetch args 276 | parser = argparse.ArgumentParser() 277 | # model parameter 278 | parser.add_argument('--NAME-TAG', default='CAUSE-TR', type=str) 279 | parser.add_argument('--data_dir', default='/mnt/hard2/lbk-iccv/datasets/', type=str) 280 | parser.add_argument('--dataset', default='cocostuff27', type=str) 281 | parser.add_argument('--ckpt', default='checkpoint/dino_vit_base_8.pth', type=str) 282 | parser.add_argument('--epoch', default=5, type=int) 283 | parser.add_argument('--distributed', default=True, type=str2bool) 284 | parser.add_argument('--load_segment', default=True, type=str2bool) 285 | parser.add_argument('--load_cluster', default=False, type=str2bool) 286 | parser.add_argument('--train_resolution', default=320, type=int) 287 | parser.add_argument('--test_resolution', default=320, type=int) 288 | parser.add_argument('--batch_size', default=16, type=int) 289 | parser.add_argument('--num_workers', default=int(os.cpu_count() / 8), type=int) 290 | 291 | # DDP 292 | parser.add_argument('--gpu', default='0,1,2,3', type=str) 293 | parser.add_argument('--port', default='12355', type=str) 294 | 295 | # codebook parameter 296 | parser.add_argument('--grid', default='yes', type=str2bool) 297 | parser.add_argument('--num_codebook', default=2048, type=int) 298 | 299 | # model parameter 300 | parser.add_argument('--reduced_dim', default=90, type=int) 301 | parser.add_argument('--projection_dim', default=2048, type=int) 302 | 303 | args = parser.parse_args() 304 | 305 | if 'dinov2' in args.ckpt: 306 | args.train_resolution=322 307 | args.test_resolution=322 308 | if 'small' in args.ckpt: 309 | args.dim=384 310 | elif 'base' in args.ckpt: 311 | args.dim=768 312 | args.num_queries=args.train_resolution**2 // int(args.ckpt.split('_')[-1].split('.')[0])**2 313 | 314 | # the number of gpus for multi-process 315 | gpu_list = list(map(int, args.gpu.split(','))) 316 | ngpus_per_node = len(gpu_list) 317 | 318 | if args.distributed: 319 | # cuda visible devices 320 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 321 | # multiprocess spawn 322 | mp.spawn(main, args=(args, ngpus_per_node), nprocs=ngpus_per_node, join=True) 323 | else: 324 | # first gpu index is activated once there are several gpu in args.gpu 325 | main(rank=gpu_list[0], args=args, ngpus_per_node=1) 326 | -------------------------------------------------------------------------------- /loader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ByungKwanLee/Causal-Unsupervised-Segmentation/ee0aa8478a6b6704f4db44ecc70e44a14fe5067f/loader/__init__.py -------------------------------------------------------------------------------- /loader/netloader.py: -------------------------------------------------------------------------------- 1 | from utils.utils import * 2 | from modules.segment import Segment_MLP 3 | from modules.segment import Segment_TR 4 | from modules.segment_module import Cluster 5 | from torch.nn.parallel import DistributedDataParallel 6 | 7 | def network_loader(args, rank=0): 8 | # load network 9 | net = load_model(args.ckpt, rank).cuda() 10 | if args.distributed: 11 | net = DistributedDataParallel(net, device_ids=[rank]) 12 | freeze(net) 13 | return net 14 | 15 | def cluster_mlp_loader(args, rank): 16 | cluster = Cluster(args).cuda() 17 | 18 | if args.load_cluster: 19 | baseline = args.ckpt.split('/')[-1].split('.')[0] 20 | y = f'CAUSE/{args.dataset}/{baseline}/{args.num_codebook}/cluster_mlp.pth' 21 | cluster.load_state_dict(torch.load(y, map_location=f'cuda:{rank}'), strict=False) 22 | rprint(f'[Cluster] {y} loaded', rank) 23 | 24 | if args.distributed: 25 | cluster = DistributedDataParallel(cluster, device_ids=[rank]) 26 | return cluster 27 | 28 | 29 | def cluster_tr_loader(args, rank): 30 | cluster = Cluster(args).cuda() 31 | 32 | if args.load_cluster: 33 | baseline = args.ckpt.split('/')[-1].split('.')[0] 34 | y = f'CAUSE/{args.dataset}/{baseline}/{args.num_codebook}/cluster_tr.pth' 35 | cluster.load_state_dict(torch.load(y, map_location=f'cuda:{rank}'), strict=False) 36 | rprint(f'[Cluster] {y} loaded', rank) 37 | 38 | if args.distributed: 39 | cluster = DistributedDataParallel(cluster, device_ids=[rank]) 40 | return cluster 41 | 42 | def segment_mlp_loader(args, rank=0): 43 | segment = Segment_MLP(args).cuda() 44 | 45 | if args.load_segment: 46 | baseline = args.ckpt.split('/')[-1].split('.')[0] 47 | y = f'CAUSE/{args.dataset}/{baseline}/{args.num_codebook}/segment_mlp.pth' 48 | segment.load_state_dict(torch.load(y, map_location=f'cuda:{rank}'), strict=False) 49 | rprint(f'[Segment] {y} loaded', rank) 50 | 51 | if args.distributed: 52 | segment = DistributedDataParallel(segment, device_ids=[rank]) 53 | 54 | return segment 55 | 56 | def segment_tr_loader(args, rank=0): 57 | segment = Segment_TR(args).cuda() 58 | 59 | if args.load_segment: 60 | baseline = args.ckpt.split('/')[-1].split('.')[0] 61 | y = f'CAUSE/{args.dataset}/{baseline}/{args.num_codebook}/segment_tr.pth' 62 | segment.load_state_dict(torch.load(y, map_location=f'cuda:{rank}'), strict=False) 63 | rprint(f'[Segment] {y} loaded', rank) 64 | 65 | if args.distributed: 66 | segment = DistributedDataParallel(segment, device_ids=[rank]) 67 | 68 | return segment 69 | 70 | 71 | def checkpoint_module(checkpoint, net): 72 | from collections import OrderedDict 73 | new_state_dict = OrderedDict() 74 | for k, v in checkpoint.items(): 75 | name = k[7:] # remove `module.` 76 | new_state_dict[name] = v 77 | msg = net.load_state_dict(new_state_dict, strict=False) 78 | return msg 79 | 80 | def load_model(ckpt, rank=0): 81 | # name and arch 82 | name = ckpt_to_name(ckpt) 83 | arch = ckpt_to_arch(ckpt) 84 | 85 | if name == "dino" or name == "mae": 86 | import models.dinomaevit as model 87 | elif name == "dinov2": 88 | import models.dinov2vit as model 89 | elif name == "ibot": 90 | import models.ibotvit as model 91 | elif name == "msn": 92 | import models.msnvit as model 93 | else: 94 | raise ValueError 95 | 96 | net = getattr(model, arch)() 97 | checkpoint = torch.load(ckpt, map_location=torch.device(f'cuda:{rank}')) 98 | if name == "mae": 99 | msg = net.load_state_dict(checkpoint["model"], strict=False) 100 | elif name == "dino": 101 | msg = net.load_state_dict(checkpoint, strict=False) 102 | elif name == "dinov2": 103 | msg = net.load_state_dict(checkpoint, strict=False) 104 | elif name == "ibot": 105 | msg = net.load_state_dict(checkpoint['state_dict'], strict=False) 106 | elif name == "msn": 107 | msg = checkpoint_module(checkpoint['target_encoder'], net) 108 | 109 | # check incompatible layer or variables 110 | rprint(msg, rank) 111 | 112 | return net 113 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ByungKwanLee/Causal-Unsupervised-Segmentation/ee0aa8478a6b6704f4db44ecc70e44a14fe5067f/models/__init__.py -------------------------------------------------------------------------------- /models/dinomaevit.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import timm.models.layers 7 | import timm.models.vision_transformer as vit 8 | 9 | #TODO: Remaining 10 | class PatchEmbed_DimensionFree(timm.models.layers.PatchEmbed): 11 | """ 2D Image to Patch Embedding 12 | """ 13 | def __init__( 14 | self, 15 | img_size=224, 16 | patch_size=16, 17 | in_chans=3, 18 | embed_dim=768, 19 | norm_layer=None, 20 | flatten=True, 21 | bias=True, 22 | ): 23 | super().__init__(img_size=img_size, 24 | patch_size=patch_size, 25 | in_chans=in_chans, 26 | embed_dim=embed_dim, 27 | norm_layer=norm_layer, 28 | flatten=flatten, 29 | bias=bias,) 30 | 31 | def forward(self, x): 32 | x = self.proj(x) 33 | if self.flatten: 34 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 35 | x = self.norm(x) 36 | return x 37 | 38 | # DINO VIT 39 | class DINOMAEVisionTransformer(vit.VisionTransformer): 40 | def __init__(self, **kwargs): 41 | super(DINOMAEVisionTransformer, self).__init__(embed_layer=PatchEmbed_DimensionFree, **kwargs) 42 | 43 | def forward_features(self, x): 44 | w, h = x.shape[2:] 45 | x = self.patch_embed(x) 46 | x = self.interpolate_pos_embed(x, w, h) # to fit changed resolution 47 | x = self.norm_pre(x) 48 | x = self.blocks(x) 49 | x = self.norm(x) 50 | return x 51 | 52 | def forward(self, *args, **kwargs): return self.forward_features(*args, **kwargs) 53 | 54 | def interpolate_pos_embed(self, x, w, h): 55 | if self.no_embed_class: 56 | # deit-3, updated JAX (big vision) 57 | # position embedding does not overlap with class token, add then concat 58 | x = x + self.pos_embed 59 | if self.cls_token is not None: 60 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) 61 | else: 62 | # original timm, JAX, and deit vit impl 63 | # pos_embed has entry for class token, concat then add 64 | if self.cls_token is not None: 65 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) 66 | 67 | # original version is just adding positional embedding 68 | # x = x + self.pos_embed 69 | # but to adaptively control the image resolution, interpolation of positional embedding is necessary 70 | x = x + self.interpolate_pos_encoding(x, w, h) 71 | return self.pos_drop(x) 72 | 73 | def interpolate_pos_encoding(self, x, w, h): 74 | npatch = x.shape[1] - 1 75 | N = self.pos_embed.shape[1] - 1 76 | if npatch == N and w == h: 77 | return self.pos_embed 78 | class_pos_embed = self.pos_embed[:, 0] 79 | patch_pos_embed = self.pos_embed[:, 1:] 80 | dim = x.shape[-1] 81 | w0 = w // self.patch_embed.patch_size[0] 82 | h0 = h // self.patch_embed.patch_size[1] 83 | # we add a small number to avoid floating point error in the interpolation 84 | # see discussion at https://github.com/facebookresearch/dino/issues/8 85 | w0, h0 = w0 + 0.1, h0 + 0.1 86 | patch_pos_embed = nn.functional.interpolate( 87 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 88 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 89 | mode='bicubic', 90 | ) 91 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 92 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 93 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 94 | 95 | # DINO 96 | def dino_vit_small_8(**kwargs): 97 | model = DINOMAEVisionTransformer( 98 | patch_size=8, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 99 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 100 | return model 101 | 102 | def dino_vit_base_8(**kwargs): 103 | model = DINOMAEVisionTransformer( 104 | patch_size=8, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 105 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 106 | return model 107 | def dino_vit_small_16(**kwargs): 108 | model = DINOMAEVisionTransformer( 109 | patch_size=16, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 110 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 111 | return model 112 | 113 | def dino_vit_base_16(**kwargs): 114 | model = DINOMAEVisionTransformer( 115 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 116 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 117 | return model 118 | 119 | # MAE 120 | def mae_vit_base_16(**kwargs): 121 | model = DINOMAEVisionTransformer( 122 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 123 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 124 | return model -------------------------------------------------------------------------------- /models/dinov2vit.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.utils.checkpoint 7 | from torch.nn.init import trunc_normal_ 8 | from torch import Tensor, nn 9 | from typing import Callable, Optional, Tuple, Union, Sequence, Dict, Any, List 10 | 11 | 12 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 13 | if drop_prob == 0.0 or not training: 14 | return x 15 | keep_prob = 1 - drop_prob 16 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 17 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 18 | if keep_prob > 0.0: 19 | random_tensor.div_(keep_prob) 20 | output = x * random_tensor 21 | return output 22 | 23 | 24 | class DropPath(nn.Module): 25 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 26 | 27 | def __init__(self, drop_prob=None): 28 | super(DropPath, self).__init__() 29 | self.drop_prob = drop_prob 30 | 31 | def forward(self, x): 32 | return drop_path(x, self.drop_prob, self.training) 33 | 34 | class LayerScale(nn.Module): 35 | def __init__( 36 | self, 37 | dim: int, 38 | init_values: Union[float, Tensor] = 1e-5, 39 | inplace: bool = False, 40 | ) -> None: 41 | super().__init__() 42 | self.inplace = inplace 43 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 44 | 45 | def forward(self, x: Tensor) -> Tensor: 46 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 47 | 48 | 49 | try: 50 | from xformers.ops import fmha 51 | from xformers.ops import scaled_index_add, index_select_cat 52 | 53 | XFORMERS_AVAILABLE = True 54 | except ImportError: 55 | XFORMERS_AVAILABLE = False 56 | 57 | 58 | def drop_add_residual_stochastic_depth( 59 | x: Tensor, 60 | residual_func: Callable[[Tensor], Tensor], 61 | sample_drop_ratio: float = 0.0, 62 | ) -> Tensor: 63 | # 1) extract subset using permutation 64 | b, n, d = x.shape 65 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 66 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 67 | x_subset = x[brange] 68 | 69 | # 2) apply residual_func to get residual 70 | residual = residual_func(x_subset) 71 | 72 | x_flat = x.flatten(1) 73 | residual = residual.flatten(1) 74 | 75 | residual_scale_factor = b / sample_subset_size 76 | 77 | # 3) add the residual 78 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) 79 | return x_plus_residual.view_as(x) 80 | 81 | 82 | def get_branges_scales(x, sample_drop_ratio=0.0): 83 | b, n, d = x.shape 84 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 85 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 86 | residual_scale_factor = b / sample_subset_size 87 | return brange, residual_scale_factor 88 | 89 | 90 | def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): 91 | if scaling_vector is None: 92 | x_flat = x.flatten(1) 93 | residual = residual.flatten(1) 94 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) 95 | else: 96 | x_plus_residual = scaled_index_add( 97 | x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor 98 | ) 99 | return x_plus_residual 100 | 101 | 102 | attn_bias_cache: Dict[Tuple, Any] = {} 103 | 104 | 105 | def get_attn_bias_and_cat(x_list, branges=None): 106 | """ 107 | this will perform the index select, cat the tensors, and provide the attn_bias from cache 108 | """ 109 | batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] 110 | all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) 111 | if all_shapes not in attn_bias_cache.keys(): 112 | seqlens = [] 113 | for b, x in zip(batch_sizes, x_list): 114 | for _ in range(b): 115 | seqlens.append(x.shape[1]) 116 | attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) 117 | attn_bias._batch_sizes = batch_sizes 118 | attn_bias_cache[all_shapes] = attn_bias 119 | 120 | if branges is not None: 121 | cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) 122 | else: 123 | tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) 124 | cat_tensors = torch.cat(tensors_bs1, dim=1) 125 | 126 | return attn_bias_cache[all_shapes], cat_tensors 127 | 128 | 129 | def drop_add_residual_stochastic_depth_list( 130 | x_list: List[Tensor], 131 | residual_func: Callable[[Tensor, Any], Tensor], 132 | sample_drop_ratio: float = 0.0, 133 | scaling_vector=None, 134 | ) -> Tensor: 135 | # 1) generate random set of indices for dropping samples in the batch 136 | branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] 137 | branges = [s[0] for s in branges_scales] 138 | residual_scale_factors = [s[1] for s in branges_scales] 139 | 140 | # 2) get attention bias and index+concat the tensors 141 | attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) 142 | 143 | # 3) apply residual_func to get residual, and split the result 144 | residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore 145 | 146 | outputs = [] 147 | for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): 148 | outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) 149 | return outputs 150 | 151 | 152 | 153 | 154 | try: 155 | from xformers.ops import memory_efficient_attention, unbind, fmha 156 | 157 | XFORMERS_AVAILABLE = True 158 | except ImportError: 159 | XFORMERS_AVAILABLE = False 160 | 161 | 162 | class Mlp(nn.Module): 163 | def __init__( 164 | self, 165 | in_features: int, 166 | hidden_features: Optional[int] = None, 167 | out_features: Optional[int] = None, 168 | act_layer: Callable[..., nn.Module] = nn.GELU, 169 | drop: float = 0.0, 170 | bias: bool = True, 171 | ) -> None: 172 | super().__init__() 173 | out_features = out_features or in_features 174 | hidden_features = hidden_features or in_features 175 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) 176 | self.act = act_layer() 177 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) 178 | self.drop = nn.Dropout(drop) 179 | 180 | def forward(self, x: Tensor) -> Tensor: 181 | x = self.fc1(x) 182 | x = self.act(x) 183 | x = self.drop(x) 184 | x = self.fc2(x) 185 | x = self.drop(x) 186 | return x 187 | 188 | class Attention(nn.Module): 189 | def __init__( 190 | self, 191 | dim: int, 192 | num_heads: int = 8, 193 | qkv_bias: bool = False, 194 | proj_bias: bool = True, 195 | attn_drop: float = 0.0, 196 | proj_drop: float = 0.0, 197 | ) -> None: 198 | super().__init__() 199 | self.num_heads = num_heads 200 | head_dim = dim // num_heads 201 | self.scale = head_dim**-0.5 202 | 203 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 204 | self.attn_drop = nn.Dropout(attn_drop) 205 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 206 | self.proj_drop = nn.Dropout(proj_drop) 207 | 208 | def forward(self, x: Tensor) -> Tensor: 209 | B, N, C = x.shape 210 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 211 | 212 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 213 | attn = q @ k.transpose(-2, -1) 214 | 215 | attn = attn.softmax(dim=-1) 216 | attn = self.attn_drop(attn) 217 | 218 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 219 | x = self.proj(x) 220 | x = self.proj_drop(x) 221 | return x 222 | 223 | 224 | class MemEffAttention(Attention): 225 | def forward(self, x: Tensor, attn_bias=None) -> Tensor: 226 | if not XFORMERS_AVAILABLE: 227 | assert attn_bias is None, "xFormers is required for nested tensors usage" 228 | return super().forward(x) 229 | 230 | B, N, C = x.shape 231 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 232 | 233 | q, k, v = unbind(qkv, 2) 234 | 235 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 236 | x = x.reshape([B, N, C]) 237 | 238 | x = self.proj(x) 239 | x = self.proj_drop(x) 240 | return x 241 | 242 | 243 | 244 | class Block(nn.Module): 245 | def __init__( 246 | self, 247 | dim: int, 248 | num_heads: int, 249 | mlp_ratio: float = 4.0, 250 | qkv_bias: bool = False, 251 | proj_bias: bool = True, 252 | ffn_bias: bool = True, 253 | drop: float = 0.0, 254 | attn_drop: float = 0.0, 255 | init_values=None, 256 | drop_path: float = 0.0, 257 | act_layer: Callable[..., nn.Module] = nn.GELU, 258 | norm_layer: Callable[..., nn.Module] = nn.LayerNorm, 259 | attn_class: Callable[..., nn.Module] = Attention, 260 | ffn_layer: Callable[..., nn.Module] = Mlp, 261 | ) -> None: 262 | super().__init__() 263 | # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") 264 | self.norm1 = norm_layer(dim) 265 | self.attn = attn_class( 266 | dim, 267 | num_heads=num_heads, 268 | qkv_bias=qkv_bias, 269 | proj_bias=proj_bias, 270 | attn_drop=attn_drop, 271 | proj_drop=drop, 272 | ) 273 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 274 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 275 | 276 | self.norm2 = norm_layer(dim) 277 | mlp_hidden_dim = int(dim * mlp_ratio) 278 | self.mlp = ffn_layer( 279 | in_features=dim, 280 | hidden_features=mlp_hidden_dim, 281 | act_layer=act_layer, 282 | drop=drop, 283 | bias=ffn_bias, 284 | ) 285 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 286 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 287 | 288 | self.sample_drop_ratio = drop_path 289 | 290 | def forward(self, x: Tensor) -> Tensor: 291 | def attn_residual_func(x: Tensor) -> Tensor: 292 | return self.ls1(self.attn(self.norm1(x))) 293 | 294 | def ffn_residual_func(x: Tensor) -> Tensor: 295 | return self.ls2(self.mlp(self.norm2(x))) 296 | 297 | if self.training and self.sample_drop_ratio > 0.1: 298 | # the overhead is compensated only for a drop path rate larger than 0.1 299 | x = drop_add_residual_stochastic_depth( 300 | x, 301 | residual_func=attn_residual_func, 302 | sample_drop_ratio=self.sample_drop_ratio, 303 | ) 304 | x = drop_add_residual_stochastic_depth( 305 | x, 306 | residual_func=ffn_residual_func, 307 | sample_drop_ratio=self.sample_drop_ratio, 308 | ) 309 | elif self.training and self.sample_drop_ratio > 0.0: 310 | x = x + self.drop_path1(attn_residual_func(x)) 311 | x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 312 | else: 313 | x = x + attn_residual_func(x) 314 | x = x + ffn_residual_func(x) 315 | return x 316 | 317 | 318 | class NestedTensorBlock(Block): 319 | def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: 320 | """ 321 | x_list contains a list of tensors to nest together and run 322 | """ 323 | assert isinstance(self.attn, MemEffAttention) 324 | 325 | if self.training and self.sample_drop_ratio > 0.0: 326 | 327 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 328 | return self.attn(self.norm1(x), attn_bias=attn_bias) 329 | 330 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 331 | return self.mlp(self.norm2(x)) 332 | 333 | x_list = drop_add_residual_stochastic_depth_list( 334 | x_list, 335 | residual_func=attn_residual_func, 336 | sample_drop_ratio=self.sample_drop_ratio, 337 | scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, 338 | ) 339 | x_list = drop_add_residual_stochastic_depth_list( 340 | x_list, 341 | residual_func=ffn_residual_func, 342 | sample_drop_ratio=self.sample_drop_ratio, 343 | scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, 344 | ) 345 | return x_list 346 | else: 347 | 348 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 349 | return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) 350 | 351 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 352 | return self.ls2(self.mlp(self.norm2(x))) 353 | 354 | attn_bias, x = get_attn_bias_and_cat(x_list) 355 | x = x + attn_residual_func(x, attn_bias=attn_bias) 356 | x = x + ffn_residual_func(x) 357 | return attn_bias.split(x) 358 | 359 | def forward(self, x_or_x_list): 360 | if isinstance(x_or_x_list, Tensor): 361 | return super().forward(x_or_x_list) 362 | elif isinstance(x_or_x_list, list): 363 | assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" 364 | return self.forward_nested(x_or_x_list) 365 | else: 366 | raise AssertionError 367 | 368 | 369 | 370 | def make_2tuple(x): 371 | if isinstance(x, tuple): 372 | assert len(x) == 2 373 | return x 374 | 375 | assert isinstance(x, int) 376 | return (x, x) 377 | 378 | 379 | class PatchEmbed(nn.Module): 380 | """ 381 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D) 382 | 383 | Args: 384 | img_size: Image size. 385 | patch_size: Patch token size. 386 | in_chans: Number of input image channels. 387 | embed_dim: Number of linear projection output channels. 388 | norm_layer: Normalization layer. 389 | """ 390 | 391 | def __init__( 392 | self, 393 | img_size: Union[int, Tuple[int, int]] = 224, 394 | patch_size: Union[int, Tuple[int, int]] = 16, 395 | in_chans: int = 3, 396 | embed_dim: int = 768, 397 | norm_layer: Optional[Callable] = None, 398 | flatten_embedding: bool = True, 399 | ) -> None: 400 | super().__init__() 401 | 402 | image_HW = make_2tuple(img_size) 403 | patch_HW = make_2tuple(patch_size) 404 | patch_grid_size = ( 405 | image_HW[0] // patch_HW[0], 406 | image_HW[1] // patch_HW[1], 407 | ) 408 | 409 | self.img_size = image_HW 410 | self.patch_size = patch_HW 411 | self.patches_resolution = patch_grid_size 412 | self.num_patches = patch_grid_size[0] * patch_grid_size[1] 413 | 414 | self.in_chans = in_chans 415 | self.embed_dim = embed_dim 416 | 417 | self.flatten_embedding = flatten_embedding 418 | 419 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) 420 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 421 | 422 | def forward(self, x: Tensor) -> Tensor: 423 | _, _, H, W = x.shape 424 | patch_H, patch_W = self.patch_size 425 | 426 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" 427 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" 428 | 429 | x = self.proj(x) # B C H W 430 | H, W = x.size(2), x.size(3) 431 | x = x.flatten(2).transpose(1, 2) # B HW C 432 | x = self.norm(x) 433 | if not self.flatten_embedding: 434 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C 435 | return x 436 | 437 | def flops(self) -> float: 438 | Ho, Wo = self.patches_resolution 439 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 440 | if self.norm is not None: 441 | flops += Ho * Wo * self.embed_dim 442 | return flops 443 | 444 | 445 | 446 | 447 | 448 | def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: 449 | if not depth_first and include_root: 450 | fn(module=module, name=name) 451 | for child_name, child_module in module.named_children(): 452 | child_name = ".".join((name, child_name)) if name else child_name 453 | named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) 454 | if depth_first and include_root: 455 | fn(module=module, name=name) 456 | return module 457 | 458 | 459 | class BlockChunk(nn.ModuleList): 460 | def forward(self, x): 461 | for b in self: 462 | x = b(x) 463 | return x 464 | 465 | 466 | class DinoVisionTransformer(nn.Module): 467 | def __init__( 468 | self, 469 | img_size=518, 470 | patch_size=16, 471 | in_chans=3, 472 | embed_dim=768, 473 | depth=12, 474 | num_heads=12, 475 | mlp_ratio=4.0, 476 | qkv_bias=True, 477 | ffn_bias=True, 478 | proj_bias=True, 479 | drop_path_rate=0.0, 480 | drop_path_uniform=False, 481 | init_values=1.0, # for layerscale: None or 0 => no layerscale 482 | embed_layer=PatchEmbed, 483 | act_layer=nn.GELU, 484 | block_fn=Block, 485 | ffn_layer="mlp", 486 | block_chunks=0, 487 | ): 488 | """ 489 | Args: 490 | img_size (int, tuple): input image size 491 | patch_size (int, tuple): patch size 492 | in_chans (int): number of input channels 493 | embed_dim (int): embedding dimension 494 | depth (int): depth of transformer 495 | num_heads (int): number of attention heads 496 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 497 | qkv_bias (bool): enable bias for qkv if True 498 | proj_bias (bool): enable bias for proj in attn if True 499 | ffn_bias (bool): enable bias for ffn if True 500 | drop_path_rate (float): stochastic depth rate 501 | drop_path_uniform (bool): apply uniform drop rate across blocks 502 | weight_init (str): weight init scheme 503 | init_values (float): layer-scale init values 504 | embed_layer (nn.Module): patch embedding layer 505 | act_layer (nn.Module): MLP activation layer 506 | block_fn (nn.Module): transformer block class 507 | ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" 508 | block_chunks: (int) split block sequence into block_chunks units for FSDP wrap 509 | """ 510 | super().__init__() 511 | norm_layer = partial(nn.LayerNorm, eps=1e-6) 512 | 513 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 514 | self.num_tokens = 1 515 | self.n_blocks = depth 516 | self.num_heads = num_heads 517 | self.patch_size = patch_size 518 | 519 | self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 520 | num_patches = self.patch_embed.num_patches 521 | 522 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 523 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) 524 | 525 | if drop_path_uniform is True: 526 | dpr = [drop_path_rate] * depth 527 | else: 528 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 529 | 530 | if ffn_layer == "mlp": 531 | ffn_layer = Mlp 532 | else: 533 | raise NotImplementedError 534 | 535 | blocks_list = [ 536 | block_fn( 537 | dim=embed_dim, 538 | num_heads=num_heads, 539 | mlp_ratio=mlp_ratio, 540 | qkv_bias=qkv_bias, 541 | proj_bias=proj_bias, 542 | ffn_bias=ffn_bias, 543 | drop_path=dpr[i], 544 | norm_layer=norm_layer, 545 | act_layer=act_layer, 546 | ffn_layer=ffn_layer, 547 | init_values=init_values, 548 | ) 549 | for i in range(depth) 550 | ] 551 | if block_chunks > 0: 552 | self.chunked_blocks = True 553 | chunked_blocks = [] 554 | chunksize = depth // block_chunks 555 | for i in range(0, depth, chunksize): 556 | # this is to keep the block index consistent if we chunk the block list 557 | chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) 558 | self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) 559 | else: 560 | self.chunked_blocks = False 561 | self.blocks = nn.ModuleList(blocks_list) 562 | 563 | self.norm = norm_layer(embed_dim) 564 | self.head = nn.Identity() 565 | 566 | self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) 567 | 568 | self.init_weights() 569 | 570 | def init_weights(self): 571 | trunc_normal_(self.pos_embed, std=0.02) 572 | nn.init.normal_(self.cls_token, std=1e-6) 573 | named_apply(init_weights_vit_timm, self) 574 | 575 | def interpolate_pos_encoding(self, x, w, h): 576 | previous_dtype = x.dtype 577 | npatch = x.shape[1] - 1 578 | N = self.pos_embed.shape[1] - 1 579 | if npatch == N and w == h: 580 | return self.pos_embed 581 | pos_embed = self.pos_embed.float() 582 | class_pos_embed = pos_embed[:, 0] 583 | patch_pos_embed = pos_embed[:, 1:] 584 | dim = x.shape[-1] 585 | w0 = w // self.patch_size 586 | h0 = h // self.patch_size 587 | # we add a small number to avoid floating point error in the interpolation 588 | # see discussion at https://github.com/facebookresearch/dino/issues/8 589 | w0, h0 = w0 + 0.1, h0 + 0.1 590 | 591 | patch_pos_embed = nn.functional.interpolate( 592 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 593 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 594 | mode="bicubic", 595 | ) 596 | 597 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 598 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 599 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) 600 | 601 | def prepare_tokens_with_masks(self, x, masks=None): 602 | B, nc, w, h = x.shape 603 | x = self.patch_embed(x) 604 | if masks is not None: 605 | x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) 606 | 607 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) 608 | x = x + self.interpolate_pos_encoding(x, w, h) 609 | 610 | return x 611 | 612 | def forward_features_list(self, x_list, masks_list): 613 | x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] 614 | for blk in self.blocks: 615 | x = blk(x) 616 | 617 | all_x = x 618 | output = [] 619 | for x, masks in zip(all_x, masks_list): 620 | x_norm = self.norm(x) 621 | output.append( 622 | { 623 | "x_norm_clstoken": x_norm[:, 0], 624 | "x_norm_patchtokens": x_norm[:, 1:], 625 | "x_prenorm": x, 626 | "masks": masks, 627 | } 628 | ) 629 | return output 630 | 631 | def forward_features(self, x, masks=None): 632 | if isinstance(x, list): 633 | return self.forward_features_list(x, masks) 634 | 635 | x = self.prepare_tokens_with_masks(x, masks) 636 | for blk in self.blocks: 637 | x = blk(x) 638 | 639 | x_norm = self.norm(x) 640 | return x_norm 641 | 642 | def _get_intermediate_layers_not_chunked(self, x, n=1): 643 | x = self.prepare_tokens_with_masks(x) 644 | # If n is an int, take the n last blocks. If it's a list, take them 645 | output, total_block_len = [], len(self.blocks) 646 | blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n 647 | for i, blk in enumerate(self.blocks): 648 | x = blk(x) 649 | if i in blocks_to_take: 650 | output.append(x) 651 | assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" 652 | return output 653 | 654 | def _get_intermediate_layers_chunked(self, x, n=1): 655 | x = self.prepare_tokens_with_masks(x) 656 | output, i, total_block_len = [], 0, len(self.blocks[-1]) 657 | # If n is an int, take the n last blocks. If it's a list, take them 658 | blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n 659 | for block_chunk in self.blocks: 660 | for blk in block_chunk[i:]: # Passing the nn.Identity() 661 | x = blk(x) 662 | if i in blocks_to_take: 663 | output.append(x) 664 | i += 1 665 | assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" 666 | return output 667 | 668 | def get_intermediate_layers( 669 | self, 670 | x: torch.Tensor, 671 | n: Union[int, Sequence] = 1, # Layers or n last layers to take 672 | reshape: bool = False, 673 | return_class_token: bool = False, 674 | norm=True, 675 | ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: 676 | if self.chunked_blocks: 677 | outputs = self._get_intermediate_layers_chunked(x, n) 678 | else: 679 | outputs = self._get_intermediate_layers_not_chunked(x, n) 680 | if norm: 681 | outputs = [self.norm(out) for out in outputs] 682 | class_tokens = [out[:, 0] for out in outputs] 683 | outputs = [out[:, 1:] for out in outputs] 684 | if reshape: 685 | B, _, w, h = x.shape 686 | outputs = [ 687 | out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() 688 | for out in outputs 689 | ] 690 | if return_class_token: 691 | return tuple(zip(outputs, class_tokens)) 692 | return tuple(outputs) 693 | 694 | def forward(self, *args, **kwargs): 695 | ret = self.forward_features(*args, **kwargs) 696 | return ret 697 | 698 | 699 | def init_weights_vit_timm(module: nn.Module, name: str = ""): 700 | """ViT weight initialization, original timm impl (for reproducibility)""" 701 | if isinstance(module, nn.Linear): 702 | trunc_normal_(module.weight, std=0.02) 703 | if module.bias is not None: 704 | nn.init.zeros_(module.bias) 705 | 706 | 707 | # DINO V2 708 | def dinov2_vit_small_14(**kwargs): 709 | model = DinoVisionTransformer( 710 | patch_size=14, 711 | embed_dim=384, 712 | depth=12, 713 | num_heads=6, 714 | mlp_ratio=4, 715 | block_fn=partial(Block, attn_class=MemEffAttention), 716 | **kwargs, 717 | ) 718 | return model 719 | 720 | 721 | def dinov2_vit_base_14(**kwargs): 722 | model = DinoVisionTransformer( 723 | patch_size=14, 724 | embed_dim=768, 725 | depth=12, 726 | num_heads=12, 727 | mlp_ratio=4, 728 | block_fn=partial(Block, attn_class=MemEffAttention), 729 | **kwargs, 730 | ) 731 | return model 732 | 733 | 734 | def dinov2_vit_large_14(**kwargs): 735 | model = DinoVisionTransformer( 736 | patch_size=14, 737 | embed_dim=1024, 738 | depth=24, 739 | num_heads=16, 740 | mlp_ratio=4, 741 | block_fn=partial(Block, attn_class=MemEffAttention), 742 | **kwargs, 743 | ) 744 | return model -------------------------------------------------------------------------------- /models/ibotvit.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import warnings 4 | import torch.nn as nn 5 | from functools import partial 6 | from torch import Tensor 7 | 8 | 9 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 10 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 11 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 12 | def norm_cdf(x): 13 | # Computes standard normal cumulative distribution function 14 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 15 | 16 | if (mean < a - 2 * std) or (mean > b + 2 * std): 17 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 18 | "The distribution of values may be incorrect.", 19 | stacklevel=2) 20 | 21 | with torch.no_grad(): 22 | # Values are generated by using a truncated uniform distribution and 23 | # then using the inverse CDF for the normal distribution. 24 | # Get upper and lower cdf values 25 | l = norm_cdf((a - mean) / std) 26 | u = norm_cdf((b - mean) / std) 27 | 28 | # Uniformly fill tensor with values from [l, u], then translate to 29 | # [2l-1, 2u-1]. 30 | tensor.uniform_(2 * l - 1, 2 * u - 1) 31 | 32 | # Use inverse cdf transform for normal distribution to get truncated 33 | # standard normal 34 | tensor.erfinv_() 35 | 36 | # Transform to proper mean, std 37 | tensor.mul_(std * math.sqrt(2.)) 38 | tensor.add_(mean) 39 | 40 | # Clamp to ensure it's in the proper range 41 | tensor.clamp_(min=a, max=b) 42 | return tensor 43 | 44 | 45 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 46 | # type: (Tensor, float, float, float, float) -> Tensor 47 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 48 | 49 | def drop_path(x, drop_prob: float = 0., training: bool = False): 50 | if drop_prob == 0. or not training: 51 | return x 52 | keep_prob = 1 - drop_prob 53 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 54 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 55 | random_tensor.floor_() # binarize 56 | output = x.div(keep_prob) * random_tensor 57 | return output 58 | 59 | 60 | class DropPath(nn.Module): 61 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 62 | """ 63 | def __init__(self, drop_prob=None): 64 | super(DropPath, self).__init__() 65 | self.drop_prob = drop_prob 66 | 67 | def forward(self, x): 68 | return drop_path(x, self.drop_prob, self.training) 69 | 70 | 71 | class Mlp(nn.Module): 72 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 73 | super().__init__() 74 | out_features = out_features or in_features 75 | hidden_features = hidden_features or in_features 76 | self.fc1 = nn.Linear(in_features, hidden_features) 77 | self.act = act_layer() 78 | self.fc2 = nn.Linear(hidden_features, out_features) 79 | self.drop = nn.Dropout(drop) 80 | 81 | def forward(self, x): 82 | x = self.fc1(x) 83 | x = self.act(x) 84 | x = self.drop(x) 85 | x = self.fc2(x) 86 | x = self.drop(x) 87 | return x 88 | 89 | 90 | class Attention(nn.Module): 91 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 92 | super().__init__() 93 | self.num_heads = num_heads 94 | head_dim = dim // num_heads 95 | self.scale = qk_scale or head_dim ** -0.5 96 | 97 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 98 | self.attn_drop = nn.Dropout(attn_drop) 99 | self.proj = nn.Linear(dim, dim) 100 | self.proj_drop = nn.Dropout(proj_drop) 101 | 102 | def forward(self, x): 103 | B, N, C = x.shape 104 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 105 | q, k, v = qkv[0], qkv[1], qkv[2] 106 | 107 | attn = (q @ k.transpose(-2, -1)) * self.scale 108 | attn = attn.softmax(dim=-1) 109 | attn = self.attn_drop(attn) 110 | 111 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 112 | x = self.proj(x) 113 | x = self.proj_drop(x) 114 | return x, attn 115 | 116 | class Block(nn.Module): 117 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., 118 | attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, init_values=0): 119 | super().__init__() 120 | self.norm1 = norm_layer(dim) 121 | self.attn = Attention( 122 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 123 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 124 | self.norm2 = norm_layer(dim) 125 | mlp_hidden_dim = int(dim * mlp_ratio) 126 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 127 | 128 | if init_values > 0: 129 | self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) 130 | self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) 131 | else: 132 | self.gamma_1, self.gamma_2 = None, None 133 | 134 | def forward(self, x, return_attention=False): 135 | y, attn = self.attn(self.norm1(x)) 136 | if return_attention: 137 | return attn 138 | if self.gamma_1 is None: 139 | x = x + self.drop_path(y) 140 | x = x + self.drop_path(self.mlp(self.norm2(x))) 141 | else: 142 | x = x + self.drop_path(self.gamma_1 * y) 143 | x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) 144 | return x 145 | 146 | class PatchEmbed(nn.Module): 147 | """ Image to Patch Embedding 148 | """ 149 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 150 | super().__init__() 151 | num_patches = (img_size // patch_size) * (img_size // patch_size) 152 | self.img_size = img_size 153 | self.patch_size = patch_size 154 | self.num_patches = num_patches 155 | 156 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 157 | 158 | def forward(self, x): 159 | B, C, H, W = x.shape 160 | return self.proj(x) 161 | 162 | class VisionTransformer(nn.Module): 163 | """ Vision Transformer """ 164 | def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, 165 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 166 | drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), return_all_tokens=False, 167 | init_values=0, use_mean_pooling=False, masked_im_modeling=False): 168 | super().__init__() 169 | self.num_features = self.embed_dim = embed_dim 170 | self.return_all_tokens = return_all_tokens 171 | 172 | self.patch_embed = PatchEmbed( 173 | img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 174 | num_patches = self.patch_embed.num_patches 175 | 176 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 177 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 178 | self.pos_drop = nn.Dropout(p=drop_rate) 179 | 180 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 181 | self.blocks = nn.ModuleList([ 182 | Block( 183 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 184 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 185 | init_values=init_values) 186 | for i in range(depth)]) 187 | 188 | self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim) 189 | self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None 190 | # Classifier head 191 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 192 | 193 | trunc_normal_(self.pos_embed, std=.02) 194 | trunc_normal_(self.cls_token, std=.02) 195 | self.apply(self._init_weights) 196 | 197 | # masked image modeling 198 | self.masked_im_modeling = masked_im_modeling 199 | if masked_im_modeling: 200 | self.masked_embed = nn.Parameter(torch.zeros(1, embed_dim)) 201 | 202 | def _init_weights(self, m): 203 | if isinstance(m, nn.Linear): 204 | trunc_normal_(m.weight, std=.02) 205 | if isinstance(m, nn.Linear) and m.bias is not None: 206 | nn.init.constant_(m.bias, 0) 207 | elif isinstance(m, nn.LayerNorm): 208 | nn.init.constant_(m.bias, 0) 209 | nn.init.constant_(m.weight, 1.0) 210 | 211 | def interpolate_pos_encoding(self, x, w, h): 212 | npatch = x.shape[1] - 1 213 | N = self.pos_embed.shape[1] - 1 214 | if npatch == N and w == h: 215 | return self.pos_embed 216 | class_pos_embed = self.pos_embed[:, 0] 217 | patch_pos_embed = self.pos_embed[:, 1:] 218 | dim = x.shape[-1] 219 | w0 = w // self.patch_embed.patch_size 220 | h0 = h // self.patch_embed.patch_size 221 | # we add a small number to avoid floating point error in the interpolation 222 | # see discussion at https://github.com/facebookresearch/dino/issues/8 223 | w0, h0 = w0 + 0.1, h0 + 0.1 224 | patch_pos_embed = nn.functional.interpolate( 225 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 226 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 227 | mode='bicubic', 228 | ) 229 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 230 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 231 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 232 | 233 | def prepare_tokens(self, x, mask=None): 234 | B, nc, w, h = x.shape 235 | # patch linear embedding 236 | x = self.patch_embed(x) 237 | 238 | # mask image modeling 239 | if mask is not None: 240 | x = self.mask_model(x, mask) 241 | x = x.flatten(2).transpose(1, 2) 242 | 243 | # add the [CLS] token to the embed patch tokens 244 | cls_tokens = self.cls_token.expand(B, -1, -1) 245 | x = torch.cat((cls_tokens, x), dim=1) 246 | 247 | # add positional encoding to each token 248 | x = x + self.interpolate_pos_encoding(x, w, h) 249 | 250 | return self.pos_drop(x) 251 | 252 | def forward(self, x): 253 | # mim 254 | x = self.prepare_tokens(x) 255 | for f in self.blocks: 256 | x = f(x) 257 | x = self.norm(x) 258 | return x 259 | 260 | def get_last_selfattention(self, x): 261 | x = self.prepare_tokens(x) 262 | for i, blk in enumerate(self.blocks): 263 | if i < len(self.blocks) - 1: 264 | x = blk(x) 265 | else: 266 | # return attention of the last block 267 | return blk(x, return_attention=True) 268 | 269 | def get_intermediate_layers(self, x, n=1): 270 | x = self.prepare_tokens(x) 271 | # we return the output tokens from the `n` last blocks 272 | output = [] 273 | for i, blk in enumerate(self.blocks): 274 | x = blk(x) 275 | if len(self.blocks) - i <= n: 276 | output.append(self.norm(x)) 277 | return output 278 | 279 | def get_num_layers(self): 280 | return len(self.blocks) 281 | 282 | def mask_model(self, x, mask): 283 | x.permute(0, 2, 3, 1)[mask, :] = self.masked_embed.to(x.dtype) 284 | return x 285 | 286 | def ibot_vit_small_16(**kwargs): 287 | model = VisionTransformer( 288 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, 289 | qkv_bias=True, **kwargs) 290 | return model 291 | 292 | def ibot_vit_base_16(**kwargs): 293 | model = VisionTransformer( 294 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 295 | qkv_bias=True, **kwargs) 296 | return model -------------------------------------------------------------------------------- /models/msnvit.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch import Tensor 8 | 9 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 10 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 11 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 12 | def norm_cdf(x): 13 | # Computes standard normal cumulative distribution function 14 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 15 | 16 | with torch.no_grad(): 17 | # Values are generated by using a truncated uniform distribution and 18 | # then using the inverse CDF for the normal distribution. 19 | # Get upper and lower cdf values 20 | l = norm_cdf((a - mean) / std) 21 | u = norm_cdf((b - mean) / std) 22 | 23 | # Uniformly fill tensor with values from [l, u], then translate to 24 | # [2l-1, 2u-1]. 25 | tensor.uniform_(2 * l - 1, 2 * u - 1) 26 | 27 | # Use inverse cdf transform for normal distribution to get truncated 28 | # standard normal 29 | tensor.erfinv_() 30 | 31 | # Transform to proper mean, std 32 | tensor.mul_(std * math.sqrt(2.)) 33 | tensor.add_(mean) 34 | 35 | # Clamp to ensure it's in the proper range 36 | tensor.clamp_(min=a, max=b) 37 | return tensor 38 | 39 | 40 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 41 | # type: (Tensor, float, float, float, float) -> Tensor 42 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 43 | 44 | def drop_path(x, drop_prob: float = 0., training: bool = False): 45 | if drop_prob == 0. or not training: 46 | return x 47 | keep_prob = 1 - drop_prob 48 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 49 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 50 | random_tensor.floor_() # binarize 51 | output = x.div(keep_prob) * random_tensor 52 | return output 53 | 54 | class DropPath(nn.Module): 55 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 56 | """ 57 | def __init__(self, drop_prob=None): 58 | super(DropPath, self).__init__() 59 | self.drop_prob = drop_prob 60 | 61 | def forward(self, x): 62 | return drop_path(x, self.drop_prob, self.training) 63 | 64 | 65 | class MLP(nn.Module): 66 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 67 | super().__init__() 68 | out_features = out_features or in_features 69 | hidden_features = hidden_features or in_features 70 | self.fc1 = nn.Linear(in_features, hidden_features) 71 | self.act = act_layer() 72 | self.fc2 = nn.Linear(hidden_features, out_features) 73 | self.drop = nn.Dropout(drop) 74 | 75 | def forward(self, x): 76 | x = self.fc1(x) 77 | x = self.act(x) 78 | x = self.drop(x) 79 | x = self.fc2(x) 80 | x = self.drop(x) 81 | return x 82 | 83 | 84 | class Attention(nn.Module): 85 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 86 | super().__init__() 87 | self.num_heads = num_heads 88 | head_dim = dim // num_heads 89 | self.scale = qk_scale or head_dim ** -0.5 90 | 91 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 92 | self.attn_drop = nn.Dropout(attn_drop) 93 | self.proj = nn.Linear(dim, dim) 94 | self.proj_drop = nn.Dropout(proj_drop) 95 | 96 | def forward(self, x): 97 | B, N, C = x.shape 98 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 99 | q, k, v = qkv[0], qkv[1], qkv[2] 100 | 101 | attn = (q @ k.transpose(-2, -1)) * self.scale 102 | attn = attn.softmax(dim=-1) 103 | attn = self.attn_drop(attn) 104 | 105 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 106 | x = self.proj(x) 107 | x = self.proj_drop(x) 108 | return x, attn 109 | 110 | 111 | class Block(nn.Module): 112 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 113 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 114 | super().__init__() 115 | self.norm1 = norm_layer(dim) 116 | self.attn = Attention( 117 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 118 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 119 | self.norm2 = norm_layer(dim) 120 | mlp_hidden_dim = int(dim * mlp_ratio) 121 | self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 122 | 123 | def forward(self, x, return_attention=False): 124 | y, attn = self.attn(self.norm1(x)) 125 | if return_attention: 126 | return attn 127 | x = x + self.drop_path(y) 128 | x = x + self.drop_path(self.mlp(self.norm2(x))) 129 | return x 130 | 131 | 132 | class PatchEmbed(nn.Module): 133 | """ Image to Patch Embedding 134 | """ 135 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 136 | super().__init__() 137 | num_patches = (img_size // patch_size) * (img_size // patch_size) 138 | self.img_size = img_size 139 | self.patch_size = patch_size 140 | self.num_patches = num_patches 141 | 142 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 143 | 144 | def forward(self, x): 145 | B, C, H, W = x.shape 146 | x = self.proj(x).flatten(2).transpose(1, 2) 147 | return x 148 | 149 | 150 | class ConvEmbed(nn.Module): 151 | """ 152 | 3x3 Convolution stems for ViT following ViTC models 153 | """ 154 | 155 | def __init__(self, channels, strides, img_size=224, in_chans=3, batch_norm=True): 156 | super().__init__() 157 | # Build the stems 158 | stem = [] 159 | channels = [in_chans] + channels 160 | for i in range(len(channels) - 2): 161 | stem += [nn.Conv2d(channels[i], channels[i+1], kernel_size=3, 162 | stride=strides[i], padding=1, bias=(not batch_norm))] 163 | if batch_norm: 164 | stem += [nn.BatchNorm2d(channels[i+1])] 165 | stem += [nn.ReLU(inplace=True)] 166 | stem += [nn.Conv2d(channels[-2], channels[-1], kernel_size=1, stride=strides[-1])] 167 | self.stem = nn.Sequential(*stem) 168 | 169 | # Comptute the number of patches 170 | stride_prod = int(np.prod(strides)) 171 | self.num_patches = (img_size[0] // stride_prod)**2 172 | 173 | def forward(self, x): 174 | p = self.stem(x) 175 | return p.flatten(2).transpose(1, 2) 176 | 177 | 178 | class VisionTransformer(nn.Module): 179 | """ Vision Transformer """ 180 | def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, 181 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 182 | drop_path_rate=0., norm_layer=nn.LayerNorm, 183 | conv_stem=False, conv_stem_channels=None, conv_stem_strides=None, **kwargs): 184 | super().__init__() 185 | self.num_features = self.embed_dim = embed_dim 186 | 187 | if conv_stem: 188 | self.patch_embed = ConvEmbed(conv_stem_channels, conv_stem_strides, 189 | in_chans=in_chans, img_size=img_size) 190 | else: 191 | self.patch_embed = PatchEmbed( 192 | img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 193 | num_patches = self.patch_embed.num_patches 194 | 195 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 196 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 197 | self.pos_drop = nn.Dropout(p=drop_rate) 198 | 199 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 200 | self.blocks = nn.ModuleList([ 201 | Block( 202 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 203 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 204 | for i in range(depth)]) 205 | self.norm = norm_layer(embed_dim) 206 | 207 | # Classifier head 208 | self.fc = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 209 | self.pred = None 210 | 211 | trunc_normal_(self.pos_embed, std=.02) 212 | trunc_normal_(self.cls_token, std=.02) 213 | self.apply(self._init_weights) 214 | 215 | def _init_weights(self, m): 216 | if isinstance(m, nn.Linear): 217 | trunc_normal_(m.weight, std=.02) 218 | if isinstance(m, nn.Linear) and m.bias is not None: 219 | nn.init.constant_(m.bias, 0) 220 | elif isinstance(m, nn.LayerNorm): 221 | nn.init.constant_(m.bias, 0) 222 | nn.init.constant_(m.weight, 1.0) 223 | elif isinstance(m, nn.Conv2d): 224 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') 225 | if isinstance(m, nn.Conv2d) and m.bias is not None: 226 | nn.init.constant_(m.bias, 0) 227 | 228 | def forward(self, x): 229 | B = x.shape[0] 230 | x = self.patch_embed(x) 231 | 232 | cls_tokens = self.cls_token.expand(B, -1, -1) 233 | x = torch.cat((cls_tokens, x), dim=1) 234 | pos_embed = self.interpolate_pos_encoding(x, self.pos_embed) 235 | x = x + pos_embed 236 | x = self.pos_drop(x) 237 | for f in self.blocks: 238 | x = f(x) 239 | x = self.norm(x) 240 | return x 241 | 242 | def interpolate_pos_encoding(self, x, pos_embed): 243 | npatch = x.shape[1] - 1 244 | N = pos_embed.shape[1] - 1 245 | if npatch == N: 246 | return pos_embed 247 | class_emb = pos_embed[:, 0] 248 | pos_embed = pos_embed[:, 1:] 249 | dim = x.shape[-1] 250 | pos_embed = nn.functional.interpolate( 251 | pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 252 | scale_factor=math.sqrt(npatch / N), 253 | mode='bicubic', 254 | ) 255 | pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 256 | return torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1) 257 | 258 | def msn_vit_small_16(patch_size=16, **kwargs): 259 | model = VisionTransformer( 260 | patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, 261 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 262 | return model 263 | 264 | def msn_vit_base_16(patch_size=16, **kwargs): 265 | model = VisionTransformer( 266 | patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 267 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 268 | return model -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ByungKwanLee/Causal-Unsupervised-Segmentation/ee0aa8478a6b6704f4db44ecc70e44a14fe5067f/modules/__init__.py -------------------------------------------------------------------------------- /modules/segment.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from modules.segment_module import HeadSegment, Decoder, ProjectionSegment 4 | 5 | class Segment_MLP(nn.Module): 6 | def __init__(self, args): 7 | super().__init__() 8 | 9 | ################################################################################## 10 | # dropout 11 | self.dropout = torch.nn.Dropout(p=0.1) 12 | ################################################################################## 13 | 14 | ################################################################################## 15 | # MLP Head 16 | self.head = HeadSegment(args.dim, args.reduced_dim) 17 | self.projection_head = ProjectionSegment(nn.Conv2d(args.reduced_dim, args.projection_dim, kernel_size=1), is_untrans=True) 18 | ################################################################################## 19 | 20 | ################################################################################## 21 | # MLP EMA Head 22 | self.head_ema = HeadSegment(args.dim, args.reduced_dim) 23 | self.projection_head_ema = ProjectionSegment(nn.Conv2d(args.reduced_dim, args.projection_dim, kernel_size=1), is_untrans=True) 24 | self.linear = ProjectionSegment(nn.Conv2d(args.reduced_dim, args.n_classes, kernel_size=1), is_untrans=False) 25 | ################################################################################## 26 | 27 | class Segment_TR(nn.Module): 28 | def __init__(self, args): 29 | super().__init__() 30 | 31 | ################################################################################## 32 | # dropout 33 | self.dropout = nn.Dropout(p=0.1) 34 | ################################################################################## 35 | 36 | ################################################################################## 37 | # TR Decoder Head 38 | self.head = Decoder(args) 39 | self.projection_head = ProjectionSegment(nn.Conv2d(args.reduced_dim, args.projection_dim, kernel_size=1), is_untrans=True) 40 | ################################################################################## 41 | 42 | ################################################################################## 43 | # TR Decoder EMA Head 44 | self.head_ema = Decoder(args) 45 | self.projection_head_ema = ProjectionSegment(nn.Conv2d(args.reduced_dim, args.projection_dim, kernel_size=1), is_untrans=True) 46 | self.linear = ProjectionSegment(nn.Conv2d(args.reduced_dim, args.n_classes, kernel_size=1), is_untrans=False) 47 | ################################################################################## 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /modules/segment_module.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import cat 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch import randperm as perm 7 | 8 | """ 9 | Below are Classes for CAUSE 10 | """ 11 | 12 | class HeadSegment(nn.Module): 13 | def __init__(self, dim, reduced_dim): 14 | super().__init__() 15 | self.dim = dim 16 | self.reduced_dim = reduced_dim 17 | self.f1 = nn.Conv2d(self.dim, self.reduced_dim, (1, 1)) 18 | self.f2 = nn.Sequential(nn.Conv2d(self.dim, self.dim, (1, 1)), 19 | nn.ReLU(), 20 | nn.Conv2d(self.dim, self.reduced_dim, (1, 1))) 21 | 22 | def forward(self, feat, drop=nn.Identity()): 23 | feat = transform(feat) 24 | feat = self.f1(drop(feat)) + self.f2(drop(feat)) 25 | return untransform(feat) 26 | 27 | class ProjectionSegment(nn.Module): 28 | def __init__(self, func, is_untrans): 29 | super().__init__() 30 | self.f = func 31 | self.is_untrans = is_untrans 32 | 33 | def forward(self, feat): 34 | feat = transform(feat) 35 | feat = self.f(feat) 36 | return untransform(feat) if self.is_untrans else feat 37 | 38 | class TRDecoder(nn.Module): 39 | 40 | def __init__(self, dim, reduced_dim, hidden_dim=2048, nhead=1, dropout=0.1): 41 | super().__init__() 42 | self.self_attn = nn.MultiheadAttention(dim, nhead, dropout=dropout) 43 | self.multihead_attn = nn.MultiheadAttention(dim, nhead, dropout=dropout) 44 | 45 | self.linear1 = nn.Linear(dim, hidden_dim) 46 | self.dropout = nn.Dropout(dropout) 47 | self.linear2 = nn.Linear(hidden_dim, dim) 48 | 49 | self.norm1 = nn.LayerNorm(dim) 50 | self.norm2 = nn.LayerNorm(dim) 51 | self.norm3 = nn.LayerNorm(dim) 52 | self.dropout1 = nn.Dropout(dropout) 53 | self.dropout2 = nn.Dropout(dropout) 54 | self.dropout3 = nn.Dropout(dropout) 55 | 56 | self.f1 = nn.Conv2d(dim, reduced_dim, (1, 1)) 57 | self.f2 = nn.Sequential(nn.Conv2d(dim, dim, (1, 1)), 58 | nn.ReLU(), 59 | nn.Conv2d(dim, reduced_dim, (1, 1))) 60 | 61 | def forward(self, tgt, memory, pos, drop): 62 | q = k = tgt + pos 63 | tgt2 = self.self_attn(q, k, value=tgt)[0] 64 | tgt = tgt + self.dropout1(tgt2) 65 | tgt = self.norm1(tgt) 66 | tgt2 = self.multihead_attn(query=tgt + pos, key=memory, value=memory)[0] 67 | tgt = tgt + self.dropout2(tgt2) 68 | tgt = self.norm2(tgt) 69 | tgt2 = self.linear2(self.dropout(F.relu(self.linear1(tgt)))) 70 | tgt = tgt + self.dropout3(tgt2) 71 | tgt = memory + self.norm3(tgt) 72 | tgt = transform(tgt.transpose(0, 1)) 73 | tgt = self.f1(drop(tgt)) + self.f2(drop(tgt)) 74 | tgt = untransform(tgt) 75 | return tgt 76 | 77 | class Decoder(nn.Module): 78 | def __init__(self, args, codebook=None): 79 | super().__init__() 80 | self.codebook = codebook 81 | 82 | # TR decoder 83 | self.query_pos = nn.Parameter(torch.randn(args.num_queries, args.dim)) 84 | self.tr = TRDecoder(args.dim, args.reduced_dim) 85 | 86 | def forward(self, feat, drop=nn.Identity()): 87 | discrete_query = vqt(feat, self.codebook) 88 | tr_feat = self.tr(discrete_query.transpose(0, 1), feat.transpose(0, 1), 89 | self.query_pos.unsqueeze(1), drop) 90 | return tr_feat 91 | 92 | class Cluster(nn.Module): 93 | def __init__(self, args): 94 | super().__init__() 95 | 96 | # projection dim 97 | self.dim = args.dim 98 | self.reduced_dim = args.reduced_dim 99 | self.projection_dim = args.projection_dim 100 | 101 | # num codebook 102 | self.num_codebook = args.num_codebook 103 | 104 | # Codebook 105 | self.codebook = nn.Parameter(torch.empty(self.num_codebook, self.dim)) 106 | reset(self.codebook, args.num_codebook) 107 | 108 | # cluster centroid 109 | self.cluster_probe = torch.nn.Parameter(torch.randn(args.n_classes, self.reduced_dim)) 110 | reset(self.cluster_probe, args.n_classes) 111 | 112 | def bank_init(self): 113 | self.prime_bank = {} 114 | start_of_tensor = torch.empty([0, self.projection_dim]).cuda() 115 | for i in range(self.num_codebook): 116 | self.prime_bank[i] = start_of_tensor 117 | 118 | def bank_update(self, feat, proj_feat_ema, max_num=100): 119 | # load all and bank collection 120 | quant_ind = quantize_index(feat, self.codebook) 121 | for i in quant_ind.unique(): 122 | # key bank 123 | key = proj_feat_ema[torch.where(quant_ind == i)] 124 | 125 | # 50% random cutting 126 | key = key[perm(len(key))][:int(len(key)*0.5)] 127 | 128 | # merging 129 | self.prime_bank[i.item()] = cat([self.prime_bank[i.item()], key], dim=0) 130 | 131 | # bank length 132 | length = len(self.prime_bank[i.item()]) 133 | 134 | # if maximum number is over, slice by the order of the older 135 | if length >= max_num: 136 | self.prime_bank[i.item()] = self.prime_bank[i.item()][length-max_num:] 137 | 138 | def bank_compute(self): 139 | bank_vq_feat = torch.empty([0, self.dim]).cuda() 140 | bank_proj_feat_ema = torch.empty([0, self.projection_dim]).cuda() 141 | for key in self.prime_bank.keys(): 142 | num = self.prime_bank[key].shape[0] 143 | if num == 0: continue 144 | bank_vq_feat = cat([bank_vq_feat, self.codebook[key].unsqueeze(0).repeat(num, 1)], dim=0) 145 | bank_proj_feat_ema = cat([bank_proj_feat_ema, self.prime_bank[key]], dim=0) 146 | 147 | # normalized feature and flat its feature for computing correspondence 148 | self.flat_norm_bank_vq_feat = F.normalize(bank_vq_feat, dim=1) 149 | self.flat_norm_bank_proj_feat_ema = F.normalize(bank_proj_feat_ema, dim=1) 150 | 151 | 152 | def contrastive_ema_with_codebook_bank(self, feat, proj_feat, proj_feat_ema, temp=0.07, pos_thresh=0.3, neg_thresh=0.1): 153 | """ 154 | get all anchors and positive samples with same codebook index 155 | """ 156 | 157 | # quantized feature to positive sample and negative sample 158 | vq_feat = vqt(feat, self.codebook) 159 | norm_vq_feat = F.normalize(vq_feat, dim=2) 160 | flat_norm_vq_feat = flatten(norm_vq_feat) 161 | 162 | # normalized feature and flat its feature for computing correspondence 163 | norm_proj_feat = F.normalize(proj_feat, dim=2) 164 | 165 | # normalized feature and flat its feature for computing correspondence 166 | norm_proj_feat_ema = F.normalize(proj_feat_ema, dim=2) 167 | flat_norm_proj_feat_ema = flatten(norm_proj_feat_ema) 168 | 169 | # selecting anchors by one-batch for all correspondence to all-batches 170 | # positive/negative 171 | loss_NCE_list = [] 172 | for batch_ind in range(proj_feat.shape[0]): 173 | 174 | # anchor selection 175 | anchor_vq_feat = norm_vq_feat[batch_ind] 176 | anchor_proj_feat = norm_proj_feat[batch_ind] 177 | 178 | # cosine similarity of student-teacher 179 | cs_st = anchor_proj_feat @ flat_norm_proj_feat_ema.T 180 | 181 | # Codebook distance 182 | codebook_distance = anchor_vq_feat @ flat_norm_vq_feat.T 183 | bank_codebook_distance = anchor_vq_feat @ self.flat_norm_bank_vq_feat.T 184 | 185 | # [1] student-teacher (in-batch, local) 186 | pos_mask = (codebook_distance > pos_thresh) 187 | neg_mask = (codebook_distance < neg_thresh) 188 | 189 | auto_mask = torch.ones_like(pos_mask) 190 | auto_mask[:, batch_ind * pos_mask.shape[0]:(batch_ind + 1) * pos_mask.shape[0]].fill_diagonal_(0) 191 | pos_mask *= auto_mask 192 | 193 | cs_teacher = cs_st / temp 194 | shifted_cs_teacher = cs_teacher - cs_teacher.max(dim=1, keepdim=True)[0].detach() 195 | shifted_cs_teacher_with_only_neg = shifted_cs_teacher.exp() * (pos_mask + neg_mask) 196 | pos_neg_loss_matrix_teacher = -shifted_cs_teacher + torch.log(shifted_cs_teacher_with_only_neg.sum(dim=1, keepdim=True)) 197 | loss_NCE_list.append(pos_neg_loss_matrix_teacher[torch.where(pos_mask!=0)].mean()) 198 | 199 | # [2] student-teacher bank (out-batch, global) 200 | if self.flat_norm_bank_proj_feat_ema.shape[0] != 0: 201 | 202 | # cosine similarity of student-teacher bank 203 | cs_st_bank = anchor_proj_feat @ self.flat_norm_bank_proj_feat_ema.T 204 | 205 | bank_pos_mask = (bank_codebook_distance > pos_thresh) 206 | bank_neg_mask = (bank_codebook_distance < neg_thresh) 207 | 208 | cs_teacher_bank = cs_st_bank / temp 209 | shifted_cs_teacher_bank = cs_teacher_bank - cs_teacher_bank.max(dim=1, keepdim=True)[0].detach() 210 | shifted_cs_teacher_bank_with_only_neg = shifted_cs_teacher_bank.exp() * (bank_pos_mask + bank_neg_mask) 211 | pos_neg_loss_matrix_teacher_bank = -shifted_cs_teacher_bank + torch.log(shifted_cs_teacher_bank_with_only_neg.sum(dim=1, keepdim=True)) 212 | 213 | # loss append 214 | loss_NCE_list.append(pos_neg_loss_matrix_teacher_bank[torch.where(bank_pos_mask!=0)].mean()) 215 | 216 | # front 217 | loss_front = sum(loss_NCE_list) / float(len(loss_NCE_list)) 218 | return loss_front 219 | 220 | def forward_centroid(self, x, inference=False, alpha=3, crf=False): 221 | normed_features = F.normalize(transform(x.detach()), dim=1) 222 | normed_clusters = F.normalize(self.cluster_probe, dim=1) 223 | inner_products = torch.einsum("bchw,nc->bnhw", normed_features, normed_clusters) 224 | 225 | if inference: 226 | return torch.argmax(inner_products, dim=1) 227 | 228 | if crf: 229 | return torch.log_softmax(inner_products * alpha, dim=1) 230 | 231 | cluster_probs = F.one_hot(torch.argmax(inner_products, dim=1), self.cluster_probe.shape[0]) \ 232 | .permute(0, 3, 1, 2).to(torch.float32) 233 | 234 | cluster_loss = -(cluster_probs * inner_products).sum(1).mean() 235 | return cluster_loss, cluster_probs.argmax(1) 236 | 237 | 238 | 239 | """ 240 | Below are functions 241 | """ 242 | 243 | 244 | def transform(x): 245 | """ 246 | B, P, D => B, D, root(P), root(P) 247 | 248 | Ex) 128, 400, 768 => 128, 768, 20, 20 249 | """ 250 | B, P, D = x.shape 251 | return x.permute(0, 2, 1).view(B, D, int(math.sqrt(P)), int(math.sqrt(P))) 252 | 253 | def untransform(x): 254 | """ 255 | B, D, P, P => B, P*P, D, 256 | 257 | Ex) 128, 768, 20, 20 => 128, 400, 768 258 | """ 259 | B, D, P, P = x.shape 260 | return x.view(B, D, -1).permute(0, 2, 1) 261 | 262 | def flatten(x): 263 | """ 264 | B, P, D => B*P, D 265 | 266 | Ex) 16, 400, 768 => 6400, 768 267 | """ 268 | B, P, D = x.shape 269 | return x.contiguous().view(B*P, D) 270 | 271 | def unflatten(x, batch_size=16): 272 | """ 273 | B*P, D => B, P, D 274 | 275 | Ex) 6400, 768 => 16, 400, 768 276 | """ 277 | P, D = x.shape 278 | return x.contiguous().view(batch_size, P//batch_size, D) 279 | 280 | def stochastic_sampling(x, order=None, k=4): 281 | """ 282 | pooling 283 | """ 284 | x = transform(x) 285 | x_patch = x.unfold(2, k, k).unfold(3, k, k) 286 | x_patch = x_patch.permute(0, 2, 3, 4, 5, 1) 287 | x_patch = x_patch.reshape(-1, x_patch.shape[3:5].numel(), x_patch.shape[5]) 288 | 289 | if order==None: order = torch.randint(k ** 2, size=(x_patch.shape[0],)) 290 | 291 | x_patch = x_patch[range(x_patch.shape[0]), order].reshape(x.shape[0], x.shape[2]//k, x.shape[3]//k, -1) 292 | x_patch = x_patch.permute(0, 3, 1, 2) 293 | x = untransform(x_patch) 294 | return x, order 295 | 296 | 297 | def quantize_index(z, c, mode='cos'): 298 | if mode == 'cos': 299 | # computing distance 300 | dist = cos_distance_matrix(z, c) 301 | elif mode == 'l2': 302 | dist = l2_distance_matrix(z, c) 303 | 304 | # quantize 305 | return dist.argmax(dim=2) 306 | 307 | def cos_distance_matrix(z, c): 308 | # flatten z 309 | z_flattened = z.contiguous().view(-1, z.shape[-1]) 310 | norm_z = F.normalize(z_flattened, dim=1) 311 | norm_embed = F.normalize(c, dim=1) 312 | return torch.einsum("ab,cb->ac", norm_z, norm_embed).view(*z.shape[:-1], -1) 313 | 314 | def l2_distance_matrix(z, c): 315 | # flatten z 316 | z_flattened = z.contiguous().view(-1, z.shape[-1]) 317 | dist = (z_flattened.square().sum(dim=1, keepdims=True) + c.square().sum(dim=1).unsqueeze(0) 318 | -2 * z_flattened @ c.transpose(0, 1)) / c.shape[1] 319 | return torch.exp(-dist/z.shape[2]/2).view(*z.shape[:-1], -1) 320 | 321 | def codebook_index(z, c): 322 | # computing distance 323 | dist = cos_distance_matrix(z, c) 324 | 325 | # codebook index 326 | return dist.argmax(dim=2) 327 | 328 | def vqt(z, c): 329 | """ 330 | Return Vector-Quantized Tensor 331 | """ 332 | codebook_ind = codebook_index(z, c) 333 | return c[codebook_ind].view(*z.shape[:-1], c.shape[1]) 334 | 335 | 336 | def auto_cs(x): 337 | a = F.normalize(x, dim=1) 338 | return a @ a.T 339 | 340 | def reset(x, n_c): x.data.uniform_(-1.0 / n_c, 1.0 / n_c) 341 | 342 | def ema_init(x, x_ema): 343 | for param, param_ema in zip(x.parameters(), x_ema.parameters()): param_ema.data = param.data; param_ema.requires_grad = False 344 | 345 | def ema_update(x, x_ema, lamb=0.99): 346 | for student_params, teacher_params in zip(x.parameters(), x_ema.parameters()): 347 | teacher_params.data = lamb * teacher_params.data + (1-lamb) * student_params.data 348 | 349 | 350 | def img_to_patch_for_affinity(img, patch_size): 351 | img_patch = img.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size) 352 | img_patch = img_patch.permute(0, 2, 3, 1, 4, 5) 353 | img_patch = img_patch.reshape(*img_patch.shape[:3], -1) 354 | img_patch = img_patch.reshape(img_patch.shape[0], -1, img_patch.shape[3]) 355 | return img_patch 356 | 357 | def get_modularity_matrix_and_edge(x, mode='cos'): 358 | """ 359 | getting W=(A-ddT/2m) and getting all edges (e) 360 | """ 361 | if mode=='cos': 362 | norm = F.normalize(x, dim=2) 363 | A = (norm @ norm.transpose(2, 1)).clamp(0) 364 | elif mode=='l2': 365 | A = compute_self_distance_batch(x) 366 | 367 | A = A - A * torch.eye(A.shape[1]).cuda() 368 | d = A.sum(dim=2, keepdims=True) 369 | e = A.sum(dim=(1, 2), keepdims=True) 370 | W = A - (d / e) @ (d.transpose(2, 1) / e) * e 371 | return W, e 372 | 373 | def cluster_assignment_matrix(z, c): 374 | norm_z = F.normalize(z, dim=2) 375 | norm_c = F.normalize(c, dim=1) 376 | return (norm_z @ norm_c.unsqueeze(0).transpose(2, 1)).clamp(0) 377 | 378 | def compute_modularity_based_codebook(c, x, temp=0.1, grid=False): 379 | 380 | # detach 381 | x = x.detach() 382 | 383 | # pooling for reducing GPU memory allocation 384 | if grid: x, _ = stochastic_sampling(x) 385 | 386 | # modularity matrix and its edge matrix 387 | W, e = get_modularity_matrix_and_edge(x) 388 | 389 | 390 | # cluster assignment matrix 391 | C = cluster_assignment_matrix(x, c) 392 | 393 | # tanh with temperature 394 | D = C.transpose(2, 1) 395 | E = torch.tanh(D.unsqueeze(3) @ D.unsqueeze(2) / temp) 396 | delta, _ = E.max(dim=1) 397 | Q = (W / e) @ delta 398 | 399 | # trace 400 | diag = Q.diagonal(offset=0, dim1=-2, dim2=-1) 401 | trace = diag.sum(dim=-1) 402 | 403 | return -trace.mean() 404 | 405 | def compute_self_distance_batch(x): 406 | dist = x.square().sum(dim=2, keepdims=True) + x.square().sum(dim=2).unsqueeze(1) -2 * (x @ x.transpose(2, 1)) 407 | return torch.exp(-dist/x.shape[2]) 408 | 409 | def img_to_patch(img, patch_size=16): 410 | img_patch = img.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size) 411 | img_patch = img_patch.permute(0, 2, 3, 1, 4, 5) 412 | return img_patch.reshape(-1, *img_patch.shape[3:]) 413 | 414 | def patch_to_img(patch, batch_size=16, patch_size=16, img_size=320): 415 | patch_ = patch.reshape(batch_size, img_size//patch_size, img_size//patch_size, 3, patch_size, patch_size) 416 | patch_ = patch_.permute(0, 3, 1, 4, 2, 5) 417 | return patch_.reshape(batch_size, 3, img_size, img_size) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | timm 3 | tqdm 4 | scipy 5 | numpy 6 | tensorboardX 7 | wget 8 | scikit-image 9 | scikit-learn 10 | xformers 11 | -------------------------------------------------------------------------------- /run: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ###################################### 3 | # [OPTION] DATASET 4 | 5 | # cocostuff27 6 | dataset="cocostuff27" 7 | 8 | # cityscapes 9 | # dataset="cityscapes" 10 | 11 | # pascalvoc 12 | # dataset="pascalvoc" 13 | 14 | # coco-81 15 | # dataset="coco81" 16 | 17 | # coco-171 18 | # dataset="coco171" 19 | ###################################### 20 | 21 | ###################################### 22 | # [OPTION] STRUCTURE 23 | # structure="MLP" 24 | structure="TR" 25 | ###################################### 26 | 27 | ###################################### 28 | # [OPTION] Self-Supervised Method 29 | 30 | # DINO 31 | # ckpt="checkpoint/dino_vit_small_8.pth" 32 | # ckpt="checkpoint/dino_vit_small_16.pth" 33 | ckpt="checkpoint/dino_vit_base_8.pth" 34 | # ckpt="checkpoint/dino_vit_base_16.pth" 35 | 36 | # DINOv2 37 | # ckpt="checkpoint/dinov2_vit_base_14.pth" 38 | 39 | # iBOT 40 | # ckpt="checkpoint/ibot_vit_base_16.pth" 41 | 42 | # MSN 43 | # ckpt="checkpoint/msn_vit_small_16.pth" 44 | 45 | # MAE 46 | # ckpt="checkpoint/mae_vit_base_16.pth" 47 | ###################################### 48 | 49 | ###################################### 50 | # GPU and PORT 51 | if [ "$structure" = "MLP" ] 52 | then 53 | train_gpu="0,1,2,3" 54 | elif [ "$structure" = "TR" ] 55 | then 56 | train_gpu="4,5,6,7" 57 | fi 58 | 59 | # Non-Changeable Variable 60 | test_gpu="${train_gpu:0}" 61 | port=$(($RANDOM%800+1200)) 62 | ###################################### 63 | 64 | ###################################### 65 | # [STEP1] MEDIATOR 66 | python train_mediator.py --dataset $dataset --ckpt $ckpt --gpu $train_gpu --port $port 67 | ###################################### 68 | 69 | ###################################### 70 | # [STEP2] CAUSE 71 | if [ "$structure" = "MLP" ] 72 | then 73 | python train_front_door_mlp.py --dataset $dataset --ckpt $ckpt --gpu $train_gpu --port $port 74 | python fine_tuning_mlp.py --dataset $dataset --ckpt $ckpt --gpu $train_gpu --port $port 75 | elif [ "$structure" = "TR" ] 76 | then 77 | python train_front_door_tr.py --dataset $dataset --ckpt $ckpt --gpu $train_gpu --port $port 78 | python fine_tuning_tr.py --dataset $dataset --ckpt $ckpt --gpu $train_gpu --port $port 79 | fi 80 | ###################################### 81 | 82 | ###################################### 83 | # TEST 84 | if [ "$structure" = "MLP" ] 85 | then 86 | python test_mlp.py --dataset $dataset --ckpt $ckpt --gpu $test_gpu 87 | elif [ "$structure" = "TR" ] 88 | then 89 | python test_tr.py --dataset $dataset --ckpt $ckpt --gpu $test_gpu 90 | fi 91 | ###################################### -------------------------------------------------------------------------------- /test_mlp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from tqdm import tqdm 4 | from utils.utils import * 5 | from loader.dataloader import dataloader 6 | from torch.cuda.amp import autocast 7 | from modules.segment_module import transform, untransform 8 | from loader.netloader import network_loader, segment_mlp_loader, cluster_mlp_loader 9 | 10 | 11 | def test(args, net, segment, cluster, nice, test_loader, cmap): 12 | segment.eval() 13 | 14 | prog_bar = tqdm(enumerate(test_loader), total=len(test_loader), leave=True) 15 | with Pool(40) as pool: 16 | for _, batch in prog_bar: 17 | # image and label and self supervised feature 18 | ind = batch["ind"].cuda() 19 | img = batch["img"].cuda() 20 | label = batch["label"].cuda() 21 | 22 | with autocast(): 23 | # intermediate feature 24 | feat = net(img)[:, 1:, :] 25 | feat_flip = net(img.flip(dims=[3]))[:, 1:, :] 26 | seg_feat = transform(segment.head_ema(feat)) 27 | seg_feat_flip = transform(segment.head_ema(feat_flip)) 28 | seg_feat = untransform((seg_feat + seg_feat_flip.flip(dims=[3])) / 2) 29 | 30 | # interp feat 31 | interp_seg_feat = F.interpolate(transform(seg_feat), label.shape[-2:], mode='bilinear', align_corners=False) 32 | 33 | # cluster preds 34 | cluster_preds = cluster.forward_centroid(untransform(interp_seg_feat), crf=True) 35 | 36 | # crf 37 | crf_preds = do_crf(pool, img, cluster_preds).argmax(1).cuda() 38 | 39 | # nice evaluation 40 | _, desc_nice = nice.eval(crf_preds, label) 41 | 42 | # hungarian 43 | hungarian_preds = nice.do_hungarian(crf_preds) 44 | 45 | # save images 46 | save_all(args, ind, img, label, cluster_preds.argmax(dim=1), crf_preds, hungarian_preds, cmap) 47 | 48 | # real-time print 49 | desc = f'{desc_nice}' 50 | prog_bar.set_description(desc, refresh=True) 51 | 52 | # evaludation metric reset 53 | nice.reset() 54 | 55 | 56 | def test_without_crf(args, net, segment, cluster, nice, test_loader): 57 | segment.eval() 58 | 59 | total_acc = 0 60 | prog_bar = tqdm(enumerate(test_loader), total=len(test_loader), leave=True) 61 | for idx, batch in prog_bar: 62 | # image and label and self supervised feature 63 | ind = batch["ind"].cuda() 64 | img = batch["img"].cuda() 65 | label = batch["label"].cuda() 66 | 67 | # intermediate feature 68 | with autocast(): 69 | 70 | feat = net(img)[:, 1:, :] 71 | seg_feat_ema = segment.head_ema(feat) 72 | 73 | # linear probe loss 74 | linear_logits = segment.linear(seg_feat_ema) 75 | linear_logits = F.interpolate(linear_logits, label.shape[-2:], mode='bilinear', align_corners=False) 76 | flat_label = label.reshape(-1) 77 | flat_label_mask = (flat_label >= 0) & (flat_label < args.n_classes) 78 | 79 | # interp feat 80 | interp_seg_feat = F.interpolate(transform(seg_feat_ema), label.shape[-2:], mode='bilinear', align_corners=False) 81 | 82 | # cluster 83 | cluster_preds = cluster.forward_centroid(untransform(interp_seg_feat), inference=True) 84 | 85 | # nice evaluation 86 | _, desc_nice = nice.eval(cluster_preds, label) 87 | 88 | # linear probe acc check 89 | pred_label = linear_logits.argmax(dim=1) 90 | flat_pred_label = pred_label.reshape(-1) 91 | acc = (flat_pred_label[flat_label_mask] == flat_label[flat_label_mask]).sum() / flat_label[ 92 | flat_label_mask].numel() 93 | total_acc += acc.item() 94 | 95 | # real-time print 96 | desc = f'[TEST] Acc (Linear): {100. * total_acc / (idx + 1):.1f}% | {desc_nice}' 97 | prog_bar.set_description(desc, refresh=True) 98 | 99 | # evaludation metric reset 100 | nice.reset() 101 | 102 | 103 | def test_linear_without_crf(args, net, segment, nice, test_loader): 104 | segment.eval() 105 | 106 | prog_bar = tqdm(enumerate(test_loader), total=len(test_loader), leave=True) 107 | with Pool(40) as pool: 108 | for _, batch in prog_bar: 109 | # image and label and self supervised feature 110 | ind = batch["ind"].cuda() 111 | img = batch["img"].cuda() 112 | label = batch["label"].cuda() 113 | 114 | with autocast(): 115 | # intermediate feature 116 | feat = net(img)[:, 1:, :] 117 | feat_flip = net(img.flip(dims=[3]))[:, 1:, :] 118 | seg_feat = segment.transform(segment.head_ema(feat)) 119 | seg_feat_flip = segment.transform(segment.head_ema(feat_flip)) 120 | seg_feat = segment.untransform((seg_feat + seg_feat_flip.flip(dims=[3])) / 2) 121 | 122 | # interp feat 123 | interp_seg_feat = F.interpolate(segment.transform(seg_feat), label.shape[-2:], mode='bilinear', align_corners=False) 124 | 125 | # linear probe interp feat 126 | linear_logits = segment.linear(segment.untransform(interp_seg_feat)) 127 | 128 | # cluster preds 129 | cluster_preds = linear_logits.argmax(dim=1) 130 | 131 | # nice evaluation 132 | _, desc_nice = nice.eval(cluster_preds, label) 133 | 134 | # real-time print 135 | desc = f'{desc_nice}' 136 | prog_bar.set_description(desc, refresh=True) 137 | 138 | # evaludation metric reset 139 | nice.reset() 140 | 141 | 142 | 143 | def test_linear(args, net, segment, nice, test_loader): 144 | segment.eval() 145 | 146 | prog_bar = tqdm(enumerate(test_loader), total=len(test_loader), leave=True) 147 | with Pool(40) as pool: 148 | for _, batch in prog_bar: 149 | # image and label and self supervised feature 150 | ind = batch["ind"].cuda() 151 | img = batch["img"].cuda() 152 | label = batch["label"].cuda() 153 | 154 | with autocast(): 155 | # intermediate feature 156 | feat = net(img)[:, 1:, :] 157 | feat_flip = net(img.flip(dims=[3]))[:, 1:, :] 158 | seg_feat = segment.transform(segment.head_ema(feat)) 159 | seg_feat_flip = segment.transform(segment.head_ema(feat_flip)) 160 | seg_feat = segment.untransform((seg_feat + seg_feat_flip.flip(dims=[3])) / 2) 161 | 162 | 163 | # interp feat 164 | interp_seg_feat = F.interpolate(segment.transform(seg_feat), label.shape[-2:], mode='bilinear', align_corners=False) 165 | 166 | # linear probe interp feat 167 | linear_logits = segment.linear(segment.untransform(interp_seg_feat)) 168 | 169 | # cluster preds 170 | cluster_preds = torch.log_softmax(linear_logits * 10, dim=1).argmax(dim=1) 171 | 172 | # crf 173 | onehot = F.one_hot(cluster_preds.to(torch.int64), args.n_classes).to(torch.float32) 174 | crf_preds = do_crf(pool, img, onehot.permute(0, 3, 1, 2)).argmax(1).cuda() 175 | 176 | # nice evaluation 177 | _, desc_nice = nice.eval(crf_preds, label) 178 | 179 | # real-time print 180 | desc = f'{desc_nice}' 181 | prog_bar.set_description(desc, refresh=True) 182 | 183 | # evaludation metric reset 184 | nice.reset() 185 | 186 | 187 | def main(rank, args): 188 | 189 | # setting gpu id of this process 190 | torch.cuda.set_device(rank) 191 | 192 | # print argparse 193 | print_argparse(args, rank=0) 194 | 195 | # dataset loader 196 | _, test_loader, _ = dataloader(args, False) 197 | 198 | # network loader 199 | net = network_loader(args, rank) 200 | segment = segment_mlp_loader(args, rank) 201 | cluster = cluster_mlp_loader(args, rank) 202 | 203 | # evaluation 204 | nice = NiceTool(args.n_classes) 205 | 206 | # color map 207 | cmap = create_cityscapes_colormap() if args.dataset == 'cityscapes' else create_pascal_label_colormap() 208 | 209 | # param size 210 | print(f'# of Parameters: {num_param(segment)/10**6:.2f}(M)') 211 | 212 | 213 | # post-processing with crf and hungarian matching 214 | test_without_crf( 215 | args, 216 | net, 217 | segment, 218 | cluster, 219 | nice, 220 | test_loader) 221 | 222 | # post-processing with crf and hungarian matching 223 | test( 224 | args, 225 | net, 226 | segment, 227 | cluster, 228 | nice, 229 | test_loader, 230 | cmap) 231 | 232 | # post-processing with crf and hungarian matching 233 | # test_linear_without_crf( 234 | # args, 235 | # net, 236 | # segment, 237 | # nice, 238 | # test_loader) 239 | 240 | # test_linear( 241 | # args, 242 | # net, 243 | # segment, 244 | # nice, 245 | # test_loader) 246 | 247 | 248 | if __name__ == "__main__": 249 | 250 | # fetch args 251 | parser = argparse.ArgumentParser() 252 | # model parameter 253 | parser.add_argument('--NAME-TAG', default='CAUSE-MLP', type=str) 254 | parser.add_argument('--data_dir', default='/mnt/hard2/lbk-iccv/datasets', type=str) 255 | parser.add_argument('--dataset', default='coco171', type=str) 256 | parser.add_argument('--port', default='12355', type=str) 257 | parser.add_argument('--load_segment', default=True, type=str2bool) 258 | parser.add_argument('--load_cluster', default=True, type=str2bool) 259 | parser.add_argument('--ckpt', default='checkpoint/dino_vit_small_8.pth', type=str) 260 | parser.add_argument('--distributed', default=False, type=str2bool) 261 | parser.add_argument('--train_resolution', default=224, type=int) 262 | parser.add_argument('--test_resolution', default=320, type=int) 263 | parser.add_argument('--batch_size', default=32, type=int) 264 | parser.add_argument('--num_workers', default=int(os.cpu_count() / 8), type=int) 265 | parser.add_argument('--gpu', default='4', type=str) 266 | parser.add_argument('--num_codebook', default=2048, type=int) 267 | 268 | # model parameter 269 | parser.add_argument('--reduced_dim', default=90, type=int) 270 | parser.add_argument('--projection_dim', default=2048, type=int) 271 | 272 | args = parser.parse_args() 273 | 274 | if 'dinov2' in args.ckpt: args.test_resolution=322 275 | if 'small' in args.ckpt: 276 | args.dim = 384 277 | elif 'base' in args.ckpt: 278 | args.dim = 768 279 | 280 | # the number of gpus for multi-process 281 | gpu_list = list(map(int, args.gpu.split(','))) 282 | ngpus_per_node = len(gpu_list) 283 | 284 | # first gpu index is activated once there are several gpu in args.gpu 285 | main(rank=gpu_list[0], args=args) 286 | -------------------------------------------------------------------------------- /test_tr.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from tqdm import tqdm 4 | from utils.utils import * 5 | from modules.segment_module import transform, untransform 6 | from loader.dataloader import dataloader 7 | from torch.cuda.amp import autocast 8 | from loader.netloader import network_loader, segment_tr_loader, cluster_tr_loader 9 | 10 | 11 | def test(args, net, segment, cluster, nice, test_loader, cmap): 12 | segment.eval() 13 | 14 | prog_bar = tqdm(enumerate(test_loader), total=len(test_loader), leave=True) 15 | with Pool(40) as pool: 16 | for _, batch in prog_bar: 17 | # image and label and self supervised feature 18 | ind = batch["ind"].cuda() 19 | img = batch["img"].cuda() 20 | label = batch["label"].cuda() 21 | 22 | with autocast(): 23 | # intermediate feature 24 | feat = net(img)[:, 1:, :] 25 | feat_flip = net(img.flip(dims=[3]))[:, 1:, :] 26 | seg_feat = transform(segment.head_ema(feat)) 27 | seg_feat_flip = transform(segment.head_ema(feat_flip)) 28 | seg_feat = untransform((seg_feat + seg_feat_flip.flip(dims=[3])) / 2) 29 | 30 | # interp feat 31 | interp_seg_feat = F.interpolate(transform(seg_feat), label.shape[-2:], mode='bilinear', align_corners=False) 32 | 33 | # cluster preds 34 | cluster_preds = cluster.forward_centroid(untransform(interp_seg_feat), crf=True) 35 | 36 | # crf 37 | crf_preds = do_crf(pool, img, cluster_preds).argmax(1).cuda() 38 | 39 | # nice evaluation 40 | _, desc_nice = nice.eval(crf_preds, label) 41 | 42 | # hungarian 43 | hungarian_preds = nice.do_hungarian(crf_preds) 44 | 45 | # save images 46 | save_all(args, ind, img, label, cluster_preds.argmax(dim=1), crf_preds, hungarian_preds, cmap, is_tr=True) 47 | 48 | # real-time print 49 | desc = f'{desc_nice}' 50 | prog_bar.set_description(desc, refresh=True) 51 | 52 | # evaludation metric reset 53 | nice.reset() 54 | 55 | 56 | 57 | def test_without_crf(args, net, segment, cluster, nice, test_loader): 58 | segment.eval() 59 | 60 | total_acc = 0 61 | prog_bar = tqdm(enumerate(test_loader), total=len(test_loader), leave=True) 62 | for idx, batch in prog_bar: 63 | # image and label and self supervised feature 64 | ind = batch["ind"].cuda() 65 | img = batch["img"].cuda() 66 | label = batch["label"].cuda() 67 | 68 | cmap = create_pascal_label_colormap() 69 | a = invTrans(img)[0].permute(1,2,0) 70 | b = cmap[label[0].cpu()] 71 | 72 | # intermediate feature 73 | with autocast(): 74 | 75 | feat = net(img)[:, 1:, :] 76 | seg_feat_ema = segment.head_ema(feat) 77 | 78 | # linear probe loss 79 | linear_logits = segment.linear(seg_feat_ema) 80 | linear_logits = F.interpolate(linear_logits, label.shape[-2:], mode='bilinear', align_corners=False) 81 | flat_label = label.reshape(-1) 82 | flat_label_mask = (flat_label >= 0) & (flat_label < args.n_classes) 83 | 84 | # interp feat 85 | interp_seg_feat = F.interpolate(transform(seg_feat_ema), label.shape[-2:], mode='bilinear', align_corners=False) 86 | 87 | # cluster 88 | cluster_preds = cluster.forward_centroid(untransform(interp_seg_feat), inference=True) 89 | 90 | # nice evaluation 91 | _, desc_nice = nice.eval(cluster_preds, label) 92 | 93 | # linear probe acc check 94 | pred_label = linear_logits.argmax(dim=1) 95 | flat_pred_label = pred_label.reshape(-1) 96 | acc = (flat_pred_label[flat_label_mask] == flat_label[flat_label_mask]).sum() / flat_label[ 97 | flat_label_mask].numel() 98 | total_acc += acc.item() 99 | 100 | # real-time print 101 | desc = f'[TEST] Acc (Linear): {100. * total_acc / (idx + 1):.1f}% | {desc_nice}' 102 | prog_bar.set_description(desc, refresh=True) 103 | 104 | # evaludation metric reset 105 | nice.reset() 106 | 107 | 108 | def test_linear_without_crf(args, net, segment, nice, test_loader): 109 | segment.eval() 110 | 111 | prog_bar = tqdm(enumerate(test_loader), total=len(test_loader), leave=True) 112 | with Pool(40) as pool: 113 | for _, batch in prog_bar: 114 | # image and label and self supervised feature 115 | ind = batch["ind"].cuda() 116 | img = batch["img"].cuda() 117 | label = batch["label"].cuda() 118 | 119 | with autocast(): 120 | # intermediate feature 121 | feat = net(img)[:, 1:, :] 122 | feat_flip = net(img.flip(dims=[3]))[:, 1:, :] 123 | seg_feat = transform(segment.head_ema(feat)) 124 | seg_feat_flip = transform(segment.head_ema(feat_flip)) 125 | seg_feat = untransform((seg_feat + seg_feat_flip.flip(dims=[3])) / 2) 126 | 127 | # interp feat 128 | interp_seg_feat = F.interpolate(transform(seg_feat), label.shape[-2:], mode='bilinear', align_corners=False) 129 | 130 | # linear probe interp feat 131 | linear_logits = segment.linear(untransform(interp_seg_feat)) 132 | 133 | # cluster preds 134 | cluster_preds = linear_logits.argmax(dim=1) 135 | 136 | # nice evaluation 137 | _, desc_nice = nice.eval(cluster_preds, label) 138 | 139 | # real-time print 140 | desc = f'{desc_nice}' 141 | prog_bar.set_description(desc, refresh=True) 142 | 143 | # evaludation metric reset 144 | nice.reset() 145 | 146 | 147 | 148 | def test_linear(args, net, segment, nice, test_loader): 149 | segment.eval() 150 | 151 | prog_bar = tqdm(enumerate(test_loader), total=len(test_loader), leave=True) 152 | with Pool(40) as pool: 153 | for _, batch in prog_bar: 154 | # image and label and self supervised feature 155 | ind = batch["ind"].cuda() 156 | img = batch["img"].cuda() 157 | label = batch["label"].cuda() 158 | 159 | with autocast(): 160 | # intermediate feature 161 | feat = net(img)[:, 1:, :] 162 | feat_flip = net(img.flip(dims=[3]))[:, 1:, :] 163 | seg_feat = transform(segment.head_ema(feat)) 164 | seg_feat_flip = transform(segment.head_ema(feat_flip)) 165 | seg_feat = untransform((seg_feat + seg_feat_flip.flip(dims=[3])) / 2) 166 | 167 | # interp feat 168 | interp_seg_feat = F.interpolate(transform(seg_feat), label.shape[-2:], mode='bilinear', align_corners=False) 169 | 170 | # linear probe interp feat 171 | linear_logits = segment.linear(untransform(interp_seg_feat)) 172 | 173 | # cluster preds 174 | cluster_preds = torch.log_softmax(linear_logits, dim=1) 175 | 176 | # crf 177 | crf_preds = do_crf(pool, img, cluster_preds).argmax(1).cuda() 178 | 179 | # nice evaluation 180 | _, desc_nice = nice.eval(crf_preds, label) 181 | 182 | # real-time print 183 | desc = f'{desc_nice}' 184 | prog_bar.set_description(desc, refresh=True) 185 | 186 | # evaludation metric reset 187 | nice.reset() 188 | 189 | 190 | def main(rank, args): 191 | 192 | # setting gpu id of this process 193 | torch.cuda.set_device(rank) 194 | 195 | # print argparse 196 | print_argparse(args, rank=0) 197 | 198 | # dataset loader 199 | train_loader, test_loader, _ = dataloader(args, False) 200 | 201 | # network loader 202 | net = network_loader(args, rank) 203 | segment = segment_tr_loader(args, rank) 204 | cluster = cluster_tr_loader(args, rank) 205 | 206 | # evaluation 207 | nice = NiceTool(args.n_classes) 208 | 209 | # color map 210 | cmap = create_cityscapes_colormap() if args.dataset == 'cityscapes' else create_pascal_label_colormap() 211 | 212 | 213 | ################################################################################### 214 | # First, run train_mediator.py 215 | path, is_exist = pickle_path_and_exist(args) 216 | 217 | # early save for time 218 | if is_exist: 219 | # load 220 | codebook = np.load(path) 221 | cb = torch.from_numpy(codebook).cuda() 222 | cluster.codebook.data = cb 223 | cluster.codebook.requires_grad = False 224 | segment.head.codebook = cb 225 | segment.head_ema.codebook = cb 226 | 227 | # print successful loading modularity 228 | rprint(f'Modularity {path} loaded', rank) 229 | 230 | else: 231 | rprint('Train Modularity-based Codebook First', rank) 232 | return 233 | ################################################################################### 234 | 235 | # param size 236 | print(f'# of Parameters: {num_param(segment)/10**6:.2f}(M)') 237 | 238 | # post-processing with crf and hungarian matching 239 | test_without_crf( 240 | args, 241 | net, 242 | segment, 243 | cluster, 244 | nice, 245 | test_loader) 246 | 247 | # post-processing with crf and hungarian matching 248 | test( 249 | args, 250 | net, 251 | segment, 252 | cluster, 253 | nice, 254 | test_loader, 255 | cmap) 256 | 257 | # post-processing with crf and hungarian matching 258 | # test_linear_without_crf( 259 | # args, 260 | # net, 261 | # segment, 262 | # nice, 263 | # test_loader) 264 | 265 | # test_linear( 266 | # args, 267 | # net, 268 | # segment, 269 | # nice, 270 | # test_loader) 271 | 272 | 273 | if __name__ == "__main__": 274 | 275 | # fetch args 276 | parser = argparse.ArgumentParser() 277 | 278 | # model parameter 279 | parser.add_argument('--NAME-TAG', default='CAUSE-TR', type=str) 280 | parser.add_argument('--data_dir', default='/mnt/hard2/lbk-iccv/datasets', type=str) 281 | parser.add_argument('--dataset', default='pascalvoc', type=str) 282 | parser.add_argument('--port', default='12355', type=str) 283 | parser.add_argument('--ckpt', default='checkpoint/dino_vit_small_8.pth', type=str) 284 | parser.add_argument('--distributed', default=False, type=str2bool) 285 | parser.add_argument('--load_segment', default=True, type=str2bool) 286 | parser.add_argument('--load_cluster', default=True, type=str2bool) 287 | parser.add_argument('--train_resolution', default=320, type=int) 288 | parser.add_argument('--test_resolution', default=320, type=int) 289 | parser.add_argument('--batch_size', default=16, type=int) 290 | parser.add_argument('--num_workers', default=int(os.cpu_count() / 8), type=int) 291 | parser.add_argument('--gpu', default='4', type=str) 292 | parser.add_argument('--num_codebook', default=2048, type=int) 293 | 294 | # model parameter 295 | parser.add_argument('--reduced_dim', default=90, type=int) 296 | parser.add_argument('--projection_dim', default=2048, type=int) 297 | 298 | args = parser.parse_args() 299 | 300 | 301 | if 'dinov2' in args.ckpt: 302 | args.train_resolution=322 303 | args.test_resolution=322 304 | if 'small' in args.ckpt: 305 | args.dim=384 306 | elif 'base' in args.ckpt: 307 | args.dim=768 308 | args.num_queries=args.train_resolution**2 // int(args.ckpt.split('_')[-1].split('.')[0])**2 309 | 310 | 311 | # the number of gpus for multi-process 312 | gpu_list = list(map(int, args.gpu.split(','))) 313 | ngpus_per_node = len(gpu_list) 314 | 315 | # first gpu index is activated once there are several gpu in args.gpu 316 | main(rank=gpu_list[0], args=args) -------------------------------------------------------------------------------- /train_front_door_mlp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import warnings 4 | warnings.filterwarnings("ignore", category=RuntimeWarning) 5 | 6 | 7 | from tqdm import tqdm 8 | from utils.utils import * 9 | import torch.distributed as dist 10 | import torch.multiprocessing as mp 11 | import torch.backends.cudnn as cudnn 12 | from modules.segment_module import stochastic_sampling, ema_init, ema_update 13 | from loader.dataloader import dataloader 14 | from torch.cuda.amp import autocast, GradScaler 15 | from loader.netloader import network_loader, segment_mlp_loader, cluster_mlp_loader 16 | from tensorboardX import SummaryWriter 17 | 18 | cudnn.benchmark = True 19 | scaler = GradScaler() 20 | 21 | # tensorboard 22 | counter = 0 23 | counter_test = 0 24 | 25 | def ddp_setup(args, rank, world_size): 26 | os.environ['MASTER_ADDR'] = 'localhost' 27 | os.environ['MASTER_PORT'] = args.port 28 | 29 | # initialize 30 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 31 | 32 | 33 | def ddp_clean(): 34 | dist.destroy_process_group() 35 | 36 | 37 | @Wrapper.EpochPrint 38 | def train(args, net, segment, cluster, train_loader, optimizer_segment, writer, rank): 39 | global counter 40 | segment.train() 41 | 42 | total_acc = 0 43 | total_loss = 0 44 | total_loss_front = 0 45 | total_loss_linear = 0 46 | 47 | prog_bar = tqdm(enumerate(train_loader), total=len(train_loader), leave=True) 48 | for idx, batch in prog_bar: 49 | 50 | # optimizer 51 | with autocast(): 52 | 53 | # image and label and self supervised feature 54 | img = batch["img"].cuda() 55 | label = batch["label"].cuda() 56 | 57 | # intermediate features 58 | feat = net(img)[:, 1:, :] 59 | orig_seg_feat_ema = segment.head_ema(feat, drop=segment.dropout) 60 | 61 | if args.grid: feat, _ = stochastic_sampling(feat) 62 | 63 | ###################################################################### 64 | # teacher 65 | seg_feat_ema = segment.head_ema(feat, drop=segment.dropout) 66 | proj_feat_ema = segment.projection_head_ema(seg_feat_ema) 67 | ###################################################################### 68 | 69 | ###################################################################### 70 | # student 71 | seg_feat = segment.head(feat, drop=segment.dropout) 72 | proj_feat = segment.projection_head(seg_feat) 73 | ###################################################################### 74 | 75 | ###################################################################### 76 | # bank compute and contrastive loss 77 | cluster.bank_compute() 78 | loss_front = cluster.contrastive_ema_with_codebook_bank(feat, proj_feat, proj_feat_ema) 79 | ###################################################################### 80 | 81 | # linear probe loss 82 | linear_logits = segment.linear(orig_seg_feat_ema) 83 | linear_logits = F.interpolate(linear_logits, label.shape[-2:], mode='bilinear', align_corners=False) 84 | flat_linear_logits = linear_logits.permute(0, 2, 3, 1).view(-1, args.n_classes) 85 | flat_label = label.view(-1) 86 | flat_label_mask = (flat_label >= 0) & (flat_label < args.n_classes) 87 | loss_linear = F.cross_entropy(flat_linear_logits[flat_label_mask], flat_label[flat_label_mask]) 88 | 89 | # loss 90 | loss = loss_front + loss_linear 91 | 92 | # optimizer 93 | optimizer_segment.zero_grad() 94 | scaler.scale(loss).backward() 95 | if args.dataset=='cityscapes': 96 | scaler.unscale_(optimizer_segment) 97 | torch.nn.utils.clip_grad_norm_(segment.parameters(), 1) 98 | elif args.dataset=='cocostuff27': 99 | scaler.unscale_(optimizer_segment) 100 | torch.nn.utils.clip_grad_norm_(segment.parameters(), 0.1) 101 | else: 102 | raise NotImplementedError 103 | scaler.step(optimizer_segment) 104 | scaler.update() 105 | 106 | # ema update 107 | ema_update(segment.head, segment.head_ema) 108 | ema_update(segment.projection_head, segment.projection_head_ema) 109 | 110 | # bank update 111 | cluster.bank_update(feat, proj_feat_ema) 112 | 113 | # linear probe acc check 114 | pred_label = linear_logits.argmax(dim=1) 115 | flat_pred_label = pred_label.view(-1) 116 | acc = (flat_pred_label[flat_label_mask] == flat_label[flat_label_mask]).sum() / flat_label[ 117 | flat_label_mask].numel() 118 | total_acc += acc.item() 119 | 120 | # loss check 121 | total_loss += loss.item() 122 | total_loss_front += loss_front.item() 123 | total_loss_linear += loss_linear.item() 124 | 125 | # real-time print 126 | desc = f'[Train] Loss: {total_loss / (idx + 1):.2f}={total_loss_front / (idx + 1):.2f}+{total_loss_linear / (idx + 1):.2f}' 127 | desc += f' ACC: {100. * total_acc / (idx + 1):.1f}%' 128 | prog_bar.set_description(desc, refresh=True) 129 | 130 | 131 | # tensorboard 132 | if (args.distributed == True) and (rank == 0): 133 | writer.add_scalar('Train/Contrastive', loss_front, counter) 134 | writer.add_scalar('Train/Linear', loss_linear, counter) 135 | writer.add_scalar('Train/Acc', total_acc / (idx + 1), counter) 136 | counter += 1 137 | 138 | # Interrupt for sync GPU Process 139 | if args.distributed: dist.barrier() 140 | 141 | 142 | @Wrapper.TestPrint 143 | def test(args, net, segment, nice, test_loader): 144 | global counter_test 145 | segment.eval() 146 | 147 | total_acc = 0 148 | prog_bar = tqdm(enumerate(test_loader), total=len(test_loader), leave=True) 149 | for idx, batch in prog_bar: 150 | # image and label and self supervised feature 151 | img = batch["img"].cuda() 152 | label = batch["label"].cuda() 153 | 154 | # intermediate feature 155 | with autocast(): 156 | feat = net(img)[:, 1:, :] 157 | seg_feat_ema = segment.head_ema(feat) 158 | 159 | # linear probe loss 160 | linear_logits = segment.linear(seg_feat_ema) 161 | linear_logits = F.interpolate(linear_logits, label.shape[-2:], mode='bilinear', align_corners=False) 162 | flat_label = label.view(-1) 163 | flat_label_mask = (flat_label >= 0) & (flat_label < args.n_classes) 164 | 165 | # linear probe acc check 166 | pred_label = linear_logits.argmax(dim=1) 167 | flat_pred_label = pred_label.view(-1) 168 | acc = (flat_pred_label[flat_label_mask] == flat_label[flat_label_mask]).sum() / flat_label[ 169 | flat_label_mask].numel() 170 | total_acc += acc.item() 171 | 172 | # real-time print 173 | desc = f'[TEST] Acc (Linear): {100. * total_acc / (idx + 1):.1f}%' 174 | prog_bar.set_description(desc, refresh=True) 175 | 176 | # evaluation metric reset 177 | nice.reset() 178 | 179 | # Interrupt for sync GPU Process 180 | if args.distributed: dist.barrier() 181 | 182 | 183 | def main(rank, args, ngpus_per_node): 184 | 185 | # setup ddp process 186 | if args.distributed: ddp_setup(args, rank, ngpus_per_node) 187 | 188 | # setting gpu id of this process 189 | torch.cuda.set_device(rank) 190 | 191 | # print argparse 192 | print_argparse(args, rank) 193 | 194 | # dataset loader 195 | train_loader, test_loader, sampler = dataloader(args) 196 | 197 | # network loader 198 | net = network_loader(args, rank) 199 | segment = segment_mlp_loader(args, rank) 200 | cluster = cluster_mlp_loader(args, rank) 201 | 202 | # distributed parsing 203 | if args.distributed: net = net.module; segment = segment.module; cluster = cluster.module 204 | 205 | # Bank and EMA initialization 206 | cluster.bank_init() 207 | ema_init(segment.head, segment.head_ema) 208 | ema_init(segment.projection_head, segment.projection_head_ema) 209 | 210 | ################################################################################### 211 | # First, run train_mediator.py 212 | path, is_exist = pickle_path_and_exist(args) 213 | 214 | # early save for time 215 | if is_exist: 216 | # load 217 | codebook = np.load(path) 218 | cluster.codebook.data = torch.from_numpy(codebook).cuda() 219 | cluster.codebook.requires_grad = False 220 | 221 | # print successful loading modularity 222 | rprint(f'Modularity {path} loaded', rank) 223 | 224 | # Interrupt for sync GPU Process 225 | if args.distributed: dist.barrier() 226 | 227 | else: 228 | rprint('Train Modularity-based Codebook First', rank) 229 | return 230 | ################################################################################### 231 | 232 | # optimizer 233 | if args.dataset=='cityscapes': 234 | optimizer_segment = torch.optim.Adam(segment.parameters(), lr=1e-3 * ngpus_per_node) 235 | else: 236 | optimizer_segment = torch.optim.Adam(segment.parameters(), lr=1e-3 * ngpus_per_node, weight_decay=1e-4) 237 | 238 | # tensorboard 239 | if (args.distributed == True) and (rank == 0): 240 | from datetime import datetime 241 | log_dir = os.path.join('logs', 242 | datetime.today().strftime(" %m:%d_%H:%M")[2:], 243 | args.dataset, 244 | "_".join( 245 | [args.ckpt.split('/')[-1].split('.')[0], 246 | str( args.num_codebook), 247 | os.path.abspath(__file__).split('/')[-1]])) 248 | check_dir(log_dir) 249 | writer = SummaryWriter(log_dir=log_dir) if (rank == 0) and (args.distributed == True) else None 250 | 251 | # evaluation 252 | nice = NiceTool(args.n_classes) 253 | 254 | 255 | # train 256 | for epoch in range(args.epoch): 257 | 258 | # for shuffle 259 | if args.distributed: sampler.set_epoch(epoch) 260 | 261 | 262 | # train 263 | train( 264 | epoch, # for decorator 265 | rank, # for decorator 266 | args, 267 | net, 268 | segment, 269 | cluster, 270 | train_loader, 271 | optimizer_segment, 272 | writer, rank) 273 | 274 | 275 | test( 276 | epoch, # for decorator 277 | rank, # for decorator 278 | args, 279 | net, 280 | segment, 281 | nice, 282 | test_loader) 283 | 284 | if (rank == 0): 285 | x = segment.state_dict() 286 | baseline = args.ckpt.split('/')[-1].split('.')[0] 287 | 288 | # filepath hierarchy 289 | check_dir(f'CAUSE/{args.dataset}/{baseline}/{args.num_codebook}') 290 | 291 | # save path 292 | y = f'CAUSE/{args.dataset}/{baseline}/{args.num_codebook}/segment_mlp.pth' 293 | torch.save(x, y) 294 | print(f'-----------------TEST Epoch {epoch}: SAVING CHECKPOINT IN {y}-----------------') 295 | 296 | # Interrupt for sync GPU Process 297 | if args.distributed: dist.barrier() 298 | 299 | # Closing DDP 300 | if args.distributed: dist.barrier(); dist.destroy_process_group() 301 | 302 | 303 | if __name__ == "__main__": 304 | 305 | # fetch args 306 | parser = argparse.ArgumentParser() 307 | # model parameter 308 | parser.add_argument('--NAME-TAG', default='CAUSE-MLP', type=str) 309 | parser.add_argument('--data_dir', default='/mnt/hard2/lbk-iccv/datasets', type=str) 310 | parser.add_argument('--dataset', default='cocostuff27', type=str) 311 | parser.add_argument('--ckpt', default='checkpoint/dino_vit_base_8.pth', type=str) 312 | parser.add_argument('--epoch', default=2, type=int) 313 | parser.add_argument('--distributed', default=True, type=str2bool) 314 | parser.add_argument('--load_segment', default=False, type=str2bool) 315 | parser.add_argument('--load_cluster', default=False, type=str2bool) 316 | parser.add_argument('--train_resolution', default=224, type=int) 317 | parser.add_argument('--test_resolution', default=320, type=int) 318 | parser.add_argument('--batch_size', default=16, type=int) 319 | parser.add_argument('--num_workers', default=int(os.cpu_count() / 8), type=int) 320 | 321 | # DDP 322 | parser.add_argument('--gpu', default='0,1,2,3', type=str) 323 | parser.add_argument('--port', default='12355', type=str) 324 | 325 | # codebook parameter 326 | parser.add_argument('--grid', default='yes', type=str2bool) 327 | parser.add_argument('--num_codebook', default=2048, type=int) 328 | 329 | # model parameter 330 | parser.add_argument('--reduced_dim', default=90, type=int) 331 | parser.add_argument('--projection_dim', default=2048, type=int) 332 | 333 | args = parser.parse_args() 334 | 335 | if 'dinov2' in args.ckpt: args.test_resolution=322 336 | if 'small' in args.ckpt: 337 | args.dim = 384 338 | elif 'base' in args.ckpt: 339 | args.dim = 768 340 | 341 | # the number of gpus for multi-process 342 | gpu_list = list(map(int, args.gpu.split(','))) 343 | ngpus_per_node = len(gpu_list) 344 | 345 | if args.distributed: 346 | # cuda visible devices 347 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 348 | # multiprocess spawn 349 | mp.spawn(main, args=(args, ngpus_per_node), nprocs=ngpus_per_node, join=True) 350 | else: 351 | # first gpu index is activated once there are several gpu in args.gpu 352 | main(rank=gpu_list[0], args=args, ngpus_per_node=1) 353 | -------------------------------------------------------------------------------- /train_front_door_tr.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import warnings 4 | warnings.filterwarnings("ignore", category=RuntimeWarning) 5 | 6 | 7 | from tqdm import tqdm 8 | from utils.utils import * 9 | import torch.distributed as dist 10 | import torch.multiprocessing as mp 11 | import torch.backends.cudnn as cudnn 12 | from modules.segment_module import stochastic_sampling, ema_init, ema_update 13 | from loader.dataloader import dataloader 14 | from torch.cuda.amp import autocast, GradScaler 15 | from loader.netloader import network_loader, segment_tr_loader, cluster_tr_loader 16 | from tensorboardX import SummaryWriter 17 | 18 | cudnn.benchmark = True 19 | scaler = GradScaler() 20 | 21 | # tensorboard 22 | counter = 0 23 | counter_test = 0 24 | 25 | def ddp_setup(args, rank, world_size): 26 | os.environ['MASTER_ADDR'] = 'localhost' 27 | os.environ['MASTER_PORT'] = args.port 28 | 29 | # initialize 30 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 31 | 32 | 33 | def ddp_clean(): 34 | dist.destroy_process_group() 35 | 36 | 37 | @Wrapper.EpochPrint 38 | def train(args, net, segment, cluster, train_loader, optimizer_segment, writer, rank): 39 | global counter 40 | segment.train() 41 | 42 | total_acc = 0 43 | total_loss = 0 44 | total_loss_front = 0 45 | total_loss_linear = 0 46 | 47 | prog_bar = tqdm(enumerate(train_loader), total=len(train_loader), leave=True) 48 | for idx, batch in prog_bar: 49 | 50 | # optimizer 51 | with autocast(): 52 | 53 | # image and label and self supervised feature 54 | img = batch["img"].cuda() 55 | label = batch["label"].cuda() 56 | 57 | # intermediate features 58 | feat = net(img)[:, 1:, :] 59 | 60 | ###################################################################### 61 | # teacher 62 | seg_feat_ema = segment.head_ema(feat, drop=segment.dropout) 63 | proj_feat_ema = segment.projection_head_ema(seg_feat_ema) 64 | ###################################################################### 65 | 66 | ###################################################################### 67 | # student 68 | seg_feat = segment.head(feat, drop=segment.dropout) 69 | proj_feat = segment.projection_head(seg_feat) 70 | ###################################################################### 71 | 72 | ###################################################################### 73 | # grid 74 | if args.grid: 75 | feat, order = stochastic_sampling(feat) 76 | proj_feat, _ = stochastic_sampling(proj_feat, order=order) 77 | proj_feat_ema, _ = stochastic_sampling(proj_feat_ema, order=order) 78 | ###################################################################### 79 | 80 | ###################################################################### 81 | # bank compute and contrastive loss 82 | cluster.bank_compute() 83 | loss_front = cluster.contrastive_ema_with_codebook_bank(feat, proj_feat, proj_feat_ema) 84 | ###################################################################### 85 | 86 | # linear probe loss 87 | linear_logits = segment.linear(seg_feat_ema) 88 | linear_logits = F.interpolate(linear_logits, label.shape[-2:], mode='bilinear', align_corners=False) 89 | flat_linear_logits = linear_logits.permute(0, 2, 3, 1).reshape(-1, args.n_classes) 90 | flat_label = label.reshape(-1) 91 | flat_label_mask = (flat_label >= 0) & (flat_label < args.n_classes) 92 | loss_linear = F.cross_entropy(flat_linear_logits[flat_label_mask], flat_label[flat_label_mask]) 93 | 94 | # loss 95 | loss = loss_front + loss_linear 96 | 97 | # optimizer 98 | optimizer_segment.zero_grad() 99 | scaler.scale(loss).backward() 100 | if args.dataset=='cityscapes': 101 | scaler.unscale_(optimizer_segment) 102 | torch.nn.utils.clip_grad_norm_(segment.parameters(), 1) 103 | elif args.dataset=='cocostuff27': 104 | scaler.unscale_(optimizer_segment) 105 | torch.nn.utils.clip_grad_norm_(segment.parameters(), 2) 106 | else: 107 | raise NotImplementedError 108 | scaler.step(optimizer_segment) 109 | scaler.update() 110 | 111 | # ema update 112 | ema_update(segment.head, segment.head_ema) 113 | ema_update(segment.projection_head, segment.projection_head_ema) 114 | 115 | # bank update 116 | cluster.bank_update(feat, proj_feat_ema) 117 | 118 | # linear probe acc check 119 | pred_label = linear_logits.argmax(dim=1) 120 | flat_pred_label = pred_label.view(-1) 121 | acc = (flat_pred_label[flat_label_mask] == flat_label[flat_label_mask]).sum() / flat_label[ 122 | flat_label_mask].numel() 123 | total_acc += acc.item() 124 | 125 | # loss check 126 | total_loss += loss.item() 127 | total_loss_front += loss_front.item() 128 | total_loss_linear += loss_linear.item() 129 | 130 | # real-time print 131 | desc = f'[Train] Loss: {total_loss / (idx + 1):.2f}={total_loss_front / (idx + 1):.2f}+{total_loss_linear / (idx + 1):.2f}' 132 | desc += f' ACC: {100. * total_acc / (idx + 1):.1f}%' 133 | prog_bar.set_description(desc, refresh=True) 134 | 135 | 136 | # tensorboard 137 | if (args.distributed == True) and (rank == 0): 138 | writer.add_scalar('Train/Contrastive', loss_front, counter) 139 | writer.add_scalar('Train/Linear', loss_linear, counter) 140 | writer.add_scalar('Train/Acc', total_acc / (idx + 1), counter) 141 | counter += 1 142 | 143 | # Interrupt for sync GPU Process 144 | if args.distributed: dist.barrier() 145 | 146 | 147 | @Wrapper.TestPrint 148 | def test(args, net, segment, nice, test_loader): 149 | global counter_test 150 | segment.eval() 151 | 152 | total_acc = 0 153 | prog_bar = tqdm(enumerate(test_loader), total=len(test_loader), leave=True) 154 | for idx, batch in prog_bar: 155 | # image and label and self supervised feature 156 | img = batch["img"].cuda() 157 | label = batch["label"].cuda() 158 | 159 | # intermediate feature 160 | with autocast(): 161 | feat = net(img)[:, 1:, :] 162 | seg_feat_ema = segment.head_ema(feat) 163 | 164 | # linear probe loss 165 | linear_logits = segment.linear(seg_feat_ema) 166 | linear_logits = F.interpolate(linear_logits, label.shape[-2:], mode='bilinear', align_corners=False) 167 | flat_label = label.view(-1) 168 | flat_label_mask = (flat_label >= 0) & (flat_label < args.n_classes) 169 | 170 | # linear probe acc check 171 | pred_label = linear_logits.argmax(dim=1) 172 | flat_pred_label = pred_label.view(-1) 173 | acc = (flat_pred_label[flat_label_mask] == flat_label[flat_label_mask]).sum() / flat_label[ 174 | flat_label_mask].numel() 175 | total_acc += acc.item() 176 | 177 | # real-time print 178 | desc = f'[TEST] Acc (Linear): {100. * total_acc / (idx + 1):.1f}%' 179 | prog_bar.set_description(desc, refresh=True) 180 | 181 | # evaluation metric reset 182 | nice.reset() 183 | 184 | # Interrupt for sync GPU Process 185 | if args.distributed: dist.barrier() 186 | 187 | 188 | def main(rank, args, ngpus_per_node): 189 | 190 | # setup ddp process 191 | if args.distributed: ddp_setup(args, rank, ngpus_per_node) 192 | 193 | # setting gpu id of this process 194 | torch.cuda.set_device(rank) 195 | 196 | # print argparse 197 | print_argparse(args, rank) 198 | 199 | # dataset loader 200 | train_loader, test_loader, sampler = dataloader(args) 201 | 202 | # network loader 203 | net = network_loader(args, rank) 204 | segment = segment_tr_loader(args, rank) 205 | cluster = cluster_tr_loader(args, rank) 206 | 207 | # distributed parsing 208 | if args.distributed: net = net.module; segment = segment.module; cluster = cluster.module 209 | 210 | # Bank and EMA 211 | cluster.bank_init() 212 | ema_init(segment.head, segment.head_ema) 213 | ema_init(segment.projection_head, segment.projection_head_ema) 214 | 215 | ################################################################################### 216 | # First, run train_mediator.py 217 | path, is_exist = pickle_path_and_exist(args) 218 | 219 | # early save for time 220 | if is_exist: 221 | # load 222 | codebook = np.load(path) 223 | cluster.codebook.data = torch.from_numpy(codebook).cuda() 224 | cluster.codebook.requires_grad = False 225 | segment.head.codebook = torch.from_numpy(codebook).cuda() 226 | segment.head_ema.codebook = torch.from_numpy(codebook).cuda() 227 | 228 | # print successful loading modularity 229 | rprint(f'Modularity {path} loaded', rank) 230 | 231 | # Interrupt for sync GPU Process 232 | if args.distributed: dist.barrier() 233 | 234 | else: 235 | rprint('Train Modularity-based Codebook First', rank) 236 | return 237 | ################################################################################### 238 | 239 | # optimizer 240 | if args.dataset=='cityscapes': 241 | optimizer_segment = torch.optim.Adam(segment.parameters(), lr=1e-3 * ngpus_per_node) 242 | else: 243 | optimizer_segment = torch.optim.Adam(segment.parameters(), lr=1e-3 * ngpus_per_node, weight_decay=1e-4) 244 | 245 | # tensorboard 246 | if (args.distributed == True) and (rank == 0): 247 | from datetime import datetime 248 | log_dir = os.path.join('logs', 249 | datetime.today().strftime(" %m:%d_%H:%M")[2:], 250 | args.dataset, 251 | "_".join( 252 | [args.ckpt.split('/')[-1].split('.')[0], 253 | str(args.num_codebook), 254 | os.path.abspath(__file__).split('/')[-1]])) 255 | check_dir(log_dir) 256 | writer = SummaryWriter(log_dir=log_dir) if (rank == 0) and (args.distributed == True) else None 257 | 258 | # evaluation 259 | nice = NiceTool(args.n_classes) 260 | 261 | 262 | # train 263 | for epoch in range(args.epoch): 264 | 265 | # for shuffle 266 | if args.distributed: sampler.set_epoch(epoch) 267 | 268 | 269 | # train 270 | train( 271 | epoch, # for decorator 272 | rank, # for decorator 273 | args, 274 | net, 275 | segment, 276 | cluster, 277 | train_loader, 278 | optimizer_segment, 279 | writer, rank) 280 | 281 | 282 | test( 283 | epoch, # for decorator 284 | rank, # for decorator 285 | args, 286 | net, 287 | segment, 288 | nice, 289 | test_loader) 290 | 291 | if (rank == 0): 292 | x = segment.state_dict() 293 | baseline = args.ckpt.split('/')[-1].split('.')[0] 294 | 295 | # filepath hierarchy 296 | check_dir(f'CAUSE/{args.dataset}/{baseline}/{args.num_codebook}') 297 | 298 | # save path 299 | y = f'CAUSE/{args.dataset}/{baseline}/{args.num_codebook}/segment_tr.pth' 300 | torch.save(x, y) 301 | print(f'-----------------TEST Epoch {epoch}: SAVING CHECKPOINT IN {y}-----------------') 302 | 303 | # Interrupt for sync GPU Process 304 | if args.distributed: dist.barrier() 305 | 306 | # Closing DDP 307 | if args.distributed: dist.barrier(); dist.destroy_process_group() 308 | 309 | 310 | if __name__ == "__main__": 311 | 312 | # fetch args 313 | parser = argparse.ArgumentParser() 314 | # model parameter 315 | parser.add_argument('--NAME-TAG', default='CAUSE-TR', type=str) 316 | parser.add_argument('--data_dir', default='/mnt/hard2/lbk-iccv/datasets', type=str) 317 | parser.add_argument('--dataset', default='cocostuff27', type=str) 318 | parser.add_argument('--ckpt', default='checkpoint/dino_vit_base_8.pth', type=str) 319 | parser.add_argument('--epoch', default=2, type=int) 320 | parser.add_argument('--distributed', default=True, type=str2bool) 321 | parser.add_argument('--load_segment', default=False, type=str2bool) 322 | parser.add_argument('--load_cluster', default=False, type=str2bool) 323 | parser.add_argument('--train_resolution', default=320, type=int) 324 | parser.add_argument('--test_resolution', default=320, type=int) 325 | parser.add_argument('--batch_size', default=16, type=int) 326 | parser.add_argument('--num_workers', default=int(os.cpu_count() / 8), type=int) 327 | 328 | # DDP 329 | parser.add_argument('--gpu', default='0,1,2,3', type=str) 330 | parser.add_argument('--port', default='12355', type=str) 331 | 332 | # codebook parameter 333 | parser.add_argument('--grid', default='yes', type=str2bool) 334 | parser.add_argument('--num_codebook', default=2048, type=int) 335 | 336 | # model parameter 337 | parser.add_argument('--reduced_dim', default=90, type=int) 338 | parser.add_argument('--projection_dim', default=2048, type=int) 339 | 340 | args = parser.parse_args() 341 | 342 | if 'dinov2' in args.ckpt: 343 | args.train_resolution=322 344 | args.test_resolution=322 345 | if 'small' in args.ckpt: 346 | args.dim=384 347 | elif 'base' in args.ckpt: 348 | args.dim=768 349 | args.num_queries=args.train_resolution**2 // int(args.ckpt.split('_')[-1].split('.')[0])**2 350 | 351 | # the number of gpus for multi-process 352 | gpu_list = list(map(int, args.gpu.split(','))) 353 | ngpus_per_node = len(gpu_list) 354 | 355 | if args.distributed: 356 | # cuda visible devices 357 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 358 | # multiprocess spawn 359 | mp.spawn(main, args=(args, ngpus_per_node), nprocs=ngpus_per_node, join=True) 360 | else: 361 | # first gpu index is activated once there are several gpu in args.gpu 362 | main(rank=gpu_list[0], args=args, ngpus_per_node=1) 363 | -------------------------------------------------------------------------------- /train_mediator.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from tqdm import tqdm 4 | from utils.utils import * 5 | import torch.distributed as dist 6 | import torch.multiprocessing as mp 7 | import torch.backends.cudnn as cudnn 8 | from modules.segment_module import compute_modularity_based_codebook 9 | from loader.dataloader import dataloader 10 | from torch.cuda.amp import autocast, GradScaler 11 | from loader.netloader import network_loader, cluster_mlp_loader 12 | 13 | cudnn.benchmark = True 14 | scaler = GradScaler() 15 | 16 | def ddp_setup(args, rank, world_size): 17 | os.environ['MASTER_ADDR'] = 'localhost' 18 | os.environ['MASTER_PORT'] = args.port 19 | 20 | # initialize 21 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 22 | 23 | def ddp_clean(): 24 | dist.destroy_process_group() 25 | 26 | @Wrapper.EpochPrint 27 | def train(args, net, cluster, train_loader, optimizer): 28 | prog_bar = tqdm(enumerate(train_loader), total=len(train_loader), leave=True) 29 | for idx, batch in prog_bar: 30 | # image and label and self supervised feature 31 | img = batch["img"].cuda() 32 | 33 | # intermediate feature 34 | with autocast(): 35 | feat = net(img)[:, 1:, :] 36 | 37 | # computing modularity based codebook 38 | loss_mod = compute_modularity_based_codebook(cluster.codebook, feat, grid=args.grid) 39 | 40 | # optimization 41 | optimizer.zero_grad() 42 | scaler.scale(loss_mod).backward() 43 | scaler.step(optimizer) 44 | scaler.update() 45 | 46 | # real-time print 47 | desc = f'[Train]' 48 | prog_bar.set_description(desc, refresh=True) 49 | 50 | # Interrupt for sync GPU Process 51 | if args.distributed: dist.barrier() 52 | 53 | def main(rank, args, ngpus_per_node): 54 | # setup ddp process 55 | if args.distributed: ddp_setup(args, rank, ngpus_per_node) 56 | 57 | # setting gpu id of this process 58 | torch.cuda.set_device(rank) 59 | 60 | # print argparse 61 | print_argparse(args, rank) 62 | 63 | # dataset loader 64 | train_loader, _, sampler = dataloader(args) 65 | 66 | # network loader 67 | net = network_loader(args, rank) 68 | cluster = cluster_mlp_loader(args, rank) 69 | 70 | # distributed parsing 71 | if args.distributed: net = net.module; cluster = cluster.module 72 | 73 | # optimizer and scheduler 74 | optimizer = torch.optim.Adam(cluster.parameters(), lr=1e-3 * ngpus_per_node) 75 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.2) 76 | 77 | ################################################################################### 78 | # train only modularity? 79 | path, is_exist = pickle_path_and_exist(args) 80 | 81 | # early save for time 82 | if not is_exist: 83 | rprint("No File Exists!!", rank) 84 | # train 85 | for epoch in range(args.epoch): 86 | 87 | # for shuffle 88 | if args.distributed: sampler.set_epoch(epoch) 89 | 90 | # train 91 | train( 92 | epoch, # for decorator 93 | rank, # for decorator 94 | args, 95 | net, 96 | cluster, 97 | train_loader, 98 | optimizer) 99 | 100 | # scheduler step 101 | scheduler.step() 102 | 103 | # save 104 | if rank == 0: 105 | np.save(path, cluster.codebook.detach().cpu().numpy() 106 | if args.distributed else cluster.codebook.detach().cpu().numpy()) 107 | 108 | # Interrupt for sync GPU Process 109 | if args.distributed: dist.barrier() 110 | 111 | else: 112 | rprint("Already Exists!!", rank) 113 | ################################################################################### 114 | 115 | 116 | # clean ddp process 117 | if args.distributed: ddp_clean() 118 | 119 | 120 | if __name__ == "__main__": 121 | 122 | # fetch args 123 | parser = argparse.ArgumentParser() 124 | 125 | # fixed parameter 126 | parser.add_argument('--epoch', default=1, type=int) 127 | parser.add_argument('--distributed', default=True, type=str2bool) 128 | parser.add_argument('--load_segment', default=False, type=str2bool) 129 | parser.add_argument('--load_cluster', default=False, type=str2bool) 130 | parser.add_argument('--train_resolution', default=320, type=int) 131 | parser.add_argument('--test_resolution', default=320, type=int) 132 | parser.add_argument('--batch_size', default=16, type=int) 133 | parser.add_argument('--num_workers', default=int(os.cpu_count() / 8), type=int) 134 | 135 | # dataset and baseline 136 | parser.add_argument('--data_dir', default='/mnt/hard2/lbk-iccv/datasets', type=str) 137 | parser.add_argument('--dataset', default='cocostuff27', type=str) 138 | parser.add_argument('--ckpt', default='checkpoint/dino_vit_base_8.pth', type=str) 139 | 140 | # DDP 141 | parser.add_argument('--gpu', default='0,1,2,3', type=str) 142 | parser.add_argument('--port', default='12355', type=str) 143 | 144 | # parameter 145 | parser.add_argument('--grid', default='yes', type=str2bool) 146 | parser.add_argument('--num_codebook', default=2048, type=int) 147 | 148 | # model parameter 149 | parser.add_argument('--reduced_dim', default=90, type=int) 150 | parser.add_argument('--projection_dim', default=2048, type=int) 151 | 152 | args = parser.parse_args() 153 | 154 | if 'dinov2' in args.ckpt: 155 | args.train_resolution=322 156 | args.test_resolution=322 157 | if 'small' in args.ckpt: 158 | args.dim=384 159 | elif 'base' in args.ckpt: 160 | args.dim=768 161 | 162 | # the number of gpus for multi-process 163 | gpu_list = list(map(int, args.gpu.split(','))) 164 | ngpus_per_node = len(gpu_list) 165 | 166 | if args.distributed: 167 | # cuda visible devices 168 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 169 | # multiprocess spawn 170 | mp.spawn(main, args=(args, ngpus_per_node), nprocs=ngpus_per_node, join=True) 171 | else: 172 | # first gpu index is activated once there are several gpu in args.gpu 173 | main(rank=gpu_list[0], args=args, ngpus_per_node=1) 174 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ByungKwanLee/Causal-Unsupervised-Segmentation/ee0aa8478a6b6704f4db44ecc70e44a14fe5067f/utils/__init__.py -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | import os 3 | import numpy as np 4 | import torchvision 5 | import matplotlib 6 | import matplotlib.pyplot as plt 7 | from multiprocessing import Pool 8 | 9 | invTrans = transforms.Compose([transforms.Normalize(mean=[0., 0., 0.], 10 | std=[1 / 0.229, 1 / 0.224, 1 / 0.225]), 11 | transforms.Normalize(mean=[-0.485, -0.456, -0.406], 12 | std=[1., 1., 1.]), 13 | ]) 14 | Trans = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 15 | 16 | is_sym = lambda x: (x.transpose(1, 0) == x).all().item() 17 | 18 | import datetime 19 | class Wrapper(object): 20 | @staticmethod 21 | def InitializePrint(func): 22 | def wrapper(rank, *args, **kwards): 23 | rprint(f'-------------Initialize VQ-VAE-------------', rank) 24 | func(*args, **kwards) 25 | return wrapper 26 | @staticmethod 27 | def TestPrint(func): 28 | def wrapper(epoch, rank, *args, **kwards): 29 | rprint(f'-------------TEST EPOCH: {epoch+1}-------------', rank) 30 | return func(*args, **kwards) 31 | return wrapper 32 | @staticmethod 33 | def EpochPrint(func): 34 | def wrapper(epoch, rank, *args, **kwards): 35 | rprint(f'-------------TRAIN EPOCH: {epoch+1}-------------', rank) 36 | func(*args, **kwards) 37 | return wrapper 38 | @staticmethod 39 | def KmeansPrint(func): 40 | def wrapper(rank, *args, **kwards): 41 | rprint(f'-------------K-Means Clustering-------------', rank) 42 | func(*args, **kwards) 43 | return wrapper 44 | @staticmethod 45 | def TimePrint(func): 46 | def wrapper(*args, **kwargs): 47 | start = datetime.datetime.now() 48 | out = func(*args, **kwargs) 49 | end = datetime.datetime.now() 50 | print(f'[{func.__name__}] Time: {(end - start).total_seconds():.2f}sec') 51 | return out 52 | return wrapper 53 | 54 | def pickle_path_and_exist(args): 55 | from os.path import exists 56 | baseline = args.ckpt.split('/')[-1].split('.')[0] 57 | check_dir(f'CAUSE/{args.dataset}') 58 | check_dir(f'CAUSE/{args.dataset}/modularity') 59 | check_dir(f'CAUSE/{args.dataset}/modularity/{baseline}') 60 | check_dir(f'CAUSE/{args.dataset}/modularity/{baseline}/{args.num_codebook}') 61 | filepath = f'CAUSE/{args.dataset}/modularity/{baseline}/{args.num_codebook}/modular.npy' 62 | return filepath, exists(filepath) 63 | 64 | def freeze(net): 65 | # net eval and freeze 66 | net.eval() 67 | for param in net.parameters(): 68 | param.requires_grad = False 69 | 70 | def no_freeze(net): 71 | # net eval and freeze 72 | net.eval() 73 | for param in net.parameters(): 74 | param.requires_grad = True 75 | 76 | def str2bool(v): 77 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 78 | return True 79 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 80 | return False 81 | else: 82 | assert False 83 | 84 | def check_dir(dir_path): 85 | if not os.path.exists(dir_path): 86 | os.makedirs(dir_path) 87 | 88 | def save_all(args, ind, img, label, cluster_preds, crf_preds, hungarian_preds, cmap, is_tr=False): 89 | baseline = args.ckpt.split('/')[-1].split('.')[0] 90 | y = f'{args.num_codebook}' 91 | check_dir(f'results') 92 | if is_tr: 93 | root = os.path.join('results', args.dataset, 'TR', baseline, y) 94 | else: 95 | root = os.path.join('results', args.dataset, 'MLP', baseline, y) 96 | 97 | check_dir(f'{root}/imgs') 98 | check_dir(f'{root}/labels') 99 | check_dir(f'{root}/kmeans') 100 | check_dir(f'{root}/crfs') 101 | check_dir(f'{root}/hungarians') 102 | # save image 103 | for id, i in [(id, x.item()) for id, x in enumerate(list(ind))]: 104 | torchvision.utils.save_image(invTrans(img)[id].cpu(), f'{root}/imgs/imgs_{i}.png') 105 | torchvision.utils.save_image(torch.from_numpy(cmap[label[id].cpu()]).permute(2, 0, 1), 106 | f'{root}/labels/labels_{i}.png') 107 | torchvision.utils.save_image(torch.from_numpy(cmap[cluster_preds[id].cpu()]).permute(2, 0, 1), 108 | f'{root}/kmeans/kmeans_{i}.png') 109 | torchvision.utils.save_image(torch.from_numpy(cmap[crf_preds[id].cpu()]).permute(2, 0, 1), 110 | f'{root}/crfs/crfs_{i}.png') 111 | torchvision.utils.save_image(torch.from_numpy(cmap[hungarian_preds[id].cpu()]).permute(2, 0, 1), 112 | f'{root}/hungarians/hungarians_{i}.png') 113 | 114 | def rprint(msg, rank=0): 115 | if rank==0: print(msg) 116 | 117 | def num_param(f): 118 | out = 0 119 | for param in f.head.parameters(): 120 | out += param.numel() 121 | return out 122 | 123 | def imshow(img): 124 | a = 255 * invTrans(img).permute(1, 2, 0).cpu().numpy() 125 | plt.imshow(a.astype(np.int64)) 126 | plt.show() 127 | 128 | def plot(x): 129 | plt.plot(x.cpu().numpy()) 130 | plt.show() 131 | 132 | def cmshow(img): 133 | # color map 134 | cmap = create_cityscapes_colormap() 135 | plt.imshow(cmap[img.cpu().numpy()]) 136 | plt.show() 137 | 138 | def getCMap(n_classes=27, cmapName='jet'): 139 | 140 | # Get jet color map from Matlab 141 | labelCount = n_classes 142 | cmapGen = matplotlib.cm.get_cmap(cmapName, labelCount) 143 | cmap = cmapGen(np.arange(labelCount)) 144 | cmap = cmap[:, 0:3] 145 | 146 | # Reduce value/brightness of stuff colors (easier in HSV format) 147 | cmap = cmap.reshape((-1, 1, 3)) 148 | hsv = matplotlib.colors.rgb_to_hsv(cmap) 149 | hsv[:, 0, 2] = hsv[:, 0, 2] * 0.7 150 | cmap = matplotlib.colors.hsv_to_rgb(hsv) 151 | cmap = cmap.reshape((-1, 3)) 152 | 153 | # Permute entries to avoid classes with similar name having similar colors 154 | st0 = np.random.get_state() 155 | np.random.seed(42) 156 | perm = np.random.permutation(labelCount) 157 | np.random.set_state(st0) 158 | cmap = cmap[perm, :] 159 | 160 | return cmap 161 | 162 | def ckpt_to_name(ckpt): 163 | name = ckpt.split('/')[-1].split('_')[0] 164 | return name 165 | 166 | def ckpt_to_arch(ckpt): 167 | arch = ckpt.split('/')[-1].split('.')[0] 168 | return arch 169 | 170 | def print_argparse(args, rank=0): 171 | dict = vars(args) 172 | if rank == 0: 173 | print('------------------Configurations------------------') 174 | for key in dict.keys(): print("{}: {}".format(key, dict[key])) 175 | print('-------------------------------------------------') 176 | 177 | 178 | """ 179 | 180 | BELOW are STEGO Fucntion 181 | 182 | """ 183 | 184 | from collections import OrderedDict 185 | class NiceTool(object): 186 | 187 | def __init__(self, n_classes): 188 | self.n_classes = n_classes 189 | self.histogram = torch.zeros((self.n_classes, self.n_classes)).cuda() 190 | 191 | def scores(self, label_trues, label_preds): 192 | mask = (label_trues >= 0) & (label_trues < self.n_classes) & (label_preds >= 0) & (label_preds < self.n_classes) # Exclude unlabelled data. 193 | hist = torch.bincount(self.n_classes * label_trues[mask] + label_preds[mask], \ 194 | minlength=self.n_classes ** 2).reshape(self.n_classes, self.n_classes).t().cuda() 195 | return hist 196 | 197 | def eval(self, pred, label): 198 | pred = pred.reshape(-1) 199 | label = label.reshape(-1) 200 | self.histogram += self.scores(label, pred) 201 | 202 | self.assignments = linear_sum_assignment(self.histogram.cpu(), maximize=True) 203 | hist = self.histogram[np.argsort(self.assignments[1]), :] 204 | 205 | tp = torch.diag(hist) 206 | fp = torch.sum(hist, dim=0) - tp 207 | fn = torch.sum(hist, dim=1) - tp 208 | 209 | iou = tp / (tp + fp + fn) 210 | prc = tp / (tp + fn) 211 | opc = torch.sum(tp) / torch.sum(hist) 212 | 213 | # metric 214 | metric_dict = OrderedDict({"mIoU": iou[~torch.isnan(iou)].mean().item() * 100, 215 | # "Precision per Class (%)": prc * 100, 216 | "mAP": prc[~torch.isnan(prc)].mean().item() * 100, 217 | "Acc": opc.item() * 100}) 218 | 219 | 220 | self.metric_dict_by_class = OrderedDict({"mIoU": iou * 100, 221 | # "Precision per Class (%)": prc * 100, 222 | "mAP": prc * 100, 223 | "Acc": (torch.diag(hist) / hist.sum(dim=1)) * 100}) 224 | 225 | 226 | # generate desc 227 | sentence = '' 228 | for key, value in metric_dict.items(): 229 | if type(value) == torch.Tensor: continue 230 | sentence += f'[{key}]: {value:.1f}, ' 231 | return metric_dict, sentence 232 | 233 | def reset(self): 234 | self.histogram = torch.zeros((self.n_classes, self.n_classes)).cuda() 235 | 236 | def do_hungarian(self, clusters): 237 | return torch.tensor(self.assignments[1])[clusters.cpu()] 238 | 239 | 240 | from scipy.optimize import linear_sum_assignment 241 | import pydensecrf.densecrf as dcrf 242 | import pydensecrf.utils as utils 243 | import torch 244 | import torch.nn.functional as F 245 | import torchvision.transforms.functional as VF 246 | 247 | def dense_crf(image_tensor: torch.FloatTensor, output_logits: torch.FloatTensor, max_iter: int): 248 | MAX_ITER = max_iter 249 | POS_W = 3 250 | POS_XY_STD = 1 251 | Bi_W = 4 252 | Bi_XY_STD = 67 253 | Bi_RGB_STD = 3 254 | 255 | image = np.array(VF.to_pil_image(invTrans(image_tensor)))[:, :, ::-1] 256 | H, W = image.shape[:2] 257 | image = np.ascontiguousarray(image) 258 | 259 | output_logits = F.interpolate(output_logits.unsqueeze(0), size=(H, W), mode="bilinear", 260 | align_corners=False).squeeze() 261 | output_probs = F.softmax(output_logits, dim=0).cpu().numpy() 262 | c = output_probs.shape[0] 263 | h = output_probs.shape[1] 264 | w = output_probs.shape[2] 265 | 266 | U = utils.unary_from_softmax(output_probs) 267 | U = np.ascontiguousarray(U) 268 | 269 | d = dcrf.DenseCRF2D(w, h, c) 270 | d.setUnaryEnergy(U) 271 | d.addPairwiseGaussian(sxy=POS_XY_STD, compat=POS_W) 272 | d.addPairwiseBilateral(sxy=Bi_XY_STD, srgb=Bi_RGB_STD, rgbim=image, compat=Bi_W) 273 | 274 | Q = d.inference(MAX_ITER) 275 | Q = np.array(Q).reshape((c, h, w)) 276 | return Q 277 | 278 | def _apply_crf(tup, max_iter): 279 | return dense_crf(tup[0], tup[1], max_iter=max_iter) 280 | 281 | def do_crf(pool, img_tensor, prob_tensor, max_iter=10): 282 | from functools import partial 283 | outputs = pool.map(partial(_apply_crf, max_iter=max_iter), zip(img_tensor.detach().cpu(), prob_tensor.detach().cpu())) 284 | return torch.cat([torch.from_numpy(arr).unsqueeze(0) for arr in outputs], dim=0) 285 | 286 | def create_pascal_label_colormap(): 287 | def bit_get(val, idx): 288 | return (val >> idx) & 1 289 | colormap = np.zeros((512, 3), dtype=int) 290 | ind = np.arange(512, dtype=int) 291 | 292 | for shift in reversed(list(range(8))): 293 | for channel in range(3): 294 | colormap[:, channel] |= bit_get(ind, channel) << shift 295 | ind >>= 3 296 | 297 | return colormap / 255 298 | 299 | def create_cityscapes_colormap(): 300 | colors = [(128, 64, 128), 301 | (244, 35, 232), 302 | (250, 170, 160), 303 | (230, 150, 140), 304 | (70, 70, 70), 305 | (102, 102, 156), 306 | (190, 153, 153), 307 | (180, 165, 180), 308 | (150, 100, 100), 309 | (150, 120, 90), 310 | (153, 153, 153), 311 | (153, 153, 153), 312 | (250, 170, 30), 313 | (220, 220, 0), 314 | (107, 142, 35), 315 | (152, 251, 152), 316 | (70, 130, 180), 317 | (220, 20, 60), 318 | (255, 0, 0), 319 | (0, 0, 142), 320 | (0, 0, 70), 321 | (0, 60, 100), 322 | (0, 0, 90), 323 | (0, 0, 110), 324 | (0, 80, 100), 325 | (0, 0, 230), 326 | (119, 11, 32), 327 | (0, 0, 0)] 328 | return np.array(colors, dtype=int) / 255 329 | 330 | 331 | 332 | class ToTargetTensor(object): 333 | def __call__(self, target): 334 | return torch.as_tensor(np.array(target), dtype=torch.int64).unsqueeze(0) 335 | 336 | 337 | 338 | from torchvision.transforms import InterpolationMode 339 | # DATA Transformation 340 | def get_cococity_transform(res, is_label): 341 | 342 | if is_label: 343 | return transforms.Compose([transforms.Resize(res, interpolation=InterpolationMode.NEAREST), 344 | transforms.CenterCrop(res), 345 | ToTargetTensor()]) 346 | else: 347 | return transforms.Compose([transforms.Resize(res, interpolation=InterpolationMode.NEAREST), 348 | transforms.CenterCrop(res), 349 | transforms.ToTensor()]) 350 | 351 | # DATA Transformation 352 | def get_pascal_transform(res, is_label): 353 | if is_label: 354 | return transforms.Compose([transforms.Resize((res, res), interpolation=InterpolationMode.NEAREST), 355 | ToTargetTensor()]) 356 | else: 357 | return transforms.Compose([transforms.Resize((res, res), interpolation=InterpolationMode.NEAREST), 358 | transforms.ToTensor()]) --------------------------------------------------------------------------------