├── trident ├── optims │ ├── __init__.py │ ├── tensorflow_regularizers.py │ ├── pytorch_regularizers.py │ ├── metrics.py │ ├── pytorch_constraints.py │ ├── pytorch_profiling.py │ └── tensorflow_constraints.py ├── layers │ ├── __init__.py │ └── jax_layers.py ├── misc │ ├── __init__.py │ └── ipython_utils.py ├── reinforcement │ ├── __init__.py │ └── envs │ │ └── __init__.py ├── loggers │ ├── __init__.py │ ├── BaseLogger.py │ ├── logger.py │ ├── mlflow.py │ ├── history.py │ └── mlflow_logger.py ├── data │ ├── __init__.py │ ├── label_common.py │ ├── text_common.py │ ├── image_reader.py │ ├── preprocess_policy.py │ └── augment_policy.py ├── callbacks │ ├── __init__.py │ ├── saving_strategies.py │ └── dataflow_callbacks.py ├── __init__.py ├── models │ ├── coco_classes.txt │ ├── __init__.py │ ├── pretrained_utils.py │ ├── README.md │ ├── pytorch_arcfacenet.py │ ├── pytorch_deeplab.py │ ├── tensorflow_deeplab.py │ ├── tensorflow_mobilenet.py │ ├── pytorch_mobilenet.py │ ├── tensorflow_vgg.py │ └── tensorflow_resnet.py └── backend │ ├── __init__.py │ ├── dtype.py │ ├── decorators.py │ └── load_backend.py ├── images ├── cat.jpg ├── tensorboard.png ├── text_process.png ├── visualization.jpg ├── vision_transform.png ├── tensorboard_graph.png ├── tensorboard_images.png ├── tensorboard_scalar.png └── tensorboard_histogram.png ├── trident_logo.png ├── docs ├── images │ ├── shortcut.png │ ├── trident_logo.png │ ├── quickstart_ouput.png │ ├── quickstart_summary.png │ └── training_snapshot.png └── zh-tw │ ├── main.md │ ├── backend.md │ └── quickstart.md ├── sphinx_docs ├── source │ ├── modules.rst │ ├── trident.rst │ ├── trident.misc.rst │ ├── trident.loggers.rst │ ├── index.rst │ ├── trident.callbacks.rst │ ├── trident.backend.rst │ ├── trident.layers.rst │ ├── trident.data.rst │ ├── trident.models.rst │ ├── trident.optims.rst │ └── conf.py ├── Makefile └── make.bat ├── MANIFEST.in ├── requirements.txt ├── setup.cfg ├── LICENSE ├── LICENSE.txt ├── setup.py ├── .gitignore └── README.md /trident/optims/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /trident/layers/__init__.py: -------------------------------------------------------------------------------- 1 | """trident layers""" 2 | -------------------------------------------------------------------------------- /images/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllanYiin/trident/HEAD/images/cat.jpg -------------------------------------------------------------------------------- /trident_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllanYiin/trident/HEAD/trident_logo.png -------------------------------------------------------------------------------- /images/tensorboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllanYiin/trident/HEAD/images/tensorboard.png -------------------------------------------------------------------------------- /docs/images/shortcut.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllanYiin/trident/HEAD/docs/images/shortcut.png -------------------------------------------------------------------------------- /images/text_process.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllanYiin/trident/HEAD/images/text_process.png -------------------------------------------------------------------------------- /images/visualization.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllanYiin/trident/HEAD/images/visualization.jpg -------------------------------------------------------------------------------- /images/vision_transform.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllanYiin/trident/HEAD/images/vision_transform.png -------------------------------------------------------------------------------- /docs/images/trident_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllanYiin/trident/HEAD/docs/images/trident_logo.png -------------------------------------------------------------------------------- /images/tensorboard_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllanYiin/trident/HEAD/images/tensorboard_graph.png -------------------------------------------------------------------------------- /images/tensorboard_images.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllanYiin/trident/HEAD/images/tensorboard_images.png -------------------------------------------------------------------------------- /images/tensorboard_scalar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllanYiin/trident/HEAD/images/tensorboard_scalar.png -------------------------------------------------------------------------------- /docs/images/quickstart_ouput.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllanYiin/trident/HEAD/docs/images/quickstart_ouput.png -------------------------------------------------------------------------------- /docs/zh-tw/main.md: -------------------------------------------------------------------------------- 1 | ![trident](../images/trident_logo.png) 2 | Multiverses for Deep Learning Developers without Pitfall -------------------------------------------------------------------------------- /images/tensorboard_histogram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllanYiin/trident/HEAD/images/tensorboard_histogram.png -------------------------------------------------------------------------------- /sphinx_docs/source/modules.rst: -------------------------------------------------------------------------------- 1 | trident 2 | ======= 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | trident 8 | -------------------------------------------------------------------------------- /docs/images/quickstart_summary.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllanYiin/trident/HEAD/docs/images/quickstart_summary.png -------------------------------------------------------------------------------- /docs/images/training_snapshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllanYiin/trident/HEAD/docs/images/training_snapshot.png -------------------------------------------------------------------------------- /trident/misc/__init__.py: -------------------------------------------------------------------------------- 1 | from trident.misc.ipython_utils import * 2 | from trident.misc.visualization_utils import * 3 | -------------------------------------------------------------------------------- /docs/zh-tw/backend.md: -------------------------------------------------------------------------------- 1 | # 後台的設定與切換 2 | 3 | 4 | 5 | 目前trident支援pytorch以及tensorflow兩種後台,其中對應需需求版本如下: 6 | 7 | 8 | 9 | ## 宣告後台 10 | 11 | 12 | 13 | ## 數據集 14 | 15 | 16 | 17 | ## 預訓練模型 18 | 19 | 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include *.md 2 | 3 | include LICENSE 4 | recursive-exclude * __pycache__ 5 | recursive-exclude optimizers *.pth *.h5 6 | recursive-include models *.txt 7 | recursive-exclude * onnx_*.py 8 | recursive-exclude * graph_tools.py 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.18 2 | opencv-python 3 | six>=1.13.0 4 | scikit-image>=0.19.0 5 | scipy 6 | beautifulsoup4 7 | scikit-learn 8 | matplotlib>=3.0.2 9 | ipython>=7.10.2 10 | pillow>= 4.1.1 11 | setuptools>=42.0.1 12 | onnx>=1.4.1 13 | tensorboard>=1.15.0 14 | dill>=0.3.1 15 | tqdm 16 | requests 17 | h5py -------------------------------------------------------------------------------- /trident/reinforcement/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from trident.backend.common import get_backend 6 | if get_backend()=='pytorch': 7 | from . import pytorch_policies as policies 8 | 9 | elif get_backend()=='tensorflow': 10 | from . import tensorflow_policies as policies 11 | 12 | from . import utils 13 | 14 | -------------------------------------------------------------------------------- /trident/loggers/__init__.py: -------------------------------------------------------------------------------- 1 | """trident models""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | from trident.backend.common import get_backend 7 | # if get_backend()=='pytorch': 8 | # from . import pytorch_tensorboard as tensorboard 9 | # 10 | # elif get_backend()=='tensorflow': 11 | # from . import tensorflow_tensorboard as tensorboard 12 | # 13 | # from . import mlflow_logger -------------------------------------------------------------------------------- /sphinx_docs/source/trident.rst: -------------------------------------------------------------------------------- 1 | trident package 2 | =============== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | trident.backend 11 | trident.callbacks 12 | trident.data 13 | trident.layers 14 | trident.loggers 15 | trident.misc 16 | trident.models 17 | trident.optims 18 | 19 | Module contents 20 | --------------- 21 | 22 | .. automodule:: trident 23 | :members: 24 | :undoc-members: 25 | :show-inheritance: 26 | -------------------------------------------------------------------------------- /sphinx_docs/source/trident.misc.rst: -------------------------------------------------------------------------------- 1 | trident.misc package 2 | ==================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | trident.misc.ipython\_utils module 8 | ---------------------------------- 9 | 10 | .. automodule:: trident.misc.ipython_utils 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | trident.misc.visualization\_utils module 16 | ---------------------------------------- 17 | 18 | .. automodule:: trident.misc.visualization_utils 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | 24 | Module contents 25 | --------------- 26 | 27 | .. automodule:: trident.misc 28 | :members: 29 | :undoc-members: 30 | :show-inheritance: 31 | -------------------------------------------------------------------------------- /sphinx_docs/source/trident.loggers.rst: -------------------------------------------------------------------------------- 1 | trident.loggers package 2 | ======================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | trident.loggers.BaseLogger module 8 | --------------------------------- 9 | 10 | .. automodule:: trident.loggers.BaseLogger 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | trident.loggers.pytorch\_tensorboard module 16 | ------------------------------------------- 17 | 18 | .. automodule:: trident.loggers.pytorch_tensorboard 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | 24 | Module contents 25 | --------------- 26 | 27 | .. automodule:: trident.loggers 28 | :members: 29 | :undoc-members: 30 | :show-inheritance: 31 | -------------------------------------------------------------------------------- /sphinx_docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /trident/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import threading 7 | 8 | #from trident.data.image_common import * 9 | 10 | from trident.data.utils import * 11 | from trident.data.samplers import * 12 | from trident.data.dataset import * 13 | from trident.data.data_provider import * 14 | 15 | 16 | from trident.data.preprocess_policy import * 17 | from trident.data.augment_policy import * 18 | 19 | from . import label_common 20 | from . import mask_common 21 | from . import bbox_common 22 | from . import text_common 23 | 24 | 25 | 26 | 27 | 28 | from trident.data.image_reader import ImageReader,ImageThread 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | # This includes the license file(s) in the wheel. 3 | # https://wheel.readthedocs.io/en/stable/user_guide.html#including-license-files-in-the-generated-wheel-file 4 | license_files = LICENSE 5 | 6 | [bdist_wheel] 7 | # This flag says to generate wheels that support both Python 2 and Python 8 | # 3. If your code will not run unchanged on both Python 2 and 3, you will 9 | # need to generate separate wheels for each Python version that you 10 | # support. Removing this line (or setting universal to 0) will prevent 11 | # bdist_wheel from trying to make a universal wheel. For more see: 12 | # https://packaging.python.org/guides/distributing-packages-using-setuptools/#wheels 13 | universal=0 14 | 15 | [flake8] 16 | max-line-length = 179 17 | exclude =.git,__pycache__,build,dist,examples,tests -------------------------------------------------------------------------------- /trident/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | """trident callbacks""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | 7 | from trident.callbacks.callback_base import * 8 | from trident.callbacks.lr_schedulers import AdjustLRCallback,ReduceLROnPlateau,reduce_lr_on_plateau,lambda_lr,LambdaLR,RandomCosineLR,random_cosine_lr,CosineLR,cosine_lr,StepLR,PolyLR 9 | # from trident.callbacks.saving_strategies import * 10 | from trident.callbacks.visualization_callbacks import * 11 | # 12 | from trident.callbacks.regularization_callbacks import RegularizationCallbacksBase, MixupCallback, CutMixCallback,GradientClippingCallback 13 | # from trident.callbacks.data_flow_callbacks import DataProcessCallback 14 | # from trident.callbacks import gan_callbacks 15 | # 16 | -------------------------------------------------------------------------------- /trident/__init__.py: -------------------------------------------------------------------------------- 1 | """trident api""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | import sys 6 | import os 7 | from importlib import reload 8 | from sys import stderr 9 | 10 | defaultencoding = 'utf-8' 11 | if sys.getdefaultencoding() != defaultencoding: 12 | reload(sys) 13 | sys.setdefaultencoding(defaultencoding) 14 | 15 | PACKAGE_ROOT = os.path.dirname(__file__) 16 | PROJECT_ROOT = os.path.dirname(PACKAGE_ROOT) 17 | 18 | __version__ = '0.7.9' 19 | stderr.write('trident {0}\n'.format(__version__)) 20 | 21 | from trident import context 22 | from trident import readable_errors 23 | from trident.backend import * 24 | import threading 25 | import random 26 | import cv2 27 | import glob 28 | from tqdm import tqdm 29 | import numpy as np 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /sphinx_docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. trident documentation master file, created by 2 | sphinx-quickstart on Sat May 9 14:04:30 2020. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to trident's documentation! 7 | =================================== 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | .. image:: _static/trident_logo.png 14 | :width: 500px 15 | :scale: 50 % 16 | :alt: alternate text 17 | :align: left 18 | 19 | #. *Classify cancer using simulated data (Logistic Regression)* 20 | Quick Start:`Logistic Regression `_ with NumPy (:tridenttw:`source `) 21 | 22 | 23 | 24 | Indices and tables 25 | ================== 26 | 27 | * :ref:`genindex` 28 | * :ref:`modindex` 29 | * :ref:`search` 30 | -------------------------------------------------------------------------------- /sphinx_docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /trident/models/coco_classes.txt: -------------------------------------------------------------------------------- 1 | person 2 | bicycle 3 | car 4 | motorbike 5 | aeroplane 6 | bus 7 | train 8 | truck 9 | boat 10 | traffic_light 11 | fire hydrant 12 | stop sign 13 | parking meter 14 | bench 15 | bird 16 | cat 17 | dog 18 | horse 19 | sheep 20 | cow 21 | elephant 22 | bear 23 | zebra 24 | giraffe 25 | backpack 26 | umbrella 27 | handbag 28 | tie 29 | suitcase 30 | frisbee 31 | skis 32 | snowboard 33 | sports ball 34 | kite 35 | baseball bat 36 | baseball glove 37 | skateboard 38 | surfboard 39 | tennis racket 40 | bottle 41 | wine glass 42 | cup 43 | fork 44 | knife 45 | spoon 46 | bowl 47 | banana 48 | apple 49 | sandwich 50 | orange 51 | broccoli 52 | carrot 53 | hot dog 54 | pizza 55 | donut 56 | cake 57 | chair 58 | sofa 59 | pottedplant 60 | bed 61 | diningtable 62 | toilet 63 | tvmonitor 64 | laptop 65 | mouse 66 | remote 67 | keyboard 68 | cell phone 69 | microwave 70 | oven 71 | toaster 72 | sink 73 | refrigerator 74 | book 75 | clock 76 | vase 77 | scissors 78 | teddy bear 79 | hair drier 80 | toothbrush -------------------------------------------------------------------------------- /trident/misc/ipython_utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import gc 3 | import os 4 | import sys 5 | import traceback 6 | import warnings 7 | 8 | 9 | 10 | def is_in_ipython(): 11 | "Is the code running in the ipython environment (jupyter including)" 12 | program_name = os.path.basename(os.getenv('_', '')) 13 | 14 | if ('jupyter-notebook' in program_name or # jupyter-notebook 15 | 'ipython' in program_name or # ipython 16 | 'jupyter' in program_name or # jupyter 17 | 'JPY_PARENT_PID' in os.environ): # ipython-notebook 18 | return True 19 | else: 20 | return False 21 | IS_IN_IPYTHON = is_in_ipython() 22 | 23 | 24 | def is_in_colab(): 25 | if not is_in_ipython(): return False 26 | try: 27 | from google import colab 28 | return True 29 | except: 30 | return False 31 | IS_IN_COLAB = is_in_colab() 32 | 33 | def is_in_kaggle_kernel(): 34 | if 'kaggle' in os.environ['PYTHONPATH']: 35 | return True 36 | else: 37 | return False 38 | 39 | 40 | 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 AllanYiin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /trident/loggers/BaseLogger.py: -------------------------------------------------------------------------------- 1 | 2 | class BaseLogger(object): 3 | """ 4 | Base logger handler. See implementations: TensorboardLogger, VisdomLogger, PolyaxonLogger, MLflowLogger, ... 5 | 6 | """ 7 | def log_metrics(self, metrics, step): 8 | 9 | """Record metrics. 10 | :param float metrics: Dictionary with metric names as keys and measured quanties as values 11 | :param int|None step: Step number at which the metrics should be recorded 12 | """ 13 | raise NotImplementedError() 14 | 15 | def log_aggregate(self, agg, step): 16 | 17 | """Record metrics. 18 | :param float metric: Dictionary with metric names as keys and measured quanties as values 19 | :param int|None step: Step number at which the metrics should be recorded 20 | """ 21 | raise NotImplementedError() 22 | 23 | 24 | 25 | def save(self): 26 | 27 | """Save log data.""" 28 | pass 29 | 30 | def __enter__(self): 31 | return self 32 | 33 | def __exit__(self, type, value, traceback): 34 | self.close() 35 | 36 | def close(self): 37 | pass 38 | 39 | 40 | 41 | 42 | 43 | pass 44 | 45 | -------------------------------------------------------------------------------- /trident/loggers/logger.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import operator 3 | from abc import ABC, abstractmethod 4 | 5 | class BaseLogger(ABC): 6 | """ 7 | Base logger handler. See implementations: TensorboardLogger, VisdomLogger, PolyaxonLogger, MLflowLogger, ... 8 | 9 | """ 10 | 11 | def log_metrics(self, metric_name,value, step): 12 | 13 | """Record metrics. 14 | :param float metrics: Dictionary with metric names as keys and measured quanties as values 15 | :param int|None step: Step number at which the metrics should be recorded 16 | """ 17 | raise NotImplementedError() 18 | 19 | def log_aggregate(self, agg, step): 20 | 21 | """Record metrics. 22 | :param float metric: Dictionary with metric names as keys and measured quanties as values 23 | :param int|None step: Step number at which the metrics should be recorded 24 | """ 25 | raise NotImplementedError() 26 | 27 | 28 | 29 | def save(self): 30 | 31 | """Save log data.""" 32 | pass 33 | 34 | def __enter__(self): 35 | return self 36 | 37 | def __exit__(self, type, value, traceback): 38 | self.close() 39 | 40 | def close(self): 41 | pass 42 | 43 | 44 | 45 | 46 | 47 | pass 48 | 49 | -------------------------------------------------------------------------------- /trident/loggers/mlflow.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import operator 3 | from abc import ABC, abstractmethod 4 | from trident.loggers.logger import BaseLogger 5 | 6 | class MLFlowLogger(BaseLogger): 7 | """ 8 | Base logger handler. See implementations: TensorboardLogger, VisdomLogger, PolyaxonLogger, MLflowLogger, ... 9 | 10 | """ 11 | 12 | def log_metrics(self, metrics, step): 13 | 14 | """Record metrics. 15 | :param float metrics: Dictionary with metric names as keys and measured quanties as values 16 | :param int|None step: Step number at which the metrics should be recorded 17 | """ 18 | raise NotImplementedError() 19 | 20 | def log_aggregate(self, agg, step): 21 | 22 | """Record metrics. 23 | :param float metric: Dictionary with metric names as keys and measured quanties as values 24 | :param int|None step: Step number at which the metrics should be recorded 25 | """ 26 | raise NotImplementedError() 27 | 28 | 29 | 30 | def save(self): 31 | 32 | """Save log data.""" 33 | pass 34 | 35 | def __enter__(self): 36 | return self 37 | 38 | def __exit__(self, type, value, traceback): 39 | self.close() 40 | 41 | def close(self): 42 | pass 43 | 44 | 45 | 46 | 47 | 48 | pass 49 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | COPYRIGHT 2 | All contributions by Allan Yiin: 3 | Copyright (c) 2019 - 2020, Allan Yiin. 4 | All rights reserved. 5 | 6 | Each contributor holds copyright over their respective contributions. 7 | The project versioning (Git) records all such contribution source information. 8 | 9 | LICENSE 10 | The MIT License (MIT) 11 | 12 | Permission is hereby granted, free of charge, to any person obtaining a copy 13 | of this software and associated documentation files (the "Software"), to deal 14 | in the Software without restriction, including without limitation the rights 15 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 16 | copies of the Software, and to permit persons to whom the Software is 17 | furnished to do so, subject to the following conditions: 18 | 19 | The above copyright notice and this permission notice shall be included in all 20 | copies or substantial portions of the Software. 21 | 22 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 23 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 24 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 25 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 26 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 27 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 28 | SOFTWARE. -------------------------------------------------------------------------------- /sphinx_docs/source/trident.callbacks.rst: -------------------------------------------------------------------------------- 1 | trident.callbacks package 2 | ========================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | trident.callbacks.data\_flow\_callbacks module 8 | ---------------------------------------------- 9 | 10 | .. automodule:: trident.callbacks.data_flow_callbacks 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | trident.callbacks.gan\_callbacks module 16 | --------------------------------------- 17 | 18 | .. automodule:: trident.callbacks.gan_callbacks 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | trident.callbacks.lr\_schedulers module 24 | --------------------------------------- 25 | 26 | .. automodule:: trident.callbacks.lr_schedulers 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | trident.callbacks.regularization\_callbacks module 32 | -------------------------------------------------- 33 | 34 | .. automodule:: trident.callbacks.regularization_callbacks 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | trident.callbacks.saving\_strategies module 40 | ------------------------------------------- 41 | 42 | .. automodule:: trident.callbacks.saving_strategies 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | trident.callbacks.visualization\_callbacks module 48 | ------------------------------------------------- 49 | 50 | .. automodule:: trident.callbacks.visualization_callbacks 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | 56 | Module contents 57 | --------------- 58 | 59 | .. automodule:: trident.callbacks 60 | :members: 61 | :undoc-members: 62 | :show-inheritance: 63 | -------------------------------------------------------------------------------- /trident/models/__init__.py: -------------------------------------------------------------------------------- 1 | """trident models""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | from trident.backend.common import get_backend,compile_and_install_module 7 | 8 | if get_backend()=='pytorch': 9 | 10 | from . import pytorch_efficientnet as efficientnet 11 | from . import pytorch_resnet as resnet 12 | from . import pytorch_vgg as vgg 13 | from . import pytorch_deeplab as deeplab 14 | from . import pytorch_senet as senet 15 | from . import pytorch_densenet as densenet 16 | from . import pytorch_bisenet as bisenet 17 | 18 | #from . import pytorch_efficientnetv2 as efficientnet_v2 19 | from . import pytorch_mobilenet as mobilenet 20 | from . import pytorch_yolo as yolo 21 | from . import pytorch_arcfacenet as arcfacenet 22 | from . import pytorch_mtcnn as mtcnn 23 | from . import pytorch_rfbnet as rfbnet 24 | from . import pytorch_ssd as ssd 25 | 26 | from . import pytorch_embedded as embedded 27 | from . import pytorch_inception as inception 28 | from . import pytorch_visual_transformer as visual_transformer 29 | elif get_backend()=='tensorflow': 30 | from . import tensorflow_efficientnet as efficientnet 31 | from . import tensorflow_resnet as resnet 32 | from . import tensorflow_vgg as vgg 33 | 34 | from . import tensorflow_densenet as densenet 35 | from . import tensorflow_mobilenet as mobilenet 36 | from . import tensorflow_deeplab as deeplab 37 | from . import tensorflow_mtcnn as mtcnn 38 | from . import tensorflow_arcfacenet as arcfacenet 39 | 40 | #__all__ = ['vgg','resnet','densenet','efficientnet','mobilenet','gan','deeplab','arcfacenet','mtcnn','rfbnet','ssd','yolo'] 41 | 42 | 43 | -------------------------------------------------------------------------------- /trident/callbacks/saving_strategies.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | import os 5 | import math 6 | import warnings 7 | 8 | import numpy as np 9 | 10 | from trident.backend.common import * 11 | from trident.backend.common import get_backend 12 | from trident.callbacks.callback_base import * 13 | from trident.context import split_path, make_dir_if_need, sanitize_path 14 | if get_backend()=='pytorch': 15 | from trident.backend.pytorch_ops import to_numpy,to_tensor 16 | elif get_backend()=='tensorflow': 17 | from trident.backend.tensorflow_ops import to_numpy,to_tensor 18 | elif get_backend()=='jax': 19 | from trident.backend.jax_ops import to_numpy,to_tensor 20 | 21 | __all__ = ['SavingStrategyCallback','CyclicSavingStrategyCallback'] 22 | 23 | class SavingStrategyCallback(CallbackBase): 24 | def __init__(self, **kwargs): 25 | super(SavingStrategyCallback, self).__init__() 26 | 27 | def on_model_saving_start(self, training_context): 28 | pass 29 | 30 | def on_model_saving_end(self, training_context): 31 | pass 32 | 33 | 34 | 35 | class CyclicSavingStrategyCallback(SavingStrategyCallback): 36 | def __init__(self, repeat_period=5,**kwargs): 37 | super(CyclicSavingStrategyCallback, self).__init__() 38 | self.repeat_period=repeat_period 39 | self.counter=0 40 | self.origin_save_path =None 41 | 42 | 43 | def on_model_saving_start(self, training_context): 44 | if 'save_path' in training_context and self.origin_save_path is None: 45 | self.origin_save_path = training_context['save_path'] 46 | folder,filename,ext=split_path(self.origin_save_path) 47 | training_context['save_path']=os.path.join(folder,filename+'_{0}'.format(self.counter)+ext) 48 | 49 | def on_model_saving_end(self, training_context): 50 | training_context['save_path']= self.origin_save_path 51 | self.origin_save_path =None 52 | 53 | -------------------------------------------------------------------------------- /sphinx_docs/source/trident.backend.rst: -------------------------------------------------------------------------------- 1 | trident.backend package 2 | ======================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | trident.backend.common module 8 | ----------------------------- 9 | 10 | .. automodule:: trident.backend.common 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | 16 | 17 | trident.backend.iteration\_tools module 18 | --------------------------------------- 19 | 20 | .. automodule:: trident.backend.iteration_tools 21 | :members: 22 | :undoc-members: 23 | :show-inheritance: 24 | 25 | trident.backend.load\_backend module 26 | ------------------------------------ 27 | 28 | .. automodule:: trident.backend.load_backend 29 | :members: 30 | :undoc-members: 31 | :show-inheritance: 32 | 33 | trident.backend.opencv\_backend module 34 | -------------------------------------- 35 | 36 | .. automodule:: trident.backend.opencv_backend 37 | :members: 38 | :undoc-members: 39 | :show-inheritance: 40 | 41 | 42 | 43 | trident.backend.pillow\_backend module 44 | -------------------------------------- 45 | 46 | .. automodule:: trident.backend.pillow_backend 47 | :members: 48 | :undoc-members: 49 | :show-inheritance: 50 | 51 | trident.backend.pytorch\_backend module 52 | --------------------------------------- 53 | 54 | .. automodule:: trident.backend.pytorch_backend 55 | :members: 56 | :undoc-members: 57 | :show-inheritance: 58 | 59 | trident.backend.pytorch\_ops module 60 | ----------------------------------- 61 | 62 | .. automodule:: trident.backend.pytorch_ops 63 | :members: 64 | :undoc-members: 65 | :show-inheritance: 66 | 67 | trident.backend.tensorflow\_backend module 68 | ------------------------------------------ 69 | 70 | .. automodule:: trident.backend.tensorflow_backend 71 | :members: 72 | :undoc-members: 73 | :show-inheritance: 74 | 75 | trident.backend.tensorflow\_ops module 76 | -------------------------------------- 77 | 78 | .. automodule:: trident.backend.tensorflow_ops 79 | :members: 80 | :undoc-members: 81 | :show-inheritance: 82 | 83 | trident.backend.tensorflow\_serialization module 84 | ------------------------------------------------ 85 | 86 | .. automodule:: trident.backend.tensorflow_serialization 87 | :members: 88 | :undoc-members: 89 | :show-inheritance: 90 | 91 | 92 | Module contents 93 | --------------- 94 | 95 | .. automodule:: trident.backend 96 | :members: 97 | :undoc-members: 98 | :show-inheritance: 99 | -------------------------------------------------------------------------------- /trident/optims/tensorflow_regularizers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | import tensorflow as tf 5 | 6 | from trident.backend.common import get_session, get_function 7 | from trident.backend.tensorflow_backend import Layer 8 | from trident.backend.tensorflow_ops import * 9 | __all__ = ['l1_reg','l2_reg','orth_reg','get_reg','total_variation_norm_reg'] 10 | 11 | 12 | _session=get_session() 13 | _epsilon=_session.epsilon 14 | 15 | 16 | def l1_reg(model:Layer,reg_weight=1e-6): 17 | loss=0.0 18 | for name, param in model.named_parameters(): 19 | if param.trainable: 20 | param_ = where(is_abnormal_number(param.value()), zeros_like(param), param.value()) 21 | if 'bias' not in name : 22 | loss = loss + (reg_weight * reduce_sum(abs(param_))) 23 | return loss 24 | 25 | 26 | def l2_reg(model:Layer ,reg_weight=1e-6): 27 | loss = 0.0 28 | for name, param in model.named_parameters(): 29 | 30 | if param.trainable: 31 | param_ = where(is_abnormal_number(param.value()), zeros_like(param), param.value()) 32 | if 'bias' not in name : 33 | loss = loss + reg_weight *reduce_sum(square(param_)) 34 | return loss 35 | 36 | 37 | def orth_reg(model:tf.Module,reg_weight=1e-6): 38 | loss = 0.0 39 | for param in model.trainable_weights: 40 | if param.trainable: 41 | if not any_abnormal_number(param) and param.trainable: 42 | param_flat =tf.reshape(param,(param.int_shape[0], -1)) 43 | sym =tf.math.multiply(param_flat, tf.transpose(param_flat)) 44 | sym -= tf.eye(param_flat.int_shape[0]) 45 | loss = loss +reduce_sum(reg_weight * abs(sym)) 46 | return loss 47 | 48 | def total_variation_norm_reg(output:tf.Tensor,reg_weight=1e-6): 49 | diff_i = tf.math.reduce_sum(tf.math.pow(output[:, :, 1:,:] - output[:, :, :-1,:], 2)) 50 | diff_j = tf.math.reduce_sum(tf.math.pow(output[:, 1:, :,:] - output[:, :-1, :,:], 2)) 51 | tv_loss = (diff_i + diff_j) 52 | loss = reg_weight * tv_loss 53 | return loss 54 | 55 | 56 | def get_reg(reg_name): 57 | if reg_name is None: 58 | return None 59 | if '_reg' not in reg_name: 60 | reg_name=reg_name+'_reg' 61 | reg_modules = ['trident.optims.tensorflow_regularizers'] 62 | reg_fn = get_function(reg_name, reg_modules) 63 | return reg_fn -------------------------------------------------------------------------------- /sphinx_docs/source/trident.layers.rst: -------------------------------------------------------------------------------- 1 | trident.layers package 2 | ====================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | trident.layers.pytorch\_activations module 8 | ------------------------------------------ 9 | 10 | .. automodule:: trident.layers.pytorch_activations 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | trident.layers.pytorch\_blocks module 16 | ------------------------------------- 17 | 18 | .. automodule:: trident.layers.pytorch_blocks 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | trident.layers.pytorch\_layers module 24 | ------------------------------------- 25 | 26 | .. automodule:: trident.layers.pytorch_layers 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | trident.layers.pytorch\_normalizations module 32 | --------------------------------------------- 33 | 34 | .. automodule:: trident.layers.pytorch_normalizations 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | trident.layers.pytorch\_pooling module 40 | -------------------------------------- 41 | 42 | .. automodule:: trident.layers.pytorch_pooling 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | trident.layers.tensorflow\_activations module 48 | --------------------------------------------- 49 | 50 | .. automodule:: trident.layers.tensorflow_activations 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | trident.layers.tensorflow\_blocks module 56 | ---------------------------------------- 57 | 58 | .. automodule:: trident.layers.tensorflow_blocks 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | trident.layers.tensorflow\_layers module 64 | ---------------------------------------- 65 | 66 | .. automodule:: trident.layers.tensorflow_layers 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | trident.layers.tensorflow\_normalizations module 72 | ------------------------------------------------ 73 | 74 | .. automodule:: trident.layers.tensorflow_normalizations 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | 79 | trident.layers.tensorflow\_pooling module 80 | ----------------------------------------- 81 | 82 | .. automodule:: trident.layers.tensorflow_pooling 83 | :members: 84 | :undoc-members: 85 | :show-inheritance: 86 | 87 | 88 | Module contents 89 | --------------- 90 | 91 | .. automodule:: trident.layers 92 | :members: 93 | :undoc-members: 94 | :show-inheritance: 95 | -------------------------------------------------------------------------------- /trident/callbacks/dataflow_callbacks.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import math 6 | import warnings 7 | 8 | import numpy as np 9 | 10 | from trident.backend.common import * 11 | from trident.backend.common import get_backend 12 | from trident.callbacks.callback_base import CallbackBase,_valid_when 13 | from trident.data.image_common import * 14 | from trident.context import split_path, make_dir_if_need, sanitize_path 15 | if get_backend()=='pytorch': 16 | import torch 17 | import torch.nn as nn 18 | from trident.backend.pytorch_ops import to_numpy,to_tensor 19 | 20 | elif get_backend()=='tensorflow': 21 | from trident.backend.tensorflow_ops import to_numpy,to_tensor 22 | elif get_backend()=='jax': 23 | from trident.backend.jax_ops import to_numpy,to_tensor 24 | 25 | 26 | __all__ = ['DataProcessCallback'] 27 | 28 | class DataProcessCallback(CallbackBase): 29 | def __init__(self, when='on_data_received',policy=None, **kwargs): 30 | super(DataProcessCallback, self).__init__() 31 | if when in _valid_when: 32 | self.when = when 33 | else: 34 | raise ValueError("{0} is not valid event trigger.".format(when)) 35 | self.policy = policy 36 | 37 | def on_batch_start(self, training_context): 38 | try: 39 | train_data = training_context['train_data'] 40 | test_data = training_context['test_data'] 41 | data_provider 42 | 43 | input = train_data[train_data.key_list[0]] 44 | new_input = [] 45 | for i in range(input.shape[0]): 46 | try: 47 | new_input.append(self.policy(input[i])) 48 | except: 49 | new_input.append(input[i]) 50 | 51 | new_input = np.array(new_input).astype(np.float32) 52 | 53 | train_data[train_data.key_list[0]] = new_input 54 | 55 | if test_data is not None and len(test_data) > 0: 56 | input = test_data[test_data.key_list[0]] 57 | new_input = [] 58 | for i in range(input.shape[0]): 59 | try: 60 | new_input.append(self.policy(input[i])) 61 | except: 62 | new_input.append(input[i]) 63 | 64 | new_input = np.array(new_input).astype(np.float32) 65 | 66 | test_data[test_data.key_list[0]] = new_input 67 | except Exception as e: 68 | print(e) 69 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import pathlib 3 | import os 4 | import pkg_resources 5 | from setuptools import setup, find_packages 6 | 7 | # with open("README.md", "r",encoding='utf-8-sig') as fh: 8 | # long_description = fh.read() 9 | 10 | 11 | 12 | NAME = "tridentx" 13 | DIR = '.' 14 | 15 | PACKAGES = find_packages(exclude= ["tests","tests.*","sphinx_docs","sphinx_docs.*", "examples","examples.*","internal_tool","internal_tool.*"]) 16 | print(PACKAGES) 17 | 18 | 19 | 20 | 21 | with pathlib.Path('requirements.txt').open() as requirements_txt: 22 | install_requires = [ 23 | str(requirement) 24 | for requirement 25 | in pkg_resources.parse_requirements(requirements_txt) 26 | ] 27 | 28 | 29 | 30 | 31 | setup(name=NAME, 32 | version='0.7.9', 33 | description='Make pytorch and tensorflow two become one.', 34 | # long_description=long_description, 35 | # long_description_content_type="text/markdown", 36 | long_description=open("README.md", encoding="utf-8").read(), 37 | long_description_content_type="text/markdown", 38 | author='Allan Yiin', 39 | author_email='allanyiin.ai@gmail.com', 40 | download_url='https://test.pypi.org/project/tridentx', 41 | license='MIT', 42 | install_requires=install_requires, 43 | extras_require={ 44 | 'visualize': ['pydot>=1.2.4',], 45 | 'tests': ['pytest', 46 | 'pytest-pep8', 47 | 'pytest-xdist', 48 | 'flaky', 49 | 'pytest-cov', 50 | 'requests', 51 | 'markdown'], 52 | }, 53 | classifiers=[ 54 | 'Development Status :: 4 - Beta', 55 | 'Intended Audience :: Developers', 56 | 'Intended Audience :: Education', 57 | 'Intended Audience :: Science/Research', 58 | 'License :: OSI Approved :: MIT License', 59 | 'Programming Language :: Python :: 3.5', 60 | 'Programming Language :: Python :: 3.6', 61 | 'Programming Language :: Python :: 3.7', 62 | 'Programming Language :: Python :: 3.8', 63 | 'Programming Language :: Python :: 3.9', 64 | 'Topic :: Software Development :: Libraries', 65 | 'Topic :: Software Development :: Libraries :: Python Modules' 66 | ], 67 | python_requires='>=3.5', 68 | keywords=['deep learning', 'machine learning', 'pytorch', 'tensorflow', 'AI'], 69 | packages= find_packages(exclude= ["tests","tests.*","sphinx_docs","sphinx_docs.*", "examples","examples.*","internal_tool","internal_tool.*"]), 70 | package_data={ 71 | 'trident': ['data/*.txt','models/*.txt'], 72 | }, 73 | include_package_data=True, 74 | scripts=[], 75 | 76 | ) 77 | 78 | -------------------------------------------------------------------------------- /sphinx_docs/source/trident.data.rst: -------------------------------------------------------------------------------- 1 | trident.data package 2 | ==================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | trident.data.augment\_policy module 8 | ----------------------------------- 9 | 10 | .. automodule:: trident.data.augment_policy 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | trident.data.bbox module 16 | ------------------------ 17 | 18 | .. automodule:: trident.data.bbox 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | trident.data.bbox\_common module 24 | -------------------------------- 25 | 26 | .. automodule:: trident.data.bbox_common 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | trident.data.data\_loaders module 32 | --------------------------------- 33 | 34 | .. automodule:: trident.data.data_loaders 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | trident.data.data\_provider module 40 | ---------------------------------- 41 | 42 | .. automodule:: trident.data.data_provider 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | trident.data.dataset module 48 | --------------------------- 49 | 50 | .. automodule:: trident.data.dataset 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | trident.data.image\_common module 56 | --------------------------------- 57 | 58 | .. automodule:: trident.data.image_common 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | trident.data.image\_reader module 64 | --------------------------------- 65 | 66 | .. automodule:: trident.data.image_reader 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | trident.data.label\_common module 72 | --------------------------------- 73 | 74 | .. automodule:: trident.data.label_common 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | 79 | trident.data.mask\_common module 80 | -------------------------------- 81 | 82 | .. automodule:: trident.data.mask_common 83 | :members: 84 | :undoc-members: 85 | :show-inheritance: 86 | 87 | trident.data.preprocess\_policy module 88 | -------------------------------------- 89 | 90 | .. automodule:: trident.data.preprocess_policy 91 | :members: 92 | :undoc-members: 93 | :show-inheritance: 94 | 95 | trident.data.samplers module 96 | ---------------------------- 97 | 98 | .. automodule:: trident.data.samplers 99 | :members: 100 | :undoc-members: 101 | :show-inheritance: 102 | 103 | trident.data.utils module 104 | ------------------------- 105 | 106 | .. automodule:: trident.data.utils 107 | :members: 108 | :undoc-members: 109 | :show-inheritance: 110 | 111 | 112 | Module contents 113 | --------------- 114 | 115 | .. automodule:: trident.data 116 | :members: 117 | :undoc-members: 118 | :show-inheritance: 119 | -------------------------------------------------------------------------------- /trident/optims/pytorch_regularizers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | import torch.nn as nn 7 | from trident.backend.common import get_session,addindent,get_time_suffix,get_class,format_time,get_terminal_size,snake2camel,camel2snake,get_function 8 | from trident.backend.pytorch_ops import * 9 | 10 | _session = get_session() 11 | _epsilon = _session.epsilon 12 | 13 | __all__ = ['l1_reg','l2_reg','orth_reg','get_reg','total_variation_norm_reg'] 14 | 15 | 16 | def l1_reg(model:nn.Module,reg_weight=1e-7): 17 | #with torch.enable_grad(): 18 | loss =to_tensor(0.0,requires_grad=True) 19 | for name, param in model.named_parameters(): 20 | param_ = where(is_abnormal_number(param.data), zeros_like(param), param.data) 21 | if 'bias' not in name and param.requires_grad==True: 22 | loss = loss + (reg_weight * abs(param_).sum()) 23 | return loss 24 | 25 | 26 | def l2_reg(model:nn.Module,reg_weight=1e-6): 27 | loss =to_tensor(0.0,requires_grad=True) 28 | for name, param in model.named_parameters(): 29 | param_ = where(is_abnormal_number(param.data), zeros_like(param), param.data) 30 | if 'bias' not in name and param.requires_grad==True: 31 | loss=loss+ reg_weight *(param_**2).sum() 32 | return loss 33 | 34 | 35 | def orth_reg(model:nn.Module,reg_weight=1e-6): 36 | loss =to_tensor(0.0,requires_grad=True) 37 | for name, param in model.named_parameters(): 38 | param_ = where(is_abnormal_number(param.data), zeros_like(param), param.data) 39 | if 'bias' not in name and param.requires_grad==True: 40 | param_flat = param.view(param_.shape[0], -1) 41 | sym = torch.mm(param_flat, torch.t(param_flat)) 42 | sym -= torch.eye(param_flat.shape[0]) 43 | loss = loss + (reg_weight * sym.abs().sum()) 44 | return loss 45 | 46 | def total_variation_norm_reg(output:torch.Tensor,reg_weight=1): 47 | b,c,h,w=int_shape(output) 48 | total_variation_norm_reg.reg_weight=reg_weight 49 | assert len(output.size())==4 50 | count_w = output.size()[2] * (output.size()[3] - 1) 51 | count_h = (output.size()[2] - 1) * output.size()[3] 52 | 53 | diff_i = torch.sum(torch.pow(output[:, :, :, 1:] - output[:, :, :, :-1], 2)) 54 | diff_j = torch.sum(torch.pow(output[:, :, 1:, :] - output[:, :, :-1, :], 2)) 55 | tv_loss = true_divide(2*(true_divide(diff_i,count_w) + true_divide(diff_j,count_h)),b) 56 | return reg_weight * tv_loss 57 | 58 | 59 | def get_reg(reg_name): 60 | if reg_name is None: 61 | return None 62 | if reg_name=='l1': 63 | return l1_reg 64 | elif reg_name=='l2': 65 | return l2_reg 66 | else: 67 | if '_reg' not in reg_name: 68 | reg_name=reg_name+'_reg' 69 | 70 | reg_modules = ['trident.optims.pytorch_regularizers'] 71 | reg_fn = get_function(camel2snake(reg_name), reg_modules) 72 | return reg_fn -------------------------------------------------------------------------------- /trident/data/label_common.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import math 6 | import os 7 | import random 8 | import re 9 | 10 | import numpy as np 11 | import six 12 | import numbers 13 | from trident.backend.common import * 14 | from trident.backend.tensorspec import * 15 | from trident.backend.load_backend import * 16 | 17 | __all__ = ['label_backend_adaptive','get_onehot','check_is_onehot'] 18 | 19 | 20 | 21 | if get_backend()== 'pytorch': 22 | from trident.backend.pytorch_backend import to_numpy, to_tensor 23 | from trident.backend.pytorch_ops import int_shape 24 | import torch 25 | elif get_backend()== 'tensorflow': 26 | from trident.backend.tensorflow_backend import to_numpy, to_tensor 27 | from trident.backend.tensorflow_ops import int_shape 28 | from trident.backend.tensorspec import * 29 | 30 | def get_onehot(idx,len): 31 | if idx>=len: 32 | raise ValueError('') 33 | arr=np.zeros(len,dtype=np.float32) 34 | arr[idx]=1 35 | return arr 36 | 37 | 38 | def check_is_onehot(label): 39 | label1=label.copy() 40 | label1[label1 > 0] = 1 41 | mean_lable1=label1.mean() 42 | if mean_lable1 < 2 * 1 / float(label.shape[-1]): 43 | return True 44 | else: 45 | return False 46 | 47 | def label_backend_adaptive(label,label_mapping=None,object_type=None,**kwargs): 48 | if get_backend()== 'pytorch': 49 | if isinstance(label,np.ndarray): 50 | # binary mask 51 | if object_type == ObjectType.binary_mask: 52 | if label.ndim==2 : 53 | label[label > 0] = 1 54 | return label.astype(np.int64) 55 | elif label.ndim==3 and label.shape[-1] in [1,2]: 56 | if label.shape[-1] ==2: 57 | label=label[:,:,1] 58 | elif label.shape[-1] ==1: 59 | label = label[:, :,0] 60 | label[label > 0] = 1 61 | return label.astype(np.int64) 62 | elif object_type == ObjectType.label_mask: 63 | if label.ndim==2 : 64 | return label.astype(np.int64) 65 | if label.ndim == 3 and label.shape[-1] >2: 66 | if check_is_onehot(label): 67 | label=np.argmax(label,-1).astype(np.int64) 68 | return label 69 | label = label.astype(np.int64) 70 | elif isinstance(label, int): 71 | return label 72 | return label 73 | elif get_backend()== 'tensorflow': 74 | if isinstance(label, numbers.Integral): 75 | if isinstance(label_mapping, dict) and len(label_mapping) > 0: 76 | label_mapping = list(label_mapping.values())[0] 77 | label = get_onehot(label, len(label_mapping)) 78 | return label 79 | elif label_mapping is None: 80 | return label 81 | return label 82 | 83 | 84 | -------------------------------------------------------------------------------- /trident/optims/metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | import copy 5 | import inspect 6 | import os 7 | import sys 8 | import builtins 9 | import random 10 | import shutil 11 | import string 12 | import sys 13 | import time 14 | import uuid 15 | import json 16 | from typing import Callable, Any 17 | 18 | import numpy as np 19 | from trident.context import split_path, make_dir_if_need, sanitize_path 20 | 21 | from trident.backend.common import to_list, addindent, get_time_suffix, format_time, get_terminal_size, get_session, \ 22 | snake2camel, PrintException, unpack_singleton, enforce_singleton, OrderedDict, if_none 23 | from trident.backend.tensorspec import * 24 | _session = get_session() 25 | _backend = _session.backend 26 | if _backend == 'pytorch': 27 | from trident.backend.pytorch_backend import * 28 | from trident.backend.pytorch_ops import * 29 | 30 | elif _backend == 'tensorflow': 31 | from trident.backend.tensorflow_backend import * 32 | from trident.backend.tensorflow_ops import * 33 | 34 | 35 | 36 | class Metric(object): 37 | def __init__(self, *args, **kwargs): 38 | super().__init__() 39 | self.name=kwargs.get('name') 40 | self.format=kwargs.get('format') 41 | self.aggregate = kwargs.get('aggregate') 42 | 43 | 44 | def _forward_unimplemented(self, output: Any, target: Any, **kwargs) -> None: 45 | raise NotImplementedError 46 | calculate_metric=_forward_unimplemented 47 | 48 | 49 | 50 | def as_metric(format=None,aggregate='mean',name=None): 51 | def _f(fun): 52 | m=Metric 53 | m.calculate_metric = fun 54 | m.__name__ = m.__qualname__ = if_none(name,fun.__name__), 55 | m.format=format 56 | m.aggregate=aggregate 57 | m.__doc__ = m.__doc__ 58 | m._signature = inspect.Signature(fun) 59 | return _f 60 | 61 | 62 | 63 | class RunningMeanStd(object): 64 | # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm 65 | def __init__(self, epsilon=1e-4, shape=()): 66 | self.mean = np.zeros(shape, 'float64') 67 | self.var = np.ones(shape, 'float64') 68 | self.count = epsilon 69 | 70 | def update(self, x): 71 | batch_mean = np.mean(x, axis=0) 72 | batch_var = np.var(x, axis=0) 73 | batch_count = x.shape[0] 74 | self.update_from_moments(batch_mean, batch_var, batch_count) 75 | 76 | def update_from_moments(self, batch_mean, batch_var, batch_count): 77 | self.mean, self.var, self.count = update_mean_var_count_from_moments( 78 | self.mean, self.var, self.count, batch_mean, batch_var, batch_count) 79 | 80 | def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count): 81 | delta = batch_mean - mean 82 | tot_count = count + batch_count 83 | 84 | new_mean = mean + delta * batch_count / tot_count 85 | m_a = var * count 86 | m_b = batch_var * batch_count 87 | M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count 88 | new_var = M2 / tot_count 89 | new_count = tot_count 90 | 91 | return new_mean, new_var, new_count -------------------------------------------------------------------------------- /sphinx_docs/source/trident.models.rst: -------------------------------------------------------------------------------- 1 | trident.models package 2 | ====================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | trident.models.pytorch\_arcfacenet module 8 | ----------------------------------------- 9 | 10 | .. automodule:: trident.models.pytorch_arcfacenet 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | trident.models.pytorch\_deeplab module 16 | -------------------------------------- 17 | 18 | .. automodule:: trident.models.pytorch_deeplab 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | trident.models.pytorch\_densenet module 24 | --------------------------------------- 25 | 26 | .. automodule:: trident.models.pytorch_densenet 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | trident.models.pytorch\_efficientnet module 32 | ------------------------------------------- 33 | 34 | .. automodule:: trident.models.pytorch_efficientnet 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | trident.models.pytorch\_gan module 40 | ---------------------------------- 41 | 42 | .. automodule:: trident.models.pytorch_gan 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | trident.models.pytorch\_mobilenet module 48 | ---------------------------------------- 49 | 50 | .. automodule:: trident.models.pytorch_mobilenet 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | trident.models.pytorch\_mtcnn module 56 | ------------------------------------ 57 | 58 | .. automodule:: trident.models.pytorch_mtcnn 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | trident.models.pytorch\_resnet module 64 | ------------------------------------- 65 | 66 | .. automodule:: trident.models.pytorch_resnet 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | trident.models.pytorch\_rfbnet module 72 | ------------------------------------- 73 | 74 | .. automodule:: trident.models.pytorch_rfbnet 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | 79 | trident.models.pytorch\_ssd module 80 | ---------------------------------- 81 | 82 | .. automodule:: trident.models.pytorch_ssd 83 | :members: 84 | :undoc-members: 85 | :show-inheritance: 86 | 87 | trident.models.pytorch\_vgg module 88 | ---------------------------------- 89 | 90 | .. automodule:: trident.models.pytorch_vgg 91 | :members: 92 | :undoc-members: 93 | :show-inheritance: 94 | 95 | trident.models.pytorch\_yolo module 96 | ----------------------------------- 97 | 98 | .. automodule:: trident.models.pytorch_yolo 99 | :members: 100 | :undoc-members: 101 | :show-inheritance: 102 | 103 | trident.models.tensorflow\_efficientnet module 104 | ---------------------------------------------- 105 | 106 | .. automodule:: trident.models.tensorflow_efficientnet 107 | :members: 108 | :undoc-members: 109 | :show-inheritance: 110 | 111 | trident.models.tensorflow\_resnet module 112 | ---------------------------------------- 113 | 114 | .. automodule:: trident.models.tensorflow_resnet 115 | :members: 116 | :undoc-members: 117 | :show-inheritance: 118 | 119 | 120 | Module contents 121 | --------------- 122 | 123 | .. automodule:: trident.models 124 | :members: 125 | :undoc-members: 126 | :show-inheritance: 127 | -------------------------------------------------------------------------------- /sphinx_docs/source/trident.optims.rst: -------------------------------------------------------------------------------- 1 | trident.optims package 2 | ====================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | trident.optims.pytorch\_constraints module 8 | ------------------------------------------ 9 | 10 | .. automodule:: trident.optims.pytorch_constraints 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | trident.optims.pytorch\_losses module 16 | ------------------------------------- 17 | 18 | .. automodule:: trident.optims.pytorch_losses 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | trident.optims.pytorch\_metrics module 24 | -------------------------------------- 25 | 26 | .. automodule:: trident.optims.pytorch_metrics 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | trident.optims.pytorch\_optimizers module 32 | ----------------------------------------- 33 | 34 | .. automodule:: trident.optims.pytorch_optimizers 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | trident.optims.pytorch\_profiling module 40 | ---------------------------------------- 41 | 42 | .. automodule:: trident.optims.pytorch_profiling 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | trident.optims.pytorch\_regularizers module 48 | ------------------------------------------- 49 | 50 | .. automodule:: trident.optims.pytorch_regularizers 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | trident.optims.pytorch\_trainer module 56 | -------------------------------------- 57 | 58 | .. automodule:: trident.optims.pytorch_trainer 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | trident.optims.tensorflow\_constraints module 64 | --------------------------------------------- 65 | 66 | .. automodule:: trident.optims.tensorflow_constraints 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | trident.optims.tensorflow\_losses module 72 | ---------------------------------------- 73 | 74 | .. automodule:: trident.optims.tensorflow_losses 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | 79 | trident.optims.tensorflow\_metrics module 80 | ----------------------------------------- 81 | 82 | .. automodule:: trident.optims.tensorflow_metrics 83 | :members: 84 | :undoc-members: 85 | :show-inheritance: 86 | 87 | trident.optims.tensorflow\_optimizers module 88 | -------------------------------------------- 89 | 90 | .. automodule:: trident.optims.tensorflow_optimizers 91 | :members: 92 | :undoc-members: 93 | :show-inheritance: 94 | 95 | trident.optims.tensorflow\_regularizers module 96 | ---------------------------------------------- 97 | 98 | .. automodule:: trident.optims.tensorflow_regularizers 99 | :members: 100 | :undoc-members: 101 | :show-inheritance: 102 | 103 | trident.optims.tensorflow\_trainer module 104 | ----------------------------------------- 105 | 106 | .. automodule:: trident.optims.tensorflow_trainer 107 | :members: 108 | :undoc-members: 109 | :show-inheritance: 110 | 111 | trident.optims.trainers module 112 | ------------------------------ 113 | 114 | .. automodule:: trident.optims.trainers 115 | :members: 116 | :undoc-members: 117 | :show-inheritance: 118 | 119 | 120 | Module contents 121 | --------------- 122 | 123 | .. automodule:: trident.optims 124 | :members: 125 | :undoc-members: 126 | :show-inheritance: 127 | -------------------------------------------------------------------------------- /docs/zh-tw/quickstart.md: -------------------------------------------------------------------------------- 1 | # 如何建構圖像識別模型AB Test 2 | 3 | ## 場景 4 | 5 | 幾乎所有深度學習的hello world就是mnist手寫數據集,我們在這也不例外,我們將會透過這個案例的實作來讓大家理解trident的主要特性與使用邏輯。 6 | 7 | ## 需求 8 | 9 | 基於mnist手寫數據集,建構一個識別10個數字的分類模型。但是我們想要知道到底使用傳統的relu活化函數,還是使用參數化relu (parametered relu)哪個效果好一點。這時候我們會需要一個AB Test來協助我們比較兩者的效果。 10 | 11 | ## 代碼 12 | 13 | ```python 14 | import os 15 | os.environ['TRIDENT_BACKEND'] = 'pytorch' 16 | import trident as T 17 | from trident import * 18 | 19 | data_provider = T.load_mnist('mnist') 20 | data_provider.image_transform_funcs = [add_noise(0.1),normalize(127.5, 127.5)] 21 | 22 | def convnet(activation='relu'): 23 | return Sequential( 24 | Conv2d((5,5),16,strides=2,auto_pad=True,activation=activation), 25 | Conv2d((3, 3),32, strides=2, auto_pad=True,activation=activation), 26 | Conv2d((3, 3), 64, strides=1, auto_pad=True,activation=activation), 27 | Flatten(), 28 | Dense(10) 29 | ) 30 | 31 | net1=Model(input_shape=(1,28,28),output=convnet('relu')) 32 | net2=Model(input_shape=(1,28,28),output=convnet('p_relu')) 33 | net1.summary() 34 | net2.summary() 35 | 36 | 37 | net1.with_optimizer(optimizer='Adam',lr=1e-3) \ 38 | .with_loss(CrossEntropyLoss)\ 39 | .with_metric(accuracy)\ 40 | .with_regularizer('l2')\ 41 | 42 | 43 | net2.with_optimizer(optimizer='Adam',lr=1e-3)\ 44 | .with_loss(CrossEntropyLoss)\ 45 | .with_metric(accuracy)\ 46 | .with_regularizer('l2') 47 | 48 | 49 | plan=TrainingPlan()\ 50 | .add_training_item(net1,name='net1')\ 51 | .add_training_item(net2,name='net2')\ 52 | .with_data_loader(data_provider)\ 53 | .repeat_epochs(10)\ 54 | .within_minibatch_size(64)\ 55 | .print_progress_scheduling(10,unit='batch')\ 56 | .display_loss_metric_curve_scheduling(100,unit='batch',imshow=True) 57 | 58 | plan.start_now() 59 | ``` 60 | 61 | 62 | 63 | ## 成果 64 | 65 | ```python 66 | import os 67 | os.environ['TRIDENT_BACKEND'] = 'pytorch' 68 | import trident as T 69 | from trident import * 70 | ``` 71 | 72 | ```python 73 | 74 | data_provider = T.load_mnist('mnist') 75 | data_provider.image_transform_funcs = [add_noise(0.1),normalize(127.5, 127.5)] 76 | ``` 77 | 78 | 79 | 80 | ```python 81 | def convnet(activation='relu'): 82 | return Sequential( 83 | Conv2d((5,5),16,strides=2,auto_pad=True,activation=activation), 84 | Conv2d((3, 3),32, strides=2, auto_pad=True,activation=activation), 85 | Conv2d((3, 3), 64, strides=1, auto_pad=True,activation=activation), 86 | Flatten(), 87 | Dense(10) 88 | ) 89 | 90 | net1=Model(input_shape=(1,28,28),output=convnet('relu')) 91 | net2=Model(input_shape=(1,28,28),output=convnet('p_relu')) 92 | net1.summary() 93 | net2.summary() 94 | ``` 95 | 96 | ![](../images/quickstart_summary.png) 97 | 98 | ``` 99 | 100 | net1.with_optimizer(optimizer='Adam',lr=1e-3) \ 101 | .with_loss(CrossEntropyLoss)\ 102 | .with_metric(accuracy)\ 103 | .with_regularizer('l2')\ 104 | ``` 105 | 106 | 107 | 108 | ```python 109 | 110 | plan=TrainingPlan()\ 111 | .add_training_item(net1,name='net1')\ 112 | .add_training_item(net2,name='net2')\ 113 | .with_data_loader(data_provider)\ 114 | .repeat_epochs(10)\ 115 | .within_minibatch_size(64)\ 116 | .print_progress_scheduling(10,unit='batch')\ 117 | .display_loss_metric_curve_scheduling(100,unit='batch',imshow=True) 118 | 119 | plan.start_now() 120 | ``` 121 | 122 | 123 | 124 | ![](../images/quickstart_ouput.png) 125 | 126 | 127 | 128 | 129 | 130 | -------------------------------------------------------------------------------- /trident/reinforcement/envs/__init__.py: -------------------------------------------------------------------------------- 1 | import threading 2 | from multiprocessing import Pipe 3 | import numpy as np 4 | import subprocess as sp 5 | import torch.multiprocessing as mp 6 | from trident import context 7 | 8 | ctx=context._context() 9 | 10 | def worker(remote, parent_remote, env_fn): 11 | parent_remote.close() 12 | env = env_fn() 13 | while True: 14 | cmd, data = remote.recv() 15 | 16 | if cmd == 'step': 17 | ob, reward, done, info = env.step(data) 18 | if done: 19 | ob = env.reset() 20 | remote.send((ob, reward, done, info)) 21 | 22 | elif cmd == 'render': 23 | remote.send(env.render()) 24 | 25 | elif cmd == 'close': 26 | remote.close() 27 | break 28 | 29 | else: 30 | raise NotImplementedError 31 | 32 | class SubprocVecEnv(): 33 | def __init__(self): 34 | self.waiting = False 35 | 36 | self.closed = False 37 | no_of_envs = len(env_fns) 38 | self.remotes, self.work_remotes = \ 39 | zip(*[Pipe() for _ in range(no_of_envs)]) 40 | self.ps = [] 41 | 42 | for wrk, rem, fn in zip(self.work_remotes, self.remotes, env_fns): 43 | proc = Process(target=worker, 44 | args=(wrk, rem, CloudpickleWrapper(fn))) 45 | self.ps.append(proc) 46 | 47 | for p in self.ps: 48 | p.daemon = True 49 | p.start() 50 | 51 | for remote in self.work_remotes: 52 | remote.close() 53 | 54 | def step_async(self, actions): 55 | if self.waiting: 56 | raise AlreadySteppingError 57 | self.waiting = True 58 | 59 | for remote, action in zip(self.remotes, actions): 60 | remote.send(('step', action)) 61 | 62 | def step_wait(self): 63 | if not self.waiting: 64 | raise NotSteppingError 65 | self.waiting = False 66 | 67 | results = [remote.recv() for remote in self.remotes] 68 | obs, rews, dones, infos = zip(*results) 69 | return np.stack(obs), np.stack(rews), np.stack(dones), info 70 | 71 | def step(self, actions): 72 | self.step_async(actions) 73 | return self.step_wait() 74 | 75 | def reset(self): 76 | for remote in self.remotes: 77 | remote.send(('reset', None)) 78 | 79 | return np.stack([remote.recv() for remote in self.remotes]) 80 | 81 | def close(self): 82 | if self.closed: 83 | return 84 | if self.waiting: 85 | for remote in self.remotes: 86 | remote.recv() 87 | for remote in self.remotes: 88 | remote.send(('close', None)) 89 | for p in self.ps: 90 | p.join() 91 | self.closed = True 92 | 93 | 94 | class MultipleEnvironments: 95 | def __init__(self,*envs,num_envs,**kwargs): 96 | if num_envs is None: 97 | num_envs=len(envs) 98 | self.agent_conns, self.env_conns = zip(*[mp.Pipe() for _ in range(num_envs)]) 99 | 100 | self.envs = list(envs) 101 | self.num_states = self.envs[0].observation_space.shape[0] 102 | 103 | for index in range(num_envs): 104 | process = mp.Process(target=self.run, args=(index,)) 105 | process.start() 106 | self.env_conns[index].close() 107 | 108 | def run(self, index): 109 | self.agent_conns[index].close() 110 | while True: 111 | request, action = self.env_conns[index].recv() 112 | if request == "step": 113 | self.env_conns[index].send(self.envs[index].step(action.item())) 114 | elif request == "reset": 115 | self.env_conns[index].send(self.envs[index].reset()) 116 | else: 117 | raise NotImplementedError -------------------------------------------------------------------------------- /trident/backend/__init__.py: -------------------------------------------------------------------------------- 1 | """trident backend""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | import six 6 | import os 7 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 8 | from trident.context import * 9 | from trident.backend.common import * 10 | from trident.backend.decorators import * 11 | #from trident.backend import dtype 12 | #import trident.backend.numpy_ops 13 | 14 | import trident.backend.load_backend 15 | 16 | if get_backend()=='pytorch': 17 | from trident.backend.pytorch_ops import * 18 | 19 | 20 | elif get_backend()=='tensorflow': 21 | from trident.backend.tensorflow_ops import * 22 | 23 | elif get_backend()=='jax': 24 | from trident.backend.jax_ops import * 25 | 26 | 27 | 28 | 29 | if get_backend()=='pytorch': 30 | from trident.backend.pytorch_backend import * 31 | 32 | elif get_backend()=='tensorflow': 33 | from trident.backend.tensorflow_backend import * 34 | 35 | elif get_backend()=='jax': 36 | from trident.backend.jax_backend import * 37 | 38 | from trident.backend.tensorspec import * 39 | from trident.loggers.history import * 40 | from trident.backend import model 41 | from trident.backend.pillow_backend import * 42 | 43 | from trident.data.dataset import * 44 | from trident.data.data_provider import * 45 | from trident.callbacks import * 46 | from trident.misc import * 47 | 48 | 49 | if get_backend()=='pytorch': 50 | from trident.optims.pytorch_optimizers import * 51 | from trident.layers.pytorch_activations import * 52 | from trident.layers.pytorch_initializers import * 53 | from trident.layers.pytorch_layers import * 54 | from trident.layers.pytorch_pooling import * 55 | from trident.layers.pytorch_blocks import * 56 | from trident.layers.pytorch_normalizations import * 57 | from trident.layers.pytorch_rnn import * 58 | from trident.layers.pytorch_transformers import * 59 | 60 | from trident.optims.pytorch_constraints import * 61 | from trident.optims.pytorch_regularizers import * 62 | from trident.optims.pytorch_losses import * 63 | from trident.optims.pytorch_metrics import * 64 | from trident.optims.pytorch_trainer import * 65 | 66 | elif get_backend()=='tensorflow': 67 | from trident.backend.tensorflow_serialization import * 68 | from trident.optims.tensorflow_optimizers import * 69 | 70 | from trident.layers.tensorflow_activations import * 71 | from trident.layers.tensorflow_initializers import * 72 | from trident.layers.tensorflow_layers import * 73 | from trident.layers.tensorflow_pooling import * 74 | from trident.layers.tensorflow_blocks import * 75 | from trident.layers.tensorflow_normalizations import * 76 | from trident.layers.tensorflow_rnn import * 77 | 78 | from trident.optims.tensorflow_constraints import * 79 | from trident.optims.tensorflow_regularizers import * 80 | from trident.optims.tensorflow_losses import * 81 | from trident.optims.tensorflow_metrics import * 82 | from trident.optims.tensorflow_trainer import * 83 | 84 | elif get_backend()=='jax': 85 | from trident.backend.jax_serialization import * 86 | from trident.layers.jax_activations import * 87 | from trident.layers.jax_layers import * 88 | 89 | elif get_backend()=='onnx': 90 | import_or_install('onnx_runtime') 91 | pass 92 | 93 | from trident.data.utils import * 94 | from trident.data.image_common import * 95 | from trident.data.bbox_common import * 96 | from trident.data.label_common import * 97 | from trident.data.mask_common import * 98 | from trident.data.transform import * 99 | from trident.data.vision_transforms import * 100 | from trident.data.text_transforms import * 101 | from trident.data.data_loaders import * 102 | 103 | from trident.optims.trainers import TrainingPlan 104 | from trident.misc.ipython_utils import * 105 | from trident.misc.visualization_utils import * 106 | from trident.backend.iteration_tools import * 107 | #import trident.models 108 | 109 | 110 | -------------------------------------------------------------------------------- /trident/loggers/history.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import builtins 4 | import numbers 5 | import time 6 | import math 7 | import numpy as np 8 | from trident import context 9 | 10 | from trident.backend.common import get_backend,to_list, addindent, get_time_suffix, format_time, get_terminal_size, get_session,get_backend, \ 11 | snake2camel, PrintException, unpack_singleton, enforce_singleton, OrderedDict, Signature 12 | from trident.context import split_path, make_dir_if_need, sanitize_path 13 | 14 | 15 | ctx = context._context() 16 | _backend =get_backend() 17 | working_directory=ctx.working_directory 18 | 19 | 20 | if _backend == 'pytorch': 21 | import torch 22 | import torch.nn as nn 23 | from trident.backend.pytorch_backend import Tensor 24 | from trident.backend.pytorch_ops import * 25 | 26 | elif _backend == 'tensorflow': 27 | import tensorflow as tf 28 | from trident.backend.tensorflow_backend import Tensor 29 | from trident.backend.tensorflow_ops import * 30 | 31 | 32 | 33 | class HistoryBase(OrderedDict): 34 | def __init__(self, prevent_redundant=True,name='', **kwargs): 35 | super().__init__(**kwargs) 36 | self.name=name 37 | self.training_name=None 38 | self.prevent_redundant=prevent_redundant 39 | 40 | self.summary_writer=ctx.summary_writer 41 | 42 | @property 43 | def enable_tensorboard(self): 44 | return ctx.enable_tensorboard 45 | 46 | @property 47 | def enable_mlflow(self): 48 | return ctx.enable_mlflow 49 | 50 | def regist(self,data_name:str): 51 | if data_name not in self: 52 | self[data_name]=[] 53 | 54 | def collect(self, data_name: str, step: int, value: (float,np.ndarray)): 55 | if data_name not in self: 56 | self.regist(data_name) 57 | if any_abnormal_number(value): 58 | pass 59 | else: 60 | value=to_scalar(value) 61 | is_redundant_skip=False 62 | if self.prevent_redundant: 63 | if (step, value) in self[data_name]: 64 | is_redundant_skip=True 65 | 66 | if not is_redundant_skip: 67 | self[data_name].append((step, value)) 68 | if ctx.enable_tensorboard: 69 | if self.training_name is None: 70 | ctx.summary_writer.add_scalar( self.name+"/"+data_name, value, global_step=step, walltime=time.time()) 71 | else: 72 | ctx.summary_writer.add_scalar(self.training_name+ "/"+self.name + "/" + data_name, value, global_step=step, walltime=time.time()) 73 | if ctx.enable_mlflow: 74 | ctx.mlflow_logger.add_scalar( data_name, value, global_step=step, walltime=time.time()) 75 | 76 | def reset(self): 77 | for i in range(len(self)): 78 | self.value_list[i].clear() 79 | def get_keys(self): 80 | return self.key_list 81 | 82 | def get_series(self,data_name): 83 | 84 | if data_name in self and self[data_name] is not None and len(self[data_name])>=1: 85 | steps,values=zip(*self[data_name].copy()) 86 | return list(steps),list(values) 87 | else: 88 | #sys.stderr.write('{0} is not in this history.'.format(data_name)) 89 | return [], [] 90 | 91 | def get_last(self,data_name): 92 | if data_name in self and len(self[data_name])>0: 93 | return self[data_name][-1] 94 | else: 95 | return [] 96 | #raise ValueError('{0} is not in this History.'.format(data_name)) 97 | 98 | def get_best(self,data_name,is_larger_better=True): 99 | if data_name in self: 100 | steps,values=zip(*self[data_name].copy()) 101 | if is_larger_better: 102 | return builtins.max(values) 103 | else: 104 | return builtins.min(values) 105 | else: 106 | raise ValueError('{0} is not in this History.'.format(data_name)) -------------------------------------------------------------------------------- /trident/layers/jax_layers.py: -------------------------------------------------------------------------------- 1 | """Activation Layers""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | import builtins 6 | import inspect 7 | 8 | import string 9 | from functools import partial 10 | from pydoc import locate 11 | 12 | import six 13 | import jax 14 | import jax.numpy as jnp 15 | import jaxlib 16 | 17 | from trident.backend.common import get_function, get_class, camel2snake,snake2camel, enforce_singleton,TensorShape 18 | from trident.backend.jax_backend import Layer,Sequential,Parameter 19 | from trident.backend.jax_ops import * 20 | from trident.layers.jax_activations import get_activation 21 | 22 | __all__ = ['Dense'] 23 | 24 | 25 | def _ntuple(n): 26 | def parse(x): 27 | if isinstance(x, collections.Iterable): 28 | return x 29 | return tuple(repeat(x, n)) 30 | 31 | return parse 32 | 33 | 34 | _single = _ntuple(1) 35 | _pair = _ntuple(2) 36 | _triple = _ntuple(3) 37 | _quadruple = _ntuple(4) 38 | def get_layer_repr(layer): 39 | # We treat the extra repr like the sub-module, one item per line 40 | extra_lines = [] 41 | if hasattr(layer, 'extra_repr') and callable(layer.extra_repr): 42 | extra_repr = layer.extra_repr() 43 | # empty string will be split into list [''] 44 | if extra_repr: 45 | extra_lines = extra_repr.split('\n') 46 | child_lines = [] 47 | if isinstance(layer,Layer) and layer.layers is not None: 48 | for module in layer.layers: 49 | mod_str = repr(module) 50 | mod_str = addindent(mod_str, 2) 51 | child_lines.append('(' + module.name + '): ' + mod_str) 52 | lines = extra_lines + child_lines 53 | 54 | main_str = layer.__class__.__name__ + '(' 55 | if lines: 56 | # simple one-liner info, which most builtin Modules will use 57 | if len(extra_lines) == 1 and not child_lines: 58 | main_str += extra_lines[0] 59 | else: 60 | main_str += '\n ' + '\n '.join(lines) + '\n' 61 | 62 | main_str += ')' 63 | return main_str 64 | class Dense(Layer): 65 | def __init__(self, num_filters, use_bias=True, activation=None,kernel_regularizer=None, keep_output=False, name=None, **kwargs): 66 | super(Dense, self).__init__(name=name,keep_output=keep_output) 67 | self.rank = 0 68 | if isinstance(num_filters, int): 69 | self.num_filters = num_filters 70 | elif isinstance(num_filters, tuple): 71 | self.num_filters = unpack_singleton(num_filters) 72 | else: 73 | raise ValueError('output_shape should be integer, list of integer or tuple of integer...') 74 | 75 | self.use_bias = use_bias 76 | if kernel_regularizer == 'l2': 77 | self.kernel_regularizer = l2_normalize 78 | else: 79 | self.kernel_regularizer = None 80 | 81 | self.activation = get_activation(activation) 82 | 83 | 84 | def build(self, input_shape:TensorShape): 85 | if not self._built: 86 | if len(input_shape.dims) == 1: 87 | self.input_filters = input_shape.dims[0] 88 | else: 89 | self.input_filters = input_shape[self.filter_index] 90 | self.register_parameter('weight',Parameter(data=random_normal(shape=(self.input_filters,self.num_filters), mean=0., std=0.2) , name='weight')) 91 | kaiming_uniform(self.weight, a=math.sqrt(5)) 92 | if self.use_bias: 93 | self.register_parameter('bias',Parameter(data=random_normal(shape=(self.num_filters), mean=0., std=0.002) , name='bias')) 94 | self._built = True 95 | 96 | def forward(self, x, **kwargs) : 97 | 98 | if hasattr(self, 'kernel_regularizer') and self.kernel_regularizer is not None: 99 | x = self.kernel_regularizer(self.weight)@x 100 | else: 101 | x =self.weight@x 102 | if self.use_bias: 103 | x=x+ self.bias 104 | 105 | if self.activation is not None: 106 | x = self.activation(x) 107 | return x -------------------------------------------------------------------------------- /trident/data/text_common.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | import functools 5 | import math 6 | import os 7 | import random 8 | import re 9 | import builtins 10 | import string 11 | import time 12 | from itertools import repeat 13 | from functools import partial 14 | import inspect 15 | import cv2 16 | import numpy as np 17 | import six 18 | from trident.backend.common import * 19 | 20 | __all__ = ['text_backend_adaption','reverse_text_backend_adaption'] 21 | 22 | 23 | if get_backend()== 'pytorch': 24 | from trident.backend.pytorch_backend import to_numpy, to_tensor, ObjectType 25 | from trident.backend.pytorch_ops import int_shape 26 | from trident.layers.pytorch_layers import Embedding 27 | import torch 28 | elif get_backend()== 'tensorflow': 29 | from trident.backend.tensorflow_backend import to_numpy, to_tensor,ObjectType 30 | from trident.backend.tensorflow_ops import int_shape 31 | 32 | 33 | 34 | def chinese_full2half(): 35 | """Convert all fullwidth Chinese characters to halfwidth . 36 | 37 | Returns: 38 | 39 | """ 40 | def string_op(input_str:str): 41 | rstring = "" 42 | for uchar in input_str: 43 | u_code = ord(uchar) 44 | if u_code == 0x3000 or u_code == 12288 or uchar == string.whitespace: 45 | u_code = 32 46 | elif 65281 <= u_code <= 65374: 47 | u_code -= 65248 48 | rstring += chr(u_code) 49 | return rstring 50 | return string_op 51 | 52 | def chinese_half2full(): 53 | """Convert all halfwidth Chinese characters to fullwidth . 54 | 55 | Returns: 56 | 57 | """ 58 | def string_op(input_str:str): 59 | rstring = "" 60 | for uchar in input_str: 61 | u_code = ord(uchar) 62 | if u_code == 32: 63 | u_code = 12288 64 | elif 33 <= u_code <= 126: 65 | u_code += 65248 66 | rstring += chr(u_code) 67 | return rstring 68 | return string_op 69 | 70 | def text_backend_adaption(text): 71 | if get_backend() == 'tensorflow': 72 | if text.dtype==np.int64 and text.ndim ==1: 73 | pass 74 | elif text.ndim ==2: 75 | text=text.astype(np.float32) 76 | else: 77 | if text.dtype == np.int64: 78 | pass 79 | elif text.ndim ==2: 80 | text=text.astype(np.float32) 81 | return text 82 | 83 | 84 | def reverse_text_backend_adaption(text): 85 | # if get_backend() == 'tensorflow': 86 | # if text.dtype == np.int64 and text.ndim == 1: 87 | # pass 88 | # elif text.ndim == 2: 89 | # text =argmax(text,-1) 90 | # else: 91 | # if text.dtype == np.int64: 92 | # pass 93 | # elif text.ndim == 2: 94 | # text = argmax(text, -1) 95 | return text 96 | 97 | 98 | 99 | # def char2embedding(embedding:Embedding): 100 | # def img_op(sentence:str,**kwargs): 101 | # sentence = reverse_image_backend_adaption(image) 102 | # norm_mean = mean 103 | # norm_std = std 104 | # if isinstance(norm_mean, tuple): 105 | # norm_mean = list(norm_mean) 106 | # 107 | # if isinstance(norm_std, tuple): 108 | # norm_std = list(norm_std) 109 | # 110 | # if isinstance(norm_mean, (float, int)) and isinstance(norm_std, (float, int)) and image.ndim == 3: 111 | # return image * float(norm_std) + float(norm_mean) 112 | # elif isinstance(norm_mean, list) and isinstance(norm_std, list) and len(norm_mean) == 1 and len(norm_std) == 1: 113 | # return image * float(norm_std[0]) + float(norm_mean[0]) 114 | # elif isinstance(norm_mean, list) and isinstance(norm_std, list) and len(norm_mean) == 3 and len(norm_std) == 3: 115 | # norm_mean = np.reshape(np.array(norm_mean), (1, 1, 3)) 116 | # norm_std = np.reshape(np.array(norm_std), (1, 1, 3)) 117 | # return image * norm_std + norm_mean 118 | # return image 119 | # 120 | # return img_op 121 | -------------------------------------------------------------------------------- /trident/backend/dtype.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import typing 3 | 4 | from trident.backend.common import get_backend 5 | 6 | 7 | __all__=[ 8 | "int8", "byte", 9 | "int16", "short", 10 | "int32", "intc", 11 | "int64", "intp", 12 | "uint8", "ubyte", 13 | "float16", "half", 14 | "float32", "single", 15 | "float64", "double","long","float", 16 | "bool"] 17 | 18 | 19 | if get_backend() == 'pytorch': 20 | import torch 21 | # type definition 22 | bool = torch.bool 23 | 24 | int8 = torch.int8 25 | byte = torch.int8 26 | int16 = torch.int16 27 | short = torch.int16 28 | int32 = torch.int32 29 | intc = torch.int32 30 | int64 = torch.int64 31 | intp = torch.int64 32 | 33 | uint8 = torch.uint8 34 | ubyte = torch.uint8 35 | float16 = torch.float16 36 | half = torch.float16 37 | float32 = torch.float32 38 | single = torch.float32 39 | float64 = torch.float64 40 | double = torch.float64 41 | long = torch.int64 42 | float= torch.float32 43 | complex64 = torch.complex64 44 | complex128 = torch.complex128 45 | cfloat = torch.cfloat 46 | 47 | 48 | elif get_backend() == 'tensorflow': 49 | import tensorflow as tf 50 | bool = tf.bool 51 | 52 | int8 = tf.int8 53 | byte = tf.int8 54 | int16 = tf.int16 55 | short = tf.int16 56 | int32 = tf.int32 57 | intc = tf.int32 58 | int64 = tf.int64 59 | intp = tf.int64 60 | 61 | uint8 = tf.uint8 62 | ubyte = tf.uint8 63 | float16 = tf.float16 64 | half = tf.float16 65 | float32 = tf.float32 66 | single = tf.float32 67 | float64 = tf.float64 68 | double = tf.float64 69 | long=tf.int64 70 | float = tf.float32 71 | complex64 = tf.complex64 72 | complex128 = tf.complex128 73 | cfloat = tf.complex64 74 | 75 | elif get_backend() == 'jax': 76 | import jax.numpy as jnp 77 | # type definition 78 | bool = jnp.bool_ 79 | 80 | int8 =jnp.int8 81 | byte = jnp.int8 82 | int16 = jnp.int16 83 | short = jnp.int16 84 | int32 = jnp.int32 85 | intc = jnp.int32 86 | int64 = jnp.int64 87 | intp = jnp.int64 88 | 89 | uint8 = jnp.uint8 90 | ubyte = jnp.uint8 91 | float16 = jnp.float16 92 | half = jnp.float16 93 | float32 = jnp.float32 94 | single = jnp.float32 95 | float64 = jnp.float64 96 | double = jnp.float64 97 | long = jnp.int64 98 | float= jnp.float32 99 | complex64 = jnp.complex64 100 | complex128 =jnp.complex128 101 | cfloat =None 102 | 103 | elif get_backend() == 'onnx': 104 | from onnx import onnx_pb 105 | bool = onnx_pb.TensorProto.BOOL 106 | int8 = onnx_pb.TensorProto.INT8 107 | byte = onnx_pb.TensorProto.INT8 108 | int16 = onnx_pb.TensorProto.INT16 109 | short = onnx_pb.TensorProto.INT16 110 | int32 = onnx_pb.TensorProto.INT32 111 | intc = onnx_pb.TensorProto.INT32 112 | int64 = onnx_pb.TensorProto.INT64 113 | intp = onnx_pb.TensorProto.INT64 114 | 115 | uint8 = onnx_pb.TensorProto.UINT8 116 | ubyte = onnx_pb.TensorProto.UINT8 117 | float16 = onnx_pb.TensorProto.FLOAT1 118 | half = onnx_pb.TensorProto.FLOAT1 119 | float32 = onnx_pb.TensorProto.FLOAT 120 | single = onnx_pb.TensorProto.FLOAT 121 | float64 = onnx_pb.TensorProto.DOUBLE 122 | double = onnx_pb.TensorProto.DOUBLE 123 | long = onnx_pb.TensorProto.INT64 124 | float = onnx_pb.TensorProto.FLOAT 125 | complex64 = None 126 | complex128 =None 127 | cfloat =None 128 | else: 129 | bool = np.bool 130 | 131 | int8 = np.int8 132 | byte = np.int8 133 | int16 = np.int16 134 | short = np.int16 135 | int32 = np.int32 136 | intc = np.int32 137 | int64 = np.int64 138 | intp = np.int64 139 | 140 | uint8 = np.uint8 141 | ubyte = np.uint8 142 | float16 = np.float16 143 | half = np.float16 144 | float32 = np.float32 145 | single = np.float32 146 | float64 = np.float64 147 | double = np.float64 148 | long = np.int64 149 | float= np.float32 150 | complex64 = np.complex64 151 | complex128 = np.complex128 152 | cfloat = np.complex64 153 | 154 | -------------------------------------------------------------------------------- /trident/data/image_reader.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | try: 6 | import Queue 7 | except ImportError: 8 | import queue as Queue 9 | import threading 10 | import time 11 | 12 | from trident.misc.ipython_utils import is_in_ipython, is_in_colab 13 | 14 | if is_in_ipython(): 15 | from IPython import display 16 | 17 | if not is_in_colab: 18 | import matplotlib 19 | matplotlib.use('TkAgg' if not is_in_ipython() and not is_in_colab() else 'NbAgg' ) 20 | else: 21 | import matplotlib 22 | import matplotlib.pyplot as plt 23 | import itertools 24 | from trident.data.image_common import list_pictures 25 | 26 | 27 | 28 | class ImageThread(threading.Thread): 29 | """Image Thread""" 30 | def __init__(self, queue, out_queue): 31 | threading.Thread.__init__(self) 32 | self.queue = queue 33 | self.out_queue = out_queue 34 | 35 | def run(self): 36 | while True: 37 | # Grabs image path from queue 38 | image_path_group = self.queue.get() 39 | # Grab image 40 | image_group = [plt.imread(i) for i in image_path_group] 41 | # Place image in out queue 42 | self.out_queue.put(image_group) 43 | # Signals to queue job is done 44 | self.queue.task_done() 45 | 46 | class ImageReader(object): 47 | r"""Base class for all Samplers. 48 | 49 | Every Sampler subclass has to provide an __iter__ method, providing a way 50 | to iterate over indices of dataset elements, and a __len__ method that 51 | returns the length of the returned iterators. 52 | """ 53 | 54 | def __init__(self,images=None): 55 | self.image_paths =None 56 | if images is not None : 57 | if hasattr(images,' __iter__') and all(isinstance(images, str) for img in images): 58 | self.image_paths=images 59 | else: 60 | raise TypeError('pins must be a list of one or more strings.') 61 | 62 | self.workers=2 63 | self.itr = 0 64 | self.statistics=[] 65 | self.buffer_size = 5 66 | self._minibatch_size = 32 67 | self.input_qsize = 50 68 | self.min_input_qsize = 10 69 | self.n_minibatches_to_run = float('inf') 70 | self.queue = Queue.Queue() 71 | self.out_queue = Queue.Queue(maxsize=self.buffer_size) 72 | 73 | self.prepare_queue() 74 | 75 | def prepare_queue(self): 76 | if self.image_paths is not None and len(self.image_paths)>0: 77 | self.grouped_image_paths = zip(*[iter(self.image_paths[:-(len(self.image_paths) % self._minibatch_size)])] * self._minibatch_size) 78 | self.grouped_image_paths = itertools.cycle(self.grouped_image_paths) 79 | 80 | self.threadPool=[] 81 | for i in range(self.workers): 82 | t = ImageThread(self.queue, self.out_queue) 83 | t.setDaemon(True) 84 | t.start() 85 | self.threadPool.append(t) 86 | for image_path_group in range(self.input_qsize): 87 | image_path_group = self.grouped_image_paths.__next__() 88 | self.queue.put(image_path_group) 89 | 90 | @property 91 | def minibatch_size(self): 92 | return self._minibatch_size 93 | 94 | @minibatch_size.setter 95 | def minibatch_size(self, minibatch_size): 96 | if (isinstance(minibatch_size, str)): 97 | self._minibatch_size = int(minibatch_size) 98 | elif (isinstance(minibatch_size, int)): 99 | self._minibatch_size = minibatch_size 100 | self.grouped_image_paths = zip(*[iter(self.image_paths[:-(len(self.image_paths) % self._minibatch_size)])] * self._minibatch_size) 101 | self.grouped_image_paths = itertools.cycle(self.grouped_image_paths) 102 | 103 | def get_all_images(self,base_folder): 104 | self.image_paths=list_pictures(base_folder) 105 | self.prepare_queue() 106 | 107 | def __iter__(self): 108 | if self.itr<=self.n_minibatches_to_run: 109 | start = time.time() 110 | image_group = self.out_queue.get() 111 | stop = time.time() 112 | self.statistics.append(stop - start) 113 | self.itr += 1 114 | if self.queue.qsize() <= self.min_input_qsize: 115 | for image_path_group in range(self.input_qsize): 116 | image_path_group = self.grouped_image_paths.__next__() 117 | self.queue.put(image_path_group) 118 | yield image_group 119 | 120 | 121 | 122 | 123 | def __len__(self): 124 | return len(self.image_paths) -(len(self.image_paths) % self._minibatch_size) 125 | 126 | 127 | 128 | 129 | 130 | -------------------------------------------------------------------------------- /trident/optims/pytorch_constraints.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from trident.backend.pytorch_backend import Layer 10 | 11 | from trident.backend.common import get_session, epsilon 12 | from trident.backend.pytorch_ops import * 13 | 14 | __all__ = ['max_norm', 'non_neg_norm', 'unit_norm', 'min_max_norm', 'maxnorm', 'nonnegnorm', 'unitnorm', 'minmaxnorm', 'get_constraint'] 15 | 16 | _session=get_session() 17 | _epsilon=_session.epsilon 18 | @torch.no_grad() 19 | def max_norm(model, max_value=2,axis=0): 20 | """ 21 | MaxNorm weight constraint. 22 | Constrains the weights incident to each hidden unit to have a norm less than or equal to a desired value. 23 | Args: 24 | model : the model contains weights need to setting the constraints. 25 | max_value (float):the maximum norm value for the incoming weights 26 | axis (int):axis along which to calculate weight norms. 27 | 28 | """ 29 | with torch.no_grad(): 30 | for name, param in model.named_parameters(): 31 | if 'bias' not in name and param is not None and param.requires_grad==True: 32 | norm = param.data.norm(2, dim=axis, keepdim=True) 33 | desired = torch.clamp(norm, 0, max_value) 34 | param.data.copy_(param.data * (desired / (epsilon() + norm))) 35 | 36 | @torch.no_grad() 37 | def non_neg_norm(model): 38 | """ 39 | Constrains the weights to be non-negative. 40 | Args: 41 | model : the model contains weights need to setting the constraints. 42 | 43 | """ 44 | with torch.no_grad(): 45 | for name, param in model.named_parameters(): 46 | if 'bias' not in name and param is not None and param.requires_grad==True: 47 | param.data.copy_(clip(param.data, 0, np.inf)) 48 | 49 | @torch.no_grad() 50 | def unit_norm(model,axis=0): 51 | """ 52 | Constrains the weights incident to each hidden unit to have unit norm. 53 | Args: 54 | axis (int):axis along which to calculate weight norms. 55 | model : the model contains weights need to setting the constraints. 56 | 57 | """ 58 | with torch.no_grad(): 59 | if isinstance(model,Layer): 60 | for name, param in model.named_parameters(): 61 | if 'bias' not in name and param is not None and param.requires_grad==True: 62 | norm = param.data.norm(2, dim=axis, keepdim=True) 63 | param.data.copy_(param.data / (epsilon() + norm)) 64 | elif is_tensor(model): 65 | if model is not None and model.requires_grad == True: 66 | norm = model.data.norm(2, dim=axis, keepdim=True) 67 | model.data.copy_(model.data / (epsilon() + norm)) 68 | 69 | @torch.no_grad() 70 | def min_max_norm(model,min_value=0, max_value=1, rate=1.0, axis=0): 71 | """ 72 | MinMaxNorm weight constraint. 73 | Constrains the weights incident to each hidden unit to have the norm between a lower bound and an upper bound. 74 | 75 | Args: 76 | model : the model contains weights need to setting the constraints. 77 | min_value (float):the minimum norm for the incoming weights. 78 | max_value ()float:the maximum norm for the incoming weights. 79 | rate (float):rate for enforcing the constraint: weights will be rescaled to yield (1 - rate) * norm + rate * norm.clip(min_value, max_value). Effectively, this means that rate=1.0 stands for strict enforcement of the constraint, while rate<1.0 means that weights will be rescaled at each step to slowly move towards a value inside the desired interval. 80 | axis (int): axis along which to calculate weight norms 81 | 82 | 83 | """ 84 | with torch.no_grad(): 85 | for name, param in model.named_parameters(): 86 | if 'bias' not in name and param is not None and param.requires_grad==True: 87 | norm = param.data.norm(2, dim=axis, keepdim=True) 88 | desired = rate *clip(norm, min_value, max_value)+ (1 - rate) * norm 89 | param.data.copy_(param.data * (desired / (epsilon() + norm))) 90 | 91 | 92 | 93 | 94 | maxnorm = max_norm 95 | nonnegnorm = non_neg_norm 96 | unitnorm = unit_norm 97 | minmaxnorm=min_max_norm 98 | 99 | 100 | def get_constraint(constraint): 101 | if constraint in ['maxnorm','max_norm']: 102 | return max_norm 103 | elif constraint in ['non_neg','nonneg']: 104 | return non_neg_norm 105 | elif constraint in ['unit_norm','unitnorm']: 106 | return unit_norm 107 | elif constraint in ['min_max_norm', 'minmaxnorm']: 108 | return min_max_norm 109 | else: 110 | return None -------------------------------------------------------------------------------- /trident/optims/pytorch_profiling.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | from torch.autograd import Function 5 | from torch.autograd import Variable 6 | 7 | """ 8 | https://raw.githubusercontent.com/zhuwenxi/pytorch-profiling-tool/master/profiling.py 9 | """ 10 | 11 | class Profiling(object): 12 | def __init__(self, model): 13 | if isinstance(model, torch.nn.Module) is False: 14 | print("Not a valid model, please provide a 'nn.Module' instance.") 15 | 16 | self.model = model 17 | self.record = {'forward': [], 'backward': []} 18 | self.profiling_on = True 19 | self.origin_call = {} 20 | self.hook_done = False 21 | self.layer_num = 0 22 | 23 | def __enter__(self): 24 | self.start() 25 | 26 | return self 27 | 28 | def __exit__(self, *args): 29 | self.stop() 30 | 31 | def __str__(self): 32 | ret = "" 33 | 34 | iter = len(self.record['forward']) / self.layer_num 35 | 36 | for i in xrange(iter): 37 | ret += "\n================================= Iteration {} =================================\n".format(i + 1) 38 | 39 | ret += "\nFORWARD TIME:\n\n" 40 | for j in xrange(self.layer_num): 41 | record_item = self.record['forward'][i * self.layer_num + j] 42 | ret += "layer{:3d}: {:.6f} ms ({})\n".format(j + 1, record_item[2] - record_item[1], 43 | record_item[0]) 44 | 45 | ret += "\nBACKWARD TIME:\n\n" 46 | for j in (xrange(self.layer_num)): 47 | record_item = self.record['backward'][i * self.layer_num + self.layer_num - j - 1] 48 | try: 49 | ret += "layer{:3d}: {:.6f} ms ({})\n".format(j + 1, 50 | record_item[2] - record_item[1], 51 | record_item[0]) 52 | except: 53 | # Oops, this layer doesn't execute backward post-hooks 54 | pass 55 | 56 | return ret 57 | 58 | def start(self): 59 | if self.hook_done is False: 60 | self.hook_done = True 61 | self.hook_modules(self.model) 62 | 63 | self.profiling_on = True 64 | 65 | return self 66 | 67 | def stop(self): 68 | self.profiling_on = False 69 | 70 | return self 71 | 72 | def hook_modules(self, module): 73 | 74 | this_profiler = self 75 | 76 | sub_modules = module.__dict__['_modules'] 77 | 78 | for name, sub_module in sub_modules.items(): 79 | # nn.Module is the only thing we care about. 80 | if sub_module is None or isinstance(sub_module, torch.nn.Module) is False: 81 | break 82 | 83 | if isinstance(sub_module, torch.nn.Container) or isinstance(sub_module, torch.nn.Sequential): 84 | # 85 | # nn.Container or nn.Sequential who have sub nn.Module. Recursively visit and hook their decendants. 86 | # 87 | self.hook_modules(sub_module) 88 | else: 89 | 90 | self.layer_num += 1 91 | 92 | # 93 | # nn.Module who doesn't have sub nn.Module, hook it. 94 | # 95 | 96 | # Wrapper function to "__call__", with time counter in it. 97 | def wrapper_call(self, *input, **kwargs): 98 | start_time = time.time() 99 | result = this_profiler.origin_call[self.__class__](self, *input, **kwargs) 100 | stop_time = time.time() 101 | 102 | that = self 103 | 104 | def backward_pre_hook(*args): 105 | if (this_profiler.profiling_on): 106 | this_profiler.record['backward'].append((that, time.time())) 107 | 108 | result.grad_fn.register_pre_hook(backward_pre_hook); 109 | 110 | if (this_profiler.profiling_on): 111 | global record 112 | this_profiler.record['forward'].append((self, start_time, stop_time)) 113 | 114 | return result 115 | 116 | # Replace "__call__" with "wrapper_call". 117 | if sub_module.__class__ not in this_profiler.origin_call: 118 | this_profiler.origin_call.update({sub_module.__class__: sub_module.__class__.__call__}) 119 | sub_module.__class__.__call__ = wrapper_call 120 | 121 | def backward_post_hook(*args): 122 | if (this_profiler.profiling_on): 123 | this_profiler.record['backward'][-1] = ( 124 | this_profiler.record['backward'][-1][0], this_profiler.record['backward'][-1][1], time.time()) 125 | 126 | sub_module.register_backward_hook(backward_post_hook) -------------------------------------------------------------------------------- /sphinx_docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | import trident 16 | import re 17 | #sys.path.insert(0, os.path.abspath('D:/PycharmProjects/trident')) 18 | 19 | # -- Project information ----------------------------------------------------- 20 | 21 | project = 'trident' 22 | copyright = '2022, AllanYiin' 23 | author = 'AllanYiin' 24 | 25 | # The full version, including alpha/beta/rc tags 26 | release = '0.7.5' 27 | 28 | 29 | # -- General configuration --------------------------------------------------- 30 | 31 | # Add any Sphinx extension module names here, as strings. They can be 32 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 33 | # ones. 34 | extensions = [ 35 | 'sphinx.ext.autodoc', 36 | 'sphinx.ext.viewcode', 37 | 'sphinx.ext.extlinks', 38 | 'sphinx.ext.napoleon', 39 | 'nbsphinx', 40 | 'recommonmark', 41 | 'IPython.sphinxext.ipython_console_highlighting' 42 | ] 43 | 44 | # Suppress warnings 45 | suppress_warnings = ['image.nonlocal_uri'] 46 | # Define source suffix 47 | source_suffix = ['.rst', '.ipynb','.md'] 48 | 49 | master_doc = 'index' 50 | # Add any paths that contain templates here, relative to this directory. 51 | templates_path = ['_templates'] 52 | 53 | # List of patterns, relative to source directory, that match files and 54 | # directories to ignore when looking for source files. 55 | # This pattern also affects html_static_path and html_extra_path. 56 | exclude_patterns = [] 57 | 58 | autodoc_mock_imports = [ 59 | 'tensorflow','torch','cntk','opencv','numpy' 60 | ] 61 | 62 | 63 | # sphinx.ext.napoleon options 64 | napoleon_google_docstring = True 65 | napoleon_numpy_docstring = False 66 | 67 | add_module_names = False 68 | 69 | # Linkcheck builder options 70 | 71 | 72 | source_prefix = 'https://github.com/AllanYiin/trident' 73 | 74 | 75 | 76 | # sphinx.ext.extlinks options 77 | extlinks = { 78 | 'tridenttw': (source_prefix + '/README.md', ''), 79 | 'tridentcn': (source_prefix + '/README.md', ''), 80 | 'tridenten': (source_prefix + '/README.md', '') 81 | } 82 | 83 | # -- Options for HTML output ------------------------------------------------- 84 | 85 | # The theme to use for HTML and HTML Help pages. See the documentation for 86 | # a list of builtin themes. 87 | 88 | html_theme = 'sphinx_materialdesign_theme' 89 | 90 | # Html logo in drawer. 91 | # Fit in the drawer at the width of image is 240 px. 92 | html_logo = '_static/trident_logo.png' 93 | 94 | html_theme_options = { 95 | # Specify a list of menu in Header. 96 | # Tuples forms: 97 | # ('Name', 'external url or path of pages in the document', boolean, 'icon name') 98 | # 99 | # Third argument: 100 | # True indicates an external link. 101 | # False indicates path of pages in the document. 102 | # 103 | # Fourth argument: 104 | # Specify the icon name. 105 | # For details see link. 106 | # https://material.io/icons/ 107 | 'header_links' : [ 108 | ('Home', 'index', False, 'home'), 109 | ("ExternalLink", "http://example.com", True, 'launch'), 110 | ("NoIconLink", "http://example.com", True, ''), 111 | ("GitHub", "https://github.com/AllanYiin/trident", True, 'link') 112 | ], 113 | 114 | # Customize css colors. 115 | # For details see link. 116 | # https://getmdl.io/customize/index.html 117 | # 118 | # Values: amber, blue, brown, cyan deep_orange, deep_purple, green, grey, indigo, light_blue, 119 | # light_green, lime, orange, pink, purple, red, teal, yellow(Default: indigo) 120 | 'primary_color': 'deep_orange', 121 | # Values: Same as primary_color. (Default: pink) 122 | 'accent_color': 'indigo', 123 | 124 | # Customize layout. 125 | # For details see link. 126 | # https://getmdl.io/components/index.html#layout-section 127 | 'fixed_drawer': True, 128 | 'fixed_header': True, 129 | 'header_waterfall': True, 130 | 'header_scroll': False, 131 | 132 | # Render title in header. 133 | # Values: True, False (Default: False) 134 | 'show_header_title': True, 135 | # Render title in drawer. 136 | # Values: True, False (Default: True) 137 | 'show_drawer_title': True, 138 | # Render footer. 139 | # Values: True, False (Default: True) 140 | 'show_footer': True 141 | } 142 | 143 | 144 | # Add any paths that contain custom static files (such as style sheets) here, 145 | # relative to this directory. They are copied after the builtin static files, 146 | # so a file named "default.css" will overwrite the builtin "default.css". 147 | html_static_path = ['_static'] -------------------------------------------------------------------------------- /trident/backend/decorators.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | import time 3 | import functools 4 | 5 | # Global variable to control the state of the performance timing 6 | PERFORMANCE_TIMING_ENABLED = False 7 | 8 | def measure_perf(func): 9 | """Decorator to measure function execution time.""" 10 | @functools.wraps(func) 11 | def wrapper(*args, **kwargs): 12 | if PERFORMANCE_TIMING_ENABLED: 13 | start_time = time.time() 14 | result = func(*args, **kwargs) 15 | end_time = time.time() 16 | # 檢查是函數還是類別方法 17 | if '.' in func.__qualname__: 18 | # 獲得類別名稱和方法名 19 | class_name, method_name = func.__qualname__.split('.') 20 | # 檢查類別是否有 'name' 屬性 21 | if hasattr(args[0], 'training_name'): 22 | class_info = f"{class_name}({getattr(args[0], 'training_name')})" 23 | else: 24 | class_info = class_name 25 | name = f"{class_info}.{method_name}" 26 | else: 27 | name = func.__name__ 28 | print(f"Execution time of {name}: {end_time - start_time} seconds") 29 | return result 30 | else: 31 | # Execute the function normally without timing 32 | return func(*args, **kwargs) 33 | return wrapper 34 | 35 | class PerformanceTimerContext: 36 | """Context manager to enable or disable performance timing.""" 37 | def __enter__(self): 38 | global PERFORMANCE_TIMING_ENABLED 39 | PERFORMANCE_TIMING_ENABLED = True 40 | 41 | def __exit__(self, exc_type, exc_val, exc_tb): 42 | global PERFORMANCE_TIMING_ENABLED 43 | PERFORMANCE_TIMING_ENABLED = False 44 | 45 | 46 | def trident_export(*api_names, **kwargs): 47 | """A decorator function for exporting Trident APIs. 48 | 49 | Args: 50 | *api_names: Variable length argument list of API names to be exported. 51 | 52 | Returns: 53 | The decorated function or class. 54 | 55 | Example usage: 56 | @trident_export("api_name1", "api_name2") 57 | def my_function(): 58 | pass""" 59 | 60 | def Decorator(func_or_class): 61 | """Decorator function. 62 | 63 | Args: 64 | func_or_class: The function or class to be decorated. 65 | 66 | Returns: 67 | The decorated function or class.""" 68 | func_or_class._TRIDENT_API = api_names 69 | return func_or_class 70 | 71 | return Decorator 72 | 73 | 74 | def deprecated(version, substitute): 75 | """deprecated warning 76 | Args: 77 | version (str): version that the operator or function is deprecated. 78 | substitute (str): the substitute name for deprecated operator or function. 79 | """ 80 | 81 | def decorate(func): 82 | """Decorator function to print a warning message for deprecated functions. 83 | 84 | Args: 85 | func: The function to be decorated 86 | 87 | Returns: 88 | The decorated function""" 89 | 90 | def wrapper(*args, **kwargs): 91 | """A wrapper function to deprecate a function. 92 | 93 | Args: 94 | *args: Variable length argument list 95 | **kwargs: Arbitrary keyword arguments 96 | 97 | Returns: 98 | The return value of the wrapped function.""" 99 | cls = getattr(args[0], "__class__", None) if args else None 100 | name = cls.__name__ if cls else func.__name__ 101 | print( 102 | f"WARNING: '{func.__name__}' is deprecated from version {version} and will be removed in a future version, " 103 | f"use '{substitute}' instead.") 104 | ret = func(*args, **kwargs) 105 | return ret 106 | 107 | return wrapper 108 | 109 | return decorate 110 | 111 | 112 | def compact(fun: Callable) -> Callable: 113 | """Marks the given module method allowing inlined submodules. 114 | Methods wrapped in @compact can define submodules directly within the method. 115 | For instance:: 116 | @compact 117 | __call__(self, x, features): 118 | x = nn.Dense(features)(x) 119 | ... 120 | At most one method in each Module may be wrapped with @compact. 121 | Args: 122 | fun: The Module method to mark as compact. 123 | Returns: 124 | The given function `fun` marked as compact. 125 | """ 126 | fun.compact = True # type: ignore[attr-defined] 127 | return fun 128 | 129 | 130 | def signature(fun: Callable) -> Callable: 131 | """Generate this Callable's signature 132 | Methods wrapped in @compact can define submodules directly within the method. 133 | For instance:: 134 | @compact 135 | __call__(self, x, features): 136 | x = nn.Dense(features)(x) 137 | ... 138 | At most one method in each Module may be wrapped with @compact. 139 | Args: 140 | fun: The Module method to mark as compact. 141 | Returns: 142 | The given function `fun` marked as compact. 143 | """ 144 | fun.compact = True # type: ignore[attr-defined] 145 | return fun 146 | -------------------------------------------------------------------------------- /trident/models/pretrained_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | import copy 5 | import inspect 6 | import json 7 | import numbers 8 | import os 9 | import shutil 10 | import sys 11 | import time 12 | import uuid 13 | from functools import partial 14 | import builtins 15 | 16 | import numpy as np 17 | from trident.callbacks.lr_schedulers import AdjustLRCallback 18 | 19 | from trident.backend import iteration_tools 20 | from trident.data.dataset import ZipDataset 21 | from trident.backend.common import get_backend, to_list, addindent, get_time_suffix, format_time, get_terminal_size, get_session, \ 22 | snake2camel, PrintException, unpack_singleton, enforce_singleton, OrderedDict, Signature, TensorShape,is_instance 23 | from trident.backend.model import ModelBase, progress_bar 24 | from trident.callbacks.visualization_callbacks import * 25 | from trident.data.data_provider import * 26 | from trident.misc.ipython_utils import * 27 | from trident.misc.visualization_utils import tile_rgb_images, loss_metric_curve 28 | from trident.backend.tensorspec import TensorSpec, assert_spec_compatibility, ObjectType 29 | from trident.loggers.history import HistoryBase 30 | from trident.context import split_path, make_dir_if_need, sanitize_path 31 | 32 | 33 | _session = get_session() 34 | 35 | working_directory=_session.working_directory 36 | 37 | 38 | if get_backend() == 'pytorch': 39 | import torch 40 | import torch.nn as nn 41 | from trident.backend.pytorch_backend import * 42 | from trident.backend.pytorch_ops import * 43 | from trident.optims.pytorch_optimizers import * 44 | from trident.layers.pytorch_layers import * 45 | elif get_backend() == 'tensorflow': 46 | import tensorflow as tf 47 | from trident.backend.tensorflow_backend import * 48 | from trident.backend.tensorflow_ops import * 49 | from trident.optims.tensorflow_optimizers import * 50 | from trident.layers.tensorflow_layers import * 51 | 52 | 53 | def _make_recovery_model_include_top(recovery_model:Layer,default_shape=None,input_shape=None, include_top=True, classes=1000, freeze_features=True): 54 | size_change=False 55 | if default_shape is None: 56 | if recovery_model.built: 57 | default_shape = tuple(recovery_model._input_shape.dims[1:] if isinstance(recovery_model._input_shape, TensorShape) else recovery_model._input_shape) 58 | else: 59 | default_shape = (3, 224, 224) if get_backend() == 'pytorch' else (224, 224, 3) 60 | if input_shape is not None and input_shape !=default_shape: 61 | size_change=True 62 | 63 | if freeze_features: 64 | recovery_model.trainable = False 65 | idx = -1 66 | is_last_dense=True 67 | if isinstance(recovery_model,Sequential): 68 | while (len(recovery_model[idx]._parameters) == 0 or isinstance(recovery_model[idx], Dense)) and len(recovery_model[idx].output_shape) >= 2: 69 | layer = recovery_model[idx] 70 | if layer.output_shape.rank > 2: 71 | break 72 | elif len(recovery_model[idx]._parameters) > 0: 73 | if not include_top: 74 | recovery_model.remove_at(idx) 75 | idx+=1 76 | elif size_change or (is_last_dense and classes != 1000 and isinstance(recovery_model[idx], Dense)): 77 | if hasattr(recovery_model[idx],'num_filters') and recovery_model[idx].num_filters!=classes: 78 | recovery_model[idx].num_filters=classes 79 | recovery_model[idx]._built=False 80 | recovery_model[idx]._parameters.clear() 81 | 82 | else: 83 | recovery_model[idx].trainable = True 84 | else: 85 | if not include_top: 86 | recovery_model.remove_at(idx) 87 | idx+=1 88 | idx -= 1 89 | elif is_instance(recovery_model, 'VisionTransformer'): 90 | recovery_model.trainable=False 91 | if hasattr(recovery_model,'head'): 92 | recovery_model.head.trainable=True 93 | if hasattr(recovery_model, 'fc'): 94 | recovery_model.fc.trainable = True 95 | if hasattr(recovery_model,'norm'): 96 | recovery_model.norm.trainable=True 97 | 98 | 99 | 100 | 101 | 102 | dims =list(default_shape) 103 | dims.insert(0, None) 104 | new_tensorshape =TensorShape(dims) 105 | if size_change: 106 | dims = list(input_shape) 107 | dims.insert(0, None) 108 | new_tensorshape=TensorShape(dims) 109 | for module in recovery_model.modules(): 110 | module._input_shape=None 111 | module._output_shape = None 112 | 113 | recovery_model.to(get_device()) 114 | dummy_input=to_tensor(new_tensorshape.get_dummy_tensor()).to(recovery_model.device) 115 | print(dummy_input.device) 116 | out = recovery_model(dummy_input) 117 | if isinstance(recovery_model.signature, Signature): 118 | recovery_model.signature.inputs.value_list[0].shape = TensorShape(dims) 119 | recovery_model.signature.inputs.value_list[0].object_type=ObjectType.rgb 120 | 121 | return recovery_model 122 | 123 | 124 | -------------------------------------------------------------------------------- /trident/models/README.md: -------------------------------------------------------------------------------- 1 | ## Pretrained Models 預訓練模型 2 | 3 | 4 | 5 | 6 | 7 | ### Image Classification 圖像識別 8 | 9 | | | PyTorch | PyTorch | Tensorflow | Tensorflow | 10 | |:--------------:|:-----------------------------------------------------------------------------------------:|:-------------------:| :----------: | :--------: | 11 | | | google drive | onedrive | google drive | onedrive | 12 | | ResNet18 | [resnet18.pth](https://drive.google.com/open?id=156C4a0_nts8QbjCE8YWbA-QbvCnTrfb5) | resnet18.pth | - | - | 13 | | ResNet50 | [resnet50.pth](https://drive.google.com/open?id=1dYlgpFtqi87KDG54_db4ALWKLARxCWMS) | resnet50.pth |[resnet50_tf.pth](https://drive.google.com/open?id=1vReSW_l8fldyYQ6ay5HCYFGoMaGbdW2T)|resnet50_tf.pth| 14 | | ResNet101 | [resnet101.pth](https://drive.google.com/open?id=17moUOsGynsWALLHyv3yprHWbbDMrdiOP) | resnet101.pth |[resnet101_tf.pth](https://drive.google.com/open?id=13QYdFX3CvsNiegi-iUX1PUC0KKKgPNwr)|resnet101_tf.pth| 15 | | ResNet152 | [resnet152.pth](https://drive.google.com/open?id=1BIaHb7_qunUVvt4TDAwonSKI2jYg4Ybj) | resnet152.pth |[resnet152_tf.pth](https://drive.google.com/open?id=1TeVBB5ynW9E4_EgxIdjugLT8oaXnQH_c)|resnet152_tf.pth| 16 | | EfficientNetB0 | [efficientnet-b0.pth](https://drive.google.com/open?id=1bxnoDerzoNfiZZLft4ocD3DAgx4v6aTN) | efficientnet-b0.pth |[efficientnet-b0_tf.pth](https://drive.google.com/open?id=1pO4wRWY6N4e7U_7E2H-NhBPEF4MlR4ru)|efficientnet-b0_tf.pth| 17 | | EfficientNetB1 | [efficientnet-b1.pth](https://drive.google.com/open?id=1F3BtnAjmDz4G9RS9Q0hqU_K7WWXCni1G) | efficientnet-b1.pth |[efficientnet-b1_tf.pth](https://drive.google.com/open?id=1zCWDn4lwHCn4exAnGfBSPh9YHYTGdIYt)|efficientnet-b1_tf.pth| 18 | | EfficientNetB2 | [efficientnet-b2.pth](https://drive.google.com/open?id=1PjqhB7WJasF_hqOwYtSBNSXSGBY-cRLU) | efficientnet-b2.pth |efficientnet-b2_tf.pth|efficientnet-b2_tf.pth| 19 | | EfficientNetB3 | [efficientnet-b3.pth](https://drive.google.com/open?id=11tMxdYdFfaEREwnESO4cwjtcoEB42zB_) | efficientnet-b3.pth |efficientnet-b3_tf.pth|efficientnet-b3_tf.pth| 20 | | EfficientNetB4 | [efficientnet-b4.pth](https://drive.google.com/open?id=1X4ZOBR_ETRHZJeffJHvCmWTTy9_aW8SP) | efficientnet-b4.pth |efficientnet-b4_tf.pth|efficientnet-b4_tf.pth| 21 | | EfficientNetB5 | [efficientnet-b5.pth](https://drive.google.com/open?id=17iTD12G9oW3jYAui84MKtdY4gjd9vpgG) | efficientnet-b5.pth |efficientnet-b5_tf.pth|efficientnet-b5_tf.pth| 22 | | EfficientNetB6 | [efficientnet-b6.pth](https://drive.google.com/open?id=1XJrKmcmMObN_nnjP2Z-YH_BQ3img58qF) | efficientnet-b6.pth |efficientnet-b6_tf.pth|efficientnet-b6_tf.pth| 23 | | EfficientNetB7 | [efficientnet-b7.pth](https://drive.google.com/open?id=1M2DfvsNPRCWSo_CeXnUCQOR46rvOrhLl) | efficientnet-b7.pth |efficientnet-b7_tf.pth|efficientnet-b7_tf.pth| 24 | | SE_ResNet50 | [se_resnet50.pth](https://drive.google.com/open?id=1vq0uueiHXuHSEFhb02GoEuPzLwrDW_Mb) | se_resnet50.pth |se_resnet50_tf.pth|se_resnet50_tf.pth| 25 | | SE_ResNet101 | [se_resnet101.pth](https://drive.google.com/open?id=17moUOsGynsWALLHyv3yprHWbbDMrdiOP) | se_resnet101.pth |-|-| 26 | | SE_ResNet152 | [se_resnet152.pth](https://drive.google.com/open?id=1L9eGvwVOcH40_lCgCadZQtCrBdeVrsYi) | se_resnet152.pth |-|-| 27 | | VGG11 | [vgg11.pth](https://drive.google.com/open?id=1PV9-AwgD1v-JxDRzduOjjGduIR7MDhPW) | vgg11.pth |-|-| 28 | | VGG13 | [vgg13.pth](https://drive.google.com/open?id=1wx67gmQ8eHWXs2mhJmNl-t-cFNw7dJ7O) | vgg13.pth |-|-| 29 | | VGG16 | [vgg16.pth](https://drive.google.com/open?id=1uXiH5MSy1rvxrHjW4uB9E2BHMM8b0Fwr) | vgg16.pth |vgg16_tf.pth|vgg16_tf.pth| 30 | | VGG19 | [vgg19.pth](https://drive.google.com/open?id=1nqQJLYMzeiUX9hji39-rrBUG42YyjhYg) | vgg19.pth |vgg19_tf.pth|vgg19_tf.pth| 31 | | DenseNet121 | [densenet121.pth](https://drive.google.com/open?id=16N2BECErDMRTV5JqESEBWyylXbQmKAIk) | densenet121.pth |densenet121_tf.pth|densenet121_tf.pth| 32 | | DenseNet161 | [densenet161.pth](https://drive.google.com/open?id=1n3HRkdPbxKrLVua9gOCY6iJnzM8JnBau) | densenet161.pth |densenet161_tf.pth|densenet161_tf.pth| 33 | | DenseNet169 | [densenet169.pth](https://drive.google.com/open?id=1QV73Th0Wo4SCq9AFPVEKqnzs7BUvIG5B) | densenet169.pth |densenet169_tf.pth|densenet169_tf.pth| 34 | | DenseNet201 | [densenet201.pth](https://drive.google.com/open?id=1V2JazzdnrU64lDfE-O4bVIgFNQJ38q3J) | densenet201.pth |densenet201_tf.pth|densenet201_tf.pth| 35 | | MobileNetV2 | [mobilenet_v2.pth](https://drive.google.com/open?id=1ULenXTjOO5PdT3fHv6N8bPXEfoJAn5yL) | mobilenet_v2.pth |mobilenet_v2_tf.pth|mobilenet_v2_tf.pth| 36 | 37 | 38 | 39 | ### Object Detection 目標檢測 40 | 41 | | | PyTorch | PyTorch | Tensorflow | Tensorflow | 42 | | :--------: | :----------: | :------: | :----------: | :--------: | 43 | | | google drive | onedrive | google drive | onedrive | 44 | | MTCNN-PNet | pnet.pth | pnet.pth | - | - | 45 | | MTCNN-RNet | rnet.pth | rnet.pth | | | 46 | | MTCNN-ONet | onet.pth | onet.pth | | | 47 | | | | | | | 48 | 49 | -------------------------------------------------------------------------------- /trident/backend/load_backend.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | import json 5 | import os 6 | import random 7 | from sys import stderr,stdout 8 | from trident.backend.common import * 9 | from trident.backend.decorators import * 10 | 11 | __all__ = [] 12 | 13 | _session=get_session() 14 | _trident_dir=get_trident_dir() 15 | _config_path = os.path.expanduser(os.path.join(_trident_dir, 'trident.json')) 16 | 17 | 18 | def write_config(_config_path): 19 | # _config = { 20 | # 'floatx': _session.floatx, 21 | # 'epsilon': _session.epsilon, 22 | # 'backend': _session.backend , 23 | # 'image_backend':_session.image_backend 24 | # } 25 | try: 26 | with open(_config_path, 'w') as f: 27 | session=_session.__dict__.copy() 28 | session.pop('_thread_local_info') 29 | session.pop('_context_handle') 30 | session.pop('_module_dict') 31 | session.pop('print') 32 | f.write(json.dumps(session, indent=4)) 33 | except IOError: 34 | # Except permission denied. 35 | pass 36 | 37 | 38 | # Save config file, if possible. 39 | if not os.path.exists(_trident_dir): 40 | try: 41 | os.makedirs(_trident_dir) 42 | except OSError: 43 | # Except permission denied and potential race conditions 44 | # in multi-threaded environments. 45 | pass 46 | 47 | 48 | # Set backend based on TRIDENT_BACKEND flag, if applicable. 49 | if 'TRIDENT_BACKEND' in os.environ: 50 | if _session.backend!=os.environ['TRIDENT_BACKEND']: 51 | _session.backend = os.environ['TRIDENT_BACKEND'] 52 | if 'TRIDENT_WORKING_DIR' in os.environ: 53 | _session.working_directory = os.environ['TRIDENT_WORKING_DIR'] 54 | os.chdir(os.environ['TRIDENT_WORKING_DIR']) 55 | write_config(_config_path) 56 | 57 | 58 | 59 | if _session.backend== 'pytorch': 60 | stdout.write('Using Pytorch backend.\n') 61 | stdout.write('Image Data Format: channels_first.\n') 62 | stdout.write('Image Channel Order: rgb.\n') 63 | _session.image_data_format='channels_first' 64 | _session.image_channel_order='rgb' 65 | import torch 66 | from trident.backend.pytorch_ops import * 67 | 68 | if torch.cuda.is_available(): 69 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 70 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 71 | set_session('device','cuda') 72 | elif is_tpu_available(): 73 | import torch_xla.core.xla_model as xm 74 | 75 | # os.environ['XLA_USE_BF16'] = '1' 76 | # os.environ['XLA_TENSOR_ALLOCATOR_MAXSIZE'] = '1000000000' 77 | set_session('device', 'tpu') 78 | set_session('print', xm.master_print) 79 | 80 | 81 | 82 | 83 | from trident.backend.pytorch_backend import * 84 | 85 | elif _session.backend == 'tensorflow': 86 | stdout.write('Using TensorFlow backend.\n') 87 | stdout.write('Image Data Format: channels_last.\n') 88 | stdout.write('Image Channel Order: rgb.\n') 89 | _session.image_data_format = 'channels_last' 90 | _session.image_channel_order = 'rgb' 91 | 92 | 93 | import tensorflow as tf 94 | 95 | gpus = tf.config.list_physical_devices('GPU') 96 | if gpus: 97 | try: 98 | tf.config.experimental.set_memory_growth(gpus[0], False) 99 | # tf.config.experimental.set_virtual_device_configuration( 100 | # gpus[0], 101 | # [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1024 * 2)]) 102 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 103 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 104 | set_session('device', '/gpu:0') 105 | tf.config.set_visible_devices(gpus[0], 'GPU') 106 | 107 | logical_gpus = tf.config.list_logical_devices('GPU') 108 | print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs") 109 | 110 | except RuntimeError as e: 111 | # Virtual devices must be set before GPUs have been initialized 112 | print(e) 113 | 114 | else: 115 | if "CUDA_VISIBLE_DEVICES" in os.environ: 116 | os.environ.pop("CUDA_VISIBLE_DEVICES") 117 | set_session('device', '/cpu:0') 118 | 119 | from trident.backend.tensorflow_ops import * 120 | from trident.backend.tensorflow_backend import * 121 | 122 | elif _session.backend == 'jax': 123 | stdout.write('Using Jax backend.\n') 124 | stdout.write('Image Data Format: channels_last.\n') 125 | stdout.write('Image Channel Order: rgb.\n') 126 | _session.image_data_format = 'channels_last' 127 | _session.image_channel_order = 'rgb' 128 | 129 | import jax 130 | from trident.backend.jax_ops import * 131 | 132 | if is_gpu_available(): 133 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 134 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 135 | set_session('device','cuda') 136 | elif is_tpu_available(): 137 | if is_tpu_available(): 138 | set_session('device', 'tpu') 139 | else: 140 | set_session('device', 'cpu') 141 | from trident.backend.jax_backend import * 142 | 143 | 144 | elif _session.backend == 'onnx': 145 | stdout.write('Using ONNX backend.\n') 146 | stdout.write('Image Data Format: channels_first.\n') 147 | stdout.write('Image Channel Order: rgb.\n') 148 | _session.image_data_format = 'channels_first' 149 | _session.image_channel_order = 'rgb' 150 | 151 | 152 | 153 | from trident.backend.opencv_backend import * 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | if not os.path.exists(_config_path): 164 | write_config(_config_path) 165 | 166 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/lib 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | tools/* 18 | experiments/* 19 | examples/Results/* 20 | datav3/* 21 | internal_tool/* 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | test/ 26 | tools/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | pip-wheel-metadata/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit tests / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | cover/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | db.sqlite3-journal 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | .pybuilder/ 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | #unused 96 | *_1.py 97 | _*.py 98 | *.pypirc 99 | .pypirc 100 | 101 | 102 | # pyenv 103 | # For a library or package, you might want to ignore these files since the code is 104 | # intended to run in multiple environments; otherwise, check them in: 105 | # .python-version 106 | 107 | # pipenv 108 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 109 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 110 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 111 | # install all needed dependencies. 112 | #Pipfile.lock 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | ### JetBrains template 158 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 159 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 160 | 161 | # User-specific stuff 162 | .idea/**/workspace.xml 163 | .idea/**/tasks.xml 164 | .idea/**/usage.statistics.xml 165 | .idea/**/dictionaries 166 | .idea/**/shelf 167 | 168 | # Generated files 169 | .idea/**/contentModel.xml 170 | 171 | # Sensitive or high-churn files 172 | .idea/**/dataSources/ 173 | .idea/**/dataSources.ids 174 | .idea/**/dataSources.local.xml 175 | .idea/**/sqlDataSources.xml 176 | .idea/**/dynamic.xml 177 | .idea/**/uiDesigner.xml 178 | .idea/**/dbnavigator.xml 179 | 180 | # Gradle 181 | .idea/**/gradle.xml 182 | .idea/**/libraries 183 | 184 | # Gradle and Maven with auto-import 185 | # When using Gradle or Maven with auto-import, you should exclude module files, 186 | # since they will be recreated, and may cause churn. Uncomment if using 187 | # auto-import. 188 | # .idea/artifacts 189 | # .idea/compiler.xml 190 | # .idea/jarRepositories.xml 191 | # .idea/modules.xml 192 | # .idea/*.iml 193 | # .idea/modules 194 | # *.iml 195 | # *.ipr 196 | 197 | # CMake 198 | cmake-build-*/ 199 | 200 | # Mongo Explorer plugin 201 | .idea/**/mongoSettings.xml 202 | 203 | # File-based project format 204 | *.iws 205 | 206 | # IntelliJ 207 | out/ 208 | 209 | # mpeltonen/sbt-idea plugin 210 | .idea_modules/ 211 | 212 | # JIRA plugin 213 | atlassian-ide-plugin.xml 214 | 215 | # Cursive Clojure plugin 216 | .idea/replstate.xml 217 | 218 | # Crashlytics plugin (for Android Studio and IntelliJ) 219 | com_crashlytics_export_strings.xml 220 | crashlytics.properties 221 | crashlytics-build.properties 222 | fabric.properties 223 | 224 | # Editor-based Rest Client 225 | .idea/httpRequests 226 | 227 | # Android studio 3.1+ serialized cache file 228 | .idea/caches/build_file_checksums.ser 229 | 230 | 231 | .idea/.gitignore 232 | .idea/DeepTrident.iml 233 | .idea/codeStyles/ 234 | .idea/encodings.xml 235 | .idea/inspectionProfiles/ 236 | .idea/misc.xml 237 | .idea/modules.xml 238 | .idea/other.xml 239 | .idea/vcs.xml 240 | 241 | trident/experiments/tfe_layers.py 242 | .idea/.name 243 | .idea/Trident.iml 244 | 245 | 246 | docs/en-us/ 247 | docs/zh-cn/ 248 | 249 | 250 | trident.png 251 | sphinx_docs/ 252 | .idea/markdown-navigator.xml 253 | .idea/markdown-navigator/ 254 | .idea/vagrant.xml 255 | .travis/ 256 | build/bdist.win-amd64/ 257 | docs/zh-tw/advanced/ 258 | docs/zh-tw/callbacks/ 259 | docs/zh-tw/custom/ 260 | docs/zh-tw/data/ 261 | docs/zh-tw/layers/ 262 | docs/zh-tw/models/ 263 | docs/zh-tw/ops/ 264 | docs/zh-tw/optims/ 265 | sphinx_docs/source/Models/ 266 | 267 | trident/backend/graph_tools.py 268 | trident/backend/onnx_ops.py 269 | trident/layers/pytorch_heads.py 270 | trident/layers/pytorch_necks.py 271 | trident/models/pytorch_xception.py 272 | trident/tools/model_converter.py 273 | tests/data/transform_results/random_adjust_brightness1.png 274 | tests/data/transform_results/random_adjust_contrast.png 275 | tests/data/transform_results/random_adjust_hue.png 276 | tests/data/transform_results/random_adjust_saturation.png 277 | tests/data/transform_results/random_color_jitter.png 278 | tests/data/transform_results/random_lighting.png 279 | -------------------------------------------------------------------------------- /trident/data/preprocess_policy.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import inspect 6 | import random 7 | import time 8 | import builtins 9 | import numpy as np 10 | from PIL import Image, ImageEnhance, ImageOps 11 | from trident.backend.common import * 12 | from trident.backend.tensorspec import TensorSpec, assert_input_compatibility, ObjectType,object_type_inference 13 | from trident.data.image_common import image_backend_adaption 14 | 15 | if get_backend() == 'pytorch': 16 | from trident.backend.pytorch_ops import * 17 | elif get_backend() == 'tensorflow': 18 | from trident.backend.tensorflow_ops import * 19 | 20 | __all__ = ['PreprocessPolicy', 'PreprocessPolicyItem'] 21 | 22 | 23 | class PreprocessPolicy(object): 24 | """ Randomly choose one of the best 24 Sub-policies on ImageNet. 25 | 26 | 27 | 28 | 29 | """ 30 | 31 | def __init__(self,*args): 32 | self.policies = [] 33 | for arg in args: 34 | self.add(arg) 35 | self.pass_cnt=0 36 | self.pass_time_spend=0. 37 | 38 | def add(self, item): 39 | if isinstance(item, PreprocessPolicyItem): 40 | if len(item.name) == 0: 41 | item.name = 'item_{0}'.format(len(self.policies)) 42 | self.policies.append(item) 43 | def reset_statistics(self): 44 | self.pass_cnt=0 45 | self.pass_time_spend=0 46 | for item in self.policies: 47 | if isinstance(item, PreprocessPolicyItem): 48 | item.reset_statistics() 49 | 50 | def print_statistics(self): 51 | print('avg. process time: {0:.5f}'.format(self.pass_time_spend/float(builtins.max(1,self.pass_cnt)))) 52 | 53 | for item in self.policies: 54 | if isinstance(item,PreprocessPolicyItem): 55 | print(' policy {0} hit-rate={1:.3%}'.format(item.name, item.hit_rate)) 56 | print(' avg. time spend (true): {0}'.format(item.time_spend_true/float(builtins.max(1,item.count_true)))) 57 | print(' avg. time spend (false):{0}'.format(item.time_spend_false/float(builtins.max(1,item.count_false)))) 58 | 59 | 60 | def __call__(self, img,spec:TensorSpec=None,**kwargs): 61 | if isinstance(img, np.ndarray): 62 | start_time = time.time() 63 | if spec is None: 64 | spec = TensorSpec(shape=to_tensor(img.shape), object_type=object_type_inference(img)) 65 | if spec.object_type==ObjectType.rgb or spec.object_type==ObjectType.rgb or spec.object_type==ObjectType.gray: 66 | if (img.ndim==3 and img.shape[0] in [1,3,4]): 67 | img=img.transpose(1,2,0) 68 | 69 | for i in range(len(self.policies)): 70 | try: 71 | item = self.policies[i] 72 | img = item(img,spec=spec) 73 | except Exception as e: 74 | print(e) 75 | img=image_backend_adaption(img) 76 | self.pass_cnt+=1 77 | self.pass_time_spend+=float(time.time()-start_time) 78 | return img 79 | 80 | def __repr__(self): 81 | return "PreprocessPolicy" 82 | 83 | 84 | class PreprocessPolicyItem(object): 85 | def __init__(self, condition_if, then_process, else_process=None, name=''): 86 | # lambda or function 87 | if inspect.isfunction(condition_if) or callable(condition_if) or isinstance(condition_if, bool): 88 | self.condition_if = condition_if 89 | else: 90 | print('PreprocessPolicyItem {0} condition_if is not callable'.format(name)) 91 | if callable(then_process) or (isinstance(then_process, list) and callable(then_process[0])): 92 | self.then_process = then_process 93 | else: 94 | print('PreprocessPolicyItem {0} then_process is not callable'.format(name)) 95 | 96 | if else_process is None or callable(else_process) or ( 97 | isinstance(else_process, list) and callable(else_process[0])): 98 | self.else_process = else_process 99 | else: 100 | print('PreprocessPolicyItem {0} else_process is not callable'.format(name)) 101 | self.name = name 102 | self.count_true = 0 103 | self.count_false = 0 104 | self.time_spend_true = 0. 105 | self.time_spend_false = 0. 106 | 107 | @property 108 | def hit_rate(self): 109 | return self.count_true / float(max((self.count_true + self.count_false), 1)) 110 | 111 | def reset_statistics(self): 112 | self.count_true = 0 113 | self.count_false = 0 114 | self.time_spend_true = 0 115 | self.time_spend_false = 0 116 | 117 | def __call__(self, img,spec:TensorSpec=None,**kwargs): 118 | start_time=time.time() 119 | bool_if=None 120 | if isinstance(self.condition_if,bool): 121 | bool_if=self.condition_if 122 | elif inspect.isfunction(self.condition_if) or callable(self.condition_if) : 123 | argspec=inspect.getfullargspec(self.condition_if) 124 | if "spec" in argspec.args: 125 | bool_if =self.condition_if(img,spec=spec) 126 | else: 127 | bool_if = self.condition_if(img) 128 | 129 | if bool_if==True: 130 | if isinstance(self.then_process, list): 131 | for proc in self.then_process: 132 | img = proc(img,spec=spec) 133 | elif callable(self.then_process) or inspect.isfunction(self.then_process): 134 | img = self.then_process(img,spec=spec) 135 | self.count_true += 1 136 | self.time_spend_true+=float(time.time()-start_time) 137 | elif bool_if==False: 138 | if self.else_process is not None: 139 | if isinstance(self.else_process, list): 140 | for proc in self.else_process: 141 | img = proc(img,spec=spec) 142 | elif callable(self.else_process) or inspect.isfunction(self.else_process): 143 | img = self.else_process(img,spec=spec) 144 | self.count_false += 1 145 | self.time_spend_false+= float(time.time() - start_time) 146 | else: 147 | self.count_false += 1 148 | self.time_spend_false += float(time.time() - start_time) 149 | pass 150 | return img -------------------------------------------------------------------------------- /trident/models/pytorch_arcfacenet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | 6 | import inspect 7 | import math 8 | import os 9 | import uuid 10 | from collections import * 11 | from collections import deque 12 | from copy import copy, deepcopy 13 | from functools import partial 14 | from itertools import repeat 15 | 16 | import numpy as np 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | from collections import abc 21 | from torch.nn import init 22 | from trident.models.pretrained_utils import _make_recovery_model_include_top 23 | from trident.data.vision_transforms import Resize,Normalize 24 | from trident.backend.common import * 25 | from trident.backend.pytorch_backend import to_numpy, to_tensor, Layer, Sequential, summary, fix_layer, load,get_device 26 | from trident.data.image_common import * 27 | from trident.data.utils import download_model_from_google_drive 28 | from trident.layers.pytorch_activations import get_activation, Identity, Relu, PRelu 29 | from trident.layers.pytorch_blocks import * 30 | from trident.layers.pytorch_layers import * 31 | from trident.layers.pytorch_normalizations import get_normalization, BatchNorm2d, BatchNorm 32 | from trident.layers.pytorch_pooling import * 33 | from trident.optims.pytorch_trainer import * 34 | 35 | __all__ = ['SEResNet_IR','load','BottleNeck_IR_SE','BottleNeck_IR','SEResNet_IR_50_512'] 36 | 37 | _session = get_session() 38 | _device = get_device() 39 | _epsilon=_session.epsilon 40 | _trident_dir=_session.trident_dir 41 | _backend = get_backend() 42 | 43 | dirname = os.path.join(_trident_dir, 'models') 44 | if not os.path.exists(dirname): 45 | try: 46 | os.makedirs(dirname) 47 | except OSError: 48 | # Except permission denied and potential race conditions 49 | # in multi-threaded environments. 50 | pass 51 | 52 | 53 | 54 | 55 | 56 | 57 | def BottleNeck_IR(num_filters, strides,keep_filter=True): 58 | blocks = OrderedDict() 59 | blocks['res_layer'] = Sequential(BatchNorm2d(), 60 | Conv2d_Block((3, 3), num_filters=num_filters, strides=1, auto_pad=True, use_bias=False, activation=PRelu(num_filters)), 61 | Conv2d_Block((3, 3), num_filters, strides=strides, use_bias=False, activation=None, normalization='batch')) 62 | if keep_filter: 63 | blocks['shortcut_layer']=MaxPool2d(1, strides=strides,name='shortcut_layer') 64 | else: 65 | blocks['shortcut_layer'] = Conv2d_Block((1, 1), num_filters, strides=strides, use_bias=False, activation=None, normalization='batch') 66 | return ShortCut2d(blocks,mode='add') 67 | 68 | def BottleNeck_IR_SE( num_filters, strides,keep_filter=True): 69 | blocks = OrderedDict() 70 | blocks['res_layer'] = Sequential(BatchNorm2d(), 71 | Conv2d_Block((3, 3), num_filters=num_filters, strides=1, auto_pad=True, use_bias=False, activation=PRelu(num_filters)), 72 | Conv2d_Block((3, 3), num_filters, strides=strides, use_bias=False, activation=None, normalization='batch'), 73 | SqueezeExcite(num_filters//16,num_filters),name='res_layer') 74 | if keep_filter: 75 | blocks['shortcut_layer'] = MaxPool2d(1, strides=strides, name='shortcut_layer') 76 | 77 | else: 78 | blocks['shortcut_layer'] =Conv2d_Block((1, 1), num_filters, strides=strides, use_bias=False, activation=None, normalization='batch',name='shortcut_layer') 79 | return ShortCut2d(blocks,mode='add') 80 | 81 | 82 | def get_block(Bottleneck, out_channel, num_units, strides=2,keep_filter=True): 83 | blocks=[Bottleneck(out_channel, strides,keep_filter)] 84 | for i in range(num_units - 1): 85 | blocks.append(Bottleneck(out_channel, 1,True)) 86 | return blocks 87 | 88 | 89 | 90 | def SEResNet_IR(include_top=True,num_layers=50,Bottleneck=BottleNeck_IR_SE,drop_ratio=0.4,feature_dim=128,input_shape=(3,112,112)): 91 | blocks=OrderedDict() 92 | blocks['input_layer']=Conv2d_Block((3,3),64,strides=1,auto_pad=True,use_bias=False,activation=PRelu(64),normalization='batch',name='input_layer') 93 | blocks['body']=Sequential( 94 | get_block(Bottleneck, out_channel=64, num_units=3,keep_filter=True)+ 95 | get_block(Bottleneck, out_channel=128, num_units=4,keep_filter=False)+ 96 | get_block(Bottleneck, out_channel=256, num_units=14,keep_filter=False)+ 97 | get_block(Bottleneck, out_channel=512, num_units=3,keep_filter=False) 98 | ) 99 | blocks['output_layer']=Sequential( 100 | BatchNorm2d(), 101 | Dropout(drop_ratio), 102 | Flatten(), 103 | Dense(feature_dim), 104 | BatchNorm(), 105 | name='output_layer' 106 | ) 107 | facenet=Sequential(blocks).to(_device) 108 | facenet.name=camel2snake('SEResNet_IR') 109 | model=FaceRecognitionModel(input_shape=input_shape,output=facenet) 110 | model.preprocess_flow=[Resize((input_shape[1],input_shape[2]),keep_aspect=True),Normalize(0,255),Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])] 111 | #model.summary() 112 | return model 113 | 114 | 115 | def SEResNet_IR_50_512(include_top=True, 116 | pretrained=True, 117 | freeze_features=True, 118 | input_shape=(3,112,112), 119 | classes=1000, 120 | **kwargs): 121 | if input_shape is not None and len(input_shape)==3: 122 | input_shape=tuple(input_shape) 123 | else: 124 | input_shape=(3, 112, 112) 125 | seresnet = SEResNet_IR(include_top=include_top,num_layers=50,Bottleneck=BottleNeck_IR_SE,drop_ratio=0.4,feature_dim=512,input_shape=input_shape) 126 | if pretrained: 127 | download_model_from_google_drive('1aLYbFvtvsV2gQ16D_vwzrKdbCgij7IoZ', dirname, 'arcface_se_50_512.pth') 128 | recovery_model = load(os.path.join(dirname, 'arcface_se_50_512.pth')) 129 | recovery_model = fix_layer(recovery_model) 130 | recovery_model.name = 'arcface_se_50_512' 131 | recovery_model = _make_recovery_model_include_top(recovery_model, include_top=include_top, classes=classes, freeze_features=freeze_features) 132 | seresnet.model = recovery_model 133 | else: 134 | seresnet.model = _make_recovery_model_include_top(seresnet.model, include_top=include_top, classes=classes, freeze_features=True) 135 | seresnet.model.input_shape = input_shape 136 | seresnet.model.to(_device) 137 | return seresnet 138 | 139 | 140 | 141 | 142 | -------------------------------------------------------------------------------- /trident/models/pytorch_deeplab.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import inspect 6 | import math 7 | import os 8 | import uuid 9 | from collections import * 10 | from collections import deque 11 | from copy import copy, deepcopy 12 | from functools import partial 13 | from itertools import repeat 14 | 15 | import numpy as np 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | from collections import abc 20 | from torch.nn import init 21 | 22 | from trident.backend.common import * 23 | from trident.backend.pytorch_backend import to_numpy, to_tensor, Layer, Sequential, summary,get_device 24 | from trident.data.image_common import * 25 | from trident.data.utils import download_model_from_google_drive 26 | from trident.layers.pytorch_activations import get_activation, Identity, Relu 27 | from trident.layers.pytorch_blocks import * 28 | from trident.layers.pytorch_layers import * 29 | from trident.layers.pytorch_normalizations import get_normalization, BatchNorm2d 30 | from trident.layers.pytorch_pooling import * 31 | from trident.optims.pytorch_trainer import * 32 | from trident.data.vision_transforms import Resize,Normalize 33 | __all__ = ['DeeplabV3_plus','DeeplabV3'] 34 | 35 | _session = get_session() 36 | _device = get_device() 37 | _epsilon=_session.epsilon 38 | _trident_dir=_session.trident_dir 39 | _backend = _session.backend 40 | 41 | dirname = os.path.join(_trident_dir, 'models') 42 | if not os.path.exists(dirname): 43 | try: 44 | os.makedirs(dirname) 45 | except OSError: 46 | # Except permission denied and potential race conditions 47 | # in multi-threaded environments. 48 | pass 49 | 50 | 51 | 52 | 53 | def DeepLabHead(classes=20, atrous_rates=(6, 12, 18,24),num_filters=256): 54 | return Sequential( 55 | ASPP(atrous_rates,num_filters=num_filters), 56 | Conv2d_Block((3,3),num_filters,auto_pad=True,use_bias=False,activation='relu',normalization='batch'), 57 | Conv2d((1,1),num_filters=classes,strides=1,auto_pad=True,activation=None,name='classifier'), 58 | SoftMax() 59 | ) 60 | 61 | 62 | 63 | def ASPPPooling(num_filters): 64 | return Sequential(AdaptiveAvgPool2d((1,1)), 65 | Conv2d((1,1),num_filters,strides=1,use_bias=False,activation=None), 66 | Upsampling2d(scale_factor=14,mode='bilinear', align_corners=False)) 67 | 68 | 69 | 70 | def ASPP(atrous_rates=(6,12,18),num_filters=256): 71 | layers=OrderedDict() 72 | layers['conv1']=Conv2d_Block((1,1),num_filters=num_filters,strides=1,use_bias=False,activation=None,normalization='batch') 73 | for i in range(len(atrous_rates)): 74 | layers['aspp_dilation{0}'.format(i)]=Conv2d_Block((3,3),num_filters=num_filters,strides=1,use_bias=False,activation=None,normalization='batch',dilation=atrous_rates[i]) 75 | layers['aspp_pooling'] =ASPPPooling(num_filters) 76 | return Sequential( 77 | ShortCut2d(layers,mode='concate'), 78 | Conv2d_Block((1, 1), num_filters, strides=1, use_bias=False, bias=False, activation='relu', normalization='batch', dilation=1, dropout_rate=0.5, name='project') 79 | ) 80 | 81 | 82 | 83 | 84 | 85 | 86 | def DeeplabV3(backbond, 87 | input_shape=(3,224,224), 88 | classes=20, 89 | **kwargs): 90 | input_shape=tuple(input_shape) 91 | deeplab=Sequential(name='deeplabv3') 92 | 93 | deeplab.add_module('backbond',backbond) 94 | deeplab.add_module('classifier', DeepLabHead(classes=classes,num_filters=128)) 95 | deeplab.add_module('upsample', Upsampling2d(scale_factor=16, mode='bilinear', align_corners=False)) 96 | model = ImageSegmentationModel(input_shape=input_shape, output=deeplab) 97 | return model 98 | 99 | 100 | 101 | class _DeeplabV3_plus(Layer): 102 | def __init__(self, backbond, input_shape=(3,224,224), atrous_rates=(6, 12, 18, 24), num_filters=256, classes=20): 103 | super(_DeeplabV3_plus, self).__init__() 104 | moduals=list(backbond.children()) 105 | low_level_idx=-1 106 | high_level_idx=-1 107 | for i in range(len(moduals)): 108 | if low_level_idx<0 and moduals[i].output_shape[-1]==backbond.input_shape[-1]//8: 109 | low_level_idx=i 110 | 111 | if high_level_idx<0 and moduals[i].output_shape[-1]==backbond.input_shape[-1]//32: 112 | high_level_idx=i 113 | break 114 | self.num_filters=num_filters 115 | self.classes=classes 116 | self.atrous_rates=atrous_rates 117 | self.backbond1=Sequential(*backbond[:low_level_idx]) 118 | self.backbond2 = Sequential(*backbond[low_level_idx:high_level_idx]) 119 | self.aspp=ASPP(atrous_rates=self.atrous_rates,num_filters=self.num_filters) 120 | self.low_level_conv=Conv2d_Block((1,1),num_filters=int(48*self.num_filters/256),strides=1,use_bias=False,activation='leaky_relu',normalization='batch') 121 | self.decoder=Sequential( 122 | DepthwiseConv2d_Block((3,3),depth_multiplier=0.5,strides=1,use_bias=False,activation='leaky_relu',normalization='batch',dropout_rate=0.5), 123 | DepthwiseConv2d_Block((3,3),depth_multiplier=1,strides=1,use_bias=False,activation='leaky_relu',normalization='batch',dropout_rate=0.1), 124 | Conv2d((1, 1), num_filters=self.classes, strides=1, use_bias=False, activation=None), 125 | SoftMax() 126 | 127 | ) 128 | 129 | def forward(self, x,**kwargs): 130 | low_level_feature=self.backbond1(x) 131 | high_level_feature = self.backbond2(low_level_feature) 132 | x=self.aspp(high_level_feature) 133 | x=F.interpolate(x, None, (4.0,4.0),mode='bilinear', align_corners=True) 134 | low_level_feature=self.low_level_conv(low_level_feature) 135 | x=torch.cat([x,low_level_feature],dim=1) 136 | x=self.decoder(x) 137 | x = F.interpolate(x, None, (4.0, 4.0), mode='bilinear', align_corners=True) 138 | return x 139 | 140 | 141 | 142 | def DeeplabV3_plus(backbond, 143 | input_shape=(3,224,224), 144 | atrous_rates = (6, 12, 18, 24), 145 | num_filters = 256, 146 | classes=20, 147 | **kwargs): 148 | deeplab=_DeeplabV3_plus(backbond=backbond,input_shape=input_shape,atrous_rates=atrous_rates,num_filters=num_filters,classes=classes) 149 | deeplab.name='DeeplabV3_plus' 150 | model = ImageSegmentationModel(input_shape=input_shape, output=deeplab) 151 | return model 152 | -------------------------------------------------------------------------------- /trident/data/augment_policy.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | from PIL import Image, ImageEnhance, ImageOps 5 | 6 | __all__ = ['AugmentPolicy','SubPolicy'] 7 | 8 | class AugmentPolicy(object): 9 | """ Randomly choose one of the best 24 Sub-policies on ImageNet. 10 | 11 | Examples: 12 | >>> policy = AugmentPolicy() 13 | >>> transformed = policy(image) 14 | 15 | Example as a PyTorch Transform: 16 | >>> transform=transforms.Compose([ 17 | >>> transforms.Resize(256), 18 | >>> ImageNetPolicy(), 19 | >>> transforms.ToTensor()]) 20 | """ 21 | def __init__(self, fillcolor=(128, 128, 128)): 22 | self.policies = [ 23 | SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), 24 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 25 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), 26 | SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), 27 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 28 | 29 | SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), 30 | SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), 31 | SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), 32 | SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), 33 | SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), 34 | 35 | SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), 36 | SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), 37 | SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), 38 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 39 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 40 | 41 | SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), 42 | SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), 43 | SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), 44 | SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), 45 | SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), 46 | 47 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 48 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 49 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 50 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 51 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor) 52 | ] 53 | 54 | 55 | def __call__(self, img): 56 | policy_idx = random.randint(0, len(self.policies) - 1) 57 | return self.policies[policy_idx](img) 58 | 59 | def __repr__(self): 60 | return "AutoAugment ImageNet Policy" 61 | 62 | 63 | 64 | 65 | 66 | class SubPolicy(object): 67 | def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)): 68 | ranges = { 69 | "shearX": np.linspace(0, 0.3, 10), 70 | "shearY": np.linspace(0, 0.3, 10), 71 | "translateX": np.linspace(0, 150 / 331, 10), 72 | "translateY": np.linspace(0, 150 / 331, 10), 73 | "rotate": np.linspace(0, 30, 10), 74 | "color": np.linspace(0.0, 0.9, 10), 75 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), 76 | "solarize": np.linspace(256, 0, 10), 77 | "contrast": np.linspace(0.0, 0.9, 10), 78 | "sharpness": np.linspace(0.0, 0.9, 10), 79 | "brightness": np.linspace(0.0, 0.9, 10), 80 | "autocontrast": [0] * 10, 81 | "equalize": [0] * 10, 82 | "invert": [0] * 10 83 | } 84 | 85 | # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand 86 | def rotate_with_fill(img, magnitude): 87 | rot = img.convert("RGBA").rotate(magnitude) 88 | return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode) 89 | 90 | func = { 91 | "shearX": lambda img, magnitude: img.transform( 92 | img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), 93 | Image.BICUBIC, fillcolor=fillcolor), 94 | "shearY": lambda img, magnitude: img.transform( 95 | img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), 96 | Image.BICUBIC, fillcolor=fillcolor), 97 | "translateX": lambda img, magnitude: img.transform( 98 | img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), 99 | fillcolor=fillcolor), 100 | "translateY": lambda img, magnitude: img.transform( 101 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), 102 | fillcolor=fillcolor), 103 | "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), 104 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])), 105 | "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), 106 | "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), 107 | "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( 108 | 1 + magnitude * random.choice([-1, 1])), 109 | "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( 110 | 1 + magnitude * random.choice([-1, 1])), 111 | "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( 112 | 1 + magnitude * random.choice([-1, 1])), 113 | "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), 114 | "equalize": lambda img, magnitude: ImageOps.equalize(img), 115 | "invert": lambda img, magnitude: ImageOps.invert(img) 116 | } 117 | 118 | self.p1 = p1 119 | self.operation1 = func[operation1] 120 | self.magnitude1 = ranges[operation1][magnitude_idx1] 121 | self.p2 = p2 122 | self.operation2 = func[operation2] 123 | self.magnitude2 = ranges[operation2][magnitude_idx2] 124 | 125 | 126 | def __call__(self, img): 127 | if random.random() < self.p1: img = self.operation1(img, self.magnitude1) 128 | if random.random() < self.p2: img = self.operation2(img, self.magnitude2) 129 | return img -------------------------------------------------------------------------------- /trident/models/tensorflow_deeplab.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import inspect 6 | import math 7 | import os 8 | import uuid 9 | from collections import * 10 | from collections import deque 11 | from copy import copy, deepcopy 12 | from functools import partial 13 | from itertools import repeat 14 | import tensorflow as tf 15 | import numpy as np 16 | from tensorflow.python.ops import image_ops 17 | from tensorflow.python.ops.image_ops_impl import ResizeMethod 18 | 19 | from trident.backend.common import * 20 | from trident.backend.tensorflow_ops import * 21 | from trident.backend.tensorflow_backend import to_numpy, to_tensor, Layer, Sequential, summary 22 | from trident.data.image_common import * 23 | from trident.data.utils import download_model_from_google_drive 24 | from trident.layers.tensorflow_activations import get_activation, Identity, Relu 25 | from trident.layers.tensorflow_blocks import * 26 | from trident.layers.tensorflow_layers import * 27 | from trident.layers.tensorflow_normalizations import get_normalization, BatchNorm2d 28 | from trident.layers.tensorflow_pooling import * 29 | from trident.optims.tensorflow_trainer import * 30 | from trident.data.vision_transforms import Resize,Normalize 31 | __all__ = ['DeeplabV3_plus','DeeplabV3'] 32 | 33 | _session = get_session() 34 | _epsilon=_session.epsilon 35 | _trident_dir=_session.trident_dir 36 | _backend = get_backend() 37 | 38 | dirname = os.path.join(_trident_dir, 'models') 39 | if not os.path.exists(dirname): 40 | try: 41 | os.makedirs(dirname) 42 | except OSError: 43 | # Except permission denied and potential race conditions 44 | # in multi-threaded environments. 45 | pass 46 | 47 | 48 | 49 | 50 | 51 | def DeepLabHead(classes=20, atrous_rates=(6, 12, 18,24),num_filters=256): 52 | return Sequential( 53 | ASPP(atrous_rates,num_filters=num_filters), 54 | Conv2d_Block((3,3),num_filters,auto_pad=True,use_bias=False,activation='relu',normalization='batch'), 55 | Conv2d((1,1),num_filters=classes,strides=1,auto_pad=True,activation=None,name='classifier'), 56 | SoftMax() 57 | ) 58 | 59 | 60 | 61 | def ASPPPooling(num_filters): 62 | return Sequential(AdaptiveAvgPool2d((1,1)), 63 | Conv2d((1,1),num_filters,strides=1,use_bias=False,activation=None), 64 | Upsampling2d(scale_factor=14,mode='bilinear', align_corners=False)) 65 | 66 | 67 | 68 | def ASPP(atrous_rates=(6,12,18),num_filters=256): 69 | layers=OrderedDict() 70 | layers['conv1']=Conv2d_Block((1,1),num_filters=num_filters,strides=1,use_bias=False,activation=None,normalization='batch') 71 | for i in range(len(atrous_rates)): 72 | layers['aspp_dilation{0}'.format(i)]=Conv2d_Block((3,3),num_filters=num_filters,strides=1,use_bias=False,activation=None,normalization='batch',dilation=atrous_rates[i]) 73 | layers['aspp_pooling'] =ASPPPooling(num_filters) 74 | return Sequential( 75 | ShortCut2d(layers,mode='concate'), 76 | Conv2d_Block((1, 1), num_filters, strides=1, use_bias=False, bias=False, activation='relu', normalization='batch', dilation=1, dropout_rate=0.5, name='project') 77 | ) 78 | 79 | 80 | 81 | 82 | 83 | 84 | def DeeplabV3(backbond, 85 | input_shape=(224,224,3), 86 | classes=20, 87 | **kwargs): 88 | input_shape=tuple(input_shape) 89 | deeplab=Sequential(name='deeplabv3') 90 | 91 | deeplab.add_module('backbond',backbond) 92 | deeplab.add_module('classifier', DeepLabHead(classes=classes,num_filters=128)) 93 | deeplab.add_module('upsample', Upsampling2d(scale_factor=16, mode='bilinear', align_corners=False)) 94 | model = ImageSegmentationModel(input_shape=input_shape, output=deeplab) 95 | return model 96 | 97 | 98 | 99 | class _DeeplabV3_plus(Layer): 100 | def __init__(self, backbond, input_shape=(224,224,3), atrous_rates=(6, 12, 18, 24), num_filters=256, classes=20): 101 | super(_DeeplabV3_plus, self).__init__() 102 | moduals=list(backbond.children()) 103 | low_level_idx=-1 104 | high_level_idx=-1 105 | for i in range(len(moduals)): 106 | if low_level_idx<0 and moduals[i].output_shape[1]==backbond.input_shape[1]//8: 107 | low_level_idx=i 108 | 109 | if high_level_idx<0 and moduals[i].output_shape[1]==backbond.input_shape[1]//32: 110 | high_level_idx=i 111 | break 112 | self.num_filters=num_filters 113 | self.classes=classes 114 | self.atrous_rates=atrous_rates 115 | self.backbond1=Sequential(*backbond[:low_level_idx]) 116 | self.backbond2 = Sequential(*backbond[low_level_idx:high_level_idx]) 117 | self.aspp=ASPP(atrous_rates=self.atrous_rates,num_filters=self.num_filters) 118 | self.low_level_conv=Conv2d_Block((1,1),num_filters=int(48*self.num_filters/256),strides=1,use_bias=False,activation='leaky_relu',normalization='batch') 119 | self.decoder=Sequential( 120 | DepthwiseConv2d_Block((3,3),depth_multiplier=0.5,strides=1,use_bias=False,activation='leaky_relu',normalization='batch',dropout_rate=0.5), 121 | DepthwiseConv2d_Block((3,3),depth_multiplier=1,strides=1,use_bias=False,activation='leaky_relu',normalization='batch',dropout_rate=0.1), 122 | Conv2d((1, 1), num_filters=self.classes, strides=1, use_bias=False, activation=None), 123 | SoftMax(axis=-1) 124 | 125 | ) 126 | 127 | def forward(self, x,**kwargs): 128 | low_level_feature=self.backbond1(x) 129 | high_level_feature = self.backbond2(low_level_feature) 130 | x=self.aspp(high_level_feature) 131 | x=tf.image.resize(x, [x.shape[2]*4, x.shape[1]*4], method=ResizeMethod.BILINEAR, preserve_aspect_ratio=True,antialias=True) 132 | low_level_feature=self.low_level_conv(low_level_feature) 133 | x=tf.concat([x,low_level_feature],axis=-1) 134 | x=self.decoder(x) 135 | x=tf.image.resize(x, [x.shape[2]*4, x.shape[1]*4], method=ResizeMethod.BILINEAR, preserve_aspect_ratio=True,antialias=True) 136 | return x 137 | 138 | 139 | 140 | def DeeplabV3_plus(backbond, 141 | input_shape=(224,224,3), 142 | atrous_rates = (6, 12, 18, 24), 143 | num_filters = 256, 144 | classes=20, 145 | **kwargs): 146 | deeplab=_DeeplabV3_plus(backbond=backbond,input_shape=input_shape,atrous_rates=atrous_rates,num_filters=num_filters,classes=classes) 147 | deeplab.name='DeeplabV3_plus' 148 | model = ImageSegmentationModel(input_shape=input_shape, output=deeplab) 149 | return model 150 | -------------------------------------------------------------------------------- /trident/optims/tensorflow_constraints.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import functools 6 | import numpy as np 7 | import tensorflow as tf 8 | from trident.backend.tensorflow_backend import Parameter, Layer 9 | 10 | from trident.backend.common import get_session,epsilon 11 | from trident.backend.tensorflow_ops import * 12 | 13 | __all__ = ['max_norm', 'non_neg_norm', 'unit_norm', 'min_max_norm', 'maxnorm', 'nonnegnorm', 'unitnorm', 'minmaxnorm', 'get_constraint'] 14 | 15 | _session=get_session() 16 | _epsilon=_session.epsilon 17 | 18 | 19 | def max_norm(model,max_value=2, axis=0): 20 | """ 21 | MaxNorm weight constraint. 22 | Constrains the weights incident to each hidden unit to have a norm less than or equal to a desired value. 23 | Args: 24 | model : the model contains weights need to setting the constraints. 25 | max_value (float):the maximum norm value for the incoming weights 26 | axis (int):axis along which to calculate weight norms. 27 | 28 | """ 29 | 30 | def apply_constraint(t: Tensor): 31 | w_data = None 32 | if isinstance(t, tf.Variable): 33 | w_data = t.value().detach() 34 | else: 35 | w_data = t.copy().detach() 36 | norms = sqrt(reduce_sum(square(w_data), axis=axis, keepdims=True)) 37 | desired = clip(norms, 0, max_value) 38 | param_applied = w_data * (desired / (epsilon() + norms)) 39 | param_applied = param_applied.detach() 40 | return param_applied 41 | 42 | if is_tensor(model): 43 | model = apply_constraint(model) 44 | elif isinstance(model, Layer): 45 | for name, param in model.named_parameters(): 46 | if 'bias' not in name and param is not None and param.trainable == True: 47 | param.assign(apply_constraint(param)) 48 | 49 | 50 | def non_neg_norm(model): 51 | """ 52 | Constrains the weights to be non-negative. 53 | Args: 54 | model : the model contains weights need to setting the constraints. 55 | 56 | """ 57 | 58 | 59 | def apply_constraint(t: Tensor): 60 | w_data = None 61 | if isinstance(t, tf.Variable): 62 | w_data = t.value().detach() 63 | else: 64 | w_data = t.copy().detach() 65 | param_applied =w_data * tf.cast(greater_equal(param, 0.), tf.float32) 66 | param_applied = param_applied.detach() 67 | return param_applied 68 | 69 | if is_tensor(model): 70 | model = apply_constraint(model) 71 | elif isinstance(model, Layer): 72 | for name, param in model.named_parameters(): 73 | if 'bias' not in name and param is not None and param.trainable == True: 74 | param.assign(apply_constraint(param)) 75 | 76 | 77 | def unit_norm(model,axis=0): 78 | """ 79 | Constrains the weights incident to each hidden unit to have unit norm. 80 | Args: 81 | axis (int):axis along which to calculate weight norms. 82 | model : the model contains weights need to setting the constraints. 83 | 84 | """ 85 | def apply_constraint(t: Tensor): 86 | w_data = None 87 | if isinstance(t, tf.Variable): 88 | w_data = t.value().detach() 89 | else: 90 | w_data = t.copy().detach() 91 | param_applied = w_data/ (epsilon() +sqrt(reduce_sum(square(w_data),axis=axis,keepdims=True))) 92 | param_applied = param_applied.detach() 93 | return param_applied 94 | 95 | if is_tensor(model): 96 | model = apply_constraint(model) 97 | elif isinstance(model, Layer): 98 | for name, param in model.named_parameters(): 99 | if 'bias' not in name and param is not None and param.trainable == True: 100 | param.assign(apply_constraint(param)) 101 | 102 | 103 | def min_max_norm(model,min_value=0.0, max_value=1.0, rate=1.0, axis=0): 104 | """ 105 | MinMaxNorm weight constraint. 106 | Constrains the weights incident to each hidden unit to have the norm between a lower bound and an upper bound. 107 | 108 | Args: 109 | model : the model contains weights need to setting the constraints. 110 | min_value (float):the minimum norm for the incoming weights. 111 | max_value ()float:the maximum norm for the incoming weights. 112 | rate (float):rate for enforcing the constraint: weights will be rescaled to yield (1 - rate) * norm + rate * norm.clip(min_value, max_value). Effectively, this means that rate=1.0 stands for strict enforcement of the constraint, while rate<1.0 means that weights will be rescaled at each step to slowly move towards a value inside the desired interval. 113 | axis (int): axis along which to calculate weight norms 114 | 115 | Examples: 116 | >>> t=random_normal((2,64,64,32),mean=5,std=10) 117 | >>> print(t) 118 | >>> t1=tf.keras.constraints.MinMaxNorm(min_value=0.0, max_value=1.0, rate=1.0, axis=0)(t) 119 | >>> print(t1) 120 | >>> t2=min_max_norm(t,min_value=0.0, max_value=1.0, rate=1.0, axis=0) 121 | >>> print(t2) 122 | >>> np.testing.assert_almost_equal(to_numpy(t1),to_numpy(t2),decimal=6,verbose=True) 123 | 124 | 125 | 126 | """ 127 | def apply_constraint(t:Tensor): 128 | w_data=None 129 | if isinstance(t, tf.Variable): 130 | w_data = t.value().detach() 131 | else: 132 | w_data=t.copy().detach() 133 | norms = sqrt(reduce_sum(square(w_data), axis=axis, keepdims=True)) 134 | desired = (rate * clip(norms, min_value, max_value) + (1 - rate) * norms) 135 | param_applied = w_data * (desired / (epsilon() + norms)) 136 | return param_applied 137 | 138 | if is_tensor(model): 139 | model=apply_constraint(model) 140 | elif isinstance(model,Layer): 141 | for name, param in model.named_parameters(): 142 | if 'bias' not in name and param is not None and param.trainable==True: 143 | param.assign(apply_constraint(param)) 144 | 145 | 146 | # Legacy aliases. 147 | 148 | maxnorm = max_norm 149 | nonnegnorm = non_neg_norm 150 | unitnorm = unit_norm 151 | minmaxnorm=min_max_norm 152 | default_constrains=functools.partial(min_max_norm,functools.partial) 153 | 154 | def get_constraint(constraint): 155 | if constraint in ['maxnorm','max_norm']: 156 | return max_norm 157 | elif constraint in ['non_neg','nonneg']: 158 | return non_neg_norm 159 | elif constraint in ['unit_norm','unitnorm']: 160 | return unit_norm 161 | elif constraint in ['min_max_norm', 'minmaxnorm']: 162 | return min_max_norm 163 | else: 164 | return None 165 | 166 | -------------------------------------------------------------------------------- /trident/models/tensorflow_mobilenet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import inspect 6 | import math 7 | import os 8 | import uuid 9 | from collections import * 10 | from collections import deque 11 | from copy import copy, deepcopy 12 | from functools import partial 13 | from itertools import repeat 14 | 15 | import numpy as np 16 | import tensorflow as tf 17 | 18 | from trident.backend.common import * 19 | from trident.backend.tensorspec import * 20 | from trident.backend.tensorflow_backend import to_numpy, to_tensor, Layer, Sequential,load,get_device,fix_layer 21 | from trident.data.image_common import * 22 | from trident.data.utils import download_model_from_google_drive 23 | from trident.layers.tensorflow_activations import get_activation, Identity 24 | from trident.layers.tensorflow_blocks import * 25 | from trident.layers.tensorflow_layers import * 26 | from trident.layers.tensorflow_normalizations import get_normalization 27 | from trident.layers.tensorflow_pooling import * 28 | from trident.optims.tensorflow_trainer import * 29 | from trident.data.vision_transforms import Resize,Normalize 30 | __all__ = ['MobileNet','MobileNetV2'] 31 | 32 | _session = get_session() 33 | 34 | _epsilon=_session.epsilon 35 | _trident_dir=_session.trident_dir 36 | 37 | 38 | dirname = os.path.join(_trident_dir, 'models') 39 | if not os.path.exists(dirname): 40 | try: 41 | os.makedirs(dirname) 42 | except OSError: 43 | # Except permission denied and potential race conditions 44 | # in multi-threaded environments. 45 | pass 46 | 47 | 48 | def _make_divisible(v, divisor, min_value=None): 49 | """ 50 | This function is taken from the original tf repo. 51 | It ensures that all layers have a channel number that is divisible by 8 52 | It can be seen here: 53 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 54 | Args 55 | v: 56 | divisor: 57 | min_value: 58 | :return: 59 | """ 60 | if min_value is None: 61 | min_value = divisor 62 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 63 | # Make sure that round down does not go down by more than 10%. 64 | if new_v < 0.9 * v: 65 | new_v += divisor 66 | return new_v 67 | 68 | 69 | def inverted_residual(in_filters,num_filters=64,strides=1,expansion = 4,name=''): 70 | mid_filters= int(round(in_filters * expansion)) 71 | layers=[] 72 | if expansion!=1 : 73 | layers.append(Conv2d_Block((1,1),num_filters=mid_filters,strides=1,auto_pad=True,padding_mode='zero',normalization='batch',activation='relu6',name=name + '_{0}_conv'.format(len(layers)))) 74 | 75 | layers.append(DepthwiseConv2d_Block((3, 3), depth_multiplier=1, strides=strides, auto_pad=True,padding_mode='zero', normalization='batch', activation='relu6', name=name + '_{0}_conv'.format(len(layers)))) 76 | layers.append(Conv2d_Block((1, 1), num_filters=num_filters, strides=1, auto_pad=False, padding_mode='zero', normalization='batch', activation=None, name=name + '_{0}_conv'.format(len(layers)))) 77 | if strides == 1 and in_filters==num_filters: 78 | return ShortCut2d(Sequential(*layers), Identity(), activation=None) 79 | else: 80 | return Sequential(*layers) 81 | 82 | def MobileNet( input_shape=(224, 224,3), classes=1000, use_bias=False, width_mult=1.0,round_nearest=8, include_top=True, model_name='', 83 | **kwargs): 84 | input_filters = 32 85 | last_filters = 1280 86 | mobilenet=Sequential(name='mobilenet') 87 | inverted_residual_setting = [ 88 | # t, c, n, s 89 | [1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 2], [6, 64, 4, 2], [6, 96, 3, 1], [6, 160, 3, 2], [6, 320, 1, 1], 90 | ] 91 | input_filters = _make_divisible(input_filters * width_mult, round_nearest) 92 | last_filters = _make_divisible(last_filters * max(1.0, width_mult), round_nearest) 93 | features = [] 94 | features.append(Conv2d_Block((3,3),num_filters=input_filters,strides=2,auto_pad=True,padding_mode='zero',normalization='batch',activation='relu6',name='first_layer')) 95 | for t, c, n, s in inverted_residual_setting: 96 | output_filters = _make_divisible(c * width_mult, round_nearest) 97 | for i in range(n): 98 | strides = s if i == 0 else 1 99 | features.append( inverted_residual(input_filters,num_filters=output_filters, strides=strides, expansion=t,name='irb_{0}'.format(i))) 100 | input_filters = output_filters 101 | features.append(Conv2d_Block((1,1), last_filters,auto_pad=True,padding_mode='zero',normalization='batch',activation='relu6',name='last_layer')) 102 | mobilenet.add_module('features',Sequential(*features,name='features')) 103 | 104 | if include_top: 105 | mobilenet.add_module('gap', GlobalAvgPool2d()) 106 | mobilenet.add_module('drop', Dropout(0.2)) 107 | mobilenet.add_module('fc',Dense((classes),activation=None)) 108 | mobilenet.add_module('softmax', SoftMax(name='softmax')) 109 | model = ImageClassificationModel(input_shape=input_shape, output=mobilenet) 110 | 111 | if os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'imagenet_labels1.txt')): 112 | with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'imagenet_labels1.txt'), 'r', 113 | encoding='utf-8-sig') as f: 114 | labels = [l.rstrip() for l in f] 115 | model.class_names = labels 116 | model.preprocess_flow = [Resize((224, 224), keep_aspect=True), Normalize(127.5, 127.5)] 117 | # model.summary() 118 | return model 119 | 120 | 121 | 122 | 123 | def MobileNetV2(include_top=True, 124 | pretrained=True, 125 | input_shape=(224,224,3), 126 | classes=1000, 127 | **kwargs): 128 | if input_shape is not None and len(input_shape)==3: 129 | input_shape=tuple(input_shape) 130 | else: 131 | input_shape=(224, 224,3) 132 | mob =MobileNet(input_shape=(224, 224,3), classes=classes, use_bias=False, width_mult=1.0,round_nearest=8, include_top=include_top, model_name='mobilenet') 133 | if pretrained==True: 134 | download_model_from_google_drive('15LtLJHpvimV6cFGqAwJ4QALNEjeATrKe',dirname,'mobilenet_v2_tf.pth') 135 | recovery_model=fix_layer(load(os.path.join(dirname,'mobilenet_v2_tf.pth'))) 136 | recovery_model.eval() 137 | with tf.device(get_device()): 138 | if include_top==False: 139 | recovery_model.remove_at(-1) 140 | recovery_model.remove_at(-1) 141 | recovery_model.remove_at(-1) 142 | else: 143 | if classes!=1000: 144 | new_fc = Dense(classes, activation=None, name='fc') 145 | new_fc.input_shape=recovery_model.fc.input_shape 146 | recovery_model.fc=new_fc 147 | mob.model=recovery_model 148 | return mob -------------------------------------------------------------------------------- /trident/models/pytorch_mobilenet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import inspect 6 | import math 7 | import os 8 | import uuid 9 | from collections import * 10 | from collections import deque 11 | from copy import copy, deepcopy 12 | from functools import partial 13 | from itertools import repeat 14 | 15 | import numpy as np 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | from collections import abc 20 | from torch.nn import init 21 | from torch.nn.parameter import Parameter 22 | 23 | from trident.backend.common import * 24 | from trident.backend.pytorch_backend import to_numpy, to_tensor, Layer, Sequential, fix_layer, load,get_device 25 | from trident.data.image_common import * 26 | from trident.data.utils import download_model_from_google_drive 27 | from trident.layers.pytorch_activations import get_activation, Identity 28 | from trident.layers.pytorch_blocks import * 29 | from trident.layers.pytorch_layers import * 30 | from trident.layers.pytorch_normalizations import get_normalization 31 | from trident.layers.pytorch_pooling import * 32 | from trident.optims.pytorch_trainer import * 33 | from trident.models.pretrained_utils import _make_recovery_model_include_top 34 | from trident.data.vision_transforms import Resize,Normalize 35 | __all__ = ['MobileNet','MobileNetV2'] 36 | 37 | _session = get_session() 38 | _device =get_device() 39 | _epsilon=_session.epsilon 40 | _trident_dir=_session.trident_dir 41 | 42 | 43 | dirname = os.path.join(_trident_dir, 'models') 44 | if not os.path.exists(dirname): 45 | try: 46 | os.makedirs(dirname) 47 | except OSError: 48 | # Except permission denied and potential race conditions 49 | # in multi-threaded environments. 50 | pass 51 | 52 | 53 | def _make_divisible(v, divisor, min_value=None): 54 | """ 55 | This function is taken from the original tf repo. 56 | It ensures that all layers have a channel number that is divisible by 8 57 | It can be seen here: 58 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 59 | Args 60 | v: 61 | divisor: 62 | min_value: 63 | :return: 64 | """ 65 | if min_value is None: 66 | min_value = divisor 67 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 68 | # Make sure that round down does not go down by more than 10%. 69 | if new_v < 0.9 * v: 70 | new_v += divisor 71 | return new_v 72 | 73 | 74 | def inverted_residual(in_filters,num_filters=64,strides=1,expansion = 4,name=''): 75 | mid_filters= int(round(in_filters * expansion)) 76 | layers=[] 77 | if expansion!=1 : 78 | layers.append(Conv2d_Block((1,1),num_filters=mid_filters,strides=1,auto_pad=True,padding_mode='zero',normalization='batch',activation='relu6')) 79 | 80 | layers.append(DepthwiseConv2d_Block((3, 3), depth_multiplier=1, strides=strides, auto_pad=True,padding_mode='zero', normalization='batch', activation='relu6')) 81 | layers.append(Conv2d_Block((1, 1), num_filters=num_filters, strides=1, auto_pad=False, padding_mode='zero', normalization='batch', activation=None)) 82 | if strides == 1 and in_filters==num_filters: 83 | return ShortCut2d(Sequential(*layers), Identity(), activation=None) 84 | else: 85 | return Sequential(*layers) 86 | 87 | def MobileNet( input_shape=(3, 224, 224), classes=1000, use_bias=False, width_mult=1.0,round_nearest=8, include_top=True, model_name='', 88 | **kwargs): 89 | input_filters = 32 90 | last_filters = 1280 91 | mobilenet=Sequential(name='mobilenet') 92 | inverted_residual_setting = [ 93 | # t, c, n, s 94 | [1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 2], [6, 64, 4, 2], [6, 96, 3, 1], [6, 160, 3, 2], [6, 320, 1, 1], 95 | ] 96 | input_filters = _make_divisible(input_filters * width_mult, round_nearest) 97 | last_filters = _make_divisible(last_filters * max(1.0, width_mult), round_nearest) 98 | features = [] 99 | features.append(Conv2d_Block((3,3),num_filters=input_filters,strides=2,auto_pad=True,padding_mode='zero',normalization='batch',activation='relu6')) 100 | for t, c, n, s in inverted_residual_setting: 101 | output_filters = _make_divisible(c * width_mult, round_nearest) 102 | for i in range(n): 103 | strides = s if i == 0 else 1 104 | features.append( inverted_residual(input_filters,num_filters=output_filters, strides=strides, expansion=t)) 105 | input_filters = output_filters 106 | features.append(Conv2d_Block((1,1), last_filters,auto_pad=True,padding_mode='zero',normalization='batch',activation='relu6')) 107 | mobilenet.add_module('features',Sequential(*features,name='features')) 108 | 109 | if include_top: 110 | mobilenet.add_module('gap', GlobalAvgPool2d()) 111 | mobilenet.add_module('drop', Dropout(0.2)) 112 | mobilenet.add_module('fc',Dense((classes),activation=None)) 113 | mobilenet.add_module('softmax', SoftMax(name='softmax')) 114 | model = ImageClassificationModel(input_shape=input_shape, output=mobilenet) 115 | if os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'imagenet_labels1.txt')): 116 | with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'imagenet_labels1.txt'), 'r', 117 | encoding='utf-8-sig') as f: 118 | labels = [l.rstrip() for l in f] 119 | model.class_names = labels 120 | model.preprocess_flow = [Resize((224, 224), keep_aspect=True), Normalize(0, 255), 121 | Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])] 122 | # model.summary() 123 | return model 124 | 125 | 126 | 127 | 128 | def MobileNetV2(include_top=True, 129 | pretrained=True, 130 | freeze_features=True, 131 | input_shape=(3,224,224), 132 | classes=1000, 133 | **kwargs): 134 | if input_shape is not None and len(input_shape)==3: 135 | input_shape=tuple(input_shape) 136 | else: 137 | input_shape=(3, 224, 224) 138 | mob =MobileNet(input_shape=(3, 224, 224), classes=classes, use_bias=False, width_mult=1.0,round_nearest=8, include_top=include_top, model_name='mobilenet') 139 | if pretrained==True: 140 | download_model_from_google_drive('1ULenXTjOO5PdT3fHv6N8bPXEfoJAn5yL',dirname,'mobilenet_v2.pth') 141 | recovery_model=load(os.path.join(dirname,'mobilenet_v2.pth')) 142 | recovery_model = fix_layer(recovery_model) 143 | recovery_model = _make_recovery_model_include_top(recovery_model,input_shape=input_shape, include_top=include_top, classes=classes, freeze_features=freeze_features) 144 | mob.model = recovery_model 145 | 146 | else: 147 | mob.model = _make_recovery_model_include_top(mob.model, include_top=include_top, classes=classes, freeze_features=True) 148 | 149 | mob.model.input_shape = input_shape 150 | mob.model.to(_device) 151 | return mob -------------------------------------------------------------------------------- /trident/loggers/mlflow_logger.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import operator 3 | import threading 4 | from abc import ABC, abstractmethod 5 | 6 | from trident.backend.common import import_or_install, get_session 7 | 8 | import logging 9 | import os 10 | import re 11 | from argparse import Namespace 12 | import time 13 | from typing import Any, Dict, Optional, Union 14 | 15 | from trident.loggers.logger import BaseLogger 16 | from trident.context import split_path, make_dir_if_need, sanitize_path 17 | ctx=get_session() 18 | if ctx.get_backend() == 'pytorch': 19 | import torch 20 | import torch.nn as nn 21 | from trident.backend.pytorch_backend import Tensor 22 | from trident.backend.pytorch_ops import * 23 | 24 | elif ctx.get_backend() == 'tensorflow': 25 | import tensorflow as tf 26 | from trident.backend.tensorflow_backend import Tensor 27 | from trident.backend.tensorflow_ops import * 28 | 29 | 30 | 31 | 32 | 33 | _MLFLOW_AVAILABLE =True 34 | try: 35 | import mlflow 36 | from mlflow.tracking import context, MlflowClient 37 | from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME 38 | 39 | except ImportError: 40 | try: 41 | import_or_install('mlflow','mlflow') 42 | import mlflow 43 | from mlflow.tracking import context, MlflowClient 44 | from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME 45 | except ImportError: 46 | _MLFLOW_AVAILABLE=False 47 | raise ImportError( 'You want to use `mlflow` logger which is not installed yet,' 48 | ' install it with `pip install mlflow`.') 49 | 50 | # before v1.1.0 51 | if hasattr(context, 'resolve_tags'): 52 | from mlflow.tracking.context import resolve_tags 53 | 54 | 55 | # since v1.1.0 56 | elif hasattr(context, 'registry'): 57 | from mlflow.tracking.context.registry import resolve_tags 58 | else: 59 | def resolve_tags(tags=None): 60 | return tags 61 | LOCAL_FILE_URI_PREFIX = "file:" 62 | 63 | 64 | 65 | 66 | class MLFlowLogger(object): 67 | """Writes entries directly to event files in the log_dir to be 68 | consumed by TensorBoard. 69 | 70 | The `SummaryWriter` class provides a high-level API to create an event file 71 | in a given directory and add summaries and events to it. The class updates the 72 | file contents asynchronously. This allows a training program to call methods 73 | to add data to the file directly from the training loop, without slowing down 74 | training. 75 | """ 76 | __singleton_lock = threading.Lock() 77 | __singleton_instance = None 78 | 79 | # define the classmethod 80 | @classmethod 81 | def instance(cls): 82 | 83 | # check for the singleton instance 84 | if not cls.__singleton_instance: 85 | with cls.__singleton_lock: 86 | if not cls.__singleton_instance: 87 | cls.__singleton_instance = cls() 88 | 89 | # return the singleton instance 90 | return cls.__singleton_instance 91 | def __init__(self,experiment_name="my-experiment"): 92 | mlflow.set_tracking_uri('http://{0}:{1}'.format(ctx.mlflow_server, 4040)) 93 | self.client= MlflowClient() 94 | make_dir_if_need('Log/images') 95 | self.run=None 96 | self.file_writer=None 97 | self.all_writers=None 98 | self.experiment_name=experiment_name 99 | experiments = self.client.list_experiments() 100 | experiment_ids=[e.experiment_id for e in experiments if e.name==self.experiment_name] 101 | if len(experiment_ids)>0: 102 | self.experiment_id =experiment_ids[-1] 103 | else: 104 | self.experiment_id = self.client.create_experiment(self.experiment_name,) 105 | 106 | def start_run(self,run_id=None,experiment_id=None,run_name=None): 107 | if experiment_id is None: 108 | experiment_id=self.experiment_id 109 | self.run= self.client.create_run(experiment_id) 110 | mlflow.start_run(run_id=self.run.info.run_id,experiment_id=experiment_id,run_name=run_name) 111 | 112 | def add_hparams( 113 | self, hparam_dict, metric_dict, hparam_domain_discrete=None, run_name=None 114 | ): 115 | pass 116 | 117 | def add_scalar(self, tag, scalar_value, global_step=None, walltime=None): 118 | """Add scalar data to summary. 119 | 120 | Args: 121 | tag (string): Data identifier 122 | scalar_value (float or string/blobname): Value to save 123 | global_step (int): Global step value to record 124 | walltime (float): Optional override default walltime (time.time()) 125 | with seconds after epoch of event 126 | 127 | Examples:: 128 | 129 | from torch.utils.tensorboard import SummaryWriter 130 | writer = SummaryWriter() 131 | x = range(100) 132 | for i in x: 133 | writer.add_scalar('y=2x', i * 2, i) 134 | writer.close() 135 | 136 | Expected result: 137 | 138 | .. image:: _static/img/tensorboard/add_scalar.png 139 | :scale: 50 % 140 | 141 | """ 142 | self.client.log_metric(run_id=self.run.info.run_id,key=tag, value=to_scalar(scalar_value),timestamp=walltime if walltime is not None else int(time.time()), step=global_step) 143 | #log_metric(run_id: str, key: str, value: float, timestamp: Optional[int] = None, step: Optional[int] = None) → None 144 | 145 | 146 | def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None): 147 | pass 148 | 149 | def add_histogram(self, tag, values, global_step=None, bins='tensorflow', walltime=None): 150 | pass 151 | 152 | def add_histogram_raw(self, tag, min, max, num, sum, sum_squares, 153 | bucket_limits, bucket_counts, global_step=None, 154 | walltime=None): 155 | pass 156 | 157 | def add_image(self, img_path): 158 | 159 | self.client.log_image(run_id=self.run.info.run_id,image=image2array(img_path).astype(np.uint8),artifact_file='Log/images') 160 | 161 | 162 | def add_images(self, tag, img_tensor, global_step=None, walltime=None): 163 | pass 164 | 165 | def add_figure(self, tag, figure, artifact_file=None): 166 | self.client.log_figure(run_id=self.run.info.run_id,figure=figure,artifact_file=artifact_file) 167 | #log_figure(run_id: str, figure: Union[matplotlib.figure.Figure, plotly.graph_objects.Figure], artifact_file: str) 168 | 169 | 170 | 171 | def add_onnx_graph(self, prototxt): 172 | pass 173 | 174 | 175 | def flush(self): 176 | """Flushes the event file to disk. 177 | Call this method to make sure that all pending events have been written to 178 | disk. 179 | """ 180 | if self.all_writers is None: 181 | return 182 | for writer in self.all_writers.values(): 183 | writer.flush() 184 | 185 | def close(self): 186 | if self.all_writers is None: 187 | return # ignore double close 188 | for writer in self.all_writers.values(): 189 | writer.flush() 190 | writer.close() 191 | self.file_writer = self.all_writers = None 192 | self.client.set_terminated(self.run.info.run_id,end_time=time.time()) 193 | 194 | def __enter__(self): 195 | return self 196 | 197 | def __exit__(self, exc_type, exc_val, exc_tb): 198 | self.close() -------------------------------------------------------------------------------- /trident/models/tensorflow_vgg.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import inspect 6 | import math 7 | import os 8 | import uuid 9 | from collections import * 10 | from collections import deque 11 | from copy import copy, deepcopy 12 | from functools import partial 13 | from itertools import repeat 14 | 15 | import numpy as np 16 | import tensorflow as tf 17 | from trident.backend.common import * 18 | from trident.backend.tensorspec import * 19 | from trident.backend.model import * 20 | from trident.backend.tensorflow_backend import to_numpy, to_tensor, Layer, Sequential,load,get_device 21 | from trident.data.image_common import * 22 | from trident.data.utils import download_model_from_google_drive 23 | from trident.layers.tensorflow_activations import get_activation, Identity, Relu 24 | from trident.layers.tensorflow_blocks import * 25 | from trident.layers.tensorflow_layers import * 26 | from trident.layers.tensorflow_normalizations import get_normalization 27 | from trident.layers.tensorflow_pooling import * 28 | from trident.optims.tensorflow_trainer import ImageClassificationModel 29 | from trident.data.vision_transforms import Resize,Normalize 30 | __all__ = ['VGG19','VGG11','VGG13','VGG16'] 31 | 32 | _session = get_session() 33 | 34 | _epsilon=_session.epsilon 35 | _trident_dir=_session.trident_dir 36 | 37 | 38 | dirname = os.path.join(_trident_dir, 'models') 39 | if not os.path.exists(dirname): 40 | try: 41 | os.makedirs(dirname) 42 | except OSError: 43 | # Except permission denied and potential race conditions 44 | # in multi-threaded environments. 45 | pass 46 | 47 | 48 | 49 | 50 | cfgs = { 51 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 52 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 53 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 54 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 55 | } 56 | 57 | 58 | def make_vgg_layers(cfg, num_classes=1000,input_shape=(224,224,3),include_top=True): 59 | layers = [] 60 | in_channels = 3 61 | block=1 62 | conv=1 63 | vgg=Sequential() 64 | for v in cfg: 65 | if v == 'M': 66 | vgg.add_module('block{0}_pool'.format(block),MaxPool2d(kernel_size=2, strides=2,use_bias=True,name='block{0}_pool'.format(block))) 67 | block += 1 68 | conv = 1 69 | else: 70 | if len(vgg)==0: 71 | vgg.add_module('block{0}_conv{1}'.format(block,conv),Conv2d((3,3),v,auto_pad=True,activation=None,use_bias=True,name='block{0}_conv{1}'.format(block,conv))) 72 | else: 73 | vgg.add_module('block{0}_conv{1}'.format(block, conv), Conv2d((3, 3), v, auto_pad=True, activation=None, use_bias=True,name='block{0}_conv{1}'.format(block, conv))) 74 | 75 | vgg.add_module('block{0}_relu{1}'.format(block, conv),Relu(name='block{0}_relu{1}'.format(block, conv))) 76 | conv+=1 77 | in_channels = v 78 | if include_top: 79 | vgg.add_module('flattened', Flatten()) 80 | vgg.add_module('fc1',Dense(4096,use_bias=True, activation='relu')) 81 | vgg.add_module('drop1', Dropout(0.5)) 82 | vgg.add_module('fc2', Dense(4096, use_bias=True,activation='relu')) 83 | vgg.add_module('drop2', Dropout(0.5)) 84 | vgg.add_module('fc3', Dense(num_classes,use_bias=True,activation='softmax')) 85 | 86 | 87 | model = ImageClassificationModel(input_shape=input_shape, output=vgg) 88 | if os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'imagenet_labels1.txt')): 89 | with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'imagenet_labels1.txt'), 'r', 90 | encoding='utf-8-sig') as f: 91 | labels = [l.rstrip() for l in f] 92 | model.class_names = labels 93 | model.preprocess_flow = [Resize((input_shape[0], input_shape[1]), keep_aspect=True),to_bgr(), 94 | Normalize([103.939, 116.779, 123.68], [1, 1, 1])] 95 | 96 | 97 | # model.summary() 98 | 99 | return model 100 | 101 | 102 | #vgg11 =make_vgg_layers(cfgs['A'], 1000) 103 | def VGG11(include_top=True, 104 | pretrained=True, 105 | input_shape=None, 106 | classes=1000, 107 | **kwargs): 108 | if input_shape is not None and len(input_shape)==3: 109 | input_shape=tuple(input_shape) 110 | else: 111 | input_shape=(224, 224,3) 112 | vgg11 =make_vgg_layers(cfgs['A'], classes) 113 | vgg11.input_shape =input_shape 114 | if pretrained==True: 115 | print('There is no pretrained Vgg11 in tensorflow backend') 116 | return vgg11 117 | 118 | 119 | 120 | 121 | 122 | #vgg13 =make_vgg_layers(cfgs['B'], 1000) 123 | def VGG13(include_top=True, 124 | pretrained=True, 125 | input_shape=None, 126 | classes=1000, 127 | **kwargs): 128 | if input_shape is not None and len(input_shape)==3: 129 | input_shape=tuple(input_shape) 130 | else: 131 | input_shape=(224,224 ,3) 132 | vgg13 =make_vgg_layers(cfgs['B'], classes) 133 | 134 | if pretrained==True: 135 | print('There is no pretrained Vgg13 in tensorflow backend') 136 | return vgg13 137 | 138 | 139 | #vgg16 =make_vgg_layers(cfgs['D'], 1000) 140 | def VGG16(include_top=True, 141 | pretrained=True, 142 | input_shape=None, 143 | classes=1000, 144 | **kwargs): 145 | if input_shape is not None and len(input_shape)==3: 146 | input_shape=tuple(input_shape) 147 | else: 148 | input_shape=(224, 224,3) 149 | vgg16 =make_vgg_layers(cfgs['D'], classes) 150 | vgg16.input_shape =input_shape 151 | if pretrained==True: 152 | download_model_from_google_drive('1fozCY4Yv_ud5UGpv7q4M9tcxZ2ryDCTb',dirname,'vgg16_tf.pth') 153 | recovery_model=load(os.path.join(dirname,'vgg16_tf.pth')) 154 | recovery_model.name = 'vgg16' 155 | recovery_model.eval() 156 | with tf.device(get_device()): 157 | if include_top==False: 158 | [recovery_model.remove_at(-1) for i in range(7)] 159 | else: 160 | if classes!=1000: 161 | recovery_model.remove_at(-1) 162 | recovery_model.add_module('fc3', Dense(classes, activation=None, name='fc3')) 163 | recovery_model.add_module('softmax', SoftMax(name='softmax')) 164 | 165 | vgg16.model=recovery_model 166 | return vgg16 167 | 168 | #vgg19 =make_vgg_layers(cfgs['E'], 1000) 169 | def VGG19(include_top=True, 170 | pretrained=True, 171 | input_shape=None, 172 | classes=1000, 173 | **kwargs): 174 | if input_shape is not None and len(input_shape)==3: 175 | input_shape=tuple(input_shape) 176 | else: 177 | input_shape=(224,224 ,3) 178 | vgg19 =make_vgg_layers(cfgs['E'], classes) 179 | vgg19.input_shape =input_shape 180 | if pretrained==True: 181 | download_model_from_google_drive('1nXKMsYklBimtqs7ZRv0dQ-RIqNvgopVh',dirname,'vgg19_tf.pth') 182 | recovery_model=load(os.path.join(dirname,'vgg19_tf.pth')) 183 | recovery_model.name = 'vgg19' 184 | recovery_model.eval() 185 | with tf.device(get_device()): 186 | if include_top==False: 187 | [recovery_model.remove_at(-1) for i in range(7)] 188 | else: 189 | if classes!=1000: 190 | recovery_model.remove_at(-1) 191 | recovery_model.add_module('fc3', Dense(classes, activation=None, name='fc3')) 192 | recovery_model.add_module('softmax', SoftMax(name='softmax')) 193 | 194 | vgg19.model=recovery_model 195 | return vgg19 -------------------------------------------------------------------------------- /trident/models/tensorflow_resnet.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import cv2 4 | os.environ['TRIDENT_BACKEND'] = 'tensorflow' 5 | #!pip install tridentx --upgrade 6 | import trident as T 7 | from trident import * 8 | import tensorflow as tf 9 | from tensorflow.keras.applications.resnet import ResNet50 10 | 11 | from trident.models import resnet 12 | 13 | res50=resnet.ResNet50(pretrained=False,input_shape=(224,224,3),classes=1000) 14 | #res50.save_model('resnet50.pth.tar') 15 | #tf.saved_model.save(res50.model.state_dict(),'resnet50') 16 | 17 | 18 | 19 | 20 | res50_old=ResNet50(include_top=True,input_shape=(224,224,3),weights='imagenet',classes=1000) 21 | 22 | 23 | layers_old=[m for m in res50_old.layers if len(list(m.variables))>0] 24 | # for i in range(len(layers_old)): 25 | # layer=layers_old[i] 26 | # if '_0_conv' in layer.name: 27 | # if '_0_bn' in layers_old[i+2].name: 28 | # bn=layers_old.pop(i+2) 29 | # conv=layers_old.pop(i) 30 | # layers_old.insert(i-4,conv) 31 | # layers_old.insert(i - 3, bn) 32 | # 33 | 34 | print(len(layers_old)) 35 | 36 | 37 | layers_old_dict=OrderedDict() 38 | for layer in layers_old: 39 | layers_old_dict[layer.name]=str(layer.get_config()) 40 | 41 | def is_same_layer(m_new,m_old): 42 | if 'Conv' in m_new.__class__.__name__ and 'Conv' in m_old.__class__.__name__: 43 | return True 44 | elif 'Batch' in m_new.__class__.__name__ and 'Batch' in m_old.__class__.__name__: 45 | return True 46 | elif 'Dense' in m_new.__class__.__name__ and 'Dense' in m_old.__class__.__name__: 47 | return True 48 | else: 49 | return False 50 | 51 | 52 | 53 | layers_new=[m for m in res50.model.named_modules() if len(m[1]._parameters)+len(m[1]._buffers)>0] 54 | layers_new_dict=OrderedDict() 55 | for k,v in layers_new: 56 | layers_new_dict[k]=str(v) 57 | 58 | print(len(layers_new)) 59 | weights_new_dict=OrderedDict() 60 | mappings=OrderedDict() 61 | 62 | # for i in range(len(layers_new)): 63 | # m_new=layers_new[i][1] 64 | # m_old = layers_old[i] 65 | # if is_same_layer(m_new,m_old): 66 | # pass 67 | # else: 68 | # if is_same_layer(m_new, layers_old[i+1]) and is_same_layer(m_old, layers_new[i+1][1]) : 69 | # item= layers_old.pop(i) 70 | # layers_old.insert(i+1,item) 71 | # m_old = layers_old[i] 72 | # if is_same_layer(m_new, m_old): 73 | # pass 74 | # else: 75 | # print('') 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | for item,m_old in zip(layers_new,layers_old): 84 | name_new,m_new =item 85 | for k,v in m_new._parameters.item_list: 86 | key=name_new + '/' + k+':0' 87 | mappings[key]=None 88 | for k,v in m_new._buffers.item_list: 89 | key=name_new + '/' + k+':0' 90 | mappings[key]=None 91 | if 'Conv' in m_new.__class__.__name__ and 'Conv' in m_old.__class__.__name__: 92 | if m_new.get_weights()[0].shape==m_old.get_weights()[0].shape: 93 | mappings[name_new+'/'+m_new.weight.name]= m_old.get_weights()[0] 94 | 95 | mappings[name_new+'/'+m_new.bias.name] = m_old.get_weights()[1] 96 | else: 97 | mappings[name_new+'/'+m_new.weight.name] = None 98 | mappings[name_new+'/'+m_new.bias.name] = None 99 | elif 'Batch' in m_new.__class__.__name__ and 'Batch' in m_old.__class__.__name__: 100 | for w in m_old.weights: 101 | if 'gamma' in w.name and m_new.weight.shape==w.shape: 102 | mappings[name_new+'/'+m_new.weight.name] = to_numpy(w) 103 | elif 'beta' in w.name and m_new.bias.shape==w.shape: 104 | mappings[name_new+'/'+m_new.bias.name] =to_numpy(w) 105 | elif 'mean' in w.name and m_new.running_mean.shape==w.shape: 106 | mappings[name_new+'/'+m_new.running_mean.name] =to_numpy(w) 107 | elif 'variance' in w.name and m_new.running_var.shape==w.shape: 108 | mappings[name_new+'/'+m_new.running_var.name] = to_numpy(w) 109 | elif 'gamma' in w.name or 'beta' in w.name or 'mean' in w.name or 'variance': 110 | mappings[name_new+'/'+m_new.weight.name] = None 111 | mappings[name_new + '/' + m_new.bias.name] = None 112 | mappings[name_new + '/' + m_new.running_mean.name] = None 113 | mappings[name_new + '/' + m_new.running_var.name] = None 114 | else: 115 | print(w.name) 116 | 117 | 118 | elif 'Dense' in m_new.__class__.__name__ and 'Dense' in m_old.__class__.__name__: 119 | if m_new.get_weights()[0].shape==m_old.get_weights()[0].shape: 120 | mappings[name_new+'/'+m_new.weight.name] = m_old.get_weights()[0] 121 | mappings[name_new+'/'+m_new.bias.name] = m_old.get_weights()[1] 122 | else: 123 | mappings[name_new+'/'+m_new.weight.name] = None 124 | mappings[name_new+'/'+m_new.bias.name] = None 125 | else: 126 | print(m_old) 127 | 128 | print(len(mappings)) 129 | nn=0 130 | for name,module in res50.model.named_modules(): 131 | 132 | for k,v in module._parameters.item_list: 133 | key=name + '/' + k+':0' 134 | if key in mappings and mappings[key] is not None: 135 | if v.shape==mappings[key].shape: 136 | #v.assign(tf.Variable(mappings[key])) 137 | v.assign(tf.Variable(np.reshape(mappings[key], (v.shape.as_list())).astype(np.float32))) 138 | nn+=1 139 | for k,v in module._buffers.item_list: 140 | key=name + '/' + k+':0' 141 | if key in mappings and mappings[key] is not None: 142 | if v.shape == mappings[key].shape: 143 | #v.assign(tf.constant(mappings[key])) 144 | v.assign(tf.Variable(np.reshape( mappings[key],(v.shape.as_list())).astype(np.float32))) 145 | nn+=1 146 | 147 | 148 | print(res50.infer_single_image('cat.jpg',3)) 149 | 150 | 151 | 152 | # for name,module in res50.model.named_modules(): 153 | # for k,v in module._parameters.item_list: 154 | # key=name + '/' + k+':0' 155 | # if key in mappings and mappings[key] is not None: 156 | # if v.shape==mappings[key].shape: 157 | # if (v!=mappings[key]).numpy().any(): 158 | # print(key) 159 | # for k,v in module._buffers.item_list: 160 | # key=name + '/' + k+':0' 161 | # if key in mappings and mappings[key] is not None: 162 | # if v.shape == mappings[key].shape: 163 | # if (v != mappings[key]).numpy().any(): 164 | # print(key) 165 | #cat_array=np.expand_dims((resize((224,224),True)(image2array('cat.jpg'))-127.5)/127.5,0) 166 | #print('cat',cat_array.shape) 167 | # new_list=list(res50.model.modules()) 168 | # conv1_new=new_list[2] 169 | # bn_new=new_list[3] 170 | # 171 | # padding_old=res50_old.layers[1] 172 | # conv1_old=res50_old.layers[2] 173 | # bn_old=res50_old.layers[3] 174 | # print('is weight the same? {0}'.format(np.array_equal(conv1_new.get_weights()[0],conv1_old.get_weights()[0]))) 175 | # print('is bias the same? {0}'.format(np.array_equal(conv1_new.get_weights()[1],conv1_old.get_weights()[1]))) 176 | # conv1_weights_new=conv1_new.get_weights() 177 | # conv1_weights_old=conv1_old.get_weights() 178 | # print(np.abs(conv1_weights_new[0]-conv1_weights_old[0]).sum()) 179 | # print(np.abs(conv1_weights_new[1]-conv1_weights_old[1]).sum()) 180 | # 181 | # results_new_conv1=bn_new(conv1_new(to_tensor(cat_array.copy())))[0].numpy() 182 | # results_old_conv1=bn_old(conv1_old(padding_old(to_tensor(cat_array.copy()))))[0].numpy() 183 | # is_ok=np.array_equal(results_new_conv1,results_old_conv1) 184 | # print('is result the same? {0}'.format(is_ok)) 185 | # print(results_new_conv1.shape,results_old_conv1.shape) 186 | # print(results_new_conv1.mean(),results_old_conv1.mean()) 187 | # print(results_new_conv1[0,0,:].mean(),results_old_conv1[0,0,:].mean()) 188 | # print(results_new_conv1[-1,-1,:].mean(),results_old_conv1[-1,-1,:].mean()) 189 | # if not is_ok: 190 | # print(results_new_conv1[0,0,:]-results_old_conv1[0,0,:]) 191 | 192 | 193 | 194 | res50.model.eval() 195 | print(res50.infer_single_image('cat.jpg',5)) 196 | save(res50.model,'resnet50_tf.pth') 197 | res50.save_model('resnet50_tf.pth.tar') 198 | 199 | print('finish') 200 | 201 | 202 | 203 | 204 | 205 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![trident](trident_logo.png) 2 | **Make PyTorch and TensorFlow two become one.** 3 | 4 | 5 | | version | pytorch | tensorflow | 6 | | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------ | 7 | | [![version](https://img.shields.io/static/v1?label=&message=0.7.6&color=377EF0&style=for-the-badge)](https://img.shields.io/static/v1?label=&message=0.7.5&color=377EF0&style=for-the-badge) | ![pytorch](https://img.shields.io/static/v1?label=&message=>1.4&color=377EF0&style=for-the-badge) | ![tensorflow](https://img.shields.io/static/v1?label=&message=>2.2.0&color=377EF0&style=for-the-badge) | 8 | 9 | **Trident** is a deep learning dynamic calculation graph api based on PyTorch and TensorFlow (pure Eager mode, no Keras dependency). Through Trident, not only can you use the same developer experience (more than 99% of the same code) within PyTorch and Tensorflow, it is also designed to simplify deep learning developers routine work. It's functions not only cover computing vision, natural language understanding and reinforcement learning, but also include a simpler network structure declaration, a more powerful but easier training process control, intuitive data access and data augmentation. 10 | 11 | **Trident** 是基於PyTorch和TensorFlow的深度學習動態計算圖API(純Eager模式,無Keras依賴)。 通過Trident,您不僅可以在PyTorch和Tensorflow中使用相同的開發經驗(超過99%的相同代碼),它的誕生的目的就是希望簡化深度學習開發人員的日常工作,Trident的功能不但覆蓋機器視覺、自然語言與強化學習。它的功能還包括更簡單的網絡結構宣告,更強大但更容易實現的訓練流程控制,直觀的數據訪問和數據增強。 12 | 13 | ## Key Features 14 | 15 | - Integrated Pytorch and Tensorflow experience (from ops operation, neural network structure announcement, loss function and evaluation function call...) 16 | - Able to automatically transpose the tensor direction according to the background type (PyTorch (CHW) or Tensorflow (HWC)) 17 | - Only one original neural block is used to meet more needs. For example, Conv2d_Block integrates five functions including convolution layer, normalization, activation function, dropout, and noise. 18 | - The amount of padding can be automatically calculated through Autopad during neural layer design. Even PyTorch can delay shape inference, and use Summary to view model structure and computing power consumption information. 19 | - Rich built-in visualization, evaluation function and internal information can be inserted into the training plan. 20 | - Training Plan can be flexible like building blocks stacked to design the training process you want, while using fluent style syntax to make the overall code easier to read and easier to manage. 21 | - Provide the latest optimizers (Ranger, Lars, RangerLars, AdaBelief...) and optimization techniques (gradient centralization). 22 | - 整合一致的Pytorch與Tensorflow體驗(從ops操作、神經網路結構宣告、損失函數與評估函數調用....) 23 | - 能夠根據後台種類(PyTorch (CHW) 或是 Tensorflow (HWC))自動進行張量方向轉置 24 | - 僅用神經區塊(block)元件來滿足更多的建模需求,例如Conv2d_Block整合了卷積層、正規化、活化函數、Dropout、噪音等五種功能,同時可以在block中執行神經層融合。 25 | - 神經層設計時可以透過Autopad 自動計算padding量,就連PyTorch也可以延遲形狀推斷,以及使用Summary檢視模型結構與算力耗用信息。 26 | - 豐富的內建視覺化、評估函數以及內部訊息,可供插入至訓練計畫中。 27 | - 訓練計畫(Training Plan) 可以如同堆積木般彈性設計你想要的訓練流程,同時使用fluent style語法讓整體代碼易讀更容易管理。 28 | - 提供最新的優化器( Ranger,Lars, RangerLars, AdaBelief...)以及優化技巧(gradient centralization)。 29 | 30 | a 31 | 32 | ## New Release version 0.7.4 33 | 34 | - Experimental: Keras model (at tensorflow backend) and Primitive pytorch model (at pytorch backend) support in TrainingPlan 35 | - print_gpu_utilization in TrainingPlan 36 | - Experimental: Layer fusion (conv+norm=>conv) in ConvXd_Block, FullConnect_Block 37 | - Experimental: Automatic inplace of Relu and LeakyRelu, switch back to False when it detect in leaf layer. 38 | - Experimental: MLFlow support 39 | - New optimizer LAMB, Ranger_AdaBelief 40 | - Rewrite lots of loss function. 41 | . List[String](image path) as output of ImageDataset 42 | . 43 | . More stable and reliability. 44 | 45 | ## New Release version 0.7.3 46 | 47 | ![Alt text](images/text_process.png) 48 | 49 | - New with_accumulate_grad for accumulating gradient. 50 | - Enhancement for TextSequenceDataset and TextSequenceDataprovider. 51 | - New TextTransform: RandomMask,BopomofoConvert,ChineseConvert,RandomHomophonicTypo,RandomHomomorphicTypo 52 | - New VisionTransform: ImageMosaic, SaltPepperNoise 53 | - Transformer, Bert, Vit support in pytorch backend. 54 | - New layers and blocks: FullConnect_Block, TemporalConv1d_Block 55 | - Differentiable color space convertion function: rgb2hsv, rgb2xyz rgb2lab.... 56 | - Enhancement for GANBuilder, now conditional GAN and skip-connections networks is support. 57 | - LSTM support attention in pytorch backend, and LSTM comes in tensorflow mode. 58 | 59 | ## New Release version 0.7.1 60 | 61 | ![Alt text](images/vision_transform.png) 62 | 63 | - New Vision Transform. 64 | 65 | ## New Release version 0.7.0 66 | 67 | ![Alt text](images/tensorboard.png) 68 | 69 | - Tensorboard support. 70 | - New optimizer: AdaBelief, DiffGrad 71 | - Initializers support. 72 | 73 | ## How To Use 74 | 75 | #### Step 0: Install 76 | 77 | Simple installation from PyPI 78 | 79 | ```bash 80 | pip install tridentx --upgrade 81 | ``` 82 | 83 | #### Step 1: Add these imports 84 | 85 | ```python 86 | import os 87 | os.environ['TRIDENT_BACKEND'] = 'pytorch' 88 | import trident as T 89 | from trident import * 90 | from trident.models.pytorch_densenet import DenseNetFcn 91 | ``` 92 | 93 | #### Step 2: A simple case both in PyTorch and Tensorflow 94 | 95 | ``` 96 | data_provider=load_examples_data('dogs-vs-cats') 97 | data_provider.image_transform_funcs=[ 98 | random_rescale_crop(224,224,scale=(0.9,1.1)), 99 | random_adjust_gamma(gamma=(0.9,1.1)), 100 | normalize(0,255), 101 | normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])] 102 | 103 | model=resnet.ResNet50(include_top=True,pretrained=True,freeze_features=True,classes=2)\ 104 | .with_optimizer(optimizer=Ranger,lr=1e-3,betas=(0.9, 0.999),gradient_centralization='all')\ 105 | .with_loss(CrossEntropyLoss)\ 106 | .with_metric(accuracy,name='accuracy')\ 107 | .unfreeze_model_scheduling(200,'batch',5,None) \ 108 | .unfreeze_model_scheduling(1, 'epoch', 4, None) \ 109 | .summary() 110 | 111 | plan=TrainingPlan()\ 112 | .add_training_item(model)\ 113 | .with_data_loader(data_provider)\ 114 | .repeat_epochs(10)\ 115 | .within_minibatch_size(32)\ 116 | .print_progress_scheduling(10,unit='batch')\ 117 | .display_loss_metric_curve_scheduling(200,'batch')\ 118 | .print_gradients_scheduling(200,'batch')\ 119 | .start_now() 120 | ``` 121 | 122 | #### Step 3: Examples 123 | 124 | - mnist classsification [pytorch](https://github.com/AllanYiin/DeepBelief_Course5_Examples/blob/master/epoch001_%E5%8F%A6%E4%B8%80%E7%A8%AE%E8%A7%92%E5%BA%A6%E7%9C%8Bmnist/HelloWorld_mnist_pytorch.ipynb) [tensorflow](https://github.com/AllanYiin/DeepBelief_Course5_Examples/blob/master/epoch001_%E5%8F%A6%E4%B8%80%E7%A8%AE%E8%A7%92%E5%BA%A6%E7%9C%8Bmnist/HelloWorld_mnist_tf.ipynb) 125 | - activation function [pytorch](https://github.com/AllanYiin/DeepBelief_Course5_Examples/blob/master/epoch002_%E6%B4%BB%E5%8C%96%E5%87%BD%E6%95%B8%E5%A4%A7%E6%B8%85%E9%BB%9E/%20Activation_Function_AllStar_Pytorch.ipynb) [tensorflow](https://github.com/AllanYiin/DeepBelief_Course5_Examples/blob/master/epoch002_%E6%B4%BB%E5%8C%96%E5%87%BD%E6%95%B8%E5%A4%A7%E6%B8%85%E9%BB%9E/Activation_Function_AllStar_tf.ipynb) 126 | - auto-encoder [pytorch](https://github.com/AllanYiin/DeepBelief_Course5_Examples/blob/master/epoch003_%E8%87%AA%E5%8B%95%E5%AF%B6%E5%8F%AF%E5%A4%A2%E7%B7%A8%E7%A2%BC%E5%99%A8/Pokemon_Autoencoder_pytorch.ipynb) [tensorflow](https://github.com/AllanYiin/DeepBelief_Course5_Examples/blob/master/epoch003_%E8%87%AA%E5%8B%95%E5%AF%B6%E5%8F%AF%E5%A4%A2%E7%B7%A8%E7%A2%BC%E5%99%A8/Pokemon_Autoencoder_tf.ipynb) 127 | 128 | ## BibTeX 129 | 130 | If you want to cite the framework feel free to use this: 131 | 132 | ```bibtex 133 | @article{AllanYiin2020Trident, 134 | title={Trident}, 135 | author={AllanYiin, Taiwan}, 136 | journal={GitHub. Note: https://github.com/AllanYiin/trident}, 137 | volume={1}, 138 | year={2020} 139 | } 140 | ``` 141 | --------------------------------------------------------------------------------