├── .gitignore ├── LICENSE ├── README.md ├── configs ├── baselines │ ├── cgqa │ │ ├── aopp.yml │ │ ├── le+.yml │ │ ├── symnet.yml │ │ └── tmn.yml │ ├── mit │ │ ├── aopp.yml │ │ ├── le+.yml │ │ ├── symnet.yml │ │ └── tmn.yml │ └── utzppos │ │ ├── aopp.yml │ │ ├── le+.yml │ │ ├── symnet.yml │ │ └── tmn.yml ├── cge │ ├── cgqa.yml │ ├── mit.yml │ └── utzappos.yml └── compcos │ ├── cgqa │ ├── compcos.yml │ └── compcos_cw.yml │ ├── mit │ ├── compcos.yml │ └── compcos_cw.yml │ └── utzppos │ ├── compcos.yml │ └── compcos_cw.yml ├── data ├── __init__.py └── dataset.py ├── environment.yml ├── flags.py ├── models ├── __init__.py ├── common.py ├── compcos.py ├── gcn.py ├── graph_method.py ├── image_extractor.py ├── manifold_methods.py ├── modular_methods.py ├── svm.py ├── symnet.py ├── visual_product.py └── word_embedding.py ├── notebooks └── analysis.ipynb ├── test.py ├── train.py └── utils ├── __init__.py ├── cgqa-graph.t7 ├── config_model.py ├── download_data.sh ├── download_embeddings.py ├── img.png ├── mitstates-graph.t7 ├── reorganize_utzap.py ├── utils.py └── utzappos-graph.t7 /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # IDE 132 | .vscode/ 133 | 134 | #logs and dataset 135 | cv/* 136 | data/antonyms.txt 137 | data/glove/* 138 | data/mit-states/* 139 | data/ut-zap50k/* 140 | logs/* 141 | tensor-completion/* 142 | .DS_Store 143 | ._.DS_Store 144 | *.csv 145 | *.txt 146 | notebooks/._mit_graph.pdf 147 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Compositional Zero-Shot Learning 3 | This is the official PyTorch code of the CVPR 2021 works [Learning Graph Embeddings for Compositional Zero-shot Learning](https://arxiv.org/pdf/2102.01987.pdf) and [Open World Compositional Zero-Shot Learning](https://arxiv.org/pdf/2101.12609.pdf). The code provides the implementation of the methods CGE, CompCos together with other baselines (e.g. SymNet, AoP, TMN, LabelEmbed+,RedWine). It also provides train and test for the Open World CZSL setting and the new C-GQA benchmark. 4 | 5 | **Important note:** the C-GQA dataset has been updated (see [this issue](https://github.com/ExplainableML/czsl/issues/3)) and the code will automatically download the new version. The results of all models for the updated benchmark can be found in the [Co-CGE](https://arxiv.org/abs/2105.01017) and [KG-SP](https://openaccess.thecvf.com/content/CVPR2022/html/Karthik_KG-SP_Knowledge_Guided_Simple_Primitives_for_Open_World_Compositional_Zero-Shot_CVPR_2022_paper.html) papers. 6 | 7 |

8 | 9 |

