├── .idea
├── inspectionProfiles
│ └── Project_Default.xml
├── modules.xml
└── translating-images-into-maps.iml
├── LICENSE.md
├── README.md
├── images
├── ._image_to_bev_motivation.gif
└── image_to_bev_motivation.gif
├── src
├── __init__.py
├── data
│ ├── augmentations.py
│ ├── collate_funcs.py
│ ├── data_generation.py
│ ├── dataloader.py
│ ├── dataloader_new.py
│ ├── nuscenes
│ │ ├── __init__.py
│ │ ├── dataset.py
│ │ ├── splits.py
│ │ └── utils.py
│ └── utils.py
├── model
│ ├── __init__.py
│ ├── axial_attention.py
│ ├── backbone_utils.py
│ ├── bev_transform.py
│ ├── dla.py
│ ├── dla_up.py
│ ├── fpn.py
│ ├── general_modules.py
│ ├── intermediate_layer_getter.py
│ ├── loss.py
│ ├── mocha.py
│ ├── mocha2.py
│ ├── monotonic_attention.py
│ ├── monotonic_attention00.py
│ ├── monotonic_attention_fairseq.py
│ ├── network.py
│ ├── positional_encoding.py
│ ├── resnet.py
│ └── transformer.py
└── utils.py
└── train.py
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/translating-images-into-maps.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
2 |
3 | Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.
4 |
5 | ### Using Creative Commons Public Licenses
6 |
7 | Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses.
8 |
9 | * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors).
10 |
11 | * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees).
12 |
13 | ## Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License
14 |
15 | By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
16 |
17 | ### Section 1 – Definitions.
18 |
19 | a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.
20 |
21 | b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License.
22 |
23 | c. __BY-NC-SA Compatible License__ means a license listed at [creativecommons.org/compatiblelicenses](http://creativecommons.org/compatiblelicenses), approved by Creative Commons as essentially the equivalent of this Public License.
24 |
25 | d. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.
26 |
27 | e. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.
28 |
29 | f. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.
30 |
31 | g. __License Elements__ means the license attributes listed in the name of a Creative Commons Public License. The License Elements of this Public License are Attribution, NonCommercial, and ShareAlike.
32 |
33 | h. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License.
34 |
35 | i. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.
36 |
37 | h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License.
38 |
39 | i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange.
40 |
41 | j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.
42 |
43 | k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.
44 |
45 | l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning.
46 |
47 | ### Section 2 – Scope.
48 |
49 | a. ___License grant.___
50 |
51 | 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:
52 |
53 | A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and
54 |
55 | B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only.
56 |
57 | 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.
58 |
59 | 3. __Term.__ The term of this Public License is specified in Section 6(a).
60 |
61 | 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.
62 |
63 | 5. __Downstream recipients.__
64 |
65 | A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.
66 |
67 | B. __Additional offer from the Licensor – Adapted Material.__ Every recipient of Adapted Material from You automatically receives an offer from the Licensor to exercise the Licensed Rights in the Adapted Material under the conditions of the Adapter’s License You apply.
68 |
69 | C. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.
70 |
71 | 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).
72 |
73 | b. ___Other rights.___
74 |
75 | 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.
76 |
77 | 2. Patent and trademark rights are not licensed under this Public License.
78 |
79 | 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes.
80 |
81 | ### Section 3 – License Conditions.
82 |
83 | Your exercise of the Licensed Rights is expressly made subject to the following conditions.
84 |
85 | a. ___Attribution.___
86 |
87 | 1. If You Share the Licensed Material (including in modified form), You must:
88 |
89 | A. retain the following if it is supplied by the Licensor with the Licensed Material:
90 |
91 | i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);
92 |
93 | ii. a copyright notice;
94 |
95 | iii. a notice that refers to this Public License;
96 |
97 | iv. a notice that refers to the disclaimer of warranties;
98 |
99 | v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;
100 |
101 | B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and
102 |
103 | C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.
104 |
105 | 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.
106 |
107 | 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.
108 |
109 | b. ___ShareAlike.___
110 |
111 | In addition to the conditions in Section 3(a), if You Share Adapted Material You produce, the following conditions also apply.
112 |
113 | 1. The Adapter’s License You apply must be a Creative Commons license with the same License Elements, this version or later, or a BY-NC-SA Compatible License.
114 |
115 | 2. You must include the text of, or the URI or hyperlink to, the Adapter's License You apply. You may satisfy this condition in any reasonable manner based on the medium, means, and context in which You Share Adapted Material.
116 |
117 | 3. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, Adapted Material that restrict exercise of the rights granted under the Adapter's License You apply.
118 |
119 | ### Section 4 – Sui Generis Database Rights.
120 |
121 | Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:
122 |
123 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only;
124 |
125 | b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material, including for purposes of Section 3(b); and
126 |
127 | c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.
128 |
129 | For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.
130 |
131 | ### Section 5 – Disclaimer of Warranties and Limitation of Liability.
132 |
133 | a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__
134 |
135 | b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__
136 |
137 | c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.
138 |
139 | ### Section 6 – Term and Termination.
140 |
141 | a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.
142 |
143 | b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:
144 |
145 | 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
146 |
147 | 2. upon express reinstatement by the Licensor.
148 |
149 | For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.
150 |
151 | c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.
152 |
153 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.
154 |
155 | ### Section 7 – Other Terms and Conditions.
156 |
157 | a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.
158 |
159 | b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.
160 |
161 | ### Section 8 – Interpretation.
162 |
163 | a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.
164 |
165 | b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.
166 |
167 | c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.
168 |
169 | d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.
170 |
171 | > Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.
172 | >
173 | > Creative Commons may be contacted at creativecommons.org
174 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Translating Images into Maps
2 | #### Avishkar Saha, Oscar Mendez Maldonado, Chris Russell and Richard Bowden
3 |
4 | This is the official code for the paper [Translating Images Into Maps](https://arxiv.org/abs/2110.00966) presented at ICRA 2022.
5 |
6 | ### Translating Images into Maps
7 |
8 |
9 |
10 |
11 |
12 | ### Setup
13 | The code was written using python 3.7.
14 | The following libraries are the minimal required for this repo:
15 | ```python
16 | pytorch
17 | cv2
18 | numpy
19 | pickle
20 | pyquaternion
21 | shapely
22 | lmdb
23 | ```
24 |
25 | ### Data
26 | The official nuScenes data will be required to train the entire model.
27 | But for convenience, we provide the nuScenes mini dataset wrapped into
28 | lmdb's, they can be downladed from either of the links below:
29 | ```
30 | https://www.icloud.com/iclouddrive/0aaSjW59DEqgUDKyy1uw0iSVg#nuscenes%5Fdata
31 |
32 | https://drive.google.com/drive/folders/1-1dZXeHnPiuqX-w8ruJHqfxBuMYMONRT?usp=share_link
33 | ```
34 |
35 | The contents of this folder need to be unzipped and placed in a folder, create the folder
36 | as follows:
37 | ```
38 | cd translating-images-into-maps
39 | mkdir nuscenes_data
40 | ```
41 |
42 | This contains the ground truth maps which have already been generated for
43 | the mini dataset, the input images and intrinsics.
44 |
45 | ### Using the code:
46 | To train a model with the configuration in the paper, simply run:
47 | ```bash
48 | python train.py
49 | ```
50 |
51 | ### Pretrained model
52 | Pretrained models and their configs required to load/train them can be downloaded from here (updated with correct models 09-09-2024):
53 | ````
54 | https://www.icloud.com/iclouddrive/041FdACyj8m0pM4L383luzZJg#tiim_checkpoints
55 | ````
56 |
57 |
58 | ### Citation
59 | If you find this code useful, please cite the following papers:
60 | ```
61 | @inproceedings{saha2022translating,
62 | title={Translating Images into Maps},
63 | author={Saha, Avishkar and Mendez, Oscar and Russell, Chris and Bowden, Richard},
64 | booktitle={2022 IEEE International Conference on Robotics and Automation (ICRA)},
65 | year={2022},
66 | organization={IEEE}
67 | }
68 | @inproceedings{saha2021enabling,
69 | title={Enabling spatio-temporal aggregation in birds-eye-view vehicle estimation},
70 | author={Saha, Avishkar and Mendez, Oscar and Russell, Chris and Bowden, Richard},
71 | booktitle={2021 IEEE International Conference on Robotics and Automation (ICRA)},
72 | pages={5133--5139},
73 | year={2021},
74 | organization={IEEE}
75 | }
76 | @inproceedings{saha2022pedestrian,
77 | title={" The Pedestrian next to the Lamppost" Adaptive Object Graphs for Better Instantaneous Mapping},
78 | author={Saha, Avishkar and Mendez, Oscar and Russell, Chris and Bowden, Richard},
79 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
80 | pages={19528--19537},
81 | year={2022}
82 | }
83 | ```
84 |
--------------------------------------------------------------------------------
/images/._image_to_bev_motivation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avishkarsaha/translating-images-into-maps/a452fe7a1eb0062f96133d14380d504e7bc0c8d1/images/._image_to_bev_motivation.gif
--------------------------------------------------------------------------------
/images/image_to_bev_motivation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avishkarsaha/translating-images-into-maps/a452fe7a1eb0062f96133d14380d504e7bc0c8d1/images/image_to_bev_motivation.gif
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avishkarsaha/translating-images-into-maps/a452fe7a1eb0062f96133d14380d504e7bc0c8d1/src/__init__.py
--------------------------------------------------------------------------------
/src/data/augmentations.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 |
4 |
5 | class AugmentedMapDataset(Dataset):
6 |
7 | def __init__(self, dataset, hflip=True, desired_image_size=None):
8 | self.dataset = dataset
9 | self.hflip = hflip
10 | self.desired_image_size = desired_image_size
11 |
12 | def __len__(self):
13 | return len(self.dataset)
14 |
15 | def __getitem__(self, index):
16 | image, calib, labels, mask = self.dataset[index]
17 |
18 | # Apply data augmentation
19 | if self.hflip:
20 | image, labels, mask = random_hflip(image, labels, mask)
21 | ## TODO: add padding+cropping augmentation (copied from the original dataloader) - don't we need to change the gt and mask OR the z_intervals when the image is being changed?
22 | # if self.desired_image_size:
23 | # image, calib = image_calib_pad_and_crop(image, calib, desired_image_size)
24 |
25 | return image, calib, labels, mask
26 |
27 |
28 | def random_hflip(image, labels, mask):
29 | image = torch.flip(image, (-1,))
30 | labels = torch.flip(labels.int(), (-1,)).bool()
31 | mask = torch.flip(mask.int(), (-1,)).bool()
32 | return image, labels, mask
33 |
34 | def image_calib_pad_and_crop(image, calib, desired_image_size):
35 |
36 | og_w, og_h = image.shape[-2:]
37 | desired_w, desired_h = desired_image_size
38 | scale_w, scale_h = desired_w / og_w, desired_h / og_h
39 | # Scale image
40 | image = image.resize((int(image.size[0] * scale_w), int(image.size[1] * scale_h)))
41 | # Pad images to the same dimensions
42 | w = image.size[0]
43 | h = image.size[1]
44 | delta_w = desired_w - w
45 | delta_h = desired_h - h
46 | pad_left = int(delta_w / 2)
47 | pad_right = delta_w - pad_left
48 | pad_top = int(delta_h / 2)
49 | pad_bottom = delta_h - pad_top
50 | left = 0 - pad_left
51 | right = pad_right + w
52 | top = 0 - pad_top
53 | bottom = pad_bottom + h
54 | image = image.crop((left, top, right, bottom))
55 |
56 | # Modify calibration matrices
57 | # Scale first two rows of calibration matrix
58 | calib[:2, :] *= scale_w
59 | # cx' = cx - du
60 | calib[0, 2] = calib[0, 2] + pad_left
61 | # cy' = cy - dv
62 | calib[1, 2] = calib[1, 2] + pad_top
63 |
64 | return image, calib
65 |
--------------------------------------------------------------------------------
/src/data/collate_funcs.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torchvision.transforms.functional import to_tensor
4 |
5 | from src.utils import merge_classes2
6 | from src.utils import make_uvcoords, merge_nusc_static_classes, downsample_gt, merge_classes2, merge_classes_lyft
7 |
8 |
9 | def collate_nusc_s_occ_200down100up(batch):
10 | """
11 | Collate fuction for:
12 | - NuScenes
13 | - Singe image input models
14 | - Ground truths with occluded regions masked out
15 |
16 | Merges input classes "road_segment" + "lane" into "drivable_area"
17 |
18 | Down Up scheme: downsample 200 to 100, then upsample 100 to required resolution
19 | Output resolution: 100
20 | """
21 |
22 | idxs, images, calibs, gt_maps, grids2d, class_dict, des_image_size = zip(*batch)
23 |
24 | # All class dicts are same, so only select first
25 | class_dict = class_dict[0]
26 |
27 | og_w, og_h = 1600, 900
28 | desired_w, desired_h = des_image_size[0]
29 | scale_w, scale_h = desired_w / og_w, desired_h / og_h
30 |
31 | # Scale image
32 | images = [
33 | image.resize((int(image.size[0] * scale_w), int(image.size[1] * scale_h)))
34 | for image in images
35 | ]
36 |
37 | # Pad images to the same dimensions
38 | w = [img.size[0] for img in images]
39 | h = [img.size[1] for img in images]
40 |
41 | delta_w = [desired_w - img.size[0] for img in images]
42 | delta_h = [desired_h - img.size[1] for img in images]
43 |
44 | pad_left = [int(d / 2) for d in delta_w]
45 | pad_right = [delta_w[i] - pad_left[i] for i in range(len(images))]
46 | pad_top = [int(d / 2) for d in delta_h]
47 | pad_bottom = [delta_h[i] - pad_top[i] for i in range(len(images))]
48 |
49 | left = [0 - pad_left[i] for i in range(len(images))]
50 | right = [pad_right[i] + w[i] for i in range(len(images))]
51 | top = [0 - pad_top[i] for i in range(len(images))]
52 | bottom = [pad_bottom[i] + h[i] for i in range(len(images))]
53 |
54 | images = [
55 | images[i].crop((left[i], top[i], right[i], bottom[i]))
56 | for i in range(len(images))
57 | ]
58 |
59 | w = [img.size[0] for img in images]
60 | h = [img.size[1] for img in images]
61 |
62 | # Stack images and calibration matrices along the batch dimension
63 | images = torch.stack([to_tensor(img) for img in images])
64 | gt_maps = torch.stack(
65 | [
66 | torch.stack([to_tensor(class_map) for class_map in gt_map])
67 | for gt_map in gt_maps
68 | ]
69 | ).squeeze(2)
70 | calibs = torch.stack(calibs)
71 | grids2d = torch.stack(grids2d)
72 |
73 | # Create visbility mask from lidar and fov masks
74 | lidar_ray_mask = gt_maps[:, -2:-1]
75 | fov_mask = gt_maps[:, -1:]
76 | vis_mask = lidar_ray_mask * fov_mask
77 | vis_mask = downsample_gt(vis_mask, [[100, 100]])[0]
78 |
79 | # Downsample all classes from 200x200 to 100x100
80 | gt_maps_all_classes = gt_maps[:, :-2]
81 | gt_maps_200down100 = downsample_gt(gt_maps_all_classes, [[100, 100]])[0]
82 |
83 | # Apply vis mask
84 | gt_maps = gt_maps_200down100 * vis_mask
85 |
86 | # Modify calibration matrices
87 | # Scale first two rows of calibration matrix
88 | calibs[:, :2, :] *= scale_w
89 | # cx' = cx - du
90 | calibs[:, 0, 2] = calibs[:, 0, 2] + torch.tensor(pad_left)
91 | # cy' = cy - dv
92 | calibs[:, 1, 2] = calibs[:, 1, 2] + torch.tensor(pad_top)
93 |
94 | # Create a vector of indices
95 | idxs = torch.LongTensor(idxs)
96 |
97 | # Add drivable area to first cell in FOV
98 | # gt_maps = gt_maps_masked.squeeze(2)
99 | gt_maps[:, 0, 0, int(gt_maps.shape[-1] / 2 - 1)] = 1
100 | gt_maps[:, 0, 0, int(gt_maps.shape[-1] / 2)] = 1
101 |
102 | # Merge ground truth for road_segment + lane into one and create new gt
103 | classes_to_merge = ["drivable_area", "road_segment", "lane"]
104 | class_idx_to_merge = [class_dict[name] for name in classes_to_merge]
105 | merged_class_idx = class_dict["drivable_area"]
106 | gt_maps_new = torch.stack(
107 | [
108 | merge_nusc_static_classes(gt, class_idx_to_merge, merged_class_idx)
109 | for gt in gt_maps
110 | ]
111 | )
112 |
113 | return idxs, images, calibs, gt_maps_new, grids2d, vis_mask
114 |
115 |
116 | def collate_nusc_s(batch):
117 | """
118 | Collate fuction for:
119 | - NuScenes
120 | - Singe image input models
121 | - Ground truths with occluded regions masked out
122 | """
123 |
124 | images, cls_maps, vis_masks, calibs, grids2d = zip(*batch)
125 |
126 | # Stack images and calibration matrices along the batch dimension
127 | images = torch.stack(images)
128 | cls_maps = torch.stack(cls_maps)
129 | vis_masks = torch.stack(vis_masks)
130 | calibs = torch.stack(calibs)
131 | grids2d = torch.stack(grids2d)
132 |
133 | return (images, calibs, grids2d), (cls_maps, vis_masks)
--------------------------------------------------------------------------------
/src/data/dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import csv
4 |
5 | from pathlib import Path
6 | import io
7 | import lmdb
8 | import pickle
9 | import numpy as np
10 | import matplotlib.pyplot as plt
11 | from torch.utils.data import DataLoader
12 | from PIL import Image
13 | import numpy as np
14 | from torch.utils.data import Dataset
15 | from torchvision.transforms.functional import to_tensor, to_pil_image
16 | import torch.nn.functional as F
17 |
18 | from nuscenes.nuscenes import NuScenes
19 | from nuscenes.utils.geometry_utils import (
20 | view_points,
21 | box_in_image,
22 | BoxVisibility,
23 | transform_matrix,
24 | )
25 | from nuscenes.map_expansion.map_api import NuScenesMap, NuScenesMapExplorer
26 | from nuscenes.utils.data_classes import LidarPointCloud
27 |
28 | from src import utils
29 |
30 |
31 | class nuScenesMaps(Dataset):
32 |
33 | def __init__(
34 | self,
35 | root="temp",
36 | split="train_mini",
37 | grid_size=(50.0, 50.0),
38 | grid_res=1.0,
39 | classes=[
40 | "bus",
41 | "bicycle",
42 | "car",
43 | "construction_vehicle",
44 | "motorcycle",
45 | "trailer",
46 | "truck",
47 | "pedestrian",
48 | ],
49 | dataset_size=1.0,
50 | mini=False,
51 | desired_image_size=(1280, 720),
52 | gt_out_size=(100, 100)
53 | ):
54 | self.dataset_size = dataset_size
55 | self.desired_image_size = desired_image_size
56 | self.gt_out_size = gt_out_size
57 |
58 | # paths for data files
59 | self.root = os.path.join(root)
60 | self.gtmaps_db_path = os.path.join(
61 | root, "lmdb",
62 | "semantic_maps_new_200x200"
63 | )
64 | self.images_db_path = os.path.join(
65 | root, "lmdb",
66 | "samples", "CAM_FRONT"
67 | )
68 |
69 | # databases
70 | if mini:
71 | self.nusc = NuScenes(version="v1.0-mini",
72 | dataroot=self.root,
73 | verbose=False)
74 | else:
75 | self.nusc = NuScenes(version="v1.0-trainval", dataroot=self.root, verbose=False)
76 |
77 | self.tokens = read_split(
78 | os.path.join(root, "splits", "{}.txt".format(split))
79 | )
80 | self.gtmaps_db = lmdb.open(
81 | path=self.gtmaps_db_path,
82 | readonly=True,
83 | readahead=False,
84 | max_spare_txns=128,
85 | lock=False,
86 | )
87 | self.images_db = lmdb.open(
88 | path=self.images_db_path,
89 | readonly=True,
90 | readahead=False,
91 | max_spare_txns=128,
92 | lock=False,
93 | )
94 |
95 | # Set classes
96 | self.classes = list(classes)
97 | self.classes.append("lidar_ray_mask_dense")
98 | self.class2idx = {
99 | name: idx for idx, name in enumerate(self.classes)
100 | }
101 | self.nusc_classes = [
102 | "vehicle.bus",
103 | "vehicle.bicycle",
104 | "vehicle.car",
105 | "vehicle.construction",
106 | "vehicle.motorcycle",
107 | "vehicle.trailer",
108 | "vehicle.truck",
109 | "human.pedestrian",
110 | ]
111 | self.nuscclass2idx = {
112 | name: idx for idx, name in enumerate(self.nusc_classes)
113 | }
114 | # load FOV mask
115 | self.fov_mask = Image.open(
116 | os.path.join(root, "lmdb", "semantic_maps_new_200x200", "fov_mask.png")
117 | )
118 | # Make grid
119 | self.grid2d = utils.make_grid2d(grid_size, (-grid_size[0] / 2.0, 0.0), grid_res)
120 |
121 | def __len__(self):
122 | return int(len(self.tokens) * self.dataset_size - 1)
123 |
124 | def __getitem__(self, index):
125 |
126 |
127 | # Load sample ID
128 | sample_token = self.tokens[index]
129 | sample_record = self.nusc.get("sample", sample_token)
130 | cam_token = sample_record["data"]["CAM_FRONT"]
131 | cam_record = self.nusc.get("sample_data", cam_token)
132 | cam_path = self.nusc.get_sample_data_path(cam_token)
133 | id = Path(cam_path).stem
134 |
135 | # Load intrinsincs
136 | calib = self.nusc.get(
137 | "calibrated_sensor", cam_record["calibrated_sensor_token"]
138 | )["camera_intrinsic"]
139 | calib = np.array(calib)
140 |
141 | # Load input images
142 | image_input_key = pickle.dumps(id)
143 | with self.images_db.begin() as txn:
144 | value = txn.get(key=image_input_key)
145 | image = Image.open(io.BytesIO(value)).convert(mode='RGB')
146 |
147 | # resize/augment images
148 | image, calib = self.image_calib_pad_and_crop(image, calib)
149 | image = to_tensor(image)
150 | calib = to_tensor(calib).reshape(3, 3)
151 |
152 | # Load ground truth maps
153 | gtmaps_key = [pickle.dumps("{}___{}".format(id, cls)) for cls in self.classes]
154 | with self.gtmaps_db.begin() as txn:
155 | value = [txn.get(key=key) for key in gtmaps_key]
156 | gtmaps = [Image.open(io.BytesIO(im)) for im in value]
157 |
158 | # each map is of shape [1, 200, 200]
159 | mapsdict = {cls: to_tensor(map) for cls, map in zip(self.classes, gtmaps)}
160 | mapsdict["fov_mask"] = to_tensor(self.fov_mask)
161 | mapsdict = self.merge_map_classes(mapsdict)
162 |
163 | # Create visbility mask from lidar and fov masks
164 | lidar_ray_mask = mapsdict['lidar_ray_mask_dense']
165 | fov_mask = mapsdict['fov_mask']
166 | vis_mask = lidar_ray_mask * fov_mask
167 | mapsdict['vis_mask'] = vis_mask
168 |
169 | del mapsdict['lidar_ray_mask_dense'], mapsdict['fov_mask']
170 |
171 | # downsample maps to required output resolution
172 | mapsdict = {
173 | cls: F.interpolate(cls_map.unsqueeze(0), size=self.gt_out_size).squeeze(0)
174 | for cls, cls_map in mapsdict.items()
175 | }
176 |
177 | # apply vis mask to maps
178 | mapsdict = {
179 | cls: cls_map * mapsdict['vis_mask'] for cls, cls_map in mapsdict.items()
180 | }
181 |
182 | cls_maps = torch.cat(
183 | [cls_map for cls, cls_map in mapsdict.items() if 'mask' not in cls], dim=0
184 | )
185 | vis_mask = mapsdict['vis_mask']
186 |
187 | return (
188 | image, cls_maps, vis_mask, calib, self.grid2d
189 | )
190 |
191 | def merge_map_classes(self, mapsdict):
192 | classes_to_merge = ["drivable_area", "road_segment", "lane"]
193 | merged_class = 'drivable_area'
194 | maps2merge = torch.stack([mapsdict[k] for k in classes_to_merge]) # [n, 1, 200, 200]
195 | maps2merge = maps2merge.sum(dim=0)
196 | maps2merge = (maps2merge > 0).float()
197 | mapsdict[merged_class] = maps2merge
198 | del mapsdict['road_segment'], mapsdict['lane']
199 | return mapsdict
200 |
201 | def image_calib_pad_and_crop(self, image, calib):
202 |
203 | og_w, og_h = 1600, 900
204 | desired_w, desired_h = self.desired_image_size
205 | scale_w, scale_h = desired_w / og_w, desired_h / og_h
206 | # Scale image
207 | image = image.resize((int(image.size[0] * scale_w), int(image.size[1] * scale_h)))
208 | # Pad images to the same dimensions
209 | w = image.size[0]
210 | h = image.size[1]
211 | delta_w = desired_w - w
212 | delta_h = desired_h - h
213 | pad_left = int(delta_w / 2)
214 | pad_right = delta_w - pad_left
215 | pad_top = int(delta_h / 2)
216 | pad_bottom = delta_h - pad_top
217 | left = 0 - pad_left
218 | right = pad_right + w
219 | top = 0 - pad_top
220 | bottom = pad_bottom + h
221 | image = image.crop((left, top, right, bottom))
222 |
223 | # Modify calibration matrices
224 | # Scale first two rows of calibration matrix
225 | calib[:2, :] *= scale_w
226 | # cx' = cx - du
227 | calib[0, 2] = calib[0, 2] + pad_left
228 | # cy' = cy - dv
229 | calib[1, 2] = calib[1, 2] + pad_top
230 |
231 | return image, calib
232 |
233 |
234 | def read_split(filename):
235 | """
236 | Read a list of NuScenes sample tokens
237 | """
238 | with open(filename, "r") as f:
239 | lines = f.read().split("\n")
240 | return [val for val in lines if val != ""]
241 |
242 |
243 | def create_batch_indices_old(split, batch_size, seq_len, n_pred_frames):
244 | nuscenes_root = "/vol/research/sceneEvolution/data/nuscenes"
245 | scene_len_file = os.path.join(
246 | nuscenes_root, "splits", (split + "_with_seq_len.txt")
247 | )
248 | scene_len = np.array(read_split(scene_len_file), dtype=np.int)
249 | cumsum_scene_len = np.cumsum(scene_len)
250 | zeros = np.zeros(len(cumsum_scene_len) + 1, dtype=np.int)
251 | zeros[1:] = cumsum_scene_len
252 | cumsum_scene_len = zeros
253 |
254 | idxs_batch_start = []
255 | idxs_batch_end = []
256 | idxs_batch_pred = []
257 |
258 | for idx_scene, scene in enumerate(scene_len):
259 | # Split scene length into chunks of size seq_len
260 | nbatches_in_scene = (scene - 1) // seq_len
261 | local_batch_num = (np.arange(nbatches_in_scene, dtype=np.int) + 1) * seq_len
262 | z = np.zeros(len(local_batch_num) + 1, dtype=np.int)
263 | z[1:] = local_batch_num
264 | local_batch_idx = z
265 |
266 | # Add cumsum scene_lengths to get global idx
267 | global_batch_idx = local_batch_idx + cumsum_scene_len[idx_scene]
268 |
269 | start_batch_idx = global_batch_idx[:-1]
270 | end_batch_idx = global_batch_idx[1:] - 1
271 | pred_batch_idx = end_batch_idx + n_pred_frames
272 | pred_batch_idx = np.clip(
273 | pred_batch_idx, a_min=0, a_max=scene - 1 + cumsum_scene_len[idx_scene]
274 | )
275 |
276 | idxs_batch_start.extend(list(start_batch_idx))
277 | idxs_batch_end.extend(list(end_batch_idx))
278 | idxs_batch_pred.extend(list(pred_batch_idx))
279 |
280 | return idxs_batch_start, idxs_batch_end, idxs_batch_pred
281 |
282 |
283 | def create_batch_indices(split, batch_size, seq_len, n_pred_frames):
284 | nuscenes_root = "/vol/research/sceneEvolution/data/nuscenes"
285 | scene_len_file = os.path.join(
286 | nuscenes_root, "splits", (split + "_with_seq_len.txt")
287 | )
288 | scene_len = np.array(read_split(scene_len_file), dtype=np.int)
289 | cumsum_scene_len = np.cumsum(scene_len)
290 | zeros = np.zeros(len(cumsum_scene_len) + 1, dtype=np.int)
291 | zeros[1:] = cumsum_scene_len
292 | cumsum_scene_len = zeros
293 |
294 | # Offset for sliding window through sequences
295 | offset = seq_len // 2
296 |
297 | idxs_batch_start = []
298 | idxs_batch_end = []
299 | idxs_batch_pred = []
300 |
301 | for idx_scene, scene in enumerate(scene_len):
302 | # Split scene length into chunks of size seq_len
303 | nbatches_in_scene = (scene - 1) // seq_len
304 | local_batch_num = (np.arange(nbatches_in_scene, dtype=np.int) + 1) * seq_len
305 | z = np.zeros(len(local_batch_num) + 1, dtype=np.int)
306 | z[1:] = local_batch_num
307 | local_batch_idx = z
308 |
309 | # Add cumsum scene_lengths to get global idx
310 | global_batch_idx = local_batch_idx + cumsum_scene_len[idx_scene]
311 |
312 | start_batch_idx = global_batch_idx[:-1]
313 | end_batch_idx = global_batch_idx[1:] - 1
314 | pred_batch_idx = end_batch_idx + n_pred_frames
315 | pred_batch_idx = np.clip(
316 | pred_batch_idx, a_min=0, a_max=scene - 1 + cumsum_scene_len[idx_scene]
317 | )
318 |
319 | idxs_batch_start.extend(list(start_batch_idx))
320 | idxs_batch_end.extend(list(end_batch_idx))
321 | idxs_batch_pred.extend(list(pred_batch_idx))
322 |
323 | # Create intermediate sequences (first trim either end)
324 | start_batch_idx = start_batch_idx[1:-1] - offset
325 | end_batch_idx = end_batch_idx[1:-1] - offset
326 | pred_batch_idx = pred_batch_idx[1:-1] - offset
327 |
328 | idxs_batch_start.extend(list(start_batch_idx))
329 | idxs_batch_end.extend(list(end_batch_idx))
330 | idxs_batch_pred.extend(list(pred_batch_idx))
331 |
332 | return idxs_batch_start, idxs_batch_end, idxs_batch_pred
333 |
334 |
335 | def create_batch_indices_wo_int(split, batch_size, seq_len, n_pred_frames):
336 | nuscenes_root = "/vol/research/sceneEvolution/data/nuscenes"
337 | scene_len_file = os.path.join(
338 | nuscenes_root, "splits", (split + "_with_seq_len.txt")
339 | )
340 | scene_len = np.array(read_split(scene_len_file), dtype=np.int)
341 | cumsum_scene_len = np.cumsum(scene_len)
342 | zeros = np.zeros(len(cumsum_scene_len) + 1, dtype=np.int)
343 | zeros[1:] = cumsum_scene_len
344 | cumsum_scene_len = zeros
345 |
346 | # Offset for sliding window through sequences
347 | offset = seq_len // 2
348 |
349 | idxs_batch_start = []
350 | idxs_batch_end = []
351 | idxs_batch_pred = []
352 |
353 | for idx_scene, scene in enumerate(scene_len):
354 | # Split scene length into chunks of size seq_len
355 | nbatches_in_scene = (scene - 1) // seq_len
356 | local_batch_num = (np.arange(nbatches_in_scene, dtype=np.int) + 1) * seq_len
357 | z = np.zeros(len(local_batch_num) + 1, dtype=np.int)
358 | z[1:] = local_batch_num
359 | local_batch_idx = z
360 |
361 | # Add cumsum scene_lengths to get global idx
362 | global_batch_idx = local_batch_idx + cumsum_scene_len[idx_scene]
363 |
364 | start_batch_idx = global_batch_idx[:-1]
365 | end_batch_idx = global_batch_idx[1:] - 1
366 | pred_batch_idx = end_batch_idx + n_pred_frames
367 | pred_batch_idx = np.clip(
368 | pred_batch_idx, a_min=0, a_max=scene - 1 + cumsum_scene_len[idx_scene]
369 | )
370 |
371 | idxs_batch_start.extend(list(start_batch_idx))
372 | idxs_batch_end.extend(list(end_batch_idx))
373 | idxs_batch_pred.extend(list(pred_batch_idx))
374 |
375 | # # Create intermediate sequences (first trim either end)
376 | # start_batch_idx = start_batch_idx[1:-1] - offset
377 | # end_batch_idx = end_batch_idx[1:-1] - offset
378 | # pred_batch_idx = pred_batch_idx[1:-1] - offset
379 | #
380 | # idxs_batch_start.extend(list(start_batch_idx))
381 | # idxs_batch_end.extend(list(end_batch_idx))
382 | # idxs_batch_pred.extend(list(pred_batch_idx))
383 |
384 | return idxs_batch_start, idxs_batch_end, idxs_batch_pred
385 |
386 |
387 | def create_batch_indices2(split, seq_len):
388 | nuscenes_root = "/vol/research/sceneEvolution/data/nuscenes"
389 | scene_len_file = os.path.join(
390 | nuscenes_root, "splits", (split + "_with_seq_len.txt")
391 | )
392 | scene_len = np.array(read_split(scene_len_file), dtype=np.int)
393 | cumsum_scene_len = np.cumsum(scene_len)
394 | zeros = np.zeros(len(cumsum_scene_len) + 1, dtype=np.int)
395 | zeros[1:] = cumsum_scene_len
396 | cumsum_scene_len = zeros
397 |
398 | idxs_batch_start = []
399 | for idx_scene, scene in enumerate(scene_len):
400 | # Split scene length into chunks of size seq_len
401 | nbatches_in_scene = (scene - 1) // seq_len
402 | local_batch_num = (np.arange(nbatches_in_scene, dtype=np.int) + 1) * seq_len
403 | z = np.zeros(len(local_batch_num) + 1, dtype=np.int)
404 | z[1:] = local_batch_num
405 | local_batch_idx = z
406 |
407 | # Add cumsum scene_lengths to get global idx
408 | global_batch_idx = local_batch_idx + cumsum_scene_len[idx_scene]
409 | idxs_batch_start.extend(list(global_batch_idx))
410 |
411 | idxs_batch_end = list(np.array(idxs_batch_start, dtype=np.int) - 1)
412 |
413 | return idxs_batch_start, idxs_batch_end
414 |
415 |
416 | def create_batch_indices_mini(split, batch_size, seq_len, n_pred_frames):
417 | nuscenes_root = "/vol/research/sceneEvolution/data/nuscenes"
418 | scene_len_file = os.path.join(
419 | nuscenes_root, "splits", (split + "_mini_with_seq_len.txt")
420 | )
421 | scene_len = np.array(read_split(scene_len_file), dtype=np.int)
422 | cumsum_scene_len = np.cumsum(scene_len)
423 | zeros = np.zeros(len(cumsum_scene_len) + 1, dtype=np.int)
424 | zeros[1:] = cumsum_scene_len
425 | cumsum_scene_len = zeros
426 |
427 | idxs_batch_start = []
428 | idxs_batch_end = []
429 | idxs_batch_pred = []
430 |
431 | for idx_scene, scene in enumerate(scene_len):
432 | # Split scene length into chunks of size seq_len
433 | nbatches_in_scene = (scene - 1) // seq_len
434 | local_batch_num = (np.arange(nbatches_in_scene, dtype=np.int) + 1) * seq_len
435 | z = np.zeros(len(local_batch_num) + 1, dtype=np.int)
436 | z[1:] = local_batch_num
437 | local_batch_idx = z
438 |
439 | # Add cumsum scene_lengths to get global idx
440 | global_batch_idx = local_batch_idx + cumsum_scene_len[idx_scene]
441 |
442 | start_batch_idx = global_batch_idx[:-1]
443 | end_batch_idx = global_batch_idx[1:] - 1
444 | pred_batch_idx = end_batch_idx + n_pred_frames
445 | pred_batch_idx = np.clip(
446 | pred_batch_idx, a_min=0, a_max=scene - 1 + cumsum_scene_len[idx_scene]
447 | )
448 |
449 | idxs_batch_start.extend(list(start_batch_idx))
450 | idxs_batch_end.extend(list(end_batch_idx))
451 | idxs_batch_pred.extend(list(pred_batch_idx))
452 |
453 | return idxs_batch_start, idxs_batch_end, idxs_batch_pred
454 |
455 |
--------------------------------------------------------------------------------
/src/data/dataloader_new.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torch.utils.data import DataLoader, RandomSampler
3 | from .augmentations import AugmentedMapDataset
4 |
5 | from nuscenes import NuScenes
6 | from .nuscenes.dataset import NuScenesMapDataset
7 | from .nuscenes.splits import TRAIN_SCENES, VAL_SCENES, CALIBRATION_SCENES
8 |
9 |
10 | def build_nuscenes_datasets(config):
11 | print('==> Loading NuScenes dataset...')
12 | nuscenes = NuScenes(config.nuscenes_version,
13 | os.path.expandvars(config.dataroot))
14 |
15 | # Exclude calibration scenes
16 | if config.hold_out_calibration:
17 | train_scenes = list(set(TRAIN_SCENES) - set(CALIBRATION_SCENES))
18 | else:
19 | train_scenes = TRAIN_SCENES
20 |
21 | train_data = NuScenesMapDataset(nuscenes, config.label_root,
22 | config.img_size, train_scenes)
23 | val_data = NuScenesMapDataset(nuscenes, config.label_root,
24 | config.img_size, VAL_SCENES)
25 | return train_data, val_data
26 |
27 | def build_datasets(dataset_name, config):
28 | if dataset_name == 'nuscenes':
29 | return build_nuscenes_datasets(config)
30 | else:
31 | raise ValueError(f"Unknown dataset option '{dataset_name}'")
32 |
33 |
34 | def build_trainval_datasets(dataset_name, config):
35 | # Construct the base dataset
36 | train_data, val_data = build_datasets(dataset_name, config)
37 |
38 | # Add data augmentation to train dataset
39 | train_data = AugmentedMapDataset(train_data, config.hflip)
40 |
41 | return train_data, val_data
42 |
43 |
44 | def build_dataloaders(dataset_name, config):
45 | # Build training and validation datasets
46 | train_data, val_data = build_trainval_datasets(dataset_name, config)
47 |
48 | # Create training set dataloader
49 | sampler = RandomSampler(train_data, True, config.epoch_size)
50 | train_loader = DataLoader(train_data, config.batch_size, sampler=sampler,
51 | num_workers=config.num_workers)
52 |
53 | # Create validation dataloader
54 | val_loader = DataLoader(val_data, config.batch_size,
55 | num_workers=config.num_workers)
56 |
57 | return train_loader, val_loader
58 |
59 |
60 |
61 |
62 |
63 |
64 |
--------------------------------------------------------------------------------
/src/data/nuscenes/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avishkarsaha/translating-images-into-maps/a452fe7a1eb0062f96133d14380d504e7bc0c8d1/src/data/nuscenes/__init__.py
--------------------------------------------------------------------------------
/src/data/nuscenes/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch.utils.data import Dataset
4 | from PIL import Image, ImageFile
5 | from nuscenes import NuScenes
6 | from torchvision.transforms.functional import to_tensor
7 |
8 | from .utils import CAMERA_NAMES, NUSCENES_CLASS_NAMES, iterate_samples
9 | from ..utils import decode_binary_labels
10 |
11 | class NuScenesMapDataset(Dataset):
12 |
13 | def __init__(self, nuscenes, map_root, image_size=(800, 450),
14 | scene_names=None):
15 |
16 | self.nuscenes = nuscenes
17 | self.map_root = os.path.expandvars(map_root)
18 | self.image_size = image_size
19 |
20 | # Preload the list of tokens in the dataset
21 | self.get_tokens(scene_names)
22 |
23 | # Allow PIL to load partially corrupted images
24 | # (otherwise training crashes at the most inconvenient possible times!)
25 | ImageFile.LOAD_TRUNCATED_IMAGES = True
26 |
27 |
28 | def get_tokens(self, scene_names=None):
29 |
30 | self.tokens = list()
31 |
32 | # Iterate over scenes
33 | for scene in self.nuscenes.scene:
34 |
35 | # Ignore scenes which don't belong to the current split
36 | if scene_names is not None and scene['name'] not in scene_names:
37 | continue
38 |
39 | # Iterate over samples
40 | for sample in iterate_samples(self.nuscenes,
41 | scene['first_sample_token']):
42 |
43 | # Iterate over cameras
44 | for camera in CAMERA_NAMES:
45 | self.tokens.append(sample['data'][camera])
46 |
47 | return self.tokens
48 |
49 |
50 | def __len__(self):
51 | return len(self.tokens)
52 |
53 | def __getitem__(self, index):
54 | token = self.tokens[index]
55 |
56 | image = self.load_image(token)
57 | calib = self.load_calib(token)
58 | labels, mask = self.load_labels(token)
59 |
60 | return image, calib, labels, mask
61 |
62 |
63 | def load_image(self, token):
64 |
65 | # Load image as a PIL image
66 | image = Image.open(self.nuscenes.get_sample_data_path(token))
67 |
68 | # Resize to input resolution
69 | image = image.resize(self.image_size)
70 |
71 | # Convert to a torch tensor
72 | return to_tensor(image)
73 |
74 |
75 | def load_calib(self, token):
76 |
77 | # Load camera intrinsics matrix
78 | sample_data = self.nuscenes.get('sample_data', token)
79 | sensor = self.nuscenes.get(
80 | 'calibrated_sensor', sample_data['calibrated_sensor_token'])
81 | intrinsics = torch.tensor(sensor['camera_intrinsic'])
82 |
83 | # Scale calibration matrix to account for image downsampling
84 | intrinsics[0] *= self.image_size[0] / sample_data['width']
85 | intrinsics[1] *= self.image_size[1] / sample_data['height']
86 | return intrinsics
87 |
88 |
89 | def load_labels(self, token):
90 |
91 | # Load label image as a torch tensor
92 | label_path = os.path.join(self.map_root, token + '.png')
93 | encoded_labels = to_tensor(Image.open(label_path)).long()
94 |
95 | # Decode to binary labels
96 | num_class = len(NUSCENES_CLASS_NAMES)
97 | labels = decode_binary_labels(encoded_labels, num_class + 1)
98 | labels, mask = labels[:-1], ~labels[-1]
99 |
100 | return labels, mask
101 |
102 |
103 |
104 |
--------------------------------------------------------------------------------
/src/data/nuscenes/splits.py:
--------------------------------------------------------------------------------
1 |
2 | TRAIN_SCENES = [
3 | "scene-0002", "scene-0003", "scene-0004", "scene-0005", "scene-0006",
4 | "scene-0007", "scene-0008", "scene-0009", "scene-0012", "scene-0013",
5 | "scene-0014", "scene-0015", "scene-0016", "scene-0017", "scene-0018",
6 | "scene-0019", "scene-0021", "scene-0022", "scene-0023", "scene-0024",
7 | "scene-0025", "scene-0026", "scene-0027", "scene-0028", "scene-0029",
8 | "scene-0030", "scene-0031", "scene-0032", "scene-0033", "scene-0034",
9 | "scene-0035", "scene-0036", "scene-0039", "scene-0042", "scene-0043",
10 | "scene-0044", "scene-0045", "scene-0046", "scene-0047", "scene-0048",
11 | "scene-0049", "scene-0050", "scene-0051", "scene-0052", "scene-0055",
12 | "scene-0056", "scene-0057", "scene-0058", "scene-0059", "scene-0060",
13 | "scene-0061", "scene-0062", "scene-0063", "scene-0064", "scene-0065",
14 | "scene-0066", "scene-0067", "scene-0068", "scene-0069", "scene-0070",
15 | "scene-0071", "scene-0072", "scene-0073", "scene-0074", "scene-0075",
16 | "scene-0076", "scene-0092", "scene-0093", "scene-0094", "scene-0095",
17 | "scene-0096", "scene-0097", "scene-0098", "scene-0099", "scene-0100",
18 | "scene-0101", "scene-0102", "scene-0103", "scene-0104", "scene-0105",
19 | "scene-0106", "scene-0107", "scene-0108", "scene-0109", "scene-0110",
20 | "scene-0120", "scene-0123", "scene-0124", "scene-0125", "scene-0126",
21 | "scene-0127", "scene-0128", "scene-0129", "scene-0130", "scene-0131",
22 | "scene-0132", "scene-0133", "scene-0134", "scene-0135", "scene-0138",
23 | "scene-0149", "scene-0150", "scene-0151", "scene-0154", "scene-0155",
24 | "scene-0157", "scene-0158", "scene-0159", "scene-0161", "scene-0162",
25 | "scene-0163", "scene-0164", "scene-0165", "scene-0166", "scene-0167",
26 | "scene-0168", "scene-0170", "scene-0171", "scene-0172", "scene-0173",
27 | "scene-0174", "scene-0175", "scene-0176", "scene-0177", "scene-0178",
28 | "scene-0179", "scene-0180", "scene-0181", "scene-0182", "scene-0183",
29 | "scene-0185", "scene-0187", "scene-0188", "scene-0190", "scene-0191",
30 | "scene-0192", "scene-0193", "scene-0194", "scene-0195", "scene-0196",
31 | "scene-0199", "scene-0200", "scene-0202", "scene-0203", "scene-0204",
32 | "scene-0206", "scene-0207", "scene-0208", "scene-0209", "scene-0210",
33 | "scene-0211", "scene-0212", "scene-0213", "scene-0214", "scene-0218",
34 | "scene-0219", "scene-0220", "scene-0221", "scene-0222", "scene-0224",
35 | "scene-0225", "scene-0226", "scene-0227", "scene-0228", "scene-0229",
36 | "scene-0230", "scene-0231", "scene-0232", "scene-0233", "scene-0234",
37 | "scene-0235", "scene-0236", "scene-0237", "scene-0238", "scene-0239",
38 | "scene-0240", "scene-0241", "scene-0242", "scene-0243", "scene-0244",
39 | "scene-0245", "scene-0246", "scene-0247", "scene-0248", "scene-0249",
40 | "scene-0250", "scene-0251", "scene-0252", "scene-0253", "scene-0254",
41 | "scene-0255", "scene-0256", "scene-0257", "scene-0258", "scene-0259",
42 | "scene-0260", "scene-0261", "scene-0262", "scene-0263", "scene-0264",
43 | "scene-0268", "scene-0270", "scene-0271", "scene-0272", "scene-0273",
44 | "scene-0274", "scene-0275", "scene-0276", "scene-0277", "scene-0278",
45 | "scene-0283", "scene-0284", "scene-0285", "scene-0286", "scene-0287",
46 | "scene-0288", "scene-0289", "scene-0290", "scene-0291", "scene-0292",
47 | "scene-0293", "scene-0294", "scene-0295", "scene-0296", "scene-0297",
48 | "scene-0298", "scene-0299", "scene-0300", "scene-0301", "scene-0302",
49 | "scene-0303", "scene-0304", "scene-0305", "scene-0306", "scene-0315",
50 | "scene-0316", "scene-0317", "scene-0318", "scene-0321", "scene-0323",
51 | "scene-0324", "scene-0328", "scene-0329", "scene-0330", "scene-0331",
52 | "scene-0332", "scene-0344", "scene-0345", "scene-0346", "scene-0349",
53 | "scene-0350", "scene-0351", "scene-0352", "scene-0353", "scene-0354",
54 | "scene-0355", "scene-0356", "scene-0357", "scene-0358", "scene-0359",
55 | "scene-0360", "scene-0361", "scene-0362", "scene-0363", "scene-0364",
56 | "scene-0365", "scene-0367", "scene-0370", "scene-0371", "scene-0372",
57 | "scene-0373", "scene-0374", "scene-0375", "scene-0376", "scene-0377",
58 | "scene-0379", "scene-0380", "scene-0381", "scene-0382", "scene-0383",
59 | "scene-0384", "scene-0385", "scene-0386", "scene-0388", "scene-0399",
60 | "scene-0400", "scene-0401", "scene-0402", "scene-0403", "scene-0405",
61 | "scene-0406", "scene-0407", "scene-0408", "scene-0420", "scene-0421",
62 | "scene-0422", "scene-0423", "scene-0424", "scene-0425", "scene-0426",
63 | "scene-0427", "scene-0428", "scene-0429", "scene-0430", "scene-0431",
64 | "scene-0432", "scene-0433", "scene-0434", "scene-0435", "scene-0436",
65 | "scene-0437", "scene-0438", "scene-0439", "scene-0440", "scene-0441",
66 | "scene-0442", "scene-0443", "scene-0444", "scene-0445", "scene-0446",
67 | "scene-0447", "scene-0448", "scene-0449", "scene-0450", "scene-0451",
68 | "scene-0452", "scene-0453", "scene-0454", "scene-0455", "scene-0456",
69 | "scene-0457", "scene-0458", "scene-0459", "scene-0461", "scene-0462",
70 | "scene-0463", "scene-0464", "scene-0465", "scene-0467", "scene-0468",
71 | "scene-0469", "scene-0471", "scene-0472", "scene-0474", "scene-0475",
72 | "scene-0476", "scene-0477", "scene-0478", "scene-0479", "scene-0480",
73 | "scene-0499", "scene-0500", "scene-0501", "scene-0502", "scene-0504",
74 | "scene-0505", "scene-0506", "scene-0507", "scene-0508", "scene-0509",
75 | "scene-0510", "scene-0511", "scene-0512", "scene-0513", "scene-0514",
76 | "scene-0515", "scene-0517", "scene-0518", "scene-0519", "scene-0520",
77 | "scene-0521", "scene-0522", "scene-0523", "scene-0524", "scene-0552",
78 | "scene-0553", "scene-0554", "scene-0555", "scene-0559", "scene-0560",
79 | "scene-0561", "scene-0562", "scene-0563", "scene-0564", "scene-0565",
80 | "scene-0584", "scene-0585", "scene-0586", "scene-0587", "scene-0588",
81 | "scene-0589", "scene-0590", "scene-0591", "scene-0592", "scene-0593",
82 | "scene-0594", "scene-0595", "scene-0596", "scene-0597", "scene-0598",
83 | "scene-0599", "scene-0600", "scene-0625", "scene-0626", "scene-0627",
84 | "scene-0629", "scene-0630", "scene-0632", "scene-0633", "scene-0634",
85 | "scene-0635", "scene-0636", "scene-0637", "scene-0638", "scene-0639",
86 | "scene-0640", "scene-0652", "scene-0653", "scene-0654", "scene-0655",
87 | "scene-0656", "scene-0657", "scene-0658", "scene-0659", "scene-0660",
88 | "scene-0661", "scene-0662", "scene-0663", "scene-0664", "scene-0665",
89 | "scene-0666", "scene-0667", "scene-0668", "scene-0669", "scene-0670",
90 | "scene-0671", "scene-0672", "scene-0673", "scene-0674", "scene-0675",
91 | "scene-0676", "scene-0677", "scene-0678", "scene-0679", "scene-0681",
92 | "scene-0683", "scene-0684", "scene-0685", "scene-0686", "scene-0687",
93 | "scene-0688", "scene-0689", "scene-0695", "scene-0696", "scene-0697",
94 | "scene-0698", "scene-0700", "scene-0701", "scene-0703", "scene-0704",
95 | "scene-0705", "scene-0706", "scene-0707", "scene-0708", "scene-0709",
96 | "scene-0710", "scene-0711", "scene-0712", "scene-0713", "scene-0714",
97 | "scene-0715", "scene-0716", "scene-0717", "scene-0718", "scene-0719",
98 | "scene-0726", "scene-0727", "scene-0728", "scene-0730", "scene-0731",
99 | "scene-0733", "scene-0734", "scene-0735", "scene-0736", "scene-0737",
100 | "scene-0738", "scene-0780", "scene-0781", "scene-0782", "scene-0783",
101 | "scene-0784", "scene-0786", "scene-0787", "scene-0789", "scene-0790",
102 | "scene-0791", "scene-0792", "scene-0802", "scene-0806", "scene-0808",
103 | "scene-0809", "scene-0810", "scene-0811", "scene-0812", "scene-0813",
104 | "scene-0815", "scene-0816", "scene-0817", "scene-0819", "scene-0820",
105 | "scene-0821", "scene-0822", "scene-0847", "scene-0848", "scene-0849",
106 | "scene-0850", "scene-0851", "scene-0852", "scene-0853", "scene-0854",
107 | "scene-0855", "scene-0856", "scene-0858", "scene-0860", "scene-0861",
108 | "scene-0862", "scene-0863", "scene-0864", "scene-0865", "scene-0866",
109 | "scene-0868", "scene-0869", "scene-0870", "scene-0871", "scene-0872",
110 | "scene-0873", "scene-0875", "scene-0876", "scene-0877", "scene-0878",
111 | "scene-0880", "scene-0882", "scene-0883", "scene-0884", "scene-0885",
112 | "scene-0886", "scene-0887", "scene-0888", "scene-0889", "scene-0890",
113 | "scene-0891", "scene-0892", "scene-0893", "scene-0894", "scene-0895",
114 | "scene-0896", "scene-0897", "scene-0898", "scene-0899", "scene-0900",
115 | "scene-0901", "scene-0902", "scene-0903", "scene-0904", "scene-0905",
116 | "scene-0906", "scene-0907", "scene-0908", "scene-0909", "scene-0916",
117 | "scene-0917", "scene-0921", "scene-0922", "scene-0923", "scene-0925",
118 | "scene-0926", "scene-0927", "scene-0928", "scene-0929", "scene-0930",
119 | "scene-0931", "scene-0945", "scene-0947", "scene-0949", "scene-0952",
120 | "scene-0953", "scene-0955", "scene-0956", "scene-0957", "scene-0958",
121 | "scene-0959", "scene-0960", "scene-0961", "scene-0966", "scene-0967",
122 | "scene-0968", "scene-0969", "scene-0971", "scene-0972", "scene-0975",
123 | "scene-0976", "scene-0977", "scene-0978", "scene-0979", "scene-0980",
124 | "scene-0981", "scene-0982", "scene-0983", "scene-0984", "scene-0988",
125 | "scene-0989", "scene-0990", "scene-0991", "scene-0992", "scene-0994",
126 | "scene-0995", "scene-0996", "scene-0997", "scene-0998", "scene-0999",
127 | "scene-1000", "scene-1001", "scene-1004", "scene-1005", "scene-1006",
128 | "scene-1007", "scene-1008", "scene-1009", "scene-1010", "scene-1011",
129 | "scene-1012", "scene-1013", "scene-1014", "scene-1015", "scene-1019",
130 | "scene-1020", "scene-1021", "scene-1022", "scene-1023", "scene-1024",
131 | "scene-1025", "scene-1044", "scene-1045", "scene-1046", "scene-1047",
132 | "scene-1048", "scene-1049", "scene-1050", "scene-1051", "scene-1052",
133 | "scene-1053", "scene-1054", "scene-1064", "scene-1065", "scene-1066",
134 | "scene-1067", "scene-1068", "scene-1069", "scene-1070", "scene-1071",
135 | "scene-1072", "scene-1073", "scene-1074", "scene-1075", "scene-1076",
136 | "scene-1077", "scene-1078", "scene-1079", "scene-1080", "scene-1081",
137 | "scene-1082", "scene-1083", "scene-1084", "scene-1085", "scene-1086",
138 | "scene-1087", "scene-1088", "scene-1089", "scene-1090", "scene-1091",
139 | "scene-1092", "scene-1093", "scene-1094", "scene-1095", "scene-1096",
140 | "scene-1097", "scene-1098", "scene-1099", "scene-1100", "scene-1101",
141 | "scene-1102", "scene-1104", "scene-1105", "scene-1106", "scene-1107",
142 | "scene-1108", "scene-1109", "scene-1110"]
143 |
144 | VAL_SCENES = [
145 | "scene-0001", "scene-0010", "scene-0011", "scene-0020", "scene-0038",
146 | "scene-0041", "scene-0053", "scene-0054", "scene-0121", "scene-0122",
147 | "scene-0139", "scene-0152", "scene-0160", "scene-0184", "scene-0269",
148 | "scene-0347", "scene-0348", "scene-0366", "scene-0368", "scene-0369",
149 | "scene-0378", "scene-0389", "scene-0390", "scene-0391", "scene-0392",
150 | "scene-0393", "scene-0394", "scene-0395", "scene-0396", "scene-0397",
151 | "scene-0398", "scene-0411", "scene-0412", "scene-0413", "scene-0414",
152 | "scene-0415", "scene-0416", "scene-0417", "scene-0418", "scene-0419",
153 | "scene-0525", "scene-0526", "scene-0527", "scene-0528", "scene-0529",
154 | "scene-0530", "scene-0531", "scene-0532", "scene-0533", "scene-0534",
155 | "scene-0535", "scene-0536", "scene-0537", "scene-0538", "scene-0539",
156 | "scene-0541", "scene-0542", "scene-0543", "scene-0544", "scene-0545",
157 | "scene-0546", "scene-0556", "scene-0557", "scene-0558", "scene-0566",
158 | "scene-0568", "scene-0570", "scene-0571", "scene-0572", "scene-0573",
159 | "scene-0574", "scene-0575", "scene-0576", "scene-0577", "scene-0578",
160 | "scene-0580", "scene-0582", "scene-0583", "scene-0642", "scene-0643",
161 | "scene-0644", "scene-0645", "scene-0646", "scene-0647", "scene-0648",
162 | "scene-0649", "scene-0650", "scene-0651", "scene-0739", "scene-0740",
163 | "scene-0741", "scene-0744", "scene-0746", "scene-0747", "scene-0749",
164 | "scene-0750", "scene-0751", "scene-0752", "scene-0757", "scene-0758",
165 | "scene-0759", "scene-0760", "scene-0761", "scene-0762", "scene-0763",
166 | "scene-0764", "scene-0765", "scene-0767", "scene-0768", "scene-0769",
167 | "scene-0770", "scene-0771", "scene-0775", "scene-0777", "scene-0778",
168 | "scene-0794", "scene-0795", "scene-0796", "scene-0797", "scene-0798",
169 | "scene-0799", "scene-0800", "scene-0803", "scene-0804", "scene-0911",
170 | "scene-0912", "scene-0913", "scene-0914", "scene-0915", "scene-0919",
171 | "scene-0920", "scene-0924", "scene-0962", "scene-0963", "scene-1002",
172 | "scene-1003", "scene-1016", "scene-1017", "scene-1018", "scene-1055",
173 | "scene-1056", "scene-1057", "scene-1058", "scene-1059", "scene-1060",
174 | "scene-1061", "scene-1062", "scene-1063"]
175 |
176 |
177 | CALIBRATION_SCENES = [
178 | "scene-0852", "scene-0429", "scene-0956", "scene-0194", "scene-0811",
179 | "scene-1110", "scene-1107", "scene-0294", "scene-0900", "scene-0596",
180 | "scene-0296", "scene-0885", "scene-0866", "scene-0105", "scene-0782",
181 | "scene-0191", "scene-0876", "scene-0133", "scene-0231", "scene-0847",
182 | "scene-0363", "scene-0026", "scene-0791", "scene-0909", "scene-0002",
183 | "scene-0283", "scene-0007", "scene-0251", "scene-1100", "scene-0668",
184 | "scene-0584", "scene-0287", "scene-0260", "scene-0171", "scene-0789",
185 | "scene-0108", "scene-0190", "scene-0206", "scene-0635", "scene-0815",
186 | "scene-0058", "scene-0710", "scene-0302", "scene-0639", "scene-0166",
187 | "scene-0094", "scene-0735", "scene-0321", "scene-1091", "scene-0344"
188 | ]
--------------------------------------------------------------------------------
/src/data/nuscenes/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from shapely import geometry, affinity
4 | from pyquaternion import Quaternion
5 |
6 | from nuscenes.eval.detection.utils import category_to_detection_name
7 | from nuscenes.eval.detection.constants import DETECTION_NAMES
8 | from nuscenes.utils.data_classes import LidarPointCloud
9 |
10 | from ..utils import transform_polygon, render_polygon, transform
11 |
12 | CAMERA_NAMES = ['CAM_FRONT', 'CAM_FRONT_LEFT', 'CAM_FRONT_RIGHT',
13 | 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT', 'CAM_BACK']
14 |
15 | NUSCENES_CLASS_NAMES = [
16 | 'drivable_area', 'ped_crossing', 'walkway', 'carpark', 'car', 'truck',
17 | 'bus', 'trailer', 'construction_vehicle', 'pedestrian', 'motorcycle',
18 | 'bicycle', 'traffic_cone', 'barrier'
19 | ]
20 |
21 | STATIC_CLASSES = ['drivable_area', 'ped_crossing', 'walkway', 'carpark_area']
22 |
23 | LOCATIONS = ['boston-seaport', 'singapore-onenorth', 'singapore-queenstown',
24 | 'singapore-hollandvillage']
25 |
26 |
27 | def iterate_samples(nuscenes, start_token):
28 | sample_token = start_token
29 | while sample_token != '':
30 | sample = nuscenes.get('sample', sample_token)
31 | yield sample
32 | sample_token = sample['next']
33 |
34 |
35 | def get_map_masks(nuscenes, map_data, sample_data, extents, resolution):
36 |
37 | # Render each layer sequentially
38 | layers = [get_layer_mask(nuscenes, polys, sample_data, extents,
39 | resolution) for layer, polys in map_data.items()]
40 |
41 | return np.stack(layers, axis=0)
42 |
43 |
44 | def get_layer_mask(nuscenes, polygons, sample_data, extents, resolution):
45 |
46 | # Get the 2D affine transform from bev coords to map coords
47 | tfm = get_sensor_transform(nuscenes, sample_data)[[0, 1, 3]][:, [0, 2, 3]]
48 | inv_tfm = np.linalg.inv(tfm)
49 |
50 | # Create a patch representing the birds-eye-view region in map coordinates
51 | map_patch = geometry.box(*extents)
52 | map_patch = transform_polygon(map_patch, tfm)
53 |
54 | # Initialise the map mask
55 | x1, z1, x2, z2 = extents
56 | mask = np.zeros((int((z2 - z1) / resolution), int((x2 - x1) / resolution)),
57 | dtype=np.uint8)
58 |
59 | # Find all polygons which intersect with the area of interest
60 | for polygon in polygons.query(map_patch):
61 |
62 | polygon = polygon.intersection(map_patch)
63 |
64 | # Transform into map coordinates
65 | polygon = transform_polygon(polygon, inv_tfm)
66 |
67 | # Render the polygon to the mask
68 | render_shapely_polygon(mask, polygon, extents, resolution)
69 |
70 | return mask.astype(np.bool)
71 |
72 |
73 |
74 |
75 | def get_object_masks(nuscenes, sample_data, extents, resolution):
76 |
77 | # Initialize object masks
78 | nclass = len(DETECTION_NAMES) + 1
79 | grid_width = int((extents[2] - extents[0]) / resolution)
80 | grid_height = int((extents[3] - extents[1]) / resolution)
81 | masks = np.zeros((nclass, grid_height, grid_width), dtype=np.uint8)
82 |
83 | # Get the 2D affine transform from bev coords to map coords
84 | tfm = get_sensor_transform(nuscenes, sample_data)[[0, 1, 3]][:, [0, 2, 3]]
85 | inv_tfm = np.linalg.inv(tfm)
86 |
87 | for box in nuscenes.get_boxes(sample_data['token']):
88 |
89 | # Get the index of the class
90 | det_name = category_to_detection_name(box.name)
91 | if det_name not in DETECTION_NAMES:
92 | class_id = -1
93 | else:
94 | class_id = DETECTION_NAMES.index(det_name)
95 |
96 | # Get bounding box coordinates in the grid coordinate frame
97 | bbox = box.bottom_corners()[:2]
98 | local_bbox = np.dot(inv_tfm[:2, :2], bbox).T + inv_tfm[:2, 2]
99 |
100 | # Render the rotated bounding box to the mask
101 | render_polygon(masks[class_id], local_bbox, extents, resolution)
102 |
103 | return masks.astype(np.bool)
104 |
105 |
106 | def get_sensor_transform(nuscenes, sample_data):
107 |
108 | # Load sensor transform data
109 | sensor = nuscenes.get(
110 | 'calibrated_sensor', sample_data['calibrated_sensor_token'])
111 | sensor_tfm = make_transform_matrix(sensor)
112 |
113 | # Load ego pose data
114 | pose = nuscenes.get('ego_pose', sample_data['ego_pose_token'])
115 | pose_tfm = make_transform_matrix(pose)
116 |
117 | return np.dot(pose_tfm, sensor_tfm)
118 |
119 |
120 | def load_point_cloud(nuscenes, sample_data):
121 |
122 | # Load point cloud
123 | lidar_path = os.path.join(nuscenes.dataroot, sample_data['filename'])
124 | pcl = LidarPointCloud.from_file(lidar_path)
125 | return pcl.points[:3, :].T
126 |
127 |
128 | def make_transform_matrix(record):
129 | """
130 | Create a 4x4 transform matrix from a calibrated_sensor or ego_pose record
131 | """
132 | transform = np.eye(4)
133 | transform[:3, :3] = Quaternion(record['rotation']).rotation_matrix
134 | transform[:3, 3] = np.array(record['translation'])
135 | return transform
136 |
137 |
138 | def render_shapely_polygon(mask, polygon, extents, resolution):
139 |
140 | if polygon.geom_type == 'Polygon':
141 |
142 | # Render exteriors
143 | render_polygon(mask, polygon.exterior.coords, extents, resolution, 1)
144 |
145 | # Render interiors
146 | for hole in polygon.interiors:
147 | render_polygon(mask, hole.coords, extents, resolution, 0)
148 |
149 | # Handle the case of compound shapes
150 | else:
151 | for poly in polygon:
152 | render_shapely_polygon(mask, poly, extents, resolution)
153 |
154 |
155 |
156 |
157 |
--------------------------------------------------------------------------------
/src/data/utils.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 | from shapely import affinity
5 |
6 | def decode_binary_labels(labels, nclass):
7 | bits = torch.pow(2, torch.arange(nclass))
8 | return (labels & bits.view(-1, 1, 1)) > 0
9 |
10 |
11 | def encode_binary_labels(masks):
12 | bits = np.power(2, np.arange(len(masks), dtype=np.int32))
13 | return (masks.astype(np.int32) * bits.reshape(-1, 1, 1)).sum(0)
14 |
15 |
16 | def transform(matrix, vectors):
17 | vectors = np.dot(matrix[:-1, :-1], vectors.T)
18 | vectors = vectors.T + matrix[:-1, -1]
19 | return vectors
20 |
21 |
22 | def transform_polygon(polygon, affine):
23 | """
24 | Transform a 2D polygon
25 | """
26 | a, b, tx, c, d, ty = affine.flatten()[:6]
27 | return affinity.affine_transform(polygon, [a, b, c, d, tx, ty])
28 |
29 |
30 | def render_polygon(mask, polygon, extents, resolution, value=1):
31 | if len(polygon) == 0:
32 | return
33 | polygon = (polygon - np.array(extents[:2])) / resolution
34 | polygon = np.ascontiguousarray(polygon).round().astype(np.int32)
35 | cv2.fillConvexPoly(mask, polygon, value)
36 |
37 |
38 | def get_visible_mask(instrinsics, image_width, extents, resolution):
39 |
40 | # Get calibration parameters
41 | fu, cu = instrinsics[0, 0], instrinsics[0, 2]
42 |
43 | # Construct a grid of image coordinates
44 | x1, z1, x2, z2 = extents
45 | x, z = np.arange(x1, x2, resolution), np.arange(z1, z2, resolution)
46 | ucoords = x / z[:, None] * fu + cu
47 |
48 | # Return all points which lie within the camera bounds
49 | return (ucoords >= 0) & (ucoords < image_width)
50 |
51 |
52 | def get_occlusion_mask(points, extents, resolution):
53 |
54 | x1, z1, x2, z2 = extents
55 |
56 | # A 'ray' is defined by the ratio between x and z coordinates
57 | ray_width = resolution / z2
58 | ray_offset = x1 / ray_width
59 | max_rays = int((x2 - x1) / ray_width)
60 |
61 | # Group LiDAR points into bins
62 | rayid = np.round(points[:, 0] / points[:, 2] / ray_width - ray_offset)
63 | depth = points[:, 2]
64 |
65 | # Ignore rays which do not correspond to any grid cells in the BEV
66 | valid = (rayid > 0) & (rayid < max_rays) & (depth > 0)
67 | rayid = rayid[valid]
68 | depth = depth[valid]
69 |
70 | # Find the LiDAR point with maximum depth within each bin
71 | max_depth = np.zeros((max_rays,))
72 | np.maximum.at(max_depth, rayid.astype(np.int32), depth)
73 |
74 | # For each bev grid point, sample the max depth along the corresponding ray
75 | x = np.arange(x1, x2, resolution)
76 | z = np.arange(z1, z2, resolution)[:, None]
77 | grid_rayid = np.round(x / z / ray_width - ray_offset).astype(np.int32)
78 | grid_max_depth = max_depth[grid_rayid]
79 |
80 | # A grid position is considered occluded if the there are no LiDAR points
81 | # passing through it
82 | occluded = grid_max_depth < z
83 | return occluded
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
--------------------------------------------------------------------------------
/src/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avishkarsaha/translating-images-into-maps/a452fe7a1eb0062f96133d14380d504e7bc0c8d1/src/model/__init__.py
--------------------------------------------------------------------------------
/src/model/axial_attention.py:
--------------------------------------------------------------------------------
1 | """
2 | Copied from https://github.com/lucidrains/axial-attention.git
3 | """
4 |
5 | import torch
6 | from torch import nn
7 | from operator import itemgetter
8 |
9 |
10 | def map_el_ind(arr, ind):
11 | return list(map(itemgetter(ind), arr))
12 |
13 |
14 | def sort_and_return_indices(arr):
15 | indices = [ind for ind in range(len(arr))]
16 | arr = zip(arr, indices)
17 | arr = sorted(arr)
18 | return map_el_ind(arr, 0), map_el_ind(arr, 1)
19 |
20 |
21 | # calculates the permutation to bring the input tensor to something attend-able
22 | # also calculates the inverse permutation to bring the tensor back to its original shape
23 |
24 |
25 | def calculate_permutations(num_dimensions, emb_dim):
26 | total_dimensions = num_dimensions + 2
27 | emb_dim = emb_dim if emb_dim > 0 else (emb_dim + total_dimensions)
28 | axial_dims = [ind for ind in range(1, total_dimensions) if ind != emb_dim]
29 |
30 | permutations = []
31 |
32 | for axial_dim in axial_dims:
33 | last_two_dims = [axial_dim, emb_dim]
34 | dims_rest = set(range(0, total_dimensions)) - set(last_two_dims)
35 | permutation = [*dims_rest, *last_two_dims]
36 | permutations.append(permutation)
37 |
38 | return permutations
39 |
40 |
41 | # helper classes
42 |
43 |
44 | class Rezero(nn.Module):
45 | def __init__(self, fn):
46 | super().__init__()
47 | self.fn = fn
48 | self.g = nn.Parameter(torch.tensor(0.0))
49 |
50 | def forward(self, x):
51 | return self.fn(x) * self.g
52 |
53 |
54 | class Sequential(nn.Module):
55 | def __init__(self, blocks):
56 | super().__init__()
57 | self.blocks = blocks
58 |
59 | def forward(self, x):
60 | for f, g in self.blocks:
61 | x = x + f(x) + g(x)
62 | return x
63 |
64 |
65 | class PermuteToFrom(nn.Module):
66 | def __init__(self, permutation, fn):
67 | super().__init__()
68 | self.fn = fn
69 | _, inv_permutation = sort_and_return_indices(permutation)
70 | self.permutation = permutation
71 | self.inv_permutation = inv_permutation
72 |
73 | def forward(self, x, **kwargs):
74 | axial = x.permute(*self.permutation).contiguous()
75 |
76 | shape = axial.shape
77 | *_, t, d = shape
78 |
79 | # merge all but axial dimension
80 | axial = axial.reshape(-1, t, d)
81 |
82 | # attention
83 | axial = self.fn(axial, **kwargs)
84 |
85 | # restore to original shape and permutation
86 | axial = axial.reshape(*shape)
87 | axial = axial.permute(*self.inv_permutation).contiguous()
88 | return axial
89 |
90 |
91 | class AxialPositionalEmbedding(nn.Module):
92 | def __init__(self, emb_dim, emb_dim_index, dimensions):
93 | super().__init__()
94 | parameters = []
95 | total_dimensions = len(dimensions) + 2
96 | ax_dim_indexes = [i for i in range(1, total_dimensions) if i != emb_dim_index]
97 |
98 | for axial_dim, axial_dim_index in zip(dimensions, ax_dim_indexes):
99 | shape = [1] * total_dimensions
100 | shape[emb_dim_index] = emb_dim
101 | shape[axial_dim_index] = axial_dim
102 | parameter = nn.Parameter(torch.randn(*shape))
103 | parameters.append(parameter)
104 |
105 | self.params = nn.ParameterList(parameters)
106 |
107 | def forward(self, x):
108 | for param in self.params:
109 | x = x + param
110 | return x
111 |
112 |
113 | # classic multi-head attention
114 |
115 |
116 | def attention(q, k, v, h):
117 | b, t, d = q.shape
118 | e = d // h
119 |
120 | merge_heads = lambda x: x.reshape(b, -1, h, e).transpose(1, 2).reshape(b * h, -1, e)
121 | q, k, v = map(merge_heads, (q, k, v))
122 | dots = torch.einsum("bie,bje->bij", q, k) * ((d // h) ** -0.5)
123 | dots = dots.softmax(dim=-1)
124 | out = torch.einsum("bij,bje->bie", dots, v)
125 | out = out.reshape(b, h, -1, e).transpose(1, 2).reshape(b, -1, d)
126 | return out
127 |
128 |
129 | class SelfAttention(nn.Module):
130 | def __init__(self, dim, heads, dim_heads=None):
131 | super().__init__()
132 | self.dim_heads = (dim // heads) if dim_heads is None else dim_heads
133 | dim_hidden = self.dim_heads * heads
134 |
135 | self.heads = heads
136 | self.to_q = nn.Linear(dim, dim_hidden, bias=False)
137 | self.to_kv = nn.Linear(dim, 2 * dim_hidden, bias=False)
138 | self.to_out = nn.Linear(dim_hidden, dim)
139 |
140 | def forward(self, x, kv=None):
141 | kv = x if kv is None else kv
142 | q, k, v = (self.to_q(x), *self.to_kv(kv).chunk(2, dim=-1))
143 |
144 | b, t, d, h, e = *q.shape, self.heads, self.dim_heads
145 |
146 | merge_heads = (
147 | lambda x: x.reshape(b, -1, h, e).transpose(1, 2).reshape(b * h, -1, e)
148 | )
149 | q, k, v = map(merge_heads, (q, k, v))
150 |
151 | dots = torch.einsum("bie,bje->bij", q, k) * ((d // h) ** -0.5)
152 | dots = dots.softmax(dim=-1)
153 | out = torch.einsum("bij,bje->bie", dots, v)
154 |
155 | out = out.reshape(b, h, -1, e).transpose(1, 2).reshape(b, -1, d)
156 | out = self.to_out(out)
157 | return out
158 |
159 |
160 | class InducedSetAttention(nn.Module):
161 | def __init__(self, num_queries, dim, heads, dim_heads=None):
162 | super().__init__()
163 | self.queries = nn.Parameter(torch.randn(1, num_queries, dim))
164 | self.attn_in = SelfAttention(dim, heads)
165 | self.attn_out = SelfAttention(dim, heads)
166 |
167 | def forward(self, x):
168 | b = x.shape[0]
169 | q = self.queries.expand(b, -1, -1)
170 | q_out = self.attn_in(q, x)
171 | out = self.attn_out(x, q_out)
172 | return out
173 |
174 |
175 | # axial attention class
176 |
177 |
178 | class AxialAttention(nn.Module):
179 | def __init__(
180 | self,
181 | dim,
182 | num_dimensions=2,
183 | heads=8,
184 | dim_heads=None,
185 | dim_index=-1,
186 | sum_axial_out=True,
187 | ):
188 | assert (
189 | dim % heads
190 | ) == 0, "hidden dimension must be divisible by number of heads"
191 | super().__init__()
192 | self.dim = dim
193 | self.total_dimensions = num_dimensions + 2
194 | self.dim_index = (
195 | dim_index if dim_index > 0 else (dim_index + self.total_dimensions)
196 | )
197 |
198 | attentions = []
199 | for permutation in calculate_permutations(num_dimensions, dim_index):
200 | attentions.append(
201 | PermuteToFrom(permutation, SelfAttention(dim, heads, dim_heads))
202 | )
203 |
204 | self.axial_attentions = nn.ModuleList(attentions)
205 | self.sum_axial_out = sum_axial_out
206 |
207 | def forward(self, x):
208 | assert (
209 | len(x.shape) == self.total_dimensions
210 | ), "input tensor does not have the correct number of dimensions"
211 | assert (
212 | x.shape[self.dim_index] == self.dim
213 | ), "input tensor does not have the correct input dimension"
214 |
215 | if self.sum_axial_out:
216 | return sum(map(lambda axial_attn: axial_attn(x), self.axial_attentions))
217 |
218 | axial_attn = self.axial_attentions[0]
219 | out = axial_attn(x)
220 | for axial_attn in self.axial_attentions[1:]:
221 | out = axial_attn(out)
222 | return out
223 |
--------------------------------------------------------------------------------
/src/model/backbone_utils.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from torch import nn
3 | from torchvision.ops.feature_pyramid_network import (
4 | FeaturePyramidNetwork,
5 | LastLevelMaxPool,
6 | )
7 |
8 | from torchvision.ops import misc as misc_nn_ops
9 | from torchvision.models._utils import IntermediateLayerGetter
10 | from torch.utils.checkpoint import checkpoint, checkpoint_sequential
11 |
12 | from src.model.intermediate_layer_getter import IntermediateLayerGetter as MidGetter
13 |
14 | from src.model import resnet
15 | from src.model.fpn import FeaturePyramidNetworkRoddick, LLMaxPool
16 |
17 |
18 | class BackboneWithFPN(nn.Module):
19 | """
20 | Adds a FPN on top of a model.
21 | Internally, it uses torchvision.models._utils.IntermediateLayerGetter to
22 | extract a submodel that returns the feature maps specified in return_layers.
23 | The same limitations of IntermediatLayerGetter apply here.
24 | Arguments:
25 | backbone (nn.Module)
26 | return_layers (Dict[name, new_name]): a dict containing the names
27 | of the modules for which the activations will be returned as
28 | the key of the dict, and the value of the dict is the name
29 | of the returned activation (which the user can specify).
30 | in_channels_list (List[int]): number of channels for each feature map
31 | that is returned, in the order they are present in the OrderedDict
32 | out_channels (int): number of channels in the FPN.
33 | Attributes:
34 | out_channels (int): the number of channels in the FPN
35 | """
36 |
37 | def __init__(self, backbone, return_layers, in_channels_list, out_channels):
38 | super(BackboneWithFPN, self).__init__()
39 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
40 | self.fpn = FeaturePyramidNetwork(
41 | in_channels_list=in_channels_list,
42 | out_channels=out_channels,
43 | extra_blocks=LastLevelMaxPool(),
44 | )
45 | self.out_channels = out_channels
46 |
47 | def forward(self, x):
48 | x = self.body(x)
49 | x = self.fpn(x)
50 | return x
51 |
52 |
53 | class BackboneWithFPNRoddick(nn.Module):
54 | """
55 | Adds a FPN on top of a model.
56 | Internally, it uses torchvision.models._utils.IntermediateLayerGetter to
57 | extract a submodel that returns the feature maps specified in return_layers.
58 | The same limitations of IntermediatLayerGetter apply here.
59 | Arguments:
60 | backbone (nn.Module)
61 | return_layers (Dict[name, new_name]): a dict containing the names
62 | of the modules for which the activations will be returned as
63 | the key of the dict, and the value of the dict is the name
64 | of the returned activation (which the user can specify).
65 | in_channels_list (List[int]): number of channels for each feature map
66 | that is returned, in the order they are present in the OrderedDict
67 | out_channels (int): number of channels in the FPN.
68 | Attributes:
69 | out_channels (int): the number of channels in the FPN
70 | """
71 |
72 | def __init__(self, backbone, return_layers, in_channels_list, out_channels):
73 | super(BackboneWithFPNRoddick, self).__init__()
74 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
75 | self.fpn = FeaturePyramidNetworkRoddick(
76 | in_channels_list=in_channels_list,
77 | out_channels=out_channels,
78 | extra_blocks=LLMaxPool(),
79 | )
80 | self.out_channels = out_channels
81 |
82 | def forward(self, x):
83 | x = self.body(x)
84 | x = self.fpn(x)
85 | return x
86 |
87 |
88 | def resnet_fpn_backbone(backbone_name, pretrained):
89 | backbone = resnet.__dict__[backbone_name](
90 | pretrained=pretrained, norm_layer=misc_nn_ops.FrozenBatchNorm2d
91 | )
92 | # freeze layers
93 | for name, parameter in backbone.named_parameters():
94 | if "layer2" not in name and "layer3" not in name and "layer4" not in name:
95 | parameter.requires_grad_(False)
96 |
97 | return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
98 |
99 | in_channels_stage2 = backbone.inplanes // 8
100 | in_channels_list = [
101 | in_channels_stage2,
102 | in_channels_stage2 * 2,
103 | in_channels_stage2 * 4,
104 | in_channels_stage2 * 8,
105 | ]
106 | out_channels = 256
107 | return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels)
108 |
109 |
110 | def resnet_fpn_backbone_roddick(backbone_name, pretrained):
111 | backbone = resnet.__dict__["resnet50Roddick"](
112 | pretrained=pretrained, norm_layer=misc_nn_ops.FrozenBatchNorm2d
113 | )
114 | # freeze layers
115 | for name, parameter in backbone.named_parameters():
116 | if (
117 | "layer2" not in name
118 | and "layer3" not in name
119 | and "layer4" not in name
120 | and "layer5" not in name
121 | ):
122 | parameter.requires_grad_(False)
123 |
124 | return_layers = {
125 | "layer1": "0",
126 | "layer2": "1",
127 | "layer3": "2",
128 | "layer4": "3",
129 | "layer5": "4",
130 | }
131 |
132 | in_channels_stage2 = backbone.inplanes // 8
133 | in_channels_list = [
134 | in_channels_stage2,
135 | in_channels_stage2 * 2,
136 | in_channels_stage2 * 4,
137 | in_channels_stage2 * 8,
138 | in_channels_stage2 * 8,
139 | ]
140 | out_channels = 256
141 | return BackboneWithFPNRoddick(
142 | backbone, return_layers, in_channels_list, out_channels
143 | )
144 |
--------------------------------------------------------------------------------
/src/model/bev_transform.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import matplotlib.pyplot as plt
5 | from .. import utils
6 |
7 | EPSILON = 1e-6
8 |
9 |
10 | class BEVT_04(nn.Module):
11 | """
12 | Changes from BEVT:
13 | Condensed bottleneck along channel dimension,
14 | Increased dropout probability,
15 | """
16 |
17 | def __init__(self, in_height, z_max, z_min, cell_size):
18 | super().__init__()
19 |
20 | # [B, C, H, W] --> [B, C_condensed, H, W]
21 | self.conv_c = nn.Sequential(
22 | nn.Conv2d(256, 32, kernel_size=1),
23 | nn.GroupNorm(2, 32),
24 | nn.ReLU(),
25 | )
26 | # [B, C_condensed, W, H] --> [B, C_condensed, W, 1]
27 | self.linear_v = nn.Sequential(
28 | nn.Linear(in_height, 1),
29 | nn.GroupNorm(16, 32),
30 | nn.ReLU(),
31 | nn.Dropout(p=0.5),
32 | )
33 |
34 | # [B, 1, C_condensed, W] --> [B, Z, C_condensed, W]
35 | depth = (z_max - z_min) / cell_size
36 |
37 | self.z_extend = nn.Conv2d(1, int(depth), kernel_size=1)
38 |
39 | self.bev_expand_c = nn.Conv2d(32, 256, kernel_size=1)
40 |
41 | self.z_max = z_max
42 | self.z_min = z_min
43 | self.cell_size = cell_size
44 |
45 | def forward(self, features, calib, grid):
46 | # print(' BEV input shape:', features.shape)
47 | # Condense channel dimensions
48 | # [B, C, H, W] --> [B, C_cond, H, W]
49 | condensed_c = self.conv_c(features)
50 |
51 | # Reshape input tensor then condense input
52 | # features along the y dimension (height)
53 | # [B, C_cond, H, W] --> [B, C_cond, W, H] --> [B, C_cond, W, 1]
54 | bottleneck = self.linear_v(condensed_c.permute(0, 1, 3, 2))
55 |
56 | # condense collapsed v features along channels
57 | # [B, C, W, 1] --> [B, 1, W, C] --> [B, 1, W, C_condensed]
58 | # bottleneck = self.linear_c(condensed_h.permute((0, 3, 2, 1)))
59 |
60 | # Expand the bottleneck along the z dimension (depth)
61 | # [B, 1, C, W] --> [B, Z, C, W] --> [B, C, Z, W]
62 | bev_polar_feats = self.z_extend(bottleneck.permute(0, 3, 1, 2)).permute(
63 | 0, 2, 1, 3
64 | )
65 | # bev_polar_feats_gn = self.gn_polar(bev_polar_feats)
66 |
67 | # print(' BEV polar feats shape:', bev_polar_feats.shape)
68 |
69 | # Normalise grid to [-1, 1]
70 | norm_grid = self.normalise_grid(grid, calib)
71 |
72 | # TODO compute features within voxels of 2D grid instead of at coordinates
73 |
74 | bev_cart_feats = F.grid_sample(bev_polar_feats, norm_grid, align_corners=True)
75 |
76 | bev_cart_feats_expand = self.bev_expand_c(bev_cart_feats)
77 |
78 | # print(' BEV cart feats shape:', bev_cart_feats.shape)
79 | return bev_cart_feats_expand
80 |
81 | def normalise_grid(self, grid, calib):
82 | """
83 | :param grid: BEV grid in with coords range
84 | [grid_h1, grid_h2] and [-grid_w/2, grid_w/2]
85 | :param calib:
86 | :return:
87 | """
88 |
89 | f, cu = calib[:, 0, 0], calib[:, 0, 2]
90 | batch_size = len(calib)
91 |
92 | # Compute positive x dimension at z_max and z_min
93 | # Computed x dimension is half grid width
94 | x_zmax = self.z_max / f * cu
95 | x_zmin = self.z_min / f * cu
96 |
97 | # Compute normalising constant for each row along the z-axis
98 | sample_res_z = (self.z_max - self.z_min) / self.cell_size
99 | sample_res_x = grid.shape[2] / self.cell_size
100 |
101 | norm_z = (
102 | 2
103 | * (grid[..., 1] - grid[..., 1].min())
104 | / (grid[..., 1].max() - grid[..., 1].min())
105 | - 1
106 | )
107 |
108 | norm_scale_x = torch.stack(
109 | [
110 | torch.linspace(float(x_zmin[i]), float(x_zmax[i]), int(sample_res_z))
111 | for i in range(batch_size)
112 | ]
113 | )
114 |
115 | grid_ones = torch.ones_like(grid)
116 | grid_ones[..., 0] *= norm_scale_x.view(batch_size, -1, 1).cuda()
117 |
118 | # Normalise grid to [-1, 1]
119 | norm_grid = grid / grid_ones
120 |
121 | norm_grid[..., 1] = norm_z
122 |
123 | # print(' norm grid', norm_grid[...,0].max(), norm_grid[...,0].min(),
124 | # norm_grid[...,1].max(), norm_grid[...,1].min())
125 |
126 | return norm_grid
127 |
128 |
129 | class BEVT_H(nn.Module):
130 | def __init__(
131 | self,
132 | in_height,
133 | z_max,
134 | z_min,
135 | cell_size,
136 | additions_linear=False,
137 | additions_conv=False,
138 | kernel_h=9,
139 | stride_h=3,
140 | padding_h=4,
141 | ):
142 | super().__init__()
143 |
144 | self.horizontal = nn.Sequential(
145 | nn.Conv2d(
146 | 256,
147 | 256,
148 | kernel_size=[3, kernel_h],
149 | stride=[1, stride_h],
150 | padding=[1, padding_h],
151 | ),
152 | nn.GroupNorm(16, 256),
153 | nn.ReLU(),
154 | )
155 |
156 | # [B, C, W, H] --> [B, C, W, 1]
157 | if additions_linear:
158 | self.linear = nn.Sequential(
159 | nn.Linear(in_height, 1),
160 | nn.GroupNorm(16, 256),
161 | nn.ReLU(),
162 | nn.Dropout(p=0.25),
163 | )
164 | else:
165 | self.linear = nn.Linear(in_height, 1)
166 |
167 | # [B, 1, C, W] --> [B, Z, C, W]
168 | depth = (z_max - z_min) / cell_size
169 |
170 | self.additions_conv = additions_conv
171 | if additions_conv:
172 | self.z_extend = nn.Sequential(
173 | nn.Conv2d(1, int(depth), kernel_size=1),
174 | )
175 | self.gn_polar = nn.GroupNorm(16, 256)
176 | else:
177 | self.z_extend = nn.Conv2d(1, int(depth), kernel_size=1)
178 |
179 | self.z_max = z_max
180 | self.z_min = z_min
181 | self.cell_size = cell_size
182 |
183 | def forward(self, features, calib, grid):
184 |
185 | # Convolve horizontally first
186 | features = self.horizontal(features)
187 |
188 | # Reshape input tensor then collapse input
189 | # features along the y dimension (height)
190 | # [B, C, H, W] --> [B, C, W, H] --> [B, C, W, 1]
191 | bottleneck = self.linear(features.permute(0, 1, 3, 2))
192 |
193 | # Expand the bottleneck along the z dimension (depth)
194 | # [B, 1, C, W] --> [B, Z, C, W] --> [B, C, Z, W]
195 | bev_polar_feats = self.z_extend(bottleneck.permute(0, 3, 1, 2)).permute(
196 | 0, 2, 1, 3
197 | )
198 | # bev_polar_feats_gn = self.gn_polar(bev_polar_feats)
199 |
200 | # print(' BEV polar feats shape:', bev_polar_feats.shape)
201 |
202 | # Normalise grid to [-1, 1]
203 | norm_grid = self.normalise_grid(grid, calib)
204 |
205 | if self.additions_conv:
206 | bev_polar_feats_gn = self.gn_polar(bev_polar_feats)
207 | # Sample BEV polar features at grid locations
208 | # [B, C, Z, W] --> [B, C, Z, X]
209 | bev_cart_feats = F.grid_sample(
210 | bev_polar_feats_gn, norm_grid, align_corners=True
211 | )
212 | # print(' BEV cart feats shape:', bev_cart_feats.shape)
213 | return bev_cart_feats
214 | else:
215 | bev_cart_feats = F.grid_sample(
216 | bev_polar_feats, norm_grid, align_corners=True
217 | )
218 | # print(' BEV cart feats shape:', bev_cart_feats.shape)
219 | return bev_cart_feats
220 |
221 | def normalise_grid(self, grid, calib):
222 | """
223 | :param grid: BEV grid in with coords range
224 | [grid_h1, grid_h2] and [-grid_w/2, grid_w/2]
225 | :param calib:
226 | :return:
227 | """
228 |
229 | f, cu = calib[:, 0, 0], calib[:, 0, 2]
230 | batch_size = len(calib)
231 |
232 | # Compute positive x dimension at z_max and z_min
233 | # Computed x dimension is half grid width
234 | x_zmax = self.z_max / f * cu
235 | x_zmin = self.z_min / f * cu
236 |
237 | # Compute normalising constant for each row along the z-axis
238 | sample_res_z = (self.z_max - self.z_min) / self.cell_size
239 | sample_res_x = grid.shape[2] / self.cell_size
240 |
241 | norm_z = (
242 | 2
243 | * (grid[..., 1] - grid[..., 1].min())
244 | / (grid[..., 1].max() - grid[..., 1].min())
245 | - 1
246 | )
247 |
248 | norm_scale_x = torch.stack(
249 | [
250 | torch.linspace(float(x_zmin[i]), float(x_zmax[i]), int(sample_res_z))
251 | for i in range(batch_size)
252 | ]
253 | )
254 |
255 | grid_ones = torch.ones_like(grid)
256 | grid_ones[..., 0] *= norm_scale_x.view(batch_size, -1, 1).cuda()
257 |
258 | # Normalise grid to [-1, 1]
259 | norm_grid = grid / grid_ones
260 |
261 | norm_grid[..., 1] = norm_z
262 |
263 | return norm_grid
264 |
265 |
266 | class BEVT(nn.Module):
267 | def __init__(
268 | self,
269 | in_height,
270 | z_max,
271 | z_min,
272 | cell_size,
273 | additions_linear=False,
274 | additions_conv=False,
275 | ):
276 | super().__init__()
277 |
278 | # [B, C, W, H] --> [B, C, W, 1]
279 | if additions_linear:
280 | self.linear = nn.Sequential(
281 | nn.Linear(in_height, 1),
282 | nn.GroupNorm(16, 256),
283 | nn.ReLU(),
284 | nn.Dropout(p=0.25),
285 | )
286 | else:
287 | self.linear = nn.Linear(in_height, 1)
288 |
289 | # [B, 1, C, W] --> [B, Z, C, W]
290 | depth = (z_max - z_min) / cell_size
291 |
292 | self.additions_conv = additions_conv
293 | if additions_conv:
294 | self.z_extend = nn.Sequential(
295 | nn.Conv2d(1, int(depth), kernel_size=1),
296 | )
297 | self.gn_polar = nn.GroupNorm(16, 256)
298 | else:
299 | self.z_extend = nn.Conv2d(1, int(depth), kernel_size=1)
300 |
301 | self.z_max = z_max
302 | self.z_min = z_min
303 | self.cell_size = cell_size
304 |
305 | def forward(self, features, calib, grid):
306 | # print(' BEV input shape:', features.shape)
307 | # Reshape input tensor then collapse input
308 | # features along the y dimension (height)
309 | # [B, C, H, W] --> [B, C, W, H] --> [B, C, W, 1]
310 | bottleneck = self.linear(features.permute(0, 1, 3, 2))
311 |
312 | # Expand the bottleneck along the z dimension (depth)
313 | # [B, 1, C, W] --> [B, Z, C, W] --> [B, C, Z, W]
314 | bev_polar_feats = self.z_extend(bottleneck.permute(0, 3, 1, 2)).permute(
315 | 0, 2, 1, 3
316 | )
317 | # bev_polar_feats_gn = self.gn_polar(bev_polar_feats)
318 |
319 | # print(' BEV polar feats shape:', bev_polar_feats.shape)
320 |
321 | # Normalise grid to [-1, 1]
322 | norm_grid = self.normalise_grid(grid, calib)
323 |
324 | if self.additions_conv:
325 | bev_polar_feats_gn = self.gn_polar(bev_polar_feats)
326 | # Sample BEV polar features at grid locations
327 | # [B, C, Z, W] --> [B, C, Z, X]
328 | bev_cart_feats = F.grid_sample(
329 | bev_polar_feats_gn, norm_grid, align_corners=True
330 | )
331 | # print(' BEV cart feats shape:', bev_cart_feats.shape)
332 | return bev_cart_feats
333 | else:
334 | bev_cart_feats = F.grid_sample(
335 | bev_polar_feats, norm_grid, align_corners=True
336 | )
337 | # print(' BEV cart feats shape:', bev_cart_feats.shape)
338 | return bev_cart_feats
339 |
340 | def normalise_grid(self, grid, calib):
341 | """
342 | :param grid: BEV grid in with coords range
343 | [grid_h1, grid_h2] and [-grid_w/2, grid_w/2]
344 | :param calib:
345 | :return:
346 | """
347 |
348 | f, cu = calib[:, 0, 0], calib[:, 0, 2]
349 | batch_size = len(calib)
350 |
351 | # Compute positive x dimension at z_max and z_min
352 | # Computed x dimension is half grid width
353 | x_zmax = self.z_max / f * cu
354 | x_zmin = self.z_min / f * cu
355 |
356 | # Compute normalising constant for each row along the z-axis
357 | sample_res_z = (self.z_max - self.z_min) / self.cell_size
358 | sample_res_x = grid.shape[2] / self.cell_size
359 |
360 | norm_z = (
361 | 2
362 | * (grid[..., 1] - grid[..., 1].min())
363 | / (grid[..., 1].max() - grid[..., 1].min())
364 | - 1
365 | )
366 |
367 | norm_scale_x = torch.stack(
368 | [
369 | torch.linspace(float(x_zmin[i]), float(x_zmax[i]), int(sample_res_z))
370 | for i in range(batch_size)
371 | ]
372 | )
373 |
374 | grid_ones = torch.ones_like(grid)
375 | grid_ones[..., 0] *= norm_scale_x.view(batch_size, -1, 1).cuda()
376 |
377 | # Normalise grid to [-1, 1]
378 | norm_grid = grid / grid_ones
379 |
380 | norm_grid[..., 1] = norm_z
381 |
382 | return norm_grid
383 |
384 |
385 | class sample_polar2cart(nn.Module):
386 | def __init__(
387 | self,
388 | z_max,
389 | z_min,
390 | cell_size,
391 | ):
392 | super().__init__()
393 |
394 | self.z_max = z_max
395 | self.z_min = z_min
396 | self.cell_size = cell_size
397 |
398 | def forward(self, features, calib, grid):
399 |
400 | # Normalise grid to [-1, 1]
401 | norm_grid = self.normalise_grid(grid, calib)
402 |
403 | bev_cart_feats = F.grid_sample(features, norm_grid, align_corners=True)
404 | return bev_cart_feats
405 |
406 | def normalise_grid(self, grid, calib):
407 | """
408 | :param grid: BEV grid in with coords range
409 | [grid_h1, grid_h2] and [-grid_w/2, grid_w/2]
410 | :param calib:
411 | :return:
412 | """
413 |
414 | f, cu = calib[:, 0, 0], calib[:, 0, 2]
415 | batch_size = len(calib)
416 |
417 | # Compute positive x dimension at z_max and z_min
418 | # Computed x dimension is half grid width
419 | x_zmax = self.z_max / f * cu
420 | x_zmin = self.z_min / f * cu
421 |
422 | # Compute normalising constant for each row along the z-axis
423 | sample_res_z = (self.z_max - self.z_min) / self.cell_size
424 | sample_res_x = grid.shape[2] / self.cell_size
425 |
426 | norm_z = (
427 | 2
428 | * (grid[..., 1] - grid[..., 1].min())
429 | / (grid[..., 1].max() - grid[..., 1].min())
430 | - 1
431 | )
432 |
433 | norm_scale_x = torch.stack(
434 | [
435 | torch.linspace(float(x_zmin[i]), float(x_zmax[i]), int(sample_res_z))
436 | for i in range(batch_size)
437 | ]
438 | )
439 |
440 | grid_ones = torch.ones_like(grid)
441 | grid_ones[..., 0] *= norm_scale_x.view(batch_size, -1, 1).cuda()
442 |
443 | # Normalise grid to [-1, 1]
444 | norm_grid = grid / grid_ones
445 |
446 | norm_grid[..., 1] = norm_z
447 |
448 | return norm_grid
449 |
450 |
451 | def integral_image(features):
452 | return torch.cumsum(torch.cumsum(features, dim=-1), dim=-2)
453 |
--------------------------------------------------------------------------------
/src/model/dla.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | import math
5 | from os.path import join
6 |
7 | import torch
8 | from torch import nn
9 | import torch.utils.model_zoo as model_zoo
10 |
11 | import src.model.dla_dataset as dataset
12 |
13 | BatchNorm = nn.BatchNorm2d
14 |
15 | WEB_ROOT = "http://dl.yf.io/dla/models"
16 |
17 |
18 | def get_model_url(data, name):
19 | return join(WEB_ROOT, data.name, "{}-{}.pth".format(name, data.model_hash[name]))
20 |
21 |
22 | def conv3x3(in_planes, out_planes, stride=1):
23 | "3x3 convolution with padding"
24 | return nn.Conv2d(
25 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
26 | )
27 |
28 |
29 | class BasicBlock(nn.Module):
30 | def __init__(self, inplanes, planes, stride=1, dilation=1):
31 | super(BasicBlock, self).__init__()
32 | self.conv1 = nn.Conv2d(
33 | inplanes,
34 | planes,
35 | kernel_size=3,
36 | stride=stride,
37 | padding=dilation,
38 | bias=False,
39 | dilation=dilation,
40 | )
41 | self.bn1 = BatchNorm(planes)
42 | self.relu = nn.ReLU(inplace=True)
43 | self.conv2 = nn.Conv2d(
44 | planes,
45 | planes,
46 | kernel_size=3,
47 | stride=1,
48 | padding=dilation,
49 | bias=False,
50 | dilation=dilation,
51 | )
52 | self.bn2 = BatchNorm(planes)
53 | self.stride = stride
54 |
55 | def forward(self, x, residual=None):
56 | if residual is None:
57 | residual = x
58 |
59 | out = self.conv1(x)
60 | out = self.bn1(out)
61 | out = self.relu(out)
62 |
63 | out = self.conv2(out)
64 | out = self.bn2(out)
65 |
66 | out += residual
67 | out = self.relu(out)
68 |
69 | return out
70 |
71 |
72 | class Bottleneck(nn.Module):
73 | expansion = 2
74 |
75 | def __init__(self, inplanes, planes, stride=1, dilation=1):
76 | super(Bottleneck, self).__init__()
77 | expansion = Bottleneck.expansion
78 | bottle_planes = planes // expansion
79 | self.conv1 = nn.Conv2d(inplanes, bottle_planes, kernel_size=1, bias=False)
80 | self.bn1 = BatchNorm(bottle_planes)
81 | self.conv2 = nn.Conv2d(
82 | bottle_planes,
83 | bottle_planes,
84 | kernel_size=3,
85 | stride=stride,
86 | padding=dilation,
87 | bias=False,
88 | dilation=dilation,
89 | )
90 | self.bn2 = BatchNorm(bottle_planes)
91 | self.conv3 = nn.Conv2d(bottle_planes, planes, kernel_size=1, bias=False)
92 | self.bn3 = BatchNorm(planes)
93 | self.relu = nn.ReLU(inplace=True)
94 | self.stride = stride
95 |
96 | def forward(self, x, residual=None):
97 | if residual is None:
98 | residual = x
99 |
100 | out = self.conv1(x)
101 | out = self.bn1(out)
102 | out = self.relu(out)
103 |
104 | out = self.conv2(out)
105 | out = self.bn2(out)
106 | out = self.relu(out)
107 |
108 | out = self.conv3(out)
109 | out = self.bn3(out)
110 |
111 | out += residual
112 | out = self.relu(out)
113 |
114 | return out
115 |
116 |
117 | class BottleneckX(nn.Module):
118 | expansion = 2
119 | cardinality = 32
120 |
121 | def __init__(self, inplanes, planes, stride=1, dilation=1):
122 | super(BottleneckX, self).__init__()
123 | cardinality = BottleneckX.cardinality
124 | # dim = int(math.floor(planes * (BottleneckV5.expansion / 64.0)))
125 | # bottle_planes = dim * cardinality
126 | bottle_planes = planes * cardinality // 32
127 | self.conv1 = nn.Conv2d(inplanes, bottle_planes, kernel_size=1, bias=False)
128 | self.bn1 = BatchNorm(bottle_planes)
129 | self.conv2 = nn.Conv2d(
130 | bottle_planes,
131 | bottle_planes,
132 | kernel_size=3,
133 | stride=stride,
134 | padding=dilation,
135 | bias=False,
136 | dilation=dilation,
137 | groups=cardinality,
138 | )
139 | self.bn2 = BatchNorm(bottle_planes)
140 | self.conv3 = nn.Conv2d(bottle_planes, planes, kernel_size=1, bias=False)
141 | self.bn3 = BatchNorm(planes)
142 | self.relu = nn.ReLU(inplace=True)
143 | self.stride = stride
144 |
145 | def forward(self, x, residual=None):
146 | if residual is None:
147 | residual = x
148 |
149 | out = self.conv1(x)
150 | out = self.bn1(out)
151 | out = self.relu(out)
152 |
153 | out = self.conv2(out)
154 | out = self.bn2(out)
155 | out = self.relu(out)
156 |
157 | out = self.conv3(out)
158 | out = self.bn3(out)
159 |
160 | out += residual
161 | out = self.relu(out)
162 |
163 | return out
164 |
165 |
166 | class Root(nn.Module):
167 | def __init__(self, in_channels, out_channels, kernel_size, residual):
168 | super(Root, self).__init__()
169 | self.conv = nn.Conv2d(
170 | in_channels,
171 | out_channels,
172 | kernel_size,
173 | stride=1,
174 | bias=False,
175 | padding=(kernel_size - 1) // 2,
176 | )
177 | self.bn = BatchNorm(out_channels)
178 | self.relu = nn.ReLU(inplace=True)
179 | self.residual = residual
180 |
181 | def forward(self, *x):
182 | children = x
183 | x = self.conv(torch.cat(x, 1))
184 | x = self.bn(x)
185 | if self.residual:
186 | x += children[0]
187 | x = self.relu(x)
188 |
189 | return x
190 |
191 |
192 | class Tree(nn.Module):
193 | def __init__(
194 | self,
195 | levels,
196 | block,
197 | in_channels,
198 | out_channels,
199 | stride=1,
200 | level_root=False,
201 | root_dim=0,
202 | root_kernel_size=1,
203 | dilation=1,
204 | root_residual=False,
205 | ):
206 | super(Tree, self).__init__()
207 | if root_dim == 0:
208 | root_dim = 2 * out_channels
209 | if level_root:
210 | root_dim += in_channels
211 | if levels == 1:
212 | self.tree1 = block(in_channels, out_channels, stride, dilation=dilation)
213 | self.tree2 = block(out_channels, out_channels, 1, dilation=dilation)
214 | else:
215 | self.tree1 = Tree(
216 | levels - 1,
217 | block,
218 | in_channels,
219 | out_channels,
220 | stride,
221 | root_dim=0,
222 | root_kernel_size=root_kernel_size,
223 | dilation=dilation,
224 | root_residual=root_residual,
225 | )
226 | self.tree2 = Tree(
227 | levels - 1,
228 | block,
229 | out_channels,
230 | out_channels,
231 | root_dim=root_dim + out_channels,
232 | root_kernel_size=root_kernel_size,
233 | dilation=dilation,
234 | root_residual=root_residual,
235 | )
236 | if levels == 1:
237 | self.root = Root(root_dim, out_channels, root_kernel_size, root_residual)
238 | self.level_root = level_root
239 | self.root_dim = root_dim
240 | self.downsample = None
241 | self.project = None
242 | self.levels = levels
243 | if stride > 1:
244 | self.downsample = nn.MaxPool2d(stride, stride=stride)
245 | if in_channels != out_channels:
246 | self.project = nn.Sequential(
247 | nn.Conv2d(
248 | in_channels, out_channels, kernel_size=1, stride=1, bias=False
249 | ),
250 | BatchNorm(out_channels),
251 | )
252 |
253 | def forward(self, x, residual=None, children=None):
254 | children = [] if children is None else children
255 | bottom = self.downsample(x) if self.downsample else x
256 | residual = self.project(bottom) if self.project else bottom
257 | if self.level_root:
258 | children.append(bottom)
259 | x1 = self.tree1(x, residual)
260 | if self.levels == 1:
261 | x2 = self.tree2(x1)
262 | x = self.root(x2, x1, *children)
263 | else:
264 | children.append(x1)
265 | x = self.tree2(x1, children=children)
266 | return x
267 |
268 |
269 | class DLA(nn.Module):
270 | def __init__(
271 | self,
272 | levels,
273 | channels,
274 | num_classes=1000,
275 | block=BasicBlock,
276 | residual_root=False,
277 | return_levels=False,
278 | pool_size=7,
279 | linear_root=False,
280 | ):
281 | super(DLA, self).__init__()
282 | self.channels = channels
283 | self.return_levels = return_levels
284 | self.num_classes = num_classes
285 | self.base_layer = nn.Sequential(
286 | nn.Conv2d(3, channels[0], kernel_size=7, stride=1, padding=3, bias=False),
287 | BatchNorm(channels[0]),
288 | nn.ReLU(inplace=True),
289 | )
290 | self.level0 = self._make_conv_level(channels[0], channels[0], levels[0])
291 | self.level1 = self._make_conv_level(
292 | channels[0], channels[1], levels[1], stride=2
293 | )
294 | self.level2 = Tree(
295 | levels[2],
296 | block,
297 | channels[1],
298 | channels[2],
299 | 2,
300 | level_root=False,
301 | root_residual=residual_root,
302 | )
303 | self.level3 = Tree(
304 | levels[3],
305 | block,
306 | channels[2],
307 | channels[3],
308 | 2,
309 | level_root=True,
310 | root_residual=residual_root,
311 | )
312 | self.level4 = Tree(
313 | levels[4],
314 | block,
315 | channels[3],
316 | channels[4],
317 | 2,
318 | level_root=True,
319 | root_residual=residual_root,
320 | )
321 | self.level5 = Tree(
322 | levels[5],
323 | block,
324 | channels[4],
325 | channels[5],
326 | 2,
327 | level_root=True,
328 | root_residual=residual_root,
329 | )
330 |
331 | self.avgpool = nn.AvgPool2d(pool_size)
332 | self.fc = nn.Conv2d(
333 | channels[-1], num_classes, kernel_size=1, stride=1, padding=0, bias=True
334 | )
335 |
336 | for m in self.modules():
337 | if isinstance(m, nn.Conv2d):
338 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
339 | m.weight.data.normal_(0, math.sqrt(2.0 / n))
340 | elif isinstance(m, BatchNorm):
341 | m.weight.data.fill_(1)
342 | m.bias.data.zero_()
343 |
344 | def _make_level(self, block, inplanes, planes, blocks, stride=1):
345 | downsample = None
346 | if stride != 1 or inplanes != planes:
347 | downsample = nn.Sequential(
348 | nn.MaxPool2d(stride, stride=stride),
349 | nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False),
350 | BatchNorm(planes),
351 | )
352 |
353 | layers = []
354 | layers.append(block(inplanes, planes, stride, downsample=downsample))
355 | for i in range(1, blocks):
356 | layers.append(block(inplanes, planes))
357 |
358 | return nn.Sequential(*layers)
359 |
360 | def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1):
361 | modules = []
362 | for i in range(convs):
363 | modules.extend(
364 | [
365 | nn.Conv2d(
366 | inplanes,
367 | planes,
368 | kernel_size=3,
369 | stride=stride if i == 0 else 1,
370 | padding=dilation,
371 | bias=False,
372 | dilation=dilation,
373 | ),
374 | BatchNorm(planes),
375 | nn.ReLU(inplace=True),
376 | ]
377 | )
378 | inplanes = planes
379 | return nn.Sequential(*modules)
380 |
381 | def forward(self, x):
382 | y = []
383 | x = self.base_layer(x)
384 | for i in range(6):
385 | x = getattr(self, "level{}".format(i))(x)
386 | y.append(x)
387 | if self.return_levels:
388 | return y
389 | else:
390 | x = self.avgpool(x)
391 | x = self.fc(x)
392 | x = x.view(x.size(0), -1)
393 |
394 | return x
395 |
396 | def load_pretrained_model(self, data_name, name):
397 | assert data_name in dataset.__dict__, "No pretrained model for {}".format(
398 | data_name
399 | )
400 | data = dataset.__dict__[data_name]
401 | fc = self.fc
402 | if self.num_classes != data.classes:
403 | self.fc = nn.Conv2d(
404 | self.channels[-1],
405 | data.classes,
406 | kernel_size=1,
407 | stride=1,
408 | padding=0,
409 | bias=True,
410 | )
411 | try:
412 | model_url = get_model_url(data, name)
413 | except KeyError:
414 | raise ValueError("{} trained on {} does not exist.".format(data.name, name))
415 | self.load_state_dict(model_zoo.load_url(model_url))
416 | self.fc = fc
417 |
418 |
419 | def dla34(pretrained=None, **kwargs): # DLA-34
420 | model = DLA(
421 | [1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512], block=BasicBlock, **kwargs
422 | )
423 | if pretrained is not None:
424 | model.load_pretrained_model(pretrained, "dla34")
425 | return model
426 |
427 |
428 | def dla46_c(pretrained=None, **kwargs): # DLA-46-C
429 | Bottleneck.expansion = 2
430 | model = DLA(
431 | [1, 1, 1, 2, 2, 1], [16, 32, 64, 64, 128, 256], block=Bottleneck, **kwargs
432 | )
433 | if pretrained is not None:
434 | model.load_pretrained_model(pretrained, "dla46_c")
435 | return model
436 |
437 |
438 | def dla46x_c(pretrained=None, **kwargs): # DLA-X-46-C
439 | BottleneckX.expansion = 2
440 | model = DLA(
441 | [1, 1, 1, 2, 2, 1], [16, 32, 64, 64, 128, 256], block=BottleneckX, **kwargs
442 | )
443 | if pretrained is not None:
444 | model.load_pretrained_model(pretrained, "dla46x_c")
445 | return model
446 |
447 |
448 | def dla60x_c(pretrained=None, **kwargs): # DLA-X-60-C
449 | BottleneckX.expansion = 2
450 | model = DLA(
451 | [1, 1, 1, 2, 3, 1], [16, 32, 64, 64, 128, 256], block=BottleneckX, **kwargs
452 | )
453 | if pretrained is not None:
454 | model.load_pretrained_model(pretrained, "dla60x_c")
455 | return model
456 |
457 |
458 | def dla60(pretrained=None, **kwargs): # DLA-60
459 | Bottleneck.expansion = 2
460 | model = DLA(
461 | [1, 1, 1, 2, 3, 1], [16, 32, 128, 256, 512, 1024], block=Bottleneck, **kwargs
462 | )
463 | if pretrained is not None:
464 | model.load_pretrained_model(pretrained, "dla60")
465 | return model
466 |
467 |
468 | def dla60x(pretrained=None, **kwargs): # DLA-X-60
469 | BottleneckX.expansion = 2
470 | model = DLA(
471 | [1, 1, 1, 2, 3, 1], [16, 32, 128, 256, 512, 1024], block=BottleneckX, **kwargs
472 | )
473 | if pretrained is not None:
474 | model.load_pretrained_model(pretrained, "dla60x")
475 | return model
476 |
477 |
478 | def dla102(pretrained=None, **kwargs): # DLA-102
479 | Bottleneck.expansion = 2
480 | model = DLA(
481 | [1, 1, 1, 3, 4, 1],
482 | [16, 32, 128, 256, 512, 1024],
483 | block=Bottleneck,
484 | residual_root=True,
485 | **kwargs
486 | )
487 | if pretrained is not None:
488 | model.load_pretrained_model(pretrained, "dla102")
489 | return model
490 |
491 |
492 | def dla102x(pretrained=None, **kwargs): # DLA-X-102
493 | BottleneckX.expansion = 2
494 | model = DLA(
495 | [1, 1, 1, 3, 4, 1],
496 | [16, 32, 128, 256, 512, 1024],
497 | block=BottleneckX,
498 | residual_root=True,
499 | **kwargs
500 | )
501 | if pretrained is not None:
502 | model.load_pretrained_model(pretrained, "dla102x")
503 | return model
504 |
505 |
506 | def dla102x2(pretrained=None, **kwargs): # DLA-X-102 64
507 | BottleneckX.cardinality = 64
508 | model = DLA(
509 | [1, 1, 1, 3, 4, 1],
510 | [16, 32, 128, 256, 512, 1024],
511 | block=BottleneckX,
512 | residual_root=True,
513 | **kwargs
514 | )
515 | if pretrained is not None:
516 | model.load_pretrained_model(pretrained, "dla102x2")
517 | return model
518 |
519 |
520 | def dla169(pretrained=None, **kwargs): # DLA-169
521 | Bottleneck.expansion = 2
522 | model = DLA(
523 | [1, 1, 2, 3, 5, 1],
524 | [16, 32, 128, 256, 512, 1024],
525 | block=Bottleneck,
526 | residual_root=True,
527 | **kwargs
528 | )
529 | if pretrained is not None:
530 | model.load_pretrained_model(pretrained, "dla169")
531 | return model
532 |
--------------------------------------------------------------------------------
/src/model/dla_up.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import numpy as np
4 | import torch
5 | from torch import nn
6 |
7 | import src.model.dla as dla
8 |
9 | BatchNorm = nn.BatchNorm2d
10 |
11 |
12 | def set_bn(bn):
13 | global BatchNorm
14 | BatchNorm = bn
15 | dla.BatchNorm = bn
16 |
17 |
18 | class Identity(nn.Module):
19 | def __init__(self):
20 | super(Identity, self).__init__()
21 |
22 | def forward(self, x):
23 | return x
24 |
25 |
26 | def fill_up_weights(up):
27 | w = up.weight.data
28 | f = math.ceil(w.size(2) / 2)
29 | c = (2 * f - 1 - f % 2) / (2.0 * f)
30 | for i in range(w.size(2)):
31 | for j in range(w.size(3)):
32 | w[0, 0, i, j] = (1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c))
33 | for c in range(1, w.size(0)):
34 | w[c, 0, :, :] = w[0, 0, :, :]
35 |
36 |
37 | class IDAUp(nn.Module):
38 | def __init__(self, node_kernel, out_dim, channels, up_factors):
39 | super(IDAUp, self).__init__()
40 | self.channels = channels
41 | self.out_dim = out_dim
42 | for i, c in enumerate(channels):
43 | if c == out_dim:
44 | proj = Identity()
45 | else:
46 | proj = nn.Sequential(
47 | nn.Conv2d(c, out_dim, kernel_size=1, stride=1, bias=False),
48 | BatchNorm(out_dim),
49 | nn.ReLU(inplace=True),
50 | )
51 | f = int(up_factors[i])
52 | if f == 1:
53 | up = Identity()
54 | else:
55 | up = nn.ConvTranspose2d(
56 | out_dim,
57 | out_dim,
58 | f * 2,
59 | stride=f,
60 | padding=f // 2,
61 | output_padding=0,
62 | groups=out_dim,
63 | bias=False,
64 | )
65 | fill_up_weights(up)
66 | setattr(self, "proj_" + str(i), proj)
67 | setattr(self, "up_" + str(i), up)
68 |
69 | for i in range(1, len(channels)):
70 | node = nn.Sequential(
71 | nn.Conv2d(
72 | out_dim * 2,
73 | out_dim,
74 | kernel_size=node_kernel,
75 | stride=1,
76 | padding=node_kernel // 2,
77 | bias=False,
78 | ),
79 | BatchNorm(out_dim),
80 | nn.ReLU(inplace=True),
81 | )
82 | setattr(self, "node_" + str(i), node)
83 |
84 | for m in self.modules():
85 | if isinstance(m, nn.Conv2d):
86 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
87 | m.weight.data.normal_(0, math.sqrt(2.0 / n))
88 | elif isinstance(m, BatchNorm):
89 | m.weight.data.fill_(1)
90 | m.bias.data.zero_()
91 |
92 | def forward(self, layers):
93 | assert len(self.channels) == len(layers), "{} vs {} layers".format(
94 | len(self.channels), len(layers)
95 | )
96 | layers = list(layers)
97 | for i, l in enumerate(layers):
98 | upsample = getattr(self, "up_" + str(i))
99 | project = getattr(self, "proj_" + str(i))
100 | layers[i] = upsample(project(l))
101 | x = layers[0]
102 | y = []
103 | for i in range(1, len(layers)):
104 | node = getattr(self, "node_" + str(i))
105 | x = node(torch.cat([x, layers[i]], 1))
106 | y.append(x)
107 | return x, y
108 |
109 |
110 | class DLAUp(nn.Module):
111 | def __init__(self, channels, scales=(1, 2, 4, 8, 16), in_channels=None):
112 | super(DLAUp, self).__init__()
113 | if in_channels is None:
114 | in_channels = channels
115 | self.channels = channels
116 | channels = list(channels)
117 | scales = np.array(scales, dtype=int)
118 | for i in range(len(channels) - 1):
119 | j = -i - 2
120 | setattr(
121 | self,
122 | "ida_{}".format(i),
123 | IDAUp(3, channels[j], in_channels[j:], scales[j:] // scales[j]),
124 | )
125 | scales[j + 1 :] = scales[j]
126 | in_channels[j + 1 :] = [channels[j] for _ in channels[j + 1 :]]
127 |
128 | def forward(self, layers):
129 | layers = list(layers)
130 | assert len(layers) > 1
131 | for i in range(len(layers) - 1):
132 | ida = getattr(self, "ida_{}".format(i))
133 | x, y = ida(layers[-i - 2 :])
134 | layers[-i - 1 :] = y
135 | return x
136 |
137 |
138 | class DLASeg(nn.Module):
139 | def __init__(self, base_name, classes, pretrained_base=None, down_ratio=2):
140 | super(DLASeg, self).__init__()
141 | assert down_ratio in [2, 4, 8, 16]
142 | self.first_level = int(np.log2(down_ratio))
143 | self.base = dla.__dict__[base_name](
144 | pretrained=pretrained_base, return_levels=True
145 | )
146 | channels = self.base.channels
147 | scales = [2 ** i for i in range(len(channels[self.first_level :]))]
148 | self.dla_up = DLAUp(channels[self.first_level :], scales=scales)
149 | print(self.dla_up)
150 | self.fc = nn.Sequential(
151 | nn.Conv2d(
152 | channels[self.first_level],
153 | classes,
154 | kernel_size=1,
155 | stride=1,
156 | padding=0,
157 | bias=True,
158 | )
159 | )
160 | up_factor = 2 ** self.first_level
161 | if up_factor > 1:
162 | up = nn.ConvTranspose2d(
163 | classes,
164 | classes,
165 | up_factor * 2,
166 | stride=up_factor,
167 | padding=up_factor // 2,
168 | output_padding=0,
169 | groups=classes,
170 | bias=False,
171 | )
172 | fill_up_weights(up)
173 | up.weight.requires_grad = False
174 | else:
175 | up = Identity()
176 | self.up = up
177 | self.softmax = nn.LogSoftmax(dim=1)
178 |
179 | for m in self.fc.modules():
180 | if isinstance(m, nn.Conv2d):
181 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
182 | m.weight.data.normal_(0, math.sqrt(2.0 / n))
183 | elif isinstance(m, BatchNorm):
184 | m.weight.data.fill_(1)
185 | m.bias.data.zero_()
186 |
187 | def forward(self, x):
188 | x = self.base(x)
189 | x = self.dla_up(x[self.first_level :])
190 | x = self.fc(x)
191 | y = self.softmax(self.up(x))
192 | return y, x
193 |
194 | def optim_parameters(self, memo=None):
195 | for param in self.base.parameters():
196 | yield param
197 | for param in self.dla_up.parameters():
198 | yield param
199 | for param in self.fc.parameters():
200 | yield param
201 |
202 |
203 | def dla34up(classes, pretrained_base=None, **kwargs):
204 | model = DLASeg("dla34", classes, pretrained_base=pretrained_base, **kwargs)
205 | return model
206 |
207 |
208 | def dla60up(classes, pretrained_base=None, **kwargs):
209 | model = DLASeg("dla60", classes, pretrained_base=pretrained_base, **kwargs)
210 | return model
211 |
212 |
213 | def dla102up(classes, pretrained_base=None, **kwargs):
214 | model = DLASeg("dla102", classes, pretrained_base=pretrained_base, **kwargs)
215 | return model
216 |
217 |
218 | def dla169up(classes, pretrained_base=None, **kwargs):
219 | model = DLASeg("dla169", classes, pretrained_base=pretrained_base, **kwargs)
220 | return model
221 |
222 |
223 | model = dla34up(classes=3, pretrained_base=None, down_ratio=2)
224 | print(model)
225 |
--------------------------------------------------------------------------------
/src/model/fpn.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import nn, Tensor
6 |
7 | from torch.jit.annotations import Tuple, List, Dict, Optional
8 |
9 |
10 | class ExtraFPNBlock(nn.Module):
11 | """
12 | Base class for the extra block in the FPN.
13 | Arguments:
14 | results (List[Tensor]): the result of the FPN
15 | x (List[Tensor]): the original feature maps
16 | names (List[str]): the names for each one of the
17 | original feature maps
18 | Returns:
19 | results (List[Tensor]): the extended set of results
20 | of the FPN
21 | names (List[str]): the extended set of names for the results
22 | """
23 |
24 | def forward(
25 | self,
26 | results: List[Tensor],
27 | x: List[Tensor],
28 | names: List[str],
29 | ) -> Tuple[List[Tensor], List[str]]:
30 | pass
31 |
32 |
33 | class FeaturePyramidNetworkRoddick(nn.Module):
34 | """
35 | Module that adds a FPN from on top of a set of feature maps. This is based on
36 | `"Feature Pyramid Network for Object Detection" `_.
37 | The feature maps are currently supposed to be in increasing depth
38 | order.
39 | The input to the model is expected to be an OrderedDict[Tensor], containing
40 | the feature maps on top of which the FPN will be added.
41 | Arguments:
42 | in_channels_list (list[int]): number of channels for each feature map that
43 | is passed to the module
44 | out_channels (int): number of channels of the FPN representation
45 | extra_blocks (ExtraFPNBlock or None): if provided, extra operations will
46 | be performed. It is expected to take the fpn features, the original
47 | features and the names of the original features as input, and returns
48 | a new list of feature maps and their corresponding names
49 | Examples::
50 | >>> m = torchvision.ops.FeaturePyramidNetwork([10, 20, 30], 5)
51 | >>> # get some dummy data
52 | >>> x = OrderedDict()
53 | >>> x['feat0'] = torch.rand(1, 10, 64, 64)
54 | >>> x['feat2'] = torch.rand(1, 20, 16, 16)
55 | >>> x['feat3'] = torch.rand(1, 30, 8, 8)
56 | >>> # compute the FPN on top of x
57 | >>> output = m(x)
58 | >>> print([(k, v.shape) for k, v in output.items()])
59 | >>> # returns
60 | >>> [('feat0', torch.Size([1, 5, 64, 64])),
61 | >>> ('feat2', torch.Size([1, 5, 16, 16])),
62 | >>> ('feat3', torch.Size([1, 5, 8, 8]))]
63 | """
64 |
65 | def __init__(
66 | self,
67 | in_channels_list: List[int],
68 | out_channels: int,
69 | extra_blocks: Optional[ExtraFPNBlock] = None,
70 | ):
71 | super(FeaturePyramidNetworkRoddick, self).__init__()
72 | self.inner_blocks = nn.ModuleList()
73 | self.layer_blocks = nn.ModuleList()
74 | for in_channels in in_channels_list:
75 | if in_channels == 0:
76 | raise ValueError("in_channels=0 is currently not supported")
77 | inner_block_module = nn.Conv2d(in_channels, out_channels, 1)
78 | layer_block_module = nn.Conv2d(out_channels, out_channels, 3, padding=1)
79 | self.inner_blocks.append(inner_block_module)
80 | self.layer_blocks.append(layer_block_module)
81 |
82 | # initialize parameters now to avoid modifying the initialization of top_blocks
83 | for m in self.children():
84 | if isinstance(m, nn.Conv2d):
85 | nn.init.kaiming_uniform_(m.weight, a=1)
86 | nn.init.constant_(m.bias, 0)
87 |
88 | if extra_blocks is not None:
89 | assert isinstance(extra_blocks, ExtraFPNBlock)
90 | self.extra_blocks = extra_blocks
91 |
92 | def get_result_from_inner_blocks(self, x: Tensor, idx: int) -> Tensor:
93 | """
94 | This is equivalent to self.inner_blocks[idx](x),
95 | but torchscript doesn't support this yet
96 | """
97 | num_blocks = 0
98 | for m in self.inner_blocks:
99 | num_blocks += 1
100 | if idx < 0:
101 | idx += num_blocks
102 | i = 0
103 | out = x
104 | for module in self.inner_blocks:
105 | if i == idx:
106 | out = module(x)
107 | i += 1
108 | return out
109 |
110 | def get_result_from_layer_blocks(self, x: Tensor, idx: int) -> Tensor:
111 | """
112 | This is equivalent to self.layer_blocks[idx](x),
113 | but torchscript doesn't support this yet
114 | """
115 | num_blocks = 0
116 | for m in self.layer_blocks:
117 | num_blocks += 1
118 | if idx < 0:
119 | idx += num_blocks
120 | i = 0
121 | out = x
122 | for module in self.layer_blocks:
123 | if i == idx:
124 | out = module(x)
125 | i += 1
126 | return out
127 |
128 | def forward(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]:
129 | """
130 | Computes the FPN for a set of feature maps.
131 | Arguments:
132 | x (OrderedDict[Tensor]): feature maps for each feature level.
133 | Returns:
134 | results (OrderedDict[Tensor]): feature maps after FPN layers.
135 | They are ordered from highest resolution first.
136 | """
137 | # unpack OrderedDict into two lists for easier handling
138 | names = list(x.keys())
139 | x = list(x.values())
140 |
141 | last_inner = self.get_result_from_inner_blocks(x[-1], -1)
142 | results = []
143 | results.append(self.get_result_from_layer_blocks(last_inner, -1))
144 |
145 | for idx in range(len(x) - 2, -1, -1):
146 | inner_lateral = self.get_result_from_inner_blocks(x[idx], idx)
147 | feat_shape = inner_lateral.shape[-2:]
148 | inner_top_down = F.interpolate(last_inner, size=feat_shape, mode="nearest")
149 | last_inner = inner_lateral + inner_top_down
150 | results.insert(0, self.get_result_from_layer_blocks(last_inner, idx))
151 |
152 | if self.extra_blocks is not None:
153 | results, names = self.extra_blocks(results, x, names)
154 |
155 | # make it back an OrderedDict
156 | out = OrderedDict([(k, v) for k, v in zip(names, results)])
157 |
158 | return out
159 |
160 |
161 | class LLMaxPool(ExtraFPNBlock):
162 | """
163 | Applies a max_pool2d on top of the last feature map
164 | """
165 |
166 | def forward(
167 | self,
168 | x: List[Tensor],
169 | y: List[Tensor],
170 | names: List[str],
171 | ) -> Tuple[List[Tensor], List[str]]:
172 | names.append("pool")
173 | x.append(F.max_pool2d(x[-1], 1, 2, 0))
174 | return x, names
175 |
176 |
177 | class LastLevelP6P7(ExtraFPNBlock):
178 | """
179 | This module is used in RetinaNet to generate extra layers, P6 and P7.
180 | """
181 |
182 | def __init__(self, in_channels: int, out_channels: int):
183 | super(LastLevelP6P7, self).__init__()
184 | self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)
185 | self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1)
186 | for module in [self.p6, self.p7]:
187 | nn.init.kaiming_uniform_(module.weight, a=1)
188 | nn.init.constant_(module.bias, 0)
189 | self.use_P5 = in_channels == out_channels
190 |
191 | def forward(
192 | self,
193 | p: List[Tensor],
194 | c: List[Tensor],
195 | names: List[str],
196 | ) -> Tuple[List[Tensor], List[str]]:
197 | p5, c5 = p[-1], c[-1]
198 | x = p5 if self.use_P5 else c5
199 | p6 = self.p6(x)
200 | p7 = self.p7(F.relu(p6))
201 | p.extend([p6, p7])
202 | names.extend(["p6", "p7"])
203 | return p, names
204 |
--------------------------------------------------------------------------------
/src/model/intermediate_layer_getter.py:
--------------------------------------------------------------------------------
1 | import functools
2 | from collections import OrderedDict
3 |
4 |
5 | # using wonder's beautiful simplification: https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-objects/31174427?noredirect=1#comment86638618_31174427
6 | def rgetattr(obj, attr, *args):
7 | def _getattr(obj, attr):
8 | return getattr(obj, attr, *args)
9 |
10 | return functools.reduce(_getattr, [obj] + attr.split("."))
11 |
12 |
13 | class IntermediateLayerGetter:
14 | def __init__(self, model, return_layers, keep_output=True):
15 | """Wraps a Pytorch module to get intermediate values
16 |
17 | Arguments:
18 | model {nn.module} -- The Pytorch module to call
19 | return_layers {dict} -- Dictionary with the selected submodules
20 | to return the output (format: {[current_module_name]: [desired_output_name]},
21 | current_module_name can be a nested submodule, e.g. submodule1.submodule2.submodule3)
22 |
23 | Keyword Arguments:
24 | keep_output {bool} -- If True model_output contains the final model's output
25 | in the other case model_output is None (default: {True})
26 | Returns:
27 | (mid_outputs {OrderedDict}, model_output {any}) -- mid_outputs keys are
28 | your desired_output_name (s) and their values are the returned tensors
29 | of those submodules (OrderedDict([(desired_output_name,tensor(...)), ...).
30 | See keep_output argument for model_output description.
31 | In case a submodule is called more than one time, all it's outputs are
32 | stored in a list.
33 | """
34 | self._model = model
35 | self.return_layers = return_layers
36 | self.keep_output = keep_output
37 |
38 | def __call__(self, *args, **kwargs):
39 | ret = OrderedDict()
40 | handles = []
41 | for name, new_name in self.return_layers.items():
42 | layer = rgetattr(self._model, name)
43 |
44 | def hook(module, input, output, new_name=new_name):
45 | if new_name in ret:
46 | if type(ret[new_name]) is list:
47 | ret[new_name].append(output)
48 | else:
49 | ret[new_name] = [ret[new_name], output]
50 | else:
51 | ret[new_name] = output
52 |
53 | try:
54 | h = layer.register_forward_hook(hook)
55 | except AttributeError as e:
56 | raise AttributeError(f"Module {name} not found")
57 | handles.append(h)
58 |
59 | if self.keep_output:
60 | output = self._model(*args, **kwargs)
61 | else:
62 | self._model(*args, **kwargs)
63 | output = None
64 |
65 | for h in handles:
66 | h.remove()
67 |
68 | return ret, output
69 |
--------------------------------------------------------------------------------
/src/model/loss.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import numpy as np
6 | import matplotlib.pyplot as plt
7 |
8 | import src
9 | from torch.nn.utils.rnn import pad_sequence
10 |
11 |
12 | def class_obj_size():
13 | """
14 | NuScenes class object size on Roddick Split
15 | gt_dir/name: semantic_maps_new_200x200
16 | """
17 |
18 | # Object size per class relative to map resolution
19 | s200 = [0.263, 0.031, 0.055, 0.029, 0.009, 0.002, 0.010, 0.001, 0.008, 0.006, 0.001]
20 | s100 = [0.296, 0.037, 0.066, 0.035, 0.011, 0.003, 0.013, 0.002, 0.010, 0.008, 0.002]
21 | s50 = [0.331, 0.046, 0.083, 0.044, 0.015, 0.004, 0.019, 0.003, 0.014, 0.011, 0.003]
22 | s25 = [0.380, 0.065, 0.114, 0.061, 0.022, 0.007, 0.031, 0.006, 0.022, 0.017, 0.008]
23 | s13 = [0.465, 0.105, 0.169, 0.093, 0.037, 0.016, 0.058, 0.013, 0.039, 0.031, 0.021]
24 | s7 = [0.599, 0.196, 0.279, 0.148, 0.069, 0.037, 0.117, 0.033, 0.076, 0.065, 0.054]
25 |
26 | ms_obj_size = torch.stack(
27 | [
28 | torch.tensor(s100),
29 | torch.tensor(s50),
30 | torch.tensor(s25),
31 | torch.tensor(s13),
32 | torch.tensor(s7),
33 | ]
34 | )
35 |
36 | return ms_obj_size.cuda()
37 |
38 |
39 | def get_class_frequency():
40 | """
41 | NuScenes class frequency on Roddick Split
42 | gt_dir/name: semantic_maps_new_200x200
43 | """
44 |
45 | # Class frequency
46 | s200 = [
47 | 0.999,
48 | 0.489,
49 | 0.969,
50 | 0.294,
51 | 0.094,
52 | 0.080,
53 | 0.758,
54 | 0.0646,
55 | 0.070,
56 | 0.080,
57 | 0.313,
58 | 0.485,
59 | ]
60 |
61 | return torch.tensor(s200).cuda()
62 |
63 |
64 | def class_weight():
65 | """
66 | NuScenes class frequency on Roddick Split
67 | gt_dir/name: semantic_maps_new_200x200
68 | """
69 |
70 | # Class weights
71 | w = torch.tensor([1.31, 5.43, 1.00, 2.13, 8.04, 1.19, 1.75, 5.71])
72 |
73 | return w
74 |
75 |
76 | def class_flood_level():
77 | ms_flood = torch.tensor(
78 | [
79 | [
80 | 0.0356,
81 | 0.2656,
82 | 0.1070,
83 | 0.6454,
84 | 0.8,
85 | 0.75,
86 | 0.1832,
87 | 0.8,
88 | 0.75,
89 | 0.6633,
90 | 0.5591,
91 | ],
92 | [
93 | 0.0294,
94 | 0.2418,
95 | 0.0902,
96 | 0.5982,
97 | 0.75,
98 | 0.7,
99 | 0.1589,
100 | 0.8,
101 | 0.75,
102 | 0.6302,
103 | 0.5116,
104 | ],
105 | [
106 | 0.0247,
107 | 0.2363,
108 | 0.0787,
109 | 0.5765,
110 | 0.75,
111 | 0.75,
112 | 0.1416,
113 | 0.8,
114 | 0.75,
115 | 0.6317,
116 | 0.4982,
117 | ],
118 | ]
119 | ).cuda()
120 | return ms_flood
121 |
122 |
123 | class MultiTaskLoss(nn.Module):
124 | def __init__(self, l1_reduction="mean"):
125 | super().__init__()
126 |
127 | self.log_vars = nn.Parameter(torch.zeros(2))
128 | self.l1_reduction = l1_reduction
129 |
130 | def forward(self, preds, labels, bev, bev_pred):
131 | precision1 = torch.exp(-self.log_vars[0])
132 | precision2 = torch.exp(-self.log_vars[1])
133 |
134 | s1_loss = dice_loss_mean(preds[0], labels[0])
135 | s2_loss = dice_loss_mean(preds[1], labels[1])
136 | s4_loss = dice_loss_mean(preds[2], labels[2])
137 |
138 | dice_l = s1_loss + s2_loss + s4_loss
139 | bev_l = F.l1_loss(bev_pred, bev, reduction=self.l1_reduction)
140 |
141 | total_loss = (precision1 * dice_l + self.log_vars[0]) + (
142 | precision2 * bev_l + self.log_vars[1]
143 | )
144 |
145 | return (
146 | total_loss,
147 | float(self.log_vars[0]),
148 | float(self.log_vars[1]),
149 | float(dice_l),
150 | float(bev_l),
151 | )
152 |
153 |
154 | def lovasz_grad(gt_sorted):
155 | """
156 | Computes gradient of the Lovasz extension w.r.t sorted errors
157 | See Alg. 1 in paper
158 | """
159 | p = len(gt_sorted)
160 | gts = gt_sorted.sum()
161 | intersection = gts - gt_sorted.float().cumsum(0)
162 | union = gts + (1 - gt_sorted).float().cumsum(0)
163 | jaccard = 1.0 - intersection / union
164 | if p > 1: # cover 1-pixel case
165 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
166 | return jaccard
167 |
168 |
169 | def lovasz_hinge_flat(logits, labels):
170 | """
171 | Binary Lovasz hinge loss
172 | logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
173 | labels: [P] Tensor, binary ground truth labels (0 or 1)
174 | ignore: label to ignore
175 | """
176 | logits = logits.view(-1)
177 | labels = labels.view(-1)
178 | signs = 2.0 * labels.float() - 1.0
179 | errors = 1.0 - logits * signs
180 | errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
181 | perm = perm.data
182 | gt_sorted = labels[perm]
183 | grad = lovasz_grad(gt_sorted)
184 | loss = torch.dot(F.relu(errors_sorted), grad)
185 | return loss
186 |
187 |
188 | def basic_weighted_dice_loss_mean(pred, label, idx_scale=None, args=None):
189 | pred = torch.sigmoid(pred)
190 | label = label.float()
191 | intersection = 2 * pred * label
192 | union = pred + label
193 | iou = (intersection.float().sum(dim=0).sum(dim=-1).sum(dim=-1)) / (
194 | union.float().sum(dim=0).sum(dim=-1).sum(dim=-1) + 1e-5
195 | )
196 |
197 | # Set up class weighting
198 | weight = class_weight()
199 |
200 | loss_mean = (weight * (1 - iou)).mean()
201 |
202 | return loss_mean
203 |
204 |
205 | def weighted_dice_loss_mean(pred, label, idx_scale, args):
206 | pred = torch.sigmoid(pred)
207 | label = label.float()
208 | intersection = 2 * pred * label
209 | union = pred + label
210 | iou = (intersection.float().sum(dim=0).sum(dim=-1).sum(dim=-1)) / (
211 | union.float().sum(dim=0).sum(dim=-1).sum(dim=-1) + 1e-5
212 | )
213 |
214 | # Set up class weighting
215 | obj_size = class_obj_size()[idx_scale]
216 | class_freq = class_frequency()[idx_scale]
217 | weight = 1 / obj_size ** args.exp_os * 1 / class_freq ** args.exp_cf
218 |
219 | # Set large class' weights to one
220 | idxs_to_one = torch.tensor([0, 1, 2, 3, 6]).long()
221 | weight[idxs_to_one] = 1
222 |
223 | loss_mean = (weight * (1 - iou)).mean()
224 |
225 | return loss_mean
226 |
227 |
228 | def weighted_dice_loss_sum(pred, label, idx_scale, args):
229 | pred = torch.sigmoid(pred)
230 | label = label.float()
231 | intersection = 2 * pred * label
232 | union = pred + label
233 | iou = (intersection.float().sum(dim=0).sum(dim=-1).sum(dim=-1) + 1e-5) / (
234 | union.float().sum(dim=0).sum(dim=-1).sum(dim=-1) + 1e-5
235 | )
236 |
237 | # Set up class weighting
238 | obj_size = class_obj_size()[idx_scale]
239 | class_freq = class_frequency()[idx_scale]
240 | weight = 1 / obj_size ** args.exp_os * 1 / class_freq ** args.exp_cf
241 |
242 | loss_mean = (weight * (1 - iou)).sum()
243 |
244 | return loss_mean
245 |
246 |
247 | def dice_loss_sum(pred, label, scale_idx=None, args=None):
248 | pred = torch.sigmoid(pred)
249 | label = label.float()
250 | intersection = 2 * pred * label
251 | union = pred + label
252 | iou = (intersection.float().sum(dim=0).sum(dim=-1).sum(dim=-1)) / (
253 | union.float().sum(dim=0).sum(dim=-1).sum(dim=-1) + 1e-5
254 | )
255 |
256 | loss_mean = (1 - iou).sum()
257 |
258 | return loss_mean
259 |
260 |
261 | def dice_loss_mean(pred, label, vis_mask=None, scale_idx=None, args=None):
262 | pred = torch.sigmoid(pred)
263 | label = label.float()
264 | intersection = 2 * pred * label
265 | union = pred + label
266 | iou = (intersection.float().sum(dim=0).sum(dim=-1).sum(dim=-1)) / (
267 | union.float().sum(dim=0).sum(dim=-1).sum(dim=-1) + 1e-5
268 | )
269 |
270 | loss_mean = 1 - iou.mean()
271 |
272 | return loss_mean
273 |
274 |
275 | def dice_loss_mean_wo_sigmoid(pred, label, vis_mask=None, scale_idx=None, args=None):
276 | label = label.float()
277 | intersection = 2 * pred * label
278 | union = pred + label
279 | iou = (intersection.float().sum()) / (union.float().sum() + 1e-5)
280 |
281 | loss_mean = 1 - iou.mean()
282 |
283 | return loss_mean
284 |
285 |
286 | def dice_loss_mean_vis_only(pred, label, vis_mask, scale_idx=None, args=None):
287 | # Calculates the loss in the visible areas only
288 | pred = torch.sigmoid(pred) * vis_mask.float()
289 | label = label.float()
290 | intersection = 2 * pred * label
291 | union = pred + label
292 | iou = (intersection.float().sum(dim=0).sum(dim=-1).sum(dim=-1)) / (
293 | union.float().sum(dim=0).sum(dim=-1).sum(dim=-1) + 1e-5
294 | )
295 |
296 | loss_mean = 1 - iou.mean()
297 |
298 | return loss_mean
299 |
300 |
301 | def dice2_loss_mean(pred, label, scale_idx=None, args=None):
302 | pred = torch.sigmoid(pred)
303 | label = label.float()
304 | intersection = 2 * pred * label
305 | union = pred + label
306 | iou = (intersection.float().sum(dim=0).sum(dim=-1).sum(dim=-1) + 1.0) / (
307 | union.float().sum(dim=0).sum(dim=-1).sum(dim=-1) + 1.0
308 | )
309 |
310 | loss_mean = 1 - iou.mean()
311 |
312 | return loss_mean
313 |
314 |
315 | def dice_loss_mean_mix1(pred, label, scale_idx=None, args=None):
316 | pred = torch.sigmoid(pred)
317 | label = label.float()
318 | intersection = 2 * pred * label
319 | union = pred + label
320 |
321 | e_numerator = torch.zeros(pred.shape[1]) + 1e-5
322 | idxs_to_zero_out = torch.tensor([4, 5, 7, 8, 10]).long()
323 | e_numerator[idxs_to_zero_out] = 0
324 | e_numerator = e_numerator.cuda()
325 |
326 | iou = (intersection.float().sum(dim=0).sum(dim=-1).sum(dim=-1) + e_numerator) / (
327 | union.float().sum(dim=0).sum(dim=-1).sum(dim=-1) + 1e-5
328 | )
329 |
330 | loss_mean = 1 - iou.mean()
331 |
332 | return loss_mean
333 |
334 |
335 | def dice_loss_mean_mix2(pred, label, scale_idx=None, args=None):
336 | pred = torch.sigmoid(pred)
337 | label = label.float()
338 | intersection = 2 * pred * label
339 | union = pred + label
340 |
341 | e_numerator = torch.zeros(pred.shape[1]) + 1e-5
342 | idxs_to_zero_out = torch.tensor([4, 5, 7, 8]).long()
343 | e_numerator[idxs_to_zero_out] = 0
344 | e_numerator = e_numerator.cuda()
345 |
346 | iou = (intersection.float().sum(dim=0).sum(dim=-1).sum(dim=-1) + e_numerator) / (
347 | union.float().sum(dim=0).sum(dim=-1).sum(dim=-1) + 1e-5
348 | )
349 |
350 | loss_mean = 1 - iou.mean()
351 |
352 | return loss_mean
353 |
354 |
355 | def dice_loss_per_class(pred, labels, scale_idx=None):
356 | pred = torch.sigmoid(pred)
357 | labels = labels.float()
358 | intersection = 2 * pred * labels
359 | union = pred + labels
360 | iou = (intersection.float().sum(dim=0).sum(dim=-1).sum(dim=-1)) / (
361 | union.float().sum(dim=0).sum(dim=-1).sum(dim=-1) + 1e-5
362 | )
363 | return 1 - iou
364 |
365 |
366 | def dice_loss_per_class_infer(pred, labels, scale_idx=None):
367 | pred = torch.sigmoid(pred)
368 | labels = labels.float()
369 | intersection = (2 * pred * labels).float().sum(dim=0).sum(dim=-1).sum(dim=-1)
370 | union = (pred + labels).float().sum(dim=0).sum(dim=-1).sum(dim=-1)
371 | iou = (intersection) / (union + 1e-5)
372 | return 1 - iou, intersection, union
373 |
374 |
375 | def flooded_dice_loss_mean(dice_loss, scale_idx):
376 |
377 | # Flood level per class
378 | flood_level = class_flood_level()[scale_idx]
379 |
380 | flood_loss = (dice_loss - flood_level).abs() + flood_level
381 |
382 | return flood_loss.mean()
383 |
384 |
385 | def flooded_dice_loss_sum(dice_loss, scale_idx):
386 |
387 | # Flood level per class
388 | flood_level = class_flood_level()[scale_idx]
389 |
390 | flood_loss = (dice_loss - flood_level).abs() + flood_level
391 |
392 | return flood_loss.sum()
393 |
394 |
395 | def iou_loss(pred, labels):
396 | e = 1e-6
397 | pred = torch.sigmoid(pred)
398 | labels = labels.float()
399 | intersection = pred * labels
400 | union = (pred + labels) - (pred * labels)
401 | iou = intersection.sum() / (union.sum() + e)
402 | return 1.0 - iou
403 |
404 |
405 | def bce_plus_dice_loss(pred, labels, weights, alpha):
406 | bce = F.binary_cross_entropy_with_logits(
407 | input=pred, target=labels, weight=weights, reduction="mean"
408 | )
409 | dice = dice_loss_mean(pred, labels)
410 | total = (alpha[0] * bce + alpha[1] * dice).sum()
411 | return total
412 |
413 |
414 | def bce_plus_iou_loss(pred, labels, weights, alpha):
415 | bce = F.binary_cross_entropy_with_logits(
416 | input=pred, target=labels, weight=weights, reduction="mean"
417 | )
418 | iou = iou_loss(pred, labels)
419 | total = (alpha[0] * bce + alpha[1] * iou).sum()
420 | return total
421 |
422 |
423 | def class_balanced_loss(
424 | labels, logits, samples_per_cls, no_of_classes, loss_type, alpha=None
425 | ):
426 | """Compute the Class Balanced Loss between `logits` and the ground truth `labels`.
427 | Class Balanced Loss: ((1-beta)/(1-beta^n))*Loss(labels, logits)
428 | where Loss is one of the standard losses used for Neural Networks.
429 | Args:
430 | labels: A int tensor of size [batch].
431 | logits: A float tensor of size [batch, no_of_classes].
432 | samples_per_cls: A python list of size [no_of_classes].
433 | no_of_classes: total number of classes. int
434 | loss_type: string. One of "sigmoid", "focal", "softmax".
435 | beta: float. Hyperparameter for Class balanced loss.
436 | gamma: float. Hyperparameter for Focal loss.
437 | Returns:
438 | cb_loss: A float tensor representing class balanced loss
439 | """
440 | # effective_num = 1.0 - np.power(beta, samples_per_cls)
441 | # weights = (1.0 - beta) / np.array(effective_num)
442 | # weights = weights / np.sum(weights) * no_of_classes
443 |
444 | if loss_type == "bce":
445 | labels_one_hot = F.one_hot(labels.long(), no_of_classes).float()
446 | weights = 1 / np.array(samples_per_cls)
447 | weights = torch.tensor(weights).float().cuda()
448 | weights = weights.view(1, 1, 1, 1, -1)
449 | weights = (
450 | weights.repeat(
451 | labels_one_hot.shape[0],
452 | labels_one_hot.shape[1],
453 | labels_one_hot.shape[2],
454 | labels_one_hot.shape[3],
455 | 1,
456 | )
457 | * labels_one_hot
458 | )
459 | weights = weights.sum(-1)
460 | cb_loss = F.binary_cross_entropy_with_logits(
461 | input=logits, target=labels, weight=weights
462 | )
463 | elif loss_type == "iou":
464 | cb_loss = iou_loss(logits, labels)
465 | elif loss_type == "dice":
466 | cb_loss = dice_loss_mean(logits, labels)
467 | elif loss_type == "lovasz":
468 | cb_loss = lovasz_hinge_flat(logits, labels)
469 | # elif loss_type == "bce_iou":
470 | # cb_loss = bce_plus_iou_loss(logits, labels, weights, alpha)
471 | # elif loss_type == "bce_dice":
472 | # cb_loss = bce_plus_dice_loss(logits, labels, weights, alpha)
473 |
474 | return cb_loss
475 |
476 |
477 | def weighted_bce_loss(pred, label, vis_mask):
478 | # Apply visibility mask to predictions
479 | pred = pred * vis_mask.float()
480 |
481 | pred = torch.sigmoid(pred)
482 |
483 | label = label.float()
484 |
485 | eps = 1e-12
486 |
487 | # Reshape weights for broadcasting
488 | cf = get_class_frequency()[None, :, None, None]
489 | pos_weights = torch.sqrt(1 / cf)
490 | neg_weights = torch.sqrt(1 / (1 - cf))
491 |
492 | # labels * -log(sigmoid(logits)) +
493 | # (1 - labels) * -log(1 - sigmoid(logits))
494 |
495 | loss = pos_weights * label * -torch.log(pred + eps) + (1 - label) * -torch.log(
496 | 1 - pred + eps
497 | )
498 |
499 | return loss.sum(dim=-1).sum(dim=-1).sum(dim=0)
500 |
501 |
502 | def uncertainty_loss(pred, vis_mask):
503 | # Invert visibility mask
504 | vis_mask = 1 - vis_mask.float()
505 | eps = 1e-12
506 | pred = pred * vis_mask
507 | pred = torch.sigmoid(pred)
508 | loss = 1 - pred * torch.log(pred + eps)
509 |
510 | return loss.sum(dim=-1).sum(dim=-1).sum(dim=0)
511 |
512 |
513 | def aux_recognition_loss(pred, label):
514 | """
515 | BCE loss for class probabilities, multi-label classification
516 | """
517 | # pred = torch.sigmoid(pred)
518 | label = src.utils.count_classes_per_sample(label)
519 |
520 | assert pred.shape == label.shape
521 |
522 | class_freq = get_class_frequency()[None, :]
523 | weights = (1 / class_freq).repeat(len(pred), 1)
524 |
525 | loss = F.binary_cross_entropy_with_logits(
526 | pred, label, weight=weights, reduction="none"
527 | )
528 | return loss
529 |
530 |
531 | def focal_loss(pred, gt, alpha, gamma):
532 | BCE_loss = F.binary_cross_entropy_with_logits(
533 | pred, gt.float(), reduction="none"
534 | ).sum(-1)
535 | gt, _ = torch.max(gt, dim=-1)
536 | gt = gt.long()
537 | at = torch.tensor(alpha, device=gt.device).gather(0, gt.data.view(-1))
538 | pt = torch.exp(-BCE_loss)
539 | F_loss = at * (1 - pt) ** gamma * BCE_loss
540 | return F_loss
541 |
--------------------------------------------------------------------------------
/src/model/mocha2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class Energy(nn.Module):
7 | def __init__(self, enc_dim=10, dec_dim=10, att_dim=10, init_r=-4):
8 | """
9 | [Modified Bahdahnau attention] from
10 | "Online and Linear-Time Attention by Enforcing Monotonic Alignment" (ICML 2017)
11 | http://arxiv.org/abs/1704.00784
12 | Used for Monotonic Attention and Chunk Attention
13 | """
14 | super().__init__()
15 | self.tanh = nn.Tanh()
16 | self.W = nn.Linear(enc_dim, att_dim, bias=False)
17 | self.V = nn.Linear(dec_dim, att_dim, bias=False)
18 | self.b = nn.Parameter(torch.Tensor(att_dim).normal_())
19 |
20 | self.v = nn.utils.weight_norm(nn.Linear(10, 1))
21 | self.v.weight_g.data = torch.Tensor([1 / att_dim]).sqrt()
22 |
23 | self.r = nn.Parameter(torch.Tensor([init_r]))
24 |
25 | def forward(self, encoder_outputs, decoder_h):
26 | """
27 | Args:
28 | encoder_outputs: [batch_size, sequence_length, enc_dim]
29 | decoder_h: [batch_size, dec_dim]
30 | Return:
31 | Energy [batch_size, sequence_length]
32 | """
33 | batch_size, sequence_length, enc_dim = encoder_outputs.size()
34 | encoder_outputs = encoder_outputs.view(-1, enc_dim)
35 | energy = self.tanh(
36 | self.W(encoder_outputs)
37 | + self.V(decoder_h).repeat(sequence_length, 1)
38 | + self.b
39 | )
40 | energy = self.v(energy).squeeze(-1) + self.r
41 |
42 | return energy.view(batch_size, sequence_length)
43 |
44 |
45 | class MonotonicAttention(nn.Module):
46 | def __init__(self):
47 | """
48 | [Monotonic Attention] from
49 | "Online and Linear-Time Attention by Enforcing Monotonic Alignment" (ICML 2017)
50 | http://arxiv.org/abs/1704.00784
51 | """
52 | super().__init__()
53 |
54 | self.monotonic_energy = Energy()
55 | self.sigmoid = nn.Sigmoid()
56 |
57 | def gaussian_noise(self, *size):
58 | """Additive gaussian nosie to encourage discreteness"""
59 | if torch.cuda.is_available():
60 | return torch.cuda.FloatTensor(*size).normal_()
61 | else:
62 | return torch.Tensor(*size).normal_()
63 |
64 | def safe_cumprod(self, x):
65 | """Numerically stable cumulative product by cumulative sum in log-space"""
66 | return torch.exp(
67 | torch.cumsum(torch.log(torch.clamp(x, min=1e-10, max=1)), dim=1)
68 | )
69 |
70 | def exclusive_cumprod(self, x):
71 | """Exclusive cumulative product [a, b, c] => [1, a, a * b]
72 | * TensorFlow: https://www.tensorflow.org/api_docs/python/tf/cumprod
73 | * PyTorch: https://discuss.pytorch.org/t/cumprod-exclusive-true-equivalences/2614
74 | """
75 | batch_size, sequence_length = x.size()
76 | if torch.cuda.is_available():
77 | one_x = torch.cat([torch.ones(batch_size, 1).cuda(), x], dim=1)[:, :-1]
78 | else:
79 | one_x = torch.cat([torch.ones(batch_size, 1), x], dim=1)[:, :-1]
80 | return torch.cumprod(one_x, dim=1)
81 |
82 | def soft(self, encoder_outputs, decoder_h, previous_alpha=None):
83 | """
84 | Soft monotonic attention (Train)
85 | Args:
86 | encoder_outputs [batch_size, sequence_length, enc_dim]
87 | decoder_h [batch_size, dec_dim]
88 | previous_alpha [batch_size, sequence_length]
89 | Return:
90 | alpha [batch_size, sequence_length]
91 | """
92 | batch_size, sequence_length, enc_dim = encoder_outputs.size()
93 |
94 | monotonic_energy = self.monotonic_energy(encoder_outputs, decoder_h)
95 | p_select = self.sigmoid(
96 | monotonic_energy + self.gaussian_noise(monotonic_energy.size())
97 | )
98 | cumprod_1_minus_p = self.safe_cumprod(1 - p_select)
99 |
100 | if previous_alpha is None:
101 | # First iteration => alpha = [1, 0, 0 ... 0]
102 | alpha = torch.zeros(batch_size, sequence_length)
103 | alpha[:, 0] = torch.ones(batch_size)
104 | if torch.cuda.is_available:
105 | alpha = alpha.cuda()
106 |
107 | else:
108 | alpha = (
109 | p_select
110 | * cumprod_1_minus_p
111 | * torch.cumsum(previous_alpha / cumprod_1_minus_p, dim=1)
112 | )
113 |
114 | return alpha
115 |
116 | def hard(self, encoder_outputs, decoder_h, previous_attention=None):
117 | """
118 | Hard monotonic attention (Test)
119 | Args:
120 | encoder_outputs [batch_size, sequence_length, enc_dim]
121 | decoder_h [batch_size, dec_dim]
122 | previous_attention [batch_size, sequence_length]
123 | Return:
124 | alpha [batch_size, sequence_length]
125 | """
126 | batch_size, sequence_length, enc_dim = encoder_outputs.size()
127 |
128 | if previous_attention is None:
129 | # First iteration => alpha = [1, 0, 0 ... 0]
130 | attention = torch.zeros(batch_size, sequence_length)
131 | attention[:, 0] = torch.ones(batch_size)
132 | if torch.cuda.is_available:
133 | attention = attention.cuda()
134 | else:
135 | # TODO: Linear Time Decoding
136 | # It's not clear if authors' TF implementation decodes in linear time.
137 | # https://github.com/craffel/mad/blob/master/example_decoder.py#L235
138 | # They calculate energies for whole encoder outputs
139 | # instead of scanning from previous attended encoder output.
140 | monotonic_energy = self.monotonic_energy(encoder_outputs, decoder_h)
141 |
142 | # Hard Sigmoid
143 | # Attend when monotonic energy is above threshold (Sigmoid > 0.5)
144 | above_threshold = (monotonic_energy > 0).float()
145 |
146 | p_select = above_threshold * torch.cumsum(previous_attention, dim=1)
147 | attention = p_select * self.exclusive_cumprod(1 - p_select)
148 |
149 | # Not attended => attend at last encoder output
150 | # Assume that encoder outputs are not padded
151 | attended = attention.sum(dim=1)
152 | for batch_i in range(batch_size):
153 | if not attended[batch_i]:
154 | attention[batch_i, -1] = 1
155 |
156 | # Ex)
157 | # p_select = [0, 0, 0, 1, 1, 0, 1, 1]
158 | # 1 - p_select = [1, 1, 1, 0, 0, 1, 0, 0]
159 | # exclusive_cumprod(1 - p_select) = [1, 1, 1, 1, 0, 0, 0, 0]
160 | # attention: product of above = [0, 0, 0, 1, 0, 0, 0, 0]
161 | return attention
162 |
163 |
164 | class MoChA(MonotonicAttention):
165 | def __init__(self, chunk_size=3):
166 | """
167 | [Monotonic Chunkwise Attention] from
168 | "Monotonic Chunkwise Attention" (ICLR 2018)
169 | https://openreview.net/forum?id=Hko85plCW
170 | """
171 | super().__init__()
172 | self.chunk_size = chunk_size
173 | self.chunk_energy = Energy()
174 | self.softmax = nn.Softmax(dim=1)
175 |
176 | def moving_sum(self, x, back, forward):
177 | """Parallel moving sum with 1D Convolution"""
178 | # Pad window before applying convolution
179 | # [batch_size, back + sequence_length + forward]
180 | x_padded = F.pad(x, pad=[back, forward])
181 |
182 | # Fake channel dimension for conv1d
183 | # [batch_size, 1, back + sequence_length + forward]
184 | x_padded = x_padded.unsqueeze(1)
185 |
186 | # Apply conv1d with filter of all ones for moving sum
187 | filters = torch.ones(1, 1, back + forward + 1)
188 | if torch.cuda.is_available():
189 | filters = filters.cuda()
190 | x_sum = F.conv1d(x_padded, filters)
191 |
192 | # Remove fake channel dimension
193 | # [batch_size, sequence_length]
194 | return x_sum.squeeze(1)
195 |
196 | def chunkwise_attention_soft(self, alpha, u):
197 | """
198 | Args:
199 | alpha [batch_size, sequence_length]: emission probability in monotonic attention
200 | u [batch_size, sequence_length]: chunk energy
201 | chunk_size (int): window size of chunk
202 | Return
203 | beta [batch_size, sequence_length]: MoChA weights
204 | """
205 |
206 | # Numerical stability
207 | # Divide by same exponent => doesn't affect softmax
208 | u -= torch.max(u, dim=1, keepdim=True)[0]
209 | exp_u = torch.exp(u)
210 | # Limit range of logit
211 | exp_u = torch.clamp(exp_u, min=1e-5)
212 |
213 | # Moving sum:
214 | # Zero-pad (chunk size - 1) on the left + 1D conv with filters of 1s.
215 | # [batch_size, sequence_length]
216 | denominators = self.moving_sum(exp_u, back=self.chunk_size - 1, forward=0)
217 |
218 | # Compute beta (MoChA weights)
219 | beta = exp_u * self.moving_sum(
220 | alpha / denominators, back=0, forward=self.chunk_size - 1
221 | )
222 | return beta
223 |
224 | def chunkwise_attention_hard(self, monotonic_attention, chunk_energy):
225 | """
226 | Mask non-attended area with '-inf'
227 | Args:
228 | monotonic_attention [batch_size, sequence_length]
229 | chunk_energy [batch_size, sequence_length]
230 | Return:
231 | masked_energy [batch_size, sequence_length]
232 | """
233 | batch_size, sequence_length = monotonic_attention.size()
234 |
235 | # [batch_size]
236 | attended_indices = monotonic_attention.nonzero().cpu().data[:, 1].tolist()
237 |
238 | i = [[], []]
239 | total_i = 0
240 | for batch_i, attended_idx in enumerate(attended_indices):
241 | for window in range(self.chunk_size):
242 | if attended_idx - window >= 0:
243 | i[0].append(batch_i)
244 | i[1].append(attended_idx - window)
245 | total_i += 1
246 | i = torch.LongTensor(i)
247 | v = torch.FloatTensor([1] * total_i)
248 | mask = torch.sparse.FloatTensor(i, v, monotonic_attention.size())
249 | mask = ~mask.to_dense().cuda().byte()
250 |
251 | # mask '-inf' energy before softmax
252 | masked_energy = chunk_energy.masked_fill_(mask, -float("inf"))
253 | return masked_energy
254 |
255 | def soft(self, encoder_outputs, decoder_h, previous_alpha=None):
256 | """
257 | Soft monotonic chunkwise attention (Train)
258 | Args:
259 | encoder_outputs [batch_size, sequence_length, enc_dim]
260 | decoder_h [batch_size, dec_dim]
261 | previous_alpha [batch_size, sequence_length]
262 | Return:
263 | alpha [batch_size, sequence_length]
264 | beta [batch_size, sequence_length]
265 | """
266 | alpha = super().soft(encoder_outputs, decoder_h, previous_alpha)
267 | chunk_energy = self.chunk_energy(encoder_outputs, decoder_h)
268 | beta = self.chunkwise_attention_soft(alpha, chunk_energy)
269 | return alpha, beta
270 |
271 | def hard(self, encoder_outputs, decoder_h, previous_attention=None):
272 | """
273 | Hard monotonic chunkwise attention (Test)
274 | Args:
275 | encoder_outputs [batch_size, sequence_length, enc_dim]
276 | decoder_h [batch_size, dec_dim]
277 | previous_attention [batch_size, sequence_length]
278 | Return:
279 | monotonic_attention [batch_size, sequence_length]: hard alpha
280 | chunkwise_attention [batch_size, sequence_length]: hard beta
281 | """
282 | # hard attention (one-hot)
283 | # [batch_size, sequence_length]
284 | monotonic_attention = super().hard(
285 | encoder_outputs, decoder_h, previous_attention
286 | )
287 | chunk_energy = self.chunk_energy(encoder_outputs, decoder_h)
288 | masked_energy = self.chunkwise_attention_hard(monotonic_attention, chunk_energy)
289 | chunkwise_attention = self.softmax(masked_energy)
290 | return monotonic_attention, chunkwise_attention
291 |
--------------------------------------------------------------------------------
/src/model/monotonic_attention.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import random
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class MonotonicAttention(nn.Module):
10 | """
11 | Monotonic Multihead Attention with Infinite Lookback
12 | """
13 |
14 | def __init__(
15 | self,
16 | embed_dim,
17 | num_heads,
18 | kdim=None,
19 | vdim=None,
20 | dropout=0.0,
21 | bias=True,
22 | energy_bias=True,
23 | ):
24 | self.embed_dim = embed_dim
25 | self.kdim = kdim if kdim is not None else embed_dim
26 | self.vdim = vdim if vdim is not None else embed_dim
27 | self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
28 |
29 | self.k_proj_mono = nn.Linear(embed_dim, embed_dim, bias=bias)
30 | self.q_proj_mono = nn.Linear(embed_dim, embed_dim, bias=bias)
31 | self.k_proj_soft = nn.Linear(embed_dim, embed_dim, bias=bias)
32 | self.q_proj_soft = nn.Linear(embed_dim, embed_dim, bias=bias)
33 | self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
34 |
35 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
36 |
37 | self.num_heads = num_heads
38 | self.head_dim = embed_dim // num_heads
39 | assert (
40 | self.head_dim * num_heads == self.embed_dim
41 | ), "embed_dim must be divisible by num_heads"
42 | self.scaling = self.head_dim ** -0.5
43 |
44 | self.eps = 1e-6
45 |
46 | self.noise_mean = 0.0
47 | self.noise_var = 1.0
48 |
49 | self.energy_bias_init = -2.0
50 | self.energy_bias = (
51 | nn.Parameter(self.energy_bias_init * torch.ones([1]))
52 | if energy_bias is True
53 | else 0
54 | )
55 |
56 | self.reset_parameters()
57 |
58 | self.k_in_proj = {"monotonic": self.k_proj_mono, "soft": self.k_proj_soft}
59 | self.q_in_proj = {"monotonic": self.q_proj_mono, "soft": self.q_proj_soft}
60 | self.v_in_proj = {"output": self.v_proj}
61 |
62 | def reset_parameters(self):
63 | if self.qkv_same_dim:
64 | # Empirically observed the convergence to be much better with
65 | # the scaled initialization
66 | nn.init.xavier_uniform_(self.k_proj_mono.weight, gain=1 / math.sqrt(2))
67 | nn.init.xavier_uniform_(self.k_proj_soft.weight, gain=1 / math.sqrt(2))
68 | nn.init.xavier_uniform_(self.q_proj_mono.weight, gain=1 / math.sqrt(2))
69 | nn.init.xavier_uniform_(self.q_proj_soft.weight, gain=1 / math.sqrt(2))
70 | nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
71 |
72 | else:
73 | nn.init.xavier_uniform_(self.k_proj_mono.weight)
74 | nn.init.xavier_uniform_(self.k_proj_soft.weight)
75 | nn.init.xavier_uniform_(self.q_proj_mono.weight)
76 | nn.init.xavier_uniform_(self.q_proj_soft.weight)
77 | nn.init.xavier_uniform_(self.v_proj.weight)
78 |
79 | nn.init.xavier_uniform_(self.out_proj.weight)
80 | if self.out_proj.bias is not None:
81 | nn.init.constant_(self.out_proj.bias, 0.0)
82 |
83 | def input_projections(self, query, key, value, name):
84 | """
85 | Prepare inputs for multihead attention
86 | ============================================================
87 | Expected input size
88 | query: tgt_len, bsz, embed_dim
89 | key: src_len, bsz, embed_dim
90 | value: src_len, bsz, embed_dim
91 | name: monotonic or soft
92 | """
93 |
94 | if query is not None:
95 | bsz = query.size(1)
96 | q = self.q_in_proj(query)
97 | q *= self.scaling
98 | q = (
99 | q.contiguous()
100 | .view(-1, bsz * self.num_heads, self.head_dim)
101 | .transpose(0, 1)
102 | )
103 | else:
104 | q = None
105 |
106 | if key is not None:
107 | bsz = key.size(1)
108 | k = self.k_in_proj(key)
109 | k = (
110 | k.contiguous()
111 | .view(-1, bsz * self.num_heads, self.head_dim)
112 | .transpose(0, 1)
113 | )
114 | else:
115 | k = None
116 |
117 | if value is not None:
118 | bsz = value.size(1)
119 | v = self.v_proj(value)
120 | v = (
121 | v.contiguous()
122 | .view(-1, bsz * self.num_heads, self.head_dim)
123 | .transpose(0, 1)
124 | )
125 | else:
126 | v = None
127 |
128 | return q, k, v
129 |
130 | def p_choose(self, query, key, key_padding_mask=None):
131 | """
132 | Calculating step wise prob for reading and writing
133 | 1 to read, 0 to write
134 | ============================================================
135 | Expected input size
136 | query: bsz, tgt_len, embed_dim
137 | key: bsz, src_len, embed_dim
138 | value: bsz, src_len, embed_dim
139 | key_padding_mask: bsz, src_len
140 | attn_mask: bsz, src_len
141 | query: bsz, tgt_len, embed_dim
142 | """
143 |
144 | # prepare inputs
145 | q_proj, k_proj, _ = self.input_projections(query, key, None, "monotonic")
146 |
147 | # attention energy
148 | attn_energy = self.attn_energy(q_proj, k_proj, key_padding_mask)
149 |
150 | noise = 0
151 |
152 | if self.training:
153 | # add noise here to encourage discretness
154 | noise = (
155 | torch.normal(self.noise_mean, self.noise_var, attn_energy.size())
156 | .type_as(attn_energy)
157 | .to(attn_energy.device)
158 | )
159 |
160 | p_choose = torch.sigmoid(attn_energy + noise)
161 | _, _, tgt_len, src_len = p_choose.size()
162 |
163 | # p_choose: bsz * self.num_heads, tgt_len, src_len
164 | return p_choose.view(-1, tgt_len, src_len)
165 |
166 | def attn_energy(self, q_proj, k_proj, key_padding_mask=None):
167 | """
168 | Calculating monotonic energies
169 | ============================================================
170 | Expected input size
171 | q_proj: bsz * num_heads, tgt_len, self.head_dim
172 | k_proj: bsz * num_heads, src_len, self.head_dim
173 | key_padding_mask: bsz, src_len
174 | attn_mask: tgt_len, src_len
175 | """
176 | bsz, tgt_len, embed_dim = q_proj.size()
177 | bsz = bsz // self.num_heads
178 | src_len = k_proj.size(1)
179 |
180 | attn_energy = torch.bmm(q_proj, k_proj.transpose(1, 2)) + self.energy_bias
181 |
182 | attn_energy = attn_energy.view(bsz, self.num_heads, tgt_len, src_len)
183 |
184 | if key_padding_mask is not None:
185 | attn_energy = attn_energy.masked_fill(
186 | key_padding_mask.unsqueeze(1).unsqueeze(2).bool(),
187 | float("-inf"),
188 | )
189 |
190 | return attn_energy
191 |
192 | def expected_alignment_train(self, p_choose, key_padding_mask):
193 | """
194 | Calculating expected alignment for MMA
195 | Mask is not need because p_choose will be 0 if masked
196 | q_ij = (1 − p_{ij−1})q_{ij−1} + a+{i−1j}
197 | a_ij = p_ij q_ij
198 | parellel solution:
199 | ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi))
200 | ============================================================
201 | Expected input size
202 | p_choose: bsz * num_heads, tgt_len, src_len
203 | """
204 |
205 | # p_choose: bsz * num_heads, tgt_len, src_len
206 | bsz_num_heads, tgt_len, src_len = p_choose.size()
207 |
208 | # cumprod_1mp : bsz * num_heads, tgt_len, src_len
209 | cumprod_1mp = exclusive_cumprod(1 - p_choose, dim=2, eps=self.eps)
210 | cumprod_1mp_clamp = torch.clamp(cumprod_1mp, self.eps, 1.0)
211 |
212 | init_attention = p_choose.new_zeros([bsz_num_heads, 1, src_len])
213 | init_attention[:, :, 0] = 1.0
214 |
215 | previous_attn = [init_attention]
216 |
217 | for i in range(tgt_len):
218 | # p_choose: bsz * num_heads, tgt_len, src_len
219 | # cumprod_1mp_clamp : bsz * num_heads, tgt_len, src_len
220 | # previous_attn[i]: bsz * num_heads, 1, src_len
221 | # alpha_i: bsz * num_heads, src_len
222 | alpha_i = (
223 | p_choose[:, i]
224 | * cumprod_1mp[:, i]
225 | * torch.cumsum(previous_attn[i][:, 0] / cumprod_1mp_clamp[:, i], dim=1)
226 | ).clamp(0, 1.0)
227 | previous_attn.append(alpha_i.unsqueeze(1))
228 |
229 | # alpha: bsz * num_heads, tgt_len, src_len
230 | alpha = torch.cat(previous_attn[1:], dim=1)
231 |
232 | if self.mass_preservation:
233 | # Last token has the residual probabilities
234 | alpha[:, :, -1] = 1 - alpha[:, :, :-1].sum(dim=-1).clamp(0.0, 1.0)
235 |
236 | assert not torch.isnan(alpha).any(), "NaN detected in alpha."
237 |
238 | return alpha
239 |
240 | def expected_attention(
241 | self, alpha, query, key, value, key_padding_mask, incremental_state
242 | ):
243 | # monotonic attention, we will calculate milk here
244 | bsz_x_num_heads, tgt_len, src_len = alpha.size()
245 | bsz = int(bsz_x_num_heads / self.num_heads)
246 |
247 | q, k, _ = self.input_projections(query, key, None, "soft")
248 | soft_energy = self.attn_energy(q, k, key_padding_mask)
249 |
250 | assert list(soft_energy.size()) == [bsz, self.num_heads, tgt_len, src_len]
251 |
252 | soft_energy = soft_energy.view(bsz * self.num_heads, tgt_len, src_len)
253 |
254 | if incremental_state is not None:
255 | monotonic_cache = self._get_monotonic_buffer(incremental_state)
256 | monotonic_step = monotonic_cache["step"] + 1
257 | step_offset = 0
258 | if key_padding_mask is not None:
259 | if key_padding_mask[:, 0].any():
260 | # left_pad_source = True:
261 | step_offset = key_padding_mask.sum(dim=-1, keepdim=True)
262 | monotonic_step += step_offset
263 | mask = lengths_to_mask(
264 | monotonic_step.view(-1), soft_energy.size(2), 1
265 | ).unsqueeze(1)
266 |
267 | soft_energy = soft_energy.masked_fill(~mask.bool(), float("-inf"))
268 | soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0]
269 | exp_soft_energy = torch.exp(soft_energy)
270 | exp_soft_energy_sum = exp_soft_energy.sum(dim=2)
271 | beta = exp_soft_energy / exp_soft_energy_sum.unsqueeze(2)
272 |
273 | else:
274 | # bsz * num_heads, tgt_len, src_len
275 | soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0]
276 | exp_soft_energy = torch.exp(soft_energy)
277 | exp_soft_energy_cumsum = torch.cumsum(exp_soft_energy, dim=2)
278 |
279 | if key_padding_mask is not None:
280 | if key_padding_mask.any():
281 | exp_soft_energy_cumsum = (
282 | exp_soft_energy_cumsum.view(
283 | -1, self.num_heads, tgt_len, src_len
284 | )
285 | .masked_fill(
286 | key_padding_mask.unsqueeze(1).unsqueeze(1), self.eps
287 | )
288 | .view(-1, tgt_len, src_len)
289 | )
290 |
291 | inner_items = alpha / exp_soft_energy_cumsum
292 |
293 | beta = exp_soft_energy * torch.cumsum(
294 | inner_items.flip(dims=[2]), dim=2
295 | ).flip(dims=[2])
296 |
297 | beta = self.dropout_module(beta)
298 |
299 | assert not torch.isnan(beta).any(), "NaN detected in beta."
300 |
301 | return beta
302 |
303 | def forward(
304 | self,
305 | query,
306 | key,
307 | value,
308 | key_padding_mask=None,
309 | incremental_state=None,
310 | *args,
311 | **kwargs,
312 | ):
313 |
314 | tgt_len, bsz, embed_dim = query.size()
315 | src_len = value.size(0)
316 |
317 | # stepwise prob
318 | # p_choose: bsz * self.num_heads, tgt_len, src_len
319 | p_choose = self.p_choose(query, key, key_padding_mask)
320 |
321 | # expected alignment alpha
322 | # bsz * self.num_heads, tgt_len, src_len
323 |
324 | alpha = self.expected_alignment_train(p_choose, key_padding_mask)
325 |
326 | # expected attention beta
327 | # bsz * self.num_heads, tgt_len, src_len
328 | beta = self.expected_attention(
329 | alpha, query, key, value, key_padding_mask, incremental_state
330 | )
331 |
332 | attn_weights = beta
333 |
334 | _, _, v_proj = self.input_projections(None, None, value, "output")
335 | attn = torch.bmm(attn_weights.type_as(v_proj), v_proj)
336 |
337 | attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
338 |
339 | attn = self.out_proj(attn)
340 |
341 | beta = beta.view(bsz, self.num_heads, tgt_len, src_len)
342 | alpha = alpha.view(bsz, self.num_heads, tgt_len, src_len)
343 | p_choose = p_choose.view(bsz, self.num_heads, tgt_len, src_len)
344 |
345 | return attn, alpha, beta, p_choose
346 |
347 |
348 | def exclusive_cumprod(tensor, dim: int, eps: float = 1e-10):
349 | """
350 | Implementing exclusive cumprod.
351 | There is cumprod in pytorch, however there is no exclusive mode.
352 | cumprod(x) = [x1, x1x2, x2x3x4, ..., prod_{i=1}^n x_i]
353 | exclusive means cumprod(x) = [1, x1, x1x2, x1x2x3, ..., prod_{i=1}^{n-1} x_i]
354 | """
355 | tensor_size = list(tensor.size())
356 | tensor_size[dim] = 1
357 | return_tensor = safe_cumprod(
358 | torch.cat([torch.ones(tensor_size).type_as(tensor), tensor], dim=dim),
359 | dim=dim,
360 | eps=eps,
361 | )
362 |
363 | if dim == 0:
364 | return return_tensor[:-1]
365 | elif dim == 1:
366 | return return_tensor[:, :-1]
367 | elif dim == 2:
368 | return return_tensor[:, :, :-1]
369 | else:
370 | raise RuntimeError("Cumprod on dimension 3 and more is not implemented")
371 |
372 |
373 | def safe_cumprod(tensor, dim: int, eps: float = 1e-10):
374 | """
375 | An implementation of cumprod to prevent precision issue.
376 | cumprod(x)
377 | = [x1, x1x2, x1x2x3, ....]
378 | = [exp(log(x1)), exp(log(x1) + log(x2)), exp(log(x1) + log(x2) + log(x3)), ...]
379 | = exp(cumsum(log(x)))
380 | """
381 |
382 | if (tensor + eps < 0).any().item():
383 | raise RuntimeError(
384 | "Safe cumprod can only take non-negative tensors as input."
385 | "Consider use torch.cumprod if you want to calculate negative values."
386 | )
387 |
388 | log_tensor = torch.log(tensor + eps)
389 | cumsum_log_tensor = torch.cumsum(log_tensor, dim)
390 | exp_cumsum_log_tensor = torch.exp(cumsum_log_tensor)
391 | return exp_cumsum_log_tensor
392 |
393 |
394 | def lengths_to_mask(lengths, max_len: int, dim: int = 0, negative_mask: bool = False):
395 | """
396 | Convert a tensor of lengths to mask
397 | For example, lengths = [[2, 3, 4]], max_len = 5
398 | mask =
399 | [[1, 1, 1],
400 | [1, 1, 1],
401 | [0, 1, 1],
402 | [0, 0, 1],
403 | [0, 0, 0]]
404 | """
405 | assert len(lengths.size()) <= 2
406 | if len(lengths) == 2:
407 | if dim == 1:
408 | lengths = lengths.t()
409 | lengths = lengths
410 | else:
411 | lengths = lengths.unsqueeze(1)
412 |
413 | # lengths : batch_size, 1
414 | lengths = lengths.view(-1, 1)
415 |
416 | batch_size = lengths.size(0)
417 | # batch_size, max_len
418 | mask = torch.arange(max_len).expand(batch_size, max_len).type_as(lengths) < lengths
419 |
420 | if negative_mask:
421 | mask = ~mask
422 |
423 | if dim == 0:
424 | # max_len, batch_size
425 | mask = mask.t()
426 |
427 | return mask
428 |
429 |
430 | def moving_sum(x, start_idx: int, end_idx: int):
431 | """
432 | From MONOTONIC CHUNKWISE ATTENTION
433 | https://arxiv.org/pdf/1712.05382.pdf
434 | Equation (18)
435 | x = [x_1, x_2, ..., x_N]
436 | MovingSum(x, start_idx, end_idx)_n = Sigma_{m=n−(start_idx−1)}^{n+end_idx-1} x_m
437 | for n in {1, 2, 3, ..., N}
438 | x : src_len, batch_size
439 | start_idx : start idx
440 | end_idx : end idx
441 | Example
442 | src_len = 5
443 | batch_size = 3
444 | x =
445 | [[ 0, 5, 10],
446 | [ 1, 6, 11],
447 | [ 2, 7, 12],
448 | [ 3, 8, 13],
449 | [ 4, 9, 14]]
450 | MovingSum(x, 3, 1) =
451 | [[ 0, 5, 10],
452 | [ 1, 11, 21],
453 | [ 3, 18, 33],
454 | [ 6, 21, 36],
455 | [ 9, 24, 39]]
456 | MovingSum(x, 1, 3) =
457 | [[ 3, 18, 33],
458 | [ 6, 21, 36],
459 | [ 9, 24, 39],
460 | [ 7, 17, 27],
461 | [ 4, 9, 14]]
462 | """
463 | assert start_idx > 0 and end_idx > 0
464 | assert len(x.size()) == 2
465 | src_len, batch_size = x.size()
466 | # batch_size, 1, src_len
467 | x = x.t().unsqueeze(1)
468 | # batch_size, 1, src_len
469 | moving_sum_weight = x.new_ones([1, 1, end_idx + start_idx - 1])
470 |
471 | moving_sum = (
472 | torch.nn.functional.conv1d(
473 | x, moving_sum_weight, padding=start_idx + end_idx - 1
474 | )
475 | .squeeze(1)
476 | .t()
477 | )
478 | moving_sum = moving_sum[end_idx:-start_idx]
479 |
480 | assert src_len == moving_sum.size(0)
481 | assert batch_size == moving_sum.size(1)
482 |
483 | return moving_sum
484 |
--------------------------------------------------------------------------------
/src/model/monotonic_attention00.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import random
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class MonotonicEnergy(nn.Module):
10 | def __init__(
11 | self,
12 | kdim,
13 | qdim,
14 | adim,
15 | atype,
16 | n_heads,
17 | init_r,
18 | bias=True,
19 | param_init="",
20 | conv1d=False,
21 | conv_kernel_size=5,
22 | ):
23 | """Energy function for the monotonic attenion.
24 | Args:
25 | kdim (int): dimension of key
26 | qdim (int): dimension of quary
27 | adim (int): dimension of attention space
28 | atype (str): type of attention mechanism
29 | n_heads (int): number of monotonic attention heads
30 | init_r (int): initial value for offset r
31 | bias (bool): use bias term in linear layers
32 | param_init (str): parameter initialization method
33 | conv1d (bool): use 1D causal convolution for energy calculation
34 | conv_kernel_size (int): kernel size for 1D convolution
35 | """
36 | super().__init__()
37 |
38 | assert conv_kernel_size % 2 == 1, "Kernel size should be odd for 'same' conv."
39 | self.key = None
40 | self.mask = None
41 |
42 | self.atype = atype
43 | assert adim % n_heads == 0
44 | self.d_k = adim // n_heads
45 | self.n_heads = n_heads
46 | self.scale = math.sqrt(adim)
47 |
48 | if atype == "add":
49 | self.w_key = nn.Linear(kdim, adim)
50 | self.v = nn.Linear(adim, n_heads, bias=False)
51 | self.w_query = nn.Linear(qdim, adim, bias=False)
52 | elif atype == "scaled_dot":
53 | self.w_key = nn.Linear(kdim, adim, bias=bias)
54 | self.w_query = nn.Linear(qdim, adim, bias=bias)
55 | else:
56 | raise NotImplementedError(atype)
57 |
58 | self.r = nn.Parameter(torch.Tensor([init_r]))
59 | logger.info("init_r is initialized with %d" % init_r)
60 |
61 | self.conv1d = None
62 | if conv1d:
63 | self.conv1d = CausalConv1d(
64 | in_channels=kdim,
65 | out_channels=kdim,
66 | kernel_size=conv_kernel_size,
67 | param_init=param_init,
68 | )
69 | # padding=(conv_kernel_size - 1) // 2
70 |
71 | if atype == "add":
72 | self.v = nn.utils.weight_norm(self.v, name="weight", dim=0)
73 | # initialization
74 | self.v.weight_g.data = torch.Tensor([1 / adim]).sqrt()
75 | elif atype == "scaled_dot":
76 | if param_init == "xavier_uniform":
77 | self.reset_parameters(bias)
78 |
79 | def reset_parameters(self, bias):
80 | """Initialize parameters with Xavier uniform distribution."""
81 | logger.info(
82 | "===== Initialize %s with Xavier uniform distribution ====="
83 | % self.__class__.__name__
84 | )
85 | # NOTE: see https://github.com/pytorch/fairseq/blob/master/fairseq/modules/multihead_attention.py
86 | nn.init.xavier_uniform_(self.w_key.weight, gain=1 / math.sqrt(2))
87 | nn.init.xavier_uniform_(self.w_query.weight, gain=1 / math.sqrt(2))
88 | if bias:
89 | nn.init.constant_(self.w_key.bias, 0.0)
90 | nn.init.constant_(self.w_query.bias, 0.0)
91 |
92 | def reset(self):
93 | self.key = None
94 | self.mask = None
95 |
96 | def forward(self, key, query, mask, cache=False, boundary_leftmost=0):
97 | """Compute monotonic energy.
98 | Args:
99 | key (FloatTensor): `[B, klen, kdim]`
100 | query (FloatTensor): `[B, qlen, qdim]`
101 | mask (ByteTensor): `[B, qlen, klen]`
102 | cache (bool): cache key and mask
103 | Returns:
104 | e (FloatTensor): `[B, H_ma, qlen, klen]`
105 | """
106 | bs, klen, kdim = key.size()
107 | qlen = query.size(1)
108 |
109 | # Pre-computation of encoder-side features for computing scores
110 | if self.key is None or not cache:
111 | # 1d conv
112 | if self.conv1d is not None:
113 | key = torch.relu(self.conv1d(key))
114 | key = self.w_key(key).view(bs, -1, self.n_heads, self.d_k)
115 | self.key = key.transpose(2, 1).contiguous() # `[B, H_ma, klen, d_k]`
116 | self.mask = mask
117 | if mask is not None:
118 | self.mask = self.mask.unsqueeze(1).repeat(
119 | [1, self.n_heads, 1, 1]
120 | ) # `[B, H_ma, qlen, klen]`
121 | assert self.mask.size() == (bs, self.n_heads, qlen, klen), (
122 | self.mask.size(),
123 | (bs, self.n_heads, qlen, klen),
124 | )
125 |
126 | query = self.w_query(query).view(bs, -1, self.n_heads, self.d_k)
127 | query = query.transpose(2, 1).contiguous() # `[B, H_ma, qlen, d_k]`
128 | m = self.mask
129 |
130 | if self.atype == "add":
131 | k = self.key.unsqueeze(2) # `[B, H_ma, 1, klen, d_k]`
132 | # Truncate encoder memories
133 | if boundary_leftmost > 0:
134 | k = k[:, :, :, boundary_leftmost:]
135 | klen = k.size(3)
136 | if m is not None:
137 | m = m[:, :, :, boundary_leftmost:]
138 | e = torch.relu(k + query.unsqueeze(3)) # `[B, H_ma, qlen, klen, d_k]`
139 | e = e.permute(0, 2, 3, 1, 4).contiguous().view(bs, qlen, klen, -1)
140 | e = self.v(e).permute(0, 3, 1, 2) # `[B, qlen, klen, H_ma]`
141 | elif self.atype == "scaled_dot":
142 | k = self.key.transpose(3, 2)
143 | e = torch.matmul(query, k) / self.scale
144 |
145 | if self.r is not None:
146 | e = e + self.r
147 | if m is not None:
148 | NEG_INF = float(np.finfo(torch.tensor(0, dtype=e.dtype).numpy().dtype).min)
149 | e = e.masked_fill_(m == 0, NEG_INF)
150 | assert e.size() == (bs, self.n_heads, qlen, klen), (
151 | e.size(),
152 | (bs, self.n_heads, qlen, klen),
153 | )
154 | return e
155 |
156 |
157 | class ChunkEnergy(nn.Module):
158 | def __init__(self, kdim, qdim, adim, atype, n_heads=1, bias=True, param_init=""):
159 | """Energy function for the chunkwise attention.
160 | Args:
161 | kdim (int): dimension of key
162 | qdim (int): dimension of quary
163 | adim (int): dimension of attention space
164 | atype (str): type of attention mechanism
165 | n_heads (int): number of chunkwise attention heads
166 | bias (bool): use bias term in linear layers
167 | param_init (str): parameter initialization method
168 | """
169 | super().__init__()
170 |
171 | self.key = None
172 | self.mask = None
173 |
174 | self.atype = atype
175 | assert adim % n_heads == 0
176 | self.d_k = adim // n_heads
177 | self.n_heads = n_heads
178 | self.scale = math.sqrt(adim)
179 |
180 | if atype == "add":
181 | self.w_key = nn.Linear(kdim, adim)
182 | self.w_query = nn.Linear(qdim, adim, bias=False)
183 | self.v = nn.Linear(adim, n_heads, bias=False)
184 | elif atype == "scaled_dot":
185 | self.w_key = nn.Linear(kdim, adim, bias=bias)
186 | self.w_query = nn.Linear(qdim, adim, bias=bias)
187 | if param_init == "xavier_uniform":
188 | self.reset_parameters(bias)
189 | else:
190 | raise NotImplementedError(atype)
191 |
192 | def reset_parameters(self, bias):
193 | """Initialize parameters with Xavier uniform distribution."""
194 | logger.info(
195 | "===== Initialize %s with Xavier uniform distribution ====="
196 | % self.__class__.__name__
197 | )
198 | # NOTE: see https://github.com/pytorch/fairseq/blob/master/fairseq/modules/multihead_attention.py
199 | nn.init.xavier_uniform_(self.w_key.weight, gain=1 / math.sqrt(2))
200 | nn.init.xavier_uniform_(self.w_query.weight, gain=1 / math.sqrt(2))
201 | if bias:
202 | nn.init.constant_(self.w_key.bias, 0.0)
203 | nn.init.constant_(self.w_query.bias, 0.0)
204 |
205 | def reset(self):
206 | self.key = None
207 | self.mask = None
208 |
209 | def forward(
210 | self,
211 | key,
212 | query,
213 | mask,
214 | cache=False,
215 | boundary_leftmost=0,
216 | boundary_rightmost=10e6,
217 | ):
218 | """Compute chunkwise energy.
219 | Args:
220 | key (FloatTensor): `[B, klen, kdim]`
221 | query (FloatTensor): `[B, qlen, qdim]`
222 | mask (ByteTensor): `[B, qlen, klen]`
223 | cache (bool): cache key and mask
224 | Returns:
225 | e (FloatTensor): `[B, H_ca, qlen, klen]`
226 | """
227 | bs, klen, kdim = key.size()
228 | qlen = query.size(1)
229 |
230 | # Pre-computation of encoder-side features for computing scores
231 | if self.key is None or not cache:
232 | key = self.w_key(key).view(bs, -1, self.n_heads, self.d_k)
233 | self.key = key.transpose(2, 1).contiguous() # `[B, H_ca, klen, d_k]`
234 | self.mask = mask
235 | if mask is not None:
236 | self.mask = self.mask.unsqueeze(1).repeat(
237 | [1, self.n_heads, 1, 1]
238 | ) # `[B, H_ca, qlen, klen]`
239 | assert self.mask.size() == (bs, self.n_heads, qlen, klen), (
240 | self.mask.size(),
241 | (bs, self.n_heads, qlen, klen),
242 | )
243 |
244 | query = self.w_query(query).view(bs, -1, self.n_heads, self.d_k)
245 | query = query.transpose(2, 1).contiguous() # `[B, H_ca, qlen, d_k]`
246 | m = self.mask
247 |
248 | if self.atype == "add":
249 | k = self.key.unsqueeze(2) # `[B, H_ca, 1, klen, d_k]`
250 | # Truncate
251 | k = k[:, :, :, boundary_leftmost:boundary_rightmost]
252 | klen = k.size(3)
253 | if m is not None:
254 | m = m[:, :, :, boundary_leftmost:boundary_rightmost]
255 |
256 | r = torch.relu(k + query.unsqueeze(3)) # `[B, H_ca, qlen, klen, d_k]`
257 | r = (
258 | r.permute(0, 2, 3, 1, 4).contiguous().view(bs, qlen, klen, -1)
259 | ) # `[B, qlen, klen, H_ca * d_k]`
260 | r = self.v(r).permute(0, 3, 1, 2).contiguous() # `[B, H_ca, qlen, klen]`
261 | elif self.atype == "scaled_dot":
262 | k = self.key.transpose(3, 2)
263 | r = torch.matmul(query, k) / self.scale
264 |
265 | if m is not None:
266 | NEG_INF = float(np.finfo(torch.tensor(0, dtype=r.dtype).numpy().dtype).min)
267 | r = r.masked_fill_(m == 0, NEG_INF)
268 | assert r.size() == (bs, self.n_heads, qlen, klen), (
269 | r.size(),
270 | (bs, self.n_heads, qlen, klen),
271 | )
272 | return r
273 |
274 |
275 | class MonotonicAttention(nn.Module):
276 | def __init__(self, embed_dim, n_heads, dropout=0.1):
277 | super(MonotonicAttention, self).__init__()
278 |
279 | self.n_heads = n_heads
280 |
281 | self.monotonic_energy = MonotonicEnergy(
282 | kdim=embed_dim,
283 | qdim=embed_dim,
284 | adim=embed_dim,
285 | atype="scaled_dot",
286 | n_heads=n_heads,
287 | init_r=0,
288 | param_init="xavier_uniform",
289 | )
290 |
291 | self.chunk_energy = ChunkEnergy(
292 | kdim=embed_dim,
293 | qdim=embed_dim,
294 | adim=embed_dim,
295 | atype="scaled_dot",
296 | n_heads=n_heads,
297 | param_init="xavier_uniform",
298 | )
299 | self.dropout_attn = nn.Dropout(p=dropout) # for beta
300 | self.bd_offset = 0
301 |
302 | def forward(self, key, value, query, aw_prev=None, mask=None, cache=None):
303 |
304 | bs, klen = key.size()[:2]
305 | qlen = query.size(1)
306 |
307 | if aw_prev is None:
308 | # aw_prev = [1, 0, 0 ... 0]
309 | aw_prev = key.new_zeros(bs, self.n_heads_ma, 1, klen)
310 | aw_prev[:, :, :, 0:1] = key.new_ones(bs, self.n_heads_ma, 1, 1)
311 |
312 | # Compute monotonic energy
313 | e_ma = self.monotonic_energy(
314 | key, query, mask, cache=cache, boundary_leftmost=self.bd_offset
315 | ) # `[B, H_ma, qlen, klen]`
316 |
317 | p_choose = torch.sigmoid(
318 | add_gaussian_noise(e_ma, self.noise_std)
319 | ) # `[B, H_ma, qlen, klen]`
320 |
321 | # safe_cumprod computes cumprod in logspace with numeric checks
322 | cumprod_1mp_choose = safe_cumprod(
323 | 1 - p_choose, eps=self.eps
324 | ) # `[B, H_ma, qlen, klen]`
325 |
326 | # Compute recurrence relation solution
327 | alpha = []
328 | for i in range(qlen):
329 | denom = (
330 | 1
331 | if self.no_denom
332 | else torch.clamp(
333 | cumprod_1mp_choose[:, :, i : i + 1], min=self.eps, max=1.0
334 | )
335 | )
336 | aw_prev = (
337 | p_choose[:, :, i : i + 1]
338 | * cumprod_1mp_choose[:, :, i : i + 1]
339 | * torch.cumsum(aw_prev / denom, dim=-1)
340 | ) # `[B, H_ma, 1, klen]`
341 | # Mask the right part from the trigger point
342 | alpha.append(aw_prev)
343 |
344 | alpha = (
345 | torch.cat(alpha, dim=2) if qlen > 1 else alpha[-1]
346 | ) # `[B, H_ma, qlen, klen]`
347 | alpha_masked = alpha.clone()
348 |
349 | # Compute chunk energy
350 | e_ca = self.chunk_energy(
351 | key,
352 | query,
353 | mask,
354 | ) # `[B, (H_ma*)H_ca, qlen, ken]`
355 |
356 | beta = efficient_chunkwise_attention(
357 | alpha_masked,
358 | e_ca,
359 | mask,
360 | chunk_size=-1,
361 | n_heads_chunk=self.n_heads,
362 | sharpening_factor=1.0,
363 | share_chunkwise_attention=True,
364 | )
365 | beta = self.dropout_attn(beta)
366 |
367 | value = self.w_value(value).view(
368 | bs, -1, self.n_heads_ma * self.n_heads_ca, self.d_k
369 | )
370 | value = value.transpose(2, 1).contiguous() # `[B, H_ma * H_ca, klen, d_k]`
371 | cv = torch.matmul(
372 | alpha if self.w == 1 else beta, value
373 | ) # `[B, H_ma * H_ca, qlen, d_k]`
374 | cv = (
375 | cv.transpose(2, 1)
376 | .contiguous()
377 | .view(bs, -1, self.n_heads_ma * self.n_heads_ca * self.d_k)
378 | )
379 | cv = self.w_out(cv) # `[B, qlen, adim]`
380 |
381 |
382 | def add_gaussian_noise(xs, std):
383 | """Add Gaussian nosie to encourage discreteness."""
384 | noise = xs.new_zeros(xs.size()).normal_(std=std)
385 | return xs + noise
386 |
387 |
388 | def safe_cumprod(x, eps):
389 | """Numerically stable cumulative product by cumulative sum in log-space.
390 | Args:
391 | x (FloatTensor): `[B, H, qlen, klen]`
392 | Returns:
393 | x (FloatTensor): `[B, H, qlen, klen]`
394 | """
395 | return torch.exp(exclusive_cumsum(torch.log(torch.clamp(x, min=eps, max=1.0))))
396 |
397 |
398 | def exclusive_cumsum(x):
399 | """Exclusive cumulative summation [a, b, c] => [0, a, a + b].
400 | Args:
401 | x (FloatTensor): `[B, H, qlen, klen]`
402 | Returns:
403 | x (FloatTensor): `[B, H, qlen, klen]`
404 | """
405 | return torch.cumsum(
406 | torch.cat(
407 | [x.new_zeros(x.size(0), x.size(1), x.size(2), 1), x[:, :, :, :-1]], dim=-1
408 | ),
409 | dim=-1,
410 | )
411 |
412 |
413 | def exclusive_cumprod(x):
414 | """Exclusive cumulative product [a, b, c] => [1, a, a * b].
415 | Args:
416 | x (FloatTensor): `[B, H, qlen, klen]`
417 | Returns:
418 | x (FloatTensor): `[B, H, qlen, klen]`
419 | """
420 | return torch.cumprod(
421 | torch.cat(
422 | [x.new_ones(x.size(0), x.size(1), x.size(2), 1), x[:, :, :, :-1]], dim=-1
423 | ),
424 | dim=-1,
425 | )
426 |
427 |
428 | def efficient_chunkwise_attention(
429 | alpha,
430 | u,
431 | mask,
432 | chunk_size,
433 | n_heads_chunk,
434 | sharpening_factor,
435 | share_chunkwise_attention,
436 | ):
437 | """Compute chunkwise attention efficiently by clipping logits at training time.
438 | Args:
439 | alpha (FloatTensor): `[B, H_ma, qlen, klen]`
440 | u (FloatTensor): `[B, (H_ma*)H_ca, qlen, klen]`
441 | mask (ByteTensor): `[B, qlen, klen]`
442 | chunk_size (int): window size for chunkwise attention
443 | n_heads_chunk (int): number of chunkwise attention heads
444 | sharpening_factor (float): sharping factor for beta calculation
445 | share_chunkwise_attention (int): share CA heads among MA heads
446 | Returns:
447 | beta (FloatTensor): `[B, H_ma * H_ca, qlen, klen]`
448 | """
449 | bs, n_heads_mono, qlen, klen = alpha.size()
450 | alpha = alpha.unsqueeze(2) # `[B, H_ma, 1, qlen, klen]`
451 | u = u.unsqueeze(1) # `[B, 1, (H_ma*)H_ca, qlen, klen]`
452 | if n_heads_chunk > 1:
453 | alpha = alpha.repeat([1, 1, n_heads_chunk, 1, 1])
454 | if n_heads_mono > 1 and not share_chunkwise_attention:
455 | u = u.view(bs, n_heads_mono, n_heads_chunk, qlen, klen)
456 | # Shift logits to avoid overflow
457 | u -= torch.max(u, dim=-1, keepdim=True)[0]
458 | # Limit the range for numerical stability
459 | softmax_exp = torch.clamp(torch.exp(u), min=1e-5)
460 | # Compute chunkwise softmax denominators
461 | if chunk_size == -1:
462 | # infinite lookback attention
463 | # inner_items = alpha * sharpening_factor / torch.cumsum(softmax_exp, dim=-1)
464 | # beta = softmax_exp * torch.cumsum(inner_items.flip(dims=[-1]), dim=-1).flip(dims=[-1])
465 | # beta = beta.masked_fill(mask.unsqueeze(1), 0)
466 | # beta = beta / beta.sum(dim=-1, keepdim=True)
467 |
468 | softmax_denominators = torch.cumsum(softmax_exp, dim=-1)
469 | # Compute \beta_{i, :}. emit_probs are \alpha_{i, :}.
470 | beta = softmax_exp * moving_sum(
471 | alpha * sharpening_factor / softmax_denominators, back=0, forward=klen - 1
472 | )
473 | else:
474 | softmax_denominators = moving_sum(softmax_exp, back=chunk_size - 1, forward=0)
475 | # Compute \beta_{i, :}. emit_probs are \alpha_{i, :}.
476 | beta = softmax_exp * moving_sum(
477 | alpha * sharpening_factor / softmax_denominators,
478 | back=0,
479 | forward=chunk_size - 1,
480 | )
481 | return beta.view(bs, -1, qlen, klen)
482 |
483 |
484 | def moving_sum(x, back, forward):
485 | """Compute the moving sum of x over a chunk_size with the provided bounds.
486 | Args:
487 | x (FloatTensor): `[B, H_ma, H_ca, qlen, klen]`
488 | back (int):
489 | forward (int):
490 | Returns:
491 | x_sum (FloatTensor): `[B, H_ma, H_ca, qlen, klen]`
492 | """
493 | bs, n_heads_mono, n_heads_chunk, qlen, klen = x.size()
494 | x = x.view(-1, klen)
495 | # Moving sum is computed as a carefully-padded 1D convolution with ones
496 | x_padded = F.pad(
497 | x, pad=[back, forward]
498 | ) # `[B * H_ma * H_ca * qlen, back + klen + forward]`
499 | # Add a "channel" dimension
500 | x_padded = x_padded.unsqueeze(1)
501 | # Construct filters
502 | filters = x.new_ones(1, 1, back + forward + 1)
503 | x_sum = F.conv1d(x_padded, filters)
504 | x_sum = x_sum.squeeze(1).view(bs, n_heads_mono, n_heads_chunk, qlen, -1)
505 | return x_sum
506 |
--------------------------------------------------------------------------------
/src/model/positional_encoding.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class PositionalEncoding_2(nn.Module):
8 | r"""Inject some information about the relative or absolute position of the tokens
9 | in the sequence. The positional encodings have the same dimension as
10 | the embeddings, so that the two can be summed. Here, we use sine and cosine
11 | functions of different frequencies.
12 | .. math::
13 | \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
14 | \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
15 | \text{where pos is the word position and i is the embed idx)
16 | Args:
17 | d_model: the embed dim (required).
18 | dropout: the dropout value (default=0.1).
19 | max_len: the max. length of the incoming sequence (default=5000).
20 | """
21 |
22 | def __init__(self, d_model, dropout=0.1, max_len=5000):
23 | super(PositionalEncoding, self).__init__()
24 | self.dropout = nn.Dropout(p=dropout)
25 |
26 | pe = torch.zeros(max_len, d_model)
27 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
28 | div_term = torch.exp(
29 | torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
30 | )
31 | pe[:, 0::2] = torch.sin(position * div_term)
32 | pe[:, 1::2] = torch.cos(position * div_term)
33 | pe = pe.unsqueeze(0).transpose(0, 1)
34 | self.register_buffer("pe", pe)
35 |
36 | def forward(self, x):
37 | r"""Inputs of forward function
38 | Args:
39 | x: the sequence fed to the positional encoder model (required).
40 | Shape:
41 | x: [sequence length, batch size, embed dim]
42 | output: [sequence length, batch size, embed dim]
43 | Examples:
44 | """
45 |
46 | x = x + self.pe[: x.size(0), :]
47 | return self.dropout(x)
48 |
49 |
50 | class PositionalEncoding(nn.Module):
51 | r"""Inject some information about the relative or absolute position of the tokens
52 | in the sequence. The positional encodings have the same dimension as
53 | the embeddings, so that the two can be summed. Here, we use sine and cosine
54 | functions of different frequencies.
55 | .. math::
56 | \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
57 | \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
58 | \text{where pos is the word position and i is the embed idx)
59 | Args:
60 | d_model: the embed dim (required).
61 | dropout: the dropout value (default=0.1).
62 | max_len: the max. length of the incoming sequence (default=5000).
63 | """
64 |
65 | def __init__(self, d_model, dropout=0.1, max_len=5000):
66 | super(PositionalEncoding, self).__init__()
67 | self.dropout = nn.Dropout(p=dropout)
68 |
69 | pe = torch.zeros(max_len, d_model)
70 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
71 | div_term = torch.exp(
72 | torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
73 | )
74 | pe[:, 0::2] = torch.sin(position * div_term)
75 | pe[:, 1::2] = torch.sin(position * div_term)
76 | pe = pe.unsqueeze(0).transpose(0, 1)
77 | self.register_buffer("pe", pe)
78 |
79 | def forward(self, x):
80 | r"""Inputs of forward function
81 | Args:
82 | x: the sequence fed to the positional encoder model (required).
83 | Shape:
84 | x: [sequence length, batch size, embed dim]
85 | output: [sequence length, batch size, embed dim]
86 | Examples:
87 | """
88 |
89 | x = x + self.pe[: x.size(0), :]
90 | return self.dropout(x)
91 |
92 |
93 | class PositionalEncoding_00(nn.Module):
94 | r"""Inject some information about the relative or absolute position of the tokens
95 | in the sequence. The positional encodings have the same dimension as
96 | the embeddings, so that the two can be summed. Here, we use sine and cosine
97 | functions of different frequencies.
98 | .. math::
99 | \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
100 | \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
101 | \text{where pos is the word position and i is the embed idx)
102 | Args:
103 | d_model: the embed dim (required).
104 | dropout: the dropout value (default=0.1).
105 | max_len: the max. length of the incoming sequence (default=5000).
106 | """
107 |
108 | def __init__(self, d_model, max_len=5000):
109 | super(PositionalEncoding_00, self).__init__()
110 |
111 | pe = torch.zeros(max_len, d_model)
112 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
113 | div_term = torch.exp(
114 | torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
115 | )
116 | pe[:, 0::2] = torch.sin(position * div_term)
117 | pe[:, 1::2] = torch.cos(position * div_term)
118 | pe = pe.unsqueeze(0).transpose(0, 1)
119 | self.register_buffer("pe", pe)
120 |
121 | def forward(self, stop, step):
122 | r"""Inputs of forward function
123 | Args:
124 | x: the sequence fed to the positional encoder model (required).
125 | Shape:
126 | x: [sequence length, batch size, embed dim]
127 | output: [sequence length, batch size, embed dim]
128 | Examples:
129 | """
130 |
131 | return self.pe[0:stop:step, :]
132 |
133 |
134 | class PositionEmbeddingSine(nn.Module):
135 | """
136 | This is a more standard version of the position embedding, very similar to the one
137 | used by the Attention is all you need paper, generalized to work on images.
138 | """
139 |
140 | def __init__(
141 | self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
142 | ):
143 | super().__init__()
144 | self.num_pos_feats = num_pos_feats
145 | self.temperature = temperature
146 | self.normalize = normalize
147 | if scale is not None and normalize is False:
148 | raise ValueError("normalize should be True if scale is passed")
149 | if scale is None:
150 | scale = 2 * math.pi
151 | self.scale = scale
152 |
153 | def forward(self, x):
154 | """
155 | x: [B, C, h, w]
156 | """
157 | B, C, h, w = x.shape
158 | ones = torch.ones_like(x[:, 0])
159 | y_embed = ones.cumsum(dim=1)
160 | x_embed = ones.cumsum(dim=2)
161 |
162 | if self.normalize:
163 | eps = 1e-6
164 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
165 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
166 |
167 | i = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
168 | dim_t = self.temperature ** (2 * (i // 2) / self.num_pos_feats)
169 |
170 | pos_x = x_embed[:, :, :, None] / dim_t
171 | pos_y = y_embed[:, :, :, None] / dim_t
172 | pos_x = torch.stack(
173 | [pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()], dim=4
174 | ).flatten(3)
175 | pos_y = torch.stack(
176 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
177 | ).flatten(3)
178 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
179 | return pos
180 |
181 |
182 | class PositionalEncoding2DwDropout(nn.Module):
183 | def __init__(self, d_model, dropout=0.1, max_len=1000):
184 | super().__init__()
185 |
186 | self.dropout = nn.Dropout(p=dropout)
187 | self.d_model = d_model
188 |
189 | def forward(self, x, height, width):
190 | pe = torch.zeros(self.d_model, height, width, device=x.device)
191 |
192 | # Each dimension use half of d_model
193 | d_model = int(self.d_model / 2)
194 | div_term = torch.exp(
195 | torch.arange(0.0, d_model, 2) * -(math.log(10000.0) / d_model)
196 | ).cuda()
197 | pos_w = torch.arange(0.0, width // 2, device=x.device)
198 | pos_w = torch.cat([pos_w, pos_w.flip(dims=[0])]).unsqueeze(1)
199 | pos_h = torch.arange(0.0, height, device=x.device).unsqueeze(1)
200 | pe[0:d_model:2, :, :] = (
201 | torch.sin(pos_w * div_term)
202 | .transpose(0, 1)
203 | .unsqueeze(1)
204 | .repeat(1, height, 1)
205 | )
206 | pe[1:d_model:2, :, :] = (
207 | torch.cos(pos_w * div_term)
208 | .transpose(0, 1)
209 | .unsqueeze(1)
210 | .repeat(1, height, 1)
211 | )
212 |
213 | pe[d_model::2, :, :] = (
214 | torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
215 | )
216 | pe[d_model + 1 :: 2, :, :] = (
217 | torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
218 | )
219 |
220 | x = x + pe[None, :, :, :]
221 | return self.dropout(x)
222 |
223 |
224 | class PositionalEncoding3D(nn.Module):
225 | def __init__(self, channels):
226 | """
227 | :param channels: The last dimension of the tensor you want to apply pos emb to.
228 | """
229 | super(PositionalEncoding3D, self).__init__()
230 | channels = int(np.ceil(channels / 3))
231 | if channels % 2:
232 | channels += 1
233 | self.channels = channels
234 | inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))
235 | self.register_buffer("inv_freq", inv_freq)
236 |
237 | def forward(self, tensor):
238 | """
239 | :param tensor: A 5d tensor of size (batch_size, x, y, z, ch)
240 | :return: Positional Encoding Matrix of size (batch_size, x, y, z, ch)
241 | """
242 | if len(tensor.shape) != 5:
243 | raise RuntimeError("The input tensor has to be 5d!")
244 |
245 | _, x, y, z, orig_ch = tensor.shape
246 |
247 | pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type())
248 | pos_y = torch.arange(y, device=tensor.device).type(self.inv_freq.type())
249 | pos_z = torch.arange(z, device=tensor.device).type(self.inv_freq.type())
250 | sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
251 | sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq)
252 | sin_inp_z = torch.einsum("i,j->ij", pos_z, self.inv_freq)
253 | emb_x = (
254 | torch.cat((sin_inp_x.sin(), sin_inp_x.cos()), dim=-1)
255 | .unsqueeze(1)
256 | .unsqueeze(1)
257 | )
258 | emb_y = torch.cat((sin_inp_y.sin(), sin_inp_y.cos()), dim=-1).unsqueeze(1)
259 | emb_z = torch.cat((sin_inp_z.sin(), sin_inp_z.cos()), dim=-1)
260 | emb = torch.zeros((x, y, z, self.channels * 3), device=tensor.device).type(
261 | tensor.type()
262 | )
263 | emb[:, :, :, : self.channels] = emb_x
264 | emb[:, :, :, self.channels : 2 * self.channels] = emb_y
265 | emb[:, :, :, 2 * self.channels :] = emb_z
266 |
267 | return emb[None, :, :, :, :orig_ch]
268 |
--------------------------------------------------------------------------------