├── README.md ├── environment.yml ├── meshed-memory-transformer-master ├── LICENSE ├── README.md ├── data │ ├── __init__.py │ ├── dataset.py │ ├── example.py │ ├── field.py │ ├── utils.py │ └── vocab.py ├── environment.yml ├── evaluation │ ├── __init__.py │ ├── bleu │ │ ├── __init__.py │ │ ├── bleu.py │ │ └── bleu_scorer.py │ ├── cider │ │ ├── __init__.py │ │ ├── cider.py │ │ └── cider_scorer.py │ ├── meteor │ │ ├── __init__.py │ │ └── meteor.py │ ├── rouge │ │ ├── __init__.py │ │ └── rouge.py │ └── tokenizer.py ├── images │ ├── m2.png │ └── results.png ├── models │ ├── __init__.py │ ├── beam_search │ │ ├── __init__.py │ │ └── beam_search.py │ ├── captioning_model.py │ ├── containers.py │ └── transformer │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── decoders.py │ │ ├── encoders.py │ │ ├── transformer.py │ │ └── utils.py ├── output_logs │ └── meshed_memory_transformer_test_o ├── test.py ├── train.py ├── utils │ ├── __init__.py │ ├── typing.py │ └── utils.py └── vocab.pkl ├── model ├── Model.py ├── __init__.py ├── __pycache__ │ ├── Model.cpython-36.pyc │ ├── __init__.cpython-36.pyc │ └── view_transformer_random4.cpython-36.pyc ├── best.py ├── gvcnn.py ├── gvcnn_random.py ├── mvcnn.py ├── mvcnn_random.py ├── view_transformer.py └── view_transformer_random4.py ├── models ├── Vit.py ├── __init__.py ├── __pycache__ │ ├── Vit.cpython-36.pyc │ ├── Vit.cpython-38.pyc │ ├── __init__.cpython-36.pyc │ ├── captioning_model.cpython-36.pyc │ ├── containers.cpython-36.pyc │ ├── utils.cpython-36.pyc │ └── utils.cpython-38.pyc ├── beam_search │ ├── __init__.py │ ├── __pycache__ │ │ └── __init__.cpython-36.pyc │ └── beam_search.py ├── captioning_model.py ├── containers.py ├── transformer │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── attention.cpython-36.pyc │ │ ├── decoders.cpython-36.pyc │ │ ├── encoders.cpython-36.pyc │ │ ├── transformer.cpython-36.pyc │ │ └── utils.cpython-36.pyc │ ├── attention.py │ ├── decoders.py │ ├── encoders.py │ ├── transformer.py │ └── utils.py ├── utils.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── typing.cpython-36.pyc │ └── utils.cpython-36.pyc │ ├── typing.py │ └── utils.py ├── otk ├── cross_attention.py ├── data_utils.py ├── layers.py ├── models.py ├── models_deepsea.py ├── sinkhorn.py └── utils.py ├── tools ├── ImgDataset_m40_6_20.py ├── ImgDataset_obj_6_20.py ├── Trainer_ours_m40r4.py ├── __init__.py ├── replas.py ├── test_tools.py └── view_gcn_utils.py └── train4.py /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch code for CVR [ICCV2021]. 2 | 3 | Xin Wei*, Yifei Gong*, Fudong Wang, Xing Sun, Jian Sun. **Learning Canonical View Representation for 3D Shape Recognition with Arbitrary Views**. ICCV, accepted, 2021. [[pdf]](https://openaccess.thecvf.com/content/ICCV2021/papers/Wei_Learning_Canonical_View_Representation_for_3D_Shape_Recognition_With_Arbitrary_ICCV_2021_paper.pdf) [[supp]](https://openaccess.thecvf.com/content/ICCV2021/supplemental/Wei_Learning_Canonical_View_ICCV_2021_supplemental.pdf) 4 | 5 | ## Citation 6 | If you find our work useful in your research, please consider citing: 7 | ``` 8 | @InProceedings{Wei_2021_ICCV, 9 | author = {Wei, Xin and Gong, Yifei and Wang, Fudong and Sun, Xing and Sun, Jian}, 10 | title = {Learning Canonical View Representation for 3D Shape Recognition With Arbitrary Views}, 11 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 12 | month = {October}, 13 | year = {2021}, 14 | pages = {407-416} 15 | } 16 | ``` 17 | 18 | ## Training 19 | 20 | ### Requiement 21 | 22 | This code is tested on Python 3.6 and Pytorch 1.8 + 23 | 24 | ### Dataset 25 | 26 | First download the arbitrary views ModelNet40 dataset and put it under `data` 27 | 28 | `https://drive.google.com/file/d/1RfE0aJ_IXNspVs610BkMgcWDP6FB0cYX/view?usp=sharing` 29 | 30 | Link of arbitrary views ScanObjectNN dataset: 31 | 32 | `https://drive.google.com/file/d/10xl-S8-XlaX5187Dkv91pX4JQ5DZxeM8/view?usp=sharing` 33 | 34 | Aligned-ScanObjectNN dataset: `https://drive.google.com/file/d/1ihR6Fv88-6FOVUWdfHVMfDbUrx2eIPpR/view?usp=sharing` 35 | 36 | Rotated-ScanObjectNN dataset: `https://drive.google.com/file/d/1GCwgrfbO_uO3Qh9UNPWRCuz2yr8UyRRT/view?usp=sharing` 37 | 38 | #### The code is borrowed from [[OTK]](https://github.com/claying/OTK) and [[view-GCN]](https://github.com/weixmath/view-GCN). 39 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: view 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=conda_forge 8 | - _openmp_mutex=4.5=1_llvm 9 | - blas=1.0=mkl 10 | - bzip2=1.0.8=h7f98852_4 11 | - ca-certificates=2021.1.19=h06a4308_1 12 | - certifi=2020.12.5=py38h06a4308_0 13 | - cudatoolkit=11.1.1=h6406543_8 14 | - ffmpeg=4.3=hf484d3e_0 15 | - freetype=2.10.4=h0708190_1 16 | - gmp=6.2.1=h58526e2_0 17 | - gnutls=3.6.13=h85f3911_1 18 | - jpeg=9b=h024ee3a_2 19 | - lame=3.100=h7f98852_1001 20 | - lcms2=2.12=h3be6417_0 21 | - ld_impl_linux-64=2.35.1=hea4e1c9_2 22 | - libffi=3.3=h58526e2_2 23 | - libgcc-ng=9.3.0=h2828fa1_19 24 | - libiconv=1.16=h516909a_0 25 | - libpng=1.6.37=h21135ba_2 26 | - libstdcxx-ng=9.3.0=h6de172a_19 27 | - libtiff=4.1.0=h2733197_1 28 | - libuv=1.41.0=h7f98852_0 29 | - llvm-openmp=11.1.0=h4bd325d_1 30 | - lz4-c=1.9.3=h9c3ff4c_0 31 | - mkl=2020.4=h726a3e6_304 32 | - mkl-service=2.3.0=py38h1e0a361_2 33 | - mkl_fft=1.3.0=py38h5c078b8_1 34 | - mkl_random=1.2.0=py38hc5bc63f_1 35 | - ncurses=6.2=h58526e2_4 36 | - nettle=3.6=he412f7d_0 37 | - ninja=1.10.2=h4bd325d_0 38 | - numpy=1.19.2=py38h54aff64_0 39 | - numpy-base=1.19.2=py38hfa32c7d_0 40 | - olefile=0.46=pyh9f0ad1d_1 41 | - openh264=2.1.1=h780b84a_0 42 | - openssl=1.1.1k=h27cfd23_0 43 | - pillow=8.2.0=py38he98fc37_0 44 | - pip=21.0.1=pyhd8ed1ab_0 45 | - python=3.8.8=hffdb5ce_0_cpython 46 | - python_abi=3.8=1_cp38 47 | - pytorch=1.8.1=py3.8_cuda11.1_cudnn8.0.5_0 48 | - readline=8.0=he28a2e2_2 49 | - setuptools=49.6.0=py38h578d9bd_3 50 | - six=1.15.0=pyh9f0ad1d_0 51 | - sqlite=3.35.4=h74cdb3f_0 52 | - tk=8.6.10=h21135ba_1 53 | - torchaudio=0.8.1=py38 54 | - torchvision=0.9.1=py38_cu111 55 | - typing_extensions=3.7.4.3=py_0 56 | - wheel=0.36.2=pyhd3deb0d_0 57 | - xz=5.2.5=h516909a_1 58 | - zlib=1.2.11=h516909a_1010 59 | - zstd=1.4.9=ha95c52a_0 60 | - pip: 61 | - chardet==4.0.0 62 | - idna==2.10 63 | - opencv-python==4.5.3.56 64 | - protobuf==3.15.8 65 | - requests==2.25.1 66 | - scipy==1.6.3 67 | - tensorboardx==2.2 68 | - urllib3==1.26.4 69 | prefix: /public/home/weixin/.conda/envs/view 70 | -------------------------------------------------------------------------------- /meshed-memory-transformer-master/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, AImageLab 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /meshed-memory-transformer-master/README.md: -------------------------------------------------------------------------------- 1 | # M²: Meshed-Memory Transformer 2 | This repository contains the reference code for the paper _[Meshed-Memory Transformer for Image Captioning](https://arxiv.org/abs/1912.08226)_ (CVPR 2020). 3 | 4 | Please cite with the following BibTeX: 5 | 6 | ``` 7 | @inproceedings{cornia2020m2, 8 | title={{Meshed-Memory Transformer for Image Captioning}}, 9 | author={Cornia, Marcella and Stefanini, Matteo and Baraldi, Lorenzo and Cucchiara, Rita}, 10 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 11 | year={2020} 12 | } 13 | ``` 14 |

15 | Meshed-Memory Transformer 16 |

17 | 18 | ## Environment setup 19 | Clone the repository and create the `m2release` conda environment using the `environment.yml` file: 20 | ``` 21 | conda env create -f environment.yml 22 | conda activate m2release 23 | ``` 24 | 25 | Then download spacy data by executing the following command: 26 | ``` 27 | python -m spacy download en 28 | ``` 29 | 30 | Note: Python 3.6 is required to run our code. 31 | 32 | 33 | ## Data preparation 34 | To run the code, annotations and detection features for the COCO dataset are needed. Please download the annotations file [annotations.zip](https://drive.google.com/file/d/1i8mqKFKhqvBr8kEp3DbIh9-9UNAfKGmE/view?usp=sharing) and extract it. 35 | 36 | Detection features are computed with the code provided by [1]. To reproduce our result, please download the COCO features file [coco_detections.hdf5](https://drive.google.com/open?id=1MV6dSnqViQfyvgyHrmAT_lLpFbkzp3mx) (~53.5 GB), in which detections of each image are stored under the `_features` key. `` is the id of each COCO image, without leading zeros (e.g. the `` for `COCO_val2014_000000037209.jpg` is `37209`), and each value should be a `(N, 2048)` tensor, where `N` is the number of detections. 37 | 38 | 39 | ## Evaluation 40 | To reproduce the results reported in our paper, download the pretrained model file [meshed_memory_transformer.pth](https://drive.google.com/file/d/1naUSnVqXSMIdoiNz_fjqKwh9tXF7x8Nx/view?usp=sharing) and place it in the code folder. 41 | 42 | Run `python test.py` using the following arguments: 43 | 44 | | Argument | Possible values | 45 | |------|------| 46 | | `--batch_size` | Batch size (default: 10) | 47 | | `--workers` | Number of workers (default: 0) | 48 | | `--features_path` | Path to detection features file | 49 | | `--annotation_folder` | Path to folder with COCO annotations | 50 | 51 | #### Expected output 52 | Under `output_logs/`, you may also find the expected output of the evaluation code. 53 | 54 | 55 | ## Training procedure 56 | Run `python train.py` using the following arguments: 57 | 58 | | Argument | Possible values | 59 | |------|------| 60 | | `--exp_name` | Experiment name| 61 | | `--batch_size` | Batch size (default: 10) | 62 | | `--workers` | Number of workers (default: 0) | 63 | | `--m` | Number of memory vectors (default: 40) | 64 | | `--head` | Number of heads (default: 8) | 65 | | `--warmup` | Warmup value for learning rate scheduling (default: 10000) | 66 | | `--resume_last` | If used, the training will be resumed from the last checkpoint. | 67 | | `--resume_best` | If used, the training will be resumed from the best checkpoint. | 68 | | `--features_path` | Path to detection features file | 69 | | `--annotation_folder` | Path to folder with COCO annotations | 70 | | `--logs_folder` | Path folder for tensorboard logs (default: "tensorboard_logs")| 71 | 72 | For example, to train our model with the parameters used in our experiments, use 73 | ``` 74 | python train.py --exp_name m2_transformer --batch_size 50 --m 40 --head 8 --warmup 10000 --features_path /path/to/features --annotation_folder /path/to/annotations 75 | ``` 76 | 77 |

78 | Sample Results 79 |

80 | 81 | #### References 82 | [1] P. Anderson, X. He, C. Buehler, D. Teney, M. Johnson, S. Gould, and L. Zhang. Bottom-up and top-down attention for image captioning and visual question answering. In _Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition_, 2018. 83 | -------------------------------------------------------------------------------- /meshed-memory-transformer-master/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .field import RawField, Merge, ImageDetectionsField, TextField 2 | from .dataset import COCO 3 | from torch.utils.data import DataLoader as TorchDataLoader 4 | 5 | class DataLoader(TorchDataLoader): 6 | def __init__(self, dataset, *args, **kwargs): 7 | super(DataLoader, self).__init__(dataset, *args, collate_fn=dataset.collate_fn(), **kwargs) 8 | -------------------------------------------------------------------------------- /meshed-memory-transformer-master/data/example.py: -------------------------------------------------------------------------------- 1 | 2 | class Example(object): 3 | """Defines a single training or test example. 4 | Stores each column of the example as an attribute. 5 | """ 6 | @classmethod 7 | def fromdict(cls, data): 8 | ex = cls(data) 9 | return ex 10 | 11 | def __init__(self, data): 12 | for key, val in data.items(): 13 | super(Example, self).__setattr__(key, val) 14 | 15 | def __setattr__(self, key, value): 16 | raise AttributeError 17 | 18 | def __hash__(self): 19 | return hash(tuple(x for x in self.__dict__.values())) 20 | 21 | def __eq__(self, other): 22 | this = tuple(x for x in self.__dict__.values()) 23 | other = tuple(x for x in other.__dict__.values()) 24 | return this == other 25 | 26 | def __ne__(self, other): 27 | return not self.__eq__(other) 28 | -------------------------------------------------------------------------------- /meshed-memory-transformer-master/data/utils.py: -------------------------------------------------------------------------------- 1 | import contextlib, sys 2 | 3 | class DummyFile(object): 4 | def write(self, x): pass 5 | 6 | @contextlib.contextmanager 7 | def nostdout(): 8 | save_stdout = sys.stdout 9 | sys.stdout = DummyFile() 10 | yield 11 | sys.stdout = save_stdout 12 | 13 | def reporthook(t): 14 | """https://github.com/tqdm/tqdm""" 15 | last_b = [0] 16 | 17 | def inner(b=1, bsize=1, tsize=None): 18 | """ 19 | b: int, optionala 20 | Number of blocks just transferred [default: 1]. 21 | bsize: int, optional 22 | Size of each block (in tqdm units) [default: 1]. 23 | tsize: int, optional 24 | Total size (in tqdm units). If [default: None] remains unchanged. 25 | """ 26 | if tsize is not None: 27 | t.total = tsize 28 | t.update((b - last_b[0]) * bsize) 29 | last_b[0] = b 30 | return inner 31 | 32 | def get_tokenizer(tokenizer): 33 | if callable(tokenizer): 34 | return tokenizer 35 | if tokenizer == "spacy": 36 | try: 37 | import spacy 38 | spacy_en = spacy.load('en') 39 | return lambda s: [tok.text for tok in spacy_en.tokenizer(s)] 40 | except ImportError: 41 | print("Please install SpaCy and the SpaCy English tokenizer. " 42 | "See the docs at https://spacy.io for more information.") 43 | raise 44 | except AttributeError: 45 | print("Please install SpaCy and the SpaCy English tokenizer. " 46 | "See the docs at https://spacy.io for more information.") 47 | raise 48 | elif tokenizer == "moses": 49 | try: 50 | from nltk.tokenize.moses import MosesTokenizer 51 | moses_tokenizer = MosesTokenizer() 52 | return moses_tokenizer.tokenize 53 | except ImportError: 54 | print("Please install NLTK. " 55 | "See the docs at http://nltk.org for more information.") 56 | raise 57 | except LookupError: 58 | print("Please install the necessary NLTK corpora. " 59 | "See the docs at http://nltk.org for more information.") 60 | raise 61 | elif tokenizer == 'revtok': 62 | try: 63 | import revtok 64 | return revtok.tokenize 65 | except ImportError: 66 | print("Please install revtok.") 67 | raise 68 | elif tokenizer == 'subword': 69 | try: 70 | import revtok 71 | return lambda x: revtok.tokenize(x, decap=True) 72 | except ImportError: 73 | print("Please install revtok.") 74 | raise 75 | raise ValueError("Requested tokenizer {}, valid choices are a " 76 | "callable that takes a single string as input, " 77 | "\"revtok\" for the revtok reversible tokenizer, " 78 | "\"subword\" for the revtok caps-aware tokenizer, " 79 | "\"spacy\" for the SpaCy English tokenizer, or " 80 | "\"moses\" for the NLTK port of the Moses tokenization " 81 | "script.".format(tokenizer)) 82 | -------------------------------------------------------------------------------- /meshed-memory-transformer-master/environment.yml: -------------------------------------------------------------------------------- 1 | name: m2release 2 | channels: 3 | - anaconda 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - asn1crypto=1.2.0=py36_0 8 | - blas=1.0=mkl 9 | - ca-certificates=2019.10.16=0 10 | - certifi=2019.9.11=py36_0 11 | - cffi=1.13.2=py36h2e261b9_0 12 | - chardet=3.0.4=py36_1003 13 | - cryptography=2.8=py36h1ba5d50_0 14 | - cython=0.29.14=py36he6710b0_0 15 | - dill=0.2.9=py36_0 16 | - idna=2.8=py36_0 17 | - intel-openmp=2019.5=281 18 | - libedit=3.1.20181209=hc058e9b_0 19 | - libffi=3.2.1=hd88cf55_4 20 | - libgcc-ng=9.1.0=hdf63c60_0 21 | - libgfortran-ng=7.3.0=hdf63c60_0 22 | - libstdcxx-ng=9.1.0=hdf63c60_0 23 | - mkl=2019.5=281 24 | - mkl-service=2.3.0=py36he904b0f_0 25 | - mkl_fft=1.0.15=py36ha843d7b_0 26 | - mkl_random=1.1.0=py36hd6b4f25_0 27 | - msgpack-numpy=0.4.4.3=py_0 28 | - msgpack-python=0.5.6=py36h6bb024c_1 29 | - ncurses=6.1=he6710b0_1 30 | - openjdk=8.0.152=h46b5887_1 31 | - openssl=1.1.1=h7b6447c_0 32 | - pip=19.3.1=py36_0 33 | - pycparser=2.19=py_0 34 | - pyopenssl=19.1.0=py36_0 35 | - pysocks=1.7.1=py36_0 36 | - python=3.6.9=h265db76_0 37 | - readline=7.0=h7b6447c_5 38 | - requests=2.22.0=py36_0 39 | - setuptools=41.6.0=py36_0 40 | - six=1.13.0=py36_0 41 | - spacy=2.0.11=py36h04863e7_2 42 | - sqlite=3.30.1=h7b6447c_0 43 | - termcolor=1.1.0=py36_1 44 | - thinc=6.11.2=py36hedc7406_1 45 | - tk=8.6.8=hbc83047_0 46 | - toolz=0.10.0=py_0 47 | - urllib3=1.24.2=py36_0 48 | - wheel=0.33.6=py36_0 49 | - xz=5.2.4=h14c3975_4 50 | - zlib=1.2.11=h7b6447c_3 51 | - pip: 52 | - absl-py==0.8.1 53 | - cycler==0.10.0 54 | - cymem==1.31.2 55 | - cytoolz==0.9.0.1 56 | - future==0.17.1 57 | - grpcio==1.25.0 58 | - h5py==2.8.0 59 | - kiwisolver==1.1.0 60 | - markdown==3.1.1 61 | - matplotlib==2.2.3 62 | - msgpack==0.6.2 63 | - multiprocess==0.70.9 64 | - murmurhash==0.28.0 65 | - numpy==1.16.4 66 | - pathlib==1.0.1 67 | - pathos==0.2.3 68 | - pillow==6.2.1 69 | - plac==0.9.6 70 | - pox==0.2.7 71 | - ppft==1.6.6.1 72 | - preshed==1.0.1 73 | - protobuf==3.10.0 74 | - pycocotools==2.0.0 75 | - pyparsing==2.4.5 76 | - python-dateutil==2.8.1 77 | - pytz==2019.3 78 | - regex==2017.4.5 79 | - tensorboard==1.14.0 80 | - torch==1.1.0 81 | - torchvision==0.3.0 82 | - tqdm==4.32.2 83 | - ujson==1.35 84 | - werkzeug==0.16.0 85 | - wrapt==1.10.11 86 | 87 | -------------------------------------------------------------------------------- /meshed-memory-transformer-master/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .bleu import Bleu 2 | from .meteor import Meteor 3 | from .rouge import Rouge 4 | from .cider import Cider 5 | from .tokenizer import PTBTokenizer 6 | 7 | def compute_scores(gts, gen): 8 | metrics = (Bleu(), Meteor(), Rouge(), Cider()) 9 | all_score = {} 10 | all_scores = {} 11 | for metric in metrics: 12 | score, scores = metric.compute_score(gts, gen) 13 | all_score[str(metric)] = score 14 | all_scores[str(metric)] = scores 15 | 16 | return all_score, all_scores 17 | -------------------------------------------------------------------------------- /meshed-memory-transformer-master/evaluation/bleu/__init__.py: -------------------------------------------------------------------------------- 1 | from .bleu import Bleu -------------------------------------------------------------------------------- /meshed-memory-transformer-master/evaluation/bleu/bleu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : bleu.py 4 | # 5 | # Description : Wrapper for BLEU scorer. 6 | # 7 | # Creation Date : 06-01-2015 8 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | from .bleu_scorer import BleuScorer 12 | 13 | 14 | class Bleu: 15 | def __init__(self, n=4): 16 | # default compute Blue score up to 4 17 | self._n = n 18 | self._hypo_for_image = {} 19 | self.ref_for_image = {} 20 | 21 | def compute_score(self, gts, res): 22 | 23 | assert(gts.keys() == res.keys()) 24 | imgIds = gts.keys() 25 | 26 | bleu_scorer = BleuScorer(n=self._n) 27 | for id in imgIds: 28 | hypo = res[id] 29 | ref = gts[id] 30 | 31 | # Sanity check. 32 | assert(type(hypo) is list) 33 | assert(len(hypo) == 1) 34 | assert(type(ref) is list) 35 | assert(len(ref) >= 1) 36 | 37 | bleu_scorer += (hypo[0], ref) 38 | 39 | # score, scores = bleu_scorer.compute_score(option='shortest') 40 | score, scores = bleu_scorer.compute_score(option='closest', verbose=0) 41 | # score, scores = bleu_scorer.compute_score(option='average', verbose=1) 42 | 43 | return score, scores 44 | 45 | def __str__(self): 46 | return 'BLEU' 47 | -------------------------------------------------------------------------------- /meshed-memory-transformer-master/evaluation/bleu/bleu_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # bleu_scorer.py 4 | # David Chiang 5 | 6 | # Copyright (c) 2004-2006 University of Maryland. All rights 7 | # reserved. Do not redistribute without permission from the 8 | # author. Not for commercial use. 9 | 10 | # Modified by: 11 | # Hao Fang 12 | # Tsung-Yi Lin 13 | 14 | ''' Provides: 15 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test(). 16 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked(). 17 | ''' 18 | 19 | import copy 20 | import sys, math, re 21 | from collections import defaultdict 22 | 23 | 24 | def precook(s, n=4, out=False): 25 | """Takes a string as input and returns an object that can be given to 26 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 27 | can take string arguments as well.""" 28 | words = s.split() 29 | counts = defaultdict(int) 30 | for k in range(1, n + 1): 31 | for i in range(len(words) - k + 1): 32 | ngram = tuple(words[i:i + k]) 33 | counts[ngram] += 1 34 | return (len(words), counts) 35 | 36 | 37 | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average" 38 | '''Takes a list of reference sentences for a single segment 39 | and returns an object that encapsulates everything that BLEU 40 | needs to know about them.''' 41 | 42 | reflen = [] 43 | maxcounts = {} 44 | for ref in refs: 45 | rl, counts = precook(ref, n) 46 | reflen.append(rl) 47 | for (ngram, count) in counts.items(): 48 | maxcounts[ngram] = max(maxcounts.get(ngram, 0), count) 49 | 50 | # Calculate effective reference sentence length. 51 | if eff == "shortest": 52 | reflen = min(reflen) 53 | elif eff == "average": 54 | reflen = float(sum(reflen)) / len(reflen) 55 | 56 | ## lhuang: N.B.: leave reflen computaiton to the very end!! 57 | 58 | ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design) 59 | 60 | return (reflen, maxcounts) 61 | 62 | 63 | def cook_test(test, ref_tuple, eff=None, n=4): 64 | '''Takes a test sentence and returns an object that 65 | encapsulates everything that BLEU needs to know about it.''' 66 | 67 | testlen, counts = precook(test, n, True) 68 | reflen, refmaxcounts = ref_tuple 69 | 70 | result = {} 71 | 72 | # Calculate effective reference sentence length. 73 | 74 | if eff == "closest": 75 | result["reflen"] = min((abs(l - testlen), l) for l in reflen)[1] 76 | else: ## i.e., "average" or "shortest" or None 77 | result["reflen"] = reflen 78 | 79 | result["testlen"] = testlen 80 | 81 | result["guess"] = [max(0, testlen - k + 1) for k in range(1, n + 1)] 82 | 83 | result['correct'] = [0] * n 84 | for (ngram, count) in counts.items(): 85 | result["correct"][len(ngram) - 1] += min(refmaxcounts.get(ngram, 0), count) 86 | 87 | return result 88 | 89 | 90 | class BleuScorer(object): 91 | """Bleu scorer. 92 | """ 93 | 94 | __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen" 95 | 96 | # special_reflen is used in oracle (proportional effective ref len for a node). 97 | 98 | def copy(self): 99 | ''' copy the refs.''' 100 | new = BleuScorer(n=self.n) 101 | new.ctest = copy.copy(self.ctest) 102 | new.crefs = copy.copy(self.crefs) 103 | new._score = None 104 | return new 105 | 106 | def __init__(self, test=None, refs=None, n=4, special_reflen=None): 107 | ''' singular instance ''' 108 | 109 | self.n = n 110 | self.crefs = [] 111 | self.ctest = [] 112 | self.cook_append(test, refs) 113 | self.special_reflen = special_reflen 114 | 115 | def cook_append(self, test, refs): 116 | '''called by constructor and __iadd__ to avoid creating new instances.''' 117 | 118 | if refs is not None: 119 | self.crefs.append(cook_refs(refs)) 120 | if test is not None: 121 | cooked_test = cook_test(test, self.crefs[-1]) 122 | self.ctest.append(cooked_test) ## N.B.: -1 123 | else: 124 | self.ctest.append(None) # lens of crefs and ctest have to match 125 | 126 | self._score = None ## need to recompute 127 | 128 | def ratio(self, option=None): 129 | self.compute_score(option=option) 130 | return self._ratio 131 | 132 | def score_ratio(self, option=None): 133 | ''' 134 | return (bleu, len_ratio) pair 135 | ''' 136 | 137 | return self.fscore(option=option), self.ratio(option=option) 138 | 139 | def score_ratio_str(self, option=None): 140 | return "%.4f (%.2f)" % self.score_ratio(option) 141 | 142 | def reflen(self, option=None): 143 | self.compute_score(option=option) 144 | return self._reflen 145 | 146 | def testlen(self, option=None): 147 | self.compute_score(option=option) 148 | return self._testlen 149 | 150 | def retest(self, new_test): 151 | if type(new_test) is str: 152 | new_test = [new_test] 153 | assert len(new_test) == len(self.crefs), new_test 154 | self.ctest = [] 155 | for t, rs in zip(new_test, self.crefs): 156 | self.ctest.append(cook_test(t, rs)) 157 | self._score = None 158 | 159 | return self 160 | 161 | def rescore(self, new_test): 162 | ''' replace test(s) with new test(s), and returns the new score.''' 163 | 164 | return self.retest(new_test).compute_score() 165 | 166 | def size(self): 167 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 168 | return len(self.crefs) 169 | 170 | def __iadd__(self, other): 171 | '''add an instance (e.g., from another sentence).''' 172 | 173 | if type(other) is tuple: 174 | ## avoid creating new BleuScorer instances 175 | self.cook_append(other[0], other[1]) 176 | else: 177 | assert self.compatible(other), "incompatible BLEUs." 178 | self.ctest.extend(other.ctest) 179 | self.crefs.extend(other.crefs) 180 | self._score = None ## need to recompute 181 | 182 | return self 183 | 184 | def compatible(self, other): 185 | return isinstance(other, BleuScorer) and self.n == other.n 186 | 187 | def single_reflen(self, option="average"): 188 | return self._single_reflen(self.crefs[0][0], option) 189 | 190 | def _single_reflen(self, reflens, option=None, testlen=None): 191 | 192 | if option == "shortest": 193 | reflen = min(reflens) 194 | elif option == "average": 195 | reflen = float(sum(reflens)) / len(reflens) 196 | elif option == "closest": 197 | reflen = min((abs(l - testlen), l) for l in reflens)[1] 198 | else: 199 | assert False, "unsupported reflen option %s" % option 200 | 201 | return reflen 202 | 203 | def recompute_score(self, option=None, verbose=0): 204 | self._score = None 205 | return self.compute_score(option, verbose) 206 | 207 | def compute_score(self, option=None, verbose=0): 208 | n = self.n 209 | small = 1e-9 210 | tiny = 1e-15 ## so that if guess is 0 still return 0 211 | bleu_list = [[] for _ in range(n)] 212 | 213 | if self._score is not None: 214 | return self._score 215 | 216 | if option is None: 217 | option = "average" if len(self.crefs) == 1 else "closest" 218 | 219 | self._testlen = 0 220 | self._reflen = 0 221 | totalcomps = {'testlen': 0, 'reflen': 0, 'guess': [0] * n, 'correct': [0] * n} 222 | 223 | # for each sentence 224 | for comps in self.ctest: 225 | testlen = comps['testlen'] 226 | self._testlen += testlen 227 | 228 | if self.special_reflen is None: ## need computation 229 | reflen = self._single_reflen(comps['reflen'], option, testlen) 230 | else: 231 | reflen = self.special_reflen 232 | 233 | self._reflen += reflen 234 | 235 | for key in ['guess', 'correct']: 236 | for k in range(n): 237 | totalcomps[key][k] += comps[key][k] 238 | 239 | # append per image bleu score 240 | bleu = 1. 241 | for k in range(n): 242 | bleu *= (float(comps['correct'][k]) + tiny) \ 243 | / (float(comps['guess'][k]) + small) 244 | bleu_list[k].append(bleu ** (1. / (k + 1))) 245 | ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division 246 | if ratio < 1: 247 | for k in range(n): 248 | bleu_list[k][-1] *= math.exp(1 - 1 / ratio) 249 | 250 | if verbose > 1: 251 | print(comps, reflen) 252 | 253 | totalcomps['reflen'] = self._reflen 254 | totalcomps['testlen'] = self._testlen 255 | 256 | bleus = [] 257 | bleu = 1. 258 | for k in range(n): 259 | bleu *= float(totalcomps['correct'][k] + tiny) \ 260 | / (totalcomps['guess'][k] + small) 261 | bleus.append(bleu ** (1. / (k + 1))) 262 | ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division 263 | if ratio < 1: 264 | for k in range(n): 265 | bleus[k] *= math.exp(1 - 1 / ratio) 266 | 267 | if verbose > 0: 268 | print(totalcomps) 269 | print("ratio:", ratio) 270 | 271 | self._score = bleus 272 | return self._score, bleu_list 273 | -------------------------------------------------------------------------------- /meshed-memory-transformer-master/evaluation/cider/__init__.py: -------------------------------------------------------------------------------- 1 | from .cider import Cider -------------------------------------------------------------------------------- /meshed-memory-transformer-master/evaluation/cider/cider.py: -------------------------------------------------------------------------------- 1 | # Filename: cider.py 2 | # 3 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric 4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 5 | # 6 | # Creation Date: Sun Feb 8 14:16:54 2015 7 | # 8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin 9 | 10 | from .cider_scorer import CiderScorer 11 | 12 | class Cider: 13 | """ 14 | Main Class to compute the CIDEr metric 15 | 16 | """ 17 | def __init__(self, gts=None, n=4, sigma=6.0): 18 | # set cider to sum over 1 to 4-grams 19 | self._n = n 20 | # set the standard deviation parameter for gaussian penalty 21 | self._sigma = sigma 22 | self.doc_frequency = None 23 | self.ref_len = None 24 | if gts is not None: 25 | tmp_cider = CiderScorer(gts, n=self._n, sigma=self._sigma) 26 | self.doc_frequency = tmp_cider.doc_frequency 27 | self.ref_len = tmp_cider.ref_len 28 | 29 | def compute_score(self, gts, res): 30 | """ 31 | Main function to compute CIDEr score 32 | :param gts (dict) : dictionary with key and value 33 | res (dict) : dictionary with key and value 34 | :return: cider (float) : computed CIDEr score for the corpus 35 | """ 36 | assert(gts.keys() == res.keys()) 37 | cider_scorer = CiderScorer(gts, test=res, n=self._n, sigma=self._sigma, doc_frequency=self.doc_frequency, 38 | ref_len=self.ref_len) 39 | return cider_scorer.compute_score() 40 | 41 | def __str__(self): 42 | return 'CIDEr' 43 | -------------------------------------------------------------------------------- /meshed-memory-transformer-master/evaluation/cider/cider_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Tsung-Yi Lin 3 | # Ramakrishna Vedantam 4 | 5 | import copy 6 | from collections import defaultdict 7 | import numpy as np 8 | import math 9 | 10 | def precook(s, n=4): 11 | """ 12 | Takes a string as input and returns an object that can be given to 13 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 14 | can take string arguments as well. 15 | :param s: string : sentence to be converted into ngrams 16 | :param n: int : number of ngrams for which representation is calculated 17 | :return: term frequency vector for occuring ngrams 18 | """ 19 | words = s.split() 20 | counts = defaultdict(int) 21 | for k in range(1,n+1): 22 | for i in range(len(words)-k+1): 23 | ngram = tuple(words[i:i+k]) 24 | counts[ngram] += 1 25 | return counts 26 | 27 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" 28 | '''Takes a list of reference sentences for a single segment 29 | and returns an object that encapsulates everything that BLEU 30 | needs to know about them. 31 | :param refs: list of string : reference sentences for some image 32 | :param n: int : number of ngrams for which (ngram) representation is calculated 33 | :return: result (list of dict) 34 | ''' 35 | return [precook(ref, n) for ref in refs] 36 | 37 | def cook_test(test, n=4): 38 | '''Takes a test sentence and returns an object that 39 | encapsulates everything that BLEU needs to know about it. 40 | :param test: list of string : hypothesis sentence for some image 41 | :param n: int : number of ngrams for which (ngram) representation is calculated 42 | :return: result (dict) 43 | ''' 44 | return precook(test, n) 45 | 46 | class CiderScorer(object): 47 | """CIDEr scorer. 48 | """ 49 | 50 | def __init__(self, refs, test=None, n=4, sigma=6.0, doc_frequency=None, ref_len=None): 51 | ''' singular instance ''' 52 | self.n = n 53 | self.sigma = sigma 54 | self.crefs = [] 55 | self.ctest = [] 56 | self.doc_frequency = defaultdict(float) 57 | self.ref_len = None 58 | 59 | for k in refs.keys(): 60 | self.crefs.append(cook_refs(refs[k])) 61 | if test is not None: 62 | self.ctest.append(cook_test(test[k][0])) ## N.B.: -1 63 | else: 64 | self.ctest.append(None) # lens of crefs and ctest have to match 65 | 66 | if doc_frequency is None and ref_len is None: 67 | # compute idf 68 | self.compute_doc_freq() 69 | # compute log reference length 70 | self.ref_len = np.log(float(len(self.crefs))) 71 | else: 72 | self.doc_frequency = doc_frequency 73 | self.ref_len = ref_len 74 | 75 | def compute_doc_freq(self): 76 | ''' 77 | Compute term frequency for reference data. 78 | This will be used to compute idf (inverse document frequency later) 79 | The term frequency is stored in the object 80 | :return: None 81 | ''' 82 | for refs in self.crefs: 83 | # refs, k ref captions of one image 84 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]): 85 | self.doc_frequency[ngram] += 1 86 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 87 | 88 | def compute_cider(self): 89 | def counts2vec(cnts): 90 | """ 91 | Function maps counts of ngram to vector of tfidf weights. 92 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. 93 | The n-th entry of array denotes length of n-grams. 94 | :param cnts: 95 | :return: vec (array of dict), norm (array of float), length (int) 96 | """ 97 | vec = [defaultdict(float) for _ in range(self.n)] 98 | length = 0 99 | norm = [0.0 for _ in range(self.n)] 100 | for (ngram,term_freq) in cnts.items(): 101 | # give word count 1 if it doesn't appear in reference corpus 102 | df = np.log(max(1.0, self.doc_frequency[ngram])) 103 | # ngram index 104 | n = len(ngram)-1 105 | # tf (term_freq) * idf (precomputed idf) for n-grams 106 | vec[n][ngram] = float(term_freq)*(self.ref_len - df) 107 | # compute norm for the vector. the norm will be used for computing similarity 108 | norm[n] += pow(vec[n][ngram], 2) 109 | 110 | if n == 1: 111 | length += term_freq 112 | norm = [np.sqrt(n) for n in norm] 113 | return vec, norm, length 114 | 115 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): 116 | ''' 117 | Compute the cosine similarity of two vectors. 118 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis 119 | :param vec_ref: array of dictionary for vector corresponding to reference 120 | :param norm_hyp: array of float for vector corresponding to hypothesis 121 | :param norm_ref: array of float for vector corresponding to reference 122 | :param length_hyp: int containing length of hypothesis 123 | :param length_ref: int containing length of reference 124 | :return: array of score for each n-grams cosine similarity 125 | ''' 126 | delta = float(length_hyp - length_ref) 127 | # measure consine similarity 128 | val = np.array([0.0 for _ in range(self.n)]) 129 | for n in range(self.n): 130 | # ngram 131 | for (ngram,count) in vec_hyp[n].items(): 132 | # vrama91 : added clipping 133 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram] 134 | 135 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0): 136 | val[n] /= (norm_hyp[n]*norm_ref[n]) 137 | 138 | assert(not math.isnan(val[n])) 139 | # vrama91: added a length based gaussian penalty 140 | val[n] *= np.e**(-(delta**2)/(2*self.sigma**2)) 141 | return val 142 | 143 | scores = [] 144 | for test, refs in zip(self.ctest, self.crefs): 145 | # compute vector for test captions 146 | vec, norm, length = counts2vec(test) 147 | # compute vector for ref captions 148 | score = np.array([0.0 for _ in range(self.n)]) 149 | for ref in refs: 150 | vec_ref, norm_ref, length_ref = counts2vec(ref) 151 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) 152 | # change by vrama91 - mean of ngram scores, instead of sum 153 | score_avg = np.mean(score) 154 | # divide by number of references 155 | score_avg /= len(refs) 156 | # multiply score by 10 157 | score_avg *= 10.0 158 | # append score of an image to the score list 159 | scores.append(score_avg) 160 | return scores 161 | 162 | def compute_score(self): 163 | # compute cider score 164 | score = self.compute_cider() 165 | # debug 166 | # print score 167 | return np.mean(np.array(score)), np.array(score) -------------------------------------------------------------------------------- /meshed-memory-transformer-master/evaluation/meteor/__init__.py: -------------------------------------------------------------------------------- 1 | from .meteor import Meteor -------------------------------------------------------------------------------- /meshed-memory-transformer-master/evaluation/meteor/meteor.py: -------------------------------------------------------------------------------- 1 | # Python wrapper for METEOR implementation, by Xinlei Chen 2 | # Acknowledge Michael Denkowski for the generous discussion and help 3 | 4 | import os 5 | import subprocess 6 | import threading 7 | import tarfile 8 | from utils import download_from_url 9 | 10 | METEOR_GZ_URL = 'http://aimagelab.ing.unimore.it/speaksee/data/meteor.tgz' 11 | METEOR_JAR = 'meteor-1.5.jar' 12 | 13 | class Meteor: 14 | def __init__(self): 15 | base_path = os.path.dirname(os.path.abspath(__file__)) 16 | jar_path = os.path.join(base_path, METEOR_JAR) 17 | gz_path = os.path.join(base_path, os.path.basename(METEOR_GZ_URL)) 18 | if not os.path.isfile(jar_path): 19 | if not os.path.isfile(gz_path): 20 | download_from_url(METEOR_GZ_URL, gz_path) 21 | tar = tarfile.open(gz_path, "r") 22 | tar.extractall(path=os.path.dirname(os.path.abspath(__file__))) 23 | tar.close() 24 | os.remove(gz_path) 25 | 26 | self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \ 27 | '-', '-', '-stdio', '-l', 'en', '-norm'] 28 | self.meteor_p = subprocess.Popen(self.meteor_cmd, \ 29 | cwd=os.path.dirname(os.path.abspath(__file__)), \ 30 | stdin=subprocess.PIPE, \ 31 | stdout=subprocess.PIPE, \ 32 | stderr=subprocess.PIPE) 33 | # Used to guarantee thread safety 34 | self.lock = threading.Lock() 35 | 36 | def compute_score(self, gts, res): 37 | assert(gts.keys() == res.keys()) 38 | imgIds = gts.keys() 39 | scores = [] 40 | 41 | eval_line = 'EVAL' 42 | self.lock.acquire() 43 | for i in imgIds: 44 | assert(len(res[i]) == 1) 45 | stat = self._stat(res[i][0], gts[i]) 46 | eval_line += ' ||| {}'.format(stat) 47 | 48 | self.meteor_p.stdin.write('{}\n'.format(eval_line).encode()) 49 | self.meteor_p.stdin.flush() 50 | for i in range(0,len(imgIds)): 51 | scores.append(float(self.meteor_p.stdout.readline().strip())) 52 | score = float(self.meteor_p.stdout.readline().strip()) 53 | self.lock.release() 54 | 55 | return score, scores 56 | 57 | def _stat(self, hypothesis_str, reference_list): 58 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 59 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') 60 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 61 | self.meteor_p.stdin.write('{}\n'.format(score_line).encode()) 62 | self.meteor_p.stdin.flush() 63 | raw = self.meteor_p.stdout.readline().decode().strip() 64 | numbers = [str(int(float(n))) for n in raw.split()] 65 | return ' '.join(numbers) 66 | 67 | def __del__(self): 68 | self.lock.acquire() 69 | self.meteor_p.stdin.close() 70 | self.meteor_p.kill() 71 | self.meteor_p.wait() 72 | self.lock.release() 73 | 74 | def __str__(self): 75 | return 'METEOR' 76 | -------------------------------------------------------------------------------- /meshed-memory-transformer-master/evaluation/rouge/__init__.py: -------------------------------------------------------------------------------- 1 | from .rouge import Rouge -------------------------------------------------------------------------------- /meshed-memory-transformer-master/evaluation/rouge/rouge.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : rouge.py 4 | # 5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004) 6 | # 7 | # Creation Date : 2015-01-07 06:03 8 | # Author : Ramakrishna Vedantam 9 | 10 | import numpy as np 11 | import pdb 12 | 13 | 14 | def my_lcs(string, sub): 15 | """ 16 | Calculates longest common subsequence for a pair of tokenized strings 17 | :param string : list of str : tokens from a string split using whitespace 18 | :param sub : list of str : shorter string, also split using whitespace 19 | :returns: length (list of int): length of the longest common subsequence between the two strings 20 | 21 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS 22 | """ 23 | if (len(string) < len(sub)): 24 | sub, string = string, sub 25 | 26 | lengths = [[0 for i in range(0, len(sub) + 1)] for j in range(0, len(string) + 1)] 27 | 28 | for j in range(1, len(sub) + 1): 29 | for i in range(1, len(string) + 1): 30 | if (string[i - 1] == sub[j - 1]): 31 | lengths[i][j] = lengths[i - 1][j - 1] + 1 32 | else: 33 | lengths[i][j] = max(lengths[i - 1][j], lengths[i][j - 1]) 34 | 35 | return lengths[len(string)][len(sub)] 36 | 37 | 38 | class Rouge(): 39 | ''' 40 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set 41 | 42 | ''' 43 | 44 | def __init__(self): 45 | # vrama91: updated the value below based on discussion with Hovey 46 | self.beta = 1.2 47 | 48 | def calc_score(self, candidate, refs): 49 | """ 50 | Compute ROUGE-L score given one candidate and references for an image 51 | :param candidate: str : candidate sentence to be evaluated 52 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated 53 | :returns score: int (ROUGE-L score for the candidate evaluated against references) 54 | """ 55 | assert (len(candidate) == 1) 56 | assert (len(refs) > 0) 57 | prec = [] 58 | rec = [] 59 | 60 | # split into tokens 61 | token_c = candidate[0].split(" ") 62 | 63 | for reference in refs: 64 | # split into tokens 65 | token_r = reference.split(" ") 66 | # compute the longest common subsequence 67 | lcs = my_lcs(token_r, token_c) 68 | prec.append(lcs / float(len(token_c))) 69 | rec.append(lcs / float(len(token_r))) 70 | 71 | prec_max = max(prec) 72 | rec_max = max(rec) 73 | 74 | if (prec_max != 0 and rec_max != 0): 75 | score = ((1 + self.beta ** 2) * prec_max * rec_max) / float(rec_max + self.beta ** 2 * prec_max) 76 | else: 77 | score = 0.0 78 | return score 79 | 80 | def compute_score(self, gts, res): 81 | """ 82 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset 83 | Invoked by evaluate_captions.py 84 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values 85 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values 86 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images) 87 | """ 88 | assert (gts.keys() == res.keys()) 89 | imgIds = gts.keys() 90 | 91 | score = [] 92 | for id in imgIds: 93 | hypo = res[id] 94 | ref = gts[id] 95 | 96 | score.append(self.calc_score(hypo, ref)) 97 | 98 | # Sanity check. 99 | assert (type(hypo) is list) 100 | assert (len(hypo) == 1) 101 | assert (type(ref) is list) 102 | assert (len(ref) > 0) 103 | 104 | average_score = np.mean(np.array(score)) 105 | return average_score, np.array(score) 106 | 107 | def __str__(self): 108 | return 'ROUGE' 109 | -------------------------------------------------------------------------------- /meshed-memory-transformer-master/evaluation/tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : ptbtokenizer.py 4 | # 5 | # Description : Do the PTB Tokenization and remove punctuations. 6 | # 7 | # Creation Date : 29-12-2014 8 | # Last Modified : Thu Mar 19 09:53:35 2015 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | import os 12 | import subprocess 13 | import tempfile 14 | 15 | class PTBTokenizer(object): 16 | """Python wrapper of Stanford PTBTokenizer""" 17 | 18 | corenlp_jar = 'stanford-corenlp-3.4.1.jar' 19 | punctuations = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \ 20 | ".", "?", "!", ",", ":", "-", "--", "...", ";"] 21 | 22 | @classmethod 23 | def tokenize(cls, corpus): 24 | cmd = ['java', '-cp', cls.corenlp_jar, \ 25 | 'edu.stanford.nlp.process.PTBTokenizer', \ 26 | '-preserveLines', '-lowerCase'] 27 | 28 | if isinstance(corpus, list) or isinstance(corpus, tuple): 29 | if isinstance(corpus[0], list) or isinstance(corpus[0], tuple): 30 | corpus = {i:c for i, c in enumerate(corpus)} 31 | else: 32 | corpus = {i: [c, ] for i, c in enumerate(corpus)} 33 | 34 | # prepare data for PTB Tokenizer 35 | tokenized_corpus = {} 36 | image_id = [k for k, v in list(corpus.items()) for _ in range(len(v))] 37 | sentences = '\n'.join([c.replace('\n', ' ') for k, v in corpus.items() for c in v]) 38 | 39 | # save sentences to temporary file 40 | path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__)) 41 | tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname) 42 | tmp_file.write(sentences.encode()) 43 | tmp_file.close() 44 | 45 | # tokenize sentence 46 | cmd.append(os.path.basename(tmp_file.name)) 47 | p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dirname, \ 48 | stdout=subprocess.PIPE, stderr=open(os.devnull, 'w')) 49 | token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0] 50 | token_lines = token_lines.decode() 51 | lines = token_lines.split('\n') 52 | # remove temp file 53 | os.remove(tmp_file.name) 54 | 55 | # create dictionary for tokenized captions 56 | for k, line in zip(image_id, lines): 57 | if not k in tokenized_corpus: 58 | tokenized_corpus[k] = [] 59 | tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \ 60 | if w not in cls.punctuations]) 61 | tokenized_corpus[k].append(tokenized_caption) 62 | 63 | return tokenized_corpus -------------------------------------------------------------------------------- /meshed-memory-transformer-master/images/m2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/meshed-memory-transformer-master/images/m2.png -------------------------------------------------------------------------------- /meshed-memory-transformer-master/images/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/meshed-memory-transformer-master/images/results.png -------------------------------------------------------------------------------- /meshed-memory-transformer-master/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer import Transformer 2 | from .captioning_model import CaptioningModel 3 | -------------------------------------------------------------------------------- /meshed-memory-transformer-master/models/beam_search/__init__.py: -------------------------------------------------------------------------------- 1 | from .beam_search import BeamSearch 2 | -------------------------------------------------------------------------------- /meshed-memory-transformer-master/models/beam_search/beam_search.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import utils 3 | 4 | 5 | class BeamSearch(object): 6 | def __init__(self, model, max_len: int, eos_idx: int, beam_size: int): 7 | self.model = model 8 | self.max_len = max_len 9 | self.eos_idx = eos_idx 10 | self.beam_size = beam_size 11 | self.b_s = None 12 | self.device = None 13 | self.seq_mask = None 14 | self.seq_logprob = None 15 | self.outputs = None 16 | self.log_probs = None 17 | self.selected_words = None 18 | self.all_log_probs = None 19 | 20 | def _expand_state(self, selected_beam, cur_beam_size): 21 | def fn(s): 22 | shape = [int(sh) for sh in s.shape] 23 | beam = selected_beam 24 | for _ in shape[1:]: 25 | beam = beam.unsqueeze(-1) 26 | s = torch.gather(s.view(*([self.b_s, cur_beam_size] + shape[1:])), 1, 27 | beam.expand(*([self.b_s, self.beam_size] + shape[1:]))) 28 | s = s.view(*([-1, ] + shape[1:])) 29 | return s 30 | 31 | return fn 32 | 33 | def _expand_visual(self, visual: utils.TensorOrSequence, cur_beam_size: int, selected_beam: torch.Tensor): 34 | if isinstance(visual, torch.Tensor): 35 | visual_shape = visual.shape 36 | visual_exp_shape = (self.b_s, cur_beam_size) + visual_shape[1:] 37 | visual_red_shape = (self.b_s * self.beam_size,) + visual_shape[1:] 38 | selected_beam_red_size = (self.b_s, self.beam_size) + tuple(1 for _ in range(len(visual_exp_shape) - 2)) 39 | selected_beam_exp_size = (self.b_s, self.beam_size) + visual_exp_shape[2:] 40 | visual_exp = visual.view(visual_exp_shape) 41 | selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size) 42 | visual = torch.gather(visual_exp, 1, selected_beam_exp).view(visual_red_shape) 43 | else: 44 | new_visual = [] 45 | for im in visual: 46 | visual_shape = im.shape 47 | visual_exp_shape = (self.b_s, cur_beam_size) + visual_shape[1:] 48 | visual_red_shape = (self.b_s * self.beam_size,) + visual_shape[1:] 49 | selected_beam_red_size = (self.b_s, self.beam_size) + tuple(1 for _ in range(len(visual_exp_shape) - 2)) 50 | selected_beam_exp_size = (self.b_s, self.beam_size) + visual_exp_shape[2:] 51 | visual_exp = im.view(visual_exp_shape) 52 | selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size) 53 | new_im = torch.gather(visual_exp, 1, selected_beam_exp).view(visual_red_shape) 54 | new_visual.append(new_im) 55 | visual = tuple(new_visual) 56 | return visual 57 | 58 | def apply(self, visual: utils.TensorOrSequence, out_size=1, return_probs=False, **kwargs): 59 | self.b_s = utils.get_batch_size(visual) 60 | self.device = utils.get_device(visual) 61 | self.seq_mask = torch.ones((self.b_s, self.beam_size, 1), device=self.device) 62 | self.seq_logprob = torch.zeros((self.b_s, 1, 1), device=self.device) 63 | self.log_probs = [] 64 | self.selected_words = None 65 | if return_probs: 66 | self.all_log_probs = [] 67 | 68 | outputs = [] 69 | with self.model.statefulness(self.b_s): 70 | for t in range(self.max_len): 71 | visual, outputs = self.iter(t, visual, outputs, return_probs, **kwargs) 72 | 73 | # Sort result 74 | seq_logprob, sort_idxs = torch.sort(self.seq_logprob, 1, descending=True) 75 | outputs = torch.cat(outputs, -1) 76 | outputs = torch.gather(outputs, 1, sort_idxs.expand(self.b_s, self.beam_size, self.max_len)) 77 | log_probs = torch.cat(self.log_probs, -1) 78 | log_probs = torch.gather(log_probs, 1, sort_idxs.expand(self.b_s, self.beam_size, self.max_len)) 79 | if return_probs: 80 | all_log_probs = torch.cat(self.all_log_probs, 2) 81 | all_log_probs = torch.gather(all_log_probs, 1, sort_idxs.unsqueeze(-1).expand(self.b_s, self.beam_size, 82 | self.max_len, 83 | all_log_probs.shape[-1])) 84 | 85 | outputs = outputs.contiguous()[:, :out_size] 86 | log_probs = log_probs.contiguous()[:, :out_size] 87 | if out_size == 1: 88 | outputs = outputs.squeeze(1) 89 | log_probs = log_probs.squeeze(1) 90 | 91 | if return_probs: 92 | return outputs, log_probs, all_log_probs 93 | else: 94 | return outputs, log_probs 95 | 96 | def select(self, t, candidate_logprob, **kwargs): 97 | selected_logprob, selected_idx = torch.sort(candidate_logprob.view(self.b_s, -1), -1, descending=True) 98 | selected_logprob, selected_idx = selected_logprob[:, :self.beam_size], selected_idx[:, :self.beam_size] 99 | return selected_idx, selected_logprob 100 | 101 | def iter(self, t: int, visual: utils.TensorOrSequence, outputs, return_probs, **kwargs): 102 | cur_beam_size = 1 if t == 0 else self.beam_size 103 | 104 | word_logprob = self.model.step(t, self.selected_words, visual, None, mode='feedback', **kwargs) 105 | word_logprob = word_logprob.view(self.b_s, cur_beam_size, -1) 106 | candidate_logprob = self.seq_logprob + word_logprob 107 | 108 | # Mask sequence if it reaches EOS 109 | if t > 0: 110 | mask = (self.selected_words.view(self.b_s, cur_beam_size) != self.eos_idx).float().unsqueeze(-1) 111 | self.seq_mask = self.seq_mask * mask 112 | word_logprob = word_logprob * self.seq_mask.expand_as(word_logprob) 113 | old_seq_logprob = self.seq_logprob.expand_as(candidate_logprob).contiguous() 114 | old_seq_logprob[:, :, 1:] = -999 115 | candidate_logprob = self.seq_mask * candidate_logprob + old_seq_logprob * (1 - self.seq_mask) 116 | 117 | selected_idx, selected_logprob = self.select(t, candidate_logprob, **kwargs) 118 | selected_beam = selected_idx / candidate_logprob.shape[-1] 119 | selected_words = selected_idx - selected_beam * candidate_logprob.shape[-1] 120 | 121 | self.model.apply_to_states(self._expand_state(selected_beam, cur_beam_size)) 122 | visual = self._expand_visual(visual, cur_beam_size, selected_beam) 123 | 124 | self.seq_logprob = selected_logprob.unsqueeze(-1) 125 | self.seq_mask = torch.gather(self.seq_mask, 1, selected_beam.unsqueeze(-1)) 126 | outputs = list(torch.gather(o, 1, selected_beam.unsqueeze(-1)) for o in outputs) 127 | outputs.append(selected_words.unsqueeze(-1)) 128 | 129 | if return_probs: 130 | if t == 0: 131 | self.all_log_probs.append(word_logprob.expand((self.b_s, self.beam_size, -1)).unsqueeze(2)) 132 | else: 133 | self.all_log_probs.append(word_logprob.unsqueeze(2)) 134 | 135 | this_word_logprob = torch.gather(word_logprob, 1, 136 | selected_beam.unsqueeze(-1).expand(self.b_s, self.beam_size, 137 | word_logprob.shape[-1])) 138 | this_word_logprob = torch.gather(this_word_logprob, 2, selected_words.unsqueeze(-1)) 139 | self.log_probs = list( 140 | torch.gather(o, 1, selected_beam.unsqueeze(-1).expand(self.b_s, self.beam_size, 1)) for o in self.log_probs) 141 | self.log_probs.append(this_word_logprob) 142 | self.selected_words = selected_words.view(-1, 1) 143 | 144 | return visual, outputs 145 | -------------------------------------------------------------------------------- /meshed-memory-transformer-master/models/captioning_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import distributions 3 | import utils 4 | from models.containers import Module 5 | from models.beam_search import * 6 | 7 | 8 | class CaptioningModel(Module): 9 | def __init__(self): 10 | super(CaptioningModel, self).__init__() 11 | 12 | def init_weights(self): 13 | raise NotImplementedError 14 | 15 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs): 16 | raise NotImplementedError 17 | 18 | def forward(self, images, seq, *args): 19 | device = images.device 20 | b_s = images.size(0) 21 | seq_len = seq.size(1) 22 | state = self.init_state(b_s, device) 23 | out = None 24 | 25 | outputs = [] 26 | for t in range(seq_len): 27 | out, state = self.step(t, state, out, images, seq, *args, mode='teacher_forcing') 28 | outputs.append(out) 29 | 30 | outputs = torch.cat([o.unsqueeze(1) for o in outputs], 1) 31 | return outputs 32 | 33 | def test(self, visual: utils.TensorOrSequence, max_len: int, eos_idx: int, **kwargs) -> utils.Tuple[torch.Tensor, torch.Tensor]: 34 | b_s = utils.get_batch_size(visual) 35 | device = utils.get_device(visual) 36 | outputs = [] 37 | log_probs = [] 38 | 39 | mask = torch.ones((b_s,), device=device) 40 | with self.statefulness(b_s): 41 | out = None 42 | for t in range(max_len): 43 | log_probs_t = self.step(t, out, visual, None, mode='feedback', **kwargs) 44 | out = torch.max(log_probs_t, -1)[1] 45 | mask = mask * (out.squeeze(-1) != eos_idx).float() 46 | log_probs.append(log_probs_t * mask.unsqueeze(-1).unsqueeze(-1)) 47 | outputs.append(out) 48 | 49 | return torch.cat(outputs, 1), torch.cat(log_probs, 1) 50 | 51 | def sample_rl(self, visual: utils.TensorOrSequence, max_len: int, **kwargs) -> utils.Tuple[torch.Tensor, torch.Tensor]: 52 | b_s = utils.get_batch_size(visual) 53 | outputs = [] 54 | log_probs = [] 55 | 56 | with self.statefulness(b_s): 57 | out = None 58 | for t in range(max_len): 59 | out = self.step(t, out, visual, None, mode='feedback', **kwargs) 60 | distr = distributions.Categorical(logits=out[:, 0]) 61 | out = distr.sample().unsqueeze(1) 62 | outputs.append(out) 63 | log_probs.append(distr.log_prob(out).unsqueeze(1)) 64 | 65 | return torch.cat(outputs, 1), torch.cat(log_probs, 1) 66 | 67 | def beam_search(self, visual: utils.TensorOrSequence, max_len: int, eos_idx: int, beam_size: int, out_size=1, 68 | return_probs=False, **kwargs): 69 | bs = BeamSearch(self, max_len, eos_idx, beam_size) 70 | return bs.apply(visual, out_size, return_probs, **kwargs) 71 | -------------------------------------------------------------------------------- /meshed-memory-transformer-master/models/containers.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from torch import nn 3 | from utils.typing import * 4 | 5 | 6 | class Module(nn.Module): 7 | def __init__(self): 8 | super(Module, self).__init__() 9 | self._is_stateful = False 10 | self._state_names = [] 11 | self._state_defaults = dict() 12 | 13 | def register_state(self, name: str, default: TensorOrNone): 14 | self._state_names.append(name) 15 | if default is None: 16 | self._state_defaults[name] = None 17 | else: 18 | self._state_defaults[name] = default.clone().detach() 19 | self.register_buffer(name, default) 20 | 21 | def states(self): 22 | for name in self._state_names: 23 | yield self._buffers[name] 24 | for m in self.children(): 25 | if isinstance(m, Module): 26 | yield from m.states() 27 | 28 | def apply_to_states(self, fn): 29 | for name in self._state_names: 30 | self._buffers[name] = fn(self._buffers[name]) 31 | for m in self.children(): 32 | if isinstance(m, Module): 33 | m.apply_to_states(fn) 34 | 35 | def _init_states(self, batch_size: int): 36 | for name in self._state_names: 37 | if self._state_defaults[name] is None: 38 | self._buffers[name] = None 39 | else: 40 | self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device) 41 | self._buffers[name] = self._buffers[name].unsqueeze(0) 42 | self._buffers[name] = self._buffers[name].expand([batch_size, ] + list(self._buffers[name].shape[1:])) 43 | self._buffers[name] = self._buffers[name].contiguous() 44 | 45 | def _reset_states(self): 46 | for name in self._state_names: 47 | if self._state_defaults[name] is None: 48 | self._buffers[name] = None 49 | else: 50 | self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device) 51 | 52 | def enable_statefulness(self, batch_size: int): 53 | for m in self.children(): 54 | if isinstance(m, Module): 55 | m.enable_statefulness(batch_size) 56 | self._init_states(batch_size) 57 | self._is_stateful = True 58 | 59 | def disable_statefulness(self): 60 | for m in self.children(): 61 | if isinstance(m, Module): 62 | m.disable_statefulness() 63 | self._reset_states() 64 | self._is_stateful = False 65 | 66 | @contextmanager 67 | def statefulness(self, batch_size: int): 68 | self.enable_statefulness(batch_size) 69 | try: 70 | yield 71 | finally: 72 | self.disable_statefulness() 73 | 74 | 75 | class ModuleList(nn.ModuleList, Module): 76 | pass 77 | 78 | 79 | class ModuleDict(nn.ModuleDict, Module): 80 | pass 81 | -------------------------------------------------------------------------------- /meshed-memory-transformer-master/models/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer import * 2 | from .encoders import * 3 | from .decoders import * 4 | from .attention import * 5 | -------------------------------------------------------------------------------- /meshed-memory-transformer-master/models/transformer/attention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from models.containers import Module 5 | 6 | 7 | class ScaledDotProductAttention(nn.Module): 8 | ''' 9 | Scaled dot-product attention 10 | ''' 11 | 12 | def __init__(self, d_model, d_k, d_v, h): 13 | ''' 14 | :param d_model: Output dimensionality of the model 15 | :param d_k: Dimensionality of queries and keys 16 | :param d_v: Dimensionality of values 17 | :param h: Number of heads 18 | ''' 19 | super(ScaledDotProductAttention, self).__init__() 20 | self.fc_q = nn.Linear(d_model, h * d_k) 21 | self.fc_k = nn.Linear(d_model, h * d_k) 22 | self.fc_v = nn.Linear(d_model, h * d_v) 23 | self.fc_o = nn.Linear(h * d_v, d_model) 24 | 25 | self.d_model = d_model 26 | self.d_k = d_k 27 | self.d_v = d_v 28 | self.h = h 29 | 30 | self.init_weights() 31 | 32 | def init_weights(self): 33 | nn.init.xavier_uniform_(self.fc_q.weight) 34 | nn.init.xavier_uniform_(self.fc_k.weight) 35 | nn.init.xavier_uniform_(self.fc_v.weight) 36 | nn.init.xavier_uniform_(self.fc_o.weight) 37 | nn.init.constant_(self.fc_q.bias, 0) 38 | nn.init.constant_(self.fc_k.bias, 0) 39 | nn.init.constant_(self.fc_v.bias, 0) 40 | nn.init.constant_(self.fc_o.bias, 0) 41 | 42 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 43 | ''' 44 | Computes 45 | :param queries: Queries (b_s, nq, d_model) 46 | :param keys: Keys (b_s, nk, d_model) 47 | :param values: Values (b_s, nk, d_model) 48 | :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking. 49 | :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk). 50 | :return: 51 | ''' 52 | b_s, nq = queries.shape[:2] 53 | nk = keys.shape[1] 54 | q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k) 55 | k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk) 56 | v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v) 57 | 58 | att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk) 59 | if attention_weights is not None: 60 | att = att * attention_weights 61 | if attention_mask is not None: 62 | att = att.masked_fill(attention_mask, -np.inf) 63 | att = torch.softmax(att, -1) 64 | out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v) 65 | out = self.fc_o(out) # (b_s, nq, d_model) 66 | return out 67 | 68 | 69 | class ScaledDotProductAttentionMemory(nn.Module): 70 | ''' 71 | Scaled dot-product attention with memory 72 | ''' 73 | 74 | def __init__(self, d_model, d_k, d_v, h, m): 75 | ''' 76 | :param d_model: Output dimensionality of the model 77 | :param d_k: Dimensionality of queries and keys 78 | :param d_v: Dimensionality of values 79 | :param h: Number of heads 80 | :param m: Number of memory slots 81 | ''' 82 | super(ScaledDotProductAttentionMemory, self).__init__() 83 | self.fc_q = nn.Linear(d_model, h * d_k) 84 | self.fc_k = nn.Linear(d_model, h * d_k) 85 | self.fc_v = nn.Linear(d_model, h * d_v) 86 | self.fc_o = nn.Linear(h * d_v, d_model) 87 | self.m_k = nn.Parameter(torch.FloatTensor(1, m, h * d_k)) 88 | self.m_v = nn.Parameter(torch.FloatTensor(1, m, h * d_v)) 89 | 90 | self.d_model = d_model 91 | self.d_k = d_k 92 | self.d_v = d_v 93 | self.h = h 94 | self.m = m 95 | 96 | self.init_weights() 97 | 98 | def init_weights(self): 99 | nn.init.xavier_uniform_(self.fc_q.weight) 100 | nn.init.xavier_uniform_(self.fc_k.weight) 101 | nn.init.xavier_uniform_(self.fc_v.weight) 102 | nn.init.xavier_uniform_(self.fc_o.weight) 103 | nn.init.normal_(self.m_k, 0, 1 / self.d_k) 104 | nn.init.normal_(self.m_v, 0, 1 / self.m) 105 | nn.init.constant_(self.fc_q.bias, 0) 106 | nn.init.constant_(self.fc_k.bias, 0) 107 | nn.init.constant_(self.fc_v.bias, 0) 108 | nn.init.constant_(self.fc_o.bias, 0) 109 | 110 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 111 | ''' 112 | Computes 113 | :param queries: Queries (b_s, nq, d_model) 114 | :param keys: Keys (b_s, nk, d_model) 115 | :param values: Values (b_s, nk, d_model) 116 | :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking. 117 | :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk). 118 | :return: 119 | ''' 120 | b_s, nq = queries.shape[:2] 121 | nk = keys.shape[1] 122 | 123 | m_k = np.sqrt(self.d_k) * self.m_k.expand(b_s, self.m, self.h * self.d_k) 124 | m_v = np.sqrt(self.m) * self.m_v.expand(b_s, self.m, self.h * self.d_v) 125 | 126 | q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k) 127 | k = torch.cat([self.fc_k(keys), m_k], 1).view(b_s, nk + self.m, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk) 128 | v = torch.cat([self.fc_v(values), m_v], 1).view(b_s, nk + self.m, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v) 129 | 130 | att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk) 131 | if attention_weights is not None: 132 | att = torch.cat([att[:, :, :, :nk] * attention_weights, att[:, :, :, nk:]], -1) 133 | if attention_mask is not None: 134 | att[:, :, :, :nk] = att[:, :, :, :nk].masked_fill(attention_mask, -np.inf) 135 | att = torch.softmax(att, -1) 136 | out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v) 137 | out = self.fc_o(out) # (b_s, nq, d_model) 138 | return out 139 | 140 | 141 | class MultiHeadAttention(Module): 142 | ''' 143 | Multi-head attention layer with Dropout and Layer Normalization. 144 | ''' 145 | 146 | def __init__(self, d_model, d_k, d_v, h, dropout=.1, identity_map_reordering=False, can_be_stateful=False, 147 | attention_module=None, attention_module_kwargs=None): 148 | super(MultiHeadAttention, self).__init__() 149 | self.identity_map_reordering = identity_map_reordering 150 | if attention_module is not None: 151 | if attention_module_kwargs is not None: 152 | self.attention = attention_module(d_model=d_model, d_k=d_k, d_v=d_v, h=h, **attention_module_kwargs) 153 | else: 154 | self.attention = attention_module(d_model=d_model, d_k=d_k, d_v=d_v, h=h) 155 | else: 156 | self.attention = ScaledDotProductAttention(d_model=d_model, d_k=d_k, d_v=d_v, h=h) 157 | self.dropout = nn.Dropout(p=dropout) 158 | self.layer_norm = nn.LayerNorm(d_model) 159 | 160 | self.can_be_stateful = can_be_stateful 161 | if self.can_be_stateful: 162 | self.register_state('running_keys', torch.zeros((0, d_model))) 163 | self.register_state('running_values', torch.zeros((0, d_model))) 164 | 165 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 166 | if self.can_be_stateful and self._is_stateful: 167 | self.running_keys = torch.cat([self.running_keys, keys], 1) 168 | keys = self.running_keys 169 | 170 | self.running_values = torch.cat([self.running_values, values], 1) 171 | values = self.running_values 172 | 173 | if self.identity_map_reordering: 174 | q_norm = self.layer_norm(queries) 175 | k_norm = self.layer_norm(keys) 176 | v_norm = self.layer_norm(values) 177 | out = self.attention(q_norm, k_norm, v_norm, attention_mask, attention_weights) 178 | out = queries + self.dropout(torch.relu(out)) 179 | else: 180 | out = self.attention(queries, keys, values, attention_mask, attention_weights) 181 | out = self.dropout(out) 182 | out = self.layer_norm(queries + out) 183 | return out 184 | -------------------------------------------------------------------------------- /meshed-memory-transformer-master/models/transformer/decoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | 6 | from models.transformer.attention import MultiHeadAttention 7 | from models.transformer.utils import sinusoid_encoding_table, PositionWiseFeedForward 8 | from models.containers import Module, ModuleList 9 | 10 | 11 | class MeshedDecoderLayer(Module): 12 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, self_att_module=None, 13 | enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None): 14 | super(MeshedDecoderLayer, self).__init__() 15 | self.self_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=True, 16 | attention_module=self_att_module, 17 | attention_module_kwargs=self_att_module_kwargs) 18 | self.enc_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=False, 19 | attention_module=enc_att_module, 20 | attention_module_kwargs=enc_att_module_kwargs) 21 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout) 22 | 23 | self.fc_alpha1 = nn.Linear(d_model + d_model, d_model) 24 | self.fc_alpha2 = nn.Linear(d_model + d_model, d_model) 25 | self.fc_alpha3 = nn.Linear(d_model + d_model, d_model) 26 | 27 | self.init_weights() 28 | 29 | def init_weights(self): 30 | nn.init.xavier_uniform_(self.fc_alpha1.weight) 31 | nn.init.xavier_uniform_(self.fc_alpha2.weight) 32 | nn.init.xavier_uniform_(self.fc_alpha3.weight) 33 | nn.init.constant_(self.fc_alpha1.bias, 0) 34 | nn.init.constant_(self.fc_alpha2.bias, 0) 35 | nn.init.constant_(self.fc_alpha3.bias, 0) 36 | 37 | def forward(self, input, enc_output, mask_pad, mask_self_att, mask_enc_att): 38 | self_att = self.self_att(input, input, input, mask_self_att) 39 | self_att = self_att * mask_pad 40 | 41 | enc_att1 = self.enc_att(self_att, enc_output[:, 0], enc_output[:, 0], mask_enc_att) * mask_pad 42 | enc_att2 = self.enc_att(self_att, enc_output[:, 1], enc_output[:, 1], mask_enc_att) * mask_pad 43 | enc_att3 = self.enc_att(self_att, enc_output[:, 2], enc_output[:, 2], mask_enc_att) * mask_pad 44 | 45 | alpha1 = torch.sigmoid(self.fc_alpha1(torch.cat([self_att, enc_att1], -1))) 46 | alpha2 = torch.sigmoid(self.fc_alpha2(torch.cat([self_att, enc_att2], -1))) 47 | alpha3 = torch.sigmoid(self.fc_alpha3(torch.cat([self_att, enc_att3], -1))) 48 | 49 | enc_att = (enc_att1 * alpha1 + enc_att2 * alpha2 + enc_att3 * alpha3) / np.sqrt(3) 50 | enc_att = enc_att * mask_pad 51 | 52 | ff = self.pwff(enc_att) 53 | ff = ff * mask_pad 54 | return ff 55 | 56 | 57 | class MeshedDecoder(Module): 58 | def __init__(self, vocab_size, max_len, N_dec, padding_idx, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, 59 | self_att_module=None, enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None): 60 | super(MeshedDecoder, self).__init__() 61 | self.d_model = d_model 62 | self.word_emb = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx) 63 | self.pos_emb = nn.Embedding.from_pretrained(sinusoid_encoding_table(max_len + 1, d_model, 0), freeze=True) 64 | self.layers = ModuleList( 65 | [MeshedDecoderLayer(d_model, d_k, d_v, h, d_ff, dropout, self_att_module=self_att_module, 66 | enc_att_module=enc_att_module, self_att_module_kwargs=self_att_module_kwargs, 67 | enc_att_module_kwargs=enc_att_module_kwargs) for _ in range(N_dec)]) 68 | self.fc = nn.Linear(d_model, vocab_size, bias=False) 69 | self.max_len = max_len 70 | self.padding_idx = padding_idx 71 | self.N = N_dec 72 | 73 | self.register_state('running_mask_self_attention', torch.zeros((1, 1, 0)).byte()) 74 | self.register_state('running_seq', torch.zeros((1,)).long()) 75 | 76 | def forward(self, input, encoder_output, mask_encoder): 77 | # input (b_s, seq_len) 78 | b_s, seq_len = input.shape[:2] 79 | mask_queries = (input != self.padding_idx).unsqueeze(-1).float() # (b_s, seq_len, 1) 80 | mask_self_attention = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8, device=input.device), 81 | diagonal=1) 82 | mask_self_attention = mask_self_attention.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len) 83 | mask_self_attention = mask_self_attention + (input == self.padding_idx).unsqueeze(1).unsqueeze(1).byte() 84 | mask_self_attention = mask_self_attention.gt(0) # (b_s, 1, seq_len, seq_len) 85 | if self._is_stateful: 86 | self.running_mask_self_attention = torch.cat([self.running_mask_self_attention, mask_self_attention], -1) 87 | mask_self_attention = self.running_mask_self_attention 88 | 89 | seq = torch.arange(1, seq_len + 1).view(1, -1).expand(b_s, -1).to(input.device) # (b_s, seq_len) 90 | seq = seq.masked_fill(mask_queries.squeeze(-1) == 0, 0) 91 | if self._is_stateful: 92 | self.running_seq.add_(1) 93 | seq = self.running_seq 94 | 95 | out = self.word_emb(input) + self.pos_emb(seq) 96 | for i, l in enumerate(self.layers): 97 | out = l(out, encoder_output, mask_queries, mask_self_attention, mask_encoder) 98 | 99 | out = self.fc(out) 100 | return F.log_softmax(out, dim=-1) 101 | -------------------------------------------------------------------------------- /meshed-memory-transformer-master/models/transformer/encoders.py: -------------------------------------------------------------------------------- 1 | from torch.nn import functional as F 2 | from models.transformer.utils import PositionWiseFeedForward 3 | import torch 4 | from torch import nn 5 | from models.transformer.attention import MultiHeadAttention 6 | 7 | 8 | class EncoderLayer(nn.Module): 9 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, identity_map_reordering=False, 10 | attention_module=None, attention_module_kwargs=None): 11 | super(EncoderLayer, self).__init__() 12 | self.identity_map_reordering = identity_map_reordering 13 | self.mhatt = MultiHeadAttention(d_model, d_k, d_v, h, dropout, identity_map_reordering=identity_map_reordering, 14 | attention_module=attention_module, 15 | attention_module_kwargs=attention_module_kwargs) 16 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout, identity_map_reordering=identity_map_reordering) 17 | 18 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 19 | att = self.mhatt(queries, keys, values, attention_mask, attention_weights) 20 | ff = self.pwff(att) 21 | return ff 22 | 23 | 24 | class MultiLevelEncoder(nn.Module): 25 | def __init__(self, N, padding_idx, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, 26 | identity_map_reordering=False, attention_module=None, attention_module_kwargs=None): 27 | super(MultiLevelEncoder, self).__init__() 28 | self.d_model = d_model 29 | self.dropout = dropout 30 | self.layers = nn.ModuleList([EncoderLayer(d_model, d_k, d_v, h, d_ff, dropout, 31 | identity_map_reordering=identity_map_reordering, 32 | attention_module=attention_module, 33 | attention_module_kwargs=attention_module_kwargs) 34 | for _ in range(N)]) 35 | self.padding_idx = padding_idx 36 | 37 | def forward(self, input, attention_weights=None): 38 | # input (b_s, seq_len, d_in) 39 | attention_mask = (torch.sum(input, -1) == self.padding_idx).unsqueeze(1).unsqueeze(1) # (b_s, 1, 1, seq_len) 40 | 41 | outs = [] 42 | out = input 43 | for l in self.layers: 44 | out = l(out, out, out, attention_mask, attention_weights) 45 | outs.append(out.unsqueeze(1)) 46 | 47 | outs = torch.cat(outs, 1) 48 | return outs, attention_mask 49 | 50 | 51 | class MemoryAugmentedEncoder(MultiLevelEncoder): 52 | def __init__(self, N, padding_idx, d_in=2048, **kwargs): 53 | super(MemoryAugmentedEncoder, self).__init__(N, padding_idx, **kwargs) 54 | self.fc = nn.Linear(d_in, self.d_model) 55 | self.dropout = nn.Dropout(p=self.dropout) 56 | self.layer_norm = nn.LayerNorm(self.d_model) 57 | 58 | def forward(self, input, attention_weights=None): 59 | out = F.relu(self.fc(input)) 60 | out = self.dropout(out) 61 | out = self.layer_norm(out) 62 | return super(MemoryAugmentedEncoder, self).forward(out, attention_weights=attention_weights) 63 | -------------------------------------------------------------------------------- /meshed-memory-transformer-master/models/transformer/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import copy 4 | from models.containers import ModuleList 5 | from ..captioning_model import CaptioningModel 6 | 7 | 8 | class Transformer(CaptioningModel): 9 | def __init__(self, bos_idx, encoder, decoder): 10 | super(Transformer, self).__init__() 11 | self.bos_idx = bos_idx 12 | self.encoder = encoder 13 | self.decoder = decoder 14 | self.register_state('enc_output', None) 15 | self.register_state('mask_enc', None) 16 | self.init_weights() 17 | 18 | @property 19 | def d_model(self): 20 | return self.decoder.d_model 21 | 22 | def init_weights(self): 23 | for p in self.parameters(): 24 | if p.dim() > 1: 25 | nn.init.xavier_uniform_(p) 26 | 27 | def forward(self, images, seq, *args): 28 | enc_output, mask_enc = self.encoder(images) 29 | dec_output = self.decoder(seq, enc_output, mask_enc) 30 | return dec_output 31 | 32 | def init_state(self, b_s, device): 33 | return [torch.zeros((b_s, 0), dtype=torch.long, device=device), 34 | None, None] 35 | 36 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs): 37 | it = None 38 | if mode == 'teacher_forcing': 39 | raise NotImplementedError 40 | elif mode == 'feedback': 41 | if t == 0: 42 | self.enc_output, self.mask_enc = self.encoder(visual) 43 | if isinstance(visual, torch.Tensor): 44 | it = visual.data.new_full((visual.shape[0], 1), self.bos_idx).long() 45 | else: 46 | it = visual[0].data.new_full((visual[0].shape[0], 1), self.bos_idx).long() 47 | else: 48 | it = prev_output 49 | 50 | return self.decoder(it, self.enc_output, self.mask_enc) 51 | 52 | 53 | class TransformerEnsemble(CaptioningModel): 54 | def __init__(self, model: Transformer, weight_files): 55 | super(TransformerEnsemble, self).__init__() 56 | self.n = len(weight_files) 57 | self.models = ModuleList([copy.deepcopy(model) for _ in range(self.n)]) 58 | for i in range(self.n): 59 | state_dict_i = torch.load(weight_files[i])['state_dict'] 60 | self.models[i].load_state_dict(state_dict_i) 61 | 62 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs): 63 | out_ensemble = [] 64 | for i in range(self.n): 65 | out_i = self.models[i].step(t, prev_output, visual, seq, mode, **kwargs) 66 | out_ensemble.append(out_i.unsqueeze(0)) 67 | 68 | return torch.mean(torch.cat(out_ensemble, 0), dim=0) 69 | -------------------------------------------------------------------------------- /meshed-memory-transformer-master/models/transformer/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | def position_embedding(input, d_model): 7 | input = input.view(-1, 1) 8 | dim = torch.arange(d_model // 2, dtype=torch.float32, device=input.device).view(1, -1) 9 | sin = torch.sin(input / 10000 ** (2 * dim / d_model)) 10 | cos = torch.cos(input / 10000 ** (2 * dim / d_model)) 11 | 12 | out = torch.zeros((input.shape[0], d_model), device=input.device) 13 | out[:, ::2] = sin 14 | out[:, 1::2] = cos 15 | return out 16 | 17 | 18 | def sinusoid_encoding_table(max_len, d_model, padding_idx=None): 19 | pos = torch.arange(max_len, dtype=torch.float32) 20 | out = position_embedding(pos, d_model) 21 | 22 | if padding_idx is not None: 23 | out[padding_idx] = 0 24 | return out 25 | 26 | 27 | class PositionWiseFeedForward(nn.Module): 28 | ''' 29 | Position-wise feed forward layer 30 | ''' 31 | 32 | def __init__(self, d_model=512, d_ff=2048, dropout=.1, identity_map_reordering=False): 33 | super(PositionWiseFeedForward, self).__init__() 34 | self.identity_map_reordering = identity_map_reordering 35 | self.fc1 = nn.Linear(d_model, d_ff) 36 | self.fc2 = nn.Linear(d_ff, d_model) 37 | self.dropout = nn.Dropout(p=dropout) 38 | self.dropout_2 = nn.Dropout(p=dropout) 39 | self.layer_norm = nn.LayerNorm(d_model) 40 | 41 | def forward(self, input): 42 | if self.identity_map_reordering: 43 | out = self.layer_norm(input) 44 | out = self.fc2(self.dropout_2(F.relu(self.fc1(out)))) 45 | out = input + self.dropout(torch.relu(out)) 46 | else: 47 | out = self.fc2(self.dropout_2(F.relu(self.fc1(input)))) 48 | out = self.dropout(out) 49 | out = self.layer_norm(input + out) 50 | return out -------------------------------------------------------------------------------- /meshed-memory-transformer-master/output_logs/meshed_memory_transformer_test_o: -------------------------------------------------------------------------------- 1 | Meshed-Memory Transformer Evaluation 2 | {'BLEU': [0.8076084272899184, 0.65337618312199, 0.5093125587687117, 0.3909357911782391], 'METEOR': 0.2918900660095916, 'ROUGE': 0.5863539878042495, 'CIDEr': 1.3119740267338893} 3 | -------------------------------------------------------------------------------- /meshed-memory-transformer-master/test.py: -------------------------------------------------------------------------------- 1 | import random 2 | from data import ImageDetectionsField, TextField, RawField 3 | from data import COCO, DataLoader 4 | import evaluation 5 | from models.transformer import Transformer, MemoryAugmentedEncoder, MeshedDecoder, ScaledDotProductAttentionMemory 6 | import torch 7 | from tqdm import tqdm 8 | import argparse 9 | import pickle 10 | import numpy as np 11 | 12 | random.seed(1234) 13 | torch.manual_seed(1234) 14 | np.random.seed(1234) 15 | 16 | 17 | def predict_captions(model, dataloader, text_field): 18 | import itertools 19 | model.eval() 20 | gen = {} 21 | gts = {} 22 | with tqdm(desc='Evaluation', unit='it', total=len(dataloader)) as pbar: 23 | for it, (images, caps_gt) in enumerate(iter(dataloader)): 24 | images = images.to(device) 25 | with torch.no_grad(): 26 | out, _ = model.beam_search(images, 20, text_field.vocab.stoi[''], 5, out_size=1) 27 | caps_gen = text_field.decode(out, join_words=False) 28 | for i, (gts_i, gen_i) in enumerate(zip(caps_gt, caps_gen)): 29 | gen_i = ' '.join([k for k, g in itertools.groupby(gen_i)]) 30 | gen['%d_%d' % (it, i)] = [gen_i.strip(), ] 31 | gts['%d_%d' % (it, i)] = gts_i 32 | pbar.update() 33 | 34 | gts = evaluation.PTBTokenizer.tokenize(gts) 35 | gen = evaluation.PTBTokenizer.tokenize(gen) 36 | scores, _ = evaluation.compute_scores(gts, gen) 37 | 38 | return scores 39 | 40 | 41 | if __name__ == '__main__': 42 | device = torch.device('cuda') 43 | 44 | parser = argparse.ArgumentParser(description='Meshed-Memory Transformer') 45 | parser.add_argument('--batch_size', type=int, default=10) 46 | parser.add_argument('--workers', type=int, default=0) 47 | parser.add_argument('--features_path', type=str) 48 | parser.add_argument('--annotation_folder', type=str) 49 | args = parser.parse_args() 50 | 51 | print('Meshed-Memory Transformer Evaluation') 52 | 53 | # Pipeline for image regions 54 | image_field = ImageDetectionsField(detections_path=args.features_path, max_detections=50, load_in_tmp=False) 55 | 56 | # Pipeline for text 57 | text_field = TextField(init_token='', eos_token='', lower=True, tokenize='spacy', 58 | remove_punctuation=True, nopoints=False) 59 | 60 | # Create the dataset 61 | dataset = COCO(image_field, text_field, 'coco/images/', args.annotation_folder, args.annotation_folder) 62 | _, _, test_dataset = dataset.splits 63 | text_field.vocab = pickle.load(open('vocab.pkl', 'rb')) 64 | 65 | # Model and dataloaders 66 | encoder = MemoryAugmentedEncoder(3, 0, attention_module=ScaledDotProductAttentionMemory, 67 | attention_module_kwargs={'m': 40}) 68 | decoder = MeshedDecoder(len(text_field.vocab), 54, 3, text_field.vocab.stoi['']) 69 | model = Transformer(text_field.vocab.stoi[''], encoder, decoder).to(device) 70 | 71 | data = torch.load('meshed_memory_transformer.pth') 72 | model.load_state_dict(data['state_dict']) 73 | 74 | dict_dataset_test = test_dataset.image_dictionary({'image': image_field, 'text': RawField()}) 75 | dict_dataloader_test = DataLoader(dict_dataset_test, batch_size=args.batch_size, num_workers=args.workers) 76 | 77 | scores = predict_captions(model, dict_dataloader_test, text_field) 78 | print(scores) 79 | -------------------------------------------------------------------------------- /meshed-memory-transformer-master/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import download_from_url 2 | from .typing import * 3 | 4 | def get_batch_size(x: TensorOrSequence) -> int: 5 | if isinstance(x, torch.Tensor): 6 | b_s = x.size(0) 7 | else: 8 | b_s = x[0].size(0) 9 | return b_s 10 | 11 | 12 | def get_device(x: TensorOrSequence) -> int: 13 | if isinstance(x, torch.Tensor): 14 | b_s = x.device 15 | else: 16 | b_s = x[0].device 17 | return b_s 18 | -------------------------------------------------------------------------------- /meshed-memory-transformer-master/utils/typing.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Sequence, Tuple 2 | import torch 3 | 4 | TensorOrSequence = Union[Sequence[torch.Tensor], torch.Tensor] 5 | TensorOrNone = Union[torch.Tensor, None] 6 | -------------------------------------------------------------------------------- /meshed-memory-transformer-master/utils/utils.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | def download_from_url(url, path): 4 | """Download file, with logic (from tensor2tensor) for Google Drive""" 5 | if 'drive.google.com' not in url: 6 | print('Downloading %s; may take a few minutes' % url) 7 | r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}) 8 | with open(path, "wb") as file: 9 | file.write(r.content) 10 | return 11 | print('Downloading from Google Drive; may take a few minutes') 12 | confirm_token = None 13 | session = requests.Session() 14 | response = session.get(url, stream=True) 15 | for k, v in response.cookies.items(): 16 | if k.startswith("download_warning"): 17 | confirm_token = v 18 | 19 | if confirm_token: 20 | url = url + "&confirm=" + confirm_token 21 | response = session.get(url, stream=True) 22 | 23 | chunk_size = 16 * 1024 24 | with open(path, "wb") as f: 25 | for chunk in response.iter_content(chunk_size): 26 | if chunk: 27 | f.write(chunk) 28 | -------------------------------------------------------------------------------- /meshed-memory-transformer-master/vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/meshed-memory-transformer-master/vocab.pkl -------------------------------------------------------------------------------- /model/Model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | import glob 5 | class Model(nn.Module): 6 | 7 | def __init__(self, name): 8 | super(Model, self).__init__() 9 | self.name = name 10 | 11 | def save(self, path, epoch=0): 12 | complete_path = os.path.join(path, self.name) 13 | if not os.path.exists(complete_path): 14 | os.makedirs(complete_path) 15 | torch.save(self.state_dict(), 16 | os.path.join(complete_path, 17 | "model-{}.pth".format(str(epoch).zfill(5)))) 18 | 19 | def save_results(self, path, data): 20 | raise NotImplementedError("Model subclass must implement this method.") 21 | 22 | def load(self, path, modelfile=None): 23 | complete_path = os.path.join(path, self.name) 24 | if not os.path.exists(complete_path): 25 | raise IOError("{} directory does not exist in {}".format(self.name, path)) 26 | 27 | if modelfile is None: 28 | model_files = glob.glob(complete_path + "/*") 29 | mf = max(model_files) 30 | else: 31 | mf = os.path.join(complete_path, modelfile) 32 | 33 | self.load_state_dict(torch.load(mf)) 34 | 35 | 36 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /model/__pycache__/Model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/model/__pycache__/Model.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/model/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/view_transformer_random4.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/model/__pycache__/view_transformer_random4.cpython-36.pyc -------------------------------------------------------------------------------- /model/gvcnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.nn as nn 4 | # from torchsummary import summary 5 | 6 | 7 | def fc_bn_block(input, output): 8 | return nn.Sequential( 9 | nn.Linear(input, output), 10 | nn.BatchNorm1d(output), 11 | nn.ReLU(inplace=True)) 12 | 13 | 14 | def cal_scores(scores): 15 | n = len(scores) 16 | s = 0 17 | for score in scores: 18 | s += torch.ceil(score * n) 19 | s /= n 20 | return s 21 | 22 | 23 | def group_fusion(view_group, weight_group): 24 | shape_des = map(lambda a, b: a * b, view_group, weight_group) 25 | shape_des = sum(shape_des) / sum(weight_group) 26 | return shape_des 27 | 28 | 29 | def group_pooling(final_views, views_score, group_num): 30 | interval = 1.0 / group_num 31 | 32 | def onebatch_grouping(onebatch_views, onebatch_scores): 33 | viewgroup_onebatch = [[] for i in range(group_num)] 34 | scoregroup_onebatch = [[] for i in range(group_num)] 35 | 36 | for i in range(group_num): 37 | left = i * interval 38 | right = (i + 1) * interval 39 | for j, score in enumerate(onebatch_scores): 40 | if left <= score < right: 41 | viewgroup_onebatch[i].append(onebatch_views[j]) 42 | scoregroup_onebatch[i].append(score) 43 | else: 44 | pass 45 | # print(len(scoregroup_onebatch)) 46 | view_group = [sum(views) / len(views) for views in viewgroup_onebatch if len(views) > 0] 47 | weight_group = [cal_scores(scores) for scores in scoregroup_onebatch if len(scores) > 0] 48 | onebatch_shape_des = group_fusion(view_group, weight_group) 49 | return onebatch_shape_des 50 | 51 | shape_descriptors = [] 52 | for (onebatch_views, onebatch_scores) in zip(final_views, views_score): 53 | shape_descriptors.append(onebatch_grouping(onebatch_views, onebatch_scores)) 54 | shape_descriptor = torch.stack(shape_descriptors, 0) 55 | # shape_descriptor: [B, 1024] 56 | return shape_descriptor 57 | 58 | 59 | class GVCNN(nn.Module): 60 | def __init__(self, num_classes=40, group_num=8, model_name='GOOGLENET', pretrained=True): 61 | super(GVCNN, self).__init__() 62 | 63 | self.num_classes = num_classes 64 | self.group_num = group_num 65 | 66 | if model_name == 'GOOGLENET': 67 | base_model = torchvision.models.googlenet(pretrained=pretrained) 68 | 69 | self.FCN = nn.Sequential(*list(base_model.children())[:6]) 70 | self.CNN = nn.Sequential(*list(base_model.children())[:-2]) 71 | self.FC = nn.Sequential(fc_bn_block(256 * 28 * 28, 256), 72 | fc_bn_block(256, 1)) 73 | self.fc_block_1 = fc_bn_block(1024, 512) 74 | self.drop_1 = nn.Dropout(0.5) 75 | self.fc_block_2 = fc_bn_block(512, 256) 76 | self.drop_2 = nn.Dropout(0.5) 77 | self.linear = nn.Linear(256, self.num_classes) 78 | 79 | def forward(self, views): 80 | ''' 81 | params views: B V C H W (B 12 3 224 224) 82 | return result: B num_classes 83 | ''' 84 | # print(views.size()) 85 | # views = views.cpu() 86 | batch_size, num_views, channel, image_size = views.size(0), views.size(1), views.size(2), views.size(3) 87 | 88 | views = views.view(batch_size * num_views, channel, image_size, image_size) 89 | raw_views = self.FCN(views) 90 | # print(raw_views.size()) 91 | # raw_views: [B*V 256 28 28] 92 | final_views = self.CNN(views) 93 | # final_views: [B*V 1024 1 1] 94 | final_views = final_views.view(batch_size, num_views, 1024) 95 | views_score = self.FC(raw_views.view(batch_size * num_views, -1)) 96 | views_score = torch.sigmoid(torch.tanh(torch.abs(views_score))) 97 | views_score = views_score.view(batch_size, num_views, -1) 98 | # views_score: [B V] 99 | shape_descriptor = group_pooling(final_views, views_score, self.group_num) 100 | # print(shape_descriptor.size()) 101 | 102 | out = self.fc_block_1(shape_descriptor) 103 | out = self.drop_1(out) 104 | out = self.fc_block_2(out) 105 | viewcnn_feature = out 106 | out = self.drop_2(out) 107 | pred = self.linear(out) 108 | 109 | return pred -------------------------------------------------------------------------------- /model/gvcnn_random.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.nn as nn 4 | # from torchsummary import summary 5 | 6 | 7 | def fc_bn_block(input, output): 8 | return nn.Sequential( 9 | nn.Linear(input, output), 10 | # nn.BatchNorm1d(output), 11 | nn.ReLU(inplace=True)) 12 | 13 | 14 | def cal_scores(scores): 15 | n = len(scores) 16 | s = 0 17 | for score in scores: 18 | s += torch.ceil(score * n) 19 | s /= n 20 | return s 21 | 22 | def group_fusion(view_group, weight_group): 23 | shape_des = map(lambda a, b: a * b, view_group, weight_group) 24 | shape_des = sum(shape_des) / sum(weight_group) 25 | return shape_des 26 | 27 | def group_pooling(final_views, views_score, group_num): 28 | interval = 1.0 / group_num 29 | 30 | def onebatch_grouping(onebatch_views, onebatch_scores): 31 | viewgroup_onebatch = [[] for i in range(group_num)] 32 | scoregroup_onebatch = [[] for i in range(group_num)] 33 | 34 | for i in range(group_num): 35 | left = i * interval 36 | right = (i + 1) * interval 37 | for j, score in enumerate(onebatch_scores): 38 | if left <= score < right: 39 | viewgroup_onebatch[i].append(onebatch_views[j]) 40 | scoregroup_onebatch[i].append(score) 41 | else: 42 | pass 43 | # print(len(scoregroup_onebatch)) 44 | view_group = [sum(views) / len(views) for views in viewgroup_onebatch if len(views) > 0] 45 | weight_group = [cal_scores(scores) for scores in scoregroup_onebatch if len(scores) > 0] 46 | onebatch_shape_des = group_fusion(view_group, weight_group) 47 | return onebatch_shape_des 48 | 49 | shape_descriptors = [] 50 | for (onebatch_views, onebatch_scores) in zip(final_views, views_score): 51 | shape_descriptors.append(onebatch_grouping(onebatch_views, onebatch_scores)) 52 | shape_descriptor = torch.stack(shape_descriptors, 0) 53 | # shape_descriptor: [B, 1024] 54 | return shape_descriptor 55 | 56 | 57 | class GVCNN(nn.Module): 58 | def __init__(self, num_classes=40, group_num=8, model_name='GOOGLENET', pretrained=True): 59 | super(GVCNN, self).__init__() 60 | 61 | self.num_classes = num_classes 62 | self.group_num = group_num 63 | 64 | if model_name == 'GOOGLENET': 65 | base_model = torchvision.models.googlenet(pretrained=pretrained) 66 | 67 | self.FCN = nn.Sequential(*list(base_model.children())[:6]) 68 | self.CNN = nn.Sequential(*list(base_model.children())[:-2]) 69 | self.FC = nn.Sequential(fc_bn_block(256 * 28 * 28, 256), 70 | fc_bn_block(256, 1)) 71 | self.fc_block_1 = fc_bn_block(1024, 512) 72 | self.drop_1 = nn.Dropout(0.5) 73 | self.fc_block_2 = fc_bn_block(512, 256) 74 | self.drop_2 = nn.Dropout(0.5) 75 | self.linear = nn.Linear(256, self.num_classes) 76 | 77 | def forward(self,views,rand_view_num,N): 78 | ''' 79 | params views: B V C H W (B 12 3 224 224) 80 | return result: B num_classes 81 | ''' 82 | # print(views.size()) 83 | # views = views.cpu() 84 | # batch_size, num_views, channel, image_size = views.size(0), views.size(1), views.size(2), views.size(3) 85 | # 86 | # views = views.view(batch_size * num_views, channel, image_size, image_size) 87 | raw_views = self.FCN(views) 88 | # print(raw_views.size()) 89 | # raw_views: [B*V 256 28 28] 90 | final_views = self.CNN(views) 91 | # final_views: [B*V 1024 1 1] 92 | final_views = final_views.squeeze() 93 | m = 0 94 | pred_final = [] 95 | for i in range(N): 96 | raw_views0 = raw_views[m:m+rand_view_num[i],:,:,:] 97 | final_views0 = final_views[m:m+rand_view_num[i],:].unsqueeze(0) 98 | m = m + rand_view_num[i] 99 | views_score = self.FC(raw_views0.view(rand_view_num[i],-1)).unsqueeze(0) 100 | views_score = torch.sigmoid(torch.tanh(torch.abs(views_score))) 101 | # views_score = views_score.view(batch_size, num_views, -1) 102 | # views_score: [B V] 103 | shape_descriptor = group_pooling(final_views0, views_score, self.group_num) 104 | # print(shape_descriptor.size()) 105 | out = self.fc_block_1(shape_descriptor) 106 | out = self.drop_1(out) 107 | out = self.fc_block_2(out) 108 | viewcnn_feature = out 109 | out = self.drop_2(out) 110 | pred = self.linear(out) 111 | pred_final.append(pred) 112 | pred_final = torch.cat(pred_final,0) 113 | return pred_final -------------------------------------------------------------------------------- /model/mvcnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torchvision.models as models 5 | from .Model import Model 6 | mean = torch.tensor([0.485, 0.456, 0.406],dtype=torch.float, requires_grad=False) 7 | std = torch.tensor([0.229, 0.224, 0.225],dtype=torch.float, requires_grad=False) 8 | def flip(x, dim): 9 | xsize = x.size() 10 | dim = x.dim() + dim if dim < 0 else dim 11 | x = x.view(-1, *xsize[dim:]) 12 | x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1) - 1, 13 | -1, -1), ('cpu', 'cuda')[x.is_cuda])().long(), :] 14 | return x.view(xsize) 15 | 16 | class SVCNN(Model): 17 | def __init__(self, name, nclasses=40, pretraining=True, cnn_name='resnet18'): 18 | super(SVCNN, self).__init__(name) 19 | if nclasses == 40: 20 | self.classnames = ['airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', 'bowl', 'car', 'chair', 21 | 'cone', 'cup', 'curtain', 'desk', 'door', 'dresser', 'flower_pot', 'glass_box', 22 | 'guitar', 'keyboard', 'lamp', 'laptop', 'mantel', 'monitor', 'night_stand', 23 | 'person', 'piano', 'plant', 'radio', 'range_hood', 'sink', 'sofa', 'stairs', 24 | 'stool', 'table', 'tent', 'toilet', 'tv_stand', 'vase', 'wardrobe', 'xbox'] 25 | elif nclasses==15: 26 | self.classnames = ['bag', 'bed', 'bin', 'box', 'cabinet', 'chair', 'desk', 'display' 27 | , 'door', 'pillow', 'shelf', 'sink', 'sofa', 'table', 'toilet'] 28 | else: 29 | self.classnames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', 30 | '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', 31 | '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', 32 | '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', 33 | '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54'] 34 | 35 | self.nclasses = nclasses 36 | self.pretraining = pretraining 37 | self.cnn_name = cnn_name 38 | self.use_resnet = cnn_name.startswith('resnet') 39 | self.mean = torch.tensor([0.485, 0.456, 0.406],dtype=torch.float, requires_grad=False) 40 | self.std = torch.tensor([0.229, 0.224, 0.225],dtype=torch.float, requires_grad=False) 41 | 42 | if self.use_resnet: 43 | if self.cnn_name == 'resnet18': 44 | self.net = models.resnet18(pretrained=self.pretraining) 45 | self.net.fc = nn.Linear(512, self.nclasses) 46 | elif self.cnn_name == 'resnet34': 47 | self.net = models.resnet34(pretrained=self.pretraining) 48 | self.net.fc = nn.Linear(512, self.nclasses) 49 | elif self.cnn_name == 'resnet50': 50 | self.net = models.resnet50(pretrained=self.pretraining) 51 | self.net.fc = nn.Linear(2048, self.nclasses) 52 | else: 53 | if self.cnn_name == 'alexnet': 54 | self.net_1 = models.alexnet(pretrained=self.pretraining).features 55 | self.net_2 = models.alexnet(pretrained=self.pretraining).classifier 56 | elif self.cnn_name == 'vgg11': 57 | self.net_1 = models.vgg11_bn(pretrained=self.pretraining).features 58 | self.net_2 = models.vgg11_bn(pretrained=self.pretraining).classifier 59 | elif self.cnn_name == 'vgg16': 60 | self.net_1 = models.vgg16(pretrained=self.pretraining).features 61 | self.net_2 = models.vgg16(pretrained=self.pretraining).classifier 62 | 63 | self.net_2._modules['6'] = nn.Linear(4096, self.nclasses) 64 | 65 | def forward(self, x): 66 | if self.use_resnet: 67 | return self.net(x) 68 | else: 69 | y = self.net_1(x) 70 | return self.net_2(y.view(y.shape[0], -1)) 71 | 72 | class view_GCN(Model): 73 | def __init__(self,name, model, nclasses=40, cnn_name='resnet18', num_views=20): 74 | super(view_GCN,self).__init__(name) 75 | if nclasses == 40: 76 | self.classnames = ['airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', 'bowl', 'car', 'chair', 77 | 'cone', 'cup', 'curtain', 'desk', 'door', 'dresser', 'flower_pot', 'glass_box', 78 | 'guitar', 'keyboard', 'lamp', 'laptop', 'mantel', 'monitor', 'night_stand', 79 | 'person', 'piano', 'plant', 'radio', 'range_hood', 'sink', 'sofa', 'stairs', 80 | 'stool', 'table', 'tent', 'toilet', 'tv_stand', 'vase', 'wardrobe', 'xbox'] 81 | elif nclasses==15: 82 | self.classnames = ['bag', 'bed', 'bin', 'box', 'cabinet', 'chair', 'desk', 'display' 83 | , 'door', 'pillow', 'shelf', 'sink', 'sofa', 'table', 'toilet'] 84 | else: 85 | self.classnames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', 86 | '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', 87 | '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', 88 | '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', 89 | '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54'] 90 | self.nclasses = nclasses 91 | self.mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float, requires_grad=False) 92 | self.std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float, requires_grad=False) 93 | self.use_resnet = cnn_name.startswith('resnet') 94 | if self.use_resnet: 95 | self.net_1 = nn.Sequential(*list(model.net.children())[:-1]) 96 | self.net_2 = model.net.fc 97 | else: 98 | self.net_1 = model.net_1 99 | self.net_2 = model.net_2 100 | self.num_views = num_views 101 | for m in self.modules(): 102 | if isinstance(m, nn.Linear): 103 | nn.init.kaiming_uniform_(m.weight) 104 | elif isinstance(m, nn.Conv1d): 105 | nn.init.kaiming_uniform_(m.weight) 106 | def forward(self,x): 107 | y = self.net_1(x) 108 | y = y.view( 109 | (int(x.shape[0] / self.num_views), self.num_views, y.shape[-3], y.shape[-2], y.shape[-1])) # (8,12,512,7,7) 110 | return self.net_2(torch.max(y, 1)[0].view(y.shape[0], -1)) -------------------------------------------------------------------------------- /model/mvcnn_random.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torchvision.models as models 5 | from models.utils import PositionWiseFeedForward 6 | from .Model import Model 7 | from tools.view_gcn_utils import * 8 | from otk.layers import OTKernel 9 | from otk.utils import normalize 10 | from models.Vit import TransformerEncoderLayer 11 | mean = torch.tensor([0.485, 0.456, 0.406],dtype=torch.float, requires_grad=False) 12 | std = torch.tensor([0.229, 0.224, 0.225],dtype=torch.float, requires_grad=False) 13 | def flip(x, dim): 14 | xsize = x.size() 15 | dim = x.dim() + dim if dim < 0 else dim 16 | x = x.view(-1, *xsize[dim:]) 17 | x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1) - 1, 18 | -1, -1), ('cpu', 'cuda')[x.is_cuda])().long(), :] 19 | return x.view(xsize) 20 | 21 | class SVCNN(Model): 22 | def __init__(self, name, nclasses=40, pretraining=True, cnn_name='resnet18'): 23 | super(SVCNN, self).__init__(name) 24 | if nclasses == 40: 25 | self.classnames = ['airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', 'bowl', 'car', 'chair', 26 | 'cone', 'cup', 'curtain', 'desk', 'door', 'dresser', 'flower_pot', 'glass_box', 27 | 'guitar', 'keyboard', 'lamp', 'laptop', 'mantel', 'monitor', 'night_stand', 28 | 'person', 'piano', 'plant', 'radio', 'range_hood', 'sink', 'sofa', 'stairs', 29 | 'stool', 'table', 'tent', 'toilet', 'tv_stand', 'vase', 'wardrobe', 'xbox'] 30 | elif nclasses==15: 31 | self.classnames = ['bag', 'bed', 'bin', 'box', 'cabinet', 'chair', 'desk', 'display' 32 | , 'door', 'pillow', 'shelf', 'sink', 'sofa', 'table', 'toilet'] 33 | else: 34 | self.classnames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', 35 | '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', 36 | '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', 37 | '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', 38 | '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54'] 39 | 40 | self.nclasses = nclasses 41 | self.pretraining = pretraining 42 | self.cnn_name = cnn_name 43 | self.use_resnet = cnn_name.startswith('resnet') 44 | self.mean = torch.tensor([0.485, 0.456, 0.406],dtype=torch.float, requires_grad=False) 45 | self.std = torch.tensor([0.229, 0.224, 0.225],dtype=torch.float, requires_grad=False) 46 | 47 | if self.use_resnet: 48 | if self.cnn_name == 'resnet18': 49 | self.net = models.resnet18(pretrained=self.pretraining) 50 | self.net.fc = nn.Linear(512, self.nclasses) 51 | elif self.cnn_name == 'resnet34': 52 | self.net = models.resnet34(pretrained=self.pretraining) 53 | self.net.fc = nn.Linear(512, self.nclasses) 54 | elif self.cnn_name == 'resnet50': 55 | self.net = models.resnet50(pretrained=self.pretraining) 56 | self.net.fc = nn.Linear(2048, self.nclasses) 57 | else: 58 | if self.cnn_name == 'alexnet': 59 | self.net_1 = models.alexnet(pretrained=self.pretraining).features 60 | self.net_2 = models.alexnet(pretrained=self.pretraining).classifier 61 | elif self.cnn_name == 'vgg11': 62 | self.net_1 = models.vgg11_bn(pretrained=self.pretraining).features 63 | self.net_2 = models.vgg11_bn(pretrained=self.pretraining).classifier 64 | elif self.cnn_name == 'vgg16': 65 | self.net_1 = models.vgg16(pretrained=self.pretraining).features 66 | self.net_2 = models.vgg16(pretrained=self.pretraining).classifier 67 | 68 | self.net_2._modules['6'] = nn.Linear(4096, self.nclasses) 69 | 70 | def forward(self, x): 71 | if self.use_resnet: 72 | return self.net(x) 73 | else: 74 | y = self.net_1(x) 75 | return self.net_2(y.view(y.shape[0], -1)) 76 | 77 | class view_GCN(Model): 78 | def __init__(self,name, model, nclasses=40, cnn_name='resnet18', num_views=20): 79 | super(view_GCN,self).__init__(name) 80 | if nclasses == 40: 81 | self.classnames = ['airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', 'bowl', 'car', 'chair', 82 | 'cone', 'cup', 'curtain', 'desk', 'door', 'dresser', 'flower_pot', 'glass_box', 83 | 'guitar', 'keyboard', 'lamp', 'laptop', 'mantel', 'monitor', 'night_stand', 84 | 'person', 'piano', 'plant', 'radio', 'range_hood', 'sink', 'sofa', 'stairs', 85 | 'stool', 'table', 'tent', 'toilet', 'tv_stand', 'vase', 'wardrobe', 'xbox'] 86 | elif nclasses==15: 87 | self.classnames = ['bag', 'bed', 'bin', 'box', 'cabinet', 'chair', 'desk', 'display' 88 | , 'door', 'pillow', 'shelf', 'sink', 'sofa', 'table', 'toilet'] 89 | else: 90 | self.classnames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', 91 | '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', 92 | '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', 93 | '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', 94 | '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54'] 95 | self.nclasses = nclasses 96 | self.mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float, requires_grad=False) 97 | self.std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float, requires_grad=False) 98 | self.use_resnet = cnn_name.startswith('resnet') 99 | if self.use_resnet: 100 | self.net_1 = nn.Sequential(*list(model.net.children())[:-1]) 101 | self.net_2 = model.net.fc 102 | else: 103 | self.net_1 = model.net_1 104 | self.net_2 = model.net_2 105 | self.num_views = num_views 106 | # for m in self.modules(): 107 | # if isinstance(m,nn.Linear): 108 | # nn.init.kaiming_uniform_(m.weight) 109 | # elif isinstance(m,nn.Conv1d): 110 | # nn.init.kaiming_uniform_(m.weight) 111 | def forward(self,x,rand_view_num,N): 112 | y = self.net_1(x) 113 | y = y.squeeze() 114 | y0 = my_pad_sequence(sequences=y, view_num=rand_view_num, N=N, max_length=self.num_views, padding_value=0) 115 | pooled_view = y0.max(1)[0] 116 | pooled_view = self.net_2(pooled_view) 117 | return pooled_view -------------------------------------------------------------------------------- /model/view_transformer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torchvision.models as models 5 | from models.utils import PositionWiseFeedForward 6 | from .Model import Model 7 | from tools.view_gcn_utils import * 8 | from otk.layers import OTKernel 9 | from otk.utils import normalize 10 | from models.Vit import TransformerEncoderLayer 11 | mean = torch.tensor([0.485, 0.456, 0.406],dtype=torch.float, requires_grad=False) 12 | std = torch.tensor([0.229, 0.224, 0.225],dtype=torch.float, requires_grad=False) 13 | def flip(x, dim): 14 | xsize = x.size() 15 | dim = x.dim() + dim if dim < 0 else dim 16 | x = x.view(-1, *xsize[dim:]) 17 | x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1) - 1, 18 | -1, -1), ('cpu', 'cuda')[x.is_cuda])().long(), :] 19 | return x.view(xsize) 20 | 21 | class SVCNN(Model): 22 | def __init__(self, name, nclasses=40, pretraining=True, cnn_name='resnet18'): 23 | super(SVCNN, self).__init__(name) 24 | if nclasses == 40: 25 | self.classnames = ['airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', 'bowl', 'car', 'chair', 26 | 'cone', 'cup', 'curtain', 'desk', 'door', 'dresser', 'flower_pot', 'glass_box', 27 | 'guitar', 'keyboard', 'lamp', 'laptop', 'mantel', 'monitor', 'night_stand', 28 | 'person', 'piano', 'plant', 'radio', 'range_hood', 'sink', 'sofa', 'stairs', 29 | 'stool', 'table', 'tent', 'toilet', 'tv_stand', 'vase', 'wardrobe', 'xbox'] 30 | elif nclasses==15: 31 | self.classnames = ['bag', 'bed', 'bin', 'box', 'cabinet', 'chair', 'desk', 'display' 32 | , 'door', 'pillow', 'shelf', 'sink', 'sofa', 'table', 'toilet'] 33 | else: 34 | self.classnames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', 35 | '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', 36 | '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', 37 | '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', 38 | '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54'] 39 | 40 | self.nclasses = nclasses 41 | self.pretraining = pretraining 42 | self.cnn_name = cnn_name 43 | self.use_resnet = cnn_name.startswith('resnet') 44 | self.mean = torch.tensor([0.485, 0.456, 0.406],dtype=torch.float, requires_grad=False) 45 | self.std = torch.tensor([0.229, 0.224, 0.225],dtype=torch.float, requires_grad=False) 46 | 47 | if self.use_resnet: 48 | if self.cnn_name == 'resnet18': 49 | self.net = models.resnet18(pretrained=self.pretraining) 50 | self.net.fc = nn.Linear(512, self.nclasses) 51 | elif self.cnn_name == 'resnet34': 52 | self.net = models.resnet34(pretrained=self.pretraining) 53 | self.net.fc = nn.Linear(512, self.nclasses) 54 | elif self.cnn_name == 'resnet50': 55 | self.net = models.resnet50(pretrained=self.pretraining) 56 | self.net.fc = nn.Linear(2048, self.nclasses) 57 | else: 58 | if self.cnn_name == 'alexnet': 59 | self.net_1 = models.alexnet(pretrained=self.pretraining).features 60 | self.net_2 = models.alexnet(pretrained=self.pretraining).classifier 61 | elif self.cnn_name == 'vgg11': 62 | self.net_1 = models.vgg11_bn(pretrained=self.pretraining).features 63 | self.net_2 = models.vgg11_bn(pretrained=self.pretraining).classifier 64 | elif self.cnn_name == 'vgg16': 65 | self.net_1 = models.vgg16(pretrained=self.pretraining).features 66 | self.net_2 = models.vgg16(pretrained=self.pretraining).classifier 67 | 68 | self.net_2._modules['6'] = nn.Linear(4096, self.nclasses) 69 | 70 | def forward(self, x): 71 | if self.use_resnet: 72 | return self.net(x) 73 | else: 74 | y = self.net_1(x) 75 | return self.net_2(y.view(y.shape[0], -1)) 76 | 77 | class view_GCN(Model): 78 | def __init__(self,name, model, nclasses=40, cnn_name='resnet18', num_views=20): 79 | super(view_GCN,self).__init__(name) 80 | if nclasses == 40: 81 | self.classnames = ['airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', 'bowl', 'car', 'chair', 82 | 'cone', 'cup', 'curtain', 'desk', 'door', 'dresser', 'flower_pot', 'glass_box', 83 | 'guitar', 'keyboard', 'lamp', 'laptop', 'mantel', 'monitor', 'night_stand', 84 | 'person', 'piano', 'plant', 'radio', 'range_hood', 'sink', 'sofa', 'stairs', 85 | 'stool', 'table', 'tent', 'toilet', 'tv_stand', 'vase', 'wardrobe', 'xbox'] 86 | elif nclasses==15: 87 | self.classnames = ['bag', 'bed', 'bin', 'box', 'cabinet', 'chair', 'desk', 'display' 88 | , 'door', 'pillow', 'shelf', 'sink', 'sofa', 'table', 'toilet'] 89 | else: 90 | self.classnames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', 91 | '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', 92 | '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', 93 | '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', 94 | '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54'] 95 | self.nclasses = nclasses 96 | self.mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float, requires_grad=False) 97 | self.std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float, requires_grad=False) 98 | self.use_resnet = cnn_name.startswith('resnet') 99 | if self.use_resnet: 100 | self.net_1 = nn.Sequential(*list(model.net.children())[:-1]) 101 | self.net_2 = model.net.fc 102 | else: 103 | self.net_1 = model.net_1 104 | self.net_2 = model.net_2 105 | self.num_views = num_views 106 | self.zdim = 8 107 | self.tgt=nn.Parameter(torch.Tensor(self.zdim,512)) 108 | nn.init.xavier_normal_(self.tgt) 109 | self.coord_encoder = nn.Sequential( 110 | nn.Linear(512,64), 111 | nn.ReLU(), 112 | nn.Linear(64,3) 113 | ) 114 | self.coord_decoder = nn.Sequential( 115 | nn.Linear(3,64), 116 | nn.ReLU(), 117 | nn.Linear(64,512) 118 | ) 119 | self.otk_layer = OTKernel(in_dim=512, out_size=self.zdim, heads=1, max_iter=100, eps=0.05) 120 | self.dim = 512 121 | self.encoder1 = TransformerEncoderLayer(d_model=self.dim,nhead=4) 122 | self.encoder2 = TransformerEncoderLayer(d_model=self.dim,nhead=4) 123 | self.ff = PositionWiseFeedForward(d_model=512, d_ff=512) 124 | self.cls = nn.Sequential( 125 | nn.Linear(512, 256), 126 | nn.Dropout(), 127 | nn.LeakyReLU(0.2, inplace=True), 128 | ) 129 | self.cls2 = nn.Linear(256,self.nclasses) 130 | def forward(self,x): 131 | y = self.net_1(x) 132 | y = y.view((int(x.shape[0] / self.num_views), self.num_views, -1)) 133 | y0 = self.encoder1(src=y.transpose(0,1), src_key_padding_mask=None, pos=None) 134 | y0 = y0.transpose(0,1) 135 | y1 = self.otk_layer(y0) 136 | y11 = self.ff(y1) 137 | pos0 = normalize(self.coord_encoder(y11)) 138 | pos = self.coord_decoder(pos0) 139 | y2 = self.encoder2(src=y11.transpose(0,1),src_key_padding_mask=None, pos=pos.transpose(0,1)) 140 | y2 =y2.transpose(0,1) 141 | weight = self.otk_layer.weight 142 | cos_sim = torch.matmul(normalize(weight), normalize(weight).transpose(1, 2)) - torch.eye(self.zdim, 143 | self.zdim).cuda() 144 | cos_sim2 = torch.matmul(normalize(y1), normalize(y1).transpose(1, 2)) - torch.eye(self.zdim, self.zdim).cuda() 145 | pooled_view = y2.mean(1) 146 | feature = self.cls(pooled_view) 147 | pooled_view = self.cls2(feature) 148 | return pooled_view,cos_sim,cos_sim2,pos0 -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer import Transformer 2 | from .captioning_model import CaptioningModel 3 | -------------------------------------------------------------------------------- /models/__pycache__/Vit.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/models/__pycache__/Vit.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/Vit.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/models/__pycache__/Vit.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/captioning_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/models/__pycache__/captioning_model.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/containers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/models/__pycache__/containers.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/models/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/models/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /models/beam_search/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/beam_search/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/models/beam_search/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/beam_search/beam_search.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import utils 3 | 4 | 5 | class BeamSearch(object): 6 | def __init__(self, model, max_len: int, eos_idx: int, beam_size: int): 7 | self.model = model 8 | self.max_len = max_len 9 | self.eos_idx = eos_idx 10 | self.beam_size = beam_size 11 | self.b_s = None 12 | self.device = None 13 | self.seq_mask = None 14 | self.seq_logprob = None 15 | self.outputs = None 16 | self.log_probs = None 17 | self.selected_words = None 18 | self.all_log_probs = None 19 | 20 | def _expand_state(self, selected_beam, cur_beam_size): 21 | def fn(s): 22 | shape = [int(sh) for sh in s.shape] 23 | beam = selected_beam 24 | for _ in shape[1:]: 25 | beam = beam.unsqueeze(-1) 26 | s = torch.gather(s.view(*([self.b_s, cur_beam_size] + shape[1:])), 1, 27 | beam.expand(*([self.b_s, self.beam_size] + shape[1:]))) 28 | s = s.view(*([-1, ] + shape[1:])) 29 | return s 30 | 31 | return fn 32 | 33 | def _expand_visual(self, visual: utils.TensorOrSequence, cur_beam_size: int, selected_beam: torch.Tensor): 34 | if isinstance(visual, torch.Tensor): 35 | visual_shape = visual.shape 36 | visual_exp_shape = (self.b_s, cur_beam_size) + visual_shape[1:] 37 | visual_red_shape = (self.b_s * self.beam_size,) + visual_shape[1:] 38 | selected_beam_red_size = (self.b_s, self.beam_size) + tuple(1 for _ in range(len(visual_exp_shape) - 2)) 39 | selected_beam_exp_size = (self.b_s, self.beam_size) + visual_exp_shape[2:] 40 | visual_exp = visual.view(visual_exp_shape) 41 | selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size) 42 | visual = torch.gather(visual_exp, 1, selected_beam_exp).view(visual_red_shape) 43 | else: 44 | new_visual = [] 45 | for im in visual: 46 | visual_shape = im.shape 47 | visual_exp_shape = (self.b_s, cur_beam_size) + visual_shape[1:] 48 | visual_red_shape = (self.b_s * self.beam_size,) + visual_shape[1:] 49 | selected_beam_red_size = (self.b_s, self.beam_size) + tuple(1 for _ in range(len(visual_exp_shape) - 2)) 50 | selected_beam_exp_size = (self.b_s, self.beam_size) + visual_exp_shape[2:] 51 | visual_exp = im.view(visual_exp_shape) 52 | selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size) 53 | new_im = torch.gather(visual_exp, 1, selected_beam_exp).view(visual_red_shape) 54 | new_visual.append(new_im) 55 | visual = tuple(new_visual) 56 | return visual 57 | 58 | def apply(self, visual: utils.TensorOrSequence, out_size=1, return_probs=False, **kwargs): 59 | self.b_s = utils.get_batch_size(visual) 60 | self.device = utils.get_device(visual) 61 | self.seq_mask = torch.ones((self.b_s, self.beam_size, 1), device=self.device) 62 | self.seq_logprob = torch.zeros((self.b_s, 1, 1), device=self.device) 63 | self.log_probs = [] 64 | self.selected_words = None 65 | if return_probs: 66 | self.all_log_probs = [] 67 | 68 | outputs = [] 69 | with self.model.statefulness(self.b_s): 70 | for t in range(self.max_len): 71 | visual, outputs = self.iter(t, visual, outputs, return_probs, **kwargs) 72 | 73 | # Sort result 74 | seq_logprob, sort_idxs = torch.sort(self.seq_logprob, 1, descending=True) 75 | outputs = torch.cat(outputs, -1) 76 | outputs = torch.gather(outputs, 1, sort_idxs.expand(self.b_s, self.beam_size, self.max_len)) 77 | log_probs = torch.cat(self.log_probs, -1) 78 | log_probs = torch.gather(log_probs, 1, sort_idxs.expand(self.b_s, self.beam_size, self.max_len)) 79 | if return_probs: 80 | all_log_probs = torch.cat(self.all_log_probs, 2) 81 | all_log_probs = torch.gather(all_log_probs, 1, sort_idxs.unsqueeze(-1).expand(self.b_s, self.beam_size, 82 | self.max_len, 83 | all_log_probs.shape[-1])) 84 | 85 | outputs = outputs.contiguous()[:, :out_size] 86 | log_probs = log_probs.contiguous()[:, :out_size] 87 | if out_size == 1: 88 | outputs = outputs.squeeze(1) 89 | log_probs = log_probs.squeeze(1) 90 | 91 | if return_probs: 92 | return outputs, log_probs, all_log_probs 93 | else: 94 | return outputs, log_probs 95 | 96 | def select(self, t, candidate_logprob, **kwargs): 97 | selected_logprob, selected_idx = torch.sort(candidate_logprob.view(self.b_s, -1), -1, descending=True) 98 | selected_logprob, selected_idx = selected_logprob[:, :self.beam_size], selected_idx[:, :self.beam_size] 99 | return selected_idx, selected_logprob 100 | 101 | def iter(self, t: int, visual: utils.TensorOrSequence, outputs, return_probs, **kwargs): 102 | cur_beam_size = 1 if t == 0 else self.beam_size 103 | 104 | word_logprob = self.model.step(t, self.selected_words, visual, None, mode='feedback', **kwargs) 105 | word_logprob = word_logprob.view(self.b_s, cur_beam_size, -1) 106 | candidate_logprob = self.seq_logprob + word_logprob 107 | 108 | # Mask sequence if it reaches EOS 109 | if t > 0: 110 | mask = (self.selected_words.view(self.b_s, cur_beam_size) != self.eos_idx).float().unsqueeze(-1) 111 | self.seq_mask = self.seq_mask * mask 112 | word_logprob = word_logprob * self.seq_mask.expand_as(word_logprob) 113 | old_seq_logprob = self.seq_logprob.expand_as(candidate_logprob).contiguous() 114 | old_seq_logprob[:, :, 1:] = -999 115 | candidate_logprob = self.seq_mask * candidate_logprob + old_seq_logprob * (1 - self.seq_mask) 116 | 117 | selected_idx, selected_logprob = self.select(t, candidate_logprob, **kwargs) 118 | selected_beam = selected_idx / candidate_logprob.shape[-1] 119 | selected_words = selected_idx - selected_beam * candidate_logprob.shape[-1] 120 | 121 | self.model.apply_to_states(self._expand_state(selected_beam, cur_beam_size)) 122 | visual = self._expand_visual(visual, cur_beam_size, selected_beam) 123 | 124 | self.seq_logprob = selected_logprob.unsqueeze(-1) 125 | self.seq_mask = torch.gather(self.seq_mask, 1, selected_beam.unsqueeze(-1)) 126 | outputs = list(torch.gather(o, 1, selected_beam.unsqueeze(-1)) for o in outputs) 127 | outputs.append(selected_words.unsqueeze(-1)) 128 | 129 | if return_probs: 130 | if t == 0: 131 | self.all_log_probs.append(word_logprob.expand((self.b_s, self.beam_size, -1)).unsqueeze(2)) 132 | else: 133 | self.all_log_probs.append(word_logprob.unsqueeze(2)) 134 | 135 | this_word_logprob = torch.gather(word_logprob, 1, 136 | selected_beam.unsqueeze(-1).expand(self.b_s, self.beam_size, 137 | word_logprob.shape[-1])) 138 | this_word_logprob = torch.gather(this_word_logprob, 2, selected_words.unsqueeze(-1)) 139 | self.log_probs = list( 140 | torch.gather(o, 1, selected_beam.unsqueeze(-1).expand(self.b_s, self.beam_size, 1)) for o in self.log_probs) 141 | self.log_probs.append(this_word_logprob) 142 | self.selected_words = selected_words.view(-1, 1) 143 | 144 | return visual, outputs 145 | -------------------------------------------------------------------------------- /models/captioning_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import distributions 3 | import models.utils as utils 4 | from models.containers import Module 5 | from models.beam_search import * 6 | 7 | 8 | class CaptioningModel(Module): 9 | def __init__(self): 10 | super(CaptioningModel, self).__init__() 11 | 12 | def init_weights(self): 13 | raise NotImplementedError 14 | 15 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs): 16 | raise NotImplementedError 17 | 18 | def forward(self, images, seq, *args): 19 | device = images.device 20 | b_s = images.size(0) 21 | seq_len = seq.size(1) 22 | state = self.init_state(b_s, device) 23 | out = None 24 | 25 | outputs = [] 26 | for t in range(seq_len): 27 | out, state = self.step(t, state, out, images, seq, *args, mode='teacher_forcing') 28 | outputs.append(out) 29 | 30 | outputs = torch.cat([o.unsqueeze(1) for o in outputs], 1) 31 | return outputs 32 | 33 | def test(self, visual: utils.TensorOrSequence, max_len: int, eos_idx: int, **kwargs) -> utils.Tuple[torch.Tensor, torch.Tensor]: 34 | b_s = utils.get_batch_size(visual) 35 | device = utils.get_device(visual) 36 | outputs = [] 37 | log_probs = [] 38 | 39 | mask = torch.ones((b_s,), device=device) 40 | with self.statefulness(b_s): 41 | out = None 42 | for t in range(max_len): 43 | log_probs_t = self.step(t, out, visual, None, mode='feedback', **kwargs) 44 | out = torch.max(log_probs_t, -1)[1] 45 | mask = mask * (out.squeeze(-1) != eos_idx).float() 46 | log_probs.append(log_probs_t * mask.unsqueeze(-1).unsqueeze(-1)) 47 | outputs.append(out) 48 | 49 | return torch.cat(outputs, 1), torch.cat(log_probs, 1) 50 | 51 | def sample_rl(self, visual: utils.TensorOrSequence, max_len: int, **kwargs) -> utils.Tuple[torch.Tensor, torch.Tensor]: 52 | b_s = utils.get_batch_size(visual) 53 | outputs = [] 54 | log_probs = [] 55 | 56 | with self.statefulness(b_s): 57 | out = None 58 | for t in range(max_len): 59 | out = self.step(t, out, visual, None, mode='feedback', **kwargs) 60 | distr = distributions.Categorical(logits=out[:, 0]) 61 | out = distr.sample().unsqueeze(1) 62 | outputs.append(out) 63 | log_probs.append(distr.log_prob(out).unsqueeze(1)) 64 | 65 | return torch.cat(outputs, 1), torch.cat(log_probs, 1) 66 | 67 | def beam_search(self, visual: utils.TensorOrSequence, max_len: int, eos_idx: int, beam_size: int, out_size=1, 68 | return_probs=False, **kwargs): 69 | bs = BeamSearch(self, max_len, eos_idx, beam_size) 70 | return bs.apply(visual, out_size, return_probs, **kwargs) 71 | -------------------------------------------------------------------------------- /models/containers.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from torch import nn 3 | from typing import Union, Sequence, Tuple 4 | import torch 5 | 6 | TensorOrSequence = Union[Sequence[torch.Tensor], torch.Tensor] 7 | TensorOrNone = Union[torch.Tensor, None] 8 | 9 | 10 | class Module(nn.Module): 11 | def __init__(self): 12 | super(Module, self).__init__() 13 | self._is_stateful = False 14 | self._state_names = [] 15 | self._state_defaults = dict() 16 | 17 | def register_state(self, name: str, default: TensorOrNone): 18 | self._state_names.append(name) 19 | if default is None: 20 | self._state_defaults[name] = None 21 | else: 22 | self._state_defaults[name] = default.clone().detach() 23 | self.register_buffer(name, default) 24 | 25 | def states(self): 26 | for name in self._state_names: 27 | yield self._buffers[name] 28 | for m in self.children(): 29 | if isinstance(m, Module): 30 | yield from m.states() 31 | 32 | def apply_to_states(self, fn): 33 | for name in self._state_names: 34 | self._buffers[name] = fn(self._buffers[name]) 35 | for m in self.children(): 36 | if isinstance(m, Module): 37 | m.apply_to_states(fn) 38 | 39 | def _init_states(self, batch_size: int): 40 | for name in self._state_names: 41 | if self._state_defaults[name] is None: 42 | self._buffers[name] = None 43 | else: 44 | self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device) 45 | self._buffers[name] = self._buffers[name].unsqueeze(0) 46 | self._buffers[name] = self._buffers[name].expand([batch_size, ] + list(self._buffers[name].shape[1:])) 47 | self._buffers[name] = self._buffers[name].contiguous() 48 | 49 | def _reset_states(self): 50 | for name in self._state_names: 51 | if self._state_defaults[name] is None: 52 | self._buffers[name] = None 53 | else: 54 | self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device) 55 | 56 | def enable_statefulness(self, batch_size: int): 57 | for m in self.children(): 58 | if isinstance(m, Module): 59 | m.enable_statefulness(batch_size) 60 | self._init_states(batch_size) 61 | self._is_stateful = True 62 | 63 | def disable_statefulness(self): 64 | for m in self.children(): 65 | if isinstance(m, Module): 66 | m.disable_statefulness() 67 | self._reset_states() 68 | self._is_stateful = False 69 | 70 | @contextmanager 71 | def statefulness(self, batch_size: int): 72 | self.enable_statefulness(batch_size) 73 | try: 74 | yield 75 | finally: 76 | self.disable_statefulness() 77 | 78 | 79 | class ModuleList(nn.ModuleList, Module): 80 | pass 81 | 82 | 83 | class ModuleDict(nn.ModuleDict, Module): 84 | pass 85 | -------------------------------------------------------------------------------- /models/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer import * 2 | from .encoders import * 3 | from .decoders import * 4 | from .attention import * 5 | -------------------------------------------------------------------------------- /models/transformer/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/models/transformer/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/transformer/__pycache__/attention.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/models/transformer/__pycache__/attention.cpython-36.pyc -------------------------------------------------------------------------------- /models/transformer/__pycache__/decoders.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/models/transformer/__pycache__/decoders.cpython-36.pyc -------------------------------------------------------------------------------- /models/transformer/__pycache__/encoders.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/models/transformer/__pycache__/encoders.cpython-36.pyc -------------------------------------------------------------------------------- /models/transformer/__pycache__/transformer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/models/transformer/__pycache__/transformer.cpython-36.pyc -------------------------------------------------------------------------------- /models/transformer/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/models/transformer/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /models/transformer/decoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | 6 | from models.transformer.attention import MultiHeadAttention 7 | from models.transformer.utils import sinusoid_encoding_table, PositionWiseFeedForward 8 | from models.containers import Module, ModuleList 9 | 10 | 11 | class MeshedDecoderLayer(Module): 12 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, self_att_module=None, 13 | enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None): 14 | super(MeshedDecoderLayer, self).__init__() 15 | self.self_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=True, 16 | attention_module=self_att_module, 17 | attention_module_kwargs=self_att_module_kwargs) 18 | self.enc_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=False, 19 | attention_module=enc_att_module, 20 | attention_module_kwargs=enc_att_module_kwargs) 21 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout) 22 | 23 | self.fc_alpha1 = nn.Linear(d_model + d_model, d_model) 24 | self.fc_alpha2 = nn.Linear(d_model + d_model, d_model) 25 | self.fc_alpha3 = nn.Linear(d_model + d_model, d_model) 26 | 27 | self.init_weights() 28 | 29 | def init_weights(self): 30 | nn.init.xavier_uniform_(self.fc_alpha1.weight) 31 | nn.init.xavier_uniform_(self.fc_alpha2.weight) 32 | nn.init.xavier_uniform_(self.fc_alpha3.weight) 33 | nn.init.constant_(self.fc_alpha1.bias, 0) 34 | nn.init.constant_(self.fc_alpha2.bias, 0) 35 | nn.init.constant_(self.fc_alpha3.bias, 0) 36 | 37 | def forward(self, input, enc_output, mask_pad, mask_self_att, mask_enc_att): 38 | self_att = self.self_att(input, input, input, mask_self_att) 39 | self_att = self_att * mask_pad 40 | 41 | enc_att1 = self.enc_att(self_att, enc_output[:, 0], enc_output[:, 0], mask_enc_att) * mask_pad 42 | enc_att2 = self.enc_att(self_att, enc_output[:, 1], enc_output[:, 1], mask_enc_att) * mask_pad 43 | enc_att3 = self.enc_att(self_att, enc_output[:, 2], enc_output[:, 2], mask_enc_att) * mask_pad 44 | 45 | alpha1 = torch.sigmoid(self.fc_alpha1(torch.cat([self_att, enc_att1], -1))) 46 | alpha2 = torch.sigmoid(self.fc_alpha2(torch.cat([self_att, enc_att2], -1))) 47 | alpha3 = torch.sigmoid(self.fc_alpha3(torch.cat([self_att, enc_att3], -1))) 48 | 49 | enc_att = (enc_att1 * alpha1 + enc_att2 * alpha2 + enc_att3 * alpha3) / np.sqrt(3) 50 | enc_att = enc_att * mask_pad 51 | 52 | ff = self.pwff(enc_att) 53 | ff = ff * mask_pad 54 | return ff 55 | 56 | 57 | class MeshedDecoder(Module): 58 | def __init__(self, vocab_size, max_len, N_dec, padding_idx, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, 59 | self_att_module=None, enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None): 60 | super(MeshedDecoder, self).__init__() 61 | self.d_model = d_model 62 | self.word_emb = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx) 63 | self.pos_emb = nn.Embedding.from_pretrained(sinusoid_encoding_table(max_len + 1, d_model, 0), freeze=True) 64 | self.layers = ModuleList( 65 | [MeshedDecoderLayer(d_model, d_k, d_v, h, d_ff, dropout, self_att_module=self_att_module, 66 | enc_att_module=enc_att_module, self_att_module_kwargs=self_att_module_kwargs, 67 | enc_att_module_kwargs=enc_att_module_kwargs) for _ in range(N_dec)]) 68 | self.fc = nn.Linear(d_model, vocab_size, bias=False) 69 | self.max_len = max_len 70 | self.padding_idx = padding_idx 71 | self.N = N_dec 72 | 73 | self.register_state('running_mask_self_attention', torch.zeros((1, 1, 0)).byte()) 74 | self.register_state('running_seq', torch.zeros((1,)).long()) 75 | 76 | def forward(self, input, encoder_output, mask_encoder): 77 | # input (b_s, seq_len) 78 | b_s, seq_len = input.shape[:2] 79 | mask_queries = (input != self.padding_idx).unsqueeze(-1).float() # (b_s, seq_len, 1) 80 | mask_self_attention = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8, device=input.device), 81 | diagonal=1) 82 | mask_self_attention = mask_self_attention.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len) 83 | mask_self_attention = mask_self_attention + (input == self.padding_idx).unsqueeze(1).unsqueeze(1).byte() 84 | mask_self_attention = mask_self_attention.gt(0) # (b_s, 1, seq_len, seq_len) 85 | if self._is_stateful: 86 | self.running_mask_self_attention = torch.cat([self.running_mask_self_attention, mask_self_attention], -1) 87 | mask_self_attention = self.running_mask_self_attention 88 | 89 | seq = torch.arange(1, seq_len + 1).view(1, -1).expand(b_s, -1).to(input.device) # (b_s, seq_len) 90 | seq = seq.masked_fill(mask_queries.squeeze(-1) == 0, 0) 91 | if self._is_stateful: 92 | self.running_seq.add_(1) 93 | seq = self.running_seq 94 | 95 | out = self.word_emb(input) + self.pos_emb(seq) 96 | for i, l in enumerate(self.layers): 97 | out = l(out, encoder_output, mask_queries, mask_self_attention, mask_encoder) 98 | 99 | out = self.fc(out) 100 | return F.log_softmax(out, dim=-1) 101 | -------------------------------------------------------------------------------- /models/transformer/encoders.py: -------------------------------------------------------------------------------- 1 | from torch.nn import functional as F 2 | from models.transformer.utils import PositionWiseFeedForward,PositionWiseFeedForward_BN 3 | import torch 4 | from torch import nn 5 | from models.transformer.attention import MultiHeadAttention,MultiHeadAttention_BN,MultiHeadAttention_BN2 6 | 7 | class EncoderLayer(nn.Module): 8 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, identity_map_reordering=False, 9 | attention_module=None, attention_module_kwargs=None): 10 | super(EncoderLayer, self).__init__() 11 | self.identity_map_reordering = identity_map_reordering 12 | self.mhatt = MultiHeadAttention(d_model, d_k, d_v, h, dropout, identity_map_reordering=identity_map_reordering, 13 | attention_module=attention_module, 14 | attention_module_kwargs=attention_module_kwargs) 15 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout, identity_map_reordering=identity_map_reordering) 16 | 17 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 18 | att = self.mhatt(queries, keys, values, attention_mask, attention_weights) 19 | # ff = self.pwff(att) 20 | return att 21 | class EncoderLayer_BN(nn.Module): 22 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, identity_map_reordering=False, 23 | attention_module=None, attention_module_kwargs=None): 24 | super(EncoderLayer_BN, self).__init__() 25 | self.identity_map_reordering = identity_map_reordering 26 | self.mhatt = MultiHeadAttention_BN(d_model, d_k, d_v, h, dropout, identity_map_reordering=identity_map_reordering, 27 | attention_module=attention_module, 28 | attention_module_kwargs=attention_module_kwargs) 29 | self.pwff = PositionWiseFeedForward_BN(d_model, d_ff, dropout, identity_map_reordering=identity_map_reordering) 30 | 31 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 32 | att = self.mhatt(queries, keys, values, attention_mask, attention_weights) 33 | # ff = self.pwff(att) 34 | return att 35 | class EncoderLayer_BN2(nn.Module): 36 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, identity_map_reordering=False, 37 | attention_module=None, attention_module_kwargs=None): 38 | super(EncoderLayer_BN2, self).__init__() 39 | self.identity_map_reordering = identity_map_reordering 40 | self.mhatt = MultiHeadAttention_BN2(d_model, d_k, d_v, h, dropout, identity_map_reordering=identity_map_reordering, 41 | attention_module=attention_module, 42 | attention_module_kwargs=attention_module_kwargs) 43 | self.pwff = PositionWiseFeedForward_BN(d_model, d_ff, dropout, identity_map_reordering=identity_map_reordering) 44 | 45 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 46 | att = self.mhatt(queries, keys, values, attention_mask, attention_weights) 47 | # ff = self.pwff(att) 48 | return att 49 | class MultiLevelEncoder(nn.Module): 50 | def __init__(self, N, padding_idx, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, 51 | identity_map_reordering=False, attention_module=None, attention_module_kwargs=None): 52 | super(MultiLevelEncoder, self).__init__() 53 | self.d_model = d_model 54 | self.dropout = dropout 55 | self.layers = nn.ModuleList([EncoderLayer(d_model, d_k, d_v, h, d_ff, dropout, 56 | identity_map_reordering=identity_map_reordering, 57 | attention_module=attention_module, 58 | attention_module_kwargs=attention_module_kwargs) 59 | for _ in range(N)]) 60 | self.padding_idx = padding_idx 61 | 62 | def forward(self, input, attention_weights=None): 63 | # input (b_s, seq_len, d_in) 64 | attention_mask = (torch.sum(input, -1) == self.padding_idx).unsqueeze(1).unsqueeze(1) # (b_s, 1, 1, seq_len) 65 | 66 | outs = [] 67 | out = input 68 | for l in self.layers: 69 | out = l(out, out, out, attention_mask, attention_weights) 70 | outs.append(out.unsqueeze(1)) 71 | 72 | outs = torch.cat(outs, 1) 73 | return outs, attention_mask 74 | 75 | 76 | class MemoryAugmentedEncoder(MultiLevelEncoder): 77 | def __init__(self, N, padding_idx, d_in=2048, **kwargs): 78 | super(MemoryAugmentedEncoder, self).__init__(N, padding_idx, **kwargs) 79 | self.fc = nn.Linear(d_in, self.d_model) 80 | self.dropout = nn.Dropout(p=self.dropout) 81 | self.layer_norm = nn.LayerNorm(self.d_model) 82 | 83 | def forward(self, input, attention_weights=None): 84 | out = F.relu(self.fc(input)) 85 | out = self.dropout(out) 86 | out = self.layer_norm(out) 87 | return super(MemoryAugmentedEncoder, self).forward(out, attention_weights=attention_weights) 88 | -------------------------------------------------------------------------------- /models/transformer/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import copy 4 | from models.containers import ModuleList 5 | from ..captioning_model import CaptioningModel 6 | 7 | 8 | class Transformer(CaptioningModel): 9 | def __init__(self, bos_idx, encoder, decoder): 10 | super(Transformer, self).__init__() 11 | self.bos_idx = bos_idx 12 | self.encoder = encoder 13 | self.decoder = decoder 14 | self.register_state('enc_output', None) 15 | self.register_state('mask_enc', None) 16 | self.init_weights() 17 | 18 | @property 19 | def d_model(self): 20 | return self.decoder.d_model 21 | 22 | def init_weights(self): 23 | for p in self.parameters(): 24 | if p.dim() > 1: 25 | nn.init.xavier_uniform_(p) 26 | 27 | def forward(self, images, seq, *args): 28 | enc_output, mask_enc = self.encoder(images) 29 | dec_output = self.decoder(seq, enc_output, mask_enc) 30 | return dec_output 31 | 32 | def init_state(self, b_s, device): 33 | return [torch.zeros((b_s, 0), dtype=torch.long, device=device), 34 | None, None] 35 | 36 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs): 37 | it = None 38 | if mode == 'teacher_forcing': 39 | raise NotImplementedError 40 | elif mode == 'feedback': 41 | if t == 0: 42 | self.enc_output, self.mask_enc = self.encoder(visual) 43 | if isinstance(visual, torch.Tensor): 44 | it = visual.data.new_full((visual.shape[0], 1), self.bos_idx).long() 45 | else: 46 | it = visual[0].data.new_full((visual[0].shape[0], 1), self.bos_idx).long() 47 | else: 48 | it = prev_output 49 | 50 | return self.decoder(it, self.enc_output, self.mask_enc) 51 | 52 | 53 | class TransformerEnsemble(CaptioningModel): 54 | def __init__(self, model: Transformer, weight_files): 55 | super(TransformerEnsemble, self).__init__() 56 | self.n = len(weight_files) 57 | self.models = ModuleList([copy.deepcopy(model) for _ in range(self.n)]) 58 | for i in range(self.n): 59 | state_dict_i = torch.load(weight_files[i])['state_dict'] 60 | self.models[i].load_state_dict(state_dict_i) 61 | 62 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs): 63 | out_ensemble = [] 64 | for i in range(self.n): 65 | out_i = self.models[i].step(t, prev_output, visual, seq, mode, **kwargs) 66 | out_ensemble.append(out_i.unsqueeze(0)) 67 | 68 | return torch.mean(torch.cat(out_ensemble, 0), dim=0) 69 | -------------------------------------------------------------------------------- /models/transformer/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | def position_embedding(input, d_model): 7 | input = input.view(-1, 1) 8 | dim = torch.arange(d_model // 2, dtype=torch.float32, device=input.device).view(1, -1) 9 | sin = torch.sin(input / 10000 ** (2 * dim / d_model)) 10 | cos = torch.cos(input / 10000 ** (2 * dim / d_model)) 11 | 12 | out = torch.zeros((input.shape[0], d_model), device=input.device) 13 | out[:, ::2] = sin 14 | out[:, 1::2] = cos 15 | return out 16 | 17 | 18 | def sinusoid_encoding_table(max_len, d_model, padding_idx=None): 19 | pos = torch.arange(max_len, dtype=torch.float32) 20 | out = position_embedding(pos, d_model) 21 | 22 | if padding_idx is not None: 23 | out[padding_idx] = 0 24 | return out 25 | 26 | 27 | class PositionWiseFeedForward(nn.Module): 28 | ''' 29 | Position-wise feed forward layer 30 | ''' 31 | 32 | def __init__(self, d_model=512, d_ff=2048, dropout=.1, identity_map_reordering=False): 33 | super(PositionWiseFeedForward, self).__init__() 34 | self.identity_map_reordering = identity_map_reordering 35 | self.fc1 = nn.Linear(d_model, d_ff) 36 | self.fc2 = nn.Linear(d_ff, d_model) 37 | self.dropout = nn.Dropout(p=dropout) 38 | self.dropout_2 = nn.Dropout(p=dropout) 39 | self.layer_norm = nn.LayerNorm(d_model) 40 | 41 | def forward(self, input): 42 | if self.identity_map_reordering: 43 | out = self.layer_norm(input) 44 | out = self.fc2(self.dropout_2(F.relu(self.fc1(out)))) 45 | out = input + self.dropout(torch.relu(out)) 46 | else: 47 | out = self.fc2(self.dropout_2(F.relu(self.fc1(input)))) 48 | # out = self.fc2(self.dropout_2(F.gelu(self.fc1(input)))) 49 | out = self.dropout(out) 50 | out = input + out 51 | # out = self.layer_norm(input + out) 52 | return out 53 | 54 | class PositionWiseFeedForward_BN(nn.Module): 55 | ''' 56 | Position-wise feed forward layer 57 | ''' 58 | 59 | def __init__(self, d_model=512, d_ff=2048, dropout=.1, identity_map_reordering=False): 60 | super(PositionWiseFeedForward_BN, self).__init__() 61 | self.identity_map_reordering = identity_map_reordering 62 | self.fc1 = nn.Linear(d_model, d_ff) 63 | self.fc2 = nn.Linear(d_ff, d_model) 64 | self.dropout = nn.Dropout(p=dropout) 65 | self.dropout_2 = nn.Dropout(p=dropout) 66 | self.layer_norm = nn.LayerNorm(d_model) 67 | self.bn = nn.BatchNorm1d(d_ff) 68 | def forward(self, input): 69 | if self.identity_map_reordering: 70 | out = self.layer_norm(input) 71 | out = self.fc2(self.dropout_2(F.relu(self.fc1(out)))) 72 | out = input + self.dropout(torch.relu(out)) 73 | else: 74 | # out = self.fc2(self.dropout_2(F.relu(self.fc1(input)))) 75 | # out = self.dropout(out) 76 | # out = (input + out).tranpose(1,2) 77 | # out = self.bn(out).transpose(1,2) 78 | # out = self.layer_norm(input + out) 79 | out = self.fc1(input).transpose(1,2) 80 | out = F.relu(self.bn(out).transpose(1,2)) 81 | out = self.fc2(out)+input 82 | 83 | return out 84 | 85 | class PositionWiseFeedForward_LN(nn.Module): 86 | ''' 87 | Position-wise feed forward layer 88 | ''' 89 | 90 | def __init__(self, d_model=512, d_ff=2048, dropout=.1, identity_map_reordering=False): 91 | super(PositionWiseFeedForward_LN, self).__init__() 92 | self.identity_map_reordering = identity_map_reordering 93 | self.fc1 = nn.Linear(d_model, d_ff) 94 | self.fc2 = nn.Linear(d_ff, d_model) 95 | self.dropout = nn.Dropout(p=dropout) 96 | self.dropout_2 = nn.Dropout(p=dropout) 97 | self.layer_norm = nn.LayerNorm(d_model) 98 | 99 | def forward(self, input): 100 | if self.identity_map_reordering: 101 | out = self.layer_norm(input) 102 | out = self.fc2(self.dropout_2(F.relu(self.fc1(out)))) 103 | out = input + self.dropout(torch.relu(out)) 104 | else: 105 | out = self.fc2(self.dropout_2(F.relu(self.fc1(input)))) 106 | # out = self.fc2(self.dropout_2(F.gelu(self.fc1(input)))) 107 | out = self.dropout(out) 108 | # out = input + out 109 | out = self.layer_norm(input + out) 110 | return out -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | def position_embedding(input, d_model): 7 | input = input.view(-1, 1) 8 | dim = torch.arange(d_model // 2, dtype=torch.float32, device=input.device).view(1, -1) 9 | sin = torch.sin(input / 10000 ** (2 * dim / d_model)) 10 | cos = torch.cos(input / 10000 ** (2 * dim / d_model)) 11 | 12 | out = torch.zeros((input.shape[0], d_model), device=input.device) 13 | out[:, ::2] = sin 14 | out[:, 1::2] = cos 15 | return out 16 | 17 | 18 | def sinusoid_encoding_table(max_len, d_model, padding_idx=None): 19 | pos = torch.arange(max_len, dtype=torch.float32) 20 | out = position_embedding(pos, d_model) 21 | 22 | if padding_idx is not None: 23 | out[padding_idx] = 0 24 | return out 25 | 26 | 27 | class PositionWiseFeedForward(nn.Module): 28 | ''' 29 | Position-wise feed forward layer 30 | ''' 31 | 32 | def __init__(self, d_model=512, d_ff=2048, dropout=.1, identity_map_reordering=False): 33 | super(PositionWiseFeedForward, self).__init__() 34 | self.identity_map_reordering = identity_map_reordering 35 | self.fc1 = nn.Linear(d_model, d_ff) 36 | self.fc2 = nn.Linear(d_ff, d_model) 37 | self.dropout = nn.Dropout(p=dropout) 38 | self.dropout_2 = nn.Dropout(p=dropout) 39 | self.layer_norm = nn.LayerNorm(d_model) 40 | 41 | def forward(self, input): 42 | if self.identity_map_reordering: 43 | out = self.layer_norm(input) 44 | out = self.fc2(self.dropout_2(F.relu(self.fc1(out)))) 45 | out = input + self.dropout(torch.relu(out)) 46 | else: 47 | out = self.fc2(self.dropout_2(F.relu(self.fc1(input)))) 48 | out = self.dropout(out) 49 | out = self.layer_norm(out+input) 50 | return out -------------------------------------------------------------------------------- /models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import download_from_url 2 | from .typing import * 3 | 4 | def get_batch_size(x: TensorOrSequence) -> int: 5 | if isinstance(x, torch.Tensor): 6 | b_s = x.size(0) 7 | else: 8 | b_s = x[0].size(0) 9 | return b_s 10 | 11 | 12 | def get_device(x: TensorOrSequence) -> int: 13 | if isinstance(x, torch.Tensor): 14 | b_s = x.device 15 | else: 16 | b_s = x[0].device 17 | return b_s 18 | -------------------------------------------------------------------------------- /models/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/models/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/utils/__pycache__/typing.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/models/utils/__pycache__/typing.cpython-36.pyc -------------------------------------------------------------------------------- /models/utils/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/models/utils/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /models/utils/typing.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Sequence, Tuple 2 | import torch 3 | 4 | TensorOrSequence = Union[Sequence[torch.Tensor], torch.Tensor] 5 | TensorOrNone = Union[torch.Tensor, None] 6 | -------------------------------------------------------------------------------- /models/utils/utils.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | def download_from_url(url, path): 4 | """Download file, with logic (from tensor2tensor) for Google Drive""" 5 | if 'drive.google.com' not in url: 6 | print('Downloading %s; may take a few minutes' % url) 7 | r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}) 8 | with open(path, "wb") as file: 9 | file.write(r.content) 10 | return 11 | print('Downloading from Google Drive; may take a few minutes') 12 | confirm_token = None 13 | session = requests.Session() 14 | response = session.get(url, stream=True) 15 | for k, v in response.cookies.items(): 16 | if k.startswith("download_warning"): 17 | confirm_token = v 18 | 19 | if confirm_token: 20 | url = url + "&confirm=" + confirm_token 21 | response = session.get(url, stream=True) 22 | 23 | chunk_size = 16 * 1024 24 | with open(path, "wb") as f: 25 | for chunk in response.iter_content(chunk_size): 26 | if chunk: 27 | f.write(chunk) 28 | -------------------------------------------------------------------------------- /otk/cross_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from einops import rearrange, repeat 4 | from torch import nn 5 | 6 | MIN_NUM_PATCHES = 16 7 | 8 | class Residual(nn.Module): 9 | def __init__(self, fn): 10 | super().__init__() 11 | self.fn = fn 12 | def forward(self, x, **kwargs): 13 | return self.fn(x, **kwargs) + x 14 | 15 | class PreNorm(nn.Module): 16 | def __init__(self, dim, fn): 17 | super().__init__() 18 | self.norm = nn.LayerNorm(dim) 19 | self.fn = fn 20 | def forward(self, x, **kwargs): 21 | return self.fn(self.norm(x), **kwargs) 22 | 23 | class FeedForward(nn.Module): 24 | def __init__(self, dim, hidden_dim, dropout = 0.): 25 | super().__init__() 26 | self.net = nn.Sequential( 27 | nn.Linear(dim, hidden_dim), 28 | nn.GELU(), 29 | nn.Dropout(dropout), 30 | nn.Linear(hidden_dim, dim), 31 | nn.Dropout(dropout) 32 | ) 33 | def forward(self, x): 34 | return self.net(x) 35 | class Attention_entropy(nn.Module): 36 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 37 | super().__init__() 38 | inner_dim = dim_head * heads 39 | self.heads = heads 40 | self.scale = dim_head ** -0.5 41 | 42 | self.to_qk = nn.Linear(dim, inner_dim * 2, bias = False) 43 | self.to_v = nn.Linear(dim,inner_dim,bias=False) 44 | self.to_out = nn.Sequential( 45 | nn.Linear(inner_dim, dim), 46 | nn.Dropout(dropout) 47 | ) 48 | 49 | def forward(self, x, mask = None,mm=None): 50 | b, n, _, h = *x.shape, self.heads 51 | qk = self.to_qk(x) 52 | v = self.to_v(x) 53 | qkv = torch.cat((qk,v),-1) 54 | qkv = qkv.chunk(3, dim = -1) 55 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 56 | 57 | dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 58 | mask_value = -torch.finfo(dots.dtype).max 59 | 60 | if mask is not None: 61 | # mask = F.pad(mask.flatten(1), (1, 0), value = True) 62 | assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' 63 | mask = rearrange(mask, 'b i -> b () i ()') * rearrange(mask, 'b j -> b () () j') 64 | dots.masked_fill_(~mask, mask_value) 65 | del mask 66 | mm = mm.unsqueeze(1).repeat(b,h,1,1) 67 | dots = dots + mm 68 | attn = dots.softmax(dim=-1) 69 | out = torch.einsum('b h i j, b h j d -> b h i d', attn, v) 70 | out = rearrange(out, 'b h n d -> b n (h d)') 71 | out = self.to_out(out) 72 | return out 73 | class Attention(nn.Module): 74 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 75 | super().__init__() 76 | inner_dim = dim_head * heads 77 | self.heads = heads 78 | self.scale = dim_head ** -0.5 79 | 80 | self.to_qk = nn.Linear(dim, inner_dim * 2, bias = False) 81 | self.to_v = nn.Linear(dim,inner_dim,bias=False) 82 | self.to_out = nn.Sequential( 83 | nn.Linear(inner_dim, dim), 84 | nn.Dropout(dropout) 85 | ) 86 | 87 | def forward(self, x, mask = None,mm=None): 88 | b, n, _, h = *x.shape, self.heads 89 | qk = self.to_qk(x+mm) 90 | v = self.to_v(x) 91 | qkv = torch.cat((qk,v),-1) 92 | qkv = qkv.chunk(3, dim = -1) 93 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 94 | 95 | dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 96 | mask_value = -torch.finfo(dots.dtype).max 97 | 98 | if mask is not None: 99 | # mask = F.pad(mask.flatten(1), (1, 0), value = True) 100 | assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' 101 | mask = rearrange(mask, 'b i -> b () i ()') * rearrange(mask, 'b j -> b () () j') 102 | dots.masked_fill_(~mask, mask_value) 103 | del mask 104 | 105 | attn = dots.softmax(dim=-1) 106 | 107 | out = torch.einsum('b h i j, b h j d -> b h i d', attn, v) 108 | out = rearrange(out, 'b h n d -> b n (h d)') 109 | out = self.to_out(out) 110 | return out 111 | 112 | class Transformer(nn.Module): 113 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout): 114 | super().__init__() 115 | self.layers = nn.ModuleList([]) 116 | for _ in range(depth): 117 | self.layers.append(nn.ModuleList([ 118 | Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))), 119 | Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))) 120 | ])) 121 | def forward(self, x, mask = None,mm=None): 122 | for attn, ff in self.layers: 123 | x = attn(x, mask = mask,mm=mm) 124 | x = ff(x) 125 | return x 126 | class Transformer_noff(nn.Module): 127 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout): 128 | super().__init__() 129 | self.layers = nn.ModuleList([]) 130 | for _ in range(depth): 131 | self.layers.append(nn.ModuleList([ 132 | Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))), 133 | Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))) 134 | ])) 135 | def forward(self, x, mask = None,mm=None): 136 | for attn, ff in self.layers: 137 | x = attn(x, mask = mask,mm=mm) 138 | # x = ff(x) 139 | return x 140 | 141 | class ViT(nn.Module): 142 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): 143 | super().__init__() 144 | assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' 145 | num_patches = (image_size // patch_size) ** 2 146 | patch_dim = channels * patch_size ** 2 147 | assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size' 148 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 149 | 150 | self.patch_size = patch_size 151 | 152 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) 153 | self.patch_to_embedding = nn.Linear(patch_dim, dim) 154 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 155 | self.dropout = nn.Dropout(emb_dropout) 156 | 157 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 158 | 159 | self.pool = pool 160 | self.to_latent = nn.Identity() 161 | 162 | self.mlp_head = nn.Sequential( 163 | nn.LayerNorm(dim), 164 | nn.Linear(dim, num_classes) 165 | ) 166 | 167 | def forward(self, img, mask = None): 168 | p = self.patch_size 169 | 170 | x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p) 171 | x = self.patch_to_embedding(x) 172 | b, n, _ = x.shape 173 | 174 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) 175 | x = torch.cat((cls_tokens, x), dim=1) 176 | x += self.pos_embedding[:, :(n + 1)] 177 | x = self.dropout(x) 178 | 179 | x = self.transformer(x, mask) 180 | 181 | x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] 182 | 183 | x = self.to_latent(x) 184 | return self.mlp_head(x) -------------------------------------------------------------------------------- /otk/data_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import h5py 3 | import scipy.io as sio 4 | import torch 5 | from torch.utils.data import Dataset 6 | 7 | 8 | def _load_mat_file(filepath, sequence_key, targets_key=None): 9 | """ 10 | Loads data from a `*.mat` file or a `*.h5` file. 11 | Parameters 12 | ---------- 13 | filepath : str 14 | The path to the file to load the data from. 15 | sequence_key : str 16 | The key for the sequences data matrix. 17 | targets_key : str, optional 18 | Default is None. The key for the targets data matrix. 19 | Returns 20 | ------- 21 | (sequences, targets, h5py_filehandle) : \ 22 | tuple(array-like, array-like, h5py.File) 23 | If the matrix files can be loaded with `scipy.io`, 24 | the tuple will only be (sequences, targets). Otherwise, 25 | the 2 matrices and the h5py file handle are returned. 26 | """ 27 | try: # see if we can load the file using scipy first 28 | mat = sio.loadmat(filepath) 29 | targets = None 30 | if targets_key: 31 | targets = mat[targets_key] 32 | return (mat[sequence_key], targets) 33 | except (NotImplementedError, ValueError): 34 | mat = h5py.File(filepath, 'r') 35 | sequences = mat[sequence_key] 36 | targets = None 37 | if targets_key: 38 | targets = mat[targets_key] 39 | return (sequences, targets, mat) 40 | 41 | 42 | class MatDataset(Dataset): 43 | def __init__(self, filepath, split='train'): 44 | super().__init__() 45 | filepath = filepath + "/{}.mat".format(split) 46 | sequence_key = "{}xdata".format(split) 47 | targets_key = "{}data".format(split) 48 | out = _load_mat_file(filepath, sequence_key, targets_key) 49 | self.data_tensor = out[0] 50 | self.target_tensor = out[1] 51 | self.split = split 52 | 53 | def __getitem__(self, index): 54 | if self.split == "train": 55 | data_tensor = self.data_tensor[:, :, index] 56 | data_tensor = data_tensor.transpose().astype('float32') 57 | target_tensor = self.target_tensor[:, index].astype('float32') 58 | else: 59 | data_tensor = self.data_tensor[index].astype('float32') 60 | target_tensor = self.target_tensor[index].astype('float32') 61 | data_tensor = torch.from_numpy(data_tensor) 62 | target_tensor = torch.from_numpy(target_tensor) 63 | return data_tensor, target_tensor 64 | 65 | def __len__(self): 66 | if self.split == 'train': 67 | return self.target_tensor.shape[1] 68 | return self.target_tensor.shape[0] 69 | -------------------------------------------------------------------------------- /otk/layers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import math 4 | from torch import nn 5 | import torch.optim as optim 6 | from .utils import spherical_kmeans, normalize 7 | from .sinkhorn import wasserstein_kmeans, multihead_attn 8 | import numpy as np 9 | 10 | 11 | class OTKernel(nn.Module): 12 | def __init__(self, in_dim, out_size, heads=1, eps=0.1, max_iter=100, 13 | log_domain=False, position_encoding=None, position_sigma=0.1): 14 | super().__init__() 15 | self.in_dim = in_dim 16 | self.out_size = out_size 17 | self.heads = heads 18 | self.eps = eps 19 | self.max_iter = max_iter 20 | 21 | self.weight = nn.Parameter( 22 | torch.Tensor(heads, out_size, in_dim)) 23 | 24 | self.log_domain = log_domain 25 | self.position_encoding = position_encoding 26 | self.position_sigma = position_sigma 27 | 28 | self.reset_parameter() 29 | nn.init.xavier_normal_(self.weight) 30 | # nn.init.kaiming_normal_(self.weight) 31 | def reset_parameter(self): 32 | stdv = 1. / math.sqrt(self.out_size) 33 | for w in self.parameters(): 34 | w.data.uniform_(-stdv, stdv) 35 | # def reset_parameter(self): 36 | # for w in self.parameters(): 37 | def get_position_filter(self, input, out_size): 38 | if input.ndim == 4: 39 | in_size1 = input.shape[1] 40 | in_size2 = input.shape[2] 41 | out_size = int(math.sqrt(out_size)) 42 | if self.position_encoding is None: 43 | return self.position_encoding 44 | elif self.position_encoding == "gaussian": 45 | sigma = self.position_sigma 46 | a1 = torch.arange(1., in_size1 + 1.).view(-1, 1) / in_size1 47 | a2 = torch.arange(1., in_size2 + 1.).view(-1, 1) / in_size2 48 | b = torch.arange(1., out_size + 1.).view(1, -1) / out_size 49 | position_filter1 = torch.exp(-((a1 - b) / sigma) ** 2) 50 | position_filter2 = torch.exp(-((a2 - b) / sigma) ** 2) 51 | position_filter = position_filter1.view( 52 | in_size1, 1, out_size, 1) * position_filter2.view( 53 | 1, in_size2, 1, out_size) 54 | if self.weight.is_cuda: 55 | position_filter = position_filter.cuda() 56 | return position_filter.reshape(1, 1, in_size1 * in_size2, out_size * out_size) 57 | in_size = input.shape[1] 58 | if self.position_encoding is None: 59 | return self.position_encoding 60 | elif self.position_encoding == "gaussian": 61 | # sigma = 1. / out_size 62 | sigma = self.position_sigma 63 | a = torch.arange(0., in_size).view(-1, 1) / in_size 64 | b = torch.arange(0., out_size).view(1, -1) / out_size 65 | position_filter = torch.exp(-((a - b) / sigma) ** 2) 66 | elif self.position_encoding == "hard": 67 | # sigma = 1. / out_size 68 | sigma = self.position_sigma 69 | a = torch.arange(0., in_size).view(-1, 1) / in_size 70 | b = torch.arange(0., out_size).view(1, -1) / out_size 71 | position_filter = torch.abs(a - b) < sigma 72 | position_filter = position_filter.float() 73 | else: 74 | raise ValueError("Unrecognizied position encoding") 75 | if self.weight.is_cuda: 76 | position_filter = position_filter.cuda() 77 | position_filter = position_filter.view(1, 1, in_size, out_size) 78 | return position_filter 79 | 80 | def get_attn(self, input, mask=None, position_filter=None): 81 | """Compute the attention weight using Sinkhorn OT 82 | input: batch_size x in_size x in_dim 83 | mask: batch_size x in_size 84 | self.weight: heads x out_size x in_dim 85 | output: batch_size x (out_size x heads) x in_size 86 | """ 87 | return multihead_attn( 88 | input, self.weight, mask=mask, eps=self.eps, 89 | max_iter=self.max_iter, log_domain=self.log_domain, 90 | position_filter=position_filter) 91 | 92 | def forward(self, input, mask=None): 93 | """ 94 | input: batch_size x in_size x in_dim 95 | output: batch_size x out_size x (heads x in_dim) 96 | """ 97 | batch_size = input.shape[0] 98 | position_filter = self.get_position_filter(input, self.out_size) 99 | in_ndim = input.ndim 100 | if in_ndim == 4: 101 | input = input.view(batch_size, -1, self.in_dim) 102 | attn_weight = self.get_attn(input, mask, position_filter) 103 | # attn_weight: batch_size x out_size x heads x in_size 104 | # aa = attn_weight.detach().cpu().numpy() 105 | output = torch.bmm( 106 | attn_weight.view(batch_size, self.out_size * self.heads, -1), input) 107 | if in_ndim == 4: 108 | out_size = int(math.sqrt(self.out_size)) 109 | output = output.reshape(batch_size, out_size, out_size, -1) 110 | else: 111 | output = output.reshape(batch_size, self.out_size, -1) 112 | return output 113 | 114 | def unsup_train(self, input, wb=False, inplace=True, use_cuda=False): 115 | """K-meeans for learning parameters 116 | input: n_samples x in_size x in_dim 117 | weight: heads x out_size x in_dim 118 | """ 119 | input_normalized = normalize(input, inplace=inplace) 120 | # input_normalized = input 121 | block_size = int(1e9) // (input.shape[1] * input.shape[2] * 4) 122 | print("Starting Wasserstein K-means") 123 | weight = wasserstein_kmeans( 124 | input_normalized, self.heads, self.out_size, eps=self.eps, 125 | block_size=block_size, wb=wb, log_domain=self.log_domain, use_cuda=use_cuda) 126 | self.weight.data.copy_(weight) 127 | 128 | def random_sample(self, input): 129 | idx = torch.randint(0, input.shape[0], (1,)) 130 | self.weight.data.copy_(input[idx].view_as(self.weight)) 131 | 132 | class Linear(nn.Linear): 133 | def forward(self, input): 134 | bias = self.bias 135 | if bias is not None and hasattr(self, 'scale_bias') and self.scale_bias is not None: 136 | bias = self.scale_bias * bias 137 | out = torch.nn.functional.linear(input, self.weight, bias) 138 | return out 139 | 140 | def fit(self, Xtr, ytr, criterion, reg=0.0, epochs=100, optimizer=None, use_cuda=False): 141 | if optimizer is None: 142 | optimizer = optim.LBFGS(self.parameters(), lr=1.0, history_size=10) 143 | if self.bias is not None: 144 | scale_bias = (Xtr ** 2).mean(-1).sqrt().mean().item() 145 | self.scale_bias = scale_bias 146 | self.train() 147 | if use_cuda: 148 | self.cuda() 149 | Xtr = Xtr.cuda() 150 | ytr = ytr.cuda() 151 | def closure(): 152 | optimizer.zero_grad() 153 | output = self(Xtr) 154 | loss = criterion(output, ytr) 155 | loss = loss + 0.5 * reg * self.weight.pow(2).sum() 156 | loss.backward() 157 | return loss 158 | 159 | for epoch in range(epochs): 160 | optimizer.step(closure) 161 | if self.bias is not None: 162 | self.bias.data.mul_(self.scale_bias) 163 | self.scale_bias = None 164 | 165 | def score(self, X, y): 166 | self.eval() 167 | with torch.no_grad(): 168 | scores = self(X) 169 | scores = scores.argmax(-1) 170 | scores = scores.cpu() 171 | return torch.mean((scores == y).float()).item() 172 | -------------------------------------------------------------------------------- /otk/models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | from torch import nn 4 | from .layers import OTKernel, Linear 5 | from ckn.layers import BioEmbedding 6 | from ckn.models import CKNSequential 7 | 8 | 9 | class SeqAttention(nn.Module): 10 | def __init__(self, in_channels, nclass, hidden_sizes, filter_sizes, 11 | subsamplings, kernel_args=None, eps=0.1, heads=1, 12 | out_size=1, max_iter=50, alpha=0., fit_bias=True, 13 | mask_zeros=True): 14 | super().__init__() 15 | self.embed_layer = BioEmbedding( 16 | in_channels, False, mask_zeros=True, no_embed=True) 17 | self.ckn_model = CKNSequential( 18 | in_channels, hidden_sizes, filter_sizes, 19 | subsamplings, kernel_args_list=kernel_args) 20 | self.attention = OTKernel(hidden_sizes[-1], out_size, heads=heads, 21 | eps=eps, max_iter=max_iter) 22 | self.out_features = out_size * heads * hidden_sizes[-1] 23 | self.nclass = nclass 24 | 25 | self.classifier = Linear(self.out_features, nclass, bias=fit_bias) 26 | self.alpha = alpha 27 | self.mask_zeros = mask_zeros 28 | 29 | def feature_parameters(self): 30 | import itertools 31 | return itertools.chain(self.ckn_model.parameters(), 32 | self.attention.parameters()) 33 | 34 | def normalize_(self): 35 | self.ckn_model.normalize_() 36 | 37 | def ckn_representation_at(self, input, n=0): 38 | output = self.embed_layer(input) 39 | mask = self.embed_layer.compute_mask(input) 40 | output = self.ckn_model.representation(output, n) 41 | mask = self.ckn_model.compute_mask(mask, n) 42 | return output, mask 43 | 44 | def ckn_representation(self, input): 45 | output = self.embed_layer(input) 46 | output = self.ckn_model(output).permute(0, 2, 1).contiguous() 47 | return output 48 | 49 | def representation(self, input): 50 | output = self.embed_layer(input) 51 | mask = self.embed_layer.compute_mask(input) 52 | output = self.ckn_model(output).permute(0, 2, 1).contiguous() 53 | mask = self.ckn_model.compute_mask(mask) 54 | if not self.mask_zeros: 55 | mask = None 56 | output = self.attention(output, mask).reshape(output.shape[0], -1) 57 | return output 58 | 59 | def forward(self, input): 60 | output = self.representation(input) 61 | return self.classifier(output) 62 | 63 | def predict(self, data_loader, only_repr=False, use_cuda=False): 64 | n_samples = len(data_loader.dataset) 65 | target_output = torch.LongTensor(n_samples) 66 | batch_start = 0 67 | for i, (data, target) in enumerate(data_loader): 68 | batch_size = data.shape[0] 69 | if use_cuda: 70 | data = data.cuda() 71 | with torch.no_grad(): 72 | if only_repr: 73 | batch_out = self.representation(data).data.cpu() 74 | else: 75 | batch_out = self(data).data.cpu() 76 | if i == 0: 77 | output = batch_out.new_empty([n_samples] + 78 | list(batch_out.shape[1:])) 79 | output[batch_start:batch_start + batch_size] = batch_out 80 | target_output[batch_start:batch_start + batch_size] = target 81 | batch_start += batch_size 82 | return output, target_output 83 | 84 | def train_classifier(self, data_loader, criterion=None, epochs=100, 85 | optimizer=None, use_cuda=False): 86 | encoded_train, encoded_target = self.predict( 87 | data_loader, only_repr=True, use_cuda=use_cuda) 88 | self.classifier.fit(encoded_train, encoded_target, criterion, 89 | reg=self.alpha, epochs=epochs, optimizer=optimizer, 90 | use_cuda=use_cuda) 91 | 92 | def unsup_train(self, data_loader, n_sampling_patches=300000, 93 | n_samples=5000, wb=False, use_cuda=False): 94 | self.eval() 95 | if use_cuda: 96 | self.cuda() 97 | 98 | for i, ckn_layer in enumerate(self.ckn_model): 99 | print("Training ckn layer {}".format(i)) 100 | n_patches = 0 101 | try: 102 | n_patches_per_batch = ( 103 | n_sampling_patches + len(data_loader) - 1 104 | ) // len(data_loader) 105 | except: 106 | n_patches_per_batch = 1000 107 | patches = torch.Tensor(n_sampling_patches, ckn_layer.patch_dim) 108 | if use_cuda: 109 | patches = patches.cuda() 110 | 111 | for data, _ in data_loader: 112 | if n_patches >= n_sampling_patches: 113 | continue 114 | if use_cuda: 115 | data = data.cuda() 116 | with torch.no_grad(): 117 | data, mask = self.ckn_representation_at(data, i) 118 | data_patches = ckn_layer.sample_patches( 119 | data, mask, n_patches_per_batch) 120 | size = data_patches.size(0) 121 | if n_patches + size > n_sampling_patches: 122 | size = n_sampling_patches - n_patches 123 | data_patches = data_patches[:size] 124 | patches[n_patches: n_patches + size] = data_patches 125 | n_patches += size 126 | 127 | print("total number of patches: {}".format(n_patches)) 128 | patches = patches[:n_patches] 129 | ckn_layer.unsup_train(patches, init=None) 130 | 131 | n_samples = min(n_samples, len(data_loader.dataset)) 132 | cur_samples = 0 133 | print("Training attention layer") 134 | for i, (data, _) in enumerate(data_loader): 135 | if cur_samples >= n_samples: 136 | continue 137 | if use_cuda: 138 | data = data.cuda() 139 | with torch.no_grad(): 140 | data = self.ckn_representation(data) 141 | 142 | if i == 0: 143 | patches = torch.empty([n_samples]+list(data.shape[1:])) 144 | 145 | size = data.shape[0] 146 | if cur_samples + size > n_samples: 147 | size = n_samples - cur_samples 148 | data = data[:size] 149 | patches[cur_samples: cur_samples + size] = data 150 | cur_samples += size 151 | print(patches.shape) 152 | self.attention.unsup_train(patches, wb=wb, use_cuda=use_cuda) 153 | -------------------------------------------------------------------------------- /otk/models_deepsea.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | from torch import nn 4 | from .layers import OTKernel 5 | 6 | 7 | class OTLayer(nn.Module): 8 | def __init__(self, in_dim, out_size, heads=1, eps=0.1, max_iter=10, 9 | position_encoding=None, position_sigma=0.1, out_dim=None, 10 | dropout=0.4): 11 | super().__init__() 12 | self.out_size = out_size 13 | self.heads = heads 14 | if out_dim is None: 15 | out_dim = in_dim 16 | 17 | self.layer = nn.Sequential( 18 | OTKernel(in_dim, out_size, heads, eps, max_iter, log_domain=True, 19 | position_encoding=position_encoding, position_sigma=position_sigma), 20 | nn.Linear(heads * in_dim, out_dim), 21 | nn.ReLU(inplace=True), 22 | nn.Dropout(dropout) 23 | ) 24 | nn.init.xavier_uniform_(self.layer[0].weight) 25 | nn.init.xavier_uniform_(self.layer[1].weight) 26 | 27 | def forward(self, input): 28 | output = self.layer(input) 29 | return output 30 | 31 | class SeqAttention(nn.Module): 32 | def __init__(self, nclass, hidden_size, filter_size, 33 | n_attn_layers, eps=0.1, heads=1, 34 | out_size=1, max_iter=10, hidden_layer=False, 35 | position_encoding=None, position_sigma=0.1): 36 | super().__init__() 37 | self.embed = nn.Sequential( 38 | nn.Conv1d(4, hidden_size, kernel_size=filter_size), 39 | nn.ReLU(inplace=True), 40 | ) 41 | 42 | attn_layers = [OTLayer( 43 | hidden_size, out_size, heads, eps, max_iter, 44 | position_encoding, position_sigma=position_sigma)] + [OTLayer( 45 | hidden_size, out_size, heads, eps, max_iter, position_encoding, position_sigma=position_sigma 46 | ) for _ in range(n_attn_layers - 1)] 47 | self.attn_layers = nn.Sequential(*attn_layers) 48 | 49 | self.out_features = out_size * hidden_size 50 | self.nclass = nclass 51 | 52 | if hidden_layer: 53 | self.classifier = nn.Sequential( 54 | nn.Linear(self.out_features, nclass), 55 | nn.ReLU(inplace=True), 56 | nn.Linear(nclass, nclass)) 57 | else: 58 | self.classifier = nn.Linear(self.out_features, nclass) 59 | 60 | def representation(self, input): 61 | output = self.embed(input).transpose(1, 2).contiguous() 62 | output = self.attn_layers(output) 63 | output = output.reshape(output.shape[0], -1) 64 | return output 65 | 66 | def forward(self, input): 67 | output = self.representation(input) 68 | return self.classifier(output) 69 | 70 | def predict(self, data_loader, only_repr=False, use_cuda=False): 71 | n_samples = len(data_loader.dataset) 72 | target_output = torch.LongTensor(n_samples) 73 | batch_start = 0 74 | for i, (data, target) in enumerate(data_loader): 75 | batch_size = data.shape[0] 76 | if use_cuda: 77 | data = data.cuda() 78 | with torch.no_grad(): 79 | if only_repr: 80 | batch_out = self.representation(data).data.cpu() 81 | else: 82 | batch_out = self(data).data.cpu() 83 | if i == 0: 84 | output = batch_out.new_empty([n_samples] + list(batch_out.shape[1:])) 85 | output[batch_start:batch_start + batch_size] = batch_out 86 | target_output[batch_start:batch_start + batch_size] = target 87 | batch_start += batch_size 88 | return output, target_output 89 | -------------------------------------------------------------------------------- /otk/sinkhorn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import math 3 | import torch 4 | from .utils import spherical_kmeans,normalize 5 | 6 | def sinkhorn(dot, mask=None, eps=1e-03, return_kernel=False, max_iter=100): 7 | """ 8 | dot: n x in_size x out_size 9 | mask: n x in_size 10 | output: n x in_size x out_size 11 | """ 12 | n, in_size, out_size = dot.shape 13 | if return_kernel: 14 | K = torch.exp(dot / eps) 15 | else: 16 | K = dot 17 | # K: n x in_size x out_size 18 | u = K.new_ones((n, in_size)) 19 | v = K.new_ones((n, out_size)) 20 | a = float(out_size / in_size) 21 | if mask is not None: 22 | mask = mask.float() 23 | a = out_size / mask.sum(1, keepdim=True) 24 | for _ in range(max_iter): 25 | u = a / torch.bmm(K, v.view(n, out_size, 1)).view(n, in_size) 26 | if mask is not None: 27 | u = u * mask 28 | v = 1. / torch.bmm(u.view(n, 1, in_size), K).view(n, out_size) 29 | K = u.view(n, in_size, 1) * (K * v.view(n, 1, out_size)) 30 | if return_kernel: 31 | K = K / out_size 32 | return (K * dot).sum(dim=[1, 2]) 33 | return K 34 | 35 | def log_sinkhorn(K, mask=None, eps=1.0, return_kernel=False, max_iter=100): 36 | """ 37 | dot: n x in_size x out_size 38 | mask: n x in_size 39 | output: n x in_size x out_size 40 | """ 41 | batch_size, in_size, out_size = K.shape 42 | def min_eps(u, v, dim): 43 | Z = (K + u.view(batch_size, in_size, 1) + v.view(batch_size, 1, out_size)) / eps 44 | return -torch.logsumexp(Z, dim=dim) 45 | # K: batch_size x in_size x out_size 46 | u = K.new_zeros((batch_size, in_size)) 47 | v = K.new_zeros((batch_size, out_size)) 48 | a = torch.ones_like(u).fill_(out_size / in_size) 49 | if mask is not None: 50 | a = out_size / mask.float().sum(1, keepdim=True) 51 | a = torch.log(a) 52 | for _ in range(max_iter): 53 | u = eps * (a + min_eps(u, v, dim=-1)) + u 54 | if mask is not None: 55 | u = u.masked_fill(~mask, -1e8) 56 | v = eps * min_eps(u, v, dim=1) + v 57 | if return_kernel: 58 | output = torch.exp( 59 | (K + u.view(batch_size, in_size, 1) + v.view(batch_size, 1, out_size)) / eps) 60 | output = output / out_size 61 | return (output * K).sum(dim=[1, 2]) 62 | K = torch.exp( 63 | (K + u.view(batch_size, in_size, 1) + v.view(batch_size, 1, out_size)) / eps) 64 | return K 65 | 66 | def multihead_attn(input, weight, mask=None, eps=1.0, return_kernel=False, 67 | max_iter=100, log_domain=False, position_filter=None): 68 | """Comput the attention weight using Sinkhorn OT 69 | input: n x in_size x in_dim 70 | mask: n x in_size 71 | weight: m x out_size x in_dim (m: number of heads/ref) 72 | output: n x out_size x m x in_size 73 | """ 74 | n, in_size, in_dim = input.shape 75 | m, out_size = weight.shape[:-1] 76 | # K = torch.tensordot(torch.nn.functional.normalize(input,dim=-1), torch.nn.functional.normalize(weight,dim=-1), dims=[[-1], [-1]]) 77 | K = torch.tensordot(normalize(input), normalize(weight), dims=[[-1], [-1]]) 78 | # K = torch.tensordot(input, weight, dims=[[-1], [-1]]) 79 | K = K.permute(0, 2, 1, 3) 80 | if position_filter is not None: 81 | K = position_filter * K 82 | # K: n x m x in_size x out_size 83 | K = K.reshape(-1, in_size, out_size) 84 | # K = K-K.min() 85 | # K = K/K.max() 86 | # K: nm x in_size x out_size 87 | if mask is not None: 88 | mask = mask.repeat_interleave(m, dim=0) 89 | if log_domain: 90 | K = log_sinkhorn(K, mask, eps, return_kernel=return_kernel, max_iter=max_iter) 91 | else: 92 | if not return_kernel: 93 | K = torch.exp(K / eps) 94 | K = sinkhorn(K, mask, eps, return_kernel=return_kernel, max_iter=max_iter) 95 | # K: nm x in_size x out_size 96 | if return_kernel: 97 | return K.reshape(n, m) 98 | K = K.reshape(n, m, in_size, out_size) 99 | if position_filter is not None: 100 | K = position_filter * K 101 | K = K.permute(0, 3, 1, 2).contiguous() 102 | return K 103 | 104 | def wasserstein_barycenter(x, c, eps=1.0, max_iter=100, sinkhorn_iter=50, log_domain=False): 105 | """ 106 | x: n x in_size x in_dim 107 | c: out_size x in_dim 108 | """ 109 | prev_c = c 110 | for i in range(max_iter): 111 | T = attn(x, c, eps=eps, log_domain=log_domain, max_iter=sinkhorn_iter) 112 | # T: n x out_size x in_size 113 | c = 0.5*c + 0.5*torch.bmm(T, x).mean(dim=0) / math.sqrt(c.shape[0]) 114 | c /= c.norm(dim=-1, keepdim=True).clamp(min=1e-06) 115 | if ((c - prev_c) ** 2).sum() < 1e-06: 116 | break 117 | prev_c = c 118 | return c 119 | 120 | def wasserstein_kmeans(x, n_clusters, out_size, eps=1.0, block_size=None, max_iter=100, 121 | sinkhorn_iter=50, wb=False, verbose=True, log_domain=False, use_cuda=False): 122 | """ 123 | x: n x in_size x in_dim 124 | output: n_clusters x out_size x in_dim 125 | out_size <= in_size 126 | """ 127 | n, in_size, in_dim = x.shape 128 | if n_clusters == 1: 129 | if use_cuda: 130 | x = x.cuda() 131 | clusters = spherical_kmeans(x.view(-1, in_dim), out_size, block_size=block_size) 132 | if wb: 133 | clusters = wasserstein_barycenter(x, clusters, eps=0.1, log_domain=False) 134 | clusters = clusters.unsqueeze_(0) 135 | return clusters 136 | ## intialization 137 | indices = torch.randperm(n)[:n_clusters] 138 | clusters = x[indices, :out_size, :].clone() 139 | if use_cuda: 140 | clusters = clusters.cuda() 141 | 142 | wass_sim = x.new_empty(n) 143 | assign = x.new_empty(n, dtype=torch.long) 144 | if block_size is None or block_size == 0: 145 | block_size = n 146 | 147 | prev_sim = float('inf') 148 | for n_iter in range(max_iter): 149 | for i in range(0, n, block_size): 150 | end_i = min(i + block_size, n) 151 | x_batch = x[i: end_i] 152 | if use_cuda: 153 | x_batch = x_batch.cuda() 154 | tmp_sim = multihead_attn(x_batch, clusters, eps=eps, return_kernel=True, max_iter=sinkhorn_iter, log_domain=log_domain) 155 | tmp_sim = tmp_sim.cpu() 156 | wass_sim[i : end_i], assign[i: end_i] = tmp_sim.max(dim=-1) 157 | del x_batch 158 | sim = wass_sim.mean() 159 | if verbose and (n_iter + 1) % 10 == 0: 160 | print("Wasserstein spherical kmeans iter {}, objective value {}".format( 161 | n_iter + 1, sim)) 162 | 163 | for j in range(n_clusters): 164 | index = assign == j 165 | if index.sum() == 0: 166 | idx = wass_sim.argmin() 167 | clusters[j].copy_(x[idx, :out_size, :]) 168 | wass_sim[idx] = 1 169 | else: 170 | xj = x[index] 171 | if use_cuda: 172 | xj = xj.cuda() 173 | c = spherical_kmeans(xj.view(-1, in_dim), out_size, block_size=block_size, verbose=False) 174 | if wb: 175 | c = wasserstein_barycenter(xj, c, eps=0.001, log_domain=True, sinkhorn_iter=50) 176 | clusters[j] = c 177 | if torch.abs(prev_sim - sim) / sim.clamp(min=1e-10) < 1e-6: 178 | break 179 | prev_sim = sim 180 | return clusters 181 | -------------------------------------------------------------------------------- /otk/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import math 4 | import random 5 | import numpy as np 6 | 7 | EPS = 1e-6 8 | 9 | 10 | def normalize(x, p=2, dim=-1, inplace=False): 11 | norm = x.norm(p=p, dim=dim, keepdim=True) 12 | if inplace: 13 | x.div_(norm.clamp(min=EPS)) 14 | else: 15 | x = x / norm.clamp(min=EPS) 16 | return x 17 | 18 | def spherical_kmeans(x, n_clusters, max_iters=100, block_size=None, verbose=True, 19 | init=None, eps=1e-4): 20 | """Spherical kmeans 21 | Args: 22 | x (Tensor n_samples x kmer_size x n_features): data points 23 | n_clusters (int): number of clusters 24 | """ 25 | use_cuda = x.is_cuda 26 | if x.ndim == 3: 27 | n_samples, kmer_size, n_features = x.size() 28 | else: 29 | n_samples, n_features = x.size() 30 | if init is None: 31 | indices = torch.randperm(n_samples)[:n_clusters] 32 | if use_cuda: 33 | indices = indices.cuda() 34 | clusters = x[indices] 35 | 36 | prev_sim = np.inf 37 | tmp = x.new_empty(n_samples) 38 | assign = x.new_empty(n_samples, dtype=torch.long) 39 | if block_size is None or block_size == 0: 40 | block_size = x.shape[0] 41 | 42 | for n_iter in range(max_iters): 43 | for i in range(0, n_samples, block_size): 44 | end_i = min(i + block_size, n_samples) 45 | cos_sim = x[i: end_i].view(end_i - i, -1).mm(clusters.view(n_clusters, -1).t()) 46 | tmp[i: end_i], assign[i: end_i] = cos_sim.max(dim=-1) 47 | sim = tmp.mean() 48 | if (n_iter + 1) % 10 == 0 and verbose: 49 | print("Spherical kmeans iter {}, objective value {}".format( 50 | n_iter + 1, sim)) 51 | 52 | # update clusters 53 | for j in range(n_clusters): 54 | index = assign == j 55 | if index.sum() == 0: 56 | idx = tmp.argmin() 57 | clusters[j] = x[idx] 58 | tmp[idx] = 1 59 | else: 60 | xj = x[index] 61 | c = xj.mean(0) 62 | clusters[j] = c / c.norm(dim=-1, keepdim=True).clamp(min=EPS) 63 | 64 | if torch.abs(prev_sim - sim)/(torch.abs(sim)+1e-20) < 1e-6: 65 | break 66 | prev_sim = sim 67 | return clusters 68 | -------------------------------------------------------------------------------- /tools/Trainer_ours_m40r4.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | from torch.optim.lr_scheduler import CosineAnnealingLR 7 | from tensorboardX import SummaryWriter 8 | import math 9 | from torch.cuda.amp import autocast as autocast 10 | from torch.cuda.amp import GradScaler as GradScaler 11 | from otk.utils import normalize 12 | class ModelNetTrainer(object): 13 | def __init__(self, model, train_loader, val_loader, optimizer, loss_fn, \ 14 | model_name, log_dir, num_views=12): 15 | self.optimizer = optimizer 16 | self.model = model 17 | self.train_loader = train_loader 18 | self.val_loader = val_loader 19 | self.loss_fn = loss_fn 20 | self.model_name = model_name 21 | self.log_dir = log_dir 22 | self.num_views = num_views 23 | self.model.cuda() 24 | if self.log_dir is not None: 25 | self.writer = SummaryWriter(log_dir) 26 | def train(self, n_epochs): 27 | best_acc = 0 28 | i_acc = 0 29 | self.model.train() 30 | scalar = GradScaler() 31 | # if self.model_name =='view-gcn': 32 | # self.model.weight_init(self.train_loader) 33 | scheduler = CosineAnnealingLR(self.optimizer,T_max=((len(self.train_loader.dataset.rand_view_num)//20)*60),eta_min=1e-5) 34 | for epoch in range(n_epochs): 35 | rand_idx = np.random.permutation(int(len(self.train_loader.dataset.rand_view_num))) 36 | filepaths_new = [] 37 | rand_view_num_new = [] 38 | view_coord_new = [] 39 | for i in range(len(rand_idx)): 40 | filepaths_new.extend(self.train_loader.dataset.filepaths[ 41 | rand_idx[i] * self.num_views:(rand_idx[i] + 1) * self.num_views]) 42 | rand_view_num_new.append(self.train_loader.dataset.rand_view_num[rand_idx[i]]) 43 | view_coord_new.append(self.train_loader.dataset.view_coord[rand_idx[i]]) 44 | self.train_loader.dataset.filepaths = filepaths_new 45 | self.train_loader.dataset.rand_view_num = np.array(rand_view_num_new) 46 | self.train_loader.dataset.view_coord= np.array(view_coord_new) 47 | # plot learning rate 48 | lr = self.optimizer.state_dict()['param_groups'][0]['lr'] 49 | self.writer.add_scalar('params/lr', lr, epoch) 50 | # train one epoch 51 | out_data = None 52 | in_data = None 53 | for i, data in enumerate(self.train_loader): 54 | if epoch == 0: 55 | for param_group in self.optimizer.param_groups: 56 | param_group['lr'] = lr * ((i + 1) / (len(self.train_loader.dataset.rand_view_num) // 20)) 57 | if self.model_name == 'svcnn': 58 | in_data = Variable(data[1].cuda()) 59 | else: 60 | view_coord = data[3].cuda() 61 | rand_view_num = data[2].cuda() 62 | N, V, C, H, W = data[1].size() 63 | in_data = Variable(data[1]).cuda() 64 | in_data_new = [] 65 | for ii in range(N): 66 | in_data_new.append(in_data[ii,0:rand_view_num[ii]]) 67 | in_data = torch.cat(in_data_new, 0) 68 | 69 | target = Variable(data[0]).cuda().long() 70 | target2 = Variable(data[0]).cuda().long() 71 | vert = [[1, 1, 1], [1, 1, -1], [1, -1, 1], [1, -1, -1], [-1, 1, 1], [-1, 1, -1], [-1, -1, 1], 72 | [-1, -1, -1]] 73 | # vert = [[1,1],[1,-1],[-1,1],[-1,-1]] 74 | # vert = [[1,1,1,1], 75 | # [-1,1,1,1],[1,-1,1,1],[1,1,-1,1],[1,1,1,-1], 76 | # [-1,-1,1,1],[-1,1,-1,1],[-1,1,1,-1],[1,-1,-1,1],[1,-1,1,-1],[1,1,-1,-1], 77 | # [-1,-1,-1,1],[-1,-1,1,-1],[-1,1,-1,-1],[1,-1,-1,-1], 78 | # [-1,-1,-1,-1]] 79 | vert = torch.Tensor(vert).cuda() 80 | # target_ = target.unsqueeze(1).repeat(1, 2*(10+5)).view(-1) 81 | self.optimizer.zero_grad() 82 | with autocast(): 83 | if self.model_name == 'svcnn': 84 | out_data = self.model(in_data) 85 | loss = self.loss_fn(out_data, target) 86 | else: 87 | out_data,cos_sim,cos_sim2,pos= self.model(in_data,rand_view_num,N) 88 | cos_loss = cos_sim[torch.where(cos_sim>-1)].mean() 89 | cos_loss2 = cos_sim2[torch.where(cos_sim2>-1)].mean() 90 | # part_loss = self.loss_fn(part.reshape(-1,40),target2) 91 | pos_loss = torch.norm(normalize(vert)-normalize(pos),p=2,dim=-1).mean() 92 | 93 | loss = self.loss_fn(out_data, target)+0.1*pos_loss 94 | 95 | self.writer.add_scalar('train/train_loss', loss, i_acc + i + 1) 96 | 97 | pred = torch.max(out_data, 1)[1] 98 | results = pred == target 99 | correct_points = torch.sum(results.long()) 100 | 101 | acc = correct_points.float() / results.size()[0] 102 | self.writer.add_scalar('train/train_overall_acc', acc, i_acc + i + 1) 103 | # print('lr = ', str(param_group['lr'])) 104 | scalar.scale(loss).backward() 105 | scalar.unscale_(self.optimizer) 106 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 20) 107 | scalar.step(self.optimizer) 108 | scalar.update() 109 | #loss.backward() 110 | #torch.nn.utils.clip_grad_norm_(self.model.parameters(), 20) 111 | #self.optimizer.step() 112 | if epoch>0: 113 | scheduler.step() 114 | # self.optimizer.zero_grad() 115 | log_str = 'epoch %d, step %d: train_loss %.3f;cos_loss %.3f;pos_loss%.3f; train_acc %.3f' % (epoch + 1, i + 1, loss,cos_loss2,pos_loss, acc) 116 | if (i + 1) % 1 == 0: 117 | print(log_str) 118 | i_acc += i 119 | # evaluation 120 | if (epoch + 1) % 1 == 0: 121 | with torch.no_grad(): 122 | loss, val_overall_acc, val_mean_class_acc = self.update_validation_accuracy(epoch) 123 | self.writer.add_scalar('val/val_mean_class_acc', val_mean_class_acc, epoch + 1) 124 | self.writer.add_scalar('val/val_overall_acc', val_overall_acc, epoch + 1) 125 | self.writer.add_scalar('val/val_loss', loss, epoch + 1) 126 | # self.model.save(self.log_dir, epoch) 127 | # save best model 128 | if val_overall_acc > best_acc: 129 | best_acc = val_overall_acc 130 | print('best_acc', best_acc) 131 | # export scalar data to JSON for external processing 132 | self.writer.export_scalars_to_json(self.log_dir + "/all_scalars.json") 133 | self.writer.close() 134 | def update_validation_accuracy(self, epoch): 135 | all_correct_points = 0 136 | all_points = 0 137 | count = 0 138 | wrong_class = np.zeros(40) 139 | samples_class = np.zeros(40) 140 | all_loss = 0 141 | self.model.eval() 142 | for _, data in enumerate(self.val_loader, 0): 143 | if self.model_name == 'svcnn': 144 | in_data = Variable(data[1].cuda()) 145 | else: 146 | view_coord = data[3].cuda() 147 | rand_view_num = data[2].cuda() 148 | N, V, C, H, W = data[1].size() 149 | in_data = Variable(data[1]).cuda() 150 | in_data_new = [] 151 | for ii in range(N): 152 | in_data_new.append(in_data[ii, 0:rand_view_num[ii]]) 153 | in_data = torch.cat(in_data_new, 0) 154 | target = Variable(data[0]).cuda() 155 | # 156 | if self.model_name == 'svcnn': 157 | out_data = self.model(in_data) 158 | else: 159 | out_data,cos_sim,cos_sim2,pos= self.model(in_data, rand_view_num, N) 160 | # cos_loss = cos_sim[torch.where(cos_sim > 0)].mean() 161 | pred = torch.max(out_data, 1)[1] 162 | all_loss += self.loss_fn(out_data, target).cpu().data.numpy() 163 | results = pred == target 164 | 165 | for i in range(results.size()[0]): 166 | if not bool(results[i].cpu().data.numpy()): 167 | wrong_class[target.cpu().data.numpy().astype('int')[i]] += 1 168 | samples_class[target.cpu().data.numpy().astype('int')[i]] += 1 169 | correct_points = torch.sum(results.long()) 170 | 171 | all_correct_points += correct_points 172 | all_points += results.size()[0] 173 | 174 | print('Total # of test models: ', all_points) 175 | class_acc = (samples_class - wrong_class) / samples_class 176 | val_mean_class_acc = np.mean(class_acc) 177 | acc = all_correct_points.float() / all_points 178 | val_overall_acc = acc.cpu().data.numpy() 179 | loss = all_loss / len(self.val_loader) 180 | 181 | 182 | print('val mean class acc. : ', val_mean_class_acc) 183 | print('val overall acc. : ', val_overall_acc) 184 | print('val loss : ', loss) 185 | # print('cos loss : ', cos_loss) 186 | print(class_acc) 187 | self.model.train() 188 | 189 | return loss, val_overall_acc, val_mean_class_acc 190 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tools/replas.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.io as sio 3 | import glob 4 | data_path = '/data/xxinwei/dataset/objectnn_hardest/' 5 | # classnames=['airplane','bathtub','bed','bench','bookshelf','bottle','bowl','car','chair', 6 | # 'cone','cup','curtain','desk','door','dresser','flower_pot','glass_box', 7 | # 'guitar','keyboard','lamp','laptop','mantel','monitor','night_stand', 8 | # 'person','piano','plant','radio','range_hood','sink','sofa','stairs', 9 | # 'stool','table','tent','toilet','tv_stand','vase','wardrobe','xbox'] 10 | classnames = ['bag', 'bed', 'bin', 'box', 'cabinet', 'chair', 'desk', 'display', 'door', 'pillow', 'shelf', 'sink', 'sofa', 'table', 'toilet'] 11 | for i in range(len(classnames)): 12 | #view = sio.loadmat(data_path+classnames[i]+'/train/view.mat') 13 | #view1 = view['ll'] 14 | #np.save(data_path + classnames[i] + '/train/view.npy', view1) 15 | #rand_view_num = np.load('/data/xxinwei/view-transformer-hardest/ModelNet40_hardest/'+classnames[i]+'/train/random_view_num.npy') 16 | numb = glob.glob(data_path + classnames[i]+'/test/*') 17 | rand_view_num = np.random.randint(6,20,(len(numb)-2)//20) 18 | np.save(data_path + classnames[i]+'/test/random_view_num.npy',rand_view_num) 19 | ss = sio.loadmat(data_path + classnames[i]+'/test/view.mat') 20 | ss = ss['ll'] 21 | np.save(data_path + classnames[i]+'/test/view.npy',ss) -------------------------------------------------------------------------------- /train4.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch.nn as nn 3 | import torch 4 | import torch.optim as optim 5 | import os,shutil,json 6 | import argparse 7 | from tools.Trainer_ours_m40r4 import ModelNetTrainer 8 | from tools.ImgDataset_m40_6_20 import RandomMultiviewImgDataset,RandomSingleImgDataset 9 | from model.view_transformer_random4 import * 10 | 11 | def seed_torch(seed=0): 12 | random.seed(seed) 13 | os.environ['PYTHONHASHSEED'] = str(seed) 14 | np.random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | # torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 18 | torch.backends.cudnn.benchmark = False 19 | torch.backends.cudnn.deterministic = True 20 | torch.backends.cudnn.enabled = True 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("-name", "--name", type=str, help="Name of the experiment", default="4") 23 | parser.add_argument("-bs", "--batchSize", type=int, help="Batch size for the second stage", default=20)# it will be *12 images in each batch for mvcnn 24 | parser.add_argument("-num_models", type=int, help="number of models per class", default=0) 25 | parser.add_argument("-lr", type=float, help="learning rate", default=1e-3) 26 | parser.add_argument("-weight_decay", type=float, help="weight decay", default=1e-3) 27 | parser.add_argument("-no_pretraining", dest='no_pretraining', action='store_true') 28 | parser.add_argument("-cnn_name", "--cnn_name", type=str, help="cnn model name", default="resnet18") 29 | parser.add_argument("-num_views", type=int, help="number of views", default=20) 30 | parser.add_argument("-train_path", type=str, default="/home/sun/weixin/view-transformer/ModelNet40_hardest_20/*/train") 31 | parser.add_argument("-val_path", type=str, default="/home/sun/weixin/view-transformer/ModelNet40_hardest_20/*/test") 32 | parser.set_defaults(train=False) 33 | # os.environ['CUDA_VISIBLE_DEVICES']='2' 34 | def create_folder(log_dir): 35 | if not os.path.exists(log_dir): 36 | os.mkdir(log_dir) 37 | else: 38 | print('WARNING: summary folder already exists!! It will be overwritten!!') 39 | shutil.rmtree(log_dir) 40 | os.mkdir(log_dir) 41 | 42 | if __name__ == '__main__': 43 | seed_torch() 44 | args = parser.parse_args() 45 | pretraining = not args.no_pretraining 46 | log_dir = args.name 47 | create_folder(args.name) 48 | config_f = open(os.path.join(log_dir, 'config.json'), 'w') 49 | json.dump(vars(args), config_f) 50 | config_f.close() 51 | cnet = SVCNN(args.name, nclasses=40, pretraining=pretraining, cnn_name=args.cnn_name) 52 | n_models_train = args.num_models * args.num_views 53 | log_dir = args.name+'_stage_2' 54 | create_folder(log_dir) 55 | cnet_2 = view_GCN(args.name, cnet, nclasses=40, cnn_name=args.cnn_name, num_views=args.num_views) 56 | optimizer = optim.SGD(cnet_2.parameters(), lr=args.lr, weight_decay=args.weight_decay,momentum=0.9) 57 | train_dataset = RandomMultiviewImgDataset(args.train_path, scale_aug=False, rot_aug=False, num_models=n_models_train, num_views=args.num_views,test_mode=True) 58 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batchSize, shuffle=False, num_workers=12,pin_memory=True)# shuffle needs to be false! it's done within the trainer 59 | val_dataset = RandomMultiviewImgDataset(args.val_path, scale_aug=False, rot_aug=False, num_views=args.num_views,test_mode=True) 60 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batchSize, shuffle=False, num_workers=12,pin_memory=True) 61 | print('num_train_files: '+str(len(train_dataset.filepaths))) 62 | print('num_val_files: '+str(len(val_dataset.filepaths))) 63 | trainer = ModelNetTrainer(cnet_2, train_loader, val_loader, optimizer, nn.CrossEntropyLoss(), 'view-gcn', log_dir, num_views=args.num_views) 64 | trainer.train(60) --------------------------------------------------------------------------------