├── tests └── __init__.py ├── src └── nett │ ├── utils │ ├── __init__.py │ ├── io.py │ ├── job.py │ ├── callbacks.py │ ├── logger.py │ ├── environment.py │ └── train.py │ ├── analysis │ ├── .Rapp.history │ ├── .gitattributes │ ├── ChickData │ │ ├── .DS_Store │ │ ├── viewinvariant.csv │ │ ├── parsing.csv │ │ └── binding.csv │ ├── NETT_train_viz.R │ ├── NETT_merge_csvs.R │ └── NETT_test_viz.R │ ├── brain │ ├── encoders │ │ ├── disembodied_models │ │ │ └── archs │ │ │ │ ├── __init__.py │ │ │ │ ├── resnet_1b.py │ │ │ │ ├── resnet_2b.py │ │ │ │ └── resnet_3b.py │ │ ├── __init__.py │ │ ├── frozensimclr.py │ │ ├── dinov1.py │ │ ├── vit.py │ │ ├── sam.py │ │ ├── dinov2.py │ │ ├── cnnlstm.py │ │ ├── resnet18.py │ │ └── resnet10.py │ └── __init__.py │ ├── _version.py │ ├── environment │ ├── __init__.py │ ├── builder.py │ └── configs.py │ ├── __init__.py │ └── body │ ├── __init__.py │ ├── wrappers │ └── dvs.py │ └── builder.py ├── docs ├── source │ ├── 2documentation │ │ ├── NETT.rst │ │ ├── body.rst │ │ ├── environment.rst │ │ ├── brain.rst │ │ └── index.rst │ ├── 3papers │ │ ├── index.rst │ │ ├── ViewInvariant.md │ │ └── Parsing.md │ ├── _static │ │ ├── images │ │ │ ├── parsing.png │ │ │ └── viewpoint.png │ │ ├── video │ │ │ ├── parsing.mp4 │ │ │ └── viewpoint.mp4 │ │ └── custom_styles.css │ ├── 1gettingstarted │ │ └── index.rst │ ├── index.rst │ └── conf.py ├── assets │ ├── images │ │ ├── banner.png │ │ ├── digital_twin.jpg │ │ └── digital_twin_cropped.jpg │ └── newbornembodied_0131_V1.docx ├── requirements.txt ├── Makefile ├── make.bat └── dev │ └── developer-notes.md ├── MANIFEST.in ├── .github ├── ISSUE_TEMPLATE │ ├── feature_request.md │ └── bug_report.md └── workflows │ ├── lint.yml │ └── docs.yml ├── examples ├── run │ ├── configuration │ │ ├── parsing.yaml │ │ ├── viewinvariant.yaml │ │ └── binding.yaml │ ├── wrapper │ │ └── dvs_wrapper.py │ └── run.py └── notebooks │ └── Getting Started.ipynb ├── LICENSE ├── pyproject.toml ├── scripts └── publish.sh ├── .gitignore └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/nett/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/nett/analysis/.Rapp.history: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/nett/brain/encoders/disembodied_models/archs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/nett/_version.py: -------------------------------------------------------------------------------- 1 | """Version of the NETT library""" 2 | __version__ = "0.4.1" -------------------------------------------------------------------------------- /docs/source/2documentation/NETT.rst: -------------------------------------------------------------------------------- 1 | NETT 2 | ==== 3 | 4 | .. autoclass:: nett.NETT 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/2documentation/body.rst: -------------------------------------------------------------------------------- 1 | body 2 | ==== 3 | 4 | .. autoclass:: nett.Body 5 | :members: -------------------------------------------------------------------------------- /src/nett/analysis/.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /docs/source/3papers/index.rst: -------------------------------------------------------------------------------- 1 | Research Papers 2 | =============== 3 | 4 | .. toctree:: 5 | 6 | Parsing 7 | ViewInvariant -------------------------------------------------------------------------------- /docs/assets/images/banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buildingamind/NewbornEmbodiedTuringTest/HEAD/docs/assets/images/banner.png -------------------------------------------------------------------------------- /docs/source/2documentation/environment.rst: -------------------------------------------------------------------------------- 1 | environment 2 | =========== 3 | 4 | .. autoclass:: nett.Environment 5 | :members: 6 | -------------------------------------------------------------------------------- /src/nett/environment/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Initializes the environment module. 3 | """ 4 | from nett.environment.configs import * 5 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | # needed to include non-python files to installation 2 | include src/nett/analysis/*.R 3 | include src/nett/analysis/ChickData/*.csv 4 | -------------------------------------------------------------------------------- /docs/assets/images/digital_twin.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buildingamind/NewbornEmbodiedTuringTest/HEAD/docs/assets/images/digital_twin.jpg -------------------------------------------------------------------------------- /docs/source/_static/images/parsing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buildingamind/NewbornEmbodiedTuringTest/HEAD/docs/source/_static/images/parsing.png -------------------------------------------------------------------------------- /docs/source/_static/video/parsing.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buildingamind/NewbornEmbodiedTuringTest/HEAD/docs/source/_static/video/parsing.mp4 -------------------------------------------------------------------------------- /docs/source/_static/video/viewpoint.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buildingamind/NewbornEmbodiedTuringTest/HEAD/docs/source/_static/video/viewpoint.mp4 -------------------------------------------------------------------------------- /src/nett/analysis/ChickData/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buildingamind/NewbornEmbodiedTuringTest/HEAD/src/nett/analysis/ChickData/.DS_Store -------------------------------------------------------------------------------- /docs/assets/newbornembodied_0131_V1.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buildingamind/NewbornEmbodiedTuringTest/HEAD/docs/assets/newbornembodied_0131_V1.docx -------------------------------------------------------------------------------- /docs/source/_static/images/viewpoint.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buildingamind/NewbornEmbodiedTuringTest/HEAD/docs/source/_static/images/viewpoint.png -------------------------------------------------------------------------------- /docs/assets/images/digital_twin_cropped.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buildingamind/NewbornEmbodiedTuringTest/HEAD/docs/assets/images/digital_twin_cropped.jpg -------------------------------------------------------------------------------- /docs/source/2documentation/brain.rst: -------------------------------------------------------------------------------- 1 | brain 2 | ===== 3 | 4 | .. autoclass:: nett.Brain 5 | :members: 6 | 7 | .. automodule:: nett.brain 8 | :members: 9 | -------------------------------------------------------------------------------- /docs/source/2documentation/index.rst: -------------------------------------------------------------------------------- 1 | Documentation 2 | ============= 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | brain 8 | body 9 | environment 10 | NETT 11 | -------------------------------------------------------------------------------- /src/nett/analysis/ChickData/viewinvariant.csv: -------------------------------------------------------------------------------- 1 | test.cond,avg,std_dev,avg_dev,se 2 | non-matched,0.667702422,0.096817389,0.071405403,0.020187821 3 | matched,0.588036395,0.102649388,0.076107463,0.029632326 4 | -------------------------------------------------------------------------------- /docs/source/1gettingstarted/index.rst: -------------------------------------------------------------------------------- 1 | Getting Started 2 | =============== 3 | 4 | .. include:: ../../../README.md 5 | :parser: myst_parser.sphinx_ 6 | :start-after: 7 | :end-before: 8 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | mlagents 2 | stable-baselines3[extra] 3 | sb3-contrib 4 | torchvision 5 | timm 6 | nvidia-ml-py 7 | lightning 8 | lightning-bolts 9 | scikit-learn 10 | sphinx 11 | mkdocs 12 | sphinx-rtd-theme 13 | myst-parser -------------------------------------------------------------------------------- /src/nett/analysis/ChickData/parsing.csv: -------------------------------------------------------------------------------- 1 | test.cond,avg,std_dev,avg_dev,se 2 | Imprinted Object Familiar,0.652731096,0.118585046,0.091188669,0.021298503 3 | Novel Familiar,0.722057353,0.100726968,0.081165209,0.018091098 4 | Both Familiar,0.707892313,0.109402977,0.087105549,0.019649355 5 | Both Unfamiliar,0.725346606,0.101022594,0.084554842,0.018144193 6 | -------------------------------------------------------------------------------- /src/nett/brain/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | """Simplifies imports for encoders""" 2 | from .resnet18 import Resnet18CNN 3 | from .resnet10 import Resnet10CNN 4 | from .dinov1 import DinoV1 5 | from .dinov2 import DinoV2 6 | from .sam import SegmentAnything 7 | from .vit import ViT 8 | from .cnnlstm import CNNLSTM 9 | from .frozensimclr import FrozenSimCLR 10 | -------------------------------------------------------------------------------- /src/nett/analysis/ChickData/binding.csv: -------------------------------------------------------------------------------- 1 | test.cond,avg,std_dev,avg_dev,se 2 | 1color,0.658206456,0.116711762,0.081252495,0.0369075 3 | 1shape,0.610765667,0.084799737,0.062939971,0.026816031 4 | 1shape&color,0.709508299,0.113511354,0.07979121,0.035895442 5 | 2color,0.783648938,0.139327049,0.104022025,0.044059081 6 | 2shape,0.703164725,0.134004893,0.101561082,0.042376068 7 | 2shape&color,0.807412561,0.11440134,0.080392005,0.03617688 8 | binding,0.672718326,0.102088828,0.072680404,0.032283322 9 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. NETT documentation master file, created by 2 | sphinx-quickstart on Thu Feb 8 11:01:54 2024. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | =================================== 7 | Newborn Embodied Turing Test (NETT) 8 | =================================== 9 | 10 | .. toctree:: 11 | :titlesonly: 12 | :glob: 13 | 14 | */index 15 | 16 | Indices and tables 17 | ================== 18 | 19 | * :ref:`genindex` 20 | 21 | .. * :ref:`search` 22 | .. * :ref:`modindex` 23 | -------------------------------------------------------------------------------- /docs/source/_static/custom_styles.css: -------------------------------------------------------------------------------- 1 | /* custom_styles.css example to alter toctree list appearance */ 2 | 3 | .rst-content .toctree-wrapper .toctree-l1 { 4 | list-style-type: none; /* Remove bullets */ 5 | font-size: 22px; /* Larger text */ 6 | font-weight: bold; /* Bold text */ 7 | margin-left: 0; /* Remove indentation */ 8 | margin-bottom: 24px; /* Add space between items */ 9 | } 10 | 11 | .rst-content .toctree-wrapper .toctree-l2 { 12 | list-style: disc; /* Change bullets */ 13 | font-size: 16px; /* Smaller text */ 14 | font-weight: normal; /* Normal text */ 15 | /* margin-left: 24px; Add back indentation */ 16 | } 17 | 18 | 19 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /examples/run/configuration/parsing.yaml: -------------------------------------------------------------------------------- 1 | Brain: 2 | reward: supervised 3 | batch_size: 512 4 | buffer_size: 2048 5 | algorithm: PPO 6 | policy: CnnPolicy 7 | seed: 42 8 | encoder: 9 | train_encoder: True 10 | Body: 11 | type: basic 12 | dvs: True 13 | Environment: 14 | use_ship: True 15 | background: A 16 | executable_path: /home/mchivuku/projects/neurips/NewbornEmbodiedTuringTest/data/executables/parsing/parsing.x86_64 17 | record_agent: False 18 | record_chamber: False 19 | recording_frames: 1000 20 | Config: 21 | run_id: "parsing_ship_A_exp2" 22 | num_brains: 1 23 | mode: "full" 24 | train_eps: 1 25 | test_eps: 20 26 | output_dir: "/home/mchivuku/projects/neurips/NewbornEmbodiedTuringTest/data/runs/parsingdata" 27 | -------------------------------------------------------------------------------- /examples/run/configuration/viewinvariant.yaml: -------------------------------------------------------------------------------- 1 | Brain: 2 | reward: supervised 3 | batch_size: 512 4 | buffer_size: 2048 5 | algorithm: PPO 6 | policy: CnnPolicy 7 | seed: 42 8 | encoder: 9 | train_encoder: True 10 | Body: 11 | type: basic 12 | dvs: False 13 | Environment: 14 | use_ship: True 15 | side_view: False 16 | executable_path: '/home/mchivuku/projects/neurips/NewbornEmbodiedTuringTest/data/executables/viewinvariant/viewinvariant.x86_64' 17 | record_chamber: True 18 | record_agent: True 19 | recording_frames: 1000 20 | Config: 21 | run_id: "viewpt_ship_A_exp2" 22 | num_brains: 1 23 | mode: "full" 24 | train_eps: 1 25 | test_eps: 20 26 | output_dir: "/home/mchivuku/projects/neurips/NewbornEmbodiedTuringTest/data/runs/viewpt" 27 | 28 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Version [e.g. 22] 29 | 30 | **Additional context** 31 | Add any other context about the problem here. 32 | -------------------------------------------------------------------------------- /examples/run/configuration/binding.yaml: -------------------------------------------------------------------------------- 1 | Brain: 2 | reward: supervised 3 | batch_size: 512 4 | buffer_size: 2048 5 | algorithm: PPO 6 | policy: CnnPolicy 7 | seed: 42 8 | encoder: large 9 | train_encoder: True 10 | Body: 11 | type: basic 12 | dvs: False 13 | Environment: 14 | # condition 15 | object: "object2" 16 | # generic config 17 | executable_path: /home/mchivuku/projects/embodied_pipeline/ICLR/newbornmain/data/executables/binding/binding.x86_64 18 | record_chamber: False 19 | record_agent: False 20 | recording_frames: 100 21 | Config: 22 | run_id: "binding_object1" 23 | num_brains: 1 24 | mode: "full" 25 | train_eps: 1000 26 | test_eps: 20 27 | output_dir: "/data/mchivuku/embodiedai/neurips_experiments/binding/base_agents_new/large/object2_exp4" 28 | 29 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /src/nett/utils/io.py: -------------------------------------------------------------------------------- 1 | """ 2 | io 3 | 4 | This module contains utility functions for input/output operations 5 | 6 | Functions: 7 | write_to_file: Write a dictionary to a file 8 | mute: Mute the standard output and standard error 9 | """ 10 | import os 11 | import sys 12 | import json 13 | 14 | def write_to_file(file_path, d) -> bool: 15 | """ 16 | write_to_file 17 | 18 | Write a dictionary to a file 19 | 20 | Args: 21 | file_path (str): The path to the file 22 | d (dict): The dictionary to write 23 | 24 | Returns: 25 | bool: True if the file was written, False otherwise 26 | """ 27 | with open(file_path, "w") as file: 28 | file.write(json.dumps(d)) 29 | return True 30 | 31 | def mute() -> None: 32 | """ 33 | mute 34 | 35 | Mute the standard output and standard error 36 | """ 37 | sys.stdout = open(os.devnull, "w") 38 | sys.stderr = open(os.devnull, "w") # TODO Should we be suppressing error messages? 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | ======= 2 | MIT License 3 | 4 | Copyright (c) 2023 Building A Mind Lab 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. -------------------------------------------------------------------------------- /src/nett/utils/job.py: -------------------------------------------------------------------------------- 1 | """Job class for training and testing models""" 2 | from pathlib import Path 3 | 4 | class Job: 5 | """Holds information for a job 6 | 7 | Args: 8 | brain_id (int): id for the brain 9 | condition (str): condition for the job 10 | device (int): device to run the job on 11 | dir (Path): directory to store the job 12 | """ 13 | def __init__(self, brain_id: int, condition: str, device: int, dir: Path, index: int) -> None: 14 | """initialize job""" 15 | self.device: int = device 16 | self.condition: str = condition 17 | self.brain_id: int = brain_id 18 | self.dir: Path = dir 19 | self.paths: dict[str, Path] = self._configure_paths(brain_id, condition) 20 | self.index: int = index 21 | 22 | 23 | def _configure_paths(self, brain_id: int, condition: str) -> dict[str, Path]: 24 | """Configure Paths for the job 25 | 26 | Args: 27 | brain_id (int): id for the brain 28 | condition (str): condition for the job 29 | 30 | Returns: 31 | dict[str, Path]: dictionary of the paths 32 | """ 33 | SUBDIRS = ["model", "checkpoints", "plots", "logs", "env_recs", "env_logs"] 34 | job_dir = Path.joinpath(self.dir, condition, f"brain_{brain_id}") 35 | return {subdir: Path.joinpath(job_dir, subdir) for subdir in SUBDIRS} 36 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - dev 8 | pull_request: 9 | branches: 10 | - main 11 | - dev 12 | 13 | jobs: 14 | build: 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - name: Checkout code 19 | uses: actions/checkout@v4 20 | 21 | - name: Set up Python 22 | uses: actions/setup-python@v5 23 | with: 24 | python-version: "3.10.12" 25 | cache: "pip" # caching pip dependencies 26 | # Install your linters here 27 | 28 | # - name: Change setuptools version 29 | # run: pip install setuptools==65.5.0 30 | 31 | # - name: Change pip version 32 | # run: pip install pip==21 33 | 34 | - name: Install pylint 35 | run: pip install pylint 36 | 37 | # - name: Install dependencies 38 | # run: pip install -r requirements.txt 39 | 40 | - name: Run Linter 41 | uses: wearerequired/lint-action@v2 42 | with: 43 | pylint: true 44 | pylint_args: "--disable=all --enable=C0114 --enable=C0115 --enable=C0116 --enable=W0612 --enable=W0613" 45 | # checks for: 46 | ## missing module docstring, missing class docstring, missing function docstring, 47 | ## unused arguments, and unused variables 48 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | '''Configuration file for the Sphinx documentation builder.''' 2 | 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | import os 7 | import sys 8 | 9 | sys.path.insert(0, os.path.abspath('../../src/')) 10 | 11 | from nett import __version__ 12 | 13 | source_suffix = { 14 | '.rst': 'restructuredtext', 15 | '.md': 'markdown', 16 | } 17 | 18 | # -- Project information ----------------------------------------------------- 19 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 20 | 21 | project = 'NETT' 22 | copyright = '2024, Zachary Laborde' 23 | author = 'Zachary Laborde' 24 | release = __version__ 25 | 26 | # -- General configuration --------------------------------------------------- 27 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 28 | 29 | extensions = [ 30 | 'sphinx.ext.autodoc', 31 | 'sphinx.ext.napoleon', 32 | 'myst_parser', 33 | ] 34 | 35 | templates_path = ['_templates'] 36 | exclude_patterns = [] 37 | 38 | # -- Options for HTML output ------------------------------------------------- 39 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 40 | 41 | html_theme = "sphinx_rtd_theme" 42 | html_theme_options = { 43 | "body_max_width": "none", 44 | } 45 | 46 | html_static_path = ['_static'] 47 | 48 | html_baseurl = "/html/" 49 | 50 | # Example conf.py snippet 51 | html_css_files = [ 52 | 'custom_styles.css', # The name of your custom CSS file 53 | ] 54 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: "Render Docs" 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - dev 8 | pull_request: 9 | branches: 10 | - main 11 | - dev 12 | 13 | jobs: 14 | docs: 15 | runs-on: ubuntu-latest 16 | permissions: 17 | contents: write 18 | steps: 19 | - name: Checkout 20 | uses: actions/checkout@v4 21 | 22 | - name: Set up Python 3.10.12 23 | uses: actions/setup-python@v5 24 | with: 25 | python-version: 3.10.12 26 | 27 | - name: Cache pip 28 | uses: actions/cache@v3 29 | id: cache-pip 30 | with: 31 | path: ${{ env.pythonLocation }} 32 | key: ${{ runner.os }}-pip-${{ hashFiles('docs/requirements.txt') }} 33 | restore-keys: | 34 | ${{ runner.os }}-pip 35 | - name: Install pip 36 | if: steps.cache-pip.outputs.cache-hit != 'true' 37 | run: | 38 | pip install setuptools==65.5.0 pip==21 39 | - name: Install Dependencies 40 | if: steps.cache-pip.outputs.cache-hit != 'true' 41 | run: | 42 | pip install -r docs/requirements.txt 43 | 44 | - name: Sphinx build 45 | run: | 46 | sphinx-build --jobs auto -b html docs/source/ docs/build/ 47 | 48 | - name: Deploy to GitHub Pages 49 | uses: peaceiris/actions-gh-pages@v3 50 | if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} 51 | with: 52 | publish_branch: gh-pages 53 | github_token: ${{ secrets.GITHUB_TOKEN }} 54 | publish_dir: docs/build/ 55 | force_orphan: true 56 | -------------------------------------------------------------------------------- /src/nett/brain/encoders/frozensimclr.py: -------------------------------------------------------------------------------- 1 | """ 2 | Frozen SimCLR encoder for stable-baselines3 3 | 4 | This module provides a feature extractor based on the SimCLR model. It takes in observations from an environment and extracts features using the SimCLR model. 5 | """ 6 | from stable_baselines3.common.torch_layers import BaseFeaturesExtractor 7 | from nett.brain.encoders.disembodied_models.simclr import SimCLR 8 | 9 | import logging 10 | 11 | logger = logging.getLogger(__name__) 12 | logger.setLevel(logging.INFO) 13 | 14 | class FrozenSimCLR(BaseFeaturesExtractor): 15 | """ 16 | 17 | Frozen SimCLR encoder for stable-baselines3 18 | 19 | Args: 20 | observation_space (gym.spaces.Box): Observation space 21 | features_dim (int, optional): Output dimension of features extractor. Defaults to 512. 22 | checkpoint_path (str, optional): Path to the SimCLR checkpoint. Defaults to "simclr". 23 | """ 24 | 25 | def __init__(self, observation_space: "gym.spaces.Box", features_dim: int = 512, checkpoint_path: str = "simclr") -> None: 26 | super(FrozenSimCLR, self).__init__(observation_space, features_dim) 27 | self.n_input_channels = observation_space.shape[0] 28 | logger.info("FrozenSimCLR Encoder: ") 29 | logger.info(checkpoint_path) 30 | self.model = SimCLR.load_from_checkpoint(checkpoint_path) 31 | 32 | def forward(self, observations: "torch.Tensor") -> "torch.Tensor": 33 | """ 34 | Forward pass in the network 35 | 36 | Args: 37 | observations (torch.Tensor): input tensor 38 | 39 | Returns: 40 | torch.Tensor: output tensor 41 | """ 42 | return self.model(observations) 43 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools<=65.5.0", "pip<=21"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "nett-benchmarks" 7 | dynamic = ["version"] 8 | authors = [ 9 | { name="Bhargav Desai", email="desabh@iu.edu" }, 10 | { name="Zachary Laborde", email="zlaborde@iu.edu" }, 11 | { name="Manju Garimella", email="mchivuku@iu.edu" }, 12 | ] 13 | description = "A testbed for comparing the learning abilities of newborn animals and autonomous artificial agents." 14 | readme = "README.md" 15 | license = {file = "LICENSE"} 16 | requires-python = "==3.10.12" 17 | dependencies = [ 18 | "mlagents==1.0.0", 19 | "stable-baselines3[extra]==1.8.0", 20 | "sb3-contrib==1.8.0", 21 | "torchvision", 22 | "timm", 23 | "nvidia-ml-py", 24 | "lightning==2.2.5", 25 | "lightning-bolts==0.7.0", 26 | "scikit-learn==1.5.0" 27 | ] 28 | keywords = ["nett","netts","newborn","embodied","turing test","benchmark","benchmarking","learning","animals","autonomous","artificial","agents","reinforcement","neuroml","AI","ML","machine learning","artificial intelligence"] 29 | classifiers = [ 30 | "Programming Language :: Python :: 3.10", 31 | "Programming Language :: R", 32 | "Environment :: GPU :: NVIDIA CUDA" 33 | ] 34 | 35 | [project.optional-dependencies] 36 | notebook = ["ipywidgets"] 37 | 38 | [project.urls] 39 | "Homepage" = "https://github.com/buildingamind/NewbornEmbodiedTuringTest" 40 | "Bug Tracker" = "https://github.com/buildingamind/NewbornEmbodiedTuringTest/issues" 41 | "Documentation" = "https://buildingamind.github.io/NewbornEmbodiedTuringTest/index.html" 42 | 43 | [tool.setuptools.dynamic] 44 | version = { attr = "nett._version.__version__" } 45 | 46 | [tool.setuptools.packages.find] 47 | where = ["src"] 48 | 49 | -------------------------------------------------------------------------------- /src/nett/brain/encoders/dinov1.py: -------------------------------------------------------------------------------- 1 | """DINO (Emerging Properties in Self-Supervised Vision Transformers) model""" 2 | import gym 3 | import torch 4 | import timm 5 | 6 | from torchvision.transforms import Compose, Resize, CenterCrop, Normalize, InterpolationMode 7 | from stable_baselines3.common.torch_layers import BaseFeaturesExtractor 8 | 9 | class DinoV1(BaseFeaturesExtractor): 10 | """ 11 | Initialize DinoV1 feature extractor. 12 | 13 | Args: 14 | observation_space (gym.spaces.Box): The observation space of the environment. 15 | features_dim (int, optional): Number of features extracted. This corresponds to the number of units for the last layer. Defaults to 384. 16 | """ 17 | def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 384) -> None: 18 | """Constructor method 19 | """ 20 | super(DinoV1, self).__init__(observation_space, features_dim) 21 | self.n_input_channels = observation_space.shape[0] 22 | self.transforms = Compose([Resize(size=248, 23 | interpolation=InterpolationMode.BICUBIC, 24 | max_size=None, 25 | antialias=True), 26 | CenterCrop(size=(224, 224)), 27 | Normalize(mean=torch.tensor([0.4850, 0.4560, 0.4060]), 28 | std=torch.tensor([0.2290, 0.2240, 0.2250]))]) 29 | self.model = timm.create_model('vit_small_patch8_224.dino', 30 | in_chans=self.n_input_channels, 31 | num_classes=0, 32 | pretrained=True) 33 | 34 | def forward(self, observations: torch.Tensor) -> torch.Tensor: 35 | """Forward pass of the DinoV1 model.""" 36 | return self.model(self.transforms(observations)) 37 | -------------------------------------------------------------------------------- /docs/dev/developer-notes.md: -------------------------------------------------------------------------------- 1 | # Developer Notes 2 | 3 | ## Introduction 4 | 5 | This document serves as a guide to the development process of the NETT toolkit. It provides an overview of the toolkit's development workflow. 6 | 7 | ## Online Documentation Website 8 | 9 | The documentation website is hosted on GitHub Pages and is available at [https://buildingamind.github.io/NewbornEmbodiedTuringTest/](https://buildingamind.github.io/NewbornEmbodiedTuringTest/). It grabs the latest documentation from the `docs` folder in the repository. It creates the documentation using Sphinx. The documentation is written in reStructuredText format and is located in the `docs/source` folder. 10 | 11 | All assets used by Sphinx are located in the `docs/source/_static` folder. The `index.rst` file in the `docs/source` folder is the main entry point for the documentation. Each section of the documentation is pulled in from the `index.rst` file in each subdirectory in `docs/source`. 12 | 13 | The website is updated automatically when changes are pushed to the `main` branch. To see how this process works, please see the `.github/workflows/docs.yml` file. 14 | 15 | The configuration for the website is defined in `conf.py`. The website uses the `ReadTheDocs` theme. It is possible to change the theme by modifying the `html_theme` variable in `conf.py`. The website currently uses three extensions: `sphinx.ext.autodoc`, `sphinx.ext.napoleon`, and `myst-parser`. The `autodoc` extension is used to automatically generate documentation from the docstrings in the source code. The `napoleon` extension is used to parse the Google-style docstrings. The `myst-parser` extension is used to parse markdown files. 16 | 17 | To build the documentation locally, you can run the following command: 18 | 19 | ```bash 20 | sphinx-build -M html docs/source/ docs/build/ 21 | ``` 22 | 23 | Subsequent builds can be done by running the following commands: 24 | ```bash 25 | cd docs 26 | make html 27 | ``` 28 | 29 | The documentation can be viewed by opening the `index.html` file in the `docs/build` folder in a web browser. 30 | -------------------------------------------------------------------------------- /src/nett/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Wood Lab, Indiana University Bloomington. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | ''' 16 | Initialize the NETT library 17 | ''' 18 | 19 | import os 20 | import stat 21 | import logging 22 | from pathlib import Path 23 | # simplify imports 24 | from nett.brain.builder import Brain 25 | from nett.body.builder import Body 26 | from nett.environment.builder import Environment 27 | from nett.nett import NETT 28 | 29 | from nett.brain import list_encoders, list_algorithms, list_policies 30 | from nett.environment import list_configs 31 | 32 | # release version 33 | from ._version import __version__ 34 | 35 | # change permissions of the ml-agents binaries directory 36 | 37 | # path to store library cache (such as configs etc) 38 | cache_dir = Path.joinpath(Path.home(), ".cache", "nett") 39 | 40 | # set up logging 41 | logging.basicConfig(format="[%(name)s] %(levelname)s: %(message)s", level=logging.INFO) 42 | logger = logging.getLogger(__name__) 43 | 44 | # path to store ml-agents binaries 45 | for tmp_dir in ["/tmp/ml-agents-binaries", "/tmp/ml-agents-binaries/binaries", "/tmp/ml-agents-binaries/tmp"]: 46 | # TODO: May need to allow for permissions other than X777 47 | if stat.S_IMODE(os.stat(tmp_dir).st_mode) % 0o1000 != 0o777: 48 | # TODO: May need to check for permissions other than W_OK 49 | if os.stat(tmp_dir).st_uid == os.getuid() or os.access(tmp_dir, os.W_OK): 50 | os.chmod(tmp_dir, 0o1777) 51 | else: 52 | logger.warning(f"You do not have permission to change the necessary files in '{tmp_dir}'.") -------------------------------------------------------------------------------- /src/nett/utils/callbacks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Callbacks for training the agents. 3 | 4 | Classes: 5 | HParamCallback(BaseCallback) 6 | """ 7 | from pathlib import Path 8 | from tqdm import tqdm 9 | import numpy as np 10 | from stable_baselines3.common.results_plotter import load_results, ts2xy 11 | from stable_baselines3.common.callbacks import BaseCallback, ProgressBarCallback 12 | from stable_baselines3.common.logger import HParam 13 | 14 | from nett.utils.train import compute_train_performance 15 | 16 | # TODO (v0.4): refactor needed, especially logging 17 | class HParamCallback(BaseCallback): 18 | """ 19 | Saves the hyperparameters and metrics at the start of the training, and logs them to TensorBoard. 20 | """ 21 | def _on_training_start(self) -> None: 22 | hparam_dict = { 23 | "algorithm": self.model.__class__.__name__, 24 | "learning rate": self.model.learning_rate, 25 | "gamma": self.model.gamma, 26 | "batch_size": self.model.batch_size, 27 | "n_steps": self.model.n_steps 28 | } 29 | # define the metrics that will appear in the `HPARAMS` Tensorboard tab by referencing their tag 30 | # Tensorbaord will find & display metrics from the `SCALARS` tab 31 | metric_dict = { 32 | "rollout/ep_len_mean": 0, 33 | "train/value_loss": 0.0, 34 | } 35 | self.logger.record( 36 | "hparams", 37 | HParam(hparam_dict, metric_dict), 38 | exclude=("stdout", "log", "json", "csv"), 39 | ) 40 | 41 | def _on_step(self) -> bool: 42 | return True 43 | 44 | class multiBarCallback(ProgressBarCallback): 45 | """ 46 | Display a progress bar when training SB3 agent 47 | using tqdm and rich packages. 48 | """ 49 | 50 | def __init__(self, index) -> None: #, num_steps 51 | super().__init__() 52 | self.index = index 53 | 54 | def _on_training_start(self) -> None: 55 | # Initialize progress bar 56 | # Remove timesteps that were done in previous training sessions 57 | self.pbar = tqdm(total=self.model.n_steps, position=self.index) 58 | # self.pbar = tqdm(total=self.locals["total_timesteps"] - self.model.num_timesteps, position=self.index) -------------------------------------------------------------------------------- /src/nett/brain/encoders/vit.py: -------------------------------------------------------------------------------- 1 | """ViT (Vision Transformer) encoder""" 2 | import gym 3 | import torch 4 | import timm 5 | 6 | from torchvision.transforms import Compose 7 | from torchvision.transforms import Resize, CenterCrop, Normalize, InterpolationMode 8 | from stable_baselines3.common.torch_layers import BaseFeaturesExtractor 9 | 10 | class ViT(BaseFeaturesExtractor): 11 | """ 12 | ViT is a feature extractor based on the Vision Transformer model. 13 | 14 | Args: 15 | observation_space (gym.spaces.Box): The observation space of the environment. 16 | features_dim (int, optional): The dimension of the extracted features. Defaults to 384. 17 | """ 18 | def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 384) -> None: 19 | """ 20 | Initializes the ViT (Vision Transformer) encoder. 21 | 22 | Args: 23 | observation_space (gym.spaces.Box): The observation space of the environment. 24 | features_dim (int, optional): The dimension of the extracted features. Defaults to 384. 25 | """ 26 | super(ViT, self).__init__(observation_space, features_dim) 27 | self.n_input_channels = observation_space.shape[0] 28 | self.transforms = Compose([Resize(size=248, 29 | interpolation=InterpolationMode.BICUBIC, 30 | max_size=None, 31 | antialias=True), 32 | CenterCrop(size=(224, 224)), 33 | Normalize(mean=torch.tensor([0.4850, 0.4560, 0.4060]), 34 | std=torch.tensor([0.2290, 0.2240, 0.2250]))]) 35 | 36 | self.model = timm.create_model("vit_small_patch8_224.dino", 37 | in_chans=self.n_input_channels, 38 | num_classes=0, 39 | pretrained=False) 40 | 41 | def forward(self, observations: torch.Tensor) -> torch.Tensor: 42 | """ 43 | Forward pass of the ViT encoder. 44 | 45 | Args: 46 | observations (torch.Tensor): The input observations. 47 | 48 | Returns: 49 | torch.Tensor: The extracted features. 50 | """ 51 | return self.model(self.transforms(observations)) 52 | -------------------------------------------------------------------------------- /src/nett/utils/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logger 3 | 4 | Classes: 5 | Logger 6 | """ 7 | 8 | #Packages for making the environment 9 | import uuid #needed for the communicator 10 | import os #Files and directories 11 | 12 | from mlagents_envs.side_channel.side_channel import ( 13 | SideChannel, 14 | IncomingMessage, 15 | OutgoingMessage, 16 | ) 17 | 18 | # Create the StringLogChannel class. 19 | class Logger(SideChannel): 20 | """ 21 | This class is used to log information from the environment to a file. This is how logging info is communicated between python and unity. It is a subclass of the SideChannel class from the mlagents_envs package. 22 | 23 | Methods: 24 | on_message_received: Method from Sidechannel interface. This method gets a message from unity and writes it to the log file. 25 | send_string: Method from Sidechannel interface. This method send a message to unity. 26 | log_str: This method is used to log a string to the file. 27 | """ 28 | def __init__(self, log_title, log_dir="./EnvLogs/") -> None: 29 | super().__init__(uuid.UUID("621f0a70-4f87-11ea-a6bf-784f4387d1f7")) # TODO why this UUID? 30 | if not os.path.exists(log_dir): 31 | os.makedirs(log_dir) 32 | self.log_dir = log_dir 33 | f_name = os.path.join(log_dir, f"{log_title}.csv") 34 | self.f = open(f_name, "w") 35 | 36 | def on_message_received(self, msg: IncomingMessage) -> None: 37 | """Method from Sidechannel interface. This method gets a message from unity and writes it to the log file.""" 38 | self.f.write(msg.read_string()) #Write message to log file 39 | self.f.write("\n") #add new line character 40 | 41 | #This is here because it is required and I currently don't use it. 42 | def send_string(self, data: str) -> None: 43 | """Method from Sidechannel interface. This method send a message to unity.""" 44 | msg = OutgoingMessage() 45 | msg.write_string(data) 46 | # We call this method to queue the data we want to send 47 | super().queue_message_to_send(msg) 48 | 49 | def log_str(self, msg: str) -> None: 50 | """This method is used to log a string to the file.""" 51 | self.f.write(msg) 52 | self.f.write("\n") 53 | 54 | def __del__(self) -> None: 55 | """This is called when the environment is shut down. It closes the log file.""" 56 | self.f.close() 57 | -------------------------------------------------------------------------------- /src/nett/brain/encoders/sam.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains the implementation of the SegmentAnything class, which is a custom feature extractor 3 | for image observations in a gym environment. It uses a pre-trained model from the timm library to extract 4 | features from the input images. 5 | """ 6 | import gym 7 | 8 | import torch as th 9 | import timm 10 | from torchvision.transforms import Compose 11 | from torchvision.transforms import Resize, CenterCrop, Normalize, InterpolationMode 12 | 13 | from stable_baselines3.common.torch_layers import BaseFeaturesExtractor 14 | 15 | class SegmentAnything(BaseFeaturesExtractor): 16 | """ 17 | Custom feature extractor for image observations in a gym environment. 18 | 19 | Args: 20 | observation_space (gym.spaces.Box): The observation space of the environment. 21 | features_dim (int, optional): Number of features extracted. This corresponds to the number of units for the last layer. Defaults to 384. 22 | """ 23 | 24 | def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 384) -> None: 25 | super(SegmentAnything, self).__init__(observation_space, features_dim) 26 | self.n_input_channels = observation_space.shape[0] 27 | self.transforms = Compose([Resize(size=256, 28 | interpolation=InterpolationMode.BICUBIC, 29 | max_size=None, 30 | antialias=True), 31 | CenterCrop(size=(224, 224)), 32 | Normalize(mean=th.tensor([0.485, 0.456, 0.406]), 33 | std=th.tensor([0.229, 0.224, 0.225]))]) 34 | 35 | n_input_channels = observation_space.shape[0] 36 | print("N_input_channels", n_input_channels) 37 | 38 | self.model = timm.create_model("samvit_base_patch16.sa1b", pretrained=True, 39 | num_classes=0) # remove classifier th.nn.Linear) 40 | 41 | def forward(self, observations: th.Tensor) -> th.Tensor: 42 | """ 43 | Forward pass of the feature extractor. 44 | """ 45 | # Cut off image 46 | # reshape to from vector to W*H 47 | # gray to color transform 48 | # application of ResNet 49 | # Concat features to the rest of observation vector 50 | # return 51 | return self.model(self.transforms(observations)) 52 | -------------------------------------------------------------------------------- /src/nett/brain/encoders/dinov2.py: -------------------------------------------------------------------------------- 1 | """DinoV2 feature extractor 2 | 3 | This module provides a feature extractor based on the DINOv2 model. It takes in observations from an environment and extracts features using the DINOv2 model. 4 | 5 | Example: 6 | 7 | >>> observation_space = gym.spaces.Box(low=0, high=255, shape=(3, 84, 84), dtype=np.uint8) 8 | >>> features_dim = 384 9 | >>> extractor = DinoV2(observation_space, features_dim) 10 | >>> observations = torch.randn(1, 3, 84, 84) 11 | >>> features = extractor.forward(observations) 12 | 13 | """ 14 | 15 | import gym 16 | import torch 17 | 18 | from torchvision.transforms import Compose, Resize, CenterCrop, Normalize, InterpolationMode 19 | from stable_baselines3.common.torch_layers import BaseFeaturesExtractor 20 | 21 | class DinoV2(BaseFeaturesExtractor): 22 | """ 23 | DinoV2 is a feature extractor based on the DINOv2 model. 24 | 25 | Args: 26 | observation_space (gym.spaces.Box): The observation space of the environment. 27 | features_dim (int, optional): Number of features extracted. This corresponds to the number of units for the last layer. Defaults to 384. 28 | 29 | Attributes: 30 | n_input_channels (int): Number of input channels in the observation space. 31 | transforms (torchvision.transforms.Compose): Preprocessing transforms applied to the input observations. 32 | model (torch.nn.Module): DINOv2 model loaded from the Facebook Research hub. 33 | """ 34 | def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 384) -> None: 35 | super(DinoV2, self).__init__(observation_space, features_dim) 36 | """Constructor method""" 37 | self.n_input_channels = observation_space.shape[0] 38 | self.transforms = Compose([Resize(size=256, 39 | interpolation=InterpolationMode.BICUBIC, 40 | max_size=None, 41 | antialias=True), 42 | CenterCrop(size=(224, 224)), 43 | Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), 44 | std=torch.tensor([0.229, 0.224, 0.225]))]) 45 | self.model = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14", pretrained=True) 46 | 47 | def forward(self, observations: torch.Tensor) -> torch.Tensor: 48 | """Forward pass of the DinoV2 feature extractor.""" 49 | return self.model(self.transforms(observations)) 50 | -------------------------------------------------------------------------------- /scripts/publish.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Get the directory where the script is located 4 | SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd) 5 | 6 | # Change directory to one level up from the script location and then into src/nett 7 | cd "$SCRIPT_DIR/../src/nett/" 8 | 9 | # Define the version variable by reading the version from the version.py file 10 | VERSION=$(python3 -c "from _version import __version__; print(__version__)") 11 | 12 | # Verify that VERSION is retrieved successfully 13 | if [[ -z "$VERSION" ]]; then 14 | echo "Failed to retrieve VERSION. Exiting..." 15 | exit 1 16 | fi 17 | 18 | # Confirm with the user 19 | echo -e "The retrieved version is \033[1m$VERSION\033[0m. Do you want to continue? (yes/no)" 20 | read user_confirmation 21 | if [[ "$user_confirmation" != "yes" ]]; then 22 | echo "User cancelled the operation. Exiting..." 23 | exit 1 24 | fi 25 | 26 | echo -e "\033[1m Publishing version $VERSION... \033[0m." 27 | 28 | # Return to base directory 29 | cd "$SCRIPT_DIR/../" 30 | 31 | echo -e "\033[1m Updating build package... \033[0m" 32 | 33 | python3 -m pip install --upgrade build 34 | 35 | echo -e "\033[1m Updating Twine package... \033[0m" 36 | 37 | python3 -m pip install --upgrade twine 38 | 39 | # Switch to the main branch and pull the latest changes 40 | git checkout main 41 | git pull origin main 42 | 43 | # Remove all files inside dist/ 44 | if [ "$(ls -A dist/)" ]; then 45 | rm -rf dist/* 46 | echo "Files and folders in dist/ have been deleted." 47 | else 48 | echo "dist/ is already empty." 49 | fi 50 | 51 | echo -e "\033[1m Tagging... \033[0m" 52 | # Tag the latest commit with the version number 53 | git tag v$VERSION 54 | 55 | # Push the tag to the remote repository 56 | git push origin v$VERSION 57 | 58 | echo -e "\033[1m Creating Release Branch... \033[0m" 59 | 60 | # Check if the branch exists and check it out if it does, otherwise create and push it 61 | if git show-ref --quiet refs/heads/release/$VERSION; then 62 | echo "Branch release/$VERSION exists. Checking it out..." 63 | git checkout release/$VERSION 64 | else 65 | echo "Branch release/$VERSION does not exist. Creating and pushing it..." 66 | git checkout -b release/$VERSION 67 | git push origin release/$VERSION 68 | fi 69 | 70 | echo -e "\033[1m Installing Dependencies... \033[0m" 71 | 72 | pip install -r ./docs/requirements.txt 73 | 74 | echo -e "\033[1m Building Distribution Archives... \033[0m" 75 | 76 | python3 -m build 77 | 78 | echo -e "\033[1m Uploading to PyPI... \033[0m" 79 | 80 | python3 -m twine upload --repository pypi dist/* 81 | -------------------------------------------------------------------------------- /src/nett/utils/environment.py: -------------------------------------------------------------------------------- 1 | """ 2 | environment.py 3 | 4 | Classes: 5 | Logger(SideChannel) 6 | 7 | Functions: 8 | port_in_use(port: int) -> bool 9 | """ 10 | 11 | import uuid 12 | import os 13 | import socket 14 | from mlagents_envs.side_channel.side_channel import ( 15 | SideChannel, 16 | IncomingMessage, 17 | OutgoingMessage, 18 | ) 19 | 20 | 21 | def port_in_use(port) -> bool: 22 | """This function checks if a port is in use. It returns True if the port is in use and False if it is not.""" 23 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 24 | try: 25 | sock.bind(("localhost", port)) 26 | except socket.error: 27 | return True 28 | return False 29 | 30 | # Create the StringLogChannel class. This is how logging info is communicated between python and unity 31 | class Logger(SideChannel): 32 | """ 33 | This class is used to log information from the environment to a file. It inherits from the SideChannel class 34 | 35 | Methods: 36 | on_message_received(msg: IncomingMessage) -> None 37 | send_string(data: str) -> None 38 | log_str(msg: str) -> None 39 | __del__() -> None 40 | """ 41 | def __init__(self, log_title, log_dir="./EnvLogs/") -> None: 42 | super().__init__(uuid.UUID("621f0a70-4f87-11ea-a6bf-784f4387d1f7")) 43 | if not os.path.exists(log_dir): 44 | os.makedirs(log_dir) 45 | self.log_dir = log_dir 46 | f_name = os.path.join(log_dir, f"{log_title}.csv") 47 | self.f = open(f_name, "w") 48 | 49 | #Method from Sidechannel interface. 50 | def on_message_received(self, msg: IncomingMessage) -> None: 51 | """This method is called when a message is received from unity.""" 52 | self.f.write(msg.read_string()) #Write message to log file 53 | self.f.write("\n") #add new line character 54 | 55 | #This is here because it is required and I currently don"t use it. 56 | def send_string(self, data: str) -> None: 57 | """Method from Sidechannel interface. This method send a message to unity.""" 58 | msg = OutgoingMessage() 59 | msg.write_string(data) 60 | # We call this method to queue the data we want to send 61 | super().queue_message_to_send(msg) 62 | 63 | def log_str(self, msg: str) -> None: 64 | """This method writes a custom string to the log file""" 65 | self.f.write(msg) 66 | self.f.write("\n") 67 | 68 | def __del__(self) -> None: 69 | """This is called when the environment is shut down""" 70 | self.f.close() 71 | -------------------------------------------------------------------------------- /src/nett/brain/__init__.py: -------------------------------------------------------------------------------- 1 | """Initializes the brain module.""" 2 | import ast 3 | from pathlib import Path 4 | 5 | import stable_baselines3 6 | import sb3_contrib 7 | from stable_baselines3.common.torch_layers import BaseFeaturesExtractor 8 | 9 | from nett.brain import encoders 10 | 11 | def list_encoders() -> list[str]: 12 | """ 13 | Returns a list of all available encoders. 14 | 15 | Returns: 16 | list[str]: List of encoder names. 17 | """ 18 | encoders_dir = Path.joinpath(Path(__file__).resolve().parent, 'encoders') 19 | encoders = [encoder.stem for encoder in list(encoders_dir.iterdir()) if "__" not in str(encoder)] 20 | return encoders 21 | 22 | encoders_list = list_encoders() 23 | 24 | def list_algorithms() -> list[str]: 25 | """ 26 | Returns a list of all available policy algorithms. 27 | 28 | Returns: 29 | list[str]: Set of algorithm names. 30 | """ 31 | sb3_policy_algorithms = [algorithm for algorithm in dir(stable_baselines3) if algorithm[0].isupper()] 32 | sb3_contrib_policy_algorithms = [algorithm for algorithm in dir(sb3_contrib) if algorithm[0].isupper()] 33 | available_policy_algorithms = sb3_policy_algorithms + sb3_contrib_policy_algorithms 34 | 35 | return available_policy_algorithms 36 | 37 | algorithms = list_algorithms() 38 | 39 | # TODO (v0.4) return all available policy models programmatically 40 | def list_policies() -> list[str]: 41 | """ 42 | Returns a list of all available policy models. 43 | 44 | Returns: 45 | list[str]: Set of policy names. 46 | """ 47 | return ['CnnPolicy', 'MlpPolicy', 'MultiInputPolicy', 'MultiInputLstmPolicy', 'CnnLstmPolicy'] 48 | 49 | policies = list_policies() 50 | 51 | # return encoder string to encoder class mapping 52 | # TODO (v0.4) optimized way to calculate and pass this dict around 53 | def get_encoder_dict() -> dict[str, str]: 54 | """ 55 | Returns a dictionary mapping encoder names to encoder class names. 56 | 57 | Returns: 58 | dict[str, str]: Dictionary mapping encoder names to encoder class names. 59 | """ 60 | encoders_dict: dict[str, str] = {} 61 | encoders_dir = Path.joinpath(Path(__file__).resolve().parent, 'encoders') 62 | # iterate through all files in the directory 63 | for encoder_path in encoders_dir.iterdir(): 64 | if encoder_path.suffix == '.py' and "__" not in str(encoder_path): 65 | module_name = encoder_path.stem 66 | # read the source 67 | with open(encoder_path) as f: 68 | source = f.read() 69 | # parse it 70 | module = ast.parse(source) 71 | # get the first class definition 72 | encoder_class = [node for node in ast.walk(module) if isinstance(node, ast.ClassDef)][0] 73 | # add to the dictionary 74 | encoders_dict[module_name] = encoder_class.name 75 | return encoders_dict 76 | 77 | encoder_dict = get_encoder_dict() 78 | -------------------------------------------------------------------------------- /src/nett/brain/encoders/cnnlstm.py: -------------------------------------------------------------------------------- 1 | """CNNLSTM module for the brain""" 2 | 3 | ### DELETE WILL NOT BE USING #### 4 | import gym 5 | import torch as th 6 | from torch import nn 7 | from stable_baselines3.common.torch_layers import BaseFeaturesExtractor 8 | 9 | class CNNLSTM(BaseFeaturesExtractor): 10 | """ 11 | CNNLSTM is a class that represents a convolutional neural network (CNN) 12 | followed by a long short-term memory (LSTM) layer. It is used as a feature 13 | extractor in reinforcement learning algorithms. 14 | 15 | Args: 16 | observation_space (gym.Space): The observation space of the environment. 17 | features_dim (int, optional): Number of features extracted. This corresponds to the number of units for the last layer. Defaults to 256. 18 | """ 19 | def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256) -> None: 20 | """Constructor method 21 | """ 22 | super(CNNLSTM, self).__init__(observation_space, features_dim) 23 | 24 | n_input_channels = observation_space.shape[0] 25 | self.cnn = nn.Sequential( 26 | nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0), 27 | nn.ReLU(), 28 | nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), 29 | nn.ReLU(), 30 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0), 31 | nn.ReLU(), 32 | nn.Flatten() 33 | ) 34 | 35 | # Compute shape by doing one forward pass 36 | with th.no_grad(): 37 | n_flatten = self.cnn(th.as_tensor(observation_space.sample()[None]).float()).shape[1] 38 | 39 | # define LSTM layer 40 | hidden_size = 512 41 | self.lstm = nn.LSTM(input_size=n_flatten, hidden_size=hidden_size, 42 | num_layers=2, batch_first=True) 43 | 44 | # outputs 45 | self.linear = nn.Sequential(nn.Linear(hidden_size, features_dim), nn.ReLU()) 46 | 47 | def forward(self, observations: th.Tensor): 48 | """ 49 | Forward pass of the CNNLSTM. 50 | 51 | Args: 52 | observations (torch.Tensor): The input observations. 53 | 54 | Returns: 55 | torch.Tensor: The extracted features. 56 | """ 57 | x = observations # original shape -> (length, batchsize, obs_size) 58 | # T,B, *_ = x.shape 59 | 60 | # Pass through CNN layers 61 | x = self.cnn(x) 62 | 63 | # Flatten the output for LSTM 64 | x = x.view(x.size(0), x.size(1), -1) 65 | 66 | # Pass through LSTM layer 67 | x, _ = self.lstm(x) 68 | 69 | # Get the last time step's output and apply the fully connected layer 70 | x = self.linear(x[:, -1, :]) 71 | 72 | return x 73 | 74 | class Identity(nn.Module): 75 | """Identity module 76 | 77 | This module is used to return the input tensor as is. 78 | 79 | Args: 80 | torch.nn.Module: PyTorch module 81 | 82 | Returns: 83 | torch.nn.Module: Identity module 84 | """ 85 | def __init__(self) -> None: 86 | """Constructor method""" 87 | super(Identity, self).__init__() 88 | 89 | def forward(self, x: th.Tensor) -> th.Tensor: 90 | """Forward pass""" 91 | return x 92 | -------------------------------------------------------------------------------- /src/nett/utils/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train performance 3 | 4 | This module contains the functions to compute the train performance of the agent. 5 | 6 | Functions: 7 | compute_train_performance: Compute Train performance 8 | average_in_episode_three_region: Train performance 9 | moving_average: Smooth values by doing a moving average 10 | """ 11 | 12 | import os 13 | import glob 14 | import numpy as np 15 | import pandas as pd 16 | 17 | def compute_train_performance(path) -> tuple[list, np.ndarray | list]: 18 | """ 19 | Compute Train performance 20 | 21 | Args: 22 | path (Path or String?): path to training files # TODO check the type 23 | 24 | Returns: 25 | x (list): list of the episode numbers 26 | y (numpy.array) : the moving averages of the success rate 27 | """ 28 | x,y = [], [] 29 | try: 30 | training_files = glob.glob(os.path.join(path, "*.csv")) 31 | 32 | if len(training_files) == 0: 33 | raise Exception(f"Training file: {training_files} was not found in the {path}") 34 | 35 | 36 | for file_name in training_files: 37 | 38 | log_df = pd.read_csv(file_name, skipinitialspace=True) 39 | 40 | _, _, values = average_in_episode_three_region(log_df,"agent.x") # percents,df, 41 | y = moving_average(values, window=100) 42 | x = list(range(len(y))) 43 | 44 | break 45 | 46 | 47 | return x, y 48 | except Exception as ex: 49 | print(str(ex)) 50 | 51 | return x,y 52 | 53 | def average_in_episode_three_region(log: pd.DataFrame, column: str = 'agent.x', transient: int = 90) -> tuple[dict, pd.DataFrame, list]: 54 | """ 55 | Train performance 56 | 57 | Args: 58 | log (_type_): _description_ 59 | column (str, optional): _description_. Defaults to 'agent.x'. 60 | transient (int, optional): _description_. Defaults to 90. 61 | 62 | Returns: 63 | _type_: _description_ 64 | """ 65 | try: 66 | log.loc[log.Episode % 2 == 1, column] *= -1 67 | #Translate coordinates 68 | log[column] += 10 69 | #Bin into 3 sections 70 | log[column] = pd.cut(log[column], [-0.1,20/3,40/3,20.1],labels=["Distractor","Null","Imprint"]) 71 | episodes = log.Episode.unique() 72 | percents = {} 73 | for ep in episodes: 74 | #Get success percentage 75 | l = log[log["Episode"]==ep] 76 | l = l[l["Step"]>transient] 77 | total = l[l[column]=="Distractor"].count() + l[l[column]=="Imprint"].count() 78 | success = l[l[column]=="Imprint"].count()/total 79 | percents[ep] = success[column] 80 | 81 | if np.isnan(percents[ep]): 82 | percents[ep] = 0.5 83 | 84 | rv = list(percents.values()) 85 | 86 | return (percents,log,rv) 87 | except Exception as ex: 88 | print(str(ex)) 89 | return (None, None, None) 90 | 91 | def moving_average(values: list, window: int) -> np.ndarray: 92 | """ 93 | Smooth values by doing a moving average. 94 | 95 | Args: 96 | values (numpy.array): The input array of values. 97 | window (int): The size of the moving window. 98 | 99 | Returns: 100 | numpy.array: The smoothed array of values. 101 | """ 102 | weights: np.ndarray = np.repeat(1.0, window) / window 103 | return np.convolve(values, weights, 'valid') 104 | -------------------------------------------------------------------------------- /src/nett/analysis/NETT_train_viz.R: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env Rscript 2 | 3 | # NETT_train_viz.R 4 | 5 | # Before running this script, you need to run merge_csvs to merge all of the agents' 6 | # output into a single, standardized format dataframe for training and test data 7 | 8 | # Variables -------------------------------------------------------------------- 9 | 10 | # Read in the user-specified variables: 11 | library(argparse) 12 | parser <- ArgumentParser(description="An executable R script for the Newborn Embodied Turing Tests to analyze test trials") 13 | parser$add_argument("--data-loc", type="character", dest="data_loc", 14 | help="Full filename (inc working directory) of the merged R data", 15 | required=TRUE) 16 | parser$add_argument("--results-wd", type="character", dest="results_wd", 17 | help="Working directory to save the resulting visualizations", 18 | required=TRUE) 19 | parser$add_argument("--ep-bucket", type="integer", dest="ep_bucket_size", 20 | help="How many episodes to group the x-axis by", 21 | required=TRUE) 22 | parser$add_argument("--num-episodes", type="integer", dest="num_episodes", 23 | help="How many episodes should be included", 24 | required=TRUE) 25 | args <- parser$parse_args() 26 | data_loc <- args$data_loc; results_wd <- args$results_wd; ep_bucket_size <- args$ep_bucket_size; num_episodes <- args$num_episodes 27 | 28 | # Set Up ----------------------------------------------------------------------- 29 | 30 | library(tidyverse) 31 | 32 | load(data_loc) 33 | rm(test_data) 34 | setwd(results_wd) 35 | 36 | train_data_fixed <- train_data %>% 37 | filter(Episode < num_episodes) %>% 38 | # Create variables for correct/incorrect calculations 39 | mutate(correct_steps = if_else(correct.monitor == " left", left_steps, right_steps)) %>% 40 | mutate(incorrect_steps = if_else(correct.monitor == " left", right_steps, left_steps)) %>% 41 | mutate(percent_correct = correct_steps / (correct_steps + incorrect_steps)) %>% 42 | # Summarise data by condition, agent, and episode bucket for graphing 43 | mutate(episode_block = Episode%/%ep_bucket_size + 1) %>% 44 | group_by(imprint.cond, agent, episode_block) %>% 45 | summarise(avgs = mean(percent_correct, na.rm = TRUE), 46 | sd = sd(percent_correct, na.rm = TRUE), 47 | count = length(percent_correct)) %>% 48 | mutate(se = sd / sqrt(count)) %>% 49 | # Convert numerical variables into correct type 50 | mutate(episode_block = as.numeric(episode_block)) %>% 51 | mutate(agent = as.numeric(agent)) %>% 52 | ungroup() 53 | 54 | 55 | # Plot line graphs by imprinting condition ------------------------------------- 56 | 57 | for (cond in unique(train_data_fixed$imprint.cond)) 58 | { 59 | data <- train_data_fixed %>% 60 | filter(imprint.cond == cond) 61 | 62 | ggplot(data=data, aes(x=episode_block, y=avgs, color=as.factor(agent))) + 63 | geom_line() + 64 | theme_classic(base_size = 16) + 65 | geom_hline(yintercept = .5, linetype = 2) + 66 | xlab(sprintf("Groups of %d Episodes", ep_bucket_size)) + 67 | ylab("Average Time with Imprinted Object") + 68 | scale_y_continuous(expand = c(0, 0), limits = c(0, 1), 69 | breaks=seq(0,1,.1), labels = scales::percent) + 70 | scale_x_continuous(expand = c(0, 0), limits = c(0, num_episodes/ep_bucket_size), 71 | breaks = seq(0, num_episodes / ep_bucket_size, 1)) + 72 | theme(legend.position="none") 73 | 74 | img_name <- paste0(cond, "_train.png") 75 | ggsave(img_name) 76 | } 77 | 78 | -------------------------------------------------------------------------------- /docs/source/3papers/ViewInvariant.md: -------------------------------------------------------------------------------- 1 | # A Newborn Embodied Turing Test for View-Invariant Recognition 2 | 3 | Denizhan Pak, Donsuk Lee, Samantha M. W. Wood & Justin N. Wood 4 | 5 | 8 | 9 | ## Abstract 10 | 11 | *Recent progress in artificial intelligence has renewed interest in building machines that learn like animals. Almost all of the work comparing learning across biological and artificial systems comes from studies where animals and machines received different training data, obscuring whether differences between animals and machines emerged from differences in learning mechanisms versus training data. We present an experimental approach—a “newborn embodied Turing Test”—that allows newborn animals and machines to be raised in the same environments and tested with the same tasks, permitting direct comparison of their learning abilities. To make this platform, we first collected controlled-rearing data from newborn chicks, then performed “digital twin” experiments in which machines were raised in virtual environments that mimicked the rearing conditions of the chicks. We found that (1) machines (deep reinforcement learning agents with intrinsic motivation) can spontaneously develop visually guided preference behavior, akin to imprinting in newborn chicks, and (2) machines are still far from newborn-level performance on object recognition tasks. Almost all of the chicks developed view-invariant object recognition, whereas the machines tended to develop view-dependent recognition. The learning outcomes were also far more constrained in the chicks versus machines. Ultimately, we anticipate that this approach will help researchers develop embodied AI systems that learn like newborn animals.* 12 | 13 | ## Experiment Design 14 | 15 | - VR chambers were equipped with two display walls (LCD monitors) for displaying object stimuli. 16 | - During the Training Phase, artificial chicks were reared in an environment containing a single 3D object rotating 15° around a vertical axis in front of a blank background scene. The object made a full rotation every 3s. Agents can be imprinted to one of 4 possible conditions: side and front views of the Fork object or side and front views of the ship object. 17 | - During the Test Phase, the VR chambers measured the artificial chicks’ imprinting response and object recognition performance. The “imprinting trials” measured whether the chicks developed an imprinting response. The “test trials” measured the aritifical chicks’ ability to visually discriminate their imprinted object. During these trials, the imprinted object, rotated at an alternate angle to the imprint condition, was presented on one display wall and an unfamiliar object was presented on the other display wall, the angle of which was either the same as the imprint condition (fixed trials) or matched to the viewpoint in the test condition (matched trials). 18 | 19 | ## Arguments 20 | 21 | ### Train configuration 22 | 23 | ``` 24 | agent_count: 1 25 | run_id:ship_front_exp 26 | log_path: data/ship_front_exp 27 | mode: full 28 | train_eps: 1000 29 | test_eps: 40 30 | cuda: 0 31 | Agent: 32 | reward: supervised 33 | encoder: small 34 | Environment: 35 | use_ship: true 36 | side_view: false 37 | background: A 38 | base_port: 5100 39 | env_path: data/executables/viewpoint_benchmark/viewpoint.x86_64 40 | log_path: data/ship_front_exp/Env_Logs 41 | rec_path: data/ship_front_exp/Recordings/ 42 | record_chamber: false 43 | record_agent: false 44 | recording_frames: 0 45 | ``` 46 | 47 | ## Executables 48 | 49 | [Exectuable can be found here](https://origins.luddy.indiana.edu/unity/executables/). 50 | -------------------------------------------------------------------------------- /docs/source/3papers/Parsing.md: -------------------------------------------------------------------------------- 1 | # A Newborn Embodied Turing Test for Visual Parsing 2 | 3 | Manju Garimella, Denizhan Pak, Lalit Pandey, Justin N. Wood, & Samantha M. W. Wood 4 | 5 | 8 | 9 | ## Abstract 10 | 11 | *Newborn brains exhibit remarkable abilities in rapid and generative learning, including the ability to parse objects from backgrounds and recognize those objects across substantial changes to their appearance (i.e., novel backgrounds and novel viewing angles). How can we build machines that can learn as efficiently as newborns? To accurately compare biological and artificial intelligence, researchers need to provide machines with the same training data that an organism has experienced since birth. Here, we present an experimental benchmark that enables researchers to raise artificial agents in the same controlled-rearing environments as newborn chicks. First, we raised newborn chicks in controlled environments with visual access to only a single object on a single background and tested their ability to recognize their object across novel viewing conditions. Then, we performed “digital twin” experiments in which we reared a variety of artificial neural networks in virtual environments that mimicked the rearing conditions of the chicks and measured whether they exhibited the same object recognition behavior as the newborn chicks. We found that biological chicks developed background-invariant object recognition, while the artificial chicks developed background-dependent recognition. Our benchmark exposes the limitations of current unsupervised and supervised algorithms in achieving the learning abilities of newborn animals. Ultimately, we anticipate that this approach will contribute to the development of AI systems that can learn with the same efficiency as newborn animals.* 12 | 13 | ## Experiment Design 14 | 15 | - VR chambers were equipped with two display walls (LCD monitors) for displaying object stimuli. 16 | - During the Training Phase, artificial chicks were reared in an environment containing a single 3D object rotating a full 360° around a horizontal axis in front of a naturalistic background scene. The object made a full rotation every 15s. 17 | - During the Test Phase, the VR chambers measured the artificial chicks’ imprinting response and object recognition performance. The “imprinting trials” measured whether the chicks developed an imprinting response. The “test trials” measured the aritifical chicks’ ability to visually parse and recognize their imprinted object. During these trials, the imprinted object was presented on one display wall and an unfamiliar object was presented on the other display wall. Across the test trials, the objects were presented on all possible combinations of the three background scenes (Background 1 vs.Background 1, Background 1 vs. Background 2, Background 1 vs.Background 3, etc.). 18 | 19 | ## Arguments 20 | 21 | ### Train configuration 22 | 23 | ``` 24 | agent_count: 1 25 | run_id:exp1 26 | log_path: data/exp1 27 | mode: full 28 | train_eps: 1000 29 | test_eps: 40 30 | cuda: 0 31 | Agent: 32 | reward: supervised 33 | encoder: small 34 | Environment: 35 | use_ship: true 36 | side_view: false 37 | background: A 38 | base_port: 5100 39 | env_path: data/executables/parsing_benchmark/parsing.x86_64 40 | log_path: data/ship_backgroundA_exp/Env_Logs 41 | rec_path: data/ship_backgroundA_exp/Recordings/ 42 | record_chamber: false 43 | record_agent: false 44 | recording_frames: 0 45 | ``` 46 | #### Run script 47 | 48 | ```bash 49 | python src/simulation/run_parsing_exp.py ++run_id=exp1 ++Environment.env_path=data/executables/parsing_benchmark/parsing_app.x86_64 ++mode=full ++train_eps=1000 ++test_eps=40 ++Agent.encoder="small" ++Environment.use_ship="true" ++Environment.background="A" 50 | ``` 51 | where 52 | 53 | `Environment.use_ship` = True or False (to choose between Ship and Fork); 54 | `Environment.background` = A, B, C (to choose between the three background); 55 | mode = full or train or test (to choose between the three modes to run); 56 | `Agent.encoder` = "small", "medium" or "large" to choose between the three different types of encoders: NatureCNN, resnet10 and resnet18 57 | `Agent.reward` = "supervised" default 58 | 59 | 60 | #### Custom Configuration: 61 | 62 | - Update train episode count; test episode count 63 | - Encoder types - small, medium and large 64 | - Reward types - supervised or 'unsupervised' 65 | 66 | ## Links 67 | 68 | [Executables can be found here](https://origins.luddy.indiana.edu/unity/executables/) 69 | 70 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | 161 | # R code 162 | 163 | # History files 164 | .Rhistory 165 | .Rapp.history 166 | 167 | # Session Data files 168 | .RData 169 | .RDataTmp 170 | 171 | # User-specific files 172 | .Ruserdata 173 | 174 | # Example code in package build process 175 | *-Ex.R 176 | 177 | # Output files from R CMD build 178 | /*.tar.gz 179 | 180 | # Output files from R CMD check 181 | /*.Rcheck/ 182 | 183 | # RStudio files 184 | .Rproj.user/ 185 | 186 | # produced vignettes 187 | vignettes/*.html 188 | vignettes/*.pdf 189 | 190 | # OAuth2 token, see https://github.com/hadley/httr/releases/tag/v0.3 191 | .httr-oauth 192 | 193 | # knitr and R markdown default cache directories 194 | *_cache/ 195 | /cache/ 196 | 197 | # Temporary files created by R markdown 198 | *.utf8.md 199 | *.knit.md 200 | 201 | # R Environment Variables 202 | .Renviron 203 | 204 | # translation temp files 205 | po/*~ 206 | 207 | # RStudio Connect folder 208 | rsconnect/ 209 | 210 | # MacOS .DS_Store files 211 | .DS_Store 212 | # Remove .vscode folder 213 | .vscode/ 214 | -------------------------------------------------------------------------------- /src/nett/analysis/NETT_merge_csvs.R: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env Rscript 2 | 3 | # merge_csvs.R 4 | # For a specified directory (see below), takes all of the csv files 5 | # and compiles them into a single data file 6 | 7 | # NOTE: For ease of use across many different experimental designs, 8 | # this script assumes that all files use a common naming scheme with the 9 | # following criteria: 10 | # 1) The agent ID number is the only number in the file name 11 | # 2) The filename ends with either train.csv or test.csv 12 | # for example "fork_side-agent3_train.csv" 13 | 14 | 15 | # Variables -------------------------------------------------------------------- 16 | 17 | # Read in the user-specified variables: 18 | library(argparse) 19 | parser <- ArgumentParser(description="An executable R script for the Newborn Embodied Turing Tests to merge the log files across many agents") 20 | parser$add_argument("--logs-dir", type="character", dest="logs_dir", 21 | help="Working directory of the agents' log files", 22 | required=TRUE) 23 | parser$add_argument("--results-dir", type="character", dest="results_dir", 24 | help="Working directory to store the merged output", 25 | required=TRUE) 26 | parser$add_argument("--results-name", type="character", dest="results_name", 27 | help="File name for the R file storing the results", 28 | required=TRUE) 29 | parser$add_argument("--csv-train", type="character", dest="csv_train_name", 30 | help="File name for the csv file storing the training results", 31 | required=FALSE) 32 | parser$add_argument("--csv-test", type="character", dest="csv_test_name", 33 | help="File name for the csv file storing the testing results", 34 | required=FALSE) 35 | args <- parser$parse_args() 36 | data_wd <- args$logs_dir; results_wd <- args$results_dir; results_name <- args$results_name 37 | csv_train_name <- args$csv_train_name; csv_test_name <- args$csv_test_name 38 | 39 | # Set Zones: 40 | upper_x_lim <- 10 41 | lower_x_lim <- -10 42 | one_third <- (upper_x_lim - lower_x_lim)/3 43 | lower_bound <- lower_x_lim + one_third 44 | upper_bound <- upper_x_lim - one_third 45 | 46 | 47 | # Set Up ----------------------------------------------------------------------- 48 | 49 | # Import libraries 50 | library(tidyverse) 51 | 52 | # Get all of the subdirectory csv filenames 53 | setwd(data_wd) 54 | train_files <- list.files(pattern="train.csv", recursive = TRUE) 55 | test_files <- list.files(pattern="test.csv", recursive = TRUE) 56 | 57 | # Check if there are any train files 58 | if (length(train_files) == 0) { 59 | stop("No train files found.") 60 | } 61 | # Check if there are any train files 62 | if (length(test_files) == 0) { 63 | stop("No test files found.") 64 | } 65 | 66 | # Main Function ---------------------------------------------------------------- 67 | 68 | # This function reads in a single csv (later we'll lapply it across all files) 69 | read_data <- function(filename) 70 | { 71 | # Read the csv file 72 | data <- read.csv(filename) 73 | 74 | # Summarize by zones 75 | data <- data %>% 76 | mutate(left = case_when( agent.x < lower_bound ~ 1, agent.x >= lower_bound ~ 0)) %>% 77 | mutate(right = case_when( agent.x > upper_bound ~ 1, agent.x <= upper_bound ~ 0)) %>% 78 | mutate(middle = 1- left - right) 79 | # Quick check to make sure that one and only one zone is chosen at each step 80 | stopifnot(all( (data$left + data$right + data$middle == 1) )) 81 | # Summarize at the episode level 82 | data <- data %>% 83 | group_by(Episode, left.monitor, right.monitor, correct.monitor, experiment.phase, imprint.cond, test.cond) %>% 84 | summarise(.groups = "keep", 85 | left_steps = sum(left), 86 | right_steps = sum(right), 87 | middle_steps = sum(middle)) %>% 88 | mutate(Episode = as.numeric(Episode)) %>% 89 | mutate(left.monitor = sub(" ", "", left.monitor)) %>% 90 | mutate(right.monitor = sub(" ", "", right.monitor)) %>% 91 | ungroup() 92 | 93 | # Add columns for original filename and agent ID number 94 | data$filename <- basename(filename) 95 | data$agent <- gsub("\\D", "", data$filename) # Only keep the number 96 | 97 | return(data) 98 | } 99 | 100 | # Combine csv's and save results ----------------------------------------------- 101 | 102 | # Combine all the training 103 | train_data <- lapply(train_files, FUN = read_data) 104 | train_data <- bind_rows(train_data) 105 | 106 | # Combine all the testing 107 | test_data <- lapply(test_files, FUN = read_data) 108 | test_data <- bind_rows(test_data) 109 | 110 | # Save it 111 | setwd(results_wd) 112 | save(train_data, test_data, file=results_name) 113 | if( !is.null(csv_train_name) ) write.csv(train_data, csv_train_name) 114 | if( !is.null(csv_test_name) ) write.csv(test_data, csv_test_name) 115 | 116 | -------------------------------------------------------------------------------- /examples/notebooks/Getting Started.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# imports\n", 10 | "from nett import Brain, Body, Environment\n", 11 | "from nett import NETT" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "name": "stderr", 21 | "output_type": "stream", 22 | "text": [ 23 | "[netts.Environment] INFO: Executable permission is set\n", 24 | "[netts.Environment] INFO: Display is set\n" 25 | ] 26 | } 27 | ], 28 | "source": [ 29 | "# define components (brain, body and environment)\n", 30 | "# this example shows only a minimal setup, a LOT of customization is possible here\n", 31 | "brain = Brain(policy='CnnPolicy', algorithm='PPO')\n", 32 | "body = Body(type=\"basic\", dvs=False)\n", 33 | "environment = Environment(config=\"identityandview\", executable_path=\"/home/desabh/builds/identityandview/smoothnesswithbinding-linux.x86_64\")" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 1, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "# construct the NETT\n", 43 | "nett = NETT(brain=brain, body=body, environment=environment)\n", 44 | "# run the NETT\n", 45 | "job_sheet = nett.run(output_dir=\"./test_run\", num_brains=1, train_eps=10, test_eps=1)" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 6, 51 | "metadata": {}, 52 | "outputs": [ 53 | { 54 | "data": { 55 | "text/html": [ 56 | "
\n", 57 | "\n", 70 | "\n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | "
runningspecification.devicespecification.modespecification.brain_id
0True0object1-horizontal1
1True0object2-horizontal1
2True1object1-vertical1
3True2object2-vertical1
\n", 111 | "
" 112 | ], 113 | "text/plain": [ 114 | " running specification.device specification.mode specification.brain_id\n", 115 | "0 True 0 object1-horizontal 1\n", 116 | "1 True 0 object2-horizontal 1\n", 117 | "2 True 1 object1-vertical 1\n", 118 | "3 True 2 object2-vertical 1" 119 | ] 120 | }, 121 | "execution_count": 6, 122 | "metadata": {}, 123 | "output_type": "execute_result" 124 | } 125 | ], 126 | "source": [ 127 | "# check status\n", 128 | "nett.status(job_sheet)" 129 | ] 130 | } 131 | ], 132 | "metadata": { 133 | "kernelspec": { 134 | "display_name": "testenv", 135 | "language": "python", 136 | "name": "python3" 137 | }, 138 | "language_info": { 139 | "codemirror_mode": { 140 | "name": "ipython", 141 | "version": 3 142 | }, 143 | "file_extension": ".py", 144 | "mimetype": "text/x-python", 145 | "name": "python", 146 | "nbconvert_exporter": "python", 147 | "pygments_lexer": "ipython3", 148 | "version": "3.10.12" 149 | } 150 | }, 151 | "nbformat": 4, 152 | "nbformat_minor": 2 153 | } 154 | -------------------------------------------------------------------------------- /src/nett/body/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Body module for nett 3 | 4 | .. module:: nett.body 5 | :synopsis: Body module for nett 6 | """ 7 | 8 | # all available body types 9 | types = ["basic", "two-eyed", "ragdoll"] 10 | 11 | # ASCII art 12 | ascii_basic = """ 13 | ---------------::::::::::::::::::::::::::::::::::::::::::::::::::--------------- 14 | ----------------::::::::::::::::::::::::::::::::::::::::::::::::---------------- 15 | ------------------::::::::::::::::..::::::::::...::::::::::::::----------------- 16 | --------------------::::::::.:::::---------------::::.:::::::------------------- 17 | ----------------------:::::-------------:--------------::::--------------------- 18 | ----------------------.------=====---:. ..---========-----.-------------------- 19 | --------------------:--====+++++======----======++****+====--.------------------ 20 | ------------------:-===+***********==========*************+===------------------ 21 | -----------------:==+****%%%%%%%%#***+++++++***%%%%%%%%%#***+==-:--------------- 22 | -----------------=++**#%%%%%@@@%%%%**++++++**%%%@%%%%%@%%%%***+==:-------------- 23 | ----------------=+***%@@@@@@@@@@@@@%**++++**%@@@@@@@@@@@@@@%***++=:------------- 24 | ---------------+++**@@@@@@@@@@@@@@@@********@@@@@@@@@@@@@@@@@**+++=------------- 25 | --------------++++*%@@@@@@@@@@@@@@@@+******+@@@@@@@@@@@@@@@@@*+++++=------------ 26 | -------------=++*++@@@@@@@@@@@@@@@@++******++@@@@@@@@@@@@@@@@+++*++=:----------- 27 | -------------++**++%@@@@@@@@@@@@@@++********++@@@@@@@@@@@@@@@+++**++=----------- 28 | -----------.=+***+=+@@@@@@@@@@@@@++****:.****++@@@@@@@@@@@@@+==****+=:---------- 29 | ------------+*****=-==@@@@@@@@===***-=+=+==.***===@@@@@@@@===-*****++----------- 30 | -----------=+*******+--=======****=+********-****+=======--=*******++----------- 31 | -----------=+*********************+**********:*********************++=---------- 32 | -----------++********************-*##*****##*=**********************+=---------- 33 | -----------************************###***###*************************+---------- 34 | ====-------**######################################################**=---------= 35 | ========----*######################################################**------===== 36 | ==========--**#####################################################*----======== 37 | ===========:.:*###################################################*============= 38 | =====::::......:################################################*.........====== 39 | ===------**###+::::::....--###############################--.........+*---::.=== 40 | ====.-===**#####********************------------======**********#####**-=--:--== 41 | ==+======**#########******************************************#######**====--=== 42 | @*+++++=+**#################**************************###############*+=======+= 43 | @*++++++***##########################################################*+++++++++@ 44 | ***+++*****##########################################################***++++++** 45 | +#*++******##########################################################*****++++** 46 | @##********#########################################################******+++**@ 47 | @*###******#########################################################**********#@ 48 | +*####****+#########################################################********###* 49 | ++######****#######################################################*+*****####*+ 50 | %%#######***#######################################################****#######@@ 51 | ++#######***#######################################################****#######++ 52 | ##+######***######################################################*+**#######%@@ 53 | ===#######**######################################################***########++* 54 | ====######***#####################################################***#######@@@% 55 | @@@@######***####################################################***########++++ 56 | ===%%######***###################################################***#######====+ 57 | =%@===######***#################################################***#######@@@@@@ 58 | @@@@@@@######****##############################################***#######@%+==== 59 | @========#####**+###########################################*+***######*==*%%=== 60 | ==========*####*=*#########################################*@***######@@@@@@@@@@ 61 | ============%*====*#######################################*==*######=========@%= 62 | @@@@%%%%%%%%%%%%%##*######################################+==@####============@% 63 | =======----%%##@@@@@#####################################*=-==@%===============@ 64 | ==--------%@--------=###################################*@@@@@@@@@@@@@%%%%%%%%%% 65 | ---------%%----------=#################################=-------@%---=========**# 66 | --------%%-------------*#############################*----------@%------======== 67 | %%%%%%%%%+--------------*###########################*------------%%---------==== 68 | @@@@@@@@@@@@@@@@@@@@@@%%%%*#######################*--------------@%%-----------= 69 | ------%@------------------@@*###################*@%%%%%%%%%%%%%%%%@%------------ 70 | -----%%-------------------%%--=*#############*%%----**@@@@@@@@@@@@@@@@@@@@@@@@@@ 71 | ----%%-----------------===%@------=*******----%%-------------------@%%---------- 72 | """ 73 | ascii_art = {'basic': ascii_basic} 74 | -------------------------------------------------------------------------------- /src/nett/body/wrappers/dvs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import collections 4 | import gym 5 | import numpy as np 6 | import cv2 7 | import logging 8 | 9 | logger = logging.getLogger(__name__) 10 | logger.setLevel(logging.INFO) 11 | 12 | class DVSWrapper(gym.ObservationWrapper): 13 | """ 14 | A gym observation wrapper that performs Dynamic Vision Sensor (DVS) transformation on the environment observations. 15 | 16 | Args: 17 | env (gym.Env): The environment to wrap. 18 | change_threshold (int): The threshold value for detecting changes in pixel intensity. 19 | kernel_size (tuple): The size of the Gaussian kernel used for blurring. 20 | sigma (float): The standard deviation of the Gaussian kernel. 21 | 22 | Attributes: 23 | change_threshold (int): The threshold value for detecting changes in pixel intensity. 24 | kernel_size (tuple): The size of the Gaussian kernel used for blurring. 25 | sigma (float): The standard deviation of the Gaussian kernel. 26 | num_stack (int): The number of frames to stack. 27 | env (gym.Env): The wrapped environment. 28 | stack (collections.deque): A deque to store the stacked frames. 29 | shape (tuple): The shape of the observation space. 30 | observation_space (gym.spaces.Box): The modified observation space. 31 | 32 | Methods: 33 | create_grayscale(image): Converts an image to grayscale. 34 | gaussianDiff(previous, current): Computes the difference between two images using Gaussian blur. 35 | observation(obs): Performs the DVS transformation on the observation. 36 | threshold(change): Applies a threshold to the change map. 37 | reset(**kwargs): Resets the environment and returns the initial observation. 38 | 39 | """ 40 | 41 | def __init__(self, env, change_threshold=60, kernel_size=(3, 3), sigma=1, is_color = True): 42 | super().__init__(env) 43 | 44 | self.change_threshold = change_threshold 45 | self.kernel_size = kernel_size 46 | self.sigma = sigma 47 | self.num_stack = 2 ## default 48 | self.env = gym.wrappers.FrameStack(env,self.num_stack) 49 | self.stack = collections.deque(maxlen=self.num_stack) 50 | self.is_color = is_color 51 | 52 | try: 53 | _, _, width, height = self.env.observation_space.shape # stack, channels, 54 | self.shape=(3, width, height) 55 | self.observation_space = gym.spaces.Box(shape=self.shape, low=0, high=255, dtype=np.uint8) 56 | logger.info("In dvs wrapper") 57 | except Exception as ex: 58 | print(str(ex)) 59 | 60 | 61 | def create_grayscale(self, image): 62 | """ 63 | Converts an image to grayscale. 64 | 65 | Args: 66 | image (numpy.ndarray): The input image. 67 | 68 | Returns: 69 | numpy.ndarray: The grayscale image. 70 | 71 | """ 72 | return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 73 | 74 | 75 | def gaussianDiff(self, previous, current): 76 | """ 77 | Computes the difference between two images using Gaussian blur. 78 | 79 | Args: 80 | previous (numpy.ndarray): The previous image. 81 | current (numpy.ndarray): The current image. 82 | 83 | Returns: 84 | numpy.ndarray: The difference map. 85 | 86 | """ 87 | previous = cv2.GaussianBlur(previous, self.kernel_size, self.sigma) 88 | np_previous = np.asarray(previous, dtype=np.int64) 89 | 90 | current = cv2.GaussianBlur(current, self.kernel_size, self.sigma) 91 | np_current = np.asarray(current, dtype=np.int64) 92 | 93 | change = np_current - np_previous 94 | 95 | return change 96 | 97 | def observation(self, obs): 98 | """ 99 | Performs the DVS transformation on the observation. 100 | 101 | Args: 102 | obs (list): The list of stacked frames. 103 | 104 | Returns: 105 | numpy.ndarray: The transformed observation. 106 | 107 | """ 108 | 109 | if len(obs)>0: 110 | prev = np.transpose(obs[0], (1, 2, 0)) 111 | current = np.transpose(obs[1], (1, 2, 0)) 112 | 113 | if not self.is_color: 114 | prev = cv2.cvtColor(prev, cv2.COLOR_RGB2GRAY) 115 | current = cv2.cvtColor(current, cv2.COLOR_RGB2GRAY) 116 | 117 | change = self.gaussianDiff(prev, current) 118 | 119 | ## threshold 120 | dc = self.threshold(change) 121 | 122 | else: 123 | obs = np.transpose(obs, (1, 2, 0)) 124 | 125 | if not self.is_color: 126 | obs = self.create_grayscale(obs) 127 | 128 | obs = np.array(obs, dtype=np.float32) / 255.0 129 | dc = self.threshold(obs) 130 | 131 | # change to channel first, w, h 132 | dc = np.transpose(dc, (2, 0, 1)) 133 | 134 | return dc.astype(np.uint8) 135 | 136 | def threshold(self, change): 137 | """ 138 | Applies a threshold to the change map. 139 | 140 | Args: 141 | change (numpy.ndarray): The change map. 142 | 143 | Returns: 144 | numpy.ndarray: The thresholded change map. 145 | 146 | """ 147 | if not self.is_color: 148 | ret_frame = np.ones(shape=change.shape) * 128 149 | ret_frame[change >= self.change_threshold] = 255 150 | ret_frame[change <= -self.change_threshold] = 0 151 | else: 152 | ret_frame = abs(change) 153 | ret_frame[ret_frame < self.change_threshold] = 0 154 | 155 | return ret_frame 156 | 157 | def reset(self, **kwargs): 158 | initial_obs = self.env.reset(**kwargs) 159 | return self.observation(initial_obs) 160 | -------------------------------------------------------------------------------- /examples/run/wrapper/dvs_wrapper.py: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | import collections 3 | import gym 4 | from scipy.ndimage import gaussian_filter 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import pdb 8 | from PIL import Image 9 | import os 10 | import cv2 11 | import pdb 12 | 13 | import logging 14 | 15 | logger = logging.getLogger(__name__) 16 | logger.setLevel(logging.INFO) 17 | 18 | class DVSWrapper(gym.ObservationWrapper): 19 | """ 20 | A gym observation wrapper that performs Dynamic Vision Sensor (DVS) transformation on the environment observations. 21 | 22 | Args: 23 | env (gym.Env): The environment to wrap. 24 | change_threshold (int): The threshold value for detecting changes in pixel intensity. 25 | kernel_size (tuple): The size of the Gaussian kernel used for blurring. 26 | sigma (float): The standard deviation of the Gaussian kernel. 27 | 28 | Attributes: 29 | change_threshold (int): The threshold value for detecting changes in pixel intensity. 30 | kernel_size (tuple): The size of the Gaussian kernel used for blurring. 31 | sigma (float): The standard deviation of the Gaussian kernel. 32 | num_stack (int): The number of frames to stack. 33 | env (gym.Env): The wrapped environment. 34 | stack (collections.deque): A deque to store the stacked frames. 35 | shape (tuple): The shape of the observation space. 36 | observation_space (gym.spaces.Box): The modified observation space. 37 | 38 | Methods: 39 | create_grayscale(image): Converts an image to grayscale. 40 | gaussianDiff(previous, current): Computes the difference between two images using Gaussian blur. 41 | observation(obs): Performs the DVS transformation on the observation. 42 | threshold(change): Applies a threshold to the change map. 43 | reset(**kwargs): Resets the environment and returns the initial observation. 44 | 45 | """ 46 | 47 | def __init__(self, env, change_threshold=60, kernel_size=(3, 3), sigma=1, is_color = True ): 48 | super().__init__(env) 49 | 50 | self.change_threshold = change_threshold 51 | self.kernel_size = kernel_size 52 | self.sigma = sigma 53 | self.num_stack = 2 ## default 54 | self.env = gym.wrappers.FrameStack(env,self.num_stack) 55 | self.stack = collections.deque(maxlen=self.num_stack) 56 | self.is_color = is_color 57 | 58 | try: 59 | stack, channels, width, height = self.env.observation_space.shape 60 | self.shape=(3, width, height) 61 | self.observation_space = gym.spaces.Box(shape=self.shape, low=0, high=255, dtype=np.uint8) 62 | logger.info("In dvs wrapper") 63 | except Exception as ex: 64 | print(str(ex)) 65 | 66 | 67 | def create_grayscale(self, image): 68 | """ 69 | Converts an image to grayscale. 70 | 71 | Args: 72 | image (numpy.ndarray): The input image. 73 | 74 | Returns: 75 | numpy.ndarray: The grayscale image. 76 | 77 | """ 78 | return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 79 | 80 | 81 | def gaussianDiff(self, previous, current): 82 | """ 83 | Computes the difference between two images using Gaussian blur. 84 | 85 | Args: 86 | previous (numpy.ndarray): The previous image. 87 | current (numpy.ndarray): The current image. 88 | 89 | Returns: 90 | numpy.ndarray: The difference map. 91 | 92 | """ 93 | previous = cv2.GaussianBlur(previous, self.kernel_size, self.sigma) 94 | np_previous = np.asarray(previous, dtype=np.int64) 95 | 96 | current = cv2.GaussianBlur(current, self.kernel_size, self.sigma) 97 | np_current = np.asarray(current, dtype=np.int64) 98 | 99 | change = np_current - np_previous 100 | 101 | return change 102 | 103 | def observation(self, obs): 104 | """ 105 | Performs the DVS transformation on the observation. 106 | 107 | Args: 108 | obs (list): The list of stacked frames. 109 | 110 | Returns: 111 | numpy.ndarray: The transformed observation. 112 | 113 | """ 114 | 115 | if len(obs)>0: 116 | prev = np.transpose(obs[0], (1, 2, 0)) 117 | current = np.transpose(obs[1], (1, 2, 0)) 118 | 119 | if not self.is_color: 120 | prev = cv2.cvtColor(prev, cv2.COLOR_RGB2GRAY) 121 | current = cv2.cvtColor(current, cv2.COLOR_RGB2GRAY) 122 | 123 | change = self.gaussianDiff(prev, current) 124 | 125 | ## threshold 126 | dc = self.threshold(change) 127 | 128 | else: 129 | obs = np.transpose(obs, (1, 2, 0)) 130 | 131 | if not self.is_color: 132 | obs = self.create_grayscale(obs) 133 | 134 | obs = np.array(obs, dtype=np.float32) / 255.0 135 | dc = self.threshold(obs) 136 | 137 | # change to channel first, w, h 138 | dc = np.transpose(dc, (2, 0, 1)) 139 | 140 | return dc.astype(np.uint8) 141 | 142 | def threshold(self, change): 143 | """ 144 | Applies a threshold to the change map. 145 | 146 | Args: 147 | change (numpy.ndarray): The change map. 148 | 149 | Returns: 150 | numpy.ndarray: The thresholded change map. 151 | 152 | """ 153 | if not self.is_color: 154 | ret_frame = np.ones(shape=change.shape) * 128 155 | ret_frame[change >= self.change_threshold] = 255 156 | ret_frame[change <= -self.change_threshold] = 0 157 | else: 158 | ret_frame = abs(change) 159 | ret_frame[ret_frame < self.change_threshold] = 0 160 | 161 | return ret_frame 162 | 163 | def reset(self, **kwargs): 164 | initial_obs = self.env.reset(**kwargs) 165 | return self.observation(initial_obs) 166 | 167 | 168 | 169 | 170 | 171 | -------------------------------------------------------------------------------- /src/nett/body/builder.py: -------------------------------------------------------------------------------- 1 | """The body of the agent in the environment.""" 2 | from gym import Env, Wrapper 3 | from stable_baselines3.common.env_checker import check_env 4 | 5 | from nett.body import types 6 | from nett.body.wrappers.dvs import DVSWrapper 7 | # from nett.body import ascii_art 8 | 9 | # this will have the necessary wrappers before the observations interact with the brain, 10 | # because it is the body that determines how the the observations will be processed. 11 | # specifically, in the case of the two-eyed agent, because the agent has two eyes, the observations are stereo 12 | # and need to be processed differently before they make it to the brain. 13 | # the body is the medium through which information travels from the environment to the brain. 14 | # the brain is limited by what the body can percieve and no information is objective. 15 | # NO INFORMATION IS OBJECTIVE (!!!!!!) 16 | class Body: 17 | """Represents the body of an agent in an environment. 18 | 19 | The body determines how observations from the environment are processed before they reach the brain. 20 | It can apply wrappers to modify the observations and provide a different perception to the brain. 21 | 22 | Args: 23 | type (str, optional): The type of the agent's body. Defaults to "basic". 24 | wrappers (list[Wrapper], optional): List of wrappers to be applied to the environment. Defaults to []. 25 | dvs (bool, optional): Flag indicating whether the agent uses dynamic vision sensors. Defaults to False. 26 | 27 | Raises: 28 | ValueError: If the agent type is not valid. 29 | TypeError: If dvs is not a boolean. 30 | 31 | Example: 32 | 33 | >>> from nett import Body 34 | >>> body = Body(type="basic", wrappers=None, dvs=False) 35 | """ 36 | 37 | def __init__(self, type: str = "basic", 38 | wrappers: list[Wrapper] = [], 39 | dvs: bool = False) -> None: 40 | """ 41 | Constructor method 42 | """ 43 | from nett import logger 44 | self.logger = logger.getChild(__class__.__name__) 45 | self.type = self._validate_agent_type(type) 46 | self.wrappers = self._validate_wrappers(wrappers) 47 | self.dvs = self._validate_dvs(dvs) 48 | 49 | def _validate_agent_type(self, type: str) -> str: 50 | """ 51 | Validate the agent type. 52 | 53 | Args: 54 | type (str): The type of the agent's body. 55 | 56 | Returns: 57 | str: The validated agent type. 58 | 59 | Raises: 60 | ValueError: If the agent type is not valid. 61 | """ 62 | if type not in types: 63 | raise ValueError(f"agent type must be one of {types}") 64 | return type 65 | 66 | def _validate_dvs(self, dvs: bool) -> bool: 67 | """ 68 | Validate the dvs flag. 69 | 70 | Args: 71 | dvs (bool): The dvs flag. 72 | 73 | Returns: 74 | bool: The validated dvs flag. 75 | 76 | Raises: 77 | TypeError: If dvs is not a boolean. 78 | """ 79 | if not isinstance(dvs, bool): 80 | raise TypeError("dvs should be a boolean [True, False]") 81 | return dvs 82 | 83 | def _validate_wrappers(self, wrappers: list[Wrapper]) -> list[Wrapper]: 84 | """ 85 | Validate the wrappers. 86 | 87 | Args: 88 | wrappers (list[Wrapper]): The list of wrappers. 89 | 90 | Returns: 91 | list[Wrapper]: The validated list of wrappers. 92 | 93 | Raises: 94 | ValueError: If any wrapper is not an instance of gym.Wrapper. 95 | """ 96 | for wrapper in wrappers: 97 | if not issubclass(wrapper, Wrapper): 98 | raise ValueError("Wrappers must inherit from gym.Wrapper") 99 | return wrappers 100 | 101 | @staticmethod 102 | def _wrap(env: Env, wrapper: Wrapper) -> Env: 103 | """ 104 | Wraps the environment with the registered wrappers. 105 | 106 | Args: 107 | env (Env): The environment to wrap. 108 | wrapper (Wrapper): The wrapper to apply. 109 | 110 | Returns: 111 | Env: The wrapped environment. 112 | 113 | Raises: 114 | Exception: If the environment does not follow the Gym API. 115 | """ 116 | try: 117 | # wrap env 118 | env = wrapper(env) 119 | # check that the env follows Gym API 120 | env_check = check_env(env, warn=True) 121 | if env_check != None: 122 | raise Exception(f"Failed env check") 123 | 124 | return env 125 | 126 | except Exception as ex: 127 | print(str(ex)) 128 | 129 | def __call__(self, env: Env) -> Env: 130 | """ 131 | Apply the registered wrappers to the environment. 132 | 133 | Args: 134 | env (Env): The environment. 135 | 136 | Returns: 137 | Env: The modified environment. 138 | """ 139 | # apply DVS wrapper 140 | # TODO: Should this wrapper go in a different order? 141 | if self.dvs: 142 | env = self._wrap(env, DVSWrapper) 143 | # apply all custom wrappers 144 | if self.wrappers: 145 | for wrapper in self.wrappers: 146 | env = self._wrap(env, wrapper) 147 | 148 | return env 149 | 150 | 151 | def __repr__(self) -> str: 152 | """ 153 | Return a string representation of the Body object. 154 | 155 | Returns: 156 | str: The string representation of the Body object. 157 | """ 158 | attrs = {k: v for k, v in vars(self).items() if k != "logger"} 159 | return f"{self.__class__.__name__}({attrs!r})" 160 | 161 | 162 | def __str__(self) -> str: 163 | """ 164 | Return a string representation of the Body object. 165 | 166 | Returns: 167 | str: The string representation of the Body object. 168 | """ 169 | attrs = {k: v for k, v in vars(self).items() if k != "logger"} 170 | return f"{self.__class__.__name__}({attrs!r})" 171 | 172 | 173 | def _register(self) -> None: 174 | """ 175 | Register the body with the environment. 176 | 177 | Raises: 178 | NotImplementedError: This method is not implemented. 179 | """ 180 | raise NotImplementedError 181 | -------------------------------------------------------------------------------- /src/nett/brain/encoders/resnet18.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module defines the Resnet18CNN class, which is a custom feature extractor 3 | based on the ResNet-18 architecture. It is used for encoding observations in 4 | the Newborn Embodied Turing Test project. 5 | 6 | The Resnet18CNN class inherits from the BaseFeaturesExtractor class provided 7 | by the stable_baselines3 library. It takes an observation space and the desired 8 | number of features as input and extracts features using the ResNet-18 model. 9 | 10 | The ResNet-18 architecture consists of several residual blocks, each containing 11 | two convolutional layers and a skip connection. The final features are obtained 12 | by applying a linear layer to the output of the last residual block. 13 | 14 | The ResBlock class and the ResNet_18 class are helper classes used by the 15 | Resnet18CNN class to define the residual blocks and the overall ResNet-18 16 | architecture, respectively. 17 | 18 | Example usage: 19 | 20 | observation_space = gym.spaces.Box(low=0, high=255, shape=(3, 84, 84), dtype=np.uint8) 21 | features_dim = 256 22 | encoder = Resnet18CNN(observation_space, features_dim) 23 | features = encoder(observation) 24 | 25 | """ 26 | 27 | #!/usr/bin/env python3 28 | 29 | # import pdb 30 | import gym 31 | 32 | import torch as th 33 | from torch import nn 34 | 35 | from stable_baselines3.common.torch_layers import BaseFeaturesExtractor 36 | import logging 37 | 38 | logger = logging.getLogger(__name__) 39 | 40 | class Resnet18CNN(BaseFeaturesExtractor): 41 | """ 42 | Custom feature extractor based on the ResNet-18 architecture. 43 | 44 | Args: 45 | observation_space (gym.Space): The observation space of the environment. 46 | features_dim (int): Number of features to be extracted. 47 | """ 48 | 49 | def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256) -> None: 50 | super(Resnet18CNN, self).__init__(observation_space, features_dim) 51 | # We assume CxHxW images (channels first) 52 | # Re-ordering will be done by pre-preprocessing or wrapper 53 | ## pretrain set false; 54 | #self.cnn = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT) 55 | n_input_channels = observation_space.shape[0] 56 | logger.info("Resnet18CNN Encoder: ") 57 | self.cnn = ResNet_18(n_input_channels, features_dim) 58 | with th.no_grad(): 59 | n_flatten = self.cnn( 60 | th.as_tensor(observation_space.sample()[None]).float() 61 | ).shape[1] 62 | 63 | self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU()) 64 | 65 | def forward(self, observations: th.Tensor) -> th.Tensor: 66 | """ 67 | Forward pass of the feature extractor. 68 | 69 | Args: 70 | observations (torch.Tensor): The input observations. 71 | 72 | Returns: 73 | torch.Tensor: The extracted features. 74 | """ 75 | # Cut off image 76 | # reshape to from vector to W*H 77 | # gray to color transform 78 | # application of ResNet 79 | # Concat features to the rest of observation vector 80 | # return 81 | return self.linear(self.cnn(observations)) 82 | 83 | ## reference - online 84 | class ResBlock(nn.Module): 85 | """ 86 | Residual block used in the ResNet-18 architecture. 87 | 88 | Args: 89 | in_channels (int): Number of input channels. 90 | out_channels (int): Number of output channels. 91 | identity_downsample (nn.Sequential): Downsample layer for the identity. 92 | stride (int): Stride of the convolutional layers. 93 | """ 94 | 95 | def __init__(self, in_channels, out_channels, identity_downsample=None, stride=1) -> None: 96 | super(ResBlock, self).__init__() 97 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1) 98 | self.bn1 = nn.BatchNorm2d(out_channels) 99 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 100 | self.bn2 = nn.BatchNorm2d(out_channels) 101 | self.relu = nn.ReLU() 102 | self.identity_downsample = identity_downsample 103 | 104 | def forward(self, x: th.Tensor) -> th.Tensor: 105 | """ 106 | Forward pass of the residual block. 107 | 108 | Args: 109 | x (torch.Tensor): The input tensor. 110 | 111 | Returns: 112 | torch.Tensor: The output tensor. 113 | """ 114 | identity = x 115 | x = self.conv1(x) 116 | x = self.bn1(x) 117 | x = self.relu(x) 118 | x = self.conv2(x) 119 | x = self.bn2(x) 120 | if self.identity_downsample is not None: 121 | identity = self.identity_downsample(identity) 122 | x += identity 123 | x = self.relu(x) 124 | return x 125 | 126 | class ResNet_18(nn.Module): 127 | """ 128 | ResNet-18 architecture used in the Resnet18CNN class. 129 | 130 | Args: 131 | image_channels (int): Number of input channels. 132 | num_classes (int): Number of output classes. 133 | """ 134 | 135 | def __init__(self, image_channels, num_classes) -> None: 136 | super(ResNet_18, self).__init__() 137 | self.in_channels = 64 138 | self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=7, stride=2, padding=3) 139 | self.bn1 = nn.BatchNorm2d(64) 140 | self.relu = nn.ReLU() 141 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 142 | 143 | #resnet layers 144 | self.layer1 = self.__make_layer(64, 64, stride=1) 145 | self.layer2 = self.__make_layer(64, 128, stride=2) 146 | self.layer3 = self.__make_layer(128, 256, stride=2) 147 | self.layer4 = self.__make_layer(256, 512, stride=2) 148 | 149 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 150 | self.fc = nn.Linear(512, num_classes) 151 | 152 | def __make_layer(self, in_channels, out_channels, stride) -> nn.Sequential: 153 | """ 154 | Helper function to create a residual layer. 155 | 156 | Args: 157 | in_channels (int): Number of input channels. 158 | out_channels (int): Number of output channels. 159 | stride (int): Stride of the convolutional layers. 160 | 161 | Returns: 162 | nn.Sequential: The residual layer. 163 | """ 164 | identity_downsample = None 165 | if stride != 1: 166 | identity_downsample = self.identity_downsample(in_channels, out_channels) 167 | 168 | return nn.Sequential( 169 | ResBlock(in_channels, out_channels, identity_downsample=identity_downsample, stride=stride), 170 | ResBlock(out_channels, out_channels) 171 | ) 172 | 173 | def forward(self, x: th.Tensor) -> th.Tensor: 174 | """ 175 | Forward pass of the ResNet-18 architecture. 176 | 177 | Args: 178 | x (torch.Tensor): The input tensor. 179 | 180 | Returns: 181 | torch.Tensor: The output tensor. 182 | """ 183 | x = self.conv1(x) 184 | x = self.bn1(x) 185 | x = self.relu(x) 186 | x = self.maxpool(x) 187 | 188 | x = self.layer1(x) 189 | x = self.layer2(x) 190 | x = self.layer3(x) 191 | x = self.layer4(x) 192 | 193 | x = self.avgpool(x) 194 | x = x.view(x.shape[0], -1) 195 | x = self.fc(x) 196 | return x 197 | 198 | def identity_downsample(self, in_channels, out_channels) -> nn.Sequential: 199 | """ 200 | Helper function to create an identity downsample layer. 201 | 202 | Args: 203 | in_channels (int): Number of input channels. 204 | out_channels (int): Number of output channels. 205 | 206 | Returns: 207 | nn.Sequential: The identity downsample layer. 208 | """ 209 | return nn.Sequential( 210 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1), 211 | nn.BatchNorm2d(out_channels) 212 | ) 213 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | Banner 3 | 4 | # **Newborn Embodied Turing Test** 5 | 6 | Benchmarking Virtual Agents in Controlled-Rearing Conditions 7 | 8 | ![PyPI - Version](https://img.shields.io/pypi/v/nett-benchmarks) 9 | ![Python Version from PEP 621 TOML](https://img.shields.io/python/required-version-toml?tomlFilePath=https%3A%2F%2Fraw.githubusercontent.com%2Fbuildingamind%2FNewbornEmbodiedTuringTest%2Fmain%2Fpyproject.toml) 10 | ![GitHub License](https://img.shields.io/github/license/buildingamind/NewbornEmbodiedTuringTest) 11 | ![GitHub Issues or Pull Requests](https://img.shields.io/github/issues/buildingamind/NewbornEmbodiedTuringTest) 12 | 13 | [Getting Started](#getting-started) • 14 | [Documentation](https://buildingamind.github.io/NewbornEmbodiedTuringTest/) • 15 | [Lab Website](http://buildingamind.com/) 16 | 17 |
18 | 19 | The Newborn Embodied Turing Test (NETT) is a cutting-edge toolkit designed to simulate virtual agents in controlled-rearing conditions. This innovative platform enables researchers to create, simulate, and analyze virtual agents, facilitating direct comparisons with real chicks as documented by the **[Building a Mind Lab](http://buildingamind.com/)**. Our comprehensive suite includes all necessary components for the simulation and analysis of embodied models, closely replicating laboratory conditions. 20 | 21 | Below is a visual representation of our experimental setup, showcasing the infrastructure for the three primary experiments discussed in this documentation. 22 | 23 |
24 | 25 | Digital Twin 26 |
27 | 28 | ## How to Use this Repository 29 | 30 | The NETT toolkit comprises three key components: 31 | 32 | 1. **Virtual Environment**: A dynamic environment that serves as the habitat for virtual agents. 33 | 2. **Experimental Simulation Programs**: Tools to initiate and conduct experiments within the virtual world. 34 | 3. **Data Visualization Programs**: Utilities for analyzing and visualizing experiment outcomes. 35 | 36 | ## Directory Structure 37 | 38 | The directory structure of the code is as follows: 39 | 40 | ``` 41 | ├── docs # Documentation and guides 42 | ├── examples 43 | │ ├── notebooks # Jupyter Notebooks for examples 44 | │ └── Getting Started.ipynb # Introduction and setup notebook 45 | │ └── run # Terminal script example 46 | ├── src/nett 47 | │ ├── analysis # Analysis scripts 48 | │ ├── body # Agent body configurations 49 | │ ├── brain # Neural network models and learning algorithms 50 | │ ├── environment # Simulation environments 51 | │ ├── utils # Utility functions 52 | │ ├── nett.py # Main library script 53 | │ └── __init__.py # Package initialization 54 | ├── tests # Unit tests 55 | ├── mkdocs.yml # MkDocs configuration 56 | ├── pyproject.toml # Project metadata 57 | └── README.md # This README file 58 | ``` 59 | 60 | ## Getting Started 61 | 62 | 63 | To begin benchmarking your first embodied agent with NETT, please be aware: 64 | 65 | **Important**: The `mlagents==1.0.0` dependency is incompatible with Apple Silicon (M1, M2, etc.) chips. Please utilize an alternate device to execute this codebase. 66 | 67 | ### Installation 68 | 69 | 1. **Virtual Environment Setup (Highly Recommended)** 70 | 71 | Create and activate a virtual environment to avoid dependency conflicts. 72 | ```bash 73 | conda create -y -n nett_env python=3.10.12 74 | conda activate nett_env 75 | ``` 76 | See [here](https://uoa-eresearch.github.io/eresearch-cookbook/recipe/2014/11/20/conda "Link for how to set-up a virtual env") for detailed instructions. 77 | 78 | 2. **Install Prerequistes** 79 | 80 | Install the needed versions of `setuptools` and `pip`: 81 | ```bash 82 | pip install setuptools==65.5.0 pip==21 wheel==0.38.4 83 | ``` 84 | **NOTE:** This is a result of incompatibilities with the subdependency `gym==0.21`. More information about this issue can be found [here](https://github.com/openai/gym/issues/3176#issuecomment-1560026649) 85 | 86 | 3. **Toolkit Installation** 87 | 88 | Install the toolkit using `pip`. 89 | ```bash 90 | pip install nett-benchmarks 91 | ``` 92 | 93 | **NOTE:**: Installation outside a virtual environment may fail due to conflicting dependencies. Ensure compatibility, especially with `gym==0.21` and `numpy<=1.21.2`. 94 | 95 | ### Running a NETT 96 | 97 | 1. **Download or Create the Unity Executable** 98 | 99 | Obtain a pre-made Unity executable from [here](https://origins.luddy.indiana.edu/environments/). The executable is required to run the virtual environment. 100 | 101 | 2. **Import NETT Components** 102 | 103 | Start by importing the NETT framework components - `Brain`, `Body`, and `Environment`, alongside the main `NETT` class. 104 | ```python 105 | from nett import Brain, Body, Environment 106 | from nett import NETT 107 | ``` 108 | 109 | 3. **Component Configuration**: 110 | 111 | - **Brain** 112 | 113 | Configure the learning aspects, including the policy network (e.g. "CnnPolicy"), learning algorithm (e.g. "PPO"), the reward function, and the encoder. 114 | ```python 115 | brain = Brain(policy="CnnPolicy", algorithm="PPO") 116 | ``` 117 | To get a list of all available policies, algorithms, and encoders, run `nett.list_policies()`, `nett.list_algorithms()`, and `nett.list_encoders()` respectively. 118 | 119 | - **Body** 120 | 121 | Set up the agent's physical interface with the environment. It's possible to apply gym.Wrappers for data preprocessing. 122 | ```python 123 | body = Body(type="basic", dvs=False, wrappers=None) 124 | ``` 125 | Here, we do not pass any wrappers, letting information from the environment reach the brain "as is". Alternative body types (e.g. `two-eyed`, `rag-doll`) are planned in future updates. 126 | 127 | - **Environment** 128 | 129 | Create the simulation environment using the path to your Unity executable (see Step 1). 130 | ```python 131 | environment = Environment(config="identityandview", executable_path="path/to/executable.x86_64") 132 | ``` 133 | To get a list of all available configurations, run `nett.list_configs()`. 134 | 135 | 4. **Run the Benchmarking** 136 | 137 | Integrate all components into a NETT instance to facilitate experiment execution. 138 | ```python 139 | benchmarks = NETT(brain=brain, body=body, environment=environment) 140 | ``` 141 | The `NETT` instance has a `.run()` method that initiates the benchmarking process. The method accepts parameters such as the number of brains, training/testing episodes, and the output directory. 142 | ```python 143 | job_sheet = benchmarks.run(output_dir="path/to/run/output/directory/", num_brains=5, trains_eps=10, test_eps=5) 144 | ``` 145 | The `run` function is asynchronous, returning the list of jobs that may or may not be complete. If you wish to display the Unity environments running, set the `batch_mode` parameter to `False`. 146 | 147 | 5. **Check Status**: 148 | 149 | To see the status of the benchmark processes, use the `.status()` method: 150 | ```python 151 | benchmarks.status(job_sheet) 152 | ``` 153 | 154 | ### Running Standard Analysis 155 | 156 | After running the experiments, the pipeline will generate a collection of datafiles in the defined output directory. 157 | 158 | 1. **Install R and dependencies** 159 | 160 | To run the analyses performed in previous experiments,this toolkit provides a set of analysis scripts. Prior to running them, you will need R and the packages `tidyverse`, `argparse`, and `scales` installed. To install these packages, run the following command in R: 161 | ```R 162 | install.packages(c("tidyverse", "argparse", "scales")) 163 | ``` 164 | Alternatively, if you are having difficulty installing R on your system, you can install these using conda. 165 | ```bash 166 | conda install -y r r-tidyverse r-argparse r-scales 167 | ``` 168 | 2. **Run the Analysis** 169 | 170 | To run the analysis, use the `analyze` method of the `NETT` class. This method will generate a set of plots and tables based on the datafiles in the output directory. 171 | ```python 172 | benchmarks.analyze(run_dir="path/to/run/output/directory/", output_dir="path/to/analysis/output/directory/") 173 | ``` 174 | 175 | 176 | 177 | 178 | ## Documentation 179 | For a link to the full documentation, please visit [here](https://buildingamind.github.io/NewbornEmbodiedTuringTest/). 180 | 181 | ## Experiment Configuration 182 | 183 | More information related to details on the experiment can be found on following pages. 184 | 185 | * [**Parsing Experiment**](https://buildingamind.github.io/NewbornEmbodiedTuringTest/papers/Parsing.html) 186 | * [**ViewPoint Experiment**](https://buildingamind.github.io/NewbornEmbodiedTuringTest/papers/ViewInvariant.html) 187 | 188 | [🔼 Back to top](#newborn-embodied-turing-test) 189 | -------------------------------------------------------------------------------- /src/nett/analysis/NETT_test_viz.R: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env Rscript 2 | 3 | # NETT_test_viz.R 4 | 5 | # Before running this script, you need to run merge_csvs to merge all of the agents' 6 | # output into a single, standardized format dataframe for training and test data 7 | 8 | # Variables -------------------------------------------------------------------- 9 | 10 | # Read in the user-specified variables: 11 | library(argparse) 12 | parser <- ArgumentParser(description="An executable R script for the Newborn Embodied Turing Tests to analyze test trials") 13 | parser$add_argument("--data-loc", type="character", dest="data_loc", 14 | help="Full filename (inc working directory) of the merged R data", 15 | required=TRUE) 16 | parser$add_argument("--chick-file", type="character", dest="chick_file", 17 | help="Full filename (inc working directory) of the chick data CSV file", 18 | required=TRUE) 19 | parser$add_argument("--results-wd", type="character", dest="results_wd", 20 | help="Working directory to save the resulting visualizations", 21 | required=TRUE) 22 | parser$add_argument("--bar-order", type = "character", default = "default", dest = "bar_order", 23 | help="Order of bars. Use 'default', 'asc', 'desc', or specify indices separated by commas (e.g., '3,2,1,4')", 24 | required=FALSE) 25 | parser$add_argument("--color-bars", type = "character", dest="color_bars", 26 | help="Should the bars be colored by test condition?", 27 | required=TRUE) 28 | 29 | # Set script variables based on user input 30 | args <- parser$parse_args() 31 | data_loc <- args$data_loc; chick_file <- args$chick_file; results_wd <- args$results_wd; bar_order <- args$bar_order 32 | if( args$color_bars %in% c("t", "T", "true", "TRUE", "True")) {color_bars <- TRUE} else { color_bars <- FALSE} 33 | 34 | # Set Up ----------------------------------------------------------------------- 35 | 36 | library(tidyverse) 37 | library(stringr) 38 | 39 | # Load the chick data 40 | chick_data <- read.csv(chick_file) 41 | 42 | # Load test data 43 | load(data_loc) 44 | rm(train_data) 45 | 46 | # Code each episode correct/incorrect 47 | test_data <- test_data %>% 48 | mutate(correct_steps = if_else(correct.monitor == " left", left_steps, right_steps)) %>% 49 | mutate(incorrect_steps = if_else(correct.monitor == " left", right_steps, left_steps)) %>% 50 | mutate(percent_correct = correct_steps / (correct_steps + incorrect_steps)) 51 | 52 | # Adjust bar order according to user input ------------------------------------- 53 | 54 | # Create a variable to store the final order 55 | order <- NULL 56 | if (bar_order == "default" || bar_order == "asc" || bar_order == "desc"){ 57 | order <- bar_order 58 | }else { 59 | order <- as.integer(strsplit(order_input, ",")[[1]]) 60 | } 61 | 62 | # Conditionally reorder the dataframe based on user input 63 | if (!is.null(order)) { 64 | if (order == "desc") { 65 | test_data <- test_data %>% 66 | arrange(desc(percent_correct)) %>% 67 | mutate(test.cond = factor(test.cond, levels = unique(test.cond))) 68 | } else if (order == "asc"){ 69 | test_data <- test_data %>% 70 | arrange(percent_correct) %>% 71 | mutate(test.cond = factor(test.cond, levels = unique(test.cond))) 72 | } else if (order != "default") { 73 | # Map numeric indices to factor levels 74 | current_order <- levels(factor(test_data$test.cond)) 75 | new_order <- current_order[order] 76 | test_data$test.cond <- factor(test_data$test.cond, levels = new_order) 77 | } 78 | # If order is "default", no need to change anything 79 | } 80 | 81 | 82 | # Plot aesthetic settings ------------------------------------------------------ 83 | custom_palette <- c("#3F8CB7", "#FCEF88", "#5D5797", "#62AC6B", "#B74779", "#2C4E98","#CCCCE7", "#08625B", "#D15056") 84 | chickred <- "#AF264A" 85 | 86 | p <- ggplot() + 87 | theme_classic() + 88 | theme(axis.text.x = element_text(size = 6)) + 89 | ylab("Percent Correct") + 90 | xlab("Test Condition") + 91 | scale_y_continuous(expand = c(0, 0), limits = c(0, 1), breaks=seq(0,1,.1), labels = scales::percent) + 92 | geom_hline(yintercept = .5, linetype = 2) + 93 | scale_fill_manual(values = custom_palette) + 94 | scale_colour_manual(values = custom_palette) + 95 | theme(axis.title = element_text(face="bold"), 96 | axis.text.x = element_text(face="bold", size=7.5), 97 | axis.text.y = element_text(face="bold", size=7.5)) 98 | 99 | 100 | # Bar Chart Function ----------------------------------------------------------- 101 | make_bar_charts <- function(data, dots, aes_y, error_min, error_max, img_name) 102 | { 103 | p + 104 | 105 | # Add chicken performance FIRST to sort the bars 106 | geom_errorbar(data=chick_data, width = 0.7, colour = chickred, 107 | aes(x=test.cond, ymin=avg, ymax=avg)) + 108 | 109 | # Model performance: bars 110 | {if(color_bars)geom_col(data = data, width = 0.7, aes(x=test.cond, y = {{aes_y}}, fill = test.cond))}+ 111 | {if(!color_bars)geom_col(data = data, width = 0.7, aes(x=test.cond, y = {{aes_y}}), fill = "gray45")}+ 112 | # Model performance: error bars 113 | geom_errorbar(data = data, width = 0.3, 114 | aes(x = test.cond, ymin = {{error_min}}, ymax = {{error_max}})) + 115 | # Model performance: dots 116 | {if(!is.null(dots))geom_jitter(data = dot_data, aes(x=test.cond, y = avgs), width = .3)}+ 117 | theme(legend.position="none") + 118 | 119 | # Add chicken performance again so that it shows up on top 120 | # Chick performance: lines (errorbar) with ribbons (crossbar) 121 | geom_errorbar(data=chick_data, width = 0.7, colour = chickred, 122 | aes(x=test.cond, ymin=avg, ymax=avg)) + 123 | geom_crossbar(data=chick_data, width = 0.7, 124 | linetype = 0, fill = chickred, alpha = 0.2, 125 | aes(x = test.cond, y = avg, 126 | ymin = avg - avg_dev, ymax = avg + avg_dev)) 127 | 128 | ggsave(img_name, width = 6, height = 6) 129 | } 130 | 131 | # Switch wd before we save the graphs 132 | setwd(results_wd) 133 | 134 | # Plot by agent ---------------------------------------------------------------- 135 | ## Leave rest data for agent-level graphs 136 | 137 | ## Group data by test conditions 138 | by_test_cond <- test_data %>% 139 | group_by(imprint.cond, agent, test.cond) %>% 140 | summarise(avgs = mean(percent_correct, na.rm = TRUE), 141 | sd = sd(percent_correct, na.rm = TRUE), 142 | count = length(percent_correct), 143 | tval = tryCatch({ (t.test(percent_correct, mu=0.5)$statistic)}, error = function(err){NA}), 144 | df = tryCatch({(t.test(percent_correct, mu=0.5)$parameter)},error = function(err){NA}), 145 | pval = tryCatch({(t.test(percent_correct, mu=0.5)$p.value)},error = function(err){NA}))%>% 146 | mutate(se = sd / sqrt(count)) %>% 147 | mutate(cohensd = (avgs - .5) / sd) %>% 148 | mutate(imp_agent = paste(imprint.cond, agent, sep="_")) 149 | 150 | write.csv(by_test_cond, "stats_by_agent.csv") 151 | 152 | for (i in unique(by_test_cond$imp_agent)) 153 | { 154 | bar_data <- by_test_cond %>% 155 | filter(imp_agent == i) 156 | 157 | img_name <- paste0(i, "_test.png") 158 | 159 | make_bar_charts(data = bar_data, 160 | dots = NULL, 161 | aes_y = avgs, 162 | error_min = avgs - se, 163 | error_max = avgs + se, 164 | img_name = img_name) 165 | } 166 | 167 | 168 | # Plot by imprinting condition ------------------------------------------------- 169 | ## Remove rest data once we start to group agents (for ease of presentation) 170 | 171 | by_imp_cond <- by_test_cond %>% 172 | ungroup() %>% 173 | group_by(imprint.cond, test.cond) %>% 174 | summarise(avgs_by_imp = mean(avgs, na.rm = TRUE), 175 | sd = sd(avgs, na.rm = TRUE), 176 | count = length(avgs), 177 | tval = tryCatch({ (t.test(avgs, mu=0.5)$statistic)}, error = function(err){NA}), 178 | df = tryCatch({ (t.test(avgs, mu=0.5)$parameter)}, error = function(err){NA}), 179 | pval = tryCatch({ (t.test(avgs, mu=0.5)$p.value)}, error = function(err){NA}))%>% 180 | mutate(se = sd / sqrt(count)) %>% 181 | mutate(cohensd = (avgs_by_imp - .5) / sd) 182 | 183 | write.csv(by_imp_cond, "stats_by_imp_cond.csv") 184 | 185 | for (i in unique(by_imp_cond$imprint.cond)) 186 | { 187 | bar_data <- by_imp_cond %>% 188 | filter(imprint.cond == i) %>% 189 | filter(test.cond != "Rest") 190 | 191 | dot_data <- by_test_cond %>% 192 | filter(imprint.cond == i) %>% 193 | filter(test.cond != "Rest") 194 | 195 | img_name <- paste0(i, "_test.png") 196 | 197 | make_bar_charts(data = bar_data, 198 | dots = dot_data, 199 | aes_y = avgs_by_imp, 200 | error_min = avgs_by_imp - se, 201 | error_max = avgs_by_imp + se, 202 | img_name = img_name) 203 | } 204 | 205 | 206 | # Plot across all imprinting conditions ---------------------------------------- 207 | across_imp_cond <- by_test_cond %>% 208 | ungroup() %>% 209 | filter(test.cond != "Rest") %>% 210 | group_by(test.cond) %>% 211 | summarise(all_avgs = mean(avgs, na.rm = TRUE), 212 | sd = sd(avgs, na.rm = TRUE), 213 | count = length(avgs), 214 | tval = tryCatch({ (t.test(avgs, mu=0.5)$statistic)}, error = function(err){NA}), 215 | df = tryCatch({ (t.test(avgs, mu=0.5)$parameter)}, error = function(err){NA}), 216 | pval = tryCatch({ (t.test(avgs, mu=0.5)$p.value)}, error = function(err){NA}))%>% 217 | mutate(se = sd / sqrt(count)) %>% 218 | mutate(cohensd = (all_avgs - .5) / sd) 219 | 220 | write.csv(across_imp_cond, "stats_across_all_agents.csv") 221 | 222 | dot_data <- filter(by_test_cond, test.cond != "Rest") 223 | 224 | make_bar_charts(data = across_imp_cond, 225 | dots = dot_data, 226 | aes_y = all_avgs, 227 | error_min = all_avgs - se, 228 | error_max = all_avgs + se, 229 | img_name = "all_imprinting_conds_test.png") 230 | -------------------------------------------------------------------------------- /src/nett/environment/builder.py: -------------------------------------------------------------------------------- 1 | """Module for the Environment class.""" 2 | 3 | from __future__ import annotations 4 | 5 | import os 6 | import subprocess 7 | from typing import Optional, Any 8 | 9 | import numpy as np 10 | from gym import Wrapper 11 | from mlagents_envs.environment import UnityEnvironment 12 | 13 | # checks to see if ml-agents tmp files have the proper permissions 14 | try : 15 | from mlagents_envs.envs.unity_gym_env import UnityToGymWrapper 16 | except PermissionError as _: 17 | raise PermissionError("Directory '/tmp/ml-agents-binaries' is not accessible. Please change permissions of the directory and its subdirectories ('tmp' and 'binaries') to 1777 or delete the entire directory and try again.") 18 | 19 | from nett.environment import configs 20 | from nett.environment.configs import NETTConfig, list_configs 21 | from nett.utils.environment import Logger, port_in_use 22 | 23 | class Environment(Wrapper): 24 | """ 25 | Represents the environment where the agent lives. 26 | 27 | The environment is the source of all input data streams to train the brain of the agent. 28 | It accepts a Unity Executable and wraps it around as a Gym environment by leveraging the UnityEnvironment 29 | class from the mlagents_envs library. 30 | 31 | It provides a convenient interface for interacting with the Unity environment and includes methods for initializing the environment, rendering frames, taking steps, resetting the environment, and logging messages. 32 | 33 | Args: 34 | config (str | NETTConfig): The configuration for the environment. It can be either a string representing the name of a pre-defined configuration, or an instance of the NETTConfig class. 35 | executable_path (str): The path to the Unity executable file. 36 | display (int, optional): The display number to use for the Unity environment. Defaults to 0. 37 | base_port (int, optional): The base port number to use for communication with the Unity environment. Defaults to 5004. 38 | record_chamber (bool, optional): Whether to record the chamber. Defaults to False. 39 | record_agent (bool, optional): Whether to record the agent. Defaults to False. 40 | recording_frames (int, optional): The number of frames to record. Defaults to 1000. 41 | 42 | Raises: 43 | ValueError: If the configuration is not a valid string or an instance of NETTConfig. 44 | 45 | Example: 46 | 47 | >>> from nett import Environment 48 | >>> env = Environment(config="identityandview", executable_path="path/to/executable") 49 | """ 50 | def __init__(self, 51 | config: str | NETTConfig, 52 | executable_path: str, 53 | display: int = 0, 54 | base_port: int = 5004, 55 | record_chamber: bool = False, 56 | record_agent: bool = False, 57 | recording_frames: int = 1000) -> None: 58 | """Constructor method 59 | """ 60 | from nett import logger 61 | self.logger = logger.getChild(__class__.__name__) 62 | self.config = self._validate_config(config) 63 | # TODO (v0.5) what might be a way to check if it is a valid executable path? 64 | self.executable_path = executable_path 65 | self.base_port = base_port 66 | self.record_chamber = record_chamber 67 | self.record_agent = record_agent 68 | self.recording_frames = recording_frames 69 | self.display = display 70 | 71 | # set the correct permissions on the executable 72 | self._set_executable_permission() 73 | # set the display for Unity environment 74 | self._set_display() 75 | 76 | def _validate_config(self, config: str | NETTConfig) -> NETTConfig: 77 | """ 78 | Validates the configuration for the environment. 79 | 80 | Args: 81 | config (str | NETTConfig): The configuration to validate. 82 | 83 | Returns: 84 | NETTConfig: The validated configuration. 85 | 86 | Raises: 87 | ValueError: If the configuration is not a valid string or an instance of NETTConfig. 88 | """ 89 | # for when config is a str 90 | if isinstance(config, str): 91 | config_dict = {config_str.lower(): config_str for config_str in list_configs()} 92 | if config not in config_dict.keys(): 93 | raise ValueError(f"Should be one of {config_dict.keys()}") 94 | 95 | config = getattr(configs, config_dict[config])() 96 | 97 | # for when config is a NETTConfig 98 | elif isinstance(config, NETTConfig): 99 | pass 100 | 101 | else: 102 | raise ValueError(f"Should either be one of {list(config_dict.keys())} or a subclass of NETTConfig") 103 | 104 | return config 105 | 106 | def _set_executable_permission(self) -> None: 107 | """ 108 | Sets the executable permission for the Unity executable file. 109 | """ 110 | subprocess.run(["chmod", "-R", "755", self.executable_path], check=True) 111 | self.logger.info("Executable permission is set") 112 | 113 | def _set_display(self) -> None: 114 | """ 115 | Sets the display environment variable for the Unity environment. 116 | """ 117 | os.environ["DISPLAY"] = str(f":{self.display}") 118 | self.logger.info("Display is set") 119 | 120 | 121 | # copied from __init__() of chickai_env_wrapper.py (legacy) 122 | # TODO (v0.4) Critical refactor, don't like how this works, extremely error prone. 123 | # how can we build + constraint arguments better? something like an ArgumentParser sounds neat 124 | # TODO (v0.4) fix random_pos logic inside of Unity code 125 | def initialize(self, mode: str, **kwargs) -> None: 126 | """ 127 | Initializes the environment with the given mode and arguments. 128 | 129 | Args: 130 | mode (str): The mode to set the environment for training or testing or both. 131 | **kwargs: The arguments to pass to the environment. 132 | """ 133 | 134 | args = [] 135 | 136 | # from environment arguments 137 | if self.recording_frames: 138 | args.extend(["--recording-steps", str(self.recording_frames)]) 139 | if self.record_chamber: 140 | args.extend(["--record-chamber", "true"]) 141 | if self.record_agent: 142 | args.extend(["--record-agent", "true"]) 143 | 144 | # from runtime 145 | args.extend(["--mode", f"{mode}-{kwargs['condition']}"]) 146 | if kwargs.get("rec_path", None): 147 | args.extend(["--log-dir", f"{kwargs['rec_path']}/"]) 148 | # needs to fixed in Unity code where the default is always false 149 | if mode == "train": 150 | args.extend(["--random-pos", "true"]) 151 | if kwargs.get("rewarded", False): 152 | args.extend(["--rewarded", "true"]) 153 | self.step_per_episode = kwargs.get("episode_steps", 1000) 154 | args.extend(["--episode-steps", str(self.step_per_episode)]) 155 | 156 | 157 | # if kwargs["device_type"] == "cpu": 158 | # args.extend(["-batchmode", "-nographics"]) 159 | # elif kwargs["batch_mode"]: 160 | if kwargs["batch_mode"]: 161 | args.append("-batchmode") 162 | 163 | # TODO: Figure out a way to run on multiple GPUs 164 | # if ("device" in kwargs): 165 | # args.extend(["-force-device-index", str(kwargs["device"])]) 166 | # args.extend(["-gpu", str(kwargs["device"])]) 167 | 168 | # find unused port 169 | while port_in_use(self.base_port): 170 | self.base_port += 1 171 | 172 | # create logger 173 | self.log = Logger(f"{kwargs['condition'].replace('-', '_')}{kwargs['run_id']}-{mode}", 174 | log_dir=f"{kwargs['log_path']}/") 175 | 176 | # create environment and connect it to logger 177 | self.env = UnityEnvironment(self.executable_path, side_channels=[self.log], additional_args=args, base_port=self.base_port) 178 | self.env = UnityToGymWrapper(self.env, uint8_visual=True) 179 | 180 | # initialize the parent class (gym.Wrapper) 181 | super().__init__(self.env) 182 | 183 | # converts the (c, w, h) frame returned by mlagents v1.0.0 and Unity 2022.3 to (w, h, c) 184 | # as expected by gym==0.21.0 185 | # HACK: mode is not used, but is required by the gym.Wrapper class (might be unnecessary but keeping for now) 186 | def render(self, mode="rgb_array") -> np.ndarray: # pylint: disable=unused-argument 187 | """ 188 | Renders the current frame of the environment. 189 | 190 | Args: 191 | mode (str, optional): The mode to render the frame in. Defaults to "rgb_array". 192 | 193 | Returns: 194 | numpy.ndarray: The rendered frame of the environment. 195 | """ 196 | return np.moveaxis(self.env.render(), [0, 1, 2], [2, 0, 1]) 197 | 198 | def step(self, action: list[Any]) -> tuple[np.ndarray, float, bool, dict]: 199 | """ 200 | Takes a step in the environment with the given action. 201 | 202 | Args: 203 | action (list[Any]): The action to take in the environment. 204 | 205 | Returns: 206 | tuple[numpy.ndarray, float, bool, dict]: A tuple containing the next state, reward, done flag, and info dictionary. 207 | """ 208 | next_state, reward, done, info = self.env.step(action) 209 | return next_state, float(reward), done, info 210 | 211 | def log(self, msg: str) -> None: 212 | """ 213 | Logs a message to the environment. 214 | 215 | Args: 216 | msg (str): The message to log. 217 | """ 218 | self.log.log_str(msg) 219 | 220 | def reset(self, seed: Optional[int] = None, **kwargs) -> None | list[np.ndarray] | np.ndarray: # pylint: disable=unused-argument 221 | # nothing to do if the wrapped env does not accept `seed` 222 | """ 223 | Resets the environment with the given seed and arguments. 224 | 225 | Args: 226 | seed (int, optional): The seed to use for the environment. Defaults to None. 227 | **kwargs: The arguments to pass to the environment. 228 | 229 | Returns: 230 | numpy.ndarray: The initial state of the environment. 231 | """ 232 | return self.env.reset(**kwargs) 233 | 234 | def __repr__(self) -> str: 235 | attrs = {k: v for k, v in vars(self).items() if k != "logger"} 236 | return f"{self.__class__.__name__}({attrs!r})" 237 | 238 | def __str__(self) -> str: 239 | attrs = {k: v for k, v in vars(self).items() if k != "logger"} 240 | return f"{self.__class__.__name__}({attrs!r})" 241 | -------------------------------------------------------------------------------- /examples/run/run.py: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | from abc import abstractmethod 3 | import os 4 | import logging 5 | import argparse 6 | import yaml 7 | from nett import Brain, Body, Environment 8 | from nett import NETT 9 | from nett.environment.configs import Binding, Parsing, ViewInvariant 10 | from wrapper.dvs_wrapper import DVSWrapper 11 | 12 | logger = logging.getLogger(__name__) 13 | logger.setLevel(logging.INFO) 14 | 15 | def load_configuration(config_path: str): 16 | with open(config_path, 'r') as f: 17 | return yaml.safe_load(f) 18 | 19 | class BodyConfiguration: 20 | def __init__(self, kwargs): 21 | for key, value in kwargs.items(): 22 | setattr(self, key, value) 23 | 24 | class BrainConfiguration: 25 | def __init__(self, kwargs): 26 | for key, value in kwargs.items(): 27 | setattr(self, key, value) 28 | 29 | class EnvironmentConfiguration: 30 | def __init__(self, kwargs): 31 | for key, value in kwargs.items(): 32 | setattr(self, key, value) 33 | 34 | class Experiment: 35 | """ 36 | Generic Experiment Class To Run 3 experiments - Parsing, Binding and ViewInvariant 37 | """ 38 | 39 | def __init__(self, **kwargs) -> None: 40 | ## initialize configurations 41 | self.brain_config = BrainConfiguration(kwargs.get('Brain')) 42 | self.body_config = BodyConfiguration(kwargs.get('Body')) 43 | self.env_config = EnvironmentConfiguration(kwargs.get('Environment')) 44 | self.base_simclr_checkpoint_path = os.path.join(os.getcwd(), "../data/checkpoints") 45 | 46 | self.encoder_config = { 47 | "small": { 48 | "feature_dimensions": 512, # replace with actual feature dimensions for 'small' 49 | "encoder": "", 50 | }, 51 | "medium": { 52 | "feature_dimensions": 128, # replace with actual feature dimensions for 'medium' 53 | "encoder": "resnet10", 54 | }, 55 | "large": { 56 | "feature_dimensions": 128, # replace with actual feature dimensions for 'large' 57 | "encoder": "resnet18", 58 | }, 59 | "dinov2": { 60 | "feature_dimensions": 384, # replace with actual feature dimensions for 'dinov2' 61 | "encoder": "dinvo2", 62 | }, 63 | "dinov1": { 64 | "feature_dimensions": 384, # replace with actual feature dimensions for 'dinov1', 65 | "encoder": "dinov1", 66 | }, 67 | "simclr": { 68 | "feature_dimensions": 512, # replace with actual feature dimensions for 'ego4d' 69 | "encoder": "frozensimclr", 70 | }, 71 | "sam": { 72 | "feature_dimensions": 256, # replace with actual feature dimensions for 'sam' 73 | "encoder": "sam", 74 | } 75 | } 76 | 77 | ## Environment 78 | self.env = self.initialize_environment() 79 | 80 | ## Body 81 | self.body = self.initialize_body() 82 | 83 | ## Brain 84 | self.brain = self.initialize_brain() 85 | 86 | ## configuration 87 | config = kwargs.get('Config') 88 | self.train_eps = config['train_eps'] 89 | self.test_eps = config['test_eps'] 90 | self.mode = config['mode'] 91 | self.num_brains = config['num_brains'] 92 | self.output_dir = config['output_dir'] 93 | self.run_id = config['run_id'] 94 | 95 | print(self.train_eps, self.test_eps, self.mode, self.num_brains, self.output_dir, self.run_id) 96 | 97 | def initialize_brain(self): 98 | """ 99 | Initialize Brain class with the attributes extracted from the brain_config 100 | 101 | Returns: 102 | _type_: _description_ 103 | """ 104 | 105 | # Extract attributes from brain_config 106 | brain_config_attrs = {attr: getattr(self.brain_config, attr) for attr in dir(self.brain_config) \ 107 | if not attr.startswith('__')} 108 | 109 | 110 | ## update encoder attr in brain_config 111 | if brain_config_attrs['encoder']: 112 | encoder_config = self.encoder_config[brain_config_attrs['encoder']] 113 | brain_config_attrs['encoder'] = encoder_config['encoder'] 114 | brain_config_attrs['embedding_dim'] = encoder_config['feature_dimensions'] 115 | 116 | 117 | 118 | ## Add checkpoint path 119 | if brain_config_attrs.get('encoder') == 'frozensimclr': 120 | checkpt_path = self.get_checkpoint_path() 121 | brain_config_attrs['custom_encoder_args'] = {'checkpoint_path':\ 122 | self.get_checkpoint_path()} 123 | 124 | # Initialize Brain class with the extracted attributes 125 | brain = Brain(**brain_config_attrs) 126 | return brain 127 | 128 | def initialize_body(self): 129 | wrappers = [] 130 | if self.body_config.dvs: 131 | wrappers = [DVSWrapper] 132 | 133 | return Body(type='basic', wrappers=wrappers) 134 | 135 | @abstractmethod 136 | def initialize_environment(self): 137 | pass## abstract method to be implemented by the child classes 138 | 139 | def run(self): 140 | benchmarks = NETT(brain=self.brain, body=self.body, environment=self.env) 141 | print(self.mode) 142 | benchmarks.run(num_brains=self.num_brains, \ 143 | train_eps=self.train_eps, \ 144 | test_eps=self.train_eps, \ 145 | mode=self.mode, \ 146 | job_memory=21, \ 147 | output_dir=self.output_dir,run_id=self.run_id) 148 | 149 | #logger.info("Experiment completed successfully") 150 | 151 | class ParsingExperiment(Experiment): 152 | def __init__(self, **kwargs) -> None: 153 | super().__init__(**kwargs) 154 | 155 | def get_checkpoint_path(self): 156 | ## compute simclr checkpoints 157 | checkpoint_dict = { 158 | "ship_a": "ship_A/checkpoints/epoch=97-step=14601.ckpt", 159 | "ship_b": "ship_B/checkpoints/epoch=97-step=14601.ckpt", 160 | "ship_c": "ship_C/checkpoints/epoch=96-step=14452.ckpt", 161 | "fork_b": "fork_B/checkpoints/epoch=95-step=14303.ckpt", 162 | "fork_a": "fork_A/checkpoints/epoch=97-step=14601.ckpt", 163 | "fork_c": "fork_C/checkpoints/epoch=97-step=14601.ckpt" 164 | } 165 | 166 | parsing_checkpoint = os.path.join(self.base_simclr_checkpoint_path, 'simclr_parsing') 167 | checkpoint_key = f"{self.object.lower()}_{self.background.lower()}" 168 | path = checkpoint_dict.get(checkpoint_key, '') 169 | return os.path.join(parsing_checkpoint, path) 170 | 171 | def initialize_environment(self): 172 | """ 173 | Initialize environment class with the attributes extracted from the env_config 174 | 175 | Returns: 176 | _type_: _description_ 177 | """ 178 | self.object = "ship" if getattr(self.env_config, 'use_ship', False) else "fork" 179 | self.background = getattr(self.env_config, 'background', '') 180 | 181 | # Extract attributes from brain_config 182 | env_config_attrs = {attr: getattr(self.env_config, attr) for attr in dir(self.env_config) \ 183 | if not attr.startswith('__')} 184 | 185 | del env_config_attrs['use_ship'] 186 | del env_config_attrs['background'] 187 | 188 | env_config_attrs['config'] = Parsing(background=self.background, object=self.object) 189 | return Environment(**env_config_attrs) 190 | 191 | class BindingExperiment(Experiment): 192 | def __init__(self, **kwargs) -> None: 193 | super().__init__(**kwargs) 194 | 195 | def get_checkpoint_path(self): 196 | ## compute simclr checkpoints 197 | checkpoint_dict = { 198 | "object_1": "object_1/checkpoints/epoch=97-step=14601.ckpt", 199 | "object_2": "object_2/checkpoints/epoch=97-step=14601.ckpt" 200 | } 201 | 202 | parsing_checkpoint = os.path.join(self.base_simclr_checkpoint_path, 'simclr_binding') 203 | checkpoint_key = f"{self.object.lower()}" 204 | path = checkpoint_dict.get(checkpoint_key, '') 205 | return os.path.join(parsing_checkpoint, path) 206 | 207 | def initialize_environment(self): 208 | """ 209 | Initialize environment class with the attributes extracted from the env_config 210 | 211 | Returns: 212 | _type_: _description_ 213 | """ 214 | # Extract attributes from brain_config 215 | env_config_attrs = {attr: getattr(self.env_config, attr) for attr in dir(self.env_config) \ 216 | if not attr.startswith('__')} 217 | del env_config_attrs['object'] 218 | env_config_attrs['config'] = Binding(object= self.env_config.object) 219 | return Environment(**env_config_attrs) 220 | 221 | class ViewInvariantExperiment(Experiment): 222 | def __init__(self, **kwargs) -> None: 223 | super().__init__(**kwargs) 224 | 225 | def get_checkpoint_path(self): 226 | ## compute simclr checkpoints 227 | checkpoint_dict = { 228 | "ship_side":"", 229 | "ship_front":"", 230 | "fork_side":"", 231 | "fork_front":"" 232 | } 233 | 234 | viewpt_checkpoint = os.path.join(self.base_simclr_checkpoint_path, 'simclr_viewpt') 235 | checkpoint_key = f"{self.object}_{self.view.lower()}" 236 | path = checkpoint_dict.get(checkpoint_key, '') 237 | return os.path.join(viewpt_checkpoint, path) 238 | 239 | def initialize_environment(self): 240 | """ 241 | Initialize environment class with the attributes extracted from the env_config 242 | 243 | Returns: 244 | _type_: _description_ 245 | """ 246 | self.object = "ship" if getattr(self.env_config, 'use_ship', False) else "fork" 247 | self.view = "side" if getattr(self.env_config, 'side_view', False) else "front" 248 | 249 | # Extract attributes from brain_config 250 | env_config_attrs = {attr: getattr(self.env_config, attr) for attr in dir(self.env_config) \ 251 | if not attr.startswith('__')} 252 | 253 | del env_config_attrs['use_ship'] 254 | del env_config_attrs['side_view'] 255 | 256 | env_config_attrs['config'] = ViewInvariant(object=self.object, view=self.view) 257 | return Environment(**env_config_attrs) 258 | 259 | def main(): 260 | args = parse_args() 261 | 262 | if args.exp_name: 263 | exp_name = args.exp_name 264 | config_path = f'configuration/{exp_name}.yaml' 265 | config = load_configuration(config_path) 266 | 267 | if exp_name == 'parsing': 268 | exp = ParsingExperiment(**config) 269 | exp.run() 270 | 271 | 272 | elif exp_name == 'binding': 273 | exp = BindingExperiment(**config) 274 | exp.run() 275 | 276 | elif exp_name == 'viewinvariant': 277 | exp = ViewInvariantExperiment(**config) 278 | exp.run() 279 | 280 | else: 281 | raise ValueError("Invalid Experiment Name") 282 | 283 | def parse_args(): 284 | parser = argparse.ArgumentParser(description='Run the NETT pipeline - NeurIPS 2021 submission') 285 | parser.add_argument('-exp_name', '--exp_name', type=str, required=True, default="binding", 286 | help='name of the experiment') 287 | return parser.parse_args() 288 | 289 | 290 | if __name__ == '__main__': 291 | main() 292 | -------------------------------------------------------------------------------- /src/nett/environment/configs.py: -------------------------------------------------------------------------------- 1 | """This module contains the NETT configurations for different experiments.""" 2 | 3 | import sys 4 | import inspect 5 | from typing import Any 6 | from abc import ABC, abstractmethod 7 | from itertools import product 8 | 9 | # the naming is confusing since it is used for train or test too. 10 | class NETTConfig(ABC): 11 | """Abstract base class for NETT configurations. 12 | 13 | Args: 14 | param_defaults (dict[str, str]): A dictionary of parameter defaults. 15 | **params: Keyword arguments representing the configuration parameters. 16 | 17 | Raises: 18 | ValueError: If any parameter value is not a value or subset of the default values. 19 | """ 20 | 21 | def __init__(self, param_defaults: dict[str, str], **params) -> None: 22 | """Constructor method 23 | """ 24 | self.param_defaults = param_defaults 25 | self.params = self._validate_params(params) 26 | self.conditions = self._create_conditions_from_params(self.params) 27 | 28 | def _create_conditions_from_params(self, params: dict[str, str]) -> list[str]: 29 | """ 30 | Creates conditions from the configuration parameters. 31 | 32 | Args: 33 | params (dict[str, str]): The configuration parameters. 34 | 35 | Returns: 36 | list[str]: A list of conditions. 37 | """ 38 | combination_params = list(product(*params.values())) 39 | conditions = ["-".join(combination).lower() for combination in combination_params] 40 | return conditions 41 | 42 | def _normalize_params(self, params: dict[str, str | int | float]) -> dict[str, str]: 43 | """ 44 | Normalizes the configuration parameters. 45 | 46 | Args: 47 | params (dict[str, str | int | float]): The configuration parameters. 48 | 49 | Returns: 50 | dict[str, str]: The normalized configuration parameters. 51 | """ 52 | params = {param: (value if isinstance(value, list) else [value]) for param, value in params.items()} 53 | params = {param: [str(item) for item in value] for param, value in params.items()} 54 | return params 55 | 56 | def _validate_params(self, params: dict[str, str]) -> dict[str, str]: 57 | """ 58 | Validates the configuration parameters. 59 | 60 | Args: 61 | params (dict[str, str]): The configuration parameters. 62 | 63 | Returns: 64 | dict[str, str]: The validated configuration parameters. 65 | 66 | Raises: 67 | ValueError: If any parameter value is not a value or subset of the default values. 68 | """ 69 | params = self._normalize_params(params) 70 | for (values, default_values) in zip(params.values(), self.param_defaults.values()): 71 | if not set(values) <= set(default_values): 72 | raise ValueError(f"{values} should be a value or subset of {default_values}") 73 | return params 74 | 75 | @property 76 | def defaults(self) -> dict[str, Any]: 77 | """ 78 | Get the default values of the configuration parameters. 79 | 80 | Returns: 81 | dict[str, Any]: A dictionary of parameter defaults. 82 | """ 83 | signature = inspect.signature(self.__init__) 84 | return {param: value.default for param, value in signature.parameters.items() 85 | if value.default is not inspect.Parameter.empty} 86 | 87 | @property 88 | @abstractmethod 89 | def num_conditions(self) -> int: 90 | """ 91 | Get the number of conditions for the configuration. 92 | 93 | Returns: 94 | int: The number of conditions. 95 | """ 96 | pass 97 | 98 | 99 | class IdentityAndView(NETTConfig): 100 | """ 101 | NETT configuration for Identity and View. 102 | 103 | Args: 104 | object (str | list[str]): The object(s) to be used. Defaults to ["object1", "object2"]. 105 | rotation (str | list[str]): The rotation(s) to be used. Defaults to ["horizontal", "vertical"]. 106 | 107 | Raises: 108 | ValueError: If any parameter value is not a value or subset of the default values. 109 | """ 110 | 111 | def __init__(self, 112 | object: str | list[str] = ["object1", "object2"], 113 | rotation: str | list[str] = ["horizontal", "vertical"]) -> None: 114 | """Constructor method 115 | """ 116 | super().__init__(param_defaults=self.defaults, 117 | object=object, 118 | rotation=rotation) 119 | 120 | @property 121 | def num_conditions(self) -> int: 122 | """ 123 | Get the number of conditions for the configuration. 124 | 125 | Returns: 126 | int: The number of conditions. 127 | """ 128 | return 18 129 | 130 | 131 | class Binding(NETTConfig): 132 | """ 133 | NETT configuration for Binding. 134 | 135 | Args: 136 | object (str | list[str]): The object(s) to be used. Defaults to ["object1", "object2"]. 137 | 138 | Raises: 139 | ValueError: If any parameter value is not a value or subset of the default values. 140 | """ 141 | 142 | def __init__(self, 143 | object: str | list[str] = ["object1", "object2"]) -> None: 144 | """Constructor method 145 | """ 146 | super().__init__(param_defaults=self.defaults, 147 | object=object) 148 | 149 | @property 150 | def num_conditions(self) -> int: 151 | """ 152 | Get the number of conditions for the configuration. 153 | 154 | Returns: 155 | int: The number of conditions. 156 | """ 157 | return 50 158 | 159 | 160 | class Parsing(NETTConfig): 161 | """ 162 | NETT configuration for Parsing. 163 | 164 | Args: 165 | background (str | list[str], optional): The background(s) to be used. Defaults to ["A", "B", "C"]. 166 | object (str | list[str], optional): The object(s) to be used. Defaults to ["ship", "fork"]. 167 | """ 168 | 169 | def __init__(self, 170 | background: str | list[str] = ["A", "B", "C"], 171 | object: str | list[str] = ["ship", "fork"]) -> None: 172 | """Constructor method 173 | """ 174 | super().__init__(param_defaults=self.defaults, 175 | background=background, 176 | object=object) 177 | 178 | @property 179 | def num_conditions(self) -> int: 180 | """ 181 | Get the number of conditions for the configuration. 182 | 183 | Returns: 184 | int: The number of conditions. 185 | """ 186 | return 56 187 | 188 | 189 | class Slowness(NETTConfig): 190 | """ 191 | NETT configuration for Slowness. 192 | 193 | Args: 194 | experiment (str | list[int], optional): The experiment(s) to be used. Defaults to [1, 2]. 195 | object (str | list[str], optional): The object(s) to be used. Defaults to ["obj1", "obj2"]. 196 | speed (str | list[str], optional): The speed(s) to be used. Defaults to ["slow", "med", "fast"]. 197 | 198 | Raises: 199 | ValueError: If any parameter value is not a value or subset of the default values. 200 | """ 201 | 202 | def __init__(self, 203 | experiment: str | list[int] = [1, 2], 204 | object: str | list[str] = ["obj1", "obj2"], 205 | speed: str | list[str] = ["slow", "med", "fast"]) -> None: 206 | """Constructor method 207 | """ 208 | super().__init__(param_defaults=self.defaults, 209 | experiment=experiment, 210 | object=object, 211 | speed=speed) 212 | 213 | @property 214 | def num_conditions(self) -> int: 215 | """ 216 | Get the number of conditions for the configuration. 217 | 218 | Returns: 219 | int: The number of conditions. 220 | """ 221 | if self.params["experiment"] == "1": 222 | return 5 223 | return 13 224 | 225 | 226 | class Smoothness(NETTConfig): 227 | """ 228 | NETT configuration for Smoothness. 229 | 230 | Args: 231 | object (str or list[str], optional): The object(s) to be used. Defaults to ["obj1"]. 232 | temporal (str or list[str], optional): The temporal condition(s) to be used. Defaults to ["norm", "scram"]. 233 | 234 | Attributes: 235 | num_conditions (int): The number of conditions for the configuration. 236 | """ 237 | 238 | def __init__(self, 239 | object: str | list[str] = ["obj1"], 240 | temporal: str | list[str] = ["norm", "scram"]) -> None: 241 | """Constructor method 242 | """ 243 | super().__init__(param_defaults=self.defaults, 244 | object=object, 245 | temporal=temporal) 246 | 247 | @property 248 | def num_conditions(self) -> int: 249 | """ 250 | Get the number of conditions for the configuration. 251 | 252 | Returns: 253 | int: The number of conditions. 254 | """ 255 | return 5 256 | 257 | 258 | class OneShotViewInvariant(NETTConfig): 259 | """ 260 | NETT configuration for One-Shot View Invariant. 261 | 262 | Args: 263 | object (str | list[str]): The object(s) to be used. Defaults to ["fork", "ship"]. 264 | range (str | list[str]): The range(s) to be used. Defaults to ["360", "small", "1"]. 265 | view (str | list[str]): The view(s) to be used. Defaults to ["front", "side"]. 266 | 267 | Raises: 268 | ValueError: If any parameter value is not a value or subset of the default values. 269 | """ 270 | 271 | def __init__(self, 272 | object: str | list[str] = ["fork", "ship"], 273 | range: str | list[str] = ["360", "small", "1"], 274 | view: str | list[str] = ["front", "side"]) -> None: 275 | """Constructor method 276 | """ 277 | super().__init__(param_defaults=self.defaults, 278 | object=object, 279 | range=range, 280 | view=view) 281 | 282 | @property 283 | def num_conditions(self) -> int: 284 | """ 285 | Get the number of conditions for the configuration. 286 | 287 | Returns: 288 | int: The number of conditions. 289 | """ 290 | return 50 291 | 292 | 293 | class ViewInvariant(NETTConfig): 294 | """ 295 | NETT configuration for Binding. 296 | 297 | Args: 298 | object (str | list[str]): The object(s) to be used. Defaults to ["object1", "object2"]. 299 | 300 | Raises: 301 | ValueError: If any parameter value is not a value or subset of the default values. 302 | """ 303 | 304 | def __init__(self, 305 | object: str | list[str] = ["ship", "fork"], 306 | view: str | list[str] = ["front", "side"]) -> None: 307 | """Constructor method 308 | """ 309 | super().__init__(param_defaults=self.defaults, 310 | object=object, view = view) 311 | 312 | @property 313 | def num_conditions(self) -> int: 314 | """ 315 | Get the number of conditions for the configuration. 316 | 317 | Returns: 318 | int: The number of conditions. 319 | """ 320 | if self.view.lower()=="front": 321 | return 50 322 | return 26 323 | 324 | 325 | def list_configs() -> list[str]: 326 | """ 327 | Lists all available NETT configurations. 328 | 329 | Returns: 330 | list[str]: A list of configuration names. 331 | """ 332 | #TODO: Are these really strings? 333 | is_class_member = lambda member: inspect.isclass(member) and member.__module__ == __name__ 334 | clsmembers = inspect.getmembers(sys.modules[__name__], is_class_member) 335 | clsmembers = [clsmember[0] for clsmember in clsmembers if clsmember[0] != "NETTConfig"] 336 | return clsmembers 337 | -------------------------------------------------------------------------------- /src/nett/brain/encoders/disembodied_models/archs/resnet_1b.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | 4 | from pl_bolts.utils import _TORCHVISION_AVAILABLE 5 | from pl_bolts.utils.warnings import warn_missing_pkg 6 | 7 | __all__ = [ 8 | 'ResNet', 9 | 'resnet18', 10 | 'resnet34', 11 | 'resnet50', 12 | 'resnet101', 13 | 'resnet152', 14 | 'resnext50_32x4d', 15 | 'resnext101_32x8d', 16 | 'wide_resnet50_2', 17 | 'wide_resnet101_2', 18 | ] 19 | 20 | MODEL_URLS = { 21 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 22 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 23 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 24 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 25 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 26 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 27 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 28 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 29 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 30 | } 31 | 32 | 33 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1) -> nn.Conv2d: 34 | """ 35 | 3x3 convolution with padding 36 | 37 | Args: 38 | in_planes (int): number of input channels 39 | out_planes (int): number of output channels 40 | stride (int): stride for the convolution 41 | groups (int): number of groups for the convolution 42 | dilation (int): dilation rate for the convolution 43 | 44 | Returns: 45 | nn.Conv2d: 3x3 convolution layer 46 | """ 47 | return nn.Conv2d( 48 | in_planes, 49 | out_planes, 50 | kernel_size=3, 51 | stride=stride, 52 | padding=dilation, 53 | groups=groups, 54 | bias=False, 55 | dilation=dilation 56 | ) 57 | 58 | 59 | def conv1x1(in_planes, out_planes, stride=1) -> nn.Conv2d: 60 | """ 61 | 1x1 convolution 62 | 63 | Args: 64 | in_planes (int): number of input channels 65 | out_planes (int): number of output channels 66 | stride (int): stride for the convolution 67 | 68 | Returns: 69 | nn.Conv2d: 1x1 convolution layer 70 | """ 71 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 72 | 73 | 74 | class BasicBlock(nn.Module): 75 | """ 76 | Basic block for ResNet 77 | 78 | Args: 79 | inplanes (int): number of input channels 80 | planes (int): number of output channels 81 | stride (int): stride for the first convolution 82 | downsample (nn.Module): downsample layer 83 | groups (int): number of groups for the 3x3 convolution 84 | base_width (int): number of channels per group 85 | dilation (int): dilation rate for the 3x3 convolution 86 | norm_layer (nn.Module): normalization layer 87 | """ 88 | expansion = 1 89 | 90 | def __init__( 91 | self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None 92 | ): 93 | super(BasicBlock, self).__init__() 94 | if norm_layer is None: 95 | norm_layer = nn.BatchNorm2d 96 | if groups != 1 or base_width != 64: 97 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 98 | if dilation > 1: 99 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 100 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 101 | self.conv1 = conv3x3(inplanes, planes, stride) 102 | self.bn1 = norm_layer(planes) 103 | self.relu = nn.ReLU(inplace=True) 104 | self.conv2 = conv3x3(planes, planes) 105 | self.bn2 = norm_layer(planes) 106 | self.downsample = downsample 107 | self.stride = stride 108 | 109 | def forward(self, x: torch.Tensor) -> torch.Tensor: 110 | """ 111 | Forward pass in the network 112 | 113 | Args: 114 | x (torch.Tensor): input tensor 115 | 116 | Returns: 117 | torch.Tensor: output tensor 118 | """ 119 | identity = x 120 | 121 | out = self.conv1(x) 122 | out = self.bn1(out) 123 | out = self.relu(out) 124 | 125 | out = self.conv2(out) 126 | out = self.bn2(out) 127 | 128 | if self.downsample is not None: 129 | identity = self.downsample(x) 130 | 131 | out += identity 132 | out = self.relu(out) 133 | 134 | return out 135 | 136 | class ResNet(nn.Module): 137 | """ 138 | ResNet model 139 | 140 | Args: 141 | block (nn.Module): the block type to be used in the network 142 | layers (list of int): the number of layers for each block 143 | num_classes (int): number of classes 144 | zero_init_residual (bool): whether the residual block should be initialized to zero 145 | groups (int): number of groups for the 3x3 convolution 146 | width_per_group (int): number of channels per group 147 | replace_stride_with_dilation (tuple): replace stride with dilation 148 | norm_layer (nn.Module): normalization layer 149 | return_all_feature_maps (bool): whether to return all feature maps 150 | first_conv (bool): whether to use the first convolution 151 | maxpool1 (bool): whether to use the first maxpool 152 | """ 153 | 154 | def __init__( 155 | self, 156 | block, # what kind of block, for ex - basic block or bottleneck 157 | layers, # how many layers or basic blocks in each residual block 158 | num_classes=1000, 159 | zero_init_residual=False, 160 | groups=1, 161 | width_per_group=64, 162 | replace_stride_with_dilation=None, 163 | norm_layer=None, 164 | return_all_feature_maps=False, 165 | first_conv=False, 166 | maxpool1=False 167 | ): 168 | super(ResNet, self).__init__() 169 | if norm_layer is None: 170 | norm_layer = nn.BatchNorm2d 171 | self._norm_layer = norm_layer 172 | self.return_all_feature_maps = return_all_feature_maps 173 | 174 | 175 | 176 | self.inplanes = 3 # what is inplanes and planes in CNN ???? it should be 64 as per original implementation, does not work with 64 in case of block 1 177 | self.dilation = 1 178 | if replace_stride_with_dilation is None: 179 | # each element in the tuple indicates if we should replace 180 | # the 2x2 stride with a dilated convolution instead 181 | replace_stride_with_dilation = [False, False, False] 182 | if len(replace_stride_with_dilation) != 3: 183 | raise ValueError( 184 | "replace_stride_with_dilation should be None " 185 | f"or a 3-element tuple, got {replace_stride_with_dilation}" 186 | ) 187 | self.groups = groups 188 | self.base_width = width_per_group 189 | 190 | # ------ layers before first residual block --------------- 191 | 192 | if first_conv: 193 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 194 | else: 195 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 196 | 197 | self.bn1 = norm_layer(self.inplanes) 198 | self.relu = nn.ReLU(inplace=True) 199 | 200 | if maxpool1: 201 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 202 | else: 203 | self.maxpool = nn.MaxPool2d(kernel_size=1, stride=1) 204 | 205 | # ------ residual blocks start here ------------------------ 206 | 207 | self.layer1 = self._make_layer(block, 512, layers[0]) 208 | 209 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 210 | 211 | self.fc = nn.Linear(512 * block.expansion, num_classes) 212 | 213 | 214 | for m in self.modules(): 215 | if isinstance(m, nn.Conv2d): 216 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 217 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 218 | nn.init.constant_(m.weight, 1) 219 | nn.init.constant_(m.bias, 0) 220 | 221 | # Zero-initialize the last BN in each residual branch, 222 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 223 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 224 | if zero_init_residual: 225 | for m in self.modules(): 226 | if isinstance(m, Bottleneck): 227 | nn.init.constant_(m.bn3.weight, 0) 228 | elif isinstance(m, BasicBlock): 229 | nn.init.constant_(m.bn2.weight, 0) 230 | 231 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False) -> nn.Sequential: 232 | """ 233 | Make a layer of blocks. 234 | 235 | Args: 236 | block (nn.Module): the block type to be used in the layer 237 | planes (int): number of output channels for the layer 238 | blocks (int): number of blocks to be used 239 | stride (int): stride for the first block. Default: 1 240 | dilate (bool): whether to apply dilation strategy to the layer. Default: False 241 | 242 | Returns: 243 | nn.Sequential: a layer of blocks 244 | """ 245 | norm_layer = self._norm_layer 246 | downsample = None 247 | previous_dilation = self.dilation 248 | if dilate: 249 | self.dilation *= stride 250 | stride = 1 251 | if stride != 1 or self.inplanes != planes * block.expansion: 252 | downsample = nn.Sequential( 253 | conv1x1(self.inplanes, planes * block.expansion, stride), 254 | norm_layer(planes * block.expansion), 255 | ) 256 | 257 | layers = [] 258 | layers.append( 259 | block( 260 | self.inplanes, 261 | planes, 262 | stride, 263 | downsample, 264 | self.groups, 265 | self.base_width, 266 | previous_dilation, 267 | norm_layer, 268 | ) 269 | ) 270 | self.inplanes = planes * block.expansion 271 | for _ in range(1, blocks): 272 | layers.append( 273 | block( 274 | self.inplanes, 275 | planes, 276 | groups=self.groups, 277 | base_width=self.base_width, 278 | dilation=self.dilation, 279 | norm_layer=norm_layer 280 | ) 281 | ) 282 | 283 | return nn.Sequential(*layers) 284 | 285 | 286 | def forward(self, x: torch.Tensor) -> torch.Tensor: 287 | """ 288 | Forward pass in the network 289 | 290 | Args: 291 | x (torch.Tensor): input tensor 292 | 293 | Returns: 294 | torch.Tensor: output tensor 295 | """ 296 | x0 = self.conv1(x) 297 | x0 = self.bn1(x0) 298 | x0 = self.relu(x0) 299 | x0 = self.maxpool(x0) 300 | #print(x0.shape) 301 | if self.return_all_feature_maps: 302 | x1 = self.layer1(x) 303 | return [x0, x1] 304 | else: 305 | x0 = self.layer1(x) 306 | 307 | x0 = self.avgpool(x0) # output shape = [256X1X1] 308 | 309 | x0 = x0.reshape(x0.shape[0],-1) 310 | return x0 311 | 312 | 313 | def _resnet(arch, block, layers, pretrained, progress, **kwargs) -> ResNet: 314 | """ 315 | Constructs a ResNet model. 316 | 317 | Args: 318 | arch (str): model architecture 319 | block (nn.Module): the block type to be used in the network 320 | layers (list of int): the number of layers for each block 321 | pretrained (bool): if True, returns a model pre-trained on ImageNet 322 | progress (bool): if True, displays a progress bar of the download to stderr 323 | 324 | Returns: 325 | ResNet: model 326 | """ 327 | model = ResNet(block, layers, **kwargs) 328 | if pretrained: 329 | state_dict = load_state_dict_from_url(MODEL_URLS[arch], progress=progress) 330 | model.load_state_dict(state_dict) 331 | # Remove the last fc layer, since we only need the encoder part of resnet. 332 | #model.fc = nn.Identity() 333 | 334 | # we cannot remove the last fc layer in smaller version of the resnet because we are using the fc layer 335 | # to flatten the output and making it equal dimensions for testing and training the classifier. 336 | return model 337 | 338 | 339 | def resnet_1block(pretrained: bool = False, progress: bool = True, **kwargs) -> ResNet: 340 | """ResNet-18 model from 341 | `"Deep Residual Learning for Image Recognition" ` 342 | 343 | Args: 344 | pretrained: If True, returns a model pre-trained on ImageNet 345 | progress: If True, displays a progress bar of the download to stderr 346 | 347 | Returns: 348 | ResNet: ResNet-18 model 349 | """ 350 | 351 | return _resnet('resnet18', BasicBlock, [1, 1, 1, 1], pretrained, progress, **kwargs) 352 | 353 | -------------------------------------------------------------------------------- /src/nett/brain/encoders/resnet10.py: -------------------------------------------------------------------------------- 1 | """ 2 | Resnet10CNN feature extractor for stable-baselines3 3 | """ 4 | import pdb 5 | import gym 6 | 7 | 8 | import torch as th 9 | import torch.nn as nn 10 | import torchvision 11 | 12 | from stable_baselines3.common.torch_layers import BaseFeaturesExtractor 13 | 14 | import logging 15 | logger = logging.getLogger(__name__) 16 | 17 | class Resnet10CNN(BaseFeaturesExtractor): 18 | """ 19 | Resnet10CNN feature extractor for stable-baselines3 20 | 21 | Args: 22 | observation_space (gym.spaces.Box): Observation space 23 | features_dim (int, optional): Output dimension of features extractor. Defaults to 256. 24 | """ 25 | 26 | def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256): 27 | super(Resnet10CNN, self).__init__(observation_space, features_dim) 28 | # We assume CxHxW images (channels first) 29 | # Re-ordering will be done by pre-preprocessing or wrapper 30 | 31 | n_input_channels = observation_space.shape[0] 32 | 33 | self.cnn = _resnet(BasicBlock, [2, 2, 2, 2],num_channels = n_input_channels) 34 | logger.info(f"Resnet10CNN Encoder: {self.cnn}") 35 | 36 | with th.no_grad(): 37 | n_flatten = self.cnn( 38 | th.as_tensor(observation_space.sample()[None]).float() 39 | ).shape[1] 40 | 41 | self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU()) 42 | 43 | def forward(self, observations: th.Tensor) -> th.Tensor: 44 | """ 45 | Forward pass of the feature extractor. 46 | 47 | Args: 48 | observations (torch.Tensor): The input observations. 49 | 50 | Returns: 51 | torch.Tensor: The extracted features. 52 | """ 53 | # Cut off image 54 | # reshape to from vector to W*H 55 | # gray to color transform 56 | # application of ResNet 57 | # Concat features to the rest of observation vector 58 | # return 59 | return self.linear(self.cnn(observations)) 60 | 61 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1) -> nn.Conv2d: 62 | """ 63 | 3x3 convolution with padding 64 | 65 | Args: 66 | in_planes (int): Number of input channels 67 | out_planes (int): Number of output channels 68 | stride (int, optional): Stride of the convolution. Defaults to 1. 69 | groups (int, optional): Number of groups for the convolution. Defaults to 1. 70 | dilation (int, optional): Dilation rate of the convolution. Defaults to 1. 71 | 72 | Returns: 73 | nn.Conv2d: Convolutional layer 74 | """ 75 | return nn.Conv2d( 76 | in_planes, 77 | out_planes, 78 | kernel_size=3, 79 | stride=stride, 80 | padding=dilation, 81 | groups=groups, 82 | bias=False, 83 | dilation=dilation 84 | ) 85 | 86 | 87 | def conv1x1(in_planes, out_planes, stride=1): 88 | """ 89 | 1x1 convolution 90 | 91 | Args: 92 | in_planes (int): Number of input channels 93 | out_planes (int): Number of output channels 94 | stride (int, optional): Stride of the convolution. Defaults to 1. 95 | 96 | Returns: 97 | nn.Conv2d: Convolutional layer 98 | """ 99 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 100 | 101 | 102 | class BasicBlock(nn.Module): 103 | """ 104 | Basic block used in the ResNet-18 architecture. 105 | 106 | Args: 107 | inplanes (int): Number of input channels 108 | planes (int): Number of output channels 109 | stride (int, optional): Stride of the convolution. Defaults to 1. 110 | downsample (nn.Module, optional): Downsample layer. Defaults to None. 111 | groups (int, optional): Number of groups for the convolution. Defaults to 1. 112 | base_width (int, optional): Base width for the convolution. Defaults to 64. 113 | dilation (int, optional): Dilation rate of the convolution. Defaults to 1. 114 | norm_layer ([type], optional): Normalization layer. Defaults to None. 115 | """ 116 | expansion = 1 117 | 118 | def __init__( 119 | self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None 120 | ): 121 | super(BasicBlock, self).__init__() 122 | if norm_layer is None: 123 | norm_layer = nn.BatchNorm2d 124 | if groups != 1 or base_width != 64: 125 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 126 | if dilation > 1: 127 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 128 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 129 | 130 | # layers inside each basic block of a residual block 131 | self.conv1 = conv3x3(inplanes, planes, stride) 132 | self.bn1 = norm_layer(planes) 133 | self.relu = nn.ReLU(inplace=True) 134 | self.conv2 = conv3x3(planes, planes) 135 | self.bn2 = norm_layer(planes) 136 | # last two are operations and not layers 137 | self.downsample = downsample 138 | self.stride = stride 139 | 140 | def forward(self, x: th.Tensor) -> th.Tensor: 141 | """ 142 | Forward pass in the network 143 | 144 | Args: 145 | x (torch.Tensor): input tensor 146 | 147 | Returns: 148 | torch.Tensor: output tensor 149 | """ 150 | # saving x to pass over the bridge connection 151 | identity = x 152 | 153 | out = self.conv1(x) 154 | out = self.bn1(out) 155 | out = self.relu(out) 156 | 157 | out = self.conv2(out) 158 | out = self.bn2(out) 159 | 160 | if self.downsample is not None: 161 | identity = self.downsample(x) 162 | 163 | out += identity 164 | out = self.relu(out) 165 | 166 | return out 167 | 168 | class ResNet(nn.Module): 169 | """ 170 | ResNet architecture used in the Resnet10CNN class. 171 | 172 | Args: 173 | block (nn.Module): Residual block to use 174 | layers (list): Number of layers in each block 175 | num_channels (int): Number of input channels 176 | num_classes (int, optional): Number of classes. Defaults to 1000. 177 | zero_init_residual (bool, optional): Zero initialization for the residual block. Defaults to False. 178 | groups (int, optional): Number of groups for the convolution. Defaults to 1. 179 | width_per_group (int, optional): Base width for the convolution. Defaults to 64. 180 | replace_stride_with_dilation (tuple, optional): Replace stride with dilation. Defaults to None. 181 | norm_layer ([type], optional): Normalization layer. Defaults to None. 182 | return_all_feature_maps (bool, optional): Return all feature maps. Defaults to False. 183 | first_conv (bool, optional): Pre-processing layers which makes the image size half [64->32]. Defaults to True. 184 | maxpool1 (bool, optional): Used in pre-processing. Defaults to True. 185 | """ 186 | 187 | def __init__( 188 | self, 189 | block, 190 | layers, 191 | num_channels, 192 | num_classes=1000, 193 | zero_init_residual=False, 194 | groups=1, 195 | width_per_group=64, 196 | replace_stride_with_dilation=None, 197 | norm_layer=None, 198 | return_all_feature_maps=False, 199 | first_conv=True, # pre-processing layers which makes the image size half [64->32] 200 | maxpool1=True # used in pre-processing 201 | ): 202 | super(ResNet, self).__init__() 203 | if norm_layer is None: 204 | norm_layer = nn.BatchNorm2d 205 | self._norm_layer = norm_layer 206 | self.return_all_feature_maps = return_all_feature_maps 207 | 208 | self.inplanes = 64 209 | self.dilation = 1 210 | if replace_stride_with_dilation is None: 211 | # each element in the tuple indicates if we should replace 212 | # the 2x2 stride with a dilated convolution instead 213 | replace_stride_with_dilation = [False, False, False] 214 | if len(replace_stride_with_dilation) != 3: 215 | raise ValueError( 216 | "replace_stride_with_dilation should be None " 217 | f"or a 3-element tuple, got {replace_stride_with_dilation}" 218 | ) 219 | self.groups = groups 220 | self.base_width = width_per_group 221 | 222 | # ------ layers before first residual block --------------- 223 | if first_conv: 224 | self.conv1 = nn.Conv2d(num_channels, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 225 | 226 | else: 227 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 228 | 229 | 230 | 231 | self.bn1 = norm_layer(self.inplanes) 232 | self.relu = nn.ReLU(inplace=True) 233 | 234 | if maxpool1: 235 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 236 | else: 237 | self.maxpool = nn.MaxPool2d(kernel_size=1, stride=1) 238 | 239 | # ------ residual blocks start here ------------------------ 240 | 241 | # BLOCK - 1 242 | self.layer1 = self._make_layer(block, 64, layers[0]) 243 | 244 | # BLOCK - 2 245 | self.layer2 = self._make_layer(block, 512, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) 246 | 247 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 248 | self.fc = nn.Linear(512 * block.expansion, num_classes) 249 | 250 | for m in self.modules(): 251 | if isinstance(m, nn.Conv2d): 252 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 253 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 254 | nn.init.constant_(m.weight, 1) 255 | nn.init.constant_(m.bias, 0) 256 | 257 | # Zero-initialize the last BN in each residual branch, 258 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 259 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 260 | if zero_init_residual: 261 | for m in self.modules(): 262 | if isinstance(m, Bottleneck): 263 | nn.init.constant_(m.bn3.weight, 0) 264 | elif isinstance(m, BasicBlock): 265 | nn.init.constant_(m.bn2.weight, 0) 266 | 267 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False) -> nn.Sequential: 268 | """ 269 | Helper function to create a residual layer. 270 | 271 | Args: 272 | block (nn.Module): Residual block to use 273 | planes (int): Number of output channels 274 | blocks (int): Number of blocks 275 | stride (int, optional): Stride of the convolution. Defaults to 1. 276 | dilate (bool, optional): Use dilation. Defaults to False. 277 | 278 | Returns: 279 | nn.Sequential: Residual layer 280 | """ 281 | norm_layer = self._norm_layer 282 | downsample = None 283 | previous_dilation = self.dilation 284 | if dilate: 285 | self.dilation *= stride 286 | stride = 1 287 | if stride != 1 or self.inplanes != planes * block.expansion: 288 | downsample = nn.Sequential( 289 | conv1x1(self.inplanes, planes * block.expansion, stride), 290 | norm_layer(planes * block.expansion), 291 | ) 292 | 293 | layers = [] 294 | layers.append( 295 | block( 296 | self.inplanes, 297 | planes, 298 | stride, 299 | downsample, 300 | self.groups, 301 | self.base_width, 302 | previous_dilation, 303 | norm_layer, 304 | ) 305 | ) 306 | self.inplanes = planes * block.expansion 307 | for _ in range(1, blocks): 308 | layers.append( 309 | block( 310 | self.inplanes, 311 | planes, 312 | groups=self.groups, 313 | base_width=self.base_width, 314 | dilation=self.dilation, 315 | norm_layer=norm_layer 316 | ) 317 | ) 318 | 319 | return nn.Sequential(*layers) 320 | 321 | def forward(self, x: th.Tensor) -> th.Tensor: 322 | """ 323 | Forward pass in the network 324 | 325 | Args: 326 | x (torch.Tensor): input tensor 327 | 328 | Returns: 329 | torch.Tensor: output tensor 330 | """ 331 | 332 | # passing input from pre-processing layers 333 | x0 = self.conv1(x) 334 | x0 = self.bn1(x0) 335 | x0 = self.relu(x0) 336 | x0 = self.maxpool(x0) 337 | 338 | # passing input from residual blocks 339 | if self.return_all_feature_maps: 340 | x1 = self.layer1(x0) # block1 341 | x2 = self.layer2(x1) # block2 342 | 343 | return [x0, x1, x2] 344 | else: 345 | x0 = self.layer1(x0) 346 | x0 = self.layer2(x0) 347 | 348 | x0 = self.avgpool(x0) 349 | x0 = th.flatten(x0, 1) 350 | 351 | return x0 352 | 353 | 354 | def _resnet(block, layers, **kwargs): 355 | """ 356 | ResNet architecture used in the Resnet10CNN class. 357 | 358 | Args: 359 | block (nn.Module): Residual block to use 360 | layers (list): Number of layers in each block 361 | 362 | Returns: 363 | ResNet: ResNet model 364 | """ 365 | model = ResNet(block, layers, **kwargs) 366 | model.fc = nn.Identity() 367 | return model 368 | -------------------------------------------------------------------------------- /src/nett/brain/encoders/disembodied_models/archs/resnet_2b.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This file contains the implementation of ResNet-18 with 2 blocks 3 | Output from the second block now gives 512 channels instead of 128 4 | ''' 5 | 6 | 7 | import torch 8 | from torch import nn as nn 9 | 10 | from pl_bolts.utils import _TORCHVISION_AVAILABLE 11 | from pl_bolts.utils.warnings import warn_missing_pkg 12 | 13 | # if _TORCHVISION_AVAILABLE: 14 | # #from torchvision.models.utils import load_state_dict_from_url 15 | # from torch.hub import load_state_dict_from_url 16 | # else: # pragma: no cover 17 | # warn_missing_pkg('torchvision') 18 | 19 | __all__ = [ 20 | 'ResNet', 21 | 'resnet18', 22 | 'resnet34', 23 | 'resnet50', 24 | 'resnet101', 25 | 'resnet152', 26 | 'resnext50_32x4d', 27 | 'resnext101_32x8d', 28 | 'wide_resnet50_2', 29 | 'wide_resnet101_2', 30 | ] 31 | 32 | MODEL_URLS = { 33 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 34 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 35 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 36 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 37 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 38 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 39 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 40 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 41 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 42 | } 43 | 44 | 45 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1) -> nn.Conv2d: 46 | """ 47 | 3x3 convolution with padding 48 | 49 | Args: 50 | in_planes (int): Number of input channels 51 | out_planes (int): Number of output channels 52 | stride (int): Stride 53 | groups (int): Number of groups 54 | dilation (int): Dilation 55 | 56 | Returns: 57 | nn.Conv2d: Convolution layer 58 | """ 59 | return nn.Conv2d( 60 | in_planes, 61 | out_planes, 62 | kernel_size=3, 63 | stride=stride, 64 | padding=dilation, 65 | groups=groups, 66 | bias=False, 67 | dilation=dilation 68 | ) 69 | 70 | 71 | def conv1x1(in_planes, out_planes, stride=1) -> nn.Conv2d: 72 | """ 73 | 1x1 convolution 74 | 75 | Args: 76 | in_planes (int): Number of input channels 77 | out_planes (int): Number of output channels 78 | stride (int): Stride 79 | 80 | Returns: 81 | nn.Conv2d: Convolution layer""" 82 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 83 | 84 | 85 | class BasicBlock(nn.Module): 86 | """ 87 | BasicBlock for ResNet 88 | 89 | Args: 90 | inplanes (int): Number of input channels 91 | planes (int): Number of output channels 92 | stride (int): Stride 93 | downsample (nn.Module): Downsample layer 94 | groups (int): Number of groups 95 | base_width (int): Base width 96 | dilation (int): Dilation 97 | norm_layer (nn.Module): Normalization layer 98 | """ 99 | expansion = 1 100 | 101 | def __init__( 102 | self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None 103 | ): 104 | super(BasicBlock, self).__init__() 105 | if norm_layer is None: 106 | norm_layer = nn.BatchNorm2d 107 | if groups != 1 or base_width != 64: 108 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 109 | if dilation > 1: 110 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 111 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 112 | 113 | # layers inside each basic block of a residual block 114 | self.conv1 = conv3x3(inplanes, planes, stride) 115 | self.bn1 = norm_layer(planes) 116 | self.relu = nn.ReLU(inplace=True) 117 | self.conv2 = conv3x3(planes, planes) 118 | self.bn2 = norm_layer(planes) 119 | # last two are operations and not layers 120 | self.downsample = downsample 121 | self.stride = stride 122 | 123 | def forward(self, x: torch.Tensor) -> torch.Tensor: 124 | """ 125 | Forward pass in the network 126 | 127 | Args: 128 | x (torch.Tensor): input tensor 129 | 130 | Returns: 131 | torch.Tensor: output tensor 132 | """ 133 | # saving x to pass over the bridge connection 134 | identity = x 135 | 136 | out = self.conv1(x) 137 | out = self.bn1(out) 138 | out = self.relu(out) 139 | 140 | out = self.conv2(out) 141 | out = self.bn2(out) 142 | 143 | if self.downsample is not None: 144 | identity = self.downsample(x) 145 | 146 | out += identity 147 | out = self.relu(out) 148 | 149 | return out 150 | 151 | 152 | class ResNet(nn.Module): 153 | """ 154 | ResNet model 155 | 156 | Args: 157 | block (nn.Module): ResNet block 158 | layers (list): List of layers 159 | num_classes (int): Number of classes 160 | zero_init_residual (bool): If True, zero-initialize the last BN in each residual branch 161 | groups (int): Number of groups 162 | width_per_group (int): Width per group 163 | replace_stride_with_dilation (tuple): Replace stride with dilation 164 | norm_layer (nn.Module): Normalization layer 165 | return_all_feature_maps (bool): If True, returns all feature maps 166 | first_conv (bool): If True, uses first conv layer 167 | maxpool1 (bool): If True, uses maxpool1 layer 168 | """ 169 | 170 | def __init__( 171 | self, 172 | block, 173 | layers, 174 | num_classes=1000, 175 | zero_init_residual=False, 176 | groups=1, 177 | width_per_group=64, 178 | replace_stride_with_dilation=None, 179 | norm_layer=None, 180 | return_all_feature_maps=False, 181 | first_conv=True, # pre-processing layers which makes the image size half [64->32] 182 | maxpool1=True # used in pre-processing 183 | ): 184 | super(ResNet, self).__init__() 185 | if norm_layer is None: 186 | norm_layer = nn.BatchNorm2d 187 | self._norm_layer = norm_layer 188 | self.return_all_feature_maps = return_all_feature_maps 189 | 190 | self.inplanes = 64 191 | self.dilation = 1 192 | if replace_stride_with_dilation is None: 193 | # each element in the tuple indicates if we should replace 194 | # the 2x2 stride with a dilated convolution instead 195 | replace_stride_with_dilation = [False, False, False] 196 | if len(replace_stride_with_dilation) != 3: 197 | raise ValueError( 198 | "replace_stride_with_dilation should be None " 199 | f"or a 3-element tuple, got {replace_stride_with_dilation}" 200 | ) 201 | self.groups = groups 202 | self.base_width = width_per_group 203 | 204 | # ------ layers before first residual block --------------- 205 | 206 | if first_conv: 207 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 208 | else: 209 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 210 | 211 | self.bn1 = norm_layer(self.inplanes) 212 | self.relu = nn.ReLU(inplace=True) 213 | 214 | if maxpool1: 215 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 216 | else: 217 | self.maxpool = nn.MaxPool2d(kernel_size=1, stride=1) 218 | 219 | # ------ residual blocks start here ------------------------ 220 | 221 | # BLOCK - 1 222 | self.layer1 = self._make_layer(block, 64, layers[0]) 223 | 224 | # BLOCK - 2 225 | self.layer2 = self._make_layer(block, 512, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) 226 | 227 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 228 | self.fc = nn.Linear(512 * block.expansion, num_classes) 229 | 230 | for m in self.modules(): 231 | if isinstance(m, nn.Conv2d): 232 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 233 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 234 | nn.init.constant_(m.weight, 1) 235 | nn.init.constant_(m.bias, 0) 236 | 237 | # Zero-initialize the last BN in each residual branch, 238 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 239 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 240 | if zero_init_residual: 241 | for m in self.modules(): 242 | if isinstance(m, Bottleneck): 243 | nn.init.constant_(m.bn3.weight, 0) 244 | elif isinstance(m, BasicBlock): 245 | nn.init.constant_(m.bn2.weight, 0) 246 | 247 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False) -> nn.Sequential: 248 | """ 249 | Creates a layer of residual blocks 250 | 251 | Args: 252 | block (nn.Module): ResNet block 253 | planes (int): Number of planes 254 | blocks (int): Number of blocks 255 | stride (int): Stride 256 | dilate (bool): If True, dilates the stride 257 | 258 | Returns: 259 | nn.Sequential: Residual block layer 260 | """ 261 | norm_layer = self._norm_layer 262 | downsample = None 263 | previous_dilation = self.dilation 264 | if dilate: 265 | self.dilation *= stride 266 | stride = 1 267 | if stride != 1 or self.inplanes != planes * block.expansion: 268 | downsample = nn.Sequential( 269 | conv1x1(self.inplanes, planes * block.expansion, stride), 270 | norm_layer(planes * block.expansion), 271 | ) 272 | 273 | layers = [] 274 | layers.append( 275 | block( 276 | self.inplanes, 277 | planes, 278 | stride, 279 | downsample, 280 | self.groups, 281 | self.base_width, 282 | previous_dilation, 283 | norm_layer, 284 | ) 285 | ) 286 | self.inplanes = planes * block.expansion 287 | for _ in range(1, blocks): 288 | layers.append( 289 | block( 290 | self.inplanes, 291 | planes, 292 | groups=self.groups, 293 | base_width=self.base_width, 294 | dilation=self.dilation, 295 | norm_layer=norm_layer 296 | ) 297 | ) 298 | 299 | return nn.Sequential(*layers) 300 | 301 | def forward(self, x: torch.Tensor) -> torch.Tensor: 302 | """ 303 | Forward pass in the network 304 | 305 | Args: 306 | x (torch.Tensor): input tensor 307 | 308 | Returns: 309 | torch.Tensor: output tensor 310 | """ 311 | # passing input from pre-processing layers 312 | x0 = self.conv1(x) 313 | x0 = self.bn1(x0) 314 | x0 = self.relu(x0) 315 | x0 = self.maxpool(x0) 316 | 317 | # passing input from residual blocks 318 | if self.return_all_feature_maps: 319 | x1 = self.layer1(x0) # block1 320 | x2 = self.layer2(x1) # block2 321 | 322 | return [x0, x1, x2] 323 | else: 324 | x0 = self.layer1(x0) 325 | x0 = self.layer2(x0) 326 | 327 | x0 = self.avgpool(x0) 328 | x0 = torch.flatten(x0, 1) 329 | 330 | return x0 331 | 332 | 333 | def _resnet(arch, block, layers, pretrained, progress, **kwargs) -> ResNet: 334 | """ 335 | Constructs a ResNet model. 336 | 337 | Args: 338 | arch (str): Architecture name from the URLs 339 | block (nn.Module): ResNet block 340 | layers (list): List of layers 341 | pretrained (bool): If True, returns a model pre-trained on ImageNet 342 | progress (bool): If True, displays a progress bar of the download to stderr 343 | **kwargs: Other arguments for the ResNet model 344 | 345 | Returns: 346 | ResNet: ResNet model 347 | """ 348 | model = ResNet(block, layers, **kwargs) 349 | if pretrained: 350 | state_dict = load_state_dict_from_url(MODEL_URLS[arch], progress=progress) 351 | model.load_state_dict(state_dict) 352 | # Remove the last fc layer, since we only need the encoder part of resnet. 353 | model.fc = nn.Identity() 354 | return model 355 | 356 | 357 | 358 | def resnet_2blocks(pretrained: bool = False, progress: bool = True, **kwargs) -> nn.Module: 359 | """ 360 | Constructs a ResNet-18 model with 2 blocks. 361 | 362 | Args: 363 | pretrained (bool): If True, returns a model pre-trained on ImageNet 364 | progress (bool): If True, displays a progress bar of the download to stderr 365 | **kwargs: Other arguments for the ResNet model 366 | 367 | Returns: 368 | nn.Module: ResNet-18 model with 2 blocks 369 | """ 370 | 371 | 372 | 373 | """ 374 | first argument in _resnet() : architecture name from the URLs 375 | since URL for resnet9 is not available, therefore resnet18 is used with modifications 376 | """ 377 | 378 | # to print this architecture, print the model from the evaluator/evaluate file 379 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) 380 | -------------------------------------------------------------------------------- /src/nett/brain/encoders/disembodied_models/archs/resnet_3b.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This file contains the implementation of ResNet-18 with 3 blocks 3 | Output from the third block now gives 512 channels instead of 256 4 | ''' 5 | 6 | 7 | import torch 8 | from torch import nn as nn 9 | 10 | from pl_bolts.utils import _TORCHVISION_AVAILABLE 11 | from pl_bolts.utils.warnings import warn_missing_pkg 12 | 13 | # if _TORCHVISION_AVAILABLE: 14 | # from torchvision.models.utils import load_state_dict_from_url 15 | # else: # pragma: no cover 16 | # warn_missing_pkg('torchvision') 17 | 18 | __all__ = [ 19 | 'ResNet', 20 | 'resnet18', 21 | 'resnet34', 22 | 'resnet50', 23 | 'resnet101', 24 | 'resnet152', 25 | 'resnext50_32x4d', 26 | 'resnext101_32x8d', 27 | 'wide_resnet50_2', 28 | 'wide_resnet101_2', 29 | ] 30 | 31 | MODEL_URLS = { 32 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 33 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 34 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 35 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 36 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 37 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 38 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 39 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 40 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 41 | } 42 | 43 | 44 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1) -> nn.Conv2d: 45 | """ 46 | 3x3 convolution with padding 47 | 48 | Args: 49 | in_planes (int): number of input planes 50 | out_planes (int): number of output planes 51 | stride (int, optional): stride. Defaults to 1. 52 | groups (int, optional): number of groups. Defaults to 1. 53 | dilation (int, optional): dilation. Defaults to 1. 54 | 55 | Returns: 56 | nn.Conv2d: convolution layer 57 | """ 58 | return nn.Conv2d( 59 | in_planes, 60 | out_planes, 61 | kernel_size=3, 62 | stride=stride, 63 | padding=dilation, 64 | groups=groups, 65 | bias=False, 66 | dilation=dilation 67 | ) 68 | 69 | 70 | def conv1x1(in_planes, out_planes, stride=1) -> nn.Conv2d: 71 | """ 72 | 1x1 convolution 73 | 74 | Args: 75 | in_planes (int): number of input planes 76 | out_planes (int): number of output planes 77 | stride (int, optional): stride. Defaults to 1. 78 | 79 | Returns: 80 | nn.Conv2d: convolution layer 81 | """ 82 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 83 | 84 | 85 | class BasicBlock(nn.Module): 86 | """ 87 | Basic block for ResNet 88 | 89 | Args: 90 | inplanes (int): number of input planes 91 | planes (int): number of planes 92 | stride (int, optional): stride. Defaults to 1. 93 | downsample (nn.Module, optional): downsample. Defaults to None. 94 | groups (int, optional): number of groups. Defaults to 1. 95 | base_width (int, optional): base width. Defaults to 64. 96 | dilation (int, optional): dilation. Defaults to 1. 97 | norm_layer (nn.Module, optional): normalization layer. Defaults to None. 98 | """ 99 | expansion = 1 100 | 101 | def __init__( 102 | self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None 103 | ): 104 | super(BasicBlock, self).__init__() 105 | if norm_layer is None: 106 | norm_layer = nn.BatchNorm2d 107 | if groups != 1 or base_width != 64: 108 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 109 | if dilation > 1: 110 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 111 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 112 | 113 | # layers inside each basic block of a residual block 114 | self.conv1 = conv3x3(inplanes, planes, stride) 115 | self.bn1 = norm_layer(planes) 116 | self.relu = nn.ReLU(inplace=True) 117 | self.conv2 = conv3x3(planes, planes) 118 | self.bn2 = norm_layer(planes) 119 | # last two are operations and not layers 120 | self.downsample = downsample 121 | self.stride = stride 122 | 123 | def forward(self, x) -> torch.Tensor: 124 | """ 125 | Forward pass in the network 126 | 127 | Args: 128 | x (torch.Tensor): input tensor 129 | 130 | Returns: 131 | torch.Tensor: output tensor 132 | """ 133 | # saving x to pass over the bridge connection 134 | identity = x 135 | 136 | out = self.conv1(x) 137 | out = self.bn1(out) 138 | out = self.relu(out) 139 | 140 | out = self.conv2(out) 141 | out = self.bn2(out) 142 | 143 | if self.downsample is not None: 144 | identity = self.downsample(x) 145 | 146 | out += identity 147 | out = self.relu(out) 148 | 149 | return out 150 | 151 | 152 | class ResNet(nn.Module): 153 | """ 154 | ResNet model 155 | 156 | Args: 157 | block (nn.Module): block type 158 | layers (list): list of layers 159 | num_classes (int, optional): number of classes. Defaults to 1000. 160 | zero_init_residual (bool, optional): If True, zero-initialize the last BN in each residual branch. Defaults to False. 161 | groups (int, optional): number of groups. Defaults to 1. 162 | width_per_group (int, optional): width per group. Defaults to 64. 163 | replace_stride_with_dilation (tuple, optional): replace stride with dilation. Defaults to None. 164 | norm_layer (nn.Module, optional): normalization layer. Defaults to None. 165 | return_all_feature_maps (bool, optional): If True, returns all feature maps. Defaults to False. 166 | first_conv (bool, optional): If True, uses a 7x7 kernel for the first convolution. Defaults to True. 167 | maxpool1 (bool, optional): If True, uses a maxpool layer after the first convolution. Defaults to True. 168 | """ 169 | 170 | def __init__( 171 | self, 172 | block, 173 | layers, 174 | num_classes=1000, # what should be the right parameter? 175 | zero_init_residual=False, 176 | groups=1, 177 | width_per_group=64, 178 | replace_stride_with_dilation=None, 179 | norm_layer=None, 180 | return_all_feature_maps=False, 181 | first_conv=True, #pre-processing layers which makes the image size half [64->32] 182 | maxpool1=True, #pre-processing 183 | ): 184 | super(ResNet, self).__init__() 185 | if norm_layer is None: 186 | norm_layer = nn.BatchNorm2d 187 | self._norm_layer = norm_layer 188 | self.return_all_feature_maps = return_all_feature_maps 189 | 190 | self.inplanes = 64 191 | self.dilation = 1 192 | if replace_stride_with_dilation is None: 193 | # each element in the tuple indicates if we should replace 194 | # the 2x2 stride with a dilated convolution instead 195 | replace_stride_with_dilation = [False, False, False] 196 | if len(replace_stride_with_dilation) != 3: 197 | raise ValueError( 198 | "replace_stride_with_dilation should be None " 199 | f"or a 3-element tuple, got {replace_stride_with_dilation}" 200 | ) 201 | self.groups = groups 202 | self.base_width = width_per_group 203 | 204 | # ------ layers before first residual block --------------- 205 | 206 | if first_conv: 207 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 208 | else: 209 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 210 | 211 | self.bn1 = norm_layer(self.inplanes) 212 | self.relu = nn.ReLU(inplace=True) 213 | 214 | if maxpool1: 215 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 216 | else: 217 | self.maxpool = nn.MaxPool2d(kernel_size=1, stride=1) 218 | 219 | # ------ residual blocks start here ------------------------ 220 | 221 | self.layer1 = self._make_layer(block, 64, layers[0]) 222 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) 223 | self.layer3 = self._make_layer(block, 512, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) 224 | #self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) 225 | 226 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 227 | self.fc = nn.Linear(512 * block.expansion, num_classes) 228 | 229 | for m in self.modules(): 230 | if isinstance(m, nn.Conv2d): 231 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 232 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 233 | nn.init.constant_(m.weight, 1) 234 | nn.init.constant_(m.bias, 0) 235 | 236 | # Zero-initialize the last BN in each residual branch, 237 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 238 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 239 | if zero_init_residual: 240 | for m in self.modules(): 241 | if isinstance(m, Bottleneck): 242 | nn.init.constant_(m.bn3.weight, 0) 243 | elif isinstance(m, BasicBlock): 244 | nn.init.constant_(m.bn2.weight, 0) 245 | 246 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False) -> nn.Sequential: 247 | """ 248 | Create a layer of residual blocks 249 | 250 | Args: 251 | block (nn.Module): block type 252 | planes (int): number of planes 253 | blocks (int): number of blocks 254 | stride (int): stride 255 | dilate (bool): If True, use dilation 256 | 257 | Returns: 258 | nn.Sequential: layer of residual blocks 259 | """ 260 | norm_layer = self._norm_layer 261 | downsample = None 262 | previous_dilation = self.dilation 263 | if dilate: 264 | self.dilation *= stride 265 | stride = 1 266 | if stride != 1 or self.inplanes != planes * block.expansion: 267 | downsample = nn.Sequential( 268 | conv1x1(self.inplanes, planes * block.expansion, stride), 269 | norm_layer(planes * block.expansion), 270 | ) 271 | 272 | layers = [] 273 | layers.append( 274 | block( 275 | self.inplanes, 276 | planes, 277 | stride, 278 | downsample, 279 | self.groups, 280 | self.base_width, 281 | previous_dilation, 282 | norm_layer, 283 | ) 284 | ) 285 | self.inplanes = planes * block.expansion 286 | for _ in range(1, blocks): 287 | layers.append( 288 | block( 289 | self.inplanes, 290 | planes, 291 | groups=self.groups, 292 | base_width=self.base_width, 293 | dilation=self.dilation, 294 | norm_layer=norm_layer 295 | ) 296 | ) 297 | 298 | return nn.Sequential(*layers) 299 | 300 | def forward(self, x: torch.Tensor) -> torch.Tensor: 301 | """ 302 | Forward pass in the network 303 | 304 | Args: 305 | x (torch.Tensor): input tensor 306 | 307 | Returns: 308 | torch.Tensor: output tensor 309 | """ 310 | 311 | # passing input from pre-processing layers 312 | x0 = self.conv1(x) 313 | x0 = self.bn1(x0) 314 | x0 = self.relu(x0) 315 | x0 = self.maxpool(x0) 316 | 317 | # passing input from residual blocks 318 | if self.return_all_feature_maps: 319 | x1 = self.layer1(x0) 320 | x2 = self.layer2(x1) 321 | x3 = self.layer3(x2) 322 | 323 | 324 | return [x0, x1, x2, x3] 325 | 326 | else: 327 | x0 = self.layer1(x0) 328 | x0 = self.layer2(x0) 329 | x0 = self.layer3(x0) 330 | 331 | 332 | x0 = self.avgpool(x0) 333 | x0 = torch.flatten(x0, 1) 334 | 335 | return x0 336 | 337 | 338 | def _resnet(arch, block, layers, pretrained, progress, **kwargs) -> ResNet: 339 | """ 340 | Constructs a ResNet model. 341 | 342 | Args: 343 | arch (str): model architecture 344 | block (nn.Module): block type 345 | layers (list): list of layers 346 | pretrained (bool): If True, returns a model pre-trained on ImageNet 347 | progress (bool): If True, displays a progress bar of the download to stderr 348 | 349 | Returns: 350 | ResNet: model 351 | """ 352 | model = ResNet(block, layers, **kwargs) 353 | if pretrained: 354 | state_dict = load_state_dict_from_url(MODEL_URLS[arch], progress=progress) 355 | model.load_state_dict(state_dict) 356 | # Remove the last fc layer, since we only need the encoder part of resnet. 357 | model.fc = nn.Identity() 358 | return model 359 | 360 | 361 | def resnet_3blocks(pretrained: bool = False, progress: bool = True, **kwargs) -> ResNet: 362 | """ResNet-18 model from 363 | `"Deep Residual Learning for Image Recognition" ` 364 | 365 | Args: 366 | pretrained: If True, returns a model pre-trained on ImageNet 367 | progress: If True, displays a progress bar of the download to stderr 368 | 369 | Returns: 370 | ResNet: model 371 | """ 372 | 373 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) --------------------------------------------------------------------------------