├── README.md
├── environment.yml
├── meshed-memory-transformer-master
├── LICENSE
├── README.md
├── data
│ ├── __init__.py
│ ├── dataset.py
│ ├── example.py
│ ├── field.py
│ ├── utils.py
│ └── vocab.py
├── environment.yml
├── evaluation
│ ├── __init__.py
│ ├── bleu
│ │ ├── __init__.py
│ │ ├── bleu.py
│ │ └── bleu_scorer.py
│ ├── cider
│ │ ├── __init__.py
│ │ ├── cider.py
│ │ └── cider_scorer.py
│ ├── meteor
│ │ ├── __init__.py
│ │ └── meteor.py
│ ├── rouge
│ │ ├── __init__.py
│ │ └── rouge.py
│ └── tokenizer.py
├── images
│ ├── m2.png
│ └── results.png
├── models
│ ├── __init__.py
│ ├── beam_search
│ │ ├── __init__.py
│ │ └── beam_search.py
│ ├── captioning_model.py
│ ├── containers.py
│ └── transformer
│ │ ├── __init__.py
│ │ ├── attention.py
│ │ ├── decoders.py
│ │ ├── encoders.py
│ │ ├── transformer.py
│ │ └── utils.py
├── output_logs
│ └── meshed_memory_transformer_test_o
├── test.py
├── train.py
├── utils
│ ├── __init__.py
│ ├── typing.py
│ └── utils.py
└── vocab.pkl
├── model
├── Model.py
├── __init__.py
├── __pycache__
│ ├── Model.cpython-36.pyc
│ ├── __init__.cpython-36.pyc
│ └── view_transformer_random4.cpython-36.pyc
├── best.py
├── gvcnn.py
├── gvcnn_random.py
├── mvcnn.py
├── mvcnn_random.py
├── view_transformer.py
└── view_transformer_random4.py
├── models
├── Vit.py
├── __init__.py
├── __pycache__
│ ├── Vit.cpython-36.pyc
│ ├── Vit.cpython-38.pyc
│ ├── __init__.cpython-36.pyc
│ ├── captioning_model.cpython-36.pyc
│ ├── containers.cpython-36.pyc
│ ├── utils.cpython-36.pyc
│ └── utils.cpython-38.pyc
├── beam_search
│ ├── __init__.py
│ ├── __pycache__
│ │ └── __init__.cpython-36.pyc
│ └── beam_search.py
├── captioning_model.py
├── containers.py
├── transformer
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── attention.cpython-36.pyc
│ │ ├── decoders.cpython-36.pyc
│ │ ├── encoders.cpython-36.pyc
│ │ ├── transformer.cpython-36.pyc
│ │ └── utils.cpython-36.pyc
│ ├── attention.py
│ ├── decoders.py
│ ├── encoders.py
│ ├── transformer.py
│ └── utils.py
├── utils.py
└── utils
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── typing.cpython-36.pyc
│ └── utils.cpython-36.pyc
│ ├── typing.py
│ └── utils.py
├── otk
├── cross_attention.py
├── data_utils.py
├── layers.py
├── models.py
├── models_deepsea.py
├── sinkhorn.py
└── utils.py
├── tools
├── ImgDataset_m40_6_20.py
├── ImgDataset_obj_6_20.py
├── Trainer_ours_m40r4.py
├── __init__.py
├── replas.py
├── test_tools.py
└── view_gcn_utils.py
└── train4.py
/README.md:
--------------------------------------------------------------------------------
1 | # Pytorch code for CVR [ICCV2021].
2 |
3 | Xin Wei*, Yifei Gong*, Fudong Wang, Xing Sun, Jian Sun. **Learning Canonical View Representation for 3D Shape Recognition with Arbitrary Views**. ICCV, accepted, 2021. [[pdf]](https://openaccess.thecvf.com/content/ICCV2021/papers/Wei_Learning_Canonical_View_Representation_for_3D_Shape_Recognition_With_Arbitrary_ICCV_2021_paper.pdf) [[supp]](https://openaccess.thecvf.com/content/ICCV2021/supplemental/Wei_Learning_Canonical_View_ICCV_2021_supplemental.pdf)
4 |
5 | ## Citation
6 | If you find our work useful in your research, please consider citing:
7 | ```
8 | @InProceedings{Wei_2021_ICCV,
9 | author = {Wei, Xin and Gong, Yifei and Wang, Fudong and Sun, Xing and Sun, Jian},
10 | title = {Learning Canonical View Representation for 3D Shape Recognition With Arbitrary Views},
11 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
12 | month = {October},
13 | year = {2021},
14 | pages = {407-416}
15 | }
16 | ```
17 |
18 | ## Training
19 |
20 | ### Requiement
21 |
22 | This code is tested on Python 3.6 and Pytorch 1.8 +
23 |
24 | ### Dataset
25 |
26 | First download the arbitrary views ModelNet40 dataset and put it under `data`
27 |
28 | `https://drive.google.com/file/d/1RfE0aJ_IXNspVs610BkMgcWDP6FB0cYX/view?usp=sharing`
29 |
30 | Link of arbitrary views ScanObjectNN dataset:
31 |
32 | `https://drive.google.com/file/d/10xl-S8-XlaX5187Dkv91pX4JQ5DZxeM8/view?usp=sharing`
33 |
34 | Aligned-ScanObjectNN dataset: `https://drive.google.com/file/d/1ihR6Fv88-6FOVUWdfHVMfDbUrx2eIPpR/view?usp=sharing`
35 |
36 | Rotated-ScanObjectNN dataset: `https://drive.google.com/file/d/1GCwgrfbO_uO3Qh9UNPWRCuz2yr8UyRRT/view?usp=sharing`
37 |
38 | #### The code is borrowed from [[OTK]](https://github.com/claying/OTK) and [[view-GCN]](https://github.com/weixmath/view-GCN).
39 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: view
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1=conda_forge
8 | - _openmp_mutex=4.5=1_llvm
9 | - blas=1.0=mkl
10 | - bzip2=1.0.8=h7f98852_4
11 | - ca-certificates=2021.1.19=h06a4308_1
12 | - certifi=2020.12.5=py38h06a4308_0
13 | - cudatoolkit=11.1.1=h6406543_8
14 | - ffmpeg=4.3=hf484d3e_0
15 | - freetype=2.10.4=h0708190_1
16 | - gmp=6.2.1=h58526e2_0
17 | - gnutls=3.6.13=h85f3911_1
18 | - jpeg=9b=h024ee3a_2
19 | - lame=3.100=h7f98852_1001
20 | - lcms2=2.12=h3be6417_0
21 | - ld_impl_linux-64=2.35.1=hea4e1c9_2
22 | - libffi=3.3=h58526e2_2
23 | - libgcc-ng=9.3.0=h2828fa1_19
24 | - libiconv=1.16=h516909a_0
25 | - libpng=1.6.37=h21135ba_2
26 | - libstdcxx-ng=9.3.0=h6de172a_19
27 | - libtiff=4.1.0=h2733197_1
28 | - libuv=1.41.0=h7f98852_0
29 | - llvm-openmp=11.1.0=h4bd325d_1
30 | - lz4-c=1.9.3=h9c3ff4c_0
31 | - mkl=2020.4=h726a3e6_304
32 | - mkl-service=2.3.0=py38h1e0a361_2
33 | - mkl_fft=1.3.0=py38h5c078b8_1
34 | - mkl_random=1.2.0=py38hc5bc63f_1
35 | - ncurses=6.2=h58526e2_4
36 | - nettle=3.6=he412f7d_0
37 | - ninja=1.10.2=h4bd325d_0
38 | - numpy=1.19.2=py38h54aff64_0
39 | - numpy-base=1.19.2=py38hfa32c7d_0
40 | - olefile=0.46=pyh9f0ad1d_1
41 | - openh264=2.1.1=h780b84a_0
42 | - openssl=1.1.1k=h27cfd23_0
43 | - pillow=8.2.0=py38he98fc37_0
44 | - pip=21.0.1=pyhd8ed1ab_0
45 | - python=3.8.8=hffdb5ce_0_cpython
46 | - python_abi=3.8=1_cp38
47 | - pytorch=1.8.1=py3.8_cuda11.1_cudnn8.0.5_0
48 | - readline=8.0=he28a2e2_2
49 | - setuptools=49.6.0=py38h578d9bd_3
50 | - six=1.15.0=pyh9f0ad1d_0
51 | - sqlite=3.35.4=h74cdb3f_0
52 | - tk=8.6.10=h21135ba_1
53 | - torchaudio=0.8.1=py38
54 | - torchvision=0.9.1=py38_cu111
55 | - typing_extensions=3.7.4.3=py_0
56 | - wheel=0.36.2=pyhd3deb0d_0
57 | - xz=5.2.5=h516909a_1
58 | - zlib=1.2.11=h516909a_1010
59 | - zstd=1.4.9=ha95c52a_0
60 | - pip:
61 | - chardet==4.0.0
62 | - idna==2.10
63 | - opencv-python==4.5.3.56
64 | - protobuf==3.15.8
65 | - requests==2.25.1
66 | - scipy==1.6.3
67 | - tensorboardx==2.2
68 | - urllib3==1.26.4
69 | prefix: /public/home/weixin/.conda/envs/view
70 |
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2019, AImageLab
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/README.md:
--------------------------------------------------------------------------------
1 | # M²: Meshed-Memory Transformer
2 | This repository contains the reference code for the paper _[Meshed-Memory Transformer for Image Captioning](https://arxiv.org/abs/1912.08226)_ (CVPR 2020).
3 |
4 | Please cite with the following BibTeX:
5 |
6 | ```
7 | @inproceedings{cornia2020m2,
8 | title={{Meshed-Memory Transformer for Image Captioning}},
9 | author={Cornia, Marcella and Stefanini, Matteo and Baraldi, Lorenzo and Cucchiara, Rita},
10 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
11 | year={2020}
12 | }
13 | ```
14 |
15 |
16 |
17 |
18 | ## Environment setup
19 | Clone the repository and create the `m2release` conda environment using the `environment.yml` file:
20 | ```
21 | conda env create -f environment.yml
22 | conda activate m2release
23 | ```
24 |
25 | Then download spacy data by executing the following command:
26 | ```
27 | python -m spacy download en
28 | ```
29 |
30 | Note: Python 3.6 is required to run our code.
31 |
32 |
33 | ## Data preparation
34 | To run the code, annotations and detection features for the COCO dataset are needed. Please download the annotations file [annotations.zip](https://drive.google.com/file/d/1i8mqKFKhqvBr8kEp3DbIh9-9UNAfKGmE/view?usp=sharing) and extract it.
35 |
36 | Detection features are computed with the code provided by [1]. To reproduce our result, please download the COCO features file [coco_detections.hdf5](https://drive.google.com/open?id=1MV6dSnqViQfyvgyHrmAT_lLpFbkzp3mx) (~53.5 GB), in which detections of each image are stored under the `_features` key. `` is the id of each COCO image, without leading zeros (e.g. the `` for `COCO_val2014_000000037209.jpg` is `37209`), and each value should be a `(N, 2048)` tensor, where `N` is the number of detections.
37 |
38 |
39 | ## Evaluation
40 | To reproduce the results reported in our paper, download the pretrained model file [meshed_memory_transformer.pth](https://drive.google.com/file/d/1naUSnVqXSMIdoiNz_fjqKwh9tXF7x8Nx/view?usp=sharing) and place it in the code folder.
41 |
42 | Run `python test.py` using the following arguments:
43 |
44 | | Argument | Possible values |
45 | |------|------|
46 | | `--batch_size` | Batch size (default: 10) |
47 | | `--workers` | Number of workers (default: 0) |
48 | | `--features_path` | Path to detection features file |
49 | | `--annotation_folder` | Path to folder with COCO annotations |
50 |
51 | #### Expected output
52 | Under `output_logs/`, you may also find the expected output of the evaluation code.
53 |
54 |
55 | ## Training procedure
56 | Run `python train.py` using the following arguments:
57 |
58 | | Argument | Possible values |
59 | |------|------|
60 | | `--exp_name` | Experiment name|
61 | | `--batch_size` | Batch size (default: 10) |
62 | | `--workers` | Number of workers (default: 0) |
63 | | `--m` | Number of memory vectors (default: 40) |
64 | | `--head` | Number of heads (default: 8) |
65 | | `--warmup` | Warmup value for learning rate scheduling (default: 10000) |
66 | | `--resume_last` | If used, the training will be resumed from the last checkpoint. |
67 | | `--resume_best` | If used, the training will be resumed from the best checkpoint. |
68 | | `--features_path` | Path to detection features file |
69 | | `--annotation_folder` | Path to folder with COCO annotations |
70 | | `--logs_folder` | Path folder for tensorboard logs (default: "tensorboard_logs")|
71 |
72 | For example, to train our model with the parameters used in our experiments, use
73 | ```
74 | python train.py --exp_name m2_transformer --batch_size 50 --m 40 --head 8 --warmup 10000 --features_path /path/to/features --annotation_folder /path/to/annotations
75 | ```
76 |
77 |
78 |
79 |
80 |
81 | #### References
82 | [1] P. Anderson, X. He, C. Buehler, D. Teney, M. Johnson, S. Gould, and L. Zhang. Bottom-up and top-down attention for image captioning and visual question answering. In _Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition_, 2018.
83 |
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .field import RawField, Merge, ImageDetectionsField, TextField
2 | from .dataset import COCO
3 | from torch.utils.data import DataLoader as TorchDataLoader
4 |
5 | class DataLoader(TorchDataLoader):
6 | def __init__(self, dataset, *args, **kwargs):
7 | super(DataLoader, self).__init__(dataset, *args, collate_fn=dataset.collate_fn(), **kwargs)
8 |
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/data/example.py:
--------------------------------------------------------------------------------
1 |
2 | class Example(object):
3 | """Defines a single training or test example.
4 | Stores each column of the example as an attribute.
5 | """
6 | @classmethod
7 | def fromdict(cls, data):
8 | ex = cls(data)
9 | return ex
10 |
11 | def __init__(self, data):
12 | for key, val in data.items():
13 | super(Example, self).__setattr__(key, val)
14 |
15 | def __setattr__(self, key, value):
16 | raise AttributeError
17 |
18 | def __hash__(self):
19 | return hash(tuple(x for x in self.__dict__.values()))
20 |
21 | def __eq__(self, other):
22 | this = tuple(x for x in self.__dict__.values())
23 | other = tuple(x for x in other.__dict__.values())
24 | return this == other
25 |
26 | def __ne__(self, other):
27 | return not self.__eq__(other)
28 |
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/data/utils.py:
--------------------------------------------------------------------------------
1 | import contextlib, sys
2 |
3 | class DummyFile(object):
4 | def write(self, x): pass
5 |
6 | @contextlib.contextmanager
7 | def nostdout():
8 | save_stdout = sys.stdout
9 | sys.stdout = DummyFile()
10 | yield
11 | sys.stdout = save_stdout
12 |
13 | def reporthook(t):
14 | """https://github.com/tqdm/tqdm"""
15 | last_b = [0]
16 |
17 | def inner(b=1, bsize=1, tsize=None):
18 | """
19 | b: int, optionala
20 | Number of blocks just transferred [default: 1].
21 | bsize: int, optional
22 | Size of each block (in tqdm units) [default: 1].
23 | tsize: int, optional
24 | Total size (in tqdm units). If [default: None] remains unchanged.
25 | """
26 | if tsize is not None:
27 | t.total = tsize
28 | t.update((b - last_b[0]) * bsize)
29 | last_b[0] = b
30 | return inner
31 |
32 | def get_tokenizer(tokenizer):
33 | if callable(tokenizer):
34 | return tokenizer
35 | if tokenizer == "spacy":
36 | try:
37 | import spacy
38 | spacy_en = spacy.load('en')
39 | return lambda s: [tok.text for tok in spacy_en.tokenizer(s)]
40 | except ImportError:
41 | print("Please install SpaCy and the SpaCy English tokenizer. "
42 | "See the docs at https://spacy.io for more information.")
43 | raise
44 | except AttributeError:
45 | print("Please install SpaCy and the SpaCy English tokenizer. "
46 | "See the docs at https://spacy.io for more information.")
47 | raise
48 | elif tokenizer == "moses":
49 | try:
50 | from nltk.tokenize.moses import MosesTokenizer
51 | moses_tokenizer = MosesTokenizer()
52 | return moses_tokenizer.tokenize
53 | except ImportError:
54 | print("Please install NLTK. "
55 | "See the docs at http://nltk.org for more information.")
56 | raise
57 | except LookupError:
58 | print("Please install the necessary NLTK corpora. "
59 | "See the docs at http://nltk.org for more information.")
60 | raise
61 | elif tokenizer == 'revtok':
62 | try:
63 | import revtok
64 | return revtok.tokenize
65 | except ImportError:
66 | print("Please install revtok.")
67 | raise
68 | elif tokenizer == 'subword':
69 | try:
70 | import revtok
71 | return lambda x: revtok.tokenize(x, decap=True)
72 | except ImportError:
73 | print("Please install revtok.")
74 | raise
75 | raise ValueError("Requested tokenizer {}, valid choices are a "
76 | "callable that takes a single string as input, "
77 | "\"revtok\" for the revtok reversible tokenizer, "
78 | "\"subword\" for the revtok caps-aware tokenizer, "
79 | "\"spacy\" for the SpaCy English tokenizer, or "
80 | "\"moses\" for the NLTK port of the Moses tokenization "
81 | "script.".format(tokenizer))
82 |
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/environment.yml:
--------------------------------------------------------------------------------
1 | name: m2release
2 | channels:
3 | - anaconda
4 | - defaults
5 | dependencies:
6 | - _libgcc_mutex=0.1=main
7 | - asn1crypto=1.2.0=py36_0
8 | - blas=1.0=mkl
9 | - ca-certificates=2019.10.16=0
10 | - certifi=2019.9.11=py36_0
11 | - cffi=1.13.2=py36h2e261b9_0
12 | - chardet=3.0.4=py36_1003
13 | - cryptography=2.8=py36h1ba5d50_0
14 | - cython=0.29.14=py36he6710b0_0
15 | - dill=0.2.9=py36_0
16 | - idna=2.8=py36_0
17 | - intel-openmp=2019.5=281
18 | - libedit=3.1.20181209=hc058e9b_0
19 | - libffi=3.2.1=hd88cf55_4
20 | - libgcc-ng=9.1.0=hdf63c60_0
21 | - libgfortran-ng=7.3.0=hdf63c60_0
22 | - libstdcxx-ng=9.1.0=hdf63c60_0
23 | - mkl=2019.5=281
24 | - mkl-service=2.3.0=py36he904b0f_0
25 | - mkl_fft=1.0.15=py36ha843d7b_0
26 | - mkl_random=1.1.0=py36hd6b4f25_0
27 | - msgpack-numpy=0.4.4.3=py_0
28 | - msgpack-python=0.5.6=py36h6bb024c_1
29 | - ncurses=6.1=he6710b0_1
30 | - openjdk=8.0.152=h46b5887_1
31 | - openssl=1.1.1=h7b6447c_0
32 | - pip=19.3.1=py36_0
33 | - pycparser=2.19=py_0
34 | - pyopenssl=19.1.0=py36_0
35 | - pysocks=1.7.1=py36_0
36 | - python=3.6.9=h265db76_0
37 | - readline=7.0=h7b6447c_5
38 | - requests=2.22.0=py36_0
39 | - setuptools=41.6.0=py36_0
40 | - six=1.13.0=py36_0
41 | - spacy=2.0.11=py36h04863e7_2
42 | - sqlite=3.30.1=h7b6447c_0
43 | - termcolor=1.1.0=py36_1
44 | - thinc=6.11.2=py36hedc7406_1
45 | - tk=8.6.8=hbc83047_0
46 | - toolz=0.10.0=py_0
47 | - urllib3=1.24.2=py36_0
48 | - wheel=0.33.6=py36_0
49 | - xz=5.2.4=h14c3975_4
50 | - zlib=1.2.11=h7b6447c_3
51 | - pip:
52 | - absl-py==0.8.1
53 | - cycler==0.10.0
54 | - cymem==1.31.2
55 | - cytoolz==0.9.0.1
56 | - future==0.17.1
57 | - grpcio==1.25.0
58 | - h5py==2.8.0
59 | - kiwisolver==1.1.0
60 | - markdown==3.1.1
61 | - matplotlib==2.2.3
62 | - msgpack==0.6.2
63 | - multiprocess==0.70.9
64 | - murmurhash==0.28.0
65 | - numpy==1.16.4
66 | - pathlib==1.0.1
67 | - pathos==0.2.3
68 | - pillow==6.2.1
69 | - plac==0.9.6
70 | - pox==0.2.7
71 | - ppft==1.6.6.1
72 | - preshed==1.0.1
73 | - protobuf==3.10.0
74 | - pycocotools==2.0.0
75 | - pyparsing==2.4.5
76 | - python-dateutil==2.8.1
77 | - pytz==2019.3
78 | - regex==2017.4.5
79 | - tensorboard==1.14.0
80 | - torch==1.1.0
81 | - torchvision==0.3.0
82 | - tqdm==4.32.2
83 | - ujson==1.35
84 | - werkzeug==0.16.0
85 | - wrapt==1.10.11
86 |
87 |
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | from .bleu import Bleu
2 | from .meteor import Meteor
3 | from .rouge import Rouge
4 | from .cider import Cider
5 | from .tokenizer import PTBTokenizer
6 |
7 | def compute_scores(gts, gen):
8 | metrics = (Bleu(), Meteor(), Rouge(), Cider())
9 | all_score = {}
10 | all_scores = {}
11 | for metric in metrics:
12 | score, scores = metric.compute_score(gts, gen)
13 | all_score[str(metric)] = score
14 | all_scores[str(metric)] = scores
15 |
16 | return all_score, all_scores
17 |
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/evaluation/bleu/__init__.py:
--------------------------------------------------------------------------------
1 | from .bleu import Bleu
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/evaluation/bleu/bleu.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # File Name : bleu.py
4 | #
5 | # Description : Wrapper for BLEU scorer.
6 | #
7 | # Creation Date : 06-01-2015
8 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT
9 | # Authors : Hao Fang and Tsung-Yi Lin
10 |
11 | from .bleu_scorer import BleuScorer
12 |
13 |
14 | class Bleu:
15 | def __init__(self, n=4):
16 | # default compute Blue score up to 4
17 | self._n = n
18 | self._hypo_for_image = {}
19 | self.ref_for_image = {}
20 |
21 | def compute_score(self, gts, res):
22 |
23 | assert(gts.keys() == res.keys())
24 | imgIds = gts.keys()
25 |
26 | bleu_scorer = BleuScorer(n=self._n)
27 | for id in imgIds:
28 | hypo = res[id]
29 | ref = gts[id]
30 |
31 | # Sanity check.
32 | assert(type(hypo) is list)
33 | assert(len(hypo) == 1)
34 | assert(type(ref) is list)
35 | assert(len(ref) >= 1)
36 |
37 | bleu_scorer += (hypo[0], ref)
38 |
39 | # score, scores = bleu_scorer.compute_score(option='shortest')
40 | score, scores = bleu_scorer.compute_score(option='closest', verbose=0)
41 | # score, scores = bleu_scorer.compute_score(option='average', verbose=1)
42 |
43 | return score, scores
44 |
45 | def __str__(self):
46 | return 'BLEU'
47 |
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/evaluation/bleu/bleu_scorer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # bleu_scorer.py
4 | # David Chiang
5 |
6 | # Copyright (c) 2004-2006 University of Maryland. All rights
7 | # reserved. Do not redistribute without permission from the
8 | # author. Not for commercial use.
9 |
10 | # Modified by:
11 | # Hao Fang
12 | # Tsung-Yi Lin
13 |
14 | ''' Provides:
15 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test().
16 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked().
17 | '''
18 |
19 | import copy
20 | import sys, math, re
21 | from collections import defaultdict
22 |
23 |
24 | def precook(s, n=4, out=False):
25 | """Takes a string as input and returns an object that can be given to
26 | either cook_refs or cook_test. This is optional: cook_refs and cook_test
27 | can take string arguments as well."""
28 | words = s.split()
29 | counts = defaultdict(int)
30 | for k in range(1, n + 1):
31 | for i in range(len(words) - k + 1):
32 | ngram = tuple(words[i:i + k])
33 | counts[ngram] += 1
34 | return (len(words), counts)
35 |
36 |
37 | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average"
38 | '''Takes a list of reference sentences for a single segment
39 | and returns an object that encapsulates everything that BLEU
40 | needs to know about them.'''
41 |
42 | reflen = []
43 | maxcounts = {}
44 | for ref in refs:
45 | rl, counts = precook(ref, n)
46 | reflen.append(rl)
47 | for (ngram, count) in counts.items():
48 | maxcounts[ngram] = max(maxcounts.get(ngram, 0), count)
49 |
50 | # Calculate effective reference sentence length.
51 | if eff == "shortest":
52 | reflen = min(reflen)
53 | elif eff == "average":
54 | reflen = float(sum(reflen)) / len(reflen)
55 |
56 | ## lhuang: N.B.: leave reflen computaiton to the very end!!
57 |
58 | ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design)
59 |
60 | return (reflen, maxcounts)
61 |
62 |
63 | def cook_test(test, ref_tuple, eff=None, n=4):
64 | '''Takes a test sentence and returns an object that
65 | encapsulates everything that BLEU needs to know about it.'''
66 |
67 | testlen, counts = precook(test, n, True)
68 | reflen, refmaxcounts = ref_tuple
69 |
70 | result = {}
71 |
72 | # Calculate effective reference sentence length.
73 |
74 | if eff == "closest":
75 | result["reflen"] = min((abs(l - testlen), l) for l in reflen)[1]
76 | else: ## i.e., "average" or "shortest" or None
77 | result["reflen"] = reflen
78 |
79 | result["testlen"] = testlen
80 |
81 | result["guess"] = [max(0, testlen - k + 1) for k in range(1, n + 1)]
82 |
83 | result['correct'] = [0] * n
84 | for (ngram, count) in counts.items():
85 | result["correct"][len(ngram) - 1] += min(refmaxcounts.get(ngram, 0), count)
86 |
87 | return result
88 |
89 |
90 | class BleuScorer(object):
91 | """Bleu scorer.
92 | """
93 |
94 | __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen"
95 |
96 | # special_reflen is used in oracle (proportional effective ref len for a node).
97 |
98 | def copy(self):
99 | ''' copy the refs.'''
100 | new = BleuScorer(n=self.n)
101 | new.ctest = copy.copy(self.ctest)
102 | new.crefs = copy.copy(self.crefs)
103 | new._score = None
104 | return new
105 |
106 | def __init__(self, test=None, refs=None, n=4, special_reflen=None):
107 | ''' singular instance '''
108 |
109 | self.n = n
110 | self.crefs = []
111 | self.ctest = []
112 | self.cook_append(test, refs)
113 | self.special_reflen = special_reflen
114 |
115 | def cook_append(self, test, refs):
116 | '''called by constructor and __iadd__ to avoid creating new instances.'''
117 |
118 | if refs is not None:
119 | self.crefs.append(cook_refs(refs))
120 | if test is not None:
121 | cooked_test = cook_test(test, self.crefs[-1])
122 | self.ctest.append(cooked_test) ## N.B.: -1
123 | else:
124 | self.ctest.append(None) # lens of crefs and ctest have to match
125 |
126 | self._score = None ## need to recompute
127 |
128 | def ratio(self, option=None):
129 | self.compute_score(option=option)
130 | return self._ratio
131 |
132 | def score_ratio(self, option=None):
133 | '''
134 | return (bleu, len_ratio) pair
135 | '''
136 |
137 | return self.fscore(option=option), self.ratio(option=option)
138 |
139 | def score_ratio_str(self, option=None):
140 | return "%.4f (%.2f)" % self.score_ratio(option)
141 |
142 | def reflen(self, option=None):
143 | self.compute_score(option=option)
144 | return self._reflen
145 |
146 | def testlen(self, option=None):
147 | self.compute_score(option=option)
148 | return self._testlen
149 |
150 | def retest(self, new_test):
151 | if type(new_test) is str:
152 | new_test = [new_test]
153 | assert len(new_test) == len(self.crefs), new_test
154 | self.ctest = []
155 | for t, rs in zip(new_test, self.crefs):
156 | self.ctest.append(cook_test(t, rs))
157 | self._score = None
158 |
159 | return self
160 |
161 | def rescore(self, new_test):
162 | ''' replace test(s) with new test(s), and returns the new score.'''
163 |
164 | return self.retest(new_test).compute_score()
165 |
166 | def size(self):
167 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
168 | return len(self.crefs)
169 |
170 | def __iadd__(self, other):
171 | '''add an instance (e.g., from another sentence).'''
172 |
173 | if type(other) is tuple:
174 | ## avoid creating new BleuScorer instances
175 | self.cook_append(other[0], other[1])
176 | else:
177 | assert self.compatible(other), "incompatible BLEUs."
178 | self.ctest.extend(other.ctest)
179 | self.crefs.extend(other.crefs)
180 | self._score = None ## need to recompute
181 |
182 | return self
183 |
184 | def compatible(self, other):
185 | return isinstance(other, BleuScorer) and self.n == other.n
186 |
187 | def single_reflen(self, option="average"):
188 | return self._single_reflen(self.crefs[0][0], option)
189 |
190 | def _single_reflen(self, reflens, option=None, testlen=None):
191 |
192 | if option == "shortest":
193 | reflen = min(reflens)
194 | elif option == "average":
195 | reflen = float(sum(reflens)) / len(reflens)
196 | elif option == "closest":
197 | reflen = min((abs(l - testlen), l) for l in reflens)[1]
198 | else:
199 | assert False, "unsupported reflen option %s" % option
200 |
201 | return reflen
202 |
203 | def recompute_score(self, option=None, verbose=0):
204 | self._score = None
205 | return self.compute_score(option, verbose)
206 |
207 | def compute_score(self, option=None, verbose=0):
208 | n = self.n
209 | small = 1e-9
210 | tiny = 1e-15 ## so that if guess is 0 still return 0
211 | bleu_list = [[] for _ in range(n)]
212 |
213 | if self._score is not None:
214 | return self._score
215 |
216 | if option is None:
217 | option = "average" if len(self.crefs) == 1 else "closest"
218 |
219 | self._testlen = 0
220 | self._reflen = 0
221 | totalcomps = {'testlen': 0, 'reflen': 0, 'guess': [0] * n, 'correct': [0] * n}
222 |
223 | # for each sentence
224 | for comps in self.ctest:
225 | testlen = comps['testlen']
226 | self._testlen += testlen
227 |
228 | if self.special_reflen is None: ## need computation
229 | reflen = self._single_reflen(comps['reflen'], option, testlen)
230 | else:
231 | reflen = self.special_reflen
232 |
233 | self._reflen += reflen
234 |
235 | for key in ['guess', 'correct']:
236 | for k in range(n):
237 | totalcomps[key][k] += comps[key][k]
238 |
239 | # append per image bleu score
240 | bleu = 1.
241 | for k in range(n):
242 | bleu *= (float(comps['correct'][k]) + tiny) \
243 | / (float(comps['guess'][k]) + small)
244 | bleu_list[k].append(bleu ** (1. / (k + 1)))
245 | ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division
246 | if ratio < 1:
247 | for k in range(n):
248 | bleu_list[k][-1] *= math.exp(1 - 1 / ratio)
249 |
250 | if verbose > 1:
251 | print(comps, reflen)
252 |
253 | totalcomps['reflen'] = self._reflen
254 | totalcomps['testlen'] = self._testlen
255 |
256 | bleus = []
257 | bleu = 1.
258 | for k in range(n):
259 | bleu *= float(totalcomps['correct'][k] + tiny) \
260 | / (totalcomps['guess'][k] + small)
261 | bleus.append(bleu ** (1. / (k + 1)))
262 | ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division
263 | if ratio < 1:
264 | for k in range(n):
265 | bleus[k] *= math.exp(1 - 1 / ratio)
266 |
267 | if verbose > 0:
268 | print(totalcomps)
269 | print("ratio:", ratio)
270 |
271 | self._score = bleus
272 | return self._score, bleu_list
273 |
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/evaluation/cider/__init__.py:
--------------------------------------------------------------------------------
1 | from .cider import Cider
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/evaluation/cider/cider.py:
--------------------------------------------------------------------------------
1 | # Filename: cider.py
2 | #
3 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric
4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726)
5 | #
6 | # Creation Date: Sun Feb 8 14:16:54 2015
7 | #
8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin
9 |
10 | from .cider_scorer import CiderScorer
11 |
12 | class Cider:
13 | """
14 | Main Class to compute the CIDEr metric
15 |
16 | """
17 | def __init__(self, gts=None, n=4, sigma=6.0):
18 | # set cider to sum over 1 to 4-grams
19 | self._n = n
20 | # set the standard deviation parameter for gaussian penalty
21 | self._sigma = sigma
22 | self.doc_frequency = None
23 | self.ref_len = None
24 | if gts is not None:
25 | tmp_cider = CiderScorer(gts, n=self._n, sigma=self._sigma)
26 | self.doc_frequency = tmp_cider.doc_frequency
27 | self.ref_len = tmp_cider.ref_len
28 |
29 | def compute_score(self, gts, res):
30 | """
31 | Main function to compute CIDEr score
32 | :param gts (dict) : dictionary with key and value
33 | res (dict) : dictionary with key and value
34 | :return: cider (float) : computed CIDEr score for the corpus
35 | """
36 | assert(gts.keys() == res.keys())
37 | cider_scorer = CiderScorer(gts, test=res, n=self._n, sigma=self._sigma, doc_frequency=self.doc_frequency,
38 | ref_len=self.ref_len)
39 | return cider_scorer.compute_score()
40 |
41 | def __str__(self):
42 | return 'CIDEr'
43 |
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/evaluation/cider/cider_scorer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Tsung-Yi Lin
3 | # Ramakrishna Vedantam
4 |
5 | import copy
6 | from collections import defaultdict
7 | import numpy as np
8 | import math
9 |
10 | def precook(s, n=4):
11 | """
12 | Takes a string as input and returns an object that can be given to
13 | either cook_refs or cook_test. This is optional: cook_refs and cook_test
14 | can take string arguments as well.
15 | :param s: string : sentence to be converted into ngrams
16 | :param n: int : number of ngrams for which representation is calculated
17 | :return: term frequency vector for occuring ngrams
18 | """
19 | words = s.split()
20 | counts = defaultdict(int)
21 | for k in range(1,n+1):
22 | for i in range(len(words)-k+1):
23 | ngram = tuple(words[i:i+k])
24 | counts[ngram] += 1
25 | return counts
26 |
27 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average"
28 | '''Takes a list of reference sentences for a single segment
29 | and returns an object that encapsulates everything that BLEU
30 | needs to know about them.
31 | :param refs: list of string : reference sentences for some image
32 | :param n: int : number of ngrams for which (ngram) representation is calculated
33 | :return: result (list of dict)
34 | '''
35 | return [precook(ref, n) for ref in refs]
36 |
37 | def cook_test(test, n=4):
38 | '''Takes a test sentence and returns an object that
39 | encapsulates everything that BLEU needs to know about it.
40 | :param test: list of string : hypothesis sentence for some image
41 | :param n: int : number of ngrams for which (ngram) representation is calculated
42 | :return: result (dict)
43 | '''
44 | return precook(test, n)
45 |
46 | class CiderScorer(object):
47 | """CIDEr scorer.
48 | """
49 |
50 | def __init__(self, refs, test=None, n=4, sigma=6.0, doc_frequency=None, ref_len=None):
51 | ''' singular instance '''
52 | self.n = n
53 | self.sigma = sigma
54 | self.crefs = []
55 | self.ctest = []
56 | self.doc_frequency = defaultdict(float)
57 | self.ref_len = None
58 |
59 | for k in refs.keys():
60 | self.crefs.append(cook_refs(refs[k]))
61 | if test is not None:
62 | self.ctest.append(cook_test(test[k][0])) ## N.B.: -1
63 | else:
64 | self.ctest.append(None) # lens of crefs and ctest have to match
65 |
66 | if doc_frequency is None and ref_len is None:
67 | # compute idf
68 | self.compute_doc_freq()
69 | # compute log reference length
70 | self.ref_len = np.log(float(len(self.crefs)))
71 | else:
72 | self.doc_frequency = doc_frequency
73 | self.ref_len = ref_len
74 |
75 | def compute_doc_freq(self):
76 | '''
77 | Compute term frequency for reference data.
78 | This will be used to compute idf (inverse document frequency later)
79 | The term frequency is stored in the object
80 | :return: None
81 | '''
82 | for refs in self.crefs:
83 | # refs, k ref captions of one image
84 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]):
85 | self.doc_frequency[ngram] += 1
86 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
87 |
88 | def compute_cider(self):
89 | def counts2vec(cnts):
90 | """
91 | Function maps counts of ngram to vector of tfidf weights.
92 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights.
93 | The n-th entry of array denotes length of n-grams.
94 | :param cnts:
95 | :return: vec (array of dict), norm (array of float), length (int)
96 | """
97 | vec = [defaultdict(float) for _ in range(self.n)]
98 | length = 0
99 | norm = [0.0 for _ in range(self.n)]
100 | for (ngram,term_freq) in cnts.items():
101 | # give word count 1 if it doesn't appear in reference corpus
102 | df = np.log(max(1.0, self.doc_frequency[ngram]))
103 | # ngram index
104 | n = len(ngram)-1
105 | # tf (term_freq) * idf (precomputed idf) for n-grams
106 | vec[n][ngram] = float(term_freq)*(self.ref_len - df)
107 | # compute norm for the vector. the norm will be used for computing similarity
108 | norm[n] += pow(vec[n][ngram], 2)
109 |
110 | if n == 1:
111 | length += term_freq
112 | norm = [np.sqrt(n) for n in norm]
113 | return vec, norm, length
114 |
115 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
116 | '''
117 | Compute the cosine similarity of two vectors.
118 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis
119 | :param vec_ref: array of dictionary for vector corresponding to reference
120 | :param norm_hyp: array of float for vector corresponding to hypothesis
121 | :param norm_ref: array of float for vector corresponding to reference
122 | :param length_hyp: int containing length of hypothesis
123 | :param length_ref: int containing length of reference
124 | :return: array of score for each n-grams cosine similarity
125 | '''
126 | delta = float(length_hyp - length_ref)
127 | # measure consine similarity
128 | val = np.array([0.0 for _ in range(self.n)])
129 | for n in range(self.n):
130 | # ngram
131 | for (ngram,count) in vec_hyp[n].items():
132 | # vrama91 : added clipping
133 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram]
134 |
135 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
136 | val[n] /= (norm_hyp[n]*norm_ref[n])
137 |
138 | assert(not math.isnan(val[n]))
139 | # vrama91: added a length based gaussian penalty
140 | val[n] *= np.e**(-(delta**2)/(2*self.sigma**2))
141 | return val
142 |
143 | scores = []
144 | for test, refs in zip(self.ctest, self.crefs):
145 | # compute vector for test captions
146 | vec, norm, length = counts2vec(test)
147 | # compute vector for ref captions
148 | score = np.array([0.0 for _ in range(self.n)])
149 | for ref in refs:
150 | vec_ref, norm_ref, length_ref = counts2vec(ref)
151 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
152 | # change by vrama91 - mean of ngram scores, instead of sum
153 | score_avg = np.mean(score)
154 | # divide by number of references
155 | score_avg /= len(refs)
156 | # multiply score by 10
157 | score_avg *= 10.0
158 | # append score of an image to the score list
159 | scores.append(score_avg)
160 | return scores
161 |
162 | def compute_score(self):
163 | # compute cider score
164 | score = self.compute_cider()
165 | # debug
166 | # print score
167 | return np.mean(np.array(score)), np.array(score)
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/evaluation/meteor/__init__.py:
--------------------------------------------------------------------------------
1 | from .meteor import Meteor
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/evaluation/meteor/meteor.py:
--------------------------------------------------------------------------------
1 | # Python wrapper for METEOR implementation, by Xinlei Chen
2 | # Acknowledge Michael Denkowski for the generous discussion and help
3 |
4 | import os
5 | import subprocess
6 | import threading
7 | import tarfile
8 | from utils import download_from_url
9 |
10 | METEOR_GZ_URL = 'http://aimagelab.ing.unimore.it/speaksee/data/meteor.tgz'
11 | METEOR_JAR = 'meteor-1.5.jar'
12 |
13 | class Meteor:
14 | def __init__(self):
15 | base_path = os.path.dirname(os.path.abspath(__file__))
16 | jar_path = os.path.join(base_path, METEOR_JAR)
17 | gz_path = os.path.join(base_path, os.path.basename(METEOR_GZ_URL))
18 | if not os.path.isfile(jar_path):
19 | if not os.path.isfile(gz_path):
20 | download_from_url(METEOR_GZ_URL, gz_path)
21 | tar = tarfile.open(gz_path, "r")
22 | tar.extractall(path=os.path.dirname(os.path.abspath(__file__)))
23 | tar.close()
24 | os.remove(gz_path)
25 |
26 | self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \
27 | '-', '-', '-stdio', '-l', 'en', '-norm']
28 | self.meteor_p = subprocess.Popen(self.meteor_cmd, \
29 | cwd=os.path.dirname(os.path.abspath(__file__)), \
30 | stdin=subprocess.PIPE, \
31 | stdout=subprocess.PIPE, \
32 | stderr=subprocess.PIPE)
33 | # Used to guarantee thread safety
34 | self.lock = threading.Lock()
35 |
36 | def compute_score(self, gts, res):
37 | assert(gts.keys() == res.keys())
38 | imgIds = gts.keys()
39 | scores = []
40 |
41 | eval_line = 'EVAL'
42 | self.lock.acquire()
43 | for i in imgIds:
44 | assert(len(res[i]) == 1)
45 | stat = self._stat(res[i][0], gts[i])
46 | eval_line += ' ||| {}'.format(stat)
47 |
48 | self.meteor_p.stdin.write('{}\n'.format(eval_line).encode())
49 | self.meteor_p.stdin.flush()
50 | for i in range(0,len(imgIds)):
51 | scores.append(float(self.meteor_p.stdout.readline().strip()))
52 | score = float(self.meteor_p.stdout.readline().strip())
53 | self.lock.release()
54 |
55 | return score, scores
56 |
57 | def _stat(self, hypothesis_str, reference_list):
58 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words
59 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ')
60 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str))
61 | self.meteor_p.stdin.write('{}\n'.format(score_line).encode())
62 | self.meteor_p.stdin.flush()
63 | raw = self.meteor_p.stdout.readline().decode().strip()
64 | numbers = [str(int(float(n))) for n in raw.split()]
65 | return ' '.join(numbers)
66 |
67 | def __del__(self):
68 | self.lock.acquire()
69 | self.meteor_p.stdin.close()
70 | self.meteor_p.kill()
71 | self.meteor_p.wait()
72 | self.lock.release()
73 |
74 | def __str__(self):
75 | return 'METEOR'
76 |
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/evaluation/rouge/__init__.py:
--------------------------------------------------------------------------------
1 | from .rouge import Rouge
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/evaluation/rouge/rouge.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # File Name : rouge.py
4 | #
5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004)
6 | #
7 | # Creation Date : 2015-01-07 06:03
8 | # Author : Ramakrishna Vedantam
9 |
10 | import numpy as np
11 | import pdb
12 |
13 |
14 | def my_lcs(string, sub):
15 | """
16 | Calculates longest common subsequence for a pair of tokenized strings
17 | :param string : list of str : tokens from a string split using whitespace
18 | :param sub : list of str : shorter string, also split using whitespace
19 | :returns: length (list of int): length of the longest common subsequence between the two strings
20 |
21 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS
22 | """
23 | if (len(string) < len(sub)):
24 | sub, string = string, sub
25 |
26 | lengths = [[0 for i in range(0, len(sub) + 1)] for j in range(0, len(string) + 1)]
27 |
28 | for j in range(1, len(sub) + 1):
29 | for i in range(1, len(string) + 1):
30 | if (string[i - 1] == sub[j - 1]):
31 | lengths[i][j] = lengths[i - 1][j - 1] + 1
32 | else:
33 | lengths[i][j] = max(lengths[i - 1][j], lengths[i][j - 1])
34 |
35 | return lengths[len(string)][len(sub)]
36 |
37 |
38 | class Rouge():
39 | '''
40 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set
41 |
42 | '''
43 |
44 | def __init__(self):
45 | # vrama91: updated the value below based on discussion with Hovey
46 | self.beta = 1.2
47 |
48 | def calc_score(self, candidate, refs):
49 | """
50 | Compute ROUGE-L score given one candidate and references for an image
51 | :param candidate: str : candidate sentence to be evaluated
52 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated
53 | :returns score: int (ROUGE-L score for the candidate evaluated against references)
54 | """
55 | assert (len(candidate) == 1)
56 | assert (len(refs) > 0)
57 | prec = []
58 | rec = []
59 |
60 | # split into tokens
61 | token_c = candidate[0].split(" ")
62 |
63 | for reference in refs:
64 | # split into tokens
65 | token_r = reference.split(" ")
66 | # compute the longest common subsequence
67 | lcs = my_lcs(token_r, token_c)
68 | prec.append(lcs / float(len(token_c)))
69 | rec.append(lcs / float(len(token_r)))
70 |
71 | prec_max = max(prec)
72 | rec_max = max(rec)
73 |
74 | if (prec_max != 0 and rec_max != 0):
75 | score = ((1 + self.beta ** 2) * prec_max * rec_max) / float(rec_max + self.beta ** 2 * prec_max)
76 | else:
77 | score = 0.0
78 | return score
79 |
80 | def compute_score(self, gts, res):
81 | """
82 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset
83 | Invoked by evaluate_captions.py
84 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values
85 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values
86 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images)
87 | """
88 | assert (gts.keys() == res.keys())
89 | imgIds = gts.keys()
90 |
91 | score = []
92 | for id in imgIds:
93 | hypo = res[id]
94 | ref = gts[id]
95 |
96 | score.append(self.calc_score(hypo, ref))
97 |
98 | # Sanity check.
99 | assert (type(hypo) is list)
100 | assert (len(hypo) == 1)
101 | assert (type(ref) is list)
102 | assert (len(ref) > 0)
103 |
104 | average_score = np.mean(np.array(score))
105 | return average_score, np.array(score)
106 |
107 | def __str__(self):
108 | return 'ROUGE'
109 |
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/evaluation/tokenizer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # File Name : ptbtokenizer.py
4 | #
5 | # Description : Do the PTB Tokenization and remove punctuations.
6 | #
7 | # Creation Date : 29-12-2014
8 | # Last Modified : Thu Mar 19 09:53:35 2015
9 | # Authors : Hao Fang and Tsung-Yi Lin
10 |
11 | import os
12 | import subprocess
13 | import tempfile
14 |
15 | class PTBTokenizer(object):
16 | """Python wrapper of Stanford PTBTokenizer"""
17 |
18 | corenlp_jar = 'stanford-corenlp-3.4.1.jar'
19 | punctuations = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \
20 | ".", "?", "!", ",", ":", "-", "--", "...", ";"]
21 |
22 | @classmethod
23 | def tokenize(cls, corpus):
24 | cmd = ['java', '-cp', cls.corenlp_jar, \
25 | 'edu.stanford.nlp.process.PTBTokenizer', \
26 | '-preserveLines', '-lowerCase']
27 |
28 | if isinstance(corpus, list) or isinstance(corpus, tuple):
29 | if isinstance(corpus[0], list) or isinstance(corpus[0], tuple):
30 | corpus = {i:c for i, c in enumerate(corpus)}
31 | else:
32 | corpus = {i: [c, ] for i, c in enumerate(corpus)}
33 |
34 | # prepare data for PTB Tokenizer
35 | tokenized_corpus = {}
36 | image_id = [k for k, v in list(corpus.items()) for _ in range(len(v))]
37 | sentences = '\n'.join([c.replace('\n', ' ') for k, v in corpus.items() for c in v])
38 |
39 | # save sentences to temporary file
40 | path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__))
41 | tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname)
42 | tmp_file.write(sentences.encode())
43 | tmp_file.close()
44 |
45 | # tokenize sentence
46 | cmd.append(os.path.basename(tmp_file.name))
47 | p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dirname, \
48 | stdout=subprocess.PIPE, stderr=open(os.devnull, 'w'))
49 | token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0]
50 | token_lines = token_lines.decode()
51 | lines = token_lines.split('\n')
52 | # remove temp file
53 | os.remove(tmp_file.name)
54 |
55 | # create dictionary for tokenized captions
56 | for k, line in zip(image_id, lines):
57 | if not k in tokenized_corpus:
58 | tokenized_corpus[k] = []
59 | tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \
60 | if w not in cls.punctuations])
61 | tokenized_corpus[k].append(tokenized_caption)
62 |
63 | return tokenized_corpus
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/images/m2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/meshed-memory-transformer-master/images/m2.png
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/images/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/meshed-memory-transformer-master/images/results.png
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .transformer import Transformer
2 | from .captioning_model import CaptioningModel
3 |
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/models/beam_search/__init__.py:
--------------------------------------------------------------------------------
1 | from .beam_search import BeamSearch
2 |
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/models/beam_search/beam_search.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import utils
3 |
4 |
5 | class BeamSearch(object):
6 | def __init__(self, model, max_len: int, eos_idx: int, beam_size: int):
7 | self.model = model
8 | self.max_len = max_len
9 | self.eos_idx = eos_idx
10 | self.beam_size = beam_size
11 | self.b_s = None
12 | self.device = None
13 | self.seq_mask = None
14 | self.seq_logprob = None
15 | self.outputs = None
16 | self.log_probs = None
17 | self.selected_words = None
18 | self.all_log_probs = None
19 |
20 | def _expand_state(self, selected_beam, cur_beam_size):
21 | def fn(s):
22 | shape = [int(sh) for sh in s.shape]
23 | beam = selected_beam
24 | for _ in shape[1:]:
25 | beam = beam.unsqueeze(-1)
26 | s = torch.gather(s.view(*([self.b_s, cur_beam_size] + shape[1:])), 1,
27 | beam.expand(*([self.b_s, self.beam_size] + shape[1:])))
28 | s = s.view(*([-1, ] + shape[1:]))
29 | return s
30 |
31 | return fn
32 |
33 | def _expand_visual(self, visual: utils.TensorOrSequence, cur_beam_size: int, selected_beam: torch.Tensor):
34 | if isinstance(visual, torch.Tensor):
35 | visual_shape = visual.shape
36 | visual_exp_shape = (self.b_s, cur_beam_size) + visual_shape[1:]
37 | visual_red_shape = (self.b_s * self.beam_size,) + visual_shape[1:]
38 | selected_beam_red_size = (self.b_s, self.beam_size) + tuple(1 for _ in range(len(visual_exp_shape) - 2))
39 | selected_beam_exp_size = (self.b_s, self.beam_size) + visual_exp_shape[2:]
40 | visual_exp = visual.view(visual_exp_shape)
41 | selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size)
42 | visual = torch.gather(visual_exp, 1, selected_beam_exp).view(visual_red_shape)
43 | else:
44 | new_visual = []
45 | for im in visual:
46 | visual_shape = im.shape
47 | visual_exp_shape = (self.b_s, cur_beam_size) + visual_shape[1:]
48 | visual_red_shape = (self.b_s * self.beam_size,) + visual_shape[1:]
49 | selected_beam_red_size = (self.b_s, self.beam_size) + tuple(1 for _ in range(len(visual_exp_shape) - 2))
50 | selected_beam_exp_size = (self.b_s, self.beam_size) + visual_exp_shape[2:]
51 | visual_exp = im.view(visual_exp_shape)
52 | selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size)
53 | new_im = torch.gather(visual_exp, 1, selected_beam_exp).view(visual_red_shape)
54 | new_visual.append(new_im)
55 | visual = tuple(new_visual)
56 | return visual
57 |
58 | def apply(self, visual: utils.TensorOrSequence, out_size=1, return_probs=False, **kwargs):
59 | self.b_s = utils.get_batch_size(visual)
60 | self.device = utils.get_device(visual)
61 | self.seq_mask = torch.ones((self.b_s, self.beam_size, 1), device=self.device)
62 | self.seq_logprob = torch.zeros((self.b_s, 1, 1), device=self.device)
63 | self.log_probs = []
64 | self.selected_words = None
65 | if return_probs:
66 | self.all_log_probs = []
67 |
68 | outputs = []
69 | with self.model.statefulness(self.b_s):
70 | for t in range(self.max_len):
71 | visual, outputs = self.iter(t, visual, outputs, return_probs, **kwargs)
72 |
73 | # Sort result
74 | seq_logprob, sort_idxs = torch.sort(self.seq_logprob, 1, descending=True)
75 | outputs = torch.cat(outputs, -1)
76 | outputs = torch.gather(outputs, 1, sort_idxs.expand(self.b_s, self.beam_size, self.max_len))
77 | log_probs = torch.cat(self.log_probs, -1)
78 | log_probs = torch.gather(log_probs, 1, sort_idxs.expand(self.b_s, self.beam_size, self.max_len))
79 | if return_probs:
80 | all_log_probs = torch.cat(self.all_log_probs, 2)
81 | all_log_probs = torch.gather(all_log_probs, 1, sort_idxs.unsqueeze(-1).expand(self.b_s, self.beam_size,
82 | self.max_len,
83 | all_log_probs.shape[-1]))
84 |
85 | outputs = outputs.contiguous()[:, :out_size]
86 | log_probs = log_probs.contiguous()[:, :out_size]
87 | if out_size == 1:
88 | outputs = outputs.squeeze(1)
89 | log_probs = log_probs.squeeze(1)
90 |
91 | if return_probs:
92 | return outputs, log_probs, all_log_probs
93 | else:
94 | return outputs, log_probs
95 |
96 | def select(self, t, candidate_logprob, **kwargs):
97 | selected_logprob, selected_idx = torch.sort(candidate_logprob.view(self.b_s, -1), -1, descending=True)
98 | selected_logprob, selected_idx = selected_logprob[:, :self.beam_size], selected_idx[:, :self.beam_size]
99 | return selected_idx, selected_logprob
100 |
101 | def iter(self, t: int, visual: utils.TensorOrSequence, outputs, return_probs, **kwargs):
102 | cur_beam_size = 1 if t == 0 else self.beam_size
103 |
104 | word_logprob = self.model.step(t, self.selected_words, visual, None, mode='feedback', **kwargs)
105 | word_logprob = word_logprob.view(self.b_s, cur_beam_size, -1)
106 | candidate_logprob = self.seq_logprob + word_logprob
107 |
108 | # Mask sequence if it reaches EOS
109 | if t > 0:
110 | mask = (self.selected_words.view(self.b_s, cur_beam_size) != self.eos_idx).float().unsqueeze(-1)
111 | self.seq_mask = self.seq_mask * mask
112 | word_logprob = word_logprob * self.seq_mask.expand_as(word_logprob)
113 | old_seq_logprob = self.seq_logprob.expand_as(candidate_logprob).contiguous()
114 | old_seq_logprob[:, :, 1:] = -999
115 | candidate_logprob = self.seq_mask * candidate_logprob + old_seq_logprob * (1 - self.seq_mask)
116 |
117 | selected_idx, selected_logprob = self.select(t, candidate_logprob, **kwargs)
118 | selected_beam = selected_idx / candidate_logprob.shape[-1]
119 | selected_words = selected_idx - selected_beam * candidate_logprob.shape[-1]
120 |
121 | self.model.apply_to_states(self._expand_state(selected_beam, cur_beam_size))
122 | visual = self._expand_visual(visual, cur_beam_size, selected_beam)
123 |
124 | self.seq_logprob = selected_logprob.unsqueeze(-1)
125 | self.seq_mask = torch.gather(self.seq_mask, 1, selected_beam.unsqueeze(-1))
126 | outputs = list(torch.gather(o, 1, selected_beam.unsqueeze(-1)) for o in outputs)
127 | outputs.append(selected_words.unsqueeze(-1))
128 |
129 | if return_probs:
130 | if t == 0:
131 | self.all_log_probs.append(word_logprob.expand((self.b_s, self.beam_size, -1)).unsqueeze(2))
132 | else:
133 | self.all_log_probs.append(word_logprob.unsqueeze(2))
134 |
135 | this_word_logprob = torch.gather(word_logprob, 1,
136 | selected_beam.unsqueeze(-1).expand(self.b_s, self.beam_size,
137 | word_logprob.shape[-1]))
138 | this_word_logprob = torch.gather(this_word_logprob, 2, selected_words.unsqueeze(-1))
139 | self.log_probs = list(
140 | torch.gather(o, 1, selected_beam.unsqueeze(-1).expand(self.b_s, self.beam_size, 1)) for o in self.log_probs)
141 | self.log_probs.append(this_word_logprob)
142 | self.selected_words = selected_words.view(-1, 1)
143 |
144 | return visual, outputs
145 |
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/models/captioning_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import distributions
3 | import utils
4 | from models.containers import Module
5 | from models.beam_search import *
6 |
7 |
8 | class CaptioningModel(Module):
9 | def __init__(self):
10 | super(CaptioningModel, self).__init__()
11 |
12 | def init_weights(self):
13 | raise NotImplementedError
14 |
15 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs):
16 | raise NotImplementedError
17 |
18 | def forward(self, images, seq, *args):
19 | device = images.device
20 | b_s = images.size(0)
21 | seq_len = seq.size(1)
22 | state = self.init_state(b_s, device)
23 | out = None
24 |
25 | outputs = []
26 | for t in range(seq_len):
27 | out, state = self.step(t, state, out, images, seq, *args, mode='teacher_forcing')
28 | outputs.append(out)
29 |
30 | outputs = torch.cat([o.unsqueeze(1) for o in outputs], 1)
31 | return outputs
32 |
33 | def test(self, visual: utils.TensorOrSequence, max_len: int, eos_idx: int, **kwargs) -> utils.Tuple[torch.Tensor, torch.Tensor]:
34 | b_s = utils.get_batch_size(visual)
35 | device = utils.get_device(visual)
36 | outputs = []
37 | log_probs = []
38 |
39 | mask = torch.ones((b_s,), device=device)
40 | with self.statefulness(b_s):
41 | out = None
42 | for t in range(max_len):
43 | log_probs_t = self.step(t, out, visual, None, mode='feedback', **kwargs)
44 | out = torch.max(log_probs_t, -1)[1]
45 | mask = mask * (out.squeeze(-1) != eos_idx).float()
46 | log_probs.append(log_probs_t * mask.unsqueeze(-1).unsqueeze(-1))
47 | outputs.append(out)
48 |
49 | return torch.cat(outputs, 1), torch.cat(log_probs, 1)
50 |
51 | def sample_rl(self, visual: utils.TensorOrSequence, max_len: int, **kwargs) -> utils.Tuple[torch.Tensor, torch.Tensor]:
52 | b_s = utils.get_batch_size(visual)
53 | outputs = []
54 | log_probs = []
55 |
56 | with self.statefulness(b_s):
57 | out = None
58 | for t in range(max_len):
59 | out = self.step(t, out, visual, None, mode='feedback', **kwargs)
60 | distr = distributions.Categorical(logits=out[:, 0])
61 | out = distr.sample().unsqueeze(1)
62 | outputs.append(out)
63 | log_probs.append(distr.log_prob(out).unsqueeze(1))
64 |
65 | return torch.cat(outputs, 1), torch.cat(log_probs, 1)
66 |
67 | def beam_search(self, visual: utils.TensorOrSequence, max_len: int, eos_idx: int, beam_size: int, out_size=1,
68 | return_probs=False, **kwargs):
69 | bs = BeamSearch(self, max_len, eos_idx, beam_size)
70 | return bs.apply(visual, out_size, return_probs, **kwargs)
71 |
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/models/containers.py:
--------------------------------------------------------------------------------
1 | from contextlib import contextmanager
2 | from torch import nn
3 | from utils.typing import *
4 |
5 |
6 | class Module(nn.Module):
7 | def __init__(self):
8 | super(Module, self).__init__()
9 | self._is_stateful = False
10 | self._state_names = []
11 | self._state_defaults = dict()
12 |
13 | def register_state(self, name: str, default: TensorOrNone):
14 | self._state_names.append(name)
15 | if default is None:
16 | self._state_defaults[name] = None
17 | else:
18 | self._state_defaults[name] = default.clone().detach()
19 | self.register_buffer(name, default)
20 |
21 | def states(self):
22 | for name in self._state_names:
23 | yield self._buffers[name]
24 | for m in self.children():
25 | if isinstance(m, Module):
26 | yield from m.states()
27 |
28 | def apply_to_states(self, fn):
29 | for name in self._state_names:
30 | self._buffers[name] = fn(self._buffers[name])
31 | for m in self.children():
32 | if isinstance(m, Module):
33 | m.apply_to_states(fn)
34 |
35 | def _init_states(self, batch_size: int):
36 | for name in self._state_names:
37 | if self._state_defaults[name] is None:
38 | self._buffers[name] = None
39 | else:
40 | self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device)
41 | self._buffers[name] = self._buffers[name].unsqueeze(0)
42 | self._buffers[name] = self._buffers[name].expand([batch_size, ] + list(self._buffers[name].shape[1:]))
43 | self._buffers[name] = self._buffers[name].contiguous()
44 |
45 | def _reset_states(self):
46 | for name in self._state_names:
47 | if self._state_defaults[name] is None:
48 | self._buffers[name] = None
49 | else:
50 | self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device)
51 |
52 | def enable_statefulness(self, batch_size: int):
53 | for m in self.children():
54 | if isinstance(m, Module):
55 | m.enable_statefulness(batch_size)
56 | self._init_states(batch_size)
57 | self._is_stateful = True
58 |
59 | def disable_statefulness(self):
60 | for m in self.children():
61 | if isinstance(m, Module):
62 | m.disable_statefulness()
63 | self._reset_states()
64 | self._is_stateful = False
65 |
66 | @contextmanager
67 | def statefulness(self, batch_size: int):
68 | self.enable_statefulness(batch_size)
69 | try:
70 | yield
71 | finally:
72 | self.disable_statefulness()
73 |
74 |
75 | class ModuleList(nn.ModuleList, Module):
76 | pass
77 |
78 |
79 | class ModuleDict(nn.ModuleDict, Module):
80 | pass
81 |
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/models/transformer/__init__.py:
--------------------------------------------------------------------------------
1 | from .transformer import *
2 | from .encoders import *
3 | from .decoders import *
4 | from .attention import *
5 |
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/models/transformer/attention.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 | from models.containers import Module
5 |
6 |
7 | class ScaledDotProductAttention(nn.Module):
8 | '''
9 | Scaled dot-product attention
10 | '''
11 |
12 | def __init__(self, d_model, d_k, d_v, h):
13 | '''
14 | :param d_model: Output dimensionality of the model
15 | :param d_k: Dimensionality of queries and keys
16 | :param d_v: Dimensionality of values
17 | :param h: Number of heads
18 | '''
19 | super(ScaledDotProductAttention, self).__init__()
20 | self.fc_q = nn.Linear(d_model, h * d_k)
21 | self.fc_k = nn.Linear(d_model, h * d_k)
22 | self.fc_v = nn.Linear(d_model, h * d_v)
23 | self.fc_o = nn.Linear(h * d_v, d_model)
24 |
25 | self.d_model = d_model
26 | self.d_k = d_k
27 | self.d_v = d_v
28 | self.h = h
29 |
30 | self.init_weights()
31 |
32 | def init_weights(self):
33 | nn.init.xavier_uniform_(self.fc_q.weight)
34 | nn.init.xavier_uniform_(self.fc_k.weight)
35 | nn.init.xavier_uniform_(self.fc_v.weight)
36 | nn.init.xavier_uniform_(self.fc_o.weight)
37 | nn.init.constant_(self.fc_q.bias, 0)
38 | nn.init.constant_(self.fc_k.bias, 0)
39 | nn.init.constant_(self.fc_v.bias, 0)
40 | nn.init.constant_(self.fc_o.bias, 0)
41 |
42 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
43 | '''
44 | Computes
45 | :param queries: Queries (b_s, nq, d_model)
46 | :param keys: Keys (b_s, nk, d_model)
47 | :param values: Values (b_s, nk, d_model)
48 | :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.
49 | :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).
50 | :return:
51 | '''
52 | b_s, nq = queries.shape[:2]
53 | nk = keys.shape[1]
54 | q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k)
55 | k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk)
56 | v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v)
57 |
58 | att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk)
59 | if attention_weights is not None:
60 | att = att * attention_weights
61 | if attention_mask is not None:
62 | att = att.masked_fill(attention_mask, -np.inf)
63 | att = torch.softmax(att, -1)
64 | out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v)
65 | out = self.fc_o(out) # (b_s, nq, d_model)
66 | return out
67 |
68 |
69 | class ScaledDotProductAttentionMemory(nn.Module):
70 | '''
71 | Scaled dot-product attention with memory
72 | '''
73 |
74 | def __init__(self, d_model, d_k, d_v, h, m):
75 | '''
76 | :param d_model: Output dimensionality of the model
77 | :param d_k: Dimensionality of queries and keys
78 | :param d_v: Dimensionality of values
79 | :param h: Number of heads
80 | :param m: Number of memory slots
81 | '''
82 | super(ScaledDotProductAttentionMemory, self).__init__()
83 | self.fc_q = nn.Linear(d_model, h * d_k)
84 | self.fc_k = nn.Linear(d_model, h * d_k)
85 | self.fc_v = nn.Linear(d_model, h * d_v)
86 | self.fc_o = nn.Linear(h * d_v, d_model)
87 | self.m_k = nn.Parameter(torch.FloatTensor(1, m, h * d_k))
88 | self.m_v = nn.Parameter(torch.FloatTensor(1, m, h * d_v))
89 |
90 | self.d_model = d_model
91 | self.d_k = d_k
92 | self.d_v = d_v
93 | self.h = h
94 | self.m = m
95 |
96 | self.init_weights()
97 |
98 | def init_weights(self):
99 | nn.init.xavier_uniform_(self.fc_q.weight)
100 | nn.init.xavier_uniform_(self.fc_k.weight)
101 | nn.init.xavier_uniform_(self.fc_v.weight)
102 | nn.init.xavier_uniform_(self.fc_o.weight)
103 | nn.init.normal_(self.m_k, 0, 1 / self.d_k)
104 | nn.init.normal_(self.m_v, 0, 1 / self.m)
105 | nn.init.constant_(self.fc_q.bias, 0)
106 | nn.init.constant_(self.fc_k.bias, 0)
107 | nn.init.constant_(self.fc_v.bias, 0)
108 | nn.init.constant_(self.fc_o.bias, 0)
109 |
110 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
111 | '''
112 | Computes
113 | :param queries: Queries (b_s, nq, d_model)
114 | :param keys: Keys (b_s, nk, d_model)
115 | :param values: Values (b_s, nk, d_model)
116 | :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.
117 | :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).
118 | :return:
119 | '''
120 | b_s, nq = queries.shape[:2]
121 | nk = keys.shape[1]
122 |
123 | m_k = np.sqrt(self.d_k) * self.m_k.expand(b_s, self.m, self.h * self.d_k)
124 | m_v = np.sqrt(self.m) * self.m_v.expand(b_s, self.m, self.h * self.d_v)
125 |
126 | q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k)
127 | k = torch.cat([self.fc_k(keys), m_k], 1).view(b_s, nk + self.m, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk)
128 | v = torch.cat([self.fc_v(values), m_v], 1).view(b_s, nk + self.m, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v)
129 |
130 | att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk)
131 | if attention_weights is not None:
132 | att = torch.cat([att[:, :, :, :nk] * attention_weights, att[:, :, :, nk:]], -1)
133 | if attention_mask is not None:
134 | att[:, :, :, :nk] = att[:, :, :, :nk].masked_fill(attention_mask, -np.inf)
135 | att = torch.softmax(att, -1)
136 | out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v)
137 | out = self.fc_o(out) # (b_s, nq, d_model)
138 | return out
139 |
140 |
141 | class MultiHeadAttention(Module):
142 | '''
143 | Multi-head attention layer with Dropout and Layer Normalization.
144 | '''
145 |
146 | def __init__(self, d_model, d_k, d_v, h, dropout=.1, identity_map_reordering=False, can_be_stateful=False,
147 | attention_module=None, attention_module_kwargs=None):
148 | super(MultiHeadAttention, self).__init__()
149 | self.identity_map_reordering = identity_map_reordering
150 | if attention_module is not None:
151 | if attention_module_kwargs is not None:
152 | self.attention = attention_module(d_model=d_model, d_k=d_k, d_v=d_v, h=h, **attention_module_kwargs)
153 | else:
154 | self.attention = attention_module(d_model=d_model, d_k=d_k, d_v=d_v, h=h)
155 | else:
156 | self.attention = ScaledDotProductAttention(d_model=d_model, d_k=d_k, d_v=d_v, h=h)
157 | self.dropout = nn.Dropout(p=dropout)
158 | self.layer_norm = nn.LayerNorm(d_model)
159 |
160 | self.can_be_stateful = can_be_stateful
161 | if self.can_be_stateful:
162 | self.register_state('running_keys', torch.zeros((0, d_model)))
163 | self.register_state('running_values', torch.zeros((0, d_model)))
164 |
165 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
166 | if self.can_be_stateful and self._is_stateful:
167 | self.running_keys = torch.cat([self.running_keys, keys], 1)
168 | keys = self.running_keys
169 |
170 | self.running_values = torch.cat([self.running_values, values], 1)
171 | values = self.running_values
172 |
173 | if self.identity_map_reordering:
174 | q_norm = self.layer_norm(queries)
175 | k_norm = self.layer_norm(keys)
176 | v_norm = self.layer_norm(values)
177 | out = self.attention(q_norm, k_norm, v_norm, attention_mask, attention_weights)
178 | out = queries + self.dropout(torch.relu(out))
179 | else:
180 | out = self.attention(queries, keys, values, attention_mask, attention_weights)
181 | out = self.dropout(out)
182 | out = self.layer_norm(queries + out)
183 | return out
184 |
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/models/transformer/decoders.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | import numpy as np
5 |
6 | from models.transformer.attention import MultiHeadAttention
7 | from models.transformer.utils import sinusoid_encoding_table, PositionWiseFeedForward
8 | from models.containers import Module, ModuleList
9 |
10 |
11 | class MeshedDecoderLayer(Module):
12 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, self_att_module=None,
13 | enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None):
14 | super(MeshedDecoderLayer, self).__init__()
15 | self.self_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=True,
16 | attention_module=self_att_module,
17 | attention_module_kwargs=self_att_module_kwargs)
18 | self.enc_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=False,
19 | attention_module=enc_att_module,
20 | attention_module_kwargs=enc_att_module_kwargs)
21 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout)
22 |
23 | self.fc_alpha1 = nn.Linear(d_model + d_model, d_model)
24 | self.fc_alpha2 = nn.Linear(d_model + d_model, d_model)
25 | self.fc_alpha3 = nn.Linear(d_model + d_model, d_model)
26 |
27 | self.init_weights()
28 |
29 | def init_weights(self):
30 | nn.init.xavier_uniform_(self.fc_alpha1.weight)
31 | nn.init.xavier_uniform_(self.fc_alpha2.weight)
32 | nn.init.xavier_uniform_(self.fc_alpha3.weight)
33 | nn.init.constant_(self.fc_alpha1.bias, 0)
34 | nn.init.constant_(self.fc_alpha2.bias, 0)
35 | nn.init.constant_(self.fc_alpha3.bias, 0)
36 |
37 | def forward(self, input, enc_output, mask_pad, mask_self_att, mask_enc_att):
38 | self_att = self.self_att(input, input, input, mask_self_att)
39 | self_att = self_att * mask_pad
40 |
41 | enc_att1 = self.enc_att(self_att, enc_output[:, 0], enc_output[:, 0], mask_enc_att) * mask_pad
42 | enc_att2 = self.enc_att(self_att, enc_output[:, 1], enc_output[:, 1], mask_enc_att) * mask_pad
43 | enc_att3 = self.enc_att(self_att, enc_output[:, 2], enc_output[:, 2], mask_enc_att) * mask_pad
44 |
45 | alpha1 = torch.sigmoid(self.fc_alpha1(torch.cat([self_att, enc_att1], -1)))
46 | alpha2 = torch.sigmoid(self.fc_alpha2(torch.cat([self_att, enc_att2], -1)))
47 | alpha3 = torch.sigmoid(self.fc_alpha3(torch.cat([self_att, enc_att3], -1)))
48 |
49 | enc_att = (enc_att1 * alpha1 + enc_att2 * alpha2 + enc_att3 * alpha3) / np.sqrt(3)
50 | enc_att = enc_att * mask_pad
51 |
52 | ff = self.pwff(enc_att)
53 | ff = ff * mask_pad
54 | return ff
55 |
56 |
57 | class MeshedDecoder(Module):
58 | def __init__(self, vocab_size, max_len, N_dec, padding_idx, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1,
59 | self_att_module=None, enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None):
60 | super(MeshedDecoder, self).__init__()
61 | self.d_model = d_model
62 | self.word_emb = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
63 | self.pos_emb = nn.Embedding.from_pretrained(sinusoid_encoding_table(max_len + 1, d_model, 0), freeze=True)
64 | self.layers = ModuleList(
65 | [MeshedDecoderLayer(d_model, d_k, d_v, h, d_ff, dropout, self_att_module=self_att_module,
66 | enc_att_module=enc_att_module, self_att_module_kwargs=self_att_module_kwargs,
67 | enc_att_module_kwargs=enc_att_module_kwargs) for _ in range(N_dec)])
68 | self.fc = nn.Linear(d_model, vocab_size, bias=False)
69 | self.max_len = max_len
70 | self.padding_idx = padding_idx
71 | self.N = N_dec
72 |
73 | self.register_state('running_mask_self_attention', torch.zeros((1, 1, 0)).byte())
74 | self.register_state('running_seq', torch.zeros((1,)).long())
75 |
76 | def forward(self, input, encoder_output, mask_encoder):
77 | # input (b_s, seq_len)
78 | b_s, seq_len = input.shape[:2]
79 | mask_queries = (input != self.padding_idx).unsqueeze(-1).float() # (b_s, seq_len, 1)
80 | mask_self_attention = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8, device=input.device),
81 | diagonal=1)
82 | mask_self_attention = mask_self_attention.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len)
83 | mask_self_attention = mask_self_attention + (input == self.padding_idx).unsqueeze(1).unsqueeze(1).byte()
84 | mask_self_attention = mask_self_attention.gt(0) # (b_s, 1, seq_len, seq_len)
85 | if self._is_stateful:
86 | self.running_mask_self_attention = torch.cat([self.running_mask_self_attention, mask_self_attention], -1)
87 | mask_self_attention = self.running_mask_self_attention
88 |
89 | seq = torch.arange(1, seq_len + 1).view(1, -1).expand(b_s, -1).to(input.device) # (b_s, seq_len)
90 | seq = seq.masked_fill(mask_queries.squeeze(-1) == 0, 0)
91 | if self._is_stateful:
92 | self.running_seq.add_(1)
93 | seq = self.running_seq
94 |
95 | out = self.word_emb(input) + self.pos_emb(seq)
96 | for i, l in enumerate(self.layers):
97 | out = l(out, encoder_output, mask_queries, mask_self_attention, mask_encoder)
98 |
99 | out = self.fc(out)
100 | return F.log_softmax(out, dim=-1)
101 |
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/models/transformer/encoders.py:
--------------------------------------------------------------------------------
1 | from torch.nn import functional as F
2 | from models.transformer.utils import PositionWiseFeedForward
3 | import torch
4 | from torch import nn
5 | from models.transformer.attention import MultiHeadAttention
6 |
7 |
8 | class EncoderLayer(nn.Module):
9 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, identity_map_reordering=False,
10 | attention_module=None, attention_module_kwargs=None):
11 | super(EncoderLayer, self).__init__()
12 | self.identity_map_reordering = identity_map_reordering
13 | self.mhatt = MultiHeadAttention(d_model, d_k, d_v, h, dropout, identity_map_reordering=identity_map_reordering,
14 | attention_module=attention_module,
15 | attention_module_kwargs=attention_module_kwargs)
16 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout, identity_map_reordering=identity_map_reordering)
17 |
18 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
19 | att = self.mhatt(queries, keys, values, attention_mask, attention_weights)
20 | ff = self.pwff(att)
21 | return ff
22 |
23 |
24 | class MultiLevelEncoder(nn.Module):
25 | def __init__(self, N, padding_idx, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1,
26 | identity_map_reordering=False, attention_module=None, attention_module_kwargs=None):
27 | super(MultiLevelEncoder, self).__init__()
28 | self.d_model = d_model
29 | self.dropout = dropout
30 | self.layers = nn.ModuleList([EncoderLayer(d_model, d_k, d_v, h, d_ff, dropout,
31 | identity_map_reordering=identity_map_reordering,
32 | attention_module=attention_module,
33 | attention_module_kwargs=attention_module_kwargs)
34 | for _ in range(N)])
35 | self.padding_idx = padding_idx
36 |
37 | def forward(self, input, attention_weights=None):
38 | # input (b_s, seq_len, d_in)
39 | attention_mask = (torch.sum(input, -1) == self.padding_idx).unsqueeze(1).unsqueeze(1) # (b_s, 1, 1, seq_len)
40 |
41 | outs = []
42 | out = input
43 | for l in self.layers:
44 | out = l(out, out, out, attention_mask, attention_weights)
45 | outs.append(out.unsqueeze(1))
46 |
47 | outs = torch.cat(outs, 1)
48 | return outs, attention_mask
49 |
50 |
51 | class MemoryAugmentedEncoder(MultiLevelEncoder):
52 | def __init__(self, N, padding_idx, d_in=2048, **kwargs):
53 | super(MemoryAugmentedEncoder, self).__init__(N, padding_idx, **kwargs)
54 | self.fc = nn.Linear(d_in, self.d_model)
55 | self.dropout = nn.Dropout(p=self.dropout)
56 | self.layer_norm = nn.LayerNorm(self.d_model)
57 |
58 | def forward(self, input, attention_weights=None):
59 | out = F.relu(self.fc(input))
60 | out = self.dropout(out)
61 | out = self.layer_norm(out)
62 | return super(MemoryAugmentedEncoder, self).forward(out, attention_weights=attention_weights)
63 |
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/models/transformer/transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import copy
4 | from models.containers import ModuleList
5 | from ..captioning_model import CaptioningModel
6 |
7 |
8 | class Transformer(CaptioningModel):
9 | def __init__(self, bos_idx, encoder, decoder):
10 | super(Transformer, self).__init__()
11 | self.bos_idx = bos_idx
12 | self.encoder = encoder
13 | self.decoder = decoder
14 | self.register_state('enc_output', None)
15 | self.register_state('mask_enc', None)
16 | self.init_weights()
17 |
18 | @property
19 | def d_model(self):
20 | return self.decoder.d_model
21 |
22 | def init_weights(self):
23 | for p in self.parameters():
24 | if p.dim() > 1:
25 | nn.init.xavier_uniform_(p)
26 |
27 | def forward(self, images, seq, *args):
28 | enc_output, mask_enc = self.encoder(images)
29 | dec_output = self.decoder(seq, enc_output, mask_enc)
30 | return dec_output
31 |
32 | def init_state(self, b_s, device):
33 | return [torch.zeros((b_s, 0), dtype=torch.long, device=device),
34 | None, None]
35 |
36 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs):
37 | it = None
38 | if mode == 'teacher_forcing':
39 | raise NotImplementedError
40 | elif mode == 'feedback':
41 | if t == 0:
42 | self.enc_output, self.mask_enc = self.encoder(visual)
43 | if isinstance(visual, torch.Tensor):
44 | it = visual.data.new_full((visual.shape[0], 1), self.bos_idx).long()
45 | else:
46 | it = visual[0].data.new_full((visual[0].shape[0], 1), self.bos_idx).long()
47 | else:
48 | it = prev_output
49 |
50 | return self.decoder(it, self.enc_output, self.mask_enc)
51 |
52 |
53 | class TransformerEnsemble(CaptioningModel):
54 | def __init__(self, model: Transformer, weight_files):
55 | super(TransformerEnsemble, self).__init__()
56 | self.n = len(weight_files)
57 | self.models = ModuleList([copy.deepcopy(model) for _ in range(self.n)])
58 | for i in range(self.n):
59 | state_dict_i = torch.load(weight_files[i])['state_dict']
60 | self.models[i].load_state_dict(state_dict_i)
61 |
62 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs):
63 | out_ensemble = []
64 | for i in range(self.n):
65 | out_i = self.models[i].step(t, prev_output, visual, seq, mode, **kwargs)
66 | out_ensemble.append(out_i.unsqueeze(0))
67 |
68 | return torch.mean(torch.cat(out_ensemble, 0), dim=0)
69 |
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/models/transformer/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 |
6 | def position_embedding(input, d_model):
7 | input = input.view(-1, 1)
8 | dim = torch.arange(d_model // 2, dtype=torch.float32, device=input.device).view(1, -1)
9 | sin = torch.sin(input / 10000 ** (2 * dim / d_model))
10 | cos = torch.cos(input / 10000 ** (2 * dim / d_model))
11 |
12 | out = torch.zeros((input.shape[0], d_model), device=input.device)
13 | out[:, ::2] = sin
14 | out[:, 1::2] = cos
15 | return out
16 |
17 |
18 | def sinusoid_encoding_table(max_len, d_model, padding_idx=None):
19 | pos = torch.arange(max_len, dtype=torch.float32)
20 | out = position_embedding(pos, d_model)
21 |
22 | if padding_idx is not None:
23 | out[padding_idx] = 0
24 | return out
25 |
26 |
27 | class PositionWiseFeedForward(nn.Module):
28 | '''
29 | Position-wise feed forward layer
30 | '''
31 |
32 | def __init__(self, d_model=512, d_ff=2048, dropout=.1, identity_map_reordering=False):
33 | super(PositionWiseFeedForward, self).__init__()
34 | self.identity_map_reordering = identity_map_reordering
35 | self.fc1 = nn.Linear(d_model, d_ff)
36 | self.fc2 = nn.Linear(d_ff, d_model)
37 | self.dropout = nn.Dropout(p=dropout)
38 | self.dropout_2 = nn.Dropout(p=dropout)
39 | self.layer_norm = nn.LayerNorm(d_model)
40 |
41 | def forward(self, input):
42 | if self.identity_map_reordering:
43 | out = self.layer_norm(input)
44 | out = self.fc2(self.dropout_2(F.relu(self.fc1(out))))
45 | out = input + self.dropout(torch.relu(out))
46 | else:
47 | out = self.fc2(self.dropout_2(F.relu(self.fc1(input))))
48 | out = self.dropout(out)
49 | out = self.layer_norm(input + out)
50 | return out
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/output_logs/meshed_memory_transformer_test_o:
--------------------------------------------------------------------------------
1 | Meshed-Memory Transformer Evaluation
2 | {'BLEU': [0.8076084272899184, 0.65337618312199, 0.5093125587687117, 0.3909357911782391], 'METEOR': 0.2918900660095916, 'ROUGE': 0.5863539878042495, 'CIDEr': 1.3119740267338893}
3 |
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/test.py:
--------------------------------------------------------------------------------
1 | import random
2 | from data import ImageDetectionsField, TextField, RawField
3 | from data import COCO, DataLoader
4 | import evaluation
5 | from models.transformer import Transformer, MemoryAugmentedEncoder, MeshedDecoder, ScaledDotProductAttentionMemory
6 | import torch
7 | from tqdm import tqdm
8 | import argparse
9 | import pickle
10 | import numpy as np
11 |
12 | random.seed(1234)
13 | torch.manual_seed(1234)
14 | np.random.seed(1234)
15 |
16 |
17 | def predict_captions(model, dataloader, text_field):
18 | import itertools
19 | model.eval()
20 | gen = {}
21 | gts = {}
22 | with tqdm(desc='Evaluation', unit='it', total=len(dataloader)) as pbar:
23 | for it, (images, caps_gt) in enumerate(iter(dataloader)):
24 | images = images.to(device)
25 | with torch.no_grad():
26 | out, _ = model.beam_search(images, 20, text_field.vocab.stoi[''], 5, out_size=1)
27 | caps_gen = text_field.decode(out, join_words=False)
28 | for i, (gts_i, gen_i) in enumerate(zip(caps_gt, caps_gen)):
29 | gen_i = ' '.join([k for k, g in itertools.groupby(gen_i)])
30 | gen['%d_%d' % (it, i)] = [gen_i.strip(), ]
31 | gts['%d_%d' % (it, i)] = gts_i
32 | pbar.update()
33 |
34 | gts = evaluation.PTBTokenizer.tokenize(gts)
35 | gen = evaluation.PTBTokenizer.tokenize(gen)
36 | scores, _ = evaluation.compute_scores(gts, gen)
37 |
38 | return scores
39 |
40 |
41 | if __name__ == '__main__':
42 | device = torch.device('cuda')
43 |
44 | parser = argparse.ArgumentParser(description='Meshed-Memory Transformer')
45 | parser.add_argument('--batch_size', type=int, default=10)
46 | parser.add_argument('--workers', type=int, default=0)
47 | parser.add_argument('--features_path', type=str)
48 | parser.add_argument('--annotation_folder', type=str)
49 | args = parser.parse_args()
50 |
51 | print('Meshed-Memory Transformer Evaluation')
52 |
53 | # Pipeline for image regions
54 | image_field = ImageDetectionsField(detections_path=args.features_path, max_detections=50, load_in_tmp=False)
55 |
56 | # Pipeline for text
57 | text_field = TextField(init_token='', eos_token='', lower=True, tokenize='spacy',
58 | remove_punctuation=True, nopoints=False)
59 |
60 | # Create the dataset
61 | dataset = COCO(image_field, text_field, 'coco/images/', args.annotation_folder, args.annotation_folder)
62 | _, _, test_dataset = dataset.splits
63 | text_field.vocab = pickle.load(open('vocab.pkl', 'rb'))
64 |
65 | # Model and dataloaders
66 | encoder = MemoryAugmentedEncoder(3, 0, attention_module=ScaledDotProductAttentionMemory,
67 | attention_module_kwargs={'m': 40})
68 | decoder = MeshedDecoder(len(text_field.vocab), 54, 3, text_field.vocab.stoi[''])
69 | model = Transformer(text_field.vocab.stoi[''], encoder, decoder).to(device)
70 |
71 | data = torch.load('meshed_memory_transformer.pth')
72 | model.load_state_dict(data['state_dict'])
73 |
74 | dict_dataset_test = test_dataset.image_dictionary({'image': image_field, 'text': RawField()})
75 | dict_dataloader_test = DataLoader(dict_dataset_test, batch_size=args.batch_size, num_workers=args.workers)
76 |
77 | scores = predict_captions(model, dict_dataloader_test, text_field)
78 | print(scores)
79 |
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import download_from_url
2 | from .typing import *
3 |
4 | def get_batch_size(x: TensorOrSequence) -> int:
5 | if isinstance(x, torch.Tensor):
6 | b_s = x.size(0)
7 | else:
8 | b_s = x[0].size(0)
9 | return b_s
10 |
11 |
12 | def get_device(x: TensorOrSequence) -> int:
13 | if isinstance(x, torch.Tensor):
14 | b_s = x.device
15 | else:
16 | b_s = x[0].device
17 | return b_s
18 |
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/utils/typing.py:
--------------------------------------------------------------------------------
1 | from typing import Union, Sequence, Tuple
2 | import torch
3 |
4 | TensorOrSequence = Union[Sequence[torch.Tensor], torch.Tensor]
5 | TensorOrNone = Union[torch.Tensor, None]
6 |
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/utils/utils.py:
--------------------------------------------------------------------------------
1 | import requests
2 |
3 | def download_from_url(url, path):
4 | """Download file, with logic (from tensor2tensor) for Google Drive"""
5 | if 'drive.google.com' not in url:
6 | print('Downloading %s; may take a few minutes' % url)
7 | r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'})
8 | with open(path, "wb") as file:
9 | file.write(r.content)
10 | return
11 | print('Downloading from Google Drive; may take a few minutes')
12 | confirm_token = None
13 | session = requests.Session()
14 | response = session.get(url, stream=True)
15 | for k, v in response.cookies.items():
16 | if k.startswith("download_warning"):
17 | confirm_token = v
18 |
19 | if confirm_token:
20 | url = url + "&confirm=" + confirm_token
21 | response = session.get(url, stream=True)
22 |
23 | chunk_size = 16 * 1024
24 | with open(path, "wb") as f:
25 | for chunk in response.iter_content(chunk_size):
26 | if chunk:
27 | f.write(chunk)
28 |
--------------------------------------------------------------------------------
/meshed-memory-transformer-master/vocab.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/meshed-memory-transformer-master/vocab.pkl
--------------------------------------------------------------------------------
/model/Model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import os
4 | import glob
5 | class Model(nn.Module):
6 |
7 | def __init__(self, name):
8 | super(Model, self).__init__()
9 | self.name = name
10 |
11 | def save(self, path, epoch=0):
12 | complete_path = os.path.join(path, self.name)
13 | if not os.path.exists(complete_path):
14 | os.makedirs(complete_path)
15 | torch.save(self.state_dict(),
16 | os.path.join(complete_path,
17 | "model-{}.pth".format(str(epoch).zfill(5))))
18 |
19 | def save_results(self, path, data):
20 | raise NotImplementedError("Model subclass must implement this method.")
21 |
22 | def load(self, path, modelfile=None):
23 | complete_path = os.path.join(path, self.name)
24 | if not os.path.exists(complete_path):
25 | raise IOError("{} directory does not exist in {}".format(self.name, path))
26 |
27 | if modelfile is None:
28 | model_files = glob.glob(complete_path + "/*")
29 | mf = max(model_files)
30 | else:
31 | mf = os.path.join(complete_path, modelfile)
32 |
33 | self.load_state_dict(torch.load(mf))
34 |
35 |
36 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/model/__pycache__/Model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/model/__pycache__/Model.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/model/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/view_transformer_random4.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/model/__pycache__/view_transformer_random4.cpython-36.pyc
--------------------------------------------------------------------------------
/model/gvcnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import torch.nn as nn
4 | # from torchsummary import summary
5 |
6 |
7 | def fc_bn_block(input, output):
8 | return nn.Sequential(
9 | nn.Linear(input, output),
10 | nn.BatchNorm1d(output),
11 | nn.ReLU(inplace=True))
12 |
13 |
14 | def cal_scores(scores):
15 | n = len(scores)
16 | s = 0
17 | for score in scores:
18 | s += torch.ceil(score * n)
19 | s /= n
20 | return s
21 |
22 |
23 | def group_fusion(view_group, weight_group):
24 | shape_des = map(lambda a, b: a * b, view_group, weight_group)
25 | shape_des = sum(shape_des) / sum(weight_group)
26 | return shape_des
27 |
28 |
29 | def group_pooling(final_views, views_score, group_num):
30 | interval = 1.0 / group_num
31 |
32 | def onebatch_grouping(onebatch_views, onebatch_scores):
33 | viewgroup_onebatch = [[] for i in range(group_num)]
34 | scoregroup_onebatch = [[] for i in range(group_num)]
35 |
36 | for i in range(group_num):
37 | left = i * interval
38 | right = (i + 1) * interval
39 | for j, score in enumerate(onebatch_scores):
40 | if left <= score < right:
41 | viewgroup_onebatch[i].append(onebatch_views[j])
42 | scoregroup_onebatch[i].append(score)
43 | else:
44 | pass
45 | # print(len(scoregroup_onebatch))
46 | view_group = [sum(views) / len(views) for views in viewgroup_onebatch if len(views) > 0]
47 | weight_group = [cal_scores(scores) for scores in scoregroup_onebatch if len(scores) > 0]
48 | onebatch_shape_des = group_fusion(view_group, weight_group)
49 | return onebatch_shape_des
50 |
51 | shape_descriptors = []
52 | for (onebatch_views, onebatch_scores) in zip(final_views, views_score):
53 | shape_descriptors.append(onebatch_grouping(onebatch_views, onebatch_scores))
54 | shape_descriptor = torch.stack(shape_descriptors, 0)
55 | # shape_descriptor: [B, 1024]
56 | return shape_descriptor
57 |
58 |
59 | class GVCNN(nn.Module):
60 | def __init__(self, num_classes=40, group_num=8, model_name='GOOGLENET', pretrained=True):
61 | super(GVCNN, self).__init__()
62 |
63 | self.num_classes = num_classes
64 | self.group_num = group_num
65 |
66 | if model_name == 'GOOGLENET':
67 | base_model = torchvision.models.googlenet(pretrained=pretrained)
68 |
69 | self.FCN = nn.Sequential(*list(base_model.children())[:6])
70 | self.CNN = nn.Sequential(*list(base_model.children())[:-2])
71 | self.FC = nn.Sequential(fc_bn_block(256 * 28 * 28, 256),
72 | fc_bn_block(256, 1))
73 | self.fc_block_1 = fc_bn_block(1024, 512)
74 | self.drop_1 = nn.Dropout(0.5)
75 | self.fc_block_2 = fc_bn_block(512, 256)
76 | self.drop_2 = nn.Dropout(0.5)
77 | self.linear = nn.Linear(256, self.num_classes)
78 |
79 | def forward(self, views):
80 | '''
81 | params views: B V C H W (B 12 3 224 224)
82 | return result: B num_classes
83 | '''
84 | # print(views.size())
85 | # views = views.cpu()
86 | batch_size, num_views, channel, image_size = views.size(0), views.size(1), views.size(2), views.size(3)
87 |
88 | views = views.view(batch_size * num_views, channel, image_size, image_size)
89 | raw_views = self.FCN(views)
90 | # print(raw_views.size())
91 | # raw_views: [B*V 256 28 28]
92 | final_views = self.CNN(views)
93 | # final_views: [B*V 1024 1 1]
94 | final_views = final_views.view(batch_size, num_views, 1024)
95 | views_score = self.FC(raw_views.view(batch_size * num_views, -1))
96 | views_score = torch.sigmoid(torch.tanh(torch.abs(views_score)))
97 | views_score = views_score.view(batch_size, num_views, -1)
98 | # views_score: [B V]
99 | shape_descriptor = group_pooling(final_views, views_score, self.group_num)
100 | # print(shape_descriptor.size())
101 |
102 | out = self.fc_block_1(shape_descriptor)
103 | out = self.drop_1(out)
104 | out = self.fc_block_2(out)
105 | viewcnn_feature = out
106 | out = self.drop_2(out)
107 | pred = self.linear(out)
108 |
109 | return pred
--------------------------------------------------------------------------------
/model/gvcnn_random.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import torch.nn as nn
4 | # from torchsummary import summary
5 |
6 |
7 | def fc_bn_block(input, output):
8 | return nn.Sequential(
9 | nn.Linear(input, output),
10 | # nn.BatchNorm1d(output),
11 | nn.ReLU(inplace=True))
12 |
13 |
14 | def cal_scores(scores):
15 | n = len(scores)
16 | s = 0
17 | for score in scores:
18 | s += torch.ceil(score * n)
19 | s /= n
20 | return s
21 |
22 | def group_fusion(view_group, weight_group):
23 | shape_des = map(lambda a, b: a * b, view_group, weight_group)
24 | shape_des = sum(shape_des) / sum(weight_group)
25 | return shape_des
26 |
27 | def group_pooling(final_views, views_score, group_num):
28 | interval = 1.0 / group_num
29 |
30 | def onebatch_grouping(onebatch_views, onebatch_scores):
31 | viewgroup_onebatch = [[] for i in range(group_num)]
32 | scoregroup_onebatch = [[] for i in range(group_num)]
33 |
34 | for i in range(group_num):
35 | left = i * interval
36 | right = (i + 1) * interval
37 | for j, score in enumerate(onebatch_scores):
38 | if left <= score < right:
39 | viewgroup_onebatch[i].append(onebatch_views[j])
40 | scoregroup_onebatch[i].append(score)
41 | else:
42 | pass
43 | # print(len(scoregroup_onebatch))
44 | view_group = [sum(views) / len(views) for views in viewgroup_onebatch if len(views) > 0]
45 | weight_group = [cal_scores(scores) for scores in scoregroup_onebatch if len(scores) > 0]
46 | onebatch_shape_des = group_fusion(view_group, weight_group)
47 | return onebatch_shape_des
48 |
49 | shape_descriptors = []
50 | for (onebatch_views, onebatch_scores) in zip(final_views, views_score):
51 | shape_descriptors.append(onebatch_grouping(onebatch_views, onebatch_scores))
52 | shape_descriptor = torch.stack(shape_descriptors, 0)
53 | # shape_descriptor: [B, 1024]
54 | return shape_descriptor
55 |
56 |
57 | class GVCNN(nn.Module):
58 | def __init__(self, num_classes=40, group_num=8, model_name='GOOGLENET', pretrained=True):
59 | super(GVCNN, self).__init__()
60 |
61 | self.num_classes = num_classes
62 | self.group_num = group_num
63 |
64 | if model_name == 'GOOGLENET':
65 | base_model = torchvision.models.googlenet(pretrained=pretrained)
66 |
67 | self.FCN = nn.Sequential(*list(base_model.children())[:6])
68 | self.CNN = nn.Sequential(*list(base_model.children())[:-2])
69 | self.FC = nn.Sequential(fc_bn_block(256 * 28 * 28, 256),
70 | fc_bn_block(256, 1))
71 | self.fc_block_1 = fc_bn_block(1024, 512)
72 | self.drop_1 = nn.Dropout(0.5)
73 | self.fc_block_2 = fc_bn_block(512, 256)
74 | self.drop_2 = nn.Dropout(0.5)
75 | self.linear = nn.Linear(256, self.num_classes)
76 |
77 | def forward(self,views,rand_view_num,N):
78 | '''
79 | params views: B V C H W (B 12 3 224 224)
80 | return result: B num_classes
81 | '''
82 | # print(views.size())
83 | # views = views.cpu()
84 | # batch_size, num_views, channel, image_size = views.size(0), views.size(1), views.size(2), views.size(3)
85 | #
86 | # views = views.view(batch_size * num_views, channel, image_size, image_size)
87 | raw_views = self.FCN(views)
88 | # print(raw_views.size())
89 | # raw_views: [B*V 256 28 28]
90 | final_views = self.CNN(views)
91 | # final_views: [B*V 1024 1 1]
92 | final_views = final_views.squeeze()
93 | m = 0
94 | pred_final = []
95 | for i in range(N):
96 | raw_views0 = raw_views[m:m+rand_view_num[i],:,:,:]
97 | final_views0 = final_views[m:m+rand_view_num[i],:].unsqueeze(0)
98 | m = m + rand_view_num[i]
99 | views_score = self.FC(raw_views0.view(rand_view_num[i],-1)).unsqueeze(0)
100 | views_score = torch.sigmoid(torch.tanh(torch.abs(views_score)))
101 | # views_score = views_score.view(batch_size, num_views, -1)
102 | # views_score: [B V]
103 | shape_descriptor = group_pooling(final_views0, views_score, self.group_num)
104 | # print(shape_descriptor.size())
105 | out = self.fc_block_1(shape_descriptor)
106 | out = self.drop_1(out)
107 | out = self.fc_block_2(out)
108 | viewcnn_feature = out
109 | out = self.drop_2(out)
110 | pred = self.linear(out)
111 | pred_final.append(pred)
112 | pred_final = torch.cat(pred_final,0)
113 | return pred_final
--------------------------------------------------------------------------------
/model/mvcnn.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torchvision.models as models
5 | from .Model import Model
6 | mean = torch.tensor([0.485, 0.456, 0.406],dtype=torch.float, requires_grad=False)
7 | std = torch.tensor([0.229, 0.224, 0.225],dtype=torch.float, requires_grad=False)
8 | def flip(x, dim):
9 | xsize = x.size()
10 | dim = x.dim() + dim if dim < 0 else dim
11 | x = x.view(-1, *xsize[dim:])
12 | x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1) - 1,
13 | -1, -1), ('cpu', 'cuda')[x.is_cuda])().long(), :]
14 | return x.view(xsize)
15 |
16 | class SVCNN(Model):
17 | def __init__(self, name, nclasses=40, pretraining=True, cnn_name='resnet18'):
18 | super(SVCNN, self).__init__(name)
19 | if nclasses == 40:
20 | self.classnames = ['airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', 'bowl', 'car', 'chair',
21 | 'cone', 'cup', 'curtain', 'desk', 'door', 'dresser', 'flower_pot', 'glass_box',
22 | 'guitar', 'keyboard', 'lamp', 'laptop', 'mantel', 'monitor', 'night_stand',
23 | 'person', 'piano', 'plant', 'radio', 'range_hood', 'sink', 'sofa', 'stairs',
24 | 'stool', 'table', 'tent', 'toilet', 'tv_stand', 'vase', 'wardrobe', 'xbox']
25 | elif nclasses==15:
26 | self.classnames = ['bag', 'bed', 'bin', 'box', 'cabinet', 'chair', 'desk', 'display'
27 | , 'door', 'pillow', 'shelf', 'sink', 'sofa', 'table', 'toilet']
28 | else:
29 | self.classnames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11',
30 | '12', '13', '14', '15', '16', '17', '18', '19', '20', '21',
31 | '22', '23', '24', '25', '26', '27', '28', '29', '30', '31',
32 | '32', '33', '34', '35', '36', '37', '38', '39', '40', '41',
33 | '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54']
34 |
35 | self.nclasses = nclasses
36 | self.pretraining = pretraining
37 | self.cnn_name = cnn_name
38 | self.use_resnet = cnn_name.startswith('resnet')
39 | self.mean = torch.tensor([0.485, 0.456, 0.406],dtype=torch.float, requires_grad=False)
40 | self.std = torch.tensor([0.229, 0.224, 0.225],dtype=torch.float, requires_grad=False)
41 |
42 | if self.use_resnet:
43 | if self.cnn_name == 'resnet18':
44 | self.net = models.resnet18(pretrained=self.pretraining)
45 | self.net.fc = nn.Linear(512, self.nclasses)
46 | elif self.cnn_name == 'resnet34':
47 | self.net = models.resnet34(pretrained=self.pretraining)
48 | self.net.fc = nn.Linear(512, self.nclasses)
49 | elif self.cnn_name == 'resnet50':
50 | self.net = models.resnet50(pretrained=self.pretraining)
51 | self.net.fc = nn.Linear(2048, self.nclasses)
52 | else:
53 | if self.cnn_name == 'alexnet':
54 | self.net_1 = models.alexnet(pretrained=self.pretraining).features
55 | self.net_2 = models.alexnet(pretrained=self.pretraining).classifier
56 | elif self.cnn_name == 'vgg11':
57 | self.net_1 = models.vgg11_bn(pretrained=self.pretraining).features
58 | self.net_2 = models.vgg11_bn(pretrained=self.pretraining).classifier
59 | elif self.cnn_name == 'vgg16':
60 | self.net_1 = models.vgg16(pretrained=self.pretraining).features
61 | self.net_2 = models.vgg16(pretrained=self.pretraining).classifier
62 |
63 | self.net_2._modules['6'] = nn.Linear(4096, self.nclasses)
64 |
65 | def forward(self, x):
66 | if self.use_resnet:
67 | return self.net(x)
68 | else:
69 | y = self.net_1(x)
70 | return self.net_2(y.view(y.shape[0], -1))
71 |
72 | class view_GCN(Model):
73 | def __init__(self,name, model, nclasses=40, cnn_name='resnet18', num_views=20):
74 | super(view_GCN,self).__init__(name)
75 | if nclasses == 40:
76 | self.classnames = ['airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', 'bowl', 'car', 'chair',
77 | 'cone', 'cup', 'curtain', 'desk', 'door', 'dresser', 'flower_pot', 'glass_box',
78 | 'guitar', 'keyboard', 'lamp', 'laptop', 'mantel', 'monitor', 'night_stand',
79 | 'person', 'piano', 'plant', 'radio', 'range_hood', 'sink', 'sofa', 'stairs',
80 | 'stool', 'table', 'tent', 'toilet', 'tv_stand', 'vase', 'wardrobe', 'xbox']
81 | elif nclasses==15:
82 | self.classnames = ['bag', 'bed', 'bin', 'box', 'cabinet', 'chair', 'desk', 'display'
83 | , 'door', 'pillow', 'shelf', 'sink', 'sofa', 'table', 'toilet']
84 | else:
85 | self.classnames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11',
86 | '12', '13', '14', '15', '16', '17', '18', '19', '20', '21',
87 | '22', '23', '24', '25', '26', '27', '28', '29', '30', '31',
88 | '32', '33', '34', '35', '36', '37', '38', '39', '40', '41',
89 | '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54']
90 | self.nclasses = nclasses
91 | self.mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float, requires_grad=False)
92 | self.std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float, requires_grad=False)
93 | self.use_resnet = cnn_name.startswith('resnet')
94 | if self.use_resnet:
95 | self.net_1 = nn.Sequential(*list(model.net.children())[:-1])
96 | self.net_2 = model.net.fc
97 | else:
98 | self.net_1 = model.net_1
99 | self.net_2 = model.net_2
100 | self.num_views = num_views
101 | for m in self.modules():
102 | if isinstance(m, nn.Linear):
103 | nn.init.kaiming_uniform_(m.weight)
104 | elif isinstance(m, nn.Conv1d):
105 | nn.init.kaiming_uniform_(m.weight)
106 | def forward(self,x):
107 | y = self.net_1(x)
108 | y = y.view(
109 | (int(x.shape[0] / self.num_views), self.num_views, y.shape[-3], y.shape[-2], y.shape[-1])) # (8,12,512,7,7)
110 | return self.net_2(torch.max(y, 1)[0].view(y.shape[0], -1))
--------------------------------------------------------------------------------
/model/mvcnn_random.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torchvision.models as models
5 | from models.utils import PositionWiseFeedForward
6 | from .Model import Model
7 | from tools.view_gcn_utils import *
8 | from otk.layers import OTKernel
9 | from otk.utils import normalize
10 | from models.Vit import TransformerEncoderLayer
11 | mean = torch.tensor([0.485, 0.456, 0.406],dtype=torch.float, requires_grad=False)
12 | std = torch.tensor([0.229, 0.224, 0.225],dtype=torch.float, requires_grad=False)
13 | def flip(x, dim):
14 | xsize = x.size()
15 | dim = x.dim() + dim if dim < 0 else dim
16 | x = x.view(-1, *xsize[dim:])
17 | x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1) - 1,
18 | -1, -1), ('cpu', 'cuda')[x.is_cuda])().long(), :]
19 | return x.view(xsize)
20 |
21 | class SVCNN(Model):
22 | def __init__(self, name, nclasses=40, pretraining=True, cnn_name='resnet18'):
23 | super(SVCNN, self).__init__(name)
24 | if nclasses == 40:
25 | self.classnames = ['airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', 'bowl', 'car', 'chair',
26 | 'cone', 'cup', 'curtain', 'desk', 'door', 'dresser', 'flower_pot', 'glass_box',
27 | 'guitar', 'keyboard', 'lamp', 'laptop', 'mantel', 'monitor', 'night_stand',
28 | 'person', 'piano', 'plant', 'radio', 'range_hood', 'sink', 'sofa', 'stairs',
29 | 'stool', 'table', 'tent', 'toilet', 'tv_stand', 'vase', 'wardrobe', 'xbox']
30 | elif nclasses==15:
31 | self.classnames = ['bag', 'bed', 'bin', 'box', 'cabinet', 'chair', 'desk', 'display'
32 | , 'door', 'pillow', 'shelf', 'sink', 'sofa', 'table', 'toilet']
33 | else:
34 | self.classnames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11',
35 | '12', '13', '14', '15', '16', '17', '18', '19', '20', '21',
36 | '22', '23', '24', '25', '26', '27', '28', '29', '30', '31',
37 | '32', '33', '34', '35', '36', '37', '38', '39', '40', '41',
38 | '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54']
39 |
40 | self.nclasses = nclasses
41 | self.pretraining = pretraining
42 | self.cnn_name = cnn_name
43 | self.use_resnet = cnn_name.startswith('resnet')
44 | self.mean = torch.tensor([0.485, 0.456, 0.406],dtype=torch.float, requires_grad=False)
45 | self.std = torch.tensor([0.229, 0.224, 0.225],dtype=torch.float, requires_grad=False)
46 |
47 | if self.use_resnet:
48 | if self.cnn_name == 'resnet18':
49 | self.net = models.resnet18(pretrained=self.pretraining)
50 | self.net.fc = nn.Linear(512, self.nclasses)
51 | elif self.cnn_name == 'resnet34':
52 | self.net = models.resnet34(pretrained=self.pretraining)
53 | self.net.fc = nn.Linear(512, self.nclasses)
54 | elif self.cnn_name == 'resnet50':
55 | self.net = models.resnet50(pretrained=self.pretraining)
56 | self.net.fc = nn.Linear(2048, self.nclasses)
57 | else:
58 | if self.cnn_name == 'alexnet':
59 | self.net_1 = models.alexnet(pretrained=self.pretraining).features
60 | self.net_2 = models.alexnet(pretrained=self.pretraining).classifier
61 | elif self.cnn_name == 'vgg11':
62 | self.net_1 = models.vgg11_bn(pretrained=self.pretraining).features
63 | self.net_2 = models.vgg11_bn(pretrained=self.pretraining).classifier
64 | elif self.cnn_name == 'vgg16':
65 | self.net_1 = models.vgg16(pretrained=self.pretraining).features
66 | self.net_2 = models.vgg16(pretrained=self.pretraining).classifier
67 |
68 | self.net_2._modules['6'] = nn.Linear(4096, self.nclasses)
69 |
70 | def forward(self, x):
71 | if self.use_resnet:
72 | return self.net(x)
73 | else:
74 | y = self.net_1(x)
75 | return self.net_2(y.view(y.shape[0], -1))
76 |
77 | class view_GCN(Model):
78 | def __init__(self,name, model, nclasses=40, cnn_name='resnet18', num_views=20):
79 | super(view_GCN,self).__init__(name)
80 | if nclasses == 40:
81 | self.classnames = ['airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', 'bowl', 'car', 'chair',
82 | 'cone', 'cup', 'curtain', 'desk', 'door', 'dresser', 'flower_pot', 'glass_box',
83 | 'guitar', 'keyboard', 'lamp', 'laptop', 'mantel', 'monitor', 'night_stand',
84 | 'person', 'piano', 'plant', 'radio', 'range_hood', 'sink', 'sofa', 'stairs',
85 | 'stool', 'table', 'tent', 'toilet', 'tv_stand', 'vase', 'wardrobe', 'xbox']
86 | elif nclasses==15:
87 | self.classnames = ['bag', 'bed', 'bin', 'box', 'cabinet', 'chair', 'desk', 'display'
88 | , 'door', 'pillow', 'shelf', 'sink', 'sofa', 'table', 'toilet']
89 | else:
90 | self.classnames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11',
91 | '12', '13', '14', '15', '16', '17', '18', '19', '20', '21',
92 | '22', '23', '24', '25', '26', '27', '28', '29', '30', '31',
93 | '32', '33', '34', '35', '36', '37', '38', '39', '40', '41',
94 | '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54']
95 | self.nclasses = nclasses
96 | self.mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float, requires_grad=False)
97 | self.std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float, requires_grad=False)
98 | self.use_resnet = cnn_name.startswith('resnet')
99 | if self.use_resnet:
100 | self.net_1 = nn.Sequential(*list(model.net.children())[:-1])
101 | self.net_2 = model.net.fc
102 | else:
103 | self.net_1 = model.net_1
104 | self.net_2 = model.net_2
105 | self.num_views = num_views
106 | # for m in self.modules():
107 | # if isinstance(m,nn.Linear):
108 | # nn.init.kaiming_uniform_(m.weight)
109 | # elif isinstance(m,nn.Conv1d):
110 | # nn.init.kaiming_uniform_(m.weight)
111 | def forward(self,x,rand_view_num,N):
112 | y = self.net_1(x)
113 | y = y.squeeze()
114 | y0 = my_pad_sequence(sequences=y, view_num=rand_view_num, N=N, max_length=self.num_views, padding_value=0)
115 | pooled_view = y0.max(1)[0]
116 | pooled_view = self.net_2(pooled_view)
117 | return pooled_view
--------------------------------------------------------------------------------
/model/view_transformer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torchvision.models as models
5 | from models.utils import PositionWiseFeedForward
6 | from .Model import Model
7 | from tools.view_gcn_utils import *
8 | from otk.layers import OTKernel
9 | from otk.utils import normalize
10 | from models.Vit import TransformerEncoderLayer
11 | mean = torch.tensor([0.485, 0.456, 0.406],dtype=torch.float, requires_grad=False)
12 | std = torch.tensor([0.229, 0.224, 0.225],dtype=torch.float, requires_grad=False)
13 | def flip(x, dim):
14 | xsize = x.size()
15 | dim = x.dim() + dim if dim < 0 else dim
16 | x = x.view(-1, *xsize[dim:])
17 | x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1) - 1,
18 | -1, -1), ('cpu', 'cuda')[x.is_cuda])().long(), :]
19 | return x.view(xsize)
20 |
21 | class SVCNN(Model):
22 | def __init__(self, name, nclasses=40, pretraining=True, cnn_name='resnet18'):
23 | super(SVCNN, self).__init__(name)
24 | if nclasses == 40:
25 | self.classnames = ['airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', 'bowl', 'car', 'chair',
26 | 'cone', 'cup', 'curtain', 'desk', 'door', 'dresser', 'flower_pot', 'glass_box',
27 | 'guitar', 'keyboard', 'lamp', 'laptop', 'mantel', 'monitor', 'night_stand',
28 | 'person', 'piano', 'plant', 'radio', 'range_hood', 'sink', 'sofa', 'stairs',
29 | 'stool', 'table', 'tent', 'toilet', 'tv_stand', 'vase', 'wardrobe', 'xbox']
30 | elif nclasses==15:
31 | self.classnames = ['bag', 'bed', 'bin', 'box', 'cabinet', 'chair', 'desk', 'display'
32 | , 'door', 'pillow', 'shelf', 'sink', 'sofa', 'table', 'toilet']
33 | else:
34 | self.classnames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11',
35 | '12', '13', '14', '15', '16', '17', '18', '19', '20', '21',
36 | '22', '23', '24', '25', '26', '27', '28', '29', '30', '31',
37 | '32', '33', '34', '35', '36', '37', '38', '39', '40', '41',
38 | '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54']
39 |
40 | self.nclasses = nclasses
41 | self.pretraining = pretraining
42 | self.cnn_name = cnn_name
43 | self.use_resnet = cnn_name.startswith('resnet')
44 | self.mean = torch.tensor([0.485, 0.456, 0.406],dtype=torch.float, requires_grad=False)
45 | self.std = torch.tensor([0.229, 0.224, 0.225],dtype=torch.float, requires_grad=False)
46 |
47 | if self.use_resnet:
48 | if self.cnn_name == 'resnet18':
49 | self.net = models.resnet18(pretrained=self.pretraining)
50 | self.net.fc = nn.Linear(512, self.nclasses)
51 | elif self.cnn_name == 'resnet34':
52 | self.net = models.resnet34(pretrained=self.pretraining)
53 | self.net.fc = nn.Linear(512, self.nclasses)
54 | elif self.cnn_name == 'resnet50':
55 | self.net = models.resnet50(pretrained=self.pretraining)
56 | self.net.fc = nn.Linear(2048, self.nclasses)
57 | else:
58 | if self.cnn_name == 'alexnet':
59 | self.net_1 = models.alexnet(pretrained=self.pretraining).features
60 | self.net_2 = models.alexnet(pretrained=self.pretraining).classifier
61 | elif self.cnn_name == 'vgg11':
62 | self.net_1 = models.vgg11_bn(pretrained=self.pretraining).features
63 | self.net_2 = models.vgg11_bn(pretrained=self.pretraining).classifier
64 | elif self.cnn_name == 'vgg16':
65 | self.net_1 = models.vgg16(pretrained=self.pretraining).features
66 | self.net_2 = models.vgg16(pretrained=self.pretraining).classifier
67 |
68 | self.net_2._modules['6'] = nn.Linear(4096, self.nclasses)
69 |
70 | def forward(self, x):
71 | if self.use_resnet:
72 | return self.net(x)
73 | else:
74 | y = self.net_1(x)
75 | return self.net_2(y.view(y.shape[0], -1))
76 |
77 | class view_GCN(Model):
78 | def __init__(self,name, model, nclasses=40, cnn_name='resnet18', num_views=20):
79 | super(view_GCN,self).__init__(name)
80 | if nclasses == 40:
81 | self.classnames = ['airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', 'bowl', 'car', 'chair',
82 | 'cone', 'cup', 'curtain', 'desk', 'door', 'dresser', 'flower_pot', 'glass_box',
83 | 'guitar', 'keyboard', 'lamp', 'laptop', 'mantel', 'monitor', 'night_stand',
84 | 'person', 'piano', 'plant', 'radio', 'range_hood', 'sink', 'sofa', 'stairs',
85 | 'stool', 'table', 'tent', 'toilet', 'tv_stand', 'vase', 'wardrobe', 'xbox']
86 | elif nclasses==15:
87 | self.classnames = ['bag', 'bed', 'bin', 'box', 'cabinet', 'chair', 'desk', 'display'
88 | , 'door', 'pillow', 'shelf', 'sink', 'sofa', 'table', 'toilet']
89 | else:
90 | self.classnames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11',
91 | '12', '13', '14', '15', '16', '17', '18', '19', '20', '21',
92 | '22', '23', '24', '25', '26', '27', '28', '29', '30', '31',
93 | '32', '33', '34', '35', '36', '37', '38', '39', '40', '41',
94 | '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54']
95 | self.nclasses = nclasses
96 | self.mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float, requires_grad=False)
97 | self.std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float, requires_grad=False)
98 | self.use_resnet = cnn_name.startswith('resnet')
99 | if self.use_resnet:
100 | self.net_1 = nn.Sequential(*list(model.net.children())[:-1])
101 | self.net_2 = model.net.fc
102 | else:
103 | self.net_1 = model.net_1
104 | self.net_2 = model.net_2
105 | self.num_views = num_views
106 | self.zdim = 8
107 | self.tgt=nn.Parameter(torch.Tensor(self.zdim,512))
108 | nn.init.xavier_normal_(self.tgt)
109 | self.coord_encoder = nn.Sequential(
110 | nn.Linear(512,64),
111 | nn.ReLU(),
112 | nn.Linear(64,3)
113 | )
114 | self.coord_decoder = nn.Sequential(
115 | nn.Linear(3,64),
116 | nn.ReLU(),
117 | nn.Linear(64,512)
118 | )
119 | self.otk_layer = OTKernel(in_dim=512, out_size=self.zdim, heads=1, max_iter=100, eps=0.05)
120 | self.dim = 512
121 | self.encoder1 = TransformerEncoderLayer(d_model=self.dim,nhead=4)
122 | self.encoder2 = TransformerEncoderLayer(d_model=self.dim,nhead=4)
123 | self.ff = PositionWiseFeedForward(d_model=512, d_ff=512)
124 | self.cls = nn.Sequential(
125 | nn.Linear(512, 256),
126 | nn.Dropout(),
127 | nn.LeakyReLU(0.2, inplace=True),
128 | )
129 | self.cls2 = nn.Linear(256,self.nclasses)
130 | def forward(self,x):
131 | y = self.net_1(x)
132 | y = y.view((int(x.shape[0] / self.num_views), self.num_views, -1))
133 | y0 = self.encoder1(src=y.transpose(0,1), src_key_padding_mask=None, pos=None)
134 | y0 = y0.transpose(0,1)
135 | y1 = self.otk_layer(y0)
136 | y11 = self.ff(y1)
137 | pos0 = normalize(self.coord_encoder(y11))
138 | pos = self.coord_decoder(pos0)
139 | y2 = self.encoder2(src=y11.transpose(0,1),src_key_padding_mask=None, pos=pos.transpose(0,1))
140 | y2 =y2.transpose(0,1)
141 | weight = self.otk_layer.weight
142 | cos_sim = torch.matmul(normalize(weight), normalize(weight).transpose(1, 2)) - torch.eye(self.zdim,
143 | self.zdim).cuda()
144 | cos_sim2 = torch.matmul(normalize(y1), normalize(y1).transpose(1, 2)) - torch.eye(self.zdim, self.zdim).cuda()
145 | pooled_view = y2.mean(1)
146 | feature = self.cls(pooled_view)
147 | pooled_view = self.cls2(feature)
148 | return pooled_view,cos_sim,cos_sim2,pos0
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .transformer import Transformer
2 | from .captioning_model import CaptioningModel
3 |
--------------------------------------------------------------------------------
/models/__pycache__/Vit.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/models/__pycache__/Vit.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/Vit.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/models/__pycache__/Vit.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/models/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/captioning_model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/models/__pycache__/captioning_model.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/containers.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/models/__pycache__/containers.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/models/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/models/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/models/beam_search/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/models/beam_search/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/models/beam_search/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/beam_search/beam_search.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import utils
3 |
4 |
5 | class BeamSearch(object):
6 | def __init__(self, model, max_len: int, eos_idx: int, beam_size: int):
7 | self.model = model
8 | self.max_len = max_len
9 | self.eos_idx = eos_idx
10 | self.beam_size = beam_size
11 | self.b_s = None
12 | self.device = None
13 | self.seq_mask = None
14 | self.seq_logprob = None
15 | self.outputs = None
16 | self.log_probs = None
17 | self.selected_words = None
18 | self.all_log_probs = None
19 |
20 | def _expand_state(self, selected_beam, cur_beam_size):
21 | def fn(s):
22 | shape = [int(sh) for sh in s.shape]
23 | beam = selected_beam
24 | for _ in shape[1:]:
25 | beam = beam.unsqueeze(-1)
26 | s = torch.gather(s.view(*([self.b_s, cur_beam_size] + shape[1:])), 1,
27 | beam.expand(*([self.b_s, self.beam_size] + shape[1:])))
28 | s = s.view(*([-1, ] + shape[1:]))
29 | return s
30 |
31 | return fn
32 |
33 | def _expand_visual(self, visual: utils.TensorOrSequence, cur_beam_size: int, selected_beam: torch.Tensor):
34 | if isinstance(visual, torch.Tensor):
35 | visual_shape = visual.shape
36 | visual_exp_shape = (self.b_s, cur_beam_size) + visual_shape[1:]
37 | visual_red_shape = (self.b_s * self.beam_size,) + visual_shape[1:]
38 | selected_beam_red_size = (self.b_s, self.beam_size) + tuple(1 for _ in range(len(visual_exp_shape) - 2))
39 | selected_beam_exp_size = (self.b_s, self.beam_size) + visual_exp_shape[2:]
40 | visual_exp = visual.view(visual_exp_shape)
41 | selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size)
42 | visual = torch.gather(visual_exp, 1, selected_beam_exp).view(visual_red_shape)
43 | else:
44 | new_visual = []
45 | for im in visual:
46 | visual_shape = im.shape
47 | visual_exp_shape = (self.b_s, cur_beam_size) + visual_shape[1:]
48 | visual_red_shape = (self.b_s * self.beam_size,) + visual_shape[1:]
49 | selected_beam_red_size = (self.b_s, self.beam_size) + tuple(1 for _ in range(len(visual_exp_shape) - 2))
50 | selected_beam_exp_size = (self.b_s, self.beam_size) + visual_exp_shape[2:]
51 | visual_exp = im.view(visual_exp_shape)
52 | selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size)
53 | new_im = torch.gather(visual_exp, 1, selected_beam_exp).view(visual_red_shape)
54 | new_visual.append(new_im)
55 | visual = tuple(new_visual)
56 | return visual
57 |
58 | def apply(self, visual: utils.TensorOrSequence, out_size=1, return_probs=False, **kwargs):
59 | self.b_s = utils.get_batch_size(visual)
60 | self.device = utils.get_device(visual)
61 | self.seq_mask = torch.ones((self.b_s, self.beam_size, 1), device=self.device)
62 | self.seq_logprob = torch.zeros((self.b_s, 1, 1), device=self.device)
63 | self.log_probs = []
64 | self.selected_words = None
65 | if return_probs:
66 | self.all_log_probs = []
67 |
68 | outputs = []
69 | with self.model.statefulness(self.b_s):
70 | for t in range(self.max_len):
71 | visual, outputs = self.iter(t, visual, outputs, return_probs, **kwargs)
72 |
73 | # Sort result
74 | seq_logprob, sort_idxs = torch.sort(self.seq_logprob, 1, descending=True)
75 | outputs = torch.cat(outputs, -1)
76 | outputs = torch.gather(outputs, 1, sort_idxs.expand(self.b_s, self.beam_size, self.max_len))
77 | log_probs = torch.cat(self.log_probs, -1)
78 | log_probs = torch.gather(log_probs, 1, sort_idxs.expand(self.b_s, self.beam_size, self.max_len))
79 | if return_probs:
80 | all_log_probs = torch.cat(self.all_log_probs, 2)
81 | all_log_probs = torch.gather(all_log_probs, 1, sort_idxs.unsqueeze(-1).expand(self.b_s, self.beam_size,
82 | self.max_len,
83 | all_log_probs.shape[-1]))
84 |
85 | outputs = outputs.contiguous()[:, :out_size]
86 | log_probs = log_probs.contiguous()[:, :out_size]
87 | if out_size == 1:
88 | outputs = outputs.squeeze(1)
89 | log_probs = log_probs.squeeze(1)
90 |
91 | if return_probs:
92 | return outputs, log_probs, all_log_probs
93 | else:
94 | return outputs, log_probs
95 |
96 | def select(self, t, candidate_logprob, **kwargs):
97 | selected_logprob, selected_idx = torch.sort(candidate_logprob.view(self.b_s, -1), -1, descending=True)
98 | selected_logprob, selected_idx = selected_logprob[:, :self.beam_size], selected_idx[:, :self.beam_size]
99 | return selected_idx, selected_logprob
100 |
101 | def iter(self, t: int, visual: utils.TensorOrSequence, outputs, return_probs, **kwargs):
102 | cur_beam_size = 1 if t == 0 else self.beam_size
103 |
104 | word_logprob = self.model.step(t, self.selected_words, visual, None, mode='feedback', **kwargs)
105 | word_logprob = word_logprob.view(self.b_s, cur_beam_size, -1)
106 | candidate_logprob = self.seq_logprob + word_logprob
107 |
108 | # Mask sequence if it reaches EOS
109 | if t > 0:
110 | mask = (self.selected_words.view(self.b_s, cur_beam_size) != self.eos_idx).float().unsqueeze(-1)
111 | self.seq_mask = self.seq_mask * mask
112 | word_logprob = word_logprob * self.seq_mask.expand_as(word_logprob)
113 | old_seq_logprob = self.seq_logprob.expand_as(candidate_logprob).contiguous()
114 | old_seq_logprob[:, :, 1:] = -999
115 | candidate_logprob = self.seq_mask * candidate_logprob + old_seq_logprob * (1 - self.seq_mask)
116 |
117 | selected_idx, selected_logprob = self.select(t, candidate_logprob, **kwargs)
118 | selected_beam = selected_idx / candidate_logprob.shape[-1]
119 | selected_words = selected_idx - selected_beam * candidate_logprob.shape[-1]
120 |
121 | self.model.apply_to_states(self._expand_state(selected_beam, cur_beam_size))
122 | visual = self._expand_visual(visual, cur_beam_size, selected_beam)
123 |
124 | self.seq_logprob = selected_logprob.unsqueeze(-1)
125 | self.seq_mask = torch.gather(self.seq_mask, 1, selected_beam.unsqueeze(-1))
126 | outputs = list(torch.gather(o, 1, selected_beam.unsqueeze(-1)) for o in outputs)
127 | outputs.append(selected_words.unsqueeze(-1))
128 |
129 | if return_probs:
130 | if t == 0:
131 | self.all_log_probs.append(word_logprob.expand((self.b_s, self.beam_size, -1)).unsqueeze(2))
132 | else:
133 | self.all_log_probs.append(word_logprob.unsqueeze(2))
134 |
135 | this_word_logprob = torch.gather(word_logprob, 1,
136 | selected_beam.unsqueeze(-1).expand(self.b_s, self.beam_size,
137 | word_logprob.shape[-1]))
138 | this_word_logprob = torch.gather(this_word_logprob, 2, selected_words.unsqueeze(-1))
139 | self.log_probs = list(
140 | torch.gather(o, 1, selected_beam.unsqueeze(-1).expand(self.b_s, self.beam_size, 1)) for o in self.log_probs)
141 | self.log_probs.append(this_word_logprob)
142 | self.selected_words = selected_words.view(-1, 1)
143 |
144 | return visual, outputs
145 |
--------------------------------------------------------------------------------
/models/captioning_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import distributions
3 | import models.utils as utils
4 | from models.containers import Module
5 | from models.beam_search import *
6 |
7 |
8 | class CaptioningModel(Module):
9 | def __init__(self):
10 | super(CaptioningModel, self).__init__()
11 |
12 | def init_weights(self):
13 | raise NotImplementedError
14 |
15 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs):
16 | raise NotImplementedError
17 |
18 | def forward(self, images, seq, *args):
19 | device = images.device
20 | b_s = images.size(0)
21 | seq_len = seq.size(1)
22 | state = self.init_state(b_s, device)
23 | out = None
24 |
25 | outputs = []
26 | for t in range(seq_len):
27 | out, state = self.step(t, state, out, images, seq, *args, mode='teacher_forcing')
28 | outputs.append(out)
29 |
30 | outputs = torch.cat([o.unsqueeze(1) for o in outputs], 1)
31 | return outputs
32 |
33 | def test(self, visual: utils.TensorOrSequence, max_len: int, eos_idx: int, **kwargs) -> utils.Tuple[torch.Tensor, torch.Tensor]:
34 | b_s = utils.get_batch_size(visual)
35 | device = utils.get_device(visual)
36 | outputs = []
37 | log_probs = []
38 |
39 | mask = torch.ones((b_s,), device=device)
40 | with self.statefulness(b_s):
41 | out = None
42 | for t in range(max_len):
43 | log_probs_t = self.step(t, out, visual, None, mode='feedback', **kwargs)
44 | out = torch.max(log_probs_t, -1)[1]
45 | mask = mask * (out.squeeze(-1) != eos_idx).float()
46 | log_probs.append(log_probs_t * mask.unsqueeze(-1).unsqueeze(-1))
47 | outputs.append(out)
48 |
49 | return torch.cat(outputs, 1), torch.cat(log_probs, 1)
50 |
51 | def sample_rl(self, visual: utils.TensorOrSequence, max_len: int, **kwargs) -> utils.Tuple[torch.Tensor, torch.Tensor]:
52 | b_s = utils.get_batch_size(visual)
53 | outputs = []
54 | log_probs = []
55 |
56 | with self.statefulness(b_s):
57 | out = None
58 | for t in range(max_len):
59 | out = self.step(t, out, visual, None, mode='feedback', **kwargs)
60 | distr = distributions.Categorical(logits=out[:, 0])
61 | out = distr.sample().unsqueeze(1)
62 | outputs.append(out)
63 | log_probs.append(distr.log_prob(out).unsqueeze(1))
64 |
65 | return torch.cat(outputs, 1), torch.cat(log_probs, 1)
66 |
67 | def beam_search(self, visual: utils.TensorOrSequence, max_len: int, eos_idx: int, beam_size: int, out_size=1,
68 | return_probs=False, **kwargs):
69 | bs = BeamSearch(self, max_len, eos_idx, beam_size)
70 | return bs.apply(visual, out_size, return_probs, **kwargs)
71 |
--------------------------------------------------------------------------------
/models/containers.py:
--------------------------------------------------------------------------------
1 | from contextlib import contextmanager
2 | from torch import nn
3 | from typing import Union, Sequence, Tuple
4 | import torch
5 |
6 | TensorOrSequence = Union[Sequence[torch.Tensor], torch.Tensor]
7 | TensorOrNone = Union[torch.Tensor, None]
8 |
9 |
10 | class Module(nn.Module):
11 | def __init__(self):
12 | super(Module, self).__init__()
13 | self._is_stateful = False
14 | self._state_names = []
15 | self._state_defaults = dict()
16 |
17 | def register_state(self, name: str, default: TensorOrNone):
18 | self._state_names.append(name)
19 | if default is None:
20 | self._state_defaults[name] = None
21 | else:
22 | self._state_defaults[name] = default.clone().detach()
23 | self.register_buffer(name, default)
24 |
25 | def states(self):
26 | for name in self._state_names:
27 | yield self._buffers[name]
28 | for m in self.children():
29 | if isinstance(m, Module):
30 | yield from m.states()
31 |
32 | def apply_to_states(self, fn):
33 | for name in self._state_names:
34 | self._buffers[name] = fn(self._buffers[name])
35 | for m in self.children():
36 | if isinstance(m, Module):
37 | m.apply_to_states(fn)
38 |
39 | def _init_states(self, batch_size: int):
40 | for name in self._state_names:
41 | if self._state_defaults[name] is None:
42 | self._buffers[name] = None
43 | else:
44 | self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device)
45 | self._buffers[name] = self._buffers[name].unsqueeze(0)
46 | self._buffers[name] = self._buffers[name].expand([batch_size, ] + list(self._buffers[name].shape[1:]))
47 | self._buffers[name] = self._buffers[name].contiguous()
48 |
49 | def _reset_states(self):
50 | for name in self._state_names:
51 | if self._state_defaults[name] is None:
52 | self._buffers[name] = None
53 | else:
54 | self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device)
55 |
56 | def enable_statefulness(self, batch_size: int):
57 | for m in self.children():
58 | if isinstance(m, Module):
59 | m.enable_statefulness(batch_size)
60 | self._init_states(batch_size)
61 | self._is_stateful = True
62 |
63 | def disable_statefulness(self):
64 | for m in self.children():
65 | if isinstance(m, Module):
66 | m.disable_statefulness()
67 | self._reset_states()
68 | self._is_stateful = False
69 |
70 | @contextmanager
71 | def statefulness(self, batch_size: int):
72 | self.enable_statefulness(batch_size)
73 | try:
74 | yield
75 | finally:
76 | self.disable_statefulness()
77 |
78 |
79 | class ModuleList(nn.ModuleList, Module):
80 | pass
81 |
82 |
83 | class ModuleDict(nn.ModuleDict, Module):
84 | pass
85 |
--------------------------------------------------------------------------------
/models/transformer/__init__.py:
--------------------------------------------------------------------------------
1 | from .transformer import *
2 | from .encoders import *
3 | from .decoders import *
4 | from .attention import *
5 |
--------------------------------------------------------------------------------
/models/transformer/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/models/transformer/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/transformer/__pycache__/attention.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/models/transformer/__pycache__/attention.cpython-36.pyc
--------------------------------------------------------------------------------
/models/transformer/__pycache__/decoders.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/models/transformer/__pycache__/decoders.cpython-36.pyc
--------------------------------------------------------------------------------
/models/transformer/__pycache__/encoders.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/models/transformer/__pycache__/encoders.cpython-36.pyc
--------------------------------------------------------------------------------
/models/transformer/__pycache__/transformer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/models/transformer/__pycache__/transformer.cpython-36.pyc
--------------------------------------------------------------------------------
/models/transformer/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/models/transformer/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/models/transformer/decoders.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | import numpy as np
5 |
6 | from models.transformer.attention import MultiHeadAttention
7 | from models.transformer.utils import sinusoid_encoding_table, PositionWiseFeedForward
8 | from models.containers import Module, ModuleList
9 |
10 |
11 | class MeshedDecoderLayer(Module):
12 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, self_att_module=None,
13 | enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None):
14 | super(MeshedDecoderLayer, self).__init__()
15 | self.self_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=True,
16 | attention_module=self_att_module,
17 | attention_module_kwargs=self_att_module_kwargs)
18 | self.enc_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=False,
19 | attention_module=enc_att_module,
20 | attention_module_kwargs=enc_att_module_kwargs)
21 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout)
22 |
23 | self.fc_alpha1 = nn.Linear(d_model + d_model, d_model)
24 | self.fc_alpha2 = nn.Linear(d_model + d_model, d_model)
25 | self.fc_alpha3 = nn.Linear(d_model + d_model, d_model)
26 |
27 | self.init_weights()
28 |
29 | def init_weights(self):
30 | nn.init.xavier_uniform_(self.fc_alpha1.weight)
31 | nn.init.xavier_uniform_(self.fc_alpha2.weight)
32 | nn.init.xavier_uniform_(self.fc_alpha3.weight)
33 | nn.init.constant_(self.fc_alpha1.bias, 0)
34 | nn.init.constant_(self.fc_alpha2.bias, 0)
35 | nn.init.constant_(self.fc_alpha3.bias, 0)
36 |
37 | def forward(self, input, enc_output, mask_pad, mask_self_att, mask_enc_att):
38 | self_att = self.self_att(input, input, input, mask_self_att)
39 | self_att = self_att * mask_pad
40 |
41 | enc_att1 = self.enc_att(self_att, enc_output[:, 0], enc_output[:, 0], mask_enc_att) * mask_pad
42 | enc_att2 = self.enc_att(self_att, enc_output[:, 1], enc_output[:, 1], mask_enc_att) * mask_pad
43 | enc_att3 = self.enc_att(self_att, enc_output[:, 2], enc_output[:, 2], mask_enc_att) * mask_pad
44 |
45 | alpha1 = torch.sigmoid(self.fc_alpha1(torch.cat([self_att, enc_att1], -1)))
46 | alpha2 = torch.sigmoid(self.fc_alpha2(torch.cat([self_att, enc_att2], -1)))
47 | alpha3 = torch.sigmoid(self.fc_alpha3(torch.cat([self_att, enc_att3], -1)))
48 |
49 | enc_att = (enc_att1 * alpha1 + enc_att2 * alpha2 + enc_att3 * alpha3) / np.sqrt(3)
50 | enc_att = enc_att * mask_pad
51 |
52 | ff = self.pwff(enc_att)
53 | ff = ff * mask_pad
54 | return ff
55 |
56 |
57 | class MeshedDecoder(Module):
58 | def __init__(self, vocab_size, max_len, N_dec, padding_idx, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1,
59 | self_att_module=None, enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None):
60 | super(MeshedDecoder, self).__init__()
61 | self.d_model = d_model
62 | self.word_emb = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
63 | self.pos_emb = nn.Embedding.from_pretrained(sinusoid_encoding_table(max_len + 1, d_model, 0), freeze=True)
64 | self.layers = ModuleList(
65 | [MeshedDecoderLayer(d_model, d_k, d_v, h, d_ff, dropout, self_att_module=self_att_module,
66 | enc_att_module=enc_att_module, self_att_module_kwargs=self_att_module_kwargs,
67 | enc_att_module_kwargs=enc_att_module_kwargs) for _ in range(N_dec)])
68 | self.fc = nn.Linear(d_model, vocab_size, bias=False)
69 | self.max_len = max_len
70 | self.padding_idx = padding_idx
71 | self.N = N_dec
72 |
73 | self.register_state('running_mask_self_attention', torch.zeros((1, 1, 0)).byte())
74 | self.register_state('running_seq', torch.zeros((1,)).long())
75 |
76 | def forward(self, input, encoder_output, mask_encoder):
77 | # input (b_s, seq_len)
78 | b_s, seq_len = input.shape[:2]
79 | mask_queries = (input != self.padding_idx).unsqueeze(-1).float() # (b_s, seq_len, 1)
80 | mask_self_attention = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8, device=input.device),
81 | diagonal=1)
82 | mask_self_attention = mask_self_attention.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len)
83 | mask_self_attention = mask_self_attention + (input == self.padding_idx).unsqueeze(1).unsqueeze(1).byte()
84 | mask_self_attention = mask_self_attention.gt(0) # (b_s, 1, seq_len, seq_len)
85 | if self._is_stateful:
86 | self.running_mask_self_attention = torch.cat([self.running_mask_self_attention, mask_self_attention], -1)
87 | mask_self_attention = self.running_mask_self_attention
88 |
89 | seq = torch.arange(1, seq_len + 1).view(1, -1).expand(b_s, -1).to(input.device) # (b_s, seq_len)
90 | seq = seq.masked_fill(mask_queries.squeeze(-1) == 0, 0)
91 | if self._is_stateful:
92 | self.running_seq.add_(1)
93 | seq = self.running_seq
94 |
95 | out = self.word_emb(input) + self.pos_emb(seq)
96 | for i, l in enumerate(self.layers):
97 | out = l(out, encoder_output, mask_queries, mask_self_attention, mask_encoder)
98 |
99 | out = self.fc(out)
100 | return F.log_softmax(out, dim=-1)
101 |
--------------------------------------------------------------------------------
/models/transformer/encoders.py:
--------------------------------------------------------------------------------
1 | from torch.nn import functional as F
2 | from models.transformer.utils import PositionWiseFeedForward,PositionWiseFeedForward_BN
3 | import torch
4 | from torch import nn
5 | from models.transformer.attention import MultiHeadAttention,MultiHeadAttention_BN,MultiHeadAttention_BN2
6 |
7 | class EncoderLayer(nn.Module):
8 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, identity_map_reordering=False,
9 | attention_module=None, attention_module_kwargs=None):
10 | super(EncoderLayer, self).__init__()
11 | self.identity_map_reordering = identity_map_reordering
12 | self.mhatt = MultiHeadAttention(d_model, d_k, d_v, h, dropout, identity_map_reordering=identity_map_reordering,
13 | attention_module=attention_module,
14 | attention_module_kwargs=attention_module_kwargs)
15 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout, identity_map_reordering=identity_map_reordering)
16 |
17 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
18 | att = self.mhatt(queries, keys, values, attention_mask, attention_weights)
19 | # ff = self.pwff(att)
20 | return att
21 | class EncoderLayer_BN(nn.Module):
22 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, identity_map_reordering=False,
23 | attention_module=None, attention_module_kwargs=None):
24 | super(EncoderLayer_BN, self).__init__()
25 | self.identity_map_reordering = identity_map_reordering
26 | self.mhatt = MultiHeadAttention_BN(d_model, d_k, d_v, h, dropout, identity_map_reordering=identity_map_reordering,
27 | attention_module=attention_module,
28 | attention_module_kwargs=attention_module_kwargs)
29 | self.pwff = PositionWiseFeedForward_BN(d_model, d_ff, dropout, identity_map_reordering=identity_map_reordering)
30 |
31 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
32 | att = self.mhatt(queries, keys, values, attention_mask, attention_weights)
33 | # ff = self.pwff(att)
34 | return att
35 | class EncoderLayer_BN2(nn.Module):
36 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, identity_map_reordering=False,
37 | attention_module=None, attention_module_kwargs=None):
38 | super(EncoderLayer_BN2, self).__init__()
39 | self.identity_map_reordering = identity_map_reordering
40 | self.mhatt = MultiHeadAttention_BN2(d_model, d_k, d_v, h, dropout, identity_map_reordering=identity_map_reordering,
41 | attention_module=attention_module,
42 | attention_module_kwargs=attention_module_kwargs)
43 | self.pwff = PositionWiseFeedForward_BN(d_model, d_ff, dropout, identity_map_reordering=identity_map_reordering)
44 |
45 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
46 | att = self.mhatt(queries, keys, values, attention_mask, attention_weights)
47 | # ff = self.pwff(att)
48 | return att
49 | class MultiLevelEncoder(nn.Module):
50 | def __init__(self, N, padding_idx, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1,
51 | identity_map_reordering=False, attention_module=None, attention_module_kwargs=None):
52 | super(MultiLevelEncoder, self).__init__()
53 | self.d_model = d_model
54 | self.dropout = dropout
55 | self.layers = nn.ModuleList([EncoderLayer(d_model, d_k, d_v, h, d_ff, dropout,
56 | identity_map_reordering=identity_map_reordering,
57 | attention_module=attention_module,
58 | attention_module_kwargs=attention_module_kwargs)
59 | for _ in range(N)])
60 | self.padding_idx = padding_idx
61 |
62 | def forward(self, input, attention_weights=None):
63 | # input (b_s, seq_len, d_in)
64 | attention_mask = (torch.sum(input, -1) == self.padding_idx).unsqueeze(1).unsqueeze(1) # (b_s, 1, 1, seq_len)
65 |
66 | outs = []
67 | out = input
68 | for l in self.layers:
69 | out = l(out, out, out, attention_mask, attention_weights)
70 | outs.append(out.unsqueeze(1))
71 |
72 | outs = torch.cat(outs, 1)
73 | return outs, attention_mask
74 |
75 |
76 | class MemoryAugmentedEncoder(MultiLevelEncoder):
77 | def __init__(self, N, padding_idx, d_in=2048, **kwargs):
78 | super(MemoryAugmentedEncoder, self).__init__(N, padding_idx, **kwargs)
79 | self.fc = nn.Linear(d_in, self.d_model)
80 | self.dropout = nn.Dropout(p=self.dropout)
81 | self.layer_norm = nn.LayerNorm(self.d_model)
82 |
83 | def forward(self, input, attention_weights=None):
84 | out = F.relu(self.fc(input))
85 | out = self.dropout(out)
86 | out = self.layer_norm(out)
87 | return super(MemoryAugmentedEncoder, self).forward(out, attention_weights=attention_weights)
88 |
--------------------------------------------------------------------------------
/models/transformer/transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import copy
4 | from models.containers import ModuleList
5 | from ..captioning_model import CaptioningModel
6 |
7 |
8 | class Transformer(CaptioningModel):
9 | def __init__(self, bos_idx, encoder, decoder):
10 | super(Transformer, self).__init__()
11 | self.bos_idx = bos_idx
12 | self.encoder = encoder
13 | self.decoder = decoder
14 | self.register_state('enc_output', None)
15 | self.register_state('mask_enc', None)
16 | self.init_weights()
17 |
18 | @property
19 | def d_model(self):
20 | return self.decoder.d_model
21 |
22 | def init_weights(self):
23 | for p in self.parameters():
24 | if p.dim() > 1:
25 | nn.init.xavier_uniform_(p)
26 |
27 | def forward(self, images, seq, *args):
28 | enc_output, mask_enc = self.encoder(images)
29 | dec_output = self.decoder(seq, enc_output, mask_enc)
30 | return dec_output
31 |
32 | def init_state(self, b_s, device):
33 | return [torch.zeros((b_s, 0), dtype=torch.long, device=device),
34 | None, None]
35 |
36 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs):
37 | it = None
38 | if mode == 'teacher_forcing':
39 | raise NotImplementedError
40 | elif mode == 'feedback':
41 | if t == 0:
42 | self.enc_output, self.mask_enc = self.encoder(visual)
43 | if isinstance(visual, torch.Tensor):
44 | it = visual.data.new_full((visual.shape[0], 1), self.bos_idx).long()
45 | else:
46 | it = visual[0].data.new_full((visual[0].shape[0], 1), self.bos_idx).long()
47 | else:
48 | it = prev_output
49 |
50 | return self.decoder(it, self.enc_output, self.mask_enc)
51 |
52 |
53 | class TransformerEnsemble(CaptioningModel):
54 | def __init__(self, model: Transformer, weight_files):
55 | super(TransformerEnsemble, self).__init__()
56 | self.n = len(weight_files)
57 | self.models = ModuleList([copy.deepcopy(model) for _ in range(self.n)])
58 | for i in range(self.n):
59 | state_dict_i = torch.load(weight_files[i])['state_dict']
60 | self.models[i].load_state_dict(state_dict_i)
61 |
62 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs):
63 | out_ensemble = []
64 | for i in range(self.n):
65 | out_i = self.models[i].step(t, prev_output, visual, seq, mode, **kwargs)
66 | out_ensemble.append(out_i.unsqueeze(0))
67 |
68 | return torch.mean(torch.cat(out_ensemble, 0), dim=0)
69 |
--------------------------------------------------------------------------------
/models/transformer/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 |
6 | def position_embedding(input, d_model):
7 | input = input.view(-1, 1)
8 | dim = torch.arange(d_model // 2, dtype=torch.float32, device=input.device).view(1, -1)
9 | sin = torch.sin(input / 10000 ** (2 * dim / d_model))
10 | cos = torch.cos(input / 10000 ** (2 * dim / d_model))
11 |
12 | out = torch.zeros((input.shape[0], d_model), device=input.device)
13 | out[:, ::2] = sin
14 | out[:, 1::2] = cos
15 | return out
16 |
17 |
18 | def sinusoid_encoding_table(max_len, d_model, padding_idx=None):
19 | pos = torch.arange(max_len, dtype=torch.float32)
20 | out = position_embedding(pos, d_model)
21 |
22 | if padding_idx is not None:
23 | out[padding_idx] = 0
24 | return out
25 |
26 |
27 | class PositionWiseFeedForward(nn.Module):
28 | '''
29 | Position-wise feed forward layer
30 | '''
31 |
32 | def __init__(self, d_model=512, d_ff=2048, dropout=.1, identity_map_reordering=False):
33 | super(PositionWiseFeedForward, self).__init__()
34 | self.identity_map_reordering = identity_map_reordering
35 | self.fc1 = nn.Linear(d_model, d_ff)
36 | self.fc2 = nn.Linear(d_ff, d_model)
37 | self.dropout = nn.Dropout(p=dropout)
38 | self.dropout_2 = nn.Dropout(p=dropout)
39 | self.layer_norm = nn.LayerNorm(d_model)
40 |
41 | def forward(self, input):
42 | if self.identity_map_reordering:
43 | out = self.layer_norm(input)
44 | out = self.fc2(self.dropout_2(F.relu(self.fc1(out))))
45 | out = input + self.dropout(torch.relu(out))
46 | else:
47 | out = self.fc2(self.dropout_2(F.relu(self.fc1(input))))
48 | # out = self.fc2(self.dropout_2(F.gelu(self.fc1(input))))
49 | out = self.dropout(out)
50 | out = input + out
51 | # out = self.layer_norm(input + out)
52 | return out
53 |
54 | class PositionWiseFeedForward_BN(nn.Module):
55 | '''
56 | Position-wise feed forward layer
57 | '''
58 |
59 | def __init__(self, d_model=512, d_ff=2048, dropout=.1, identity_map_reordering=False):
60 | super(PositionWiseFeedForward_BN, self).__init__()
61 | self.identity_map_reordering = identity_map_reordering
62 | self.fc1 = nn.Linear(d_model, d_ff)
63 | self.fc2 = nn.Linear(d_ff, d_model)
64 | self.dropout = nn.Dropout(p=dropout)
65 | self.dropout_2 = nn.Dropout(p=dropout)
66 | self.layer_norm = nn.LayerNorm(d_model)
67 | self.bn = nn.BatchNorm1d(d_ff)
68 | def forward(self, input):
69 | if self.identity_map_reordering:
70 | out = self.layer_norm(input)
71 | out = self.fc2(self.dropout_2(F.relu(self.fc1(out))))
72 | out = input + self.dropout(torch.relu(out))
73 | else:
74 | # out = self.fc2(self.dropout_2(F.relu(self.fc1(input))))
75 | # out = self.dropout(out)
76 | # out = (input + out).tranpose(1,2)
77 | # out = self.bn(out).transpose(1,2)
78 | # out = self.layer_norm(input + out)
79 | out = self.fc1(input).transpose(1,2)
80 | out = F.relu(self.bn(out).transpose(1,2))
81 | out = self.fc2(out)+input
82 |
83 | return out
84 |
85 | class PositionWiseFeedForward_LN(nn.Module):
86 | '''
87 | Position-wise feed forward layer
88 | '''
89 |
90 | def __init__(self, d_model=512, d_ff=2048, dropout=.1, identity_map_reordering=False):
91 | super(PositionWiseFeedForward_LN, self).__init__()
92 | self.identity_map_reordering = identity_map_reordering
93 | self.fc1 = nn.Linear(d_model, d_ff)
94 | self.fc2 = nn.Linear(d_ff, d_model)
95 | self.dropout = nn.Dropout(p=dropout)
96 | self.dropout_2 = nn.Dropout(p=dropout)
97 | self.layer_norm = nn.LayerNorm(d_model)
98 |
99 | def forward(self, input):
100 | if self.identity_map_reordering:
101 | out = self.layer_norm(input)
102 | out = self.fc2(self.dropout_2(F.relu(self.fc1(out))))
103 | out = input + self.dropout(torch.relu(out))
104 | else:
105 | out = self.fc2(self.dropout_2(F.relu(self.fc1(input))))
106 | # out = self.fc2(self.dropout_2(F.gelu(self.fc1(input))))
107 | out = self.dropout(out)
108 | # out = input + out
109 | out = self.layer_norm(input + out)
110 | return out
--------------------------------------------------------------------------------
/models/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 |
6 | def position_embedding(input, d_model):
7 | input = input.view(-1, 1)
8 | dim = torch.arange(d_model // 2, dtype=torch.float32, device=input.device).view(1, -1)
9 | sin = torch.sin(input / 10000 ** (2 * dim / d_model))
10 | cos = torch.cos(input / 10000 ** (2 * dim / d_model))
11 |
12 | out = torch.zeros((input.shape[0], d_model), device=input.device)
13 | out[:, ::2] = sin
14 | out[:, 1::2] = cos
15 | return out
16 |
17 |
18 | def sinusoid_encoding_table(max_len, d_model, padding_idx=None):
19 | pos = torch.arange(max_len, dtype=torch.float32)
20 | out = position_embedding(pos, d_model)
21 |
22 | if padding_idx is not None:
23 | out[padding_idx] = 0
24 | return out
25 |
26 |
27 | class PositionWiseFeedForward(nn.Module):
28 | '''
29 | Position-wise feed forward layer
30 | '''
31 |
32 | def __init__(self, d_model=512, d_ff=2048, dropout=.1, identity_map_reordering=False):
33 | super(PositionWiseFeedForward, self).__init__()
34 | self.identity_map_reordering = identity_map_reordering
35 | self.fc1 = nn.Linear(d_model, d_ff)
36 | self.fc2 = nn.Linear(d_ff, d_model)
37 | self.dropout = nn.Dropout(p=dropout)
38 | self.dropout_2 = nn.Dropout(p=dropout)
39 | self.layer_norm = nn.LayerNorm(d_model)
40 |
41 | def forward(self, input):
42 | if self.identity_map_reordering:
43 | out = self.layer_norm(input)
44 | out = self.fc2(self.dropout_2(F.relu(self.fc1(out))))
45 | out = input + self.dropout(torch.relu(out))
46 | else:
47 | out = self.fc2(self.dropout_2(F.relu(self.fc1(input))))
48 | out = self.dropout(out)
49 | out = self.layer_norm(out+input)
50 | return out
--------------------------------------------------------------------------------
/models/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import download_from_url
2 | from .typing import *
3 |
4 | def get_batch_size(x: TensorOrSequence) -> int:
5 | if isinstance(x, torch.Tensor):
6 | b_s = x.size(0)
7 | else:
8 | b_s = x[0].size(0)
9 | return b_s
10 |
11 |
12 | def get_device(x: TensorOrSequence) -> int:
13 | if isinstance(x, torch.Tensor):
14 | b_s = x.device
15 | else:
16 | b_s = x[0].device
17 | return b_s
18 |
--------------------------------------------------------------------------------
/models/utils/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/models/utils/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/utils/__pycache__/typing.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/models/utils/__pycache__/typing.cpython-36.pyc
--------------------------------------------------------------------------------
/models/utils/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixmath/CVR/af2ea6859230c1a5f56868a2242c06c2fe02f05a/models/utils/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/models/utils/typing.py:
--------------------------------------------------------------------------------
1 | from typing import Union, Sequence, Tuple
2 | import torch
3 |
4 | TensorOrSequence = Union[Sequence[torch.Tensor], torch.Tensor]
5 | TensorOrNone = Union[torch.Tensor, None]
6 |
--------------------------------------------------------------------------------
/models/utils/utils.py:
--------------------------------------------------------------------------------
1 | import requests
2 |
3 | def download_from_url(url, path):
4 | """Download file, with logic (from tensor2tensor) for Google Drive"""
5 | if 'drive.google.com' not in url:
6 | print('Downloading %s; may take a few minutes' % url)
7 | r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'})
8 | with open(path, "wb") as file:
9 | file.write(r.content)
10 | return
11 | print('Downloading from Google Drive; may take a few minutes')
12 | confirm_token = None
13 | session = requests.Session()
14 | response = session.get(url, stream=True)
15 | for k, v in response.cookies.items():
16 | if k.startswith("download_warning"):
17 | confirm_token = v
18 |
19 | if confirm_token:
20 | url = url + "&confirm=" + confirm_token
21 | response = session.get(url, stream=True)
22 |
23 | chunk_size = 16 * 1024
24 | with open(path, "wb") as f:
25 | for chunk in response.iter_content(chunk_size):
26 | if chunk:
27 | f.write(chunk)
28 |
--------------------------------------------------------------------------------
/otk/cross_attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from einops import rearrange, repeat
4 | from torch import nn
5 |
6 | MIN_NUM_PATCHES = 16
7 |
8 | class Residual(nn.Module):
9 | def __init__(self, fn):
10 | super().__init__()
11 | self.fn = fn
12 | def forward(self, x, **kwargs):
13 | return self.fn(x, **kwargs) + x
14 |
15 | class PreNorm(nn.Module):
16 | def __init__(self, dim, fn):
17 | super().__init__()
18 | self.norm = nn.LayerNorm(dim)
19 | self.fn = fn
20 | def forward(self, x, **kwargs):
21 | return self.fn(self.norm(x), **kwargs)
22 |
23 | class FeedForward(nn.Module):
24 | def __init__(self, dim, hidden_dim, dropout = 0.):
25 | super().__init__()
26 | self.net = nn.Sequential(
27 | nn.Linear(dim, hidden_dim),
28 | nn.GELU(),
29 | nn.Dropout(dropout),
30 | nn.Linear(hidden_dim, dim),
31 | nn.Dropout(dropout)
32 | )
33 | def forward(self, x):
34 | return self.net(x)
35 | class Attention_entropy(nn.Module):
36 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
37 | super().__init__()
38 | inner_dim = dim_head * heads
39 | self.heads = heads
40 | self.scale = dim_head ** -0.5
41 |
42 | self.to_qk = nn.Linear(dim, inner_dim * 2, bias = False)
43 | self.to_v = nn.Linear(dim,inner_dim,bias=False)
44 | self.to_out = nn.Sequential(
45 | nn.Linear(inner_dim, dim),
46 | nn.Dropout(dropout)
47 | )
48 |
49 | def forward(self, x, mask = None,mm=None):
50 | b, n, _, h = *x.shape, self.heads
51 | qk = self.to_qk(x)
52 | v = self.to_v(x)
53 | qkv = torch.cat((qk,v),-1)
54 | qkv = qkv.chunk(3, dim = -1)
55 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
56 |
57 | dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
58 | mask_value = -torch.finfo(dots.dtype).max
59 |
60 | if mask is not None:
61 | # mask = F.pad(mask.flatten(1), (1, 0), value = True)
62 | assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
63 | mask = rearrange(mask, 'b i -> b () i ()') * rearrange(mask, 'b j -> b () () j')
64 | dots.masked_fill_(~mask, mask_value)
65 | del mask
66 | mm = mm.unsqueeze(1).repeat(b,h,1,1)
67 | dots = dots + mm
68 | attn = dots.softmax(dim=-1)
69 | out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
70 | out = rearrange(out, 'b h n d -> b n (h d)')
71 | out = self.to_out(out)
72 | return out
73 | class Attention(nn.Module):
74 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
75 | super().__init__()
76 | inner_dim = dim_head * heads
77 | self.heads = heads
78 | self.scale = dim_head ** -0.5
79 |
80 | self.to_qk = nn.Linear(dim, inner_dim * 2, bias = False)
81 | self.to_v = nn.Linear(dim,inner_dim,bias=False)
82 | self.to_out = nn.Sequential(
83 | nn.Linear(inner_dim, dim),
84 | nn.Dropout(dropout)
85 | )
86 |
87 | def forward(self, x, mask = None,mm=None):
88 | b, n, _, h = *x.shape, self.heads
89 | qk = self.to_qk(x+mm)
90 | v = self.to_v(x)
91 | qkv = torch.cat((qk,v),-1)
92 | qkv = qkv.chunk(3, dim = -1)
93 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
94 |
95 | dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
96 | mask_value = -torch.finfo(dots.dtype).max
97 |
98 | if mask is not None:
99 | # mask = F.pad(mask.flatten(1), (1, 0), value = True)
100 | assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
101 | mask = rearrange(mask, 'b i -> b () i ()') * rearrange(mask, 'b j -> b () () j')
102 | dots.masked_fill_(~mask, mask_value)
103 | del mask
104 |
105 | attn = dots.softmax(dim=-1)
106 |
107 | out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
108 | out = rearrange(out, 'b h n d -> b n (h d)')
109 | out = self.to_out(out)
110 | return out
111 |
112 | class Transformer(nn.Module):
113 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):
114 | super().__init__()
115 | self.layers = nn.ModuleList([])
116 | for _ in range(depth):
117 | self.layers.append(nn.ModuleList([
118 | Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
119 | Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
120 | ]))
121 | def forward(self, x, mask = None,mm=None):
122 | for attn, ff in self.layers:
123 | x = attn(x, mask = mask,mm=mm)
124 | x = ff(x)
125 | return x
126 | class Transformer_noff(nn.Module):
127 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):
128 | super().__init__()
129 | self.layers = nn.ModuleList([])
130 | for _ in range(depth):
131 | self.layers.append(nn.ModuleList([
132 | Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
133 | Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
134 | ]))
135 | def forward(self, x, mask = None,mm=None):
136 | for attn, ff in self.layers:
137 | x = attn(x, mask = mask,mm=mm)
138 | # x = ff(x)
139 | return x
140 |
141 | class ViT(nn.Module):
142 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
143 | super().__init__()
144 | assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
145 | num_patches = (image_size // patch_size) ** 2
146 | patch_dim = channels * patch_size ** 2
147 | assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size'
148 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
149 |
150 | self.patch_size = patch_size
151 |
152 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
153 | self.patch_to_embedding = nn.Linear(patch_dim, dim)
154 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
155 | self.dropout = nn.Dropout(emb_dropout)
156 |
157 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
158 |
159 | self.pool = pool
160 | self.to_latent = nn.Identity()
161 |
162 | self.mlp_head = nn.Sequential(
163 | nn.LayerNorm(dim),
164 | nn.Linear(dim, num_classes)
165 | )
166 |
167 | def forward(self, img, mask = None):
168 | p = self.patch_size
169 |
170 | x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
171 | x = self.patch_to_embedding(x)
172 | b, n, _ = x.shape
173 |
174 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
175 | x = torch.cat((cls_tokens, x), dim=1)
176 | x += self.pos_embedding[:, :(n + 1)]
177 | x = self.dropout(x)
178 |
179 | x = self.transformer(x, mask)
180 |
181 | x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
182 |
183 | x = self.to_latent(x)
184 | return self.mlp_head(x)
--------------------------------------------------------------------------------
/otk/data_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import h5py
3 | import scipy.io as sio
4 | import torch
5 | from torch.utils.data import Dataset
6 |
7 |
8 | def _load_mat_file(filepath, sequence_key, targets_key=None):
9 | """
10 | Loads data from a `*.mat` file or a `*.h5` file.
11 | Parameters
12 | ----------
13 | filepath : str
14 | The path to the file to load the data from.
15 | sequence_key : str
16 | The key for the sequences data matrix.
17 | targets_key : str, optional
18 | Default is None. The key for the targets data matrix.
19 | Returns
20 | -------
21 | (sequences, targets, h5py_filehandle) : \
22 | tuple(array-like, array-like, h5py.File)
23 | If the matrix files can be loaded with `scipy.io`,
24 | the tuple will only be (sequences, targets). Otherwise,
25 | the 2 matrices and the h5py file handle are returned.
26 | """
27 | try: # see if we can load the file using scipy first
28 | mat = sio.loadmat(filepath)
29 | targets = None
30 | if targets_key:
31 | targets = mat[targets_key]
32 | return (mat[sequence_key], targets)
33 | except (NotImplementedError, ValueError):
34 | mat = h5py.File(filepath, 'r')
35 | sequences = mat[sequence_key]
36 | targets = None
37 | if targets_key:
38 | targets = mat[targets_key]
39 | return (sequences, targets, mat)
40 |
41 |
42 | class MatDataset(Dataset):
43 | def __init__(self, filepath, split='train'):
44 | super().__init__()
45 | filepath = filepath + "/{}.mat".format(split)
46 | sequence_key = "{}xdata".format(split)
47 | targets_key = "{}data".format(split)
48 | out = _load_mat_file(filepath, sequence_key, targets_key)
49 | self.data_tensor = out[0]
50 | self.target_tensor = out[1]
51 | self.split = split
52 |
53 | def __getitem__(self, index):
54 | if self.split == "train":
55 | data_tensor = self.data_tensor[:, :, index]
56 | data_tensor = data_tensor.transpose().astype('float32')
57 | target_tensor = self.target_tensor[:, index].astype('float32')
58 | else:
59 | data_tensor = self.data_tensor[index].astype('float32')
60 | target_tensor = self.target_tensor[index].astype('float32')
61 | data_tensor = torch.from_numpy(data_tensor)
62 | target_tensor = torch.from_numpy(target_tensor)
63 | return data_tensor, target_tensor
64 |
65 | def __len__(self):
66 | if self.split == 'train':
67 | return self.target_tensor.shape[1]
68 | return self.target_tensor.shape[0]
69 |
--------------------------------------------------------------------------------
/otk/layers.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | import math
4 | from torch import nn
5 | import torch.optim as optim
6 | from .utils import spherical_kmeans, normalize
7 | from .sinkhorn import wasserstein_kmeans, multihead_attn
8 | import numpy as np
9 |
10 |
11 | class OTKernel(nn.Module):
12 | def __init__(self, in_dim, out_size, heads=1, eps=0.1, max_iter=100,
13 | log_domain=False, position_encoding=None, position_sigma=0.1):
14 | super().__init__()
15 | self.in_dim = in_dim
16 | self.out_size = out_size
17 | self.heads = heads
18 | self.eps = eps
19 | self.max_iter = max_iter
20 |
21 | self.weight = nn.Parameter(
22 | torch.Tensor(heads, out_size, in_dim))
23 |
24 | self.log_domain = log_domain
25 | self.position_encoding = position_encoding
26 | self.position_sigma = position_sigma
27 |
28 | self.reset_parameter()
29 | nn.init.xavier_normal_(self.weight)
30 | # nn.init.kaiming_normal_(self.weight)
31 | def reset_parameter(self):
32 | stdv = 1. / math.sqrt(self.out_size)
33 | for w in self.parameters():
34 | w.data.uniform_(-stdv, stdv)
35 | # def reset_parameter(self):
36 | # for w in self.parameters():
37 | def get_position_filter(self, input, out_size):
38 | if input.ndim == 4:
39 | in_size1 = input.shape[1]
40 | in_size2 = input.shape[2]
41 | out_size = int(math.sqrt(out_size))
42 | if self.position_encoding is None:
43 | return self.position_encoding
44 | elif self.position_encoding == "gaussian":
45 | sigma = self.position_sigma
46 | a1 = torch.arange(1., in_size1 + 1.).view(-1, 1) / in_size1
47 | a2 = torch.arange(1., in_size2 + 1.).view(-1, 1) / in_size2
48 | b = torch.arange(1., out_size + 1.).view(1, -1) / out_size
49 | position_filter1 = torch.exp(-((a1 - b) / sigma) ** 2)
50 | position_filter2 = torch.exp(-((a2 - b) / sigma) ** 2)
51 | position_filter = position_filter1.view(
52 | in_size1, 1, out_size, 1) * position_filter2.view(
53 | 1, in_size2, 1, out_size)
54 | if self.weight.is_cuda:
55 | position_filter = position_filter.cuda()
56 | return position_filter.reshape(1, 1, in_size1 * in_size2, out_size * out_size)
57 | in_size = input.shape[1]
58 | if self.position_encoding is None:
59 | return self.position_encoding
60 | elif self.position_encoding == "gaussian":
61 | # sigma = 1. / out_size
62 | sigma = self.position_sigma
63 | a = torch.arange(0., in_size).view(-1, 1) / in_size
64 | b = torch.arange(0., out_size).view(1, -1) / out_size
65 | position_filter = torch.exp(-((a - b) / sigma) ** 2)
66 | elif self.position_encoding == "hard":
67 | # sigma = 1. / out_size
68 | sigma = self.position_sigma
69 | a = torch.arange(0., in_size).view(-1, 1) / in_size
70 | b = torch.arange(0., out_size).view(1, -1) / out_size
71 | position_filter = torch.abs(a - b) < sigma
72 | position_filter = position_filter.float()
73 | else:
74 | raise ValueError("Unrecognizied position encoding")
75 | if self.weight.is_cuda:
76 | position_filter = position_filter.cuda()
77 | position_filter = position_filter.view(1, 1, in_size, out_size)
78 | return position_filter
79 |
80 | def get_attn(self, input, mask=None, position_filter=None):
81 | """Compute the attention weight using Sinkhorn OT
82 | input: batch_size x in_size x in_dim
83 | mask: batch_size x in_size
84 | self.weight: heads x out_size x in_dim
85 | output: batch_size x (out_size x heads) x in_size
86 | """
87 | return multihead_attn(
88 | input, self.weight, mask=mask, eps=self.eps,
89 | max_iter=self.max_iter, log_domain=self.log_domain,
90 | position_filter=position_filter)
91 |
92 | def forward(self, input, mask=None):
93 | """
94 | input: batch_size x in_size x in_dim
95 | output: batch_size x out_size x (heads x in_dim)
96 | """
97 | batch_size = input.shape[0]
98 | position_filter = self.get_position_filter(input, self.out_size)
99 | in_ndim = input.ndim
100 | if in_ndim == 4:
101 | input = input.view(batch_size, -1, self.in_dim)
102 | attn_weight = self.get_attn(input, mask, position_filter)
103 | # attn_weight: batch_size x out_size x heads x in_size
104 | # aa = attn_weight.detach().cpu().numpy()
105 | output = torch.bmm(
106 | attn_weight.view(batch_size, self.out_size * self.heads, -1), input)
107 | if in_ndim == 4:
108 | out_size = int(math.sqrt(self.out_size))
109 | output = output.reshape(batch_size, out_size, out_size, -1)
110 | else:
111 | output = output.reshape(batch_size, self.out_size, -1)
112 | return output
113 |
114 | def unsup_train(self, input, wb=False, inplace=True, use_cuda=False):
115 | """K-meeans for learning parameters
116 | input: n_samples x in_size x in_dim
117 | weight: heads x out_size x in_dim
118 | """
119 | input_normalized = normalize(input, inplace=inplace)
120 | # input_normalized = input
121 | block_size = int(1e9) // (input.shape[1] * input.shape[2] * 4)
122 | print("Starting Wasserstein K-means")
123 | weight = wasserstein_kmeans(
124 | input_normalized, self.heads, self.out_size, eps=self.eps,
125 | block_size=block_size, wb=wb, log_domain=self.log_domain, use_cuda=use_cuda)
126 | self.weight.data.copy_(weight)
127 |
128 | def random_sample(self, input):
129 | idx = torch.randint(0, input.shape[0], (1,))
130 | self.weight.data.copy_(input[idx].view_as(self.weight))
131 |
132 | class Linear(nn.Linear):
133 | def forward(self, input):
134 | bias = self.bias
135 | if bias is not None and hasattr(self, 'scale_bias') and self.scale_bias is not None:
136 | bias = self.scale_bias * bias
137 | out = torch.nn.functional.linear(input, self.weight, bias)
138 | return out
139 |
140 | def fit(self, Xtr, ytr, criterion, reg=0.0, epochs=100, optimizer=None, use_cuda=False):
141 | if optimizer is None:
142 | optimizer = optim.LBFGS(self.parameters(), lr=1.0, history_size=10)
143 | if self.bias is not None:
144 | scale_bias = (Xtr ** 2).mean(-1).sqrt().mean().item()
145 | self.scale_bias = scale_bias
146 | self.train()
147 | if use_cuda:
148 | self.cuda()
149 | Xtr = Xtr.cuda()
150 | ytr = ytr.cuda()
151 | def closure():
152 | optimizer.zero_grad()
153 | output = self(Xtr)
154 | loss = criterion(output, ytr)
155 | loss = loss + 0.5 * reg * self.weight.pow(2).sum()
156 | loss.backward()
157 | return loss
158 |
159 | for epoch in range(epochs):
160 | optimizer.step(closure)
161 | if self.bias is not None:
162 | self.bias.data.mul_(self.scale_bias)
163 | self.scale_bias = None
164 |
165 | def score(self, X, y):
166 | self.eval()
167 | with torch.no_grad():
168 | scores = self(X)
169 | scores = scores.argmax(-1)
170 | scores = scores.cpu()
171 | return torch.mean((scores == y).float()).item()
172 |
--------------------------------------------------------------------------------
/otk/models.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | from torch import nn
4 | from .layers import OTKernel, Linear
5 | from ckn.layers import BioEmbedding
6 | from ckn.models import CKNSequential
7 |
8 |
9 | class SeqAttention(nn.Module):
10 | def __init__(self, in_channels, nclass, hidden_sizes, filter_sizes,
11 | subsamplings, kernel_args=None, eps=0.1, heads=1,
12 | out_size=1, max_iter=50, alpha=0., fit_bias=True,
13 | mask_zeros=True):
14 | super().__init__()
15 | self.embed_layer = BioEmbedding(
16 | in_channels, False, mask_zeros=True, no_embed=True)
17 | self.ckn_model = CKNSequential(
18 | in_channels, hidden_sizes, filter_sizes,
19 | subsamplings, kernel_args_list=kernel_args)
20 | self.attention = OTKernel(hidden_sizes[-1], out_size, heads=heads,
21 | eps=eps, max_iter=max_iter)
22 | self.out_features = out_size * heads * hidden_sizes[-1]
23 | self.nclass = nclass
24 |
25 | self.classifier = Linear(self.out_features, nclass, bias=fit_bias)
26 | self.alpha = alpha
27 | self.mask_zeros = mask_zeros
28 |
29 | def feature_parameters(self):
30 | import itertools
31 | return itertools.chain(self.ckn_model.parameters(),
32 | self.attention.parameters())
33 |
34 | def normalize_(self):
35 | self.ckn_model.normalize_()
36 |
37 | def ckn_representation_at(self, input, n=0):
38 | output = self.embed_layer(input)
39 | mask = self.embed_layer.compute_mask(input)
40 | output = self.ckn_model.representation(output, n)
41 | mask = self.ckn_model.compute_mask(mask, n)
42 | return output, mask
43 |
44 | def ckn_representation(self, input):
45 | output = self.embed_layer(input)
46 | output = self.ckn_model(output).permute(0, 2, 1).contiguous()
47 | return output
48 |
49 | def representation(self, input):
50 | output = self.embed_layer(input)
51 | mask = self.embed_layer.compute_mask(input)
52 | output = self.ckn_model(output).permute(0, 2, 1).contiguous()
53 | mask = self.ckn_model.compute_mask(mask)
54 | if not self.mask_zeros:
55 | mask = None
56 | output = self.attention(output, mask).reshape(output.shape[0], -1)
57 | return output
58 |
59 | def forward(self, input):
60 | output = self.representation(input)
61 | return self.classifier(output)
62 |
63 | def predict(self, data_loader, only_repr=False, use_cuda=False):
64 | n_samples = len(data_loader.dataset)
65 | target_output = torch.LongTensor(n_samples)
66 | batch_start = 0
67 | for i, (data, target) in enumerate(data_loader):
68 | batch_size = data.shape[0]
69 | if use_cuda:
70 | data = data.cuda()
71 | with torch.no_grad():
72 | if only_repr:
73 | batch_out = self.representation(data).data.cpu()
74 | else:
75 | batch_out = self(data).data.cpu()
76 | if i == 0:
77 | output = batch_out.new_empty([n_samples] +
78 | list(batch_out.shape[1:]))
79 | output[batch_start:batch_start + batch_size] = batch_out
80 | target_output[batch_start:batch_start + batch_size] = target
81 | batch_start += batch_size
82 | return output, target_output
83 |
84 | def train_classifier(self, data_loader, criterion=None, epochs=100,
85 | optimizer=None, use_cuda=False):
86 | encoded_train, encoded_target = self.predict(
87 | data_loader, only_repr=True, use_cuda=use_cuda)
88 | self.classifier.fit(encoded_train, encoded_target, criterion,
89 | reg=self.alpha, epochs=epochs, optimizer=optimizer,
90 | use_cuda=use_cuda)
91 |
92 | def unsup_train(self, data_loader, n_sampling_patches=300000,
93 | n_samples=5000, wb=False, use_cuda=False):
94 | self.eval()
95 | if use_cuda:
96 | self.cuda()
97 |
98 | for i, ckn_layer in enumerate(self.ckn_model):
99 | print("Training ckn layer {}".format(i))
100 | n_patches = 0
101 | try:
102 | n_patches_per_batch = (
103 | n_sampling_patches + len(data_loader) - 1
104 | ) // len(data_loader)
105 | except:
106 | n_patches_per_batch = 1000
107 | patches = torch.Tensor(n_sampling_patches, ckn_layer.patch_dim)
108 | if use_cuda:
109 | patches = patches.cuda()
110 |
111 | for data, _ in data_loader:
112 | if n_patches >= n_sampling_patches:
113 | continue
114 | if use_cuda:
115 | data = data.cuda()
116 | with torch.no_grad():
117 | data, mask = self.ckn_representation_at(data, i)
118 | data_patches = ckn_layer.sample_patches(
119 | data, mask, n_patches_per_batch)
120 | size = data_patches.size(0)
121 | if n_patches + size > n_sampling_patches:
122 | size = n_sampling_patches - n_patches
123 | data_patches = data_patches[:size]
124 | patches[n_patches: n_patches + size] = data_patches
125 | n_patches += size
126 |
127 | print("total number of patches: {}".format(n_patches))
128 | patches = patches[:n_patches]
129 | ckn_layer.unsup_train(patches, init=None)
130 |
131 | n_samples = min(n_samples, len(data_loader.dataset))
132 | cur_samples = 0
133 | print("Training attention layer")
134 | for i, (data, _) in enumerate(data_loader):
135 | if cur_samples >= n_samples:
136 | continue
137 | if use_cuda:
138 | data = data.cuda()
139 | with torch.no_grad():
140 | data = self.ckn_representation(data)
141 |
142 | if i == 0:
143 | patches = torch.empty([n_samples]+list(data.shape[1:]))
144 |
145 | size = data.shape[0]
146 | if cur_samples + size > n_samples:
147 | size = n_samples - cur_samples
148 | data = data[:size]
149 | patches[cur_samples: cur_samples + size] = data
150 | cur_samples += size
151 | print(patches.shape)
152 | self.attention.unsup_train(patches, wb=wb, use_cuda=use_cuda)
153 |
--------------------------------------------------------------------------------
/otk/models_deepsea.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | from torch import nn
4 | from .layers import OTKernel
5 |
6 |
7 | class OTLayer(nn.Module):
8 | def __init__(self, in_dim, out_size, heads=1, eps=0.1, max_iter=10,
9 | position_encoding=None, position_sigma=0.1, out_dim=None,
10 | dropout=0.4):
11 | super().__init__()
12 | self.out_size = out_size
13 | self.heads = heads
14 | if out_dim is None:
15 | out_dim = in_dim
16 |
17 | self.layer = nn.Sequential(
18 | OTKernel(in_dim, out_size, heads, eps, max_iter, log_domain=True,
19 | position_encoding=position_encoding, position_sigma=position_sigma),
20 | nn.Linear(heads * in_dim, out_dim),
21 | nn.ReLU(inplace=True),
22 | nn.Dropout(dropout)
23 | )
24 | nn.init.xavier_uniform_(self.layer[0].weight)
25 | nn.init.xavier_uniform_(self.layer[1].weight)
26 |
27 | def forward(self, input):
28 | output = self.layer(input)
29 | return output
30 |
31 | class SeqAttention(nn.Module):
32 | def __init__(self, nclass, hidden_size, filter_size,
33 | n_attn_layers, eps=0.1, heads=1,
34 | out_size=1, max_iter=10, hidden_layer=False,
35 | position_encoding=None, position_sigma=0.1):
36 | super().__init__()
37 | self.embed = nn.Sequential(
38 | nn.Conv1d(4, hidden_size, kernel_size=filter_size),
39 | nn.ReLU(inplace=True),
40 | )
41 |
42 | attn_layers = [OTLayer(
43 | hidden_size, out_size, heads, eps, max_iter,
44 | position_encoding, position_sigma=position_sigma)] + [OTLayer(
45 | hidden_size, out_size, heads, eps, max_iter, position_encoding, position_sigma=position_sigma
46 | ) for _ in range(n_attn_layers - 1)]
47 | self.attn_layers = nn.Sequential(*attn_layers)
48 |
49 | self.out_features = out_size * hidden_size
50 | self.nclass = nclass
51 |
52 | if hidden_layer:
53 | self.classifier = nn.Sequential(
54 | nn.Linear(self.out_features, nclass),
55 | nn.ReLU(inplace=True),
56 | nn.Linear(nclass, nclass))
57 | else:
58 | self.classifier = nn.Linear(self.out_features, nclass)
59 |
60 | def representation(self, input):
61 | output = self.embed(input).transpose(1, 2).contiguous()
62 | output = self.attn_layers(output)
63 | output = output.reshape(output.shape[0], -1)
64 | return output
65 |
66 | def forward(self, input):
67 | output = self.representation(input)
68 | return self.classifier(output)
69 |
70 | def predict(self, data_loader, only_repr=False, use_cuda=False):
71 | n_samples = len(data_loader.dataset)
72 | target_output = torch.LongTensor(n_samples)
73 | batch_start = 0
74 | for i, (data, target) in enumerate(data_loader):
75 | batch_size = data.shape[0]
76 | if use_cuda:
77 | data = data.cuda()
78 | with torch.no_grad():
79 | if only_repr:
80 | batch_out = self.representation(data).data.cpu()
81 | else:
82 | batch_out = self(data).data.cpu()
83 | if i == 0:
84 | output = batch_out.new_empty([n_samples] + list(batch_out.shape[1:]))
85 | output[batch_start:batch_start + batch_size] = batch_out
86 | target_output[batch_start:batch_start + batch_size] = target
87 | batch_start += batch_size
88 | return output, target_output
89 |
--------------------------------------------------------------------------------
/otk/sinkhorn.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import math
3 | import torch
4 | from .utils import spherical_kmeans,normalize
5 |
6 | def sinkhorn(dot, mask=None, eps=1e-03, return_kernel=False, max_iter=100):
7 | """
8 | dot: n x in_size x out_size
9 | mask: n x in_size
10 | output: n x in_size x out_size
11 | """
12 | n, in_size, out_size = dot.shape
13 | if return_kernel:
14 | K = torch.exp(dot / eps)
15 | else:
16 | K = dot
17 | # K: n x in_size x out_size
18 | u = K.new_ones((n, in_size))
19 | v = K.new_ones((n, out_size))
20 | a = float(out_size / in_size)
21 | if mask is not None:
22 | mask = mask.float()
23 | a = out_size / mask.sum(1, keepdim=True)
24 | for _ in range(max_iter):
25 | u = a / torch.bmm(K, v.view(n, out_size, 1)).view(n, in_size)
26 | if mask is not None:
27 | u = u * mask
28 | v = 1. / torch.bmm(u.view(n, 1, in_size), K).view(n, out_size)
29 | K = u.view(n, in_size, 1) * (K * v.view(n, 1, out_size))
30 | if return_kernel:
31 | K = K / out_size
32 | return (K * dot).sum(dim=[1, 2])
33 | return K
34 |
35 | def log_sinkhorn(K, mask=None, eps=1.0, return_kernel=False, max_iter=100):
36 | """
37 | dot: n x in_size x out_size
38 | mask: n x in_size
39 | output: n x in_size x out_size
40 | """
41 | batch_size, in_size, out_size = K.shape
42 | def min_eps(u, v, dim):
43 | Z = (K + u.view(batch_size, in_size, 1) + v.view(batch_size, 1, out_size)) / eps
44 | return -torch.logsumexp(Z, dim=dim)
45 | # K: batch_size x in_size x out_size
46 | u = K.new_zeros((batch_size, in_size))
47 | v = K.new_zeros((batch_size, out_size))
48 | a = torch.ones_like(u).fill_(out_size / in_size)
49 | if mask is not None:
50 | a = out_size / mask.float().sum(1, keepdim=True)
51 | a = torch.log(a)
52 | for _ in range(max_iter):
53 | u = eps * (a + min_eps(u, v, dim=-1)) + u
54 | if mask is not None:
55 | u = u.masked_fill(~mask, -1e8)
56 | v = eps * min_eps(u, v, dim=1) + v
57 | if return_kernel:
58 | output = torch.exp(
59 | (K + u.view(batch_size, in_size, 1) + v.view(batch_size, 1, out_size)) / eps)
60 | output = output / out_size
61 | return (output * K).sum(dim=[1, 2])
62 | K = torch.exp(
63 | (K + u.view(batch_size, in_size, 1) + v.view(batch_size, 1, out_size)) / eps)
64 | return K
65 |
66 | def multihead_attn(input, weight, mask=None, eps=1.0, return_kernel=False,
67 | max_iter=100, log_domain=False, position_filter=None):
68 | """Comput the attention weight using Sinkhorn OT
69 | input: n x in_size x in_dim
70 | mask: n x in_size
71 | weight: m x out_size x in_dim (m: number of heads/ref)
72 | output: n x out_size x m x in_size
73 | """
74 | n, in_size, in_dim = input.shape
75 | m, out_size = weight.shape[:-1]
76 | # K = torch.tensordot(torch.nn.functional.normalize(input,dim=-1), torch.nn.functional.normalize(weight,dim=-1), dims=[[-1], [-1]])
77 | K = torch.tensordot(normalize(input), normalize(weight), dims=[[-1], [-1]])
78 | # K = torch.tensordot(input, weight, dims=[[-1], [-1]])
79 | K = K.permute(0, 2, 1, 3)
80 | if position_filter is not None:
81 | K = position_filter * K
82 | # K: n x m x in_size x out_size
83 | K = K.reshape(-1, in_size, out_size)
84 | # K = K-K.min()
85 | # K = K/K.max()
86 | # K: nm x in_size x out_size
87 | if mask is not None:
88 | mask = mask.repeat_interleave(m, dim=0)
89 | if log_domain:
90 | K = log_sinkhorn(K, mask, eps, return_kernel=return_kernel, max_iter=max_iter)
91 | else:
92 | if not return_kernel:
93 | K = torch.exp(K / eps)
94 | K = sinkhorn(K, mask, eps, return_kernel=return_kernel, max_iter=max_iter)
95 | # K: nm x in_size x out_size
96 | if return_kernel:
97 | return K.reshape(n, m)
98 | K = K.reshape(n, m, in_size, out_size)
99 | if position_filter is not None:
100 | K = position_filter * K
101 | K = K.permute(0, 3, 1, 2).contiguous()
102 | return K
103 |
104 | def wasserstein_barycenter(x, c, eps=1.0, max_iter=100, sinkhorn_iter=50, log_domain=False):
105 | """
106 | x: n x in_size x in_dim
107 | c: out_size x in_dim
108 | """
109 | prev_c = c
110 | for i in range(max_iter):
111 | T = attn(x, c, eps=eps, log_domain=log_domain, max_iter=sinkhorn_iter)
112 | # T: n x out_size x in_size
113 | c = 0.5*c + 0.5*torch.bmm(T, x).mean(dim=0) / math.sqrt(c.shape[0])
114 | c /= c.norm(dim=-1, keepdim=True).clamp(min=1e-06)
115 | if ((c - prev_c) ** 2).sum() < 1e-06:
116 | break
117 | prev_c = c
118 | return c
119 |
120 | def wasserstein_kmeans(x, n_clusters, out_size, eps=1.0, block_size=None, max_iter=100,
121 | sinkhorn_iter=50, wb=False, verbose=True, log_domain=False, use_cuda=False):
122 | """
123 | x: n x in_size x in_dim
124 | output: n_clusters x out_size x in_dim
125 | out_size <= in_size
126 | """
127 | n, in_size, in_dim = x.shape
128 | if n_clusters == 1:
129 | if use_cuda:
130 | x = x.cuda()
131 | clusters = spherical_kmeans(x.view(-1, in_dim), out_size, block_size=block_size)
132 | if wb:
133 | clusters = wasserstein_barycenter(x, clusters, eps=0.1, log_domain=False)
134 | clusters = clusters.unsqueeze_(0)
135 | return clusters
136 | ## intialization
137 | indices = torch.randperm(n)[:n_clusters]
138 | clusters = x[indices, :out_size, :].clone()
139 | if use_cuda:
140 | clusters = clusters.cuda()
141 |
142 | wass_sim = x.new_empty(n)
143 | assign = x.new_empty(n, dtype=torch.long)
144 | if block_size is None or block_size == 0:
145 | block_size = n
146 |
147 | prev_sim = float('inf')
148 | for n_iter in range(max_iter):
149 | for i in range(0, n, block_size):
150 | end_i = min(i + block_size, n)
151 | x_batch = x[i: end_i]
152 | if use_cuda:
153 | x_batch = x_batch.cuda()
154 | tmp_sim = multihead_attn(x_batch, clusters, eps=eps, return_kernel=True, max_iter=sinkhorn_iter, log_domain=log_domain)
155 | tmp_sim = tmp_sim.cpu()
156 | wass_sim[i : end_i], assign[i: end_i] = tmp_sim.max(dim=-1)
157 | del x_batch
158 | sim = wass_sim.mean()
159 | if verbose and (n_iter + 1) % 10 == 0:
160 | print("Wasserstein spherical kmeans iter {}, objective value {}".format(
161 | n_iter + 1, sim))
162 |
163 | for j in range(n_clusters):
164 | index = assign == j
165 | if index.sum() == 0:
166 | idx = wass_sim.argmin()
167 | clusters[j].copy_(x[idx, :out_size, :])
168 | wass_sim[idx] = 1
169 | else:
170 | xj = x[index]
171 | if use_cuda:
172 | xj = xj.cuda()
173 | c = spherical_kmeans(xj.view(-1, in_dim), out_size, block_size=block_size, verbose=False)
174 | if wb:
175 | c = wasserstein_barycenter(xj, c, eps=0.001, log_domain=True, sinkhorn_iter=50)
176 | clusters[j] = c
177 | if torch.abs(prev_sim - sim) / sim.clamp(min=1e-10) < 1e-6:
178 | break
179 | prev_sim = sim
180 | return clusters
181 |
--------------------------------------------------------------------------------
/otk/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | import math
4 | import random
5 | import numpy as np
6 |
7 | EPS = 1e-6
8 |
9 |
10 | def normalize(x, p=2, dim=-1, inplace=False):
11 | norm = x.norm(p=p, dim=dim, keepdim=True)
12 | if inplace:
13 | x.div_(norm.clamp(min=EPS))
14 | else:
15 | x = x / norm.clamp(min=EPS)
16 | return x
17 |
18 | def spherical_kmeans(x, n_clusters, max_iters=100, block_size=None, verbose=True,
19 | init=None, eps=1e-4):
20 | """Spherical kmeans
21 | Args:
22 | x (Tensor n_samples x kmer_size x n_features): data points
23 | n_clusters (int): number of clusters
24 | """
25 | use_cuda = x.is_cuda
26 | if x.ndim == 3:
27 | n_samples, kmer_size, n_features = x.size()
28 | else:
29 | n_samples, n_features = x.size()
30 | if init is None:
31 | indices = torch.randperm(n_samples)[:n_clusters]
32 | if use_cuda:
33 | indices = indices.cuda()
34 | clusters = x[indices]
35 |
36 | prev_sim = np.inf
37 | tmp = x.new_empty(n_samples)
38 | assign = x.new_empty(n_samples, dtype=torch.long)
39 | if block_size is None or block_size == 0:
40 | block_size = x.shape[0]
41 |
42 | for n_iter in range(max_iters):
43 | for i in range(0, n_samples, block_size):
44 | end_i = min(i + block_size, n_samples)
45 | cos_sim = x[i: end_i].view(end_i - i, -1).mm(clusters.view(n_clusters, -1).t())
46 | tmp[i: end_i], assign[i: end_i] = cos_sim.max(dim=-1)
47 | sim = tmp.mean()
48 | if (n_iter + 1) % 10 == 0 and verbose:
49 | print("Spherical kmeans iter {}, objective value {}".format(
50 | n_iter + 1, sim))
51 |
52 | # update clusters
53 | for j in range(n_clusters):
54 | index = assign == j
55 | if index.sum() == 0:
56 | idx = tmp.argmin()
57 | clusters[j] = x[idx]
58 | tmp[idx] = 1
59 | else:
60 | xj = x[index]
61 | c = xj.mean(0)
62 | clusters[j] = c / c.norm(dim=-1, keepdim=True).clamp(min=EPS)
63 |
64 | if torch.abs(prev_sim - sim)/(torch.abs(sim)+1e-20) < 1e-6:
65 | break
66 | prev_sim = sim
67 | return clusters
68 |
--------------------------------------------------------------------------------
/tools/Trainer_ours_m40r4.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | import numpy as np
6 | from torch.optim.lr_scheduler import CosineAnnealingLR
7 | from tensorboardX import SummaryWriter
8 | import math
9 | from torch.cuda.amp import autocast as autocast
10 | from torch.cuda.amp import GradScaler as GradScaler
11 | from otk.utils import normalize
12 | class ModelNetTrainer(object):
13 | def __init__(self, model, train_loader, val_loader, optimizer, loss_fn, \
14 | model_name, log_dir, num_views=12):
15 | self.optimizer = optimizer
16 | self.model = model
17 | self.train_loader = train_loader
18 | self.val_loader = val_loader
19 | self.loss_fn = loss_fn
20 | self.model_name = model_name
21 | self.log_dir = log_dir
22 | self.num_views = num_views
23 | self.model.cuda()
24 | if self.log_dir is not None:
25 | self.writer = SummaryWriter(log_dir)
26 | def train(self, n_epochs):
27 | best_acc = 0
28 | i_acc = 0
29 | self.model.train()
30 | scalar = GradScaler()
31 | # if self.model_name =='view-gcn':
32 | # self.model.weight_init(self.train_loader)
33 | scheduler = CosineAnnealingLR(self.optimizer,T_max=((len(self.train_loader.dataset.rand_view_num)//20)*60),eta_min=1e-5)
34 | for epoch in range(n_epochs):
35 | rand_idx = np.random.permutation(int(len(self.train_loader.dataset.rand_view_num)))
36 | filepaths_new = []
37 | rand_view_num_new = []
38 | view_coord_new = []
39 | for i in range(len(rand_idx)):
40 | filepaths_new.extend(self.train_loader.dataset.filepaths[
41 | rand_idx[i] * self.num_views:(rand_idx[i] + 1) * self.num_views])
42 | rand_view_num_new.append(self.train_loader.dataset.rand_view_num[rand_idx[i]])
43 | view_coord_new.append(self.train_loader.dataset.view_coord[rand_idx[i]])
44 | self.train_loader.dataset.filepaths = filepaths_new
45 | self.train_loader.dataset.rand_view_num = np.array(rand_view_num_new)
46 | self.train_loader.dataset.view_coord= np.array(view_coord_new)
47 | # plot learning rate
48 | lr = self.optimizer.state_dict()['param_groups'][0]['lr']
49 | self.writer.add_scalar('params/lr', lr, epoch)
50 | # train one epoch
51 | out_data = None
52 | in_data = None
53 | for i, data in enumerate(self.train_loader):
54 | if epoch == 0:
55 | for param_group in self.optimizer.param_groups:
56 | param_group['lr'] = lr * ((i + 1) / (len(self.train_loader.dataset.rand_view_num) // 20))
57 | if self.model_name == 'svcnn':
58 | in_data = Variable(data[1].cuda())
59 | else:
60 | view_coord = data[3].cuda()
61 | rand_view_num = data[2].cuda()
62 | N, V, C, H, W = data[1].size()
63 | in_data = Variable(data[1]).cuda()
64 | in_data_new = []
65 | for ii in range(N):
66 | in_data_new.append(in_data[ii,0:rand_view_num[ii]])
67 | in_data = torch.cat(in_data_new, 0)
68 |
69 | target = Variable(data[0]).cuda().long()
70 | target2 = Variable(data[0]).cuda().long()
71 | vert = [[1, 1, 1], [1, 1, -1], [1, -1, 1], [1, -1, -1], [-1, 1, 1], [-1, 1, -1], [-1, -1, 1],
72 | [-1, -1, -1]]
73 | # vert = [[1,1],[1,-1],[-1,1],[-1,-1]]
74 | # vert = [[1,1,1,1],
75 | # [-1,1,1,1],[1,-1,1,1],[1,1,-1,1],[1,1,1,-1],
76 | # [-1,-1,1,1],[-1,1,-1,1],[-1,1,1,-1],[1,-1,-1,1],[1,-1,1,-1],[1,1,-1,-1],
77 | # [-1,-1,-1,1],[-1,-1,1,-1],[-1,1,-1,-1],[1,-1,-1,-1],
78 | # [-1,-1,-1,-1]]
79 | vert = torch.Tensor(vert).cuda()
80 | # target_ = target.unsqueeze(1).repeat(1, 2*(10+5)).view(-1)
81 | self.optimizer.zero_grad()
82 | with autocast():
83 | if self.model_name == 'svcnn':
84 | out_data = self.model(in_data)
85 | loss = self.loss_fn(out_data, target)
86 | else:
87 | out_data,cos_sim,cos_sim2,pos= self.model(in_data,rand_view_num,N)
88 | cos_loss = cos_sim[torch.where(cos_sim>-1)].mean()
89 | cos_loss2 = cos_sim2[torch.where(cos_sim2>-1)].mean()
90 | # part_loss = self.loss_fn(part.reshape(-1,40),target2)
91 | pos_loss = torch.norm(normalize(vert)-normalize(pos),p=2,dim=-1).mean()
92 |
93 | loss = self.loss_fn(out_data, target)+0.1*pos_loss
94 |
95 | self.writer.add_scalar('train/train_loss', loss, i_acc + i + 1)
96 |
97 | pred = torch.max(out_data, 1)[1]
98 | results = pred == target
99 | correct_points = torch.sum(results.long())
100 |
101 | acc = correct_points.float() / results.size()[0]
102 | self.writer.add_scalar('train/train_overall_acc', acc, i_acc + i + 1)
103 | # print('lr = ', str(param_group['lr']))
104 | scalar.scale(loss).backward()
105 | scalar.unscale_(self.optimizer)
106 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 20)
107 | scalar.step(self.optimizer)
108 | scalar.update()
109 | #loss.backward()
110 | #torch.nn.utils.clip_grad_norm_(self.model.parameters(), 20)
111 | #self.optimizer.step()
112 | if epoch>0:
113 | scheduler.step()
114 | # self.optimizer.zero_grad()
115 | log_str = 'epoch %d, step %d: train_loss %.3f;cos_loss %.3f;pos_loss%.3f; train_acc %.3f' % (epoch + 1, i + 1, loss,cos_loss2,pos_loss, acc)
116 | if (i + 1) % 1 == 0:
117 | print(log_str)
118 | i_acc += i
119 | # evaluation
120 | if (epoch + 1) % 1 == 0:
121 | with torch.no_grad():
122 | loss, val_overall_acc, val_mean_class_acc = self.update_validation_accuracy(epoch)
123 | self.writer.add_scalar('val/val_mean_class_acc', val_mean_class_acc, epoch + 1)
124 | self.writer.add_scalar('val/val_overall_acc', val_overall_acc, epoch + 1)
125 | self.writer.add_scalar('val/val_loss', loss, epoch + 1)
126 | # self.model.save(self.log_dir, epoch)
127 | # save best model
128 | if val_overall_acc > best_acc:
129 | best_acc = val_overall_acc
130 | print('best_acc', best_acc)
131 | # export scalar data to JSON for external processing
132 | self.writer.export_scalars_to_json(self.log_dir + "/all_scalars.json")
133 | self.writer.close()
134 | def update_validation_accuracy(self, epoch):
135 | all_correct_points = 0
136 | all_points = 0
137 | count = 0
138 | wrong_class = np.zeros(40)
139 | samples_class = np.zeros(40)
140 | all_loss = 0
141 | self.model.eval()
142 | for _, data in enumerate(self.val_loader, 0):
143 | if self.model_name == 'svcnn':
144 | in_data = Variable(data[1].cuda())
145 | else:
146 | view_coord = data[3].cuda()
147 | rand_view_num = data[2].cuda()
148 | N, V, C, H, W = data[1].size()
149 | in_data = Variable(data[1]).cuda()
150 | in_data_new = []
151 | for ii in range(N):
152 | in_data_new.append(in_data[ii, 0:rand_view_num[ii]])
153 | in_data = torch.cat(in_data_new, 0)
154 | target = Variable(data[0]).cuda()
155 | #
156 | if self.model_name == 'svcnn':
157 | out_data = self.model(in_data)
158 | else:
159 | out_data,cos_sim,cos_sim2,pos= self.model(in_data, rand_view_num, N)
160 | # cos_loss = cos_sim[torch.where(cos_sim > 0)].mean()
161 | pred = torch.max(out_data, 1)[1]
162 | all_loss += self.loss_fn(out_data, target).cpu().data.numpy()
163 | results = pred == target
164 |
165 | for i in range(results.size()[0]):
166 | if not bool(results[i].cpu().data.numpy()):
167 | wrong_class[target.cpu().data.numpy().astype('int')[i]] += 1
168 | samples_class[target.cpu().data.numpy().astype('int')[i]] += 1
169 | correct_points = torch.sum(results.long())
170 |
171 | all_correct_points += correct_points
172 | all_points += results.size()[0]
173 |
174 | print('Total # of test models: ', all_points)
175 | class_acc = (samples_class - wrong_class) / samples_class
176 | val_mean_class_acc = np.mean(class_acc)
177 | acc = all_correct_points.float() / all_points
178 | val_overall_acc = acc.cpu().data.numpy()
179 | loss = all_loss / len(self.val_loader)
180 |
181 |
182 | print('val mean class acc. : ', val_mean_class_acc)
183 | print('val overall acc. : ', val_overall_acc)
184 | print('val loss : ', loss)
185 | # print('cos loss : ', cos_loss)
186 | print(class_acc)
187 | self.model.train()
188 |
189 | return loss, val_overall_acc, val_mean_class_acc
190 |
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/tools/replas.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.io as sio
3 | import glob
4 | data_path = '/data/xxinwei/dataset/objectnn_hardest/'
5 | # classnames=['airplane','bathtub','bed','bench','bookshelf','bottle','bowl','car','chair',
6 | # 'cone','cup','curtain','desk','door','dresser','flower_pot','glass_box',
7 | # 'guitar','keyboard','lamp','laptop','mantel','monitor','night_stand',
8 | # 'person','piano','plant','radio','range_hood','sink','sofa','stairs',
9 | # 'stool','table','tent','toilet','tv_stand','vase','wardrobe','xbox']
10 | classnames = ['bag', 'bed', 'bin', 'box', 'cabinet', 'chair', 'desk', 'display', 'door', 'pillow', 'shelf', 'sink', 'sofa', 'table', 'toilet']
11 | for i in range(len(classnames)):
12 | #view = sio.loadmat(data_path+classnames[i]+'/train/view.mat')
13 | #view1 = view['ll']
14 | #np.save(data_path + classnames[i] + '/train/view.npy', view1)
15 | #rand_view_num = np.load('/data/xxinwei/view-transformer-hardest/ModelNet40_hardest/'+classnames[i]+'/train/random_view_num.npy')
16 | numb = glob.glob(data_path + classnames[i]+'/test/*')
17 | rand_view_num = np.random.randint(6,20,(len(numb)-2)//20)
18 | np.save(data_path + classnames[i]+'/test/random_view_num.npy',rand_view_num)
19 | ss = sio.loadmat(data_path + classnames[i]+'/test/view.mat')
20 | ss = ss['ll']
21 | np.save(data_path + classnames[i]+'/test/view.npy',ss)
--------------------------------------------------------------------------------
/train4.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch.nn as nn
3 | import torch
4 | import torch.optim as optim
5 | import os,shutil,json
6 | import argparse
7 | from tools.Trainer_ours_m40r4 import ModelNetTrainer
8 | from tools.ImgDataset_m40_6_20 import RandomMultiviewImgDataset,RandomSingleImgDataset
9 | from model.view_transformer_random4 import *
10 |
11 | def seed_torch(seed=0):
12 | random.seed(seed)
13 | os.environ['PYTHONHASHSEED'] = str(seed)
14 | np.random.seed(seed)
15 | torch.manual_seed(seed)
16 | torch.cuda.manual_seed(seed)
17 | # torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
18 | torch.backends.cudnn.benchmark = False
19 | torch.backends.cudnn.deterministic = True
20 | torch.backends.cudnn.enabled = True
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument("-name", "--name", type=str, help="Name of the experiment", default="4")
23 | parser.add_argument("-bs", "--batchSize", type=int, help="Batch size for the second stage", default=20)# it will be *12 images in each batch for mvcnn
24 | parser.add_argument("-num_models", type=int, help="number of models per class", default=0)
25 | parser.add_argument("-lr", type=float, help="learning rate", default=1e-3)
26 | parser.add_argument("-weight_decay", type=float, help="weight decay", default=1e-3)
27 | parser.add_argument("-no_pretraining", dest='no_pretraining', action='store_true')
28 | parser.add_argument("-cnn_name", "--cnn_name", type=str, help="cnn model name", default="resnet18")
29 | parser.add_argument("-num_views", type=int, help="number of views", default=20)
30 | parser.add_argument("-train_path", type=str, default="/home/sun/weixin/view-transformer/ModelNet40_hardest_20/*/train")
31 | parser.add_argument("-val_path", type=str, default="/home/sun/weixin/view-transformer/ModelNet40_hardest_20/*/test")
32 | parser.set_defaults(train=False)
33 | # os.environ['CUDA_VISIBLE_DEVICES']='2'
34 | def create_folder(log_dir):
35 | if not os.path.exists(log_dir):
36 | os.mkdir(log_dir)
37 | else:
38 | print('WARNING: summary folder already exists!! It will be overwritten!!')
39 | shutil.rmtree(log_dir)
40 | os.mkdir(log_dir)
41 |
42 | if __name__ == '__main__':
43 | seed_torch()
44 | args = parser.parse_args()
45 | pretraining = not args.no_pretraining
46 | log_dir = args.name
47 | create_folder(args.name)
48 | config_f = open(os.path.join(log_dir, 'config.json'), 'w')
49 | json.dump(vars(args), config_f)
50 | config_f.close()
51 | cnet = SVCNN(args.name, nclasses=40, pretraining=pretraining, cnn_name=args.cnn_name)
52 | n_models_train = args.num_models * args.num_views
53 | log_dir = args.name+'_stage_2'
54 | create_folder(log_dir)
55 | cnet_2 = view_GCN(args.name, cnet, nclasses=40, cnn_name=args.cnn_name, num_views=args.num_views)
56 | optimizer = optim.SGD(cnet_2.parameters(), lr=args.lr, weight_decay=args.weight_decay,momentum=0.9)
57 | train_dataset = RandomMultiviewImgDataset(args.train_path, scale_aug=False, rot_aug=False, num_models=n_models_train, num_views=args.num_views,test_mode=True)
58 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batchSize, shuffle=False, num_workers=12,pin_memory=True)# shuffle needs to be false! it's done within the trainer
59 | val_dataset = RandomMultiviewImgDataset(args.val_path, scale_aug=False, rot_aug=False, num_views=args.num_views,test_mode=True)
60 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batchSize, shuffle=False, num_workers=12,pin_memory=True)
61 | print('num_train_files: '+str(len(train_dataset.filepaths)))
62 | print('num_val_files: '+str(len(val_dataset.filepaths)))
63 | trainer = ModelNetTrainer(cnet_2, train_loader, val_loader, optimizer, nn.CrossEntropyLoss(), 'view-gcn', log_dir, num_views=args.num_views)
64 | trainer.train(60)
--------------------------------------------------------------------------------