10 | 11 | ## Check also: 12 | - [Co-CGE](https://ieeexplore.ieee.org/document/9745371/) and its [repo](https://github.com/ExplainableML/co-cge) if you are interested in a stronger OW-CZSL model and a faster OW evaluation code. 13 | - [KG-SP](https://openaccess.thecvf.com/content/CVPR2022/html/Karthik_KG-SP_Knowledge_Guided_Simple_Primitives_for_Open_World_Compositional_Zero-Shot_CVPR_2022_paper.html) and its [repo](https://github.com/ExplainableML/KG-SP) if you are interested in the partial CZSL setting and a simple but effective OW model. 14 | 15 | 16 | ## Setup 17 | 18 | 1. Clone the repo 19 | 20 | 2. We recommend using Anaconda for environment setup. To create the environment and activate it, please run: 21 | ``` 22 | conda env create --file environment.yml 23 | conda activate czsl 24 | ``` 25 | 26 | 4. Go to the cloned repo and open a terminal. Download the datasets and embeddings, specifying the desired path (e.g. `DATA_ROOT` in the example): 27 | ``` 28 | bash ./utils/download_data.sh DATA_ROOT 29 | mkdir logs 30 | ``` 31 | 32 | 33 | ## Training 34 | **Closed World.** To train a model, the command is simply: 35 | ``` 36 | python train.py --config CONFIG_FILE 37 | ``` 38 | where `CONFIG_FILE` is the path to the configuration file of the model. 39 | The folder `configs` contains configuration files for all methods, i.e. CGE in `configs/cge`, CompCos in `configs/compcos`, and the other methods in `configs/baselines`. 40 | 41 | To run CGE on MitStates, the command is just: 42 | ``` 43 | python train.py --config configs/cge/mit.yml 44 | ``` 45 | On UT-Zappos, the command is: 46 | ``` 47 | python train.py --config configs/cge/utzappos.yml 48 | ``` 49 | 50 | **Open World.** To train CompCos (in the open world scenario) on MitStates, run: 51 | ``` 52 | python train.py --config configs/compcos/mit/compcos.yml 53 | ``` 54 | 55 | To run experiments in the open world setting for a non-open world method, just add `--open_world` after the command. E.g. for running SymNet in the open world scenario on MitStates, the command is: 56 | ``` 57 | python train.py --config configs/baselines/mit/symnet.yml --open_world 58 | ``` 59 | **Note:** To create a new config, all the available arguments are indicated in `flags.py`. 60 | 61 | ## Test 62 | 63 | 64 | **Closed World.** To test a model, the code is simple: 65 | ``` 66 | python test.py --logpath LOG_DIR 67 | ``` 68 | where `LOG_DIR` is the directory containing the logs of a model. 69 | 70 | 71 | **Open World.** To test a model in the open world setting, run: 72 | ``` 73 | python test.py --logpath LOG_DIR --open_world 74 | ``` 75 | 76 | To test a CompCos model in the open world setting with hard masking, run: 77 | ``` 78 | python test.py --logpath LOG_DIR_COMPCOS --open_world --hard_masking 79 | ``` 80 | 81 | 82 | ## References 83 | If you use this code, please cite 84 | ``` 85 | @inproceedings{naeem2021learning, 86 | title={Learning Graph Embeddings for Compositional Zero-shot Learning}, 87 | author={Naeem, MF and Xian, Y and Tombari, F and Akata, Zeynep}, 88 | booktitle={34th IEEE Conference on Computer Vision and Pattern Recognition}, 89 | year={2021}, 90 | organization={IEEE} 91 | } 92 | ``` 93 | and 94 | ``` 95 | @inproceedings{mancini2021open, 96 | title={Open World Compositional Zero-Shot Learning}, 97 | author={Mancini, M and Naeem, MF and Xian, Y and Akata, Zeynep}, 98 | booktitle={34th IEEE Conference on Computer Vision and Pattern Recognition}, 99 | year={2021}, 100 | organization={IEEE} 101 | } 102 | 103 | ``` 104 | 105 | **Note**: Some of the scripts are adapted from AttributeasOperators repository. GCN and GCNII implementations are imported from their respective repositories. If you find those parts useful, please consider citing: 106 | ``` 107 | @inproceedings{nagarajan2018attributes, 108 | title={Attributes as operators: factorizing unseen attribute-object compositions}, 109 | author={Nagarajan, Tushar and Grauman, Kristen}, 110 | booktitle={Proceedings of the European Conference on Computer Vision (ECCV)}, 111 | pages={169--185}, 112 | year={2018} 113 | } 114 | ``` 115 | -------------------------------------------------------------------------------- /configs/baselines/cgqa/aopp.yml: -------------------------------------------------------------------------------- 1 | --- 2 | experiment: 3 | name: aopp/cgqa/ 4 | dataset: 5 | data_dir: cgqa 6 | dataset: cgqa 7 | splitname: compositional-split-natural 8 | model_params: 9 | model: attributeop 10 | emb_dim: 300 11 | emb_init: glove 12 | eval_type: dist_fast 13 | image_extractor: resnet18 14 | train_only: true 15 | static_inp: false 16 | composition: add 17 | loss: 18 | lambda_aux: 1.0 19 | lambda_comm: 1.0 20 | training: 21 | batch_size: 128 22 | eval_val_every: 2 23 | load: 24 | lr: 5.0e-05 25 | lrg: 0.001 26 | margin: 0.5 27 | max_epochs: 1000 28 | norm_family: imagenet 29 | save_every: 10000 30 | test_batch_size: 32 31 | test_set: val 32 | topk: 1 33 | wd: 5.0e-05 34 | workers: 8 35 | update_features: false 36 | freeze_features: false 37 | extra: 38 | lambda_attr: 0 39 | lambda_obj: 0 40 | lambda_sub: 0 41 | graph: false 42 | hardk: null -------------------------------------------------------------------------------- /configs/baselines/cgqa/le+.yml: -------------------------------------------------------------------------------- 1 | --- 2 | experiment: 3 | name: labelembed/cgqa/ 4 | dataset: 5 | data_dir: cgqa 6 | dataset: cgqa 7 | splitname: compositional-split-natural 8 | model_params: 9 | model: labelembed+ 10 | dropout: false 11 | norm: false 12 | nlayers: 2 13 | fc_emb: 512,512 14 | emb_dim: 300 15 | emb_init: glove 16 | eval_type: dist_fast 17 | image_extractor: resnet18 18 | train_only: true 19 | static_inp: false 20 | composition: add 21 | loss: 22 | lambda_ce: 1.0 23 | lambda_trip_smart: 0 24 | training: 25 | batch_size: 512 26 | eval_val_every: 2 27 | load: 28 | lr: 5.0e-05 29 | margin: 0.5 30 | max_epochs: 2000 31 | norm_family: imagenet 32 | save_every: 10000 33 | test_batch_size: 32 34 | test_set: val 35 | topk: 1 36 | wd: 5.0e-05 37 | workers: 8 38 | update_features: false 39 | freeze_features: false -------------------------------------------------------------------------------- /configs/baselines/cgqa/symnet.yml: -------------------------------------------------------------------------------- 1 | --- 2 | experiment: 3 | name: symnet/cgqa/ 4 | dataset: 5 | data_dir: cgqa 6 | dataset: cgqa 7 | splitname: compositional-split-natural 8 | model_params: 9 | model: symnet 10 | dropout: true 11 | norm: true 12 | nlayers: 3 13 | fc_emb: 768,1024,1200 14 | emb_dim: 300 15 | emb_init: glove 16 | image_extractor: resnet18 17 | loss: 18 | lambda_cls_attr: 1.0 19 | lambda_cls_obj: 1.0 20 | lambda_trip: 0.5 21 | lambda_sym: 0.01 22 | lambda_axiom: 0.03 23 | training: 24 | batch_size: 512 25 | eval_val_every: 2 26 | load: 27 | lr: 5.0e-05 28 | lrg: 0.001 29 | margin: 0.5 30 | max_epochs: 2000 31 | norm_family: imagenet 32 | save_every: 10000 33 | test_batch_size: 32 34 | test_set: val 35 | topk: 1 36 | wd: 5.0e-05 37 | workers: 8 38 | update_features: false 39 | freeze_features: false -------------------------------------------------------------------------------- /configs/baselines/cgqa/tmn.yml: -------------------------------------------------------------------------------- 1 | --- 2 | experiment: 3 | name: tmn/cgqa/ 4 | dataset: 5 | data_dir: cgqa 6 | dataset: cgqa 7 | splitname: compositional-split-natural 8 | model_params: 9 | model: tmn 10 | emb_init: glove 11 | image_extractor: resnet18 12 | num_negs: 600 13 | embed_rank: 64 14 | emb_dim: 16 15 | nmods: 24 16 | training: 17 | batch_size: 256 18 | eval_val_every: 2 19 | load: 20 | lr: 0.0001 21 | lrg: 0.01 22 | margin: 0.5 23 | max_epochs: 100 24 | norm_family: imagenet 25 | save_every: 10000 26 | test_batch_size: 32 27 | test_set: val 28 | topk: 1 29 | wd: 5.0e-05 30 | workers: 8 31 | update_features: false 32 | freeze_features: false -------------------------------------------------------------------------------- /configs/baselines/mit/aopp.yml: -------------------------------------------------------------------------------- 1 | --- 2 | experiment: 3 | name: aopp/mitstates/ 4 | dataset: 5 | data_dir: mit-states 6 | dataset: mitstates 7 | splitname: compositional-split-natural 8 | model_params: 9 | model: attributeop 10 | emb_dim: 300 11 | emb_init: glove 12 | eval_type: dist_fast 13 | image_extractor: resnet18 14 | train_only: true 15 | static_inp: false 16 | composition: add 17 | loss: 18 | lambda_aux: 1000.0 19 | lambda_comm: 1.0 20 | training: 21 | batch_size: 512 22 | eval_val_every: 2 23 | load: 24 | lr: 5.0e-05 25 | lrg: 0.001 26 | margin: 0.5 27 | max_epochs: 1000 28 | norm_family: imagenet 29 | save_every: 10000 30 | test_batch_size: 32 31 | test_set: val 32 | topk: 1 33 | wd: 5.0e-05 34 | workers: 8 35 | update_features: false 36 | freeze_features: false 37 | extra: 38 | lambda_attr: 0 39 | lambda_obj: 0 40 | lambda_sub: 0 41 | graph: false 42 | hardk: null -------------------------------------------------------------------------------- /configs/baselines/mit/le+.yml: -------------------------------------------------------------------------------- 1 | --- 2 | experiment: 3 | name: labelembed/mitstates/ 4 | dataset: 5 | data_dir: mit-states 6 | dataset: mitstates 7 | splitname: compositional-split-natural 8 | model_params: 9 | model: labelembed+ 10 | dropout: false 11 | norm: false 12 | nlayers: 2 13 | fc_emb: 512,512 14 | emb_dim: 300 15 | emb_init: glove 16 | eval_type: dist_fast 17 | image_extractor: resnet18 18 | train_only: true 19 | static_inp: false 20 | composition: add 21 | training: 22 | batch_size: 512 23 | eval_val_every: 2 24 | load: 25 | lr: 5.0e-05 26 | margin: 0.5 27 | max_epochs: 2000 28 | norm_family: imagenet 29 | save_every: 10000 30 | test_batch_size: 32 31 | test_set: val 32 | topk: 1 33 | wd: 5.0e-05 34 | workers: 8 35 | update_features: false 36 | freeze_features: false -------------------------------------------------------------------------------- /configs/baselines/mit/symnet.yml: -------------------------------------------------------------------------------- 1 | --- 2 | experiment: 3 | name: symnet/mit-states/ 4 | dataset: 5 | data_dir: mit-states 6 | dataset: mitstates 7 | splitname: compositional-split-natural 8 | model_params: 9 | model: symnet 10 | dropout: true 11 | norm: true 12 | nlayers: 3 13 | fc_emb: 768,1024,1200 14 | emb_dim: 300 15 | emb_init: glove 16 | image_extractor: resnet18 17 | obj_fixed: True 18 | loss: 19 | lambda_cls_attr: 1.0 20 | lambda_cls_obj: 0.1 21 | lambda_trip: 0.5 22 | lambda_sym: 0.01 23 | lambda_axiom: 0.03 24 | training: 25 | batch_size: 450 26 | eval_val_every: 2 27 | load: 28 | lr: 5.0e-05 29 | lrg: 5.0e-06 30 | margin: 0.5 31 | max_epochs: 2000 32 | norm_family: imagenet 33 | save_every: 10000 34 | test_batch_size: 32 35 | test_set: val 36 | topk: 1 37 | wd: 5.0e-05 38 | workers: 8 39 | update_features: false 40 | freeze_features: false -------------------------------------------------------------------------------- /configs/baselines/mit/tmn.yml: -------------------------------------------------------------------------------- 1 | --- 2 | experiment: 3 | name: tmn/mitstates/ 4 | dataset: 5 | data_dir: mit-states 6 | dataset: mitstates 7 | splitname: compositional-split-natural 8 | model_params: 9 | model: tmn 10 | emb_init: glove 11 | image_extractor: resnet18 12 | num_negs: 600 13 | embed_rank: 64 14 | emb_dim: 16 15 | nmods: 24 16 | training: 17 | batch_size: 256 18 | eval_val_every: 2 19 | load: 20 | lr: 0.0001 21 | lrg: 1.0e-06 22 | margin: 0.5 23 | max_epochs: 100 24 | norm_family: imagenet 25 | save_every: 10000 26 | test_batch_size: 32 27 | test_set: val 28 | topk: 1 29 | wd: 5.0e-05 30 | workers: 8 31 | update_features: true 32 | freeze_features: false -------------------------------------------------------------------------------- /configs/baselines/utzppos/aopp.yml: -------------------------------------------------------------------------------- 1 | --- 2 | experiment: 3 | name: aopp/utzappos/ 4 | dataset: 5 | data_dir: ut-zap50k 6 | dataset: utzappos 7 | splitname: compositional-split-natural 8 | model_params: 9 | model: attributeop 10 | emb_dim: 300 11 | emb_init: glove 12 | eval_type: dist_fast 13 | image_extractor: resnet18 14 | train_only: true 15 | static_inp: false 16 | composition: add 17 | loss: 18 | lambda_aux: 1.0 19 | lambda_comm: 1.0 20 | training: 21 | batch_size: 512 22 | eval_val_every: 2 23 | load: 24 | lr: 5.0e-05 25 | lrg: 0.001 26 | margin: 0.5 27 | max_epochs: 1000 28 | norm_family: imagenet 29 | save_every: 10000 30 | test_batch_size: 32 31 | test_set: val 32 | topk: 1 33 | wd: 5.0e-05 34 | workers: 8 35 | update_features: false 36 | freeze_features: false 37 | extra: 38 | lambda_attr: 0 39 | lambda_obj: 0 40 | lambda_sub: 0 41 | graph: false 42 | hardk: null -------------------------------------------------------------------------------- /configs/baselines/utzppos/le+.yml: -------------------------------------------------------------------------------- 1 | --- 2 | experiment: 3 | name: labelembed/utzappos/ 4 | dataset: 5 | data_dir: ut-zap50k 6 | dataset: utzappos 7 | splitname: compositional-split-natural 8 | model_params: 9 | model: labelembed+ 10 | dropout: false 11 | norm: false 12 | nlayers: 2 13 | fc_emb: 512,512 14 | emb_dim: 300 15 | emb_init: glove 16 | eval_type: dist_fast 17 | image_extractor: resnet18 18 | train_only: true 19 | static_inp: false 20 | composition: add 21 | training: 22 | batch_size: 512 23 | eval_val_every: 2 24 | load: 25 | lr: 5.0e-05 26 | margin: 0.5 27 | max_epochs: 2000 28 | norm_family: imagenet 29 | save_every: 10000 30 | test_batch_size: 32 31 | test_set: val 32 | topk: 1 33 | wd: 5.0e-05 34 | workers: 8 35 | update_features: false 36 | freeze_features: false -------------------------------------------------------------------------------- /configs/baselines/utzppos/symnet.yml: -------------------------------------------------------------------------------- 1 | --- 2 | experiment: 3 | name: symnet/utzappos/ 4 | dataset: 5 | data_dir: ut-zap50k 6 | dataset: utzappos 7 | splitname: compositional-split-natural 8 | model_params: 9 | model: symnet 10 | dropout: true 11 | norm: true 12 | nlayers: 3 13 | fc_emb: 768,1024,1200 14 | emb_dim: 300 15 | emb_init: fasttext 16 | image_extractor: resnet18 17 | loss: 18 | lambda_cls_attr: 1.0 19 | lambda_cls_obj: 1.0 20 | lambda_trip: 0.5 21 | lambda_sym: 0.01 22 | lambda_axiom: 0.03 23 | training: 24 | batch_size: 512 25 | eval_val_every: 2 26 | load: 27 | lr: 5.0e-05 28 | lrg: 0.001 29 | margin: 0.5 30 | max_epochs: 2000 31 | norm_family: imagenet 32 | save_every: 10000 33 | test_batch_size: 32 34 | test_set: val 35 | topk: 1 36 | wd: 5.0e-05 37 | workers: 8 38 | update_features: false 39 | freeze_features: false -------------------------------------------------------------------------------- /configs/baselines/utzppos/tmn.yml: -------------------------------------------------------------------------------- 1 | --- 2 | experiment: 3 | name: tmn/utzappos/ 4 | dataset: 5 | data_dir: ut-zap50k 6 | dataset: utzappos 7 | splitname: compositional-split-natural 8 | model_params: 9 | model: tmn 10 | emb_init: word2vec 11 | emb_dim: 300 12 | image_extractor: resnet18 13 | train_only: true 14 | static_inp: true 15 | num_negs: 600 16 | embed_rank: 64 17 | nmods: 24 18 | training: 19 | batch_size: 256 20 | eval_val_every: 10 21 | load: 22 | lr: 0.0001 23 | lrg: 1.0e-06 24 | margin: 0.5 25 | max_epochs: 100 26 | norm_family: imagenet 27 | save_every: 10000 28 | test_batch_size: 32 29 | test_set: val 30 | topk: 1 31 | wd: 5.0e-05 32 | workers: 8 33 | update_features: false 34 | freeze_features: false 35 | (czsl) mmancini68@EML-Jupi -------------------------------------------------------------------------------- /configs/cge/cgqa.yml: -------------------------------------------------------------------------------- 1 | experiment: 2 | name: graphembed/cgqa 3 | dataset: 4 | data_dir: cgqa 5 | dataset: cgqa 6 | splitname: compositional-split-natural 7 | model_params: 8 | model: graphfull 9 | dropout: true 10 | norm: true 11 | nlayers: 12 | fc_emb: 768,1024,1200 13 | gr_emb: d4096,d 14 | emb_dim: 512 15 | emb_init: word2vec 16 | image_extractor: resnet18 17 | train_only: true 18 | training: 19 | batch_size: 256 20 | eval_val_every: 2 21 | load: 22 | lr: 5.0e-05 23 | wd: 5.0e-05 24 | lrg: 5.0e-06 25 | margin: 2 26 | max_epochs: 2000 27 | norm_family: imagenet 28 | save_every: 10000 29 | test_batch_size: 32 30 | test_set: val 31 | topk: 1 32 | workers: 8 33 | update_features: true 34 | freeze_featues: false 35 | -------------------------------------------------------------------------------- /configs/cge/mit.yml: -------------------------------------------------------------------------------- 1 | --- 2 | experiment: 3 | name: graphembed/mitstates/base 4 | dataset: 5 | data_dir: mit-states 6 | dataset: mitstates 7 | splitname: compositional-split-natural 8 | model_params: 9 | model: graphfull 10 | dropout: true 11 | norm: true 12 | nlayers: 13 | gr_emb: d4096,d 14 | emb_dim: 512 15 | graph_init: utils/mitstates-graph.t7 16 | image_extractor: resnet18 17 | train_only: true 18 | training: 19 | batch_size: 128 20 | eval_val_every: 2 21 | load: 22 | lr: 5.0e-05 23 | wd: 5.0e-05 24 | lrg: 5.0e-6 25 | margin: 0.5 26 | max_epochs: 200 27 | norm_family: imagenet 28 | save_every: 10000 29 | test_batch_size: 32 30 | test_set: val 31 | topk: 1 32 | workers: 8 33 | update_features: true 34 | freeze_features: false 35 | 36 | -------------------------------------------------------------------------------- /configs/cge/utzappos.yml: -------------------------------------------------------------------------------- 1 | --- 2 | experiment: 3 | name: graphembed/utzappos/base 4 | dataset: 5 | data_dir: ut-zap50k 6 | dataset: utzappos 7 | splitname: compositional-split-natural 8 | model_params: 9 | model: graphfull 10 | dropout: true 11 | norm: true 12 | nlayers: 1 13 | fc_emb: 768,1024,1200 14 | gr_emb: d4096,d 15 | emb_dim: 300 16 | emb_init: null 17 | graph_init: utils/utzappos-graph.t7 18 | image_extractor: resnet18 19 | train_only: true 20 | training: 21 | batch_size: 128 22 | eval_val_every: 2 23 | load: 24 | lr: 5.0e-05 25 | wd: 5.0e-05 26 | lrg: 5.0e-6 27 | margin: 0.5 28 | max_epochs: 2000 29 | norm_family: imagenet 30 | save_every: 10000 31 | test_batch_size: 32 32 | test_set: val 33 | topk: 1 34 | workers: 8 35 | update_features: true 36 | freeze_features: false 37 | extra: 38 | lambda_attr: 0 39 | lambda_obj: 0 40 | lambda_sub: 0 41 | graph: false 42 | hardk: null 43 | cutmix: false 44 | cutmix_prob: 1.0 45 | beta: 1.0 -------------------------------------------------------------------------------- /configs/compcos/cgqa/compcos.yml: -------------------------------------------------------------------------------- 1 | --- 2 | experiment: 3 | name: compcos/cgqa 4 | dataset: 5 | data_dir: cgqa 6 | dataset: cgqa 7 | splitname: compositional-split-natural 8 | open_world: true 9 | model_params: 10 | model: compcos 11 | dropout: true 12 | norm: true 13 | nlayers: 2 14 | relu: false 15 | fc_emb: 768,1024,1200 16 | emb_dim: 300 17 | emb_init: word2vec 18 | image_extractor: resnet18 19 | train_only: false 20 | static_inp: false 21 | training: 22 | batch_size: 128 23 | load: 24 | lr: 5.0e-05 25 | lrg: 0.001 26 | margin: 0.4 27 | cosine_scale: 100 28 | max_epochs: 300 29 | norm_family: imagenet 30 | save_every: 10000 31 | test_batch_size: 64 32 | test_set: val 33 | topk: 1 34 | wd: 5.0e-05 35 | workers: 8 36 | update_features: false 37 | freeze_features: false 38 | epoch_max_margin: 15 39 | -------------------------------------------------------------------------------- /configs/compcos/cgqa/compcos_cw.yml: -------------------------------------------------------------------------------- 1 | --- 2 | experiment: 3 | name: compcos_cw/cgqa/ 4 | dataset: 5 | data_dir: cgqa 6 | dataset: cgqa 7 | splitname: compositional-split-natural 8 | model_params: 9 | model: compcos 10 | dropout: true 11 | norm: true 12 | nlayers: 2 13 | relu: false 14 | fc_emb: 768,1024,1200 15 | emb_dim: 300 16 | emb_init: word2vec 17 | image_extractor: resnet18 18 | train_only: true 19 | static_inp: false 20 | training: 21 | batch_size: 128 22 | load: 23 | lr: 5.0e-05 24 | lrg: 0.001 25 | margin: 0.4 26 | cosine_scale: 100 27 | max_epochs: 300 28 | norm_family: imagenet 29 | save_every: 10000 30 | test_batch_size: 64 31 | test_set: val 32 | topk: 1 33 | wd: 5.0e-05 34 | workers: 8 35 | update_features: false 36 | freeze_features: false 37 | epoch_max_margin: 15 38 | -------------------------------------------------------------------------------- /configs/compcos/mit/compcos.yml: -------------------------------------------------------------------------------- 1 | --- 2 | experiment: 3 | name: compcos/mitstates 4 | dataset: 5 | data_dir: mit-states 6 | dataset: mitstates 7 | splitname: compositional-split-natural 8 | open_world: true 9 | model_params: 10 | model: compcos 11 | dropout: true 12 | norm: true 13 | nlayers: 2 14 | relu: false 15 | fc_emb: 768,1024,1200 16 | emb_dim: 600 17 | emb_init: ft+w2v 18 | image_extractor: resnet18 19 | train_only: false 20 | static_inp: false 21 | training: 22 | batch_size: 128 23 | load: 24 | lr: 5.0e-05 25 | lrg: 0.001 26 | margin: 0.4 27 | cosine_scale: 20 28 | max_epochs: 300 29 | norm_family: imagenet 30 | save_every: 10000 31 | test_batch_size: 64 32 | test_set: val 33 | topk: 1 34 | wd: 5.0e-05 35 | workers: 8 36 | update_features: false 37 | freeze_features: false 38 | epoch_max_margin: 15 39 | -------------------------------------------------------------------------------- /configs/compcos/mit/compcos_cw.yml: -------------------------------------------------------------------------------- 1 | --- 2 | experiment: 3 | name: compcos_cw/mitstates 4 | dataset: 5 | data_dir: mit-states 6 | dataset: mitstates 7 | splitname: compositional-split-natural 8 | model_params: 9 | model: compcos 10 | dropout: true 11 | norm: true 12 | nlayers: 2 13 | relu: false 14 | fc_emb: 768,1024,1200 15 | emb_dim: 600 16 | emb_init: ft+w2v 17 | image_extractor: resnet18 18 | train_only: true 19 | static_inp: false 20 | training: 21 | batch_size: 128 22 | load: 23 | lr: 5.0e-05 24 | lrg: 0.001 25 | margin: 0.4 26 | cosine_scale: 20 27 | max_epochs: 300 28 | norm_family: imagenet 29 | save_every: 10000 30 | test_batch_size: 64 31 | test_set: val 32 | topk: 1 33 | wd: 5.0e-05 34 | workers: 8 35 | update_features: false 36 | freeze_features: false 37 | epoch_max_margin: 15 38 | -------------------------------------------------------------------------------- /configs/compcos/utzppos/compcos.yml: -------------------------------------------------------------------------------- 1 | --- 2 | experiment: 3 | name: compcos/utzappos 4 | dataset: 5 | data_dir: ut-zap50k 6 | dataset: utzappos 7 | splitname: compositional-split-natural 8 | open_world: true 9 | model_params: 10 | model: compcos 11 | dropout: true 12 | norm: true 13 | nlayers: 2 14 | relu: false 15 | fc_emb: 768,1024 16 | emb_dim: 300 17 | emb_init: word2vec 18 | image_extractor: resnet18 19 | train_only: false 20 | static_inp: false 21 | training: 22 | batch_size: 128 23 | load: 24 | lr: 5.0e-05 25 | lrg: 0.001 26 | margin: 1.0 27 | cosine_scale: 50 28 | max_epochs: 300 29 | norm_family: imagenet 30 | save_every: 10000 31 | test_batch_size: 64 32 | test_set: val 33 | topk: 1 34 | wd: 5.0e-05 35 | workers: 8 36 | update_features: false 37 | freeze_features: false 38 | epoch_max_margin: 100 -------------------------------------------------------------------------------- /configs/compcos/utzppos/compcos_cw.yml: -------------------------------------------------------------------------------- 1 | --- 2 | experiment: 3 | name: compcos/utzappos 4 | dataset: 5 | data_dir: ut-zap50k 6 | dataset: utzappos 7 | splitname: compositional-split-natural 8 | model_params: 9 | model: compcos 10 | dropout: true 11 | norm: true 12 | nlayers: 2 13 | relu: false 14 | fc_emb: 768,1024 15 | emb_dim: 300 16 | emb_init: word2vec 17 | image_extractor: resnet18 18 | train_only: true 19 | static_inp: false 20 | training: 21 | batch_size: 128 22 | load: 23 | lr: 5.0e-05 24 | lrg: 0.001 25 | margin: 1.0 26 | cosine_scale: 50 27 | max_epochs: 300 28 | norm_family: imagenet 29 | save_every: 10000 30 | test_batch_size: 64 31 | test_set: val 32 | topk: 1 33 | wd: 5.0e-05 34 | workers: 8 35 | update_features: false 36 | freeze_features: false 37 | epoch_max_margin: 100 -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/czsl/22ab1aca82c64d4f351a20a887986147edfb1905/data/__init__.py -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | #external libs 2 | import numpy as np 3 | from tqdm import tqdm 4 | from PIL import Image 5 | import os 6 | import random 7 | from os.path import join as ospj 8 | from glob import glob 9 | #torch libs 10 | from torch.utils.data import Dataset 11 | import torch 12 | import torchvision.transforms as transforms 13 | #local libs 14 | from utils.utils import get_norm_values, chunks 15 | from models.image_extractor import get_image_extractor 16 | from itertools import product 17 | 18 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 19 | 20 | class ImageLoader: 21 | def __init__(self, root): 22 | self.root_dir = root 23 | 24 | def __call__(self, img): 25 | img = Image.open(ospj(self.root_dir,img)).convert('RGB') #We don't want alpha 26 | return img 27 | 28 | 29 | def dataset_transform(phase, norm_family = 'imagenet'): 30 | ''' 31 | Inputs 32 | phase: String controlling which set of transforms to use 33 | norm_family: String controlling which normaliztion values to use 34 | 35 | Returns 36 | transform: A list of pytorch transforms 37 | ''' 38 | mean, std = get_norm_values(norm_family=norm_family) 39 | 40 | if phase == 'train': 41 | transform = transforms.Compose([ 42 | transforms.RandomResizedCrop(224), 43 | transforms.RandomHorizontalFlip(), 44 | transforms.ToTensor(), 45 | transforms.Normalize(mean, std) 46 | ]) 47 | 48 | elif phase == 'val' or phase == 'test': 49 | transform = transforms.Compose([ 50 | transforms.Resize(256), 51 | transforms.CenterCrop(224), 52 | transforms.ToTensor(), 53 | transforms.Normalize(mean, std) 54 | ]) 55 | elif phase == 'all': 56 | transform = transforms.Compose([ 57 | transforms.Resize(256), 58 | transforms.CenterCrop(224), 59 | transforms.ToTensor(), 60 | transforms.Normalize(mean, std) 61 | ]) 62 | else: 63 | raise ValueError('Invalid transform') 64 | 65 | return transform 66 | 67 | def filter_data(all_data, pairs_gt, topk = 5): 68 | ''' 69 | Helper function to clean data 70 | ''' 71 | valid_files = [] 72 | with open('/home/ubuntu/workspace/top'+str(topk)+'.txt') as f: 73 | for line in f: 74 | valid_files.append(line.strip()) 75 | 76 | data, pairs, attr, obj = [], [], [], [] 77 | for current in all_data: 78 | if current[0] in valid_files: 79 | data.append(current) 80 | pairs.append((current[1],current[2])) 81 | attr.append(current[1]) 82 | obj.append(current[2]) 83 | 84 | counter = 0 85 | for current in pairs_gt: 86 | if current in pairs: 87 | counter+=1 88 | print('Matches ', counter, ' out of ', len(pairs_gt)) 89 | print('Samples ', len(data), ' out of ', len(all_data)) 90 | return data, sorted(list(set(pairs))), sorted(list(set(attr))), sorted(list(set(obj))) 91 | 92 | # Dataset class now 93 | 94 | class CompositionDataset(Dataset): 95 | ''' 96 | Inputs 97 | root: String of base dir of dataset 98 | phase: String train, val, test 99 | split: String dataset split 100 | subset: Boolean if true uses a subset of train at each epoch 101 | num_negs: Int, numbers of negative pairs per batch 102 | pair_dropout: Percentage of pairs to leave in current epoch 103 | ''' 104 | def __init__( 105 | self, 106 | root, 107 | phase, 108 | split = 'compositional-split', 109 | model = 'resnet18', 110 | norm_family = 'imagenet', 111 | subset = False, 112 | num_negs = 1, 113 | pair_dropout = 0.0, 114 | update_features = False, 115 | return_images = False, 116 | train_only = False, 117 | open_world=False 118 | ): 119 | self.root = root 120 | self.phase = phase 121 | self.split = split 122 | self.num_negs = num_negs 123 | self.pair_dropout = pair_dropout 124 | self.norm_family = norm_family 125 | self.return_images = return_images 126 | self.update_features = update_features 127 | self.feat_dim = 512 if 'resnet18' in model else 2048 # todo, unify this with models 128 | self.open_world = open_world 129 | 130 | self.attrs, self.objs, self.pairs, self.train_pairs, \ 131 | self.val_pairs, self.test_pairs = self.parse_split() 132 | self.train_data, self.val_data, self.test_data = self.get_split_info() 133 | self.full_pairs = list(product(self.attrs,self.objs)) 134 | 135 | # Clean only was here 136 | self.obj2idx = {obj: idx for idx, obj in enumerate(self.objs)} 137 | self.attr2idx = {attr : idx for idx, attr in enumerate(self.attrs)} 138 | if self.open_world: 139 | self.pairs = self.full_pairs 140 | 141 | self.all_pair2idx = {pair: idx for idx, pair in enumerate(self.pairs)} 142 | 143 | if train_only and self.phase == 'train': 144 | print('Using only train pairs') 145 | self.pair2idx = {pair : idx for idx, pair in enumerate(self.train_pairs)} 146 | else: 147 | print('Using all pairs') 148 | self.pair2idx = {pair : idx for idx, pair in enumerate(self.pairs)} 149 | 150 | if self.phase == 'train': 151 | self.data = self.train_data 152 | elif self.phase == 'val': 153 | self.data = self.val_data 154 | elif self.phase == 'test': 155 | self.data = self.test_data 156 | elif self.phase == 'all': 157 | print('Using all data') 158 | self.data = self.train_data + self.val_data + self.test_data 159 | else: 160 | raise ValueError('Invalid training phase') 161 | 162 | self.all_data = self.train_data + self.val_data + self.test_data 163 | print('Dataset loaded') 164 | print('Train pairs: {}, Validation pairs: {}, Test Pairs: {}'.format( 165 | len(self.train_pairs), len(self.val_pairs), len(self.test_pairs))) 166 | print('Train images: {}, Validation images: {}, Test images: {}'.format( 167 | len(self.train_data), len(self.val_data), len(self.test_data))) 168 | 169 | if subset: 170 | ind = np.arange(len(self.data)) 171 | ind = ind[::len(ind) // 1000] 172 | self.data = [self.data[i] for i in ind] 173 | 174 | 175 | # Keeping a list of all pairs that occur with each object 176 | self.obj_affordance = {} 177 | self.train_obj_affordance = {} 178 | for _obj in self.objs: 179 | candidates = [attr for (_, attr, obj) in self.train_data+self.test_data if obj==_obj] 180 | self.obj_affordance[_obj] = list(set(candidates)) 181 | 182 | candidates = [attr for (_, attr, obj) in self.train_data if obj==_obj] 183 | self.train_obj_affordance[_obj] = list(set(candidates)) 184 | 185 | self.sample_indices = list(range(len(self.data))) 186 | self.sample_pairs = self.train_pairs 187 | 188 | # Load based on what to output 189 | self.transform = dataset_transform(self.phase, self.norm_family) 190 | self.loader = ImageLoader(ospj(self.root, 'images')) 191 | if not self.update_features: 192 | feat_file = ospj(root, model+'_featurers.t7') 193 | print(f'Using {model} and feature file {feat_file}') 194 | if not os.path.exists(feat_file): 195 | with torch.no_grad(): 196 | self.generate_features(feat_file, model) 197 | self.phase = phase 198 | activation_data = torch.load(feat_file) 199 | self.activations = dict( 200 | zip(activation_data['files'], activation_data['features'])) 201 | self.feat_dim = activation_data['features'].size(1) 202 | print('{} activations loaded'.format(len(self.activations))) 203 | 204 | 205 | def parse_split(self): 206 | ''' 207 | Helper function to read splits of object atrribute pair 208 | Returns 209 | all_attrs: List of all attributes 210 | all_objs: List of all objects 211 | all_pairs: List of all combination of attrs and objs 212 | tr_pairs: List of train pairs of attrs and objs 213 | vl_pairs: List of validation pairs of attrs and objs 214 | ts_pairs: List of test pairs of attrs and objs 215 | ''' 216 | def parse_pairs(pair_list): 217 | ''' 218 | Helper function to parse each phase to object attrribute vectors 219 | Inputs 220 | pair_list: path to textfile 221 | ''' 222 | with open(pair_list, 'r') as f: 223 | pairs = f.read().strip().split('\n') 224 | pairs = [line.split() for line in pairs] 225 | pairs = list(map(tuple, pairs)) 226 | 227 | attrs, objs = zip(*pairs) 228 | return attrs, objs, pairs 229 | 230 | tr_attrs, tr_objs, tr_pairs = parse_pairs( 231 | ospj(self.root, self.split, 'train_pairs.txt') 232 | ) 233 | vl_attrs, vl_objs, vl_pairs = parse_pairs( 234 | ospj(self.root, self.split, 'val_pairs.txt') 235 | ) 236 | ts_attrs, ts_objs, ts_pairs = parse_pairs( 237 | ospj(self.root, self.split, 'test_pairs.txt') 238 | ) 239 | 240 | #now we compose all objs, attrs and pairs 241 | all_attrs, all_objs = sorted( 242 | list(set(tr_attrs + vl_attrs + ts_attrs))), sorted( 243 | list(set(tr_objs + vl_objs + ts_objs))) 244 | all_pairs = sorted(list(set(tr_pairs + vl_pairs + ts_pairs))) 245 | 246 | return all_attrs, all_objs, all_pairs, tr_pairs, vl_pairs, ts_pairs 247 | 248 | def get_split_info(self): 249 | ''' 250 | Helper method to read image, attrs, objs samples 251 | 252 | Returns 253 | train_data, val_data, test_data: List of tuple of image, attrs, obj 254 | ''' 255 | data = torch.load(ospj(self.root, 'metadata_{}.t7'.format(self.split))) 256 | 257 | train_data, val_data, test_data = [], [], [] 258 | 259 | for instance in data: 260 | image, attr, obj, settype = instance['image'], instance['attr'], \ 261 | instance['obj'], instance['set'] 262 | curr_data = [image, attr, obj] 263 | 264 | if attr == 'NA' or (attr, obj) not in self.pairs or settype == 'NA': 265 | # Skip incomplete pairs, unknown pairs and unknown set 266 | continue 267 | 268 | if settype == 'train': 269 | train_data.append(curr_data) 270 | elif settype == 'val': 271 | val_data.append(curr_data) 272 | else: 273 | test_data.append(curr_data) 274 | 275 | return train_data, val_data, test_data 276 | 277 | def get_dict_data(self, data, pairs): 278 | data_dict = {} 279 | for current in pairs: 280 | data_dict[current] = [] 281 | 282 | for current in data: 283 | image, attr, obj = current 284 | data_dict[(attr, obj)].append(image) 285 | 286 | return data_dict 287 | 288 | 289 | def reset_dropout(self): 290 | ''' 291 | Helper function to sample new subset of data containing a subset of pairs of objs and attrs 292 | ''' 293 | self.sample_indices = list(range(len(self.data))) 294 | self.sample_pairs = self.train_pairs 295 | 296 | # Using sampling from random instead of 2 step numpy 297 | n_pairs = int((1 - self.pair_dropout) * len(self.train_pairs)) 298 | 299 | self.sample_pairs = random.sample(self.train_pairs, n_pairs) 300 | print('Sampled new subset') 301 | print('Using {} pairs out of {} pairs right now'.format( 302 | n_pairs, len(self.train_pairs))) 303 | 304 | self.sample_indices = [ i for i in range(len(self.data)) 305 | if (self.data[i][1], self.data[i][2]) in self.sample_pairs 306 | ] 307 | print('Using {} images out of {} images right now'.format( 308 | len(self.sample_indices), len(self.data))) 309 | 310 | def sample_negative(self, attr, obj): 311 | ''' 312 | Inputs 313 | attr: String of valid attribute 314 | obj: String of valid object 315 | Returns 316 | Tuple of a different attribute, object indexes 317 | ''' 318 | new_attr, new_obj = self.sample_pairs[np.random.choice( 319 | len(self.sample_pairs))] 320 | 321 | while new_attr == attr and new_obj == obj: 322 | new_attr, new_obj = self.sample_pairs[np.random.choice( 323 | len(self.sample_pairs))] 324 | 325 | return (self.attr2idx[new_attr], self.obj2idx[new_obj]) 326 | 327 | def sample_affordance(self, attr, obj): 328 | ''' 329 | Inputs 330 | attr: String of valid attribute 331 | obj: String of valid object 332 | Return 333 | Idx of a different attribute for the same object 334 | ''' 335 | new_attr = np.random.choice(self.obj_affordance[obj]) 336 | 337 | while new_attr == attr: 338 | new_attr = np.random.choice(self.obj_affordance[obj]) 339 | 340 | return self.attr2idx[new_attr] 341 | 342 | def sample_train_affordance(self, attr, obj): 343 | ''' 344 | Inputs 345 | attr: String of valid attribute 346 | obj: String of valid object 347 | Return 348 | Idx of a different attribute for the same object from the training pairs 349 | ''' 350 | new_attr = np.random.choice(self.train_obj_affordance[obj]) 351 | 352 | while new_attr == attr: 353 | new_attr = np.random.choice(self.train_obj_affordance[obj]) 354 | 355 | return self.attr2idx[new_attr] 356 | 357 | def generate_features(self, out_file, model): 358 | ''' 359 | Inputs 360 | out_file: Path to save features 361 | model: String of extraction model 362 | ''' 363 | # data = self.all_data 364 | data = ospj(self.root,'images') 365 | files_before = glob(ospj(data, '**', '*.jpg'), recursive=True) 366 | files_all = [] 367 | for current in files_before: 368 | parts = current.split('/') 369 | if "cgqa" in self.root: 370 | files_all.append(parts[-1]) 371 | else: 372 | files_all.append(os.path.join(parts[-2],parts[-1])) 373 | transform = dataset_transform('test', self.norm_family) 374 | feat_extractor = get_image_extractor(arch = model).eval() 375 | feat_extractor = feat_extractor.to(device) 376 | 377 | image_feats = [] 378 | image_files = [] 379 | for chunk in tqdm( 380 | chunks(files_all, 512), total=len(files_all) // 512, desc=f'Extracting features {model}'): 381 | 382 | files = chunk 383 | imgs = list(map(self.loader, files)) 384 | imgs = list(map(transform, imgs)) 385 | feats = feat_extractor(torch.stack(imgs, 0).to(device)) 386 | image_feats.append(feats.data.cpu()) 387 | image_files += files 388 | image_feats = torch.cat(image_feats, 0) 389 | print('features for %d images generated' % (len(image_files))) 390 | 391 | torch.save({'features': image_feats, 'files': image_files}, out_file) 392 | 393 | def __getitem__(self, index): 394 | ''' 395 | Call for getting samples 396 | ''' 397 | index = self.sample_indices[index] 398 | 399 | image, attr, obj = self.data[index] 400 | 401 | # Decide what to output 402 | if not self.update_features: 403 | img = self.activations[image] 404 | else: 405 | img = self.loader(image) 406 | img = self.transform(img) 407 | 408 | data = [img, self.attr2idx[attr], self.obj2idx[obj], self.pair2idx[(attr, obj)]] 409 | 410 | if self.phase == 'train': 411 | all_neg_attrs = [] 412 | all_neg_objs = [] 413 | 414 | for curr in range(self.num_negs): 415 | neg_attr, neg_obj = self.sample_negative(attr, obj) # negative for triplet lose 416 | all_neg_attrs.append(neg_attr) 417 | all_neg_objs.append(neg_obj) 418 | 419 | neg_attr, neg_obj = torch.LongTensor(all_neg_attrs), torch.LongTensor(all_neg_objs) 420 | 421 | #note here 422 | if len(self.train_obj_affordance[obj])>1: 423 | inv_attr = self.sample_train_affordance(attr, obj) # attribute for inverse regularizer 424 | else: 425 | inv_attr = (all_neg_attrs[0]) 426 | 427 | comm_attr = self.sample_affordance(inv_attr, obj) # attribute for commutative regularizer 428 | 429 | 430 | data += [neg_attr, neg_obj, inv_attr, comm_attr] 431 | 432 | # Return image paths if requested as the last element of the list 433 | if self.return_images and self.phase != 'train': 434 | data.append(image) 435 | 436 | return data 437 | 438 | def __len__(self): 439 | ''' 440 | Call for length 441 | ''' 442 | return len(self.sample_indices) 443 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: czsl 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - blas=1.0=mkl 8 | - ca-certificates=2021.1.19=h06a4308_0 9 | - certifi=2020.12.5=py38h06a4308_0 10 | - cudatoolkit=11.0.221=h6bb024c_0 11 | - freetype=2.10.4=h5ab3b9f_0 12 | - intel-openmp=2020.2=254 13 | - jpeg=9b=h024ee3a_2 14 | - lcms2=2.11=h396b838_0 15 | - ld_impl_linux-64=2.33.1=h53a641e_7 16 | - libedit=3.1.20191231=h14c3975_1 17 | - libffi=3.3=he6710b0_2 18 | - libgcc-ng=9.1.0=hdf63c60_0 19 | - libgfortran-ng=7.3.0=hdf63c60_0 20 | - libpng=1.6.37=hbc83047_0 21 | - libstdcxx-ng=9.1.0=hdf63c60_0 22 | - libtiff=4.1.0=h2733197_1 23 | - libuv=1.40.0=h7b6447c_0 24 | - lz4-c=1.9.3=h2531618_0 25 | - mkl=2020.2=256 26 | - mkl-service=2.3.0=py38he904b0f_0 27 | - mkl_fft=1.2.0=py38h23d657b_0 28 | - mkl_random=1.1.1=py38h0573a6f_0 29 | - ncurses=6.2=he6710b0_1 30 | - ninja=1.10.2=py38hff7bd54_0 31 | - numpy=1.19.2=py38h54aff64_0 32 | - numpy-base=1.19.2=py38hfa32c7d_0 33 | - olefile=0.46=py_0 34 | - openssl=1.1.1i=h27cfd23_0 35 | - pillow=8.1.0=py38he98fc37_0 36 | - pip=20.3.3=py38h06a4308_0 37 | - python=3.8.5=h7579374_1 38 | - pytorch=1.7.1=py3.8_cuda11.0.221_cudnn8.0.5_0 39 | - readline=8.1=h27cfd23_0 40 | - scipy=1.5.2=py38h0b6359f_0 41 | - setuptools=52.0.0=py38h06a4308_0 42 | - six=1.15.0=py38h06a4308_0 43 | - sqlite=3.33.0=h62c20be_0 44 | - tk=8.6.10=hbc83047_0 45 | - torchaudio=0.7.2=py38 46 | - torchvision=0.8.2=py38_cu110 47 | - typing_extensions=3.7.4.3=pyh06a4308_0 48 | - wheel=0.36.2=pyhd3eb1b0_0 49 | - xz=5.2.5=h7b6447c_0 50 | - yaml=0.2.5=h7b6447c_0 51 | - zlib=1.2.11=h7b6447c_3 52 | - zstd=1.4.5=h9ceee32_0 53 | - pip: 54 | - absl-py==0.11.0 55 | - cachetools==4.2.1 56 | - chardet==4.0.0 57 | - fasttext==0.9.2 58 | - gensim==3.8.3 59 | - google-auth==1.24.0 60 | - google-auth-oauthlib==0.4.2 61 | - grpcio==1.35.0 62 | - idna==2.10 63 | - markdown==3.3.3 64 | - oauthlib==3.1.0 65 | - protobuf==3.14.0 66 | - pyasn1==0.4.8 67 | - pyasn1-modules==0.2.8 68 | - pybind11==2.6.2 69 | - pyyaml==5.4.1 70 | - requests==2.25.1 71 | - requests-oauthlib==1.3.0 72 | - rsa==4.7 73 | - smart-open==4.1.2 74 | - tensorboard==2.4.1 75 | - tensorboard-plugin-wit==1.8.0 76 | - tqdm==4.56.0 77 | - urllib3==1.26.3 78 | - werkzeug==1.0.1 79 | 80 | -------------------------------------------------------------------------------- /flags.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | DATA_FOLDER = "ROOT_FOLDER" 4 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 5 | 6 | parser.add_argument('--config', default='configs/args.yml', help='path of the config file (training only)') 7 | parser.add_argument('--dataset', default='mitstates', help='mitstates|zappos') 8 | parser.add_argument('--data_dir', default='mit-states', help='local path to data root dir from ' + DATA_FOLDER) 9 | parser.add_argument('--logpath', default=None, help='Path to dir where to logs are stored (test only)') 10 | parser.add_argument('--splitname', default='compositional-split-natural', help="dataset split") 11 | parser.add_argument('--cv_dir', default='logs/', help='dir to save checkpoints and logs to') 12 | parser.add_argument('--name', default='temp', help='Name of exp used to name models') 13 | parser.add_argument('--load', default=None, help='path to checkpoint to load from') 14 | parser.add_argument('--image_extractor', default = 'resnet18', help = 'Feature extractor model') 15 | parser.add_argument('--norm_family', default = 'imagenet', help = 'Normalization values from dataset') 16 | parser.add_argument('--num_negs', type=int, default=1, help='Number of negatives to sample per positive (triplet loss)') 17 | parser.add_argument('--pair_dropout', type=float, default=0.0, help='Each epoch drop this fraction of train pairs') 18 | parser.add_argument('--test_set', default='val', help='val|test mode') 19 | parser.add_argument('--clean_only', action='store_true', default=False, help='use only clean subset of data (mitstates)') 20 | parser.add_argument('--subset', action='store_true', default=False, help='test on a 1000 image subset (debug purpose)') 21 | parser.add_argument('--open_world', action='store_true', default=False, help='perform open world experiment') 22 | parser.add_argument('--test_batch_size', type=int, default=32, help="Batch size at test/eval time") 23 | parser.add_argument('--cpu_eval', action='store_true', help='Perform test on cpu') 24 | 25 | # Model parameters 26 | parser.add_argument('--model', default='graphfull', help='visprodNN|redwine|labelembed+|attributeop|tmn|compcos') 27 | parser.add_argument('--emb_dim', type=int, default=300, help='dimension of share embedding space') 28 | parser.add_argument('--nlayers', type=int, default=3, help='Layers in the image embedder') 29 | parser.add_argument('--nmods', type=int, default=24, help='number of mods per layer for TMN') 30 | parser.add_argument('--embed_rank', type=int, default=64, help='intermediate dimension in the gating model for TMN') 31 | parser.add_argument('--bias', type=float, default=1e3, help='Bias value for unseen concepts') 32 | parser.add_argument('--update_features', action = 'store_true', default=False, help='If specified, train feature extractor') 33 | parser.add_argument('--freeze_features', action = 'store_true', default=False, help='If specified, put extractor in eval mode') 34 | parser.add_argument('--emb_init', default=None, help='w2v|ft|gl|glove|word2vec|fasttext, name of embeddings to use for initializing the primitives') 35 | parser.add_argument('--clf_init', action='store_true', default=False, help='initialize inputs with SVM weights') 36 | parser.add_argument('--static_inp', action='store_true', default=False, help='do not optimize primitives representations') 37 | parser.add_argument('--composition', default='mlp_add', help='add|mul|mlp|mlp_add, how to compose primitives') 38 | parser.add_argument('--relu', action='store_true', default=False, help='Use relu in image embedder') 39 | parser.add_argument('--dropout', action='store_true', default=False, help='Use dropout in image embedder') 40 | parser.add_argument('--norm', action='store_true', default=False, help='Use normalization in image embedder') 41 | parser.add_argument('--train_only', action='store_true', default=False, help='Optimize only for train pairs') 42 | 43 | #CGE 44 | parser.add_argument('--graph', action='store_true', default=False, help='graph l2 distance triplet') # Do we need this 45 | parser.add_argument('--graph_init', default=None, help='filename, file from which initializing the nodes and adjacency matrix of the graph') 46 | parser.add_argument('--gcn_type', default='gcn', help='GCN Version') 47 | 48 | # Forward 49 | parser.add_argument('--eval_type', default='dist', help='dist|prod|direct, function for computing the predictions') 50 | 51 | # Primitive-based loss 52 | parser.add_argument('--lambda_aux', type=float, default=0.0, help='weight of auxiliar losses on the primitives') 53 | 54 | # AoP 55 | parser.add_argument('--lambda_inv', type=float, default=0.0, help='weight of inverse consistency loss') 56 | parser.add_argument('--lambda_comm', type=float, default=0.0, help='weight of commutative loss') 57 | parser.add_argument('--lambda_ant', type=float, default=0.0, help='weight of antonyms loss') 58 | 59 | 60 | # SymNet 61 | parser.add_argument("--lambda_trip", type=float, default=0, help='weight of triplet loss') 62 | parser.add_argument("--lambda_sym", type=float, default=0, help='weight of symmetry loss') 63 | parser.add_argument("--lambda_axiom", type=float, default=0, help='weight of axiom loss') 64 | parser.add_argument("--lambda_cls_attr", type=float, default=0, help='weight of classification loss on attributes') 65 | parser.add_argument("--lambda_cls_obj", type=float, default=0, help='weight of classification loss on objects') 66 | 67 | # CompCos (for the margin, see below) 68 | parser.add_argument('--cosine_scale', type=float, default=20,help="Scale for cosine similarity") 69 | parser.add_argument('--epoch_max_margin', type=float, default=15,help="Epoch of max margin") 70 | parser.add_argument('--update_feasibility_every', type=float, default=1,help="Frequency of feasibility scores update in epochs") 71 | parser.add_argument('--hard_masking', action='store_true', default=False, help='hard masking during test time') 72 | parser.add_argument('--threshold', default=None,help="Apply a specific threshold at test time for the hard masking") 73 | parser.add_argument('--threshold_trials', type=int, default=50,help="how many threshold values to try") 74 | 75 | # Hyperparameters 76 | parser.add_argument('--topk', type=int, default=1,help="Compute topk accuracy") 77 | parser.add_argument('--margin', type=float, default=2,help="Margin for triplet loss or feasibility scores in CompCos") 78 | parser.add_argument('--workers', type=int, default=8,help="Number of workers") 79 | parser.add_argument('--batch_size', type=int, default=512,help="Training batch size") 80 | parser.add_argument('--lr', type=float, default=5e-5,help="Learning rate") 81 | parser.add_argument('--lrg', type=float, default=1e-3,help="Learning rate feature extractor") 82 | parser.add_argument('--wd', type=float, default=5e-5,help="Weight decay") 83 | parser.add_argument('--save_every', type=int, default=10000,help="Frequency of snapshots in epochs") 84 | parser.add_argument('--eval_val_every', type=int, default=1,help="Frequency of eval in epochs") 85 | parser.add_argument('--max_epochs', type=int, default=800,help="Max number of epochs") 86 | parser.add_argument("--fc_emb", default='768,1024,1200',help="Image embedder layer config") 87 | parser.add_argument("--gr_emb", default='d4096,d',help="graph layers config") 88 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/czsl/22ab1aca82c64d4f351a20a887986147edfb1905/models/__init__.py -------------------------------------------------------------------------------- /models/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import copy 6 | from scipy.stats import hmean 7 | 8 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 9 | 10 | class MLP(nn.Module): 11 | ''' 12 | Baseclass to create a simple MLP 13 | Inputs 14 | inp_dim: Int, Input dimension 15 | out-dim: Int, Output dimension 16 | num_layer: Number of hidden layers 17 | relu: Bool, Use non linear function at output 18 | bias: Bool, Use bias 19 | ''' 20 | def __init__(self, inp_dim, out_dim, num_layers = 1, relu = True, bias = True, dropout = False, norm = False, layers = []): 21 | super(MLP, self).__init__() 22 | mod = [] 23 | incoming = inp_dim 24 | for layer in range(num_layers - 1): 25 | if len(layers) == 0: 26 | outgoing = incoming 27 | else: 28 | outgoing = layers.pop(0) 29 | mod.append(nn.Linear(incoming, outgoing, bias = bias)) 30 | 31 | incoming = outgoing 32 | if norm: 33 | mod.append(nn.LayerNorm(outgoing)) 34 | # mod.append(nn.BatchNorm1d(outgoing)) 35 | mod.append(nn.ReLU(inplace = True)) 36 | # mod.append(nn.LeakyReLU(inplace=True, negative_slope=0.2)) 37 | if dropout: 38 | mod.append(nn.Dropout(p = 0.5)) 39 | 40 | mod.append(nn.Linear(incoming, out_dim, bias = bias)) 41 | 42 | if relu: 43 | mod.append(nn.ReLU(inplace = True)) 44 | # mod.append(nn.LeakyReLU(inplace=True, negative_slope=0.2)) 45 | self.mod = nn.Sequential(*mod) 46 | 47 | def forward(self, x): 48 | return self.mod(x) 49 | 50 | class Reshape(nn.Module): 51 | def __init__(self, *args): 52 | super(Reshape, self).__init__() 53 | self.shape = args 54 | 55 | def forward(self, x): 56 | return x.view(self.shape) 57 | 58 | def calculate_margines(domain_embedding, gt, margin_range = 5): 59 | ''' 60 | domain_embedding: pairs * feats 61 | gt: batch * feats 62 | ''' 63 | batch_size, pairs, features = gt.shape[0], domain_embedding.shape[0], domain_embedding.shape[1] 64 | gt_expanded = gt[:,None,:].expand(-1, pairs, -1) 65 | domain_embedding_expanded = domain_embedding[None, :, :].expand(batch_size, -1, -1) 66 | margin = (gt_expanded - domain_embedding_expanded)**2 67 | margin = margin.sum(2) 68 | max_margin, _ = torch.max(margin, dim = 0) 69 | margin /= max_margin 70 | margin *= margin_range 71 | return margin 72 | 73 | def l2_all_batched(image_embedding, domain_embedding): 74 | ''' 75 | Image Embedding: Tensor of Batch_size * pairs * Feature_dim 76 | domain_embedding: Tensor of pairs * Feature_dim 77 | ''' 78 | pairs = image_embedding.shape[1] 79 | domain_embedding_extended = image_embedding[:,None,:].expand(-1,pairs,-1) 80 | l2_loss = (image_embedding - domain_embedding_extended) ** 2 81 | l2_loss = l2_loss.sum(2) 82 | l2_loss = l2_loss.sum() / l2_loss.numel() 83 | return l2_loss 84 | 85 | def same_domain_triplet_loss(image_embedding, trip_images, gt, hard_k = None, margin = 2): 86 | ''' 87 | Image Embedding: Tensor of Batch_size * Feature_dim 88 | Triplet Images: Tensor of Batch_size * num_pairs * Feature_dim 89 | GT: Tensor of Batch_size 90 | ''' 91 | batch_size, pairs, features = trip_images.shape 92 | batch_iterator = torch.arange(batch_size).to(device) 93 | image_embedding_expanded = image_embedding[:,None,:].expand(-1, pairs, -1) 94 | 95 | diff = (image_embedding_expanded - trip_images)**2 96 | diff = diff.sum(2) 97 | 98 | positive_anchor = diff[batch_iterator, gt][:,None] 99 | positive_anchor = positive_anchor.expand(-1, pairs) 100 | 101 | # Calculating triplet loss 102 | triplet_loss = positive_anchor - diff + margin 103 | 104 | # Setting positive anchor loss to 0 105 | triplet_loss[batch_iterator, gt] = 0 106 | 107 | # Removing easy triplets 108 | triplet_loss[triplet_loss < 0] = 0 109 | 110 | # If only mining hard triplets 111 | if hard_k: 112 | triplet_loss, _ = triplet_loss.topk(hard_k) 113 | 114 | # Counting number of valid pairs 115 | num_positive_triplets = triplet_loss[triplet_loss > 1e-16].size(0) 116 | 117 | # Calculating the final loss 118 | triplet_loss = triplet_loss.sum() / (num_positive_triplets + 1e-16) 119 | 120 | return triplet_loss 121 | 122 | 123 | def cross_domain_triplet_loss(image_embedding, domain_embedding, gt, hard_k = None, margin = 2): 124 | ''' 125 | Image Embedding: Tensor of Batch_size * Feature_dim 126 | Domain Embedding: Tensor of Num_pairs * Feature_dim 127 | gt: Tensor of Batch_size with ground truth labels 128 | margin: Float of margin 129 | Returns: 130 | Triplet loss of all valid triplets 131 | ''' 132 | batch_size, pairs, features = image_embedding.shape[0], domain_embedding.shape[0], domain_embedding.shape[1] 133 | batch_iterator = torch.arange(batch_size).to(device) 134 | # Now dimensions will be Batch_size * Num_pairs * Feature_dim 135 | image_embedding = image_embedding[:,None,:].expand(-1, pairs, -1) 136 | domain_embedding = domain_embedding[None,:,:].expand(batch_size, -1, -1) 137 | 138 | # Calculating difference 139 | diff = (image_embedding - domain_embedding)**2 140 | diff = diff.sum(2) 141 | 142 | # Getting the positive pair 143 | positive_anchor = diff[batch_iterator, gt][:,None] 144 | positive_anchor = positive_anchor.expand(-1, pairs) 145 | 146 | # Calculating triplet loss 147 | triplet_loss = positive_anchor - diff + margin 148 | 149 | # Setting positive anchor loss to 0 150 | triplet_loss[batch_iterator, gt] = 0 151 | 152 | # Removing easy triplets 153 | triplet_loss[triplet_loss < 0] = 0 154 | 155 | # If only mining hard triplets 156 | if hard_k: 157 | triplet_loss, _ = triplet_loss.topk(hard_k) 158 | 159 | # Counting number of valid pairs 160 | num_positive_triplets = triplet_loss[triplet_loss > 1e-16].size(0) 161 | 162 | # Calculating the final loss 163 | triplet_loss = triplet_loss.sum() / (num_positive_triplets + 1e-16) 164 | 165 | return triplet_loss 166 | 167 | def same_domain_triplet_loss_old(image_embedding, positive_anchor, negative_anchor, margin = 2): 168 | ''' 169 | Image Embedding: Tensor of Batch_size * Feature_dim 170 | Positive anchor: Tensor of Batch_size * Feature_dim 171 | negative anchor: Tensor of Batch_size * negs *Feature_dim 172 | ''' 173 | batch_size, negs, features = negative_anchor.shape 174 | dist_pos = (image_embedding - positive_anchor)**2 175 | dist_pos = dist_pos.sum(1) 176 | dist_pos = dist_pos[:, None].expand(-1, negs) 177 | 178 | image_embedding_expanded = image_embedding[:,None,:].expand(-1, negs, -1) 179 | dist_neg = (image_embedding_expanded - negative_anchor)**2 180 | dist_neg = dist_neg.sum(2) 181 | 182 | triplet_loss = dist_pos - dist_neg + margin 183 | triplet_loss[triplet_loss < 0] = 0 184 | num_positive_triplets = triplet_loss[triplet_loss > 1e-16].size(0) 185 | 186 | triplet_loss = triplet_loss.sum() / (num_positive_triplets + 1e-16) 187 | return triplet_loss 188 | 189 | 190 | def pairwise_distances(x, y=None): 191 | ''' 192 | Input: x is a Nxd matrix 193 | y is an optional Mxd matirx 194 | Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:] 195 | if y is not given then use 'y=x'. 196 | i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2 197 | ''' 198 | x_norm = (x**2).sum(1).view(-1, 1) 199 | if y is not None: 200 | y_t = torch.transpose(y, 0, 1) 201 | y_norm = (y**2).sum(1).view(1, -1) 202 | else: 203 | y_t = torch.transpose(x, 0, 1) 204 | y_norm = x_norm.view(1, -1) 205 | 206 | dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t) 207 | # Ensure diagonal is zero if x=y 208 | # if y is None: 209 | # dist = dist - torch.diag(dist.diag) 210 | return torch.clamp(dist, 0.0, np.inf) 211 | 212 | class Evaluator: 213 | 214 | def __init__(self, dset, model): 215 | 216 | self.dset = dset 217 | 218 | # Convert text pairs to idx tensors: [('sliced', 'apple'), ('ripe', 'apple'), ...] --> torch.LongTensor([[0,1],[1,1], ...]) 219 | pairs = [(dset.attr2idx[attr], dset.obj2idx[obj]) for attr, obj in dset.pairs] 220 | self.train_pairs = [(dset.attr2idx[attr], dset.obj2idx[obj]) for attr, obj in dset.train_pairs] 221 | self.pairs = torch.LongTensor(pairs) 222 | 223 | # Mask over pairs that occur in closed world 224 | # Select set based on phase 225 | if dset.phase == 'train': 226 | print('Evaluating with train pairs') 227 | test_pair_set = set(dset.train_pairs) 228 | test_pair_gt = set(dset.train_pairs) 229 | elif dset.phase == 'val': 230 | print('Evaluating with validation pairs') 231 | test_pair_set = set(dset.val_pairs + dset.train_pairs) 232 | test_pair_gt = set(dset.val_pairs) 233 | else: 234 | print('Evaluating with test pairs') 235 | test_pair_set = set(dset.test_pairs + dset.train_pairs) 236 | test_pair_gt = set(dset.test_pairs) 237 | 238 | self.test_pair_dict = [(dset.attr2idx[attr], dset.obj2idx[obj]) for attr, obj in test_pair_gt] 239 | self.test_pair_dict = dict.fromkeys(self.test_pair_dict, 0) 240 | 241 | # dict values are pair val, score, total 242 | for attr, obj in test_pair_gt: 243 | pair_val = dset.pair2idx[(attr,obj)] 244 | key = (dset.attr2idx[attr], dset.obj2idx[obj]) 245 | self.test_pair_dict[key] = [pair_val, 0, 0] 246 | 247 | if dset.open_world: 248 | masks = [1 for _ in dset.pairs] 249 | else: 250 | masks = [1 if pair in test_pair_set else 0 for pair in dset.pairs] 251 | 252 | self.closed_mask = torch.BoolTensor(masks) 253 | # Mask of seen concepts 254 | seen_pair_set = set(dset.train_pairs) 255 | mask = [1 if pair in seen_pair_set else 0 for pair in dset.pairs] 256 | self.seen_mask = torch.BoolTensor(mask) 257 | 258 | # Object specific mask over which pairs occur in the object oracle setting 259 | oracle_obj_mask = [] 260 | for _obj in dset.objs: 261 | mask = [1 if _obj == obj else 0 for attr, obj in dset.pairs] 262 | oracle_obj_mask.append(torch.BoolTensor(mask)) 263 | self.oracle_obj_mask = torch.stack(oracle_obj_mask, 0) 264 | 265 | # Decide if the model under evaluation is a manifold model or not 266 | self.score_model = self.score_manifold_model 267 | 268 | # Generate mask for each settings, mask scores, and get prediction labels 269 | def generate_predictions(self, scores, obj_truth, bias = 0.0, topk = 5): # (Batch, #pairs) 270 | ''' 271 | Inputs 272 | scores: Output scores 273 | obj_truth: Ground truth object 274 | Returns 275 | results: dict of results in 3 settings 276 | ''' 277 | def get_pred_from_scores(_scores, topk): 278 | ''' 279 | Given list of scores, returns top 10 attr and obj predictions 280 | Check later 281 | ''' 282 | _, pair_pred = _scores.topk(topk, dim = 1) #sort returns indices of k largest values 283 | pair_pred = pair_pred.contiguous().view(-1) 284 | attr_pred, obj_pred = self.pairs[pair_pred][:, 0].view(-1, topk), \ 285 | self.pairs[pair_pred][:, 1].view(-1, topk) 286 | return (attr_pred, obj_pred) 287 | 288 | results = {} 289 | orig_scores = scores.clone() 290 | mask = self.seen_mask.repeat(scores.shape[0],1) # Repeat mask along pairs dimension 291 | scores[~mask] += bias # Add bias to test pairs 292 | 293 | # Unbiased setting 294 | 295 | # Open world setting --no mask, all pairs of the dataset 296 | results.update({'open': get_pred_from_scores(scores, topk)}) 297 | results.update({'unbiased_open': get_pred_from_scores(orig_scores, topk)}) 298 | # Closed world setting - set the score for all Non test pairs to -1e10, 299 | # this excludes the pairs from set not in evaluation 300 | mask = self.closed_mask.repeat(scores.shape[0], 1) 301 | closed_scores = scores.clone() 302 | closed_scores[~mask] = -1e10 303 | closed_orig_scores = orig_scores.clone() 304 | closed_orig_scores[~mask] = -1e10 305 | results.update({'closed': get_pred_from_scores(closed_scores, topk)}) 306 | results.update({'unbiased_closed': get_pred_from_scores(closed_orig_scores, topk)}) 307 | 308 | # Object_oracle setting - set the score to -1e10 for all pairs where the true object does Not participate, can also use the closed score 309 | mask = self.oracle_obj_mask[obj_truth] 310 | oracle_obj_scores = scores.clone() 311 | oracle_obj_scores[~mask] = -1e10 312 | oracle_obj_scores_unbiased = orig_scores.clone() 313 | oracle_obj_scores_unbiased[~mask] = -1e10 314 | results.update({'object_oracle': get_pred_from_scores(oracle_obj_scores, 1)}) 315 | results.update({'object_oracle_unbiased': get_pred_from_scores(oracle_obj_scores_unbiased, 1)}) 316 | 317 | return results 318 | 319 | def score_clf_model(self, scores, obj_truth, topk = 5): 320 | ''' 321 | Wrapper function to call generate_predictions for CLF models 322 | ''' 323 | attr_pred, obj_pred = scores 324 | 325 | # Go to CPU 326 | attr_pred, obj_pred, obj_truth = attr_pred.to('cpu'), obj_pred.to('cpu'), obj_truth.to('cpu') 327 | 328 | # Gather scores (P(a), P(o)) for all relevant (a,o) pairs 329 | # Multiply P(a) * P(o) to get P(pair) 330 | attr_subset = attr_pred.index_select(1, self.pairs[:,0]) # Return only attributes that are in our pairs 331 | obj_subset = obj_pred.index_select(1, self.pairs[:, 1]) 332 | scores = (attr_subset * obj_subset) # (Batch, #pairs) 333 | 334 | results = self.generate_predictions(scores, obj_truth) 335 | results['biased_scores'] = scores 336 | 337 | return results 338 | 339 | def score_manifold_model(self, scores, obj_truth, bias = 0.0, topk = 5): 340 | ''' 341 | Wrapper function to call generate_predictions for manifold models 342 | ''' 343 | # Go to CPU 344 | scores = {k: v.to('cpu') for k, v in scores.items()} 345 | obj_truth = obj_truth.to(device) 346 | 347 | # Gather scores for all relevant (a,o) pairs 348 | scores = torch.stack( 349 | [scores[(attr,obj)] for attr, obj in self.dset.pairs], 1 350 | ) # (Batch, #pairs) 351 | orig_scores = scores.clone() 352 | results = self.generate_predictions(scores, obj_truth, bias, topk) 353 | results['scores'] = orig_scores 354 | return results 355 | 356 | def score_fast_model(self, scores, obj_truth, bias = 0.0, topk = 5): 357 | ''' 358 | Wrapper function to call generate_predictions for manifold models 359 | ''' 360 | 361 | results = {} 362 | mask = self.seen_mask.repeat(scores.shape[0],1) # Repeat mask along pairs dimension 363 | scores[~mask] += bias # Add bias to test pairs 364 | 365 | mask = self.closed_mask.repeat(scores.shape[0], 1) 366 | closed_scores = scores.clone() 367 | closed_scores[~mask] = -1e10 368 | 369 | _, pair_pred = closed_scores.topk(topk, dim = 1) #sort returns indices of k largest values 370 | pair_pred = pair_pred.contiguous().view(-1) 371 | attr_pred, obj_pred = self.pairs[pair_pred][:, 0].view(-1, topk), \ 372 | self.pairs[pair_pred][:, 1].view(-1, topk) 373 | 374 | results.update({'closed': (attr_pred, obj_pred)}) 375 | return results 376 | 377 | def evaluate_predictions(self, predictions, attr_truth, obj_truth, pair_truth, allpred, topk = 1): 378 | # Go to CPU 379 | attr_truth, obj_truth, pair_truth = attr_truth.to('cpu'), obj_truth.to('cpu'), pair_truth.to('cpu') 380 | 381 | pairs = list( 382 | zip(list(attr_truth.numpy()), list(obj_truth.numpy()))) 383 | 384 | 385 | seen_ind, unseen_ind = [], [] 386 | for i in range(len(attr_truth)): 387 | if pairs[i] in self.train_pairs: 388 | seen_ind.append(i) 389 | else: 390 | unseen_ind.append(i) 391 | 392 | 393 | seen_ind, unseen_ind = torch.LongTensor(seen_ind), torch.LongTensor(unseen_ind) 394 | def _process(_scores): 395 | # Top k pair accuracy 396 | # Attribute, object and pair 397 | attr_match = (attr_truth.unsqueeze(1).repeat(1, topk) == _scores[0][:, :topk]) 398 | obj_match = (obj_truth.unsqueeze(1).repeat(1, topk) == _scores[1][:, :topk]) 399 | 400 | # Match of object pair 401 | match = (attr_match * obj_match).any(1).float() 402 | attr_match = attr_match.any(1).float() 403 | obj_match = obj_match.any(1).float() 404 | # Match of seen and unseen pairs 405 | seen_match = match[seen_ind] 406 | unseen_match = match[unseen_ind] 407 | ### Calculating class average accuracy 408 | 409 | # local_score_dict = copy.deepcopy(self.test_pair_dict) 410 | # for pair_gt, pair_pred in zip(pairs, match): 411 | # # print(pair_gt) 412 | # local_score_dict[pair_gt][2] += 1.0 #increase counter 413 | # if int(pair_pred) == 1: 414 | # local_score_dict[pair_gt][1] += 1.0 415 | 416 | # # Now we have hits and totals for classes in evaluation set 417 | # seen_score, unseen_score = [], [] 418 | # for key, (idx, hits, total) in local_score_dict.items(): 419 | # score = hits/total 420 | # if bool(self.seen_mask[idx]) == True: 421 | # seen_score.append(score) 422 | # else: 423 | # unseen_score.append(score) 424 | 425 | seen_score, unseen_score = torch.ones(512,5), torch.ones(512,5) 426 | 427 | return attr_match, obj_match, match, seen_match, unseen_match, \ 428 | torch.Tensor(seen_score+unseen_score), torch.Tensor(seen_score), torch.Tensor(unseen_score) 429 | 430 | def _add_to_dict(_scores, type_name, stats): 431 | base = ['_attr_match', '_obj_match', '_match', '_seen_match', '_unseen_match', '_ca', '_seen_ca', '_unseen_ca'] 432 | for val, name in zip(_scores, base): 433 | stats[type_name + name] = val 434 | 435 | ##################### Match in places where corrent object 436 | obj_oracle_match = (attr_truth == predictions['object_oracle'][0][:, 0]).float() #object is already conditioned 437 | obj_oracle_match_unbiased = (attr_truth == predictions['object_oracle_unbiased'][0][:, 0]).float() 438 | 439 | stats = dict(obj_oracle_match = obj_oracle_match, obj_oracle_match_unbiased = obj_oracle_match_unbiased) 440 | 441 | #################### Closed world 442 | closed_scores = _process(predictions['closed']) 443 | unbiased_closed = _process(predictions['unbiased_closed']) 444 | _add_to_dict(closed_scores, 'closed', stats) 445 | _add_to_dict(unbiased_closed, 'closed_ub', stats) 446 | 447 | #################### Calculating AUC 448 | scores = predictions['scores'] 449 | # getting score for each ground truth class 450 | correct_scores = scores[torch.arange(scores.shape[0]), pair_truth][unseen_ind] 451 | 452 | # Getting top predicted score for these unseen classes 453 | max_seen_scores = predictions['scores'][unseen_ind][:, self.seen_mask].topk(topk, dim=1)[0][:, topk - 1] 454 | 455 | # Getting difference between these scores 456 | unseen_score_diff = max_seen_scores - correct_scores 457 | 458 | # Getting matched classes at max bias for diff 459 | unseen_matches = stats['closed_unseen_match'].bool() 460 | correct_unseen_score_diff = unseen_score_diff[unseen_matches] - 1e-4 461 | 462 | # sorting these diffs 463 | correct_unseen_score_diff = torch.sort(correct_unseen_score_diff)[0] 464 | magic_binsize = 20 465 | # getting step size for these bias values 466 | bias_skip = max(len(correct_unseen_score_diff) // magic_binsize, 1) 467 | # Getting list 468 | biaslist = correct_unseen_score_diff[::bias_skip] 469 | 470 | seen_match_max = float(stats['closed_seen_match'].mean()) 471 | unseen_match_max = float(stats['closed_unseen_match'].mean()) 472 | seen_accuracy, unseen_accuracy = [], [] 473 | 474 | # Go to CPU 475 | base_scores = {k: v.to('cpu') for k, v in allpred.items()} 476 | obj_truth = obj_truth.to('cpu') 477 | 478 | # Gather scores for all relevant (a,o) pairs 479 | base_scores = torch.stack( 480 | [allpred[(attr,obj)] for attr, obj in self.dset.pairs], 1 481 | ) # (Batch, #pairs) 482 | 483 | for bias in biaslist: 484 | scores = base_scores.clone() 485 | results = self.score_fast_model(scores, obj_truth, bias = bias, topk = topk) 486 | results = results['closed'] # we only need biased 487 | results = _process(results) 488 | seen_match = float(results[3].mean()) 489 | unseen_match = float(results[4].mean()) 490 | seen_accuracy.append(seen_match) 491 | unseen_accuracy.append(unseen_match) 492 | 493 | seen_accuracy.append(seen_match_max) 494 | unseen_accuracy.append(unseen_match_max) 495 | seen_accuracy, unseen_accuracy = np.array(seen_accuracy), np.array(unseen_accuracy) 496 | area = np.trapz(seen_accuracy, unseen_accuracy) 497 | 498 | for key in stats: 499 | stats[key] = float(stats[key].mean()) 500 | 501 | harmonic_mean = hmean([seen_accuracy, unseen_accuracy], axis = 0) 502 | max_hm = np.max(harmonic_mean) 503 | idx = np.argmax(harmonic_mean) 504 | if idx == len(biaslist): 505 | bias_term = 1e3 506 | else: 507 | bias_term = biaslist[idx] 508 | stats['biasterm'] = float(bias_term) 509 | stats['best_unseen'] = np.max(unseen_accuracy) 510 | stats['best_seen'] = np.max(seen_accuracy) 511 | stats['AUC'] = area 512 | stats['hm_unseen'] = unseen_accuracy[idx] 513 | stats['hm_seen'] = seen_accuracy[idx] 514 | stats['best_hm'] = max_hm 515 | return stats -------------------------------------------------------------------------------- /models/compcos.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .word_embedding import load_word_embeddings 5 | from .common import MLP 6 | 7 | from itertools import product 8 | 9 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 10 | 11 | def compute_cosine_similarity(names, weights, return_dict=True): 12 | pairing_names = list(product(names, names)) 13 | normed_weights = F.normalize(weights,dim=1) 14 | similarity = torch.mm(normed_weights, normed_weights.t()) 15 | if return_dict: 16 | dict_sim = {} 17 | for i,n in enumerate(names): 18 | for j,m in enumerate(names): 19 | dict_sim[(n,m)]=similarity[i,j].item() 20 | return dict_sim 21 | return pairing_names, similarity.to('cpu') 22 | 23 | class CompCos(nn.Module): 24 | 25 | def __init__(self, dset, args): 26 | super(CompCos, self).__init__() 27 | self.args = args 28 | self.dset = dset 29 | 30 | def get_all_ids(relevant_pairs): 31 | # Precompute validation pairs 32 | attrs, objs = zip(*relevant_pairs) 33 | attrs = [dset.attr2idx[attr] for attr in attrs] 34 | objs = [dset.obj2idx[obj] for obj in objs] 35 | pairs = [a for a in range(len(relevant_pairs))] 36 | attrs = torch.LongTensor(attrs).to(device) 37 | objs = torch.LongTensor(objs).to(device) 38 | pairs = torch.LongTensor(pairs).to(device) 39 | return attrs, objs, pairs 40 | 41 | # Validation 42 | self.val_attrs, self.val_objs, self.val_pairs = get_all_ids(self.dset.pairs) 43 | 44 | # for indivual projections 45 | self.uniq_attrs, self.uniq_objs = torch.arange(len(self.dset.attrs)).long().to(device), \ 46 | torch.arange(len(self.dset.objs)).long().to(device) 47 | self.factor = 2 48 | 49 | self.scale = self.args.cosine_scale 50 | 51 | if dset.open_world: 52 | self.train_forward = self.train_forward_open 53 | self.known_pairs = dset.train_pairs 54 | seen_pair_set = set(self.known_pairs) 55 | mask = [1 if pair in seen_pair_set else 0 for pair in dset.pairs] 56 | self.seen_mask = torch.BoolTensor(mask).to(device) * 1. 57 | 58 | self.activated = False 59 | 60 | # Init feasibility-related variables 61 | self.attrs = dset.attrs 62 | self.objs = dset.objs 63 | self.possible_pairs = dset.pairs 64 | 65 | self.validation_pairs = dset.val_pairs 66 | 67 | self.feasibility_margin = (1-self.seen_mask).float() 68 | self.epoch_max_margin = self.args.epoch_max_margin 69 | self.cosine_margin_factor = -args.margin 70 | 71 | # Intantiate attribut-object relations, needed just to evaluate mined pairs 72 | self.obj_by_attrs_train = {k: [] for k in self.attrs} 73 | for (a, o) in self.known_pairs: 74 | self.obj_by_attrs_train[a].append(o) 75 | 76 | # Intantiate attribut-object relations, needed just to evaluate mined pairs 77 | self.attrs_by_obj_train = {k: [] for k in self.objs} 78 | for (a, o) in self.known_pairs: 79 | self.attrs_by_obj_train[o].append(a) 80 | 81 | else: 82 | self.train_forward = self.train_forward_closed 83 | 84 | # Precompute training compositions 85 | if args.train_only: 86 | self.train_attrs, self.train_objs, self.train_pairs = get_all_ids(self.dset.train_pairs) 87 | else: 88 | self.train_attrs, self.train_objs, self.train_pairs = self.val_attrs, self.val_objs, self.val_pairs 89 | 90 | try: 91 | self.args.fc_emb = self.args.fc_emb.split(',') 92 | except: 93 | self.args.fc_emb = [self.args.fc_emb] 94 | layers = [] 95 | for a in self.args.fc_emb: 96 | a = int(a) 97 | layers.append(a) 98 | 99 | 100 | self.image_embedder = MLP(dset.feat_dim, int(args.emb_dim), relu=args.relu, num_layers=args.nlayers, 101 | dropout=self.args.dropout, 102 | norm=self.args.norm, layers=layers) 103 | 104 | # Fixed 105 | self.composition = args.composition 106 | 107 | input_dim = args.emb_dim 108 | self.attr_embedder = nn.Embedding(len(dset.attrs), input_dim) 109 | self.obj_embedder = nn.Embedding(len(dset.objs), input_dim) 110 | 111 | # init with word embeddings 112 | if args.emb_init: 113 | pretrained_weight = load_word_embeddings(args.emb_init, dset.attrs) 114 | self.attr_embedder.weight.data.copy_(pretrained_weight) 115 | pretrained_weight = load_word_embeddings(args.emb_init, dset.objs) 116 | self.obj_embedder.weight.data.copy_(pretrained_weight) 117 | 118 | # static inputs 119 | if args.static_inp: 120 | for param in self.attr_embedder.parameters(): 121 | param.requires_grad = False 122 | for param in self.obj_embedder.parameters(): 123 | param.requires_grad = False 124 | 125 | # Composition MLP 126 | self.projection = nn.Linear(input_dim * 2, args.emb_dim) 127 | 128 | 129 | def freeze_representations(self): 130 | print('Freezing representations') 131 | for param in self.image_embedder.parameters(): 132 | param.requires_grad = False 133 | for param in self.attr_embedder.parameters(): 134 | param.requires_grad = False 135 | for param in self.obj_embedder.parameters(): 136 | param.requires_grad = False 137 | 138 | 139 | def compose(self, attrs, objs): 140 | attrs, objs = self.attr_embedder(attrs), self.obj_embedder(objs) 141 | inputs = torch.cat([attrs, objs], 1) 142 | output = self.projection(inputs) 143 | output = F.normalize(output, dim=1) 144 | return output 145 | 146 | 147 | def compute_feasibility(self): 148 | obj_embeddings = self.obj_embedder(torch.arange(len(self.objs)).long().to('cuda')) 149 | obj_embedding_sim = compute_cosine_similarity(self.objs, obj_embeddings, 150 | return_dict=True) 151 | attr_embeddings = self.attr_embedder(torch.arange(len(self.attrs)).long().to('cuda')) 152 | attr_embedding_sim = compute_cosine_similarity(self.attrs, attr_embeddings, 153 | return_dict=True) 154 | 155 | feasibility_scores = self.seen_mask.clone().float() 156 | for a in self.attrs: 157 | for o in self.objs: 158 | if (a, o) not in self.known_pairs: 159 | idx = self.dset.all_pair2idx[(a, o)] 160 | score_obj = self.get_pair_scores_objs(a, o, obj_embedding_sim) 161 | score_attr = self.get_pair_scores_attrs(a, o, attr_embedding_sim) 162 | score = (score_obj + score_attr) / 2 163 | feasibility_scores[idx] = score 164 | 165 | self.feasibility_scores = feasibility_scores 166 | 167 | return feasibility_scores * (1 - self.seen_mask.float()) 168 | 169 | 170 | def get_pair_scores_objs(self, attr, obj, obj_embedding_sim): 171 | score = -1. 172 | for o in self.objs: 173 | if o!=obj and attr in self.attrs_by_obj_train[o]: 174 | temp_score = obj_embedding_sim[(obj,o)] 175 | if temp_score>score: 176 | score=temp_score 177 | return score 178 | 179 | def get_pair_scores_attrs(self, attr, obj, attr_embedding_sim): 180 | score = -1. 181 | for a in self.attrs: 182 | if a != attr and obj in self.obj_by_attrs_train[a]: 183 | temp_score = attr_embedding_sim[(attr, a)] 184 | if temp_score > score: 185 | score = temp_score 186 | return score 187 | 188 | def update_feasibility(self,epoch): 189 | self.activated = True 190 | feasibility_scores = self.compute_feasibility() 191 | self.feasibility_margin = min(1.,epoch/self.epoch_max_margin) * \ 192 | (self.cosine_margin_factor*feasibility_scores.float().to(device)) 193 | 194 | 195 | def val_forward(self, x): 196 | img = x[0] 197 | img_feats = self.image_embedder(img) 198 | img_feats_normed = F.normalize(img_feats, dim=1) 199 | pair_embeds = self.compose(self.val_attrs, self.val_objs).permute(1, 0) # Evaluate all pairs 200 | score = torch.matmul(img_feats_normed, pair_embeds) 201 | 202 | scores = {} 203 | for itr, pair in enumerate(self.dset.pairs): 204 | scores[pair] = score[:, self.dset.all_pair2idx[pair]] 205 | 206 | return None, scores 207 | 208 | 209 | def val_forward_with_threshold(self, x, th=0.): 210 | img = x[0] 211 | img_feats = self.image_embedder(img) 212 | img_feats_normed = F.normalize(img_feats, dim=1) 213 | pair_embeds = self.compose(self.val_attrs, self.val_objs).permute(1, 0) # Evaluate all pairs 214 | score = torch.matmul(img_feats_normed, pair_embeds) 215 | 216 | # Note: Pairs are already aligned here 217 | mask = (self.feasibility_scores>=th).float() 218 | score = score*mask + (1.-mask)*(-1.) 219 | 220 | scores = {} 221 | for itr, pair in enumerate(self.dset.pairs): 222 | scores[pair] = score[:, self.dset.all_pair2idx[pair]] 223 | 224 | return None, scores 225 | 226 | 227 | def train_forward_open(self, x): 228 | img, attrs, objs, pairs = x[0], x[1], x[2], x[3] 229 | img_feats = self.image_embedder(img) 230 | 231 | pair_embed = self.compose(self.train_attrs, self.train_objs).permute(1, 0) 232 | img_feats_normed = F.normalize(img_feats, dim=1) 233 | 234 | pair_pred = torch.matmul(img_feats_normed, pair_embed) 235 | 236 | if self.activated: 237 | pair_pred += (1 - self.seen_mask) * self.feasibility_margin 238 | loss_cos = F.cross_entropy(self.scale * pair_pred, pairs) 239 | else: 240 | pair_pred = pair_pred * self.seen_mask + (1 - self.seen_mask) * (-10) 241 | loss_cos = F.cross_entropy(self.scale * pair_pred, pairs) 242 | 243 | return loss_cos.mean(), None 244 | 245 | 246 | def train_forward_closed(self, x): 247 | img, attrs, objs, pairs = x[0], x[1], x[2], x[3] 248 | img_feats = self.image_embedder(img) 249 | 250 | pair_embed = self.compose(self.train_attrs, self.train_objs).permute(1, 0) 251 | img_feats_normed = F.normalize(img_feats, dim=1) 252 | 253 | pair_pred = torch.matmul(img_feats_normed, pair_embed) 254 | 255 | loss_cos = F.cross_entropy(self.scale * pair_pred, pairs) 256 | 257 | return loss_cos.mean(), None 258 | 259 | 260 | def forward(self, x): 261 | if self.training: 262 | loss, pred = self.train_forward(x) 263 | else: 264 | with torch.no_grad(): 265 | loss, pred = self.val_forward(x) 266 | return loss, pred 267 | 268 | 269 | 270 | -------------------------------------------------------------------------------- /models/gcn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sp 3 | 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn.init import xavier_uniform_ 9 | 10 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 11 | 12 | def normt_spm(mx, method='in'): 13 | if method == 'in': 14 | mx = mx.transpose() 15 | rowsum = np.array(mx.sum(1)) 16 | r_inv = np.power(rowsum, -1).flatten() 17 | r_inv[np.isinf(r_inv)] = 0. 18 | r_mat_inv = sp.diags(r_inv) 19 | mx = r_mat_inv.dot(mx) 20 | return mx 21 | 22 | if method == 'sym': 23 | rowsum = np.array(mx.sum(1)) 24 | r_inv = np.power(rowsum, -0.5).flatten() 25 | r_inv[np.isinf(r_inv)] = 0. 26 | r_mat_inv = sp.diags(r_inv) 27 | mx = mx.dot(r_mat_inv).transpose().dot(r_mat_inv) 28 | return mx 29 | 30 | def spm_to_tensor(sparse_mx): 31 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 32 | indices = torch.from_numpy(np.vstack( 33 | (sparse_mx.row, sparse_mx.col))).long() 34 | values = torch.from_numpy(sparse_mx.data) 35 | shape = torch.Size(sparse_mx.shape) 36 | return torch.sparse.FloatTensor(indices, values, shape) 37 | 38 | class GraphConv(nn.Module): 39 | 40 | def __init__(self, in_channels, out_channels, dropout=False, relu=True): 41 | super().__init__() 42 | 43 | if dropout: 44 | self.dropout = nn.Dropout(p=0.5) 45 | else: 46 | self.dropout = None 47 | 48 | self.layer = nn.Linear(in_channels, out_channels) 49 | 50 | if relu: 51 | # self.relu = nn.LeakyReLU(negative_slope=0.2) 52 | self.relu = nn.ReLU() 53 | else: 54 | self.relu = None 55 | 56 | def forward(self, inputs, adj): 57 | if self.dropout is not None: 58 | inputs = self.dropout(inputs) 59 | 60 | outputs = torch.mm(adj, torch.mm(inputs, self.layer.weight.T)) + self.layer.bias 61 | 62 | if self.relu is not None: 63 | outputs = self.relu(outputs) 64 | return outputs 65 | 66 | 67 | class GCN(nn.Module): 68 | 69 | def __init__(self, adj, in_channels, out_channels, hidden_layers): 70 | super().__init__() 71 | 72 | adj = normt_spm(adj, method='in') 73 | adj = spm_to_tensor(adj) 74 | self.adj = adj.to(device) 75 | 76 | self.train_adj = self.adj 77 | 78 | hl = hidden_layers.split(',') 79 | if hl[-1] == 'd': 80 | dropout_last = True 81 | hl = hl[:-1] 82 | else: 83 | dropout_last = False 84 | 85 | i = 0 86 | layers = [] 87 | last_c = in_channels 88 | for c in hl: 89 | if c[0] == 'd': 90 | dropout = True 91 | c = c[1:] 92 | else: 93 | dropout = False 94 | c = int(c) 95 | 96 | i += 1 97 | conv = GraphConv(last_c, c, dropout=dropout) 98 | self.add_module('conv{}'.format(i), conv) 99 | layers.append(conv) 100 | 101 | last_c = c 102 | 103 | conv = GraphConv(last_c, out_channels, relu=False, dropout=dropout_last) 104 | self.add_module('conv-last', conv) 105 | layers.append(conv) 106 | 107 | self.layers = layers 108 | 109 | def forward(self, x): 110 | if self.training: 111 | for conv in self.layers: 112 | x = conv(x, self.train_adj) 113 | else: 114 | for conv in self.layers: 115 | x = conv(x, self.adj) 116 | return F.normalize(x) 117 | 118 | ### GCNII 119 | class GraphConvolution(nn.Module): 120 | 121 | def __init__(self, in_features, out_features, dropout=False, relu=True, residual=False, variant=False): 122 | super(GraphConvolution, self).__init__() 123 | self.variant = variant 124 | if self.variant: 125 | self.in_features = 2*in_features 126 | else: 127 | self.in_features = in_features 128 | 129 | if dropout: 130 | self.dropout = nn.Dropout(p=0.5) 131 | else: 132 | self.dropout = None 133 | 134 | if relu: 135 | self.relu = nn.ReLU() 136 | else: 137 | self.relu = None 138 | 139 | self.out_features = out_features 140 | self.residual = residual 141 | self.layer = nn.Linear(self.in_features, self.out_features, bias = False) 142 | 143 | def reset_parameters(self): 144 | stdv = 1. / math.sqrt(self.out_features) 145 | self.weight.data.uniform_(-stdv, stdv) 146 | 147 | def forward(self, input, adj , h0 , lamda, alpha, l): 148 | if self.dropout is not None: 149 | input = self.dropout(input) 150 | 151 | theta = math.log(lamda/l+1) 152 | hi = torch.spmm(adj, input) 153 | if self.variant: 154 | support = torch.cat([hi,h0],1) 155 | r = (1-alpha)*hi+alpha*h0 156 | else: 157 | support = (1-alpha)*hi+alpha*h0 158 | r = support 159 | 160 | mm_term = torch.mm(support, self.layer.weight.T) 161 | 162 | output = theta*mm_term+(1-theta)*r 163 | if self.residual: 164 | output = output+input 165 | 166 | if self.relu is not None: 167 | output = self.relu(output) 168 | 169 | return output 170 | 171 | class GCNII(nn.Module): 172 | def __init__(self, adj, in_channels , out_channels, hidden_dim, hidden_layers, lamda, alpha, variant, dropout = True): 173 | super(GCNII, self).__init__() 174 | 175 | self.alpha = alpha 176 | self.lamda = lamda 177 | 178 | adj = normt_spm(adj, method='in') 179 | adj = spm_to_tensor(adj) 180 | self.adj = adj.to(device) 181 | 182 | i = 0 183 | layers = nn.ModuleList() 184 | self.fc_dim = nn.Linear(in_channels, hidden_dim) 185 | self.relu = nn.ReLU() 186 | self.dropout = nn.Dropout() 187 | 188 | for i, c in enumerate(range(hidden_layers)): 189 | conv = GraphConvolution(hidden_dim, hidden_dim, variant=variant, dropout=dropout) 190 | layers.append(conv) 191 | 192 | self.layers = layers 193 | self.fc_out = nn.Linear(hidden_dim, out_channels) 194 | 195 | def forward(self, x): 196 | _layers = [] 197 | layer_inner = self.relu(self.fc_dim(self.dropout(x))) 198 | # layer_inner = x 199 | _layers.append(layer_inner) 200 | 201 | for i,con in enumerate(self.layers): 202 | layer_inner = con(layer_inner,self.adj,_layers[0],self.lamda,self.alpha,i+1) 203 | 204 | layer_inner = self.fc_out(self.dropout(layer_inner)) 205 | return layer_inner -------------------------------------------------------------------------------- /models/graph_method.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from .common import MLP 6 | from .gcn import GCN, GCNII 7 | from .word_embedding import load_word_embeddings 8 | import scipy.sparse as sp 9 | 10 | 11 | def adj_to_edges(adj): 12 | # Adj sparse matrix to list of edges 13 | rows, cols = np.nonzero(adj) 14 | edges = list(zip(rows.tolist(), cols.tolist())) 15 | return edges 16 | 17 | 18 | def edges_to_adj(edges, n): 19 | # List of edges to Adj sparse matrix 20 | edges = np.array(edges) 21 | adj = sp.coo_matrix((np.ones(len(edges)), (edges[:, 0], edges[:, 1])), 22 | shape=(n, n), dtype='float32') 23 | return adj 24 | 25 | 26 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 27 | 28 | class GraphFull(nn.Module): 29 | def __init__(self, dset, args): 30 | super(GraphFull, self).__init__() 31 | self.args = args 32 | self.dset = dset 33 | 34 | self.val_forward = self.val_forward_dotpr 35 | self.train_forward = self.train_forward_normal 36 | # Image Embedder 37 | self.num_attrs, self.num_objs, self.num_pairs = len(dset.attrs), len(dset.objs), len(dset.pairs) 38 | self.pairs = dset.pairs 39 | 40 | if self.args.train_only: 41 | train_idx = [] 42 | for current in dset.train_pairs: 43 | train_idx.append(dset.all_pair2idx[current]+self.num_attrs+self.num_objs) 44 | self.train_idx = torch.LongTensor(train_idx).to(device) 45 | 46 | self.args.fc_emb = self.args.fc_emb.split(',') 47 | layers = [] 48 | for a in self.args.fc_emb: 49 | a = int(a) 50 | layers.append(a) 51 | 52 | if args.nlayers: 53 | self.image_embedder = MLP(dset.feat_dim, args.emb_dim, num_layers= args.nlayers, dropout = self.args.dropout, 54 | norm = self.args.norm, layers = layers, relu = True) 55 | 56 | all_words = list(self.dset.attrs) + list(self.dset.objs) 57 | self.displacement = len(all_words) 58 | 59 | self.obj_to_idx = {word: idx for idx, word in enumerate(self.dset.objs)} 60 | self.attr_to_idx = {word: idx for idx, word in enumerate(self.dset.attrs)} 61 | 62 | if args.graph_init is not None: 63 | path = args.graph_init 64 | graph = torch.load(path) 65 | embeddings = graph['embeddings'].to(device) 66 | adj = graph['adj'] 67 | self.embeddings = embeddings 68 | else: 69 | embeddings = self.init_embeddings(all_words).to(device) 70 | adj = self.adj_from_pairs() 71 | self.embeddings = embeddings 72 | 73 | hidden_layers = self.args.gr_emb 74 | if args.gcn_type == 'gcn': 75 | self.gcn = GCN(adj, self.embeddings.shape[1], args.emb_dim, hidden_layers) 76 | else: 77 | self.gcn = GCNII(adj, self.embeddings.shape[1], args.emb_dim, args.hidden_dim, args.gcn_nlayers, lamda = 0.5, alpha = 0.1, variant = False) 78 | 79 | 80 | 81 | def init_embeddings(self, all_words): 82 | 83 | def get_compositional_embeddings(embeddings, pairs): 84 | # Getting compositional embeddings from base embeddings 85 | composition_embeds = [] 86 | for (attr, obj) in pairs: 87 | attr_embed = embeddings[self.attr_to_idx[attr]] 88 | obj_embed = embeddings[self.obj_to_idx[obj]+self.num_attrs] 89 | composed_embed = (attr_embed + obj_embed) / 2 90 | composition_embeds.append(composed_embed) 91 | composition_embeds = torch.stack(composition_embeds) 92 | print('Compositional Embeddings are ', composition_embeds.shape) 93 | return composition_embeds 94 | 95 | # init with word embeddings 96 | embeddings = load_word_embeddings(self.args.emb_init, all_words) 97 | 98 | composition_embeds = get_compositional_embeddings(embeddings, self.pairs) 99 | full_embeddings = torch.cat([embeddings, composition_embeds], dim=0) 100 | 101 | return full_embeddings 102 | 103 | 104 | def update_dict(self, wdict, row,col,data): 105 | wdict['row'].append(row) 106 | wdict['col'].append(col) 107 | wdict['data'].append(data) 108 | 109 | def adj_from_pairs(self): 110 | 111 | def edges_from_pairs(pairs): 112 | weight_dict = {'data':[],'row':[],'col':[]} 113 | 114 | 115 | for i in range(self.displacement): 116 | self.update_dict(weight_dict,i,i,1.) 117 | 118 | for idx, (attr, obj) in enumerate(pairs): 119 | attr_idx, obj_idx = self.attr_to_idx[attr], self.obj_to_idx[obj] + self.num_attrs 120 | 121 | self.update_dict(weight_dict, attr_idx, obj_idx, 1.) 122 | self.update_dict(weight_dict, obj_idx, attr_idx, 1.) 123 | 124 | node_id = idx + self.displacement 125 | self.update_dict(weight_dict,node_id,node_id,1.) 126 | 127 | self.update_dict(weight_dict, node_id, attr_idx, 1.) 128 | self.update_dict(weight_dict, node_id, obj_idx, 1.) 129 | 130 | 131 | self.update_dict(weight_dict, attr_idx, node_id, 1.) 132 | self.update_dict(weight_dict, obj_idx, node_id, 1.) 133 | 134 | return weight_dict 135 | 136 | edges = edges_from_pairs(self.pairs) 137 | adj = sp.csr_matrix((edges['data'], (edges['row'], edges['col'])), 138 | shape=(len(self.pairs)+self.displacement, len(self.pairs)+self.displacement)) 139 | 140 | return adj 141 | 142 | 143 | 144 | def train_forward_normal(self, x): 145 | img, attrs, objs, pairs = x[0], x[1], x[2], x[3] 146 | 147 | if self.args.nlayers: 148 | img_feats = self.image_embedder(img) 149 | else: 150 | img_feats = (img) 151 | 152 | current_embeddings = self.gcn(self.embeddings) 153 | 154 | if self.args.train_only: 155 | pair_embed = current_embeddings[self.train_idx] 156 | else: 157 | pair_embed = current_embeddings[self.num_attrs+self.num_objs:self.num_attrs+self.num_objs+self.num_pairs,:] 158 | 159 | pair_embed = pair_embed.permute(1,0) 160 | pair_pred = torch.matmul(img_feats, pair_embed) 161 | loss = F.cross_entropy(pair_pred, pairs) 162 | 163 | return loss, None 164 | 165 | def val_forward_dotpr(self, x): 166 | img = x[0] 167 | 168 | if self.args.nlayers: 169 | img_feats = self.image_embedder(img) 170 | else: 171 | img_feats = (img) 172 | 173 | current_embedddings = self.gcn(self.embeddings) 174 | 175 | pair_embeds = current_embedddings[self.num_attrs+self.num_objs:self.num_attrs+self.num_objs+self.num_pairs,:].permute(1,0) 176 | 177 | score = torch.matmul(img_feats, pair_embeds) 178 | 179 | scores = {} 180 | for itr, pair in enumerate(self.dset.pairs): 181 | scores[pair] = score[:,self.dset.all_pair2idx[pair]] 182 | 183 | return None, scores 184 | 185 | def val_forward_distance_fast(self, x): 186 | img = x[0] 187 | 188 | img_feats = (self.image_embedder(img)) 189 | current_embeddings = self.gcn(self.embeddings) 190 | pair_embeds = current_embeddings[self.num_attrs+self.num_objs:,:] 191 | 192 | batch_size, pairs, features = img_feats.shape[0], pair_embeds.shape[0], pair_embeds.shape[1] 193 | img_feats = img_feats[:,None,:].expand(-1, pairs, -1) 194 | pair_embeds = pair_embeds[None,:,:].expand(batch_size, -1, -1) 195 | diff = (img_feats - pair_embeds)**2 196 | score = diff.sum(2) * -1 197 | 198 | scores = {} 199 | for itr, pair in enumerate(self.dset.pairs): 200 | scores[pair] = score[:,self.dset.all_pair2idx[pair]] 201 | 202 | return None, scores 203 | 204 | def forward(self, x): 205 | if self.training: 206 | loss, pred = self.train_forward(x) 207 | else: 208 | with torch.no_grad(): 209 | loss, pred = self.val_forward(x) 210 | return loss, pred 211 | -------------------------------------------------------------------------------- /models/image_extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision import models 5 | from torchvision.models.resnet import ResNet, BasicBlock 6 | 7 | class ResNet18_conv(ResNet): 8 | def __init__(self): 9 | super(ResNet18_conv, self).__init__(BasicBlock, [2, 2, 2, 2]) 10 | 11 | def forward(self, x): 12 | # change forward here 13 | x = self.conv1(x) 14 | x = self.bn1(x) 15 | x = self.relu(x) 16 | x = self.maxpool(x) 17 | 18 | x = self.layer1(x) 19 | x = self.layer2(x) 20 | x = self.layer3(x) 21 | x = self.layer4(x) 22 | 23 | return x 24 | 25 | 26 | def get_image_extractor(arch = 'resnet18', pretrained = True, feature_dim = None, checkpoint = ''): 27 | ''' 28 | Inputs 29 | arch: Base architecture 30 | pretrained: Bool, Imagenet weights 31 | feature_dim: Int, output feature dimension 32 | checkpoint: String, not implemented 33 | Returns 34 | Pytorch model 35 | ''' 36 | 37 | if arch == 'resnet18': 38 | model = models.resnet18(pretrained = pretrained) 39 | if feature_dim is None: 40 | model.fc = nn.Sequential() 41 | else: 42 | model.fc = nn.Linear(512, feature_dim) 43 | 44 | if arch == 'resnet18_conv': 45 | model = ResNet18_conv() 46 | model.load_state_dict(models.resnet18(pretrained=True).state_dict()) 47 | 48 | elif arch == 'resnet50': 49 | model = models.resnet50(pretrained = pretrained) 50 | if feature_dim is None: 51 | model.fc = nn.Sequential() 52 | else: 53 | model.fc = nn.Linear(2048, feature_dim) 54 | 55 | elif arch == 'resnet50_cutmix': 56 | model = models.resnet50(pretrained = pretrained) 57 | checkpoint = torch.load('/home/ubuntu/workspace/pretrained/resnet50_cutmix.tar') 58 | model.load_state_dict(checkpoint['state_dict'], strict=False) 59 | if feature_dim is None: 60 | model.fc = nn.Sequential() 61 | else: 62 | model.fc = nn.Linear(2048, feature_dim) 63 | 64 | elif arch == 'resnet152': 65 | model = models.resnet152(pretrained = pretrained) 66 | if feature_dim is None: 67 | model.fc = nn.Sequential() 68 | else: 69 | model.fc = nn.Linear(2048, feature_dim) 70 | 71 | elif arch == 'vgg16': 72 | model = models.vgg16(pretrained = pretrained) 73 | modules = list(model.classifier.children())[:-3] 74 | model.classifier=torch.nn.Sequential(*modules) 75 | if feature_dim is not None: 76 | model.classifier[3]=torch.nn.Linear(4096,feature_dim) 77 | 78 | return model 79 | 80 | -------------------------------------------------------------------------------- /models/manifold_methods.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from .word_embedding import load_word_embeddings 6 | from .common import MLP, Reshape 7 | from flags import DATA_FOLDER 8 | 9 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 10 | 11 | class ManifoldModel(nn.Module): 12 | 13 | def __init__(self, dset, args): 14 | super(ManifoldModel, self).__init__() 15 | self.args = args 16 | self.dset = dset 17 | 18 | def get_all_ids(relevant_pairs): 19 | # Precompute validation pairs 20 | attrs, objs = zip(*relevant_pairs) 21 | attrs = [dset.attr2idx[attr] for attr in attrs] 22 | objs = [dset.obj2idx[obj] for obj in objs] 23 | pairs = [a for a in range(len(relevant_pairs))] 24 | attrs = torch.LongTensor(attrs).to(device) 25 | objs = torch.LongTensor(objs).to(device) 26 | pairs = torch.LongTensor(pairs).to(device) 27 | return attrs, objs, pairs 28 | 29 | #Validation 30 | self.val_attrs, self.val_objs, self.val_pairs = get_all_ids(self.dset.pairs) 31 | # for indivual projections 32 | self.uniq_attrs, self.uniq_objs = torch.arange(len(self.dset.attrs)).long().to(device), \ 33 | torch.arange(len(self.dset.objs)).long().to(device) 34 | self.factor = 2 35 | # Precompute training compositions 36 | if args.train_only: 37 | self.train_attrs, self.train_objs, self.train_pairs = get_all_ids(self.dset.train_pairs) 38 | else: 39 | self.train_attrs, self.train_objs = self.val_subs, self.val_attrs, self.val_objs 40 | 41 | if args.lambda_aux > 0 or args.lambda_cls_attr > 0 or args.lambda_cls_obj > 0: 42 | print('Initializing classifiers') 43 | self.obj_clf = nn.Linear(args.emb_dim, len(dset.objs)) 44 | self.attr_clf = nn.Linear(args.emb_dim, len(dset.attrs)) 45 | 46 | def train_forward_bce(self, x): 47 | img, attrs, objs = x[0], x[1], x[2] 48 | neg_attrs, neg_objs = x[4][:,0],x[5][:,0] #todo: do a less hacky version 49 | 50 | img_feat = self.image_embedder(img) 51 | # Sample 25% positive and 75% negative pairs 52 | labels = np.random.binomial(1, 0.25, attrs.shape[0]) 53 | labels = torch.from_numpy(labels).bool().to(device) 54 | sampled_attrs, sampled_objs = neg_attrs.clone(), neg_objs.clone() 55 | sampled_attrs[labels] = attrs[labels] 56 | sampled_objs[labels] = objs[labels] 57 | labels = labels.float() 58 | 59 | composed_clf = self.compose(attrs, objs) 60 | p = torch.sigmoid((img_feat*composed_clf).sum(1)) 61 | loss = F.binary_cross_entropy(p, labels) 62 | 63 | return loss, None 64 | 65 | def train_forward_triplet(self, x): 66 | img, attrs, objs = x[0], x[1], x[2] 67 | neg_attrs, neg_objs = x[4][:,0], x[5][:,0] #todo:do a less hacky version 68 | 69 | img_feats = self.image_embedder(img) 70 | positive = self.compose(attrs, objs) 71 | negative = self.compose(neg_attrs, neg_objs) 72 | loss = F.triplet_margin_loss(img_feats, positive, negative, margin = self.args.margin) 73 | 74 | # Auxiliary object/ attribute prediction loss both need to be correct 75 | if self.args.lambda_aux > 0: 76 | obj_pred = self.obj_clf(positive) 77 | attr_pred = self.attr_clf(positive) 78 | loss_aux = F.cross_entropy(attr_pred, attrs) + F.cross_entropy(obj_pred, objs) 79 | loss += self.args.lambda_aux * loss_aux 80 | 81 | return loss, None 82 | 83 | 84 | def val_forward_distance(self, x): 85 | img = x[0] 86 | batch_size = img.shape[0] 87 | 88 | img_feats = self.image_embedder(img) 89 | scores = {} 90 | pair_embeds = self.compose(self.val_attrs, self.val_objs) 91 | 92 | for itr, pair in enumerate(self.dset.pairs): 93 | pair_embed = pair_embeds[itr, None].expand(batch_size, pair_embeds.size(1)) 94 | score = self.compare_metric(img_feats, pair_embed) 95 | scores[pair] = score 96 | 97 | return None, scores 98 | 99 | def val_forward_distance_fast(self, x): 100 | img = x[0] 101 | batch_size = img.shape[0] 102 | 103 | img_feats = self.image_embedder(img) 104 | pair_embeds = self.compose(self.val_attrs, self.val_objs) # Evaluate all pairs 105 | 106 | batch_size, pairs, features = img_feats.shape[0], pair_embeds.shape[0], pair_embeds.shape[1] 107 | img_feats = img_feats[:,None,:].expand(-1, pairs, -1) 108 | pair_embeds = pair_embeds[None,:,:].expand(batch_size, -1, -1) 109 | diff = (img_feats - pair_embeds)**2 110 | score = diff.sum(2) * -1 111 | 112 | scores = {} 113 | for itr, pair in enumerate(self.dset.pairs): 114 | scores[pair] = score[:,self.dset.all_pair2idx[pair]] 115 | 116 | return None, scores 117 | 118 | def val_forward_direct(self, x): 119 | img = x[0] 120 | batch_size = img.shape[0] 121 | 122 | img_feats = self.image_embedder(img) 123 | pair_embeds = self.compose(self.val_attrs, self.val_objs).permute(1,0) # Evaluate all pairs 124 | score = torch.matmul(img_feats, pair_embeds) 125 | 126 | scores = {} 127 | for itr, pair in enumerate(self.dset.pairs): 128 | scores[pair] = score[:,self.dset.all_pair2idx[pair]] 129 | 130 | return None, scores 131 | 132 | def forward(self, x): 133 | if self.training: 134 | loss, pred = self.train_forward(x) 135 | else: 136 | with torch.no_grad(): 137 | loss, pred = self.val_forward(x) 138 | return loss, pred 139 | 140 | ######################################### 141 | class RedWine(ManifoldModel): 142 | 143 | def __init__(self, dset, args): 144 | super(RedWine, self).__init__(dset, args) 145 | self.image_embedder = lambda img: img 146 | self.compare_metric = lambda img_feats, pair_embed: torch.sigmoid((img_feats*pair_embed).sum(1)) 147 | self.train_forward = self.train_forward_bce 148 | self.val_forward = self.val_forward_distance_fast 149 | 150 | in_dim = dset.feat_dim if args.clf_init else args.emb_dim 151 | self.T = nn.Sequential( 152 | nn.Linear(2*in_dim, 3*in_dim), 153 | nn.LeakyReLU(0.1, True), 154 | nn.Linear(3*in_dim, 3*in_dim//2), 155 | nn.LeakyReLU(0.1, True), 156 | nn.Linear(3*in_dim//2, dset.feat_dim)) 157 | 158 | self.attr_embedder = nn.Embedding(len(dset.attrs), in_dim) 159 | self.obj_embedder = nn.Embedding(len(dset.objs), in_dim) 160 | 161 | # initialize the weights of the embedders with the svm weights 162 | if args.emb_init: 163 | pretrained_weight = load_word_embeddings(args.emb_init, dset.attrs) 164 | self.attr_embedder.weight.data.copy_(pretrained_weight) 165 | pretrained_weight = load_word_embeddings(args.emb_init, dset.objs) 166 | self.obj_embedder.weight.data.copy_(pretrained_weight) 167 | 168 | elif args.clf_init: 169 | for idx, attr in enumerate(dset.attrs): 170 | at_id = self.dset.attr2idx[attr] 171 | weight = torch.load('%s/svm/attr_%d'%(args.data_dir, at_id)).coef_.squeeze() 172 | self.attr_embedder.weight[idx].data.copy_(torch.from_numpy(weight)) 173 | for idx, obj in enumerate(dset.objs): 174 | obj_id = self.dset.obj2idx[obj] 175 | weight = torch.load('%s/svm/obj_%d'%(args.data_dir, obj_id)).coef_.squeeze() 176 | self.obj_embedder.weight[idx].data.copy_(torch.from_numpy(weight)) 177 | else: 178 | print ('init must be either glove or clf') 179 | return 180 | 181 | if args.static_inp: 182 | for param in self.attr_embedder.parameters(): 183 | param.requires_grad = False 184 | for param in self.obj_embedder.parameters(): 185 | param.requires_grad = False 186 | 187 | def compose(self, attrs, objs): 188 | attr_wt = self.attr_embedder(attrs) 189 | obj_wt = self.obj_embedder(objs) 190 | inp_wts = torch.cat([attr_wt, obj_wt], 1) # 2D 191 | composed_clf = self.T(inp_wts) 192 | return composed_clf 193 | 194 | class LabelEmbedPlus(ManifoldModel): 195 | def __init__(self, dset, args): 196 | super(LabelEmbedPlus, self).__init__(dset, args) 197 | if 'conv' in args.image_extractor: 198 | self.image_embedder = torch.nn.Sequential(torch.nn.Conv2d(dset.feat_dim,args.emb_dim,7), 199 | torch.nn.ReLU(True), 200 | Reshape(-1,args.emb_dim) 201 | ) 202 | else: 203 | self.image_embedder = MLP(dset.feat_dim, args.emb_dim) 204 | 205 | self.compare_metric = lambda img_feats, pair_embed: -F.pairwise_distance(img_feats, pair_embed) 206 | self.train_forward = self.train_forward_triplet 207 | self.val_forward = self.val_forward_distance_fast 208 | 209 | input_dim = dset.feat_dim if args.clf_init else args.emb_dim 210 | self.attr_embedder = nn.Embedding(len(dset.attrs), input_dim) 211 | self.obj_embedder = nn.Embedding(len(dset.objs), input_dim) 212 | self.T = MLP(2*input_dim, args.emb_dim, num_layers= args.nlayers) 213 | 214 | # init with word embeddings 215 | if args.emb_init: 216 | pretrained_weight = load_word_embeddings(args.emb_init, dset.attrs) 217 | self.attr_embedder.weight.data.copy_(pretrained_weight) 218 | pretrained_weight = load_word_embeddings(args.emb_init, dset.objs) 219 | self.obj_embedder.weight.data.copy_(pretrained_weight) 220 | 221 | # init with classifier weights 222 | elif args.clf_init: 223 | for idx, attr in enumerate(dset.attrs): 224 | at_id = dset.attrs.index(attr) 225 | weight = torch.load('%s/svm/attr_%d'%(args.data_dir, at_id)).coef_.squeeze() 226 | self.attr_embedder.weight[idx].data.copy_(torch.from_numpy(weight)) 227 | for idx, obj in enumerate(dset.objs): 228 | obj_id = dset.objs.index(obj) 229 | weight = torch.load('%s/svm/obj_%d'%(args.data_dir, obj_id)).coef_.squeeze() 230 | self.obj_emb.weight[idx].data.copy_(torch.from_numpy(weight)) 231 | 232 | # static inputs 233 | if args.static_inp: 234 | for param in self.attr_embedder.parameters(): 235 | param.requires_grad = False 236 | for param in self.obj_embedder.parameters(): 237 | param.requires_grad = False 238 | 239 | def compose(self, attrs, objs): 240 | inputs = [self.attr_embedder(attrs), self.obj_embedder(objs)] 241 | inputs = torch.cat(inputs, 1) 242 | output = self.T(inputs) 243 | return output 244 | 245 | class AttributeOperator(ManifoldModel): 246 | def __init__(self, dset, args): 247 | super(AttributeOperator, self).__init__(dset, args) 248 | self.image_embedder = MLP(dset.feat_dim, args.emb_dim) 249 | self.compare_metric = lambda img_feats, pair_embed: -F.pairwise_distance(img_feats, pair_embed) 250 | self.val_forward = self.val_forward_distance_fast 251 | 252 | self.attr_ops = nn.ParameterList([nn.Parameter(torch.eye(args.emb_dim)) for _ in range(len(self.dset.attrs))]) 253 | self.obj_embedder = nn.Embedding(len(dset.objs), args.emb_dim) 254 | 255 | if args.emb_init: 256 | pretrained_weight = load_word_embeddings('glove', dset.objs) 257 | self.obj_embedder.weight.data.copy_(pretrained_weight) 258 | 259 | self.inverse_cache = {} 260 | 261 | if args.lambda_ant>0 and args.dataset=='mitstates': 262 | antonym_list = open(DATA_FOLDER+'/data/antonyms.txt').read().strip().split('\n') 263 | antonym_list = [l.split() for l in antonym_list] 264 | antonym_list = [[self.dset.attrs.index(a1), self.dset.attrs.index(a2)] for a1, a2 in antonym_list] 265 | antonyms = {} 266 | antonyms.update({a1:a2 for a1, a2 in antonym_list}) 267 | antonyms.update({a2:a1 for a1, a2 in antonym_list}) 268 | self.antonyms, self.antonym_list = antonyms, antonym_list 269 | 270 | if args.static_inp: 271 | for param in self.obj_embedder.parameters(): 272 | param.requires_grad = False 273 | 274 | def apply_ops(self, ops, rep): 275 | out = torch.bmm(ops, rep.unsqueeze(2)).squeeze(2) 276 | out = F.relu(out) 277 | return out 278 | 279 | def compose(self, attrs, objs): 280 | obj_rep = self.obj_embedder(objs) 281 | attr_ops = torch.stack([self.attr_ops[attr.item()] for attr in attrs]) 282 | embedded_reps = self.apply_ops(attr_ops, obj_rep) 283 | return embedded_reps 284 | 285 | def apply_inverse(self, img_rep, attrs): 286 | inverse_ops = [] 287 | for i in range(img_rep.size(0)): 288 | attr = attrs[i] 289 | if attr not in self.inverse_cache: 290 | self.inverse_cache[attr] = self.attr_ops[attr].inverse() 291 | inverse_ops.append(self.inverse_cache[attr]) 292 | inverse_ops = torch.stack(inverse_ops) # (B,512,512) 293 | obj_rep = self.apply_ops(inverse_ops, img_rep) 294 | return obj_rep 295 | 296 | def train_forward(self, x): 297 | img, attrs, objs = x[0], x[1], x[2] 298 | neg_attrs, neg_objs, inv_attrs, comm_attrs = x[4][:,0], x[5][:,0], x[6], x[7] 299 | batch_size = img.size(0) 300 | 301 | loss = [] 302 | 303 | anchor = self.image_embedder(img) 304 | 305 | obj_emb = self.obj_embedder(objs) 306 | pos_ops = torch.stack([self.attr_ops[attr.item()] for attr in attrs]) 307 | positive = self.apply_ops(pos_ops, obj_emb) 308 | 309 | neg_obj_emb = self.obj_embedder(neg_objs) 310 | neg_ops = torch.stack([self.attr_ops[attr.item()] for attr in neg_attrs]) 311 | negative = self.apply_ops(neg_ops, neg_obj_emb) 312 | 313 | loss_triplet = F.triplet_margin_loss(anchor, positive, negative, margin=self.args.margin) 314 | loss.append(loss_triplet) 315 | 316 | 317 | # Auxiliary object/attribute loss 318 | if self.args.lambda_aux>0: 319 | obj_pred = self.obj_clf(positive) 320 | attr_pred = self.attr_clf(positive) 321 | loss_aux = F.cross_entropy(attr_pred, attrs) + F.cross_entropy(obj_pred, objs) 322 | loss.append(self.args.lambda_aux*loss_aux) 323 | 324 | # Inverse Consistency 325 | if self.args.lambda_inv>0: 326 | obj_rep = self.apply_inverse(anchor, attrs) 327 | new_ops = torch.stack([self.attr_ops[attr.item()] for attr in inv_attrs]) 328 | new_rep = self.apply_ops(new_ops, obj_rep) 329 | new_positive = self.apply_ops(new_ops, obj_emb) 330 | loss_inv = F.triplet_margin_loss(new_rep, new_positive, positive, margin=self.args.margin) 331 | loss.append(self.args.lambda_inv*loss_inv) 332 | 333 | # Commutative Operators 334 | if self.args.lambda_comm>0: 335 | B = torch.stack([self.attr_ops[attr.item()] for attr in comm_attrs]) 336 | BA = self.apply_ops(B, positive) 337 | AB = self.apply_ops(pos_ops, self.apply_ops(B, obj_emb)) 338 | loss_comm = ((AB-BA)**2).sum(1).mean() 339 | loss.append(self.args.lambda_comm*loss_comm) 340 | 341 | # Antonym Consistency 342 | if self.args.lambda_ant>0: 343 | 344 | select_idx = [i for i in range(batch_size) if attrs[i].item() in self.antonyms] 345 | if len(select_idx)>0: 346 | select_idx = torch.LongTensor(select_idx).cuda() 347 | attr_subset = attrs[select_idx] 348 | antonym_ops = torch.stack([self.attr_ops[self.antonyms[attr.item()]] for attr in attr_subset]) 349 | 350 | Ao = anchor[select_idx] 351 | if self.args.lambda_inv>0: 352 | o = obj_rep[select_idx] 353 | else: 354 | o = self.apply_inverse(Ao, attr_subset) 355 | BAo = self.apply_ops(antonym_ops, Ao) 356 | 357 | loss_cycle = ((BAo-o)**2).sum(1).mean() 358 | loss.append(self.args.lambda_ant*loss_cycle) 359 | 360 | 361 | loss = sum(loss) 362 | return loss, None 363 | 364 | def forward(self, x): 365 | loss, pred = super(AttributeOperator, self).forward(x) 366 | self.inverse_cache = {} 367 | return loss, pred -------------------------------------------------------------------------------- /models/modular_methods.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torchvision.models as tmodels 11 | import numpy as np 12 | from .word_embedding import load_word_embeddings 13 | import itertools 14 | import math 15 | import collections 16 | from torch.distributions.bernoulli import Bernoulli 17 | import pdb 18 | import sys 19 | 20 | 21 | class Flatten(nn.Module): 22 | def forward(self, input): 23 | return input.view(input.size(0), -1) 24 | 25 | 26 | class GatingSampler(nn.Module): 27 | """Docstring for GatingSampler. """ 28 | 29 | def __init__(self, gater, stoch_sample=True, temperature=1.0): 30 | """TODO: to be defined1. 31 | 32 | """ 33 | nn.Module.__init__(self) 34 | self.gater = gater 35 | self._stoch_sample = stoch_sample 36 | self._temperature = temperature 37 | 38 | def disable_stochastic_sampling(self): 39 | self._stoch_sample = False 40 | 41 | def enable_stochastic_sampling(self): 42 | self._stoch_sample = True 43 | 44 | def forward(self, tdesc=None, return_additional=False, gating_wt=None): 45 | if self.gater is None and return_additional: 46 | return None, None 47 | elif self.gater is None: 48 | return None 49 | 50 | if gating_wt is not None: 51 | return_wts = gating_wt 52 | gating_g = self.gater(tdesc, gating_wt=gating_wt) 53 | else: 54 | gating_g = self.gater(tdesc) 55 | return_wts = None 56 | if isinstance(gating_g, tuple): 57 | return_wts = gating_g[1] 58 | gating_g = gating_g[0] 59 | 60 | if not self._stoch_sample: 61 | sampled_g = gating_g 62 | else: 63 | raise (NotImplementedError) 64 | 65 | if return_additional: 66 | return sampled_g, return_wts 67 | return sampled_g 68 | 69 | 70 | class GatedModularNet(nn.Module): 71 | """ 72 | An interface for creating modular nets 73 | """ 74 | 75 | def __init__(self, 76 | module_list, 77 | start_modules=None, 78 | end_modules=None, 79 | single_head=False, 80 | chain=False): 81 | """TODO: to be defined1. 82 | 83 | :module_list: TODO 84 | :g: TODO 85 | 86 | """ 87 | nn.Module.__init__(self) 88 | self._module_list = nn.ModuleList( 89 | [nn.ModuleList(m) for m in module_list]) 90 | 91 | self.num_layers = len(self._module_list) 92 | if start_modules is not None: 93 | self._start_modules = nn.ModuleList(start_modules) 94 | else: 95 | self._start_modules = None 96 | 97 | if end_modules is not None: 98 | self._end_modules = nn.ModuleList(end_modules) 99 | self.num_layers += 1 100 | else: 101 | self._end_modules = None 102 | 103 | self.sampled_g = None 104 | self.single_head = single_head 105 | self._chain = chain 106 | 107 | def forward(self, x, sampled_g=None, t=None, return_feat=False): 108 | """TODO: Docstring for forward. 109 | 110 | :x: Input data 111 | :g: Gating tensor (#Task x )#num_layer x #num_mods x #num_mods 112 | :t: task ID 113 | :returns: TODO 114 | 115 | """ 116 | 117 | if t is None: 118 | t_not_set = True 119 | t = torch.tensor([0] * x.shape[0], dtype=x.dtype).long() 120 | else: 121 | t_not_set = False 122 | t = t.squeeze() 123 | 124 | if self._start_modules is not None: 125 | prev_out = [mod(x) for mod in self._start_modules] 126 | else: 127 | prev_out = [x] 128 | 129 | if sampled_g is None: 130 | # NON-Gated Module network 131 | prev_out = sum(prev_out) / float(len(prev_out)) 132 | #prev_out = torch.mean(prev_out, 0) 133 | for li in range(len(self._module_list)): 134 | prev_out = sum([ 135 | mod(prev_out) for mod in self._module_list[li] 136 | ]) / float(len(self._module_list[li])) 137 | features = prev_out 138 | if self._end_modules is not None: 139 | if t_not_set or self.single_head: 140 | prev_out = self._end_modules[0](prev_out) 141 | else: 142 | prev_out = torch.cat([ 143 | self._end_modules[tid](prev_out[bi:bi + 1]) 144 | for bi, tid in enumerate(t) 145 | ], 0) 146 | if return_feat: 147 | return prev_out, features 148 | return prev_out 149 | else: 150 | # Forward prop with sampled Gs 151 | for li in range(len(self._module_list)): 152 | curr_out = [] 153 | for j in range(len(self._module_list[li])): 154 | gind = j if not self._chain else 0 155 | # Dim: #Batch x C 156 | module_in_wt = sampled_g[li + 1][gind] 157 | # Module input weights rearranged to match inputs 158 | module_in_wt = module_in_wt.transpose(0, 1) 159 | add_dims = prev_out[0].dim() + 1 - module_in_wt.dim() 160 | module_in_wt = module_in_wt.view(*module_in_wt.shape, 161 | *([1] * add_dims)) 162 | module_in_wt = module_in_wt.expand( 163 | len(prev_out), *prev_out[0].shape) 164 | module_in = sum([ 165 | module_in_wt[i] * prev_out[i] 166 | for i in range(len(prev_out)) 167 | ]) 168 | mod = self._module_list[li][j] 169 | curr_out.append(mod(module_in)) 170 | prev_out = curr_out 171 | 172 | # Output modules (with sampled Gs) 173 | if self._end_modules is not None: 174 | li = self.num_layers - 1 175 | if t_not_set or self.single_head: 176 | # Dim: #Batch x C 177 | module_in_wt = sampled_g[li + 1][0] 178 | # Module input weights rearranged to match inputs 179 | module_in_wt = module_in_wt.transpose(0, 1) 180 | add_dims = prev_out[0].dim() + 1 - module_in_wt.dim() 181 | module_in_wt = module_in_wt.view(*module_in_wt.shape, 182 | *([1] * add_dims)) 183 | module_in_wt = module_in_wt.expand( 184 | len(prev_out), *prev_out[0].shape) 185 | module_in = sum([ 186 | module_in_wt[i] * prev_out[i] 187 | for i in range(len(prev_out)) 188 | ]) 189 | features = module_in 190 | prev_out = self._end_modules[0](module_in) 191 | else: 192 | curr_out = [] 193 | for bi, tid in enumerate(t): 194 | # Dim: #Batch x C 195 | gind = tid if not self._chain else 0 196 | module_in_wt = sampled_g[li + 1][gind] 197 | # Module input weights rearranged to match inputs 198 | module_in_wt = module_in_wt.transpose(0, 1) 199 | add_dims = prev_out[0].dim() + 1 - module_in_wt.dim() 200 | module_in_wt = module_in_wt.view( 201 | *module_in_wt.shape, *([1] * add_dims)) 202 | module_in_wt = module_in_wt.expand( 203 | len(prev_out), *prev_out[0].shape) 204 | module_in = sum([ 205 | module_in_wt[i] * prev_out[i] 206 | for i in range(len(prev_out)) 207 | ]) 208 | features = module_in 209 | mod = self._end_modules[tid] 210 | curr_out.append(mod(module_in[bi:bi + 1])) 211 | prev_out = curr_out 212 | prev_out = torch.cat(prev_out, 0) 213 | if return_feat: 214 | return prev_out, features 215 | return prev_out 216 | 217 | 218 | class CompositionalModel(nn.Module): 219 | def __init__(self, dset, args): 220 | super(CompositionalModel, self).__init__() 221 | self.args = args 222 | self.dset = dset 223 | 224 | # precompute validation pairs 225 | attrs, objs = zip(*self.dset.pairs) 226 | attrs = [dset.attr2idx[attr] for attr in attrs] 227 | objs = [dset.obj2idx[obj] for obj in objs] 228 | self.val_attrs = torch.LongTensor(attrs).cuda() 229 | self.val_objs = torch.LongTensor(objs).cuda() 230 | 231 | def train_forward_softmax(self, x): 232 | img, attrs, objs = x[0], x[1], x[2] 233 | neg_attrs, neg_objs = x[4], x[5] 234 | inv_attrs, comm_attrs = x[6], x[7] 235 | 236 | sampled_attrs = torch.cat((attrs.unsqueeze(1), neg_attrs), 1) 237 | sampled_objs = torch.cat((objs.unsqueeze(1), neg_objs), 1) 238 | img_ind = torch.arange(sampled_objs.shape[0]).unsqueeze(1).repeat( 239 | 1, sampled_attrs.shape[1]) 240 | 241 | flat_sampled_attrs = sampled_attrs.view(-1) 242 | flat_sampled_objs = sampled_objs.view(-1) 243 | flat_img_ind = img_ind.view(-1) 244 | labels = torch.zeros_like(sampled_attrs[:, 0]).long() 245 | 246 | self.composed_g = self.compose(flat_sampled_attrs, flat_sampled_objs) 247 | 248 | cls_scores, feat = self.comp_network( 249 | img[flat_img_ind], self.composed_g, return_feat=True) 250 | pair_scores = cls_scores[:, :1] 251 | pair_scores = pair_scores.view(*sampled_attrs.shape) 252 | 253 | loss = 0 254 | loss_cls = F.cross_entropy(pair_scores, labels) 255 | loss += loss_cls 256 | 257 | loss_obj = torch.FloatTensor([0]) 258 | loss_attr = torch.FloatTensor([0]) 259 | loss_sparse = torch.FloatTensor([0]) 260 | loss_unif = torch.FloatTensor([0]) 261 | loss_aux = torch.FloatTensor([0]) 262 | 263 | acc = (pair_scores.argmax(1) == labels).sum().float() / float( 264 | len(labels)) 265 | all_losses = {} 266 | all_losses['total_loss'] = loss 267 | all_losses['main_loss'] = loss_cls 268 | all_losses['aux_loss'] = loss_aux 269 | all_losses['obj_loss'] = loss_obj 270 | all_losses['attr_loss'] = loss_attr 271 | all_losses['sparse_loss'] = loss_sparse 272 | all_losses['unif_loss'] = loss_unif 273 | 274 | return loss, all_losses, acc, (pair_scores, feat) 275 | 276 | def val_forward(self, x): 277 | img = x[0] 278 | batch_size = img.shape[0] 279 | pair_scores = torch.zeros(batch_size, len(self.val_attrs)) 280 | pair_feats = torch.zeros(batch_size, len(self.val_attrs), 281 | self.args.emb_dim) 282 | pair_bs = len(self.val_attrs) 283 | 284 | for pi in range(math.ceil(len(self.val_attrs) / pair_bs)): 285 | self.compose_g = self.compose( 286 | self.val_attrs[pi * pair_bs:(pi + 1) * pair_bs], 287 | self.val_objs[pi * pair_bs:(pi + 1) * pair_bs]) 288 | compose_g = self.compose_g 289 | expanded_im = img.unsqueeze(1).repeat( 290 | 1, compose_g[0][0].shape[0], 291 | *tuple([1] * (img.dim() - 1))).view(-1, *img.shape[1:]) 292 | expanded_compose_g = [[ 293 | g.unsqueeze(0).repeat(batch_size, *tuple([1] * g.dim())).view( 294 | -1, *g.shape[1:]) for g in layer_g 295 | ] for layer_g in compose_g] 296 | this_pair_scores, this_feat = self.comp_network( 297 | expanded_im, expanded_compose_g, return_feat=True) 298 | featnorm = torch.norm(this_feat, p=2, dim=-1) 299 | this_feat = this_feat.div( 300 | featnorm.unsqueeze(-1).expand_as(this_feat)) 301 | 302 | this_pair_scores = this_pair_scores[:, :1].view(batch_size, -1) 303 | this_feat = this_feat.view(batch_size, -1, self.args.emb_dim) 304 | 305 | pair_scores[:, pi * pair_bs:pi * pair_bs + 306 | this_pair_scores.shape[1]] = this_pair_scores[:, :] 307 | pair_feats[:, pi * pair_bs:pi * pair_bs + 308 | this_pair_scores.shape[1], :] = this_feat[:] 309 | 310 | scores = {} 311 | feats = {} 312 | for i, (attr, obj) in enumerate(self.dset.pairs): 313 | scores[(attr, obj)] = pair_scores[:, i] 314 | feats[(attr, obj)] = pair_feats[:, i] 315 | 316 | # return None, (scores, feats) 317 | return None, scores 318 | 319 | def forward(self, x, with_grad=False): 320 | if self.training: 321 | loss, loss_aux, acc, pred = self.train_forward(x) 322 | else: 323 | loss_aux = torch.Tensor([0]) 324 | loss = torch.Tensor([0]) 325 | if not with_grad: 326 | with torch.no_grad(): 327 | acc, pred = self.val_forward(x) 328 | else: 329 | acc, pred = self.val_forward(x) 330 | # return loss, loss_aux, acc, pred 331 | return loss, pred 332 | 333 | 334 | class GatedGeneralNN(CompositionalModel): 335 | """Docstring for GatedCompositionalModel. """ 336 | 337 | def __init__(self, 338 | dset, 339 | args, 340 | num_layers=2, 341 | num_modules_per_layer=3, 342 | stoch_sample=False, 343 | use_full_model=False, 344 | num_classes=[2], 345 | gater_type='general'): 346 | """TODO: to be defined1. 347 | 348 | :dset: TODO 349 | :args: TODO 350 | 351 | """ 352 | CompositionalModel.__init__(self, dset, args) 353 | 354 | self.train_forward = self.train_forward_softmax 355 | self.compose_type = 'nn' #todo: could be different 356 | 357 | gating_in_dim = 128 358 | if args.emb_init: 359 | gating_in_dim = 300 360 | elif args.clf_init: 361 | gating_in_dim = 512 362 | 363 | if self.compose_type == 'nn': 364 | tdim = gating_in_dim * 2 365 | inter_tdim = self.args.embed_rank 366 | # Change this to allow only obj, only attr gatings 367 | self.attr_embedder = nn.Embedding( 368 | len(dset.attrs) + 1, 369 | gating_in_dim, 370 | padding_idx=len(dset.attrs), 371 | ) 372 | self.obj_embedder = nn.Embedding( 373 | len(dset.objs) + 1, 374 | gating_in_dim, 375 | padding_idx=len(dset.objs), 376 | ) 377 | 378 | # initialize the weights of the embedders with the svm weights 379 | if args.emb_init: 380 | pretrained_weight = load_word_embeddings( 381 | args.emb_init, dset.attrs) 382 | self.attr_embedder.weight[:-1, :].data.copy_(pretrained_weight) 383 | pretrained_weight = load_word_embeddings( 384 | args.emb_init, dset.objs) 385 | self.obj_embedder.weight.data[:-1, :].copy_(pretrained_weight) 386 | elif args.clf_init: 387 | for idx, attr in enumerate(dset.attrs): 388 | at_id = self.dset.attr2idx[attr] 389 | weight = torch.load( 390 | '%s/svm/attr_%d' % (args.data_dir, 391 | at_id)).coef_.squeeze() 392 | self.attr_embedder.weight[idx].data.copy_( 393 | torch.from_numpy(weight)) 394 | for idx, obj in enumerate(dset.objs): 395 | obj_id = self.dset.obj2idx[obj] 396 | weight = torch.load( 397 | '%s/svm/obj_%d' % (args.data_dir, 398 | obj_id)).coef_.squeeze() 399 | self.obj_embedder.weight[idx].data.copy_( 400 | torch.from_numpy(weight)) 401 | else: 402 | n_attr = len(dset.attrs) 403 | gating_in_dim = 300 404 | tdim = gating_in_dim * 2 + n_attr 405 | self.attr_embedder = nn.Embedding( 406 | n_attr, 407 | n_attr, 408 | ) 409 | self.attr_embedder.weight.data.copy_( 410 | torch.from_numpy(np.eye(n_attr))) 411 | self.obj_embedder = nn.Embedding( 412 | len(dset.objs) + 1, 413 | gating_in_dim, 414 | padding_idx=len(dset.objs), 415 | ) 416 | pretrained_weight = load_word_embeddings( 417 | '/home/ubuntu/workspace/czsl/data/glove/glove.6B.300d.txt', dset.objs) 418 | self.obj_embedder.weight.data[:-1, :].copy_(pretrained_weight) 419 | else: 420 | raise (NotImplementedError) 421 | 422 | self.comp_network, self.gating_network, self.nummods, _ = modular_general( 423 | num_layers=num_layers, 424 | num_modules_per_layer=num_modules_per_layer, 425 | feat_dim=dset.feat_dim, 426 | inter_dim=args.emb_dim, 427 | stoch_sample=stoch_sample, 428 | use_full_model=use_full_model, 429 | tdim=tdim, 430 | inter_tdim=inter_tdim, 431 | gater_type=gater_type, 432 | ) 433 | 434 | if args.static_inp: 435 | for param in self.attr_embedder.parameters(): 436 | param.requires_grad = False 437 | for param in self.obj_embedder.parameters(): 438 | param.requires_grad = False 439 | 440 | def compose(self, attrs, objs): 441 | obj_wt = self.obj_embedder(objs) 442 | if self.compose_type == 'nn': 443 | attr_wt = self.attr_embedder(attrs) 444 | inp_wts = torch.cat([attr_wt, obj_wt], 1) # 2D 445 | else: 446 | raise (NotImplementedError) 447 | composed_g, composed_g_wt = self.gating_network( 448 | inp_wts, return_additional=True) 449 | 450 | return composed_g 451 | 452 | 453 | class GeneralNormalizedNN(nn.Module): 454 | """Docstring for GatedCompositionalModel. """ 455 | 456 | def __init__(self, num_layers, num_modules_per_layer, in_dim, inter_dim): 457 | """TODO: to be defined1. """ 458 | nn.Module.__init__(self) 459 | self.start_modules = [nn.Sequential()] 460 | self.layer1 = [[ 461 | nn.Sequential( 462 | nn.Linear(in_dim, inter_dim), nn.BatchNorm1d(inter_dim), 463 | nn.ReLU()) for _ in range(num_modules_per_layer) 464 | ]] 465 | if num_layers > 1: 466 | self.layer2 = [[ 467 | nn.Sequential( 468 | nn.BatchNorm1d(inter_dim), nn.Linear(inter_dim, inter_dim), 469 | nn.BatchNorm1d(inter_dim), nn.ReLU()) 470 | for _m in range(num_modules_per_layer) 471 | ] for _l in range(num_layers - 1)] 472 | self.avgpool = nn.Sequential() 473 | self.fc = [ 474 | nn.Sequential(nn.BatchNorm1d(inter_dim), nn.Linear(inter_dim, 1)) 475 | ] 476 | 477 | 478 | class GeneralGatingNN(nn.Module): 479 | def __init__( 480 | self, 481 | num_mods, 482 | tdim, 483 | inter_tdim, 484 | randinit=False, 485 | ): 486 | """TODO: to be defined1. 487 | 488 | :num_mods: TODO 489 | :tdim: TODO 490 | 491 | """ 492 | nn.Module.__init__(self) 493 | 494 | self._num_mods = num_mods 495 | self._tdim = tdim 496 | self._inter_tdim = inter_tdim 497 | task_outdim = self._inter_tdim 498 | 499 | self.task_linear1 = nn.Linear(self._tdim, task_outdim, bias=False) 500 | self.task_bn1 = nn.BatchNorm1d(task_outdim) 501 | self.task_linear2 = nn.Linear(task_outdim, task_outdim, bias=False) 502 | self.task_bn2 = nn.BatchNorm1d(task_outdim) 503 | self.joint_linear1 = nn.Linear(task_outdim, task_outdim, bias=False) 504 | self.joint_bn1 = nn.BatchNorm1d(task_outdim) 505 | 506 | num_out = [[1]] + [[ 507 | self._num_mods[i - 1] for _ in range(self._num_mods[i]) 508 | ] for i in range(1, len(self._num_mods))] 509 | count = 0 510 | out_ind = [] 511 | for i in range(len(num_out)): 512 | this_out_ind = [] 513 | for j in range(len(num_out[i])): 514 | this_out_ind.append([count, count + num_out[i][j]]) 515 | count += num_out[i][j] 516 | out_ind.append(this_out_ind) 517 | self.out_ind = out_ind 518 | self.out_count = count 519 | 520 | self.joint_linear2 = nn.Linear(task_outdim, count, bias=False) 521 | 522 | def apply_init(m): 523 | if isinstance(m, nn.Linear): 524 | nn.init.normal_(m.weight, 0, 0.1) 525 | if m.bias is not None: 526 | nn.init.constant_(m.bias, 0) 527 | elif isinstance(m, nn.BatchNorm1d): 528 | nn.init.constant_(m.weight, 1) 529 | nn.init.constant_(m.bias, 0) 530 | 531 | for m in self.modules(): 532 | if isinstance(m, nn.ModuleList): 533 | for subm in m: 534 | if isinstance(subm, nn.ModuleList): 535 | for subsubm in subm: 536 | apply_init(subsubm) 537 | else: 538 | apply_init(subm) 539 | else: 540 | apply_init(m) 541 | 542 | if not randinit: 543 | self.joint_linear2.weight.data.zero_() 544 | 545 | def forward(self, tdesc=None): 546 | """TODO: Docstring for function. 547 | 548 | :arg1: TODO 549 | :returns: TODO 550 | 551 | """ 552 | if tdesc is None: 553 | return None 554 | 555 | x = tdesc 556 | task_embeds1 = F.relu(self.task_bn1(self.task_linear1(x))) 557 | joint_embed = task_embeds1 558 | joint_embed = self.joint_linear2(joint_embed) 559 | joint_embed = [[ 560 | joint_embed[:, self.out_ind[i][j][0]:self.out_ind[i][j][1]] 561 | for j in range(len(self.out_ind[i])) 562 | ] for i in range(len(self.out_ind))] 563 | 564 | gating_wt = joint_embed 565 | prob_g = [[F.softmax(wt, -1) for wt in gating_wt[i]] 566 | for i in range(len(gating_wt))] 567 | return prob_g, gating_wt 568 | 569 | 570 | def modularize_network( 571 | model, 572 | stoch_sample=False, 573 | use_full_model=False, 574 | tdim=200, 575 | inter_tdim=200, 576 | gater_type='general', 577 | single_head=True, 578 | num_classes=[2], 579 | num_lookup_gating=10, 580 | ): 581 | # Copy start modules and end modules 582 | start_modules = model.start_modules 583 | end_modules = [ 584 | nn.Sequential(model.avgpool, Flatten(), fci) for fci in model.fc 585 | ] 586 | 587 | # Create module_list as list of lists [[layer1 modules], [layer2 modules], ...] 588 | module_list = [] 589 | li = 1 590 | while True: 591 | if hasattr(model, 'layer{}'.format(li)): 592 | module_list.extend(getattr(model, 'layer{}'.format(li))) 593 | li += 1 594 | else: 595 | break 596 | 597 | num_module_list = [len(start_modules)] + [len(layer) for layer in module_list] \ 598 | + [len(end_modules)] 599 | gated_model_func = GatedModularNet 600 | gated_net = gated_model_func( 601 | module_list, 602 | start_modules=start_modules, 603 | end_modules=end_modules, 604 | single_head=single_head) 605 | 606 | gater_func = GeneralGatingNN 607 | gater = gater_func( 608 | num_mods=num_module_list, tdim=tdim, inter_tdim=inter_tdim) 609 | 610 | fan_in = num_module_list 611 | if use_full_model: 612 | gater = None 613 | 614 | # Create Gating Sampler 615 | gating_sampler = GatingSampler(gater=gater, stoch_sample=stoch_sample) 616 | return gated_net, gating_sampler, num_module_list, fan_in 617 | 618 | 619 | def modular_general( 620 | num_layers, 621 | num_modules_per_layer, 622 | feat_dim, 623 | inter_dim, 624 | stoch_sample=False, 625 | use_full_model=False, 626 | num_classes=[2], 627 | single_head=True, 628 | gater_type='general', 629 | tdim=300, 630 | inter_tdim=300, 631 | num_lookup_gating=10, 632 | ): 633 | # First create a ResNext model 634 | model = GeneralNormalizedNN( 635 | num_layers, 636 | num_modules_per_layer, 637 | feat_dim, 638 | inter_dim, 639 | ) 640 | 641 | # Modularize the model and create gating funcs 642 | gated_net, gating_sampler, num_module_list, fan_in = modularize_network( 643 | model, 644 | stoch_sample, 645 | use_full_model=use_full_model, 646 | gater_type=gater_type, 647 | tdim=tdim, 648 | inter_tdim=inter_tdim, 649 | single_head=single_head, 650 | num_classes=num_classes, 651 | num_lookup_gating=num_lookup_gating, 652 | ) 653 | return gated_net, gating_sampler, num_module_list, fan_in -------------------------------------------------------------------------------- /models/svm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tqdm 3 | from data import dataset as dset 4 | import os 5 | from utils import utils 6 | import torch 7 | from torch.autograd import Variable 8 | import h5py 9 | from sklearn.svm import LinearSVC 10 | from sklearn.model_selection import GridSearchCV 11 | import torch.nn.functional as F 12 | from joblib import Parallel, delayed 13 | import glob 14 | import scipy.io 15 | from sklearn.calibration import CalibratedClassifierCV 16 | 17 | 18 | import argparse 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--batch_size', type=int, default=512) 21 | parser.add_argument('--dataset', default='mitstates', help='mitstates|zappos') 22 | parser.add_argument('--data_dir', default='data/mit-states/') 23 | parser.add_argument('--generate', action ='store_true', default=False) 24 | parser.add_argument('--evalsvm', action ='store_true', default=False) 25 | parser.add_argument('--evaltf', action ='store_true', default=False) 26 | parser.add_argument('--completed', default='tensor-completion/completed/complete.mat') 27 | args = parser.parse_args() 28 | 29 | #----------------------------------------------------------------------------------------# 30 | 31 | search_params = {'C': np.logspace(-5,5,11)} 32 | def train_svm(Y_train, sample_weight=None): 33 | 34 | try: 35 | clf = GridSearchCV(LinearSVC(class_weight='balanced', fit_intercept=False), search_params, scoring='f1', cv=4) 36 | clf.fit(X_train, Y_train, sample_weight=sample_weight) 37 | except: 38 | clf = LinearSVC(C=0.1, class_weight='balanced', fit_intercept=False) 39 | if Y_train.sum()==len(Y_train) or Y_train.sum()==0: 40 | return None 41 | clf.fit(X_train, Y_train, sample_weight=sample_weight) 42 | 43 | return clf 44 | 45 | def generate_svms(): 46 | # train an SVM for every attribute, object and pair primitive 47 | 48 | Y = [(train_attrs==attr).astype(np.int) for attr in range(len(dataset.attrs))] 49 | attr_clfs = Parallel(n_jobs=32, verbose=16)(delayed(train_svm)(Y[attr]) for attr in range(len(dataset.attrs))) 50 | for attr, clf in enumerate(attr_clfs): 51 | print (attr, dataset.attrs[attr]) 52 | print ('params:', clf.best_params_) 53 | print ('-'*30) 54 | torch.save(clf.best_estimator_, '%s/svm/attr_%d'%(args.data_dir, attr)) 55 | 56 | Y = [(train_objs==obj).astype(np.int) for obj in range(len(dataset.objs))] 57 | obj_clfs = Parallel(n_jobs=32, verbose=16)(delayed(train_svm)(Y[obj]) for obj in range(len(dataset.objs))) 58 | for obj, clf in enumerate(obj_clfs): 59 | print (obj, dataset.objs[obj]) 60 | print ('params:', clf.best_params_) 61 | print ('-'*30) 62 | torch.save(clf.best_estimator_, '%s/svm/obj_%d'%(args.data_dir, obj)) 63 | 64 | 65 | Y, Y_attr, sample_weight = [], [], [] 66 | for idx, (attr, obj) in enumerate(dataset.train_pairs): 67 | Y_train = ((train_attrs==attr)*(train_objs==obj)).astype(np.int) 68 | 69 | # reweight instances to get a little more training data 70 | Y_train_attr = (train_attrs==attr).astype(np.int) 71 | instance_weights = 0.1*Y_train_attr 72 | instance_weights[Y_train.nonzero()[0]] = 1.0 73 | 74 | Y.append(Y_train) 75 | Y_attr.append(Y_train_attr) 76 | sample_weight.append(instance_weights) 77 | 78 | pair_clfs = Parallel(n_jobs=32, verbose=16)(delayed(train_svm)(Y[pair]) for pair in range(len(dataset.train_pairs))) 79 | for idx, (attr, obj) in enumerate(dataset.train_pairs): 80 | clf = pair_clfs[idx] 81 | print (dataset.attrs[attr], dataset.objs[obj]) 82 | 83 | try: 84 | print ('params:', clf.best_params_) 85 | torch.save(clf.best_estimator_, '%s/svm/pair_%d_%d'%(args.data_dir, attr, obj)) 86 | except: 87 | print ('FAILED! #positive:', Y[idx].sum(), len(Y[idx])) 88 | 89 | return 90 | 91 | def make_svm_tensor(): 92 | subs, vals, size = [], [], (len(dataset.attrs), len(dataset.objs), X.shape[1]) 93 | fullsubs, fullvals = [], [] 94 | composite_clfs = glob.glob('%s/svm/pair*'%args.data_dir) 95 | print ('%d composite classifiers found'%(len(composite_clfs))) 96 | for clf in tqdm.tqdm(composite_clfs): 97 | _, attr, obj = os.path.basename(clf).split('_') 98 | attr, obj = int(attr), int(obj) 99 | 100 | clf = torch.load(clf) 101 | weight = clf.coef_.squeeze() 102 | for i in range(len(weight)): 103 | subs.append((attr, obj, i)) 104 | vals.append(weight[i]) 105 | 106 | for attr, obj in dataset.pairs: 107 | for i in range(X.shape[1]): 108 | fullsubs.append((attr, obj, i)) 109 | 110 | subs, vals = np.array(subs), np.array(vals).reshape(-1,1) 111 | fullsubs, fullvals = np.array(fullsubs), np.ones(len(fullsubs)).reshape(-1,1) 112 | savedat = {'subs':subs, 'vals':vals, 'size':size, 'fullsubs':fullsubs, 'fullvals':fullvals} 113 | scipy.io.savemat('tensor-completion/incomplete/%s.mat'%args.dataset, savedat) 114 | print (subs.shape, vals.shape, size) 115 | 116 | def evaluate_svms(): 117 | 118 | attr_clfs = [torch.load('%s/svm/attr_%d'%(args.data_dir, attr)) for attr in range(len(dataset.attrs))] 119 | obj_clfs = [torch.load('%s/svm/obj_%d'%(args.data_dir, obj)) for obj in range(len(dataset.objs))] 120 | 121 | # Calibrate all classifiers first 122 | Y = [(train_attrs==attr).astype(np.int) for attr in range(len(dataset.attrs))] 123 | for attr in tqdm.tqdm(range(len(dataset.attrs))): 124 | clf = attr_clfs[attr] 125 | calibrated = CalibratedClassifierCV(clf, method='sigmoid', cv='prefit') 126 | calibrated.fit(X_train, Y[attr]) 127 | attr_clfs[attr] = calibrated 128 | 129 | Y = [(train_objs==obj).astype(np.int) for obj in range(len(dataset.objs))] 130 | for obj in tqdm.tqdm(range(len(dataset.objs))): 131 | clf = obj_clfs[obj] 132 | calibrated = CalibratedClassifierCV(clf, method='sigmoid', cv='prefit') 133 | calibrated.fit(X_train, Y[obj]) 134 | obj_clfs[obj] = calibrated 135 | 136 | # Generate all the scores 137 | attr_scores, obj_scores = [], [] 138 | for attr in tqdm.tqdm(range(len(dataset.attrs))): 139 | clf = attr_clfs[attr] 140 | score = clf.predict_proba(X_test)[:,1] 141 | attr_scores.append(score) 142 | attr_scores = np.vstack(attr_scores) 143 | 144 | for obj in tqdm.tqdm(range(len(dataset.objs))): 145 | clf = obj_clfs[obj] 146 | score = clf.predict_proba(X_test)[:,1] 147 | obj_scores.append(score) 148 | obj_scores = np.vstack(obj_scores) 149 | 150 | attr_pred = torch.from_numpy(attr_scores).transpose(0,1) 151 | obj_pred = torch.from_numpy(obj_scores).transpose(0,1) 152 | 153 | x = [None, Variable(torch.from_numpy(test_attrs)).long(), Variable(torch.from_numpy(test_objs)).long(), Variable(torch.from_numpy(test_pairs)).long()] 154 | attr_pred, obj_pred, _ = utils.generate_prediction_tensors([attr_pred, obj_pred], dataset, x[2].data, source='classification') 155 | attr_match, obj_match, zsl_match, gzsl_match, fixobj_match = utils.performance_stats(attr_pred, obj_pred, x) 156 | print (attr_match.mean(), obj_match.mean(), zsl_match.mean(), gzsl_match.mean(), fixobj_match.mean()) 157 | 158 | def evaluate_tensorcompletion(): 159 | 160 | def parse_tensor(fl): 161 | tensor = scipy.io.loadmat(fl) 162 | nz_idx = zip(*(tensor['subs'])) 163 | composite_clfs = np.zeros((len(dataset.attrs), len(dataset.objs), X.shape[1])) 164 | composite_clfs[nz_idx[0], nz_idx[1], nz_idx[2]] = tensor['vals'].squeeze() 165 | return composite_clfs, nz_idx, tensor['vals'].squeeze() 166 | 167 | # see recon error 168 | tr_file = 'tensor-completion/incomplete/%s.mat'%args.dataset 169 | ts_file = args.completed 170 | 171 | tr_clfs, tr_nz_idx, tr_vals = parse_tensor(tr_file) 172 | ts_clfs, ts_nz_idx, ts_vals = parse_tensor(ts_file) 173 | 174 | print (tr_vals.min(), tr_vals.max(), tr_vals.mean()) 175 | print (ts_vals.min(), ts_vals.max(), ts_vals.mean()) 176 | 177 | print ('Completed Tensor: %s'%args.completed) 178 | 179 | # see train recon error 180 | err = 1.0*((tr_clfs[tr_nz_idx[0], tr_nz_idx[1], tr_nz_idx[2]]-ts_clfs[tr_nz_idx[0], tr_nz_idx[1], tr_nz_idx[2]])**2).sum()/(len(tr_vals)) 181 | print ('recon error:', err) 182 | 183 | # Create and scale classifiers for each pair 184 | clfs = {} 185 | test_pair_set = set(map(tuple, dataset.test_pairs.numpy().tolist())) 186 | for idx, (attr, obj) in tqdm.tqdm(enumerate(dataset.pairs), total=len(dataset.pairs)): 187 | clf = LinearSVC(fit_intercept=False) 188 | clf.fit(np.eye(2), [0,1]) 189 | 190 | if (attr, obj) in test_pair_set: 191 | X_ = X_test 192 | Y_ = (test_attrs==attr).astype(np.int)*(test_objs==obj).astype(np.int) 193 | clf.coef_ = ts_clfs[attr, obj][None,:] 194 | else: 195 | X_ = X_train 196 | Y_ = (train_attrs==attr).astype(np.int)*(train_objs==obj).astype(np.int) 197 | clf.coef_ = tr_clfs[attr, obj][None,:] 198 | 199 | calibrated = CalibratedClassifierCV(clf, method='sigmoid', cv='prefit') 200 | calibrated.fit(X_, Y_) 201 | clfs[(attr, obj)] = calibrated 202 | 203 | scores = {} 204 | for attr, obj in tqdm.tqdm(dataset.pairs): 205 | score = clfs[(attr, obj)].predict_proba(X_test)[:,1] 206 | scores[(attr, obj)] = torch.from_numpy(score).float().unsqueeze(1) 207 | 208 | x = [None, Variable(torch.from_numpy(test_attrs)).long(), Variable(torch.from_numpy(test_objs)).long(), Variable(torch.from_numpy(test_pairs)).long()] 209 | attr_pred, obj_pred, _ = utils.generate_prediction_tensors(scores, dataset, x[2].data, source='manifold') 210 | attr_match, obj_match, zsl_match, gzsl_match, fixobj_match = utils.performance_stats(attr_pred, obj_pred, x) 211 | print (attr_match.mean(), obj_match.mean(), zsl_match.mean(), gzsl_match.mean(), fixobj_match.mean()) 212 | 213 | 214 | 215 | #----------------------------------------------------------------------------------------# 216 | if args.dataset == 'mitstates': 217 | DSet = dset.MITStatesActivations 218 | elif args.dataset == 'zappos': 219 | DSet = dset.UTZapposActivations 220 | 221 | dataset = DSet(root=args.data_dir, phase='train') 222 | train_idx, train_attrs, train_objs, train_pairs = map(np.array, zip(*dataset.train_data)) 223 | test_idx, test_attrs, test_objs, test_pairs = map(np.array, zip(*dataset.test_data)) 224 | 225 | X = dataset.activations.numpy() 226 | X_train, X_test = X[train_idx,:], X[test_idx,:] 227 | 228 | print (len(dataset.attrs), len(dataset.objs), len(dataset.pairs)) 229 | print (X_train.shape, X_test.shape) 230 | 231 | if args.generate: 232 | generate_svms() 233 | make_svm_tensor() 234 | 235 | if args.evalsvm: 236 | evaluate_svms() 237 | 238 | if args.evaltf: 239 | evaluate_tensorcompletion() 240 | -------------------------------------------------------------------------------- /models/symnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from .word_embedding import load_word_embeddings 6 | from .common import MLP 7 | 8 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 9 | 10 | 11 | 12 | class Symnet(nn.Module): 13 | def __init__(self, dset, args): 14 | super(Symnet, self).__init__() 15 | self.dset = dset 16 | self.args = args 17 | self.num_attrs = len(dset.attrs) 18 | self.num_objs = len(dset.objs) 19 | 20 | self.image_embedder = MLP(dset.feat_dim, args.emb_dim, relu = False) 21 | self.ce_loss = nn.CrossEntropyLoss() 22 | self.mse_loss = nn.MSELoss() 23 | self.nll_loss = nn.NLLLoss() 24 | self.softmax = nn.Softmax(dim=1) 25 | # Attribute embedder and object classifier 26 | self.attr_embedder = nn.Embedding(len(dset.attrs), args.emb_dim) 27 | 28 | self.obj_classifier = MLP(args.emb_dim, len(dset.objs), num_layers = 2, relu = False, dropout = True, layers = [512]) # original paper uses a second fc classifier for final product 29 | self.attr_classifier = MLP(args.emb_dim, len(dset.attrs), num_layers = 2, relu = False, dropout = True, layers = [512]) 30 | # CoN and DecoN 31 | self.CoN_fc_attention = MLP(args.emb_dim, args.emb_dim, num_layers = 2, relu = False, dropout = True, layers = [512]) 32 | self.CoN_emb = MLP(args.emb_dim + args.emb_dim, args.emb_dim, num_layers = 2, relu = False, dropout = True, layers = [768]) 33 | self.DecoN_fc_attention = MLP(args.emb_dim, args.emb_dim, num_layers = 2, relu = False, dropout = True, layers = [512]) 34 | self.DecoN_emb = MLP(args.emb_dim + args.emb_dim, args.emb_dim, num_layers = 2, relu = False, dropout = True, layers = [768]) 35 | 36 | # if args.glove_init: 37 | pretrained_weight = load_word_embeddings(args.emb_init, dset.attrs) 38 | self.attr_embedder.weight.data.copy_(pretrained_weight) 39 | 40 | for param in self.attr_embedder.parameters(): 41 | param.requires_grad = False 42 | 43 | def CoN(self, img_embedding, attr_embedding): 44 | attention = torch.sigmoid(self.CoN_fc_attention(attr_embedding)) 45 | img_embedding = attention*img_embedding + img_embedding 46 | 47 | hidden = torch.cat([img_embedding, attr_embedding], dim = 1) 48 | output = self.CoN_emb(hidden) 49 | return output 50 | 51 | def DeCoN(self, img_embedding, attr_embedding): 52 | attention = torch.sigmoid(self.DecoN_fc_attention(attr_embedding)) 53 | img_embedding = attention*img_embedding + img_embedding 54 | 55 | hidden = torch.cat([img_embedding, attr_embedding], dim = 1) 56 | output = self.DecoN_emb(hidden) 57 | return output 58 | 59 | 60 | def distance_metric(self, a, b): 61 | return torch.norm(a-b, dim = -1) 62 | 63 | def RMD_prob(self, feat_plus, feat_minus, repeat_img_feat): 64 | """return attribute classification probability with our RMD""" 65 | # feat_plus, feat_minus: shape=(bz, #attr, dim_emb) 66 | # d_plus: distance between feature before&after CoN 67 | # d_minus: distance between feature before&after DecoN 68 | d_plus = self.distance_metric(feat_plus, repeat_img_feat).reshape(-1, self.num_attrs) 69 | d_minus = self.distance_metric(feat_minus, repeat_img_feat).reshape(-1, self.num_attrs) 70 | # not adding softmax because it is part of cross entropy loss 71 | return d_minus - d_plus 72 | 73 | def train_forward(self, x): 74 | pos_image_feat, pos_attr_id, pos_obj_id = x[0], x[1], x[2] 75 | neg_attr_id = x[4][:,0] 76 | 77 | batch_size = pos_image_feat.size(0) 78 | 79 | loss = [] 80 | pos_attr_emb = self.attr_embedder(pos_attr_id) 81 | neg_attr_emb = self.attr_embedder(neg_attr_id) 82 | 83 | pos_img = self.image_embedder(pos_image_feat) 84 | 85 | # rA = remove positive attribute A 86 | # aA = add positive attribute A 87 | # rB = remove negative attribute B 88 | # aB = add negative attribute B 89 | pos_rA = self.DeCoN(pos_img, pos_attr_emb) 90 | pos_aA = self.CoN(pos_img, pos_attr_emb) 91 | pos_rB = self.DeCoN(pos_img, neg_attr_emb) 92 | pos_aB = self.CoN(pos_img, neg_attr_emb) 93 | 94 | # get all attr embedding #attr, embedding 95 | attr_emb = torch.LongTensor(np.arange(self.num_attrs)).to(device) 96 | attr_emb = self.attr_embedder(attr_emb) 97 | tile_attr_emb = attr_emb.repeat(batch_size, 1) # (batch*attr, dim_emb) 98 | 99 | # Now we calculate all the losses 100 | if self.args.lambda_cls_attr > 0: 101 | # Original image 102 | score_pos_A = self.attr_classifier(pos_img) 103 | loss_cls_pos_a = self.ce_loss(score_pos_A, pos_attr_id) 104 | 105 | # After removing pos_attr 106 | score_pos_rA_A = self.attr_classifier(pos_rA) 107 | total = sum(score_pos_rA_A) 108 | # prob_pos_rA_A = 1 - self.softmax(score_pos_rA_A) 109 | # loss_cls_pos_rA_a = self.nll_loss(torch.log(prob_pos_rA_A), pos_attr_id) 110 | loss_cls_pos_rA_a = self.ce_loss(total - score_pos_rA_A, pos_attr_id) #should be maximum for the gt label 111 | 112 | # rmd time 113 | repeat_img_feat = torch.repeat_interleave(pos_img, self.num_attrs, 0) #(batch*attr, dim_rep) 114 | feat_plus = self.CoN(repeat_img_feat, tile_attr_emb) 115 | feat_minus = self.DeCoN(repeat_img_feat, tile_attr_emb) 116 | score_cls_rmd = self.RMD_prob(feat_plus, feat_minus, repeat_img_feat) 117 | loss_cls_rmd = self.ce_loss(score_cls_rmd, pos_attr_id) 118 | 119 | loss_cls_attr = self.args.lambda_cls_attr*sum([loss_cls_pos_a, loss_cls_pos_rA_a, loss_cls_rmd]) 120 | loss.append(loss_cls_attr) 121 | 122 | if self.args.lambda_cls_obj > 0: 123 | # Original image 124 | score_pos_O = self.obj_classifier(pos_img) 125 | loss_cls_pos_o = self.ce_loss(score_pos_O, pos_obj_id) 126 | 127 | # After removing pos attr 128 | score_pos_rA_O = self.obj_classifier(pos_rA) 129 | loss_cls_pos_rA_o = self.ce_loss(score_pos_rA_O, pos_obj_id) 130 | 131 | # After adding neg attr 132 | score_pos_aB_O = self.obj_classifier(pos_aB) 133 | loss_cls_pos_aB_o = self.ce_loss(score_pos_aB_O, pos_obj_id) 134 | 135 | loss_cls_obj = self.args.lambda_cls_obj * sum([loss_cls_pos_o, loss_cls_pos_rA_o, loss_cls_pos_aB_o]) 136 | 137 | loss.append(loss_cls_obj) 138 | 139 | if self.args.lambda_sym > 0: 140 | 141 | loss_sys_pos = self.mse_loss(pos_aA, pos_img) 142 | loss_sys_neg = self.mse_loss(pos_rB, pos_img) 143 | loss_sym = self.args.lambda_sym * (loss_sys_pos + loss_sys_neg) 144 | 145 | loss.append(loss_sym) 146 | 147 | ##### Axiom losses 148 | if self.args.lambda_axiom > 0: 149 | loss_clo = loss_inv = loss_com = 0 150 | # closure 151 | pos_aA_rA = self.DeCoN(pos_aA, pos_attr_emb) 152 | pos_rB_aB = self.CoN(pos_rB, neg_attr_emb) 153 | loss_clo = self.mse_loss(pos_aA_rA, pos_rA) + self.mse_loss(pos_rB_aB, pos_aB) 154 | 155 | # invertibility 156 | pos_rA_aA = self.CoN(pos_rA, pos_attr_emb) 157 | pos_aB_rB = self.DeCoN(pos_aB, neg_attr_emb) 158 | loss_inv = self.mse_loss(pos_rA_aA, pos_img) + self.mse_loss(pos_aB_rB, pos_img) 159 | 160 | # commutative 161 | pos_aA_rB = self.DeCoN(pos_aA, neg_attr_emb) 162 | pos_rB_aA = self.DeCoN(pos_rB, pos_attr_emb) 163 | loss_com = self.mse_loss(pos_aA_rB, pos_rB_aA) 164 | 165 | loss_axiom = self.args.lambda_axiom * (loss_clo + loss_inv + loss_com) 166 | loss.append(loss_axiom) 167 | 168 | # triplet loss 169 | if self.args.lambda_trip > 0: 170 | pos_triplet = F.triplet_margin_loss(pos_img, pos_aA, pos_rA) 171 | neg_triplet = F.triplet_margin_loss(pos_img, pos_rB, pos_aB) 172 | 173 | loss_triplet = self.args.lambda_trip * (pos_triplet + neg_triplet) 174 | loss.append(loss_triplet) 175 | 176 | 177 | loss = sum(loss) 178 | return loss, None 179 | 180 | def val_forward(self, x): 181 | pos_image_feat, pos_attr_id, pos_obj_id = x[0], x[1], x[2] 182 | batch_size = pos_image_feat.shape[0] 183 | pos_img = self.image_embedder(pos_image_feat) 184 | repeat_img_feat = torch.repeat_interleave(pos_img, self.num_attrs, 0) #(batch*attr, dim_rep) 185 | 186 | # get all attr embedding #attr, embedding 187 | attr_emb = torch.LongTensor(np.arange(self.num_attrs)).to(device) 188 | attr_emb = self.attr_embedder(attr_emb) 189 | tile_attr_emb = attr_emb.repeat(batch_size, 1) # (batch*attr, dim_emb) 190 | 191 | feat_plus = self.CoN(repeat_img_feat, tile_attr_emb) 192 | feat_minus = self.DeCoN(repeat_img_feat, tile_attr_emb) 193 | score_cls_rmd = self.RMD_prob(feat_plus, feat_minus, repeat_img_feat) 194 | prob_A_rmd = F.softmax(score_cls_rmd, dim = 1) 195 | 196 | score_obj = self.obj_classifier(pos_img) 197 | prob_O = F.softmax(score_obj, dim = 1) 198 | 199 | scores = {} 200 | for itr, (attr, obj) in enumerate(self.dset.pairs): 201 | attr_id, obj_id = self.dset.attr2idx[attr], self.dset.obj2idx[obj] 202 | score = prob_A_rmd[:,attr_id] * prob_O[:, obj_id] 203 | 204 | scores[(attr, obj)] = score 205 | 206 | return None, scores 207 | 208 | def forward(self, x): 209 | if self.training: 210 | loss, pred = self.train_forward(x) 211 | else: 212 | with torch.no_grad(): 213 | loss, pred = self.val_forward(x) 214 | return loss, pred -------------------------------------------------------------------------------- /models/visual_product.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision.models as models 5 | import numpy as np 6 | from .common import MLP 7 | 8 | class VisualProductNN(nn.Module): 9 | def __init__(self, dset, args): 10 | super(VisualProductNN, self).__init__() 11 | self.attr_clf = MLP(dset.feat_dim, len(dset.attrs), 2, relu = False) 12 | self.obj_clf = MLP(dset.feat_dim, len(dset.objs), 2, relu = False) 13 | self.dset = dset 14 | 15 | def train_forward(self, x): 16 | img, attrs, objs = x[0],x[1], x[2] 17 | 18 | attr_pred = self.attr_clf(img) 19 | obj_pred = self.obj_clf(img) 20 | 21 | attr_loss = F.cross_entropy(attr_pred, attrs) 22 | obj_loss = F.cross_entropy(obj_pred, objs) 23 | 24 | loss = attr_loss + obj_loss 25 | 26 | return loss, None 27 | 28 | def val_forward(self, x): 29 | img = x[0] 30 | attr_pred = F.softmax(self.attr_clf(img), dim =1) 31 | obj_pred = F.softmax(self.obj_clf(img), dim = 1) 32 | 33 | scores = {} 34 | for itr, (attr, obj) in enumerate(self.dset.pairs): 35 | attr_id, obj_id = self.dset.attr2idx[attr], self.dset.obj2idx[obj] 36 | score = attr_pred[:,attr_id] * obj_pred[:, obj_id] 37 | 38 | scores[(attr, obj)] = score 39 | return None, scores 40 | 41 | def forward(self, x): 42 | if self.training: 43 | loss, pred = self.train_forward(x) 44 | else: 45 | with torch.no_grad(): 46 | loss, pred = self.val_forward(x) 47 | 48 | return loss, pred -------------------------------------------------------------------------------- /models/word_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from flags import DATA_FOLDER 4 | def load_word_embeddings(emb_type, vocab): 5 | if emb_type == 'glove': 6 | embeds = load_glove_embeddings(vocab) 7 | elif emb_type == 'fasttext': 8 | embeds = load_fasttext_embeddings(vocab) 9 | elif emb_type == 'word2vec': 10 | embeds = load_word2vec_embeddings(vocab) 11 | elif emb_type == 'ft+w2v': 12 | embeds1 = load_fasttext_embeddings(vocab) 13 | embeds2 = load_word2vec_embeddings(vocab) 14 | embeds = torch.cat([embeds1, embeds2], dim = 1) 15 | print('Combined embeddings are ',embeds.shape) 16 | elif emb_type == 'ft+gl': 17 | embeds1 = load_fasttext_embeddings(vocab) 18 | embeds2 = load_glove_embeddings(vocab) 19 | embeds = torch.cat([embeds1, embeds2], dim = 1) 20 | print('Combined embeddings are ',embeds.shape) 21 | elif emb_type == 'ft+ft': 22 | embeds1 = load_fasttext_embeddings(vocab) 23 | embeds = torch.cat([embeds1, embeds1], dim = 1) 24 | print('Combined embeddings are ',embeds.shape) 25 | elif emb_type == 'gl+w2v': 26 | embeds1 = load_glove_embeddings(vocab) 27 | embeds2 = load_word2vec_embeddings(vocab) 28 | embeds = torch.cat([embeds1, embeds2], dim = 1) 29 | print('Combined embeddings are ',embeds.shape) 30 | elif emb_type == 'ft+w2v+gl': 31 | embeds1 = load_fasttext_embeddings(vocab) 32 | embeds2 = load_word2vec_embeddings(vocab) 33 | embeds3 = load_glove_embeddings(vocab) 34 | embeds = torch.cat([embeds1, embeds2, embeds3], dim = 1) 35 | print('Combined embeddings are ',embeds.shape) 36 | else: 37 | raise ValueError('Invalid embedding') 38 | return embeds 39 | 40 | def load_fasttext_embeddings(vocab): 41 | custom_map = { 42 | 'Faux.Fur': 'fake fur', 43 | 'Faux.Leather': 'fake leather', 44 | 'Full.grain.leather': 'thick leather', 45 | 'Hair.Calf': 'hairy leather', 46 | 'Patent.Leather': 'shiny leather', 47 | 'Boots.Ankle': 'ankle boots', 48 | 'Boots.Knee.High': 'kneehigh boots', 49 | 'Boots.Mid-Calf': 'midcalf boots', 50 | 'Shoes.Boat.Shoes': 'boatshoes', 51 | 'Shoes.Clogs.and.Mules': 'clogs shoes', 52 | 'Shoes.Flats': 'flats shoes', 53 | 'Shoes.Heels': 'heels', 54 | 'Shoes.Loafers': 'loafers', 55 | 'Shoes.Oxfords': 'oxford shoes', 56 | 'Shoes.Sneakers.and.Athletic.Shoes': 'sneakers', 57 | 'traffic_light': 'traficlight', 58 | 'trash_can': 'trashcan', 59 | 'dry-erase_board' : 'dry_erase_board', 60 | 'black_and_white' : 'black_white', 61 | 'eiffel_tower' : 'tower' 62 | } 63 | vocab_lower = [v.lower() for v in vocab] 64 | vocab = [] 65 | for current in vocab_lower: 66 | if current in custom_map: 67 | vocab.append(custom_map[current]) 68 | else: 69 | vocab.append(current) 70 | 71 | import fasttext.util 72 | ft = fasttext.load_model(DATA_FOLDER+'/fast/cc.en.300.bin') 73 | embeds = [] 74 | for k in vocab: 75 | if '_' in k: 76 | ks = k.split('_') 77 | emb = np.stack([ft.get_word_vector(it) for it in ks]).mean(axis=0) 78 | else: 79 | emb = ft.get_word_vector(k) 80 | embeds.append(emb) 81 | 82 | embeds = torch.Tensor(np.stack(embeds)) 83 | print('Fasttext Embeddings loaded, total embeddings: {}'.format(embeds.size())) 84 | return embeds 85 | 86 | def load_word2vec_embeddings(vocab): 87 | # vocab = [v.lower() for v in vocab] 88 | 89 | from gensim import models 90 | model = models.KeyedVectors.load_word2vec_format( 91 | DATA_FOLDER+'/w2v/GoogleNews-vectors-negative300.bin', binary=True) 92 | 93 | custom_map = { 94 | 'Faux.Fur': 'fake_fur', 95 | 'Faux.Leather': 'fake_leather', 96 | 'Full.grain.leather': 'thick_leather', 97 | 'Hair.Calf': 'hair_leather', 98 | 'Patent.Leather': 'shiny_leather', 99 | 'Boots.Ankle': 'ankle_boots', 100 | 'Boots.Knee.High': 'knee_high_boots', 101 | 'Boots.Mid-Calf': 'midcalf_boots', 102 | 'Shoes.Boat.Shoes': 'boat_shoes', 103 | 'Shoes.Clogs.and.Mules': 'clogs_shoes', 104 | 'Shoes.Flats': 'flats_shoes', 105 | 'Shoes.Heels': 'heels', 106 | 'Shoes.Loafers': 'loafers', 107 | 'Shoes.Oxfords': 'oxford_shoes', 108 | 'Shoes.Sneakers.and.Athletic.Shoes': 'sneakers', 109 | 'traffic_light': 'traffic_light', 110 | 'trash_can': 'trashcan', 111 | 'dry-erase_board' : 'dry_erase_board', 112 | 'black_and_white' : 'black_white', 113 | 'eiffel_tower' : 'tower' 114 | } 115 | 116 | embeds = [] 117 | for k in vocab: 118 | if k in custom_map: 119 | k = custom_map[k] 120 | if '_' in k and k not in model: 121 | ks = k.split('_') 122 | emb = np.stack([model[it] for it in ks]).mean(axis=0) 123 | else: 124 | emb = model[k] 125 | embeds.append(emb) 126 | embeds = torch.Tensor(np.stack(embeds)) 127 | print('Word2Vec Embeddings loaded, total embeddings: {}'.format(embeds.size())) 128 | return embeds 129 | 130 | def load_glove_embeddings(vocab): 131 | ''' 132 | Inputs 133 | emb_file: Text file with word embedding pairs e.g. Glove, Processed in lower case. 134 | vocab: List of words 135 | Returns 136 | Embedding Matrix 137 | ''' 138 | vocab = [v.lower() for v in vocab] 139 | emb_file = DATA_FOLDER+'/glove/glove.6B.300d.txt' 140 | model = {} # populating a dictionary of word and embeddings 141 | for line in open(emb_file, 'r'): 142 | line = line.strip().split(' ') # Word-embedding 143 | wvec = torch.FloatTensor(list(map(float, line[1:]))) 144 | model[line[0]] = wvec 145 | 146 | # Adding some vectors for UT Zappos 147 | custom_map = { 148 | 'faux.fur': 'fake_fur', 149 | 'faux.leather': 'fake_leather', 150 | 'full.grain.leather': 'thick_leather', 151 | 'hair.calf': 'hair_leather', 152 | 'patent.leather': 'shiny_leather', 153 | 'boots.ankle': 'ankle_boots', 154 | 'boots.knee.high': 'knee_high_boots', 155 | 'boots.mid-calf': 'midcalf_boots', 156 | 'shoes.boat.shoes': 'boat_shoes', 157 | 'shoes.clogs.and.mules': 'clogs_shoes', 158 | 'shoes.flats': 'flats_shoes', 159 | 'shoes.heels': 'heels', 160 | 'shoes.loafers': 'loafers', 161 | 'shoes.oxfords': 'oxford_shoes', 162 | 'shoes.sneakers.and.athletic.shoes': 'sneakers', 163 | 'traffic_light': 'traffic_light', 164 | 'trash_can': 'trashcan', 165 | 'dry-erase_board' : 'dry_erase_board', 166 | 'black_and_white' : 'black_white', 167 | 'eiffel_tower' : 'tower', 168 | 'nubuck' : 'grainy_leather', 169 | } 170 | 171 | embeds = [] 172 | for k in vocab: 173 | if k in custom_map: 174 | k = custom_map[k] 175 | if '_' in k: 176 | ks = k.split('_') 177 | emb = torch.stack([model[it] for it in ks]).mean(dim=0) 178 | else: 179 | emb = model[k] 180 | embeds.append(emb) 181 | embeds = torch.stack(embeds) 182 | print('Glove Embeddings loaded, total embeddings: {}'.format(embeds.size())) 183 | return embeds -------------------------------------------------------------------------------- /notebooks/analysis.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "sys.path.append('../')\n", 11 | "# Torch imports\n", 12 | "import torch\n", 13 | "import torch.nn as nn\n", 14 | "import torch.optim as optim\n", 15 | "import torch.nn.functional as F\n", 16 | "from torch.utils.tensorboard import SummaryWriter\n", 17 | "import torch.backends.cudnn as cudnn\n", 18 | "cudnn.benchmark = True\n", 19 | "\n", 20 | "# Python imports\n", 21 | "import numpy as np\n", 22 | "import tqdm\n", 23 | "import torchvision.models as tmodels\n", 24 | "from tqdm import tqdm\n", 25 | "import os\n", 26 | "from os.path import join as ospj\n", 27 | "import itertools\n", 28 | "import glob\n", 29 | "import random\n", 30 | "\n", 31 | "#Local imports\n", 32 | "from data import dataset as dset\n", 33 | "from models.common import Evaluator\n", 34 | "from models.image_extractor import get_image_extractor\n", 35 | "from models.manifold_methods import RedWine, LabelEmbedPlus, AttributeOperator\n", 36 | "from models.modular_methods import GatedGeneralNN\n", 37 | "from models.symnet import Symnet\n", 38 | "from utils.utils import save_args, UnNormalizer, load_args\n", 39 | "from utils.config_model import configure_model\n", 40 | "from flags import parser\n", 41 | "from PIL import Image\n", 42 | "import matplotlib.pyplot as plt\n", 43 | "import importlib\n", 44 | "\n", 45 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 46 | "args, unknown = parser.parse_known_args()" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "### Run one of the cells to load the dataset you want to run test for and move to the next section" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "best_mit = '../logs/graphembed/mitstates/base/mit.yml'\n", 63 | "load_args(best_mit,args)\n", 64 | "args.graph_init = '../'+args.graph_init\n", 65 | "args.load = best_mit[:-7] + 'ckpt_best_auc.t7'" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "best_ut = '../logs/graphembed/utzappos/base/utzappos.yml'\n", 75 | "load_args(best_ut,args)\n", 76 | "args.graph_init = '../'+args.graph_init\n", 77 | "args.load = best_ut[:-12] + 'ckpt_best_auc.t7'" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "### Loading arguments and dataset" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "args.data_dir = '../'+args.data_dir\n", 94 | "args.test_set = 'test'\n", 95 | "testset = dset.CompositionDataset(\n", 96 | " root= args.data_dir,\n", 97 | " phase=args.test_set,\n", 98 | " split=args.splitname,\n", 99 | " model =args.image_extractor,\n", 100 | " subset=args.subset,\n", 101 | " return_images = True,\n", 102 | " update_features = args.update_features,\n", 103 | " clean_only = args.clean_only\n", 104 | " )\n", 105 | "testloader = torch.utils.data.DataLoader(\n", 106 | " testset,\n", 107 | " batch_size=args.test_batch_size,\n", 108 | " shuffle=True,\n", 109 | " num_workers=args.workers)\n", 110 | "\n", 111 | "print('Objs ', len(testset.objs), ' Attrs ', len(testset.attrs))" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "image_extractor, model, optimizer = configure_model(args, testset)\n", 121 | "evaluator = Evaluator(testset, model)" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "if args.load is not None:\n", 131 | " checkpoint = torch.load(args.load)\n", 132 | " if image_extractor:\n", 133 | " try:\n", 134 | " image_extractor.load_state_dict(checkpoint['image_extractor'])\n", 135 | " image_extractor.eval()\n", 136 | " except:\n", 137 | " print('No Image extractor in checkpoint')\n", 138 | " model.load_state_dict(checkpoint['net'])\n", 139 | " model.eval()\n", 140 | " print('Loaded model from ', args.load)\n", 141 | " print('Best AUC: ', checkpoint['AUC'])" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "def print_results(scores, exp):\n", 151 | " print(exp)\n", 152 | " result = scores[exp]\n", 153 | " attr = [evaluator.dset.attrs[result[0][idx,a]] for a in range(topk)]\n", 154 | " obj = [evaluator.dset.objs[result[1][idx,a]] for a in range(topk)]\n", 155 | " attr_gt, obj_gt = evaluator.dset.attrs[data[1][idx]], evaluator.dset.objs[data[2][idx]]\n", 156 | " print(f'Ground truth: {attr_gt} {obj_gt}')\n", 157 | " prediction = ''\n", 158 | " for a,o in zip(attr, obj):\n", 159 | " prediction += a + ' ' + o + '| '\n", 160 | " print('Predictions: ', prediction)\n", 161 | " print('__'*50)" 162 | ] 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "metadata": {}, 167 | "source": [ 168 | "### An example of predictions\n", 169 | "closed -> Biased for unseen classes\n", 170 | "\n", 171 | "unbiiased -> Biased against unseen classes" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "metadata": {}, 178 | "outputs": [], 179 | "source": [ 180 | "data = next(iter(testloader))\n", 181 | "images = data[-1]\n", 182 | "data = [d.to(device) for d in data[:-1]]\n", 183 | "if image_extractor:\n", 184 | " data[0] = image_extractor(data[0])\n", 185 | "_, predictions = model(data)\n", 186 | "data = [d.to('cpu') for d in data]\n", 187 | "topk = 5\n", 188 | "results = evaluator.score_model(predictions, data[2], bias = 1000, topk=topk)" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": null, 194 | "metadata": { 195 | "scrolled": true 196 | }, 197 | "outputs": [], 198 | "source": [ 199 | "for idx in range(len(images)):\n", 200 | " seen = bool(evaluator.seen_mask[data[3][idx]])\n", 201 | " if seen:\n", 202 | " continue\n", 203 | " image = Image.open(ospj( args.data_dir,'images', images[idx]))\n", 204 | " \n", 205 | " plt.figure(dpi=300)\n", 206 | " plt.imshow(image)\n", 207 | " plt.axis('off')\n", 208 | " plt.show()\n", 209 | " \n", 210 | " print(f'GT pair seen: {seen}')\n", 211 | " print_results(results, 'closed')\n", 212 | " print_results(results, 'unbiased_closed')" 213 | ] 214 | }, 215 | { 216 | "cell_type": "markdown", 217 | "metadata": {}, 218 | "source": [ 219 | "### Run Evaluation" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [ 228 | "model.eval()\n", 229 | "args.bias = 1e3\n", 230 | "accuracies, all_attr_gt, all_obj_gt, all_pair_gt, all_pred = [], [], [], [], []\n", 231 | "\n", 232 | "for idx, data in tqdm(enumerate(testloader), total=len(testloader), desc = 'Testing'):\n", 233 | " data.pop()\n", 234 | " data = [d.to(device) for d in data]\n", 235 | " if image_extractor:\n", 236 | " data[0] = image_extractor(data[0])\n", 237 | "\n", 238 | " _, predictions = model(data) # todo: Unify outputs across models\n", 239 | "\n", 240 | " attr_truth, obj_truth, pair_truth = data[1], data[2], data[3]\n", 241 | " all_pred.append(predictions)\n", 242 | " all_attr_gt.append(attr_truth)\n", 243 | " all_obj_gt.append(obj_truth)\n", 244 | " all_pair_gt.append(pair_truth)\n", 245 | "\n", 246 | "all_attr_gt, all_obj_gt, all_pair_gt = torch.cat(all_attr_gt), torch.cat(all_obj_gt), torch.cat(all_pair_gt)\n", 247 | "\n", 248 | "all_pred_dict = {}\n", 249 | "# Gather values as dict of (attr, obj) as key and list of predictions as values\n", 250 | "for k in all_pred[0].keys():\n", 251 | " all_pred_dict[k] = torch.cat(\n", 252 | " [all_pred[i][k] for i in range(len(all_pred))])\n", 253 | "\n", 254 | "# Calculate best unseen accuracy\n", 255 | "attr_truth, obj_truth = all_attr_gt.to('cpu'), all_obj_gt.to('cpu')\n", 256 | "pairs = list(\n", 257 | " zip(list(attr_truth.numpy()), list(obj_truth.numpy())))" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": null, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "topk = 1 ### For topk results\n", 267 | "our_results = evaluator.score_model(all_pred_dict, all_obj_gt, bias = 1e3, topk = topk)\n", 268 | "stats = evaluator.evaluate_predictions(our_results, all_attr_gt, all_obj_gt, all_pair_gt, all_pred_dict, topk = topk)" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": null, 274 | "metadata": {}, 275 | "outputs": [], 276 | "source": [ 277 | "for k, v in stats.items():\n", 278 | " print(k, v)" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": null, 284 | "metadata": {}, 285 | "outputs": [], 286 | "source": [] 287 | } 288 | ], 289 | "metadata": { 290 | "kernelspec": { 291 | "display_name": "Python 3", 292 | "language": "python", 293 | "name": "python3" 294 | }, 295 | "language_info": { 296 | "codemirror_mode": { 297 | "name": "ipython", 298 | "version": 3 299 | }, 300 | "file_extension": ".py", 301 | "mimetype": "text/x-python", 302 | "name": "python", 303 | "nbconvert_exporter": "python", 304 | "pygments_lexer": "ipython3", 305 | "version": "3.8.5" 306 | } 307 | }, 308 | "nbformat": 4, 309 | "nbformat_minor": 4 310 | } 311 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # Torch imports 2 | import torch 3 | from torch.utils.tensorboard import SummaryWriter 4 | import torch.backends.cudnn as cudnn 5 | import numpy as np 6 | from flags import DATA_FOLDER 7 | 8 | cudnn.benchmark = True 9 | 10 | # Python imports 11 | import tqdm 12 | from tqdm import tqdm 13 | import os 14 | from os.path import join as ospj 15 | 16 | # Local imports 17 | from data import dataset as dset 18 | from models.common import Evaluator 19 | from utils.utils import load_args 20 | from utils.config_model import configure_model 21 | from flags import parser 22 | 23 | 24 | 25 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 26 | 27 | def main(): 28 | # Get arguments and start logging 29 | args = parser.parse_args() 30 | logpath = args.logpath 31 | config = [os.path.join(logpath, _) for _ in os.listdir(logpath) if _.endswith('yml')][0] 32 | load_args(config, args) 33 | 34 | # Get dataset 35 | trainset = dset.CompositionDataset( 36 | root=os.path.join(DATA_FOLDER,args.data_dir), 37 | phase='train', 38 | split=args.splitname, 39 | model=args.image_extractor, 40 | update_features=args.update_features, 41 | train_only=args.train_only, 42 | subset=args.subset, 43 | open_world=args.open_world 44 | ) 45 | 46 | valset = dset.CompositionDataset( 47 | root=os.path.join(DATA_FOLDER,args.data_dir), 48 | phase='val', 49 | split=args.splitname, 50 | model=args.image_extractor, 51 | subset=args.subset, 52 | update_features=args.update_features, 53 | open_world=args.open_world 54 | ) 55 | 56 | valoader = torch.utils.data.DataLoader( 57 | valset, 58 | batch_size=args.test_batch_size, 59 | shuffle=False, 60 | num_workers=8) 61 | 62 | testset = dset.CompositionDataset( 63 | root=os.path.join(DATA_FOLDER,args.data_dir), 64 | phase='test', 65 | split=args.splitname, 66 | model =args.image_extractor, 67 | subset=args.subset, 68 | update_features = args.update_features, 69 | open_world=args.open_world 70 | ) 71 | testloader = torch.utils.data.DataLoader( 72 | testset, 73 | batch_size=args.test_batch_size, 74 | shuffle=False, 75 | num_workers=args.workers) 76 | 77 | 78 | # Get model and optimizer 79 | image_extractor, model, optimizer = configure_model(args, trainset) 80 | args.extractor = image_extractor 81 | 82 | 83 | args.load = ospj(logpath,'ckpt_best_auc.t7') 84 | 85 | checkpoint = torch.load(args.load) 86 | if image_extractor: 87 | try: 88 | image_extractor.load_state_dict(checkpoint['image_extractor']) 89 | image_extractor.eval() 90 | except: 91 | print('No Image extractor in checkpoint') 92 | model.load_state_dict(checkpoint['net']) 93 | model.eval() 94 | 95 | threshold = None 96 | if args.open_world and args.hard_masking: 97 | assert args.model == 'compcos', args.model + ' does not have hard masking.' 98 | if args.threshold is not None: 99 | threshold = args.threshold 100 | else: 101 | evaluator_val = Evaluator(valset, model) 102 | unseen_scores = model.compute_feasibility().to('cpu') 103 | seen_mask = model.seen_mask.to('cpu') 104 | min_feasibility = (unseen_scores+seen_mask*10.).min() 105 | max_feasibility = (unseen_scores-seen_mask*10.).max() 106 | thresholds = np.linspace(min_feasibility,max_feasibility, num=args.threshold_trials) 107 | best_auc = 0. 108 | best_th = -10 109 | with torch.no_grad(): 110 | for th in thresholds: 111 | results = test(image_extractor,model,valoader,evaluator_val,args,threshold=th,print_results=False) 112 | auc = results['AUC'] 113 | if auc > best_auc: 114 | best_auc = auc 115 | best_th = th 116 | print('New best AUC',best_auc) 117 | print('Threshold',best_th) 118 | 119 | threshold = best_th 120 | 121 | evaluator = Evaluator(testset, model) 122 | 123 | with torch.no_grad(): 124 | test(image_extractor, model, testloader, evaluator, args, threshold) 125 | 126 | 127 | def test(image_extractor, model, testloader, evaluator, args, threshold=None, print_results=True): 128 | if image_extractor: 129 | image_extractor.eval() 130 | 131 | model.eval() 132 | 133 | accuracies, all_sub_gt, all_attr_gt, all_obj_gt, all_pair_gt, all_pred = [], [], [], [], [], [] 134 | 135 | for idx, data in tqdm(enumerate(testloader), total=len(testloader), desc='Testing'): 136 | data = [d.to(device) for d in data] 137 | 138 | if image_extractor: 139 | data[0] = image_extractor(data[0]) 140 | if threshold is None: 141 | _, predictions = model(data) 142 | else: 143 | _, predictions = model.val_forward_with_threshold(data,threshold) 144 | 145 | attr_truth, obj_truth, pair_truth = data[1], data[2], data[3] 146 | 147 | all_pred.append(predictions) 148 | all_attr_gt.append(attr_truth) 149 | all_obj_gt.append(obj_truth) 150 | all_pair_gt.append(pair_truth) 151 | 152 | if args.cpu_eval: 153 | all_attr_gt, all_obj_gt, all_pair_gt = torch.cat(all_attr_gt), torch.cat(all_obj_gt), torch.cat(all_pair_gt) 154 | else: 155 | all_attr_gt, all_obj_gt, all_pair_gt = torch.cat(all_attr_gt).to('cpu'), torch.cat(all_obj_gt).to( 156 | 'cpu'), torch.cat(all_pair_gt).to('cpu') 157 | 158 | all_pred_dict = {} 159 | # Gather values as dict of (attr, obj) as key and list of predictions as values 160 | if args.cpu_eval: 161 | for k in all_pred[0].keys(): 162 | all_pred_dict[k] = torch.cat( 163 | [all_pred[i][k].to('cpu') for i in range(len(all_pred))]) 164 | else: 165 | for k in all_pred[0].keys(): 166 | all_pred_dict[k] = torch.cat( 167 | [all_pred[i][k] for i in range(len(all_pred))]) 168 | 169 | # Calculate best unseen accuracy 170 | results = evaluator.score_model(all_pred_dict, all_obj_gt, bias=args.bias, topk=args.topk) 171 | stats = evaluator.evaluate_predictions(results, all_attr_gt, all_obj_gt, all_pair_gt, all_pred_dict, 172 | topk=args.topk) 173 | 174 | 175 | result = '' 176 | for key in stats: 177 | result = result + key + ' ' + str(round(stats[key], 4)) + '| ' 178 | 179 | result = result + args.name 180 | if print_results: 181 | print(f'Results') 182 | print(result) 183 | return results 184 | 185 | 186 | if __name__ == '__main__': 187 | main() 188 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Torch imports 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.utils.tensorboard import SummaryWriter 6 | import torch.backends.cudnn as cudnn 7 | cudnn.benchmark = True 8 | 9 | # Python imports 10 | import tqdm 11 | from tqdm import tqdm 12 | import os 13 | from os.path import join as ospj 14 | import csv 15 | 16 | #Local imports 17 | from data import dataset as dset 18 | from models.common import Evaluator 19 | from utils.utils import save_args, load_args 20 | from utils.config_model import configure_model 21 | from flags import parser, DATA_FOLDER 22 | 23 | best_auc = 0 24 | best_hm = 0 25 | compose_switch = True 26 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 27 | 28 | def main(): 29 | # Get arguments and start logging 30 | args = parser.parse_args() 31 | load_args(args.config, args) 32 | logpath = os.path.join(args.cv_dir, args.name) 33 | os.makedirs(logpath, exist_ok=True) 34 | save_args(args, logpath, args.config) 35 | writer = SummaryWriter(log_dir = logpath, flush_secs = 30) 36 | 37 | # Get dataset 38 | trainset = dset.CompositionDataset( 39 | root=os.path.join(DATA_FOLDER,args.data_dir), 40 | phase='train', 41 | split=args.splitname, 42 | model =args.image_extractor, 43 | num_negs=args.num_negs, 44 | pair_dropout=args.pair_dropout, 45 | update_features = args.update_features, 46 | train_only= args.train_only, 47 | open_world=args.open_world 48 | ) 49 | trainloader = torch.utils.data.DataLoader( 50 | trainset, 51 | batch_size=args.batch_size, 52 | shuffle=True, 53 | num_workers=args.workers) 54 | testset = dset.CompositionDataset( 55 | root=os.path.join(DATA_FOLDER,args.data_dir), 56 | phase=args.test_set, 57 | split=args.splitname, 58 | model =args.image_extractor, 59 | subset=args.subset, 60 | update_features = args.update_features, 61 | open_world=args.open_world 62 | ) 63 | testloader = torch.utils.data.DataLoader( 64 | testset, 65 | batch_size=args.test_batch_size, 66 | shuffle=False, 67 | num_workers=args.workers) 68 | 69 | 70 | # Get model and optimizer 71 | image_extractor, model, optimizer = configure_model(args, trainset) 72 | args.extractor = image_extractor 73 | 74 | train = train_normal 75 | 76 | evaluator_val = Evaluator(testset, model) 77 | 78 | print(model) 79 | 80 | start_epoch = 0 81 | # Load checkpoint 82 | if args.load is not None: 83 | checkpoint = torch.load(args.load) 84 | if image_extractor: 85 | try: 86 | image_extractor.load_state_dict(checkpoint['image_extractor']) 87 | if args.freeze_features: 88 | print('Freezing image extractor') 89 | image_extractor.eval() 90 | for param in image_extractor.parameters(): 91 | param.requires_grad = False 92 | except: 93 | print('No Image extractor in checkpoint') 94 | model.load_state_dict(checkpoint['net']) 95 | start_epoch = checkpoint['epoch'] 96 | print('Loaded model from ', args.load) 97 | 98 | for epoch in tqdm(range(start_epoch, args.max_epochs + 1), desc = 'Current epoch'): 99 | train(epoch, image_extractor, model, trainloader, optimizer, writer) 100 | if model.is_open and args.model=='compcos' and ((epoch+1)%args.update_feasibility_every)==0 : 101 | print('Updating feasibility scores') 102 | model.update_feasibility(epoch+1.) 103 | 104 | if epoch % args.eval_val_every == 0: 105 | with torch.no_grad(): # todo: might not be needed 106 | test(epoch, image_extractor, model, testloader, evaluator_val, writer, args, logpath) 107 | print('Best AUC achieved is ', best_auc) 108 | print('Best HM achieved is ', best_hm) 109 | 110 | 111 | def train_normal(epoch, image_extractor, model, trainloader, optimizer, writer): 112 | ''' 113 | Runs training for an epoch 114 | ''' 115 | 116 | if image_extractor: 117 | image_extractor.train() 118 | model.train() # Let's switch to training 119 | 120 | train_loss = 0.0 121 | for idx, data in tqdm(enumerate(trainloader), total=len(trainloader), desc = 'Training'): 122 | data = [d.to(device) for d in data] 123 | 124 | if image_extractor: 125 | data[0] = image_extractor(data[0]) 126 | 127 | loss, _ = model(data) 128 | 129 | optimizer.zero_grad() 130 | loss.backward() 131 | optimizer.step() 132 | 133 | train_loss += loss.item() 134 | 135 | train_loss = train_loss/len(trainloader) 136 | writer.add_scalar('Loss/train_total', train_loss, epoch) 137 | print('Epoch: {}| Loss: {}'.format(epoch, round(train_loss, 2))) 138 | 139 | 140 | def test(epoch, image_extractor, model, testloader, evaluator, writer, args, logpath): 141 | ''' 142 | Runs testing for an epoch 143 | ''' 144 | global best_auc, best_hm 145 | 146 | def save_checkpoint(filename): 147 | state = { 148 | 'net': model.state_dict(), 149 | 'epoch': epoch, 150 | 'AUC': stats['AUC'] 151 | } 152 | if image_extractor: 153 | state['image_extractor'] = image_extractor.state_dict() 154 | torch.save(state, os.path.join(logpath, 'ckpt_{}.t7'.format(filename))) 155 | 156 | if image_extractor: 157 | image_extractor.eval() 158 | 159 | model.eval() 160 | 161 | accuracies, all_sub_gt, all_attr_gt, all_obj_gt, all_pair_gt, all_pred = [], [], [], [], [], [] 162 | 163 | for idx, data in tqdm(enumerate(testloader), total=len(testloader), desc='Testing'): 164 | data = [d.to(device) for d in data] 165 | 166 | if image_extractor: 167 | data[0] = image_extractor(data[0]) 168 | 169 | _, predictions = model(data) 170 | 171 | attr_truth, obj_truth, pair_truth = data[1], data[2], data[3] 172 | 173 | all_pred.append(predictions) 174 | all_attr_gt.append(attr_truth) 175 | all_obj_gt.append(obj_truth) 176 | all_pair_gt.append(pair_truth) 177 | 178 | if args.cpu_eval: 179 | all_attr_gt, all_obj_gt, all_pair_gt = torch.cat(all_attr_gt), torch.cat(all_obj_gt), torch.cat(all_pair_gt) 180 | else: 181 | all_attr_gt, all_obj_gt, all_pair_gt = torch.cat(all_attr_gt).to('cpu'), torch.cat(all_obj_gt).to( 182 | 'cpu'), torch.cat(all_pair_gt).to('cpu') 183 | 184 | all_pred_dict = {} 185 | # Gather values as dict of (attr, obj) as key and list of predictions as values 186 | if args.cpu_eval: 187 | for k in all_pred[0].keys(): 188 | all_pred_dict[k] = torch.cat( 189 | [all_pred[i][k].to('cpu') for i in range(len(all_pred))]) 190 | else: 191 | for k in all_pred[0].keys(): 192 | all_pred_dict[k] = torch.cat( 193 | [all_pred[i][k] for i in range(len(all_pred))]) 194 | 195 | # Calculate best unseen accuracy 196 | results = evaluator.score_model(all_pred_dict, all_obj_gt, bias=args.bias, topk=args.topk) 197 | stats = evaluator.evaluate_predictions(results, all_attr_gt, all_obj_gt, all_pair_gt, all_pred_dict, topk=args.topk) 198 | 199 | stats['a_epoch'] = epoch 200 | 201 | result = '' 202 | # write to Tensorboard 203 | for key in stats: 204 | writer.add_scalar(key, stats[key], epoch) 205 | result = result + key + ' ' + str(round(stats[key], 4)) + '| ' 206 | 207 | result = result + args.name 208 | print(f'Test Epoch: {epoch}') 209 | print(result) 210 | if epoch > 0 and epoch % args.save_every == 0: 211 | save_checkpoint(epoch) 212 | if stats['AUC'] > best_auc: 213 | best_auc = stats['AUC'] 214 | print('New best AUC ', best_auc) 215 | save_checkpoint('best_auc') 216 | 217 | if stats['best_hm'] > best_hm: 218 | best_hm = stats['best_hm'] 219 | print('New best HM ', best_hm) 220 | save_checkpoint('best_hm') 221 | 222 | # Logs 223 | with open(ospj(logpath, 'logs.csv'), 'a') as f: 224 | w = csv.DictWriter(f, stats.keys()) 225 | if epoch == 0: 226 | w.writeheader() 227 | w.writerow(stats) 228 | 229 | 230 | if __name__ == '__main__': 231 | try: 232 | main() 233 | except KeyboardInterrupt: 234 | print('Best AUC achieved is ', best_auc) 235 | print('Best HM achieved is ', best_hm) -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/czsl/22ab1aca82c64d4f351a20a887986147edfb1905/utils/__init__.py -------------------------------------------------------------------------------- /utils/cgqa-graph.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/czsl/22ab1aca82c64d4f351a20a887986147edfb1905/utils/cgqa-graph.t7 -------------------------------------------------------------------------------- /utils/config_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | 4 | from models.image_extractor import get_image_extractor 5 | from models.visual_product import VisualProductNN 6 | from models.manifold_methods import RedWine, LabelEmbedPlus, AttributeOperator 7 | from models.modular_methods import GatedGeneralNN 8 | from models.graph_method import GraphFull 9 | from models.symnet import Symnet 10 | from models.compcos import CompCos 11 | 12 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 13 | 14 | 15 | def configure_model(args, dataset): 16 | image_extractor = None 17 | is_open = False 18 | 19 | if args.model == 'visprodNN': 20 | model = VisualProductNN(dataset, args) 21 | elif args.model == 'redwine': 22 | model = RedWine(dataset, args) 23 | elif args.model == 'labelembed+': 24 | model = LabelEmbedPlus(dataset, args) 25 | elif args.model == 'attributeop': 26 | model = AttributeOperator(dataset, args) 27 | elif args.model == 'tmn': 28 | model = GatedGeneralNN(dataset, args, num_layers=args.nlayers, num_modules_per_layer=args.nmods) 29 | elif args.model == 'symnet': 30 | model = Symnet(dataset, args) 31 | elif args.model == 'graphfull': 32 | model = GraphFull(dataset, args) 33 | elif args.model == 'compcos': 34 | model = CompCos(dataset, args) 35 | if dataset.open_world and not args.train_only: 36 | is_open = True 37 | else: 38 | raise NotImplementedError 39 | 40 | model = model.to(device) 41 | 42 | if args.update_features: 43 | print('Learnable image_embeddings') 44 | image_extractor = get_image_extractor(arch = args.image_extractor, pretrained = True) 45 | image_extractor = image_extractor.to(device) 46 | 47 | # configuring optimizer 48 | if args.model=='redwine': 49 | optim_params = filter(lambda p: p.requires_grad, model.parameters()) 50 | elif args.model=='attributeop': 51 | attr_params = [param for name, param in model.named_parameters() if 'attr_op' in name and param.requires_grad] 52 | other_params = [param for name, param in model.named_parameters() if 'attr_op' not in name and param.requires_grad] 53 | optim_params = [{'params':attr_params, 'lr':0.1*args.lr}, {'params':other_params}] 54 | elif args.model=='tmn': 55 | gating_params = [ 56 | param for name, param in model.named_parameters() 57 | if 'gating_network' in name and param.requires_grad 58 | ] 59 | network_params = [ 60 | param for name, param in model.named_parameters() 61 | if 'gating_network' not in name and param.requires_grad 62 | ] 63 | optim_params = [ 64 | { 65 | 'params': network_params, 66 | }, 67 | { 68 | 'params': gating_params, 69 | 'lr': args.lrg 70 | }, 71 | ] 72 | else: 73 | model_params = [param for name, param in model.named_parameters() if param.requires_grad] 74 | optim_params = [{'params':model_params}] 75 | if args.update_features: 76 | ie_parameters = [param for name, param in image_extractor.named_parameters()] 77 | optim_params.append({'params': ie_parameters, 78 | 'lr': args.lrg}) 79 | optimizer = optim.Adam(optim_params, lr=args.lr, weight_decay=args.wd) 80 | 81 | model.is_open = is_open 82 | 83 | return image_extractor, model, optimizer -------------------------------------------------------------------------------- /utils/download_data.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | CURRENT_DIR=$(pwd) 9 | 10 | FOLDER=$1 11 | 12 | mkdir $FOLDER 13 | mkdir $FOLDER/fast 14 | mkdir $FOLDER/glove 15 | mkdir $FOLDER/w2v 16 | 17 | cp utils/download_embeddings.py $FOLDER/fast 18 | 19 | 20 | # Download everything 21 | wget --show-progress -O $FOLDER/attr-ops-data.tar.gz https://utexas.box.com/shared/static/h7ckk2gx5ykw53h8o641no8rdyi7185g.gz 22 | wget --show-progress -O $FOLDER/mitstates.zip http://wednesday.csail.mit.edu/joseph_result/state_and_transformation/release_dataset.zip 23 | wget --show-progress -O $FOLDER/utzap.zip http://vision.cs.utexas.edu/projects/finegrained/utzap50k/ut-zap50k-images.zip 24 | wget --show-progress -O $FOLDER/splits.tar.gz https://www.senthilpurushwalkam.com/publication/compositional/compositional_split_natural.tar.gz 25 | wget --show-progress -O $FOLDER/cgqa.zip https://s3.mlcloud.uni-tuebingen.de/czsl/cgqa-updated.zip 26 | 27 | echo "Data downloaded. Extracting files..." 28 | sed -i "s|ROOT_FOLDER|$FOLDER|g" utils/reorganize_utzap.py 29 | sed -i "s|ROOT_FOLDER|$FOLDER|g" flags.py 30 | 31 | # Dataset metadata, pretrained SVMs and features, tensor completion data 32 | 33 | cd $FOLDER 34 | 35 | tar -zxvf attr-ops-data.tar.gz --strip 1 36 | 37 | # MIT-States 38 | unzip mitstates.zip 'release_dataset/images/*' -d mit-states/ 39 | mv mit-states/release_dataset/images mit-states/images/ 40 | rm -r mit-states/release_dataset 41 | rename "s/ /_/g" mit-states/images/* 42 | 43 | # UT-Zappos50k 44 | unzip utzap.zip -d ut-zap50k/ 45 | mv ut-zap50k/ut-zap50k-images ut-zap50k/_images/ 46 | 47 | # C-GQA 48 | unzip cgqa.zip -d cgqa/ 49 | 50 | # Download new splits for Purushwalkam et. al 51 | tar -zxvf splits.tar.gz 52 | 53 | # remove all zip files and temporary files 54 | rm -r attr-ops-data.tar.gz mitstates.zip utzap.zip splits.tar.gz cgqa.zip 55 | 56 | # Download embeddings 57 | 58 | # Glove (from attribute as operators) 59 | mv data/glove/* glove/ 60 | 61 | # FastText 62 | cd fast 63 | python download_embeddings.py 64 | rm cc.en.300.bin.gz 65 | 66 | # Word2Vec 67 | cd ../w2v 68 | wget -c "https://s3.amazonaws.com/dl4j-distribution/GoogleNews-vectors-negative300.bin.gz" 69 | gzip -d GoogleNews-vectors-negative300.bin.gz 70 | rm GoogleNews-vectors-negative300.bin.gz 71 | 72 | cd .. 73 | rm -r data 74 | 75 | cd $CURRENT_DIR 76 | python utils/reorganize_utzap.py 77 | -------------------------------------------------------------------------------- /utils/download_embeddings.py: -------------------------------------------------------------------------------- 1 | import fasttext.util 2 | fasttext.util.download_model('en', if_exists='ignore') # English 3 | 4 | -------------------------------------------------------------------------------- /utils/img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/czsl/22ab1aca82c64d4f351a20a887986147edfb1905/utils/img.png -------------------------------------------------------------------------------- /utils/mitstates-graph.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/czsl/22ab1aca82c64d4f351a20a887986147edfb1905/utils/mitstates-graph.t7 -------------------------------------------------------------------------------- /utils/reorganize_utzap.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | """ 8 | Reorganize the UT-Zappos dataset to resemble the MIT-States dataset 9 | root/attr_obj/img1.jpg 10 | root/attr_obj/img2.jpg 11 | root/attr_obj/img3.jpg 12 | ... 13 | """ 14 | 15 | import os 16 | import torch 17 | import shutil 18 | import tqdm 19 | 20 | DATA_FOLDER= "ROOT_FOLDER" 21 | 22 | root = DATA_FOLDER+'/ut-zap50k/' 23 | os.makedirs(root+'/images',exist_ok=True) 24 | 25 | data = torch.load(root+'/metadata_compositional-split-natural.t7') 26 | for instance in tqdm.tqdm(data): 27 | image, attr, obj = instance['_image'], instance['attr'], instance['obj'] 28 | old_file = '%s/_images/%s'%(root, image) 29 | new_dir = '%s/images/%s_%s/'%(root, attr, obj) 30 | os.makedirs(new_dir, exist_ok=True) 31 | shutil.copy(old_file, new_dir) 32 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as ospj 3 | import torch 4 | import random 5 | import copy 6 | import shutil 7 | import sys 8 | import yaml 9 | 10 | 11 | def chunks(l, n): 12 | """Yield successive n-sized chunks from l.""" 13 | for i in range(0, len(l), n): 14 | yield l[i:i + n] 15 | 16 | def get_norm_values(norm_family = 'imagenet'): 17 | ''' 18 | Inputs 19 | norm_family: String of norm_family 20 | Returns 21 | mean, std : tuple of 3 channel values 22 | ''' 23 | if norm_family == 'imagenet': 24 | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 25 | else: 26 | raise ValueError('Incorrect normalization family') 27 | return mean, std 28 | 29 | def save_args(args, log_path, argfile): 30 | shutil.copy('train.py', log_path) 31 | modelfiles = ospj(log_path, 'models') 32 | try: 33 | shutil.copy(argfile, log_path) 34 | except: 35 | print('Config exists') 36 | try: 37 | shutil.copytree('models/', modelfiles) 38 | except: 39 | print('Already exists') 40 | with open(ospj(log_path,'args_all.yaml'),'w') as f: 41 | yaml.dump(args, f, default_flow_style=False, allow_unicode=True) 42 | with open(ospj(log_path, 'args.txt'), 'w') as f: 43 | f.write('\n'.join(sys.argv[1:])) 44 | 45 | class UnNormalizer: 46 | ''' 47 | Unnormalize a given tensor using mean and std of a dataset family 48 | Inputs 49 | norm_family: String, dataset 50 | tensor: Torch tensor 51 | Outputs 52 | tensor: Unnormalized tensor 53 | ''' 54 | def __init__(self, norm_family = 'imagenet'): 55 | self.mean, self.std = get_norm_values(norm_family=norm_family) 56 | self.mean, self.std = torch.Tensor(self.mean).view(1, 3, 1, 1), torch.Tensor(self.std).view(1, 3, 1, 1) 57 | 58 | def __call__(self, tensor): 59 | return (tensor * self.std) + self.mean 60 | 61 | def load_args(filename, args): 62 | with open(filename, 'r') as stream: 63 | data_loaded = yaml.safe_load(stream) 64 | for key, group in data_loaded.items(): 65 | for key, val in group.items(): 66 | setattr(args, key, val) -------------------------------------------------------------------------------- /utils/utzappos-graph.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/czsl/22ab1aca82c64d4f351a20a887986147edfb1905/utils/utzappos-graph.t7 --------------------------------------------------------------------------------