├── 2007_train.txt ├── 2007_val.txt ├── FER-YOLO-Mamba.jpg ├── README.md ├── model_data ├── sfew_classes.txt ├── simhei.ttf └── yolox_s.pth ├── nets ├── .ipynb_checkpoints │ └── yolo-checkpoint.py ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-311.pyc │ ├── darknet.cpython-311.pyc │ ├── yolo.cpython-311.pyc │ └── yolo_training.cpython-311.pyc ├── darknet.py ├── yolo.py └── yolo_training.py ├── predict.py ├── requirements.txt ├── train.py ├── utils ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-311.pyc │ ├── callbacks.cpython-311.pyc │ ├── dataloader.cpython-311.pyc │ ├── utils.cpython-311.pyc │ ├── utils_bbox.cpython-311.pyc │ ├── utils_fit.cpython-311.pyc │ └── utils_map.cpython-311.pyc ├── callbacks.py ├── dataloader.py ├── utils.py ├── utils_bbox.py ├── utils_fit.py └── utils_map.py └── yolo.py /2007_val.txt: -------------------------------------------------------------------------------- 1 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/Hangover_011409414_00000041.jpg 431,124,535,285,3 2 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/PursuitofHappiness_001943720_00000006.jpg 152,138,263,316,0 3 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/GirlWithAPearlEarring_003311120_00000044.jpg 304,181,462,416,3 4 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/SocialNetwork_003715767_00000028.jpg 214,95,356,347,4 5 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/ThereIsSomethingAboutMary_005941880_00000060.jpg 323,165,494,423,4 6 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/OceansTwelve_011845760_00000013.jpg 341,158,467,342,2 7 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/IAmSam_013827920__00000055.jpg 290,143,473,405,5 8 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/Saw3D_004056050_00000003.jpg 386,161,508,368,0 9 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/DeepBlueSea_000744200_00000010.jpg 237,101,319,226,4 10 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/HarryPotter_Deathly_Hallows_1_010156240_00000001.jpg 265,165,346,286,3 11 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/RevolutionaryRoad_000806440_00000052.jpg 429,151,480,231,0 12 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/HarryPotter_Half_Blood_Prince_004424374_00000028.jpg 248,177,327,286,2 13 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/MissMarch_000546640_00000065.jpg 274,188,431,452,4 14 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/NanyDiaries_012335760_00000006.jpg 234,80,402,391,0 15 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/SomethingBorrowed_002522600_00000009.jpg 288,162,448,370,3 16 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/NottingHill_010958154_00000029.jpg 170,134,286,360,4 17 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/OceansTwelve_000552400_00000036.jpg 260,132,413,384,2 18 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/SomethingBorrowed_012536680_00000001.jpg 374,170,420,235,5 19 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/Hangover_010805134_00000008.jpg 145,83,286,339,2 20 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/WrongTurn3_003354120_00000034.jpg 349,151,507,408,4 21 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/21_012215600_00000022.jpg 325,135,454,356,0 22 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/RevolutionaryRoad_010251320_00000001.jpg 281,160,409,404,5 23 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/OceansThirteen_003646520_00000002.jpg 230,166,277,259,4 24 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/AboutABoy_012436647_00000040.jpg 460,192,609,420,5 25 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/PursuitofHappiness_002303560_00000086.jpg 532,153,614,310,0 26 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/NanyDiaries_002901240_00000031.jpg 335,189,512,478,4 27 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/DidYouHearAboutTheMorgans_003555207_00000003.jpg 291,143,371,268,6 28 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/JennifersBody_012437771_00000033.jpg 229,117,449,420,5 29 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/OceansEleven_005436800_00000038.jpg 314,152,393,294,4 30 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/AboutABoy_000403327_00000010.jpg 281,144,382,313,3 31 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/JennifersBody_003236644_00000001.jpg 384,87,492,271,6 32 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/AlexEmma_000958320_00000069.jpg 265,129,330,281,0 33 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/MarotAtTheWedding_010431280_00000042.jpg 392,110,519,302,6 34 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/DeepBlueSea_004628560_00000015.jpg 407,101,635,486,2 35 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/PrettyInPink_000543360_00000004.jpg 315,124,504,437,4 36 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/RevolutionaryRoad_000806440_00000038.jpg 440,171,479,244,0 37 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/AlexEmma_004507280_00000022.jpg 83,97,260,397,4 38 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/PrettyInPink_000543360_00000023.jpg 260,174,444,479,5 39 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/Informant_000946654_00000058.jpg 359,91,419,183,1 40 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/WrongTurn3_001131440_00000027.jpg 238,135,347,315,3 41 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/Bridesmaids_004523000_00000001.jpg 400,158,504,337,1 42 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/ACaseofYou_003118400_00000044.jpg 169,155,227,265,6 43 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/DecemberBoys_003642720_00000058.jpg 331,113,412,255,0 44 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/IAmSam_013820200_00000068.jpg 253,106,411,361,5 45 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/OneFlewOverCuckooNest_004157440_00000001.jpg 395,198,504,372,3 46 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/Saw3D_004056050_00000032.jpg 345,197,484,447,0 47 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/SomethingBorrowed_000513320_00000001.jpg 282,58,374,198,3 48 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/SocialNetwork_002249887_00000019.jpg 479,132,538,211,4 49 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/MissMarch_004423640_00000062.jpg 173,114,320,383,0 50 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/MissMarch_000257240_00000013.jpg 347,62,437,259,3 51 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/Terminal_014001000_00000051.jpg 273,104,425,341,3 52 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/Hangover_011409414_00000015.jpg 429,138,532,314,3 53 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/RememberMe_011005340_00000030.jpg 388,65,495,242,5 54 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/OrangeAndSunshine_011341440_00000004.jpg 189,112,301,329,5 55 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/Hangover_013036534_00000022.jpg 361,93,501,288,3 56 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/ItsComplicated_011340288_00000001.jpg 284,182,417,376,3 57 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/AlexEmma_000846120_00000015.jpg 401,170,467,284,0 58 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/ThereIsSomethingAboutMary_002830120_00000001.jpg 236,209,416,461,3 59 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/RevolutionaryRoad_012136400_00000012.jpg 251,160,340,306,0 60 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/HarryPotter_Half_Blood_Prince_004424374_00000001.jpg 228,176,302,287,5 61 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/DeepBlueSea_010340320_00000019.jpg 221,181,415,475,2 62 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/OceansTwelve_002551440_00000064.jpg 302,194,424,367,2 63 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/OceansEleven_005846280_00000054.jpg 348,170,452,352,0 64 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/AlexEmma_005143680_00000023.jpg 188,113,410,461,3 65 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/FriendsWithBenefit_002137920_00000012.jpg 190,148,336,376,3 66 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/21_001108440_00000030.jpg 282,178,428,415,0 67 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/34.jpg 351,125,495,335,1 68 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/GirlWithAPearlEarring_012432920_00000021.jpg 287,240,366,388,0 69 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/Informant_000946654_00000014.jpg 350,83,415,197,4 70 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/ChangeUp_000423240_00000040.jpg 426,137,559,340,3 71 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/21_012224880_00000015.jpg 128,156,278,413,0 72 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/AlexEmma_012652960_00000060.jpg 417,153,532,347,5 73 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/Hangover_001949614_00000089.jpg 424,175,491,271,2 74 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/Unstoppable_004406878_00000020.jpg 444,197,540,381,6 75 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/DeepBlueSea_010340320_00000006.jpg 290,155,479,463,2 76 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/HauntingMollyHartely_002326760_00000046.jpg 163,82,264,237,1 77 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/Serendipity_001925727_00000001.jpg 306,149,432,383,3 78 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/LittleManhattan_000451774_00000052.jpg 271,166,360,294,5 79 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/JennifersBody_001947263_00000001.jpg 225,146,404,475,2 80 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/21_013508360_00000078.jpg 437,170,548,391,5 81 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/HarryPotter_Half_Blood_Prince_002404654_00000021.jpg 356,159,437,273,4 82 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/Descendants_002643280_00000048.jpg 212,162,342,340,5 83 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/ChangeUp_000423240_00000006.jpg 417,133,537,344,3 84 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/TheEye_003934640_00000006.jpg 192,187,324,408,5 85 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/21_010330280_00000013.jpg 269,147,351,272,1 86 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/Hangover_010805134_00000013.jpg 214,155,353,413,2 87 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/ItsComplicated_005221727_00000030.jpg 361,121,465,296,5 88 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/OneFlewOverCuckooNest_013610680_00000016.jpg 510,80,644,265,5 89 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/21_013508360_00000064.jpg 430,170,540,402,5 90 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/WrongTurn3_002350040_00000050.jpg 260,134,436,409,0 91 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/NotSuitableForChildren_002504480_00000039.jpg 436,89,606,426,3 92 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/GirlWithAPearlEarring_004432000_00000009.jpg 321,140,403,313,6 93 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/ThereIsSomethingAboutMary_002830120_00000013.jpg 168,211,338,458,3 94 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/HarryPotter_Deathly_Hallows_2_000852880_00000044.jpg 230,172,341,338,5 95 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/Saw3D_000322047_00000001.jpg 338,150,446,313,2 96 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/HarryPotter_Half_Blood_Prince_000029301_00000004.jpg 247,111,492,472,4 97 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/21_012152480_00000001.jpg 341,154,445,330,2 98 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/Bridesmaids_012949160_00000012.jpg 453,179,504,254,5 99 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/Juno_010939560_00000065.jpg 387,102,546,376,6 100 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/HarryPotter_Deathly_Hallows_1_005528520_00000001.jpg 353,183,518,428,5 101 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/Terminal_004949960_00000001.jpg 193,151,322,364,0 102 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/Town_010806840_00000001.jpg 419,127,596,438,0 103 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/DecemberBoys_002006160_00000052.jpg 351,159,471,368,4 104 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/RememberMe_001635160_00000012.jpg 219,161,321,311,2 105 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/ChangeUp_001855040_00000009.jpg 345,139,486,397,0 106 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/JennifersBody_012514049_00000002.jpg 260,69,464,393,5 107 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/MissMarch_002616840_00000024.jpg 340,275,486,485,1 108 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/ItsComplicated_013152487_00000001.jpg 334,194,452,376,4 109 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/LittleManhattan_010156083_00000001.jpg 266,182,376,337,4 110 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/ItsComplicated_000857887_00000049.jpg 279,117,420,308,3 111 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/HarryPotter_GobletOfFire_000658734_00000001.jpg 245,168,373,360,6 112 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/TheShining_000534800_00000020.jpg 267,185,398,353,3 113 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/RevolutionaryRoad_012134720_00000035.jpg 278,132,353,250,0 114 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/YouveGotAMail_004828414_00000001.jpg 86,152,233,379,0 115 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/PrettyInPink_001802800_00000013.jpg 222,175,359,404,4 116 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/OceansEleven_002010400_00000015.jpg 345,198,424,341,1 117 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/GirlWithAPearlEarring_010643040_00000021.jpg 172,175,293,375,3 118 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/AlexEmma_004411160_00000026.jpg 129,118,269,366,3 119 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/21_015145521_00000044.jpg 156,168,263,340,2 120 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/ItsComplicated_003722087_00000051.jpg 298,121,410,338,6 121 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/OceansTwelve_011847320_00000054.jpg 430,160,548,351,2 122 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/PrettyInPink_002751240_00000035.jpg 158,115,308,382,5 123 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/OneFlewOverCuckooNest_011532650_00000016.jpg 262,140,407,340,5 124 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/31.jpg 218,138,295,276,2 125 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/SomethingBorrowed_001318800_00000001.jpg 300,142,402,321,4 126 | /workspace/yolox/VOCdevkit/VOC2007/JPEGImages/Informant_001817054_00000015.jpg 361,74,452,237,4 127 | -------------------------------------------------------------------------------- /FER-YOLO-Mamba.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SwjtuMa/FER-YOLO-Mamba/e301ceeaf53f0eec2f6552cf0a6eb09a2ce90587/FER-YOLO-Mamba.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FER-YOLO-Mamba: Facial Expression Detection and Classification Based on Selective State Space [arXiv](https://arxiv.org/pdf/2405.01828) 2 | 3 | ![image](https://github.com/SwjtuMa/FER-YOLO-Mamba/blob/main/FER-YOLO-Mamba.jpg) 4 | 5 | # 📝Introduction📝 6 | 7 | This repository is the code implementation of the paper FER-YOLO-Mamba: Facial Expression Detection and Classification Based on Selective State Space. 8 | 9 | The current branch has been tested on Linux system, PyTorch 2.0.0 and CUDA 11.7, supports Python Python 3.11.8, and is compatible with CUDA11.7+ versions. 10 | 11 | If you find this project helpful, please give us a star ⭐️, your support is our greatest motivation. 12 | 13 | # 📌Installation📌 14 | - pip install torch==2.0.0 torchvision==0.15.1 torchaudio==2.0.1 15 | 16 | - pip install packaging 17 | 18 | - pip install timm==0.4.12 19 | 20 | - pip install causal-conv1d==1.2.0.post2 21 | 22 | - pip install mamba-ssm==1.2.0.post1 23 | 24 | # 📜Other requirements📜: 25 | - Linux System 26 | 27 | * NVIDIA GPU 28 | 29 | + CUDA 11.7+ 30 | 31 | # ✨Acknowledgments✨ 32 | We thank the authors of [VMamba](https://github.com/MzeroMiko/VMamba), [MedMamba](https://github.com/YubiaoYue/MedMamba) and [YOLOvX](https://github.com/bubbliiiing/yolox-pytorch) for their open-source codes. 33 | 34 | # 💞Citation💞 35 | If you use the code or performance benchmarks of this project in your research, please refer to the following bibtex citation of FER-YOLO-Mamba. 36 | ``` 37 | @article{ma2024feryolomamba, 38 | title={FER-YOLO-Mamba: Facial Expression Detection and Classification Based on Selective State Space}, 39 | author={Hui Ma and Sen Lei and Turgay Celik and Heng-Chao Li}, 40 | journal={arXiv preprint arXiv:2405.01828}, 41 | year={2024} 42 | } 43 | ``` 44 | 45 | # Contact Us 46 | If you have any other questions❓, please contact us in time 👬 47 | -------------------------------------------------------------------------------- /model_data/sfew_classes.txt: -------------------------------------------------------------------------------- 1 | Angry 2 | Disgust 3 | Fear 4 | Happy 5 | Neutral 6 | Sad 7 | Surprise 8 | -------------------------------------------------------------------------------- /model_data/simhei.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SwjtuMa/FER-YOLO-Mamba/e301ceeaf53f0eec2f6552cf0a6eb09a2ce90587/model_data/simhei.ttf -------------------------------------------------------------------------------- /model_data/yolox_s.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SwjtuMa/FER-YOLO-Mamba/e301ceeaf53f0eec2f6552cf0a6eb09a2ce90587/model_data/yolox_s.pth -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- /nets/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SwjtuMa/FER-YOLO-Mamba/e301ceeaf53f0eec2f6552cf0a6eb09a2ce90587/nets/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /nets/__pycache__/darknet.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SwjtuMa/FER-YOLO-Mamba/e301ceeaf53f0eec2f6552cf0a6eb09a2ce90587/nets/__pycache__/darknet.cpython-311.pyc -------------------------------------------------------------------------------- /nets/__pycache__/yolo.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SwjtuMa/FER-YOLO-Mamba/e301ceeaf53f0eec2f6552cf0a6eb09a2ce90587/nets/__pycache__/yolo.cpython-311.pyc -------------------------------------------------------------------------------- /nets/__pycache__/yolo_training.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SwjtuMa/FER-YOLO-Mamba/e301ceeaf53f0eec2f6552cf0a6eb09a2ce90587/nets/__pycache__/yolo_training.cpython-311.pyc -------------------------------------------------------------------------------- /nets/darknet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Copyright (c) Megvii, Inc. and its affiliates. 4 | 5 | import torch 6 | from torch import nn 7 | 8 | class SiLU(nn.Module): 9 | @staticmethod 10 | def forward(x): 11 | return x * torch.sigmoid(x) 12 | 13 | def get_activation(name="silu", inplace=True): 14 | if name == "silu": 15 | module = SiLU() 16 | elif name == "relu": 17 | module = nn.ReLU(inplace=inplace) 18 | elif name == "lrelu": 19 | module = nn.LeakyReLU(0.1, inplace=inplace) 20 | else: 21 | raise AttributeError("Unsupported act type: {}".format(name)) 22 | return module 23 | 24 | class Focus(nn.Module): 25 | def __init__(self, in_channels, out_channels, ksize=1, stride=1, act="silu"): 26 | super().__init__() 27 | self.conv = BaseConv(in_channels * 4, out_channels, ksize, stride, act=act) 28 | 29 | def forward(self, x): 30 | patch_top_left = x[..., ::2, ::2] 31 | patch_bot_left = x[..., 1::2, ::2] 32 | patch_top_right = x[..., ::2, 1::2] 33 | patch_bot_right = x[..., 1::2, 1::2] 34 | x = torch.cat((patch_top_left, patch_bot_left, patch_top_right, patch_bot_right,), dim=1,) 35 | return self.conv(x) 36 | 37 | class BaseConv(nn.Module): 38 | def __init__(self, in_channels, out_channels, ksize, stride, groups=1, bias=False, act="silu"): 39 | super().__init__() 40 | pad = (ksize - 1) // 2 41 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=ksize, stride=stride, padding=pad, groups=groups, bias=bias) 42 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.03) 43 | self.act = get_activation(act, inplace=True) 44 | 45 | def forward(self, x): 46 | return self.act(self.bn(self.conv(x))) 47 | 48 | def fuseforward(self, x): 49 | return self.act(self.conv(x)) 50 | 51 | class DWConv(nn.Module): 52 | def __init__(self, in_channels, out_channels, ksize, stride=1, act="silu"): 53 | super().__init__() 54 | self.dconv = BaseConv(in_channels, in_channels, ksize=ksize, stride=stride, groups=in_channels, act=act,) 55 | self.pconv = BaseConv(in_channels, out_channels, ksize=1, stride=1, groups=1, act=act) 56 | 57 | def forward(self, x): 58 | x = self.dconv(x) 59 | return self.pconv(x) 60 | 61 | class SPPBottleneck(nn.Module): 62 | def __init__(self, in_channels, out_channels, kernel_sizes=(5, 9, 13), activation="silu"): 63 | super().__init__() 64 | hidden_channels = in_channels // 2 65 | self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=activation) 66 | self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) for ks in kernel_sizes]) 67 | conv2_channels = hidden_channels * (len(kernel_sizes) + 1) 68 | self.conv2 = BaseConv(conv2_channels, out_channels, 1, stride=1, act=activation) 69 | 70 | def forward(self, x): 71 | x = self.conv1(x) 72 | x = torch.cat([x] + [m(x) for m in self.m], dim=1) 73 | x = self.conv2(x) 74 | return x 75 | 76 | class Bottleneck(nn.Module): 77 | # Standard bottleneck 78 | def __init__(self, in_channels, out_channels, shortcut=True, expansion=0.5, depthwise=False, act="silu",): 79 | super().__init__() 80 | hidden_channels = int(out_channels * expansion) 81 | Conv = DWConv if depthwise else BaseConv 82 | self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act) 83 | self.conv2 = Conv(hidden_channels, out_channels, 3, stride=1, act=act) 84 | self.use_add = shortcut and in_channels == out_channels 85 | 86 | def forward(self, x): 87 | y = self.conv2(self.conv1(x)) 88 | if self.use_add: 89 | y = y + x 90 | return y 91 | 92 | class CSPLayer(nn.Module): 93 | def __init__(self, in_channels, out_channels, n=1, shortcut=True, expansion=0.5, depthwise=False, act="silu",): 94 | # ch_in, ch_out, number, shortcut, groups, expansion 95 | super().__init__() 96 | hidden_channels = int(out_channels * expansion) 97 | self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act) 98 | self.conv2 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act) 99 | self.conv3 = BaseConv(2 * hidden_channels, out_channels, 1, stride=1, act=act) 100 | module_list = [Bottleneck(hidden_channels, hidden_channels, shortcut, 1.0, depthwise, act=act) for _ in range(n)] 101 | self.m = nn.Sequential(*module_list) 102 | 103 | def forward(self, x): 104 | x_1 = self.conv1(x) 105 | x_2 = self.conv2(x) 106 | x_1 = self.m(x_1) 107 | x = torch.cat((x_1, x_2), dim=1) 108 | return self.conv3(x) 109 | 110 | class CSPDarknet(nn.Module): 111 | def __init__(self, dep_mul, wid_mul, out_features=("dark3", "dark4", "dark5"), depthwise=False, act="silu",): 112 | super().__init__() 113 | assert out_features, "please provide output features of Darknet" 114 | self.out_features = out_features 115 | Conv = DWConv if depthwise else BaseConv 116 | base_channels = int(wid_mul * 64) 117 | base_depth = max(round(dep_mul * 3), 1) 118 | self.stem = Focus(3, base_channels, ksize=3, act=act) 119 | self.dark2 = nn.Sequential( 120 | Conv(base_channels, base_channels * 2, 3, 2, act=act), 121 | CSPLayer(base_channels * 2, base_channels * 2, n=base_depth, depthwise=depthwise, act=act), 122 | ) 123 | self.dark3 = nn.Sequential( 124 | Conv(base_channels * 2, base_channels * 4, 3, 2, act=act), 125 | CSPLayer(base_channels * 4, base_channels * 4, n=base_depth * 3, depthwise=depthwise, act=act), 126 | ) 127 | self.dark4 = nn.Sequential( 128 | Conv(base_channels * 4, base_channels * 8, 3, 2, act=act), 129 | CSPLayer(base_channels * 8, base_channels * 8, n=base_depth * 3, depthwise=depthwise, act=act), 130 | ) 131 | self.dark5 = nn.Sequential( 132 | Conv(base_channels * 8, base_channels * 16, 3, 2, act=act), 133 | SPPBottleneck(base_channels * 16, base_channels * 16, activation=act), 134 | CSPLayer(base_channels * 16, base_channels * 16, n=base_depth, shortcut=False, depthwise=depthwise, act=act), 135 | ) 136 | 137 | def forward(self, x): 138 | outputs = {} 139 | x = self.stem(x) 140 | outputs["stem"] = x 141 | x = self.dark2(x) 142 | outputs["dark2"] = x 143 | x = self.dark3(x) 144 | outputs["dark3"] = x 145 | x = self.dark4(x) 146 | outputs["dark4"] = x 147 | x = self.dark5(x) 148 | outputs["dark5"] = x 149 | return {k: v for k, v in outputs.items() if k in self.out_features} 150 | 151 | 152 | if __name__ == '__main__': 153 | print(CSPDarknet(1, 1)) -------------------------------------------------------------------------------- /nets/yolo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Copyright (c) Megvii, Inc. and its affiliates. 4 | 5 | import torch 6 | import torch.nn as nn 7 | from functools import partial 8 | from typing import Optional, Callable, Any 9 | from .darknet import BaseConv, CSPDarknet, CSPLayer, DWConv 10 | import math 11 | from einops import rearrange, repeat 12 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 13 | import torch.nn.functional as F 14 | try: 15 | from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref 16 | except: 17 | pass 18 | 19 | try: 20 | "sscore acts the same as mamba_ssm" 21 | SSMODE = "sscore" 22 | import selective_scan_cuda_core 23 | except Exception as e: 24 | print(e, flush=True) 25 | "you should install mamba_ssm to use this" 26 | SSMODE = "mamba_ssm" 27 | import selective_scan_cuda 28 | # from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref 29 | 30 | # an alternative for mamba_ssm (in which causal_conv1d is needed) 31 | try: 32 | from selective_scan import selective_scan_fn as selective_scan_fn_v1 33 | from selective_scan import selective_scan_ref as selective_scan_ref_v1 34 | except: 35 | pass 36 | 37 | class PatchEmbed2D(nn.Module): 38 | r""" Image to Patch Embedding 39 | Args: 40 | patch_size (int): Patch token size. Default: 4. 41 | in_chans (int): Number of input image channels. Default: 3. 42 | embed_dim (int): Number of linear projection output channels. Default: 96. 43 | norm_layer (nn.Module, optional): Normalization layer. Default: None 44 | """ 45 | def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, **kwargs): 46 | super().__init__() 47 | if isinstance(patch_size, int): 48 | patch_size = (patch_size, patch_size) 49 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 50 | if norm_layer is not None: 51 | self.norm = norm_layer(embed_dim) 52 | else: 53 | self.norm = None 54 | 55 | def forward(self, x): 56 | x = self.proj(x).permute(0, 2, 3, 1) 57 | if self.norm is not None: 58 | x = self.norm(x) 59 | return x 60 | 61 | 62 | class PatchMerging2D(nn.Module): 63 | r""" Patch Merging Layer. 64 | Args: 65 | input_resolution (tuple[int]): Resolution of input feature. 66 | dim (int): Number of input channels. 67 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 68 | """ 69 | 70 | def __init__(self, dim, norm_layer=nn.LayerNorm): 71 | super().__init__() 72 | self.dim = dim 73 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 74 | self.norm = norm_layer(4 * dim) 75 | 76 | def forward(self, x): 77 | B, H, W, C = x.shape 78 | 79 | SHAPE_FIX = [-1, -1] 80 | if (W % 2 != 0) or (H % 2 != 0): 81 | print(f"Warning, x.shape {x.shape} is not match even ===========", flush=True) 82 | SHAPE_FIX[0] = H // 2 83 | SHAPE_FIX[1] = W // 2 84 | 85 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 86 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 87 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 88 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 89 | 90 | if SHAPE_FIX[0] > 0: 91 | x0 = x0[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] 92 | x1 = x1[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] 93 | x2 = x2[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] 94 | x3 = x3[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] 95 | 96 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 97 | x = x.view(B, H//2, W//2, 4 * C) # B H/2*W/2 4*C 98 | 99 | x = self.norm(x) 100 | x = self.reduction(x) 101 | 102 | return x 103 | 104 | 105 | class PatchExpand2D(nn.Module): 106 | def __init__(self, dim, dim_scale=2, norm_layer=nn.LayerNorm): 107 | super().__init__() 108 | self.dim = dim*2 109 | self.dim_scale = dim_scale 110 | self.expand = nn.Linear(self.dim, dim_scale*self.dim, bias=False) 111 | self.norm = norm_layer(self.dim // dim_scale) 112 | 113 | def forward(self, x): 114 | B, H, W, C = x.shape 115 | x = self.expand(x) 116 | 117 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//self.dim_scale) 118 | x= self.norm(x) 119 | 120 | return x 121 | 122 | 123 | class Final_PatchExpand2D(nn.Module): 124 | def __init__(self, dim, dim_scale=4, norm_layer=nn.LayerNorm): 125 | super().__init__() 126 | self.dim = dim 127 | self.dim_scale = dim_scale 128 | self.expand = nn.Linear(self.dim, dim_scale*self.dim, bias=False) 129 | self.norm = norm_layer(self.dim // dim_scale) 130 | 131 | def forward(self, x): 132 | B, H, W, C = x.shape 133 | x = self.expand(x) 134 | 135 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//self.dim_scale) 136 | x= self.norm(x) 137 | 138 | return x 139 | 140 | 141 | class SelectiveScanMamba(torch.autograd.Function): 142 | # comment all checks if inside cross_selective_scan 143 | @staticmethod 144 | @torch.cuda.amp.custom_fwd 145 | def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1, oflex=True): 146 | ctx.delta_softplus = delta_softplus 147 | out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, None, delta_bias, delta_softplus) 148 | ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) 149 | return out 150 | 151 | @staticmethod 152 | @torch.cuda.amp.custom_bwd 153 | def backward(ctx, dout, *args): 154 | u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors 155 | if dout.stride(-1) != 1: 156 | dout = dout.contiguous() 157 | 158 | du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( 159 | u, delta, A, B, C, D, None, delta_bias, dout, x, None, None, ctx.delta_softplus, 160 | False 161 | ) 162 | return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None) 163 | 164 | 165 | class SelectiveScanCore(torch.autograd.Function): 166 | @staticmethod 167 | @torch.cuda.amp.custom_fwd 168 | def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1, oflex=True): 169 | ctx.delta_softplus = delta_softplus 170 | if SSMODE == "mamba_ssm": 171 | out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, None, delta_bias, delta_softplus) 172 | else: 173 | out, x, *rest = selective_scan_cuda_core.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows) 174 | ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) 175 | # out, x, *rest = selective_scan_cuda_core.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1) 176 | # ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) 177 | return out 178 | @staticmethod 179 | @torch.cuda.amp.custom_bwd 180 | def backward(ctx, dout, *args): 181 | u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors 182 | if dout.stride(-1) != 1: 183 | dout = dout.contiguous() 184 | 185 | if SSMODE == "mamba_ssm": 186 | du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( 187 | u, delta, A, B, C, D, None, delta_bias, dout, x, None, None, ctx.delta_softplus, 188 | False # option to recompute out_z, not used here 189 | ) 190 | else: 191 | du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_core.bwd( 192 | u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1 193 | ) 194 | 195 | return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None) 196 | 197 | 198 | class SelectiveScanOflex(torch.autograd.Function): 199 | # comment all checks if inside cross_selective_scan 200 | @staticmethod 201 | @torch.cuda.amp.custom_fwd 202 | def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1, oflex=True): 203 | ctx.delta_softplus = delta_softplus 204 | out, x, *rest = selective_scan_cuda_oflex.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1, oflex) 205 | ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) 206 | return out 207 | 208 | @staticmethod 209 | @torch.cuda.amp.custom_bwd 210 | def backward(ctx, dout, *args): 211 | u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors 212 | if dout.stride(-1) != 1: 213 | dout = dout.contiguous() 214 | du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_oflex.bwd( 215 | u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1 216 | ) 217 | return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None) 218 | 219 | 220 | class SelectiveScanFake(torch.autograd.Function): 221 | # comment all checks if inside cross_selective_scan 222 | @staticmethod 223 | @torch.cuda.amp.custom_fwd 224 | def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1, oflex=True): 225 | ctx.delta_softplus = delta_softplus 226 | ctx.backnrows = backnrows 227 | x = delta 228 | out = u 229 | ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) 230 | return out 231 | 232 | @staticmethod 233 | @torch.cuda.amp.custom_bwd 234 | def backward(ctx, dout, *args): 235 | u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors 236 | if dout.stride(-1) != 1: 237 | dout = dout.contiguous() 238 | du, ddelta, dA, dB, dC, dD, ddelta_bias = u * 0, delta * 0, A * 0, B * 0, C * 0, C * 0, (D * 0 if D else None), (delta_bias * 0 if delta_bias else None) 239 | return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None) 240 | 241 | # ============= 242 | def antidiagonal_gather(tensor): 243 | B, C, H, W = tensor.size() 244 | shift = torch.arange(H, device=tensor.device).unsqueeze(1) 245 | index = (torch.arange(W, device=tensor.device) - shift) % W 246 | expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1) 247 | return tensor.gather(3, expanded_index).transpose(-1,-2).reshape(B, C, H*W) 248 | 249 | def diagonal_gather(tensor): 250 | B, C, H, W = tensor.size() 251 | shift = torch.arange(H, device=tensor.device).unsqueeze(1) 252 | index = (shift + torch.arange(W, device=tensor.device)) % W 253 | expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1) 254 | return tensor.gather(3, expanded_index).transpose(-1,-2).reshape(B, C, H*W) 255 | 256 | def diagonal_scatter(tensor_flat, original_shape): 257 | B, C, H, W = original_shape 258 | shift = torch.arange(H, device=tensor_flat.device).unsqueeze(1) 259 | index = (shift + torch.arange(W, device=tensor_flat.device)) % W 260 | expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1) 261 | result_tensor = torch.zeros(B, C, H, W, device=tensor_flat.device, dtype=tensor_flat.dtype) 262 | tensor_reshaped = tensor_flat.reshape(B, C, W, H).transpose(-1, -2) 263 | result_tensor.scatter_(3, expanded_index, tensor_reshaped) 264 | return result_tensor 265 | 266 | def antidiagonal_scatter(tensor_flat, original_shape): 267 | B, C, H, W = original_shape 268 | shift = torch.arange(H, device=tensor_flat.device).unsqueeze(1) 269 | index = (torch.arange(W, device=tensor_flat.device) - shift) % W 270 | expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1) 271 | result_tensor = torch.zeros(B, C, H, W, device=tensor_flat.device, dtype=tensor_flat.dtype) 272 | tensor_reshaped = tensor_flat.reshape(B, C, W, H).transpose(-1, -2) 273 | result_tensor.scatter_(3, expanded_index, tensor_reshaped) 274 | return result_tensor 275 | 276 | class CrossScan(torch.autograd.Function): 277 | @staticmethod 278 | def forward(ctx, x: torch.Tensor): 279 | B, C, H, W = x.shape 280 | ctx.shape = (B, C, H, W) 281 | # xs = x.new_empty((B, 4, C, H * W)) 282 | xs = x.new_empty((B, 8, C, H * W)) 283 | xs[:, 0] = x.flatten(2, 3) 284 | xs[:, 1] = x.transpose(dim0=2, dim1=3).flatten(2, 3) 285 | xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1]) 286 | 287 | xs[:, 4] = diagonal_gather(x) 288 | xs[:, 5] = antidiagonal_gather(x) 289 | xs[:, 6:8] = torch.flip(xs[:, 4:6], dims=[-1]) 290 | 291 | return xs 292 | 293 | @staticmethod 294 | def backward(ctx, ys: torch.Tensor): 295 | # out: (b, k, d, l) 296 | B, C, H, W = ctx.shape 297 | L = H * W 298 | # ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L) 299 | y_rb = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L) 300 | # y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L) 301 | y_rb = y_rb[:, 0] + y_rb[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L) 302 | y_rb = y_rb.view(B, -1, H, W) 303 | 304 | y_da = ys[:, 4:6] + ys[:, 6:8].flip(dims=[-1]).view(B, 2, -1, L) 305 | y_da = diagonal_scatter(y_da[:, 0], (B,C,H,W)) + antidiagonal_scatter(y_da[:, 1], (B,C,H,W)) 306 | 307 | y_res = y_rb + y_da 308 | # return y.view(B, -1, H, W) 309 | return y_res 310 | 311 | 312 | class CrossMerge(torch.autograd.Function): 313 | @staticmethod 314 | def forward(ctx, ys: torch.Tensor): 315 | B, K, D, H, W = ys.shape 316 | ctx.shape = (H, W) 317 | ys = ys.view(B, K, D, -1) 318 | # ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) 319 | # y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1) 320 | 321 | y_rb = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) 322 | y_rb = y_rb[:, 0] + y_rb[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1) 323 | y_rb = y_rb.view(B, -1, H, W) 324 | 325 | y_da = ys[:, 4:6] + ys[:, 6:8].flip(dims=[-1]).view(B, 2, D, -1) 326 | y_da = diagonal_scatter(y_da[:, 0], (B,D,H,W)) + antidiagonal_scatter(y_da[:, 1], (B,D,H,W)) 327 | 328 | y_res = y_rb + y_da 329 | return y_res.view(B, D, -1) 330 | # return y 331 | 332 | @staticmethod 333 | def backward(ctx, x: torch.Tensor): 334 | # B, D, L = x.shape 335 | # out: (b, k, d, l) 336 | H, W = ctx.shape 337 | B, C, L = x.shape 338 | # xs = x.new_empty((B, 4, C, L)) 339 | xs = x.new_empty((B, 8, C, L)) 340 | 341 | # 横向和竖向扫描 342 | xs[:, 0] = x 343 | xs[:, 1] = x.view(B, C, H, W).transpose(dim0=2, dim1=3).flatten(2, 3) 344 | xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1]) 345 | # xs = xs.view(B, 4, C, H, W) 346 | 347 | xs[:, 4] = diagonal_gather(x.view(B,C,H,W)) 348 | xs[:, 5] = antidiagonal_gather(x.view(B,C,H,W)) 349 | xs[:, 6:8] = torch.flip(xs[:, 4:6], dims=[-1]) 350 | 351 | # return xs 352 | return xs.view(B, 8, C, H, W) 353 | 354 | 355 | # these are for ablations ============= 356 | class CrossScan_Ab_2direction(torch.autograd.Function): 357 | @staticmethod 358 | def forward(ctx, x: torch.Tensor): 359 | B, C, H, W = x.shape 360 | ctx.shape = (B, C, H, W) 361 | xs = x.new_empty((B, 4, C, H * W)) 362 | xs[:, 0] = x.flatten(2, 3) 363 | xs[:, 1] = x.flatten(2, 3) 364 | xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1]) 365 | return xs 366 | 367 | @staticmethod 368 | def backward(ctx, ys: torch.Tensor): 369 | # out: (b, k, d, l) 370 | B, C, H, W = ctx.shape 371 | L = H * W 372 | ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L) 373 | y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L) 374 | return y.view(B, -1, H, W) 375 | 376 | 377 | class CrossMerge_Ab_2direction(torch.autograd.Function): 378 | @staticmethod 379 | def forward(ctx, ys: torch.Tensor): 380 | B, K, D, H, W = ys.shape 381 | ctx.shape = (H, W) 382 | ys = ys.view(B, K, D, -1) 383 | ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) 384 | y = ys.sum(dim=1) 385 | return y 386 | 387 | @staticmethod 388 | def backward(ctx, x: torch.Tensor): 389 | # B, D, L = x.shape 390 | # out: (b, k, d, l) 391 | H, W = ctx.shape 392 | B, C, L = x.shape 393 | xs = x.new_empty((B, 4, C, L)) 394 | xs[:, 0] = x 395 | xs[:, 1] = x 396 | xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1]) 397 | xs = xs.view(B, 4, C, H, W) 398 | return xs 399 | 400 | 401 | class CrossScan_Ab_1direction(torch.autograd.Function): 402 | @staticmethod 403 | def forward(ctx, x: torch.Tensor): 404 | B, C, H, W = x.shape 405 | ctx.shape = (B, C, H, W) 406 | xs = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1).contiguous() 407 | return xs 408 | 409 | @staticmethod 410 | def backward(ctx, ys: torch.Tensor): 411 | # out: (b, k, d, l) 412 | B, C, H, W = ctx.shape 413 | y = ys.sum(dim=1).view(B, C, H, W) 414 | return y 415 | 416 | 417 | class CrossMerge_Ab_1direction(torch.autograd.Function): 418 | @staticmethod 419 | def forward(ctx, ys: torch.Tensor): 420 | B, K, D, H, W = ys.shape 421 | ctx.shape = (H, W) 422 | y = ys.sum(dim=1).view(B, D, H * W) 423 | return y 424 | 425 | @staticmethod 426 | def backward(ctx, x: torch.Tensor): 427 | # B, D, L = x.shape 428 | # out: (b, k, d, l) 429 | H, W = ctx.shape 430 | B, C, L = x.shape 431 | xs = x.view(B, 1, C, L).repeat(1, 4, 1, 1).contiguous().view(B, 4, C, H, W) 432 | return xs 433 | 434 | def cross_selective_scan( 435 | x: torch.Tensor=None, 436 | x_proj_weight: torch.Tensor=None, 437 | x_proj_bias: torch.Tensor=None, 438 | dt_projs_weight: torch.Tensor=None, 439 | dt_projs_bias: torch.Tensor=None, 440 | A_logs: torch.Tensor=None, 441 | Ds: torch.Tensor=None, 442 | delta_softplus = True, 443 | out_norm: torch.nn.Module=None, 444 | out_norm_shape="v0", 445 | # ============================== 446 | to_dtype=True, # True: final out to dtype 447 | force_fp32=False, # True: input fp32 448 | # ============================== 449 | nrows = -1, # for SelectiveScanNRow; 0: auto; -1: disable; 450 | backnrows = -1, # for SelectiveScanNRow; 0: auto; -1: disable; 451 | ssoflex=True, # True: out fp32 in SSOflex; else, SSOflex is the same as SSCore 452 | # ============================== 453 | SelectiveScan=None, 454 | CrossScan=CrossScan, 455 | CrossMerge=CrossMerge, 456 | ): 457 | # out_norm: whatever fits (B, L, C); LayerNorm; Sigmoid; Softmax(dim=1);... 458 | 459 | B, D, H, W = x.shape 460 | D, N = A_logs.shape 461 | K, D, R = dt_projs_weight.shape 462 | L = H * W 463 | 464 | if nrows == 0: 465 | if D % 4 == 0: 466 | nrows = 4 467 | elif D % 3 == 0: 468 | nrows = 3 469 | elif D % 2 == 0: 470 | nrows = 2 471 | else: 472 | nrows = 1 473 | 474 | if backnrows == 0: 475 | if D % 4 == 0: 476 | backnrows = 4 477 | elif D % 3 == 0: 478 | backnrows = 3 479 | elif D % 2 == 0: 480 | backnrows = 2 481 | else: 482 | backnrows = 1 483 | 484 | def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True): 485 | return SelectiveScan.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows, backnrows, ssoflex) 486 | 487 | xs = CrossScan.apply(x) 488 | 489 | x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, x_proj_weight) 490 | if x_proj_bias is not None: 491 | x_dbl = x_dbl + x_proj_bias.view(1, K, -1, 1) 492 | dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2) 493 | dts = torch.einsum("b k r l, k d r -> b k d l", dts, dt_projs_weight) 494 | xs = xs.view(B, -1, L) 495 | dts = dts.contiguous().view(B, -1, L) 496 | As = -torch.exp(A_logs.to(torch.float)) # (k * c, d_state) 497 | Bs = Bs.contiguous() 498 | Cs = Cs.contiguous() 499 | Ds = Ds.to(torch.float) # (K * c) 500 | delta_bias = dt_projs_bias.view(-1).to(torch.float) 501 | 502 | if force_fp32: 503 | xs = xs.to(torch.float) 504 | dts = dts.to(torch.float) 505 | Bs = Bs.to(torch.float) 506 | Cs = Cs.to(torch.float) 507 | ys: torch.Tensor = selective_scan( 508 | xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus 509 | ).view(B, K, -1, H, W) 510 | y: torch.Tensor = CrossMerge.apply(ys) 511 | 512 | if out_norm_shape in ["v1"]: # (B, C, H, W) 513 | y = out_norm(y.view(B, -1, H, W)).permute(0, 2, 3, 1) # (B, H, W, C) 514 | else: # (B, L, C) 515 | y = y.transpose(dim0=1, dim1=2).contiguous() # (B, L, C) 516 | y = out_norm(y).view(B, H, W, -1) 517 | 518 | return (y.to(x.dtype) if to_dtype else y) 519 | 520 | class SS2D(nn.Module): 521 | def __init__( 522 | self, 523 | # basic dims =========== 524 | d_model=64, 525 | d_state=16, 526 | ssm_ratio=2.0, 527 | dt_rank="auto", 528 | act_layer=nn.SiLU, 529 | # dwconv =============== 530 | d_conv=3, # < 2 means no conv 531 | conv_bias=True, 532 | # ====================== 533 | dropout=0.0, 534 | bias=False, 535 | # dt init ============== 536 | dt_min=0.001, 537 | dt_max=0.1, 538 | dt_init="random", 539 | dt_scale=1.0, 540 | dt_init_floor=1e-4, 541 | initialize="v0", 542 | # ====================== 543 | forward_type="v2", 544 | # ====================== 545 | **kwargs, 546 | ): 547 | factory_kwargs = {"device": None, "dtype": None} 548 | super().__init__() 549 | d_inner = int(ssm_ratio * d_model) 550 | dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank 551 | self.d_conv = d_conv 552 | 553 | # tags for forward_type ============================== 554 | def checkpostfix(tag, value): 555 | ret = value[-len(tag):] == tag 556 | if ret: 557 | value = value[:-len(tag)] 558 | return ret, value 559 | 560 | self.disable_force32, forward_type = checkpostfix("no32", forward_type) 561 | self.disable_z, forward_type = checkpostfix("noz", forward_type) 562 | self.disable_z_act, forward_type = checkpostfix("nozact", forward_type) 563 | 564 | # softmax | sigmoid | dwconv | norm =========================== 565 | if forward_type[-len("none"):] == "none": 566 | forward_type = forward_type[:-len("none")] 567 | self.out_norm = nn.Identity() 568 | elif forward_type[-len("dwconv3"):] == "dwconv3": 569 | forward_type = forward_type[:-len("dwconv3")] 570 | self.out_norm = nn.Conv2d(d_inner, d_inner, kernel_size=3, padding=1, groups=d_inner, bias=False) 571 | self.out_norm_shape = "v1" 572 | elif forward_type[-len("softmax"):] == "softmax": 573 | forward_type = forward_type[:-len("softmax")] 574 | self.out_norm = nn.Softmax(dim=1) 575 | elif forward_type[-len("sigmoid"):] == "sigmoid": 576 | forward_type = forward_type[:-len("sigmoid")] 577 | self.out_norm = nn.Sigmoid() 578 | else: 579 | self.out_norm = nn.LayerNorm(d_inner) 580 | 581 | # forward_type debug ======================================= 582 | FORWARD_TYPES = dict( 583 | v0=self.forward_corev0, 584 | # v2=partial(self.forward_corev2, force_fp32=(not self.disable_force32), SelectiveScan=SelectiveScanCore), 585 | v2=partial(self.forward_corev2, force_fp32=True, SelectiveScan=SelectiveScanCore), 586 | v3=partial(self.forward_corev2, force_fp32=False, SelectiveScan=SelectiveScanOflex), 587 | v31d=partial(self.forward_corev2, force_fp32=False, SelectiveScan=SelectiveScanOflex, cross_selective_scan=partial( 588 | cross_selective_scan, CrossScan=CrossScan_Ab_1direction, CrossMerge=CrossMerge_Ab_1direction, 589 | )), 590 | v32d=partial(self.forward_corev2, force_fp32=False, SelectiveScan=SelectiveScanOflex, cross_selective_scan=partial( 591 | cross_selective_scan, CrossScan=CrossScan_Ab_2direction, CrossMerge=CrossMerge_Ab_2direction, 592 | )), 593 | # =============================== 594 | fake=partial(self.forward_corev2, force_fp32=(not self.disable_force32), SelectiveScan=SelectiveScanFake), 595 | v1=partial(self.forward_corev2, force_fp32=True, SelectiveScan=SelectiveScanOflex), 596 | v01=partial(self.forward_corev2, force_fp32=(not self.disable_force32), SelectiveScan=SelectiveScanMamba), 597 | ) 598 | if forward_type.startswith("debug"): 599 | from .ss2d_ablations import SS2D_ForwardCoreSpeedAblations, SS2D_ForwardCoreModeAblations, cross_selective_scanv2 600 | FORWARD_TYPES.update(dict( 601 | debugforward_core_mambassm_seq=partial(SS2D_ForwardCoreSpeedAblations.forward_core_mambassm_seq, self), 602 | debugforward_core_mambassm=partial(SS2D_ForwardCoreSpeedAblations.forward_core_mambassm, self), 603 | debugforward_core_mambassm_fp16=partial(SS2D_ForwardCoreSpeedAblations.forward_core_mambassm_fp16, self), 604 | debugforward_core_mambassm_fusecs=partial(SS2D_ForwardCoreSpeedAblations.forward_core_mambassm_fusecs, self), 605 | debugforward_core_mambassm_fusecscm=partial(SS2D_ForwardCoreSpeedAblations.forward_core_mambassm_fusecscm, self), 606 | debugforward_core_sscore_fusecscm=partial(SS2D_ForwardCoreSpeedAblations.forward_core_sscore_fusecscm, self), 607 | debugforward_core_sscore_fusecscm_fwdnrow=partial(SS2D_ForwardCoreSpeedAblations.forward_core_ssnrow_fusecscm_fwdnrow, self), 608 | debugforward_core_sscore_fusecscm_bwdnrow=partial(SS2D_ForwardCoreSpeedAblations.forward_core_ssnrow_fusecscm_bwdnrow, self), 609 | debugforward_core_sscore_fusecscm_fbnrow=partial(SS2D_ForwardCoreSpeedAblations.forward_core_ssnrow_fusecscm_fbnrow, self), 610 | debugforward_core_ssoflex_fusecscm=partial(SS2D_ForwardCoreSpeedAblations.forward_core_ssoflex_fusecscm, self), 611 | debugforward_core_ssoflex_fusecscm_i16o32=partial(SS2D_ForwardCoreSpeedAblations.forward_core_ssoflex_fusecscm_i16o32, self), 612 | debugscan_sharessm=partial(self.forward_corev2, force_fp32=False, SelectiveScan=SelectiveScanOflex, cross_selective_scan=cross_selective_scanv2), 613 | )) 614 | self.forward_core = FORWARD_TYPES.get(forward_type, None) 615 | k_group = 8 if forward_type not in ["debugscan_sharessm"] else 1 616 | 617 | # in proj ======================================= 618 | d_proj = d_inner if self.disable_z else (d_inner * 2) 619 | self.in_proj = nn.Linear(d_model, d_proj, bias=bias, **factory_kwargs) 620 | self.act: nn.Module = act_layer() 621 | 622 | # conv ======================================= 623 | if d_conv > 1: 624 | self.conv2d = nn.Conv2d( 625 | in_channels=d_inner, 626 | out_channels=d_inner, 627 | groups=d_inner, 628 | bias=conv_bias, 629 | kernel_size=d_conv, 630 | padding=(d_conv - 1) // 2, 631 | **factory_kwargs, 632 | ) 633 | 634 | # x proj ============================ 635 | self.x_proj = [ 636 | nn.Linear(d_inner, (dt_rank + d_state * 2), bias=False, **factory_kwargs) 637 | for _ in range(k_group) 638 | ] 639 | self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner) 640 | del self.x_proj 641 | 642 | # out proj ======================================= 643 | self.out_proj = nn.Linear(d_inner, d_model, bias=bias, **factory_kwargs) 644 | self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity() 645 | 646 | if initialize in ["v0"]: 647 | # dt proj ============================ 648 | self.dt_projs = [ 649 | self.dt_init(dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs) 650 | for _ in range(k_group) 651 | ] 652 | self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K, inner, rank) 653 | self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K, inner) 654 | del self.dt_projs 655 | 656 | # A, D ======================================= 657 | self.A_logs = self.A_log_init(d_state, d_inner, copies=k_group, merge=True) # (K * D, N) 658 | self.Ds = self.D_init(d_inner, copies=k_group, merge=True) # (K * D) 659 | elif initialize in ["v1"]: 660 | # simple init dt_projs, A_logs, Ds 661 | self.Ds = nn.Parameter(torch.ones((k_group * d_inner))) 662 | self.A_logs = nn.Parameter(torch.randn((k_group * d_inner, d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1 663 | self.dt_projs_weight = nn.Parameter(torch.randn((k_group, d_inner, dt_rank))) 664 | self.dt_projs_bias = nn.Parameter(torch.randn((k_group, d_inner))) 665 | elif initialize in ["v2"]: 666 | # simple init dt_projs, A_logs, Ds 667 | self.Ds = nn.Parameter(torch.ones((k_group * d_inner))) 668 | self.A_logs = nn.Parameter(torch.zeros((k_group * d_inner, d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1 669 | self.dt_projs_weight = nn.Parameter(torch.randn((k_group, d_inner, dt_rank))) 670 | self.dt_projs_bias = nn.Parameter(torch.randn((k_group, d_inner))) 671 | 672 | @staticmethod 673 | def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs): 674 | dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs) 675 | 676 | # Initialize special dt projection to preserve variance at initialization 677 | dt_init_std = dt_rank**-0.5 * dt_scale 678 | if dt_init == "constant": 679 | nn.init.constant_(dt_proj.weight, dt_init_std) 680 | elif dt_init == "random": 681 | nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) 682 | else: 683 | raise NotImplementedError 684 | 685 | # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max 686 | dt = torch.exp( 687 | torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) 688 | + math.log(dt_min) 689 | ).clamp(min=dt_init_floor) 690 | # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 691 | inv_dt = dt + torch.log(-torch.expm1(-dt)) 692 | with torch.no_grad(): 693 | dt_proj.bias.copy_(inv_dt) 694 | # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit 695 | # dt_proj.bias._no_reinit = True 696 | 697 | return dt_proj 698 | 699 | @staticmethod 700 | def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True): 701 | # S4D real initialization 702 | A = repeat( 703 | torch.arange(1, d_state + 1, dtype=torch.float32, device=device), 704 | "n -> d n", 705 | d=d_inner, 706 | ).contiguous() 707 | A_log = torch.log(A) # Keep A_log in fp32 708 | if copies > 0: 709 | A_log = repeat(A_log, "d n -> r d n", r=copies) 710 | if merge: 711 | A_log = A_log.flatten(0, 1) 712 | A_log = nn.Parameter(A_log) 713 | A_log._no_weight_decay = True 714 | return A_log 715 | 716 | @staticmethod 717 | def D_init(d_inner, copies=-1, device=None, merge=True): 718 | # D "skip" parameter 719 | D = torch.ones(d_inner, device=device) 720 | if copies > 0: 721 | D = repeat(D, "n1 -> r n1", r=copies) 722 | if merge: 723 | D = D.flatten(0, 1) 724 | D = nn.Parameter(D) # Keep in fp32 725 | D._no_weight_decay = True 726 | return D 727 | 728 | # only used to run previous version 729 | def forward_corev0(self, x: torch.Tensor, to_dtype=False, channel_first=False): 730 | def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True, nrows=1): 731 | return SelectiveScanCore.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows, False) 732 | 733 | if not channel_first: 734 | x = x.permute(0, 3, 1, 2).contiguous() 735 | B, D, H, W = x.shape 736 | D, N = self.A_logs.shape 737 | K, D, R = self.dt_projs_weight.shape 738 | L = H * W 739 | 740 | x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) 741 | xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) 742 | 743 | x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight) 744 | # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) 745 | dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2) 746 | dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight) 747 | 748 | xs = xs.float().view(B, -1, L) # (b, k * d, l) 749 | dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l) 750 | Bs = Bs.float() # (b, k, d_state, l) 751 | Cs = Cs.float() # (b, k, d_state, l) 752 | 753 | As = -torch.exp(self.A_logs.float()) # (k * d, d_state) 754 | Ds = self.Ds.float() # (k * d) 755 | dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) 756 | out_y = selective_scan( 757 | xs, dts, 758 | As, Bs, Cs, Ds, 759 | delta_bias=dt_projs_bias, 760 | delta_softplus=True, 761 | ).view(B, K, -1, L) 762 | # assert out_y.dtype == torch.float 763 | 764 | inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) 765 | wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) 766 | invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) 767 | y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y 768 | y = y.transpose(dim0=1, dim1=2).contiguous() # (B, L, C) 769 | y = self.out_norm(y).view(B, H, W, -1) 770 | 771 | return (y.to(x.dtype) if to_dtype else y) 772 | 773 | def forward_corev2(self, x: torch.Tensor, channel_first=False, SelectiveScan=SelectiveScanOflex, cross_selective_scan=cross_selective_scan, force_fp32=None): 774 | if not channel_first: 775 | x = x.permute(0, 3, 1, 2).contiguous() 776 | x = cross_selective_scan( 777 | x, self.x_proj_weight, None, self.dt_projs_weight, self.dt_projs_bias, 778 | self.A_logs, self.Ds, delta_softplus=True, 779 | out_norm=getattr(self, "out_norm", None), 780 | out_norm_shape=getattr(self, "out_norm_shape", "v0"), 781 | force_fp32=force_fp32, 782 | SelectiveScan=SelectiveScan, 783 | ) 784 | return x 785 | 786 | def forward(self, x: torch.Tensor, **kwargs): 787 | with_dconv = (self.d_conv > 1) 788 | x = self.in_proj(x) 789 | if not self.disable_z: 790 | x, z = x.chunk(2, dim=-1) # (b, h, w, d) 791 | if not self.disable_z_act: 792 | z = self.act(z) 793 | if with_dconv: 794 | x = x.permute(0, 3, 1, 2).contiguous() 795 | x = self.conv2d(x) # (b, d, h, w) 796 | x = self.act(x) 797 | y = self.forward_core(x, channel_first=with_dconv) 798 | if not self.disable_z: 799 | y = y * z 800 | out = self.dropout(self.out_proj(y)) 801 | return out 802 | 803 | class AttentionBlockWithMLP(nn.Module): 804 | def __init__(self, in_channels, reduction=16): 805 | super(AttentionBlockWithMLP, self).__init__() 806 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 807 | self.fc = nn.Sequential( 808 | nn.Linear(in_channels, in_channels // reduction, bias=False), 809 | nn.ReLU(inplace=True), 810 | nn.Linear(in_channels // reduction, in_channels // reduction, bias=False), 811 | nn.ReLU(inplace=True), 812 | nn.Linear(in_channels // reduction, in_channels, bias=False), 813 | nn.Sigmoid() 814 | ) 815 | def forward(self, x): 816 | b, c, _, _ = x.size() 817 | y = self.avg_pool(x).view(b, c) 818 | y = self.fc(y).view(b, c, 1, 1) 819 | return x * y.expand_as(x) 820 | 821 | 822 | class ConvSSM1(nn.Module): 823 | def __init__( 824 | self, 825 | hidden_dim: int = 0, 826 | drop_path: float = 0, 827 | norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), 828 | attn_drop_rate: float = 0, 829 | d_state: int = 16, 830 | **kwargs, 831 | ): 832 | super().__init__() 833 | self.ln_1 = norm_layer(hidden_dim//2) 834 | self.self_attention = SS2D(d_model=hidden_dim//2, dropout=attn_drop_rate, d_state=d_state, **kwargs) 835 | self.drop_path = DropPath(drop_path) 836 | self.frm = nn.Sequential( 837 | nn.Conv2d(in_channels=hidden_dim//2,out_channels=hidden_dim//4, kernel_size=1, bias=False), 838 | nn.BatchNorm2d(hidden_dim//4) , 839 | nn.ReLU(inplace=True) , 840 | AttentionBlockWithMLP(hidden_dim//4) , 841 | nn.Conv2d(hidden_dim//4, hidden_dim//2, kernel_size=1, bias=False) , 842 | nn.BatchNorm2d(hidden_dim//2) 843 | ) 844 | self.finalconv11 = nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=1, stride=1) 845 | def forward(self, input: torch.Tensor): 846 | input_left, input_right = input.chunk(2,dim=-1) 847 | x = input_right + self.drop_path(self.self_attention(self.ln_1(input_right))) 848 | input_left = input_left.permute(0,3,1,2).contiguous() 849 | input_left = self.frm(input_left) 850 | x = x.permute(0,3,1,2).contiguous() 851 | output = torch.cat((input_left,x),dim=1) 852 | output = self.finalconv11(output).permute(0,2,3,1).contiguous() 853 | return output+input 854 | 855 | class YOLOXHead(nn.Module): 856 | def __init__(self, num_classes, width = 1.0, in_channels = [256, 512, 1024], act = "silu", depthwise = False,): 857 | super().__init__() 858 | Conv = DWConv if depthwise else BaseConv 859 | 860 | self.cls_convs = nn.ModuleList() 861 | self.reg_convs = nn.ModuleList() 862 | self.cls_preds = nn.ModuleList() 863 | self.reg_preds = nn.ModuleList() 864 | self.obj_preds = nn.ModuleList() 865 | self.stems = nn.ModuleList() 866 | 867 | for i in range(len(in_channels)): 868 | self.stems.append(BaseConv(in_channels = int(in_channels[i] * width), out_channels = int(256 * width), ksize = 1, stride = 1, act = act)) 869 | self.cls_convs.append(nn.Sequential(*[ 870 | Conv(in_channels = int(256 * width), out_channels = int(256 * width), ksize = 3, stride = 1, act = act), 871 | Conv(in_channels = int(256 * width), out_channels = int(256 * width), ksize = 3, stride = 1, act = act), 872 | ])) 873 | self.cls_preds.append( 874 | nn.Conv2d(in_channels = int(256 * width), out_channels = num_classes, kernel_size = 1, stride = 1, padding = 0) 875 | ) 876 | 877 | 878 | self.reg_convs.append(nn.Sequential(*[ 879 | Conv(in_channels = int(256 * width), out_channels = int(256 * width), ksize = 3, stride = 1, act = act), 880 | Conv(in_channels = int(256 * width), out_channels = int(256 * width), ksize = 3, stride = 1, act = act) 881 | ])) 882 | self.reg_preds.append( 883 | nn.Conv2d(in_channels = int(256 * width), out_channels = 4, kernel_size = 1, stride = 1, padding = 0) 884 | ) 885 | self.obj_preds.append( 886 | nn.Conv2d(in_channels = int(256 * width), out_channels = 1, kernel_size = 1, stride = 1, padding = 0) 887 | ) 888 | 889 | def forward(self, inputs): 890 | outputs = [] 891 | for k, x in enumerate(inputs): 892 | cls_feat = self.cls_convs[k](x) 893 | cls_output = self.cls_preds[k](cls_feat) 894 | reg_feat = self.reg_convs[k](x) 895 | reg_output = self.reg_preds[k](reg_feat) 896 | obj_output = self.obj_preds[k](reg_feat) 897 | 898 | output = torch.cat([reg_output, obj_output, cls_output], 1) 899 | outputs.append(output) 900 | return outputs 901 | 902 | class YOLOPAFPN(nn.Module): 903 | def __init__(self, depth = 1.0, width = 1.0, in_features = ("dark3", "dark4", "dark5"), in_channels = [256, 512, 1024], depthwise = False, act = "silu"): 904 | super().__init__() 905 | Conv = DWConv if depthwise else BaseConv 906 | self.backbone = CSPDarknet(depth, width, depthwise = depthwise, act = act) 907 | self.in_features = in_features 908 | self.upsample = nn.Upsample(scale_factor=2, mode="nearest") 909 | norm_layer=nn.LayerNorm 910 | self.lateral_conv0 = BaseConv(int(in_channels[2] * width), int(in_channels[1] * width), 1, 1, act=act) 911 | self.C3_p4 = nn.ModuleList([ConvSSM1( 912 | hidden_dim=int(512), 913 | drop_path=0., 914 | norm_layer=norm_layer, 915 | attn_drop_rate=0., 916 | d_state=16, 917 | )]) 918 | self.reduce_conv1 = BaseConv(512,128, 1, 1, act=act) 919 | self.C3_p3 = nn.ModuleList([ConvSSM1( 920 | hidden_dim=int(256), 921 | drop_path=0., 922 | norm_layer=norm_layer, 923 | attn_drop_rate=0., 924 | d_state=16, 925 | )]) 926 | self.bu_conv2 = Conv(256,128, 3, 2, act=act) 927 | self.reduce_conv111 = Conv(256,128, 3, 2, act=act) 928 | self.C3_n3 = nn.ModuleList([ConvSSM1( 929 | hidden_dim=int(256), 930 | drop_path=0., 931 | norm_layer=norm_layer, 932 | attn_drop_rate=0., 933 | d_state=16, 934 | )]) 935 | self.bu_conv1 = Conv(256,256, 3, 2, act=act) 936 | self.C3_n4 = nn.ModuleList([ConvSSM1( 937 | hidden_dim=int(512), 938 | drop_path=0., 939 | norm_layer=norm_layer, 940 | attn_drop_rate=0., 941 | d_state=16, 942 | )]) 943 | def forward(self, input): 944 | out_features = self.backbone.forward(input) 945 | [feat1, feat2, feat3] = [out_features[f] for f in self.in_features] #torch.Size([8, 128, 40, 40]) torch.Size([8, 256, 20, 20]) torch.Size([8, 512, 10, 10]) 946 | #-------------------------------------------# 947 | P5 = self.lateral_conv0(feat3) # torch.Size([2, 256, 10, 10]) 948 | #-------------------------------------------# 949 | P5_upsample = self.upsample(P5) # torch.Size([2, 256, 20, 20]) 950 | #-------------------------------------------# 951 | P5_upsample = torch.cat([P5_upsample, feat2], 1) # torch.Size([2, 512, 20, 20]) 952 | P5_upsample=P5_upsample.permute(0,2,3,1) 953 | 954 | for blk in self.C3_p4: 955 | P5_upsample = blk(P5_upsample) # torch.Size([8, 20, 20, 256]) 956 | #-------------------------------------------# 957 | P5_upsample=P5_upsample.permute(0,3,1,2) 958 | P4 = self.reduce_conv1(P5_upsample) # torch.Size([8, 128, 20, 20]) 959 | P4_upsample = self.upsample(P4) # torch.Size([8, 128, 40, 40]) 960 | #-------------------------------------------# 961 | P4_upsample = torch.cat([P4_upsample, feat1], 1) # 962 | P4_upsample=P4_upsample.permute(0,2,3,1) # torch.Size([8, 40, 40, 256]) 963 | for blk3 in self.C3_p3: 964 | P3_out = blk3(P4_upsample) # torch.Size([8, 40, 40, 128]) 965 | P3_out1=P3_out.permute(0,3,1,2) # 966 | P3_out=self.reduce_conv111(P3_out1) 967 | P3_downsample = self.bu_conv2(P3_out1) # torch.Size([8, 128, 20, 20]) 968 | #-------------------------------------------# 969 | P3_downsample = torch.cat([P3_downsample, P4], 1) # torch.Size([8, 256, 20, 20])) 970 | P3_downsample=P3_downsample.permute(0,2,3,1) # 971 | for blk4 in self.C3_n3: 972 | P4_out = blk4(P3_downsample) 973 | #-------------------------------------------# 974 | P4_out=P4_out.permute(0,3,1,2) #torch.Size([8, 256, 20, 20]) 975 | P4_downsample = self.bu_conv1(P4_out) # torch.Size([8, 256, 10, 10]) 976 | P4_downsample = torch.cat([P4_downsample, P5], 1) # torch.Size([8, 512, 10, 10]) 977 | P4_downsample=P4_downsample.permute(0,2,3,1) 978 | for blk5 in self.C3_n4: 979 | P5_out = blk5(P4_downsample) 980 | P5_out=P5_out.permute(0,3,1,2) #torch.Size([8, 512, 10, 10]) 981 | return (P3_out, P4_out, P5_out) 982 | 983 | class YoloBody(nn.Module): 984 | def __init__(self, num_classes, phi): 985 | super().__init__() 986 | depth_dict = {'nano': 0.33, 'tiny': 0.33, 's' : 0.33, 'm' : 0.67, 'l' : 1.00, 'x' : 1.33,} 987 | width_dict = {'nano': 0.25, 'tiny': 0.375, 's' : 0.50, 'm' : 0.75, 'l' : 1.00, 'x' : 1.25,} 988 | depth, width = depth_dict[phi], width_dict[phi] 989 | depthwise = True if phi == 'nano' else False 990 | 991 | self.backbone = YOLOPAFPN(depth, width, depthwise=depthwise) #.to(device) 992 | self.head = YOLOXHead(num_classes, width, depthwise=depthwise)#.to(device) 993 | 994 | def forward(self, x): 995 | fpn_outs = self.backbone.forward(x) 996 | outputs = self.head.forward(fpn_outs) 997 | return outputs 998 | -------------------------------------------------------------------------------- /nets/yolo_training.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Copyright (c) Megvii, Inc. and its affiliates. 4 | import math 5 | from copy import deepcopy 6 | from functools import partial 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | class IOUloss(nn.Module): 13 | def __init__(self, reduction="none", loss_type="iou"): 14 | super(IOUloss, self).__init__() 15 | self.reduction = reduction 16 | self.loss_type = loss_type 17 | 18 | def forward(self, pred, target): 19 | assert pred.shape[0] == target.shape[0] 20 | 21 | pred = pred.view(-1, 4) 22 | target = target.view(-1, 4) 23 | tl = torch.max( 24 | (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2) 25 | ) 26 | br = torch.min( 27 | (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2) 28 | ) 29 | 30 | area_p = torch.prod(pred[:, 2:], 1) 31 | area_g = torch.prod(target[:, 2:], 1) 32 | 33 | en = (tl < br).type(tl.type()).prod(dim=1) 34 | area_i = torch.prod(br - tl, 1) * en 35 | area_u = area_p + area_g - area_i 36 | iou = (area_i) / (area_u + 1e-16) 37 | 38 | if self.loss_type == "iou": 39 | loss = 1 - iou ** 2 40 | elif self.loss_type == "giou": 41 | c_tl = torch.min( 42 | (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2) 43 | ) 44 | c_br = torch.max( 45 | (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2) 46 | ) 47 | area_c = torch.prod(c_br - c_tl, 1) 48 | giou = iou - (area_c - area_u) / area_c.clamp(1e-16) 49 | loss = 1 - giou.clamp(min=-1.0, max=1.0) 50 | 51 | if self.reduction == "mean": 52 | loss = loss.mean() 53 | elif self.reduction == "sum": 54 | loss = loss.sum() 55 | 56 | return loss 57 | 58 | class YOLOLoss(nn.Module): 59 | def __init__(self, num_classes, fp16, strides=[8, 16, 32]): 60 | super().__init__() 61 | self.num_classes = num_classes 62 | self.strides = strides 63 | 64 | self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none") 65 | self.iou_loss = IOUloss(reduction="none") 66 | self.grids = [torch.zeros(1)] * len(strides) 67 | self.fp16 = fp16 68 | 69 | def forward(self, inputs, labels=None): 70 | outputs = [] 71 | x_shifts = [] 72 | y_shifts = [] 73 | expanded_strides = [] 74 | 75 | for k, (stride, output) in enumerate(zip(self.strides, inputs)): 76 | output, grid = self.get_output_and_grid(output, k, stride) 77 | x_shifts.append(grid[:, :, 0]) 78 | y_shifts.append(grid[:, :, 1]) 79 | expanded_strides.append(torch.ones_like(grid[:, :, 0]) * stride) 80 | outputs.append(output) 81 | 82 | return self.get_losses(x_shifts, y_shifts, expanded_strides, labels, torch.cat(outputs, 1)) 83 | 84 | def get_output_and_grid(self, output, k, stride): 85 | grid = self.grids[k] 86 | hsize, wsize = output.shape[-2:] 87 | if grid.shape[2:4] != output.shape[2:4]: 88 | yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)]) 89 | grid = torch.stack((xv, yv), 2).view(1, hsize, wsize, 2).type(output.type()) 90 | self.grids[k] = grid 91 | grid = grid.view(1, -1, 2) 92 | 93 | output = output.flatten(start_dim=2).permute(0, 2, 1) 94 | output[..., :2] = (output[..., :2] + grid.type_as(output)) * stride 95 | output[..., 2:4] = torch.exp(output[..., 2:4]) * stride 96 | return output, grid 97 | 98 | def get_losses(self, x_shifts, y_shifts, expanded_strides, labels, outputs): 99 | bbox_preds = outputs[:, :, :4] 100 | obj_preds = outputs[:, :, 4:5] 101 | cls_preds = outputs[:, :, 5:] 102 | 103 | total_num_anchors = outputs.shape[1] 104 | x_shifts = torch.cat(x_shifts, 1).type_as(outputs) 105 | y_shifts = torch.cat(y_shifts, 1).type_as(outputs) 106 | expanded_strides = torch.cat(expanded_strides, 1).type_as(outputs) 107 | 108 | cls_targets = [] 109 | reg_targets = [] 110 | obj_targets = [] 111 | fg_masks = [] 112 | 113 | num_fg = 0.0 114 | for batch_idx in range(outputs.shape[0]): 115 | num_gt = len(labels[batch_idx]) 116 | if num_gt == 0: 117 | cls_target = outputs.new_zeros((0, self.num_classes)) 118 | reg_target = outputs.new_zeros((0, 4)) 119 | obj_target = outputs.new_zeros((total_num_anchors, 1)) 120 | fg_mask = outputs.new_zeros(total_num_anchors).bool() 121 | else: 122 | gt_bboxes_per_image = labels[batch_idx][..., :4].type_as(outputs) 123 | gt_classes = labels[batch_idx][..., 4].type_as(outputs) 124 | bboxes_preds_per_image = bbox_preds[batch_idx] 125 | cls_preds_per_image = cls_preds[batch_idx] 126 | obj_preds_per_image = obj_preds[batch_idx] 127 | 128 | gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg_img = self.get_assignments( 129 | num_gt, total_num_anchors, gt_bboxes_per_image, gt_classes, bboxes_preds_per_image, cls_preds_per_image, obj_preds_per_image, 130 | expanded_strides, x_shifts, y_shifts, 131 | ) 132 | torch.cuda.empty_cache() 133 | num_fg += num_fg_img 134 | cls_target = F.one_hot(gt_matched_classes.to(torch.int64), self.num_classes).float() * pred_ious_this_matching.unsqueeze(-1) 135 | obj_target = fg_mask.unsqueeze(-1) 136 | reg_target = gt_bboxes_per_image[matched_gt_inds] 137 | cls_targets.append(cls_target) 138 | reg_targets.append(reg_target) 139 | obj_targets.append(obj_target.type(cls_target.type())) 140 | fg_masks.append(fg_mask) 141 | 142 | cls_targets = torch.cat(cls_targets, 0) 143 | reg_targets = torch.cat(reg_targets, 0) 144 | obj_targets = torch.cat(obj_targets, 0) 145 | fg_masks = torch.cat(fg_masks, 0) 146 | 147 | num_fg = max(num_fg, 1) 148 | loss_iou = (self.iou_loss(bbox_preds.view(-1, 4)[fg_masks], reg_targets)).sum() 149 | loss_obj = (self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets)).sum() 150 | loss_cls = (self.bcewithlog_loss(cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets)).sum() 151 | reg_weight = 5.0 152 | loss = reg_weight * loss_iou + loss_obj + loss_cls 153 | 154 | return loss / num_fg 155 | 156 | @torch.no_grad() 157 | def get_assignments(self, num_gt, total_num_anchors, gt_bboxes_per_image, gt_classes, bboxes_preds_per_image, cls_preds_per_image, obj_preds_per_image, expanded_strides, x_shifts, y_shifts): 158 | fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, total_num_anchors, num_gt) 159 | bboxes_preds_per_image = bboxes_preds_per_image[fg_mask] 160 | cls_preds_ = cls_preds_per_image[fg_mask] 161 | obj_preds_ = obj_preds_per_image[fg_mask] 162 | num_in_boxes_anchor = bboxes_preds_per_image.shape[0] 163 | pair_wise_ious = self.bboxes_iou(gt_bboxes_per_image, bboxes_preds_per_image, False) 164 | pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8) 165 | if self.fp16: 166 | with torch.cuda.amp.autocast(enabled=False): 167 | cls_preds_ = cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() * obj_preds_.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() 168 | gt_cls_per_image = F.one_hot(gt_classes.to(torch.int64), self.num_classes).float().unsqueeze(1).repeat(1, num_in_boxes_anchor, 1) 169 | pair_wise_cls_loss = F.binary_cross_entropy(cls_preds_.sqrt_(), gt_cls_per_image, reduction="none").sum(-1) 170 | else: 171 | cls_preds_ = cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() * obj_preds_.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() 172 | gt_cls_per_image = F.one_hot(gt_classes.to(torch.int64), self.num_classes).float().unsqueeze(1).repeat(1, num_in_boxes_anchor, 1) 173 | pair_wise_cls_loss = F.binary_cross_entropy(cls_preds_.sqrt_(), gt_cls_per_image, reduction="none").sum(-1) 174 | del cls_preds_ 175 | 176 | cost = pair_wise_cls_loss + 3.0 * pair_wise_ious_loss + 100000.0 * (~is_in_boxes_and_center).float() 177 | 178 | num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds = self.dynamic_k_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask) 179 | del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss 180 | return gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg 181 | 182 | def bboxes_iou(self, bboxes_a, bboxes_b, xyxy=True): 183 | if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4: 184 | raise IndexError 185 | 186 | if xyxy: 187 | tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2]) 188 | br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:]) 189 | area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1) 190 | area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1) 191 | else: 192 | tl = torch.max( 193 | (bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2), 194 | (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2), 195 | ) 196 | br = torch.min( 197 | (bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2), 198 | (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2), 199 | ) 200 | 201 | area_a = torch.prod(bboxes_a[:, 2:], 1) 202 | area_b = torch.prod(bboxes_b[:, 2:], 1) 203 | en = (tl < br).type(tl.type()).prod(dim=2) 204 | area_i = torch.prod(br - tl, 2) * en 205 | return area_i / (area_a[:, None] + area_b - area_i) 206 | 207 | def get_in_boxes_info(self, gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, total_num_anchors, num_gt, center_radius = 2.5): 208 | expanded_strides_per_image = expanded_strides[0] 209 | x_centers_per_image = ((x_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0).repeat(num_gt, 1) 210 | y_centers_per_image = ((y_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0).repeat(num_gt, 1) 211 | gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0] - 0.5 * gt_bboxes_per_image[:, 2]).unsqueeze(1).repeat(1, total_num_anchors) 212 | gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0] + 0.5 * gt_bboxes_per_image[:, 2]).unsqueeze(1).repeat(1, total_num_anchors) 213 | gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1] - 0.5 * gt_bboxes_per_image[:, 3]).unsqueeze(1).repeat(1, total_num_anchors) 214 | gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1] + 0.5 * gt_bboxes_per_image[:, 3]).unsqueeze(1).repeat(1, total_num_anchors) 215 | b_l = x_centers_per_image - gt_bboxes_per_image_l 216 | b_r = gt_bboxes_per_image_r - x_centers_per_image 217 | b_t = y_centers_per_image - gt_bboxes_per_image_t 218 | b_b = gt_bboxes_per_image_b - y_centers_per_image 219 | bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2) 220 | is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0 221 | is_in_boxes_all = is_in_boxes.sum(dim=0) > 0 222 | 223 | gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(1, total_num_anchors) - center_radius * expanded_strides_per_image.unsqueeze(0) 224 | gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(1, total_num_anchors) + center_radius * expanded_strides_per_image.unsqueeze(0) 225 | gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(1, total_num_anchors) - center_radius * expanded_strides_per_image.unsqueeze(0) 226 | gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(1, total_num_anchors) + center_radius * expanded_strides_per_image.unsqueeze(0) 227 | 228 | c_l = x_centers_per_image - gt_bboxes_per_image_l 229 | c_r = gt_bboxes_per_image_r - x_centers_per_image 230 | c_t = y_centers_per_image - gt_bboxes_per_image_t 231 | c_b = gt_bboxes_per_image_b - y_centers_per_image 232 | center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2) 233 | is_in_centers = center_deltas.min(dim=-1).values > 0.0 234 | is_in_centers_all = is_in_centers.sum(dim=0) > 0 235 | 236 | is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all 237 | is_in_boxes_and_center = is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor] 238 | return is_in_boxes_anchor, is_in_boxes_and_center 239 | 240 | def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask): 241 | matching_matrix = torch.zeros_like(cost) 242 | 243 | n_candidate_k = min(10, pair_wise_ious.size(1)) 244 | topk_ious, _ = torch.topk(pair_wise_ious, n_candidate_k, dim=1) 245 | dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1) 246 | 247 | for gt_idx in range(num_gt): 248 | _, pos_idx = torch.topk(cost[gt_idx], k=dynamic_ks[gt_idx].item(), largest=False) 249 | matching_matrix[gt_idx][pos_idx] = 1.0 250 | del topk_ious, dynamic_ks, pos_idx 251 | anchor_matching_gt = matching_matrix.sum(0) 252 | if (anchor_matching_gt > 1).sum() > 0: 253 | _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0) 254 | matching_matrix[:, anchor_matching_gt > 1] *= 0.0 255 | matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0 256 | 257 | fg_mask_inboxes = matching_matrix.sum(0) > 0.0 258 | num_fg = fg_mask_inboxes.sum().item() 259 | fg_mask[fg_mask.clone()] = fg_mask_inboxes 260 | matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0) 261 | gt_matched_classes = gt_classes[matched_gt_inds] 262 | 263 | pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[fg_mask_inboxes] 264 | return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds 265 | 266 | def is_parallel(model): 267 | return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) 268 | 269 | def de_parallel(model): 270 | return model.module if is_parallel(model) else model 271 | 272 | def copy_attr(a, b, include=(), exclude=()): 273 | for k, v in b.__dict__.items(): 274 | if (len(include) and k not in include) or k.startswith('_') or k in exclude: 275 | continue 276 | else: 277 | setattr(a, k, v) 278 | 279 | class ModelEMA: 280 | """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models 281 | Keeps a moving average of everything in the model state_dict (parameters and buffers) 282 | For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage 283 | """ 284 | 285 | def __init__(self, model, decay=0.9999, tau=2000, updates=0): 286 | # Create EMA 287 | self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA 288 | # if next(model.parameters()).device.type != 'cpu': 289 | # self.ema.half() # FP16 EMA 290 | self.updates = updates # number of EMA updates 291 | self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs) 292 | for p in self.ema.parameters(): 293 | p.requires_grad_(False) 294 | 295 | def update(self, model): 296 | # Update EMA parameters 297 | with torch.no_grad(): 298 | self.updates += 1 299 | d = self.decay(self.updates) 300 | 301 | msd = de_parallel(model).state_dict() 302 | for k, v in self.ema.state_dict().items(): 303 | if v.dtype.is_floating_point: 304 | v *= d 305 | v += (1 - d) * msd[k].detach() 306 | 307 | def update_attr(self, model, include=(), exclude=('process_group', 'reducer')): 308 | copy_attr(self.ema, model, include, exclude) 309 | 310 | def weights_init(net, init_type='normal', init_gain = 0.02): 311 | def init_func(m): 312 | classname = m.__class__.__name__ 313 | if hasattr(m, 'weight') and classname.find('Conv') != -1: 314 | if init_type == 'normal': 315 | torch.nn.init.normal_(m.weight.data, 0.0, init_gain) 316 | elif init_type == 'xavier': 317 | torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain) 318 | elif init_type == 'kaiming': 319 | torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 320 | elif init_type == 'orthogonal': 321 | torch.nn.init.orthogonal_(m.weight.data, gain=init_gain) 322 | else: 323 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 324 | elif classname.find('BatchNorm2d') != -1: 325 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 326 | torch.nn.init.constant_(m.bias.data, 0.0) 327 | print('initialize network with %s type' % init_type) 328 | net.apply(init_func) 329 | 330 | def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.05, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.05, step_num = 10): 331 | def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters): 332 | if iters <= warmup_total_iters: 333 | lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start 334 | elif iters >= total_iters - no_aug_iter: 335 | lr = min_lr 336 | else: 337 | lr = min_lr + 0.5 * (lr - min_lr) * ( 338 | 1.0 + math.cos(math.pi* (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter)) 339 | ) 340 | return lr 341 | 342 | def step_lr(lr, decay_rate, step_size, iters): 343 | if step_size < 1: 344 | raise ValueError("step_size must above 1.") 345 | n = iters // step_size 346 | out_lr = lr * decay_rate ** n 347 | return out_lr 348 | 349 | if lr_decay_type == "cos": 350 | warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3) 351 | warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6) 352 | no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15) 353 | func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter) 354 | else: 355 | decay_rate = (min_lr / lr) ** (1 / (step_num - 1)) 356 | step_size = total_iters / step_num 357 | func = partial(step_lr, lr, decay_rate, step_size) 358 | 359 | return func 360 | 361 | def set_optimizer_lr(optimizer, lr_scheduler_func, epoch): 362 | lr = lr_scheduler_func(epoch) 363 | for param_group in optimizer.param_groups: 364 | param_group['lr'] = lr 365 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import cv2 4 | import numpy as np 5 | from PIL import Image 6 | 7 | from yolo import YOLO 8 | 9 | if __name__ == "__main__": 10 | yolo = YOLO() 11 | mode = "heatmap" 12 | crop = False 13 | count = False 14 | video_path = 0 15 | video_save_path = "" 16 | video_fps = 25.0 17 | test_interval = 100 18 | fps_image_path = "img/street.jpg" 19 | dir_origin_path = "img/" 20 | dir_save_path = "img_out/" 21 | heatmap_save_path = "model_data/heatmap_vision.png" 22 | simplify = True 23 | onnx_save_path = "model_data/models.onnx" 24 | 25 | if mode == "predict": 26 | while True: 27 | img = input('Input image filename:') 28 | try: 29 | image = Image.open(img) 30 | except: 31 | print('Open Error! Try again!') 32 | continue 33 | else: 34 | r_image = yolo.detect_image(image, crop = crop, count=count) 35 | r_image.show() 36 | elif mode == "dir_predict": 37 | import os 38 | 39 | from tqdm import tqdm 40 | 41 | img_names = os.listdir(dir_origin_path) 42 | for img_name in tqdm(img_names): 43 | if img_name.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')): 44 | image_path = os.path.join(dir_origin_path, img_name) 45 | image = Image.open(image_path) 46 | r_image = yolo.detect_image(image) 47 | if not os.path.exists(dir_save_path): 48 | os.makedirs(dir_save_path) 49 | r_image.save(os.path.join(dir_save_path, img_name.replace(".jpg", ".png")), quality=95, subsampling=0) 50 | 51 | elif mode == "heatmap": 52 | while True: 53 | img = input('Input image filename:') 54 | try: 55 | image = Image.open(img) 56 | except: 57 | print('Open Error! Try again!') 58 | continue 59 | else: 60 | yolo.detect_heatmap(image, heatmap_save_path) 61 | 62 | elif mode == "export_onnx": 63 | yolo.convert_to_onnx(simplify, onnx_save_path) 64 | 65 | else: 66 | raise AssertionError("Please specify the correct mode: 'predict', 'video', 'fps', 'heatmap', 'export_onnx', 'dir_predict'.") 67 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | tensorboard 4 | scipy==1.2.1 5 | numpy==1.17.0 6 | matplotlib==3.1.2 7 | opencv_python==4.1.2.30 8 | tqdm==4.60.0 9 | Pillow==8.2.0 10 | h5py==2.10.0 11 | causal-convd==1.2.0.post2 12 | mamba-ssm==1.2.0.post1 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import datetime 3 | import os 4 | from functools import partial 5 | 6 | import numpy as np 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | import torch.distributed as dist 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | from torch.utils.data import DataLoader 13 | from nets.yolo import YoloBody 14 | from nets.yolo_training import (ModelEMA, YOLOLoss, get_lr_scheduler, 15 | set_optimizer_lr, weights_init) 16 | from utils.callbacks import EvalCallback, LossHistory 17 | from utils.dataloader import YoloDataset, yolo_dataset_collate 18 | from utils.utils import (get_classes, seed_everything, show_config, 19 | worker_init_fn) 20 | from utils.utils_fit import fit_one_epoch 21 | 22 | if __name__ == "__main__": 23 | 24 | Cuda = True 25 | seed = 11 26 | distributed = False 27 | sync_bn = False 28 | fp16 = False 29 | classes_path = 'model_data/sfew_classes.txt' 30 | model_path = 'model_data/yolox_s.pth' 31 | input_shape = [320, 320] 32 | phi = 's' 33 | mosaic = True 34 | mosaic_prob = 0.5 35 | mixup = True 36 | mixup_prob = 0.5 37 | special_aug_ratio = 0.7 38 | Init_Epoch = 0 39 | Freeze_Epoch = 100 40 | Freeze_batch_size = 32 41 | UnFreeze_Epoch = 300 42 | Unfreeze_batch_size = 16 43 | Freeze_Train = True 44 | Init_lr = 1e-2 45 | Min_lr = Init_lr * 0.01 46 | optimizer_type = "sgd" 47 | momentum = 0.937 48 | weight_decay = 5e-4 49 | lr_decay_type = "cos" 50 | save_period = 10 51 | save_dir = 'logs' 52 | eval_flag = True 53 | eval_period = 10 54 | num_workers = 0 55 | 56 | train_annotation_path = '2007_train.txt' 57 | val_annotation_path = '2007_val.txt' 58 | 59 | seed_everything(seed) 60 | 61 | ngpus_per_node = torch.cuda.device_count() 62 | if distributed: 63 | dist.init_process_group(backend="nccl") 64 | local_rank = int(os.environ["LOCAL_RANK"]) 65 | rank = int(os.environ["RANK"]) 66 | device = torch.device("cuda", local_rank) 67 | if local_rank == 0: 68 | print(f"[{os.getpid()}] (rank = {rank}, local_rank = {local_rank}) training...") 69 | print("Gpu Device Count : ", ngpus_per_node) 70 | else: 71 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 72 | local_rank = 0 73 | rank = 0 74 | 75 | class_names, num_classes = get_classes(classes_path) 76 | 77 | model = YoloBody(num_classes, phi) 78 | weights_init(model) 79 | if model_path != '': 80 | 81 | if local_rank == 0: 82 | print('Load weights {}.'.format(model_path)) 83 | model_dict = model.state_dict() 84 | pretrained_dict = torch.load(model_path, map_location = device) 85 | load_key, no_load_key, temp_dict = [], [], {} 86 | for k, v in pretrained_dict.items(): 87 | if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v): 88 | temp_dict[k] = v 89 | load_key.append(k) 90 | else: 91 | no_load_key.append(k) 92 | model_dict.update(temp_dict) 93 | model.load_state_dict(model_dict) 94 | 95 | if local_rank == 0: 96 | print("\nSuccessful Load Key:", str(load_key)[:500], "……\nSuccessful Load Key Num:", len(load_key)) 97 | print("\nFail To Load Key:", str(no_load_key)[:500], "……\nFail To Load Key num:", len(no_load_key)) 98 | 99 | yolo_loss = YOLOLoss(num_classes, fp16) 100 | 101 | if local_rank == 0: 102 | time_str = datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S') 103 | log_dir = os.path.join(save_dir, "loss_" + str(time_str)) 104 | loss_history = LossHistory(log_dir, model, input_shape=input_shape) 105 | else: 106 | loss_history = None 107 | if fp16: 108 | from torch.cuda.amp import GradScaler as GradScaler 109 | scaler = GradScaler() 110 | else: 111 | scaler = None 112 | 113 | model_train = model.train() 114 | if sync_bn and ngpus_per_node > 1 and distributed: 115 | model_train = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_train) 116 | elif sync_bn: 117 | print("Sync_bn is not support in one gpu or not distributed.") 118 | model_train = model_train.cuda() 119 | 120 | ema = ModelEMA(model_train) 121 | with open(train_annotation_path, encoding='utf-8') as f: 122 | train_lines = f.readlines() 123 | with open(val_annotation_path, encoding='utf-8') as f: 124 | val_lines = f.readlines() 125 | num_train = len(train_lines) 126 | num_val = len(val_lines) 127 | 128 | if local_rank == 0: 129 | show_config( 130 | classes_path = classes_path, model_path = model_path, input_shape = input_shape, \ 131 | Init_Epoch = Init_Epoch, Freeze_Epoch = Freeze_Epoch, UnFreeze_Epoch = UnFreeze_Epoch, Freeze_batch_size = Freeze_batch_size, Unfreeze_batch_size = Unfreeze_batch_size, Freeze_Train = Freeze_Train, \ 132 | Init_lr = Init_lr, Min_lr = Min_lr, optimizer_type = optimizer_type, momentum = momentum, lr_decay_type = lr_decay_type, \ 133 | save_period = save_period, save_dir = save_dir, num_workers = num_workers, num_train = num_train, num_val = num_val 134 | ) 135 | wanted_step = 5e4 if optimizer_type == "sgd" else 1.5e4 136 | total_step = num_train // Unfreeze_batch_size * UnFreeze_Epoch 137 | if True: 138 | UnFreeze_flag = False 139 | if Freeze_Train: 140 | for param in model.backbone.parameters(): 141 | param.requires_grad = False 142 | batch_size = Freeze_batch_size if Freeze_Train else Unfreeze_batch_size 143 | 144 | nbs = 64 145 | lr_limit_max = 1e-3 if optimizer_type == 'adam' else 5e-2 146 | lr_limit_min = 3e-4 if optimizer_type == 'adam' else 5e-4 147 | Init_lr_fit = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max) 148 | Min_lr_fit = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2) 149 | 150 | pg0, pg1, pg2 = [], [], [] 151 | for k, v in model.named_modules(): 152 | if hasattr(v, "bias") and isinstance(v.bias, nn.Parameter): 153 | pg2.append(v.bias) 154 | if isinstance(v, nn.BatchNorm2d) or "bn" in k: 155 | pg0.append(v.weight) 156 | elif hasattr(v, "weight") and isinstance(v.weight, nn.Parameter): 157 | pg1.append(v.weight) 158 | optimizer = { 159 | 'adam' : optim.Adam(pg0, Init_lr_fit, betas = (momentum, 0.999)), 160 | 'sgd' : optim.SGD(pg0, Init_lr_fit, momentum = momentum, nesterov=True) 161 | }[optimizer_type] 162 | optimizer.add_param_group({"params": pg1, "weight_decay": weight_decay}) 163 | optimizer.add_param_group({"params": pg2}) 164 | 165 | lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch) 166 | 167 | epoch_step = num_train // batch_size 168 | epoch_step_val = num_val // batch_size 169 | 170 | if ema: 171 | ema.updates = epoch_step * Init_Epoch 172 | train_dataset = YoloDataset(train_lines, input_shape, num_classes, epoch_length = UnFreeze_Epoch, \ 173 | mosaic=mosaic, mixup=mixup, mosaic_prob=mosaic_prob, mixup_prob=mixup_prob, train=True, special_aug_ratio=special_aug_ratio) 174 | val_dataset = YoloDataset(val_lines, input_shape, num_classes, epoch_length = UnFreeze_Epoch, \ 175 | mosaic=False, mixup=False, mosaic_prob=0, mixup_prob=0, train=False, special_aug_ratio=0) 176 | 177 | if distributed: 178 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True,) 179 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False,) 180 | batch_size = batch_size // ngpus_per_node 181 | shuffle = False 182 | else: 183 | train_sampler = None 184 | val_sampler = None 185 | shuffle = True 186 | 187 | gen = DataLoader(train_dataset, shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True, 188 | drop_last=True, collate_fn=yolo_dataset_collate, sampler=train_sampler, 189 | worker_init_fn=partial(worker_init_fn, rank=rank, seed=seed)) 190 | gen_val = DataLoader(val_dataset , shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True, 191 | drop_last=True, collate_fn=yolo_dataset_collate, sampler=val_sampler, 192 | worker_init_fn=partial(worker_init_fn, rank=rank, seed=seed)) 193 | if local_rank == 0: 194 | eval_callback = EvalCallback(model, input_shape, class_names, num_classes, val_lines, log_dir, Cuda, \ 195 | eval_flag=eval_flag, period=eval_period) 196 | else: 197 | eval_callback = None 198 | for epoch in range(Init_Epoch, UnFreeze_Epoch): 199 | if epoch >= Freeze_Epoch and not UnFreeze_flag and Freeze_Train: 200 | batch_size = Unfreeze_batch_size 201 | nbs = 64 202 | lr_limit_max = 1e-3 if optimizer_type == 'adam' else 5e-2 203 | lr_limit_min = 3e-4 if optimizer_type == 'adam' else 5e-4 204 | Init_lr_fit = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max) 205 | Min_lr_fit = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2) 206 | lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch) 207 | 208 | for param in model.backbone.parameters(): 209 | param.requires_grad = True 210 | 211 | epoch_step = num_train // batch_size 212 | epoch_step_val = num_val // batch_size 213 | 214 | if distributed: 215 | batch_size = batch_size // ngpus_per_node 216 | 217 | if ema: 218 | ema.updates = epoch_step * epoch 219 | 220 | gen = DataLoader(train_dataset, shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True, 221 | drop_last=True, collate_fn=yolo_dataset_collate, sampler=train_sampler, 222 | worker_init_fn=partial(worker_init_fn, rank=rank, seed=seed)) 223 | gen_val = DataLoader(val_dataset , shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True, 224 | drop_last=True, collate_fn=yolo_dataset_collate, sampler=val_sampler, 225 | worker_init_fn=partial(worker_init_fn, rank=rank, seed=seed)) 226 | UnFreeze_flag = True 227 | 228 | gen.dataset.epoch_now = epoch 229 | gen_val.dataset.epoch_now = epoch 230 | 231 | if distributed: 232 | train_sampler.set_epoch(epoch) 233 | 234 | set_optimizer_lr(optimizer, lr_scheduler_func, epoch) 235 | 236 | fit_one_epoch(model_train, model, ema, yolo_loss, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, UnFreeze_Epoch, Cuda, fp16, scaler, save_period, save_dir, local_rank) 237 | 238 | if distributed: 239 | dist.barrier() 240 | 241 | if local_rank == 0: 242 | loss_history.writer.close() 243 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SwjtuMa/FER-YOLO-Mamba/e301ceeaf53f0eec2f6552cf0a6eb09a2ce90587/utils/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /utils/__pycache__/callbacks.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SwjtuMa/FER-YOLO-Mamba/e301ceeaf53f0eec2f6552cf0a6eb09a2ce90587/utils/__pycache__/callbacks.cpython-311.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataloader.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SwjtuMa/FER-YOLO-Mamba/e301ceeaf53f0eec2f6552cf0a6eb09a2ce90587/utils/__pycache__/dataloader.cpython-311.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SwjtuMa/FER-YOLO-Mamba/e301ceeaf53f0eec2f6552cf0a6eb09a2ce90587/utils/__pycache__/utils.cpython-311.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_bbox.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SwjtuMa/FER-YOLO-Mamba/e301ceeaf53f0eec2f6552cf0a6eb09a2ce90587/utils/__pycache__/utils_bbox.cpython-311.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_fit.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SwjtuMa/FER-YOLO-Mamba/e301ceeaf53f0eec2f6552cf0a6eb09a2ce90587/utils/__pycache__/utils_fit.cpython-311.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_map.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SwjtuMa/FER-YOLO-Mamba/e301ceeaf53f0eec2f6552cf0a6eb09a2ce90587/utils/__pycache__/utils_map.cpython-311.pyc -------------------------------------------------------------------------------- /utils/callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import matplotlib 5 | matplotlib.use('Agg') 6 | import scipy.signal 7 | from matplotlib import pyplot as plt 8 | from torch.utils.tensorboard import SummaryWriter 9 | 10 | import shutil 11 | import numpy as np 12 | 13 | from PIL import Image 14 | from tqdm import tqdm 15 | from .utils import cvtColor, preprocess_input, resize_image 16 | from .utils_bbox import decode_outputs, non_max_suppression 17 | from .utils_map import get_coco_map, get_map 18 | 19 | 20 | class LossHistory(): 21 | def __init__(self, log_dir, model, input_shape): 22 | self.log_dir = log_dir 23 | self.losses = [] 24 | self.val_loss = [] 25 | 26 | os.makedirs(self.log_dir) 27 | self.writer = SummaryWriter(self.log_dir) 28 | try: 29 | dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1]) 30 | self.writer.add_graph(model, dummy_input) 31 | except: 32 | pass 33 | 34 | def append_loss(self, epoch, loss, val_loss): 35 | if not os.path.exists(self.log_dir): 36 | os.makedirs(self.log_dir) 37 | 38 | self.losses.append(loss) 39 | self.val_loss.append(val_loss) 40 | 41 | with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f: 42 | f.write(str(loss)) 43 | f.write("\n") 44 | with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f: 45 | f.write(str(val_loss)) 46 | f.write("\n") 47 | 48 | self.writer.add_scalar('loss', loss, epoch) 49 | self.writer.add_scalar('val_loss', val_loss, epoch) 50 | self.loss_plot() 51 | 52 | def loss_plot(self): 53 | iters = range(len(self.losses)) 54 | 55 | plt.figure() 56 | plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss') 57 | plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss') 58 | try: 59 | if len(self.losses) < 25: 60 | num = 5 61 | else: 62 | num = 15 63 | 64 | plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss') 65 | plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss') 66 | except: 67 | pass 68 | 69 | plt.grid(True) 70 | plt.xlabel('Epoch') 71 | plt.ylabel('Loss') 72 | plt.legend(loc="upper right") 73 | 74 | plt.savefig(os.path.join(self.log_dir, "epoch_loss.png")) 75 | 76 | plt.cla() 77 | plt.close("all") 78 | 79 | class EvalCallback(): 80 | def __init__(self, net, input_shape, class_names, num_classes, val_lines, log_dir, cuda, \ 81 | map_out_path=".temp_map_out", max_boxes=100, confidence=0.05, nms_iou=0.5, letterbox_image=True, MINOVERLAP=0.5, eval_flag=True, period=1): 82 | super(EvalCallback, self).__init__() 83 | 84 | self.net = net 85 | self.input_shape = input_shape 86 | self.class_names = class_names 87 | self.num_classes = num_classes 88 | self.val_lines = val_lines 89 | self.log_dir = log_dir 90 | self.cuda = cuda 91 | self.map_out_path = map_out_path 92 | self.max_boxes = max_boxes 93 | self.confidence = confidence 94 | self.nms_iou = nms_iou 95 | self.letterbox_image = letterbox_image 96 | self.MINOVERLAP = MINOVERLAP 97 | self.eval_flag = eval_flag 98 | self.period = period 99 | 100 | self.maps = [0] 101 | self.epoches = [0] 102 | if self.eval_flag: 103 | with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f: 104 | f.write(str(0)) 105 | f.write("\n") 106 | 107 | def get_map_txt(self, image_id, image, class_names, map_out_path): 108 | f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"),"w") 109 | image_shape = np.array(np.shape(image)[0:2]) 110 | #---------------------------------------------------------# 111 | # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 112 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB 113 | #---------------------------------------------------------# 114 | image = cvtColor(image) 115 | #---------------------------------------------------------# 116 | # 给图像增加灰条,实现不失真的resize 117 | # 也可以直接resize进行识别 118 | #---------------------------------------------------------# 119 | image_data = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image) 120 | #---------------------------------------------------------# 121 | # 添加上batch_size维度 122 | #---------------------------------------------------------# 123 | image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0) 124 | 125 | with torch.no_grad(): 126 | images = torch.from_numpy(image_data) 127 | if self.cuda: 128 | images = images.cuda() 129 | #---------------------------------------------------------# 130 | # 将图像输入网络当中进行预测! 131 | #---------------------------------------------------------# 132 | outputs = self.net(images) 133 | outputs = decode_outputs(outputs, self.input_shape) 134 | #---------------------------------------------------------# 135 | # 将预测框进行堆叠,然后进行非极大抑制 136 | #---------------------------------------------------------# 137 | results = non_max_suppression(outputs, self.num_classes, self.input_shape, 138 | image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou) 139 | 140 | if results[0] is None: 141 | return 142 | 143 | top_label = np.array(results[0][:, 6], dtype = 'int32') 144 | top_conf = results[0][:, 4] * results[0][:, 5] 145 | top_boxes = results[0][:, :4] 146 | 147 | top_100 = np.argsort(top_conf)[::-1][:self.max_boxes] 148 | top_boxes = top_boxes[top_100] 149 | top_conf = top_conf[top_100] 150 | top_label = top_label[top_100] 151 | 152 | for i, c in list(enumerate(top_label)): 153 | predicted_class = self.class_names[int(c)] 154 | box = top_boxes[i] 155 | score = str(top_conf[i]) 156 | 157 | top, left, bottom, right = box 158 | if predicted_class not in class_names: 159 | continue 160 | 161 | f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom)))) 162 | 163 | f.close() 164 | return 165 | 166 | def on_epoch_end(self, epoch, model_eval): 167 | if epoch % self.period == 0 and self.eval_flag: 168 | self.net = model_eval 169 | if not os.path.exists(self.map_out_path): 170 | os.makedirs(self.map_out_path) 171 | if not os.path.exists(os.path.join(self.map_out_path, "ground-truth")): 172 | os.makedirs(os.path.join(self.map_out_path, "ground-truth")) 173 | if not os.path.exists(os.path.join(self.map_out_path, "detection-results")): 174 | os.makedirs(os.path.join(self.map_out_path, "detection-results")) 175 | print("Get map.") 176 | for annotation_line in tqdm(self.val_lines): 177 | line = annotation_line.split() 178 | image_id = os.path.basename(line[0]).split('.')[0] 179 | #------------------------------# 180 | # 读取图像并转换成RGB图像 181 | #------------------------------# 182 | image = Image.open(line[0]) 183 | #------------------------------# 184 | # 获得预测框 185 | #------------------------------# 186 | gt_boxes = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]]) 187 | #------------------------------# 188 | # 获得预测txt 189 | #------------------------------# 190 | self.get_map_txt(image_id, image, self.class_names, self.map_out_path) 191 | 192 | #------------------------------# 193 | # 获得真实框txt 194 | #------------------------------# 195 | with open(os.path.join(self.map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f: 196 | for box in gt_boxes: 197 | left, top, right, bottom, obj = box 198 | obj_name = self.class_names[obj] 199 | new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom)) 200 | 201 | print("Calculate Map.") 202 | try: 203 | temp_map = get_coco_map(class_names = self.class_names, path = self.map_out_path)[1] 204 | except: 205 | temp_map = get_map(self.MINOVERLAP, False, path = self.map_out_path) 206 | self.maps.append(temp_map) 207 | self.epoches.append(epoch) 208 | 209 | with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f: 210 | f.write(str(temp_map)) 211 | f.write("\n") 212 | 213 | plt.figure() 214 | plt.plot(self.epoches, self.maps, 'red', linewidth = 2, label='train map') 215 | 216 | plt.grid(True) 217 | plt.xlabel('Epoch') 218 | plt.ylabel('Map %s'%str(self.MINOVERLAP)) 219 | plt.title('A Map Curve') 220 | plt.legend(loc="upper right") 221 | 222 | plt.savefig(os.path.join(self.log_dir, "epoch_map.png")) 223 | plt.cla() 224 | plt.close("all") 225 | 226 | print("Get map done.") 227 | shutil.rmtree(self.map_out_path) 228 | -------------------------------------------------------------------------------- /utils/dataloader.py: -------------------------------------------------------------------------------- 1 | from random import sample, shuffle 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | from torch.utils.data.dataset import Dataset 8 | 9 | from utils.utils import cvtColor, preprocess_input 10 | 11 | 12 | class YoloDataset(Dataset): 13 | def __init__(self, annotation_lines, input_shape, num_classes, epoch_length, \ 14 | mosaic, mixup, mosaic_prob, mixup_prob, train, special_aug_ratio = 0.7): 15 | super(YoloDataset, self).__init__() 16 | self.annotation_lines = annotation_lines 17 | self.input_shape = input_shape 18 | self.num_classes = num_classes 19 | self.epoch_length = epoch_length 20 | self.mosaic = mosaic 21 | self.mosaic_prob = mosaic_prob 22 | self.mixup = mixup 23 | self.mixup_prob = mixup_prob 24 | self.train = train 25 | self.special_aug_ratio = special_aug_ratio 26 | 27 | self.epoch_now = -1 28 | self.length = len(self.annotation_lines) 29 | 30 | def __len__(self): 31 | return self.length 32 | 33 | def __getitem__(self, index): 34 | index = index % self.length 35 | 36 | #---------------------------------------------------# 37 | # 训练时进行数据的随机增强 38 | # 验证时不进行数据的随机增强 39 | #---------------------------------------------------# 40 | if self.mosaic and self.rand() < self.mosaic_prob and self.epoch_now < self.epoch_length * self.special_aug_ratio: 41 | lines = sample(self.annotation_lines, 3) 42 | lines.append(self.annotation_lines[index]) 43 | shuffle(lines) 44 | image, box = self.get_random_data_with_Mosaic(lines, self.input_shape) 45 | 46 | if self.mixup and self.rand() < self.mixup_prob: 47 | lines = sample(self.annotation_lines, 1) 48 | image_2, box_2 = self.get_random_data(lines[0], self.input_shape, random = self.train) 49 | image, box = self.get_random_data_with_MixUp(image, box, image_2, box_2) 50 | else: 51 | image, box = self.get_random_data(self.annotation_lines[index], self.input_shape, random = self.train) 52 | 53 | image = np.transpose(preprocess_input(np.array(image, dtype=np.float32)), (2, 0, 1)) 54 | box = np.array(box, dtype=np.float32) 55 | if len(box) != 0: 56 | box[:, 2:4] = box[:, 2:4] - box[:, 0:2] 57 | box[:, 0:2] = box[:, 0:2] + box[:, 2:4] / 2 58 | return image, box 59 | 60 | def rand(self, a=0, b=1): 61 | return np.random.rand()*(b-a) + a 62 | 63 | def get_random_data(self, annotation_line, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.4, random=True): 64 | line = annotation_line.split() 65 | #------------------------------# 66 | # 读取图像并转换成RGB图像 67 | #------------------------------# 68 | image = Image.open(line[0]) 69 | image = cvtColor(image) 70 | #------------------------------# 71 | # 获得图像的高宽与目标高宽 72 | #------------------------------# 73 | iw, ih = image.size 74 | h, w = input_shape 75 | #------------------------------# 76 | # 获得预测框 77 | #------------------------------# 78 | box = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]]) 79 | 80 | if not random: 81 | scale = min(w/iw, h/ih) 82 | nw = int(iw*scale) 83 | nh = int(ih*scale) 84 | dx = (w-nw)//2 85 | dy = (h-nh)//2 86 | 87 | #---------------------------------# 88 | # 将图像多余的部分加上灰条 89 | #---------------------------------# 90 | image = image.resize((nw,nh), Image.BICUBIC) 91 | new_image = Image.new('RGB', (w,h), (128,128,128)) 92 | new_image.paste(image, (dx, dy)) 93 | image_data = np.array(new_image, np.float32) 94 | 95 | #---------------------------------# 96 | # 对真实框进行调整 97 | #---------------------------------# 98 | if len(box)>0: 99 | np.random.shuffle(box) 100 | box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx 101 | box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy 102 | box[:, 0:2][box[:, 0:2]<0] = 0 103 | box[:, 2][box[:, 2]>w] = w 104 | box[:, 3][box[:, 3]>h] = h 105 | box_w = box[:, 2] - box[:, 0] 106 | box_h = box[:, 3] - box[:, 1] 107 | box = box[np.logical_and(box_w>1, box_h>1)] # discard invalid box 108 | 109 | return image_data, box 110 | 111 | #------------------------------------------# 112 | # 对图像进行缩放并且进行长和宽的扭曲 113 | #------------------------------------------# 114 | new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter) 115 | scale = self.rand(.25, 2) 116 | if new_ar < 1: 117 | nh = int(scale*h) 118 | nw = int(nh*new_ar) 119 | else: 120 | nw = int(scale*w) 121 | nh = int(nw/new_ar) 122 | image = image.resize((nw,nh), Image.BICUBIC) 123 | 124 | #------------------------------------------# 125 | # 将图像多余的部分加上灰条 126 | #------------------------------------------# 127 | dx = int(self.rand(0, w-nw)) 128 | dy = int(self.rand(0, h-nh)) 129 | new_image = Image.new('RGB', (w,h), (128,128,128)) 130 | new_image.paste(image, (dx, dy)) 131 | image = new_image 132 | 133 | #------------------------------------------# 134 | # 翻转图像 135 | #------------------------------------------# 136 | flip = self.rand()<.5 137 | if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT) 138 | 139 | image_data = np.array(image, np.uint8) 140 | #---------------------------------# 141 | # 对图像进行色域变换 142 | # 计算色域变换的参数 143 | #---------------------------------# 144 | r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1 145 | #---------------------------------# 146 | # 将图像转到HSV上 147 | #---------------------------------# 148 | hue, sat, val = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV)) 149 | dtype = image_data.dtype 150 | #---------------------------------# 151 | # 应用变换 152 | #---------------------------------# 153 | x = np.arange(0, 256, dtype=r.dtype) 154 | lut_hue = ((x * r[0]) % 180).astype(dtype) 155 | lut_sat = np.clip(x * r[1], 0, 255).astype(dtype) 156 | lut_val = np.clip(x * r[2], 0, 255).astype(dtype) 157 | 158 | image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))) 159 | image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB) 160 | 161 | #---------------------------------# 162 | # 对真实框进行调整 163 | #---------------------------------# 164 | if len(box)>0: 165 | np.random.shuffle(box) 166 | box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx 167 | box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy 168 | if flip: box[:, [0,2]] = w - box[:, [2,0]] 169 | box[:, 0:2][box[:, 0:2]<0] = 0 170 | box[:, 2][box[:, 2]>w] = w 171 | box[:, 3][box[:, 3]>h] = h 172 | box_w = box[:, 2] - box[:, 0] 173 | box_h = box[:, 3] - box[:, 1] 174 | box = box[np.logical_and(box_w>1, box_h>1)] 175 | 176 | return image_data, box 177 | 178 | def merge_bboxes(self, bboxes, cutx, cuty): 179 | merge_bbox = [] 180 | for i in range(len(bboxes)): 181 | for box in bboxes[i]: 182 | tmp_box = [] 183 | x1, y1, x2, y2 = box[0], box[1], box[2], box[3] 184 | 185 | if i == 0: 186 | if y1 > cuty or x1 > cutx: 187 | continue 188 | if y2 >= cuty and y1 <= cuty: 189 | y2 = cuty 190 | if x2 >= cutx and x1 <= cutx: 191 | x2 = cutx 192 | 193 | if i == 1: 194 | if y2 < cuty or x1 > cutx: 195 | continue 196 | if y2 >= cuty and y1 <= cuty: 197 | y1 = cuty 198 | if x2 >= cutx and x1 <= cutx: 199 | x2 = cutx 200 | 201 | if i == 2: 202 | if y2 < cuty or x2 < cutx: 203 | continue 204 | if y2 >= cuty and y1 <= cuty: 205 | y1 = cuty 206 | if x2 >= cutx and x1 <= cutx: 207 | x1 = cutx 208 | 209 | if i == 3: 210 | if y1 > cuty or x2 < cutx: 211 | continue 212 | if y2 >= cuty and y1 <= cuty: 213 | y2 = cuty 214 | if x2 >= cutx and x1 <= cutx: 215 | x1 = cutx 216 | tmp_box.append(x1) 217 | tmp_box.append(y1) 218 | tmp_box.append(x2) 219 | tmp_box.append(y2) 220 | tmp_box.append(box[-1]) 221 | merge_bbox.append(tmp_box) 222 | return merge_bbox 223 | 224 | def get_random_data_with_Mosaic(self, annotation_line, input_shape, jitter=0.3, hue=.1, sat=0.7, val=0.4): 225 | h, w = input_shape 226 | min_offset_x = self.rand(0.3, 0.7) 227 | min_offset_y = self.rand(0.3, 0.7) 228 | 229 | image_datas = [] 230 | box_datas = [] 231 | index = 0 232 | for line in annotation_line: 233 | #---------------------------------# 234 | # 每一行进行分割 235 | #---------------------------------# 236 | line_content = line.split() 237 | #---------------------------------# 238 | # 打开图片 239 | #---------------------------------# 240 | image = Image.open(line_content[0]) 241 | image = cvtColor(image) 242 | 243 | #---------------------------------# 244 | # 图片的大小 245 | #---------------------------------# 246 | iw, ih = image.size 247 | #---------------------------------# 248 | # 保存框的位置 249 | #---------------------------------# 250 | box = np.array([np.array(list(map(int,box.split(',')))) for box in line_content[1:]]) 251 | 252 | #---------------------------------# 253 | # 是否翻转图片 254 | #---------------------------------# 255 | flip = self.rand()<.5 256 | if flip and len(box)>0: 257 | image = image.transpose(Image.FLIP_LEFT_RIGHT) 258 | box[:, [0,2]] = iw - box[:, [2,0]] 259 | 260 | #------------------------------------------# 261 | # 对图像进行缩放并且进行长和宽的扭曲 262 | #------------------------------------------# 263 | new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter) 264 | scale = self.rand(.4, 1) 265 | if new_ar < 1: 266 | nh = int(scale*h) 267 | nw = int(nh*new_ar) 268 | else: 269 | nw = int(scale*w) 270 | nh = int(nw/new_ar) 271 | image = image.resize((nw, nh), Image.BICUBIC) 272 | 273 | #-----------------------------------------------# 274 | # 将图片进行放置,分别对应四张分割图片的位置 275 | #-----------------------------------------------# 276 | if index == 0: 277 | dx = int(w*min_offset_x) - nw 278 | dy = int(h*min_offset_y) - nh 279 | elif index == 1: 280 | dx = int(w*min_offset_x) - nw 281 | dy = int(h*min_offset_y) 282 | elif index == 2: 283 | dx = int(w*min_offset_x) 284 | dy = int(h*min_offset_y) 285 | elif index == 3: 286 | dx = int(w*min_offset_x) 287 | dy = int(h*min_offset_y) - nh 288 | 289 | new_image = Image.new('RGB', (w,h), (128,128,128)) 290 | new_image.paste(image, (dx, dy)) 291 | image_data = np.array(new_image) 292 | 293 | index = index + 1 294 | box_data = [] 295 | #---------------------------------# 296 | # 对box进行重新处理 297 | #---------------------------------# 298 | if len(box)>0: 299 | np.random.shuffle(box) 300 | box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx 301 | box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy 302 | box[:, 0:2][box[:, 0:2]<0] = 0 303 | box[:, 2][box[:, 2]>w] = w 304 | box[:, 3][box[:, 3]>h] = h 305 | box_w = box[:, 2] - box[:, 0] 306 | box_h = box[:, 3] - box[:, 1] 307 | box = box[np.logical_and(box_w>1, box_h>1)] 308 | box_data = np.zeros((len(box),5)) 309 | box_data[:len(box)] = box 310 | 311 | image_datas.append(image_data) 312 | box_datas.append(box_data) 313 | 314 | #---------------------------------# 315 | # 将图片分割,放在一起 316 | #---------------------------------# 317 | cutx = int(w * min_offset_x) 318 | cuty = int(h * min_offset_y) 319 | 320 | new_image = np.zeros([h, w, 3]) 321 | new_image[:cuty, :cutx, :] = image_datas[0][:cuty, :cutx, :] 322 | new_image[cuty:, :cutx, :] = image_datas[1][cuty:, :cutx, :] 323 | new_image[cuty:, cutx:, :] = image_datas[2][cuty:, cutx:, :] 324 | new_image[:cuty, cutx:, :] = image_datas[3][:cuty, cutx:, :] 325 | 326 | new_image = np.array(new_image, np.uint8) 327 | #---------------------------------# 328 | # 对图像进行色域变换 329 | # 计算色域变换的参数 330 | #---------------------------------# 331 | r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1 332 | #---------------------------------# 333 | # 将图像转到HSV上 334 | #---------------------------------# 335 | hue, sat, val = cv2.split(cv2.cvtColor(new_image, cv2.COLOR_RGB2HSV)) 336 | dtype = new_image.dtype 337 | #---------------------------------# 338 | # 应用变换 339 | #---------------------------------# 340 | x = np.arange(0, 256, dtype=r.dtype) 341 | lut_hue = ((x * r[0]) % 180).astype(dtype) 342 | lut_sat = np.clip(x * r[1], 0, 255).astype(dtype) 343 | lut_val = np.clip(x * r[2], 0, 255).astype(dtype) 344 | 345 | new_image = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))) 346 | new_image = cv2.cvtColor(new_image, cv2.COLOR_HSV2RGB) 347 | 348 | #---------------------------------# 349 | # 对框进行进一步的处理 350 | #---------------------------------# 351 | new_boxes = self.merge_bboxes(box_datas, cutx, cuty) 352 | 353 | return new_image, new_boxes 354 | 355 | def get_random_data_with_MixUp(self, image_1, box_1, image_2, box_2): 356 | new_image = np.array(image_1, np.float32) * 0.5 + np.array(image_2, np.float32) * 0.5 357 | if len(box_1) == 0: 358 | new_boxes = box_2 359 | elif len(box_2) == 0: 360 | new_boxes = box_1 361 | else: 362 | new_boxes = np.concatenate([box_1, box_2], axis=0) 363 | return new_image, new_boxes 364 | 365 | # DataLoader中collate_fn使用 366 | def yolo_dataset_collate(batch): 367 | images = [] 368 | bboxes = [] 369 | for img, box in batch: 370 | images.append(img) 371 | bboxes.append(box) 372 | images = torch.from_numpy(np.array(images)).type(torch.FloatTensor) 373 | bboxes = [torch.from_numpy(ann).type(torch.FloatTensor) for ann in bboxes] 374 | return images, bboxes 375 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | from PIL import Image 6 | 7 | 8 | #---------------------------------------------------------# 9 | # 将图像转换成RGB图像,防止灰度图在预测时报错。 10 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB 11 | #---------------------------------------------------------# 12 | def cvtColor(image): 13 | if len(np.shape(image)) == 3 and np.shape(image)[2] == 3: 14 | return image 15 | else: 16 | image = image.convert('RGB') 17 | return image 18 | 19 | #---------------------------------------------------# 20 | # 对输入图像进行resize 21 | #---------------------------------------------------# 22 | def resize_image(image, size, letterbox_image): 23 | iw, ih = image.size 24 | w, h = size 25 | if letterbox_image: 26 | scale = min(w/iw, h/ih) 27 | nw = int(iw*scale) 28 | nh = int(ih*scale) 29 | 30 | image = image.resize((nw,nh), Image.BICUBIC) 31 | new_image = Image.new('RGB', size, (128,128,128)) 32 | new_image.paste(image, ((w-nw)//2, (h-nh)//2)) 33 | else: 34 | new_image = image.resize((w, h), Image.BICUBIC) 35 | return new_image 36 | 37 | #---------------------------------------------------# 38 | # 获得类 39 | #---------------------------------------------------# 40 | def get_classes(classes_path): 41 | with open(classes_path, encoding='utf-8') as f: 42 | class_names = f.readlines() 43 | class_names = [c.strip() for c in class_names] 44 | return class_names, len(class_names) 45 | 46 | #---------------------------------------------------# 47 | # 设置种子 48 | #---------------------------------------------------# 49 | def seed_everything(seed=11): 50 | random.seed(seed) 51 | np.random.seed(seed) 52 | torch.manual_seed(seed) 53 | torch.cuda.manual_seed(seed) 54 | torch.cuda.manual_seed_all(seed) 55 | torch.backends.cudnn.deterministic = True 56 | torch.backends.cudnn.benchmark = False 57 | 58 | #---------------------------------------------------# 59 | # 设置Dataloader的种子 60 | #---------------------------------------------------# 61 | def worker_init_fn(worker_id, rank, seed): 62 | worker_seed = rank + seed 63 | random.seed(worker_seed) 64 | np.random.seed(worker_seed) 65 | torch.manual_seed(worker_seed) 66 | 67 | def preprocess_input(image): 68 | image /= 255.0 69 | image -= np.array([0.485, 0.456, 0.406]) 70 | image /= np.array([0.229, 0.224, 0.225]) 71 | return image 72 | 73 | #---------------------------------------------------# 74 | # 获得学习率 75 | #---------------------------------------------------# 76 | def get_lr(optimizer): 77 | for param_group in optimizer.param_groups: 78 | return param_group['lr'] 79 | 80 | def show_config(**kwargs): 81 | print('Configurations:') 82 | print('-' * 70) 83 | print('|%25s | %40s|' % ('keys', 'values')) 84 | print('-' * 70) 85 | for key, value in kwargs.items(): 86 | print('|%25s | %40s|' % (str(key), str(value))) 87 | print('-' * 70) 88 | -------------------------------------------------------------------------------- /utils/utils_bbox.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torchvision.ops import nms, boxes 4 | 5 | def yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image): 6 | #-----------------------------------------------------------------# 7 | # 把y轴放前面是因为方便预测框和图像的宽高进行相乘 8 | #-----------------------------------------------------------------# 9 | box_yx = box_xy[..., ::-1] 10 | box_hw = box_wh[..., ::-1] 11 | input_shape = np.array(input_shape) 12 | image_shape = np.array(image_shape) 13 | 14 | if letterbox_image: 15 | #-----------------------------------------------------------------# 16 | # 这里求出来的offset是图像有效区域相对于图像左上角的偏移情况 17 | # new_shape指的是宽高缩放情况 18 | #-----------------------------------------------------------------# 19 | new_shape = np.round(image_shape * np.min(input_shape/image_shape)) 20 | offset = (input_shape - new_shape)/2./input_shape 21 | scale = input_shape/new_shape 22 | 23 | box_yx = (box_yx - offset) * scale 24 | box_hw *= scale 25 | 26 | box_mins = box_yx - (box_hw / 2.) 27 | box_maxes = box_yx + (box_hw / 2.) 28 | boxes = np.concatenate([box_mins[..., 0:1], box_mins[..., 1:2], box_maxes[..., 0:1], box_maxes[..., 1:2]], axis=-1) 29 | boxes *= np.concatenate([image_shape, image_shape], axis=-1) 30 | return boxes 31 | 32 | def decode_outputs(outputs, input_shape): 33 | grids = [] 34 | strides = [] 35 | hw = [x.shape[-2:] for x in outputs] 36 | #---------------------------------------------------# 37 | # outputs输入前代表每个特征层的预测结果 38 | # batch_size, 4 + 1 + num_classes, 80, 80 => batch_size, 4 + 1 + num_classes, 6400 39 | # batch_size, 5 + num_classes, 40, 40 40 | # batch_size, 5 + num_classes, 20, 20 41 | # batch_size, 4 + 1 + num_classes, 6400 + 1600 + 400 -> batch_size, 4 + 1 + num_classes, 8400 42 | # 堆叠后为batch_size, 8400, 5 + num_classes 43 | #---------------------------------------------------# 44 | outputs = torch.cat([x.flatten(start_dim=2) for x in outputs], dim=2).permute(0, 2, 1) 45 | #---------------------------------------------------# 46 | # 获得每一个特征点属于每一个种类的概率 47 | #---------------------------------------------------# 48 | outputs[:, :, 4:] = torch.sigmoid(outputs[:, :, 4:]) 49 | for h, w in hw: 50 | #---------------------------# 51 | # 根据特征层的高宽生成网格点 52 | #---------------------------# 53 | grid_y, grid_x = torch.meshgrid([torch.arange(h), torch.arange(w)]) 54 | #---------------------------# 55 | # 1, 6400, 2 56 | # 1, 1600, 2 57 | # 1, 400, 2 58 | #---------------------------# 59 | grid = torch.stack((grid_x, grid_y), 2).view(1, -1, 2) 60 | shape = grid.shape[:2] 61 | 62 | grids.append(grid) 63 | strides.append(torch.full((shape[0], shape[1], 1), input_shape[0] / h)) 64 | #---------------------------# 65 | # 将网格点堆叠到一起 66 | # 1, 6400, 2 67 | # 1, 1600, 2 68 | # 1, 400, 2 69 | # 70 | # 1, 8400, 2 71 | #---------------------------# 72 | grids = torch.cat(grids, dim=1).type(outputs.type()) 73 | strides = torch.cat(strides, dim=1).type(outputs.type()) 74 | #------------------------# 75 | # 根据网格点进行解码 76 | #------------------------# 77 | outputs[..., :2] = (outputs[..., :2] + grids) * strides 78 | outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides 79 | #-----------------# 80 | # 归一化 81 | #-----------------# 82 | outputs[..., [0,2]] = outputs[..., [0,2]] / input_shape[1] 83 | outputs[..., [1,3]] = outputs[..., [1,3]] / input_shape[0] 84 | return outputs 85 | 86 | def non_max_suppression(prediction, num_classes, input_shape, image_shape, letterbox_image, conf_thres=0.5, nms_thres=0.4): 87 | #----------------------------------------------------------# 88 | # 将预测结果的格式转换成左上角右下角的格式。 89 | # prediction [batch_size, num_anchors, 85] 90 | #----------------------------------------------------------# 91 | box_corner = prediction.new(prediction.shape) 92 | box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2 93 | box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2 94 | box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2 95 | box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2 96 | prediction[:, :, :4] = box_corner[:, :, :4] 97 | 98 | output = [None for _ in range(len(prediction))] 99 | #----------------------------------------------------------# 100 | # 对输入图片进行循环,一般只会进行一次 101 | #----------------------------------------------------------# 102 | for i, image_pred in enumerate(prediction): 103 | #----------------------------------------------------------# 104 | # 对种类预测部分取max。 105 | # class_conf [num_anchors, 1] 种类置信度 106 | # class_pred [num_anchors, 1] 种类 107 | #----------------------------------------------------------# 108 | class_conf, class_pred = torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True) 109 | 110 | #----------------------------------------------------------# 111 | # 利用置信度进行第一轮筛选 112 | #----------------------------------------------------------# 113 | conf_mask = (image_pred[:, 4] * class_conf[:, 0] >= conf_thres).squeeze() 114 | 115 | if not image_pred.size(0): 116 | continue 117 | #-------------------------------------------------------------------------# 118 | # detections [num_anchors, 7] 119 | # 7的内容为:x1, y1, x2, y2, obj_conf, class_conf, class_pred 120 | #-------------------------------------------------------------------------# 121 | detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1) 122 | detections = detections[conf_mask] 123 | 124 | nms_out_index = boxes.batched_nms( 125 | detections[:, :4], 126 | detections[:, 4] * detections[:, 5], 127 | detections[:, 6], 128 | nms_thres, 129 | ) 130 | 131 | output[i] = detections[nms_out_index] 132 | 133 | # #------------------------------------------# 134 | # # 获得预测结果中包含的所有种类 135 | # #------------------------------------------# 136 | # unique_labels = detections[:, -1].cpu().unique() 137 | 138 | # if prediction.is_cuda: 139 | # unique_labels = unique_labels.cuda() 140 | # detections = detections.cuda() 141 | 142 | # for c in unique_labels: 143 | # #------------------------------------------# 144 | # # 获得某一类得分筛选后全部的预测结果 145 | # #------------------------------------------# 146 | # detections_class = detections[detections[:, -1] == c] 147 | 148 | # #------------------------------------------# 149 | # # 使用官方自带的非极大抑制会速度更快一些! 150 | # #------------------------------------------# 151 | # keep = nms( 152 | # detections_class[:, :4], 153 | # detections_class[:, 4] * detections_class[:, 5], 154 | # nms_thres 155 | # ) 156 | # max_detections = detections_class[keep] 157 | 158 | # # # 按照存在物体的置信度排序 159 | # # _, conf_sort_index = torch.sort(detections_class[:, 4]*detections_class[:, 5], descending=True) 160 | # # detections_class = detections_class[conf_sort_index] 161 | # # # 进行非极大抑制 162 | # # max_detections = [] 163 | # # while detections_class.size(0): 164 | # # # 取出这一类置信度最高的,一步一步往下判断,判断重合程度是否大于nms_thres,如果是则去除掉 165 | # # max_detections.append(detections_class[0].unsqueeze(0)) 166 | # # if len(detections_class) == 1: 167 | # # break 168 | # # ious = bbox_iou(max_detections[-1], detections_class[1:]) 169 | # # detections_class = detections_class[1:][ious < nms_thres] 170 | # # # 堆叠 171 | # # max_detections = torch.cat(max_detections).data 172 | 173 | # # Add max detections to outputs 174 | # output[i] = max_detections if output[i] is None else torch.cat((output[i], max_detections)) 175 | 176 | if output[i] is not None: 177 | output[i] = output[i].cpu().numpy() 178 | box_xy, box_wh = (output[i][:, 0:2] + output[i][:, 2:4])/2, output[i][:, 2:4] - output[i][:, 0:2] 179 | output[i][:, :4] = yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image) 180 | return output 181 | -------------------------------------------------------------------------------- /utils/utils_fit.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from tqdm import tqdm 5 | 6 | from utils.utils import get_lr 7 | 8 | 9 | def fit_one_epoch(model_train, model, ema, yolo_loss, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, fp16, scaler, save_period, save_dir, local_rank=0): 10 | loss = 0 11 | val_loss = 0 12 | 13 | if local_rank == 0: 14 | print('Start Train') 15 | pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) 16 | model_train.train() 17 | for iteration, batch in enumerate(gen): 18 | if iteration >= epoch_step: 19 | break 20 | 21 | images, targets = batch[0], batch[1] 22 | with torch.no_grad(): 23 | if cuda: 24 | images = images.cuda(local_rank) 25 | targets = [ann.cuda(local_rank) for ann in targets] 26 | #----------------------# 27 | # 清零梯度 28 | #----------------------# 29 | optimizer.zero_grad() 30 | if not fp16: 31 | #----------------------# 32 | # 前向传播 33 | #----------------------# 34 | outputs = model_train(images) 35 | 36 | #----------------------# 37 | # 计算损失 38 | #----------------------# 39 | loss_value = yolo_loss(outputs, targets) 40 | 41 | #----------------------# 42 | # 反向传播 43 | #----------------------# 44 | loss_value.backward() 45 | optimizer.step() 46 | else: 47 | from torch.cuda.amp import autocast 48 | with autocast(): 49 | outputs = model_train(images) 50 | #----------------------# 51 | # 计算损失 52 | #----------------------# 53 | loss_value = yolo_loss(outputs, targets) 54 | 55 | #----------------------# 56 | # 反向传播 57 | #----------------------# 58 | scaler.scale(loss_value).backward() 59 | scaler.step(optimizer) 60 | scaler.update() 61 | if ema: 62 | ema.update(model_train) 63 | 64 | loss += loss_value.item() 65 | 66 | if local_rank == 0: 67 | pbar.set_postfix(**{'loss' : loss / (iteration + 1), 68 | 'lr' : get_lr(optimizer)}) 69 | pbar.update(1) 70 | 71 | if local_rank == 0: 72 | pbar.close() 73 | print('Finish Train') 74 | print('Start Validation') 75 | pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) 76 | 77 | if ema: 78 | model_train_eval = ema.ema 79 | else: 80 | model_train_eval = model_train.eval() 81 | 82 | for iteration, batch in enumerate(gen_val): 83 | if iteration >= epoch_step_val: 84 | break 85 | images, targets = batch[0], batch[1] 86 | with torch.no_grad(): 87 | if cuda: 88 | images = images.cuda(local_rank) 89 | targets = [ann.cuda(local_rank) for ann in targets] 90 | #----------------------# 91 | # 清零梯度 92 | #----------------------# 93 | optimizer.zero_grad() 94 | #----------------------# 95 | # 前向传播 96 | #----------------------# 97 | outputs = model_train_eval(images) 98 | 99 | #----------------------# 100 | # 计算损失 101 | #----------------------# 102 | loss_value = yolo_loss(outputs, targets) 103 | 104 | val_loss += loss_value.item() 105 | if local_rank == 0: 106 | pbar.set_postfix(**{'val_loss': val_loss / (iteration + 1)}) 107 | pbar.update(1) 108 | 109 | if local_rank == 0: 110 | pbar.close() 111 | print('Finish Validation') 112 | loss_history.append_loss(epoch + 1, loss / epoch_step, val_loss / epoch_step_val) 113 | eval_callback.on_epoch_end(epoch + 1, model_train_eval) 114 | print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch)) 115 | print('Total Loss: %.3f || Val Loss: %.3f ' % (loss / epoch_step, val_loss / epoch_step_val)) 116 | 117 | #-----------------------------------------------# 118 | # 保存权值 119 | #-----------------------------------------------# 120 | if ema: 121 | save_state_dict = ema.ema.state_dict() 122 | else: 123 | save_state_dict = model.state_dict() 124 | 125 | if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch: 126 | torch.save(save_state_dict, os.path.join(save_dir, "ep%03d-loss%.3f-val_loss%.3f.pth" % (epoch + 1, loss / epoch_step, val_loss / epoch_step_val))) 127 | 128 | if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss): 129 | print('Save best model to best_epoch_weights.pth') 130 | torch.save(save_state_dict, os.path.join(save_dir, "best_epoch_weights.pth")) 131 | 132 | torch.save(save_state_dict, os.path.join(save_dir, "last_epoch_weights.pth")) -------------------------------------------------------------------------------- /utils/utils_map.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import math 4 | import operator 5 | import os 6 | import shutil 7 | import sys 8 | try: 9 | from pycocotools.coco import COCO 10 | from pycocotools.cocoeval import COCOeval 11 | except: 12 | pass 13 | import cv2 14 | import matplotlib 15 | matplotlib.use('Agg') 16 | from matplotlib import pyplot as plt 17 | import numpy as np 18 | 19 | ''' 20 | 0,0 ------> x (width) 21 | | 22 | | (Left,Top) 23 | | *_________ 24 | | | | 25 | | | 26 | y |_________| 27 | (height) * 28 | (Right,Bottom) 29 | ''' 30 | 31 | def log_average_miss_rate(precision, fp_cumsum, num_images): 32 | """ 33 | log-average miss rate: 34 | Calculated by averaging miss rates at 9 evenly spaced FPPI points 35 | between 10e-2 and 10e0, in log-space. 36 | 37 | output: 38 | lamr | log-average miss rate 39 | mr | miss rate 40 | fppi | false positives per image 41 | 42 | references: 43 | [1] Dollar, Piotr, et al. "Pedestrian Detection: An Evaluation of the 44 | State of the Art." Pattern Analysis and Machine Intelligence, IEEE 45 | Transactions on 34.4 (2012): 743 - 761. 46 | """ 47 | 48 | if precision.size == 0: 49 | lamr = 0 50 | mr = 1 51 | fppi = 0 52 | return lamr, mr, fppi 53 | 54 | fppi = fp_cumsum / float(num_images) 55 | mr = (1 - precision) 56 | 57 | fppi_tmp = np.insert(fppi, 0, -1.0) 58 | mr_tmp = np.insert(mr, 0, 1.0) 59 | 60 | ref = np.logspace(-2.0, 0.0, num = 9) 61 | for i, ref_i in enumerate(ref): 62 | j = np.where(fppi_tmp <= ref_i)[-1][-1] 63 | ref[i] = mr_tmp[j] 64 | 65 | lamr = math.exp(np.mean(np.log(np.maximum(1e-10, ref)))) 66 | 67 | return lamr, mr, fppi 68 | 69 | """ 70 | throw error and exit 71 | """ 72 | def error(msg): 73 | print(msg) 74 | sys.exit(0) 75 | 76 | """ 77 | check if the number is a float between 0.0 and 1.0 78 | """ 79 | def is_float_between_0_and_1(value): 80 | try: 81 | val = float(value) 82 | if val > 0.0 and val < 1.0: 83 | return True 84 | else: 85 | return False 86 | except ValueError: 87 | return False 88 | 89 | """ 90 | Calculate the AP given the recall and precision array 91 | 1st) We compute a version of the measured precision/recall curve with 92 | precision monotonically decreasing 93 | 2nd) We compute the AP as the area under this curve by numerical integration. 94 | """ 95 | def voc_ap(rec, prec): 96 | """ 97 | --- Official matlab code VOC2012--- 98 | mrec=[0 ; rec ; 1]; 99 | mpre=[0 ; prec ; 0]; 100 | for i=numel(mpre)-1:-1:1 101 | mpre(i)=max(mpre(i),mpre(i+1)); 102 | end 103 | i=find(mrec(2:end)~=mrec(1:end-1))+1; 104 | ap=sum((mrec(i)-mrec(i-1)).*mpre(i)); 105 | """ 106 | rec.insert(0, 0.0) # insert 0.0 at begining of list 107 | rec.append(1.0) # insert 1.0 at end of list 108 | mrec = rec[:] 109 | prec.insert(0, 0.0) # insert 0.0 at begining of list 110 | prec.append(0.0) # insert 0.0 at end of list 111 | mpre = prec[:] 112 | """ 113 | This part makes the precision monotonically decreasing 114 | (goes from the end to the beginning) 115 | matlab: for i=numel(mpre)-1:-1:1 116 | mpre(i)=max(mpre(i),mpre(i+1)); 117 | """ 118 | for i in range(len(mpre)-2, -1, -1): 119 | mpre[i] = max(mpre[i], mpre[i+1]) 120 | """ 121 | This part creates a list of indexes where the recall changes 122 | matlab: i=find(mrec(2:end)~=mrec(1:end-1))+1; 123 | """ 124 | i_list = [] 125 | for i in range(1, len(mrec)): 126 | if mrec[i] != mrec[i-1]: 127 | i_list.append(i) # if it was matlab would be i + 1 128 | """ 129 | The Average Precision (AP) is the area under the curve 130 | (numerical integration) 131 | matlab: ap=sum((mrec(i)-mrec(i-1)).*mpre(i)); 132 | """ 133 | ap = 0.0 134 | for i in i_list: 135 | ap += ((mrec[i]-mrec[i-1])*mpre[i]) 136 | return ap, mrec, mpre 137 | 138 | 139 | """ 140 | Convert the lines of a file to a list 141 | """ 142 | def file_lines_to_list(path): 143 | # open txt file lines to a list 144 | with open(path) as f: 145 | content = f.readlines() 146 | # remove whitespace characters like `\n` at the end of each line 147 | content = [x.strip() for x in content] 148 | return content 149 | 150 | """ 151 | Draws text in image 152 | """ 153 | def draw_text_in_image(img, text, pos, color, line_width): 154 | font = cv2.FONT_HERSHEY_PLAIN 155 | fontScale = 1 156 | lineType = 1 157 | bottomLeftCornerOfText = pos 158 | cv2.putText(img, text, 159 | bottomLeftCornerOfText, 160 | font, 161 | fontScale, 162 | color, 163 | lineType) 164 | text_width, _ = cv2.getTextSize(text, font, fontScale, lineType)[0] 165 | return img, (line_width + text_width) 166 | 167 | """ 168 | Plot - adjust axes 169 | """ 170 | def adjust_axes(r, t, fig, axes): 171 | # get text width for re-scaling 172 | bb = t.get_window_extent(renderer=r) 173 | text_width_inches = bb.width / fig.dpi 174 | # get axis width in inches 175 | current_fig_width = fig.get_figwidth() 176 | new_fig_width = current_fig_width + text_width_inches 177 | propotion = new_fig_width / current_fig_width 178 | # get axis limit 179 | x_lim = axes.get_xlim() 180 | axes.set_xlim([x_lim[0], x_lim[1]*propotion]) 181 | 182 | """ 183 | Draw plot using Matplotlib 184 | """ 185 | def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, output_path, to_show, plot_color, true_p_bar): 186 | # sort the dictionary by decreasing value, into a list of tuples 187 | sorted_dic_by_value = sorted(dictionary.items(), key=operator.itemgetter(1)) 188 | # unpacking the list of tuples into two lists 189 | sorted_keys, sorted_values = zip(*sorted_dic_by_value) 190 | # 191 | if true_p_bar != "": 192 | """ 193 | Special case to draw in: 194 | - green -> TP: True Positives (object detected and matches ground-truth) 195 | - red -> FP: False Positives (object detected but does not match ground-truth) 196 | - orange -> FN: False Negatives (object not detected but present in the ground-truth) 197 | """ 198 | fp_sorted = [] 199 | tp_sorted = [] 200 | for key in sorted_keys: 201 | fp_sorted.append(dictionary[key] - true_p_bar[key]) 202 | tp_sorted.append(true_p_bar[key]) 203 | plt.barh(range(n_classes), fp_sorted, align='center', color='crimson', label='False Positive') 204 | plt.barh(range(n_classes), tp_sorted, align='center', color='forestgreen', label='True Positive', left=fp_sorted) 205 | # add legend 206 | plt.legend(loc='lower right') 207 | """ 208 | Write number on side of bar 209 | """ 210 | fig = plt.gcf() # gcf - get current figure 211 | axes = plt.gca() 212 | r = fig.canvas.get_renderer() 213 | for i, val in enumerate(sorted_values): 214 | fp_val = fp_sorted[i] 215 | tp_val = tp_sorted[i] 216 | fp_str_val = " " + str(fp_val) 217 | tp_str_val = fp_str_val + " " + str(tp_val) 218 | # trick to paint multicolor with offset: 219 | # first paint everything and then repaint the first number 220 | t = plt.text(val, i, tp_str_val, color='forestgreen', va='center', fontweight='bold') 221 | plt.text(val, i, fp_str_val, color='crimson', va='center', fontweight='bold') 222 | if i == (len(sorted_values)-1): # largest bar 223 | adjust_axes(r, t, fig, axes) 224 | else: 225 | plt.barh(range(n_classes), sorted_values, color=plot_color) 226 | """ 227 | Write number on side of bar 228 | """ 229 | fig = plt.gcf() # gcf - get current figure 230 | axes = plt.gca() 231 | r = fig.canvas.get_renderer() 232 | for i, val in enumerate(sorted_values): 233 | str_val = " " + str(val) # add a space before 234 | if val < 1.0: 235 | str_val = " {0:.2f}".format(val) 236 | t = plt.text(val, i, str_val, color=plot_color, va='center', fontweight='bold') 237 | # re-set axes to show number inside the figure 238 | if i == (len(sorted_values)-1): # largest bar 239 | adjust_axes(r, t, fig, axes) 240 | # set window title 241 | fig.canvas.set_window_title(window_title) 242 | # write classes in y axis 243 | tick_font_size = 12 244 | plt.yticks(range(n_classes), sorted_keys, fontsize=tick_font_size) 245 | """ 246 | Re-scale height accordingly 247 | """ 248 | init_height = fig.get_figheight() 249 | # comput the matrix height in points and inches 250 | dpi = fig.dpi 251 | height_pt = n_classes * (tick_font_size * 1.4) # 1.4 (some spacing) 252 | height_in = height_pt / dpi 253 | # compute the required figure height 254 | top_margin = 0.15 # in percentage of the figure height 255 | bottom_margin = 0.05 # in percentage of the figure height 256 | figure_height = height_in / (1 - top_margin - bottom_margin) 257 | # set new height 258 | if figure_height > init_height: 259 | fig.set_figheight(figure_height) 260 | 261 | # set plot title 262 | plt.title(plot_title, fontsize=14) 263 | # set axis titles 264 | # plt.xlabel('classes') 265 | plt.xlabel(x_label, fontsize='large') 266 | # adjust size of window 267 | fig.tight_layout() 268 | # save the plot 269 | fig.savefig(output_path) 270 | # show image 271 | if to_show: 272 | plt.show() 273 | # close the plot 274 | plt.close() 275 | 276 | def get_map(MINOVERLAP, draw_plot, score_threhold=0.5, path = './map_out'): 277 | GT_PATH = os.path.join(path, 'ground-truth') 278 | DR_PATH = os.path.join(path, 'detection-results') 279 | IMG_PATH = os.path.join(path, 'images-optional') 280 | TEMP_FILES_PATH = os.path.join(path, '.temp_files') 281 | RESULTS_FILES_PATH = os.path.join(path, 'results') 282 | 283 | show_animation = True 284 | if os.path.exists(IMG_PATH): 285 | for dirpath, dirnames, files in os.walk(IMG_PATH): 286 | if not files: 287 | show_animation = False 288 | else: 289 | show_animation = False 290 | 291 | if not os.path.exists(TEMP_FILES_PATH): 292 | os.makedirs(TEMP_FILES_PATH) 293 | 294 | if os.path.exists(RESULTS_FILES_PATH): 295 | shutil.rmtree(RESULTS_FILES_PATH) 296 | else: 297 | os.makedirs(RESULTS_FILES_PATH) 298 | if draw_plot: 299 | try: 300 | matplotlib.use('TkAgg') 301 | except: 302 | pass 303 | os.makedirs(os.path.join(RESULTS_FILES_PATH, "AP")) 304 | os.makedirs(os.path.join(RESULTS_FILES_PATH, "F1")) 305 | os.makedirs(os.path.join(RESULTS_FILES_PATH, "Recall")) 306 | os.makedirs(os.path.join(RESULTS_FILES_PATH, "Precision")) 307 | if show_animation: 308 | os.makedirs(os.path.join(RESULTS_FILES_PATH, "images", "detections_one_by_one")) 309 | 310 | ground_truth_files_list = glob.glob(GT_PATH + '/*.txt') 311 | if len(ground_truth_files_list) == 0: 312 | error("Error: No ground-truth files found!") 313 | ground_truth_files_list.sort() 314 | gt_counter_per_class = {} 315 | counter_images_per_class = {} 316 | 317 | for txt_file in ground_truth_files_list: 318 | file_id = txt_file.split(".txt", 1)[0] 319 | file_id = os.path.basename(os.path.normpath(file_id)) 320 | temp_path = os.path.join(DR_PATH, (file_id + ".txt")) 321 | if not os.path.exists(temp_path): 322 | error_msg = "Error. File not found: {}\n".format(temp_path) 323 | error(error_msg) 324 | lines_list = file_lines_to_list(txt_file) 325 | bounding_boxes = [] 326 | is_difficult = False 327 | already_seen_classes = [] 328 | for line in lines_list: 329 | try: 330 | if "difficult" in line: 331 | class_name, left, top, right, bottom, _difficult = line.split() 332 | is_difficult = True 333 | else: 334 | class_name, left, top, right, bottom = line.split() 335 | except: 336 | if "difficult" in line: 337 | line_split = line.split() 338 | _difficult = line_split[-1] 339 | bottom = line_split[-2] 340 | right = line_split[-3] 341 | top = line_split[-4] 342 | left = line_split[-5] 343 | class_name = "" 344 | for name in line_split[:-5]: 345 | class_name += name + " " 346 | class_name = class_name[:-1] 347 | is_difficult = True 348 | else: 349 | line_split = line.split() 350 | bottom = line_split[-1] 351 | right = line_split[-2] 352 | top = line_split[-3] 353 | left = line_split[-4] 354 | class_name = "" 355 | for name in line_split[:-4]: 356 | class_name += name + " " 357 | class_name = class_name[:-1] 358 | 359 | bbox = left + " " + top + " " + right + " " + bottom 360 | if is_difficult: 361 | bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False, "difficult":True}) 362 | is_difficult = False 363 | else: 364 | bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False}) 365 | if class_name in gt_counter_per_class: 366 | gt_counter_per_class[class_name] += 1 367 | else: 368 | gt_counter_per_class[class_name] = 1 369 | 370 | if class_name not in already_seen_classes: 371 | if class_name in counter_images_per_class: 372 | counter_images_per_class[class_name] += 1 373 | else: 374 | counter_images_per_class[class_name] = 1 375 | already_seen_classes.append(class_name) 376 | 377 | with open(TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json", 'w') as outfile: 378 | json.dump(bounding_boxes, outfile) 379 | 380 | gt_classes = list(gt_counter_per_class.keys()) 381 | gt_classes = sorted(gt_classes) 382 | n_classes = len(gt_classes) 383 | 384 | dr_files_list = glob.glob(DR_PATH + '/*.txt') 385 | dr_files_list.sort() 386 | for class_index, class_name in enumerate(gt_classes): 387 | bounding_boxes = [] 388 | for txt_file in dr_files_list: 389 | file_id = txt_file.split(".txt",1)[0] 390 | file_id = os.path.basename(os.path.normpath(file_id)) 391 | temp_path = os.path.join(GT_PATH, (file_id + ".txt")) 392 | if class_index == 0: 393 | if not os.path.exists(temp_path): 394 | error_msg = "Error. File not found: {}\n".format(temp_path) 395 | error(error_msg) 396 | lines = file_lines_to_list(txt_file) 397 | for line in lines: 398 | try: 399 | tmp_class_name, confidence, left, top, right, bottom = line.split() 400 | except: 401 | line_split = line.split() 402 | bottom = line_split[-1] 403 | right = line_split[-2] 404 | top = line_split[-3] 405 | left = line_split[-4] 406 | confidence = line_split[-5] 407 | tmp_class_name = "" 408 | for name in line_split[:-5]: 409 | tmp_class_name += name + " " 410 | tmp_class_name = tmp_class_name[:-1] 411 | 412 | if tmp_class_name == class_name: 413 | bbox = left + " " + top + " " + right + " " +bottom 414 | bounding_boxes.append({"confidence":confidence, "file_id":file_id, "bbox":bbox}) 415 | 416 | bounding_boxes.sort(key=lambda x:float(x['confidence']), reverse=True) 417 | with open(TEMP_FILES_PATH + "/" + class_name + "_dr.json", 'w') as outfile: 418 | json.dump(bounding_boxes, outfile) 419 | 420 | sum_AP = 0.0 421 | ap_dictionary = {} 422 | lamr_dictionary = {} 423 | with open(RESULTS_FILES_PATH + "/results.txt", 'w') as results_file: 424 | results_file.write("# AP and precision/recall per class\n") 425 | count_true_positives = {} 426 | 427 | for class_index, class_name in enumerate(gt_classes): 428 | count_true_positives[class_name] = 0 429 | dr_file = TEMP_FILES_PATH + "/" + class_name + "_dr.json" 430 | dr_data = json.load(open(dr_file)) 431 | 432 | nd = len(dr_data) 433 | tp = [0] * nd 434 | fp = [0] * nd 435 | score = [0] * nd 436 | score_threhold_idx = 0 437 | for idx, detection in enumerate(dr_data): 438 | file_id = detection["file_id"] 439 | score[idx] = float(detection["confidence"]) 440 | if score[idx] >= score_threhold: 441 | score_threhold_idx = idx 442 | 443 | if show_animation: 444 | ground_truth_img = glob.glob1(IMG_PATH, file_id + ".*") 445 | if len(ground_truth_img) == 0: 446 | error("Error. Image not found with id: " + file_id) 447 | elif len(ground_truth_img) > 1: 448 | error("Error. Multiple image with id: " + file_id) 449 | else: 450 | img = cv2.imread(IMG_PATH + "/" + ground_truth_img[0]) 451 | img_cumulative_path = RESULTS_FILES_PATH + "/images/" + ground_truth_img[0] 452 | if os.path.isfile(img_cumulative_path): 453 | img_cumulative = cv2.imread(img_cumulative_path) 454 | else: 455 | img_cumulative = img.copy() 456 | bottom_border = 60 457 | BLACK = [0, 0, 0] 458 | img = cv2.copyMakeBorder(img, 0, bottom_border, 0, 0, cv2.BORDER_CONSTANT, value=BLACK) 459 | 460 | gt_file = TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json" 461 | ground_truth_data = json.load(open(gt_file)) 462 | ovmax = -1 463 | gt_match = -1 464 | bb = [float(x) for x in detection["bbox"].split()] 465 | for obj in ground_truth_data: 466 | if obj["class_name"] == class_name: 467 | bbgt = [ float(x) for x in obj["bbox"].split() ] 468 | bi = [max(bb[0],bbgt[0]), max(bb[1],bbgt[1]), min(bb[2],bbgt[2]), min(bb[3],bbgt[3])] 469 | iw = bi[2] - bi[0] + 1 470 | ih = bi[3] - bi[1] + 1 471 | if iw > 0 and ih > 0: 472 | ua = (bb[2] - bb[0] + 1) * (bb[3] - bb[1] + 1) + (bbgt[2] - bbgt[0] 473 | + 1) * (bbgt[3] - bbgt[1] + 1) - iw * ih 474 | ov = iw * ih / ua 475 | if ov > ovmax: 476 | ovmax = ov 477 | gt_match = obj 478 | 479 | if show_animation: 480 | status = "NO MATCH FOUND!" 481 | 482 | min_overlap = MINOVERLAP 483 | if ovmax >= min_overlap: 484 | if "difficult" not in gt_match: 485 | if not bool(gt_match["used"]): 486 | tp[idx] = 1 487 | gt_match["used"] = True 488 | count_true_positives[class_name] += 1 489 | with open(gt_file, 'w') as f: 490 | f.write(json.dumps(ground_truth_data)) 491 | if show_animation: 492 | status = "MATCH!" 493 | else: 494 | fp[idx] = 1 495 | if show_animation: 496 | status = "REPEATED MATCH!" 497 | else: 498 | fp[idx] = 1 499 | if ovmax > 0: 500 | status = "INSUFFICIENT OVERLAP" 501 | 502 | """ 503 | Draw image to show animation 504 | """ 505 | if show_animation: 506 | height, widht = img.shape[:2] 507 | white = (255,255,255) 508 | light_blue = (255,200,100) 509 | green = (0,255,0) 510 | light_red = (30,30,255) 511 | margin = 10 512 | # 1nd line 513 | v_pos = int(height - margin - (bottom_border / 2.0)) 514 | text = "Image: " + ground_truth_img[0] + " " 515 | img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0) 516 | text = "Class [" + str(class_index) + "/" + str(n_classes) + "]: " + class_name + " " 517 | img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), light_blue, line_width) 518 | if ovmax != -1: 519 | color = light_red 520 | if status == "INSUFFICIENT OVERLAP": 521 | text = "IoU: {0:.2f}% ".format(ovmax*100) + "< {0:.2f}% ".format(min_overlap*100) 522 | else: 523 | text = "IoU: {0:.2f}% ".format(ovmax*100) + ">= {0:.2f}% ".format(min_overlap*100) 524 | color = green 525 | img, _ = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width) 526 | # 2nd line 527 | v_pos += int(bottom_border / 2.0) 528 | rank_pos = str(idx+1) 529 | text = "Detection #rank: " + rank_pos + " confidence: {0:.2f}% ".format(float(detection["confidence"])*100) 530 | img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0) 531 | color = light_red 532 | if status == "MATCH!": 533 | color = green 534 | text = "Result: " + status + " " 535 | img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width) 536 | 537 | font = cv2.FONT_HERSHEY_SIMPLEX 538 | if ovmax > 0: 539 | bbgt = [ int(round(float(x))) for x in gt_match["bbox"].split() ] 540 | cv2.rectangle(img,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2) 541 | cv2.rectangle(img_cumulative,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2) 542 | cv2.putText(img_cumulative, class_name, (bbgt[0],bbgt[1] - 5), font, 0.6, light_blue, 1, cv2.LINE_AA) 543 | bb = [int(i) for i in bb] 544 | cv2.rectangle(img,(bb[0],bb[1]),(bb[2],bb[3]),color,2) 545 | cv2.rectangle(img_cumulative,(bb[0],bb[1]),(bb[2],bb[3]),color,2) 546 | cv2.putText(img_cumulative, class_name, (bb[0],bb[1] - 5), font, 0.6, color, 1, cv2.LINE_AA) 547 | 548 | cv2.imshow("Animation", img) 549 | cv2.waitKey(20) 550 | output_img_path = RESULTS_FILES_PATH + "/images/detections_one_by_one/" + class_name + "_detection" + str(idx) + ".jpg" 551 | cv2.imwrite(output_img_path, img) 552 | cv2.imwrite(img_cumulative_path, img_cumulative) 553 | 554 | cumsum = 0 555 | for idx, val in enumerate(fp): 556 | fp[idx] += cumsum 557 | cumsum += val 558 | 559 | cumsum = 0 560 | for idx, val in enumerate(tp): 561 | tp[idx] += cumsum 562 | cumsum += val 563 | 564 | rec = tp[:] 565 | for idx, val in enumerate(tp): 566 | rec[idx] = float(tp[idx]) / np.maximum(gt_counter_per_class[class_name], 1) 567 | 568 | prec = tp[:] 569 | for idx, val in enumerate(tp): 570 | prec[idx] = float(tp[idx]) / np.maximum((fp[idx] + tp[idx]), 1) 571 | 572 | ap, mrec, mprec = voc_ap(rec[:], prec[:]) 573 | F1 = np.array(rec)*np.array(prec)*2 / np.where((np.array(prec)+np.array(rec))==0, 1, (np.array(prec)+np.array(rec))) 574 | 575 | sum_AP += ap 576 | text = "{0:.2f}%".format(ap*100) + " = " + class_name + " AP " #class_name + " AP = {0:.2f}%".format(ap*100) 577 | 578 | if len(prec)>0: 579 | F1_text = "{0:.2f}".format(F1[score_threhold_idx]) + " = " + class_name + " F1 " 580 | Recall_text = "{0:.2f}%".format(rec[score_threhold_idx]*100) + " = " + class_name + " Recall " 581 | Precision_text = "{0:.2f}%".format(prec[score_threhold_idx]*100) + " = " + class_name + " Precision " 582 | else: 583 | F1_text = "0.00" + " = " + class_name + " F1 " 584 | Recall_text = "0.00%" + " = " + class_name + " Recall " 585 | Precision_text = "0.00%" + " = " + class_name + " Precision " 586 | 587 | rounded_prec = [ '%.2f' % elem for elem in prec ] 588 | rounded_rec = [ '%.2f' % elem for elem in rec ] 589 | results_file.write(text + "\n Precision: " + str(rounded_prec) + "\n Recall :" + str(rounded_rec) + "\n\n") 590 | 591 | if len(prec)>0: 592 | print(text + "\t||\tscore_threhold=" + str(score_threhold) + " : " + "F1=" + "{0:.2f}".format(F1[score_threhold_idx])\ 593 | + " ; Recall=" + "{0:.2f}%".format(rec[score_threhold_idx]*100) + " ; Precision=" + "{0:.2f}%".format(prec[score_threhold_idx]*100)) 594 | else: 595 | print(text + "\t||\tscore_threhold=" + str(score_threhold) + " : " + "F1=0.00% ; Recall=0.00% ; Precision=0.00%") 596 | ap_dictionary[class_name] = ap 597 | 598 | n_images = counter_images_per_class[class_name] 599 | lamr, mr, fppi = log_average_miss_rate(np.array(rec), np.array(fp), n_images) 600 | lamr_dictionary[class_name] = lamr 601 | 602 | if draw_plot: 603 | plt.plot(rec, prec, '-o') 604 | area_under_curve_x = mrec[:-1] + [mrec[-2]] + [mrec[-1]] 605 | area_under_curve_y = mprec[:-1] + [0.0] + [mprec[-1]] 606 | plt.fill_between(area_under_curve_x, 0, area_under_curve_y, alpha=0.2, edgecolor='r') 607 | 608 | fig = plt.gcf() 609 | fig.canvas.set_window_title('AP ' + class_name) 610 | 611 | plt.title('class: ' + text) 612 | plt.xlabel('Recall') 613 | plt.ylabel('Precision') 614 | axes = plt.gca() 615 | axes.set_xlim([0.0,1.0]) 616 | axes.set_ylim([0.0,1.05]) 617 | fig.savefig(RESULTS_FILES_PATH + "/AP/" + class_name + ".png") 618 | plt.cla() 619 | 620 | plt.plot(score, F1, "-", color='orangered') 621 | plt.title('class: ' + F1_text + "\nscore_threhold=" + str(score_threhold)) 622 | plt.xlabel('Score_Threhold') 623 | plt.ylabel('F1') 624 | axes = plt.gca() 625 | axes.set_xlim([0.0,1.0]) 626 | axes.set_ylim([0.0,1.05]) 627 | fig.savefig(RESULTS_FILES_PATH + "/F1/" + class_name + ".png") 628 | plt.cla() 629 | 630 | plt.plot(score, rec, "-H", color='gold') 631 | plt.title('class: ' + Recall_text + "\nscore_threhold=" + str(score_threhold)) 632 | plt.xlabel('Score_Threhold') 633 | plt.ylabel('Recall') 634 | axes = plt.gca() 635 | axes.set_xlim([0.0,1.0]) 636 | axes.set_ylim([0.0,1.05]) 637 | fig.savefig(RESULTS_FILES_PATH + "/Recall/" + class_name + ".png") 638 | plt.cla() 639 | 640 | plt.plot(score, prec, "-s", color='palevioletred') 641 | plt.title('class: ' + Precision_text + "\nscore_threhold=" + str(score_threhold)) 642 | plt.xlabel('Score_Threhold') 643 | plt.ylabel('Precision') 644 | axes = plt.gca() 645 | axes.set_xlim([0.0,1.0]) 646 | axes.set_ylim([0.0,1.05]) 647 | fig.savefig(RESULTS_FILES_PATH + "/Precision/" + class_name + ".png") 648 | plt.cla() 649 | 650 | if show_animation: 651 | cv2.destroyAllWindows() 652 | if n_classes == 0: 653 | print("未检测到任何种类,请检查标签信息与get_map.py中的classes_path是否修改。") 654 | return 0 655 | results_file.write("\n# mAP of all classes\n") 656 | mAP = sum_AP / n_classes 657 | text = "mAP = {0:.2f}%".format(mAP*100) 658 | results_file.write(text + "\n") 659 | print(text) 660 | 661 | shutil.rmtree(TEMP_FILES_PATH) 662 | 663 | """ 664 | Count total of detection-results 665 | """ 666 | det_counter_per_class = {} 667 | for txt_file in dr_files_list: 668 | lines_list = file_lines_to_list(txt_file) 669 | for line in lines_list: 670 | class_name = line.split()[0] 671 | if class_name in det_counter_per_class: 672 | det_counter_per_class[class_name] += 1 673 | else: 674 | det_counter_per_class[class_name] = 1 675 | dr_classes = list(det_counter_per_class.keys()) 676 | 677 | """ 678 | Write number of ground-truth objects per class to results.txt 679 | """ 680 | with open(RESULTS_FILES_PATH + "/results.txt", 'a') as results_file: 681 | results_file.write("\n# Number of ground-truth objects per class\n") 682 | for class_name in sorted(gt_counter_per_class): 683 | results_file.write(class_name + ": " + str(gt_counter_per_class[class_name]) + "\n") 684 | 685 | """ 686 | Finish counting true positives 687 | """ 688 | for class_name in dr_classes: 689 | if class_name not in gt_classes: 690 | count_true_positives[class_name] = 0 691 | 692 | """ 693 | Write number of detected objects per class to results.txt 694 | """ 695 | with open(RESULTS_FILES_PATH + "/results.txt", 'a') as results_file: 696 | results_file.write("\n# Number of detected objects per class\n") 697 | for class_name in sorted(dr_classes): 698 | n_det = det_counter_per_class[class_name] 699 | text = class_name + ": " + str(n_det) 700 | text += " (tp:" + str(count_true_positives[class_name]) + "" 701 | text += ", fp:" + str(n_det - count_true_positives[class_name]) + ")\n" 702 | results_file.write(text) 703 | 704 | """ 705 | Plot the total number of occurences of each class in the ground-truth 706 | """ 707 | if draw_plot: 708 | window_title = "ground-truth-info" 709 | plot_title = "ground-truth\n" 710 | plot_title += "(" + str(len(ground_truth_files_list)) + " files and " + str(n_classes) + " classes)" 711 | x_label = "Number of objects per class" 712 | output_path = RESULTS_FILES_PATH + "/ground-truth-info.png" 713 | to_show = False 714 | plot_color = 'forestgreen' 715 | draw_plot_func( 716 | gt_counter_per_class, 717 | n_classes, 718 | window_title, 719 | plot_title, 720 | x_label, 721 | output_path, 722 | to_show, 723 | plot_color, 724 | '', 725 | ) 726 | 727 | # """ 728 | # Plot the total number of occurences of each class in the "detection-results" folder 729 | # """ 730 | # if draw_plot: 731 | # window_title = "detection-results-info" 732 | # # Plot title 733 | # plot_title = "detection-results\n" 734 | # plot_title += "(" + str(len(dr_files_list)) + " files and " 735 | # count_non_zero_values_in_dictionary = sum(int(x) > 0 for x in list(det_counter_per_class.values())) 736 | # plot_title += str(count_non_zero_values_in_dictionary) + " detected classes)" 737 | # # end Plot title 738 | # x_label = "Number of objects per class" 739 | # output_path = RESULTS_FILES_PATH + "/detection-results-info.png" 740 | # to_show = False 741 | # plot_color = 'forestgreen' 742 | # true_p_bar = count_true_positives 743 | # draw_plot_func( 744 | # det_counter_per_class, 745 | # len(det_counter_per_class), 746 | # window_title, 747 | # plot_title, 748 | # x_label, 749 | # output_path, 750 | # to_show, 751 | # plot_color, 752 | # true_p_bar 753 | # ) 754 | 755 | """ 756 | Draw log-average miss rate plot (Show lamr of all classes in decreasing order) 757 | """ 758 | if draw_plot: 759 | window_title = "lamr" 760 | plot_title = "log-average miss rate" 761 | x_label = "log-average miss rate" 762 | output_path = RESULTS_FILES_PATH + "/lamr.png" 763 | to_show = False 764 | plot_color = 'royalblue' 765 | draw_plot_func( 766 | lamr_dictionary, 767 | n_classes, 768 | window_title, 769 | plot_title, 770 | x_label, 771 | output_path, 772 | to_show, 773 | plot_color, 774 | "" 775 | ) 776 | 777 | """ 778 | Draw mAP plot (Show AP's of all classes in decreasing order) 779 | """ 780 | if draw_plot: 781 | window_title = "mAP" 782 | plot_title = "mAP = {0:.2f}%".format(mAP*100) 783 | x_label = "Average Precision" 784 | output_path = RESULTS_FILES_PATH + "/mAP.png" 785 | to_show = True 786 | plot_color = 'royalblue' 787 | draw_plot_func( 788 | ap_dictionary, 789 | n_classes, 790 | window_title, 791 | plot_title, 792 | x_label, 793 | output_path, 794 | to_show, 795 | plot_color, 796 | "" 797 | ) 798 | return mAP 799 | 800 | def preprocess_gt(gt_path, class_names): 801 | image_ids = os.listdir(gt_path) 802 | results = {} 803 | 804 | images = [] 805 | bboxes = [] 806 | for i, image_id in enumerate(image_ids): 807 | lines_list = file_lines_to_list(os.path.join(gt_path, image_id)) 808 | boxes_per_image = [] 809 | image = {} 810 | image_id = os.path.splitext(image_id)[0] 811 | image['file_name'] = image_id + '.jpg' 812 | image['width'] = 1 813 | image['height'] = 1 814 | #-----------------------------------------------------------------# 815 | # 感谢 多学学英语吧 的提醒 816 | # 解决了'Results do not correspond to current coco set'问题 817 | #-----------------------------------------------------------------# 818 | image['id'] = str(image_id) 819 | 820 | for line in lines_list: 821 | difficult = 0 822 | if "difficult" in line: 823 | line_split = line.split() 824 | left, top, right, bottom, _difficult = line_split[-5:] 825 | class_name = "" 826 | for name in line_split[:-5]: 827 | class_name += name + " " 828 | class_name = class_name[:-1] 829 | difficult = 1 830 | else: 831 | line_split = line.split() 832 | left, top, right, bottom = line_split[-4:] 833 | class_name = "" 834 | for name in line_split[:-4]: 835 | class_name += name + " " 836 | class_name = class_name[:-1] 837 | 838 | left, top, right, bottom = float(left), float(top), float(right), float(bottom) 839 | if class_name not in class_names: 840 | continue 841 | cls_id = class_names.index(class_name) + 1 842 | bbox = [left, top, right - left, bottom - top, difficult, str(image_id), cls_id, (right - left) * (bottom - top) - 10.0] 843 | boxes_per_image.append(bbox) 844 | images.append(image) 845 | bboxes.extend(boxes_per_image) 846 | results['images'] = images 847 | 848 | categories = [] 849 | for i, cls in enumerate(class_names): 850 | category = {} 851 | category['supercategory'] = cls 852 | category['name'] = cls 853 | category['id'] = i + 1 854 | categories.append(category) 855 | results['categories'] = categories 856 | 857 | annotations = [] 858 | for i, box in enumerate(bboxes): 859 | annotation = {} 860 | annotation['area'] = box[-1] 861 | annotation['category_id'] = box[-2] 862 | annotation['image_id'] = box[-3] 863 | annotation['iscrowd'] = box[-4] 864 | annotation['bbox'] = box[:4] 865 | annotation['id'] = i 866 | annotations.append(annotation) 867 | results['annotations'] = annotations 868 | return results 869 | 870 | def preprocess_dr(dr_path, class_names): 871 | image_ids = os.listdir(dr_path) 872 | results = [] 873 | for image_id in image_ids: 874 | lines_list = file_lines_to_list(os.path.join(dr_path, image_id)) 875 | image_id = os.path.splitext(image_id)[0] 876 | for line in lines_list: 877 | line_split = line.split() 878 | confidence, left, top, right, bottom = line_split[-5:] 879 | class_name = "" 880 | for name in line_split[:-5]: 881 | class_name += name + " " 882 | class_name = class_name[:-1] 883 | left, top, right, bottom = float(left), float(top), float(right), float(bottom) 884 | result = {} 885 | result["image_id"] = str(image_id) 886 | if class_name not in class_names: 887 | continue 888 | result["category_id"] = class_names.index(class_name) + 1 889 | result["bbox"] = [left, top, right - left, bottom - top] 890 | result["score"] = float(confidence) 891 | results.append(result) 892 | return results 893 | 894 | def get_coco_map(class_names, path): 895 | GT_PATH = os.path.join(path, 'ground-truth') 896 | DR_PATH = os.path.join(path, 'detection-results') 897 | COCO_PATH = os.path.join(path, 'coco_eval') 898 | 899 | if not os.path.exists(COCO_PATH): 900 | os.makedirs(COCO_PATH) 901 | 902 | GT_JSON_PATH = os.path.join(COCO_PATH, 'instances_gt.json') 903 | DR_JSON_PATH = os.path.join(COCO_PATH, 'instances_dr.json') 904 | 905 | with open(GT_JSON_PATH, "w") as f: 906 | results_gt = preprocess_gt(GT_PATH, class_names) 907 | json.dump(results_gt, f, indent=4) 908 | 909 | with open(DR_JSON_PATH, "w") as f: 910 | results_dr = preprocess_dr(DR_PATH, class_names) 911 | json.dump(results_dr, f, indent=4) 912 | if len(results_dr) == 0: 913 | print("未检测到任何目标。") 914 | return [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 915 | 916 | cocoGt = COCO(GT_JSON_PATH) 917 | cocoDt = cocoGt.loadRes(DR_JSON_PATH) 918 | cocoEval = COCOeval(cocoGt, cocoDt, 'bbox') 919 | cocoEval.evaluate() 920 | cocoEval.accumulate() 921 | cocoEval.summarize() 922 | 923 | return cocoEval.stats -------------------------------------------------------------------------------- /yolo.py: -------------------------------------------------------------------------------- 1 | import colorsys 2 | import os 3 | import time 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from PIL import ImageDraw, ImageFont 8 | from nets.yolo import YoloBody 9 | from utils.utils import (cvtColor, get_classes, preprocess_input, resize_image, 10 | show_config) 11 | from utils.utils_bbox import decode_outputs, non_max_suppression 12 | class YOLO(object): 13 | _defaults = { 14 | "model_path" : 'logs/best_epoch_weights.pth', 15 | "classes_path" : 'model_data/sfew_classes.txt', 16 | "input_shape" : [320, 320], 17 | "phi" : 's', 18 | "confidence" : 0.5, 19 | "nms_iou" : 0.3, 20 | "letterbox_image" : True, 21 | "cuda" : True, 22 | } 23 | 24 | @classmethod 25 | def get_defaults(cls, n): 26 | if n in cls._defaults: 27 | return cls._defaults[n] 28 | else: 29 | return "Unrecognized attribute name '" + n + "'" 30 | def __init__(self, **kwargs): 31 | self.__dict__.update(self._defaults) 32 | for name, value in kwargs.items(): 33 | setattr(self, name, value) 34 | self._defaults[name] = value 35 | self.class_names, self.num_classes = get_classes(self.classes_path) 36 | hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)] 37 | self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples)) 38 | self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors)) 39 | self.generate() 40 | show_config(**self._defaults) 41 | def generate(self, onnx=False): 42 | self.net = YoloBody(self.num_classes, self.phi) 43 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 44 | self.net.load_state_dict(torch.load(self.model_path, map_location=device)) 45 | self.net = self.net.eval() 46 | print('{} model, and classes loaded.'.format(self.model_path)) 47 | if not onnx: 48 | if self.cuda: 49 | self.net = nn.DataParallel(self.net) 50 | self.net = self.net.cuda() 51 | def detect_image(self, image, crop = False, count = False): 52 | image_shape = np.array(np.shape(image)[0:2]) 53 | 54 | image = cvtColor(image) 55 | 56 | image_data = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image) 57 | 58 | image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0) 59 | 60 | with torch.no_grad(): 61 | images = torch.from_numpy(image_data) 62 | if self.cuda: 63 | images = images.cuda() 64 | 65 | outputs = self.net(images) 66 | outputs = decode_outputs(outputs, self.input_shape) 67 | 68 | results = non_max_suppression(outputs, self.num_classes, self.input_shape, 69 | image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou) 70 | 71 | if results[0] is None: 72 | return image 73 | 74 | top_label = np.array(results[0][:, 6], dtype = 'int32') 75 | top_conf = results[0][:, 4] * results[0][:, 5] 76 | top_boxes = results[0][:, :4] 77 | 78 | font = ImageFont.truetype(font='model_data/simhei.ttf', size=np.floor(3e-2 * image.size[1] + 0.5).astype('int32')) 79 | thickness = int(max((image.size[0] + image.size[1]) // np.mean(self.input_shape), 1)) 80 | if count: 81 | print("top_label:", top_label) 82 | classes_nums = np.zeros([self.num_classes]) 83 | for i in range(self.num_classes): 84 | num = np.sum(top_label == i) 85 | if num > 0: 86 | print(self.class_names[i], " : ", num) 87 | classes_nums[i] = num 88 | print("classes_nums:", classes_nums) 89 | if crop: 90 | for i, c in list(enumerate(top_label)): 91 | top, left, bottom, right = top_boxes[i] 92 | top = max(0, np.floor(top).astype('int32')) 93 | left = max(0, np.floor(left).astype('int32')) 94 | bottom = min(image.size[1], np.floor(bottom).astype('int32')) 95 | right = min(image.size[0], np.floor(right).astype('int32')) 96 | 97 | dir_save_path = "img_crop" 98 | if not os.path.exists(dir_save_path): 99 | os.makedirs(dir_save_path) 100 | crop_image = image.crop([left, top, right, bottom]) 101 | crop_image.save(os.path.join(dir_save_path, "crop_" + str(i) + ".png"), quality=95, subsampling=0) 102 | print("save crop_" + str(i) + ".png to " + dir_save_path) 103 | 104 | for i, c in list(enumerate(top_label)): 105 | predicted_class = self.class_names[int(c)] 106 | box = top_boxes[i] 107 | score = top_conf[i] 108 | 109 | top, left, bottom, right = box 110 | 111 | top = max(0, np.floor(top).astype('int32')) 112 | left = max(0, np.floor(left).astype('int32')) 113 | bottom = min(image.size[1], np.floor(bottom).astype('int32')) 114 | right = min(image.size[0], np.floor(right).astype('int32')) 115 | 116 | label = '{} {:.2f}'.format(predicted_class, score) 117 | draw = ImageDraw.Draw(image) 118 | label_size = draw.textsize(label, font) 119 | label = label.encode('utf-8') 120 | print(label, top, left, bottom, right) 121 | 122 | if top - label_size[1] >= 0: 123 | text_origin = np.array([left, top - label_size[1]]) 124 | else: 125 | text_origin = np.array([left, top + 1]) 126 | 127 | for i in range(thickness): 128 | draw.rectangle([left + i, top + i, right - i, bottom - i], outline=self.colors[c]) 129 | draw.rectangle([tuple(text_origin), tuple(text_origin + label_size)], fill=self.colors[c]) 130 | draw.text(text_origin, str(label,'UTF-8'), fill=(0, 0, 0), font=font) 131 | del draw 132 | 133 | return image 134 | 135 | def detect_heatmap(self, image, heatmap_save_path): 136 | import cv2 137 | import matplotlib 138 | matplotlib.use('Agg') 139 | import matplotlib.pyplot as plt 140 | def sigmoid(x): 141 | y = 1.0 / (1.0 + np.exp(-x)) 142 | return y 143 | 144 | image_shape = np.array(np.shape(image)[0:2]) 145 | 146 | image = cvtColor(image) 147 | 148 | image_data = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image) 149 | 150 | image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0) 151 | 152 | with torch.no_grad(): 153 | images = torch.from_numpy(image_data) 154 | if self.cuda: 155 | images = images.cuda() 156 | outputs = self.net(images) 157 | 158 | outputs = [output.cpu().numpy() for output in outputs] 159 | plt.imshow(image, alpha=1) 160 | plt.axis('off') 161 | mask = np.zeros((image.size[1], image.size[0])) 162 | for sub_output in outputs: 163 | b, c, h, w = np.shape(sub_output) 164 | sub_output = np.transpose(sub_output, [0, 2, 3, 1])[0] 165 | score = np.max(sigmoid(sub_output[..., 5:]), -1) * sigmoid(sub_output[..., 4]) 166 | score = cv2.resize(score, (image.size[0], image.size[1])) 167 | normed_score = (score * 255).astype('uint8') 168 | mask = np.maximum(mask, normed_score) 169 | 170 | plt.imshow(mask, alpha=0.5, interpolation='nearest', cmap="jet") 171 | 172 | plt.axis('off') 173 | plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) 174 | plt.margins(0, 0) 175 | plt.savefig(heatmap_save_path, dpi=200) 176 | print("Save to the " + heatmap_save_path) 177 | plt.cla() --------------------------------------------------------------------------------