├── .gitignore ├── .gitmodules ├── .vscode └── settings.json ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── __pycache__ ├── attributed_dataloader.cpython-37.pyc ├── brenden_number_game.cpython-37.pyc └── model.cpython-37.pyc ├── concept_data ├── CLEVR_val_scenes.json ├── clevr-metadata.json ├── clevr-properties.json ├── clevr_typed_fol_properties.json ├── meta_dataset_properties.sh └── v2_typed_simple_fol.json ├── dataloaders ├── __init__.py ├── adhoc_data_loader.py ├── build_sound_scene.py ├── get_dataloader.py ├── utils.py └── vocabulary.py ├── docs ├── model_generalization.pdf ├── model_generalization.png └── rel_net_alternate.png ├── hydra_cfg ├── data_files │ └── v2_typed_simple_fol_depth_6_trials_2000000_ban_1_max_scene_id_200.yaml ├── experiment.yaml ├── modality │ ├── image.yaml │ └── json.yaml ├── mode │ ├── eval.yaml │ └── train.yaml ├── pooling │ ├── concat.yaml │ ├── gap.yaml │ ├── rel_net.yaml │ └── trnsf.yaml ├── special_rules │ ├── json_concat.yaml │ ├── json_gap.yaml │ ├── json_rel_net.yaml │ ├── json_trnsf.yaml │ └── sound_concat.yaml └── task │ └── adhoc_concepts.yaml ├── hydra_eval.py ├── hydra_qualitative.py ├── hydra_train.py ├── hypothesis_generation ├── __init__.py ├── hypothesis_utils.py ├── prefix_postfix.py └── reduce_and_process_hypotheses.py ├── install.sh ├── launch_test_jobs.sh ├── launch_train_eval_jobs.sh ├── losses.py ├── models ├── __init__.py ├── __pycache__ │ └── .nfs00780000024e4221000025e7 ├── _evaluator.py ├── _map_evaluator.py ├── _trainer.py ├── audio_resnet.py ├── encoders.py ├── protonet.py ├── simple_lstm_decoder.py └── utils.py ├── notebooks ├── .ipynb_checkpoints │ ├── ahdoc_iclr_numbers-checkpoint.ipynb │ ├── structured_splits-checkpoint.ipynb │ ├── vis_meta_dataset-checkpoint.ipynb │ ├── visualize_alternate_hypotheses-checkpoint.ipynb │ └── visualize_meta_dataset-checkpoint.ipynb ├── ahdoc_iclr_numbers.ipynb ├── structured_splits.ipynb ├── visualize_alternate_hypotheses.ipynb └── visualize_meta_dataset.ipynb ├── paths.sh ├── runs └── hydra_train │ ├── .hydra │ ├── config.yaml │ ├── hydra.yaml │ └── overrides.yaml │ ├── events.out.tfevents.1625066218.devfair0241 │ └── hydra_train.log ├── scripts └── pick_best_checkpoints.py ├── third_party ├── __init__.py ├── image_utils.py ├── mlp.py └── relation_network.py └── utils ├── __init__.py ├── checkpointing.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Gitignore 2 | 3 | # Reject hidden files 4 | ._* 5 | *.out 6 | datasets/* 7 | outputs/* 8 | *.pyc 9 | .ipynb_checkpoints/abstraction_model-Copy1-checkpoint.ipynb 10 | .ipynb_checkpoints/abstraction_model-checkpoint.ipynb 11 | __pycache__/* 12 | scenes_* 13 | performance_* 14 | concept_data/typed_simple_fol_depth_40_rejection_sample_trials_100000.json 15 | concept_data/typed_simple_fol_depth_40_rejection_sample_trials_100000.pkl 16 | hypothesis_generation/__pycache__/* 17 | concept_data/grammar_expander_typed_simple_fol_depth_40_ordinal_size.pkl 18 | results_dump.pkl 19 | concept_data/typed_simple_fol_clevr-properties.pkl 20 | concept_data/typed_simple_fol_depth_30_trials_400000.pkl 21 | concept_data/v1_typed_simple_fol_clevr-properties.pkl 22 | notebooks/.ipynb_checkpoints/Untitled-checkpoint.ipynb 23 | concept_data/temp_data/v1_typed_simple_fol_clevr_typed_fol_properties.pkl 24 | concept_data/temp_data/v2_typed_simple_fol_clevr_typed_fol_properties.pkl 25 | notebooks/hypothtesis_analysis.ipynb 26 | notebooks/Untitled.ipynb 27 | notebooks/.ipynb_checkpoints/hypothtesis_analysis-checkpoint.ipynb 28 | slurm-17729056.out 29 | log_test/* 30 | env/* 31 | notebooks/create_folder_hypothesis_and_json.ipynb 32 | notebooks/.ipynb_checkpoints/create_folder_hypothesis_and_json-checkpoint.ipynb 33 | notebooks/1.png 34 | notebooks/2.png 35 | notebooks/3.png 36 | notebooks/csv_download.csv 37 | notebooks/structured_splits.ipynb 38 | notebooks/visualize_meta_dataset.ipynb 39 | notebooks/*.png 40 | notebooks/.ipynb_checkpoints/* 41 | release/* 42 | runs/* 43 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "clevr-dataset-gen"] 2 | path = clevr-dataset-gen 3 | url = https://github.com/facebookresearch/clevr-dataset-gen 4 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.pythonPath": "/private/home/ramav/anaconda3/bin/python", 3 | "python.formatting.provider": "yapf", 4 | "editor.rulers": [ 5 | 80, 6 | 120 7 | ], 8 | "python.linting.pylintEnabled": true, 9 | "python.linting.enabled": true 10 | } -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Open Source Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to make participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. 6 | 7 | ## Our Standards 8 | 9 | Examples of behavior that contributes to creating a positive environment include: 10 | 11 | Using welcoming and inclusive language 12 | Being respectful of differing viewpoints and experiences 13 | Gracefully accepting constructive criticism 14 | Focusing on what is best for the community 15 | Showing empathy towards other community members 16 | Examples of unacceptable behavior by participants include: 17 | 18 | The use of sexualized language or imagery and unwelcome sexual attention or advances 19 | Trolling, insulting/derogatory comments, and personal or political attacks 20 | Public or private harassment 21 | Publishing others’ private information, such as a physical or electronic address, without explicit permission 22 | Other conduct which could reasonably be considered inappropriate in a professional setting 23 | 24 | ## Our Responsibilities 25 | 26 | Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. 27 | 28 | Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. 29 | 30 | ## Scope 31 | 32 | This Code of Conduct applies within all project spaces, and it also applies when an individual is representing the project or its community in public spaces. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. 33 | 34 | ## Enforcement 35 | 36 | Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at opensource-conduct@fb.com. All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. 37 | 38 | Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project’s leadership. 39 | 40 | ## Attribution 41 | 42 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 43 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 44 | 45 | [homepage]: https://www.contributor-covenant.org -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to productive_concept_learning 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## Coding Style 30 | * 4 spaces for indentation rather than tabs 31 | * 80 character line length 32 | 33 | ## License 34 | By contributing to __________, you agree that your contributions will be licensed 35 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Productive Concept Learning with the CURI Dataset (ICML 2021) 2 | ` 3 | We introduce the productive concept learning task and the CURI dataset, 4 | where few-shot, meta-learners need 5 | to acquire concepts sampled from a structured, productive, compositional space and 6 | generalize, while demonstrating the ability to reason about compositionality 7 | under uncertainty. 8 | ` 9 | ## Compositional Reasoning Under Uncertainty 10 | ![Generalization](docs/model_generalization.png) 11 | **Illustration of Compositional Reasoning Under Uncertainty** 12 | 13 | Given images in a support set (left) consistent with two concepts: 14 | 15 | 1) ``all objects are blue and for all objects the x coordinate is greater 16 | than the y-coordinate'' 17 | 18 | 2) ``all objects are blue there exists 19 | an object such that the x coordinate is greater than the y coordinate'' 20 | 21 | A model that understands the compositional structure of the space 22 | should first be able to infer that these are concepts which can explain the observed 23 | images, and then reason that the former concept is more likely to explain 24 | the set and make predictions on held out images keeping the relative 25 | likelihoods of both the scenarios into account. The CURI benchmark tests 26 | such compositional reasoning. 27 | 28 | ### Systematic Splits and Compositionality Gap (Comp Gap) 29 | Our paper introduces various systematic splits to test different aspects of 30 | productive concept learning. While similar methodologies have been applied to 31 | create datasets and splits in the past, our work is unique in that we also 32 | introduce a notion of a "compositionality gap" which gives an objective, 33 | model-independent measure of how difficult a given compositional split is. 34 | 35 | ![Compositionality Gap](docs/rel_net_alternate.png) 36 | 37 | ### Models 38 | This repository contains an implementation of prototypical networks (Snell et.al.) 39 | with various architectures, input modalities, use of auxiliary information 40 | etc. For more details see our paper: 41 | 42 | **“CURI: A Benchmark for Productive Concept Learning Under Uncertainty.”** 43 | arXiv [cs.AI]. arXiv. http://arxiv.org/abs/2010.02855. 44 | *Vedantam, Ramakrishna, Arthur Szlam, Maximilian Nickel, Ari Morcos, and Brenden Lake* 45 | ICML 2021 46 | 47 | ---- 48 | ---- 49 | ## Setup 50 | Run `source install.sh` to install all the dependencies for the project. 51 | 52 | ### Download the CURI dataset 53 | Run the following command to download the CURI dataset: 54 | 55 | ``` 56 | wget https://dl.fbaipublicfiles.com/CURI/curi_v0.2.tar.gz -P /path/to/dataset 57 | ``` 58 | The dataset itself is very large ~200GB after decompressing 59 | so please ensure there is enough storage 60 | where you are downloading it to hold the dataset. 61 | 62 | Uncompress the following files: 63 | 64 | ``` 65 | cd /path/to/dataset/curi_release 66 | tar -xvf images_200.tar.gz 67 | tar -xvf scenes_200.tar.gz 68 | ``` 69 | 70 | ### Before Training and Evaluating models 71 | 72 | Before we train the models we need to set the following two paths in `paths.sh`: 73 | 74 | * First set `${RUN_DIR}=/path/to/runs` to set where we want to store the results 75 | of the sweeps on the models. 76 | * Set `${CURI_DATASET_PATH}` variable in file `paths.sh` to `/path/to/dataset/curi_release` 77 | * Run `source paths.sh` 78 | 79 | ### Training a model 80 | ``` 81 | python hydra_train.py mode=train 82 | ``` 83 | runs the training of the model with the default parameters. 84 | The project uses [Hydra](hydra.cc) to manage configuration. Set the 85 | corresponding flags based on the configuration in `hydra_cfg/experiment.yaml` 86 | to run various models of interest. 87 | 88 | Run `source launch_train_eval_jobs.sh` to launch a full sweep of models explored 89 | in the paper, and `source launch_test_jobs.sh` to launch a full sweep of test 90 | runs on the models after they train. 91 | 92 | ## Computing oracle metrics 93 | See `source launch_test_jobs.sh` to see how to compute oracle metrics and the 94 | associated compositionality gap for various splits discussed in the paper. 95 | 96 | ## Additional Information 97 | ### Mapping between the splits in the code and the splits in the paper 98 | The following mapping exists between the splits in the code and the splits 99 | mentioned in our paper: 100 | 101 | | Code | Paper | 102 | |---------------------|-------------------------| 103 | | Color Boolean | Boolean | 104 | | Color Count | Counting | 105 | | Color Location | Extrinsic Disentangling | 106 | | Color Material | Intrinsic Disentangling | 107 | | Color | Binding (Color) . | 108 | | Comp | Compositional Split | 109 | | IID | IID | 110 | | Shape | Binding (Shape) . | 111 | | length_threshold_10 | Complexity | 112 | 113 | ### Directory Structure 114 | The dataset directory contains the following files: 115 | 116 | ``` 117 | CURI 118 | | --> images 119 | | ---->1 120 | | ------> 1.png 121 | | ------> 2.png 122 | | --> scenes 123 | | ---->1 124 | | ------> 1.json 125 | | ------> 2.json 126 | | --> hypotheses 127 | | ----> hypothesis_property_dir 128 | | ------> split_1.pkl 129 | | ------> split_2.pkl 130 | . 131 | . 132 | ``` 133 | 134 | ## License 135 | productive_concept_learning is CC-BY-NC 4.0 licensed, as found in the LICENSE file. 136 | -------------------------------------------------------------------------------- /__pycache__/attributed_dataloader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/productive_concept_learning/d00a3b85215bff0686ad3f834d50c8eae374f06a/__pycache__/attributed_dataloader.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/brenden_number_game.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/productive_concept_learning/d00a3b85215bff0686ad3f834d50c8eae374f06a/__pycache__/brenden_number_game.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/productive_concept_learning/d00a3b85215bff0686ad3f834d50c8eae374f06a/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /concept_data/clevr-metadata.json: -------------------------------------------------------------------------------- 1 | { 2 | "image_size_x": 480, 3 | "image_size_y": 320, 4 | "image_size_z": 16, 5 | "max_object_size": 0.7, 6 | "sizes": { 7 | "large": 0.7, 8 | "small": 0.35 9 | }, 10 | "dimensions_to_idx": { 11 | "x": 0, 12 | "y": 1, 13 | "z": 2 14 | } 15 | } -------------------------------------------------------------------------------- /concept_data/clevr-properties.json: -------------------------------------------------------------------------------- 1 | { 2 | "shapes": { 3 | "cube": "SmoothCube_v2", 4 | "sphere": "Sphere", 5 | "cylinder": "SmoothCylinder" 6 | }, 7 | "colors": { 8 | "gray": [87, 87, 87], 9 | "red": [173, 35, 35], 10 | "blue": [42, 75, 215], 11 | "green": [29, 105, 20], 12 | "brown": [129, 74, 25], 13 | "purple": [129, 38, 192], 14 | "cyan": [41, 208, 208], 15 | "yellow": [255, 238, 51] 16 | }, 17 | "materials": { 18 | "rubber": "Rubber", 19 | "metal": "MyMetal" 20 | }, 21 | "sizes": { 22 | "large": 0.7, 23 | "small": 0.35 24 | } 25 | } -------------------------------------------------------------------------------- /concept_data/clevr_typed_fol_properties.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "v0.2", 3 | "properties": { 4 | "shapes": { 5 | "cube": "SmoothCube_v2", 6 | "sphere": "Sphere", 7 | "cylinder": "SmoothCylinder" 8 | }, 9 | "colors": { 10 | "gray": [ 11 | 87, 12 | 87, 13 | 87 14 | ], 15 | "red": [ 16 | 173, 17 | 35, 18 | 35 19 | ], 20 | "blue": [ 21 | 42, 22 | 75, 23 | 215 24 | ], 25 | "green": [ 26 | 29, 27 | 105, 28 | 20 29 | ], 30 | "brown": [ 31 | 129, 32 | 74, 33 | 25 34 | ], 35 | "purple": [ 36 | 129, 37 | 38, 38 | 192 39 | ], 40 | "cyan": [ 41 | 41, 42 | 208, 43 | 208 44 | ], 45 | "yellow": [ 46 | 255, 47 | 238, 48 | 51 49 | ] 50 | }, 51 | "materials": { 52 | "rubber": "Rubber", 53 | "metal": "MyMetal" 54 | }, 55 | "sizes": { 56 | "large": 0.7, 57 | "small": 0.35 58 | } 59 | }, 60 | "metadata": { 61 | "image_size": { 62 | "x": 240, 63 | "y": 160, 64 | "z": 16 65 | }, 66 | "max_rotation": 360, 67 | "min_objects": 2, 68 | "max_objects": 5, 69 | "max_object_size": 0.7, 70 | "sizes": { 71 | "large": 0.7, 72 | "small": 0.35 73 | }, 74 | "dimensions_to_idx": { 75 | "x": 0, 76 | "y": 1, 77 | "z": 2 78 | }, 79 | "location_bins": { 80 | "x": 8, 81 | "y": 8, 82 | "z": 8 83 | } 84 | } 85 | } -------------------------------------------------------------------------------- /concept_data/meta_dataset_properties.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Meta-learning dataset properties 7 | # 8 | # Provides the paths relevant to the meta-dataset generation code 9 | OUTPUT_DATASET_PATH="/checkpoint/ramav/adhoc_concept_data/" 10 | RAW_IMAGE_DATA_PATH="adhoc_images_slurm_v0.2" 11 | 12 | # 13 | DATASET_PROPERTIES="clevr_typed_fol_properties" 14 | GRAMMAR_TYPE="v2_typed_simple_fol" 15 | MAX_RECURSION_DEPTH=6 16 | NUM_HYPOTHESES=2000000 17 | BAN_STRINGS_WITH_SAME_ARGS=1 18 | MAX_SCENE_FILE_ID_FOR_EXEC=200 19 | RAW_DATASET_NUM_IMAGES=5000000 20 | META_DATASET_TRAIN_SIZE=500000 21 | 22 | META_DATASET_NAME=$(cat <", 47 | "*" 48 | ], 49 | [ 50 | "#L# #L# >", 51 | "*" 52 | ], 53 | [ 54 | "#NUM# #NUM# >", 55 | "*" 56 | ], 57 | [ 58 | "#SETFC# #C# for-all=", 59 | "*" 60 | ], 61 | [ 62 | "#SETFSH# #SH# for-all=", 63 | "*" 64 | ], 65 | [ 66 | "#SETFM# #M# for-all=", 67 | "*" 68 | ], 69 | [ 70 | "#SETFSI# #SI# for-all=", 71 | "*" 72 | ], 73 | [ 74 | "#SETFL# #L# for-all=", 75 | "*" 76 | ], 77 | [ 78 | "#SETFC# #C# exists=", 79 | "*" 80 | ], 81 | [ 82 | "#SETFSH# #SH# exists=", 83 | "*" 84 | ], 85 | [ 86 | "#SETFM# #M# exists=", 87 | "*" 88 | ], 89 | [ 90 | "#SETFSI# #SI# exists=", 91 | "*" 92 | ], 93 | [ 94 | "#SETFL# #L# exists=", 95 | "*" 96 | ] 97 | ] 98 | ], 99 | "NUM": [ 100 | [ 101 | "#SETFC# #C# count=", 102 | "#SETFSH# #SH# count=", 103 | "#SETFM# #M# count=", 104 | "#SETFSI# #SI# count=", 105 | "#SETFL# #L# count=" 106 | ], 107 | [ 108 | "1", 109 | "2", 110 | "3" 111 | ] 112 | ], 113 | "SETFC": "#SET# #FC#", 114 | "SETFSH": "#SET# #FSH#", 115 | "SETFM": "#SET# #FM#", 116 | "SETFSI": "#SET# #FSI#", 117 | "SETFL": "#SET# #FL#", 118 | "C": [ 119 | [ 120 | "gray", 121 | "red", 122 | "blue", 123 | "green", 124 | "brown", 125 | "purple", 126 | "cyan", 127 | "yellow" 128 | ], 129 | [ 130 | "#OBJECT# #FC#" 131 | ] 132 | ], 133 | "SH": [ 134 | [ 135 | "cube", 136 | "sphere", 137 | "cylinder" 138 | ], 139 | [ 140 | "#OBJECT# #FSH#" 141 | ] 142 | ], 143 | "M": [ 144 | [ 145 | "rubber", 146 | "metal" 147 | ], 148 | [ 149 | "#OBJECT# #FM#" 150 | ] 151 | ], 152 | "SI": [ 153 | [ 154 | "large", 155 | "small" 156 | ], 157 | [ 158 | "#OBJECT# #FSI#" 159 | ] 160 | ], 161 | "L": [ 162 | [ 163 | "1", 164 | "2", 165 | "3", 166 | "4", 167 | "5", 168 | "6", 169 | "7", 170 | "8" 171 | ], 172 | [ 173 | "#OBJECT# #FL#" 174 | ] 175 | ], 176 | "FC": "color?", 177 | "FSH": "shape?", 178 | "FM": "material?", 179 | "FSI": "size?", 180 | "FL": [ 181 | "locationX?", 182 | "locationY?" 183 | ], 184 | "OBJECT": "x", 185 | "SET": [ 186 | [ 187 | [ 188 | "S", 189 | "*" 190 | ], 191 | [ 192 | "non-x-S", 193 | "*" 194 | ] 195 | ] 196 | ] 197 | }, 198 | "metadata": { 199 | "variables": [ 200 | "x", 201 | "S", 202 | "non-x-S" 203 | ], 204 | "functions": [ 205 | "FC", 206 | "FSH", 207 | "FM", 208 | "FSI", 209 | "FL" 210 | ] 211 | } 212 | } -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import dataloaders.get_dataloader as get_dataloader 7 | import dataloaders.adhoc_data_loader as adhoc_data_loader -------------------------------------------------------------------------------- /dataloaders/build_sound_scene.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | import math 8 | import numpy as np 9 | import random 10 | import soundfile as sf 11 | import json 12 | import torch 13 | import logging 14 | 15 | from typing import List, Dict 16 | 17 | # TODO(ramav): Consolidate these settings into the rest of the code config. 18 | THIS_DIR = os.path.dirname(os.path.realpath(__file__)) 19 | SOURCE_DIR = '/checkpoint/aszlam/audio_clips/wavs' 20 | INSTRUMENTS = ['trumpet', 'clarinet', 'violin', 'flute', 'oboe', 'saxophone', 'french-horn', 'guitar'] 21 | CMAP = {"gray": 0, 22 | "red": 1, 23 | "blue": 2, 24 | "green": 3, 25 | "brown": 4, 26 | "purple": 5, 27 | "cyan": 6, 28 | "yellow": 7} 29 | 30 | NOTES = ['A4', 'B4', 'C4', 'D4', 'E4', 'F4', 'G4', 'A5'] 31 | NOTES_ALT = ['A3', 'B3', 'C3', 'D3', 'E3', 'F3', 'G3', 'A4'] #for french-horn and guitar 32 | 33 | #get from metadata? 34 | SHAPEMAP = {"cube":0, "sphere":1, "cylinder":2} 35 | SIZEMAP = {"large":1.0, "small":.3} 36 | L = 40000 37 | TL = 120000 38 | FILTN = 1000 39 | MAX_COMP = 8 40 | NUM_POSITIONS = 8 41 | VOLUMES = len(SIZEMAP) 42 | 43 | 44 | def filt(x): 45 | xn = np.linalg.norm(x) 46 | u = np.fft.fft(x) 47 | u[FILTN:] = 0 48 | z = np.real(np.fft.ifft(u)) 49 | zn = np.linalg.norm(z) 50 | return z*(xn/zn) 51 | 52 | MATERIALSMAP = {"rubber": lambda x: x, "metal": filt} 53 | 54 | 55 | 56 | def get_bins(obj, metadata): 57 | dimensions_to_idx = metadata['dimensions_to_idx'] 58 | nbinsx = metadata["location_bins"]["x"] 59 | nbinsy = metadata["location_bins"]["y"] 60 | xbin = int(math.floor(obj["pixel_coords"][dimensions_to_idx["x"]] / 61 | (metadata["image_size"]["x"] * (1.0 / nbinsx)))) 62 | ybin = int(math.floor(obj["pixel_coords"][dimensions_to_idx["y"]] / 63 | (metadata["image_size"]["y"] * (1.0 / nbinsy)))) 64 | return xbin, ybin 65 | 66 | 67 | class ClevrJsonToSoundTensor(object): 68 | def __init__(self, metadata_path): 69 | self.TL = TL 70 | self.metadata = json.load(open(metadata_path, 'r'))['metadata'] 71 | self.clips = [] 72 | self.clip_info = [] 73 | for fname in os.listdir(SOURCE_DIR): 74 | q = fname.split('_') 75 | data, _ = sf.read(os.path.join(SOURCE_DIR, fname)) 76 | if len(data) >= L: 77 | try: 78 | i = INSTRUMENTS.index(q[0]) 79 | if q[0] == "guitar" or q[0] == "french-horn": 80 | pitch = NOTES_ALT.index(q[1]) 81 | else: 82 | pitch = NOTES.index(q[1]) 83 | self.clips.append(data[:L]) 84 | self.clip_info.append([i, pitch]) 85 | except: 86 | continue 87 | self.clips_by_prop = {} 88 | for i in range(len(INSTRUMENTS)): 89 | self.clips_by_prop[i] = {} 90 | for j in range(len(NOTES)): 91 | self.clips_by_prop[i][j] = [] 92 | 93 | for c in range(len(self.clip_info)): 94 | i = self.clip_info[c][0] 95 | pitch = self.clip_info[c][1] 96 | self.clips_by_prop[i][pitch].append(c) 97 | 98 | self.offsets = np.floor((np.linspace(0, TL - L, NUM_POSITIONS))).astype('int64') 99 | self.masks = [np.ones(L), 100 | np.linspace(1, 0, L)**2, 101 | np.linspace(0, 1, L)**2] 102 | 103 | # TODO(ramav): Add material into the sound creation pipeline. 104 | def __call__(self, json_objects: List[Dict]): 105 | out = np.zeros(TL) 106 | for obj in json_objects: 107 | instrument = CMAP[obj['color']] 108 | shift, pitch = get_bins(obj, self.metadata) 109 | #FIX THE EMPTY clip_by_prop bin!!!!! 110 | try: 111 | clipid = random.choice(self.clips_by_prop[instrument][pitch]) 112 | except: 113 | logging.info('warning: bad translation bc missing pitch, FIXME') 114 | continue 115 | c = self.clips[clipid] 116 | c = MATERIALSMAP[obj['material']](c) 117 | mask = self.masks[SHAPEMAP[obj['shape']]] 118 | v = SIZEMAP[obj['size']] 119 | c = c*mask*v 120 | offset = self.offsets[shift] 121 | out[offset:offset + L] = out[offset:offset + L] + c 122 | return out 123 | 124 | def generate_from_json(self, fpath): 125 | with open(fpath) as j: 126 | spec = json.load(j) 127 | return self.__call__(spec["objects"]) 128 | 129 | def generate_random(self): 130 | num_comp = np.random.randint(MAX_COMP) 131 | clip = np.zeros(TL) 132 | info = [] 133 | for i in range(num_comp): 134 | j = np.random.randint(len(self.clips)) 135 | p = np.random.randint(NUM_POSITIONS) 136 | c = self.clips[j] 137 | vid = np.random.randint(VOLUMES) 138 | v = vid/VOLUMES 139 | mid = np.random.randint(len(self.masks)) 140 | mask = self.masks[mid] 141 | c = c*mask*v 142 | offset = self.offsets[p] 143 | clip[offset:offset + L] = clip[offset:offset + L] + c 144 | info.append([self.clip_info[j][0], 145 | self.clip_info[j][1], 146 | p, 147 | mid, 148 | vid]) 149 | return clip, info 150 | 151 | 152 | 153 | if __name__ == "__main__": 154 | import visdom 155 | import json 156 | vis = visdom.Visdom(server ='http://localhost') 157 | S = ClevrJsonToSoundTensor('/private/home/aszlam/junk/clevr_typed_fol_properties.json') 158 | 159 | -------------------------------------------------------------------------------- /dataloaders/get_dataloader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import dataloaders 7 | 8 | def load(cfg, batch_size, splits): 9 | if cfg.data.dataset == 'adhoc_concepts': 10 | ds = ( 11 | dataloaders.adhoc_data_loader.get_adhoc_loader( 12 | cfg, batch_size, splits) 13 | ) 14 | else: 15 | raise ValueError("Unknown dataset: {:s}".format(cfg.data.dataset)) 16 | 17 | return ds 18 | 19 | class GetDataloader(object): 20 | def __init__(self, splits): 21 | if "&" in splits: 22 | self._splits = splits.split(" & ") 23 | elif "," in splits: 24 | self._splits = splits.split(",") 25 | else: 26 | self._splits = [splits] 27 | 28 | def __call__(self, cfg, batch_size=None): 29 | if batch_size == None: 30 | batch_size = cfg._data.batch_size 31 | 32 | data = load(cfg, batch_size, self._splits) 33 | 34 | if cfg.data.split_type not in ["comp", "iid", "color_count", 35 | "color_location", "color_material", 36 | "color", "shape", "color_boolean", 37 | "length_threshold_10"]: 38 | raise ValueError(f"Unknown split {cfg.data.split_type}") 39 | 40 | train_loader = data.get('train') 41 | eval_loader = data.get(cfg.eval_split_name) 42 | 43 | if train_loader is not None: 44 | if len( 45 | train_loader.dataset.hypotheses_in_split.intersection( 46 | eval_loader.dataset.hypotheses_in_split) 47 | ) != 0 and (cfg.data.split_type != "iid"): 48 | if cfg.eval_split_name != "train": 49 | raise ValueError( 50 | "Expect no overlap in concepts between train and eval" 51 | "splits.") 52 | 53 | if (train_loader.dataset.vocabulary.items != 54 | eval_loader.dataset.vocabulary.items): 55 | raise ValueError( 56 | "Expect identical vocabularies in train and val.") 57 | 58 | return data 59 | -------------------------------------------------------------------------------- /dataloaders/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """A set of utilities for dataloaders.""" 7 | import os 8 | import numpy as np 9 | import tempfile 10 | import torch 11 | import torch.utils.data as data 12 | import json 13 | import pickle 14 | 15 | from typing import Union, Tuple 16 | from torchvision.datasets.folder import has_file_allowed_extension 17 | 18 | 19 | def has_allowed_extension(f, extension): 20 | if len(f) <= len(extension): 21 | return False 22 | if f[len(f)-len(extension):] == extension: 23 | return True 24 | return False 25 | 26 | 27 | def tokenizer_programs(prog_string): 28 | """A tokenizer function for programs.""" 29 | prog_string = prog_string.replace("lambda S.", "lambdaS.") 30 | return prog_string.split(' ') 31 | 32 | 33 | def clevr_json_loader(file): 34 | with open(file, 'r') as f: 35 | json_scene_data = json.load(f) 36 | 37 | return json_scene_data['objects'] 38 | 39 | def to_tensor_sound(x): 40 | return torch.Tensor(x) 41 | 42 | def sound_loader(file): 43 | with open(file, 'rb') as f: 44 | sound_data = pickle.load(f) 45 | return sound_data 46 | 47 | def _numeric_string_array_to_numbers(numeric_string_array, cast_type="float"): 48 | if cast_type=="float": 49 | cast_fn = float 50 | elif cast_type=="int": 51 | cast_fn = int 52 | 53 | numeric_array = [] 54 | for t in numeric_string_array: 55 | numeric_array.append( 56 | np.array([cast_fn(x) for x in t.split(",")])) 57 | return numeric_array 58 | 59 | class VisionDataset(data.Dataset): 60 | _repr_indent = 4 61 | 62 | def __init__(self, 63 | root, 64 | transforms=None, 65 | transform=None, 66 | target_transform=None): 67 | if isinstance(root, torch._six.string_classes): 68 | root = os.path.expanduser(root) 69 | self.root = root 70 | 71 | has_transforms = transforms is not None 72 | has_separate_transform = transform is not None or target_transform is not None 73 | if has_transforms and has_separate_transform: 74 | raise ValueError( 75 | "Only transforms or transform/target_transform can " 76 | "be passed as argument") 77 | 78 | # for backwards-compatibility 79 | self.transform = transform 80 | self.target_transform = target_transform 81 | 82 | self.transforms = transforms 83 | 84 | def __getitem__(self, index): 85 | raise NotImplementedError 86 | 87 | def __len__(self): 88 | raise NotImplementedError 89 | 90 | def __repr__(self): 91 | head = "Dataset " + self.__class__.__name__ 92 | body = ["Number of datapoints: {}".format(self.__len__())] 93 | if self.root is not None: 94 | body.append("Root location: {}".format(self.root)) 95 | body += self.extra_repr().splitlines() 96 | if hasattr(self, "transforms") and self.transforms is not None: 97 | body += [repr(self.transforms)] 98 | lines = [head] + [" " * self._repr_indent + line for line in body] 99 | return '\n'.join(lines) 100 | 101 | def _format_transform_repr(self, transform, head): 102 | lines = transform.__repr__().splitlines() 103 | return (["{}{}".format(head, lines[0])] + 104 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 105 | 106 | def extra_repr(self): 107 | return "" 108 | 109 | 110 | class ImageAccess(object): 111 | def __init__(self, 112 | root_dir, 113 | image_string="ADHOC_train_%.8d.png", 114 | file_pattern=".png", 115 | debug=False): 116 | all_data = make_adhoc_dataset_with_buffer( 117 | root_dir, extensions=tuple([file_pattern])) 118 | self._all_data = all_data 119 | 120 | def __call__(self, image_id): 121 | return self._all_data[image_id] 122 | 123 | 124 | def get_validity_fn(extensions: Union[Tuple[str], str]): 125 | """Return a function that assesses validity of files based on extensions. 126 | 127 | Args: 128 | extensions: Tuple of str, or str, where each element is a possible extension 129 | Returns: 130 | A function that checks if a filepath is valid 131 | Raises: 132 | ValueError if extensions is not a list 133 | """ 134 | def is_valid_file(x): 135 | return has_file_allowed_extension(x, extensions) 136 | #return has_allowed_extension(x, extensions) 137 | 138 | return is_valid_file 139 | 140 | 141 | def make_adhoc_dataset_with_buffer(dir, 142 | extensions=None, 143 | is_valid_file=None, 144 | buffer_threshold=100000): 145 | """Implements a buffered way to create a folderdataset. 146 | 147 | A modification of `torchvision.datasets.folder.make_dataset` that uses a 148 | buffer mechanism for faster performance when the number of files in the 149 | dataset can potentially be very very large. 150 | 151 | Args: 152 | dir: Str, Directory where the dataset exists 153 | extensions: List of Str 154 | is_valid_file: Function 155 | buffer_threshold: Int, number of entries to flush the buffer with 156 | Returns: 157 | all_data: list of tuple of path to object and class index. 158 | Raises: 159 | ValueError: If both extensions and is_valid_file are None or if dataset 160 | files are not in the format x_y_z.extension or if we repeat an index that 161 | has already been processed 162 | RuntimeError: If the system command pfind fails 163 | """ 164 | dir = os.path.expanduser(dir) 165 | if not ((extensions is None) ^ (is_valid_file is None)): 166 | raise ValueError( 167 | "Both extensions and is_valid_file cannot be None or not None at the same time" 168 | ) 169 | if extensions is not None: 170 | is_valid_file = get_validity_fn(extensions) 171 | 172 | all_data = {} 173 | buffer = {} 174 | 175 | _, tfile = tempfile.mkstemp() 176 | ret = os.system('pfind %s > %s' % (dir, tfile)) 177 | if ret == 0: 178 | with open(tfile, 'r') as f: 179 | file_list = [x.rstrip() for x in f.readlines()] 180 | os.system('rm %s' % (tfile)) 181 | else: 182 | raise RuntimeError("System command pfind failed.") 183 | 184 | for path in sorted(file_list): 185 | if is_valid_file(path): 186 | fname = path.split('/')[-1] 187 | idx = fname.split('_') 188 | 189 | if len(idx) != 3: 190 | raise ValueError("Unexpected file format.") 191 | 192 | idx = int(fname.split('_')[-1].split('.')[0]) 193 | if idx in buffer.keys(): 194 | raise ValueError("Index already processed.") 195 | buffer[idx] = path 196 | 197 | if len(buffer) > buffer_threshold: 198 | all_data = {**all_data, **buffer} 199 | del buffer 200 | buffer = {} 201 | 202 | all_data = {**all_data, **buffer} 203 | 204 | return all_data 205 | 206 | 207 | class DatasetFolderPathIndexing(VisionDataset): 208 | def __init__(self, 209 | root, 210 | loader, 211 | extensions=None, 212 | transform=None, 213 | target_transform=None, 214 | is_valid_file=None): 215 | super(DatasetFolderPathIndexing, 216 | self).__init__(root, 217 | transform=transform, 218 | target_transform=target_transform) 219 | samples = make_adhoc_dataset_with_buffer(self.root, extensions, 220 | is_valid_file) 221 | if len(samples) == 0: 222 | raise (RuntimeError("Found 0 files in subfolders of: " + 223 | self.root + "\n" 224 | "Supported extensions are: " + 225 | ",".join(extensions))) 226 | 227 | self.loader = loader 228 | self.extensions = extensions 229 | 230 | self.samples = samples 231 | 232 | def __getitem__(self, index): 233 | """ 234 | Args: 235 | index (int): Index 236 | 237 | Returns: 238 | tuple: (sample, target) where target is class_index of the target class. 239 | """ 240 | path = self.samples[index] 241 | sample = self.loader(path) 242 | if self.transform is not None: 243 | sample = self.transform(sample) 244 | return {"datum": sample, "path": path} 245 | 246 | def get_item_list(self, index_list): 247 | return [self.__getitem__(x) for x in index_list] 248 | 249 | def __len__(self): 250 | return len(self.samples) 251 | -------------------------------------------------------------------------------- /dataloaders/vocabulary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """A vocabulary class definition. 7 | 8 | This code largely builds on the vocabulary class definition in FairSeq. 9 | """ 10 | import json 11 | import logging 12 | import numpy as np 13 | import torch 14 | 15 | from collections import Counter 16 | 17 | class Vocabulary(object): 18 | def __init__( 19 | self, 20 | all_sentences: list, 21 | tokenizer, 22 | pad='', 23 | eos='', 24 | unk='', 25 | bos='', 26 | ): 27 | """Initialize the vocabulary""" 28 | tokenized_sentences = [tokenizer(x) for x in all_sentences] 29 | vocabulary = [] 30 | 31 | vocabulary.append(pad) 32 | vocabulary.append(eos) 33 | vocabulary.append(unk) 34 | vocabulary.append(bos) 35 | 36 | # Closed world assumption, cannot see a sentence longer than what has been 37 | # shown at training time. 38 | max_length_dataset = np.max([len(x) for x in tokenized_sentences]) 39 | vocabulary.extend( 40 | sorted( 41 | Counter([x for y in tokenized_sentences for x in y]).keys())) 42 | 43 | logging.info("%d tokens found in the dataset." % (len(vocabulary))) 44 | 45 | self._items = vocabulary 46 | self._tokenizer = tokenizer 47 | # Add +2 below for bos and eos tokens. 48 | self._max_length_dataset = max_length_dataset + 2 49 | self._item_to_idx = {v: k for k, v in enumerate(vocabulary)} 50 | self._idx_to_item = {k: v for k, v in enumerate(vocabulary)} 51 | self._pad_string = pad 52 | self._eos_string = eos 53 | self._unk_string = unk 54 | self._bos_string = bos 55 | 56 | def index(self, sym): 57 | """Returns the index of the specified symbol""" 58 | assert isinstance(sym, str) 59 | if sym in self._items: 60 | return self._item_to_idx[sym] 61 | return self.unk() 62 | 63 | def encode_string(self, 64 | input_string, 65 | append_eos=True, 66 | add_start_token=True): 67 | tokenized_string = self._tokenizer(input_string) 68 | n_words = len(tokenized_string) 69 | 70 | additional_tokens = 0 71 | if add_start_token is True: 72 | additional_tokens += 1 73 | 74 | if append_eos is True: 75 | additional_tokens += 1 76 | 77 | if n_words + additional_tokens > self._max_length_dataset: 78 | raise ValueError("String is too long to encode.") 79 | ids = torch.IntTensor(self._max_length_dataset) 80 | if add_start_token is True: 81 | ids[0] = self.bos() 82 | 83 | for i, w in enumerate(tokenized_string): 84 | ids[i + additional_tokens - 1] = self.index(w) 85 | 86 | if append_eos == True: 87 | ids[len(tokenized_string) + additional_tokens - 1] = self.eos() 88 | 89 | for i in range( 90 | len(tokenized_string) + additional_tokens, 91 | self._max_length_dataset): 92 | ids[i] = self.pad() 93 | 94 | return ids 95 | 96 | def decode_string(self, tensor, escape_unk=False): 97 | """Helper for converting a tensor of token indices to a string. 98 | 99 | Can cfgionally remove BPE symbols or escape words. 100 | """ 101 | if torch.is_tensor(tensor) and tensor.dim() == 2: 102 | return '\n'.join(self.string(t, escape_unk) for t in tensor) 103 | 104 | def token_string(i): 105 | if i == self.unk(): 106 | return self.unk_string(escape_unk) 107 | else: 108 | return self._idx_to_item[int(i)] 109 | 110 | sent = ' '.join(token_string(i) for i in tensor) 111 | return sent 112 | 113 | def __len__(self): 114 | return len(self._item_to_idx) 115 | 116 | def bos(self): 117 | """Helper to get index of beginning-of-sentence symbol""" 118 | return self._item_to_idx[self._bos_string] 119 | 120 | def pad(self): 121 | """Helper to get index of pad symbol""" 122 | return self._item_to_idx[self._pad_string] 123 | 124 | def eos(self): 125 | """Helper to get index of end-of-sentence symbol""" 126 | return self._item_to_idx[self._eos_string] 127 | 128 | def unk(self): 129 | """Helper to get index of unk symbol""" 130 | return self._item_to_idx[self._unk_string] 131 | 132 | def bos_string(self): 133 | """Helper to get index of beginning-of-sentence symbol""" 134 | return self._bos_string 135 | 136 | def pad_string(self): 137 | """Helper to get index of pad symbol""" 138 | return self._pad_string 139 | 140 | def eos_string(self): 141 | """Helper to get index of end-of-sentence symbol""" 142 | return self._eos_string 143 | 144 | def unk_string(self, escape_unk=False): 145 | """Helper to get index of unk symbol""" 146 | if escape_unk is True: 147 | return '' 148 | return self._unk_string 149 | 150 | @property 151 | def items(self): 152 | return self._items 153 | 154 | 155 | class ClevrJsonToTensor(object): 156 | def __init__(self, properties_file_path): 157 | with open(properties_file_path, "r") as f: 158 | properties_json = json.load(f) 159 | metadata = properties_json["metadata"] 160 | properties_vocabulary = [] 161 | 162 | # NOTE: The notion of categorical properties here is based on the 163 | # CLEVR JSON format, and and not based on the language of thought. 164 | # This is in many ways a deliberate choice. What is discrete vs not 165 | # in the language of thought is something that learning should 166 | # handle, ideally. 167 | cateogrical_properties = [] 168 | 169 | for key in properties_json["properties"]: 170 | properties_vocabulary.extend( 171 | list(properties_json["properties"][key].keys())) 172 | # Properties file has plural forms of the properties listed. 173 | # Make it singular. 174 | cateogrical_properties.append(key.rstrip("s")) 175 | 176 | self._dimension_to_axis_name = {v: k for k, v in properties_json[ 177 | "metadata"]["dimensions_to_idx"].items()} 178 | self._max_objects_in_scene = properties_json["metadata"]["max_objects"] 179 | 180 | self._vocabulary = properties_vocabulary 181 | self._categorical_properties = cateogrical_properties 182 | self._metadata = metadata 183 | self._idx_to_value = {} 184 | self._word_to_idx = {v: k for k, v in enumerate(self._vocabulary)} 185 | self._BAN_FROM_ENCODING="3d_coords" 186 | 187 | for idx in range(len(self._vocabulary)): 188 | self._idx_to_value[idx] = 1 189 | 190 | self._word_to_idx["pixel_coords_x"] = len(self._word_to_idx) 191 | self._idx_to_value[len(self._word_to_idx)-1] = lambda x: x/float( 192 | metadata["image_size"]["x"]) 193 | 194 | self._word_to_idx["pixel_coords_y"] = len(self._word_to_idx) 195 | self._idx_to_value[len(self._word_to_idx)-1] = lambda x: x/float( 196 | metadata["image_size"]["y"] 197 | ) 198 | 199 | self._word_to_idx["pixel_coords_z"] = len(self._word_to_idx) 200 | self._idx_to_value[len(self._word_to_idx)-1] = lambda x: x/float( 201 | metadata["image_size"]["z"]) 202 | 203 | self._word_to_idx["rotation"] = len(self._word_to_idx) 204 | self._idx_to_value[len(self._word_to_idx)-1] = lambda x: x/float( 205 | metadata["max_rotation"] 206 | ) 207 | 208 | 209 | def _encode(self, obj): 210 | 211 | flat_obj = {} 212 | for prop, value in obj.items(): 213 | if prop == "pixel_coords": 214 | for loc_idx in range(len(value)): 215 | word_string = (prop + "_" + 216 | self._dimension_to_axis_name[loc_idx]) 217 | flat_obj[word_string] = value[loc_idx] 218 | elif prop not in self._BAN_FROM_ENCODING: 219 | flat_obj[prop] = value 220 | sorted_keys = sorted(list(flat_obj.keys())) 221 | overall_vec = [] 222 | for prop in sorted_keys: 223 | value = flat_obj[prop] 224 | vec = torch.zeros(len(self._word_to_idx)) 225 | if prop in self._categorical_properties: 226 | vec[self._word_to_idx[value]] = self._idx_to_value[ 227 | self._word_to_idx[value]] 228 | else: 229 | vec[self._word_to_idx[prop]] = self._idx_to_value[ 230 | self._word_to_idx[prop]](value) 231 | overall_vec.append(vec) 232 | 233 | return torch.stack(overall_vec, dim=0).sum(0) 234 | 235 | def __call__(self, scene): 236 | x = [] 237 | if len(scene) > self.max_objects_in_scene: 238 | raise ValueError( 239 | "Number of objects in scene greater than max objects in scene.") 240 | 241 | for obj in scene: 242 | x.append(self._encode(obj)) 243 | 244 | for _ in range(len(scene), self.max_objects_in_scene): 245 | x.append(-1 * torch.ones_like(x[0])) 246 | 247 | return torch.stack(x) 248 | 249 | @property 250 | def max_objects_in_scene(self): 251 | return self._max_objects_in_scene -------------------------------------------------------------------------------- /docs/model_generalization.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/productive_concept_learning/d00a3b85215bff0686ad3f834d50c8eae374f06a/docs/model_generalization.pdf -------------------------------------------------------------------------------- /docs/model_generalization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/productive_concept_learning/d00a3b85215bff0686ad3f834d50c8eae374f06a/docs/model_generalization.png -------------------------------------------------------------------------------- /docs/rel_net_alternate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/productive_concept_learning/d00a3b85215bff0686ad3f834d50c8eae374f06a/docs/rel_net_alternate.png -------------------------------------------------------------------------------- /hydra_cfg/data_files/v2_typed_simple_fol_depth_6_trials_2000000_ban_1_max_scene_id_200.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Copyright (c) Facebook, Inc. and its affiliates. 7 | # All rights reserved. 8 | # 9 | # This source code is licensed under the license found in the 10 | # LICENSE file in the root directory of this source tree. 11 | filetype: v2_typed_simple_fol_depth_6_trials_2000000_ban_1_max_scene_id_200 12 | 13 | train: "${data.split_type}_sampling_${data.hypothesis_prior}_train_threshold_\ 14 | 0.10_pos_im_${data.num_positives}_neg\ 15 | _im_${data.num_negatives}_train_examples_${data.train_examples}\ 16 | _neg_type_${data.negative_type}_\ 17 | alternate_hypo_1_random_seed_42.pkl" 18 | 19 | val: "${data.split_type}_sampling_${data.hypothesis_prior}_val_threshold_\ 20 | 0.10_pos_im_${data.num_positives}_neg_im_${data.num_negatives}_train_examples_\ 21 | ${data.train_examples}_neg_type_${data.negative_type}_\ 22 | alternate_hypo_1_random_seed_42.pkl" 23 | 24 | test: "${data.split_type}_sampling_${data.hypothesis_prior}_test\ 25 | _threshold_0.10_pos_im_${data.num_positives}_neg_im_${data.num_negatives}_train_examples_\ 26 | ${data.train_examples}_neg_type_${data.negative_type}\ 27 | _alternate_hypo_1_random_seed_42.pkl" 28 | 29 | qualitative: "qualitative_eval_inputs_for_hierarchy.pkl" 30 | 31 | cross_split: "${data.num_positives}_0.10_hypotheses_heavy.json" 32 | cross_split_hypothesis_image_mapping: "${data.num_positives}_0.10_image_hyp_mapping.json" 33 | -------------------------------------------------------------------------------- /hydra_cfg/experiment.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | defaults: 7 | # Expects to see this as the name of the folder containing the dataset. 8 | - data_files: v2_typed_simple_fol_depth_6_trials_2000000_ban_1_max_scene_id_200 9 | - task: adhoc_concepts 10 | - modality: image 11 | - pooling: concat 12 | - mode: train 13 | - special_rules: ${defaults.2.modality}_${defaults.3.pooling} 14 | optional: true 15 | 16 | # Points to the raw images, jsons, sounds etc. 17 | raw_data: 18 | data_dir: ${env:CURI_DATA_PATH} 19 | image_path: ${raw_data.data_dir}"/images" 20 | json_path: ${raw_data.data_dir}"/scenes" 21 | audio_path: ${raw_data.data_dir}"/sound_scenes" 22 | properties_file_path: ${env:PWD}/concept_data/clevr_typed_fol_properties.json 23 | 24 | data: 25 | dataset: "adhoc_concepts" 26 | split_type: "comp" 27 | negative_type: "alternate_hypotheses" # Only "alternate_hypothesis", "random" 28 | train_examples: 500000 29 | path: ${raw_data.data_dir}/hypotheses/${filetype} 30 | hypothesis_prior: "log_linear" 31 | num_negatives: 20 32 | num_positives: 5 33 | positive_threshold: 0.10 # Needs to be in %.2f format. TODO(ramav): Remove hardcoding of this. 34 | map_eval_num_images_per_concept: 3 35 | 36 | data_args: 37 | class: dataloaders.get_dataloader.GetDataloader 38 | params: 39 | splits: ${splits} 40 | 41 | model: 42 | name: protonet 43 | class: models.protonet.GetProtoNetModel 44 | params: 45 | feature_dim: 256 46 | obj_fdim: ${_model.obj_fdim} 47 | pooling: ${_model.pooling} # "global_average_pooling" Or "rel_net" Or "concat" Or "trnsf" 48 | modality: ${_data.modality} 49 | pretrained_encoder: False 50 | num_classes: ${num_classes} 51 | language_alpha: ${loss.params.alpha} 52 | input_dim: ${input_dim} 53 | init_to_use_pooling: ${_model.pooling_init} 54 | use_batch_norm_rel_net: ${_modality.use_batch_norm_rel_net} 55 | pairwise_position_encoding: ${_model.rel_pos_enc} 56 | absolute_position_encoding_for_pooling: ${_model.abs_pos_enc} 57 | absolute_position_encoding_for_modality: ${_modality.abs_pos_enc} 58 | im_fg: True 59 | 60 | opt: 61 | max_steps: 1000000 62 | checkpoint_every: 30000 63 | lr_gamma: 0.5 64 | lr_patience: 10 65 | num_workers: 10 66 | weight_decay: False 67 | 68 | loss: 69 | name: "nll" 70 | class: losses.NegativeLogLikelihoodMultiTask 71 | params: 72 | alpha: 0.1 73 | pad_token_idx: -10 # Will always be determined at runtime. 74 | num_classes: ${num_classes} 75 | 76 | costly_loss: 77 | name: "map" 78 | class: losses.MetaLearningMeanAveragePrecision 79 | 80 | device: "cuda" 81 | 82 | job_replica: 0 # Used to set the replica for running multiple jobs with same params. 83 | 84 | hydra: 85 | sweep: 86 | dir: ${env:RUN_DIR}/${hydra.job.name} 87 | subdir: ${hydra.job.override_dirname}/${job_replica} 88 | 89 | job: 90 | config: 91 | override_dirname: 92 | exclude_keys: ["job_replica", "mode", "opt.max_steps", 93 | "opt.checkpoint_every", "model_or_oracle_metrics", 94 | "eval_cfg.evaluate_once", "val", 95 | "splits", "eval_split_name", 96 | "test", "train", "eval_cfg.write_raw_metrics", 97 | "eval_cfg.evaluate_all", "eval_cfg.best_test_metric", 98 | "eval_cfg.sort_validation_checkpoints" 99 | ] 100 | 101 | run: 102 | dir: ${env:RUN_DIR}/${hydra.job.name}/${hydra.job.override_dirname} 103 | -------------------------------------------------------------------------------- /hydra_cfg/modality/image.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | input_dim: "3, 160, 240" 7 | _data: 8 | batch_size: 8 9 | map_eval_batch_size: 8 10 | modality: "image" 11 | 12 | _modality: 13 | abs_pos_enc: True 14 | use_batch_norm_rel_net: True 15 | 16 | opt: 17 | lr: 1e-4 -------------------------------------------------------------------------------- /hydra_cfg/modality/json.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | input_dim: "5, 19" 7 | _data: 8 | batch_size: 64 9 | map_eval_batch_size: 8 10 | modality: "json" 11 | 12 | _modality: 13 | abs_pos_enc: False 14 | use_batch_norm_rel_net: True # This is actually enabling batch norm. 15 | 16 | opt: 17 | lr: 1e-3 -------------------------------------------------------------------------------- /hydra_cfg/mode/eval.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | _mode: "eval" 7 | mem_requirement: 80GB 8 | splits: "cross_split & val" 9 | 10 | model_or_oracle_metrics: "model" 11 | eval_split_name: "val" # "val" or "test" 12 | 13 | eval_cfg: 14 | write_raw_metrics: True 15 | sort_validation_checkpoints: False 16 | evaluate_all: False 17 | evaluate_once: False 18 | best_test_metric: "modelmap" -------------------------------------------------------------------------------- /hydra_cfg/mode/train.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | _mode: "train" 7 | mem_requirement: 80GB 8 | splits: "train & val" 9 | 10 | # Unused, just needs to be there. 11 | eval_split_name: "val" # "val" or "test" 12 | 13 | model_or_oracle_metrics: "model" 14 | evaluate_once: False -------------------------------------------------------------------------------- /hydra_cfg/pooling/concat.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | _model: 7 | pooling: "concat" 8 | abs_pos_enc: True 9 | rel_pos_enc: False 10 | obj_fdim: 64 # Only for sound and image 11 | pooling_init: "xavier" -------------------------------------------------------------------------------- /hydra_cfg/pooling/gap.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | _model: 7 | pooling: "gap" 8 | abs_pos_enc: True 9 | rel_pos_enc: False 10 | obj_fdim: 64 # Only for sounds and image 11 | pooling_init: "xavier" -------------------------------------------------------------------------------- /hydra_cfg/pooling/rel_net.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | _model: 7 | pooling: "rel_net" 8 | abs_pos_enc: False 9 | rel_pos_enc: True 10 | obj_fdim: 64 # Only for sound and image 11 | pooling_init: "xavier" -------------------------------------------------------------------------------- /hydra_cfg/pooling/trnsf.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | _model: 7 | pooling: "trnsf" 8 | abs_pos_enc: True 9 | rel_pos_enc: False 10 | obj_fdim: 64 11 | pooling_init: "bert" -------------------------------------------------------------------------------- /hydra_cfg/special_rules/json_concat.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | _model: 7 | obj_fdim: 96 -------------------------------------------------------------------------------- /hydra_cfg/special_rules/json_gap.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | _model: 7 | obj_fdim: 96 -------------------------------------------------------------------------------- /hydra_cfg/special_rules/json_rel_net.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | _model: 7 | obj_fdim: 96 8 | rel_pos_enc: False -------------------------------------------------------------------------------- /hydra_cfg/special_rules/json_trnsf.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | _model: 7 | obj_fdim: 96 8 | -------------------------------------------------------------------------------- /hydra_cfg/special_rules/sound_concat.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | opt: 7 | lr: 5e-4 -------------------------------------------------------------------------------- /hydra_cfg/task/adhoc_concepts.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | num_classes: 2 -------------------------------------------------------------------------------- /hydra_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Eval script for running adhoc categorization. 7 | import hydra 8 | import logging 9 | import os 10 | 11 | import dataloaders 12 | import losses 13 | import models 14 | 15 | from collections import defaultdict 16 | from tqdm import tqdm 17 | from time import time 18 | from time import sleep 19 | 20 | from dataloaders.get_dataloader import GetDataloader 21 | from models._trainer import _Trainer 22 | from models._evaluator import _Evaluator 23 | from models._map_evaluator import _MapEvaluator 24 | 25 | _PRINT_EVERY = 10 26 | _NUM_EVAL_BATCHES_FOR_SPLIT = defaultdict(lambda: None) 27 | _NUM_EVAL_BATCHES_FOR_SPLIT["train"] = 625 28 | 29 | 30 | def load_model(cfg, vocabulary): 31 | get_model_helper = hydra.utils.instantiate(cfg.model) 32 | return get_model_helper(vocabulary) 33 | 34 | 35 | class _Workplace(object): 36 | def __init__(self, cfg): 37 | self.cfg = cfg 38 | self._eval_split_name = self.cfg.eval_split_name 39 | 40 | if self._eval_split_name == "qualitative": 41 | raise ValueError("Use hydra_qualitative.py for qualitative eval.") 42 | 43 | if self._eval_split_name not in self.cfg.splits.split(' & '): 44 | raise ValueError("Need to load the evaluation split.") 45 | 46 | # Initialize the data loader. 47 | get_data_loader = hydra.utils.instantiate(self.cfg.data_args) 48 | dataloaders = get_data_loader(self.cfg, self.cfg._data.map_eval_batch_size) 49 | 50 | # Train, val and test, all vocabularies are the same. 51 | self._vocabulary = dataloaders[self._eval_split_name].dataset.vocabulary 52 | logging.warn("Please ensure train, val and test " 53 | "loaders give the same vocabulary.") 54 | 55 | # Initialize the model/trainer 56 | model = load_model(self.cfg, self._vocabulary) 57 | loss_fn = hydra.utils.instantiate(config=self.cfg.loss, 58 | pad_token_idx=self._vocabulary.pad()) 59 | 60 | if self.cfg.model_or_oracle_metrics == "oracle": 61 | self._best_test_metric_or_oracle = "oracle" 62 | else: 63 | self._best_test_metric_or_oracle = self.cfg.eval_cfg.best_test_metric.replace('/', '_') 64 | 65 | 66 | # If evaluating on test mention how the validation checkpoint was chosen. 67 | # replace '/' with '_' if we have a best metric like modelmetrics/acc 68 | if self._eval_split_name == "test": 69 | write_metrics_file = (self._best_test_metric_or_oracle + "_" + 70 | self.cfg.get(self._eval_split_name).rstrip(".pkl") 71 | + "_metrics.txt") 72 | else: 73 | write_metrics_file = (self.cfg.get(self._eval_split_name).rstrip(".pkl") 74 | + "_metrics.txt") 75 | 76 | # If we are not computing model metrics, then the best test metric used 77 | # to choose validation points is irrelevant 78 | trainer = _Trainer(config=self.cfg, 79 | dataloader=dataloaders.get("train"), 80 | models={"model": model}, 81 | loss_fn=loss_fn, 82 | serialization_dir=os.getcwd(), 83 | write_metrics_file=write_metrics_file) 84 | 85 | evaluator = _Evaluator( 86 | config=self.cfg, 87 | loss_fn=loss_fn, 88 | dataloader=dataloaders[self._eval_split_name], 89 | models={"model": model}, 90 | ) 91 | 92 | # Two kinds of evaluators: cheap and costly. 93 | costly_loss_fn = hydra.utils.instantiate(self.cfg.costly_loss) 94 | costly_evaluator = _MapEvaluator( 95 | config=self.cfg, 96 | loss_fn=costly_loss_fn, 97 | test_loader=dataloaders["cross_split"], 98 | dataloader=dataloaders[self._eval_split_name], 99 | models={"model": model}, 100 | ) 101 | 102 | self._trainer = trainer 103 | self._evaluator = evaluator 104 | self._costly_evaluator = costly_evaluator 105 | 106 | def run_eval(self): 107 | # Iterate over all the checkpoints. 108 | current_iteration = -1 109 | active_iteration = -1 110 | num_sleep = 0 111 | _WAIT=7200 112 | 113 | # Compute metrics that do not depend on a model. 114 | if self.cfg.model_or_oracle_metrics == "oracle": 115 | all_oracle_baselines = {} 116 | 117 | # Compute oracle metrics for the query accuracy metric 118 | logging.info("Computing query accuracy oracles and baselines.") 119 | for eval_object in ["weak_oracle", "random", "oracle"]: 120 | all_oracle_baselines["acc_" + eval_object] = self._evaluator.evaluate( 121 | eval_object=eval_object, 122 | num_batches=_NUM_EVAL_BATCHES_FOR_SPLIT[self._eval_split_name]) 123 | 124 | logging.info(f"Completed {eval_object} evaluation.") 125 | 126 | # Compute oracle metrics for the map metric 127 | logging.info("Computing mAP oracles and baselines.") 128 | for eval_object in ["weak_oracle", "random", "oracle"]: 129 | all_oracle_baselines["map_" + eval_object] = self._costly_evaluator.evaluate( 130 | eval_object=eval_object, 131 | num_batches=_NUM_EVAL_BATCHES_FOR_SPLIT[self._eval_split_name]) 132 | 133 | logging.info(f"Completed {eval_object} evaluation.") 134 | 135 | if self.cfg.model_or_oracle_metrics == "model" and self.cfg.eval_cfg.evaluate_all == True: 136 | all_checkpoint_paths_and_idx = ( 137 | self._trainer._checkpoint_manager.all_checkpoints( 138 | sort_iterations=self.cfg.eval_cfg.sort_validation_checkpoints, 139 | random_shuffle=not self.cfg.eval_cfg.sort_validation_checkpoints)) 140 | if not isinstance(all_checkpoint_paths_and_idx, list): 141 | raise ValueError("Not enough checkpoints to evaluate.") 142 | 143 | if self.cfg.eval_cfg.evaluate_once == True: 144 | raise ValueError("Evaluate once and evaluate all cannot be true at once.") 145 | 146 | while(True): 147 | if self.cfg.model_or_oracle_metrics == "model": 148 | if self.cfg.eval_cfg.evaluate_all == True: 149 | active_checkpoint, active_iteration = all_checkpoint_paths_and_idx.pop() 150 | else: 151 | # Test set is always evaluated with the best checkpoint from validation. 152 | if self._eval_split_name == "test": 153 | active_checkpoint, active_iteration = ( 154 | self._trainer._checkpoint_manager.best_checkpoint( 155 | based_on_metric=self._best_test_metric_or_oracle) 156 | ) 157 | else: 158 | active_checkpoint, active_iteration = ( 159 | self._trainer._checkpoint_manager.latest_checkpoint) 160 | 161 | # Active iteration is None when we are evaluating the best checkpoint. 162 | if active_iteration is None and self._eval_split_name != "test": 163 | raise ValueError("Expect active_iteration to not be None.") 164 | else: 165 | active_checkpoint, active_iteration = None, None 166 | 167 | if active_iteration is None or active_iteration != current_iteration: 168 | all_metrics = {"model": {}, "metric": []} 169 | if self.cfg.model_or_oracle_metrics == "model": 170 | logging.info(f"Evaluating checkpoint {active_checkpoint}") 171 | self._trainer.load_checkpoint(active_checkpoint) 172 | all_metrics = self._costly_evaluator.evaluate( 173 | num_batches=_NUM_EVAL_BATCHES_FOR_SPLIT[self._eval_split_name]) 174 | cheap_metrics = self._evaluator.evaluate( 175 | num_batches=_NUM_EVAL_BATCHES_FOR_SPLIT[self._eval_split_name]) 176 | # Combine all model metrics and write them together. 177 | all_metrics["model"].update(cheap_metrics["model"]) 178 | 179 | if self.cfg.model_or_oracle_metrics == "oracle": 180 | for _, oracle_or_baseline in all_oracle_baselines.items(): 181 | all_metrics["model"].update(oracle_or_baseline["model"]) 182 | 183 | # Serialize metric values for plotting. 184 | # TODO(ramav): Make changes so that oracle jobs can also write 185 | # raw metrics. Something that is currently not possible. 186 | is_repeated_checkpoint = self._trainer.write_metrics(all_metrics, 187 | eval_split_name=self._eval_split_name, 188 | test_metric_name=self._best_test_metric_or_oracle, 189 | write_raw_metrics= 190 | (self._best_test_metric_or_oracle == "model" 191 | and self.cfg.eval_cfg.write_raw_metrics) 192 | ) 193 | 194 | # If we started evaluating checkpoints in descending order and we 195 | # hit a checkpoint that has already been evaluated, then stop the job. 196 | if is_repeated_checkpoint == True and self.cfg.eval_cfg.evaluate_all == True and ( 197 | self.cfg.eval_cfg.sort_validation_checkpoints == True 198 | ): 199 | logging.info("Reached a checkpoint that has already been " 200 | "evaluated, stopping eval now.") 201 | break 202 | 203 | 204 | current_iteration = active_iteration 205 | num_sleep = 0 206 | 207 | if (self.cfg.model_or_oracle_metrics == "oracle" or 208 | self.cfg.eval_cfg.evaluate_once == True or 209 | self._eval_split_name == "test"): 210 | logging.info("Finished evaluation.") 211 | break 212 | 213 | if self.cfg.eval_cfg.evaluate_all == True: 214 | if len(all_checkpoint_paths_and_idx) == 0: 215 | logging.info("No checkpoints left to evaluate. Finished evaluation.") 216 | break 217 | elif self.cfg.eval_cfg.evaluate_all == False: 218 | logging.info(f"Sleeping for {_WAIT} sec waiting for checkpoint.") 219 | sleep(_WAIT) 220 | num_sleep += 1 221 | 222 | if num_sleep == 10: 223 | logging.info(f"Terminating job after waiting for a new checkpoint.") 224 | break 225 | 226 | 227 | @hydra.main(config_path='hydra_cfg/experiment.yaml') 228 | def main(cfg): 229 | logging.info(cfg.pretty()) 230 | logging.info("Base Directory: %s", os.getcwd()) 231 | 232 | if cfg._mode != "eval": 233 | raise ValueError("Invalid mode %s" % cfg._mode) 234 | 235 | workplace = _Workplace(cfg) 236 | workplace.run_eval() 237 | 238 | 239 | if __name__ == "__main__": 240 | from hypothesis_generation.hypothesis_utils import MetaDatasetExample 241 | from hypothesis_generation.hypothesis_utils import HypothesisEval 242 | main() 243 | -------------------------------------------------------------------------------- /hydra_qualitative.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Qualitative examples for running adhoc categorization. 7 | import hydra 8 | import logging 9 | import os 10 | import pickle 11 | 12 | import dataloaders 13 | import losses 14 | import models 15 | 16 | from collections import defaultdict 17 | from tqdm import tqdm 18 | from time import time 19 | from time import sleep 20 | 21 | from dataloaders.get_dataloader import GetDataloader 22 | from models._trainer import _Trainer 23 | from models._evaluator import _Evaluator 24 | from models._map_evaluator import _MapEvaluator 25 | 26 | _PRINT_EVERY = 10 27 | _NUM_EVAL_BATCHES_FOR_SPLIT = defaultdict(lambda: None) 28 | _NUM_EVAL_BATCHES_FOR_SPLIT["train"] = 625 29 | 30 | 31 | def load_model(cfg, vocabulary): 32 | get_model_helper = hydra.utils.instantiate(cfg.model) 33 | return get_model_helper(vocabulary) 34 | 35 | 36 | class _Workplace(object): 37 | def __init__(self, cfg): 38 | self.cfg = cfg 39 | self._eval_split_name = self.cfg.eval_split_name 40 | if self._eval_split_name != "qualitative": 41 | raise ValueError("Expect qualitative split.") 42 | 43 | if self._eval_split_name not in self.cfg.splits.split(' & '): 44 | raise ValueError("Need to load the evaluation split.") 45 | 46 | # Initialize the data loader. 47 | get_data_loader = hydra.utils.instantiate(self.cfg.data_args) 48 | dataloaders = get_data_loader(self.cfg, self.cfg._data.map_eval_batch_size) 49 | 50 | # Train, val and test, all vocabularies are the same. 51 | self._vocabulary = dataloaders[self._eval_split_name].dataset.vocabulary 52 | logging.warn("Please ensure train, val and test " 53 | "loaders give the same vocabulary.") 54 | 55 | # Initialize the model/trainer 56 | model = load_model(self.cfg, self._vocabulary) 57 | loss_fn = hydra.utils.instantiate(config=self.cfg.loss, 58 | pad_token_idx=self._vocabulary.pad()) 59 | 60 | if self.cfg.model_or_oracle_metrics == "oracle": 61 | self._best_test_metric_or_oracle = "oracle" 62 | else: 63 | self._best_test_metric_or_oracle = self.cfg.eval_cfg.best_test_metric.replace('/', '_') 64 | 65 | 66 | # If evaluating on test mention how the validation checkpoint was chosen. 67 | # replace '/' with '_' if we have a best metric like modelmetrics/acc 68 | write_results_file = (self._best_test_metric_or_oracle + "_" + 69 | self.cfg.get(self._eval_split_name).rstrip(".pkl") 70 | + "_qualitative.pkl") 71 | 72 | # If we are not computing model metrics, then the best test metric used 73 | # to choose validation points is irrelevant 74 | trainer = _Trainer(config=self.cfg, 75 | dataloader=dataloaders.get("train"), 76 | models={"model": model}, 77 | loss_fn=loss_fn, 78 | serialization_dir=os.getcwd(), 79 | write_metrics_file=write_results_file) 80 | 81 | costly_loss_fn = hydra.utils.instantiate(self.cfg.costly_loss) 82 | costly_evaluator = _MapEvaluator( 83 | config=self.cfg, 84 | loss_fn=costly_loss_fn, 85 | test_loader=dataloaders["cross_split"], 86 | dataloader=dataloaders[self._eval_split_name], 87 | models={"model": model}, 88 | ) 89 | 90 | self._trainer = trainer 91 | self._costly_evaluator = costly_evaluator 92 | self._write_results_file = write_results_file 93 | 94 | def run_eval(self): 95 | # Iterate over all the checkpoints. 96 | current_iteration = -1 97 | active_iteration = -1 98 | num_sleep = 0 99 | _WAIT=7200 100 | 101 | if self.cfg.model_or_oracle_metrics == "model" and self.cfg.eval_cfg.evaluate_all == True: 102 | all_checkpoint_paths_and_idx = ( 103 | self._trainer._checkpoint_manager.all_checkpoints( 104 | sort_iterations=self.cfg.eval_cfg.sort_validation_checkpoints, 105 | random_shuffle=not self.cfg.eval_cfg.sort_validation_checkpoints)) 106 | if not isinstance(all_checkpoint_paths_and_idx, list): 107 | raise ValueError("Not enough checkpoints to evaluate.") 108 | 109 | if self.cfg.eval_cfg.evaluate_once == True: 110 | raise ValueError("Evaluate once and evaluate all cannot be true at once.") 111 | 112 | while(True): 113 | if self.cfg.model_or_oracle_metrics == "model": 114 | if self.cfg.eval_cfg.evaluate_all == True: 115 | active_checkpoint, active_iteration = all_checkpoint_paths_and_idx.pop() 116 | else: 117 | # Test set is always evaluated with the best checkpoint from validation. 118 | if self._eval_split_name == "test" or self._eval_split_name == "qualitative": 119 | active_checkpoint, active_iteration = ( 120 | self._trainer._checkpoint_manager.best_checkpoint( 121 | based_on_metric=self._best_test_metric_or_oracle) 122 | ) 123 | else: 124 | active_checkpoint, active_iteration = ( 125 | self._trainer._checkpoint_manager.latest_checkpoint) 126 | 127 | # Active iteration is None when we are evaluating the best checkpoint. 128 | if active_iteration is None and self._eval_split_name not in ["test", "qualitative"]: 129 | raise ValueError("Expect active_iteration to not be None.") 130 | else: 131 | active_checkpoint, active_iteration = None, None 132 | 133 | if active_iteration is None or active_iteration != current_iteration: 134 | logging.info(f"Evaluating checkpoint {active_checkpoint}") 135 | self._trainer.load_checkpoint(active_checkpoint) 136 | _, all_qualitative_results = self._costly_evaluator.evaluate( 137 | num_batches=_NUM_EVAL_BATCHES_FOR_SPLIT[self._eval_split_name], 138 | output_raw_results=True) 139 | held_out_image_paths = self._costly_evaluator._held_out_image_paths 140 | 141 | current_iteration = active_iteration 142 | num_sleep = 0 143 | 144 | if (self.cfg.model_or_oracle_metrics == "oracle" or 145 | self.cfg.eval_cfg.evaluate_once == True or 146 | self._eval_split_name == "qualitative"): 147 | logging.info("Finished evaluation.") 148 | break 149 | 150 | with open(self._write_results_file, 'wb') as f: 151 | pickle.dump({"qualitative_results": all_qualitative_results, 152 | "held_out_image_paths": held_out_image_paths}, f) 153 | 154 | 155 | @hydra.main(config_path='hydra_cfg/experiment.yaml') 156 | def main(cfg): 157 | logging.info(cfg.pretty()) 158 | logging.info("Base Directory: %s", os.getcwd()) 159 | 160 | if cfg._mode != "eval": 161 | raise ValueError("Invalid mode %s" % cfg._mode) 162 | 163 | workplace = _Workplace(cfg) 164 | workplace.run_eval() 165 | 166 | 167 | if __name__ == "__main__": 168 | from hypothesis_generation.hypothesis_utils import MetaDatasetExample 169 | from hypothesis_generation.hypothesis_utils import HypothesisEval 170 | main() 171 | -------------------------------------------------------------------------------- /hydra_train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Test script for running adhoc categorization. 7 | import hydra 8 | import logging 9 | import os 10 | 11 | import dataloaders 12 | import losses 13 | import models 14 | 15 | from tqdm import tqdm 16 | from time import time 17 | from time import sleep 18 | 19 | from dataloaders.get_dataloader import GetDataloader 20 | from models._trainer import _Trainer 21 | from models._evaluator import _Evaluator 22 | from models._map_evaluator import _MapEvaluator 23 | 24 | _PRINT_EVERY = 10 25 | 26 | 27 | def load_model(cfg, vocabulary): 28 | get_model_helper = hydra.utils.instantiate(cfg.model) 29 | return get_model_helper(vocabulary) 30 | 31 | 32 | class _Workplace(object): 33 | def __init__(self, cfg): 34 | self.cfg = cfg 35 | 36 | # Initialize the data loader. 37 | get_data_loader = hydra.utils.instantiate(self.cfg.data_args) 38 | dataloaders = get_data_loader(self.cfg) 39 | 40 | # Get the vocabulary we are training with. 41 | self._vocabulary = dataloaders["train"].dataset.vocabulary 42 | 43 | # Initialize the model/trainer 44 | model = load_model(self.cfg, self._vocabulary) 45 | total_params_model = sum(p.numel() for p in model.parameters() 46 | if p.requires_grad) 47 | logging.info(f"Model has {total_params_model} trainable parameters.") 48 | 49 | loss_fn = hydra.utils.instantiate(config=self.cfg.loss, 50 | pad_token_idx=self._vocabulary.pad()) 51 | 52 | trainer = _Trainer(config=self.cfg, 53 | dataloader=dataloaders["train"], 54 | models={"model": model}, 55 | loss_fn=loss_fn, 56 | serialization_dir=os.getcwd()) 57 | 58 | evaluator = _Evaluator( 59 | config=self.cfg, 60 | loss_fn=loss_fn, 61 | dataloader=dataloaders["val"], 62 | models={"model": model}, 63 | ) 64 | 65 | # Two kinds of evaluators: cheap and costly. 66 | costly_loss_fn = hydra.utils.instantiate(self.cfg.costly_loss) 67 | 68 | if dataloaders.get("cross_split") is not None: 69 | costly_evaluator = _MapEvaluator( 70 | config=self.cfg, 71 | loss_fn=costly_loss_fn, 72 | test_loader=dataloaders["cross_split"], 73 | dataloader=dataloaders["val"], 74 | models={"model": model}, 75 | ) 76 | else: 77 | costly_evaluator = None 78 | 79 | self._trainer = trainer 80 | self._evaluator = evaluator 81 | self._costly_evaluator = costly_evaluator 82 | 83 | def run_training(self): 84 | latest_checkpoint, latest_iteration = self._trainer._checkpoint_manager.latest_checkpoint 85 | self._trainer.load_checkpoint(latest_checkpoint) 86 | 87 | t = time() 88 | for step in range(latest_iteration + 1, self.cfg.opt.max_steps): 89 | loss = self._trainer.step()["loss"] 90 | if (step + 1) % self.cfg.opt.checkpoint_every == 0: 91 | # Clear all the metric values accumulated during training 92 | # TODO(ramav): Fix this more properly, this is a hack for now. 93 | self._trainer._loss_fn._reset_metrics() 94 | 95 | metrics = self._evaluator.evaluate() 96 | self._trainer.after_validation(metrics, step) 97 | 98 | if (step + 1) % _PRINT_EVERY == 0: 99 | time_elapsed = (time() - t) / _PRINT_EVERY 100 | logging.info("%d step]: loss: %f (%f sec per step)", 101 | step, loss.detach().cpu().item(), time_elapsed) 102 | t = time() 103 | 104 | def run_eval(self): 105 | # Iterate over all the checkpoints. 106 | current_iteration = -1 107 | latest_iteration = -1 108 | num_sleep = 0 109 | _WAIT=7200 110 | 111 | while(True): 112 | logging.info(f"Sleeping for {_WAIT} sec waiting for checkpoint.") 113 | sleep(_WAIT) 114 | 115 | num_sleep += 1 116 | latest_checkpoint, latest_iteration = ( 117 | self._trainer._checkpoint_manager.latest_checkpoint) 118 | 119 | if latest_iteration != current_iteration: 120 | self._trainer.load_checkpoint(latest_checkpoint) 121 | costly_metrics = self._costly_evaluator.evaluate() 122 | self._trainer.write_metrics(costly_metrics, latest_iteration) 123 | 124 | current_iteration = latest_iteration 125 | num_sleep = 0 126 | 127 | if num_sleep == 10: 128 | logging.info(f"Terminating job after waiting for a new checkpoint.") 129 | break 130 | 131 | 132 | @hydra.main(config_path='hydra_cfg/experiment.yaml') 133 | def main(cfg): 134 | logging.info(cfg.pretty()) 135 | 136 | logging.info("Base directory: %s", os.getcwd()) 137 | 138 | workplace = _Workplace(cfg) 139 | 140 | workplace.run_training() 141 | 142 | 143 | if __name__ == "__main__": 144 | from hypothesis_generation.hypothesis_utils import MetaDatasetExample 145 | from hypothesis_generation.hypothesis_utils import HypothesisEval 146 | main() 147 | -------------------------------------------------------------------------------- /hypothesis_generation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /hypothesis_generation/prefix_postfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Prefix Postfix Utilities""" 7 | import json 8 | import pickle 9 | 10 | from hypothesis_generation.hypothesis_utils import Node 11 | from hypothesis_generation.hypothesis_utils import GrammarExpander 12 | from hypothesis_generation.execute_hypotheses_on_scene_json import FunctionBinders 13 | 14 | 15 | class PrefixPostfix(object): 16 | 17 | def __init__(self, properties_file, grammar_expander_file=None, grammar_expander=None): 18 | if grammar_expander_file is None and grammar_expander is None: 19 | raise ValueError("Expect to be provided either expander serialized or expander object.") 20 | 21 | with open(properties_file, 'r') as f: 22 | properties_metadata = json.load(f) 23 | metadata = properties_metadata["metadata"] 24 | size_mapping = properties_metadata["properties"]["sizes"] 25 | 26 | if grammar_expander is None: 27 | with open(grammar_expander_file, 'rb') as f: 28 | grammar_expander = pickle.load(f) 29 | constants_in_grammar = grammar_expander.terminal_constants 30 | 31 | function_binders = FunctionBinders(metadata, size_mapping, 32 | constants_in_grammar) 33 | self._function_binders = function_binders 34 | 35 | def postfix_to_tree(self, postfix_program): 36 | execution_stack = [] 37 | if not isinstance(postfix_program, str): 38 | raise ValueError 39 | 40 | postfix_program = [x for x in postfix_program.split(' ') if x != ' '] 41 | 42 | quantifier_reached = 0 43 | for token in postfix_program: 44 | if token == 'lambda' or token == 'S.': 45 | quantifier_reached = 1 46 | continue 47 | 48 | if token in self._function_binders.function_names: 49 | _, cardinality = self._function_binders[token] 50 | current_node = Node(token, expansion_of="function") 51 | 52 | if quantifier_reached == 0: 53 | args = [] 54 | for _ in range(cardinality): 55 | if len(execution_stack) == 0: 56 | raise RuntimeError( 57 | "Stack is empty, check if the postfix_program is valid.") 58 | args.append(execution_stack.pop()) 59 | args.reverse() 60 | 61 | for this_arg in args: 62 | current_node.add_child(this_arg) 63 | 64 | new_stack_top = current_node 65 | else: 66 | new_stack_top = Node(token, expansion_of="") 67 | execution_stack.append(new_stack_top) 68 | 69 | if len(execution_stack) not in [1, 2]: 70 | raise ValueError("Invalid Program.") 71 | return execution_stack 72 | 73 | def tree_to_postfix(self, list_tree): 74 | expression_tree = list_tree[0] 75 | 76 | if len(list_tree) == 2: 77 | quantifier = list_tree[1].op 78 | if len(list_tree[1]) != 0: 79 | raise ValueError 80 | quantifier_print = quantifier + r"x \in S " 81 | else: 82 | quantifier_print = "" 83 | 84 | def t2p(tree): 85 | if len(tree) == 0: 86 | return tree.op 87 | elif len(tree) == 1: 88 | return tree.op + "( " + t2p(tree._children[0]) + " )" 89 | elif len(tree) == 2: 90 | return tree.op + "(" + t2p(tree._children[0]) + ", " + t2p( 91 | tree._children[1]) + " )" 92 | else: 93 | raise ValueError("Maximum 2 children per node expected.") 94 | 95 | return quantifier_print + t2p(expression_tree) 96 | 97 | def postfix_to_prefix(self, postfix): 98 | tree = self.postfix_to_tree(postfix) 99 | prefix = self.tree_to_postfix(tree) 100 | return prefix 101 | 102 | 103 | if __name__ == "__main__": 104 | properties_file = "concept_data/clevr_typed_fol_properties.json" 105 | grammar_expander = "concept_data/temp_data/v2_typed_simple_fol_clevr_typed_fol_properties.pkl" 106 | 107 | program_converter = PrefixPostfix(properties_file, grammar_expander) 108 | 109 | list_postfix = [ 110 | "brown x color? = sphere x shape? = and lambda S. for-all=", 111 | "S locationY? x locationY? count= 2 > lambda S. exists=", 112 | "S color? x color? exists= S color? x color? for-all= and lambda S. for-all=", 113 | "x shape? cylinder = x color? brown = and lambda S. for-all=", 114 | "S color? red count= 1 > lambda S. for-all=", 115 | "S shape? x shape? for-all= S color? cyan for-all= and lambda S. for-all=", 116 | "S shape? cube for-all= S color? cyan for-all= and lambda S. exists=", 117 | "S color? x color? count= S locationY? x locationY? count= > lambda S. for-all=", 118 | "S locationX? x locationX? for-all= x shape? cylinder = and lambda S. exists=", 119 | "S size? 0.35 for-all= S material? rubber for-all= and lambda S. for-all=", 120 | "non-x-S color? brown for-all= non-x-S shape? sphere for-all= and lambda S. exists=", 121 | "S shape? sphere exists= red x color? = and lambda S. for-all=", 122 | "non-x-S locationX? 7 exists= non-x-S color? x color? exists= and lambda S. exists=", 123 | "blue x color? = S size? x size? for-all= and lambda S. for-all=", 124 | "non-x-S locationY? x locationY? for-all= S size? x size? exists= and lambda S. exists=", 125 | "S locationX? 1 for-all= S size? x size? for-all= and lambda S. exists=", 126 | ] 127 | for this_postfix in list_postfix: 128 | this_prefix = program_converter.postfix_to_prefix(this_postfix) 129 | print(this_postfix) 130 | print(this_prefix) 131 | 132 | print("______________") -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Installation for setting up productive concept learning project. 7 | python3 -m venv release 8 | source release/bin/activate 9 | 10 | pip3 install torch torchvision torchaudio 11 | pip3 install fairseq==v0.9.0 12 | 13 | pip install OmegaConf==1.4.1 14 | pip install hydra-core==0.11 15 | pip install tensorboardX 16 | pip install tensorboard 17 | pip install soundfile 18 | pip install tqdm 19 | pip install scipy 20 | pip install matplotlib 21 | pip install scikit-learn==0.22.0 22 | pip install frozendict 23 | pip install pandas 24 | pip install submitit -------------------------------------------------------------------------------- /launch_test_jobs.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Run jobs on TEST set 7 | source release/bin/activate 8 | source paths.sh 9 | 10 | ############################## Global parameters ############################### 11 | METRIC="modelmap" 12 | GREATER_BETTER="1" 13 | JOB_TYPE="eval" 14 | REPLICA_JOBS="0,1,2" # 3 jobs per run. 15 | MODEL_OR_ORACLE="model" # Compute oracle metrics separately 16 | 17 | 18 | ################################### Model sweep ################################ 19 | JOB_NAME="sweep_models" 20 | SWEEP_PATH="${RUN_DIR}/${JOB_NAME}" 21 | python scripts/pick_best_checkpoints.py \ 22 | --sweep_folder ${SWEEP_PATH} \ 23 | --val_metric ${METRIC} \ 24 | --greater_is_better ${GREATER_BETTER} 25 | 26 | # Sweep that trains all the models explained in the paper. 27 | for SPLIT_TYPE in "comp" "iid" "color_count" "color_location" "color_material" \ 28 | "color" "shape" "color_boolean" "length_threshold_10" 29 | do 30 | for MODALITY in "image" "json" "sound" 31 | do 32 | for POOLING in "gap" "concat" "rel_net" "trnsf" 33 | do 34 | for LANGUAGE_ALPHA in "0.0" "1.0" 35 | do 36 | for NEGATIVE_TYPE in "alternate_hypotheses" "random" 37 | do 38 | for REPLICA_JOBS in "0" "1" "2" 39 | do 40 | EVAL_STR="" 41 | if [ $JOB_TYPE = "eval" ]; 42 | then 43 | EVAL_STR="eval_cfg.write_raw_metrics=True" 44 | fi 45 | CMD="python hydra_${JOB_TYPE}.py \ 46 | splits='test & cross_split'\ 47 | eval_split_name=test\ 48 | eval_cfg.best_test_metric=${METRIC} \ 49 | ${EVAL_STR}\ 50 | hydra.job.name=${JOB_NAME}\ 51 | mode=${JOB_TYPE} \ 52 | model_or_oracle_metrics=${MODEL_OR_ORACLE} \ 53 | modality=${MODALITY}\ 54 | pooling=${POOLING}\ 55 | data.split_type=${SPLIT_TYPE} \ 56 | data.negative_type=${NEGATIVE_TYPE}\ 57 | loss.params.alpha=${LANGUAGE_ALPHA} \ 58 | job_replica=${REPLICA_JOBS} &" 59 | echo ${CMD} 60 | eval ${CMD} 61 | done 62 | done 63 | done 64 | done 65 | done 66 | done 67 | 68 | 69 | ############################## Oracle Metrics ######################################## 70 | JOB_NAME="oracle" # Test oracle jobs 71 | SWEEP_PATH="${RUN_DIR}/${JOB_NAME}" 72 | 73 | MODEL_OR_ORACLE="oracle" # Compute oracle metrics separately 74 | LANGUAGE_ALPHA="0.0" 75 | REPLICA_JOBS="0" 76 | MODALITY="json" 77 | POOLING="gap" 78 | 79 | NEGATIVE_TYPE="alternate_hypotheses,random" 80 | 81 | # Jobs are launched with for loops to avoid MaxJobArrayLimit on SLURM. 82 | for SPLIT_TYPE in "iid" "color_count" "color_location" "color_material"\ 83 | "color" "shape" "color_boolean" "length_threshold_10" "comp" 84 | do 85 | for NEGATIVE_TYPE in "alternate_hypotheses" "random" 86 | do 87 | CMD="python hydra_${JOB_TYPE}.py \ 88 | splits='test & cross_split'\ 89 | eval_split_name=test\ 90 | eval_cfg.best_test_metric=${METRIC} \ 91 | hydra.job.name=${JOB_NAME}\ 92 | mode=${JOB_TYPE} \ 93 | model_or_oracle_metrics=${MODEL_OR_ORACLE} \ 94 | modality=${MODALITY}\ 95 | pooling=${POOLING}\ 96 | data.split_type=${SPLIT_TYPE} \ 97 | data.negative_type=${NEGATIVE_TYPE}\ 98 | loss.params.alpha=${LANGUAGE_ALPHA} \ 99 | job_replica=${REPLICA_JOBS} &" 100 | echo ${CMD} 101 | eval ${CMD} 102 | done 103 | done -------------------------------------------------------------------------------- /launch_train_eval_jobs.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | source release/bin/activate 7 | source paths.sh 8 | 9 | MODEL_OR_ORACLE="model" 10 | 11 | JOB_NAME="sweep_models" 12 | 13 | # Sweep that trains all the models explained in the paper. 14 | for SPLIT_TYPE in "comp" "iid" "color_count" "color_location" "color_material" \ 15 | "color" "shape" "color_boolean" "length_threshold_10" 16 | do 17 | for MODALITY in "image" "json" 18 | do 19 | for POOLING in "gap" "concat" "rel_net" "trnsf" 20 | do 21 | for JOB_TYPE in "train" "eval" 22 | do 23 | for LANGUAGE_ALPHA in "0.0" "1.0" 24 | do 25 | for NEGATIVE_TYPE in "alternate_hypotheses" "random" 26 | do 27 | for REPLICA_JOBS in "0" "1" "2" 28 | do 29 | EVAL_STR="" 30 | if [ $JOB_TYPE = "eval" ]; 31 | then 32 | EVAL_STR="eval_cfg.evaluate_all=True\ 33 | eval_cfg.write_raw_metrics=True" 34 | fi 35 | CMD="python hydra_${JOB_TYPE}.py \ 36 | ${EVAL_STR}\ 37 | hydra.job.name=${JOB_NAME}\ 38 | mode=${JOB_TYPE} \ 39 | model_or_oracle_metrics=${MODEL_OR_ORACLE} \ 40 | modality=${MODALITY}\ 41 | pooling=${POOLING}\ 42 | data.split_type=${SPLIT_TYPE} \ 43 | data.negative_type=${NEGATIVE_TYPE}\ 44 | loss.params.alpha=${LANGUAGE_ALPHA} \ 45 | job_replica=${REPLICA_JOBS}" 46 | echo ${CMD} 47 | eval ${CMD} 48 | done 49 | done 50 | done 51 | done 52 | done 53 | done 54 | done -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /models/__pycache__/.nfs00780000024e4221000025e7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/productive_concept_learning/d00a3b85215bff0686ad3f834d50c8eae374f06a/models/__pycache__/.nfs00780000024e4221000025e7 -------------------------------------------------------------------------------- /models/_evaluator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import math 7 | import torch 8 | from torch import nn 9 | from torch.utils.data import DataLoader 10 | from tqdm import tqdm 11 | 12 | from typing import Any, Dict, List, Optional, Type 13 | 14 | from omegaconf import OmegaConf 15 | 16 | from losses import _LossFun 17 | from losses import NegativeLogLikelihoodMultiTask 18 | 19 | 20 | class _Evaluator(object): 21 | r""" 22 | A base class for generic evaluation of models. This class can have multiple models interacting 23 | with each other, rather than a single model, which is suitable to our use-case (for example, 24 | ``module_training`` phase has two models: 25 | :class:`~probnmn.models.program_generator.ProgramGenerator` and 26 | :class:`~probnmn.models.nmn.NeuralModuleNetwork`). It offers full flexibility, with sensible 27 | defaults which may be changed (or disabled) while extending this class. 28 | 29 | Extended Summary 30 | ---------------- 31 | Extend this class and override :meth:`_do_iteration` method, with core evaluation loop - what 32 | happens every iteration, given a ``batch`` from the dataloader this class holds. 33 | 34 | Notes 35 | ----- 36 | 1. All models are `passed by assignment`, so they could be shared with an external trainer. 37 | Do not set ``self._models = ...`` anywhere while extending this class. 38 | 39 | 2. An instantiation of this class will always be paired in conjunction to a 40 | :class:`~probnmn.trainers._trainer._Trainer`. Pass the models of trainer class while 41 | instantiating this class. 42 | 43 | Parameters 44 | ---------- 45 | config: Config 46 | A :class:`~probnmn.Config` object with all the relevant configuration parameters. 47 | dataloader: torch.utils.data.DataLoader 48 | A :class:`~torch.utils.data.DataLoader` which provides batches of evaluation examples. It 49 | wraps one of :mod:`probnmn.data.datasets` depending on the evaluation phase. 50 | models: Dict[str, Type[nn.Module]] 51 | All the models which interact with each other for evaluation. These are one or more from 52 | :mod:`probnmn.models` depending on the evaluation phase. 53 | gpu_ids: List[int], optional (default=[0]) 54 | List of GPU IDs to use or evaluation, ``[-1]`` - use CPU. 55 | """ 56 | 57 | def __init__( 58 | self, 59 | config: OmegaConf, 60 | loss_fn: _LossFun, 61 | dataloader: DataLoader, 62 | models: Dict[str, Type[nn.Module]], 63 | gpu_ids: List[int] = [0], 64 | ): 65 | self._C = config 66 | self._dataloader = dataloader 67 | self._models = models 68 | self._loss_fn = loss_fn 69 | 70 | # Set device according to specified GPU ids. This device is only required for batches, 71 | # models will already be on apropriate device already, if passed from trainer. 72 | self._device = torch.device(f"cuda:{gpu_ids[0]}" if gpu_ids[0] >= 0 else "cpu") 73 | 74 | @property 75 | def models(self): 76 | return self._models 77 | 78 | def evaluate(self, eval_object: str = "model", 79 | num_batches: Optional[int] = None, 80 | output_raw_preds=False) -> Dict[str, Any]: 81 | r""" 82 | Perform evaluation using first ``num_batches`` of dataloader and return all evaluation 83 | metrics from the models. 84 | 85 | Parameters 86 | ---------- 87 | eval_object: str, optional (default=None), can be one of "random", 88 | "weak_oracle" or "oracle", or "model" 89 | num_batches: int, optional (default=None) 90 | Number of batches to use from dataloader. If ``None``, use all batches. 91 | 92 | Returns 93 | ------- 94 | Dict[str, Any] 95 | Final evaluation metrics for all the models. 96 | Dict[str, Any] 97 | Raw predictions from the model 98 | """ 99 | # Switch all models to "eval" mode. 100 | for model_name in self._models: 101 | self._models[model_name].eval() 102 | 103 | model_outputs = [] 104 | with torch.no_grad(): 105 | for iteration, batch in enumerate(self._dataloader): 106 | for key in batch: 107 | if isinstance(batch[key], torch.Tensor): 108 | batch[key] = batch[key].to(self._device) 109 | 110 | model_outputs.append( 111 | self._do_iteration(batch, self._loss_fn, eval_object)) 112 | 113 | if num_batches is not None and iteration > num_batches: 114 | break 115 | 116 | eval_metrics = self._loss_fn.aggregate_metrics() 117 | 118 | if len(self._models) > 1 or list(self._models.keys())[0] != "model": 119 | raise NotImplementedError("Only supports one model.") 120 | 121 | # Switch all models back to "train" mode. 122 | for model_name in self._models: 123 | self._models[model_name].train() 124 | 125 | if output_raw_preds == True: 126 | return eval_metrics, model_outputs 127 | 128 | return eval_metrics 129 | 130 | def _do_iteration(self, batch: Dict[str, Any], loss_fn: _LossFun, 131 | eval_object: str = "model") -> Dict[str, Any]: 132 | r""" 133 | Core evaluation logic for one iteration, operates on a batch. This base class has a dummy 134 | implementation - just forward pass through some "model". 135 | 136 | Parameters 137 | ---------- 138 | batch: Dict[str, Any] 139 | A batch of evaluation examples sampled from dataloader. See :func:`evaluate` on how 140 | this batch is sampled. 141 | 142 | Returns 143 | ------- 144 | Dict[str, Any] 145 | An output dictionary typically returned by the models. This may contain predictions 146 | from models, validation loss etc. 147 | """ 148 | # Multiple evaluation objects are only supported for the meta 149 | # learning negative log likelihood metric 150 | if not isinstance(loss_fn, NegativeLogLikelihoodMultiTask): 151 | raise NotImplementedError("eval_object support only for NLL loss.") 152 | 153 | if eval_object == "model": 154 | output_dict = self._models["model"](batch) 155 | elif "oracle" in eval_object: 156 | if eval_object == "oracle": 157 | posterior_dist = batch["posterior_probs_sparse"] 158 | elif eval_object == "weak_oracle": 159 | posterior_dist = batch["posterior_probs_train_sparse"] 160 | 161 | # B x N x H 162 | query_multihot_perdata_labels = batch[ 163 | "query_multihot_perdata_labels"] 164 | b = query_multihot_perdata_labels.shape[0] 165 | 166 | batch_eval_scores = [] 167 | for it in range(b): 168 | this_multihot_labels = query_multihot_perdata_labels[it] 169 | eval_scores = ( 170 | posterior_dist[it].cpu().detach().unsqueeze(0) * 171 | this_multihot_labels.cpu() 172 | ).sum(-1) 173 | batch_eval_scores.append(eval_scores) 174 | 175 | batch_eval_scores = torch.stack(batch_eval_scores, dim=0) 176 | 177 | elif eval_object == "random": 178 | labels = batch["query_labels"] 179 | batch_eval_scores = torch.rand(labels.shape[0], labels.shape[1]) 180 | 181 | if eval_object != "model": 182 | eval_probs = torch.zeros(batch_eval_scores.shape[0], 183 | batch_eval_scores.shape[1], 184 | 2) 185 | 186 | eval_probs[:, :, self._dataloader.dataset.true_class_id] = batch_eval_scores 187 | eval_probs[:, :, 1 - self._dataloader.dataset.true_class_id] = 1 - batch_eval_scores 188 | 189 | # Add a small quantity for log 0, does not normalize but we dont 190 | # care here. 191 | output_dict = { 192 | "neg_log_p_y": -1 * torch.log(eval_probs + 1e-12), # For log 0 193 | } 194 | output_dict["neg_log_p_y"] = output_dict["neg_log_p_y"].to( 195 | self._device) 196 | 197 | loss_fn(output_dict, batch, metric_prefix=eval_object) 198 | return output_dict 199 | -------------------------------------------------------------------------------- /models/_map_evaluator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from typing import Any, Dict, List, Optional, Type 7 | 8 | import torch 9 | import time 10 | import logging 11 | 12 | from collections import defaultdict 13 | 14 | from torch import nn 15 | from torch.utils.data import DataLoader 16 | from torch.utils.data import RandomSampler 17 | from omegaconf import OmegaConf 18 | 19 | from models._evaluator import _Evaluator 20 | from losses import _LossFun 21 | from losses import MetaLearningMeanAveragePrecision 22 | from models.utils import dict_to_device 23 | 24 | 25 | class _MapEvaluator(_Evaluator): 26 | r""" 27 | A base class for generic evaluation of models. This class can have multiple models interacting 28 | with each other, rather than a single model, which is suitable to our use-case (for example, 29 | ``module_training`` phase has two models: 30 | :class:`~probnmn.models.program_generator.ProgramGenerator` and 31 | :class:`~probnmn.models.nmn.NeuralModuleNetwork`). It offers full flexibility, with sensible 32 | defaults which may be changed (or disabled) while extending this class. 33 | 34 | Extended Summary 35 | ---------------- 36 | Extend this class and override :meth:`_do_iteration` method, with core evaluation loop - what 37 | happens every iteration, given a ``batch`` from the dataloader this class holds. 38 | 39 | Notes 40 | ----- 41 | 1. All models are `passed by assignment`, so they could be shared with an external trainer. 42 | Do not set ``self._models = ...`` anywhere while extending this class. 43 | 44 | 2. An instantiation of this class will always be paired in conjunction to a 45 | :class:`~probnmn.trainers._trainer._Trainer`. Pass the models of trainer class while 46 | instantiating this class. 47 | 48 | Parameters 49 | ---------- 50 | config: Config 51 | A :class:`~probnmn.Config` object with all the relevant configuration parameters. 52 | dataloader: torch.utils.data.DataLoader 53 | A :class:`~torch.utils.data.DataLoader` which provides batches of evaluation examples. It 54 | wraps one of :mod:`probnmn.data.datasets` depending on the evaluation phase. 55 | models: Dict[str, Type[nn.Module]] 56 | All the models which interact with each other for evaluation. These are one or more from 57 | :mod:`probnmn.models` depending on the evaluation phase. 58 | gpu_ids: List[int], optional (default=[0]) 59 | List of GPU IDs to use or evaluation, ``[-1]`` - use CPU. 60 | """ 61 | 62 | def __init__( 63 | self, 64 | config: OmegaConf, 65 | loss_fn: MetaLearningMeanAveragePrecision, 66 | dataloader: DataLoader, 67 | test_loader: DataLoader, 68 | models: Dict[str, Type[nn.Module]], 69 | gpu_ids: List[int] = [0], 70 | ): 71 | r""" 72 | Initialize the MapEvaluator object. 73 | 74 | Args: 75 | config: An `OmegaConf` object provided by hydra. 76 | dataloader: A meta learning dataloader with query and support. 77 | test_loader: A dataloader to access a distinct set of images. 78 | models: Dict, with key values str, and model. 79 | gpu_ids: Which gpus to run evaluation on. 80 | """ 81 | self._C = config 82 | self._dataloader = dataloader 83 | self._test_loader = test_loader 84 | self._models = models 85 | self._loss_fn = loss_fn 86 | 87 | if isinstance(self._test_loader.sampler, RandomSampler): 88 | raise ValueError("Expect no shuffling in the test loader.") 89 | 90 | if isinstance(self._dataloader.sampler, RandomSampler): 91 | raise ValueError("Expect no shuffling in the validation loader.") 92 | 93 | if not isinstance(self._loss_fn, MetaLearningMeanAveragePrecision): 94 | raise ValueError("Expect meta learning ap object for the loss") 95 | 96 | # Set device according to specified GPU ids. This device is only required for batches, 97 | # models will already be on apropriate device already, if passed from trainer. 98 | self._device = torch.device( 99 | f"cuda:{gpu_ids[0]}" if gpu_ids[0] >= 0 else "cpu") 100 | 101 | self.all_hyp_str_to_idx = { 102 | v: k 103 | for k, v in enumerate(self._dataloader.dataset.all_hypotheses_across_splits) 104 | } 105 | self.num_total_hypotheses = len( 106 | self._dataloader.dataset.all_hypotheses_across_splits) 107 | 108 | logging.info("Setting up MAP evaluation.") 109 | with torch.no_grad(): 110 | self._held_out_image_hypotheses = [] 111 | for _, held_out_batch in enumerate(self._test_loader): 112 | self._held_out_image_hypotheses.append( 113 | held_out_batch["labels"].bool()) 114 | 115 | self._held_out_image_hypotheses = torch.cat( 116 | self._held_out_image_hypotheses, dim=0) 117 | logging.info("Done setting up map evaluation.") 118 | 119 | def evaluate(self, num_batches: Optional[int] = None, 120 | eval_object:str ="model", 121 | output_raw_results=False) -> Dict[str, Any]: 122 | r""" 123 | Perform evaluation using first ``num_batches`` of dataloader and return all evaluation 124 | metrics from the models. 125 | 126 | Args: 127 | num_batches: int, optional (default=None) 128 | Number of batches to use from dataloader. If ``None``, use all batches. 129 | eval_object: str, kind of model / approach to evaluate. 130 | output_raw_results: bool, whether we output raw results or not 131 | Returns: 132 | Dict[str, Any] 133 | Final evaluation metrics for all the models. 134 | """ 135 | if num_batches is not None: 136 | logging.info(f"Evaluating on {num_batches} batches.") 137 | 138 | self._held_out_features = [] 139 | self._held_out_image_paths = [] 140 | 141 | # Switch all models to "eval" mode. 142 | for model_name in self._models: 143 | self._models[model_name].eval() 144 | 145 | model_output = [] 146 | 147 | with torch.no_grad(): 148 | if eval_object == "model": 149 | for _, held_out_batch in enumerate(self._test_loader): 150 | held_out_batch = dict_to_device(held_out_batch, 151 | self._device) 152 | feat = self._models["model"].creator.encoder( 153 | held_out_batch["datum"]) 154 | self._held_out_features.append(feat) 155 | self._held_out_image_paths.append(held_out_batch["path"]) 156 | 157 | cpu_only_tensors = [ 158 | "all_consistent_hypotheses_idx_sparse", 159 | "posterior_probs_sparse", "posterior_probs_train_sparse" 160 | ] 161 | 162 | for iteration, batch in enumerate( 163 | self._dataloader): 164 | for key in batch: 165 | if isinstance(batch[key], torch.Tensor): 166 | if key not in cpu_only_tensors: 167 | batch[key] = batch[key].to(self._device) 168 | 169 | model_output.append( 170 | self._do_iteration(batch, self._loss_fn, eval_object)) 171 | 172 | if (iteration + 1) % 50 == 0: 173 | logging.info("Finished %d steps of evaluation.", iteration) 174 | 175 | if num_batches is not None and iteration > num_batches: 176 | break 177 | 178 | # keys: `self._models.keys()` 179 | eval_metrics = self._loss_fn.aggregate_metrics() 180 | 181 | # Switch all models back to "train" mode. 182 | for model_name in self._models: 183 | self._models[model_name].train() 184 | 185 | if output_raw_results == True: 186 | return eval_metrics, model_output 187 | 188 | return eval_metrics 189 | 190 | def _do_iteration(self, batch: Dict[str, Any], loss_fn: _LossFun, 191 | eval_object: str) -> Dict[str, Any]: 192 | r""" 193 | Takes as input a batch of meta learning examples to evaluate, and 194 | iterates over a large set of test data points (provided by the 195 | test_loader) to compute metrics. 196 | 197 | Dynamically figures out the labels for the examples in the test set and 198 | computes the mean average precision loss accordingly. 199 | """ 200 | if not isinstance(loss_fn, MetaLearningMeanAveragePrecision): 201 | raise NotImplementedError("eval_object support only for mAP" 202 | " loss.") 203 | 204 | classifier = self._models["model"].creator(batch["support_images"], 205 | batch["support_labels"]) 206 | labels = batch["all_consistent_hypotheses_idx_sparse"].unsqueeze( 207 | 1) * self._held_out_image_hypotheses.unsqueeze(0) 208 | if eval_object == "model": 209 | eval_scores = [] 210 | for test_feat in self._held_out_features: 211 | log_prob_label = torch.squeeze( 212 | -1 * self._models["model"].applier(classifier, test_feat) 213 | ["neg_log_p_y"][:, :, self._dataloader.dataset. 214 | true_class_id]) # B x N x L 215 | eval_scores.append(log_prob_label.cpu().detach()) 216 | eval_scores = torch.cat(eval_scores, dim=1) 217 | elif eval_object == "oracle": 218 | eval_scores = ( 219 | batch["posterior_probs_sparse"].detach().unsqueeze(1) * 220 | self._held_out_image_hypotheses.type( 221 | torch.float).unsqueeze(0)).sum(-1) 222 | elif eval_object == "weak_oracle": 223 | eval_scores = ( 224 | batch["posterior_probs_train_sparse"].detach().unsqueeze(1) * 225 | self._held_out_image_hypotheses.type( 226 | torch.float).unsqueeze(0)).sum(-1) 227 | elif eval_object == "random": 228 | eval_scores = torch.rand(labels.shape[0], labels.shape[1]) 229 | loss_fn({ 230 | "scores": eval_scores, 231 | "gt_labels": labels, 232 | }, 233 | batch, 234 | metric_prefix=eval_object) 235 | # Filter the gt labels only corresponding to the hypotheses we are 236 | # working with based on the support set. Above we needed lables for 237 | # all hypotheses for the possibility of computing metrics like expected 238 | # mAP etc. 239 | gt_labels = [] 240 | for idx, this_hyp in enumerate(list(batch['hypotheses_idx_dense'])): 241 | gt_labels.append(labels[idx, :, this_hyp]) 242 | gt_labels = torch.stack(gt_labels, dim=0) 243 | 244 | return {"scores": eval_scores, "gt_labels": gt_labels} 245 | -------------------------------------------------------------------------------- /models/audio_resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """A resnet ultity with 1d convolutions. 7 | 8 | Adapts the ResNet [code]( 9 | https://pytorch.org/docs/stable/_modules/torchvision/models/resnet.html#resnet18) 10 | to include 1d convolutions instead of 2d convolutions. 11 | 12 | This is useful, for example when processing audio. 13 | """ 14 | 15 | import torch 16 | import torch.nn as nn 17 | 18 | __all__ = ['OneDimResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 19 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 20 | 'wide_resnet50_2', 'wide_resnet101_2'] 21 | 22 | 23 | 24 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 25 | """3x3 convolution with padding""" 26 | return nn.Conv1d(in_planes, out_planes, kernel_size=3, stride=stride, 27 | padding=dilation, groups=groups, bias=False, dilation=dilation) 28 | 29 | 30 | def conv1x1(in_planes, out_planes, stride=1): 31 | """1x1 convolution""" 32 | return nn.Conv1d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 33 | 34 | 35 | class BasicBlock(nn.Module): 36 | expansion = 1 37 | __constants__ = ['downsample'] 38 | 39 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 40 | base_width=64, dilation=1, norm_layer=None): 41 | super(BasicBlock, self).__init__() 42 | if norm_layer is None: 43 | norm_layer = nn.BatchNorm1d 44 | if groups != 1 or base_width != 64: 45 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 46 | if dilation > 1: 47 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 48 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 49 | self.conv1 = conv3x3(inplanes, planes, stride) 50 | self.bn1 = norm_layer(planes) 51 | self.relu = nn.ReLU(inplace=True) 52 | self.conv2 = conv3x3(planes, planes) 53 | self.bn2 = norm_layer(planes) 54 | self.downsample = downsample 55 | self.stride = stride 56 | 57 | def forward(self, x): 58 | identity = x 59 | 60 | out = self.conv1(x) 61 | out = self.bn1(out) 62 | out = self.relu(out) 63 | 64 | out = self.conv2(out) 65 | out = self.bn2(out) 66 | 67 | if self.downsample is not None: 68 | identity = self.downsample(x) 69 | 70 | out += identity 71 | out = self.relu(out) 72 | 73 | return out 74 | 75 | 76 | class Bottleneck(nn.Module): 77 | expansion = 4 78 | __constants__ = ['downsample'] 79 | 80 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 81 | base_width=64, dilation=1, norm_layer=None): 82 | super(Bottleneck, self).__init__() 83 | if norm_layer is None: 84 | norm_layer = nn.BatchNorm1d 85 | width = int(planes * (base_width / 64.)) * groups 86 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 87 | self.conv1 = conv1x1(inplanes, width) 88 | self.bn1 = norm_layer(width) 89 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 90 | self.bn2 = norm_layer(width) 91 | self.conv3 = conv1x1(width, planes * self.expansion) 92 | self.bn3 = norm_layer(planes * self.expansion) 93 | self.relu = nn.ReLU(inplace=True) 94 | self.downsample = downsample 95 | self.stride = stride 96 | 97 | def forward(self, x): 98 | identity = x 99 | 100 | out = self.conv1(x) 101 | out = self.bn1(out) 102 | out = self.relu(out) 103 | 104 | out = self.conv2(out) 105 | out = self.bn2(out) 106 | out = self.relu(out) 107 | 108 | out = self.conv3(out) 109 | out = self.bn3(out) 110 | 111 | if self.downsample is not None: 112 | identity = self.downsample(x) 113 | 114 | out += identity 115 | out = self.relu(out) 116 | 117 | return out 118 | 119 | 120 | class OneDimResNet(nn.Module): 121 | 122 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 123 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 124 | norm_layer=None, input_channels=3): 125 | 126 | super(OneDimResNet, self).__init__() 127 | if norm_layer is None: 128 | norm_layer = nn.BatchNorm1d 129 | self._norm_layer = norm_layer 130 | 131 | self.inplanes = 64 132 | self.input_channels = input_channels 133 | self.dilation = 1 134 | if replace_stride_with_dilation is None: 135 | # each element in the tuple indicates if we should replace 136 | # the 2x2 stride with a dilated convolution instead 137 | replace_stride_with_dilation = [False, False, False] 138 | if len(replace_stride_with_dilation) != 3: 139 | raise ValueError("replace_stride_with_dilation should be None " 140 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 141 | self.groups = groups 142 | self.base_width = width_per_group 143 | self.conv1 = nn.Conv1d(self.input_channels, self.inplanes, kernel_size=7, stride=2, padding=3, 144 | bias=False) 145 | self.bn1 = norm_layer(self.inplanes) 146 | self.relu = nn.ReLU(inplace=True) 147 | self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) 148 | self.layer1 = self._make_layer(block, 64, layers[0]) 149 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 150 | dilate=replace_stride_with_dilation[0]) 151 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 152 | dilate=replace_stride_with_dilation[1]) 153 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 154 | dilate=replace_stride_with_dilation[2]) 155 | self.avgpool = nn.AdaptiveAvgPool1d((1, 1)) 156 | self.fc = nn.Linear(512 * block.expansion, num_classes) 157 | 158 | for m in self.modules(): 159 | if isinstance(m, nn.Conv1d): 160 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 161 | elif isinstance(m, (nn.BatchNorm1d, nn.GroupNorm)): 162 | nn.init.constant_(m.weight, 1) 163 | nn.init.constant_(m.bias, 0) 164 | 165 | # Zero-initialize the last BN in each residual branch, 166 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 167 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 168 | if zero_init_residual: 169 | for m in self.modules(): 170 | if isinstance(m, Bottleneck): 171 | nn.init.constant_(m.bn3.weight, 0) 172 | elif isinstance(m, BasicBlock): 173 | nn.init.constant_(m.bn2.weight, 0) 174 | 175 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 176 | norm_layer = self._norm_layer 177 | downsample = None 178 | previous_dilation = self.dilation 179 | if dilate: 180 | self.dilation *= stride 181 | stride = 1 182 | if stride != 1 or self.inplanes != planes * block.expansion: 183 | downsample = nn.Sequential( 184 | conv1x1(self.inplanes, planes * block.expansion, stride), 185 | norm_layer(planes * block.expansion), 186 | ) 187 | 188 | layers = [] 189 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 190 | self.base_width, previous_dilation, norm_layer)) 191 | self.inplanes = planes * block.expansion 192 | for _ in range(1, blocks): 193 | layers.append(block(self.inplanes, planes, groups=self.groups, 194 | base_width=self.base_width, dilation=self.dilation, 195 | norm_layer=norm_layer)) 196 | 197 | return nn.Sequential(*layers) 198 | 199 | def _forward_impl(self, x): 200 | # See note [TorchScript super()] 201 | x = self.conv1(x) 202 | x = self.bn1(x) 203 | x = self.relu(x) 204 | x = self.maxpool(x) 205 | 206 | x = self.layer1(x) 207 | x = self.layer2(x) 208 | x = self.layer3(x) 209 | x = self.layer4(x) 210 | 211 | x = self.avgpool(x) 212 | x = torch.flatten(x, 1) 213 | x = self.fc(x) 214 | 215 | return x 216 | 217 | def forward(self, x): 218 | return self._forward_impl(x) 219 | 220 | 221 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 222 | model = OneDimResNet(block, layers, **kwargs) 223 | if pretrained == True: 224 | raise ValueError("No pretrained models available for audio.") 225 | return model 226 | 227 | 228 | def resnet18(pretrained=False, progress=True, **kwargs): 229 | r"""OneDimResNet-18 model from 230 | `"Deep Residual Learning for Image Recognition" `_ 231 | 232 | Args: 233 | pretrained (bool): If True, returns a model pre-trained on ImageNet 234 | progress (bool): If True, displays a progress bar of the download to stderr 235 | """ 236 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 237 | **kwargs) 238 | 239 | 240 | 241 | def resnet34(pretrained=False, progress=True, **kwargs): 242 | r"""OneDimResNet-34 model from 243 | `"Deep Residual Learning for Image Recognition" `_ 244 | 245 | Args: 246 | pretrained (bool): If True, returns a model pre-trained on ImageNet 247 | progress (bool): If True, displays a progress bar of the download to stderr 248 | """ 249 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 250 | **kwargs) 251 | 252 | 253 | 254 | def resnet50(pretrained=False, progress=True, **kwargs): 255 | r"""OneDimResNet-50 model from 256 | `"Deep Residual Learning for Image Recognition" `_ 257 | 258 | Args: 259 | pretrained (bool): If True, returns a model pre-trained on ImageNet 260 | progress (bool): If True, displays a progress bar of the download to stderr 261 | """ 262 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 263 | **kwargs) 264 | 265 | 266 | 267 | def resnet101(pretrained=False, progress=True, **kwargs): 268 | r"""OneDimResNet-101 model from 269 | `"Deep Residual Learning for Image Recognition" `_ 270 | 271 | Args: 272 | pretrained (bool): If True, returns a model pre-trained on ImageNet 273 | progress (bool): If True, displays a progress bar of the download to stderr 274 | """ 275 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 276 | **kwargs) 277 | 278 | 279 | 280 | def resnet152(pretrained=False, progress=True, **kwargs): 281 | r"""OneDimResNet-152 model from 282 | `"Deep Residual Learning for Image Recognition" `_ 283 | 284 | Args: 285 | pretrained (bool): If True, returns a model pre-trained on ImageNet 286 | progress (bool): If True, displays a progress bar of the download to stderr 287 | """ 288 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 289 | **kwargs) 290 | 291 | 292 | 293 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 294 | r"""ResNeXt-50 32x4d model from 295 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 296 | 297 | Args: 298 | pretrained (bool): If True, returns a model pre-trained on ImageNet 299 | progress (bool): If True, displays a progress bar of the download to stderr 300 | """ 301 | kwargs['groups'] = 32 302 | kwargs['width_per_group'] = 4 303 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 304 | pretrained, progress, **kwargs) 305 | 306 | 307 | 308 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 309 | r"""ResNeXt-101 32x8d model from 310 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 311 | 312 | Args: 313 | pretrained (bool): If True, returns a model pre-trained on ImageNet 314 | progress (bool): If True, displays a progress bar of the download to stderr 315 | """ 316 | kwargs['groups'] = 32 317 | kwargs['width_per_group'] = 8 318 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 319 | pretrained, progress, **kwargs) 320 | 321 | 322 | 323 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 324 | r"""Wide OneDimResNet-50-2 model from 325 | `"Wide Residual Networks" `_ 326 | 327 | The model is the same as OneDimResNet except for the bottleneck number of channels 328 | which is twice larger in every block. The number of channels in outer 1x1 329 | convolutions is the same, e.g. last block in OneDimResNet-50 has 2048-512-2048 330 | channels, and in Wide OneDimResNet-50-2 has 2048-1024-2048. 331 | 332 | Args: 333 | pretrained (bool): If True, returns a model pre-trained on ImageNet 334 | progress (bool): If True, displays a progress bar of the download to stderr 335 | """ 336 | kwargs['width_per_group'] = 64 * 2 337 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 338 | pretrained, progress, **kwargs) 339 | 340 | 341 | 342 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 343 | r"""Wide OneDimResNet-101-2 model from 344 | `"Wide Residual Networks" `_ 345 | 346 | The model is the same as OneDimResNet except for the bottleneck number of channels 347 | which is twice larger in every block. The number of channels in outer 1x1 348 | convolutions is the same, e.g. last block in OneDimResNet-50 has 2048-512-2048 349 | channels, and in Wide OneDimResNet-50-2 has 2048-1024-2048. 350 | 351 | Args: 352 | pretrained (bool): If True, returns a model pre-trained on ImageNet 353 | progress (bool): If True, displays a progress bar of the download to stderr 354 | """ 355 | kwargs['width_per_group'] = 64 * 2 356 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 357 | pretrained, progress, **kwargs) -------------------------------------------------------------------------------- /models/encoders.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Encoders for prototypical networks. 7 | 8 | Provides different encoders for prototypical networks corresponding to 9 | modalities like image, json and sound. 10 | """ 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | import torch.functional as F 15 | 16 | from frozendict import frozendict 17 | from itertools import product 18 | 19 | from models.utils import SumPool 20 | from models.utils import build_resnet_base 21 | from models.utils import UnsqueezeRepeatTensor 22 | 23 | _MODEL_STAGE_RESNET = 3 24 | _RESNET_FEATURE_DIM_FOR_STAGE = frozendict({3: 256}) 25 | _SOUND_SPECTROGRAM_SIZE_FOR_SOUND = frozendict({120000: 201}) 26 | _IMAGE_ENCODER_POST_RESNET_STRIDE = (1, 2) 27 | _JSON_EMBED_TO_INPUT_RATIO = 3 28 | 29 | 30 | class PermuteTensor(nn.Module): 31 | def __init__(self, perm_order): 32 | super(PermuteTensor, self).__init__() 33 | self._perm_order = perm_order 34 | 35 | def forward(self, x): 36 | return x.permute(self._perm_order) 37 | 38 | 39 | class FlattenTensor(nn.Module): 40 | def __init__(self, dim): 41 | super(FlattenTensor, self).__init__() 42 | self._dim = dim 43 | 44 | def forward(self, x): 45 | return x.flatten(self._dim) 46 | 47 | 48 | #class OuterProductFlattenTensor(nn.Module): 49 | # def __init__(self, dim): 50 | # super(OuterProductFlattenTensor, self).__init__() 51 | # self_dim = dim 52 | # 53 | # def forward(self, x): 54 | # """Take an ensemble of differences and use it to create feature.""" 55 | # # x is [B x 5 x 19] 56 | # x = x.permute(0, ) 57 | # y = x.unsqueeze(-1).repeat(1, 1, 1, x.shape(1)).view(x.shape(0), x.shape(1) * x.shape(1), -1) 58 | # 59 | # x = x.sum(dim=) 60 | 61 | class OptionalPositionEncoding(nn.Module): 62 | """An nn.module class for position encoding.""" 63 | def __init__(self, num_objects_list, position_encoding=True): 64 | super(OptionalPositionEncoding, self).__init__() 65 | self.num_objects_list = num_objects_list 66 | self.position_encoding = position_encoding 67 | 68 | if position_encoding == True: 69 | position_embeds = torch.zeros(np.prod(self.num_objects_list), 70 | np.sum(self.num_objects_list)) 71 | locations_offset = ([0] + 72 | list(np.cumsum(self.num_objects_list)))[:-1] 73 | 74 | locations_iteator = [ 75 | range(offset, x + offset) 76 | for x, offset in zip(self.num_objects_list, locations_offset) 77 | ] 78 | 79 | for prod_idx, locations in enumerate(product(*locations_iteator)): 80 | for loc in locations: 81 | position_embeds[prod_idx][loc] = 1.0 82 | position_embeds = position_embeds.unsqueeze(0).permute(0, 2, 1) 83 | position_embeds = position_embeds.reshape( 84 | position_embeds.size(0), position_embeds.size(1), *self.num_objects_list) 85 | 86 | self.register_buffer('position_embeds', position_embeds) 87 | 88 | def forward(self, x): 89 | if self.position_encoding == True: 90 | position_embeds = self.position_embeds.repeat( 91 | x.size(0), *([1] * (1 + len(self.num_objects_list)))) 92 | return torch.cat([x, position_embeds], dim=1) 93 | return x 94 | 95 | def build_concat_pooling(embed_dim, object_dims, feature_dim): 96 | concat_pooling = nn.Sequential( 97 | FlattenTensor(dim=2), # Flatten all the objects. # B x F x {O} 98 | PermuteTensor((0, 2, 1)), # B x {O} x F 99 | FlattenTensor(dim=1), # B x F x {O} 100 | nn.Linear(embed_dim * np.prod(object_dims), 256), 101 | nn.BatchNorm1d(256), 102 | nn.ReLU(), 103 | nn.Linear(256, 512, bias=True), # Just replace with 1x1 conv. 104 | nn.BatchNorm1d(512), 105 | nn.ReLU(), 106 | nn.Linear(512, 256, bias=True), # Just replace with 1x1 conv. 107 | nn.BatchNorm1d(256), 108 | nn.ReLU(), 109 | nn.Linear(256, feature_dim, bias=True), 110 | ) 111 | return concat_pooling 112 | 113 | 114 | class TransformerPooling(nn.Module): 115 | """Inspired by the model in the following PGM paper: 116 | 117 | Wang, Duo, Mateja Jamnik, and Pietro Lio. 2019. 118 | Abstract Diagrammatic Reasoning with Multiplex Graph Networks. 119 | https://openreview.net/pdf?id=ByxQB1BKwH. 120 | """ 121 | def __init__(self, obj_fdim, num_objects_list, output_dim_f, 122 | n_head=2, num_layers=4, dim_feedforward=512): 123 | super(TransformerPooling, self).__init__() 124 | 125 | if obj_fdim % n_head != 0: 126 | raise ValueError("Object dim must be divisible by num heads.") 127 | 128 | encoder_layer = nn.TransformerEncoderLayer( 129 | d_model=obj_fdim, nhead=n_head, dim_feedforward=dim_feedforward) 130 | transformer_encoder = nn.TransformerEncoder(encoder_layer, 131 | num_layers=num_layers) 132 | self._trnsf = transformer_encoder 133 | 134 | # 4 here is because input is concat(max(), min(), sum(), mean()) 135 | self._embedder = nn.Linear(4 * obj_fdim, output_dim_f) 136 | self._num_objects_list = num_objects_list 137 | 138 | def forward(self, x): 139 | n_objects = np.prod(self._num_objects_list) 140 | 141 | # Reshape to [B x C x n_objects] 142 | x = x.view(x.size(0), x.size(1), n_objects) 143 | # Transpose to [n_objects x B x C] 144 | x = x.permute(2, 0, 1) 145 | feat_x = self._trnsf(x) 146 | 147 | # Transpose to [B x C x n_objects] 148 | feat_x = feat_x.permute(1, 2, 0) 149 | 150 | # Transpose to [B x 4*C] 151 | feat_x = torch.cat([feat_x.max(-1).values, 152 | feat_x.min(-1).values, 153 | feat_x.sum(-1), 154 | feat_x.mean(-1)], dim=-1) 155 | 156 | return self._embedder(feat_x) 157 | 158 | 159 | def build_json_object_encoder(feature_dim, pretrained_object_encoder, input_feature_dim): 160 | if pretrained_object_encoder == True: 161 | raise ValueError("Cannot use a pretrained encoder for _JSONs") 162 | 163 | input_true_feat_dim = int(input_feature_dim.split(",")[-1]) 164 | if len(input_feature_dim.split(",")) != 2: 165 | raise ValueError 166 | json_embed_dim = _JSON_EMBED_TO_INPUT_RATIO * input_true_feat_dim 167 | 168 | object_feature_object_encoder = nn.Sequential( 169 | PermuteTensor((0, 2, 1)), # B x 19 x 5 170 | nn.Conv1d(input_true_feat_dim, json_embed_dim, 1), # B x 19*3 x 5 171 | nn.ReLU(), 172 | nn.Conv1d(json_embed_dim, feature_dim, 1), # B x 19*3*2 x 5 173 | ) 174 | return object_feature_object_encoder 175 | 176 | 177 | def build_image_object_encoder(feature_dim, pretrained_object_encoder, 178 | input_feature_dim, 179 | image_encoder_stride=_IMAGE_ENCODER_POST_RESNET_STRIDE): 180 | encoder = build_resnet_base(pretrained=pretrained_object_encoder) 181 | return nn.Sequential( 182 | encoder, 183 | nn.Conv2d(_RESNET_FEATURE_DIM_FOR_STAGE[_MODEL_STAGE_RESNET], 184 | feature_dim * 2, 185 | 1, 186 | stride=image_encoder_stride), 187 | nn.ReLU(), 188 | nn.Conv2d(feature_dim * 2, 189 | feature_dim, 190 | 1, 191 | stride=1)) 192 | 193 | 194 | 195 | def infer_output_dim(encoder, input_feature_dim): 196 | """Infer the output dimensions of the feature that the encoder produces. 197 | 198 | This is useful for initializing a relation network. 199 | """ 200 | input_feature_dim = [int(x) for x in input_feature_dim.split(',')] 201 | 202 | with torch.no_grad(): 203 | # Needs a batch size of atleast 2 for batchnorm in forward pass. 204 | dummy_input = torch.zeros([2] + input_feature_dim) 205 | encoder_out = encoder.forward(dummy_input) 206 | 207 | # Report the last few dimensions. 208 | if encoder_out.dim() <= 2: 209 | raise ValueError("Encoder output should have more than 2 dims.") 210 | 211 | return encoder_out.shape[1], encoder_out.shape[ 212 | 2:] # First dimension is batch, second is channels -------------------------------------------------------------------------------- /models/simple_lstm_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """LSTM decoder model.""" 7 | import torch 8 | import torch.nn as nn 9 | from fairseq.models import FairseqDecoder 10 | from fairseq.data import Dictionary 11 | 12 | 13 | class SimpleLSTMDecoder(FairseqDecoder): 14 | 15 | def __init__( 16 | self, 17 | dictionary, 18 | encoder_hidden_dim=128, 19 | embed_dim=128, 20 | hidden_dim=128, 21 | dropout=0.1, 22 | ): 23 | super().__init__(dictionary) 24 | 25 | self.encoder_hidden_dim = encoder_hidden_dim 26 | self.embed_dim = embed_dim 27 | self.hidden_dim = hidden_dim 28 | self.dropout = dropout 29 | 30 | # Our decoder will embed the inputs before feeding them to the LSTM. 31 | self.embed_tokens = nn.Embedding( 32 | num_embeddings=len(dictionary), 33 | embedding_dim=embed_dim, 34 | padding_idx=dictionary.pad(), 35 | ) 36 | self.dropout = nn.Dropout(p=dropout) 37 | 38 | # We'll use a single-layer, unidirectional LSTM for simplicity. 39 | self.lstm = nn.LSTM( 40 | # For the first layer we'll concatenate the Encoder's final hidden 41 | # state with the embedded target tokens. 42 | input_size=encoder_hidden_dim + embed_dim, 43 | hidden_size=hidden_dim, 44 | num_layers=1, 45 | bidirectional=False, 46 | ) 47 | 48 | self.encoder_projection = nn.Linear(encoder_hidden_dim, hidden_dim) 49 | 50 | # Define the output projection. 51 | self.output_projection = nn.Linear(hidden_dim, len(dictionary)) 52 | 53 | self.register_buffer( 54 | 'sequence_prepend', (torch.ones(1, 1) * self.dictionary.eos()).long()) 55 | 56 | # During training Decoders are expected to take the entire target sequence 57 | # (shifted right by one position) and produce logits over the vocabulary. 58 | # The *prev_output_tokens* tensor begins with the end-of-sentence symbol, 59 | # ``dictionary.eos()``, followed by the target sequence. 60 | def forward(self, output_tokens, encoder_out): 61 | """ 62 | Args: 63 | output_tokens (LongTensor): outputs of shape 64 | `(batch, tgt_len)`, for teacher forcing, without any shifting. 65 | encoder_out (Tensor, optional): output from the encoder, used for 66 | encoder-side attention 67 | 68 | Returns: 69 | tuple: 70 | - the last decoder layer's output of shape 71 | `(batch, tgt_len, vocab)` 72 | - the last decoder layer's attention weights of shape 73 | `(batch, tgt_len, src_len)` 74 | """ 75 | batch_size = output_tokens.size()[0] 76 | sequence_prepend = self.sequence_prepend.expand(batch_size, 1) 77 | prev_output_tokens = torch.cat( 78 | [sequence_prepend, output_tokens], 79 | dim=-1) 80 | 81 | bsz, tgt_len = prev_output_tokens.size() 82 | 83 | # Extract the final hidden state from the Encoder. 84 | final_encoder_hidden = encoder_out 85 | 86 | # Embed the target sequence, which has been shifted right by one 87 | # position and now starts with the end-of-sentence symbol. 88 | x = self.embed_tokens(prev_output_tokens) 89 | 90 | # Apply dropout. 91 | x = self.dropout(x) 92 | 93 | # Concatenate the Encoder's final hidden state to *every* embedded 94 | # target token. 95 | x = torch.cat( 96 | [x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)], 97 | dim=2, 98 | ) 99 | 100 | # Using PackedSequence objects in the Decoder is harder than in the 101 | # Encoder, since the targets are not sorted in descending length order, 102 | # which is a requirement of ``pack_padded_sequence()``. Instead we'll 103 | # feed nn.LSTM directly. 104 | initial_hidden_state = self.encoder_projection( 105 | final_encoder_hidden).unsqueeze(0) 106 | initial_state = ( 107 | initial_hidden_state, # hidden 108 | torch.zeros_like(initial_hidden_state), # cell 109 | ) 110 | output, _ = self.lstm( 111 | x.transpose(0, 1), # convert to shape `(tgt_len, bsz, dim)` 112 | initial_state, 113 | ) 114 | x = output.transpose(0, 1) # convert to shape `(bsz, tgt_len, hidden)` 115 | 116 | # Project the outputs to the size of the vocabulary, and retain the first 117 | # target_inds number of dimensions in the time dimension. 118 | x = self.output_projection(x)[:, :-1, :] 119 | 120 | # Return the logits and ``None`` for the attention weights 121 | return x, None -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import torch 7 | import torch.nn as nn 8 | import torchvision 9 | import models.audio_resnet as audio_resnet 10 | 11 | 12 | class EncoderSequential(nn.Sequential): 13 | def __init__(self, *args, modality="image", feature_dim=None): 14 | super(EncoderSequential, self).__init__(*args) 15 | self._modality = modality 16 | self._feature_dim = feature_dim 17 | 18 | @property 19 | def modality(self): 20 | return self._modality 21 | 22 | @property 23 | def feature_dim(self): 24 | return self._feature_dim 25 | 26 | 27 | class Squeeze(nn.Module): 28 | def __init__(self): 29 | super(Squeeze, self).__init__() 30 | 31 | def forward(self, x): 32 | return x.squeeze() 33 | 34 | 35 | class Flatten(nn.Module): 36 | def __init__(self): 37 | super(Flatten, self).__init__() 38 | 39 | def forward(self, x): 40 | return x.view(x.size(0), -1) 41 | 42 | 43 | class UnsqueezeRepeatTensor(nn.Module): 44 | def __init__(self, dim, repeat=3): 45 | super(UnsqueezeRepeatTensor, self).__init__() 46 | self._dim = dim 47 | self._repeat = repeat 48 | 49 | def forward(self, x): 50 | unsqueeze_pick = [1] * (x.dim() + 1) 51 | unsqueeze_pick[self._dim] = self._repeat 52 | return x.unsqueeze(self._dim).repeat(unsqueeze_pick) 53 | 54 | 55 | class ReshapeTensor(nn.Module): 56 | def __init__(self, new_shape): 57 | super(ReshapeTensor, self).__init__() 58 | self._new_shape = new_shape 59 | 60 | def forward(self, x): 61 | return x.reshape(self._new_shape) 62 | 63 | class SumPool(nn.Module): 64 | def __init__(self, dim): 65 | super(SumPool, self).__init__() 66 | self._dim = dim 67 | 68 | def forward(self, x): 69 | return x.sum(dim=self._dim) 70 | 71 | 72 | def build_resnet_base(model_name='resnet18', model_stage=3, pretrained=False, 73 | one_dim_resnet=False, audio_input_channels=201): 74 | 75 | if not hasattr(torchvision.models, model_name): 76 | raise ValueError('Invalid model "%s"' % model_name) 77 | if not 'resnet' in model_name: 78 | raise ValueError('Feature extraction only supports ResNets') 79 | if one_dim_resnet == True: 80 | cnn = getattr(audio_resnet, model_name)(pretrained=pretrained, 81 | input_channels=audio_input_channels) 82 | else: 83 | cnn = getattr(torchvision.models, model_name)(pretrained=pretrained) 84 | layers = [ 85 | cnn.conv1, 86 | cnn.bn1, 87 | cnn.relu, 88 | cnn.maxpool, 89 | ] 90 | for i in range(model_stage): 91 | name = 'layer%d' % (i + 1) 92 | layers.append(getattr(cnn, name)) 93 | model = nn.Sequential(*layers) 94 | return model 95 | 96 | 97 | def euclidean_dist(x, y): 98 | # x: B x N x D 99 | # y: B x M x D 100 | b = x.size(0) 101 | n = x.size(1) 102 | m = y.size(1) 103 | d = x.size(2) 104 | assert d == y.size(2) 105 | assert b == y.size(0) 106 | 107 | x = x.unsqueeze(2).expand(b, n, m, d) 108 | y = y.unsqueeze(1).expand(b, n, m, d) 109 | 110 | return torch.pow(x - y, 2).sum(3) 111 | 112 | 113 | def dict_to_device(batch, device): 114 | """Move a batch from to a target device.""" 115 | batch = { 116 | k: v.to(device) if isinstance(v, torch.Tensor) else v 117 | for k, v in batch.items() 118 | } 119 | return batch 120 | -------------------------------------------------------------------------------- /notebooks/.ipynb_checkpoints/structured_splits-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 2 6 | } 7 | -------------------------------------------------------------------------------- /notebooks/visualize_alternate_hypotheses.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "sys.path.append('..')\n", 11 | "sys.path.append('../hypothesis_generation')\n", 12 | "\n", 13 | "\n", 14 | "from PIL import Image as PImage\n", 15 | "import matplotlib.pyplot as plt\n", 16 | "import matplotlib.image as mpimg\n", 17 | "\n", 18 | "import numpy as np\n", 19 | "import pickle\n", 20 | "import json\n", 21 | "from torchvision.utils import make_grid\n", 22 | "import torch\n", 23 | "from hypothesis_generation.prefix_postfix import PrefixPostfix\n", 24 | "from hypothesis_generator import GrammarExpander\n", 25 | "from reduce_and_process_hypotheses import MetaDatasetExample\n", 26 | "from reduce_and_process_hypotheses import HypothesisEval\n", 27 | "\n", 28 | "%matplotlib inline\n", 29 | "%load_ext autoreload\n", 30 | "%autoreload 2" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 2, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "properties_file = \"/private/home/ramav/code/ad_hoc_categories/concept_data/clevr_typed_fol_properties.json\"\n", 40 | "grammar_expander = (\"/private/home/ramav/code/ad_hoc_categories\"\n", 41 | " \"/concept_data/temp_data/v2_typed_simple_fol_clevr_typed_fol_properties.pkl\")\n", 42 | "program_converter = PrefixPostfix(\n", 43 | " properties_file, grammar_expander_file=grammar_expander\n", 44 | ")" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 3, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "def visualize_example(query_and_support, data_idx, split_name=\"train\", nrows=5, filter_or=True,\n", 54 | " nimages_per_row=5):\n", 55 | " # TODO(ramav): nimages per row is not something that should be left as a free variable to supply. \n", 56 | " _TOTAL_NEGATIVES=20\n", 57 | " \n", 58 | " hypothesis = query_and_support[split_name].hypothesis\n", 59 | " positive_images_per_row = 5\n", 60 | " \n", 61 | " fig = plt.figure(figsize=(10, 20))\n", 62 | " for idx, image_path in enumerate(query_and_support[split_name].image_paths):\n", 63 | " ax = plt.subplot(1, positive_images_per_row, idx + 1)\n", 64 | " img = mpimg.imread(image_path)\n", 65 | " plt.imshow(img)\n", 66 | " plt.axis('off')\n", 67 | " if idx == 0:\n", 68 | " ax.set_title(\"%d] : %s\" % (data_idx, program_converter.postfix_to_prefix(hypothesis)))\n", 69 | " \n", 70 | " \n", 71 | " if filter_or == True:\n", 72 | " idx_to_print = [idx for idx, x in enumerate(\n", 73 | " query_and_support[split_name].alternate_hypothesis_str) if \" or \" not in x]\n", 74 | " \n", 75 | " else:\n", 76 | " idx_to_print = range(len(query_and_support[split_name].alternate_hypothesis_str))\n", 77 | " \n", 78 | " if len(query_and_support[split_name].alternate_hypothesis_str) * _TOTAL_NEGATIVES != len(\n", 79 | " query_and_support[split_name].image_paths_negative):\n", 80 | " print(\"%d ratio\" % (len(query_and_support[split_name].image_paths_negative)\n", 81 | " / len(query_and_support[split_name].alternate_hypothesis_str)))\n", 82 | " raise ValueError\n", 83 | " \n", 84 | " printed_idx_count = 0\n", 85 | " nrows = len(idx_to_print)\n", 86 | " fig = plt.figure(figsize=(10, 4 + 2 * nrows)) \n", 87 | "\n", 88 | " \n", 89 | " for idx, image_path in enumerate(query_and_support[split_name].image_paths_negative):\n", 90 | " if (int(idx/_TOTAL_NEGATIVES) not in idx_to_print) or (idx % _TOTAL_NEGATIVES >= nimages_per_row):\n", 91 | " continue\n", 92 | " \n", 93 | " ax = plt.subplot(nrows, nimages_per_row, printed_idx_count + 1)\n", 94 | " img = mpimg.imread(image_path)\n", 95 | " plt.imshow(img)\n", 96 | " plt.axis('off')\n", 97 | " if printed_idx_count % nimages_per_row == 0:\n", 98 | " ax.set_title(\"%d] : %s\" % (data_idx, program_converter.postfix_to_prefix(\n", 99 | " query_and_support[split_name].alternate_hypothesis_str[int(idx/_TOTAL_NEGATIVES)])))\n", 100 | " \n", 101 | " printed_idx_count += 1" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 6, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "dataset_json = (\"${CURI_DATA_PATH}/\"\n", 111 | " \"hypotheses/v2_typed_simple_fol_depth_6_trials_2000000_ban_1_max_scene_id_200/\"\n", 112 | " \"comp_sampling_log_linear_test_threshold_0.10_pos_im_5_neg_im_20_train_examples_\"\n", 113 | " \"500000_neg_type_alternate_hypotheses_alternate_hypo_1_random_seed_42.pkl\")" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 7, 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [ 122 | "with open(dataset_json, 'rb') as f:\n", 123 | " dataset = pickle.load(f)\n", 124 | " dataset = dataset[\"meta_dataset\"]" 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "metadata": {}, 130 | "source": [ 131 | "### Qualitative Results ###\n", 132 | "\n", 133 | "NOTE: Only showing alternate hypotheses for which there is no \"or\" clause, so that we are able to look at and focus on that subset. Location origin (0, 0) is at the top left corner, and the image is the bottom right quadrant.\n", 134 | "\n", 135 | "#### List of properties present in the dataset ####\n", 136 | " \n", 137 | "\"COUNTS\": [1, 2, 3],\n", 138 | "\n", 139 | "\"COLOR\": [\"gray\",\n", 140 | " \"red\",\n", 141 | " \"blue\",\n", 142 | " \"green\",\n", 143 | " \"brown\",\n", 144 | " \"purple\",\n", 145 | " \"cyan\",\n", 146 | " \"yellow\"]\n", 147 | " \n", 148 | "\"SHAPE\": \n", 149 | " [\n", 150 | " \"cube\",\n", 151 | " \"sphere\",\n", 152 | " \"cylinder\"\n", 153 | " ],\n", 154 | " \n", 155 | "\"MATERIAL\": \n", 156 | " [\n", 157 | " \"rubber\",\n", 158 | " \"metal\"\n", 159 | " ],\n", 160 | " \n", 161 | "\"SIZE\":\n", 162 | " [\n", 163 | " \"large\",\n", 164 | " \"small\"\n", 165 | " ],\n", 166 | " \n", 167 | "\"LOCX\":\n", 168 | " [\n", 169 | " \"1\",\n", 170 | " \"2\",\n", 171 | " \"3\",\n", 172 | " \"4\",\n", 173 | " \"5\",\n", 174 | " \"6\",\n", 175 | " \"7\",\n", 176 | " \"8\"\n", 177 | " ],\n", 178 | "\"LOCY\":\n", 179 | " [\n", 180 | " \"1\",\n", 181 | " \"2\",\n", 182 | " \"3\",\n", 183 | " \"4\",\n", 184 | " \"5\",\n", 185 | " \"6\",\n", 186 | " \"7\",\n", 187 | " \"8\"\n", 188 | " ],\n", 189 | " \n", 190 | "#### Format ####\n", 191 | "There is example number < original query> followed by images for that concept, and then example number followed by alternate images for the other concepts which explain the images corresponding to the " 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 13, 197 | "metadata": { 198 | "scrolled": false 199 | }, 200 | "outputs": [ 201 | { 202 | "ename": "AttributeError", 203 | "evalue": "'MetaDatasetExample' object has no attribute 'image_paths_positive'", 204 | "output_type": "error", 205 | "traceback": [ 206 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 207 | "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", 208 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mthis_idx\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m10\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mvisualize_example\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdatum\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mthis_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"support\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 209 | "\u001b[0;32m\u001b[0m in \u001b[0;36mvisualize_example\u001b[0;34m(query_and_support, data_idx, split_name, nrows, filter_or, nimages_per_row)\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mfig\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfigure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfigsize\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m20\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mimage_path\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mquery_and_support\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0msplit_name\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimage_paths_positive\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 11\u001b[0m \u001b[0max\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msubplot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpositive_images_per_row\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0midx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0mimg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmpimg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimage_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 210 | "\u001b[0;31mAttributeError\u001b[0m: 'MetaDatasetExample' object has no attribute 'image_paths_positive'" 211 | ] 212 | }, 213 | { 214 | "data": { 215 | "text/plain": [ 216 | "
" 217 | ] 218 | }, 219 | "metadata": {}, 220 | "output_type": "display_data" 221 | } 222 | ], 223 | "source": [ 224 | "for this_idx, datum in enumerate(dataset):\n", 225 | " if this_idx > 10:\n", 226 | " break\n", 227 | " visualize_example(datum, this_idx, \"support\")" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": null, 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [] 236 | } 237 | ], 238 | "metadata": { 239 | "kernelspec": { 240 | "display_name": "Python 3", 241 | "language": "python", 242 | "name": "python3" 243 | }, 244 | "language_info": { 245 | "codemirror_mode": { 246 | "name": "ipython", 247 | "version": 3 248 | }, 249 | "file_extension": ".py", 250 | "mimetype": "text/x-python", 251 | "name": "python", 252 | "nbconvert_exporter": "python", 253 | "pygments_lexer": "ipython3", 254 | "version": "3.7.1" 255 | } 256 | }, 257 | "nbformat": 4, 258 | "nbformat_minor": 2 259 | } 260 | -------------------------------------------------------------------------------- /notebooks/visualize_meta_dataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%matplotlib inline\n", 10 | "%load_ext autoreload\n", 11 | "%autoreload 2\n", 12 | "\n", 13 | "import sys\n", 14 | "sys.path.append('..')\n", 15 | "\n", 16 | "from PIL import Image as PImage\n", 17 | "import matplotlib\n", 18 | "import matplotlib.pyplot as plt\n", 19 | "import matplotlib.gridspec as gridspec\n", 20 | "import matplotlib.image as mpimg\n", 21 | "\n", 22 | "import numpy as np\n", 23 | "import pickle\n", 24 | "import json\n", 25 | "from torchvision.utils import make_grid\n", 26 | "import torch\n", 27 | "from hypothesis_generation.prefix_postfix import PrefixPostfix\n", 28 | "from hypothesis_generation.hypothesis_utils import GrammarExpander\n", 29 | "from hypothesis_generation.hypothesis_utils import MetaDatasetExample\n", 30 | "from hypothesis_generation.hypothesis_utils import HypothesisEval\n", 31 | "from dataloaders.utils import ImageAccess\n", 32 | "\n", 33 | "from third_party.image_utils import plot_images" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "properties_file = \"/private/home/ramav/code/ad_hoc_categories/concept_data/clevr_typed_fol_properties.json\"\n", 43 | "grammar_expander = (\"/private/home/ramav/code/ad_hoc_categories\"\n", 44 | " \"/concept_data/temp_data/v2_typed_simple_fol_clevr_typed_fol_properties.pkl\")\n", 45 | "program_converter = PrefixPostfix(\n", 46 | " properties_file, grammar_expander_file=grammar_expander\n", 47 | ")" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "image_path_access = ImageAccess(root_dir=\"/checkpoint/ramav/adhoc_concept_data/adhoc_images_slurm_v0.2/images\",)" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 4, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "dataset_pkl = (\"/checkpoint/ramav/adhoc_concept_data/adhoc_images_slurm_v0.2/hypotheses/\"\n", 66 | " \"v2_typed_simple_fol_depth_6_trials_2000000_ban_1_max_scene_id_200/\"\n", 67 | " \"comp_sampling_log_linear_test_threshold_0.10_pos_im_5_neg_im_20_\"\n", 68 | " \"train_examples_500000_neg_type_alternate_hypotheses_alternate_hypo_1_random_seed_42.pkl\")\n", 69 | "\n", 70 | "with open(dataset_pkl, 'rb') as f:\n", 71 | " dataset = pickle.load(f)['meta_dataset']" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 145, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "def load_image_list(image_id_list):\n", 81 | " images = np.array([mpimg.imread(image_path_access(x)) for x in image_id_list])\n", 82 | " return images\n", 83 | "\n", 84 | "def get_image_grid(episode):\n", 85 | " raw_data_ids = episode.raw_data_ids\n", 86 | " images = np.array([mpimg.imread(image_path_access(x)) for x in raw_data_ids])\n", 87 | " labels = episode.data_labels\n", 88 | " stacked_images = plot_images(images, n=5, orig_labels=labels, gt_labels=[1]*len(labels))/255.0\n", 89 | " return stacked_images\n", 90 | "\n", 91 | "def pretty_hypotheses(hyp_str):\n", 92 | " hyp_str = hyp_str.replace(\"exists=(\", \"any(\")\n", 93 | " hyp_str = hyp_str.replace(\"for-all=(\", \"all(\")\n", 94 | " hyp_str = hyp_str.replace(\"exists=x \\in S\", \"exists x in S\")\n", 95 | " hyp_str = hyp_str.replace(\"for-all=x \\in S\", \"for-all x in S\")\n", 96 | " hyp_str = hyp_str.replace(\"non-x-S\", \"S_{-x}\")\n", 97 | " hyp_str = hyp_str.replace(\"lambda\", \"\\lambda\")\n", 98 | " return hyp_str\n", 99 | "\n", 100 | "def visualize_example(idx, meta_dataset_example):\n", 101 | " def get_sorted_hypotheses(all_hypotheses, logprobs, top_k=5):\n", 102 | " top_k = min(len(logprobs), top_k)\n", 103 | " sorted_idx = np.argsort(-1 * logprobs)[:top_k]\n", 104 | " return [\"* %s {log-prob [%.3f]}\" %(pretty_hypotheses(program_converter.postfix_to_prefix(all_hypotheses[x])),\n", 105 | " float(logprobs[x])) for x in sorted_idx]\n", 106 | " \n", 107 | " gs1 = gridspec.GridSpec(2, 2)\n", 108 | " gs1.update(hspace=0.05)\n", 109 | " fig = plt.figure(figsize=(20, 15), edgecolor='b')\n", 110 | " plt.suptitle(\"Productive Concept: %s\" % (pretty_hypotheses(program_converter.postfix_to_prefix(\n", 111 | " meta_dataset_example['support'].hypothesis))), fontsize=20)\n", 112 | "\n", 113 | " plt.axis([0, 10, 0, 10])\n", 114 | "\n", 115 | " plt.subplot(gs1[0])\n", 116 | " plt.axis('off')\n", 117 | " plt.imshow(get_image_grid(meta_dataset_example['support']))\n", 118 | " plt.title('Support', fontsize=20)\n", 119 | "\n", 120 | " plt.subplot(gs1[1])\n", 121 | " plt.imshow(get_image_grid(meta_dataset_example['query']))\n", 122 | " plt.title('Query', fontsize=20)\n", 123 | " plt.axis('off')\n", 124 | "\n", 125 | " \n", 126 | " ax = plt.subplot(gs1[2])\n", 127 | " plt.axis('off')\n", 128 | " text = \"\\n\".join(\n", 129 | " [\"Valid Hypotheses\", \" \"] + get_sorted_hypotheses(\n", 130 | " meta_dataset_example['support'].all_valid_hypotheses,\n", 131 | " meta_dataset_example['support'].posterior_logprobs))\n", 132 | " \n", 133 | " negatives_come_from = list(set(meta_dataset_example['support'].alternate_hypotheses_for_positives).difference(\n", 134 | " meta_dataset_example['support'].all_valid_hypotheses\n", 135 | " ))\n", 136 | " \n", 137 | " max_negatives = min(5, len(negatives_come_from))\n", 138 | " negatives_come_from = negatives_come_from[:max_negatives]\n", 139 | "\n", 140 | " text += \"\\n\\n\" + \"\\n\".join(\n", 141 | " [\"Hypotheses for Hard Negatives\", \" \"] + [\"* %s\" % x for x in [ \n", 142 | " pretty_hypotheses(program_converter.postfix_to_prefix(\n", 143 | " x)) for x in negatives_come_from]])\n", 144 | " \n", 145 | " plt.text(x=0.1, y=.15\n", 146 | " , s=text, wrap=True, fontsize=20)\n", 147 | " #plt.xlabel(text)\n", 148 | "\n", 149 | " #f = ax.get_figure()\n", 150 | " fig.tight_layout()\n", 151 | " fig.subplots_adjust(top=0.95)" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 136, 157 | "metadata": {}, 158 | "outputs": [ 159 | { 160 | "data": { 161 | "text/plain": [ 162 | "['exists x in S =(2, count=(color?( S_{-x} ), cyan ) )',\n", 163 | " 'exists x in S >(locationY?( x ), 6 )',\n", 164 | " '=(count=(color?( S ), brown ), 3 )',\n", 165 | " '>(count=(locationX?( S ), 3 ), 2 )',\n", 166 | " 'any(locationY?( S ), 6 )',\n", 167 | " '=(1, count=(locationY?( S ), 7 ) )',\n", 168 | " '=(3, count=(locationY?( S ), 3 ) )',\n", 169 | " 'all(locationX?( S ), 2 )',\n", 170 | " 'exists x in S all(locationY?( S_{-x} ), 5 )',\n", 171 | " '=(2, count=(color?( S ), blue ) )',\n", 172 | " 'for-all x in S not( >(6, locationX?( x ) ) )',\n", 173 | " '=(count=(color?( S ), gray ), 2 )',\n", 174 | " '=(2, count=(color?( S ), gray ) )']" 175 | ] 176 | }, 177 | "execution_count": 136, 178 | "metadata": {}, 179 | "output_type": "execute_result" 180 | } 181 | ], 182 | "source": [ 183 | "### Most probable concepts under the prior#########\n", 184 | "hyp =['2 non-x-S color? cyan count= = lambda S. exists=',\n", 185 | " 'x locationY? 6 > lambda S. exists=',\n", 186 | " 'S color? brown count= 3 = lambda S.',\n", 187 | " 'S locationX? 3 count= 2 > lambda S.',\n", 188 | " 'S locationY? 6 exists= lambda S.',\n", 189 | " '1 S locationY? 7 count= = lambda S.',\n", 190 | " '3 S locationY? 3 count= = lambda S.',\n", 191 | " 'S locationX? 2 for-all= lambda S.',\n", 192 | " 'non-x-S locationY? 5 for-all= lambda S. exists=',\n", 193 | " '2 S color? blue count= = lambda S.',\n", 194 | " '6 x locationX? > not lambda S. for-all=',\n", 195 | " 'S color? gray count= 2 = lambda S.',\n", 196 | " '2 S color? gray count= = lambda S.',]\n", 197 | "[pretty_hypotheses(program_converter.postfix_to_prefix(x)) for x in hyp]" 198 | ] 199 | }, 200 | { 201 | "cell_type": "markdown", 202 | "metadata": {}, 203 | "source": [ 204 | "### Qualitative Results ###\n", 205 | "\n", 206 | "NOTE: Only showing alternate hypotheses for which there is no \"or\" clause, so that we are able to look at and focus on that subset. Location origin (0, 0) is at the top left corner, and the image is the bottom right quadrant.\n", 207 | "\n", 208 | "#### List of properties present in the dataset ####\n", 209 | " \n", 210 | "\"COUNTS\": [1, 2, 3],\n", 211 | "\n", 212 | "\"COLOR\": [\"gray\",\n", 213 | " \"red\",\n", 214 | " \"blue\",\n", 215 | " \"green\",\n", 216 | " \"brown\",\n", 217 | " \"purple\",\n", 218 | " \"cyan\",\n", 219 | " \"yellow\"]\n", 220 | " \n", 221 | "\"SHAPE\": \n", 222 | " [\n", 223 | " \"cube\",\n", 224 | " \"sphere\",\n", 225 | " \"cylinder\"\n", 226 | " ],\n", 227 | " \n", 228 | "\"MATERIAL\": \n", 229 | " [\n", 230 | " \"rubber\",\n", 231 | " \"metal\"\n", 232 | " ],\n", 233 | " \n", 234 | "\"SIZE\":\n", 235 | " [\n", 236 | " \"large\",\n", 237 | " \"small\"\n", 238 | " ],\n", 239 | " \n", 240 | "\"LOCX\":\n", 241 | " [\n", 242 | " \"1\",\n", 243 | " \"2\",\n", 244 | " \"3\",\n", 245 | " \"4\",\n", 246 | " \"5\",\n", 247 | " \"6\",\n", 248 | " \"7\",\n", 249 | " \"8\"\n", 250 | " ],\n", 251 | "\"LOCY\":\n", 252 | " [\n", 253 | " \"1\",\n", 254 | " \"2\",\n", 255 | " \"3\",\n", 256 | " \"4\",\n", 257 | " \"5\",\n", 258 | " \"6\",\n", 259 | " \"7\",\n", 260 | " \"8\"\n", 261 | " ],\n", 262 | " \n", 263 | "#### Format ####\n", 264 | "There is example number < original query> followed by images for that concept, and then example number followed by alternate images for the other concepts which explain the images corresponding to the " 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 146, 270 | "metadata": { 271 | "scrolled": false 272 | }, 273 | "outputs": [], 274 | "source": [ 275 | "NUM_VIS = 20\n", 276 | "dataset = sorted(dataset, key=lambda x: np.random.randn(1))\n", 277 | "\n", 278 | "\n", 279 | "for idx, datum in enumerate(dataset):\n", 280 | " if idx + 1 < NUM_VIS:\n", 281 | " visualize_example(idx, datum)\n", 282 | " plt.tight_layout()\n", 283 | " plt.savefig('qualitative_%d.png' % (idx + 1))\n", 284 | " plt.close()" 285 | ] 286 | } 287 | ], 288 | "metadata": { 289 | "kernelspec": { 290 | "display_name": "Python 3", 291 | "language": "python", 292 | "name": "python3" 293 | }, 294 | "language_info": { 295 | "codemirror_mode": { 296 | "name": "ipython", 297 | "version": 3 298 | }, 299 | "file_extension": ".py", 300 | "mimetype": "text/x-python", 301 | "name": "python", 302 | "nbconvert_exporter": "python", 303 | "pygments_lexer": "ipython3", 304 | "version": "3.7.1" 305 | } 306 | }, 307 | "nbformat": 4, 308 | "nbformat_minor": 2 309 | } 310 | -------------------------------------------------------------------------------- /paths.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | export RUN_DIR="./runs/" 7 | export CURI_DATA_PATH="" 8 | -------------------------------------------------------------------------------- /runs/hydra_train/.hydra/config.yaml: -------------------------------------------------------------------------------- 1 | _data: 2 | batch_size: 8 3 | map_eval_batch_size: 8 4 | modality: image 5 | _modality: 6 | abs_pos_enc: true 7 | use_batch_norm_rel_net: true 8 | _mode: train 9 | _model: 10 | abs_pos_enc: true 11 | obj_fdim: 64 12 | pooling: concat 13 | pooling_init: xavier 14 | rel_pos_enc: false 15 | costly_loss: 16 | class: losses.MetaLearningMeanAveragePrecision 17 | name: map 18 | cross_split: ${data.num_positives}_0.10_hypotheses_heavy.json 19 | cross_split_hypothesis_image_mapping: ${data.num_positives}_0.10_image_hyp_mapping.json 20 | data: 21 | dataset: adhoc_concepts 22 | hypothesis_prior: log_linear 23 | map_eval_num_images_per_concept: 3 24 | negative_type: alternate_hypotheses 25 | num_negatives: 20 26 | num_positives: 5 27 | path: ${raw_data.data_dir}/hypotheses/${filetype} 28 | positive_threshold: 0.1 29 | split_type: comp 30 | train_examples: 500000 31 | data_args: 32 | class: dataloaders.get_dataloader.GetDataloader 33 | params: 34 | splits: ${splits} 35 | device: cuda 36 | eval_split_name: val 37 | evaluate_once: false 38 | filetype: v2_typed_simple_fol_depth_6_trials_2000000_ban_1_max_scene_id_200 39 | input_dim: 3, 160, 240 40 | job_replica: 0 41 | loss: 42 | class: losses.NegativeLogLikelihoodMultiTask 43 | name: nll 44 | params: 45 | alpha: 0.1 46 | num_classes: ${num_classes} 47 | pad_token_idx: -10 48 | mem_requirement: 80GB 49 | model: 50 | class: models.protonet.GetProtoNetModel 51 | name: protonet 52 | params: 53 | absolute_position_encoding_for_modality: ${_modality.abs_pos_enc} 54 | absolute_position_encoding_for_pooling: ${_model.abs_pos_enc} 55 | feature_dim: 256 56 | im_fg: true 57 | init_to_use_pooling: ${_model.pooling_init} 58 | input_dim: ${input_dim} 59 | language_alpha: ${loss.params.alpha} 60 | modality: ${_data.modality} 61 | num_classes: ${num_classes} 62 | obj_fdim: ${_model.obj_fdim} 63 | pairwise_position_encoding: ${_model.rel_pos_enc} 64 | pooling: ${_model.pooling} 65 | pretrained_encoder: false 66 | use_batch_norm_rel_net: ${_modality.use_batch_norm_rel_net} 67 | model_or_oracle_metrics: model 68 | num_classes: 2 69 | opt: 70 | checkpoint_every: 30000 71 | lr: 0.0001 72 | lr_gamma: 0.5 73 | lr_patience: 10 74 | max_steps: 1000000 75 | num_workers: 10 76 | weight_decay: false 77 | qualitative: qualitative_eval_inputs_for_hierarchy.pkl 78 | raw_data: 79 | audio_path: ${raw_data.data_dir}"/sound_scenes" 80 | data_dir: ${env:CURI_DATA_PATH} 81 | image_path: ${raw_data.data_dir}"/images" 82 | json_path: ${raw_data.data_dir}"/scenes" 83 | properties_file_path: ${env:PWD}/concept_data/clevr_typed_fol_properties.json 84 | splits: train & val 85 | test: ${data.split_type}_sampling_${data.hypothesis_prior}_test_threshold_0.10_pos_im_${data.num_positives}_neg_im_${data.num_negatives}_train_examples_${data.train_examples}_neg_type_${data.negative_type}_alternate_hypo_1_random_seed_42.pkl 86 | train: ${data.split_type}_sampling_${data.hypothesis_prior}_train_threshold_0.10_pos_im_${data.num_positives}_neg_im_${data.num_negatives}_train_examples_${data.train_examples}_neg_type_${data.negative_type}_alternate_hypo_1_random_seed_42.pkl 87 | val: ${data.split_type}_sampling_${data.hypothesis_prior}_val_threshold_0.10_pos_im_${data.num_positives}_neg_im_${data.num_negatives}_train_examples_${data.train_examples}_neg_type_${data.negative_type}_alternate_hypo_1_random_seed_42.pkl 88 | -------------------------------------------------------------------------------- /runs/hydra_train/.hydra/hydra.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | help: 3 | app_name: ${hydra.job.name} 4 | footer: 'Powered by Hydra (https://cli.dev) 5 | 6 | Use --hydra-help to view Hydra specific help 7 | 8 | ' 9 | header: '${hydra.help.app_name} is powered by Hydra. 10 | 11 | ' 12 | template: '${hydra.help.header} 13 | 14 | == Configuration groups == 15 | 16 | Compose your configuration from those groups (group=option) 17 | 18 | 19 | $APP_CONFIG_GROUPS 20 | 21 | 22 | == Config == 23 | 24 | Override anything in the config (foo.bar=value) 25 | 26 | 27 | $CONFIG 28 | 29 | 30 | ${hydra.help.footer} 31 | 32 | ' 33 | hydra_help: 34 | template: 'Hydra (${hydra.runtime.version}) 35 | 36 | See https://cli.dev for more info. 37 | 38 | 39 | == Flags == 40 | 41 | $FLAGS_HELP 42 | 43 | 44 | == Configuration groups == 45 | 46 | Compose your configuration from those groups (For example, append hydra/job_logging=disabled 47 | to command line) 48 | 49 | 50 | $HYDRA_CONFIG_GROUPS 51 | 52 | 53 | Use ''--cfg hydra'' to Show the Hydra config.' 54 | hydra_logging: 55 | disable_existing_loggers: false 56 | formatters: 57 | simple: 58 | format: '[%(asctime)s][HYDRA] %(message)s' 59 | handlers: 60 | console: 61 | class: logging.StreamHandler 62 | formatter: simple 63 | stream: ext://sys.stdout 64 | loggers: 65 | logging_example: 66 | level: DEBUG 67 | root: 68 | handlers: 69 | - console 70 | level: INFO 71 | version: 1 72 | job: 73 | config: 74 | override_dirname: 75 | exclude_keys: 76 | - job_replica 77 | - mode 78 | - opt.max_steps 79 | - opt.checkpoint_every 80 | - model_or_oracle_metrics 81 | - eval_cfg.evaluate_once 82 | - val 83 | - splits 84 | - eval_split_name 85 | - test 86 | - train 87 | - eval_cfg.write_raw_metrics 88 | - eval_cfg.evaluate_all 89 | - eval_cfg.best_test_metric 90 | - eval_cfg.sort_validation_checkpoints 91 | item_sep: ',' 92 | kv_sep: '=' 93 | config_file: experiment.yaml 94 | id: ??? 95 | name: hydra_train 96 | num: ??? 97 | override_dirname: '' 98 | job_logging: 99 | disable_existing_loggers: false 100 | formatters: 101 | simple: 102 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' 103 | handlers: 104 | console: 105 | class: logging.StreamHandler 106 | formatter: simple 107 | stream: ext://sys.stdout 108 | file: 109 | class: logging.FileHandler 110 | filename: ${hydra.job.name}.log 111 | formatter: simple 112 | root: 113 | handlers: 114 | - console 115 | - file 116 | level: INFO 117 | version: 1 118 | launcher: 119 | class: hydra._internal.core_plugins.basic_launcher.BasicLauncher 120 | params: 121 | queue_parameters: 122 | slurm: 123 | additional_parameters: 124 | constraint: '' 125 | job_name: ${hydra.job.name}_${_mode} 126 | max_num_timeout: 10 127 | mem: ${mem_requirement} 128 | partition: learnfair 129 | time: 4300 130 | output_subdir: .hydra 131 | overrides: 132 | hydra: [] 133 | task: 134 | - mode=train 135 | run: 136 | dir: ${env:RUN_DIR}/${hydra.job.name}/${hydra.job.override_dirname} 137 | runtime: 138 | cwd: /private/home/ramav/code/productive_concept_learning 139 | version: 0.11.0 140 | sweep: 141 | dir: ${env:RUN_DIR}/${hydra.job.name} 142 | subdir: ${hydra.job.override_dirname}/${job_replica} 143 | sweeper: 144 | class: hydra._internal.core_plugins.basic_sweeper.BasicSweeper 145 | verbose: false 146 | -------------------------------------------------------------------------------- /runs/hydra_train/.hydra/overrides.yaml: -------------------------------------------------------------------------------- 1 | - mode=train 2 | -------------------------------------------------------------------------------- /runs/hydra_train/events.out.tfevents.1625066218.devfair0241: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/productive_concept_learning/d00a3b85215bff0686ad3f834d50c8eae374f06a/runs/hydra_train/events.out.tfevents.1625066218.devfair0241 -------------------------------------------------------------------------------- /scripts/pick_best_checkpoints.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | r"""Code to pick best checkpoints from the validation runs. 7 | 8 | Assumes a nested folder structure for the runs as follows: 9 | 10 | sweep_folder/ 11 | job_1/ 12 | 0/ 13 | validation_results.txt 14 | 1/ 15 | validation_results.txt 16 | job_2/ 17 | 0/ 18 | validation_results.txt 19 | 1/ 20 | validation_results.txt 21 | 22 | The user needs to specify the sweep folder, and the script augments each job 23 | folder job_1/0, job_1/1 etc. with the details of the best checkpoint. 24 | Output of the script is a file like: 25 | sweep_folder/job_1/0/checkpoint_best_modelmap.pth 26 | 27 | which denotes the best validation checkpoint based on the metric of interest, 28 | here, modelmap. 29 | """ 30 | import argparse 31 | 32 | import logging 33 | import pandas as pd 34 | import shutil 35 | import os 36 | 37 | from glob import glob 38 | 39 | 40 | def compute_best_valid(job, metric, greater_is_better): 41 | all_txt_files = glob(job + "/*.txt") 42 | useful_txt_files = [] 43 | 44 | for this_txt_file in all_txt_files: 45 | # Skip / remove any validation metrics dumped during training. 46 | if "val" in this_txt_file: 47 | useful_txt_files.append(this_txt_file) 48 | del all_txt_files 49 | 50 | if len(useful_txt_files) == 0: 51 | logging.warning(f"No eval files found for job {job}") 52 | return 53 | 54 | result_table = None 55 | for idx, this_txt_file in enumerate(useful_txt_files): 56 | table = pd.read_csv(this_txt_file) 57 | table.columns = ["step", "metric", "value"] 58 | 59 | if idx == 0: 60 | result_table = table 61 | else: 62 | result_table = result_table.append(table) 63 | 64 | result_table = result_table[result_table["metric"] == metric] 65 | 66 | if len(result_table) == 0: 67 | logging.warning(f"{metric} not found in result files for {job}") 68 | return 69 | 70 | metric_to_print = metric.replace("/", "_") 71 | 72 | best_checkpoint_idx = (result_table.sort_values( 73 | "value", ascending=greater_is_better)["step"].iloc[-1]) 74 | best_checkpoint = os.path.join(job, 75 | f"checkpoint_{best_checkpoint_idx}.pth") 76 | target_checkpoint_path = os.path.join(job, 77 | f"checkpoint_best_{metric_to_print}.pth") 78 | shutil.copy(best_checkpoint, target_checkpoint_path) 79 | 80 | 81 | def main(args): 82 | all_jobs = glob(args.sweep_folder + "/*") 83 | 84 | folders = [] 85 | for job in all_jobs: 86 | folders.extend(glob(job + "/*")) 87 | del all_jobs 88 | 89 | for job in folders: 90 | compute_best_valid(job, args.val_metric, args.greater_is_better) 91 | print(f"Done with {job}") 92 | 93 | 94 | if __name__ == "__main__": 95 | parser = argparse.ArgumentParser() 96 | 97 | parser.add_argument("--sweep_folder", 98 | type=str, 99 | help="Folder which contains all the jobs.") 100 | parser.add_argument("--val_metric", 101 | default="modelmap", 102 | type=str, 103 | help="Metric to use for picking best checkpoint.") 104 | parser.add_argument("--greater_is_better", type=int, default=1, 105 | help="If greater is better for the chosen val metric," 106 | "1 if true, 0 if false.") 107 | 108 | args = parser.parse_args() 109 | 110 | if args.greater_is_better == 1: 111 | args.greater_is_better = True 112 | else: 113 | args.greater_is_better = False 114 | 115 | main(args) -------------------------------------------------------------------------------- /third_party/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /third_party/image_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Copyright 2017 Google Inc. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | """A few handy utilities for showing or saving images. 21 | 22 | plot_images lays a set of images out on a grid, with optional labels. 23 | 24 | show_image displays a single image, with optional border color. 25 | """ 26 | 27 | from __future__ import absolute_import 28 | from __future__ import division 29 | from __future__ import print_function 30 | 31 | import matplotlib as mpl 32 | mpl.use('Agg') 33 | import matplotlib.pyplot as plt # pylint: disable=g-import-not-at-top 34 | import numpy as np 35 | 36 | 37 | def get_color(orig_label=None, 38 | gt_label=None, 39 | target_label=None, 40 | adv_label=None, 41 | targeted_attack=False): 42 | """Get color according to different possible correctness combinations.""" 43 | if ((gt_label is None or orig_label == gt_label) and 44 | (adv_label is None or adv_label == gt_label)): 45 | color = '#023603' # 'drak-green': correct prediction 46 | elif adv_label is None: 47 | color = '#b81e26' # 'red': not adversarial, incorrect prediction 48 | elif adv_label == gt_label: 49 | color = '#a8db78' # 'lightgreen': adv_label correct, but orig_label wrong 50 | elif adv_label == orig_label: 51 | color = '#b09b3e' # 'yellow': incorrect, but no change to prediction 52 | elif not targeted_attack and adv_label is not None and adv_label != gt_label: 53 | color = '#b81e26' # 'red': untargeted attack, adv_label changed, success 54 | elif targeted_attack and adv_label == target_label: 55 | color = '#b81e26' # 'red': targeted attack, success 56 | elif targeted_attack and adv_label != target_label: 57 | color = '#d064a6' # 'pink': targeted attack, changed prediction but failed 58 | else: 59 | color = '#6780d8' # 'blue': unaccounted-for result 60 | return color 61 | 62 | 63 | def plot_images(images, 64 | n=3, 65 | figure_width=8, 66 | filename='', 67 | filetype='png', # pylint: disable=unused-argument 68 | orig_labels=None, 69 | gt_labels=None, 70 | target_labels=None, 71 | adv_labels=None, 72 | targeted_attack=False, 73 | skip_incorrect=False, 74 | blank_incorrect=False, 75 | blank_adv_correct=False, 76 | color_from_orig_label_if_blank_target=False, 77 | annotations=None, 78 | annotate_only_first_image=True, 79 | text_alpha=1.0, 80 | return_fig=False): 81 | """Plot images in a tight grid with optional labels.""" 82 | is_adv = adv_labels is not None 83 | has_labels = ( 84 | orig_labels is not None and gt_labels is not None and 85 | (not is_adv or not targeted_attack or target_labels is not None)) 86 | 87 | orig_labels = (orig_labels 88 | if orig_labels is not None else np.zeros(images.shape[0])) 89 | gt_labels = (gt_labels 90 | if gt_labels is not None else np.zeros(images.shape[0])) 91 | target_labels = (target_labels 92 | if target_labels is not None else np.zeros(images.shape[0])) 93 | adv_labels = (adv_labels 94 | if adv_labels is not None else np.zeros(images.shape[0])) 95 | 96 | annotations = (annotations if annotations is not None else 97 | [[] for _ in range(images.shape[0])]) 98 | 99 | w = n 100 | h = int(np.ceil(images.shape[0] / float(w))) 101 | ih, iw = images.shape[1:3] 102 | channels = (images.shape[-1] if len(images.shape) > 3 else 1) 103 | 104 | figure_height = h * ih * figure_width / (iw * float(w)) 105 | dpi = max(float(iw * w) / figure_width, float(ih * h) / figure_height) 106 | fig = plt.figure(figsize=(figure_width, figure_height), dpi=dpi) 107 | ax = fig.gca() 108 | 109 | minc = np.min(images) 110 | if minc < 0: 111 | # Assume images is in [-1, 1], and transform it back to [0, 1] for display 112 | images = images * 0.5 + 0.5 113 | 114 | figure = np.ones((ih * h, iw * w, channels)) * 0.5 115 | if channels == 1: 116 | figure = figure.squeeze(-1) 117 | 118 | ax.set_frame_on(False) 119 | ax.set_xticks([]) 120 | ax.set_yticks([]) 121 | ax.axis('off') 122 | fig.tight_layout(pad=0) 123 | 124 | cur_image = 0 125 | for i in range(h): 126 | for j in range(w): 127 | if i * w + j >= images.shape[0]: 128 | # White out overflow slots, but not slots from skipping images. 129 | figure[i * ih:(i + 1) * ih, j * iw:(j + 1) * iw] = 1.0 130 | continue 131 | 132 | try: 133 | image = None 134 | while image is None: 135 | image = images[cur_image] 136 | orig_label = orig_labels[cur_image] 137 | gt_label = gt_labels[cur_image] 138 | target_label = target_labels[cur_image] 139 | adv_label = adv_labels[cur_image] 140 | annots = annotations[cur_image] 141 | cur_image += 1 142 | 143 | if has_labels and orig_label != gt_label: 144 | if skip_incorrect: 145 | image = None 146 | elif blank_incorrect: 147 | image *= 0.2 148 | 149 | if (image is not None and has_labels and is_adv and blank_adv_correct 150 | and adv_label == gt_label): 151 | image *= 0.0 152 | 153 | except IndexError: 154 | continue 155 | 156 | image = image.reshape([ih, iw, channels]) 157 | if channels == 1: 158 | image = image.squeeze(-1) 159 | 160 | figure[i * ih:(i + 1) * ih, j * iw:(j + 1) * iw] = image 161 | 162 | if annotate_only_first_image and j == 0: 163 | annot_this = True 164 | elif annotate_only_first_image is False: 165 | annot_this = True 166 | else: 167 | annot_this = False 168 | annot_this &= annots is not None 169 | 170 | if annot_this: 171 | for k, a in enumerate(annots): 172 | ax.annotate( 173 | str(a['label']), 174 | xy=(float(j) / w + 0.05 / w, 175 | 1.0 - float(i) / h - (0.05 + k * 0.05) / h), 176 | xycoords='axes fraction', 177 | color=a['color'], 178 | verticalalignment='bottom', 179 | alpha=text_alpha, 180 | fontsize=(0.12 * 72 * figure_height / h)) 181 | if has_labels: 182 | color = get_color(orig_label, gt_label, 183 | None, None, False) 184 | ax.annotate( 185 | str(orig_label), 186 | xy=(float(j) / w + 0.05 / w, 1.0 - float(i) / h - 0.05 / h), 187 | xycoords='axes fraction', 188 | color=color, 189 | verticalalignment='top', 190 | alpha=text_alpha, 191 | fontsize=(0.30 * 72 * figure_height / h) 192 | ) 193 | 194 | if is_adv: 195 | if target_label or not color_from_orig_label_if_blank_target: 196 | color = get_color(orig_label, gt_label, target_label, adv_label, 197 | targeted_attack and target_label) 198 | s = str(adv_label) 199 | if targeted_attack and target_label: 200 | s = '%s|%s' % (str(adv_label), str(target_label)) 201 | ax.annotate( 202 | s, 203 | xy=(float(j) / w + 0.05 / w, 1.0 - float(i) / h - 0.18 / h), 204 | xycoords='axes fraction', 205 | color=color, 206 | verticalalignment='top', 207 | alpha=text_alpha, 208 | fontsize=(0.12 * 72 * figure_height / h) 209 | ) 210 | 211 | if channels == 1: 212 | fig.figimage(figure, cmap='Greys') 213 | else: 214 | fig.figimage(figure) 215 | 216 | figure = canvas_to_np(fig) 217 | 218 | if filename: 219 | with open('{filename}.{filetype}'.format(**locals()), 'w') as f: 220 | fig.savefig(f, dpi='figure') 221 | else: 222 | fig.show(warn=False) 223 | 224 | if return_fig: 225 | return fig 226 | 227 | plt.close(fig) 228 | 229 | return figure 230 | 231 | 232 | def show_image( 233 | image, 234 | color=None, 235 | filename=None, 236 | filetype='png', # pylint: disable=unused-argument 237 | show=True, 238 | return_fig=False): 239 | """Draw a single image as a stand-alone figure with an optional border.""" 240 | image = image[0] if image.ndim > 3 else image 241 | w = image.shape[1] 242 | h = image.shape[0] 243 | c = 3 if color else image.shape[-1] 244 | 245 | pad = int(w * 0.05) if color else 0 246 | wpad = 2 * pad + w 247 | hpad = 2 * pad + h 248 | 249 | figure_width = 2.0 250 | figure_height = figure_width * hpad / float(wpad) 251 | dpi = wpad / figure_width 252 | fig = plt.figure(figsize=(figure_width, figure_height), dpi=dpi) 253 | ax = fig.gca() 254 | 255 | minc = np.min(image) 256 | if minc < 0: 257 | # Assume image is in [-1, 1], and transform it back to [0, 1] for display 258 | image = image * 0.5 + 0.5 259 | 260 | ax.set_frame_on(False) 261 | ax.set_xticks([]) 262 | ax.set_yticks([]) 263 | ax.axis('off') 264 | fig.tight_layout(pad=0) 265 | 266 | color = mpl.colors.colorConverter.to_rgb(color) if color else 0.0 267 | 268 | figure = np.ones((hpad, wpad, c)) * color 269 | if c == 1: 270 | figure = figure.squeeze(-1) 271 | 272 | figure[pad:pad + h, pad:pad + w] = image 273 | 274 | if c == 1: 275 | fig.figimage(figure, cmap='Greys') 276 | else: 277 | fig.figimage(figure) 278 | 279 | print(figure.shape) 280 | 281 | figure = canvas_to_np(fig) 282 | 283 | if filename: 284 | with open('{filename}.{filetype}'.format(**locals()), 'w') as f: 285 | fig.savefig(f, dpi='figure') 286 | 287 | if show: 288 | fig.show(warn=False) 289 | 290 | if return_fig: 291 | return fig 292 | 293 | plt.close(fig) 294 | 295 | return figure 296 | 297 | 298 | def canvas_to_np(figure, rescale=False): 299 | """Turn a pyplt figure into an np image array of bytes or floats.""" 300 | figure.canvas.draw() 301 | 302 | # Collect the pixels back from the pyplot figure. 303 | buff, (width, height) = figure.canvas.print_to_buffer() 304 | img = np.frombuffer(buff, np.uint8).reshape(height, width, -1) 305 | if img.shape[-1] > 3: 306 | img = img[:, :, 0:3] 307 | 308 | img = img.astype(np.float32) 309 | if rescale: 310 | img /= 255.0 311 | 312 | return img -------------------------------------------------------------------------------- /third_party/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class MLP(nn.Module): 11 | def __init__(self, 12 | input_dim, 13 | hidden_dims, 14 | output_dim, 15 | tanh=False, 16 | nonlinear=False, 17 | batch_norm=False): 18 | super(MLP, self).__init__() 19 | self.input_dim = input_dim 20 | self.hidden_dims = hidden_dims 21 | self.output_dim = output_dim 22 | self.nonlinear = nonlinear 23 | self.use_batch_norm = batch_norm 24 | 25 | self.linears = nn.ModuleList( 26 | [nn.Linear(self.input_dim, self.hidden_dims[0])]) 27 | if self.use_batch_norm: 28 | self.batchnorms = nn.ModuleList([ 29 | nn.BatchNorm1d(self.hidden_dims[i]) 30 | for i in range(len(self.hidden_dims)) 31 | ]) 32 | 33 | for i in range(1, len(self.hidden_dims)): 34 | self.linears.append( 35 | nn.Linear(self.hidden_dims[i - 1], self.hidden_dims[i])) 36 | self.linears.append(nn.Linear(self.hidden_dims[-1], self.output_dim)) 37 | 38 | if tanh: 39 | self.activation = torch.tanh 40 | else: 41 | self.activation = torch.relu 42 | 43 | def forward(self, x): 44 | 45 | x = self.linears[0](x) 46 | x = self.activation(x) 47 | if self.use_batch_norm: 48 | x = self.batchnorms[0](x) 49 | 50 | for i in range(1, len(self.hidden_dims)): 51 | x = self.linears[i](x) 52 | x = self.activation(x) 53 | if self.use_batch_norm and ((i == len(self.hidden_dims) - 1) or 54 | (i == len(self.hidden_dims) - 2)): 55 | x = self.batchnorms[i](x) 56 | 57 | out = self.linears[-1](x) 58 | if self.nonlinear: 59 | out = self.activation(out) 60 | 61 | return out -------------------------------------------------------------------------------- /third_party/relation_network.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Relation Netowrk code for few-shot learning experiments. 7 | 8 | Code from https://github.com/AndreaCossu/Relation-Network-PyTorch/ 9 | """ 10 | import torch 11 | import torch.nn as nn 12 | import numpy as np 13 | from third_party.mlp import MLP 14 | from itertools import product 15 | 16 | REDUCTION_FACTOR = 2 17 | 18 | 19 | def get_relative_spatial_feature(w, h, spatial_grid_feat_zero_center=True): 20 | h = int(h) 21 | w = int(w) 22 | d_sq = h * w 23 | d_four = d_sq*d_sq 24 | spatial_grid_feat = torch.zeros(1, 2, d_four) 25 | ctr = 0 26 | for o in range(d_sq): 27 | for i in range(d_sq): 28 | o_h = o // w 29 | o_w = o % w 30 | i_h = i // w 31 | i_w = i % w 32 | spatial_grid_feat[:, 0, ctr] = o_h - i_h 33 | spatial_grid_feat[:, 1, ctr] = o_w - i_w 34 | ctr += 1 35 | if spatial_grid_feat_zero_center: 36 | spatial_grid_feat -= (spatial_grid_feat.mean()) 37 | 38 | return spatial_grid_feat 39 | 40 | def get_relative_1d_feature(w, spatial_grid_feat_zero_center=True): 41 | w = int(w) 42 | spatial_grid_feat = torch.zeros(1, 1, w*w) 43 | 44 | ctr = 0 45 | for o in range(w): 46 | for i in range(w): 47 | spatial_grid_feat[:, 0, ctr] = o - i 48 | 49 | ctr += 1 50 | 51 | if spatial_grid_feat_zero_center: 52 | spatial_grid_feat -= spatial_grid_feat.mean() 53 | 54 | return spatial_grid_feat 55 | 56 | 57 | 58 | class RelationNetwork(nn.Module): 59 | def __init__(self, obj_fdim, num_objects_list, 60 | hidden_dims_g, output_dim_g, hidden_dims_f, 61 | output_dim_f, 62 | tanh, use_batch_norm_f=False, relative_position_encoding=True, 63 | image_concat_reduce=False): 64 | 65 | super(RelationNetwork, self).__init__() 66 | 67 | self.relative_position_encoding = relative_position_encoding 68 | self.obj_fdim = obj_fdim 69 | self.image_concat_reduce = image_concat_reduce 70 | self.num_objects_list = num_objects_list 71 | self.input_dim_g = 2 * self.obj_fdim # g analyzes pairs of objects 72 | self.hidden_dims_g = hidden_dims_g 73 | self.output_dim_g = output_dim_g 74 | self.input_dim_f = self.output_dim_g 75 | 76 | self.hidden_dims_f = hidden_dims_f 77 | self.output_dim_f = output_dim_f 78 | 79 | self.reduction_factor = 1 80 | if self.image_concat_reduce == True: 81 | self.reduction_factor = REDUCTION_FACTOR 82 | 83 | self.input_dim_g = self.input_dim_g * (self.reduction_factor)**2 84 | 85 | if self.relative_position_encoding == True: 86 | self.input_dim_g = self.input_dim_g + len(self.num_objects_list) 87 | 88 | self.g = MLP(self.input_dim_g, 89 | self.hidden_dims_g, 90 | self.output_dim_g, 91 | tanh=tanh, 92 | nonlinear=True, 93 | batch_norm=False) 94 | 95 | # Different from the original paper, we replace / drop dropout layers 96 | # since models never seem to overfit in our setting. 97 | self.f = MLP(self.input_dim_f, 98 | self.hidden_dims_f, 99 | self.output_dim_f, 100 | tanh=tanh, 101 | batch_norm=use_batch_norm_f) 102 | 103 | if self.relative_position_encoding == True: 104 | if len(self.num_objects_list) == 2: 105 | self.register_buffer( 106 | 'relative_grid_feat', 107 | get_relative_spatial_feature( 108 | self.num_objects_list[0] / self.reduction_factor, 109 | self.num_objects_list[1] / self.reduction_factor)) 110 | elif len(self.num_objects_list) == 1: 111 | self.register_buffer( 112 | 'relative_grid_feat', 113 | get_relative_1d_feature( 114 | self.num_objects_list[0] / self.reduction_factor)) 115 | else: 116 | raise ValueError("Num object dimensions must be 2 or 1.") 117 | 118 | 119 | def forward(self, x, q=None): 120 | """Forward the relation network model. 121 | 122 | Args: 123 | x: A `Tensor` of [B x C x H x W] or [B x C x L]; C is the object 124 | q: A `Tensor` of [B x C x H x W] or [B x C x L]; C is the object 125 | Returns: 126 | A `Tensor` of [B x D] 127 | """ 128 | n_objects = np.prod(self.num_objects_list) 129 | 130 | # Reshape to [B x C x n_objects] 131 | x = x.view(x.size(0), x.size(1), n_objects) 132 | 133 | if self.image_concat_reduce == True: 134 | if len(self.num_objects_list) == 2: 135 | x = x.view(x.size(0), x.size(1), self.num_objects_list[0], 136 | self.num_objects_list[1]) 137 | # TODO(ramav): Remove this hardcoding. 138 | x = x.permute(0, 2, 3, 1) # [B x H x W x C] 139 | x = x.contiguous() 140 | if x.size(1) % self.reduction_factor != 0 or x.size(2) % self.reduction_factor != 0: 141 | raise ValueError("Expect feature width and height to " 142 | "be divisible by 2.") 143 | x = x.view(x.size(0), int(x.size(1)/self.reduction_factor), int(x.size(2)/self.reduction_factor 144 | ), -1) 145 | n_objects = int(n_objects / self.reduction_factor**2) 146 | x = x.permute(0, 3, 1, 2) 147 | x = x.contiguous() 148 | x = x.view(x.size(0), x.size(1), n_objects) 149 | else: 150 | raise NotImplementedError("Concat reduce is only implemented " 151 | "for image models.") 152 | xi = x.repeat(1, 1, n_objects) # [B x C x n_objects * n_objects] 153 | xj = x.unsqueeze(3) # [B x C x n_objects x 1] 154 | xj = xj.repeat(1, 1, 1, n_objects) # [B x C x n_objects x n_objects] 155 | xj = xj.view(x.size(0), x.size(1), 156 | -1) # [B x C x n_objects * n_objects] 157 | if q is not None: 158 | raise NotImplementedError 159 | 160 | pair_concat = torch.cat((xi, xj), 161 | dim=1) # (B, 2*C, n_objects * n_objects) 162 | 163 | if self.relative_position_encoding == True: 164 | pair_concat = torch.cat([ 165 | pair_concat, 166 | self.relative_grid_feat.repeat( 167 | pair_concat.size(0), 1, 1) 168 | ], dim=1) 169 | 170 | # MLP will take as input [B , n_objects * n_objects, 2*C] 171 | pair_concat = pair_concat.permute(0, 2, 1) 172 | relations = self.g(pair_concat.reshape( 173 | -1, pair_concat.size(2))) # (n_objects*n_objects, hidden_dim_g) 174 | relations = relations.view(pair_concat.size(0), pair_concat.size(1), 175 | self.output_dim_g) 176 | 177 | embedding = torch.sum(relations, dim=1) # (B x hidden_dim_g) 178 | 179 | out = self.f(embedding) # (B x hidden_dim_f) 180 | 181 | return out 182 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /utils/checkpointing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import copy 7 | import glob 8 | import numpy as np 9 | import os 10 | import random 11 | import torch 12 | 13 | from typing import Any, Dict, Optional, Union, Type 14 | from torch import nn, optim 15 | 16 | 17 | class CheckpointManager(object): 18 | r""" 19 | A :class:`CheckpointManager` periodically serializes models and optimizer as .pth files during 20 | training, and keeps track of best performing checkpoint based on an observed metric. 21 | 22 | Extended Summary 23 | ---------------- 24 | It saves state dicts of models and optimizer as ``.pth`` files in a specified directory. This 25 | class closely follows the API of PyTorch optimizers and learning rate schedulers. 26 | 27 | Notes 28 | ----- 29 | For :class:`~torch.nn.DataParallel` objects, ``.module.state_dict()`` is called instead of 30 | ``.state_dict()``. 31 | 32 | Parameters 33 | ---------- 34 | models: Dict[str, torch.nn.Module] 35 | Models which need to be serialized as a checkpoint. 36 | optimizer: torch.optim.Optimizer 37 | Optimizer which needs to be serialized as a checkpoint. 38 | serialization_dir: str 39 | Path to an empty or non-existent directory to save checkpoints. 40 | mode: str, optional (default="max") 41 | One of ``min``, ``max``. In ``min`` mode, best checkpoint will be recorded when metric 42 | hits a lower value; in `max` mode it will be recorded when metric hits a higher value. 43 | filename_prefix: str, optional (default="checkpoint") 44 | Prefix of the to-be-saved checkpoint files. 45 | 46 | Examples 47 | -------- 48 | >>> model = torch.nn.Linear(10, 2) 49 | >>> optimizer = torch.optim.Adam(model.parameters()) 50 | >>> ckpt_manager = CheckpointManager({"model": model}, optimizer, "/tmp/ckpt", mode="min") 51 | >>> num_epochs = 20 52 | >>> for epoch in range(num_epochs): 53 | ... train(model) 54 | ... val_loss = validate(model) 55 | ... ckpt_manager.step(val_loss, epoch) 56 | """ 57 | def __init__( 58 | self, 59 | models: Dict[str, nn.Module], 60 | optimizer: Type[optim.Optimizer], 61 | serialization_dir: str, 62 | mode: str = "max", 63 | filename_prefix: str = "checkpoint", 64 | ): 65 | r"""Initialize checkpoint manager.""" 66 | for key in models: 67 | if not isinstance(models[key], nn.Module): 68 | raise TypeError("{} is not a Module".format(type(models).__name__)) 69 | 70 | if not isinstance(optimizer, optim.Optimizer): 71 | raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) 72 | 73 | self._models = models 74 | self._optimizer = optimizer 75 | self._serialization_dir = serialization_dir 76 | 77 | self._mode = mode 78 | self._filename_prefix = filename_prefix 79 | 80 | # Initialize members to hold state dict of best checkpoint and its performance. 81 | self._best_metric: Optional[Union[float, torch.Tensor]] = None 82 | self._best_ckpt: Dict[str, Any] = {} 83 | 84 | def step(self, metric: Union[float, torch.Tensor], epoch_or_iteration: int): 85 | r"""Serialize checkpoint and update best checkpoint based on metric and mode.""" 86 | 87 | # Update best checkpoint based on metric and metric mode. 88 | if not self._best_metric: 89 | self._best_metric = metric 90 | 91 | models_state_dict: Dict[str, Any] = {} 92 | for key in self._models: 93 | if isinstance(self._models[key], nn.DataParallel): 94 | models_state_dict[key] = self._models[key].module.state_dict() 95 | else: 96 | models_state_dict[key] = self._models[key].state_dict() 97 | 98 | if (self._mode == "min" and metric < self._best_metric) or ( 99 | self._mode == "max" and metric > self._best_metric 100 | ): 101 | self._best_metric = metric 102 | self._best_ckpt = copy.copy(models_state_dict) 103 | 104 | # Serialize checkpoint corresponding to current epoch (or iteration). 105 | torch.save( 106 | {**models_state_dict, "optimizer": self._optimizer.state_dict()}, 107 | os.path.join( 108 | self._serialization_dir, f"{self._filename_prefix}_{epoch_or_iteration}.pth" 109 | ), 110 | ) 111 | # Serialize best performing checkpoint observed so far. By default, 112 | # the best checkpoint is saved as "_best". 113 | torch.save( 114 | self._best_ckpt, 115 | os.path.join(self._serialization_dir, f"{self._filename_prefix}_best.pth"), 116 | ) 117 | 118 | def best_checkpoint(self, based_on_metric: str = ""): 119 | r"""Returns the best checkpoint. 120 | 121 | Defaults to returning the best checkpoint stored by step() above, if 122 | based_on_metric is not provided. If not, then looks for a file like, 123 | `checkpoint_best_modelmap' for example, if `based_on_metric` is 124 | `modelmap'. 125 | """ 126 | # If based_on_metric is like modelmetrics/acc, then read it as 127 | # modelmetrics_acc 128 | based_on_metric = based_on_metric.replace("/", "_") 129 | best_checkpoint_path = os.path.join( 130 | self._serialization_dir, "%s.pth" % ("_".join( 131 | [self._filename_prefix, "best", based_on_metric])).rstrip("_")) 132 | 133 | if not os.path.exists(best_checkpoint_path): 134 | raise ValueError( 135 | f"Best checkpoint based on {based_on_metric} does not exist.") 136 | 137 | return best_checkpoint_path, None 138 | 139 | @property 140 | def latest_checkpoint(self): 141 | all_checkpoint_epochs_or_iterations = glob.glob( 142 | os.path.join(self._serialization_dir, 143 | f"{self._filename_prefix}_*.pth")) 144 | 145 | # TODO(ramav): This is a bit brittle, replace the "best" check with 146 | # an int check. 147 | all_checkpoint_epochs_or_iterations = [ 148 | int(x.split("_")[-1].split(".")[0]) 149 | for x in all_checkpoint_epochs_or_iterations if 'best' not in x 150 | ] 151 | 152 | if len(all_checkpoint_epochs_or_iterations) != 0: 153 | latest_epoch_or_iteration = np.max( 154 | all_checkpoint_epochs_or_iterations) 155 | return os.path.join( 156 | self._serialization_dir, 157 | f"{self._filename_prefix}_{latest_epoch_or_iteration}.pth" 158 | ), latest_epoch_or_iteration 159 | return None, -1 160 | 161 | def all_checkpoints(self, sort_iterations: bool = False, 162 | random_shuffle: bool = False): 163 | if sort_iterations is True and random is True: 164 | raise ValueError("Only one of sorted or random can be true.") 165 | 166 | all_checkpoint_epochs_or_iterations = glob.glob( 167 | os.path.join(self._serialization_dir, 168 | f"{self._filename_prefix}_*.pth")) 169 | 170 | # TODO(ramav): This is a bit brittle, replace the "best" check with 171 | # an int check. 172 | all_checkpoint_epochs_or_iterations = [ 173 | int(x.split("_")[-1].split(".")[0]) 174 | for x in all_checkpoint_epochs_or_iterations if 'best' not in x 175 | ] 176 | 177 | # Sort iterations in increasing order, so that when we pop we 178 | # pick the latest iteration. 179 | if sort_iterations == True: 180 | all_checkpoint_epochs_or_iterations = sorted( 181 | all_checkpoint_epochs_or_iterations) 182 | elif random_shuffle == True: 183 | random.shuffle(all_checkpoint_epochs_or_iterations) 184 | 185 | if len(all_checkpoint_epochs_or_iterations) != 0: 186 | return [(os.path.join( 187 | self._serialization_dir, f"{self._filename_prefix}_{x}.pth"), x) 188 | for x in all_checkpoint_epochs_or_iterations] 189 | 190 | return None, -1 191 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Visualization routine for few-shot meta learning model.""" 7 | import torch 8 | import numpy as np 9 | 10 | from protonets.third_party.image_utils import plot_images 11 | 12 | 13 | def save_visuals_for_fewshot_model(samples, outputs, num_samples=4): 14 | """Save visualizations for a few-shot prototypical network. 15 | 16 | Args: 17 | samples: A dict provided during forward pass to the model. 18 | outputs: A list of items returned from the model. 19 | outputs[0]: is negative log-probs. 20 | outputs[1]: 21 | """ 22 | query_images = samples["query_images"].detach()[:num_samples] 23 | num_images_row = query_images.shape[1] 24 | 25 | query_images_concat = torch.reshape( 26 | query_images, (-1, query_images.shape[-3], query_images.shape[-2], 27 | query_images.shape[-1])).numpy() 28 | # Make images N x H x W x C from N x C x H x W 29 | query_images_concat = np.swapaxes(query_images_concat, 1, 3) 30 | query_images_concat = np.swapaxes(query_images_concat, 1, 2) 31 | pred_labels = np.argmin( 32 | torch.reshape(outputs["neg_log_p_y"][:num_samples].detach(), 33 | (-1, outputs["neg_log_p_y"].shape[-1])).numpy(), 34 | axis=-1) 35 | 36 | targets = torch.reshape(outputs[1][:num_samples].detach(), [-1]).numpy() 37 | 38 | annotations = [] 39 | for idx in range(num_samples): 40 | for row_idx in range(num_images_row): 41 | annotations.append( 42 | [ 43 | {"label": samples["hypothesis_string"][idx], "color": "#000801"}, 44 | ] 45 | ) 46 | 47 | viz_images = plot_images( 48 | query_images_concat, 49 | n=num_images_row, 50 | gt_labels=targets, 51 | annotations=annotations, 52 | orig_labels=pred_labels) 53 | 54 | return viz_images.astype(np.uint8) --------------------------------------------------------------------------------