├── .gitignore ├── LICENSE ├── README.md ├── config ├── xtrans_scanrefer.yaml └── xtrans_scanrefer_rl.yaml ├── data └── scannet │ ├── meta_data │ ├── nyu40_labels.csv │ ├── referit-labels.tsv │ ├── render_option.json │ ├── scannet_means.npz │ ├── scannet_reference_means.npz │ ├── scannetv2-labels.combined.tsv │ ├── scannetv2.txt │ ├── scannetv2_test.txt │ ├── scannetv2_train.txt │ └── scannetv2_val.txt │ └── model_util_scannet.py ├── figures └── pipeline.png ├── in_out ├── __pycache__ │ └── arguments.cpython-36.pyc └── arguments.py ├── lib ├── ap_helper.py ├── capeval │ ├── bleu │ │ ├── .gitignore │ │ ├── __init__.py │ │ ├── bleu.py │ │ └── bleu_scorer.py │ ├── cider │ │ ├── .gitignore │ │ ├── __init__.py │ │ ├── cider.py │ │ └── cider_scorer.py │ ├── meteor │ │ ├── .gitignore │ │ ├── __init__.py │ │ ├── data │ │ │ └── paraphrase-en.gz │ │ ├── meteor-1.5.jar │ │ └── meteor.py │ └── rouge │ │ ├── .gitignore │ │ ├── __init__.py │ │ └── rouge.py ├── config.py ├── dataset.py ├── eval_helper.py ├── loss_helper.py ├── pointnet2 │ ├── _ext_src │ │ ├── include │ │ │ ├── ball_query.h │ │ │ ├── cuda_utils.h │ │ │ ├── group_points.h │ │ │ ├── interpolate.h │ │ │ ├── sampling.h │ │ │ └── utils.h │ │ └── src │ │ │ ├── ball_query.cpp │ │ │ ├── ball_query_gpu.cu │ │ │ ├── bindings.cpp │ │ │ ├── group_points.cpp │ │ │ ├── group_points_gpu.cu │ │ │ ├── interpolate.cpp │ │ │ ├── interpolate_gpu.cu │ │ │ ├── sampling.cpp │ │ │ └── sampling_gpu.cu │ ├── _version.py │ ├── pointnet2_modules.py │ ├── pointnet2_test.py │ ├── pointnet2_utils.py │ ├── pytorch_utils.py │ └── setup.py ├── reference_dataset.py └── solver.py ├── models ├── backbone_module.py ├── capnet.py ├── proposal_module.py ├── transformer │ ├── __pycache__ │ │ ├── attention.cpython-36.pyc │ │ ├── beam_search.cpython-36.pyc │ │ ├── containers.cpython-36.pyc │ │ ├── decoders.cpython-36.pyc │ │ ├── encoders.cpython-36.pyc │ │ ├── m2_transformer.cpython-36.pyc │ │ └── utils.cpython-36.pyc │ ├── attention.py │ ├── beam_search.py │ ├── containers.py │ ├── decoders.py │ ├── encoders.py │ ├── m2_transformer.py │ └── utils.py ├── utils.py ├── voting_module.py └── xtrans.py ├── scripts ├── eval.py ├── organize.py └── train.py └── utils ├── __init__.py ├── box_util.py ├── eta.py ├── eval_det.py ├── metric_util.py ├── nms.py ├── nn_distance.py └── pc_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # dataset 2 | data/scanrefer* 3 | data/ScanRefer* 4 | data/glove* 5 | data/scannet/scannet_data 6 | data/scannet/scans 7 | data/scannetv2_enet.pth 8 | data/Scan2CAD_dataset/ 9 | data/votenet_estimated_viewpoint* 10 | data/nr3d* 11 | 12 | # cache 13 | __pycache__/ 14 | 15 | # pointnet2 16 | lib/pointnet2/build/ 17 | lib/pointnet2/dist/ 18 | lib/pointnet2/pointnet2.egg-info/ 19 | 20 | # # pretrained models 21 | # pretrained/ 22 | 23 | # logs 24 | logs/ 25 | 26 | # output 27 | outputs/ 28 | archive/ 29 | 30 | # node 31 | viewer/node_modules/ 32 | viewer/server/node_modules/ 33 | 34 | # build 35 | viewer/client/build/ 36 | 37 | # features 38 | gt_features 39 | gt_features/ 40 | data/glove.p 41 | data/scannet/scans 42 | data/scannet/scannet_data 43 | votenet_features 44 | votenet_features/ 45 | 46 | # misc 47 | viewer/server/static/ScanNetv2/ 48 | viewer/server/static/ScanNetv1/scans 49 | viewer/server/static/ScanNetv2/scans 50 | viewer/server/static/ScanNetv2/ScanNet_objects 51 | pretrained_bbox/ 52 | 53 | # Mac 54 | .DS_Store 55 | data/.DS_Store 56 | 57 | # IDE 58 | .vscode 59 | .idea 60 | data/2d_feature_agg.npz 61 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # X-Trans2Cap 2 | **[CVPR2022]** X-Trans2Cap: Cross-Modal Knowledge Transfer using Transformer for 3D Dense Captioning [[Arxiv Paper]](https://arxiv.org/abs/2203.00843) 3 | 4 | Zhihao Yuan, [Xu Yan](https://github.com/yanx27), Yinghong Liao, Yao Guo, Guanbin Li, Shuguang Cui, [Zhen Li*](https://mypage.cuhk.edu.cn/academics/lizhen/) 5 | ![](figures/pipeline.png) 6 | 7 | ## Citation 8 | 9 | If you find our work useful in your research, please consider citing: 10 | ```bibtex 11 | @InProceedings{Yuan_2022_CVPR, 12 | author = {Yuan, Zhihao and Yan, Xu and Liao, Yinghong and Guo, Yao and Li, Guanbin and Cui, Shuguang and Li, Zhen}, 13 | title = {X-Trans2Cap: Cross-Modal Knowledge Transfer Using Transformer for 3D Dense Captioning}, 14 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 15 | month = {June}, 16 | year = {2022}, 17 | pages = {8563-8573} 18 | } 19 | ``` 20 | 21 | ## Prerequisites 22 | * Python 3.6.9 (e.g., conda create -n xtrans_env python=3.6.9) 23 | * Pytorch 1.7.1 (e.g., conda install pytorch==1.7.1 cudatoolkit=11.0 -c pytorch) 24 | * Install other common packages (numpy, [transformers](https://huggingface.co/docs/transformers/index), etc.) 25 | 26 | ## Installation 27 | - Clone the repository 28 | 29 | ``` 30 | git clone https://github.com/CurryYuan/X-Trans2Cap.git 31 | ``` 32 | 33 | - To use a PointNet++ visual-encoder you need to compile its CUDA layers for [PointNet++](http://arxiv.org/abs/1706.02413): 34 | ```Note: To do this compilation also need: gcc5.4 or later.``` 35 | ``` 36 | cd lib/pointnet2 37 | python setup.py install 38 | ``` 39 | 40 | ## Data 41 | 42 | ### ScanRefer 43 | 44 | If you would like to access to the ScanRefer dataset, please fill out [this form](https://forms.gle/aLtzXN12DsYDMSXX6). Once your request is accepted, you will receive an email with the download link. 45 | 46 | > Note: In addition to language annotations in ScanRefer dataset, you also need to access the original ScanNet dataset. Please refer to the [ScanNet Instructions](data/scannet/README.md) for more details. 47 | 48 | Download the dataset by simply executing the wget command: 49 | ```shell 50 | wget 51 | ``` 52 | 53 | Run this commoand to organize the ScanRefer data: 54 | ```bash 55 | python scripts/organize_data.py 56 | ``` 57 | 58 | ### Processed 2D Features 59 | You can download the processed 2D Image features from [OneDrive](https://cuhko365-my.sharepoint.com/:u:/g/personal/221019046_link_cuhk_edu_cn/EYoVKnDvr89OoWstNIK2aDEBWjBmxAovQjg6bP34xZ3j2w?e=zvGRom). The feature extraction code is borrowed from [bottom-up-attention.pytorch](https://github.com/MILVLG/bottom-up-attention.pytorch). 60 | 61 | Change the data path in `lib/config.py`. 62 | 63 | ## Training 64 | 65 | Run this command to train the model: 66 | 67 | ```bash 68 | python scripts/train.py --config config/xtrans_scanrefer.yaml 69 | ``` 70 | 71 | Run CIDEr optimization: 72 | ```bash 73 | python scripts/train.py --config config/xtrans_scanrefer_rl.yaml 74 | ``` 75 | 76 | Our code also support training on Nr3D/Sr3D dataset. Please organize data as ScanRefer, and change the argument `dataset` in config file. 77 | 78 | ## Evaluation 79 | 80 | ```bash 81 | python scripts/eval.py --config config/xtrans_scanrefer.yaml --use_pretrained xtrans_scanrefer_rl --force 82 | ``` 83 | -------------------------------------------------------------------------------- /config/xtrans_scanrefer.yaml: -------------------------------------------------------------------------------- 1 | GENERAL: 2 | tag: xtrans_scanrefer 3 | model: xtrans 4 | dataset: ScanRefer 5 | dataloader: dataset 6 | detection: False 7 | use_gt_ins: True 8 | mode: gt 9 | use_tf: True 10 | num_proposals: 64 11 | use_rl: False 12 | 13 | DATA: 14 | use_color: True 15 | num_points: 40000 16 | 17 | TRAIN: 18 | batch_size: 32 19 | epoch: 25 20 | use_pretrain: False 21 | lr: 0.0001 22 | -------------------------------------------------------------------------------- /config/xtrans_scanrefer_rl.yaml: -------------------------------------------------------------------------------- 1 | GENERAL: 2 | tag: scanrefer_xtrans_rl 3 | model: xtrans 4 | dataset: ScanRefer 5 | dataloader: dataset 6 | detection: False 7 | use_gt_ins: True 8 | mode: gt 9 | use_tf: True 10 | num_proposals: 64 11 | use_rl: True 12 | pretrained_path: scanrefer_xtrans 13 | 14 | DATA: 15 | use_color: True 16 | num_points: 40000 17 | 18 | TRAIN: 19 | batch_size: 32 20 | epoch: 5 21 | use_pretrain: False 22 | lr: 0.00001 23 | 24 | -------------------------------------------------------------------------------- /data/scannet/meta_data/nyu40_labels.csv: -------------------------------------------------------------------------------- 1 | nyu40id,nyu40class,mappedId,mappedIdConsecutive,weight 2 | 1,wall,(ignore),19,0.0 3 | 2,floor,(ignore),19,0.0 4 | 3,cabinet,3,1,3.9644974086960434 5 | 4,bed,4,2,5.459494152836571 6 | 5,chair,5,3,2.241522691584157 7 | 6,sofa,6,4,4.820655512680854 8 | 7,table,7,5,3.565918577548873 9 | 8,door,8,6,3.538498341919445 10 | 9,window,9,7,4.636521236560596 11 | 10,bookshelf,10,8,5.445050937449535 12 | 11,picture,11,9,5.079250281008131 13 | 12,counter,12,10,6.2030429647735845 14 | 13,blinds,(ignore),19,0.0 15 | 14,desk,14,11,4.622662494840168 16 | 15,shelves,(ignore),19,0.0 17 | 16,curtain,16,12,5.956294301248057 18 | 17,dresser,(ignore),19,0.0 19 | 18,pillow,(ignore),19,0.0 20 | 19,mirror,(ignore),19,0.0 21 | 20,floor_mat,(ignore),19,0.0 22 | 21,clothes,(ignore),19,0.0 23 | 22,ceiling,(ignore),19,0.0 24 | 23,books,(ignore),19,0.0 25 | 24,refridgerator,24,13,5.459141107819665 26 | 25,television,(ignore),19,0.0 27 | 26,paper,(ignore),19,0.0 28 | 27,towel,(ignore),19,0.0 29 | 28,shower_curtain,28,14,6.724871661883906 30 | 29,box,(ignore),19,0.0 31 | 30,whiteboard,(ignore),19,0.0 32 | 31,person,(ignore),19,0.0 33 | 32,night_stand,(ignore),19,0.0 34 | 33,toilet,33,15,5.832442848923174 35 | 34,sink,34,16,5.064773947290611 36 | 35,lamp,(ignore),19,0.0 37 | 36,bathtub,36,17,6.738988357113375 38 | 37,bag,(ignore),19,0.0 39 | 38,otherstructure,(ignore),19,0.0 40 | 39,otherfurniture,39,18,3.375217918833916 41 | 40,otherprop,(ignore),19,0.0 -------------------------------------------------------------------------------- /data/scannet/meta_data/render_option.json: -------------------------------------------------------------------------------- 1 | { 2 | "background_color" : [ 1, 1, 1 ], 3 | "class_name" : "RenderOption", 4 | "default_mesh_color" : [ 0.69999999999999996, 0.69999999999999996, 0.69999999999999996 ], 5 | "image_max_depth" : 3000, 6 | "image_stretch_option" : 0, 7 | "interpolation_option" : 0, 8 | "light0_color" : [ 1, 1, 1 ], 9 | "light0_diffuse_power" : 0.66000000000000003, 10 | "light0_position" : [ 0, 0, 2 ], 11 | "light0_specular_power" : 0.20000000000000001, 12 | "light0_specular_shininess" : 100, 13 | "light1_color" : [ 1, 1, 1 ], 14 | "light1_diffuse_power" : 0.66000000000000003, 15 | "light1_position" : [ 0, 0, 2 ], 16 | "light1_specular_power" : 0.20000000000000001, 17 | "light1_specular_shininess" : 100, 18 | "light2_color" : [ 1, 1, 1 ], 19 | "light2_diffuse_power" : 0.66000000000000003, 20 | "light2_position" : [ 0, 0, -2 ], 21 | "light2_specular_power" : 0.20000000000000001, 22 | "light2_specular_shininess" : 100, 23 | "light3_color" : [ 1, 1, 1 ], 24 | "light3_diffuse_power" : 0.66000000000000003, 25 | "light3_position" : [ 0, 0, -2 ], 26 | "light3_specular_power" : 0.20000000000000001, 27 | "light3_specular_shininess" : 100, 28 | "light_ambient_color" : [ 0, 0, 0 ], 29 | "light_on" : true, 30 | "mesh_color_option" : 1, 31 | "mesh_shade_option" : 0, 32 | "mesh_show_back_face" : false, 33 | "mesh_show_wireframe" : false, 34 | "point_color_option" : 9, 35 | "point_show_normal" : false, 36 | "point_size" : 5, 37 | "show_coordinate_frame" : false, 38 | "version_major" : 1, 39 | "version_minor" : 0 40 | } -------------------------------------------------------------------------------- /data/scannet/meta_data/scannet_means.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CurryYuan/X-Trans2Cap/aebe6e2d421034f2de8742fa9946e669bcb497e6/data/scannet/meta_data/scannet_means.npz -------------------------------------------------------------------------------- /data/scannet/meta_data/scannet_reference_means.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CurryYuan/X-Trans2Cap/aebe6e2d421034f2de8742fa9946e669bcb497e6/data/scannet/meta_data/scannet_reference_means.npz -------------------------------------------------------------------------------- /data/scannet/meta_data/scannetv2_test.txt: -------------------------------------------------------------------------------- 1 | scene0707_00 2 | scene0708_00 3 | scene0709_00 4 | scene0710_00 5 | scene0711_00 6 | scene0712_00 7 | scene0713_00 8 | scene0714_00 9 | scene0715_00 10 | scene0716_00 11 | scene0717_00 12 | scene0718_00 13 | scene0719_00 14 | scene0720_00 15 | scene0721_00 16 | scene0722_00 17 | scene0723_00 18 | scene0724_00 19 | scene0725_00 20 | scene0726_00 21 | scene0727_00 22 | scene0728_00 23 | scene0729_00 24 | scene0730_00 25 | scene0731_00 26 | scene0732_00 27 | scene0733_00 28 | scene0734_00 29 | scene0735_00 30 | scene0736_00 31 | scene0737_00 32 | scene0738_00 33 | scene0739_00 34 | scene0740_00 35 | scene0741_00 36 | scene0742_00 37 | scene0743_00 38 | scene0744_00 39 | scene0745_00 40 | scene0746_00 41 | scene0747_00 42 | scene0748_00 43 | scene0749_00 44 | scene0750_00 45 | scene0751_00 46 | scene0752_00 47 | scene0753_00 48 | scene0754_00 49 | scene0755_00 50 | scene0756_00 51 | scene0757_00 52 | scene0758_00 53 | scene0759_00 54 | scene0760_00 55 | scene0761_00 56 | scene0762_00 57 | scene0763_00 58 | scene0764_00 59 | scene0765_00 60 | scene0766_00 61 | scene0767_00 62 | scene0768_00 63 | scene0769_00 64 | scene0770_00 65 | scene0771_00 66 | scene0772_00 67 | scene0773_00 68 | scene0774_00 69 | scene0775_00 70 | scene0776_00 71 | scene0777_00 72 | scene0778_00 73 | scene0779_00 74 | scene0780_00 75 | scene0781_00 76 | scene0782_00 77 | scene0783_00 78 | scene0784_00 79 | scene0785_00 80 | scene0786_00 81 | scene0787_00 82 | scene0788_00 83 | scene0789_00 84 | scene0790_00 85 | scene0791_00 86 | scene0792_00 87 | scene0793_00 88 | scene0794_00 89 | scene0795_00 90 | scene0796_00 91 | scene0797_00 92 | scene0798_00 93 | scene0799_00 94 | scene0800_00 95 | scene0801_00 96 | scene0802_00 97 | scene0803_00 98 | scene0804_00 99 | scene0805_00 100 | scene0806_00 101 | -------------------------------------------------------------------------------- /data/scannet/meta_data/scannetv2_val.txt: -------------------------------------------------------------------------------- 1 | scene0011_00 2 | scene0011_01 3 | scene0015_00 4 | scene0019_00 5 | scene0019_01 6 | scene0025_00 7 | scene0025_01 8 | scene0025_02 9 | scene0030_00 10 | scene0030_01 11 | scene0030_02 12 | scene0046_00 13 | scene0046_01 14 | scene0046_02 15 | scene0050_00 16 | scene0050_01 17 | scene0050_02 18 | scene0063_00 19 | scene0064_00 20 | scene0064_01 21 | scene0077_00 22 | scene0077_01 23 | scene0081_00 24 | scene0081_01 25 | scene0081_02 26 | scene0084_00 27 | scene0084_01 28 | scene0084_02 29 | scene0086_00 30 | scene0086_01 31 | scene0086_02 32 | scene0088_00 33 | scene0088_01 34 | scene0088_02 35 | scene0088_03 36 | scene0095_00 37 | scene0095_01 38 | scene0100_00 39 | scene0100_01 40 | scene0100_02 41 | scene0131_00 42 | scene0131_01 43 | scene0131_02 44 | scene0139_00 45 | scene0144_00 46 | scene0144_01 47 | scene0146_00 48 | scene0146_01 49 | scene0146_02 50 | scene0149_00 51 | scene0153_00 52 | scene0153_01 53 | scene0164_00 54 | scene0164_01 55 | scene0164_02 56 | scene0164_03 57 | scene0169_00 58 | scene0169_01 59 | scene0187_00 60 | scene0187_01 61 | scene0193_00 62 | scene0193_01 63 | scene0196_00 64 | scene0203_00 65 | scene0203_01 66 | scene0203_02 67 | scene0207_00 68 | scene0207_01 69 | scene0207_02 70 | scene0208_00 71 | scene0217_00 72 | scene0221_00 73 | scene0221_01 74 | scene0222_00 75 | scene0222_01 76 | scene0231_00 77 | scene0231_01 78 | scene0231_02 79 | scene0246_00 80 | scene0249_00 81 | scene0251_00 82 | scene0256_00 83 | scene0256_01 84 | scene0256_02 85 | scene0257_00 86 | scene0277_00 87 | scene0277_01 88 | scene0277_02 89 | scene0278_00 90 | scene0278_01 91 | scene0300_00 92 | scene0300_01 93 | scene0304_00 94 | scene0307_00 95 | scene0307_01 96 | scene0307_02 97 | scene0314_00 98 | scene0316_00 99 | scene0328_00 100 | scene0329_00 101 | scene0329_01 102 | scene0329_02 103 | scene0334_00 104 | scene0334_01 105 | scene0334_02 106 | scene0338_00 107 | scene0338_01 108 | scene0338_02 109 | scene0342_00 110 | scene0343_00 111 | scene0351_00 112 | scene0351_01 113 | scene0353_00 114 | scene0353_01 115 | scene0353_02 116 | scene0354_00 117 | scene0355_00 118 | scene0355_01 119 | scene0356_00 120 | scene0356_01 121 | scene0356_02 122 | scene0357_00 123 | scene0357_01 124 | scene0377_00 125 | scene0377_01 126 | scene0377_02 127 | scene0378_00 128 | scene0378_01 129 | scene0378_02 130 | scene0382_00 131 | scene0382_01 132 | scene0389_00 133 | scene0406_00 134 | scene0406_01 135 | scene0406_02 136 | scene0412_00 137 | scene0412_01 138 | scene0414_00 139 | scene0423_00 140 | scene0423_01 141 | scene0423_02 142 | scene0426_00 143 | scene0426_01 144 | scene0426_02 145 | scene0426_03 146 | scene0427_00 147 | scene0430_00 148 | scene0430_01 149 | scene0432_00 150 | scene0432_01 151 | scene0435_00 152 | scene0435_01 153 | scene0435_02 154 | scene0435_03 155 | scene0441_00 156 | scene0458_00 157 | scene0458_01 158 | scene0461_00 159 | scene0462_00 160 | scene0474_00 161 | scene0474_01 162 | scene0474_02 163 | scene0474_03 164 | scene0474_04 165 | scene0474_05 166 | scene0488_00 167 | scene0488_01 168 | scene0490_00 169 | scene0494_00 170 | scene0496_00 171 | scene0500_00 172 | scene0500_01 173 | scene0518_00 174 | scene0527_00 175 | scene0535_00 176 | scene0549_00 177 | scene0549_01 178 | scene0550_00 179 | scene0552_00 180 | scene0552_01 181 | scene0553_00 182 | scene0553_01 183 | scene0553_02 184 | scene0558_00 185 | scene0558_01 186 | scene0558_02 187 | scene0559_00 188 | scene0559_01 189 | scene0559_02 190 | scene0565_00 191 | scene0568_00 192 | scene0568_01 193 | scene0568_02 194 | scene0574_00 195 | scene0574_01 196 | scene0574_02 197 | scene0575_00 198 | scene0575_01 199 | scene0575_02 200 | scene0578_00 201 | scene0578_01 202 | scene0578_02 203 | scene0580_00 204 | scene0580_01 205 | scene0583_00 206 | scene0583_01 207 | scene0583_02 208 | scene0591_00 209 | scene0591_01 210 | scene0591_02 211 | scene0593_00 212 | scene0593_01 213 | scene0595_00 214 | scene0598_00 215 | scene0598_01 216 | scene0598_02 217 | scene0599_00 218 | scene0599_01 219 | scene0599_02 220 | scene0606_00 221 | scene0606_01 222 | scene0606_02 223 | scene0607_00 224 | scene0607_01 225 | scene0608_00 226 | scene0608_01 227 | scene0608_02 228 | scene0609_00 229 | scene0609_01 230 | scene0609_02 231 | scene0609_03 232 | scene0616_00 233 | scene0616_01 234 | scene0618_00 235 | scene0621_00 236 | scene0629_00 237 | scene0629_01 238 | scene0629_02 239 | scene0633_00 240 | scene0633_01 241 | scene0643_00 242 | scene0644_00 243 | scene0645_00 244 | scene0645_01 245 | scene0645_02 246 | scene0647_00 247 | scene0647_01 248 | scene0648_00 249 | scene0648_01 250 | scene0651_00 251 | scene0651_01 252 | scene0651_02 253 | scene0652_00 254 | scene0653_00 255 | scene0653_01 256 | scene0655_00 257 | scene0655_01 258 | scene0655_02 259 | scene0658_00 260 | scene0660_00 261 | scene0663_00 262 | scene0663_01 263 | scene0663_02 264 | scene0664_00 265 | scene0664_01 266 | scene0664_02 267 | scene0665_00 268 | scene0665_01 269 | scene0670_00 270 | scene0670_01 271 | scene0671_00 272 | scene0671_01 273 | scene0678_00 274 | scene0678_01 275 | scene0678_02 276 | scene0684_00 277 | scene0684_01 278 | scene0685_00 279 | scene0685_01 280 | scene0685_02 281 | scene0686_00 282 | scene0686_01 283 | scene0686_02 284 | scene0689_00 285 | scene0690_00 286 | scene0690_01 287 | scene0693_00 288 | scene0693_01 289 | scene0693_02 290 | scene0695_00 291 | scene0695_01 292 | scene0695_02 293 | scene0695_03 294 | scene0696_00 295 | scene0696_01 296 | scene0696_02 297 | scene0697_00 298 | scene0697_01 299 | scene0697_02 300 | scene0697_03 301 | scene0699_00 302 | scene0700_00 303 | scene0700_01 304 | scene0700_02 305 | scene0701_00 306 | scene0701_01 307 | scene0701_02 308 | scene0702_00 309 | scene0702_01 310 | scene0702_02 311 | scene0704_00 312 | scene0704_01 313 | -------------------------------------------------------------------------------- /data/scannet/model_util_scannet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from: https://github.com/facebookresearch/votenet/blob/master/scannet/model_util_scannet.py 3 | """ 4 | 5 | import numpy as np 6 | import sys 7 | import os 8 | 9 | import torch 10 | 11 | sys.path.append(os.path.join(os.getcwd(), os.pardir, "lib")) # HACK add the lib folder 12 | from lib.config import CONF 13 | 14 | 15 | def in_hull(p, hull): 16 | from scipy.spatial import Delaunay 17 | if not isinstance(hull, Delaunay): 18 | hull = Delaunay(hull) 19 | return hull.find_simplex(p) >= 0 20 | 21 | 22 | def extract_pc_in_box3d(pc, box3d): 23 | ''' pc: (N,3), box3d: (8,3) ''' 24 | box3d_roi_inds = in_hull(pc[:, 0:3], box3d) 25 | return pc[box3d_roi_inds, :], box3d_roi_inds 26 | 27 | 28 | def rotate_aligned_boxes(input_boxes, rot_mat): 29 | centers, lengths = input_boxes[:, 0:3], input_boxes[:, 3:6] 30 | new_centers = np.dot(centers, np.transpose(rot_mat)) 31 | 32 | dx, dy = lengths[:, 0] / 2.0, lengths[:, 1] / 2.0 33 | new_x = np.zeros((dx.shape[0], 4)) 34 | new_y = np.zeros((dx.shape[0], 4)) 35 | 36 | for i, crnr in enumerate([(-1, -1), (1, -1), (1, 1), (-1, 1)]): 37 | crnrs = np.zeros((dx.shape[0], 3)) 38 | crnrs[:, 0] = crnr[0] * dx 39 | crnrs[:, 1] = crnr[1] * dy 40 | crnrs = np.dot(crnrs, np.transpose(rot_mat)) 41 | new_x[:, i] = crnrs[:, 0] 42 | new_y[:, i] = crnrs[:, 1] 43 | 44 | new_dx = 2.0 * np.max(new_x, 1) 45 | new_dy = 2.0 * np.max(new_y, 1) 46 | new_lengths = np.stack((new_dx, new_dy, lengths[:, 2]), axis=1) 47 | 48 | return np.concatenate([new_centers, new_lengths], axis=1) 49 | 50 | 51 | def rotate_aligned_boxes_along_axis(input_boxes, rot_mat, axis): 52 | centers, lengths = input_boxes[:, 0:3], input_boxes[:, 3:6] 53 | new_centers = np.dot(centers, np.transpose(rot_mat)) 54 | 55 | if axis == "x": 56 | d1, d2 = lengths[:, 1] / 2.0, lengths[:, 2] / 2.0 57 | elif axis == "y": 58 | d1, d2 = lengths[:, 0] / 2.0, lengths[:, 2] / 2.0 59 | else: 60 | d1, d2 = lengths[:, 0] / 2.0, lengths[:, 1] / 2.0 61 | 62 | new_1 = np.zeros((d1.shape[0], 4)) 63 | new_2 = np.zeros((d1.shape[0], 4)) 64 | 65 | for i, crnr in enumerate([(-1, -1), (1, -1), (1, 1), (-1, 1)]): 66 | crnrs = np.zeros((d1.shape[0], 3)) 67 | crnrs[:, 0] = crnr[0] * d1 68 | crnrs[:, 1] = crnr[1] * d2 69 | crnrs = np.dot(crnrs, np.transpose(rot_mat)) 70 | new_1[:, i] = crnrs[:, 0] 71 | new_2[:, i] = crnrs[:, 1] 72 | 73 | new_d1 = 2.0 * np.max(new_1, 1) 74 | new_d2 = 2.0 * np.max(new_2, 1) 75 | 76 | if axis == "x": 77 | new_lengths = np.stack((lengths[:, 0], new_d1, new_d2), axis=1) 78 | elif axis == "y": 79 | new_lengths = np.stack((new_d1, lengths[:, 1], new_d2), axis=1) 80 | else: 81 | new_lengths = np.stack((new_d1, new_d2, lengths[:, 2]), axis=1) 82 | 83 | return np.concatenate([new_centers, new_lengths], axis=1) 84 | 85 | 86 | class ScannetDatasetConfig(object): 87 | def __init__(self): 88 | self.type2class = {'cabinet': 0, 'bed': 1, 'chair': 2, 'sofa': 3, 'table': 4, 'door': 5, 89 | 'window': 6, 'bookshelf': 7, 'picture': 8, 'counter': 9, 'desk': 10, 'curtain': 11, 90 | 'refrigerator': 12, 'shower curtain': 13, 'toilet': 14, 'sink': 15, 'bathtub': 16, 91 | 'others': 17} 92 | self.class2type = {self.type2class[t]: t for t in self.type2class} 93 | 94 | self.nyu40ids = np.array( 95 | [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 23, 24, 25, 26, 27, 28, 29, 30, 31, 96 | 32, 33, 34, 35, 36, 37, 38, 39, 40]) # exclude wall (1), floor (2), ceiling (22) 97 | self.nyu40id2class = self._get_nyu40id2class() 98 | self.mean_size_arr = np.load(os.path.join(CONF.PATH.SCANNET, 'meta_data/scannet_reference_means.npz'))['arr_0'] 99 | 100 | self.num_class = len(self.type2class.keys()) 101 | self.num_heading_bin = 1 102 | self.num_size_cluster = len(self.type2class.keys()) 103 | 104 | self.type_mean_size = {} 105 | for i in range(self.num_size_cluster): 106 | self.type_mean_size[self.class2type[i]] = self.mean_size_arr[i, :] 107 | 108 | def _get_nyu40id2class(self): 109 | lines = [line.rstrip() for line in 110 | open(os.path.join(CONF.PATH.SCANNET, 'meta_data/scannetv2-labels.combined.tsv'))] 111 | lines = lines[1:] 112 | nyu40ids2class = {} 113 | for i in range(len(lines)): 114 | label_classes_set = set(self.type2class.keys()) 115 | elements = lines[i].split('\t') 116 | nyu40_id = int(elements[4]) 117 | nyu40_name = elements[7] 118 | if nyu40_id in self.nyu40ids: 119 | if nyu40_name not in label_classes_set: 120 | nyu40ids2class[nyu40_id] = self.type2class["others"] 121 | else: 122 | nyu40ids2class[nyu40_id] = self.type2class[nyu40_name] 123 | 124 | return nyu40ids2class 125 | 126 | def angle2class(self, angle): 127 | ''' Convert continuous angle to discrete class 128 | [optinal] also small regression number from 129 | class center angle to current angle. 130 | 131 | angle is from 0-2pi (or -pi~pi), class center at 0, 1*(2pi/N), 2*(2pi/N) ... (N-1)*(2pi/N) 132 | return is class of int32 of 0,1,...,N-1 and a number such that 133 | class*(2pi/N) + number = angle 134 | 135 | NOT USED. 136 | ''' 137 | assert (False) 138 | 139 | def class2angle(self, pred_cls, residual, to_label_format=True): 140 | ''' Inverse function to angle2class. 141 | 142 | As ScanNet only has axis-alined boxes so angles are always 0. ''' 143 | return 0 144 | 145 | def class2angle_batch(self, pred_cls, residual, to_label_format=True): 146 | ''' Inverse function to angle2class. 147 | 148 | As ScanNet only has axis-alined boxes so angles are always 0. ''' 149 | return np.zeros(pred_cls.shape[0]) 150 | 151 | def size2class(self, size, type_name): 152 | ''' Convert 3D box size (l,w,h) to size class and size residual ''' 153 | size_class = self.type2class[type_name] 154 | size_residual = size - self.type_mean_size[type_name] 155 | return size_class, size_residual 156 | 157 | def class2size(self, pred_cls, residual): 158 | ''' Inverse function to size2class ''' 159 | return self.mean_size_arr[pred_cls] + residual 160 | 161 | def class2size_batch(self, pred_cls, residual): 162 | ''' Inverse function to size2class ''' 163 | return self.mean_size_arr[pred_cls] + residual 164 | 165 | def param2obb(self, center, heading_class, heading_residual, size_class, size_residual): 166 | heading_angle = self.class2angle(heading_class, heading_residual) 167 | box_size = self.class2size(int(size_class), size_residual) 168 | obb = np.zeros((7,)) 169 | obb[0:3] = center 170 | obb[3:6] = box_size 171 | obb[6] = heading_angle * -1 172 | return obb 173 | 174 | def param2obb_batch(self, center, heading_class, heading_residual, size_class, size_residual): 175 | heading_angle = self.class2angle_batch(heading_class, heading_residual) 176 | box_size = self.class2size_batch(size_class, size_residual) 177 | obb = np.zeros((heading_class.shape[0], 7)) 178 | obb[:, 0:3] = center 179 | obb[:, 3:6] = box_size 180 | obb[:, 6] = heading_angle * -1 181 | return obb 182 | 183 | def param2obb_torch(self, center, heading_class, heading_residual, size_class, size_residual): 184 | device = center.device 185 | heading_angle = torch.zeros(heading_class.shape[0], device=device) 186 | 187 | mean_size_arr = torch.tensor(self.mean_size_arr, device=device) 188 | box_size = mean_size_arr[size_class] + size_residual 189 | obb = torch.zeros((heading_class.shape[0], 7), device=device) 190 | obb[:, 0:3] = center 191 | obb[:, 3:6] = box_size 192 | obb[:, 6] = heading_angle * -1 193 | return obb 194 | -------------------------------------------------------------------------------- /figures/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CurryYuan/X-Trans2Cap/aebe6e2d421034f2de8742fa9946e669bcb497e6/figures/pipeline.png -------------------------------------------------------------------------------- /in_out/__pycache__/arguments.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CurryYuan/X-Trans2Cap/aebe6e2d421034f2de8742fa9946e669bcb497e6/in_out/__pycache__/arguments.cpython-36.pyc -------------------------------------------------------------------------------- /in_out/arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import yaml 4 | 5 | 6 | def parse_arguments(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--config_file', type=str, help='config file') 9 | 10 | parser.add_argument("--tag", type=str, help="tag for the training, e.g. cuda_wl", default="") 11 | parser.add_argument("--dataset", type=str, help="Choose a dataset: ScanRefer or Nr3D", default="ScanRefer") 12 | parser.add_argument("--gpu", type=str, help="gpu", default="0") 13 | parser.add_argument("--seed", type=int, default=400, help="random seed") 14 | 15 | parser.add_argument("--batch_size", type=int, help="batch size", default=8) 16 | parser.add_argument("--epoch", type=int, help="number of epochs", default=20) 17 | parser.add_argument("--verbose", type=int, help="iterations of showing verbose", default=40) 18 | parser.add_argument("--val_step", type=int, help="iterations of validating", default=2000) 19 | parser.add_argument("--lr", type=float, help="learning rate", default=1e-3) 20 | parser.add_argument("--wd", type=float, help="weight decay", default=1e-5) 21 | 22 | parser.add_argument("--num_points", type=int, default=40000, help="Point Number [default: 40000]") 23 | parser.add_argument("--num_proposals", type=int, default=256, help="Proposal number [default: 256]") 24 | parser.add_argument("--num_locals", type=int, default=-1, help="Number of local objects [default: -1]") 25 | parser.add_argument("--num_scenes", type=int, default=-1, help="Number of scenes [default: -1]") 26 | 27 | parser.add_argument("--criterion", type=str, default="cider", 28 | help="criterion for selecting the best model [choices: bleu-1, bleu-2, bleu-3, bleu-4, cider, rouge, meteor, sum]") 29 | 30 | parser.add_argument("--no_height", action="store_true", help="Do NOT use height signal in input.") 31 | parser.add_argument("--no_augment", action="store_true", help="Do NOT use height signal in input.") 32 | parser.add_argument("--no_caption", action="store_true", help="Do NOT train the caption module.") 33 | 34 | parser.add_argument("--use_tf", action="store_true", help="enable teacher forcing in inference.") 35 | parser.add_argument("--use_color", action="store_true", help="Use RGB color in input.") 36 | parser.add_argument("--use_normal", action="store_true", help="Use RGB color in input.") 37 | parser.add_argument("--use_pretrained", type=str, 38 | help="Specify the folder name containing the pretrained detection module.") 39 | 40 | parser.add_argument("--use_checkpoint", type=str, help="Specify the checkpoint root", default="") 41 | parser.add_argument("--debug", action="store_true", help="Debug mode.") 42 | 43 | parser.add_argument("--use_train", action="store_true", help="Use train split in evaluation.") 44 | parser.add_argument("--use_last", action="store_true", help="Use the last model") 45 | parser.add_argument("--force", action="store_true", help="generate the results by force") 46 | parser.add_argument("--save_interm", action="store_true", help="Save the intermediate results") 47 | parser.add_argument("--min_iou", type=float, default=0.25, help="Min IoU threshold for evaluation") 48 | 49 | 50 | args = parser.parse_args() 51 | 52 | with open(args.config_file, 'r') as fin: 53 | configs_dict = yaml.load(fin, Loader=yaml.FullLoader) 54 | apply_configs(args, configs_dict) 55 | 56 | return args 57 | 58 | def apply_configs(args, config_dict): 59 | for key in config_dict: 60 | for k, v in config_dict[key].items(): 61 | setattr(args, k, v) -------------------------------------------------------------------------------- /lib/capeval/bleu/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ -------------------------------------------------------------------------------- /lib/capeval/bleu/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CurryYuan/X-Trans2Cap/aebe6e2d421034f2de8742fa9946e669bcb497e6/lib/capeval/bleu/__init__.py -------------------------------------------------------------------------------- /lib/capeval/bleu/bleu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : bleu.py 4 | # 5 | # Description : Wrapper for BLEU scorer. 6 | # 7 | # Creation Date : 06-01-2015 8 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | from .bleu_scorer import BleuScorer 12 | 13 | 14 | class Bleu: 15 | def __init__(self, n=4): 16 | # default compute Blue score up to 4 17 | self._n = n 18 | self._hypo_for_image = {} 19 | self.ref_for_image = {} 20 | 21 | def compute_score(self, gts, res): 22 | 23 | assert(gts.keys() == res.keys()) 24 | imgIds = gts.keys() 25 | 26 | bleu_scorer = BleuScorer(n=self._n) 27 | for id in imgIds: 28 | hypo = res[id] 29 | ref = gts[id] 30 | 31 | # Sanity check. 32 | assert(type(hypo) is list) 33 | assert(len(hypo) >= 1) 34 | assert(type(ref) is list) 35 | assert(len(ref) >= 1) 36 | 37 | bleu_scorer += (hypo[0], ref) 38 | 39 | #score, scores = bleu_scorer.compute_score(option='shortest') 40 | score, scores = bleu_scorer.compute_score(option='closest', verbose=0) 41 | #score, scores = bleu_scorer.compute_score(option='average', verbose=1) 42 | 43 | # return (bleu, bleu_info) 44 | return score, scores 45 | 46 | def method(self): 47 | return "Bleu" 48 | -------------------------------------------------------------------------------- /lib/capeval/bleu/bleu_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # bleu_scorer.py 4 | # David Chiang 5 | 6 | # Copyright (c) 2004-2006 University of Maryland. All rights 7 | # reserved. Do not redistribute without permission from the 8 | # author. Not for commercial use. 9 | 10 | # Modified by: 11 | # Hao Fang 12 | # Tsung-Yi Lin 13 | 14 | '''Provides: 15 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test(). 16 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked(). 17 | ''' 18 | 19 | import copy 20 | import sys, math, re 21 | from collections import defaultdict 22 | 23 | def precook(s, n=4, out=False): 24 | """Takes a string as input and returns an object that can be given to 25 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 26 | can take string arguments as well.""" 27 | words = s.split() 28 | counts = defaultdict(int) 29 | for k in range(1,n+1): 30 | for i in range(len(words)-k+1): 31 | ngram = tuple(words[i:i+k]) 32 | counts[ngram] += 1 33 | return (len(words), counts) 34 | 35 | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average" 36 | '''Takes a list of reference sentences for a single segment 37 | and returns an object that encapsulates everything that BLEU 38 | needs to know about them.''' 39 | 40 | reflen = [] 41 | maxcounts = {} 42 | for ref in refs: 43 | rl, counts = precook(ref, n) 44 | reflen.append(rl) 45 | for (ngram,count) in counts.items(): 46 | maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 47 | 48 | # Calculate effective reference sentence length. 49 | if eff == "shortest": 50 | reflen = min(reflen) 51 | elif eff == "average": 52 | reflen = float(sum(reflen))/len(reflen) 53 | 54 | ## lhuang: N.B.: leave reflen computaiton to the very end!! 55 | 56 | ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design) 57 | 58 | return (reflen, maxcounts) 59 | 60 | def cook_test(test, refs, eff=None, n=4): 61 | '''Takes a test sentence and returns an object that 62 | encapsulates everything that BLEU needs to know about it.''' 63 | 64 | reflen, refmaxcounts = refs 65 | testlen, counts = precook(test, n, True) 66 | 67 | result = {} 68 | 69 | # Calculate effective reference sentence length. 70 | 71 | if eff == "closest": 72 | result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1] 73 | else: ## i.e., "average" or "shortest" or None 74 | result["reflen"] = reflen 75 | 76 | result["testlen"] = testlen 77 | 78 | result["guess"] = [max(0,testlen-k+1) for k in range(1,n+1)] 79 | 80 | result['correct'] = [0]*n 81 | for (ngram, count) in counts.items(): 82 | result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count) 83 | 84 | return result 85 | 86 | class BleuScorer(object): 87 | """Bleu scorer. 88 | """ 89 | 90 | __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen" 91 | # special_reflen is used in oracle (proportional effective ref len for a node). 92 | 93 | def copy(self): 94 | ''' copy the refs.''' 95 | new = BleuScorer(n=self.n) 96 | new.ctest = copy.copy(self.ctest) 97 | new.crefs = copy.copy(self.crefs) 98 | new._score = None 99 | return new 100 | 101 | def __init__(self, test=None, refs=None, n=4, special_reflen=None): 102 | ''' singular instance ''' 103 | 104 | self.n = n 105 | self.crefs = [] 106 | self.ctest = [] 107 | self.cook_append(test, refs) 108 | self.special_reflen = special_reflen 109 | 110 | def cook_append(self, test, refs): 111 | '''called by constructor and __iadd__ to avoid creating new instances.''' 112 | 113 | if refs is not None: 114 | self.crefs.append(cook_refs(refs)) 115 | if test is not None: 116 | cooked_test = cook_test(test, self.crefs[-1]) 117 | self.ctest.append(cooked_test) ## N.B.: -1 118 | else: 119 | self.ctest.append(None) # lens of crefs and ctest have to match 120 | 121 | self._score = None ## need to recompute 122 | 123 | def ratio(self, option=None): 124 | self.compute_score(option=option) 125 | return self._ratio 126 | 127 | def score_ratio(self, option=None): 128 | '''return (bleu, len_ratio) pair''' 129 | return (self.fscore(option=option), self.ratio(option=option)) 130 | 131 | def score_ratio_str(self, option=None): 132 | return "%.4f (%.2f)" % self.score_ratio(option) 133 | 134 | def reflen(self, option=None): 135 | self.compute_score(option=option) 136 | return self._reflen 137 | 138 | def testlen(self, option=None): 139 | self.compute_score(option=option) 140 | return self._testlen 141 | 142 | def retest(self, new_test): 143 | if type(new_test) is str: 144 | new_test = [new_test] 145 | assert len(new_test) == len(self.crefs), new_test 146 | self.ctest = [] 147 | for t, rs in zip(new_test, self.crefs): 148 | self.ctest.append(cook_test(t, rs)) 149 | self._score = None 150 | 151 | return self 152 | 153 | def rescore(self, new_test): 154 | ''' replace test(s) with new test(s), and returns the new score.''' 155 | 156 | return self.retest(new_test).compute_score() 157 | 158 | def size(self): 159 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 160 | return len(self.crefs) 161 | 162 | def __iadd__(self, other): 163 | '''add an instance (e.g., from another sentence).''' 164 | 165 | if type(other) is tuple: 166 | ## avoid creating new BleuScorer instances 167 | self.cook_append(other[0], other[1]) 168 | else: 169 | assert self.compatible(other), "incompatible BLEUs." 170 | self.ctest.extend(other.ctest) 171 | self.crefs.extend(other.crefs) 172 | self._score = None ## need to recompute 173 | 174 | return self 175 | 176 | def compatible(self, other): 177 | return isinstance(other, BleuScorer) and self.n == other.n 178 | 179 | def single_reflen(self, option="average"): 180 | return self._single_reflen(self.crefs[0][0], option) 181 | 182 | def _single_reflen(self, reflens, option=None, testlen=None): 183 | 184 | if option == "shortest": 185 | reflen = min(reflens) 186 | elif option == "average": 187 | reflen = float(sum(reflens))/len(reflens) 188 | elif option == "closest": 189 | reflen = min((abs(l-testlen), l) for l in reflens)[1] 190 | else: 191 | assert False, "unsupported reflen option %s" % option 192 | 193 | return reflen 194 | 195 | def recompute_score(self, option=None, verbose=0): 196 | self._score = None 197 | return self.compute_score(option, verbose) 198 | 199 | def compute_score(self, option=None, verbose=0): 200 | n = self.n 201 | small = 1e-9 202 | tiny = 1e-15 ## so that if guess is 0 still return 0 203 | bleu_list = [[] for _ in range(n)] 204 | 205 | if self._score is not None: 206 | return self._score 207 | 208 | if option is None: 209 | option = "average" if len(self.crefs) == 1 else "closest" 210 | 211 | self._testlen = 0 212 | self._reflen = 0 213 | totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n} 214 | 215 | # for each sentence 216 | for comps in self.ctest: 217 | testlen = comps['testlen'] 218 | self._testlen += testlen 219 | 220 | if self.special_reflen is None: ## need computation 221 | reflen = self._single_reflen(comps['reflen'], option, testlen) 222 | else: 223 | reflen = self.special_reflen 224 | 225 | self._reflen += reflen 226 | 227 | for key in ['guess','correct']: 228 | for k in range(n): 229 | totalcomps[key][k] += comps[key][k] 230 | 231 | # append per image bleu score 232 | bleu = 1. 233 | for k in range(n): 234 | bleu *= (float(comps['correct'][k]) + tiny) \ 235 | /(float(comps['guess'][k]) + small) 236 | bleu_list[k].append(bleu ** (1./(k+1))) 237 | ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division 238 | if ratio < 1: 239 | for k in range(n): 240 | bleu_list[k][-1] *= math.exp(1 - 1/ratio) 241 | 242 | if verbose > 1: 243 | print(comps, reflen) 244 | 245 | totalcomps['reflen'] = self._reflen 246 | totalcomps['testlen'] = self._testlen 247 | 248 | bleus = [] 249 | bleu = 1. 250 | for k in range(n): 251 | bleu *= float(totalcomps['correct'][k] + tiny) \ 252 | / (totalcomps['guess'][k] + small) 253 | bleus.append(bleu ** (1./(k+1))) 254 | ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division 255 | if ratio < 1: 256 | for k in range(n): 257 | bleus[k] *= math.exp(1 - 1/ratio) 258 | 259 | if verbose > 0: 260 | print(totalcomps) 261 | print("ratio:", ratio) 262 | 263 | self._score = bleus 264 | return self._score, bleu_list 265 | -------------------------------------------------------------------------------- /lib/capeval/cider/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ -------------------------------------------------------------------------------- /lib/capeval/cider/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CurryYuan/X-Trans2Cap/aebe6e2d421034f2de8742fa9946e669bcb497e6/lib/capeval/cider/__init__.py -------------------------------------------------------------------------------- /lib/capeval/cider/cider.py: -------------------------------------------------------------------------------- 1 | # Filename: cider.py 2 | # 3 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric 4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 5 | # 6 | # Creation Date: Sun Feb 8 14:16:54 2015 7 | # 8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin 9 | 10 | from .cider_scorer import CiderScorer 11 | import pdb 12 | 13 | class Cider: 14 | """ 15 | Main Class to compute the CIDEr metric 16 | 17 | """ 18 | def __init__(self, test=None, refs=None, n=4, sigma=6.0): 19 | # set cider to sum over 1 to 4-grams 20 | self._n = n 21 | # set the standard deviation parameter for gaussian penalty 22 | self._sigma = sigma 23 | 24 | def compute_score(self, gts, res): 25 | """ 26 | Main function to compute CIDEr score 27 | :param hypo_for_image (dict) : dictionary with key and value 28 | ref_for_image (dict) : dictionary with key and value 29 | :return: cider (float) : computed CIDEr score for the corpus 30 | """ 31 | 32 | assert(gts.keys() == res.keys()) 33 | imgIds = gts.keys() 34 | 35 | cider_scorer = CiderScorer(n=self._n, sigma=self._sigma) 36 | 37 | for id in imgIds: 38 | hypo = res[id] 39 | ref = gts[id] 40 | 41 | # Sanity check. 42 | assert(type(hypo) is list) 43 | assert(len(hypo) >= 1) 44 | assert(type(ref) is list) 45 | assert(len(ref) > 0) 46 | 47 | cider_scorer += (hypo[0], ref) 48 | 49 | (score, scores) = cider_scorer.compute_score() 50 | 51 | return score, scores 52 | 53 | def method(self): 54 | return "CIDEr" 55 | -------------------------------------------------------------------------------- /lib/capeval/cider/cider_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Tsung-Yi Lin 3 | # Ramakrishna Vedantam 4 | 5 | import copy 6 | from collections import defaultdict 7 | import numpy as np 8 | import pdb 9 | import math 10 | 11 | def precook(s, n=4, out=False): 12 | """ 13 | Takes a string as input and returns an object that can be given to 14 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 15 | can take string arguments as well. 16 | :param s: string : sentence to be converted into ngrams 17 | :param n: int : number of ngrams for which representation is calculated 18 | :return: term frequency vector for occuring ngrams 19 | """ 20 | words = s.split() 21 | counts = defaultdict(int) 22 | for k in range(1,n+1): 23 | for i in range(len(words)-k+1): 24 | ngram = tuple(words[i:i+k]) 25 | counts[ngram] += 1 26 | return counts 27 | 28 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" 29 | '''Takes a list of reference sentences for a single segment 30 | and returns an object that encapsulates everything that BLEU 31 | needs to know about them. 32 | :param refs: list of string : reference sentences for some image 33 | :param n: int : number of ngrams for which (ngram) representation is calculated 34 | :return: result (list of dict) 35 | ''' 36 | return [precook(ref, n) for ref in refs] 37 | 38 | def cook_test(test, n=4): 39 | '''Takes a test sentence and returns an object that 40 | encapsulates everything that BLEU needs to know about it. 41 | :param test: list of string : hypothesis sentence for some image 42 | :param n: int : number of ngrams for which (ngram) representation is calculated 43 | :return: result (dict) 44 | ''' 45 | return precook(test, n, True) 46 | 47 | class CiderScorer(object): 48 | """CIDEr scorer. 49 | """ 50 | 51 | def copy(self): 52 | ''' copy the refs.''' 53 | new = CiderScorer(n=self.n) 54 | new.ctest = copy.copy(self.ctest) 55 | new.crefs = copy.copy(self.crefs) 56 | return new 57 | 58 | def __init__(self, test=None, refs=None, n=4, sigma=6.0): 59 | ''' singular instance ''' 60 | self.n = n 61 | self.sigma = sigma 62 | self.crefs = [] 63 | self.ctest = [] 64 | self.document_frequency = defaultdict(float) 65 | self.cook_append(test, refs) 66 | self.ref_len = None 67 | 68 | def cook_append(self, test, refs): 69 | '''called by constructor and __iadd__ to avoid creating new instances.''' 70 | 71 | if refs is not None: 72 | self.crefs.append(cook_refs(refs)) 73 | if test is not None: 74 | self.ctest.append(cook_test(test)) ## N.B.: -1 75 | else: 76 | self.ctest.append(None) # lens of crefs and ctest have to match 77 | 78 | def size(self): 79 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 80 | return len(self.crefs) 81 | 82 | def __iadd__(self, other): 83 | '''add an instance (e.g., from another sentence).''' 84 | 85 | if type(other) is tuple: 86 | ## avoid creating new CiderScorer instances 87 | self.cook_append(other[0], other[1]) 88 | else: 89 | self.ctest.extend(other.ctest) 90 | self.crefs.extend(other.crefs) 91 | 92 | return self 93 | def compute_doc_freq(self): 94 | ''' 95 | Compute term frequency for reference data. 96 | This will be used to compute idf (inverse document frequency later) 97 | The term frequency is stored in the object 98 | :return: None 99 | ''' 100 | for refs in self.crefs: 101 | # refs, k ref captions of one image 102 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]): 103 | self.document_frequency[ngram] += 1 104 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 105 | 106 | def compute_cider(self): 107 | def counts2vec(cnts): 108 | """ 109 | Function maps counts of ngram to vector of tfidf weights. 110 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. 111 | The n-th entry of array denotes length of n-grams. 112 | :param cnts: 113 | :return: vec (array of dict), norm (array of float), length (int) 114 | """ 115 | vec = [defaultdict(float) for _ in range(self.n)] 116 | length = 0 117 | norm = [0.0 for _ in range(self.n)] 118 | for (ngram,term_freq) in cnts.items(): 119 | # give word count 1 if it doesn't appear in reference corpus 120 | df = np.log(max(1.0, self.document_frequency[ngram])) 121 | # ngram index 122 | n = len(ngram)-1 123 | # tf (term_freq) * idf (precomputed idf) for n-grams 124 | vec[n][ngram] = float(term_freq)*(self.ref_len - df) 125 | # compute norm for the vector. the norm will be used for computing similarity 126 | norm[n] += pow(vec[n][ngram], 2) 127 | 128 | if n == 1: 129 | length += term_freq 130 | norm = [np.sqrt(n) for n in norm] 131 | return vec, norm, length 132 | 133 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): 134 | ''' 135 | Compute the cosine similarity of two vectors. 136 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis 137 | :param vec_ref: array of dictionary for vector corresponding to reference 138 | :param norm_hyp: array of float for vector corresponding to hypothesis 139 | :param norm_ref: array of float for vector corresponding to reference 140 | :param length_hyp: int containing length of hypothesis 141 | :param length_ref: int containing length of reference 142 | :return: array of score for each n-grams cosine similarity 143 | ''' 144 | delta = float(length_hyp - length_ref) 145 | # measure consine similarity 146 | val = np.array([0.0 for _ in range(self.n)]) 147 | for n in range(self.n): 148 | # ngram 149 | for (ngram,count) in vec_hyp[n].items(): 150 | # vrama91 : added clipping 151 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram] 152 | 153 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0): 154 | val[n] /= (norm_hyp[n]*norm_ref[n]) 155 | 156 | assert(not math.isnan(val[n])) 157 | # vrama91: added a length based gaussian penalty 158 | val[n] *= np.e**(-(delta**2)/(2*self.sigma**2)) 159 | return val 160 | 161 | # compute log reference length 162 | self.ref_len = np.log(float(len(self.crefs))) 163 | 164 | scores = [] 165 | for test, refs in zip(self.ctest, self.crefs): 166 | # compute vector for test captions 167 | vec, norm, length = counts2vec(test) 168 | # compute vector for ref captions 169 | score = np.array([0.0 for _ in range(self.n)]) 170 | for ref in refs: 171 | vec_ref, norm_ref, length_ref = counts2vec(ref) 172 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) 173 | # change by vrama91 - mean of ngram scores, instead of sum 174 | score_avg = np.mean(score) 175 | # divide by number of references 176 | score_avg /= len(refs) 177 | # multiply score by 10 178 | score_avg *= 10.0 179 | # append score of an image to the score list 180 | scores.append(score_avg) 181 | return scores 182 | 183 | def compute_score(self, option=None, verbose=0): 184 | # compute idf 185 | self.compute_doc_freq() 186 | # assert to check document frequency 187 | assert(len(self.ctest) >= max(self.document_frequency.values())) 188 | # compute cider score 189 | score = self.compute_cider() 190 | # debug 191 | # print score 192 | return np.mean(np.array(score)), np.array(score) 193 | -------------------------------------------------------------------------------- /lib/capeval/meteor/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ -------------------------------------------------------------------------------- /lib/capeval/meteor/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CurryYuan/X-Trans2Cap/aebe6e2d421034f2de8742fa9946e669bcb497e6/lib/capeval/meteor/__init__.py -------------------------------------------------------------------------------- /lib/capeval/meteor/data/paraphrase-en.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CurryYuan/X-Trans2Cap/aebe6e2d421034f2de8742fa9946e669bcb497e6/lib/capeval/meteor/data/paraphrase-en.gz -------------------------------------------------------------------------------- /lib/capeval/meteor/meteor-1.5.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CurryYuan/X-Trans2Cap/aebe6e2d421034f2de8742fa9946e669bcb497e6/lib/capeval/meteor/meteor-1.5.jar -------------------------------------------------------------------------------- /lib/capeval/meteor/meteor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Python wrapper for METEOR implementation, by Xinlei Chen 4 | # Acknowledge Michael Denkowski for the generous discussion and help 5 | 6 | import os 7 | import sys 8 | import subprocess 9 | import threading 10 | 11 | # Assumes meteor-1.5.jar is in the same directory as meteor.py. Change as needed. 12 | METEOR_JAR = 'meteor-1.5.jar' 13 | # print METEOR_JAR 14 | 15 | class Meteor: 16 | 17 | def __init__(self): 18 | self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \ 19 | '-', '-', '-stdio', '-l', 'en', '-norm'] 20 | self.meteor_p = subprocess.Popen(self.meteor_cmd, \ 21 | cwd=os.path.dirname(os.path.abspath(__file__)), \ 22 | stdin=subprocess.PIPE, \ 23 | stdout=subprocess.PIPE, \ 24 | stderr=subprocess.PIPE) 25 | # Used to guarantee thread safety 26 | self.lock = threading.Lock() 27 | 28 | def compute_score(self, gts, res): 29 | assert(gts.keys() == res.keys()) 30 | imgIds = gts.keys() 31 | scores = [] 32 | 33 | eval_line = 'EVAL' 34 | self.lock.acquire() 35 | for i in imgIds: 36 | assert(len(res[i]) >= 1) 37 | stat = self._stat(res[i][0], gts[i]) 38 | eval_line += ' ||| {}'.format(stat) 39 | 40 | self.meteor_p.stdin.write('{}\n'.format(eval_line).encode()) 41 | self.meteor_p.stdin.flush() 42 | for i in range(0,len(imgIds)): 43 | scores.append(float(self.meteor_p.stdout.readline().strip())) 44 | score = float(self.meteor_p.stdout.readline().strip()) 45 | self.lock.release() 46 | 47 | return score, scores 48 | 49 | def method(self): 50 | return "METEOR" 51 | 52 | def _stat(self, hypothesis_str, reference_list): 53 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 54 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') 55 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 56 | self.meteor_p.stdin.write('{}\n'.format(score_line).encode()) 57 | self.meteor_p.stdin.flush() 58 | return self.meteor_p.stdout.readline().decode().strip() 59 | 60 | def _score(self, hypothesis_str, reference_list): 61 | self.lock.acquire() 62 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 63 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') 64 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 65 | self.meteor_p.stdin.write('{}\n'.format(score_line)) 66 | stats = self.meteor_p.stdout.readline().strip() 67 | eval_line = 'EVAL ||| {}'.format(stats) 68 | # EVAL ||| stats 69 | self.meteor_p.stdin.write('{}\n'.format(eval_line)) 70 | score = float(self.meteor_p.stdout.readline().strip()) 71 | # bug fix: there are two values returned by the jar file, one average, and one all, so do it twice 72 | # thanks for Andrej for pointing this out 73 | score = float(self.meteor_p.stdout.readline().strip()) 74 | self.lock.release() 75 | return score 76 | 77 | def __del__(self): 78 | self.lock.acquire() 79 | self.meteor_p.stdin.close() 80 | self.meteor_p.kill() 81 | self.meteor_p.wait() 82 | self.lock.release() 83 | -------------------------------------------------------------------------------- /lib/capeval/rouge/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ -------------------------------------------------------------------------------- /lib/capeval/rouge/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CurryYuan/X-Trans2Cap/aebe6e2d421034f2de8742fa9946e669bcb497e6/lib/capeval/rouge/__init__.py -------------------------------------------------------------------------------- /lib/capeval/rouge/rouge.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : rouge.py 4 | # 5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004) 6 | # 7 | # Creation Date : 2015-01-07 06:03 8 | # Author : Ramakrishna Vedantam 9 | 10 | import numpy as np 11 | import pdb 12 | 13 | def my_lcs(string, sub): 14 | """ 15 | Calculates longest common subsequence for a pair of tokenized strings 16 | :param string : list of str : tokens from a string split using whitespace 17 | :param sub : list of str : shorter string, also split using whitespace 18 | :returns: length (list of int): length of the longest common subsequence between the two strings 19 | 20 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS 21 | """ 22 | if(len(string)< len(sub)): 23 | sub, string = string, sub 24 | 25 | lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)] 26 | 27 | for j in range(1,len(sub)+1): 28 | for i in range(1,len(string)+1): 29 | if(string[i-1] == sub[j-1]): 30 | lengths[i][j] = lengths[i-1][j-1] + 1 31 | else: 32 | lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1]) 33 | 34 | return lengths[len(string)][len(sub)] 35 | 36 | class Rouge(): 37 | ''' 38 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set 39 | 40 | ''' 41 | def __init__(self): 42 | # vrama91: updated the value below based on discussion with Hovey 43 | self.beta = 1.2 44 | 45 | def calc_score(self, candidate, refs): 46 | """ 47 | Compute ROUGE-L score given one candidate and references for an image 48 | :param candidate: str : candidate sentence to be evaluated 49 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated 50 | :returns score: int (ROUGE-L score for the candidate evaluated against references) 51 | """ 52 | # assert(len(candidate)==0) 53 | # assert(len(refs)>0) 54 | prec = [] 55 | rec = [] 56 | 57 | # split into tokens 58 | token_c = candidate[0].split(" ") 59 | 60 | for reference in refs: 61 | # split into tokens 62 | token_r = reference.split(" ") 63 | # compute the longest common subsequence 64 | lcs = my_lcs(token_r, token_c) 65 | prec.append(lcs/float(len(token_c))) 66 | rec.append(lcs/float(len(token_r))) 67 | 68 | prec_max = max(prec) 69 | rec_max = max(rec) 70 | 71 | if(prec_max!=0 and rec_max !=0): 72 | score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max) 73 | else: 74 | score = 0.0 75 | return score 76 | 77 | def compute_score(self, gts, res): 78 | """ 79 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset 80 | Invoked by evaluate_captions.py 81 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values 82 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values 83 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images) 84 | """ 85 | assert(gts.keys() == res.keys()) 86 | imgIds = gts.keys() 87 | 88 | score = [] 89 | for id in imgIds: 90 | hypo = res[id] 91 | ref = gts[id] 92 | 93 | score.append(self.calc_score(hypo, ref)) 94 | 95 | # Sanity check. 96 | assert(type(hypo) is list) 97 | assert(len(hypo) >= 1) 98 | assert(type(ref) is list) 99 | assert(len(ref) > 0) 100 | 101 | average_score = np.mean(np.array(score)) 102 | return average_score, np.array(score) 103 | 104 | def method(self): 105 | return "Rouge" 106 | -------------------------------------------------------------------------------- /lib/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from easydict import EasyDict 4 | 5 | CONF = EasyDict() 6 | 7 | # path 8 | CONF.PATH = EasyDict() 9 | CONF.PATH.BASE = "/home/yuanzhihao/Projects/X-Trans2Cap/" # TODO: change this 10 | CONF.PATH.CLUSTER = "/mntntfs/med_data1/yuanzhihao/X-Trans2Cap/" # TODO: change this 11 | CONF.PATH.DATA = os.path.join(CONF.PATH.BASE, 'data') 12 | CONF.PATH.SCANNET = os.path.join(CONF.PATH.DATA, "scannet") 13 | CONF.PATH.LIB = os.path.join(CONF.PATH.BASE, "lib") 14 | CONF.PATH.MODELS = os.path.join(CONF.PATH.BASE, "models") 15 | CONF.PATH.UTILS = os.path.join(CONF.PATH.BASE, "utils") 16 | 17 | # append to syspath 18 | for _, path in CONF.PATH.items(): 19 | sys.path.append(path) 20 | 21 | # scannet data 22 | CONF.PATH.SCANNET_SCANS = os.path.join(CONF.PATH.SCANNET, "scans") 23 | CONF.PATH.SCANNET_META = os.path.join(CONF.PATH.SCANNET, "meta_data") 24 | CONF.PATH.SCANNET_DATA = os.path.join(CONF.PATH.SCANNET, "scannet_data") 25 | 26 | # data 27 | CONF.NYU40_LABELS = os.path.join(CONF.PATH.SCANNET_META, "nyu40_labels.csv") 28 | 29 | # scannet 30 | CONF.SCANNETV2_TRAIN = os.path.join(CONF.PATH.SCANNET_META, "scannetv2_train.txt") 31 | CONF.SCANNETV2_VAL = os.path.join(CONF.PATH.SCANNET_META, "scannetv2_val.txt") 32 | CONF.SCANNETV2_TEST = os.path.join(CONF.PATH.SCANNET_META, "scannetv2_test.txt") 33 | CONF.SCANNETV2_LIST = os.path.join(CONF.PATH.SCANNET_META, "scannetv2.txt") 34 | 35 | # output 36 | CONF.PATH.OUTPUT = os.path.join(CONF.PATH.BASE, "outputs") 37 | 38 | # train 39 | CONF.TRAIN = EasyDict() 40 | CONF.TRAIN.MAX_DES_LEN = 30 41 | CONF.TRAIN.SEED = 42 42 | CONF.TRAIN.OVERLAID_THRESHOLD = 0.5 43 | CONF.TRAIN.MIN_IOU_THRESHOLD = 0.25 44 | CONF.TRAIN.NUM_BINS = 6 45 | 46 | # eval 47 | CONF.EVAL = EasyDict() 48 | CONF.EVAL.MIN_IOU_THRESHOLD = 0.5 49 | 50 | # data path 51 | CONF.SCANNET_V2_TSV = os.path.join(CONF.PATH.SCANNET_META, "scannetv2-labels.combined.tsv") 52 | CONF.VOCAB = os.path.join(CONF.PATH.DATA, "{}_vocabulary.json") # dataset_name 53 | CONF.GLOVE_PICKLE = os.path.join(CONF.PATH.DATA, "glove.p") 54 | CONF.VOCAB_WEIGHTS = os.path.join(CONF.PATH.DATA, "{}_vocabulary_weights.json") # dataset_name 55 | CONF.PATH.DATA_2D = os.path.join(CONF.PATH.DATA, "2d_feature_agg.npz") # processed 2D features -------------------------------------------------------------------------------- /lib/loss_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | import sys 10 | import os 11 | 12 | from icecream import ic 13 | 14 | # sys.path.append(os.path.join(os.getcwd(), "lib")) # HACK add the lib folder 15 | from utils.nn_distance import nn_distance 16 | from lib.config import CONF 17 | 18 | 19 | def compute_cap_loss(data_dict, mode="gt"): 20 | """ Compute cluster caption loss 21 | 22 | Args: 23 | data_dict: dict (read-only) 24 | 25 | Returns: 26 | cap_loss, cap_acc 27 | """ 28 | 29 | if mode == "gt": 30 | # unpack 31 | pred_caps = data_dict["lang_cap"] # (B, num_words - 1, num_vocabs) 32 | des_lens = data_dict["lang_len"] # batch_size 33 | num_words = des_lens.max() 34 | target_caps = data_dict["lang_ids"][:, 1:num_words] # (B, num_words - 1) 35 | _, _, num_vocabs = pred_caps.shape 36 | 37 | # caption loss 38 | criterion = nn.CrossEntropyLoss(ignore_index=0) 39 | cap_loss = criterion(pred_caps.reshape(-1, num_vocabs), target_caps.reshape(-1)) 40 | 41 | if 'lang_cap_s' in data_dict: 42 | pred_caps_s = data_dict["lang_cap_s"] # (B, num_words - 1, num_vocabs) 43 | cap_loss += criterion(pred_caps_s.reshape(-1, num_vocabs), target_caps.reshape(-1)) 44 | cap_loss /= 2 45 | 46 | # caption acc 47 | pred_caps = pred_caps.reshape(-1, num_vocabs).argmax(-1) # B * (num_words - 1) 48 | target_caps = target_caps.reshape(-1) # B * (num_words - 1) 49 | masks = target_caps != 0 50 | masked_pred_caps = pred_caps[masks] 51 | masked_target_caps = target_caps[masks] 52 | cap_acc = (masked_pred_caps == masked_target_caps).sum().float() / masks.sum().float() 53 | elif mode == "votenet": 54 | # unpack 55 | pred_caps = data_dict["lang_cap"] # (B, num_words - 1, num_vocabs) 56 | des_lens = data_dict["lang_len"] # batch_size 57 | num_words = des_lens.max() 58 | target_caps = data_dict["lang_ids"][:, 1:num_words] # (B, num_words - 1) 59 | 60 | _, _, num_vocabs = pred_caps.shape 61 | 62 | # caption loss 63 | criterion = nn.CrossEntropyLoss(ignore_index=0, reduction="none") 64 | cap_loss = criterion(pred_caps.reshape(-1, num_vocabs), target_caps.reshape(-1)) 65 | 66 | # mask out bad boxes 67 | good_bbox_masks = data_dict["good_bbox_masks"].unsqueeze(1).repeat(1, num_words - 1) # (B, num_words - 1) 68 | good_bbox_masks = good_bbox_masks.reshape(-1) # (B * num_words - 1) 69 | cap_loss = torch.sum(cap_loss * good_bbox_masks) / (torch.sum(good_bbox_masks) + 1e-6) 70 | 71 | num_good_bbox = data_dict["good_bbox_masks"].sum() 72 | if num_good_bbox > 0: # only apply loss on the good boxes 73 | pred_caps = pred_caps[data_dict["good_bbox_masks"]] # num_good_bbox 74 | target_caps = target_caps[data_dict["good_bbox_masks"]] # num_good_bbox 75 | 76 | # caption acc 77 | pred_caps = pred_caps.reshape(-1, num_vocabs).argmax(-1) # num_good_bbox * (num_words - 1) 78 | target_caps = target_caps.reshape(-1) # num_good_bbox * (num_words - 1) 79 | masks = target_caps != 0 80 | masked_pred_caps = pred_caps[masks] 81 | masked_target_caps = target_caps[masks] 82 | cap_acc = (masked_pred_caps == masked_target_caps).sum().float() / masks.sum().float() 83 | else: # zero placeholder if there is no good box 84 | cap_acc = torch.zeros(1)[0].cuda() 85 | 86 | return cap_loss, cap_acc 87 | 88 | 89 | def radian_to_label(radians, num_bins=6): 90 | """ 91 | convert radians to labels 92 | 93 | Arguments: 94 | radians: a tensor representing the rotation radians, (batch_size) 95 | radians: a binary tensor representing the valid masks, (batch_size) 96 | num_bins: number of bins for discretizing the rotation degrees 97 | 98 | Return: 99 | labels: a long tensor representing the discretized rotation degree classes, (batch_size) 100 | """ 101 | 102 | boundaries = torch.arange(np.pi / num_bins, np.pi - 1e-8, np.pi / num_bins).cuda() 103 | labels = torch.bucketize(radians, boundaries) 104 | 105 | return labels 106 | 107 | 108 | def get_loss(data_dict, mode="gt", use_rl=False): 109 | """ Loss functions 110 | Returns: 111 | loss: pytorch scalar tensor 112 | data_dict: dict 113 | """ 114 | 115 | if not use_rl: 116 | cap_loss, cap_acc = compute_cap_loss(data_dict, mode) 117 | 118 | # store 119 | data_dict["cap_loss"] = cap_loss 120 | data_dict["cap_acc"] = cap_acc 121 | else: 122 | # store 123 | data_dict["cap_acc"] = torch.zeros(1)[0].cuda() 124 | 125 | # Final loss function 126 | loss = data_dict["cap_loss"] 127 | 128 | # loss *= 10 # amplify 129 | 130 | if 'kd_loss' in data_dict: 131 | kd_loss = data_dict['kd_loss'] 132 | else: 133 | kd_loss = 0 134 | data_dict['kd_loss'] = torch.zeros(1)[0].cuda() 135 | 136 | data_dict["loss"] = loss + kd_loss 137 | 138 | return data_dict 139 | -------------------------------------------------------------------------------- /lib/pointnet2/_ext_src/include/ball_query.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 5 | const int nsample); 6 | -------------------------------------------------------------------------------- /lib/pointnet2/_ext_src/include/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef _CUDA_UTILS_H 2 | #define _CUDA_UTILS_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | #include 12 | 13 | #define TOTAL_THREADS 512 14 | 15 | inline int opt_n_threads(int work_size) { 16 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 17 | 18 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 19 | } 20 | 21 | inline dim3 opt_block_config(int x, int y) { 22 | const int x_threads = opt_n_threads(x); 23 | const int y_threads = 24 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 25 | dim3 block_config(x_threads, y_threads, 1); 26 | 27 | return block_config; 28 | } 29 | 30 | #define CUDA_CHECK_ERRORS() \ 31 | do { \ 32 | cudaError_t err = cudaGetLastError(); \ 33 | if (cudaSuccess != err) { \ 34 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ 35 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ 36 | __FILE__); \ 37 | exit(-1); \ 38 | } \ 39 | } while (0) 40 | 41 | #endif 42 | -------------------------------------------------------------------------------- /lib/pointnet2/_ext_src/include/group_points.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor group_points(at::Tensor points, at::Tensor idx); 5 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 6 | -------------------------------------------------------------------------------- /lib/pointnet2/_ext_src/include/interpolate.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows); 7 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 8 | at::Tensor weight); 9 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 10 | at::Tensor weight, const int m); 11 | -------------------------------------------------------------------------------- /lib/pointnet2/_ext_src/include/sampling.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor gather_points(at::Tensor points, at::Tensor idx); 5 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 6 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples); 7 | -------------------------------------------------------------------------------- /lib/pointnet2/_ext_src/include/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | #define CHECK_CUDA(x) \ 6 | do { \ 7 | AT_ASSERT(x.is_cuda(), #x " must be a CUDA tensor"); \ 8 | } while (0) 9 | 10 | #define CHECK_CONTIGUOUS(x) \ 11 | do { \ 12 | AT_ASSERT(x.is_contiguous(), #x " must be a contiguous tensor"); \ 13 | } while (0) 14 | 15 | #define CHECK_IS_INT(x) \ 16 | do { \ 17 | AT_ASSERT(x.scalar_type() == at::ScalarType::Int, \ 18 | #x " must be an int tensor"); \ 19 | } while (0) 20 | 21 | #define CHECK_IS_FLOAT(x) \ 22 | do { \ 23 | AT_ASSERT(x.scalar_type() == at::ScalarType::Float, \ 24 | #x " must be a float tensor"); \ 25 | } while (0) 26 | -------------------------------------------------------------------------------- /lib/pointnet2/_ext_src/src/ball_query.cpp: -------------------------------------------------------------------------------- 1 | #include "ball_query.h" 2 | #include "utils.h" 3 | 4 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 5 | int nsample, const float *new_xyz, 6 | const float *xyz, int *idx); 7 | 8 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 9 | const int nsample) { 10 | CHECK_CONTIGUOUS(new_xyz); 11 | CHECK_CONTIGUOUS(xyz); 12 | CHECK_IS_FLOAT(new_xyz); 13 | CHECK_IS_FLOAT(xyz); 14 | 15 | if (new_xyz.is_cuda()) { 16 | CHECK_CUDA(xyz); 17 | } 18 | 19 | at::Tensor idx = 20 | torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample}, 21 | at::device(new_xyz.device()).dtype(at::ScalarType::Int)); 22 | 23 | if (new_xyz.is_cuda()) { 24 | query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), 25 | radius, nsample, new_xyz.data_ptr(), 26 | xyz.data_ptr(), idx.data_ptr()); 27 | } else { 28 | AT_ASSERT(false, "CPU not supported"); 29 | } 30 | 31 | return idx; 32 | } 33 | -------------------------------------------------------------------------------- /lib/pointnet2/_ext_src/src/ball_query_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.h" 6 | 7 | // input: new_xyz(b, m, 3) xyz(b, n, 3) 8 | // output: idx(b, m, nsample) 9 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius, 10 | int nsample, 11 | const float *__restrict__ new_xyz, 12 | const float *__restrict__ xyz, 13 | int *__restrict__ idx) { 14 | int batch_index = blockIdx.x; 15 | xyz += batch_index * n * 3; 16 | new_xyz += batch_index * m * 3; 17 | idx += m * nsample * batch_index; 18 | 19 | int index = threadIdx.x; 20 | int stride = blockDim.x; 21 | 22 | float radius2 = radius * radius; 23 | for (int j = index; j < m; j += stride) { 24 | float new_x = new_xyz[j * 3 + 0]; 25 | float new_y = new_xyz[j * 3 + 1]; 26 | float new_z = new_xyz[j * 3 + 2]; 27 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) { 28 | float x = xyz[k * 3 + 0]; 29 | float y = xyz[k * 3 + 1]; 30 | float z = xyz[k * 3 + 2]; 31 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + 32 | (new_z - z) * (new_z - z); 33 | if (d2 < radius2) { 34 | if (cnt == 0) { 35 | for (int l = 0; l < nsample; ++l) { 36 | idx[j * nsample + l] = k; 37 | } 38 | } 39 | idx[j * nsample + cnt] = k; 40 | ++cnt; 41 | } 42 | } 43 | } 44 | } 45 | 46 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 47 | int nsample, const float *new_xyz, 48 | const float *xyz, int *idx) { 49 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 50 | query_ball_point_kernel<<>>( 51 | b, n, m, radius, nsample, new_xyz, xyz, idx); 52 | 53 | CUDA_CHECK_ERRORS(); 54 | } 55 | -------------------------------------------------------------------------------- /lib/pointnet2/_ext_src/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include "ball_query.h" 2 | #include "group_points.h" 3 | #include "interpolate.h" 4 | #include "sampling.h" 5 | 6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 7 | m.def("gather_points", &gather_points); 8 | m.def("gather_points_grad", &gather_points_grad); 9 | m.def("furthest_point_sampling", &furthest_point_sampling); 10 | 11 | m.def("three_nn", &three_nn); 12 | m.def("three_interpolate", &three_interpolate); 13 | m.def("three_interpolate_grad", &three_interpolate_grad); 14 | 15 | m.def("ball_query", &ball_query); 16 | 17 | m.def("group_points", &group_points); 18 | m.def("group_points_grad", &group_points_grad); 19 | } 20 | -------------------------------------------------------------------------------- /lib/pointnet2/_ext_src/src/group_points.cpp: -------------------------------------------------------------------------------- 1 | #include "group_points.h" 2 | #include "utils.h" 3 | 4 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 5 | const float *points, const int *idx, 6 | float *out); 7 | 8 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 9 | int nsample, const float *grad_out, 10 | const int *idx, float *grad_points); 11 | 12 | at::Tensor group_points(at::Tensor points, at::Tensor idx) { 13 | CHECK_CONTIGUOUS(points); 14 | CHECK_CONTIGUOUS(idx); 15 | CHECK_IS_FLOAT(points); 16 | CHECK_IS_INT(idx); 17 | 18 | if (points.is_cuda()) { 19 | CHECK_CUDA(idx); 20 | } 21 | 22 | at::Tensor output = 23 | torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)}, 24 | at::device(points.device()).dtype(at::ScalarType::Float)); 25 | 26 | if (points.is_cuda()) { 27 | group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 28 | idx.size(1), idx.size(2), 29 | points.data_ptr(), idx.data_ptr(), 30 | output.data_ptr()); 31 | } else { 32 | AT_ASSERT(false, "CPU not supported"); 33 | } 34 | 35 | return output; 36 | } 37 | 38 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) { 39 | CHECK_CONTIGUOUS(grad_out); 40 | CHECK_CONTIGUOUS(idx); 41 | CHECK_IS_FLOAT(grad_out); 42 | CHECK_IS_INT(idx); 43 | 44 | if (grad_out.is_cuda()) { 45 | CHECK_CUDA(idx); 46 | } 47 | 48 | at::Tensor output = 49 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 50 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 51 | 52 | if (grad_out.is_cuda()) { 53 | group_points_grad_kernel_wrapper( 54 | grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2), 55 | grad_out.data_ptr(), idx.data_ptr(), 56 | output.data_ptr()); 57 | } else { 58 | AT_ASSERT(false, "CPU not supported"); 59 | } 60 | 61 | return output; 62 | } 63 | -------------------------------------------------------------------------------- /lib/pointnet2/_ext_src/src/group_points_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | 6 | // input: points(b, c, n) idx(b, npoints, nsample) 7 | // output: out(b, c, npoints, nsample) 8 | __global__ void group_points_kernel(int b, int c, int n, int npoints, 9 | int nsample, 10 | const float *__restrict__ points, 11 | const int *__restrict__ idx, 12 | float *__restrict__ out) { 13 | int batch_index = blockIdx.x; 14 | points += batch_index * n * c; 15 | idx += batch_index * npoints * nsample; 16 | out += batch_index * npoints * nsample * c; 17 | 18 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 19 | const int stride = blockDim.y * blockDim.x; 20 | for (int i = index; i < c * npoints; i += stride) { 21 | const int l = i / npoints; 22 | const int j = i % npoints; 23 | for (int k = 0; k < nsample; ++k) { 24 | int ii = idx[j * nsample + k]; 25 | out[(l * npoints + j) * nsample + k] = points[l * n + ii]; 26 | } 27 | } 28 | } 29 | 30 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 31 | const float *points, const int *idx, 32 | float *out) { 33 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 34 | 35 | group_points_kernel<<>>( 36 | b, c, n, npoints, nsample, points, idx, out); 37 | 38 | CUDA_CHECK_ERRORS(); 39 | } 40 | 41 | // input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample) 42 | // output: grad_points(b, c, n) 43 | __global__ void group_points_grad_kernel(int b, int c, int n, int npoints, 44 | int nsample, 45 | const float *__restrict__ grad_out, 46 | const int *__restrict__ idx, 47 | float *__restrict__ grad_points) { 48 | int batch_index = blockIdx.x; 49 | grad_out += batch_index * npoints * nsample * c; 50 | idx += batch_index * npoints * nsample; 51 | grad_points += batch_index * n * c; 52 | 53 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 54 | const int stride = blockDim.y * blockDim.x; 55 | for (int i = index; i < c * npoints; i += stride) { 56 | const int l = i / npoints; 57 | const int j = i % npoints; 58 | for (int k = 0; k < nsample; ++k) { 59 | int ii = idx[j * nsample + k]; 60 | atomicAdd(grad_points + l * n + ii, 61 | grad_out[(l * npoints + j) * nsample + k]); 62 | } 63 | } 64 | } 65 | 66 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 67 | int nsample, const float *grad_out, 68 | const int *idx, float *grad_points) { 69 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 70 | 71 | group_points_grad_kernel<<>>( 72 | b, c, n, npoints, nsample, grad_out, idx, grad_points); 73 | 74 | CUDA_CHECK_ERRORS(); 75 | } 76 | -------------------------------------------------------------------------------- /lib/pointnet2/_ext_src/src/interpolate.cpp: -------------------------------------------------------------------------------- 1 | #include "interpolate.h" 2 | #include "utils.h" 3 | 4 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 5 | const float *known, float *dist2, int *idx); 6 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 7 | const float *points, const int *idx, 8 | const float *weight, float *out); 9 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 10 | const float *grad_out, 11 | const int *idx, const float *weight, 12 | float *grad_points); 13 | 14 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows) { 15 | CHECK_CONTIGUOUS(unknowns); 16 | CHECK_CONTIGUOUS(knows); 17 | CHECK_IS_FLOAT(unknowns); 18 | CHECK_IS_FLOAT(knows); 19 | 20 | if (unknowns.is_cuda()) { 21 | CHECK_CUDA(knows); 22 | } 23 | 24 | at::Tensor idx = 25 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 26 | at::device(unknowns.device()).dtype(at::ScalarType::Int)); 27 | at::Tensor dist2 = 28 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 29 | at::device(unknowns.device()).dtype(at::ScalarType::Float)); 30 | 31 | if (unknowns.is_cuda()) { 32 | three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1), 33 | unknowns.data_ptr(), knows.data_ptr(), 34 | dist2.data_ptr(), idx.data_ptr()); 35 | } else { 36 | AT_ASSERT(false, "CPU not supported"); 37 | } 38 | 39 | return {dist2, idx}; 40 | } 41 | 42 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 43 | at::Tensor weight) { 44 | CHECK_CONTIGUOUS(points); 45 | CHECK_CONTIGUOUS(idx); 46 | CHECK_CONTIGUOUS(weight); 47 | CHECK_IS_FLOAT(points); 48 | CHECK_IS_INT(idx); 49 | CHECK_IS_FLOAT(weight); 50 | 51 | if (points.is_cuda()) { 52 | CHECK_CUDA(idx); 53 | CHECK_CUDA(weight); 54 | } 55 | 56 | at::Tensor output = 57 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 58 | at::device(points.device()).dtype(at::ScalarType::Float)); 59 | 60 | if (points.is_cuda()) { 61 | three_interpolate_kernel_wrapper( 62 | points.size(0), points.size(1), points.size(2), idx.size(1), 63 | points.data_ptr(), idx.data_ptr(), weight.data_ptr(), 64 | output.data_ptr()); 65 | } else { 66 | AT_ASSERT(false, "CPU not supported"); 67 | } 68 | 69 | return output; 70 | } 71 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 72 | at::Tensor weight, const int m) { 73 | CHECK_CONTIGUOUS(grad_out); 74 | CHECK_CONTIGUOUS(idx); 75 | CHECK_CONTIGUOUS(weight); 76 | CHECK_IS_FLOAT(grad_out); 77 | CHECK_IS_INT(idx); 78 | CHECK_IS_FLOAT(weight); 79 | 80 | if (grad_out.is_cuda()) { 81 | CHECK_CUDA(idx); 82 | CHECK_CUDA(weight); 83 | } 84 | 85 | at::Tensor output = 86 | torch::zeros({grad_out.size(0), grad_out.size(1), m}, 87 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 88 | 89 | if (grad_out.is_cuda()) { 90 | three_interpolate_grad_kernel_wrapper( 91 | grad_out.size(0), grad_out.size(1), grad_out.size(2), m, 92 | grad_out.data_ptr(), idx.data_ptr(), 93 | weight.data_ptr(), output.data_ptr()); 94 | } else { 95 | AT_ASSERT(false, "CPU not supported"); 96 | } 97 | 98 | return output; 99 | } 100 | -------------------------------------------------------------------------------- /lib/pointnet2/_ext_src/src/interpolate_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.h" 6 | 7 | // input: unknown(b, n, 3) known(b, m, 3) 8 | // output: dist2(b, n, 3), idx(b, n, 3) 9 | __global__ void three_nn_kernel(int b, int n, int m, 10 | const float *__restrict__ unknown, 11 | const float *__restrict__ known, 12 | float *__restrict__ dist2, 13 | int *__restrict__ idx) { 14 | int batch_index = blockIdx.x; 15 | unknown += batch_index * n * 3; 16 | known += batch_index * m * 3; 17 | dist2 += batch_index * n * 3; 18 | idx += batch_index * n * 3; 19 | 20 | int index = threadIdx.x; 21 | int stride = blockDim.x; 22 | for (int j = index; j < n; j += stride) { 23 | float ux = unknown[j * 3 + 0]; 24 | float uy = unknown[j * 3 + 1]; 25 | float uz = unknown[j * 3 + 2]; 26 | 27 | double best1 = 1e40, best2 = 1e40, best3 = 1e40; 28 | int besti1 = 0, besti2 = 0, besti3 = 0; 29 | for (int k = 0; k < m; ++k) { 30 | float x = known[k * 3 + 0]; 31 | float y = known[k * 3 + 1]; 32 | float z = known[k * 3 + 2]; 33 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); 34 | if (d < best1) { 35 | best3 = best2; 36 | besti3 = besti2; 37 | best2 = best1; 38 | besti2 = besti1; 39 | best1 = d; 40 | besti1 = k; 41 | } else if (d < best2) { 42 | best3 = best2; 43 | besti3 = besti2; 44 | best2 = d; 45 | besti2 = k; 46 | } else if (d < best3) { 47 | best3 = d; 48 | besti3 = k; 49 | } 50 | } 51 | dist2[j * 3 + 0] = best1; 52 | dist2[j * 3 + 1] = best2; 53 | dist2[j * 3 + 2] = best3; 54 | 55 | idx[j * 3 + 0] = besti1; 56 | idx[j * 3 + 1] = besti2; 57 | idx[j * 3 + 2] = besti3; 58 | } 59 | } 60 | 61 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 62 | const float *known, float *dist2, int *idx) { 63 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 64 | three_nn_kernel<<>>(b, n, m, unknown, known, 65 | dist2, idx); 66 | 67 | CUDA_CHECK_ERRORS(); 68 | } 69 | 70 | // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3) 71 | // output: out(b, c, n) 72 | __global__ void three_interpolate_kernel(int b, int c, int m, int n, 73 | const float *__restrict__ points, 74 | const int *__restrict__ idx, 75 | const float *__restrict__ weight, 76 | float *__restrict__ out) { 77 | int batch_index = blockIdx.x; 78 | points += batch_index * m * c; 79 | 80 | idx += batch_index * n * 3; 81 | weight += batch_index * n * 3; 82 | 83 | out += batch_index * n * c; 84 | 85 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 86 | const int stride = blockDim.y * blockDim.x; 87 | for (int i = index; i < c * n; i += stride) { 88 | const int l = i / n; 89 | const int j = i % n; 90 | float w1 = weight[j * 3 + 0]; 91 | float w2 = weight[j * 3 + 1]; 92 | float w3 = weight[j * 3 + 2]; 93 | 94 | int i1 = idx[j * 3 + 0]; 95 | int i2 = idx[j * 3 + 1]; 96 | int i3 = idx[j * 3 + 2]; 97 | 98 | out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 + 99 | points[l * m + i3] * w3; 100 | } 101 | } 102 | 103 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 104 | const float *points, const int *idx, 105 | const float *weight, float *out) { 106 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 107 | three_interpolate_kernel<<>>( 108 | b, c, m, n, points, idx, weight, out); 109 | 110 | CUDA_CHECK_ERRORS(); 111 | } 112 | 113 | // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3) 114 | // output: grad_points(b, c, m) 115 | 116 | __global__ void three_interpolate_grad_kernel( 117 | int b, int c, int n, int m, const float *__restrict__ grad_out, 118 | const int *__restrict__ idx, const float *__restrict__ weight, 119 | float *__restrict__ grad_points) { 120 | int batch_index = blockIdx.x; 121 | grad_out += batch_index * n * c; 122 | idx += batch_index * n * 3; 123 | weight += batch_index * n * 3; 124 | grad_points += batch_index * m * c; 125 | 126 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 127 | const int stride = blockDim.y * blockDim.x; 128 | for (int i = index; i < c * n; i += stride) { 129 | const int l = i / n; 130 | const int j = i % n; 131 | float w1 = weight[j * 3 + 0]; 132 | float w2 = weight[j * 3 + 1]; 133 | float w3 = weight[j * 3 + 2]; 134 | 135 | int i1 = idx[j * 3 + 0]; 136 | int i2 = idx[j * 3 + 1]; 137 | int i3 = idx[j * 3 + 2]; 138 | 139 | atomicAdd(grad_points + l * m + i1, grad_out[i] * w1); 140 | atomicAdd(grad_points + l * m + i2, grad_out[i] * w2); 141 | atomicAdd(grad_points + l * m + i3, grad_out[i] * w3); 142 | } 143 | } 144 | 145 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 146 | const float *grad_out, 147 | const int *idx, const float *weight, 148 | float *grad_points) { 149 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 150 | three_interpolate_grad_kernel<<>>( 151 | b, c, n, m, grad_out, idx, weight, grad_points); 152 | 153 | CUDA_CHECK_ERRORS(); 154 | } 155 | -------------------------------------------------------------------------------- /lib/pointnet2/_ext_src/src/sampling.cpp: -------------------------------------------------------------------------------- 1 | #include "sampling.h" 2 | #include "utils.h" 3 | 4 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 5 | const float *points, const int *idx, 6 | float *out); 7 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 8 | const float *grad_out, const int *idx, 9 | float *grad_points); 10 | 11 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 12 | const float *dataset, float *temp, 13 | int *idxs); 14 | 15 | at::Tensor gather_points(at::Tensor points, at::Tensor idx) { 16 | CHECK_CONTIGUOUS(points); 17 | CHECK_CONTIGUOUS(idx); 18 | CHECK_IS_FLOAT(points); 19 | CHECK_IS_INT(idx); 20 | 21 | if (points.is_cuda()) { 22 | CHECK_CUDA(idx); 23 | } 24 | 25 | at::Tensor output = 26 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 27 | at::device(points.device()).dtype(at::ScalarType::Float)); 28 | 29 | if (points.is_cuda()) { 30 | gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 31 | idx.size(1), points.data_ptr(), 32 | idx.data_ptr(), output.data_ptr()); 33 | } else { 34 | AT_ASSERT(false, "CPU not supported"); 35 | } 36 | 37 | return output; 38 | } 39 | 40 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, 41 | const int n) { 42 | CHECK_CONTIGUOUS(grad_out); 43 | CHECK_CONTIGUOUS(idx); 44 | CHECK_IS_FLOAT(grad_out); 45 | CHECK_IS_INT(idx); 46 | 47 | if (grad_out.is_cuda()) { 48 | CHECK_CUDA(idx); 49 | } 50 | 51 | at::Tensor output = 52 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 53 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 54 | 55 | if (grad_out.is_cuda()) { 56 | gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n, 57 | idx.size(1), grad_out.data_ptr(), 58 | idx.data_ptr(), 59 | output.data_ptr()); 60 | } else { 61 | AT_ASSERT(false, "CPU not supported"); 62 | } 63 | 64 | return output; 65 | } 66 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) { 67 | CHECK_CONTIGUOUS(points); 68 | CHECK_IS_FLOAT(points); 69 | 70 | at::Tensor output = 71 | torch::zeros({points.size(0), nsamples}, 72 | at::device(points.device()).dtype(at::ScalarType::Int)); 73 | 74 | at::Tensor tmp = 75 | torch::full({points.size(0), points.size(1)}, 1e10, 76 | at::device(points.device()).dtype(at::ScalarType::Float)); 77 | 78 | if (points.is_cuda()) { 79 | furthest_point_sampling_kernel_wrapper( 80 | points.size(0), points.size(1), nsamples, points.data_ptr(), 81 | tmp.data_ptr(), output.data_ptr()); 82 | } else { 83 | AT_ASSERT(false, "CPU not supported"); 84 | } 85 | 86 | return output; 87 | } 88 | -------------------------------------------------------------------------------- /lib/pointnet2/_ext_src/src/sampling_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | 6 | // input: points(b, c, n) idx(b, m) 7 | // output: out(b, c, m) 8 | __global__ void gather_points_kernel(int b, int c, int n, int m, 9 | const float *__restrict__ points, 10 | const int *__restrict__ idx, 11 | float *__restrict__ out) { 12 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 13 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 14 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 15 | int a = idx[i * m + j]; 16 | out[(i * c + l) * m + j] = points[(i * c + l) * n + a]; 17 | } 18 | } 19 | } 20 | } 21 | 22 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 23 | const float *points, const int *idx, 24 | float *out) { 25 | gather_points_kernel<<>>(b, c, n, npoints, 27 | points, idx, out); 28 | 29 | CUDA_CHECK_ERRORS(); 30 | } 31 | 32 | // input: grad_out(b, c, m) idx(b, m) 33 | // output: grad_points(b, c, n) 34 | __global__ void gather_points_grad_kernel(int b, int c, int n, int m, 35 | const float *__restrict__ grad_out, 36 | const int *__restrict__ idx, 37 | float *__restrict__ grad_points) { 38 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 39 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 40 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 41 | int a = idx[i * m + j]; 42 | atomicAdd(grad_points + (i * c + l) * n + a, 43 | grad_out[(i * c + l) * m + j]); 44 | } 45 | } 46 | } 47 | } 48 | 49 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 50 | const float *grad_out, const int *idx, 51 | float *grad_points) { 52 | gather_points_grad_kernel<<>>( 54 | b, c, n, npoints, grad_out, idx, grad_points); 55 | 56 | CUDA_CHECK_ERRORS(); 57 | } 58 | 59 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, 60 | int idx1, int idx2) { 61 | const float v1 = dists[idx1], v2 = dists[idx2]; 62 | const int i1 = dists_i[idx1], i2 = dists_i[idx2]; 63 | dists[idx1] = max(v1, v2); 64 | dists_i[idx1] = v2 > v1 ? i2 : i1; 65 | } 66 | 67 | // Input dataset: (b, n, 3), tmp: (b, n) 68 | // Ouput idxs (b, m) 69 | template 70 | __global__ void furthest_point_sampling_kernel( 71 | int b, int n, int m, const float *__restrict__ dataset, 72 | float *__restrict__ temp, int *__restrict__ idxs) { 73 | if (m <= 0) return; 74 | __shared__ float dists[block_size]; 75 | __shared__ int dists_i[block_size]; 76 | 77 | int batch_index = blockIdx.x; 78 | dataset += batch_index * n * 3; 79 | temp += batch_index * n; 80 | idxs += batch_index * m; 81 | 82 | int tid = threadIdx.x; 83 | const int stride = block_size; 84 | 85 | int old = 0; 86 | if (threadIdx.x == 0) idxs[0] = old; 87 | 88 | __syncthreads(); 89 | for (int j = 1; j < m; j++) { 90 | int besti = 0; 91 | float best = -1; 92 | float x1 = dataset[old * 3 + 0]; 93 | float y1 = dataset[old * 3 + 1]; 94 | float z1 = dataset[old * 3 + 2]; 95 | for (int k = tid; k < n; k += stride) { 96 | float x2, y2, z2; 97 | x2 = dataset[k * 3 + 0]; 98 | y2 = dataset[k * 3 + 1]; 99 | z2 = dataset[k * 3 + 2]; 100 | float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); 101 | if (mag <= 1e-3) continue; 102 | 103 | float d = 104 | (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); 105 | 106 | float d2 = min(d, temp[k]); 107 | temp[k] = d2; 108 | besti = d2 > best ? k : besti; 109 | best = d2 > best ? d2 : best; 110 | } 111 | dists[tid] = best; 112 | dists_i[tid] = besti; 113 | __syncthreads(); 114 | 115 | if (block_size >= 512) { 116 | if (tid < 256) { 117 | __update(dists, dists_i, tid, tid + 256); 118 | } 119 | __syncthreads(); 120 | } 121 | if (block_size >= 256) { 122 | if (tid < 128) { 123 | __update(dists, dists_i, tid, tid + 128); 124 | } 125 | __syncthreads(); 126 | } 127 | if (block_size >= 128) { 128 | if (tid < 64) { 129 | __update(dists, dists_i, tid, tid + 64); 130 | } 131 | __syncthreads(); 132 | } 133 | if (block_size >= 64) { 134 | if (tid < 32) { 135 | __update(dists, dists_i, tid, tid + 32); 136 | } 137 | __syncthreads(); 138 | } 139 | if (block_size >= 32) { 140 | if (tid < 16) { 141 | __update(dists, dists_i, tid, tid + 16); 142 | } 143 | __syncthreads(); 144 | } 145 | if (block_size >= 16) { 146 | if (tid < 8) { 147 | __update(dists, dists_i, tid, tid + 8); 148 | } 149 | __syncthreads(); 150 | } 151 | if (block_size >= 8) { 152 | if (tid < 4) { 153 | __update(dists, dists_i, tid, tid + 4); 154 | } 155 | __syncthreads(); 156 | } 157 | if (block_size >= 4) { 158 | if (tid < 2) { 159 | __update(dists, dists_i, tid, tid + 2); 160 | } 161 | __syncthreads(); 162 | } 163 | if (block_size >= 2) { 164 | if (tid < 1) { 165 | __update(dists, dists_i, tid, tid + 1); 166 | } 167 | __syncthreads(); 168 | } 169 | 170 | old = dists_i[0]; 171 | if (tid == 0) idxs[j] = old; 172 | } 173 | } 174 | 175 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 176 | const float *dataset, float *temp, 177 | int *idxs) { 178 | unsigned int n_threads = opt_n_threads(n); 179 | 180 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 181 | 182 | switch (n_threads) { 183 | case 512: 184 | furthest_point_sampling_kernel<512> 185 | <<>>(b, n, m, dataset, temp, idxs); 186 | break; 187 | case 256: 188 | furthest_point_sampling_kernel<256> 189 | <<>>(b, n, m, dataset, temp, idxs); 190 | break; 191 | case 128: 192 | furthest_point_sampling_kernel<128> 193 | <<>>(b, n, m, dataset, temp, idxs); 194 | break; 195 | case 64: 196 | furthest_point_sampling_kernel<64> 197 | <<>>(b, n, m, dataset, temp, idxs); 198 | break; 199 | case 32: 200 | furthest_point_sampling_kernel<32> 201 | <<>>(b, n, m, dataset, temp, idxs); 202 | break; 203 | case 16: 204 | furthest_point_sampling_kernel<16> 205 | <<>>(b, n, m, dataset, temp, idxs); 206 | break; 207 | case 8: 208 | furthest_point_sampling_kernel<8> 209 | <<>>(b, n, m, dataset, temp, idxs); 210 | break; 211 | case 4: 212 | furthest_point_sampling_kernel<4> 213 | <<>>(b, n, m, dataset, temp, idxs); 214 | break; 215 | case 2: 216 | furthest_point_sampling_kernel<2> 217 | <<>>(b, n, m, dataset, temp, idxs); 218 | break; 219 | case 1: 220 | furthest_point_sampling_kernel<1> 221 | <<>>(b, n, m, dataset, temp, idxs); 222 | break; 223 | default: 224 | furthest_point_sampling_kernel<512> 225 | <<>>(b, n, m, dataset, temp, idxs); 226 | } 227 | 228 | CUDA_CHECK_ERRORS(); 229 | } 230 | -------------------------------------------------------------------------------- /lib/pointnet2/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = "3.0.0" 2 | -------------------------------------------------------------------------------- /lib/pointnet2/pointnet2_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | ''' Testing customized ops. ''' 7 | 8 | import torch 9 | from torch.autograd import gradcheck 10 | import numpy as np 11 | 12 | import os 13 | import sys 14 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 15 | sys.path.append(BASE_DIR) 16 | import pointnet2_utils 17 | 18 | def test_interpolation_grad(): 19 | batch_size = 1 20 | feat_dim = 2 21 | m = 4 22 | feats = torch.randn(batch_size, feat_dim, m, requires_grad=True).float().cuda() 23 | 24 | def interpolate_func(inputs): 25 | idx = torch.from_numpy(np.array([[[0,1,2],[1,2,3]]])).int().cuda() 26 | weight = torch.from_numpy(np.array([[[1,1,1],[2,2,2]]])).float().cuda() 27 | interpolated_feats = pointnet2_utils.three_interpolate(inputs, idx, weight) 28 | return interpolated_feats 29 | 30 | assert (gradcheck(interpolate_func, feats, atol=1e-1, rtol=1e-1)) 31 | 32 | if __name__=='__main__': 33 | test_interpolation_grad() 34 | -------------------------------------------------------------------------------- /lib/pointnet2/pytorch_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | ''' Modified based on Ref: https://github.com/erikwijmans/Pointnet2_PyTorch ''' 7 | import torch 8 | import torch.nn as nn 9 | from typing import List, Tuple 10 | 11 | class SharedMLP(nn.Sequential): 12 | 13 | def __init__( 14 | self, 15 | args: List[int], 16 | *, 17 | bn: bool = False, 18 | activation=nn.ReLU(inplace=True), 19 | preact: bool = False, 20 | first: bool = False, 21 | name: str = "" 22 | ): 23 | super().__init__() 24 | 25 | for i in range(len(args) - 1): 26 | self.add_module( 27 | name + 'layer{}'.format(i), 28 | Conv2d( 29 | args[i], 30 | args[i + 1], 31 | bn=(not first or not preact or (i != 0)) and bn, 32 | activation=activation 33 | if (not first or not preact or (i != 0)) else None, 34 | preact=preact 35 | ) 36 | ) 37 | 38 | 39 | class _BNBase(nn.Sequential): 40 | 41 | def __init__(self, in_size, batch_norm=None, name=""): 42 | super().__init__() 43 | self.add_module(name + "bn", batch_norm(in_size)) 44 | 45 | nn.init.constant_(self[0].weight, 1.0) 46 | nn.init.constant_(self[0].bias, 0) 47 | 48 | 49 | class BatchNorm1d(_BNBase): 50 | 51 | def __init__(self, in_size: int, *, name: str = ""): 52 | super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name) 53 | 54 | 55 | class BatchNorm2d(_BNBase): 56 | 57 | def __init__(self, in_size: int, name: str = ""): 58 | super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name) 59 | 60 | 61 | class BatchNorm3d(_BNBase): 62 | 63 | def __init__(self, in_size: int, name: str = ""): 64 | super().__init__(in_size, batch_norm=nn.BatchNorm3d, name=name) 65 | 66 | 67 | class _ConvBase(nn.Sequential): 68 | 69 | def __init__( 70 | self, 71 | in_size, 72 | out_size, 73 | kernel_size, 74 | stride, 75 | padding, 76 | activation, 77 | bn, 78 | init, 79 | conv=None, 80 | batch_norm=None, 81 | bias=True, 82 | preact=False, 83 | name="" 84 | ): 85 | super().__init__() 86 | 87 | bias = bias and (not bn) 88 | conv_unit = conv( 89 | in_size, 90 | out_size, 91 | kernel_size=kernel_size, 92 | stride=stride, 93 | padding=padding, 94 | bias=bias 95 | ) 96 | init(conv_unit.weight) 97 | if bias: 98 | nn.init.constant_(conv_unit.bias, 0) 99 | 100 | if bn: 101 | if not preact: 102 | bn_unit = batch_norm(out_size) 103 | else: 104 | bn_unit = batch_norm(in_size) 105 | 106 | if preact: 107 | if bn: 108 | self.add_module(name + 'bn', bn_unit) 109 | 110 | if activation is not None: 111 | self.add_module(name + 'activation', activation) 112 | 113 | self.add_module(name + 'conv', conv_unit) 114 | 115 | if not preact: 116 | if bn: 117 | self.add_module(name + 'bn', bn_unit) 118 | 119 | if activation is not None: 120 | self.add_module(name + 'activation', activation) 121 | 122 | 123 | class Conv1d(_ConvBase): 124 | 125 | def __init__( 126 | self, 127 | in_size: int, 128 | out_size: int, 129 | *, 130 | kernel_size: int = 1, 131 | stride: int = 1, 132 | padding: int = 0, 133 | activation=nn.ReLU(inplace=True), 134 | bn: bool = False, 135 | init=nn.init.kaiming_normal_, 136 | bias: bool = True, 137 | preact: bool = False, 138 | name: str = "" 139 | ): 140 | super().__init__( 141 | in_size, 142 | out_size, 143 | kernel_size, 144 | stride, 145 | padding, 146 | activation, 147 | bn, 148 | init, 149 | conv=nn.Conv1d, 150 | batch_norm=BatchNorm1d, 151 | bias=bias, 152 | preact=preact, 153 | name=name 154 | ) 155 | 156 | 157 | class Conv2d(_ConvBase): 158 | 159 | def __init__( 160 | self, 161 | in_size: int, 162 | out_size: int, 163 | *, 164 | kernel_size: Tuple[int, int] = (1, 1), 165 | stride: Tuple[int, int] = (1, 1), 166 | padding: Tuple[int, int] = (0, 0), 167 | activation=nn.ReLU(inplace=True), 168 | bn: bool = False, 169 | init=nn.init.kaiming_normal_, 170 | bias: bool = True, 171 | preact: bool = False, 172 | name: str = "" 173 | ): 174 | super().__init__( 175 | in_size, 176 | out_size, 177 | kernel_size, 178 | stride, 179 | padding, 180 | activation, 181 | bn, 182 | init, 183 | conv=nn.Conv2d, 184 | batch_norm=BatchNorm2d, 185 | bias=bias, 186 | preact=preact, 187 | name=name 188 | ) 189 | 190 | 191 | class Conv3d(_ConvBase): 192 | 193 | def __init__( 194 | self, 195 | in_size: int, 196 | out_size: int, 197 | *, 198 | kernel_size: Tuple[int, int, int] = (1, 1, 1), 199 | stride: Tuple[int, int, int] = (1, 1, 1), 200 | padding: Tuple[int, int, int] = (0, 0, 0), 201 | activation=nn.ReLU(inplace=True), 202 | bn: bool = False, 203 | init=nn.init.kaiming_normal_, 204 | bias: bool = True, 205 | preact: bool = False, 206 | name: str = "" 207 | ): 208 | super().__init__( 209 | in_size, 210 | out_size, 211 | kernel_size, 212 | stride, 213 | padding, 214 | activation, 215 | bn, 216 | init, 217 | conv=nn.Conv3d, 218 | batch_norm=BatchNorm3d, 219 | bias=bias, 220 | preact=preact, 221 | name=name 222 | ) 223 | 224 | 225 | class FC(nn.Sequential): 226 | 227 | def __init__( 228 | self, 229 | in_size: int, 230 | out_size: int, 231 | *, 232 | activation=nn.ReLU(inplace=True), 233 | bn: bool = False, 234 | init=None, 235 | preact: bool = False, 236 | name: str = "" 237 | ): 238 | super().__init__() 239 | 240 | fc = nn.Linear(in_size, out_size, bias=not bn) 241 | if init is not None: 242 | init(fc.weight) 243 | if not bn: 244 | nn.init.constant_(fc.bias, 0) 245 | 246 | if preact: 247 | if bn: 248 | self.add_module(name + 'bn', BatchNorm1d(in_size)) 249 | 250 | if activation is not None: 251 | self.add_module(name + 'activation', activation) 252 | 253 | self.add_module(name + 'fc', fc) 254 | 255 | if not preact: 256 | if bn: 257 | self.add_module(name + 'bn', BatchNorm1d(out_size)) 258 | 259 | if activation is not None: 260 | self.add_module(name + 'activation', activation) 261 | 262 | def set_bn_momentum_default(bn_momentum): 263 | 264 | def fn(m): 265 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 266 | m.momentum = bn_momentum 267 | 268 | return fn 269 | 270 | 271 | class BNMomentumScheduler(object): 272 | 273 | def __init__( 274 | self, model, bn_lambda, last_epoch=-1, 275 | setter=set_bn_momentum_default 276 | ): 277 | if not isinstance(model, nn.Module): 278 | raise RuntimeError( 279 | "Class '{}' is not a PyTorch nn Module".format( 280 | type(model).__name__ 281 | ) 282 | ) 283 | 284 | self.model = model 285 | self.setter = setter 286 | self.lmbd = bn_lambda 287 | 288 | self.step(last_epoch + 1) 289 | self.last_epoch = last_epoch 290 | 291 | def step(self, epoch=None): 292 | if epoch is None: 293 | epoch = self.last_epoch + 1 294 | 295 | self.last_epoch = epoch 296 | self.model.apply(self.setter(self.lmbd(epoch))) 297 | 298 | 299 | -------------------------------------------------------------------------------- /lib/pointnet2/setup.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import os.path as osp 4 | 5 | from setuptools import find_packages, setup 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | _this_dir = osp.dirname(osp.abspath(__file__)) 9 | _ext_src_root = "_ext_src" 10 | _ext_sources = glob.glob("{}/src/*.cpp".format(_ext_src_root)) + glob.glob( 11 | "{}/src/*.cu".format(_ext_src_root) 12 | ) 13 | _ext_headers = glob.glob("{}/include/*".format(_ext_src_root)) 14 | 15 | requirements = ["torch>=1.4"] 16 | 17 | os.environ["TORCH_CUDA_ARCH_LIST"] = "3.7+PTX;5.0;6.0;6.1;6.2;7.0;7.5" 18 | 19 | exec(open("_version.py").read()) 20 | 21 | setup( 22 | name='pointnet2', 23 | version=__version__, 24 | packages=find_packages(), 25 | install_requires=requirements, 26 | ext_modules=[ 27 | CUDAExtension( 28 | name='pointnet2._ext', 29 | sources=_ext_sources, 30 | extra_compile_args={ 31 | "cxx": ["-O3"], 32 | "nvcc": ["-O3", "-Xfatbin", "-compress-all"], 33 | }, 34 | include_dirs=[osp.join(_this_dir, _ext_src_root, "include")], 35 | ) 36 | ], 37 | cmdclass={"build_ext": BuildExtension}, 38 | include_package_data=True, 39 | ) -------------------------------------------------------------------------------- /models/backbone_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import sys 6 | import os 7 | 8 | from lib.pointnet2.pointnet2_modules import PointnetSAModuleVotes, PointnetFPModule 9 | 10 | class Pointnet2Backbone(nn.Module): 11 | r""" 12 | Backbone network for point cloud feature learning. 13 | Based on Pointnet++ single-scale grouping network. 14 | 15 | Parameters 16 | ---------- 17 | input_feature_dim: int 18 | Number of input channels in the feature descriptor for each point. 19 | e.g. 3 for RGB. 20 | """ 21 | def __init__(self, input_feature_dim=0): 22 | super().__init__() 23 | 24 | self.input_feature_dim = input_feature_dim 25 | 26 | # --------- 4 SET ABSTRACTION LAYERS --------- 27 | self.sa1 = PointnetSAModuleVotes( 28 | npoint=2048, 29 | radius=0.2, 30 | nsample=64, 31 | mlp=[input_feature_dim, 64, 64, 128], 32 | use_xyz=True, 33 | normalize_xyz=True 34 | ) 35 | 36 | self.sa2 = PointnetSAModuleVotes( 37 | npoint=1024, 38 | radius=0.4, 39 | nsample=32, 40 | mlp=[128, 128, 128, 256], 41 | use_xyz=True, 42 | normalize_xyz=True 43 | ) 44 | 45 | self.sa3 = PointnetSAModuleVotes( 46 | npoint=512, 47 | radius=0.8, 48 | nsample=16, 49 | mlp=[256, 128, 128, 256], 50 | use_xyz=True, 51 | normalize_xyz=True 52 | ) 53 | 54 | self.sa4 = PointnetSAModuleVotes( 55 | npoint=256, 56 | radius=1.2, 57 | nsample=16, 58 | mlp=[256, 128, 128, 256], 59 | use_xyz=True, 60 | normalize_xyz=True 61 | ) 62 | 63 | # --------- 2 FEATURE UPSAMPLING LAYERS -------- 64 | self.fp1 = PointnetFPModule(mlp=[256+256,256,256]) 65 | self.fp2 = PointnetFPModule(mlp=[256+256,256,256]) 66 | 67 | def _break_up_pc(self, pc): 68 | xyz = pc[..., :3].contiguous() 69 | features = pc[..., 3:].transpose(1, 2).contiguous() if pc.size(-1) > 3 else None 70 | 71 | return xyz, features 72 | 73 | def forward(self, data_dict): 74 | r""" 75 | Forward pass of the network 76 | 77 | Parameters 78 | ---------- 79 | pointcloud: Variable(torch.cuda.FloatTensor) 80 | (B, N, 3 + input_feature_dim) tensor 81 | Point cloud to run predicts on 82 | Each point in the point-cloud MUST 83 | be formated as (x, y, z, features...) 84 | 85 | Returns 86 | ---------- 87 | data_dict: {XXX_xyz, XXX_features, XXX_inds} 88 | XXX_xyz: float32 Tensor of shape (B,K,3) 89 | XXX_features: float32 Tensor of shape (B,K,D) 90 | XXX-inds: int64 Tensor of shape (B,K) values in [0,N-1] 91 | """ 92 | 93 | pointcloud = data_dict["point_clouds"] 94 | 95 | batch_size = pointcloud.shape[0] 96 | 97 | xyz, features = self._break_up_pc(pointcloud) 98 | 99 | # --------- 4 SET ABSTRACTION LAYERS --------- 100 | xyz, features, fps_inds = self.sa1(xyz, features) 101 | data_dict['sa1_inds'] = fps_inds 102 | data_dict['sa1_xyz'] = xyz 103 | data_dict['sa1_features'] = features 104 | 105 | xyz, features, fps_inds = self.sa2(xyz, features) # this fps_inds is just 0,1,...,1023 106 | data_dict['sa2_inds'] = fps_inds 107 | data_dict['sa2_xyz'] = xyz 108 | data_dict['sa2_features'] = features 109 | 110 | xyz, features, fps_inds = self.sa3(xyz, features) # this fps_inds is just 0,1,...,511 111 | data_dict['sa3_xyz'] = xyz 112 | data_dict['sa3_features'] = features 113 | 114 | xyz, features, fps_inds = self.sa4(xyz, features) # this fps_inds is just 0,1,...,255 115 | data_dict['sa4_xyz'] = xyz 116 | data_dict['sa4_features'] = features 117 | 118 | # --------- 2 FEATURE UPSAMPLING LAYERS -------- 119 | features = self.fp1(data_dict['sa3_xyz'], data_dict['sa4_xyz'], data_dict['sa3_features'], data_dict['sa4_features']) 120 | features = self.fp2(data_dict['sa2_xyz'], data_dict['sa3_xyz'], data_dict['sa2_features'], features) 121 | data_dict['fp2_features'] = features 122 | data_dict['fp2_xyz'] = data_dict['sa2_xyz'] 123 | num_seed = data_dict['fp2_xyz'].shape[1] 124 | data_dict['fp2_inds'] = data_dict['sa1_inds'][:,0:num_seed] # indices among the entire input point clouds 125 | 126 | return data_dict 127 | 128 | if __name__=='__main__': 129 | backbone_net = Pointnet2Backbone(input_feature_dim=3).cuda() 130 | print(backbone_net) 131 | backbone_net.eval() 132 | out = backbone_net(torch.rand(16,20000,6).cuda()) 133 | for key in sorted(out.keys()): 134 | print(key, '\t', out[key].shape) 135 | -------------------------------------------------------------------------------- /models/capnet.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import sys 6 | import os 7 | 8 | from models.backbone_module import Pointnet2Backbone 9 | from models.voting_module import VotingModule 10 | from models.proposal_module import ProposalModule 11 | 12 | 13 | class CapNet(nn.Module): 14 | def __init__(self, 15 | args, 16 | num_class, 17 | vocabulary, 18 | embeddings, 19 | num_heading_bin, 20 | num_size_cluster, 21 | mean_size_arr, 22 | input_feature_dim=0, 23 | num_proposal=256, 24 | vote_factor=1, 25 | sampling="vote_fps", 26 | detection=True, 27 | no_caption=False, 28 | emb_size=300, 29 | hidden_size=512, 30 | dataset=None): 31 | super().__init__() 32 | 33 | self.num_class = num_class 34 | self.num_heading_bin = num_heading_bin 35 | self.num_size_cluster = num_size_cluster 36 | self.mean_size_arr = mean_size_arr 37 | assert (mean_size_arr.shape[0] == self.num_size_cluster) 38 | self.input_feature_dim = input_feature_dim 39 | self.num_proposal = num_proposal 40 | self.vote_factor = vote_factor 41 | self.sampling = sampling 42 | self.no_caption = no_caption 43 | self.detection = detection 44 | 45 | if detection: 46 | # --------- PROPOSAL GENERATION --------- 47 | # Backbone point feature learning 48 | self.backbone_net = Pointnet2Backbone(input_feature_dim=self.input_feature_dim) 49 | 50 | # Hough voting 51 | self.vgen = VotingModule(self.vote_factor, 256) 52 | 53 | # Vote aggregation and object proposal 54 | self.proposal = ProposalModule(num_class, num_heading_bin, num_size_cluster, mean_size_arr, num_proposal, sampling) 55 | 56 | module = importlib.import_module('models.' + args.model) 57 | TransformerCaptionModule = getattr(module, 'TransformerCaptionModule') 58 | self.caption = TransformerCaptionModule(vocabulary, 59 | embeddings, 60 | emb_size, 61 | 128, 62 | hidden_size, 63 | num_proposal, 64 | use_gt_ins=args.use_gt_ins, 65 | use_rl=args.use_rl, 66 | dataset=dataset) 67 | 68 | def forward(self, data_dict, use_tf=True, is_eval=False): 69 | """ Forward pass of the network 70 | 71 | Args: 72 | data_dict: dict 73 | { 74 | point_clouds, 75 | lang_feat 76 | } 77 | 78 | point_clouds: Variable(torch.cuda.FloatTensor) 79 | (B, N, 3 + input_channels) tensor 80 | Point cloud to run predicts on 81 | Each point in the point-cloud MUST 82 | be formated as (x, y, z, features...) 83 | Returns: 84 | end_points: dict 85 | """ 86 | 87 | ####################################### 88 | # # 89 | # DETECTION BRANCH # 90 | # # 91 | ####################################### 92 | if self.detection: 93 | # --------- HOUGH VOTING --------- 94 | data_dict = self.backbone_net(data_dict) 95 | 96 | # --------- HOUGH VOTING --------- 97 | xyz = data_dict["fp2_xyz"] 98 | features = data_dict["fp2_features"] 99 | data_dict["seed_inds"] = data_dict["fp2_inds"] 100 | data_dict["seed_xyz"] = xyz 101 | data_dict["seed_features"] = features 102 | 103 | xyz, features = self.vgen(xyz, features) 104 | features_norm = torch.norm(features, p=2, dim=1) 105 | features = features.div(features_norm.unsqueeze(1)) 106 | data_dict["vote_xyz"] = xyz 107 | data_dict["vote_features"] = features 108 | 109 | # --------- PROPOSAL GENERATION --------- 110 | data_dict = self.proposal(xyz, features, data_dict) 111 | 112 | # --------- CAPTION GENERATION --------- 113 | data_dict = self.caption(data_dict, use_tf, is_eval) 114 | 115 | return data_dict 116 | -------------------------------------------------------------------------------- /models/proposal_module.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from: https://github.com/facebookresearch/votenet/blob/master/models/proposal_module.py 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import numpy as np 9 | import os 10 | import sys 11 | 12 | sys.path.append(os.path.join(os.getcwd(), "lib")) # HACK add the lib folder 13 | import lib.pointnet2.pointnet2_utils 14 | from data.scannet.model_util_scannet import ScannetDatasetConfig 15 | from lib.pointnet2.pointnet2_modules import PointnetSAModuleVotes 16 | from utils.box_util import get_3d_box_batch 17 | 18 | # constants 19 | DC = ScannetDatasetConfig() 20 | 21 | 22 | class ProposalModule(nn.Module): 23 | def __init__(self, num_class, num_heading_bin, num_size_cluster, mean_size_arr, num_proposal, sampling, 24 | seed_feat_dim=256): 25 | super().__init__() 26 | 27 | self.num_class = num_class 28 | self.num_heading_bin = num_heading_bin 29 | self.num_size_cluster = num_size_cluster 30 | self.mean_size_arr = mean_size_arr 31 | self.num_proposal = num_proposal 32 | self.sampling = sampling 33 | self.seed_feat_dim = seed_feat_dim 34 | 35 | # Vote clustering 36 | self.vote_aggregation = PointnetSAModuleVotes( 37 | npoint=self.num_proposal, 38 | radius=0.3, 39 | nsample=16, 40 | mlp=[self.seed_feat_dim, 128, 128, 128], 41 | use_xyz=True, 42 | normalize_xyz=True 43 | ) 44 | 45 | # Object proposal/detection 46 | # Objectness scores (2), center residual (3), 47 | # heading class+residual (num_heading_bin*2), size class+residual(num_size_cluster*4) 48 | self.proposal = nn.Sequential( 49 | nn.Conv1d(128, 128, 1, bias=False), 50 | nn.BatchNorm1d(128), 51 | nn.ReLU(), 52 | nn.Conv1d(128, 128, 1, bias=False), 53 | nn.BatchNorm1d(128), 54 | nn.ReLU(), 55 | nn.Conv1d(128, 2 + 3 + num_heading_bin * 2 + num_size_cluster * 4 + self.num_class, 1) 56 | ) 57 | 58 | def forward(self, xyz, features, data_dict): 59 | """ 60 | Args: 61 | xyz: (B,K,3) 62 | features: (B,C,K) 63 | Returns: 64 | scores: (B,num_proposal,2+3+NH*2+NS*4) 65 | """ 66 | 67 | # Farthest point sampling (FPS) on votes 68 | xyz, features, fps_inds = self.vote_aggregation(xyz, features) 69 | 70 | sample_inds = fps_inds 71 | 72 | data_dict['aggregated_vote_xyz'] = xyz # (batch_size, num_proposal, 3) 73 | data_dict['aggregated_vote_features'] = features.permute(0, 2, 74 | 1).contiguous() # (batch_size, num_proposal, 128) 75 | data_dict[ 76 | 'aggregated_vote_inds'] = sample_inds # (batch_size, num_proposal,) # should be 0,1,2,...,num_proposal 77 | 78 | # --------- PROPOSAL GENERATION --------- 79 | net = self.proposal(features) 80 | data_dict = self.decode_scores(net, data_dict, self.num_class, self.num_heading_bin, self.num_size_cluster, 81 | self.mean_size_arr) 82 | 83 | return data_dict 84 | 85 | def decode_pred_box(self, data_dict): 86 | # predicted bbox 87 | pred_center = data_dict["center"].detach().cpu().numpy() # (B,K,3) 88 | pred_heading_class = torch.argmax(data_dict["heading_scores"], -1) # B,num_proposal 89 | pred_heading_residual = torch.gather(data_dict["heading_residuals"], 2, 90 | pred_heading_class.unsqueeze(-1)) # B,num_proposal,1 91 | pred_heading_class = pred_heading_class.detach().cpu().numpy() # B,num_proposal 92 | pred_heading_residual = pred_heading_residual.squeeze(2).detach().cpu().numpy() # B,num_proposal 93 | pred_size_class = torch.argmax(data_dict["size_scores"], -1) # B,num_proposal 94 | pred_size_residual = torch.gather(data_dict["size_residuals"], 2, 95 | pred_size_class.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 1, 96 | 3)) # B,num_proposal,1,3 97 | pred_size_class = pred_size_class.detach().cpu().numpy() 98 | pred_size_residual = pred_size_residual.squeeze(2).detach().cpu().numpy() # B,num_proposal,3 99 | 100 | batch_size, num_proposals, _ = pred_center.shape 101 | pred_bboxes = [] 102 | pred_obbs = [] 103 | for i in range(batch_size): 104 | # convert the bbox parameters to bbox corners 105 | pred_obb_batch = DC.param2obb_batch(pred_center[i, :, 0:3], pred_heading_class[i], pred_heading_residual[i], 106 | pred_size_class[i], pred_size_residual[i]) 107 | pred_bbox_batch = get_3d_box_batch(pred_obb_batch[:, 3:6], pred_obb_batch[:, 6], pred_obb_batch[:, 0:3]) 108 | pred_bboxes.append(torch.from_numpy(pred_bbox_batch).cuda().unsqueeze(0)) 109 | pred_obbs.append(torch.from_numpy(pred_obb_batch).cuda().unsqueeze(0)) 110 | 111 | pred_bboxes = torch.cat(pred_bboxes, dim=0) # batch_size, num_proposals, 8, 3 112 | pred_obbs = torch.cat(pred_obbs, dim=0).float() # (batch_size, num_proposals, 7) 113 | 114 | data_dict["bbox_corner"] = pred_bboxes 115 | data_dict['bbox_obb'] = pred_obbs 116 | 117 | 118 | def decode_scores(self, net, data_dict, num_class, num_heading_bin, num_size_cluster, mean_size_arr): 119 | """ 120 | decode the predicted parameters for the bounding boxes 121 | 122 | """ 123 | net_transposed = net.transpose(2, 1).contiguous() # (batch_size, 1024, ..) 124 | batch_size = net_transposed.shape[0] 125 | num_proposal = net_transposed.shape[1] 126 | 127 | objectness_scores = net_transposed[:, :, 0:2] 128 | 129 | base_xyz = data_dict['aggregated_vote_xyz'] # (batch_size, num_proposal, 3) 130 | center = base_xyz + net_transposed[:, :, 2:5] # (batch_size, num_proposal, 3) 131 | 132 | heading_scores = net_transposed[:, :, 5:5 + num_heading_bin] 133 | heading_residuals_normalized = net_transposed[:, :, 5 + num_heading_bin:5 + num_heading_bin * 2] 134 | 135 | size_scores = net_transposed[:, :, 5 + num_heading_bin * 2:5 + num_heading_bin * 2 + num_size_cluster] 136 | size_residuals_normalized = net_transposed[:, :, 137 | 5 + num_heading_bin * 2 + num_size_cluster:5 + num_heading_bin * 2 + num_size_cluster * 4].view( 138 | [batch_size, num_proposal, num_size_cluster, 3]) # Bxnum_proposalxnum_size_clusterx3 139 | 140 | sem_cls_scores = net_transposed[:, :, 5 + num_heading_bin * 2 + num_size_cluster * 4:] # Bxnum_proposalx10 141 | 142 | # store 143 | data_dict['objectness_scores'] = objectness_scores 144 | data_dict['center'] = center 145 | data_dict['heading_scores'] = heading_scores # Bxnum_proposalxnum_heading_bin 146 | # B x num_proposal x num_heading_bin (should be -1 to 1) 147 | data_dict['heading_residuals_normalized'] = heading_residuals_normalized 148 | # B x num_proposal x num_heading_bin 149 | data_dict['heading_residuals'] = heading_residuals_normalized * (np.pi / num_heading_bin) 150 | data_dict['size_scores'] = size_scores 151 | data_dict['size_residuals_normalized'] = size_residuals_normalized 152 | data_dict['size_residuals'] = size_residuals_normalized * torch.from_numpy( 153 | mean_size_arr.astype(np.float32)).cuda().unsqueeze(0).unsqueeze(0) 154 | data_dict['sem_cls_scores'] = sem_cls_scores 155 | # processed box info 156 | self.decode_pred_box(data_dict) # bounding box corner coordinates 157 | data_dict["bbox_feature"] = data_dict["aggregated_vote_features"] 158 | data_dict["bbox_mask"] = objectness_scores.argmax(-1) 159 | data_dict['bbox_sems'] = sem_cls_scores.argmax(-1) 160 | data_dict['sem_cls'] = sem_cls_scores.argmax(-1) 161 | 162 | return data_dict 163 | -------------------------------------------------------------------------------- /models/transformer/__pycache__/attention.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CurryYuan/X-Trans2Cap/aebe6e2d421034f2de8742fa9946e669bcb497e6/models/transformer/__pycache__/attention.cpython-36.pyc -------------------------------------------------------------------------------- /models/transformer/__pycache__/beam_search.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CurryYuan/X-Trans2Cap/aebe6e2d421034f2de8742fa9946e669bcb497e6/models/transformer/__pycache__/beam_search.cpython-36.pyc -------------------------------------------------------------------------------- /models/transformer/__pycache__/containers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CurryYuan/X-Trans2Cap/aebe6e2d421034f2de8742fa9946e669bcb497e6/models/transformer/__pycache__/containers.cpython-36.pyc -------------------------------------------------------------------------------- /models/transformer/__pycache__/decoders.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CurryYuan/X-Trans2Cap/aebe6e2d421034f2de8742fa9946e669bcb497e6/models/transformer/__pycache__/decoders.cpython-36.pyc -------------------------------------------------------------------------------- /models/transformer/__pycache__/encoders.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CurryYuan/X-Trans2Cap/aebe6e2d421034f2de8742fa9946e669bcb497e6/models/transformer/__pycache__/encoders.cpython-36.pyc -------------------------------------------------------------------------------- /models/transformer/__pycache__/m2_transformer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CurryYuan/X-Trans2Cap/aebe6e2d421034f2de8742fa9946e669bcb497e6/models/transformer/__pycache__/m2_transformer.cpython-36.pyc -------------------------------------------------------------------------------- /models/transformer/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CurryYuan/X-Trans2Cap/aebe6e2d421034f2de8742fa9946e669bcb497e6/models/transformer/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /models/transformer/attention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from icecream import ic 4 | from torch import nn 5 | from .containers import Module 6 | 7 | 8 | class ScaledDotProductAttention(nn.Module): 9 | ''' 10 | Scaled dot-product attention 11 | ''' 12 | 13 | def __init__(self, d_model, d_k, d_v, h): 14 | ''' 15 | :param d_model: Output dimensionality of the model 16 | :param d_k: Dimensionality of queries and keys 17 | :param d_v: Dimensionality of values 18 | :param h: Number of heads 19 | ''' 20 | super(ScaledDotProductAttention, self).__init__() 21 | self.fc_q = nn.Linear(d_model, h * d_k) 22 | self.fc_k = nn.Linear(d_model, h * d_k) 23 | self.fc_v = nn.Linear(d_model, h * d_v) 24 | self.fc_o = nn.Linear(h * d_v, d_model) 25 | 26 | self.d_model = d_model 27 | self.d_k = d_k 28 | self.d_v = d_v 29 | self.h = h 30 | 31 | self.init_weights() 32 | 33 | def init_weights(self): 34 | nn.init.xavier_uniform_(self.fc_q.weight) 35 | nn.init.xavier_uniform_(self.fc_k.weight) 36 | nn.init.xavier_uniform_(self.fc_v.weight) 37 | nn.init.xavier_uniform_(self.fc_o.weight) 38 | nn.init.constant_(self.fc_q.bias, 0) 39 | nn.init.constant_(self.fc_k.bias, 0) 40 | nn.init.constant_(self.fc_v.bias, 0) 41 | nn.init.constant_(self.fc_o.bias, 0) 42 | 43 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 44 | ''' 45 | Computes 46 | :param queries: Queries (b_s, nq, d_model) 47 | :param keys: Keys (b_s, nk, d_model) 48 | :param values: Values (b_s, nk, d_model) 49 | :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking. 50 | :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk). 51 | :return: 52 | ''' 53 | b_s, nq = queries.shape[:2] 54 | nk = keys.shape[1] 55 | 56 | q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k) 57 | k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk) 58 | v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v) 59 | att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk) 60 | 61 | if attention_weights is not None: 62 | att = att * attention_weights 63 | if attention_mask is not None: 64 | att = att.masked_fill(attention_mask, -1e9) 65 | att = torch.softmax(att, -1) 66 | out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v) 67 | out = self.fc_o(out) # (b_s, nq, d_model) 68 | return out 69 | 70 | 71 | class ScaledDotProductAttentionMemory(nn.Module): 72 | ''' 73 | Scaled dot-product attention with memory 74 | ''' 75 | 76 | def __init__(self, d_model, d_k, d_v, h, m): 77 | ''' 78 | :param d_model: Output dimensionality of the model 79 | :param d_k: Dimensionality of queries and keys 80 | :param d_v: Dimensionality of values 81 | :param h: Number of heads 82 | :param m: Number of memory slots 83 | ''' 84 | super(ScaledDotProductAttentionMemory, self).__init__() 85 | self.fc_q = nn.Linear(d_model, h * d_k) 86 | self.fc_k = nn.Linear(d_model, h * d_k) 87 | self.fc_v = nn.Linear(d_model, h * d_v) 88 | self.fc_o = nn.Linear(h * d_v, d_model) 89 | self.m_k = nn.Parameter(torch.FloatTensor(1, m, h * d_k)) 90 | self.m_v = nn.Parameter(torch.FloatTensor(1, m, h * d_v)) 91 | 92 | self.d_model = d_model 93 | self.d_k = d_k 94 | self.d_v = d_v 95 | self.h = h 96 | self.m = m 97 | 98 | self.init_weights() 99 | 100 | def init_weights(self): 101 | nn.init.xavier_uniform_(self.fc_q.weight) 102 | nn.init.xavier_uniform_(self.fc_k.weight) 103 | nn.init.xavier_uniform_(self.fc_v.weight) 104 | nn.init.xavier_uniform_(self.fc_o.weight) 105 | nn.init.normal_(self.m_k, 0, 1 / self.d_k) 106 | nn.init.normal_(self.m_v, 0, 1 / self.m) 107 | nn.init.constant_(self.fc_q.bias, 0) 108 | nn.init.constant_(self.fc_k.bias, 0) 109 | nn.init.constant_(self.fc_v.bias, 0) 110 | nn.init.constant_(self.fc_o.bias, 0) 111 | 112 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 113 | ''' 114 | Computes 115 | :param queries: Queries (b_s, nq, d_model) 116 | :param keys: Keys (b_s, nk, d_model) 117 | :param values: Values (b_s, nk, d_model) 118 | :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking. 119 | :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk). 120 | :return: 121 | ''' 122 | b_s, nq = queries.shape[:2] 123 | nk = keys.shape[1] 124 | 125 | m_k = np.sqrt(self.d_k) * self.m_k.expand(b_s, self.m, self.h * self.d_k) 126 | m_v = np.sqrt(self.m) * self.m_v.expand(b_s, self.m, self.h * self.d_v) 127 | 128 | q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k) 129 | k = torch.cat([self.fc_k(keys), m_k], 1).view(b_s, nk + self.m, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk) 130 | v = torch.cat([self.fc_v(values), m_v], 1).view(b_s, nk + self.m, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v) 131 | 132 | att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk) 133 | if attention_weights is not None: 134 | att = torch.cat([att[:, :, :, :nk] * attention_weights, att[:, :, :, nk:]], -1) 135 | if attention_mask is not None: 136 | att[:, :, :, :nk] = att[:, :, :, :nk].masked_fill(attention_mask, -np.inf) 137 | att = torch.softmax(att, -1) 138 | out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v) 139 | out = self.fc_o(out) # (b_s, nq, d_model) 140 | return out 141 | 142 | 143 | class MultiHeadAttention(Module): 144 | ''' 145 | Multi-head attention layer with Dropout and Layer Normalization. 146 | ''' 147 | 148 | def __init__(self, d_model, d_k, d_v, h, dropout=.1, identity_map_reordering=False, can_be_stateful=False, 149 | attention_module=None, attention_module_kwargs=None): 150 | super(MultiHeadAttention, self).__init__() 151 | self.identity_map_reordering = identity_map_reordering 152 | if attention_module is not None: 153 | if attention_module_kwargs is not None: 154 | self.attention = attention_module(d_model=d_model, d_k=d_k, d_v=d_v, h=h, **attention_module_kwargs) 155 | else: 156 | self.attention = attention_module(d_model=d_model, d_k=d_k, d_v=d_v, h=h) 157 | else: 158 | self.attention = ScaledDotProductAttention(d_model=d_model, d_k=d_k, d_v=d_v, h=h) 159 | self.dropout = nn.Dropout(p=dropout) 160 | self.layer_norm = nn.LayerNorm(d_model) 161 | 162 | self.can_be_stateful = can_be_stateful 163 | if self.can_be_stateful: 164 | self.register_state('running_keys', torch.zeros((0, d_model))) 165 | self.register_state('running_values', torch.zeros((0, d_model))) 166 | 167 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 168 | if self.can_be_stateful and self._is_stateful: 169 | self.running_keys = torch.cat([self.running_keys, keys], 1) 170 | keys = self.running_keys 171 | 172 | self.running_values = torch.cat([self.running_values, values], 1) 173 | values = self.running_values 174 | 175 | if self.identity_map_reordering: 176 | q_norm = self.layer_norm(queries) 177 | k_norm = self.layer_norm(keys) 178 | v_norm = self.layer_norm(values) 179 | out = self.attention(q_norm, k_norm, v_norm, attention_mask, attention_weights) 180 | out = queries + self.dropout(torch.relu(out)) 181 | else: 182 | out = self.attention(queries, keys, values, attention_mask, attention_weights) 183 | out = self.dropout(out) 184 | out = self.layer_norm(queries + out) 185 | return out 186 | -------------------------------------------------------------------------------- /models/transformer/beam_search.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from icecream import ic 3 | 4 | from .containers import Module 5 | from .utils import * 6 | 7 | 8 | class BeamSearch(object): 9 | def __init__(self, model, max_len: int, eos_idx: int, beam_size: int): 10 | self.model = model 11 | self.max_len = max_len 12 | self.eos_idx = eos_idx 13 | self.beam_size = beam_size 14 | self.b_s = None 15 | self.device = None 16 | self.seq_mask = None 17 | self.seq_logprob = None 18 | self.outputs = None 19 | self.log_probs = None 20 | self.selected_words = None 21 | self.all_log_probs = None 22 | 23 | def _expand_state(self, selected_beam, cur_beam_size): 24 | def fn(s): 25 | shape = [int(sh) for sh in s.shape] 26 | beam = selected_beam 27 | for _ in shape[1:]: 28 | beam = beam.unsqueeze(-1) 29 | s = torch.gather(s.view(*([self.b_s, cur_beam_size] + shape[1:])), 1, 30 | beam.expand(*([self.b_s, self.beam_size] + shape[1:]))) 31 | s = s.view(*([-1, ] + shape[1:])) 32 | return s 33 | 34 | return fn 35 | 36 | def _expand_visual(self, visual: TensorOrSequence, cur_beam_size: int, selected_beam: torch.Tensor): 37 | if isinstance(visual, torch.Tensor): 38 | visual_shape = visual.shape 39 | visual_exp_shape = (self.b_s, cur_beam_size) + visual_shape[1:] 40 | visual_red_shape = (self.b_s * self.beam_size,) + visual_shape[1:] 41 | selected_beam_red_size = (self.b_s, self.beam_size) + tuple(1 for _ in range(len(visual_exp_shape) - 2)) 42 | selected_beam_exp_size = (self.b_s, self.beam_size) + visual_exp_shape[2:] 43 | visual_exp = visual.view(visual_exp_shape) 44 | selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size) 45 | visual = torch.gather(visual_exp, 1, selected_beam_exp).view(visual_red_shape) 46 | else: 47 | new_visual = [] 48 | for im in visual: 49 | visual_shape = im.shape 50 | visual_exp_shape = (self.b_s, cur_beam_size) + visual_shape[1:] 51 | visual_red_shape = (self.b_s * self.beam_size,) + visual_shape[1:] 52 | selected_beam_red_size = (self.b_s, self.beam_size) + tuple(1 for _ in range(len(visual_exp_shape) - 2)) 53 | selected_beam_exp_size = (self.b_s, self.beam_size) + visual_exp_shape[2:] 54 | visual_exp = im.view(visual_exp_shape) 55 | selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size) 56 | new_im = torch.gather(visual_exp, 1, selected_beam_exp).view(visual_red_shape) 57 | new_visual.append(new_im) 58 | visual = tuple(new_visual) 59 | return visual 60 | 61 | def apply(self, visual: TensorOrSequence, out_size=1, return_probs=False, **kwargs): 62 | self.b_s = get_batch_size(visual) 63 | self.device = get_device(visual) 64 | self.seq_mask = torch.ones((self.b_s, self.beam_size, 1), device=self.device) 65 | self.seq_logprob = torch.zeros((self.b_s, 1, 1), device=self.device) 66 | self.log_probs = [] 67 | self.selected_words = None 68 | if return_probs: 69 | self.all_log_probs = [] 70 | 71 | outputs = [] 72 | with self.model.statefulness(self.b_s): 73 | for t in range(self.max_len): 74 | visual, outputs = self.iter(t, visual, outputs, return_probs, **kwargs) 75 | 76 | enc_output = self.model.enc_output 77 | 78 | # Sort result 79 | seq_logprob, sort_idxs = torch.sort(self.seq_logprob, 1, descending=True) 80 | outputs = torch.cat(outputs, -1) 81 | outputs = torch.gather(outputs, 1, sort_idxs.expand(self.b_s, self.beam_size, self.max_len)) 82 | log_probs = torch.cat(self.log_probs, -1) 83 | log_probs = torch.gather(log_probs, 1, sort_idxs.expand(self.b_s, self.beam_size, self.max_len)) 84 | if return_probs: 85 | all_log_probs = torch.cat(self.all_log_probs, 2) 86 | all_log_probs = torch.gather(all_log_probs, 1, sort_idxs.unsqueeze(-1).expand(self.b_s, self.beam_size, 87 | self.max_len, 88 | all_log_probs.shape[-1])) 89 | 90 | outputs = outputs.contiguous()[:, :out_size] 91 | log_probs = log_probs.contiguous()[:, :out_size] 92 | if out_size == 1: 93 | outputs = outputs.squeeze(1) 94 | log_probs = log_probs.squeeze(1) 95 | 96 | if return_probs: 97 | return outputs, log_probs, all_log_probs 98 | else: 99 | return outputs, log_probs, enc_output 100 | 101 | def select(self, t, candidate_logprob, **kwargs): 102 | selected_logprob, selected_idx = torch.sort(candidate_logprob.view(self.b_s, -1), -1, descending=True) 103 | selected_logprob, selected_idx = selected_logprob[:, :self.beam_size], selected_idx[:, :self.beam_size] 104 | return selected_idx, selected_logprob 105 | 106 | def iter(self, t: int, visual: TensorOrSequence, outputs, return_probs, **kwargs): 107 | cur_beam_size = 1 if t == 0 else self.beam_size 108 | 109 | word_logprob = self.model.step(t, self.selected_words, visual, None, mode='feedback', **kwargs) 110 | word_logprob = word_logprob.view(self.b_s, cur_beam_size, -1) 111 | candidate_logprob = self.seq_logprob + word_logprob 112 | 113 | # Mask sequence if it reaches EOS 114 | if t > 0: 115 | mask = (self.selected_words.view(self.b_s, cur_beam_size) != self.eos_idx).float().unsqueeze(-1) 116 | self.seq_mask = self.seq_mask * mask 117 | word_logprob = word_logprob * self.seq_mask.expand_as(word_logprob) 118 | old_seq_logprob = self.seq_logprob.expand_as(candidate_logprob).contiguous() 119 | old_seq_logprob[:, :, 1:] = -999 120 | candidate_logprob = self.seq_mask * candidate_logprob + old_seq_logprob * (1 - self.seq_mask) 121 | 122 | selected_idx, selected_logprob = self.select(t, candidate_logprob, **kwargs) 123 | selected_beam = selected_idx // candidate_logprob.shape[-1] 124 | selected_words = selected_idx - selected_beam * candidate_logprob.shape[-1] 125 | 126 | self.model.apply_to_states(self._expand_state(selected_beam, cur_beam_size)) 127 | visual = self._expand_visual(visual, cur_beam_size, selected_beam) 128 | 129 | self.seq_logprob = selected_logprob.unsqueeze(-1) 130 | self.seq_mask = torch.gather(self.seq_mask, 1, selected_beam.unsqueeze(-1)) 131 | outputs = list(torch.gather(o, 1, selected_beam.unsqueeze(-1)) for o in outputs) 132 | outputs.append(selected_words.unsqueeze(-1)) 133 | 134 | if return_probs: 135 | if t == 0: 136 | self.all_log_probs.append(word_logprob.expand((self.b_s, self.beam_size, -1)).unsqueeze(2)) 137 | else: 138 | self.all_log_probs.append(word_logprob.unsqueeze(2)) 139 | 140 | this_word_logprob = torch.gather(word_logprob, 1, 141 | selected_beam.unsqueeze(-1).expand(self.b_s, self.beam_size, 142 | word_logprob.shape[-1])) 143 | this_word_logprob = torch.gather(this_word_logprob, 2, selected_words.unsqueeze(-1)) 144 | self.log_probs = list( 145 | torch.gather(o, 1, selected_beam.unsqueeze(-1).expand(self.b_s, self.beam_size, 1)) for o in self.log_probs) 146 | self.log_probs.append(this_word_logprob) 147 | self.selected_words = selected_words.view(-1, 1) 148 | 149 | return visual, outputs 150 | 151 | -------------------------------------------------------------------------------- /models/transformer/containers.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from torch import nn 3 | from .utils import TensorOrNone 4 | 5 | 6 | class Module(nn.Module): 7 | def __init__(self): 8 | super(Module, self).__init__() 9 | self._is_stateful = False 10 | self._state_names = [] 11 | self._state_defaults = dict() 12 | 13 | def register_state(self, name: str, default: TensorOrNone): 14 | self._state_names.append(name) 15 | if default is None: 16 | self._state_defaults[name] = None 17 | else: 18 | self._state_defaults[name] = default.clone().detach() 19 | self.register_buffer(name, default) 20 | 21 | def states(self): 22 | for name in self._state_names: 23 | yield self._buffers[name] 24 | for m in self.children(): 25 | if isinstance(m, Module): 26 | yield from m.states() 27 | 28 | def apply_to_states(self, fn): 29 | for name in self._state_names: 30 | self._buffers[name] = fn(self._buffers[name]) 31 | for name, m in self.named_children(): 32 | if isinstance(m, Module): 33 | m.apply_to_states(fn) 34 | 35 | def _init_states(self, batch_size: int): 36 | for name in self._state_names: 37 | if self._state_defaults[name] is None: 38 | self._buffers[name] = None 39 | else: 40 | self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device) 41 | self._buffers[name] = self._buffers[name].unsqueeze(0) 42 | self._buffers[name] = self._buffers[name].expand([batch_size, ] + list(self._buffers[name].shape[1:])) 43 | self._buffers[name] = self._buffers[name].contiguous() 44 | 45 | def _reset_states(self): 46 | for name in self._state_names: 47 | if self._state_defaults[name] is None: 48 | self._buffers[name] = None 49 | else: 50 | self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device) 51 | 52 | def enable_statefulness(self, batch_size: int): 53 | for m in self.children(): 54 | if isinstance(m, Module): 55 | m.enable_statefulness(batch_size) 56 | self._init_states(batch_size) 57 | self._is_stateful = True 58 | 59 | def disable_statefulness(self): 60 | for m in self.children(): 61 | if isinstance(m, Module): 62 | m.disable_statefulness() 63 | self._reset_states() 64 | self._is_stateful = False 65 | 66 | @contextmanager 67 | def statefulness(self, batch_size: int): 68 | self.enable_statefulness(batch_size) 69 | try: 70 | yield 71 | finally: 72 | self.disable_statefulness() 73 | 74 | 75 | class ModuleList(nn.ModuleList, Module): 76 | pass 77 | 78 | 79 | class ModuleDict(nn.ModuleDict, Module): 80 | pass 81 | -------------------------------------------------------------------------------- /models/transformer/decoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from icecream import ic 3 | from torch import nn 4 | from torch.nn import functional as F 5 | import numpy as np 6 | 7 | from .attention import MultiHeadAttention 8 | from .utils import sinusoid_encoding_table, PositionWiseFeedForward 9 | from .containers import Module, ModuleList 10 | 11 | 12 | class MeshedDecoderLayer(Module): 13 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, self_att_module=None, 14 | enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None): 15 | super(MeshedDecoderLayer, self).__init__() 16 | self.self_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=True, 17 | attention_module=self_att_module, 18 | attention_module_kwargs=self_att_module_kwargs) 19 | self.enc_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=False, 20 | attention_module=enc_att_module, 21 | attention_module_kwargs=enc_att_module_kwargs) 22 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout) 23 | 24 | self.fc_alpha1 = nn.Linear(d_model + d_model, d_model) 25 | self.fc_alpha2 = nn.Linear(d_model + d_model, d_model) 26 | self.fc_alpha3 = nn.Linear(d_model + d_model, d_model) 27 | 28 | self.init_weights() 29 | 30 | def init_weights(self): 31 | nn.init.xavier_uniform_(self.fc_alpha1.weight) 32 | nn.init.xavier_uniform_(self.fc_alpha2.weight) 33 | nn.init.xavier_uniform_(self.fc_alpha3.weight) 34 | nn.init.constant_(self.fc_alpha1.bias, 0) 35 | nn.init.constant_(self.fc_alpha2.bias, 0) 36 | nn.init.constant_(self.fc_alpha3.bias, 0) 37 | 38 | def forward(self, input, enc_output, mask_pad, mask_self_att, mask_enc_att): 39 | self_att = self.self_att(input, input, input, mask_self_att) 40 | self_att = self_att * mask_pad 41 | 42 | enc_att1 = self.enc_att(self_att, enc_output[:, 0], enc_output[:, 0], mask_enc_att) * mask_pad 43 | enc_att2 = self.enc_att(self_att, enc_output[:, 1], enc_output[:, 1], mask_enc_att) * mask_pad 44 | enc_att3 = self.enc_att(self_att, enc_output[:, 2], enc_output[:, 2], mask_enc_att) * mask_pad 45 | 46 | alpha1 = torch.sigmoid(self.fc_alpha1(torch.cat([self_att, enc_att1], -1))) 47 | alpha2 = torch.sigmoid(self.fc_alpha2(torch.cat([self_att, enc_att2], -1))) 48 | alpha3 = torch.sigmoid(self.fc_alpha3(torch.cat([self_att, enc_att3], -1))) 49 | 50 | enc_att = (enc_att1 * alpha1 + enc_att2 * alpha2 + enc_att3 * alpha3) / np.sqrt(3) 51 | enc_att = enc_att * mask_pad 52 | 53 | ff = self.pwff(enc_att) 54 | ff = ff * mask_pad 55 | return ff 56 | 57 | 58 | class MeshedDecoder(Module): 59 | def __init__(self, vocab_size, max_len, N_dec, padding_idx, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, 60 | self_att_module=None, enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None): 61 | super().__init__() 62 | self.d_model = d_model 63 | self.word_emb = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx) 64 | self.pos_emb = nn.Embedding.from_pretrained(sinusoid_encoding_table(max_len + 1, d_model, 0), freeze=True) 65 | self.layers = ModuleList( 66 | [MeshedDecoderLayer(d_model, d_k, d_v, h, d_ff, dropout, self_att_module=self_att_module, 67 | enc_att_module=enc_att_module, self_att_module_kwargs=self_att_module_kwargs, 68 | enc_att_module_kwargs=enc_att_module_kwargs) for _ in range(N_dec)]) 69 | self.fc = nn.Linear(d_model, vocab_size, bias=False) 70 | self.max_len = max_len 71 | self.padding_idx = padding_idx 72 | self.N = N_dec 73 | 74 | self.register_state('running_mask_self_attention', torch.zeros((1, 1, 0)).byte()) 75 | self.register_state('running_seq', torch.zeros((1,)).long()) 76 | 77 | def forward(self, input, encoder_output, mask_encoder): 78 | # input (b_s, seq_len) 79 | b_s, seq_len = input.shape[:2] 80 | mask_queries = (input != self.padding_idx).unsqueeze(-1).float() # (b_s, seq_len, 1) 81 | mask_self_attention = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8, device=input.device), 82 | diagonal=1) 83 | mask_self_attention = mask_self_attention.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len) 84 | mask_self_attention = mask_self_attention + (input == self.padding_idx).unsqueeze(1).unsqueeze(1).byte() 85 | mask_self_attention = mask_self_attention.gt(0) # (b_s, 1, seq_len, seq_len) 86 | if self._is_stateful: 87 | self.running_mask_self_attention = torch.cat([self.running_mask_self_attention, mask_self_attention], -1) 88 | mask_self_attention = self.running_mask_self_attention 89 | mask_self_attention = mask_self_attention.bool() 90 | 91 | seq = torch.arange(1, seq_len + 1).view(1, -1).expand(b_s, -1).to(input.device) # (b_s, seq_len) 92 | seq = seq.masked_fill(mask_queries.squeeze(-1) == 0, 0) 93 | if self._is_stateful: 94 | self.running_seq.add_(1) 95 | seq = self.running_seq 96 | 97 | out = self.word_emb(input) + self.pos_emb(seq) 98 | 99 | intermediate_feats = [] 100 | for i, l in enumerate(self.layers): 101 | out = l(out, encoder_output, mask_queries, mask_self_attention, mask_encoder) 102 | intermediate_feats.append(out) 103 | 104 | intermediate_feats = torch.stack(intermediate_feats) 105 | out = self.fc(out) 106 | return out, intermediate_feats 107 | -------------------------------------------------------------------------------- /models/transformer/encoders.py: -------------------------------------------------------------------------------- 1 | from icecream import ic 2 | from torch.nn import functional as F 3 | from .utils import PositionWiseFeedForward 4 | import torch 5 | from torch import nn 6 | from .attention import MultiHeadAttention 7 | 8 | 9 | class EncoderLayer(nn.Module): 10 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, identity_map_reordering=False, 11 | attention_module=None, attention_module_kwargs=None): 12 | super(EncoderLayer, self).__init__() 13 | self.identity_map_reordering = identity_map_reordering 14 | self.mhatt = MultiHeadAttention(d_model, d_k, d_v, h, dropout, identity_map_reordering=identity_map_reordering, 15 | attention_module=attention_module, 16 | attention_module_kwargs=attention_module_kwargs) 17 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout, identity_map_reordering=identity_map_reordering) 18 | 19 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 20 | att = self.mhatt(queries, keys, values, attention_mask, attention_weights) 21 | ff = self.pwff(att) 22 | return ff 23 | 24 | 25 | class MultiLevelEncoder(nn.Module): 26 | def __init__(self, N, padding_idx, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, 27 | identity_map_reordering=False, attention_module=None, attention_module_kwargs=None): 28 | super(MultiLevelEncoder, self).__init__() 29 | self.d_model = d_model 30 | self.dropout = dropout 31 | self.layers = nn.ModuleList([EncoderLayer(d_model, d_k, d_v, h, d_ff, dropout, 32 | identity_map_reordering=identity_map_reordering, 33 | attention_module=attention_module, 34 | attention_module_kwargs=attention_module_kwargs) 35 | for _ in range(N)]) 36 | self.padding_idx = padding_idx 37 | 38 | def forward(self, input, attention_mask, attention_weights=None): 39 | 40 | outs = [] 41 | out = input 42 | for l in self.layers: 43 | out = l(out, out, out, attention_mask, attention_weights) 44 | outs.append(out.unsqueeze(1)) 45 | 46 | outs = torch.cat(outs, 1) 47 | return outs 48 | 49 | 50 | class MemoryAugmentedEncoder(MultiLevelEncoder): 51 | def __init__(self, N, padding_idx, d_in=2048, **kwargs): 52 | super(MemoryAugmentedEncoder, self).__init__(N, padding_idx, **kwargs) 53 | self.fc = nn.Linear(d_in, self.d_model) 54 | self.dropout = nn.Dropout(p=self.dropout) 55 | self.layer_norm = nn.LayerNorm(self.d_model) 56 | 57 | def forward(self, input, attention_mask, attention_weights=None): 58 | out = F.relu(self.fc(input)) 59 | out = self.dropout(out) 60 | out = self.layer_norm(out) 61 | return super(MemoryAugmentedEncoder, self).forward(out, attention_mask, attention_weights=attention_weights) 62 | 63 | 64 | class DualPathEncoder(nn.Module): 65 | def __init__(self, N, padding_idx, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, 66 | identity_map_reordering=False, attention_module=None, attention_module_kwargs=None): 67 | super().__init__() 68 | self.d_model = d_model 69 | self.dropout = dropout 70 | self.layers = nn.ModuleList([EncoderLayer(d_model, d_k, d_v, h, d_ff, dropout, 71 | identity_map_reordering=identity_map_reordering, 72 | attention_module=attention_module, 73 | attention_module_kwargs=attention_module_kwargs) 74 | for _ in range(N)]) 75 | 76 | self.padding_idx = padding_idx 77 | self.fc_fuse = nn.Linear(2*d_model, d_model) 78 | 79 | def forward(self, input, extra_feat, attention_mask, attention_weights=None): 80 | outs = [] 81 | out = input 82 | 83 | for i, l in enumerate(self.layers): 84 | out = l(out, out, out, attention_mask, attention_weights) 85 | out = out + extra_feat[:, i] 86 | outs.append(out.unsqueeze(1)) 87 | 88 | outs = torch.cat(outs, 1) 89 | return outs 90 | 91 | 92 | class DualPathMemoryAugmentedEncoder(DualPathEncoder): 93 | def __init__(self, N, padding_idx, d_in=2048, **kwargs): 94 | super(DualPathMemoryAugmentedEncoder, self).__init__(N, padding_idx, **kwargs) 95 | self.fc = nn.Linear(d_in, self.d_model) 96 | self.dropout = nn.Dropout(p=self.dropout) 97 | self.layer_norm = nn.LayerNorm(self.d_model) 98 | 99 | def forward(self, input, extra_feat, attention_mask, attention_weights=None): 100 | out = F.relu(self.fc(input)) 101 | out = self.dropout(out) 102 | out = self.layer_norm(out) 103 | 104 | return super(DualPathMemoryAugmentedEncoder, self).forward(out, extra_feat, attention_mask, 105 | attention_weights=attention_weights) 106 | -------------------------------------------------------------------------------- /models/transformer/m2_transformer.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from icecream import ic 7 | 8 | from .containers import Module 9 | from .decoders import MeshedDecoder 10 | from .encoders import MemoryAugmentedEncoder, DualPathMemoryAugmentedEncoder 11 | from .attention import ScaledDotProductAttentionMemory, ScaledDotProductAttention 12 | from .beam_search import BeamSearch 13 | from .utils import TensorOrSequence, get_batch_size, get_device 14 | 15 | 16 | class M2Transformer(Module): 17 | def __init__(self, vocab, max_seq_len, object_latent_dim, padding_idx): 18 | super(M2Transformer, self).__init__() 19 | self.padding_idx = padding_idx 20 | self.bos_idx = vocab['word2idx']['sos'] 21 | self.eos_idx = vocab['word2idx']['eos'] 22 | self.vocab = vocab 23 | 24 | self.encoder = MemoryAugmentedEncoder(3, 0, d_in=object_latent_dim, 25 | attention_module=ScaledDotProductAttentionMemory, 26 | attention_module_kwargs={'m': 40} 27 | ) 28 | self.decoder = MeshedDecoder(len(vocab["word2idx"]), max_seq_len, 1, padding_idx) 29 | 30 | self.register_state('enc_output', None) 31 | self.register_state('mask_enc', None) 32 | 33 | def forward(self, objects_features, tokens): 34 | # input (b_s, seq_len, d_in) 35 | mask_enc = (torch.sum(objects_features, -1) == self.padding_idx).unsqueeze(1).unsqueeze( 36 | 1) # (b_s, 1, 1, seq_len) 37 | 38 | objects_features = self.encoder(objects_features, mask_enc) # (B, 3, n_object, 512) 39 | 40 | dec_outputs, intermediate_feats = self.decoder(tokens, objects_features, mask_enc) # (B, max_len, vocab_size) 41 | 42 | return dec_outputs, intermediate_feats, objects_features 43 | 44 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs): 45 | it = None 46 | if mode == 'teacher_forcing': 47 | raise NotImplementedError 48 | elif mode == 'feedback': 49 | if t == 0: 50 | self.mask_enc = (torch.sum(visual, -1) == self.padding_idx).unsqueeze(1).unsqueeze( 51 | 1) # (b_s, 1, 1, seq_len) 52 | self.enc_output = self.encoder(visual, self.mask_enc) 53 | if isinstance(visual, torch.Tensor): 54 | it = visual.data.new_full((visual.shape[0], 1), self.bos_idx).long() 55 | else: 56 | it = visual[0].data.new_full((visual[0].shape[0], 1), self.bos_idx).long() 57 | else: 58 | it = prev_output 59 | 60 | output = self.decoder(it, self.enc_output, self.mask_enc)[0] 61 | return F.log_softmax(output, dim=-1) 62 | 63 | def beam_search(self, visual: TensorOrSequence, max_len: int, beam_size: int, out_size=1, 64 | return_probs=False, **kwargs): 65 | bs = BeamSearch(self, max_len, self.eos_idx, beam_size) 66 | return bs.apply(visual, out_size, return_probs, **kwargs) 67 | 68 | 69 | class DualM2Transformer(Module): 70 | def __init__(self, vocab, max_seq_len, object_latent_dim, padding_idx): 71 | super(DualM2Transformer, self).__init__() 72 | self.padding_idx = padding_idx 73 | self.bos_idx = vocab['word2idx']['sos'] 74 | self.eos_idx = vocab['word2idx']['eos'] 75 | self.vocab = vocab 76 | 77 | self.encoder = DualPathMemoryAugmentedEncoder(3, 0, d_in=object_latent_dim, 78 | attention_module=ScaledDotProductAttentionMemory, 79 | attention_module_kwargs={'m': 40}) 80 | # self.decoder_t = MeshedDecoder(len(vocab["word2idx"]), max_seq_len, 1, padding_idx) 81 | self.decoder = MeshedDecoder(len(vocab["word2idx"]), max_seq_len, 1, padding_idx) 82 | 83 | self.register_state('enc_output', None) 84 | self.register_state('mask_enc', None) 85 | 86 | def forward(self, feats, extra_feats, tokens): 87 | # input (b_s, seq_len, d_in) 88 | mask_enc = (torch.sum(feats, -1) == self.padding_idx).unsqueeze(1).unsqueeze( 89 | 1) # (b_s, 1, 1, seq_len) 90 | 91 | feats = self.encoder(feats, extra_feats, mask_enc) # (B, 3, n_object, 512) 92 | 93 | dec_outputs, intermediate_feats = self.decoder(tokens, feats, mask_enc) # (B, max_len, vocab_size) 94 | 95 | return dec_outputs, intermediate_feats 96 | 97 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs): 98 | it = None 99 | if mode == 'teacher_forcing': 100 | raise NotImplementedError 101 | elif mode == 'feedback': 102 | if t == 0: 103 | self.mask_enc = (torch.sum(visual[0], -1) == self.padding_idx).unsqueeze(1).unsqueeze( 104 | 1) # (b_s, 1, 1, seq_len) 105 | self.enc_output = self.encoder(visual[0], visual[1], self.mask_enc) 106 | if isinstance(visual, torch.Tensor): 107 | it = visual.data.new_full((visual.shape[0], 1), self.bos_idx).long() 108 | else: 109 | it = visual[0].data.new_full((visual[0].shape[0], 1), self.bos_idx).long() 110 | else: 111 | it = prev_output 112 | 113 | output = self.decoder(it, self.enc_output, self.mask_enc)[0] 114 | return F.log_softmax(output, dim=-1) 115 | 116 | def beam_search(self, visual: TensorOrSequence, max_len: int, beam_size: int, out_size=1, 117 | return_probs=False, **kwargs): 118 | bs = BeamSearch(self, max_len, self.eos_idx, beam_size) 119 | return bs.apply(visual, out_size, return_probs, **kwargs) -------------------------------------------------------------------------------- /models/transformer/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Sequence 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | TensorOrSequence = Union[Sequence[torch.Tensor], torch.Tensor] 8 | TensorOrNone = Union[torch.Tensor, None] 9 | 10 | 11 | def position_embedding(input, d_model): 12 | input = input.view(-1, 1) 13 | dim = torch.arange(d_model // 2, dtype=torch.float32, device=input.device).view(1, -1) 14 | sin = torch.sin(input / 10000 ** (2 * dim / d_model)) 15 | cos = torch.cos(input / 10000 ** (2 * dim / d_model)) 16 | 17 | out = torch.zeros((input.shape[0], d_model), device=input.device) 18 | out[:, ::2] = sin 19 | out[:, 1::2] = cos 20 | return out 21 | 22 | 23 | def sinusoid_encoding_table(max_len, d_model, padding_idx=None): 24 | pos = torch.arange(max_len, dtype=torch.float32) 25 | out = position_embedding(pos, d_model) 26 | 27 | if padding_idx is not None: 28 | out[padding_idx] = 0 29 | return out 30 | 31 | 32 | class PositionWiseFeedForward(nn.Module): 33 | ''' 34 | Position-wise feed forward layer 35 | ''' 36 | 37 | def __init__(self, d_model=512, d_ff=2048, dropout=.1, identity_map_reordering=False): 38 | super(PositionWiseFeedForward, self).__init__() 39 | self.identity_map_reordering = identity_map_reordering 40 | self.fc1 = nn.Linear(d_model, d_ff) 41 | self.fc2 = nn.Linear(d_ff, d_model) 42 | self.dropout = nn.Dropout(p=dropout) 43 | self.dropout_2 = nn.Dropout(p=dropout) 44 | self.layer_norm = nn.LayerNorm(d_model) 45 | 46 | def forward(self, input): 47 | if self.identity_map_reordering: 48 | out = self.layer_norm(input) 49 | out = self.fc2(self.dropout_2(F.relu(self.fc1(out)))) 50 | out = input + self.dropout(torch.relu(out)) 51 | else: 52 | out = self.fc2(self.dropout_2(F.relu(self.fc1(input)))) 53 | out = self.dropout(out) 54 | out = self.layer_norm(input + out) 55 | return out 56 | 57 | 58 | def get_batch_size(x: TensorOrSequence) -> int: 59 | if isinstance(x, torch.Tensor): 60 | b_s = x.size(0) 61 | else: 62 | b_s = x[0].size(0) 63 | return b_s 64 | 65 | 66 | def get_device(x: TensorOrSequence) -> torch.device: 67 | if isinstance(x, torch.Tensor): 68 | b_s = x.device 69 | else: 70 | b_s = x[0].device 71 | return b_s 72 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | from torch import nn, Tensor 2 | import numpy as np 3 | from lib.pointnet2.pointnet2_modules import PointnetSAModule 4 | 5 | 6 | def get_siamese_features(net, in_features, aggregator=None): 7 | """ Applies a network in a siamese way, to 'each' in_feature independently 8 | :param net: nn.Module, Feat-Dim to new-Feat-Dim 9 | :param in_features: B x N-objects x Feat-Dim 10 | :param aggregator, (opt, None, torch.stack, or torch.cat) 11 | :return: B x N-objects x new-Feat-Dim 12 | """ 13 | independent_dim = 1 14 | n_items = in_features.size(independent_dim) 15 | out_features = [] 16 | for i in range(n_items): 17 | out_features.append(net(in_features[:, i])) 18 | if aggregator is not None: 19 | out_features = aggregator(out_features, dim=independent_dim) 20 | return out_features 21 | 22 | 23 | def break_up_pc(pc: Tensor): 24 | """ 25 | Split the pointcloud into xyz positions and features tensors. 26 | This method is taken from VoteNet codebase (https://github.com/facebookresearch/votenet) 27 | 28 | @param pc: pointcloud [N, 3 + C] 29 | :return: the xyz tensor and the feature tensor 30 | """ 31 | xyz = pc[..., 0:3].contiguous() 32 | features = ( 33 | pc[..., 3:].transpose(1, 2).contiguous() 34 | if pc.size(-1) > 3 else None 35 | ) 36 | return xyz, features 37 | 38 | 39 | class PointNetPP(nn.Module): 40 | """ 41 | Pointnet++ encoder. 42 | For the hyper parameters please advise the paper (https://arxiv.org/abs/1706.02413) 43 | """ 44 | 45 | def __init__(self, sa_n_points: list, 46 | sa_n_samples: list, 47 | sa_radii: list, 48 | sa_mlps: list, 49 | bn=True, 50 | use_xyz=True): 51 | super().__init__() 52 | 53 | n_sa = len(sa_n_points) 54 | if not (n_sa == len(sa_n_samples) == len(sa_radii) == len(sa_mlps)): 55 | raise ValueError('Lens of given hyper-params are not compatible') 56 | 57 | self.encoder = nn.ModuleList() 58 | 59 | for i in range(n_sa): 60 | self.encoder.append(PointnetSAModule( 61 | npoint=sa_n_points[i], 62 | nsample=sa_n_samples[i], 63 | radius=sa_radii[i], 64 | mlp=sa_mlps[i], 65 | bn=bn, 66 | use_xyz=use_xyz, 67 | )) 68 | 69 | out_n_points = sa_n_points[-1] if sa_n_points[-1] is not None else 1 70 | self.fc = nn.Linear(out_n_points * sa_mlps[-1][-1], sa_mlps[-1][-1]) 71 | 72 | def forward(self, features): 73 | """ 74 | @param features: B x N_objects x N_Points x 3 + C 75 | """ 76 | xyz, features = break_up_pc(features) 77 | for i in range(len(self.encoder)): 78 | xyz, features = self.encoder[i](xyz, features) 79 | 80 | return self.fc(features.view(features.size(0), -1)) 81 | 82 | 83 | def show_point_clouds(pts, out): 84 | fout = open(out, 'w') 85 | MEAN_COLOR_RGB = np.array([109.8, 97.2, 83.8]) 86 | color = pts[:, 3:6] + MEAN_COLOR_RGB / 255 87 | for i in range(pts.shape[0]): 88 | fout.write('v %f %f %f %f %f %f\n' % ( 89 | pts[i, 0], pts[i, 1], pts[i, 2], color[i, 0], color[i, 1], color[i, 2])) 90 | fout.close() -------------------------------------------------------------------------------- /models/voting_module.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Voting module: generate votes from XYZ and features of seed points. 3 | 4 | Modified from: https://github.com/facebookresearch/votenet/blob/master/models/voting_module.py 5 | ''' 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | class VotingModule(nn.Module): 12 | def __init__(self, vote_factor, seed_feature_dim): 13 | """ Votes generation from seed point features. 14 | 15 | Args: 16 | vote_facotr: int 17 | number of votes generated from each seed point 18 | seed_feature_dim: int 19 | number of channels of seed point features 20 | vote_feature_dim: int 21 | number of channels of vote features 22 | """ 23 | super().__init__() 24 | self.vote_factor = vote_factor 25 | self.in_dim = seed_feature_dim 26 | self.out_dim = self.in_dim # due to residual feature, in_dim has to be == out_dim 27 | self.conv1 = torch.nn.Conv1d(self.in_dim, self.in_dim, 1) 28 | self.conv2 = torch.nn.Conv1d(self.in_dim, self.in_dim, 1) 29 | self.conv3 = torch.nn.Conv1d(self.in_dim, (3+self.out_dim) * self.vote_factor, 1) 30 | self.bn1 = torch.nn.BatchNorm1d(self.in_dim) 31 | self.bn2 = torch.nn.BatchNorm1d(self.in_dim) 32 | 33 | def forward(self, seed_xyz, seed_features): 34 | """ Forward pass. 35 | 36 | Arguments: 37 | seed_xyz: (batch_size, num_seed, 3) Pytorch tensor 38 | seed_features: (batch_size, feature_dim, num_seed) Pytorch tensor 39 | Returns: 40 | vote_xyz: (batch_size, num_seed*vote_factor, 3) 41 | vote_features: (batch_size, vote_feature_dim, num_seed*vote_factor) 42 | """ 43 | batch_size = seed_xyz.shape[0] 44 | num_seed = seed_xyz.shape[1] 45 | num_vote = num_seed*self.vote_factor 46 | net = F.relu(self.bn1(self.conv1(seed_features))) 47 | net = F.relu(self.bn2(self.conv2(net))) 48 | net = self.conv3(net) # (batch_size, (3+out_dim)*vote_factor, num_seed) 49 | 50 | net = net.transpose(2,1).view(batch_size, num_seed, self.vote_factor, 3+self.out_dim) 51 | offset = net[:,:,:,0:3] 52 | vote_xyz = seed_xyz.contiguous().unsqueeze(2) + offset.contiguous() 53 | vote_xyz = vote_xyz.contiguous().view(batch_size, num_vote, 3) 54 | 55 | residual_features = net[:,:,:,3:] # (batch_size, num_seed, vote_factor, out_dim) 56 | vote_features = seed_features.transpose(2,1).unsqueeze(2) + residual_features 57 | vote_features = vote_features.contiguous().view(batch_size, num_vote, self.out_dim) 58 | vote_features = vote_features.transpose(2,1).contiguous() 59 | 60 | return vote_xyz, vote_features 61 | 62 | if __name__=='__main__': 63 | net = VotingModule(2, 256).cuda() 64 | xyz, features = net(torch.rand(8,1024,3).cuda(), torch.rand(8,256,1024).cuda()) 65 | print('xyz', xyz.shape) 66 | print('features', features.shape) 67 | -------------------------------------------------------------------------------- /scripts/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import torch 5 | import numpy as np 6 | 7 | from copy import deepcopy 8 | from torch.utils.data import DataLoader 9 | 10 | sys.path.insert(0, os.getcwd()) # HACK add the root folder 11 | 12 | from data.scannet.model_util_scannet import ScannetDatasetConfig 13 | from lib.dataset import Dataset 14 | from lib.config import CONF 15 | from models.xtrans import TransformerCaptionModule 16 | from lib.eval_helper import eval_cap 17 | from in_out.arguments import parse_arguments 18 | 19 | # constants 20 | DC = ScannetDatasetConfig() 21 | 22 | 23 | def get_dataloader(args, scanrefer, all_scene_list): 24 | dataset = Dataset( 25 | scanrefer=scanrefer, 26 | scanrefer_all_scene=all_scene_list, 27 | name=args.dataset, 28 | split='val', 29 | num_points=args.num_points, 30 | augment=False, 31 | use_color=args.use_color, 32 | ) 33 | 34 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=4) 35 | 36 | return dataset, dataloader 37 | 38 | 39 | def get_scannet_scene_list(data): 40 | # scene_list = sorted([line.rstrip() for line in open(os.path.join(CONF.PATH.DATA, "ScanRefer_filtered_{}.txt".format(split)))]) 41 | scene_list = sorted(list(set([d["scene_id"] for d in data]))) 42 | 43 | return scene_list 44 | 45 | 46 | def get_eval_data(args): 47 | if args.dataset == "ScanRefer": 48 | scanrefer_train = json.load(open(os.path.join(CONF.PATH.DATA, "ScanRefer_filtered_train.json"))) 49 | scanrefer_val = json.load(open(os.path.join(CONF.PATH.DATA, "ScanRefer_filtered_val.json"))) 50 | elif args.dataset == "Nr3d": 51 | scanrefer_train = json.load(open(os.path.join(CONF.PATH.DATA, "nr3d_train.json"))) 52 | scanrefer_val = json.load(open(os.path.join(CONF.PATH.DATA, "nr3d_val.json"))) 53 | else: 54 | raise ValueError("Invalid dataset.") 55 | 56 | eval_scene_list = get_scannet_scene_list(scanrefer_train) if args.use_train else get_scannet_scene_list( 57 | scanrefer_val) 58 | scanrefer_eval = [] 59 | for scene_id in eval_scene_list: 60 | data = deepcopy(scanrefer_train[0]) if args.use_train else deepcopy(scanrefer_val[0]) 61 | data["scene_id"] = scene_id 62 | scanrefer_eval.append(data) 63 | 64 | print("eval on {} samples".format(len(scanrefer_eval))) 65 | 66 | return scanrefer_eval, eval_scene_list 67 | 68 | 69 | def eval_caption(args): 70 | print("initializing...") 71 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 72 | 73 | # get eval data 74 | scanrefer_eval, eval_scene_list = get_eval_data(args) 75 | 76 | # get dataloader 77 | dataset, dataloader = get_dataloader(args, scanrefer_eval, eval_scene_list) 78 | 79 | # get model 80 | model = TransformerCaptionModule(args, dataset) 81 | 82 | # load 83 | model_name = "model_last.pth" if args.use_last else "model.pth" 84 | model_path = os.path.join(CONF.PATH.OUTPUT, args.use_pretrained, model_name) 85 | model.load_state_dict(torch.load(model_path), strict=True) 86 | 87 | model.to(device) 88 | 89 | if args.use_train: 90 | pharse = 'train' 91 | else: 92 | pharse = 'val' 93 | 94 | # evaluate 95 | bleu, cider, rouge, meteor = eval_cap(args.mode, model, dataset, dataloader, pharse, args.use_pretrained, args.use_tf, 96 | force=args.force, save_interm=args.save_interm, min_iou=args.min_iou) 97 | 98 | # report 99 | print("\n----------------------Evaluation-----------------------") 100 | print("[BLEU-1] Mean: {:.4f}, Max: {:.4f}, Min: {:.4f}".format(bleu[0][0], max(bleu[1][0]), min(bleu[1][0]))) 101 | print("[BLEU-2] Mean: {:.4f}, Max: {:.4f}, Min: {:.4f}".format(bleu[0][1], max(bleu[1][1]), min(bleu[1][1]))) 102 | print("[BLEU-3] Mean: {:.4f}, Max: {:.4f}, Min: {:.4f}".format(bleu[0][2], max(bleu[1][2]), min(bleu[1][2]))) 103 | print("[BLEU-4] Mean: {:.4f}, Max: {:.4f}, Min: {:.4f}".format(bleu[0][3], max(bleu[1][3]), min(bleu[1][3]))) 104 | print("[CIDEr] Mean: {:.4f}, Max: {:.4f}, Min: {:.4f}".format(cider[0], max(cider[1]), min(cider[1]))) 105 | print("[ROUGE-L] Mean: {:.4f}, Max: {:.4f}, Min: {:.4f}".format(rouge[0], max(rouge[1]), min(rouge[1]))) 106 | print("[METEOR] Mean: {:.4f}, Max: {:.4f}, Min: {:.4f}".format(meteor[0], max(meteor[1]), min(meteor[1]))) 107 | print() 108 | 109 | 110 | if __name__ == "__main__": 111 | args = parse_arguments() 112 | 113 | # setting 114 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(args.gpu) 115 | 116 | # reproducibility 117 | torch.manual_seed(args.seed) 118 | torch.backends.cudnn.deterministic = True 119 | torch.backends.cudnn.benchmark = False 120 | np.random.seed(args.seed) 121 | 122 | # evaluate 123 | eval_caption(args) 124 | -------------------------------------------------------------------------------- /scripts/organize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | 5 | sys.path.append(os.path.join(os.getcwd())) # HACK add the root folder 6 | from lib.config import CONF 7 | 8 | SCANREFER = json.load(open(os.path.join(CONF.PATH.DATA, "ScanRefer_filtered.json"))) 9 | 10 | organized = {} 11 | for data in SCANREFER: 12 | scene_id = data["scene_id"] 13 | object_id = data["object_id"] 14 | ann_id = data["ann_id"] 15 | 16 | # store 17 | if scene_id not in organized: 18 | organized[scene_id] = {} 19 | 20 | if object_id not in organized[scene_id]: 21 | organized[scene_id][object_id] = {} 22 | 23 | if ann_id not in organized[scene_id][object_id]: 24 | organized[scene_id][object_id][ann_id] = None 25 | 26 | organized[scene_id][object_id][ann_id] = data 27 | 28 | with open(os.path.join(CONF.PATH.DATA, "ScanRefer_organized.json"), "w") as f: 29 | json.dump(organized, f, indent=4) 30 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import h5py 5 | import argparse 6 | import importlib 7 | import torch 8 | import torch.optim as optim 9 | import torch.nn as nn 10 | import numpy as np 11 | 12 | from torch.utils.data import DataLoader 13 | from datetime import datetime 14 | from copy import deepcopy 15 | 16 | sys.path.insert(0, os.getcwd()) # HACK add the root folder 17 | from data.scannet.model_util_scannet import ScannetDatasetConfig 18 | from lib.solver import Solver 19 | from lib.config import CONF 20 | from models.xtrans import TransformerCaptionModule 21 | from lib.dataset import Dataset 22 | from in_out.arguments import parse_arguments 23 | 24 | # constants 25 | DC = ScannetDatasetConfig() 26 | 27 | 28 | def get_dataloader(args, scanrefer, all_scene_list, split, augment): 29 | dataset = Dataset( 30 | scanrefer=scanrefer, 31 | scanrefer_all_scene=all_scene_list, 32 | name=args.dataset, 33 | split=split, 34 | num_points=args.num_points, 35 | augment=augment, 36 | use_color=args.use_color, 37 | ) 38 | is_shuffle = True if split == 'train' else False 39 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=is_shuffle, num_workers=8, pin_memory=True) 40 | 41 | return dataset, dataloader 42 | 43 | 44 | def get_num_params(model): 45 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 46 | num_params = int(sum([np.prod(p.size()) for p in model_parameters])) 47 | 48 | return num_params 49 | 50 | 51 | def get_solver(args, dataset, dataloader): 52 | # initiate model 53 | model = TransformerCaptionModule(args, dataset["train"]) 54 | # to device 55 | model.cuda() 56 | 57 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) 58 | 59 | if args.use_rl: 60 | model_path = os.path.join(CONF.PATH.OUTPUT, args.pretrained_path, "model.pth") 61 | model.load_state_dict(torch.load(model_path), strict=True) 62 | 63 | if args.use_checkpoint: 64 | print("loading checkpoint {}...".format(args.use_checkpoint)) 65 | stamp = args.use_checkpoint 66 | root = os.path.join(CONF.PATH.OUTPUT, stamp) 67 | checkpoint = torch.load(os.path.join(CONF.PATH.OUTPUT, args.use_checkpoint, "checkpoint.tar")) 68 | model.load_state_dict(checkpoint["model_state_dict"]) 69 | optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 70 | else: 71 | stamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 72 | if args.tag: 73 | stamp = args.tag 74 | root = os.path.join(CONF.PATH.OUTPUT, stamp) 75 | os.makedirs(root, exist_ok=True) 76 | 77 | # scheduler parameters for training solely the detection pipeline 78 | LR_DECAY_STEP = [15, 20] 79 | LR_DECAY_RATE = 0.1 80 | BN_DECAY_STEP = None 81 | BN_DECAY_RATE = None 82 | 83 | solver = Solver( 84 | mode=args.mode, 85 | model=model, 86 | config=DC, 87 | dataset=dataset, 88 | dataloader=dataloader, 89 | optimizer=optimizer, 90 | stamp=stamp, 91 | val_step=args.val_step, 92 | use_tf=args.use_tf, 93 | use_rl=args.use_rl, 94 | lr_decay_step=LR_DECAY_STEP, 95 | lr_decay_rate=LR_DECAY_RATE, 96 | bn_decay_step=BN_DECAY_STEP, 97 | bn_decay_rate=BN_DECAY_RATE, 98 | criterion=args.criterion 99 | ) 100 | num_params = get_num_params(model) 101 | print('params: ', num_params) 102 | 103 | return solver, num_params, root 104 | 105 | 106 | def save_info(args, root, num_params, dataset): 107 | info = {} 108 | for key, value in vars(args).items(): 109 | info[key] = value 110 | 111 | info["num_train"] = len(dataset["train"]) 112 | info["num_eval_train"] = len(dataset["eval"]["train"]) 113 | info["num_eval_val"] = len(dataset["eval"]["val"]) 114 | info["num_train_scenes"] = len(dataset["train"].scene_list) 115 | info["num_eval_train_scenes"] = len(dataset["eval"]["train"].scene_list) 116 | info["num_eval_val_scenes"] = len(dataset["eval"]["val"].scene_list) 117 | info["num_params"] = num_params 118 | 119 | with open(os.path.join(root, "info.json"), "w") as f: 120 | json.dump(info, f, indent=4) 121 | 122 | 123 | def get_scanrefer(args): 124 | if args.dataset == "ScanRefer": 125 | scanrefer_train = json.load(open(os.path.join(CONF.PATH.DATA, "ScanRefer_filtered_train.json"))) 126 | scanrefer_eval_val = json.load(open(os.path.join(CONF.PATH.DATA, "ScanRefer_filtered_val.json"))) 127 | elif args.dataset == "Nr3d": 128 | scanrefer_train = json.load(open(os.path.join(CONF.PATH.DATA, "nr3d_train.json"))) 129 | scanrefer_eval_val = json.load(open(os.path.join(CONF.PATH.DATA, "nr3d_val.json"))) 130 | else: 131 | raise ValueError("Invalid dataset.") 132 | 133 | if args.debug: 134 | scanrefer_train = [scanrefer_train[0]] 135 | scanrefer_eval_val = [scanrefer_train[0]] 136 | 137 | train_scene_list = sorted(list(set([data["scene_id"] for data in scanrefer_train]))) 138 | val_scene_list = sorted(list(set([data["scene_id"] for data in scanrefer_eval_val]))) 139 | 140 | # eval 141 | scanrefer_eval_train = [] 142 | for scene_id in train_scene_list: 143 | data = deepcopy(scanrefer_train[0]) 144 | data["scene_id"] = scene_id 145 | scanrefer_eval_train.append(data) 146 | 147 | scanrefer_eval_val = [] 148 | for scene_id in val_scene_list: 149 | data = deepcopy(scanrefer_train[0]) 150 | data["scene_id"] = scene_id 151 | scanrefer_eval_val.append(data) 152 | 153 | print("train on {} samples from {} scenes".format(len(scanrefer_eval_train), len(train_scene_list))) 154 | print("eval on {} scenes from train and {} scenes from val".format(len(train_scene_list), len(val_scene_list))) 155 | 156 | return scanrefer_train, scanrefer_eval_train, scanrefer_eval_val, train_scene_list, val_scene_list 157 | 158 | 159 | def train(args): 160 | # init training dataset 161 | print("preparing data...") 162 | scanrefer_train, scanrefer_eval_train, scanrefer_eval_val, train_scene_list, val_scene_list = get_scanrefer(args) 163 | 164 | # dataloader 165 | train_dataset, train_dataloader = get_dataloader(args, scanrefer_train, train_scene_list, "train", not args.no_augment) 166 | eval_train_dataset, eval_train_dataloader = get_dataloader(args, scanrefer_eval_train, train_scene_list, "train", False) 167 | eval_val_dataset, eval_val_dataloader = get_dataloader(args, scanrefer_eval_val, val_scene_list, "val", False) 168 | dataset = { 169 | "train": train_dataset, 170 | "eval": { 171 | "train": eval_train_dataset, 172 | "val": eval_val_dataset 173 | } 174 | } 175 | dataloader = { 176 | "train": train_dataloader, 177 | "eval": { 178 | "train": eval_train_dataloader, 179 | "val": eval_val_dataloader 180 | } 181 | } 182 | 183 | print("initializing...") 184 | solver, num_params, root = get_solver(args, dataset, dataloader) 185 | 186 | print("Start training...\n") 187 | save_info(args, root, num_params, dataset) 188 | solver(args.epoch, args.verbose) 189 | 190 | 191 | if __name__ == "__main__": 192 | args = parse_arguments() 193 | 194 | # setting 195 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 196 | 197 | # reproducibility 198 | torch.manual_seed(args.seed) 199 | torch.backends.cudnn.deterministic = True 200 | torch.backends.cudnn.benchmark = False 201 | np.random.seed(args.seed) 202 | 203 | train(args) 204 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CurryYuan/X-Trans2Cap/aebe6e2d421034f2de8742fa9946e669bcb497e6/utils/__init__.py -------------------------------------------------------------------------------- /utils/eta.py: -------------------------------------------------------------------------------- 1 | ''' 2 | File Created: Monday, 25th November 2019 1:35:30 pm 3 | Author: Dave Zhenyu Chen (zhenyu.chen@tum.de) 4 | ''' 5 | 6 | def get_eta(start, end, extra, num_left): 7 | exe_s = end - start 8 | eta_s = (exe_s + extra) * num_left 9 | eta = {'h': 0, 'm': 0, 's': 0} 10 | if eta_s < 60: 11 | eta['s'] = int(eta_s) 12 | elif eta_s >= 60 and eta_s < 3600: 13 | eta['m'] = int(eta_s / 60) 14 | eta['s'] = int(eta_s % 60) 15 | else: 16 | eta['h'] = int(eta_s / (60 * 60)) 17 | eta['m'] = int(eta_s % (60 * 60) / 60) 18 | eta['s'] = int(eta_s % (60 * 60) % 60) 19 | 20 | return eta 21 | 22 | def decode_eta(eta_sec): 23 | eta = {'h': 0, 'm': 0, 's': 0} 24 | if eta_sec < 60: 25 | eta['s'] = int(eta_sec) 26 | elif eta_sec >= 60 and eta_sec < 3600: 27 | eta['m'] = int(eta_sec / 60) 28 | eta['s'] = int(eta_sec % 60) 29 | else: 30 | eta['h'] = int(eta_sec / (60 * 60)) 31 | eta['m'] = int(eta_sec % (60 * 60) / 60) 32 | eta['s'] = int(eta_sec % (60 * 60) % 60) 33 | 34 | return eta -------------------------------------------------------------------------------- /utils/eval_det.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generic Code for Object Detection Evaluation 3 | From: https://github.com/facebookresearch/votenet/blob/master/utils/eval_det.py 4 | 5 | Input: 6 | For each class: 7 | For each image: 8 | Predictions: box, score 9 | Groundtruths: box 10 | 11 | Output: 12 | For each class: 13 | precision-recal and average precision 14 | 15 | Author: Charles R. Qi 16 | 17 | Ref: https://raw.githubusercontent.com/rbgirshick/py-faster-rcnn/master/lib/datasets/voc_eval.py 18 | """ 19 | import numpy as np 20 | 21 | def voc_ap(rec, prec, use_07_metric=False): 22 | """ ap = voc_ap(rec, prec, [use_07_metric]) 23 | Compute VOC AP given precision and recall. 24 | If use_07_metric is true, uses the 25 | VOC 07 11 point method (default:False). 26 | """ 27 | if use_07_metric: 28 | # 11 point metric 29 | ap = 0. 30 | for t in np.arange(0., 1.1, 0.1): 31 | if np.sum(rec >= t) == 0: 32 | p = 0 33 | else: 34 | p = np.max(prec[rec >= t]) 35 | ap = ap + p / 11. 36 | else: 37 | # correct AP calculation 38 | # first append sentinel values at the end 39 | mrec = np.concatenate(([0.], rec, [1.])) 40 | mpre = np.concatenate(([0.], prec, [0.])) 41 | 42 | # compute the precision envelope 43 | for i in range(mpre.size - 1, 0, -1): 44 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 45 | 46 | # to calculate area under PR curve, look for points 47 | # where X axis (recall) changes value 48 | i = np.where(mrec[1:] != mrec[:-1])[0] 49 | 50 | # and sum (\Delta recall) * prec 51 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 52 | return ap 53 | 54 | import os 55 | import sys 56 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 57 | from utils.metric_util import calc_iou # axis-aligned 3D box IoU 58 | def get_iou(bb1, bb2): 59 | """ Compute IoU of two bounding boxes. 60 | ** Define your bod IoU function HERE ** 61 | """ 62 | #pass 63 | iou3d = calc_iou(bb1, bb2) 64 | return iou3d 65 | 66 | from utils.box_util import box3d_iou 67 | def get_iou_obb(bb1,bb2): 68 | iou3d = box3d_iou(bb1,bb2) 69 | return iou3d 70 | 71 | def get_iou_main(get_iou_func, args): 72 | return get_iou_func(*args) 73 | 74 | def eval_det_cls(pred, gt, ovthresh=0.25, use_07_metric=False, get_iou_func=get_iou): 75 | """ Generic functions to compute precision/recall for object detection 76 | for a single class. 77 | Input: 78 | pred: map of {img_id: [(bbox, score)]} where bbox is numpy array 79 | gt: map of {img_id: [bbox]} 80 | ovthresh: scalar, iou threshold 81 | use_07_metric: bool, if True use VOC07 11 point method 82 | Output: 83 | rec: numpy array of length nd 84 | prec: numpy array of length nd 85 | ap: scalar, average precision 86 | """ 87 | 88 | # construct gt objects 89 | class_recs = {} # {img_id: {'bbox': bbox list, 'det': matched list}} 90 | npos = 0 91 | for img_id in gt.keys(): 92 | bbox = np.array(gt[img_id]) 93 | det = [False] * len(bbox) 94 | npos += len(bbox) 95 | class_recs[img_id] = {'bbox': bbox, 'det': det} 96 | # pad empty list to all other imgids 97 | for img_id in pred.keys(): 98 | if img_id not in gt: 99 | class_recs[img_id] = {'bbox': np.array([]), 'det': []} 100 | 101 | # construct dets 102 | image_ids = [] 103 | confidence = [] 104 | BB = [] 105 | for img_id in pred.keys(): 106 | for box,score in pred[img_id]: 107 | image_ids.append(img_id) 108 | confidence.append(score) 109 | BB.append(box) 110 | confidence = np.array(confidence) 111 | BB = np.array(BB) # (nd,4 or 8,3 or 6) 112 | 113 | # sort by confidence 114 | sorted_ind = np.argsort(-confidence) 115 | sorted_scores = np.sort(-confidence) 116 | BB = BB[sorted_ind, ...] 117 | image_ids = [image_ids[x] for x in sorted_ind] 118 | 119 | # go down dets and mark TPs and FPs 120 | nd = len(image_ids) 121 | tp = np.zeros(nd) 122 | fp = np.zeros(nd) 123 | for d in range(nd): 124 | #if d%100==0: print(d) 125 | R = class_recs[image_ids[d]] 126 | bb = BB[d,...].astype(float) 127 | ovmax = -np.inf 128 | BBGT = R['bbox'].astype(float) 129 | 130 | if BBGT.size > 0: 131 | # compute overlaps 132 | for j in range(BBGT.shape[0]): 133 | iou = get_iou_main(get_iou_func, (bb, BBGT[j,...])) 134 | if iou > ovmax: 135 | ovmax = iou 136 | jmax = j 137 | 138 | #print d, ovmax 139 | if ovmax > ovthresh: 140 | if not R['det'][jmax]: 141 | tp[d] = 1. 142 | R['det'][jmax] = 1 143 | else: 144 | fp[d] = 1. 145 | else: 146 | fp[d] = 1. 147 | 148 | # compute precision recall 149 | fp = np.cumsum(fp) 150 | tp = np.cumsum(tp) 151 | rec = tp / float(npos + 1e-8) 152 | #print('NPOS: ', npos) 153 | # avoid divide by zero in case the first detection matches a difficult 154 | # ground truth 155 | prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps) 156 | ap = voc_ap(rec, prec, use_07_metric) 157 | 158 | return rec, prec, ap 159 | 160 | def eval_det_cls_wrapper(arguments): 161 | pred, gt, ovthresh, use_07_metric, get_iou_func = arguments 162 | rec, prec, ap = eval_det_cls(pred, gt, ovthresh, use_07_metric, get_iou_func) 163 | return (rec, prec, ap) 164 | 165 | def eval_det(pred_all, gt_all, ovthresh=0.25, use_07_metric=False, get_iou_func=get_iou): 166 | """ Generic functions to compute precision/recall for object detection 167 | for multiple classes. 168 | Input: 169 | pred_all: map of {img_id: [(classname, bbox, score)]} 170 | gt_all: map of {img_id: [(classname, bbox)]} 171 | ovthresh: scalar, iou threshold 172 | use_07_metric: bool, if true use VOC07 11 point method 173 | Output: 174 | rec: {classname: rec} 175 | prec: {classname: prec_all} 176 | ap: {classname: scalar} 177 | """ 178 | pred = {} # map {classname: pred} 179 | gt = {} # map {classname: gt} 180 | for img_id in pred_all.keys(): 181 | for classname, bbox, score in pred_all[img_id]: 182 | if classname not in pred: pred[classname] = {} 183 | if img_id not in pred[classname]: 184 | pred[classname][img_id] = [] 185 | if classname not in gt: gt[classname] = {} 186 | if img_id not in gt[classname]: 187 | gt[classname][img_id] = [] 188 | pred[classname][img_id].append((bbox,score)) 189 | for img_id in gt_all.keys(): 190 | for classname, bbox in gt_all[img_id]: 191 | if classname not in gt: gt[classname] = {} 192 | if img_id not in gt[classname]: 193 | gt[classname][img_id] = [] 194 | gt[classname][img_id].append(bbox) 195 | 196 | rec = {} 197 | prec = {} 198 | ap = {} 199 | for classname in gt.keys(): 200 | print('Computing AP for class: ', classname) 201 | rec[classname], prec[classname], ap[classname] = eval_det_cls(pred[classname], gt[classname], ovthresh, use_07_metric, get_iou_func) 202 | print(classname, ap[classname]) 203 | 204 | return rec, prec, ap 205 | 206 | from multiprocessing import Pool 207 | def eval_det_multiprocessing(pred_all, gt_all, ovthresh=0.25, use_07_metric=False, get_iou_func=get_iou): 208 | """ Generic functions to compute precision/recall for object detection 209 | for multiple classes. 210 | Input: 211 | pred_all: map of {img_id: [(classname, bbox, score)]} 212 | gt_all: map of {img_id: [(classname, bbox)]} 213 | ovthresh: scalar, iou threshold 214 | use_07_metric: bool, if true use VOC07 11 point method 215 | Output: 216 | rec: {classname: rec} 217 | prec: {classname: prec_all} 218 | ap: {classname: scalar} 219 | """ 220 | pred = {} # map {classname: pred} 221 | gt = {} # map {classname: gt} 222 | for img_id in pred_all.keys(): 223 | for classname, bbox, score in pred_all[img_id]: 224 | if classname not in pred: pred[classname] = {} 225 | if img_id not in pred[classname]: 226 | pred[classname][img_id] = [] 227 | if classname not in gt: gt[classname] = {} 228 | if img_id not in gt[classname]: 229 | gt[classname][img_id] = [] 230 | pred[classname][img_id].append((bbox,score)) 231 | for img_id in gt_all.keys(): 232 | for classname, bbox in gt_all[img_id]: 233 | if classname not in gt: gt[classname] = {} 234 | if img_id not in gt[classname]: 235 | gt[classname][img_id] = [] 236 | gt[classname][img_id].append(bbox) 237 | 238 | rec = {} 239 | prec = {} 240 | ap = {} 241 | p = Pool(processes=10) 242 | ret_values = p.map(eval_det_cls_wrapper, [(pred[classname], gt[classname], ovthresh, use_07_metric, get_iou_func) for classname in gt.keys() if classname in pred]) 243 | p.close() 244 | for i, classname in enumerate(gt.keys()): 245 | if classname in pred: 246 | rec[classname], prec[classname], ap[classname] = ret_values[i] 247 | else: 248 | rec[classname] = 0 249 | prec[classname] = 0 250 | ap[classname] = 0 251 | print(classname, ap[classname]) 252 | 253 | return rec, prec, ap 254 | -------------------------------------------------------------------------------- /utils/metric_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for metric evaluation. 3 | From: https://github.com/facebookresearch/votenet/blob/master/utils/metric_util.py 4 | 5 | Author: Or Litany and Charles R. Qi 6 | """ 7 | 8 | import os 9 | import sys 10 | import torch 11 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 12 | sys.path.append(BASE_DIR) 13 | 14 | import numpy as np 15 | 16 | # Mesh IO 17 | import trimesh 18 | 19 | 20 | # ---------------------------------------- 21 | # Precision and Recall 22 | # ---------------------------------------- 23 | 24 | def multi_scene_precision_recall(labels, pred, iou_thresh, conf_thresh, label_mask, pred_mask=None): 25 | ''' 26 | Args: 27 | labels: (B, N, 6) 28 | pred: (B, M, 6) 29 | iou_thresh: scalar 30 | conf_thresh: scalar 31 | label_mask: (B, N,) with values in 0 or 1 to indicate which GT boxes to consider. 32 | pred_mask: (B, M,) with values in 0 or 1 to indicate which PRED boxes to consider. 33 | Returns: 34 | TP,FP,FN,Precision,Recall 35 | ''' 36 | # Make sure the masks are not Torch tensor, otherwise the mask==1 returns uint8 array instead 37 | # of True/False array as in numpy 38 | assert(not torch.is_tensor(label_mask)) 39 | assert(not torch.is_tensor(pred_mask)) 40 | TP, FP, FN = 0, 0, 0 41 | if label_mask is None: label_mask = np.ones((labels.shape[0], labels.shape[1])) 42 | if pred_mask is None: pred_mask = np.ones((pred.shape[0], pred.shape[1])) 43 | for batch_idx in range(labels.shape[0]): 44 | TP_i, FP_i, FN_i = single_scene_precision_recall(labels[batch_idx, label_mask[batch_idx,:]==1, :], 45 | pred[batch_idx, pred_mask[batch_idx,:]==1, :], 46 | iou_thresh, conf_thresh) 47 | TP += TP_i 48 | FP += FP_i 49 | FN += FN_i 50 | 51 | return TP, FP, FN, precision_recall(TP, FP, FN) 52 | 53 | 54 | def single_scene_precision_recall(labels, pred, iou_thresh, conf_thresh): 55 | """Compute P and R for predicted bounding boxes. Ignores classes! 56 | Args: 57 | labels: (N x bbox) ground-truth bounding boxes (6 dims) 58 | pred: (M x (bbox + conf)) predicted bboxes with confidence and maybe classification 59 | Returns: 60 | TP, FP, FN 61 | """ 62 | 63 | 64 | # for each pred box with high conf (C), compute IoU with all gt boxes. 65 | # TP = number of times IoU > th ; FP = C - TP 66 | # FN - number of scene objects without good match 67 | 68 | gt_bboxes = labels[:, :6] 69 | 70 | num_scene_bboxes = gt_bboxes.shape[0] 71 | conf = pred[:, 6] 72 | 73 | conf_pred_bbox = pred[np.where(conf > conf_thresh)[0], :6] 74 | num_conf_pred_bboxes = conf_pred_bbox.shape[0] 75 | 76 | # init an array to keep iou between generated and scene bboxes 77 | iou_arr = np.zeros([num_conf_pred_bboxes, num_scene_bboxes]) 78 | for g_idx in range(num_conf_pred_bboxes): 79 | for s_idx in range(num_scene_bboxes): 80 | iou_arr[g_idx, s_idx] = calc_iou(conf_pred_bbox[g_idx ,:], gt_bboxes[s_idx, :]) 81 | 82 | 83 | good_match_arr = (iou_arr >= iou_thresh) 84 | 85 | TP = good_match_arr.any(axis=1).sum() 86 | FP = num_conf_pred_bboxes - TP 87 | FN = num_scene_bboxes - good_match_arr.any(axis=0).sum() 88 | 89 | return TP, FP, FN 90 | 91 | 92 | def precision_recall(TP, FP, FN): 93 | Prec = 1.0 * TP / (TP + FP) if TP+FP>0 else 0 94 | Rec = 1.0 * TP / (TP + FN) 95 | return Prec, Rec 96 | 97 | 98 | def calc_iou(box_a, box_b): 99 | """Computes IoU of two axis aligned bboxes. 100 | Args: 101 | box_a, box_b: 6D of center and lengths 102 | Returns: 103 | iou 104 | """ 105 | 106 | max_a = box_a[0:3] + box_a[3:6]/2 107 | max_b = box_b[0:3] + box_b[3:6]/2 108 | min_max = np.array([max_a, max_b]).min(0) 109 | 110 | min_a = box_a[0:3] - box_a[3:6]/2 111 | min_b = box_b[0:3] - box_b[3:6]/2 112 | max_min = np.array([min_a, min_b]).max(0) 113 | if not ((min_max > max_min).all()): 114 | return 0.0 115 | 116 | intersection = (min_max - max_min).prod() 117 | vol_a = box_a[3:6].prod() 118 | vol_b = box_b[3:6].prod() 119 | union = vol_a + vol_b - intersection 120 | return 1.0*intersection / union 121 | 122 | 123 | if __name__ == '__main__': 124 | print('running some tests') 125 | 126 | ############ 127 | ## Test IoU 128 | ############ 129 | box_a = np.array([0,0,0,1,1,1]) 130 | box_b = np.array([0,0,0,2,2,2]) 131 | expected_iou = 1.0/8 132 | pred_iou = calc_iou(box_a, box_b) 133 | assert expected_iou == pred_iou, 'function returned wrong IoU' 134 | 135 | box_a = np.array([0,0,0,1,1,1]) 136 | box_b = np.array([10,10,10,2,2,2]) 137 | expected_iou = 0.0 138 | pred_iou = calc_iou(box_a, box_b) 139 | assert expected_iou == pred_iou, 'function returned wrong IoU' 140 | 141 | print('IoU test -- PASSED') 142 | 143 | ######################### 144 | ## Test Precition Recall 145 | ######################### 146 | gt_boxes = np.array([[0,0,0,1,1,1],[3, 0, 1, 1, 10, 1]]) 147 | detected_boxes = np.array([[0,0,0,1,1,1, 1.0],[3, 0, 1, 1, 10, 1, 0.9]]) 148 | TP, FP, FN = single_scene_precision_recall(gt_boxes, detected_boxes, 0.5, 0.5) 149 | assert TP == 2 and FP == 0 and FN == 0 150 | assert precision_recall(TP, FP, FN) == (1, 1) 151 | 152 | detected_boxes = np.array([[0,0,0,1,1,1, 1.0]]) 153 | TP, FP, FN = single_scene_precision_recall(gt_boxes, detected_boxes, 0.5, 0.5) 154 | assert TP == 1 and FP == 0 and FN == 1 155 | assert precision_recall(TP, FP, FN) == (1, 0.5) 156 | 157 | detected_boxes = np.array([[0,0,0,1,1,1, 1.0], [-1,-1,0,0.1,0.1,1, 1.0]]) 158 | TP, FP, FN = single_scene_precision_recall(gt_boxes, detected_boxes, 0.5, 0.5) 159 | assert TP == 1 and FP == 1 and FN == 1 160 | assert precision_recall(TP, FP, FN) == (0.5, 0.5) 161 | 162 | # wrong box has low confidence 163 | detected_boxes = np.array([[0,0,0,1,1,1, 1.0], [-1,-1,0,0.1,0.1,1, 0.1]]) 164 | TP, FP, FN = single_scene_precision_recall(gt_boxes, detected_boxes, 0.5, 0.5) 165 | assert TP == 1 and FP == 0 and FN == 1 166 | assert precision_recall(TP, FP, FN) == (1, 0.5) 167 | 168 | print('Precition Recall test -- PASSED') 169 | 170 | -------------------------------------------------------------------------------- /utils/nms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from utils.pc_utils import bbox_corner_dist_measure 3 | 4 | # boxes are axis aigned 2D boxes of shape (n,5) in FLOAT numbers with (x1,y1,x2,y2,score) 5 | ''' Ref: https://www.pyimagesearch.com/2015/02/16/faster-non-maximum-suppression-python/ 6 | Ref: https://github.com/vickyboy47/nms-python/blob/master/nms.py 7 | ''' 8 | def nms_2d(boxes, overlap_threshold): 9 | x1 = boxes[:,0] 10 | y1 = boxes[:,1] 11 | x2 = boxes[:,2] 12 | y2 = boxes[:,3] 13 | score = boxes[:,4] 14 | area = (x2-x1)*(y2-y1) 15 | 16 | I = np.argsort(score) 17 | pick = [] 18 | while (I.size!=0): 19 | last = I.size 20 | i = I[-1] 21 | pick.append(i) 22 | suppress = [last-1] 23 | for pos in range(last-1): 24 | j = I[pos] 25 | xx1 = max(x1[i],x1[j]) 26 | yy1 = max(y1[i],y1[j]) 27 | xx2 = min(x2[i],x2[j]) 28 | yy2 = min(y2[i],y2[j]) 29 | w = xx2-xx1 30 | h = yy2-yy1 31 | if (w>0 and h>0): 32 | o = w*h/area[j] 33 | print('Overlap is', o) 34 | if (o>overlap_threshold): 35 | suppress.append(pos) 36 | I = np.delete(I,suppress) 37 | return pick 38 | 39 | def nms_2d_faster(boxes, overlap_threshold, old_type=False): 40 | x1 = boxes[:,0] 41 | y1 = boxes[:,1] 42 | x2 = boxes[:,2] 43 | y2 = boxes[:,3] 44 | score = boxes[:,4] 45 | area = (x2-x1)*(y2-y1) 46 | 47 | I = np.argsort(score) 48 | pick = [] 49 | while (I.size!=0): 50 | last = I.size 51 | i = I[-1] 52 | pick.append(i) 53 | 54 | xx1 = np.maximum(x1[i], x1[I[:last-1]]) 55 | yy1 = np.maximum(y1[i], y1[I[:last-1]]) 56 | xx2 = np.minimum(x2[i], x2[I[:last-1]]) 57 | yy2 = np.minimum(y2[i], y2[I[:last-1]]) 58 | 59 | w = np.maximum(0, xx2-xx1) 60 | h = np.maximum(0, yy2-yy1) 61 | 62 | if old_type: 63 | o = (w*h)/area[I[:last-1]] 64 | else: 65 | inter = w*h 66 | o = inter / (area[i] + area[I[:last-1]] - inter) 67 | 68 | I = np.delete(I, np.concatenate(([last-1], np.where(o>overlap_threshold)[0]))) 69 | 70 | return pick 71 | 72 | def nms_3d_faster(boxes, overlap_threshold, old_type=False): 73 | x1 = boxes[:,0] 74 | y1 = boxes[:,1] 75 | z1 = boxes[:,2] 76 | x2 = boxes[:,3] 77 | y2 = boxes[:,4] 78 | z2 = boxes[:,5] 79 | score = boxes[:,6] 80 | area = (x2-x1)*(y2-y1)*(z2-z1) 81 | 82 | I = np.argsort(score) 83 | pick = [] 84 | while (I.size!=0): 85 | last = I.size 86 | i = I[-1] 87 | pick.append(i) 88 | 89 | xx1 = np.maximum(x1[i], x1[I[:last-1]]) 90 | yy1 = np.maximum(y1[i], y1[I[:last-1]]) 91 | zz1 = np.maximum(z1[i], z1[I[:last-1]]) 92 | xx2 = np.minimum(x2[i], x2[I[:last-1]]) 93 | yy2 = np.minimum(y2[i], y2[I[:last-1]]) 94 | zz2 = np.minimum(z2[i], z2[I[:last-1]]) 95 | 96 | l = np.maximum(0, xx2-xx1) 97 | w = np.maximum(0, yy2-yy1) 98 | h = np.maximum(0, zz2-zz1) 99 | 100 | if old_type: 101 | o = (l*w*h)/area[I[:last-1]] 102 | else: 103 | inter = l*w*h 104 | o = inter / (area[i] + area[I[:last-1]] - inter) 105 | 106 | I = np.delete(I, np.concatenate(([last-1], np.where(o>overlap_threshold)[0]))) 107 | 108 | return pick 109 | 110 | def nms_3d_faster_samecls(boxes, overlap_threshold, old_type=False): 111 | x1 = boxes[:,0] 112 | y1 = boxes[:,1] 113 | z1 = boxes[:,2] 114 | x2 = boxes[:,3] 115 | y2 = boxes[:,4] 116 | z2 = boxes[:,5] 117 | score = boxes[:,6] 118 | cls = boxes[:,7] 119 | area = (x2-x1)*(y2-y1)*(z2-z1) 120 | 121 | I = np.argsort(score) 122 | pick = [] 123 | while (I.size!=0): 124 | last = I.size 125 | i = I[-1] 126 | pick.append(i) 127 | 128 | xx1 = np.maximum(x1[i], x1[I[:last-1]]) 129 | yy1 = np.maximum(y1[i], y1[I[:last-1]]) 130 | zz1 = np.maximum(z1[i], z1[I[:last-1]]) 131 | xx2 = np.minimum(x2[i], x2[I[:last-1]]) 132 | yy2 = np.minimum(y2[i], y2[I[:last-1]]) 133 | zz2 = np.minimum(z2[i], z2[I[:last-1]]) 134 | cls1 = cls[i] 135 | cls2 = cls[I[:last-1]] 136 | 137 | l = np.maximum(0, xx2-xx1) 138 | w = np.maximum(0, yy2-yy1) 139 | h = np.maximum(0, zz2-zz1) 140 | 141 | if old_type: 142 | o = (l*w*h)/area[I[:last-1]] 143 | else: 144 | inter = l*w*h 145 | o = inter / (area[i] + area[I[:last-1]] - inter + 1e-8) 146 | o = o * (cls1==cls2) 147 | 148 | I = np.delete(I, np.concatenate(([last-1], np.where(o>overlap_threshold)[0]))) 149 | 150 | return pick 151 | 152 | 153 | def nms_crnr_dist(boxes, conf, overlap_threshold): 154 | 155 | I = np.argsort(conf) 156 | pick = [] 157 | while (I.size!=0): 158 | last = I.size 159 | i = I[-1] 160 | pick.append(i) 161 | 162 | scores = [] 163 | for ind in I[:-1]: 164 | scores.append(bbox_corner_dist_measure(boxes[i,:], boxes[ind, :])) 165 | 166 | I = np.delete(I, np.concatenate(([last-1], np.where(np.array(scores)>overlap_threshold)[0]))) 167 | 168 | return pick 169 | 170 | if __name__=='__main__': 171 | a = np.random.random((100,5)) 172 | print(nms_2d(a,0.9)) 173 | print(nms_2d_faster(a,0.9)) 174 | -------------------------------------------------------------------------------- /utils/nn_distance.py: -------------------------------------------------------------------------------- 1 | """ 2 | Chamfer distance in Pytorch. 3 | Author: Charles R. Qi 4 | 5 | From: https://github.com/facebookresearch/votenet/blob/master/utils/nn_distance.py 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import numpy as np 11 | 12 | 13 | def huber_loss(error, delta=1.0): 14 | """ 15 | Args: 16 | error: Torch tensor (d1,d2,...,dk) 17 | Returns: 18 | loss: Torch tensor (d1,d2,...,dk) 19 | 20 | x = error = pred - gt or dist(pred,gt) 21 | 0.5 * |x|^2 if |x|<=d 22 | 0.5 * d^2 + d * (|x|-d) if |x|>d 23 | Ref: https://github.com/charlesq34/frustum-pointnets/blob/master/models/model_util.py 24 | """ 25 | abs_error = torch.abs(error) 26 | # quadratic = torch.min(abs_error, torch.FloatTensor([delta])) 27 | quadratic = torch.clamp(abs_error, max=delta) 28 | linear = (abs_error - quadratic) 29 | loss = 0.5 * quadratic ** 2 + delta * linear 30 | return loss 31 | 32 | 33 | def nn_distance(pc1, pc2, l1smooth=False, delta=1.0, l1=False): 34 | """ 35 | Input: 36 | pc1: (B,N,C) torch tensor 37 | pc2: (B,M,C) torch tensor 38 | l1smooth: bool, whether to use l1smooth loss 39 | delta: scalar, the delta used in l1smooth loss 40 | Output: 41 | dist1: (B,N) torch float32 tensor 42 | idx1: (B,N) torch int64 tensor 43 | dist2: (B,M) torch float32 tensor 44 | idx2: (B,M) torch int64 tensor 45 | """ 46 | N = pc1.shape[1] 47 | M = pc2.shape[1] 48 | pc1_expand_tile = pc1.unsqueeze(2).repeat(1, 1, M, 1) 49 | pc2_expand_tile = pc2.unsqueeze(1).repeat(1, N, 1, 1) 50 | pc_diff = pc1_expand_tile - pc2_expand_tile 51 | 52 | if l1smooth: 53 | pc_dist = torch.sum(huber_loss(pc_diff, delta), dim=-1) # (B,N,M) 54 | elif l1: 55 | pc_dist = torch.sum(torch.abs(pc_diff), dim=-1) # (B,N,M) 56 | else: 57 | pc_dist = torch.sum(pc_diff ** 2, dim=-1) # (B,N,M) 58 | dist1, idx1 = torch.min(pc_dist, dim=2) # (B,N) 59 | dist2, idx2 = torch.min(pc_dist, dim=1) # (B,M) 60 | return dist1, idx1, dist2, idx2 61 | 62 | 63 | def demo_nn_distance(): 64 | np.random.seed(0) 65 | pc1arr = np.random.random((1, 5, 3)) 66 | pc2arr = np.random.random((1, 6, 3)) 67 | pc1 = torch.from_numpy(pc1arr.astype(np.float32)) 68 | pc2 = torch.from_numpy(pc2arr.astype(np.float32)) 69 | dist1, idx1, dist2, idx2 = nn_distance(pc1, pc2) 70 | print(dist1) 71 | print(idx1) 72 | dist = np.zeros((5, 6)) 73 | for i in range(5): 74 | for j in range(6): 75 | dist[i, j] = np.sum((pc1arr[0, i, :] - pc2arr[0, j, :]) ** 2) 76 | print(dist) 77 | print('-' * 30) 78 | print('L1smooth dists:') 79 | dist1, idx1, dist2, idx2 = nn_distance(pc1, pc2, True) 80 | print(dist1) 81 | print(idx1) 82 | dist = np.zeros((5, 6)) 83 | for i in range(5): 84 | for j in range(6): 85 | error = np.abs(pc1arr[0, i, :] - pc2arr[0, j, :]) 86 | quad = np.minimum(error, 1.0) 87 | linear = error - quad 88 | loss = 0.5 * quad ** 2 + 1.0 * linear 89 | dist[i, j] = np.sum(loss) 90 | print(dist) 91 | 92 | 93 | if __name__ == '__main__': 94 | demo_nn_distance() 95 | --------------------------------------------------------------------------------