├── src └── nest │ ├── __main__.py │ ├── __init__.py │ ├── logger.py │ ├── settings.py │ ├── parser.py │ ├── utils.py │ ├── cli.py │ └── modules.py ├── .gitignore ├── setup.py ├── LICENSE └── README.md /src/nest/__main__.py: -------------------------------------------------------------------------------- 1 | from nest.cli import CLI 2 | 3 | def main(): 4 | CLI() 5 | 6 | 7 | if __name__ == '__main__': 8 | main() 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled python modules. 2 | *.pyc 3 | 4 | # Setuptools distribution folder. 5 | /dist/ 6 | 7 | # Python egg metadata, regenerated from source files by setuptools. 8 | /**/*.egg-info 9 | 10 | # vscode 11 | .vscode -------------------------------------------------------------------------------- /src/nest/__init__.py: -------------------------------------------------------------------------------- 1 | from nest.parser import run_tasks 2 | from nest.modules import Context, ModuleManager, module_manager 3 | 4 | 5 | # alias 6 | modules = module_manager 7 | register = ModuleManager._register 8 | 9 | __all__ = ['Context', 'modules', 'register', 'run_tasks'] 10 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from setuptools import setup, find_packages 3 | 4 | 5 | # check python version 6 | if sys.version_info < (3, 5, 4): 7 | sys.exit('Python < 3.5.4 is not supported.') 8 | 9 | setup( 10 | name='nest', 11 | version='0.1.1', 12 | description='Nest - A flexible tool for building and sharing deep learning modules', 13 | url='https://github.com/ZhouYanzhao/Nest', 14 | author='Zhou, Yanzhao', 15 | author_email='yzhou.work@outlook.com', 16 | license='MIT', 17 | packages=find_packages('src'), 18 | package_dir={'': 'src'}, 19 | install_requires=[ 20 | 'PyYAML', 21 | 'python-dateutil' 22 | ], 23 | entry_points={ 24 | 'console_scripts': ['nest=nest.__main__:main'], 25 | }, 26 | ) 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018-present Yanzhao Zhou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/nest/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import logging 4 | import warnings 5 | from typing import Callable 6 | 7 | from nest.settings import settings 8 | 9 | 10 | class ExceptionFilter(logging.Filter): 11 | """Avoid showing exceptions twice. 12 | """ 13 | 14 | def filter(self, record: object) -> bool: 15 | return record.levelno != logging.ERROR 16 | 17 | 18 | def setup_logger() -> logging.RootLogger: 19 | """Initialize logger. 20 | 21 | Returns: 22 | The global logger 23 | """ 24 | 25 | # set up logger 26 | logger = logging.getLogger('Nest') 27 | logger.setLevel(logging.DEBUG) 28 | # create a formatter and add it to the handlers 29 | screen_formatter = logging.Formatter('%(message)s') 30 | # create a console handler 31 | screen_handler = logging.StreamHandler() 32 | screen_handler.setLevel(logging.INFO) 33 | screen_handler.setFormatter(screen_formatter) 34 | screen_handler.addFilter(ExceptionFilter()) 35 | logger.addHandler(screen_handler) 36 | if settings['LOGGING_TO_FILE']: 37 | # create a file handler which logs warning and error messages 38 | file_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 39 | file_handler = logging.FileHandler(settings['LOGGING_PATH'], encoding='utf8') 40 | file_handler.setLevel(logging.WARNING) 41 | file_handler.setFormatter(file_formatter) 42 | logger.addHandler(file_handler) 43 | 44 | return logger 45 | 46 | 47 | def exception(func: Callable) -> Callable: 48 | """Decorator for logging errors and warnings of function. 49 | 50 | Parameters: 51 | func: 52 | The decorated function 53 | """ 54 | 55 | def wrapper(*args, **kwargs): 56 | try: 57 | with warnings.catch_warnings(record=True) as warning_list: 58 | warnings.simplefilter('always') 59 | res = func(*args, **kwargs) 60 | for w in warning_list: 61 | logger.warning(w.message) 62 | return res 63 | except Exception as exc_info: 64 | logger.exception(exc_info) 65 | raise 66 | 67 | return wrapper 68 | 69 | 70 | # create global logger 71 | logger = setup_logger() 72 | -------------------------------------------------------------------------------- /src/nest/settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | from typing import Union, Dict, Any 4 | 5 | 6 | SETTINGS_DIR = os.path.join(str(os.path.expanduser('~')), '.nest') 7 | TEMPLATE_FILE = os.path.join(SETTINGS_DIR, 'template.yml') 8 | SETTINGS_FILE = os.path.join(SETTINGS_DIR, 'settings.yml') 9 | 10 | DEFAULT_SETTINGS = """\ 11 | # Nest - A flexible tool for building and sharing deep learning modules 12 | # Settings Template (User custom settings should be placed in "settings.yml") 13 | 14 | # Enable logging to file 15 | LOGGING_TO_FILE: true 16 | 17 | # Logging path (default: /.nest/nest.log) 18 | LOGGING_PATH: null 19 | 20 | # User defined search paths {namespace: path} for Nest module auto-discover (default: {}) 21 | SEARCH_PATHS: null 22 | 23 | # Seperator between namespace and Nest module name 24 | NAMESPACE_SEP: '.' 25 | 26 | # Put namespace behind module name if set to true 27 | NAMESPACE_ORDER_REVERSE: false 28 | 29 | # Module manager update interval (seconds) 30 | UPDATE_INTERVAL: 1.5 31 | 32 | # Varaible prefix in Nest config syntax 33 | VARIABLE_PREFIX: '@' 34 | 35 | # Enable strict mode for config parser 36 | # User must specify 'delay_resolve' for Nest modules 37 | # in the config file if set to true. 38 | PARSER_STRICT: false 39 | 40 | # Namespace config file name 41 | NAMESPACE_CONFIG_FILENAME: 'nest.yml' 42 | 43 | # Automatically install requirements when install Nest modules 44 | AUTO_INSTALL_REQUIREMENTS: false 45 | 46 | # Threshold of missing dependency matching 47 | INSTALL_TIP_THRESHOLD: 0.15 48 | 49 | # Internel debug flags 50 | # Raises errors instead of warnings 51 | RAISES_ERROR: false 52 | """ 53 | 54 | 55 | class SettingManager(object): 56 | 57 | @staticmethod 58 | def save_settings(path: str, settings: Union[str, Dict[str, Any]]) -> None: 59 | """Save Nest settings. 60 | 61 | Parameters: 62 | path: 63 | Path to the setting file 64 | settings: 65 | The settings dict or string 66 | """ 67 | 68 | with open(path, 'w') as f: 69 | if isinstance(settings, str): 70 | f.write(settings) 71 | elif isinstance(settings, dict): 72 | yaml.dump(settings, f, default_flow_style=False) 73 | else: 74 | raise TypeError('The settings should have a type of "str" or "dict".') 75 | 76 | @staticmethod 77 | def load_settings() -> Dict[str, Any]: 78 | """Load Nest settings. 79 | 80 | Returns: 81 | The settings dict 82 | """ 83 | 84 | # create if not exists 85 | if not os.path.exists(SETTINGS_DIR): 86 | os.mkdir(SETTINGS_DIR) 87 | if not os.path.exists(TEMPLATE_FILE): 88 | SettingManager.save_settings(TEMPLATE_FILE, DEFAULT_SETTINGS) 89 | if not os.path.exists(SETTINGS_FILE): 90 | SettingManager.save_settings(SETTINGS_FILE, '# User custom settings') 91 | 92 | # load settings 93 | settings = yaml.load(DEFAULT_SETTINGS) 94 | with open(SETTINGS_FILE, 'r') as f: 95 | user_settings = yaml.load(f) or dict() 96 | settings.update(user_settings) 97 | 98 | # handle defaults 99 | if settings['LOGGING_PATH'] is None: 100 | settings['LOGGING_PATH'] = os.path.join(SETTINGS_DIR, 'nest.log') 101 | if settings['SEARCH_PATHS'] is None: 102 | settings['SEARCH_PATHS'] = dict() 103 | 104 | return settings, user_settings 105 | 106 | def __init__(self): 107 | self.load() 108 | 109 | def __getitem__(self, key: str): 110 | return self.settings[key] 111 | 112 | def __setitem__(self, key: str, val: str): 113 | self.user_settings[key] = val 114 | 115 | def __contains__(self, key): 116 | return key in self.settings.keys() 117 | 118 | def load(self): 119 | self.settings, self.user_settings = SettingManager.load_settings() 120 | 121 | def save(self): 122 | SettingManager.save_settings(SETTINGS_FILE, self.user_settings) 123 | 124 | # load global settings 125 | try: 126 | settings = SettingManager() 127 | except Exception as exc_info: 128 | raise RuntimeError('Unable to load global settings of Nest. %s' % exc_info) 129 | -------------------------------------------------------------------------------- /src/nest/parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from typing import Any, Dict, Union, Optional 4 | from datetime import datetime 5 | from copy import deepcopy 6 | 7 | import nest.utils as U 8 | from nest.modules import module_manager 9 | from nest.settings import settings 10 | from nest.logger import logger 11 | 12 | 13 | def parse_config( 14 | config: Union[list, dict], 15 | env_vars: Dict[str, str] = dict(), 16 | global_vars: Dict[str, str] = dict()) -> Union[list, dict]: 17 | """Parse experiment config. 18 | 19 | Parameters: 20 | config: 21 | The configuration of Nest modules, which specifies initial parameters, topologies, etc. 22 | env_vars: 23 | The environment variables 24 | global_vars: 25 | The global variables 26 | 27 | Returns: 28 | The resolved config 29 | """ 30 | 31 | def is_variable(name: str) -> bool: 32 | return isinstance(name, str) and name.startswith(settings['VARIABLE_PREFIX']) 33 | 34 | def resolve_variable(name: str) -> Any: 35 | if name[1:] in global_vars.keys(): 36 | return global_vars[name[1:]] 37 | elif name[1:] in env_vars.keys(): 38 | return env_vars[name[1:]] 39 | else: 40 | raise TypeError('Could not resolve variable "%s".' % name) 41 | 42 | if isinstance(config, list): 43 | for idx, val in enumerate(config): 44 | if is_variable(val): 45 | config[idx] = resolve_variable(val) 46 | elif isinstance(val, dict): 47 | config[idx] = parse_config(val, env_vars=env_vars, global_vars=global_vars) 48 | elif isinstance(config, dict): 49 | for key, val in config.items(): 50 | if is_variable(val): 51 | config[key] = resolve_variable(val) 52 | elif isinstance(val, list): 53 | for sub_idx, sub_val in enumerate(val): 54 | if is_variable(sub_val): 55 | val[sub_idx] = resolve_variable(sub_val) 56 | elif isinstance(sub_val, dict): 57 | val[sub_idx] = parse_config(sub_val, env_vars=env_vars, global_vars=global_vars) 58 | elif isinstance(val, dict): 59 | config[key] = parse_config( 60 | val, env_vars=env_vars, global_vars=global_vars) 61 | if key == '_var': 62 | U.merge_dict(global_vars, config[key], union=True) 63 | 64 | nest_module_name = config.pop('_name', None) 65 | if nest_module_name: 66 | nest_module = module_manager[nest_module_name] 67 | if settings['PARSER_STRICT']: 68 | return nest_module(**config) 69 | else: 70 | return nest_module(**config, delay_resolve=True) 71 | 72 | return config 73 | 74 | 75 | def run_tasks( 76 | config_file: str, 77 | param_file: Optional[str] = None, 78 | verbose: bool = False) -> None: 79 | """Run experiment tasks by resolving config. 80 | 81 | Parameters: 82 | config_file: 83 | The path to the config file 84 | param_file: 85 | The path to the parameter file 86 | verbose: 87 | Show verbose information 88 | """ 89 | 90 | # helper function 91 | def check_all_resolved(resolved_config: Any) -> None: 92 | if isinstance(resolved_config, list): 93 | for v in resolved_config: 94 | check_all_resolved(v) 95 | elif isinstance(resolved_config, dict): 96 | for v in resolved_config.values(): 97 | check_all_resolved(v) 98 | elif type(resolved_config).__name__ == 'NestModule': 99 | raise RuntimeError('Unresolved Nest module found in the result.\n%s' % ( 100 | U.indent_text(str(resolved_config), 4))) 101 | 102 | # start resolving config 103 | try: 104 | start_time = datetime.now() 105 | # load config file 106 | config, raw = U.load_yaml(config_file) 107 | # load environment variables 108 | env_vars = {k: v for k, v in os.environ.items()} 109 | # record raw config 110 | env_vars['CONFIG'] = re.sub(r'\{(.*?)\}', r'{{\1}}', raw).replace('\\', '\\\\') 111 | env_vars['PARAMS'] = '' 112 | 113 | if param_file is not None: 114 | # initial global variables 115 | global_vars = dict() 116 | # iterate over params 117 | param_list, _ = U.load_yaml(param_file) 118 | if not isinstance(param_list, list): 119 | param_list = [param_list] 120 | for idx, param in enumerate(param_list): 121 | param_start_time = datetime.now() 122 | if isinstance(param, dict): 123 | U.merge_dict(global_vars, param, union=True) 124 | else: 125 | raise TypeError('Parameter file should define a list of Dict[str, Any]. Got "%s" in it.' % param) 126 | # record parameters 127 | env_vars['PARAMS'] = U.yaml_format(global_vars) 128 | if verbose: 129 | logger.info('(%d/%d) Resolving with parameters: \n' % (idx + 1, len(param_list)) + env_vars['PARAMS']) 130 | # parse config with updated vars 131 | resolved_config = parse_config(deepcopy(config), env_vars=env_vars, global_vars=global_vars) 132 | check_all_resolved(resolved_config) 133 | if verbose: 134 | end_time = datetime.now() 135 | logger.info('Finished (%s).' % (U.format_elapse(seconds=(end_time - param_start_time).total_seconds()))) 136 | else: 137 | resolved_config = parse_config(config, env_vars=env_vars) 138 | check_all_resolved(resolved_config) 139 | 140 | end_time = datetime.now() 141 | logger.info('All finished. (%s)' % U.format_elapse(seconds=(end_time - start_time).total_seconds())) 142 | 143 | except KeyboardInterrupt: 144 | logger.info('Processing is canceled by user.') 145 | -------------------------------------------------------------------------------- /src/nest/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import string 4 | import inspect 5 | import collections 6 | import warnings 7 | from typing import List, Set, Dict, Tuple, Callable, Any, Union, Iterable, Iterator 8 | 9 | import yaml 10 | from dateutil.relativedelta import relativedelta 11 | 12 | from nest.logger import exception 13 | from nest.settings import settings 14 | 15 | 16 | # helper functions 17 | def yaml_format(obj: object) -> str: 18 | """Format dict using YAML. 19 | 20 | Parameters: 21 | obj: 22 | The input dict object 23 | 24 | Returns: 25 | formated string 26 | """ 27 | 28 | return yaml.dump(obj, default_flow_style=False) 29 | 30 | 31 | def format_elapse(**elapse) -> str: 32 | """Format an elapse to human readable string. 33 | 34 | Parameters: 35 | elapse: 36 | The elapse dict, e.g., dict(seconds=10) 37 | 38 | Returns: 39 | Human readable string 40 | """ 41 | 42 | attrs = ['years', 'months', 'days', 'hours', 'minutes', 'seconds'] 43 | elapse = relativedelta(**elapse) 44 | return ', '.join(['%d %s' % (getattr(elapse, attr), getattr(elapse, attr) > 1 and attr or attr[:-1]) for attr in attrs if getattr(elapse, attr)]) 45 | 46 | 47 | def load_yaml(path: str) -> Tuple[dict, str]: 48 | """Load yaml file. 49 | 50 | Parameters: 51 | path: 52 | The path to the file 53 | 54 | Returns: 55 | The dict 56 | Raw string 57 | """ 58 | 59 | with open(path, 'r') as f: 60 | raw = ''.join(f.readlines()) 61 | return yaml.load(raw), raw 62 | 63 | 64 | def indent_text(text: str, indent: int) -> str: 65 | """Indent multi-line text. 66 | 67 | Parameters: 68 | text: 69 | The input text 70 | indent: 71 | Number of spaces 72 | """ 73 | 74 | return '\n'.join([' ' * indent + v for v in text.split('\n')]) 75 | 76 | 77 | def encode_id(namespace: str, key: str) -> str: 78 | """Encode unique id from namespace and key. 79 | 80 | Parameters: 81 | namespace: 82 | The given namespace 83 | key: 84 | The given key 85 | 86 | Returns: 87 | a unique id 88 | """ 89 | 90 | if settings['NAMESPACE_ORDER_REVERSE']: 91 | return key + settings['NAMESPACE_SEP'] + namespace 92 | else: 93 | return namespace + settings['NAMESPACE_SEP'] + key 94 | 95 | 96 | def decode_id(unique_id: str) -> Tuple[str, str]: 97 | """Decode unique id to namespace and key. 98 | 99 | Parameters: 100 | unique_id: 101 | The given unique id 102 | 103 | Returns: 104 | namespace and key 105 | """ 106 | 107 | if settings['NAMESPACE_ORDER_REVERSE']: 108 | return unique_id.split(settings['NAMESPACE_SEP'])[::-1] 109 | else: 110 | return unique_id.split(settings['NAMESPACE_SEP']) 111 | 112 | 113 | @exception 114 | def alert_msg(msg: str) -> None: 115 | """Show alert message. 116 | 117 | Parameters: 118 | msg: 119 | The message 120 | """ 121 | 122 | if settings['RAISES_ERROR']: 123 | raise RuntimeError(msg) 124 | else: 125 | warnings.warn(msg) 126 | 127 | 128 | def merge_dict( 129 | src: dict, 130 | diff: dict, 131 | union: bool = False, 132 | _path: List[str] = None) -> dict: 133 | """Recursively merges two dicts. 134 | 135 | Parameters: 136 | src: 137 | The source dict (will be modified) 138 | diff: 139 | Differences to be merged 140 | union: 141 | Whether to keep all keys 142 | _path: 143 | The internal falg that should not be used by the user 144 | 145 | Returns: 146 | The merged dict 147 | """ 148 | 149 | if _path is None: 150 | _path = [] 151 | for key in diff: 152 | if key in src: 153 | if isinstance(src[key], dict) and isinstance(diff[key], dict): 154 | merge_dict(src[key], diff[key], _path+[str(key)]) 155 | else: 156 | src[key] = diff[key] 157 | elif union: 158 | src[key] = diff[key] 159 | return src 160 | 161 | 162 | @exception 163 | def is_annotation_matched(var: object, annotation: object) -> bool: 164 | """Return True if annotation is matched with the given variable. 165 | {Any, List, Set, Tuple, Dict, Union, Callable, Iterable, Iterator} from "typing" are supported. 166 | 167 | Parameters: 168 | var: 169 | The variable 170 | annotation: 171 | The annotation 172 | 173 | Returns: 174 | True if matched, otherwise False. 175 | """ 176 | 177 | var_type = type(var) 178 | anno_str = str(annotation).split('[')[0] 179 | 180 | if var is None and annotation is None: 181 | return True 182 | elif type(annotation) == type: 183 | return issubclass(var_type, annotation) 184 | elif anno_str.startswith('typing.'): 185 | anno_type = anno_str[7:] 186 | if anno_type == 'Any': 187 | return True 188 | elif anno_type == 'List': 189 | sub_annotation = annotation.__args__ 190 | if var_type == list: 191 | if sub_annotation is None: 192 | return True 193 | else: 194 | return all(map(lambda x: is_annotation_matched(x, sub_annotation[0]), var)) 195 | else: 196 | return False 197 | elif anno_type == 'Set': 198 | sub_annotation = annotation.__args__ 199 | if var_type == set: 200 | if sub_annotation is None: 201 | return True 202 | else: 203 | return all(map(lambda x: is_annotation_matched(x, sub_annotation[0]), var)) 204 | else: 205 | return False 206 | elif anno_type == 'Iterable': 207 | # currently we can't check the type of items 208 | return issubclass(var_type, collections.abc.Iterable) 209 | elif anno_type == 'Iterator': 210 | # currently we can't check the type of items 211 | return issubclass(var_type, collections.abc.Iterator) 212 | elif anno_type == 'Tuple': 213 | sub_annotation = annotation.__args__ 214 | if var_type == tuple: 215 | if sub_annotation is None: 216 | return True 217 | if len(sub_annotation) != len(var): 218 | return False 219 | else: 220 | return all(map(lambda x, y: is_annotation_matched(x, y), var, sub_annotation)) 221 | else: 222 | return False 223 | elif anno_type == 'Dict': 224 | sub_annotation = annotation.__args__ 225 | if var_type == dict: 226 | if sub_annotation is None: 227 | return True 228 | else: 229 | key_anno, val_anno = sub_annotation 230 | return all(map( 231 | lambda x: is_annotation_matched(x[0], key_anno) and \ 232 | is_annotation_matched(x[1], val_anno), var.items())) 233 | else: 234 | return False 235 | elif anno_type == 'Union': 236 | sub_annotation = annotation.__args__ 237 | if sub_annotation is None: 238 | return False 239 | else: 240 | return any(map(lambda y: is_annotation_matched(var, y), sub_annotation)) 241 | elif anno_type == 'Callable': 242 | sub_annotation = annotation.__args__ 243 | if callable(var): 244 | if sub_annotation is None: 245 | return True 246 | else: 247 | if type(var).__name__ == 'NestModule': 248 | # Nest module 249 | # filter out resolved / optional params 250 | sig = var.sig 251 | func_annos = [v.annotation for k, v in sig.parameters.items() \ 252 | if not k in var.params.keys() and v.default is inspect.Parameter.empty] 253 | else: 254 | # regular callable object 255 | sig = inspect.signature(var) 256 | func_annos = [v.annotation for v in sig.parameters.values()] 257 | if len(func_annos) == len(sub_annotation) - 1: 258 | return all(map(lambda x, y: x == y, func_annos, sub_annotation)) and \ 259 | (sub_annotation[-1] == Any or \ 260 | sub_annotation[-1] == object or \ 261 | sig.return_annotation == sub_annotation[-1] or 262 | (sig.return_annotation is None and sub_annotation[-1] == type(None))) 263 | else: 264 | return False 265 | else: 266 | return False 267 | 268 | raise NotImplementedError('The annotation type %s is not supported' % inspect.formatannotation(annotation)) 269 | -------------------------------------------------------------------------------- /src/nest/cli.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import traceback 4 | import logging 5 | import argparse 6 | from glob import glob 7 | from shutil import rmtree 8 | from typing import Any, Dict 9 | 10 | from nest import utils as U 11 | from nest.logger import logger 12 | from nest.modules import module_manager 13 | from nest.parser import run_tasks 14 | from nest.settings import settings, SETTINGS_DIR, SETTINGS_FILE 15 | 16 | 17 | ROOT_HELP = """Nest - A flexible tool for building and sharing deep learning modules 18 | usage: nest [] 19 | 20 | The currently available commands are 21 | {} 22 | 23 | """ 24 | 25 | class Parser(argparse.ArgumentParser): 26 | """Custom parser for multi-level argparse. 27 | """ 28 | 29 | def format_help(self) -> str: 30 | """Allow to use custom help. 31 | """ 32 | 33 | if self.add_help: 34 | return super(Parser, self).format_help() 35 | else: 36 | return self.usage 37 | 38 | def error(self, message) -> None: 39 | """Show help when parse failed. 40 | """ 41 | 42 | if self.add_help: 43 | logger.error('error: {}\n'.format(message)) 44 | self.print_help() 45 | sys.exit(2) 46 | 47 | def add_argument(self, *args, **kwargs) -> None: 48 | """Hide metavar by default. 49 | """ 50 | 51 | keys = kwargs.keys() 52 | if ('metavar' not in keys) and ('action' not in keys): 53 | kwargs['metavar'] = '' 54 | super(Parser, self).add_argument(*args, **kwargs) 55 | 56 | 57 | class CLI(object): 58 | """Command line interface. 59 | """ 60 | 61 | def __init__(self) -> None: 62 | # find available commands 63 | commands = [att for att in dir(self) if att.startswith('cmd_')][::-1] 64 | commands_help = '\n'.join( 65 | [getattr(self, cmd).__doc__.strip() for cmd in commands]) 66 | # root parser 67 | parser = Parser(usage=ROOT_HELP.format(commands_help), add_help=False) 68 | parser.add_argument('command', choices=[cmd[4:] for cmd in commands]) 69 | args = parser.parse_args(sys.argv[1:2]) 70 | # dispatch 71 | getattr(self, 'cmd_' + args.command)('nest ' + args.command, sys.argv[2:]) 72 | 73 | def hook_exceptions(self, logger: logging.RootLogger) -> None: 74 | """Format excetion traceback. 75 | 76 | Parameters: 77 | logger: 78 | The logger for logging exceptions. 79 | """ 80 | 81 | def _hook(exc_type, value, exc_tb) -> None: 82 | nest_dir = os.path.dirname(os.path.abspath(__file__)) 83 | traceback_str = '' 84 | idx = 0 85 | for file_name, line_number, func_name, text in traceback.extract_tb(exc_tb)[1:]: 86 | # skip Nest-related tracebacks to make it more readable 87 | if os.path.dirname(os.path.abspath(file_name)) == nest_dir: 88 | continue 89 | idx += 1 90 | traceback_str += '\n [%d] File "%s", line %d, in function "%s"\n %s' % \ 91 | (idx, file_name, line_number, func_name, text) 92 | if traceback_str != '': 93 | traceback_str = 'Traceback: ' + traceback_str 94 | logger.critical('Exception occurred during resolving:\nType: %s\nMessage: %s\n%s' % \ 95 | (exc_type.__name__, value, traceback_str)) 96 | 97 | sys.excepthook = _hook 98 | 99 | def cmd_task(self, prog: str, arguments: str) -> None: 100 | """task Task runner. 101 | """ 102 | 103 | parser = Parser(prog=prog) 104 | subparsers = parser.add_subparsers(metavar="", dest='command') 105 | 106 | # excute tasks 107 | parser_run = subparsers.add_parser('run', help='Execute tasks.') 108 | parser_run.add_argument('config', metavar='CONFIG', nargs='?', default='config.yml', 109 | help='Path to the config file (default: config.yml).') 110 | parser_run.add_argument('-p', '--param', default=None, 111 | help='Path to the parameter file that can be used for hyper-params tuning.') 112 | parser_run.add_argument('-v', '--verbose', action='store_true', help='Show verbose information.') 113 | args = parser.parse_args(arguments) 114 | 115 | # exception formatter 116 | self.hook_exceptions(logger) 117 | 118 | if args.command == 'run': 119 | run_tasks(args.config, args.param, args.verbose) 120 | else: 121 | parser.print_help() 122 | 123 | def cmd_module(self, prog: str, arguments: str) -> None: 124 | """module Nest module manager. 125 | """ 126 | 127 | parser = Parser(prog=prog) 128 | subparsers = parser.add_subparsers(metavar="", dest='command') 129 | 130 | # show modules 131 | parser_list = subparsers.add_parser('list', help='Show module information.') 132 | parser_list.add_argument('-f', '--filter', help='Keyword for filtering module list.') 133 | parser_list.add_argument('-v', '--verbose', action='store_true', help='Show verbose information.') 134 | # install modules 135 | parser_install = subparsers.add_parser('install', help='Install modules.') 136 | parser_install.add_argument('src', metavar='SRC', help='URL or path of the modules.') 137 | parser_install.add_argument('namespace', metavar='NAMESPACE', nargs='?', 138 | help='The namespace for local path installation, use directory name if not specified.') 139 | parser_install.add_argument('-y', '--yes', action='store_true', help='Skip confirmation.') 140 | # remove modules 141 | parser_remove = subparsers.add_parser('remove', help='Remove modules.') 142 | parser_remove.add_argument('src', metavar='SRC', help='Namespace or path.') 143 | parser_remove.add_argument('-d', '--delete', action='store_true', help='Delete the namespace folder.') 144 | parser_remove.add_argument('-y', '--yes', action='store_true', help='Skip confirmation.') 145 | # pack modules 146 | parser_pack = subparsers.add_parser('pack', help='Pack modules.') 147 | parser_pack.add_argument('path', metavar='PATH', nargs='+', help='Path to namespaces.') 148 | parser_pack.add_argument('-s', '--save', default='./nest_modules.zip', help='Save path (default: ./nest_modules.zip).') 149 | parser_pack.add_argument('-y', '--yes', action='store_true', help='Skip confirmation.') 150 | # check modules 151 | parser_check = subparsers.add_parser('check', help='Check modules.') 152 | parser_check.add_argument('src', metavar='SRC', nargs='*', 153 | help='Path to the namespaces or python files (check all available modules if not specified).') 154 | args = parser.parse_args(arguments) 155 | 156 | if args.command == 'list': 157 | # list available Nest modules 158 | if args.verbose: 159 | module_info = ['%s (%s) by "%s":\n%s' % \ 160 | (k, v.meta.get('version', 'version'), v.meta.get('author', 'author'), U.indent_text(str(v), 4)) for k, v in module_manager] 161 | else: 162 | module_info = ['%s (%s)' % (k, v.meta.get('version', 'version')) for k, v in module_manager] 163 | 164 | # filtering 165 | if args.filter: 166 | module_info = list(filter(lambda v: args.filter in v, module_info)) 167 | module_info = ['[%d] ' % idx + v for idx, v in enumerate(sorted(module_info))] 168 | 169 | num_module = len(module_info) 170 | if num_module > 1: 171 | logger.info('%d Nest modules found.\n' % num_module + '\n'.join(module_info)) 172 | elif num_module == 1: 173 | module_doc = U.indent_text(type(module_manager[module_info[0].split()[1]]).__doc__, 4) 174 | logger.info('1 Nest module found.\n' + module_info[0] + '\n\nDocumentation:\n' + module_doc) 175 | else: 176 | logger.info( 177 | 'No available Nest modules found. You can install build-in PyTorch modules by executing ' 178 | '"nest module install github@ZhouYanzhao/Nest:pytorch".') 179 | 180 | elif args.command == 'install': 181 | if os.path.isdir(args.src): 182 | # install Nest modules from path 183 | confirm = 'y' if args.yes else input('Install "%s" -> Search paths. Continue? (Y/n)' % (args.src,)).lower() 184 | if confirm == '' or confirm == 'y': 185 | module_manager._install_namespaces_from_path(args.src, args.namespace) 186 | else: 187 | # install Nest modules from url 188 | confirm = 'y' if args.yes else input('Install "%s" --> "%s". Continue? (Y/n)' % (args.src, './')).lower() 189 | if confirm == '' or confirm == 'y': 190 | module_manager._install_namespaces_from_url(args.src, args.namespace) 191 | 192 | elif args.command == 'remove': 193 | # remove Nest modules from paths 194 | confirm = 'y' if args.yes else input('Remove "%s" from paths. Continue? (Y/n)' % (args.src,)).lower() 195 | if confirm == '' or confirm == 'y': 196 | path = module_manager._remove_namespaces_from_path(args.src) 197 | if args.delete and path is not None and os.path.isdir(path): 198 | del_confirm = 'y' if args.yes else input('Delete the namespace directory "%s". Continue? (Y/n)' % (path,)).lower() 199 | if del_confirm == '' or del_confirm == 'y': 200 | # error handler 201 | def onerror(func, path, exc_info): 202 | import stat 203 | if not os.access(path, os.W_OK): 204 | os.chmod(path, stat.S_IWUSR) 205 | func(path) 206 | else: 207 | logger.warning('Failed to delete the namespace directory "%s".' % path) 208 | rmtree(path, onerror=onerror) 209 | 210 | elif args.command == 'pack': 211 | # pack Nest modules to a zip file 212 | confirm = 'y' if args.yes else input('Pack "%s" --> "%s". Continue? (Y/n)' % (','.join(args.path), args.save)).lower() 213 | if confirm == '' or confirm == 'y': 214 | save_list = module_manager._pack_namespaces(args.path, args.save) 215 | logger.info('Packed list: \n%s', U.indent_text(U.yaml_format(save_list), 4)) 216 | 217 | elif args.command == 'check': 218 | if len(args.src) == 0: 219 | logger.info('Checking all available modules') 220 | # check all 221 | module_manager._update_modules() 222 | else: 223 | for idx, path in enumerate(args.src): 224 | logger.info('[%d/%d] Checking "%s"' % (idx + 1, len(args.src), path)) 225 | if os.path.isfile(path): 226 | module_manager._import_nest_modules_from_file(path, 'nest_check', dict(), dict()) 227 | elif os.path.isdir(path): 228 | module_manager._import_nest_modules_from_dir(path, 'nest_check', dict(), dict()) 229 | else: 230 | logger.warning('Skipped as it does not exist.') 231 | logger.info('Done.') 232 | 233 | else: 234 | parser.print_help() 235 | 236 | def cmd_setting(self, prog: str, arguments: str) -> None: 237 | """setting Settings configuration. 238 | """ 239 | 240 | parser = Parser(prog=prog) 241 | subparsers = parser.add_subparsers(metavar="", dest='command') 242 | 243 | # display settings 244 | parser_show = subparsers.add_parser('show', help='Show settings.') 245 | parser_show.add_argument('-d', '--directory', action='store_true', help='Locate setting file in directory.') 246 | parser_show.add_argument('-e', '--editor', action='store_true', help='Show settings in external editor.') 247 | # define settings 248 | parser_set = subparsers.add_parser('set', help='Modify settings.') 249 | parser_set.add_argument('key', metavar='KEY', help='The key.') 250 | parser_set.add_argument('val', metavar='VAL', help='The value (python expression).') 251 | args = parser.parse_args(arguments) 252 | 253 | # exception formatter 254 | self.hook_exceptions(logger) 255 | 256 | if args.command == 'show': 257 | import webbrowser 258 | if args.directory: 259 | webbrowser.open(SETTINGS_DIR) 260 | elif args.editor: 261 | webbrowser.open(SETTINGS_FILE) 262 | else: 263 | logger.info(U.yaml_format(settings.settings)) 264 | elif args.command == 'set': 265 | if args.key in settings: 266 | settings[args.key] = eval(args.val) 267 | settings.save() 268 | else: 269 | raise KeyError('Invalid setting key "%s". Use "nest setting show" to check the supported settings.' % args.key) 270 | else: 271 | parser.print_help() 272 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Nest - A flexible tool for building and sharing deep learning modules

