├── docs ├── .gitignore ├── source │ ├── img │ │ ├── layers.png │ │ ├── cog │ │ │ └── cog_sample.png │ │ ├── core_concepts.png │ │ ├── model_class_diagram.png │ │ ├── worker_class_diagram.png │ │ ├── worker_state_diagram.png │ │ ├── configuration_sections.png │ │ ├── problem_class_diagram.png │ │ ├── episode_sequence_diagram.png │ │ ├── parameter_registry_tree.png │ │ ├── worker_grid_class_diagram.png │ │ ├── grid_configuration_sections.png │ │ ├── inference_sequence_diagram.png │ │ ├── problem_class_diagram_wide.png │ │ ├── worker_basic_class_diagram.png │ │ ├── algorithmic │ │ │ ├── psychometric_tests.png │ │ │ ├── recall │ │ │ │ ├── scratch_pad.png │ │ │ │ ├── serial_recall.png │ │ │ │ └── repeat_serial_recall.png │ │ │ ├── exemplary_sequence_data.png │ │ │ ├── problem_taxonomy_simple.png │ │ │ ├── problem_taxonomy_complex.png │ │ │ ├── exemplary_sequence_markers.png │ │ │ ├── dual_comparison │ │ │ │ ├── sequence_equality.png │ │ │ │ ├── sequence_symmetry.png │ │ │ │ └── sequence_comparison.png │ │ │ ├── manipulation_temporal │ │ │ │ ├── skip_recall.png │ │ │ │ ├── reverse_recall.png │ │ │ │ ├── repeat_reverse_recall.png │ │ │ │ └── manipulation_temporal_rotation.png │ │ │ └── manipulation_spatial │ │ │ │ ├── manipulation_spatial_not.png │ │ │ │ └── manipulation_spatial_rotation.png │ │ └── initialization_sequence_diagram.png │ ├── tutorials │ │ └── 1_lenet_mnist.rst │ ├── _static │ │ └── theme_overrides.css │ ├── api_reference │ │ ├── 6_helpers.rst │ │ ├── 4_workers.rst │ │ ├── 5_grid_workers.rst │ │ └── 1_utils.rst │ ├── mip_primer │ │ ├── 6_helpers.rst │ │ ├── 5_grid_workers.rst │ │ ├── 3_models_explained.rst │ │ └── 2_problems_explained.rst │ ├── index.rst │ ├── research │ │ └── 1_introduction.rst │ └── notes │ │ └── 1_installation.rst ├── Makefile ├── build_docs.sh └── make.bat ├── configs ├── cog │ ├── cog_cog.yaml │ ├── cog_mac.yaml │ ├── cog_generate.yaml │ └── default_cog.yaml ├── maes_baselines │ ├── lstm │ │ ├── default_lstm.yaml │ │ ├── lstm_reverse_recall.yaml │ │ ├── lstm_serial_recall.yaml │ │ ├── lstm_sequence_comparison.yaml │ │ ├── lstm_sequence_symmetry.yaml │ │ ├── lstm_sequence_equality.yaml │ │ └── lstm_odd_recall.yaml │ ├── es_lstm │ │ ├── default_es_lstm.yaml │ │ ├── es_lstm_serial_recall.yaml │ │ ├── es_lstm_reverse_recall.yaml │ │ ├── es_lstm_sequence_equality.yaml │ │ ├── es_lstm_sequence_symmetry.yaml │ │ ├── es_lstm_sequence_comparison.yaml │ │ └── es_lstm_odd_recall.yaml │ ├── ntm │ │ ├── ntm_reverse_recall.yaml │ │ ├── ntm_serial_recall.yaml │ │ ├── ntm_sequence_comparison.yaml │ │ ├── ntm_sequence_equality.yaml │ │ ├── ntm_sequence_symmetry.yaml │ │ ├── default_ntm.yaml │ │ └── ntm_odd_recall.yaml │ ├── dnc │ │ ├── dnc_serial_recall.yaml │ │ ├── dnc_reverse_recall.yaml │ │ ├── dnc_sequence_equality.yaml │ │ ├── dnc_sequence_symmetry.yaml │ │ ├── dnc_sequence_comparison.yaml │ │ ├── default_dnc.yaml │ │ ├── dnc_odd_recall.yaml │ │ └── dnc_default_training.yaml │ ├── maes │ │ ├── default_maes.yaml │ │ ├── maes_serial_recall.yaml │ │ ├── maes_sequence_equality.yaml │ │ ├── maes_sequence_symmetry.yaml │ │ ├── maes_reverse_recall.yaml │ │ ├── maes_sequence_comparison.yaml │ │ ├── maes_odd_recall.yaml │ │ ├── maes_repeat_reverse_recall.yaml │ │ └── maes_repeat_serial_recall.yaml │ ├── default_training.yaml │ ├── es_ntm │ │ └── es_ntm_serial_recall.yaml │ ├── default_problem.yaml │ └── maes_grid_training.yaml ├── dwm_baselines │ ├── lstm │ │ ├── default_lstm.yaml │ │ ├── distraction_forget.yaml │ │ ├── distraction_ignore.yaml │ │ ├── serial_recall.yaml │ │ ├── reverse_recall.yaml │ │ ├── operation_span.yaml │ │ ├── manipulation_spatial_rotation.yaml │ │ ├── reading_span.yaml │ │ └── scratch_pad.yaml │ ├── dwm │ │ ├── reading_span.yaml │ │ ├── interruption_not.yaml │ │ ├── scratch_pad.yaml │ │ ├── reverse_recall.yaml │ │ ├── serial_recall.yaml │ │ ├── distraction_forget.yaml │ │ ├── distraction_ignore.yaml │ │ ├── manipulation_spatial_not.yaml │ │ ├── default_dwm.yaml │ │ ├── interruption_reverse_recall.yaml │ │ ├── distraction_carry.yaml │ │ ├── operation_span.yaml │ │ ├── manipulation_spatial_rotation.yaml │ │ ├── manipulation_temporal_swap.yaml │ │ └── interruption_swap_recall.yaml │ ├── dnc │ │ ├── reading_span.yaml │ │ ├── scratch_pad.yaml │ │ ├── distraction_forget.yaml │ │ ├── reverse_recall.yaml │ │ ├── serial_recall.yaml │ │ ├── distraction_ignore.yaml │ │ ├── operation_span.yaml │ │ ├── manipulation_spatial_rotation.yaml │ │ ├── default_dnc.yaml │ │ ├── dnc_default_settings_simple_task.yaml │ │ └── dnc_default_settings_dual_task.yaml │ ├── default_settings_simple_task.yaml │ └── default_settings_dual_task.yaml ├── mac │ ├── s_mac_clevr.yaml │ ├── s_mac_cogent.yaml │ ├── mac_clevr.yaml │ ├── mac_cogent.yaml │ ├── default_cogent.yaml │ ├── default_clevr.yaml │ ├── mac_smac_cogent_a_finetuning.yaml │ └── mac_smac_cogent_b_finetuning.yaml ├── ntm │ ├── serial_recall_no_mask.yaml │ └── serial_recall.yaml ├── problem_inits │ └── COG_dataset_generator.yaml ├── vqa_baselines │ ├── cnn_lstm │ │ ├── clevr.yaml │ │ ├── vqa_med.yaml │ │ ├── sort_of_clevr.yaml │ │ └── shape_color_query.yaml │ ├── multi_hops_stacked_attention_networks │ │ └── shape_color_query.yaml │ ├── stacked_attention_networks │ │ ├── sort_of_clevr.yaml │ │ └── shape_color_query.yaml │ ├── default_vqa_med.yaml │ ├── default_sort_of_clevr.yaml │ ├── default_shape_color_query.yaml │ └── default_clevr.yaml ├── vision │ ├── alexnet_mnist.yaml │ ├── alexnet_cifar10.yaml │ ├── lenet5_mnist.yaml │ ├── grid_trainer_mnist.yaml │ ├── simplecnn_cifar10.yaml │ └── simplecnn_mnist.yaml ├── mental_model │ └── mm_cog.yaml ├── thalnet │ ├── seq_pixel_mnist.yaml │ └── serial_recall.yaml ├── relational_net │ └── sort_of_clevr.yaml └── text2text │ └── translation.yaml ├── miprometheus ├── helpers │ └── __init__.py ├── problems │ ├── seq_to_seq │ │ ├── video_text_to_class │ │ │ └── cog │ │ │ │ └── cog_utils │ │ │ │ └── roboto.ttf │ │ └── seq_to_seq_problem.py │ └── problem_factory.py ├── __init__.py ├── workers │ └── __init__.py ├── grid_workers │ └── __init__.py ├── models │ ├── controllers │ │ ├── __init__.py │ │ ├── feedforward_controller.py │ │ ├── gru_controller.py │ │ ├── ffgru_controller.py │ │ ├── lstm_controller.py │ │ ├── rnn_controller.py │ │ └── controller_factory.py │ ├── __init__.py │ ├── mental_model │ │ └── ops.py │ ├── mac │ │ ├── utils_mac.py │ │ ├── image_encoding.py │ │ └── output_unit.py │ ├── dnc │ │ ├── plot_data.py │ │ └── memory.py │ ├── VWM_model │ │ ├── utils_VWM.py │ │ └── thought_unit.py │ ├── vision │ │ └── lenet5.py │ ├── vqa_baselines │ │ └── stacked_attention_networks │ │ │ └── image_encoding.py │ └── dwm │ │ └── memory.py └── utils │ ├── singleton.py │ ├── __init__.py │ └── split_indices.py ├── .readthedocs.yml ├── requirements.txt ├── .gitignore ├── .github ├── ISSUE_TEMPLATE │ ├── feature_request.md │ └── bug_report.md └── PULL_REQUEST_TEMPLATE │ └── PULL_REQUEST_TEMPLATE.md └── .travis.yml /docs/.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | _build/ 3 | source/python.rst 4 | -------------------------------------------------------------------------------- /docs/source/img/layers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/mi-prometheus/HEAD/docs/source/img/layers.png -------------------------------------------------------------------------------- /docs/source/img/cog/cog_sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/mi-prometheus/HEAD/docs/source/img/cog/cog_sample.png -------------------------------------------------------------------------------- /docs/source/img/core_concepts.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/mi-prometheus/HEAD/docs/source/img/core_concepts.png -------------------------------------------------------------------------------- /docs/source/img/model_class_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/mi-prometheus/HEAD/docs/source/img/model_class_diagram.png -------------------------------------------------------------------------------- /docs/source/img/worker_class_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/mi-prometheus/HEAD/docs/source/img/worker_class_diagram.png -------------------------------------------------------------------------------- /docs/source/img/worker_state_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/mi-prometheus/HEAD/docs/source/img/worker_state_diagram.png -------------------------------------------------------------------------------- /docs/source/img/configuration_sections.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/mi-prometheus/HEAD/docs/source/img/configuration_sections.png -------------------------------------------------------------------------------- /docs/source/img/problem_class_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/mi-prometheus/HEAD/docs/source/img/problem_class_diagram.png -------------------------------------------------------------------------------- /docs/source/img/episode_sequence_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/mi-prometheus/HEAD/docs/source/img/episode_sequence_diagram.png -------------------------------------------------------------------------------- /docs/source/img/parameter_registry_tree.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/mi-prometheus/HEAD/docs/source/img/parameter_registry_tree.png -------------------------------------------------------------------------------- /docs/source/img/worker_grid_class_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/mi-prometheus/HEAD/docs/source/img/worker_grid_class_diagram.png -------------------------------------------------------------------------------- /docs/source/img/grid_configuration_sections.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/mi-prometheus/HEAD/docs/source/img/grid_configuration_sections.png -------------------------------------------------------------------------------- /docs/source/img/inference_sequence_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/mi-prometheus/HEAD/docs/source/img/inference_sequence_diagram.png -------------------------------------------------------------------------------- /docs/source/img/problem_class_diagram_wide.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/mi-prometheus/HEAD/docs/source/img/problem_class_diagram_wide.png -------------------------------------------------------------------------------- /docs/source/img/worker_basic_class_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/mi-prometheus/HEAD/docs/source/img/worker_basic_class_diagram.png -------------------------------------------------------------------------------- /docs/source/img/algorithmic/psychometric_tests.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/mi-prometheus/HEAD/docs/source/img/algorithmic/psychometric_tests.png -------------------------------------------------------------------------------- /docs/source/img/algorithmic/recall/scratch_pad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/mi-prometheus/HEAD/docs/source/img/algorithmic/recall/scratch_pad.png -------------------------------------------------------------------------------- /configs/cog/cog_cog.yaml: -------------------------------------------------------------------------------- 1 | # Load the COG config first. 2 | default_configs: 3 | configs/cog/default_cog.yaml 4 | 5 | model: 6 | name: CogModel 7 | 8 | -------------------------------------------------------------------------------- /docs/source/img/algorithmic/recall/serial_recall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/mi-prometheus/HEAD/docs/source/img/algorithmic/recall/serial_recall.png -------------------------------------------------------------------------------- /docs/source/img/initialization_sequence_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/mi-prometheus/HEAD/docs/source/img/initialization_sequence_diagram.png -------------------------------------------------------------------------------- /docs/source/img/algorithmic/exemplary_sequence_data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/mi-prometheus/HEAD/docs/source/img/algorithmic/exemplary_sequence_data.png -------------------------------------------------------------------------------- /docs/source/img/algorithmic/problem_taxonomy_simple.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/mi-prometheus/HEAD/docs/source/img/algorithmic/problem_taxonomy_simple.png -------------------------------------------------------------------------------- /docs/source/img/algorithmic/problem_taxonomy_complex.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/mi-prometheus/HEAD/docs/source/img/algorithmic/problem_taxonomy_complex.png -------------------------------------------------------------------------------- /docs/source/img/algorithmic/exemplary_sequence_markers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/mi-prometheus/HEAD/docs/source/img/algorithmic/exemplary_sequence_markers.png -------------------------------------------------------------------------------- /docs/source/img/algorithmic/recall/repeat_serial_recall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/mi-prometheus/HEAD/docs/source/img/algorithmic/recall/repeat_serial_recall.png -------------------------------------------------------------------------------- /configs/maes_baselines/lstm/default_lstm.yaml: -------------------------------------------------------------------------------- 1 | # Model parameters: 2 | model: 3 | name: LSTM 4 | # Hidden state. 5 | hidden_state_size: 256 6 | num_layers: 1 7 | -------------------------------------------------------------------------------- /docs/source/img/algorithmic/dual_comparison/sequence_equality.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/mi-prometheus/HEAD/docs/source/img/algorithmic/dual_comparison/sequence_equality.png -------------------------------------------------------------------------------- /docs/source/img/algorithmic/dual_comparison/sequence_symmetry.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/mi-prometheus/HEAD/docs/source/img/algorithmic/dual_comparison/sequence_symmetry.png -------------------------------------------------------------------------------- /docs/source/img/algorithmic/manipulation_temporal/skip_recall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/mi-prometheus/HEAD/docs/source/img/algorithmic/manipulation_temporal/skip_recall.png -------------------------------------------------------------------------------- /configs/dwm_baselines/lstm/default_lstm.yaml: -------------------------------------------------------------------------------- 1 | # Model parameters: 2 | model: 3 | name: LSTM 4 | # Controller parameters. 5 | num_layers: 3 6 | hidden_state_size: 512 7 | -------------------------------------------------------------------------------- /docs/source/img/algorithmic/dual_comparison/sequence_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/mi-prometheus/HEAD/docs/source/img/algorithmic/dual_comparison/sequence_comparison.png -------------------------------------------------------------------------------- /docs/source/img/algorithmic/manipulation_temporal/reverse_recall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/mi-prometheus/HEAD/docs/source/img/algorithmic/manipulation_temporal/reverse_recall.png -------------------------------------------------------------------------------- /docs/source/img/algorithmic/manipulation_temporal/repeat_reverse_recall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/mi-prometheus/HEAD/docs/source/img/algorithmic/manipulation_temporal/repeat_reverse_recall.png -------------------------------------------------------------------------------- /configs/maes_baselines/es_lstm/default_es_lstm.yaml: -------------------------------------------------------------------------------- 1 | # Model parameters: 2 | model: 3 | name: EncoderSolverLSTM 4 | # Controller hidden state. 5 | hidden_state_size: 256 6 | num_layers: 1 7 | -------------------------------------------------------------------------------- /docs/source/img/algorithmic/manipulation_spatial/manipulation_spatial_not.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/mi-prometheus/HEAD/docs/source/img/algorithmic/manipulation_spatial/manipulation_spatial_not.png -------------------------------------------------------------------------------- /miprometheus/helpers/__init__.py: -------------------------------------------------------------------------------- 1 | # Helpers. 2 | from .index_splitter import IndexSplitter 3 | from .problem_initializer import ProblemInitializer 4 | 5 | __all__ = ['IndexSplitter', 'ProblemInitializer'] 6 | -------------------------------------------------------------------------------- /miprometheus/problems/seq_to_seq/video_text_to_class/cog/cog_utils/roboto.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/mi-prometheus/HEAD/miprometheus/problems/seq_to_seq/video_text_to_class/cog/cog_utils/roboto.ttf -------------------------------------------------------------------------------- /docs/source/img/algorithmic/manipulation_spatial/manipulation_spatial_rotation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/mi-prometheus/HEAD/docs/source/img/algorithmic/manipulation_spatial/manipulation_spatial_rotation.png -------------------------------------------------------------------------------- /docs/source/img/algorithmic/manipulation_temporal/manipulation_temporal_rotation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/mi-prometheus/HEAD/docs/source/img/algorithmic/manipulation_temporal/manipulation_temporal_rotation.png -------------------------------------------------------------------------------- /docs/source/tutorials/1_lenet_mnist.rst: -------------------------------------------------------------------------------- 1 | Basics: LeNet-5 on MNIST 2 | =========================== 3 | `@author: Tomasz Kornuta` 4 | 5 | This tutorial will be a step-by-step explanation of how to run LeNet-5 on MNIST using Mi-Prometheus. 6 | 7 | `Coming soon!` -------------------------------------------------------------------------------- /miprometheus/__init__.py: -------------------------------------------------------------------------------- 1 | from .grid_workers import * 2 | 3 | from .helpers import * 4 | 5 | from .models import * 6 | from .models.controllers import * 7 | 8 | from .problems import * 9 | 10 | from .utils import * 11 | from .workers import * 12 | -------------------------------------------------------------------------------- /configs/mac/s_mac_clevr.yaml: -------------------------------------------------------------------------------- 1 | # Load the CLEVR config first. 2 | default_configs: 3 | configs/mac/default_clevr.yaml 4 | 5 | # Add the model parameters: 6 | model: 7 | name: sMacNetwork 8 | dim: 512 9 | embed_hidden: 300 10 | max_step: 12 11 | dropout: 0.15 12 | -------------------------------------------------------------------------------- /configs/mac/s_mac_cogent.yaml: -------------------------------------------------------------------------------- 1 | # Load the CoGenT config first. 2 | default_configs: 3 | configs/mac/default_cogent.yaml 4 | 5 | # Add the model parameters: 6 | model: 7 | name: sMacNetwork 8 | dim: 512 9 | embed_hidden: 300 10 | max_step: 12 11 | dropout: 0.15 12 | -------------------------------------------------------------------------------- /configs/ntm/serial_recall_no_mask.yaml: -------------------------------------------------------------------------------- 1 | default_configs: 2 | configs/ntm/serial_recall.yaml 3 | 4 | # Problem parameters: 5 | training: 6 | problem: 7 | use_mask: False 8 | 9 | validation: 10 | problem: 11 | use_mask: False 12 | 13 | testing: 14 | problem: 15 | use_mask: False 16 | -------------------------------------------------------------------------------- /configs/mac/mac_clevr.yaml: -------------------------------------------------------------------------------- 1 | # Load the CLEVR config first. 2 | default_configs: 3 | configs/mac/default_clevr.yaml 4 | 5 | # Add the model parameters: 6 | model: 7 | name: MACNetwork 8 | dim: 512 9 | embed_hidden: 300 10 | max_step: 12 11 | self_attention: False 12 | memory_gate: False 13 | dropout: 0.15 14 | -------------------------------------------------------------------------------- /configs/mac/mac_cogent.yaml: -------------------------------------------------------------------------------- 1 | # Load the CoGenT config first. 2 | default_configs: 3 | configs/mac/default_cogent.yaml 4 | 5 | # Add the model parameters: 6 | model: 7 | name: MACNetwork 8 | dim: 512 9 | embed_hidden: 300 10 | max_step: 12 11 | self_attention: False 12 | memory_gate: False 13 | dropout: 0.15 14 | -------------------------------------------------------------------------------- /configs/problem_inits/COG_dataset_generator.yaml: -------------------------------------------------------------------------------- 1 | # Problem parameters: 2 | training: 3 | problem: 4 | name: COG 5 | data_folder: '~/data/cog' 6 | set: train 7 | dataset_type: generated 8 | generation: 9 | examples_per_task: 256 10 | sequence_length: 5 11 | memory_length: 4 12 | max_distractors: 2 13 | nr_processors: 6 14 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | 3 | # build image specification 4 | build: 5 | image: latest 6 | 7 | # Python configuration 8 | python: 9 | version: 3.6 10 | setup_py_install: false 11 | use_system_site_packages: false 12 | 13 | #requirements_file: docs/requirements.txt 14 | 15 | # Don't build any extra formats 16 | formats: [] 17 | -------------------------------------------------------------------------------- /miprometheus/workers/__init__.py: -------------------------------------------------------------------------------- 1 | from .worker import Worker 2 | from .trainer import Trainer 3 | from .offline_trainer import OfflineTrainer 4 | from .online_trainer import OnlineTrainer 5 | from .tester import Tester 6 | 7 | __all__ = [ 8 | 'Worker', 9 | 'Trainer', 10 | 'OfflineTrainer', 11 | 'OnlineTrainer', 12 | 'Tester' 13 | ] 14 | -------------------------------------------------------------------------------- /configs/vqa_baselines/cnn_lstm/clevr.yaml: -------------------------------------------------------------------------------- 1 | # Load the CLEVR config first. 2 | default_configs: 3 | configs/vqa_baselines/default_clevr.yaml 4 | 5 | # Model parameters: 6 | model: 7 | name: CNN_LSTM 8 | lstm: 9 | hidden_size: 64 10 | num_layers: 1 11 | bidirectional: False 12 | dropout: 0 13 | classifier: 14 | nb_hidden_nodes: 256 15 | 16 | -------------------------------------------------------------------------------- /configs/vqa_baselines/cnn_lstm/vqa_med.yaml: -------------------------------------------------------------------------------- 1 | # Load the VQA_MED config first. 2 | default_configs: 3 | configs/vqa_baselines/default_vqa_med.yaml 4 | 5 | # Model parameters: 6 | model: 7 | name: CNN_LSTM 8 | lstm: 9 | hidden_size: 64 10 | num_layers: 1 11 | bidirectional: False 12 | dropout: 0 13 | classifier: 14 | nb_hidden_nodes: 256 15 | 16 | -------------------------------------------------------------------------------- /configs/vqa_baselines/cnn_lstm/sort_of_clevr.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/vqa_baselines/default_sort_of_clevr.yaml 4 | 5 | # Model parameters: 6 | model: 7 | name: CNN_LSTM 8 | lstm: 9 | hidden_size: 64 10 | num_layers: 1 11 | bidirectional: False 12 | dropout: 0 13 | classifier: 14 | nb_hidden_nodes: 256 15 | -------------------------------------------------------------------------------- /configs/vqa_baselines/cnn_lstm/shape_color_query.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/vqa_baselines/default_shape_color_query.yaml 4 | 5 | # Model parameters: 6 | model: 7 | name: CNN_LSTM 8 | lstm: 9 | hidden_size: 64 10 | num_layers: 1 11 | bidirectional: False 12 | dropout: 0 13 | classifier: 14 | nb_hidden_nodes: 256 15 | 16 | -------------------------------------------------------------------------------- /docs/source/_static/theme_overrides.css: -------------------------------------------------------------------------------- 1 | /* override table width restrictions */ 2 | @media screen and (min-width: 767px) { 3 | 4 | .wy-table-responsive table td { 5 | /* !important prevents the common CSS stylesheets from overriding 6 | this as on RTD they are loaded after this stylesheet */ 7 | white-space: normal !important; 8 | } 9 | 10 | .wy-table-responsive { 11 | overflow: visible !important; 12 | } 13 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | alabaster==0.7.12 2 | Babel==2.6.0 3 | certifi==2018.11.29 4 | chardet==3.0.4 5 | docutils==0.14 6 | idna==2.8 7 | imagesize==1.1.0 8 | Jinja2==2.10.1 9 | MarkupSafe==1.1.0 10 | packaging==19.0 11 | Pillow==5.4.1 12 | Pygments==2.3.1 13 | pyparsing==2.3.1 14 | pytz==2018.9 15 | requests==2.21.0 16 | six==1.12.0 17 | snowballstemmer==1.2.1 18 | Sphinx==1.8.4 19 | sphinx-rtd-theme==0.4.2 20 | sphinxcontrib-websupport==1.1.0 21 | urllib3==1.24.2 22 | -------------------------------------------------------------------------------- /configs/dwm_baselines/dwm/reading_span.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/dwm_baselines/dwm/default_dwm.yaml, 4 | configs/dwm_baselines/default_settings_dual_task.yaml 5 | 6 | # Then overwrite problem name(s). 7 | training: 8 | problem: 9 | name: &name reading_span 10 | 11 | validation: 12 | problem: 13 | name: *name 14 | 15 | testing: 16 | problem: 17 | name: *name 18 | -------------------------------------------------------------------------------- /miprometheus/grid_workers/__init__.py: -------------------------------------------------------------------------------- 1 | # Grid workers. 2 | from .grid_worker import GridWorker 3 | from .grid_trainer_cpu import GridTrainerCPU 4 | from .grid_trainer_gpu import GridTrainerGPU 5 | from .grid_tester_cpu import GridTesterCPU 6 | from .grid_tester_gpu import GridTesterGPU 7 | from .grid_analyzer import GridAnalyzer 8 | 9 | 10 | __all__ = ['GridWorker', 'GridTrainerCPU', 'GridTrainerGPU', 11 | 'GridTesterCPU', 'GridTesterGPU', 'GridAnalyzer'] 12 | -------------------------------------------------------------------------------- /configs/dwm_baselines/dnc/reading_span.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/dwm_baselines/dnc/default_dnc.yaml, 4 | configs/dwm_baselines/dnc/dnc_default_settings_dual_task.yaml 5 | 6 | # Then overwrite problem name(s). 7 | training: 8 | problem: 9 | name: &name reading_span 10 | 11 | validation: 12 | problem: 13 | name: *name 14 | 15 | testing: 16 | problem: 17 | name: *name 18 | -------------------------------------------------------------------------------- /configs/dwm_baselines/dwm/interruption_not.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/dwm_baselines/dwm/default_dwm.yaml, 4 | configs/dwm_baselines/default_settings_dual_task.yaml 5 | 6 | # Then overwrite problem name(s). 7 | training: 8 | problem: 9 | name: &name interruption_not 10 | 11 | validation: 12 | problem: 13 | name: *name 14 | 15 | testing: 16 | problem: 17 | name: *name 18 | -------------------------------------------------------------------------------- /configs/dwm_baselines/dwm/scratch_pad.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/dwm_baselines/dwm/default_dwm.yaml, 4 | configs/dwm_baselines/default_settings_dual_task.yaml 5 | 6 | # Then overwrite problem name(s). 7 | training: 8 | problem: 9 | name: &name ScratchPadCommandLines 10 | 11 | validation: 12 | problem: 13 | name: *name 14 | 15 | testing: 16 | problem: 17 | name: *name 18 | -------------------------------------------------------------------------------- /configs/dwm_baselines/dnc/scratch_pad.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/dwm_baselines/dnc/default_dnc.yaml, 4 | configs/dwm_baselines/dnc/dnc_default_settings_dual_task.yaml 5 | 6 | # Then overwrite problem name(s). 7 | training: 8 | problem: 9 | name: &name ScratchPadCommandLines 10 | 11 | validation: 12 | problem: 13 | name: *name 14 | 15 | testing: 16 | problem: 17 | name: *name 18 | -------------------------------------------------------------------------------- /configs/dwm_baselines/dwm/reverse_recall.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/dwm_baselines/dwm/default_dwm.yaml, 4 | configs/dwm_baselines/default_settings_simple_task.yaml 5 | 6 | # Then overwrite problem name(s). 7 | training: 8 | problem: 9 | name: &name ReverseRecallCommandLines 10 | 11 | validation: 12 | problem: 13 | name: *name 14 | 15 | testing: 16 | problem: 17 | name: *name 18 | -------------------------------------------------------------------------------- /configs/dwm_baselines/dwm/serial_recall.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/dwm_baselines/dwm/default_dwm.yaml, 4 | configs/dwm_baselines/default_settings_simple_task.yaml 5 | 6 | # Then overwrite problem name(s). 7 | training: 8 | problem: 9 | name: &name SerialRecallCommandLines 10 | 11 | validation: 12 | problem: 13 | name: *name 14 | 15 | testing: 16 | problem: 17 | name: *name 18 | -------------------------------------------------------------------------------- /configs/dwm_baselines/lstm/distraction_forget.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/dwm_baselines/lstm/default_lstm.yaml, 4 | configs/dwm_baselines/default_settings_dual_task.yaml 5 | 6 | # Then overwrite problem name(s). 7 | training: 8 | problem: 9 | name: &name distraction_forget 10 | 11 | validation: 12 | problem: 13 | name: *name 14 | 15 | testing: 16 | problem: 17 | name: *name 18 | -------------------------------------------------------------------------------- /configs/dwm_baselines/lstm/distraction_ignore.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/dwm_baselines/lstm/default_lstm.yaml, 4 | configs/dwm_baselines/default_settings_dual_task.yaml 5 | 6 | # Then overwrite problem name(s). 7 | training: 8 | problem: 9 | name: &name distraction_ignore 10 | 11 | validation: 12 | problem: 13 | name: *name 14 | 15 | testing: 16 | problem: 17 | name: *name 18 | -------------------------------------------------------------------------------- /configs/dwm_baselines/lstm/serial_recall.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/dwm_baselines/lstm/default_lstm.yaml, 4 | configs/dwm_baselines/default_settings_simple_task.yaml 5 | 6 | # Then overwrite problem name(s). 7 | training: 8 | problem: 9 | name: &name SerialRecallCommandLines 10 | 11 | validation: 12 | problem: 13 | name: *name 14 | 15 | testing: 16 | problem: 17 | name: *name 18 | -------------------------------------------------------------------------------- /configs/dwm_baselines/dwm/distraction_forget.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/dwm_baselines/dwm/default_dwm.yaml, 4 | configs/dwm_baselines/default_settings_dual_task.yaml 5 | 6 | # Then overwrite problem name(s). 7 | training: 8 | problem: 9 | name: &name distraction_forget 10 | 11 | validation: 12 | problem: 13 | name: *name 14 | 15 | testing: 16 | problem: 17 | name: *name 18 | 19 | -------------------------------------------------------------------------------- /configs/dwm_baselines/lstm/reverse_recall.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/dwm_baselines/lstm/default_lstm.yaml, 4 | configs/dwm_baselines/default_settings_simple_task.yaml 5 | 6 | # Then overwrite problem name(s). 7 | training: 8 | problem: 9 | name: &name ReverseRecallCommandLines 10 | 11 | validation: 12 | problem: 13 | name: *name 14 | 15 | testing: 16 | problem: 17 | name: *name 18 | -------------------------------------------------------------------------------- /configs/dwm_baselines/dnc/distraction_forget.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/dwm_baselines/dnc/default_dnc.yaml, 4 | configs/dwm_baselines/dnc/dnc_default_settings_dual_task.yaml 5 | 6 | # Then overwrite problem name(s). 7 | training: 8 | problem: 9 | name: &name distraction_forget 10 | 11 | validation: 12 | problem: 13 | name: *name 14 | 15 | testing: 16 | problem: 17 | name: *name 18 | 19 | -------------------------------------------------------------------------------- /configs/dwm_baselines/dnc/reverse_recall.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/dwm_baselines/dnc/default_dnc.yaml, 4 | configs/dwm_baselines/dnc/dnc_default_settings_simple_task.yaml 5 | 6 | # Then overwrite problem name(s). 7 | training: 8 | problem: 9 | name: &name ReverseRecallCommandLines 10 | 11 | validation: 12 | problem: 13 | name: *name 14 | 15 | testing: 16 | problem: 17 | name: *name 18 | -------------------------------------------------------------------------------- /configs/dwm_baselines/dnc/serial_recall.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/dwm_baselines/dnc/default_dnc.yaml, 4 | configs/dwm_baselines/dnc/dnc_default_settings_simple_task.yaml 5 | 6 | # Then overwrite problem name(s). 7 | training: 8 | problem: 9 | name: &name SerialRecallCommandLines 10 | 11 | validation: 12 | problem: 13 | name: *name 14 | 15 | testing: 16 | problem: 17 | name: *name 18 | -------------------------------------------------------------------------------- /configs/dwm_baselines/dwm/distraction_ignore.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/dwm_baselines/dwm/default_dwm.yaml, 4 | configs/dwm_baselines/default_settings_dual_task.yaml 5 | 6 | # Then overwrite problem name(s). 7 | training: 8 | problem: 9 | name: &name distraction_ignore 10 | 11 | validation: 12 | problem: 13 | name: *name 14 | 15 | testing: 16 | problem: 17 | name: *name 18 | 19 | 20 | -------------------------------------------------------------------------------- /configs/dwm_baselines/dwm/manipulation_spatial_not.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/dwm_baselines/dwm/default_dwm.yaml, 4 | configs/dwm_baselines/default_settings_simple_task.yaml 5 | 6 | # Then overwrite problem name(s). 7 | training: 8 | problem: 9 | name: &name manipulation_spatial_not 10 | 11 | validation: 12 | problem: 13 | name: *name 14 | 15 | testing: 16 | problem: 17 | name: *name 18 | -------------------------------------------------------------------------------- /configs/dwm_baselines/dnc/distraction_ignore.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/dwm_baselines/dnc/default_dnc.yaml, 4 | configs/dwm_baselines/dnc/dnc_default_settings_dual_task.yaml 5 | 6 | # Then overwrite problem name(s). 7 | training: 8 | problem: 9 | name: &name distraction_ignore 10 | 11 | validation: 12 | problem: 13 | name: *name 14 | 15 | testing: 16 | problem: 17 | name: *name 18 | 19 | 20 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything 2 | * 3 | # But descend into directories 4 | !*/ 5 | # Recursively allow files under subtree 6 | 7 | !Readme.md 8 | !readthedocs.yml 9 | !setup.py 10 | !doc_build.sh 11 | !__init__.py 12 | !/configs/** 13 | !/docs/** 14 | !/miprometheus/** 15 | !.github/** 16 | 17 | # You can be specific with these rules 18 | __pycache__* 19 | *.swp 20 | *.vector_cache 21 | !.gitignore 22 | 23 | # Ignore every DS_Store 24 | **/.DS_Store 25 | 26 | # Ignore build directory 27 | /build/** -------------------------------------------------------------------------------- /configs/dwm_baselines/dwm/default_dwm.yaml: -------------------------------------------------------------------------------- 1 | # Model parameters: 2 | model: 3 | name: DWM # dynamic_working_memory 4 | # Controller hidden state. 5 | hidden_state_size: 5 6 | # Memory (?). 7 | memory_content_size: 10 8 | memory_addresses_size: -1 9 | # Number of heads. 10 | num_heads: 1 11 | # Addressing parameters. 12 | use_content_addressing: False 13 | shift_size: 3 14 | # active the plotting of the memory and attention 15 | plot_memory: False 16 | -------------------------------------------------------------------------------- /configs/dwm_baselines/dwm/interruption_reverse_recall.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/dwm_baselines/dwm/default_dwm.yaml, 4 | configs/dwm_baselines/default_settings_dual_task.yaml 5 | 6 | # Then overwrite problem name(s). 7 | training: 8 | problem: 9 | name: &name interruption_reverse_recall 10 | 11 | validation: 12 | problem: 13 | name: *name 14 | 15 | testing: 16 | problem: 17 | name: *name 18 | 19 | 20 | -------------------------------------------------------------------------------- /configs/dwm_baselines/dwm/distraction_carry.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/dwm_baselines/dwm/default_dwm.yaml, 4 | configs/dwm_baselines/default_settings_dual_task.yaml 5 | 6 | # Then overwrite problem name(s). 7 | training: 8 | problem: 9 | name: &name distraction_carry 10 | 11 | validation: 12 | problem: 13 | name: *name 14 | 15 | testing: 16 | problem: 17 | name: *name 18 | 19 | 20 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /configs/maes_baselines/ntm/ntm_reverse_recall.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/maes_baselines/ntm/default_ntm.yaml, 4 | configs/maes_baselines/default_problem.yaml, 5 | configs/maes_baselines/default_training.yaml 6 | 7 | # Then overwrite problem name(s). 8 | training: 9 | problem: 10 | name: &name ReverseRecallCommandLines 11 | 12 | validation: 13 | problem: 14 | name: *name 15 | 16 | testing: 17 | problem: 18 | name: *name 19 | -------------------------------------------------------------------------------- /configs/maes_baselines/ntm/ntm_serial_recall.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/maes_baselines/ntm/default_ntm.yaml, 4 | configs/maes_baselines/default_problem.yaml, 5 | configs/maes_baselines/default_training.yaml 6 | 7 | # Then overwrite problem name(s). 8 | training: 9 | problem: 10 | name: &name SerialRecallCommandLines 11 | 12 | validation: 13 | problem: 14 | name: *name 15 | 16 | testing: 17 | problem: 18 | name: *name 19 | -------------------------------------------------------------------------------- /configs/maes_baselines/lstm/lstm_reverse_recall.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/maes_baselines/lstm/default_lstm.yaml, 4 | configs/maes_baselines/default_problem.yaml, 5 | configs/maes_baselines/default_training.yaml 6 | 7 | # Then overwrite problem name(s). 8 | training: 9 | problem: 10 | name: &name ReverseRecallCommandLines 11 | 12 | validation: 13 | problem: 14 | name: *name 15 | 16 | testing: 17 | problem: 18 | name: *name 19 | -------------------------------------------------------------------------------- /configs/maes_baselines/lstm/lstm_serial_recall.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/maes_baselines/lstm/default_lstm.yaml, 4 | configs/maes_baselines/default_problem.yaml, 5 | configs/maes_baselines/default_training.yaml 6 | 7 | # Then overwrite problem name(s). 8 | training: 9 | problem: 10 | name: &name SerialRecallCommandLines 11 | 12 | validation: 13 | problem: 14 | name: *name 15 | 16 | testing: 17 | problem: 18 | name: *name 19 | -------------------------------------------------------------------------------- /configs/maes_baselines/dnc/dnc_serial_recall.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/maes_baselines/dnc/default_dnc.yaml, 4 | configs/maes_baselines/default_problem.yaml, 5 | configs/maes_baselines/dnc/dnc_default_training.yaml 6 | 7 | # Then overwrite problem name(s). 8 | training: 9 | problem: 10 | name: &name SerialRecallCommandLines 11 | 12 | validation: 13 | problem: 14 | name: *name 15 | 16 | testing: 17 | problem: 18 | name: *name 19 | -------------------------------------------------------------------------------- /configs/maes_baselines/ntm/ntm_sequence_comparison.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/maes_baselines/ntm/default_ntm.yaml, 4 | configs/maes_baselines/default_problem.yaml, 5 | configs/maes_baselines/default_training.yaml 6 | 7 | # Then overwrite problem name(s). 8 | training: 9 | problem: 10 | name: &name SequenceComparisonCommandLines 11 | 12 | validation: 13 | problem: 14 | name: *name 15 | 16 | testing: 17 | problem: 18 | name: *name 19 | -------------------------------------------------------------------------------- /configs/maes_baselines/ntm/ntm_sequence_equality.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/maes_baselines/ntm/default_ntm.yaml, 4 | configs/maes_baselines/default_problem.yaml, 5 | configs/maes_baselines/default_training.yaml 6 | 7 | # Then overwrite problem name(s). 8 | training: 9 | problem: 10 | name: &name SequenceEqualityCommandLines 11 | 12 | validation: 13 | problem: 14 | name: *name 15 | 16 | testing: 17 | problem: 18 | name: *name 19 | -------------------------------------------------------------------------------- /configs/maes_baselines/dnc/dnc_reverse_recall.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/maes_baselines/dnc/default_dnc.yaml, 4 | configs/maes_baselines/default_problem.yaml, 5 | configs/maes_baselines/dnc/dnc_default_training.yaml 6 | 7 | # Then overwrite problem name(s). 8 | training: 9 | problem: 10 | name: &name ReverseRecallCommandLines 11 | 12 | validation: 13 | problem: 14 | name: *name 15 | 16 | testing: 17 | problem: 18 | name: *name 19 | 20 | -------------------------------------------------------------------------------- /configs/maes_baselines/dnc/dnc_sequence_equality.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/maes_baselines/dnc/default_dnc.yaml, 4 | configs/maes_baselines/default_problem.yaml, 5 | configs/maes_baselines/dnc/dnc_default_training.yaml 6 | 7 | # Then overwrite problem name(s). 8 | training: 9 | problem: 10 | name: &name SequenceEqualityCommandLines 11 | 12 | validation: 13 | problem: 14 | name: *name 15 | 16 | testing: 17 | problem: 18 | name: *name 19 | -------------------------------------------------------------------------------- /configs/maes_baselines/dnc/dnc_sequence_symmetry.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/maes_baselines/dnc/default_dnc.yaml, 4 | configs/maes_baselines/default_problem.yaml, 5 | configs/maes_baselines/dnc/dnc_default_training.yaml 6 | 7 | # Then overwrite problem name(s). 8 | training: 9 | problem: 10 | name: &name SequenceSymmetryCommandLines 11 | 12 | validation: 13 | problem: 14 | name: *name 15 | 16 | testing: 17 | problem: 18 | name: *name 19 | -------------------------------------------------------------------------------- /configs/maes_baselines/es_lstm/es_lstm_serial_recall.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/maes_baselines/es_lstm/default_es_lstm.yaml, 4 | configs/maes_baselines/default_problem.yaml, 5 | configs/maes_baselines/default_training.yaml 6 | 7 | # Then overwrite problem name(s). 8 | training: 9 | problem: 10 | name: &name SerialRecallCommandLines 11 | 12 | validation: 13 | problem: 14 | name: *name 15 | 16 | testing: 17 | problem: 18 | name: *name 19 | -------------------------------------------------------------------------------- /configs/maes_baselines/lstm/lstm_sequence_comparison.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/maes_baselines/lstm/default_lstm.yaml, 4 | configs/maes_baselines/default_problem.yaml, 5 | configs/maes_baselines/default_training.yaml 6 | 7 | # Then overwrite problem name(s). 8 | training: 9 | problem: 10 | name: &name SequenceComparisonCommandLines 11 | 12 | validation: 13 | problem: 14 | name: *name 15 | 16 | testing: 17 | problem: 18 | name: *name 19 | -------------------------------------------------------------------------------- /configs/maes_baselines/lstm/lstm_sequence_symmetry.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/maes_baselines/lstm/default_lstm.yaml, 4 | configs/maes_baselines/default_problem.yaml, 5 | configs/maes_baselines/default_training.yaml 6 | 7 | # Then overwrite problem name(s). 8 | training: 9 | problem: 10 | name: &name SequenceSymmetryCommandLines 11 | 12 | validation: 13 | problem: 14 | name: *name 15 | 16 | testing: 17 | problem: 18 | name: *name 19 | -------------------------------------------------------------------------------- /configs/maes_baselines/ntm/ntm_sequence_symmetry.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/maes_baselines/ntm/default_ntm.yaml, 4 | configs/maes_baselines/default_problem.yaml, 5 | configs/maes_baselines/default_training.yaml 6 | 7 | # Then overwrite problem name(s). 8 | training: 9 | problem: 10 | name: &name SequenceSymmetryCommandLines 11 | 12 | validation: 13 | problem: 14 | name: *name 15 | 16 | testing: 17 | problem: 18 | name: *name 19 | 20 | -------------------------------------------------------------------------------- /configs/vqa_baselines/multi_hops_stacked_attention_networks/shape_color_query.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/vqa_baselines/default_shape_color_query.yaml 4 | 5 | # Model parameters: 6 | model: 7 | name: MultiHopsStackedAttentionNetwork 8 | use_pretrained_cnn: False 9 | pretrained_cnn: 10 | name: 'resnet18' 11 | num_layers: 2 12 | attention_layer: 13 | nb_nodes: 128 14 | classifier: 15 | nb_hidden_nodes: 256 16 | default_nb_hops: 3 17 | -------------------------------------------------------------------------------- /configs/maes_baselines/dnc/dnc_sequence_comparison.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/maes_baselines/dnc/default_dnc.yaml, 4 | configs/maes_baselines/default_problem.yaml, 5 | configs/maes_baselines/dnc/dnc_default_training.yaml 6 | 7 | # Then overwrite problem name(s). 8 | training: 9 | problem: 10 | name: &name SequenceComparisonCommandLines 11 | 12 | validation: 13 | problem: 14 | name: *name 15 | 16 | testing: 17 | problem: 18 | name: *name 19 | -------------------------------------------------------------------------------- /configs/maes_baselines/es_lstm/es_lstm_reverse_recall.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/maes_baselines/es_lstm/default_es_lstm.yaml, 4 | configs/maes_baselines/default_problem.yaml, 5 | configs/maes_baselines/default_training.yaml 6 | 7 | # Then overwrite problem name(s). 8 | training: 9 | problem: 10 | name: &name ReverseRecallCommandLines 11 | 12 | validation: 13 | problem: 14 | name: *name 15 | 16 | testing: 17 | problem: 18 | name: *name 19 | -------------------------------------------------------------------------------- /configs/maes_baselines/es_lstm/es_lstm_sequence_equality.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/maes_baselines/es_lstm/default_es_lstm.yaml, 4 | configs/maes_baselines/default_problem.yaml, 5 | configs/maes_baselines/default_training.yaml 6 | 7 | # Then overwrite problem name(s). 8 | training: 9 | problem: 10 | name: &name SequenceEqualityCommandLines 11 | 12 | validation: 13 | problem: 14 | name: *name 15 | 16 | testing: 17 | problem: 18 | name: *name 19 | -------------------------------------------------------------------------------- /configs/maes_baselines/es_lstm/es_lstm_sequence_symmetry.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/maes_baselines/es_lstm/default_es_lstm.yaml, 4 | configs/maes_baselines/default_problem.yaml, 5 | configs/maes_baselines/default_training.yaml 6 | 7 | # Then overwrite problem name(s). 8 | training: 9 | problem: 10 | name: &name SequenceSymmetryCommandLines 11 | 12 | validation: 13 | problem: 14 | name: *name 15 | 16 | testing: 17 | problem: 18 | name: *name 19 | -------------------------------------------------------------------------------- /configs/maes_baselines/lstm/lstm_sequence_equality.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/maes_baselines/lstm/default_lstm.yaml, 4 | configs/maes_baselines/default_problem.yaml, 5 | configs/maes_baselines/default_training.yaml 6 | 7 | # Then overwrite problem name(s). 8 | training: 9 | problem: 10 | name: &name SequenceEqualityCommandLines 11 | 12 | validation: 13 | problem: 14 | name: *name 15 | 16 | testing: 17 | problem: 18 | name: *name 19 | 20 | 21 | -------------------------------------------------------------------------------- /configs/maes_baselines/es_lstm/es_lstm_sequence_comparison.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/maes_baselines/es_lstm/default_es_lstm.yaml, 4 | configs/maes_baselines/default_problem.yaml, 5 | configs/maes_baselines/default_training.yaml 6 | 7 | # Then overwrite problem name(s). 8 | training: 9 | problem: 10 | name: &name SequenceComparisonCommandLines 11 | 12 | validation: 13 | problem: 14 | name: *name 15 | 16 | testing: 17 | problem: 18 | name: *name 19 | -------------------------------------------------------------------------------- /configs/dwm_baselines/dwm/operation_span.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/dwm_baselines/dwm/default_dwm.yaml, 4 | configs/dwm_baselines/default_settings_dual_task.yaml 5 | 6 | # Then overwrite problem name(s). 7 | training: 8 | problem: 9 | name: &name operation_span 10 | num_rotation: &nrot 0.5 11 | 12 | validation: 13 | problem: 14 | name: *name 15 | num_rotation: *nrot 16 | 17 | testing: 18 | problem: 19 | name: *name 20 | num_rotation: *nrot 21 | -------------------------------------------------------------------------------- /configs/dwm_baselines/dnc/operation_span.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/dwm_baselines/dnc/default_dnc.yaml, 4 | configs/dwm_baselines/dnc/dnc_default_settings_dual_task.yaml 5 | 6 | # Then overwrite problem name(s). 7 | training: 8 | problem: 9 | name: &name operation_span 10 | num_rotation: &nrot 0.5 11 | 12 | validation: 13 | problem: 14 | name: *name 15 | num_rotation: *nrot 16 | 17 | testing: 18 | problem: 19 | name: *name 20 | num_rotation: *nrot 21 | -------------------------------------------------------------------------------- /configs/dwm_baselines/lstm/operation_span.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/dwm_baselines/lstm/default_lstm.yaml, 4 | configs/dwm_baselines/default_settings_dual_task.yaml 5 | 6 | # Then overwrite problem name(s). 7 | training: 8 | problem: 9 | name: &name operation_span 10 | num_rotation: &nrot 0.5 11 | 12 | validation: 13 | problem: 14 | name: *name 15 | num_rotation: *nrot 16 | 17 | testing: 18 | problem: 19 | name: *name 20 | num_rotation: *nrot 21 | 22 | -------------------------------------------------------------------------------- /configs/dwm_baselines/dwm/manipulation_spatial_rotation.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/dwm_baselines/dwm/default_dwm.yaml, 4 | configs/dwm_baselines/default_settings_simple_task.yaml 5 | 6 | # Then overwrite problem name(s). 7 | training: 8 | problem: 9 | name: &name manipulation_spatial_rotation 10 | num_bits: &nrot 0.5 11 | 12 | validation: 13 | problem: 14 | name: *name 15 | num_bits: *nrot 16 | 17 | testing: 18 | problem: 19 | name: *name 20 | num_bits: *nrot 21 | -------------------------------------------------------------------------------- /configs/dwm_baselines/dwm/manipulation_temporal_swap.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/dwm_baselines/dwm/default_dwm.yaml, 4 | configs/dwm_baselines/default_settings_simple_task.yaml 5 | 6 | # Then overwrite problem name(s). 7 | training: 8 | problem: 9 | name: &name manipulation_temporal_swap 10 | num_items: &nrot 0.5 11 | 12 | validation: 13 | problem: 14 | name: *name 15 | num_items: *nrot 16 | 17 | testing: 18 | problem: 19 | name: *name 20 | num_items: *nrot 21 | -------------------------------------------------------------------------------- /configs/dwm_baselines/lstm/manipulation_spatial_rotation.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/dwm_baselines/lstm/default_lstm.yaml, 4 | configs/dwm_baselines/default_settings_dual_task.yaml 5 | 6 | # Then overwrite problem name(s). 7 | training: 8 | problem: 9 | name: &name manipulation_spatial_rotation 10 | num_bits: &nrot 0.5 11 | 12 | validation: 13 | problem: 14 | name: *name 15 | num_bits: *nrot 16 | 17 | testing: 18 | problem: 19 | name: *name 20 | num_bits: *nrot 21 | -------------------------------------------------------------------------------- /configs/maes_baselines/maes/default_maes.yaml: -------------------------------------------------------------------------------- 1 | # Model parameters: 2 | model: 3 | name: MAES 4 | # Pass the whole state from encoder to solver cell. 5 | pass_cell_state: True 6 | # Controller parameters. 7 | controller: 8 | name: RNNController 9 | hidden_state_size: 20 10 | num_layers: 1 11 | non_linearity: sigmoid 12 | # Interface 13 | mae_interface: 14 | shift_size: 3 15 | mas_interface: 16 | shift_size: 3 17 | # Memory parameters. 18 | memory: 19 | num_content_bits: 15 20 | num_addresses: -1 -------------------------------------------------------------------------------- /configs/dwm_baselines/dnc/manipulation_spatial_rotation.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/dwm_baselines/dnc/default_dnc.yaml, 4 | configs/dwm_baselines/dnc/dnc_default_settings_simple_task.yaml 5 | 6 | # Then overwrite problem name(s). 7 | training: 8 | problem: 9 | name: &name manipulation_spatial_rotation 10 | num_bits: &nrot 0.5 11 | 12 | validation: 13 | problem: 14 | name: *name 15 | num_bits: *nrot 16 | 17 | testing: 18 | problem: 19 | name: *name 20 | num_bits: *nrot 21 | -------------------------------------------------------------------------------- /configs/dwm_baselines/dwm/interruption_swap_recall.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/dwm_baselines/dwm/default_dwm.yaml, 4 | configs/dwm_baselines/default_settings_dual_task.yaml 5 | 6 | # Then overwrite problem name(s). 7 | training: 8 | problem: 9 | name: &name interruption_swap_recall 10 | num_rotation: &nrot 0.5 11 | 12 | validation: 13 | problem: 14 | name: *name 15 | num_rotation: *nrot 16 | 17 | testing: 18 | problem: 19 | name: *name 20 | num_rotation: *nrot 21 | 22 | -------------------------------------------------------------------------------- /configs/maes_baselines/ntm/default_ntm.yaml: -------------------------------------------------------------------------------- 1 | # Model parameters: 2 | model: 3 | name: NTM 4 | # Optional parameter: visualization. 5 | visualization_mode: 2 6 | # Controller parameters. 7 | controller: 8 | name: RNNController 9 | hidden_state_size: 20 10 | num_layers: 1 11 | non_linearity: sigmoid 12 | # Interface 13 | interface: 14 | num_read_heads: 1 15 | shift_size: 3 16 | use_content_based_addressing: True 17 | # Memory parameters. 18 | memory: 19 | num_content_bits: 15 20 | num_addresses: -1 -------------------------------------------------------------------------------- /configs/vqa_baselines/stacked_attention_networks/sort_of_clevr.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/vqa_baselines/default_sort_of_clevr.yaml 4 | 5 | # Model parameters: 6 | model: 7 | name: StackedAttentionNetwork 8 | use_pretrained_cnn: True 9 | pretrained_cnn: 10 | name: 'resnet18' 11 | num_layers: 2 12 | lstm: 13 | hidden_size: 64 14 | num_layers: 1 15 | bidirectional: False 16 | dropout: 0 17 | attention_layer: 18 | nb_nodes: 128 19 | classifier: 20 | nb_hidden_nodes: 256 21 | -------------------------------------------------------------------------------- /configs/dwm_baselines/lstm/reading_span.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/dwm_baselines/lstm/default_lstm.yaml, 4 | configs/dwm_baselines/default_settings_dual_task.yaml 5 | 6 | # Then overwrite problem name(s). 7 | training: 8 | problem: 9 | name: &name reading_span 10 | control_bits: &cbits 2 11 | 12 | validation: 13 | problem: 14 | name: *name 15 | control_bits: *cbits 16 | 17 | testing: 18 | problem: 19 | name: *name 20 | control_bits: *cbits 21 | 22 | 23 | model: 24 | control_bits: *cbits 25 | -------------------------------------------------------------------------------- /configs/vqa_baselines/stacked_attention_networks/shape_color_query.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/vqa_baselines/default_shape_color_query.yaml 4 | 5 | # Model parameters: 6 | model: 7 | name: StackedAttentionNetwork 8 | use_pretrained_cnn: True 9 | pretrained_cnn: 10 | name: 'resnet18' 11 | num_layers: 2 12 | lstm: 13 | hidden_size: 64 14 | num_layers: 1 15 | bidirectional: False 16 | dropout: 0 17 | attention_layer: 18 | nb_nodes: 128 19 | classifier: 20 | nb_hidden_nodes: 256 21 | 22 | 23 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | 5 | --- 6 | 7 | **Is your feature request related to a problem? Please describe.** 8 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 9 | 10 | **Describe the solution you'd like** 11 | A clear and concise description of what you want to happen. 12 | 13 | **Describe alternatives you've considered** 14 | A clear and concise description of any alternative solutions or features you've considered. 15 | 16 | **Additional context** 17 | Add any other context or screenshots about the feature request here. 18 | -------------------------------------------------------------------------------- /configs/dwm_baselines/lstm/scratch_pad.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/dwm_baselines/lstm/default_lstm.yaml, 4 | configs/dwm_baselines/default_settings_dual_task.yaml 5 | 6 | # Then overwrite problem name(s). 7 | training: 8 | problem: 9 | name: &name ScratchPadCommandLines 10 | # Parameters denoting min and max lengths. 11 | min_sequence_length: 1 12 | max_sequence_length: 10 13 | num_subseq_min: 1 14 | num_subseq_max: 3 15 | 16 | 17 | validation: 18 | problem: 19 | name: *name 20 | 21 | 22 | testing: 23 | problem: 24 | name: *name 25 | -------------------------------------------------------------------------------- /configs/dwm_baselines/dnc/default_dnc.yaml: -------------------------------------------------------------------------------- 1 | # Model parameters: 2 | model: 3 | name: DNC # differential neural computer 4 | # Controller hidden state. 5 | hidden_state_size: 20 6 | # Memory (?). 7 | memory_content_size: 10 8 | memory_addresses_size: -1 9 | # Number of heads. 10 | num_writes: 1 11 | num_reads: 1 12 | # Addressing parameters. 13 | shift_size: 3 14 | controller_type: LSTMController 15 | use_ntm_write: False 16 | use_ntm_read: False 17 | use_ntm_order: False 18 | use_extra_write_gate: False 19 | non_linearity: sigmoid 20 | # active the plotting of the memory and attention 21 | plot_memory: False 22 | 23 | -------------------------------------------------------------------------------- /configs/maes_baselines/dnc/default_dnc.yaml: -------------------------------------------------------------------------------- 1 | # Model parameters: 2 | model: 3 | name: DNC # differential neural computer 4 | # Controller hidden state. 5 | hidden_state_size: 20 6 | # Memory (?). 7 | memory_content_size: 15 8 | memory_addresses_size: -1 9 | # Number of heads. 10 | num_writes: 1 11 | num_reads: 1 12 | # Addressing parameters. 13 | shift_size: 3 14 | controller_type: LSTMController 15 | use_ntm_write: False 16 | use_ntm_read: False 17 | use_ntm_order: False 18 | use_extra_write_gate: False 19 | non_linearity: sigmoid 20 | # active the plotting of the memory and attention 21 | plot_memory: False 22 | 23 | -------------------------------------------------------------------------------- /docs/source/api_reference/6_helpers.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | .. currentmodule:: miprometheus.helpers 5 | 6 | miprometheus.helpers 7 | =================================== 8 | .. automodule:: miprometheus.helpers 9 | :members: 10 | :special-members: 11 | :exclude-members: __dict__,__weakref__ 12 | 13 | IndexSplitter 14 | -------------- 15 | .. autoclass:: IndexSplitter 16 | :members: 17 | :special-members: 18 | :exclude-members: __dict__,__weakref__ 19 | 20 | ProblemInitializer 21 | ------------------- 22 | .. autoclass:: ProblemInitializer 23 | :members: 24 | :special-members: 25 | :exclude-members: __dict__,__weakref__ 26 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = MI-Prometheus 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /miprometheus/models/controllers/__init__.py: -------------------------------------------------------------------------------- 1 | from .controller_factory import ControllerFactory 2 | from .feedforward_controller import FeedforwardController 3 | from .ffgru_controller import FFGRUStateTuple, FFGRUController 4 | from .gru_controller import GRUStateTuple, GRUController 5 | from .lstm_controller import LSTMStateTuple, LSTMController 6 | from .rnn_controller import RNNStateTuple, RNNController 7 | 8 | __all__ = [ 9 | 'ControllerFactory', 10 | 'FeedforwardController', 11 | 'FFGRUStateTuple', 12 | 'FFGRUController', 13 | 'GRUStateTuple', 14 | 'GRUController', 15 | 'LSTMStateTuple', 16 | 'LSTMController', 17 | 'RNNStateTuple', 18 | 'RNNController'] 19 | -------------------------------------------------------------------------------- /configs/cog/cog_mac.yaml: -------------------------------------------------------------------------------- 1 | # Load the COG config first. 2 | default_configs: 3 | configs/cog/default_cog.yaml 4 | 5 | model: 6 | name: MACNetworkSequential 7 | dim: 128 8 | embed_hidden: 64 9 | max_step: 12 10 | slot : 8 11 | classes : 49 12 | words_embed_length : 64 13 | nwords : 24 14 | trainable_init : False 15 | self_attention: False 16 | memory_gate: False 17 | dropout: 0.15 18 | memory_pass: False 19 | control_pass: False 20 | 21 | 22 | training: 23 | # fix the seeds 24 | seed_torch: 0 25 | seed_numpy: 0 26 | 27 | validation: 28 | # fix the seeds 29 | seed_torch: 0 30 | seed_numpy: 0 31 | 32 | testing: 33 | # fix the seeds 34 | seed_torch: 0 35 | seed_numpy: 0 36 | -------------------------------------------------------------------------------- /configs/maes_baselines/lstm/lstm_odd_recall.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/maes_baselines/lstm/default_lstm.yaml, 4 | configs/maes_baselines/default_problem.yaml, 5 | configs/maes_baselines/default_training.yaml 6 | 7 | # Then overwrite problem name(s). 8 | training: 9 | problem: 10 | name: &name SkipRecallCommandLines 11 | # Skip params. 12 | seq_start: &start 0 13 | skip_step: &skip 2 14 | 15 | validation: 16 | problem: 17 | name: *name 18 | seq_start: *start 19 | skip_step: *skip 20 | 21 | testing: 22 | problem: 23 | name: *name 24 | seq_start: *start 25 | skip_step: *skip 26 | -------------------------------------------------------------------------------- /configs/maes_baselines/dnc/dnc_odd_recall.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/maes_baselines/dnc/default_dnc.yaml, 4 | configs/maes_baselines/default_problem.yaml, 5 | configs/maes_baselines/dnc/dnc_default_training.yaml 6 | 7 | # Then overwrite problem name(s). 8 | training: 9 | problem: 10 | name: &name SkipRecallCommandLines 11 | # Skip params. 12 | seq_start: &start 0 13 | skip_step: &skip 2 14 | 15 | validation: 16 | problem: 17 | name: *name 18 | seq_start: *start 19 | skip_step: *skip 20 | 21 | testing: 22 | problem: 23 | name: *name 24 | seq_start: *start 25 | skip_step: *skip 26 | -------------------------------------------------------------------------------- /configs/maes_baselines/ntm/ntm_odd_recall.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/maes_baselines/ntm/default_ntm.yaml, 4 | configs/maes_baselines/default_problem.yaml, 5 | configs/maes_baselines/default_training.yaml 6 | 7 | # Then overwrite problem name(s). 8 | training: 9 | problem: 10 | name: &name SkipRecallCommandLines 11 | # Skip params. 12 | seq_start: &start 0 13 | skip_step: &skip 2 14 | 15 | validation: 16 | problem: 17 | name: *name 18 | seq_start: *start 19 | skip_step: *skip 20 | 21 | testing: 22 | problem: 23 | name: *name 24 | seq_start: *start 25 | skip_step: *skip 26 | 27 | 28 | -------------------------------------------------------------------------------- /configs/maes_baselines/es_lstm/es_lstm_odd_recall.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/maes_baselines/es_lstm/default_es_lstm.yaml, 4 | configs/maes_baselines/default_problem.yaml, 5 | configs/maes_baselines/default_training.yaml 6 | 7 | # Then overwrite problem name(s). 8 | training: 9 | problem: 10 | name: &name SkipRecallCommandLines 11 | # Skip params. 12 | seq_start: &start 0 13 | skip_step: &skip 2 14 | 15 | validation: 16 | problem: 17 | name: *name 18 | seq_start: *start 19 | skip_step: *skip 20 | 21 | testing: 22 | problem: 23 | name: *name 24 | seq_start: *start 25 | skip_step: *skip 26 | -------------------------------------------------------------------------------- /configs/maes_baselines/maes/maes_serial_recall.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/maes_baselines/maes/default_maes.yaml, 4 | configs/maes_baselines/default_problem.yaml, 5 | configs/maes_baselines/default_training.yaml 6 | 7 | # Then overwrite problem name(s). 8 | training: 9 | problem: 10 | name: &name SerialRecallCommandLines 11 | 12 | 13 | validation: 14 | problem: 15 | name: *name 16 | 17 | testing: 18 | problem: 19 | name: *name 20 | 21 | #model: 22 | # Save/load encoder. 23 | #save_encoder: True 24 | #load_encoder: ./experiments/maes_baselines/serial_recall_cl/maes/20180702_132024/models/encoder_episode_09400.pth.tar 25 | #freeze_encoder: True 26 | 27 | -------------------------------------------------------------------------------- /configs/maes_baselines/maes/maes_sequence_equality.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/maes_baselines/maes/default_maes.yaml, 4 | configs/maes_baselines/default_problem.yaml, 5 | configs/maes_baselines/default_training.yaml 6 | 7 | # Then overwrite problem name(s). 8 | training: 9 | problem: 10 | name: &name SequenceEqualityCommandLines 11 | 12 | validation: 13 | problem: 14 | name: *name 15 | 16 | testing: 17 | problem: 18 | name: *name 19 | 20 | #model: 21 | # Save/load encoder. 22 | #save_encoder: True 23 | #load_encoder: ./experiments/maes_baselines/serial_recall_cl/maes/20180702_132024/models/encoder_episode_09400.pth.tar 24 | #freeze_encoder: True 25 | 26 | -------------------------------------------------------------------------------- /configs/maes_baselines/maes/maes_sequence_symmetry.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/maes_baselines/maes/default_maes.yaml, 4 | configs/maes_baselines/default_problem.yaml, 5 | configs/maes_baselines/default_training.yaml 6 | 7 | # Then overwrite problem name(s). 8 | training: 9 | problem: 10 | name: &name SequenceSymmetryCommandLines 11 | 12 | validation: 13 | problem: 14 | name: *name 15 | 16 | testing: 17 | problem: 18 | name: *name 19 | 20 | #model: 21 | # Save/load encoder. 22 | #save_encoder: True 23 | #load_encoder: ./experiments/maes_baselines/serial_recall_cl/maes/20180702_132024/models/encoder_episode_09400.pth.tar 24 | #freeze_encoder: True 25 | 26 | -------------------------------------------------------------------------------- /docs/build_docs.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright (C) tkornuta, IBM Corporation 2018-2019 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | cd docs 17 | rm -rf build 18 | 19 | # create html pages 20 | sphinx-build -b html source build 21 | 22 | -------------------------------------------------------------------------------- /docs/source/mip_primer/6_helpers.rst: -------------------------------------------------------------------------------- 1 | Helpers Explained 2 | =================== 3 | `@author: Tomasz Kornuta & Vincent Marois` 4 | 5 | Helper is an application useful from the point of view of experiment running, but independent/external to the Workers. 6 | Currently MI-Prometheus offers two type of helpers: 7 | 8 | - **Problem Initializer**, responsible for initialization of a problem (i.e. download of required data from internet or generation of all samples) in advance, before the real experiment starts. 9 | - **Index Splitter**, responsible for generation of files with indices splitting given dataset (in fact set of indices) into two. The resulting files can later be used in training/verification testing when using ``SubsetRandomSampler``. 10 | 11 | We expect this list to grow soon. 12 | -------------------------------------------------------------------------------- /configs/maes_baselines/maes/maes_reverse_recall.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/maes_baselines/maes/default_maes.yaml, 4 | configs/maes_baselines/default_problem.yaml, 5 | configs/maes_baselines/default_training.yaml 6 | 7 | # Then overwrite problem name(s). 8 | training: 9 | problem: 10 | name: &name ReverseRecallCommandLines 11 | 12 | validation: 13 | problem: 14 | name: *name 15 | 16 | testing: 17 | problem: 18 | name: *name 19 | 20 | # And overwrite model parameters. 21 | #model: 22 | # Save/load encoder. 23 | #save_encoder: True 24 | #load_encoder: ./experiments/maes_baselines/serial_recall_cl/maes/20180702_132024/models/encoder_episode_09400.pth.tar 25 | #freeze_encoder: True 26 | -------------------------------------------------------------------------------- /configs/maes_baselines/maes/maes_sequence_comparison.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/maes_baselines/maes/default_maes.yaml, 4 | configs/maes_baselines/default_problem.yaml, 5 | configs/maes_baselines/default_training.yaml 6 | 7 | # Then overwrite problem name(s). 8 | training: 9 | problem: 10 | name: &name SequenceComparisonCommandLines 11 | 12 | validation: 13 | problem: 14 | name: *name 15 | 16 | testing: 17 | problem: 18 | name: *name 19 | 20 | # And overwrite model parameters. 21 | #model: 22 | # Save/load encoder. 23 | #save_encoder: True 24 | #load_encoder: ./experiments/maes_baselines/serial_recall_cl/maes/20180702_132024/models/encoder_episode_09400.pth.tar 25 | #freeze_encoder: True 26 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | 5 | --- 6 | 7 | **Describe the bug** 8 | A clear and concise description of what the bug is. 9 | 10 | **To Reproduce** 11 | Steps to reproduce the behavior: 12 | 1. Go to '...' 13 | 2. Click on '....' 14 | 3. Scroll down to '....' 15 | 4. See error 16 | 17 | **Expected behavior** 18 | A clear and concise description of what you expected to happen. 19 | 20 | **Screenshots** 21 | If applicable, add screenshots to help explain your problem. 22 | 23 | **Desktop (please complete the following information):** 24 | - OS: [e.g. Ubuntu] 25 | - OS version [e.g. 14.04] 26 | - Python version [e.g. 3.7] 27 | - PyTorch version [e.g. 0.4.0] 28 | 29 | **Additional context** 30 | Add any other context about the problem here. 31 | -------------------------------------------------------------------------------- /configs/maes_baselines/default_training.yaml: -------------------------------------------------------------------------------- 1 | # Problem parameters: 2 | training: 3 | # Curriculum learning - optional. 4 | curriculum_learning: 5 | initial_max_sequence_length: 5 6 | # must_finish: false 7 | 8 | # Sampler. 9 | sampler: 10 | name: RandomSampler 11 | 12 | # Optimizer parameters: 13 | optimizer: 14 | # Exact name of the pytorch optimizer function 15 | name: Adam 16 | # Function arguments of the optimizer, by name 17 | lr: 0.01 18 | 19 | # Optional parameter, its presence results in clipping gradient to a range (-gradient_clipping, gradient_clipping) 20 | gradient_clipping: 10 21 | 22 | # Terminal condition parameters: 23 | terminal_conditions: 24 | loss_stop: 1.0e-5 25 | #early_stop_delta: 1.0e-5 26 | episode_limit: 100000 27 | #epoch_limit: 100 28 | -------------------------------------------------------------------------------- /configs/mac/default_cogent.yaml: -------------------------------------------------------------------------------- 1 | # Load the following the default CLEVR config first. 2 | default_configs: 3 | configs/mac/default_clevr.yaml 4 | 5 | # Then overwrite problem params. 6 | training: 7 | problem: 8 | settings: 9 | data_folder: &dir '~/data/CLEVR_CoGenT_v1.0' 10 | set: 'trainA' 11 | dataset_variant: &var 'CLEVR-CoGenT' 12 | questions: 13 | embedding_source: 'CLEVR-CoGenT' 14 | 15 | # fix the seeds 16 | seed_torch: 0 17 | seed_numpy: 0 18 | 19 | validation: 20 | problem: 21 | settings: 22 | data_folder: *dir 23 | set: 'valA' 24 | dataset_variant: *var 25 | questions: 26 | embedding_source: 'CLEVR-CoGenT' 27 | 28 | testing: 29 | problem: 30 | settings: 31 | data_folder: *dir 32 | set: 'valB' 33 | dataset_variant: *var 34 | questions: 35 | embedding_source: 'CLEVR-CoGenT' 36 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=MI-Prometheus 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 20 | echo.installed, then set the SPHINXBUILD environment variable to point 21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 22 | echo.may add the Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /configs/maes_baselines/dnc/dnc_default_training.yaml: -------------------------------------------------------------------------------- 1 | # Problem parameters: 2 | training: 3 | # Curriculum learning - optional. 4 | curriculum_learning: 5 | initial_max_sequence_length: 5 6 | 7 | # Optimizer parameters: 8 | optimizer: 9 | # Exact name of the pytorch optimizer function 10 | name: Adam 11 | # Function arguments of the optimizer, by name 12 | lr: 0.0005 13 | 14 | # Optional parameter, its presence results in clipping gradient to a range (-gradient_clipping, gradient_clipping) 15 | gradient_clipping: 10 16 | # How often the model will be validated/saved. 17 | validation_interval: 100 18 | # Optional parameters denoting number of - used when validation section is not present. 19 | length_loss: 10 20 | 21 | # Terminal condition parameters: 22 | terminal_conditions: 23 | loss_stop: 1.0e-5 24 | #early_stop_delta: 1.0e-5 25 | episode_limit: 100000 26 | #epoch_limit: 100 27 | -------------------------------------------------------------------------------- /miprometheus/utils/singleton.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) IBM Corporation 2018 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | class SingletonMetaClass(type): 19 | _instances = {} 20 | 21 | def __call__(cls, *args, **kwargs): 22 | if cls not in cls._instances: 23 | cls._instances[cls] = super( 24 | SingletonMetaClass, cls).__call__(*args, **kwargs) 25 | return cls._instances[cls] 26 | -------------------------------------------------------------------------------- /configs/cog/cog_generate.yaml: -------------------------------------------------------------------------------- 1 | # Problem parameters: 2 | training: 3 | cuda: True 4 | problem: 5 | name: COG 6 | # Size of generated input: [batch_size x sequence_length x classes]. 7 | initialization_only: True 8 | batch_size: 48 9 | tasks: &task all 10 | data_folder: '~/data/cog' 11 | set: train 12 | use_mask: False 13 | # Generate hard dataset 14 | dataset_type: generated 15 | generation: 16 | examples_per_task: 100 17 | sequence_length: 8 18 | memory_length: 7 19 | max_distractors: 10 20 | nr_processors: 1 21 | 22 | validation: 23 | cuda: True 24 | problem: 25 | name: COG 26 | # Size of generated input: [batch_size x sequence_length x classes]. 27 | batch_size: 48 28 | tasks: *task 29 | data_folder: '~/data/cog' 30 | set: val 31 | use_mask: False 32 | # Generate hard dataset 33 | dataset_type: generated 34 | generation: 35 | examples_per_task: 100 36 | sequence_length: 8 37 | memory_length: 7 38 | max_distractors: 10 39 | nr_processors: 1 40 | 41 | 42 | model: 43 | name: CogModel 44 | -------------------------------------------------------------------------------- /configs/maes_baselines/maes/maes_odd_recall.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/maes_baselines/maes/default_maes.yaml, 4 | configs/maes_baselines/default_problem.yaml, 5 | configs/maes_baselines/default_training.yaml 6 | 7 | # Then overwrite problem name(s). 8 | training: 9 | problem: 10 | name: &name SkipRecallCommandLines 11 | # Skip params. 12 | seq_start: &start 0 13 | skip_step: &skip 2 14 | 15 | validation: 16 | problem: 17 | name: *name 18 | # Skip params. 19 | seq_start: *start 20 | skip_step: *skip 21 | 22 | testing: 23 | problem: 24 | name: *name 25 | # Skip params. 26 | seq_start: *start 27 | skip_step: *skip 28 | 29 | # And overwrite model parameters. 30 | #model: 31 | # Save/load encoder. 32 | #save_encoder: True 33 | #load_encoder: ./experiments/maes_baselines/serial_recall_cl/maes/20180702_132024/models/encoder_episode_09400.pth.tar 34 | #freeze_encoder: True 35 | 36 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Pull Request template 2 | Please, go through these steps before you submit a PR. 3 | 4 | 1. Make sure that your PR is not a duplicate. 5 | 2. If not, then make sure that: 6 | 7 | 2.1. You have done your changes in a separate branch. Branches should have descriptive names that start with either the `fix/` or `feature/` prefixes. Good examples are: `fix/signin-issue` or `feature/new-model`. 8 | 9 | 2.2. You have descriptive commits messages with short titles (first line). 10 | 11 | 2.3. You have only one commit (if not, squash them into one commit). 12 | 13 | 3. **After** these steps, you're ready to open a pull request. 14 | 15 | 3.1. Give a descriptive title to your PR. 16 | 17 | 3.2. Provide a description of your changes. 18 | 19 | 3.3. Put `closes #XXXX` in your comment to auto-close the issue that your PR fixes (if such). 20 | 21 | Important: Please review the [CONTRIBUTING.md](../CONTRIBUTING.md) file for detailed contributing guidelines. 22 | 23 | *Please remove this template before submitting.* 24 | -------------------------------------------------------------------------------- /configs/maes_baselines/maes/maes_repeat_reverse_recall.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/maes_baselines/maes/default_maes.yaml, 4 | configs/maes_baselines/default_problem.yaml, 5 | configs/maes_baselines/default_training.yaml 6 | 7 | # Then overwrite problem name(s). 8 | training: 9 | problem: 10 | name: &name RepeatReverseRecallCommandLines 11 | # Number of repetitions. 12 | min_recall_number: 1 13 | max_recall_number: 4 14 | 15 | validation: 16 | problem: 17 | name: *name 18 | # Number of repetitions. 19 | min_recall_number: 5 20 | max_recall_number: 5 21 | 22 | testing: 23 | problem: 24 | name: *name 25 | # Number of repetitions. 26 | min_recall_number: 10 27 | max_recall_number: 10 28 | 29 | # And overwrite model parameters. 30 | #model: 31 | # Save/load encoder. 32 | #save_encoder: True 33 | #load_encoder: ./experiments/maes_baselines/serial_recall_cl/maes/20180702_132024/models/encoder_episode_09400.pth.tar 34 | #freeze_encoder: True 35 | 36 | -------------------------------------------------------------------------------- /configs/maes_baselines/maes/maes_repeat_serial_recall.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/maes_baselines/maes/default_maes.yaml, 4 | configs/maes_baselines/default_problem.yaml, 5 | configs/maes_baselines/default_training.yaml 6 | 7 | # Then overwrite problem name(s). 8 | training: 9 | problem: 10 | name: &name RepeatSerialRecallCommandLines 11 | # Number of repetitions. 12 | min_recall_number: 1 13 | max_recall_number: 4 14 | 15 | validation: 16 | problem: 17 | name: *name 18 | # Number of repetitions. 19 | min_recall_number: 5 20 | max_recall_number: 5 21 | 22 | testing: 23 | problem: 24 | name: *name 25 | # Number of repetitions. 26 | min_recall_number: 10 27 | max_recall_number: 10 28 | 29 | # And overwrite model parameters. 30 | #model: 31 | # Save/load encoder. 32 | #save_encoder: True 33 | #load_encoder: ./experiments/maes_baselines/serial_recall_cl/maes/20180702_132024/models/encoder_episode_09400.pth.tar 34 | #freeze_encoder: True 35 | 36 | -------------------------------------------------------------------------------- /configs/maes_baselines/es_ntm/es_ntm_serial_recall.yaml: -------------------------------------------------------------------------------- 1 | # Load the following (default) configs first. 2 | default_configs: 3 | configs/maes_baselines/default_problem.yaml, 4 | configs/maes_baselines/default_training.yaml 5 | 6 | # Then overwrite problem name(s). 7 | training: 8 | problem: 9 | name: &name SerialRecallCommandLines 10 | 11 | validation: 12 | problem: 13 | name: *name 14 | 15 | testing: 16 | problem: 17 | name: *name 18 | 19 | # Model parameters: 20 | model: 21 | name: EncoderSolverNTM 22 | # Optional parameter: visualization. 23 | #visualization_mode: 2 24 | # Pass the whole state from encoder to solver cell. 25 | pass_cell_state: True 26 | # Controller parameters. 27 | controller: 28 | name: RNNController 29 | hidden_state_size: 20 30 | num_layers: 1 31 | non_linearity: sigmoid 32 | # Interface 33 | interface: 34 | shift_size: 3 35 | num_read_heads: 1 36 | use_content_based_addressing: False 37 | # Memory parameters. 38 | memory: 39 | num_content_bits: 15 40 | num_addresses: -1 41 | -------------------------------------------------------------------------------- /configs/vision/alexnet_mnist.yaml: -------------------------------------------------------------------------------- 1 | # Problem parameters: 2 | training: 3 | #seed_numpy: 4354 4 | #seed_torch: 2452 5 | problem: 6 | name: &name MNIST 7 | batch_size: &b 64 8 | use_train_data: True 9 | resize: [224, 224] 10 | # Use sampler that operates on a subset. 11 | sampler: 12 | name: SubsetRandomSampler 13 | indices: [0, 55000] 14 | # optimizer parameters: 15 | optimizer: 16 | name: Adam 17 | lr: 0.01 18 | # settings parameters 19 | terminal_conditions: 20 | loss_stop: 1.0e-2 21 | episode_limit: 50000 22 | epochs_limit: 10 23 | 24 | # Problem parameters: 25 | validation: 26 | problem: 27 | name: *name 28 | batch_size: *b 29 | use_train_data: True # True because we are splitting the training set to: validation and training 30 | resize: [224, 224] 31 | 32 | # Problem parameters: 33 | testing: 34 | problem: 35 | name: *name 36 | batch_size: *b 37 | use_train_data: False 38 | resize: [224, 224] 39 | 40 | # Model parameters: 41 | model: 42 | name: AlexnetWrapper 43 | pretrained: False -------------------------------------------------------------------------------- /configs/vqa_baselines/default_vqa_med.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | # Problem parameters: 3 | problem: 4 | name: &name VQAMED 5 | batch_size: 64 6 | image: &image 7 | width: 256 8 | height: 128 9 | question: &question 10 | embedding_type: random 11 | embedding_dim: 300 12 | settings: 13 | data_folder: '~/data/ImageClef-2019-VQA-Med-Training/' 14 | split: 'train' 15 | # Set optimizer. 16 | optimizer: 17 | name: Adam 18 | lr: 0.005 19 | # Optional parameter, its presence results in clipping gradient to a range (-gradient_clipping, gradient_clipping) 20 | gradient_clipping: 10 21 | # Terminal condition parameters: 22 | terminal_conditions: 23 | loss_stop: 0.0001 24 | episode_limit: 100000 25 | epoch_limit: 20 26 | 27 | 28 | validation: 29 | # Problem parameters: 30 | problem: 31 | name: *name 32 | batch_size: 64 33 | image: *image 34 | question: *question 35 | settings: 36 | data_folder: '~/data/ImageClef-2019-VQA-Med-Validation/' 37 | split: 'valid' 38 | -------------------------------------------------------------------------------- /configs/mental_model/mm_cog.yaml: -------------------------------------------------------------------------------- 1 | # Problem parameters: 2 | training: 3 | cuda: True 4 | problem: 5 | name: COG 6 | # Size of generated input: [batch_size x sequence_length x classes]. 7 | batch_size: 48 8 | tasks: &task all 9 | data_folder: '~/data/cog' 10 | set: train 11 | use_mask: False 12 | dataset_type: canonical 13 | # Set optimizer. 14 | optimizer: 15 | name: Adam 16 | lr: 0.00005 17 | weight_decay: 0.00002 18 | amsgrad: True 19 | gradient_clipping: 10 20 | terminal_conditions: 21 | epoch_limit: 20 22 | episode_limit: 10000000 23 | 24 | testing: 25 | cuda: True 26 | problem: 27 | name: COG 28 | # Size of generated input: [batch_size x sequence_length x classes]. 29 | batch_size: 48 30 | tasks: *task 31 | data_folder: '~/data/cog' 32 | set: test 33 | dataset_type: canonical 34 | 35 | validation: 36 | cuda: True 37 | problem: 38 | name: COG 39 | # Size of generated input: [batch_size x sequence_length x classes]. 40 | batch_size: 48 41 | tasks: *task 42 | data_folder: '~/data/cog' 43 | set: val 44 | use_mask: False 45 | dataset_type: canonical 46 | partial_validation_interval: 100 47 | 48 | model: 49 | name: MentalModel 50 | -------------------------------------------------------------------------------- /docs/source/api_reference/4_workers.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | .. currentmodule:: miprometheus.workers 5 | 6 | miprometheus.workers 7 | =================================== 8 | .. automodule:: miprometheus.workers 9 | :members: 10 | :special-members: 11 | :exclude-members: __dict__,__weakref__ 12 | 13 | Worker 14 | -------------- 15 | .. autoclass:: Worker 16 | :members: 17 | :special-members: 18 | :exclude-members: __dict__,__weakref__ 19 | 20 | Trainer 21 | ---------------------------- 22 | .. autoclass:: Trainer 23 | :members: 24 | :special-members: 25 | :exclude-members: __dict__,__weakref__ 26 | 27 | OfflineTrainer 28 | ---------------------------- 29 | .. autoclass:: OfflineTrainer 30 | :members: 31 | :special-members: 32 | :exclude-members: __dict__,__weakref__ 33 | 34 | OnlineTrainer 35 | ---------------------------- 36 | .. autoclass:: OnlineTrainer 37 | :members: 38 | :special-members: 39 | :exclude-members: __dict__,__weakref__ 40 | 41 | Tester 42 | ---------------------------- 43 | .. autoclass:: Tester 44 | :members: 45 | :special-members: 46 | :exclude-members: __dict__,__weakref__ 47 | -------------------------------------------------------------------------------- /miprometheus/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .app_state import AppState 2 | from .param_interface import ParamInterface 3 | from .param_registry import MetaSingletonABC, ParamRegistry 4 | from .sampler_factory import SamplerFactory 5 | from .singleton import SingletonMetaClass 6 | from .split_indices import split_indices 7 | from .statistics_collector import StatisticsCollector 8 | from .statistics_aggregator import StatisticsAggregator 9 | from .time_plot import TimePlot 10 | from .data_dict import DataDict 11 | 12 | from .loss.masked_cross_entropy_loss import MaskedCrossEntropyLoss 13 | from .loss.masked_bce_with_logits_loss import MaskedBCEWithLogitsLoss 14 | 15 | from .problems_utils.generate_feature_maps import GenerateFeatureMaps 16 | from .problems_utils.language import Language 17 | 18 | __all__ = [ 19 | 'AppState', 20 | 'ParamInterface', 21 | 'MetaSingletonABC', 22 | 'ParamRegistry', 23 | 'SamplerFactory', 24 | 'SingletonMetaClass', 25 | 'split_indices', 26 | 'StatisticsCollector', 27 | 'StatisticsAggregator', 28 | 'TimePlot', 29 | 'DataDict', 30 | 'MaskedCrossEntropyLoss', 31 | 'MaskedBCEWithLogitsLoss', 32 | 'GenerateFeatureMaps', 33 | 'Language' 34 | ] 35 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. Mi-Prometheus documentation master file, created by 2 | sphinx-quickstart on Wed Apr 25 11:52:47 2018. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | :github_url: https://github.com/IBM/mi-prometheus 7 | 8 | MI Prometheus documentation 9 | =================================== 10 | 11 | MI Prometheus is an open source Python library, built using PyTorch, that enables reproducible Machine Learning research. 12 | 13 | .. toctree:: 14 | :glob: 15 | :maxdepth: 1 16 | :caption: Notes 17 | 18 | notes/* 19 | 20 | .. toctree:: 21 | :glob: 22 | :maxdepth: 1 23 | :caption: MI-Prometheus Primer 24 | 25 | mip_primer/* 26 | 27 | .. toctree:: 28 | :glob: 29 | :maxdepth: 1 30 | :caption: Reproducible Research 31 | 32 | research/* 33 | 34 | .. toctree:: 35 | :glob: 36 | :maxdepth: 1 37 | :caption: Tutorials 38 | 39 | tutorials/* 40 | 41 | 42 | .. toctree:: 43 | :glob: 44 | :maxdepth: 4 45 | :caption: Package Reference 46 | 47 | api_reference/* 48 | 49 | Indices and tables 50 | ================== 51 | 52 | * :ref:`genindex` 53 | * :ref:`modindex` 54 | * :ref:`search` 55 | -------------------------------------------------------------------------------- /configs/thalnet/seq_pixel_mnist.yaml: -------------------------------------------------------------------------------- 1 | # Problem parameters: 2 | training: 3 | problem: 4 | name: &name SequentialPixelMNIST 5 | batch_size: &bs 64 6 | index: [0, 54999] 7 | use_train_data: True 8 | root_dir: '~/data/mnist' 9 | # optimizer parameters: 10 | optimizer: 11 | # Exact name of the pytorch optimizer function 12 | name: Adam 13 | # Function arguments of the optimizer, by name 14 | lr: 0.01 15 | terminal_conditions: 16 | loss_stop: 1.0e-5 17 | episode_limit: 50000 18 | 19 | # Problem parameters: 20 | testing: 21 | problem: 22 | name: *name 23 | batch_size: *bs 24 | use_train_data: False 25 | root_dir: '~/data/mnist' 26 | 27 | # Problem parameters: 28 | validation: 29 | problem: 30 | name: *name 31 | batch_size: *bs 32 | index: [55000, 59999] 33 | use_train_data: True 34 | root_dir: '~/data/mnist' 35 | 36 | # Model parameters: 37 | model: 38 | name: ThalNetModel 39 | # Controller parameters. 40 | context_input_size: 32 41 | input_size: 1 # row_size 42 | output_size: 10 # number of classes 43 | center_size_per_module: 32 44 | num_modules: 4 45 | -------------------------------------------------------------------------------- /configs/vision/alexnet_cifar10.yaml: -------------------------------------------------------------------------------- 1 | # Problem parameters: 2 | training: 3 | #seed_numpy: 4354 4 | #seed_torch: 2452 5 | problem: 6 | name: &name CIFAR10 7 | batch_size: &b 64 8 | use_train_data: True 9 | resize: [224, 224] 10 | # Use sampler that operates on a subset. 11 | sampler: 12 | name: SubsetRandomSampler 13 | indices: [0, 45000] 14 | # optimizer parameters: 15 | optimizer: 16 | name: Adam 17 | lr: 0.01 18 | terminal_conditions: 19 | loss_stop: 1.0e-5 20 | episode_limit: 50000 21 | 22 | # Problem parameters: 23 | validation: 24 | problem: 25 | name: *name 26 | batch_size: *b 27 | use_train_data: True # True because we are splitting the training set to: validation and training 28 | resize: [224, 224] 29 | # Use sampler that operates on a subset. 30 | sampler: 31 | name: SubsetRandomSampler 32 | indices: [45000, 50000] 33 | 34 | # Problem parameters: 35 | testing: 36 | problem: 37 | name: *name 38 | batch_size: *b 39 | use_train_data: False 40 | resize: [224, 224] 41 | 42 | # Model parameters: 43 | model: 44 | name: AlexnetWrapper 45 | pretrained: False 46 | -------------------------------------------------------------------------------- /configs/vision/lenet5_mnist.yaml: -------------------------------------------------------------------------------- 1 | # Training parameters: 2 | training: 3 | problem: 4 | name: &name MNIST 5 | batch_size: &b 64 6 | use_train_data: True 7 | data_folder: &folder '~/data/mnist' 8 | resize: [32, 32] 9 | # Use sampler that operates on a subset. 10 | sampler: 11 | name: SubsetRandomSampler 12 | indices: [0, 55000] 13 | # optimizer parameters: 14 | optimizer: 15 | name: Adam 16 | lr: 0.01 17 | # settings parameters 18 | terminal_conditions: 19 | loss_stop: 1.0e-2 20 | episode_limit: 10000 21 | epoch_limit: 10 22 | 23 | # Validation parameters: 24 | validation: 25 | partial_validation_interval: 500 26 | problem: 27 | name: *name 28 | batch_size: *b 29 | use_train_data: True # True because we are splitting the training set to: validation and training 30 | data_folder: *folder 31 | resize: [32, 32] 32 | # Use sampler that operates on a subset. 33 | sampler: 34 | name: SubsetRandomSampler 35 | indices: [55000, 60000] 36 | 37 | # Testing parameters: 38 | testing: 39 | problem: 40 | name: *name 41 | batch_size: *b 42 | use_train_data: False 43 | data_folder: *folder 44 | resize: [32, 32] 45 | 46 | # Model parameters: 47 | model: 48 | name: LeNet5 49 | -------------------------------------------------------------------------------- /configs/vqa_baselines/default_sort_of_clevr.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | # Problem parameters: 3 | problem: 4 | name: &name SortOfCLEVR 5 | batch_size: &b 64 6 | data_folder: '~/data/sort-of-clevr/' 7 | split: 'train' 8 | dataset_size: &ds 12000 9 | regenerate: False 10 | img_size: &imgs 128 11 | # Set optimizer. 12 | optimizer: 13 | name: Adam 14 | lr: 0.005 15 | # Optional parameter, its presence results in clipping gradient to a range (-gradient_clipping, gradient_clipping) 16 | gradient_clipping: 10 17 | # Terminal condition parameters: 18 | terminal_conditions: 19 | loss_stop: 0.0001 20 | episode_limit: 100000 21 | epoch_limit: 20 22 | 23 | testing: 24 | # Problem parameters: 25 | problem: 26 | name: *name 27 | batch_size: *b 28 | data_folder: '~/data/sort-of-clevr/' 29 | dataset_size: *ds 30 | split: 'test' 31 | img_size: *imgs 32 | regenerate: False 33 | 34 | validation: 35 | # Problem parameters: 36 | problem: 37 | name: *name 38 | batch_size: *b 39 | data_folder: '~/data/sort-of-clevr/' 40 | split: 'val' 41 | dataset_size: *ds 42 | img_size: *imgs 43 | regenerate: False -------------------------------------------------------------------------------- /configs/vision/grid_trainer_mnist.yaml: -------------------------------------------------------------------------------- 1 | grid_tasks: 2 | - 3 | default_configs: configs/vision/lenet5_mnist.yaml 4 | - 5 | default_configs: configs/vision/simplecnn_mnist.yaml 6 | 7 | # Set exactly the same experiment conditions for the 2 tasks. 8 | grid_overwrite: 9 | training: 10 | problem: 11 | batch_size: &b 1000 12 | sampler: 13 | name: SubsetRandomSampler 14 | indices: [0, 55000] 15 | # Set the same optimizer parameters. 16 | optimizer: 17 | name: Adam 18 | lr: 0.01 19 | # Set the same terminal conditions. 20 | terminal_conditions: 21 | loss_stop: 4.0e-2 22 | episode_limit: 10000 23 | epoch_limit: 10 24 | 25 | # Problem parameters: 26 | validation: 27 | problem: 28 | batch_size: *b 29 | sampler: 30 | name: SubsetRandomSampler 31 | indices: [55000, 60000] 32 | 33 | testing: 34 | problem: 35 | batch_size: *b 36 | 37 | grid_settings: 38 | # Set number of repetitions of each experiments. 39 | experiment_repetitions: 5 40 | # Set number of concurrent running experiments. 41 | max_concurrent_runs: 4 42 | # Set trainer. 43 | trainer: mip-online-trainer 44 | -------------------------------------------------------------------------------- /configs/vqa_baselines/default_shape_color_query.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | # Problem parameters: 3 | cuda: True 4 | problem: 5 | name: &name ShapeColorQuery 6 | batch_size: &b 64 7 | data_folder: '~/data/shape-color-query/' 8 | split: 'train' 9 | dataset_size: &ds 12000 10 | regenerate: False 11 | img_size: &imgs 128 12 | 13 | # Set optimizer. 14 | optimizer: 15 | name: Adam 16 | lr: 0.005 17 | # Optional parameter, its presence results in clipping gradient to a range (-gradient_clipping, gradient_clipping) 18 | gradient_clipping: 10 19 | # Terminal condition parameters: 20 | terminal_condition: 21 | loss_stop: 0.0001 22 | max_episodes: 100000 23 | 24 | testing: 25 | # Problem parameters: 26 | problem: 27 | name: *name 28 | batch_size: *b 29 | data_folder: '~/data/shape-color-query/' 30 | split: 'test' 31 | dataset_size: *ds 32 | regenerate: False 33 | img_size: *imgs 34 | 35 | validation: 36 | # Problem parameters: 37 | problem: 38 | name: *name 39 | batch_size: *b 40 | data_folder: '~/data/shape-color-query/' 41 | split: 'val' 42 | dataset_size: *ds 43 | regenerate: False 44 | img_size: *imgs -------------------------------------------------------------------------------- /configs/relational_net/sort_of_clevr.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | # Problem parameters: 3 | problem: 4 | name: &name SortOfCLEVR 5 | batch_size: &b 64 6 | data_folder: '~/data/sort-of-clevr/' 7 | split: 'train' 8 | dataset_size: &ds 12000 9 | regenerate: False 10 | img_size: &imgs 128 11 | # Set optimizer. 12 | optimizer: 13 | name: Adam 14 | lr: 2.5e-4 15 | # Optional parameter, its presence results in clipping gradient to a range (-gradient_clipping, gradient_clipping) 16 | gradient_clipping: 10 17 | # Terminal condition parameters: 18 | terminal_conditions: 19 | loss_stop: 0.1 20 | episode_limit: 100000 21 | epoch_limit: 10 22 | 23 | testing: 24 | # Problem parameters: 25 | problem: 26 | name: *name 27 | batch_size: *b 28 | data_folder: '~/data/sort-of-clevr/' 29 | dataset_size: *ds 30 | split: 'test' 31 | img_size: *imgs 32 | regenerate: False 33 | 34 | validation: 35 | # Problem parameters: 36 | problem: 37 | name: *name 38 | batch_size: *b 39 | data_folder: '~/data/sort-of-clevr/' 40 | split: 'val' 41 | dataset_size: *ds 42 | img_size: *imgs 43 | regenerate: False 44 | 45 | # Model parameters: 46 | model: 47 | name: RelationalNetwork 48 | -------------------------------------------------------------------------------- /configs/dwm_baselines/dnc/dnc_default_settings_simple_task.yaml: -------------------------------------------------------------------------------- 1 | # Problem parameters: 2 | training: 3 | problem: 4 | # Size of generated input: [batch_size x sequence_length x number of control and data bits]. 5 | batch_size: 16 6 | # Parameters denoting min and max lengths. 7 | min_sequence_length: 1 8 | max_sequence_length: 10 9 | # Set optimizer. 10 | optimizer: 11 | name: Adam 12 | lr: 0.0005 13 | # Optional parameter, its presence results in clipping gradient to a range (-gradient_clipping, gradient_clipping) 14 | gradient_clipping: 10 15 | # Terminal condition parameters: 16 | terminal_conditions: 17 | loss_stop: 0.0001 18 | eposides_limit: 100000 19 | 20 | # Problem parameters: 21 | validation: 22 | problem: 23 | # Size of generated input: [batch_size x sequence_length x number of control + data bits]. 24 | batch_size: 64 25 | # Parameters denoting min and max lengths. 26 | min_sequence_length: 100 27 | max_sequence_length: 100 28 | 29 | # Problem parameters: 30 | testing: 31 | problem: 32 | # Size of generated input: [batch_size x sequence_length x number of control + data bits]. 33 | batch_size: 64 34 | # Parameters denoting min and max lengths. 35 | min_sequence_length: 1000 36 | max_sequence_length: 1000 37 | -------------------------------------------------------------------------------- /docs/source/api_reference/5_grid_workers.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | .. currentmodule:: miprometheus.grid_workers 5 | 6 | miprometheus.grid_workers 7 | =================================== 8 | .. automodule:: miprometheus.grid_workers 9 | :members: 10 | :special-members: 11 | :exclude-members: __dict__,__weakref__ 12 | 13 | GridWorker 14 | ---------------------------- 15 | .. autoclass:: GridWorker 16 | :members: 17 | :special-members: 18 | :exclude-members: __dict__,__weakref__ 19 | 20 | 21 | GridTrainerCPU 22 | -------------- 23 | .. autoclass:: GridTrainerCPU 24 | :members: 25 | :special-members: 26 | :exclude-members: __dict__,__weakref__ 27 | 28 | GridTrainerGPU 29 | -------------- 30 | .. autoclass:: GridTrainerGPU 31 | :members: 32 | :special-members: 33 | :exclude-members: __dict__,__weakref__ 34 | 35 | 36 | GridTesterCPU 37 | -------------- 38 | .. autoclass:: GridTesterCPU 39 | :members: 40 | :special-members: 41 | :exclude-members: __dict__,__weakref__ 42 | 43 | GridTesterGPU 44 | -------------- 45 | .. autoclass:: GridTesterGPU 46 | :members: 47 | :special-members: 48 | :exclude-members: __dict__,__weakref__ 49 | 50 | 51 | GridAnalyzer 52 | -------------- 53 | .. autoclass:: GridAnalyzer 54 | :members: 55 | :special-members: 56 | :exclude-members: __dict__,__weakref__ 57 | -------------------------------------------------------------------------------- /configs/dwm_baselines/default_settings_simple_task.yaml: -------------------------------------------------------------------------------- 1 | # Problem parameters: 2 | training: 3 | problem: 4 | # Size of generated input: [batch_size x sequence_length x number of control and data bits]. 5 | batch_size: 64 6 | # Parameters denoting min and max lengths. 7 | min_sequence_length: 1 8 | max_sequence_length: 10 9 | # Set optimizer. 10 | optimizer: 11 | name: Adam 12 | lr: 0.005 13 | # Optional parameter, its presence results in clipping gradient to a range (-gradient_clipping, gradient_clipping) 14 | gradient_clipping: 10 15 | # Terminal condition parameters: 16 | terminal_conditions: 17 | loss_stop: 0.0001 18 | eposides_limit: 100000 19 | 20 | 21 | # Problem parameters: 22 | validation: 23 | problem: 24 | # Size of generated input: [batch_size x sequence_length x number of control + data bits]. 25 | batch_size: 64 26 | # Parameters denoting min and max lengths. 27 | min_sequence_length: 100 28 | max_sequence_length: 100 29 | 30 | 31 | # Problem parameters: 32 | testing: 33 | problem: 34 | # Size of generated input: [batch_size x sequence_length x number of control + data bits]. 35 | batch_size: 64 36 | # Parameters denoting min and max lengths. 37 | min_sequence_length: 1000 38 | max_sequence_length: 1000 39 | 40 | -------------------------------------------------------------------------------- /configs/vision/simplecnn_cifar10.yaml: -------------------------------------------------------------------------------- 1 | # Problem parameters: 2 | training: 3 | #seed_numpy: 4354 4 | #seed_torch: 2452 5 | problem: 6 | name: &name CIFAR10 7 | batch_size: &b 64 8 | use_train_data: True 9 | # Use sampler that operates on a subset. 10 | sampler: 11 | name: SubsetRandomSampler 12 | indices: [0, 45000] 13 | # optimizer parameters: 14 | optimizer: 15 | name: Adam 16 | lr: 0.01 17 | terminal_conditions: 18 | loss_stop: 1.0e-5 19 | episode_limit: 50000 20 | 21 | # Problem parameters: 22 | validation: 23 | problem: 24 | name: *name 25 | batch_size: *b 26 | use_train_data: True # True because we are splitting the training set to: validation and training 27 | # Use sampler that operates on a subset. 28 | sampler: 29 | name: SubsetRandomSampler 30 | indices: [45000, 50000] 31 | 32 | # Problem parameters: 33 | testing: 34 | problem: 35 | name: *name 36 | batch_size: *b 37 | use_train_data: False 38 | 39 | # Model parameters: 40 | model: 41 | name: SimpleConvNet 42 | conv1: 43 | out_channels: 6 44 | kernel_size: 5 45 | stride: 1 46 | padding: 0 47 | conv2: 48 | out_channels: 16 49 | kernel_size: 5 50 | stride: 1 51 | padding: 0 52 | maxpool1: 53 | kernel_size: 2 54 | maxpool2: 55 | kernel_size: 2 56 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | # Copyright (C) tkornuta, IBM Corporation 2018-2019 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | language: python 16 | python: 3.6 17 | 18 | # Safelist: focus Travis' attention on the master and develop branches only. 19 | branches: 20 | only: 21 | - master 22 | - develop 23 | 24 | install: 25 | - pip3 install flake8 -r requirements.txt 26 | 27 | before_script: 28 | # Test plake8 compatibility. 29 | - python3 -m flake8 --version # flake8 3.6.0 on CPython 3.6.5 on Linux 30 | # stop the build if there are Python syntax errors or undefined names 31 | - python3 -m flake8 . --count --select=E9,F63,F72,F82 --show-source --statistics 32 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 33 | - python3 -m flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 34 | 35 | script: 36 | # Try to build documentation. 37 | - ./docs/build_docs.sh 38 | -------------------------------------------------------------------------------- /configs/dwm_baselines/dnc/dnc_default_settings_dual_task.yaml: -------------------------------------------------------------------------------- 1 | # Problem parameters: 2 | training: 3 | problem: 4 | # Size of generated input: [batch_size x sequence_length x number of control and data bits]. 5 | batch_size: 64 6 | # Parameters denoting min and max lengths. 7 | min_sequence_length: 1 8 | max_sequence_length: 6 9 | num_subseq_min: 1 10 | num_subseq_max: 3 11 | # Set optimizer. 12 | optimizer: 13 | name: Adam 14 | lr: 0.0005 15 | # Optional parameter, its presence results in clipping gradient to a range (-gradient_clipping, gradient_clipping) 16 | gradient_clipping: 10 17 | # Terminal condition parameters: 18 | terminal_conditions: 19 | loss_stop: 0.0001 20 | eposides_limit: 100000 21 | 22 | # Problem parameters: 23 | testing: 24 | problem: 25 | # Size of generated input: [batch_size x sequence_length x number of control + data bits]. 26 | batch_size: 64 27 | # Parameters denoting min and max lengths. 28 | min_sequence_length: 50 29 | max_sequence_length: 50 30 | num_subseq_min: 20 31 | num_subseq_max: 20 32 | 33 | # Problem parameters: 34 | validation: 35 | problem: 36 | # Size of generated input: [batch_size x sequence_length x number of control + data bits]. 37 | batch_size: 64 38 | # Parameters denoting min and max lengths. 39 | min_sequence_length: 20 40 | max_sequence_length: 20 41 | num_subseq_min: 5 42 | num_subseq_max: 5 43 | -------------------------------------------------------------------------------- /configs/dwm_baselines/default_settings_dual_task.yaml: -------------------------------------------------------------------------------- 1 | # Problem parameters: 2 | training: 3 | problem: 4 | # Size of generated input: [batch_size x sequence_length x number of control and data bits]. 5 | batch_size: 64 6 | # Parameters denoting min and max lengths. 7 | min_sequence_length: 1 8 | max_sequence_length: 6 9 | num_subseq_min: 1 10 | num_subseq_max: 3 11 | # Set optimizer. 12 | optimizer: 13 | name: Adam 14 | lr: 0.005 15 | # Optional parameter, its presence results in clipping gradient to a range (-gradient_clipping, gradient_clipping) 16 | gradient_clipping: 10 17 | # Terminal condition parameters: 18 | terminal_conditions: 19 | loss_stop: 0.0001 20 | episodes_limit: 100000 21 | 22 | 23 | # Problem parameters: 24 | validation: 25 | problem: 26 | # Size of generated input: [batch_size x sequence_length x number of control + data bits]. 27 | batch_size: 64 28 | # Parameters denoting min and max lengths. 29 | min_sequence_length: 20 30 | max_sequence_length: 20 31 | num_subseq_min: 5 32 | num_subseq_max: 5 33 | 34 | 35 | # Problem parameters: 36 | testing: 37 | problem: 38 | # Size of generated input: [batch_size x sequence_length x number of control + data bits]. 39 | batch_size: 64 40 | # Parameters denoting min and max lengths. 41 | min_sequence_length: 50 42 | max_sequence_length: 50 43 | num_subseq_min: 20 44 | num_subseq_max: 20 45 | -------------------------------------------------------------------------------- /docs/source/research/1_introduction.rst: -------------------------------------------------------------------------------- 1 | Reproducible Research Primer 2 | ================================================== 3 | `@author: Vincent Marois, Tomasz Kornuta` 4 | 5 | One of the main reasons for the existence of this framework is to enable reproducible research, \ 6 | i.e. to enable other researchers to run experiments designed by us and reproduce our results. 7 | 8 | This section gathers information about our published work and describes how to reproduce the associated experiments. 9 | 10 | +---------------------------------------+ 11 | | Description | 12 | +=======================================+ 13 | | |vigil_decription| | 14 | | | 15 | | See :ref:`vigil-experiments` | 16 | | | 17 | | |vigil_reference| `arxiv/1811.06529`_ | 18 | +---------------------------------------+ 19 | | |cog_decription| | 20 | | | 21 | | See :ref:`cog-experiments` | 22 | +---------------------------------------+ 23 | 24 | .. |vigil_decription| replace:: We have tested the compositional generalization of the MAC and S-MAC models using CLEVR and CoGenT datasets. Our work [1] has been accepted and presented during the `ViGIL Workshop`_ @ NeurIPS 2018. 25 | 26 | .. |vigil_reference| replace:: [1] Marois, Jayram, Albouy, Kornuta, Bouhadjar, Ozcan (2018). On transfer learning using a MAC model variant. 27 | 28 | .. |cog_decription| replace:: Preliminary results of experiments with COG dataset. 29 | 30 | .. _ViGIL Workshop: https://nips2018vigil.github.io/ 31 | .. _arxiv/1811.06529: https://arxiv.org/abs/1811.06529 -------------------------------------------------------------------------------- /configs/vqa_baselines/default_clevr.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | # Problem parameters: 3 | problem: 4 | name: &name CLEVR 5 | batch_size: &b 64 6 | settings: 7 | data_folder: &dir '~/data/CLEVR_v1.0' 8 | set: 'train' 9 | dataset_variant: &var 'CLEVR' 10 | images: 11 | raw_images: True 12 | questions: 13 | embedding_type: &emb 'random' 14 | embedding_dim: 300 15 | 16 | # Set optimizer. 17 | optimizer: 18 | name: Adam 19 | lr: 1.0e-4 20 | # Optional parameter, its presence results in clipping gradient to a range (-gradient_clipping, gradient_clipping) 21 | gradient_clipping: 10 22 | # Terminal condition parameters: 23 | terminal_conditions: 24 | loss_stop: 0.03 25 | epoch_limit: 20 26 | 27 | validation: 28 | partial_validation_interval: 200 29 | # Problem parameters: 30 | problem: 31 | name: *name 32 | batch_size: *b 33 | settings: 34 | data_folder: *dir 35 | set: 'val' 36 | dataset_variant: *var 37 | images: 38 | raw_images: True 39 | questions: 40 | embedding_type: *emb 41 | embedding_dim: 300 42 | 43 | testing: 44 | # Problem parameters: 45 | problem: 46 | name: *name 47 | batch_size: *b 48 | settings: 49 | data_folder: *dir 50 | set: 'val' 51 | dataset_variant: *var 52 | images: 53 | raw_images: True 54 | questions: 55 | embedding_type: *emb 56 | embedding_dim: 300 57 | max_test_episodes: -1 # do an entire epoch. 58 | -------------------------------------------------------------------------------- /configs/vision/simplecnn_mnist.yaml: -------------------------------------------------------------------------------- 1 | # Problem parameters: 2 | training: 3 | #seed_numpy: 4354 4 | #seed_torch: 2452 5 | problem: 6 | name: &name MNIST 7 | batch_size: &b 64 8 | data_folder: &folder '~/data/mnist' 9 | use_train_data: True 10 | resize: [32, 32] 11 | sampler: 12 | name: SubsetRandomSampler 13 | indices: [0, 55000] 14 | #indices: ~/data/mnist/split_a.txt 15 | # optimizer parameters: 16 | optimizer: 17 | name: Adam 18 | lr: 0.01 19 | # settings parameters 20 | terminal_conditions: 21 | loss_stop: 1.0e-2 22 | episode_limit: 1000 23 | epoch_limit: 1 24 | 25 | # Problem parameters: 26 | validation: 27 | problem: 28 | name: *name 29 | batch_size: *b 30 | data_folder: *folder 31 | use_train_data: True # True because we are splitting the training set to: validation and training 32 | resize: [32, 32] 33 | sampler: 34 | name: SubsetRandomSampler 35 | indices: [55000, 60000] 36 | #indices: ~/data/mnist/split_b.txt 37 | #dataloader: 38 | # drop_last: True 39 | 40 | # Problem parameters: 41 | testing: 42 | #seed_numpy: 4354 43 | #seed_torch: 2452 44 | problem: 45 | name: *name 46 | batch_size: *b 47 | data_folder: *folder 48 | use_train_data: False 49 | resize: [32, 32] 50 | 51 | 52 | # Model parameters: 53 | model: 54 | name: SimpleConvNet 55 | conv1: 56 | out_channels: 6 57 | kernel_size: 5 58 | stride: 1 59 | padding: 0 60 | conv2: 61 | out_channels: 16 62 | kernel_size: 5 63 | stride: 1 64 | padding: 0 65 | maxpool1: 66 | kernel_size: 2 67 | maxpool2: 68 | kernel_size: 2 69 | -------------------------------------------------------------------------------- /configs/maes_baselines/default_problem.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | # Set random seeds that will be used during training (and validation). When not present using random values. 3 | #seed_numpy: -1 4 | #seed_torch: -1 5 | problem: 6 | # Size of generated input: [batch_size x sequence_length x number of control + data bits]. 7 | control_bits: &cbits 3 8 | data_bits: &dbits 8 9 | batch_size: &bs 64 10 | #randomize_control_lines: True 11 | # Parameters denoting min and max lengths. 12 | min_sequence_length: 3 13 | max_sequence_length: 20 14 | # Size of the dataset. Influences how often curriculum learning will be triggered (when the dataset is exhausted!) 15 | size: 32000 # i.e. every 500 episodes, as in the paper 16 | 17 | # This section is optional. 18 | validation: 19 | # How often the model will be validated/saved. 20 | partial_validation_interval: 100 21 | problem: 22 | # Size of generated input: [batch_size x sequence_length x number of control + data bits]. 23 | control_bits: *cbits 24 | data_bits: *dbits 25 | batch_size: 64 26 | # Parameters denoting min and max lengths. 27 | min_sequence_length: 21 28 | max_sequence_length: 21 29 | size: 64 # Which means that partial validation = full validation. 30 | 31 | 32 | testing: 33 | # Set random seeds that will be used during training (and validation). When not present using random values. 34 | #seed_numpy: -1 35 | #seed_torch: -1 36 | # Problem definition. 37 | problem: 38 | # Size of generated input: [batch_size x sequence_length x number of control + data bits]. 39 | control_bits: *cbits 40 | data_bits: *dbits 41 | batch_size: 64 42 | # Parameters denoting min and max lengths. 43 | min_sequence_length: 100 44 | max_sequence_length: 100 45 | size: 64 46 | -------------------------------------------------------------------------------- /configs/thalnet/serial_recall.yaml: -------------------------------------------------------------------------------- 1 | # Problem parameters: 2 | training: 3 | cuda: True 4 | problem: 5 | name: &name SerialRecall 6 | # Size of generated input: [batch_size x sequence_length x number of control and data bits]. 7 | control_bits: &cbits 2 8 | data_bits: &dbits 8 9 | batch_size: 64 10 | # Parameters denoting min and max lengths. 11 | min_sequence_length: 1 12 | max_sequence_length: 10 13 | # Set optimizer. 14 | optimizer: 15 | name: Adam 16 | lr: 0.005 17 | # Optional parameter, its presence results in clipping gradient to a range (-gradient_clipping, gradient_clipping) 18 | gradient_clipping: 10 19 | # Terminal condition parameters: 20 | terminal_conditions: 21 | loss_stop: 0.0001 22 | episode_limit: 100000 23 | 24 | # Problem parameters: 25 | testing: 26 | problem: 27 | name: *name 28 | # Size of generated input: [batch_size x sequence_length x number of control + data bits]. 29 | control_bits: *cbits 30 | data_bits: *dbits 31 | batch_size: 64 32 | # Parameters denoting min and max lengths. 33 | min_sequence_length: 1000 34 | max_sequence_length: 1000 35 | bias: 0.5 36 | 37 | # Problem parameters: 38 | validation: 39 | problem: 40 | name: *name 41 | # Size of generated input: [batch_size x sequence_length x number of control + data bits]. 42 | control_bits: *cbits 43 | data_bits: *dbits 44 | batch_size: 64 45 | # Parameters denoting min and max lengths. 46 | min_sequence_length: 100 47 | max_sequence_length: 100 48 | bias: 0.5 49 | 50 | # Model parameters: 51 | model: 52 | name: ThalNetModel 53 | # Controller parameters. 54 | context_input_size: 32 55 | input_size: 10 # row_size 56 | output_size: 8 # number of classes 57 | center_size_per_module: 32 58 | num_modules: 4 59 | 60 | -------------------------------------------------------------------------------- /miprometheus/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Main imports. 2 | from .model import Model 3 | from .model_factory import ModelFactory 4 | from .sequential_model import SequentialModel 5 | 6 | # .cog 7 | from .cog.network import CogModel 8 | 9 | # MANN models. 10 | from .dnc.dnc_model import DNC 11 | from .dwm.dwm_model import DWM 12 | from .ntm.ntm_model import NTM 13 | from .thalnet.thalnet_model import ThalNetModel 14 | 15 | # .encoder_solver 16 | from .encoder_solver import * 17 | from .encoder_solver.es_lstm_model import EncoderSolverLSTM 18 | from .encoder_solver.es_ntm_model import EncoderSolverNTM 19 | from .encoder_solver.maes_model import MAES 20 | 21 | # Other models. 22 | from .lstm.lstm_model import LSTM 23 | from .mental_model.mental_model import MentalModel 24 | 25 | # VQA models. 26 | from .mac.model import MACNetwork 27 | from .s_mac.s_mac import sMacNetwork 28 | from .VWM_model.model import MACNetworkSequential 29 | from .relational_net.relational_network import RelationalNetwork 30 | 31 | # .vqa_baselines 32 | from .vqa_baselines.cnn_lstm import CNN_LSTM 33 | from .vqa_baselines.stacked_attention_networks.stacked_attention_model import StackedAttentionNetwork 34 | 35 | # .vision 36 | from .vision.alexnet_wrapper import AlexnetWrapper 37 | from .vision.lenet5 import LeNet5 38 | from .vision.simple_cnn import SimpleConvNet 39 | 40 | 41 | 42 | __all__ = [ 43 | 'Model', 44 | 'ModelFactory', 45 | 'SequentialModel', 46 | # .cog 47 | 'CogModel', 48 | # .controllers 49 | # MANN models. 50 | 'DNC', 51 | 'DWM', 52 | 'NTM', 53 | 'ThalNetModel', 54 | # .encoder_solver 55 | 'EncoderSolverLSTM', 56 | 'EncoderSolverNTM', 57 | 'MAES', 58 | # Other models. 59 | 'LSTM', 60 | 'MentalModel', 61 | # VQA models. 62 | 'MACNetwork', 63 | 'sMacNetwork', 64 | 'MACNetworkSequential', 65 | 'RelationalNetwork', 66 | # .vqa_baselines 67 | 'CNN_LSTM', 68 | 'StackedAttentionNetwork', 69 | # .vision 70 | 'AlexnetWrapper', 'LeNet5', 'SimpleConvNet', 71 | ] 72 | -------------------------------------------------------------------------------- /configs/text2text/translation.yaml: -------------------------------------------------------------------------------- 1 | # Problem parameters: 2 | training: 3 | problem: 4 | name: &name translation 5 | batch_size: &b 32 6 | training_size: &ts 0.90 7 | output_lang_name: 'fra' 8 | max_sequence_length: &seq 15 9 | eng_prefixes: ["i am ", "i m ", "he is", "he s ", "she is", "she s", "you are", "you re ", "we are", "we re ", "they are", "they re "] 10 | use_train_data: True 11 | data_folder: '~/data/language' 12 | reverse: False 13 | 14 | cuda: True 15 | 16 | # set optimizer 17 | optimizer: 18 | # Exact name of the pytorch optimizer function 19 | name: SGD 20 | # Function arguments of the optimizer, by name 21 | lr: 0.01 22 | 23 | terminal_condition: 24 | loss_stop: 0.1 25 | max_episodes: 1300 26 | 27 | 28 | # Problem parameters: 29 | validation: 30 | problem: 31 | name: *name 32 | batch_size: *b 33 | training_size: *ts 34 | output_lang_name: 'fra' 35 | max_sequence_length: *seq 36 | eng_prefixes: ["i am ", "i m ", "he is", "he s ", "she is", "she s", "you are", "you re ", "we are", "we re ", "they are", "they re "] 37 | use_train_data: True 38 | data_folder: '~/data/language' 39 | reverse: False 40 | 41 | 42 | # Problem parameters: 43 | testing: 44 | problem: 45 | name: *name 46 | batch_size: *b 47 | training_size: *ts 48 | output_lang_name: 'fra' 49 | max_sequence_length: *seq 50 | eng_prefixes: ["i am ", "i m ", "he is", "he s ", "she is", "she s", "you are", "you re ", "we are", "we re ", "they are", "they re "] 51 | use_train_data: False 52 | data_folder: '~/data/language' 53 | reverse: False 54 | 55 | 56 | 57 | # Model parameters: 58 | model: 59 | name: simple_encoder_decoder 60 | max_length: *seq 61 | input_voc_size: 3506 # this value is coming from the problem class (problem.input_lang.n_words) 62 | hidden_size: 256 63 | output_voc_size: 5231 64 | encoder_bidirectional: True 65 | -------------------------------------------------------------------------------- /configs/mac/default_clevr.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | # Problem parameters: 3 | problem: 4 | name: &name CLEVR 5 | batch_size: &b 64 6 | settings: 7 | data_folder: &dir '~/data/CLEVR_v1.0' 8 | set: 'train' 9 | dataset_variant: &var 'CLEVR' 10 | images: 11 | raw_images: False 12 | feature_extractor: 13 | cnn_model: &cnn 'resnet101' 14 | num_blocks: 4 15 | questions: 16 | embedding_type: &emb 'random' 17 | embedding_dim: 300 18 | 19 | # Set optimizer. 20 | optimizer: 21 | name: Adam 22 | lr: 1.0e-4 23 | # Optional parameter, its presence results in clipping gradient to a range (-gradient_clipping, gradient_clipping) 24 | gradient_clipping: 10 25 | # Terminal condition parameters: 26 | terminal_conditions: 27 | loss_stop: 0.03 28 | epoch_limit: 20 29 | 30 | # fix the seeds 31 | seed_torch: 0 32 | seed_numpy: 0 33 | 34 | validation: 35 | partial_validation_interval: 200 36 | # Problem parameters: 37 | problem: 38 | name: *name 39 | batch_size: *b 40 | settings: 41 | data_folder: *dir 42 | set: 'val' 43 | dataset_variant: *var 44 | images: 45 | raw_images: False 46 | feature_extractor: 47 | cnn_model: *cnn 48 | num_blocks: 4 49 | questions: 50 | embedding_type: *emb 51 | embedding_dim: 300 52 | 53 | testing: 54 | # Problem parameters: 55 | problem: 56 | name: *name 57 | batch_size: *b 58 | settings: 59 | data_folder: *dir 60 | set: 'val' 61 | dataset_variant: *var 62 | images: 63 | raw_images: False 64 | feature_extractor: 65 | cnn_model: *cnn 66 | num_blocks: 4 67 | questions: 68 | embedding_type: *emb 69 | embedding_dim: 300 70 | max_test_episodes: -1 # do an entire epoch. 71 | -------------------------------------------------------------------------------- /miprometheus/models/mental_model/ops.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines attention subunit for mental model 3 | """ 4 | 5 | __author__ = "Emre Sevgen" 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | class SemanticAttention(nn.Module): 11 | """ 12 | ``SemanticAttention`` performs Bahdanau attention on post-LSTM question encoding. 13 | 14 | """ 15 | def __init__(self,attention_size, attention_input_size): 16 | """ 17 | Constructor of the ``SemanticAttention``. Instantiates linear layers. 18 | 19 | :param attention_size: Output size of the linear layer for attention. 20 | :type attention_size: Int 21 | 22 | :param attention_input_size: Input size to the linear layer that generates attention. 23 | :type attention_input_size: Int 24 | 25 | """ 26 | super(SemanticAttention,self).__init__() 27 | self.attention_size = attention_size 28 | 29 | self.attn1 = nn.Linear(attention_input_size,attention_size) 30 | self.trainable_weights = nn.Parameter(torch.randn(attention_size)*0.1) 31 | 32 | # Initialize network 33 | nn.init.xavier_uniform_(self.attn1.weight) 34 | self.attn1.bias.data.fill_(0.0) 35 | 36 | def forward(self,keys,query): 37 | """ 38 | Forward pass of ``SemanticAttention``. 39 | 40 | :param keys: LSTM encoded question. 41 | :type keys: torch.Tensor 42 | 43 | :param query: Query vector generated by linear layer. 44 | :type query: torch.Tensor 45 | 46 | """ 47 | 48 | query = self.attn1(query) 49 | #print(keys.size()) 50 | #print(self.trainable_weights.size()) 51 | #print(query.size()) 52 | # keys is batch x sequence x embedding length 53 | # trainable weights is embedding length 54 | # query is batch x embedding length 55 | # similarity is batch x sequence 56 | 57 | similarity = torch.sum( 58 | self.trainable_weights.unsqueeze(0).unsqueeze(0) * 59 | torch.tanh(query.unsqueeze(1) + keys), -1 ) 60 | 61 | post_attention = torch.sum( 62 | keys * similarity.unsqueeze(-1), 1) 63 | 64 | return post_attention 65 | 66 | 67 | if __name__ == '__main__': 68 | 69 | semantic_attn = SemanticAttention(8,128) 70 | keys = torch.rand((2,4,8)) 71 | query = torch.rand((2,128)) 72 | semantic_attn(keys,query) 73 | 74 | 75 | -------------------------------------------------------------------------------- /configs/maes_baselines/maes_grid_training.yaml: -------------------------------------------------------------------------------- 1 | # Lists of tasks that will be executed in batch. 2 | grid_tasks: 3 | - 4 | default_configs: configs/maes_baselines/maes/maes_sequence_comparison.yaml 5 | # Parameters that will be overwritten for a given task. 6 | overwrite: 7 | testing: 8 | problem: 9 | min_sequence_length: 99 10 | max_sequence_length: 99 11 | - 12 | default_configs: configs/maes_baselines/maes/maes_sequence_equality.yaml 13 | - 14 | default_configs: configs/maes_baselines/maes/maes_sequence_symmetry.yaml 15 | - 16 | default_configs: configs/maes_baselines/maes/maes_odd_recall.yaml 17 | - 18 | default_configs: configs/maes_baselines/maes/maes_serial_recall.yaml 19 | - 20 | default_configs: configs/maes_baselines/maes/maes_reverse_recall.yaml 21 | 22 | # Parameters that will be overwritten for all tasks. 23 | grid_overwrite: 24 | # Model configuration. 25 | model: 26 | num_content_bits: 15 27 | 28 | # Training configuration. 29 | training: 30 | problem: 31 | min_sequence_length: 3 32 | max_sequence_length: 20 33 | control_bits: &cbits 3 34 | data_bits: &dbits 8 35 | # Curriculum learning - optional. 36 | curriculum_learning: 37 | interval: 500 38 | initial_max_sequence_length: 5 39 | must_finish: false 40 | # Terminal condition parameters: 41 | terminal_conditions: 42 | loss_stop: 1.0e-5 43 | episodes_limit: 100 44 | 45 | # Validation configuration. 46 | validation: 47 | # How often the model will be validated/saved. 48 | partial_validation_interval: 1000 49 | problem: 50 | control_bits: *cbits 51 | data_bits: *dbits 52 | min_sequence_length: 21 53 | max_sequence_length: 21 54 | 55 | # Testing configuration. 56 | testing: 57 | problem: 58 | min_sequence_length: 101 59 | max_sequence_length: 101 60 | control_bits: *cbits 61 | data_bits: *dbits 62 | 63 | 64 | grid_settings: 65 | # Number of times each task will be repeated. 66 | experiment_repetitions: 3 67 | # Max runs (that will be limited by the actual number of available CPUs/GPUs) 68 | max_concurrent_runs: 7 69 | -------------------------------------------------------------------------------- /docs/source/mip_primer/5_grid_workers.rst: -------------------------------------------------------------------------------- 1 | 2 | Grid Workers Explained 3 | ====================== 4 | `@author: Tomasz Kornuta & Vincent Marois` 5 | 6 | There are five Grid Workers, i.e. scripts which manage sets of experiments on grids of CPUs/GPUs. 7 | These are: 8 | 9 | - two Grid Trainers (separate versions for collections of CPUs and GPUs) spanning several trainings in parallel, 10 | - two Grid Testers (similarly), 11 | - a single Grid Analyzer, which collects the results of several trainings & tests in a given experiment directory into a single csv file. 12 | 13 | 14 | .. figure:: ../img/worker_grid_class_diagram.png 15 | :scale: 50 % 16 | :alt: Class diagram of the grid workers. 17 | :align: center 18 | 19 | The class inheritance of the grid workers. The Trainers & the Tester classes inherit from a base Worker class, to follow OOP best practices. 20 | 21 | 22 | The Grid Trainers and Testers in fact spawn several instances of base Trainers and Testers respectively. 23 | The CPU & GPU versions execute different operations, i.e. the CPUs grid workers assign one processor for each child, whereas the GPUs ones assigns a single GPU instead. 24 | 25 | Fig. 7 presents the most important sections of the grid trainer configuration files. Section grid tasks defines the grid of experiments that need to be executed, reusing the mechanism of default configuration nesting. 26 | Additionally, in grid settings, the user needs to define the number of repetitions of each experiment, as well as the maximum number of authorized concurrent runs (which later on will be compared to the number of available CPUs/GPUs). 27 | Optionally, the user might overwrite some parameters of a given experiment (in the `overwrite` section) or all experiments at once (`grid_overwrite`). 28 | 29 | As a result of running these Grid Trainers and Testers, the user ends up with an experiment directory containing several models and statistics collected during several training, validation and test repetitions. 30 | The role of the last script, Grid Analyzer, is to iterate through those directories, collecting all statistics and merging them into a single file that facilitates a further analysis of results, the comparison of the models performance, etc. 31 | -------------------------------------------------------------------------------- /configs/ntm/serial_recall.yaml: -------------------------------------------------------------------------------- 1 | # Problem parameters: 2 | training: 3 | problem: 4 | name: SerialRecallCommandLines 5 | # Size of generated input: [batch_size x sequence_length x number of command and data bits]. 6 | batch_size: &bs 64 7 | # Parameters denoting min and max lengths. 8 | min_sequence_length: 2 9 | max_sequence_length: 20 10 | size: 32000 # i.e. every 500 episodes = 1 epoch 11 | 12 | #Curriculum learning - optional. 13 | curriculum_learning: 14 | initial_max_sequence_length: 5 15 | must_finish: false 16 | 17 | # Optional parameter, its presence results in clipping gradient to a range (-gradient_clipping, gradient_clipping) 18 | gradient_clipping: 10 19 | 20 | # Set optimizer. 21 | optimizer: 22 | name: Adam 23 | lr: 0.01 24 | 25 | # Terminal condition parameters: 26 | terminal_conditions: 27 | loss_stop: 1.0e-4 28 | episode_limit: 10000 29 | 30 | 31 | # Problem parameters: 32 | validation: 33 | problem: 34 | name: SerialRecallCommandLines 35 | # Size of generated input: [batch_size x sequence_length x number of control + data bits]. 36 | batch_size: 64 37 | # Parameters denoting min and max lengths. 38 | min_sequence_length: 21 39 | max_sequence_length: 21 40 | 41 | 42 | # Problem parameters: 43 | testing: 44 | problem: 45 | name: SerialRecallCommandLines 46 | # Size of generated input: [batch_size x sequence_length x number of control + data bits]. 47 | batch_size: 64 48 | # Parameters denoting min and max lengths. 49 | min_sequence_length: 100 50 | max_sequence_length: 100 51 | 52 | # Model parameters: 53 | model: 54 | name: NTM 55 | # Optional parameter: visualization. 56 | visualization_mode: 2 57 | # Controller parameters. 58 | controller: 59 | name: RNNController 60 | hidden_state_size: 20 61 | num_layers: 1 62 | non_linearity: sigmoid 63 | # Interface 64 | interface: 65 | num_read_heads: 1 66 | shift_size: 3 67 | # Memory parameters. 68 | memory: 69 | num_content_bits: 10 70 | num_addresses: -1 71 | -------------------------------------------------------------------------------- /configs/cog/default_cog.yaml: -------------------------------------------------------------------------------- 1 | # Problem parameters: 2 | training: 3 | problem: 4 | name: COG 5 | # Size of generated input: [batch_size x sequence_length x classes]. 6 | batch_size: 48 7 | tasks: [AndCompareColor, AndCompareShape , AndSimpleCompareColor, AndSimpleCompareShape, CompareColor, CompareShape, Exist,ExistColor, ExistColorOf, ExistColorSpace, ExistLastColorSameShape, ExistLastObjectSameObject, ExistLastShapeSameColor, ExistShape, ExistShapeOf, ExistShapeSpace, ExistSpace, GetColor, GetColorSpace, GetShape, GetShapeSpace, SimpleCompareColor, SimpleCompareShape] 8 | words_embed_length : 64 9 | nwords : 24 10 | data_folder: '~/data/cog' 11 | set: val 12 | use_mask: False 13 | dataset_type: canonical 14 | # Set optimizer. 15 | optimizer: 16 | name: Adam 17 | lr: 0.0002 18 | weight_decay: 0.00002 19 | gradient_clipping: 10 20 | terminal_conditions: 21 | epoch_limit: 40 22 | episodes: 100000000 23 | 24 | validation: 25 | partial_validation_interval: 100 26 | problem: 27 | name: COG 28 | # Size of generated input: [batch_size x sequence_length x classes]. 29 | batch_size: 48 30 | tasks: [AndCompareColor, AndCompareShape , AndSimpleCompareColor, AndSimpleCompareShape, CompareColor, CompareShape, Exist,ExistColor, ExistColorOf, ExistColorSpace, ExistLastColorSameShape, ExistLastObjectSameObject, ExistLastShapeSameColor, ExistShape, ExistShapeOf, ExistShapeSpace, ExistSpace, GetColor, GetColorSpace, GetShape, GetShapeSpace, SimpleCompareColor, SimpleCompareShape] 31 | data_folder: '~/data/cog' 32 | set: val 33 | use_mask: False 34 | dataset_type: canonical 35 | 36 | testing: 37 | problem: 38 | name: COG 39 | # Size of generated input: [batch_size x sequence_length x classes]. 40 | batch_size: 48 41 | tasks: [AndCompareColor, AndCompareShape , AndSimpleCompareColor, AndSimpleCompareShape, CompareColor, CompareShape, Exist,ExistColor, ExistColorOf, ExistColorSpace, ExistLastColorSameShape, ExistLastObjectSameObject, ExistLastShapeSameColor, ExistShape, ExistShapeOf, ExistShapeSpace, ExistSpace, GetColor, GetColorSpace, GetShape, GetShapeSpace, SimpleCompareColor, SimpleCompareShape] 42 | data_folder: '~/data/cog' 43 | set: test 44 | dataset_type: canonical 45 | -------------------------------------------------------------------------------- /miprometheus/utils/split_indices.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) IBM Corporation 2018 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | """ 19 | split_indices.py: 20 | 21 | - Contains the definition of a split_indices function. 22 | 23 | """ 24 | __author__ = "Tomasz Kornuta" 25 | 26 | import numpy as np 27 | 28 | 29 | def split_indices(length, split, logger, random_sampling=True): 30 | """ 31 | Splits the indices of an array of a given ``length`` into two parts, using the ``split`` as the divider. 32 | 33 | Random sampling is used by default, but can be turned off. 34 | 35 | :param length: Length (size) of the dataset. 36 | :type length: int 37 | 38 | :param split: Determines how many indices will belong to subset a and subset b. 39 | :type split: int 40 | 41 | :param logger: Logging utility. 42 | :type logger: logging.Logger 43 | 44 | :param random_sampling: Use random sampling (DEFAULT: ``True``). If set to ``False``, will return two ranges \ 45 | instead of lists with indices. 46 | :type random_sampling: bool 47 | 48 | :return: Two lists with indices (when random_sampling is ``True``), or two lists with two elements - \ 49 | ranges (when ``False``). 50 | 51 | """ 52 | if random_sampling: 53 | logger.info('Using random sampling') 54 | # Random indices. 55 | indices = np.random.permutation(length) 56 | # Split into two pieces. 57 | split_a = indices[0:split] 58 | split_b = indices[split:length] 59 | else: 60 | logger.info('Splitting into two ranges without random sampling') 61 | # Split into two ranges. 62 | split_a = np.asarray([0,split], dtype=int) # outer right indices won't be included! 63 | split_b = np.asarray([split,length], dtype=int) 64 | 65 | return split_a, split_b 66 | -------------------------------------------------------------------------------- /miprometheus/models/controllers/feedforward_controller.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) IBM Corporation 2018 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | """lstm_controller.py: pytorch module implementing wrapper for feedforward controller.""" 19 | __author__ = "Tomasz Kornuta/Ryan L. McAvoy" 20 | 21 | import torch 22 | from torch.nn import Module 23 | 24 | 25 | class FeedforwardController(Module): 26 | """ 27 | A wrapper class for a feedforward controller. 28 | """ 29 | 30 | def __init__(self, params): 31 | """ 32 | Constructor. 33 | 34 | :param params: Dictionary of parameters. 35 | 36 | """ 37 | # Call constructor of base class. 38 | super(FeedforwardController, self).__init__() 39 | 40 | # Parse parameters. 41 | # Set input and hidden dimensions. 42 | self.input_size = params["input_size"] 43 | self.ctrl_hidden_state_size = params["output_size"] 44 | 45 | # Processes input and produces hidden state of the controller. 46 | self.ff = torch.nn.Linear(self.input_size, self.ctrl_hidden_state_size) 47 | 48 | def init_state(self, batch_size): 49 | """ 50 | Returns 'zero' (initial) state tuple - in this case empy tuple. 51 | 52 | :param batch_size: Size of the batch in given iteraction/epoch. 53 | :returns: Initial state tuple - empty (). 54 | """ 55 | return () 56 | 57 | def forward(self, inputs_BxI, prev_state_tuple): 58 | """ 59 | Controller forward function. 60 | 61 | :param inputs_BxI: a Tensor of input data of size [BATCH_SIZE x INPUT_SIZE] 62 | :param prev_state_tuple: unused - empty tuple () 63 | :returns: outputs a Tensor of size [BATCH_SIZE x OUTPUT_SIZE] and empty tuple. 64 | 65 | """ 66 | # Execute feedforward pass. 67 | hidden_state = self.ff(inputs_BxI) 68 | 69 | # Return hidden_state (as output) and empty state tuple. 70 | return hidden_state, () 71 | -------------------------------------------------------------------------------- /docs/source/mip_primer/3_models_explained.rst: -------------------------------------------------------------------------------- 1 | Models Explained 2 | =================== 3 | `@author: Tomasz Kornuta & Vincent Marois` 4 | 5 | 6 | From the application point of view, the models are the most important components of the system, as the main goal of 7 | the MI-Prometheus framework is to facilitate the development and training of new models. 8 | 9 | Currently, the framework contains a suite of models that can be divided into three categories. 10 | 11 | The first category includes several classic models, such as AlexNet (Krizhevsky & Hinton, 2009) or the LSTM (Long-Short Term Memory) (Hochreiter & Schmidhuber, 12 | 1997). Those wrapper classes, aside of being common baselines, are examples showing how one should incorporate external models into MI-Prometheus. 13 | 14 | Other classical models that we incorporated are sequence-to-sequence models for neural translation such as a LSTM-based encoder-decoder (Sutskever et al., 2014) and its variation with GRUs and attention (Bahdanau et al., 2014). 15 | 16 | The second category includes Memory-Augmented Neural Networks (Samir, 2017). It includes some well-acclaimed models, such as the NTM (Neural Turing Machine) (Graves et al., 2014) or the DNC (Differentiable Neural Computer) (Graves et al., 2016), 17 | alongside several new architectures, e.g. ThalNet (Hafner et al., 2017). 18 | 19 | The third category includes models designed for the VQA (Visual Question Answering) (Antol et al., 2015) problem domain. Here, MI-Prometheus currently offers some baseline models (e.g. a CNN coupled with an LSTM) and several state-of-the-art models such as SANs (Stacked Attention Networks) (Yang et al., 2016) RNs (Relational Networks) (Santoro et al., 2017) or MAC (Memory Attention Composition) (Hudson & Manning, 2018). 20 | 21 | Each model defines several methods and attributes: 22 | 23 | - A forward pass, resulting in predictions, 24 | - A plot function, creating a dynamic visualization of parameters of the model during execution, 25 | - A definition of the inputs that it expects/accepts, 26 | - Getter & setter methods for the collected statistics. 27 | - Getter & setter methods for the statistics which are `aggregated` based on the collected statistics: This enables to compute e.g. the mean, min, max, std of the loss or accuracy over an epoch or the whole validation set. 28 | 29 | 30 | 31 | .. figure:: ../img/model_class_diagram.png 32 | :scale: 50 % 33 | :alt: Class diagram of the models. 34 | :align: center 35 | 36 | The class inheritance of the models. 37 | 38 | The base `Model` class defines useful features for all models, such as an instance of the Application State, the Parameter Registry or the logger. 39 | This base class inherits from `torch.nn.Module` to access the `forward` method, and have its weights be differentiable. 40 | 41 | We also defined a `SequentialModel` class which can be subclassed by recurrent or sequential models, such as MANNs. -------------------------------------------------------------------------------- /docs/source/notes/1_installation.rst: -------------------------------------------------------------------------------- 1 | .. _installation: 2 | 3 | Installation 4 | =================== 5 | `@author: Vincent Marois & Tomasz Kornuta` 6 | 7 | 8 | To use Mi-Prometheus, we need to first install PyTorch_. Refer to the official installation guide_ of PyTorch for its installation. 9 | It is easily installable via conda_, or you can compile it from source to optimize it for your machine. 10 | 11 | Mi-Prometheus is not (yet) available as a pip_ package, or on conda_. 12 | However, we provide the `setup.py` script and recommend to use it for the installation of Mi-Prometheus. 13 | First please clone the MI-Prometheus repository:: 14 | 15 | git clone git@github.com:IBM/mi-prometheus.git 16 | cd mi-prometheus/ 17 | 18 | Then, install the dependencies by running:: 19 | 20 | python setup.py install 21 | 22 | This command will install all dependencies of Mi-Prometheus via pip_. 23 | If you plan to develop and introduce changes, please call the following command instead:: 24 | 25 | python setup.py develop 26 | 27 | This will enable you to change the code of the existing problems/models/workers and still be able to run them by calling the associated ``mip-*`` commands. 28 | More in that subject can be found in the following blog post on dev_mode_. 29 | 30 | .. _guide: https://github.com/pytorch/pytorch#installation 31 | .. _PyTorch: https://github.com/pytorch/pytorch 32 | .. _conda: https://anaconda.org/pytorch/pytorch 33 | .. _pip: https://pip.pypa.io/en/stable/quickstart/ 34 | .. _dev_mode: https://setuptools.readthedocs.io/en/latest/setuptools.html#development-mode 35 | 36 | 37 | **Please note that it will not install PyTorch**, as we have observed inconsistent and erratic errors when we were installing it from pip_. 38 | Use conda_, or compile it from source instead. The indicated way by the PyTorch team is conda_. 39 | 40 | The `setup.py` will register Mi-Prometheus as a package in your `Python` environment so that you will be able to `import` it: 41 | 42 | >>> import miprometheus as mip 43 | 44 | And then you will be able to access the API, for instance: 45 | 46 | >>> datadict = mip.utils.DataDict() 47 | 48 | Additionally, `setup.py` also creates aliases for the workers, so that you can use them as regular commands:: 49 | 50 | mip-offline-trainer --c path/to/your/config/file 51 | 52 | The currently available commands are: 53 | 54 | - mip-offline-trainer 55 | - mip-online-trainer 56 | - mip-tester 57 | - mip-grid-trainer-cpu 58 | - mip-grid-trainer-gpu 59 | - mip-grid-tester-cpu 60 | - mip-grid-tester-gpu 61 | - mip-grid-analyzer 62 | - mip-index-splitter 63 | 64 | Use `--h` to see the available flags for each command. 65 | 66 | You can then delete the cloned repository and use these commands to run a particular worker with your configurations files. 67 | 68 | **Please note** that we provide multiple configuration files in `configs/` (which we use daily for our research & development). 69 | Feel free to copy this folder somewhere and use these files as you would like. We are investigating on how we could make these files usable in an easier way. 70 | 71 | -------------------------------------------------------------------------------- /docs/source/mip_primer/2_problems_explained.rst: -------------------------------------------------------------------------------- 1 | Problems Explained 2 | =================== 3 | `@author: Tomasz Kornuta & Vincent Marois` 4 | 5 | 6 | A Problem class generalizes the idea of a dataset (a limited set of samples, already generated) with a data generator (a potentially unlimited set of samples, generated `on-the-fly`). 7 | Its main role is twofold: generate batches of samples (e.g. loading them from disk or generating them `on-the-fly`) and compute metrics such as the loss and accuracy, which are typically dependent on the type of problem at hand. 8 | 9 | In MI-Prometheus, we created a hierarchy of classes that follows a taxonomy driven by the input-output combinations. 10 | For instance, MNIST is derived from the abstract `ImageToClassProblem` class, which itself inherits from the main `Problem` class. 11 | TranslationAnki (a toy machine translation problem) is derived from the abstract `TextToTextProblem` class, which inherits from 12 | the `SequenceToSequenceProblem` class, derived finally from the main Problem class. 13 | 14 | Such a hierarchy speeds up the integration of new problem classes, as similar problems share common characteristics (same loss function, same metrics, same definition of the data structures, etc.) 15 | and limits the amount of code to write. 16 | 17 | It is also straight-forward to impact all problem classes (e.g. adding a new feature) by implementing it in the base Problem class. 18 | 19 | It is worth noting that the abstract base Problem class is derived from the PyTorch Dataset_ class. 20 | This enables to incorporate the usage of concurrent data loading workers (through the use of PyTorch’s DataLoader_). 21 | 22 | .. _Dataset: https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset 23 | .. _DataLoader: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader 24 | 25 | This feature works with all MI-Prometheus problem classes by default. 26 | Hence, every problem class defines (through inheritance or directly) several methods & attributes (among others): 27 | 28 | - The necessary functions `__getitem__` , `len` and `collate`, used by PyTorch’s DataLoader, 29 | - The used loss function and a method `evaluate_loss`, 30 | - A definition of the inputs that it expects/accepts, 31 | - Getter & setter functions for the collected statistics. 32 | - Getter & setter methods for the statistics which are `aggregated` based on the collected statistics. 33 | 34 | .. figure:: ../img/problem_class_diagram_wide.png 35 | :scale: 50 % 36 | :alt: Class diagram of the problems. 37 | :align: center 38 | 39 | The class inheritance of the problems. For clarity, we present only the most important fields and methods of the base Problem class. 40 | 41 | The figure above presents a class diagram with a subset of problems that are currently present in the framework, starting from classical image classification (such as MNIST and CIFAR10), machine translation (TranslationAnki) 42 | or Visual Question Answering (e.g. CLEVR (Johnson et al., 2017) and SortOf-CLEVR (Santoro et al., 2017)). 43 | It also contains several algorithmic problems such as SerialRecall and ReverseRecall used in (Graves et al., 2014). 44 | -------------------------------------------------------------------------------- /miprometheus/models/mac/utils_mac.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Kim Seonghyeon 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | # 26 | # ------------------------------------------------------------------------------ 27 | # 28 | # Copyright (C) IBM Corporation 2018 29 | # 30 | # Licensed under the Apache License, Version 2.0 (the "License"); 31 | # you may not use this file except in compliance with the License. 32 | # You may obtain a copy of the License at 33 | # 34 | # http://www.apache.org/licenses/LICENSE-2.0 35 | # 36 | # Unless required by applicable law or agreed to in writing, software 37 | # distributed under the License is distributed on an "AS IS" BASIS, 38 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 39 | # See the License for the specific language governing permissions and 40 | # limitations under the License. 41 | 42 | """ 43 | utils_mac.py: Implementation of utils methods for the MAC network. Cf https://arxiv.org/abs/1803.03067 for the \ 44 | reference paper. 45 | """ 46 | __author__ = "Vincent Marois" 47 | 48 | from torch import nn 49 | 50 | 51 | def linear(input_dim, output_dim, bias=True): 52 | """ 53 | Defines a Linear layer. Specifies Xavier as the initialization type of the weights, to respect the original \ 54 | implementation: https://github.com/stanfordnlp/mac-network/blob/master/ops.py#L20 55 | 56 | :param input_dim: input dimension 57 | :type input_dim: int 58 | 59 | :param output_dim: output dimension 60 | :type output_dim: int 61 | 62 | :param bias: If set to True, the layer will learn an additive bias initially set to true \ 63 | (as original implementation https://github.com/stanfordnlp/mac-network/blob/master/ops.py#L40) 64 | :type bias: bool 65 | 66 | :return: Initialized Linear layer 67 | 68 | """ 69 | 70 | linear_layer = nn.Linear(input_dim, output_dim, bias=bias) 71 | nn.init.xavier_uniform_(linear_layer.weight) 72 | if bias: 73 | linear_layer.bias.data.zero_() 74 | 75 | return linear_layer 76 | -------------------------------------------------------------------------------- /docs/source/api_reference/1_utils.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | .. currentmodule:: miprometheus.utils 5 | 6 | miprometheus.utils 7 | =================================== 8 | .. automodule:: miprometheus.utils 9 | :members: 10 | :special-members: 11 | :exclude-members: __dict__,__weakref__ 12 | 13 | AppState 14 | ---------- 15 | .. autoclass:: AppState 16 | :members: 17 | :special-members: 18 | :exclude-members: __dict__,__weakref__ 19 | 20 | DataDict 21 | ---------- 22 | .. autoclass:: DataDict 23 | :members: 24 | :special-members: 25 | :exclude-members: __dict__,__weakref__ 26 | 27 | ParamInterface 28 | ----------------- 29 | .. autoclass:: ParamInterface 30 | :members: 31 | :special-members: 32 | :exclude-members: __dict__,__weakref__ 33 | 34 | ParamRegistry 35 | ----------------- 36 | .. autoclass:: MetaSingletonABC 37 | :members: 38 | :special-members: 39 | :exclude-members: __dict__,__weakref__ 40 | 41 | .. autoclass:: SingletonMetaClass 42 | :members: 43 | :special-members: 44 | :exclude-members: __dict__,__weakref__ 45 | 46 | .. autoclass:: ParamRegistry 47 | :members: 48 | :special-members: 49 | :exclude-members: __dict__,__weakref__ 50 | 51 | SamplerFactory 52 | ----------------------- 53 | .. autoclass:: SamplerFactory 54 | :members: 55 | :special-members: 56 | :exclude-members: __dict__,__weakref__ 57 | 58 | Split Indices 59 | ----------------------- 60 | .. automodule:: miprometheus.utils.split_indices 61 | :members: split_indices 62 | 63 | .. currentmodule:: miprometheus.utils 64 | 65 | StatisticsCollector 66 | ----------------------- 67 | .. autoclass:: StatisticsCollector 68 | :members: 69 | :special-members: 70 | :exclude-members: __dict__,__weakref__ 71 | 72 | StatisticsAggregator 73 | ----------------------- 74 | .. autoclass:: StatisticsAggregator 75 | :members: 76 | :special-members: 77 | :exclude-members: __dict__,__weakref__ 78 | 79 | TimePlot 80 | ---------- 81 | .. autoclass:: TimePlot 82 | :members: 83 | :special-members: 84 | :exclude-members: __dict__,__weakref__ 85 | 86 | Losses 87 | ---------- 88 | 89 | :hidden:`Masked BCEWithLogitsLoss` 90 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 91 | .. autoclass:: MaskedBCEWithLogitsLoss 92 | :members: 93 | :special-members: 94 | :exclude-members: __dict__,__weakref__ 95 | 96 | :hidden:`Masked CrossEntropyLoss` 97 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 98 | .. autoclass:: MaskedCrossEntropyLoss 99 | :members: 100 | :special-members: 101 | :exclude-members: __dict__,__weakref__ 102 | 103 | Problems Utils 104 | -------------------- 105 | 106 | :hidden:`GenerateFeatureMaps` 107 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 108 | .. autoclass:: GenerateFeatureMaps 109 | :members: 110 | :special-members: 111 | :exclude-members: __dict__,__weakref__ 112 | 113 | :hidden:`Language` 114 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 115 | .. autoclass:: Language 116 | :members: 117 | :special-members: 118 | :exclude-members: __dict__,__weakref__ 119 | -------------------------------------------------------------------------------- /miprometheus/problems/seq_to_seq/seq_to_seq_problem.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) IBM Corporation 2018 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | """ 19 | seq_to_seq_problem.py: contains base class for all sequence to sequence problems. 20 | 21 | """ 22 | 23 | __author__ = "Tomasz Kornuta & Vincent Marois" 24 | 25 | from miprometheus.problems.problem import Problem 26 | import torch 27 | 28 | 29 | class SeqToSeqProblem(Problem): 30 | """ 31 | Class representing base class for all sequential problems. 32 | """ 33 | 34 | def __init__(self, params): 35 | """ 36 | Initializes problem object. Calls base constructor. 37 | 38 | :param params: Dictionary of parameters (read from configuration ``.yaml`` file). 39 | 40 | """ 41 | super(SeqToSeqProblem, self).__init__(params) 42 | 43 | # "Default" problem name. 44 | self.name = 'SeqToSeqProblem' 45 | 46 | # Set default data_definitions dict for all Seq2Seq problems. 47 | self.data_definitions = {'sequences': {'size': [-1, -1, -1], 'type': [torch.Tensor]}, 48 | 'targets': {'size': [-1, -1, -1], 'type': [torch.Tensor]}, 49 | 'masks': {'size': [-1, -1, 1], 'type': [torch.Tensor]}, 50 | 'sequences_length': {'size': [-1, 1], 'type': [torch.Tensor]} 51 | } 52 | 53 | 54 | # Check if predictions/targets should be masked (DEFAULT: True). 55 | params.add_default_params({'use_mask': True}) 56 | self.use_mask = params["use_mask"] 57 | 58 | def evaluate_loss(self, data_dict, logits): 59 | """ Calculates accuracy equal to mean number of correct predictions in a given batch. 60 | WARNING: Applies mask to both logits and targets! 61 | 62 | :param data_dict: DataDict({'sequences', 'sequences_length', 'targets', 'mask'}). 63 | 64 | :param logits: Predictions being output of the model. 65 | 66 | """ 67 | # Check if mask should be is used - if so, use the correct loss 68 | # function. 69 | if self.use_mask: 70 | loss = self.loss_function( 71 | logits, data_dict['targets'], data_dict['masks']) 72 | else: 73 | loss = self.loss_function(logits, data_dict['targets']) 74 | 75 | return loss 76 | 77 | 78 | if __name__ == '__main__': 79 | 80 | from miprometheus.utils.param_interface import ParamInterface 81 | 82 | sample = SeqToSeqProblem(ParamInterface())[0] 83 | # equivalent to ImageTextToClassProblem(params={}).__getitem__(index=0) 84 | 85 | print(repr(sample)) 86 | -------------------------------------------------------------------------------- /miprometheus/models/dnc/plot_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) IBM Corporation 2018 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | #      http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | """plot_data.py: Out of Date function for plotting memory and attention in the DNC """ 19 | __author__ = " Ryan L. McAvoy" 20 | 21 | import matplotlib 22 | import numpy as np 23 | 24 | 25 | def plot_memory_attention(prediction, memory, wt_read, wt_write, usage, label): 26 | matplotlib.pyplot.clf() 27 | fig = matplotlib.pyplot.figure(1) 28 | ax1 = fig.add_subplot(221) 29 | ax2 = fig.add_subplot(333) 30 | ax3 = fig.add_subplot(212) 31 | ax4 = fig.add_subplot(222) 32 | 33 | ax1.xaxis.set_major_locator(matplotlib.ticker.MaxNLocator(integer=True)) 34 | ax1.set_title("Attention", fontname='Times New Roman', fontsize=15) 35 | ax1.plot(np.arange(wt_read.size()[-1]), 36 | wt_read[0, 37 | 0, 38 | :].detach().numpy(), 39 | 'go', 40 | label="read head") 41 | 42 | # ax1.xaxis.set_major_locator(matplotlib.ticker.MaxNLocator(integer=True)) 43 | #ax1.set_title("Attention Write", fontname='Times New Roman', fontsize=15) 44 | ax1.plot(np.arange(wt_write.size()[-1]), 45 | wt_write[0, 46 | 0, 47 | :].detach().numpy(), 48 | 'o', 49 | label="write head") 50 | 51 | ax1.plot( 52 | np.arange(usage.size()[-1]), 53 | usage[0, :].detach().numpy(), 54 | 'o', label="write head") 55 | 56 | ax3.xaxis.set_major_locator(matplotlib.ticker.MaxNLocator(integer=True)) 57 | ax3.set_ylabel("Word size", fontname='Times New Roman', fontsize=15) 58 | ax3.set_xlabel("Memory addresses", fontname='Times New Roman', fontsize=15) 59 | ax3.set_title("Task: xxx", fontname='Times New Roman', fontsize=15) 60 | ax3.imshow(memory[0, :, :].detach().numpy(), interpolation='nearest') 61 | 62 | ax4.xaxis.set_major_locator(matplotlib.ticker.MaxNLocator(integer=True)) 63 | ax4.set_title("Prediction", fontname='Times New Roman', fontsize=15) 64 | ax4.imshow(np.transpose((prediction[0, ...]).detach().numpy(), [1, 0])) 65 | 66 | matplotlib.pyplot.pause(0.2) 67 | 68 | 69 | def plot_memory(memory): 70 | matplotlib.pyplot.clf() 71 | fig = matplotlib.pyplot.figure(1) 72 | 73 | ax2 = fig.add_subplot(111) 74 | 75 | ax2.xaxis.set_major_locator(matplotlib.ticker.MaxNLocator(integer=True)) 76 | ax2.set_ylabel("Word size", fontname='Times New Roman', fontsize=15) 77 | ax2.set_xlabel("Memory addresses", fontname='Times New Roman', fontsize=15) 78 | 79 | ax2.imshow(memory[0, :, :], interpolation='nearest', cmap='Greys') 80 | matplotlib.pyplot.pause(0.1) 81 | # input("pause") 82 | -------------------------------------------------------------------------------- /miprometheus/models/VWM_model/utils_VWM.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Kim Seonghyeon 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | # 26 | # ------------------------------------------------------------------------------ 27 | # 28 | # Copyright (C) IBM Corporation 2018 29 | # 30 | # Licensed under the Apache License, Version 2.0 (the "License"); 31 | # you may not use this file except in compliance with the License. 32 | # You may obtain a copy of the License at 33 | # 34 | # http://www.apache.org/licenses/LICENSE-2.0 35 | # 36 | # Unless required by applicable law or agreed to in writing, software 37 | # distributed under the License is distributed on an "AS IS" BASIS, 38 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 39 | # See the License for the specific language governing permissions and 40 | # limitations under the License. 41 | 42 | """ 43 | utils_mac.py: Implementation of utils methods for the MAC network. Cf https://arxiv.org/abs/1803.03067 for the \ 44 | reference paper. 45 | """ 46 | __author__ = "Vincent Albouy, T.S. Jayram" 47 | 48 | from torch import nn 49 | 50 | def linear(input_dim, output_dim, bias=True): 51 | 52 | """ 53 | Defines a Linear layer. Specifies Xavier as the initialization type of the weights, to respect the original \ 54 | implementation: https://github.com/stanfordnlp/mac-network/blob/master/ops.py#L20 55 | 56 | :param input_dim: input dimension 57 | :type input_dim: int 58 | 59 | :param output_dim: output dimension 60 | :type output_dim: int 61 | 62 | :param bias: If set to True, the layer will learn an additive bias initially set to true \ 63 | (as original implementation https://github.com/stanfordnlp/mac-network/blob/master/ops.py#L40) 64 | :type bias: bool 65 | 66 | :return: Initialized Linear layer 67 | :type: torch layer 68 | 69 | """ 70 | 71 | #define linear layer from torch.nn library 72 | linear_layer = nn.Linear(input_dim, output_dim, bias=bias) 73 | 74 | #initialize weights 75 | nn.init.xavier_uniform_(linear_layer.weight) 76 | 77 | #initialize biases 78 | if bias: 79 | linear_layer.bias.data.zero_() 80 | 81 | return linear_layer 82 | -------------------------------------------------------------------------------- /miprometheus/models/controllers/gru_controller.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) IBM Corporation 2018 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | """gru_controller.py: pytorch module implementing wrapper for gru controller of NTM.""" 19 | __author__ = "Tomasz Kornuta/Ryan L. McAvoy/Younes Bouhadjar" 20 | 21 | import torch 22 | import collections 23 | from torch.nn import Module 24 | 25 | from miprometheus.utils.app_state import AppState 26 | 27 | _GRUStateTuple = collections.namedtuple('GRUStateTuple', ('hidden_state')) 28 | 29 | 30 | class GRUStateTuple(_GRUStateTuple): 31 | """ 32 | Tuple used by GRU Cells for storing current/past state information. 33 | """ 34 | __slots__ = () 35 | 36 | 37 | class GRUController(Module): 38 | """ 39 | A wrapper class for a GRU cell-based controller. 40 | """ 41 | def __init__(self, params): 42 | """ 43 | Constructor. 44 | 45 | :param params: Dictionary of parameters. 46 | 47 | """ 48 | 49 | self.input_size = params["input_size"] 50 | self.ctrl_hidden_state_size = params["output_size"] 51 | #self.hidden_state_dim = params["hidden_state_dim"] 52 | self.num_layers = params["num_layers"] 53 | assert self.num_layers > 0, "Number of layers should be > 0" 54 | 55 | super(GRUController, self).__init__() 56 | 57 | self.gru = torch.nn.GRUCell(self.input_size, self.ctrl_hidden_state_size) 58 | 59 | def init_state(self, batch_size): 60 | """ 61 | Returns 'zero' (initial) state tuple. 62 | 63 | :param batch_size: Size of the batch in given iteraction/epoch. 64 | :returns: Initial state tuple - object of GRUStateTuple class. 65 | :returns: Initial state tuple - object of GRUStateTuple class. 66 | 67 | """ 68 | # Initialize GRU hidden state [BATCH_SIZE x CTRL_HIDDEN_SIZE]. 69 | dtype = AppState().dtype 70 | hidden_state = torch.zeros( 71 | (batch_size, 72 | self.ctrl_hidden_state_size), 73 | requires_grad=False).type(dtype) 74 | 75 | return GRUStateTuple(hidden_state) 76 | 77 | def forward(self, x, prev_state_tuple): 78 | """ 79 | Controller forward function. 80 | 81 | :param x: a Tensor of input data of size [BATCH_SIZE x INPUT_SIZE] (generally the read data and input word concatenated) 82 | :param prev_state_tuple: Tuple of the previous hidden and cell state 83 | :returns: outputs a Tensor of size [BATCH_SIZE x OUTPUT_SIZE] and an GRU state tuple. 84 | 85 | """ 86 | 87 | hidden_state_prev = prev_state_tuple.hidden_state 88 | hidden_state = self.gru(x, hidden_state_prev) 89 | 90 | return hidden_state, GRUStateTuple(hidden_state) 91 | -------------------------------------------------------------------------------- /miprometheus/models/vision/lenet5.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) IBM Corporation 2018 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | """ 19 | lenet5: a classical LeNet-5 model for MNIST digit classification. \ 20 | To be taken as an illustrative example. 21 | """ 22 | 23 | __author__ = "Tomasz Kornuta & Vincent Marois" 24 | 25 | 26 | import torch 27 | 28 | # Import useful MI-Prometheus classes. 29 | from miprometheus.models.model import Model 30 | 31 | 32 | class LeNet5(Model): 33 | """ 34 | A classical LeNet-5 model for MNIST digits classification. 35 | """ 36 | def __init__(self, params_, problem_default_values_): 37 | """ 38 | Initializes the ``LeNet5`` model, creates the required layers. 39 | 40 | :param params: Parameters read from configuration file. 41 | :type params: ``miprometheus.utils.ParamInterface`` 42 | 43 | :param problem_default_values_: dict of parameters values coming from the problem class. 44 | :type problem_default_values_: dict 45 | 46 | """ 47 | super(LeNet5, self).__init__(params_, problem_default_values_) 48 | self.name = 'LeNet5' 49 | 50 | # Create the LeNet-5 layers. 51 | self.conv1 = torch.nn.Conv2d(1, 6, kernel_size=(5, 5)) 52 | self.maxpool1 = torch.nn.MaxPool2d(kernel_size=(2, 2), stride=2) 53 | self.conv2 = torch.nn.Conv2d(6, 16, kernel_size=(5, 5)) 54 | self.maxpool2 = torch.nn.MaxPool2d(kernel_size=(2, 2), stride=2) 55 | self.conv3 = torch.nn.Conv2d(16, 120, kernel_size=(5, 5)) 56 | self.linear1 = torch.nn.Linear(120, 84) 57 | self.linear2 = torch.nn.Linear(84, 10) 58 | 59 | # Create Model data definitions - indicate what a given model needs. 60 | self.data_definitions = { 61 | 'images': {'size': [-1, 1, 32, 32], 'type': [torch.Tensor]}} 62 | 63 | def forward(self, data_dict): 64 | """ 65 | Main forward pass of the ``LeNet5`` model. 66 | 67 | :param data_dict: DataDict({'images',**}), where: 68 | 69 | - images: [batch_size, num_channels, width, height] 70 | 71 | :type data_dict: ``miprometheus.utils.DataDict`` 72 | 73 | :return: Predictions [batch_size, num_classes] 74 | 75 | """ 76 | 77 | # Unpack DataDict. 78 | img = data_dict['images'] 79 | 80 | # Pass inputs through layers. 81 | x = self.conv1(img) 82 | x = torch.nn.functional.relu(x) 83 | x = self.maxpool1(x) 84 | x = self.conv2(x) 85 | x = torch.nn.functional.relu(x) 86 | x = self.maxpool2(x) 87 | x = self.conv3(x) 88 | x = torch.nn.functional.relu(x) 89 | x = x.view(-1, 120) 90 | x = self.linear1(x) 91 | x = torch.nn.functional.relu(x) 92 | x = self.linear2(x) 93 | 94 | return x # return logits. 95 | -------------------------------------------------------------------------------- /configs/mac/mac_smac_cogent_a_finetuning.yaml: -------------------------------------------------------------------------------- 1 | # Lists of tasks that will be executed in parallel. 2 | grid_tasks: 3 | - 4 | default_configs: configs/mac/mac_cogent.yaml # MAC on CoGenT 5 | 6 | overwrite: 7 | # Add the model parameters: 8 | model: 9 | load: 'path/to/the/trained/model/to/finetune' # the CLEVR trained mac model to finetune on CoGenT-A 10 | 11 | # indicate the following so that the CLEVR-trained model can properly handle CoGenT data. 12 | training: 13 | problem: 14 | questions: 15 | embedding_source: 'CLEVR' 16 | validation: 17 | problem: 18 | questions: 19 | embedding_source: 'CLEVR' 20 | testing: 21 | problem: 22 | questions: 23 | embedding_source: 'CLEVR' 24 | - 25 | default_configs: configs/mac/s_mac_cogent.yaml # S-MAC on CoGenT 26 | 27 | overwrite: 28 | # Add the model parameters: 29 | model: 30 | load: 'path/to/the/trained/model/to/finetune' # the CLEVR trained s-mac model to finetune on CoGenT-A 31 | 32 | # indicate the following so that the CLEVR-trained model can properly handle CoGenT data. 33 | training: 34 | problem: 35 | questions: 36 | embedding_source: 'CLEVR' 37 | validation: 38 | problem: 39 | questions: 40 | embedding_source: 'CLEVR' 41 | testing: 42 | problem: 43 | questions: 44 | embedding_source: 'CLEVR' 45 | 46 | # Parameters that will be overwritten for all tasks. 47 | grid_overwrite: 48 | 49 | training: 50 | #use_EMA: True # EMA: keep track of exponential moving averages of the models weights. 51 | 52 | problem: 53 | settings: 54 | set: 'valA' # finetune on CoGenT-B, on 30k samples indicated by the indices below. 55 | sampler: 56 | name: 'SubsetRandomSampler' 57 | indices: '~/data/CLEVR_CoGenT_v1.0/vigil_cogent_finetuning_valA_indices.txt' 58 | terminal_conditions: 59 | epoch_limit: 10 # finetune for 10 epochs. 60 | 61 | # fix the seeds 62 | seed_torch: 0 63 | seed_numpy: 0 64 | 65 | validation: 66 | problem: 67 | settings: 68 | set: 'trainA' # use CoGenT-A as validation set (-> 10% of the true training set) as we will test on CoGenT. 69 | sampler: 70 | name: 'SubsetRandomSampler' 71 | indices: '~/data/CLEVR_CoGenT_v1.0/vigil_cogent_val_set_indices.txt' 72 | 73 | testing: 74 | problem: 75 | settings: 76 | set: 'valB' 77 | sampler: 78 | name: 'SubsetRandomSampler' 79 | indices: '~/data/CLEVR_CoGenT_v1.0/vigil_cogent_test_valB_indices.txt' 80 | 81 | # test on the CoGenT-B validation set, but on the complementary samples (not used for finetuning). 82 | # test on condition A. We simply point to a file containing all the indices of valA. 83 | multi_tests: { 84 | set: ['valB', 'valA'], 85 | indices: ['~/data/CLEVR_CoGenT_v1.0/vigil_cogent_valB_full_indices.txt', 86 | '~/data/CLEVR_CoGenT_v1.0/vigil_cogent_test_valA_indices.txt'], 87 | max_test_episodes: [-1, -1] 88 | } 89 | 90 | 91 | grid_settings: 92 | # Set number of repetitions of each experiments. 93 | experiment_repetitions: 1 94 | # Set number of concurrent running experiments (will be limited by the actual number of available CPUs/GPUs). 95 | max_concurrent_runs: 7 96 | # Set trainer. 97 | trainer: mip-offline-trainer -------------------------------------------------------------------------------- /miprometheus/models/controllers/ffgru_controller.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) IBM Corporation 2018 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | """ffgru_controller.py: pytorch module implementing wrapper for FF-GRU controller used in ThalNet.""" 19 | __author__ = "Younes Bouhadjar" 20 | 21 | import torch 22 | import collections 23 | from torch.nn import Module 24 | 25 | from miprometheus.utils.app_state import AppState 26 | 27 | _FFGRUStateTuple = collections.namedtuple('FFGRUStateTuple', ('hidden_state')) 28 | 29 | 30 | class FFGRUStateTuple(_FFGRUStateTuple): 31 | """ 32 | Tuple used by gru Cells for storing current/past state information. 33 | """ 34 | __slots__ = () 35 | 36 | 37 | class FFGRUController(Module): 38 | """ 39 | A wrapper class for a feedforward controller with a GRU cell. 40 | """ 41 | 42 | def __init__(self, params): 43 | """ 44 | Constructor. 45 | 46 | :param params: Dictionary of parameters. 47 | 48 | """ 49 | 50 | self.input_size = params["input_size"] 51 | self.ctrl_hidden_state_size = params["output_size"] 52 | #self.hidden_state_dim = params["hidden_state_dim"] 53 | self.num_layers = params["num_layers"] 54 | self.ff_output_size = params["ff_output_size"] 55 | assert self.num_layers > 0, "Number of layers should be > 0" 56 | 57 | super(FFGRUController, self).__init__() 58 | 59 | self.ff = torch.nn.Linear(self.input_size, self.ff_output_size) 60 | self.gru = torch.nn.GRUCell(self.ff_output_size, self.ctrl_hidden_state_size) 61 | 62 | def init_state(self, batch_size): 63 | """ 64 | Returns 'zero' (initial) state tuple. 65 | 66 | :param batch_size: Size of the batch in given iteraction/epoch. 67 | :returns: Initial state tuple - object of GRUStateTuple class. 68 | 69 | """ 70 | # Initialize GRU hidden state [BATCH_SIZE x CTRL_HIDDEN_SIZE]. 71 | dtype = AppState().dtype 72 | hidden_state = torch.zeros( 73 | (batch_size, 74 | self.ctrl_hidden_state_size), 75 | requires_grad=False).type(dtype) 76 | 77 | return FFGRUStateTuple(hidden_state) 78 | 79 | def forward(self, x, prev_state_tuple): 80 | """ 81 | Controller forward function. 82 | 83 | :param x: a Tensor of input data of size [BATCH_SIZE x INPUT_SIZE] (generally the read data and input word concatenated) 84 | :param prev_state_tuple: Tuple of the previous hidden and cell state 85 | :returns: outputs a Tensor of size [BATCH_SIZE x OUTPUT_SIZE] and an GRU state tuple. 86 | 87 | """ 88 | 89 | hidden_state_prev = prev_state_tuple.hidden_state 90 | 91 | input = self.ff(x) 92 | hidden_state = self.gru(input, hidden_state_prev) 93 | 94 | return hidden_state, FFGRUStateTuple(hidden_state) 95 | -------------------------------------------------------------------------------- /miprometheus/models/controllers/lstm_controller.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) IBM Corporation 2018 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | """lstm_controller.py: pytorch module implementing wrapper for lstm controller of NTM.""" 19 | __author__ = "Tomasz Kornuta/Ryan L. McAvoy" 20 | 21 | import torch 22 | import collections 23 | from torch.nn import Module 24 | 25 | from miprometheus.utils.app_state import AppState 26 | 27 | _LSTMStateTuple = collections.namedtuple( 28 | 'LSTMStateTuple', ('hidden_state', 'cell_state')) 29 | 30 | 31 | class LSTMStateTuple(_LSTMStateTuple): 32 | """ 33 | Tuple used by LSTM Cells for storing current/past state information. 34 | """ 35 | __slots__ = () 36 | 37 | 38 | class LSTMController(Module): 39 | """ 40 | A wrapper class for a LSTM-based controller. 41 | """ 42 | def __init__(self, params): 43 | """ 44 | Constructor. 45 | 46 | :param params: Dictionary of parameters. 47 | 48 | """ 49 | 50 | self.input_size = params["input_size"] 51 | self.ctrl_hidden_state_size = params["output_size"] 52 | #self.hidden_state_dim = params["hidden_state_dim"] 53 | self.num_layers = params["num_layers"] 54 | assert self.num_layers > 0, "Number of layers should be > 0" 55 | 56 | super(LSTMController, self).__init__() 57 | 58 | self.lstm = torch.nn.LSTMCell(self.input_size, self.ctrl_hidden_state_size) 59 | 60 | def init_state(self, batch_size): 61 | """ 62 | Returns 'zero' (initial) state tuple. 63 | 64 | :param batch_size: Size of the batch in given iteraction/epoch. 65 | :returns: Initial state tuple - object of LSTMStateTuple class. 66 | 67 | """ 68 | dtype = AppState().dtype 69 | # Initialize LSTM hidden state [BATCH_SIZE x CTRL_HIDDEN_SIZE]. 70 | hidden_state = torch.zeros( 71 | (batch_size, 72 | self.ctrl_hidden_state_size), 73 | requires_grad=False).type(dtype) 74 | # Initialize LSTM memory cell [BATCH_SIZE x CTRL_HIDDEN_SIZE]. 75 | cell_state = torch.zeros( 76 | (batch_size, 77 | self.ctrl_hidden_state_size), 78 | requires_grad=False).type(dtype) 79 | 80 | return LSTMStateTuple(hidden_state, cell_state) 81 | 82 | def forward(self, x, prev_state_tuple): 83 | """ 84 | Controller forward function. 85 | 86 | :param x: a Tensor of input data of size [BATCH_SIZE x INPUT_SIZE] (generally the read data and input word concatenated) 87 | :param prev_state_tuple: Tuple of the previous hidden and cell state 88 | :returns: outputs a Tensor of size [BATCH_SIZE x OUTPUT_SIZE] and an LSTM state tuple. 89 | 90 | """ 91 | 92 | hidden_state, cell_state = self.lstm(x, prev_state_tuple) 93 | 94 | return hidden_state, LSTMStateTuple(hidden_state, cell_state) 95 | -------------------------------------------------------------------------------- /miprometheus/models/vqa_baselines/stacked_attention_networks/image_encoding.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) IBM Corporation 2018 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | """ 19 | image_encoding.py: Contains a class using a pretrained CNN from ``torchvision`` as the image encoding \ 20 | for the Stacked Attention Network. 21 | 22 | """ 23 | __author__ = "Vincent Marois & Younes Bouhadjar" 24 | 25 | import torch 26 | from torch.nn import Module 27 | import torchvision 28 | 29 | 30 | class PretrainedImageEncoding(Module): 31 | """ 32 | Wrapper class over a ``torchvision.model`` to produce feature maps for the SAN model. 33 | 34 | """ 35 | 36 | def __init__(self, cnn_model='resnet18', num_layers=2): 37 | """ 38 | Constructor of the ``PretrainedImageEncoding`` class. 39 | 40 | :param cnn_model: select which pretrained model to load. 41 | :type cnn_model: str 42 | 43 | .. warning:: 44 | 45 | This class has only been tested with the ``resnet18`` model. 46 | 47 | :param num_layers: Number of layers to select from the ``cnn_model``. 48 | :type num_layers: int 49 | 50 | 51 | """ 52 | # call base constructor 53 | super(PretrainedImageEncoding, self).__init__() 54 | 55 | # Get pretrained cnn model 56 | self.cnn_model = cnn_model 57 | cnn = getattr(torchvision.models, self.cnn_model)(pretrained=True) 58 | 59 | # First layer added with num_channel equal 3 60 | layers = [ 61 | cnn.conv1, 62 | cnn.bn1, 63 | cnn.relu, 64 | cnn.maxpool, 65 | ] 66 | 67 | # select the following layers and append them. 68 | for i in range(1, num_layers+1): 69 | name = 'layer%d' % i 70 | layers.append(getattr(cnn, name)) 71 | 72 | self.model = torch.nn.Sequential(*layers) 73 | 74 | def get_output_nb_filters(self): 75 | """ 76 | :return: The number of filters of the last conv layer. 77 | """ 78 | try: 79 | nb_channels = self.model[-1][-1].bn2.num_features 80 | return nb_channels 81 | except Exception: 82 | print('Could not get the number of output channels of the model {}'.format(self.cnn_model)) 83 | 84 | def forward(self, img): 85 | """ 86 | Forward pass of a pretrained cnn model. 87 | 88 | :param img: input image [batch_size, num_channels, height, width] 89 | :type img: torch.tensor 90 | 91 | :return x: feature maps, [batch_size, output_channels, new_height, new_width] 92 | 93 | """ 94 | # Apply model image encoding 95 | return self.model(img) 96 | 97 | 98 | if __name__ == '__main__': 99 | 100 | img_encoding = PretrainedImageEncoding() 101 | print(img_encoding.get_output_nb_filters()) 102 | -------------------------------------------------------------------------------- /miprometheus/models/VWM_model/thought_unit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Kim Seonghyeon 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | # 26 | # ------------------------------------------------------------------------------ 27 | # 28 | # Copyright (C) IBM Corporation 2018 29 | # 30 | # Licensed under the Apache License, Version 2.0 (the "License"); 31 | # you may not use this file except in compliance with the License. 32 | # You may obtain a copy of the License at 33 | # 34 | # http://www.apache.org/licenses/LICENSE-2.0 35 | # 36 | # Unless required by applicable law or agreed to in writing, software 37 | # distributed under the License is distributed on an "AS IS" BASIS, 38 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 39 | # See the License for the specific language governing permissions and 40 | # limitations under the License. 41 | 42 | """ 43 | thought_unit.py: Implementation of the ``ThoughtUnit`` for the VWM network. 44 | """ 45 | __author__ = "Vincent Albouy, T.S. Jayram" 46 | 47 | import torch 48 | from torch.nn import Module 49 | from miprometheus.models.VWM_model.utils_VWM import linear 50 | 51 | class ThoughtUnit(Module): 52 | 53 | """ 54 | Implementation of the ``ThoughtUnit`` of the MAC network. 55 | """ 56 | 57 | def __init__(self, dim): 58 | """ 59 | Constructor for the ``ThoughtUnit``. 60 | 61 | :param dim: global 'd' hidden dimension 62 | :type dim: int 63 | 64 | """ 65 | 66 | # call base constructor 67 | super(ThoughtUnit, self).__init__() 68 | 69 | # linear layer for the concatenation of context_output and summary_output 70 | self.concat_layer = linear(2 * dim, dim, bias=True) 71 | 72 | 73 | def forward(self, summary_output, context_output): 74 | 75 | """ 76 | Forward pass of the ``ThoughtUnit``. 77 | 78 | :param summary_output: previous memory states, each of shape [batch_size x dim]. 79 | :type summary_output: torch.tensor 80 | 81 | :param context_output: current read vector (output of the viusal retrieval unit), shape [batch_size x dim]. 82 | :type context_output: torch.tensor 83 | 84 | :return: next_context_output: shape [batch_size x dim] 85 | :type: next_context_output: torch.tensor 86 | 87 | """ 88 | 89 | # combine the new context_output with the summary_output along dimension 1 90 | next_context_output = self.concat_layer(torch.cat([context_output, summary_output], 1)) 91 | 92 | 93 | return next_context_output 94 | -------------------------------------------------------------------------------- /miprometheus/problems/problem_factory.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) IBM Corporation 2018 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | """ 19 | problem_factory.py: Utility constructing a problem class using specified parameters. 20 | 21 | """ 22 | __author__ = "Tomasz Kornuta & Vincent Marois" 23 | 24 | import os.path 25 | import logging 26 | import inspect 27 | 28 | from miprometheus import problems 29 | 30 | 31 | class ProblemFactory(object): 32 | """ 33 | ProblemFactory: Class instantiating the specified problem class using the passed params. 34 | """ 35 | 36 | @staticmethod 37 | def build(params): 38 | """ 39 | Static method returning a particular problem, depending on the name \ 40 | provided in the list of parameters. 41 | 42 | :param params: Parameters used to instantiate the Problem class. 43 | :type params: :py:class:`miprometheus.utils.ParamInterface` 44 | 45 | ..note:: 46 | 47 | ``params`` should contains the exact (case-sensitive) class name of the Problem to instantiate. 48 | 49 | 50 | :return: Instance of a given problem. 51 | 52 | """ 53 | logging.basicConfig(level=logging.INFO) 54 | logger = logging.getLogger('ProblemFactory') 55 | 56 | import pkgutil 57 | #print([name for _, name, _ in pkgutil.iter_modules(['miprometheus.problems'])]) 58 | #print(inspect.getmembers(problems, inspect.ismodule)) 59 | #exit(1) 60 | 61 | 62 | # Check presence of the name 63 | if 'name' not in params: 64 | logger.error("Problem configuration section does not contain the key 'name'") 65 | exit(-1) 66 | 67 | # Get the class name. 68 | name = os.path.basename(params['name']) 69 | 70 | # Get the actual class. 71 | problem_class = getattr(problems, name) 72 | 73 | # Check if class is derived (even indirectly) from Problem. 74 | inherits = False 75 | for c in inspect.getmro(problem_class): 76 | if c.__name__ == problems.Problem.__name__: 77 | inherits = True 78 | break 79 | if not inherits: 80 | logger.error("The specified class '{}' is not derived from the Problem class".format(name)) 81 | exit(-1) 82 | 83 | # Ok, proceed. 84 | logger.info('Loading the {} problem from {}'.format(name, problem_class.__module__)) 85 | # return the instantiated problem class 86 | return problem_class(params) 87 | 88 | 89 | if __name__ == "__main__": 90 | """ 91 | Tests Problem Factory. 92 | """ 93 | from miprometheus.utils.param_interface import ParamInterface 94 | params = ParamInterface() 95 | params.add_default_params({ 96 | 'name': 'SerialRecallCommandLines', 97 | 'control_bits': 3, 98 | 'data_bits': 8, 99 | 'min_sequence_length': 1, 100 | 'max_sequence_length': 5 101 | }) 102 | 103 | problem = ProblemFactory.build(params) 104 | print(type(problem)) 105 | -------------------------------------------------------------------------------- /miprometheus/models/dwm/memory.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) IBM Corporation 2018 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | #      http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | """memory.py: Class for editing memory """ 19 | __author__ = "Younes Bouhadjar" 20 | 21 | import torch 22 | from miprometheus.models.dwm.tensor_utils import sim, outer_prod 23 | 24 | 25 | class Memory: 26 | """ 27 | Implementation of the memory of the DWM. 28 | """ 29 | def __init__(self, mem_t): 30 | """ 31 | Initializes the memory. 32 | 33 | :param mem_t: the memory at time t [batch_size, memory_content_size, memory_addresses_size] 34 | 35 | """ 36 | 37 | self._memory = mem_t 38 | 39 | def attention_read(self, wt): 40 | """ 41 | Returns the data read from memory. 42 | 43 | :param wt: head's weights [batch_size, num_heads, memory_addresses_size] 44 | :return: the read data [batch_size, num_heads, memory_content_size] 45 | 46 | """ 47 | 48 | return sim(wt, self._memory) 49 | 50 | def add_weighted(self, add, wt): 51 | """ 52 | Writes data to memory. 53 | 54 | :param wt: head's weights [batch_size, num_heads, memory_addresses_size] 55 | :param add: the data to be added to memory [batch_size, num_heads, memory_content_size] 56 | 57 | :return the updated memory [batch_size, memory_addresses_size, memory_content_size] 58 | 59 | """ 60 | 61 | # memory = memory + sum_{head h} weighted add(h) 62 | self._memory = self._memory + torch.sum(outer_prod(add, wt), dim=-3) 63 | 64 | def erase_weighted(self, erase, wt): 65 | """ 66 | Erases elements from memory. 67 | 68 | :param wt: head's weights [batch_size, num_heads, memory_addresses_size] 69 | :param erase: data to be erased from memory [batch_size, num_heads, memory_content_size] 70 | 71 | :return the updated memory [batch_size, memory_addresses_size, memory_content_size] 72 | 73 | """ 74 | 75 | # memory = memory * product_{head h} (1 - weighted erase(h)) 76 | self._memory = self._memory * \ 77 | torch.prod(1 - outer_prod(erase, wt), dim=-3) 78 | 79 | def content_similarity(self, k): 80 | """ 81 | Calculates the dot product for Content aware addressing. 82 | 83 | :param k: the keys emitted by the controller [batch_size, num_heads, memory_content_size] 84 | :return: the dot product between the keys and query [batch_size, num_heads, memory_addresses_size] 85 | 86 | """ 87 | 88 | return sim(k, self._memory, l2_normalize=True, aligned=False) 89 | 90 | @property 91 | def size(self): 92 | """ 93 | Returns the size of the memory. 94 | 95 | :return: Int size of the memory 96 | 97 | """ 98 | 99 | return self._memory.size() 100 | 101 | @property 102 | def content(self): 103 | """ 104 | Returns the entire memory. 105 | 106 | :return: the memory [] 107 | 108 | """ 109 | 110 | return self._memory 111 | -------------------------------------------------------------------------------- /miprometheus/models/dnc/memory.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) IBM Corporation 2018 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | #      http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | """memory.py: Class for editing memory. Copy of the DWM equivalent """ 19 | __author__ = "Younes Bouhadjar" 20 | 21 | 22 | """DWM Memory""" 23 | import torch 24 | from miprometheus.models.dnc.tensor_utils import sim, outer_prod 25 | 26 | 27 | class Memory(object): 28 | """ 29 | 30 | """ 31 | def __init__(self, mem_t): 32 | """ 33 | Initializes the memory. 34 | 35 | :param mem_t of shape (batch_size, memory_content_size, memory_addresses_size): the memory at time t 36 | 37 | """ 38 | 39 | self._memory = mem_t 40 | 41 | def attention_read(self, wt): 42 | """ 43 | Returns the data read from memory. 44 | 45 | :param wt of shape (batch_size, num_heads, memory_addresses_size) : head's weights 46 | :return: the read data of shape (batch_size, num_heads, memory_content_size) 47 | 48 | """ 49 | 50 | return sim(wt, self._memory) 51 | 52 | def add_weighted(self, add, wt): 53 | """ 54 | Writes data to memory. 55 | 56 | :param wt of shape (batch_size, num_heads, memory_addresses_size) : head's weights 57 | :param add of shape (batch_size, num_heads, memory_content_size) : the data to be added to memory 58 | 59 | :return the updated memory of shape (batch_size, memory_addresses_size, memory_content_size) 60 | 61 | """ 62 | 63 | # memory = memory + sum_{head h} weighted add(h) 64 | self._memory = self._memory + torch.sum(outer_prod(add, wt), dim=-3) 65 | 66 | def erase_weighted(self, erase, wt): 67 | """ 68 | Erases elements from memory. 69 | 70 | :param wt of shape (batch_size, num_heads, memory_addresses_size) : head's weights 71 | :param erase of shape (batch_size, num_heads, memory_content_size) : data to be erased from memory 72 | 73 | :return the updated memory of shape (batch_size, memory_addresses_size, memory_content_size) 74 | 75 | """ 76 | 77 | # memory = memory * product_{head h} (1 - weighted erase(h)) 78 | self._memory = self._memory * \ 79 | torch.prod(1 - outer_prod(erase, wt), dim=-3) 80 | 81 | def content_similarity(self, k): 82 | """ 83 | Calculates the dot product for Content aware addressing. 84 | 85 | :param k of shape (batch_size, num_heads, memory_content_size): the keys emitted by the controller 86 | :return: the dot product between the keys and query of shape (batch_size, num_heads, memory_addresses_size) 87 | 88 | """ 89 | 90 | return sim(k, self._memory, l2_normalize=True, aligned=False) 91 | 92 | @property 93 | def size(self): 94 | """ 95 | Returns the size of the memory. 96 | 97 | :return: Int size of the memory 98 | 99 | """ 100 | 101 | return self._memory.size() 102 | 103 | @property 104 | def content(self): 105 | """ 106 | Returns the entire memory. 107 | 108 | :return: the memory [] 109 | 110 | """ 111 | 112 | return self._memory 113 | -------------------------------------------------------------------------------- /miprometheus/models/controllers/rnn_controller.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) IBM Corporation 2018 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | """rnn_controller.py: pytorch module implementing wrapper for RNN controller of NTM.""" 19 | __author__ = "Ryan L. McAvoy" 20 | 21 | import torch 22 | import collections 23 | from torch.nn import Module 24 | 25 | from miprometheus.utils.app_state import AppState 26 | 27 | _RNNStateTuple = collections.namedtuple('RNNStateTuple', ('hidden_state')) 28 | 29 | 30 | class RNNStateTuple(_RNNStateTuple): 31 | """ 32 | Tuple used by LSTM Cells for storing current/past state information. 33 | """ 34 | __slots__ = () 35 | 36 | 37 | class RNNController(Module): 38 | """ 39 | A wrapper class for a feedforward controller? 40 | 41 | TODO: Doc needs update! 42 | 43 | """ 44 | def __init__(self, params): 45 | """ 46 | Constructor for a RNN. 47 | 48 | :param params: Dictionary of parameters. 49 | 50 | """ 51 | 52 | self.input_size = params["input_size"] 53 | self.ctrl_hidden_state_size = params["output_size"] 54 | #self.hidden_state_dim = params["hidden_state_dim"] 55 | self.non_linearity = params["non_linearity"] 56 | self.num_layers = params["num_layers"] 57 | assert self.num_layers > 0, "Number of layers should be > 0" 58 | 59 | super(RNNController, self).__init__() 60 | full_size = self.input_size + self.ctrl_hidden_state_size 61 | self.rnn = torch.nn.Linear(full_size, self.ctrl_hidden_state_size) 62 | 63 | def init_state(self, batch_size): 64 | """ 65 | Returns 'zero' (initial) state tuple. 66 | 67 | :param batch_size: Size of the batch in given iteraction/epoch. 68 | :returns: Initial state tuple - object of RNNStateTuple class. 69 | 70 | """ 71 | # Initialize LSTM hidden state [BATCH_SIZE x CTRL_HIDDEN_SIZE]. 72 | dtype = AppState().dtype 73 | hidden_state = torch.zeros( 74 | (batch_size, 75 | self.ctrl_hidden_state_size), 76 | requires_grad=False).type(dtype) 77 | 78 | return RNNStateTuple(hidden_state) 79 | 80 | def forward(self, inputs, prev_hidden_state_tuple): 81 | """ 82 | Controller forward function. 83 | 84 | :param inputs: a Tensor of input data of size [BATCH_SIZE x INPUT_SIZE] (generally the read data and input word concatenated) 85 | :param prev_state_tuple: Tuple of the previous hidden state 86 | :returns: outputs a Tensor of size [BATCH_SIZE x OUTPUT_SIZE] and an RNN state tuple. 87 | 88 | """ 89 | 90 | h = prev_hidden_state_tuple[0] 91 | combo = torch.cat((inputs, h), dim=-1) 92 | hidden_state = self.rnn(combo) 93 | 94 | if self.non_linearity == "sigmoid": 95 | hidden_state = torch.nn.functional.sigmoid(hidden_state) 96 | elif self.non_linearity == "tanh": 97 | hidden_state = torch.nn.functional.tanh(hidden_state) 98 | elif self.non_linearity == "relu": 99 | hidden_state = torch.nn.functional.relu(hidden_state) 100 | 101 | return hidden_state, RNNStateTuple(hidden_state) 102 | -------------------------------------------------------------------------------- /miprometheus/models/mac/image_encoding.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Kim Seonghyeon 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | # 26 | # ------------------------------------------------------------------------------ 27 | # 28 | # Copyright (C) IBM Corporation 2018 29 | # 30 | # Licensed under the Apache License, Version 2.0 (the "License"); 31 | # you may not use this file except in compliance with the License. 32 | # You may obtain a copy of the License at 33 | # 34 | # http://www.apache.org/licenses/LICENSE-2.0 35 | # 36 | # Unless required by applicable law or agreed to in writing, software 37 | # distributed under the License is distributed on an "AS IS" BASIS, 38 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 39 | # See the License for the specific language governing permissions and 40 | # limitations under the License. 41 | 42 | """ 43 | image_encoding.py: Implementation of the image processing done by the input unit in the MAC network. See \ 44 | https://arxiv.org/pdf/1803.03067.pdf for the reference paper. 45 | 46 | """ 47 | __author__ = "Vincent Marois" 48 | 49 | import torch 50 | from torch.nn import Module 51 | 52 | 53 | class ImageProcessing(Module): 54 | """ 55 | Image encoding using a 2-layers CNN assuming the images have been already \ 56 | preprocessed by `ResNet101`. 57 | """ 58 | 59 | def __init__(self, dim): 60 | """ 61 | Constructor for the 2-layers CNN. 62 | 63 | :param dim: global 'd' hidden dimension 64 | :type dim: int 65 | 66 | """ 67 | 68 | # call base constructor 69 | super(ImageProcessing, self).__init__() 70 | 71 | self.conv = torch.nn.Sequential(torch.nn.Conv2d(1024, dim, 3, padding=1), 72 | torch.nn.ELU(), 73 | torch.nn.Conv2d(dim, dim, 3, padding=1), 74 | torch.nn.ELU()) 75 | # specify weights initialization 76 | torch.nn.init.kaiming_uniform_(self.conv[0].weight) 77 | self.conv[0].bias.data.zero_() 78 | torch.nn.init.kaiming_uniform_(self.conv[2].weight) 79 | self.conv[2].bias.data.zero_() 80 | 81 | def forward(self, feature_maps): 82 | """ 83 | Apply the constructed CNN model on the feature maps (coming from \ 84 | `ResNet101`). 85 | 86 | :param feature_maps: [batch_size x nb_kernels x feat_H x feat_W] coming from `ResNet101`. \ 87 | Should have [nb_kernels x feat_H x feat_W] = [1024 x 14 x 14]. 88 | :type feature_maps: torch.tensor 89 | 90 | :return feature_maps: feature map, shape [batch_size, dim, new_height, new_width] 91 | 92 | """ 93 | 94 | new_feature_maps = self.conv(feature_maps) 95 | 96 | return new_feature_maps 97 | -------------------------------------------------------------------------------- /miprometheus/models/controllers/controller_factory.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) IBM Corporation 2018 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | """ 19 | controller_factory.py: Factory building controllers for MANNs. 20 | 21 | """ 22 | __author__ = "Ryan L. McAvoy & Vincent Marois" 23 | import logging 24 | import inspect 25 | from torch.nn import Module 26 | 27 | import miprometheus.models.controllers 28 | 29 | 30 | class ControllerFactory(object): 31 | """ 32 | Class returning concrete controller depending on the name provided in the \ 33 | list of parameters. 34 | """ 35 | 36 | @staticmethod 37 | def build(params): 38 | """ 39 | Static method returning particular controller, depending on the name \ 40 | provided in the list of parameters. 41 | 42 | :param params: Parameters used to instantiate the controller. 43 | :type params: ``utils.param_interface.ParamInterface`` 44 | 45 | ..note:: 46 | 47 | ``params`` should contains the exact (case-sensitive) class name of the controller to instantiate. 48 | 49 | :return: Instance of a given controller. 50 | """ 51 | logging.basicConfig(level=logging.INFO) 52 | logger = logging.getLogger('ControllerFactory') 53 | 54 | # Check presence of the name attribute. 55 | if 'name' not in params: 56 | logger.error("Controller configuration section does not contain the key 'name'") 57 | exit(-1) 58 | 59 | # Get the class name. 60 | name = params['name'] 61 | 62 | # Verify that the specified class is in the controller package. 63 | if name not in dir(miprometheus.models.controllers): 64 | logger.error("Could not find the specified class '{}' in the controllers package.".format(name)) 65 | exit(-1) 66 | 67 | # Get the actual class. 68 | controller_class = getattr(miprometheus.models.controllers, name) 69 | 70 | # Check if class is derived (even indirectly) from nn.Module. 71 | inherits = False 72 | for c in inspect.getmro(controller_class): 73 | if c.__name__ == Module.__name__: 74 | inherits = True 75 | break 76 | if not inherits: 77 | logger.error("The specified class '{}' is not derived from the nn.Module class".format(name)) 78 | exit(-1) 79 | 80 | # Ok, proceed. 81 | logger.info('Loading the {} controller from {}'.format(name, controller_class.__module__)) 82 | 83 | # return the instantiated controller class 84 | return controller_class(params) 85 | 86 | 87 | if __name__ == "__main__": 88 | """ 89 | Tests ControllerFactory. 90 | """ 91 | from miprometheus.utils.param_interface import ParamInterface 92 | 93 | controller_params = ParamInterface() 94 | controller_params.add_default_params({'name': 'RNNController', 95 | 'input_size': 11, 96 | 'output_size': 11, 97 | 'hidden_state_size': 20, 98 | 'num_layers': 1, 99 | 'non_linearity': 'sigmoid'}) 100 | 101 | controller = ControllerFactory.build_controller(controller_params) 102 | print(type(controller)) 103 | -------------------------------------------------------------------------------- /miprometheus/models/mac/output_unit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Kim Seonghyeon 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | # 26 | # ------------------------------------------------------------------------------ 27 | # 28 | # Copyright (C) IBM Corporation 2018 29 | # 30 | # Licensed under the Apache License, Version 2.0 (the "License"); 31 | # you may not use this file except in compliance with the License. 32 | # You may obtain a copy of the License at 33 | # 34 | # http://www.apache.org/licenses/LICENSE-2.0 35 | # 36 | # Unless required by applicable law or agreed to in writing, software 37 | # distributed under the License is distributed on an "AS IS" BASIS, 38 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 39 | # See the License for the specific language governing permissions and 40 | # limitations under the License. 41 | 42 | """ 43 | output_unit.py: Implementation of the ``OutputUnit`` for the MAC network. Cf https://arxiv.org/abs/1803.03067 \ 44 | for the reference paper. 45 | """ 46 | __author__ = "Vincent Marois" 47 | 48 | import torch 49 | from torch.nn import Module 50 | 51 | from miprometheus.models.mac.utils_mac import linear 52 | 53 | 54 | class OutputUnit(Module): 55 | """ 56 | Implementation of the ``OutputUnit`` of the MAC network. 57 | """ 58 | 59 | def __init__(self, dim, nb_classes): 60 | """ 61 | Constructor for the ``OutputUnit``. 62 | 63 | :param dim: global 'd' dimension. 64 | :type dim: int 65 | 66 | :param nb_classes: number of classes to consider (classification problem). 67 | :type nb_classes: int 68 | 69 | """ 70 | 71 | # call base constructor 72 | super(OutputUnit, self).__init__() 73 | 74 | # define the 2-layers MLP & specify weights initialization 75 | self.classifier = torch.nn.Sequential(linear(dim * 3, dim, bias=True), 76 | torch.nn.ELU(), 77 | linear(dim, nb_classes, bias=True)) 78 | torch.nn.init.kaiming_uniform_(self.classifier[0].weight) 79 | 80 | def forward(self, mem_state, question_encodings): 81 | """ 82 | Forward pass of the ``OutputUnit``. 83 | 84 | :param mem_state: final memory state, shape [batch_size x dim] 85 | :type mem_state: torch.tensor 86 | 87 | :param question_encodings: questions encodings, shape [batch_size x (2*dim)] 88 | :type question_encodings: torch.tensor 89 | 90 | :return: probability distribution over the classes, [batch_size x nb_classes] 91 | 92 | """ 93 | # cat memory state & questions encodings 94 | concat = torch.cat([mem_state, question_encodings], dim=1) 95 | 96 | # get logits 97 | logits = self.classifier(concat) 98 | 99 | return logits 100 | -------------------------------------------------------------------------------- /configs/mac/mac_smac_cogent_b_finetuning.yaml: -------------------------------------------------------------------------------- 1 | # Lists of tasks that will be executed in parallel. 2 | grid_tasks: 3 | - 4 | default_configs: configs/mac/mac_cogent.yaml # MAC on CoGenT 5 | 6 | overwrite: 7 | # Add the model parameters: 8 | model: 9 | load: 'path/to/the/trained/model/to/finetune' # the CogenT-A trained mac model to finetune on CoGenT-B 10 | - 11 | default_configs: configs/mac/mac_cogent.yaml # MAC on CoGenT 12 | 13 | overwrite: 14 | # Add the model parameters: 15 | model: 16 | load: 'path/to/the/trained/model/to/finetune' # the CLEVR trained mac model to finetune on CoGenT-B 17 | 18 | # indicate the following so that the CLEVR-trained model can properly handle CoGenT data. 19 | training: 20 | problem: 21 | questions: 22 | embedding_source: 'CLEVR' 23 | validation: 24 | problem: 25 | questions: 26 | embedding_source: 'CLEVR' 27 | testing: 28 | problem: 29 | questions: 30 | embedding_source: 'CLEVR' 31 | - 32 | default_configs: configs/mac/s_mac_cogent.yaml # S-MAC on CoGenT 33 | 34 | overwrite: 35 | # Add the model parameters: 36 | model: 37 | load: 'path/to/the/trained/model/to/finetune' # the CogenT-A trained s-mac model to finetune on CoGenT-B 38 | - 39 | default_configs: configs/mac/s_mac_cogent.yaml # S-MAC on CoGenT 40 | 41 | overwrite: 42 | # Add the model parameters: 43 | model: 44 | load: 'path/to/the/trained/model/to/finetune' # the CLEVR trained s-mac model to finetune on CoGenT-B 45 | 46 | # indicate the following so that the CLEVR-trained model can properly handle CoGenT data. 47 | training: 48 | problem: 49 | questions: 50 | embedding_source: 'CLEVR' 51 | validation: 52 | problem: 53 | questions: 54 | embedding_source: 'CLEVR' 55 | testing: 56 | problem: 57 | questions: 58 | embedding_source: 'CLEVR' 59 | 60 | # Parameters that will be overwritten for all tasks. 61 | grid_overwrite: 62 | 63 | training: 64 | #use_EMA: True # EMA: keep track of exponential moving averages of the models weights. 65 | 66 | problem: 67 | settings: 68 | set: 'valB' # finetune on CoGenT-B, on 30k samples indicated by the indices below. 69 | sampler: 70 | name: 'SubsetRandomSampler' 71 | indices: '~/data/CLEVR_CoGenT_v1.0/vigil_cogent_finetuning_valB_indices.txt' 72 | terminal_conditions: 73 | epoch_limit: 10 # finetune for 10 epochs. 74 | 75 | # fix the seeds 76 | seed_torch: 0 77 | seed_numpy: 0 78 | 79 | validation: 80 | problem: 81 | settings: 82 | set: 'trainA' # use CoGenT-A as validation set (-> 10% of the true training set) as we will test on CoGenT. 83 | sampler: 84 | name: 'SubsetRandomSampler' 85 | indices: '~/data/CLEVR_CoGenT_v1.0/vigil_cogent_val_set_indices.txt' 86 | 87 | testing: 88 | problem: 89 | settings: 90 | set: 'valB' 91 | sampler: 92 | name: 'SubsetRandomSampler' 93 | indices: '~/data/CLEVR_CoGenT_v1.0/vigil_cogent_test_valB_indices.txt' 94 | 95 | # test on the CoGenT-B validation set, but on the complementary samples (not used for finetuning). 96 | # test on condition A. We simply point to a file containing all the indices of valA. 97 | multi_tests: { 98 | set: ['valB', 'valA'], 99 | indices: ['~/data/CLEVR_CoGenT_v1.0/vigil_cogent_test_valB_indices.txt', 100 | '~/data/CLEVR_CoGenT_v1.0/vigil_cogent_valA_full_indices.txt'], 101 | max_test_episodes: [-1, -1] 102 | } 103 | 104 | 105 | grid_settings: 106 | # Set number of repetitions of each experiments. 107 | experiment_repetitions: 1 108 | # Set number of concurrent running experiments (will be limited by the actual number of available CPUs/GPUs). 109 | max_concurrent_runs: 7 110 | # Set trainer. 111 | trainer: mip-offline-trainer --------------------------------------------------------------------------------