18 |
19 | The Newborn Embodied Turing Test (NETT) is a cutting-edge toolkit designed to simulate virtual agents in controlled-rearing conditions. This innovative platform enables researchers to create, simulate, and analyze virtual agents, facilitating direct comparisons with real chicks as documented by the **[Building a Mind Lab](http://buildingamind.com/)**. Our comprehensive suite includes all necessary components for the simulation and analysis of embodied models, closely replicating laboratory conditions.
20 |
21 | Below is a visual representation of our experimental setup, showcasing the infrastructure for the three primary experiments discussed in this documentation.
22 |
23 |
24 |
25 |
26 |
27 |
28 | ## How to Use this Repository
29 |
30 | The NETT toolkit comprises three key components:
31 |
32 | 1. **Virtual Environment**: A dynamic environment that serves as the habitat for virtual agents.
33 | 2. **Experimental Simulation Programs**: Tools to initiate and conduct experiments within the virtual world.
34 | 3. **Data Visualization Programs**: Utilities for analyzing and visualizing experiment outcomes.
35 |
36 | ## Directory Structure
37 |
38 | The directory structure of the code is as follows:
39 |
40 | ```
41 | ├── docs # Documentation and guides
42 | ├── examples
43 | │ ├── notebooks # Jupyter Notebooks for examples
44 | │ └── Getting Started.ipynb # Introduction and setup notebook
45 | │ └── run # Terminal script example
46 | ├── src/nett
47 | │ ├── analysis # Analysis scripts
48 | │ ├── body # Agent body configurations
49 | │ ├── brain # Neural network models and learning algorithms
50 | │ ├── environment # Simulation environments
51 | │ ├── utils # Utility functions
52 | │ ├── nett.py # Main library script
53 | │ └── __init__.py # Package initialization
54 | ├── tests # Unit tests
55 | ├── mkdocs.yml # MkDocs configuration
56 | ├── pyproject.toml # Project metadata
57 | └── README.md # This README file
58 | ```
59 |
60 | ## Getting Started
61 |
62 |
63 | To begin benchmarking your first embodied agent with NETT, please be aware:
64 |
65 | **Important**: The `mlagents==1.0.0` dependency is incompatible with Apple Silicon (M1, M2, etc.) chips. Please utilize an alternate device to execute this codebase.
66 |
67 | ### Installation
68 |
69 | 1. **Virtual Environment Setup (Highly Recommended)**
70 |
71 | Create and activate a virtual environment to avoid dependency conflicts.
72 | ```bash
73 | conda create -y -n nett_env python=3.10.12
74 | conda activate nett_env
75 | ```
76 | See [here](https://uoa-eresearch.github.io/eresearch-cookbook/recipe/2014/11/20/conda "Link for how to set-up a virtual env") for detailed instructions.
77 |
78 | 2. **Install Prerequistes**
79 |
80 | Install the needed versions of `setuptools` and `pip`:
81 | ```bash
82 | pip install setuptools==65.5.0 pip==21 wheel==0.38.4
83 | ```
84 | **NOTE:** This is a result of incompatibilities with the subdependency `gym==0.21`. More information about this issue can be found [here](https://github.com/openai/gym/issues/3176#issuecomment-1560026649)
85 |
86 | 3. **Toolkit Installation**
87 |
88 | Install the toolkit using `pip`.
89 | ```bash
90 | pip install nett-benchmarks
91 | ```
92 |
93 | **NOTE:**: Installation outside a virtual environment may fail due to conflicting dependencies. Ensure compatibility, especially with `gym==0.21` and `numpy<=1.21.2`.
94 |
95 | ### Running a NETT
96 |
97 | 1. **Download or Create the Unity Executable**
98 |
99 | Obtain a pre-made Unity executable from [here](https://origins.luddy.indiana.edu/environments/). The executable is required to run the virtual environment.
100 |
101 | 2. **Import NETT Components**
102 |
103 | Start by importing the NETT framework components - `Brain`, `Body`, and `Environment`, alongside the main `NETT` class.
104 | ```python
105 | from nett import Brain, Body, Environment
106 | from nett import NETT
107 | ```
108 |
109 | 3. **Component Configuration**:
110 |
111 | - **Brain**
112 |
113 | Configure the learning aspects, including the policy network (e.g. "CnnPolicy"), learning algorithm (e.g. "PPO"), the reward function, and the encoder.
114 | ```python
115 | brain = Brain(policy="CnnPolicy", algorithm="PPO")
116 | ```
117 | To get a list of all available policies, algorithms, and encoders, run `nett.list_policies()`, `nett.list_algorithms()`, and `nett.list_encoders()` respectively.
118 |
119 | - **Body**
120 |
121 | Set up the agent's physical interface with the environment. It's possible to apply gym.Wrappers for data preprocessing.
122 | ```python
123 | body = Body(type="basic", dvs=False, wrappers=None)
124 | ```
125 | Here, we do not pass any wrappers, letting information from the environment reach the brain "as is". Alternative body types (e.g. `two-eyed`, `rag-doll`) are planned in future updates.
126 |
127 | - **Environment**
128 |
129 | Create the simulation environment using the path to your Unity executable (see Step 1).
130 | ```python
131 | environment = Environment(config="identityandview", executable_path="path/to/executable.x86_64")
132 | ```
133 | To get a list of all available configurations, run `nett.list_configs()`.
134 |
135 | 4. **Run the Benchmarking**
136 |
137 | Integrate all components into a NETT instance to facilitate experiment execution.
138 | ```python
139 | benchmarks = NETT(brain=brain, body=body, environment=environment)
140 | ```
141 | The `NETT` instance has a `.run()` method that initiates the benchmarking process. The method accepts parameters such as the number of brains, training/testing episodes, and the output directory.
142 | ```python
143 | job_sheet = benchmarks.run(output_dir="path/to/run/output/directory/", num_brains=5, trains_eps=10, test_eps=5)
144 | ```
145 | The `run` function is asynchronous, returning the list of jobs that may or may not be complete. If you wish to display the Unity environments running, set the `batch_mode` parameter to `False`.
146 |
147 | 5. **Check Status**:
148 |
149 | To see the status of the benchmark processes, use the `.status()` method:
150 | ```python
151 | benchmarks.status(job_sheet)
152 | ```
153 |
154 | ### Running Standard Analysis
155 |
156 | After running the experiments, the pipeline will generate a collection of datafiles in the defined output directory.
157 |
158 | 1. **Install R and dependencies**
159 |
160 | To run the analyses performed in previous experiments,this toolkit provides a set of analysis scripts. Prior to running them, you will need R and the packages `tidyverse`, `argparse`, and `scales` installed. To install these packages, run the following command in R:
161 | ```R
162 | install.packages(c("tidyverse", "argparse", "scales"))
163 | ```
164 | Alternatively, if you are having difficulty installing R on your system, you can install these using conda.
165 | ```bash
166 | conda install -y r r-tidyverse r-argparse r-scales
167 | ```
168 | 2. **Run the Analysis**
169 |
170 | To run the analysis, use the `analyze` method of the `NETT` class. This method will generate a set of plots and tables based on the datafiles in the output directory.
171 | ```python
172 | benchmarks.analyze(run_dir="path/to/run/output/directory/", output_dir="path/to/analysis/output/directory/")
173 | ```
174 |
175 |
176 |
177 |
178 | ## Documentation
179 | For a link to the full documentation, please visit [here](https://buildingamind.github.io/NewbornEmbodiedTuringTest/).
180 |
181 | ## Experiment Configuration
182 |
183 | More information related to details on the experiment can be found on following pages.
184 |
185 | * [**Parsing Experiment**](https://buildingamind.github.io/NewbornEmbodiedTuringTest/papers/Parsing.html)
186 | * [**ViewPoint Experiment**](https://buildingamind.github.io/NewbornEmbodiedTuringTest/papers/ViewInvariant.html)
187 |
188 | [🔼 Back to top](#newborn-embodied-turing-test)
189 |
--------------------------------------------------------------------------------
/src/nett/analysis/NETT_test_viz.R:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env Rscript
2 |
3 | # NETT_test_viz.R
4 |
5 | # Before running this script, you need to run merge_csvs to merge all of the agents'
6 | # output into a single, standardized format dataframe for training and test data
7 |
8 | # Variables --------------------------------------------------------------------
9 |
10 | # Read in the user-specified variables:
11 | library(argparse)
12 | parser <- ArgumentParser(description="An executable R script for the Newborn Embodied Turing Tests to analyze test trials")
13 | parser$add_argument("--data-loc", type="character", dest="data_loc",
14 | help="Full filename (inc working directory) of the merged R data",
15 | required=TRUE)
16 | parser$add_argument("--chick-file", type="character", dest="chick_file",
17 | help="Full filename (inc working directory) of the chick data CSV file",
18 | required=TRUE)
19 | parser$add_argument("--results-wd", type="character", dest="results_wd",
20 | help="Working directory to save the resulting visualizations",
21 | required=TRUE)
22 | parser$add_argument("--bar-order", type = "character", default = "default", dest = "bar_order",
23 | help="Order of bars. Use 'default', 'asc', 'desc', or specify indices separated by commas (e.g., '3,2,1,4')",
24 | required=FALSE)
25 | parser$add_argument("--color-bars", type = "character", dest="color_bars",
26 | help="Should the bars be colored by test condition?",
27 | required=TRUE)
28 |
29 | # Set script variables based on user input
30 | args <- parser$parse_args()
31 | data_loc <- args$data_loc; chick_file <- args$chick_file; results_wd <- args$results_wd; bar_order <- args$bar_order
32 | if( args$color_bars %in% c("t", "T", "true", "TRUE", "True")) {color_bars <- TRUE} else { color_bars <- FALSE}
33 |
34 | # Set Up -----------------------------------------------------------------------
35 |
36 | library(tidyverse)
37 | library(stringr)
38 |
39 | # Load the chick data
40 | chick_data <- read.csv(chick_file)
41 |
42 | # Load test data
43 | load(data_loc)
44 | rm(train_data)
45 |
46 | # Code each episode correct/incorrect
47 | test_data <- test_data %>%
48 | mutate(correct_steps = if_else(correct.monitor == " left", left_steps, right_steps)) %>%
49 | mutate(incorrect_steps = if_else(correct.monitor == " left", right_steps, left_steps)) %>%
50 | mutate(percent_correct = correct_steps / (correct_steps + incorrect_steps))
51 |
52 | # Adjust bar order according to user input -------------------------------------
53 |
54 | # Create a variable to store the final order
55 | order <- NULL
56 | if (bar_order == "default" || bar_order == "asc" || bar_order == "desc"){
57 | order <- bar_order
58 | }else {
59 | order <- as.integer(strsplit(order_input, ",")[[1]])
60 | }
61 |
62 | # Conditionally reorder the dataframe based on user input
63 | if (!is.null(order)) {
64 | if (order == "desc") {
65 | test_data <- test_data %>%
66 | arrange(desc(percent_correct)) %>%
67 | mutate(test.cond = factor(test.cond, levels = unique(test.cond)))
68 | } else if (order == "asc"){
69 | test_data <- test_data %>%
70 | arrange(percent_correct) %>%
71 | mutate(test.cond = factor(test.cond, levels = unique(test.cond)))
72 | } else if (order != "default") {
73 | # Map numeric indices to factor levels
74 | current_order <- levels(factor(test_data$test.cond))
75 | new_order <- current_order[order]
76 | test_data$test.cond <- factor(test_data$test.cond, levels = new_order)
77 | }
78 | # If order is "default", no need to change anything
79 | }
80 |
81 |
82 | # Plot aesthetic settings ------------------------------------------------------
83 | custom_palette <- c("#3F8CB7", "#FCEF88", "#5D5797", "#62AC6B", "#B74779", "#2C4E98","#CCCCE7", "#08625B", "#D15056")
84 | chickred <- "#AF264A"
85 |
86 | p <- ggplot() +
87 | theme_classic() +
88 | theme(axis.text.x = element_text(size = 6)) +
89 | ylab("Percent Correct") +
90 | xlab("Test Condition") +
91 | scale_y_continuous(expand = c(0, 0), limits = c(0, 1), breaks=seq(0,1,.1), labels = scales::percent) +
92 | geom_hline(yintercept = .5, linetype = 2) +
93 | scale_fill_manual(values = custom_palette) +
94 | scale_colour_manual(values = custom_palette) +
95 | theme(axis.title = element_text(face="bold"),
96 | axis.text.x = element_text(face="bold", size=7.5),
97 | axis.text.y = element_text(face="bold", size=7.5))
98 |
99 |
100 | # Bar Chart Function -----------------------------------------------------------
101 | make_bar_charts <- function(data, dots, aes_y, error_min, error_max, img_name)
102 | {
103 | p +
104 |
105 | # Add chicken performance FIRST to sort the bars
106 | geom_errorbar(data=chick_data, width = 0.7, colour = chickred,
107 | aes(x=test.cond, ymin=avg, ymax=avg)) +
108 |
109 | # Model performance: bars
110 | {if(color_bars)geom_col(data = data, width = 0.7, aes(x=test.cond, y = {{aes_y}}, fill = test.cond))}+
111 | {if(!color_bars)geom_col(data = data, width = 0.7, aes(x=test.cond, y = {{aes_y}}), fill = "gray45")}+
112 | # Model performance: error bars
113 | geom_errorbar(data = data, width = 0.3,
114 | aes(x = test.cond, ymin = {{error_min}}, ymax = {{error_max}})) +
115 | # Model performance: dots
116 | {if(!is.null(dots))geom_jitter(data = dot_data, aes(x=test.cond, y = avgs), width = .3)}+
117 | theme(legend.position="none") +
118 |
119 | # Add chicken performance again so that it shows up on top
120 | # Chick performance: lines (errorbar) with ribbons (crossbar)
121 | geom_errorbar(data=chick_data, width = 0.7, colour = chickred,
122 | aes(x=test.cond, ymin=avg, ymax=avg)) +
123 | geom_crossbar(data=chick_data, width = 0.7,
124 | linetype = 0, fill = chickred, alpha = 0.2,
125 | aes(x = test.cond, y = avg,
126 | ymin = avg - avg_dev, ymax = avg + avg_dev))
127 |
128 | ggsave(img_name, width = 6, height = 6)
129 | }
130 |
131 | # Switch wd before we save the graphs
132 | setwd(results_wd)
133 |
134 | # Plot by agent ----------------------------------------------------------------
135 | ## Leave rest data for agent-level graphs
136 |
137 | ## Group data by test conditions
138 | by_test_cond <- test_data %>%
139 | group_by(imprint.cond, agent, test.cond) %>%
140 | summarise(avgs = mean(percent_correct, na.rm = TRUE),
141 | sd = sd(percent_correct, na.rm = TRUE),
142 | count = length(percent_correct),
143 | tval = tryCatch({ (t.test(percent_correct, mu=0.5)$statistic)}, error = function(err){NA}),
144 | df = tryCatch({(t.test(percent_correct, mu=0.5)$parameter)},error = function(err){NA}),
145 | pval = tryCatch({(t.test(percent_correct, mu=0.5)$p.value)},error = function(err){NA}))%>%
146 | mutate(se = sd / sqrt(count)) %>%
147 | mutate(cohensd = (avgs - .5) / sd) %>%
148 | mutate(imp_agent = paste(imprint.cond, agent, sep="_"))
149 |
150 | write.csv(by_test_cond, "stats_by_agent.csv")
151 |
152 | for (i in unique(by_test_cond$imp_agent))
153 | {
154 | bar_data <- by_test_cond %>%
155 | filter(imp_agent == i)
156 |
157 | img_name <- paste0(i, "_test.png")
158 |
159 | make_bar_charts(data = bar_data,
160 | dots = NULL,
161 | aes_y = avgs,
162 | error_min = avgs - se,
163 | error_max = avgs + se,
164 | img_name = img_name)
165 | }
166 |
167 |
168 | # Plot by imprinting condition -------------------------------------------------
169 | ## Remove rest data once we start to group agents (for ease of presentation)
170 |
171 | by_imp_cond <- by_test_cond %>%
172 | ungroup() %>%
173 | group_by(imprint.cond, test.cond) %>%
174 | summarise(avgs_by_imp = mean(avgs, na.rm = TRUE),
175 | sd = sd(avgs, na.rm = TRUE),
176 | count = length(avgs),
177 | tval = tryCatch({ (t.test(avgs, mu=0.5)$statistic)}, error = function(err){NA}),
178 | df = tryCatch({ (t.test(avgs, mu=0.5)$parameter)}, error = function(err){NA}),
179 | pval = tryCatch({ (t.test(avgs, mu=0.5)$p.value)}, error = function(err){NA}))%>%
180 | mutate(se = sd / sqrt(count)) %>%
181 | mutate(cohensd = (avgs_by_imp - .5) / sd)
182 |
183 | write.csv(by_imp_cond, "stats_by_imp_cond.csv")
184 |
185 | for (i in unique(by_imp_cond$imprint.cond))
186 | {
187 | bar_data <- by_imp_cond %>%
188 | filter(imprint.cond == i) %>%
189 | filter(test.cond != "Rest")
190 |
191 | dot_data <- by_test_cond %>%
192 | filter(imprint.cond == i) %>%
193 | filter(test.cond != "Rest")
194 |
195 | img_name <- paste0(i, "_test.png")
196 |
197 | make_bar_charts(data = bar_data,
198 | dots = dot_data,
199 | aes_y = avgs_by_imp,
200 | error_min = avgs_by_imp - se,
201 | error_max = avgs_by_imp + se,
202 | img_name = img_name)
203 | }
204 |
205 |
206 | # Plot across all imprinting conditions ----------------------------------------
207 | across_imp_cond <- by_test_cond %>%
208 | ungroup() %>%
209 | filter(test.cond != "Rest") %>%
210 | group_by(test.cond) %>%
211 | summarise(all_avgs = mean(avgs, na.rm = TRUE),
212 | sd = sd(avgs, na.rm = TRUE),
213 | count = length(avgs),
214 | tval = tryCatch({ (t.test(avgs, mu=0.5)$statistic)}, error = function(err){NA}),
215 | df = tryCatch({ (t.test(avgs, mu=0.5)$parameter)}, error = function(err){NA}),
216 | pval = tryCatch({ (t.test(avgs, mu=0.5)$p.value)}, error = function(err){NA}))%>%
217 | mutate(se = sd / sqrt(count)) %>%
218 | mutate(cohensd = (all_avgs - .5) / sd)
219 |
220 | write.csv(across_imp_cond, "stats_across_all_agents.csv")
221 |
222 | dot_data <- filter(by_test_cond, test.cond != "Rest")
223 |
224 | make_bar_charts(data = across_imp_cond,
225 | dots = dot_data,
226 | aes_y = all_avgs,
227 | error_min = all_avgs - se,
228 | error_max = all_avgs + se,
229 | img_name = "all_imprinting_conds_test.png")
230 |
--------------------------------------------------------------------------------
/src/nett/environment/builder.py:
--------------------------------------------------------------------------------
1 | """Module for the Environment class."""
2 |
3 | from __future__ import annotations
4 |
5 | import os
6 | import subprocess
7 | from typing import Optional, Any
8 |
9 | import numpy as np
10 | from gym import Wrapper
11 | from mlagents_envs.environment import UnityEnvironment
12 |
13 | # checks to see if ml-agents tmp files have the proper permissions
14 | try :
15 | from mlagents_envs.envs.unity_gym_env import UnityToGymWrapper
16 | except PermissionError as _:
17 | raise PermissionError("Directory '/tmp/ml-agents-binaries' is not accessible. Please change permissions of the directory and its subdirectories ('tmp' and 'binaries') to 1777 or delete the entire directory and try again.")
18 |
19 | from nett.environment import configs
20 | from nett.environment.configs import NETTConfig, list_configs
21 | from nett.utils.environment import Logger, port_in_use
22 |
23 | class Environment(Wrapper):
24 | """
25 | Represents the environment where the agent lives.
26 |
27 | The environment is the source of all input data streams to train the brain of the agent.
28 | It accepts a Unity Executable and wraps it around as a Gym environment by leveraging the UnityEnvironment
29 | class from the mlagents_envs library.
30 |
31 | It provides a convenient interface for interacting with the Unity environment and includes methods for initializing the environment, rendering frames, taking steps, resetting the environment, and logging messages.
32 |
33 | Args:
34 | config (str | NETTConfig): The configuration for the environment. It can be either a string representing the name of a pre-defined configuration, or an instance of the NETTConfig class.
35 | executable_path (str): The path to the Unity executable file.
36 | display (int, optional): The display number to use for the Unity environment. Defaults to 0.
37 | base_port (int, optional): The base port number to use for communication with the Unity environment. Defaults to 5004.
38 | record_chamber (bool, optional): Whether to record the chamber. Defaults to False.
39 | record_agent (bool, optional): Whether to record the agent. Defaults to False.
40 | recording_frames (int, optional): The number of frames to record. Defaults to 1000.
41 |
42 | Raises:
43 | ValueError: If the configuration is not a valid string or an instance of NETTConfig.
44 |
45 | Example:
46 |
47 | >>> from nett import Environment
48 | >>> env = Environment(config="identityandview", executable_path="path/to/executable")
49 | """
50 | def __init__(self,
51 | config: str | NETTConfig,
52 | executable_path: str,
53 | display: int = 0,
54 | base_port: int = 5004,
55 | record_chamber: bool = False,
56 | record_agent: bool = False,
57 | recording_frames: int = 1000) -> None:
58 | """Constructor method
59 | """
60 | from nett import logger
61 | self.logger = logger.getChild(__class__.__name__)
62 | self.config = self._validate_config(config)
63 | # TODO (v0.5) what might be a way to check if it is a valid executable path?
64 | self.executable_path = executable_path
65 | self.base_port = base_port
66 | self.record_chamber = record_chamber
67 | self.record_agent = record_agent
68 | self.recording_frames = recording_frames
69 | self.display = display
70 |
71 | # set the correct permissions on the executable
72 | self._set_executable_permission()
73 | # set the display for Unity environment
74 | self._set_display()
75 |
76 | def _validate_config(self, config: str | NETTConfig) -> NETTConfig:
77 | """
78 | Validates the configuration for the environment.
79 |
80 | Args:
81 | config (str | NETTConfig): The configuration to validate.
82 |
83 | Returns:
84 | NETTConfig: The validated configuration.
85 |
86 | Raises:
87 | ValueError: If the configuration is not a valid string or an instance of NETTConfig.
88 | """
89 | # for when config is a str
90 | if isinstance(config, str):
91 | config_dict = {config_str.lower(): config_str for config_str in list_configs()}
92 | if config not in config_dict.keys():
93 | raise ValueError(f"Should be one of {config_dict.keys()}")
94 |
95 | config = getattr(configs, config_dict[config])()
96 |
97 | # for when config is a NETTConfig
98 | elif isinstance(config, NETTConfig):
99 | pass
100 |
101 | else:
102 | raise ValueError(f"Should either be one of {list(config_dict.keys())} or a subclass of NETTConfig")
103 |
104 | return config
105 |
106 | def _set_executable_permission(self) -> None:
107 | """
108 | Sets the executable permission for the Unity executable file.
109 | """
110 | subprocess.run(["chmod", "-R", "755", self.executable_path], check=True)
111 | self.logger.info("Executable permission is set")
112 |
113 | def _set_display(self) -> None:
114 | """
115 | Sets the display environment variable for the Unity environment.
116 | """
117 | os.environ["DISPLAY"] = str(f":{self.display}")
118 | self.logger.info("Display is set")
119 |
120 |
121 | # copied from __init__() of chickai_env_wrapper.py (legacy)
122 | # TODO (v0.4) Critical refactor, don't like how this works, extremely error prone.
123 | # how can we build + constraint arguments better? something like an ArgumentParser sounds neat
124 | # TODO (v0.4) fix random_pos logic inside of Unity code
125 | def initialize(self, mode: str, **kwargs) -> None:
126 | """
127 | Initializes the environment with the given mode and arguments.
128 |
129 | Args:
130 | mode (str): The mode to set the environment for training or testing or both.
131 | **kwargs: The arguments to pass to the environment.
132 | """
133 |
134 | args = []
135 |
136 | # from environment arguments
137 | if self.recording_frames:
138 | args.extend(["--recording-steps", str(self.recording_frames)])
139 | if self.record_chamber:
140 | args.extend(["--record-chamber", "true"])
141 | if self.record_agent:
142 | args.extend(["--record-agent", "true"])
143 |
144 | # from runtime
145 | args.extend(["--mode", f"{mode}-{kwargs['condition']}"])
146 | if kwargs.get("rec_path", None):
147 | args.extend(["--log-dir", f"{kwargs['rec_path']}/"])
148 | # needs to fixed in Unity code where the default is always false
149 | if mode == "train":
150 | args.extend(["--random-pos", "true"])
151 | if kwargs.get("rewarded", False):
152 | args.extend(["--rewarded", "true"])
153 | self.step_per_episode = kwargs.get("episode_steps", 1000)
154 | args.extend(["--episode-steps", str(self.step_per_episode)])
155 |
156 |
157 | # if kwargs["device_type"] == "cpu":
158 | # args.extend(["-batchmode", "-nographics"])
159 | # elif kwargs["batch_mode"]:
160 | if kwargs["batch_mode"]:
161 | args.append("-batchmode")
162 |
163 | # TODO: Figure out a way to run on multiple GPUs
164 | # if ("device" in kwargs):
165 | # args.extend(["-force-device-index", str(kwargs["device"])])
166 | # args.extend(["-gpu", str(kwargs["device"])])
167 |
168 | # find unused port
169 | while port_in_use(self.base_port):
170 | self.base_port += 1
171 |
172 | # create logger
173 | self.log = Logger(f"{kwargs['condition'].replace('-', '_')}{kwargs['run_id']}-{mode}",
174 | log_dir=f"{kwargs['log_path']}/")
175 |
176 | # create environment and connect it to logger
177 | self.env = UnityEnvironment(self.executable_path, side_channels=[self.log], additional_args=args, base_port=self.base_port)
178 | self.env = UnityToGymWrapper(self.env, uint8_visual=True)
179 |
180 | # initialize the parent class (gym.Wrapper)
181 | super().__init__(self.env)
182 |
183 | # converts the (c, w, h) frame returned by mlagents v1.0.0 and Unity 2022.3 to (w, h, c)
184 | # as expected by gym==0.21.0
185 | # HACK: mode is not used, but is required by the gym.Wrapper class (might be unnecessary but keeping for now)
186 | def render(self, mode="rgb_array") -> np.ndarray: # pylint: disable=unused-argument
187 | """
188 | Renders the current frame of the environment.
189 |
190 | Args:
191 | mode (str, optional): The mode to render the frame in. Defaults to "rgb_array".
192 |
193 | Returns:
194 | numpy.ndarray: The rendered frame of the environment.
195 | """
196 | return np.moveaxis(self.env.render(), [0, 1, 2], [2, 0, 1])
197 |
198 | def step(self, action: list[Any]) -> tuple[np.ndarray, float, bool, dict]:
199 | """
200 | Takes a step in the environment with the given action.
201 |
202 | Args:
203 | action (list[Any]): The action to take in the environment.
204 |
205 | Returns:
206 | tuple[numpy.ndarray, float, bool, dict]: A tuple containing the next state, reward, done flag, and info dictionary.
207 | """
208 | next_state, reward, done, info = self.env.step(action)
209 | return next_state, float(reward), done, info
210 |
211 | def log(self, msg: str) -> None:
212 | """
213 | Logs a message to the environment.
214 |
215 | Args:
216 | msg (str): The message to log.
217 | """
218 | self.log.log_str(msg)
219 |
220 | def reset(self, seed: Optional[int] = None, **kwargs) -> None | list[np.ndarray] | np.ndarray: # pylint: disable=unused-argument
221 | # nothing to do if the wrapped env does not accept `seed`
222 | """
223 | Resets the environment with the given seed and arguments.
224 |
225 | Args:
226 | seed (int, optional): The seed to use for the environment. Defaults to None.
227 | **kwargs: The arguments to pass to the environment.
228 |
229 | Returns:
230 | numpy.ndarray: The initial state of the environment.
231 | """
232 | return self.env.reset(**kwargs)
233 |
234 | def __repr__(self) -> str:
235 | attrs = {k: v for k, v in vars(self).items() if k != "logger"}
236 | return f"{self.__class__.__name__}({attrs!r})"
237 |
238 | def __str__(self) -> str:
239 | attrs = {k: v for k, v in vars(self).items() if k != "logger"}
240 | return f"{self.__class__.__name__}({attrs!r})"
241 |
--------------------------------------------------------------------------------
/examples/run/run.py:
--------------------------------------------------------------------------------
1 | # pylint: skip-file
2 | from abc import abstractmethod
3 | import os
4 | import logging
5 | import argparse
6 | import yaml
7 | from nett import Brain, Body, Environment
8 | from nett import NETT
9 | from nett.environment.configs import Binding, Parsing, ViewInvariant
10 | from wrapper.dvs_wrapper import DVSWrapper
11 |
12 | logger = logging.getLogger(__name__)
13 | logger.setLevel(logging.INFO)
14 |
15 | def load_configuration(config_path: str):
16 | with open(config_path, 'r') as f:
17 | return yaml.safe_load(f)
18 |
19 | class BodyConfiguration:
20 | def __init__(self, kwargs):
21 | for key, value in kwargs.items():
22 | setattr(self, key, value)
23 |
24 | class BrainConfiguration:
25 | def __init__(self, kwargs):
26 | for key, value in kwargs.items():
27 | setattr(self, key, value)
28 |
29 | class EnvironmentConfiguration:
30 | def __init__(self, kwargs):
31 | for key, value in kwargs.items():
32 | setattr(self, key, value)
33 |
34 | class Experiment:
35 | """
36 | Generic Experiment Class To Run 3 experiments - Parsing, Binding and ViewInvariant
37 | """
38 |
39 | def __init__(self, **kwargs) -> None:
40 | ## initialize configurations
41 | self.brain_config = BrainConfiguration(kwargs.get('Brain'))
42 | self.body_config = BodyConfiguration(kwargs.get('Body'))
43 | self.env_config = EnvironmentConfiguration(kwargs.get('Environment'))
44 | self.base_simclr_checkpoint_path = os.path.join(os.getcwd(), "../data/checkpoints")
45 |
46 | self.encoder_config = {
47 | "small": {
48 | "feature_dimensions": 512, # replace with actual feature dimensions for 'small'
49 | "encoder": "",
50 | },
51 | "medium": {
52 | "feature_dimensions": 128, # replace with actual feature dimensions for 'medium'
53 | "encoder": "resnet10",
54 | },
55 | "large": {
56 | "feature_dimensions": 128, # replace with actual feature dimensions for 'large'
57 | "encoder": "resnet18",
58 | },
59 | "dinov2": {
60 | "feature_dimensions": 384, # replace with actual feature dimensions for 'dinov2'
61 | "encoder": "dinvo2",
62 | },
63 | "dinov1": {
64 | "feature_dimensions": 384, # replace with actual feature dimensions for 'dinov1',
65 | "encoder": "dinov1",
66 | },
67 | "simclr": {
68 | "feature_dimensions": 512, # replace with actual feature dimensions for 'ego4d'
69 | "encoder": "frozensimclr",
70 | },
71 | "sam": {
72 | "feature_dimensions": 256, # replace with actual feature dimensions for 'sam'
73 | "encoder": "sam",
74 | }
75 | }
76 |
77 | ## Environment
78 | self.env = self.initialize_environment()
79 |
80 | ## Body
81 | self.body = self.initialize_body()
82 |
83 | ## Brain
84 | self.brain = self.initialize_brain()
85 |
86 | ## configuration
87 | config = kwargs.get('Config')
88 | self.train_eps = config['train_eps']
89 | self.test_eps = config['test_eps']
90 | self.mode = config['mode']
91 | self.num_brains = config['num_brains']
92 | self.output_dir = config['output_dir']
93 | self.run_id = config['run_id']
94 |
95 | print(self.train_eps, self.test_eps, self.mode, self.num_brains, self.output_dir, self.run_id)
96 |
97 | def initialize_brain(self):
98 | """
99 | Initialize Brain class with the attributes extracted from the brain_config
100 |
101 | Returns:
102 | _type_: _description_
103 | """
104 |
105 | # Extract attributes from brain_config
106 | brain_config_attrs = {attr: getattr(self.brain_config, attr) for attr in dir(self.brain_config) \
107 | if not attr.startswith('__')}
108 |
109 |
110 | ## update encoder attr in brain_config
111 | if brain_config_attrs['encoder']:
112 | encoder_config = self.encoder_config[brain_config_attrs['encoder']]
113 | brain_config_attrs['encoder'] = encoder_config['encoder']
114 | brain_config_attrs['embedding_dim'] = encoder_config['feature_dimensions']
115 |
116 |
117 |
118 | ## Add checkpoint path
119 | if brain_config_attrs.get('encoder') == 'frozensimclr':
120 | checkpt_path = self.get_checkpoint_path()
121 | brain_config_attrs['custom_encoder_args'] = {'checkpoint_path':\
122 | self.get_checkpoint_path()}
123 |
124 | # Initialize Brain class with the extracted attributes
125 | brain = Brain(**brain_config_attrs)
126 | return brain
127 |
128 | def initialize_body(self):
129 | wrappers = []
130 | if self.body_config.dvs:
131 | wrappers = [DVSWrapper]
132 |
133 | return Body(type='basic', wrappers=wrappers)
134 |
135 | @abstractmethod
136 | def initialize_environment(self):
137 | pass## abstract method to be implemented by the child classes
138 |
139 | def run(self):
140 | benchmarks = NETT(brain=self.brain, body=self.body, environment=self.env)
141 | print(self.mode)
142 | benchmarks.run(num_brains=self.num_brains, \
143 | train_eps=self.train_eps, \
144 | test_eps=self.train_eps, \
145 | mode=self.mode, \
146 | job_memory=21, \
147 | output_dir=self.output_dir,run_id=self.run_id)
148 |
149 | #logger.info("Experiment completed successfully")
150 |
151 | class ParsingExperiment(Experiment):
152 | def __init__(self, **kwargs) -> None:
153 | super().__init__(**kwargs)
154 |
155 | def get_checkpoint_path(self):
156 | ## compute simclr checkpoints
157 | checkpoint_dict = {
158 | "ship_a": "ship_A/checkpoints/epoch=97-step=14601.ckpt",
159 | "ship_b": "ship_B/checkpoints/epoch=97-step=14601.ckpt",
160 | "ship_c": "ship_C/checkpoints/epoch=96-step=14452.ckpt",
161 | "fork_b": "fork_B/checkpoints/epoch=95-step=14303.ckpt",
162 | "fork_a": "fork_A/checkpoints/epoch=97-step=14601.ckpt",
163 | "fork_c": "fork_C/checkpoints/epoch=97-step=14601.ckpt"
164 | }
165 |
166 | parsing_checkpoint = os.path.join(self.base_simclr_checkpoint_path, 'simclr_parsing')
167 | checkpoint_key = f"{self.object.lower()}_{self.background.lower()}"
168 | path = checkpoint_dict.get(checkpoint_key, '')
169 | return os.path.join(parsing_checkpoint, path)
170 |
171 | def initialize_environment(self):
172 | """
173 | Initialize environment class with the attributes extracted from the env_config
174 |
175 | Returns:
176 | _type_: _description_
177 | """
178 | self.object = "ship" if getattr(self.env_config, 'use_ship', False) else "fork"
179 | self.background = getattr(self.env_config, 'background', '')
180 |
181 | # Extract attributes from brain_config
182 | env_config_attrs = {attr: getattr(self.env_config, attr) for attr in dir(self.env_config) \
183 | if not attr.startswith('__')}
184 |
185 | del env_config_attrs['use_ship']
186 | del env_config_attrs['background']
187 |
188 | env_config_attrs['config'] = Parsing(background=self.background, object=self.object)
189 | return Environment(**env_config_attrs)
190 |
191 | class BindingExperiment(Experiment):
192 | def __init__(self, **kwargs) -> None:
193 | super().__init__(**kwargs)
194 |
195 | def get_checkpoint_path(self):
196 | ## compute simclr checkpoints
197 | checkpoint_dict = {
198 | "object_1": "object_1/checkpoints/epoch=97-step=14601.ckpt",
199 | "object_2": "object_2/checkpoints/epoch=97-step=14601.ckpt"
200 | }
201 |
202 | parsing_checkpoint = os.path.join(self.base_simclr_checkpoint_path, 'simclr_binding')
203 | checkpoint_key = f"{self.object.lower()}"
204 | path = checkpoint_dict.get(checkpoint_key, '')
205 | return os.path.join(parsing_checkpoint, path)
206 |
207 | def initialize_environment(self):
208 | """
209 | Initialize environment class with the attributes extracted from the env_config
210 |
211 | Returns:
212 | _type_: _description_
213 | """
214 | # Extract attributes from brain_config
215 | env_config_attrs = {attr: getattr(self.env_config, attr) for attr in dir(self.env_config) \
216 | if not attr.startswith('__')}
217 | del env_config_attrs['object']
218 | env_config_attrs['config'] = Binding(object= self.env_config.object)
219 | return Environment(**env_config_attrs)
220 |
221 | class ViewInvariantExperiment(Experiment):
222 | def __init__(self, **kwargs) -> None:
223 | super().__init__(**kwargs)
224 |
225 | def get_checkpoint_path(self):
226 | ## compute simclr checkpoints
227 | checkpoint_dict = {
228 | "ship_side":"",
229 | "ship_front":"",
230 | "fork_side":"",
231 | "fork_front":""
232 | }
233 |
234 | viewpt_checkpoint = os.path.join(self.base_simclr_checkpoint_path, 'simclr_viewpt')
235 | checkpoint_key = f"{self.object}_{self.view.lower()}"
236 | path = checkpoint_dict.get(checkpoint_key, '')
237 | return os.path.join(viewpt_checkpoint, path)
238 |
239 | def initialize_environment(self):
240 | """
241 | Initialize environment class with the attributes extracted from the env_config
242 |
243 | Returns:
244 | _type_: _description_
245 | """
246 | self.object = "ship" if getattr(self.env_config, 'use_ship', False) else "fork"
247 | self.view = "side" if getattr(self.env_config, 'side_view', False) else "front"
248 |
249 | # Extract attributes from brain_config
250 | env_config_attrs = {attr: getattr(self.env_config, attr) for attr in dir(self.env_config) \
251 | if not attr.startswith('__')}
252 |
253 | del env_config_attrs['use_ship']
254 | del env_config_attrs['side_view']
255 |
256 | env_config_attrs['config'] = ViewInvariant(object=self.object, view=self.view)
257 | return Environment(**env_config_attrs)
258 |
259 | def main():
260 | args = parse_args()
261 |
262 | if args.exp_name:
263 | exp_name = args.exp_name
264 | config_path = f'configuration/{exp_name}.yaml'
265 | config = load_configuration(config_path)
266 |
267 | if exp_name == 'parsing':
268 | exp = ParsingExperiment(**config)
269 | exp.run()
270 |
271 |
272 | elif exp_name == 'binding':
273 | exp = BindingExperiment(**config)
274 | exp.run()
275 |
276 | elif exp_name == 'viewinvariant':
277 | exp = ViewInvariantExperiment(**config)
278 | exp.run()
279 |
280 | else:
281 | raise ValueError("Invalid Experiment Name")
282 |
283 | def parse_args():
284 | parser = argparse.ArgumentParser(description='Run the NETT pipeline - NeurIPS 2021 submission')
285 | parser.add_argument('-exp_name', '--exp_name', type=str, required=True, default="binding",
286 | help='name of the experiment')
287 | return parser.parse_args()
288 |
289 |
290 | if __name__ == '__main__':
291 | main()
292 |
--------------------------------------------------------------------------------
/src/nett/environment/configs.py:
--------------------------------------------------------------------------------
1 | """This module contains the NETT configurations for different experiments."""
2 |
3 | import sys
4 | import inspect
5 | from typing import Any
6 | from abc import ABC, abstractmethod
7 | from itertools import product
8 |
9 | # the naming is confusing since it is used for train or test too.
10 | class NETTConfig(ABC):
11 | """Abstract base class for NETT configurations.
12 |
13 | Args:
14 | param_defaults (dict[str, str]): A dictionary of parameter defaults.
15 | **params: Keyword arguments representing the configuration parameters.
16 |
17 | Raises:
18 | ValueError: If any parameter value is not a value or subset of the default values.
19 | """
20 |
21 | def __init__(self, param_defaults: dict[str, str], **params) -> None:
22 | """Constructor method
23 | """
24 | self.param_defaults = param_defaults
25 | self.params = self._validate_params(params)
26 | self.conditions = self._create_conditions_from_params(self.params)
27 |
28 | def _create_conditions_from_params(self, params: dict[str, str]) -> list[str]:
29 | """
30 | Creates conditions from the configuration parameters.
31 |
32 | Args:
33 | params (dict[str, str]): The configuration parameters.
34 |
35 | Returns:
36 | list[str]: A list of conditions.
37 | """
38 | combination_params = list(product(*params.values()))
39 | conditions = ["-".join(combination).lower() for combination in combination_params]
40 | return conditions
41 |
42 | def _normalize_params(self, params: dict[str, str | int | float]) -> dict[str, str]:
43 | """
44 | Normalizes the configuration parameters.
45 |
46 | Args:
47 | params (dict[str, str | int | float]): The configuration parameters.
48 |
49 | Returns:
50 | dict[str, str]: The normalized configuration parameters.
51 | """
52 | params = {param: (value if isinstance(value, list) else [value]) for param, value in params.items()}
53 | params = {param: [str(item) for item in value] for param, value in params.items()}
54 | return params
55 |
56 | def _validate_params(self, params: dict[str, str]) -> dict[str, str]:
57 | """
58 | Validates the configuration parameters.
59 |
60 | Args:
61 | params (dict[str, str]): The configuration parameters.
62 |
63 | Returns:
64 | dict[str, str]: The validated configuration parameters.
65 |
66 | Raises:
67 | ValueError: If any parameter value is not a value or subset of the default values.
68 | """
69 | params = self._normalize_params(params)
70 | for (values, default_values) in zip(params.values(), self.param_defaults.values()):
71 | if not set(values) <= set(default_values):
72 | raise ValueError(f"{values} should be a value or subset of {default_values}")
73 | return params
74 |
75 | @property
76 | def defaults(self) -> dict[str, Any]:
77 | """
78 | Get the default values of the configuration parameters.
79 |
80 | Returns:
81 | dict[str, Any]: A dictionary of parameter defaults.
82 | """
83 | signature = inspect.signature(self.__init__)
84 | return {param: value.default for param, value in signature.parameters.items()
85 | if value.default is not inspect.Parameter.empty}
86 |
87 | @property
88 | @abstractmethod
89 | def num_conditions(self) -> int:
90 | """
91 | Get the number of conditions for the configuration.
92 |
93 | Returns:
94 | int: The number of conditions.
95 | """
96 | pass
97 |
98 |
99 | class IdentityAndView(NETTConfig):
100 | """
101 | NETT configuration for Identity and View.
102 |
103 | Args:
104 | object (str | list[str]): The object(s) to be used. Defaults to ["object1", "object2"].
105 | rotation (str | list[str]): The rotation(s) to be used. Defaults to ["horizontal", "vertical"].
106 |
107 | Raises:
108 | ValueError: If any parameter value is not a value or subset of the default values.
109 | """
110 |
111 | def __init__(self,
112 | object: str | list[str] = ["object1", "object2"],
113 | rotation: str | list[str] = ["horizontal", "vertical"]) -> None:
114 | """Constructor method
115 | """
116 | super().__init__(param_defaults=self.defaults,
117 | object=object,
118 | rotation=rotation)
119 |
120 | @property
121 | def num_conditions(self) -> int:
122 | """
123 | Get the number of conditions for the configuration.
124 |
125 | Returns:
126 | int: The number of conditions.
127 | """
128 | return 18
129 |
130 |
131 | class Binding(NETTConfig):
132 | """
133 | NETT configuration for Binding.
134 |
135 | Args:
136 | object (str | list[str]): The object(s) to be used. Defaults to ["object1", "object2"].
137 |
138 | Raises:
139 | ValueError: If any parameter value is not a value or subset of the default values.
140 | """
141 |
142 | def __init__(self,
143 | object: str | list[str] = ["object1", "object2"]) -> None:
144 | """Constructor method
145 | """
146 | super().__init__(param_defaults=self.defaults,
147 | object=object)
148 |
149 | @property
150 | def num_conditions(self) -> int:
151 | """
152 | Get the number of conditions for the configuration.
153 |
154 | Returns:
155 | int: The number of conditions.
156 | """
157 | return 50
158 |
159 |
160 | class Parsing(NETTConfig):
161 | """
162 | NETT configuration for Parsing.
163 |
164 | Args:
165 | background (str | list[str], optional): The background(s) to be used. Defaults to ["A", "B", "C"].
166 | object (str | list[str], optional): The object(s) to be used. Defaults to ["ship", "fork"].
167 | """
168 |
169 | def __init__(self,
170 | background: str | list[str] = ["A", "B", "C"],
171 | object: str | list[str] = ["ship", "fork"]) -> None:
172 | """Constructor method
173 | """
174 | super().__init__(param_defaults=self.defaults,
175 | background=background,
176 | object=object)
177 |
178 | @property
179 | def num_conditions(self) -> int:
180 | """
181 | Get the number of conditions for the configuration.
182 |
183 | Returns:
184 | int: The number of conditions.
185 | """
186 | return 56
187 |
188 |
189 | class Slowness(NETTConfig):
190 | """
191 | NETT configuration for Slowness.
192 |
193 | Args:
194 | experiment (str | list[int], optional): The experiment(s) to be used. Defaults to [1, 2].
195 | object (str | list[str], optional): The object(s) to be used. Defaults to ["obj1", "obj2"].
196 | speed (str | list[str], optional): The speed(s) to be used. Defaults to ["slow", "med", "fast"].
197 |
198 | Raises:
199 | ValueError: If any parameter value is not a value or subset of the default values.
200 | """
201 |
202 | def __init__(self,
203 | experiment: str | list[int] = [1, 2],
204 | object: str | list[str] = ["obj1", "obj2"],
205 | speed: str | list[str] = ["slow", "med", "fast"]) -> None:
206 | """Constructor method
207 | """
208 | super().__init__(param_defaults=self.defaults,
209 | experiment=experiment,
210 | object=object,
211 | speed=speed)
212 |
213 | @property
214 | def num_conditions(self) -> int:
215 | """
216 | Get the number of conditions for the configuration.
217 |
218 | Returns:
219 | int: The number of conditions.
220 | """
221 | if self.params["experiment"] == "1":
222 | return 5
223 | return 13
224 |
225 |
226 | class Smoothness(NETTConfig):
227 | """
228 | NETT configuration for Smoothness.
229 |
230 | Args:
231 | object (str or list[str], optional): The object(s) to be used. Defaults to ["obj1"].
232 | temporal (str or list[str], optional): The temporal condition(s) to be used. Defaults to ["norm", "scram"].
233 |
234 | Attributes:
235 | num_conditions (int): The number of conditions for the configuration.
236 | """
237 |
238 | def __init__(self,
239 | object: str | list[str] = ["obj1"],
240 | temporal: str | list[str] = ["norm", "scram"]) -> None:
241 | """Constructor method
242 | """
243 | super().__init__(param_defaults=self.defaults,
244 | object=object,
245 | temporal=temporal)
246 |
247 | @property
248 | def num_conditions(self) -> int:
249 | """
250 | Get the number of conditions for the configuration.
251 |
252 | Returns:
253 | int: The number of conditions.
254 | """
255 | return 5
256 |
257 |
258 | class OneShotViewInvariant(NETTConfig):
259 | """
260 | NETT configuration for One-Shot View Invariant.
261 |
262 | Args:
263 | object (str | list[str]): The object(s) to be used. Defaults to ["fork", "ship"].
264 | range (str | list[str]): The range(s) to be used. Defaults to ["360", "small", "1"].
265 | view (str | list[str]): The view(s) to be used. Defaults to ["front", "side"].
266 |
267 | Raises:
268 | ValueError: If any parameter value is not a value or subset of the default values.
269 | """
270 |
271 | def __init__(self,
272 | object: str | list[str] = ["fork", "ship"],
273 | range: str | list[str] = ["360", "small", "1"],
274 | view: str | list[str] = ["front", "side"]) -> None:
275 | """Constructor method
276 | """
277 | super().__init__(param_defaults=self.defaults,
278 | object=object,
279 | range=range,
280 | view=view)
281 |
282 | @property
283 | def num_conditions(self) -> int:
284 | """
285 | Get the number of conditions for the configuration.
286 |
287 | Returns:
288 | int: The number of conditions.
289 | """
290 | return 50
291 |
292 |
293 | class ViewInvariant(NETTConfig):
294 | """
295 | NETT configuration for Binding.
296 |
297 | Args:
298 | object (str | list[str]): The object(s) to be used. Defaults to ["object1", "object2"].
299 |
300 | Raises:
301 | ValueError: If any parameter value is not a value or subset of the default values.
302 | """
303 |
304 | def __init__(self,
305 | object: str | list[str] = ["ship", "fork"],
306 | view: str | list[str] = ["front", "side"]) -> None:
307 | """Constructor method
308 | """
309 | super().__init__(param_defaults=self.defaults,
310 | object=object, view = view)
311 |
312 | @property
313 | def num_conditions(self) -> int:
314 | """
315 | Get the number of conditions for the configuration.
316 |
317 | Returns:
318 | int: The number of conditions.
319 | """
320 | if self.view.lower()=="front":
321 | return 50
322 | return 26
323 |
324 |
325 | def list_configs() -> list[str]:
326 | """
327 | Lists all available NETT configurations.
328 |
329 | Returns:
330 | list[str]: A list of configuration names.
331 | """
332 | #TODO: Are these really strings?
333 | is_class_member = lambda member: inspect.isclass(member) and member.__module__ == __name__
334 | clsmembers = inspect.getmembers(sys.modules[__name__], is_class_member)
335 | clsmembers = [clsmember[0] for clsmember in clsmembers if clsmember[0] != "NETTConfig"]
336 | return clsmembers
337 |
--------------------------------------------------------------------------------
/src/nett/brain/encoders/disembodied_models/archs/resnet_1b.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn as nn
3 |
4 | from pl_bolts.utils import _TORCHVISION_AVAILABLE
5 | from pl_bolts.utils.warnings import warn_missing_pkg
6 |
7 | __all__ = [
8 | 'ResNet',
9 | 'resnet18',
10 | 'resnet34',
11 | 'resnet50',
12 | 'resnet101',
13 | 'resnet152',
14 | 'resnext50_32x4d',
15 | 'resnext101_32x8d',
16 | 'wide_resnet50_2',
17 | 'wide_resnet101_2',
18 | ]
19 |
20 | MODEL_URLS = {
21 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
22 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
23 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
24 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
25 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
26 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
27 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
28 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
29 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
30 | }
31 |
32 |
33 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1) -> nn.Conv2d:
34 | """
35 | 3x3 convolution with padding
36 |
37 | Args:
38 | in_planes (int): number of input channels
39 | out_planes (int): number of output channels
40 | stride (int): stride for the convolution
41 | groups (int): number of groups for the convolution
42 | dilation (int): dilation rate for the convolution
43 |
44 | Returns:
45 | nn.Conv2d: 3x3 convolution layer
46 | """
47 | return nn.Conv2d(
48 | in_planes,
49 | out_planes,
50 | kernel_size=3,
51 | stride=stride,
52 | padding=dilation,
53 | groups=groups,
54 | bias=False,
55 | dilation=dilation
56 | )
57 |
58 |
59 | def conv1x1(in_planes, out_planes, stride=1) -> nn.Conv2d:
60 | """
61 | 1x1 convolution
62 |
63 | Args:
64 | in_planes (int): number of input channels
65 | out_planes (int): number of output channels
66 | stride (int): stride for the convolution
67 |
68 | Returns:
69 | nn.Conv2d: 1x1 convolution layer
70 | """
71 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
72 |
73 |
74 | class BasicBlock(nn.Module):
75 | """
76 | Basic block for ResNet
77 |
78 | Args:
79 | inplanes (int): number of input channels
80 | planes (int): number of output channels
81 | stride (int): stride for the first convolution
82 | downsample (nn.Module): downsample layer
83 | groups (int): number of groups for the 3x3 convolution
84 | base_width (int): number of channels per group
85 | dilation (int): dilation rate for the 3x3 convolution
86 | norm_layer (nn.Module): normalization layer
87 | """
88 | expansion = 1
89 |
90 | def __init__(
91 | self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None
92 | ):
93 | super(BasicBlock, self).__init__()
94 | if norm_layer is None:
95 | norm_layer = nn.BatchNorm2d
96 | if groups != 1 or base_width != 64:
97 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
98 | if dilation > 1:
99 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
100 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
101 | self.conv1 = conv3x3(inplanes, planes, stride)
102 | self.bn1 = norm_layer(planes)
103 | self.relu = nn.ReLU(inplace=True)
104 | self.conv2 = conv3x3(planes, planes)
105 | self.bn2 = norm_layer(planes)
106 | self.downsample = downsample
107 | self.stride = stride
108 |
109 | def forward(self, x: torch.Tensor) -> torch.Tensor:
110 | """
111 | Forward pass in the network
112 |
113 | Args:
114 | x (torch.Tensor): input tensor
115 |
116 | Returns:
117 | torch.Tensor: output tensor
118 | """
119 | identity = x
120 |
121 | out = self.conv1(x)
122 | out = self.bn1(out)
123 | out = self.relu(out)
124 |
125 | out = self.conv2(out)
126 | out = self.bn2(out)
127 |
128 | if self.downsample is not None:
129 | identity = self.downsample(x)
130 |
131 | out += identity
132 | out = self.relu(out)
133 |
134 | return out
135 |
136 | class ResNet(nn.Module):
137 | """
138 | ResNet model
139 |
140 | Args:
141 | block (nn.Module): the block type to be used in the network
142 | layers (list of int): the number of layers for each block
143 | num_classes (int): number of classes
144 | zero_init_residual (bool): whether the residual block should be initialized to zero
145 | groups (int): number of groups for the 3x3 convolution
146 | width_per_group (int): number of channels per group
147 | replace_stride_with_dilation (tuple): replace stride with dilation
148 | norm_layer (nn.Module): normalization layer
149 | return_all_feature_maps (bool): whether to return all feature maps
150 | first_conv (bool): whether to use the first convolution
151 | maxpool1 (bool): whether to use the first maxpool
152 | """
153 |
154 | def __init__(
155 | self,
156 | block, # what kind of block, for ex - basic block or bottleneck
157 | layers, # how many layers or basic blocks in each residual block
158 | num_classes=1000,
159 | zero_init_residual=False,
160 | groups=1,
161 | width_per_group=64,
162 | replace_stride_with_dilation=None,
163 | norm_layer=None,
164 | return_all_feature_maps=False,
165 | first_conv=False,
166 | maxpool1=False
167 | ):
168 | super(ResNet, self).__init__()
169 | if norm_layer is None:
170 | norm_layer = nn.BatchNorm2d
171 | self._norm_layer = norm_layer
172 | self.return_all_feature_maps = return_all_feature_maps
173 |
174 |
175 |
176 | self.inplanes = 3 # what is inplanes and planes in CNN ???? it should be 64 as per original implementation, does not work with 64 in case of block 1
177 | self.dilation = 1
178 | if replace_stride_with_dilation is None:
179 | # each element in the tuple indicates if we should replace
180 | # the 2x2 stride with a dilated convolution instead
181 | replace_stride_with_dilation = [False, False, False]
182 | if len(replace_stride_with_dilation) != 3:
183 | raise ValueError(
184 | "replace_stride_with_dilation should be None "
185 | f"or a 3-element tuple, got {replace_stride_with_dilation}"
186 | )
187 | self.groups = groups
188 | self.base_width = width_per_group
189 |
190 | # ------ layers before first residual block ---------------
191 |
192 | if first_conv:
193 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
194 | else:
195 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
196 |
197 | self.bn1 = norm_layer(self.inplanes)
198 | self.relu = nn.ReLU(inplace=True)
199 |
200 | if maxpool1:
201 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
202 | else:
203 | self.maxpool = nn.MaxPool2d(kernel_size=1, stride=1)
204 |
205 | # ------ residual blocks start here ------------------------
206 |
207 | self.layer1 = self._make_layer(block, 512, layers[0])
208 |
209 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
210 |
211 | self.fc = nn.Linear(512 * block.expansion, num_classes)
212 |
213 |
214 | for m in self.modules():
215 | if isinstance(m, nn.Conv2d):
216 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
217 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
218 | nn.init.constant_(m.weight, 1)
219 | nn.init.constant_(m.bias, 0)
220 |
221 | # Zero-initialize the last BN in each residual branch,
222 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
223 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
224 | if zero_init_residual:
225 | for m in self.modules():
226 | if isinstance(m, Bottleneck):
227 | nn.init.constant_(m.bn3.weight, 0)
228 | elif isinstance(m, BasicBlock):
229 | nn.init.constant_(m.bn2.weight, 0)
230 |
231 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False) -> nn.Sequential:
232 | """
233 | Make a layer of blocks.
234 |
235 | Args:
236 | block (nn.Module): the block type to be used in the layer
237 | planes (int): number of output channels for the layer
238 | blocks (int): number of blocks to be used
239 | stride (int): stride for the first block. Default: 1
240 | dilate (bool): whether to apply dilation strategy to the layer. Default: False
241 |
242 | Returns:
243 | nn.Sequential: a layer of blocks
244 | """
245 | norm_layer = self._norm_layer
246 | downsample = None
247 | previous_dilation = self.dilation
248 | if dilate:
249 | self.dilation *= stride
250 | stride = 1
251 | if stride != 1 or self.inplanes != planes * block.expansion:
252 | downsample = nn.Sequential(
253 | conv1x1(self.inplanes, planes * block.expansion, stride),
254 | norm_layer(planes * block.expansion),
255 | )
256 |
257 | layers = []
258 | layers.append(
259 | block(
260 | self.inplanes,
261 | planes,
262 | stride,
263 | downsample,
264 | self.groups,
265 | self.base_width,
266 | previous_dilation,
267 | norm_layer,
268 | )
269 | )
270 | self.inplanes = planes * block.expansion
271 | for _ in range(1, blocks):
272 | layers.append(
273 | block(
274 | self.inplanes,
275 | planes,
276 | groups=self.groups,
277 | base_width=self.base_width,
278 | dilation=self.dilation,
279 | norm_layer=norm_layer
280 | )
281 | )
282 |
283 | return nn.Sequential(*layers)
284 |
285 |
286 | def forward(self, x: torch.Tensor) -> torch.Tensor:
287 | """
288 | Forward pass in the network
289 |
290 | Args:
291 | x (torch.Tensor): input tensor
292 |
293 | Returns:
294 | torch.Tensor: output tensor
295 | """
296 | x0 = self.conv1(x)
297 | x0 = self.bn1(x0)
298 | x0 = self.relu(x0)
299 | x0 = self.maxpool(x0)
300 | #print(x0.shape)
301 | if self.return_all_feature_maps:
302 | x1 = self.layer1(x)
303 | return [x0, x1]
304 | else:
305 | x0 = self.layer1(x)
306 |
307 | x0 = self.avgpool(x0) # output shape = [256X1X1]
308 |
309 | x0 = x0.reshape(x0.shape[0],-1)
310 | return x0
311 |
312 |
313 | def _resnet(arch, block, layers, pretrained, progress, **kwargs) -> ResNet:
314 | """
315 | Constructs a ResNet model.
316 |
317 | Args:
318 | arch (str): model architecture
319 | block (nn.Module): the block type to be used in the network
320 | layers (list of int): the number of layers for each block
321 | pretrained (bool): if True, returns a model pre-trained on ImageNet
322 | progress (bool): if True, displays a progress bar of the download to stderr
323 |
324 | Returns:
325 | ResNet: model
326 | """
327 | model = ResNet(block, layers, **kwargs)
328 | if pretrained:
329 | state_dict = load_state_dict_from_url(MODEL_URLS[arch], progress=progress)
330 | model.load_state_dict(state_dict)
331 | # Remove the last fc layer, since we only need the encoder part of resnet.
332 | #model.fc = nn.Identity()
333 |
334 | # we cannot remove the last fc layer in smaller version of the resnet because we are using the fc layer
335 | # to flatten the output and making it equal dimensions for testing and training the classifier.
336 | return model
337 |
338 |
339 | def resnet_1block(pretrained: bool = False, progress: bool = True, **kwargs) -> ResNet:
340 | """ResNet-18 model from
341 | `"Deep Residual Learning for Image Recognition" `
342 |
343 | Args:
344 | pretrained: If True, returns a model pre-trained on ImageNet
345 | progress: If True, displays a progress bar of the download to stderr
346 |
347 | Returns:
348 | ResNet: ResNet-18 model
349 | """
350 |
351 | return _resnet('resnet18', BasicBlock, [1, 1, 1, 1], pretrained, progress, **kwargs)
352 |
353 |
--------------------------------------------------------------------------------
/src/nett/brain/encoders/resnet10.py:
--------------------------------------------------------------------------------
1 | """
2 | Resnet10CNN feature extractor for stable-baselines3
3 | """
4 | import pdb
5 | import gym
6 |
7 |
8 | import torch as th
9 | import torch.nn as nn
10 | import torchvision
11 |
12 | from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
13 |
14 | import logging
15 | logger = logging.getLogger(__name__)
16 |
17 | class Resnet10CNN(BaseFeaturesExtractor):
18 | """
19 | Resnet10CNN feature extractor for stable-baselines3
20 |
21 | Args:
22 | observation_space (gym.spaces.Box): Observation space
23 | features_dim (int, optional): Output dimension of features extractor. Defaults to 256.
24 | """
25 |
26 | def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256):
27 | super(Resnet10CNN, self).__init__(observation_space, features_dim)
28 | # We assume CxHxW images (channels first)
29 | # Re-ordering will be done by pre-preprocessing or wrapper
30 |
31 | n_input_channels = observation_space.shape[0]
32 |
33 | self.cnn = _resnet(BasicBlock, [2, 2, 2, 2],num_channels = n_input_channels)
34 | logger.info(f"Resnet10CNN Encoder: {self.cnn}")
35 |
36 | with th.no_grad():
37 | n_flatten = self.cnn(
38 | th.as_tensor(observation_space.sample()[None]).float()
39 | ).shape[1]
40 |
41 | self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())
42 |
43 | def forward(self, observations: th.Tensor) -> th.Tensor:
44 | """
45 | Forward pass of the feature extractor.
46 |
47 | Args:
48 | observations (torch.Tensor): The input observations.
49 |
50 | Returns:
51 | torch.Tensor: The extracted features.
52 | """
53 | # Cut off image
54 | # reshape to from vector to W*H
55 | # gray to color transform
56 | # application of ResNet
57 | # Concat features to the rest of observation vector
58 | # return
59 | return self.linear(self.cnn(observations))
60 |
61 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1) -> nn.Conv2d:
62 | """
63 | 3x3 convolution with padding
64 |
65 | Args:
66 | in_planes (int): Number of input channels
67 | out_planes (int): Number of output channels
68 | stride (int, optional): Stride of the convolution. Defaults to 1.
69 | groups (int, optional): Number of groups for the convolution. Defaults to 1.
70 | dilation (int, optional): Dilation rate of the convolution. Defaults to 1.
71 |
72 | Returns:
73 | nn.Conv2d: Convolutional layer
74 | """
75 | return nn.Conv2d(
76 | in_planes,
77 | out_planes,
78 | kernel_size=3,
79 | stride=stride,
80 | padding=dilation,
81 | groups=groups,
82 | bias=False,
83 | dilation=dilation
84 | )
85 |
86 |
87 | def conv1x1(in_planes, out_planes, stride=1):
88 | """
89 | 1x1 convolution
90 |
91 | Args:
92 | in_planes (int): Number of input channels
93 | out_planes (int): Number of output channels
94 | stride (int, optional): Stride of the convolution. Defaults to 1.
95 |
96 | Returns:
97 | nn.Conv2d: Convolutional layer
98 | """
99 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
100 |
101 |
102 | class BasicBlock(nn.Module):
103 | """
104 | Basic block used in the ResNet-18 architecture.
105 |
106 | Args:
107 | inplanes (int): Number of input channels
108 | planes (int): Number of output channels
109 | stride (int, optional): Stride of the convolution. Defaults to 1.
110 | downsample (nn.Module, optional): Downsample layer. Defaults to None.
111 | groups (int, optional): Number of groups for the convolution. Defaults to 1.
112 | base_width (int, optional): Base width for the convolution. Defaults to 64.
113 | dilation (int, optional): Dilation rate of the convolution. Defaults to 1.
114 | norm_layer ([type], optional): Normalization layer. Defaults to None.
115 | """
116 | expansion = 1
117 |
118 | def __init__(
119 | self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None
120 | ):
121 | super(BasicBlock, self).__init__()
122 | if norm_layer is None:
123 | norm_layer = nn.BatchNorm2d
124 | if groups != 1 or base_width != 64:
125 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
126 | if dilation > 1:
127 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
128 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
129 |
130 | # layers inside each basic block of a residual block
131 | self.conv1 = conv3x3(inplanes, planes, stride)
132 | self.bn1 = norm_layer(planes)
133 | self.relu = nn.ReLU(inplace=True)
134 | self.conv2 = conv3x3(planes, planes)
135 | self.bn2 = norm_layer(planes)
136 | # last two are operations and not layers
137 | self.downsample = downsample
138 | self.stride = stride
139 |
140 | def forward(self, x: th.Tensor) -> th.Tensor:
141 | """
142 | Forward pass in the network
143 |
144 | Args:
145 | x (torch.Tensor): input tensor
146 |
147 | Returns:
148 | torch.Tensor: output tensor
149 | """
150 | # saving x to pass over the bridge connection
151 | identity = x
152 |
153 | out = self.conv1(x)
154 | out = self.bn1(out)
155 | out = self.relu(out)
156 |
157 | out = self.conv2(out)
158 | out = self.bn2(out)
159 |
160 | if self.downsample is not None:
161 | identity = self.downsample(x)
162 |
163 | out += identity
164 | out = self.relu(out)
165 |
166 | return out
167 |
168 | class ResNet(nn.Module):
169 | """
170 | ResNet architecture used in the Resnet10CNN class.
171 |
172 | Args:
173 | block (nn.Module): Residual block to use
174 | layers (list): Number of layers in each block
175 | num_channels (int): Number of input channels
176 | num_classes (int, optional): Number of classes. Defaults to 1000.
177 | zero_init_residual (bool, optional): Zero initialization for the residual block. Defaults to False.
178 | groups (int, optional): Number of groups for the convolution. Defaults to 1.
179 | width_per_group (int, optional): Base width for the convolution. Defaults to 64.
180 | replace_stride_with_dilation (tuple, optional): Replace stride with dilation. Defaults to None.
181 | norm_layer ([type], optional): Normalization layer. Defaults to None.
182 | return_all_feature_maps (bool, optional): Return all feature maps. Defaults to False.
183 | first_conv (bool, optional): Pre-processing layers which makes the image size half [64->32]. Defaults to True.
184 | maxpool1 (bool, optional): Used in pre-processing. Defaults to True.
185 | """
186 |
187 | def __init__(
188 | self,
189 | block,
190 | layers,
191 | num_channels,
192 | num_classes=1000,
193 | zero_init_residual=False,
194 | groups=1,
195 | width_per_group=64,
196 | replace_stride_with_dilation=None,
197 | norm_layer=None,
198 | return_all_feature_maps=False,
199 | first_conv=True, # pre-processing layers which makes the image size half [64->32]
200 | maxpool1=True # used in pre-processing
201 | ):
202 | super(ResNet, self).__init__()
203 | if norm_layer is None:
204 | norm_layer = nn.BatchNorm2d
205 | self._norm_layer = norm_layer
206 | self.return_all_feature_maps = return_all_feature_maps
207 |
208 | self.inplanes = 64
209 | self.dilation = 1
210 | if replace_stride_with_dilation is None:
211 | # each element in the tuple indicates if we should replace
212 | # the 2x2 stride with a dilated convolution instead
213 | replace_stride_with_dilation = [False, False, False]
214 | if len(replace_stride_with_dilation) != 3:
215 | raise ValueError(
216 | "replace_stride_with_dilation should be None "
217 | f"or a 3-element tuple, got {replace_stride_with_dilation}"
218 | )
219 | self.groups = groups
220 | self.base_width = width_per_group
221 |
222 | # ------ layers before first residual block ---------------
223 | if first_conv:
224 | self.conv1 = nn.Conv2d(num_channels, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
225 |
226 | else:
227 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
228 |
229 |
230 |
231 | self.bn1 = norm_layer(self.inplanes)
232 | self.relu = nn.ReLU(inplace=True)
233 |
234 | if maxpool1:
235 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
236 | else:
237 | self.maxpool = nn.MaxPool2d(kernel_size=1, stride=1)
238 |
239 | # ------ residual blocks start here ------------------------
240 |
241 | # BLOCK - 1
242 | self.layer1 = self._make_layer(block, 64, layers[0])
243 |
244 | # BLOCK - 2
245 | self.layer2 = self._make_layer(block, 512, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
246 |
247 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
248 | self.fc = nn.Linear(512 * block.expansion, num_classes)
249 |
250 | for m in self.modules():
251 | if isinstance(m, nn.Conv2d):
252 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
253 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
254 | nn.init.constant_(m.weight, 1)
255 | nn.init.constant_(m.bias, 0)
256 |
257 | # Zero-initialize the last BN in each residual branch,
258 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
259 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
260 | if zero_init_residual:
261 | for m in self.modules():
262 | if isinstance(m, Bottleneck):
263 | nn.init.constant_(m.bn3.weight, 0)
264 | elif isinstance(m, BasicBlock):
265 | nn.init.constant_(m.bn2.weight, 0)
266 |
267 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False) -> nn.Sequential:
268 | """
269 | Helper function to create a residual layer.
270 |
271 | Args:
272 | block (nn.Module): Residual block to use
273 | planes (int): Number of output channels
274 | blocks (int): Number of blocks
275 | stride (int, optional): Stride of the convolution. Defaults to 1.
276 | dilate (bool, optional): Use dilation. Defaults to False.
277 |
278 | Returns:
279 | nn.Sequential: Residual layer
280 | """
281 | norm_layer = self._norm_layer
282 | downsample = None
283 | previous_dilation = self.dilation
284 | if dilate:
285 | self.dilation *= stride
286 | stride = 1
287 | if stride != 1 or self.inplanes != planes * block.expansion:
288 | downsample = nn.Sequential(
289 | conv1x1(self.inplanes, planes * block.expansion, stride),
290 | norm_layer(planes * block.expansion),
291 | )
292 |
293 | layers = []
294 | layers.append(
295 | block(
296 | self.inplanes,
297 | planes,
298 | stride,
299 | downsample,
300 | self.groups,
301 | self.base_width,
302 | previous_dilation,
303 | norm_layer,
304 | )
305 | )
306 | self.inplanes = planes * block.expansion
307 | for _ in range(1, blocks):
308 | layers.append(
309 | block(
310 | self.inplanes,
311 | planes,
312 | groups=self.groups,
313 | base_width=self.base_width,
314 | dilation=self.dilation,
315 | norm_layer=norm_layer
316 | )
317 | )
318 |
319 | return nn.Sequential(*layers)
320 |
321 | def forward(self, x: th.Tensor) -> th.Tensor:
322 | """
323 | Forward pass in the network
324 |
325 | Args:
326 | x (torch.Tensor): input tensor
327 |
328 | Returns:
329 | torch.Tensor: output tensor
330 | """
331 |
332 | # passing input from pre-processing layers
333 | x0 = self.conv1(x)
334 | x0 = self.bn1(x0)
335 | x0 = self.relu(x0)
336 | x0 = self.maxpool(x0)
337 |
338 | # passing input from residual blocks
339 | if self.return_all_feature_maps:
340 | x1 = self.layer1(x0) # block1
341 | x2 = self.layer2(x1) # block2
342 |
343 | return [x0, x1, x2]
344 | else:
345 | x0 = self.layer1(x0)
346 | x0 = self.layer2(x0)
347 |
348 | x0 = self.avgpool(x0)
349 | x0 = th.flatten(x0, 1)
350 |
351 | return x0
352 |
353 |
354 | def _resnet(block, layers, **kwargs):
355 | """
356 | ResNet architecture used in the Resnet10CNN class.
357 |
358 | Args:
359 | block (nn.Module): Residual block to use
360 | layers (list): Number of layers in each block
361 |
362 | Returns:
363 | ResNet: ResNet model
364 | """
365 | model = ResNet(block, layers, **kwargs)
366 | model.fc = nn.Identity()
367 | return model
368 |
--------------------------------------------------------------------------------
/src/nett/brain/encoders/disembodied_models/archs/resnet_2b.py:
--------------------------------------------------------------------------------
1 | '''
2 | This file contains the implementation of ResNet-18 with 2 blocks
3 | Output from the second block now gives 512 channels instead of 128
4 | '''
5 |
6 |
7 | import torch
8 | from torch import nn as nn
9 |
10 | from pl_bolts.utils import _TORCHVISION_AVAILABLE
11 | from pl_bolts.utils.warnings import warn_missing_pkg
12 |
13 | # if _TORCHVISION_AVAILABLE:
14 | # #from torchvision.models.utils import load_state_dict_from_url
15 | # from torch.hub import load_state_dict_from_url
16 | # else: # pragma: no cover
17 | # warn_missing_pkg('torchvision')
18 |
19 | __all__ = [
20 | 'ResNet',
21 | 'resnet18',
22 | 'resnet34',
23 | 'resnet50',
24 | 'resnet101',
25 | 'resnet152',
26 | 'resnext50_32x4d',
27 | 'resnext101_32x8d',
28 | 'wide_resnet50_2',
29 | 'wide_resnet101_2',
30 | ]
31 |
32 | MODEL_URLS = {
33 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
34 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
35 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
36 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
37 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
38 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
39 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
40 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
41 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
42 | }
43 |
44 |
45 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1) -> nn.Conv2d:
46 | """
47 | 3x3 convolution with padding
48 |
49 | Args:
50 | in_planes (int): Number of input channels
51 | out_planes (int): Number of output channels
52 | stride (int): Stride
53 | groups (int): Number of groups
54 | dilation (int): Dilation
55 |
56 | Returns:
57 | nn.Conv2d: Convolution layer
58 | """
59 | return nn.Conv2d(
60 | in_planes,
61 | out_planes,
62 | kernel_size=3,
63 | stride=stride,
64 | padding=dilation,
65 | groups=groups,
66 | bias=False,
67 | dilation=dilation
68 | )
69 |
70 |
71 | def conv1x1(in_planes, out_planes, stride=1) -> nn.Conv2d:
72 | """
73 | 1x1 convolution
74 |
75 | Args:
76 | in_planes (int): Number of input channels
77 | out_planes (int): Number of output channels
78 | stride (int): Stride
79 |
80 | Returns:
81 | nn.Conv2d: Convolution layer"""
82 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
83 |
84 |
85 | class BasicBlock(nn.Module):
86 | """
87 | BasicBlock for ResNet
88 |
89 | Args:
90 | inplanes (int): Number of input channels
91 | planes (int): Number of output channels
92 | stride (int): Stride
93 | downsample (nn.Module): Downsample layer
94 | groups (int): Number of groups
95 | base_width (int): Base width
96 | dilation (int): Dilation
97 | norm_layer (nn.Module): Normalization layer
98 | """
99 | expansion = 1
100 |
101 | def __init__(
102 | self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None
103 | ):
104 | super(BasicBlock, self).__init__()
105 | if norm_layer is None:
106 | norm_layer = nn.BatchNorm2d
107 | if groups != 1 or base_width != 64:
108 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
109 | if dilation > 1:
110 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
111 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
112 |
113 | # layers inside each basic block of a residual block
114 | self.conv1 = conv3x3(inplanes, planes, stride)
115 | self.bn1 = norm_layer(planes)
116 | self.relu = nn.ReLU(inplace=True)
117 | self.conv2 = conv3x3(planes, planes)
118 | self.bn2 = norm_layer(planes)
119 | # last two are operations and not layers
120 | self.downsample = downsample
121 | self.stride = stride
122 |
123 | def forward(self, x: torch.Tensor) -> torch.Tensor:
124 | """
125 | Forward pass in the network
126 |
127 | Args:
128 | x (torch.Tensor): input tensor
129 |
130 | Returns:
131 | torch.Tensor: output tensor
132 | """
133 | # saving x to pass over the bridge connection
134 | identity = x
135 |
136 | out = self.conv1(x)
137 | out = self.bn1(out)
138 | out = self.relu(out)
139 |
140 | out = self.conv2(out)
141 | out = self.bn2(out)
142 |
143 | if self.downsample is not None:
144 | identity = self.downsample(x)
145 |
146 | out += identity
147 | out = self.relu(out)
148 |
149 | return out
150 |
151 |
152 | class ResNet(nn.Module):
153 | """
154 | ResNet model
155 |
156 | Args:
157 | block (nn.Module): ResNet block
158 | layers (list): List of layers
159 | num_classes (int): Number of classes
160 | zero_init_residual (bool): If True, zero-initialize the last BN in each residual branch
161 | groups (int): Number of groups
162 | width_per_group (int): Width per group
163 | replace_stride_with_dilation (tuple): Replace stride with dilation
164 | norm_layer (nn.Module): Normalization layer
165 | return_all_feature_maps (bool): If True, returns all feature maps
166 | first_conv (bool): If True, uses first conv layer
167 | maxpool1 (bool): If True, uses maxpool1 layer
168 | """
169 |
170 | def __init__(
171 | self,
172 | block,
173 | layers,
174 | num_classes=1000,
175 | zero_init_residual=False,
176 | groups=1,
177 | width_per_group=64,
178 | replace_stride_with_dilation=None,
179 | norm_layer=None,
180 | return_all_feature_maps=False,
181 | first_conv=True, # pre-processing layers which makes the image size half [64->32]
182 | maxpool1=True # used in pre-processing
183 | ):
184 | super(ResNet, self).__init__()
185 | if norm_layer is None:
186 | norm_layer = nn.BatchNorm2d
187 | self._norm_layer = norm_layer
188 | self.return_all_feature_maps = return_all_feature_maps
189 |
190 | self.inplanes = 64
191 | self.dilation = 1
192 | if replace_stride_with_dilation is None:
193 | # each element in the tuple indicates if we should replace
194 | # the 2x2 stride with a dilated convolution instead
195 | replace_stride_with_dilation = [False, False, False]
196 | if len(replace_stride_with_dilation) != 3:
197 | raise ValueError(
198 | "replace_stride_with_dilation should be None "
199 | f"or a 3-element tuple, got {replace_stride_with_dilation}"
200 | )
201 | self.groups = groups
202 | self.base_width = width_per_group
203 |
204 | # ------ layers before first residual block ---------------
205 |
206 | if first_conv:
207 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
208 | else:
209 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
210 |
211 | self.bn1 = norm_layer(self.inplanes)
212 | self.relu = nn.ReLU(inplace=True)
213 |
214 | if maxpool1:
215 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
216 | else:
217 | self.maxpool = nn.MaxPool2d(kernel_size=1, stride=1)
218 |
219 | # ------ residual blocks start here ------------------------
220 |
221 | # BLOCK - 1
222 | self.layer1 = self._make_layer(block, 64, layers[0])
223 |
224 | # BLOCK - 2
225 | self.layer2 = self._make_layer(block, 512, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
226 |
227 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
228 | self.fc = nn.Linear(512 * block.expansion, num_classes)
229 |
230 | for m in self.modules():
231 | if isinstance(m, nn.Conv2d):
232 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
233 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
234 | nn.init.constant_(m.weight, 1)
235 | nn.init.constant_(m.bias, 0)
236 |
237 | # Zero-initialize the last BN in each residual branch,
238 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
239 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
240 | if zero_init_residual:
241 | for m in self.modules():
242 | if isinstance(m, Bottleneck):
243 | nn.init.constant_(m.bn3.weight, 0)
244 | elif isinstance(m, BasicBlock):
245 | nn.init.constant_(m.bn2.weight, 0)
246 |
247 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False) -> nn.Sequential:
248 | """
249 | Creates a layer of residual blocks
250 |
251 | Args:
252 | block (nn.Module): ResNet block
253 | planes (int): Number of planes
254 | blocks (int): Number of blocks
255 | stride (int): Stride
256 | dilate (bool): If True, dilates the stride
257 |
258 | Returns:
259 | nn.Sequential: Residual block layer
260 | """
261 | norm_layer = self._norm_layer
262 | downsample = None
263 | previous_dilation = self.dilation
264 | if dilate:
265 | self.dilation *= stride
266 | stride = 1
267 | if stride != 1 or self.inplanes != planes * block.expansion:
268 | downsample = nn.Sequential(
269 | conv1x1(self.inplanes, planes * block.expansion, stride),
270 | norm_layer(planes * block.expansion),
271 | )
272 |
273 | layers = []
274 | layers.append(
275 | block(
276 | self.inplanes,
277 | planes,
278 | stride,
279 | downsample,
280 | self.groups,
281 | self.base_width,
282 | previous_dilation,
283 | norm_layer,
284 | )
285 | )
286 | self.inplanes = planes * block.expansion
287 | for _ in range(1, blocks):
288 | layers.append(
289 | block(
290 | self.inplanes,
291 | planes,
292 | groups=self.groups,
293 | base_width=self.base_width,
294 | dilation=self.dilation,
295 | norm_layer=norm_layer
296 | )
297 | )
298 |
299 | return nn.Sequential(*layers)
300 |
301 | def forward(self, x: torch.Tensor) -> torch.Tensor:
302 | """
303 | Forward pass in the network
304 |
305 | Args:
306 | x (torch.Tensor): input tensor
307 |
308 | Returns:
309 | torch.Tensor: output tensor
310 | """
311 | # passing input from pre-processing layers
312 | x0 = self.conv1(x)
313 | x0 = self.bn1(x0)
314 | x0 = self.relu(x0)
315 | x0 = self.maxpool(x0)
316 |
317 | # passing input from residual blocks
318 | if self.return_all_feature_maps:
319 | x1 = self.layer1(x0) # block1
320 | x2 = self.layer2(x1) # block2
321 |
322 | return [x0, x1, x2]
323 | else:
324 | x0 = self.layer1(x0)
325 | x0 = self.layer2(x0)
326 |
327 | x0 = self.avgpool(x0)
328 | x0 = torch.flatten(x0, 1)
329 |
330 | return x0
331 |
332 |
333 | def _resnet(arch, block, layers, pretrained, progress, **kwargs) -> ResNet:
334 | """
335 | Constructs a ResNet model.
336 |
337 | Args:
338 | arch (str): Architecture name from the URLs
339 | block (nn.Module): ResNet block
340 | layers (list): List of layers
341 | pretrained (bool): If True, returns a model pre-trained on ImageNet
342 | progress (bool): If True, displays a progress bar of the download to stderr
343 | **kwargs: Other arguments for the ResNet model
344 |
345 | Returns:
346 | ResNet: ResNet model
347 | """
348 | model = ResNet(block, layers, **kwargs)
349 | if pretrained:
350 | state_dict = load_state_dict_from_url(MODEL_URLS[arch], progress=progress)
351 | model.load_state_dict(state_dict)
352 | # Remove the last fc layer, since we only need the encoder part of resnet.
353 | model.fc = nn.Identity()
354 | return model
355 |
356 |
357 |
358 | def resnet_2blocks(pretrained: bool = False, progress: bool = True, **kwargs) -> nn.Module:
359 | """
360 | Constructs a ResNet-18 model with 2 blocks.
361 |
362 | Args:
363 | pretrained (bool): If True, returns a model pre-trained on ImageNet
364 | progress (bool): If True, displays a progress bar of the download to stderr
365 | **kwargs: Other arguments for the ResNet model
366 |
367 | Returns:
368 | nn.Module: ResNet-18 model with 2 blocks
369 | """
370 |
371 |
372 |
373 | """
374 | first argument in _resnet() : architecture name from the URLs
375 | since URL for resnet9 is not available, therefore resnet18 is used with modifications
376 | """
377 |
378 | # to print this architecture, print the model from the evaluator/evaluate file
379 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
380 |
--------------------------------------------------------------------------------
/src/nett/brain/encoders/disembodied_models/archs/resnet_3b.py:
--------------------------------------------------------------------------------
1 | '''
2 | This file contains the implementation of ResNet-18 with 3 blocks
3 | Output from the third block now gives 512 channels instead of 256
4 | '''
5 |
6 |
7 | import torch
8 | from torch import nn as nn
9 |
10 | from pl_bolts.utils import _TORCHVISION_AVAILABLE
11 | from pl_bolts.utils.warnings import warn_missing_pkg
12 |
13 | # if _TORCHVISION_AVAILABLE:
14 | # from torchvision.models.utils import load_state_dict_from_url
15 | # else: # pragma: no cover
16 | # warn_missing_pkg('torchvision')
17 |
18 | __all__ = [
19 | 'ResNet',
20 | 'resnet18',
21 | 'resnet34',
22 | 'resnet50',
23 | 'resnet101',
24 | 'resnet152',
25 | 'resnext50_32x4d',
26 | 'resnext101_32x8d',
27 | 'wide_resnet50_2',
28 | 'wide_resnet101_2',
29 | ]
30 |
31 | MODEL_URLS = {
32 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
33 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
34 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
35 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
36 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
37 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
38 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
39 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
40 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
41 | }
42 |
43 |
44 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1) -> nn.Conv2d:
45 | """
46 | 3x3 convolution with padding
47 |
48 | Args:
49 | in_planes (int): number of input planes
50 | out_planes (int): number of output planes
51 | stride (int, optional): stride. Defaults to 1.
52 | groups (int, optional): number of groups. Defaults to 1.
53 | dilation (int, optional): dilation. Defaults to 1.
54 |
55 | Returns:
56 | nn.Conv2d: convolution layer
57 | """
58 | return nn.Conv2d(
59 | in_planes,
60 | out_planes,
61 | kernel_size=3,
62 | stride=stride,
63 | padding=dilation,
64 | groups=groups,
65 | bias=False,
66 | dilation=dilation
67 | )
68 |
69 |
70 | def conv1x1(in_planes, out_planes, stride=1) -> nn.Conv2d:
71 | """
72 | 1x1 convolution
73 |
74 | Args:
75 | in_planes (int): number of input planes
76 | out_planes (int): number of output planes
77 | stride (int, optional): stride. Defaults to 1.
78 |
79 | Returns:
80 | nn.Conv2d: convolution layer
81 | """
82 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
83 |
84 |
85 | class BasicBlock(nn.Module):
86 | """
87 | Basic block for ResNet
88 |
89 | Args:
90 | inplanes (int): number of input planes
91 | planes (int): number of planes
92 | stride (int, optional): stride. Defaults to 1.
93 | downsample (nn.Module, optional): downsample. Defaults to None.
94 | groups (int, optional): number of groups. Defaults to 1.
95 | base_width (int, optional): base width. Defaults to 64.
96 | dilation (int, optional): dilation. Defaults to 1.
97 | norm_layer (nn.Module, optional): normalization layer. Defaults to None.
98 | """
99 | expansion = 1
100 |
101 | def __init__(
102 | self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None
103 | ):
104 | super(BasicBlock, self).__init__()
105 | if norm_layer is None:
106 | norm_layer = nn.BatchNorm2d
107 | if groups != 1 or base_width != 64:
108 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
109 | if dilation > 1:
110 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
111 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
112 |
113 | # layers inside each basic block of a residual block
114 | self.conv1 = conv3x3(inplanes, planes, stride)
115 | self.bn1 = norm_layer(planes)
116 | self.relu = nn.ReLU(inplace=True)
117 | self.conv2 = conv3x3(planes, planes)
118 | self.bn2 = norm_layer(planes)
119 | # last two are operations and not layers
120 | self.downsample = downsample
121 | self.stride = stride
122 |
123 | def forward(self, x) -> torch.Tensor:
124 | """
125 | Forward pass in the network
126 |
127 | Args:
128 | x (torch.Tensor): input tensor
129 |
130 | Returns:
131 | torch.Tensor: output tensor
132 | """
133 | # saving x to pass over the bridge connection
134 | identity = x
135 |
136 | out = self.conv1(x)
137 | out = self.bn1(out)
138 | out = self.relu(out)
139 |
140 | out = self.conv2(out)
141 | out = self.bn2(out)
142 |
143 | if self.downsample is not None:
144 | identity = self.downsample(x)
145 |
146 | out += identity
147 | out = self.relu(out)
148 |
149 | return out
150 |
151 |
152 | class ResNet(nn.Module):
153 | """
154 | ResNet model
155 |
156 | Args:
157 | block (nn.Module): block type
158 | layers (list): list of layers
159 | num_classes (int, optional): number of classes. Defaults to 1000.
160 | zero_init_residual (bool, optional): If True, zero-initialize the last BN in each residual branch. Defaults to False.
161 | groups (int, optional): number of groups. Defaults to 1.
162 | width_per_group (int, optional): width per group. Defaults to 64.
163 | replace_stride_with_dilation (tuple, optional): replace stride with dilation. Defaults to None.
164 | norm_layer (nn.Module, optional): normalization layer. Defaults to None.
165 | return_all_feature_maps (bool, optional): If True, returns all feature maps. Defaults to False.
166 | first_conv (bool, optional): If True, uses a 7x7 kernel for the first convolution. Defaults to True.
167 | maxpool1 (bool, optional): If True, uses a maxpool layer after the first convolution. Defaults to True.
168 | """
169 |
170 | def __init__(
171 | self,
172 | block,
173 | layers,
174 | num_classes=1000, # what should be the right parameter?
175 | zero_init_residual=False,
176 | groups=1,
177 | width_per_group=64,
178 | replace_stride_with_dilation=None,
179 | norm_layer=None,
180 | return_all_feature_maps=False,
181 | first_conv=True, #pre-processing layers which makes the image size half [64->32]
182 | maxpool1=True, #pre-processing
183 | ):
184 | super(ResNet, self).__init__()
185 | if norm_layer is None:
186 | norm_layer = nn.BatchNorm2d
187 | self._norm_layer = norm_layer
188 | self.return_all_feature_maps = return_all_feature_maps
189 |
190 | self.inplanes = 64
191 | self.dilation = 1
192 | if replace_stride_with_dilation is None:
193 | # each element in the tuple indicates if we should replace
194 | # the 2x2 stride with a dilated convolution instead
195 | replace_stride_with_dilation = [False, False, False]
196 | if len(replace_stride_with_dilation) != 3:
197 | raise ValueError(
198 | "replace_stride_with_dilation should be None "
199 | f"or a 3-element tuple, got {replace_stride_with_dilation}"
200 | )
201 | self.groups = groups
202 | self.base_width = width_per_group
203 |
204 | # ------ layers before first residual block ---------------
205 |
206 | if first_conv:
207 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
208 | else:
209 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
210 |
211 | self.bn1 = norm_layer(self.inplanes)
212 | self.relu = nn.ReLU(inplace=True)
213 |
214 | if maxpool1:
215 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
216 | else:
217 | self.maxpool = nn.MaxPool2d(kernel_size=1, stride=1)
218 |
219 | # ------ residual blocks start here ------------------------
220 |
221 | self.layer1 = self._make_layer(block, 64, layers[0])
222 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
223 | self.layer3 = self._make_layer(block, 512, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
224 | #self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
225 |
226 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
227 | self.fc = nn.Linear(512 * block.expansion, num_classes)
228 |
229 | for m in self.modules():
230 | if isinstance(m, nn.Conv2d):
231 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
232 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
233 | nn.init.constant_(m.weight, 1)
234 | nn.init.constant_(m.bias, 0)
235 |
236 | # Zero-initialize the last BN in each residual branch,
237 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
238 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
239 | if zero_init_residual:
240 | for m in self.modules():
241 | if isinstance(m, Bottleneck):
242 | nn.init.constant_(m.bn3.weight, 0)
243 | elif isinstance(m, BasicBlock):
244 | nn.init.constant_(m.bn2.weight, 0)
245 |
246 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False) -> nn.Sequential:
247 | """
248 | Create a layer of residual blocks
249 |
250 | Args:
251 | block (nn.Module): block type
252 | planes (int): number of planes
253 | blocks (int): number of blocks
254 | stride (int): stride
255 | dilate (bool): If True, use dilation
256 |
257 | Returns:
258 | nn.Sequential: layer of residual blocks
259 | """
260 | norm_layer = self._norm_layer
261 | downsample = None
262 | previous_dilation = self.dilation
263 | if dilate:
264 | self.dilation *= stride
265 | stride = 1
266 | if stride != 1 or self.inplanes != planes * block.expansion:
267 | downsample = nn.Sequential(
268 | conv1x1(self.inplanes, planes * block.expansion, stride),
269 | norm_layer(planes * block.expansion),
270 | )
271 |
272 | layers = []
273 | layers.append(
274 | block(
275 | self.inplanes,
276 | planes,
277 | stride,
278 | downsample,
279 | self.groups,
280 | self.base_width,
281 | previous_dilation,
282 | norm_layer,
283 | )
284 | )
285 | self.inplanes = planes * block.expansion
286 | for _ in range(1, blocks):
287 | layers.append(
288 | block(
289 | self.inplanes,
290 | planes,
291 | groups=self.groups,
292 | base_width=self.base_width,
293 | dilation=self.dilation,
294 | norm_layer=norm_layer
295 | )
296 | )
297 |
298 | return nn.Sequential(*layers)
299 |
300 | def forward(self, x: torch.Tensor) -> torch.Tensor:
301 | """
302 | Forward pass in the network
303 |
304 | Args:
305 | x (torch.Tensor): input tensor
306 |
307 | Returns:
308 | torch.Tensor: output tensor
309 | """
310 |
311 | # passing input from pre-processing layers
312 | x0 = self.conv1(x)
313 | x0 = self.bn1(x0)
314 | x0 = self.relu(x0)
315 | x0 = self.maxpool(x0)
316 |
317 | # passing input from residual blocks
318 | if self.return_all_feature_maps:
319 | x1 = self.layer1(x0)
320 | x2 = self.layer2(x1)
321 | x3 = self.layer3(x2)
322 |
323 |
324 | return [x0, x1, x2, x3]
325 |
326 | else:
327 | x0 = self.layer1(x0)
328 | x0 = self.layer2(x0)
329 | x0 = self.layer3(x0)
330 |
331 |
332 | x0 = self.avgpool(x0)
333 | x0 = torch.flatten(x0, 1)
334 |
335 | return x0
336 |
337 |
338 | def _resnet(arch, block, layers, pretrained, progress, **kwargs) -> ResNet:
339 | """
340 | Constructs a ResNet model.
341 |
342 | Args:
343 | arch (str): model architecture
344 | block (nn.Module): block type
345 | layers (list): list of layers
346 | pretrained (bool): If True, returns a model pre-trained on ImageNet
347 | progress (bool): If True, displays a progress bar of the download to stderr
348 |
349 | Returns:
350 | ResNet: model
351 | """
352 | model = ResNet(block, layers, **kwargs)
353 | if pretrained:
354 | state_dict = load_state_dict_from_url(MODEL_URLS[arch], progress=progress)
355 | model.load_state_dict(state_dict)
356 | # Remove the last fc layer, since we only need the encoder part of resnet.
357 | model.fc = nn.Identity()
358 | return model
359 |
360 |
361 | def resnet_3blocks(pretrained: bool = False, progress: bool = True, **kwargs) -> ResNet:
362 | """ResNet-18 model from
363 | `"Deep Residual Learning for Image Recognition" `
364 |
365 | Args:
366 | pretrained: If True, returns a model pre-trained on ImageNet
367 | progress: If True, displays a progress bar of the download to stderr
368 |
369 | Returns:
370 | ResNet: model
371 | """
372 |
373 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
--------------------------------------------------------------------------------