├── .idea ├── inspectionProfiles │ └── Project_Default.xml ├── modules.xml └── translating-images-into-maps.iml ├── LICENSE.md ├── README.md ├── images ├── ._image_to_bev_motivation.gif └── image_to_bev_motivation.gif ├── src ├── __init__.py ├── data │ ├── augmentations.py │ ├── collate_funcs.py │ ├── data_generation.py │ ├── dataloader.py │ ├── dataloader_new.py │ ├── nuscenes │ │ ├── __init__.py │ │ ├── dataset.py │ │ ├── splits.py │ │ └── utils.py │ └── utils.py ├── model │ ├── __init__.py │ ├── axial_attention.py │ ├── backbone_utils.py │ ├── bev_transform.py │ ├── dla.py │ ├── dla_up.py │ ├── fpn.py │ ├── general_modules.py │ ├── intermediate_layer_getter.py │ ├── loss.py │ ├── mocha.py │ ├── mocha2.py │ ├── monotonic_attention.py │ ├── monotonic_attention00.py │ ├── monotonic_attention_fairseq.py │ ├── network.py │ ├── positional_encoding.py │ ├── resnet.py │ └── transformer.py └── utils.py └── train.py /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 92 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/translating-images-into-maps.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International 2 | 3 | Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible. 4 | 5 | ### Using Creative Commons Public Licenses 6 | 7 | Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses. 8 | 9 | * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors). 10 | 11 | * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees). 12 | 13 | ## Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License 14 | 15 | By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions. 16 | 17 | ### Section 1 – Definitions. 18 | 19 | a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image. 20 | 21 | b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License. 22 | 23 | c. __BY-NC-SA Compatible License__ means a license listed at [creativecommons.org/compatiblelicenses](http://creativecommons.org/compatiblelicenses), approved by Creative Commons as essentially the equivalent of this Public License. 24 | 25 | d. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights. 26 | 27 | e. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements. 28 | 29 | f. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material. 30 | 31 | g. __License Elements__ means the license attributes listed in the name of a Creative Commons Public License. The License Elements of this Public License are Attribution, NonCommercial, and ShareAlike. 32 | 33 | h. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License. 34 | 35 | i. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license. 36 | 37 | h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License. 38 | 39 | i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange. 40 | 41 | j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them. 42 | 43 | k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world. 44 | 45 | l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning. 46 | 47 | ### Section 2 – Scope. 48 | 49 | a. ___License grant.___ 50 | 51 | 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to: 52 | 53 | A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and 54 | 55 | B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only. 56 | 57 | 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions. 58 | 59 | 3. __Term.__ The term of this Public License is specified in Section 6(a). 60 | 61 | 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material. 62 | 63 | 5. __Downstream recipients.__ 64 | 65 | A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License. 66 | 67 | B. __Additional offer from the Licensor – Adapted Material.__ Every recipient of Adapted Material from You automatically receives an offer from the Licensor to exercise the Licensed Rights in the Adapted Material under the conditions of the Adapter’s License You apply. 68 | 69 | C. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material. 70 | 71 | 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i). 72 | 73 | b. ___Other rights.___ 74 | 75 | 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise. 76 | 77 | 2. Patent and trademark rights are not licensed under this Public License. 78 | 79 | 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes. 80 | 81 | ### Section 3 – License Conditions. 82 | 83 | Your exercise of the Licensed Rights is expressly made subject to the following conditions. 84 | 85 | a. ___Attribution.___ 86 | 87 | 1. If You Share the Licensed Material (including in modified form), You must: 88 | 89 | A. retain the following if it is supplied by the Licensor with the Licensed Material: 90 | 91 | i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated); 92 | 93 | ii. a copyright notice; 94 | 95 | iii. a notice that refers to this Public License; 96 | 97 | iv. a notice that refers to the disclaimer of warranties; 98 | 99 | v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable; 100 | 101 | B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and 102 | 103 | C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License. 104 | 105 | 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information. 106 | 107 | 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable. 108 | 109 | b. ___ShareAlike.___ 110 | 111 | In addition to the conditions in Section 3(a), if You Share Adapted Material You produce, the following conditions also apply. 112 | 113 | 1. The Adapter’s License You apply must be a Creative Commons license with the same License Elements, this version or later, or a BY-NC-SA Compatible License. 114 | 115 | 2. You must include the text of, or the URI or hyperlink to, the Adapter's License You apply. You may satisfy this condition in any reasonable manner based on the medium, means, and context in which You Share Adapted Material. 116 | 117 | 3. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, Adapted Material that restrict exercise of the rights granted under the Adapter's License You apply. 118 | 119 | ### Section 4 – Sui Generis Database Rights. 120 | 121 | Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material: 122 | 123 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only; 124 | 125 | b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material, including for purposes of Section 3(b); and 126 | 127 | c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database. 128 | 129 | For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights. 130 | 131 | ### Section 5 – Disclaimer of Warranties and Limitation of Liability. 132 | 133 | a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__ 134 | 135 | b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__ 136 | 137 | c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability. 138 | 139 | ### Section 6 – Term and Termination. 140 | 141 | a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically. 142 | 143 | b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates: 144 | 145 | 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or 146 | 147 | 2. upon express reinstatement by the Licensor. 148 | 149 | For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License. 150 | 151 | c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License. 152 | 153 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License. 154 | 155 | ### Section 7 – Other Terms and Conditions. 156 | 157 | a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed. 158 | 159 | b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License. 160 | 161 | ### Section 8 – Interpretation. 162 | 163 | a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License. 164 | 165 | b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions. 166 | 167 | c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor. 168 | 169 | d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority. 170 | 171 | > Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses. 172 | > 173 | > Creative Commons may be contacted at creativecommons.org 174 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Translating Images into Maps 2 | #### Avishkar Saha, Oscar Mendez Maldonado, Chris Russell and Richard Bowden 3 | 4 | This is the official code for the paper [Translating Images Into Maps](https://arxiv.org/abs/2110.00966) presented at ICRA 2022. 5 | 6 | ### Translating Images into Maps 7 |
8 | 9 |
10 |
11 | 12 | ### Setup 13 | The code was written using python 3.7. 14 | The following libraries are the minimal required for this repo: 15 | ```python 16 | pytorch 17 | cv2 18 | numpy 19 | pickle 20 | pyquaternion 21 | shapely 22 | lmdb 23 | ``` 24 | 25 | ### Data 26 | The official nuScenes data will be required to train the entire model. 27 | But for convenience, we provide the nuScenes mini dataset wrapped into 28 | lmdb's, they can be downladed from either of the links below: 29 | ``` 30 | https://www.icloud.com/iclouddrive/0aaSjW59DEqgUDKyy1uw0iSVg#nuscenes%5Fdata 31 | 32 | https://drive.google.com/drive/folders/1-1dZXeHnPiuqX-w8ruJHqfxBuMYMONRT?usp=share_link 33 | ``` 34 | 35 | The contents of this folder need to be unzipped and placed in a folder, create the folder 36 | as follows: 37 | ``` 38 | cd translating-images-into-maps 39 | mkdir nuscenes_data 40 | ``` 41 | 42 | This contains the ground truth maps which have already been generated for 43 | the mini dataset, the input images and intrinsics. 44 | 45 | ### Using the code: 46 | To train a model with the configuration in the paper, simply run: 47 | ```bash 48 | python train.py 49 | ``` 50 | 51 | ### Pretrained model 52 | Pretrained models and their configs required to load/train them can be downloaded from here (updated with correct models 09-09-2024): 53 | ```` 54 | https://www.icloud.com/iclouddrive/041FdACyj8m0pM4L383luzZJg#tiim_checkpoints 55 | ```` 56 | 57 | 58 | ### Citation 59 | If you find this code useful, please cite the following papers: 60 | ``` 61 | @inproceedings{saha2022translating, 62 | title={Translating Images into Maps}, 63 | author={Saha, Avishkar and Mendez, Oscar and Russell, Chris and Bowden, Richard}, 64 | booktitle={2022 IEEE International Conference on Robotics and Automation (ICRA)}, 65 | year={2022}, 66 | organization={IEEE} 67 | } 68 | @inproceedings{saha2021enabling, 69 | title={Enabling spatio-temporal aggregation in birds-eye-view vehicle estimation}, 70 | author={Saha, Avishkar and Mendez, Oscar and Russell, Chris and Bowden, Richard}, 71 | booktitle={2021 IEEE International Conference on Robotics and Automation (ICRA)}, 72 | pages={5133--5139}, 73 | year={2021}, 74 | organization={IEEE} 75 | } 76 | @inproceedings{saha2022pedestrian, 77 | title={" The Pedestrian next to the Lamppost" Adaptive Object Graphs for Better Instantaneous Mapping}, 78 | author={Saha, Avishkar and Mendez, Oscar and Russell, Chris and Bowden, Richard}, 79 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 80 | pages={19528--19537}, 81 | year={2022} 82 | } 83 | ``` 84 | -------------------------------------------------------------------------------- /images/._image_to_bev_motivation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avishkarsaha/translating-images-into-maps/a452fe7a1eb0062f96133d14380d504e7bc0c8d1/images/._image_to_bev_motivation.gif -------------------------------------------------------------------------------- /images/image_to_bev_motivation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avishkarsaha/translating-images-into-maps/a452fe7a1eb0062f96133d14380d504e7bc0c8d1/images/image_to_bev_motivation.gif -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avishkarsaha/translating-images-into-maps/a452fe7a1eb0062f96133d14380d504e7bc0c8d1/src/__init__.py -------------------------------------------------------------------------------- /src/data/augmentations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | 4 | 5 | class AugmentedMapDataset(Dataset): 6 | 7 | def __init__(self, dataset, hflip=True, desired_image_size=None): 8 | self.dataset = dataset 9 | self.hflip = hflip 10 | self.desired_image_size = desired_image_size 11 | 12 | def __len__(self): 13 | return len(self.dataset) 14 | 15 | def __getitem__(self, index): 16 | image, calib, labels, mask = self.dataset[index] 17 | 18 | # Apply data augmentation 19 | if self.hflip: 20 | image, labels, mask = random_hflip(image, labels, mask) 21 | ## TODO: add padding+cropping augmentation (copied from the original dataloader) - don't we need to change the gt and mask OR the z_intervals when the image is being changed? 22 | # if self.desired_image_size: 23 | # image, calib = image_calib_pad_and_crop(image, calib, desired_image_size) 24 | 25 | return image, calib, labels, mask 26 | 27 | 28 | def random_hflip(image, labels, mask): 29 | image = torch.flip(image, (-1,)) 30 | labels = torch.flip(labels.int(), (-1,)).bool() 31 | mask = torch.flip(mask.int(), (-1,)).bool() 32 | return image, labels, mask 33 | 34 | def image_calib_pad_and_crop(image, calib, desired_image_size): 35 | 36 | og_w, og_h = image.shape[-2:] 37 | desired_w, desired_h = desired_image_size 38 | scale_w, scale_h = desired_w / og_w, desired_h / og_h 39 | # Scale image 40 | image = image.resize((int(image.size[0] * scale_w), int(image.size[1] * scale_h))) 41 | # Pad images to the same dimensions 42 | w = image.size[0] 43 | h = image.size[1] 44 | delta_w = desired_w - w 45 | delta_h = desired_h - h 46 | pad_left = int(delta_w / 2) 47 | pad_right = delta_w - pad_left 48 | pad_top = int(delta_h / 2) 49 | pad_bottom = delta_h - pad_top 50 | left = 0 - pad_left 51 | right = pad_right + w 52 | top = 0 - pad_top 53 | bottom = pad_bottom + h 54 | image = image.crop((left, top, right, bottom)) 55 | 56 | # Modify calibration matrices 57 | # Scale first two rows of calibration matrix 58 | calib[:2, :] *= scale_w 59 | # cx' = cx - du 60 | calib[0, 2] = calib[0, 2] + pad_left 61 | # cy' = cy - dv 62 | calib[1, 2] = calib[1, 2] + pad_top 63 | 64 | return image, calib 65 | -------------------------------------------------------------------------------- /src/data/collate_funcs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torchvision.transforms.functional import to_tensor 4 | 5 | from src.utils import merge_classes2 6 | from src.utils import make_uvcoords, merge_nusc_static_classes, downsample_gt, merge_classes2, merge_classes_lyft 7 | 8 | 9 | def collate_nusc_s_occ_200down100up(batch): 10 | """ 11 | Collate fuction for: 12 | - NuScenes 13 | - Singe image input models 14 | - Ground truths with occluded regions masked out 15 | 16 | Merges input classes "road_segment" + "lane" into "drivable_area" 17 | 18 | Down Up scheme: downsample 200 to 100, then upsample 100 to required resolution 19 | Output resolution: 100 20 | """ 21 | 22 | idxs, images, calibs, gt_maps, grids2d, class_dict, des_image_size = zip(*batch) 23 | 24 | # All class dicts are same, so only select first 25 | class_dict = class_dict[0] 26 | 27 | og_w, og_h = 1600, 900 28 | desired_w, desired_h = des_image_size[0] 29 | scale_w, scale_h = desired_w / og_w, desired_h / og_h 30 | 31 | # Scale image 32 | images = [ 33 | image.resize((int(image.size[0] * scale_w), int(image.size[1] * scale_h))) 34 | for image in images 35 | ] 36 | 37 | # Pad images to the same dimensions 38 | w = [img.size[0] for img in images] 39 | h = [img.size[1] for img in images] 40 | 41 | delta_w = [desired_w - img.size[0] for img in images] 42 | delta_h = [desired_h - img.size[1] for img in images] 43 | 44 | pad_left = [int(d / 2) for d in delta_w] 45 | pad_right = [delta_w[i] - pad_left[i] for i in range(len(images))] 46 | pad_top = [int(d / 2) for d in delta_h] 47 | pad_bottom = [delta_h[i] - pad_top[i] for i in range(len(images))] 48 | 49 | left = [0 - pad_left[i] for i in range(len(images))] 50 | right = [pad_right[i] + w[i] for i in range(len(images))] 51 | top = [0 - pad_top[i] for i in range(len(images))] 52 | bottom = [pad_bottom[i] + h[i] for i in range(len(images))] 53 | 54 | images = [ 55 | images[i].crop((left[i], top[i], right[i], bottom[i])) 56 | for i in range(len(images)) 57 | ] 58 | 59 | w = [img.size[0] for img in images] 60 | h = [img.size[1] for img in images] 61 | 62 | # Stack images and calibration matrices along the batch dimension 63 | images = torch.stack([to_tensor(img) for img in images]) 64 | gt_maps = torch.stack( 65 | [ 66 | torch.stack([to_tensor(class_map) for class_map in gt_map]) 67 | for gt_map in gt_maps 68 | ] 69 | ).squeeze(2) 70 | calibs = torch.stack(calibs) 71 | grids2d = torch.stack(grids2d) 72 | 73 | # Create visbility mask from lidar and fov masks 74 | lidar_ray_mask = gt_maps[:, -2:-1] 75 | fov_mask = gt_maps[:, -1:] 76 | vis_mask = lidar_ray_mask * fov_mask 77 | vis_mask = downsample_gt(vis_mask, [[100, 100]])[0] 78 | 79 | # Downsample all classes from 200x200 to 100x100 80 | gt_maps_all_classes = gt_maps[:, :-2] 81 | gt_maps_200down100 = downsample_gt(gt_maps_all_classes, [[100, 100]])[0] 82 | 83 | # Apply vis mask 84 | gt_maps = gt_maps_200down100 * vis_mask 85 | 86 | # Modify calibration matrices 87 | # Scale first two rows of calibration matrix 88 | calibs[:, :2, :] *= scale_w 89 | # cx' = cx - du 90 | calibs[:, 0, 2] = calibs[:, 0, 2] + torch.tensor(pad_left) 91 | # cy' = cy - dv 92 | calibs[:, 1, 2] = calibs[:, 1, 2] + torch.tensor(pad_top) 93 | 94 | # Create a vector of indices 95 | idxs = torch.LongTensor(idxs) 96 | 97 | # Add drivable area to first cell in FOV 98 | # gt_maps = gt_maps_masked.squeeze(2) 99 | gt_maps[:, 0, 0, int(gt_maps.shape[-1] / 2 - 1)] = 1 100 | gt_maps[:, 0, 0, int(gt_maps.shape[-1] / 2)] = 1 101 | 102 | # Merge ground truth for road_segment + lane into one and create new gt 103 | classes_to_merge = ["drivable_area", "road_segment", "lane"] 104 | class_idx_to_merge = [class_dict[name] for name in classes_to_merge] 105 | merged_class_idx = class_dict["drivable_area"] 106 | gt_maps_new = torch.stack( 107 | [ 108 | merge_nusc_static_classes(gt, class_idx_to_merge, merged_class_idx) 109 | for gt in gt_maps 110 | ] 111 | ) 112 | 113 | return idxs, images, calibs, gt_maps_new, grids2d, vis_mask 114 | 115 | 116 | def collate_nusc_s(batch): 117 | """ 118 | Collate fuction for: 119 | - NuScenes 120 | - Singe image input models 121 | - Ground truths with occluded regions masked out 122 | """ 123 | 124 | images, cls_maps, vis_masks, calibs, grids2d = zip(*batch) 125 | 126 | # Stack images and calibration matrices along the batch dimension 127 | images = torch.stack(images) 128 | cls_maps = torch.stack(cls_maps) 129 | vis_masks = torch.stack(vis_masks) 130 | calibs = torch.stack(calibs) 131 | grids2d = torch.stack(grids2d) 132 | 133 | return (images, calibs, grids2d), (cls_maps, vis_masks) -------------------------------------------------------------------------------- /src/data/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import csv 4 | 5 | from pathlib import Path 6 | import io 7 | import lmdb 8 | import pickle 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | from torch.utils.data import DataLoader 12 | from PIL import Image 13 | import numpy as np 14 | from torch.utils.data import Dataset 15 | from torchvision.transforms.functional import to_tensor, to_pil_image 16 | import torch.nn.functional as F 17 | 18 | from nuscenes.nuscenes import NuScenes 19 | from nuscenes.utils.geometry_utils import ( 20 | view_points, 21 | box_in_image, 22 | BoxVisibility, 23 | transform_matrix, 24 | ) 25 | from nuscenes.map_expansion.map_api import NuScenesMap, NuScenesMapExplorer 26 | from nuscenes.utils.data_classes import LidarPointCloud 27 | 28 | from src import utils 29 | 30 | 31 | class nuScenesMaps(Dataset): 32 | 33 | def __init__( 34 | self, 35 | root="temp", 36 | split="train_mini", 37 | grid_size=(50.0, 50.0), 38 | grid_res=1.0, 39 | classes=[ 40 | "bus", 41 | "bicycle", 42 | "car", 43 | "construction_vehicle", 44 | "motorcycle", 45 | "trailer", 46 | "truck", 47 | "pedestrian", 48 | ], 49 | dataset_size=1.0, 50 | mini=False, 51 | desired_image_size=(1280, 720), 52 | gt_out_size=(100, 100) 53 | ): 54 | self.dataset_size = dataset_size 55 | self.desired_image_size = desired_image_size 56 | self.gt_out_size = gt_out_size 57 | 58 | # paths for data files 59 | self.root = os.path.join(root) 60 | self.gtmaps_db_path = os.path.join( 61 | root, "lmdb", 62 | "semantic_maps_new_200x200" 63 | ) 64 | self.images_db_path = os.path.join( 65 | root, "lmdb", 66 | "samples", "CAM_FRONT" 67 | ) 68 | 69 | # databases 70 | if mini: 71 | self.nusc = NuScenes(version="v1.0-mini", 72 | dataroot=self.root, 73 | verbose=False) 74 | else: 75 | self.nusc = NuScenes(version="v1.0-trainval", dataroot=self.root, verbose=False) 76 | 77 | self.tokens = read_split( 78 | os.path.join(root, "splits", "{}.txt".format(split)) 79 | ) 80 | self.gtmaps_db = lmdb.open( 81 | path=self.gtmaps_db_path, 82 | readonly=True, 83 | readahead=False, 84 | max_spare_txns=128, 85 | lock=False, 86 | ) 87 | self.images_db = lmdb.open( 88 | path=self.images_db_path, 89 | readonly=True, 90 | readahead=False, 91 | max_spare_txns=128, 92 | lock=False, 93 | ) 94 | 95 | # Set classes 96 | self.classes = list(classes) 97 | self.classes.append("lidar_ray_mask_dense") 98 | self.class2idx = { 99 | name: idx for idx, name in enumerate(self.classes) 100 | } 101 | self.nusc_classes = [ 102 | "vehicle.bus", 103 | "vehicle.bicycle", 104 | "vehicle.car", 105 | "vehicle.construction", 106 | "vehicle.motorcycle", 107 | "vehicle.trailer", 108 | "vehicle.truck", 109 | "human.pedestrian", 110 | ] 111 | self.nuscclass2idx = { 112 | name: idx for idx, name in enumerate(self.nusc_classes) 113 | } 114 | # load FOV mask 115 | self.fov_mask = Image.open( 116 | os.path.join(root, "lmdb", "semantic_maps_new_200x200", "fov_mask.png") 117 | ) 118 | # Make grid 119 | self.grid2d = utils.make_grid2d(grid_size, (-grid_size[0] / 2.0, 0.0), grid_res) 120 | 121 | def __len__(self): 122 | return int(len(self.tokens) * self.dataset_size - 1) 123 | 124 | def __getitem__(self, index): 125 | 126 | 127 | # Load sample ID 128 | sample_token = self.tokens[index] 129 | sample_record = self.nusc.get("sample", sample_token) 130 | cam_token = sample_record["data"]["CAM_FRONT"] 131 | cam_record = self.nusc.get("sample_data", cam_token) 132 | cam_path = self.nusc.get_sample_data_path(cam_token) 133 | id = Path(cam_path).stem 134 | 135 | # Load intrinsincs 136 | calib = self.nusc.get( 137 | "calibrated_sensor", cam_record["calibrated_sensor_token"] 138 | )["camera_intrinsic"] 139 | calib = np.array(calib) 140 | 141 | # Load input images 142 | image_input_key = pickle.dumps(id) 143 | with self.images_db.begin() as txn: 144 | value = txn.get(key=image_input_key) 145 | image = Image.open(io.BytesIO(value)).convert(mode='RGB') 146 | 147 | # resize/augment images 148 | image, calib = self.image_calib_pad_and_crop(image, calib) 149 | image = to_tensor(image) 150 | calib = to_tensor(calib).reshape(3, 3) 151 | 152 | # Load ground truth maps 153 | gtmaps_key = [pickle.dumps("{}___{}".format(id, cls)) for cls in self.classes] 154 | with self.gtmaps_db.begin() as txn: 155 | value = [txn.get(key=key) for key in gtmaps_key] 156 | gtmaps = [Image.open(io.BytesIO(im)) for im in value] 157 | 158 | # each map is of shape [1, 200, 200] 159 | mapsdict = {cls: to_tensor(map) for cls, map in zip(self.classes, gtmaps)} 160 | mapsdict["fov_mask"] = to_tensor(self.fov_mask) 161 | mapsdict = self.merge_map_classes(mapsdict) 162 | 163 | # Create visbility mask from lidar and fov masks 164 | lidar_ray_mask = mapsdict['lidar_ray_mask_dense'] 165 | fov_mask = mapsdict['fov_mask'] 166 | vis_mask = lidar_ray_mask * fov_mask 167 | mapsdict['vis_mask'] = vis_mask 168 | 169 | del mapsdict['lidar_ray_mask_dense'], mapsdict['fov_mask'] 170 | 171 | # downsample maps to required output resolution 172 | mapsdict = { 173 | cls: F.interpolate(cls_map.unsqueeze(0), size=self.gt_out_size).squeeze(0) 174 | for cls, cls_map in mapsdict.items() 175 | } 176 | 177 | # apply vis mask to maps 178 | mapsdict = { 179 | cls: cls_map * mapsdict['vis_mask'] for cls, cls_map in mapsdict.items() 180 | } 181 | 182 | cls_maps = torch.cat( 183 | [cls_map for cls, cls_map in mapsdict.items() if 'mask' not in cls], dim=0 184 | ) 185 | vis_mask = mapsdict['vis_mask'] 186 | 187 | return ( 188 | image, cls_maps, vis_mask, calib, self.grid2d 189 | ) 190 | 191 | def merge_map_classes(self, mapsdict): 192 | classes_to_merge = ["drivable_area", "road_segment", "lane"] 193 | merged_class = 'drivable_area' 194 | maps2merge = torch.stack([mapsdict[k] for k in classes_to_merge]) # [n, 1, 200, 200] 195 | maps2merge = maps2merge.sum(dim=0) 196 | maps2merge = (maps2merge > 0).float() 197 | mapsdict[merged_class] = maps2merge 198 | del mapsdict['road_segment'], mapsdict['lane'] 199 | return mapsdict 200 | 201 | def image_calib_pad_and_crop(self, image, calib): 202 | 203 | og_w, og_h = 1600, 900 204 | desired_w, desired_h = self.desired_image_size 205 | scale_w, scale_h = desired_w / og_w, desired_h / og_h 206 | # Scale image 207 | image = image.resize((int(image.size[0] * scale_w), int(image.size[1] * scale_h))) 208 | # Pad images to the same dimensions 209 | w = image.size[0] 210 | h = image.size[1] 211 | delta_w = desired_w - w 212 | delta_h = desired_h - h 213 | pad_left = int(delta_w / 2) 214 | pad_right = delta_w - pad_left 215 | pad_top = int(delta_h / 2) 216 | pad_bottom = delta_h - pad_top 217 | left = 0 - pad_left 218 | right = pad_right + w 219 | top = 0 - pad_top 220 | bottom = pad_bottom + h 221 | image = image.crop((left, top, right, bottom)) 222 | 223 | # Modify calibration matrices 224 | # Scale first two rows of calibration matrix 225 | calib[:2, :] *= scale_w 226 | # cx' = cx - du 227 | calib[0, 2] = calib[0, 2] + pad_left 228 | # cy' = cy - dv 229 | calib[1, 2] = calib[1, 2] + pad_top 230 | 231 | return image, calib 232 | 233 | 234 | def read_split(filename): 235 | """ 236 | Read a list of NuScenes sample tokens 237 | """ 238 | with open(filename, "r") as f: 239 | lines = f.read().split("\n") 240 | return [val for val in lines if val != ""] 241 | 242 | 243 | def create_batch_indices_old(split, batch_size, seq_len, n_pred_frames): 244 | nuscenes_root = "/vol/research/sceneEvolution/data/nuscenes" 245 | scene_len_file = os.path.join( 246 | nuscenes_root, "splits", (split + "_with_seq_len.txt") 247 | ) 248 | scene_len = np.array(read_split(scene_len_file), dtype=np.int) 249 | cumsum_scene_len = np.cumsum(scene_len) 250 | zeros = np.zeros(len(cumsum_scene_len) + 1, dtype=np.int) 251 | zeros[1:] = cumsum_scene_len 252 | cumsum_scene_len = zeros 253 | 254 | idxs_batch_start = [] 255 | idxs_batch_end = [] 256 | idxs_batch_pred = [] 257 | 258 | for idx_scene, scene in enumerate(scene_len): 259 | # Split scene length into chunks of size seq_len 260 | nbatches_in_scene = (scene - 1) // seq_len 261 | local_batch_num = (np.arange(nbatches_in_scene, dtype=np.int) + 1) * seq_len 262 | z = np.zeros(len(local_batch_num) + 1, dtype=np.int) 263 | z[1:] = local_batch_num 264 | local_batch_idx = z 265 | 266 | # Add cumsum scene_lengths to get global idx 267 | global_batch_idx = local_batch_idx + cumsum_scene_len[idx_scene] 268 | 269 | start_batch_idx = global_batch_idx[:-1] 270 | end_batch_idx = global_batch_idx[1:] - 1 271 | pred_batch_idx = end_batch_idx + n_pred_frames 272 | pred_batch_idx = np.clip( 273 | pred_batch_idx, a_min=0, a_max=scene - 1 + cumsum_scene_len[idx_scene] 274 | ) 275 | 276 | idxs_batch_start.extend(list(start_batch_idx)) 277 | idxs_batch_end.extend(list(end_batch_idx)) 278 | idxs_batch_pred.extend(list(pred_batch_idx)) 279 | 280 | return idxs_batch_start, idxs_batch_end, idxs_batch_pred 281 | 282 | 283 | def create_batch_indices(split, batch_size, seq_len, n_pred_frames): 284 | nuscenes_root = "/vol/research/sceneEvolution/data/nuscenes" 285 | scene_len_file = os.path.join( 286 | nuscenes_root, "splits", (split + "_with_seq_len.txt") 287 | ) 288 | scene_len = np.array(read_split(scene_len_file), dtype=np.int) 289 | cumsum_scene_len = np.cumsum(scene_len) 290 | zeros = np.zeros(len(cumsum_scene_len) + 1, dtype=np.int) 291 | zeros[1:] = cumsum_scene_len 292 | cumsum_scene_len = zeros 293 | 294 | # Offset for sliding window through sequences 295 | offset = seq_len // 2 296 | 297 | idxs_batch_start = [] 298 | idxs_batch_end = [] 299 | idxs_batch_pred = [] 300 | 301 | for idx_scene, scene in enumerate(scene_len): 302 | # Split scene length into chunks of size seq_len 303 | nbatches_in_scene = (scene - 1) // seq_len 304 | local_batch_num = (np.arange(nbatches_in_scene, dtype=np.int) + 1) * seq_len 305 | z = np.zeros(len(local_batch_num) + 1, dtype=np.int) 306 | z[1:] = local_batch_num 307 | local_batch_idx = z 308 | 309 | # Add cumsum scene_lengths to get global idx 310 | global_batch_idx = local_batch_idx + cumsum_scene_len[idx_scene] 311 | 312 | start_batch_idx = global_batch_idx[:-1] 313 | end_batch_idx = global_batch_idx[1:] - 1 314 | pred_batch_idx = end_batch_idx + n_pred_frames 315 | pred_batch_idx = np.clip( 316 | pred_batch_idx, a_min=0, a_max=scene - 1 + cumsum_scene_len[idx_scene] 317 | ) 318 | 319 | idxs_batch_start.extend(list(start_batch_idx)) 320 | idxs_batch_end.extend(list(end_batch_idx)) 321 | idxs_batch_pred.extend(list(pred_batch_idx)) 322 | 323 | # Create intermediate sequences (first trim either end) 324 | start_batch_idx = start_batch_idx[1:-1] - offset 325 | end_batch_idx = end_batch_idx[1:-1] - offset 326 | pred_batch_idx = pred_batch_idx[1:-1] - offset 327 | 328 | idxs_batch_start.extend(list(start_batch_idx)) 329 | idxs_batch_end.extend(list(end_batch_idx)) 330 | idxs_batch_pred.extend(list(pred_batch_idx)) 331 | 332 | return idxs_batch_start, idxs_batch_end, idxs_batch_pred 333 | 334 | 335 | def create_batch_indices_wo_int(split, batch_size, seq_len, n_pred_frames): 336 | nuscenes_root = "/vol/research/sceneEvolution/data/nuscenes" 337 | scene_len_file = os.path.join( 338 | nuscenes_root, "splits", (split + "_with_seq_len.txt") 339 | ) 340 | scene_len = np.array(read_split(scene_len_file), dtype=np.int) 341 | cumsum_scene_len = np.cumsum(scene_len) 342 | zeros = np.zeros(len(cumsum_scene_len) + 1, dtype=np.int) 343 | zeros[1:] = cumsum_scene_len 344 | cumsum_scene_len = zeros 345 | 346 | # Offset for sliding window through sequences 347 | offset = seq_len // 2 348 | 349 | idxs_batch_start = [] 350 | idxs_batch_end = [] 351 | idxs_batch_pred = [] 352 | 353 | for idx_scene, scene in enumerate(scene_len): 354 | # Split scene length into chunks of size seq_len 355 | nbatches_in_scene = (scene - 1) // seq_len 356 | local_batch_num = (np.arange(nbatches_in_scene, dtype=np.int) + 1) * seq_len 357 | z = np.zeros(len(local_batch_num) + 1, dtype=np.int) 358 | z[1:] = local_batch_num 359 | local_batch_idx = z 360 | 361 | # Add cumsum scene_lengths to get global idx 362 | global_batch_idx = local_batch_idx + cumsum_scene_len[idx_scene] 363 | 364 | start_batch_idx = global_batch_idx[:-1] 365 | end_batch_idx = global_batch_idx[1:] - 1 366 | pred_batch_idx = end_batch_idx + n_pred_frames 367 | pred_batch_idx = np.clip( 368 | pred_batch_idx, a_min=0, a_max=scene - 1 + cumsum_scene_len[idx_scene] 369 | ) 370 | 371 | idxs_batch_start.extend(list(start_batch_idx)) 372 | idxs_batch_end.extend(list(end_batch_idx)) 373 | idxs_batch_pred.extend(list(pred_batch_idx)) 374 | 375 | # # Create intermediate sequences (first trim either end) 376 | # start_batch_idx = start_batch_idx[1:-1] - offset 377 | # end_batch_idx = end_batch_idx[1:-1] - offset 378 | # pred_batch_idx = pred_batch_idx[1:-1] - offset 379 | # 380 | # idxs_batch_start.extend(list(start_batch_idx)) 381 | # idxs_batch_end.extend(list(end_batch_idx)) 382 | # idxs_batch_pred.extend(list(pred_batch_idx)) 383 | 384 | return idxs_batch_start, idxs_batch_end, idxs_batch_pred 385 | 386 | 387 | def create_batch_indices2(split, seq_len): 388 | nuscenes_root = "/vol/research/sceneEvolution/data/nuscenes" 389 | scene_len_file = os.path.join( 390 | nuscenes_root, "splits", (split + "_with_seq_len.txt") 391 | ) 392 | scene_len = np.array(read_split(scene_len_file), dtype=np.int) 393 | cumsum_scene_len = np.cumsum(scene_len) 394 | zeros = np.zeros(len(cumsum_scene_len) + 1, dtype=np.int) 395 | zeros[1:] = cumsum_scene_len 396 | cumsum_scene_len = zeros 397 | 398 | idxs_batch_start = [] 399 | for idx_scene, scene in enumerate(scene_len): 400 | # Split scene length into chunks of size seq_len 401 | nbatches_in_scene = (scene - 1) // seq_len 402 | local_batch_num = (np.arange(nbatches_in_scene, dtype=np.int) + 1) * seq_len 403 | z = np.zeros(len(local_batch_num) + 1, dtype=np.int) 404 | z[1:] = local_batch_num 405 | local_batch_idx = z 406 | 407 | # Add cumsum scene_lengths to get global idx 408 | global_batch_idx = local_batch_idx + cumsum_scene_len[idx_scene] 409 | idxs_batch_start.extend(list(global_batch_idx)) 410 | 411 | idxs_batch_end = list(np.array(idxs_batch_start, dtype=np.int) - 1) 412 | 413 | return idxs_batch_start, idxs_batch_end 414 | 415 | 416 | def create_batch_indices_mini(split, batch_size, seq_len, n_pred_frames): 417 | nuscenes_root = "/vol/research/sceneEvolution/data/nuscenes" 418 | scene_len_file = os.path.join( 419 | nuscenes_root, "splits", (split + "_mini_with_seq_len.txt") 420 | ) 421 | scene_len = np.array(read_split(scene_len_file), dtype=np.int) 422 | cumsum_scene_len = np.cumsum(scene_len) 423 | zeros = np.zeros(len(cumsum_scene_len) + 1, dtype=np.int) 424 | zeros[1:] = cumsum_scene_len 425 | cumsum_scene_len = zeros 426 | 427 | idxs_batch_start = [] 428 | idxs_batch_end = [] 429 | idxs_batch_pred = [] 430 | 431 | for idx_scene, scene in enumerate(scene_len): 432 | # Split scene length into chunks of size seq_len 433 | nbatches_in_scene = (scene - 1) // seq_len 434 | local_batch_num = (np.arange(nbatches_in_scene, dtype=np.int) + 1) * seq_len 435 | z = np.zeros(len(local_batch_num) + 1, dtype=np.int) 436 | z[1:] = local_batch_num 437 | local_batch_idx = z 438 | 439 | # Add cumsum scene_lengths to get global idx 440 | global_batch_idx = local_batch_idx + cumsum_scene_len[idx_scene] 441 | 442 | start_batch_idx = global_batch_idx[:-1] 443 | end_batch_idx = global_batch_idx[1:] - 1 444 | pred_batch_idx = end_batch_idx + n_pred_frames 445 | pred_batch_idx = np.clip( 446 | pred_batch_idx, a_min=0, a_max=scene - 1 + cumsum_scene_len[idx_scene] 447 | ) 448 | 449 | idxs_batch_start.extend(list(start_batch_idx)) 450 | idxs_batch_end.extend(list(end_batch_idx)) 451 | idxs_batch_pred.extend(list(pred_batch_idx)) 452 | 453 | return idxs_batch_start, idxs_batch_end, idxs_batch_pred 454 | 455 | -------------------------------------------------------------------------------- /src/data/dataloader_new.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.data import DataLoader, RandomSampler 3 | from .augmentations import AugmentedMapDataset 4 | 5 | from nuscenes import NuScenes 6 | from .nuscenes.dataset import NuScenesMapDataset 7 | from .nuscenes.splits import TRAIN_SCENES, VAL_SCENES, CALIBRATION_SCENES 8 | 9 | 10 | def build_nuscenes_datasets(config): 11 | print('==> Loading NuScenes dataset...') 12 | nuscenes = NuScenes(config.nuscenes_version, 13 | os.path.expandvars(config.dataroot)) 14 | 15 | # Exclude calibration scenes 16 | if config.hold_out_calibration: 17 | train_scenes = list(set(TRAIN_SCENES) - set(CALIBRATION_SCENES)) 18 | else: 19 | train_scenes = TRAIN_SCENES 20 | 21 | train_data = NuScenesMapDataset(nuscenes, config.label_root, 22 | config.img_size, train_scenes) 23 | val_data = NuScenesMapDataset(nuscenes, config.label_root, 24 | config.img_size, VAL_SCENES) 25 | return train_data, val_data 26 | 27 | def build_datasets(dataset_name, config): 28 | if dataset_name == 'nuscenes': 29 | return build_nuscenes_datasets(config) 30 | else: 31 | raise ValueError(f"Unknown dataset option '{dataset_name}'") 32 | 33 | 34 | def build_trainval_datasets(dataset_name, config): 35 | # Construct the base dataset 36 | train_data, val_data = build_datasets(dataset_name, config) 37 | 38 | # Add data augmentation to train dataset 39 | train_data = AugmentedMapDataset(train_data, config.hflip) 40 | 41 | return train_data, val_data 42 | 43 | 44 | def build_dataloaders(dataset_name, config): 45 | # Build training and validation datasets 46 | train_data, val_data = build_trainval_datasets(dataset_name, config) 47 | 48 | # Create training set dataloader 49 | sampler = RandomSampler(train_data, True, config.epoch_size) 50 | train_loader = DataLoader(train_data, config.batch_size, sampler=sampler, 51 | num_workers=config.num_workers) 52 | 53 | # Create validation dataloader 54 | val_loader = DataLoader(val_data, config.batch_size, 55 | num_workers=config.num_workers) 56 | 57 | return train_loader, val_loader 58 | 59 | 60 | 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /src/data/nuscenes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avishkarsaha/translating-images-into-maps/a452fe7a1eb0062f96133d14380d504e7bc0c8d1/src/data/nuscenes/__init__.py -------------------------------------------------------------------------------- /src/data/nuscenes/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import Dataset 4 | from PIL import Image, ImageFile 5 | from nuscenes import NuScenes 6 | from torchvision.transforms.functional import to_tensor 7 | 8 | from .utils import CAMERA_NAMES, NUSCENES_CLASS_NAMES, iterate_samples 9 | from ..utils import decode_binary_labels 10 | 11 | class NuScenesMapDataset(Dataset): 12 | 13 | def __init__(self, nuscenes, map_root, image_size=(800, 450), 14 | scene_names=None): 15 | 16 | self.nuscenes = nuscenes 17 | self.map_root = os.path.expandvars(map_root) 18 | self.image_size = image_size 19 | 20 | # Preload the list of tokens in the dataset 21 | self.get_tokens(scene_names) 22 | 23 | # Allow PIL to load partially corrupted images 24 | # (otherwise training crashes at the most inconvenient possible times!) 25 | ImageFile.LOAD_TRUNCATED_IMAGES = True 26 | 27 | 28 | def get_tokens(self, scene_names=None): 29 | 30 | self.tokens = list() 31 | 32 | # Iterate over scenes 33 | for scene in self.nuscenes.scene: 34 | 35 | # Ignore scenes which don't belong to the current split 36 | if scene_names is not None and scene['name'] not in scene_names: 37 | continue 38 | 39 | # Iterate over samples 40 | for sample in iterate_samples(self.nuscenes, 41 | scene['first_sample_token']): 42 | 43 | # Iterate over cameras 44 | for camera in CAMERA_NAMES: 45 | self.tokens.append(sample['data'][camera]) 46 | 47 | return self.tokens 48 | 49 | 50 | def __len__(self): 51 | return len(self.tokens) 52 | 53 | def __getitem__(self, index): 54 | token = self.tokens[index] 55 | 56 | image = self.load_image(token) 57 | calib = self.load_calib(token) 58 | labels, mask = self.load_labels(token) 59 | 60 | return image, calib, labels, mask 61 | 62 | 63 | def load_image(self, token): 64 | 65 | # Load image as a PIL image 66 | image = Image.open(self.nuscenes.get_sample_data_path(token)) 67 | 68 | # Resize to input resolution 69 | image = image.resize(self.image_size) 70 | 71 | # Convert to a torch tensor 72 | return to_tensor(image) 73 | 74 | 75 | def load_calib(self, token): 76 | 77 | # Load camera intrinsics matrix 78 | sample_data = self.nuscenes.get('sample_data', token) 79 | sensor = self.nuscenes.get( 80 | 'calibrated_sensor', sample_data['calibrated_sensor_token']) 81 | intrinsics = torch.tensor(sensor['camera_intrinsic']) 82 | 83 | # Scale calibration matrix to account for image downsampling 84 | intrinsics[0] *= self.image_size[0] / sample_data['width'] 85 | intrinsics[1] *= self.image_size[1] / sample_data['height'] 86 | return intrinsics 87 | 88 | 89 | def load_labels(self, token): 90 | 91 | # Load label image as a torch tensor 92 | label_path = os.path.join(self.map_root, token + '.png') 93 | encoded_labels = to_tensor(Image.open(label_path)).long() 94 | 95 | # Decode to binary labels 96 | num_class = len(NUSCENES_CLASS_NAMES) 97 | labels = decode_binary_labels(encoded_labels, num_class + 1) 98 | labels, mask = labels[:-1], ~labels[-1] 99 | 100 | return labels, mask 101 | 102 | 103 | 104 | -------------------------------------------------------------------------------- /src/data/nuscenes/splits.py: -------------------------------------------------------------------------------- 1 | 2 | TRAIN_SCENES = [ 3 | "scene-0002", "scene-0003", "scene-0004", "scene-0005", "scene-0006", 4 | "scene-0007", "scene-0008", "scene-0009", "scene-0012", "scene-0013", 5 | "scene-0014", "scene-0015", "scene-0016", "scene-0017", "scene-0018", 6 | "scene-0019", "scene-0021", "scene-0022", "scene-0023", "scene-0024", 7 | "scene-0025", "scene-0026", "scene-0027", "scene-0028", "scene-0029", 8 | "scene-0030", "scene-0031", "scene-0032", "scene-0033", "scene-0034", 9 | "scene-0035", "scene-0036", "scene-0039", "scene-0042", "scene-0043", 10 | "scene-0044", "scene-0045", "scene-0046", "scene-0047", "scene-0048", 11 | "scene-0049", "scene-0050", "scene-0051", "scene-0052", "scene-0055", 12 | "scene-0056", "scene-0057", "scene-0058", "scene-0059", "scene-0060", 13 | "scene-0061", "scene-0062", "scene-0063", "scene-0064", "scene-0065", 14 | "scene-0066", "scene-0067", "scene-0068", "scene-0069", "scene-0070", 15 | "scene-0071", "scene-0072", "scene-0073", "scene-0074", "scene-0075", 16 | "scene-0076", "scene-0092", "scene-0093", "scene-0094", "scene-0095", 17 | "scene-0096", "scene-0097", "scene-0098", "scene-0099", "scene-0100", 18 | "scene-0101", "scene-0102", "scene-0103", "scene-0104", "scene-0105", 19 | "scene-0106", "scene-0107", "scene-0108", "scene-0109", "scene-0110", 20 | "scene-0120", "scene-0123", "scene-0124", "scene-0125", "scene-0126", 21 | "scene-0127", "scene-0128", "scene-0129", "scene-0130", "scene-0131", 22 | "scene-0132", "scene-0133", "scene-0134", "scene-0135", "scene-0138", 23 | "scene-0149", "scene-0150", "scene-0151", "scene-0154", "scene-0155", 24 | "scene-0157", "scene-0158", "scene-0159", "scene-0161", "scene-0162", 25 | "scene-0163", "scene-0164", "scene-0165", "scene-0166", "scene-0167", 26 | "scene-0168", "scene-0170", "scene-0171", "scene-0172", "scene-0173", 27 | "scene-0174", "scene-0175", "scene-0176", "scene-0177", "scene-0178", 28 | "scene-0179", "scene-0180", "scene-0181", "scene-0182", "scene-0183", 29 | "scene-0185", "scene-0187", "scene-0188", "scene-0190", "scene-0191", 30 | "scene-0192", "scene-0193", "scene-0194", "scene-0195", "scene-0196", 31 | "scene-0199", "scene-0200", "scene-0202", "scene-0203", "scene-0204", 32 | "scene-0206", "scene-0207", "scene-0208", "scene-0209", "scene-0210", 33 | "scene-0211", "scene-0212", "scene-0213", "scene-0214", "scene-0218", 34 | "scene-0219", "scene-0220", "scene-0221", "scene-0222", "scene-0224", 35 | "scene-0225", "scene-0226", "scene-0227", "scene-0228", "scene-0229", 36 | "scene-0230", "scene-0231", "scene-0232", "scene-0233", "scene-0234", 37 | "scene-0235", "scene-0236", "scene-0237", "scene-0238", "scene-0239", 38 | "scene-0240", "scene-0241", "scene-0242", "scene-0243", "scene-0244", 39 | "scene-0245", "scene-0246", "scene-0247", "scene-0248", "scene-0249", 40 | "scene-0250", "scene-0251", "scene-0252", "scene-0253", "scene-0254", 41 | "scene-0255", "scene-0256", "scene-0257", "scene-0258", "scene-0259", 42 | "scene-0260", "scene-0261", "scene-0262", "scene-0263", "scene-0264", 43 | "scene-0268", "scene-0270", "scene-0271", "scene-0272", "scene-0273", 44 | "scene-0274", "scene-0275", "scene-0276", "scene-0277", "scene-0278", 45 | "scene-0283", "scene-0284", "scene-0285", "scene-0286", "scene-0287", 46 | "scene-0288", "scene-0289", "scene-0290", "scene-0291", "scene-0292", 47 | "scene-0293", "scene-0294", "scene-0295", "scene-0296", "scene-0297", 48 | "scene-0298", "scene-0299", "scene-0300", "scene-0301", "scene-0302", 49 | "scene-0303", "scene-0304", "scene-0305", "scene-0306", "scene-0315", 50 | "scene-0316", "scene-0317", "scene-0318", "scene-0321", "scene-0323", 51 | "scene-0324", "scene-0328", "scene-0329", "scene-0330", "scene-0331", 52 | "scene-0332", "scene-0344", "scene-0345", "scene-0346", "scene-0349", 53 | "scene-0350", "scene-0351", "scene-0352", "scene-0353", "scene-0354", 54 | "scene-0355", "scene-0356", "scene-0357", "scene-0358", "scene-0359", 55 | "scene-0360", "scene-0361", "scene-0362", "scene-0363", "scene-0364", 56 | "scene-0365", "scene-0367", "scene-0370", "scene-0371", "scene-0372", 57 | "scene-0373", "scene-0374", "scene-0375", "scene-0376", "scene-0377", 58 | "scene-0379", "scene-0380", "scene-0381", "scene-0382", "scene-0383", 59 | "scene-0384", "scene-0385", "scene-0386", "scene-0388", "scene-0399", 60 | "scene-0400", "scene-0401", "scene-0402", "scene-0403", "scene-0405", 61 | "scene-0406", "scene-0407", "scene-0408", "scene-0420", "scene-0421", 62 | "scene-0422", "scene-0423", "scene-0424", "scene-0425", "scene-0426", 63 | "scene-0427", "scene-0428", "scene-0429", "scene-0430", "scene-0431", 64 | "scene-0432", "scene-0433", "scene-0434", "scene-0435", "scene-0436", 65 | "scene-0437", "scene-0438", "scene-0439", "scene-0440", "scene-0441", 66 | "scene-0442", "scene-0443", "scene-0444", "scene-0445", "scene-0446", 67 | "scene-0447", "scene-0448", "scene-0449", "scene-0450", "scene-0451", 68 | "scene-0452", "scene-0453", "scene-0454", "scene-0455", "scene-0456", 69 | "scene-0457", "scene-0458", "scene-0459", "scene-0461", "scene-0462", 70 | "scene-0463", "scene-0464", "scene-0465", "scene-0467", "scene-0468", 71 | "scene-0469", "scene-0471", "scene-0472", "scene-0474", "scene-0475", 72 | "scene-0476", "scene-0477", "scene-0478", "scene-0479", "scene-0480", 73 | "scene-0499", "scene-0500", "scene-0501", "scene-0502", "scene-0504", 74 | "scene-0505", "scene-0506", "scene-0507", "scene-0508", "scene-0509", 75 | "scene-0510", "scene-0511", "scene-0512", "scene-0513", "scene-0514", 76 | "scene-0515", "scene-0517", "scene-0518", "scene-0519", "scene-0520", 77 | "scene-0521", "scene-0522", "scene-0523", "scene-0524", "scene-0552", 78 | "scene-0553", "scene-0554", "scene-0555", "scene-0559", "scene-0560", 79 | "scene-0561", "scene-0562", "scene-0563", "scene-0564", "scene-0565", 80 | "scene-0584", "scene-0585", "scene-0586", "scene-0587", "scene-0588", 81 | "scene-0589", "scene-0590", "scene-0591", "scene-0592", "scene-0593", 82 | "scene-0594", "scene-0595", "scene-0596", "scene-0597", "scene-0598", 83 | "scene-0599", "scene-0600", "scene-0625", "scene-0626", "scene-0627", 84 | "scene-0629", "scene-0630", "scene-0632", "scene-0633", "scene-0634", 85 | "scene-0635", "scene-0636", "scene-0637", "scene-0638", "scene-0639", 86 | "scene-0640", "scene-0652", "scene-0653", "scene-0654", "scene-0655", 87 | "scene-0656", "scene-0657", "scene-0658", "scene-0659", "scene-0660", 88 | "scene-0661", "scene-0662", "scene-0663", "scene-0664", "scene-0665", 89 | "scene-0666", "scene-0667", "scene-0668", "scene-0669", "scene-0670", 90 | "scene-0671", "scene-0672", "scene-0673", "scene-0674", "scene-0675", 91 | "scene-0676", "scene-0677", "scene-0678", "scene-0679", "scene-0681", 92 | "scene-0683", "scene-0684", "scene-0685", "scene-0686", "scene-0687", 93 | "scene-0688", "scene-0689", "scene-0695", "scene-0696", "scene-0697", 94 | "scene-0698", "scene-0700", "scene-0701", "scene-0703", "scene-0704", 95 | "scene-0705", "scene-0706", "scene-0707", "scene-0708", "scene-0709", 96 | "scene-0710", "scene-0711", "scene-0712", "scene-0713", "scene-0714", 97 | "scene-0715", "scene-0716", "scene-0717", "scene-0718", "scene-0719", 98 | "scene-0726", "scene-0727", "scene-0728", "scene-0730", "scene-0731", 99 | "scene-0733", "scene-0734", "scene-0735", "scene-0736", "scene-0737", 100 | "scene-0738", "scene-0780", "scene-0781", "scene-0782", "scene-0783", 101 | "scene-0784", "scene-0786", "scene-0787", "scene-0789", "scene-0790", 102 | "scene-0791", "scene-0792", "scene-0802", "scene-0806", "scene-0808", 103 | "scene-0809", "scene-0810", "scene-0811", "scene-0812", "scene-0813", 104 | "scene-0815", "scene-0816", "scene-0817", "scene-0819", "scene-0820", 105 | "scene-0821", "scene-0822", "scene-0847", "scene-0848", "scene-0849", 106 | "scene-0850", "scene-0851", "scene-0852", "scene-0853", "scene-0854", 107 | "scene-0855", "scene-0856", "scene-0858", "scene-0860", "scene-0861", 108 | "scene-0862", "scene-0863", "scene-0864", "scene-0865", "scene-0866", 109 | "scene-0868", "scene-0869", "scene-0870", "scene-0871", "scene-0872", 110 | "scene-0873", "scene-0875", "scene-0876", "scene-0877", "scene-0878", 111 | "scene-0880", "scene-0882", "scene-0883", "scene-0884", "scene-0885", 112 | "scene-0886", "scene-0887", "scene-0888", "scene-0889", "scene-0890", 113 | "scene-0891", "scene-0892", "scene-0893", "scene-0894", "scene-0895", 114 | "scene-0896", "scene-0897", "scene-0898", "scene-0899", "scene-0900", 115 | "scene-0901", "scene-0902", "scene-0903", "scene-0904", "scene-0905", 116 | "scene-0906", "scene-0907", "scene-0908", "scene-0909", "scene-0916", 117 | "scene-0917", "scene-0921", "scene-0922", "scene-0923", "scene-0925", 118 | "scene-0926", "scene-0927", "scene-0928", "scene-0929", "scene-0930", 119 | "scene-0931", "scene-0945", "scene-0947", "scene-0949", "scene-0952", 120 | "scene-0953", "scene-0955", "scene-0956", "scene-0957", "scene-0958", 121 | "scene-0959", "scene-0960", "scene-0961", "scene-0966", "scene-0967", 122 | "scene-0968", "scene-0969", "scene-0971", "scene-0972", "scene-0975", 123 | "scene-0976", "scene-0977", "scene-0978", "scene-0979", "scene-0980", 124 | "scene-0981", "scene-0982", "scene-0983", "scene-0984", "scene-0988", 125 | "scene-0989", "scene-0990", "scene-0991", "scene-0992", "scene-0994", 126 | "scene-0995", "scene-0996", "scene-0997", "scene-0998", "scene-0999", 127 | "scene-1000", "scene-1001", "scene-1004", "scene-1005", "scene-1006", 128 | "scene-1007", "scene-1008", "scene-1009", "scene-1010", "scene-1011", 129 | "scene-1012", "scene-1013", "scene-1014", "scene-1015", "scene-1019", 130 | "scene-1020", "scene-1021", "scene-1022", "scene-1023", "scene-1024", 131 | "scene-1025", "scene-1044", "scene-1045", "scene-1046", "scene-1047", 132 | "scene-1048", "scene-1049", "scene-1050", "scene-1051", "scene-1052", 133 | "scene-1053", "scene-1054", "scene-1064", "scene-1065", "scene-1066", 134 | "scene-1067", "scene-1068", "scene-1069", "scene-1070", "scene-1071", 135 | "scene-1072", "scene-1073", "scene-1074", "scene-1075", "scene-1076", 136 | "scene-1077", "scene-1078", "scene-1079", "scene-1080", "scene-1081", 137 | "scene-1082", "scene-1083", "scene-1084", "scene-1085", "scene-1086", 138 | "scene-1087", "scene-1088", "scene-1089", "scene-1090", "scene-1091", 139 | "scene-1092", "scene-1093", "scene-1094", "scene-1095", "scene-1096", 140 | "scene-1097", "scene-1098", "scene-1099", "scene-1100", "scene-1101", 141 | "scene-1102", "scene-1104", "scene-1105", "scene-1106", "scene-1107", 142 | "scene-1108", "scene-1109", "scene-1110"] 143 | 144 | VAL_SCENES = [ 145 | "scene-0001", "scene-0010", "scene-0011", "scene-0020", "scene-0038", 146 | "scene-0041", "scene-0053", "scene-0054", "scene-0121", "scene-0122", 147 | "scene-0139", "scene-0152", "scene-0160", "scene-0184", "scene-0269", 148 | "scene-0347", "scene-0348", "scene-0366", "scene-0368", "scene-0369", 149 | "scene-0378", "scene-0389", "scene-0390", "scene-0391", "scene-0392", 150 | "scene-0393", "scene-0394", "scene-0395", "scene-0396", "scene-0397", 151 | "scene-0398", "scene-0411", "scene-0412", "scene-0413", "scene-0414", 152 | "scene-0415", "scene-0416", "scene-0417", "scene-0418", "scene-0419", 153 | "scene-0525", "scene-0526", "scene-0527", "scene-0528", "scene-0529", 154 | "scene-0530", "scene-0531", "scene-0532", "scene-0533", "scene-0534", 155 | "scene-0535", "scene-0536", "scene-0537", "scene-0538", "scene-0539", 156 | "scene-0541", "scene-0542", "scene-0543", "scene-0544", "scene-0545", 157 | "scene-0546", "scene-0556", "scene-0557", "scene-0558", "scene-0566", 158 | "scene-0568", "scene-0570", "scene-0571", "scene-0572", "scene-0573", 159 | "scene-0574", "scene-0575", "scene-0576", "scene-0577", "scene-0578", 160 | "scene-0580", "scene-0582", "scene-0583", "scene-0642", "scene-0643", 161 | "scene-0644", "scene-0645", "scene-0646", "scene-0647", "scene-0648", 162 | "scene-0649", "scene-0650", "scene-0651", "scene-0739", "scene-0740", 163 | "scene-0741", "scene-0744", "scene-0746", "scene-0747", "scene-0749", 164 | "scene-0750", "scene-0751", "scene-0752", "scene-0757", "scene-0758", 165 | "scene-0759", "scene-0760", "scene-0761", "scene-0762", "scene-0763", 166 | "scene-0764", "scene-0765", "scene-0767", "scene-0768", "scene-0769", 167 | "scene-0770", "scene-0771", "scene-0775", "scene-0777", "scene-0778", 168 | "scene-0794", "scene-0795", "scene-0796", "scene-0797", "scene-0798", 169 | "scene-0799", "scene-0800", "scene-0803", "scene-0804", "scene-0911", 170 | "scene-0912", "scene-0913", "scene-0914", "scene-0915", "scene-0919", 171 | "scene-0920", "scene-0924", "scene-0962", "scene-0963", "scene-1002", 172 | "scene-1003", "scene-1016", "scene-1017", "scene-1018", "scene-1055", 173 | "scene-1056", "scene-1057", "scene-1058", "scene-1059", "scene-1060", 174 | "scene-1061", "scene-1062", "scene-1063"] 175 | 176 | 177 | CALIBRATION_SCENES = [ 178 | "scene-0852", "scene-0429", "scene-0956", "scene-0194", "scene-0811", 179 | "scene-1110", "scene-1107", "scene-0294", "scene-0900", "scene-0596", 180 | "scene-0296", "scene-0885", "scene-0866", "scene-0105", "scene-0782", 181 | "scene-0191", "scene-0876", "scene-0133", "scene-0231", "scene-0847", 182 | "scene-0363", "scene-0026", "scene-0791", "scene-0909", "scene-0002", 183 | "scene-0283", "scene-0007", "scene-0251", "scene-1100", "scene-0668", 184 | "scene-0584", "scene-0287", "scene-0260", "scene-0171", "scene-0789", 185 | "scene-0108", "scene-0190", "scene-0206", "scene-0635", "scene-0815", 186 | "scene-0058", "scene-0710", "scene-0302", "scene-0639", "scene-0166", 187 | "scene-0094", "scene-0735", "scene-0321", "scene-1091", "scene-0344" 188 | ] -------------------------------------------------------------------------------- /src/data/nuscenes/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from shapely import geometry, affinity 4 | from pyquaternion import Quaternion 5 | 6 | from nuscenes.eval.detection.utils import category_to_detection_name 7 | from nuscenes.eval.detection.constants import DETECTION_NAMES 8 | from nuscenes.utils.data_classes import LidarPointCloud 9 | 10 | from ..utils import transform_polygon, render_polygon, transform 11 | 12 | CAMERA_NAMES = ['CAM_FRONT', 'CAM_FRONT_LEFT', 'CAM_FRONT_RIGHT', 13 | 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT', 'CAM_BACK'] 14 | 15 | NUSCENES_CLASS_NAMES = [ 16 | 'drivable_area', 'ped_crossing', 'walkway', 'carpark', 'car', 'truck', 17 | 'bus', 'trailer', 'construction_vehicle', 'pedestrian', 'motorcycle', 18 | 'bicycle', 'traffic_cone', 'barrier' 19 | ] 20 | 21 | STATIC_CLASSES = ['drivable_area', 'ped_crossing', 'walkway', 'carpark_area'] 22 | 23 | LOCATIONS = ['boston-seaport', 'singapore-onenorth', 'singapore-queenstown', 24 | 'singapore-hollandvillage'] 25 | 26 | 27 | def iterate_samples(nuscenes, start_token): 28 | sample_token = start_token 29 | while sample_token != '': 30 | sample = nuscenes.get('sample', sample_token) 31 | yield sample 32 | sample_token = sample['next'] 33 | 34 | 35 | def get_map_masks(nuscenes, map_data, sample_data, extents, resolution): 36 | 37 | # Render each layer sequentially 38 | layers = [get_layer_mask(nuscenes, polys, sample_data, extents, 39 | resolution) for layer, polys in map_data.items()] 40 | 41 | return np.stack(layers, axis=0) 42 | 43 | 44 | def get_layer_mask(nuscenes, polygons, sample_data, extents, resolution): 45 | 46 | # Get the 2D affine transform from bev coords to map coords 47 | tfm = get_sensor_transform(nuscenes, sample_data)[[0, 1, 3]][:, [0, 2, 3]] 48 | inv_tfm = np.linalg.inv(tfm) 49 | 50 | # Create a patch representing the birds-eye-view region in map coordinates 51 | map_patch = geometry.box(*extents) 52 | map_patch = transform_polygon(map_patch, tfm) 53 | 54 | # Initialise the map mask 55 | x1, z1, x2, z2 = extents 56 | mask = np.zeros((int((z2 - z1) / resolution), int((x2 - x1) / resolution)), 57 | dtype=np.uint8) 58 | 59 | # Find all polygons which intersect with the area of interest 60 | for polygon in polygons.query(map_patch): 61 | 62 | polygon = polygon.intersection(map_patch) 63 | 64 | # Transform into map coordinates 65 | polygon = transform_polygon(polygon, inv_tfm) 66 | 67 | # Render the polygon to the mask 68 | render_shapely_polygon(mask, polygon, extents, resolution) 69 | 70 | return mask.astype(np.bool) 71 | 72 | 73 | 74 | 75 | def get_object_masks(nuscenes, sample_data, extents, resolution): 76 | 77 | # Initialize object masks 78 | nclass = len(DETECTION_NAMES) + 1 79 | grid_width = int((extents[2] - extents[0]) / resolution) 80 | grid_height = int((extents[3] - extents[1]) / resolution) 81 | masks = np.zeros((nclass, grid_height, grid_width), dtype=np.uint8) 82 | 83 | # Get the 2D affine transform from bev coords to map coords 84 | tfm = get_sensor_transform(nuscenes, sample_data)[[0, 1, 3]][:, [0, 2, 3]] 85 | inv_tfm = np.linalg.inv(tfm) 86 | 87 | for box in nuscenes.get_boxes(sample_data['token']): 88 | 89 | # Get the index of the class 90 | det_name = category_to_detection_name(box.name) 91 | if det_name not in DETECTION_NAMES: 92 | class_id = -1 93 | else: 94 | class_id = DETECTION_NAMES.index(det_name) 95 | 96 | # Get bounding box coordinates in the grid coordinate frame 97 | bbox = box.bottom_corners()[:2] 98 | local_bbox = np.dot(inv_tfm[:2, :2], bbox).T + inv_tfm[:2, 2] 99 | 100 | # Render the rotated bounding box to the mask 101 | render_polygon(masks[class_id], local_bbox, extents, resolution) 102 | 103 | return masks.astype(np.bool) 104 | 105 | 106 | def get_sensor_transform(nuscenes, sample_data): 107 | 108 | # Load sensor transform data 109 | sensor = nuscenes.get( 110 | 'calibrated_sensor', sample_data['calibrated_sensor_token']) 111 | sensor_tfm = make_transform_matrix(sensor) 112 | 113 | # Load ego pose data 114 | pose = nuscenes.get('ego_pose', sample_data['ego_pose_token']) 115 | pose_tfm = make_transform_matrix(pose) 116 | 117 | return np.dot(pose_tfm, sensor_tfm) 118 | 119 | 120 | def load_point_cloud(nuscenes, sample_data): 121 | 122 | # Load point cloud 123 | lidar_path = os.path.join(nuscenes.dataroot, sample_data['filename']) 124 | pcl = LidarPointCloud.from_file(lidar_path) 125 | return pcl.points[:3, :].T 126 | 127 | 128 | def make_transform_matrix(record): 129 | """ 130 | Create a 4x4 transform matrix from a calibrated_sensor or ego_pose record 131 | """ 132 | transform = np.eye(4) 133 | transform[:3, :3] = Quaternion(record['rotation']).rotation_matrix 134 | transform[:3, 3] = np.array(record['translation']) 135 | return transform 136 | 137 | 138 | def render_shapely_polygon(mask, polygon, extents, resolution): 139 | 140 | if polygon.geom_type == 'Polygon': 141 | 142 | # Render exteriors 143 | render_polygon(mask, polygon.exterior.coords, extents, resolution, 1) 144 | 145 | # Render interiors 146 | for hole in polygon.interiors: 147 | render_polygon(mask, hole.coords, extents, resolution, 0) 148 | 149 | # Handle the case of compound shapes 150 | else: 151 | for poly in polygon: 152 | render_shapely_polygon(mask, poly, extents, resolution) 153 | 154 | 155 | 156 | 157 | -------------------------------------------------------------------------------- /src/data/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from shapely import affinity 5 | 6 | def decode_binary_labels(labels, nclass): 7 | bits = torch.pow(2, torch.arange(nclass)) 8 | return (labels & bits.view(-1, 1, 1)) > 0 9 | 10 | 11 | def encode_binary_labels(masks): 12 | bits = np.power(2, np.arange(len(masks), dtype=np.int32)) 13 | return (masks.astype(np.int32) * bits.reshape(-1, 1, 1)).sum(0) 14 | 15 | 16 | def transform(matrix, vectors): 17 | vectors = np.dot(matrix[:-1, :-1], vectors.T) 18 | vectors = vectors.T + matrix[:-1, -1] 19 | return vectors 20 | 21 | 22 | def transform_polygon(polygon, affine): 23 | """ 24 | Transform a 2D polygon 25 | """ 26 | a, b, tx, c, d, ty = affine.flatten()[:6] 27 | return affinity.affine_transform(polygon, [a, b, c, d, tx, ty]) 28 | 29 | 30 | def render_polygon(mask, polygon, extents, resolution, value=1): 31 | if len(polygon) == 0: 32 | return 33 | polygon = (polygon - np.array(extents[:2])) / resolution 34 | polygon = np.ascontiguousarray(polygon).round().astype(np.int32) 35 | cv2.fillConvexPoly(mask, polygon, value) 36 | 37 | 38 | def get_visible_mask(instrinsics, image_width, extents, resolution): 39 | 40 | # Get calibration parameters 41 | fu, cu = instrinsics[0, 0], instrinsics[0, 2] 42 | 43 | # Construct a grid of image coordinates 44 | x1, z1, x2, z2 = extents 45 | x, z = np.arange(x1, x2, resolution), np.arange(z1, z2, resolution) 46 | ucoords = x / z[:, None] * fu + cu 47 | 48 | # Return all points which lie within the camera bounds 49 | return (ucoords >= 0) & (ucoords < image_width) 50 | 51 | 52 | def get_occlusion_mask(points, extents, resolution): 53 | 54 | x1, z1, x2, z2 = extents 55 | 56 | # A 'ray' is defined by the ratio between x and z coordinates 57 | ray_width = resolution / z2 58 | ray_offset = x1 / ray_width 59 | max_rays = int((x2 - x1) / ray_width) 60 | 61 | # Group LiDAR points into bins 62 | rayid = np.round(points[:, 0] / points[:, 2] / ray_width - ray_offset) 63 | depth = points[:, 2] 64 | 65 | # Ignore rays which do not correspond to any grid cells in the BEV 66 | valid = (rayid > 0) & (rayid < max_rays) & (depth > 0) 67 | rayid = rayid[valid] 68 | depth = depth[valid] 69 | 70 | # Find the LiDAR point with maximum depth within each bin 71 | max_depth = np.zeros((max_rays,)) 72 | np.maximum.at(max_depth, rayid.astype(np.int32), depth) 73 | 74 | # For each bev grid point, sample the max depth along the corresponding ray 75 | x = np.arange(x1, x2, resolution) 76 | z = np.arange(z1, z2, resolution)[:, None] 77 | grid_rayid = np.round(x / z / ray_width - ray_offset).astype(np.int32) 78 | grid_max_depth = max_depth[grid_rayid] 79 | 80 | # A grid position is considered occluded if the there are no LiDAR points 81 | # passing through it 82 | occluded = grid_max_depth < z 83 | return occluded 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avishkarsaha/translating-images-into-maps/a452fe7a1eb0062f96133d14380d504e7bc0c8d1/src/model/__init__.py -------------------------------------------------------------------------------- /src/model/axial_attention.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copied from https://github.com/lucidrains/axial-attention.git 3 | """ 4 | 5 | import torch 6 | from torch import nn 7 | from operator import itemgetter 8 | 9 | 10 | def map_el_ind(arr, ind): 11 | return list(map(itemgetter(ind), arr)) 12 | 13 | 14 | def sort_and_return_indices(arr): 15 | indices = [ind for ind in range(len(arr))] 16 | arr = zip(arr, indices) 17 | arr = sorted(arr) 18 | return map_el_ind(arr, 0), map_el_ind(arr, 1) 19 | 20 | 21 | # calculates the permutation to bring the input tensor to something attend-able 22 | # also calculates the inverse permutation to bring the tensor back to its original shape 23 | 24 | 25 | def calculate_permutations(num_dimensions, emb_dim): 26 | total_dimensions = num_dimensions + 2 27 | emb_dim = emb_dim if emb_dim > 0 else (emb_dim + total_dimensions) 28 | axial_dims = [ind for ind in range(1, total_dimensions) if ind != emb_dim] 29 | 30 | permutations = [] 31 | 32 | for axial_dim in axial_dims: 33 | last_two_dims = [axial_dim, emb_dim] 34 | dims_rest = set(range(0, total_dimensions)) - set(last_two_dims) 35 | permutation = [*dims_rest, *last_two_dims] 36 | permutations.append(permutation) 37 | 38 | return permutations 39 | 40 | 41 | # helper classes 42 | 43 | 44 | class Rezero(nn.Module): 45 | def __init__(self, fn): 46 | super().__init__() 47 | self.fn = fn 48 | self.g = nn.Parameter(torch.tensor(0.0)) 49 | 50 | def forward(self, x): 51 | return self.fn(x) * self.g 52 | 53 | 54 | class Sequential(nn.Module): 55 | def __init__(self, blocks): 56 | super().__init__() 57 | self.blocks = blocks 58 | 59 | def forward(self, x): 60 | for f, g in self.blocks: 61 | x = x + f(x) + g(x) 62 | return x 63 | 64 | 65 | class PermuteToFrom(nn.Module): 66 | def __init__(self, permutation, fn): 67 | super().__init__() 68 | self.fn = fn 69 | _, inv_permutation = sort_and_return_indices(permutation) 70 | self.permutation = permutation 71 | self.inv_permutation = inv_permutation 72 | 73 | def forward(self, x, **kwargs): 74 | axial = x.permute(*self.permutation).contiguous() 75 | 76 | shape = axial.shape 77 | *_, t, d = shape 78 | 79 | # merge all but axial dimension 80 | axial = axial.reshape(-1, t, d) 81 | 82 | # attention 83 | axial = self.fn(axial, **kwargs) 84 | 85 | # restore to original shape and permutation 86 | axial = axial.reshape(*shape) 87 | axial = axial.permute(*self.inv_permutation).contiguous() 88 | return axial 89 | 90 | 91 | class AxialPositionalEmbedding(nn.Module): 92 | def __init__(self, emb_dim, emb_dim_index, dimensions): 93 | super().__init__() 94 | parameters = [] 95 | total_dimensions = len(dimensions) + 2 96 | ax_dim_indexes = [i for i in range(1, total_dimensions) if i != emb_dim_index] 97 | 98 | for axial_dim, axial_dim_index in zip(dimensions, ax_dim_indexes): 99 | shape = [1] * total_dimensions 100 | shape[emb_dim_index] = emb_dim 101 | shape[axial_dim_index] = axial_dim 102 | parameter = nn.Parameter(torch.randn(*shape)) 103 | parameters.append(parameter) 104 | 105 | self.params = nn.ParameterList(parameters) 106 | 107 | def forward(self, x): 108 | for param in self.params: 109 | x = x + param 110 | return x 111 | 112 | 113 | # classic multi-head attention 114 | 115 | 116 | def attention(q, k, v, h): 117 | b, t, d = q.shape 118 | e = d // h 119 | 120 | merge_heads = lambda x: x.reshape(b, -1, h, e).transpose(1, 2).reshape(b * h, -1, e) 121 | q, k, v = map(merge_heads, (q, k, v)) 122 | dots = torch.einsum("bie,bje->bij", q, k) * ((d // h) ** -0.5) 123 | dots = dots.softmax(dim=-1) 124 | out = torch.einsum("bij,bje->bie", dots, v) 125 | out = out.reshape(b, h, -1, e).transpose(1, 2).reshape(b, -1, d) 126 | return out 127 | 128 | 129 | class SelfAttention(nn.Module): 130 | def __init__(self, dim, heads, dim_heads=None): 131 | super().__init__() 132 | self.dim_heads = (dim // heads) if dim_heads is None else dim_heads 133 | dim_hidden = self.dim_heads * heads 134 | 135 | self.heads = heads 136 | self.to_q = nn.Linear(dim, dim_hidden, bias=False) 137 | self.to_kv = nn.Linear(dim, 2 * dim_hidden, bias=False) 138 | self.to_out = nn.Linear(dim_hidden, dim) 139 | 140 | def forward(self, x, kv=None): 141 | kv = x if kv is None else kv 142 | q, k, v = (self.to_q(x), *self.to_kv(kv).chunk(2, dim=-1)) 143 | 144 | b, t, d, h, e = *q.shape, self.heads, self.dim_heads 145 | 146 | merge_heads = ( 147 | lambda x: x.reshape(b, -1, h, e).transpose(1, 2).reshape(b * h, -1, e) 148 | ) 149 | q, k, v = map(merge_heads, (q, k, v)) 150 | 151 | dots = torch.einsum("bie,bje->bij", q, k) * ((d // h) ** -0.5) 152 | dots = dots.softmax(dim=-1) 153 | out = torch.einsum("bij,bje->bie", dots, v) 154 | 155 | out = out.reshape(b, h, -1, e).transpose(1, 2).reshape(b, -1, d) 156 | out = self.to_out(out) 157 | return out 158 | 159 | 160 | class InducedSetAttention(nn.Module): 161 | def __init__(self, num_queries, dim, heads, dim_heads=None): 162 | super().__init__() 163 | self.queries = nn.Parameter(torch.randn(1, num_queries, dim)) 164 | self.attn_in = SelfAttention(dim, heads) 165 | self.attn_out = SelfAttention(dim, heads) 166 | 167 | def forward(self, x): 168 | b = x.shape[0] 169 | q = self.queries.expand(b, -1, -1) 170 | q_out = self.attn_in(q, x) 171 | out = self.attn_out(x, q_out) 172 | return out 173 | 174 | 175 | # axial attention class 176 | 177 | 178 | class AxialAttention(nn.Module): 179 | def __init__( 180 | self, 181 | dim, 182 | num_dimensions=2, 183 | heads=8, 184 | dim_heads=None, 185 | dim_index=-1, 186 | sum_axial_out=True, 187 | ): 188 | assert ( 189 | dim % heads 190 | ) == 0, "hidden dimension must be divisible by number of heads" 191 | super().__init__() 192 | self.dim = dim 193 | self.total_dimensions = num_dimensions + 2 194 | self.dim_index = ( 195 | dim_index if dim_index > 0 else (dim_index + self.total_dimensions) 196 | ) 197 | 198 | attentions = [] 199 | for permutation in calculate_permutations(num_dimensions, dim_index): 200 | attentions.append( 201 | PermuteToFrom(permutation, SelfAttention(dim, heads, dim_heads)) 202 | ) 203 | 204 | self.axial_attentions = nn.ModuleList(attentions) 205 | self.sum_axial_out = sum_axial_out 206 | 207 | def forward(self, x): 208 | assert ( 209 | len(x.shape) == self.total_dimensions 210 | ), "input tensor does not have the correct number of dimensions" 211 | assert ( 212 | x.shape[self.dim_index] == self.dim 213 | ), "input tensor does not have the correct input dimension" 214 | 215 | if self.sum_axial_out: 216 | return sum(map(lambda axial_attn: axial_attn(x), self.axial_attentions)) 217 | 218 | axial_attn = self.axial_attentions[0] 219 | out = axial_attn(x) 220 | for axial_attn in self.axial_attentions[1:]: 221 | out = axial_attn(out) 222 | return out 223 | -------------------------------------------------------------------------------- /src/model/backbone_utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from torch import nn 3 | from torchvision.ops.feature_pyramid_network import ( 4 | FeaturePyramidNetwork, 5 | LastLevelMaxPool, 6 | ) 7 | 8 | from torchvision.ops import misc as misc_nn_ops 9 | from torchvision.models._utils import IntermediateLayerGetter 10 | from torch.utils.checkpoint import checkpoint, checkpoint_sequential 11 | 12 | from src.model.intermediate_layer_getter import IntermediateLayerGetter as MidGetter 13 | 14 | from src.model import resnet 15 | from src.model.fpn import FeaturePyramidNetworkRoddick, LLMaxPool 16 | 17 | 18 | class BackboneWithFPN(nn.Module): 19 | """ 20 | Adds a FPN on top of a model. 21 | Internally, it uses torchvision.models._utils.IntermediateLayerGetter to 22 | extract a submodel that returns the feature maps specified in return_layers. 23 | The same limitations of IntermediatLayerGetter apply here. 24 | Arguments: 25 | backbone (nn.Module) 26 | return_layers (Dict[name, new_name]): a dict containing the names 27 | of the modules for which the activations will be returned as 28 | the key of the dict, and the value of the dict is the name 29 | of the returned activation (which the user can specify). 30 | in_channels_list (List[int]): number of channels for each feature map 31 | that is returned, in the order they are present in the OrderedDict 32 | out_channels (int): number of channels in the FPN. 33 | Attributes: 34 | out_channels (int): the number of channels in the FPN 35 | """ 36 | 37 | def __init__(self, backbone, return_layers, in_channels_list, out_channels): 38 | super(BackboneWithFPN, self).__init__() 39 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 40 | self.fpn = FeaturePyramidNetwork( 41 | in_channels_list=in_channels_list, 42 | out_channels=out_channels, 43 | extra_blocks=LastLevelMaxPool(), 44 | ) 45 | self.out_channels = out_channels 46 | 47 | def forward(self, x): 48 | x = self.body(x) 49 | x = self.fpn(x) 50 | return x 51 | 52 | 53 | class BackboneWithFPNRoddick(nn.Module): 54 | """ 55 | Adds a FPN on top of a model. 56 | Internally, it uses torchvision.models._utils.IntermediateLayerGetter to 57 | extract a submodel that returns the feature maps specified in return_layers. 58 | The same limitations of IntermediatLayerGetter apply here. 59 | Arguments: 60 | backbone (nn.Module) 61 | return_layers (Dict[name, new_name]): a dict containing the names 62 | of the modules for which the activations will be returned as 63 | the key of the dict, and the value of the dict is the name 64 | of the returned activation (which the user can specify). 65 | in_channels_list (List[int]): number of channels for each feature map 66 | that is returned, in the order they are present in the OrderedDict 67 | out_channels (int): number of channels in the FPN. 68 | Attributes: 69 | out_channels (int): the number of channels in the FPN 70 | """ 71 | 72 | def __init__(self, backbone, return_layers, in_channels_list, out_channels): 73 | super(BackboneWithFPNRoddick, self).__init__() 74 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 75 | self.fpn = FeaturePyramidNetworkRoddick( 76 | in_channels_list=in_channels_list, 77 | out_channels=out_channels, 78 | extra_blocks=LLMaxPool(), 79 | ) 80 | self.out_channels = out_channels 81 | 82 | def forward(self, x): 83 | x = self.body(x) 84 | x = self.fpn(x) 85 | return x 86 | 87 | 88 | def resnet_fpn_backbone(backbone_name, pretrained): 89 | backbone = resnet.__dict__[backbone_name]( 90 | pretrained=pretrained, norm_layer=misc_nn_ops.FrozenBatchNorm2d 91 | ) 92 | # freeze layers 93 | for name, parameter in backbone.named_parameters(): 94 | if "layer2" not in name and "layer3" not in name and "layer4" not in name: 95 | parameter.requires_grad_(False) 96 | 97 | return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} 98 | 99 | in_channels_stage2 = backbone.inplanes // 8 100 | in_channels_list = [ 101 | in_channels_stage2, 102 | in_channels_stage2 * 2, 103 | in_channels_stage2 * 4, 104 | in_channels_stage2 * 8, 105 | ] 106 | out_channels = 256 107 | return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels) 108 | 109 | 110 | def resnet_fpn_backbone_roddick(backbone_name, pretrained): 111 | backbone = resnet.__dict__["resnet50Roddick"]( 112 | pretrained=pretrained, norm_layer=misc_nn_ops.FrozenBatchNorm2d 113 | ) 114 | # freeze layers 115 | for name, parameter in backbone.named_parameters(): 116 | if ( 117 | "layer2" not in name 118 | and "layer3" not in name 119 | and "layer4" not in name 120 | and "layer5" not in name 121 | ): 122 | parameter.requires_grad_(False) 123 | 124 | return_layers = { 125 | "layer1": "0", 126 | "layer2": "1", 127 | "layer3": "2", 128 | "layer4": "3", 129 | "layer5": "4", 130 | } 131 | 132 | in_channels_stage2 = backbone.inplanes // 8 133 | in_channels_list = [ 134 | in_channels_stage2, 135 | in_channels_stage2 * 2, 136 | in_channels_stage2 * 4, 137 | in_channels_stage2 * 8, 138 | in_channels_stage2 * 8, 139 | ] 140 | out_channels = 256 141 | return BackboneWithFPNRoddick( 142 | backbone, return_layers, in_channels_list, out_channels 143 | ) 144 | -------------------------------------------------------------------------------- /src/model/bev_transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import matplotlib.pyplot as plt 5 | from .. import utils 6 | 7 | EPSILON = 1e-6 8 | 9 | 10 | class BEVT_04(nn.Module): 11 | """ 12 | Changes from BEVT: 13 | Condensed bottleneck along channel dimension, 14 | Increased dropout probability, 15 | """ 16 | 17 | def __init__(self, in_height, z_max, z_min, cell_size): 18 | super().__init__() 19 | 20 | # [B, C, H, W] --> [B, C_condensed, H, W] 21 | self.conv_c = nn.Sequential( 22 | nn.Conv2d(256, 32, kernel_size=1), 23 | nn.GroupNorm(2, 32), 24 | nn.ReLU(), 25 | ) 26 | # [B, C_condensed, W, H] --> [B, C_condensed, W, 1] 27 | self.linear_v = nn.Sequential( 28 | nn.Linear(in_height, 1), 29 | nn.GroupNorm(16, 32), 30 | nn.ReLU(), 31 | nn.Dropout(p=0.5), 32 | ) 33 | 34 | # [B, 1, C_condensed, W] --> [B, Z, C_condensed, W] 35 | depth = (z_max - z_min) / cell_size 36 | 37 | self.z_extend = nn.Conv2d(1, int(depth), kernel_size=1) 38 | 39 | self.bev_expand_c = nn.Conv2d(32, 256, kernel_size=1) 40 | 41 | self.z_max = z_max 42 | self.z_min = z_min 43 | self.cell_size = cell_size 44 | 45 | def forward(self, features, calib, grid): 46 | # print(' BEV input shape:', features.shape) 47 | # Condense channel dimensions 48 | # [B, C, H, W] --> [B, C_cond, H, W] 49 | condensed_c = self.conv_c(features) 50 | 51 | # Reshape input tensor then condense input 52 | # features along the y dimension (height) 53 | # [B, C_cond, H, W] --> [B, C_cond, W, H] --> [B, C_cond, W, 1] 54 | bottleneck = self.linear_v(condensed_c.permute(0, 1, 3, 2)) 55 | 56 | # condense collapsed v features along channels 57 | # [B, C, W, 1] --> [B, 1, W, C] --> [B, 1, W, C_condensed] 58 | # bottleneck = self.linear_c(condensed_h.permute((0, 3, 2, 1))) 59 | 60 | # Expand the bottleneck along the z dimension (depth) 61 | # [B, 1, C, W] --> [B, Z, C, W] --> [B, C, Z, W] 62 | bev_polar_feats = self.z_extend(bottleneck.permute(0, 3, 1, 2)).permute( 63 | 0, 2, 1, 3 64 | ) 65 | # bev_polar_feats_gn = self.gn_polar(bev_polar_feats) 66 | 67 | # print(' BEV polar feats shape:', bev_polar_feats.shape) 68 | 69 | # Normalise grid to [-1, 1] 70 | norm_grid = self.normalise_grid(grid, calib) 71 | 72 | # TODO compute features within voxels of 2D grid instead of at coordinates 73 | 74 | bev_cart_feats = F.grid_sample(bev_polar_feats, norm_grid, align_corners=True) 75 | 76 | bev_cart_feats_expand = self.bev_expand_c(bev_cart_feats) 77 | 78 | # print(' BEV cart feats shape:', bev_cart_feats.shape) 79 | return bev_cart_feats_expand 80 | 81 | def normalise_grid(self, grid, calib): 82 | """ 83 | :param grid: BEV grid in with coords range 84 | [grid_h1, grid_h2] and [-grid_w/2, grid_w/2] 85 | :param calib: 86 | :return: 87 | """ 88 | 89 | f, cu = calib[:, 0, 0], calib[:, 0, 2] 90 | batch_size = len(calib) 91 | 92 | # Compute positive x dimension at z_max and z_min 93 | # Computed x dimension is half grid width 94 | x_zmax = self.z_max / f * cu 95 | x_zmin = self.z_min / f * cu 96 | 97 | # Compute normalising constant for each row along the z-axis 98 | sample_res_z = (self.z_max - self.z_min) / self.cell_size 99 | sample_res_x = grid.shape[2] / self.cell_size 100 | 101 | norm_z = ( 102 | 2 103 | * (grid[..., 1] - grid[..., 1].min()) 104 | / (grid[..., 1].max() - grid[..., 1].min()) 105 | - 1 106 | ) 107 | 108 | norm_scale_x = torch.stack( 109 | [ 110 | torch.linspace(float(x_zmin[i]), float(x_zmax[i]), int(sample_res_z)) 111 | for i in range(batch_size) 112 | ] 113 | ) 114 | 115 | grid_ones = torch.ones_like(grid) 116 | grid_ones[..., 0] *= norm_scale_x.view(batch_size, -1, 1).cuda() 117 | 118 | # Normalise grid to [-1, 1] 119 | norm_grid = grid / grid_ones 120 | 121 | norm_grid[..., 1] = norm_z 122 | 123 | # print(' norm grid', norm_grid[...,0].max(), norm_grid[...,0].min(), 124 | # norm_grid[...,1].max(), norm_grid[...,1].min()) 125 | 126 | return norm_grid 127 | 128 | 129 | class BEVT_H(nn.Module): 130 | def __init__( 131 | self, 132 | in_height, 133 | z_max, 134 | z_min, 135 | cell_size, 136 | additions_linear=False, 137 | additions_conv=False, 138 | kernel_h=9, 139 | stride_h=3, 140 | padding_h=4, 141 | ): 142 | super().__init__() 143 | 144 | self.horizontal = nn.Sequential( 145 | nn.Conv2d( 146 | 256, 147 | 256, 148 | kernel_size=[3, kernel_h], 149 | stride=[1, stride_h], 150 | padding=[1, padding_h], 151 | ), 152 | nn.GroupNorm(16, 256), 153 | nn.ReLU(), 154 | ) 155 | 156 | # [B, C, W, H] --> [B, C, W, 1] 157 | if additions_linear: 158 | self.linear = nn.Sequential( 159 | nn.Linear(in_height, 1), 160 | nn.GroupNorm(16, 256), 161 | nn.ReLU(), 162 | nn.Dropout(p=0.25), 163 | ) 164 | else: 165 | self.linear = nn.Linear(in_height, 1) 166 | 167 | # [B, 1, C, W] --> [B, Z, C, W] 168 | depth = (z_max - z_min) / cell_size 169 | 170 | self.additions_conv = additions_conv 171 | if additions_conv: 172 | self.z_extend = nn.Sequential( 173 | nn.Conv2d(1, int(depth), kernel_size=1), 174 | ) 175 | self.gn_polar = nn.GroupNorm(16, 256) 176 | else: 177 | self.z_extend = nn.Conv2d(1, int(depth), kernel_size=1) 178 | 179 | self.z_max = z_max 180 | self.z_min = z_min 181 | self.cell_size = cell_size 182 | 183 | def forward(self, features, calib, grid): 184 | 185 | # Convolve horizontally first 186 | features = self.horizontal(features) 187 | 188 | # Reshape input tensor then collapse input 189 | # features along the y dimension (height) 190 | # [B, C, H, W] --> [B, C, W, H] --> [B, C, W, 1] 191 | bottleneck = self.linear(features.permute(0, 1, 3, 2)) 192 | 193 | # Expand the bottleneck along the z dimension (depth) 194 | # [B, 1, C, W] --> [B, Z, C, W] --> [B, C, Z, W] 195 | bev_polar_feats = self.z_extend(bottleneck.permute(0, 3, 1, 2)).permute( 196 | 0, 2, 1, 3 197 | ) 198 | # bev_polar_feats_gn = self.gn_polar(bev_polar_feats) 199 | 200 | # print(' BEV polar feats shape:', bev_polar_feats.shape) 201 | 202 | # Normalise grid to [-1, 1] 203 | norm_grid = self.normalise_grid(grid, calib) 204 | 205 | if self.additions_conv: 206 | bev_polar_feats_gn = self.gn_polar(bev_polar_feats) 207 | # Sample BEV polar features at grid locations 208 | # [B, C, Z, W] --> [B, C, Z, X] 209 | bev_cart_feats = F.grid_sample( 210 | bev_polar_feats_gn, norm_grid, align_corners=True 211 | ) 212 | # print(' BEV cart feats shape:', bev_cart_feats.shape) 213 | return bev_cart_feats 214 | else: 215 | bev_cart_feats = F.grid_sample( 216 | bev_polar_feats, norm_grid, align_corners=True 217 | ) 218 | # print(' BEV cart feats shape:', bev_cart_feats.shape) 219 | return bev_cart_feats 220 | 221 | def normalise_grid(self, grid, calib): 222 | """ 223 | :param grid: BEV grid in with coords range 224 | [grid_h1, grid_h2] and [-grid_w/2, grid_w/2] 225 | :param calib: 226 | :return: 227 | """ 228 | 229 | f, cu = calib[:, 0, 0], calib[:, 0, 2] 230 | batch_size = len(calib) 231 | 232 | # Compute positive x dimension at z_max and z_min 233 | # Computed x dimension is half grid width 234 | x_zmax = self.z_max / f * cu 235 | x_zmin = self.z_min / f * cu 236 | 237 | # Compute normalising constant for each row along the z-axis 238 | sample_res_z = (self.z_max - self.z_min) / self.cell_size 239 | sample_res_x = grid.shape[2] / self.cell_size 240 | 241 | norm_z = ( 242 | 2 243 | * (grid[..., 1] - grid[..., 1].min()) 244 | / (grid[..., 1].max() - grid[..., 1].min()) 245 | - 1 246 | ) 247 | 248 | norm_scale_x = torch.stack( 249 | [ 250 | torch.linspace(float(x_zmin[i]), float(x_zmax[i]), int(sample_res_z)) 251 | for i in range(batch_size) 252 | ] 253 | ) 254 | 255 | grid_ones = torch.ones_like(grid) 256 | grid_ones[..., 0] *= norm_scale_x.view(batch_size, -1, 1).cuda() 257 | 258 | # Normalise grid to [-1, 1] 259 | norm_grid = grid / grid_ones 260 | 261 | norm_grid[..., 1] = norm_z 262 | 263 | return norm_grid 264 | 265 | 266 | class BEVT(nn.Module): 267 | def __init__( 268 | self, 269 | in_height, 270 | z_max, 271 | z_min, 272 | cell_size, 273 | additions_linear=False, 274 | additions_conv=False, 275 | ): 276 | super().__init__() 277 | 278 | # [B, C, W, H] --> [B, C, W, 1] 279 | if additions_linear: 280 | self.linear = nn.Sequential( 281 | nn.Linear(in_height, 1), 282 | nn.GroupNorm(16, 256), 283 | nn.ReLU(), 284 | nn.Dropout(p=0.25), 285 | ) 286 | else: 287 | self.linear = nn.Linear(in_height, 1) 288 | 289 | # [B, 1, C, W] --> [B, Z, C, W] 290 | depth = (z_max - z_min) / cell_size 291 | 292 | self.additions_conv = additions_conv 293 | if additions_conv: 294 | self.z_extend = nn.Sequential( 295 | nn.Conv2d(1, int(depth), kernel_size=1), 296 | ) 297 | self.gn_polar = nn.GroupNorm(16, 256) 298 | else: 299 | self.z_extend = nn.Conv2d(1, int(depth), kernel_size=1) 300 | 301 | self.z_max = z_max 302 | self.z_min = z_min 303 | self.cell_size = cell_size 304 | 305 | def forward(self, features, calib, grid): 306 | # print(' BEV input shape:', features.shape) 307 | # Reshape input tensor then collapse input 308 | # features along the y dimension (height) 309 | # [B, C, H, W] --> [B, C, W, H] --> [B, C, W, 1] 310 | bottleneck = self.linear(features.permute(0, 1, 3, 2)) 311 | 312 | # Expand the bottleneck along the z dimension (depth) 313 | # [B, 1, C, W] --> [B, Z, C, W] --> [B, C, Z, W] 314 | bev_polar_feats = self.z_extend(bottleneck.permute(0, 3, 1, 2)).permute( 315 | 0, 2, 1, 3 316 | ) 317 | # bev_polar_feats_gn = self.gn_polar(bev_polar_feats) 318 | 319 | # print(' BEV polar feats shape:', bev_polar_feats.shape) 320 | 321 | # Normalise grid to [-1, 1] 322 | norm_grid = self.normalise_grid(grid, calib) 323 | 324 | if self.additions_conv: 325 | bev_polar_feats_gn = self.gn_polar(bev_polar_feats) 326 | # Sample BEV polar features at grid locations 327 | # [B, C, Z, W] --> [B, C, Z, X] 328 | bev_cart_feats = F.grid_sample( 329 | bev_polar_feats_gn, norm_grid, align_corners=True 330 | ) 331 | # print(' BEV cart feats shape:', bev_cart_feats.shape) 332 | return bev_cart_feats 333 | else: 334 | bev_cart_feats = F.grid_sample( 335 | bev_polar_feats, norm_grid, align_corners=True 336 | ) 337 | # print(' BEV cart feats shape:', bev_cart_feats.shape) 338 | return bev_cart_feats 339 | 340 | def normalise_grid(self, grid, calib): 341 | """ 342 | :param grid: BEV grid in with coords range 343 | [grid_h1, grid_h2] and [-grid_w/2, grid_w/2] 344 | :param calib: 345 | :return: 346 | """ 347 | 348 | f, cu = calib[:, 0, 0], calib[:, 0, 2] 349 | batch_size = len(calib) 350 | 351 | # Compute positive x dimension at z_max and z_min 352 | # Computed x dimension is half grid width 353 | x_zmax = self.z_max / f * cu 354 | x_zmin = self.z_min / f * cu 355 | 356 | # Compute normalising constant for each row along the z-axis 357 | sample_res_z = (self.z_max - self.z_min) / self.cell_size 358 | sample_res_x = grid.shape[2] / self.cell_size 359 | 360 | norm_z = ( 361 | 2 362 | * (grid[..., 1] - grid[..., 1].min()) 363 | / (grid[..., 1].max() - grid[..., 1].min()) 364 | - 1 365 | ) 366 | 367 | norm_scale_x = torch.stack( 368 | [ 369 | torch.linspace(float(x_zmin[i]), float(x_zmax[i]), int(sample_res_z)) 370 | for i in range(batch_size) 371 | ] 372 | ) 373 | 374 | grid_ones = torch.ones_like(grid) 375 | grid_ones[..., 0] *= norm_scale_x.view(batch_size, -1, 1).cuda() 376 | 377 | # Normalise grid to [-1, 1] 378 | norm_grid = grid / grid_ones 379 | 380 | norm_grid[..., 1] = norm_z 381 | 382 | return norm_grid 383 | 384 | 385 | class sample_polar2cart(nn.Module): 386 | def __init__( 387 | self, 388 | z_max, 389 | z_min, 390 | cell_size, 391 | ): 392 | super().__init__() 393 | 394 | self.z_max = z_max 395 | self.z_min = z_min 396 | self.cell_size = cell_size 397 | 398 | def forward(self, features, calib, grid): 399 | 400 | # Normalise grid to [-1, 1] 401 | norm_grid = self.normalise_grid(grid, calib) 402 | 403 | bev_cart_feats = F.grid_sample(features, norm_grid, align_corners=True) 404 | return bev_cart_feats 405 | 406 | def normalise_grid(self, grid, calib): 407 | """ 408 | :param grid: BEV grid in with coords range 409 | [grid_h1, grid_h2] and [-grid_w/2, grid_w/2] 410 | :param calib: 411 | :return: 412 | """ 413 | 414 | f, cu = calib[:, 0, 0], calib[:, 0, 2] 415 | batch_size = len(calib) 416 | 417 | # Compute positive x dimension at z_max and z_min 418 | # Computed x dimension is half grid width 419 | x_zmax = self.z_max / f * cu 420 | x_zmin = self.z_min / f * cu 421 | 422 | # Compute normalising constant for each row along the z-axis 423 | sample_res_z = (self.z_max - self.z_min) / self.cell_size 424 | sample_res_x = grid.shape[2] / self.cell_size 425 | 426 | norm_z = ( 427 | 2 428 | * (grid[..., 1] - grid[..., 1].min()) 429 | / (grid[..., 1].max() - grid[..., 1].min()) 430 | - 1 431 | ) 432 | 433 | norm_scale_x = torch.stack( 434 | [ 435 | torch.linspace(float(x_zmin[i]), float(x_zmax[i]), int(sample_res_z)) 436 | for i in range(batch_size) 437 | ] 438 | ) 439 | 440 | grid_ones = torch.ones_like(grid) 441 | grid_ones[..., 0] *= norm_scale_x.view(batch_size, -1, 1).cuda() 442 | 443 | # Normalise grid to [-1, 1] 444 | norm_grid = grid / grid_ones 445 | 446 | norm_grid[..., 1] = norm_z 447 | 448 | return norm_grid 449 | 450 | 451 | def integral_image(features): 452 | return torch.cumsum(torch.cumsum(features, dim=-1), dim=-2) 453 | -------------------------------------------------------------------------------- /src/model/dla.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import math 5 | from os.path import join 6 | 7 | import torch 8 | from torch import nn 9 | import torch.utils.model_zoo as model_zoo 10 | 11 | import src.model.dla_dataset as dataset 12 | 13 | BatchNorm = nn.BatchNorm2d 14 | 15 | WEB_ROOT = "http://dl.yf.io/dla/models" 16 | 17 | 18 | def get_model_url(data, name): 19 | return join(WEB_ROOT, data.name, "{}-{}.pth".format(name, data.model_hash[name])) 20 | 21 | 22 | def conv3x3(in_planes, out_planes, stride=1): 23 | "3x3 convolution with padding" 24 | return nn.Conv2d( 25 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False 26 | ) 27 | 28 | 29 | class BasicBlock(nn.Module): 30 | def __init__(self, inplanes, planes, stride=1, dilation=1): 31 | super(BasicBlock, self).__init__() 32 | self.conv1 = nn.Conv2d( 33 | inplanes, 34 | planes, 35 | kernel_size=3, 36 | stride=stride, 37 | padding=dilation, 38 | bias=False, 39 | dilation=dilation, 40 | ) 41 | self.bn1 = BatchNorm(planes) 42 | self.relu = nn.ReLU(inplace=True) 43 | self.conv2 = nn.Conv2d( 44 | planes, 45 | planes, 46 | kernel_size=3, 47 | stride=1, 48 | padding=dilation, 49 | bias=False, 50 | dilation=dilation, 51 | ) 52 | self.bn2 = BatchNorm(planes) 53 | self.stride = stride 54 | 55 | def forward(self, x, residual=None): 56 | if residual is None: 57 | residual = x 58 | 59 | out = self.conv1(x) 60 | out = self.bn1(out) 61 | out = self.relu(out) 62 | 63 | out = self.conv2(out) 64 | out = self.bn2(out) 65 | 66 | out += residual 67 | out = self.relu(out) 68 | 69 | return out 70 | 71 | 72 | class Bottleneck(nn.Module): 73 | expansion = 2 74 | 75 | def __init__(self, inplanes, planes, stride=1, dilation=1): 76 | super(Bottleneck, self).__init__() 77 | expansion = Bottleneck.expansion 78 | bottle_planes = planes // expansion 79 | self.conv1 = nn.Conv2d(inplanes, bottle_planes, kernel_size=1, bias=False) 80 | self.bn1 = BatchNorm(bottle_planes) 81 | self.conv2 = nn.Conv2d( 82 | bottle_planes, 83 | bottle_planes, 84 | kernel_size=3, 85 | stride=stride, 86 | padding=dilation, 87 | bias=False, 88 | dilation=dilation, 89 | ) 90 | self.bn2 = BatchNorm(bottle_planes) 91 | self.conv3 = nn.Conv2d(bottle_planes, planes, kernel_size=1, bias=False) 92 | self.bn3 = BatchNorm(planes) 93 | self.relu = nn.ReLU(inplace=True) 94 | self.stride = stride 95 | 96 | def forward(self, x, residual=None): 97 | if residual is None: 98 | residual = x 99 | 100 | out = self.conv1(x) 101 | out = self.bn1(out) 102 | out = self.relu(out) 103 | 104 | out = self.conv2(out) 105 | out = self.bn2(out) 106 | out = self.relu(out) 107 | 108 | out = self.conv3(out) 109 | out = self.bn3(out) 110 | 111 | out += residual 112 | out = self.relu(out) 113 | 114 | return out 115 | 116 | 117 | class BottleneckX(nn.Module): 118 | expansion = 2 119 | cardinality = 32 120 | 121 | def __init__(self, inplanes, planes, stride=1, dilation=1): 122 | super(BottleneckX, self).__init__() 123 | cardinality = BottleneckX.cardinality 124 | # dim = int(math.floor(planes * (BottleneckV5.expansion / 64.0))) 125 | # bottle_planes = dim * cardinality 126 | bottle_planes = planes * cardinality // 32 127 | self.conv1 = nn.Conv2d(inplanes, bottle_planes, kernel_size=1, bias=False) 128 | self.bn1 = BatchNorm(bottle_planes) 129 | self.conv2 = nn.Conv2d( 130 | bottle_planes, 131 | bottle_planes, 132 | kernel_size=3, 133 | stride=stride, 134 | padding=dilation, 135 | bias=False, 136 | dilation=dilation, 137 | groups=cardinality, 138 | ) 139 | self.bn2 = BatchNorm(bottle_planes) 140 | self.conv3 = nn.Conv2d(bottle_planes, planes, kernel_size=1, bias=False) 141 | self.bn3 = BatchNorm(planes) 142 | self.relu = nn.ReLU(inplace=True) 143 | self.stride = stride 144 | 145 | def forward(self, x, residual=None): 146 | if residual is None: 147 | residual = x 148 | 149 | out = self.conv1(x) 150 | out = self.bn1(out) 151 | out = self.relu(out) 152 | 153 | out = self.conv2(out) 154 | out = self.bn2(out) 155 | out = self.relu(out) 156 | 157 | out = self.conv3(out) 158 | out = self.bn3(out) 159 | 160 | out += residual 161 | out = self.relu(out) 162 | 163 | return out 164 | 165 | 166 | class Root(nn.Module): 167 | def __init__(self, in_channels, out_channels, kernel_size, residual): 168 | super(Root, self).__init__() 169 | self.conv = nn.Conv2d( 170 | in_channels, 171 | out_channels, 172 | kernel_size, 173 | stride=1, 174 | bias=False, 175 | padding=(kernel_size - 1) // 2, 176 | ) 177 | self.bn = BatchNorm(out_channels) 178 | self.relu = nn.ReLU(inplace=True) 179 | self.residual = residual 180 | 181 | def forward(self, *x): 182 | children = x 183 | x = self.conv(torch.cat(x, 1)) 184 | x = self.bn(x) 185 | if self.residual: 186 | x += children[0] 187 | x = self.relu(x) 188 | 189 | return x 190 | 191 | 192 | class Tree(nn.Module): 193 | def __init__( 194 | self, 195 | levels, 196 | block, 197 | in_channels, 198 | out_channels, 199 | stride=1, 200 | level_root=False, 201 | root_dim=0, 202 | root_kernel_size=1, 203 | dilation=1, 204 | root_residual=False, 205 | ): 206 | super(Tree, self).__init__() 207 | if root_dim == 0: 208 | root_dim = 2 * out_channels 209 | if level_root: 210 | root_dim += in_channels 211 | if levels == 1: 212 | self.tree1 = block(in_channels, out_channels, stride, dilation=dilation) 213 | self.tree2 = block(out_channels, out_channels, 1, dilation=dilation) 214 | else: 215 | self.tree1 = Tree( 216 | levels - 1, 217 | block, 218 | in_channels, 219 | out_channels, 220 | stride, 221 | root_dim=0, 222 | root_kernel_size=root_kernel_size, 223 | dilation=dilation, 224 | root_residual=root_residual, 225 | ) 226 | self.tree2 = Tree( 227 | levels - 1, 228 | block, 229 | out_channels, 230 | out_channels, 231 | root_dim=root_dim + out_channels, 232 | root_kernel_size=root_kernel_size, 233 | dilation=dilation, 234 | root_residual=root_residual, 235 | ) 236 | if levels == 1: 237 | self.root = Root(root_dim, out_channels, root_kernel_size, root_residual) 238 | self.level_root = level_root 239 | self.root_dim = root_dim 240 | self.downsample = None 241 | self.project = None 242 | self.levels = levels 243 | if stride > 1: 244 | self.downsample = nn.MaxPool2d(stride, stride=stride) 245 | if in_channels != out_channels: 246 | self.project = nn.Sequential( 247 | nn.Conv2d( 248 | in_channels, out_channels, kernel_size=1, stride=1, bias=False 249 | ), 250 | BatchNorm(out_channels), 251 | ) 252 | 253 | def forward(self, x, residual=None, children=None): 254 | children = [] if children is None else children 255 | bottom = self.downsample(x) if self.downsample else x 256 | residual = self.project(bottom) if self.project else bottom 257 | if self.level_root: 258 | children.append(bottom) 259 | x1 = self.tree1(x, residual) 260 | if self.levels == 1: 261 | x2 = self.tree2(x1) 262 | x = self.root(x2, x1, *children) 263 | else: 264 | children.append(x1) 265 | x = self.tree2(x1, children=children) 266 | return x 267 | 268 | 269 | class DLA(nn.Module): 270 | def __init__( 271 | self, 272 | levels, 273 | channels, 274 | num_classes=1000, 275 | block=BasicBlock, 276 | residual_root=False, 277 | return_levels=False, 278 | pool_size=7, 279 | linear_root=False, 280 | ): 281 | super(DLA, self).__init__() 282 | self.channels = channels 283 | self.return_levels = return_levels 284 | self.num_classes = num_classes 285 | self.base_layer = nn.Sequential( 286 | nn.Conv2d(3, channels[0], kernel_size=7, stride=1, padding=3, bias=False), 287 | BatchNorm(channels[0]), 288 | nn.ReLU(inplace=True), 289 | ) 290 | self.level0 = self._make_conv_level(channels[0], channels[0], levels[0]) 291 | self.level1 = self._make_conv_level( 292 | channels[0], channels[1], levels[1], stride=2 293 | ) 294 | self.level2 = Tree( 295 | levels[2], 296 | block, 297 | channels[1], 298 | channels[2], 299 | 2, 300 | level_root=False, 301 | root_residual=residual_root, 302 | ) 303 | self.level3 = Tree( 304 | levels[3], 305 | block, 306 | channels[2], 307 | channels[3], 308 | 2, 309 | level_root=True, 310 | root_residual=residual_root, 311 | ) 312 | self.level4 = Tree( 313 | levels[4], 314 | block, 315 | channels[3], 316 | channels[4], 317 | 2, 318 | level_root=True, 319 | root_residual=residual_root, 320 | ) 321 | self.level5 = Tree( 322 | levels[5], 323 | block, 324 | channels[4], 325 | channels[5], 326 | 2, 327 | level_root=True, 328 | root_residual=residual_root, 329 | ) 330 | 331 | self.avgpool = nn.AvgPool2d(pool_size) 332 | self.fc = nn.Conv2d( 333 | channels[-1], num_classes, kernel_size=1, stride=1, padding=0, bias=True 334 | ) 335 | 336 | for m in self.modules(): 337 | if isinstance(m, nn.Conv2d): 338 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 339 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 340 | elif isinstance(m, BatchNorm): 341 | m.weight.data.fill_(1) 342 | m.bias.data.zero_() 343 | 344 | def _make_level(self, block, inplanes, planes, blocks, stride=1): 345 | downsample = None 346 | if stride != 1 or inplanes != planes: 347 | downsample = nn.Sequential( 348 | nn.MaxPool2d(stride, stride=stride), 349 | nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False), 350 | BatchNorm(planes), 351 | ) 352 | 353 | layers = [] 354 | layers.append(block(inplanes, planes, stride, downsample=downsample)) 355 | for i in range(1, blocks): 356 | layers.append(block(inplanes, planes)) 357 | 358 | return nn.Sequential(*layers) 359 | 360 | def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1): 361 | modules = [] 362 | for i in range(convs): 363 | modules.extend( 364 | [ 365 | nn.Conv2d( 366 | inplanes, 367 | planes, 368 | kernel_size=3, 369 | stride=stride if i == 0 else 1, 370 | padding=dilation, 371 | bias=False, 372 | dilation=dilation, 373 | ), 374 | BatchNorm(planes), 375 | nn.ReLU(inplace=True), 376 | ] 377 | ) 378 | inplanes = planes 379 | return nn.Sequential(*modules) 380 | 381 | def forward(self, x): 382 | y = [] 383 | x = self.base_layer(x) 384 | for i in range(6): 385 | x = getattr(self, "level{}".format(i))(x) 386 | y.append(x) 387 | if self.return_levels: 388 | return y 389 | else: 390 | x = self.avgpool(x) 391 | x = self.fc(x) 392 | x = x.view(x.size(0), -1) 393 | 394 | return x 395 | 396 | def load_pretrained_model(self, data_name, name): 397 | assert data_name in dataset.__dict__, "No pretrained model for {}".format( 398 | data_name 399 | ) 400 | data = dataset.__dict__[data_name] 401 | fc = self.fc 402 | if self.num_classes != data.classes: 403 | self.fc = nn.Conv2d( 404 | self.channels[-1], 405 | data.classes, 406 | kernel_size=1, 407 | stride=1, 408 | padding=0, 409 | bias=True, 410 | ) 411 | try: 412 | model_url = get_model_url(data, name) 413 | except KeyError: 414 | raise ValueError("{} trained on {} does not exist.".format(data.name, name)) 415 | self.load_state_dict(model_zoo.load_url(model_url)) 416 | self.fc = fc 417 | 418 | 419 | def dla34(pretrained=None, **kwargs): # DLA-34 420 | model = DLA( 421 | [1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512], block=BasicBlock, **kwargs 422 | ) 423 | if pretrained is not None: 424 | model.load_pretrained_model(pretrained, "dla34") 425 | return model 426 | 427 | 428 | def dla46_c(pretrained=None, **kwargs): # DLA-46-C 429 | Bottleneck.expansion = 2 430 | model = DLA( 431 | [1, 1, 1, 2, 2, 1], [16, 32, 64, 64, 128, 256], block=Bottleneck, **kwargs 432 | ) 433 | if pretrained is not None: 434 | model.load_pretrained_model(pretrained, "dla46_c") 435 | return model 436 | 437 | 438 | def dla46x_c(pretrained=None, **kwargs): # DLA-X-46-C 439 | BottleneckX.expansion = 2 440 | model = DLA( 441 | [1, 1, 1, 2, 2, 1], [16, 32, 64, 64, 128, 256], block=BottleneckX, **kwargs 442 | ) 443 | if pretrained is not None: 444 | model.load_pretrained_model(pretrained, "dla46x_c") 445 | return model 446 | 447 | 448 | def dla60x_c(pretrained=None, **kwargs): # DLA-X-60-C 449 | BottleneckX.expansion = 2 450 | model = DLA( 451 | [1, 1, 1, 2, 3, 1], [16, 32, 64, 64, 128, 256], block=BottleneckX, **kwargs 452 | ) 453 | if pretrained is not None: 454 | model.load_pretrained_model(pretrained, "dla60x_c") 455 | return model 456 | 457 | 458 | def dla60(pretrained=None, **kwargs): # DLA-60 459 | Bottleneck.expansion = 2 460 | model = DLA( 461 | [1, 1, 1, 2, 3, 1], [16, 32, 128, 256, 512, 1024], block=Bottleneck, **kwargs 462 | ) 463 | if pretrained is not None: 464 | model.load_pretrained_model(pretrained, "dla60") 465 | return model 466 | 467 | 468 | def dla60x(pretrained=None, **kwargs): # DLA-X-60 469 | BottleneckX.expansion = 2 470 | model = DLA( 471 | [1, 1, 1, 2, 3, 1], [16, 32, 128, 256, 512, 1024], block=BottleneckX, **kwargs 472 | ) 473 | if pretrained is not None: 474 | model.load_pretrained_model(pretrained, "dla60x") 475 | return model 476 | 477 | 478 | def dla102(pretrained=None, **kwargs): # DLA-102 479 | Bottleneck.expansion = 2 480 | model = DLA( 481 | [1, 1, 1, 3, 4, 1], 482 | [16, 32, 128, 256, 512, 1024], 483 | block=Bottleneck, 484 | residual_root=True, 485 | **kwargs 486 | ) 487 | if pretrained is not None: 488 | model.load_pretrained_model(pretrained, "dla102") 489 | return model 490 | 491 | 492 | def dla102x(pretrained=None, **kwargs): # DLA-X-102 493 | BottleneckX.expansion = 2 494 | model = DLA( 495 | [1, 1, 1, 3, 4, 1], 496 | [16, 32, 128, 256, 512, 1024], 497 | block=BottleneckX, 498 | residual_root=True, 499 | **kwargs 500 | ) 501 | if pretrained is not None: 502 | model.load_pretrained_model(pretrained, "dla102x") 503 | return model 504 | 505 | 506 | def dla102x2(pretrained=None, **kwargs): # DLA-X-102 64 507 | BottleneckX.cardinality = 64 508 | model = DLA( 509 | [1, 1, 1, 3, 4, 1], 510 | [16, 32, 128, 256, 512, 1024], 511 | block=BottleneckX, 512 | residual_root=True, 513 | **kwargs 514 | ) 515 | if pretrained is not None: 516 | model.load_pretrained_model(pretrained, "dla102x2") 517 | return model 518 | 519 | 520 | def dla169(pretrained=None, **kwargs): # DLA-169 521 | Bottleneck.expansion = 2 522 | model = DLA( 523 | [1, 1, 2, 3, 5, 1], 524 | [16, 32, 128, 256, 512, 1024], 525 | block=Bottleneck, 526 | residual_root=True, 527 | **kwargs 528 | ) 529 | if pretrained is not None: 530 | model.load_pretrained_model(pretrained, "dla169") 531 | return model 532 | -------------------------------------------------------------------------------- /src/model/dla_up.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | 7 | import src.model.dla as dla 8 | 9 | BatchNorm = nn.BatchNorm2d 10 | 11 | 12 | def set_bn(bn): 13 | global BatchNorm 14 | BatchNorm = bn 15 | dla.BatchNorm = bn 16 | 17 | 18 | class Identity(nn.Module): 19 | def __init__(self): 20 | super(Identity, self).__init__() 21 | 22 | def forward(self, x): 23 | return x 24 | 25 | 26 | def fill_up_weights(up): 27 | w = up.weight.data 28 | f = math.ceil(w.size(2) / 2) 29 | c = (2 * f - 1 - f % 2) / (2.0 * f) 30 | for i in range(w.size(2)): 31 | for j in range(w.size(3)): 32 | w[0, 0, i, j] = (1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c)) 33 | for c in range(1, w.size(0)): 34 | w[c, 0, :, :] = w[0, 0, :, :] 35 | 36 | 37 | class IDAUp(nn.Module): 38 | def __init__(self, node_kernel, out_dim, channels, up_factors): 39 | super(IDAUp, self).__init__() 40 | self.channels = channels 41 | self.out_dim = out_dim 42 | for i, c in enumerate(channels): 43 | if c == out_dim: 44 | proj = Identity() 45 | else: 46 | proj = nn.Sequential( 47 | nn.Conv2d(c, out_dim, kernel_size=1, stride=1, bias=False), 48 | BatchNorm(out_dim), 49 | nn.ReLU(inplace=True), 50 | ) 51 | f = int(up_factors[i]) 52 | if f == 1: 53 | up = Identity() 54 | else: 55 | up = nn.ConvTranspose2d( 56 | out_dim, 57 | out_dim, 58 | f * 2, 59 | stride=f, 60 | padding=f // 2, 61 | output_padding=0, 62 | groups=out_dim, 63 | bias=False, 64 | ) 65 | fill_up_weights(up) 66 | setattr(self, "proj_" + str(i), proj) 67 | setattr(self, "up_" + str(i), up) 68 | 69 | for i in range(1, len(channels)): 70 | node = nn.Sequential( 71 | nn.Conv2d( 72 | out_dim * 2, 73 | out_dim, 74 | kernel_size=node_kernel, 75 | stride=1, 76 | padding=node_kernel // 2, 77 | bias=False, 78 | ), 79 | BatchNorm(out_dim), 80 | nn.ReLU(inplace=True), 81 | ) 82 | setattr(self, "node_" + str(i), node) 83 | 84 | for m in self.modules(): 85 | if isinstance(m, nn.Conv2d): 86 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 87 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 88 | elif isinstance(m, BatchNorm): 89 | m.weight.data.fill_(1) 90 | m.bias.data.zero_() 91 | 92 | def forward(self, layers): 93 | assert len(self.channels) == len(layers), "{} vs {} layers".format( 94 | len(self.channels), len(layers) 95 | ) 96 | layers = list(layers) 97 | for i, l in enumerate(layers): 98 | upsample = getattr(self, "up_" + str(i)) 99 | project = getattr(self, "proj_" + str(i)) 100 | layers[i] = upsample(project(l)) 101 | x = layers[0] 102 | y = [] 103 | for i in range(1, len(layers)): 104 | node = getattr(self, "node_" + str(i)) 105 | x = node(torch.cat([x, layers[i]], 1)) 106 | y.append(x) 107 | return x, y 108 | 109 | 110 | class DLAUp(nn.Module): 111 | def __init__(self, channels, scales=(1, 2, 4, 8, 16), in_channels=None): 112 | super(DLAUp, self).__init__() 113 | if in_channels is None: 114 | in_channels = channels 115 | self.channels = channels 116 | channels = list(channels) 117 | scales = np.array(scales, dtype=int) 118 | for i in range(len(channels) - 1): 119 | j = -i - 2 120 | setattr( 121 | self, 122 | "ida_{}".format(i), 123 | IDAUp(3, channels[j], in_channels[j:], scales[j:] // scales[j]), 124 | ) 125 | scales[j + 1 :] = scales[j] 126 | in_channels[j + 1 :] = [channels[j] for _ in channels[j + 1 :]] 127 | 128 | def forward(self, layers): 129 | layers = list(layers) 130 | assert len(layers) > 1 131 | for i in range(len(layers) - 1): 132 | ida = getattr(self, "ida_{}".format(i)) 133 | x, y = ida(layers[-i - 2 :]) 134 | layers[-i - 1 :] = y 135 | return x 136 | 137 | 138 | class DLASeg(nn.Module): 139 | def __init__(self, base_name, classes, pretrained_base=None, down_ratio=2): 140 | super(DLASeg, self).__init__() 141 | assert down_ratio in [2, 4, 8, 16] 142 | self.first_level = int(np.log2(down_ratio)) 143 | self.base = dla.__dict__[base_name]( 144 | pretrained=pretrained_base, return_levels=True 145 | ) 146 | channels = self.base.channels 147 | scales = [2 ** i for i in range(len(channels[self.first_level :]))] 148 | self.dla_up = DLAUp(channels[self.first_level :], scales=scales) 149 | print(self.dla_up) 150 | self.fc = nn.Sequential( 151 | nn.Conv2d( 152 | channels[self.first_level], 153 | classes, 154 | kernel_size=1, 155 | stride=1, 156 | padding=0, 157 | bias=True, 158 | ) 159 | ) 160 | up_factor = 2 ** self.first_level 161 | if up_factor > 1: 162 | up = nn.ConvTranspose2d( 163 | classes, 164 | classes, 165 | up_factor * 2, 166 | stride=up_factor, 167 | padding=up_factor // 2, 168 | output_padding=0, 169 | groups=classes, 170 | bias=False, 171 | ) 172 | fill_up_weights(up) 173 | up.weight.requires_grad = False 174 | else: 175 | up = Identity() 176 | self.up = up 177 | self.softmax = nn.LogSoftmax(dim=1) 178 | 179 | for m in self.fc.modules(): 180 | if isinstance(m, nn.Conv2d): 181 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 182 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 183 | elif isinstance(m, BatchNorm): 184 | m.weight.data.fill_(1) 185 | m.bias.data.zero_() 186 | 187 | def forward(self, x): 188 | x = self.base(x) 189 | x = self.dla_up(x[self.first_level :]) 190 | x = self.fc(x) 191 | y = self.softmax(self.up(x)) 192 | return y, x 193 | 194 | def optim_parameters(self, memo=None): 195 | for param in self.base.parameters(): 196 | yield param 197 | for param in self.dla_up.parameters(): 198 | yield param 199 | for param in self.fc.parameters(): 200 | yield param 201 | 202 | 203 | def dla34up(classes, pretrained_base=None, **kwargs): 204 | model = DLASeg("dla34", classes, pretrained_base=pretrained_base, **kwargs) 205 | return model 206 | 207 | 208 | def dla60up(classes, pretrained_base=None, **kwargs): 209 | model = DLASeg("dla60", classes, pretrained_base=pretrained_base, **kwargs) 210 | return model 211 | 212 | 213 | def dla102up(classes, pretrained_base=None, **kwargs): 214 | model = DLASeg("dla102", classes, pretrained_base=pretrained_base, **kwargs) 215 | return model 216 | 217 | 218 | def dla169up(classes, pretrained_base=None, **kwargs): 219 | model = DLASeg("dla169", classes, pretrained_base=pretrained_base, **kwargs) 220 | return model 221 | 222 | 223 | model = dla34up(classes=3, pretrained_base=None, down_ratio=2) 224 | print(model) 225 | -------------------------------------------------------------------------------- /src/model/fpn.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, Tensor 6 | 7 | from torch.jit.annotations import Tuple, List, Dict, Optional 8 | 9 | 10 | class ExtraFPNBlock(nn.Module): 11 | """ 12 | Base class for the extra block in the FPN. 13 | Arguments: 14 | results (List[Tensor]): the result of the FPN 15 | x (List[Tensor]): the original feature maps 16 | names (List[str]): the names for each one of the 17 | original feature maps 18 | Returns: 19 | results (List[Tensor]): the extended set of results 20 | of the FPN 21 | names (List[str]): the extended set of names for the results 22 | """ 23 | 24 | def forward( 25 | self, 26 | results: List[Tensor], 27 | x: List[Tensor], 28 | names: List[str], 29 | ) -> Tuple[List[Tensor], List[str]]: 30 | pass 31 | 32 | 33 | class FeaturePyramidNetworkRoddick(nn.Module): 34 | """ 35 | Module that adds a FPN from on top of a set of feature maps. This is based on 36 | `"Feature Pyramid Network for Object Detection" `_. 37 | The feature maps are currently supposed to be in increasing depth 38 | order. 39 | The input to the model is expected to be an OrderedDict[Tensor], containing 40 | the feature maps on top of which the FPN will be added. 41 | Arguments: 42 | in_channels_list (list[int]): number of channels for each feature map that 43 | is passed to the module 44 | out_channels (int): number of channels of the FPN representation 45 | extra_blocks (ExtraFPNBlock or None): if provided, extra operations will 46 | be performed. It is expected to take the fpn features, the original 47 | features and the names of the original features as input, and returns 48 | a new list of feature maps and their corresponding names 49 | Examples:: 50 | >>> m = torchvision.ops.FeaturePyramidNetwork([10, 20, 30], 5) 51 | >>> # get some dummy data 52 | >>> x = OrderedDict() 53 | >>> x['feat0'] = torch.rand(1, 10, 64, 64) 54 | >>> x['feat2'] = torch.rand(1, 20, 16, 16) 55 | >>> x['feat3'] = torch.rand(1, 30, 8, 8) 56 | >>> # compute the FPN on top of x 57 | >>> output = m(x) 58 | >>> print([(k, v.shape) for k, v in output.items()]) 59 | >>> # returns 60 | >>> [('feat0', torch.Size([1, 5, 64, 64])), 61 | >>> ('feat2', torch.Size([1, 5, 16, 16])), 62 | >>> ('feat3', torch.Size([1, 5, 8, 8]))] 63 | """ 64 | 65 | def __init__( 66 | self, 67 | in_channels_list: List[int], 68 | out_channels: int, 69 | extra_blocks: Optional[ExtraFPNBlock] = None, 70 | ): 71 | super(FeaturePyramidNetworkRoddick, self).__init__() 72 | self.inner_blocks = nn.ModuleList() 73 | self.layer_blocks = nn.ModuleList() 74 | for in_channels in in_channels_list: 75 | if in_channels == 0: 76 | raise ValueError("in_channels=0 is currently not supported") 77 | inner_block_module = nn.Conv2d(in_channels, out_channels, 1) 78 | layer_block_module = nn.Conv2d(out_channels, out_channels, 3, padding=1) 79 | self.inner_blocks.append(inner_block_module) 80 | self.layer_blocks.append(layer_block_module) 81 | 82 | # initialize parameters now to avoid modifying the initialization of top_blocks 83 | for m in self.children(): 84 | if isinstance(m, nn.Conv2d): 85 | nn.init.kaiming_uniform_(m.weight, a=1) 86 | nn.init.constant_(m.bias, 0) 87 | 88 | if extra_blocks is not None: 89 | assert isinstance(extra_blocks, ExtraFPNBlock) 90 | self.extra_blocks = extra_blocks 91 | 92 | def get_result_from_inner_blocks(self, x: Tensor, idx: int) -> Tensor: 93 | """ 94 | This is equivalent to self.inner_blocks[idx](x), 95 | but torchscript doesn't support this yet 96 | """ 97 | num_blocks = 0 98 | for m in self.inner_blocks: 99 | num_blocks += 1 100 | if idx < 0: 101 | idx += num_blocks 102 | i = 0 103 | out = x 104 | for module in self.inner_blocks: 105 | if i == idx: 106 | out = module(x) 107 | i += 1 108 | return out 109 | 110 | def get_result_from_layer_blocks(self, x: Tensor, idx: int) -> Tensor: 111 | """ 112 | This is equivalent to self.layer_blocks[idx](x), 113 | but torchscript doesn't support this yet 114 | """ 115 | num_blocks = 0 116 | for m in self.layer_blocks: 117 | num_blocks += 1 118 | if idx < 0: 119 | idx += num_blocks 120 | i = 0 121 | out = x 122 | for module in self.layer_blocks: 123 | if i == idx: 124 | out = module(x) 125 | i += 1 126 | return out 127 | 128 | def forward(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]: 129 | """ 130 | Computes the FPN for a set of feature maps. 131 | Arguments: 132 | x (OrderedDict[Tensor]): feature maps for each feature level. 133 | Returns: 134 | results (OrderedDict[Tensor]): feature maps after FPN layers. 135 | They are ordered from highest resolution first. 136 | """ 137 | # unpack OrderedDict into two lists for easier handling 138 | names = list(x.keys()) 139 | x = list(x.values()) 140 | 141 | last_inner = self.get_result_from_inner_blocks(x[-1], -1) 142 | results = [] 143 | results.append(self.get_result_from_layer_blocks(last_inner, -1)) 144 | 145 | for idx in range(len(x) - 2, -1, -1): 146 | inner_lateral = self.get_result_from_inner_blocks(x[idx], idx) 147 | feat_shape = inner_lateral.shape[-2:] 148 | inner_top_down = F.interpolate(last_inner, size=feat_shape, mode="nearest") 149 | last_inner = inner_lateral + inner_top_down 150 | results.insert(0, self.get_result_from_layer_blocks(last_inner, idx)) 151 | 152 | if self.extra_blocks is not None: 153 | results, names = self.extra_blocks(results, x, names) 154 | 155 | # make it back an OrderedDict 156 | out = OrderedDict([(k, v) for k, v in zip(names, results)]) 157 | 158 | return out 159 | 160 | 161 | class LLMaxPool(ExtraFPNBlock): 162 | """ 163 | Applies a max_pool2d on top of the last feature map 164 | """ 165 | 166 | def forward( 167 | self, 168 | x: List[Tensor], 169 | y: List[Tensor], 170 | names: List[str], 171 | ) -> Tuple[List[Tensor], List[str]]: 172 | names.append("pool") 173 | x.append(F.max_pool2d(x[-1], 1, 2, 0)) 174 | return x, names 175 | 176 | 177 | class LastLevelP6P7(ExtraFPNBlock): 178 | """ 179 | This module is used in RetinaNet to generate extra layers, P6 and P7. 180 | """ 181 | 182 | def __init__(self, in_channels: int, out_channels: int): 183 | super(LastLevelP6P7, self).__init__() 184 | self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1) 185 | self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1) 186 | for module in [self.p6, self.p7]: 187 | nn.init.kaiming_uniform_(module.weight, a=1) 188 | nn.init.constant_(module.bias, 0) 189 | self.use_P5 = in_channels == out_channels 190 | 191 | def forward( 192 | self, 193 | p: List[Tensor], 194 | c: List[Tensor], 195 | names: List[str], 196 | ) -> Tuple[List[Tensor], List[str]]: 197 | p5, c5 = p[-1], c[-1] 198 | x = p5 if self.use_P5 else c5 199 | p6 = self.p6(x) 200 | p7 = self.p7(F.relu(p6)) 201 | p.extend([p6, p7]) 202 | names.extend(["p6", "p7"]) 203 | return p, names 204 | -------------------------------------------------------------------------------- /src/model/intermediate_layer_getter.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from collections import OrderedDict 3 | 4 | 5 | # using wonder's beautiful simplification: https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-objects/31174427?noredirect=1#comment86638618_31174427 6 | def rgetattr(obj, attr, *args): 7 | def _getattr(obj, attr): 8 | return getattr(obj, attr, *args) 9 | 10 | return functools.reduce(_getattr, [obj] + attr.split(".")) 11 | 12 | 13 | class IntermediateLayerGetter: 14 | def __init__(self, model, return_layers, keep_output=True): 15 | """Wraps a Pytorch module to get intermediate values 16 | 17 | Arguments: 18 | model {nn.module} -- The Pytorch module to call 19 | return_layers {dict} -- Dictionary with the selected submodules 20 | to return the output (format: {[current_module_name]: [desired_output_name]}, 21 | current_module_name can be a nested submodule, e.g. submodule1.submodule2.submodule3) 22 | 23 | Keyword Arguments: 24 | keep_output {bool} -- If True model_output contains the final model's output 25 | in the other case model_output is None (default: {True}) 26 | Returns: 27 | (mid_outputs {OrderedDict}, model_output {any}) -- mid_outputs keys are 28 | your desired_output_name (s) and their values are the returned tensors 29 | of those submodules (OrderedDict([(desired_output_name,tensor(...)), ...). 30 | See keep_output argument for model_output description. 31 | In case a submodule is called more than one time, all it's outputs are 32 | stored in a list. 33 | """ 34 | self._model = model 35 | self.return_layers = return_layers 36 | self.keep_output = keep_output 37 | 38 | def __call__(self, *args, **kwargs): 39 | ret = OrderedDict() 40 | handles = [] 41 | for name, new_name in self.return_layers.items(): 42 | layer = rgetattr(self._model, name) 43 | 44 | def hook(module, input, output, new_name=new_name): 45 | if new_name in ret: 46 | if type(ret[new_name]) is list: 47 | ret[new_name].append(output) 48 | else: 49 | ret[new_name] = [ret[new_name], output] 50 | else: 51 | ret[new_name] = output 52 | 53 | try: 54 | h = layer.register_forward_hook(hook) 55 | except AttributeError as e: 56 | raise AttributeError(f"Module {name} not found") 57 | handles.append(h) 58 | 59 | if self.keep_output: 60 | output = self._model(*args, **kwargs) 61 | else: 62 | self._model(*args, **kwargs) 63 | output = None 64 | 65 | for h in handles: 66 | h.remove() 67 | 68 | return ret, output 69 | -------------------------------------------------------------------------------- /src/model/loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | 8 | import src 9 | from torch.nn.utils.rnn import pad_sequence 10 | 11 | 12 | def class_obj_size(): 13 | """ 14 | NuScenes class object size on Roddick Split 15 | gt_dir/name: semantic_maps_new_200x200 16 | """ 17 | 18 | # Object size per class relative to map resolution 19 | s200 = [0.263, 0.031, 0.055, 0.029, 0.009, 0.002, 0.010, 0.001, 0.008, 0.006, 0.001] 20 | s100 = [0.296, 0.037, 0.066, 0.035, 0.011, 0.003, 0.013, 0.002, 0.010, 0.008, 0.002] 21 | s50 = [0.331, 0.046, 0.083, 0.044, 0.015, 0.004, 0.019, 0.003, 0.014, 0.011, 0.003] 22 | s25 = [0.380, 0.065, 0.114, 0.061, 0.022, 0.007, 0.031, 0.006, 0.022, 0.017, 0.008] 23 | s13 = [0.465, 0.105, 0.169, 0.093, 0.037, 0.016, 0.058, 0.013, 0.039, 0.031, 0.021] 24 | s7 = [0.599, 0.196, 0.279, 0.148, 0.069, 0.037, 0.117, 0.033, 0.076, 0.065, 0.054] 25 | 26 | ms_obj_size = torch.stack( 27 | [ 28 | torch.tensor(s100), 29 | torch.tensor(s50), 30 | torch.tensor(s25), 31 | torch.tensor(s13), 32 | torch.tensor(s7), 33 | ] 34 | ) 35 | 36 | return ms_obj_size.cuda() 37 | 38 | 39 | def get_class_frequency(): 40 | """ 41 | NuScenes class frequency on Roddick Split 42 | gt_dir/name: semantic_maps_new_200x200 43 | """ 44 | 45 | # Class frequency 46 | s200 = [ 47 | 0.999, 48 | 0.489, 49 | 0.969, 50 | 0.294, 51 | 0.094, 52 | 0.080, 53 | 0.758, 54 | 0.0646, 55 | 0.070, 56 | 0.080, 57 | 0.313, 58 | 0.485, 59 | ] 60 | 61 | return torch.tensor(s200).cuda() 62 | 63 | 64 | def class_weight(): 65 | """ 66 | NuScenes class frequency on Roddick Split 67 | gt_dir/name: semantic_maps_new_200x200 68 | """ 69 | 70 | # Class weights 71 | w = torch.tensor([1.31, 5.43, 1.00, 2.13, 8.04, 1.19, 1.75, 5.71]) 72 | 73 | return w 74 | 75 | 76 | def class_flood_level(): 77 | ms_flood = torch.tensor( 78 | [ 79 | [ 80 | 0.0356, 81 | 0.2656, 82 | 0.1070, 83 | 0.6454, 84 | 0.8, 85 | 0.75, 86 | 0.1832, 87 | 0.8, 88 | 0.75, 89 | 0.6633, 90 | 0.5591, 91 | ], 92 | [ 93 | 0.0294, 94 | 0.2418, 95 | 0.0902, 96 | 0.5982, 97 | 0.75, 98 | 0.7, 99 | 0.1589, 100 | 0.8, 101 | 0.75, 102 | 0.6302, 103 | 0.5116, 104 | ], 105 | [ 106 | 0.0247, 107 | 0.2363, 108 | 0.0787, 109 | 0.5765, 110 | 0.75, 111 | 0.75, 112 | 0.1416, 113 | 0.8, 114 | 0.75, 115 | 0.6317, 116 | 0.4982, 117 | ], 118 | ] 119 | ).cuda() 120 | return ms_flood 121 | 122 | 123 | class MultiTaskLoss(nn.Module): 124 | def __init__(self, l1_reduction="mean"): 125 | super().__init__() 126 | 127 | self.log_vars = nn.Parameter(torch.zeros(2)) 128 | self.l1_reduction = l1_reduction 129 | 130 | def forward(self, preds, labels, bev, bev_pred): 131 | precision1 = torch.exp(-self.log_vars[0]) 132 | precision2 = torch.exp(-self.log_vars[1]) 133 | 134 | s1_loss = dice_loss_mean(preds[0], labels[0]) 135 | s2_loss = dice_loss_mean(preds[1], labels[1]) 136 | s4_loss = dice_loss_mean(preds[2], labels[2]) 137 | 138 | dice_l = s1_loss + s2_loss + s4_loss 139 | bev_l = F.l1_loss(bev_pred, bev, reduction=self.l1_reduction) 140 | 141 | total_loss = (precision1 * dice_l + self.log_vars[0]) + ( 142 | precision2 * bev_l + self.log_vars[1] 143 | ) 144 | 145 | return ( 146 | total_loss, 147 | float(self.log_vars[0]), 148 | float(self.log_vars[1]), 149 | float(dice_l), 150 | float(bev_l), 151 | ) 152 | 153 | 154 | def lovasz_grad(gt_sorted): 155 | """ 156 | Computes gradient of the Lovasz extension w.r.t sorted errors 157 | See Alg. 1 in paper 158 | """ 159 | p = len(gt_sorted) 160 | gts = gt_sorted.sum() 161 | intersection = gts - gt_sorted.float().cumsum(0) 162 | union = gts + (1 - gt_sorted).float().cumsum(0) 163 | jaccard = 1.0 - intersection / union 164 | if p > 1: # cover 1-pixel case 165 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] 166 | return jaccard 167 | 168 | 169 | def lovasz_hinge_flat(logits, labels): 170 | """ 171 | Binary Lovasz hinge loss 172 | logits: [P] Variable, logits at each prediction (between -\infty and +\infty) 173 | labels: [P] Tensor, binary ground truth labels (0 or 1) 174 | ignore: label to ignore 175 | """ 176 | logits = logits.view(-1) 177 | labels = labels.view(-1) 178 | signs = 2.0 * labels.float() - 1.0 179 | errors = 1.0 - logits * signs 180 | errors_sorted, perm = torch.sort(errors, dim=0, descending=True) 181 | perm = perm.data 182 | gt_sorted = labels[perm] 183 | grad = lovasz_grad(gt_sorted) 184 | loss = torch.dot(F.relu(errors_sorted), grad) 185 | return loss 186 | 187 | 188 | def basic_weighted_dice_loss_mean(pred, label, idx_scale=None, args=None): 189 | pred = torch.sigmoid(pred) 190 | label = label.float() 191 | intersection = 2 * pred * label 192 | union = pred + label 193 | iou = (intersection.float().sum(dim=0).sum(dim=-1).sum(dim=-1)) / ( 194 | union.float().sum(dim=0).sum(dim=-1).sum(dim=-1) + 1e-5 195 | ) 196 | 197 | # Set up class weighting 198 | weight = class_weight() 199 | 200 | loss_mean = (weight * (1 - iou)).mean() 201 | 202 | return loss_mean 203 | 204 | 205 | def weighted_dice_loss_mean(pred, label, idx_scale, args): 206 | pred = torch.sigmoid(pred) 207 | label = label.float() 208 | intersection = 2 * pred * label 209 | union = pred + label 210 | iou = (intersection.float().sum(dim=0).sum(dim=-1).sum(dim=-1)) / ( 211 | union.float().sum(dim=0).sum(dim=-1).sum(dim=-1) + 1e-5 212 | ) 213 | 214 | # Set up class weighting 215 | obj_size = class_obj_size()[idx_scale] 216 | class_freq = class_frequency()[idx_scale] 217 | weight = 1 / obj_size ** args.exp_os * 1 / class_freq ** args.exp_cf 218 | 219 | # Set large class' weights to one 220 | idxs_to_one = torch.tensor([0, 1, 2, 3, 6]).long() 221 | weight[idxs_to_one] = 1 222 | 223 | loss_mean = (weight * (1 - iou)).mean() 224 | 225 | return loss_mean 226 | 227 | 228 | def weighted_dice_loss_sum(pred, label, idx_scale, args): 229 | pred = torch.sigmoid(pred) 230 | label = label.float() 231 | intersection = 2 * pred * label 232 | union = pred + label 233 | iou = (intersection.float().sum(dim=0).sum(dim=-1).sum(dim=-1) + 1e-5) / ( 234 | union.float().sum(dim=0).sum(dim=-1).sum(dim=-1) + 1e-5 235 | ) 236 | 237 | # Set up class weighting 238 | obj_size = class_obj_size()[idx_scale] 239 | class_freq = class_frequency()[idx_scale] 240 | weight = 1 / obj_size ** args.exp_os * 1 / class_freq ** args.exp_cf 241 | 242 | loss_mean = (weight * (1 - iou)).sum() 243 | 244 | return loss_mean 245 | 246 | 247 | def dice_loss_sum(pred, label, scale_idx=None, args=None): 248 | pred = torch.sigmoid(pred) 249 | label = label.float() 250 | intersection = 2 * pred * label 251 | union = pred + label 252 | iou = (intersection.float().sum(dim=0).sum(dim=-1).sum(dim=-1)) / ( 253 | union.float().sum(dim=0).sum(dim=-1).sum(dim=-1) + 1e-5 254 | ) 255 | 256 | loss_mean = (1 - iou).sum() 257 | 258 | return loss_mean 259 | 260 | 261 | def dice_loss_mean(pred, label, vis_mask=None, scale_idx=None, args=None): 262 | pred = torch.sigmoid(pred) 263 | label = label.float() 264 | intersection = 2 * pred * label 265 | union = pred + label 266 | iou = (intersection.float().sum(dim=0).sum(dim=-1).sum(dim=-1)) / ( 267 | union.float().sum(dim=0).sum(dim=-1).sum(dim=-1) + 1e-5 268 | ) 269 | 270 | loss_mean = 1 - iou.mean() 271 | 272 | return loss_mean 273 | 274 | 275 | def dice_loss_mean_wo_sigmoid(pred, label, vis_mask=None, scale_idx=None, args=None): 276 | label = label.float() 277 | intersection = 2 * pred * label 278 | union = pred + label 279 | iou = (intersection.float().sum()) / (union.float().sum() + 1e-5) 280 | 281 | loss_mean = 1 - iou.mean() 282 | 283 | return loss_mean 284 | 285 | 286 | def dice_loss_mean_vis_only(pred, label, vis_mask, scale_idx=None, args=None): 287 | # Calculates the loss in the visible areas only 288 | pred = torch.sigmoid(pred) * vis_mask.float() 289 | label = label.float() 290 | intersection = 2 * pred * label 291 | union = pred + label 292 | iou = (intersection.float().sum(dim=0).sum(dim=-1).sum(dim=-1)) / ( 293 | union.float().sum(dim=0).sum(dim=-1).sum(dim=-1) + 1e-5 294 | ) 295 | 296 | loss_mean = 1 - iou.mean() 297 | 298 | return loss_mean 299 | 300 | 301 | def dice2_loss_mean(pred, label, scale_idx=None, args=None): 302 | pred = torch.sigmoid(pred) 303 | label = label.float() 304 | intersection = 2 * pred * label 305 | union = pred + label 306 | iou = (intersection.float().sum(dim=0).sum(dim=-1).sum(dim=-1) + 1.0) / ( 307 | union.float().sum(dim=0).sum(dim=-1).sum(dim=-1) + 1.0 308 | ) 309 | 310 | loss_mean = 1 - iou.mean() 311 | 312 | return loss_mean 313 | 314 | 315 | def dice_loss_mean_mix1(pred, label, scale_idx=None, args=None): 316 | pred = torch.sigmoid(pred) 317 | label = label.float() 318 | intersection = 2 * pred * label 319 | union = pred + label 320 | 321 | e_numerator = torch.zeros(pred.shape[1]) + 1e-5 322 | idxs_to_zero_out = torch.tensor([4, 5, 7, 8, 10]).long() 323 | e_numerator[idxs_to_zero_out] = 0 324 | e_numerator = e_numerator.cuda() 325 | 326 | iou = (intersection.float().sum(dim=0).sum(dim=-1).sum(dim=-1) + e_numerator) / ( 327 | union.float().sum(dim=0).sum(dim=-1).sum(dim=-1) + 1e-5 328 | ) 329 | 330 | loss_mean = 1 - iou.mean() 331 | 332 | return loss_mean 333 | 334 | 335 | def dice_loss_mean_mix2(pred, label, scale_idx=None, args=None): 336 | pred = torch.sigmoid(pred) 337 | label = label.float() 338 | intersection = 2 * pred * label 339 | union = pred + label 340 | 341 | e_numerator = torch.zeros(pred.shape[1]) + 1e-5 342 | idxs_to_zero_out = torch.tensor([4, 5, 7, 8]).long() 343 | e_numerator[idxs_to_zero_out] = 0 344 | e_numerator = e_numerator.cuda() 345 | 346 | iou = (intersection.float().sum(dim=0).sum(dim=-1).sum(dim=-1) + e_numerator) / ( 347 | union.float().sum(dim=0).sum(dim=-1).sum(dim=-1) + 1e-5 348 | ) 349 | 350 | loss_mean = 1 - iou.mean() 351 | 352 | return loss_mean 353 | 354 | 355 | def dice_loss_per_class(pred, labels, scale_idx=None): 356 | pred = torch.sigmoid(pred) 357 | labels = labels.float() 358 | intersection = 2 * pred * labels 359 | union = pred + labels 360 | iou = (intersection.float().sum(dim=0).sum(dim=-1).sum(dim=-1)) / ( 361 | union.float().sum(dim=0).sum(dim=-1).sum(dim=-1) + 1e-5 362 | ) 363 | return 1 - iou 364 | 365 | 366 | def dice_loss_per_class_infer(pred, labels, scale_idx=None): 367 | pred = torch.sigmoid(pred) 368 | labels = labels.float() 369 | intersection = (2 * pred * labels).float().sum(dim=0).sum(dim=-1).sum(dim=-1) 370 | union = (pred + labels).float().sum(dim=0).sum(dim=-1).sum(dim=-1) 371 | iou = (intersection) / (union + 1e-5) 372 | return 1 - iou, intersection, union 373 | 374 | 375 | def flooded_dice_loss_mean(dice_loss, scale_idx): 376 | 377 | # Flood level per class 378 | flood_level = class_flood_level()[scale_idx] 379 | 380 | flood_loss = (dice_loss - flood_level).abs() + flood_level 381 | 382 | return flood_loss.mean() 383 | 384 | 385 | def flooded_dice_loss_sum(dice_loss, scale_idx): 386 | 387 | # Flood level per class 388 | flood_level = class_flood_level()[scale_idx] 389 | 390 | flood_loss = (dice_loss - flood_level).abs() + flood_level 391 | 392 | return flood_loss.sum() 393 | 394 | 395 | def iou_loss(pred, labels): 396 | e = 1e-6 397 | pred = torch.sigmoid(pred) 398 | labels = labels.float() 399 | intersection = pred * labels 400 | union = (pred + labels) - (pred * labels) 401 | iou = intersection.sum() / (union.sum() + e) 402 | return 1.0 - iou 403 | 404 | 405 | def bce_plus_dice_loss(pred, labels, weights, alpha): 406 | bce = F.binary_cross_entropy_with_logits( 407 | input=pred, target=labels, weight=weights, reduction="mean" 408 | ) 409 | dice = dice_loss_mean(pred, labels) 410 | total = (alpha[0] * bce + alpha[1] * dice).sum() 411 | return total 412 | 413 | 414 | def bce_plus_iou_loss(pred, labels, weights, alpha): 415 | bce = F.binary_cross_entropy_with_logits( 416 | input=pred, target=labels, weight=weights, reduction="mean" 417 | ) 418 | iou = iou_loss(pred, labels) 419 | total = (alpha[0] * bce + alpha[1] * iou).sum() 420 | return total 421 | 422 | 423 | def class_balanced_loss( 424 | labels, logits, samples_per_cls, no_of_classes, loss_type, alpha=None 425 | ): 426 | """Compute the Class Balanced Loss between `logits` and the ground truth `labels`. 427 | Class Balanced Loss: ((1-beta)/(1-beta^n))*Loss(labels, logits) 428 | where Loss is one of the standard losses used for Neural Networks. 429 | Args: 430 | labels: A int tensor of size [batch]. 431 | logits: A float tensor of size [batch, no_of_classes]. 432 | samples_per_cls: A python list of size [no_of_classes]. 433 | no_of_classes: total number of classes. int 434 | loss_type: string. One of "sigmoid", "focal", "softmax". 435 | beta: float. Hyperparameter for Class balanced loss. 436 | gamma: float. Hyperparameter for Focal loss. 437 | Returns: 438 | cb_loss: A float tensor representing class balanced loss 439 | """ 440 | # effective_num = 1.0 - np.power(beta, samples_per_cls) 441 | # weights = (1.0 - beta) / np.array(effective_num) 442 | # weights = weights / np.sum(weights) * no_of_classes 443 | 444 | if loss_type == "bce": 445 | labels_one_hot = F.one_hot(labels.long(), no_of_classes).float() 446 | weights = 1 / np.array(samples_per_cls) 447 | weights = torch.tensor(weights).float().cuda() 448 | weights = weights.view(1, 1, 1, 1, -1) 449 | weights = ( 450 | weights.repeat( 451 | labels_one_hot.shape[0], 452 | labels_one_hot.shape[1], 453 | labels_one_hot.shape[2], 454 | labels_one_hot.shape[3], 455 | 1, 456 | ) 457 | * labels_one_hot 458 | ) 459 | weights = weights.sum(-1) 460 | cb_loss = F.binary_cross_entropy_with_logits( 461 | input=logits, target=labels, weight=weights 462 | ) 463 | elif loss_type == "iou": 464 | cb_loss = iou_loss(logits, labels) 465 | elif loss_type == "dice": 466 | cb_loss = dice_loss_mean(logits, labels) 467 | elif loss_type == "lovasz": 468 | cb_loss = lovasz_hinge_flat(logits, labels) 469 | # elif loss_type == "bce_iou": 470 | # cb_loss = bce_plus_iou_loss(logits, labels, weights, alpha) 471 | # elif loss_type == "bce_dice": 472 | # cb_loss = bce_plus_dice_loss(logits, labels, weights, alpha) 473 | 474 | return cb_loss 475 | 476 | 477 | def weighted_bce_loss(pred, label, vis_mask): 478 | # Apply visibility mask to predictions 479 | pred = pred * vis_mask.float() 480 | 481 | pred = torch.sigmoid(pred) 482 | 483 | label = label.float() 484 | 485 | eps = 1e-12 486 | 487 | # Reshape weights for broadcasting 488 | cf = get_class_frequency()[None, :, None, None] 489 | pos_weights = torch.sqrt(1 / cf) 490 | neg_weights = torch.sqrt(1 / (1 - cf)) 491 | 492 | # labels * -log(sigmoid(logits)) + 493 | # (1 - labels) * -log(1 - sigmoid(logits)) 494 | 495 | loss = pos_weights * label * -torch.log(pred + eps) + (1 - label) * -torch.log( 496 | 1 - pred + eps 497 | ) 498 | 499 | return loss.sum(dim=-1).sum(dim=-1).sum(dim=0) 500 | 501 | 502 | def uncertainty_loss(pred, vis_mask): 503 | # Invert visibility mask 504 | vis_mask = 1 - vis_mask.float() 505 | eps = 1e-12 506 | pred = pred * vis_mask 507 | pred = torch.sigmoid(pred) 508 | loss = 1 - pred * torch.log(pred + eps) 509 | 510 | return loss.sum(dim=-1).sum(dim=-1).sum(dim=0) 511 | 512 | 513 | def aux_recognition_loss(pred, label): 514 | """ 515 | BCE loss for class probabilities, multi-label classification 516 | """ 517 | # pred = torch.sigmoid(pred) 518 | label = src.utils.count_classes_per_sample(label) 519 | 520 | assert pred.shape == label.shape 521 | 522 | class_freq = get_class_frequency()[None, :] 523 | weights = (1 / class_freq).repeat(len(pred), 1) 524 | 525 | loss = F.binary_cross_entropy_with_logits( 526 | pred, label, weight=weights, reduction="none" 527 | ) 528 | return loss 529 | 530 | 531 | def focal_loss(pred, gt, alpha, gamma): 532 | BCE_loss = F.binary_cross_entropy_with_logits( 533 | pred, gt.float(), reduction="none" 534 | ).sum(-1) 535 | gt, _ = torch.max(gt, dim=-1) 536 | gt = gt.long() 537 | at = torch.tensor(alpha, device=gt.device).gather(0, gt.data.view(-1)) 538 | pt = torch.exp(-BCE_loss) 539 | F_loss = at * (1 - pt) ** gamma * BCE_loss 540 | return F_loss 541 | -------------------------------------------------------------------------------- /src/model/mocha2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Energy(nn.Module): 7 | def __init__(self, enc_dim=10, dec_dim=10, att_dim=10, init_r=-4): 8 | """ 9 | [Modified Bahdahnau attention] from 10 | "Online and Linear-Time Attention by Enforcing Monotonic Alignment" (ICML 2017) 11 | http://arxiv.org/abs/1704.00784 12 | Used for Monotonic Attention and Chunk Attention 13 | """ 14 | super().__init__() 15 | self.tanh = nn.Tanh() 16 | self.W = nn.Linear(enc_dim, att_dim, bias=False) 17 | self.V = nn.Linear(dec_dim, att_dim, bias=False) 18 | self.b = nn.Parameter(torch.Tensor(att_dim).normal_()) 19 | 20 | self.v = nn.utils.weight_norm(nn.Linear(10, 1)) 21 | self.v.weight_g.data = torch.Tensor([1 / att_dim]).sqrt() 22 | 23 | self.r = nn.Parameter(torch.Tensor([init_r])) 24 | 25 | def forward(self, encoder_outputs, decoder_h): 26 | """ 27 | Args: 28 | encoder_outputs: [batch_size, sequence_length, enc_dim] 29 | decoder_h: [batch_size, dec_dim] 30 | Return: 31 | Energy [batch_size, sequence_length] 32 | """ 33 | batch_size, sequence_length, enc_dim = encoder_outputs.size() 34 | encoder_outputs = encoder_outputs.view(-1, enc_dim) 35 | energy = self.tanh( 36 | self.W(encoder_outputs) 37 | + self.V(decoder_h).repeat(sequence_length, 1) 38 | + self.b 39 | ) 40 | energy = self.v(energy).squeeze(-1) + self.r 41 | 42 | return energy.view(batch_size, sequence_length) 43 | 44 | 45 | class MonotonicAttention(nn.Module): 46 | def __init__(self): 47 | """ 48 | [Monotonic Attention] from 49 | "Online and Linear-Time Attention by Enforcing Monotonic Alignment" (ICML 2017) 50 | http://arxiv.org/abs/1704.00784 51 | """ 52 | super().__init__() 53 | 54 | self.monotonic_energy = Energy() 55 | self.sigmoid = nn.Sigmoid() 56 | 57 | def gaussian_noise(self, *size): 58 | """Additive gaussian nosie to encourage discreteness""" 59 | if torch.cuda.is_available(): 60 | return torch.cuda.FloatTensor(*size).normal_() 61 | else: 62 | return torch.Tensor(*size).normal_() 63 | 64 | def safe_cumprod(self, x): 65 | """Numerically stable cumulative product by cumulative sum in log-space""" 66 | return torch.exp( 67 | torch.cumsum(torch.log(torch.clamp(x, min=1e-10, max=1)), dim=1) 68 | ) 69 | 70 | def exclusive_cumprod(self, x): 71 | """Exclusive cumulative product [a, b, c] => [1, a, a * b] 72 | * TensorFlow: https://www.tensorflow.org/api_docs/python/tf/cumprod 73 | * PyTorch: https://discuss.pytorch.org/t/cumprod-exclusive-true-equivalences/2614 74 | """ 75 | batch_size, sequence_length = x.size() 76 | if torch.cuda.is_available(): 77 | one_x = torch.cat([torch.ones(batch_size, 1).cuda(), x], dim=1)[:, :-1] 78 | else: 79 | one_x = torch.cat([torch.ones(batch_size, 1), x], dim=1)[:, :-1] 80 | return torch.cumprod(one_x, dim=1) 81 | 82 | def soft(self, encoder_outputs, decoder_h, previous_alpha=None): 83 | """ 84 | Soft monotonic attention (Train) 85 | Args: 86 | encoder_outputs [batch_size, sequence_length, enc_dim] 87 | decoder_h [batch_size, dec_dim] 88 | previous_alpha [batch_size, sequence_length] 89 | Return: 90 | alpha [batch_size, sequence_length] 91 | """ 92 | batch_size, sequence_length, enc_dim = encoder_outputs.size() 93 | 94 | monotonic_energy = self.monotonic_energy(encoder_outputs, decoder_h) 95 | p_select = self.sigmoid( 96 | monotonic_energy + self.gaussian_noise(monotonic_energy.size()) 97 | ) 98 | cumprod_1_minus_p = self.safe_cumprod(1 - p_select) 99 | 100 | if previous_alpha is None: 101 | # First iteration => alpha = [1, 0, 0 ... 0] 102 | alpha = torch.zeros(batch_size, sequence_length) 103 | alpha[:, 0] = torch.ones(batch_size) 104 | if torch.cuda.is_available: 105 | alpha = alpha.cuda() 106 | 107 | else: 108 | alpha = ( 109 | p_select 110 | * cumprod_1_minus_p 111 | * torch.cumsum(previous_alpha / cumprod_1_minus_p, dim=1) 112 | ) 113 | 114 | return alpha 115 | 116 | def hard(self, encoder_outputs, decoder_h, previous_attention=None): 117 | """ 118 | Hard monotonic attention (Test) 119 | Args: 120 | encoder_outputs [batch_size, sequence_length, enc_dim] 121 | decoder_h [batch_size, dec_dim] 122 | previous_attention [batch_size, sequence_length] 123 | Return: 124 | alpha [batch_size, sequence_length] 125 | """ 126 | batch_size, sequence_length, enc_dim = encoder_outputs.size() 127 | 128 | if previous_attention is None: 129 | # First iteration => alpha = [1, 0, 0 ... 0] 130 | attention = torch.zeros(batch_size, sequence_length) 131 | attention[:, 0] = torch.ones(batch_size) 132 | if torch.cuda.is_available: 133 | attention = attention.cuda() 134 | else: 135 | # TODO: Linear Time Decoding 136 | # It's not clear if authors' TF implementation decodes in linear time. 137 | # https://github.com/craffel/mad/blob/master/example_decoder.py#L235 138 | # They calculate energies for whole encoder outputs 139 | # instead of scanning from previous attended encoder output. 140 | monotonic_energy = self.monotonic_energy(encoder_outputs, decoder_h) 141 | 142 | # Hard Sigmoid 143 | # Attend when monotonic energy is above threshold (Sigmoid > 0.5) 144 | above_threshold = (monotonic_energy > 0).float() 145 | 146 | p_select = above_threshold * torch.cumsum(previous_attention, dim=1) 147 | attention = p_select * self.exclusive_cumprod(1 - p_select) 148 | 149 | # Not attended => attend at last encoder output 150 | # Assume that encoder outputs are not padded 151 | attended = attention.sum(dim=1) 152 | for batch_i in range(batch_size): 153 | if not attended[batch_i]: 154 | attention[batch_i, -1] = 1 155 | 156 | # Ex) 157 | # p_select = [0, 0, 0, 1, 1, 0, 1, 1] 158 | # 1 - p_select = [1, 1, 1, 0, 0, 1, 0, 0] 159 | # exclusive_cumprod(1 - p_select) = [1, 1, 1, 1, 0, 0, 0, 0] 160 | # attention: product of above = [0, 0, 0, 1, 0, 0, 0, 0] 161 | return attention 162 | 163 | 164 | class MoChA(MonotonicAttention): 165 | def __init__(self, chunk_size=3): 166 | """ 167 | [Monotonic Chunkwise Attention] from 168 | "Monotonic Chunkwise Attention" (ICLR 2018) 169 | https://openreview.net/forum?id=Hko85plCW 170 | """ 171 | super().__init__() 172 | self.chunk_size = chunk_size 173 | self.chunk_energy = Energy() 174 | self.softmax = nn.Softmax(dim=1) 175 | 176 | def moving_sum(self, x, back, forward): 177 | """Parallel moving sum with 1D Convolution""" 178 | # Pad window before applying convolution 179 | # [batch_size, back + sequence_length + forward] 180 | x_padded = F.pad(x, pad=[back, forward]) 181 | 182 | # Fake channel dimension for conv1d 183 | # [batch_size, 1, back + sequence_length + forward] 184 | x_padded = x_padded.unsqueeze(1) 185 | 186 | # Apply conv1d with filter of all ones for moving sum 187 | filters = torch.ones(1, 1, back + forward + 1) 188 | if torch.cuda.is_available(): 189 | filters = filters.cuda() 190 | x_sum = F.conv1d(x_padded, filters) 191 | 192 | # Remove fake channel dimension 193 | # [batch_size, sequence_length] 194 | return x_sum.squeeze(1) 195 | 196 | def chunkwise_attention_soft(self, alpha, u): 197 | """ 198 | Args: 199 | alpha [batch_size, sequence_length]: emission probability in monotonic attention 200 | u [batch_size, sequence_length]: chunk energy 201 | chunk_size (int): window size of chunk 202 | Return 203 | beta [batch_size, sequence_length]: MoChA weights 204 | """ 205 | 206 | # Numerical stability 207 | # Divide by same exponent => doesn't affect softmax 208 | u -= torch.max(u, dim=1, keepdim=True)[0] 209 | exp_u = torch.exp(u) 210 | # Limit range of logit 211 | exp_u = torch.clamp(exp_u, min=1e-5) 212 | 213 | # Moving sum: 214 | # Zero-pad (chunk size - 1) on the left + 1D conv with filters of 1s. 215 | # [batch_size, sequence_length] 216 | denominators = self.moving_sum(exp_u, back=self.chunk_size - 1, forward=0) 217 | 218 | # Compute beta (MoChA weights) 219 | beta = exp_u * self.moving_sum( 220 | alpha / denominators, back=0, forward=self.chunk_size - 1 221 | ) 222 | return beta 223 | 224 | def chunkwise_attention_hard(self, monotonic_attention, chunk_energy): 225 | """ 226 | Mask non-attended area with '-inf' 227 | Args: 228 | monotonic_attention [batch_size, sequence_length] 229 | chunk_energy [batch_size, sequence_length] 230 | Return: 231 | masked_energy [batch_size, sequence_length] 232 | """ 233 | batch_size, sequence_length = monotonic_attention.size() 234 | 235 | # [batch_size] 236 | attended_indices = monotonic_attention.nonzero().cpu().data[:, 1].tolist() 237 | 238 | i = [[], []] 239 | total_i = 0 240 | for batch_i, attended_idx in enumerate(attended_indices): 241 | for window in range(self.chunk_size): 242 | if attended_idx - window >= 0: 243 | i[0].append(batch_i) 244 | i[1].append(attended_idx - window) 245 | total_i += 1 246 | i = torch.LongTensor(i) 247 | v = torch.FloatTensor([1] * total_i) 248 | mask = torch.sparse.FloatTensor(i, v, monotonic_attention.size()) 249 | mask = ~mask.to_dense().cuda().byte() 250 | 251 | # mask '-inf' energy before softmax 252 | masked_energy = chunk_energy.masked_fill_(mask, -float("inf")) 253 | return masked_energy 254 | 255 | def soft(self, encoder_outputs, decoder_h, previous_alpha=None): 256 | """ 257 | Soft monotonic chunkwise attention (Train) 258 | Args: 259 | encoder_outputs [batch_size, sequence_length, enc_dim] 260 | decoder_h [batch_size, dec_dim] 261 | previous_alpha [batch_size, sequence_length] 262 | Return: 263 | alpha [batch_size, sequence_length] 264 | beta [batch_size, sequence_length] 265 | """ 266 | alpha = super().soft(encoder_outputs, decoder_h, previous_alpha) 267 | chunk_energy = self.chunk_energy(encoder_outputs, decoder_h) 268 | beta = self.chunkwise_attention_soft(alpha, chunk_energy) 269 | return alpha, beta 270 | 271 | def hard(self, encoder_outputs, decoder_h, previous_attention=None): 272 | """ 273 | Hard monotonic chunkwise attention (Test) 274 | Args: 275 | encoder_outputs [batch_size, sequence_length, enc_dim] 276 | decoder_h [batch_size, dec_dim] 277 | previous_attention [batch_size, sequence_length] 278 | Return: 279 | monotonic_attention [batch_size, sequence_length]: hard alpha 280 | chunkwise_attention [batch_size, sequence_length]: hard beta 281 | """ 282 | # hard attention (one-hot) 283 | # [batch_size, sequence_length] 284 | monotonic_attention = super().hard( 285 | encoder_outputs, decoder_h, previous_attention 286 | ) 287 | chunk_energy = self.chunk_energy(encoder_outputs, decoder_h) 288 | masked_energy = self.chunkwise_attention_hard(monotonic_attention, chunk_energy) 289 | chunkwise_attention = self.softmax(masked_energy) 290 | return monotonic_attention, chunkwise_attention 291 | -------------------------------------------------------------------------------- /src/model/monotonic_attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import random 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class MonotonicAttention(nn.Module): 10 | """ 11 | Monotonic Multihead Attention with Infinite Lookback 12 | """ 13 | 14 | def __init__( 15 | self, 16 | embed_dim, 17 | num_heads, 18 | kdim=None, 19 | vdim=None, 20 | dropout=0.0, 21 | bias=True, 22 | energy_bias=True, 23 | ): 24 | self.embed_dim = embed_dim 25 | self.kdim = kdim if kdim is not None else embed_dim 26 | self.vdim = vdim if vdim is not None else embed_dim 27 | self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim 28 | 29 | self.k_proj_mono = nn.Linear(embed_dim, embed_dim, bias=bias) 30 | self.q_proj_mono = nn.Linear(embed_dim, embed_dim, bias=bias) 31 | self.k_proj_soft = nn.Linear(embed_dim, embed_dim, bias=bias) 32 | self.q_proj_soft = nn.Linear(embed_dim, embed_dim, bias=bias) 33 | self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias) 34 | 35 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 36 | 37 | self.num_heads = num_heads 38 | self.head_dim = embed_dim // num_heads 39 | assert ( 40 | self.head_dim * num_heads == self.embed_dim 41 | ), "embed_dim must be divisible by num_heads" 42 | self.scaling = self.head_dim ** -0.5 43 | 44 | self.eps = 1e-6 45 | 46 | self.noise_mean = 0.0 47 | self.noise_var = 1.0 48 | 49 | self.energy_bias_init = -2.0 50 | self.energy_bias = ( 51 | nn.Parameter(self.energy_bias_init * torch.ones([1])) 52 | if energy_bias is True 53 | else 0 54 | ) 55 | 56 | self.reset_parameters() 57 | 58 | self.k_in_proj = {"monotonic": self.k_proj_mono, "soft": self.k_proj_soft} 59 | self.q_in_proj = {"monotonic": self.q_proj_mono, "soft": self.q_proj_soft} 60 | self.v_in_proj = {"output": self.v_proj} 61 | 62 | def reset_parameters(self): 63 | if self.qkv_same_dim: 64 | # Empirically observed the convergence to be much better with 65 | # the scaled initialization 66 | nn.init.xavier_uniform_(self.k_proj_mono.weight, gain=1 / math.sqrt(2)) 67 | nn.init.xavier_uniform_(self.k_proj_soft.weight, gain=1 / math.sqrt(2)) 68 | nn.init.xavier_uniform_(self.q_proj_mono.weight, gain=1 / math.sqrt(2)) 69 | nn.init.xavier_uniform_(self.q_proj_soft.weight, gain=1 / math.sqrt(2)) 70 | nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) 71 | 72 | else: 73 | nn.init.xavier_uniform_(self.k_proj_mono.weight) 74 | nn.init.xavier_uniform_(self.k_proj_soft.weight) 75 | nn.init.xavier_uniform_(self.q_proj_mono.weight) 76 | nn.init.xavier_uniform_(self.q_proj_soft.weight) 77 | nn.init.xavier_uniform_(self.v_proj.weight) 78 | 79 | nn.init.xavier_uniform_(self.out_proj.weight) 80 | if self.out_proj.bias is not None: 81 | nn.init.constant_(self.out_proj.bias, 0.0) 82 | 83 | def input_projections(self, query, key, value, name): 84 | """ 85 | Prepare inputs for multihead attention 86 | ============================================================ 87 | Expected input size 88 | query: tgt_len, bsz, embed_dim 89 | key: src_len, bsz, embed_dim 90 | value: src_len, bsz, embed_dim 91 | name: monotonic or soft 92 | """ 93 | 94 | if query is not None: 95 | bsz = query.size(1) 96 | q = self.q_in_proj(query) 97 | q *= self.scaling 98 | q = ( 99 | q.contiguous() 100 | .view(-1, bsz * self.num_heads, self.head_dim) 101 | .transpose(0, 1) 102 | ) 103 | else: 104 | q = None 105 | 106 | if key is not None: 107 | bsz = key.size(1) 108 | k = self.k_in_proj(key) 109 | k = ( 110 | k.contiguous() 111 | .view(-1, bsz * self.num_heads, self.head_dim) 112 | .transpose(0, 1) 113 | ) 114 | else: 115 | k = None 116 | 117 | if value is not None: 118 | bsz = value.size(1) 119 | v = self.v_proj(value) 120 | v = ( 121 | v.contiguous() 122 | .view(-1, bsz * self.num_heads, self.head_dim) 123 | .transpose(0, 1) 124 | ) 125 | else: 126 | v = None 127 | 128 | return q, k, v 129 | 130 | def p_choose(self, query, key, key_padding_mask=None): 131 | """ 132 | Calculating step wise prob for reading and writing 133 | 1 to read, 0 to write 134 | ============================================================ 135 | Expected input size 136 | query: bsz, tgt_len, embed_dim 137 | key: bsz, src_len, embed_dim 138 | value: bsz, src_len, embed_dim 139 | key_padding_mask: bsz, src_len 140 | attn_mask: bsz, src_len 141 | query: bsz, tgt_len, embed_dim 142 | """ 143 | 144 | # prepare inputs 145 | q_proj, k_proj, _ = self.input_projections(query, key, None, "monotonic") 146 | 147 | # attention energy 148 | attn_energy = self.attn_energy(q_proj, k_proj, key_padding_mask) 149 | 150 | noise = 0 151 | 152 | if self.training: 153 | # add noise here to encourage discretness 154 | noise = ( 155 | torch.normal(self.noise_mean, self.noise_var, attn_energy.size()) 156 | .type_as(attn_energy) 157 | .to(attn_energy.device) 158 | ) 159 | 160 | p_choose = torch.sigmoid(attn_energy + noise) 161 | _, _, tgt_len, src_len = p_choose.size() 162 | 163 | # p_choose: bsz * self.num_heads, tgt_len, src_len 164 | return p_choose.view(-1, tgt_len, src_len) 165 | 166 | def attn_energy(self, q_proj, k_proj, key_padding_mask=None): 167 | """ 168 | Calculating monotonic energies 169 | ============================================================ 170 | Expected input size 171 | q_proj: bsz * num_heads, tgt_len, self.head_dim 172 | k_proj: bsz * num_heads, src_len, self.head_dim 173 | key_padding_mask: bsz, src_len 174 | attn_mask: tgt_len, src_len 175 | """ 176 | bsz, tgt_len, embed_dim = q_proj.size() 177 | bsz = bsz // self.num_heads 178 | src_len = k_proj.size(1) 179 | 180 | attn_energy = torch.bmm(q_proj, k_proj.transpose(1, 2)) + self.energy_bias 181 | 182 | attn_energy = attn_energy.view(bsz, self.num_heads, tgt_len, src_len) 183 | 184 | if key_padding_mask is not None: 185 | attn_energy = attn_energy.masked_fill( 186 | key_padding_mask.unsqueeze(1).unsqueeze(2).bool(), 187 | float("-inf"), 188 | ) 189 | 190 | return attn_energy 191 | 192 | def expected_alignment_train(self, p_choose, key_padding_mask): 193 | """ 194 | Calculating expected alignment for MMA 195 | Mask is not need because p_choose will be 0 if masked 196 | q_ij = (1 − p_{ij−1})q_{ij−1} + a+{i−1j} 197 | a_ij = p_ij q_ij 198 | parellel solution: 199 | ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi)) 200 | ============================================================ 201 | Expected input size 202 | p_choose: bsz * num_heads, tgt_len, src_len 203 | """ 204 | 205 | # p_choose: bsz * num_heads, tgt_len, src_len 206 | bsz_num_heads, tgt_len, src_len = p_choose.size() 207 | 208 | # cumprod_1mp : bsz * num_heads, tgt_len, src_len 209 | cumprod_1mp = exclusive_cumprod(1 - p_choose, dim=2, eps=self.eps) 210 | cumprod_1mp_clamp = torch.clamp(cumprod_1mp, self.eps, 1.0) 211 | 212 | init_attention = p_choose.new_zeros([bsz_num_heads, 1, src_len]) 213 | init_attention[:, :, 0] = 1.0 214 | 215 | previous_attn = [init_attention] 216 | 217 | for i in range(tgt_len): 218 | # p_choose: bsz * num_heads, tgt_len, src_len 219 | # cumprod_1mp_clamp : bsz * num_heads, tgt_len, src_len 220 | # previous_attn[i]: bsz * num_heads, 1, src_len 221 | # alpha_i: bsz * num_heads, src_len 222 | alpha_i = ( 223 | p_choose[:, i] 224 | * cumprod_1mp[:, i] 225 | * torch.cumsum(previous_attn[i][:, 0] / cumprod_1mp_clamp[:, i], dim=1) 226 | ).clamp(0, 1.0) 227 | previous_attn.append(alpha_i.unsqueeze(1)) 228 | 229 | # alpha: bsz * num_heads, tgt_len, src_len 230 | alpha = torch.cat(previous_attn[1:], dim=1) 231 | 232 | if self.mass_preservation: 233 | # Last token has the residual probabilities 234 | alpha[:, :, -1] = 1 - alpha[:, :, :-1].sum(dim=-1).clamp(0.0, 1.0) 235 | 236 | assert not torch.isnan(alpha).any(), "NaN detected in alpha." 237 | 238 | return alpha 239 | 240 | def expected_attention( 241 | self, alpha, query, key, value, key_padding_mask, incremental_state 242 | ): 243 | # monotonic attention, we will calculate milk here 244 | bsz_x_num_heads, tgt_len, src_len = alpha.size() 245 | bsz = int(bsz_x_num_heads / self.num_heads) 246 | 247 | q, k, _ = self.input_projections(query, key, None, "soft") 248 | soft_energy = self.attn_energy(q, k, key_padding_mask) 249 | 250 | assert list(soft_energy.size()) == [bsz, self.num_heads, tgt_len, src_len] 251 | 252 | soft_energy = soft_energy.view(bsz * self.num_heads, tgt_len, src_len) 253 | 254 | if incremental_state is not None: 255 | monotonic_cache = self._get_monotonic_buffer(incremental_state) 256 | monotonic_step = monotonic_cache["step"] + 1 257 | step_offset = 0 258 | if key_padding_mask is not None: 259 | if key_padding_mask[:, 0].any(): 260 | # left_pad_source = True: 261 | step_offset = key_padding_mask.sum(dim=-1, keepdim=True) 262 | monotonic_step += step_offset 263 | mask = lengths_to_mask( 264 | monotonic_step.view(-1), soft_energy.size(2), 1 265 | ).unsqueeze(1) 266 | 267 | soft_energy = soft_energy.masked_fill(~mask.bool(), float("-inf")) 268 | soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0] 269 | exp_soft_energy = torch.exp(soft_energy) 270 | exp_soft_energy_sum = exp_soft_energy.sum(dim=2) 271 | beta = exp_soft_energy / exp_soft_energy_sum.unsqueeze(2) 272 | 273 | else: 274 | # bsz * num_heads, tgt_len, src_len 275 | soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0] 276 | exp_soft_energy = torch.exp(soft_energy) 277 | exp_soft_energy_cumsum = torch.cumsum(exp_soft_energy, dim=2) 278 | 279 | if key_padding_mask is not None: 280 | if key_padding_mask.any(): 281 | exp_soft_energy_cumsum = ( 282 | exp_soft_energy_cumsum.view( 283 | -1, self.num_heads, tgt_len, src_len 284 | ) 285 | .masked_fill( 286 | key_padding_mask.unsqueeze(1).unsqueeze(1), self.eps 287 | ) 288 | .view(-1, tgt_len, src_len) 289 | ) 290 | 291 | inner_items = alpha / exp_soft_energy_cumsum 292 | 293 | beta = exp_soft_energy * torch.cumsum( 294 | inner_items.flip(dims=[2]), dim=2 295 | ).flip(dims=[2]) 296 | 297 | beta = self.dropout_module(beta) 298 | 299 | assert not torch.isnan(beta).any(), "NaN detected in beta." 300 | 301 | return beta 302 | 303 | def forward( 304 | self, 305 | query, 306 | key, 307 | value, 308 | key_padding_mask=None, 309 | incremental_state=None, 310 | *args, 311 | **kwargs, 312 | ): 313 | 314 | tgt_len, bsz, embed_dim = query.size() 315 | src_len = value.size(0) 316 | 317 | # stepwise prob 318 | # p_choose: bsz * self.num_heads, tgt_len, src_len 319 | p_choose = self.p_choose(query, key, key_padding_mask) 320 | 321 | # expected alignment alpha 322 | # bsz * self.num_heads, tgt_len, src_len 323 | 324 | alpha = self.expected_alignment_train(p_choose, key_padding_mask) 325 | 326 | # expected attention beta 327 | # bsz * self.num_heads, tgt_len, src_len 328 | beta = self.expected_attention( 329 | alpha, query, key, value, key_padding_mask, incremental_state 330 | ) 331 | 332 | attn_weights = beta 333 | 334 | _, _, v_proj = self.input_projections(None, None, value, "output") 335 | attn = torch.bmm(attn_weights.type_as(v_proj), v_proj) 336 | 337 | attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) 338 | 339 | attn = self.out_proj(attn) 340 | 341 | beta = beta.view(bsz, self.num_heads, tgt_len, src_len) 342 | alpha = alpha.view(bsz, self.num_heads, tgt_len, src_len) 343 | p_choose = p_choose.view(bsz, self.num_heads, tgt_len, src_len) 344 | 345 | return attn, alpha, beta, p_choose 346 | 347 | 348 | def exclusive_cumprod(tensor, dim: int, eps: float = 1e-10): 349 | """ 350 | Implementing exclusive cumprod. 351 | There is cumprod in pytorch, however there is no exclusive mode. 352 | cumprod(x) = [x1, x1x2, x2x3x4, ..., prod_{i=1}^n x_i] 353 | exclusive means cumprod(x) = [1, x1, x1x2, x1x2x3, ..., prod_{i=1}^{n-1} x_i] 354 | """ 355 | tensor_size = list(tensor.size()) 356 | tensor_size[dim] = 1 357 | return_tensor = safe_cumprod( 358 | torch.cat([torch.ones(tensor_size).type_as(tensor), tensor], dim=dim), 359 | dim=dim, 360 | eps=eps, 361 | ) 362 | 363 | if dim == 0: 364 | return return_tensor[:-1] 365 | elif dim == 1: 366 | return return_tensor[:, :-1] 367 | elif dim == 2: 368 | return return_tensor[:, :, :-1] 369 | else: 370 | raise RuntimeError("Cumprod on dimension 3 and more is not implemented") 371 | 372 | 373 | def safe_cumprod(tensor, dim: int, eps: float = 1e-10): 374 | """ 375 | An implementation of cumprod to prevent precision issue. 376 | cumprod(x) 377 | = [x1, x1x2, x1x2x3, ....] 378 | = [exp(log(x1)), exp(log(x1) + log(x2)), exp(log(x1) + log(x2) + log(x3)), ...] 379 | = exp(cumsum(log(x))) 380 | """ 381 | 382 | if (tensor + eps < 0).any().item(): 383 | raise RuntimeError( 384 | "Safe cumprod can only take non-negative tensors as input." 385 | "Consider use torch.cumprod if you want to calculate negative values." 386 | ) 387 | 388 | log_tensor = torch.log(tensor + eps) 389 | cumsum_log_tensor = torch.cumsum(log_tensor, dim) 390 | exp_cumsum_log_tensor = torch.exp(cumsum_log_tensor) 391 | return exp_cumsum_log_tensor 392 | 393 | 394 | def lengths_to_mask(lengths, max_len: int, dim: int = 0, negative_mask: bool = False): 395 | """ 396 | Convert a tensor of lengths to mask 397 | For example, lengths = [[2, 3, 4]], max_len = 5 398 | mask = 399 | [[1, 1, 1], 400 | [1, 1, 1], 401 | [0, 1, 1], 402 | [0, 0, 1], 403 | [0, 0, 0]] 404 | """ 405 | assert len(lengths.size()) <= 2 406 | if len(lengths) == 2: 407 | if dim == 1: 408 | lengths = lengths.t() 409 | lengths = lengths 410 | else: 411 | lengths = lengths.unsqueeze(1) 412 | 413 | # lengths : batch_size, 1 414 | lengths = lengths.view(-1, 1) 415 | 416 | batch_size = lengths.size(0) 417 | # batch_size, max_len 418 | mask = torch.arange(max_len).expand(batch_size, max_len).type_as(lengths) < lengths 419 | 420 | if negative_mask: 421 | mask = ~mask 422 | 423 | if dim == 0: 424 | # max_len, batch_size 425 | mask = mask.t() 426 | 427 | return mask 428 | 429 | 430 | def moving_sum(x, start_idx: int, end_idx: int): 431 | """ 432 | From MONOTONIC CHUNKWISE ATTENTION 433 | https://arxiv.org/pdf/1712.05382.pdf 434 | Equation (18) 435 | x = [x_1, x_2, ..., x_N] 436 | MovingSum(x, start_idx, end_idx)_n = Sigma_{m=n−(start_idx−1)}^{n+end_idx-1} x_m 437 | for n in {1, 2, 3, ..., N} 438 | x : src_len, batch_size 439 | start_idx : start idx 440 | end_idx : end idx 441 | Example 442 | src_len = 5 443 | batch_size = 3 444 | x = 445 | [[ 0, 5, 10], 446 | [ 1, 6, 11], 447 | [ 2, 7, 12], 448 | [ 3, 8, 13], 449 | [ 4, 9, 14]] 450 | MovingSum(x, 3, 1) = 451 | [[ 0, 5, 10], 452 | [ 1, 11, 21], 453 | [ 3, 18, 33], 454 | [ 6, 21, 36], 455 | [ 9, 24, 39]] 456 | MovingSum(x, 1, 3) = 457 | [[ 3, 18, 33], 458 | [ 6, 21, 36], 459 | [ 9, 24, 39], 460 | [ 7, 17, 27], 461 | [ 4, 9, 14]] 462 | """ 463 | assert start_idx > 0 and end_idx > 0 464 | assert len(x.size()) == 2 465 | src_len, batch_size = x.size() 466 | # batch_size, 1, src_len 467 | x = x.t().unsqueeze(1) 468 | # batch_size, 1, src_len 469 | moving_sum_weight = x.new_ones([1, 1, end_idx + start_idx - 1]) 470 | 471 | moving_sum = ( 472 | torch.nn.functional.conv1d( 473 | x, moving_sum_weight, padding=start_idx + end_idx - 1 474 | ) 475 | .squeeze(1) 476 | .t() 477 | ) 478 | moving_sum = moving_sum[end_idx:-start_idx] 479 | 480 | assert src_len == moving_sum.size(0) 481 | assert batch_size == moving_sum.size(1) 482 | 483 | return moving_sum 484 | -------------------------------------------------------------------------------- /src/model/monotonic_attention00.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import random 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class MonotonicEnergy(nn.Module): 10 | def __init__( 11 | self, 12 | kdim, 13 | qdim, 14 | adim, 15 | atype, 16 | n_heads, 17 | init_r, 18 | bias=True, 19 | param_init="", 20 | conv1d=False, 21 | conv_kernel_size=5, 22 | ): 23 | """Energy function for the monotonic attenion. 24 | Args: 25 | kdim (int): dimension of key 26 | qdim (int): dimension of quary 27 | adim (int): dimension of attention space 28 | atype (str): type of attention mechanism 29 | n_heads (int): number of monotonic attention heads 30 | init_r (int): initial value for offset r 31 | bias (bool): use bias term in linear layers 32 | param_init (str): parameter initialization method 33 | conv1d (bool): use 1D causal convolution for energy calculation 34 | conv_kernel_size (int): kernel size for 1D convolution 35 | """ 36 | super().__init__() 37 | 38 | assert conv_kernel_size % 2 == 1, "Kernel size should be odd for 'same' conv." 39 | self.key = None 40 | self.mask = None 41 | 42 | self.atype = atype 43 | assert adim % n_heads == 0 44 | self.d_k = adim // n_heads 45 | self.n_heads = n_heads 46 | self.scale = math.sqrt(adim) 47 | 48 | if atype == "add": 49 | self.w_key = nn.Linear(kdim, adim) 50 | self.v = nn.Linear(adim, n_heads, bias=False) 51 | self.w_query = nn.Linear(qdim, adim, bias=False) 52 | elif atype == "scaled_dot": 53 | self.w_key = nn.Linear(kdim, adim, bias=bias) 54 | self.w_query = nn.Linear(qdim, adim, bias=bias) 55 | else: 56 | raise NotImplementedError(atype) 57 | 58 | self.r = nn.Parameter(torch.Tensor([init_r])) 59 | logger.info("init_r is initialized with %d" % init_r) 60 | 61 | self.conv1d = None 62 | if conv1d: 63 | self.conv1d = CausalConv1d( 64 | in_channels=kdim, 65 | out_channels=kdim, 66 | kernel_size=conv_kernel_size, 67 | param_init=param_init, 68 | ) 69 | # padding=(conv_kernel_size - 1) // 2 70 | 71 | if atype == "add": 72 | self.v = nn.utils.weight_norm(self.v, name="weight", dim=0) 73 | # initialization 74 | self.v.weight_g.data = torch.Tensor([1 / adim]).sqrt() 75 | elif atype == "scaled_dot": 76 | if param_init == "xavier_uniform": 77 | self.reset_parameters(bias) 78 | 79 | def reset_parameters(self, bias): 80 | """Initialize parameters with Xavier uniform distribution.""" 81 | logger.info( 82 | "===== Initialize %s with Xavier uniform distribution =====" 83 | % self.__class__.__name__ 84 | ) 85 | # NOTE: see https://github.com/pytorch/fairseq/blob/master/fairseq/modules/multihead_attention.py 86 | nn.init.xavier_uniform_(self.w_key.weight, gain=1 / math.sqrt(2)) 87 | nn.init.xavier_uniform_(self.w_query.weight, gain=1 / math.sqrt(2)) 88 | if bias: 89 | nn.init.constant_(self.w_key.bias, 0.0) 90 | nn.init.constant_(self.w_query.bias, 0.0) 91 | 92 | def reset(self): 93 | self.key = None 94 | self.mask = None 95 | 96 | def forward(self, key, query, mask, cache=False, boundary_leftmost=0): 97 | """Compute monotonic energy. 98 | Args: 99 | key (FloatTensor): `[B, klen, kdim]` 100 | query (FloatTensor): `[B, qlen, qdim]` 101 | mask (ByteTensor): `[B, qlen, klen]` 102 | cache (bool): cache key and mask 103 | Returns: 104 | e (FloatTensor): `[B, H_ma, qlen, klen]` 105 | """ 106 | bs, klen, kdim = key.size() 107 | qlen = query.size(1) 108 | 109 | # Pre-computation of encoder-side features for computing scores 110 | if self.key is None or not cache: 111 | # 1d conv 112 | if self.conv1d is not None: 113 | key = torch.relu(self.conv1d(key)) 114 | key = self.w_key(key).view(bs, -1, self.n_heads, self.d_k) 115 | self.key = key.transpose(2, 1).contiguous() # `[B, H_ma, klen, d_k]` 116 | self.mask = mask 117 | if mask is not None: 118 | self.mask = self.mask.unsqueeze(1).repeat( 119 | [1, self.n_heads, 1, 1] 120 | ) # `[B, H_ma, qlen, klen]` 121 | assert self.mask.size() == (bs, self.n_heads, qlen, klen), ( 122 | self.mask.size(), 123 | (bs, self.n_heads, qlen, klen), 124 | ) 125 | 126 | query = self.w_query(query).view(bs, -1, self.n_heads, self.d_k) 127 | query = query.transpose(2, 1).contiguous() # `[B, H_ma, qlen, d_k]` 128 | m = self.mask 129 | 130 | if self.atype == "add": 131 | k = self.key.unsqueeze(2) # `[B, H_ma, 1, klen, d_k]` 132 | # Truncate encoder memories 133 | if boundary_leftmost > 0: 134 | k = k[:, :, :, boundary_leftmost:] 135 | klen = k.size(3) 136 | if m is not None: 137 | m = m[:, :, :, boundary_leftmost:] 138 | e = torch.relu(k + query.unsqueeze(3)) # `[B, H_ma, qlen, klen, d_k]` 139 | e = e.permute(0, 2, 3, 1, 4).contiguous().view(bs, qlen, klen, -1) 140 | e = self.v(e).permute(0, 3, 1, 2) # `[B, qlen, klen, H_ma]` 141 | elif self.atype == "scaled_dot": 142 | k = self.key.transpose(3, 2) 143 | e = torch.matmul(query, k) / self.scale 144 | 145 | if self.r is not None: 146 | e = e + self.r 147 | if m is not None: 148 | NEG_INF = float(np.finfo(torch.tensor(0, dtype=e.dtype).numpy().dtype).min) 149 | e = e.masked_fill_(m == 0, NEG_INF) 150 | assert e.size() == (bs, self.n_heads, qlen, klen), ( 151 | e.size(), 152 | (bs, self.n_heads, qlen, klen), 153 | ) 154 | return e 155 | 156 | 157 | class ChunkEnergy(nn.Module): 158 | def __init__(self, kdim, qdim, adim, atype, n_heads=1, bias=True, param_init=""): 159 | """Energy function for the chunkwise attention. 160 | Args: 161 | kdim (int): dimension of key 162 | qdim (int): dimension of quary 163 | adim (int): dimension of attention space 164 | atype (str): type of attention mechanism 165 | n_heads (int): number of chunkwise attention heads 166 | bias (bool): use bias term in linear layers 167 | param_init (str): parameter initialization method 168 | """ 169 | super().__init__() 170 | 171 | self.key = None 172 | self.mask = None 173 | 174 | self.atype = atype 175 | assert adim % n_heads == 0 176 | self.d_k = adim // n_heads 177 | self.n_heads = n_heads 178 | self.scale = math.sqrt(adim) 179 | 180 | if atype == "add": 181 | self.w_key = nn.Linear(kdim, adim) 182 | self.w_query = nn.Linear(qdim, adim, bias=False) 183 | self.v = nn.Linear(adim, n_heads, bias=False) 184 | elif atype == "scaled_dot": 185 | self.w_key = nn.Linear(kdim, adim, bias=bias) 186 | self.w_query = nn.Linear(qdim, adim, bias=bias) 187 | if param_init == "xavier_uniform": 188 | self.reset_parameters(bias) 189 | else: 190 | raise NotImplementedError(atype) 191 | 192 | def reset_parameters(self, bias): 193 | """Initialize parameters with Xavier uniform distribution.""" 194 | logger.info( 195 | "===== Initialize %s with Xavier uniform distribution =====" 196 | % self.__class__.__name__ 197 | ) 198 | # NOTE: see https://github.com/pytorch/fairseq/blob/master/fairseq/modules/multihead_attention.py 199 | nn.init.xavier_uniform_(self.w_key.weight, gain=1 / math.sqrt(2)) 200 | nn.init.xavier_uniform_(self.w_query.weight, gain=1 / math.sqrt(2)) 201 | if bias: 202 | nn.init.constant_(self.w_key.bias, 0.0) 203 | nn.init.constant_(self.w_query.bias, 0.0) 204 | 205 | def reset(self): 206 | self.key = None 207 | self.mask = None 208 | 209 | def forward( 210 | self, 211 | key, 212 | query, 213 | mask, 214 | cache=False, 215 | boundary_leftmost=0, 216 | boundary_rightmost=10e6, 217 | ): 218 | """Compute chunkwise energy. 219 | Args: 220 | key (FloatTensor): `[B, klen, kdim]` 221 | query (FloatTensor): `[B, qlen, qdim]` 222 | mask (ByteTensor): `[B, qlen, klen]` 223 | cache (bool): cache key and mask 224 | Returns: 225 | e (FloatTensor): `[B, H_ca, qlen, klen]` 226 | """ 227 | bs, klen, kdim = key.size() 228 | qlen = query.size(1) 229 | 230 | # Pre-computation of encoder-side features for computing scores 231 | if self.key is None or not cache: 232 | key = self.w_key(key).view(bs, -1, self.n_heads, self.d_k) 233 | self.key = key.transpose(2, 1).contiguous() # `[B, H_ca, klen, d_k]` 234 | self.mask = mask 235 | if mask is not None: 236 | self.mask = self.mask.unsqueeze(1).repeat( 237 | [1, self.n_heads, 1, 1] 238 | ) # `[B, H_ca, qlen, klen]` 239 | assert self.mask.size() == (bs, self.n_heads, qlen, klen), ( 240 | self.mask.size(), 241 | (bs, self.n_heads, qlen, klen), 242 | ) 243 | 244 | query = self.w_query(query).view(bs, -1, self.n_heads, self.d_k) 245 | query = query.transpose(2, 1).contiguous() # `[B, H_ca, qlen, d_k]` 246 | m = self.mask 247 | 248 | if self.atype == "add": 249 | k = self.key.unsqueeze(2) # `[B, H_ca, 1, klen, d_k]` 250 | # Truncate 251 | k = k[:, :, :, boundary_leftmost:boundary_rightmost] 252 | klen = k.size(3) 253 | if m is not None: 254 | m = m[:, :, :, boundary_leftmost:boundary_rightmost] 255 | 256 | r = torch.relu(k + query.unsqueeze(3)) # `[B, H_ca, qlen, klen, d_k]` 257 | r = ( 258 | r.permute(0, 2, 3, 1, 4).contiguous().view(bs, qlen, klen, -1) 259 | ) # `[B, qlen, klen, H_ca * d_k]` 260 | r = self.v(r).permute(0, 3, 1, 2).contiguous() # `[B, H_ca, qlen, klen]` 261 | elif self.atype == "scaled_dot": 262 | k = self.key.transpose(3, 2) 263 | r = torch.matmul(query, k) / self.scale 264 | 265 | if m is not None: 266 | NEG_INF = float(np.finfo(torch.tensor(0, dtype=r.dtype).numpy().dtype).min) 267 | r = r.masked_fill_(m == 0, NEG_INF) 268 | assert r.size() == (bs, self.n_heads, qlen, klen), ( 269 | r.size(), 270 | (bs, self.n_heads, qlen, klen), 271 | ) 272 | return r 273 | 274 | 275 | class MonotonicAttention(nn.Module): 276 | def __init__(self, embed_dim, n_heads, dropout=0.1): 277 | super(MonotonicAttention, self).__init__() 278 | 279 | self.n_heads = n_heads 280 | 281 | self.monotonic_energy = MonotonicEnergy( 282 | kdim=embed_dim, 283 | qdim=embed_dim, 284 | adim=embed_dim, 285 | atype="scaled_dot", 286 | n_heads=n_heads, 287 | init_r=0, 288 | param_init="xavier_uniform", 289 | ) 290 | 291 | self.chunk_energy = ChunkEnergy( 292 | kdim=embed_dim, 293 | qdim=embed_dim, 294 | adim=embed_dim, 295 | atype="scaled_dot", 296 | n_heads=n_heads, 297 | param_init="xavier_uniform", 298 | ) 299 | self.dropout_attn = nn.Dropout(p=dropout) # for beta 300 | self.bd_offset = 0 301 | 302 | def forward(self, key, value, query, aw_prev=None, mask=None, cache=None): 303 | 304 | bs, klen = key.size()[:2] 305 | qlen = query.size(1) 306 | 307 | if aw_prev is None: 308 | # aw_prev = [1, 0, 0 ... 0] 309 | aw_prev = key.new_zeros(bs, self.n_heads_ma, 1, klen) 310 | aw_prev[:, :, :, 0:1] = key.new_ones(bs, self.n_heads_ma, 1, 1) 311 | 312 | # Compute monotonic energy 313 | e_ma = self.monotonic_energy( 314 | key, query, mask, cache=cache, boundary_leftmost=self.bd_offset 315 | ) # `[B, H_ma, qlen, klen]` 316 | 317 | p_choose = torch.sigmoid( 318 | add_gaussian_noise(e_ma, self.noise_std) 319 | ) # `[B, H_ma, qlen, klen]` 320 | 321 | # safe_cumprod computes cumprod in logspace with numeric checks 322 | cumprod_1mp_choose = safe_cumprod( 323 | 1 - p_choose, eps=self.eps 324 | ) # `[B, H_ma, qlen, klen]` 325 | 326 | # Compute recurrence relation solution 327 | alpha = [] 328 | for i in range(qlen): 329 | denom = ( 330 | 1 331 | if self.no_denom 332 | else torch.clamp( 333 | cumprod_1mp_choose[:, :, i : i + 1], min=self.eps, max=1.0 334 | ) 335 | ) 336 | aw_prev = ( 337 | p_choose[:, :, i : i + 1] 338 | * cumprod_1mp_choose[:, :, i : i + 1] 339 | * torch.cumsum(aw_prev / denom, dim=-1) 340 | ) # `[B, H_ma, 1, klen]` 341 | # Mask the right part from the trigger point 342 | alpha.append(aw_prev) 343 | 344 | alpha = ( 345 | torch.cat(alpha, dim=2) if qlen > 1 else alpha[-1] 346 | ) # `[B, H_ma, qlen, klen]` 347 | alpha_masked = alpha.clone() 348 | 349 | # Compute chunk energy 350 | e_ca = self.chunk_energy( 351 | key, 352 | query, 353 | mask, 354 | ) # `[B, (H_ma*)H_ca, qlen, ken]` 355 | 356 | beta = efficient_chunkwise_attention( 357 | alpha_masked, 358 | e_ca, 359 | mask, 360 | chunk_size=-1, 361 | n_heads_chunk=self.n_heads, 362 | sharpening_factor=1.0, 363 | share_chunkwise_attention=True, 364 | ) 365 | beta = self.dropout_attn(beta) 366 | 367 | value = self.w_value(value).view( 368 | bs, -1, self.n_heads_ma * self.n_heads_ca, self.d_k 369 | ) 370 | value = value.transpose(2, 1).contiguous() # `[B, H_ma * H_ca, klen, d_k]` 371 | cv = torch.matmul( 372 | alpha if self.w == 1 else beta, value 373 | ) # `[B, H_ma * H_ca, qlen, d_k]` 374 | cv = ( 375 | cv.transpose(2, 1) 376 | .contiguous() 377 | .view(bs, -1, self.n_heads_ma * self.n_heads_ca * self.d_k) 378 | ) 379 | cv = self.w_out(cv) # `[B, qlen, adim]` 380 | 381 | 382 | def add_gaussian_noise(xs, std): 383 | """Add Gaussian nosie to encourage discreteness.""" 384 | noise = xs.new_zeros(xs.size()).normal_(std=std) 385 | return xs + noise 386 | 387 | 388 | def safe_cumprod(x, eps): 389 | """Numerically stable cumulative product by cumulative sum in log-space. 390 | Args: 391 | x (FloatTensor): `[B, H, qlen, klen]` 392 | Returns: 393 | x (FloatTensor): `[B, H, qlen, klen]` 394 | """ 395 | return torch.exp(exclusive_cumsum(torch.log(torch.clamp(x, min=eps, max=1.0)))) 396 | 397 | 398 | def exclusive_cumsum(x): 399 | """Exclusive cumulative summation [a, b, c] => [0, a, a + b]. 400 | Args: 401 | x (FloatTensor): `[B, H, qlen, klen]` 402 | Returns: 403 | x (FloatTensor): `[B, H, qlen, klen]` 404 | """ 405 | return torch.cumsum( 406 | torch.cat( 407 | [x.new_zeros(x.size(0), x.size(1), x.size(2), 1), x[:, :, :, :-1]], dim=-1 408 | ), 409 | dim=-1, 410 | ) 411 | 412 | 413 | def exclusive_cumprod(x): 414 | """Exclusive cumulative product [a, b, c] => [1, a, a * b]. 415 | Args: 416 | x (FloatTensor): `[B, H, qlen, klen]` 417 | Returns: 418 | x (FloatTensor): `[B, H, qlen, klen]` 419 | """ 420 | return torch.cumprod( 421 | torch.cat( 422 | [x.new_ones(x.size(0), x.size(1), x.size(2), 1), x[:, :, :, :-1]], dim=-1 423 | ), 424 | dim=-1, 425 | ) 426 | 427 | 428 | def efficient_chunkwise_attention( 429 | alpha, 430 | u, 431 | mask, 432 | chunk_size, 433 | n_heads_chunk, 434 | sharpening_factor, 435 | share_chunkwise_attention, 436 | ): 437 | """Compute chunkwise attention efficiently by clipping logits at training time. 438 | Args: 439 | alpha (FloatTensor): `[B, H_ma, qlen, klen]` 440 | u (FloatTensor): `[B, (H_ma*)H_ca, qlen, klen]` 441 | mask (ByteTensor): `[B, qlen, klen]` 442 | chunk_size (int): window size for chunkwise attention 443 | n_heads_chunk (int): number of chunkwise attention heads 444 | sharpening_factor (float): sharping factor for beta calculation 445 | share_chunkwise_attention (int): share CA heads among MA heads 446 | Returns: 447 | beta (FloatTensor): `[B, H_ma * H_ca, qlen, klen]` 448 | """ 449 | bs, n_heads_mono, qlen, klen = alpha.size() 450 | alpha = alpha.unsqueeze(2) # `[B, H_ma, 1, qlen, klen]` 451 | u = u.unsqueeze(1) # `[B, 1, (H_ma*)H_ca, qlen, klen]` 452 | if n_heads_chunk > 1: 453 | alpha = alpha.repeat([1, 1, n_heads_chunk, 1, 1]) 454 | if n_heads_mono > 1 and not share_chunkwise_attention: 455 | u = u.view(bs, n_heads_mono, n_heads_chunk, qlen, klen) 456 | # Shift logits to avoid overflow 457 | u -= torch.max(u, dim=-1, keepdim=True)[0] 458 | # Limit the range for numerical stability 459 | softmax_exp = torch.clamp(torch.exp(u), min=1e-5) 460 | # Compute chunkwise softmax denominators 461 | if chunk_size == -1: 462 | # infinite lookback attention 463 | # inner_items = alpha * sharpening_factor / torch.cumsum(softmax_exp, dim=-1) 464 | # beta = softmax_exp * torch.cumsum(inner_items.flip(dims=[-1]), dim=-1).flip(dims=[-1]) 465 | # beta = beta.masked_fill(mask.unsqueeze(1), 0) 466 | # beta = beta / beta.sum(dim=-1, keepdim=True) 467 | 468 | softmax_denominators = torch.cumsum(softmax_exp, dim=-1) 469 | # Compute \beta_{i, :}. emit_probs are \alpha_{i, :}. 470 | beta = softmax_exp * moving_sum( 471 | alpha * sharpening_factor / softmax_denominators, back=0, forward=klen - 1 472 | ) 473 | else: 474 | softmax_denominators = moving_sum(softmax_exp, back=chunk_size - 1, forward=0) 475 | # Compute \beta_{i, :}. emit_probs are \alpha_{i, :}. 476 | beta = softmax_exp * moving_sum( 477 | alpha * sharpening_factor / softmax_denominators, 478 | back=0, 479 | forward=chunk_size - 1, 480 | ) 481 | return beta.view(bs, -1, qlen, klen) 482 | 483 | 484 | def moving_sum(x, back, forward): 485 | """Compute the moving sum of x over a chunk_size with the provided bounds. 486 | Args: 487 | x (FloatTensor): `[B, H_ma, H_ca, qlen, klen]` 488 | back (int): 489 | forward (int): 490 | Returns: 491 | x_sum (FloatTensor): `[B, H_ma, H_ca, qlen, klen]` 492 | """ 493 | bs, n_heads_mono, n_heads_chunk, qlen, klen = x.size() 494 | x = x.view(-1, klen) 495 | # Moving sum is computed as a carefully-padded 1D convolution with ones 496 | x_padded = F.pad( 497 | x, pad=[back, forward] 498 | ) # `[B * H_ma * H_ca * qlen, back + klen + forward]` 499 | # Add a "channel" dimension 500 | x_padded = x_padded.unsqueeze(1) 501 | # Construct filters 502 | filters = x.new_ones(1, 1, back + forward + 1) 503 | x_sum = F.conv1d(x_padded, filters) 504 | x_sum = x_sum.squeeze(1).view(bs, n_heads_mono, n_heads_chunk, qlen, -1) 505 | return x_sum 506 | -------------------------------------------------------------------------------- /src/model/positional_encoding.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class PositionalEncoding_2(nn.Module): 8 | r"""Inject some information about the relative or absolute position of the tokens 9 | in the sequence. The positional encodings have the same dimension as 10 | the embeddings, so that the two can be summed. Here, we use sine and cosine 11 | functions of different frequencies. 12 | .. math:: 13 | \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) 14 | \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) 15 | \text{where pos is the word position and i is the embed idx) 16 | Args: 17 | d_model: the embed dim (required). 18 | dropout: the dropout value (default=0.1). 19 | max_len: the max. length of the incoming sequence (default=5000). 20 | """ 21 | 22 | def __init__(self, d_model, dropout=0.1, max_len=5000): 23 | super(PositionalEncoding, self).__init__() 24 | self.dropout = nn.Dropout(p=dropout) 25 | 26 | pe = torch.zeros(max_len, d_model) 27 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 28 | div_term = torch.exp( 29 | torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) 30 | ) 31 | pe[:, 0::2] = torch.sin(position * div_term) 32 | pe[:, 1::2] = torch.cos(position * div_term) 33 | pe = pe.unsqueeze(0).transpose(0, 1) 34 | self.register_buffer("pe", pe) 35 | 36 | def forward(self, x): 37 | r"""Inputs of forward function 38 | Args: 39 | x: the sequence fed to the positional encoder model (required). 40 | Shape: 41 | x: [sequence length, batch size, embed dim] 42 | output: [sequence length, batch size, embed dim] 43 | Examples: 44 | """ 45 | 46 | x = x + self.pe[: x.size(0), :] 47 | return self.dropout(x) 48 | 49 | 50 | class PositionalEncoding(nn.Module): 51 | r"""Inject some information about the relative or absolute position of the tokens 52 | in the sequence. The positional encodings have the same dimension as 53 | the embeddings, so that the two can be summed. Here, we use sine and cosine 54 | functions of different frequencies. 55 | .. math:: 56 | \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) 57 | \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) 58 | \text{where pos is the word position and i is the embed idx) 59 | Args: 60 | d_model: the embed dim (required). 61 | dropout: the dropout value (default=0.1). 62 | max_len: the max. length of the incoming sequence (default=5000). 63 | """ 64 | 65 | def __init__(self, d_model, dropout=0.1, max_len=5000): 66 | super(PositionalEncoding, self).__init__() 67 | self.dropout = nn.Dropout(p=dropout) 68 | 69 | pe = torch.zeros(max_len, d_model) 70 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 71 | div_term = torch.exp( 72 | torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) 73 | ) 74 | pe[:, 0::2] = torch.sin(position * div_term) 75 | pe[:, 1::2] = torch.sin(position * div_term) 76 | pe = pe.unsqueeze(0).transpose(0, 1) 77 | self.register_buffer("pe", pe) 78 | 79 | def forward(self, x): 80 | r"""Inputs of forward function 81 | Args: 82 | x: the sequence fed to the positional encoder model (required). 83 | Shape: 84 | x: [sequence length, batch size, embed dim] 85 | output: [sequence length, batch size, embed dim] 86 | Examples: 87 | """ 88 | 89 | x = x + self.pe[: x.size(0), :] 90 | return self.dropout(x) 91 | 92 | 93 | class PositionalEncoding_00(nn.Module): 94 | r"""Inject some information about the relative or absolute position of the tokens 95 | in the sequence. The positional encodings have the same dimension as 96 | the embeddings, so that the two can be summed. Here, we use sine and cosine 97 | functions of different frequencies. 98 | .. math:: 99 | \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) 100 | \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) 101 | \text{where pos is the word position and i is the embed idx) 102 | Args: 103 | d_model: the embed dim (required). 104 | dropout: the dropout value (default=0.1). 105 | max_len: the max. length of the incoming sequence (default=5000). 106 | """ 107 | 108 | def __init__(self, d_model, max_len=5000): 109 | super(PositionalEncoding_00, self).__init__() 110 | 111 | pe = torch.zeros(max_len, d_model) 112 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 113 | div_term = torch.exp( 114 | torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) 115 | ) 116 | pe[:, 0::2] = torch.sin(position * div_term) 117 | pe[:, 1::2] = torch.cos(position * div_term) 118 | pe = pe.unsqueeze(0).transpose(0, 1) 119 | self.register_buffer("pe", pe) 120 | 121 | def forward(self, stop, step): 122 | r"""Inputs of forward function 123 | Args: 124 | x: the sequence fed to the positional encoder model (required). 125 | Shape: 126 | x: [sequence length, batch size, embed dim] 127 | output: [sequence length, batch size, embed dim] 128 | Examples: 129 | """ 130 | 131 | return self.pe[0:stop:step, :] 132 | 133 | 134 | class PositionEmbeddingSine(nn.Module): 135 | """ 136 | This is a more standard version of the position embedding, very similar to the one 137 | used by the Attention is all you need paper, generalized to work on images. 138 | """ 139 | 140 | def __init__( 141 | self, num_pos_feats=64, temperature=10000, normalize=False, scale=None 142 | ): 143 | super().__init__() 144 | self.num_pos_feats = num_pos_feats 145 | self.temperature = temperature 146 | self.normalize = normalize 147 | if scale is not None and normalize is False: 148 | raise ValueError("normalize should be True if scale is passed") 149 | if scale is None: 150 | scale = 2 * math.pi 151 | self.scale = scale 152 | 153 | def forward(self, x): 154 | """ 155 | x: [B, C, h, w] 156 | """ 157 | B, C, h, w = x.shape 158 | ones = torch.ones_like(x[:, 0]) 159 | y_embed = ones.cumsum(dim=1) 160 | x_embed = ones.cumsum(dim=2) 161 | 162 | if self.normalize: 163 | eps = 1e-6 164 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 165 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 166 | 167 | i = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 168 | dim_t = self.temperature ** (2 * (i // 2) / self.num_pos_feats) 169 | 170 | pos_x = x_embed[:, :, :, None] / dim_t 171 | pos_y = y_embed[:, :, :, None] / dim_t 172 | pos_x = torch.stack( 173 | [pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()], dim=4 174 | ).flatten(3) 175 | pos_y = torch.stack( 176 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 177 | ).flatten(3) 178 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 179 | return pos 180 | 181 | 182 | class PositionalEncoding2DwDropout(nn.Module): 183 | def __init__(self, d_model, dropout=0.1, max_len=1000): 184 | super().__init__() 185 | 186 | self.dropout = nn.Dropout(p=dropout) 187 | self.d_model = d_model 188 | 189 | def forward(self, x, height, width): 190 | pe = torch.zeros(self.d_model, height, width, device=x.device) 191 | 192 | # Each dimension use half of d_model 193 | d_model = int(self.d_model / 2) 194 | div_term = torch.exp( 195 | torch.arange(0.0, d_model, 2) * -(math.log(10000.0) / d_model) 196 | ).cuda() 197 | pos_w = torch.arange(0.0, width // 2, device=x.device) 198 | pos_w = torch.cat([pos_w, pos_w.flip(dims=[0])]).unsqueeze(1) 199 | pos_h = torch.arange(0.0, height, device=x.device).unsqueeze(1) 200 | pe[0:d_model:2, :, :] = ( 201 | torch.sin(pos_w * div_term) 202 | .transpose(0, 1) 203 | .unsqueeze(1) 204 | .repeat(1, height, 1) 205 | ) 206 | pe[1:d_model:2, :, :] = ( 207 | torch.cos(pos_w * div_term) 208 | .transpose(0, 1) 209 | .unsqueeze(1) 210 | .repeat(1, height, 1) 211 | ) 212 | 213 | pe[d_model::2, :, :] = ( 214 | torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) 215 | ) 216 | pe[d_model + 1 :: 2, :, :] = ( 217 | torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) 218 | ) 219 | 220 | x = x + pe[None, :, :, :] 221 | return self.dropout(x) 222 | 223 | 224 | class PositionalEncoding3D(nn.Module): 225 | def __init__(self, channels): 226 | """ 227 | :param channels: The last dimension of the tensor you want to apply pos emb to. 228 | """ 229 | super(PositionalEncoding3D, self).__init__() 230 | channels = int(np.ceil(channels / 3)) 231 | if channels % 2: 232 | channels += 1 233 | self.channels = channels 234 | inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels)) 235 | self.register_buffer("inv_freq", inv_freq) 236 | 237 | def forward(self, tensor): 238 | """ 239 | :param tensor: A 5d tensor of size (batch_size, x, y, z, ch) 240 | :return: Positional Encoding Matrix of size (batch_size, x, y, z, ch) 241 | """ 242 | if len(tensor.shape) != 5: 243 | raise RuntimeError("The input tensor has to be 5d!") 244 | 245 | _, x, y, z, orig_ch = tensor.shape 246 | 247 | pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type()) 248 | pos_y = torch.arange(y, device=tensor.device).type(self.inv_freq.type()) 249 | pos_z = torch.arange(z, device=tensor.device).type(self.inv_freq.type()) 250 | sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) 251 | sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq) 252 | sin_inp_z = torch.einsum("i,j->ij", pos_z, self.inv_freq) 253 | emb_x = ( 254 | torch.cat((sin_inp_x.sin(), sin_inp_x.cos()), dim=-1) 255 | .unsqueeze(1) 256 | .unsqueeze(1) 257 | ) 258 | emb_y = torch.cat((sin_inp_y.sin(), sin_inp_y.cos()), dim=-1).unsqueeze(1) 259 | emb_z = torch.cat((sin_inp_z.sin(), sin_inp_z.cos()), dim=-1) 260 | emb = torch.zeros((x, y, z, self.channels * 3), device=tensor.device).type( 261 | tensor.type() 262 | ) 263 | emb[:, :, :, : self.channels] = emb_x 264 | emb[:, :, :, self.channels : 2 * self.channels] = emb_y 265 | emb[:, :, :, 2 * self.channels :] = emb_z 266 | 267 | return emb[None, :, :, :, :orig_ch] 268 | --------------------------------------------------------------------------------