├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── __init__.py ├── configs ├── agents │ ├── all.yaml │ ├── all_no_fuse.yaml │ ├── igridson_agents.yaml │ ├── non_psg.yaml │ └── psg.yaml ├── config.yaml ├── data_gen │ ├── 10k.yaml │ └── 5k.yaml ├── experiment │ ├── add_remove_objs.yaml │ ├── agent_no_priors.yaml │ ├── agent_ok_priors.yaml │ ├── example.yaml │ ├── large_scenes.yaml │ ├── model_no_fuse.yaml │ ├── noisy_priors.yaml │ ├── save_images.yaml │ ├── simulate_agents.yaml │ └── small_scenes.yaml ├── model │ ├── gcn.yaml │ ├── gcn_no_fuse.yaml │ ├── han.yaml │ ├── heat.yaml │ ├── heat_no_fuse.yaml │ ├── hgcn.yaml │ ├── hgt.yaml │ ├── mlp.yaml │ └── recurrent_gcn.yaml ├── scene_gen │ ├── igridson_scenes.yaml │ ├── large_scenes.yaml │ └── small_scenes.yaml └── task │ ├── find_object.yaml │ ├── predict_env_dynamics.yaml │ ├── predict_location.yaml │ └── predict_locations.yaml ├── environment.yml ├── install.sh ├── memsearch ├── __init__.py ├── agents.py ├── dataset.py ├── experiment_configs.py ├── graphs.py ├── igridson_env.py ├── igridson_utils.py ├── metrics.py ├── models.py ├── rl │ └── complex_input_network.py ├── running.py ├── scene.py ├── tasks.py ├── training.py └── util.py ├── priors ├── coarse_prior_graph.pickle ├── detailed_prior_graph.pickle ├── hardcoded_placement_probs.yaml └── object_metadata.yaml ├── requirements.txt ├── rl_scripts ├── check_env.py ├── interactive_test.py ├── main.py ├── train.py ├── train_dqn.py ├── train_dqn_flat.py └── train_dqn_lstm_flat.py ├── scripts ├── collect_data.py ├── eval.py ├── gen_experiment_configs.py ├── gen_prior_graphs.py ├── igridson_video.py ├── make_gifs.py ├── plot_eval_results.ipynb ├── plot_eval_results.py ├── print_experiment_results.py ├── run_experiment.sh ├── submit_slurm_job.sh └── train.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "dependencies/gym-minigrid"] 2 | path = dependencies/gym-minigrid 3 | url = git@github.com:maximecb/gym-minigrid.git 4 | [submodule "dependencies/Griddly"] 5 | path = dependencies/Griddly 6 | url = git@github.com:Bam4d/Griddly.git 7 | [submodule "dependencies/mini_behavior"] 8 | path = dependencies/mini_behavior 9 | url = git@github.com:stanfordvl/mini_behavior.git 10 | branch = psg 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Modeling Dynamic Environments with Scene Graph Memory (ICML 2023) 2 | 3 | ## Abstract 4 | 5 | Embodied AI agents that search for objects in dynamic environments such as households often need to make efficient decisions by predicting object locations based on partial information. We pose this problem as a new type of link prediction problem: \textbf{link prediction on partially observable dynamic graphs}. Our graph is a representation of a scene in which rooms and objects are nodes, and their relationship (e.g., containment) is encoded in the edges; only parts of the changing graph are known to the agent at each step. This partial observability poses a challenge to existing link prediction approaches, which we address. We propose a novel representation of the agent’s accumulated set of observations into a state representation called a Scene Graph Memory (SGM), as well as a neural net architecture called a Node Edge Predictor (NEP) that extracts information from the SGM to search efficiently. We evaluate our method in the Dynamic House Simulator, a new benchmark that creates diverse dynamic graphs following the semantic patterns typically seen at homes, and show that NEP can be trained to predict the locations of objects in a variety of environments with diverse object movement dynamics, outperforming baselines both in terms of new scene adaptability and overall accuracy. 6 | 7 | ## Installation 8 | 9 | Note: You may use conda instead of mamba 10 | 11 | ```bash 12 | git clone git@github.com:andreykurenkov/memory_object_search.git 13 | cd memory_object_search 14 | conda env create -f environment.yml 15 | conda activate mos 16 | python -m spacy download en_core_web_sm 17 | chmod +x install.sh 18 | ./install.sh 19 | ``` 20 | 21 | # Updating 22 | 23 | ```bash 24 | conda env update --file environment.yml --prune 25 | ``` 26 | 27 | # Running 28 | 29 | ```python 30 | # Generate prior graphs 31 | python scripts/gen_prior_graphs.py 32 | 33 | # Generate the data 34 | python scripts/collect_data.py 35 | 36 | # Train the model 37 | python scripts/train.py --multirun model=mlp,gcn,heat 38 | 39 | # Evaluate the model 40 | python scripts/eval.py 41 | ``` 42 | # Specific experiments 43 | To run specific experiments, just append experiment= to the above commands after the python file name. EG: 44 | ```python 45 | python scripts/gen_prior_graphs.py experiment=example 46 | python scripts/train.py --multirun experiment=example model=mlp,gcn,heat 47 | python scripts/eval.py 48 | ``` 49 | # Running simulations with iGridson 50 | iGridson simulations supports two modes: headless and render. In headless mode, the agent will navigate through the apartment environment without producing or saving any visualizations. In render mode, the visualization will be showed in a window and each frame will be saved to outputs/igridson_simulations/. This parameter can be changing the `viz_mode` parameter in the config file for configs/experiment/simulate_agents. 51 | 52 | Next, copy any trained models you want the simulator to use to ```outputs/simulate_agents/models```. 53 | 54 | Finally, run 55 | ```python 56 | python scripts/igridson_video.py experiment=simulate_agents 57 | ``` 58 | 59 | # Sweeping multiple parameters 60 | ```python 61 | python scripts/eval.py --multirun +eval=predict_location ++changes_per_step=0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0 ++obsevation_prob=0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,1.0 62 | ``` 63 | 64 | # Generating config files for experiment(s) 65 | 66 | Use ``` python scripts/yaml_gen.py``` to generate config files for experiments. The scripts expects a text file as input. Either enter strings such as pl_l_d_r_l_za_npp (one on each line), or tab-spaced variables such as 67 | predict_dynamics small detailed coarse large none None 68 | ,also one on each line 69 | 70 | Note that the tab spaced variables are what you would get if you simply copy multiple rows from a spreadsheet. 71 | 72 | The script by default expects a text file at data/iclr_experiments.txt containing formatted strings (such as pl_l_d_r_l_za_npp) 73 | 74 | To use your own text file, use 75 | ``` 76 | python scripts/generate_experiment_configs.py --text_file_path 77 | ``` 78 | 79 | To use tab separated variable names, use 80 | ``` 81 | python scripts/generate_experiment_configs.py --use_tab_separated_vars 82 | ``` 83 | 84 | To print out the formatted strings for generated config files, use 85 | ``` 86 | python scripts/generate_experiment_configs.py --print_formatted_strings 87 | ``` 88 | 89 | # Reinforcement Learning (Experimental) 90 | Some scripts for training RL agents is provided under rl_scripts/. Note that this is an experimental feature that hasn't been thoroughly tested and did not give exceptional results in our preliminary tests. To run a basic training script, run: 91 | ```python 92 | python rl_scripts/train.py experiment=simulate_agents 93 | ``` 94 | 95 | Also note that it needs a configuration file similar to simulate_agents as it uses a headless iGridson environment. 96 | 97 | ## Cite 98 | 99 | ```bibtex 100 | @article{kurenkov2023modeling, 101 | title={Modeling Dynamic Environments with Scene Graph Memory}, 102 | author={Kurenkov, Andrey and Lingelbach, Michael and Agarwal, Tanmay and Li, Chengshu and Jin, Emily, and Fei-Fei, Li and Wu, Jiajun and Savarese, Silvio, and Mart{\'i}n-Mart{\'i}n, Roberto}, 103 | booktitle={International Conference on Machine Learning}, 104 | year={2023}, 105 | organization={PMLR} 106 | } 107 | ``` 108 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andreykurenkov/modeling_env_dynamics/74e5f9d722469f2d1148fe131aa85dfb049da7cb/__init__.py -------------------------------------------------------------------------------- /configs/agents/all.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | agents: 3 | cfg_name: all-agents 4 | agent_priors_type: detailed 5 | constant_prior_probs: True 6 | linear_schedule_num_steps: 100 7 | sgm_memorization_baseline: False 8 | memorization_use_priors: False 9 | frequentist_use_priors: False 10 | bayesian_prior_var: 0.05 11 | sgm_use_priors: True 12 | sgm_use_model: True 13 | observation_prob: 0.75 14 | 15 | agent_types: 16 | - random 17 | - priors 18 | - memorization 19 | - frequentist 20 | - bayesian 21 | - sgm_mlp 22 | - sgm_gcn 23 | - sgm_heat 24 | - sgm_hgt 25 | #- sgm_hgcn 26 | - sgm_han 27 | - upper_bound 28 | -------------------------------------------------------------------------------- /configs/agents/all_no_fuse.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | agents: 3 | cfg_name: all-agents 4 | agent_priors_type: detailed 5 | constant_prior_probs: True 6 | linear_schedule_num_steps: 0 7 | psg_memorization_baseline: False 8 | memorization_use_priors: False 9 | counts_use_priors: False 10 | psg_use_priors: True 11 | psg_use_model: True 12 | observation_prob: 0.9 13 | 14 | agent_types: 15 | - random 16 | - priors 17 | - memorization 18 | - psg_mlp 19 | - psg_gcn_no_fuse 20 | - psg_heat_no_fuse 21 | - upper_bound 22 | -------------------------------------------------------------------------------- /configs/agents/igridson_agents.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | agents: 3 | cfg_name: igridson_agents 4 | agent_priors_type: detailed 5 | constant_prior_probs: True 6 | linear_schedule_num_steps: 100 7 | sgm_memorization_baseline: False 8 | memorization_use_priors: False 9 | frequentist_use_priors: False 10 | bayesian_prior_var: 0.05 11 | sgm_use_priors: True 12 | sgm_use_model: True 13 | observation_prob: 0.75 14 | 15 | agent_types: 16 | - random 17 | - priors 18 | - memorization 19 | - frequentist 20 | - bayesian 21 | - sgm_mlp 22 | - sgm_gcn 23 | - sgm_heat 24 | - upper_bound -------------------------------------------------------------------------------- /configs/agents/non_psg.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | agents: 3 | cfg_name: non_psg_agents 4 | agent_priors_type: detailed 5 | constant_prior_probs: True 6 | linear_schedule_num_steps: 0 7 | psg_memorization_baseline: False 8 | memorization_use_priors: False 9 | counts_use_priors: False 10 | psg_use_priors: True 11 | psg_use_model: True 12 | observation_prob: 0.6 13 | 14 | agent_types: 15 | - random 16 | - priors 17 | - counts 18 | - memorization 19 | - upper_bound 20 | - oracle_upper_bound 21 | -------------------------------------------------------------------------------- /configs/agents/psg.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | agents: 3 | cfg_name: all-agents 4 | agent_priors_type: detailed 5 | constant_prior_probs: True 6 | linear_schedule_num_steps: 0 7 | sgm_memorization_baseline: False 8 | memorization_use_priors: False 9 | counts_use_priors: False 10 | sgm_use_priors: True 11 | sgm_use_model: True 12 | observation_prob: 0.9 13 | 14 | agent_types: 15 | - psg_mlp 16 | - psg_gcn 17 | - psg_heat 18 | -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - experiment: example 4 | 5 | log_path: outputs 6 | collect_data_dir: ${log_path}/${run_name}/collected_data 7 | processed_dataset_dir: ${log_path}/${run_name}/processed_dataset 8 | models_dir: ${log_path}/${run_name}/models 9 | 10 | save_dir: ${hydra:run.dir} 11 | save_images: False 12 | model: 13 | dataset: ${dataset} 14 | models_dir: ${models_dir} 15 | psg_memorization_baseline: ${agents.psg_memorization_baseline} 16 | 17 | dataset: ${run_name} 18 | no_cache: False 19 | process_graphs_after_collection: True #may break if you don't have enough RAM 20 | 21 | collect_data_num_workers: 10 22 | process_data_num_workers: 1 23 | 24 | num_train_epochs: 25 25 | batch_size: 200 26 | train_labels_per_batch: 250 27 | test_labels_per_batch: 500 28 | psg_training: True 29 | 30 | eval_in_parallel: True 31 | 32 | hydra: 33 | output_subdir: hydra 34 | run: 35 | dir: ${log_path}/${run_name}/ 36 | sweep: 37 | dir: ${log_path}/${run_name}/ 38 | subdir: ${hydra.job.override_dirname} 39 | job: 40 | chdir: False 41 | config: 42 | override_dirname: 43 | exclude_keys: 44 | - eval 45 | -------------------------------------------------------------------------------- /configs/data_gen/10k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | data_gen: 3 | cfg_name: 10k 4 | num_steps_train : 10000 5 | num_steps_test : 5000 6 | agent_type: memorization 7 | -------------------------------------------------------------------------------- /configs/data_gen/5k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | data_gen: 3 | cfg_name: 5k 4 | num_steps_train : 5000 5 | num_steps_test : 1000 6 | agent_type: upper_bound 7 | -------------------------------------------------------------------------------- /configs/experiment/add_remove_objs.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /scene_gen: large_scenes 4 | - /agents: all 5 | - /data_gen: 10k 6 | - /task: predict_location 7 | - /model: gcn 8 | 9 | run_name: add_or_remove_objs 10 | scene_gen: 11 | add_or_remove_objs: True 12 | dataset: add_or_remove_objs 13 | -------------------------------------------------------------------------------- /configs/experiment/agent_no_priors.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /scene_gen: large_scenes 4 | - /agents: all 5 | - /data_gen: 10k 6 | - /task: predict_location 7 | - /model: gcn 8 | 9 | run_name: psg_no_priors 10 | dataset: no_priors 11 | agents: 12 | psg_use_priors: False 13 | -------------------------------------------------------------------------------- /configs/experiment/agent_ok_priors.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /scene_gen: large_scenes 4 | - /agents: all 5 | - /data_gen: 10k 6 | - /task: predict_location 7 | - /model: gcn 8 | 9 | run_name: ok_priors 10 | dataset: ok_priors 11 | agents: 12 | agent_priors_type: coarse 13 | -------------------------------------------------------------------------------- /configs/experiment/example.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /scene_gen: large_scenes 4 | - /agents: all 5 | - /data_gen: 10k 6 | - /task: predict_location 7 | - /model: heat 8 | 9 | run_name: example 10 | -------------------------------------------------------------------------------- /configs/experiment/large_scenes.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /scene_gen: large_scenes 4 | - /agents: all 5 | - /data_gen: 10k 6 | - /task: predict_location 7 | - /model: gcn 8 | 9 | run_name: large_scenes 10 | dataset: large_scenes 11 | -------------------------------------------------------------------------------- /configs/experiment/model_no_fuse.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /scene_gen: large_scenes 4 | - /agents: all 5 | - /data_gen: 5k 6 | - /task: predict_location 7 | - /model: gcn 8 | 9 | run_name: no_fuse 10 | model: 11 | fuse_node_features: False 12 | -------------------------------------------------------------------------------- /configs/experiment/noisy_priors.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /scene_gen: large_scenes 4 | - /agents: all 5 | - /data_gen: 5k 6 | - /task: predict_location 7 | - /model: gcn 8 | 9 | run_name: noisy_priors 10 | dataset: noisy_priors 11 | scene_gen: 12 | priors_noise: 0.5 13 | -------------------------------------------------------------------------------- /configs/experiment/save_images.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /scene_gen: small_scenes 4 | 5 | run_name: save_images 6 | dataset: save_images 7 | 8 | model: 9 | dataset: save_images 10 | 11 | task: 12 | eps_per_scene: 200 13 | 14 | data_gen: 15 | num_steps_train: 200 16 | num_steps_test: 0 17 | -------------------------------------------------------------------------------- /configs/experiment/simulate_agents.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /scene_gen: igridson_scenes 4 | - /agents: igridson_agents 5 | - /task: find_object 6 | - /model: heat 7 | - /data_gen: 10k 8 | 9 | run_name: simulate_agents 10 | num_queries: 10 11 | max_attempts: 12 12 | viz_mode: headless -------------------------------------------------------------------------------- /configs/experiment/small_scenes.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /scene_gen: small_scenes 4 | - /agents: all 5 | - /data_gen: 10k 6 | - /task: predict_location 7 | - /model: gcn 8 | 9 | run_name: small_scenes 10 | dataset: small_scenes 11 | -------------------------------------------------------------------------------- /configs/model/gcn.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model: 3 | cfg_name: gcn 4 | model_type: gcn 5 | model_name: gcn 6 | 7 | edge_features: all 8 | node_features: all 9 | add_num_nodes: False 10 | add_num_edges: False 11 | use_edge_weights: False 12 | 13 | embed_text_with_transformer: False 14 | include_transformer: True 15 | 16 | node_embedding_dim: 64 17 | edge_embedding_dim: 32 18 | node_mlp_hidden_layers: [64,64] 19 | edge_mlp_hidden_layers: [64,64] 20 | output_mlp_hidden_layers: [64,64] 21 | gnn_hidden_layers: [] 22 | node_fuse_method: 'avg' 23 | 24 | psg_memorization_baseline: False 25 | zero_priors_to_objects: False 26 | add_self_loops: True 27 | reversed_edges: True 28 | -------------------------------------------------------------------------------- /configs/model/gcn_no_fuse.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model: 3 | cfg_name: gcn 4 | model_type: gcn 5 | model_name: gcn_no_fuse 6 | recurrent: False 7 | 8 | node_embed_size: 64 9 | edge_features: all 10 | node_features: all 11 | fuse_node_features: False 12 | add_num_nodes: False 13 | add_num_edges: False 14 | use_edge_weights: False 15 | 16 | psg_memorization_baseline: False 17 | zero_priors_to_objects: False 18 | models_dir: data/models 19 | dataset: ${task.cfg_name}_${scene_gen.cfg_name} 20 | -------------------------------------------------------------------------------- /configs/model/han.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model: 3 | cfg_name: han 4 | model_type: han 5 | model_name: han 6 | 7 | edge_features: all 8 | node_features: all 9 | add_num_nodes: False 10 | add_num_edges: False 11 | use_edge_weights: False 12 | 13 | embed_text_with_transformer: False 14 | include_transformer: True 15 | 16 | node_embedding_dim: 64 17 | edge_embedding_dim: 32 18 | node_mlp_hidden_layers: [64,64] 19 | edge_mlp_hidden_layers: [64,64] 20 | output_mlp_hidden_layers: [64,64] 21 | gnn_hidden_layers: [] 22 | node_fuse_method: 'avg' 23 | 24 | psg_memorization_baseline: False 25 | zero_priors_to_objects: False 26 | add_self_loops: True 27 | reversed_edges: True 28 | -------------------------------------------------------------------------------- /configs/model/heat.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model: 3 | cfg_name: heat 4 | model_type: heat 5 | model_name: heat 6 | 7 | edge_features: all 8 | node_features: all 9 | add_num_nodes: False 10 | add_num_edges: False 11 | use_edge_weights: False 12 | 13 | embed_text_with_transformer: False 14 | include_transformer: True 15 | 16 | node_embedding_dim: 64 17 | edge_embedding_dim: 32 18 | edge_type_emb_dim: 2 19 | node_mlp_hidden_layers: [64,64] 20 | edge_mlp_hidden_layers: [64,64] 21 | output_mlp_hidden_layers: [64,64] 22 | gnn_hidden_layers: [] 23 | node_fuse_method: 'avg' 24 | 25 | psg_memorization_baseline: False 26 | zero_priors_to_objects: False 27 | add_self_loops: True 28 | reversed_edges: True 29 | -------------------------------------------------------------------------------- /configs/model/heat_no_fuse.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model: 3 | cfg_name: heat 4 | model_type: heat 5 | model_name: heat_no_fuse 6 | recurrent: False 7 | 8 | node_embed_size: 64 9 | edge_features: all 10 | node_features: all 11 | fuse_node_features: False 12 | add_num_nodes: False 13 | add_num_edges: False 14 | use_edge_weights: False 15 | 16 | psg_memorization_baseline: False 17 | zero_priors_to_objects: False 18 | models_dir: data/models 19 | dataset: ${task.cfg_name}_${scene_gen.cfg_name} 20 | -------------------------------------------------------------------------------- /configs/model/hgcn.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model: 3 | cfg_name: hgcn 4 | model_type: hgcn 5 | model_name: hgcn 6 | recurrent: False 7 | 8 | edge_features: all 9 | node_features: all 10 | add_num_nodes: False 11 | add_num_edges: False 12 | use_edge_weights: False 13 | 14 | embed_text_with_transformer: False 15 | include_transformer: True 16 | 17 | node_embedding_dim: 64 18 | edge_embedding_dim: 64 19 | node_mlp_hidden_layers: [64,64] 20 | edge_mlp_hidden_layers: [64,64] 21 | output_mlp_hidden_layers: [64,64] 22 | gnn_hidden_layers: [] 23 | node_fuse_method: 'avg' 24 | 25 | psg_memorization_baseline: False 26 | zero_priors_to_objects: False 27 | add_self_loops: False 28 | reversed_edges: True 29 | -------------------------------------------------------------------------------- /configs/model/hgt.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model: 3 | cfg_name: hgt 4 | model_type: hgt 5 | model_name: hgt 6 | 7 | edge_features: all 8 | node_features: all 9 | fuse_node_features: True 10 | add_num_nodes: False 11 | add_num_edges: False 12 | use_edge_weights: False 13 | 14 | embed_text_with_transformer: False 15 | include_transformer: True 16 | 17 | node_embedding_dim: 64 18 | edge_embedding_dim: 32 19 | node_mlp_hidden_layers: [64,64] 20 | edge_mlp_hidden_layers: [64,64] 21 | output_mlp_hidden_layers: [64,64] 22 | gnn_hidden_layers: [] 23 | node_fuse_method: 'avg' 24 | 25 | psg_memorization_baseline: False 26 | zero_priors_to_objects: False 27 | add_self_loops: True 28 | reversed_edges: True 29 | -------------------------------------------------------------------------------- /configs/model/mlp.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model: 3 | cfg_name: mlp 4 | model_type: mlp 5 | model_name: mlp 6 | 7 | edge_features: all 8 | node_features: all 9 | add_num_nodes: False 10 | add_num_edges: False 11 | use_edge_weights: False 12 | embed_text_with_transformer: False 13 | include_transformer: True 14 | 15 | node_embedding_dim: 64 16 | edge_embedding_dim: 32 17 | node_mlp_hidden_layers: [64,64] 18 | edge_mlp_hidden_layers: [64,64] 19 | output_mlp_hidden_layers: [64,64] 20 | node_fuse_method: 'avg' 21 | 22 | psg_memorization_baseline: False 23 | zero_priors_to_objects: False 24 | reversed_edges: True 25 | -------------------------------------------------------------------------------- /configs/model/recurrent_gcn.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model_type: gcn 3 | recurrent: True 4 | 5 | node_embed_size: 64 6 | edge_features: all 7 | node_features: all 8 | fuse_node_features: True 9 | add_num_nodes: False 10 | add_num_edges: False 11 | use_edge_weights: False 12 | 13 | psg_memorization_baseline: False 14 | zero_priors_to_objects: False 15 | -------------------------------------------------------------------------------- /configs/scene_gen/igridson_scenes.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | scene_gen: 3 | cfg_name: fixed_scene 4 | scene_priors_type: detailed 5 | evolver_priors_nodes_from_scene: True 6 | object_move_ratio_per_step: 0.075 7 | add_or_remove_objs: False 8 | min_floors: 1 9 | max_floors: 1 10 | min_rooms: 4 11 | max_rooms: 4 12 | min_furniture: 10 13 | max_furniture: 10 14 | min_objects: 1 15 | max_objects: 4 16 | priors_noise: 0 17 | priors_sparsity_level: 0.2 18 | use_move_freq: True 19 | random_evolver: False 20 | -------------------------------------------------------------------------------- /configs/scene_gen/large_scenes.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | scene_gen: 3 | cfg_name: large_scenes 4 | scene_priors_type: detailed 5 | evolver_priors_nodes_from_scene: True 6 | object_move_ratio_per_step: 0.05 7 | add_or_remove_objs: True 8 | min_floors: 1 9 | max_floors: 1 10 | min_rooms: 4 11 | max_rooms: 4 12 | min_furniture: 8 13 | max_furniture: 8 14 | min_objects: 6 15 | max_objects: 6 16 | priors_noise: 0.25 17 | priors_sparsity_level: 0.25 18 | use_move_freq: True 19 | random_evolver: False 20 | -------------------------------------------------------------------------------- /configs/scene_gen/small_scenes.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | scene_gen: 3 | cfg_name: small-scenes 4 | scene_priors_type: detailed 5 | evolver_priors_nodes_from_scene: True 6 | object_move_ratio_per_step: 0.025 7 | add_or_remove_objs: False 8 | min_floors: 1 9 | max_floors: 1 10 | min_rooms: 3 11 | max_rooms: 3 12 | min_furniture: 6 13 | max_furniture: 6 14 | min_objects: 4 15 | max_objects: 4 16 | priors_noise: 0.2 17 | priors_sparsity_level: 0.2 18 | use_move_freq: True 19 | random_evolver: False 20 | -------------------------------------------------------------------------------- /configs/task/find_object.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | task: 3 | cfg_name: find_object 4 | name: find_object 5 | num_smoothing_steps: 10 6 | eps_per_scene : 110 7 | num_steps : 11000 8 | top_k: 10 9 | max_attempts: 10 10 | -------------------------------------------------------------------------------- /configs/task/predict_env_dynamics.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | task: 3 | cfg_name: predict_env_dynamics 4 | name: predict_env_dynamics 5 | num_steps : 11000 6 | eps_per_scene : 110 7 | top_k : 3 8 | num_smoothing_steps : 100 9 | -------------------------------------------------------------------------------- /configs/task/predict_location.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | task: 3 | cfg_name: predict_location 4 | name: predict_location 5 | num_smoothing_steps: 10 6 | eps_per_scene : 110 7 | num_steps : 11000 8 | num_objects: 1 9 | -------------------------------------------------------------------------------- /configs/task/predict_locations.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | task: 3 | cfg_name: predict_locations 4 | name: predict_location 5 | num_smoothing_steps: 10 6 | eps_per_scene : 110 7 | num_steps : 5500 8 | num_objects: 5 9 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: mos 2 | channels: 3 | - pyg 4 | - pytorch 5 | - nvidia 6 | - conda-forge 7 | dependencies: 8 | - python=3.8 9 | - pip 10 | - pillow 11 | - pytorch=1.13.0 12 | - torchvision=0.14.0 13 | - torchaudio=0.13.0 14 | - pytorch-cuda=11.7 15 | - pyg 16 | - opencv 17 | - numpy=1.23.5 18 | - pip: 19 | - torch-geometric-temporal 20 | - networkx 21 | - hydra-core 22 | - spacy 23 | - prior 24 | - gensim 25 | - matplotlib 26 | - shapely 27 | - seaborn 28 | - gym 29 | - tensorboard 30 | - tensorboardx 31 | - coloredlogs 32 | - ray[all]==2.2.0 33 | - sentence-transformers 34 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | pip install -e . 2 | pip install -e ./dependencies/gym-minigrid 3 | pip install -e ./dependencies/mini_behavior 4 | -------------------------------------------------------------------------------- /memsearch/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | ROOT_PATH = os.path.dirname(__file__) 4 | CONFIG_PATH = os.path.join(ROOT_PATH, "..", "configs") 5 | -------------------------------------------------------------------------------- /memsearch/experiment_configs.py: -------------------------------------------------------------------------------- 1 | default_configs = { 2 | 'defaults': [{'/scene_gen': 'large_scenes'}, {'/agents': 'all'}, {'/data_gen': '10k'}, {'/task': 'predict_location'}, {'/model': 'gcn'}], 3 | 'run_name': 'pl_l_d_d_l_za_n', 4 | 'dataset': 'pl_l_d_d_l_za_n', 5 | 'node_features': ['word_vec', 'time_since_observed', 'times_observed', 'time_since_state_change', 'node_type'], 6 | 'edge_features': ['time_since_observed', 'time_since_state_change', 'times_observed', 'times_state_true', 'last_observed_state', 'freq_true', 'edge_type', 'prior_prob'], 7 | 'agents': 8 | { 9 | 'agent_priors_type': 'detailed', 10 | 'psg_use_priors': True 11 | }, 12 | 'scene_gen': 13 | { 14 | 'scene_priors_type': 'detailed', 15 | 'priors_noise': 0.2, 16 | 'priors_sparsity_level': 0.4 17 | } 18 | } 19 | 20 | TASK_KEY_MATCHES = { 21 | 'pl': 'predict_location', 22 | 'pd': 'predict_env_dynamics', 23 | 'ps': 'predict_scene' 24 | } 25 | SCENE_SIZE_KEY_MATCHES = { 26 | 'l': 'large_scenes', 27 | 's': 'small_scenes' 28 | } 29 | SCENE_PRIORS_KEY_MATCHES = { 30 | 'd': 'detailed', 31 | 'c': 'coarse', 32 | 'r': 'random' 33 | } 34 | AGENT_PRIORS_KEY_MATCHES = { 35 | 'd': 'detailed', 36 | 'c': 'coarse', 37 | 'r': 'random' 38 | } 39 | DATASET_SIZE_KEY_MATCHES = { 40 | 'l': '10k', 41 | 's': '5k', 42 | 't': '1k' 43 | } 44 | PRIORS_NOISE_KEY_MATCHES = { 45 | 'za': 0.2, 46 | 'n': 0 47 | } 48 | PRIORS_SPARSITY_LEVEL_KEY_MATHES = { 49 | 'za': 0.2, 50 | 'n': 0.2 51 | } 52 | 53 | DEFAULT_EXP_NAMES = ['pl_l_d_d_l_za_n', 54 | 'ps_l_d_d_l_za_n', 55 | 'pd_l_d_d_l_za_n', 56 | 'pl_l_d_c_l_za_n', 57 | 'ps_l_d_c_l_za_n', 58 | 'pd_l_d_c_l_za_n', 59 | 'pl_l_d_r_l_za_n', 60 | 'ps_l_d_r_l_za_n', 61 | 'pd_l_d_r_l_za_n', 62 | 'pl_s_d_d_l_za_n', 63 | 'ps_s_d_d_l_za_n', 64 | 'pd_s_d_d_l_za_n', 65 | 'pl_s_d_c_l_za_n', 66 | 'ps_s_d_c_l_za_n', 67 | 'pd_s_d_c_l_za_n', 68 | 'pl_s_d_r_l_za_npp', 69 | 'ps_s_d_r_l_za_npp', 70 | 'pd_s_d_r_l_za_npp', 71 | 'pl_l_d_d_l_za_npp', 72 | 'ps_l_d_d_l_za_npp', 73 | 'pd_l_d_d_l_za_npp', 74 | 'pl_l_d_c_l_za_npp', 75 | 'ps_l_d_c_l_za_npp', 76 | 'pd_l_d_c_l_za_npp', 77 | 'pl_l_d_r_l_za_npp', 78 | 'ps_l_d_r_l_za_npp', 79 | 'pd_l_d_r_l_za_npp', 80 | 'pl_s_d_r_l_za_npp', 81 | 'ps_s_d_r_l_za_npp', 82 | 'pd_s_d_r_l_za_npp', 83 | 'pl_l_d_d_l_za_nwv', 84 | 'ps_l_d_d_l_za_nwv', 85 | 'pd_l_d_d_l_za_nwv', 86 | 'pl_l_d_c_l_za_nwv', 87 | 'ps_l_d_c_l_za_nwv', 88 | 'pd_l_d_c_l_za_nwv', 89 | 'pl_l_d_r_l_za_nwv', 90 | 'ps_l_d_r_l_za_nwv', 91 | 'pd_l_d_r_l_za_nwv', 92 | 'pl_l_d_d_l_za_ntf', 93 | 'ps_l_d_d_l_za_ntf', 94 | 'pd_l_d_d_l_za_ntf', 95 | 'pl_l_d_c_l_za_ntf', 96 | 'ps_l_d_c_l_za_ntf', 97 | 'pd_l_d_c_l_za_ntf', 98 | 'pl_l_d_r_l_za_ntf', 99 | 'ps_l_d_r_l_za_ntf', 100 | 'pd_l_d_r_l_za_ntf', 101 | 'pl_l_d_d_l_n_n', 102 | 'ps_l_d_d_l_n_n', 103 | 'pd_l_d_d_l_n_n', 104 | 'pl_l_d_c_l_n_n', 105 | 'ps_l_d_c_l_n_n', 106 | 'pd_l_d_c_l_n_n', 107 | 'pl_l_d_r_l_n_n', 108 | 'ps_l_d_r_l_n_n', 109 | 'pd_l_d_r_l_n_n', 110 | 'pl_s_d_d_l_n_n', 111 | 'ps_s_d_d_l_n_n', 112 | 'pd_s_d_d_l_n_n', 113 | 'pl_s_d_c_l_n_n', 114 | 'ps_s_d_c_l_n_n', 115 | 'pd_s_d_c_l_n_n', 116 | ] 117 | -------------------------------------------------------------------------------- /memsearch/igridson_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym import spaces 3 | import gensim.downloader as api 4 | 5 | from mini_behavior.sampling import * 6 | from mini_behavior.envs.fixed_scene import FixedEnv 7 | from mini_behavior.objects import OBJECT_TO_IDX, IDX_TO_OBJECT 8 | from mini_behavior.grid import TILE_PIXELS 9 | 10 | from memsearch.graphs import NodeType, RECIPROCAL_EDGE_TYPES 11 | 12 | TILE_PIXELS = 32 13 | 14 | EDGE_TYPE_TO_FUNC = { 15 | "onTop": put_ontop, 16 | "in": put_inside, 17 | "contains": put_contains, 18 | "under": put_under, 19 | } 20 | 21 | 22 | class SMGFixedEnv(FixedEnv): 23 | def __init__(self, scene_sampler, scene_evolver, encode_obs_im=False, mission_mode='one_hot', scene=None, set_goal_icon=False, env_evolve_freq=100): 24 | self.scene_sampler = scene_sampler 25 | self.scene_evolver = scene_evolver 26 | 27 | self.scene = self.scene_sampler.sample() if scene is None else scene 28 | self.scene_evolver.set_new_scene(self.scene) 29 | 30 | self.env_evolove_freq = env_evolve_freq 31 | self.node_to_obj = None 32 | self.moved_objs = None 33 | 34 | num_objs = self.get_num_objs() 35 | self.mission_mode = mission_mode 36 | self.encode_obs_im = encode_obs_im 37 | 38 | self.initialized = False 39 | self.set_goal_icon = set_goal_icon 40 | 41 | super().__init__(num_objs=num_objs, agent_view_size=7) 42 | 43 | def validate_scene(self): 44 | # Check scene 45 | all_furniture_nodes = self.scene.scene_graph.get_nodes_with_type(NodeType.FURNITURE) 46 | 47 | for fn in all_furniture_nodes: 48 | obj_children = [node for node in fn.get_children_nodes() if node.type == NodeType.OBJECT] 49 | if len(obj_children) > 4: 50 | return False 51 | return True 52 | 53 | def step(self, action): 54 | obs, reward, done, info = super().step(action) 55 | # if env_evolve_freq = -1, don't evolve env during steps, only during resets 56 | if self.step_count > 1 and self.env_evolove_freq != -1 and self.step_count % self.env_evolove_freq == 0: 57 | self.evolve() 58 | return obs, reward, done, info 59 | 60 | def _reward(self): 61 | return -1 62 | 63 | def _gen_objs(self): 64 | super()._gen_objs() 65 | if self.node_to_obj is not None: 66 | self.graph_to_grid() 67 | 68 | def _set_obs_space(self): 69 | assert self.mission_mode in ['one_hot', 'word_vec', 'int'], "Only three modes supported: one hot, word vec or integer." 70 | 71 | if self.mission_mode == 'word_vec': 72 | self.word2vec_model = api.load("glove-twitter-25") 73 | mission_observation_space = spaces.Box( 74 | low=-1, 75 | high=1, 76 | shape=(25), 77 | dtype='float32' 78 | ) 79 | elif self.mission_mode == 'int': 80 | mission_observation_space = spaces.Discrete(len(IDX_TO_OBJECT)) 81 | elif self.mission_mode == 'one_hot': 82 | mission_observation_space = spaces.Box( 83 | low=0, 84 | high=1, 85 | shape=(len(IDX_TO_OBJECT),), 86 | dtype='int' 87 | ) 88 | else: 89 | assert "need valid obs mode for mission" 90 | if self.encode_obs_im: 91 | image_observation_space = spaces.Box( 92 | low=0, 93 | high=255, 94 | shape=(self.agent_view_size, self.agent_view_size, 3), 95 | dtype=np.uint8 96 | ) 97 | else: 98 | image_observation_space = spaces.Box( 99 | low=0, 100 | high=255, 101 | shape=(self.agent_view_size * TILE_PIXELS, self.agent_view_size * TILE_PIXELS, 3), 102 | dtype=np.uint8 103 | ) 104 | 105 | self.observation_space = spaces.Dict({ 106 | "direction": spaces.Box(low=0, high=4, shape=(), dtype=np.uint8), 107 | 'image': image_observation_space, 108 | "mission": mission_observation_space, 109 | }) 110 | 111 | 112 | def reset(self): 113 | # Hack around nightmare inheritance chain 114 | if not self.initialized: 115 | self._set_obs_space() 116 | 117 | # Reinitialize episode-specific variables 118 | self.agent_pos = (-1, -1) 119 | self.agent_dir = -1 120 | 121 | self.carrying = set() 122 | 123 | for obj in self.obj_instances.values(): 124 | obj.reset() 125 | 126 | self.reward = 0 127 | 128 | # Generate a new random grid at the start of each episode 129 | # the same seed before calling env.reset() 130 | self._gen_grid(self.width, self.height) 131 | 132 | # generate furniture view 133 | # self.furniture_view = self.grid.render_furniture(tile_size=TILE_PIXELS, obj_instances=self.obj_instances) 134 | 135 | # These fields should be defined by _gen_grid 136 | assert self.agent_pos is not None 137 | assert self.agent_dir is not None 138 | 139 | # Check that the agent doesn't overlap with an object 140 | assert self.grid.is_empty(*self.agent_pos) 141 | 142 | # Step count since episode start 143 | self.step_count = 0 144 | self.episode += 1 145 | 146 | # Make node to obj list 147 | self.set_node_to_obj() 148 | # TODO not sure 149 | if not self.initialized: 150 | self.graph_to_grid() 151 | self.initialized = True 152 | #TODO: Set the mission THIS IS RANDOM AND NEEDS TO GET FIXED 153 | self.set_random_mission() 154 | 155 | # Return first observation 156 | obs = self.gen_obs() 157 | 158 | self.reward = 0 159 | self.step_count = 0 160 | self.episode += 1 161 | 162 | if self.node_to_obj is not None: 163 | self.evolve() 164 | 165 | return obs 166 | 167 | def get_num_objs(self): 168 | num_objs = {} 169 | 170 | for node in self.scene.scene_graph.get_nodes_with_type(NodeType.FURNITURE) + self.scene.scene_graph.get_nodes_with_type(NodeType.OBJECT): 171 | num_objs[node.label] = num_objs.get(node.label, 0) + 1 172 | 173 | return num_objs 174 | 175 | def set_node_to_obj(self): 176 | """ 177 | returns dict: key = node, value = obj_instance 178 | """ 179 | self.node_to_obj = {} 180 | 181 | for obj_type, objs in self.objs.items(): 182 | nodes = self.scene.scene_graph.get_nodes_with_label(obj_type) 183 | 184 | assert len(objs) == len(nodes) 185 | for i in range(len(objs)): 186 | self.node_to_obj[nodes[i]] = objs[i] 187 | 188 | def graph_to_grid(self): 189 | """ 190 | NOTE: each edge obj has 1 parent node, and there are two edges between 191 | """ 192 | # for every furniture 193 | for furniture_node in self.scene.scene_graph.get_nodes_with_type(NodeType.FURNITURE): 194 | furniture = self.node_to_obj[furniture_node] 195 | # for every obj related to the furniture 196 | for obj_node in furniture_node.get_children_nodes(): 197 | if obj_node.type == NodeType.OBJECT: 198 | obj = self.node_to_obj[obj_node] 199 | edges = obj_node.get_edges_to_me() 200 | if len(edges) > 0: 201 | edge = edges[0] # edge from obj to furniture 202 | assert edge.node2 == obj_node and edge.node1 == furniture_node 203 | EDGE_TYPE_TO_FUNC[RECIPROCAL_EDGE_TYPES[edge.type].value](self, obj, furniture) # put the obj on the grid 204 | else: 205 | print("Found 0 length edges") 206 | 207 | def sample_to_grid(self, obj_node): 208 | if obj_node.type == NodeType.OBJECT: 209 | obj = self.node_to_obj[obj_node] 210 | if obj.cur_pos is not None and not obj.check_abs_state(state='inhandofrobot'): 211 | self.grid.remove(*obj.cur_pos, obj) 212 | 213 | edge = \ 214 | [e for e in obj_node.edges if (e.node1.type == NodeType.FURNITURE or e.node2.type == NodeType.FURNITURE)][0] 215 | 216 | if edge.node1.type == NodeType.FURNITURE: 217 | furniture_node = edge.node1 218 | edge_type = RECIPROCAL_EDGE_TYPES[edge.type] 219 | else: 220 | furniture_node = edge.node2 221 | edge_type = edge.type 222 | 223 | furniture = self.node_to_obj[furniture_node] 224 | EDGE_TYPE_TO_FUNC[edge_type.value](self, obj, furniture) # put the obj on the grid 225 | 226 | # uncomment for debugging 227 | # check_state(self, obj, furniture, edge_type) 228 | 229 | def evolve(self): 230 | # self.scene_evolver.scene = self.scene 231 | self.moved_objs = self.scene_evolver.evolve() # list of objects that were moved 232 | for obj_node in self.moved_objs: 233 | if obj_node not in list(self.node_to_obj.keys()): # if it is an added obj 234 | obj_instance = self.add_objs({obj_node.label: 1})[0] 235 | node_to_obj = self.node_to_obj 236 | node_to_obj[obj_node] = obj_instance 237 | self.node_to_obj = node_to_obj 238 | assert obj_node in list(self.node_to_obj.keys()) 239 | self.sample_to_grid(obj_node) 240 | 241 | def _end_conditions(self): 242 | assert self.target_poses, "This function should only be called after set_mission" 243 | for target_pos in self.target_poses: 244 | if np.all(target_pos == self.front_pos) or self.step_count == self.max_steps: 245 | return True 246 | return False 247 | 248 | def set_mission(self, goal): 249 | """ 250 | Sets the mission of the env 251 | """ 252 | assert isinstance(goal, int) or isinstance(goal, str), "Expecting either obj index or obj name" 253 | 254 | if isinstance(goal, int): # Setting target by obj idx 255 | obj_label = IDX_TO_OBJECT[goal] 256 | obj_idx = goal 257 | elif isinstance(goal, str): 258 | obj_label = goal.lower() 259 | obj_idx = OBJECT_TO_IDX[obj_label] 260 | self.goal_obj_label = obj_label 261 | 262 | assert obj_label in self.objs.keys(), "Goal object not sampled in current scene." 263 | self.target_poses = [target_obj.cur_pos for target_obj in self.objs[obj_label]] 264 | 265 | # Set mission 266 | if self.mission_mode == 'one_hot': 267 | self.mission = np.eye(len(IDX_TO_OBJECT))[obj_idx] 268 | elif self.mission_mode == 'int': 269 | self.mission = obj_idx 270 | elif self.mission_mode == 'word_vec': 271 | model_inps = obj_label.split('_') 272 | vec = np.zeros((25)) 273 | for inp in model_inps: 274 | vec += self.word2vec_model.get_vector(inp, norm=True) 275 | vec /= len(model_inps) 276 | self.mission = vec 277 | else: 278 | assert "Missing obs mode" 279 | 280 | if self.set_goal_icon: # Set icon of goal object to be green 281 | goal_objs = self.objs[obj_label] 282 | for goal_obj in goal_objs: 283 | goal_obj.icon_color = 'green' 284 | 285 | def get_possible_missions(self): 286 | all_object_nodes = self.scene.scene_graph.get_nodes_with_type(NodeType.OBJECT) 287 | all_obj_labels = [self.node_to_obj[node].type for node in all_object_nodes] 288 | return all_obj_labels 289 | 290 | def set_random_mission(self): 291 | all_goals = self.get_possible_missions() 292 | random_goal = random.choice(all_goals) 293 | self.set_mission(random_goal) 294 | 295 | def set_mission_by_node(self, node): 296 | self.goal_node = node 297 | self.goal_obj_label = self.node_to_obj[node].type 298 | self.set_mission(self.goal_obj_label) 299 | -------------------------------------------------------------------------------- /memsearch/igridson_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from memsearch.scene import * 3 | from mini_behavior.utils.navigate import * 4 | from memsearch.igridson_env import * 5 | from mini_behavior.minibehavior import MiniBehaviorEnv 6 | from memsearch.tasks import TaskType 7 | from memsearch.agents import make_agent 8 | from memsearch.dataset import make_featurizers 9 | 10 | DEFAULT_ROOM_LAYOUT = { 11 | 'bedroom': {'bed': 1, 12 | 'desk': 1, 13 | 'chair': 1 14 | }, 15 | 'kitchen': {'fridge': 1, 16 | 'counter_top': 2, 17 | # 'can': 1 18 | }, 19 | 'bathroom': {'toilet': 1, 20 | 'sink': 1 21 | }, 22 | 'living_room': {'chair': 10, 23 | 'dining_table': 1, 24 | 'shelving_unit': 1, 25 | 'sofa': 1 26 | } 27 | } 28 | 29 | 30 | def make_agent_type(cfg, agent_type, task_type, scene_sampler, scene_evolver): 31 | node_featurizer, edge_featurizer = make_featurizers(cfg.model, False, cfg.task.num_steps) 32 | if 'sgm' in agent_type: 33 | model_config = cfg.model 34 | else: 35 | model_config = None 36 | agent = make_agent(cfg.agents, agent_type, task_type, node_featurizer, edge_featurizer, scene_sampler, 37 | scene_evolver, False, model_config) 38 | return agent 39 | 40 | 41 | def get_query(env): 42 | all_object_nodes = env.scene.scene_graph.get_nodes_with_type(NodeType.OBJECT) 43 | sample_from = all_object_nodes if len(env.moved_objs) == 0 else env.moved_objs 44 | return random.choice(sample_from) 45 | 46 | 47 | def simulate_agent(agent, env, query, task, max_attempts=5): 48 | current_scene = env.scene_evolver.scene 49 | task.goal_node = query 50 | env.set_mission_by_node(query) 51 | 52 | prediction = agent.make_predictions(TaskType.FIND_OBJECT, current_scene, query) 53 | obs, score, done, info = task.step(prediction, agent, current_scene, query, max_attempts=max_attempts) 54 | agent.receive_observation(obs) 55 | agent.step() 56 | 57 | return obs, score, done, info 58 | 59 | 60 | def evolve(env, window, agent, task): 61 | env.evolve() 62 | redraw(env, window) 63 | 64 | current_scene = env.scene_evolver.scene 65 | agent.transition_to_new_scene(current_scene) 66 | task.scene = env.scene_evolver.scene 67 | for edge in task.scene.scene_graph.get_edges(): 68 | edge.age += 1 69 | for node in env.moved_objs: 70 | node.times_moved += 1 71 | task.current_ep += 1 72 | 73 | 74 | def step(action, env, window): 75 | obs, reward, done, info = env.step(action) 76 | redraw(env, window) 77 | 78 | 79 | def is_obj_at_node(node, goal_node): 80 | for child in node.get_children_nodes(): 81 | if child.description == goal_node.description: 82 | return True 83 | return False 84 | 85 | def get_astar_path(env, visited_nodes, window=None, save_dir=None, exp_name=None): 86 | def choose_node(node): 87 | if type(node) == PriorsNode or type(node) == SGMNode: 88 | raise ValueError("node of type {} found. Expected scene graph node".format(type(node))) 89 | else: 90 | return node 91 | goal_instance = env.node_to_obj[env.goal_node] 92 | goal_instance.icon_color = 'green' 93 | env.place_agent(i=0, j=0) 94 | env.step_count = 0 95 | curr_agent_pos = None 96 | total_path_length = 0 97 | 98 | for node in visited_nodes: 99 | obj_node = choose_node(node) 100 | path, actions, end_pos = get_path_and_actions(env, obj_node, curr_agent_pos) 101 | curr_agent_pos = end_pos 102 | total_path_length += len(path) 103 | 104 | if window: 105 | save_path = get_save_path(save_dir, [exp_name, obj_node.unique_id]) 106 | for action in actions: 107 | step(MiniBehaviorEnv.Actions(action), env, window) 108 | save_img(window, os.path.join(save_path, f'x_pos{end_pos[0]}_y_pos{end_pos[1]}_step_{env.step_count}.png')) 109 | 110 | return total_path_length 111 | 112 | def get_path_and_actions(env, target_node, agent_pos=None): 113 | maze = env.grid.get_maze() 114 | obj_instance = env.node_to_obj[target_node] 115 | 116 | if not agent_pos: 117 | start_pos = env.agent_pos 118 | else: 119 | start_pos = agent_pos 120 | if type(obj_instance) == FurnitureObj: 121 | i, j = random.choice(obj_instance.all_pos) 122 | else: 123 | i, j = obj_instance.cur_pos 124 | end_pos = get_pos_next_to_obj(env, i, j) 125 | assert end_pos is not None, 'not able to reach obj {}, {}'.format(i, j) 126 | 127 | start_room = env.room_from_pos(*(start_pos[1], start_pos[0])) 128 | end_room = env.room_from_pos(*(end_pos[1], end_pos[0])) 129 | path = navigate_between_rooms(start_pos, end_pos, start_room, end_room, maze) 130 | actions = get_actions(env.agent_dir, path) 131 | return path, actions, end_pos 132 | 133 | def nodes_to_coords(env, nodes): 134 | coods = [] 135 | for node in nodes: 136 | obj_instance = env.node_to_obj[node] 137 | if type(obj_instance) == FurnitureObj: 138 | i, j = random.choice(obj_instance.all_pos) 139 | else: 140 | i, j = obj_instance.cur_pos 141 | coods.append((i,j)) 142 | return coods 143 | 144 | def cumulative_manhattan_distance(visited_nodes): 145 | total_dist = 0.0 146 | manhattan_dist = lambda x : (abs(x[0][0] - x[1][0]) + abs(x[0][1] - abs(x[1][1]))) 147 | for i in range(len(visited_nodes) - 1): 148 | total_dist += manhattan_dist([visited_nodes[i], visited_nodes[i+1]]) 149 | return total_dist 150 | 151 | def get_pos_next_to_obj(env, i, j): 152 | maze = env.grid.get_maze() 153 | for pos in [(i, j - 1), (i, j + 1), (i - 1, j), (i + 1, j)]: 154 | # if env.grid.is_empty(*pos): 155 | # return pos 156 | if maze[pos[1]][pos[0]] == 0: 157 | return tuple(pos) 158 | return None 159 | 160 | 161 | def get_save_path(save_dir, save_name): 162 | if not os.path.exists(save_dir): 163 | os.mkdir(save_dir) 164 | path = save_dir 165 | for subdir in save_name: 166 | path = os.path.join(path, subdir) 167 | if not os.path.exists(path): 168 | os.mkdir(path) 169 | return path 170 | 171 | 172 | def make_scene_sampler_and_evolver_(cfg): 173 | priors_graph = load_priors_graph(cfg.scene_priors_type) 174 | scene_sampler = PriorsSceneSampler(priors_graph, 175 | min_floors=cfg.min_floors, 176 | max_floors=cfg.max_floors, 177 | min_rooms=cfg.min_rooms, 178 | max_rooms=cfg.max_rooms, 179 | min_furniture=cfg.min_furniture, 180 | max_furniture=cfg.max_furniture, 181 | min_objects=cfg.min_objects, 182 | max_objects=cfg.max_objects, 183 | priors_noise=cfg.priors_noise, 184 | sparsity_level=cfg.priors_sparsity_level, 185 | room_layout=DEFAULT_ROOM_LAYOUT) 186 | if cfg.random_evolver: 187 | scene_evolver = RandomSceneEvolver() 188 | else: 189 | scene_evolver = PriorsSceneEvolver(priors_graph, 190 | scene_sampler, 191 | object_move_ratio_per_step=cfg.object_move_ratio_per_step, 192 | prior_nodes_from_scene=cfg.evolver_priors_nodes_from_scene, 193 | use_move_freq=cfg.use_move_freq, 194 | add_or_remove_objs=cfg.add_or_remove_objs, 195 | sparsity_level=cfg.priors_sparsity_level) 196 | 197 | return scene_sampler, scene_evolver 198 | 199 | 200 | 201 | TILE_PIXELS = 32 202 | 203 | 204 | def redraw(env, window): 205 | if window is not None: 206 | img = env.render('rgb_array', tile_size=TILE_PIXELS) 207 | window.no_closeup() 208 | if hasattr(env, 'goal_node'): 209 | caption = f"Goal: {env.goal_obj_label}" 210 | window.set_caption(caption) 211 | window.show_img(img) 212 | 213 | def save_img(window, pathname='saved_grid.png'): 214 | window.save_img(pathname) 215 | 216 | 217 | def reset(env, window, task, agent): 218 | obs = env.reset() 219 | agent.transition_to_new_scene(env.scene_evolver.scene) 220 | task.scene = env.scene_evolver.scene 221 | for edge in task.scene.scene_graph.get_edges(): 222 | edge.age += 1 223 | for node in env.moved_objs: 224 | node.times_moved += 1 225 | task.current_ep += 1 226 | 227 | if window is not None: 228 | redraw(env, window) 229 | 230 | return obs 231 | 232 | 233 | def check_graph_and_grid(env): 234 | graph = env.scene.scene_graph 235 | node_to_obj = env.node_to_obj 236 | for furniture_node in graph.get_nodes_with_type(NodeType.FURNITURE): 237 | for obj_node in [node for node in furniture_node.get_children_nodes() if node.type == NodeType.OBJECT]: 238 | obj = node_to_obj[obj_node] 239 | furniture = node_to_obj[furniture_node] 240 | 241 | edge = [e for e in obj_node.edges if e.node1 == furniture_node or e.node2 == furniture_node][0] 242 | edge_type = edge.type if edge.node1 == obj_node else RECIPROCAL_EDGE_TYPES[edge.type] 243 | check_state(env, obj, furniture, edge_type) 244 | 245 | 246 | def check_state(env, obj, furniture, edge_type): 247 | if edge_type.value == "in": 248 | assert obj.check_rel_state(env, furniture, 'inside') 249 | elif edge_type.value == "contains": 250 | assert furniture.check_rel_state(env, obj, 'inside') 251 | else: 252 | assert obj.check_rel_state(env, furniture, edge_type.value) 253 | -------------------------------------------------------------------------------- /memsearch/metrics.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | from memsearch.tasks import TaskType 4 | from abc import abstractmethod 5 | from enum import Enum 6 | from memsearch.tasks import TaskType 7 | 8 | def rename(agent_type): 9 | if agent_type == 'counts': 10 | agent_type = 'Frequentist' 11 | elif agent_type == 'memorization': 12 | agent_type = 'Myopic' 13 | elif agent_type == 'upper_bound': 14 | agent_type = 'Oracle' 15 | elif agent_type == 'sgm_heat': 16 | agent_type = 'NES' 17 | else: 18 | agent_type = agent_type.capitalize() 19 | return agent_type 20 | 21 | AGENT_COLORS = { 22 | 'Random': 'brown', 23 | 'Priors': 'orange', 24 | 'Frequentist': 'red', 25 | 'Myopic': 'green', 26 | 'Oracle': 'purple', 27 | 'SGM': 'blue', 28 | } 29 | 30 | class PlotType(Enum): 31 | LINE = "line" 32 | BAR = "bar" 33 | 34 | class Metric(object): 35 | metric_id = 0 36 | 37 | def __init__(self, name, save_dir, plot_type): 38 | Metric.metric_id += 1 39 | self.fig_num = Metric.metric_id 40 | self.name = name 41 | self.plot_type = plot_type 42 | self.save_dir = save_dir 43 | self.agent_evals = {} 44 | self.agent_evals_var = {} 45 | 46 | @abstractmethod 47 | def add_data_to_plot(self, x, y, y_std, label, log_f, **kwargs): 48 | pass 49 | 50 | @abstractmethod 51 | def make_plot(self, agent_types, task, **kwargs): 52 | pass 53 | 54 | @abstractmethod 55 | def get_metric_name(self): 56 | pass 57 | 58 | def save_plot(self, save_name=None): 59 | if save_name is None: 60 | save_name = self.name 61 | plt.figure(self.fig_num) 62 | plt.savefig(f'{self.save_dir}/{save_name}.png') 63 | 64 | def show_plot(self): 65 | plt.figure(self.fig_num) 66 | plt.show() 67 | 68 | def add_data_to_csv(self, log_f, data): 69 | data = ["{:.3f}".format(data_i) for data_i in data] 70 | log_str = ",".join(data) + '\n' 71 | log_f.write(log_str) 72 | 73 | class AvgAccuracy(Metric): 74 | def __init__(self, name, save_dir, ymin=0, ymax=1.0, use_std=True): 75 | super().__init__(name, save_dir, PlotType.LINE) 76 | self.ymin = ymin 77 | self.ymax = ymax 78 | self.use_std = use_std 79 | 80 | def get_metric_name(self): 81 | return 'average_accuracy' 82 | 83 | def add_data_to_plot(self, x, y, y_std, label): 84 | plt.figure(self.fig_num)#, figsize=(12.0,8.5)) 85 | dy = (y_std**2) 86 | plt.plot(x, y, label=label) 87 | if self.use_std: 88 | plt.fill_between(x, y - dy, y + dy, alpha=0.2) 89 | self.agent_evals[label] = np.mean(y) 90 | self.agent_evals_var[label] = np.mean(dy) 91 | 92 | def make_plot(self, agent_types, task): 93 | plt.figure(self.fig_num)#, figsize=(12.0,8.5)) 94 | plt.ylim(self.ymin, self.ymax) 95 | if task == TaskType.FIND_OBJECT: 96 | plt.legend(loc='center left', bbox_to_anchor=(1.0,0.5),ncol=1, fancybox=True, shadow=True) 97 | plt.title('Number of Tries to Find the Object vs Step') 98 | else: 99 | plt.legend(loc='upper center', bbox_to_anchor=(0.5, 0.98),ncol=3, fancybox=True, shadow=True) 100 | plt.title('Predict Location Average Accuracy vs Step') 101 | plt.tight_layout() 102 | 103 | class AvgAUC(Metric): 104 | def __init__(self, name, save_dir): 105 | super().__init__(name, save_dir, PlotType.BAR) 106 | 107 | def get_metric_name(self): 108 | return 'average_auc' 109 | 110 | def add_data_to_plot(self, x, y, y_std, label): 111 | plt.figure(self.fig_num) 112 | if isinstance(y, np.ndarray): 113 | auc = np.sum(y) 114 | auc_std = np.sqrt(np.sum(y_std**2)) 115 | else: 116 | auc = y 117 | auc_std = y_std 118 | plt.bar(x, auc, yerr=auc_std, align='center', alpha=0.5, ecolor='black', capsize=10, label=label) 119 | self.agent_evals[label] = auc/100.0 120 | self.agent_evals_var[label] = np.mean(auc_std)/100.0 121 | 122 | def make_plot(self, agent_types, task): 123 | plt.figure(self.fig_num) 124 | plt.title('Predict Location Overall Average Accuracy') 125 | plt.xticks(np.arange(len(agent_types)), labels=agent_types) 126 | plt.grid(visible=True, axis='y') 127 | # set x spacing so that labels dont overlap 128 | plt.gca().margins(x=0) 129 | plt.gcf().canvas.draw() 130 | tl = plt.gca().get_xticklabels() 131 | maxsize = max([t.get_window_extent().width for t in tl]) 132 | m = 0.2 # inch margin 133 | s = maxsize/plt.gcf().dpi*len(agent_types)+2*m 134 | margin = m/plt.gcf().get_size_inches()[0] 135 | plt.gcf().subplots_adjust(left=margin, right=1.-margin) 136 | plt.gcf().set_size_inches(s, plt.gcf().get_size_inches()[1]) 137 | plt.gcf().tight_layout() 138 | 139 | class DiscSumOfRewards(Metric): 140 | def __init__(self, name, save_dir): 141 | self.plot_type = PlotType.BAR 142 | super().__init__(name, save_dir, PlotType.BAR) 143 | 144 | def get_metric_name(self): 145 | return 'average_disc_sum_rewards' 146 | 147 | def add_data_to_plot(self, x, y, y_std, label): 148 | plt.figure(self.fig_num) 149 | discount_fac = 0.99 150 | if isinstance(y, np.ndarray): 151 | discount_coeffs = [discount_fac**i for i in range(1, len(y)+1)] 152 | disc_sum = np.dot(discount_coeffs, y) 153 | disc_sum_std = np.sqrt(np.dot(discount_coeffs, y_std**2)) 154 | else: 155 | disc_sum = y 156 | disc_sum_std = y_std 157 | plt.bar(x, disc_sum, yerr=disc_sum_std, align='center', alpha=0.5, ecolor='black', capsize=10, label=label) 158 | self.agent_evals[label] = disc_sum 159 | self.agent_evals_var[label] = disc_sum_std 160 | 161 | def make_plot(self, agent_types, task): 162 | plt.figure(self.fig_num) 163 | plt.title('Discounted Sum of Rewards') 164 | plt.xticks(np.arange(len(agent_types)), labels=agent_types) 165 | plt.grid(visible=True, axis='y') 166 | # set x spacing so that labels dont overlap 167 | plt.gca().margins(x=0) 168 | plt.gcf().canvas.draw() 169 | tl = plt.gca().get_xticklabels() 170 | maxsize = max([t.get_window_extent().width for t in tl]) 171 | m = 0.2 # inch margin 172 | s = maxsize/plt.gcf().dpi*len(agent_types)+2*m 173 | margin = m/plt.gcf().get_size_inches()[0] 174 | plt.gcf().subplots_adjust(left=margin, right=1.-margin) 175 | plt.gcf().set_size_inches(s, plt.gcf().get_size_inches()[1]) 176 | plt.gcf().tight_layout() 177 | 178 | class PercentObjectsFound(Metric): 179 | def __init__(self, name, save_dir, top_k): 180 | super().__init__(name, save_dir, PlotType.BAR) 181 | self.top_k = top_k 182 | 183 | def get_metric_name(self): 184 | return 'percent_objects_found' 185 | 186 | def add_data_to_plot(self, x, score_matrix, label): 187 | plt.figure(self.fig_num) 188 | num_steps = score_matrix.shape[1] 189 | num_scenes = score_matrix.shape[0] 190 | y = np.array( 191 | [np.count_nonzero(scene_scores != self.top_k+1) for scene_scores in score_matrix] 192 | ) 193 | percent_found = np.sum(y) / (num_scenes * num_steps) 194 | plt.bar(x, percent_found, align='center', alpha=0.5, ecolor='black', capsize=10, label=label) 195 | self.agent_evals[label] = percent_found 196 | self.agent_evals_var[label] = 0.0 197 | 198 | def make_plot(self, agent_types, task): 199 | plt.figure(self.fig_num) 200 | plt.title('Percent Objects Found') 201 | plt.xticks(np.arange(len(agent_types)), labels=agent_types) 202 | plt.grid(visible=True, axis='y') 203 | # set x spacing so that labels dont overlap 204 | plt.gca().margins(x=0) 205 | plt.gcf().canvas.draw() 206 | tl = plt.gca().get_xticklabels() 207 | maxsize = max([t.get_window_extent().width for t in tl]) 208 | m = 0.2 # inch margin 209 | s = maxsize/plt.gcf().dpi*len(agent_types)+2*m 210 | margin = m/plt.gcf().get_size_inches()[0] 211 | plt.gcf().subplots_adjust(left=margin, right=1.-margin) 212 | plt.gcf().set_size_inches(s, plt.gcf().get_size_inches()[1]) 213 | plt.tight_layout() 214 | 215 | class PercObjectsFoundOverTime(Metric): 216 | def __init__(self, name, save_dir, top_k): 217 | super().__init__(name, save_dir, PlotType.LINE) 218 | self.top_k = top_k 219 | 220 | def get_metric_name(self): 221 | return 'percent_objects_found' 222 | 223 | def add_data_to_plot(self, x, score_matrix, label): 224 | plt.figure(self.fig_num)#, figsize=(12.0,8.5)) 225 | num_steps = score_matrix.shape[1] 226 | num_scenes = score_matrix.shape[0] 227 | y_object_found_arr = (score_matrix != self.top_k+1) 228 | y = np.array( 229 | [np.count_nonzero(y_object_found_arr[:, step_i]) / num_scenes for step_i in range(num_steps)] 230 | ) 231 | print(y) 232 | y_std = np.std(y_object_found_arr, axis=0) 233 | dy = (y_std**2) 234 | plt.plot(x, y, label=label) 235 | # plt.fill_between(x, y - dy, y + dy, alpha=0.2) 236 | self.agent_evals[label] = np.mean(y) 237 | self.agent_evals_var[label] = np.mean(dy) 238 | 239 | def make_plot(self, agent_types, task): 240 | plt.figure(self.fig_num, figsize=(12.0,25.0)) 241 | plt.legend(loc='center left', bbox_to_anchor=(1.0, 0.5), 242 | ncol=1, fancybox=True, shadow=True) 243 | plt.ylim(0,1.0) 244 | plt.title('Percent Objects Found vs Step') 245 | plt.tight_layout() 246 | 247 | class AvgNumAttempts(Metric): 248 | def __init__(self, name, save_dir): 249 | super().__init__(name, save_dir, PlotType.BAR) 250 | 251 | def get_metric_name(self): 252 | return 'avg_num_attempts' 253 | 254 | def add_data_to_plot(self, x, score_matrix, label): 255 | plt.figure(self.fig_num) 256 | avg_num_steps = np.mean(score_matrix) 257 | plt.bar(x, avg_num_steps, align='center', alpha=0.5, ecolor='black', capsize=10, label=label) 258 | self.agent_evals[label] = avg_num_steps 259 | self.agent_evals_var[label] = 0.0 260 | 261 | def make_plot(self, agent_types, task): 262 | plt.figure(self.fig_num) 263 | plt.title('Average Number of Attempts') 264 | plt.xticks(np.arange(len(agent_types)), labels=agent_types) 265 | plt.grid(visible=True, axis='y') 266 | # set x spacing so that labels dont overlap 267 | plt.gca().margins(x=0) 268 | plt.gcf().canvas.draw() 269 | tl = plt.gca().get_xticklabels() 270 | maxsize = max([t.get_window_extent().width for t in tl]) 271 | m = 0.2 # inch margin 272 | s = maxsize/plt.gcf().dpi*len(agent_types)+2*m 273 | margin = m/plt.gcf().get_size_inches()[0] 274 | plt.gcf().subplots_adjust(left=margin, right=1.-margin) 275 | plt.gcf().set_size_inches(s, plt.gcf().get_size_inches()[1]) 276 | plt.tight_layout() 277 | 278 | def plot_agent_eval(num_steps, 279 | score_vecs, 280 | agent_type, 281 | x_ind, 282 | metrics, 283 | smoothing_kernel_size=10, 284 | task=TaskType.PREDICT_LOC, 285 | show_fig=False, 286 | save_fig=False, 287 | x_labels=None): 288 | """ 289 | Adds plot information for each agent to the figure. 290 | Makes the final plot whenever show_fig or save_fig are turned on. 291 | 292 | args: 293 | num_steps: Number of steps the agent was evaluated for 294 | score_vecs: Score vectors containing performance metrics for the agent 295 | agent_type: Name of the agent 296 | plt_axs: Matplotlib figure axes on which plots are to be made. Currently configured for a figure with 3 subplots. 297 | x_ind: x index to be used for this agent when making bar plots 298 | metrics: metrics to plot for each agent. Must be an instance of the Metrics class 299 | show_fig: To be turned on if a figure is only to be shown 300 | save_fig: To be turn on if a figure is to be saved 301 | plot_args: List [cfg, agent_types] to be used for making the final plot 302 | """ 303 | if smoothing_kernel_size!=0: 304 | kernel = np.ones(smoothing_kernel_size) / float(smoothing_kernel_size) 305 | smoothed_score_vecs = [] 306 | for score_vec in score_vecs: 307 | smoothed_score_vec = np.convolve(score_vec, kernel, mode='valid') 308 | smoothed_score_vecs.append(smoothed_score_vec) 309 | 310 | x = np.array(range(num_steps-smoothing_kernel_size+1)) 311 | score_matrix = np.stack(smoothed_score_vecs) 312 | else: 313 | x = np.array(range(num_steps)) 314 | score_matrix = np.stack(score_vecs) 315 | y = np.mean(score_matrix, axis=0) 316 | y_std = np.std(score_matrix, axis=0) 317 | 318 | x_labels = x_labels[:] 319 | for i,agent_t in enumerate(x_labels): 320 | x_labels[i] = rename(agent_t) 321 | agent_type = rename(agent_type) 322 | 323 | for metric in metrics: 324 | if type(metric) is AvgAccuracy: 325 | metric.add_data_to_plot(x, y, y_std, agent_type) 326 | elif type(metric) in [PercentObjectsFound, AvgNumAttempts]: 327 | metric.add_data_to_plot(x_ind, score_matrix, agent_type) 328 | elif type(metric) is PercObjectsFoundOverTime: 329 | metric.add_data_to_plot(x, score_matrix, agent_type) 330 | else: 331 | metric.add_data_to_plot(x_ind, y, y_std, agent_type) 332 | 333 | if show_fig or save_fig: 334 | # Final Plot 335 | for metric in metrics: 336 | metric.make_plot(x_labels, task) 337 | if save_fig: 338 | metric.save_plot() 339 | if show_fig: 340 | plt.show() 341 | 342 | def store_agent_eval(log_dir, scores, agent_type): 343 | log_csv_path = f'{log_dir}/{agent_type}/eval.csv' 344 | with open(log_csv_path, 'a') as csv_file: 345 | for eval_run in scores: 346 | eval_run_str = ["{:.3f}".format(score) for score in eval_run] 347 | log_str = ",".join(eval_run_str)+'\n' 348 | csv_file.write(log_str) 349 | -------------------------------------------------------------------------------- /memsearch/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from memsearch.dataset import GRAPH_METADATA, REVERSED_GRAPH_METADATA 4 | from torch_geometric.nn import GCNConv, HEATConv, HGTConv, HANConv, HeteroConv, SAGEConv 5 | from torch_geometric.utils import add_self_loops 6 | from pathlib import Path 7 | from enum import Enum 8 | import torch_geometric.transforms as T 9 | 10 | class NodeFuseMethod(Enum): 11 | CONCAT = "concat" 12 | AVERAGE = "avg" 13 | MULTIPLY = "mult" 14 | 15 | # TODO try TransformerConv, GATv2Conv,RGATConv, GENConv, GeneralConv 16 | class MLPEdgeClassifier(torch.nn.Module): 17 | def __init__(self, 18 | node_features_dim, 19 | edge_features_dim, 20 | node_embedding_dim = 32, 21 | edge_embedding_dim = 32, 22 | node_mlp_hidden_layers = [32, 32], 23 | edge_mlp_hidden_layers = [32, 32], 24 | output_mlp_hidden_layers = [32, 32], 25 | node_fuse_method = NodeFuseMethod.CONCAT, 26 | include_transformer=True, 27 | reversed_edges=True): 28 | super().__init__() 29 | self.reversed_edges = reversed_edges 30 | self.node_embed_mlp = nn.ModuleList() 31 | fused_features_size = node_features_dim 32 | for layer_size in node_mlp_hidden_layers + [node_embedding_dim]: 33 | self.node_embed_mlp.append(nn.Linear(fused_features_size, layer_size)) 34 | self.node_embed_mlp.append(nn.ReLU()) 35 | fused_features_size = layer_size 36 | 37 | self.edge_embed_mlp = nn.ModuleList() 38 | fused_features_size = edge_features_dim 39 | for layer_size in edge_mlp_hidden_layers + [edge_embedding_dim]: 40 | self.edge_embed_mlp.append(nn.Linear(fused_features_size, layer_size)) 41 | self.edge_embed_mlp.append(nn.ReLU()) 42 | fused_features_size = layer_size 43 | 44 | self.node_fuse_method = node_fuse_method 45 | self.output_mlp = nn.ModuleList() 46 | if node_fuse_method != NodeFuseMethod.CONCAT: 47 | fused_features_size = node_embedding_dim + edge_embedding_dim 48 | else: 49 | fused_features_size = node_embedding_dim*2 + edge_embedding_dim 50 | 51 | self.include_transformer = include_transformer 52 | if self.include_transformer: 53 | transformer_layer = torch.nn.TransformerEncoderLayer(d_model=fused_features_size, 54 | nhead=2, 55 | dim_feedforward=fused_features_size, 56 | dropout=0.25, 57 | batch_first=True) 58 | self.transformer_encoder = torch.nn.TransformerEncoder(transformer_layer, num_layers=2) 59 | 60 | for layer_size in output_mlp_hidden_layers: 61 | self.output_mlp.append(nn.Linear(fused_features_size, layer_size)) 62 | self.output_mlp.append(nn.ReLU()) 63 | fused_features_size = layer_size 64 | self.output_layer = nn.Linear(layer_size, 1) 65 | self.sigmoid = nn.Sigmoid() 66 | 67 | def forward(self, 68 | pyg_data, 69 | classify_edge_index, 70 | classify_edge_features): 71 | node_features = self.embed_nodes(pyg_data) 72 | edge_features = self.embed_edges(classify_edge_features) 73 | fused_features = self.fuse_features(node_features, edge_features, classify_edge_index) 74 | return self.classify_edges(fused_features) 75 | 76 | def embed_nodes(self, pyg_data): 77 | node_embeddings = pyg_data.x 78 | for layer in self.node_embed_mlp: 79 | node_embeddings = layer(node_embeddings) 80 | return node_embeddings 81 | 82 | def embed_edges(self, edge_features): 83 | edge_embeddings = edge_features 84 | for layer in self.edge_embed_mlp: 85 | edge_embeddings = layer(edge_embeddings) 86 | return edge_embeddings 87 | 88 | def fuse_features(self, node_embeddings, edge_embeddings, classify_edge_index): 89 | if self.node_fuse_method == NodeFuseMethod.CONCAT: 90 | fused_features = torch.cat([node_embeddings[classify_edge_index[:,:,0]], 91 | node_embeddings[classify_edge_index[:,:,1]]], dim=-1) 92 | elif self.node_fuse_method == NodeFuseMethod.AVERAGE: 93 | fused_features = (node_embeddings[classify_edge_index[:,:,0]] + 94 | node_embeddings[classify_edge_index[:,:,1]])/2.0 95 | elif self.node_fuse_method == NodeFuseMethod.MULTIPLY: 96 | fused_features = (node_embeddings[classify_edge_index[:,:,0]] * 97 | node_embeddings[classify_edge_index[:,:,1]]) 98 | fused_features = torch.cat([fused_features, edge_embeddings], dim=-1) 99 | 100 | if self.include_transformer: 101 | padding_mask = (classify_edge_index[:,:,0].eq(0) & classify_edge_index[:,:,1].eq(0)) 102 | fused_features[padding_mask] = 0.0 103 | fused_features = self.transformer_encoder(fused_features, src_key_padding_mask=padding_mask) 104 | 105 | return fused_features 106 | 107 | def classify_edges(self, fused_features): 108 | for layer in self.output_mlp: 109 | fused_features = layer(fused_features) 110 | output = self.sigmoid(self.output_layer(fused_features)) 111 | return torch.squeeze(output) 112 | 113 | def is_heterogenous(self): 114 | return False 115 | 116 | def is_recurrent(self): 117 | return False 118 | 119 | def get_model_type(self): 120 | return type(self).__name__ 121 | 122 | class GNN(MLPEdgeClassifier): 123 | def __init__(self, 124 | conv_layers, 125 | node_features_dim, 126 | edge_features_dim, 127 | node_embedding_dim = 32, 128 | edge_embedding_dim = 32, 129 | node_mlp_hidden_layers = [32, 32], 130 | edge_mlp_hidden_layers = [32, 32], 131 | output_mlp_hidden_layers = [32, 32], 132 | node_fuse_method = NodeFuseMethod.CONCAT, 133 | include_transformer=True, 134 | add_self_loops=True, 135 | reversed_edges=True): 136 | super().__init__( 137 | node_features_dim, 138 | edge_features_dim, 139 | node_embedding_dim, 140 | edge_embedding_dim, 141 | node_mlp_hidden_layers, 142 | edge_mlp_hidden_layers, 143 | output_mlp_hidden_layers, 144 | node_fuse_method, 145 | include_transformer, 146 | reversed_edges) 147 | 148 | self.conv_layers = nn.ModuleList(conv_layers) 149 | self.add_self_loops = add_self_loops 150 | 151 | def forward(self, 152 | pyg_data, 153 | classify_edge_index, 154 | classify_edge_features): 155 | if self.add_self_loops: 156 | pyg_data.edge_index, pyg_data.edge_attr = add_self_loops(pyg_data.edge_index, pyg_data.edge_attr) 157 | node_features = self.embed_nodes(pyg_data) 158 | edge_features = self.embed_edges(classify_edge_features) 159 | fused_features = self.fuse_features(node_features, edge_features, classify_edge_index) 160 | return self.classify_edges(fused_features) 161 | 162 | def embed_nodes(self, pyg_data): 163 | node_embeddings = pyg_data.x 164 | for layer in self.node_embed_mlp: 165 | node_embeddings = layer(node_embeddings) 166 | 167 | edge_index = pyg_data.edge_index 168 | # edge_weight = pyg_data.edge_attr 169 | for layer in self.conv_layers: 170 | node_embeddings = layer(node_embeddings, edge_index) 171 | return node_embeddings 172 | 173 | class GCN(GNN): 174 | def __init__(self, 175 | node_features_dim, 176 | edge_features_dim, 177 | node_embedding_dim = 32, 178 | edge_embedding_dim = 32, 179 | node_mlp_hidden_layers = [32, 32], 180 | edge_mlp_hidden_layers = [32, 32], 181 | gnn_hidden_layers=[32], 182 | output_mlp_hidden_layers = [32, 32], 183 | node_fuse_method = NodeFuseMethod.CONCAT, 184 | include_transformer=True, 185 | add_self_loops=True, 186 | reversed_edges=True): 187 | conv_layers = [] 188 | input_dim = node_embedding_dim 189 | for layer_size in gnn_hidden_layers+[node_embedding_dim]: 190 | conv_layers.append(GCNConv(input_dim, layer_size)) 191 | input_dim = layer_size 192 | 193 | super().__init__( 194 | conv_layers, 195 | node_features_dim, 196 | edge_features_dim, 197 | node_embedding_dim, 198 | edge_embedding_dim, 199 | node_mlp_hidden_layers, 200 | edge_mlp_hidden_layers, 201 | output_mlp_hidden_layers, 202 | node_fuse_method, 203 | include_transformer, 204 | add_self_loops, 205 | reversed_edges) 206 | 207 | class HEAT(GNN): 208 | def __init__(self, 209 | node_features_dim, 210 | edge_features_dim, 211 | node_embedding_dim = 32, 212 | edge_embedding_dim = 32, 213 | node_mlp_hidden_layers = [32, 32], 214 | edge_mlp_hidden_layers = [32, 32], 215 | gnn_hidden_layers=[32], 216 | output_mlp_hidden_layers = [32, 32], 217 | node_fuse_method = NodeFuseMethod.CONCAT, 218 | include_transformer=True, 219 | add_self_loops=False, 220 | reversed_edges=True, 221 | num_node_types=5, 222 | num_edge_types=5, 223 | edge_type_emb_dim=2): 224 | conv_layers = [] 225 | input_dim = node_embedding_dim 226 | for layer_size in gnn_hidden_layers+[node_embedding_dim]: 227 | conv = HEATConv( 228 | in_channels=input_dim, 229 | out_channels=layer_size, 230 | num_node_types=num_node_types, 231 | num_edge_types=num_edge_types, 232 | edge_type_emb_dim=edge_type_emb_dim, 233 | edge_dim=edge_embedding_dim, 234 | edge_attr_emb_dim=edge_embedding_dim, 235 | dropout=0.05, 236 | heads=1 237 | ) 238 | conv_layers.append(conv) 239 | input_dim = layer_size 240 | super().__init__( 241 | conv_layers, 242 | node_features_dim, 243 | edge_features_dim, 244 | node_embedding_dim, 245 | edge_embedding_dim, 246 | node_mlp_hidden_layers, 247 | edge_mlp_hidden_layers, 248 | output_mlp_hidden_layers, 249 | node_fuse_method, 250 | include_transformer, 251 | add_self_loops, 252 | reversed_edges) 253 | 254 | def forward(self, 255 | pyg_data, 256 | classify_edge_index, 257 | classify_edge_features): 258 | if self.add_self_loops: 259 | pyg_data.edge_index, pyg_data.edge_attr = add_self_loops(pyg_data.edge_index, pyg_data.edge_attr) 260 | node_types, edge_types, classify_edge_features = self.extract_type_lists(pyg_data, classify_edge_features) 261 | node_features = self.embed_nodes(pyg_data, node_types, edge_types) 262 | edge_features = self.embed_edges(classify_edge_features) 263 | fused_features = self.fuse_features(node_features, edge_features, classify_edge_index) 264 | return self.classify_edges(fused_features) 265 | 266 | def embed_nodes(self, pyg_data, node_types, edge_types): 267 | edge_index = pyg_data.edge_index 268 | edge_attr = self.embed_edges(pyg_data.edge_attr) 269 | 270 | edge_types = edge_types.cpu().tolist() + [4]*(edge_index.size(1)-len(edge_types)) 271 | edge_types = torch.LongTensor(edge_types).cuda() 272 | 273 | node_embeddings = pyg_data.x 274 | for layer in self.node_embed_mlp: 275 | node_embeddings = layer(node_embeddings) 276 | 277 | for layer in self.conv_layers: 278 | # Edge types and edge_index don't line up 279 | node_embeddings = layer(node_embeddings, edge_index, node_types, edge_types, edge_attr).relu() 280 | return node_embeddings 281 | 282 | @staticmethod 283 | def extract_type_lists(data, loss_edge_features): 284 | # a bunch of slighly hacky stuff specificaly for HEAT model 285 | # it requires a list of node and edge types instead of a dict like other heterog pyg models 286 | # so we extract those from the node/edge feature vectors, and remove those dims from the feature vecs 287 | node_attr = data.x.detach().cpu().numpy() 288 | edge_attr = data.edge_attr.detach().cpu().numpy() 289 | # -5 because the five last spots in vector are the type 1-hot code 290 | edge_types = torch.LongTensor(edge_attr[:,-5:].nonzero()[1]) 291 | node_types = torch.LongTensor(node_attr[:,-5:].nonzero()[1]) 292 | x = torch.Tensor(node_attr[:,:-5]) 293 | edge_attr = torch.Tensor(edge_attr[:,:-5]).cuda() 294 | node_types = node_types.cuda() 295 | edge_types = edge_types.cuda() 296 | data.x = x.cuda() 297 | data.edge_attr = edge_attr.cuda() 298 | loss_edge_features = loss_edge_features.detach().cpu().numpy() 299 | loss_edge_features = torch.Tensor(loss_edge_features[:,:,:-5]).cuda() 300 | return node_types, edge_types, loss_edge_features 301 | 302 | class HGNN(GNN): 303 | 304 | def is_heterogenous(self): 305 | return True 306 | # 307 | def forward(self, 308 | pyg_data, 309 | classify_edge_index, 310 | classify_edge_features, 311 | edge_key): 312 | if self.add_self_loops: 313 | pyg_data = T.AddSelfLoops()(pyg_data) 314 | 315 | node_features = self.embed_nodes(pyg_data) 316 | edge_features = self.embed_edges(classify_edge_features) 317 | fused_features = self.fuse_features(node_features, edge_features, classify_edge_index, edge_key) 318 | return self.classify_edges(fused_features) 319 | 320 | def embed_nodes(self, pyg_data): 321 | node_embeddings = pyg_data.x_dict 322 | for layer in self.node_embed_mlp: 323 | node_embeddings = {x: layer(node_embeddings[x]) for x in node_embeddings} 324 | 325 | for conv in self.conv_layers: 326 | node_embeddings = conv(node_embeddings, pyg_data.edge_index_dict) 327 | node_embeddings = {x: y.relu() for x, y in node_embeddings.items()} 328 | 329 | return node_embeddings 330 | 331 | def embed_edges(self, edge_features): 332 | edge_embeddings = edge_features 333 | for layer in self.edge_embed_mlp: 334 | edge_embeddings = layer(edge_embeddings) 335 | return edge_embeddings 336 | 337 | def fuse_features(self, node_embeddings, edge_embeddings, classify_edge_index, classify_edge_key): 338 | 339 | if self.node_fuse_method == NodeFuseMethod.CONCAT: 340 | fused_features = torch.cat([node_embeddings[classify_edge_key[0]][classify_edge_index[:,:,0]], 341 | node_embeddings[classify_edge_key[2]][classify_edge_index[:,:,1]]], dim=-1) 342 | elif self.node_fuse_method == NodeFuseMethod.AVERAGE: 343 | fused_features = (node_embeddings[classify_edge_key[0]][classify_edge_index[:,:,0]] + 344 | node_embeddings[classify_edge_key[2]][classify_edge_index[:,:,1]])/2.0 345 | elif self.node_fuse_method == NodeFuseMethod.MULTIPLY: 346 | fused_features = (node_embeddings[classify_edge_key[0]][classify_edge_index[:,:,0]] * 347 | node_embeddings[classify_edge_key[2]][classify_edge_index[:,:,1]]) 348 | 349 | fused_features = torch.cat([fused_features, edge_embeddings], dim=-1) 350 | 351 | if self.include_transformer: 352 | padding_mask = (classify_edge_index[:,:,0].eq(0) & classify_edge_index[:,:,1].eq(0)) 353 | fused_features[padding_mask] = 0.0 354 | fused_features = self.transformer_encoder(fused_features, src_key_padding_mask=padding_mask) 355 | 356 | return fused_features 357 | 358 | class HGCN(HGNN): 359 | def __init__(self, 360 | node_features_dim, 361 | edge_features_dim, 362 | node_embedding_dim = 32, 363 | edge_embedding_dim = 32, 364 | node_mlp_hidden_layers = [32, 32], 365 | edge_mlp_hidden_layers = [32, 32], 366 | gnn_hidden_layers=[32], 367 | output_mlp_hidden_layers = [32, 32], 368 | node_fuse_method = NodeFuseMethod.CONCAT, 369 | include_transformer=True, 370 | add_self_loops=True, 371 | reversed_edges=True): 372 | conv_layers = [] 373 | metadata = REVERSED_GRAPH_METADATA if reversed_edges else GRAPH_METADATA 374 | input_dim = node_embedding_dim 375 | for layer_size in gnn_hidden_layers+[node_embedding_dim]: 376 | for key in metadata['edge_types']: 377 | conv_layers_dict = {} 378 | if key [0] == key[2]: 379 | conv_layers_dict[key] = GCNConv(input_dim, node_embedding_dim) 380 | conv_layers_dict[key] = GCNConv(input_dim, node_embedding_dim) 381 | else: 382 | conv_layers_dict[key] = SAGEConv(input_dim, node_embedding_dim) 383 | conv_layers_dict[key] = SAGEConv(input_dim, node_embedding_dim) 384 | conv_layers.append(HeteroConv(conv_layers_dict, aggr='mean')) 385 | input_dim = layer_size 386 | 387 | super().__init__( 388 | conv_layers, 389 | node_features_dim, 390 | edge_features_dim, 391 | node_embedding_dim, 392 | edge_embedding_dim, 393 | node_mlp_hidden_layers, 394 | edge_mlp_hidden_layers, 395 | output_mlp_hidden_layers, 396 | node_fuse_method, 397 | include_transformer, 398 | add_self_loops, 399 | reversed_edges) 400 | 401 | class HGT(HGNN): 402 | def __init__(self, 403 | node_features_dim, 404 | edge_features_dim, 405 | node_embedding_dim = 32, 406 | edge_embedding_dim = 32, 407 | node_mlp_hidden_layers = [32, 32], 408 | gnn_hidden_layers=[32], 409 | edge_mlp_hidden_layers = [32, 32], 410 | output_mlp_hidden_layers = [32, 32], 411 | node_fuse_method = NodeFuseMethod.CONCAT, 412 | include_transformer=True, 413 | add_self_loops=True, 414 | reversed_edges=True): 415 | metadata = REVERSED_GRAPH_METADATA if reversed_edges else GRAPH_METADATA 416 | metadata = (metadata['node_types'], metadata['edge_types']) 417 | conv_layers = [] 418 | input_dim = node_embedding_dim 419 | for layer_size in gnn_hidden_layers+[node_embedding_dim]: 420 | conv_layers.append(HGTConv(input_dim, node_embedding_dim, metadata)) 421 | input_dim = layer_size 422 | 423 | super().__init__( 424 | conv_layers, 425 | node_features_dim, 426 | edge_features_dim, 427 | node_embedding_dim, 428 | edge_embedding_dim, 429 | node_mlp_hidden_layers, 430 | edge_mlp_hidden_layers, 431 | output_mlp_hidden_layers, 432 | node_fuse_method, 433 | include_transformer, 434 | add_self_loops, 435 | reversed_edges) 436 | 437 | class HAN(HGNN): 438 | def __init__(self, 439 | node_features_dim, 440 | edge_features_dim, 441 | node_embedding_dim = 32, 442 | edge_embedding_dim = 32, 443 | node_mlp_hidden_layers = [32, 32], 444 | edge_mlp_hidden_layers = [32, 32], 445 | gnn_hidden_layers=[32], 446 | output_mlp_hidden_layers = [32, 32], 447 | node_fuse_method = NodeFuseMethod.CONCAT, 448 | include_transformer=True, 449 | add_self_loops=True, 450 | reversed_edges=True): 451 | metadata = REVERSED_GRAPH_METADATA if reversed_edges else GRAPH_METADATA 452 | metadata = (metadata['node_types'], metadata['edge_types']) 453 | conv_layers = [] 454 | input_dim = node_embedding_dim 455 | for layer_size in gnn_hidden_layers+[node_embedding_dim]: 456 | conv_layers.append(HANConv(input_dim, node_embedding_dim, metadata)) 457 | input_dim = layer_size 458 | 459 | super().__init__( 460 | conv_layers, 461 | node_features_dim, 462 | edge_features_dim, 463 | node_embedding_dim, 464 | edge_embedding_dim, 465 | node_mlp_hidden_layers, 466 | edge_mlp_hidden_layers, 467 | output_mlp_hidden_layers, 468 | node_fuse_method, 469 | include_transformer, 470 | add_self_loops, 471 | reversed_edges) 472 | 473 | def make_model(cfg, node_featurizer, edge_featurizer, load_model=False): 474 | node_features_dim = node_featurizer.get_feature_size() 475 | edge_features_dim = edge_featurizer.get_feature_size() 476 | if cfg.add_num_nodes: 477 | edge_features_dim+=1 478 | if cfg.add_num_edges: 479 | edge_features_dim+=1 480 | if cfg.model_type == 'heat': 481 | node_features_dim = node_features_dim-5 482 | edge_features_dim = edge_features_dim-5 483 | 484 | assert cfg.node_fuse_method in [x.value for x in NodeFuseMethod] 485 | node_fuse_method = [x for x in NodeFuseMethod if x.value==cfg.node_fuse_method][0] 486 | args = {'node_features_dim':node_features_dim, 487 | 'edge_features_dim':edge_features_dim, 488 | 'node_embedding_dim':cfg.node_embedding_dim, 489 | 'edge_embedding_dim':cfg.edge_embedding_dim, 490 | 'node_mlp_hidden_layers':cfg.node_mlp_hidden_layers, 491 | 'edge_mlp_hidden_layers':cfg.edge_mlp_hidden_layers, 492 | 'output_mlp_hidden_layers':cfg.output_mlp_hidden_layers, 493 | 'node_fuse_method':node_fuse_method, 494 | 'include_transformer':cfg.include_transformer, 495 | 'reversed_edges':cfg.reversed_edges} 496 | 497 | if 'add_self_loops' in cfg: 498 | cfg['add_self_loops'] = cfg.add_self_loops 499 | if cfg.model_type != 'mlp': 500 | args['gnn_hidden_layers'] = cfg.gnn_hidden_layers 501 | model = None 502 | if cfg.model_type == 'mlp': 503 | model = MLPEdgeClassifier(**args) 504 | elif cfg.model_type == 'gcn': 505 | model = GCN(**args) 506 | elif cfg.model_type == 'heat': 507 | args['edge_type_emb_dim'] = cfg.edge_type_emb_dim 508 | args['num_node_types'] = 5 509 | args['num_edge_types'] = 5 510 | model = HEAT(**args) 511 | elif cfg.model_type == 'hgt': 512 | model = HGT(**args) 513 | elif cfg.model_type == 'hgcn': 514 | model = HGCN(**args) 515 | elif cfg.model_type == 'han': 516 | model = HAN(**args) 517 | else: 518 | Exception(f"Invalid model type: {cfg.model_type}") 519 | 520 | if load_model: 521 | model_path = create_path_to_model(cfg, node_featurizer, edge_featurizer) 522 | model.load_state_dict(torch.load(model_path)) 523 | model.cuda() 524 | 525 | return model 526 | 527 | def create_path_to_model(cfg, node_featurizer, edge_featurizer): 528 | Path(cfg.models_dir).mkdir(exist_ok=True) 529 | model_path = '%s/%s.pt'%(cfg.models_dir, 530 | cfg.model_name) 531 | return model_path 532 | 533 | def compute_output(model, input_data, edges, edge_features, edge_key=None): 534 | if model.is_heterogenous(): 535 | out = model(input_data, edges, edge_features, edge_key) 536 | else: 537 | out = model(input_data, edges, edge_features) 538 | return out 539 | 540 | ''' 541 | class RecurrentGCNNet(torch.nn.Module): 542 | def __init__(self, node_features_dim=96, hidden_dim=32, gcn_out_dim=16, use_classifier_mlp=True): 543 | super().__init__() 544 | self.conv1 = GCLSTM(node_features_dim, hidden_dim, 4) 545 | self.conv2 = GCNConv(hidden_dim, gcn_out_dim) 546 | self.use_classifier_mlp = use_classifier_mlp 547 | if self.use_classifier_mlp: 548 | self.mlp = nn.Linear(gcn_out_dim*2, 1) 549 | self.sigmoid = torch.nn.Sigmoid() 550 | 551 | def forward(self, x, edge_index, classify_edge_index, h, c): 552 | h, c = self.conv1(x, edge_index, H=h, C=c) 553 | x = h.relu() 554 | node_features = self.conv2(x, edge_index) 555 | if self.use_classifier_mlp: 556 | node_pair_features = torch.cat([node_features[classify_edge_index[0]], 557 | node_features[classify_edge_index[1]]], dim=-1) 558 | mlp_output = torch.squeeze(self.mlp(node_pair_features)) 559 | out = self.sigmoid(mlp_output) 560 | else: 561 | out = self.sigmoid((node_features[classify_edge_index[0]] * \ 562 | node_features[classify_edge_index[1]]).sum(dim=-1)).view(-1) 563 | return out, h, c 564 | 565 | def compute_output_recurrent(model, input_data, edges, edge_features, h, c): 566 | return model(input_data.x, input_data.edge_index, edges, edge_features, h, c) 567 | ''' 568 | -------------------------------------------------------------------------------- /memsearch/rl/complex_input_network.py: -------------------------------------------------------------------------------- 1 | from gym.spaces import Box, Discrete, MultiDiscrete 2 | import numpy as np 3 | import tree # pip install dm_tree 4 | 5 | # TODO (sven): add IMPALA-style option. 6 | # from ray.rllib.examples.models.impala_vision_nets import TorchImpalaVisionNet 7 | from ray.rllib.models.torch.misc import ( 8 | normc_initializer as torch_normc_initializer, 9 | SlimFC, 10 | ) 11 | from ray.rllib.models.catalog import ModelCatalog 12 | from ray.rllib.models.modelv2 import ModelV2, restore_original_dimensions 13 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 14 | from ray.rllib.models.utils import get_filter_config 15 | from ray.rllib.policy.sample_batch import SampleBatch 16 | from ray.rllib.utils.annotations import override 17 | from ray.rllib.utils.framework import try_import_torch 18 | from ray.rllib.utils.spaces.space_utils import flatten_space 19 | from ray.rllib.utils.torch_utils import one_hot 20 | 21 | torch, nn = try_import_torch() 22 | 23 | 24 | class ComplexInputNetwork(TorchModelV2, nn.Module): 25 | """TorchModelV2 concat'ing CNN outputs to flat input(s), followed by FC(s). 26 | 27 | Note: This model should be used for complex (Dict or Tuple) observation 28 | spaces that have one or more image components. 29 | 30 | The data flow is as follows: 31 | 32 | `obs` (e.g. Tuple[img0, img1, discrete0]) -> `CNN0 + CNN1 + ONE-HOT` 33 | `CNN0 + CNN1 + ONE-HOT` -> concat all flat outputs -> `out` 34 | `out` -> (optional) FC-stack -> `out2` 35 | `out2` -> action (logits) and value heads. 36 | """ 37 | 38 | def __init__(self, obs_space, action_space, num_outputs, model_config, name): 39 | self.original_space = ( 40 | obs_space.original_space 41 | if hasattr(obs_space, "original_space") 42 | else obs_space 43 | ) 44 | 45 | self.processed_obs_space = ( 46 | self.original_space 47 | if model_config.get("_disable_preprocessor_api") 48 | else obs_space 49 | ) 50 | 51 | nn.Module.__init__(self) 52 | TorchModelV2.__init__( 53 | self, self.original_space, action_space, num_outputs, model_config, name 54 | ) 55 | 56 | self.flattened_input_space = flatten_space(self.original_space) 57 | 58 | # Atari type CNNs or IMPALA type CNNs (with residual layers)? 59 | # self.cnn_type = self.model_config["custom_model_config"].get( 60 | # "conv_type", "atari") 61 | 62 | # Build the CNN(s) given obs_space's image components. 63 | self.cnns = nn.ModuleDict() 64 | self.one_hot = nn.ModuleDict() 65 | self.flatten_dims = {} 66 | self.flatten = nn.ModuleDict() 67 | concat_size = 0 68 | for i, component in enumerate(self.flattened_input_space): 69 | i = str(i) 70 | # Image space. 71 | if len(component.shape) == 3 and isinstance(component, Box): 72 | config = { 73 | "conv_filters": model_config["conv_filters"] 74 | if "conv_filters" in model_config 75 | else get_filter_config(component.shape), 76 | "conv_activation": model_config.get("conv_activation"), 77 | "post_fcnet_hiddens": [], 78 | } 79 | # if self.cnn_type == "atari": 80 | self.cnns[i] = ModelCatalog.get_model_v2( 81 | component, 82 | action_space, 83 | num_outputs=None, 84 | model_config=config, 85 | framework="torch", 86 | name="cnn_{}".format(i), 87 | ) 88 | # TODO (sven): add IMPALA-style option. 89 | # else: 90 | # cnn = TorchImpalaVisionNet( 91 | # component, 92 | # action_space, 93 | # num_outputs=None, 94 | # model_config=config, 95 | # name="cnn_{}".format(i)) 96 | 97 | concat_size += self.cnns[i].num_outputs 98 | self.add_module("cnn_{}".format(i), self.cnns[i]) 99 | # Discrete|MultiDiscrete inputs -> One-hot encode. 100 | elif isinstance(component, (Discrete, MultiDiscrete)): 101 | if isinstance(component, Discrete): 102 | size = component.n 103 | else: 104 | size = np.sum(component.nvec) 105 | config = { 106 | "fcnet_hiddens": model_config["fcnet_hiddens"], 107 | "fcnet_activation": model_config.get("fcnet_activation"), 108 | "post_fcnet_hiddens": [], 109 | } 110 | self.one_hot[i] = ModelCatalog.get_model_v2( 111 | Box(-1.0, 1.0, (size,), np.float32), 112 | action_space, 113 | num_outputs=None, 114 | model_config=config, 115 | framework="torch", 116 | name="one_hot_{}".format(i), 117 | ) 118 | concat_size += self.one_hot[i].num_outputs 119 | self.add_module("one_hot_{}".format(i), self.one_hot[i]) 120 | # Everything else (1D Box). 121 | else: 122 | size = int(np.product(component.shape)) 123 | config = { 124 | "fcnet_hiddens": model_config["fcnet_hiddens"], 125 | "fcnet_activation": model_config.get("fcnet_activation"), 126 | "post_fcnet_hiddens": [], 127 | } 128 | self.flatten[i] = ModelCatalog.get_model_v2( 129 | Box(-1.0, 1.0, (size,), np.float32), 130 | action_space, 131 | num_outputs=None, 132 | model_config=config, 133 | framework="torch", 134 | name="flatten_{}".format(i), 135 | ) 136 | self.flatten_dims[i] = size 137 | concat_size += self.flatten[i].num_outputs 138 | self.add_module("flatten_{}".format(i), self.flatten[i]) 139 | 140 | # Optional post-concat FC-stack. 141 | post_fc_stack_config = { 142 | "fcnet_hiddens": model_config.get("post_fcnet_hiddens", []), 143 | "fcnet_activation": model_config.get("post_fcnet_activation", "relu"), 144 | } 145 | self.post_fc_stack = ModelCatalog.get_model_v2( 146 | Box(float("-inf"), float("inf"), shape=(concat_size,), dtype=np.float32), 147 | self.action_space, 148 | None, 149 | post_fc_stack_config, 150 | framework="torch", 151 | name="post_fc_stack", 152 | ) 153 | 154 | # Actions and value heads. 155 | self.logits_layer = None 156 | self.value_layer = None 157 | self._value_out = None 158 | 159 | if num_outputs: 160 | # Action-distribution head. 161 | self.logits_layer = SlimFC( 162 | in_size=self.post_fc_stack.num_outputs, 163 | out_size=num_outputs, 164 | activation_fn=None, 165 | initializer=torch_normc_initializer(0.01), 166 | ) 167 | # Create the value branch model. 168 | self.value_layer = SlimFC( 169 | in_size=self.post_fc_stack.num_outputs, 170 | out_size=1, 171 | activation_fn=None, 172 | initializer=torch_normc_initializer(0.01), 173 | ) 174 | else: 175 | self.num_outputs = concat_size 176 | 177 | @override(ModelV2) 178 | def forward(self, input_dict, state, seq_lens): 179 | if SampleBatch.OBS in input_dict and "obs_flat" in input_dict: 180 | orig_obs = input_dict[SampleBatch.OBS] 181 | else: 182 | orig_obs = restore_original_dimensions( 183 | input_dict[SampleBatch.OBS], self.processed_obs_space, tensorlib="torch" 184 | ) 185 | # Push observations through the different components 186 | # (CNNs, one-hot + FC, etc..). 187 | outs = [] 188 | for i, component in enumerate(tree.flatten(orig_obs)): 189 | i = str(i) 190 | if i in self.cnns: 191 | cnn_out, _ = self.cnns[i](SampleBatch({SampleBatch.OBS: component})) 192 | outs.append(cnn_out) 193 | elif i in self.one_hot: 194 | if component.dtype in [ 195 | torch.int8, 196 | torch.int16, 197 | torch.int32, 198 | torch.int64, 199 | torch.uint8, 200 | ]: 201 | one_hot_in = { 202 | SampleBatch.OBS: one_hot( 203 | component, self.flattened_input_space[int(i)] 204 | ) 205 | } 206 | else: 207 | one_hot_in = {SampleBatch.OBS: component} 208 | one_hot_out, _ = self.one_hot[i](SampleBatch(one_hot_in)) 209 | outs.append(one_hot_out) 210 | else: 211 | nn_out, _ = self.flatten[i]( 212 | SampleBatch( 213 | { 214 | SampleBatch.OBS: torch.reshape( 215 | component, [-1, self.flatten_dims[i]] 216 | ) 217 | } 218 | ) 219 | ) 220 | outs.append(nn_out) 221 | 222 | # Concat all outputs and the non-image inputs. 223 | out = torch.cat(outs, dim=1) 224 | # Push through (optional) FC-stack (this may be an empty stack). 225 | out, _ = self.post_fc_stack(SampleBatch({SampleBatch.OBS: out})) 226 | 227 | # No logits/value branches. 228 | if self.logits_layer is None: 229 | return out, [] 230 | 231 | # Logits- and value branches. 232 | logits, values = self.logits_layer(out), self.value_layer(out) 233 | self._value_out = torch.reshape(values, [-1]) 234 | return logits, [] 235 | 236 | @override(ModelV2) 237 | def value_function(self): 238 | return self._value_out 239 | 240 | -------------------------------------------------------------------------------- /memsearch/running.py: -------------------------------------------------------------------------------- 1 | import os 2 | import functools 3 | import numpy as np 4 | import warnings 5 | import itertools 6 | import random 7 | from multiprocessing import Pool, RLock 8 | from tqdm import tqdm 9 | from pathlib import Path 10 | from tensorboardX import SummaryWriter 11 | from memsearch.scene import make_scene_sampler_and_evolver 12 | from memsearch.agents import make_agent 13 | from memsearch.dataset import save_networkx_graph, make_featurizers 14 | from memsearch.tasks import make_task, TaskType 15 | 16 | def run(cfg, 17 | logger, 18 | output_dir, 19 | task_type = TaskType.PREDICT_LOC, 20 | num_steps = 10000, 21 | agent_type = 'random', 22 | save_graphs_path = None, 23 | save_images = False, 24 | for_data_collection = False, 25 | save_sgm_graphs = False, 26 | worker_num = 0): 27 | if not for_data_collection: 28 | writer = SummaryWriter(os.path.join(output_dir, agent_type)) 29 | else: 30 | # Can ignore warning from networkx about pickle - remove if upgrading to networkx 3.0 31 | warnings.filterwarnings("ignore", category=DeprecationWarning) 32 | writer = None 33 | images_dir = os.path.join(output_dir,'images') 34 | Path(images_dir).mkdir(exist_ok=True) 35 | 36 | scene_sampler, scene_evolver = make_scene_sampler_and_evolver(cfg.scene_gen) 37 | node_featurizer, edge_featurizer = make_featurizers(cfg.model, for_data_collection, cfg.task.num_steps) 38 | 39 | agents_cfg = cfg.agents 40 | if 'sgm' in agent_type or for_data_collection: 41 | model_config = cfg.model 42 | else: 43 | model_config = None 44 | agent = make_agent(agents_cfg, agent_type, task_type, node_featurizer, edge_featurizer, scene_sampler, scene_evolver, for_data_collection, model_config) 45 | sgm_agent = None 46 | if 'sgm' in agent_type: 47 | sgm_agent = agent 48 | elif for_data_collection: 49 | sgm_agent = agent.sgm_agent 50 | 51 | task = make_task( 52 | scene_sampler, 53 | scene_evolver, 54 | task_type, 55 | eps_per_scene=cfg.task.eps_per_scene) 56 | 57 | score_histories = [] 58 | score_history = [] 59 | sgm_graphs = [] 60 | scene_num, scene_steps, scene_score, total_score, total_cost = 0, 0, 0, 0, 0 61 | 62 | current_scene = None 63 | pbar = tqdm(range(num_steps), position=worker_num, leave = not for_data_collection) 64 | 65 | for step in pbar: 66 | if task.scene is None or current_scene!=task.scene: 67 | scene_steps, scene_score = 0, 0 68 | if current_scene!=None: 69 | if writer is not None: 70 | writer.add_scalar('average_score', average_score, scene_num) 71 | score_histories.append(np.array(score_history)) 72 | else: 73 | query_nodes = task.reset() 74 | score_history = [] 75 | scene_num+=1 76 | current_scene = task.scene 77 | agent.transition_to_new_scene(current_scene) 78 | 79 | if save_images and scene_num < 3: 80 | scene_sampler.current_priors_graph.save_png(f'{images_dir}/sampler_probs_{agent_type}_{scene_num}.png', colorize_edges=True) 81 | if save_images and scene_num < 3: 82 | scene_evolver.current_priors_graph.save_png(f'{images_dir}/evolver_probs_{agent_type}_{scene_num}.png', colorize_edges=True) 83 | 84 | if task_type == TaskType.PREDICT_ENV_DYNAMICS: 85 | prediction = agent.make_predictions(task_type, current_scene, query_nodes, top_k=cfg.task.top_k) 86 | obs, score, done, info = task.step(prediction, top_k=cfg.task.top_k, agent_type=agent_type) 87 | elif task_type == TaskType.FIND_OBJECT: 88 | prediction = agent.make_predictions(task_type, current_scene, query_nodes, top_k=cfg.task.top_k) 89 | obs, score, done, info = task.step(prediction, agent=agent, current_scene=current_scene, query_nodes=query_nodes, max_attempts=cfg.task.max_attempts, top_k=cfg.task.top_k) 90 | total_cost += info['acc_cost'] 91 | 92 | else: 93 | prediction = agent.make_predictions(task_type, current_scene, query_nodes) 94 | obs, score, done, info = task.step(prediction) 95 | 96 | if for_data_collection: 97 | agent.mark_true_edges(current_scene.scene_graph) 98 | has_true_edge = False 99 | for edge in agent.sgm_agent.sgm_graph.get_edges(): 100 | if edge.currently_true and edge.is_query_edge: 101 | has_true_edge = True 102 | 103 | if not has_true_edge: 104 | logger.warning('None of the sgm edges considered are correct.') 105 | 106 | if save_sgm_graphs: 107 | save_path = save_graphs_path+f'/{worker_num}_{scene_num}_{scene_steps}.pickle' 108 | sgm_graphs.append(save_path) 109 | save_networkx_graph(save_path, 110 | sgm_agent.sgm_graph, 111 | node_featurizer, 112 | edge_featurizer) 113 | 114 | else: 115 | sgm_graphs.append(sgm_agent.sgm_graph.copy()) 116 | 117 | agent.remove_hyp_query_edges() 118 | 119 | if save_images and step % 10 == 0 and scene_num < 3: 120 | current_scene.scene_graph.save_pngs(f'{images_dir}/scene_{agent_type}_{scene_num}_{scene_steps}.png') 121 | if 'sgm' in agent_type: 122 | agent.sgm_graph.save_pngs(f'{images_dir}/sgm_{scene_num}_{scene_steps}.png') 123 | elif for_data_collection: 124 | agent.sgm_agent.sgm_graph.save_pngs(f'{images_dir}/sgm_{scene_num}_{scene_steps}.png') 125 | 126 | agent.receive_observation(obs) 127 | 128 | scene_steps+=1 129 | scene_score+=float(score) 130 | total_score+=float(score) 131 | average_cost=total_cost/(step+1) 132 | score_history.append(float(score)) 133 | average_score=total_score/(step+1) 134 | scene_average_score=scene_score/(scene_steps) 135 | 136 | #short agent type to make things line up nicely 137 | if agent_type == 'sgm': 138 | short_agent_type = 'sgm ' 139 | else: 140 | short_agent_type = agent_type[:6] 141 | if for_data_collection: 142 | pbar.set_description(("Agent %s | Step %d | Scene %d with %d Nodes | sgm has %d nodes & %.0f edges | Scene Accuracy=%.2f | "+\ 143 | "Overall Accuracy=%.2f")%(short_agent_type, 144 | step+1, 145 | scene_num, 146 | len(current_scene.scene_graph.nodes), 147 | float(len(sgm_agent.sgm_graph.nodes)), 148 | float(len(sgm_agent.sgm_graph.get_edges())), 149 | scene_average_score, 150 | average_score)) 151 | elif task_type == TaskType.FIND_OBJECT: 152 | pbar.set_description(("Agent %s | Overall Accuracy=%.2f | Average Cost=%.2f")%(short_agent_type, average_score, average_cost)) 153 | else: 154 | pbar.set_description(("Agent %s | Overall Accuracy=%.2f")%(short_agent_type, average_score)) 155 | 156 | if done: 157 | query_nodes = task.reset() 158 | agent.step() 159 | pbar.close() 160 | if for_data_collection: 161 | if worker_num == 0: 162 | node_featurizer.save_text_embedding_dict() 163 | return sgm_graphs 164 | else: 165 | if 'sgm' in agent_type: 166 | node_featurizer.save_text_embedding_dict() 167 | return score_histories 168 | 169 | def collect_data(cfg, 170 | logger, 171 | output_dir, 172 | num_steps = 10000, 173 | agent_type = 'random', 174 | save_graphs_path = None, 175 | save_sgm_graphs = True): 176 | task_type = TaskType.PREDICT_LOC 177 | if cfg.collect_data_num_workers == 1: 178 | data = run(cfg, 179 | logger, 180 | output_dir, 181 | task_type, 182 | num_steps, 183 | agent_type, 184 | save_graphs_path, 185 | cfg.save_images, 186 | for_data_collection = True, 187 | save_sgm_graphs = save_sgm_graphs) 188 | else: 189 | assert num_steps % cfg.collect_data_num_workers == 0 190 | num_steps = int(num_steps / cfg.collect_data_num_workers) 191 | run_in_parallel = functools.partial(run, cfg, logger, output_dir, task_type, num_steps, 192 | agent_type, save_graphs_path, cfg.save_images, True, 193 | save_sgm_graphs) 194 | tqdm.set_lock(RLock()) 195 | with Pool(processes = cfg.collect_data_num_workers) as pool: 196 | data_list = pool.map(run_in_parallel, list(range(cfg.collect_data_num_workers))) 197 | # flatten into one list 198 | data = list(itertools.chain.from_iterable(data_list)) 199 | return data 200 | 201 | def eval_agent(cfg, 202 | logger, 203 | output_dir, 204 | task_type = TaskType.PREDICT_LOC, 205 | agent_type = 'random'): 206 | worker_num = cfg.agents.agent_types.index(agent_type) 207 | random.seed(0) 208 | return run(cfg, 209 | logger, 210 | output_dir, 211 | task_type=task_type, 212 | num_steps=cfg.task.num_steps, 213 | agent_type=agent_type, 214 | save_images=cfg.save_images, 215 | for_data_collection=False, 216 | worker_num=worker_num) 217 | -------------------------------------------------------------------------------- /memsearch/tasks.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import random 3 | from memsearch.util import sample_from_dict 4 | from memsearch.graphs import NodeType 5 | from enum import Enum 6 | from abc import abstractmethod 7 | 8 | from memsearch.scene import get_cost_between_nodes 9 | 10 | 11 | class TaskType(Enum): 12 | PREDICT_LOC = "predict_location" 13 | PREDICT_ENV_DYNAMICS = "predict_env_dynamics" 14 | FIND_OBJECT = "find_object" 15 | 16 | class Task(gym.Env): 17 | """A goal-based environment node choice env. 18 | Choose the correct parent node for a given goal node. 19 | Observation and action space implemented with integers corresponding 20 | to nodes in the scene graph. 21 | """ 22 | 23 | def __init__(self, 24 | scene_sampler, 25 | scene_evolver, 26 | eps_per_scene=250, 27 | scene_obs_percent=0.1, 28 | pct_query_moved_object=0.5): 29 | self.scene_sampler = scene_sampler 30 | self.scene_evolver = scene_evolver 31 | self.eps_per_scene = eps_per_scene 32 | self.pct_query_moved_object = pct_query_moved_object 33 | """ 34 | self.action_space = spaces.Discrete(max_nodes) 35 | self.observation_space = spaces.Dict(dict( 36 | node_features=spaces.Box(low=0, high=1, shape=(max_nodes,features_per_node), dtype=np.float32), 37 | edges=spaces.Box(low=0, high=1, shape=(2,max_edges), dtype=np.float32), 38 | num_nodes=spaces.Discrete(max_nodes), 39 | num_edges=spaces.Discrete(max_edges), 40 | achieved_goal=spaces.Box(low=0, high=1, shape=(goal_embedding_size,), dtype=np.int32), 41 | desired_goal=spaces.Box(low=0, high=1, shape=(goal_embedding_size,), dtype=np.int32), 42 | action_mask=spaces.Box(0, 1, shape=(max_avail_actions, )), 43 | )) 44 | """ 45 | self.current_ep = 0 46 | self.goal_node = None 47 | self.scene_obs_percent = scene_obs_percent 48 | self.scene = None 49 | 50 | def pick_query(self, moved_objects, num_to_choose=1): 51 | all_object_nodes = self.scene.scene_graph.get_nodes_with_type(NodeType.OBJECT) 52 | object_move_probs = self.scene.scene_graph.get_object_node_move_probs() 53 | if num_to_choose == 1: 54 | if random.random() > self.pct_query_moved_object: 55 | query = sample_from_dict(object_move_probs) 56 | else: 57 | sample_from = all_object_nodes if len(moved_objects) == 0 else moved_objects 58 | query = random.choice(sample_from) 59 | self.goal_node = query 60 | else: 61 | query = [] 62 | while len(query) < num_to_choose: 63 | if random.random() > self.pct_query_moved_object: 64 | choice = sample_from_dict(object_move_probs) 65 | else: 66 | sample_from = all_object_nodes if len(moved_objects) == 0 else moved_objects 67 | choice = random.choice(sample_from) 68 | if query in moved_objects: 69 | moved_objects.remove(query) 70 | if choice in query: 71 | continue 72 | query.append(choice) 73 | return query 74 | 75 | def reset(self): 76 | if self.current_ep==0 or self.current_ep % self.eps_per_scene == 0: 77 | self.scene = self.scene_sampler.sample() 78 | self.scene_evolver.set_new_scene(self.scene) 79 | moved_objects = self.scene_evolver.evolve() 80 | for edge in self.scene.scene_graph.get_edges(): 81 | edge.age+=1 82 | for node in moved_objects: 83 | node.times_moved+=1 84 | self.current_ep+=1 85 | return moved_objects 86 | 87 | @abstractmethod 88 | def get_task_type(self): 89 | raise NotImplementedError() 90 | 91 | @abstractmethod 92 | def step(self, action): 93 | """ 94 | Run one timestep of the environment's dynamics. 95 | Accepts an action and returns a tuple (observation, reward, done, info). 96 | Args: 97 | action (Node): the node to explore 98 | Returns: 99 | observation (object): agent's observation of the current environment 100 | reward (float) : amount of reward returned after previous action 101 | done (boolean): whether the episode has ended, in which case further step() calls will return undefined results 102 | info (dict): contains auxiliary diagnostic information (helpful for debugging, and sometimes learning) 103 | """ 104 | raise NotImplementedError() 105 | 106 | def get_random_furniture_node(self): 107 | all_furniture_nodes = self.scene.scene_graph.get_nodes(NodeType.FURNITURE) 108 | return random.choice(all_furniture_nodes) 109 | 110 | class PredictLocationsTask(Task): 111 | def __init__(self, 112 | scene_sampler, 113 | scene_evolver, 114 | eps_per_scene=250, 115 | scene_obs_percent=0.1, 116 | num_objs_to_predict=1): 117 | super().__init__(scene_sampler, 118 | scene_evolver, 119 | eps_per_scene, 120 | scene_obs_percent) 121 | assert num_objs_to_predict >= 1, "Cannot predict the location of fewer than one object." 122 | self.num_objs_to_predict = num_objs_to_predict 123 | 124 | def get_task_type(self): 125 | return TaskType.PREDICT_LOC 126 | 127 | def reset(self): 128 | moved_objects = super().reset() 129 | query = self.pick_query(moved_objects, self.num_objs_to_predict) 130 | return query 131 | 132 | def step(self, action): 133 | reward = 0 134 | if not isinstance(action, dict): 135 | action = {self.goal_node.description: action} 136 | for (goal_node_description, predicted_parent_node) in action.items(): 137 | for child in predicted_parent_node.get_children_nodes(): 138 | if child.description == goal_node_description: 139 | reward += 1 140 | reward = float(reward / len(action)) 141 | done = True 142 | info = {} 143 | obs = list(action.values())[0] # gets to observe the node 144 | return obs, reward, done, info 145 | 146 | class FindObjTask(Task): 147 | 148 | def reset(self): 149 | moved_objects = super().reset() 150 | query = self.pick_query(moved_objects, 1) 151 | return query 152 | 153 | def step(self, action, agent=None, current_scene=None, query_nodes=None, top_k=3, max_attempts=10): 154 | def is_obj_at_node(node): 155 | for child in node.get_children_nodes(): 156 | if child.description == self.goal_node.description: 157 | return True 158 | return False 159 | 160 | # Iteratively ask agent for predictions and give it observations 161 | 162 | num_attempts = 1 163 | pred_node = action 164 | all_furniture_nodes = self.scene.scene_graph.get_nodes(NodeType.FURNITURE) 165 | if not pred_node: 166 | pred_node = random.choice(all_furniture_nodes) 167 | visited_nodes = [] 168 | visited_node_ids = [] 169 | acc_cost = 0 170 | curr_node = self.get_random_furniture_node() 171 | while num_attempts <= max_attempts: 172 | agent.receive_observation(pred_node) 173 | acc_cost += get_cost_between_nodes(curr_node, pred_node, self.scene.scene_graph) 174 | 175 | visited_nodes.append(pred_node) 176 | visited_node_ids.append(pred_node.unique_id) 177 | 178 | if is_obj_at_node(pred_node): 179 | reward = num_attempts 180 | obs = pred_node 181 | info = {'acc_cost': acc_cost, 'visited_nodes': visited_nodes} 182 | done = True 183 | return obs, reward, done, info 184 | curr_node = pred_node # pred_node becomes the next "current node" for the agent. 185 | pred_node = agent.make_predictions(TaskType.FIND_OBJECT, current_scene, query_nodes, ignore_nodes_to_pred=visited_node_ids) 186 | if pred_node is None: # Could not find non-visited node, choose a random non-visited node 187 | filtered_options = [node for node in all_furniture_nodes if not node.unique_id in visited_node_ids] 188 | pred_node = random.choice(filtered_options) 189 | num_attempts += 1 190 | 191 | # object not found 192 | reward = num_attempts 193 | obs = pred_node 194 | info = {'acc_cost': acc_cost, 'visited_nodes': visited_nodes} 195 | done = True 196 | return obs, reward, done, info 197 | 198 | def get_task_type(self): 199 | return TaskType.FIND_OBJECT 200 | 201 | class PredictDynamicsTask(Task): 202 | def reset(self): 203 | moved_objects = super().reset() 204 | query = self.pick_query(moved_objects, 6) 205 | return query 206 | 207 | def step(self, action, agent_type=None, top_k=3): 208 | # gather ground truth 209 | goal_object_nodes = list(action.keys()) 210 | total_loss_edge_prob = 0.0 211 | max_likelihood_parents = [] 212 | objs_with_no_edges = 0 213 | for object_node in goal_object_nodes: 214 | option_probs = self.scene_evolver.get_move_target_probs(object_node) 215 | if len(option_probs) == 0: 216 | objs_with_no_edges += 1 217 | continue 218 | gt_edges = {} 219 | # post process to a dict with top_k entities and only the node as the key 220 | edge_prob = dict( 221 | sorted(option_probs.items(), key=lambda item: item[1], reverse=True)[:top_k] 222 | ) 223 | gt_edges = {k[0]:v for (k,v) in edge_prob.items()} 224 | max_prob_node = list(gt_edges.keys())[0] 225 | max_likelihood_parents.append(max_prob_node) 226 | # compute loss 227 | predicted_edges = action[object_node]['edge_prob'] 228 | loss_edge_prob_i = 0.0 229 | if agent_type is None: 230 | raise ValueError("Predict Task Dynamics expects agent type in the scene step.") 231 | get_node_id = lambda node: node.description if agent_type == "priors" else node.unique_id 232 | gt_node_ids = [get_node_id(n) for n in gt_edges.keys()] 233 | pred_node_ids = [get_node_id(n) for n in predicted_edges.keys()] 234 | gt_length = len(gt_node_ids) # this is also the highest possible index distance (index == max_length if index is not in ) 235 | for i in range(gt_length): 236 | node_id = gt_node_ids[i] 237 | gt_index = i 238 | if node_id in pred_node_ids: 239 | pred_ind = pred_node_ids.index(node_id) 240 | loss_edge_prob_i += (abs(gt_index - pred_ind) / gt_length) 241 | else: 242 | loss_edge_prob_i += 1.0 243 | # OLD PROBABILISTIC METHOD 244 | # if agent_type == "priors": 245 | # pred_node_by_identifier = {node.description:node for node in predicted_edges.keys()} 246 | # gt_node_by_identifier = {node.label: node for node in gt_edges.keys()} 247 | # else: 248 | # pred_node_by_identifier = {node.unique_id:node for node in predicted_edges.keys()} 249 | # gt_node_by_identifier = {node.unique_id: node for node in gt_edges.keys()} 250 | # for furniture_node in gt_edges: 251 | # if agent_type == "priors": 252 | # node_key = furniture_node.label 253 | # else: 254 | # node_key = furniture_node.unique_id 255 | # if node_key in pred_node_by_identifier: 256 | # pred_node = pred_node_by_identifier[node_key] 257 | # pred_prob, gt_prob = predicted_edges[pred_node], gt_edges[furniture_node] 258 | # loss_edge_prob_i += abs(pred_prob - gt_prob) 259 | # else: # furniture node not predicted 260 | # loss_edge_prob_i += 1.0 261 | # for (id, pred_node) in pred_node_by_identifier.items(): 262 | # if not id in gt_node_by_identifier: # predicted but not in GT 263 | # loss_edge_prob_i += 1.0 264 | loss_edge_prob_i /= len(gt_edges) # len of gt_edges <= top_k 265 | total_loss_edge_prob += loss_edge_prob_i 266 | total_loss_edge_prob /= (len(goal_object_nodes) - objs_with_no_edges) # average over all goal obj nodes 267 | info = {} 268 | reward = 1 - total_loss_edge_prob 269 | done = True 270 | if len(max_likelihood_parents) != 0: 271 | obs = max_likelihood_parents 272 | else: 273 | obs = [self.get_random_furniture_node()] 274 | return obs, reward, done, info 275 | 276 | def get_task_type(self): 277 | return TaskType.PREDICT_ENV_DYNAMICS 278 | 279 | def make_task(scene_sampler, scene_evolver, task_type, eps_per_scene, num_objs_to_predict=1): 280 | if task_type == TaskType.PREDICT_LOC: 281 | return PredictLocationsTask(scene_sampler, scene_evolver, eps_per_scene=eps_per_scene, num_objs_to_predict=num_objs_to_predict) 282 | elif task_type == TaskType.PREDICT_ENV_DYNAMICS: 283 | return PredictDynamicsTask(scene_sampler, scene_evolver, eps_per_scene=eps_per_scene) 284 | elif task_type == TaskType.FIND_OBJECT: 285 | return FindObjTask(scene_sampler, scene_evolver, eps_per_scene=eps_per_scene) 286 | else: 287 | raise ValueError("Task Type {} not implemented".format(task_type.value)) 288 | 289 | 290 | -------------------------------------------------------------------------------- /memsearch/training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import logging 5 | import numpy as np 6 | from tqdm import tqdm 7 | from torch_geometric.loader import DataLoader 8 | from torch.nn.utils.rnn import pad_sequence 9 | from torch_geometric.transforms.to_undirected import ToUndirected 10 | #from torch_geometric.transforms import RandomLinkSplit 11 | from hydra.core.hydra_config import HydraConfig 12 | from memsearch.dataset import SGMDataset 13 | from memsearch.models import make_model, create_path_to_model, compute_output 14 | from tensorboardX import SummaryWriter 15 | from collections import defaultdict 16 | 17 | def sample_loss_edges_homogenous(data, max_to_sample, subsample=False, batch_by_node=True, reversed_edges=False): 18 | edge_labels = data.y.detach().numpy() 19 | edge_should_sample_for_loss = data.should_sample_for_loss.detach().numpy() 20 | if subsample: 21 | nonzero_indeces = edge_labels.nonzero()[0] 22 | nonzero_indeces = nonzero_indeces[edge_should_sample_for_loss[nonzero_indeces] == 1] 23 | flipped_edge_labels = 1 - edge_labels 24 | zero_indeces = flipped_edge_labels.nonzero()[0] 25 | zero_indeces = zero_indeces[edge_should_sample_for_loss[zero_indeces] == 1.0] 26 | num_to_sample = min([max_to_sample, len(zero_indeces), len(nonzero_indeces)]) 27 | chosen_nonzero_indeces = random.sample(list(nonzero_indeces), int(num_to_sample/2)) 28 | chosen_zero_indeces = random.sample(list(zero_indeces), int(num_to_sample/2)) 29 | indeces = chosen_zero_indeces + chosen_nonzero_indeces 30 | else: 31 | indeces = edge_should_sample_for_loss.nonzero()[0] 32 | edges = data.edge_index[:,indeces] 33 | edge_features = data.edge_attr[indeces,:] 34 | labels = data.y[indeces] 35 | 36 | if labels.shape[0] <= 1: 37 | return [], [], torch.tensor([]) 38 | 39 | if batch_by_node: 40 | batched_edges_dict = defaultdict(lambda: []) 41 | batched_edge_features_dict = defaultdict(lambda: []) 42 | batched_labels_dict = defaultdict(lambda: []) 43 | object_node_nums = set() 44 | for i in range(edges.size(dim=1)): 45 | if reversed_edges: 46 | object_node_num = edges[0,i].cpu().item() 47 | else: 48 | object_node_num = edges[1,i].cpu().item() 49 | object_node_nums.add(object_node_num) 50 | batched_edges_dict[object_node_num].append(edges[:,i]) 51 | batched_edge_features_dict[object_node_num].append(edge_features[i]) 52 | batched_labels_dict[object_node_num].append(labels[i].unsqueeze(0)) 53 | batched_edges_tensors = [torch.stack(batched_edges_dict[node_num]) for node_num in object_node_nums] 54 | batched_edge_features_tensors = [torch.stack(batched_edge_features_dict[node_num]) for node_num in object_node_nums] 55 | batched_labels_tensors = [torch.stack(batched_labels_dict[node_num]) for node_num in object_node_nums] 56 | edges = pad_sequence(batched_edges_tensors, batch_first=True) 57 | edge_features = pad_sequence(batched_edge_features_tensors, batch_first=True) 58 | labels = torch.squeeze(pad_sequence(batched_labels_tensors, padding_value=2, batch_first=True)) 59 | edges = edges.cuda() 60 | edge_features = edge_features.cuda() 61 | labels = labels.cuda() 62 | if len(labels.size()) == 1: 63 | labels = torch.unsqueeze(labels, dim=0) 64 | return edges, edge_features, labels 65 | 66 | def sample_loss_edges_heterogenous(data, max_to_sample, subsample=False, batch_by_node=True, reversed_edges=True): 67 | edges_dict = {} 68 | edge_features_dict = {} 69 | labels_dict = {} 70 | for key in data.edge_index_dict: 71 | edges, edge_features, labels = sample_loss_edges_homogenous(data[key], max_to_sample, subsample, batch_by_node, reversed_edges) 72 | if labels.shape[0] <= 1: 73 | continue 74 | edges_dict[key] = edges 75 | edge_features_dict[key] = edge_features 76 | labels_dict[key] = labels 77 | return edges_dict, edge_features_dict, labels_dict 78 | 79 | def sample_loss_edges(model, data, max_to_sample, subsample=False, batch_by_node=True, reversed_edges=True): 80 | if model.is_heterogenous(): 81 | return sample_loss_edges_heterogenous(data, max_to_sample, subsample, batch_by_node, reversed_edges) 82 | else: 83 | return sample_loss_edges_homogenous(data, max_to_sample, subsample, batch_by_node, reversed_edges) 84 | 85 | def train_step_homogeneous(model, 86 | input_data, 87 | loss_compute_edges, 88 | loss_compute_edge_features, 89 | loss_compute_labels, 90 | criterion, 91 | optimizer, 92 | do_optim_step=True, 93 | edge_key=None): 94 | optimizer.zero_grad() 95 | out = compute_output(model, input_data, loss_compute_edges, loss_compute_edge_features, edge_key) 96 | if len(loss_compute_labels.size()) > 1: 97 | # if batched by node 98 | loss_mask = (loss_compute_labels != 2).long() 99 | loss_compute_labels = loss_compute_labels*loss_mask 100 | out = out*loss_mask 101 | num_zeros = torch.sum(torch.abs((loss_compute_labels-1))).item() 102 | num_ones = torch.sum(torch.abs((loss_compute_labels))).item() 103 | loss = criterion(out, loss_compute_labels) 104 | loss_compute_labels = loss_compute_labels.long() 105 | if num_ones != 0: 106 | # reduce loss associated with label 0 since there are more of those labels 107 | zero_loss_scaling = float(num_ones)/num_zeros 108 | loss[(1-loss_compute_labels).bool()] = loss[(1-loss_compute_labels).bool()]*zero_loss_scaling 109 | loss = torch.mean(loss) 110 | else: 111 | loss = criterion(out, loss_compute_labels) 112 | loss.backward() 113 | if do_optim_step: 114 | optimizer.step() 115 | accuracy = (out.argmax(dim=1) == loss_compute_labels.argmax(dim=1)).float().mean() 116 | return out, loss, accuracy 117 | 118 | def train_step_heterogeneous(model, 119 | input_data, 120 | loss_compute_edges, 121 | loss_compute_edge_features, 122 | loss_compute_labels, 123 | criterion, 124 | optimizer): 125 | optimizer.zero_grad() 126 | outs = {} 127 | accuracies = [] 128 | losses = [] 129 | for key in loss_compute_labels: 130 | out, loss, accuracy = train_step_homogeneous(model, 131 | input_data, 132 | loss_compute_edges[key], 133 | loss_compute_edge_features[key], 134 | loss_compute_labels[key], 135 | criterion, 136 | optimizer, 137 | do_optim_step=False, 138 | edge_key=key) 139 | outs[key] = out 140 | accuracies.append(accuracy.cpu()) 141 | losses.append(loss.cpu().detach().numpy()) 142 | 143 | optimizer.step() 144 | 145 | mean_accuracy = np.mean(accuracies) 146 | mean_loss = np.mean(losses) 147 | 148 | return outs, mean_loss, mean_accuracy 149 | 150 | def train_step(model, 151 | input_data, 152 | loss_compute_edges, 153 | loss_compute_edge_features, 154 | loss_compute_labels, 155 | criterion, 156 | optimizer): 157 | if model.is_heterogenous(): 158 | return train_step_heterogeneous(model, 159 | input_data, 160 | loss_compute_edges, 161 | loss_compute_edge_features, 162 | loss_compute_labels, 163 | criterion, 164 | optimizer) 165 | else: 166 | return train_step_homogeneous(model, 167 | input_data, 168 | loss_compute_edges, 169 | loss_compute_edge_features, 170 | loss_compute_labels, 171 | criterion, 172 | optimizer) 173 | ''' 174 | def train_step_recurrent(model, 175 | input_data, 176 | loss_compute_edges, 177 | loss_compute_edge_features, 178 | loss_compute_labels, 179 | criterion, 180 | optimizer, 181 | h, c): 182 | out, h, c = model(input_data.x, input_data.edge_index, loss_compute_edges, h, c) 183 | loss = criterion(out, loss_compute_labels) 184 | return loss, h, c 185 | ''' 186 | 187 | def train(cfg,#TODO refactor to not take cfg 188 | num_epochs, 189 | node_featurizer, 190 | edge_featurizer, 191 | add_num_nodes=True, 192 | add_num_edges=True, 193 | use_edge_weights=False, 194 | num_labels_per_batch=1000, 195 | use_undirected_edges=False, 196 | logger=None): 197 | if logger is None: 198 | logger = logging.getLogger(__name__) 199 | logger.info('Training...') 200 | writer = SummaryWriter(os.path.join(HydraConfig.get().runtime.output_dir, 'train')) 201 | model = make_model(cfg.model, node_featurizer, edge_featurizer) 202 | model.train() 203 | to_undirected = ToUndirected() 204 | 205 | input_data_path = '%s/train/'%cfg.collect_data_dir 206 | output_data_path = '%s/train/'%cfg.processed_dataset_dir 207 | 208 | dataset = SGMDataset(input_data_path, 209 | output_data_path, 210 | node_featurizer, 211 | edge_featurizer, 212 | include_labels=True, 213 | add_num_nodes=add_num_nodes, 214 | add_num_edges=add_num_edges, 215 | num_workers=cfg.process_data_num_workers, 216 | pre_transform=[to_undirected] if use_undirected_edges else None, 217 | heterogenous=model.is_heterogenous(), 218 | reverse_edges=model.reversed_edges) 219 | if cfg.no_cache: 220 | dataset.process() 221 | optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) 222 | """ 223 | splitter = RandomLinkSplit(num_val=0, 224 | num_test=1.0, 225 | is_undirected=True, 226 | add_negative_train_samples=True, 227 | neg_sampling_ratio=1.0, 228 | disjoint_train_ratio=0.4) 229 | """ 230 | if model.is_recurrent(): 231 | loader = DataLoader(dataset, batch_size=5, shuffle=False) 232 | else: 233 | loader = DataLoader(dataset, batch_size=cfg.batch_size, shuffle=True) 234 | 235 | #if model.is_recurrent(): 236 | # h, c = None, None 237 | 238 | # does not work for some reason... 239 | #criterion = torch.nn.CrossEntropyLoss(reduction='none') 240 | 241 | criterion = torch.nn.BCELoss(reduction='none') 242 | l2_loss = torch.nn.MSELoss() 243 | for epoch in range(num_epochs): 244 | l2_loss_sum = 0 245 | accuracy_sum = 0 246 | epoch_loss = 0 247 | batch_count = 0 248 | pbar = tqdm(loader) 249 | for data in pbar: 250 | batch_count+=1 251 | if not use_edge_weights: 252 | data.edge_weight = None 253 | #if not training: 254 | # input_data, val_data, loss_compute_data = splitter(data.cuda()) 255 | 256 | loss_inputs = sample_loss_edges(model, data, num_labels_per_batch, reversed_edges = model.reversed_edges)#, subsample=not model.include_transformer) 257 | loss_edges, loss_edge_features, loss_labels = loss_inputs 258 | 259 | data = data.cuda() 260 | 261 | #recurrent is not currenlty supported 262 | ''' 263 | if model.is_recurrent(): 264 | loss, h, c = train_step(model, data, loss_edges, loss_labels, criterion, optimizer, h, c, model.is_recurrent()) 265 | total_loss+=loss/10 266 | epoch_loss+=loss.item() 267 | if (batch_count+1)%20 == 0: 268 | h, c = None, None 269 | total_loss.backward() 270 | total_loss = 0 271 | optimizer.step() 272 | optimizer.zero_grad() 273 | ''' 274 | 275 | out, loss, accuracy = train_step( 276 | model, 277 | data, 278 | loss_edges, 279 | loss_edge_features, 280 | loss_labels, 281 | criterion, 282 | optimizer) 283 | accuracy_sum += accuracy 284 | if type(out) is dict: 285 | for key in out: 286 | loss_labels[key] = loss_labels[key] * (loss_labels[key] != 2).long() 287 | l2_loss_sum+=l2_loss(out[key], loss_labels[key]).mean().cpu().item() 288 | else: 289 | loss_labels = loss_labels * (loss_labels != 2).long() 290 | l2_loss_sum+=l2_loss(out, loss_labels).mean().cpu().item() 291 | epoch_loss+=loss 292 | current_loss = epoch_loss/batch_count 293 | current_accuracy = accuracy_sum/batch_count 294 | l2_loss_avg = l2_loss_sum/batch_count 295 | pbar.set_description("Epoch %d | Avg Loss=%.4f | Avg l2 error = %.3f | Avg Accuracy=%.3f"%(epoch+1, 296 | current_loss, 297 | l2_loss_avg, 298 | current_accuracy)) 299 | writer.add_scalar('avg_loss', current_loss, epoch + 1) 300 | writer.add_scalar('avg_Accuracy', current_accuracy, epoch + 1) 301 | 302 | logger.info('Train l2 error for %s model: %.3f'%(model.get_model_type(), l2_loss_avg)) 303 | logger.info('Train avg accuracy for %s model: %.3f'%(model.get_model_type(), current_accuracy)) 304 | 305 | model_path = create_path_to_model( 306 | cfg = cfg.model, 307 | node_featurizer = node_featurizer, 308 | edge_featurizer = edge_featurizer, 309 | ) 310 | torch.save(model.state_dict(), model_path) 311 | 312 | return model 313 | 314 | def test(cfg,#TODO refactor to not take cfg 315 | model, 316 | node_featurizer, 317 | edge_featurizer, 318 | add_num_nodes=True, 319 | add_num_edges=True, 320 | use_edge_weights=False, 321 | num_labels_per_batch=100, 322 | use_undirected_edges=False, 323 | logger=None): 324 | if logger is None: 325 | logger = logging.getLogger(__name__) 326 | logger.info('Testing...') 327 | model.eval() 328 | to_undirected = ToUndirected() 329 | input_data_path = '%s/test/'%cfg.collect_data_dir 330 | output_data_path = '%s/test/'%cfg.processed_dataset_dir 331 | dataset = SGMDataset(input_data_path, 332 | output_data_path, 333 | node_featurizer, 334 | edge_featurizer, 335 | include_labels=True, 336 | add_num_nodes=add_num_nodes, 337 | add_num_edges=add_num_edges, 338 | num_workers=cfg.process_data_num_workers, 339 | pre_transform=[to_undirected] if use_undirected_edges else None, 340 | heterogenous=model.is_heterogenous(), 341 | reverse_edges=model.reversed_edges) 342 | if not model.is_recurrent(): 343 | loader = DataLoader(dataset, batch_size=5, shuffle=False) 344 | else: 345 | loader = DataLoader(dataset, batch_size=cfg.batch_size, shuffle=True) 346 | 347 | #if recurrent: 348 | # h, c = None, None 349 | 350 | l2_loss = torch.nn.MSELoss() 351 | l2_loss_sum = 0 352 | accuracy_sum = 0 353 | batch_count = 0 354 | pbar = tqdm(loader) 355 | for data in pbar: 356 | if not use_edge_weights: 357 | data.edge_weight = None 358 | #if not training: 359 | # input_data, val_data, loss_compute_data = splitter(data.cuda()) 360 | test_edges, test_edge_features, test_labels = sample_loss_edges(model, data, num_labels_per_batch, reversed_edges = model.reversed_edges) 361 | data = data.cuda() 362 | 363 | if model.is_heterogenous(): 364 | for key in test_edges: 365 | out = compute_output(model, data, test_edges[key], test_edge_features[key], edge_key=key) 366 | if len(out.size()) < 2 or len(test_labels[key].size()) < 2: 367 | continue 368 | test_labels[key] = test_labels[key] * (test_labels[key] != 2).long() 369 | accuracy = (out.argmax(dim=1) == test_labels[key].argmax(dim=1)).float().mean() 370 | accuracy_sum+=accuracy.item() 371 | l2_loss_sum+=l2_loss(out, test_labels[key]).mean().cpu().item() 372 | batch_count+=1 373 | else: 374 | out = compute_output(model, data, test_edges, test_edge_features) 375 | 376 | test_labels = test_labels * (test_labels != 2).long() 377 | if len(out.size()) == 1: 378 | continue 379 | accuracy_sum+=(out.argmax(dim=1) == test_labels.argmax(dim=1)).float().mean().cpu().item() 380 | l2_loss_sum+=l2_loss(out, test_labels).mean().cpu().item() 381 | batch_count+=1 382 | 383 | logger.info('Test avg l2 error for %s model: %.3f'%(model.get_model_type(), l2_loss_sum/batch_count)) 384 | logger.info('Test avg accuracy for %s model: %.3f\n'%(model.get_model_type(), accuracy_sum/batch_count)) 385 | 386 | return accuracy_sum/batch_count 387 | -------------------------------------------------------------------------------- /memsearch/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import logging 4 | import coloredlogs 5 | from hydra.core.hydra_config import HydraConfig 6 | 7 | def sample_from_dict(dict_with_weights): 8 | d = dict_with_weights 9 | return random.choices(list(d.keys()), weights=d.values(), k=1)[0] 10 | 11 | def float_to_hex(number, base = 16): 12 | if number < 0: # Check if the number is negative to manage the sign 13 | sign = "-" # Set the negative sign, it will be used later to generate the first element of the result list 14 | number = -number # Change the number sign to positive 15 | else: 16 | sign = "" # Set the positive sign, it will be used later to generate the first element of the result list 17 | 18 | s = [sign + str(int(number)) + '.'] # Generate the list, the first element will be the integer part of the input number 19 | number -= int(number) # Remove the integer part from the number 20 | 21 | for i in range(base): # Iterate N time where N is the required base 22 | y = int(number * 16) # Multiply the number by 16 and take the integer part 23 | s.append(hex(y)[2:]) # Append to the list the hex value of y, the result is in format 0x00 so we take the value from postion 2 to the end 24 | number = number * 16 - y # Calculate the next number required for the conversion 25 | 26 | return ''.join(s).rstrip('0') 27 | 28 | def char_code_at(testS): 29 | l = list(bytes(testS, 'utf-16'))[2:] 30 | for i, c in enumerate([(b<<8)|a for a,b in list(zip(l,l[1:]))[::2]]): 31 | return c 32 | 33 | def string_to_color(text): 34 | hash = 0 35 | for x in text: 36 | hash = char_code_at(x) + ((hash << 5) - hash) 37 | 38 | colour = '#'; 39 | for i in range(3-1): 40 | value = (hash >> (i * 8)) & 0xFF; 41 | colour += ('00' + float_to_hex(value,16)[-2]) 42 | return colour 43 | 44 | FORMAT_STR = '%(asctime)s - %(message)s' 45 | DATE_FORMAT = '%Y-%m-%d %H:%M:%S' 46 | 47 | def configure_logging(dir_path=None, format_strs=[None], name='log', log_suffix=''): 48 | if dir_path is None: 49 | dir_path = os.path.join(HydraConfig.get().runtime.output_dir) 50 | logger = logging.getLogger() # root logger 51 | formatter = logging.Formatter(FORMAT_STR, DATE_FORMAT) 52 | file_to_delete = open("info.txt",'w') 53 | file_to_delete.close() 54 | file_path = "{0}/{1}.log".format(dir_path, name) 55 | #file_handler = logging.FileHandler(filename="{0}/{1}.log".format(dir_path, name), mode='w') 56 | #file_handler.setFormatter(formatter) 57 | #logger.addHandler(file_handler) 58 | if os.isatty(2): 59 | coloredlogs.install(fmt=FORMAT_STR, level='INFO') 60 | else: 61 | stream_handler = logging.StreamHandler() 62 | stream_handler.setFormatter(formatter) 63 | logger.addHandler(stream_handler) 64 | return logger 65 | -------------------------------------------------------------------------------- /priors/coarse_prior_graph.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andreykurenkov/modeling_env_dynamics/74e5f9d722469f2d1148fe131aa85dfb049da7cb/priors/coarse_prior_graph.pickle -------------------------------------------------------------------------------- /priors/detailed_prior_graph.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andreykurenkov/modeling_env_dynamics/74e5f9d722469f2d1148fe131aa85dfb049da7cb/priors/detailed_prior_graph.pickle -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aenum==3.1.11 2 | antlr4-python3-runtime==4.9.3 3 | appnope==0.1.3 4 | argon2-cffi==21.3.0 5 | argon2-cffi-bindings==21.2.0 6 | asttokens==2.0.8 7 | attrs==22.1.0 8 | backcall==0.2.0 9 | bddl==1.0.1 10 | beautifulsoup4==4.11.1 11 | bleach==5.0.1 12 | blis==0.7.8 13 | catalogue==2.0.8 14 | certifi==2022.6.15 15 | cffi==1.15.1 16 | charset-normalizer==2.1.0 17 | click==8.1.3 18 | cloudpickle==2.1.0 19 | cmake==3.24.1 20 | coloredlogs==15.0.1 21 | cycler==0.11.0 22 | cymem==2.0.6 23 | debugpy==1.6.3 24 | decorator==5.1.1 25 | defusedxml==0.7.1 26 | dill==0.3.5.1 27 | en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.4.1/en_core_web_sm-3.4.1-py3-none-any.whl 28 | entrypoints==0.4 29 | executing==0.10.0 30 | fastjsonschema==2.16.1 31 | filelock==3.8.0 32 | fonttools==4.33.3 33 | freetype-py==2.3.0 34 | future==0.18.2 35 | gitdb==4.0.9 36 | GitPython==3.1.27 37 | GPUtil==1.4.0 38 | gym==0.21.0 39 | gym-minigrid==1.0.3 40 | h5py==3.7.0 41 | huggingface-hub==0.10.1 42 | humanfriendly==10.0 43 | hydra-core==1.2.0 44 | idna==3.3 45 | iniconfig==1.1.1 46 | ipykernel==6.15.1 47 | ipython==8.4.0 48 | ipython-genutils==0.2.0 49 | ipywidgets==8.0.1 50 | jedi==0.18.1 51 | Jinja2==3.1.2 52 | joblib==1.2.0 53 | jsonschema==4.13.0 54 | jupyter==1.0.0 55 | jupyter-client==7.3.4 56 | jupyter-console==6.4.4 57 | jupyter-core==4.11.1 58 | jupyterlab-pygments==0.2.2 59 | jupyterlab-widgets==3.0.2 60 | jupytext==1.14.1 61 | kiwisolver==1.4.3 62 | langcodes==3.3.0 63 | lxml==4.9.1 64 | markdown-it-py==2.1.0 65 | MarkupSafe==2.1.1 66 | matplotlib==3.5.2 67 | matplotlib-inline==0.1.6 68 | mdit-py-plugins==0.3.0 69 | mdurl==0.1.2 70 | mistune==0.8.4 71 | murmurhash==1.0.8 72 | nbclient==0.6.6 73 | nbconvert==6.5.3 74 | nbformat==5.4.0 75 | nest-asyncio==1.5.5 76 | networkx==2.8.5 77 | nltk==3.7 78 | notebook==6.4.12 79 | numpy==1.23.0 80 | omegaconf==2.2.3 81 | opencv-python==4.6.0.66 82 | packaging==21.3 83 | pandas==1.4.3 84 | pandocfilters==1.5.0 85 | parso==0.8.3 86 | pathy==0.6.2 87 | pexpect==4.8.0 88 | pickleshare==0.7.5 89 | Pillow==9.2.0 90 | pluggy==1.0.0 91 | preshed==3.0.7 92 | progressbar==2.5 93 | prometheus-client==0.14.1 94 | prompt-toolkit==3.0.30 95 | protobuf==3.20.1 96 | psutil==5.9.1 97 | ptyprocess==0.7.0 98 | pure-eval==0.2.2 99 | py==1.11.0 100 | py360convert==0.1.0 101 | pybullet-svl==3.1.6.4 102 | pycparser==2.21 103 | pydantic==1.9.2 104 | Pygments==2.13.0 105 | pyinstrument==4.2.0 106 | pyparsing==3.0.9 107 | pyrsistent==0.18.1 108 | pytest==7.1.2 109 | python-dateutil==2.8.2 110 | pytz==2022.2.1 111 | PyYAML==6.0 112 | pyzmq==23.2.1 113 | qtconsole==5.3.1 114 | QtPy==2.2.0 115 | regex==2022.9.13 116 | requests==2.28.1 117 | scikit-learn==1.1.3 118 | scipy==1.9.0 119 | Send2Trash==1.8.0 120 | sentence-transformers==2.2.2 121 | sentencepiece==0.1.97 122 | six==1.16.0 123 | smart-open==5.2.1 124 | smmap==5.0.0 125 | soupsieve==2.3.2.post1 126 | spacy==3.4.1 127 | spacy-legacy==3.0.10 128 | spacy-loggers==1.0.3 129 | srsly==2.4.4 130 | stable-baselines3==1.6.0 131 | stack-data==0.4.0 132 | tensorboardX==2.5.1 133 | terminado==0.15.0 134 | thinc==8.1.0 135 | threadpoolctl==3.1.0 136 | tinycss2==1.1.1 137 | tokenizers==0.13.1 138 | toml==0.10.2 139 | tomli==2.0.1 140 | tornado==6.2 141 | tqdm==4.64.0 142 | traitlets==5.3.0 143 | transformers==4.23.1 144 | transforms3d==0.3.1 145 | trimesh==3.13.5 146 | typer==0.4.2 147 | typing_extensions==4.3.0 148 | urllib3==1.26.11 149 | wasabi==0.10.1 150 | wcwidth==0.2.5 151 | webencodings==0.5.1 152 | widgetsnbextension==4.0.2 -------------------------------------------------------------------------------- /rl_scripts/check_env.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | 3 | import memsearch 4 | from memsearch.igridson_env import SMGFixedEnv 5 | from memsearch.igridson_utils import make_scene_sampler_and_evolver_, reset, visualize_agent, evolve 6 | 7 | import matplotlib.pyplot as plt 8 | from ray.rllib.utils import check_env 9 | 10 | @hydra.main(version_base=None, 11 | config_path=memsearch.CONFIG_PATH, 12 | config_name="config") 13 | def main(cfg): 14 | scene_sampler, scene_evolver = make_scene_sampler_and_evolver_(cfg.scene_gen) 15 | env = SMGFixedEnv(scene_sampler=scene_sampler, scene_evolver=scene_evolver) 16 | breakpoint() 17 | 18 | check_env(env) 19 | 20 | if __name__ == "__main__": 21 | main() # type: ignore 22 | 23 | -------------------------------------------------------------------------------- /rl_scripts/interactive_test.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | 3 | import memsearch 4 | from memsearch.igridson_env import SMGFixedEnv 5 | from memsearch.igridson_utils import make_scene_sampler_and_evolver_ 6 | import matplotlib 7 | import matplotlib.pyplot as plt 8 | from tqdm import tqdm 9 | import sys 10 | 11 | matplotlib.rcParams.update({'font.size': 4}) 12 | 13 | KEYBOARD_ACTION_MAP = {'w': 2, 'a': 0, 'd': 1} 14 | 15 | def visualize(env, obs, reward, done): 16 | full_grid = env.render(mode="rgb_array", tile_size=32) 17 | # Visualize observed and full grid 18 | plt.subplot(1,2,1) 19 | plt.title("Observed view") 20 | plt.imshow(obs["image"]) 21 | plt.axis('off') 22 | plt.subplot(1,2,2) 23 | plt.title("Full Grid") 24 | plt.imshow(full_grid) 25 | plt.tight_layout() 26 | plt.suptitle("Goal: {}, Reward: {}, Done: {}".format(env.goal_obj_label, reward, done)) 27 | plt.axis('off') 28 | plt.show() 29 | 30 | def run_training(cfg): 31 | scene_sampler, scene_evolver = make_scene_sampler_and_evolver_(cfg.scene_gen) 32 | env = SMGFixedEnv(scene_sampler=scene_sampler, scene_evolver=scene_evolver, set_goal_icon=True, env_evolve_freq=5) 33 | 34 | def on_press(event): 35 | print("Target poses:", env.target_poses) 36 | print("Agent pose:", env.agent_pos) 37 | sys.stdout.flush() 38 | if event.key == 'r': 39 | print("Resetting...") 40 | env.reset() 41 | print("done.") 42 | init_action = env.action_space.sample() 43 | obs, reward, done, info = env.step(init_action) 44 | visualize(env, obs, reward, done) 45 | elif event.key in ['w', 'a', 'd']: 46 | action = KEYBOARD_ACTION_MAP[event.key] 47 | obs, reward, done, info = env.step(action) 48 | visualize(env, obs, reward, done) 49 | else: 50 | print("Invalid input", event.key) 51 | fig = plt.gcf() 52 | fig.canvas.mpl_connect('key_press_event', on_press) 53 | 54 | #Initialize 55 | init_action = env.action_space.sample() 56 | obs, reward, done, info = env.step(init_action) 57 | visualize(env, obs, reward, done) 58 | 59 | @hydra.main(version_base=None, 60 | config_path=memsearch.CONFIG_PATH, 61 | config_name="config") 62 | def main(cfg): 63 | run_training(cfg) 64 | 65 | if __name__ == "__main__": 66 | main() # type: ignore -------------------------------------------------------------------------------- /rl_scripts/main.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | 3 | import memsearch 4 | from memsearch.igridson_env import SMGFixedEnv 5 | from memsearch.igridson_utils import make_scene_sampler_and_evolver_ 6 | import matplotlib.pyplot as plt 7 | 8 | def visualize(env, obs): 9 | full_grid = env.render(mode="rgb_array", tile_size=32) 10 | 11 | # Visualize observed and full grid 12 | plt.subplot(1,2,1) 13 | plt.title("Observed view") 14 | plt.imshow(obs["image"]) 15 | plt.subplot(1,2,2) 16 | plt.title("Full Grid") 17 | plt.imshow(full_grid) 18 | plt.tight_layout() 19 | plt.show() 20 | env.reset() 21 | 22 | def run_training(cfg): 23 | scene_sampler, scene_evolver = make_scene_sampler_and_evolver_(cfg.scene_gen) 24 | env = SMGFixedEnv(scene_sampler=scene_sampler, scene_evolver=scene_evolver) 25 | for _ in range(100): 26 | env.reset() 27 | for idx in range(10000): 28 | action = env.action_space.sample() 29 | obs, reward, done, info = env.step(action) # obs is a dict with keys "image", "direction" and "mission" 30 | if done: 31 | print(f"Done at {idx}") 32 | break 33 | 34 | 35 | @hydra.main(version_base=None, 36 | config_path=memsearch.CONFIG_PATH, 37 | config_name="config") 38 | def main(cfg): 39 | run_training(cfg) 40 | 41 | if __name__ == "__main__": 42 | main() # type: ignore 43 | 44 | -------------------------------------------------------------------------------- /rl_scripts/train.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | 3 | import memsearch 4 | from memsearch.igridson_env import SMGFixedEnv 5 | from ray.rllib.algorithms.ppo.ppo import PPOConfig, PPO 6 | import tqdm 7 | from memsearch.igridson_utils import make_scene_sampler_and_evolver_ 8 | from ray.tune.registry import register_env 9 | from ray.rllib.models import ModelCatalog 10 | 11 | from memsearch.rl.complex_input_network import ComplexInputNetwork 12 | 13 | 14 | @hydra.main(version_base=None, config_path=memsearch.CONFIG_PATH, config_name="config") 15 | def main(cfg): 16 | def env_creator(env_config): 17 | return SMGFixedEnv(**env_config) 18 | 19 | register_env("SmgFixedEnv-v0", env_creator) 20 | ModelCatalog.register_custom_model("ComplexInputNetwork", ComplexInputNetwork) 21 | 22 | scene_sampler, scene_evolver = make_scene_sampler_and_evolver_(cfg.scene_gen) 23 | config = ( 24 | PPOConfig() 25 | .resources(num_gpus=1) 26 | .rollouts( num_rollout_workers=4, horizon=300) 27 | .framework("torch") 28 | .training( 29 | model={ 30 | "custom_model": "ComplexInputNetwork", 31 | "conv_filters": [[32, 4, 4], [64, 2, 2]] 32 | } 33 | ) 34 | .environment( 35 | "SmgFixedEnv-v0", 36 | env_config={ 37 | "scene_sampler": scene_sampler, 38 | "scene_evolver": scene_evolver, 39 | "mission_mode": "one_hot", 40 | }, 41 | ) 42 | ) 43 | 44 | trainer = PPO(config=config) 45 | for _ in tqdm.tqdm(range(100000)): 46 | trainer.train() 47 | 48 | 49 | if __name__ == "__main__": 50 | main() # type: ignore 51 | -------------------------------------------------------------------------------- /rl_scripts/train_dqn.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | 3 | import memsearch 4 | from memsearch.igridson_env import SMGFixedEnv 5 | from ray.rllib.algorithms.dqn.dqn import DQNConfig, DQN 6 | # from ray.rllib.algorithms.ppo.ppo import PPOConfig, PPO 7 | import tqdm 8 | from memsearch.igridson_utils import make_scene_sampler_and_evolver_ 9 | from ray.tune.registry import register_env 10 | from ray.rllib.models import ModelCatalog 11 | 12 | from memsearch.rl.complex_input_network import ComplexInputNetwork 13 | 14 | 15 | @hydra.main(version_base=None, config_path=memsearch.CONFIG_PATH, config_name="config") 16 | def main(cfg): 17 | def env_creator(env_config): 18 | return SMGFixedEnv(**env_config) 19 | 20 | register_env("SmgFixedEnv-v0", env_creator) 21 | ModelCatalog.register_custom_model("ComplexInputNetwork", ComplexInputNetwork) 22 | 23 | scene_sampler, scene_evolver = make_scene_sampler_and_evolver_(cfg.scene_gen) 24 | config = ( 25 | DQNConfig() 26 | .resources(num_gpus=1) 27 | .rollouts( num_rollout_workers=4, horizon=300) 28 | .framework("torch") 29 | .training( 30 | model={ 31 | "custom_model": "ComplexInputNetwork", 32 | "conv_filters": [[32, 4, 4], [64, 2, 2]] 33 | } 34 | ) 35 | .environment( 36 | "SmgFixedEnv-v0", 37 | env_config={ 38 | "scene_sampler": scene_sampler, 39 | "scene_evolver": scene_evolver, 40 | "mission_mode": "one_hot", 41 | }, 42 | ) 43 | ) 44 | 45 | config.replay_buffer_config.update( 46 | { 47 | "capacity": 10000, 48 | "prioritized_replay_alpha": 0.5, 49 | "prioritized_replay_beta": 0.5, 50 | "prioritized_replay_eps": 3e-6, 51 | } 52 | ) 53 | 54 | trainer = DQN(config=config) 55 | for _ in tqdm.tqdm(range(100000)): 56 | trainer.train() 57 | 58 | 59 | if __name__ == "__main__": 60 | main() # type: ignore 61 | -------------------------------------------------------------------------------- /rl_scripts/train_dqn_flat.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import ray 3 | import ray.rllib.algorithms.dqn as dqn 4 | import ray.rllib.algorithms.ppo as ppo 5 | import numpy as np 6 | from ray.tune.registry import register_env 7 | import tqdm 8 | 9 | import memsearch 10 | import memsearch 11 | from memsearch.igridson_env import SMGFixedEnv 12 | from memsearch.igridson_utils import make_scene_sampler_and_evolver_ 13 | 14 | 15 | from gym.core import ObservationWrapper 16 | import gym.spaces as spaces 17 | # class FlattenObservation(ObservationWrapper): 18 | # r"""Observation wrapper that flattens the observation.""" 19 | # 20 | # def __init__(self, env): 21 | # super().__init__(env) 22 | # self.observation_space = spaces.flatten_space(env.observation_space) 23 | # 24 | # def observation(self, observation): 25 | # return np.concatenate((observation['image'].flatten(), observation['direction'].reshape(1), observation['mission'])) 26 | 27 | from gym.wrappers.flatten_observation import FlattenObservation 28 | 29 | @hydra.main(version_base=None, config_path=memsearch.CONFIG_PATH, config_name="config") 30 | def main(cfg): 31 | ray.init() 32 | 33 | def env_creator(env_config): 34 | return FlattenObservation(SMGFixedEnv(**env_config)) 35 | 36 | register_env("SmgFixedEnv-v0", env_creator) 37 | 38 | scene_sampler, scene_evolver = make_scene_sampler_and_evolver_(cfg.scene_gen) 39 | env = FlattenObservation(SMGFixedEnv( 40 | scene_sampler=scene_sampler, 41 | scene_evolver=scene_evolver, 42 | encode_obs_im=True, 43 | mission_mode='one_hot' 44 | )) 45 | 46 | obs = env.reset() 47 | action = env.action_space.sample() 48 | all = env.step(action) 49 | breakpoint() 50 | 51 | 52 | config = ( 53 | dqn.DQNConfig() 54 | .resources(num_gpus=1) 55 | .rollouts(num_rollout_workers=4, horizon=300) 56 | .framework("torch") 57 | .training( 58 | model={ 59 | # Auto-wrap the custom(!) model with an LSTM. 60 | # "use_lstm": True, 61 | "framestack": True, 62 | # To further customize the LSTM auto-wrapper. 63 | # "lstm_cell_size": 64, 64 | # Specify our custom model from above. 65 | # Extra kwargs to be passed to your model's c'tor. 66 | # "custom_model_config": {}, 67 | } 68 | ) 69 | .environment( 70 | "SmgFixedEnv-v0", 71 | env_config={ 72 | "scene_sampler": scene_sampler, 73 | "scene_evolver": scene_evolver, 74 | "encode_obs_im": True, 75 | "mission_mode": "one_hot", 76 | }, 77 | ) 78 | ) 79 | 80 | config.replay_buffer_config.update( 81 | { 82 | "capacity": 10000, 83 | "prioritized_replay_alpha": 0.5, 84 | "prioritized_replay_beta": 0.5, 85 | "prioritized_replay_eps": 3e-6, 86 | } 87 | ) 88 | 89 | trainer = config.build() 90 | for _ in tqdm.tqdm(range(100000)): 91 | trainer.train() 92 | 93 | 94 | if __name__ == "__main__": 95 | main() # type: ignore 96 | -------------------------------------------------------------------------------- /rl_scripts/train_dqn_lstm_flat.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | 3 | import memsearch 4 | from memsearch.igridson_env import SMGFixedEnv 5 | from ray.rllib.algorithms.dqn.dqn import DQNConfig, DQN 6 | # from ray.rllib.algorithms.ppo.ppo import PPOConfig, PPO 7 | import tqdm 8 | from memsearch.igridson_utils import make_scene_sampler_and_evolver_ 9 | from ray.tune.registry import register_env 10 | 11 | from gym.wrappers import FlattenObservation 12 | 13 | @hydra.main(version_base=None, config_path=memsearch.CONFIG_PATH, config_name="config") 14 | def main(cfg): 15 | def env_creator(env_config): 16 | return FlattenObservation(SMGFixedEnv(**env_config)) 17 | 18 | register_env("SmgFixedEnv-v0", env_creator) 19 | 20 | scene_sampler, scene_evolver = make_scene_sampler_and_evolver_(cfg.scene_gen) 21 | config = ( 22 | DQNConfig() 23 | .resources(num_gpus=1) 24 | .rollouts( num_rollout_workers=4, horizon=300) 25 | .framework("torch") 26 | .training( 27 | model={ 28 | "framestack": 20 29 | } 30 | ) 31 | .environment( 32 | "SmgFixedEnv-v0", 33 | env_config={ 34 | "scene_sampler": scene_sampler, 35 | "scene_evolver": scene_evolver, 36 | "encode_obs_im": False, 37 | "mission_mode": 'one_hot' 38 | }, 39 | ) 40 | ) 41 | 42 | config.replay_buffer_config.update( 43 | { 44 | "capacity": 10000, 45 | "prioritized_replay_alpha": 0.5, 46 | "prioritized_replay_beta": 0.5, 47 | "prioritized_replay_eps": 3e-6, 48 | } 49 | ) 50 | 51 | trainer = DQN(config=config) 52 | for _ in tqdm.tqdm(range(100000)): 53 | trainer.train() 54 | 55 | 56 | if __name__ == "__main__": 57 | main() # type: ignore 58 | -------------------------------------------------------------------------------- /scripts/collect_data.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import os 3 | import gc 4 | import memsearch 5 | import logging 6 | 7 | from hydra.core.hydra_config import HydraConfig 8 | from memsearch.util import configure_logging 9 | from memsearch.tasks import TaskType 10 | from memsearch.running import collect_data 11 | from memsearch.dataset import make_featurizers, SGMDataset 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | def convert_to_pyg(cfg, node_featurizer, edge_featurizer, sgm_graphs, data_type): 16 | input_data_path = f'{cfg.collect_data_dir}/{data_type}/' 17 | output_data_path = f'{cfg.processed_dataset_dir}/{data_type}/' 18 | for is_heterogenous in [False, True]: 19 | if is_heterogenous: 20 | print(f'Creating tensors for heterogenous {data_type} graphs') 21 | else: 22 | print(f'Creating tensors for homogenous {data_type} graphs') 23 | # filter out super big graphs to make things faster 24 | sgm_graphs = [graph for graph in sgm_graphs if len(graph.get_edges()) < 2500 25 | and len(graph.nodes) < 225] 26 | dataset = SGMDataset(input_data_path, 27 | output_data_path, 28 | node_featurizer, 29 | edge_featurizer, 30 | include_labels=True, 31 | add_num_nodes = cfg.model.add_num_nodes, 32 | add_num_edges = cfg.model.add_num_edges, 33 | num_workers=cfg.process_data_num_workers, 34 | heterogenous=is_heterogenous, 35 | reverse_edges=cfg.model.reversed_edges, 36 | data = sgm_graphs) 37 | del dataset #this does processing upon construction, no need to store 38 | 39 | def run_data_collection(cfg): 40 | logger = configure_logging(name='collet_data') 41 | output_dir = HydraConfig.get().runtime.output_dir 42 | configure_logging(cfg.log_path) 43 | save_graphs_path=cfg.collect_data_dir 44 | if cfg.process_graphs_after_collection: 45 | node_featurizer, edge_featurizer = make_featurizers(cfg.model, True, cfg.task.num_steps) 46 | if not cfg.process_graphs_after_collection and not os.path.isdir(save_graphs_path): 47 | os.makedirs(save_graphs_path+'/train') 48 | os.makedirs(save_graphs_path+'/test') 49 | if not cfg.process_graphs_after_collection and len(os.listdir(save_graphs_path+'/train')) == cfg.data_gen.num_steps_train \ 50 | or cfg.process_graphs_after_collection and os.path.isdir(cfg.processed_dataset_dir+'/train'): 51 | print('Data already collected, quitting.') 52 | return 53 | cfg.agents.memorization_use_priors = True 54 | print('Collecting train data with agent type %s\n'%cfg.data_gen.agent_type) 55 | train_sgms = collect_data(cfg, 56 | logger, 57 | output_dir, 58 | num_steps = cfg.data_gen.num_steps_train, 59 | agent_type = cfg.data_gen.agent_type, 60 | save_graphs_path = '%s/train'%(save_graphs_path), 61 | save_sgm_graphs = not cfg.process_graphs_after_collection) 62 | if cfg.process_graphs_after_collection: 63 | convert_to_pyg(cfg, node_featurizer, edge_featurizer, train_sgms, 'train') 64 | del train_sgms 65 | gc.collect() 66 | print('\nDone!') 67 | 68 | print('\nCollecting test data with agent type %s\n'%cfg.data_gen.agent_type) 69 | test_sgms = collect_data(cfg, 70 | logger, 71 | output_dir, 72 | num_steps = cfg.data_gen.num_steps_test, 73 | agent_type=cfg.data_gen.agent_type, 74 | save_graphs_path='%s/test'%(save_graphs_path), 75 | save_sgm_graphs = not cfg.process_graphs_after_collection) 76 | print('\nDone!') 77 | 78 | if cfg.process_graphs_after_collection: 79 | convert_to_pyg(cfg, node_featurizer, edge_featurizer, test_sgms, 'test') 80 | del test_sgms 81 | 82 | @hydra.main(version_base=None, config_path=memsearch.CONFIG_PATH, config_name="config") 83 | def main(cfg): 84 | run_data_collection(cfg) 85 | 86 | if __name__ == '__main__': 87 | main() #type: ignore 88 | -------------------------------------------------------------------------------- /scripts/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hydra 3 | import logging 4 | import numpy as np 5 | import seaborn as sns 6 | import memsearch 7 | import functools 8 | import multiprocessing 9 | from tqdm import tqdm 10 | 11 | from memsearch.util import configure_logging 12 | from memsearch.metrics import AvgAccuracy, AvgAUC, DiscSumOfRewards, plot_agent_eval, store_agent_eval, rename 13 | from memsearch.running import eval_agent 14 | from memsearch.tasks import TaskType 15 | from hydra.core.hydra_config import HydraConfig 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | def eval_policies(cfg): 20 | logger = configure_logging(name='eval') 21 | task = TaskType(cfg.task.name) 22 | task_name = task.value 23 | final_averages = [] 24 | if type(cfg.agents.agent_types) is str: 25 | agent_types = [cfg.agents.agent_types] 26 | else: 27 | agent_types = cfg.agents.agent_types 28 | output_dir = os.path.join(HydraConfig.get().runtime.output_dir, 'eval') 29 | images_dir = os.path.join(output_dir, 'images') 30 | metrics = [ 31 | AvgAccuracy('%s_%s_avg_ac'%(cfg.run_name, task_name), images_dir), 32 | AvgAUC('%s_%s_avg_AuC'%(cfg.run_name,task_name), images_dir), 33 | DiscSumOfRewards('%s_%s_avg_DSoR'%(cfg.run_name, task_name), images_dir) 34 | ] 35 | 36 | if cfg.eval_in_parallel and len(agent_types) > 1: 37 | multiprocessing.set_start_method('spawn') 38 | run_eval_agent = functools.partial(eval_agent, cfg, logger, output_dir, task) 39 | tqdm.set_lock(multiprocessing.RLock()) 40 | with multiprocessing.Pool(processes = len(agent_types)) as pool: 41 | all_score_vecs = pool.map(run_eval_agent, agent_types) 42 | else: 43 | all_score_vecs = [] 44 | for agent_type in agent_types: 45 | logger.info('Evaluating %s'%agent_type) 46 | all_score_vecs.append(eval_agent(cfg, logger, output_dir, task, agent_type)) 47 | 48 | sns.set() 49 | for i, agent_type in enumerate(agent_types): 50 | score_vecs = all_score_vecs[i] 51 | final_averages.append(np.mean(score_vecs)) 52 | store_agent_eval(output_dir, score_vecs, agent_type) 53 | save_figs = (i == len(agent_types) - 1) 54 | plot_agent_eval(cfg.task.eps_per_scene, score_vecs, agent_type, i, metrics, 55 | smoothing_kernel_size=cfg.task.num_smoothing_steps, task=task, 56 | show_fig=False, save_fig=save_figs, x_labels=agent_types) 57 | 58 | for metric in metrics: 59 | logger.info('---') 60 | logger.info('Results for metric %s'%metric.get_metric_name()) 61 | for i, agent_type in enumerate(cfg.agents.agent_types): 62 | logger.info('%s agent final average score: %.3f +/- %.3f'%(agent_type, 63 | metric.agent_evals[rename(agent_type)], 64 | metric.agent_evals_var[rename(agent_type)])) 65 | if 'upper_bound' in cfg.agents.agent_types: 66 | logger.info('---') 67 | logger.info('Results for metric %s normalized by upper bound'%metric.get_metric_name()) 68 | for i, agent_type in enumerate(cfg.agents.agent_types): 69 | logger.info('%s agent final normalized average score: %.3f'%(agent_type, 70 | metric.agent_evals[rename(agent_type)]/metric.agent_evals[rename('upper_bound')])) 71 | 72 | @hydra.main(version_base=None, config_path=memsearch.CONFIG_PATH, config_name="config") 73 | def main(cfg): 74 | eval_policies(cfg) 75 | 76 | if __name__ == '__main__': 77 | main() #type: ignore 78 | -------------------------------------------------------------------------------- /scripts/gen_experiment_configs.py: -------------------------------------------------------------------------------- 1 | from memsearch.experiment_configs import * 2 | import itertools 3 | from tqdm import tqdm 4 | import copy 5 | import yaml 6 | import argparse 7 | 8 | default_configs = { 9 | 'defaults': [{'/scene_gen': 'large_scenes'}, {'/agents': 'all'}, {'/data_gen': '10k'}, {'/task': 'predict_location'}, {'/model': 'heat'}], 10 | 'run_name': 'pl_l_dn_d_d_za_n', 11 | 'model': 12 | { 13 | 'include_transformer': True, 14 | 'node_features': 'all', 15 | 'edge_features': 'all', 16 | }, 17 | 'agents': 18 | { 19 | 'agent_priors_type': 'detailed', 20 | 'sgm_use_priors': True 21 | }, 22 | 'scene_gen': 23 | { 24 | 'add_or_remove_objs': True, 25 | 'scene_priors_type': 'detailed', 26 | 'scene_priors_type': 'detailed', 27 | 'priors_noise': 0.25, 28 | 'priors_sparsity_level': 0.25 29 | } 30 | } 31 | 32 | TASK_MATCHES = { 33 | 'pl': 'predict_location', 34 | 'pls': 'predict_locations', 35 | 'pd': 'predict_env_dynamics', 36 | 'fo': 'find_object' 37 | } 38 | SCENE_SIZE_MATCHES = { 39 | 'l': 'large_scenes', 40 | 's': 'small_scenes' 41 | } 42 | SCENE_NODE_DYNAMICS_MATCHES = { 43 | 'dn': True, #dynamic nodes 44 | 'sn': False #static nodes 45 | } 46 | SCENE_PRIORS_MATCHES = { 47 | 'd': 'detailed', 48 | 'c': 'coarse', 49 | } 50 | AGENT_PRIORS_MATCHES = { 51 | 'd': 'detailed', 52 | 'c': 'coarse', 53 | } 54 | PRIORS_NOISE_MATCHES = { 55 | 'za': 0.25, 56 | 'n': 0 57 | } 58 | PRIORS_SPARSITY_LEVEL_MATCHES = { 59 | 'za': 0.25, 60 | 'n': 0.25 61 | } 62 | MODEL_INCLUDE_TRANSFORMER_MATCHES = { 63 | 'it': True, 64 | 'et': False 65 | } 66 | 67 | ALL_EXP_NAMES = ['_'.join(l) for l in itertools.product(TASK_MATCHES.keys(), 68 | #SCENE_SIZE_MATCHES.keys(), 69 | SCENE_NODE_DYNAMICS_MATCHES.keys(), 70 | SCENE_PRIORS_MATCHES.keys(), 71 | AGENT_PRIORS_MATCHES.keys(), 72 | PRIORS_NOISE_MATCHES.keys(), 73 | MODEL_INCLUDE_TRANSFORMER_MATCHES.keys(), 74 | ['n','nwv','ntf', 'npp'])] 75 | 76 | def assign_new_value(cfg_dict, primary_key, cfg_key, value_match, secondary_key=None): 77 | if cfg_key not in value_match: 78 | raise ValueError("Incorrect config {}. Looking for one of {}".format(cfg_key, list(value_match.keys()))) 79 | if secondary_key is None: 80 | cfg_dict[primary_key] = value_match[cfg_key] 81 | else: 82 | if isinstance(cfg_dict[primary_key], list): 83 | dict_list = cfg_dict[primary_key] 84 | match_found = False 85 | for i, sub_dict in enumerate(dict_list): 86 | if secondary_key in sub_dict: 87 | cfg_dict[primary_key][i][secondary_key] = value_match[cfg_key] 88 | match_found = True 89 | if not match_found: 90 | raise ValueError("Couldn't find {} in {}".format(secondary_key, dict_list)) 91 | else: 92 | cfg_dict[primary_key][secondary_key] = value_match[cfg_key] 93 | 94 | if __name__ == "__main__": 95 | parser = argparse.ArgumentParser() 96 | parser.add_argument( 97 | "--exp_names_file", 98 | help="Path to text file containing the strings to be used to generate config files.", 99 | required=False, 100 | default=None 101 | ) 102 | parser.add_argument( 103 | "--print_formatted_strings", 104 | help="Prints the formatted strings of the generated config files.", 105 | action="store_true" 106 | ) 107 | parser.add_argument( 108 | "--exp_names", 109 | help="Comma separated experiment names. Defaults generated if none given", 110 | default=None 111 | ) 112 | args = parser.parse_args() 113 | if args.exp_names_file: 114 | with open(args.exp_names_file, "r") as f: 115 | filenames = f.readlines() 116 | elif args.exp_names: 117 | filenames = args.exp_names.split(',') 118 | else: 119 | filenames = ALL_EXP_NAMES 120 | for filename in tqdm(filenames): 121 | filename = filename.strip() 122 | if args.print_formatted_strings: 123 | print(filename) 124 | config_keys = filename.split("_") 125 | config_dict = copy.deepcopy(default_configs) 126 | task_type, node_dynamics, scene_priors, agent_priors, priors_noise, include_transformer, ablations = config_keys 127 | assign_new_value(config_dict, "defaults", task_type, TASK_MATCHES, "/task") 128 | #assign_new_value(config_dict, "defaults", scene_size, SCENE_SIZE_MATCHES, "/scene_gen") 129 | assign_new_value(config_dict, "scene_gen", node_dynamics, SCENE_NODE_DYNAMICS_MATCHES, "add_or_remove_objs") 130 | assign_new_value(config_dict, "scene_gen", scene_priors, SCENE_PRIORS_MATCHES, "scene_priors_type") 131 | assign_new_value(config_dict, "agents", agent_priors, AGENT_PRIORS_MATCHES, "agent_priors_type") 132 | assign_new_value(config_dict, "scene_gen", priors_noise, PRIORS_NOISE_MATCHES, "priors_noise") 133 | assign_new_value(config_dict, "scene_gen", priors_noise, PRIORS_SPARSITY_LEVEL_MATCHES, "priors_sparsity_level") 134 | assign_new_value(config_dict, "model", include_transformer, MODEL_INCLUDE_TRANSFORMER_MATCHES, "include_transformer") 135 | 136 | # Ablations 137 | if ablations == "npp": 138 | config_dict['agents']['sgm_use_priors'] = False 139 | elif ablations == "nwv": 140 | config_dict["model"]["node_features"] = ['time_since_observed', 'times_observed', 'time_since_state_change', 'state_change_freq', 'node_type'] 141 | config_dict["model"]["edge_features"] = ['time_since_observed', 'time_since_state_change', 'times_observed', 'times_state_true', 'last_observed_state', 'freq_true', 'edge_type'] 142 | elif ablations == "ntf": 143 | config_dict["model"]["node_features"] = ['text_embedding', 'node_type'] 144 | config_dict["model"]["edge_features"] = ['cosine_similarity', 'last_observed_state', 'freq_true', 'prior_prob', 'edge_type'] 145 | 146 | config_dict["run_name"] = filename 147 | config_dict["dataset"] = filename 148 | with open("configs/experiment/{}.yaml".format(filename), 'w') as file: 149 | file.write("# @package _global_\n") 150 | with open("configs/experiment/{}.yaml".format(filename), 'a') as file: 151 | documents = yaml.dump(config_dict, file, sort_keys=False) 152 | -------------------------------------------------------------------------------- /scripts/gen_prior_graphs.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import prior 3 | from memsearch.graphs import * 4 | from collections import defaultdict 5 | from shapely.geometry import Point 6 | from shapely.geometry.polygon import Polygon 7 | 8 | PROBS_FILE_PATH = 'priors/hardcoded_placement_probs.yaml' 9 | 10 | def camel_to_snake(s): 11 | return ''.join(['_'+c.lower() if c.isupper() else c for c in s]).lstrip('_') 12 | 13 | def gen_priors_graph(group_by_parent=True): 14 | # makes a hand-coded graph that is an example of a prior graph 15 | room_nodes = {} 16 | furniture_nodes = {} 17 | object_nodes = {} 18 | 19 | house_node = PriorsNode('house',NodeType.HOUSE,'house') 20 | floor_node1 = PriorsNode('floor1',NodeType.FLOOR,'floor') 21 | #floor_node2 = PriorsNode('floor2',NodeType.FLOOR) 22 | 23 | top_nodes = [house_node, floor_node1] 24 | Edge(house_node, floor_node1, EdgeType.CONTAINS, 1.0) 25 | 26 | room_furniture_probs = defaultdict(lambda: defaultdict(lambda: 0)) 27 | furniture_object_probs = defaultdict(lambda: defaultdict(lambda: 0)) 28 | room_furniture_object_probs = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 0))) 29 | #forget about floor 2 for now 30 | #Edge(house_node,floor_node2,EdgeType.CONTAINS, 0.1).add_to_nodes() 31 | #Edge(floor_node1,floor_node2,EdgeType.CONNECTED, 1.0).add_to_nodes() 32 | 33 | # hard code this for now 34 | room_names = ['kitchen', 'living_room','bedroom', 'bathroom'] 35 | for room_name in room_names: 36 | room_nodes[room_name] = PriorsNode(room_name, NodeType.ROOM, room_name) 37 | Edge(floor_node1, room_nodes[room_name], EdgeType.CONTAINS, 1.0) 38 | 39 | # ugly piece of logic to take the old-format hardcoded file and load it into a prior graph 40 | with open(PROBS_FILE_PATH, 'r') as probs_file: 41 | probs = yaml.safe_load(probs_file) 42 | # ignore 43 | for obj_label in probs: 44 | obj_label_dict = probs[obj_label] 45 | for obj_instance in obj_label_dict: 46 | edges = obj_label_dict[obj_instance] 47 | for edge_name in edges: 48 | furniture_type, room_type, edge_type = edge_name.split('-') 49 | edge_prob = edges[edge_name] 50 | if room_type not in room_nodes: 51 | continue 52 | 53 | if group_by_parent: 54 | furniture_cat = furniture_type+'-'+room_type 55 | else: 56 | furniture_cat = furniture_type 57 | if furniture_cat not in furniture_nodes: 58 | furniture_nodes[furniture_cat] = PriorsNode(furniture_type, NodeType.FURNITURE, furniture_cat) 59 | 60 | furniture_node = furniture_nodes[furniture_cat] 61 | 62 | room_node = room_nodes[room_type] 63 | if not furniture_node.has_edges_to(room_node): 64 | Edge(room_node, furniture_node, EdgeType.CONTAINS, 0) 65 | room_furniture_probs[room_type][furniture_type] = edge_prob 66 | 67 | if group_by_parent: 68 | obj_cat = room_type+'-'+obj_label 69 | else: 70 | obj_cat = obj_label 71 | if obj_cat not in object_nodes: 72 | object_nodes[obj_cat] = PriorsNode(obj_label, NodeType.OBJECT, obj_cat) 73 | object_node = object_nodes[obj_cat] 74 | 75 | if not object_node.has_edges_to(furniture_node): 76 | edge_type_enum = RECIPROCAL_EDGE_TYPES[edge_enum_from_str(edge_type)] 77 | Edge(furniture_node, object_node, edge_type_enum, edge_prob) 78 | furniture_object_probs[furniture_type][obj_label] = edge_prob 79 | room_furniture_object_probs[room_type][furniture_type][obj_label] = edge_prob 80 | # don't care about other instances lol 81 | break 82 | 83 | print('Loading procthor-10k...') 84 | dataset = prior.load_dataset("procthor-10k")['train'] 85 | room_furniture_counts = defaultdict(lambda: defaultdict(lambda: 0)) 86 | furniture_object_counts = defaultdict(lambda: defaultdict(lambda: 0)) 87 | room_furniture_object_counts = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 0))) 88 | 89 | for i in range(len(dataset)): 90 | house = dataset[i] 91 | room_polygons = {} 92 | rooms_json = house['rooms'] 93 | for i, room in enumerate(rooms_json): 94 | room_type = camel_to_snake(room['roomType']) 95 | polygon_indices = [] 96 | for polygon_element in room['floorPolygon']: 97 | polygon_indices.append((polygon_element['x'],polygon_element['z'])) 98 | polygon = Polygon(polygon_indices) 99 | room_polygons[room_type+'-'+str(i)] = polygon 100 | #x,y = polygon.exterior.xy 101 | #plt.plot(x,y) 102 | #plt.show() 103 | 104 | objects_json = house['objects'] 105 | for obj_json in objects_json: 106 | obj_type = camel_to_snake(obj_json['id'].split('|')[0]) 107 | obj_pos = obj_json['position'] 108 | obj_pos = Point(obj_pos['x'], obj_pos['z']) 109 | is_furniture = 'children' in obj_json 110 | 111 | for numbered_room_type in room_polygons: 112 | if room_polygons[numbered_room_type].contains(obj_pos): 113 | room_type = numbered_room_type.split('-')[0] 114 | if is_furniture: 115 | furniture_type = obj_type 116 | room_furniture_counts[room_type][furniture_type]+=1 117 | break 118 | 119 | if room_type not in room_names: 120 | continue 121 | 122 | if is_furniture: 123 | if group_by_parent: 124 | furniture_cat = furniture_type+'-'+room_type 125 | else: 126 | furniture_cat = furniture_type 127 | 128 | if furniture_cat not in furniture_nodes: 129 | furniture_nodes[furniture_cat] = PriorsNode(furniture_type, NodeType.FURNITURE, furniture_cat) 130 | furniture_node = furniture_nodes[furniture_cat] 131 | 132 | room_node = room_nodes[room_type] 133 | if not furniture_node.has_edges_to(room_node): 134 | Edge(room_node, furniture_node, EdgeType.CONTAINS, 0) 135 | 136 | object_counts = defaultdict(lambda: 0) 137 | for child_obj_json in obj_json['children']: 138 | obj_pos = child_obj_json['position'] 139 | obj_pos = Point(obj_pos['x'], obj_pos['z']) 140 | obj_type = camel_to_snake(child_obj_json['id'].split('|')[0]) 141 | object_counts[obj_type]+=1 142 | 143 | for obj_type,object_count in object_counts.items(): 144 | furniture_object_counts[furniture_type][obj_type]+=object_count 145 | room_furniture_object_counts[room_type][furniture_type][obj_type]+=object_count 146 | if group_by_parent: 147 | obj_cat = room_type+'-'+obj_type 148 | else: 149 | obj_cat = obj_type 150 | 151 | if obj_cat not in object_nodes: 152 | object_nodes[obj_cat] = PriorsNode(obj_type, NodeType.OBJECT, obj_cat) 153 | object_node = object_nodes[obj_cat] 154 | 155 | if not object_node.has_edges_to(furniture_node): 156 | Edge(furniture_node, object_node, EdgeType.UNDER, 0.0) 157 | else: 158 | pass 159 | 160 | for room_node in room_nodes.values(): 161 | room_type = room_node.label 162 | furniture_counts = room_furniture_counts[room_type] 163 | furniture_probs = room_furniture_probs[room_type] 164 | furniture_max_count = float(max(list(furniture_counts.values()))) 165 | for furniture_node in room_node.get_children_nodes(): 166 | edge = room_node.get_edge_to(furniture_node) 167 | count = float(furniture_counts[furniture_node.label]) 168 | ig_prob = furniture_probs[furniture_node.label] 169 | if count != 0: 170 | edge.prob = count/furniture_max_count 171 | else: 172 | edge.prob = ig_prob 173 | if group_by_parent: 174 | object_counts = room_furniture_object_counts[room_type][furniture_node.label] 175 | object_probs = room_furniture_object_counts[room_type][furniture_node.label] 176 | if len(object_counts.values()) != 0: 177 | objects_max_count = float(max(list(object_counts.values()))) 178 | else: 179 | objects_max_count = 1 180 | for object_node in furniture_node.get_children_nodes(): 181 | ig_prob = object_probs[object_node.label] 182 | count = float(object_counts[object_node.label]) 183 | edge = furniture_node.get_edge_to(object_node) 184 | if count != 0: 185 | edge.prob = count/objects_max_count 186 | else: 187 | edge.prob = ig_prob 188 | 189 | if not group_by_parent: 190 | for furniture_node in furniture_nodes.values(): 191 | furniture_type = furniture_node.label 192 | if furniture_type not in furniture_object_counts: 193 | continue 194 | object_counts = furniture_object_counts[furniture_type] 195 | object_probs = furniture_object_probs[furniture_type] 196 | objects_max_count = float(max(list(object_counts.values()))) 197 | for object_type, count in object_counts.items(): 198 | object_node = object_nodes[object_type] 199 | edge = furniture_node.get_edge_to(object_node) 200 | ig_prob = object_probs[object_type] 201 | if count != 0: 202 | edge.prob = float(count)/objects_max_count 203 | else: 204 | edge.prob = ig_prob 205 | 206 | ''' 207 | low_edge_object_nodes = [] 208 | for key,node in object_nodes.items(): 209 | if len(node.get_edges_to_me()) < 3: 210 | low_edge_object_nodes.append(key) 211 | for parent_node in node.get_parent_nodes(): 212 | parent_node.remove_edges_to(node) 213 | 214 | for key in low_edge_object_nodes: 215 | del object_nodes[key] 216 | ''' 217 | 218 | low_edge_furniture_nodes = [] 219 | for key,node in furniture_nodes.items(): 220 | if len(node.get_edges_from_me()) < 3: 221 | low_edge_furniture_nodes.append(key) 222 | for parent_node in node.get_parent_nodes(): 223 | parent_node.remove_edges_to(node) 224 | for child_node in node.get_children_nodes(): 225 | child_node.remove_edges_to(node) 226 | 227 | for key in low_edge_furniture_nodes: 228 | del furniture_nodes[key] 229 | graph = PriorsGraph(top_nodes + \ 230 | list(room_nodes.values()) + \ 231 | list(furniture_nodes.values()) + \ 232 | list(object_nodes.values())) 233 | 234 | #for node in graph.nodes: 235 | # node.normalize_edges_from_me() 236 | 237 | return graph 238 | 239 | def enrich_graph_with_metadata(graph): 240 | with open('priors/object_metadata.yaml','r') as f: 241 | metadata = yaml.safe_load(f) 242 | 243 | nodes_to_add = [] 244 | nodes_to_remove = [] 245 | for node in graph.nodes: 246 | if node.label not in metadata: 247 | nodes_to_remove.append(node) 248 | continue 249 | node_dict = metadata[node.label] 250 | node.unique_id = node_dict['unique_id'] 251 | node.label = node_dict['label'] 252 | node.max_count = node_dict['max_count'] 253 | node.move_freq = float(node_dict['move_freq']) 254 | node.sample_prob = float(node_dict['sample_prob']) 255 | if 'spawn_prob' in node_dict: 256 | node.spawn_prob = node_dict['spawn_prob'] 257 | node.remove_prob = node_dict['spawn_prob'] 258 | else: 259 | node.spawn_prob = 0.0 260 | node.remove_prob = 0.0 261 | node.adjectives = node_dict['adjectives'] 262 | if 'same_probs' not in node_dict: 263 | continue 264 | for new_label in node_dict['same_probs']: 265 | new_node = node.copy(with_unique_id=False, with_edge_copies=True, exclude_edges=True) 266 | new_node.label = new_label 267 | new_node.category = node.category.replace(node.label, new_label) 268 | new_node.unique_id = node.unique_id.replace(node.label, new_label) 269 | nodes_to_add.append(new_node) 270 | for node in nodes_to_remove: 271 | graph.remove_node(node) 272 | for node in nodes_to_add: 273 | graph.add_node(node) 274 | 275 | if __name__ == "__main__": 276 | print('Generating coarse priors graph...') 277 | print('---------------------------------') 278 | graph = gen_priors_graph(group_by_parent=False) 279 | print('Coase priors graph has %d nodes and %d edges'%(len(graph.nodes), len(graph.get_edges()))) 280 | enrich_graph_with_metadata(graph) 281 | print('Coarse priors graph with metadata has %d nodes and %d edges\n'%(len(graph.nodes), len(graph.get_edges()))) 282 | save_graph(graph,'priors/coarse_prior_graph.pickle') 283 | graph.save_png('priors/coarse_priors_graph.png') 284 | 285 | print('Generating detailed priors graph...') 286 | print('---------------------------------') 287 | graph = gen_priors_graph(group_by_parent=True) 288 | print('Detailed priors graph has %d nodes and %d edges'%(len(graph.nodes), len(graph.get_edges()))) 289 | enrich_graph_with_metadata(graph) 290 | print('Detailed priors graph with metadata has %d nodes and %d edges'%(len(graph.nodes), len(graph.get_edges()))) 291 | save_graph(graph,'priors/detailed_prior_graph.pickle') 292 | graph.save_png('priors/detailed_priors_graph.png') 293 | -------------------------------------------------------------------------------- /scripts/igridson_video.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import memsearch 3 | from memsearch.igridson_utils import * 4 | from mini_behavior.window import Window 5 | from memsearch.tasks import TaskType, make_task 6 | from tqdm import tqdm 7 | import multiprocessing 8 | import functools 9 | import pickle 10 | import matplotlib.pyplot as plt 11 | from pathlib import Path 12 | from enum import Enum 13 | 14 | class VizMode(Enum): 15 | """ 16 | Three visualization modes are supported: 17 | HEADLESS: igridson environment will not be rendered at all. 18 | Use this when you don't want any visualizations and only wish to calculate path length 19 | RENDER: igridson environment will be rendered and visualized in a pop-up window in real time. 20 | Frames will be saved to the specified directory 21 | """ 22 | HEADLESS = 0 23 | RENDER = 1 24 | 25 | def eval_agent(cfg, viz_mode=VizMode.HEADLESS, agent_type=None): 26 | # scene sampler and evolver 27 | scene_sampler, scene_evolver = make_scene_sampler_and_evolver_(cfg.scene_gen) 28 | scene = scene_sampler.sample() 29 | agent = make_agent_type(cfg, agent_type, TaskType.FIND_OBJECT, scene_sampler, scene_evolver) 30 | # The task env will manage the agent, and make predictions in a symbolic space using the scene graph 31 | task = make_task(scene_sampler, scene_evolver, TaskType.FIND_OBJECT, cfg.task.eps_per_scene) 32 | # The scene graph is shared by the igridson environment too 33 | env = SMGFixedEnv(scene_sampler, scene_evolver, scene=scene, env_evolve_freq=-1) 34 | 35 | if viz_mode == VizMode.RENDER: 36 | window = Window(f'Memory Object Search -- agent type {agent_type}') 37 | else: 38 | window = None 39 | # Jointly reset task and igridson env 40 | reset(env, window, task, agent) 41 | 42 | print("Starting exp", agent_type) 43 | 44 | avg_reward = 0.0 45 | num_successes = 0 46 | num_steps = 0 47 | 48 | all_rewards, all_steps = [], [] 49 | 50 | for _ in tqdm(range(cfg.num_queries)): 51 | # Generate a query using igridson env 52 | query = get_query(env) 53 | # Simulate the agent using the task env and get all the visited nodes 54 | pred_node, score, done, info = simulate_agent(agent, env, query, task, max_attempts=cfg.max_attempts) 55 | visited_nodes = info['visited_nodes'] 56 | 57 | # Next, run the visited nodes through the igridson env to calvulate A* length 58 | # and visualize if specified 59 | if viz_mode == VizMode.HEADLESS: 60 | curr_reward = get_astar_path(env, visited_nodes) 61 | else: 62 | curr_reward = get_astar_path( 63 | env, 64 | visited_nodes, 65 | window, 66 | save_dir='./outputs/igridson_simulations', 67 | exp_name=cfg.run_name 68 | ) 69 | all_rewards.append(curr_reward) 70 | avg_reward += curr_reward 71 | 72 | num_steps += score 73 | all_steps.append(score) 74 | 75 | if score != (cfg.max_attempts + 1): 76 | num_successes += 1 77 | 78 | # Jointly reset igridson and task envs. 79 | # This will force an evolution of the scene graph 80 | reset(env, window, task, agent) 81 | 82 | num_steps /= cfg.num_queries 83 | avg_reward /= cfg.num_queries 84 | success_rate = num_successes / cfg.num_queries 85 | steps_std_dev = np.std(all_steps) 86 | reward_std_dev = np.std(all_rewards) 87 | 88 | # Plot and save histograms 89 | plt.suptitle(agent_type) 90 | plt.subplot(121) 91 | plt.hist(all_rewards) 92 | plt.title("Path Length {:.2f} +/- {:.2f}".format(avg_reward, reward_std_dev)) 93 | 94 | plt.subplot(122) 95 | 96 | plt.hist(all_steps, np.arange(1, cfg.max_attempts, 1)) 97 | plt.title("Num Attempts {:.2f} +/- {:.2f}".format(num_steps, steps_std_dev)) 98 | plt.tight_layout(pad=2.0) 99 | 100 | save_p = Path(f"./outputs/{cfg.run_name}/plots") 101 | save_p.mkdir(parents=True, exist_ok=True) 102 | plt.savefig(str(save_p / f"{agent_type}.png")) 103 | 104 | return {agent_type: {"Cost": {"mean": avg_reward, "std": reward_std_dev, 'distribution': all_rewards}, "Num Steps": {"mean": num_steps, "std": steps_std_dev, 'distribution': all_steps}, "avg success rate": success_rate}} 105 | 106 | @hydra.main(version_base=None, config_path=memsearch.CONFIG_PATH, config_name="config") 107 | def main(cfg): 108 | agent_types = cfg.agents.agent_types 109 | if cfg.viz_mode == "headless": 110 | viz_mode = VizMode.HEADLESS 111 | elif cfg.viz_mode == "render": 112 | viz_mode = VizMode.RENDER 113 | 114 | if viz_mode != VizMode.RENDER: 115 | multiprocessing.set_start_method('spawn') 116 | run_eval_agent = functools.partial(eval_agent, cfg, viz_mode) 117 | tqdm.set_lock(multiprocessing.RLock()) 118 | with multiprocessing.Pool(processes = len(agent_types)) as pool: 119 | all_score_vecs = pool.map(run_eval_agent, agent_types) 120 | else: 121 | all_score_vecs = [eval_agent(cfg, viz_mode, agent_type) for agent_type in agent_types] 122 | 123 | rearranged_dict = {list(agent_d.keys())[0]: agent_d[list(agent_d.keys())[0]] for agent_d in all_score_vecs} 124 | metrics = list(rearranged_dict[agent_types[0]].keys()) 125 | 126 | for metric in metrics: 127 | print(metric) 128 | for agent_type in agent_types: 129 | metric_val = rearranged_dict[agent_type][metric] 130 | if metric == "avg success rate": 131 | print("Agent: {}, Success Rate: {}".format(agent_type, metric_val)) 132 | else: 133 | mean, std = metric_val['mean'], metric_val['std'] 134 | print("Agent: {}, {:.2f} +/- {:.2f}".format(agent_type, mean, std)) 135 | 136 | with open('FINAL_RESULTS.pkl', 'wb') as f: 137 | pickle.dump(all_score_vecs, f) 138 | 139 | if __name__ == '__main__': 140 | main() # type: ignore -------------------------------------------------------------------------------- /scripts/make_gifs.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | from pygifsicle import optimize 4 | # filepaths 5 | types=['evolver_probs_psg_gcn', 6 | 'evolver_probs_psg_mlp', 7 | 'evolver_probs_psg_heat', 8 | 'edge_age_psg_mlp', 9 | 'edge_age_psg_gcn', 10 | 'edge_age_psg_heat', 11 | 'state_psg_mlp', 12 | 'state_psg_gcn', 13 | 'state_psg_heat', 14 | 'node_moves_psg_mlp', 15 | 'node_moves_psg_gcn', 16 | 'node_moves_psg_heat', 17 | 'priors', 18 | # 'psg', 19 | 'times_observed', 20 | 'freq_true_MLP', 21 | 'freq_true_GCN', 22 | 'freq_true_HEAT', 23 | 'times_true', 24 | 'from_priors', 25 | 'from_observed', 26 | 'last_outputs_MLP', 27 | 'last_outputs_GCN', 28 | 'last_outputs_HEAT'] 29 | 30 | for t in types: 31 | print(t) 32 | if 'evolver_' in t: 33 | paths = sorted(['graphs/'+f for f in os.listdir('graphs') if f.startswith(t)], 34 | key=lambda x: int(x.split('_')[-1].replace('.png',''))) 35 | elif 'edge_age' in t or 'node_moves' in t or 'state' in t: 36 | paths = sorted(['graphs/'+f for f in os.listdir('graphs') if f.startswith('scene_'+t)], 37 | key=lambda x: int(x.split('_')[-2])*500 + int(x.split('_')[-1].replace('.png',''))) 38 | else: 39 | paths = sorted(['graphs/'+f for f in os.listdir('graphs') if f.startswith('psg_'+t)], 40 | key=lambda x: int(x.split('_')[-2])*500 + int(x.split('_')[-1].replace('.png',''))) 41 | imgs = (Image.open(f) for f in paths) 42 | img = next(imgs) # extract first image from iterator 43 | fp_out = 'gifs/'+t+".gif" 44 | img.save(fp=fp_out, format='GIF', append_images=imgs, 45 | save_all=True, duration=40000/len(paths), loop=1) 46 | optimize(fp_out) 47 | -------------------------------------------------------------------------------- /scripts/plot_eval_results.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import numpy as np 3 | from cycler import cycler 4 | from scripts.eval import plot_agent_eval 5 | from memsearch.metrics import AvgAccuracy, AvgAUC, DiscSumOfRewards, PercentObjectsFound, PercObjectsFoundOverTime, AvgNumAttempts 6 | import argparse 7 | import matplotlib.pyplot as plt 8 | from memsearch.tasks import TaskType 9 | 10 | def is_agent_name(test_str): 11 | ignore_keys = ['experiment=iclr', 'image', 'hydra'] 12 | for key in ignore_keys: 13 | if key in test_str: 14 | return False 15 | return True 16 | 17 | if __name__ == "__main__": 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument( 20 | "experiment_name", 21 | ) 22 | parser.add_argument( 23 | "--task", 24 | help="Name of the task to be evaluated.", 25 | default=None 26 | ) 27 | parser.add_argument( 28 | "--metrics", 29 | default="all" 30 | ) 31 | parser.add_argument( 32 | "--save_dir", 33 | help="Dir where generated plots are to be stored.", 34 | default=None 35 | ) 36 | parser.add_argument( 37 | "--agent_order", 38 | help="Comma Separated Agent names, in the order to be plotted", 39 | #default="random,memorization,counts,sgm_heat,priors,upper_bound" 40 | default="random,counts,priors,memorization,sgm_heat,upper_bound" 41 | ) 42 | parser.add_argument( 43 | "--num_smoothing_steps", 44 | help="Number of steps to use for smoothing for line graphs", 45 | default=10 46 | ) 47 | 48 | args = parser.parse_args() 49 | log_dir = Path(f'outputs/{args.experiment_name}/eval') 50 | if args.task is None: 51 | if 'pl' in args.experiment_name: 52 | task = TaskType.PREDICT_LOC 53 | elif 'pls' in args.experiment_name: 54 | task = TaskType.PREDICT_LOC 55 | elif 'pd' in args.experiment_name: 56 | task = TaskType.PREDICT_ENV_DYNAMICS 57 | elif 'fo' in args.experiment_name: 58 | task = TaskType.FIND_OBJECT 59 | else: 60 | task = TaskType.PREDICT_LOC 61 | else: 62 | task = args.task 63 | if args.save_dir is None: 64 | save_dir = log_dir / 'images' 65 | else: 66 | save_dir = Path(args.save_dir) 67 | num_smoothing_steps = int(args.num_smoothing_steps) 68 | Path(save_dir).mkdir(parents=True, exist_ok=True) 69 | metrics = [ 70 | AvgAccuracy('%s_avg_ac'%task.value, save_dir, ymax=0.6), 71 | AvgAUC('%s_avg_AuC'%task.value, save_dir), 72 | DiscSumOfRewards('%s_avg_DSoR'%task.value, save_dir) 73 | ] 74 | """ 75 | if args.metrics == 'all': 76 | elif args.metrics == 'line': 77 | metrics = [ 78 | AvgAccuracy('%s_avg_ac'%task.value, save_dir, ymax=0.6) 79 | ] 80 | args.all_agent_order = "random,memorization,counts,sgm_heat,priors,upper_bound" 81 | plt.rc('axes', prop_cycle=(cycler('color', ['blue', 'red', 'orange', 'purple','green', 'pink']))) 82 | elif args.metrics == 'bar': 83 | metrics = [ 84 | AvgAUC('%s_avg_AuC'%task.value, save_dir), 85 | DiscSumOfRewards('%s_avg_DSoR'%task.value, save_dir) 86 | ] 87 | args.agent_order = "random,counts,priors,memorization,sgm_heat,upper_bound" 88 | plt.rc('axes', prop_cycle=(cycler('color', ['blue', 'orange', 'green', 'red', 'purple', 'pink']))) 89 | """ 90 | all_agent_names = args.agent_order.split(',') 91 | if task == 'find_object': 92 | metrics = [ 93 | AvgAccuracy('%s_%s_avg_num_steps'%(args.experiment_name, task), save_dir, ymin=1.0, ymax=12, use_std=False), 94 | PercentObjectsFound('%s_%s_perc_found'%(args.experiment_name, task), save_dir, top_k=10), 95 | PercObjectsFoundOverTime('%s_%s_perc_found_over_time'%(args.experiment_name, task), save_dir, top_k=10), 96 | AvgNumAttempts('%s_%s_avg_num_attempts'%(args.experiment_name, task), save_dir) 97 | ] 98 | 99 | for i, agent_name in enumerate(all_agent_names): 100 | agent_csv_path = log_dir / agent_name 101 | csv_path = agent_csv_path / "eval.csv" 102 | agent_name = agent_csv_path.stem 103 | f = open(str(csv_path.resolve()), "r") 104 | score_vecs = [np.array(l.split(','), dtype=np.float32) for l in f.readlines()[-99:]] 105 | num_scenes, num_steps = len(score_vecs), score_vecs[0].shape[0] 106 | save_figs = (i == len(all_agent_names) - 1) 107 | plot_agent_eval(num_steps, score_vecs, agent_name, i, metrics, 108 | smoothing_kernel_size=num_smoothing_steps, task=task, 109 | show_fig=False, save_fig=save_figs, x_labels=all_agent_names) 110 | -------------------------------------------------------------------------------- /scripts/print_experiment_results.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | if __name__ == "__main__": 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument("experiment_name") 6 | args = parser.parse_args() 7 | path_to_results = f'outputs/{args.experiment_name}/eval.log' 8 | 9 | experiment_settings = args.experiment_name.split('_') 10 | print(f'Results for experiment {args.experiment_name}') 11 | 12 | with open(path_to_results, 'r') as results_f: 13 | results_text = results_f.readlines() 14 | 15 | for log_line in results_text: 16 | log_line = log_line.strip() 17 | if 'for metric' in log_line: 18 | print(log_line.split(' - ')[1]) 19 | if 'score: ' in log_line: 20 | #if "+" in log_line: 21 | # print(log_line.split(': ')[1]) 22 | #else: 23 | print(log_line.split(': ')[1]) 24 | if 'upper_bound' in log_line: 25 | print('---') 26 | 27 | 28 | -------------------------------------------------------------------------------- /scripts/run_experiment.sh: -------------------------------------------------------------------------------- 1 | # Generate the data 2 | python scripts/collect_data.py experiment=$1 3 | 4 | # Train the models 5 | python scripts/train.py model=mlp experiment=$1 6 | python scripts/train.py model=gcn experiment=$1 7 | python scripts/train.py model=heat experiment=$1 8 | python scripts/train.py model=hgt experiment=$1 9 | python scripts/train.py model=han experiment=$1 10 | #python scripts/train.py model=hgcn experiment=$1 11 | 12 | # Evaluate the model 13 | python scripts/eval.py experiment=$1 14 | -------------------------------------------------------------------------------- /scripts/submit_slurm_job.sh: -------------------------------------------------------------------------------- 1 | export experiment=$1 2 | echo "Submitting sbatch job for experiment $experiment" 3 | sbatch --output=/cvgl2/u/andreyk/projects/memory_object_search/outputs/slurm/$experiment.out --account viscam --error=/cvgl2/u/andreyk/projects/memory_object_search/outputs/slurm/$experiment.err --job-name=$experiment slurm_run_experiment.sbatch 4 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hydra 3 | import random 4 | import memsearch 5 | from memsearch.util import configure_logging 6 | from memsearch.training import train, test 7 | from memsearch.dataset import make_featurizers 8 | 9 | def run_training(cfg): 10 | logger = configure_logging(name='train') 11 | node_featurizer, edge_featurizer = make_featurizers(cfg.model, False, cfg.task.num_steps) 12 | random.seed('training seed') 13 | model = train(cfg = cfg, 14 | num_epochs = cfg.num_train_epochs, 15 | node_featurizer = node_featurizer, 16 | edge_featurizer = edge_featurizer, 17 | add_num_nodes = cfg.model.add_num_nodes, 18 | add_num_edges = cfg.model.add_num_edges, 19 | use_edge_weights = cfg.model.use_edge_weights, 20 | num_labels_per_batch = cfg.train_labels_per_batch, 21 | logger = logger) 22 | random.seed('testing seed') 23 | test(cfg, 24 | model = model, 25 | node_featurizer = node_featurizer, 26 | edge_featurizer = edge_featurizer, 27 | add_num_nodes = cfg.model.add_num_nodes, 28 | add_num_edges = cfg.model.add_num_edges, 29 | use_edge_weights = cfg.model.use_edge_weights, 30 | num_labels_per_batch = cfg.test_labels_per_batch, 31 | logger = logger) 32 | 33 | @hydra.main(version_base=None, 34 | config_path=memsearch.CONFIG_PATH, 35 | config_name="config") 36 | def main(cfg): 37 | run_training(cfg) 38 | 39 | if __name__ == "__main__": 40 | main() # type: ignore 41 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | from setuptools import setup, find_packages 4 | 5 | 6 | def read(fname): 7 | with open(os.path.join(os.path.dirname(__file__), fname)) as f: 8 | return f.read() 9 | 10 | 11 | setup( 12 | name='memsearch', 13 | packages=find_packages(), 14 | ) 15 | --------------------------------------------------------------------------------