├── .flake8 ├── .github └── workflows │ ├── CI.yml │ └── Upload.yml ├── .gitignore ├── .style.yapf ├── LICENSE ├── README.md ├── setup.py ├── tests ├── README.md ├── __init__.py ├── test_autoaugment.py ├── test_data.py ├── test_lr_scheduler.py ├── test_metric.py ├── test_nn.py ├── test_obj.py ├── test_op.py ├── test_summary.py └── test_transform.py └── torchtoolbox ├── __init__.py ├── data ├── __init__.py ├── dataprefetcher.py ├── datasets.py ├── dynamic_data_provider.py ├── lmdb_dataset.py ├── sampler.py └── utils.py ├── io ├── __init__.py └── structured_data_io.py ├── metric ├── __init__.py ├── feature_verification.py ├── map.py └── metric.py ├── nn ├── __init__.py ├── activation.py ├── conv.py ├── functional.py ├── init.py ├── loss.py ├── metric_loss.py ├── modules.py ├── norm.py ├── operators.py ├── parallel │ ├── EncodingDataParallel.py │ └── __init__.py ├── sequential.py └── transformer.py ├── objects ├── __init__.py └── bbox.py ├── optimizer ├── __init__.py ├── drop_block_scheduler.py ├── lookahead.py ├── lr_scheduler.py └── sgd_gc.py ├── tools ├── __init__.py ├── config_parser.py ├── convert_lmdb.py ├── distribute.py ├── dotdict.py ├── mixup.py ├── registry.py ├── reset_model_setting.py ├── summary.py ├── tensor_transfer.py └── utils.py └── transform ├── __init__.py ├── autoaugment.py ├── dynamic_transform.py ├── functional.py ├── hybrid.py └── transforms.py /.flake8: -------------------------------------------------------------------------------- 1 | # Add "python.linting.flake8Args": ["--config=.flake8"] to your vscode setting.json 2 | # See issues https://github.com/Microsoft/vscode-python/issues/1884 3 | [flake8] 4 | ignore = W503, E203, E221, C901, E501 5 | max-line-length = 120 6 | max-complexity = 18 7 | select = B,C,E,F,W,T4,B9 8 | exclude = build,__init__.py -------------------------------------------------------------------------------- /.github/workflows/CI.yml: -------------------------------------------------------------------------------- 1 | name: Torch-Toolbox-CI 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | name: Test task. 8 | runs-on: ubuntu-18.04 9 | steps: 10 | - name: Install Pytorch 11 | run: | 12 | conda install -y pytorch torchvision cpuonly -c pytorch \ 13 | && conda clean --all --yes 14 | - name: Install dependence 15 | run: | 16 | conda install -y tqdm pyarrow six python-lmdb pytest scikit-learn \ 17 | && conda clean --all --yes \ 18 | && /usr/share/miniconda/bin/pip install -q opencv-python tensorboard pyyaml prettytable transformers 19 | - name: Checkout 20 | uses: actions/checkout@master 21 | 22 | - name: Run test 23 | run: | 24 | /usr/share/miniconda/bin/pytest tests/ 25 | -------------------------------------------------------------------------------- /.github/workflows/Upload.yml: -------------------------------------------------------------------------------- 1 | name: Automatic upload package to PyPI 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | name: Publish package 8 | runs-on: ubuntu-latest 9 | steps: 10 | - name: Checkout 11 | uses: actions/checkout@master 12 | - name: Set up Python 3.7 13 | uses: actions/setup-python@v1 14 | with: 15 | python-version: 3.7 16 | - name: Install setuptools 17 | run: >- 18 | python -m 19 | pip install 20 | setuptools wheel 21 | --user 22 | - name: Generate Pubilsh File 23 | run: | 24 | python setup.py sdist bdist_wheel 25 | - name: Automatic Publish 26 | if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags') 27 | uses: pypa/gh-action-pypi-publish@master 28 | with: 29 | user: ${{ secrets.PYPI_USER_NAME }} 30 | password: ${{ secrets.PYPI_PASSWOED }} 31 | 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### JetBrains template 3 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 4 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 5 | 6 | # User-specific stuff 7 | .idea/**/workspace.xml 8 | .idea/**/tasks.xml 9 | .idea/**/usage.statistics.xml 10 | .idea/**/dictionaries 11 | .idea/**/shelf 12 | 13 | # Generated files 14 | .idea/**/contentModel.xml 15 | 16 | # Sensitive or high-churn files 17 | .idea/**/dataSources/ 18 | .idea/**/dataSources.ids 19 | .idea/**/dataSources.local.xml 20 | .idea/**/sqlDataSources.xml 21 | .idea/**/dynamic.xml 22 | .idea/**/uiDesigner.xml 23 | .idea/**/dbnavigator.xml 24 | 25 | # Gradle 26 | .idea/**/gradle.xml 27 | .idea/**/libraries 28 | 29 | # Gradle and Maven with auto-import 30 | # When using Gradle or Maven with auto-import, you should exclude module files, 31 | # since they will be recreated, and may cause churn. Uncomment if using 32 | # auto-import. 33 | # .idea/modules.xml 34 | # .idea/*.iml 35 | # .idea/modules 36 | # *.iml 37 | # *.ipr 38 | 39 | # CMake 40 | cmake-build-*/ 41 | 42 | # Mongo Explorer plugin 43 | .idea/**/mongoSettings.xml 44 | 45 | # File-based project format 46 | *.iws 47 | 48 | # IntelliJ 49 | out/ 50 | 51 | # mpeltonen/sbt-idea plugin 52 | .idea_modules/ 53 | 54 | # JIRA plugin 55 | atlassian-ide-plugin.xml 56 | 57 | # Cursive Clojure plugin 58 | .idea/replstate.xml 59 | 60 | # Crashlytics plugin (for Android Studio and IntelliJ) 61 | com_crashlytics_export_strings.xml 62 | crashlytics.properties 63 | crashlytics-build.properties 64 | fabric.properties 65 | 66 | # Editor-based Rest Client 67 | .idea/httpRequests 68 | 69 | # Android studio 3.1+ serialized cache file 70 | .idea/caches/build_file_checksums.ser 71 | 72 | ### Python template 73 | # Byte-compiled / optimized / DLL files 74 | __pycache__/ 75 | *.py[cod] 76 | *$py.class 77 | 78 | # C extensions 79 | *.so 80 | 81 | # Distribution / packaging 82 | .Python 83 | build/ 84 | develop-eggs/ 85 | dist/ 86 | downloads/ 87 | eggs/ 88 | .eggs/ 89 | lib/ 90 | lib64/ 91 | parts/ 92 | sdist/ 93 | var/ 94 | wheels/ 95 | pip-wheel-metadata/ 96 | share/python-wheels/ 97 | *.egg-info/ 98 | .installed.cfg 99 | *.egg 100 | MANIFEST 101 | 102 | # PyInstaller 103 | # Usually these files are written by a python script from a template 104 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 105 | *.manifest 106 | *.spec 107 | 108 | # Installer logs 109 | pip-log.txt 110 | pip-delete-this-directory.txt 111 | 112 | # Unit test / coverage reports 113 | htmlcov/ 114 | .tox/ 115 | .nox/ 116 | .coverage 117 | .coverage.* 118 | .cache 119 | nosetests.xml 120 | coverage.xml 121 | *.cover 122 | .hypothesis/ 123 | .pytest_cache/ 124 | 125 | # Translations 126 | *.mo 127 | *.pot 128 | 129 | # Django stuff: 130 | *.log 131 | local_settings.py 132 | db.sqlite3 133 | db.sqlite3-journal 134 | 135 | # Flask stuff: 136 | instance/ 137 | .webassets-cache 138 | 139 | # Scrapy stuff: 140 | .scrapy 141 | 142 | # Sphinx documentation 143 | docs/_build/ 144 | 145 | # PyBuilder 146 | target/ 147 | 148 | # Jupyter Notebook 149 | .ipynb_checkpoints 150 | *.ipynb 151 | 152 | # IPython 153 | profile_default/ 154 | ipython_config.py 155 | 156 | # pyenv 157 | .python-version 158 | 159 | # pipenv 160 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 161 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 162 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 163 | # install all needed dependencies. 164 | #Pipfile.lock 165 | 166 | # celery beat schedule file 167 | celerybeat-schedule 168 | 169 | # SageMath parsed files 170 | *.sage.py 171 | 172 | # Environments 173 | .env 174 | .venv 175 | env/ 176 | venv/ 177 | ENV/ 178 | env.bak/ 179 | venv.bak/ 180 | 181 | # Spyder project settings 182 | .spyderproject 183 | .spyproject 184 | 185 | # Rope project settings 186 | .ropeproject 187 | 188 | # mkdocs documentation 189 | /site 190 | 191 | # mypy 192 | .mypy_cache/ 193 | .dmypy.json 194 | dmypy.json 195 | 196 | # Pyre type checker 197 | .pyre/ 198 | 199 | ### Example user template template 200 | ### Example user template 201 | 202 | # IntelliJ project files 203 | .idea 204 | *.iml 205 | out 206 | gen 207 | ### Linux template 208 | *~ 209 | 210 | # temporary files which can be created if a process still has a handle open of a deleted file 211 | .fuse_hidden* 212 | 213 | # KDE directory preferences 214 | .directory 215 | 216 | # Linux trash folder which might appear on any partition or disk 217 | .Trash-* 218 | 219 | # .nfs files are created when an open file is removed but is still being accessed 220 | .nfs* 221 | 222 | .vscode/ 223 | dirty_code/* 224 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | # Align closing bracket with visual indentation. 3 | align_closing_bracket_with_visual_indent=True 4 | 5 | # Allow dictionary keys to exist on multiple lines. For example: 6 | # 7 | # x = { 8 | # ('this is the first element of a tuple', 9 | # 'this is the second element of a tuple'): 10 | # value, 11 | # } 12 | allow_multiline_dictionary_keys=False 13 | 14 | # Allow lambdas to be formatted on more than one line. 15 | allow_multiline_lambdas=False 16 | 17 | # Allow splitting before a default / named assignment in an argument list. 18 | allow_split_before_default_or_named_assigns=True 19 | 20 | # Allow splits before the dictionary value. 21 | allow_split_before_dict_value=True 22 | 23 | # Let spacing indicate operator precedence. For example: 24 | # 25 | # a = 1 * 2 + 3 / 4 26 | # b = 1 / 2 - 3 * 4 27 | # c = (1 + 2) * (3 - 4) 28 | # d = (1 - 2) / (3 + 4) 29 | # e = 1 * 2 - 3 30 | # f = 1 + 2 + 3 + 4 31 | # 32 | # will be formatted as follows to indicate precedence: 33 | # 34 | # a = 1*2 + 3/4 35 | # b = 1/2 - 3*4 36 | # c = (1+2) * (3-4) 37 | # d = (1-2) / (3+4) 38 | # e = 1*2 - 3 39 | # f = 1 + 2 + 3 + 4 40 | # 41 | arithmetic_precedence_indication=False 42 | 43 | # Number of blank lines surrounding top-level function and class 44 | # definitions. 45 | blank_lines_around_top_level_definition=2 46 | 47 | # Insert a blank line before a class-level docstring. 48 | blank_line_before_class_docstring=False 49 | 50 | # Insert a blank line before a module docstring. 51 | blank_line_before_module_docstring=False 52 | 53 | # Insert a blank line before a 'def' or 'class' immediately nested 54 | # within another 'def' or 'class'. For example: 55 | # 56 | # class Foo: 57 | # # <------ this blank line 58 | # def method(): 59 | # ... 60 | blank_line_before_nested_class_or_def=False 61 | 62 | # Do not split consecutive brackets. Only relevant when 63 | # dedent_closing_brackets is set. For example: 64 | # 65 | # call_func_that_takes_a_dict( 66 | # { 67 | # 'key1': 'value1', 68 | # 'key2': 'value2', 69 | # } 70 | # ) 71 | # 72 | # would reformat to: 73 | # 74 | # call_func_that_takes_a_dict({ 75 | # 'key1': 'value1', 76 | # 'key2': 'value2', 77 | # }) 78 | coalesce_brackets=False 79 | 80 | # The column limit. 81 | column_limit=128 82 | 83 | # The style for continuation alignment. Possible values are: 84 | # 85 | # - SPACE: Use spaces for continuation alignment. This is default behavior. 86 | # - FIXED: Use fixed number (CONTINUATION_INDENT_WIDTH) of columns 87 | # (ie: CONTINUATION_INDENT_WIDTH/INDENT_WIDTH tabs or 88 | # CONTINUATION_INDENT_WIDTH spaces) for continuation alignment. 89 | # - VALIGN-RIGHT: Vertically align continuation lines to multiple of 90 | # INDENT_WIDTH columns. Slightly right (one tab or a few spaces) if 91 | # cannot vertically align continuation lines with indent characters. 92 | continuation_align_style=SPACE 93 | 94 | # Indent width used for line continuations. 95 | continuation_indent_width=4 96 | 97 | # Put closing brackets on a separate line, dedented, if the bracketed 98 | # expression can't fit in a single line. Applies to all kinds of brackets, 99 | # including function definitions and calls. For example: 100 | # 101 | # config = { 102 | # 'key1': 'value1', 103 | # 'key2': 'value2', 104 | # } # <--- this bracket is dedented and on a separate line 105 | # 106 | # time_series = self.remote_client.query_entity_counters( 107 | # entity='dev3246.region1', 108 | # key='dns.query_latency_tcp', 109 | # transform=Transformation.AVERAGE(window=timedelta(seconds=60)), 110 | # start_ts=now()-timedelta(days=3), 111 | # end_ts=now(), 112 | # ) # <--- this bracket is dedented and on a separate line 113 | dedent_closing_brackets=False 114 | 115 | # Disable the heuristic which places each list element on a separate line 116 | # if the list is comma-terminated. 117 | disable_ending_comma_heuristic=False 118 | 119 | # Place each dictionary entry onto its own line. 120 | each_dict_entry_on_separate_line=True 121 | 122 | # Require multiline dictionary even if it would normally fit on one line. 123 | # For example: 124 | # 125 | # config = { 126 | # 'key1': 'value1' 127 | # } 128 | force_multiline_dict=False 129 | 130 | # The regex for an i18n comment. The presence of this comment stops 131 | # reformatting of that line, because the comments are required to be 132 | # next to the string they translate. 133 | i18n_comment= 134 | 135 | # The i18n function call names. The presence of this function stops 136 | # reformattting on that line, because the string it has cannot be moved 137 | # away from the i18n comment. 138 | i18n_function_call= 139 | 140 | # Indent blank lines. 141 | indent_blank_lines=False 142 | 143 | # Put closing brackets on a separate line, indented, if the bracketed 144 | # expression can't fit in a single line. Applies to all kinds of brackets, 145 | # including function definitions and calls. For example: 146 | # 147 | # config = { 148 | # 'key1': 'value1', 149 | # 'key2': 'value2', 150 | # } # <--- this bracket is indented and on a separate line 151 | # 152 | # time_series = self.remote_client.query_entity_counters( 153 | # entity='dev3246.region1', 154 | # key='dns.query_latency_tcp', 155 | # transform=Transformation.AVERAGE(window=timedelta(seconds=60)), 156 | # start_ts=now()-timedelta(days=3), 157 | # end_ts=now(), 158 | # ) # <--- this bracket is indented and on a separate line 159 | indent_closing_brackets=False 160 | 161 | # Indent the dictionary value if it cannot fit on the same line as the 162 | # dictionary key. For example: 163 | # 164 | # config = { 165 | # 'key1': 166 | # 'value1', 167 | # 'key2': value1 + 168 | # value2, 169 | # } 170 | indent_dictionary_value=False 171 | 172 | # The number of columns to use for indentation. 173 | indent_width=4 174 | 175 | # Join short lines into one line. E.g., single line 'if' statements. 176 | join_multiple_lines=False 177 | 178 | # Do not include spaces around selected binary operators. For example: 179 | # 180 | # 1 + 2 * 3 - 4 / 5 181 | # 182 | # will be formatted as follows when configured with "*,/": 183 | # 184 | # 1 + 2*3 - 4/5 185 | no_spaces_around_selected_binary_operators= 186 | 187 | # Use spaces around default or named assigns. 188 | spaces_around_default_or_named_assign=False 189 | 190 | # Adds a space after the opening '{' and before the ending '}' dict delimiters. 191 | # 192 | # {1: 2} 193 | # 194 | # will be formatted as: 195 | # 196 | # { 1: 2 } 197 | spaces_around_dict_delimiters=False 198 | 199 | # Adds a space after the opening '[' and before the ending ']' list delimiters. 200 | # 201 | # [1, 2] 202 | # 203 | # will be formatted as: 204 | # 205 | # [ 1, 2 ] 206 | spaces_around_list_delimiters=False 207 | 208 | # Use spaces around the power operator. 209 | spaces_around_power_operator=False 210 | 211 | # Use spaces around the subscript / slice operator. For example: 212 | # 213 | # my_list[1 : 10 : 2] 214 | spaces_around_subscript_colon=False 215 | 216 | # Adds a space after the opening '(' and before the ending ')' tuple delimiters. 217 | # 218 | # (1, 2, 3) 219 | # 220 | # will be formatted as: 221 | # 222 | # ( 1, 2, 3 ) 223 | spaces_around_tuple_delimiters=False 224 | 225 | # The number of spaces required before a trailing comment. 226 | # This can be a single value (representing the number of spaces 227 | # before each trailing comment) or list of values (representing 228 | # alignment column values; trailing comments within a block will 229 | # be aligned to the first column value that is greater than the maximum 230 | # line length within the block). For example: 231 | # 232 | # With spaces_before_comment=5: 233 | # 234 | # 1 + 1 # Adding values 235 | # 236 | # will be formatted as: 237 | # 238 | # 1 + 1 # Adding values <-- 5 spaces between the end of the statement and comment 239 | # 240 | # With spaces_before_comment=15, 20: 241 | # 242 | # 1 + 1 # Adding values 243 | # two + two # More adding 244 | # 245 | # longer_statement # This is a longer statement 246 | # short # This is a shorter statement 247 | # 248 | # a_very_long_statement_that_extends_beyond_the_final_column # Comment 249 | # short # This is a shorter statement 250 | # 251 | # will be formatted as: 252 | # 253 | # 1 + 1 # Adding values <-- end of line comments in block aligned to col 15 254 | # two + two # More adding 255 | # 256 | # longer_statement # This is a longer statement <-- end of line comments in block aligned to col 20 257 | # short # This is a shorter statement 258 | # 259 | # a_very_long_statement_that_extends_beyond_the_final_column # Comment <-- the end of line comments are aligned based on the line length 260 | # short # This is a shorter statement 261 | # 262 | spaces_before_comment=2 263 | 264 | # Insert a space between the ending comma and closing bracket of a list, 265 | # etc. 266 | space_between_ending_comma_and_closing_bracket=True 267 | 268 | # Use spaces inside brackets, braces, and parentheses. For example: 269 | # 270 | # method_call( 1 ) 271 | # my_dict[ 3 ][ 1 ][ get_index( *args, **kwargs ) ] 272 | # my_set = { 1, 2, 3 } 273 | space_inside_brackets=False 274 | 275 | # Split before arguments 276 | split_all_comma_separated_values=False 277 | 278 | # Split before arguments, but do not split all subexpressions recursively 279 | # (unless needed). 280 | split_all_top_level_comma_separated_values=False 281 | 282 | # Split before arguments if the argument list is terminated by a 283 | # comma. 284 | split_arguments_when_comma_terminated=False 285 | 286 | # Set to True to prefer splitting before '+', '-', '*', '/', '//', or '@' 287 | # rather than after. 288 | split_before_arithmetic_operator=False 289 | 290 | # Set to True to prefer splitting before '&', '|' or '^' rather than 291 | # after. 292 | split_before_bitwise_operator=True 293 | 294 | # Split before the closing bracket if a list or dict literal doesn't fit on 295 | # a single line. 296 | split_before_closing_bracket=True 297 | 298 | # Split before a dictionary or set generator (comp_for). For example, note 299 | # the split before the 'for': 300 | # 301 | # foo = { 302 | # variable: 'Hello world, have a nice day!' 303 | # for variable in bar if variable != 42 304 | # } 305 | split_before_dict_set_generator=True 306 | 307 | # Split before the '.' if we need to split a longer expression: 308 | # 309 | # foo = ('This is a really long string: {}, {}, {}, {}'.format(a, b, c, d)) 310 | # 311 | # would reformat to something like: 312 | # 313 | # foo = ('This is a really long string: {}, {}, {}, {}' 314 | # .format(a, b, c, d)) 315 | split_before_dot=False 316 | 317 | # Split after the opening paren which surrounds an expression if it doesn't 318 | # fit on a single line. 319 | split_before_expression_after_opening_paren=False 320 | 321 | # If an argument / parameter list is going to be split, then split before 322 | # the first argument. 323 | split_before_first_argument=False 324 | 325 | # Set to True to prefer splitting before 'and' or 'or' rather than 326 | # after. 327 | split_before_logical_operator=True 328 | 329 | # Split named assignments onto individual lines. 330 | split_before_named_assigns=True 331 | 332 | # Set to True to split list comprehensions and generators that have 333 | # non-trivial expressions and multiple clauses before each of these 334 | # clauses. For example: 335 | # 336 | # result = [ 337 | # a_long_var + 100 for a_long_var in xrange(1000) 338 | # if a_long_var % 10] 339 | # 340 | # would reformat to something like: 341 | # 342 | # result = [ 343 | # a_long_var + 100 344 | # for a_long_var in xrange(1000) 345 | # if a_long_var % 10] 346 | split_complex_comprehension=False 347 | 348 | # The penalty for splitting right after the opening bracket. 349 | split_penalty_after_opening_bracket=300 350 | 351 | # The penalty for splitting the line after a unary operator. 352 | split_penalty_after_unary_operator=10000 353 | 354 | # The penalty of splitting the line around the '+', '-', '*', '/', '//', 355 | # ``%``, and '@' operators. 356 | split_penalty_arithmetic_operator=300 357 | 358 | # The penalty for splitting right before an if expression. 359 | split_penalty_before_if_expr=0 360 | 361 | # The penalty of splitting the line around the '&', '|', and '^' 362 | # operators. 363 | split_penalty_bitwise_operator=300 364 | 365 | # The penalty for splitting a list comprehension or generator 366 | # expression. 367 | split_penalty_comprehension=80 368 | 369 | # The penalty for characters over the column limit. 370 | split_penalty_excess_character=7000 371 | 372 | # The penalty incurred by adding a line split to the unwrapped line. The 373 | # more line splits added the higher the penalty. 374 | split_penalty_for_added_line_split=30 375 | 376 | # The penalty of splitting a list of "import as" names. For example: 377 | # 378 | # from a_very_long_or_indented_module_name_yada_yad import (long_argument_1, 379 | # long_argument_2, 380 | # long_argument_3) 381 | # 382 | # would reformat to something like: 383 | # 384 | # from a_very_long_or_indented_module_name_yada_yad import ( 385 | # long_argument_1, long_argument_2, long_argument_3) 386 | split_penalty_import_names=0 387 | 388 | # The penalty of splitting the line around the 'and' and 'or' 389 | # operators. 390 | split_penalty_logical_operator=300 391 | 392 | # Use the Tab character for indentation. 393 | use_tabs=False 394 | 395 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) X.Yang 2019-present, 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: pistonyang@gmail.com 3 | 4 | from setuptools import setup, find_packages 5 | from torchtoolbox import VERSION 6 | 7 | setup(name='torchtoolbox', 8 | version=VERSION, 9 | author='X.Yang', 10 | author_email='pistonyang@gmail.com', 11 | url='https://github.com/PistonY/torch-toolbox', 12 | description='ToolBox to make using Pytorch much easier.', 13 | long_description_content_type="text/markdown", 14 | long_description=open('README.md').read(), 15 | license='BSD 3-Clause', 16 | packages=find_packages(exclude=('*tests*', 'dirty_code')), 17 | zip_safe=True, 18 | classifiers=[ 19 | 'Programming Language :: Python :: 3', 20 | ], 21 | install_requires=[ 22 | 'numpy', 'tqdm', 'pyarrow', 'six', 'lmdb', 'scikit-learn', 'scipy', 'opencv-python', 'pyyaml', 'tensorboard', 23 | 'prettytable', 'transformers' 24 | ]) 25 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | # Test Module 2 | 3 | I'll use pytest for test. 4 | 5 | CI doesn't support CUDA. Please don't write any CUDA supported code in test moduel. -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | -------------------------------------------------------------------------------- /tests/test_autoaugment.py: -------------------------------------------------------------------------------- 1 | from torchtoolbox.transform.autoaugment import ImageNetPolicy, RandAugment 2 | import numpy as np 3 | from PIL import Image 4 | 5 | autoaugment = ImageNetPolicy 6 | randaugment = RandAugment(n=2, m=9) 7 | 8 | 9 | def _gen_fake_img(size=None): 10 | if size is None: 11 | size = (224, 224, 3) 12 | img = np.random.randint(0, 255, size=size, dtype='uint8') 13 | return Image.fromarray(img) 14 | 15 | 16 | def test_augment(): 17 | img = _gen_fake_img() 18 | for _ in range(1000): 19 | autoaugment(img) 20 | randaugment(img) 21 | # img = Image.open('/media/devin/data/720p/rzdf/0058.png') 22 | # ra = randaugment(img) 23 | # aa = autoaugment(img) 24 | # Image._show(aa) 25 | -------------------------------------------------------------------------------- /tests/test_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | from torchtoolbox.data.datasets import NonLabelDataset 4 | 5 | 6 | def test_nonlabeldataset(root_dir='/media/piston/data/FFHQ/train'): 7 | try: 8 | dt = NonLabelDataset(root_dir) 9 | except FileNotFoundError: 10 | return 11 | _ = len(dt) 12 | for i, _ in enumerate(dt): 13 | if i == 10: 14 | break 15 | return 16 | -------------------------------------------------------------------------------- /tests/test_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from torch import optim 2 | from transformers import AlbertModel 3 | from math import cos, pi, isclose 4 | from torchtoolbox.optimizer.lr_scheduler import * 5 | 6 | model = AlbertModel.from_pretrained("albert-base-v2") 7 | 8 | def run_scheduler(lr_scheduler, steps): 9 | if (len(lr_scheduler.optimizer.param_groups) == 1): 10 | lrs = [] 11 | for i in range(steps): 12 | current_lr = lr_scheduler.optimizer.param_groups[0]['lr'] 13 | lrs.append(current_lr) 14 | lr_scheduler.step() 15 | else: 16 | lrs = [[] for _ in range(len(lr_scheduler.optimizer.param_groups))] 17 | for i in range(steps): 18 | for i, param_group in enumerate(lr_scheduler.optimizer.param_groups): 19 | lrs[i].append(param_group['lr']) 20 | lr_scheduler.step() 21 | return lrs 22 | 23 | def test_CosineWarmupLr(): 24 | optimizer = optim.SGD(model.parameters(), lr=3e-5) 25 | batches_per_epoch = 10 26 | epochs = 5 27 | base_lr = 3e-5 28 | warmup_epochs = 1 29 | lr_scheduler = CosineWarmupLr(optimizer, batches_per_epoch, epochs, 30 | base_lr=base_lr, warmup_epochs=warmup_epochs) 31 | lrs = run_scheduler(lr_scheduler, batches_per_epoch * epochs) 32 | # test warmup 33 | warmup_steps = batches_per_epoch * warmup_epochs 34 | lr_increase_per_warmup_step = base_lr / warmup_steps 35 | assert isclose(lrs[0], lr_increase_per_warmup_step) 36 | assert isclose(lrs[3], lr_increase_per_warmup_step * 4) 37 | assert isclose(lrs[9], lr_increase_per_warmup_step * 10) 38 | assert isclose(lrs[warmup_steps-1], base_lr) 39 | assert lrs[warmup_steps-1] > lrs[warmup_steps] 40 | # test cosine decay 41 | assert isclose(lrs[warmup_steps], lr_scheduler.total_lr_decay * (1+cos(pi * (1/40)))/2) 42 | assert isclose(lrs[warmup_steps+10], lr_scheduler.total_lr_decay * (1+cos(pi * (11/40)))/2) 43 | assert isclose(lrs[warmup_steps+19], lr_scheduler.total_lr_decay * (1+cos(pi * (20/40)))/2) 44 | assert isclose(lrs[-1], 0) 45 | print("✅ passed test_CosineWarmupLr") 46 | 47 | def test_get_cosine_warmup_lr_scheduler(): 48 | params = get_layerwise_decay_params_for_bert(model) 49 | optimizer = optim.SGD(params, lr=3e-5) 50 | 51 | batches_per_epoch = 10 52 | epochs = 5 53 | warmup_epochs = 1 54 | 55 | base_lrs = [] 56 | for param_group in params: 57 | base_lrs.append(param_group['lr']) 58 | 59 | lr_scheduler = get_cosine_warmup_lr_scheduler(optimizer, batches_per_epoch, epochs, warmup_epochs=warmup_epochs) 60 | 61 | lrs = run_scheduler(lr_scheduler, batches_per_epoch * epochs) 62 | 63 | def test_schedule_for_one_base_lr(base_lr, lrs): 64 | # test warmup 65 | warmup_steps = batches_per_epoch * warmup_epochs 66 | lr_increase_per_warmup_step = base_lr / warmup_steps 67 | assert isclose(lrs[1], lr_increase_per_warmup_step) 68 | assert isclose(lrs[4], lr_increase_per_warmup_step * 4) 69 | assert isclose(lrs[10], lr_increase_per_warmup_step * 10) 70 | assert isclose(lrs[warmup_steps], base_lr) 71 | assert lrs[warmup_steps] > lrs[warmup_steps+1] 72 | # test cosine decay 73 | total_lr_decay = base_lr - 0 74 | assert isclose(lrs[warmup_steps+1], total_lr_decay * (1+cos(pi * (1/40)))/2) 75 | assert isclose(lrs[warmup_steps+10], total_lr_decay * (1+cos(pi * (10/40)))/2) 76 | assert isclose(lrs[warmup_steps+19], total_lr_decay * (1+cos(pi * (19/40)))/2) 77 | assert isclose(lrs[warmup_steps+39], total_lr_decay * (1+cos(pi * (39/40)))/2) 78 | 79 | # test the schedule for every param group 80 | for i in range(len(base_lrs)): 81 | test_schedule_for_one_base_lr(base_lrs[i], lrs[i]) 82 | print("✅ passed test_get_cosine_warmup_lr_scheduler") 83 | 84 | 85 | test_CosineWarmupLr() 86 | test_get_cosine_warmup_lr_scheduler() -------------------------------------------------------------------------------- /tests/test_metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | 4 | import torch 5 | import numpy as np 6 | from numpy.testing import assert_allclose 7 | from torchtoolbox.metric import Accuracy, NumericalCost, TopKAccuracy 8 | 9 | numerical_test_data = np.random.uniform(0, 1, size=(10, )) 10 | # Assume we have batch size of 10, and classes of 5. 11 | acc_test_label = np.random.randint(0, 5, size=(10, )) 12 | acc_test_pred = np.random.uniform(0, 1, size=(10, 5)) 13 | 14 | 15 | def get_true_numerical_result(test_data): 16 | return np.mean(test_data) 17 | 18 | 19 | # tests top1 acc 20 | def get_true_acc(label, pred): 21 | pred = np.argmax(pred, axis=1) 22 | acc = (pred == label).mean() 23 | return acc 24 | 25 | 26 | # tests top acc 27 | def get_ture_top3(label, pred): 28 | pred = np.argpartition(pred, -3)[:, -3:] 29 | num_ture_idx = 0 30 | for l, p in zip(label, pred): 31 | if l in p: 32 | num_ture_idx += 1 33 | return num_ture_idx / 10 34 | 35 | 36 | @torch.no_grad() 37 | def test_top1_acc(): 38 | true_acc = get_true_acc(acc_test_label, acc_test_pred) 39 | top1_acc = Accuracy() 40 | top1_acc.update(torch.Tensor(acc_test_pred), torch.Tensor(acc_test_label)) 41 | acc = top1_acc.get() 42 | assert true_acc == acc 43 | 44 | 45 | @torch.no_grad() 46 | def test_top_acc(): 47 | top3_true = get_ture_top3(acc_test_label, acc_test_pred) 48 | top3_acc = TopKAccuracy(top=3) 49 | top3_acc.update(torch.Tensor(acc_test_pred), torch.Tensor(acc_test_label)) 50 | top3 = top3_acc.get() 51 | assert top3_true == top3 52 | 53 | 54 | @torch.no_grad() 55 | def test_numerical_cost(): 56 | true_cost = get_true_numerical_result(numerical_test_data) 57 | nc = NumericalCost() 58 | for c in numerical_test_data: 59 | nc.update(torch.Tensor([ 60 | c, 61 | ])) 62 | cost = float(nc.get()) 63 | try: 64 | assert_allclose(true_cost, cost) 65 | except Exception: 66 | return 67 | -------------------------------------------------------------------------------- /tests/test_nn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | from torchtoolbox.nn.functional import class_balanced_weight 4 | from torchtoolbox.nn import * 5 | import torch 6 | from torch import nn 7 | import numpy as np 8 | 9 | 10 | @torch.no_grad() 11 | def test_lsloss(): 12 | pred = torch.rand(3, 10) 13 | label = torch.randint(0, 10, size=(3, )) 14 | Loss = LabelSmoothingLoss(10, 0.1) 15 | 16 | Loss1 = nn.CrossEntropyLoss() 17 | 18 | cost = Loss(pred, label) 19 | cost1 = Loss1(pred, label) 20 | 21 | assert cost.shape == cost1.shape 22 | 23 | 24 | @torch.no_grad() 25 | def test_logits_loss(): 26 | pred = torch.rand(3, 10) 27 | label = torch.randint(0, 10, size=(3, )) 28 | weight = class_balanced_weight(0.9999, np.random.randint(0, 100, size=(10, )).tolist()) 29 | 30 | Loss = SigmoidCrossEntropy(classes=10, weight=weight) 31 | Loss1 = FocalLoss(classes=10, weight=weight, gamma=0.5) 32 | 33 | cost = Loss(pred, label) 34 | cost1 = Loss1(pred, label) 35 | 36 | # print(cost, cost1) 37 | 38 | 39 | class n_to_n(nn.Module): 40 | def __init__(self): 41 | super().__init__() 42 | self.conv1 = nn.Conv2d(3, 3, 1, 1, bias=False) 43 | self.conv2 = nn.Conv2d(3, 3, 1, 1, bias=False) 44 | 45 | def forward(self, x1, x2): 46 | y1 = self.conv1(x1) 47 | y2 = self.conv2(x2) 48 | return y1, y2 49 | 50 | 51 | class n_to_one(nn.Module): 52 | def __init__(self): 53 | super().__init__() 54 | self.conv1 = nn.Conv2d(3, 3, 1, 1, bias=False) 55 | self.conv2 = nn.Conv2d(3, 3, 1, 1, bias=False) 56 | 57 | def forward(self, x1, x2): 58 | y1 = self.conv1(x1) 59 | y2 = self.conv2(x2) 60 | return y1 + y2 61 | 62 | 63 | class one_to_n(nn.Module): 64 | def __init__(self): 65 | super().__init__() 66 | self.conv1 = nn.Conv2d(3, 3, 1, 1, bias=False) 67 | self.conv2 = nn.Conv2d(3, 3, 1, 1, bias=False) 68 | 69 | def forward(self, x): 70 | y1 = self.conv1(x) 71 | y2 = self.conv2(x) 72 | return y1, y2 73 | 74 | 75 | @torch.no_grad() 76 | def test_ad_sequential(): 77 | seq = AdaptiveSequential(one_to_n(), n_to_n(), n_to_one()) 78 | td = torch.rand(1, 3, 32, 32) 79 | out = seq(td) 80 | 81 | assert out.size() == torch.Size([1, 3, 32, 32]) 82 | 83 | 84 | @torch.no_grad() 85 | def test_switch_norm(): 86 | td2 = torch.rand(1, 3, 32, 32) 87 | td3 = torch.rand(1, 3, 32, 32, 3) 88 | norm2 = SwitchNorm2d(3) 89 | norm3 = SwitchNorm3d(3) 90 | out2 = norm2(td2) 91 | out3 = norm3(td3) 92 | 93 | assert out2.size() == td2.size() and out3.size() == td3.size() 94 | 95 | 96 | def test_swish(): 97 | td = torch.rand(1, 16, 32, 32) 98 | swish = Swish(beta=10.0) 99 | swish(td) 100 | -------------------------------------------------------------------------------- /tests/test_obj.py: -------------------------------------------------------------------------------- 1 | # import sys 2 | # sys.path.insert(0, './') 3 | 4 | from torchtoolbox.objects import BBox 5 | import numpy as np 6 | 7 | 8 | def test_bbox(): 9 | test_data = np.random.rand(100, 4) 10 | bbox = BBox(test_data, mode='XYWH', category=['1' for _ in range(100)], name='test_bbox') 11 | len(bbox) 12 | for box, cat in bbox: 13 | pass 14 | ids = [1, 2, 3] 15 | bbox[ids] 16 | 17 | 18 | if __name__ == "__main__": 19 | test_bbox() 20 | -------------------------------------------------------------------------------- /tests/test_op.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | 4 | from torch.autograd import gradcheck 5 | from torchtoolbox.nn.operators import * 6 | import torch 7 | 8 | 9 | def test_swish(): 10 | switch = SwishOP.apply 11 | td = torch.rand(size=(2, 2), dtype=torch.double, requires_grad=True) 12 | test = gradcheck(switch, td, eps=1e-6, atol=1e-4) 13 | assert test 14 | -------------------------------------------------------------------------------- /tests/test_summary.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | 4 | import torch 5 | from torchtoolbox.tools import summary 6 | from torchvision.models.resnet import resnet50 7 | from torchvision.models.mobilenet import mobilenet_v2 8 | 9 | model1 = resnet50() 10 | model2 = mobilenet_v2() 11 | 12 | 13 | def test_summary(): 14 | summary(model1, torch.rand((1, 3, 224, 224)), True) 15 | print(summary(model2, torch.rand((1, 3, 224, 224)))) 16 | -------------------------------------------------------------------------------- /tests/test_transform.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | 4 | import numpy as np 5 | import random 6 | from torchtoolbox.transform import * 7 | from torchtoolbox.transform.functional import to_tensor 8 | 9 | trans = Compose([ 10 | # CV2 transforms 11 | Resize(500), 12 | CenterCrop(300), 13 | Pad(4), 14 | RandomCrop(255, 255), 15 | RandomHorizontalFlip(p=1), 16 | RandomVerticalFlip(p=1), 17 | RandomResizedCrop(100), 18 | ColorJitter(0.2, 0.2, 0.2), 19 | RandomRotation(15), 20 | RandomAffine(0), 21 | RandomPerspective(p=1), 22 | RandomGaussianNoise(p=1), 23 | RandomPoissonNoise(p=1), 24 | RandomSPNoise(p=1), 25 | Cutout(p=1), 26 | ToTensor(), 27 | # Tensor transforms 28 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 29 | RandomErasing(p=1), 30 | ]) 31 | 32 | 33 | def _genener_fake_img(size=None): 34 | if size is None: 35 | size = (400, 400, 3) 36 | return np.random.randint(0, 255, size=size, dtype='uint8') 37 | 38 | 39 | def test_transform(): 40 | img = _genener_fake_img() 41 | trans(img) 42 | -------------------------------------------------------------------------------- /torchtoolbox/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: pistonyang@gmail.com 3 | 4 | VERSION = '0.1.8.3' 5 | -------------------------------------------------------------------------------- /torchtoolbox/data/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | from .utils import * 4 | from .lmdb_dataset import * 5 | from .datasets import * 6 | from .dataprefetcher import DataPreFetcher 7 | from .dynamic_data_provider import * 8 | from .sampler import * 9 | -------------------------------------------------------------------------------- /torchtoolbox/data/dataprefetcher.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | import torch 4 | 5 | 6 | class DataPreFetcher(object): 7 | def __init__(self, loader): 8 | self.loader = iter(loader) 9 | self.stream = torch.cuda.Stream() 10 | # init 11 | self.next_data = None 12 | self.preload() 13 | 14 | def preload(self): 15 | 16 | try: 17 | self.next_data = next(self.loader) 18 | except StopIteration: 19 | self.next_data = None 20 | with torch.cuda.stream(self.stream): 21 | if not isinstance(self.next_data, (tuple, list)): 22 | self.next_data = self.next_data.cuda(non_blocking=True) 23 | else: 24 | self.next_data = tuple([d.cuda(non_blocking=True) for d in self.next_data]) 25 | 26 | def __next__(self): 27 | 28 | torch.cuda.current_stream().wait_stream(self.stream) 29 | data = self.next_data 30 | 31 | if data is None: 32 | raise StopIteration 33 | self.preload() 34 | 35 | return data 36 | 37 | def __iter__(self): 38 | return self 39 | 40 | def __len__(self): 41 | return len(self.loader) 42 | -------------------------------------------------------------------------------- /torchtoolbox/data/datasets.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | __all__ = ['NonLabelDataset', 'FeaturePairDataset', 'FeaturePairBin'] 4 | 5 | import glob 6 | import os 7 | import pickle 8 | import numpy as np 9 | from torch.utils.data import Dataset 10 | from .utils import decode_img_from_buf, cv_loader, pil_loader 11 | 12 | 13 | class NonLabelDataset(Dataset): 14 | """This is used for label-free training like GAN, VAE... 15 | 16 | root/xxx.jpg 17 | root/xxy.jpg 18 | root/xxz.jpg 19 | 20 | Args: 21 | root_dir (str): root dir of data. 22 | transform (callable): transform func. 23 | """ 24 | def __init__(self, root_dir, transform=None, loader=pil_loader): 25 | self.transform = transform 26 | self.items = sorted(os.listdir(root_dir)) 27 | self.items = [os.path.join(root_dir, f) for f in self.items] 28 | self.loader = loader 29 | 30 | def __len__(self): 31 | return len(self.items) 32 | 33 | def __getitem__(self, item): 34 | img = self.loader(self.items[item]) 35 | if self.transform: 36 | img = self.transform(img) 37 | return img 38 | 39 | 40 | class FeaturePairDataset(Dataset): 41 | """File structure should be like this. 42 | 43 | root/ 44 | xxx/ 45 | aaa.jpg 46 | bbb.jpg 47 | yyy/ 48 | ... 49 | zzz/ 50 | ... 51 | is_same(.txt) 52 | 53 | is_same file structure should be like this. 54 | 55 | is_same.txt 56 | xxx 1 57 | yyy 0 58 | zzz 0 59 | """ 60 | def __init__(self, root, is_same_file=None, transform=None, loader=pil_loader): 61 | self.root = root 62 | is_same = os.path.join(root, 'is_same.txt' if is_same_file is None else is_same_file) 63 | is_same_list = [] 64 | with open(is_same) as f: 65 | for line in f.readlines(): 66 | is_same_list.append(line.replace('\n', '').split(' ')) 67 | self.file_list = is_same_list 68 | self.transform = transform 69 | self.loader = loader 70 | self.pre_check() 71 | 72 | def pre_check(self): 73 | self.file_list = [[glob.glob(os.path.join(self.root, dir_name, '*.jpg')), 74 | int(is_same)] for dir_name, is_same in self.file_list] 75 | for files, _ in self.file_list: 76 | assert len(files) == 2 77 | 78 | def __len__(self): 79 | return len(self.file_list) 80 | 81 | def __getitem__(self, item): 82 | pair, is_same = self.file_list[item] 83 | imgs = list(map(self.loader, pair)) 84 | if self.transform: 85 | imgs = list(map(self.transform, imgs)) 86 | return imgs, is_same 87 | 88 | 89 | class FeaturePairBin(Dataset): 90 | """A dataset wrapping over a pickle serialized (.bin) file provided by InsightFace Repo. 91 | 92 | Parameters 93 | ---------- 94 | name : str. Name of val dataset. 95 | root : str. Path to face folder. 96 | transform : callable, default None 97 | A function that takes data and transforms them. 98 | 99 | """ 100 | def __init__(self, name, root, transform=None, backend='cv2'): 101 | self._transform = transform 102 | self.name = name 103 | with open(os.path.join(root, "{}.bin".format(name)), 'rb') as f: 104 | self.bins, self.issame_list = pickle.load(f, encoding='iso-8859-1') 105 | 106 | self._do_encode = not isinstance(self.bins[0], np.ndarray) 107 | self.backend = backend 108 | 109 | def _decode(self, im): 110 | if self._do_encode: 111 | im = im.encode("iso-8859-1") 112 | im = decode_img_from_buf(im, self.backend) 113 | return im 114 | 115 | def __getitem__(self, idx): 116 | img0 = self._decode(self.bins[2 * idx]) 117 | img1 = self._decode(self.bins[2 * idx + 1]) 118 | 119 | issame = 1 if self.issame_list[idx] else 0 120 | 121 | if self._transform is not None: 122 | img0 = self._transform(img0) 123 | img1 = self._transform(img1) 124 | 125 | return (img0, img1), issame 126 | 127 | def __len__(self): 128 | return len(self.issame_list) 129 | -------------------------------------------------------------------------------- /torchtoolbox/data/dynamic_data_provider.py: -------------------------------------------------------------------------------- 1 | __all__ = ['DynamicBatchSampler', 'DynamicSizeImageFolder', 'DistributedDynamicBatchSampler'] 2 | import torch 3 | from torch import distributed 4 | from torch.utils.data import BatchSampler, Dataset 5 | from torchvision.datasets import ImageFolder 6 | 7 | 8 | class DynamicBatchSampler(BatchSampler): 9 | """DynamicBatchSampler 10 | 11 | Args: 12 | info_generate_fn (callable): give batch samples extra info. 13 | """ 14 | def __init__(self, sampler, batch_size: int, drop_last: bool, info_generate_fn=None) -> None: 15 | super().__init__(sampler, batch_size, drop_last) 16 | self.info_generate_fn = info_generate_fn if info_generate_fn is not None else lambda: None 17 | 18 | def set_batch_size(self, batch_size: int): 19 | if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \ 20 | batch_size <= 0: 21 | raise ValueError("batch_size should be a positive integer value, " "but got batch_size={}".format(batch_size)) 22 | self.batch_size = batch_size 23 | 24 | def set_info_generate_fn(self, info_generate_fn): 25 | self.info_generate_fn = info_generate_fn 26 | 27 | def __iter__(self): 28 | batch = [] 29 | info = self.info_generate_fn() 30 | for idx in self.sampler: 31 | batch.append([idx, info]) 32 | if len(batch) == self.batch_size: 33 | yield batch 34 | batch = [] 35 | info = self.info_generate_fn() 36 | if len(batch) > 0 and not self.drop_last: 37 | yield batch 38 | 39 | 40 | class DistributedDynamicBatchSampler(BatchSampler): 41 | """DistributedDynamicBatchSampler to sync all rank info. 42 | 43 | Args: 44 | info_generate_fn (callable): give batch samples extra info. 45 | main_rank: rank to send data. 46 | rank: current rank. 47 | 48 | the result of info_generate_fn must be convert to tensor, current only support integer. 49 | """ 50 | def __init__(self, sampler, batch_size: int, drop_last: bool, main_rank: int, rank, info_generate_fn=None) -> None: 51 | super().__init__(sampler, batch_size, drop_last) 52 | self.info_generate_fn = info_generate_fn if info_generate_fn is not None else lambda: None 53 | self.main_rank = main_rank 54 | self.rank = rank 55 | self.epoch_info = None 56 | self.reset_and_sync_info() 57 | 58 | def reset_and_sync_info(self): 59 | epoch_info = [self.info_generate_fn() for _ in range(len(self) + 1)] 60 | epoch_info = torch.as_tensor(epoch_info, dtype=torch.int, device=torch.device('cuda', self.rank)) 61 | distributed.broadcast(epoch_info, self.main_rank) 62 | self.epoch_info = epoch_info.cpu().numpy().tolist() 63 | 64 | def set_batch_size(self, batch_size: int): 65 | if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \ 66 | batch_size <= 0: 67 | raise ValueError("batch_size should be a positive integer value, " "but got batch_size={}".format(batch_size)) 68 | self.batch_size = batch_size 69 | 70 | def set_info_generate_fn(self, info_generate_fn): 71 | self.info_generate_fn = info_generate_fn 72 | 73 | def __iter__(self): 74 | batch = [] 75 | info = self.epoch_info.pop(0) 76 | for idx in self.sampler: 77 | batch.append([idx, info]) 78 | if len(batch) == self.batch_size: 79 | yield batch 80 | batch = [] 81 | info = self.epoch_info.pop(0) 82 | if len(batch) > 0 and not self.drop_last: 83 | yield batch 84 | 85 | 86 | class DynamicSizeImageFolder(ImageFolder): 87 | def __getitem__(self, index_info): 88 | """ 89 | Args: 90 | index (int): Index 91 | 92 | Returns: 93 | tuple: (sample, target) where target is class_index of the target class. 94 | """ 95 | index, size = index_info 96 | path, target = self.samples[index] 97 | sample = self.loader(path) 98 | if self.transform is not None: 99 | sample = self.transform(sample, size) 100 | if self.target_transform is not None: 101 | target = self.target_transform(target) 102 | 103 | return sample, target 104 | 105 | 106 | class DynamicSubset(Dataset): 107 | r""" 108 | Subset of a dynamic dataset at specified indices. 109 | 110 | Arguments: 111 | dataset (Dataset): The whole Dataset 112 | indices (sequence): Indices in the whole set selected for subset 113 | """ 114 | def __init__(self, dataset: Dataset, indices) -> None: 115 | self.dataset = dataset 116 | self.indices = indices 117 | 118 | def __getitem__(self, index_info): 119 | index, size = index_info 120 | index = (self.indices[index], size) 121 | return self.dataset[index] 122 | 123 | def __len__(self): 124 | return len(self.indices) 125 | -------------------------------------------------------------------------------- /torchtoolbox/data/lmdb_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | __all__ = ['ImageLMDB'] 4 | 5 | import os 6 | import lmdb 7 | from ..tools.convert_lmdb import get_key, load_pyarrow 8 | from .utils import decode_img_from_buf 9 | from torch.utils.data import Dataset 10 | 11 | 12 | class ImageLMDB(Dataset): 13 | """ 14 | LMDB format for image folder. 15 | """ 16 | def __init__(self, db_path, db_name, transform=None, target_transform=None, backend='cv2'): 17 | self.env = lmdb.open(os.path.join(db_path, '{}.lmdb'.format(db_name)), 18 | subdir=False, 19 | readonly=True, 20 | lock=False, 21 | readahead=False, 22 | meminit=False) 23 | with self.env.begin() as txn: 24 | self.length = load_pyarrow(txn.get(b'__len__')) 25 | try: 26 | self.classes = load_pyarrow(txn.get(b'classes')) 27 | self.class_to_idx = load_pyarrow(txn.get(b'class_to_idx')) 28 | except AssertionError: 29 | pass 30 | 31 | self.map_list = [get_key(i) for i in range(self.length)] 32 | self.transform = transform 33 | self.target_transform = target_transform 34 | self.backend = backend 35 | 36 | def __len__(self): 37 | return self.length 38 | 39 | def __getitem__(self, item): 40 | with self.env.begin() as txn: 41 | byteflow = txn.get(self.map_list[item]) 42 | unpacked = load_pyarrow(byteflow) 43 | imgbuf, target = unpacked 44 | img = decode_img_from_buf(imgbuf, self.backend) 45 | 46 | if self.transform is not None: 47 | img = self.transform(img) 48 | if self.target_transform is not None: 49 | target = self.target_transform(target) 50 | 51 | return img, target 52 | -------------------------------------------------------------------------------- /torchtoolbox/data/sampler.py: -------------------------------------------------------------------------------- 1 | __all__ = ['RepeatedAugmentSampler'] 2 | 3 | import math 4 | 5 | import torch 6 | from torch import distributed as dist 7 | from torch.utils.data import Sampler 8 | 9 | 10 | class RepeatedAugmentSampler(Sampler): 11 | """Sampler that restricts data loading to a subset of the dataset for distributed, 12 | with repeated augmentation. 13 | It ensures that different each augmented version of a sample will be visible to a 14 | different process (GPU) 15 | Heavily based on torch.utils.data.DistributedSampler 16 | """ 17 | def __init__(self, dataset, m=3, num_replicas=None, rank=None, shuffle=True): 18 | if num_replicas is None: 19 | if not dist.is_available(): 20 | raise RuntimeError("Requires distributed package to be available") 21 | num_replicas = dist.get_world_size() 22 | if rank is None: 23 | if not dist.is_available(): 24 | raise RuntimeError("Requires distributed package to be available") 25 | rank = dist.get_rank() 26 | self.dataset = dataset 27 | self.num_replicas = num_replicas 28 | self.rank = rank 29 | self.epoch = 0 30 | self.m = m 31 | self.num_samples = int(math.ceil(len(self.dataset) * m / self.num_replicas)) 32 | self.total_size = self.num_samples * self.num_replicas 33 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 34 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 35 | self.shuffle = shuffle 36 | 37 | def __iter__(self): 38 | # deterministically shuffle based on epoch 39 | g = torch.Generator() 40 | g.manual_seed(self.epoch) 41 | if self.shuffle: 42 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 43 | else: 44 | indices = list(range(len(self.dataset))) 45 | 46 | # add extra samples to make it evenly divisible 47 | indices = [ele for ele in indices for i in range(self.m)] 48 | indices += indices[:(self.total_size - len(indices))] 49 | assert len(indices) == self.total_size 50 | 51 | # subsample 52 | indices = indices[self.rank:self.total_size:self.num_replicas] 53 | assert len(indices) == self.num_samples 54 | 55 | return iter(indices[:self.num_selected_samples]) 56 | 57 | def __len__(self): 58 | return self.num_selected_samples 59 | 60 | def set_epoch(self, epoch): 61 | self.epoch = epoch 62 | -------------------------------------------------------------------------------- /torchtoolbox/data/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | __all__ = ['decode_img_from_buf', 'pil_loader', 'cv_loader'] 3 | 4 | import cv2 5 | import six 6 | import numpy as np 7 | from PIL import Image 8 | 9 | 10 | def decode_img_from_buf(buf, backend='cv2'): 11 | if backend == 'pil': 12 | buf = six.BytesIO(buf) 13 | img = Image.open(buf).convert('RGB') 14 | elif backend == 'cv2': 15 | buf = np.frombuffer(buf, np.uint8) 16 | img = cv2.imdecode(buf, 1)[..., ::-1] 17 | else: 18 | raise NotImplementedError 19 | return img 20 | 21 | 22 | def pil_loader(path): 23 | # open path as file to avoid ResourceWarning 24 | # (https://github.com/python-pillow/Pillow/issues/835) 25 | with open(path, 'rb') as f: 26 | img = Image.open(f) 27 | return img.convert('RGB') 28 | 29 | 30 | def cv_loader(path): 31 | return cv2.imread(path)[..., ::-1] 32 | -------------------------------------------------------------------------------- /torchtoolbox/io/__init__.py: -------------------------------------------------------------------------------- 1 | from .structured_data_io import * 2 | -------------------------------------------------------------------------------- /torchtoolbox/io/structured_data_io.py: -------------------------------------------------------------------------------- 1 | __all__ = ['load_data', 'save_data'] 2 | 3 | import json 4 | import yaml 5 | from ..tools import DotDict 6 | 7 | 8 | def load_data(file_path: str, format: str, to_dot_dict: bool = False, load_kwargs: dict = dict(mode='r'), **kwargs) -> dict: 9 | """Use this func to load file easily. 10 | 11 | Args: 12 | file_path (str): full file path 13 | format (str): parse engine 14 | to_dot_dict (bool, optional): whether convert to dot_dict. Defaults to False. 15 | load_kwargs (dict, optional): this will give `open()`. Defaults to {}. 16 | 17 | Raises: 18 | NotImplementedError: [description] 19 | 20 | Returns: 21 | dict: [description] 22 | """ 23 | assert format in ('json', 'yaml'), 'Now only json and yaml format are supported.' 24 | with open(file_path, **load_kwargs) as f: 25 | if format == 'json': 26 | file = json.load(f, **kwargs) 27 | elif format == 'yaml': 28 | file = yaml.load(f, Loader=yaml.SafeLoader) 29 | else: 30 | raise NotImplementedError 31 | if to_dot_dict: 32 | file = DotDict(file) 33 | return file 34 | 35 | 36 | def save_data(data, file_path, format, load_kwargs: dict = dict(mode='w'), **kwargs) -> None: 37 | assert format in ('json', 'yaml'), 'Now only json and yaml format are supported.' 38 | with open(file_path, **load_kwargs) as f: 39 | if format == 'json': 40 | json.dump(data, f, indent=2, **kwargs) 41 | elif format == 'yaml': 42 | yaml.dump(data, f, Dumper=yaml.SafeDumper, **kwargs) 43 | else: 44 | raise NotImplementedError 45 | -------------------------------------------------------------------------------- /torchtoolbox/metric/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: pistonyang@gmail.com 3 | 4 | from .metric import * 5 | from .feature_verification import * 6 | from .map import MeanAveragePrecision 7 | -------------------------------------------------------------------------------- /torchtoolbox/metric/feature_verification.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | __all__ = ['FeatureVerification'] 3 | 4 | import numpy as np 5 | from .metric import Metric 6 | from ..tools.utils import to_numpy 7 | from sklearn.model_selection import KFold 8 | from scipy import interpolate 9 | 10 | 11 | class FeatureVerification(Metric): 12 | """ Compute confusion matrix of 1:1 problem in feature verification or other fields. 13 | Use update() to collect the outputs and compute distance in each batch, then use get() to compute the 14 | confusion matrix and accuracy of the val dataset. 15 | 16 | Parameters 17 | ---------- 18 | nfolds: int, default is 10 19 | 20 | thresholds: ndarray, default is None. 21 | Use np.arange to generate thresholds. If thresholds=None, np.arange(0, 2, 0.01) will be used for 22 | euclidean distance. 23 | 24 | far_target: float, default is 1e-3. 25 | This is used to get the verification accuracy of expected far. 26 | 27 | dist_type: str, default is euclidean. 28 | Option value is {euclidean, cosine}, 0 for euclidean distance, 1 for cosine similarity. 29 | Here for cosine distance, we use `1 - cosine` as the final distances. 30 | 31 | """ 32 | def __init__(self, nfolds=10, far_target=1e-3, thresholds=None, dist_type='euclidean', **kwargs): 33 | super(FeatureVerification, self).__init__(**kwargs) 34 | assert dist_type in ('euclidean', 'cosine') 35 | self.nfolds = nfolds 36 | self.far_target = far_target 37 | default_thresholds = np.arange(0, 2, 0.01) if dist_type == 'euclidean' else np.arange(0, 1, 0.005) 38 | self.thresholds = default_thresholds if thresholds is None else thresholds 39 | self.dist_type = dist_type 40 | 41 | self.dists = [] 42 | self.issame = [] 43 | 44 | def reset(self): 45 | self.dists = [] 46 | self.issame = [] 47 | 48 | def update(self, embeddings0, embeddings1, labels): 49 | embeddings0, embeddings1, labels = map(to_numpy, (embeddings0, embeddings1, labels)) 50 | if self.dist_type == 'euclidean': 51 | diff = np.subtract(embeddings0, embeddings1) 52 | dists = np.sqrt(np.sum(np.square(diff), 1)) 53 | else: 54 | dists = 1 - np.sum(np.multiply(embeddings0, embeddings1), 55 | axis=1) / (np.linalg.norm(embeddings0, axis=1) * np.linalg.norm(embeddings1, axis=1)) 56 | 57 | self.dists.extend(dists) 58 | self.issame.extend(labels) 59 | 60 | def get(self): 61 | tpr, fpr, accuracy, threshold = calculate_roc(self.thresholds, np.asarray(self.dists), np.asarray(self.issame), 62 | self.nfolds) 63 | 64 | val, val_std, far = calculate_val(self.thresholds, np.asarray(self.dists), np.asarray(self.issame), self.far_target, 65 | self.nfolds) 66 | 67 | acc, acc_std = np.mean(accuracy), np.std(accuracy) 68 | threshold = (1 - threshold) if self.dist_type == 'cosine' else threshold 69 | return tpr, fpr, acc, threshold, val, val_std, far, acc_std 70 | 71 | 72 | # code below is modified from project and 73 | # 74 | class LFold: 75 | def __init__(self, n_splits=2, shuffle=False): 76 | self.n_splits = n_splits 77 | if self.n_splits > 1: 78 | self.k_fold = KFold(n_splits=n_splits, shuffle=shuffle) 79 | 80 | def split(self, indices): 81 | if self.n_splits > 1: 82 | return self.k_fold.split(indices) 83 | else: 84 | return [(indices, indices)] 85 | 86 | 87 | def calculate_roc(thresholds, dist, actual_issame, nrof_folds=10): 88 | assert len(dist) == len(actual_issame) 89 | 90 | nrof_pairs = len(dist) 91 | nrof_thresholds = len(thresholds) 92 | k_fold = LFold(n_splits=nrof_folds, shuffle=False) 93 | 94 | tprs = np.zeros((nrof_folds, nrof_thresholds)) 95 | fprs = np.zeros((nrof_folds, nrof_thresholds)) 96 | avg_thresholds = [] 97 | accuracy = np.zeros((nrof_folds, )) 98 | indices = np.arange(nrof_pairs) 99 | dist = np.array(dist) 100 | 101 | for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)): 102 | acc_train = np.zeros((nrof_thresholds, )) 103 | for threshold_idx, threshold in enumerate(thresholds): 104 | _, _, acc_train[threshold_idx] = calculate_accuracy(threshold, dist[train_set], actual_issame[train_set]) 105 | best_threshold_index = np.argmax(acc_train) 106 | for threshold_idx, threshold in enumerate(thresholds): 107 | tprs[fold_idx, threshold_idx], \ 108 | fprs[fold_idx, threshold_idx], _ = calculate_accuracy(threshold, dist[test_set], 109 | actual_issame[test_set]) 110 | avg_thresholds.append(thresholds[best_threshold_index]) 111 | _, _, accuracy[fold_idx] = calculate_accuracy(thresholds[best_threshold_index], dist[test_set], actual_issame[test_set]) 112 | avg_thresholds = np.mean(avg_thresholds) 113 | tpr = np.mean(tprs, 0) 114 | fpr = np.mean(fprs, 0) 115 | return tpr, fpr, accuracy, avg_thresholds 116 | 117 | 118 | def calculate_accuracy(threshold, dist, actual_issame): 119 | predict_issame = np.less(dist, threshold) 120 | tp = np.sum(np.logical_and(predict_issame, actual_issame)) 121 | fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame))) 122 | tn = np.sum(np.logical_and(np.logical_not(predict_issame), np.logical_not(actual_issame))) 123 | fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame)) 124 | 125 | tpr = 0 if (tp + fn == 0) else float(tp) / float(tp + fn) 126 | fpr = 0 if (fp + tn == 0) else float(fp) / float(fp + tn) 127 | acc = float(tp + tn) / dist.size 128 | return tpr, fpr, acc 129 | 130 | 131 | def calculate_val(thresholds, dist, actual_issame, far_target, nrof_folds=10): 132 | assert len(dist) == len(actual_issame), "Shape of predicts and labels mismatch!" 133 | 134 | nrof_pairs = len(dist) 135 | nrof_thresholds = len(thresholds) 136 | k_fold = LFold(n_splits=nrof_folds, shuffle=False) 137 | 138 | val = np.zeros(nrof_folds) 139 | far = np.zeros(nrof_folds) 140 | indices = np.arange(nrof_pairs) 141 | dist = np.array(dist) 142 | 143 | for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)): 144 | # Find the threshold that gives FAR = far_target 145 | far_train = np.zeros(nrof_thresholds) 146 | for threshold_idx, threshold in enumerate(thresholds): 147 | _, far_train[threshold_idx] = calculate_val_far(threshold, dist[train_set], actual_issame[train_set]) 148 | 149 | if np.max(far_train) >= far_target: 150 | f = interpolate.interp1d(far_train, thresholds, kind='slinear') 151 | threshold = f(far_target) 152 | else: 153 | threshold = 0.0 154 | val[fold_idx], far[fold_idx] = calculate_val_far(threshold, dist[test_set], actual_issame[test_set]) 155 | 156 | val_mean = np.mean(val) 157 | val_std = np.std(val) 158 | far_mean = np.mean(far) 159 | return val_mean, val_std, far_mean 160 | 161 | 162 | def calculate_val_far(threshold, dist, actual_issame): 163 | predict_issame = np.less(dist, threshold) 164 | true_accept = np.sum(np.logical_and(predict_issame, actual_issame)) 165 | false_accept = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame))) 166 | n_same = np.sum(actual_issame) 167 | n_diff = np.sum(np.logical_not(actual_issame)) 168 | 169 | val = float(true_accept) / float(n_same) 170 | far = float(false_accept) / float(n_diff) 171 | return val, far 172 | -------------------------------------------------------------------------------- /torchtoolbox/metric/map.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | from typing import List, Union 3 | from prettytable import PrettyTable 4 | import numpy as np 5 | 6 | from .metric import Metric 7 | from ..objects import BBox 8 | from ..tools import to_list, get_value_from_dicts 9 | 10 | 11 | class MeanAveragePrecision(Metric): 12 | def __init__(self, 13 | name: str, 14 | iou_threshold=np.arange(50, 100, 5), 15 | iou_interested=(50, 75), 16 | object_type='bbox', 17 | writer: Union[SummaryWriter, None] = None): 18 | super().__init__(name=name, writer=writer) 19 | self.iou_threshold = iou_threshold / 100 20 | self.iou_interested = [iou / 100 for iou in sorted(iou_interested)] 21 | self.val_type = object_type 22 | 23 | self.coll = {iou: {} for iou in self.iou_threshold} 24 | 25 | def calculate_iou(self, predict: np.ndarray, gt: np.ndarray): 26 | """calculate one to one image iou on category. 27 | 28 | Args: 29 | predict (BBox): predict bbox. 30 | gt (BBox): gt bbox. 31 | category (str): category. 32 | """ 33 | 34 | pred_bbox_iou = [] 35 | for pcb in predict: 36 | inter_xmin = np.maximum(pcb[0], gt[:, 0]) 37 | inter_ymin = np.maximum(pcb[1], gt[:, 1]) 38 | inter_xmax = np.minimum(pcb[2], gt[:, 2]) 39 | inter_ymax = np.minimum(pcb[3], gt[:, 3]) 40 | inter_width = np.maximum(inter_xmax - inter_xmin, 0) 41 | inter_height = np.maximum(inter_ymax - inter_ymin, 0) 42 | 43 | inter_area = inter_width * inter_height 44 | union_area = (pcb[2] - pcb[0]) * (pcb[3] - pcb[1]) + (gt[:, 2] - gt[:, 0]) * (gt[:, 3] - gt[:, 1]) - inter_area 45 | iou = inter_area / union_area 46 | pred_bbox_iou.append(iou) 47 | return np.stack(pred_bbox_iou) 48 | 49 | def calculate_rank(self, iou, iou_threshold): 50 | valid_tp = [] # [[rank, bbox_to_gt, confidence, valid], ...] 51 | first_valid_bbox = [] 52 | for rank_iou in iou: 53 | max_iou_bbox = np.argmax(rank_iou) 54 | if max_iou_bbox not in first_valid_bbox and rank_iou[max_iou_bbox] > iou_threshold: 55 | first_valid_bbox.append(max_iou_bbox) 56 | tp = 1 57 | else: 58 | tp = 0 59 | valid_tp.append(tp) 60 | return valid_tp 61 | 62 | def calculate_pr(self, tp_list, gt_num, smooth=True): 63 | precision_list = [sum_tp / (idx + 1) for idx, sum_tp in enumerate(np.cumsum(tp_list))] 64 | recall_list = [sum_tp / gt_num for sum_tp in np.cumsum(tp_list)] 65 | precision = np.array([0.0] + precision_list + [0.0]) 66 | recall = np.array([0.0] + recall_list + [1.0]) 67 | if smooth: 68 | for i in range(precision.size - 1, 0, -1): 69 | precision[i - 1] = np.maximum(precision[i - 1], precision[i]) 70 | return precision, recall 71 | 72 | def calculate_ap(self, precision, recall): 73 | i = np.where(recall[1:] != recall[:-1])[0] 74 | ap = np.sum((recall[i + 1] - recall[i]) * precision[i + 1]) 75 | return ap 76 | 77 | def update(self, preds: List[BBox], gt: List[BBox], record_tb=False): 78 | preds, gt = to_list(preds), to_list(gt) 79 | assert len(preds) == len(gt), "num of two values must be same." 80 | 81 | for p, g in zip(preds, gt): 82 | if p.empty_bbox and g.empty_bbox: 83 | continue 84 | for iou in self.iou_threshold: 85 | pg_cats = list(set(p.contain_category + g.contain_category)) 86 | for pgc in pg_cats: 87 | if pgc not in self.coll[iou].keys(): 88 | self.coll[iou][pgc] = dict(tp_list=[], gt_num=0) 89 | pred_cat_bbox = p.get_category_bboxes(pgc) # shape: Pcat x 4 90 | gt_cat_bbox = g.get_category_bboxes(pgc) # shape: Gcat x 4 91 | if len(gt_cat_bbox) == 0: 92 | self.coll[iou][pgc]['tp_list'] += [0 for _ in range(len(pred_cat_bbox))] 93 | elif len(pred_cat_bbox) == 0: 94 | self.coll[iou][pgc]['gt_num'] += len(gt_cat_bbox) 95 | else: 96 | p_g_iou = self.calculate_iou(pred_cat_bbox, gt_cat_bbox) 97 | tp = self.calculate_rank(p_g_iou, iou) 98 | self.coll[iou][pgc]['tp_list'] += tp 99 | self.coll[iou][pgc]['gt_num'] += len(gt_cat_bbox) 100 | 101 | def reset(self): 102 | self.coll = {iou: {} for iou in self.iou_threshold} 103 | 104 | def get(self): 105 | interested_aps = {} 106 | ap_dicts = {} 107 | for iou in self.iou_threshold: 108 | iou_ap_list = [] 109 | for cate, tp_gt in self.coll[iou].items(): 110 | tp_list = tp_gt['tp_list'] 111 | gt_num = tp_gt['gt_num'] 112 | gt_num += 1 if gt_num == 0 else 0 113 | precision, recall = self.calculate_pr(tp_list, gt_num) 114 | ap = self.calculate_ap(precision, recall) 115 | precision = 0 if ap == 0 else precision[-2] 116 | recall = 0 if ap == 0 else recall[-2] 117 | iou_ap_list.append(dict(ap=ap, precision=precision, recall=recall, category=cate)) 118 | iou_ap, iou_precision, iou_recall = get_value_from_dicts(iou_ap_list, ('ap', "precision", "recall"), 119 | post_process='mean') 120 | ap_dicts[iou] = dict(ap=iou_ap, precision=iou_precision, recall=iou_recall) 121 | if iou in self.iou_interested: 122 | interested_aps[iou] = dict(ap=iou_ap, precision=iou_precision, recall=iou_recall, cate_info=iou_ap_list) 123 | mAP = get_value_from_dicts(ap_dicts, 'ap', post_process='mean')[0] 124 | rlt_dict = dict(map=mAP) 125 | rlt_dict.update(interested_aps) 126 | return rlt_dict 127 | 128 | def report(self): 129 | rlt_dict = self.get() 130 | map = rlt_dict['map'] 131 | print(f"mAP: {map}") 132 | for iou in self.iou_interested: 133 | tabel = PrettyTable() 134 | tabel.field_names = ["field", "AP", "precision", "recall"] 135 | tabel.add_row([f"AP{int(iou*100)}", rlt_dict[iou]['ap'], rlt_dict[iou]['precision'], rlt_dict[iou]['recall']]) 136 | for cate_info in rlt_dict[iou]['cate_info']: 137 | tabel.add_row([cate_info['category'], cate_info['ap'], cate_info['precision'], cate_info['recall']]) 138 | print(tabel) 139 | -------------------------------------------------------------------------------- /torchtoolbox/metric/metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: pistonyang@gmail.com 3 | 4 | __all__ = ['Accuracy', 'TopKAccuracy', 'NumericalCost'] 5 | 6 | import numpy as np 7 | import torch 8 | from torch.utils.tensorboard import SummaryWriter 9 | 10 | from ..tools import to_numpy 11 | 12 | 13 | class Metric(object): 14 | """This is a abstract class for all metric classes. 15 | 16 | Args: 17 | name (string or dict, optional): metric name 18 | 19 | """ 20 | def __init__(self, name: str = None, writer: SummaryWriter = None): 21 | if name is not None: 22 | assert isinstance(name, (str, dict)) 23 | if writer is not None: 24 | assert isinstance(writer, SummaryWriter) and name is not None 25 | self._iteration = 0 26 | 27 | self._writer = writer 28 | self._name = name 29 | 30 | @property 31 | def name(self): 32 | return self._name 33 | 34 | def reset(self): 35 | """Reset metric to init state. 36 | """ 37 | raise NotImplementedError 38 | 39 | def update(self, record_tb=False): 40 | """Update status. 41 | 42 | """ 43 | raise NotImplementedError 44 | 45 | def get(self): 46 | """Get metric recorded. 47 | 48 | """ 49 | raise NotImplementedError 50 | 51 | 52 | class Accuracy(Metric): 53 | """Record and calculate accuracy. 54 | 55 | Args: 56 | name (string or dict, optional): Acc name. eg: name='Class1 Acc' 57 | writer (SummaryWriter, optional): TenorBoard writer. 58 | Attributes: 59 | num_metric (int): Number of pred == label 60 | num_inst (int): All samples 61 | """ 62 | def __init__(self, name=None, writer=None): 63 | super(Accuracy, self).__init__(name, writer) 64 | self.num_metric = 0 65 | self.num_inst = 0 66 | 67 | def reset(self): 68 | """Reset status.""" 69 | 70 | self.num_metric = 0 71 | self.num_inst = 0 72 | 73 | @torch.no_grad() 74 | def update(self, preds, labels, record_tb=False): 75 | """Update status. 76 | 77 | Args: 78 | preds (Tensor): Model outputs 79 | labels (Tensor): True label 80 | record_tb (Bool): If writer is not None, 81 | will not update tensorboard when this set to true. 82 | """ 83 | _, pred = torch.max(preds, dim=1) 84 | pred = to_numpy(pred.view(-1)).astype('int32') 85 | lbs = to_numpy(labels.view(-1)).astype('int32') 86 | self.num_metric += int((pred == lbs).sum()) 87 | self.num_inst += len(lbs) 88 | 89 | def get(self): 90 | """Get accuracy recorded. 91 | 92 | You should call step before get at least once. 93 | 94 | Returns: 95 | A float number of accuracy. 96 | 97 | """ 98 | if self.num_inst == 0: 99 | return 0. 100 | return self.num_metric / self.num_inst 101 | 102 | 103 | class TopKAccuracy(Metric): 104 | """Record and calculate top k accuracy. eg: top5 acc 105 | 106 | Args: 107 | top (int): top k accuracy to calculate. 108 | name (string or dict, optional): Acc name. eg: name='Top5 Acc' 109 | writer (SummaryWriter, optional): TenorBoard writer. 110 | 111 | Attributes: 112 | num_metric (int): Number of pred == label 113 | num_inst (int): All samples 114 | """ 115 | def __init__(self, top=1, name=None, writer=None): 116 | super(TopKAccuracy, self).__init__(name, writer) 117 | assert top > 1, 'Please use Accuracy if top_k is no more than 1' 118 | self.topK = top 119 | self.num_metric = 0 120 | self.num_inst = 0 121 | 122 | def reset(self): 123 | """Reset status.""" 124 | self.num_metric = 0 125 | self.num_inst = 0 126 | 127 | @torch.no_grad() 128 | def update(self, preds, labels, record_tb=False): 129 | """Update status. 130 | 131 | Args: 132 | preds (Tensor): Model outputs 133 | labels (Tensor): True label 134 | record_tb (Bool): If writer is not None, 135 | will not update tensorboard when this set to true. 136 | """ 137 | 138 | preds = to_numpy(preds).astype('float32') 139 | labels = to_numpy(labels).astype('float32') 140 | preds = np.argpartition(preds, -self.topK)[:, -self.topK:] 141 | # TODO: Is there any more quick way? 142 | for l, p in zip(labels, preds): 143 | self.num_metric += 1 if l in p else 0 144 | self.num_inst += 1 145 | 146 | def get(self): 147 | """Get top k accuracy recorded. 148 | 149 | You should call step before get at least once. 150 | 151 | Returns: 152 | A float number of accuracy. 153 | 154 | """ 155 | if self.num_inst == 0: 156 | return 0. 157 | return self.num_metric / self.num_inst 158 | 159 | 160 | class NumericalCost(Metric): 161 | """Record and calculate numerical(scalar) cost. eg: loss 162 | 163 | Args: 164 | name (string or dict, optional): Acc name. eg: name='Loss' 165 | record_type (string, optional): how to a calculate this, 166 | only 'mean', 'max', 'min' supported. 167 | writer (SummaryWriter, optional): TenorBoard writer. 168 | Attributes: 169 | coll (list): element to be calculated. 170 | """ 171 | def __init__(self, name=None, record_type='mean', writer=None): 172 | super(NumericalCost, self).__init__(name, writer) 173 | self.coll = [] 174 | self.type = record_type 175 | assert record_type in ('mean', 'max', 'min') 176 | 177 | def reset(self): 178 | """Reset status.""" 179 | self.coll = [] 180 | 181 | @torch.no_grad() 182 | def update(self, cost, record_tb=False): 183 | """Update status. 184 | 185 | Args: 186 | cost (Tensor): cost to record. 187 | record_tb (Bool): If writer is not None, 188 | will not update tensorboard when this set to true. 189 | """ 190 | self.coll.append(to_numpy(cost)) 191 | 192 | def get(self): 193 | """Get top cost recorded. 194 | 195 | You should call step before get at least once. 196 | 197 | Returns: 198 | A float number of cost. 199 | 200 | """ 201 | assert len(self.coll) != 0, 'Please call step before get' 202 | if self.type == 'mean': 203 | ret = np.mean(self.coll) 204 | elif self.type == 'max': 205 | ret = np.max(self.coll) 206 | else: 207 | ret = np.min(self.coll) 208 | return ret.item() 209 | 210 | 211 | # class DistributedCollector(Metric): 212 | # """Collect Distribute tensors cross ranks. 213 | 214 | # Args: 215 | # rank: loc rank. 216 | # dst: main worker. 217 | # record_type: how to a calculate this, 218 | # only 'SUM', 'PRODUCT', 'MAX', 'MIN', 'BAND', 'BOR', 'BXOR' supported. 219 | # dis_coll_type: 220 | # post_process: process after reduce. 221 | # name: collector name. 222 | # writer: TenorBoard writer. 223 | 224 | # Attributes: 225 | 226 | # """ 227 | # def __init__(self, 228 | # rank=None, 229 | # dst=None, 230 | # record_type='sum', 231 | # dis_coll_type='reduce', 232 | # post_process=None, 233 | # name=None, 234 | # writer=None): 235 | 236 | # super(DistributedCollector, self).__init__(name, writer) 237 | # record_type = record_type.lower() 238 | # assert record_type in ('sum', 'product', 'min', 'max', 'band', 'bor', 'bxor') 239 | # assert dis_coll_type in ('reduce', 'all_reduce') 240 | # if dis_coll_type == 'reduce' or writer is not None: 241 | # assert dst is not None, 'please select dst device to reduce if use reduce OP.' \ 242 | # 'please select dst device to write tensorboard if use tensorboard.' 243 | 244 | # if rank is None: 245 | # rank = distributed.get_rank() 246 | # type_encode = { 247 | # 'sum': distributed.ReduceOp.SUM, 248 | # 'product': distributed.ReduceOp.PRODUCT, 249 | # 'max': distributed.ReduceOp.MAX, 250 | # 'min': distributed.ReduceOp.MIN, 251 | # 'band': distributed.ReduceOp.BAND, 252 | # 'bor': distributed.ReduceOp.BOR, 253 | # 'bxor': distributed.ReduceOp.BXOR 254 | # } 255 | 256 | # self.dst = dst 257 | # self.rank = rank 258 | # self.device = torch.device(rank) 259 | # self.dct = dis_coll_type 260 | # self.record_type = record_type 261 | # self.post_process = post_process 262 | # self.dist_op = type_encode[record_type] 263 | 264 | # self.last_rlt = 0. 265 | 266 | # def reset(self): 267 | # self.last_rlt = 0. 268 | 269 | # @torch.no_grad() 270 | # def update(self, item, record_tb=False): 271 | # """ 272 | 273 | # Args: 274 | # item: could be a Python scalar, Numpy ndarray, Pytorch tensor. 275 | # record_tb: stop write to tensorboard in this time. 276 | 277 | # Returns: 278 | # Reduced result. If dis_coll_type=='reduce' only main rank will do post_process. 279 | # """ 280 | # item = reduce_tensor(item, self.rank, self.dist_op, self.dst, self.dct) 281 | 282 | # if self.post_process is not None: 283 | # if self.dct == 'all_reduce': 284 | # item = self.post_process(item) 285 | # elif self.rank == self.dst: 286 | # item = self.post_process(item) 287 | 288 | # self.last_rlt = item 289 | 290 | # if self._writer is not None and self.rank == self.dst: 291 | # if not isinstance(self.last_rlt, (int, float)): 292 | # try: 293 | # self.last_rlt = self.last_rlt.item() 294 | # except Exception as e: 295 | # print("If you want to write to tensorboard, " 296 | # "you need to convert to a scalar in post_process " 297 | # "when target tensor is not a pytorch tensor. " 298 | # "Got error {}".format(e)) 299 | 300 | # self.write_tb(record_tb) 301 | 302 | # def get(self): 303 | # return self.last_rlt 304 | -------------------------------------------------------------------------------- /torchtoolbox/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: pistonyang@gmail.com 3 | 4 | from .loss import * 5 | from .sequential import * 6 | from .norm import * 7 | from .activation import * 8 | from .conv import * 9 | from .modules import * 10 | from .metric_loss import * 11 | from .transformer import * 12 | 13 | try: 14 | from .parallel import * 15 | except ImportError: 16 | pass 17 | -------------------------------------------------------------------------------- /torchtoolbox/nn/activation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | __all__ = ['Activation', 'Swish', 'Mish'] 4 | 5 | import torch 6 | from torch import nn 7 | 8 | from .functional import mish, swish 9 | 10 | 11 | class Swish(nn.Module): 12 | """Switch activation from 'SEARCHING FOR ACTIVATION FUNCTIONS' 13 | https://arxiv.org/pdf/1710.05941.pdf 14 | 15 | swish = x / (1 + e^-beta*x) 16 | d_swish = (1 + (1+beta*x)) / ((1 + e^-beta*x)^2) 17 | 18 | """ 19 | def __init__(self, beta=1.0): 20 | super(Swish, self).__init__() 21 | self.beta = beta 22 | 23 | def forward(self, x): 24 | return swish(x, self.beta) 25 | 26 | 27 | class Mish(nn.Module): 28 | """Mish activation from 'Mish: A Self Regularized Non-Monotonic Activation Function' 29 | https://www.bmvc2020-conference.com/assets/papers/0928.pdf 30 | 31 | mish = x*tanh(softplus(x)) 32 | d_mish = delta(x)swish(x, beta=1) + mish(x)/x 33 | 34 | """ 35 | def __init__(self): 36 | super(Mish, self).__init__() 37 | 38 | def forward(self, x): 39 | return mish(x) 40 | 41 | 42 | class QuickGELU(nn.Module): 43 | """QuickGELU refers to OpenAI-CLIP 44 | """ 45 | def forward(self, x): 46 | return x * torch.sigmoid(1.702 * x) 47 | 48 | 49 | class Activation(nn.Module): 50 | def __init__(self, act_type, auto_optimize=True, **kwargs): 51 | super(Activation, self).__init__() 52 | if act_type == 'relu': 53 | self.act = nn.ReLU(inplace=True) if auto_optimize else nn.ReLU(**kwargs) 54 | elif act_type == 'relu6': 55 | self.act = nn.ReLU6(inplace=True) if auto_optimize else nn.ReLU6(**kwargs) 56 | elif act_type == 'h_swish': 57 | self.act = nn.Hardswish(inplace=True) if auto_optimize else nn.Hardswish(**kwargs) 58 | elif act_type == 'h_sigmoid': 59 | self.act = nn.Hardsigmoid(inplace=True) if auto_optimize else nn.Hardsigmoid(**kwargs) 60 | elif act_type == 'swish': 61 | self.act = nn.SiLU(inplace=True) if auto_optimize else nn.SiLU(**kwargs) 62 | elif act_type == 'gelu': 63 | self.act = nn.GELU() 64 | elif act_type == 'quick_gelu': 65 | self.act = QuickGELU() 66 | elif act_type == 'elu': 67 | self.act = nn.ELU(inplace=True, **kwargs) if auto_optimize else nn.ELU(**kwargs) 68 | elif act_type == 'mish': 69 | self.act = Mish() 70 | elif act_type == 'sigmoid': 71 | self.act = nn.Sigmoid() 72 | elif act_type == 'lrelu': 73 | self.act = nn.LeakyReLU(inplace=True, **kwargs) if auto_optimize else nn.LeakyReLU(**kwargs) 74 | elif act_type == 'prelu': 75 | self.act = nn.PReLU(**kwargs) 76 | else: 77 | raise NotImplementedError('{} activation is not implemented.'.format(act_type)) 78 | 79 | def forward(self, x): 80 | return self.act(x) 81 | -------------------------------------------------------------------------------- /torchtoolbox/nn/conv.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | 4 | __all__ = ['DeformConv2d'] 5 | 6 | import torch 7 | from torch import nn 8 | 9 | 10 | class DeformConv2d(nn.Module): 11 | def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None, modulation=False): 12 | """ 13 | Args: 14 | modulation (bool, optional): If True, Modulated Defomable Convolution (Deformable ConvNets v2). 15 | """ 16 | super(DeformConv2d, self).__init__() 17 | self.kernel_size = kernel_size 18 | self.padding = padding 19 | self.stride = stride 20 | self.zero_padding = nn.ZeroPad2d(padding) 21 | self.conv = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias) 22 | 23 | self.p_conv = nn.Conv2d(inc, 2 * kernel_size * kernel_size, kernel_size=3, padding=1, stride=stride) 24 | nn.init.constant_(self.p_conv.weight, 0) 25 | self.p_conv.register_backward_hook(self._set_lr) 26 | 27 | self.modulation = modulation 28 | if modulation: 29 | self.m_conv = nn.Conv2d(inc, kernel_size * kernel_size, kernel_size=3, padding=1, stride=stride) 30 | nn.init.constant_(self.m_conv.weight, 0) 31 | self.m_conv.register_backward_hook(self._set_lr) 32 | 33 | @staticmethod 34 | def _set_lr(module, grad_input, grad_output): 35 | grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input))) 36 | grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output))) 37 | 38 | def forward(self, x): 39 | offset = self.p_conv(x) 40 | if self.modulation: 41 | m = torch.sigmoid(self.m_conv(x)) 42 | 43 | dtype = offset.data.type() 44 | ks = self.kernel_size 45 | N = offset.size(1) // 2 46 | 47 | if self.padding: 48 | x = self.zero_padding(x) 49 | 50 | # (b, 2N, h, w) 51 | p = self._get_p(offset, dtype) 52 | 53 | # (b, h, w, 2N) 54 | p = p.contiguous().permute(0, 2, 3, 1) 55 | q_lt = p.detach().floor() 56 | q_rb = q_lt + 1 57 | 58 | q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, 59 | x.size(2) - 1), 60 | torch.clamp(q_lt[..., N:], 0, 61 | x.size(3) - 1)], dim=-1).long() 62 | q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, 63 | x.size(2) - 1), 64 | torch.clamp(q_rb[..., N:], 0, 65 | x.size(3) - 1)], dim=-1).long() 66 | q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1) 67 | q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1) 68 | 69 | # clip p 70 | p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2) - 1), torch.clamp(p[..., N:], 0, x.size(3) - 1)], dim=-1) 71 | 72 | # bilinear kernel (b, h, w, N) 73 | g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * \ 74 | (1 + (q_lt[..., N:].type_as(p) - p[..., N:])) 75 | g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * \ 76 | (1 - (q_rb[..., N:].type_as(p) - p[..., N:])) 77 | g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * \ 78 | (1 - (q_lb[..., N:].type_as(p) - p[..., N:])) 79 | g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * \ 80 | (1 + (q_rt[..., N:].type_as(p) - p[..., N:])) 81 | 82 | # (b, c, h, w, N) 83 | x_q_lt = self._get_x_q(x, q_lt, N) 84 | x_q_rb = self._get_x_q(x, q_rb, N) 85 | x_q_lb = self._get_x_q(x, q_lb, N) 86 | x_q_rt = self._get_x_q(x, q_rt, N) 87 | 88 | # (b, c, h, w, N) 89 | x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \ 90 | g_rb.unsqueeze(dim=1) * x_q_rb + \ 91 | g_lb.unsqueeze(dim=1) * x_q_lb + \ 92 | g_rt.unsqueeze(dim=1) * x_q_rt 93 | 94 | # modulation 95 | if self.modulation: 96 | m = m.contiguous().permute(0, 2, 3, 1) 97 | m = m.unsqueeze(dim=1) 98 | m = torch.cat([m for _ in range(x_offset.size(1))], dim=1) 99 | x_offset *= m 100 | 101 | x_offset = self._reshape_x_offset(x_offset, ks) 102 | out = self.conv(x_offset) 103 | 104 | return out 105 | 106 | def _get_p_n(self, N, dtype): 107 | p_n_x, p_n_y = torch.meshgrid(torch.arange(-(self.kernel_size - 1) // 2, (self.kernel_size - 1) // 2 + 1), 108 | torch.arange(-(self.kernel_size - 1) // 2, (self.kernel_size - 1) // 2 + 1)) 109 | # (2N, 1) 110 | p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0) 111 | p_n = p_n.view(1, 2 * N, 1, 1).type(dtype) 112 | 113 | return p_n 114 | 115 | def _get_p_0(self, h, w, N, dtype): 116 | p_0_x, p_0_y = torch.meshgrid(torch.arange(1, h * self.stride + 1, self.stride), 117 | torch.arange(1, w * self.stride + 1, self.stride)) 118 | p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1) 119 | p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1) 120 | p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype) 121 | 122 | return p_0 123 | 124 | def _get_p(self, offset, dtype): 125 | N, h, w = offset.size(1) // 2, offset.size(2), offset.size(3) 126 | 127 | # (1, 2N, 1, 1) 128 | p_n = self._get_p_n(N, dtype) 129 | # (1, 2N, h, w) 130 | p_0 = self._get_p_0(h, w, N, dtype) 131 | p = p_0 + p_n + offset 132 | return p 133 | 134 | def _get_x_q(self, x, q, N): 135 | b, h, w, _ = q.size() 136 | padded_w = x.size(3) 137 | c = x.size(1) 138 | # (b, c, h*w) 139 | x = x.contiguous().view(b, c, -1) 140 | 141 | # (b, h, w, N) 142 | index = q[..., :N] * padded_w + q[..., N:] # offset_x*w + offset_y 143 | # (b, c, h*w*N) 144 | index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1) 145 | 146 | x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N) 147 | 148 | return x_offset 149 | 150 | @staticmethod 151 | def _reshape_x_offset(x_offset, ks): 152 | b, c, h, w, N = x_offset.size() 153 | x_offset = torch.cat([x_offset[..., s:s + ks].contiguous().view(b, c, h, w * ks) for s in range(0, N, ks)], dim=-1) 154 | x_offset = x_offset.contiguous().view(b, c, h * ks, w * ks) 155 | 156 | return x_offset 157 | -------------------------------------------------------------------------------- /torchtoolbox/nn/functional.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | """This file should not be used by 'form functional import *'""" 4 | 5 | import torch 6 | import numpy as np 7 | import numbers 8 | 9 | from torch import nn 10 | from .operators import * 11 | from torch.nn import functional as F 12 | 13 | 14 | def logits_distribution(pred, target, classes): 15 | one_hot = F.one_hot(target, num_classes=classes).bool() 16 | return torch.where(one_hot, pred, -1 * pred) 17 | 18 | 19 | def reducing(ret, reduction='mean'): 20 | if reduction == 'mean': 21 | ret = torch.mean(ret) 22 | elif reduction == 'sum': 23 | ret = torch.sum(ret) 24 | elif reduction == 'none': 25 | pass 26 | else: 27 | raise NotImplementedError 28 | return ret 29 | 30 | 31 | def _batch_weight(weight, target): 32 | return weight.gather(dim=0, index=target) 33 | 34 | 35 | def logits_nll_loss(input, target, weight=None, reduction='mean'): 36 | """logits_nll_loss 37 | Different from nll loss, this is for sigmoid based loss. 38 | The difference is this will add along C(class) dim. 39 | """ 40 | 41 | assert input.dim() == 2, 'Input shape should be (B, C).' 42 | if input.size(0) != target.size(0): 43 | raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'.format( 44 | input.size(0), target.size(0))) 45 | 46 | ret = input.sum(dim=-1) 47 | if weight is not None: 48 | ret = _batch_weight(weight, target) * ret 49 | return reducing(ret, reduction) 50 | 51 | 52 | def class_balanced_weight(beta, samples_per_class): 53 | assert 0 <= beta < 1, 'Wrong rang of beta {}'.format(beta) 54 | if not isinstance(samples_per_class, np.ndarray): 55 | if isinstance(samples_per_class, (list, tuple)): 56 | samples_per_class = np.array(samples_per_class) 57 | elif torch.is_tensor(samples_per_class): 58 | samples_per_class = samples_per_class.numpy() 59 | else: 60 | raise NotImplementedError('Type of samples_per_class should be {}, {} or {} but got {}'.format( 61 | (list, tuple), np.ndarray, torch.Tensor, type(samples_per_class))) 62 | assert isinstance(samples_per_class, np.ndarray) \ 63 | and isinstance(beta, numbers.Number) 64 | 65 | balanced_matrix = (1 - beta) / (1 - np.power(beta, samples_per_class)) 66 | return torch.Tensor(balanced_matrix) 67 | 68 | 69 | def swish(x, beta=1.0): 70 | """Swish activation. 71 | 'https://arxiv.org/pdf/1710.05941.pdf' 72 | Args: 73 | x: Input tensor. 74 | beta: 75 | """ 76 | return SwishOP.apply(x, beta) 77 | 78 | 79 | def mish(x): 80 | """Mish activation. 81 | 'https://www.bmvc2020-conference.com/assets/papers/0928.pdf' 82 | Args: 83 | x: Input tensor. 84 | """ 85 | return x * torch.tanh(F.softplus(x)) 86 | 87 | 88 | @torch.no_grad() 89 | def smooth_one_hot(true_labels: torch.Tensor, classes: int, smoothing=0.0): 90 | """ 91 | if smoothing == 0, it's one-hot method 92 | if 0 < smoothing < 1, it's smooth method 93 | Warning: This function has no grad. 94 | """ 95 | # assert 0 <= smoothing < 1 96 | confidence = 1.0 - smoothing 97 | label_shape = torch.Size((true_labels.size(0), classes)) 98 | 99 | smooth_label = torch.empty(size=label_shape, device=true_labels.device) 100 | smooth_label.fill_(smoothing / (classes - 1)) 101 | smooth_label.scatter_(1, true_labels.data.unsqueeze(1), confidence) 102 | return smooth_label 103 | 104 | 105 | def switch_norm(x, 106 | running_mean, 107 | running_var, 108 | weight, 109 | bias, 110 | mean_weight, 111 | var_weight, 112 | training=False, 113 | momentum=0.9, 114 | eps=0.1, 115 | moving_average=True): 116 | size = x.size() 117 | x = x.view(size[0], size[1], -1) 118 | 119 | mean_instance = x.mean(-1, keepdim=True) 120 | var_instance = x.var(-1, keepdim=True) 121 | 122 | mean_layer = x.mean((1, -1), keepdim=True) 123 | var_layer = x.var((1, -1), keepdim=True) 124 | 125 | if training: 126 | mean_batch = x.mean((0, -1)) 127 | var_batch = x.var((0, -1)) 128 | if moving_average: 129 | running_mean.mul_(momentum) 130 | running_mean.add_((1 - momentum) * mean_batch.data) 131 | running_var.mul_(momentum) 132 | running_var.add_((1 - momentum) * var_batch.data) 133 | else: 134 | running_mean.add_(mean_batch.data) 135 | running_var.add_(mean_batch.data**2 + var_batch.data) 136 | else: 137 | mean_batch = running_mean 138 | var_batch = running_var 139 | 140 | mean_weight = mean_weight.softmax(0) 141 | var_weight = var_weight.softmax(0) 142 | 143 | mean = mean_weight[0] * mean_instance + \ 144 | mean_weight[1] * mean_layer + \ 145 | mean_weight[2] * mean_batch.unsqueeze(1) # noqa:E127 146 | 147 | var = var_weight[0] * var_instance + \ 148 | var_weight[1] * var_layer + \ 149 | var_weight[2] * var_batch.unsqueeze(1) # noqa:E127 150 | 151 | x = (x - mean) / (var + eps).sqrt() 152 | x = x * weight.unsqueeze(1) + bias.unsqueeze(1) 153 | x = x.view(size) 154 | return x 155 | 156 | 157 | def instance_std(x, eps=1e-5): 158 | var = torch.var(x, dim=(2, 3), keepdim=True) 159 | std = torch.sqrt(var + eps) 160 | return std 161 | 162 | 163 | def group_std(x: torch.Tensor, groups=32, eps=1e-5): 164 | n, c, h, w = x.size() 165 | x = torch.reshape(x, (n, groups, c // groups, h, w)) 166 | var = torch.var(x, dim=(2, 3, 4), keepdim=True) 167 | std = torch.sqrt(var + eps) 168 | return torch.reshape(std, (n, c, h, w)) 169 | 170 | 171 | def evo_norm(x, prefix, running_var, v, weight, bias, training, momentum, eps=0.1, groups=32): 172 | if prefix == 'b0': 173 | if training: 174 | var = torch.var(x, dim=(0, 2, 3), keepdim=True) 175 | running_var.mul_(momentum) 176 | running_var.add_((1 - momentum) * var) 177 | else: 178 | var = running_var 179 | if v is not None: 180 | den = torch.max((var + eps).sqrt(), v * x + instance_std(x, eps)) 181 | x = x / den * weight + bias 182 | else: 183 | x = x * weight + bias 184 | else: 185 | if v is not None: 186 | x = x * torch.sigmoid(v * x) / group_std(x, groups, eps) * weight + bias 187 | else: 188 | x = x * weight + bias 189 | 190 | return x 191 | 192 | 193 | def drop_block(x, mask): 194 | return x * mask * mask.numel() / mask.sum() 195 | 196 | 197 | def channel_shuffle(x, groups): 198 | batchsize, num_channels, height, width = x.data.size() 199 | channels_per_group = num_channels // groups 200 | 201 | # reshape 202 | x = x.view(batchsize, groups, channels_per_group, height, width) 203 | 204 | x = torch.transpose(x, 1, 2).contiguous() 205 | 206 | # flatten 207 | x = x.view(batchsize, -1, height, width) 208 | 209 | return x 210 | 211 | 212 | def channel_shift(x, shift): 213 | x = torch.cat([x[:, shift:, ...], x[:, :shift, ...]], dim=1) 214 | return x 215 | -------------------------------------------------------------------------------- /torchtoolbox/nn/init.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | __all__ = ['XavierInitializer', 'KaimingInitializer', 'MSRAPrelu', 'TruncNormInitializer', 'ZeroLastGamma'] 4 | 5 | import abc 6 | import math 7 | 8 | from torch import nn 9 | from torch.nn.init import (_calculate_fan_in_and_fan_out, _no_grad_normal_, kaiming_normal_, kaiming_uniform_, xavier_normal_, 10 | xavier_uniform_, zeros_) 11 | 12 | from ..tools import to_list 13 | 14 | 15 | class Initializer(abc.ABC): 16 | def __init__(self, extra_conv=(), extra_norm=(), extra_linear=()) -> None: 17 | self.extra_conv = to_list(extra_conv) 18 | self.extra_norm = to_list(extra_norm) 19 | self.extra_linear = to_list(extra_linear) 20 | 21 | def is_conv(self, module): 22 | return isinstance(module, (nn.Conv2d, nn.Conv3d, *self.extra_conv)) 23 | 24 | def is_norm(self, module): 25 | return isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm, *self.extra_norm)) 26 | 27 | def is_linear(self, module): 28 | return isinstance(module, (nn.Linear, *self.extra_linear)) 29 | 30 | def is_msa(self, module): 31 | return isinstance(module, nn.MultiheadAttention) 32 | 33 | def init_norm(self, module): 34 | if module.weight is not None: 35 | module.weight.data.fill_(1) 36 | if module.bias is not None: 37 | module.bias.data.zero_() 38 | 39 | @abc.abstractmethod 40 | def __call__(self, module): 41 | pass 42 | 43 | 44 | class XavierInitializer(Initializer): 45 | """Initialize a model params by Xavier. 46 | 47 | Fills the input `Tensor` with values according to the method 48 | described in `Understanding the difficulty of training deep feedforward 49 | neural networks` - Glorot, X. & Bengio, Y. (2010) 50 | 51 | Args: 52 | model (nn.Module): model you need to initialize. 53 | random_type (string): random_type 54 | gain (float): an optional scaling factor, default is sqrt(2.0) 55 | 56 | """ 57 | def __init__(self, random_type='uniform', gain=math.sqrt(2.0), **kwargs): 58 | super().__init__(**kwargs) 59 | assert random_type in ('uniform', 'normal') 60 | self.random_type = random_type 61 | self.initializer = xavier_uniform_ if random_type == 'uniform' else xavier_normal_ 62 | self.gain = gain 63 | 64 | def initializer(self, tensor): 65 | initializer = xavier_uniform_ if self.random_type == 'uniform' else xavier_normal_ 66 | initializer(tensor, gain=self.gain) 67 | 68 | def __call__(self, module): 69 | if self.is_conv(module): 70 | self.initializer(module.weight.data) 71 | if module.bias is not None: 72 | module.bias.data.zero_() 73 | 74 | elif self.is_norm(module): 75 | self.init_norm(module) 76 | 77 | elif self.is_linear(module): 78 | self.initializer(module.weight.data) 79 | if module.bias is not None: 80 | module.bias.data.zero_() 81 | 82 | elif self.is_msa(module): 83 | if module.q_proj_weight is not None: 84 | self.initializer(module.q_proj_weight.data) 85 | if module.k_proj_weight is not None: 86 | self.initializer(module.k_proj_weight.data) 87 | if module.v_proj_weight is not None: 88 | self.initializer(module.v_proj_weight.data) 89 | if module.in_proj_weight is not None: 90 | self.initializer(module.in_proj_weight.data) 91 | if module.in_proj_bias is not None: 92 | module.in_proj_bias.data.zero_() 93 | if module.bias_k is not None: 94 | module.bias_k.data.zero_() 95 | if module.bias_v is not None: 96 | module.bias_v.data.zero_() 97 | 98 | 99 | class KaimingInitializer(Initializer): 100 | def __init__(self, slope=0, mode='fan_out', nonlinearity='relu', random_type='normal', **kwargs): 101 | super().__init__(**kwargs) 102 | assert random_type in ('uniform', 'normal') 103 | self.random_type = random_type 104 | self.slope = slope 105 | self.mode = mode 106 | self.nonlinearity = nonlinearity 107 | 108 | def initializer(self, tensor): 109 | initializer = kaiming_uniform_ if self.random_type == 'uniform' else kaiming_normal_ 110 | initializer(tensor, self.slope, self.mode, self.nonlinearity) 111 | 112 | def __call__(self, module): 113 | if self.is_conv(module): 114 | self.initializer(module.weight.data) 115 | if module.bias is not None: 116 | module.bias.data.zero_() 117 | 118 | elif self.is_norm(module): 119 | self.init_norm(module) 120 | 121 | elif self.is_linear(module): 122 | self.initializer(module.weight.data) 123 | if module.bias is not None: 124 | module.bias.data.zero_() 125 | 126 | elif self.is_msa(module): 127 | if module.q_proj_weight is not None: 128 | self.initializer(module.q_proj_weight.data) 129 | if module.k_proj_weight is not None: 130 | self.initializer(module.k_proj_weight.data) 131 | if module.v_proj_weight is not None: 132 | self.initializer(module.v_proj_weight.data) 133 | if module.in_proj_weight is not None: 134 | self.initializer(module.in_proj_weight.data) 135 | if module.in_proj_bias is not None: 136 | module.in_proj_bias.data.zero_() 137 | if module.bias_k is not None: 138 | module.bias_k.data.zero_() 139 | if module.bias_v is not None: 140 | module.bias_v.data.zero_() 141 | 142 | 143 | class MSRAPrelu(Initializer): 144 | """Initialize the weight according to a MSRA paper. 145 | This initializer implements *Delving Deep into Rectifiers: Surpassing 146 | Human-Level Performance on ImageNet Classification*, available at 147 | https://arxiv.org/abs/1502.01852. 148 | """ 149 | def __init__(self, slope=0.25, **kwargs): 150 | super().__init__(**kwargs) 151 | self.magnitude = 2. / (1 + slope**2) 152 | 153 | def initializer(self, tensor): 154 | fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) 155 | factor = (fan_in + fan_out) / 2.0 156 | scale = math.sqrt(self.magnitude / factor) 157 | _no_grad_normal_(tensor, 0, scale) 158 | 159 | def __call__(self, module): 160 | if self.is_conv(module): 161 | self.initializer(module.weight.data) 162 | if module.bias is not None: 163 | module.bias.data.zero_() 164 | 165 | elif self.is_norm(module): 166 | self.init_norm(module) 167 | 168 | elif self.is_linear(module): 169 | self.initializer(module.weight.data) 170 | if module.bias is not None: 171 | module.bias.data.zero_() 172 | 173 | elif self.is_msa(module): 174 | if module.q_proj_weight is not None: 175 | self.initializer(module.q_proj_weight.data) 176 | if module.k_proj_weight is not None: 177 | self.initializer(module.k_proj_weight.data) 178 | if module.v_proj_weight is not None: 179 | self.initializer(module.v_proj_weight.data) 180 | if module.in_proj_weight is not None: 181 | self.initializer(module.in_proj_weight.data) 182 | if module.in_proj_bias is not None: 183 | module.in_proj_bias.data.zero_() 184 | if module.bias_k is not None: 185 | module.bias_k.data.zero_() 186 | if module.bias_v is not None: 187 | module.bias_v.data.zero_() 188 | 189 | 190 | class TruncNormInitializer(Initializer): 191 | def __init__(self, mean=0., std=1, a=-2., b=2., **kwargs): 192 | super().__init__(**kwargs) 193 | self.mean = mean 194 | self.std = std 195 | self.a = a 196 | self.b = b 197 | 198 | def initializer(self, tensor): 199 | nn.init.trunc_normal_(tensor, self.mean, self.std, self.a, self.b) 200 | 201 | def __call__(self, module): 202 | if self.is_conv(module): 203 | self.initializer(module.weight.data) 204 | if module.bias is not None: 205 | module.bias.data.zero_() 206 | 207 | elif self.is_norm(module): 208 | self.init_norm(module) 209 | 210 | elif self.is_linear(module): 211 | self.initializer(module.weight.data) 212 | if module.bias is not None: 213 | module.bias.data.zero_() 214 | 215 | elif self.is_msa(module): 216 | if module.q_proj_weight is not None: 217 | self.initializer(module.q_proj_weight.data) 218 | if module.k_proj_weight is not None: 219 | self.initializer(module.k_proj_weight.data) 220 | if module.v_proj_weight is not None: 221 | self.initializer(module.v_proj_weight.data) 222 | if module.in_proj_weight is not None: 223 | self.initializer(module.in_proj_weight.data) 224 | if module.in_proj_bias is not None: 225 | module.in_proj_bias.data.zero_() 226 | if module.bias_k is not None: 227 | module.bias_k.data.zero_() 228 | if module.bias_v is not None: 229 | module.bias_v.data.zero_() 230 | 231 | 232 | class ZeroLastGamma(object): 233 | """Notice that this need to put after other initializer. 234 | """ 235 | def __init__(self, block_name='Bottleneck', bn_name='bn3'): 236 | self.block_name = block_name 237 | self.bn_name = bn_name 238 | 239 | def __call__(self, module): 240 | if module.__class__.__name__ == self.block_name: 241 | target_bn = module.__getattr__(self.bn_name) 242 | zeros_(target_bn.weight) 243 | -------------------------------------------------------------------------------- /torchtoolbox/nn/loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | __all__ = ['LabelSmoothingLoss', 'SigmoidCrossEntropy', 'FocalLoss', 'L0Loss', 'RingLoss', 'CenterLoss', 'CircleLoss'] 4 | 5 | from . import functional as BF 6 | from torch import nn 7 | from torch.nn import functional as F 8 | from torch.nn.modules.loss import _WeightedLoss 9 | import torch 10 | 11 | 12 | class SigmoidCrossEntropy(_WeightedLoss): 13 | def __init__(self, classes, weight=None, reduction='mean'): 14 | super(SigmoidCrossEntropy, self).__init__(weight=weight, reduction=reduction) 15 | self.classes = classes 16 | 17 | def forward(self, pred, target): 18 | zt = BF.logits_distribution(pred, target, self.classes) 19 | return BF.logits_nll_loss(-F.logsigmoid(zt), target, self.weight, self.reduction) 20 | 21 | 22 | class FocalLoss(_WeightedLoss): 23 | def __init__(self, classes, gamma, weight=None, reduction='mean'): 24 | super(FocalLoss, self).__init__(weight=weight, reduction=reduction) 25 | self.classes = classes 26 | self.gamma = gamma 27 | 28 | def forward(self, pred, target): 29 | zt = BF.logits_distribution(pred, target, self.classes) 30 | ret = -(1 - torch.sigmoid(zt)).pow(self.gamma) * F.logsigmoid(zt) 31 | return BF.logits_nll_loss(ret, target, self.weight, self.reduction) 32 | 33 | 34 | class L0Loss(nn.Module): 35 | """L0loss from 36 | "Noise2Noise: Learning Image Restoration without Clean Data" 37 | `_ paper. 38 | 39 | """ 40 | def __init__(self, gamma=2, eps=1e-8): 41 | super(L0Loss, self).__init__() 42 | self.gamma = gamma 43 | self.eps = eps 44 | 45 | def forward(self, pred, target): 46 | loss = (torch.abs(pred - target) + self.eps).pow(self.gamma) 47 | return torch.mean(loss) 48 | 49 | 50 | class LabelSmoothingLoss(nn.Module): 51 | """This is label smoothing loss function. 52 | """ 53 | def __init__(self, classes, smoothing=0.0, dim=-1): 54 | super(LabelSmoothingLoss, self).__init__() 55 | self.confidence = 1.0 - smoothing 56 | self.smoothing = smoothing 57 | self.cls = classes 58 | self.dim = dim 59 | 60 | def forward(self, pred, target): 61 | pred = pred.log_softmax(dim=self.dim) 62 | true_dist = BF.smooth_one_hot(target, self.cls, self.smoothing) 63 | return torch.mean(torch.sum(-true_dist * pred, dim=self.dim)) 64 | 65 | 66 | class CircleLoss(nn.Module): 67 | r"""CircleLoss from 68 | `"Circle Loss: A Unified Perspective of Pair Similarity Optimization" 69 | `_ paper. 70 | 71 | Parameters 72 | ---------- 73 | m: float. 74 | Margin parameter for loss. 75 | gamma: int. 76 | Scale parameter for loss. 77 | 78 | Outputs: 79 | - **loss**: scalar. 80 | """ 81 | def __init__(self, m, gamma): 82 | super(CircleLoss, self).__init__() 83 | self.m = m 84 | self.gamma = gamma 85 | self.dp = 1 - m 86 | self.dn = m 87 | 88 | def forward(self, x, target): 89 | similarity_matrix = x @ x.T # need gard here 90 | label_matrix = target.unsqueeze(1) == target.unsqueeze(0) 91 | negative_matrix = label_matrix.logical_not() 92 | positive_matrix = label_matrix.fill_diagonal_(False) 93 | 94 | sp = torch.where(positive_matrix, similarity_matrix, torch.zeros_like(similarity_matrix)) 95 | sn = torch.where(negative_matrix, similarity_matrix, torch.zeros_like(similarity_matrix)) 96 | 97 | ap = torch.clamp_min(1 + self.m - sp.detach(), min=0.) 98 | an = torch.clamp_min(sn.detach() + self.m, min=0.) 99 | 100 | logit_p = -self.gamma * ap * (sp - self.dp) 101 | logit_n = self.gamma * an * (sn - self.dn) 102 | 103 | logit_p = torch.where(positive_matrix, logit_p, torch.zeros_like(logit_p)) 104 | logit_n = torch.where(negative_matrix, logit_n, torch.zeros_like(logit_n)) 105 | 106 | loss = F.softplus(torch.logsumexp(logit_p, dim=1) + torch.logsumexp(logit_n, dim=1)).mean() 107 | return loss 108 | 109 | 110 | class RingLoss(nn.Module): 111 | """Computes the Ring Loss from 112 | `"Ring loss: Convex Feature Normalization for Face Recognition" 113 | 114 | Parameters 115 | ---------- 116 | lamda: float 117 | The loss weight enforcing a trade-off between the softmax loss and ring loss. 118 | l2_norm: bool 119 | Whether use l2 norm to embedding. 120 | weight_initializer (None or torch.Tensor): If not None a torch.Tensor should be provided. 121 | 122 | Outputs: 123 | - **loss**: scalar. 124 | """ 125 | def __init__(self, lamda, l2_norm=True, weight_initializer=None): 126 | super(RingLoss, self).__init__() 127 | self.lamda = lamda 128 | self.l2_norm = l2_norm 129 | if weight_initializer is None: 130 | self.R = self.parameters(torch.rand(1)) 131 | else: 132 | assert torch.is_tensor(weight_initializer), 'weight_initializer should be a Tensor.' 133 | self.R = self.parameters(weight_initializer) 134 | 135 | def forward(self, embedding): 136 | if self.l2_norm: 137 | embedding = F.normalize(embedding, 2, dim=-1) 138 | loss = (embedding - self.R).pow(2).sum(1).mean(0) * self.lamda * 0.5 139 | return loss 140 | 141 | 142 | class CenterLoss(nn.Module): 143 | """Computes the Center Loss from 144 | `"A Discriminative Feature Learning Approach for Deep Face Recognition" 145 | `_paper. 146 | Implementation is refer to 147 | 'https://github.com/lyakaap/image-feature-learning-pytorch/blob/master/code/center_loss.py' 148 | 149 | Parameters 150 | ---------- 151 | classes: int. 152 | Number of classes. 153 | embedding_dim: int 154 | embedding_dim. 155 | lamda: float 156 | The loss weight enforcing a trade-off between the softmax loss and center loss. 157 | 158 | Outputs: 159 | - **loss**: loss tensor with shape (batch_size,). Dimensions other than 160 | batch_axis are averaged out. 161 | """ 162 | def __init__(self, classes, embedding_dim, lamda): 163 | super(CenterLoss, self).__init__() 164 | self.lamda = lamda 165 | self.centers = nn.Parameter(torch.randn(classes, embedding_dim)) 166 | 167 | def forward(self, embedding, target): 168 | expanded_centers = self.centers.index_select(0, target) 169 | intra_distances = embedding.dist(expanded_centers) 170 | loss = self.lamda * 0.5 * intra_distances / target.size()[0] 171 | return loss 172 | 173 | 174 | class KnowledgeDistillationLoss(nn.Module): 175 | def __init__(self, temperature=1): 176 | super().__init__() 177 | self.temperature = temperature 178 | 179 | def forward(self, student_output, teacher_output): 180 | return self.temperature**2 * torch.mean( 181 | torch.sum(-F.softmax(teacher_output / self.temperature) * F.log_softmax(student_output / self.temperature), dim=1)) 182 | -------------------------------------------------------------------------------- /torchtoolbox/nn/metric_loss.py: -------------------------------------------------------------------------------- 1 | __all__ = ['L2Softmax', 'ArcLoss', 'AMSoftmax', 'CircleLossFC'] 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | import torch 6 | import math 7 | 8 | 9 | class NormLinear(nn.Module): 10 | def __init__(self, in_features, classes, weight_norm=False, feature_norm=False): 11 | super(NormLinear, self).__init__() 12 | self.weight_norm = weight_norm 13 | self.feature_norm = feature_norm 14 | 15 | self.classes = classes 16 | self.in_features = in_features 17 | 18 | self.weight = nn.Parameter(torch.Tensor(classes, in_features)) 19 | nn.init.normal_(self.weight, std=0.01) 20 | 21 | def forward(self, x): 22 | weight = F.normalize(self.weight, 2, dim=-1) if self.weight_norm else self.weight 23 | if self.feature_norm: 24 | x = F.normalize(x, 2, dim=-1) 25 | 26 | return F.linear(x, weight) 27 | 28 | def extra_repr(self): 29 | return 'in_features={}, out_features={}'.format(self.in_features, self.classes) 30 | 31 | 32 | class L2Softmax(nn.Module): 33 | r"""L2Softmax from 34 | `"L2-constrained Softmax Loss for Discriminative Face Verification" 35 | `_ paper. 36 | 37 | Parameters 38 | ---------- 39 | classes: int. 40 | Number of classes. 41 | alpha: float. 42 | The scaling parameter, a hypersphere with small alpha 43 | will limit surface area for embedding features. 44 | p: float, default is 0.9. 45 | The expected average softmax probability for correctly 46 | classifying a feature. 47 | from_normx: bool, default is False. 48 | Whether input has already been normalized. 49 | 50 | Outputs: 51 | - **loss**: loss tensor with shape (1,). Dimensions other than 52 | batch_axis are averaged out. 53 | """ 54 | def __init__(self, embedding_size, classes, alpha=64, p=0.9): 55 | super(L2Softmax, self).__init__() 56 | alpha_low = math.log(p * (classes - 2) / (1 - p)) 57 | assert alpha > alpha_low, "For given probability of p={}, alpha should higher than {}.".format(p, alpha_low) 58 | self.alpha = alpha 59 | self.linear = NormLinear(embedding_size, classes, True, True) 60 | 61 | def forward(self, x, target): 62 | x = self.linear(x) 63 | x = x * self.alpha 64 | return x 65 | 66 | 67 | class ArcLoss(nn.Module): 68 | r"""ArcLoss from 69 | `"ArcFace: Additive Angular Margin Loss for Deep Face Recognition" 70 | `_ paper. 71 | 72 | Parameters 73 | ---------- 74 | classes: int. 75 | Number of classes. 76 | m: float. 77 | Margin parameter for loss. 78 | s: int. 79 | Scale parameter for loss. 80 | 81 | Outputs: 82 | - **loss**: loss tensor with shape (batch_size,). Dimensions other than 83 | batch_axis are averaged out. 84 | """ 85 | def __init__(self, embedding_size, classes, m=0.5, s=64, easy_margin=True): 86 | super(ArcLoss, self).__init__() 87 | assert s > 0. 88 | assert 0 <= m <= (math.pi / 2) 89 | self.s = s 90 | self.m = m 91 | self.cos_m = math.cos(m) 92 | self.sin_m = math.sin(m) 93 | self.mm = math.sin(math.pi - m) * m 94 | self.threshold = math.cos(math.pi - m) 95 | self.classes = classes 96 | self.easy_margin = easy_margin 97 | self.linear = NormLinear(embedding_size, classes, True, True) 98 | 99 | @torch.no_grad() 100 | def _get_body(self, x, target): 101 | cos_t = torch.gather(x, 1, target.unsqueeze(1)) # cos(theta_yi) 102 | if self.easy_margin: 103 | # cond = torch.relu(cos_t) 104 | cond = torch.clamp_min(cos_t, min=0.) 105 | else: 106 | cond_v = cos_t - self.threshold 107 | # cond = torch.relu(cond_v) 108 | cond = torch.clamp_min(cond_v, min=0.) 109 | cond = cond.bool() 110 | new_zy = torch.cos(torch.acos(cos_t) + self.m).type(cos_t.dtype) # cos(theta_yi + m), use `.type()` to fix FP16 111 | if self.easy_margin: 112 | zy_keep = cos_t 113 | else: 114 | zy_keep = cos_t - self.mm # (cos(theta_yi) - sin(pi - m)*m) 115 | new_zy = torch.where(cond, new_zy, zy_keep) 116 | diff = new_zy - cos_t # cos(theta_yi + m) - cos(theta_yi) 117 | gt_one_hot = F.one_hot(target, num_classes=self.classes) 118 | body = gt_one_hot * diff 119 | return body 120 | 121 | def forward(self, x, target): 122 | x = self.linear(x) 123 | body = self._get_body(x, target) 124 | x = x + body 125 | x = x * self.s 126 | return x 127 | 128 | 129 | class AMSoftmax(nn.Module): 130 | r"""CosLoss from 131 | `"CosFace: Large Margin Cosine Loss for Deep Face Recognition" 132 | `_ paper. 133 | 134 | It is also AM-Softmax from 135 | `"Additive Margin Softmax for Face Verification" 136 | `_ paper. 137 | 138 | Parameters 139 | ---------- 140 | classes: int. 141 | Number of classes. 142 | m: float, default 0.4 143 | Margin parameter for loss. 144 | s: int, default 64 145 | Scale parameter for loss. 146 | 147 | 148 | Outputs: 149 | - **loss**: loss tensor with shape (batch_size,). Dimensions other than 150 | batch_axis are averaged out. 151 | """ 152 | def __init__(self, embedding_size, classes, m=0.35, s=64): 153 | super(AMSoftmax, self).__init__() 154 | assert m > 0 and s > 0 155 | self.classes = classes 156 | self.scale = s 157 | self.margin = m 158 | self.linear = NormLinear(embedding_size, classes, True, True) 159 | 160 | def forward(self, x, target): 161 | x = self.linear(x) 162 | sparse_target = F.one_hot(target, num_classes=self.classes) 163 | x = x - sparse_target * self.margin 164 | x = x * self.scale 165 | return x 166 | 167 | 168 | class CircleLossFC(nn.Module): 169 | def __init__(self, embedding_size, classes, m=0.25, gamma=256): 170 | super(CircleLossFC, self).__init__() 171 | self.m = m 172 | self.gamma = gamma 173 | self.dp = 1 - m 174 | self.dn = m 175 | self.classes = classes 176 | self.linear = NormLinear(embedding_size, classes, True, True) 177 | 178 | @torch.no_grad() 179 | def get_param(self, x): 180 | ap = torch.relu(1 + self.m - x.detach()) 181 | an = torch.relu(x.detach() + self.m) 182 | return ap, an 183 | 184 | def forward(self, x, target): 185 | x = self.linear(x) 186 | ap, an = self.get_param(x) 187 | gt_one_hot = F.one_hot(target, num_classes=self.classes) 188 | x = gt_one_hot * ap * (x - self.dp) + (1 - gt_one_hot) * an * (x - self.dn) 189 | x = x * self.gamma 190 | return x 191 | -------------------------------------------------------------------------------- /torchtoolbox/nn/modules.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | # @Author:X.Yang(xuyang@deepmotion.ai) 4 | 5 | __all__ = ['ChannelShuffle', 'ChannelCircularShift'] 6 | from . import functional as F 7 | from torch import nn 8 | 9 | 10 | class TBModule: 11 | """This is a template interface, do not inherit this. 12 | This class will provide some specific features which used for toolbox to call. 13 | You should write them in you specific class not inherit this. 14 | (For now this is a wise idea, inherit this will raise some another issues. 15 | For instance, if I only need no_wd, inherit this will bring other func to your class and how to deal with unused func is a issue.) 16 | Do not change func 17 | """ 18 | def __init__(self): 19 | raise NotImplementedError 20 | 21 | def no_wd(self, decay: list, no_decay: list): 22 | """This is a interface call by `tools.no_decay_bias` 23 | 24 | Args: 25 | decay ([type]): param use weight decay. 26 | no_decay ([type]): param do not use weight decay. 27 | 28 | Returns: 29 | None 30 | """ 31 | raise NotImplementedError 32 | 33 | def num_param(self, input, output): 34 | """This is interface call by 'tools.summary' 35 | 36 | Returns: 37 | [int, int]: module num params.(learnable, not learnable) 38 | """ 39 | raise NotImplementedError 40 | 41 | def flops(self, input, output): 42 | """This is a interface call by 'tools.summary' 43 | 44 | Returns: 45 | [int]: module flops. 46 | """ 47 | raise NotImplementedError 48 | 49 | 50 | class ChannelShuffle(nn.Module): 51 | def __init__(self, groups: int): 52 | super().__init__() 53 | self.groups = groups 54 | 55 | def forward(self, x): 56 | x = F.channel_shuffle(x, self.groups) 57 | return x 58 | 59 | 60 | class ChannelCircularShift(nn.Module): 61 | def __init__(self, num_shift): 62 | super().__init__() 63 | self.shift = num_shift 64 | 65 | def forward(self, x): 66 | return F.channel_shift(x, self.shift) 67 | -------------------------------------------------------------------------------- /torchtoolbox/nn/norm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | __all__ = ['SwitchNorm2d', 'SwitchNorm3d', 'EvoNormB0', 'EvoNormS0', 'DropBlock2d', 'DropPath'] 4 | 5 | import torch 6 | from torch import nn 7 | from . import functional as F 8 | 9 | 10 | class _SwitchNorm(nn.Module): 11 | """ 12 | Avoid to feed 1xCxHxW and NxCx1x1 data to this. 13 | """ 14 | _version = 2 15 | 16 | def __init__(self, num_features, eps=1e-5, momentum=0.9, affine=True): 17 | super(_SwitchNorm, self).__init__() 18 | self.num_features = num_features 19 | self.eps = eps 20 | self.momentum = momentum 21 | self.affine = affine 22 | if self.affine: 23 | self.weight = nn.Parameter(torch.Tensor(num_features)) 24 | self.bias = nn.Parameter(torch.Tensor(num_features)) 25 | else: 26 | self.register_parameter('weight', None) 27 | self.register_parameter('bias', None) 28 | 29 | self.mean_weight = nn.Parameter(torch.ones(3)) 30 | self.var_weight = nn.Parameter(torch.ones(3)) 31 | 32 | self.register_buffer('running_mean', torch.zeros(num_features)) 33 | self.register_buffer('running_var', torch.ones(num_features)) 34 | 35 | def _check_input_dim(self, x): 36 | raise NotImplementedError 37 | 38 | def forward(self, x): 39 | self._check_input_dim(x) 40 | return F.switch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, self.mean_weight, self.var_weight, 41 | self.training, self.momentum, self.eps) 42 | 43 | 44 | class SwitchNorm2d(_SwitchNorm): 45 | def _check_input_dim(self, x): 46 | if x.dim() != 4: 47 | raise ValueError('expected 4D input (got {}D input)'.format(x.dim())) 48 | 49 | 50 | class SwitchNorm3d(_SwitchNorm): 51 | def _check_input_dim(self, x): 52 | if x.dim() != 5: 53 | raise ValueError('expected 5D input (got {}D input)'.format(x.dim())) 54 | 55 | 56 | class _EvoNorm(nn.Module): 57 | def __init__(self, prefix, num_features, eps=1e-5, momentum=0.9, groups=32, affine=True): 58 | super(_EvoNorm, self).__init__() 59 | assert prefix in ('s0', 'b0') 60 | self.prefix = prefix 61 | self.groups = groups 62 | self.num_features = num_features 63 | self.eps = eps 64 | self.momentum = momentum 65 | self.affine = affine 66 | if self.affine: 67 | self.weight = nn.Parameter(torch.Tensor(1, num_features, 1, 1)) 68 | self.bias = nn.Parameter(torch.Tensor(1, num_features, 1, 1)) 69 | self.v = nn.Parameter(torch.Tensor(1, num_features, 1, 1)) 70 | else: 71 | self.register_parameter('weight', None) 72 | self.register_parameter('bias', None) 73 | self.register_parameter('v', None) 74 | self.register_buffer('running_var', torch.ones(1, num_features, 1, 1)) 75 | self.reset_parameters() 76 | 77 | def reset_parameters(self): 78 | if self.affine: 79 | torch.nn.init.ones_(self.weight) 80 | torch.nn.init.zeros_(self.bias) 81 | torch.nn.init.ones_(self.v) 82 | 83 | def _check_input_dim(self, x): 84 | if x.dim() != 4: 85 | raise ValueError('expected 4D input (got {}D input)'.format(x.dim())) 86 | 87 | def forward(self, x): 88 | self._check_input_dim(x) 89 | return F.evo_norm(x, self.prefix, self.running_var, self.v, self.weight, self.bias, self.training, self.momentum, 90 | self.eps, self.groups) 91 | 92 | 93 | class EvoNormB0(_EvoNorm): 94 | def __init__(self, num_features, eps=1e-5, momentum=0.9, affine=True): 95 | super(EvoNormB0, self).__init__('b0', num_features, eps, momentum, affine=affine) 96 | 97 | 98 | class EvoNormS0(_EvoNorm): 99 | def __init__(self, num_features, groups=32, affine=True): 100 | super(EvoNormS0, self).__init__('s0', num_features, groups=groups, affine=affine) 101 | 102 | 103 | class DropBlock2d(nn.Module): 104 | r"""Randomly zeroes 2D spatial blocks of the input tensor. 105 | As described in the paper 106 | `DropBlock: A regularization method for convolutional networks`_ , 107 | dropping whole blocks of feature map allows to remove semantic 108 | information as compared to regular dropout. 109 | Args: 110 | p (float): probability of an element to be dropped. 111 | block_size (int): size of the block to drop 112 | Shape: 113 | - Input: `(N, C, H, W)` 114 | - Output: `(N, C, H, W)` 115 | .. _DropBlock: A regularization method for convolutional networks: 116 | https://arxiv.org/abs/1810.12890 117 | """ 118 | def __init__(self, p=0.1, block_size=7): 119 | super(DropBlock2d, self).__init__() 120 | assert 0 <= p <= 1 121 | self.p = p 122 | self.block_size = block_size 123 | 124 | def forward(self, x): 125 | if not self.training or self.p == 0: 126 | return x 127 | _, _, h, w = x.size() 128 | gamma = self.get_gamma(h, w) 129 | mask = self.get_mask(x, gamma) 130 | y = F.drop_block(x, mask) 131 | return y 132 | 133 | @torch.no_grad() 134 | def get_mask(self, x, gamma): 135 | mask = torch.bernoulli(torch.ones_like(x) * gamma) 136 | mask = 1 - torch.max_pool2d(mask, kernel_size=self.block_size, stride=1, padding=self.block_size // 2) 137 | return mask 138 | 139 | def get_gamma(self, h, w): 140 | return self.p * (h * w) / (self.block_size**2) / ((w - self.block_size + 1) * (h * self.block_size + 1)) 141 | 142 | 143 | class DropPath(nn.Module): 144 | """DropPath method. 145 | 146 | Args: 147 | ndim ([type]): input feature dim, don't forget batch. 148 | drop_rate ([type], optional): drop path rate. Defaults to 0.. 149 | batch_axis (int, optional): batch dim axis. Defaults to 0. 150 | """ 151 | def __init__(self, drop_rate=0., batch_axis=0): 152 | super().__init__() 153 | self.drop_rate = drop_rate 154 | self.batch_axis = batch_axis 155 | 156 | @torch.no_grad() 157 | def get_param(self, x): 158 | keep_prob = 1 - self.drop_rate 159 | shape = [ 160 | 1, 161 | ] * x.ndim 162 | shape[self.batch_axis] *= x.size(self.batch_axis) 163 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 164 | random_tensor.floor_() 165 | return keep_prob, random_tensor 166 | 167 | def forward(self, x): 168 | keep_prob, random_tensor = self.get_param(x) 169 | if self.drop_rate == 0 or not self.training: 170 | return x 171 | return x.div(keep_prob) * random_tensor 172 | -------------------------------------------------------------------------------- /torchtoolbox/nn/operators.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | __all__ = ['SwishOP'] 4 | 5 | from torch.autograd import Function 6 | import torch 7 | 8 | 9 | class SwishOP(Function): 10 | @staticmethod 11 | def forward(ctx, tensor, beta=1.0): 12 | ctx.save_for_backward(tensor) 13 | ctx.beta = beta 14 | swish = tensor / (1 + torch.exp(-beta * tensor)) 15 | return swish 16 | 17 | @staticmethod 18 | def backward(ctx, grad_outputs): 19 | tensor = ctx.saved_tensors[0] 20 | beta = ctx.beta 21 | grad_swish = (torch.exp(-beta * tensor) * (1 + beta * tensor) + 1) / \ 22 | (1 + torch.exp(-beta * tensor)) ** 2 23 | grad_swish = grad_outputs * grad_swish 24 | return grad_swish, None 25 | -------------------------------------------------------------------------------- /torchtoolbox/nn/parallel/EncodingDataParallel.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | """Refers to 'https://github.com/zhanghang1989/PyTorch-Encoding/blob/master/encoding/parallel.py'""" 4 | __all__ = ['EncodingDataParallel', 'EncodingCriterionParallel'] 5 | import threading 6 | import torch 7 | import functools 8 | import torch.cuda.comm as comm 9 | from torch.nn import Module 10 | from itertools import chain 11 | from torch.autograd import Function 12 | from torch.nn.parallel.parallel_apply import get_a_var, parallel_apply 13 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 14 | from torch.nn.parallel.scatter_gather import scatter_kwargs, gather 15 | from torch.nn.parallel.replicate import replicate 16 | from torch.nn.parallel.data_parallel import _check_balance 17 | from torch.cuda._utils import _get_device_index 18 | from torch._utils import ExceptionWrapper 19 | 20 | 21 | class EncodingParallel(Module): 22 | def __init__(self, module, device_ids=None, output_device=None, dim=0): 23 | super(EncodingParallel, self).__init__() 24 | 25 | if not torch.cuda.is_available(): 26 | self.module = module 27 | self.device_ids = [] 28 | return 29 | if device_ids is None: 30 | device_ids = list(range(torch.cuda.device_count())) 31 | if output_device is None: 32 | output_device = device_ids[0] 33 | 34 | self.dim = dim 35 | self.module = module 36 | self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids)) 37 | self.output_device = _get_device_index(output_device, True) 38 | self.src_device_obj = torch.device("cuda {}".format(self.device_ids[0])) 39 | 40 | _check_balance(self.device_ids) 41 | 42 | if len(self.device_ids) == 1: 43 | self.module.cuda(device_ids[0]) 44 | 45 | def replicate(self, module, device_ids): 46 | return replicate(module, device_ids, not torch.is_grad_enabled()) 47 | 48 | def scatter(self, inputs, kwargs, device_ids): 49 | return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) 50 | 51 | 52 | class EncodingDataParallel(EncodingParallel): 53 | """Implements data parallelism at the module level. 54 | This container parallelizes the application of the given module by 55 | splitting the input across the specified devices by chunking in the 56 | batch dimension. 57 | In the forward pass, the module is replicated on each device, 58 | and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module. 59 | Note that the outputs are not gathered, please use compatible 60 | :class:`encoding.parallel.DataParallelCriterion`. 61 | The batch size should be larger than the number of GPUs used. It should 62 | also be an integer multiple of the number of GPUs so that each chunk is 63 | the same size (so that each GPU processes the same number of samples). 64 | Args: 65 | module: module to be parallelized 66 | device_ids: CUDA devices (default: all devices) 67 | Reference: 68 | Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, 69 | Amit Agrawal. “Context Encoding for Semantic Segmentation. 70 | *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* 71 | Example:: 72 | >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2]) 73 | >>> y = net(x) 74 | """ 75 | def forward(self, *inputs, **kwargs): 76 | if not self.device_ids: 77 | return self.module(*inputs, **kwargs) 78 | 79 | for t in chain(self.module.parameters(), self.module.buffers()): 80 | if t.device != self.src_device_obj: 81 | raise RuntimeError("module must have its parameters and buffers " 82 | "on device {} (device_ids[0]) but found one of " 83 | "them on device: {}".format(self.src_device_obj, t.device)) 84 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) 85 | if len(self.device_ids) == 1: 86 | return self.module(*inputs, **kwargs) 87 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) 88 | outputs = self.parallel_apply(replicas, inputs, kwargs) 89 | return outputs 90 | 91 | def parallel_apply(self, replicas, inputs, kwargs): 92 | return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) 93 | 94 | 95 | class EncodingCriterionParallel(EncodingParallel): 96 | def forward(self, inputs, *targets, **kwargs): 97 | # input should be already scatterd 98 | # scattering the targets instead 99 | 100 | if not self.device_ids: 101 | return self.module(inputs, *targets, **kwargs) 102 | 103 | targets, kwargs = self.scatter(targets, kwargs, self.device_ids) 104 | if len(self.device_ids) == 1: 105 | return self.module(inputs, *targets, **kwargs) 106 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) 107 | outputs = self.criterion_apply(replicas, inputs, targets, kwargs) 108 | return ReduceAddCoalesced.apply(self.device_ids[0], len(outputs), *outputs) / len(outputs) 109 | 110 | def criterion_apply(self, replicas, inputs, targets, kwargs): 111 | return criterion_parallel_apply(replicas, inputs, targets, kwargs, self.device_ids[:len(replicas)]) 112 | 113 | 114 | def criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None): 115 | assert len(modules) == len(inputs) 116 | assert len(targets) == len(inputs) 117 | if kwargs_tup is not None: 118 | assert len(modules) == len(kwargs_tup) 119 | else: 120 | kwargs_tup = ({}, ) * len(modules) 121 | if devices is not None: 122 | assert len(modules) == len(devices) 123 | else: 124 | devices = [None] * len(modules) 125 | devices = list(map(lambda x: _get_device_index(x, True), devices)) 126 | lock = threading.Lock() 127 | results = {} 128 | grad_enabled = torch.is_grad_enabled() 129 | 130 | def _worker(i, module, input, target, kwargs, device=None): 131 | torch.set_grad_enabled(grad_enabled) 132 | if device is None: 133 | device = get_a_var(input).get_device() 134 | try: 135 | with torch.cuda.device(device): 136 | if not isinstance(input, (list, tuple)): 137 | input = (input, ) 138 | if not isinstance(target, (list, tuple)): 139 | target = (target, ) 140 | output = module(*input, *target, **kwargs) 141 | with lock: 142 | results[i] = output 143 | except Exception: 144 | with lock: 145 | results[i] = ExceptionWrapper(where="in replica {} on device {}".format(i, device)) 146 | 147 | if len(modules) > 1: 148 | threads = [ 149 | threading.Thread(target=_worker, args=(i, module, input, target, kwargs, device)) 150 | for i, (module, input, target, kwargs, device) in enumerate(zip(modules, inputs, kwargs_tup, devices)) 151 | ] 152 | 153 | for thread in threads: 154 | thread.start() 155 | for thread in threads: 156 | thread.join() 157 | else: 158 | _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) 159 | 160 | outputs = [] 161 | for i in range(len(inputs)): 162 | output = results[i] 163 | if isinstance(output, ExceptionWrapper): 164 | output.reraise() 165 | outputs.append(output) 166 | return outputs 167 | -------------------------------------------------------------------------------- /torchtoolbox/nn/parallel/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | 4 | from .EncodingDataParallel import * 5 | -------------------------------------------------------------------------------- /torchtoolbox/nn/sequential.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: pistonyang@gmail.com 3 | 4 | __all__ = ['AdaptiveSequential'] 5 | from torch import nn 6 | 7 | 8 | class AdaptiveSequential(nn.Sequential): 9 | """Make Sequential could handle multiple input/output layer. 10 | 11 | Example: 12 | class n_to_n(nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | self.conv1 = nn.Conv2d(3, 3, 1, 1, bias=False) 16 | self.conv2 = nn.Conv2d(3, 3, 1, 1, bias=False) 17 | 18 | def forward(self, x1, x2): 19 | y1 = self.conv1(x1) 20 | y2 = self.conv2(x2) 21 | return y1, y2 22 | 23 | 24 | class n_to_one(nn.Module): 25 | def __init__(self): 26 | super().__init__() 27 | self.conv1 = nn.Conv2d(3, 3, 1, 1, bias=False) 28 | self.conv2 = nn.Conv2d(3, 3, 1, 1, bias=False) 29 | 30 | def forward(self, x1, x2): 31 | y1 = self.conv1(x1) 32 | y2 = self.conv2(x2) 33 | return y1 + y2 34 | 35 | 36 | class one_to_n(nn.Module): 37 | def __init__(self): 38 | super().__init__() 39 | self.conv1 = nn.Conv2d(3, 3, 1, 1, bias=False) 40 | self.conv2 = nn.Conv2d(3, 3, 1, 1, bias=False) 41 | 42 | def forward(self, x): 43 | y1 = self.conv1(x) 44 | y2 = self.conv2(x) 45 | return y1, y2 46 | 47 | seq = AdaptiveSequential(one_to_n(), n_to_n(), n_to_one()).cuda() 48 | td = torch.rand(1, 3, 32, 32).cuda() 49 | 50 | out = seq(td) 51 | print(out.size()) 52 | # torch.Size([1, 3, 32, 32]) 53 | 54 | """ 55 | def forward(self, *inputs): 56 | for module in self: 57 | if isinstance(inputs, tuple): 58 | inputs = module(*inputs) 59 | else: 60 | inputs = module(inputs) 61 | return inputs 62 | -------------------------------------------------------------------------------- /torchtoolbox/nn/transformer.py: -------------------------------------------------------------------------------- 1 | __all__ = ['PatchEmbedding', 'PositionEncoding', 'FeedForward', 'Token'] 2 | 3 | import math 4 | 5 | import torch 6 | from torch import nn 7 | 8 | from ..tools import check_twin 9 | from .activation import Activation 10 | 11 | 12 | class PatchEmbedding(nn.Module): 13 | def __init__(self, img_size, patch_size, dim, in_channels=3, norm_layer=None, out_order=('B', 'SL', 'D')): 14 | super().__init__() 15 | assert out_order in (('B', 'SL', 'D'), ('SL', 'B', 'D')) 16 | self.batch_first = True if out_order[0] == 'B' else False 17 | 18 | img_size = check_twin(img_size) 19 | patch_size = check_twin(patch_size) 20 | 21 | patch_grid = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 22 | self.num_patches = patch_grid[0] * patch_grid[1] 23 | 24 | self.dump_patch = nn.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size) 25 | self.norm = norm_layer(dim) if norm_layer is not None else nn.Identity() 26 | 27 | def forward(self, x): 28 | x = self.dump_patch(x).flatten(2) 29 | if self.batch_first: 30 | x = x.transpose(1, 2) 31 | else: 32 | x = x.permute(2, 0, 1) 33 | x = self.norm(x) 34 | return x 35 | 36 | 37 | class PositionEncoding(nn.Module): 38 | def __init__(self, sequence_length, dim, dropout=0., learnable=False, batch_axis=0): 39 | super().__init__() 40 | self.learnable = learnable 41 | 42 | if not learnable: 43 | pe = torch.zeros(sequence_length, dim) 44 | position = torch.arange(0, sequence_length).unsqueeze(1) 45 | div_term = torch.exp(torch.arange(0, dim, 2) * -(math.log(10000.) / dim)) 46 | pe[:, 0::2] = torch.sin(position * div_term) 47 | pe[:, 1::2] = torch.cos(position * div_term) 48 | pe.unsqueeze_(batch_axis) 49 | self.register_buffer('pe', pe) 50 | else: 51 | self.pe = nn.Parameter(torch.Tensor(sequence_length, dim).unsqueeze_(batch_axis)) 52 | nn.init.trunc_normal_(self.pe.data, std=0.02) 53 | 54 | self.dropout = nn.Dropout(dropout) 55 | 56 | def forward(self, x): 57 | x = self.dropout(x + self.pe) 58 | return x 59 | 60 | def no_wd(self, decay: list, no_decay: list): 61 | if self.learnable: 62 | no_decay.append(self.pe) 63 | 64 | def num_param(self, input, output): 65 | if self.learnable: 66 | return self.pe.numel(), 0 67 | else: 68 | return 0, 0 69 | 70 | 71 | class Token(nn.Module): 72 | def __init__(self, num, dim, token_order='first', in_order=('B', 'SL', 'D')): 73 | super().__init__() 74 | assert in_order in (('B', 'SL', 'D'), ('SL', 'B', 'D')) 75 | assert token_order == 'first', "I think we don't need last order, just remember we will add token at first of other data on sl dim." 76 | self.batch_first = True if in_order[0] == 'B' else False 77 | self.token = nn.Parameter(torch.Tensor(num, dim)) 78 | self.num = num 79 | self.reset_parameters() 80 | 81 | def reset_parameters(self): 82 | nn.init.zeros_(self.token) 83 | 84 | def forward(self, x): 85 | if self.batch_first: 86 | b, _, d = x.size() 87 | token = self.token.expand(b, self.num, d) 88 | x = torch.cat([token, x], dim=1) 89 | else: 90 | _, b, d = x.size() 91 | token = self.token.expand(self.num, b, d) 92 | x = torch.cat([token, x], dim=0) 93 | return x 94 | 95 | def no_wd(self, decay: list, no_decay: list): 96 | no_decay.append(self.token) 97 | 98 | 99 | class FeedForward(nn.Module): 100 | def __init__(self, dim, hidden_dim, activation='gelu', dropout=0.): 101 | super().__init__() 102 | # do not add last dropout, if need add after this. 103 | self.ffn = nn.Sequential(nn.Linear(dim, hidden_dim), 104 | Activation(activation), 105 | nn.Dropout(dropout), 106 | nn.Linear(hidden_dim, dim)) # yapf:disable 107 | 108 | def forward(self, x): 109 | return self.ffn(x) 110 | -------------------------------------------------------------------------------- /torchtoolbox/objects/__init__.py: -------------------------------------------------------------------------------- 1 | from .bbox import * 2 | -------------------------------------------------------------------------------- /torchtoolbox/objects/bbox.py: -------------------------------------------------------------------------------- 1 | __all__ = ['BBox'] 2 | 3 | from typing import Union, List, Optional 4 | 5 | import numpy as np 6 | 7 | from ..tools import get_list_index, get_list_value, to_numpy 8 | 9 | # def bbox_center_to_coordinate(center, wh): 10 | # """Convert center, wh to array([x_min, y_min, x_max, y_max]). 11 | 12 | # Args: 13 | # center (Any): bbox center 14 | # wh (Any): bbox wh 15 | # """ 16 | 17 | # center, wh = to_numpy(center).squeeze(), to_numpy(wh).squeeze() 18 | # return np.concatenate([center - wh, center + wh], axis=-1) 19 | 20 | # def bbox_coordinate_to_center(cord: Any): 21 | # """Convert array([x_min, y_min, x_max, y_max]) to center, wh. 22 | 23 | # Args: 24 | # cord (Any): [[x_min, y_min, x_max, y_max], ...] 25 | 26 | # Returns: 27 | # center (np.ndarray): center 28 | # wh (np.ndarray): wh 29 | # """ 30 | # cord = to_numpy(cord).squeeze().reshape((-1, 2, 2)) 31 | # center = np.mean(cord, axis=1) 32 | # wh = np.max(cord, axis=1) - center 33 | # return center.squeeze(), wh.squeeze() 34 | 35 | # def bbox_clip_boundary(bbox, boundary, mode='keep'): 36 | # """Only support coordinate input, clip bbox by boundary. 37 | 38 | # Args: 39 | # bbox (Any): bbox 40 | # boundary (Any): boundary(x_min, y_min, x_max, y_max) 41 | # mode (str): keep or drop illegal box 42 | # Returns: 43 | # Any: bbox 44 | # """ 45 | # assert mode in ('keep', 'drop') 46 | # if isinstance(bbox, BBox): 47 | # BBox.clip_boundary() 48 | # else: 49 | # bbox = to_numpy(bbox).reshape(-1, 4) 50 | # boundary = to_numpy(boundary).squeeze() 51 | # if mode == 'keep': 52 | # bbox = np.clip(bbox, np.tile(boundary[:2], 2), np.tile(boundary[2:], 2)) 53 | # return bbox 54 | # else: 55 | # ind = np.where((bbox[:, 0] >= boundary[0]) & (bbox[:, 1] >= boundary[1]) & (bbox[:, 2] <= boundary[2]) 56 | # & (bbox[:, 3] <= boundary[3])) 57 | # bbox = bbox[ind] 58 | # return bbox, ind 59 | 60 | 61 | class BBox(object): 62 | """hold all bbox on one Image. 63 | 64 | Args: 65 | bbox (List[np.ndarray]): bboxes. 66 | mode (str): XYXY or XYWH 67 | category (List[str]): bbox category. 68 | name (str, optional): name. Defaults to None. 69 | """ 70 | def __init__(self, bbox: List[np.ndarray], mode, category: List[str], name: Optional[str] = None) -> None: 71 | super().__init__() 72 | assert mode in ('XYXY', 'XYWH') 73 | if isinstance(bbox[0], (list, tuple)): 74 | self.bbox = np.array(bbox) 75 | elif isinstance(bbox[0], np.ndarray): 76 | self.bbox = np.stack(bbox) 77 | else: 78 | raise ValueError('bbox should be a list of (list or np.ndarray).') 79 | self.category = category 80 | self.mode = mode 81 | 82 | if self.mode != 'XYXY': 83 | self.bbox = self.get_xyxy() 84 | self.mode = 'XYXY' 85 | 86 | self.contain_category = list(set(category)) 87 | self.name = name 88 | 89 | assert bbox.shape[0] == len(category), "num of bbox and category must be same." 90 | 91 | def get_xyxy(self): 92 | bbox = self.bbox.copy() 93 | if self.mode == 'XYXY': 94 | return bbox 95 | bbox[:, 2] += bbox[:, 0] 96 | bbox[:, 3] += bbox[:, 1] 97 | return bbox 98 | 99 | def get_xywh(self): 100 | bbox = self.bbox.copy() 101 | assert self.mode == 'XYXY', 'Wrong BBox mode.' 102 | bbox[:, 2] -= bbox[:, 0] 103 | bbox[:, 3] -= bbox[:, 1] 104 | return bbox 105 | 106 | def area(self): 107 | """ 108 | Computes the area of all the boxes. 109 | 110 | Returns: 111 | torch.Tensor: a vector with areas of each box. 112 | """ 113 | bbox = self.bbox 114 | area = (bbox[:, 2] - bbox[:, 0]) * (bbox[:, 3] - bbox[:, 1]) 115 | return area 116 | 117 | @property 118 | def empty_bbox(self) -> bool: 119 | return True if self.bbox.shape[0] == 0 else False 120 | 121 | def clip_boundary(self, boundary, mode='keep'): 122 | """Only support coordinate input, clip bbox by boundary. 123 | 124 | Args: 125 | bbox (Any): bbox 126 | boundary (Any): boundary(x_min, y_min, x_max, y_max) 127 | mode (str): keep or drop illegal box 128 | Returns: 129 | Any: bbox 130 | """ 131 | assert mode in ('keep', 'drop') 132 | boundary = to_numpy(boundary).squeeze() 133 | if mode == 'keep': 134 | self.bbox = np.clip(self.bbox, np.tile(boundary[:2], 2), np.tile(boundary[2:], 2)) 135 | return self 136 | else: 137 | ind = np.where((self.bbox[:, 0] >= boundary[0]) & (self.bbox[:, 1] >= boundary[1]) 138 | & (self.bbox[:, 2] <= boundary[2]) 139 | & (self.bbox[:, 3] <= boundary[3])) 140 | self.bbox = self.bbox[ind] 141 | self.category = self.category[ind] 142 | return self 143 | 144 | def get_category_bboxes(self, category: str): 145 | index = get_list_index(self.category, category) 146 | return self.bbox[index] 147 | 148 | def get_centers(self): 149 | return (self.bbox[:, :2] + self.bbox[:, 2:]) / 2 150 | 151 | def resize(self): 152 | pass 153 | 154 | def scale(self): 155 | pass 156 | 157 | def crop(self): 158 | pass 159 | 160 | def horizontal_flip(self): 161 | pass 162 | 163 | def vertical_flip(self): 164 | pass 165 | 166 | def __str__(self) -> str: 167 | bbox_str = f"bbox: {self.bbox}\n category: {self.category}\n " 168 | return f"name: {self.name}\n" + bbox_str 169 | 170 | __repr__ = __str__ 171 | 172 | def __len__(self): 173 | return self.bbox.shape[0] 174 | 175 | def __iter__(self): 176 | for box, category in zip(self.bbox, self.category): 177 | yield box, category 178 | 179 | def __getitem__(self, inds: Union[list, tuple, int], name: str = None): 180 | if isinstance(inds, int): 181 | return self.bbox[inds], self.category[inds] 182 | elif isinstance(inds, (list, tuple)): 183 | category = get_list_value(self.category, inds) 184 | inds = to_numpy(inds) 185 | bbox = self.bbox[inds] 186 | return BBox(bbox, 'XYXY', category, self.name if name is None else name) 187 | else: 188 | raise ValueError('Wrong value of inds, only support int, List[int], Tuple[int]') 189 | -------------------------------------------------------------------------------- /torchtoolbox/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | 4 | from .lookahead import * 5 | from .lr_scheduler import * 6 | from .sgd_gc import * 7 | -------------------------------------------------------------------------------- /torchtoolbox/optimizer/drop_block_scheduler.py: -------------------------------------------------------------------------------- 1 | from ..nn.norm import DropBlock2d 2 | 3 | 4 | class DropBlockScheduler(object): 5 | def __init__(self, model, batches: int, num_epochs: int, start_value=0.1, stop_value=1.): 6 | self.model = model 7 | self.iter = 0 8 | self.start_value = start_value 9 | self.num_iter = batches * num_epochs 10 | self.st_line = (stop_value - start_value) / self.num_iter 11 | self.groups = [] 12 | self.value = start_value 13 | 14 | def coll_dbs(md): 15 | if isinstance(md, DropBlock2d): 16 | self.groups.append(md) 17 | 18 | model.apply(coll_dbs) 19 | 20 | def update_values(self, value): 21 | for db in self.groups: 22 | db.p = value 23 | 24 | def load_state_dict(self, state_dict): 25 | """Loads the schedulers state. 26 | 27 | Arguments: 28 | state_dict (dict): scheduler state. Should be an object returned 29 | from a call to :meth:`state_dict`. 30 | """ 31 | self.__dict__.update(state_dict) 32 | 33 | def get_value(self): 34 | self.value = self.st_line * self.iter + self.start_value 35 | 36 | def state_dict(self): 37 | return {key: value for key, value in self.__dict__.items() if (key != 'model' and key != 'groups')} 38 | 39 | def step(self): 40 | self.get_value() 41 | self.update_values(self.value) 42 | self.iter += 1 43 | 44 | 45 | class ObjectSchedule(object): 46 | def __init__(self, object, adjust_param, start_epoch, stop_epoch, batches, start_value, stop_value, mode='linear'): 47 | super().__init__() 48 | self.start_iter = start_epoch * batches 49 | self.end_iter = stop_epoch * batches 50 | self.start_value = start_value 51 | self.adjust_param = adjust_param 52 | self.object = object 53 | self.st_base = (stop_value - start_value) / \ 54 | (self.end_iter - self.start_iter) 55 | self.iter = 0 56 | self._value = start_value 57 | 58 | def get_value(self): 59 | self._value = self.st_base * \ 60 | (self.iter - self.start_iter) + self.start_value 61 | 62 | def update_value(self): 63 | setattr(self.object, self.adjust_param, self.value) 64 | 65 | def step(self): 66 | if not (self.iter < self.start_iter or self.iter > self.end_iter): 67 | self.get_value() 68 | self.update_value() 69 | self.iter += 1 70 | 71 | @property 72 | def value(self): 73 | return self._value 74 | 75 | def state_dict(self): 76 | return {key: value for key, value in self.__dict__.items() if key != 'object'} 77 | 78 | def load_state_dict(self, state_dict): 79 | """Loads the schedulers state. 80 | 81 | Arguments: 82 | state_dict (dict): scheduler state. Should be an object returned 83 | from a call to :meth:`state_dict`. 84 | """ 85 | self.__dict__.update(state_dict) 86 | -------------------------------------------------------------------------------- /torchtoolbox/optimizer/lookahead.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | """Refers to 'https://github.com/alphadl/lookahead.pytorch'.""" 4 | __all__ = ['Lookahead'] 5 | 6 | from collections import defaultdict 7 | from torch.optim.optimizer import Optimizer 8 | import torch 9 | 10 | 11 | class Lookahead(Optimizer): 12 | def __init__(self, optimizer, k=5, alpha=0.5): 13 | self.optimizer = optimizer 14 | self.k = k 15 | self.alpha = alpha 16 | self.param_groups = self.optimizer.param_groups 17 | self.state = defaultdict(dict) 18 | self.fast_state = self.optimizer.state 19 | for group in self.param_groups: 20 | group["counter"] = 0 21 | 22 | def update(self, group): 23 | for fast in group["params"]: 24 | param_state = self.state[fast] 25 | if "slow_param" not in param_state: 26 | param_state["slow_param"] = torch.empty_like(fast.data) 27 | param_state["slow_param"].copy_(fast.data) 28 | slow = param_state["slow_param"] 29 | slow += (fast.data - slow) * self.alpha 30 | fast.data.copy_(slow) 31 | 32 | def update_lookahead(self): 33 | for group in self.param_groups: 34 | self.update(group) 35 | 36 | def step(self, closure=None): 37 | loss = self.optimizer.step(closure) 38 | for group in self.param_groups: 39 | if group["counter"] == 0: 40 | self.update(group) 41 | group["counter"] += 1 42 | if group["counter"] >= self.k: 43 | group["counter"] = 0 44 | return loss 45 | 46 | def state_dict(self): 47 | fast_state_dict = self.optimizer.state_dict() 48 | slow_state = {(id(k) if isinstance(k, torch.Tensor) else k): v for k, v in self.state.items()} 49 | fast_state = fast_state_dict["state"] 50 | param_groups = fast_state_dict["param_groups"] 51 | return { 52 | "fast_state": fast_state, 53 | "slow_state": slow_state, 54 | "param_groups": param_groups, 55 | } 56 | 57 | def load_state_dict(self, state_dict): 58 | slow_state_dict = { 59 | "state": state_dict["slow_state"], 60 | "param_groups": state_dict["param_groups"], 61 | } 62 | fast_state_dict = { 63 | "state": state_dict["fast_state"], 64 | "param_groups": state_dict["param_groups"], 65 | } 66 | super(Lookahead, self).load_state_dict(slow_state_dict) 67 | self.optimizer.load_state_dict(fast_state_dict) 68 | self.fast_state = self.optimizer.state 69 | 70 | def add_param_group(self, param_group): 71 | param_group["counter"] = 0 72 | self.optimizer.add_param_group(param_group) 73 | -------------------------------------------------------------------------------- /torchtoolbox/optimizer/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : Devin Yang(pistonyang@gmail.com), Gary Lai (glai9665@gmail.com) 3 | __all__ = ['CosineWarmupLr', 'get_cosine_warmup_lr_scheduler', 'get_layerwise_decay_params_for_bert'] 4 | 5 | from math import pi, cos 6 | from torch.optim.optimizer import Optimizer 7 | from torch.optim.lr_scheduler import LambdaLR 8 | 9 | class Scheduler(object): 10 | def __init__(self): 11 | raise NotImplementedError 12 | 13 | def get_lr(self): 14 | raise NotImplementedError 15 | 16 | def state_dict(self): 17 | """Returns the state of the scheduler as a :class:`dict`. 18 | 19 | It contains an entry for every variable in self.__dict__ which 20 | is not the optimizer. 21 | """ 22 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 23 | 24 | def load_state_dict(self, state_dict): 25 | """Loads the schedulers state. 26 | 27 | Arguments: 28 | state_dict (dict): scheduler state. Should be an object returned 29 | from a call to :meth:`state_dict`. 30 | """ 31 | self.__dict__.update(state_dict) 32 | 33 | class CosineWarmupLr(Scheduler): 34 | """Cosine lr decay function with warmup. 35 | 36 | Lr warmup is proposed by ` 37 | Accurate, Large Minibatch SGD:Training ImageNet in 1 Hour` 38 | `https://arxiv.org/pdf/1706.02677.pdf` 39 | 40 | Cosine decay is proposed by ` 41 | Stochastic Gradient Descent with Warm Restarts` 42 | `https://arxiv.org/abs/1608.03983` 43 | 44 | Args: 45 | optimizer (Optimizer): optimizer of a model. 46 | batches_per_epoch (int): batches per epoch. 47 | epochs (int): epochs to train. 48 | base_lr (float): init lr. 49 | target_lr (float): minimum(final) lr. 50 | warmup_epochs (int): warmup epochs before cosine decay. 51 | warmup_lr (float): warmup starting lr. 52 | last_iter (int): init iteration. 53 | 54 | Attributes: 55 | niters (int): number of iterations of all epochs. 56 | warmup_iters (int): number of iterations of all warmup epochs. 57 | 58 | """ 59 | def __init__(self, 60 | optimizer, 61 | batches: int, 62 | epochs: int, 63 | base_lr: float, 64 | target_lr: float = 0, 65 | warmup_epochs: int = 0, 66 | warmup_lr: float = 0, 67 | last_iter: int = -1): 68 | if not isinstance(optimizer, Optimizer): 69 | raise TypeError('{} is not an Optimizer'.format(type(optimizer).__name__)) 70 | self.optimizer = optimizer 71 | if last_iter == -1: 72 | for group in optimizer.param_groups: 73 | group.setdefault('initial_lr', group['lr']) 74 | last_iter = 0 75 | else: 76 | for i, group in enumerate(optimizer.param_groups): 77 | if 'initial_lr' not in group: 78 | raise KeyError("param 'initial_lr' is not specified " 79 | "in param_groups[{}] when resuming an optimizer".format(i)) 80 | self.baselr = base_lr 81 | self.learning_rate = base_lr 82 | self.total_iters = epochs * batches 83 | self.targetlr = target_lr 84 | self.total_warmup_iters = batches * warmup_epochs 85 | self.total_cosine_iters = self.total_iters - self.total_warmup_iters 86 | self.total_lr_decay = self.baselr - self.targetlr 87 | self.warmup_lr = warmup_lr 88 | self.last_iter = last_iter 89 | self.step() 90 | 91 | def get_lr(self): 92 | if self.last_iter < self.total_warmup_iters: 93 | return self.warmup_lr + \ 94 | (self.baselr - self.warmup_lr) * self.last_iter / self.total_warmup_iters 95 | else: 96 | cosine_iter = self.last_iter - self.total_warmup_iters 97 | cosine_progress = cosine_iter / self.total_cosine_iters 98 | return self.targetlr + self.total_lr_decay * \ 99 | (1 + cos(pi * cosine_progress)) / 2 100 | 101 | def step(self, iteration=None): 102 | """Update status of lr. 103 | 104 | Args: 105 | iteration(int, optional): now training iteration of all epochs. 106 | Usually no need to set it manually. 107 | """ 108 | if iteration is None: 109 | iteration = self.last_iter + 1 110 | self.last_iter = iteration 111 | self.learning_rate = self.get_lr() 112 | for param_group in self.optimizer.param_groups: 113 | param_group['lr'] = self.learning_rate 114 | 115 | def get_cosine_warmup_lr_scheduler(optimizer : Optimizer, 116 | batches_per_epoch: int, 117 | epochs: int, 118 | warmup_epochs: int = 0, 119 | last_epoch: int = -1): 120 | """Similar to CosineWarmupLr, with support for different learning rate for different parameter groups as well as better compatibility with current PyTorch API 121 | 122 | Args: 123 | optimizer (Optimizer): optimizer of a model. 124 | batches_per_epoch (int): batches per epoch. 125 | epochs (int): epochs to train. 126 | warmup_epochs (int): warmup epochs before cosine decay. 127 | last_epoch (int): the index of the last epoch when resuming training. 128 | 129 | Example: 130 | ``` 131 | batches_per_epoch = 10 132 | epochs = 5 133 | warmup_epochs = 1 134 | params = get_layerwise_decay_params_for_bert(model) 135 | optimizer = optim.SGD(params, lr=3e-5) 136 | lr_scheduler = get_cosine_warmup_lr_scheduler(optimizer, batches_per_epoch, epochs, warmup_epochs=warmup_epochs) 137 | ``` 138 | """ 139 | total_steps = epochs * batches_per_epoch 140 | # warmup params 141 | total_warmup_steps = batches_per_epoch * warmup_epochs 142 | # cosine params 143 | total_cosine_steps = total_steps - total_warmup_steps 144 | 145 | def lr_lambda(current_step): 146 | # lr_lambda should return current lr / top learning rate 147 | if current_step < total_warmup_steps: 148 | warmup_progress = current_step / total_warmup_steps 149 | return warmup_progress 150 | else: 151 | cosine_step = current_step - total_warmup_steps 152 | cosine_progress = cosine_step / total_cosine_steps 153 | return (1 + cos(pi * cosine_progress)) / 2 154 | return LambdaLR(optimizer, lr_lambda, last_epoch) 155 | 156 | 157 | def get_differential_lr_param_group(param_groups, lrs): 158 | """Assigns different learning rates to different parameter groups. 159 | 160 | Discriminative fine-tuning, where different layers of the network have different learning rates, is first proposed in 161 | `Jeremy Howard and Sebastian Ruder. 2018. Universal language model fine-tuning for text classification. 162 | https://arxiv.org/pdf/1801.06146.pdf.` It has been found to stabilize training and speed up convergence. 163 | 164 | Args: 165 | param_groups: a list of parameter groups (each of which is a list of parameters) 166 | param group should look like: 167 | [ 168 | [param1a, param1b, ..] <-- parameter group 1 169 | [param2a, param2b, ..] <-- parameter group 2 170 | ... 171 | ] 172 | lrs: a list of learning rates you want to assign to each of the parameter groups 173 | lrs should look like 174 | [ 175 | lr1, <-- learning rate for parameter group 1 176 | lr2, <-- learning rate for parameter group 2 177 | ... 178 | ] 179 | 180 | Returns: 181 | parameter groups with different learning rates that you can then pass into an optimizer 182 | """ 183 | assert len(param_groups) == len(lrs), f"expect the learning rates to have the same lengths as the param_group length, instead got {len(param_groups)} and {len(lrs)} as lengths respectively" 184 | 185 | param_groups_for_optimizer = [] 186 | for i in range(len(param_groups)): 187 | param_groups_for_optimizer.append({ 188 | 'params': param_groups[i], 189 | 'lr': lrs[i] 190 | }) 191 | return param_groups_for_optimizer 192 | 193 | 194 | def get_layerwise_decay_param_group(param_groups, top_lr=2e-5, decay=0.95): 195 | """Assign layerwise decay learning rates to parameter groups. 196 | 197 | Layer-wise decay learning rate is used in `Chi Sun, Xipeng Qiu, Yige Xu, and Xuanjing Huang. 2019. 198 | How to fine-tune BERT for text classification? https://arxiv.org/abs/1905.05583` to improve convergence 199 | and prevent catastrophic forgetting. 200 | 201 | Args: 202 | param_groups: a list of parameter groups 203 | param group should look like: 204 | [ 205 | [param1a, param1b, ..] <-- parameter group 1 206 | [param2a, param2b, ..] <-- parameter group 2 207 | .. 208 | ] 209 | top_lr: learning rate of the top layer 210 | decay: decay factor. When decay < 1, lower layers have lower learning rates; when decay == 1, all layers have the same learning rate 211 | 212 | Returns: 213 | parameter groups with layerwise decay learning rates that you can then pass into an optimizer 214 | 215 | Examples: 216 | ``` 217 | param_groups = get_layerwise_decay_params_group(model_param_groups, top_lr=2e-5, decay=0.95) 218 | optimizer = AdamW(param_groups, lr = 2e-5) 219 | ``` 220 | """ 221 | lrs = [top_lr * pow(decay, len(param_groups)-1-i) for i in range(len(param_groups))] 222 | return get_differential_lr_param_group(param_groups, lrs) 223 | 224 | 225 | def get_layerwise_decay_params_for_bert(model, number_of_layer=12, top_lr=2e-5, decay=0.95): 226 | """Assign layerwise decay learning rates to parameter groups of BERT. 227 | 228 | Layer-wise decay learning rate is used in `Chi Sun, Xipeng Qiu, Yige Xu, and Xuanjing Huang. 2019. 229 | How to fine-tune BERT for text classification? https://arxiv.org/abs/1905.05583` to improve convergence 230 | and prevent catastrophic forgetting. 231 | 232 | Args: 233 | model: your BERT model 234 | number_of_layer: number of layers your BERT has 235 | top_lr: learning rate of the top layer 236 | decay: decay factor. When decay < 1, lower layers have lower learning rates; when decay == 1, all layers have the same learning rate 237 | 238 | Returns: 239 | BERT parameter groups with different learning rates that you can then pass into an optimizer 240 | 241 | Example: 242 | ``` 243 | param_groups = get_layerwise_decay_params_for_bert(model, number_of_layer=12, top_lr=2e-5, decay=0.95) 244 | optimizer = AdamW(param_groups, lr = 2e-5) 245 | ``` 246 | """ 247 | param_groups = get_param_group_for_bert(model, number_of_layer=number_of_layer, top_lr=top_lr, decay=decay) 248 | param_groups_for_optimizer = get_layerwise_decay_param_group(param_groups, top_lr=top_lr, decay=decay) 249 | return param_groups_for_optimizer 250 | 251 | def get_param_group_for_bert(model, number_of_layer=12, top_lr=2e-5, decay=0.95): 252 | """separate each layer of a BERT models into a parameter group 253 | 254 | Args: 255 | model: your BERT model 256 | number_of_layer: number of layers your BERT has 257 | top_lr: learning rate of the top layer 258 | decay: decay factor. When decay < 1, lower layers have lower learning rates; when decay == 1, all layers have the same learning rate 259 | 260 | Returns: 261 | a param group that should look like: 262 | [ 263 | ... 264 | [param1a, param1b, ..] <-- parameter group 1, layer 1 of BERT 265 | [param2a, param2b, ..] <-- parameter group 2, layer 2 of BERT 266 | ... 267 | ] 268 | """ 269 | param_groups_for_optimizer = [[] for _ in range(number_of_layer+2)] # tail, layer0, layer1 ...., layer11, head 270 | head = {'pooler', 'norm', 'relative_attention_bias'} 271 | tail = {'embeddings',} 272 | layers = [f'layer.{i}.' for i in range(number_of_layer)] 273 | 274 | for name, param in model.named_parameters(): 275 | if belongs(name, tail): 276 | param_groups_for_optimizer[0].append(param) 277 | elif belongs(name, head): 278 | param_groups_for_optimizer[-1].append(param) 279 | else: 280 | for i, layer in enumerate(layers): 281 | if layer in name: 282 | param_groups_for_optimizer[i+1].append(param) 283 | return param_groups_for_optimizer 284 | 285 | 286 | def belongs(name, groups): 287 | """ checks if name belongs to any of the group 288 | """ 289 | for group in groups: 290 | if group in name: 291 | return True 292 | return False 293 | -------------------------------------------------------------------------------- /torchtoolbox/optimizer/sgd_gc.py: -------------------------------------------------------------------------------- 1 | __all__ = ['SGD_GC'] 2 | 3 | import torch 4 | from torch.optim.optimizer import Optimizer, required 5 | 6 | 7 | class SGD_GC(Optimizer): 8 | r"""Implements stochastic gradient descent (optionally with momentum). 9 | 10 | Nesterov momentum is based on the formula from 11 | `On the importance of initialization and momentum in deep learning`__. 12 | 13 | Args: 14 | params (iterable): iterable of parameters to optimize or dicts defining 15 | parameter groups 16 | lr (float): learning rate 17 | momentum (float, optional): momentum factor (default: 0) 18 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 19 | dampening (float, optional): dampening for momentum (default: 0) 20 | nesterov (bool, optional): enables Nesterov momentum (default: False) 21 | 22 | Example: 23 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 24 | >>> optimizer.zero_grad() 25 | >>> loss_fn(model(input), target).backward() 26 | >>> optimizer.step() 27 | 28 | __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf 29 | 30 | .. note:: 31 | The implementation of SGD with Momentum/Nesterov subtly differs from 32 | Sutskever et. al. and implementations in some other frameworks. 33 | 34 | Considering the specific case of Momentum, the update can be written as 35 | 36 | .. math:: 37 | \begin{aligned} 38 | v_{t+1} & = \mu * v_{t} + g_{t+1}, \\ 39 | p_{t+1} & = p_{t} - \text{lr} * v_{t+1}, 40 | \end{aligned} 41 | 42 | where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the 43 | parameters, gradient, velocity, and momentum respectively. 44 | 45 | This is in contrast to Sutskever et. al. and 46 | other frameworks which employ an update of the form 47 | 48 | .. math:: 49 | \begin{aligned} 50 | v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\ 51 | p_{t+1} & = p_{t} - v_{t+1}. 52 | \end{aligned} 53 | 54 | The Nesterov version is analogously modified. 55 | """ 56 | def __init__(self, params, lr=required, momentum=0, dampening=0, weight_decay=0, nesterov=False): 57 | if lr is not required and lr < 0.0: 58 | raise ValueError("Invalid learning rate: {}".format(lr)) 59 | if momentum < 0.0: 60 | raise ValueError("Invalid momentum value: {}".format(momentum)) 61 | if weight_decay < 0.0: 62 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 63 | 64 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, nesterov=nesterov) 65 | if nesterov and (momentum <= 0 or dampening != 0): 66 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 67 | super(SGD_GC, self).__init__(params, defaults) 68 | 69 | def __setstate__(self, state): 70 | super(SGD_GC, self).__setstate__(state) 71 | for group in self.param_groups: 72 | group.setdefault('nesterov', False) 73 | 74 | @torch.no_grad() 75 | def step(self, closure=None): 76 | """Performs a single optimization step. 77 | 78 | Arguments: 79 | closure (callable, optional): A closure that reevaluates the model 80 | and returns the loss. 81 | """ 82 | loss = None 83 | if closure is not None: 84 | with torch.enable_grad(): 85 | loss = closure() 86 | 87 | for group in self.param_groups: 88 | weight_decay = group['weight_decay'] 89 | momentum = group['momentum'] 90 | dampening = group['dampening'] 91 | nesterov = group['nesterov'] 92 | 93 | for p in group['params']: 94 | if p.grad is None: 95 | continue 96 | d_p = p.grad 97 | if weight_decay != 0: 98 | d_p = d_p.add(p, alpha=weight_decay) 99 | if len(d_p.size()) > 3: 100 | d_p.add_(-d_p.mean(dim=tuple(range(1, len(d_p.size()))), keepdim=True)) 101 | 102 | if momentum != 0: 103 | param_state = self.state[p] 104 | if 'momentum_buffer' not in param_state: 105 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() 106 | else: 107 | buf = param_state['momentum_buffer'] 108 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening) 109 | if nesterov: 110 | d_p = d_p.add(buf, alpha=momentum) 111 | else: 112 | d_p = buf 113 | 114 | p.add_(d_p, alpha=-group['lr']) 115 | 116 | return loss 117 | -------------------------------------------------------------------------------- /torchtoolbox/tools/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: pistonyang@gmail.com 3 | 4 | from .reset_model_setting import * 5 | from .mixup import * 6 | from .summary import * 7 | from .utils import * 8 | from .distribute import * 9 | from .dotdict import DotDict 10 | from .config_parser import parse_config 11 | from .registry import Registry 12 | -------------------------------------------------------------------------------- /torchtoolbox/tools/config_parser.py: -------------------------------------------------------------------------------- 1 | from .dotdict import DotDict 2 | import yaml 3 | import pathlib 4 | 5 | 6 | def parse_config(config_file: str): 7 | config_file = pathlib.Path(config_file) 8 | assert config_file.suffix in ('.yml', '.yaml'), "Only support yaml files." 9 | with open(config_file, 'r') as f: 10 | cfg = yaml.load(f, Loader=yaml.SafeLoader) 11 | cfg = DotDict(cfg) 12 | circulate_parse(cfg, config_file.parent) 13 | return cfg 14 | 15 | 16 | def merge_dict(target_dict: dict, sub_dict: dict, key: str, replace=False): 17 | """merge sub dict to target dict. 18 | 19 | Now special key is only `__base__`, this can be added if needed. 20 | 21 | Args: 22 | target_dict (dict): target to merge 23 | sub_dict (dict): merge this to target 24 | key (str): key word 25 | replace (bool, optional): Whether replace if target sub key is not None. Defaults to False. 26 | """ 27 | if key == '__base__': 28 | for sub_key, sub_value in sub_dict.items(): 29 | if sub_key not in target_dict.keys() or replace: 30 | target_dict[sub_key] = sub_value 31 | target_dict.pop(key) 32 | 33 | else: 34 | target_dict[key] = sub_dict 35 | 36 | 37 | def circulate_parse(parse_dict, base_path: pathlib.Path, parse_target='yaml'): 38 | for key, value in parse_dict.copy().items(): 39 | if isinstance(value, str) and value.endswith(f".{parse_target}"): 40 | sub_config_path = base_path.joinpath(value).resolve() 41 | config = parse_config(sub_config_path) 42 | merge_dict(parse_dict, config, key) 43 | elif isinstance(value, (dict, DotDict)): 44 | circulate_parse(value, base_path) 45 | -------------------------------------------------------------------------------- /torchtoolbox/tools/convert_lmdb.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | """This file should not be included in __init__""" 4 | 5 | __all__ = ['get_key', 'load_pyarrow', 'dumps_pyarrow', 'generate_lmdb_dataset', 'raw_reader'] 6 | 7 | import lmdb 8 | import os 9 | import pyarrow 10 | from torch.utils.data import DataLoader, Dataset 11 | from tqdm import tqdm 12 | from .utils import check_dir 13 | 14 | 15 | def get_key(index): 16 | return u'{}'.format(index).encode('ascii') 17 | 18 | 19 | def raw_reader(path): 20 | with open(path, 'rb') as f: 21 | bin_data = f.read() 22 | return bin_data 23 | 24 | 25 | def dumps_pyarrow(obj): 26 | return pyarrow.serialize(obj).to_buffer() 27 | 28 | 29 | def load_pyarrow(buf): 30 | assert buf is not None, 'buf should not be None.' 31 | return pyarrow.deserialize(buf) 32 | 33 | 34 | def generate_lmdb_dataset(data_set: Dataset, save_dir: str, name: str, num_workers=0, max_size_rate=1.0, write_frequency=5000): 35 | data_loader = DataLoader(data_set, num_workers=num_workers, collate_fn=lambda x: x) 36 | num_samples = len(data_set) 37 | check_dir(save_dir) 38 | lmdb_path = os.path.join(save_dir, '{}.lmdb'.format(name)) 39 | db = lmdb.open(lmdb_path, 40 | subdir=False, 41 | map_size=int(1099511627776 * max_size_rate), 42 | readonly=False, 43 | meminit=True, 44 | map_async=True) 45 | txn = db.begin(write=True) 46 | for idx, data in enumerate(tqdm(data_loader)): 47 | txn.put(get_key(idx), dumps_pyarrow(data[0])) 48 | if idx % write_frequency == 0 and idx > 0: 49 | txn.commit() 50 | txn = db.begin(write=True) 51 | txn.put(b'__len__', dumps_pyarrow(num_samples)) 52 | try: 53 | classes = data_set.classes 54 | class_to_idx = data_set.class_to_idx 55 | txn.put(b'classes', dumps_pyarrow(classes)) 56 | txn.put(b'class_to_idx', dumps_pyarrow(class_to_idx)) 57 | except AttributeError: 58 | pass 59 | 60 | txn.commit() 61 | db.sync() 62 | db.close() 63 | -------------------------------------------------------------------------------- /torchtoolbox/tools/distribute.py: -------------------------------------------------------------------------------- 1 | __all__ = ['reduce_tensor'] 2 | 3 | from .utils import to_numpy 4 | from torch import distributed 5 | 6 | import torch 7 | import numpy as np 8 | 9 | 10 | def reduce_tensor(tensor, rank, op=distributed.ReduceOp.SUM, dst=0, reduce_type='reduce'): 11 | """Reduce tensor cross ranks. 12 | 13 | Args: 14 | tensor: tensor need to be reduced. 15 | rank(int): rank where tensor at. 16 | op: reduce op, use `sum` by default. 17 | dst(int): only used for reduce_type=='reduce' 18 | reduce_type(str): only support reduce or all_reduce. 19 | 20 | Returns: 21 | tensor after reduced. 22 | """ 23 | post_process = None 24 | device = torch.device(rank) 25 | if isinstance(tensor, (int, float)): 26 | tensor = torch.tensor(tensor, device=device) 27 | post_process = lambda x: x.item() 28 | elif torch.is_tensor(tensor): 29 | if tensor.device.index != rank: 30 | tensor = tensor.to(device) 31 | elif isinstance(tensor, np.ndarray): 32 | tensor = torch.from_numpy(tensor).to(device=device) 33 | post_process = lambda x: to_numpy(x) 34 | else: 35 | raise NotImplementedError(f'Only Pytorch Tensor, Python(float, int) and' 36 | f' Numpy ndarray are supported. But got {type(tensor)}') 37 | 38 | if reduce_type == 'reduce': 39 | distributed.reduce(tensor, dst, op=op) 40 | else: 41 | distributed.all_reduce(tensor, op=op) 42 | 43 | # if not all_reduce only process dst tensor 44 | if post_process is not None: 45 | if reduce_type == 'all_reduce': 46 | tensor = post_process(tensor) 47 | elif rank == dst: 48 | tensor = post_process(tensor) 49 | 50 | return tensor 51 | -------------------------------------------------------------------------------- /torchtoolbox/tools/dotdict.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | 4 | class DotDict(dict): 5 | def __init__(self, data_map: Optional[Dict] = None): 6 | if data_map is not None: 7 | super(DotDict, self).__init__(data_map) 8 | if isinstance(data_map, dict): 9 | for k, v in data_map.items(): 10 | if not isinstance(v, dict): 11 | self[k] = v 12 | else: 13 | self.__setattr__(k, DotDict(v)) 14 | else: 15 | super().__init__() 16 | 17 | def __getattr__(self, attr): 18 | return self.get(attr) 19 | 20 | def __setattr__(self, key, value): 21 | self.__setitem__(key, value) 22 | 23 | def __setitem__(self, key, value): 24 | super().__setitem__(key, value) 25 | self.__dict__.update({key: value}) 26 | 27 | def __delattr__(self, item): 28 | self.__delitem__(item) 29 | 30 | def __delitem__(self, key): 31 | super().__delitem__(key) 32 | del self.__dict__[key] 33 | 34 | def __str__(self) -> str: 35 | dump_str = '(' 36 | for k, v in self.__dict__.items(): 37 | dump_str += f"{k}={v}; " 38 | dump_str = dump_str[:-2] + ")" 39 | return dump_str 40 | 41 | def pop(self, key): 42 | value = super().pop(key) 43 | del self.__dict__[key] 44 | return value 45 | 46 | @classmethod 47 | def to_dict(cls, dot_dict): 48 | new_dict = {} 49 | for key, value in dot_dict.items(): 50 | if isinstance(value, cls): 51 | new_dict[key] = cls.as_dict(value) 52 | else: 53 | new_dict[key] = value 54 | return new_dict 55 | 56 | def as_dict(self): 57 | return self.to_dict(self) 58 | -------------------------------------------------------------------------------- /torchtoolbox/tools/mixup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : PistonYang(pistonyang@gmail.com) 3 | 4 | from torch import nn 5 | import numpy as np 6 | import torch 7 | import random 8 | 9 | __all__ = ['mixup_data', 'mixup_criterion', 'cutmix_data', 'MixingDataController'] 10 | 11 | 12 | def cut_mix_rand_bbox(size, lam): 13 | H = size[2] 14 | W = size[3] 15 | cut_rat = np.sqrt(1. - lam) 16 | cut_h = np.int(H * cut_rat) 17 | cut_w = np.int(W * cut_rat) 18 | 19 | # uniform 20 | cy = np.random.randint(H) 21 | cx = np.random.randint(W) 22 | 23 | bby1 = np.clip(cy - cut_h // 2, 0, H) 24 | bbx1 = np.clip(cx - cut_w // 2, 0, W) 25 | bby2 = np.clip(cy + cut_h // 2, 0, H) 26 | bbx2 = np.clip(cx + cut_w // 2, 0, W) 27 | 28 | return bby1, bbx1, bby2, bbx2 29 | 30 | 31 | @torch.no_grad() 32 | def mixup_data(x, y, alpha=0.2): 33 | """Returns mixed inputs, pairs of targets, and lambda 34 | """ 35 | if alpha > 0: 36 | lam = np.random.beta(alpha, alpha) 37 | else: 38 | lam = 1 39 | 40 | mixed_x = lam * x + (1 - lam) * x.flip(dims=(0, )) 41 | y_a, y_b = y, y.flip(dims=(0, )) 42 | return mixed_x, y_a, y_b, lam 43 | 44 | 45 | @torch.no_grad() 46 | def cutmix_data(x, y, alpha=0.2): 47 | if alpha > 0: 48 | lam = np.random.beta(alpha, alpha) 49 | else: 50 | lam = 1 51 | rand_index = torch.randperm(x.size(0)) 52 | y_a, y_b = y, y[rand_index] 53 | bby1, bbx1, bby2, bbx2 = cut_mix_rand_bbox(x.size(), lam) 54 | x[:, :, bby1:bby2, bbx1:bbx2] = x[rand_index, :, bby1:bby2, bbx1:bbx2] 55 | return x, y_a, y_b, lam 56 | 57 | 58 | def mixup_criterion(criterion, pred, y_a, y_b, lam): 59 | return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) 60 | 61 | 62 | class MixingDataController(nn.Module): 63 | def __init__(self, mixup=False, cutmix=False, mixup_alpha=0.2, cutmix_alpha=1.0, mixup_prob=1.0, cutmix_prob=1.0): 64 | super().__init__() 65 | self.mixup = mixup 66 | self.cutmix = cutmix 67 | self.mixup_alpha = mixup_alpha 68 | self.cutmix_alpha = cutmix_alpha 69 | self.mixup_prob = mixup_prob 70 | self.cutmix_prob = cutmix_prob 71 | 72 | def setup_mixup(self, enable, alpha, probability): 73 | self.mixup = enable 74 | self.mixup_alpha = alpha 75 | self.mixup_prob = probability 76 | 77 | def setup_cutmix(self, enable, alpha, probability): 78 | self.cutmix = enable 79 | self.cutmix_alpha = alpha 80 | self.cutmix_prob = probability 81 | 82 | def get_method(self): 83 | mu_w = self.mixup_prob if self.mixup else 0. 84 | cm_w = self.cutmix_prob if self.cutmix else 0. 85 | if self.mixup and self.cutmix: 86 | mu_w *= 0.5 87 | cm_w *= 0.5 88 | no_w = 1 - mu_w - cm_w 89 | return random.choices(['mixup', 'cutmix', None], weights=[mu_w, cm_w, no_w], k=1)[0] 90 | 91 | def get_loss(self, Loss, data, labels, preds): 92 | md = self.get_method() 93 | if md == 'mixup': 94 | data, labels_a, labels_b, lam = mixup_data(data, labels, self.mixup_alpha) 95 | loss = mixup_criterion(Loss, preds, labels_a, labels_b, lam) 96 | elif md == 'cutmix': 97 | data, labels_a, labels_b, lam = cutmix_data(data, labels, self.cutmix_alpha) 98 | loss = mixup_criterion(Loss, preds, labels_a, labels_b, lam) 99 | else: 100 | loss = Loss(preds, labels) 101 | return loss 102 | -------------------------------------------------------------------------------- /torchtoolbox/tools/registry.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | 4 | class Registry: 5 | """Provide a string to object transfer. 6 | Args: 7 | name (str): name of this registry. 8 | """ 9 | def __init__(self, name: str) -> None: 10 | self.name = name 11 | self._obj_map: Dict(str, object) = {} 12 | 13 | def _do_register(self, name: str, obj: object = None) -> None: 14 | assert name not in self._obj_map, \ 15 | f"An object named '{name}' was already registered in '{self.name}' registry!" 16 | self._obj_map[name] = obj 17 | 18 | def register(self, obj: object = None, name: str = None) -> Optional[object]: 19 | """Register the give object. 20 | 21 | Args: 22 | obj (object, optional): obj to register. Defaults to None. 23 | name (str, optional): specific name for this obj. 24 | Returns: 25 | Optional[object]: ori obj. 26 | """ 27 | if obj is None: 28 | 29 | def deco(func_or_class: object) -> object: 30 | _name = func_or_class.__name__ if name is None else name 31 | self._do_register(_name, func_or_class) 32 | return func_or_class 33 | 34 | return deco 35 | 36 | _name = obj.__name__ if name is None else name 37 | self._do_register(_name, obj) 38 | 39 | def get(self, name: str) -> object: 40 | ret = self._obj_map.get(name) 41 | if ret is None: 42 | raise KeyError(f'No object named {name} found in {self.name} register.') 43 | return ret 44 | 45 | def __contains__(self, name: str) -> bool: 46 | return name in self._obj_map 47 | -------------------------------------------------------------------------------- /torchtoolbox/tools/reset_model_setting.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | __all__ = ['no_decay_bias', 'reset_model_setting', 'ZeroLastGamma'] 3 | 4 | from .utils import to_list 5 | from torch import nn 6 | 7 | 8 | def no_decay_bias(net, extra_conv=()): 9 | """split network weights into to categlories, 10 | one are weights in conv layer and linear layer, 11 | others are other learnable paramters(conv bias, 12 | bn weights, bn bias, linear bias) 13 | Args: 14 | net: network architecture 15 | Returns: 16 | a dictionary of params splite into to categlories 17 | """ 18 | extra_conv = to_list(extra_conv) 19 | 20 | decay = [] 21 | no_decay = [] 22 | 23 | for m in net.modules(): 24 | if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear, *extra_conv)): 25 | decay.append(m.weight) 26 | if m.bias is not None: 27 | no_decay.append(m.bias) 28 | elif isinstance(m, nn.MultiheadAttention): 29 | decay.append(m.in_proj_weight) 30 | if m.in_proj_bias is not None: 31 | no_decay.append(m.in_proj_bias) 32 | elif hasattr(m, 'no_wd') and callable(getattr(m, 'no_wd')): 33 | m.no_wd(decay, no_decay) 34 | else: 35 | if hasattr(m, 'weight') and m.weight is not None: 36 | no_decay.append(m.weight) 37 | if hasattr(m, 'bias') and m.bias is not None: 38 | no_decay.append(m.bias) 39 | 40 | assert len(list(net.parameters())) == len(decay) + len(no_decay) 41 | 42 | return [dict(params=decay), dict(params=no_decay, weight_decay=0)] 43 | 44 | 45 | def reset_model_setting(model, layer_names, setting_names, values): 46 | """Split model params in to parts.One is normal setting, another is setting manually. 47 | 48 | Args: 49 | model: model to control. 50 | layer_names: layers to change setting. 51 | setting_name: param name to reset. 52 | values: reset values. 53 | 54 | Returns: new params dict 55 | 56 | For example: 57 | parameters = reset_model_setting(model, 'output', 'lr', '0.1') 58 | """ 59 | layer_names, setting_names, values = map(to_list, (layer_names, setting_names, values)) 60 | assert len(setting_names) == len(values) 61 | ignore_params = [] 62 | for name in layer_names: 63 | ignore_params.extend(list(map(id, getattr(model, name).parameters()))) 64 | 65 | base_param = filter(lambda p: id(p) not in ignore_params, model.parameters()) 66 | reset_param = filter(lambda p: id(p) in ignore_params, model.parameters()) 67 | 68 | parameters = [{'params': base_param}, {'params': reset_param}.update(dict(zip(setting_names, values)))] 69 | return parameters 70 | 71 | 72 | class ZeroLastGamma(object): 73 | def __init__(self, block_name='Bottleneck', bn_name='bn3'): 74 | self.block_name = block_name 75 | self.bn_name = bn_name 76 | 77 | def __call__(self, module): 78 | if module.__class__.__name__ == self.block_name: 79 | target_bn = module.__getattr__(self.bn_name) 80 | nn.init.zeros_(target_bn.weight) 81 | 82 | 83 | class SchedulerCollector(object): 84 | def __init__(self): 85 | self.schedulers = [] 86 | 87 | def register(self, scheduler): 88 | self.schedulers.append(scheduler) 89 | 90 | def step(self): 91 | for shd in self.schedulers: 92 | shd.step() 93 | 94 | def state_dict(self): 95 | return {str(idx): value.__dict__ for idx, value in enumerate(self.schedulers)} 96 | 97 | def load_state_dict(self, state_dict): 98 | for key, values in state_dict: 99 | self.schedulers[int(key)].__dict__.update(values.items()) 100 | -------------------------------------------------------------------------------- /torchtoolbox/tools/summary.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | __all__ = ['summary'] 4 | 5 | from collections import OrderedDict 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | 10 | 11 | def _flops_str(flops): 12 | preset = [(1e12, 'T'), (1e9, 'G'), (1e6, 'M'), (1e3, 'K')] 13 | 14 | for p in preset: 15 | if flops // p[0] > 0: 16 | N = flops / p[0] 17 | ret = "%.1f%s" % (N, p[1]) 18 | return ret 19 | ret = "%.1f" % flops 20 | return ret 21 | 22 | 23 | def _cac_grad_params(p, w): 24 | t, n = 0, 0 25 | if w.requires_grad: 26 | t += p 27 | else: 28 | n += p 29 | return t, n 30 | 31 | 32 | def _cac_msa(layer, input, output): 33 | sl, b, dim = output[0].size() 34 | assert b == 1, 'Only support batch size of 1.' 35 | tb_params = 0 36 | ntb__params = 0 37 | flops = 0 38 | 39 | if layer._qkv_same_embed_dim is False: 40 | tb_params += layer.q_proj_weight.numel() 41 | tb_params += layer.k_proj_weight.numel() 42 | tb_params += layer.v_proj_weight.numel() 43 | else: 44 | tb_params += layer.in_proj_weight.numel() 45 | 46 | if hasattr(layer, 'in_proj_bias'): 47 | tb_params += layer.in_proj_bias.numel() 48 | 49 | tb_params += layer.embed_dim**2 50 | 51 | # flops of this layer if fixed. 52 | # first get KQV 53 | flops += sl * dim * 3 * (2 * dim - 1) 54 | if hasattr(layer, 'in_proj_bias'): 55 | flops += dim * 3 56 | # then cac sa 57 | num_heads = layer.num_heads 58 | head_dim = layer.head_dim 59 | flops += (num_heads * sl * sl * (2 * head_dim - 1) + num_heads * sl * head_dim * (2 * sl - 1)) 60 | # last linear 61 | flops += sl * dim * (2 * dim - 1) + dim 62 | return tb_params, ntb__params, flops 63 | 64 | 65 | def _cac_conv(layer, input, output): 66 | # bs, ic, ih, iw = input[0].shape 67 | oh, ow = output.shape[-2:] 68 | kh, kw = layer.kernel_size 69 | ic, oc = layer.in_channels, layer.out_channels 70 | g = layer.groups 71 | 72 | tb_params = 0 73 | ntb__params = 0 74 | flops = 0 75 | if hasattr(layer, 'weight') and hasattr(layer.weight, 'shape'): 76 | params = np.prod(layer.weight.shape) 77 | t, n = _cac_grad_params(params, layer.weight) 78 | tb_params += t 79 | ntb__params += n 80 | flops += (2 * ic * kh * kw - 1) * oh * ow * (oc // g) 81 | if hasattr(layer, 'bias') and hasattr(layer.bias, 'shape'): 82 | params = np.prod(layer.bias.shape) 83 | t, n = _cac_grad_params(params, layer.bias) 84 | tb_params += t 85 | ntb__params += n 86 | flops += oh * ow * (oc // g) 87 | return tb_params, ntb__params, flops 88 | 89 | 90 | def _cac_xx_norm(layer, input, output): 91 | tb_params = 0 92 | ntb__params = 0 93 | if hasattr(layer, 'weight') and hasattr(layer.weight, 'shape'): 94 | params = np.prod(layer.weight.shape) 95 | t, n = _cac_grad_params(params, layer.weight) 96 | tb_params += t 97 | ntb__params += n 98 | if hasattr(layer, 'bias') and hasattr(layer.bias, 'shape'): 99 | params = np.prod(layer.bias.shape) 100 | t, n = _cac_grad_params(params, layer.bias) 101 | tb_params += t 102 | ntb__params += n 103 | if hasattr(layer, 'running_mean') and hasattr(layer.running_mean, 'shape'): 104 | params = np.prod(layer.running_mean.shape) 105 | ntb__params += params 106 | if hasattr(layer, 'running_var') and hasattr(layer.running_var, 'shape'): 107 | params = np.prod(layer.running_var.shape) 108 | ntb__params += params 109 | in_shape = input[0] 110 | flops = np.prod(in_shape.shape) 111 | if layer.affine: 112 | flops *= 2 113 | return tb_params, ntb__params, flops 114 | 115 | 116 | def _cac_linear(layer, input, output): 117 | ic, oc = layer.in_features, layer.out_features 118 | 119 | tb_params = 0 120 | ntb__params = 0 121 | flops = 0 122 | 123 | input = input[0] 124 | in_len = len(input.size()) 125 | if in_len == 2: 126 | if hasattr(layer, 'weight') and hasattr(layer.weight, 'shape'): 127 | params = np.prod(layer.weight.shape) 128 | t, n = _cac_grad_params(params, layer.weight) 129 | tb_params += t 130 | ntb__params += n 131 | flops += (2 * ic - 1) * oc 132 | if hasattr(layer, 'bias') and hasattr(layer.bias, 'shape'): 133 | params = np.prod(layer.bias.shape) 134 | t, n = _cac_grad_params(params, layer.bias) 135 | tb_params += t 136 | ntb__params += n 137 | flops += oc 138 | return tb_params, ntb__params, flops 139 | elif in_len == 3: 140 | if input.size(0) == 1: 141 | sl, dim = input.shape[1:] 142 | elif input.size(1) == 1: 143 | sl, _, dim = input.shape 144 | else: 145 | raise ValueError('Only support batch size of 1.') 146 | 147 | if hasattr(layer, 'weight') and hasattr(layer.weight, 'shape'): 148 | params = np.prod(layer.weight.shape) 149 | t, n = _cac_grad_params(params, layer.weight) 150 | tb_params += t 151 | ntb__params += n 152 | flops += sl * (2 * ic - 1) * oc 153 | if hasattr(layer, 'bias') and hasattr(layer.bias, 'shape'): 154 | params = np.prod(layer.bias.shape) 155 | t, n = _cac_grad_params(params, layer.bias) 156 | tb_params += t 157 | ntb__params += n 158 | flops += oc 159 | return tb_params, ntb__params, flops 160 | 161 | else: 162 | raise NotImplementedError 163 | 164 | 165 | @torch.no_grad() 166 | def summary(model, x, return_results=False, extra_conv=(), extra_norm=(), extra_linear=()): 167 | """ 168 | 169 | Args: 170 | model (nn.Module): model to summary 171 | x (torch.Tensor): input data 172 | return_results (bool): return results 173 | 174 | Returns: 175 | 176 | """ 177 | # change bn work way 178 | model.eval() 179 | 180 | def register_hook(layer): 181 | def hook(layer, input, output): 182 | model_name = str(layer.__class__.__name__) 183 | module_idx = len(model_summary) 184 | s_key = '{}-{}'.format(model_name, module_idx + 1) 185 | model_summary[s_key] = OrderedDict() 186 | model_summary[s_key]['input_shape'] = list(input[0].shape) 187 | if isinstance(output, (tuple, list)): 188 | model_summary[s_key]['output_shape'] = [list(o.shape) for o in output] 189 | else: 190 | model_summary[s_key]['output_shape'] = list(output.shape) 191 | tb_params = 0 192 | ntb__params = 0 193 | flops = 0 194 | 195 | if isinstance(layer, (nn.Conv2d, *extra_conv)): 196 | tb_params, ntb__params, flops = _cac_conv(layer, input, output) 197 | elif isinstance(layer, (nn.BatchNorm2d, nn.GroupNorm, *extra_norm)): 198 | tb_params, ntb__params, flops = _cac_xx_norm(layer, input, output) 199 | elif isinstance(layer, (nn.Linear, *extra_linear)): 200 | tb_params, ntb__params, flops = _cac_linear(layer, input, output) 201 | elif isinstance(layer, nn.MultiheadAttention): 202 | tb_params, ntb__params, flops = _cac_msa(layer, input, output) 203 | 204 | if hasattr(layer, 'num_param') and callable(getattr(layer, 'num_param')): 205 | assert tb_params == 0 and ntb__params == 0, 'params has been calculated by default func.' 206 | tb_params, ntb__params = layer.num_param(input, output) 207 | 208 | if hasattr(layer, 'flops') and callable(getattr(layer, 'flops')): 209 | assert flops == 0, 'flops has been calculated by default func.' 210 | flops = layer.flops(input, output) 211 | 212 | model_summary[s_key]['trainable_params'] = tb_params 213 | model_summary[s_key]['non_trainable_params'] = ntb__params 214 | model_summary[s_key]['params'] = tb_params + ntb__params 215 | model_summary[s_key]['flops'] = flops 216 | 217 | if not isinstance(layer, (nn.Sequential, nn.ModuleList, nn.Identity, nn.ModuleDict)): 218 | hooks.append(layer.register_forward_hook(hook)) 219 | 220 | model_summary = OrderedDict() 221 | hooks = [] 222 | model.apply(register_hook) 223 | try: 224 | model(x) 225 | except Exception as e: 226 | raise e 227 | finally: 228 | for h in hooks: 229 | h.remove() 230 | 231 | summary_str = '' 232 | summary_str += '-' * 80 + '\n' 233 | line_new = "{:>20} {:>25} {:>15} {:>15}\n".format("Layer (type)", "Output Shape", "Params", "FLOPs(M+A) #") 234 | summary_str += line_new 235 | summary_str += '=' * 80 + '\n' 236 | total_params = 0 237 | trainable_params = 0 238 | total_flops = 0 239 | for layer in model_summary: 240 | line_new = "{:>20} {:>25} {:>15} {:>15}\n".format( 241 | layer, 242 | str(model_summary[layer]['output_shape']), 243 | model_summary[layer]['params'], 244 | model_summary[layer]['flops'], 245 | ) 246 | summary_str += line_new 247 | total_params += model_summary[layer]['params'] 248 | trainable_params += model_summary[layer]['trainable_params'] 249 | total_flops += model_summary[layer]['flops'] 250 | 251 | param_str = _flops_str(total_params) 252 | flop_str = _flops_str(total_flops) 253 | flop_str_m = _flops_str(total_flops // 2) 254 | param_size = total_params * 4 / (1024**2) 255 | if return_results: 256 | return total_params, total_flops 257 | 258 | summary_str += '=' * 80 + '\n' 259 | summary_str += ' Total parameters: {:,} {}\n'.format(total_params, param_str) 260 | summary_str += ' Trainable parameters: {:,}\n'.format(trainable_params) 261 | summary_str += 'Non-trainable parameters: {:,}\n'.format(total_params - trainable_params) 262 | summary_str += 'Total flops(M) : {:,} {}\n'.format(total_flops // 2, flop_str_m) 263 | summary_str += 'Total flops(M+A): {:,} {}\n'.format(total_flops, flop_str) 264 | summary_str += '-' * 80 + '\n' 265 | summary_str += 'Parameters size (MB): {:.2f}'.format(param_size) 266 | return summary_str 267 | -------------------------------------------------------------------------------- /torchtoolbox/tools/tensor_transfer.py: -------------------------------------------------------------------------------- 1 | """just let it empty. 2 | """ 3 | -------------------------------------------------------------------------------- /torchtoolbox/tools/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | __all__ = [ 4 | 'check_dir', 'to_list', 'to_value', 'remove_file', 'make_divisible', 'apply_ratio', 'to_numpy', 'get_list_index', 5 | 'get_value_from_dicts', 'seconds_to_time', 'encode_one_hot', 'decode_one_hot', 'list_step_slice', 'convert_module', 6 | 'check_twin', 'get_list_value' 7 | ] 8 | 9 | import hashlib 10 | import json 11 | import os 12 | from typing import List, Tuple, Union 13 | 14 | import numpy as np 15 | import torch 16 | 17 | 18 | def to_list(value): 19 | if not isinstance(value, (list, tuple)): 20 | value = [value] 21 | return value 22 | 23 | 24 | def to_value(container: Union[List, Tuple], check_same=False): 25 | if isinstance(container, (list, tuple)): 26 | if check_same: 27 | for bef, aft in zip(container[:-1], container[1:]): 28 | assert bef == aft 29 | return container[0] 30 | return container[0] 31 | else: 32 | return container 33 | 34 | 35 | def check_twin(value, length=2): 36 | if not isinstance(value, (list, tuple)): 37 | return [value for _ in range(length)] 38 | else: 39 | assert len(value) == length, f'length of {value} should be {length} but {len(value)}' 40 | return value 41 | 42 | 43 | def check_dir(*path): 44 | """Check dir(s) exist or not, if not make one(them). 45 | Args: 46 | path: full path(s) to check. 47 | """ 48 | for p in path: 49 | os.makedirs(p, exist_ok=True) 50 | 51 | 52 | def remove_file(file_path: str, show_detail=False): 53 | if not os.path.exists(file_path): 54 | if show_detail: 55 | print(f'File {file_path} not exist.') 56 | return 57 | os.remove(file_path) 58 | 59 | 60 | def make_divisible(v: Union[int, float], divisible_by: int, min_value: Union[int, None] = None): 61 | """ 62 | This function is taken from the original tf repo. 63 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 64 | """ 65 | if min_value is None: 66 | min_value = divisible_by 67 | new_v = max(min_value, int(v + divisible_by / 2) // divisible_by * divisible_by) 68 | # Make sure that round down does not go down by more than 10%. 69 | if new_v < 0.9 * v: 70 | new_v += divisible_by 71 | return new_v 72 | 73 | 74 | def apply_ratio(src: Union[List, int], ratio: float, **kwargs): 75 | if isinstance(src, int): 76 | src = [ 77 | src, 78 | ] 79 | elif isinstance(src, list): 80 | pass 81 | else: 82 | raise NotImplementedError(f'{type(src)} of src is not support.') 83 | src = [make_divisible(s * ratio, **kwargs) for s in src] 84 | if len(src) == 1: 85 | return src[0] 86 | else: 87 | return src 88 | 89 | 90 | @torch.no_grad() 91 | def to_numpy(tensor): 92 | if isinstance(tensor, np.ndarray): 93 | return tensor 94 | elif torch.is_tensor(tensor): 95 | if tensor.get_device() == -1: # cpu tensor 96 | return tensor.numpy() 97 | else: 98 | return tensor.cpu().numpy() 99 | elif isinstance(tensor, (list, tuple)): 100 | return np.array(tensor) 101 | else: 102 | raise NotImplementedError(f"The type of {type(tensor)} is not support to convert numpy." 103 | " torch.tensor, list and tuple are support now.") 104 | 105 | 106 | def get_list_index(lst: Union[list, tuple], value): 107 | """get not only fist but all index of a value in a list or tuple. 108 | 109 | Args: 110 | lst (Union[list, tuple]): target list. 111 | value (Any): value to get index. 112 | 113 | Returns: 114 | list: result 115 | """ 116 | return [i for i, v in enumerate(lst) if v == value] 117 | 118 | 119 | def get_list_value(lst: Union[list, tuple], inds): 120 | """get value form index. 121 | 122 | Args: 123 | lst (Union[list, tuple]): target list. 124 | inds (Any): value from index. 125 | 126 | Returns: 127 | list: result 128 | 129 | """ 130 | return [lst[i] for i in inds] 131 | 132 | 133 | def get_value_from_dicts(dicts, keys, post_process=None): 134 | assert isinstance(dicts, (list, tuple, dict)) 135 | assert post_process in (None, 'max', 'min', 'mean') 136 | keys = to_list(keys) 137 | if isinstance(dicts, dict): 138 | dicts = dicts.values() 139 | value_list = [[value[key] for value in dicts if isinstance(value, dict)] for key in keys] 140 | if post_process is not None: 141 | if post_process == 'mean': 142 | value_list = [np.mean(v) for v in value_list] 143 | elif post_process == 'max': 144 | value_list = [np.max(v) for v in value_list] 145 | elif post_process == 'min': 146 | value_list = [np.mean(v) for v in value_list] 147 | else: 148 | raise NotImplementedError 149 | return value_list 150 | 151 | 152 | def seconds_to_time(seconds: int): 153 | m, s = divmod(seconds, 60) 154 | h, m = divmod(m, 60) 155 | return h, m, s 156 | 157 | 158 | def encode_one_hot(cls: int, num_classes: int): 159 | # encode to a one-hot list. 160 | assert isinstance(cls, int) and isinstance(num_classes, int) 161 | assert -1 <= cls < num_classes 162 | return [0 if cls != c else 1 for c in range(num_classes)] 163 | 164 | 165 | def decode_one_hot(one_hot_list): 166 | assert isinstance(one_hot_list, (list, tuple)) 167 | num_classes = len(one_hot_list) 168 | cls = [i for i, c in enumerate(one_hot_list) if c == 1] 169 | assert len(cls) in (0, 1), "an one-hot list should have one or zero class." 170 | cls = -1 if len(cls) == 0 else cls[0] 171 | return cls, num_classes 172 | 173 | 174 | def list_step_slice(lst: list, step: int = 1): 175 | """slice list by step. 176 | 177 | Args: 178 | lst (list, tuple): lst to slice. 179 | step (int, optional): step. Defaults to 1. 180 | 181 | Yields: 182 | [list]: sub list. 183 | """ 184 | assert isinstance(lst, (list, tuple)) 185 | for i in range(0, len(lst), step): 186 | yield lst[i:i + step] 187 | 188 | 189 | def convert_module(model, old_module, new_module, **kwargs): 190 | for child_name, child in model.named_children(): 191 | if isinstance(child, old_module): 192 | setattr(model, child_name, new_module(**kwargs)) 193 | else: 194 | convert_module(child) 195 | 196 | 197 | def remove_module_from_checkpoint(cp_dict): 198 | return {k.replace('module.', ''): v for k, v in cp_dict.items()} 199 | 200 | 201 | def get_md5(obj, trans_func=None): 202 | """get a object md5, if this obj is not supported by `json.dumps` please provide a trains_func. 203 | 204 | Args: 205 | obj (object): obj to get md5 206 | trans_func (function, optional): use this to trans obj to str. Defaults to None. 207 | """ 208 | if trans_func is None: 209 | trans_func = json.dumps 210 | obj_str = trans_func(obj) 211 | hl = hashlib.md5() 212 | hl.update(obj_str.encode(encoding='utf-8')) 213 | return hl.hexdigest() 214 | -------------------------------------------------------------------------------- /torchtoolbox/transform/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : PistonYang(pistonyang@gmail.com) 3 | 4 | from .transforms import * 5 | from .autoaugment import * 6 | from .dynamic_transform import * 7 | -------------------------------------------------------------------------------- /torchtoolbox/transform/autoaugment.py: -------------------------------------------------------------------------------- 1 | __all__ = ['ImageNetPolicy', 'CIFAR10Policy', 'SVHNPolicy', 'RandAugment'] 2 | 3 | from .transforms import RandomChoice, Compose 4 | from PIL import Image, ImageEnhance, ImageOps 5 | import numpy as np 6 | import random 7 | 8 | 9 | def rotate_with_fill(img, magnitude): 10 | rot = img.convert("RGBA").rotate(magnitude) 11 | return Image.composite(rot, Image.new("RGBA", rot.size, (128, ) * 4), rot).convert(img.mode) 12 | 13 | 14 | trans_value = lambda maxval, minval, m: (float(m) / 30) * float(maxval - minval) + minval 15 | 16 | 17 | class SubPolicy(object): 18 | def __init__(self, p, magnitude=None, ranges=None): 19 | self.p = p 20 | if magnitude is not None and ranges is not None: 21 | self.magnitude = ranges[magnitude] 22 | 23 | def do_process(self, img): 24 | raise NotImplementedError 25 | 26 | def __call__(self, img, magnitude=None): 27 | if magnitude is not None: 28 | self.magnitude = magnitude 29 | if random.random() < self.p: 30 | img = self.do_process(img) 31 | return img 32 | 33 | 34 | class ShearX(SubPolicy): 35 | def __init__(self, p, magnitude=None, fillcolor=(128, 128, 128)): 36 | ranges = np.linspace(0, 0.3, 10) 37 | super(ShearX, self).__init__(p, magnitude, ranges) 38 | self.fillcolor = fillcolor 39 | 40 | def do_process(self, img): 41 | return img.transform(img.size, 42 | Image.AFFINE, (1, self.magnitude * random.choice([-1, 1]), 0, 0, 1, 0), 43 | Image.BICUBIC, 44 | fillcolor=self.fillcolor) 45 | 46 | 47 | class ShearY(SubPolicy): 48 | def __init__(self, p, magnitude=None, fillcolor=(128, 128, 128)): 49 | ranges = np.linspace(0, 0.3, 10) 50 | super(ShearY, self).__init__(p, magnitude, ranges) 51 | self.fillcolor = fillcolor 52 | 53 | def do_process(self, img): 54 | return img.transform(img.size, 55 | Image.AFFINE, (1, 0, 0, self.magnitude * random.choice([-1, 1]), 1, 0), 56 | Image.BICUBIC, 57 | fillcolor=self.fillcolor) 58 | 59 | 60 | class TranslateX(SubPolicy): 61 | def __init__(self, p, magnitude=None, fillcolor=(128, 128, 128)): 62 | ranges = np.linspace(0, 150 / 331, 10) 63 | super(TranslateX, self).__init__(p, magnitude, ranges) 64 | self.fillcolor = fillcolor 65 | 66 | def do_process(self, img): 67 | return img.transform(img.size, 68 | Image.AFFINE, (1, 0, self.magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), 69 | fillcolor=self.fillcolor) 70 | 71 | 72 | class TranslateY(SubPolicy): 73 | def __init__(self, p, magnitude=None, fillcolor=(128, 128, 128)): 74 | ranges = np.linspace(0, 150 / 331, 10) 75 | super(TranslateY, self).__init__(p, magnitude, ranges) 76 | self.fillcolor = fillcolor 77 | 78 | def do_process(self, img): 79 | return img.transform(img.size, 80 | Image.AFFINE, (1, 0, 0, 0, 1, self.magnitude * img.size[1] * random.choice([-1, 1])), 81 | fillcolor=self.fillcolor) 82 | 83 | 84 | class Rotate(SubPolicy): 85 | def __init__(self, p, magnitude=None): 86 | ranges = np.linspace(0, 30, 10) 87 | super(Rotate, self).__init__(p, magnitude, ranges) 88 | 89 | def do_process(self, img): 90 | return rotate_with_fill(img, self.magnitude) 91 | 92 | 93 | class Color(SubPolicy): 94 | def __init__(self, p, magnitude=None): 95 | ranges = np.linspace(0.0, 0.9, 10) 96 | super(Color, self).__init__(p, magnitude, ranges) 97 | 98 | def do_process(self, img): 99 | return ImageEnhance.Color(img).enhance(1 + self.magnitude * random.choice([-1, 1])) 100 | 101 | 102 | class Posterize(SubPolicy): 103 | def __init__(self, p, magnitude=None): 104 | ranges = np.round(np.linspace(8, 4, 10), 0).astype(int) 105 | super(Posterize, self).__init__(p, magnitude, ranges) 106 | 107 | def do_process(self, img): 108 | return ImageOps.posterize(img, int(self.magnitude)) 109 | 110 | 111 | class Solarize(SubPolicy): 112 | def __init__(self, p, magnitude=None): 113 | ranges = np.linspace(256, 0, 10) 114 | super(Solarize, self).__init__(p, magnitude, ranges) 115 | 116 | def do_process(self, img): 117 | return ImageOps.solarize(img, self.magnitude) 118 | 119 | 120 | class SolarizeAdd(SubPolicy): 121 | def __init__(self, p, addition=None, threshold=128): 122 | super(SolarizeAdd, self).__init__(p, addition) 123 | self.threshold = threshold 124 | 125 | def do_process(self, img): 126 | img_np = np.array(img).astype(int) 127 | img_np = img_np + self.magnitude 128 | img_np = np.clip(img_np, 0, 255) 129 | img_np = img_np.astype(np.uint8) 130 | img = Image.fromarray(img_np) 131 | return ImageOps.solarize(img, self.threshold) 132 | 133 | 134 | class Contrast(SubPolicy): 135 | def __init__(self, p, magnitude=None): 136 | ranges = np.linspace(0.0, 0.9, 10) 137 | super(Contrast, self).__init__(p, magnitude, ranges) 138 | 139 | def do_process(self, img): 140 | return ImageEnhance.Contrast(img).enhance(1 + self.magnitude * random.choice([-1, 1])) 141 | 142 | 143 | class Sharpness(SubPolicy): 144 | def __init__(self, p, magnitude=None): 145 | ranges = np.linspace(0.0, 0.9, 10) 146 | super(Sharpness, self).__init__(p, magnitude, ranges) 147 | 148 | def do_process(self, img): 149 | return ImageEnhance.Sharpness(img).enhance(1 + self.magnitude * random.choice([-1, 1])) 150 | 151 | 152 | class Brightness(SubPolicy): 153 | def __init__(self, p, magnitude=None): 154 | ranges = np.linspace(0.0, 0.9, 10) 155 | super(Brightness, self).__init__(p, magnitude, ranges) 156 | 157 | def do_process(self, img): 158 | return ImageEnhance.Brightness(img).enhance(1 + self.magnitude * random.choice([-1, 1])) 159 | 160 | 161 | class AutoContrast(SubPolicy): 162 | def __init__(self, p): 163 | super(AutoContrast, self).__init__(p) 164 | 165 | def do_process(self, img): 166 | return ImageOps.autocontrast(img) 167 | 168 | 169 | class Equalize(SubPolicy): 170 | def __init__(self, p): 171 | super(Equalize, self).__init__(p) 172 | 173 | def do_process(self, img): 174 | return ImageOps.equalize(img) 175 | 176 | 177 | class Invert(SubPolicy): 178 | def __init__(self, p): 179 | super(Invert, self).__init__(p) 180 | 181 | def do_process(self, img): 182 | return ImageOps.invert(img) 183 | 184 | 185 | class Identity(SubPolicy): 186 | def __init__(self, p): 187 | super(Identity, self).__init__(1., ) 188 | 189 | def do_process(self, img): 190 | return img 191 | 192 | 193 | ImageNetPolicy = RandomChoice([ 194 | Compose([Posterize(0.4, 8), Rotate(0.6, 9)]), 195 | Compose([Solarize(0.6, 5), AutoContrast(0.6)]), 196 | Compose([Equalize(0.8), Equalize(0.6)]), 197 | Compose([Posterize(0.6, 7), Posterize(0.6, 6)]), 198 | Compose([Equalize(0.4), Solarize(0.2, 4)]), 199 | Compose([Equalize(0.4), Rotate(0.8, 8)]), 200 | Compose([Solarize(0.6, 3), Equalize(0.6)]), 201 | Compose([Posterize(0.8, 5), Equalize(1.0)]), 202 | Compose([Rotate(0.2, 3), Solarize(0.6, 8)]), 203 | Compose([Equalize(0.6), Posterize(0.4, 6)]), 204 | Compose([Rotate(0.8, 8), Color(0.4, 0)]), 205 | Compose([Rotate(0.4, 9), Equalize(0.6)]), 206 | Compose([Equalize(0.0), Equalize(0.8)]), 207 | Compose([Invert(0.6), Equalize(1.0)]), 208 | Compose([Color(0.6, 4), Contrast(1.0, 8)]), 209 | Compose([Rotate(0.8, 8), Color(1.0, 2)]), 210 | Compose([Color(0.8, 8), Solarize(0.8, 7)]), 211 | Compose([Sharpness(0.4, 7), Invert(0.6)]), 212 | Compose([ShearX(0.6, 5), Equalize(1.0)]), 213 | Compose([Color(0.4, 0), Equalize(0.6)]), 214 | Compose([Equalize(0.4), Solarize(0.2, 4)]), 215 | Compose([Solarize(0.6, 5), AutoContrast(0.6)]), 216 | Compose([Invert(0.6), Equalize(1.0)]), 217 | Compose([Color(0.6, 4), Contrast(1.0, 8)]), 218 | Compose([Equalize(0.8), Equalize(0.6)]) 219 | ]) 220 | 221 | CIFAR10Policy = RandomChoice([ 222 | Compose([Invert(0.1), Contrast(0.2, 6)]), 223 | Compose([Rotate(0.7, 2), TranslateX(0.3, 9)]), 224 | Compose([Sharpness(0.8, 1), Sharpness(0.9, 3)]), 225 | Compose([ShearY(0.5, 8), TranslateY(0.7, 9)]), 226 | Compose([AutoContrast(0.5), Equalize(0.9)]), 227 | Compose([ShearY(0.2, 7), Posterize(0.3, 7)]), 228 | Compose([Color(0.4, 3), Brightness(0.6, 7)]), 229 | Compose([Sharpness(0.3, 9), Brightness(0.7, 9)]), 230 | Compose([Equalize(0.6), Equalize(0.5)]), 231 | Compose([Contrast(0.6, 7), Sharpness(0.6, 5)]), 232 | Compose([Color(0.7, 7), TranslateX(0.5, 8)]), 233 | Compose([Equalize(0.3), AutoContrast(0.4)]), 234 | Compose([TranslateY(0.4, 3), Sharpness(0.2, 6)]), 235 | Compose([Brightness(0.9, 6), Color(0.2, 8)]), 236 | Compose([Solarize(0.5, 2), Invert(0.0)]), 237 | Compose([Equalize(0.2), AutoContrast(0.6)]), 238 | Compose([Equalize(0.2), Equalize(0.6)]), 239 | Compose([Color(0.9, 9), Equalize(0.6)]), 240 | Compose([AutoContrast(0.8), Solarize(0.2, 8)]), 241 | Compose([Brightness(0.1, 3), Color(0.7, 0)]), 242 | Compose([Solarize(0.4, 5), AutoContrast(0.9)]), 243 | Compose([TranslateY(0.9, 9), TranslateY(0.7, 9)]), 244 | Compose([AutoContrast(0.9), Solarize(0.8, 3)]), 245 | Compose([Equalize(0.8), Invert(0.1)]), 246 | Compose([TranslateY(0.7, 9), AutoContrast(0.9)]) 247 | ]) 248 | 249 | SVHNPolicy = RandomChoice([ 250 | Compose([ShearX(0.9, 4), Invert(0.2)]), 251 | Compose([ShearY(0.9, 8), Invert(0.7)]), 252 | Compose([Equalize(0.6), Solarize(0.6, 6)]), 253 | Compose([Invert(0.9), Equalize(0.6)]), 254 | Compose([Equalize(0.6), Rotate(0.9, 3)]), 255 | Compose([ShearX(0.9, 4), AutoContrast(0.8)]), 256 | Compose([ShearY(0.9, 8), Invert(0.4)]), 257 | Compose([ShearY(0.9, 5), Solarize(0.2, 6)]), 258 | Compose([Invert(0.9), AutoContrast(0.8)]), 259 | Compose([Equalize(0.6), Rotate(0.9, 3)]), 260 | Compose([ShearX(0.9, 4), Solarize(0.3, 3)]), 261 | Compose([ShearY(0.8, 8), Invert(0.7)]), 262 | Compose([Equalize(0.9), TranslateY(0.6, 6)]), 263 | Compose([Invert(0.9), Equalize(0.6)]), 264 | Compose([Contrast(0.3, 3), Rotate(0.8, 4)]), 265 | Compose([Invert(0.8), TranslateY(0.0, 2)]), 266 | Compose([ShearY(0.7, 6), Solarize(0.4, 8)]), 267 | Compose([Invert(0.6), Rotate(0.8, 4)]), 268 | Compose([ShearY(0.3, 7), TranslateX(0.9, 3)]), 269 | Compose([ShearX(0.1, 6), Invert(0.6)]), 270 | Compose([Solarize(0.7, 2), TranslateY(0.6, 7)]), 271 | Compose([ShearY(0.8, 4), Invert(0.8)]), 272 | Compose([ShearX(0.7, 9), TranslateY(0.8, 3)]), 273 | Compose([ShearY(0.8, 5), AutoContrast(0.7)]), 274 | Compose([ShearX(0.7, 2), Invert(0.1)]) 275 | ]) 276 | 277 | 278 | class RandAugment(object): 279 | def __init__(self, n, m, p=1.0): 280 | self.n = n 281 | self.m = m 282 | self.p = p 283 | 284 | self.augment_list = [ 285 | (Identity(1), 0, 1), 286 | (AutoContrast(1), 0, 1), 287 | (Equalize(1), 0, 1), 288 | # (Invert(1, m, True), 0, 1), 289 | (Rotate(1), 0, 30), 290 | (Posterize(1), 0, 4), 291 | (Solarize(1), 0, 256), 292 | (Color(1), 0.1, 1.9), 293 | (Contrast(1), 0.1, 1.9), 294 | (Brightness(1), 0.1, 1.9), 295 | (Sharpness(1), 0.1, 1.9), 296 | (ShearX(1), 0., 0.3), 297 | (ShearY(1), 0., 0.3), 298 | (TranslateX(1), 0., 0.33), 299 | (TranslateY(1), 0., 0.33), 300 | # (SolarizeAdd(1, m, True), 0, 110) 301 | ] 302 | 303 | def __call__(self, img): 304 | if self.p > random.random(): 305 | ops = random.choices(self.augment_list, k=self.n) 306 | for op, minval, maxval in ops: 307 | val = trans_value(maxval, minval, self.m) 308 | img = op(img, val) 309 | return img 310 | -------------------------------------------------------------------------------- /torchtoolbox/transform/dynamic_transform.py: -------------------------------------------------------------------------------- 1 | __all__ = ['DynamicRandomResizedCrop', 'DynamicResize', 'DynamicCenterCrop', 'DynamicSizeCompose'] 2 | 3 | import abc 4 | from PIL import Image 5 | from torchvision.transforms import functional as F, RandomResizedCrop, Resize, CenterCrop, Compose 6 | from torchvision.transforms.transforms import _setup_size 7 | 8 | 9 | class DynamicSize(abc.ABC): 10 | def __init__(self, size): 11 | self._active_size = size 12 | 13 | @property 14 | def active_size(self): 15 | return self._active_size 16 | 17 | @active_size.setter 18 | def active_size(self, size): 19 | self._active_size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") 20 | 21 | 22 | class DynamicRandomResizedCrop(RandomResizedCrop, DynamicSize): 23 | def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR): 24 | RandomResizedCrop.__init__(self, size, scale=scale, ratio=ratio, interpolation=interpolation) 25 | DynamicSize.__init__(self, self.size) 26 | 27 | def forward(self, img): 28 | i, j, h, w = self.get_params(img, self.scale, self.ratio) 29 | return F.resized_crop(img, i, j, h, w, self._active_size, self.interpolation) 30 | 31 | 32 | class DynamicResize(Resize, DynamicSize): 33 | def __init__(self, size, ratio=None, interpolation=Image.BILINEAR): 34 | Resize.__init__(self, size, interpolation) 35 | DynamicSize.__init__(self, self.size) 36 | self.ratio = ratio if ratio is not None else 1 37 | 38 | @property 39 | def active_size(self): 40 | return self._active_size 41 | 42 | @active_size.setter 43 | def active_size(self, size): 44 | self._active_size = _setup_size(int(size / self.ratio), error_msg="Please provide only two dimensions (h, w) for size.") 45 | 46 | def forward(self, img): 47 | """ 48 | Args: 49 | img (PIL Image or Tensor): Image to be scaled. 50 | 51 | Returns: 52 | PIL Image or Tensor: Rescaled image. 53 | """ 54 | return F.resize(img, self._active_size, self.interpolation) 55 | 56 | 57 | class DynamicCenterCrop(CenterCrop, DynamicSize): 58 | def __init__(self, size): 59 | CenterCrop.__init__(self, size) 60 | DynamicSize.__init__(self, self.size) 61 | 62 | def forward(self, img): 63 | """ 64 | Args: 65 | img (PIL Image or Tensor): Image to be cropped. 66 | 67 | Returns: 68 | PIL Image or Tensor: Cropped image. 69 | """ 70 | return F.center_crop(img, self._active_size) 71 | 72 | 73 | class DynamicSizeCompose(Compose): 74 | def __call__(self, img, size): 75 | for t in self.transforms: 76 | if hasattr(t, 'active_size') and size is not None: 77 | t.active_size = size 78 | img = t(img) 79 | return img 80 | -------------------------------------------------------------------------------- /torchtoolbox/transform/hybrid.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | # This file defind the HybridTransform to handle complex objects. 4 | 5 | import abc 6 | from PIL import Image 7 | from torch import nn 8 | from torchvision.transforms import Compose 9 | from torchvision.transforms import functional as pil_functional 10 | 11 | from . import functional as cv2_functional 12 | from ..objects import BBox 13 | from ..tools import check_twin 14 | 15 | PIL_INTER_MODE = {'NEAREST': Image.NEAREST, 'BILINEAR': Image.BILINEAR, 'BICUBIC': Image.BICUBIC} 16 | 17 | 18 | class HybridCompose(Compose): 19 | def __call__(self, img, obj): 20 | for t in self.transforms: 21 | img, obj = t(img, obj) 22 | return img, obj 23 | 24 | 25 | class HybridTransform(nn.Module, abc.ABC): 26 | def __init__(self, interpolation='BILINEAR', backend='cv2'): 27 | super().__init__() 28 | assert backend in ('cv2', 'pil'), 'Only support cv2 or pil backend.' 29 | self.interpolation = interpolation if backend == 'cv2' else PIL_INTER_MODE[interpolation] 30 | self.backend = backend 31 | 32 | def forward(self, img, obj): 33 | raise NotImplementedError 34 | 35 | 36 | class HybridResize(HybridTransform): 37 | def __init__(self, size, interpolation='BILINEAR', backend='cv2'): 38 | super().__init__(interpolation, backend) 39 | self.size = size 40 | 41 | def resize_bbox(self, bbox: BBox): 42 | pass 43 | 44 | def resize_mask(self, mask): 45 | raise NotImplementedError 46 | 47 | def forward(self, img, obj): 48 | obj = check_twin(obj) 49 | for o in obj: 50 | if isinstance(obj, BBox): 51 | o = self.resize_bbox(o) # TODO:This o won't be returned 52 | else: 53 | raise NotImplementedError(f"{type(o)} is not supported by HybridResize now.") 54 | 55 | if self.backend == 'cv2': 56 | img = cv2_functional.resize(img, self.size, self.interpolation) 57 | else: 58 | img = pil_functional.resize(img, self.size, self.interpolation) 59 | return img, obj 60 | 61 | class HybridCrop(HybridTransform) 62 | 63 | class HybridScale(HybridTransform) 64 | 65 | class HybridHorizontalFlip(HybridTransform) 66 | 67 | class HybridVerticalFlip(HybridTransform) 68 | --------------------------------------------------------------------------------