├── .gitignore
├── LICENSE
├── README.md
├── configs
├── baselines
│ ├── cgqa
│ │ ├── aopp.yml
│ │ ├── le+.yml
│ │ ├── symnet.yml
│ │ └── tmn.yml
│ ├── mit
│ │ ├── aopp.yml
│ │ ├── le+.yml
│ │ ├── symnet.yml
│ │ └── tmn.yml
│ └── utzppos
│ │ ├── aopp.yml
│ │ ├── le+.yml
│ │ ├── symnet.yml
│ │ └── tmn.yml
├── cge
│ ├── cgqa.yml
│ ├── mit.yml
│ └── utzappos.yml
└── compcos
│ ├── cgqa
│ ├── compcos.yml
│ └── compcos_cw.yml
│ ├── mit
│ ├── compcos.yml
│ └── compcos_cw.yml
│ └── utzppos
│ ├── compcos.yml
│ └── compcos_cw.yml
├── data
├── __init__.py
└── dataset.py
├── environment.yml
├── flags.py
├── models
├── __init__.py
├── common.py
├── compcos.py
├── gcn.py
├── graph_method.py
├── image_extractor.py
├── manifold_methods.py
├── modular_methods.py
├── svm.py
├── symnet.py
├── visual_product.py
└── word_embedding.py
├── notebooks
└── analysis.ipynb
├── test.py
├── train.py
└── utils
├── __init__.py
├── cgqa-graph.t7
├── config_model.py
├── download_data.sh
├── download_embeddings.py
├── img.png
├── mitstates-graph.t7
├── reorganize_utzap.py
├── utils.py
└── utzappos-graph.t7
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # IDE
132 | .vscode/
133 |
134 | #logs and dataset
135 | cv/*
136 | data/antonyms.txt
137 | data/glove/*
138 | data/mit-states/*
139 | data/ut-zap50k/*
140 | logs/*
141 | tensor-completion/*
142 | .DS_Store
143 | ._.DS_Store
144 | *.csv
145 | *.txt
146 | notebooks/._mit_graph.pdf
147 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Compositional Zero-Shot Learning
3 | This is the official PyTorch code of the CVPR 2021 works [Learning Graph Embeddings for Compositional Zero-shot Learning](https://arxiv.org/pdf/2102.01987.pdf) and [Open World Compositional Zero-Shot Learning](https://arxiv.org/pdf/2101.12609.pdf). The code provides the implementation of the methods CGE, CompCos together with other baselines (e.g. SymNet, AoP, TMN, LabelEmbed+,RedWine). It also provides train and test for the Open World CZSL setting and the new C-GQA benchmark.
4 |
5 | **Important note:** the C-GQA dataset has been updated (see [this issue](https://github.com/ExplainableML/czsl/issues/3)) and the code will automatically download the new version. The results of all models for the updated benchmark can be found in the [Co-CGE](https://arxiv.org/abs/2105.01017) and [KG-SP](https://openaccess.thecvf.com/content/CVPR2022/html/Karthik_KG-SP_Knowledge_Guided_Simple_Primitives_for_Open_World_Compositional_Zero-Shot_CVPR_2022_paper.html) papers.
6 |
7 |
8 |
9 |
10 |
11 | ## Check also:
12 | - [Co-CGE](https://ieeexplore.ieee.org/document/9745371/) and its [repo](https://github.com/ExplainableML/co-cge) if you are interested in a stronger OW-CZSL model and a faster OW evaluation code.
13 | - [KG-SP](https://openaccess.thecvf.com/content/CVPR2022/html/Karthik_KG-SP_Knowledge_Guided_Simple_Primitives_for_Open_World_Compositional_Zero-Shot_CVPR_2022_paper.html) and its [repo](https://github.com/ExplainableML/KG-SP) if you are interested in the partial CZSL setting and a simple but effective OW model.
14 |
15 |
16 | ## Setup
17 |
18 | 1. Clone the repo
19 |
20 | 2. We recommend using Anaconda for environment setup. To create the environment and activate it, please run:
21 | ```
22 | conda env create --file environment.yml
23 | conda activate czsl
24 | ```
25 |
26 | 4. Go to the cloned repo and open a terminal. Download the datasets and embeddings, specifying the desired path (e.g. `DATA_ROOT` in the example):
27 | ```
28 | bash ./utils/download_data.sh DATA_ROOT
29 | mkdir logs
30 | ```
31 |
32 |
33 | ## Training
34 | **Closed World.** To train a model, the command is simply:
35 | ```
36 | python train.py --config CONFIG_FILE
37 | ```
38 | where `CONFIG_FILE` is the path to the configuration file of the model.
39 | The folder `configs` contains configuration files for all methods, i.e. CGE in `configs/cge`, CompCos in `configs/compcos`, and the other methods in `configs/baselines`.
40 |
41 | To run CGE on MitStates, the command is just:
42 | ```
43 | python train.py --config configs/cge/mit.yml
44 | ```
45 | On UT-Zappos, the command is:
46 | ```
47 | python train.py --config configs/cge/utzappos.yml
48 | ```
49 |
50 | **Open World.** To train CompCos (in the open world scenario) on MitStates, run:
51 | ```
52 | python train.py --config configs/compcos/mit/compcos.yml
53 | ```
54 |
55 | To run experiments in the open world setting for a non-open world method, just add `--open_world` after the command. E.g. for running SymNet in the open world scenario on MitStates, the command is:
56 | ```
57 | python train.py --config configs/baselines/mit/symnet.yml --open_world
58 | ```
59 | **Note:** To create a new config, all the available arguments are indicated in `flags.py`.
60 |
61 | ## Test
62 |
63 |
64 | **Closed World.** To test a model, the code is simple:
65 | ```
66 | python test.py --logpath LOG_DIR
67 | ```
68 | where `LOG_DIR` is the directory containing the logs of a model.
69 |
70 |
71 | **Open World.** To test a model in the open world setting, run:
72 | ```
73 | python test.py --logpath LOG_DIR --open_world
74 | ```
75 |
76 | To test a CompCos model in the open world setting with hard masking, run:
77 | ```
78 | python test.py --logpath LOG_DIR_COMPCOS --open_world --hard_masking
79 | ```
80 |
81 |
82 | ## References
83 | If you use this code, please cite
84 | ```
85 | @inproceedings{naeem2021learning,
86 | title={Learning Graph Embeddings for Compositional Zero-shot Learning},
87 | author={Naeem, MF and Xian, Y and Tombari, F and Akata, Zeynep},
88 | booktitle={34th IEEE Conference on Computer Vision and Pattern Recognition},
89 | year={2021},
90 | organization={IEEE}
91 | }
92 | ```
93 | and
94 | ```
95 | @inproceedings{mancini2021open,
96 | title={Open World Compositional Zero-Shot Learning},
97 | author={Mancini, M and Naeem, MF and Xian, Y and Akata, Zeynep},
98 | booktitle={34th IEEE Conference on Computer Vision and Pattern Recognition},
99 | year={2021},
100 | organization={IEEE}
101 | }
102 |
103 | ```
104 |
105 | **Note**: Some of the scripts are adapted from AttributeasOperators repository. GCN and GCNII implementations are imported from their respective repositories. If you find those parts useful, please consider citing:
106 | ```
107 | @inproceedings{nagarajan2018attributes,
108 | title={Attributes as operators: factorizing unseen attribute-object compositions},
109 | author={Nagarajan, Tushar and Grauman, Kristen},
110 | booktitle={Proceedings of the European Conference on Computer Vision (ECCV)},
111 | pages={169--185},
112 | year={2018}
113 | }
114 | ```
115 |
--------------------------------------------------------------------------------
/configs/baselines/cgqa/aopp.yml:
--------------------------------------------------------------------------------
1 | ---
2 | experiment:
3 | name: aopp/cgqa/
4 | dataset:
5 | data_dir: cgqa
6 | dataset: cgqa
7 | splitname: compositional-split-natural
8 | model_params:
9 | model: attributeop
10 | emb_dim: 300
11 | emb_init: glove
12 | eval_type: dist_fast
13 | image_extractor: resnet18
14 | train_only: true
15 | static_inp: false
16 | composition: add
17 | loss:
18 | lambda_aux: 1.0
19 | lambda_comm: 1.0
20 | training:
21 | batch_size: 128
22 | eval_val_every: 2
23 | load:
24 | lr: 5.0e-05
25 | lrg: 0.001
26 | margin: 0.5
27 | max_epochs: 1000
28 | norm_family: imagenet
29 | save_every: 10000
30 | test_batch_size: 32
31 | test_set: val
32 | topk: 1
33 | wd: 5.0e-05
34 | workers: 8
35 | update_features: false
36 | freeze_features: false
37 | extra:
38 | lambda_attr: 0
39 | lambda_obj: 0
40 | lambda_sub: 0
41 | graph: false
42 | hardk: null
--------------------------------------------------------------------------------
/configs/baselines/cgqa/le+.yml:
--------------------------------------------------------------------------------
1 | ---
2 | experiment:
3 | name: labelembed/cgqa/
4 | dataset:
5 | data_dir: cgqa
6 | dataset: cgqa
7 | splitname: compositional-split-natural
8 | model_params:
9 | model: labelembed+
10 | dropout: false
11 | norm: false
12 | nlayers: 2
13 | fc_emb: 512,512
14 | emb_dim: 300
15 | emb_init: glove
16 | eval_type: dist_fast
17 | image_extractor: resnet18
18 | train_only: true
19 | static_inp: false
20 | composition: add
21 | loss:
22 | lambda_ce: 1.0
23 | lambda_trip_smart: 0
24 | training:
25 | batch_size: 512
26 | eval_val_every: 2
27 | load:
28 | lr: 5.0e-05
29 | margin: 0.5
30 | max_epochs: 2000
31 | norm_family: imagenet
32 | save_every: 10000
33 | test_batch_size: 32
34 | test_set: val
35 | topk: 1
36 | wd: 5.0e-05
37 | workers: 8
38 | update_features: false
39 | freeze_features: false
--------------------------------------------------------------------------------
/configs/baselines/cgqa/symnet.yml:
--------------------------------------------------------------------------------
1 | ---
2 | experiment:
3 | name: symnet/cgqa/
4 | dataset:
5 | data_dir: cgqa
6 | dataset: cgqa
7 | splitname: compositional-split-natural
8 | model_params:
9 | model: symnet
10 | dropout: true
11 | norm: true
12 | nlayers: 3
13 | fc_emb: 768,1024,1200
14 | emb_dim: 300
15 | emb_init: glove
16 | image_extractor: resnet18
17 | loss:
18 | lambda_cls_attr: 1.0
19 | lambda_cls_obj: 1.0
20 | lambda_trip: 0.5
21 | lambda_sym: 0.01
22 | lambda_axiom: 0.03
23 | training:
24 | batch_size: 512
25 | eval_val_every: 2
26 | load:
27 | lr: 5.0e-05
28 | lrg: 0.001
29 | margin: 0.5
30 | max_epochs: 2000
31 | norm_family: imagenet
32 | save_every: 10000
33 | test_batch_size: 32
34 | test_set: val
35 | topk: 1
36 | wd: 5.0e-05
37 | workers: 8
38 | update_features: false
39 | freeze_features: false
--------------------------------------------------------------------------------
/configs/baselines/cgqa/tmn.yml:
--------------------------------------------------------------------------------
1 | ---
2 | experiment:
3 | name: tmn/cgqa/
4 | dataset:
5 | data_dir: cgqa
6 | dataset: cgqa
7 | splitname: compositional-split-natural
8 | model_params:
9 | model: tmn
10 | emb_init: glove
11 | image_extractor: resnet18
12 | num_negs: 600
13 | embed_rank: 64
14 | emb_dim: 16
15 | nmods: 24
16 | training:
17 | batch_size: 256
18 | eval_val_every: 2
19 | load:
20 | lr: 0.0001
21 | lrg: 0.01
22 | margin: 0.5
23 | max_epochs: 100
24 | norm_family: imagenet
25 | save_every: 10000
26 | test_batch_size: 32
27 | test_set: val
28 | topk: 1
29 | wd: 5.0e-05
30 | workers: 8
31 | update_features: false
32 | freeze_features: false
--------------------------------------------------------------------------------
/configs/baselines/mit/aopp.yml:
--------------------------------------------------------------------------------
1 | ---
2 | experiment:
3 | name: aopp/mitstates/
4 | dataset:
5 | data_dir: mit-states
6 | dataset: mitstates
7 | splitname: compositional-split-natural
8 | model_params:
9 | model: attributeop
10 | emb_dim: 300
11 | emb_init: glove
12 | eval_type: dist_fast
13 | image_extractor: resnet18
14 | train_only: true
15 | static_inp: false
16 | composition: add
17 | loss:
18 | lambda_aux: 1000.0
19 | lambda_comm: 1.0
20 | training:
21 | batch_size: 512
22 | eval_val_every: 2
23 | load:
24 | lr: 5.0e-05
25 | lrg: 0.001
26 | margin: 0.5
27 | max_epochs: 1000
28 | norm_family: imagenet
29 | save_every: 10000
30 | test_batch_size: 32
31 | test_set: val
32 | topk: 1
33 | wd: 5.0e-05
34 | workers: 8
35 | update_features: false
36 | freeze_features: false
37 | extra:
38 | lambda_attr: 0
39 | lambda_obj: 0
40 | lambda_sub: 0
41 | graph: false
42 | hardk: null
--------------------------------------------------------------------------------
/configs/baselines/mit/le+.yml:
--------------------------------------------------------------------------------
1 | ---
2 | experiment:
3 | name: labelembed/mitstates/
4 | dataset:
5 | data_dir: mit-states
6 | dataset: mitstates
7 | splitname: compositional-split-natural
8 | model_params:
9 | model: labelembed+
10 | dropout: false
11 | norm: false
12 | nlayers: 2
13 | fc_emb: 512,512
14 | emb_dim: 300
15 | emb_init: glove
16 | eval_type: dist_fast
17 | image_extractor: resnet18
18 | train_only: true
19 | static_inp: false
20 | composition: add
21 | training:
22 | batch_size: 512
23 | eval_val_every: 2
24 | load:
25 | lr: 5.0e-05
26 | margin: 0.5
27 | max_epochs: 2000
28 | norm_family: imagenet
29 | save_every: 10000
30 | test_batch_size: 32
31 | test_set: val
32 | topk: 1
33 | wd: 5.0e-05
34 | workers: 8
35 | update_features: false
36 | freeze_features: false
--------------------------------------------------------------------------------
/configs/baselines/mit/symnet.yml:
--------------------------------------------------------------------------------
1 | ---
2 | experiment:
3 | name: symnet/mit-states/
4 | dataset:
5 | data_dir: mit-states
6 | dataset: mitstates
7 | splitname: compositional-split-natural
8 | model_params:
9 | model: symnet
10 | dropout: true
11 | norm: true
12 | nlayers: 3
13 | fc_emb: 768,1024,1200
14 | emb_dim: 300
15 | emb_init: glove
16 | image_extractor: resnet18
17 | obj_fixed: True
18 | loss:
19 | lambda_cls_attr: 1.0
20 | lambda_cls_obj: 0.1
21 | lambda_trip: 0.5
22 | lambda_sym: 0.01
23 | lambda_axiom: 0.03
24 | training:
25 | batch_size: 450
26 | eval_val_every: 2
27 | load:
28 | lr: 5.0e-05
29 | lrg: 5.0e-06
30 | margin: 0.5
31 | max_epochs: 2000
32 | norm_family: imagenet
33 | save_every: 10000
34 | test_batch_size: 32
35 | test_set: val
36 | topk: 1
37 | wd: 5.0e-05
38 | workers: 8
39 | update_features: false
40 | freeze_features: false
--------------------------------------------------------------------------------
/configs/baselines/mit/tmn.yml:
--------------------------------------------------------------------------------
1 | ---
2 | experiment:
3 | name: tmn/mitstates/
4 | dataset:
5 | data_dir: mit-states
6 | dataset: mitstates
7 | splitname: compositional-split-natural
8 | model_params:
9 | model: tmn
10 | emb_init: glove
11 | image_extractor: resnet18
12 | num_negs: 600
13 | embed_rank: 64
14 | emb_dim: 16
15 | nmods: 24
16 | training:
17 | batch_size: 256
18 | eval_val_every: 2
19 | load:
20 | lr: 0.0001
21 | lrg: 1.0e-06
22 | margin: 0.5
23 | max_epochs: 100
24 | norm_family: imagenet
25 | save_every: 10000
26 | test_batch_size: 32
27 | test_set: val
28 | topk: 1
29 | wd: 5.0e-05
30 | workers: 8
31 | update_features: true
32 | freeze_features: false
--------------------------------------------------------------------------------
/configs/baselines/utzppos/aopp.yml:
--------------------------------------------------------------------------------
1 | ---
2 | experiment:
3 | name: aopp/utzappos/
4 | dataset:
5 | data_dir: ut-zap50k
6 | dataset: utzappos
7 | splitname: compositional-split-natural
8 | model_params:
9 | model: attributeop
10 | emb_dim: 300
11 | emb_init: glove
12 | eval_type: dist_fast
13 | image_extractor: resnet18
14 | train_only: true
15 | static_inp: false
16 | composition: add
17 | loss:
18 | lambda_aux: 1.0
19 | lambda_comm: 1.0
20 | training:
21 | batch_size: 512
22 | eval_val_every: 2
23 | load:
24 | lr: 5.0e-05
25 | lrg: 0.001
26 | margin: 0.5
27 | max_epochs: 1000
28 | norm_family: imagenet
29 | save_every: 10000
30 | test_batch_size: 32
31 | test_set: val
32 | topk: 1
33 | wd: 5.0e-05
34 | workers: 8
35 | update_features: false
36 | freeze_features: false
37 | extra:
38 | lambda_attr: 0
39 | lambda_obj: 0
40 | lambda_sub: 0
41 | graph: false
42 | hardk: null
--------------------------------------------------------------------------------
/configs/baselines/utzppos/le+.yml:
--------------------------------------------------------------------------------
1 | ---
2 | experiment:
3 | name: labelembed/utzappos/
4 | dataset:
5 | data_dir: ut-zap50k
6 | dataset: utzappos
7 | splitname: compositional-split-natural
8 | model_params:
9 | model: labelembed+
10 | dropout: false
11 | norm: false
12 | nlayers: 2
13 | fc_emb: 512,512
14 | emb_dim: 300
15 | emb_init: glove
16 | eval_type: dist_fast
17 | image_extractor: resnet18
18 | train_only: true
19 | static_inp: false
20 | composition: add
21 | training:
22 | batch_size: 512
23 | eval_val_every: 2
24 | load:
25 | lr: 5.0e-05
26 | margin: 0.5
27 | max_epochs: 2000
28 | norm_family: imagenet
29 | save_every: 10000
30 | test_batch_size: 32
31 | test_set: val
32 | topk: 1
33 | wd: 5.0e-05
34 | workers: 8
35 | update_features: false
36 | freeze_features: false
--------------------------------------------------------------------------------
/configs/baselines/utzppos/symnet.yml:
--------------------------------------------------------------------------------
1 | ---
2 | experiment:
3 | name: symnet/utzappos/
4 | dataset:
5 | data_dir: ut-zap50k
6 | dataset: utzappos
7 | splitname: compositional-split-natural
8 | model_params:
9 | model: symnet
10 | dropout: true
11 | norm: true
12 | nlayers: 3
13 | fc_emb: 768,1024,1200
14 | emb_dim: 300
15 | emb_init: fasttext
16 | image_extractor: resnet18
17 | loss:
18 | lambda_cls_attr: 1.0
19 | lambda_cls_obj: 1.0
20 | lambda_trip: 0.5
21 | lambda_sym: 0.01
22 | lambda_axiom: 0.03
23 | training:
24 | batch_size: 512
25 | eval_val_every: 2
26 | load:
27 | lr: 5.0e-05
28 | lrg: 0.001
29 | margin: 0.5
30 | max_epochs: 2000
31 | norm_family: imagenet
32 | save_every: 10000
33 | test_batch_size: 32
34 | test_set: val
35 | topk: 1
36 | wd: 5.0e-05
37 | workers: 8
38 | update_features: false
39 | freeze_features: false
--------------------------------------------------------------------------------
/configs/baselines/utzppos/tmn.yml:
--------------------------------------------------------------------------------
1 | ---
2 | experiment:
3 | name: tmn/utzappos/
4 | dataset:
5 | data_dir: ut-zap50k
6 | dataset: utzappos
7 | splitname: compositional-split-natural
8 | model_params:
9 | model: tmn
10 | emb_init: word2vec
11 | emb_dim: 300
12 | image_extractor: resnet18
13 | train_only: true
14 | static_inp: true
15 | num_negs: 600
16 | embed_rank: 64
17 | nmods: 24
18 | training:
19 | batch_size: 256
20 | eval_val_every: 10
21 | load:
22 | lr: 0.0001
23 | lrg: 1.0e-06
24 | margin: 0.5
25 | max_epochs: 100
26 | norm_family: imagenet
27 | save_every: 10000
28 | test_batch_size: 32
29 | test_set: val
30 | topk: 1
31 | wd: 5.0e-05
32 | workers: 8
33 | update_features: false
34 | freeze_features: false
35 | (czsl) mmancini68@EML-Jupi
--------------------------------------------------------------------------------
/configs/cge/cgqa.yml:
--------------------------------------------------------------------------------
1 | experiment:
2 | name: graphembed/cgqa
3 | dataset:
4 | data_dir: cgqa
5 | dataset: cgqa
6 | splitname: compositional-split-natural
7 | model_params:
8 | model: graphfull
9 | dropout: true
10 | norm: true
11 | nlayers:
12 | fc_emb: 768,1024,1200
13 | gr_emb: d4096,d
14 | emb_dim: 512
15 | emb_init: word2vec
16 | image_extractor: resnet18
17 | train_only: true
18 | training:
19 | batch_size: 256
20 | eval_val_every: 2
21 | load:
22 | lr: 5.0e-05
23 | wd: 5.0e-05
24 | lrg: 5.0e-06
25 | margin: 2
26 | max_epochs: 2000
27 | norm_family: imagenet
28 | save_every: 10000
29 | test_batch_size: 32
30 | test_set: val
31 | topk: 1
32 | workers: 8
33 | update_features: true
34 | freeze_featues: false
35 |
--------------------------------------------------------------------------------
/configs/cge/mit.yml:
--------------------------------------------------------------------------------
1 | ---
2 | experiment:
3 | name: graphembed/mitstates/base
4 | dataset:
5 | data_dir: mit-states
6 | dataset: mitstates
7 | splitname: compositional-split-natural
8 | model_params:
9 | model: graphfull
10 | dropout: true
11 | norm: true
12 | nlayers:
13 | gr_emb: d4096,d
14 | emb_dim: 512
15 | graph_init: utils/mitstates-graph.t7
16 | image_extractor: resnet18
17 | train_only: true
18 | training:
19 | batch_size: 128
20 | eval_val_every: 2
21 | load:
22 | lr: 5.0e-05
23 | wd: 5.0e-05
24 | lrg: 5.0e-6
25 | margin: 0.5
26 | max_epochs: 200
27 | norm_family: imagenet
28 | save_every: 10000
29 | test_batch_size: 32
30 | test_set: val
31 | topk: 1
32 | workers: 8
33 | update_features: true
34 | freeze_features: false
35 |
36 |
--------------------------------------------------------------------------------
/configs/cge/utzappos.yml:
--------------------------------------------------------------------------------
1 | ---
2 | experiment:
3 | name: graphembed/utzappos/base
4 | dataset:
5 | data_dir: ut-zap50k
6 | dataset: utzappos
7 | splitname: compositional-split-natural
8 | model_params:
9 | model: graphfull
10 | dropout: true
11 | norm: true
12 | nlayers: 1
13 | fc_emb: 768,1024,1200
14 | gr_emb: d4096,d
15 | emb_dim: 300
16 | emb_init: null
17 | graph_init: utils/utzappos-graph.t7
18 | image_extractor: resnet18
19 | train_only: true
20 | training:
21 | batch_size: 128
22 | eval_val_every: 2
23 | load:
24 | lr: 5.0e-05
25 | wd: 5.0e-05
26 | lrg: 5.0e-6
27 | margin: 0.5
28 | max_epochs: 2000
29 | norm_family: imagenet
30 | save_every: 10000
31 | test_batch_size: 32
32 | test_set: val
33 | topk: 1
34 | workers: 8
35 | update_features: true
36 | freeze_features: false
37 | extra:
38 | lambda_attr: 0
39 | lambda_obj: 0
40 | lambda_sub: 0
41 | graph: false
42 | hardk: null
43 | cutmix: false
44 | cutmix_prob: 1.0
45 | beta: 1.0
--------------------------------------------------------------------------------
/configs/compcos/cgqa/compcos.yml:
--------------------------------------------------------------------------------
1 | ---
2 | experiment:
3 | name: compcos/cgqa
4 | dataset:
5 | data_dir: cgqa
6 | dataset: cgqa
7 | splitname: compositional-split-natural
8 | open_world: true
9 | model_params:
10 | model: compcos
11 | dropout: true
12 | norm: true
13 | nlayers: 2
14 | relu: false
15 | fc_emb: 768,1024,1200
16 | emb_dim: 300
17 | emb_init: word2vec
18 | image_extractor: resnet18
19 | train_only: false
20 | static_inp: false
21 | training:
22 | batch_size: 128
23 | load:
24 | lr: 5.0e-05
25 | lrg: 0.001
26 | margin: 0.4
27 | cosine_scale: 100
28 | max_epochs: 300
29 | norm_family: imagenet
30 | save_every: 10000
31 | test_batch_size: 64
32 | test_set: val
33 | topk: 1
34 | wd: 5.0e-05
35 | workers: 8
36 | update_features: false
37 | freeze_features: false
38 | epoch_max_margin: 15
39 |
--------------------------------------------------------------------------------
/configs/compcos/cgqa/compcos_cw.yml:
--------------------------------------------------------------------------------
1 | ---
2 | experiment:
3 | name: compcos_cw/cgqa/
4 | dataset:
5 | data_dir: cgqa
6 | dataset: cgqa
7 | splitname: compositional-split-natural
8 | model_params:
9 | model: compcos
10 | dropout: true
11 | norm: true
12 | nlayers: 2
13 | relu: false
14 | fc_emb: 768,1024,1200
15 | emb_dim: 300
16 | emb_init: word2vec
17 | image_extractor: resnet18
18 | train_only: true
19 | static_inp: false
20 | training:
21 | batch_size: 128
22 | load:
23 | lr: 5.0e-05
24 | lrg: 0.001
25 | margin: 0.4
26 | cosine_scale: 100
27 | max_epochs: 300
28 | norm_family: imagenet
29 | save_every: 10000
30 | test_batch_size: 64
31 | test_set: val
32 | topk: 1
33 | wd: 5.0e-05
34 | workers: 8
35 | update_features: false
36 | freeze_features: false
37 | epoch_max_margin: 15
38 |
--------------------------------------------------------------------------------
/configs/compcos/mit/compcos.yml:
--------------------------------------------------------------------------------
1 | ---
2 | experiment:
3 | name: compcos/mitstates
4 | dataset:
5 | data_dir: mit-states
6 | dataset: mitstates
7 | splitname: compositional-split-natural
8 | open_world: true
9 | model_params:
10 | model: compcos
11 | dropout: true
12 | norm: true
13 | nlayers: 2
14 | relu: false
15 | fc_emb: 768,1024,1200
16 | emb_dim: 600
17 | emb_init: ft+w2v
18 | image_extractor: resnet18
19 | train_only: false
20 | static_inp: false
21 | training:
22 | batch_size: 128
23 | load:
24 | lr: 5.0e-05
25 | lrg: 0.001
26 | margin: 0.4
27 | cosine_scale: 20
28 | max_epochs: 300
29 | norm_family: imagenet
30 | save_every: 10000
31 | test_batch_size: 64
32 | test_set: val
33 | topk: 1
34 | wd: 5.0e-05
35 | workers: 8
36 | update_features: false
37 | freeze_features: false
38 | epoch_max_margin: 15
39 |
--------------------------------------------------------------------------------
/configs/compcos/mit/compcos_cw.yml:
--------------------------------------------------------------------------------
1 | ---
2 | experiment:
3 | name: compcos_cw/mitstates
4 | dataset:
5 | data_dir: mit-states
6 | dataset: mitstates
7 | splitname: compositional-split-natural
8 | model_params:
9 | model: compcos
10 | dropout: true
11 | norm: true
12 | nlayers: 2
13 | relu: false
14 | fc_emb: 768,1024,1200
15 | emb_dim: 600
16 | emb_init: ft+w2v
17 | image_extractor: resnet18
18 | train_only: true
19 | static_inp: false
20 | training:
21 | batch_size: 128
22 | load:
23 | lr: 5.0e-05
24 | lrg: 0.001
25 | margin: 0.4
26 | cosine_scale: 20
27 | max_epochs: 300
28 | norm_family: imagenet
29 | save_every: 10000
30 | test_batch_size: 64
31 | test_set: val
32 | topk: 1
33 | wd: 5.0e-05
34 | workers: 8
35 | update_features: false
36 | freeze_features: false
37 | epoch_max_margin: 15
38 |
--------------------------------------------------------------------------------
/configs/compcos/utzppos/compcos.yml:
--------------------------------------------------------------------------------
1 | ---
2 | experiment:
3 | name: compcos/utzappos
4 | dataset:
5 | data_dir: ut-zap50k
6 | dataset: utzappos
7 | splitname: compositional-split-natural
8 | open_world: true
9 | model_params:
10 | model: compcos
11 | dropout: true
12 | norm: true
13 | nlayers: 2
14 | relu: false
15 | fc_emb: 768,1024
16 | emb_dim: 300
17 | emb_init: word2vec
18 | image_extractor: resnet18
19 | train_only: false
20 | static_inp: false
21 | training:
22 | batch_size: 128
23 | load:
24 | lr: 5.0e-05
25 | lrg: 0.001
26 | margin: 1.0
27 | cosine_scale: 50
28 | max_epochs: 300
29 | norm_family: imagenet
30 | save_every: 10000
31 | test_batch_size: 64
32 | test_set: val
33 | topk: 1
34 | wd: 5.0e-05
35 | workers: 8
36 | update_features: false
37 | freeze_features: false
38 | epoch_max_margin: 100
--------------------------------------------------------------------------------
/configs/compcos/utzppos/compcos_cw.yml:
--------------------------------------------------------------------------------
1 | ---
2 | experiment:
3 | name: compcos/utzappos
4 | dataset:
5 | data_dir: ut-zap50k
6 | dataset: utzappos
7 | splitname: compositional-split-natural
8 | model_params:
9 | model: compcos
10 | dropout: true
11 | norm: true
12 | nlayers: 2
13 | relu: false
14 | fc_emb: 768,1024
15 | emb_dim: 300
16 | emb_init: word2vec
17 | image_extractor: resnet18
18 | train_only: true
19 | static_inp: false
20 | training:
21 | batch_size: 128
22 | load:
23 | lr: 5.0e-05
24 | lrg: 0.001
25 | margin: 1.0
26 | cosine_scale: 50
27 | max_epochs: 300
28 | norm_family: imagenet
29 | save_every: 10000
30 | test_batch_size: 64
31 | test_set: val
32 | topk: 1
33 | wd: 5.0e-05
34 | workers: 8
35 | update_features: false
36 | freeze_features: false
37 | epoch_max_margin: 100
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ExplainableML/czsl/22ab1aca82c64d4f351a20a887986147edfb1905/data/__init__.py
--------------------------------------------------------------------------------
/data/dataset.py:
--------------------------------------------------------------------------------
1 | #external libs
2 | import numpy as np
3 | from tqdm import tqdm
4 | from PIL import Image
5 | import os
6 | import random
7 | from os.path import join as ospj
8 | from glob import glob
9 | #torch libs
10 | from torch.utils.data import Dataset
11 | import torch
12 | import torchvision.transforms as transforms
13 | #local libs
14 | from utils.utils import get_norm_values, chunks
15 | from models.image_extractor import get_image_extractor
16 | from itertools import product
17 |
18 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
19 |
20 | class ImageLoader:
21 | def __init__(self, root):
22 | self.root_dir = root
23 |
24 | def __call__(self, img):
25 | img = Image.open(ospj(self.root_dir,img)).convert('RGB') #We don't want alpha
26 | return img
27 |
28 |
29 | def dataset_transform(phase, norm_family = 'imagenet'):
30 | '''
31 | Inputs
32 | phase: String controlling which set of transforms to use
33 | norm_family: String controlling which normaliztion values to use
34 |
35 | Returns
36 | transform: A list of pytorch transforms
37 | '''
38 | mean, std = get_norm_values(norm_family=norm_family)
39 |
40 | if phase == 'train':
41 | transform = transforms.Compose([
42 | transforms.RandomResizedCrop(224),
43 | transforms.RandomHorizontalFlip(),
44 | transforms.ToTensor(),
45 | transforms.Normalize(mean, std)
46 | ])
47 |
48 | elif phase == 'val' or phase == 'test':
49 | transform = transforms.Compose([
50 | transforms.Resize(256),
51 | transforms.CenterCrop(224),
52 | transforms.ToTensor(),
53 | transforms.Normalize(mean, std)
54 | ])
55 | elif phase == 'all':
56 | transform = transforms.Compose([
57 | transforms.Resize(256),
58 | transforms.CenterCrop(224),
59 | transforms.ToTensor(),
60 | transforms.Normalize(mean, std)
61 | ])
62 | else:
63 | raise ValueError('Invalid transform')
64 |
65 | return transform
66 |
67 | def filter_data(all_data, pairs_gt, topk = 5):
68 | '''
69 | Helper function to clean data
70 | '''
71 | valid_files = []
72 | with open('/home/ubuntu/workspace/top'+str(topk)+'.txt') as f:
73 | for line in f:
74 | valid_files.append(line.strip())
75 |
76 | data, pairs, attr, obj = [], [], [], []
77 | for current in all_data:
78 | if current[0] in valid_files:
79 | data.append(current)
80 | pairs.append((current[1],current[2]))
81 | attr.append(current[1])
82 | obj.append(current[2])
83 |
84 | counter = 0
85 | for current in pairs_gt:
86 | if current in pairs:
87 | counter+=1
88 | print('Matches ', counter, ' out of ', len(pairs_gt))
89 | print('Samples ', len(data), ' out of ', len(all_data))
90 | return data, sorted(list(set(pairs))), sorted(list(set(attr))), sorted(list(set(obj)))
91 |
92 | # Dataset class now
93 |
94 | class CompositionDataset(Dataset):
95 | '''
96 | Inputs
97 | root: String of base dir of dataset
98 | phase: String train, val, test
99 | split: String dataset split
100 | subset: Boolean if true uses a subset of train at each epoch
101 | num_negs: Int, numbers of negative pairs per batch
102 | pair_dropout: Percentage of pairs to leave in current epoch
103 | '''
104 | def __init__(
105 | self,
106 | root,
107 | phase,
108 | split = 'compositional-split',
109 | model = 'resnet18',
110 | norm_family = 'imagenet',
111 | subset = False,
112 | num_negs = 1,
113 | pair_dropout = 0.0,
114 | update_features = False,
115 | return_images = False,
116 | train_only = False,
117 | open_world=False
118 | ):
119 | self.root = root
120 | self.phase = phase
121 | self.split = split
122 | self.num_negs = num_negs
123 | self.pair_dropout = pair_dropout
124 | self.norm_family = norm_family
125 | self.return_images = return_images
126 | self.update_features = update_features
127 | self.feat_dim = 512 if 'resnet18' in model else 2048 # todo, unify this with models
128 | self.open_world = open_world
129 |
130 | self.attrs, self.objs, self.pairs, self.train_pairs, \
131 | self.val_pairs, self.test_pairs = self.parse_split()
132 | self.train_data, self.val_data, self.test_data = self.get_split_info()
133 | self.full_pairs = list(product(self.attrs,self.objs))
134 |
135 | # Clean only was here
136 | self.obj2idx = {obj: idx for idx, obj in enumerate(self.objs)}
137 | self.attr2idx = {attr : idx for idx, attr in enumerate(self.attrs)}
138 | if self.open_world:
139 | self.pairs = self.full_pairs
140 |
141 | self.all_pair2idx = {pair: idx for idx, pair in enumerate(self.pairs)}
142 |
143 | if train_only and self.phase == 'train':
144 | print('Using only train pairs')
145 | self.pair2idx = {pair : idx for idx, pair in enumerate(self.train_pairs)}
146 | else:
147 | print('Using all pairs')
148 | self.pair2idx = {pair : idx for idx, pair in enumerate(self.pairs)}
149 |
150 | if self.phase == 'train':
151 | self.data = self.train_data
152 | elif self.phase == 'val':
153 | self.data = self.val_data
154 | elif self.phase == 'test':
155 | self.data = self.test_data
156 | elif self.phase == 'all':
157 | print('Using all data')
158 | self.data = self.train_data + self.val_data + self.test_data
159 | else:
160 | raise ValueError('Invalid training phase')
161 |
162 | self.all_data = self.train_data + self.val_data + self.test_data
163 | print('Dataset loaded')
164 | print('Train pairs: {}, Validation pairs: {}, Test Pairs: {}'.format(
165 | len(self.train_pairs), len(self.val_pairs), len(self.test_pairs)))
166 | print('Train images: {}, Validation images: {}, Test images: {}'.format(
167 | len(self.train_data), len(self.val_data), len(self.test_data)))
168 |
169 | if subset:
170 | ind = np.arange(len(self.data))
171 | ind = ind[::len(ind) // 1000]
172 | self.data = [self.data[i] for i in ind]
173 |
174 |
175 | # Keeping a list of all pairs that occur with each object
176 | self.obj_affordance = {}
177 | self.train_obj_affordance = {}
178 | for _obj in self.objs:
179 | candidates = [attr for (_, attr, obj) in self.train_data+self.test_data if obj==_obj]
180 | self.obj_affordance[_obj] = list(set(candidates))
181 |
182 | candidates = [attr for (_, attr, obj) in self.train_data if obj==_obj]
183 | self.train_obj_affordance[_obj] = list(set(candidates))
184 |
185 | self.sample_indices = list(range(len(self.data)))
186 | self.sample_pairs = self.train_pairs
187 |
188 | # Load based on what to output
189 | self.transform = dataset_transform(self.phase, self.norm_family)
190 | self.loader = ImageLoader(ospj(self.root, 'images'))
191 | if not self.update_features:
192 | feat_file = ospj(root, model+'_featurers.t7')
193 | print(f'Using {model} and feature file {feat_file}')
194 | if not os.path.exists(feat_file):
195 | with torch.no_grad():
196 | self.generate_features(feat_file, model)
197 | self.phase = phase
198 | activation_data = torch.load(feat_file)
199 | self.activations = dict(
200 | zip(activation_data['files'], activation_data['features']))
201 | self.feat_dim = activation_data['features'].size(1)
202 | print('{} activations loaded'.format(len(self.activations)))
203 |
204 |
205 | def parse_split(self):
206 | '''
207 | Helper function to read splits of object atrribute pair
208 | Returns
209 | all_attrs: List of all attributes
210 | all_objs: List of all objects
211 | all_pairs: List of all combination of attrs and objs
212 | tr_pairs: List of train pairs of attrs and objs
213 | vl_pairs: List of validation pairs of attrs and objs
214 | ts_pairs: List of test pairs of attrs and objs
215 | '''
216 | def parse_pairs(pair_list):
217 | '''
218 | Helper function to parse each phase to object attrribute vectors
219 | Inputs
220 | pair_list: path to textfile
221 | '''
222 | with open(pair_list, 'r') as f:
223 | pairs = f.read().strip().split('\n')
224 | pairs = [line.split() for line in pairs]
225 | pairs = list(map(tuple, pairs))
226 |
227 | attrs, objs = zip(*pairs)
228 | return attrs, objs, pairs
229 |
230 | tr_attrs, tr_objs, tr_pairs = parse_pairs(
231 | ospj(self.root, self.split, 'train_pairs.txt')
232 | )
233 | vl_attrs, vl_objs, vl_pairs = parse_pairs(
234 | ospj(self.root, self.split, 'val_pairs.txt')
235 | )
236 | ts_attrs, ts_objs, ts_pairs = parse_pairs(
237 | ospj(self.root, self.split, 'test_pairs.txt')
238 | )
239 |
240 | #now we compose all objs, attrs and pairs
241 | all_attrs, all_objs = sorted(
242 | list(set(tr_attrs + vl_attrs + ts_attrs))), sorted(
243 | list(set(tr_objs + vl_objs + ts_objs)))
244 | all_pairs = sorted(list(set(tr_pairs + vl_pairs + ts_pairs)))
245 |
246 | return all_attrs, all_objs, all_pairs, tr_pairs, vl_pairs, ts_pairs
247 |
248 | def get_split_info(self):
249 | '''
250 | Helper method to read image, attrs, objs samples
251 |
252 | Returns
253 | train_data, val_data, test_data: List of tuple of image, attrs, obj
254 | '''
255 | data = torch.load(ospj(self.root, 'metadata_{}.t7'.format(self.split)))
256 |
257 | train_data, val_data, test_data = [], [], []
258 |
259 | for instance in data:
260 | image, attr, obj, settype = instance['image'], instance['attr'], \
261 | instance['obj'], instance['set']
262 | curr_data = [image, attr, obj]
263 |
264 | if attr == 'NA' or (attr, obj) not in self.pairs or settype == 'NA':
265 | # Skip incomplete pairs, unknown pairs and unknown set
266 | continue
267 |
268 | if settype == 'train':
269 | train_data.append(curr_data)
270 | elif settype == 'val':
271 | val_data.append(curr_data)
272 | else:
273 | test_data.append(curr_data)
274 |
275 | return train_data, val_data, test_data
276 |
277 | def get_dict_data(self, data, pairs):
278 | data_dict = {}
279 | for current in pairs:
280 | data_dict[current] = []
281 |
282 | for current in data:
283 | image, attr, obj = current
284 | data_dict[(attr, obj)].append(image)
285 |
286 | return data_dict
287 |
288 |
289 | def reset_dropout(self):
290 | '''
291 | Helper function to sample new subset of data containing a subset of pairs of objs and attrs
292 | '''
293 | self.sample_indices = list(range(len(self.data)))
294 | self.sample_pairs = self.train_pairs
295 |
296 | # Using sampling from random instead of 2 step numpy
297 | n_pairs = int((1 - self.pair_dropout) * len(self.train_pairs))
298 |
299 | self.sample_pairs = random.sample(self.train_pairs, n_pairs)
300 | print('Sampled new subset')
301 | print('Using {} pairs out of {} pairs right now'.format(
302 | n_pairs, len(self.train_pairs)))
303 |
304 | self.sample_indices = [ i for i in range(len(self.data))
305 | if (self.data[i][1], self.data[i][2]) in self.sample_pairs
306 | ]
307 | print('Using {} images out of {} images right now'.format(
308 | len(self.sample_indices), len(self.data)))
309 |
310 | def sample_negative(self, attr, obj):
311 | '''
312 | Inputs
313 | attr: String of valid attribute
314 | obj: String of valid object
315 | Returns
316 | Tuple of a different attribute, object indexes
317 | '''
318 | new_attr, new_obj = self.sample_pairs[np.random.choice(
319 | len(self.sample_pairs))]
320 |
321 | while new_attr == attr and new_obj == obj:
322 | new_attr, new_obj = self.sample_pairs[np.random.choice(
323 | len(self.sample_pairs))]
324 |
325 | return (self.attr2idx[new_attr], self.obj2idx[new_obj])
326 |
327 | def sample_affordance(self, attr, obj):
328 | '''
329 | Inputs
330 | attr: String of valid attribute
331 | obj: String of valid object
332 | Return
333 | Idx of a different attribute for the same object
334 | '''
335 | new_attr = np.random.choice(self.obj_affordance[obj])
336 |
337 | while new_attr == attr:
338 | new_attr = np.random.choice(self.obj_affordance[obj])
339 |
340 | return self.attr2idx[new_attr]
341 |
342 | def sample_train_affordance(self, attr, obj):
343 | '''
344 | Inputs
345 | attr: String of valid attribute
346 | obj: String of valid object
347 | Return
348 | Idx of a different attribute for the same object from the training pairs
349 | '''
350 | new_attr = np.random.choice(self.train_obj_affordance[obj])
351 |
352 | while new_attr == attr:
353 | new_attr = np.random.choice(self.train_obj_affordance[obj])
354 |
355 | return self.attr2idx[new_attr]
356 |
357 | def generate_features(self, out_file, model):
358 | '''
359 | Inputs
360 | out_file: Path to save features
361 | model: String of extraction model
362 | '''
363 | # data = self.all_data
364 | data = ospj(self.root,'images')
365 | files_before = glob(ospj(data, '**', '*.jpg'), recursive=True)
366 | files_all = []
367 | for current in files_before:
368 | parts = current.split('/')
369 | if "cgqa" in self.root:
370 | files_all.append(parts[-1])
371 | else:
372 | files_all.append(os.path.join(parts[-2],parts[-1]))
373 | transform = dataset_transform('test', self.norm_family)
374 | feat_extractor = get_image_extractor(arch = model).eval()
375 | feat_extractor = feat_extractor.to(device)
376 |
377 | image_feats = []
378 | image_files = []
379 | for chunk in tqdm(
380 | chunks(files_all, 512), total=len(files_all) // 512, desc=f'Extracting features {model}'):
381 |
382 | files = chunk
383 | imgs = list(map(self.loader, files))
384 | imgs = list(map(transform, imgs))
385 | feats = feat_extractor(torch.stack(imgs, 0).to(device))
386 | image_feats.append(feats.data.cpu())
387 | image_files += files
388 | image_feats = torch.cat(image_feats, 0)
389 | print('features for %d images generated' % (len(image_files)))
390 |
391 | torch.save({'features': image_feats, 'files': image_files}, out_file)
392 |
393 | def __getitem__(self, index):
394 | '''
395 | Call for getting samples
396 | '''
397 | index = self.sample_indices[index]
398 |
399 | image, attr, obj = self.data[index]
400 |
401 | # Decide what to output
402 | if not self.update_features:
403 | img = self.activations[image]
404 | else:
405 | img = self.loader(image)
406 | img = self.transform(img)
407 |
408 | data = [img, self.attr2idx[attr], self.obj2idx[obj], self.pair2idx[(attr, obj)]]
409 |
410 | if self.phase == 'train':
411 | all_neg_attrs = []
412 | all_neg_objs = []
413 |
414 | for curr in range(self.num_negs):
415 | neg_attr, neg_obj = self.sample_negative(attr, obj) # negative for triplet lose
416 | all_neg_attrs.append(neg_attr)
417 | all_neg_objs.append(neg_obj)
418 |
419 | neg_attr, neg_obj = torch.LongTensor(all_neg_attrs), torch.LongTensor(all_neg_objs)
420 |
421 | #note here
422 | if len(self.train_obj_affordance[obj])>1:
423 | inv_attr = self.sample_train_affordance(attr, obj) # attribute for inverse regularizer
424 | else:
425 | inv_attr = (all_neg_attrs[0])
426 |
427 | comm_attr = self.sample_affordance(inv_attr, obj) # attribute for commutative regularizer
428 |
429 |
430 | data += [neg_attr, neg_obj, inv_attr, comm_attr]
431 |
432 | # Return image paths if requested as the last element of the list
433 | if self.return_images and self.phase != 'train':
434 | data.append(image)
435 |
436 | return data
437 |
438 | def __len__(self):
439 | '''
440 | Call for length
441 | '''
442 | return len(self.sample_indices)
443 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: czsl
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - _libgcc_mutex=0.1=main
7 | - blas=1.0=mkl
8 | - ca-certificates=2021.1.19=h06a4308_0
9 | - certifi=2020.12.5=py38h06a4308_0
10 | - cudatoolkit=11.0.221=h6bb024c_0
11 | - freetype=2.10.4=h5ab3b9f_0
12 | - intel-openmp=2020.2=254
13 | - jpeg=9b=h024ee3a_2
14 | - lcms2=2.11=h396b838_0
15 | - ld_impl_linux-64=2.33.1=h53a641e_7
16 | - libedit=3.1.20191231=h14c3975_1
17 | - libffi=3.3=he6710b0_2
18 | - libgcc-ng=9.1.0=hdf63c60_0
19 | - libgfortran-ng=7.3.0=hdf63c60_0
20 | - libpng=1.6.37=hbc83047_0
21 | - libstdcxx-ng=9.1.0=hdf63c60_0
22 | - libtiff=4.1.0=h2733197_1
23 | - libuv=1.40.0=h7b6447c_0
24 | - lz4-c=1.9.3=h2531618_0
25 | - mkl=2020.2=256
26 | - mkl-service=2.3.0=py38he904b0f_0
27 | - mkl_fft=1.2.0=py38h23d657b_0
28 | - mkl_random=1.1.1=py38h0573a6f_0
29 | - ncurses=6.2=he6710b0_1
30 | - ninja=1.10.2=py38hff7bd54_0
31 | - numpy=1.19.2=py38h54aff64_0
32 | - numpy-base=1.19.2=py38hfa32c7d_0
33 | - olefile=0.46=py_0
34 | - openssl=1.1.1i=h27cfd23_0
35 | - pillow=8.1.0=py38he98fc37_0
36 | - pip=20.3.3=py38h06a4308_0
37 | - python=3.8.5=h7579374_1
38 | - pytorch=1.7.1=py3.8_cuda11.0.221_cudnn8.0.5_0
39 | - readline=8.1=h27cfd23_0
40 | - scipy=1.5.2=py38h0b6359f_0
41 | - setuptools=52.0.0=py38h06a4308_0
42 | - six=1.15.0=py38h06a4308_0
43 | - sqlite=3.33.0=h62c20be_0
44 | - tk=8.6.10=hbc83047_0
45 | - torchaudio=0.7.2=py38
46 | - torchvision=0.8.2=py38_cu110
47 | - typing_extensions=3.7.4.3=pyh06a4308_0
48 | - wheel=0.36.2=pyhd3eb1b0_0
49 | - xz=5.2.5=h7b6447c_0
50 | - yaml=0.2.5=h7b6447c_0
51 | - zlib=1.2.11=h7b6447c_3
52 | - zstd=1.4.5=h9ceee32_0
53 | - pip:
54 | - absl-py==0.11.0
55 | - cachetools==4.2.1
56 | - chardet==4.0.0
57 | - fasttext==0.9.2
58 | - gensim==3.8.3
59 | - google-auth==1.24.0
60 | - google-auth-oauthlib==0.4.2
61 | - grpcio==1.35.0
62 | - idna==2.10
63 | - markdown==3.3.3
64 | - oauthlib==3.1.0
65 | - protobuf==3.14.0
66 | - pyasn1==0.4.8
67 | - pyasn1-modules==0.2.8
68 | - pybind11==2.6.2
69 | - pyyaml==5.4.1
70 | - requests==2.25.1
71 | - requests-oauthlib==1.3.0
72 | - rsa==4.7
73 | - smart-open==4.1.2
74 | - tensorboard==2.4.1
75 | - tensorboard-plugin-wit==1.8.0
76 | - tqdm==4.56.0
77 | - urllib3==1.26.3
78 | - werkzeug==1.0.1
79 |
80 |
--------------------------------------------------------------------------------
/flags.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | DATA_FOLDER = "ROOT_FOLDER"
4 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
5 |
6 | parser.add_argument('--config', default='configs/args.yml', help='path of the config file (training only)')
7 | parser.add_argument('--dataset', default='mitstates', help='mitstates|zappos')
8 | parser.add_argument('--data_dir', default='mit-states', help='local path to data root dir from ' + DATA_FOLDER)
9 | parser.add_argument('--logpath', default=None, help='Path to dir where to logs are stored (test only)')
10 | parser.add_argument('--splitname', default='compositional-split-natural', help="dataset split")
11 | parser.add_argument('--cv_dir', default='logs/', help='dir to save checkpoints and logs to')
12 | parser.add_argument('--name', default='temp', help='Name of exp used to name models')
13 | parser.add_argument('--load', default=None, help='path to checkpoint to load from')
14 | parser.add_argument('--image_extractor', default = 'resnet18', help = 'Feature extractor model')
15 | parser.add_argument('--norm_family', default = 'imagenet', help = 'Normalization values from dataset')
16 | parser.add_argument('--num_negs', type=int, default=1, help='Number of negatives to sample per positive (triplet loss)')
17 | parser.add_argument('--pair_dropout', type=float, default=0.0, help='Each epoch drop this fraction of train pairs')
18 | parser.add_argument('--test_set', default='val', help='val|test mode')
19 | parser.add_argument('--clean_only', action='store_true', default=False, help='use only clean subset of data (mitstates)')
20 | parser.add_argument('--subset', action='store_true', default=False, help='test on a 1000 image subset (debug purpose)')
21 | parser.add_argument('--open_world', action='store_true', default=False, help='perform open world experiment')
22 | parser.add_argument('--test_batch_size', type=int, default=32, help="Batch size at test/eval time")
23 | parser.add_argument('--cpu_eval', action='store_true', help='Perform test on cpu')
24 |
25 | # Model parameters
26 | parser.add_argument('--model', default='graphfull', help='visprodNN|redwine|labelembed+|attributeop|tmn|compcos')
27 | parser.add_argument('--emb_dim', type=int, default=300, help='dimension of share embedding space')
28 | parser.add_argument('--nlayers', type=int, default=3, help='Layers in the image embedder')
29 | parser.add_argument('--nmods', type=int, default=24, help='number of mods per layer for TMN')
30 | parser.add_argument('--embed_rank', type=int, default=64, help='intermediate dimension in the gating model for TMN')
31 | parser.add_argument('--bias', type=float, default=1e3, help='Bias value for unseen concepts')
32 | parser.add_argument('--update_features', action = 'store_true', default=False, help='If specified, train feature extractor')
33 | parser.add_argument('--freeze_features', action = 'store_true', default=False, help='If specified, put extractor in eval mode')
34 | parser.add_argument('--emb_init', default=None, help='w2v|ft|gl|glove|word2vec|fasttext, name of embeddings to use for initializing the primitives')
35 | parser.add_argument('--clf_init', action='store_true', default=False, help='initialize inputs with SVM weights')
36 | parser.add_argument('--static_inp', action='store_true', default=False, help='do not optimize primitives representations')
37 | parser.add_argument('--composition', default='mlp_add', help='add|mul|mlp|mlp_add, how to compose primitives')
38 | parser.add_argument('--relu', action='store_true', default=False, help='Use relu in image embedder')
39 | parser.add_argument('--dropout', action='store_true', default=False, help='Use dropout in image embedder')
40 | parser.add_argument('--norm', action='store_true', default=False, help='Use normalization in image embedder')
41 | parser.add_argument('--train_only', action='store_true', default=False, help='Optimize only for train pairs')
42 |
43 | #CGE
44 | parser.add_argument('--graph', action='store_true', default=False, help='graph l2 distance triplet') # Do we need this
45 | parser.add_argument('--graph_init', default=None, help='filename, file from which initializing the nodes and adjacency matrix of the graph')
46 | parser.add_argument('--gcn_type', default='gcn', help='GCN Version')
47 |
48 | # Forward
49 | parser.add_argument('--eval_type', default='dist', help='dist|prod|direct, function for computing the predictions')
50 |
51 | # Primitive-based loss
52 | parser.add_argument('--lambda_aux', type=float, default=0.0, help='weight of auxiliar losses on the primitives')
53 |
54 | # AoP
55 | parser.add_argument('--lambda_inv', type=float, default=0.0, help='weight of inverse consistency loss')
56 | parser.add_argument('--lambda_comm', type=float, default=0.0, help='weight of commutative loss')
57 | parser.add_argument('--lambda_ant', type=float, default=0.0, help='weight of antonyms loss')
58 |
59 |
60 | # SymNet
61 | parser.add_argument("--lambda_trip", type=float, default=0, help='weight of triplet loss')
62 | parser.add_argument("--lambda_sym", type=float, default=0, help='weight of symmetry loss')
63 | parser.add_argument("--lambda_axiom", type=float, default=0, help='weight of axiom loss')
64 | parser.add_argument("--lambda_cls_attr", type=float, default=0, help='weight of classification loss on attributes')
65 | parser.add_argument("--lambda_cls_obj", type=float, default=0, help='weight of classification loss on objects')
66 |
67 | # CompCos (for the margin, see below)
68 | parser.add_argument('--cosine_scale', type=float, default=20,help="Scale for cosine similarity")
69 | parser.add_argument('--epoch_max_margin', type=float, default=15,help="Epoch of max margin")
70 | parser.add_argument('--update_feasibility_every', type=float, default=1,help="Frequency of feasibility scores update in epochs")
71 | parser.add_argument('--hard_masking', action='store_true', default=False, help='hard masking during test time')
72 | parser.add_argument('--threshold', default=None,help="Apply a specific threshold at test time for the hard masking")
73 | parser.add_argument('--threshold_trials', type=int, default=50,help="how many threshold values to try")
74 |
75 | # Hyperparameters
76 | parser.add_argument('--topk', type=int, default=1,help="Compute topk accuracy")
77 | parser.add_argument('--margin', type=float, default=2,help="Margin for triplet loss or feasibility scores in CompCos")
78 | parser.add_argument('--workers', type=int, default=8,help="Number of workers")
79 | parser.add_argument('--batch_size', type=int, default=512,help="Training batch size")
80 | parser.add_argument('--lr', type=float, default=5e-5,help="Learning rate")
81 | parser.add_argument('--lrg', type=float, default=1e-3,help="Learning rate feature extractor")
82 | parser.add_argument('--wd', type=float, default=5e-5,help="Weight decay")
83 | parser.add_argument('--save_every', type=int, default=10000,help="Frequency of snapshots in epochs")
84 | parser.add_argument('--eval_val_every', type=int, default=1,help="Frequency of eval in epochs")
85 | parser.add_argument('--max_epochs', type=int, default=800,help="Max number of epochs")
86 | parser.add_argument("--fc_emb", default='768,1024,1200',help="Image embedder layer config")
87 | parser.add_argument("--gr_emb", default='d4096,d',help="graph layers config")
88 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ExplainableML/czsl/22ab1aca82c64d4f351a20a887986147edfb1905/models/__init__.py
--------------------------------------------------------------------------------
/models/common.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import copy
6 | from scipy.stats import hmean
7 |
8 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
9 |
10 | class MLP(nn.Module):
11 | '''
12 | Baseclass to create a simple MLP
13 | Inputs
14 | inp_dim: Int, Input dimension
15 | out-dim: Int, Output dimension
16 | num_layer: Number of hidden layers
17 | relu: Bool, Use non linear function at output
18 | bias: Bool, Use bias
19 | '''
20 | def __init__(self, inp_dim, out_dim, num_layers = 1, relu = True, bias = True, dropout = False, norm = False, layers = []):
21 | super(MLP, self).__init__()
22 | mod = []
23 | incoming = inp_dim
24 | for layer in range(num_layers - 1):
25 | if len(layers) == 0:
26 | outgoing = incoming
27 | else:
28 | outgoing = layers.pop(0)
29 | mod.append(nn.Linear(incoming, outgoing, bias = bias))
30 |
31 | incoming = outgoing
32 | if norm:
33 | mod.append(nn.LayerNorm(outgoing))
34 | # mod.append(nn.BatchNorm1d(outgoing))
35 | mod.append(nn.ReLU(inplace = True))
36 | # mod.append(nn.LeakyReLU(inplace=True, negative_slope=0.2))
37 | if dropout:
38 | mod.append(nn.Dropout(p = 0.5))
39 |
40 | mod.append(nn.Linear(incoming, out_dim, bias = bias))
41 |
42 | if relu:
43 | mod.append(nn.ReLU(inplace = True))
44 | # mod.append(nn.LeakyReLU(inplace=True, negative_slope=0.2))
45 | self.mod = nn.Sequential(*mod)
46 |
47 | def forward(self, x):
48 | return self.mod(x)
49 |
50 | class Reshape(nn.Module):
51 | def __init__(self, *args):
52 | super(Reshape, self).__init__()
53 | self.shape = args
54 |
55 | def forward(self, x):
56 | return x.view(self.shape)
57 |
58 | def calculate_margines(domain_embedding, gt, margin_range = 5):
59 | '''
60 | domain_embedding: pairs * feats
61 | gt: batch * feats
62 | '''
63 | batch_size, pairs, features = gt.shape[0], domain_embedding.shape[0], domain_embedding.shape[1]
64 | gt_expanded = gt[:,None,:].expand(-1, pairs, -1)
65 | domain_embedding_expanded = domain_embedding[None, :, :].expand(batch_size, -1, -1)
66 | margin = (gt_expanded - domain_embedding_expanded)**2
67 | margin = margin.sum(2)
68 | max_margin, _ = torch.max(margin, dim = 0)
69 | margin /= max_margin
70 | margin *= margin_range
71 | return margin
72 |
73 | def l2_all_batched(image_embedding, domain_embedding):
74 | '''
75 | Image Embedding: Tensor of Batch_size * pairs * Feature_dim
76 | domain_embedding: Tensor of pairs * Feature_dim
77 | '''
78 | pairs = image_embedding.shape[1]
79 | domain_embedding_extended = image_embedding[:,None,:].expand(-1,pairs,-1)
80 | l2_loss = (image_embedding - domain_embedding_extended) ** 2
81 | l2_loss = l2_loss.sum(2)
82 | l2_loss = l2_loss.sum() / l2_loss.numel()
83 | return l2_loss
84 |
85 | def same_domain_triplet_loss(image_embedding, trip_images, gt, hard_k = None, margin = 2):
86 | '''
87 | Image Embedding: Tensor of Batch_size * Feature_dim
88 | Triplet Images: Tensor of Batch_size * num_pairs * Feature_dim
89 | GT: Tensor of Batch_size
90 | '''
91 | batch_size, pairs, features = trip_images.shape
92 | batch_iterator = torch.arange(batch_size).to(device)
93 | image_embedding_expanded = image_embedding[:,None,:].expand(-1, pairs, -1)
94 |
95 | diff = (image_embedding_expanded - trip_images)**2
96 | diff = diff.sum(2)
97 |
98 | positive_anchor = diff[batch_iterator, gt][:,None]
99 | positive_anchor = positive_anchor.expand(-1, pairs)
100 |
101 | # Calculating triplet loss
102 | triplet_loss = positive_anchor - diff + margin
103 |
104 | # Setting positive anchor loss to 0
105 | triplet_loss[batch_iterator, gt] = 0
106 |
107 | # Removing easy triplets
108 | triplet_loss[triplet_loss < 0] = 0
109 |
110 | # If only mining hard triplets
111 | if hard_k:
112 | triplet_loss, _ = triplet_loss.topk(hard_k)
113 |
114 | # Counting number of valid pairs
115 | num_positive_triplets = triplet_loss[triplet_loss > 1e-16].size(0)
116 |
117 | # Calculating the final loss
118 | triplet_loss = triplet_loss.sum() / (num_positive_triplets + 1e-16)
119 |
120 | return triplet_loss
121 |
122 |
123 | def cross_domain_triplet_loss(image_embedding, domain_embedding, gt, hard_k = None, margin = 2):
124 | '''
125 | Image Embedding: Tensor of Batch_size * Feature_dim
126 | Domain Embedding: Tensor of Num_pairs * Feature_dim
127 | gt: Tensor of Batch_size with ground truth labels
128 | margin: Float of margin
129 | Returns:
130 | Triplet loss of all valid triplets
131 | '''
132 | batch_size, pairs, features = image_embedding.shape[0], domain_embedding.shape[0], domain_embedding.shape[1]
133 | batch_iterator = torch.arange(batch_size).to(device)
134 | # Now dimensions will be Batch_size * Num_pairs * Feature_dim
135 | image_embedding = image_embedding[:,None,:].expand(-1, pairs, -1)
136 | domain_embedding = domain_embedding[None,:,:].expand(batch_size, -1, -1)
137 |
138 | # Calculating difference
139 | diff = (image_embedding - domain_embedding)**2
140 | diff = diff.sum(2)
141 |
142 | # Getting the positive pair
143 | positive_anchor = diff[batch_iterator, gt][:,None]
144 | positive_anchor = positive_anchor.expand(-1, pairs)
145 |
146 | # Calculating triplet loss
147 | triplet_loss = positive_anchor - diff + margin
148 |
149 | # Setting positive anchor loss to 0
150 | triplet_loss[batch_iterator, gt] = 0
151 |
152 | # Removing easy triplets
153 | triplet_loss[triplet_loss < 0] = 0
154 |
155 | # If only mining hard triplets
156 | if hard_k:
157 | triplet_loss, _ = triplet_loss.topk(hard_k)
158 |
159 | # Counting number of valid pairs
160 | num_positive_triplets = triplet_loss[triplet_loss > 1e-16].size(0)
161 |
162 | # Calculating the final loss
163 | triplet_loss = triplet_loss.sum() / (num_positive_triplets + 1e-16)
164 |
165 | return triplet_loss
166 |
167 | def same_domain_triplet_loss_old(image_embedding, positive_anchor, negative_anchor, margin = 2):
168 | '''
169 | Image Embedding: Tensor of Batch_size * Feature_dim
170 | Positive anchor: Tensor of Batch_size * Feature_dim
171 | negative anchor: Tensor of Batch_size * negs *Feature_dim
172 | '''
173 | batch_size, negs, features = negative_anchor.shape
174 | dist_pos = (image_embedding - positive_anchor)**2
175 | dist_pos = dist_pos.sum(1)
176 | dist_pos = dist_pos[:, None].expand(-1, negs)
177 |
178 | image_embedding_expanded = image_embedding[:,None,:].expand(-1, negs, -1)
179 | dist_neg = (image_embedding_expanded - negative_anchor)**2
180 | dist_neg = dist_neg.sum(2)
181 |
182 | triplet_loss = dist_pos - dist_neg + margin
183 | triplet_loss[triplet_loss < 0] = 0
184 | num_positive_triplets = triplet_loss[triplet_loss > 1e-16].size(0)
185 |
186 | triplet_loss = triplet_loss.sum() / (num_positive_triplets + 1e-16)
187 | return triplet_loss
188 |
189 |
190 | def pairwise_distances(x, y=None):
191 | '''
192 | Input: x is a Nxd matrix
193 | y is an optional Mxd matirx
194 | Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
195 | if y is not given then use 'y=x'.
196 | i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
197 | '''
198 | x_norm = (x**2).sum(1).view(-1, 1)
199 | if y is not None:
200 | y_t = torch.transpose(y, 0, 1)
201 | y_norm = (y**2).sum(1).view(1, -1)
202 | else:
203 | y_t = torch.transpose(x, 0, 1)
204 | y_norm = x_norm.view(1, -1)
205 |
206 | dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
207 | # Ensure diagonal is zero if x=y
208 | # if y is None:
209 | # dist = dist - torch.diag(dist.diag)
210 | return torch.clamp(dist, 0.0, np.inf)
211 |
212 | class Evaluator:
213 |
214 | def __init__(self, dset, model):
215 |
216 | self.dset = dset
217 |
218 | # Convert text pairs to idx tensors: [('sliced', 'apple'), ('ripe', 'apple'), ...] --> torch.LongTensor([[0,1],[1,1], ...])
219 | pairs = [(dset.attr2idx[attr], dset.obj2idx[obj]) for attr, obj in dset.pairs]
220 | self.train_pairs = [(dset.attr2idx[attr], dset.obj2idx[obj]) for attr, obj in dset.train_pairs]
221 | self.pairs = torch.LongTensor(pairs)
222 |
223 | # Mask over pairs that occur in closed world
224 | # Select set based on phase
225 | if dset.phase == 'train':
226 | print('Evaluating with train pairs')
227 | test_pair_set = set(dset.train_pairs)
228 | test_pair_gt = set(dset.train_pairs)
229 | elif dset.phase == 'val':
230 | print('Evaluating with validation pairs')
231 | test_pair_set = set(dset.val_pairs + dset.train_pairs)
232 | test_pair_gt = set(dset.val_pairs)
233 | else:
234 | print('Evaluating with test pairs')
235 | test_pair_set = set(dset.test_pairs + dset.train_pairs)
236 | test_pair_gt = set(dset.test_pairs)
237 |
238 | self.test_pair_dict = [(dset.attr2idx[attr], dset.obj2idx[obj]) for attr, obj in test_pair_gt]
239 | self.test_pair_dict = dict.fromkeys(self.test_pair_dict, 0)
240 |
241 | # dict values are pair val, score, total
242 | for attr, obj in test_pair_gt:
243 | pair_val = dset.pair2idx[(attr,obj)]
244 | key = (dset.attr2idx[attr], dset.obj2idx[obj])
245 | self.test_pair_dict[key] = [pair_val, 0, 0]
246 |
247 | if dset.open_world:
248 | masks = [1 for _ in dset.pairs]
249 | else:
250 | masks = [1 if pair in test_pair_set else 0 for pair in dset.pairs]
251 |
252 | self.closed_mask = torch.BoolTensor(masks)
253 | # Mask of seen concepts
254 | seen_pair_set = set(dset.train_pairs)
255 | mask = [1 if pair in seen_pair_set else 0 for pair in dset.pairs]
256 | self.seen_mask = torch.BoolTensor(mask)
257 |
258 | # Object specific mask over which pairs occur in the object oracle setting
259 | oracle_obj_mask = []
260 | for _obj in dset.objs:
261 | mask = [1 if _obj == obj else 0 for attr, obj in dset.pairs]
262 | oracle_obj_mask.append(torch.BoolTensor(mask))
263 | self.oracle_obj_mask = torch.stack(oracle_obj_mask, 0)
264 |
265 | # Decide if the model under evaluation is a manifold model or not
266 | self.score_model = self.score_manifold_model
267 |
268 | # Generate mask for each settings, mask scores, and get prediction labels
269 | def generate_predictions(self, scores, obj_truth, bias = 0.0, topk = 5): # (Batch, #pairs)
270 | '''
271 | Inputs
272 | scores: Output scores
273 | obj_truth: Ground truth object
274 | Returns
275 | results: dict of results in 3 settings
276 | '''
277 | def get_pred_from_scores(_scores, topk):
278 | '''
279 | Given list of scores, returns top 10 attr and obj predictions
280 | Check later
281 | '''
282 | _, pair_pred = _scores.topk(topk, dim = 1) #sort returns indices of k largest values
283 | pair_pred = pair_pred.contiguous().view(-1)
284 | attr_pred, obj_pred = self.pairs[pair_pred][:, 0].view(-1, topk), \
285 | self.pairs[pair_pred][:, 1].view(-1, topk)
286 | return (attr_pred, obj_pred)
287 |
288 | results = {}
289 | orig_scores = scores.clone()
290 | mask = self.seen_mask.repeat(scores.shape[0],1) # Repeat mask along pairs dimension
291 | scores[~mask] += bias # Add bias to test pairs
292 |
293 | # Unbiased setting
294 |
295 | # Open world setting --no mask, all pairs of the dataset
296 | results.update({'open': get_pred_from_scores(scores, topk)})
297 | results.update({'unbiased_open': get_pred_from_scores(orig_scores, topk)})
298 | # Closed world setting - set the score for all Non test pairs to -1e10,
299 | # this excludes the pairs from set not in evaluation
300 | mask = self.closed_mask.repeat(scores.shape[0], 1)
301 | closed_scores = scores.clone()
302 | closed_scores[~mask] = -1e10
303 | closed_orig_scores = orig_scores.clone()
304 | closed_orig_scores[~mask] = -1e10
305 | results.update({'closed': get_pred_from_scores(closed_scores, topk)})
306 | results.update({'unbiased_closed': get_pred_from_scores(closed_orig_scores, topk)})
307 |
308 | # Object_oracle setting - set the score to -1e10 for all pairs where the true object does Not participate, can also use the closed score
309 | mask = self.oracle_obj_mask[obj_truth]
310 | oracle_obj_scores = scores.clone()
311 | oracle_obj_scores[~mask] = -1e10
312 | oracle_obj_scores_unbiased = orig_scores.clone()
313 | oracle_obj_scores_unbiased[~mask] = -1e10
314 | results.update({'object_oracle': get_pred_from_scores(oracle_obj_scores, 1)})
315 | results.update({'object_oracle_unbiased': get_pred_from_scores(oracle_obj_scores_unbiased, 1)})
316 |
317 | return results
318 |
319 | def score_clf_model(self, scores, obj_truth, topk = 5):
320 | '''
321 | Wrapper function to call generate_predictions for CLF models
322 | '''
323 | attr_pred, obj_pred = scores
324 |
325 | # Go to CPU
326 | attr_pred, obj_pred, obj_truth = attr_pred.to('cpu'), obj_pred.to('cpu'), obj_truth.to('cpu')
327 |
328 | # Gather scores (P(a), P(o)) for all relevant (a,o) pairs
329 | # Multiply P(a) * P(o) to get P(pair)
330 | attr_subset = attr_pred.index_select(1, self.pairs[:,0]) # Return only attributes that are in our pairs
331 | obj_subset = obj_pred.index_select(1, self.pairs[:, 1])
332 | scores = (attr_subset * obj_subset) # (Batch, #pairs)
333 |
334 | results = self.generate_predictions(scores, obj_truth)
335 | results['biased_scores'] = scores
336 |
337 | return results
338 |
339 | def score_manifold_model(self, scores, obj_truth, bias = 0.0, topk = 5):
340 | '''
341 | Wrapper function to call generate_predictions for manifold models
342 | '''
343 | # Go to CPU
344 | scores = {k: v.to('cpu') for k, v in scores.items()}
345 | obj_truth = obj_truth.to(device)
346 |
347 | # Gather scores for all relevant (a,o) pairs
348 | scores = torch.stack(
349 | [scores[(attr,obj)] for attr, obj in self.dset.pairs], 1
350 | ) # (Batch, #pairs)
351 | orig_scores = scores.clone()
352 | results = self.generate_predictions(scores, obj_truth, bias, topk)
353 | results['scores'] = orig_scores
354 | return results
355 |
356 | def score_fast_model(self, scores, obj_truth, bias = 0.0, topk = 5):
357 | '''
358 | Wrapper function to call generate_predictions for manifold models
359 | '''
360 |
361 | results = {}
362 | mask = self.seen_mask.repeat(scores.shape[0],1) # Repeat mask along pairs dimension
363 | scores[~mask] += bias # Add bias to test pairs
364 |
365 | mask = self.closed_mask.repeat(scores.shape[0], 1)
366 | closed_scores = scores.clone()
367 | closed_scores[~mask] = -1e10
368 |
369 | _, pair_pred = closed_scores.topk(topk, dim = 1) #sort returns indices of k largest values
370 | pair_pred = pair_pred.contiguous().view(-1)
371 | attr_pred, obj_pred = self.pairs[pair_pred][:, 0].view(-1, topk), \
372 | self.pairs[pair_pred][:, 1].view(-1, topk)
373 |
374 | results.update({'closed': (attr_pred, obj_pred)})
375 | return results
376 |
377 | def evaluate_predictions(self, predictions, attr_truth, obj_truth, pair_truth, allpred, topk = 1):
378 | # Go to CPU
379 | attr_truth, obj_truth, pair_truth = attr_truth.to('cpu'), obj_truth.to('cpu'), pair_truth.to('cpu')
380 |
381 | pairs = list(
382 | zip(list(attr_truth.numpy()), list(obj_truth.numpy())))
383 |
384 |
385 | seen_ind, unseen_ind = [], []
386 | for i in range(len(attr_truth)):
387 | if pairs[i] in self.train_pairs:
388 | seen_ind.append(i)
389 | else:
390 | unseen_ind.append(i)
391 |
392 |
393 | seen_ind, unseen_ind = torch.LongTensor(seen_ind), torch.LongTensor(unseen_ind)
394 | def _process(_scores):
395 | # Top k pair accuracy
396 | # Attribute, object and pair
397 | attr_match = (attr_truth.unsqueeze(1).repeat(1, topk) == _scores[0][:, :topk])
398 | obj_match = (obj_truth.unsqueeze(1).repeat(1, topk) == _scores[1][:, :topk])
399 |
400 | # Match of object pair
401 | match = (attr_match * obj_match).any(1).float()
402 | attr_match = attr_match.any(1).float()
403 | obj_match = obj_match.any(1).float()
404 | # Match of seen and unseen pairs
405 | seen_match = match[seen_ind]
406 | unseen_match = match[unseen_ind]
407 | ### Calculating class average accuracy
408 |
409 | # local_score_dict = copy.deepcopy(self.test_pair_dict)
410 | # for pair_gt, pair_pred in zip(pairs, match):
411 | # # print(pair_gt)
412 | # local_score_dict[pair_gt][2] += 1.0 #increase counter
413 | # if int(pair_pred) == 1:
414 | # local_score_dict[pair_gt][1] += 1.0
415 |
416 | # # Now we have hits and totals for classes in evaluation set
417 | # seen_score, unseen_score = [], []
418 | # for key, (idx, hits, total) in local_score_dict.items():
419 | # score = hits/total
420 | # if bool(self.seen_mask[idx]) == True:
421 | # seen_score.append(score)
422 | # else:
423 | # unseen_score.append(score)
424 |
425 | seen_score, unseen_score = torch.ones(512,5), torch.ones(512,5)
426 |
427 | return attr_match, obj_match, match, seen_match, unseen_match, \
428 | torch.Tensor(seen_score+unseen_score), torch.Tensor(seen_score), torch.Tensor(unseen_score)
429 |
430 | def _add_to_dict(_scores, type_name, stats):
431 | base = ['_attr_match', '_obj_match', '_match', '_seen_match', '_unseen_match', '_ca', '_seen_ca', '_unseen_ca']
432 | for val, name in zip(_scores, base):
433 | stats[type_name + name] = val
434 |
435 | ##################### Match in places where corrent object
436 | obj_oracle_match = (attr_truth == predictions['object_oracle'][0][:, 0]).float() #object is already conditioned
437 | obj_oracle_match_unbiased = (attr_truth == predictions['object_oracle_unbiased'][0][:, 0]).float()
438 |
439 | stats = dict(obj_oracle_match = obj_oracle_match, obj_oracle_match_unbiased = obj_oracle_match_unbiased)
440 |
441 | #################### Closed world
442 | closed_scores = _process(predictions['closed'])
443 | unbiased_closed = _process(predictions['unbiased_closed'])
444 | _add_to_dict(closed_scores, 'closed', stats)
445 | _add_to_dict(unbiased_closed, 'closed_ub', stats)
446 |
447 | #################### Calculating AUC
448 | scores = predictions['scores']
449 | # getting score for each ground truth class
450 | correct_scores = scores[torch.arange(scores.shape[0]), pair_truth][unseen_ind]
451 |
452 | # Getting top predicted score for these unseen classes
453 | max_seen_scores = predictions['scores'][unseen_ind][:, self.seen_mask].topk(topk, dim=1)[0][:, topk - 1]
454 |
455 | # Getting difference between these scores
456 | unseen_score_diff = max_seen_scores - correct_scores
457 |
458 | # Getting matched classes at max bias for diff
459 | unseen_matches = stats['closed_unseen_match'].bool()
460 | correct_unseen_score_diff = unseen_score_diff[unseen_matches] - 1e-4
461 |
462 | # sorting these diffs
463 | correct_unseen_score_diff = torch.sort(correct_unseen_score_diff)[0]
464 | magic_binsize = 20
465 | # getting step size for these bias values
466 | bias_skip = max(len(correct_unseen_score_diff) // magic_binsize, 1)
467 | # Getting list
468 | biaslist = correct_unseen_score_diff[::bias_skip]
469 |
470 | seen_match_max = float(stats['closed_seen_match'].mean())
471 | unseen_match_max = float(stats['closed_unseen_match'].mean())
472 | seen_accuracy, unseen_accuracy = [], []
473 |
474 | # Go to CPU
475 | base_scores = {k: v.to('cpu') for k, v in allpred.items()}
476 | obj_truth = obj_truth.to('cpu')
477 |
478 | # Gather scores for all relevant (a,o) pairs
479 | base_scores = torch.stack(
480 | [allpred[(attr,obj)] for attr, obj in self.dset.pairs], 1
481 | ) # (Batch, #pairs)
482 |
483 | for bias in biaslist:
484 | scores = base_scores.clone()
485 | results = self.score_fast_model(scores, obj_truth, bias = bias, topk = topk)
486 | results = results['closed'] # we only need biased
487 | results = _process(results)
488 | seen_match = float(results[3].mean())
489 | unseen_match = float(results[4].mean())
490 | seen_accuracy.append(seen_match)
491 | unseen_accuracy.append(unseen_match)
492 |
493 | seen_accuracy.append(seen_match_max)
494 | unseen_accuracy.append(unseen_match_max)
495 | seen_accuracy, unseen_accuracy = np.array(seen_accuracy), np.array(unseen_accuracy)
496 | area = np.trapz(seen_accuracy, unseen_accuracy)
497 |
498 | for key in stats:
499 | stats[key] = float(stats[key].mean())
500 |
501 | harmonic_mean = hmean([seen_accuracy, unseen_accuracy], axis = 0)
502 | max_hm = np.max(harmonic_mean)
503 | idx = np.argmax(harmonic_mean)
504 | if idx == len(biaslist):
505 | bias_term = 1e3
506 | else:
507 | bias_term = biaslist[idx]
508 | stats['biasterm'] = float(bias_term)
509 | stats['best_unseen'] = np.max(unseen_accuracy)
510 | stats['best_seen'] = np.max(seen_accuracy)
511 | stats['AUC'] = area
512 | stats['hm_unseen'] = unseen_accuracy[idx]
513 | stats['hm_seen'] = seen_accuracy[idx]
514 | stats['best_hm'] = max_hm
515 | return stats
--------------------------------------------------------------------------------
/models/compcos.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from .word_embedding import load_word_embeddings
5 | from .common import MLP
6 |
7 | from itertools import product
8 |
9 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
10 |
11 | def compute_cosine_similarity(names, weights, return_dict=True):
12 | pairing_names = list(product(names, names))
13 | normed_weights = F.normalize(weights,dim=1)
14 | similarity = torch.mm(normed_weights, normed_weights.t())
15 | if return_dict:
16 | dict_sim = {}
17 | for i,n in enumerate(names):
18 | for j,m in enumerate(names):
19 | dict_sim[(n,m)]=similarity[i,j].item()
20 | return dict_sim
21 | return pairing_names, similarity.to('cpu')
22 |
23 | class CompCos(nn.Module):
24 |
25 | def __init__(self, dset, args):
26 | super(CompCos, self).__init__()
27 | self.args = args
28 | self.dset = dset
29 |
30 | def get_all_ids(relevant_pairs):
31 | # Precompute validation pairs
32 | attrs, objs = zip(*relevant_pairs)
33 | attrs = [dset.attr2idx[attr] for attr in attrs]
34 | objs = [dset.obj2idx[obj] for obj in objs]
35 | pairs = [a for a in range(len(relevant_pairs))]
36 | attrs = torch.LongTensor(attrs).to(device)
37 | objs = torch.LongTensor(objs).to(device)
38 | pairs = torch.LongTensor(pairs).to(device)
39 | return attrs, objs, pairs
40 |
41 | # Validation
42 | self.val_attrs, self.val_objs, self.val_pairs = get_all_ids(self.dset.pairs)
43 |
44 | # for indivual projections
45 | self.uniq_attrs, self.uniq_objs = torch.arange(len(self.dset.attrs)).long().to(device), \
46 | torch.arange(len(self.dset.objs)).long().to(device)
47 | self.factor = 2
48 |
49 | self.scale = self.args.cosine_scale
50 |
51 | if dset.open_world:
52 | self.train_forward = self.train_forward_open
53 | self.known_pairs = dset.train_pairs
54 | seen_pair_set = set(self.known_pairs)
55 | mask = [1 if pair in seen_pair_set else 0 for pair in dset.pairs]
56 | self.seen_mask = torch.BoolTensor(mask).to(device) * 1.
57 |
58 | self.activated = False
59 |
60 | # Init feasibility-related variables
61 | self.attrs = dset.attrs
62 | self.objs = dset.objs
63 | self.possible_pairs = dset.pairs
64 |
65 | self.validation_pairs = dset.val_pairs
66 |
67 | self.feasibility_margin = (1-self.seen_mask).float()
68 | self.epoch_max_margin = self.args.epoch_max_margin
69 | self.cosine_margin_factor = -args.margin
70 |
71 | # Intantiate attribut-object relations, needed just to evaluate mined pairs
72 | self.obj_by_attrs_train = {k: [] for k in self.attrs}
73 | for (a, o) in self.known_pairs:
74 | self.obj_by_attrs_train[a].append(o)
75 |
76 | # Intantiate attribut-object relations, needed just to evaluate mined pairs
77 | self.attrs_by_obj_train = {k: [] for k in self.objs}
78 | for (a, o) in self.known_pairs:
79 | self.attrs_by_obj_train[o].append(a)
80 |
81 | else:
82 | self.train_forward = self.train_forward_closed
83 |
84 | # Precompute training compositions
85 | if args.train_only:
86 | self.train_attrs, self.train_objs, self.train_pairs = get_all_ids(self.dset.train_pairs)
87 | else:
88 | self.train_attrs, self.train_objs, self.train_pairs = self.val_attrs, self.val_objs, self.val_pairs
89 |
90 | try:
91 | self.args.fc_emb = self.args.fc_emb.split(',')
92 | except:
93 | self.args.fc_emb = [self.args.fc_emb]
94 | layers = []
95 | for a in self.args.fc_emb:
96 | a = int(a)
97 | layers.append(a)
98 |
99 |
100 | self.image_embedder = MLP(dset.feat_dim, int(args.emb_dim), relu=args.relu, num_layers=args.nlayers,
101 | dropout=self.args.dropout,
102 | norm=self.args.norm, layers=layers)
103 |
104 | # Fixed
105 | self.composition = args.composition
106 |
107 | input_dim = args.emb_dim
108 | self.attr_embedder = nn.Embedding(len(dset.attrs), input_dim)
109 | self.obj_embedder = nn.Embedding(len(dset.objs), input_dim)
110 |
111 | # init with word embeddings
112 | if args.emb_init:
113 | pretrained_weight = load_word_embeddings(args.emb_init, dset.attrs)
114 | self.attr_embedder.weight.data.copy_(pretrained_weight)
115 | pretrained_weight = load_word_embeddings(args.emb_init, dset.objs)
116 | self.obj_embedder.weight.data.copy_(pretrained_weight)
117 |
118 | # static inputs
119 | if args.static_inp:
120 | for param in self.attr_embedder.parameters():
121 | param.requires_grad = False
122 | for param in self.obj_embedder.parameters():
123 | param.requires_grad = False
124 |
125 | # Composition MLP
126 | self.projection = nn.Linear(input_dim * 2, args.emb_dim)
127 |
128 |
129 | def freeze_representations(self):
130 | print('Freezing representations')
131 | for param in self.image_embedder.parameters():
132 | param.requires_grad = False
133 | for param in self.attr_embedder.parameters():
134 | param.requires_grad = False
135 | for param in self.obj_embedder.parameters():
136 | param.requires_grad = False
137 |
138 |
139 | def compose(self, attrs, objs):
140 | attrs, objs = self.attr_embedder(attrs), self.obj_embedder(objs)
141 | inputs = torch.cat([attrs, objs], 1)
142 | output = self.projection(inputs)
143 | output = F.normalize(output, dim=1)
144 | return output
145 |
146 |
147 | def compute_feasibility(self):
148 | obj_embeddings = self.obj_embedder(torch.arange(len(self.objs)).long().to('cuda'))
149 | obj_embedding_sim = compute_cosine_similarity(self.objs, obj_embeddings,
150 | return_dict=True)
151 | attr_embeddings = self.attr_embedder(torch.arange(len(self.attrs)).long().to('cuda'))
152 | attr_embedding_sim = compute_cosine_similarity(self.attrs, attr_embeddings,
153 | return_dict=True)
154 |
155 | feasibility_scores = self.seen_mask.clone().float()
156 | for a in self.attrs:
157 | for o in self.objs:
158 | if (a, o) not in self.known_pairs:
159 | idx = self.dset.all_pair2idx[(a, o)]
160 | score_obj = self.get_pair_scores_objs(a, o, obj_embedding_sim)
161 | score_attr = self.get_pair_scores_attrs(a, o, attr_embedding_sim)
162 | score = (score_obj + score_attr) / 2
163 | feasibility_scores[idx] = score
164 |
165 | self.feasibility_scores = feasibility_scores
166 |
167 | return feasibility_scores * (1 - self.seen_mask.float())
168 |
169 |
170 | def get_pair_scores_objs(self, attr, obj, obj_embedding_sim):
171 | score = -1.
172 | for o in self.objs:
173 | if o!=obj and attr in self.attrs_by_obj_train[o]:
174 | temp_score = obj_embedding_sim[(obj,o)]
175 | if temp_score>score:
176 | score=temp_score
177 | return score
178 |
179 | def get_pair_scores_attrs(self, attr, obj, attr_embedding_sim):
180 | score = -1.
181 | for a in self.attrs:
182 | if a != attr and obj in self.obj_by_attrs_train[a]:
183 | temp_score = attr_embedding_sim[(attr, a)]
184 | if temp_score > score:
185 | score = temp_score
186 | return score
187 |
188 | def update_feasibility(self,epoch):
189 | self.activated = True
190 | feasibility_scores = self.compute_feasibility()
191 | self.feasibility_margin = min(1.,epoch/self.epoch_max_margin) * \
192 | (self.cosine_margin_factor*feasibility_scores.float().to(device))
193 |
194 |
195 | def val_forward(self, x):
196 | img = x[0]
197 | img_feats = self.image_embedder(img)
198 | img_feats_normed = F.normalize(img_feats, dim=1)
199 | pair_embeds = self.compose(self.val_attrs, self.val_objs).permute(1, 0) # Evaluate all pairs
200 | score = torch.matmul(img_feats_normed, pair_embeds)
201 |
202 | scores = {}
203 | for itr, pair in enumerate(self.dset.pairs):
204 | scores[pair] = score[:, self.dset.all_pair2idx[pair]]
205 |
206 | return None, scores
207 |
208 |
209 | def val_forward_with_threshold(self, x, th=0.):
210 | img = x[0]
211 | img_feats = self.image_embedder(img)
212 | img_feats_normed = F.normalize(img_feats, dim=1)
213 | pair_embeds = self.compose(self.val_attrs, self.val_objs).permute(1, 0) # Evaluate all pairs
214 | score = torch.matmul(img_feats_normed, pair_embeds)
215 |
216 | # Note: Pairs are already aligned here
217 | mask = (self.feasibility_scores>=th).float()
218 | score = score*mask + (1.-mask)*(-1.)
219 |
220 | scores = {}
221 | for itr, pair in enumerate(self.dset.pairs):
222 | scores[pair] = score[:, self.dset.all_pair2idx[pair]]
223 |
224 | return None, scores
225 |
226 |
227 | def train_forward_open(self, x):
228 | img, attrs, objs, pairs = x[0], x[1], x[2], x[3]
229 | img_feats = self.image_embedder(img)
230 |
231 | pair_embed = self.compose(self.train_attrs, self.train_objs).permute(1, 0)
232 | img_feats_normed = F.normalize(img_feats, dim=1)
233 |
234 | pair_pred = torch.matmul(img_feats_normed, pair_embed)
235 |
236 | if self.activated:
237 | pair_pred += (1 - self.seen_mask) * self.feasibility_margin
238 | loss_cos = F.cross_entropy(self.scale * pair_pred, pairs)
239 | else:
240 | pair_pred = pair_pred * self.seen_mask + (1 - self.seen_mask) * (-10)
241 | loss_cos = F.cross_entropy(self.scale * pair_pred, pairs)
242 |
243 | return loss_cos.mean(), None
244 |
245 |
246 | def train_forward_closed(self, x):
247 | img, attrs, objs, pairs = x[0], x[1], x[2], x[3]
248 | img_feats = self.image_embedder(img)
249 |
250 | pair_embed = self.compose(self.train_attrs, self.train_objs).permute(1, 0)
251 | img_feats_normed = F.normalize(img_feats, dim=1)
252 |
253 | pair_pred = torch.matmul(img_feats_normed, pair_embed)
254 |
255 | loss_cos = F.cross_entropy(self.scale * pair_pred, pairs)
256 |
257 | return loss_cos.mean(), None
258 |
259 |
260 | def forward(self, x):
261 | if self.training:
262 | loss, pred = self.train_forward(x)
263 | else:
264 | with torch.no_grad():
265 | loss, pred = self.val_forward(x)
266 | return loss, pred
267 |
268 |
269 |
270 |
--------------------------------------------------------------------------------
/models/gcn.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.sparse as sp
3 |
4 | import math
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch.nn.init import xavier_uniform_
9 |
10 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
11 |
12 | def normt_spm(mx, method='in'):
13 | if method == 'in':
14 | mx = mx.transpose()
15 | rowsum = np.array(mx.sum(1))
16 | r_inv = np.power(rowsum, -1).flatten()
17 | r_inv[np.isinf(r_inv)] = 0.
18 | r_mat_inv = sp.diags(r_inv)
19 | mx = r_mat_inv.dot(mx)
20 | return mx
21 |
22 | if method == 'sym':
23 | rowsum = np.array(mx.sum(1))
24 | r_inv = np.power(rowsum, -0.5).flatten()
25 | r_inv[np.isinf(r_inv)] = 0.
26 | r_mat_inv = sp.diags(r_inv)
27 | mx = mx.dot(r_mat_inv).transpose().dot(r_mat_inv)
28 | return mx
29 |
30 | def spm_to_tensor(sparse_mx):
31 | sparse_mx = sparse_mx.tocoo().astype(np.float32)
32 | indices = torch.from_numpy(np.vstack(
33 | (sparse_mx.row, sparse_mx.col))).long()
34 | values = torch.from_numpy(sparse_mx.data)
35 | shape = torch.Size(sparse_mx.shape)
36 | return torch.sparse.FloatTensor(indices, values, shape)
37 |
38 | class GraphConv(nn.Module):
39 |
40 | def __init__(self, in_channels, out_channels, dropout=False, relu=True):
41 | super().__init__()
42 |
43 | if dropout:
44 | self.dropout = nn.Dropout(p=0.5)
45 | else:
46 | self.dropout = None
47 |
48 | self.layer = nn.Linear(in_channels, out_channels)
49 |
50 | if relu:
51 | # self.relu = nn.LeakyReLU(negative_slope=0.2)
52 | self.relu = nn.ReLU()
53 | else:
54 | self.relu = None
55 |
56 | def forward(self, inputs, adj):
57 | if self.dropout is not None:
58 | inputs = self.dropout(inputs)
59 |
60 | outputs = torch.mm(adj, torch.mm(inputs, self.layer.weight.T)) + self.layer.bias
61 |
62 | if self.relu is not None:
63 | outputs = self.relu(outputs)
64 | return outputs
65 |
66 |
67 | class GCN(nn.Module):
68 |
69 | def __init__(self, adj, in_channels, out_channels, hidden_layers):
70 | super().__init__()
71 |
72 | adj = normt_spm(adj, method='in')
73 | adj = spm_to_tensor(adj)
74 | self.adj = adj.to(device)
75 |
76 | self.train_adj = self.adj
77 |
78 | hl = hidden_layers.split(',')
79 | if hl[-1] == 'd':
80 | dropout_last = True
81 | hl = hl[:-1]
82 | else:
83 | dropout_last = False
84 |
85 | i = 0
86 | layers = []
87 | last_c = in_channels
88 | for c in hl:
89 | if c[0] == 'd':
90 | dropout = True
91 | c = c[1:]
92 | else:
93 | dropout = False
94 | c = int(c)
95 |
96 | i += 1
97 | conv = GraphConv(last_c, c, dropout=dropout)
98 | self.add_module('conv{}'.format(i), conv)
99 | layers.append(conv)
100 |
101 | last_c = c
102 |
103 | conv = GraphConv(last_c, out_channels, relu=False, dropout=dropout_last)
104 | self.add_module('conv-last', conv)
105 | layers.append(conv)
106 |
107 | self.layers = layers
108 |
109 | def forward(self, x):
110 | if self.training:
111 | for conv in self.layers:
112 | x = conv(x, self.train_adj)
113 | else:
114 | for conv in self.layers:
115 | x = conv(x, self.adj)
116 | return F.normalize(x)
117 |
118 | ### GCNII
119 | class GraphConvolution(nn.Module):
120 |
121 | def __init__(self, in_features, out_features, dropout=False, relu=True, residual=False, variant=False):
122 | super(GraphConvolution, self).__init__()
123 | self.variant = variant
124 | if self.variant:
125 | self.in_features = 2*in_features
126 | else:
127 | self.in_features = in_features
128 |
129 | if dropout:
130 | self.dropout = nn.Dropout(p=0.5)
131 | else:
132 | self.dropout = None
133 |
134 | if relu:
135 | self.relu = nn.ReLU()
136 | else:
137 | self.relu = None
138 |
139 | self.out_features = out_features
140 | self.residual = residual
141 | self.layer = nn.Linear(self.in_features, self.out_features, bias = False)
142 |
143 | def reset_parameters(self):
144 | stdv = 1. / math.sqrt(self.out_features)
145 | self.weight.data.uniform_(-stdv, stdv)
146 |
147 | def forward(self, input, adj , h0 , lamda, alpha, l):
148 | if self.dropout is not None:
149 | input = self.dropout(input)
150 |
151 | theta = math.log(lamda/l+1)
152 | hi = torch.spmm(adj, input)
153 | if self.variant:
154 | support = torch.cat([hi,h0],1)
155 | r = (1-alpha)*hi+alpha*h0
156 | else:
157 | support = (1-alpha)*hi+alpha*h0
158 | r = support
159 |
160 | mm_term = torch.mm(support, self.layer.weight.T)
161 |
162 | output = theta*mm_term+(1-theta)*r
163 | if self.residual:
164 | output = output+input
165 |
166 | if self.relu is not None:
167 | output = self.relu(output)
168 |
169 | return output
170 |
171 | class GCNII(nn.Module):
172 | def __init__(self, adj, in_channels , out_channels, hidden_dim, hidden_layers, lamda, alpha, variant, dropout = True):
173 | super(GCNII, self).__init__()
174 |
175 | self.alpha = alpha
176 | self.lamda = lamda
177 |
178 | adj = normt_spm(adj, method='in')
179 | adj = spm_to_tensor(adj)
180 | self.adj = adj.to(device)
181 |
182 | i = 0
183 | layers = nn.ModuleList()
184 | self.fc_dim = nn.Linear(in_channels, hidden_dim)
185 | self.relu = nn.ReLU()
186 | self.dropout = nn.Dropout()
187 |
188 | for i, c in enumerate(range(hidden_layers)):
189 | conv = GraphConvolution(hidden_dim, hidden_dim, variant=variant, dropout=dropout)
190 | layers.append(conv)
191 |
192 | self.layers = layers
193 | self.fc_out = nn.Linear(hidden_dim, out_channels)
194 |
195 | def forward(self, x):
196 | _layers = []
197 | layer_inner = self.relu(self.fc_dim(self.dropout(x)))
198 | # layer_inner = x
199 | _layers.append(layer_inner)
200 |
201 | for i,con in enumerate(self.layers):
202 | layer_inner = con(layer_inner,self.adj,_layers[0],self.lamda,self.alpha,i+1)
203 |
204 | layer_inner = self.fc_out(self.dropout(layer_inner))
205 | return layer_inner
--------------------------------------------------------------------------------
/models/graph_method.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from .common import MLP
6 | from .gcn import GCN, GCNII
7 | from .word_embedding import load_word_embeddings
8 | import scipy.sparse as sp
9 |
10 |
11 | def adj_to_edges(adj):
12 | # Adj sparse matrix to list of edges
13 | rows, cols = np.nonzero(adj)
14 | edges = list(zip(rows.tolist(), cols.tolist()))
15 | return edges
16 |
17 |
18 | def edges_to_adj(edges, n):
19 | # List of edges to Adj sparse matrix
20 | edges = np.array(edges)
21 | adj = sp.coo_matrix((np.ones(len(edges)), (edges[:, 0], edges[:, 1])),
22 | shape=(n, n), dtype='float32')
23 | return adj
24 |
25 |
26 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
27 |
28 | class GraphFull(nn.Module):
29 | def __init__(self, dset, args):
30 | super(GraphFull, self).__init__()
31 | self.args = args
32 | self.dset = dset
33 |
34 | self.val_forward = self.val_forward_dotpr
35 | self.train_forward = self.train_forward_normal
36 | # Image Embedder
37 | self.num_attrs, self.num_objs, self.num_pairs = len(dset.attrs), len(dset.objs), len(dset.pairs)
38 | self.pairs = dset.pairs
39 |
40 | if self.args.train_only:
41 | train_idx = []
42 | for current in dset.train_pairs:
43 | train_idx.append(dset.all_pair2idx[current]+self.num_attrs+self.num_objs)
44 | self.train_idx = torch.LongTensor(train_idx).to(device)
45 |
46 | self.args.fc_emb = self.args.fc_emb.split(',')
47 | layers = []
48 | for a in self.args.fc_emb:
49 | a = int(a)
50 | layers.append(a)
51 |
52 | if args.nlayers:
53 | self.image_embedder = MLP(dset.feat_dim, args.emb_dim, num_layers= args.nlayers, dropout = self.args.dropout,
54 | norm = self.args.norm, layers = layers, relu = True)
55 |
56 | all_words = list(self.dset.attrs) + list(self.dset.objs)
57 | self.displacement = len(all_words)
58 |
59 | self.obj_to_idx = {word: idx for idx, word in enumerate(self.dset.objs)}
60 | self.attr_to_idx = {word: idx for idx, word in enumerate(self.dset.attrs)}
61 |
62 | if args.graph_init is not None:
63 | path = args.graph_init
64 | graph = torch.load(path)
65 | embeddings = graph['embeddings'].to(device)
66 | adj = graph['adj']
67 | self.embeddings = embeddings
68 | else:
69 | embeddings = self.init_embeddings(all_words).to(device)
70 | adj = self.adj_from_pairs()
71 | self.embeddings = embeddings
72 |
73 | hidden_layers = self.args.gr_emb
74 | if args.gcn_type == 'gcn':
75 | self.gcn = GCN(adj, self.embeddings.shape[1], args.emb_dim, hidden_layers)
76 | else:
77 | self.gcn = GCNII(adj, self.embeddings.shape[1], args.emb_dim, args.hidden_dim, args.gcn_nlayers, lamda = 0.5, alpha = 0.1, variant = False)
78 |
79 |
80 |
81 | def init_embeddings(self, all_words):
82 |
83 | def get_compositional_embeddings(embeddings, pairs):
84 | # Getting compositional embeddings from base embeddings
85 | composition_embeds = []
86 | for (attr, obj) in pairs:
87 | attr_embed = embeddings[self.attr_to_idx[attr]]
88 | obj_embed = embeddings[self.obj_to_idx[obj]+self.num_attrs]
89 | composed_embed = (attr_embed + obj_embed) / 2
90 | composition_embeds.append(composed_embed)
91 | composition_embeds = torch.stack(composition_embeds)
92 | print('Compositional Embeddings are ', composition_embeds.shape)
93 | return composition_embeds
94 |
95 | # init with word embeddings
96 | embeddings = load_word_embeddings(self.args.emb_init, all_words)
97 |
98 | composition_embeds = get_compositional_embeddings(embeddings, self.pairs)
99 | full_embeddings = torch.cat([embeddings, composition_embeds], dim=0)
100 |
101 | return full_embeddings
102 |
103 |
104 | def update_dict(self, wdict, row,col,data):
105 | wdict['row'].append(row)
106 | wdict['col'].append(col)
107 | wdict['data'].append(data)
108 |
109 | def adj_from_pairs(self):
110 |
111 | def edges_from_pairs(pairs):
112 | weight_dict = {'data':[],'row':[],'col':[]}
113 |
114 |
115 | for i in range(self.displacement):
116 | self.update_dict(weight_dict,i,i,1.)
117 |
118 | for idx, (attr, obj) in enumerate(pairs):
119 | attr_idx, obj_idx = self.attr_to_idx[attr], self.obj_to_idx[obj] + self.num_attrs
120 |
121 | self.update_dict(weight_dict, attr_idx, obj_idx, 1.)
122 | self.update_dict(weight_dict, obj_idx, attr_idx, 1.)
123 |
124 | node_id = idx + self.displacement
125 | self.update_dict(weight_dict,node_id,node_id,1.)
126 |
127 | self.update_dict(weight_dict, node_id, attr_idx, 1.)
128 | self.update_dict(weight_dict, node_id, obj_idx, 1.)
129 |
130 |
131 | self.update_dict(weight_dict, attr_idx, node_id, 1.)
132 | self.update_dict(weight_dict, obj_idx, node_id, 1.)
133 |
134 | return weight_dict
135 |
136 | edges = edges_from_pairs(self.pairs)
137 | adj = sp.csr_matrix((edges['data'], (edges['row'], edges['col'])),
138 | shape=(len(self.pairs)+self.displacement, len(self.pairs)+self.displacement))
139 |
140 | return adj
141 |
142 |
143 |
144 | def train_forward_normal(self, x):
145 | img, attrs, objs, pairs = x[0], x[1], x[2], x[3]
146 |
147 | if self.args.nlayers:
148 | img_feats = self.image_embedder(img)
149 | else:
150 | img_feats = (img)
151 |
152 | current_embeddings = self.gcn(self.embeddings)
153 |
154 | if self.args.train_only:
155 | pair_embed = current_embeddings[self.train_idx]
156 | else:
157 | pair_embed = current_embeddings[self.num_attrs+self.num_objs:self.num_attrs+self.num_objs+self.num_pairs,:]
158 |
159 | pair_embed = pair_embed.permute(1,0)
160 | pair_pred = torch.matmul(img_feats, pair_embed)
161 | loss = F.cross_entropy(pair_pred, pairs)
162 |
163 | return loss, None
164 |
165 | def val_forward_dotpr(self, x):
166 | img = x[0]
167 |
168 | if self.args.nlayers:
169 | img_feats = self.image_embedder(img)
170 | else:
171 | img_feats = (img)
172 |
173 | current_embedddings = self.gcn(self.embeddings)
174 |
175 | pair_embeds = current_embedddings[self.num_attrs+self.num_objs:self.num_attrs+self.num_objs+self.num_pairs,:].permute(1,0)
176 |
177 | score = torch.matmul(img_feats, pair_embeds)
178 |
179 | scores = {}
180 | for itr, pair in enumerate(self.dset.pairs):
181 | scores[pair] = score[:,self.dset.all_pair2idx[pair]]
182 |
183 | return None, scores
184 |
185 | def val_forward_distance_fast(self, x):
186 | img = x[0]
187 |
188 | img_feats = (self.image_embedder(img))
189 | current_embeddings = self.gcn(self.embeddings)
190 | pair_embeds = current_embeddings[self.num_attrs+self.num_objs:,:]
191 |
192 | batch_size, pairs, features = img_feats.shape[0], pair_embeds.shape[0], pair_embeds.shape[1]
193 | img_feats = img_feats[:,None,:].expand(-1, pairs, -1)
194 | pair_embeds = pair_embeds[None,:,:].expand(batch_size, -1, -1)
195 | diff = (img_feats - pair_embeds)**2
196 | score = diff.sum(2) * -1
197 |
198 | scores = {}
199 | for itr, pair in enumerate(self.dset.pairs):
200 | scores[pair] = score[:,self.dset.all_pair2idx[pair]]
201 |
202 | return None, scores
203 |
204 | def forward(self, x):
205 | if self.training:
206 | loss, pred = self.train_forward(x)
207 | else:
208 | with torch.no_grad():
209 | loss, pred = self.val_forward(x)
210 | return loss, pred
211 |
--------------------------------------------------------------------------------
/models/image_extractor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torchvision import models
5 | from torchvision.models.resnet import ResNet, BasicBlock
6 |
7 | class ResNet18_conv(ResNet):
8 | def __init__(self):
9 | super(ResNet18_conv, self).__init__(BasicBlock, [2, 2, 2, 2])
10 |
11 | def forward(self, x):
12 | # change forward here
13 | x = self.conv1(x)
14 | x = self.bn1(x)
15 | x = self.relu(x)
16 | x = self.maxpool(x)
17 |
18 | x = self.layer1(x)
19 | x = self.layer2(x)
20 | x = self.layer3(x)
21 | x = self.layer4(x)
22 |
23 | return x
24 |
25 |
26 | def get_image_extractor(arch = 'resnet18', pretrained = True, feature_dim = None, checkpoint = ''):
27 | '''
28 | Inputs
29 | arch: Base architecture
30 | pretrained: Bool, Imagenet weights
31 | feature_dim: Int, output feature dimension
32 | checkpoint: String, not implemented
33 | Returns
34 | Pytorch model
35 | '''
36 |
37 | if arch == 'resnet18':
38 | model = models.resnet18(pretrained = pretrained)
39 | if feature_dim is None:
40 | model.fc = nn.Sequential()
41 | else:
42 | model.fc = nn.Linear(512, feature_dim)
43 |
44 | if arch == 'resnet18_conv':
45 | model = ResNet18_conv()
46 | model.load_state_dict(models.resnet18(pretrained=True).state_dict())
47 |
48 | elif arch == 'resnet50':
49 | model = models.resnet50(pretrained = pretrained)
50 | if feature_dim is None:
51 | model.fc = nn.Sequential()
52 | else:
53 | model.fc = nn.Linear(2048, feature_dim)
54 |
55 | elif arch == 'resnet50_cutmix':
56 | model = models.resnet50(pretrained = pretrained)
57 | checkpoint = torch.load('/home/ubuntu/workspace/pretrained/resnet50_cutmix.tar')
58 | model.load_state_dict(checkpoint['state_dict'], strict=False)
59 | if feature_dim is None:
60 | model.fc = nn.Sequential()
61 | else:
62 | model.fc = nn.Linear(2048, feature_dim)
63 |
64 | elif arch == 'resnet152':
65 | model = models.resnet152(pretrained = pretrained)
66 | if feature_dim is None:
67 | model.fc = nn.Sequential()
68 | else:
69 | model.fc = nn.Linear(2048, feature_dim)
70 |
71 | elif arch == 'vgg16':
72 | model = models.vgg16(pretrained = pretrained)
73 | modules = list(model.classifier.children())[:-3]
74 | model.classifier=torch.nn.Sequential(*modules)
75 | if feature_dim is not None:
76 | model.classifier[3]=torch.nn.Linear(4096,feature_dim)
77 |
78 | return model
79 |
80 |
--------------------------------------------------------------------------------
/models/manifold_methods.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from .word_embedding import load_word_embeddings
6 | from .common import MLP, Reshape
7 | from flags import DATA_FOLDER
8 |
9 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
10 |
11 | class ManifoldModel(nn.Module):
12 |
13 | def __init__(self, dset, args):
14 | super(ManifoldModel, self).__init__()
15 | self.args = args
16 | self.dset = dset
17 |
18 | def get_all_ids(relevant_pairs):
19 | # Precompute validation pairs
20 | attrs, objs = zip(*relevant_pairs)
21 | attrs = [dset.attr2idx[attr] for attr in attrs]
22 | objs = [dset.obj2idx[obj] for obj in objs]
23 | pairs = [a for a in range(len(relevant_pairs))]
24 | attrs = torch.LongTensor(attrs).to(device)
25 | objs = torch.LongTensor(objs).to(device)
26 | pairs = torch.LongTensor(pairs).to(device)
27 | return attrs, objs, pairs
28 |
29 | #Validation
30 | self.val_attrs, self.val_objs, self.val_pairs = get_all_ids(self.dset.pairs)
31 | # for indivual projections
32 | self.uniq_attrs, self.uniq_objs = torch.arange(len(self.dset.attrs)).long().to(device), \
33 | torch.arange(len(self.dset.objs)).long().to(device)
34 | self.factor = 2
35 | # Precompute training compositions
36 | if args.train_only:
37 | self.train_attrs, self.train_objs, self.train_pairs = get_all_ids(self.dset.train_pairs)
38 | else:
39 | self.train_attrs, self.train_objs = self.val_subs, self.val_attrs, self.val_objs
40 |
41 | if args.lambda_aux > 0 or args.lambda_cls_attr > 0 or args.lambda_cls_obj > 0:
42 | print('Initializing classifiers')
43 | self.obj_clf = nn.Linear(args.emb_dim, len(dset.objs))
44 | self.attr_clf = nn.Linear(args.emb_dim, len(dset.attrs))
45 |
46 | def train_forward_bce(self, x):
47 | img, attrs, objs = x[0], x[1], x[2]
48 | neg_attrs, neg_objs = x[4][:,0],x[5][:,0] #todo: do a less hacky version
49 |
50 | img_feat = self.image_embedder(img)
51 | # Sample 25% positive and 75% negative pairs
52 | labels = np.random.binomial(1, 0.25, attrs.shape[0])
53 | labels = torch.from_numpy(labels).bool().to(device)
54 | sampled_attrs, sampled_objs = neg_attrs.clone(), neg_objs.clone()
55 | sampled_attrs[labels] = attrs[labels]
56 | sampled_objs[labels] = objs[labels]
57 | labels = labels.float()
58 |
59 | composed_clf = self.compose(attrs, objs)
60 | p = torch.sigmoid((img_feat*composed_clf).sum(1))
61 | loss = F.binary_cross_entropy(p, labels)
62 |
63 | return loss, None
64 |
65 | def train_forward_triplet(self, x):
66 | img, attrs, objs = x[0], x[1], x[2]
67 | neg_attrs, neg_objs = x[4][:,0], x[5][:,0] #todo:do a less hacky version
68 |
69 | img_feats = self.image_embedder(img)
70 | positive = self.compose(attrs, objs)
71 | negative = self.compose(neg_attrs, neg_objs)
72 | loss = F.triplet_margin_loss(img_feats, positive, negative, margin = self.args.margin)
73 |
74 | # Auxiliary object/ attribute prediction loss both need to be correct
75 | if self.args.lambda_aux > 0:
76 | obj_pred = self.obj_clf(positive)
77 | attr_pred = self.attr_clf(positive)
78 | loss_aux = F.cross_entropy(attr_pred, attrs) + F.cross_entropy(obj_pred, objs)
79 | loss += self.args.lambda_aux * loss_aux
80 |
81 | return loss, None
82 |
83 |
84 | def val_forward_distance(self, x):
85 | img = x[0]
86 | batch_size = img.shape[0]
87 |
88 | img_feats = self.image_embedder(img)
89 | scores = {}
90 | pair_embeds = self.compose(self.val_attrs, self.val_objs)
91 |
92 | for itr, pair in enumerate(self.dset.pairs):
93 | pair_embed = pair_embeds[itr, None].expand(batch_size, pair_embeds.size(1))
94 | score = self.compare_metric(img_feats, pair_embed)
95 | scores[pair] = score
96 |
97 | return None, scores
98 |
99 | def val_forward_distance_fast(self, x):
100 | img = x[0]
101 | batch_size = img.shape[0]
102 |
103 | img_feats = self.image_embedder(img)
104 | pair_embeds = self.compose(self.val_attrs, self.val_objs) # Evaluate all pairs
105 |
106 | batch_size, pairs, features = img_feats.shape[0], pair_embeds.shape[0], pair_embeds.shape[1]
107 | img_feats = img_feats[:,None,:].expand(-1, pairs, -1)
108 | pair_embeds = pair_embeds[None,:,:].expand(batch_size, -1, -1)
109 | diff = (img_feats - pair_embeds)**2
110 | score = diff.sum(2) * -1
111 |
112 | scores = {}
113 | for itr, pair in enumerate(self.dset.pairs):
114 | scores[pair] = score[:,self.dset.all_pair2idx[pair]]
115 |
116 | return None, scores
117 |
118 | def val_forward_direct(self, x):
119 | img = x[0]
120 | batch_size = img.shape[0]
121 |
122 | img_feats = self.image_embedder(img)
123 | pair_embeds = self.compose(self.val_attrs, self.val_objs).permute(1,0) # Evaluate all pairs
124 | score = torch.matmul(img_feats, pair_embeds)
125 |
126 | scores = {}
127 | for itr, pair in enumerate(self.dset.pairs):
128 | scores[pair] = score[:,self.dset.all_pair2idx[pair]]
129 |
130 | return None, scores
131 |
132 | def forward(self, x):
133 | if self.training:
134 | loss, pred = self.train_forward(x)
135 | else:
136 | with torch.no_grad():
137 | loss, pred = self.val_forward(x)
138 | return loss, pred
139 |
140 | #########################################
141 | class RedWine(ManifoldModel):
142 |
143 | def __init__(self, dset, args):
144 | super(RedWine, self).__init__(dset, args)
145 | self.image_embedder = lambda img: img
146 | self.compare_metric = lambda img_feats, pair_embed: torch.sigmoid((img_feats*pair_embed).sum(1))
147 | self.train_forward = self.train_forward_bce
148 | self.val_forward = self.val_forward_distance_fast
149 |
150 | in_dim = dset.feat_dim if args.clf_init else args.emb_dim
151 | self.T = nn.Sequential(
152 | nn.Linear(2*in_dim, 3*in_dim),
153 | nn.LeakyReLU(0.1, True),
154 | nn.Linear(3*in_dim, 3*in_dim//2),
155 | nn.LeakyReLU(0.1, True),
156 | nn.Linear(3*in_dim//2, dset.feat_dim))
157 |
158 | self.attr_embedder = nn.Embedding(len(dset.attrs), in_dim)
159 | self.obj_embedder = nn.Embedding(len(dset.objs), in_dim)
160 |
161 | # initialize the weights of the embedders with the svm weights
162 | if args.emb_init:
163 | pretrained_weight = load_word_embeddings(args.emb_init, dset.attrs)
164 | self.attr_embedder.weight.data.copy_(pretrained_weight)
165 | pretrained_weight = load_word_embeddings(args.emb_init, dset.objs)
166 | self.obj_embedder.weight.data.copy_(pretrained_weight)
167 |
168 | elif args.clf_init:
169 | for idx, attr in enumerate(dset.attrs):
170 | at_id = self.dset.attr2idx[attr]
171 | weight = torch.load('%s/svm/attr_%d'%(args.data_dir, at_id)).coef_.squeeze()
172 | self.attr_embedder.weight[idx].data.copy_(torch.from_numpy(weight))
173 | for idx, obj in enumerate(dset.objs):
174 | obj_id = self.dset.obj2idx[obj]
175 | weight = torch.load('%s/svm/obj_%d'%(args.data_dir, obj_id)).coef_.squeeze()
176 | self.obj_embedder.weight[idx].data.copy_(torch.from_numpy(weight))
177 | else:
178 | print ('init must be either glove or clf')
179 | return
180 |
181 | if args.static_inp:
182 | for param in self.attr_embedder.parameters():
183 | param.requires_grad = False
184 | for param in self.obj_embedder.parameters():
185 | param.requires_grad = False
186 |
187 | def compose(self, attrs, objs):
188 | attr_wt = self.attr_embedder(attrs)
189 | obj_wt = self.obj_embedder(objs)
190 | inp_wts = torch.cat([attr_wt, obj_wt], 1) # 2D
191 | composed_clf = self.T(inp_wts)
192 | return composed_clf
193 |
194 | class LabelEmbedPlus(ManifoldModel):
195 | def __init__(self, dset, args):
196 | super(LabelEmbedPlus, self).__init__(dset, args)
197 | if 'conv' in args.image_extractor:
198 | self.image_embedder = torch.nn.Sequential(torch.nn.Conv2d(dset.feat_dim,args.emb_dim,7),
199 | torch.nn.ReLU(True),
200 | Reshape(-1,args.emb_dim)
201 | )
202 | else:
203 | self.image_embedder = MLP(dset.feat_dim, args.emb_dim)
204 |
205 | self.compare_metric = lambda img_feats, pair_embed: -F.pairwise_distance(img_feats, pair_embed)
206 | self.train_forward = self.train_forward_triplet
207 | self.val_forward = self.val_forward_distance_fast
208 |
209 | input_dim = dset.feat_dim if args.clf_init else args.emb_dim
210 | self.attr_embedder = nn.Embedding(len(dset.attrs), input_dim)
211 | self.obj_embedder = nn.Embedding(len(dset.objs), input_dim)
212 | self.T = MLP(2*input_dim, args.emb_dim, num_layers= args.nlayers)
213 |
214 | # init with word embeddings
215 | if args.emb_init:
216 | pretrained_weight = load_word_embeddings(args.emb_init, dset.attrs)
217 | self.attr_embedder.weight.data.copy_(pretrained_weight)
218 | pretrained_weight = load_word_embeddings(args.emb_init, dset.objs)
219 | self.obj_embedder.weight.data.copy_(pretrained_weight)
220 |
221 | # init with classifier weights
222 | elif args.clf_init:
223 | for idx, attr in enumerate(dset.attrs):
224 | at_id = dset.attrs.index(attr)
225 | weight = torch.load('%s/svm/attr_%d'%(args.data_dir, at_id)).coef_.squeeze()
226 | self.attr_embedder.weight[idx].data.copy_(torch.from_numpy(weight))
227 | for idx, obj in enumerate(dset.objs):
228 | obj_id = dset.objs.index(obj)
229 | weight = torch.load('%s/svm/obj_%d'%(args.data_dir, obj_id)).coef_.squeeze()
230 | self.obj_emb.weight[idx].data.copy_(torch.from_numpy(weight))
231 |
232 | # static inputs
233 | if args.static_inp:
234 | for param in self.attr_embedder.parameters():
235 | param.requires_grad = False
236 | for param in self.obj_embedder.parameters():
237 | param.requires_grad = False
238 |
239 | def compose(self, attrs, objs):
240 | inputs = [self.attr_embedder(attrs), self.obj_embedder(objs)]
241 | inputs = torch.cat(inputs, 1)
242 | output = self.T(inputs)
243 | return output
244 |
245 | class AttributeOperator(ManifoldModel):
246 | def __init__(self, dset, args):
247 | super(AttributeOperator, self).__init__(dset, args)
248 | self.image_embedder = MLP(dset.feat_dim, args.emb_dim)
249 | self.compare_metric = lambda img_feats, pair_embed: -F.pairwise_distance(img_feats, pair_embed)
250 | self.val_forward = self.val_forward_distance_fast
251 |
252 | self.attr_ops = nn.ParameterList([nn.Parameter(torch.eye(args.emb_dim)) for _ in range(len(self.dset.attrs))])
253 | self.obj_embedder = nn.Embedding(len(dset.objs), args.emb_dim)
254 |
255 | if args.emb_init:
256 | pretrained_weight = load_word_embeddings('glove', dset.objs)
257 | self.obj_embedder.weight.data.copy_(pretrained_weight)
258 |
259 | self.inverse_cache = {}
260 |
261 | if args.lambda_ant>0 and args.dataset=='mitstates':
262 | antonym_list = open(DATA_FOLDER+'/data/antonyms.txt').read().strip().split('\n')
263 | antonym_list = [l.split() for l in antonym_list]
264 | antonym_list = [[self.dset.attrs.index(a1), self.dset.attrs.index(a2)] for a1, a2 in antonym_list]
265 | antonyms = {}
266 | antonyms.update({a1:a2 for a1, a2 in antonym_list})
267 | antonyms.update({a2:a1 for a1, a2 in antonym_list})
268 | self.antonyms, self.antonym_list = antonyms, antonym_list
269 |
270 | if args.static_inp:
271 | for param in self.obj_embedder.parameters():
272 | param.requires_grad = False
273 |
274 | def apply_ops(self, ops, rep):
275 | out = torch.bmm(ops, rep.unsqueeze(2)).squeeze(2)
276 | out = F.relu(out)
277 | return out
278 |
279 | def compose(self, attrs, objs):
280 | obj_rep = self.obj_embedder(objs)
281 | attr_ops = torch.stack([self.attr_ops[attr.item()] for attr in attrs])
282 | embedded_reps = self.apply_ops(attr_ops, obj_rep)
283 | return embedded_reps
284 |
285 | def apply_inverse(self, img_rep, attrs):
286 | inverse_ops = []
287 | for i in range(img_rep.size(0)):
288 | attr = attrs[i]
289 | if attr not in self.inverse_cache:
290 | self.inverse_cache[attr] = self.attr_ops[attr].inverse()
291 | inverse_ops.append(self.inverse_cache[attr])
292 | inverse_ops = torch.stack(inverse_ops) # (B,512,512)
293 | obj_rep = self.apply_ops(inverse_ops, img_rep)
294 | return obj_rep
295 |
296 | def train_forward(self, x):
297 | img, attrs, objs = x[0], x[1], x[2]
298 | neg_attrs, neg_objs, inv_attrs, comm_attrs = x[4][:,0], x[5][:,0], x[6], x[7]
299 | batch_size = img.size(0)
300 |
301 | loss = []
302 |
303 | anchor = self.image_embedder(img)
304 |
305 | obj_emb = self.obj_embedder(objs)
306 | pos_ops = torch.stack([self.attr_ops[attr.item()] for attr in attrs])
307 | positive = self.apply_ops(pos_ops, obj_emb)
308 |
309 | neg_obj_emb = self.obj_embedder(neg_objs)
310 | neg_ops = torch.stack([self.attr_ops[attr.item()] for attr in neg_attrs])
311 | negative = self.apply_ops(neg_ops, neg_obj_emb)
312 |
313 | loss_triplet = F.triplet_margin_loss(anchor, positive, negative, margin=self.args.margin)
314 | loss.append(loss_triplet)
315 |
316 |
317 | # Auxiliary object/attribute loss
318 | if self.args.lambda_aux>0:
319 | obj_pred = self.obj_clf(positive)
320 | attr_pred = self.attr_clf(positive)
321 | loss_aux = F.cross_entropy(attr_pred, attrs) + F.cross_entropy(obj_pred, objs)
322 | loss.append(self.args.lambda_aux*loss_aux)
323 |
324 | # Inverse Consistency
325 | if self.args.lambda_inv>0:
326 | obj_rep = self.apply_inverse(anchor, attrs)
327 | new_ops = torch.stack([self.attr_ops[attr.item()] for attr in inv_attrs])
328 | new_rep = self.apply_ops(new_ops, obj_rep)
329 | new_positive = self.apply_ops(new_ops, obj_emb)
330 | loss_inv = F.triplet_margin_loss(new_rep, new_positive, positive, margin=self.args.margin)
331 | loss.append(self.args.lambda_inv*loss_inv)
332 |
333 | # Commutative Operators
334 | if self.args.lambda_comm>0:
335 | B = torch.stack([self.attr_ops[attr.item()] for attr in comm_attrs])
336 | BA = self.apply_ops(B, positive)
337 | AB = self.apply_ops(pos_ops, self.apply_ops(B, obj_emb))
338 | loss_comm = ((AB-BA)**2).sum(1).mean()
339 | loss.append(self.args.lambda_comm*loss_comm)
340 |
341 | # Antonym Consistency
342 | if self.args.lambda_ant>0:
343 |
344 | select_idx = [i for i in range(batch_size) if attrs[i].item() in self.antonyms]
345 | if len(select_idx)>0:
346 | select_idx = torch.LongTensor(select_idx).cuda()
347 | attr_subset = attrs[select_idx]
348 | antonym_ops = torch.stack([self.attr_ops[self.antonyms[attr.item()]] for attr in attr_subset])
349 |
350 | Ao = anchor[select_idx]
351 | if self.args.lambda_inv>0:
352 | o = obj_rep[select_idx]
353 | else:
354 | o = self.apply_inverse(Ao, attr_subset)
355 | BAo = self.apply_ops(antonym_ops, Ao)
356 |
357 | loss_cycle = ((BAo-o)**2).sum(1).mean()
358 | loss.append(self.args.lambda_ant*loss_cycle)
359 |
360 |
361 | loss = sum(loss)
362 | return loss, None
363 |
364 | def forward(self, x):
365 | loss, pred = super(AttributeOperator, self).forward(x)
366 | self.inverse_cache = {}
367 | return loss, pred
--------------------------------------------------------------------------------
/models/modular_methods.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torchvision.models as tmodels
11 | import numpy as np
12 | from .word_embedding import load_word_embeddings
13 | import itertools
14 | import math
15 | import collections
16 | from torch.distributions.bernoulli import Bernoulli
17 | import pdb
18 | import sys
19 |
20 |
21 | class Flatten(nn.Module):
22 | def forward(self, input):
23 | return input.view(input.size(0), -1)
24 |
25 |
26 | class GatingSampler(nn.Module):
27 | """Docstring for GatingSampler. """
28 |
29 | def __init__(self, gater, stoch_sample=True, temperature=1.0):
30 | """TODO: to be defined1.
31 |
32 | """
33 | nn.Module.__init__(self)
34 | self.gater = gater
35 | self._stoch_sample = stoch_sample
36 | self._temperature = temperature
37 |
38 | def disable_stochastic_sampling(self):
39 | self._stoch_sample = False
40 |
41 | def enable_stochastic_sampling(self):
42 | self._stoch_sample = True
43 |
44 | def forward(self, tdesc=None, return_additional=False, gating_wt=None):
45 | if self.gater is None and return_additional:
46 | return None, None
47 | elif self.gater is None:
48 | return None
49 |
50 | if gating_wt is not None:
51 | return_wts = gating_wt
52 | gating_g = self.gater(tdesc, gating_wt=gating_wt)
53 | else:
54 | gating_g = self.gater(tdesc)
55 | return_wts = None
56 | if isinstance(gating_g, tuple):
57 | return_wts = gating_g[1]
58 | gating_g = gating_g[0]
59 |
60 | if not self._stoch_sample:
61 | sampled_g = gating_g
62 | else:
63 | raise (NotImplementedError)
64 |
65 | if return_additional:
66 | return sampled_g, return_wts
67 | return sampled_g
68 |
69 |
70 | class GatedModularNet(nn.Module):
71 | """
72 | An interface for creating modular nets
73 | """
74 |
75 | def __init__(self,
76 | module_list,
77 | start_modules=None,
78 | end_modules=None,
79 | single_head=False,
80 | chain=False):
81 | """TODO: to be defined1.
82 |
83 | :module_list: TODO
84 | :g: TODO
85 |
86 | """
87 | nn.Module.__init__(self)
88 | self._module_list = nn.ModuleList(
89 | [nn.ModuleList(m) for m in module_list])
90 |
91 | self.num_layers = len(self._module_list)
92 | if start_modules is not None:
93 | self._start_modules = nn.ModuleList(start_modules)
94 | else:
95 | self._start_modules = None
96 |
97 | if end_modules is not None:
98 | self._end_modules = nn.ModuleList(end_modules)
99 | self.num_layers += 1
100 | else:
101 | self._end_modules = None
102 |
103 | self.sampled_g = None
104 | self.single_head = single_head
105 | self._chain = chain
106 |
107 | def forward(self, x, sampled_g=None, t=None, return_feat=False):
108 | """TODO: Docstring for forward.
109 |
110 | :x: Input data
111 | :g: Gating tensor (#Task x )#num_layer x #num_mods x #num_mods
112 | :t: task ID
113 | :returns: TODO
114 |
115 | """
116 |
117 | if t is None:
118 | t_not_set = True
119 | t = torch.tensor([0] * x.shape[0], dtype=x.dtype).long()
120 | else:
121 | t_not_set = False
122 | t = t.squeeze()
123 |
124 | if self._start_modules is not None:
125 | prev_out = [mod(x) for mod in self._start_modules]
126 | else:
127 | prev_out = [x]
128 |
129 | if sampled_g is None:
130 | # NON-Gated Module network
131 | prev_out = sum(prev_out) / float(len(prev_out))
132 | #prev_out = torch.mean(prev_out, 0)
133 | for li in range(len(self._module_list)):
134 | prev_out = sum([
135 | mod(prev_out) for mod in self._module_list[li]
136 | ]) / float(len(self._module_list[li]))
137 | features = prev_out
138 | if self._end_modules is not None:
139 | if t_not_set or self.single_head:
140 | prev_out = self._end_modules[0](prev_out)
141 | else:
142 | prev_out = torch.cat([
143 | self._end_modules[tid](prev_out[bi:bi + 1])
144 | for bi, tid in enumerate(t)
145 | ], 0)
146 | if return_feat:
147 | return prev_out, features
148 | return prev_out
149 | else:
150 | # Forward prop with sampled Gs
151 | for li in range(len(self._module_list)):
152 | curr_out = []
153 | for j in range(len(self._module_list[li])):
154 | gind = j if not self._chain else 0
155 | # Dim: #Batch x C
156 | module_in_wt = sampled_g[li + 1][gind]
157 | # Module input weights rearranged to match inputs
158 | module_in_wt = module_in_wt.transpose(0, 1)
159 | add_dims = prev_out[0].dim() + 1 - module_in_wt.dim()
160 | module_in_wt = module_in_wt.view(*module_in_wt.shape,
161 | *([1] * add_dims))
162 | module_in_wt = module_in_wt.expand(
163 | len(prev_out), *prev_out[0].shape)
164 | module_in = sum([
165 | module_in_wt[i] * prev_out[i]
166 | for i in range(len(prev_out))
167 | ])
168 | mod = self._module_list[li][j]
169 | curr_out.append(mod(module_in))
170 | prev_out = curr_out
171 |
172 | # Output modules (with sampled Gs)
173 | if self._end_modules is not None:
174 | li = self.num_layers - 1
175 | if t_not_set or self.single_head:
176 | # Dim: #Batch x C
177 | module_in_wt = sampled_g[li + 1][0]
178 | # Module input weights rearranged to match inputs
179 | module_in_wt = module_in_wt.transpose(0, 1)
180 | add_dims = prev_out[0].dim() + 1 - module_in_wt.dim()
181 | module_in_wt = module_in_wt.view(*module_in_wt.shape,
182 | *([1] * add_dims))
183 | module_in_wt = module_in_wt.expand(
184 | len(prev_out), *prev_out[0].shape)
185 | module_in = sum([
186 | module_in_wt[i] * prev_out[i]
187 | for i in range(len(prev_out))
188 | ])
189 | features = module_in
190 | prev_out = self._end_modules[0](module_in)
191 | else:
192 | curr_out = []
193 | for bi, tid in enumerate(t):
194 | # Dim: #Batch x C
195 | gind = tid if not self._chain else 0
196 | module_in_wt = sampled_g[li + 1][gind]
197 | # Module input weights rearranged to match inputs
198 | module_in_wt = module_in_wt.transpose(0, 1)
199 | add_dims = prev_out[0].dim() + 1 - module_in_wt.dim()
200 | module_in_wt = module_in_wt.view(
201 | *module_in_wt.shape, *([1] * add_dims))
202 | module_in_wt = module_in_wt.expand(
203 | len(prev_out), *prev_out[0].shape)
204 | module_in = sum([
205 | module_in_wt[i] * prev_out[i]
206 | for i in range(len(prev_out))
207 | ])
208 | features = module_in
209 | mod = self._end_modules[tid]
210 | curr_out.append(mod(module_in[bi:bi + 1]))
211 | prev_out = curr_out
212 | prev_out = torch.cat(prev_out, 0)
213 | if return_feat:
214 | return prev_out, features
215 | return prev_out
216 |
217 |
218 | class CompositionalModel(nn.Module):
219 | def __init__(self, dset, args):
220 | super(CompositionalModel, self).__init__()
221 | self.args = args
222 | self.dset = dset
223 |
224 | # precompute validation pairs
225 | attrs, objs = zip(*self.dset.pairs)
226 | attrs = [dset.attr2idx[attr] for attr in attrs]
227 | objs = [dset.obj2idx[obj] for obj in objs]
228 | self.val_attrs = torch.LongTensor(attrs).cuda()
229 | self.val_objs = torch.LongTensor(objs).cuda()
230 |
231 | def train_forward_softmax(self, x):
232 | img, attrs, objs = x[0], x[1], x[2]
233 | neg_attrs, neg_objs = x[4], x[5]
234 | inv_attrs, comm_attrs = x[6], x[7]
235 |
236 | sampled_attrs = torch.cat((attrs.unsqueeze(1), neg_attrs), 1)
237 | sampled_objs = torch.cat((objs.unsqueeze(1), neg_objs), 1)
238 | img_ind = torch.arange(sampled_objs.shape[0]).unsqueeze(1).repeat(
239 | 1, sampled_attrs.shape[1])
240 |
241 | flat_sampled_attrs = sampled_attrs.view(-1)
242 | flat_sampled_objs = sampled_objs.view(-1)
243 | flat_img_ind = img_ind.view(-1)
244 | labels = torch.zeros_like(sampled_attrs[:, 0]).long()
245 |
246 | self.composed_g = self.compose(flat_sampled_attrs, flat_sampled_objs)
247 |
248 | cls_scores, feat = self.comp_network(
249 | img[flat_img_ind], self.composed_g, return_feat=True)
250 | pair_scores = cls_scores[:, :1]
251 | pair_scores = pair_scores.view(*sampled_attrs.shape)
252 |
253 | loss = 0
254 | loss_cls = F.cross_entropy(pair_scores, labels)
255 | loss += loss_cls
256 |
257 | loss_obj = torch.FloatTensor([0])
258 | loss_attr = torch.FloatTensor([0])
259 | loss_sparse = torch.FloatTensor([0])
260 | loss_unif = torch.FloatTensor([0])
261 | loss_aux = torch.FloatTensor([0])
262 |
263 | acc = (pair_scores.argmax(1) == labels).sum().float() / float(
264 | len(labels))
265 | all_losses = {}
266 | all_losses['total_loss'] = loss
267 | all_losses['main_loss'] = loss_cls
268 | all_losses['aux_loss'] = loss_aux
269 | all_losses['obj_loss'] = loss_obj
270 | all_losses['attr_loss'] = loss_attr
271 | all_losses['sparse_loss'] = loss_sparse
272 | all_losses['unif_loss'] = loss_unif
273 |
274 | return loss, all_losses, acc, (pair_scores, feat)
275 |
276 | def val_forward(self, x):
277 | img = x[0]
278 | batch_size = img.shape[0]
279 | pair_scores = torch.zeros(batch_size, len(self.val_attrs))
280 | pair_feats = torch.zeros(batch_size, len(self.val_attrs),
281 | self.args.emb_dim)
282 | pair_bs = len(self.val_attrs)
283 |
284 | for pi in range(math.ceil(len(self.val_attrs) / pair_bs)):
285 | self.compose_g = self.compose(
286 | self.val_attrs[pi * pair_bs:(pi + 1) * pair_bs],
287 | self.val_objs[pi * pair_bs:(pi + 1) * pair_bs])
288 | compose_g = self.compose_g
289 | expanded_im = img.unsqueeze(1).repeat(
290 | 1, compose_g[0][0].shape[0],
291 | *tuple([1] * (img.dim() - 1))).view(-1, *img.shape[1:])
292 | expanded_compose_g = [[
293 | g.unsqueeze(0).repeat(batch_size, *tuple([1] * g.dim())).view(
294 | -1, *g.shape[1:]) for g in layer_g
295 | ] for layer_g in compose_g]
296 | this_pair_scores, this_feat = self.comp_network(
297 | expanded_im, expanded_compose_g, return_feat=True)
298 | featnorm = torch.norm(this_feat, p=2, dim=-1)
299 | this_feat = this_feat.div(
300 | featnorm.unsqueeze(-1).expand_as(this_feat))
301 |
302 | this_pair_scores = this_pair_scores[:, :1].view(batch_size, -1)
303 | this_feat = this_feat.view(batch_size, -1, self.args.emb_dim)
304 |
305 | pair_scores[:, pi * pair_bs:pi * pair_bs +
306 | this_pair_scores.shape[1]] = this_pair_scores[:, :]
307 | pair_feats[:, pi * pair_bs:pi * pair_bs +
308 | this_pair_scores.shape[1], :] = this_feat[:]
309 |
310 | scores = {}
311 | feats = {}
312 | for i, (attr, obj) in enumerate(self.dset.pairs):
313 | scores[(attr, obj)] = pair_scores[:, i]
314 | feats[(attr, obj)] = pair_feats[:, i]
315 |
316 | # return None, (scores, feats)
317 | return None, scores
318 |
319 | def forward(self, x, with_grad=False):
320 | if self.training:
321 | loss, loss_aux, acc, pred = self.train_forward(x)
322 | else:
323 | loss_aux = torch.Tensor([0])
324 | loss = torch.Tensor([0])
325 | if not with_grad:
326 | with torch.no_grad():
327 | acc, pred = self.val_forward(x)
328 | else:
329 | acc, pred = self.val_forward(x)
330 | # return loss, loss_aux, acc, pred
331 | return loss, pred
332 |
333 |
334 | class GatedGeneralNN(CompositionalModel):
335 | """Docstring for GatedCompositionalModel. """
336 |
337 | def __init__(self,
338 | dset,
339 | args,
340 | num_layers=2,
341 | num_modules_per_layer=3,
342 | stoch_sample=False,
343 | use_full_model=False,
344 | num_classes=[2],
345 | gater_type='general'):
346 | """TODO: to be defined1.
347 |
348 | :dset: TODO
349 | :args: TODO
350 |
351 | """
352 | CompositionalModel.__init__(self, dset, args)
353 |
354 | self.train_forward = self.train_forward_softmax
355 | self.compose_type = 'nn' #todo: could be different
356 |
357 | gating_in_dim = 128
358 | if args.emb_init:
359 | gating_in_dim = 300
360 | elif args.clf_init:
361 | gating_in_dim = 512
362 |
363 | if self.compose_type == 'nn':
364 | tdim = gating_in_dim * 2
365 | inter_tdim = self.args.embed_rank
366 | # Change this to allow only obj, only attr gatings
367 | self.attr_embedder = nn.Embedding(
368 | len(dset.attrs) + 1,
369 | gating_in_dim,
370 | padding_idx=len(dset.attrs),
371 | )
372 | self.obj_embedder = nn.Embedding(
373 | len(dset.objs) + 1,
374 | gating_in_dim,
375 | padding_idx=len(dset.objs),
376 | )
377 |
378 | # initialize the weights of the embedders with the svm weights
379 | if args.emb_init:
380 | pretrained_weight = load_word_embeddings(
381 | args.emb_init, dset.attrs)
382 | self.attr_embedder.weight[:-1, :].data.copy_(pretrained_weight)
383 | pretrained_weight = load_word_embeddings(
384 | args.emb_init, dset.objs)
385 | self.obj_embedder.weight.data[:-1, :].copy_(pretrained_weight)
386 | elif args.clf_init:
387 | for idx, attr in enumerate(dset.attrs):
388 | at_id = self.dset.attr2idx[attr]
389 | weight = torch.load(
390 | '%s/svm/attr_%d' % (args.data_dir,
391 | at_id)).coef_.squeeze()
392 | self.attr_embedder.weight[idx].data.copy_(
393 | torch.from_numpy(weight))
394 | for idx, obj in enumerate(dset.objs):
395 | obj_id = self.dset.obj2idx[obj]
396 | weight = torch.load(
397 | '%s/svm/obj_%d' % (args.data_dir,
398 | obj_id)).coef_.squeeze()
399 | self.obj_embedder.weight[idx].data.copy_(
400 | torch.from_numpy(weight))
401 | else:
402 | n_attr = len(dset.attrs)
403 | gating_in_dim = 300
404 | tdim = gating_in_dim * 2 + n_attr
405 | self.attr_embedder = nn.Embedding(
406 | n_attr,
407 | n_attr,
408 | )
409 | self.attr_embedder.weight.data.copy_(
410 | torch.from_numpy(np.eye(n_attr)))
411 | self.obj_embedder = nn.Embedding(
412 | len(dset.objs) + 1,
413 | gating_in_dim,
414 | padding_idx=len(dset.objs),
415 | )
416 | pretrained_weight = load_word_embeddings(
417 | '/home/ubuntu/workspace/czsl/data/glove/glove.6B.300d.txt', dset.objs)
418 | self.obj_embedder.weight.data[:-1, :].copy_(pretrained_weight)
419 | else:
420 | raise (NotImplementedError)
421 |
422 | self.comp_network, self.gating_network, self.nummods, _ = modular_general(
423 | num_layers=num_layers,
424 | num_modules_per_layer=num_modules_per_layer,
425 | feat_dim=dset.feat_dim,
426 | inter_dim=args.emb_dim,
427 | stoch_sample=stoch_sample,
428 | use_full_model=use_full_model,
429 | tdim=tdim,
430 | inter_tdim=inter_tdim,
431 | gater_type=gater_type,
432 | )
433 |
434 | if args.static_inp:
435 | for param in self.attr_embedder.parameters():
436 | param.requires_grad = False
437 | for param in self.obj_embedder.parameters():
438 | param.requires_grad = False
439 |
440 | def compose(self, attrs, objs):
441 | obj_wt = self.obj_embedder(objs)
442 | if self.compose_type == 'nn':
443 | attr_wt = self.attr_embedder(attrs)
444 | inp_wts = torch.cat([attr_wt, obj_wt], 1) # 2D
445 | else:
446 | raise (NotImplementedError)
447 | composed_g, composed_g_wt = self.gating_network(
448 | inp_wts, return_additional=True)
449 |
450 | return composed_g
451 |
452 |
453 | class GeneralNormalizedNN(nn.Module):
454 | """Docstring for GatedCompositionalModel. """
455 |
456 | def __init__(self, num_layers, num_modules_per_layer, in_dim, inter_dim):
457 | """TODO: to be defined1. """
458 | nn.Module.__init__(self)
459 | self.start_modules = [nn.Sequential()]
460 | self.layer1 = [[
461 | nn.Sequential(
462 | nn.Linear(in_dim, inter_dim), nn.BatchNorm1d(inter_dim),
463 | nn.ReLU()) for _ in range(num_modules_per_layer)
464 | ]]
465 | if num_layers > 1:
466 | self.layer2 = [[
467 | nn.Sequential(
468 | nn.BatchNorm1d(inter_dim), nn.Linear(inter_dim, inter_dim),
469 | nn.BatchNorm1d(inter_dim), nn.ReLU())
470 | for _m in range(num_modules_per_layer)
471 | ] for _l in range(num_layers - 1)]
472 | self.avgpool = nn.Sequential()
473 | self.fc = [
474 | nn.Sequential(nn.BatchNorm1d(inter_dim), nn.Linear(inter_dim, 1))
475 | ]
476 |
477 |
478 | class GeneralGatingNN(nn.Module):
479 | def __init__(
480 | self,
481 | num_mods,
482 | tdim,
483 | inter_tdim,
484 | randinit=False,
485 | ):
486 | """TODO: to be defined1.
487 |
488 | :num_mods: TODO
489 | :tdim: TODO
490 |
491 | """
492 | nn.Module.__init__(self)
493 |
494 | self._num_mods = num_mods
495 | self._tdim = tdim
496 | self._inter_tdim = inter_tdim
497 | task_outdim = self._inter_tdim
498 |
499 | self.task_linear1 = nn.Linear(self._tdim, task_outdim, bias=False)
500 | self.task_bn1 = nn.BatchNorm1d(task_outdim)
501 | self.task_linear2 = nn.Linear(task_outdim, task_outdim, bias=False)
502 | self.task_bn2 = nn.BatchNorm1d(task_outdim)
503 | self.joint_linear1 = nn.Linear(task_outdim, task_outdim, bias=False)
504 | self.joint_bn1 = nn.BatchNorm1d(task_outdim)
505 |
506 | num_out = [[1]] + [[
507 | self._num_mods[i - 1] for _ in range(self._num_mods[i])
508 | ] for i in range(1, len(self._num_mods))]
509 | count = 0
510 | out_ind = []
511 | for i in range(len(num_out)):
512 | this_out_ind = []
513 | for j in range(len(num_out[i])):
514 | this_out_ind.append([count, count + num_out[i][j]])
515 | count += num_out[i][j]
516 | out_ind.append(this_out_ind)
517 | self.out_ind = out_ind
518 | self.out_count = count
519 |
520 | self.joint_linear2 = nn.Linear(task_outdim, count, bias=False)
521 |
522 | def apply_init(m):
523 | if isinstance(m, nn.Linear):
524 | nn.init.normal_(m.weight, 0, 0.1)
525 | if m.bias is not None:
526 | nn.init.constant_(m.bias, 0)
527 | elif isinstance(m, nn.BatchNorm1d):
528 | nn.init.constant_(m.weight, 1)
529 | nn.init.constant_(m.bias, 0)
530 |
531 | for m in self.modules():
532 | if isinstance(m, nn.ModuleList):
533 | for subm in m:
534 | if isinstance(subm, nn.ModuleList):
535 | for subsubm in subm:
536 | apply_init(subsubm)
537 | else:
538 | apply_init(subm)
539 | else:
540 | apply_init(m)
541 |
542 | if not randinit:
543 | self.joint_linear2.weight.data.zero_()
544 |
545 | def forward(self, tdesc=None):
546 | """TODO: Docstring for function.
547 |
548 | :arg1: TODO
549 | :returns: TODO
550 |
551 | """
552 | if tdesc is None:
553 | return None
554 |
555 | x = tdesc
556 | task_embeds1 = F.relu(self.task_bn1(self.task_linear1(x)))
557 | joint_embed = task_embeds1
558 | joint_embed = self.joint_linear2(joint_embed)
559 | joint_embed = [[
560 | joint_embed[:, self.out_ind[i][j][0]:self.out_ind[i][j][1]]
561 | for j in range(len(self.out_ind[i]))
562 | ] for i in range(len(self.out_ind))]
563 |
564 | gating_wt = joint_embed
565 | prob_g = [[F.softmax(wt, -1) for wt in gating_wt[i]]
566 | for i in range(len(gating_wt))]
567 | return prob_g, gating_wt
568 |
569 |
570 | def modularize_network(
571 | model,
572 | stoch_sample=False,
573 | use_full_model=False,
574 | tdim=200,
575 | inter_tdim=200,
576 | gater_type='general',
577 | single_head=True,
578 | num_classes=[2],
579 | num_lookup_gating=10,
580 | ):
581 | # Copy start modules and end modules
582 | start_modules = model.start_modules
583 | end_modules = [
584 | nn.Sequential(model.avgpool, Flatten(), fci) for fci in model.fc
585 | ]
586 |
587 | # Create module_list as list of lists [[layer1 modules], [layer2 modules], ...]
588 | module_list = []
589 | li = 1
590 | while True:
591 | if hasattr(model, 'layer{}'.format(li)):
592 | module_list.extend(getattr(model, 'layer{}'.format(li)))
593 | li += 1
594 | else:
595 | break
596 |
597 | num_module_list = [len(start_modules)] + [len(layer) for layer in module_list] \
598 | + [len(end_modules)]
599 | gated_model_func = GatedModularNet
600 | gated_net = gated_model_func(
601 | module_list,
602 | start_modules=start_modules,
603 | end_modules=end_modules,
604 | single_head=single_head)
605 |
606 | gater_func = GeneralGatingNN
607 | gater = gater_func(
608 | num_mods=num_module_list, tdim=tdim, inter_tdim=inter_tdim)
609 |
610 | fan_in = num_module_list
611 | if use_full_model:
612 | gater = None
613 |
614 | # Create Gating Sampler
615 | gating_sampler = GatingSampler(gater=gater, stoch_sample=stoch_sample)
616 | return gated_net, gating_sampler, num_module_list, fan_in
617 |
618 |
619 | def modular_general(
620 | num_layers,
621 | num_modules_per_layer,
622 | feat_dim,
623 | inter_dim,
624 | stoch_sample=False,
625 | use_full_model=False,
626 | num_classes=[2],
627 | single_head=True,
628 | gater_type='general',
629 | tdim=300,
630 | inter_tdim=300,
631 | num_lookup_gating=10,
632 | ):
633 | # First create a ResNext model
634 | model = GeneralNormalizedNN(
635 | num_layers,
636 | num_modules_per_layer,
637 | feat_dim,
638 | inter_dim,
639 | )
640 |
641 | # Modularize the model and create gating funcs
642 | gated_net, gating_sampler, num_module_list, fan_in = modularize_network(
643 | model,
644 | stoch_sample,
645 | use_full_model=use_full_model,
646 | gater_type=gater_type,
647 | tdim=tdim,
648 | inter_tdim=inter_tdim,
649 | single_head=single_head,
650 | num_classes=num_classes,
651 | num_lookup_gating=num_lookup_gating,
652 | )
653 | return gated_net, gating_sampler, num_module_list, fan_in
--------------------------------------------------------------------------------
/models/svm.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tqdm
3 | from data import dataset as dset
4 | import os
5 | from utils import utils
6 | import torch
7 | from torch.autograd import Variable
8 | import h5py
9 | from sklearn.svm import LinearSVC
10 | from sklearn.model_selection import GridSearchCV
11 | import torch.nn.functional as F
12 | from joblib import Parallel, delayed
13 | import glob
14 | import scipy.io
15 | from sklearn.calibration import CalibratedClassifierCV
16 |
17 |
18 | import argparse
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument('--batch_size', type=int, default=512)
21 | parser.add_argument('--dataset', default='mitstates', help='mitstates|zappos')
22 | parser.add_argument('--data_dir', default='data/mit-states/')
23 | parser.add_argument('--generate', action ='store_true', default=False)
24 | parser.add_argument('--evalsvm', action ='store_true', default=False)
25 | parser.add_argument('--evaltf', action ='store_true', default=False)
26 | parser.add_argument('--completed', default='tensor-completion/completed/complete.mat')
27 | args = parser.parse_args()
28 |
29 | #----------------------------------------------------------------------------------------#
30 |
31 | search_params = {'C': np.logspace(-5,5,11)}
32 | def train_svm(Y_train, sample_weight=None):
33 |
34 | try:
35 | clf = GridSearchCV(LinearSVC(class_weight='balanced', fit_intercept=False), search_params, scoring='f1', cv=4)
36 | clf.fit(X_train, Y_train, sample_weight=sample_weight)
37 | except:
38 | clf = LinearSVC(C=0.1, class_weight='balanced', fit_intercept=False)
39 | if Y_train.sum()==len(Y_train) or Y_train.sum()==0:
40 | return None
41 | clf.fit(X_train, Y_train, sample_weight=sample_weight)
42 |
43 | return clf
44 |
45 | def generate_svms():
46 | # train an SVM for every attribute, object and pair primitive
47 |
48 | Y = [(train_attrs==attr).astype(np.int) for attr in range(len(dataset.attrs))]
49 | attr_clfs = Parallel(n_jobs=32, verbose=16)(delayed(train_svm)(Y[attr]) for attr in range(len(dataset.attrs)))
50 | for attr, clf in enumerate(attr_clfs):
51 | print (attr, dataset.attrs[attr])
52 | print ('params:', clf.best_params_)
53 | print ('-'*30)
54 | torch.save(clf.best_estimator_, '%s/svm/attr_%d'%(args.data_dir, attr))
55 |
56 | Y = [(train_objs==obj).astype(np.int) for obj in range(len(dataset.objs))]
57 | obj_clfs = Parallel(n_jobs=32, verbose=16)(delayed(train_svm)(Y[obj]) for obj in range(len(dataset.objs)))
58 | for obj, clf in enumerate(obj_clfs):
59 | print (obj, dataset.objs[obj])
60 | print ('params:', clf.best_params_)
61 | print ('-'*30)
62 | torch.save(clf.best_estimator_, '%s/svm/obj_%d'%(args.data_dir, obj))
63 |
64 |
65 | Y, Y_attr, sample_weight = [], [], []
66 | for idx, (attr, obj) in enumerate(dataset.train_pairs):
67 | Y_train = ((train_attrs==attr)*(train_objs==obj)).astype(np.int)
68 |
69 | # reweight instances to get a little more training data
70 | Y_train_attr = (train_attrs==attr).astype(np.int)
71 | instance_weights = 0.1*Y_train_attr
72 | instance_weights[Y_train.nonzero()[0]] = 1.0
73 |
74 | Y.append(Y_train)
75 | Y_attr.append(Y_train_attr)
76 | sample_weight.append(instance_weights)
77 |
78 | pair_clfs = Parallel(n_jobs=32, verbose=16)(delayed(train_svm)(Y[pair]) for pair in range(len(dataset.train_pairs)))
79 | for idx, (attr, obj) in enumerate(dataset.train_pairs):
80 | clf = pair_clfs[idx]
81 | print (dataset.attrs[attr], dataset.objs[obj])
82 |
83 | try:
84 | print ('params:', clf.best_params_)
85 | torch.save(clf.best_estimator_, '%s/svm/pair_%d_%d'%(args.data_dir, attr, obj))
86 | except:
87 | print ('FAILED! #positive:', Y[idx].sum(), len(Y[idx]))
88 |
89 | return
90 |
91 | def make_svm_tensor():
92 | subs, vals, size = [], [], (len(dataset.attrs), len(dataset.objs), X.shape[1])
93 | fullsubs, fullvals = [], []
94 | composite_clfs = glob.glob('%s/svm/pair*'%args.data_dir)
95 | print ('%d composite classifiers found'%(len(composite_clfs)))
96 | for clf in tqdm.tqdm(composite_clfs):
97 | _, attr, obj = os.path.basename(clf).split('_')
98 | attr, obj = int(attr), int(obj)
99 |
100 | clf = torch.load(clf)
101 | weight = clf.coef_.squeeze()
102 | for i in range(len(weight)):
103 | subs.append((attr, obj, i))
104 | vals.append(weight[i])
105 |
106 | for attr, obj in dataset.pairs:
107 | for i in range(X.shape[1]):
108 | fullsubs.append((attr, obj, i))
109 |
110 | subs, vals = np.array(subs), np.array(vals).reshape(-1,1)
111 | fullsubs, fullvals = np.array(fullsubs), np.ones(len(fullsubs)).reshape(-1,1)
112 | savedat = {'subs':subs, 'vals':vals, 'size':size, 'fullsubs':fullsubs, 'fullvals':fullvals}
113 | scipy.io.savemat('tensor-completion/incomplete/%s.mat'%args.dataset, savedat)
114 | print (subs.shape, vals.shape, size)
115 |
116 | def evaluate_svms():
117 |
118 | attr_clfs = [torch.load('%s/svm/attr_%d'%(args.data_dir, attr)) for attr in range(len(dataset.attrs))]
119 | obj_clfs = [torch.load('%s/svm/obj_%d'%(args.data_dir, obj)) for obj in range(len(dataset.objs))]
120 |
121 | # Calibrate all classifiers first
122 | Y = [(train_attrs==attr).astype(np.int) for attr in range(len(dataset.attrs))]
123 | for attr in tqdm.tqdm(range(len(dataset.attrs))):
124 | clf = attr_clfs[attr]
125 | calibrated = CalibratedClassifierCV(clf, method='sigmoid', cv='prefit')
126 | calibrated.fit(X_train, Y[attr])
127 | attr_clfs[attr] = calibrated
128 |
129 | Y = [(train_objs==obj).astype(np.int) for obj in range(len(dataset.objs))]
130 | for obj in tqdm.tqdm(range(len(dataset.objs))):
131 | clf = obj_clfs[obj]
132 | calibrated = CalibratedClassifierCV(clf, method='sigmoid', cv='prefit')
133 | calibrated.fit(X_train, Y[obj])
134 | obj_clfs[obj] = calibrated
135 |
136 | # Generate all the scores
137 | attr_scores, obj_scores = [], []
138 | for attr in tqdm.tqdm(range(len(dataset.attrs))):
139 | clf = attr_clfs[attr]
140 | score = clf.predict_proba(X_test)[:,1]
141 | attr_scores.append(score)
142 | attr_scores = np.vstack(attr_scores)
143 |
144 | for obj in tqdm.tqdm(range(len(dataset.objs))):
145 | clf = obj_clfs[obj]
146 | score = clf.predict_proba(X_test)[:,1]
147 | obj_scores.append(score)
148 | obj_scores = np.vstack(obj_scores)
149 |
150 | attr_pred = torch.from_numpy(attr_scores).transpose(0,1)
151 | obj_pred = torch.from_numpy(obj_scores).transpose(0,1)
152 |
153 | x = [None, Variable(torch.from_numpy(test_attrs)).long(), Variable(torch.from_numpy(test_objs)).long(), Variable(torch.from_numpy(test_pairs)).long()]
154 | attr_pred, obj_pred, _ = utils.generate_prediction_tensors([attr_pred, obj_pred], dataset, x[2].data, source='classification')
155 | attr_match, obj_match, zsl_match, gzsl_match, fixobj_match = utils.performance_stats(attr_pred, obj_pred, x)
156 | print (attr_match.mean(), obj_match.mean(), zsl_match.mean(), gzsl_match.mean(), fixobj_match.mean())
157 |
158 | def evaluate_tensorcompletion():
159 |
160 | def parse_tensor(fl):
161 | tensor = scipy.io.loadmat(fl)
162 | nz_idx = zip(*(tensor['subs']))
163 | composite_clfs = np.zeros((len(dataset.attrs), len(dataset.objs), X.shape[1]))
164 | composite_clfs[nz_idx[0], nz_idx[1], nz_idx[2]] = tensor['vals'].squeeze()
165 | return composite_clfs, nz_idx, tensor['vals'].squeeze()
166 |
167 | # see recon error
168 | tr_file = 'tensor-completion/incomplete/%s.mat'%args.dataset
169 | ts_file = args.completed
170 |
171 | tr_clfs, tr_nz_idx, tr_vals = parse_tensor(tr_file)
172 | ts_clfs, ts_nz_idx, ts_vals = parse_tensor(ts_file)
173 |
174 | print (tr_vals.min(), tr_vals.max(), tr_vals.mean())
175 | print (ts_vals.min(), ts_vals.max(), ts_vals.mean())
176 |
177 | print ('Completed Tensor: %s'%args.completed)
178 |
179 | # see train recon error
180 | err = 1.0*((tr_clfs[tr_nz_idx[0], tr_nz_idx[1], tr_nz_idx[2]]-ts_clfs[tr_nz_idx[0], tr_nz_idx[1], tr_nz_idx[2]])**2).sum()/(len(tr_vals))
181 | print ('recon error:', err)
182 |
183 | # Create and scale classifiers for each pair
184 | clfs = {}
185 | test_pair_set = set(map(tuple, dataset.test_pairs.numpy().tolist()))
186 | for idx, (attr, obj) in tqdm.tqdm(enumerate(dataset.pairs), total=len(dataset.pairs)):
187 | clf = LinearSVC(fit_intercept=False)
188 | clf.fit(np.eye(2), [0,1])
189 |
190 | if (attr, obj) in test_pair_set:
191 | X_ = X_test
192 | Y_ = (test_attrs==attr).astype(np.int)*(test_objs==obj).astype(np.int)
193 | clf.coef_ = ts_clfs[attr, obj][None,:]
194 | else:
195 | X_ = X_train
196 | Y_ = (train_attrs==attr).astype(np.int)*(train_objs==obj).astype(np.int)
197 | clf.coef_ = tr_clfs[attr, obj][None,:]
198 |
199 | calibrated = CalibratedClassifierCV(clf, method='sigmoid', cv='prefit')
200 | calibrated.fit(X_, Y_)
201 | clfs[(attr, obj)] = calibrated
202 |
203 | scores = {}
204 | for attr, obj in tqdm.tqdm(dataset.pairs):
205 | score = clfs[(attr, obj)].predict_proba(X_test)[:,1]
206 | scores[(attr, obj)] = torch.from_numpy(score).float().unsqueeze(1)
207 |
208 | x = [None, Variable(torch.from_numpy(test_attrs)).long(), Variable(torch.from_numpy(test_objs)).long(), Variable(torch.from_numpy(test_pairs)).long()]
209 | attr_pred, obj_pred, _ = utils.generate_prediction_tensors(scores, dataset, x[2].data, source='manifold')
210 | attr_match, obj_match, zsl_match, gzsl_match, fixobj_match = utils.performance_stats(attr_pred, obj_pred, x)
211 | print (attr_match.mean(), obj_match.mean(), zsl_match.mean(), gzsl_match.mean(), fixobj_match.mean())
212 |
213 |
214 |
215 | #----------------------------------------------------------------------------------------#
216 | if args.dataset == 'mitstates':
217 | DSet = dset.MITStatesActivations
218 | elif args.dataset == 'zappos':
219 | DSet = dset.UTZapposActivations
220 |
221 | dataset = DSet(root=args.data_dir, phase='train')
222 | train_idx, train_attrs, train_objs, train_pairs = map(np.array, zip(*dataset.train_data))
223 | test_idx, test_attrs, test_objs, test_pairs = map(np.array, zip(*dataset.test_data))
224 |
225 | X = dataset.activations.numpy()
226 | X_train, X_test = X[train_idx,:], X[test_idx,:]
227 |
228 | print (len(dataset.attrs), len(dataset.objs), len(dataset.pairs))
229 | print (X_train.shape, X_test.shape)
230 |
231 | if args.generate:
232 | generate_svms()
233 | make_svm_tensor()
234 |
235 | if args.evalsvm:
236 | evaluate_svms()
237 |
238 | if args.evaltf:
239 | evaluate_tensorcompletion()
240 |
--------------------------------------------------------------------------------
/models/symnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from .word_embedding import load_word_embeddings
6 | from .common import MLP
7 |
8 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
9 |
10 |
11 |
12 | class Symnet(nn.Module):
13 | def __init__(self, dset, args):
14 | super(Symnet, self).__init__()
15 | self.dset = dset
16 | self.args = args
17 | self.num_attrs = len(dset.attrs)
18 | self.num_objs = len(dset.objs)
19 |
20 | self.image_embedder = MLP(dset.feat_dim, args.emb_dim, relu = False)
21 | self.ce_loss = nn.CrossEntropyLoss()
22 | self.mse_loss = nn.MSELoss()
23 | self.nll_loss = nn.NLLLoss()
24 | self.softmax = nn.Softmax(dim=1)
25 | # Attribute embedder and object classifier
26 | self.attr_embedder = nn.Embedding(len(dset.attrs), args.emb_dim)
27 |
28 | self.obj_classifier = MLP(args.emb_dim, len(dset.objs), num_layers = 2, relu = False, dropout = True, layers = [512]) # original paper uses a second fc classifier for final product
29 | self.attr_classifier = MLP(args.emb_dim, len(dset.attrs), num_layers = 2, relu = False, dropout = True, layers = [512])
30 | # CoN and DecoN
31 | self.CoN_fc_attention = MLP(args.emb_dim, args.emb_dim, num_layers = 2, relu = False, dropout = True, layers = [512])
32 | self.CoN_emb = MLP(args.emb_dim + args.emb_dim, args.emb_dim, num_layers = 2, relu = False, dropout = True, layers = [768])
33 | self.DecoN_fc_attention = MLP(args.emb_dim, args.emb_dim, num_layers = 2, relu = False, dropout = True, layers = [512])
34 | self.DecoN_emb = MLP(args.emb_dim + args.emb_dim, args.emb_dim, num_layers = 2, relu = False, dropout = True, layers = [768])
35 |
36 | # if args.glove_init:
37 | pretrained_weight = load_word_embeddings(args.emb_init, dset.attrs)
38 | self.attr_embedder.weight.data.copy_(pretrained_weight)
39 |
40 | for param in self.attr_embedder.parameters():
41 | param.requires_grad = False
42 |
43 | def CoN(self, img_embedding, attr_embedding):
44 | attention = torch.sigmoid(self.CoN_fc_attention(attr_embedding))
45 | img_embedding = attention*img_embedding + img_embedding
46 |
47 | hidden = torch.cat([img_embedding, attr_embedding], dim = 1)
48 | output = self.CoN_emb(hidden)
49 | return output
50 |
51 | def DeCoN(self, img_embedding, attr_embedding):
52 | attention = torch.sigmoid(self.DecoN_fc_attention(attr_embedding))
53 | img_embedding = attention*img_embedding + img_embedding
54 |
55 | hidden = torch.cat([img_embedding, attr_embedding], dim = 1)
56 | output = self.DecoN_emb(hidden)
57 | return output
58 |
59 |
60 | def distance_metric(self, a, b):
61 | return torch.norm(a-b, dim = -1)
62 |
63 | def RMD_prob(self, feat_plus, feat_minus, repeat_img_feat):
64 | """return attribute classification probability with our RMD"""
65 | # feat_plus, feat_minus: shape=(bz, #attr, dim_emb)
66 | # d_plus: distance between feature before&after CoN
67 | # d_minus: distance between feature before&after DecoN
68 | d_plus = self.distance_metric(feat_plus, repeat_img_feat).reshape(-1, self.num_attrs)
69 | d_minus = self.distance_metric(feat_minus, repeat_img_feat).reshape(-1, self.num_attrs)
70 | # not adding softmax because it is part of cross entropy loss
71 | return d_minus - d_plus
72 |
73 | def train_forward(self, x):
74 | pos_image_feat, pos_attr_id, pos_obj_id = x[0], x[1], x[2]
75 | neg_attr_id = x[4][:,0]
76 |
77 | batch_size = pos_image_feat.size(0)
78 |
79 | loss = []
80 | pos_attr_emb = self.attr_embedder(pos_attr_id)
81 | neg_attr_emb = self.attr_embedder(neg_attr_id)
82 |
83 | pos_img = self.image_embedder(pos_image_feat)
84 |
85 | # rA = remove positive attribute A
86 | # aA = add positive attribute A
87 | # rB = remove negative attribute B
88 | # aB = add negative attribute B
89 | pos_rA = self.DeCoN(pos_img, pos_attr_emb)
90 | pos_aA = self.CoN(pos_img, pos_attr_emb)
91 | pos_rB = self.DeCoN(pos_img, neg_attr_emb)
92 | pos_aB = self.CoN(pos_img, neg_attr_emb)
93 |
94 | # get all attr embedding #attr, embedding
95 | attr_emb = torch.LongTensor(np.arange(self.num_attrs)).to(device)
96 | attr_emb = self.attr_embedder(attr_emb)
97 | tile_attr_emb = attr_emb.repeat(batch_size, 1) # (batch*attr, dim_emb)
98 |
99 | # Now we calculate all the losses
100 | if self.args.lambda_cls_attr > 0:
101 | # Original image
102 | score_pos_A = self.attr_classifier(pos_img)
103 | loss_cls_pos_a = self.ce_loss(score_pos_A, pos_attr_id)
104 |
105 | # After removing pos_attr
106 | score_pos_rA_A = self.attr_classifier(pos_rA)
107 | total = sum(score_pos_rA_A)
108 | # prob_pos_rA_A = 1 - self.softmax(score_pos_rA_A)
109 | # loss_cls_pos_rA_a = self.nll_loss(torch.log(prob_pos_rA_A), pos_attr_id)
110 | loss_cls_pos_rA_a = self.ce_loss(total - score_pos_rA_A, pos_attr_id) #should be maximum for the gt label
111 |
112 | # rmd time
113 | repeat_img_feat = torch.repeat_interleave(pos_img, self.num_attrs, 0) #(batch*attr, dim_rep)
114 | feat_plus = self.CoN(repeat_img_feat, tile_attr_emb)
115 | feat_minus = self.DeCoN(repeat_img_feat, tile_attr_emb)
116 | score_cls_rmd = self.RMD_prob(feat_plus, feat_minus, repeat_img_feat)
117 | loss_cls_rmd = self.ce_loss(score_cls_rmd, pos_attr_id)
118 |
119 | loss_cls_attr = self.args.lambda_cls_attr*sum([loss_cls_pos_a, loss_cls_pos_rA_a, loss_cls_rmd])
120 | loss.append(loss_cls_attr)
121 |
122 | if self.args.lambda_cls_obj > 0:
123 | # Original image
124 | score_pos_O = self.obj_classifier(pos_img)
125 | loss_cls_pos_o = self.ce_loss(score_pos_O, pos_obj_id)
126 |
127 | # After removing pos attr
128 | score_pos_rA_O = self.obj_classifier(pos_rA)
129 | loss_cls_pos_rA_o = self.ce_loss(score_pos_rA_O, pos_obj_id)
130 |
131 | # After adding neg attr
132 | score_pos_aB_O = self.obj_classifier(pos_aB)
133 | loss_cls_pos_aB_o = self.ce_loss(score_pos_aB_O, pos_obj_id)
134 |
135 | loss_cls_obj = self.args.lambda_cls_obj * sum([loss_cls_pos_o, loss_cls_pos_rA_o, loss_cls_pos_aB_o])
136 |
137 | loss.append(loss_cls_obj)
138 |
139 | if self.args.lambda_sym > 0:
140 |
141 | loss_sys_pos = self.mse_loss(pos_aA, pos_img)
142 | loss_sys_neg = self.mse_loss(pos_rB, pos_img)
143 | loss_sym = self.args.lambda_sym * (loss_sys_pos + loss_sys_neg)
144 |
145 | loss.append(loss_sym)
146 |
147 | ##### Axiom losses
148 | if self.args.lambda_axiom > 0:
149 | loss_clo = loss_inv = loss_com = 0
150 | # closure
151 | pos_aA_rA = self.DeCoN(pos_aA, pos_attr_emb)
152 | pos_rB_aB = self.CoN(pos_rB, neg_attr_emb)
153 | loss_clo = self.mse_loss(pos_aA_rA, pos_rA) + self.mse_loss(pos_rB_aB, pos_aB)
154 |
155 | # invertibility
156 | pos_rA_aA = self.CoN(pos_rA, pos_attr_emb)
157 | pos_aB_rB = self.DeCoN(pos_aB, neg_attr_emb)
158 | loss_inv = self.mse_loss(pos_rA_aA, pos_img) + self.mse_loss(pos_aB_rB, pos_img)
159 |
160 | # commutative
161 | pos_aA_rB = self.DeCoN(pos_aA, neg_attr_emb)
162 | pos_rB_aA = self.DeCoN(pos_rB, pos_attr_emb)
163 | loss_com = self.mse_loss(pos_aA_rB, pos_rB_aA)
164 |
165 | loss_axiom = self.args.lambda_axiom * (loss_clo + loss_inv + loss_com)
166 | loss.append(loss_axiom)
167 |
168 | # triplet loss
169 | if self.args.lambda_trip > 0:
170 | pos_triplet = F.triplet_margin_loss(pos_img, pos_aA, pos_rA)
171 | neg_triplet = F.triplet_margin_loss(pos_img, pos_rB, pos_aB)
172 |
173 | loss_triplet = self.args.lambda_trip * (pos_triplet + neg_triplet)
174 | loss.append(loss_triplet)
175 |
176 |
177 | loss = sum(loss)
178 | return loss, None
179 |
180 | def val_forward(self, x):
181 | pos_image_feat, pos_attr_id, pos_obj_id = x[0], x[1], x[2]
182 | batch_size = pos_image_feat.shape[0]
183 | pos_img = self.image_embedder(pos_image_feat)
184 | repeat_img_feat = torch.repeat_interleave(pos_img, self.num_attrs, 0) #(batch*attr, dim_rep)
185 |
186 | # get all attr embedding #attr, embedding
187 | attr_emb = torch.LongTensor(np.arange(self.num_attrs)).to(device)
188 | attr_emb = self.attr_embedder(attr_emb)
189 | tile_attr_emb = attr_emb.repeat(batch_size, 1) # (batch*attr, dim_emb)
190 |
191 | feat_plus = self.CoN(repeat_img_feat, tile_attr_emb)
192 | feat_minus = self.DeCoN(repeat_img_feat, tile_attr_emb)
193 | score_cls_rmd = self.RMD_prob(feat_plus, feat_minus, repeat_img_feat)
194 | prob_A_rmd = F.softmax(score_cls_rmd, dim = 1)
195 |
196 | score_obj = self.obj_classifier(pos_img)
197 | prob_O = F.softmax(score_obj, dim = 1)
198 |
199 | scores = {}
200 | for itr, (attr, obj) in enumerate(self.dset.pairs):
201 | attr_id, obj_id = self.dset.attr2idx[attr], self.dset.obj2idx[obj]
202 | score = prob_A_rmd[:,attr_id] * prob_O[:, obj_id]
203 |
204 | scores[(attr, obj)] = score
205 |
206 | return None, scores
207 |
208 | def forward(self, x):
209 | if self.training:
210 | loss, pred = self.train_forward(x)
211 | else:
212 | with torch.no_grad():
213 | loss, pred = self.val_forward(x)
214 | return loss, pred
--------------------------------------------------------------------------------
/models/visual_product.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torchvision.models as models
5 | import numpy as np
6 | from .common import MLP
7 |
8 | class VisualProductNN(nn.Module):
9 | def __init__(self, dset, args):
10 | super(VisualProductNN, self).__init__()
11 | self.attr_clf = MLP(dset.feat_dim, len(dset.attrs), 2, relu = False)
12 | self.obj_clf = MLP(dset.feat_dim, len(dset.objs), 2, relu = False)
13 | self.dset = dset
14 |
15 | def train_forward(self, x):
16 | img, attrs, objs = x[0],x[1], x[2]
17 |
18 | attr_pred = self.attr_clf(img)
19 | obj_pred = self.obj_clf(img)
20 |
21 | attr_loss = F.cross_entropy(attr_pred, attrs)
22 | obj_loss = F.cross_entropy(obj_pred, objs)
23 |
24 | loss = attr_loss + obj_loss
25 |
26 | return loss, None
27 |
28 | def val_forward(self, x):
29 | img = x[0]
30 | attr_pred = F.softmax(self.attr_clf(img), dim =1)
31 | obj_pred = F.softmax(self.obj_clf(img), dim = 1)
32 |
33 | scores = {}
34 | for itr, (attr, obj) in enumerate(self.dset.pairs):
35 | attr_id, obj_id = self.dset.attr2idx[attr], self.dset.obj2idx[obj]
36 | score = attr_pred[:,attr_id] * obj_pred[:, obj_id]
37 |
38 | scores[(attr, obj)] = score
39 | return None, scores
40 |
41 | def forward(self, x):
42 | if self.training:
43 | loss, pred = self.train_forward(x)
44 | else:
45 | with torch.no_grad():
46 | loss, pred = self.val_forward(x)
47 |
48 | return loss, pred
--------------------------------------------------------------------------------
/models/word_embedding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from flags import DATA_FOLDER
4 | def load_word_embeddings(emb_type, vocab):
5 | if emb_type == 'glove':
6 | embeds = load_glove_embeddings(vocab)
7 | elif emb_type == 'fasttext':
8 | embeds = load_fasttext_embeddings(vocab)
9 | elif emb_type == 'word2vec':
10 | embeds = load_word2vec_embeddings(vocab)
11 | elif emb_type == 'ft+w2v':
12 | embeds1 = load_fasttext_embeddings(vocab)
13 | embeds2 = load_word2vec_embeddings(vocab)
14 | embeds = torch.cat([embeds1, embeds2], dim = 1)
15 | print('Combined embeddings are ',embeds.shape)
16 | elif emb_type == 'ft+gl':
17 | embeds1 = load_fasttext_embeddings(vocab)
18 | embeds2 = load_glove_embeddings(vocab)
19 | embeds = torch.cat([embeds1, embeds2], dim = 1)
20 | print('Combined embeddings are ',embeds.shape)
21 | elif emb_type == 'ft+ft':
22 | embeds1 = load_fasttext_embeddings(vocab)
23 | embeds = torch.cat([embeds1, embeds1], dim = 1)
24 | print('Combined embeddings are ',embeds.shape)
25 | elif emb_type == 'gl+w2v':
26 | embeds1 = load_glove_embeddings(vocab)
27 | embeds2 = load_word2vec_embeddings(vocab)
28 | embeds = torch.cat([embeds1, embeds2], dim = 1)
29 | print('Combined embeddings are ',embeds.shape)
30 | elif emb_type == 'ft+w2v+gl':
31 | embeds1 = load_fasttext_embeddings(vocab)
32 | embeds2 = load_word2vec_embeddings(vocab)
33 | embeds3 = load_glove_embeddings(vocab)
34 | embeds = torch.cat([embeds1, embeds2, embeds3], dim = 1)
35 | print('Combined embeddings are ',embeds.shape)
36 | else:
37 | raise ValueError('Invalid embedding')
38 | return embeds
39 |
40 | def load_fasttext_embeddings(vocab):
41 | custom_map = {
42 | 'Faux.Fur': 'fake fur',
43 | 'Faux.Leather': 'fake leather',
44 | 'Full.grain.leather': 'thick leather',
45 | 'Hair.Calf': 'hairy leather',
46 | 'Patent.Leather': 'shiny leather',
47 | 'Boots.Ankle': 'ankle boots',
48 | 'Boots.Knee.High': 'kneehigh boots',
49 | 'Boots.Mid-Calf': 'midcalf boots',
50 | 'Shoes.Boat.Shoes': 'boatshoes',
51 | 'Shoes.Clogs.and.Mules': 'clogs shoes',
52 | 'Shoes.Flats': 'flats shoes',
53 | 'Shoes.Heels': 'heels',
54 | 'Shoes.Loafers': 'loafers',
55 | 'Shoes.Oxfords': 'oxford shoes',
56 | 'Shoes.Sneakers.and.Athletic.Shoes': 'sneakers',
57 | 'traffic_light': 'traficlight',
58 | 'trash_can': 'trashcan',
59 | 'dry-erase_board' : 'dry_erase_board',
60 | 'black_and_white' : 'black_white',
61 | 'eiffel_tower' : 'tower'
62 | }
63 | vocab_lower = [v.lower() for v in vocab]
64 | vocab = []
65 | for current in vocab_lower:
66 | if current in custom_map:
67 | vocab.append(custom_map[current])
68 | else:
69 | vocab.append(current)
70 |
71 | import fasttext.util
72 | ft = fasttext.load_model(DATA_FOLDER+'/fast/cc.en.300.bin')
73 | embeds = []
74 | for k in vocab:
75 | if '_' in k:
76 | ks = k.split('_')
77 | emb = np.stack([ft.get_word_vector(it) for it in ks]).mean(axis=0)
78 | else:
79 | emb = ft.get_word_vector(k)
80 | embeds.append(emb)
81 |
82 | embeds = torch.Tensor(np.stack(embeds))
83 | print('Fasttext Embeddings loaded, total embeddings: {}'.format(embeds.size()))
84 | return embeds
85 |
86 | def load_word2vec_embeddings(vocab):
87 | # vocab = [v.lower() for v in vocab]
88 |
89 | from gensim import models
90 | model = models.KeyedVectors.load_word2vec_format(
91 | DATA_FOLDER+'/w2v/GoogleNews-vectors-negative300.bin', binary=True)
92 |
93 | custom_map = {
94 | 'Faux.Fur': 'fake_fur',
95 | 'Faux.Leather': 'fake_leather',
96 | 'Full.grain.leather': 'thick_leather',
97 | 'Hair.Calf': 'hair_leather',
98 | 'Patent.Leather': 'shiny_leather',
99 | 'Boots.Ankle': 'ankle_boots',
100 | 'Boots.Knee.High': 'knee_high_boots',
101 | 'Boots.Mid-Calf': 'midcalf_boots',
102 | 'Shoes.Boat.Shoes': 'boat_shoes',
103 | 'Shoes.Clogs.and.Mules': 'clogs_shoes',
104 | 'Shoes.Flats': 'flats_shoes',
105 | 'Shoes.Heels': 'heels',
106 | 'Shoes.Loafers': 'loafers',
107 | 'Shoes.Oxfords': 'oxford_shoes',
108 | 'Shoes.Sneakers.and.Athletic.Shoes': 'sneakers',
109 | 'traffic_light': 'traffic_light',
110 | 'trash_can': 'trashcan',
111 | 'dry-erase_board' : 'dry_erase_board',
112 | 'black_and_white' : 'black_white',
113 | 'eiffel_tower' : 'tower'
114 | }
115 |
116 | embeds = []
117 | for k in vocab:
118 | if k in custom_map:
119 | k = custom_map[k]
120 | if '_' in k and k not in model:
121 | ks = k.split('_')
122 | emb = np.stack([model[it] for it in ks]).mean(axis=0)
123 | else:
124 | emb = model[k]
125 | embeds.append(emb)
126 | embeds = torch.Tensor(np.stack(embeds))
127 | print('Word2Vec Embeddings loaded, total embeddings: {}'.format(embeds.size()))
128 | return embeds
129 |
130 | def load_glove_embeddings(vocab):
131 | '''
132 | Inputs
133 | emb_file: Text file with word embedding pairs e.g. Glove, Processed in lower case.
134 | vocab: List of words
135 | Returns
136 | Embedding Matrix
137 | '''
138 | vocab = [v.lower() for v in vocab]
139 | emb_file = DATA_FOLDER+'/glove/glove.6B.300d.txt'
140 | model = {} # populating a dictionary of word and embeddings
141 | for line in open(emb_file, 'r'):
142 | line = line.strip().split(' ') # Word-embedding
143 | wvec = torch.FloatTensor(list(map(float, line[1:])))
144 | model[line[0]] = wvec
145 |
146 | # Adding some vectors for UT Zappos
147 | custom_map = {
148 | 'faux.fur': 'fake_fur',
149 | 'faux.leather': 'fake_leather',
150 | 'full.grain.leather': 'thick_leather',
151 | 'hair.calf': 'hair_leather',
152 | 'patent.leather': 'shiny_leather',
153 | 'boots.ankle': 'ankle_boots',
154 | 'boots.knee.high': 'knee_high_boots',
155 | 'boots.mid-calf': 'midcalf_boots',
156 | 'shoes.boat.shoes': 'boat_shoes',
157 | 'shoes.clogs.and.mules': 'clogs_shoes',
158 | 'shoes.flats': 'flats_shoes',
159 | 'shoes.heels': 'heels',
160 | 'shoes.loafers': 'loafers',
161 | 'shoes.oxfords': 'oxford_shoes',
162 | 'shoes.sneakers.and.athletic.shoes': 'sneakers',
163 | 'traffic_light': 'traffic_light',
164 | 'trash_can': 'trashcan',
165 | 'dry-erase_board' : 'dry_erase_board',
166 | 'black_and_white' : 'black_white',
167 | 'eiffel_tower' : 'tower',
168 | 'nubuck' : 'grainy_leather',
169 | }
170 |
171 | embeds = []
172 | for k in vocab:
173 | if k in custom_map:
174 | k = custom_map[k]
175 | if '_' in k:
176 | ks = k.split('_')
177 | emb = torch.stack([model[it] for it in ks]).mean(dim=0)
178 | else:
179 | emb = model[k]
180 | embeds.append(emb)
181 | embeds = torch.stack(embeds)
182 | print('Glove Embeddings loaded, total embeddings: {}'.format(embeds.size()))
183 | return embeds
--------------------------------------------------------------------------------
/notebooks/analysis.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import sys\n",
10 | "sys.path.append('../')\n",
11 | "# Torch imports\n",
12 | "import torch\n",
13 | "import torch.nn as nn\n",
14 | "import torch.optim as optim\n",
15 | "import torch.nn.functional as F\n",
16 | "from torch.utils.tensorboard import SummaryWriter\n",
17 | "import torch.backends.cudnn as cudnn\n",
18 | "cudnn.benchmark = True\n",
19 | "\n",
20 | "# Python imports\n",
21 | "import numpy as np\n",
22 | "import tqdm\n",
23 | "import torchvision.models as tmodels\n",
24 | "from tqdm import tqdm\n",
25 | "import os\n",
26 | "from os.path import join as ospj\n",
27 | "import itertools\n",
28 | "import glob\n",
29 | "import random\n",
30 | "\n",
31 | "#Local imports\n",
32 | "from data import dataset as dset\n",
33 | "from models.common import Evaluator\n",
34 | "from models.image_extractor import get_image_extractor\n",
35 | "from models.manifold_methods import RedWine, LabelEmbedPlus, AttributeOperator\n",
36 | "from models.modular_methods import GatedGeneralNN\n",
37 | "from models.symnet import Symnet\n",
38 | "from utils.utils import save_args, UnNormalizer, load_args\n",
39 | "from utils.config_model import configure_model\n",
40 | "from flags import parser\n",
41 | "from PIL import Image\n",
42 | "import matplotlib.pyplot as plt\n",
43 | "import importlib\n",
44 | "\n",
45 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
46 | "args, unknown = parser.parse_known_args()"
47 | ]
48 | },
49 | {
50 | "cell_type": "markdown",
51 | "metadata": {},
52 | "source": [
53 | "### Run one of the cells to load the dataset you want to run test for and move to the next section"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": null,
59 | "metadata": {},
60 | "outputs": [],
61 | "source": [
62 | "best_mit = '../logs/graphembed/mitstates/base/mit.yml'\n",
63 | "load_args(best_mit,args)\n",
64 | "args.graph_init = '../'+args.graph_init\n",
65 | "args.load = best_mit[:-7] + 'ckpt_best_auc.t7'"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": null,
71 | "metadata": {},
72 | "outputs": [],
73 | "source": [
74 | "best_ut = '../logs/graphembed/utzappos/base/utzappos.yml'\n",
75 | "load_args(best_ut,args)\n",
76 | "args.graph_init = '../'+args.graph_init\n",
77 | "args.load = best_ut[:-12] + 'ckpt_best_auc.t7'"
78 | ]
79 | },
80 | {
81 | "cell_type": "markdown",
82 | "metadata": {},
83 | "source": [
84 | "### Loading arguments and dataset"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": null,
90 | "metadata": {},
91 | "outputs": [],
92 | "source": [
93 | "args.data_dir = '../'+args.data_dir\n",
94 | "args.test_set = 'test'\n",
95 | "testset = dset.CompositionDataset(\n",
96 | " root= args.data_dir,\n",
97 | " phase=args.test_set,\n",
98 | " split=args.splitname,\n",
99 | " model =args.image_extractor,\n",
100 | " subset=args.subset,\n",
101 | " return_images = True,\n",
102 | " update_features = args.update_features,\n",
103 | " clean_only = args.clean_only\n",
104 | " )\n",
105 | "testloader = torch.utils.data.DataLoader(\n",
106 | " testset,\n",
107 | " batch_size=args.test_batch_size,\n",
108 | " shuffle=True,\n",
109 | " num_workers=args.workers)\n",
110 | "\n",
111 | "print('Objs ', len(testset.objs), ' Attrs ', len(testset.attrs))"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": null,
117 | "metadata": {},
118 | "outputs": [],
119 | "source": [
120 | "image_extractor, model, optimizer = configure_model(args, testset)\n",
121 | "evaluator = Evaluator(testset, model)"
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": null,
127 | "metadata": {},
128 | "outputs": [],
129 | "source": [
130 | "if args.load is not None:\n",
131 | " checkpoint = torch.load(args.load)\n",
132 | " if image_extractor:\n",
133 | " try:\n",
134 | " image_extractor.load_state_dict(checkpoint['image_extractor'])\n",
135 | " image_extractor.eval()\n",
136 | " except:\n",
137 | " print('No Image extractor in checkpoint')\n",
138 | " model.load_state_dict(checkpoint['net'])\n",
139 | " model.eval()\n",
140 | " print('Loaded model from ', args.load)\n",
141 | " print('Best AUC: ', checkpoint['AUC'])"
142 | ]
143 | },
144 | {
145 | "cell_type": "code",
146 | "execution_count": null,
147 | "metadata": {},
148 | "outputs": [],
149 | "source": [
150 | "def print_results(scores, exp):\n",
151 | " print(exp)\n",
152 | " result = scores[exp]\n",
153 | " attr = [evaluator.dset.attrs[result[0][idx,a]] for a in range(topk)]\n",
154 | " obj = [evaluator.dset.objs[result[1][idx,a]] for a in range(topk)]\n",
155 | " attr_gt, obj_gt = evaluator.dset.attrs[data[1][idx]], evaluator.dset.objs[data[2][idx]]\n",
156 | " print(f'Ground truth: {attr_gt} {obj_gt}')\n",
157 | " prediction = ''\n",
158 | " for a,o in zip(attr, obj):\n",
159 | " prediction += a + ' ' + o + '| '\n",
160 | " print('Predictions: ', prediction)\n",
161 | " print('__'*50)"
162 | ]
163 | },
164 | {
165 | "cell_type": "markdown",
166 | "metadata": {},
167 | "source": [
168 | "### An example of predictions\n",
169 | "closed -> Biased for unseen classes\n",
170 | "\n",
171 | "unbiiased -> Biased against unseen classes"
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "execution_count": null,
177 | "metadata": {},
178 | "outputs": [],
179 | "source": [
180 | "data = next(iter(testloader))\n",
181 | "images = data[-1]\n",
182 | "data = [d.to(device) for d in data[:-1]]\n",
183 | "if image_extractor:\n",
184 | " data[0] = image_extractor(data[0])\n",
185 | "_, predictions = model(data)\n",
186 | "data = [d.to('cpu') for d in data]\n",
187 | "topk = 5\n",
188 | "results = evaluator.score_model(predictions, data[2], bias = 1000, topk=topk)"
189 | ]
190 | },
191 | {
192 | "cell_type": "code",
193 | "execution_count": null,
194 | "metadata": {
195 | "scrolled": true
196 | },
197 | "outputs": [],
198 | "source": [
199 | "for idx in range(len(images)):\n",
200 | " seen = bool(evaluator.seen_mask[data[3][idx]])\n",
201 | " if seen:\n",
202 | " continue\n",
203 | " image = Image.open(ospj( args.data_dir,'images', images[idx]))\n",
204 | " \n",
205 | " plt.figure(dpi=300)\n",
206 | " plt.imshow(image)\n",
207 | " plt.axis('off')\n",
208 | " plt.show()\n",
209 | " \n",
210 | " print(f'GT pair seen: {seen}')\n",
211 | " print_results(results, 'closed')\n",
212 | " print_results(results, 'unbiased_closed')"
213 | ]
214 | },
215 | {
216 | "cell_type": "markdown",
217 | "metadata": {},
218 | "source": [
219 | "### Run Evaluation"
220 | ]
221 | },
222 | {
223 | "cell_type": "code",
224 | "execution_count": null,
225 | "metadata": {},
226 | "outputs": [],
227 | "source": [
228 | "model.eval()\n",
229 | "args.bias = 1e3\n",
230 | "accuracies, all_attr_gt, all_obj_gt, all_pair_gt, all_pred = [], [], [], [], []\n",
231 | "\n",
232 | "for idx, data in tqdm(enumerate(testloader), total=len(testloader), desc = 'Testing'):\n",
233 | " data.pop()\n",
234 | " data = [d.to(device) for d in data]\n",
235 | " if image_extractor:\n",
236 | " data[0] = image_extractor(data[0])\n",
237 | "\n",
238 | " _, predictions = model(data) # todo: Unify outputs across models\n",
239 | "\n",
240 | " attr_truth, obj_truth, pair_truth = data[1], data[2], data[3]\n",
241 | " all_pred.append(predictions)\n",
242 | " all_attr_gt.append(attr_truth)\n",
243 | " all_obj_gt.append(obj_truth)\n",
244 | " all_pair_gt.append(pair_truth)\n",
245 | "\n",
246 | "all_attr_gt, all_obj_gt, all_pair_gt = torch.cat(all_attr_gt), torch.cat(all_obj_gt), torch.cat(all_pair_gt)\n",
247 | "\n",
248 | "all_pred_dict = {}\n",
249 | "# Gather values as dict of (attr, obj) as key and list of predictions as values\n",
250 | "for k in all_pred[0].keys():\n",
251 | " all_pred_dict[k] = torch.cat(\n",
252 | " [all_pred[i][k] for i in range(len(all_pred))])\n",
253 | "\n",
254 | "# Calculate best unseen accuracy\n",
255 | "attr_truth, obj_truth = all_attr_gt.to('cpu'), all_obj_gt.to('cpu')\n",
256 | "pairs = list(\n",
257 | " zip(list(attr_truth.numpy()), list(obj_truth.numpy())))"
258 | ]
259 | },
260 | {
261 | "cell_type": "code",
262 | "execution_count": null,
263 | "metadata": {},
264 | "outputs": [],
265 | "source": [
266 | "topk = 1 ### For topk results\n",
267 | "our_results = evaluator.score_model(all_pred_dict, all_obj_gt, bias = 1e3, topk = topk)\n",
268 | "stats = evaluator.evaluate_predictions(our_results, all_attr_gt, all_obj_gt, all_pair_gt, all_pred_dict, topk = topk)"
269 | ]
270 | },
271 | {
272 | "cell_type": "code",
273 | "execution_count": null,
274 | "metadata": {},
275 | "outputs": [],
276 | "source": [
277 | "for k, v in stats.items():\n",
278 | " print(k, v)"
279 | ]
280 | },
281 | {
282 | "cell_type": "code",
283 | "execution_count": null,
284 | "metadata": {},
285 | "outputs": [],
286 | "source": []
287 | }
288 | ],
289 | "metadata": {
290 | "kernelspec": {
291 | "display_name": "Python 3",
292 | "language": "python",
293 | "name": "python3"
294 | },
295 | "language_info": {
296 | "codemirror_mode": {
297 | "name": "ipython",
298 | "version": 3
299 | },
300 | "file_extension": ".py",
301 | "mimetype": "text/x-python",
302 | "name": "python",
303 | "nbconvert_exporter": "python",
304 | "pygments_lexer": "ipython3",
305 | "version": "3.8.5"
306 | }
307 | },
308 | "nbformat": 4,
309 | "nbformat_minor": 4
310 | }
311 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | # Torch imports
2 | import torch
3 | from torch.utils.tensorboard import SummaryWriter
4 | import torch.backends.cudnn as cudnn
5 | import numpy as np
6 | from flags import DATA_FOLDER
7 |
8 | cudnn.benchmark = True
9 |
10 | # Python imports
11 | import tqdm
12 | from tqdm import tqdm
13 | import os
14 | from os.path import join as ospj
15 |
16 | # Local imports
17 | from data import dataset as dset
18 | from models.common import Evaluator
19 | from utils.utils import load_args
20 | from utils.config_model import configure_model
21 | from flags import parser
22 |
23 |
24 |
25 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
26 |
27 | def main():
28 | # Get arguments and start logging
29 | args = parser.parse_args()
30 | logpath = args.logpath
31 | config = [os.path.join(logpath, _) for _ in os.listdir(logpath) if _.endswith('yml')][0]
32 | load_args(config, args)
33 |
34 | # Get dataset
35 | trainset = dset.CompositionDataset(
36 | root=os.path.join(DATA_FOLDER,args.data_dir),
37 | phase='train',
38 | split=args.splitname,
39 | model=args.image_extractor,
40 | update_features=args.update_features,
41 | train_only=args.train_only,
42 | subset=args.subset,
43 | open_world=args.open_world
44 | )
45 |
46 | valset = dset.CompositionDataset(
47 | root=os.path.join(DATA_FOLDER,args.data_dir),
48 | phase='val',
49 | split=args.splitname,
50 | model=args.image_extractor,
51 | subset=args.subset,
52 | update_features=args.update_features,
53 | open_world=args.open_world
54 | )
55 |
56 | valoader = torch.utils.data.DataLoader(
57 | valset,
58 | batch_size=args.test_batch_size,
59 | shuffle=False,
60 | num_workers=8)
61 |
62 | testset = dset.CompositionDataset(
63 | root=os.path.join(DATA_FOLDER,args.data_dir),
64 | phase='test',
65 | split=args.splitname,
66 | model =args.image_extractor,
67 | subset=args.subset,
68 | update_features = args.update_features,
69 | open_world=args.open_world
70 | )
71 | testloader = torch.utils.data.DataLoader(
72 | testset,
73 | batch_size=args.test_batch_size,
74 | shuffle=False,
75 | num_workers=args.workers)
76 |
77 |
78 | # Get model and optimizer
79 | image_extractor, model, optimizer = configure_model(args, trainset)
80 | args.extractor = image_extractor
81 |
82 |
83 | args.load = ospj(logpath,'ckpt_best_auc.t7')
84 |
85 | checkpoint = torch.load(args.load)
86 | if image_extractor:
87 | try:
88 | image_extractor.load_state_dict(checkpoint['image_extractor'])
89 | image_extractor.eval()
90 | except:
91 | print('No Image extractor in checkpoint')
92 | model.load_state_dict(checkpoint['net'])
93 | model.eval()
94 |
95 | threshold = None
96 | if args.open_world and args.hard_masking:
97 | assert args.model == 'compcos', args.model + ' does not have hard masking.'
98 | if args.threshold is not None:
99 | threshold = args.threshold
100 | else:
101 | evaluator_val = Evaluator(valset, model)
102 | unseen_scores = model.compute_feasibility().to('cpu')
103 | seen_mask = model.seen_mask.to('cpu')
104 | min_feasibility = (unseen_scores+seen_mask*10.).min()
105 | max_feasibility = (unseen_scores-seen_mask*10.).max()
106 | thresholds = np.linspace(min_feasibility,max_feasibility, num=args.threshold_trials)
107 | best_auc = 0.
108 | best_th = -10
109 | with torch.no_grad():
110 | for th in thresholds:
111 | results = test(image_extractor,model,valoader,evaluator_val,args,threshold=th,print_results=False)
112 | auc = results['AUC']
113 | if auc > best_auc:
114 | best_auc = auc
115 | best_th = th
116 | print('New best AUC',best_auc)
117 | print('Threshold',best_th)
118 |
119 | threshold = best_th
120 |
121 | evaluator = Evaluator(testset, model)
122 |
123 | with torch.no_grad():
124 | test(image_extractor, model, testloader, evaluator, args, threshold)
125 |
126 |
127 | def test(image_extractor, model, testloader, evaluator, args, threshold=None, print_results=True):
128 | if image_extractor:
129 | image_extractor.eval()
130 |
131 | model.eval()
132 |
133 | accuracies, all_sub_gt, all_attr_gt, all_obj_gt, all_pair_gt, all_pred = [], [], [], [], [], []
134 |
135 | for idx, data in tqdm(enumerate(testloader), total=len(testloader), desc='Testing'):
136 | data = [d.to(device) for d in data]
137 |
138 | if image_extractor:
139 | data[0] = image_extractor(data[0])
140 | if threshold is None:
141 | _, predictions = model(data)
142 | else:
143 | _, predictions = model.val_forward_with_threshold(data,threshold)
144 |
145 | attr_truth, obj_truth, pair_truth = data[1], data[2], data[3]
146 |
147 | all_pred.append(predictions)
148 | all_attr_gt.append(attr_truth)
149 | all_obj_gt.append(obj_truth)
150 | all_pair_gt.append(pair_truth)
151 |
152 | if args.cpu_eval:
153 | all_attr_gt, all_obj_gt, all_pair_gt = torch.cat(all_attr_gt), torch.cat(all_obj_gt), torch.cat(all_pair_gt)
154 | else:
155 | all_attr_gt, all_obj_gt, all_pair_gt = torch.cat(all_attr_gt).to('cpu'), torch.cat(all_obj_gt).to(
156 | 'cpu'), torch.cat(all_pair_gt).to('cpu')
157 |
158 | all_pred_dict = {}
159 | # Gather values as dict of (attr, obj) as key and list of predictions as values
160 | if args.cpu_eval:
161 | for k in all_pred[0].keys():
162 | all_pred_dict[k] = torch.cat(
163 | [all_pred[i][k].to('cpu') for i in range(len(all_pred))])
164 | else:
165 | for k in all_pred[0].keys():
166 | all_pred_dict[k] = torch.cat(
167 | [all_pred[i][k] for i in range(len(all_pred))])
168 |
169 | # Calculate best unseen accuracy
170 | results = evaluator.score_model(all_pred_dict, all_obj_gt, bias=args.bias, topk=args.topk)
171 | stats = evaluator.evaluate_predictions(results, all_attr_gt, all_obj_gt, all_pair_gt, all_pred_dict,
172 | topk=args.topk)
173 |
174 |
175 | result = ''
176 | for key in stats:
177 | result = result + key + ' ' + str(round(stats[key], 4)) + '| '
178 |
179 | result = result + args.name
180 | if print_results:
181 | print(f'Results')
182 | print(result)
183 | return results
184 |
185 |
186 | if __name__ == '__main__':
187 | main()
188 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # Torch imports
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.utils.tensorboard import SummaryWriter
6 | import torch.backends.cudnn as cudnn
7 | cudnn.benchmark = True
8 |
9 | # Python imports
10 | import tqdm
11 | from tqdm import tqdm
12 | import os
13 | from os.path import join as ospj
14 | import csv
15 |
16 | #Local imports
17 | from data import dataset as dset
18 | from models.common import Evaluator
19 | from utils.utils import save_args, load_args
20 | from utils.config_model import configure_model
21 | from flags import parser, DATA_FOLDER
22 |
23 | best_auc = 0
24 | best_hm = 0
25 | compose_switch = True
26 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
27 |
28 | def main():
29 | # Get arguments and start logging
30 | args = parser.parse_args()
31 | load_args(args.config, args)
32 | logpath = os.path.join(args.cv_dir, args.name)
33 | os.makedirs(logpath, exist_ok=True)
34 | save_args(args, logpath, args.config)
35 | writer = SummaryWriter(log_dir = logpath, flush_secs = 30)
36 |
37 | # Get dataset
38 | trainset = dset.CompositionDataset(
39 | root=os.path.join(DATA_FOLDER,args.data_dir),
40 | phase='train',
41 | split=args.splitname,
42 | model =args.image_extractor,
43 | num_negs=args.num_negs,
44 | pair_dropout=args.pair_dropout,
45 | update_features = args.update_features,
46 | train_only= args.train_only,
47 | open_world=args.open_world
48 | )
49 | trainloader = torch.utils.data.DataLoader(
50 | trainset,
51 | batch_size=args.batch_size,
52 | shuffle=True,
53 | num_workers=args.workers)
54 | testset = dset.CompositionDataset(
55 | root=os.path.join(DATA_FOLDER,args.data_dir),
56 | phase=args.test_set,
57 | split=args.splitname,
58 | model =args.image_extractor,
59 | subset=args.subset,
60 | update_features = args.update_features,
61 | open_world=args.open_world
62 | )
63 | testloader = torch.utils.data.DataLoader(
64 | testset,
65 | batch_size=args.test_batch_size,
66 | shuffle=False,
67 | num_workers=args.workers)
68 |
69 |
70 | # Get model and optimizer
71 | image_extractor, model, optimizer = configure_model(args, trainset)
72 | args.extractor = image_extractor
73 |
74 | train = train_normal
75 |
76 | evaluator_val = Evaluator(testset, model)
77 |
78 | print(model)
79 |
80 | start_epoch = 0
81 | # Load checkpoint
82 | if args.load is not None:
83 | checkpoint = torch.load(args.load)
84 | if image_extractor:
85 | try:
86 | image_extractor.load_state_dict(checkpoint['image_extractor'])
87 | if args.freeze_features:
88 | print('Freezing image extractor')
89 | image_extractor.eval()
90 | for param in image_extractor.parameters():
91 | param.requires_grad = False
92 | except:
93 | print('No Image extractor in checkpoint')
94 | model.load_state_dict(checkpoint['net'])
95 | start_epoch = checkpoint['epoch']
96 | print('Loaded model from ', args.load)
97 |
98 | for epoch in tqdm(range(start_epoch, args.max_epochs + 1), desc = 'Current epoch'):
99 | train(epoch, image_extractor, model, trainloader, optimizer, writer)
100 | if model.is_open and args.model=='compcos' and ((epoch+1)%args.update_feasibility_every)==0 :
101 | print('Updating feasibility scores')
102 | model.update_feasibility(epoch+1.)
103 |
104 | if epoch % args.eval_val_every == 0:
105 | with torch.no_grad(): # todo: might not be needed
106 | test(epoch, image_extractor, model, testloader, evaluator_val, writer, args, logpath)
107 | print('Best AUC achieved is ', best_auc)
108 | print('Best HM achieved is ', best_hm)
109 |
110 |
111 | def train_normal(epoch, image_extractor, model, trainloader, optimizer, writer):
112 | '''
113 | Runs training for an epoch
114 | '''
115 |
116 | if image_extractor:
117 | image_extractor.train()
118 | model.train() # Let's switch to training
119 |
120 | train_loss = 0.0
121 | for idx, data in tqdm(enumerate(trainloader), total=len(trainloader), desc = 'Training'):
122 | data = [d.to(device) for d in data]
123 |
124 | if image_extractor:
125 | data[0] = image_extractor(data[0])
126 |
127 | loss, _ = model(data)
128 |
129 | optimizer.zero_grad()
130 | loss.backward()
131 | optimizer.step()
132 |
133 | train_loss += loss.item()
134 |
135 | train_loss = train_loss/len(trainloader)
136 | writer.add_scalar('Loss/train_total', train_loss, epoch)
137 | print('Epoch: {}| Loss: {}'.format(epoch, round(train_loss, 2)))
138 |
139 |
140 | def test(epoch, image_extractor, model, testloader, evaluator, writer, args, logpath):
141 | '''
142 | Runs testing for an epoch
143 | '''
144 | global best_auc, best_hm
145 |
146 | def save_checkpoint(filename):
147 | state = {
148 | 'net': model.state_dict(),
149 | 'epoch': epoch,
150 | 'AUC': stats['AUC']
151 | }
152 | if image_extractor:
153 | state['image_extractor'] = image_extractor.state_dict()
154 | torch.save(state, os.path.join(logpath, 'ckpt_{}.t7'.format(filename)))
155 |
156 | if image_extractor:
157 | image_extractor.eval()
158 |
159 | model.eval()
160 |
161 | accuracies, all_sub_gt, all_attr_gt, all_obj_gt, all_pair_gt, all_pred = [], [], [], [], [], []
162 |
163 | for idx, data in tqdm(enumerate(testloader), total=len(testloader), desc='Testing'):
164 | data = [d.to(device) for d in data]
165 |
166 | if image_extractor:
167 | data[0] = image_extractor(data[0])
168 |
169 | _, predictions = model(data)
170 |
171 | attr_truth, obj_truth, pair_truth = data[1], data[2], data[3]
172 |
173 | all_pred.append(predictions)
174 | all_attr_gt.append(attr_truth)
175 | all_obj_gt.append(obj_truth)
176 | all_pair_gt.append(pair_truth)
177 |
178 | if args.cpu_eval:
179 | all_attr_gt, all_obj_gt, all_pair_gt = torch.cat(all_attr_gt), torch.cat(all_obj_gt), torch.cat(all_pair_gt)
180 | else:
181 | all_attr_gt, all_obj_gt, all_pair_gt = torch.cat(all_attr_gt).to('cpu'), torch.cat(all_obj_gt).to(
182 | 'cpu'), torch.cat(all_pair_gt).to('cpu')
183 |
184 | all_pred_dict = {}
185 | # Gather values as dict of (attr, obj) as key and list of predictions as values
186 | if args.cpu_eval:
187 | for k in all_pred[0].keys():
188 | all_pred_dict[k] = torch.cat(
189 | [all_pred[i][k].to('cpu') for i in range(len(all_pred))])
190 | else:
191 | for k in all_pred[0].keys():
192 | all_pred_dict[k] = torch.cat(
193 | [all_pred[i][k] for i in range(len(all_pred))])
194 |
195 | # Calculate best unseen accuracy
196 | results = evaluator.score_model(all_pred_dict, all_obj_gt, bias=args.bias, topk=args.topk)
197 | stats = evaluator.evaluate_predictions(results, all_attr_gt, all_obj_gt, all_pair_gt, all_pred_dict, topk=args.topk)
198 |
199 | stats['a_epoch'] = epoch
200 |
201 | result = ''
202 | # write to Tensorboard
203 | for key in stats:
204 | writer.add_scalar(key, stats[key], epoch)
205 | result = result + key + ' ' + str(round(stats[key], 4)) + '| '
206 |
207 | result = result + args.name
208 | print(f'Test Epoch: {epoch}')
209 | print(result)
210 | if epoch > 0 and epoch % args.save_every == 0:
211 | save_checkpoint(epoch)
212 | if stats['AUC'] > best_auc:
213 | best_auc = stats['AUC']
214 | print('New best AUC ', best_auc)
215 | save_checkpoint('best_auc')
216 |
217 | if stats['best_hm'] > best_hm:
218 | best_hm = stats['best_hm']
219 | print('New best HM ', best_hm)
220 | save_checkpoint('best_hm')
221 |
222 | # Logs
223 | with open(ospj(logpath, 'logs.csv'), 'a') as f:
224 | w = csv.DictWriter(f, stats.keys())
225 | if epoch == 0:
226 | w.writeheader()
227 | w.writerow(stats)
228 |
229 |
230 | if __name__ == '__main__':
231 | try:
232 | main()
233 | except KeyboardInterrupt:
234 | print('Best AUC achieved is ', best_auc)
235 | print('Best HM achieved is ', best_hm)
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ExplainableML/czsl/22ab1aca82c64d4f351a20a887986147edfb1905/utils/__init__.py
--------------------------------------------------------------------------------
/utils/cgqa-graph.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ExplainableML/czsl/22ab1aca82c64d4f351a20a887986147edfb1905/utils/cgqa-graph.t7
--------------------------------------------------------------------------------
/utils/config_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.optim as optim
3 |
4 | from models.image_extractor import get_image_extractor
5 | from models.visual_product import VisualProductNN
6 | from models.manifold_methods import RedWine, LabelEmbedPlus, AttributeOperator
7 | from models.modular_methods import GatedGeneralNN
8 | from models.graph_method import GraphFull
9 | from models.symnet import Symnet
10 | from models.compcos import CompCos
11 |
12 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
13 |
14 |
15 | def configure_model(args, dataset):
16 | image_extractor = None
17 | is_open = False
18 |
19 | if args.model == 'visprodNN':
20 | model = VisualProductNN(dataset, args)
21 | elif args.model == 'redwine':
22 | model = RedWine(dataset, args)
23 | elif args.model == 'labelembed+':
24 | model = LabelEmbedPlus(dataset, args)
25 | elif args.model == 'attributeop':
26 | model = AttributeOperator(dataset, args)
27 | elif args.model == 'tmn':
28 | model = GatedGeneralNN(dataset, args, num_layers=args.nlayers, num_modules_per_layer=args.nmods)
29 | elif args.model == 'symnet':
30 | model = Symnet(dataset, args)
31 | elif args.model == 'graphfull':
32 | model = GraphFull(dataset, args)
33 | elif args.model == 'compcos':
34 | model = CompCos(dataset, args)
35 | if dataset.open_world and not args.train_only:
36 | is_open = True
37 | else:
38 | raise NotImplementedError
39 |
40 | model = model.to(device)
41 |
42 | if args.update_features:
43 | print('Learnable image_embeddings')
44 | image_extractor = get_image_extractor(arch = args.image_extractor, pretrained = True)
45 | image_extractor = image_extractor.to(device)
46 |
47 | # configuring optimizer
48 | if args.model=='redwine':
49 | optim_params = filter(lambda p: p.requires_grad, model.parameters())
50 | elif args.model=='attributeop':
51 | attr_params = [param for name, param in model.named_parameters() if 'attr_op' in name and param.requires_grad]
52 | other_params = [param for name, param in model.named_parameters() if 'attr_op' not in name and param.requires_grad]
53 | optim_params = [{'params':attr_params, 'lr':0.1*args.lr}, {'params':other_params}]
54 | elif args.model=='tmn':
55 | gating_params = [
56 | param for name, param in model.named_parameters()
57 | if 'gating_network' in name and param.requires_grad
58 | ]
59 | network_params = [
60 | param for name, param in model.named_parameters()
61 | if 'gating_network' not in name and param.requires_grad
62 | ]
63 | optim_params = [
64 | {
65 | 'params': network_params,
66 | },
67 | {
68 | 'params': gating_params,
69 | 'lr': args.lrg
70 | },
71 | ]
72 | else:
73 | model_params = [param for name, param in model.named_parameters() if param.requires_grad]
74 | optim_params = [{'params':model_params}]
75 | if args.update_features:
76 | ie_parameters = [param for name, param in image_extractor.named_parameters()]
77 | optim_params.append({'params': ie_parameters,
78 | 'lr': args.lrg})
79 | optimizer = optim.Adam(optim_params, lr=args.lr, weight_decay=args.wd)
80 |
81 | model.is_open = is_open
82 |
83 | return image_extractor, model, optimizer
--------------------------------------------------------------------------------
/utils/download_data.sh:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | CURRENT_DIR=$(pwd)
9 |
10 | FOLDER=$1
11 |
12 | mkdir $FOLDER
13 | mkdir $FOLDER/fast
14 | mkdir $FOLDER/glove
15 | mkdir $FOLDER/w2v
16 |
17 | cp utils/download_embeddings.py $FOLDER/fast
18 |
19 |
20 | # Download everything
21 | wget --show-progress -O $FOLDER/attr-ops-data.tar.gz https://utexas.box.com/shared/static/h7ckk2gx5ykw53h8o641no8rdyi7185g.gz
22 | wget --show-progress -O $FOLDER/mitstates.zip http://wednesday.csail.mit.edu/joseph_result/state_and_transformation/release_dataset.zip
23 | wget --show-progress -O $FOLDER/utzap.zip http://vision.cs.utexas.edu/projects/finegrained/utzap50k/ut-zap50k-images.zip
24 | wget --show-progress -O $FOLDER/splits.tar.gz https://www.senthilpurushwalkam.com/publication/compositional/compositional_split_natural.tar.gz
25 | wget --show-progress -O $FOLDER/cgqa.zip https://s3.mlcloud.uni-tuebingen.de/czsl/cgqa-updated.zip
26 |
27 | echo "Data downloaded. Extracting files..."
28 | sed -i "s|ROOT_FOLDER|$FOLDER|g" utils/reorganize_utzap.py
29 | sed -i "s|ROOT_FOLDER|$FOLDER|g" flags.py
30 |
31 | # Dataset metadata, pretrained SVMs and features, tensor completion data
32 |
33 | cd $FOLDER
34 |
35 | tar -zxvf attr-ops-data.tar.gz --strip 1
36 |
37 | # MIT-States
38 | unzip mitstates.zip 'release_dataset/images/*' -d mit-states/
39 | mv mit-states/release_dataset/images mit-states/images/
40 | rm -r mit-states/release_dataset
41 | rename "s/ /_/g" mit-states/images/*
42 |
43 | # UT-Zappos50k
44 | unzip utzap.zip -d ut-zap50k/
45 | mv ut-zap50k/ut-zap50k-images ut-zap50k/_images/
46 |
47 | # C-GQA
48 | unzip cgqa.zip -d cgqa/
49 |
50 | # Download new splits for Purushwalkam et. al
51 | tar -zxvf splits.tar.gz
52 |
53 | # remove all zip files and temporary files
54 | rm -r attr-ops-data.tar.gz mitstates.zip utzap.zip splits.tar.gz cgqa.zip
55 |
56 | # Download embeddings
57 |
58 | # Glove (from attribute as operators)
59 | mv data/glove/* glove/
60 |
61 | # FastText
62 | cd fast
63 | python download_embeddings.py
64 | rm cc.en.300.bin.gz
65 |
66 | # Word2Vec
67 | cd ../w2v
68 | wget -c "https://s3.amazonaws.com/dl4j-distribution/GoogleNews-vectors-negative300.bin.gz"
69 | gzip -d GoogleNews-vectors-negative300.bin.gz
70 | rm GoogleNews-vectors-negative300.bin.gz
71 |
72 | cd ..
73 | rm -r data
74 |
75 | cd $CURRENT_DIR
76 | python utils/reorganize_utzap.py
77 |
--------------------------------------------------------------------------------
/utils/download_embeddings.py:
--------------------------------------------------------------------------------
1 | import fasttext.util
2 | fasttext.util.download_model('en', if_exists='ignore') # English
3 |
4 |
--------------------------------------------------------------------------------
/utils/img.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ExplainableML/czsl/22ab1aca82c64d4f351a20a887986147edfb1905/utils/img.png
--------------------------------------------------------------------------------
/utils/mitstates-graph.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ExplainableML/czsl/22ab1aca82c64d4f351a20a887986147edfb1905/utils/mitstates-graph.t7
--------------------------------------------------------------------------------
/utils/reorganize_utzap.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 | """
8 | Reorganize the UT-Zappos dataset to resemble the MIT-States dataset
9 | root/attr_obj/img1.jpg
10 | root/attr_obj/img2.jpg
11 | root/attr_obj/img3.jpg
12 | ...
13 | """
14 |
15 | import os
16 | import torch
17 | import shutil
18 | import tqdm
19 |
20 | DATA_FOLDER= "ROOT_FOLDER"
21 |
22 | root = DATA_FOLDER+'/ut-zap50k/'
23 | os.makedirs(root+'/images',exist_ok=True)
24 |
25 | data = torch.load(root+'/metadata_compositional-split-natural.t7')
26 | for instance in tqdm.tqdm(data):
27 | image, attr, obj = instance['_image'], instance['attr'], instance['obj']
28 | old_file = '%s/_images/%s'%(root, image)
29 | new_dir = '%s/images/%s_%s/'%(root, attr, obj)
30 | os.makedirs(new_dir, exist_ok=True)
31 | shutil.copy(old_file, new_dir)
32 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os.path import join as ospj
3 | import torch
4 | import random
5 | import copy
6 | import shutil
7 | import sys
8 | import yaml
9 |
10 |
11 | def chunks(l, n):
12 | """Yield successive n-sized chunks from l."""
13 | for i in range(0, len(l), n):
14 | yield l[i:i + n]
15 |
16 | def get_norm_values(norm_family = 'imagenet'):
17 | '''
18 | Inputs
19 | norm_family: String of norm_family
20 | Returns
21 | mean, std : tuple of 3 channel values
22 | '''
23 | if norm_family == 'imagenet':
24 | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
25 | else:
26 | raise ValueError('Incorrect normalization family')
27 | return mean, std
28 |
29 | def save_args(args, log_path, argfile):
30 | shutil.copy('train.py', log_path)
31 | modelfiles = ospj(log_path, 'models')
32 | try:
33 | shutil.copy(argfile, log_path)
34 | except:
35 | print('Config exists')
36 | try:
37 | shutil.copytree('models/', modelfiles)
38 | except:
39 | print('Already exists')
40 | with open(ospj(log_path,'args_all.yaml'),'w') as f:
41 | yaml.dump(args, f, default_flow_style=False, allow_unicode=True)
42 | with open(ospj(log_path, 'args.txt'), 'w') as f:
43 | f.write('\n'.join(sys.argv[1:]))
44 |
45 | class UnNormalizer:
46 | '''
47 | Unnormalize a given tensor using mean and std of a dataset family
48 | Inputs
49 | norm_family: String, dataset
50 | tensor: Torch tensor
51 | Outputs
52 | tensor: Unnormalized tensor
53 | '''
54 | def __init__(self, norm_family = 'imagenet'):
55 | self.mean, self.std = get_norm_values(norm_family=norm_family)
56 | self.mean, self.std = torch.Tensor(self.mean).view(1, 3, 1, 1), torch.Tensor(self.std).view(1, 3, 1, 1)
57 |
58 | def __call__(self, tensor):
59 | return (tensor * self.std) + self.mean
60 |
61 | def load_args(filename, args):
62 | with open(filename, 'r') as stream:
63 | data_loaded = yaml.safe_load(stream)
64 | for key, group in data_loaded.items():
65 | for key, val in group.items():
66 | setattr(args, key, val)
--------------------------------------------------------------------------------
/utils/utzappos-graph.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ExplainableML/czsl/22ab1aca82c64d4f351a20a887986147edfb1905/utils/utzappos-graph.t7
--------------------------------------------------------------------------------