2 | 3 | [![](https://img.shields.io/badge/core-0.1-blue.svg)](https://github.com/ZhouYanzhao/Nest) 4 | [![](https://img.shields.io/badge/pytorch-0.1-red.svg)](https://github.com/ZhouYanzhao/Nest/tree/pytorch) 5 | [![](https://img.shields.io/badge/mxnet-scheduled-green.svg)](#) 6 | [![](https://img.shields.io/badge/tensorflow-scheduled-green.svg)](#) 7 | 8 | Nest is a flexible deep learning module manager, which aims at encouraging code reuse and sharing. It ships with a bunch of useful features, such as CLI based module management, runtime checking, and experimental task runner, etc. You can integrate Nest with PyTorch, Tensorflow, MXNet, or any deep learning framework you like that provides a python interface. 9 | 10 | Moreover, a set of [Pytorch-backend Nest modules](https://github.com/ZhouYanzhao/Nest/tree/pytorch), e.g., network trainer, data loader, optimizer, dataset, visdom logging, are already provided. More modules and framework support will be added later. 11 | 12 | --- 13 | 14 | - [Prerequisites](#prerequisites) 15 | - [Installation](#installation) 16 | - [Basic Usage](#basic-usage) 17 | - [Create your first Nest module](#create-your-first-nest-module) 18 | - [Use your Nest module in Python](#use-your-nest-module-in-python) 19 | - [Debug your Nest modules](#debug-your-nest-modules) 20 | - [Install Nest modules from local path](#install-nest-modules-from-local-path) 21 | - [Install Nest modules from URL](#install-nest-modules-from-url) 22 | - [Uninstall Nest modules](#uninstall-nest-modules) 23 | - [Version control Nest modules](#version-control-nest-modules) 24 | - [Use Nest to manage your experiments](#use-nest-to-manage-your-experiments) 25 | - [Contact](#contact) 26 | - [Issues](#issues) 27 | - [Contribution](#contribution) 28 | - [License](#license) 29 | 30 | ## Prerequisites 31 | * System (tested on Ubuntu 14.04LTS, Win10, and MacOS *High Sierra*) 32 | * [Python](https://www.python.org) >= 3.5.4 33 | * [Git](https://git-scm.com) 34 | 35 | ## Installation 36 | ```bash 37 | # directly install via pip 38 | pip install git+https://github.com/ZhouYanzhao/Nest.git 39 | 40 | # manually download and install 41 | git clone https://github.com/ZhouYanzhao/Nest.git 42 | pip install ./Nest 43 | ``` 44 | 45 | ## Basic Usage 46 | > The official website and documentation are under construction. 47 | 48 | ### Create your first Nest module 49 | 1. Create "hello.py" under your current path with the following content: 50 | 51 | ```python 52 | from nest import register 53 | 54 | @register(author='Yanzhao', version='1.0.0') 55 | def hello_nest(name: str) -> str: 56 | """My first Nest module!""" 57 | 58 | return 'Hello ' + name 59 | ``` 60 | 61 | > Note that the type of module parameters and return values must be clearly defined. This helps the user to better understand the module, and at runtime Nest automatically checks whether each module receives and outputs as expected, thus helping you to identify potential bugs earlier. 62 | 63 | 2. Execute the following command in your shell to verify the module: 64 | 65 | ```bash 66 | $ nest module list -v 67 | # Output: 68 | # 69 | # 1 Nest module found. 70 | # [0] main.hello_nest (1.0.0) by "Yanzhao": 71 | # hello_nest( 72 | # name:str) -> str 73 | 74 | # Documentation: 75 | # My first Nest module! 76 | # author: Yanzhao 77 | # module_path: /Users/yanzhao/Workspace/Nest.doc 78 | # version: 1.0.0 79 | ``` 80 | 81 | > Note that all modules under current path are registered under the "**main**" namespace. 82 | 83 | > With the CLI tool, you can easily manage Nest modules. Execute `nest -h` for more details. 84 | 85 | 3. That's it. You just created a simple Nest module! 86 | 87 | 88 | ### Use your Nest module in Python 89 | 1. Open an interactive python interpreter under the same path of "hello.py" and run following commands: 90 | 91 | ```python 92 | >>> from nest import modules 93 | >>> print(modules.hello_nest) # access the module 94 | # Output: 95 | # 96 | # hello_nest( 97 | # name:str) -> str 98 | >>> print(modules['*_nes?']) # wildcard search 99 | # Output: 100 | # 101 | # hello_nest( 102 | # name:str) -> str 103 | >>> print(modules['r/main.\w+_nest']) # regex search 104 | # Output: 105 | # 106 | # hello_nest( 107 | # name:str) -> str 108 | >>> modules.hello_nest('Yanzhao') # use the module 109 | # Output: 110 | # 111 | # 'Hello Yanzhao' 112 | >>> modules.hello_nest(123) # runtime type checking 113 | # Output: 114 | # 115 | # TypeError: The param "name" of Nest module "hello_nest" should be type of "str". Got "123". 116 | >>> modules.hello_nest('Yanzhao', wrong=True) 117 | # Output: 118 | # 119 | # Unexpected param(s) "wrong" for Nest module: 120 | # hello_nest( 121 | # name:str) -> str 122 | ``` 123 | 124 | > Note that Nest automatically imports modules and checks them as they are used to make sure everything is as expected. 125 | 126 | 2. You can also directly import modules like this: 127 | 128 | ```python 129 | >>> from nest.main.hello import hello_nest 130 | >>> hello_nest('World') 131 | # Output: 132 | # 133 | # 'Hello World' 134 | ``` 135 | 136 | > The import syntax is `from nest.. import ` 137 | 138 | 3. Access to Nest modules through code is flexible and easy. 139 | 140 | ### Debug your Nest modules 141 | 1. Open an interactive python interpreter under the same path of "hello.py" and run following commands: 142 | 143 | ```python 144 | >>> from nest import modules 145 | >>> modules.hello_nest('Yanzhao') 146 | # Output: 147 | # 148 | # 'Hello Yanzhao' 149 | ``` 150 | 151 | 2. Keep the interpreter **OPEN** and use an externel editor to modify the "hello.py": 152 | 153 | ```python 154 | # change Line7 from "return 'Hello ' + name" to 155 | return 'Nice to meet you, ' + name 156 | ``` 157 | 158 | 3. Back to the interpreter and rerun the same command: 159 | 160 | ```python 161 | >>> modules.hello_nest('Yanzhao') 162 | # Output: 163 | # 164 | # 'Nice to meet you, Yanzhao' 165 | ``` 166 | 167 | > Note that Nest detects source file modifications and automatically reloads the module. 168 | 169 | 4. You can use this feature to develop and debug your Nest modules efficiently. 170 | 171 | ### Install Nest modules from local path 172 | 1. Create a folder `my_namespace` and move the `hello.py` into it: 173 | 174 | ```bash 175 | $ mkdir my_namespace 176 | $ mv hello.py ./my_namespace/ 177 | ``` 178 | 179 | 2. Create a new file `more.py` under the folder `my_namespace` with the following content: 180 | 181 | ```python 182 | from nest import register 183 | 184 | @register(author='Yanzhao', version='1.0.0') 185 | def sum(a: int, b: int) -> int: 186 | """Sum two numbers.""" 187 | 188 | return a + b 189 | 190 | # There is no need to repeatedly declare meta information 191 | # as modules within the same file automatically reuse the 192 | # previous information. But overriding is also supported. 193 | @register(version='2.0.0') 194 | def mul(a: float, b: float) -> float: 195 | """Multiply two numbers.""" 196 | 197 | return a * b 198 | ``` 199 | 200 | > Now we have: 201 | ``` 202 | current path/ 203 | ├── my_namespace/ 204 | │ ├── hello.py 205 | │ ├── more.py 206 | ``` 207 | 208 | 3. Run the following command in the shell: 209 | 210 | ```bash 211 | $ nest module install ./my_namespace hello_word 212 | # Output: 213 | # 214 | # Install "./my_namespace/" -> Search paths. Continue? (Y/n) [Press ] 215 | ``` 216 | 217 | > This command will add "**my_namespace**" folder to Nest's search path, and register all Nest modules in it under the namespace "**hello_word**". If the last argument is omitted, the directory name, "my_namespace" in this case, is used as the namespace. 218 | 219 | 4. Verify the installation via CLI: 220 | 221 | ```bash 222 | $ nest module list 223 | # Output: 224 | # 225 | # 3 Nest modules found. 226 | # [0] hello_world.hello_nest (1.0.0) 227 | # [1] hello_world.mul (2.0.0) 228 | # [2] hello_world.sum (1.0.0) 229 | ``` 230 | 231 | > Note that those Nest modules can now be accessed regardless of your working path. 232 | 233 | 5. Verify the installation via Python interpreter: 234 | 235 | ```bash 236 | $ ipython # open IPython interpreter 237 | ``` 238 | ```python 239 | >>> from nest import modules 240 | >>> print(len(modules)) 241 | # Output: 242 | # 243 | # 3 244 | >>> modules.[Press ] # IPython Auto-completion 245 | # Output: 246 | # 247 | # hello_nest 248 | # mul 249 | # sum 250 | >>> modules.sum(3, 2) 251 | # Output: 252 | # 253 | # 5 254 | >>> modules.mul(2.5, 4.0) 255 | # Output: 256 | # 257 | # 10.0 258 | ``` 259 | 260 | 6. Thanks to the auto-import feature of Nest, you can easily share modules between different local projects. 261 | 262 | ### Install Nest modules from URL 263 | 1. You can use the CLI tool to install modules from URL: 264 | 265 | ```bash 266 | # select one of the following commands to execute 267 | # 0. install from Github repo via short URL (GitLab, Bitbucket are also supported) 268 | $ nest module install github@ZhouYanzhao/Nest:pytorch pytorch 269 | # 1. install from Git repo 270 | $ nest module install "-b pytorch https://github.com/ZhouYanzhao/Nest.git" pytorch 271 | # 2. install from zip file URL 272 | $ nest module install "https://github.com/ZhouYanzhao/Nest/archive/pytorch.zip" pytorch 273 | ``` 274 | 275 | > The last optional argument is used to specify the namespace, "**pytorch**" in this case. 276 | 277 | 2. Verify the installation: 278 | 279 | ```bash 280 | $ nest module list 281 | # Output: 282 | # 283 | # 26 Nest modules found. 284 | # [0] hello_world.hello_nest (1.0.0) 285 | # [1] hello_world.mul (2.0.0) 286 | # [2] hello_world.sum (1.0.0) 287 | # [3] pytorch.adadelta_optimizer (0.1.0) 288 | # [4] pytorch.checkpoint (0.1.0) 289 | # [5] pytorch.cross_entropy_loss (0.1.0) 290 | # [6] pytorch.fetch_data (0.1.0) 291 | # [7] pytorch.finetune (0.1.0) 292 | # [8] pytorch.image_transform (0.1.0) 293 | # ... 294 | ``` 295 | 296 | ### Uninstall Nest modules 297 | 1. You can remove modules from Nest's search path by executing: 298 | 299 | ```bash 300 | # given namespace 301 | $ nest module remove hello_world 302 | # given path to the namespace 303 | $ nest module remove ./my_namespace/ 304 | ``` 305 | 306 | 2. You can also delete the corresponding files by appending a `--delete` or `-d` flag: 307 | 308 | ```bash 309 | $ nest module remove hello_world --delete 310 | ``` 311 | 312 | ### Version control Nest modules 313 | 314 | 1. When installing modules, Nest adds the namespace to its search path without modifying or moving the original files. So you can use any version control system you like, e.g., Git, to manage modules. For example: 315 | 316 | ```bash 317 | $ cd 318 | # update modules 319 | $ git pull 320 | # specify version 321 | $ git checkout v1.0 322 | ``` 323 | 324 | 2. When developing a Nest module, it is recommended to define meta information for the module, such as the author, version, requirements, etc. Those information will be used by Nest's CLI tool. There are two ways to set meta information: 325 | 326 | * define meta information in code 327 | 328 | ```python 329 | from nest import register 330 | 331 | @register(author='Yanzhao', version='1.0') 332 | def my_module() -> None: 333 | """My Module""" 334 | pass 335 | ``` 336 | 337 | * define meta information in a `nest.yml` under the path of namespace 338 | 339 | ```YAML 340 | author: Yanzhao 341 | version: 1.0 342 | requirements: 343 | - {url: opencv, tool: conda} 344 | # default tool is pip 345 | - torch>=0.4 346 | ``` 347 | 348 | > Note that you can use both ways at the same time. 349 | 350 | ### Use Nest to manage your experiments 351 | 1. Make sure you have Pytorch-backend modules installed, and if not, execute the following command: 352 | 353 | ```bash 354 | $ nest module install github@ZhouYanzhao/Nest:pytorch pytorch 355 | ``` 356 | 357 | 2. Create "**train_mnist.yml**" with the following content: 358 | 359 | ```YAML 360 | _name: network_trainer 361 | data_loaders: 362 | _name: fetch_data 363 | dataset: 364 | _name: mnist 365 | data_dir: ./data 366 | batch_size: 128 367 | num_workers: 4 368 | transform: 369 | _name: image_transform 370 | image_size: 28 371 | mean: [0.1307] 372 | std: [0.3081] 373 | train_splits: [train] 374 | test_splits: [test] 375 | model: 376 | _name: lenet5 377 | criterion: 378 | _name: cross_entropy_loss 379 | optimizer: 380 | _name: adadelta_optimizer 381 | meters: 382 | top1: 383 | _name: topk_meter 384 | k: 1 385 | max_epoch: 10 386 | device: cpu 387 | hooks: 388 | on_end_epoch: 389 | - 390 | _name: print_state 391 | formats: 392 | - 'epoch: {epoch_idx}' 393 | - 'train_acc: {metrics[train_top1]:.1f}%' 394 | - 'test_acc: {metrics[test_top1]:.1f}%' 395 | ``` 396 | 397 | > Check [HERE](https://github.com/ZhouYanzhao/Nest/tree/pytorch/demo) for more comprehensive demos. 398 | 399 | 3. Run your experiments through CLI: 400 | 401 | ```bash 402 | $ nest task run ./train_mnist.yml 403 | ``` 404 | 405 | 4. You can also use Nest's task runner in your code: 406 | 407 | ```python 408 | >>> from nest import run_tasks 409 | >>> run_tasks('./train_mnist.yml') 410 | ``` 411 | 412 | 5. Based on the task runner feature, Nest modules can be flexibly replaced and assembled to create your desired experiment settings. 413 | 414 | ## Contact 415 | Yanzhao Zhou 416 | 417 | ## Issues 418 | Feel free to submit bug reports and feature requests. 419 | 420 | ## Contribution 421 | Pull requests are welcome. 422 | 423 | ## License 424 | [MIT](https://opensource.org/licenses/MIT) 425 | 426 | Copyright © 2018-present, Yanzhao Zhou 427 | -------------------------------------------------------------------------------- /src/nest/modules.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | import fnmatch 5 | import inspect 6 | import importlib 7 | import importlib.abc 8 | import importlib.util 9 | import importlib.machinery 10 | import warnings 11 | import subprocess 12 | from types import ModuleType 13 | from typing import Any, List, Dict, Iterator, Callable, Optional 14 | from difflib import SequenceMatcher 15 | from datetime import datetime 16 | from argparse import Namespace as BaseNamespace 17 | from inspect import formatannotation as format_anno 18 | 19 | from nest import utils as U 20 | from nest.logger import exception 21 | from nest.settings import settings 22 | 23 | 24 | class Context(BaseNamespace): 25 | """Helper class for storing module context. 26 | """ 27 | 28 | def __getitem__(self, key: str) -> Any: 29 | return getattr(self, key) 30 | 31 | def __setitem__(self, key: str, val: str) -> Any: 32 | return setattr(self, key, val) 33 | 34 | def __iter__(self) -> Iterator: 35 | return iter(self.__dict__.items()) 36 | 37 | def items(self) -> Iterator: 38 | return self.__dict__.items() 39 | 40 | def keys(self) -> Iterator: 41 | return self.__dict__.keys() 42 | 43 | def values(self) -> Iterator: 44 | return self.__dict__.values() 45 | 46 | def clear(self): 47 | self.__dict__.clear() 48 | 49 | 50 | class NestModule(object): 51 | """Base Nest module class. 52 | """ 53 | 54 | __slots__ = ('__name__', 'func', 'sig', 'meta', 'params') 55 | 56 | def __init__(self, func: Callable, meta: Dict[str, object], params: dict = {}) -> None: 57 | # module func 58 | self.func = func 59 | self.__name__ = func.__name__ 60 | # module signature 61 | self.sig = inspect.signature(func) 62 | # meta information 63 | self.meta = U.merge_dict(dict(), meta, union=True) 64 | # record module params 65 | self.params = U.merge_dict(dict(), params, union=True) 66 | # init module context 67 | for k, v in self.sig.parameters.items(): 68 | if k =='ctx' and issubclass(v.annotation, Context): 69 | self.params[k] = v.annotation() 70 | break 71 | # check module 72 | self._check_definition() 73 | 74 | def _check_definition(self) -> None: 75 | """Raise errors if the module definition is invalid. 76 | """ 77 | 78 | for v in self.sig.parameters.values(): 79 | # type of parameters must be annotated 80 | if v.annotation is inspect.Parameter.empty: 81 | raise TypeError('The param "%s" of Nest module "%s" is not explicitly annotated.' % (v, self.__name__)) 82 | # type of defaults must match annotations 83 | if v.default is not inspect.Parameter.empty and not U.is_annotation_matched(v.default, v.annotation): 84 | raise TypeError('The param "%s" of Nest module "%s" has an incompatible default value of type "%s".' % 85 | (v, self.__name__, format_anno(type(v.default)))) 86 | 87 | # type of returns must be annotated 88 | if self.sig.return_annotation is inspect.Parameter.empty: 89 | raise TypeError('The returns of Nest module "%s" is not explicitly annotated.' % self.__name__) 90 | 91 | # important meta data must be provided 92 | if getattr(self, '__doc__', None) is None: 93 | raise KeyError('Documentation of module "%s" is missing.' % self.__name__) 94 | 95 | def _check_params(self, params: dict) -> None: 96 | """Raise errors if invalid params are provided to the Nest module. 97 | 98 | Parameters: 99 | params: 100 | The provided params 101 | """ 102 | 103 | unexpected_params = ', '.join(set(params.keys()) - set(self.sig.parameters.keys())) 104 | if len(unexpected_params) > 0: 105 | raise TypeError('Unexpected param(s) "%s" for Nest module: \n%s' % \ 106 | (unexpected_params, self)) 107 | 108 | for k, v in self.sig.parameters.items(): 109 | resolved = params.get(k) 110 | if resolved is None: 111 | if v.default is inspect.Parameter.empty: 112 | raise KeyError('The required param "%s" of Nest module "%s" is missing.' % \ 113 | (v, self.__name__)) 114 | elif not U.is_annotation_matched(resolved, v.annotation): 115 | if issubclass(type(resolved), NestModule): 116 | detailed_msg = 'The param "%s" of Nest module "%s" should be type of "%s". Got \n%s\n' + \ 117 | 'Please check if some important params of Nest module "%s" have been forgotten in use.' 118 | raise TypeError(detailed_msg % \ 119 | (k, self.__name__, format_anno(v.annotation), U.indent_text(str(resolved), 4), resolved.__name__)) 120 | else: 121 | raise TypeError('The param "%s" of Nest module "%s" should be type of "%s". Got "%s".' % \ 122 | (k, self.__name__, format_anno(v.annotation), resolved)) 123 | 124 | def _check_returns(self, returns: Any) -> None: 125 | """Raise errors if invalid returns are generated by the Nest module. 126 | 127 | Parameters: 128 | returns: 129 | The generated returns 130 | """ 131 | 132 | if not U.is_annotation_matched(returns, self.sig.return_annotation): 133 | raise TypeError('The returns of Nest module "%s" should be type of "%s". Got "%s".' % \ 134 | (self.__name__, format_anno(self.sig.return_annotation), returns)) 135 | 136 | def __call__(self, *args, **kwargs): 137 | # handle positional params 138 | num_args = len(args) 139 | if num_args > 0: 140 | # positional params should not be optional or resolved 141 | expected_param_names = [k for k, v in self.sig.parameters.items() 142 | if not k in self.params.keys() and v.default is inspect.Parameter.empty] 143 | num_expected_params = len(expected_param_names) 144 | if num_args != num_expected_params: 145 | raise TypeError('Nest module "%s" expects %d positional param(s) "%s". Got "%s".' % 146 | (self.__name__, num_expected_params, ', '.join(expected_param_names), ', '.join([str(v) for v in args]))) 147 | for idx, val in enumerate(args): 148 | key = expected_param_names[idx] 149 | if key in kwargs.keys(): 150 | raise TypeError('Nest module "%s" got multiple values for param "%s".' % (self.__name__, key)) 151 | else: 152 | kwargs[key] = val 153 | 154 | # resolve params 155 | resolved_params = dict() 156 | U.merge_dict(resolved_params, self.params, union=True) 157 | U.merge_dict(resolved_params, kwargs, union=True) 158 | 159 | if resolved_params.pop('delay_resolve', None): 160 | try: 161 | self._check_params(resolved_params) 162 | returns = self.func(**resolved_params) 163 | except KeyError as exc_info: 164 | if 'Nest module' in str(exc_info): 165 | # wait for next call 166 | return self.clone(resolved_params) 167 | else: 168 | raise 169 | else: 170 | # parameters must be fulfilled 171 | self._check_params(resolved_params) 172 | returns = self.func(**resolved_params) 173 | # check returns 174 | self._check_returns(returns) 175 | return returns 176 | 177 | def __str__(self) -> str: 178 | param_string = ', \n'.join(['[✓] ' + str(v) 179 | if k in self.params.keys() else ' ' + str(v) 180 | for k, v in self.sig.parameters.items()]) 181 | return_string = ' -> ' + format_anno(self.sig.return_annotation) 182 | return self.__name__ + '(\n' + param_string + ')' + return_string 183 | 184 | def __repr__(self) -> str: 185 | return "nest.modules['%s']" % self.__name__ 186 | 187 | def clone(self, params: dict = {}) -> Callable: 188 | """Clone the Nest module. 189 | 190 | Parameters: 191 | params: 192 | Module parameters 193 | """ 194 | 195 | return type(self)(self.func, self.meta, params) 196 | 197 | 198 | class ModuleManager(object): 199 | """Helper class for easy access to Nest modules. 200 | """ 201 | 202 | def __init__(self) -> None: 203 | self.namespaces = dict() 204 | self.py_modules = dict() 205 | self.nest_modules = dict() 206 | self.update_timestamp = 0.0 207 | self.namespace_regex = re.compile(r'^[a-z][a-z0-9\_]*\Z') 208 | # get available namespaces 209 | self._update_namespaces() 210 | # import syntax 211 | self._add_module_finder() 212 | 213 | @staticmethod 214 | def _format_namespace(src: str) -> str: 215 | """Format namespace. 216 | 217 | Parameters: 218 | src: 219 | The original namespace 220 | 221 | Returns: 222 | Formatted namespace. 223 | """ 224 | 225 | return src.lower().replace('-', '_').replace('.', '_') 226 | 227 | @staticmethod 228 | def _register(*args, **kwargs) -> Callable: 229 | """Decorator for Nest modules registration. 230 | 231 | Parameters: 232 | ignored: 233 | Ignore the module 234 | 235 | module meta information which could be utilized by CLI and UI. For example: 236 | author: 237 | Module author(s), e.g., 'Zhou, Yanzhao' 238 | version: 239 | Module version, e.g., '1.2.0' 240 | backend: 241 | Module backend, e.g., 'pytorch' 242 | tags: 243 | Searchable tags, e.g., ['loss', 'cuda_only'] 244 | etc. 245 | """ 246 | 247 | # ignore the Nest module (could be used for debuging) 248 | if kwargs.pop('ignored', False): 249 | return lambda x: x 250 | 251 | # use the rest of kwargs to update metadata 252 | frame = inspect.stack()[1] 253 | current_py_module = inspect.getmodule(frame[0]) 254 | nest_meta = U.merge_dict(getattr(current_py_module, '__nest_meta__', dict()), kwargs, union=True) 255 | if current_py_module is not None: 256 | setattr(current_py_module, '__nest_meta__', nest_meta) 257 | 258 | def create_module(func): 259 | # append meta to doc 260 | doc = (func.__doc__ + '\n' + (U.yaml_format(nest_meta) if len(nest_meta) > 0 else '')) \ 261 | if isinstance(func.__doc__, str) else None 262 | return type('NestModule', (NestModule,), dict(__slots__=(), __doc__=doc))(func, nest_meta) 263 | 264 | if len(args) == 1 and inspect.isfunction(args[0]): 265 | return create_module(args[0]) 266 | else: 267 | return create_module 268 | 269 | @staticmethod 270 | def _import_nest_modules_from_py_module( 271 | namespace: str, 272 | py_module: object, 273 | nest_modules: Dict[str, object]) -> bool: 274 | """Import registered Nest modules from a given python module. 275 | 276 | Parameters: 277 | namespace: 278 | A namespace that is used to avoid name conflicts 279 | py_module: 280 | The python module 281 | nest_modules: 282 | The dict for storing Nest modules 283 | 284 | Returns: 285 | The id of imported Nest modules 286 | """ 287 | 288 | imported_ids = [] 289 | # search for Nest modules 290 | for key, val in py_module.__dict__.items(): 291 | module_id = U.encode_id(namespace, key) 292 | if not key.startswith('_') and type(val).__name__ == 'NestModule': 293 | if module_id in nest_modules.keys(): 294 | U.alert_msg('There are duplicate "%s" modules under namespace "%s".' % \ 295 | (key, namespace)) 296 | else: 297 | nest_modules[module_id] = val 298 | imported_ids.append(module_id) 299 | return imported_ids 300 | 301 | @staticmethod 302 | def _import_nest_modules_from_file( 303 | path: str, 304 | namespace: str, 305 | py_modules: Dict[str, float], 306 | nest_modules: Dict[str, object], 307 | meta: Dict[str, object] = dict()) -> None: 308 | """Import registered Nest modules form a given file. 309 | 310 | Parameters: 311 | path: 312 | The path to the file 313 | namespace: 314 | A namespace that is used to avoid name conflicts 315 | py_modules: 316 | The dict for storing python modules information 317 | nest_modules: 318 | The dict for storing Nest modules 319 | meta: 320 | Global meta information 321 | """ 322 | 323 | py_module_name = os.path.basename(path).split('.')[0] 324 | py_module_id = U.encode_id(namespace, py_module_name) 325 | timestamp = os.path.getmtime(path) 326 | # check whether the python module have already been imported 327 | is_reload = False 328 | if py_module_id in py_modules.keys(): 329 | if timestamp <= py_modules[py_module_id][0]: 330 | # skip 331 | return 332 | else: 333 | is_reload = True 334 | # import the python module 335 | # note that a python module could contain multiple Nest modules. 336 | ref_id = 'nest.' + namespace + '.' + py_module_name 337 | spec = importlib.util.spec_from_file_location(ref_id, path) 338 | if spec is not None: 339 | py_module = importlib.util.module_from_spec(spec) 340 | py_module.__nest_meta__ = U.merge_dict(dict(), meta, union=True) 341 | # no need to bind global requirements to individual Nest modules. 342 | requirements = py_module.__nest_meta__.pop('requirements', None) 343 | if requirements is not None: 344 | requirements = [dict(url=v, tool='pip') if isinstance(v, str) else v for v in requirements] 345 | sys.modules[ref_id] = py_module 346 | try: 347 | with warnings.catch_warnings(): 348 | warnings.simplefilter("ignore") 349 | spec.loader.exec_module(py_module) 350 | except Exception as exc_info: 351 | # helper function 352 | def find_requirement(name): 353 | if isinstance(requirements, list) and len(requirements) > 0: 354 | scores = [(SequenceMatcher(None, name, v['url']).ratio(), v) for v in requirements] 355 | return max(scores, key=lambda x: x[0]) 356 | 357 | # install tip 358 | tip = '' 359 | if (type(exc_info) is ImportError or type(exc_info) is ModuleNotFoundError) and exc_info.name is not None: 360 | match = find_requirement(exc_info.name) 361 | if match and match[0] > settings['INSTALL_TIP_THRESHOLD']: 362 | tip = 'Try to execute "%s install %s" to install the missing dependency.' % \ 363 | (match[1]['tool'], match[1]['url']) 364 | 365 | exc_info = str(exc_info) 366 | exc_info = exc_info if exc_info.endswith('.') else exc_info + '.' 367 | U.alert_msg('%s The package "%s" under namespace "%s" could not be imported. %s' % 368 | (exc_info, py_module_name, namespace, tip)) 369 | else: 370 | # remove old Nest modules 371 | if is_reload: 372 | for key in py_modules[py_module_id][1]: 373 | if key in nest_modules.keys(): 374 | del nest_modules[key] 375 | # import all Nest modules within the python module 376 | imported_ids = ModuleManager._import_nest_modules_from_py_module(namespace, py_module, nest_modules) 377 | if len(imported_ids) > 0: 378 | # record modified time, id, and spec of imported Nest modules 379 | py_modules[py_module_id] = (timestamp, imported_ids, py_module.__spec__) 380 | 381 | @staticmethod 382 | def _import_nest_modules_from_dir( 383 | path: str, 384 | namespace: str, 385 | py_modules: Dict[str, float], 386 | nest_modules: Dict[str, object], 387 | meta: Dict[str, object] = dict()) -> None: 388 | """Import registered Nest modules form a given directory. 389 | 390 | Parameters: 391 | path: 392 | The path to the directory 393 | namespace: 394 | A namespace that is used to avoid name conflicts 395 | py_modules: 396 | The dict for storing modified timestamp of python modules 397 | nest_modules: 398 | The dict for storing Nest modules 399 | meta: 400 | Global meta information 401 | 402 | Returns: 403 | The Nest modules 404 | The set of python modules 405 | """ 406 | 407 | for entry in os.listdir(path): 408 | file_path = os.path.join(path, entry) 409 | if entry.endswith('.py') and os.path.isfile(file_path): 410 | ModuleManager._import_nest_modules_from_file(file_path, namespace, py_modules, nest_modules, meta) 411 | 412 | @staticmethod 413 | def _fetch_nest_modules_from_url(url: str, dst: str) -> None: 414 | """Fetch and unzip Nest modules from url. 415 | 416 | Parameters: 417 | url: 418 | URL of the zip file or git repo 419 | dst: 420 | Save dir path 421 | """ 422 | 423 | def _hook(count, block_size, total_size): 424 | size = float(count * block_size) / (1024.0 * 1024.0) 425 | total_size = float(total_size / (1024.0 * 1024.0)) 426 | if total_size > 0: 427 | size = min(size, total_size) 428 | percent = 100.0 * size / total_size 429 | sys.stdout.write("\rFetching...%d%%, %.2f MB / %.2f MB" % (percent, size, total_size)) 430 | else: 431 | sys.stdout.write("\rFetching...%.2f MB" % size) 432 | sys.stdout.flush() 433 | 434 | # extract 435 | if url.endswith('zip'): 436 | import random 437 | import string 438 | import zipfile 439 | from urllib import request, error 440 | 441 | cache_name = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(6)) + '.cache' 442 | cache_path = os.path.join(dst, cache_name) 443 | 444 | try: 445 | # download 446 | request.urlretrieve(url, cache_path, _hook) 447 | sys.stdout.write('\n') 448 | # unzip 449 | with zipfile.ZipFile(cache_path, 'r') as f: 450 | file_list = f.namelist() 451 | namespaces = set([v.split('/')[0] for v in file_list]) 452 | members = [v for v in file_list if '/' in v] 453 | f.extractall(dst, members) 454 | return namespaces 455 | except error.URLError as exc_info: 456 | U.alert_msg('Could not fetch "%s". %s' % (url, exc_info)) 457 | return [] 458 | except Exception as exc_info: 459 | U.alert_msg('Error occurs during extraction. %s' % exc_info) 460 | return [] 461 | finally: 462 | # remove cache 463 | if os.path.exists(cache_path): 464 | os.remove(cache_path) 465 | elif url.endswith('.git'): 466 | try: 467 | repo_name = url[url.rfind('/')+1: -4] 468 | match = re.search(r'(?:\s|^)(?:-b|--branch) (\w+)', url) 469 | if match: 470 | repo_name += '-' + match.group(1) 471 | subprocess.check_call(['git', 'clone'] + url.split() + [repo_name]) 472 | return [repo_name] 473 | except subprocess.CalledProcessError as exc_info: 474 | U.alert_msg('Failed to clone "%s".' % url) 475 | return [] 476 | else: 477 | raise NotImplementedError('Only supports zip file and git repo for now. Got "%s".' % url) 478 | 479 | @staticmethod 480 | def _install_namespaces_from_url(url: str, namespace: Optional[str] = None) -> None: 481 | """Install namespaces from url. 482 | 483 | Parameters: 484 | url: 485 | URL of the zip file or git repo 486 | namespace: 487 | Specified namespace 488 | """ 489 | # pre-process short URL 490 | if url.startswith('github@'): 491 | m = re.match(r'^github@([\w\-\_]+)/([\w\-\_]+)(:[\w\-\_]+)*$', url) 492 | repo = m.group(1) + '/' + m.group(2) 493 | branch = m.group(3) or ':master' 494 | url = '-b %s https://github.com/%s.git' % (branch[1:], repo) 495 | elif url.startswith('gitlab@'): 496 | m = re.match(r'^gitlab@([\w\-\_]+)/([\w\-\_]+)(:[\w\-\_]+)*$', url) 497 | repo = m.group(1) + '/' + m.group(2) 498 | branch = m.group(3) or ':master' 499 | url = '-b %s https://gitlab.com/%s.git' % (branch[1:], repo) 500 | elif url.startswith('bitbucket@'): 501 | m = re.match(r'^bitbucket@([\w\-\_]+)/([\w\-\_]+)(:[\w\-\_]+)*$', url) 502 | repo = m.group(1) + '/' + m.group(2) 503 | branch = m.group(3) or ':master' 504 | url = '-b %s https://bitbucket.org/%s.git' % (branch[1:], repo) 505 | elif url.startswith('file@'): 506 | path = url[5:] 507 | url = 'file:///' + os.path.abspath(path) 508 | 509 | for dirname in ModuleManager._fetch_nest_modules_from_url(url, './'): 510 | module_path = os.path.join('./', dirname) 511 | ModuleManager._install_namespaces_from_path(module_path, namespace) 512 | # parse config 513 | meta_path = os.path.join(module_path, settings['NAMESPACE_CONFIG_FILENAME']) 514 | meta = U.load_yaml(meta_path)[0] if os.path.exists(meta_path) else dict() 515 | if settings['AUTO_INSTALL_REQUIREMENTS']: 516 | # auto install deps 517 | for dep in meta.get('requirements', []): 518 | # helper function 519 | def install_dep(url, tool): 520 | # filter deps 521 | if re.match(r'^[a-zA-Z0-9<=>.-]+$', dep): 522 | try: 523 | subprocess.check_call([sys.executable, '-m', tool, 'install', dep]) 524 | except subprocess.CalledProcessError: 525 | U.alert_msg('Failed to install "%s" for "%s". Please manually install it.' % (dep, dirname)) 526 | if isinstance(dep, str): 527 | # use pip by default 528 | install_dep(dep, 'pip') 529 | elif isinstance(dep, dict) and 'url' in dep and 'tool' in dep: 530 | install_dep(dep['url'], dep['tool']) 531 | else: 532 | U.alert_msg('Invalid install requirement "%s".' % dep) 533 | 534 | @staticmethod 535 | def _install_namespaces_from_path(path: str, namespace: Optional[str] = None) -> None: 536 | """Install namespaces from path. 537 | 538 | Parameters: 539 | path: 540 | Path to the directory 541 | namespace: 542 | Specified namespace 543 | """ 544 | 545 | path = os.path.abspath(path) 546 | namespace = namespace or ModuleManager._format_namespace(os.path.basename(path)) 547 | search_paths = settings['SEARCH_PATHS'] 548 | for k, v in search_paths.items(): 549 | if namespace == k: 550 | U.alert_msg('Namespace "%s" is already bound to the path "%s".' % (k, v)) 551 | return 552 | if path == v: 553 | U.alert_msg('"%s" is already installed under the namespace "%s".' % (v, k)) 554 | return 555 | search_paths[namespace] = path 556 | settings['SEARCH_PATHS'] = search_paths 557 | settings.save() 558 | 559 | @staticmethod 560 | def _remove_namespaces_from_path(src: str) -> Optional[str]: 561 | """Remove namespaces from path. 562 | 563 | Parameters: 564 | src: 565 | Namespace or path 566 | """ 567 | 568 | if os.path.isdir(src): 569 | path, namespace = os.path.abspath(src), None 570 | else: 571 | path, namespace = None, src 572 | 573 | delete_key = None 574 | search_paths = settings['SEARCH_PATHS'] 575 | for k, v in search_paths.items(): 576 | if namespace == k: 577 | delete_key = k 578 | break 579 | if path == v: 580 | delete_key = k 581 | break 582 | 583 | if delete_key is None: 584 | if namespace: 585 | U.alert_msg('The namespace "%s" is not installed.' % namespace) 586 | if path: 587 | U.alert_msg('The path "%s" is not installed.' % path) 588 | else: 589 | path = search_paths.pop(delete_key) 590 | settings['SEARCH_PATHS'] = search_paths 591 | settings.save() 592 | return path 593 | 594 | @staticmethod 595 | def _pack_namespaces(srcs: List[str], dst: str) -> List[str]: 596 | """Pack namespaces to a zip file. 597 | 598 | Parameters: 599 | srcs: 600 | Path to the namespaces 601 | dst: 602 | Save path for the resulting zip file 603 | 604 | Returns: 605 | Archived files 606 | """ 607 | 608 | import zipfile 609 | 610 | save_list = dict() 611 | for src in srcs: 612 | namespace = os.path.basename(os.path.normpath(src)) 613 | # helper function 614 | def check_extension(filename): 615 | splits = filename.split('.') 616 | if len(splits) > 1: 617 | # Python file, YAML config, Plain text, Markdown file, Image, and IPython Notebook 618 | return splits[-1] in ['py', 'yml', 'txt', 'md', 'jpg', 'png', 'gif', 'ipynb'] 619 | else: 620 | return True 621 | # scan files 622 | file_list = [] 623 | for root, dirs, files in os.walk(src): 624 | dirs[:] = [v for v in dirs if not (v[0] == '.' or v.startswith('__'))] 625 | file_list += [os.path.join(root, v) for v in files if not v[0] == '.' and check_extension(v)] 626 | save_list[namespace] = file_list 627 | 628 | # save to the zip file 629 | with zipfile.ZipFile(dst, 'w', zipfile.ZIP_DEFLATED) as f: 630 | for v in file_list: 631 | f.write(v, os.path.join(namespace, os.path.relpath(v, src))) 632 | 633 | return save_list 634 | 635 | def _add_module_finder(self) -> None: 636 | """Add a custom finder to support Nest module import syntax. 637 | """ 638 | 639 | module_manager = self 640 | 641 | class NamespaceLoader(importlib.abc.Loader): 642 | def create_module(self, spec): 643 | _, namespace = spec.name.split('.') 644 | module = ModuleType(spec.name) 645 | module_manager._update_namespaces() 646 | meta = module_manager.namespaces.get(namespace) 647 | module.__path__ = [meta['module_path']] if meta else [] 648 | return module 649 | 650 | def exec_module(self, module): 651 | pass 652 | 653 | class NestModuleFinder(importlib.abc.MetaPathFinder): 654 | def __init__(self): 655 | super(NestModuleFinder, self).__init__() 656 | self.reserved_namespaces = [ 657 | v[:-3] for v in os.listdir(os.path.dirname(os.path.realpath(__file__))) if v.endswith('.py')] 658 | 659 | def find_spec(self, fullname, path, target=None): 660 | if fullname.startswith('nest.'): 661 | name = fullname.split('.') 662 | if len(name) == 2: 663 | if not name[1] in self.reserved_namespaces: 664 | return importlib.machinery.ModuleSpec(fullname, NamespaceLoader()) 665 | 666 | sys.meta_path.insert(0, NestModuleFinder()) 667 | 668 | def _update_namespaces(self) -> None: 669 | """Get the available namespaces. 670 | """ 671 | 672 | # user defined search paths 673 | dir_list = set() 674 | self.namespaces = dict() 675 | for k, v in settings['SEARCH_PATHS'].items(): 676 | if os.path.isdir(v): 677 | meta_path = os.path.join(v, settings['NAMESPACE_CONFIG_FILENAME']) 678 | meta = U.load_yaml(meta_path)[0] if os.path.exists(meta_path) else dict() 679 | meta['module_path'] = os.path.abspath(os.path.join(v, meta.get('module_path', './'))) 680 | if os.path.isdir(meta['module_path']): 681 | self.namespaces[k] = meta 682 | dir_list.add(meta['module_path']) 683 | else: 684 | U.alert_msg('Namespace "%s" has an invalid module path "%s".' % (k, meta['module_path'])) 685 | 686 | # current path 687 | current_path = os.path.abspath(os.curdir) 688 | if not current_path in dir_list: 689 | self.namespaces['main'] = dict(module_path=current_path) 690 | 691 | def _update_modules(self) -> None: 692 | """Automatically import all available Nest modules. 693 | """ 694 | 695 | timestamp = datetime.now().timestamp() 696 | if timestamp - self.update_timestamp > settings['UPDATE_INTERVAL']: 697 | for namespace, meta in self.namespaces.items(): 698 | importlib.import_module('nest.' + namespace) 699 | ModuleManager._import_nest_modules_from_dir(meta['module_path'], namespace, self.py_modules, self.nest_modules, meta) 700 | self.update_timestamp = timestamp 701 | 702 | def __iter__(self) -> Iterator: 703 | """Iterator for Nest modules. 704 | 705 | Returns: 706 | The Nest module iterator 707 | """ 708 | 709 | self._update_modules() 710 | return iter(self.nest_modules.items()) 711 | 712 | def __len__(self): 713 | """Number of Nest modules 714 | 715 | Returns: 716 | The number of Nest modules 717 | """ 718 | 719 | self._update_modules() 720 | return len(self.nest_modules) 721 | 722 | def _ipython_key_completions_(self) -> List[str]: 723 | """Support IPython key completion. 724 | 725 | Returns: 726 | A list of module ids 727 | """ 728 | self._update_modules() 729 | return list(self.nest_modules.keys()) 730 | 731 | def __dir__(self) -> List[str]: 732 | """Support IDE auto-completion 733 | 734 | Returns: 735 | A list of module names 736 | """ 737 | 738 | self._update_modules() 739 | return list([U.decode_id(uid)[1] for uid in self.nest_modules.keys()]) 740 | 741 | @exception 742 | def __getattr__(self, key: str) -> object: 743 | """Get a Nest module by name. 744 | 745 | Parameters: 746 | key: 747 | Name of the Nest module 748 | 749 | Returns: 750 | The Nest module 751 | """ 752 | 753 | self._update_modules() 754 | matches = [] 755 | for uid in self.nest_modules.keys(): 756 | _, module_key = U.decode_id(uid) 757 | if key == module_key: 758 | matches.append(uid) 759 | if len(matches) == 0: 760 | raise KeyError('Could not find the Nest module "%s".' % key) 761 | elif len(matches) > 1: 762 | warnings.warn('Multiple Nest modules with this name have been found. \n' 763 | 'The returned module is "%s", but you can use nest.modules[regex] to specify others: \n%s' % 764 | (matches[0], '\n'.join(['[%d] %s %s' % (k, v, self.nest_modules[v].sig) for k, v in enumerate(matches)]))) 765 | 766 | return self.nest_modules[matches[0]].clone() 767 | 768 | @exception 769 | def __getitem__(self, key: str) -> object: 770 | """Get a Nest module by a query string. 771 | 772 | There are three match modes: 773 | 1. Exact match if the query string starts with '$': 774 | E.g., nest.modules['$nest/optimizer'] 775 | 2. Regex match if the query string starts with 'r/': 776 | E.g., nest.modules['r/.*optim\w+'] 777 | 3. Wildcard match if otherwise: 778 | E.g., nest.modules['optim*er']. 779 | Note that a wildcard is automatically added to the beginning of the string. 780 | 781 | Parameters: 782 | key: 783 | The query string 784 | 785 | Returns: 786 | The Nest module 787 | """ 788 | 789 | self._update_modules() 790 | if isinstance(key, str): 791 | if key.startswith('$'): 792 | # exact match 793 | key = key[1:] 794 | if key in self.nest_modules.keys(): 795 | return self.nest_modules[key].clone() 796 | else: 797 | raise KeyError('Could not find Nest module "%s".' % key) 798 | elif key.startswith('r/'): 799 | # regex match 800 | key = key[2:] 801 | r = re.compile(key) 802 | matches = list(filter(r.match, self.nest_modules.keys())) 803 | if len(matches) == 0: 804 | raise KeyError('Could not find a Nest module matches regex "%s".' % key) 805 | elif len(matches) > 1: 806 | warnings.warn('Multiple Nest modules match the given regex have been found. \n' 807 | 'The returned module is "%s", but you can adjust regex to specify others: \n%s' % 808 | (matches[0], '\n'.join(['[%d] %s %s' % (k, v, self.nest_modules[v].sig) for k, v in enumerate(matches)]))) 809 | return self.nest_modules[matches[0]].clone() 810 | else: 811 | # wildcard match 812 | if not key[0] == '*': 813 | key = '*' + key 814 | matches = fnmatch.filter(self.nest_modules.keys(), key) 815 | if len(matches) == 0: 816 | raise KeyError('Could not find a Nest module matches query "%s".' % key) 817 | elif len(matches) > 1: 818 | warnings.warn('Multiple Nest modules match the given regex have been found. \n' 819 | 'The returned module is "%s", but you can adjust regex to specify others: \n%s' % 820 | (matches[0], '\n'.join(['[%d] %s %s' % (k, v, self.nest_modules[v].sig) for k, v in enumerate(matches)]))) 821 | return self.nest_modules[matches[0]].clone() 822 | else: 823 | raise NotImplementedError 824 | 825 | def __repr__(self) -> str: 826 | return 'nest.modules' 827 | 828 | def __str__(self) -> str: 829 | num = self.__len__() 830 | if num == 0: 831 | return 'No Nest module found.' 832 | elif num == 1: 833 | return 'Found 1 Nest module.' 834 | else: 835 | return '%d Nest modules are availble.' % num 836 | 837 | # global manager 838 | module_manager = ModuleManager() 839 | --------------------------------------------------------------------------------