├── Models
├── __init__.py
├── blur
│ ├── __init__.py
│ └── model.py
├── common
│ ├── __init__.py
│ ├── model_builder.py
│ └── util.py
├── mrcnn
│ ├── __init__.py
│ ├── utils.py
│ ├── model.py
│ └── vis.py
├── classTemplateTF
│ ├── __init__.py
│ ├── README.md
│ ├── model.py
│ └── train_classification.py
├── trainingTemplateTF
│ ├── __init__.py
│ ├── data
│ │ └── train
│ │ │ ├── input
│ │ │ └── alive_snow00001.png
│ │ │ └── groundtruth
│ │ │ └── alive00001.png
│ ├── model.py
│ ├── README.md
│ └── train_model.py
├── regressionTemplateTF
│ ├── __init__.py
│ ├── README.md
│ ├── model.py
│ └── train_regression.py
└── baseModel.py
├── Plugins
├── Server
│ ├── __init__.py
│ ├── .dockerignore
│ ├── py2.Dockerfile
│ ├── Dockerfile
│ └── server.py
└── Client
│ ├── CMakeLists.txt
│ ├── message.proto
│ ├── MLClientModelManager.h
│ ├── MLClientComms.h
│ ├── MLClient.h
│ └── MLClientModelManager.cpp
├── .gitignore
├── CMakeLists.txt
├── README.md
├── LICENSE
└── INSTALL.md
/Models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/Models/blur/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/Models/common/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/Models/mrcnn/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/Plugins/Server/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/Models/classTemplateTF/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/Models/trainingTemplateTF/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/Models/regressionTemplateTF/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/Plugins/Server/.dockerignore:
--------------------------------------------------------------------------------
1 | # Ignore when building the docker image
2 | .dockerignore
3 | Dockerfile
--------------------------------------------------------------------------------
/Models/trainingTemplateTF/data/train/input/alive_snow00001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TheFoundryVisionmongers/nuke-ML-server/HEAD/Models/trainingTemplateTF/data/train/input/alive_snow00001.png
--------------------------------------------------------------------------------
/Models/trainingTemplateTF/data/train/groundtruth/alive00001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TheFoundryVisionmongers/nuke-ML-server/HEAD/Models/trainingTemplateTF/data/train/groundtruth/alive00001.png
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore
2 | .vscode/
3 | build/
4 | build-Debug/
5 |
6 | # Ignore Python compiled files
7 | *.py[co]
8 | # Ignore configuration and weights files
9 | *.yaml
10 | *.pkl
11 |
12 | # Ignore shared objects
13 | *.so
14 |
15 | # Ignore generated files
16 | *.os
17 | *.o
18 |
19 | # Ignore all directories named:
20 | summaries/
21 | input/
22 | groundtruth/
23 | checkpoints/
24 | data/
25 | serverlocal/
26 | densepose/
--------------------------------------------------------------------------------
/Models/classTemplateTF/README.md:
--------------------------------------------------------------------------------
1 | # Classification Training Template
2 |
3 | The classTemplateTF is a training template written in TensorFlow. It aims at quickly enabling classification training. For instance, detecting the presence of a specific actor in a shot. When trained, the model can be tested and used directly in Nuke through the nuke-ML-server.
4 |
5 | Apart from the dataset structure, all other instructions are similar to the other [Training Template](https://github.com/TheFoundryVisionmongers/nuke-ML-server/tree/master/Models/trainingTemplateTF).
6 |
7 | ## Dataset
8 |
9 | To train the ML algorithm, you need to set-up your dataset in `classTemplateTF/data/train/`. This directory should contain one subdirectory per class. Any PNG, JPG, BMP, PPM or TIF images inside each of the subdirectories will be included.
10 |
11 | For example, if you want to train a classifier to differentiate between cats, dogs and foxes. The `data/train/` directory should have 3 subdirectories named `cats`, `dogs` and `foxes` with each directory containing images of the corresponding animal.
12 |
13 | Optionally, you can add a separate set of images in `classTemplateTF/data/validation/`. If available, it is periodically used to check that there is no overfitting on the training data. Please note that the validation dataset and training dataset must not intersect.
14 |
15 | If no validation dataset is found, 20% of the training data will be used as a validation split.
16 |
17 |
--------------------------------------------------------------------------------
/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.10)
2 | project(MachineLearningPlugins VERSION 1.0.0)
3 |
4 | #===------------------------------------------------------------------------===
5 | # Global settings some based on the external configuration settings
6 | set( CMAKE_CXX_STANDARD 11 )
7 | set( CMAKE_CXX_EXTENSIONS OFF )
8 | set( CMAKE_CXX_VISIBILITY_PRESET hidden )
9 | set( CMAKE_POSITION_INDEPENDENT_CODE True )
10 | if (CMAKE_CXX_COMPILER_ID MATCHES "Clang")
11 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fcolor-diagnostics")
12 | endif()
13 |
14 | #===------------------------------------------------------------------------===
15 | # Build information
16 | string( TIMESTAMP BUILDDATE_YEAR_INTERNAL "%Y" )
17 | string( TIMESTAMP BUILDDATE_MMDD_INTERNAL "%m%d" )
18 | string( TIMESTAMP BUILDDATE_FULL_INTERNAL "%Y-%m-%dT%H:%M:%S" )
19 | string( TIMESTAMP BUILDDATE_STAMP "%Y.%m%d" )
20 | string( REGEX REPLACE "^0" "" BUILDDATE_MMDD_INTERNAL ${BUILDDATE_MMDD_INTERNAL} )
21 | set( BUILDDATE_YEAR "${BUILDDATE_YEAR_INTERNAL}" CACHE STRING "Year of the build: It will default to the current year." )
22 | set( BUILDDATE_MMDD "${BUILDDATE_MMDD_INTERNAL}" CACHE STRING "Month and day of the build: It will default to the calendar month and day." )
23 | set( BUILDDATE_FULL "${BUILDDATE_FULL_INTERNAL}" CACHE STRING "Exact time of the build." )
24 |
25 | #===------------------------------------------------------------------------===
26 | # Compile CMakeLists found in subdirectories
27 | add_subdirectory(Plugins/Client)
--------------------------------------------------------------------------------
/Plugins/Client/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | # CMakeLists.txt for Machine Learning Plug-in: MLClient
2 |
3 | # Setting up MLClient sources and dependencies
4 | set (ML_CLIENT_SOURCES
5 | MLClient.cpp
6 | MLClientComms.cpp
7 | MLClientModelManager.cpp
8 | )
9 |
10 | find_package(Protobuf REQUIRED)
11 | if (WIN32)
12 | find_library(PROTOBUF_LIBRARY NAME libprotobuf PATHS ${Protobuf_LIBRARIES})
13 | endif()
14 |
15 | # Compile protobuf .cpp and .h files out of message.proto
16 | protobuf_generate_cpp(PROTO_SRCS PROTO_HDRS message.proto)
17 | list(APPEND ML_CLIENT_SOURCES ${PROTO_SRCS}) # add message.pb.cc
18 |
19 | if(NOT NUKE_INSTALL_PATH)
20 | message(FATAL_ERROR "Nuke install path not set.")
21 | endif()
22 | find_library(DDIMAGE_LIBRARY NAME DDImage libDDImage PATHS ${NUKE_INSTALL_PATH})
23 | if(NOT DDIMAGE_LIBRARY)
24 | message(FATAL_ERROR "DDImage library not found.")
25 | endif()
26 |
27 | # Create MLClient.so shared library
28 | add_library(MLClient SHARED
29 | ${ML_CLIENT_SOURCES}
30 | )
31 |
32 | set_target_properties (MLClient PROPERTIES PREFIX "")
33 | target_include_directories(MLClient PRIVATE
34 | ${NUKE_INSTALL_PATH}/include
35 | ${CMAKE_CURRENT_BINARY_DIR} # include message.pb.h
36 | ${Protobuf_INCLUDE_DIR}
37 | )
38 |
39 | target_link_libraries(MLClient
40 | ${PROTOBUF_LIBRARY}
41 | ${DDIMAGE_LIBRARY}
42 | )
43 |
44 | if (WIN32)
45 | target_link_libraries(MLClient
46 | ws2_32.lib # include windows socket library
47 | )
48 | endif (WIN32)
--------------------------------------------------------------------------------
/Models/mrcnn/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019 Foundry.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | ##############################################################################
15 |
16 | """Utility functions for mask rcnn model"""
17 |
18 | from detectron.utils.collections import AttrDict
19 | import numpy as np
20 |
21 | def dict_equal(d1, d2):
22 | """Recursively compute if two dictionaries are equals both in keys and values.
23 |
24 | # Arguments:
25 | d1, d2: The two dictionaries to compare
26 |
27 | # Return:
28 | False if any key or value are different, True otherwise
29 | """
30 | for k in d1:
31 | if k not in d2:
32 | return False
33 | for k in d2:
34 | if type(d2[k]) not in (dict, list, AttrDict, np.ndarray):
35 | if d2[k] != d1[k]:
36 | return False
37 | elif type(d2[k]) == "np.ndarray":
38 | if any(d2[k] != d1[k]):
39 | return False
40 | else: # d2[k] dictionary or list
41 | if type(d1[k]) != type(d2[k]):
42 | return False
43 | else:
44 | if type(d2[k]) == AttrDict or type(d2[k]) == dict:
45 | if(not dict_equal(d1[k], d2[k])):
46 | return False
47 | return True
--------------------------------------------------------------------------------
/Plugins/Server/py2.Dockerfile:
--------------------------------------------------------------------------------
1 | # Ubuntu 18.04 with CUDA 10.0, CuDNN 7.6
2 | # Python2.7, TensorFlow 1.15.0, PyTorch 1.4
3 | FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu18.04
4 |
5 | ARG DEBIAN_FRONTEND=noninteractive
6 | RUN apt-get update && apt-get install -y --no-install-recommends \
7 | build-essential \
8 | ca-certificates \
9 | cmake \
10 | curl \
11 | git \
12 | libglib2.0-0 \
13 | libjpeg-dev \
14 | libopencv-dev \
15 | libopenexr-dev \
16 | libpng-dev \
17 | libsm-dev \
18 | vim && \
19 | rm -rf /var/lib/apt/lists/*
20 |
21 | # Install Python 2.7
22 | RUN apt-get update && apt-get install -y --no-install-recommends \
23 | python-pip \
24 | python2.7-dev && \
25 | rm -rf /var/lib/apt/lists/*
26 |
27 | # pip version 21.0 will drop support for Python 2.7
28 | RUN python -m pip install --upgrade pip==20.1
29 | RUN pip install --no-cache-dir setuptools wheel && \
30 | pip install --no-cache-dir \
31 | future \
32 | gast==0.2.2 \
33 | protobuf \
34 | pyyaml==3.13 \
35 | scikit-image \
36 | typing \
37 | imageio==2.6.1 \
38 | OpenEXR==1.3.2
39 |
40 | # Install TF 1.15.0 GPU for Python2.7
41 | RUN pip install --no-cache-dir \
42 | tensorflow-gpu==1.15.0 \
43 | tensorflow-determinism
44 |
45 | # Install PyTorch (include Caffe2) for CUDA 10.0
46 | RUN pip install --no-cache-dir torch==1.4.0+cu100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html
47 | RUN pip install --no-cache-dir cupy-cuda100
48 | RUN pip install --no-cache-dir cython
49 |
50 | WORKDIR /workspace
51 | # Install the COCO API
52 | RUN pip install 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
53 |
54 | # Install detectron for mask RCNN
55 | RUN git clone https://github.com/facebookresearch/detectron
56 | RUN sed -i 's/cythonize(ext_modules)/cythonize(ext_modules, language_level="2")/g' detectron/setup.py
57 | RUN cd detectron && pip install -r requirements.txt && make
58 |
59 | WORKDIR /workspace/ml-server
60 | # Copy your current folder to the docker image /workspace/ml-server/ folder
61 | COPY . .
--------------------------------------------------------------------------------
/Plugins/Server/Dockerfile:
--------------------------------------------------------------------------------
1 | # Ubuntu 18.04 with CUDA 10.0, CuDNN 7.6
2 | # Python3.6, TensorFlow 1.15.0, PyTorch 1.4
3 | FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu18.04
4 |
5 | ARG DEBIAN_FRONTEND=noninteractive
6 | RUN apt-get update && apt-get install -y --no-install-recommends \
7 | build-essential \
8 | ca-certificates \
9 | cmake \
10 | curl \
11 | git \
12 | libglib2.0-0 \
13 | libjpeg-dev \
14 | libopencv-dev \
15 | libopenexr-dev \
16 | libpng-dev \
17 | libsm-dev \
18 | vim && \
19 | rm -rf /var/lib/apt/lists/*
20 |
21 | # Install Python 3.6
22 | RUN apt-get update && apt-get install -y --no-install-recommends \
23 | python3-opencv \
24 | python3-pip \
25 | python3.6-dev && \
26 | rm -rf /var/lib/apt/lists/*
27 | # Have aliases python3->python and pip3->pip
28 | RUN ln -s /usr/bin/python3 /usr/bin/python && \
29 | ln -s /usr/bin/pip3 /usr/bin/pip
30 | RUN python -m pip install --upgrade pip
31 |
32 | RUN pip install --no-cache-dir setuptools wheel && \
33 | pip install --no-cache-dir \
34 | future \
35 | gast==0.2.2 \
36 | protobuf \
37 | pyyaml==3.13 \
38 | scikit-image \
39 | typing \
40 | imageio \
41 | OpenEXR
42 |
43 | # Install TF 1.15.0 GPU for Python3.6 (no TensorRT)
44 | RUN pip install --no-cache-dir \
45 | tensorflow-gpu==1.15.0 \
46 | tensorflow-determinism
47 |
48 | # Install PyTorch (include Caffe2) for CUDA 10.0
49 | RUN pip install --no-cache-dir torch==1.4.0+cu100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html
50 | RUN pip install --no-cache-dir cupy-cuda100
51 | RUN pip install --no-cache-dir cython
52 |
53 | WORKDIR /workspace
54 | # Install the COCO API
55 | RUN pip install 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
56 |
57 | # Install detectron for mask RCNN
58 | RUN git clone https://github.com/facebookresearch/detectron
59 | RUN sed -i 's/cythonize(ext_modules)/cythonize(ext_modules, language_level="3")/g' detectron/setup.py
60 | RUN cd detectron && pip install -r requirements.txt && make
61 |
62 | WORKDIR /workspace/ml-server
63 | # Copy your current folder to the docker image /workspace/ml-server/ folder
64 | COPY . .
--------------------------------------------------------------------------------
/Plugins/Client/message.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto2";
2 |
3 | package mlserver;
4 |
5 | message RequestWrapper {
6 | optional bool info = 1;
7 | optional RequestInfo r1 = 2;
8 | optional RequestInference r2 = 3;
9 | }
10 |
11 | message RespondWrapper {
12 | optional bool info = 1;
13 | optional RespondInfo r1 = 2;
14 | optional RespondInference r2 = 3;
15 | optional Error error = 4;
16 | }
17 |
18 | message RequestInfo {
19 | optional bool info = 1;
20 | }
21 |
22 | message RespondInfo {
23 | optional int32 num_models = 1;
24 | repeated Model models = 2;
25 | }
26 |
27 | message Model {
28 | optional string name = 1;
29 | optional string label = 2;
30 | repeated ImagePrototype inputs = 3;
31 | repeated ImagePrototype outputs = 4;
32 | repeated BoolAttrib bool_options = 5;
33 | repeated IntAttrib int_options = 6;
34 | repeated FloatAttrib float_options = 7;
35 | repeated StringAttrib string_options = 8;
36 | repeated BoolAttrib button_options = 9;
37 | repeated MultipleChoiceOption mc_options = 10;
38 | }
39 |
40 | message MultipleChoiceOption {
41 | optional string name = 1;
42 | optional string value = 2;
43 | repeated string choices = 3;
44 | }
45 |
46 | message ImagePrototype {
47 | optional string name = 1;
48 | optional int32 channels = 2;
49 | }
50 |
51 | message Error {
52 | optional string msg = 1;
53 | }
54 |
55 | message RequestInference {
56 | optional Model model = 1;
57 | repeated Image images = 2;
58 | }
59 |
60 | message RespondInference {
61 | optional int32 num_images = 1;
62 | repeated Image images = 2;
63 | optional int32 num_objects = 3;
64 | repeated FieldValuePairAttrib objects = 4;
65 | }
66 |
67 | message Image {
68 | optional int32 width = 1;
69 | optional int32 height = 2;
70 | optional int32 channels = 3;
71 | optional bytes image = 4;
72 | }
73 |
74 | message BoolAttrib {
75 | optional string name = 1;
76 | repeated bool values = 2 [packed=true];
77 | }
78 |
79 | message IntAttrib {
80 | optional string name = 1;
81 | repeated int32 values = 2 [packed=true];
82 | }
83 |
84 | message FloatAttrib {
85 | optional string name = 1;
86 | repeated float values = 2 [packed=true];
87 | }
88 |
89 | message StringAttrib {
90 | optional string name = 1;
91 | repeated string values = 2;
92 | }
93 |
94 | message FieldValuePairAttrib {
95 | optional string name = 1;
96 | repeated FieldValuePair values = 2;
97 | }
98 |
99 | message FieldValuePair {
100 | repeated IntAttrib int_attributes = 1;
101 | repeated FloatAttrib float_attributes = 2;
102 | repeated StringAttrib string_attributes = 3;
103 | repeated FieldValuePairAttrib children = 4;
104 | }
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Python-based Machine Learning Frame Server for Nuke
2 |
3 | This repository contains the client-server system enabling Machine Learning (ML) inference in Nuke. This work is split into two parts: a client Nuke plug-in [Plugins/Client/](Plugins/Client) and the Python frame server [Plugins/Server](Plugins/Server).
4 |
5 | The following models are provided as examples:
6 | - blur: a simple gaussian blur operation
7 | - [Mask-RCNN](https://github.com/facebookresearch/Detectron)
8 | - [trainingTemplateTF](https://github.com/TheFoundryVisionmongers/nuke-ML-server/tree/master/Models/trainingTemplateTF): a training template written in TensorFlow which enables simple image-to-image training. Instructions on how to use this template are found [here](https://github.com/TheFoundryVisionmongers/nuke-ML-server/tree/master/Models/trainingTemplateTF).
9 |
10 |
11 |

12 |
Example of Nuke doing DensePose inference.
13 |
14 |
15 | ## Introduction
16 |
17 | The Machine Learning (ML) plug-in connects Nuke to a Python server to apply ML models to images.
18 | The plug-in works as follows:
19 | - The Nuke node can connect to a server given an ip address and port,
20 | - The Python server responds with the list of available Machine Learning (ML) models and options,
21 | - The Nuke node displays the models in an enumeration knob, from which the user can choose,
22 | - On every renderStripe call, the current image and model options are sent from the Nuke node to the server,
23 | - The server does an inference on the image using the chosen model/options. This inference can be an actual inference operation of a machine learning model, or just some other image processing code,
24 | - The resulting image is sent back to the Nuke node.
25 |
26 | ## Installation
27 |
28 | Please find installation instructions in [INSTALL.md](INSTALL.md).
29 |
30 | ## Known Issues
31 |
32 | 1. The GPU can run out of memory when doing model inference. To run Mask-RCNN, it is necessary to have a GPU memory of at least 6GB.
33 | 2. If you get the following error: "The TensorFlow library was compiled to use AVX instructions, but these aren't available on your machine." Please refer to [issue#10](https://github.com/TheFoundryVisionmongers/nuke-ML-server/issues/10) [Thanks to [samhodge](https://github.com/samhodge)]
34 |
35 | ## License
36 |
37 | The source code is licensed under the Apache License, Version 2.0, found in [LICENSE](LICENSE).
38 |
39 | ## Contacts
40 |
41 | - Johanna Barbier (Johanna.Barbier@foundry.com)
42 | - Dan Ring (Dan.Ring@foundry.com)
43 |
44 | This plug-in was initially created by Sebastian Lutz (https://v-sense.scss.tcd.ie/?profile=sebastian-lutz).
45 |
46 | ## References
47 |
48 | - [Mask R-CNN](https://arxiv.org/abs/1703.06870).
49 | Kaiming He, Georgia Gkioxari, Piotr Dollár, and Ross Girshick.
50 | IEEE International Conference on Computer Vision (ICCV), 2017.
51 | - [DensePose: Dense Human Pose Estimation In The Wild](https://arxiv.org/abs/1802.00434).
52 | Riza Alp Güler, Natalia Neverova, Iasonas Kokkinos.
53 | IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2018.
--------------------------------------------------------------------------------
/Models/blur/model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019 Foundry.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | ##############################################################################
15 |
16 | from ..baseModel import BaseModel
17 |
18 | import cv2
19 | import numpy as np
20 |
21 | from ..common.util import linear_to_srgb, srgb_to_linear
22 |
23 | import message_pb2
24 |
25 | class Model(BaseModel):
26 | def __init__(self):
27 | super(Model, self).__init__()
28 | self.name = 'Gaussian Blur'
29 |
30 | self.kernel_size = 5
31 | self.make_blur = False
32 |
33 | # Define options
34 | self.options = ('kernel_size',)
35 | self.buttons = ('make_blur',)
36 |
37 | # Define inputs/outputs
38 | self.inputs = {'input': 3}
39 | self.outputs = {'output': 3}
40 |
41 | def inference(self, image_list):
42 | """Do an inference on the model with a set of inputs.
43 |
44 | # Arguments:
45 | image_list: The input image list
46 |
47 | Return the result of the inference.
48 | """
49 | image = image_list[0]
50 | image = linear_to_srgb(image)
51 | image = (image * 255).astype(np.uint8)
52 | kernel = self.kernel_size * 2 + 1
53 | blur = cv2.GaussianBlur(image, (kernel, kernel), 0)
54 | blur = blur.astype(np.float32) / 255.
55 | blur = srgb_to_linear(blur)
56 |
57 | # If make_blur button is pressed in Nuke
58 | if self.make_blur:
59 | script_msg = message_pb2.FieldValuePairAttrib()
60 | script_msg.name = "PythonScript"
61 | # Create a Python script message to run in Nuke
62 | python_script = self.blur_script(blur)
63 | script_msg_val = script_msg.values.add()
64 | script_msg_str = script_msg_val.string_attributes.add()
65 | script_msg_str.values.extend([python_script])
66 | return [blur, script_msg]
67 |
68 | return [blur]
69 |
70 | def blur_script(self, image):
71 | """Return the Python script function to create a pop up window in Nuke.
72 |
73 | The pop up window displays the brightest pixel position of the given image.
74 | """
75 | # Compute brightest pixel of the image
76 | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
77 | [min_val, max_val, min_loc, max_loc] = cv2.minMaxLoc(gray)
78 | # Y axis are inversed in Nuke
79 | max_loc = (max_loc[0], image.shape[0] - max_loc[1])
80 | popup_msg = (
81 | "Brightest pixel of the blurred image\\n"
82 | "Location: {}, Value: {:.3f}."
83 | ).format(max_loc, max_val)
84 | script = "nuke.message('{}')\n".format(popup_msg)
85 | return script
--------------------------------------------------------------------------------
/Models/baseModel.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019 Foundry.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | ##############################################################################
15 |
16 | import sys
17 | if sys.version_info.major > 2: # python3
18 | unicode = str
19 |
20 | import numpy as np
21 |
22 | class BaseModel(object):
23 | def __init__(self):
24 | self.name = 'Base model'
25 | self.options = () # List of attribute names that should get exposed in Nuke
26 | self.buttons = () # List of button names that should get exposed in Nuke
27 | self.inputs = {'input': 3} # Define Inputs (name, #channels)
28 | self.outputs = {'output': 3} # And Outputs (name, #channels)
29 | pass
30 |
31 | def inference(self, *inputs):
32 | """Do an inference on the model with a set of inputs.
33 |
34 | # Arguments:
35 | inputs: A list of images
36 |
37 | # Return:
38 | The result of the inference as a list of images
39 | """
40 | raise NotImplementedError
41 |
42 | def get_options(self):
43 | """Get a dictionary of exposed options from the model.
44 |
45 | To expose options, self.options has to be filled with attribute names.
46 | Return a dictionary of option names and values.
47 | """
48 | opt = {}
49 | if hasattr(self, 'options'):
50 | for option in self.options:
51 | value = getattr(self, option)
52 | if isinstance(value, unicode):
53 | value = str(value)
54 | assert type(value) in [bool, int, float, str], \
55 | 'Broadcasted options need to be one of bool, int, float, str.'
56 | opt[option] = value
57 | return opt
58 |
59 | def set_options(self, optionsDict):
60 | """Set the options of the model.
61 |
62 | # Arguments:
63 | optionsDict: A dictionary of attribute names and values
64 | """
65 | for name, value in optionsDict.items():
66 | setattr(self, name, value)
67 |
68 | def get_buttons(self):
69 | """Return the defined buttons of the model.
70 |
71 | To expose buttons in Nuke, self.buttons has to be filled with attribute names.
72 | """
73 | btn = {}
74 | if hasattr(self, 'buttons'):
75 | for button in self.buttons:
76 | value = getattr(self, button)
77 | assert type(value) in [bool], 'Broadcasted buttons need to be bool.'
78 | btn[button] = value
79 | return btn
80 |
81 | def set_buttons(self, buttonsDict):
82 | """Set the buttons of the model.
83 |
84 | # Arguments:
85 | buttonsDict: A dictionary of attribute names and values
86 | """
87 | for name, value in buttonsDict.items():
88 | setattr(self, name, value)
89 |
90 | def get_inputs(self):
91 | """Return the defined inputs of the model."""
92 | return self.inputs
93 |
94 | def get_outputs(self):
95 | """Return the defined outputs of the model."""
96 | return self.outputs
97 |
98 | def get_name(self):
99 | """Return the name of the model."""
100 | return self.name
--------------------------------------------------------------------------------
/Models/regressionTemplateTF/README.md:
--------------------------------------------------------------------------------
1 | # Regression Training Template
2 |
3 | The regressionTemplateTF is a training template written in TensorFlow. It aims at quickly enabling image-to-parameters training. For instance, finding the lens distortion parameters or gamma correction of an image. When trained, the model can be tested and used directly in Nuke through the nuke-ML-server.
4 |
5 | Compared to the image-to-image [Training Template](https://github.com/TheFoundryVisionmongers/nuke-ML-server/tree/master/Models/trainingTemplateTF) and the image-to-labels [Classification Template](https://github.com/TheFoundryVisionmongers/nuke-ML-server/tree/master/Models/classTemplateTF), this template will not work out-of-the-box and will require some data preprocessing implementation, as detailed in the [following section](https://github.com/TheFoundryVisionmongers/nuke-ML-server/tree/master/Models/regressionTemplateTF#data-preprocessing-implementation). This guide will be based on the current template example: gamma-correction prediction.
6 |
7 | For instructions on how to set-up the training, on potential training issues or on TensorBoard visualisation, please refer to the [training template readme](https://github.com/TheFoundryVisionmongers/nuke-ML-server/blob/master/Models/trainingTemplateTF/README.md).
8 |
9 | ## Data Preprocessing Implementation
10 |
11 | To train the ML algorithm, you need to set-up your dataset in `regressionTemplateTF/data/train/`. In addition to the training data, it is highly recommended to have validation data in `regressionTemplateTF/data/validation/`. This allows you to check that there is no overfitting on the training data. Please note that the validation dataset and training dataset must not intersect.
12 |
13 | Your training/validation dataset will be different depending on your task, i.e. depending on which parameter(s) you want to learn. In the current implementation, we are doing a regression on one parameter (gamma) with a specifically designed data preprocessing pipeline. Namely our model training input is a stack of both original and gamma-graded image histograms.
14 |
15 | Our preprocessing pipeline read the original image (from `regressionTemplateTF/data/train/` or `regressionTemplateTF/data/validation/`), then apply gamma correction to that image using a random gamma value. Both the original and resulting gamma-graded images are grayscaled, resized and we compute their 100-bin histogram. The model input (shape [2, 100]) is a stack of those two histograms.
16 |
17 | The above data preprocessing is specific to the gamma-correction problem, which means that for other parameters prediction (e.g. colour grading, lens distortion..), you will have to modifiy the data preprocessing functions found in `train_regression.py` and in `model.py` to match your task. The inference file `model.py` has to be changed as well, as the same data preprocessing used in training has to be applied before doing an inference in Nuke.
18 |
19 | To summarise, for your specific regression task, you need to implement an appropriate data preprocessing and modify the code in both the training file `train_regression.py` and the inference file `model.py` accordingly.
20 |
21 | ## Training
22 |
23 | Inside your docker container, go to the regressionTemplateTF folder:
24 | ```
25 | cd /workspace/ml-server/models/regressionTemplateTF
26 | ```
27 | Then directly train your model:
28 | ```
29 | python train_regression.py
30 | ```
31 | You can also specify the batch size, learning rate and number of epochs:
32 | ```
33 | python train_regression.py --bch=16 --lr=1e-3 --ep=1000
34 | ```
35 | It is now possible to have deterministic training. You will be able to reproduce your training (get same model weights) by setting the seed to a random int number (here 77):
36 | ```
37 | python train_regression.py --seed=77
38 | ```
39 | We enable deterministic training in part by applying a GPU patch to the stock TensorFlow, this GPU patch slows down training significantly. By adding the `--no-gpu-patch` tag to the previous command, you achieve a slighlty less deterministic training but keep the same training time.
40 |
41 | Note: the current gamma-correction task is creating gamma-graded images on-the-fly using random gamma values, so for the training to succeed it is recommended to have >500 training images.
42 |
43 |
44 |
45 |
46 |
--------------------------------------------------------------------------------
/Plugins/Client/MLClientModelManager.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019 Foundry.
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 | //*************************************************************************
15 |
16 | #ifndef MLClientModelManager_H
17 | #define MLClientModelManager_H
18 |
19 | #include
20 | #include
21 | #include "message.pb.h"
22 |
23 | #include "DDImage/Op.h"
24 | #include "DDImage/Knobs.h"
25 |
26 | class MLClientModelManager;
27 |
28 | //! The role of this custom knob is to serialise and to store the selected model and its parameters.
29 | //! As these exist as dynamic knobs, this is to workaround the fact that we would need to know about these knobs
30 | //! in advance to save them the usual way.
31 | class MLClientModelKnob : public DD::Image::Knob
32 | {
33 | public:
34 | MLClientModelKnob(DD::Image::Knob_Closure* kc, DD::Image::Op* op, const char* name);
35 |
36 | // Knob overrides.
37 | const char* Class() const override;
38 | bool not_default () const override;
39 | //! Serialises the currently selected model and its parameters as follows:
40 | //! {model:modelName;param1:value1;param2:value2}
41 | void to_script (std::ostream &out, const DD::Image::OutputContext *, bool quote) const override;
42 | //! Deserialises the saved model and its parameters.
43 | //! The model can then be retreived with getModel()
44 | //! and the dictionary of parameters with getParameters().
45 | bool from_script(const char * src) override;
46 |
47 | std::string getModel() const;
48 | const std::map& getParameters() const;
49 |
50 | private:
51 | //! Serialises the dynamic knobs to the given output stream.
52 | //! This function is generic for the Ints, Floats and Bools knobs
53 | //! provided that the corresponding getNumOfT and getDynamicTName
54 | //! functions are given.
55 | void toScriptT(MLClientModelManager& mlManager, std::ostream &out,
56 | int (MLClientModelManager::*getNum)() const,
57 | std::string (MLClientModelManager::*getDynamicName)(int)) const;
58 | //! Serialises the dynamic knobs containing strings to the given output stream.
59 | void toScriptStrings(MLClientModelManager& mlManager, std::ostream &out) const;
60 |
61 | private:
62 | DD::Image::Op* _op;
63 | std::string _model;
64 | std::map _parameters;
65 | };
66 |
67 | //! Class to parse and store knobs for a given model.
68 | class MLClientModelManager
69 | {
70 | public:
71 | explicit MLClientModelManager(DD::Image::Op* parent);
72 | ~MLClientModelManager();
73 |
74 | // Getters of the class
75 | int getNumOfFloats() const;
76 | int getNumOfInts() const;
77 | int getNumOfBools() const;
78 | int getNumOfStrings() const;
79 | int getNumOfButtons() const;
80 |
81 | std::string getDynamicBoolName(int idx);
82 | std::string getDynamicFloatName(int idx);
83 | std::string getDynamicIntName(int idx);
84 | std::string getDynamicStringName(int idx);
85 | std::string getDynamicButtonName(int idx);
86 |
87 | float* getDynamicFloatValue(int idx);
88 | int* getDynamicIntValue(int idx);
89 | bool* getDynamicBoolValue(int idx);
90 | std::string* getDynamicStringValue(int idx);
91 | bool* getDynamicButtonValue(int idx);
92 | void setDynamicButtonValue(int idx, int value);
93 |
94 | void clear();
95 | //! Parse the model options from the ML server.
96 | void parseOptions(const mlserver::Model& m);
97 | //! Update any current options from any changes to the ML server.
98 | void updateOptions(mlserver::Model& m);
99 |
100 | private:
101 | DD::Image::Op* _parent;
102 | std::vector _dynamicBoolValues;
103 | std::vector _dynamicIntValues;
104 | std::vector _dynamicFloatValues;
105 | std::vector _dynamicStringValues;
106 | std::vector _dynamicButtonValues;
107 | std::vector _dynamicBoolNames;
108 | std::vector _dynamicIntNames;
109 | std::vector _dynamicFloatNames;
110 | std::vector _dynamicStringNames;
111 | std::vector _dynamicButtonNames;
112 | };
113 |
114 | #endif
115 |
--------------------------------------------------------------------------------
/Plugins/Client/MLClientComms.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019 Foundry.
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 | //*************************************************************************
15 |
16 | #ifndef MLCLIENTCOMMS_H
17 | #define MLCLIENTCOMMS_H
18 |
19 | // Protobuf headers
20 | #include "message.pb.h"
21 |
22 | using byte = unsigned char;
23 |
24 |
25 | //! The Machine Learning (ML) Client plug-in connects Nuke to a Python server to apply ML models to images.
26 | /*! This plug-in can connect to a server (given a host and port), which responds
27 | with a list of available Machine Learning (ML) models and options.
28 | On every /a renderStripe() call, the image and model options are sent from Nuke to the server,
29 | there the server can process the image by doing Machine Learning inference,
30 | finally the resulting image is sent back to Nuke.
31 | */
32 | class MLClientComms
33 | {
34 | public:
35 | // Static consts
36 | static const int kNumberOfBytesHeaderSize;
37 |
38 | static const int kTimeout;
39 | static const int kMaxNumberOfTry;
40 |
41 | // Static non-conts
42 | static bool Verbose;
43 |
44 | public:
45 | //! Constructor. Initialize user controls to their default values, then try to
46 | //! connect to the specified host / port. Following the c-tor, you can test for
47 | //! a valid connection by calling isConnected().
48 | MLClientComms(const std::string& hostStr, int port);
49 |
50 | //! Destructor. Tear down any existing connection.
51 | virtual ~MLClientComms();
52 |
53 | public:
54 | // Public static methods for client-server communication
55 |
56 | //! Test if a given hostname is valid, returning true if it is, false otherwise
57 | static bool ValidateHostName(const std::string& hostStr);
58 |
59 | //! Print debug related information to std::cout, when ::Verbose is set to true.
60 | static void Vprint(std::string msg);
61 |
62 | public:
63 | // Public methods for client-server communication
64 |
65 | //! Return whether this object is connected to the specified server.
66 | bool isConnected() const;
67 |
68 | //! Function for discovering & negotiating the available models and their parameters.
69 | //! Return true on success, false otherwise with the errorMsg filled in.
70 | bool sendInfoRequestAndReadInfoResponse(mlserver::RespondWrapper& responseWrapper, std::string& errorMsg);
71 |
72 | //! Function for performing the inference on a selected model.
73 | //! Return true on success, false otherwise with the errorMsg filled in.
74 | bool sendInferenceRequestAndReadInferenceResponse(mlserver::RequestInference& requestInference, mlserver::RespondWrapper& responseWrapper, std::string& errorMsg);
75 |
76 | private:
77 | // Private client / server comms functions
78 |
79 | //! Try to connect to the server with the specified hostStr & port, by repeatedly
80 | //! calling setupConnection() below until a connection is made or times out. After it
81 | //! returns, you can test if it was successful by calling isConnected().
82 | void connectLoop();
83 |
84 | //! Create a socket to connect to the server specified by hostStr and port.
85 | //! Return true on success, false otherwise with a message filled in errorStr.
86 | bool setupConnection(std::string& errorStr);
87 |
88 | //! Request the server to return a future message about its models. This is used
89 | //! to instruct the server that it should set itself up.
90 | //! Return true on success, false otherwise.
91 | bool sendInfoRequest();
92 |
93 | //! Retrieve the response from the server and store it in responseWrapper, to be parsed
94 | //! elsewhere. Return true on success, false otherwise.
95 | bool readInfoResponse(mlserver::RespondWrapper& responseWrapper);
96 |
97 | //! Send a messaged image to to the server. Return true on success, false otherwise.
98 | bool sendInferenceRequest(mlserver::RequestInference& requestInference);
99 |
100 | //! Marshall the returned image into a float buffer of the original image size. Note, this
101 | //! expects the size of result to have been set to the same size as the image that was
102 | //! previously sent to the server. Return true on success, false otherwise.
103 | bool readInferenceResponse(mlserver::RespondWrapper& responseWrapper);
104 |
105 | //! Pull the data after determining the size 'siz' from the header.
106 | //! Helper to the above 'readInfoResponse' function.
107 | bool readInfoResponse(google::protobuf::uint32 siz, mlserver::RespondWrapper& responseWrapper);
108 |
109 | //! Pull the data after determining the size 'siz' from the header.
110 | //! Helper to the above 'readInferenceResponse' function.
111 | bool readInferenceResponse(google::protobuf::uint32 siz, mlserver::RespondWrapper& responseWrapper);
112 |
113 | //! Close the current connection if one is open.
114 | void closeConnection();
115 |
116 | private:
117 | // Private helper functions
118 | google::protobuf::uint32 readHdr(char* buf);
119 | void* getInAddr(struct sockaddr* sa);
120 |
121 | private:
122 | // Private member variables
123 | std::string _hostStr;
124 | int _port;
125 |
126 | bool _isConnected;
127 | int _socket;
128 | };
129 |
130 | #endif // MLCLIENTCOMMS_H
131 |
--------------------------------------------------------------------------------
/Models/classTemplateTF/model.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import sys
4 | import os
5 | import time
6 |
7 | import scipy.misc
8 | import numpy as np
9 | import cv2
10 |
11 | import tensorflow as tf
12 | tf.compat.v1.disable_eager_execution() # For TF 2.x compatibility
13 |
14 | from ..baseModel import BaseModel
15 | from ..common.util import print_, get_saved_model_list, linear_to_srgb
16 |
17 | import message_pb2
18 |
19 | class Model(BaseModel):
20 | """Load your trained model and do inference in Nuke"""
21 |
22 | def __init__(self):
23 | super(Model, self).__init__()
24 | self.name = 'Classification Template'
25 | dir_path = os.path.dirname(os.path.realpath(__file__))
26 | self.checkpoints_dir = os.path.join(dir_path, 'checkpoints')
27 | self.batch_size = 1
28 |
29 | # Initialise checkpoint name to the most recent trained model
30 | ckpt_names = get_saved_model_list(self.checkpoints_dir)
31 | if not ckpt_names: # empty list
32 | self.checkpoint_name = ''
33 | else:
34 | self.checkpoint_name = ckpt_names[-1]
35 | self.prev_ckpt_name = self.checkpoint_name
36 |
37 | # Button to get classification label
38 | self.get_label = False
39 |
40 | # Define options
41 | self.options = ('checkpoint_name',)
42 | self.buttons = ('get_label',)
43 | # Define inputs/outputs
44 | self.inputs = {'input': 3}
45 | self.outputs = {'output': 3}
46 |
47 | def load_model(self):
48 | # Check if empty or invalid checkpoint name
49 | if self.checkpoint_name=='':
50 | ckpt_names = get_saved_model_list(self.checkpoints_dir)
51 | if not ckpt_names:
52 | raise ValueError("No checkpoints found in {}".format(self.checkpoints_dir))
53 | else:
54 | raise ValueError("Empty checkpoint name, try an available checkpoint in {} (ex: {})"
55 | .format(self.checkpoints_dir, ckpt_names[-1]))
56 | print_("Loading trained model checkpoint...\n", 'm')
57 | # Load from given checkpoint file name
58 | model = tf.keras.models.load_model(os.path.join(self.checkpoints_dir, self.checkpoint_name))
59 | model._make_predict_function()
60 | print_("...Checkpoint {} loaded\n".format(self.checkpoint_name), 'm')
61 | return model
62 |
63 | def inference(self, image_list):
64 | """Do an inference on the model with a set of inputs.
65 |
66 | # Arguments:
67 | image_list: The input image list
68 |
69 | Return the result of the inference.
70 | """
71 | image = image_list[0]
72 | image = linear_to_srgb(image).copy()
73 | image = (image * 255).astype(np.uint8)
74 |
75 | if not hasattr(self, 'model'):
76 | # Initialise tensorflow graph
77 | tf.compat.v1.reset_default_graph()
78 | config = tf.compat.v1.ConfigProto()
79 | config.gpu_options.allow_growth=True
80 | self.sess = tf.compat.v1.Session(config=config)
81 | # Necessary to switch / load_weights on different h5 file
82 | tf.compat.v1.keras.backend.set_session(self.sess)
83 | # Load most recent trained model
84 | self.model = self.load_model()
85 | self.graph = tf.compat.v1.get_default_graph()
86 | self.prev_ckpt_name = self.checkpoint_name
87 | self.class_labels = (self.checkpoint_name.split('.')[0]).split('_')
88 | else:
89 | tf.compat.v1.keras.backend.set_session(self.sess)
90 |
91 | # If checkpoint name has changed, load new checkpoint
92 | if self.prev_ckpt_name != self.checkpoint_name or self.checkpoint_name == '':
93 | self.model = self.load_model()
94 | self.graph = tf.compat.v1.get_default_graph()
95 | self.class_labels = (self.checkpoint_name.split('.')[0]).split('_')
96 | # If checkpoint correctly loaded, update previous checkpoint name
97 | self.prev_ckpt_name = self.checkpoint_name
98 |
99 | image = cv2.resize(image, dsize=(224, 224), interpolation=cv2.INTER_NEAREST)
100 | # Predict on new data
101 | image_batch = np.expand_dims(image, 0)
102 | # Preprocess a numpy array encoding a batch of images (RGB values within [0, 255])
103 | image_batch = tf.keras.applications.mobilenet.preprocess_input(image_batch)
104 | start = time.time()
105 |
106 | with self.graph.as_default():
107 | y_prob = self.model.predict(image_batch)
108 |
109 | y_class = y_prob.argmax(axis=-1)[0]
110 | duration = time.time() - start
111 | # Print results on server side
112 | print('Inference duration: {:4.3f}s'.format(duration))
113 | class_scores = str(["{0:0.4f}".format(i) for i in y_prob[0]]).replace("'", "")
114 | print("Class scores: {} --> Label: {}".format(class_scores, self.class_labels[y_class]))
115 |
116 | # If get_label button is pressed in Nuke
117 | if self.get_label:
118 | # Send back which class was detected
119 | script_msg = message_pb2.FieldValuePairAttrib()
120 | script_msg.name = "PythonScript"
121 | # Create a Python script message to run in Nuke
122 | nuke_msg = "Class scores: {}\\nLabel: {}".format(class_scores, self.class_labels[y_class])
123 | python_script = "nuke.message('{}')\n".format(nuke_msg)
124 | script_msg_val = script_msg.values.add()
125 | script_msg_str = script_msg_val.string_attributes.add()
126 | script_msg_str.values.extend([python_script])
127 | return [image_list[0], script_msg]
128 | return [image_list[0]]
--------------------------------------------------------------------------------
/Models/trainingTemplateTF/model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019 Foundry.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | ##############################################################################
15 |
16 | from __future__ import print_function
17 |
18 | import sys
19 | import os
20 | import time
21 |
22 | import scipy.misc
23 | import numpy as np
24 | import cv2
25 |
26 | import tensorflow as tf
27 | tf.compat.v1.disable_eager_execution() # For TF 2.x compatibility
28 |
29 | from ..baseModel import BaseModel
30 | from ..common.model_builder import EncoderDecoder
31 | from ..common.util import print_, get_ckpt_list, linear_to_srgb, srgb_to_linear
32 |
33 | class Model(BaseModel):
34 | """Load your trained model and do inference in Nuke"""
35 |
36 | def __init__(self):
37 | super(Model, self).__init__()
38 | self.name = 'Training Template TF'
39 | self.n_levels = 3
40 | self.scale = 0.5
41 | dir_path = os.path.dirname(os.path.realpath(__file__))
42 | self.checkpoints_dir = os.path.join(dir_path, 'checkpoints')
43 | self.batch_size = 1
44 |
45 | # Initialise checkpoint name to the most advanced checkpoint (highest step)
46 | ckpt_names = get_ckpt_list(self.checkpoints_dir)
47 | if not ckpt_names: # empty list
48 | self.checkpoint_name = ''
49 | else:
50 | ckpt_steps = [int(name.split('-')[-1]) for name in ckpt_names]
51 | self.checkpoint_name = ckpt_names[ckpt_steps.index(max(ckpt_steps))]
52 | self.prev_ckpt_name = self.checkpoint_name
53 |
54 | self.options = ('checkpoint_name',)
55 | # Define inputs/outputs
56 | self.inputs = {'input': 3}
57 | self.outputs = {'output': 3}
58 |
59 | def load(self, sess, checkpoint_dir):
60 | # Check if empty or invalid checkpoint name
61 | if self.checkpoint_name=='':
62 | ckpt_names = get_ckpt_list(self.checkpoints_dir)
63 | if not ckpt_names:
64 | raise ValueError("No checkpoints found in {}".format(self.checkpoints_dir))
65 | else:
66 | raise ValueError("Empty checkpoint name, try an available checkpoint in {} (ex: {})"
67 | .format(self.checkpoints_dir, ckpt_names[-1]))
68 | print_("Loading trained model checkpoint...\n", 'm')
69 | # Load from given checkpoint file name
70 | self.saver.restore(sess, os.path.join(checkpoint_dir, self.checkpoint_name))
71 | print_("...Checkpoint {} loaded\n".format(self.checkpoint_name), 'm')
72 |
73 | def inference(self, image_list):
74 | """Do an inference on the model with a set of inputs.
75 |
76 | # Arguments:
77 | image_list: The input image list
78 |
79 | Return the result of the inference.
80 | """
81 | image = image_list[0]
82 | image = linear_to_srgb(image).copy()
83 | H, W, channels = image.shape
84 |
85 | # Add padding so that width and height of image are a multiple of 16
86 | new_H = int(H + 16 - H%16) if H%16!=0 else H
87 | new_W = int(W + 16 - W%16) if W%16!=0 else W
88 | img_pad = np.pad(image, ((0, new_H - H), (0, new_W - W), (0, 0)), 'reflect')
89 |
90 | if not hasattr(self, 'sess'):
91 | # Initialise input placeholder size
92 | self.curr_height = new_H; self.curr_width = new_W
93 | # Initialise tensorflow graph
94 | tf.compat.v1.reset_default_graph()
95 | config = tf.compat.v1.ConfigProto()
96 | config.gpu_options.allow_growth=True
97 | self.sess=tf.compat.v1.Session(config=config)
98 | self.input = tf.compat.v1.placeholder(tf.float32, shape=[self.batch_size, new_H, new_W, channels])
99 | self.model = EncoderDecoder(self.n_levels, self.scale, channels)
100 | self.infer_op = self.model(self.input, reuse=False)
101 | # Load model checkpoint having the longest training (highest step)
102 | self.saver = tf.compat.v1.train.Saver()
103 | self.load(self.sess, self.checkpoints_dir)
104 | self.prev_ckpt_name = self.checkpoint_name
105 |
106 | elif self.curr_height != new_H or self.curr_width != new_W:
107 | # Modify input placeholder size
108 | self.input = tf.compat.v1.placeholder(tf.float32, shape=[self.batch_size, new_H, new_W, channels])
109 | self.infer_op = self.model(self.input, reuse=False)
110 | # Update image height and width
111 | self.curr_height = new_H; self.curr_width = new_W
112 |
113 | # If checkpoint name has changed, load new checkpoint
114 | if self.prev_ckpt_name != self.checkpoint_name or self.checkpoint_name == '':
115 | self.load(self.sess, self.checkpoints_dir)
116 | # If checkpoint correctly loaded, update previous checkpoint name
117 | self.prev_ckpt_name = self.checkpoint_name
118 |
119 | # Apply current model to the padded input image
120 | image_batch = np.expand_dims(img_pad, 0)
121 | start = time.time()
122 | # The network is expecting image_batch to be of type tf.float32
123 | inference = self.sess.run(self.infer_op, feed_dict={self.input: image_batch})
124 | duration = time.time() - start
125 | print('Inference duration: {:4.3f}s'.format(duration))
126 | res = inference[-1]
127 | # Remove first dimension and padding
128 | res = res[0, :H, :W, :]
129 |
130 | output_image = srgb_to_linear(res)
131 | return [output_image]
--------------------------------------------------------------------------------
/Plugins/Client/MLClient.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019 Foundry.
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 | //*************************************************************************
15 |
16 | #ifndef MLCLIENT_H
17 | #define MLCLIENT_H
18 |
19 | // Standard plug-in include files.
20 | #include "DDImage/PlanarIop.h"
21 | #include "DDImage/NukeWrapper.h"
22 | #include "DDImage/Row.h"
23 | #include "DDImage/Tile.h"
24 | #include "DDImage/Knobs.h"
25 | #include "DDImage/Thread.h"
26 | #include
27 |
28 | // Local include files
29 | #include "MLClientComms.h"
30 | #include "MLClientModelManager.h"
31 |
32 | //! The Machine Learning (ML) Client plug-in connects Nuke to a Python server to apply ML models to images.
33 | /*! This plug-in can connect to a server (given a host and port), which responds
34 | with a list of available Machine Learning (ML) models and options.
35 | On every /a renderStripe() call, the image and model options are sent from Nuke to the server,
36 | there the server can process the image by doing Machine Learning inference,
37 | finally the resulting image is sent back to Nuke.
38 | */
39 | class MLClient : public DD::Image::PlanarIop
40 | {
41 |
42 | public:
43 | // Static consts
44 | static const char* const kClassName;
45 | static const char* const kHelpString;
46 |
47 | static const char* const kDefaultHostName;
48 | static const int kDefaultPortNumber;
49 |
50 | private:
51 | static const DD::Image::ChannelSet kDefaultChannels;
52 | static const int kDefaultNumberOfChannels;
53 |
54 | public:
55 | //! Constructor. Initialize user controls to their default values.
56 | MLClient(Node* node);
57 | virtual ~MLClient();
58 |
59 | public:
60 | // DDImage::Iop overrides
61 |
62 | //! The maximum number of input connections the operator can have.
63 | int maximum_inputs() const;
64 | //! The minimum number of input connections the operator can have.
65 | int minimum_inputs() const;
66 | /*! Return the text Nuke should draw on the arrow head for input \a input
67 | in the DAG window. This should be a very short string, one letter
68 | ideally. Return null or an empty string to not label the arrow.
69 | */
70 | const char* input_label(int input, char* buffer) const;
71 |
72 | bool useStripes() const;
73 | bool renderFullPlanes() const;
74 |
75 | void _validate(bool);
76 | void getRequests(const DD::Image::Box& box, const DD::Image::ChannelSet& channels, int count, DD::Image::RequestOutput &reqData) const;
77 |
78 | /*! This function is called by Nuke for processing the current image.
79 | The image and model options are sent from Nuke to the server,
80 | there the server can process the image by doing Machine Learning inference,
81 | finally the resulting image is sent back to Nuke.
82 | The function tries to reconnect if no connection is set.
83 | */
84 | void renderStripe(DD::Image::ImagePlane& imagePlane);
85 |
86 | //! Information to the plug-in manager of DDNewImage/Nuke.
87 | static const DD::Image::Iop::Description description;
88 |
89 | static void addDynamicKnobs(void*, DD::Image::Knob_Callback);
90 | void knobs(DD::Image::Knob_Callback f);
91 | int knob_changed(DD::Image::Knob*);
92 |
93 | //! Return the name of the class.
94 | const char* Class() const;
95 | const char* node_help() const;
96 |
97 | MLClientModelManager& getModelManager();
98 | int getNumNewKnobs();
99 | void setNumNewKnobs(int i);
100 |
101 | private:
102 | // Private functions for talking to the server
103 | //! Try connect to the server and set-up the relevant knobs. Return true on
104 | //! success, false otherwise and setting a descriptive error in errorMsg.
105 | bool refreshModelsAndKnobsFromServer(std::string& errorMsg);
106 |
107 | //! Return whether we successfully managed to pull model
108 | //! info from the server at some time in the past, and the selected model is
109 | //! valid.
110 | bool haveValidModelInfo() const;
111 |
112 | //! Connect to server, then send inference request and read inference response.
113 | //! Return true on success, false otherwise filling in the errorMsg.
114 | bool processImage(const std::string& hostStr, int port, mlserver::RespondWrapper& responseWrapper, std::string& errorMsg);
115 |
116 | //! Parse the response messge from the server, and if it contains
117 | //! an image, attempt to copy the image to the imagePlane. Return
118 | //! true on success, false otherwise and fill in the error string.
119 | bool renderOutputBuffer(mlserver::RespondWrapper& responseWrapper, DD::Image::ImagePlane& imagePlane, std::string& errorMsg);
120 |
121 | //! Return whether the dynamic knobs should be shown or not.
122 | bool getShowDynamic() const;
123 |
124 | //! Look for the knob with the given name. If found, restore its value
125 | //! from the given serialised value.
126 | void restoreKnobValue(const std::string& knobName, const std::string& value);
127 |
128 | private:
129 | std::string _host;
130 | bool _hostIsValid;
131 | int _port;
132 | bool _portIsValid;
133 | int _chosenModel;
134 | bool _modelSelected;
135 |
136 | DD::Image::Knob* _selectedModelknob;
137 | std::vector _serverModels;
138 |
139 | std::vector _numInputs;
140 | std::vector> _inputNames;
141 |
142 | MLClientModelManager _modelManager;
143 |
144 | bool _showDynamic;
145 | int _numNewKnobs; // Used to track the number of knobs created by the previous pass, so that the same number can be deleted next time.
146 |
147 | };
148 |
149 | #endif // MLCLIENT_H
150 |
--------------------------------------------------------------------------------
/Models/mrcnn/model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019 Foundry.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | ##############################################################################
15 |
16 | # The inference method is based on:
17 | # --------------------------------------------------------
18 | # Facebook infer_simple.py file:
19 | # https://github.com/facebookresearch/Detectron/blob/master/tools/infer_simple.py
20 | # Copyright (c) Facebook, Inc. and its affiliates.
21 | # Licensed under the Apache License, Version 2.0
22 | # --------------------------------------------------------
23 |
24 | import copy
25 | import numpy as np
26 |
27 | from caffe2.python import workspace
28 | # import libcaffe2_detectron_ops_gpu.so
29 | import detectron.utils.c2 as c2_utils
30 | c2_utils.import_detectron_ops()
31 |
32 | from detectron.core.config import assert_and_infer_cfg
33 | from detectron.core.config import cfg
34 | from detectron.core.config import merge_cfg_from_file, merge_cfg_from_cfg
35 | from detectron.utils.collections import AttrDict
36 | from detectron.utils.io import cache_url
37 | from detectron.utils.logging import setup_logging
38 | from detectron.utils.timer import Timer
39 | import detectron.core.test_engine as infer_engine
40 | import detectron.datasets.dummy_datasets as dummy_datasets
41 | import detectron.utils.c2 as c2_utils
42 |
43 | from .vis import vis_one_image_binary, vis_one_image_opencv
44 | from .utils import dict_equal
45 | from ..common.util import linear_to_srgb, srgb_to_linear
46 | from ..baseModel import BaseModel
47 |
48 | class Model(BaseModel):
49 | def __init__(self):
50 | super(Model, self).__init__()
51 | self.name = 'Mask RCNN'
52 |
53 | # Configuration and weights options
54 | # By default, we use ResNet50 backbone architecture, you can switch to
55 | # ResNet101 to increase quality if your GPU memory is higher than 8GB.
56 | # To do so, you will need to download both .yaml and .pkl ResNet101 files
57 | # then replace the below 'cfg_file' with the following:
58 | # self.cfg_file = 'models/mrcnn/e2e_mask_rcnn_X-101-64x4d-FPN_2x.yaml'
59 | self.cfg_file = 'models/mrcnn/e2e_mask_rcnn_R-50-FPN_2x.yaml'
60 | self.weights = 'models/mrcnn/model_final.pkl'
61 | self.default_cfg = copy.deepcopy(AttrDict(cfg)) # cfg from detectron.core.config
62 | self.mrcnn_cfg = AttrDict()
63 | self.dummy_coco_dataset = dummy_datasets.get_coco_dataset()
64 |
65 | # Inference options
66 | self.show_box = True
67 | self.show_class = True
68 | self.thresh = 0.7
69 | self.alpha = 0.4
70 | self.show_border = True
71 | self.border_thick = 1
72 | self.bbox_thick = 1
73 | self.font_scale = 0.35
74 | self.binary_masks = False
75 |
76 | # Define exposed options
77 | self.options = (
78 | 'show_box', 'show_class', 'thresh', 'alpha', 'show_border',
79 | 'border_thick', 'bbox_thick', 'font_scale', 'binary_masks',
80 | )
81 | # Define inputs/outputs
82 | self.inputs = {'input': 3}
83 | self.outputs = {'output': 3}
84 |
85 | def inference(self, image_list):
86 | """Do an inference on the model with a set of inputs.
87 |
88 | # Arguments:
89 | image_list: The input image list
90 |
91 | Return the result of the inference.
92 | """
93 | image = image_list[0]
94 | image = linear_to_srgb(image)*255.
95 | imcpy = image.copy()
96 |
97 | # Initialize the model out of the configuration and weights files
98 | if not hasattr(self, 'model'):
99 | workspace.ResetWorkspace()
100 | # Reset to default config
101 | merge_cfg_from_cfg(self.default_cfg)
102 | # Load mask rcnn configuration file
103 | merge_cfg_from_file(self.cfg_file)
104 | assert_and_infer_cfg(cache_urls=False, make_immutable=False)
105 | self.model = infer_engine.initialize_model_from_cfg(self.weights)
106 | # Save mask rcnn full configuration file
107 | self.mrcnn_cfg = copy.deepcopy(AttrDict(cfg)) # cfg from detectron.core.config
108 | else:
109 | # There is a global config file for all detectron models (Densepose, Mask RCNN..)
110 | # Check if current global config file is correct for mask rcnn
111 | if not dict_equal(self.mrcnn_cfg, cfg):
112 | # Free memory of previous workspace
113 | workspace.ResetWorkspace()
114 | # Load mask rcnn configuration file
115 | merge_cfg_from_cfg(self.mrcnn_cfg)
116 | assert_and_infer_cfg(cache_urls=False, make_immutable=False)
117 | self.model = infer_engine.initialize_model_from_cfg(self.weights)
118 |
119 | with c2_utils.NamedCudaScope(0):
120 | # If using densepose/detectron GitHub, im_detect_all also returns cls_bodys
121 | # Only takes the first 3 elements of the list for compatibility
122 | cls_boxes, cls_segms, cls_keyps = infer_engine.im_detect_all(
123 | self.model, image[:, :, ::-1], None
124 | )[:3]
125 |
126 | if self.binary_masks:
127 | res = vis_one_image_binary(
128 | imcpy,
129 | cls_boxes,
130 | cls_segms,
131 | thresh=self.thresh
132 | )
133 | else:
134 | res = vis_one_image_opencv(
135 | imcpy,
136 | cls_boxes,
137 | cls_segms,
138 | cls_keyps,
139 | thresh=self.thresh,
140 | show_box=self.show_box,
141 | show_class=self.show_class,
142 | dataset=self.dummy_coco_dataset,
143 | alpha=self.alpha,
144 | show_border=self.show_border,
145 | border_thick=self.border_thick,
146 | bbox_thick=self.bbox_thick,
147 | font_scale=self.font_scale
148 | )
149 |
150 | res = srgb_to_linear(res.astype(np.float32) / 255.)
151 |
152 | return [res]
--------------------------------------------------------------------------------
/Models/trainingTemplateTF/README.md:
--------------------------------------------------------------------------------
1 | # Training Template: Train and Infer Models in the nuke-ML-server
2 |
3 | The TrainingTemplateTF model is a training template written in TensorFlow. It aims at quickly enabling image-to-image training using a multi-scale encoder-decoder model. When trained, the model can be tested and used directly in Nuke through the nuke-ML-server.
4 |
5 | For instance, if you have a set of noisy / clear image pairs and would like to train a model to be able to denoise an image, you simply need to fill in your data in the `TrainingTemplateTF/data` and start the training with one command line. You can monitor the training using TensorBoard and eventually test the trained model on live images in Nuke.
6 |
7 | This page contains instructions on how to use this training template. The training happens in the Docker container, while the inference is done through the MLClient plugin.
8 |
9 | ## Set-up
10 |
11 | Start by installing the nuke-ML-server (see [INSTALL.md](https://github.com/TheFoundryVisionmongers/nuke-ML-server/blob/master/INSTALL.md)). If you had already installed the previous version, you will still have to rebuild the docker image once:
12 | ```
13 | cd Plugins/Server/
14 | sudo docker build -t -f Dockerfile .
15 | ```
16 |
17 | To launch the [TensorBoard Visualisation](https://github.com/TheFoundryVisionmongers/nuke-ML-server/tree/master/Models/trainingTemplateTF#tensorboard) from within the Docker, you have to run the docker container ([Run Docker Container](https://github.com/TheFoundryVisionmongers/nuke-ML-server/blob/master/INSTALL.md#run-docker-container) section) with an exported port 6006:
18 | ```
19 | sudo docker run --gpus all -v /absolute/path/to/nuke-ML-server/Models/:/workspace/ml-server/models -p 6006:6006 -it
20 | ```
21 |
22 | ## Train in Docker
23 |
24 | ### Dataset
25 |
26 | To train the ML algorithm, your need to provide a dataset of groundtruth & input image data pairs. For instance, the input data could be blurred images and the groundtruth corresponding sharp images. In that case, you would like the model to learn to infer a sharp image out of a blurred input image.
27 |
28 | Respectively place your input and groundtruth data in `trainingTemplateTF/data/train/input/` and `trainingTemplateTF/data/train/groundtruth/` folders.
29 |
30 | Optionally, you can add a separate set of image pairs in `trainingTemplateTF/data/val/input/`and `trainingTemplateTF/data/val/groundtruth/`. If this validation dataset is available, it is periodically tested on the current model weights to check that there is no overfitting on the training data. Please note that the validation dataset and training dataset must not intersect, no image pair should be found in both datasets.
31 |
32 | Notes:
33 | - The preprocessing cropping size is currently 256x256, therefore the dataset images are expected to be at least 256x256.
34 | - Supported image types are JPG, PNG, BMP and EXR.
35 | - Depending on the compression used, EXR images can be slower to read. In our experiments, the fastest EXR read is achieved with B44, B44A or no compression.
36 |
37 | ### Training
38 |
39 | Inside your docker container, go to the trainingTemplateTF folder:
40 | ```
41 | cd /workspace/ml-server/models/trainingTemplateTF
42 | ```
43 | Then directly train your model:
44 | ```
45 | python train_model.py
46 | ```
47 | You can also specify the batch size, learning rate and number of epochs:
48 | ```
49 | python train_model.py --bch=16 --lr=1e-4 --ep=1000
50 | ```
51 | It is now possible to have deterministic training. You will be able to reproduce your training (get same model weights) by setting the seed to a random int number (here 77):
52 | ```
53 | python train_model.py --seed=77
54 | ```
55 | We enable deterministic training in part by applying a GPU patch to the stock TensorFlow, this GPU patch slows down training significantly. By adding the `--no-gpu-patch` tag to the previous command, you achieve a slighlty less deterministic training but keep the same training time.
56 |
57 | ### Potential Training Issues
58 |
59 | The principal issue you may hit when training is a GPU out-of-memory (OOM) error. To apply training with default values, your GPU memory should be at least 8GB.
60 |
61 | If you reach an OOM error, you can consider reducing the GPU memory requirements -likely at the expense of the final model performance- by:
62 | - Building a simplified version of the encoder-decoder model found in [`model_builder.py`](https://github.com/TheFoundryVisionmongers/nuke-ML-server/blob/master/Models/trainingTemplateTF/util/model_builder.py) (e.g. by removing layers),
63 | - Reducing the batch size (`--bch` argument),
64 | - Or lowering the preprocessing cropping size (`crop_size` in [`train_model.py`](https://github.com/TheFoundryVisionmongers/nuke-ML-server/blob/master/Models/trainingTemplateTF/train_model.py)).
65 |
66 | During training, images are cropped as a preprocessing step before being fed to the network. Therefore if you want your model to learn a global image information (e.g. lens distortion), this cropping preprocessing should be changed in the code (e.g. use resize & padding instead), so as to keep the whole image information.
67 |
68 | ### TensorBoard
69 |
70 | [TensorBoard](https://www.tensorflow.org/guide/summaries_and_tensorboard) is a great way to visualise how your training is progressing.
71 |
72 | The TrainingTemplateTF automatically saves learning rate and loss evolution as well as input, groundtruth and temporary output images in the `trainingTemplateTF/summaries/` folder.
73 |
74 | To view these TensorBoard summaries, first find which container is currently running your training (STATUS: Up, PORTS: 0.0.0.0:6006->6006/tcp, NAMES=``) from all the created docker containers:
75 | ```
76 | sudo docker ps -a
77 | ```
78 | Launch a second terminal connected to the same docker container, where `` is the name of your training container found above:
79 | ```
80 | docker exec -it bash
81 | ```
82 | Launch TensorBoard in this new docker terminal to view the progression in real-time in your browser:
83 | ```
84 | tensorboard --logdir models/trainingTemplateTF/summaries/
85 | ```
86 | From your host machine, you can now navigate to the following browser address to monitor your training: http://localhost:6006.
87 |
88 | ### Checkpoints
89 |
90 | During training, the model weights and graph are saved every N steps and put in the `trainingTemplateTF/checkpoints/` folder. A checkpoint name, for instance `trainingTemplateTF.model-375000` means that it contains the weights after 375,000 training steps using model trainingTemplateTF.
91 |
92 | When launching a training, you can decide to start from scratch or resume training from a list of previous checkpoints.
93 |
94 | ## Inference in Nuke
95 |
96 | After training your model inside the docker container, you can launch Nuke and select the `Training Template TF` model in the MLClient node.
97 |
98 | The plugin will automatically load the most advanced trained checkpoints found in `trainingTemplateTF/checkpoints/`, and run an inference using the loaded weights and graph. If you prefer to use older checkpoints, you can write the name of a previous checkpoint as an inference option in Nuke.
99 |
100 | This is a great way to verify on your own live-data that the model weights converged correctly without overfitting on the training data.
101 |
102 | Note: the inference is done on saved checkpoints and not on a frozen graph, which implies that the saved checkpoint graph must correspond to the current graph. If you change the graph (by changing the preprocessing step, number of layers, variable names etc.), you won't directly be able to load older checkpoints built on a different graph.
--------------------------------------------------------------------------------
/Models/regressionTemplateTF/model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Foundry.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | ##############################################################################
15 |
16 | from __future__ import print_function
17 |
18 | import sys
19 | import os
20 | import time
21 |
22 | import scipy.misc
23 | import numpy as np
24 | import cv2
25 |
26 | import tensorflow as tf
27 | tf.compat.v1.disable_eager_execution() # For TF 2.x compatibility
28 |
29 | from models.baseModel import BaseModel
30 | from models.common.model_builder import baseline_model
31 | from models.common.util import print_, get_ckpt_list, linear_to_srgb, srgb_to_linear
32 |
33 | import message_pb2
34 |
35 | class Model(BaseModel):
36 | """Load your trained model and do inference in Nuke"""
37 |
38 | def __init__(self):
39 | super(Model, self).__init__()
40 | self.name = 'Regression Template TF'
41 | self.n_levels = 3
42 | self.scale = 0.5
43 | dir_path = os.path.dirname(os.path.realpath(__file__))
44 | self.checkpoints_dir = os.path.join(dir_path, 'checkpoints')
45 | self.patch_size = 50
46 | self.output_param_number = 1
47 |
48 | # Initialise checkpoint name to the latest checkpoint
49 | ckpt_names = get_ckpt_list(self.checkpoints_dir)
50 | if not ckpt_names: # empty list
51 | self.checkpoint_name = ''
52 | else:
53 | latest_ckpt = tf.compat.v1.train.latest_checkpoint(self.checkpoints_dir)
54 | if latest_ckpt is not None:
55 | self.checkpoint_name = latest_ckpt.split('/')[-1]
56 | else:
57 | self.checkpoint_name = ckpt_names[-1]
58 | self.prev_ckpt_name = self.checkpoint_name
59 |
60 | # Silence TF log when creating tf.Session()
61 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
62 |
63 | # Define options
64 | self.gamma_to_predict = 1.0
65 | self.predict = False
66 | self.options = ('checkpoint_name', 'gamma_to_predict',)
67 | self.buttons = ('predict',)
68 | # Define inputs/outputs
69 | self.inputs = {'input': 3}
70 | self.outputs = {'output': 3}
71 |
72 | def load(self, model):
73 | # Check if empty or invalid checkpoint name
74 | if self.checkpoint_name=='':
75 | ckpt_names = get_ckpt_list(self.checkpoints_dir)
76 | if not ckpt_names:
77 | raise ValueError("No checkpoints found in {}".format(self.checkpoints_dir))
78 | else:
79 | raise ValueError("Empty checkpoint name, try an available checkpoint in {} (ex: {})"
80 | .format(self.checkpoints_dir, ckpt_names[-1]))
81 | print_("Loading trained model checkpoint...\n", 'm')
82 | # Load from given checkpoint file name
83 | self.saver.restore(self.sess, os.path.join(self.checkpoints_dir, self.checkpoint_name))
84 | print_("...Checkpoint {} loaded\n".format(self.checkpoint_name), 'm')
85 |
86 | def inference(self, image_list):
87 | """Do an inference on the model with a set of inputs.
88 |
89 | # Arguments:
90 | image_list: The input image list
91 |
92 | Return the result of the inference.
93 | """
94 | image = image_list[0]
95 | image = linear_to_srgb(image).copy()
96 |
97 | if not hasattr(self, 'sess'):
98 | # Initialise tensorflow graph
99 | tf.compat.v1.reset_default_graph()
100 | config = tf.compat.v1.ConfigProto()
101 | config.gpu_options.allow_growth=True
102 | self.sess=tf.compat.v1.Session(config=config)
103 | # Input is stacked histograms of original and gamma-graded images.
104 | input_shape = [1, 2, 100]
105 | # Initialise input placeholder size
106 | self.input = tf.compat.v1.placeholder(tf.float32, shape=input_shape)
107 | self.model = baseline_model(
108 | input_shape=input_shape[1:],
109 | output_param_number=self.output_param_number)
110 | self.infer_op = self.model(self.input)
111 | # Load latest model checkpoint
112 | self.saver = tf.compat.v1.train.Saver()
113 | self.load(self.model)
114 | self.prev_ckpt_name = self.checkpoint_name
115 |
116 | # If checkpoint name has changed, load new checkpoint
117 | if self.prev_ckpt_name != self.checkpoint_name or self.checkpoint_name == '':
118 | self.load(self.model)
119 | # If checkpoint correctly loaded, update previous checkpoint name
120 | self.prev_ckpt_name = self.checkpoint_name
121 |
122 | # Preprocess image same way we preprocessed it for training
123 | # Here for gamma correction compute histograms
124 | def histogram(x, value_range=[0.0, 1.0], nbins=100):
125 | """Return histogram of tensor x"""
126 | h, w, c = x.shape
127 | hist = tf.histogram_fixed_width(x, value_range, nbins=nbins)
128 | hist = tf.divide(hist, h * w * c)
129 | return hist
130 | with tf.compat.v1.Session() as sess:
131 | # Convert to grayscale
132 | img_gray = tf.image.rgb_to_grayscale(image)
133 | img_gray = tf.image.resize(img_gray, [self.patch_size, self.patch_size])
134 | # Apply gamma correction
135 | img_gray_grade = tf.math.pow(img_gray, self.gamma_to_predict)
136 | img_grade = tf.math.pow(image, self.gamma_to_predict)
137 | # Compute histograms
138 | img_hist = histogram(img_gray)
139 | img_grade_hist = histogram(img_gray_grade)
140 | hists_op = tf.stack([img_hist, img_grade_hist], axis=0)
141 | hists, img_grade = sess.run([hists_op, img_grade])
142 | res_img = srgb_to_linear(img_grade)
143 |
144 | hists_batch = np.expand_dims(hists, 0)
145 | start = time.time()
146 | # Run model inference
147 | inference = self.sess.run(self.infer_op, feed_dict={self.input: hists_batch})
148 | duration = time.time() - start
149 | print('Inference duration: {:4.3f}s'.format(duration))
150 | res = inference[-1]
151 | print("Predicted gamma: {}".format(res))
152 |
153 | # If predict button is pressed in Nuke
154 | if self.predict:
155 | script_msg = message_pb2.FieldValuePairAttrib()
156 | script_msg.name = "PythonScript"
157 | # Create a Python script message to run in Nuke
158 | python_script = self.nuke_script(res)
159 | script_msg_val = script_msg.values.add()
160 | script_msg_str = script_msg_val.string_attributes.add()
161 | script_msg_str.values.extend([python_script])
162 | return [res_img, script_msg]
163 |
164 | return [res_img]
165 |
166 | def nuke_script(self, res):
167 | """Return the Python script function to create a pop up window in Nuke."""
168 | popup_msg = "Predicted gamma: {}".format(res)
169 | script = "nuke.message('{}')\n".format(popup_msg)
170 | return script
--------------------------------------------------------------------------------
/Models/common/model_builder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019 Foundry.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | ##############################################################################
15 |
16 | from builtins import range # python 2/3 forward-compatible (xrange)
17 | import tensorflow as tf
18 |
19 | class ResNetBlock(tf.keras.layers.Layer):
20 | """Classic ResNet residual block"""
21 |
22 | def __init__(self, new_dim=32, ksize=5, name='resblock'):
23 | super(ResNetBlock, self).__init__(name=name)
24 | self.conv2D_1 = tf.keras.layers.Conv2D(
25 | filters=new_dim, kernel_size=ksize, padding='SAME',
26 | activation=tf.nn.relu, name='conv1')
27 | self.conv2D_2 = tf.keras.layers.Conv2D(
28 | filters=new_dim, kernel_size=ksize, padding='SAME',
29 | activation=None, name='conv2')
30 |
31 | def call(self, inputs):
32 | x = self.conv2D_1(inputs)
33 | x = self.conv2D_2(x)
34 | return x + inputs
35 |
36 | class EncoderDecoder(tf.keras.Model):
37 | """Create an encoder decoder model"""
38 |
39 | def __init__(self, n_levels, scale, channels, name='g_net'):
40 | super(EncoderDecoder, self).__init__(name=name)
41 | self.n_levels = n_levels
42 | self.scale = scale
43 |
44 | # Encoder layers
45 | self.conv1_1 = tf.keras.layers.Conv2D(
46 | filters=32, kernel_size=5, padding='SAME',
47 | activation=tf.nn.relu, name='enc1_1')
48 | self.block1_2 = ResNetBlock(32, 5, name='enc1_2')
49 | self.block1_3 = ResNetBlock(32, 5, name='enc1_3')
50 | self.block1_4 = ResNetBlock(32, 5, name='enc1_4')
51 | self.conv2_1 = tf.keras.layers.Conv2D(
52 | filters=64, kernel_size=5, strides=2,
53 | padding='SAME', activation=tf.nn.relu, name='enc2_1')
54 | self.block2_2 = ResNetBlock(64, 5, name='enc2_2')
55 | self.block2_3 = ResNetBlock(64, 5, name='enc2_3')
56 | self.block2_4 = ResNetBlock(64, 5, name='enc2_4')
57 | self.conv3_1 = tf.keras.layers.Conv2D(
58 | filters=128, kernel_size=5, strides=2,
59 | padding='SAME', activation=tf.nn.relu, name='enc3_1')
60 | self.block3_2 = ResNetBlock(128, 5, name='enc3_2')
61 | self.block3_3 = ResNetBlock(128, 5, name='enc3_3')
62 | self.block3_4 = ResNetBlock(128, 5, name='enc3_4')
63 | # Decoder layers
64 | self.deblock3_3 = ResNetBlock(128, 5, name='dec3_3')
65 | self.deblock3_2 = ResNetBlock(128, 5, name='dec3_2')
66 | self.deblock3_1 = ResNetBlock(128, 5, name='dec3_1')
67 | self.deconv2_4 = tf.keras.layers.Conv2DTranspose(
68 | filters=64, kernel_size=4, strides=2,
69 | padding='SAME', activation=tf.nn.relu, name='dec2_4')
70 | self.deblock2_3 = ResNetBlock(64, 5, name='dec2_3')
71 | self.deblock2_2 = ResNetBlock(64, 5, name='dec2_2')
72 | self.deblock2_1 = ResNetBlock(64, 5, name='dec2_1')
73 | self.deconv1_4 = tf.keras.layers.Conv2DTranspose(
74 | filters=32, kernel_size=4, strides=2,
75 | padding='SAME', activation=tf.nn.relu, name='dec1_4')
76 | self.deblock1_3 = ResNetBlock(32, 5, name='dec1_3')
77 | self.deblock1_2 = ResNetBlock(32, 5, name='dec1_2')
78 | self.deblock1_1 = ResNetBlock(32, 5, name='dec1_1')
79 | self.deconv0_4 = tf.keras.layers.Conv2DTranspose(
80 | filters=channels, kernel_size=5, padding='SAME',
81 | activation=None, name='dec1_0')
82 |
83 | def call(self, inputs, reuse=False):
84 | # Apply encoder decoder
85 | n, h, w, c = inputs.get_shape().as_list()
86 | n_outputs = []
87 | input_pred = inputs
88 | with tf.compat.v1.variable_scope('', reuse=reuse):
89 | for i in range(self.n_levels):
90 | scale = self.scale ** (self.n_levels - i - 1)
91 | hi = int(round(h * scale))
92 | wi = int(round(w * scale))
93 | input_init = tf.image.resize(inputs, [hi, wi], method='bilinear')
94 | input_pred = tf.stop_gradient(tf.image.resize(input_pred, [hi, wi], method='bilinear'))
95 | input_all = tf.concat([input_init, input_pred], axis=3, name='inp')
96 |
97 | # Encoder
98 | conv1_1 = self.conv1_1(input_all)
99 | conv1_2 = self.block1_2(conv1_1)
100 | conv1_3 = self.block1_3(conv1_2)
101 | conv1_4 = self.block1_4(conv1_3)
102 | conv2_1 = self.conv2_1(conv1_4)
103 | conv2_2 = self.block2_2(conv2_1)
104 | conv2_3 = self.block2_3(conv2_2)
105 | conv2_4 = self.block2_4(conv2_3)
106 | conv3_1 = self.conv3_1(conv2_4)
107 | conv3_2 = self.block3_2(conv3_1)
108 | conv3_3 = self.block3_3(conv3_2)
109 | encoded = self.block3_4(conv3_3)
110 |
111 | # Decoder
112 | deconv3_3 = self.deblock3_3(encoded)
113 | deconv3_2 = self.deblock3_2(deconv3_3)
114 | deconv3_1 = self.deblock3_1(deconv3_2)
115 | deconv2_4 = self.deconv2_4(deconv3_1)
116 | cat2 = deconv2_4 + conv2_4 # Skip connection
117 | deconv2_3 = self.deblock2_3(cat2)
118 | deconv2_2 = self.deblock2_2(deconv2_3)
119 | deconv2_1 = self.deblock2_1(deconv2_2)
120 | deconv1_4 = self.deconv1_4(deconv2_1)
121 | cat1 = deconv1_4 + conv1_4 # Skip connection
122 | deconv1_3 = self.deblock1_3(cat1)
123 | deconv1_2 = self.deblock1_2(deconv1_3)
124 | deconv1_1 = self.deblock1_1(deconv1_2)
125 | input_pred = self.deconv0_4(deconv1_1)
126 |
127 | if i >= 0:
128 | n_outputs.append(input_pred)
129 | if i == 0:
130 | tf.compat.v1.get_variable_scope().reuse_variables()
131 | return n_outputs
132 |
133 | def mobilenet_transfer(class_number):
134 | """Return a classification model with a mobilenet backbone pretrained on ImageNet
135 |
136 | # Arguments:
137 | class_number: Number of classes / labels to detect
138 | """
139 | # Import the mobilenet model and discards the last 1000 neuron layer.
140 | base_model = tf.keras.applications.MobileNet(input_shape=(224,224,3), weights='imagenet',include_top=False, pooling='avg')
141 |
142 | x = base_model.output
143 | x = tf.keras.layers.Dense(1024,activation='relu')(x)
144 | x = tf.keras.layers.Dense(1024,activation='relu')(x)
145 | x = tf.keras.layers.Dense(512,activation='relu')(x)
146 | # Final layer with softmax activation
147 | preds = tf.keras.layers.Dense(class_number,activation='softmax')(x)
148 | # Build the model
149 | model = tf.keras.models.Model(inputs=base_model.input,outputs=preds)
150 |
151 | # Freeze base_model
152 | # for layer in base_model.layers: # <=> to [:86]
153 | # layer.trainable = False
154 | # Freeze the first 60 layers and fine-tune the rest
155 | for layer in model.layers[:60]:
156 | layer.trainable=False
157 | for layer in model.layers[60:]:
158 | layer.trainable=True
159 |
160 | return model
161 |
162 | def baseline_model(input_shape, output_param_number=1, hidden_layer_size=16):
163 | """Return a fully connected model with 1 hidden layer"""
164 | if hidden_layer_size < output_param_number:
165 | raise ValueError("Neurons in the hidden layer (={}) \
166 | should be > output param number (={})".format(
167 | hidden_layer_size, output_param_number))
168 | model = tf.keras.Sequential()
169 | model.add(tf.keras.layers.Flatten(input_shape=input_shape))
170 | # Regular densely connected NN layer
171 | model.add(tf.keras.layers.Dense(hidden_layer_size, activation=tf.nn.relu))
172 | model.add(tf.keras.layers.Dense(output_param_number, activation=None)) # linear activation
173 | return model
--------------------------------------------------------------------------------
/Models/common/util.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019 Foundry.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | ##############################################################################
15 |
16 | import sys
17 | import os
18 | import re
19 |
20 | import numpy as np
21 | import OpenEXR, Imath
22 | import cv2
23 |
24 | import tensorflow as tf
25 |
26 | def print_(str, colour='', bold=False):
27 | if colour == 'w': # yellow warning
28 | sys.stdout.write('\033[93m')
29 | elif colour == "e": # red error
30 | sys.stdout.write('\033[91m')
31 | elif colour == "m": # magenta info
32 | sys.stdout.write('\033[95m')
33 | if bold:
34 | sys.stdout.write('\033[1m')
35 | sys.stdout.write(str)
36 | sys.stdout.write('\033[0m')
37 | sys.stdout.flush()
38 |
39 | ## GET DATA ##
40 |
41 | def get_filepaths_from_dir(dir_path):
42 | """Recursively walk through the given directory and return a list of file paths"""
43 | data_list = []
44 | for (root, directories, filenames) in os.walk(dir_path):
45 | directories.sort()
46 | filenames.sort()
47 | for filename in filenames:
48 | data_list += [os.path.join(root,filename)]
49 | return data_list
50 |
51 | def get_labels_from_dir(dir_path):
52 | """Return classification class labels (= first subdirectories names)"""
53 | labels_list = []
54 | for (root, directories, filenames) in os.walk(dir_path):
55 | directories.sort()
56 | labels_list += directories
57 | # Break to only keep the top directory
58 | break
59 | # Remove '.' in folder names for label retrieval in model.py
60 | labels_list = [''.join(label.split('.')) for label in labels_list]
61 | return labels_list
62 |
63 | def atoi(text):
64 | return int(text) if text.isdigit() else text
65 |
66 | def natural_keys(text):
67 | """Use mylist.sort(key=natural_keys) to sort mylist in human order"""
68 | return [atoi(c) for c in re.split(r'(\d+)', text)]
69 |
70 | def get_ckpt_list(ckpt_dir):
71 | filenames_list = []
72 | for (root, directories, filenames) in os.walk(ckpt_dir):
73 | filenames_list += filenames
74 | # Break to only keep the top directory
75 | break
76 | ckpt_list = []
77 | for filename in filenames_list:
78 | split = filename.split('.')
79 | if len(split) > 1 and split[-1] == 'index':
80 | # remove .index to get the ckeckpoint name
81 | ckpt_list += [filename[:-6]]
82 | ckpt_list.sort(key=natural_keys)
83 | return ckpt_list
84 |
85 | def get_saved_model_list(ckpt_dir):
86 | """Return a list of HDF5 models found in ckpt_dir"""
87 | filenames_list = []
88 | for (root, directories, filenames) in os.walk(ckpt_dir):
89 | filenames_list += filenames
90 | # Break to only keep the top directory
91 | break
92 | ckpt_list = []
93 | for filename in filenames_list:
94 | if filename.endswith(('.h5', '.hdf5')):
95 | ckpt_list += [filename]
96 | ckpt_list.sort(key=natural_keys)
97 | return ckpt_list
98 |
99 | ## PROCESS DATA ##
100 |
101 | def im2uint8(x):
102 | if x.__class__ == tf.Tensor:
103 | return tf.cast(tf.clip_by_value(x, 0.0, 1.0) * 255.0, tf.uint8)
104 | else:
105 | t = np.clip(x, 0.0, 1.0) * 255.0
106 | return t.astype(np.uint8)
107 |
108 | def srgb_to_linear(x):
109 | """Transform the image from sRGB to linear"""
110 | a = 0.055
111 | x = np.clip(x, 0, 1)
112 | mask = x < 0.04045
113 | x[mask] /= 12.92
114 | x[mask!=True] = np.exp(2.4 * (np.log(x[mask!=True] + a) - np.log(1 + a)))
115 | return x
116 |
117 | def linear_to_srgb(x):
118 | """Transform the image from linear to sRGB"""
119 | a = 0.055
120 | x = np.clip(x, 0, 1)
121 | mask = x <= 0.0031308
122 | x[mask] *= 12.92
123 | x[mask!=True] = np.exp(np.log(1 + a) + (1/2.4) * np.log(x[mask!=True])) - a
124 | return x
125 |
126 | ## EXR DATA UTILS ##
127 |
128 | """
129 | EXR utility functions have to be wrapped in a TensorFlow graph by using
130 | tf.numpy_function(). This function requires a specific fixed return type,
131 | which is why all EXR reading functions are of return type float32.
132 | """
133 | # Imath.PixelType can have UINT unint32, HALF float16, FLOAT float32
134 | EXR_PIX_TYPE = Imath.PixelType(Imath.PixelType.FLOAT)
135 | EXR_NP_TYPE = np.float32
136 |
137 | def is_exr(filename):
138 | file_extension = os.path.splitext(filename)[1][1:]
139 | if file_extension in ['exr', 'EXR']:
140 | return True
141 | elif file_extension in ['jpg', 'jpeg', 'png', 'bmp', 'JPG', 'JPEG', 'PNG', 'BMP']:
142 | return False
143 | else:
144 | raise TypeError("{} unhandled type extensions. Should be one of "
145 | "['jpg', 'jpeg', 'png', 'bmp', 'exr']". format(file_extension))
146 |
147 | def check_exr(exr_files, channel_names=['R', 'G', 'B']):
148 | """Check that exr_files (a list of EXR file(s)) have the requested channels
149 | and have the same data window size. Return image width and height.
150 | """
151 | if not list(channel_names):
152 | raise ValueError("channel_names is empty")
153 | if isinstance(exr_files, OpenEXR.InputFile): # single exr file
154 | exr_files = [exr_files]
155 | elif not isinstance(exr_files, list):
156 | raise TypeError("type(exr_files): {}, should be str or list".format(type(exr_files)))
157 | # Check data window size
158 | data_windows = [str(exr.header()['dataWindow']) for exr in exr_files]
159 | if any(dw != data_windows[0] for dw in data_windows):
160 | raise ValueError("input and groundtruth .exr images have different size")
161 | # Check channel to read are present in given exr file(s)
162 | channels_headers = [exr.header()['channels'] for exr in exr_files]
163 | for channels in channels_headers:
164 | if any(c not in list(channels.keys()) for c in channel_names):
165 | raise ValueError("Try to read channels {} of an exr image with channels {}"
166 | .format(channel_names, list(channels.keys())))
167 | # Compute the size
168 | dw = exr_files[0].header()['dataWindow']
169 | width = dw.max.x - dw.min.x + 1
170 | height = dw.max.y - dw.min.y + 1
171 | return width, height
172 |
173 | def read_exr(exr_path, channel_names=['R', 'G', 'B']):
174 | """Read requested channels of an exr and return them in a numpy array
175 | """
176 | # Open and check the input file
177 | exr_file = OpenEXR.InputFile(exr_path)
178 | width, height = check_exr(exr_file, channel_names)
179 | # Copy channels from an exr file into a numpy array
180 | exr_numpy = [np.frombuffer(exr_file.channel(c, EXR_PIX_TYPE), dtype=EXR_NP_TYPE)
181 | .reshape(height, width) for c in channel_names]
182 | exr_numpy = np.stack(exr_numpy, axis=-1)
183 | return exr_numpy
184 |
185 | def read_resize_exr(exr_path, patch_size, channel_names=['R', 'G', 'B']):
186 | """Read requested channels of an exr as numpy array
187 | and return them resized to (patch_size, patch_size)
188 | """
189 | exr = read_exr(exr_path, channel_names)
190 | exr_resize = cv2.resize(exr, dsize=(patch_size, patch_size))
191 | return exr_resize
192 |
193 | def read_crop_exr(exr_file, size, crop_w, crop_h, crop_size=256, channel_names=['R', 'G', 'B']):
194 | """Read requested channels of an exr file, crop it and return it as numpy array
195 |
196 | The cropping box has a size of crop_size and its bottom left point is (crop_h, crop_w)
197 | """
198 | # Read only the crop scanlines, not the full EXR image
199 | cnames = ''.join(channel_names)
200 | channels = exr_file.channels(cnames=cnames, pixel_type=EXR_PIX_TYPE,
201 | scanLine1=crop_h, scanLine2=crop_h + crop_size - 1)
202 | exr_crop = np.zeros([crop_size, crop_size, len(channel_names)], dtype=EXR_NP_TYPE)
203 | for idx, c in enumerate(channel_names):
204 | exr_crop[:,:,idx] = (np.frombuffer(channels[idx], dtype=EXR_NP_TYPE)
205 | .reshape(crop_size, size[0])[:, crop_w:crop_w+crop_size])
206 | return exr_crop
207 |
208 | def read_crop_exr_pair(exr_path_in, exr_path_gt, crop_size=256, channel_names=['R', 'G', 'B']):
209 | """Read requested channels of input and groundtruth .exr image paths
210 | and return the same random crop of both
211 | """
212 | # Open the input file
213 | exr_file_in = OpenEXR.InputFile(exr_path_in)
214 | exr_file_gt = OpenEXR.InputFile(exr_path_gt)
215 | width, height = check_exr([exr_file_in, exr_file_gt], channel_names)
216 | # Check exr image width and height >= crop_size
217 | if height < crop_size or width < crop_size:
218 | raise ValueError("Input images size should be superior or equal to crop_size: {} < ({},{})"
219 | .format((width, height), crop_size, crop_size))
220 | # Get random crop value
221 | randw = np.random.randint(0, width-crop_size) if width-crop_size > 0 else 0
222 | randh = np.random.randint(0, height-crop_size) if height-crop_size > 0 else 0
223 | # Get the crop of input and groundtruth .exr images
224 | exr_crop_in = read_crop_exr(exr_file_in, (width, height), randw, randh, crop_size, channel_names)
225 | exr_crop_gt = read_crop_exr(exr_file_gt, (width, height), randw, randh, crop_size, channel_names)
226 | return [exr_crop_in, exr_crop_gt]
--------------------------------------------------------------------------------
/Models/classTemplateTF/train_classification.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import sys
4 | import os
5 | import time
6 | import random
7 | import argparse
8 | from datetime import datetime
9 |
10 | import scipy.misc
11 | import numpy as np
12 |
13 | import tensorflow as tf
14 | tf.compat.v1.disable_eager_execution() # For TF 2.x compatibility
15 |
16 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
17 | from common.model_builder import mobilenet_transfer
18 | from common.util import im2uint8, get_filepaths_from_dir, get_saved_model_list, get_labels_from_dir, print_
19 |
20 | class TrainModel(object):
21 | """Train the chosen model from the given input and groundtruth data"""
22 |
23 | def __init__(self, args):
24 | # Training hyperparameters
25 | self.learning_rate = args.learning_rate
26 | self.batch_size = args.batch_size
27 | self.epoch = args.epoch
28 | self.save_model_period = 1 # save model weights every N epochs
29 | # Training and validation dataset paths
30 | self.train_data_path = './data/train'
31 | self.val_data_path = './data/validation'
32 | # Where to save and load model weights (=checkpoints)
33 | self.checkpoints_dir = './checkpoints'
34 | if not os.path.exists(self.checkpoints_dir):
35 | os.makedirs(self.checkpoints_dir)
36 | self.ckpt_save_name = 'classTemplate'
37 | # Where to save tensorboard summaries
38 | self.summaries_dir = './summaries'
39 | if not os.path.exists(self.summaries_dir):
40 | os.makedirs(self.summaries_dir)
41 |
42 | # Get training dataset as lists of image paths
43 | self.train_gt_data_list = get_filepaths_from_dir(self.train_data_path)
44 | if len(self.train_gt_data_list) is 0:
45 | raise ValueError("No training data found in folder {}".format(self.train_data_path))
46 | elif (len(self.train_gt_data_list) < self.batch_size):
47 | raise ValueError("Batch size must be smaller than the dataset (batch size = {}, number of training data = {})"
48 | .format(self.batch_size, len(self.train_gt_data_list)))
49 |
50 | # Get validation dataset if provided
51 | self.has_val_data = True
52 | self.val_gt_data_list = get_filepaths_from_dir(self.val_data_path)
53 | if len(self.val_gt_data_list) is 0:
54 | print("No validation data found in {}, 20% of training data will be used as validation data".format(self.val_data_path))
55 | self.has_val_data = False
56 | self.validation_split = 0.2
57 | elif (len(self.val_gt_data_list) < self.batch_size):
58 | raise ValueError("Batch size must be smaller than the dataset (batch size = {}, number of validation data = {})"
59 | .format(self.batch_size, len(self.val_gt_data_list)))
60 | else:
61 | print_("Number of validation data: {}\n".format(len(self.val_gt_data_list)), 'm')
62 | self.validation_split = 0.0
63 |
64 | self.train_labels = get_labels_from_dir(self.train_data_path)
65 | # Check class labels are the same
66 | if self.has_val_data:
67 | self.val_labels = get_labels_from_dir(self.val_data_path)
68 | if self.train_labels != self.val_labels:
69 | if len(self.train_labels) != len(self.val_labels):
70 | raise ValueError("{} and {} should have the same number of subdirectories ({}!={})"
71 | .format(self.train_data_path, self.val_data_path, len(self.train_labels), len(self.val_labels)))
72 | raise ValueError("{} and {} should have the same subdirectory label names ({}!={})"
73 | .format(self.train_data_path, self.val_data_path, self.train_labels, self.val_labels))
74 |
75 | # Compute and print training hyperparameters
76 | self.batch_per_epoch = int(np.ceil(len(self.train_gt_data_list) / float(self.batch_size)))
77 | self.max_steps = int(self.epoch * (self.batch_per_epoch))
78 | print_("Number of training data: {}\nNumber of batches per epoch: {} (batch size = {})\nNumber of training steps for {} epochs: {}\n"
79 | .format(len(self.train_gt_data_list), self.batch_per_epoch, self.batch_size, self.epoch, self.max_steps), 'm')
80 | print("Class labels: {}".format(self.train_labels))
81 |
82 | def train(self):
83 | # Build model
84 | self.model = mobilenet_transfer(len(self.train_labels))
85 | # Configure the model for training
86 | self.model.compile(optimizer=tf.keras.optimizers.Adam(),
87 | loss='categorical_crossentropy',
88 | metrics=['accuracy'])
89 | # Print current model layers
90 | # self.model.summary()
91 |
92 | # Set preprocessing function
93 | datagen = tf.keras.preprocessing.image.ImageDataGenerator(
94 | # scale pixels between -1 and 1, sample-wise
95 | preprocessing_function=tf.keras.applications.mobilenet.preprocess_input,
96 | validation_split=self.validation_split)
97 | # Get classification data
98 | train_generator=datagen.flow_from_directory(
99 | self.train_data_path,
100 | target_size=(224,224),
101 | color_mode='rgb',
102 | batch_size=self.batch_size,
103 | class_mode='categorical',
104 | shuffle=True,
105 | subset='training')
106 | if self.has_val_data:
107 | validation_generator=datagen.flow_from_directory(
108 | self.val_data_path,
109 | target_size=(224,224),
110 | color_mode='rgb',
111 | batch_size=self.batch_size,
112 | class_mode='categorical',
113 | shuffle=True)
114 | else: # Generate a split of the training data as validation data
115 | validation_generator=datagen.flow_from_directory(
116 | self.train_data_path, # subset from training data path
117 | target_size=(224,224),
118 | color_mode='rgb',
119 | batch_size=self.batch_size,
120 | class_mode='categorical',
121 | shuffle=True,
122 | subset='validation')
123 |
124 | # Callback for creating Tensorboard summary
125 | summary_name = "classif_data{}_bch{}_ep{}".format(len(self.train_gt_data_list), self.batch_size, self.epoch)
126 | tensorboard_callback = tf.keras.callbacks.TensorBoard(
127 | log_dir=os.path.join(self.summaries_dir, summary_name))
128 | # Callback for saving models periodically
129 | class_labels_save = '_'.join(self.train_labels) + '.'
130 | # 'acc' is the training accuracy and 'val_acc' is the validation set accuracy
131 | self.ckpt_save_name = class_labels_save + self.ckpt_save_name + "-val_acc{val_acc:.2f}-acc{acc:.2f}-ep{epoch:04d}.h5"
132 | checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
133 | filepath=os.path.join(self.checkpoints_dir, self.ckpt_save_name),
134 | save_weights_only=False,
135 | period=self.save_model_period,
136 | save_best_only=True, monitor='val_acc', mode='max'
137 | )
138 |
139 | # Check if there are intermediate trained model to load
140 | # Uncomment following lines if you want to resume from a previous saved model
141 | # if not self.load_model():
142 | # print_("Starting training from scratch\n", 'm')
143 |
144 | # Train the model
145 | fit_history = self.model.fit_generator(
146 | generator=train_generator,
147 | steps_per_epoch=train_generator.n // self.batch_size,
148 | validation_data=validation_generator,
149 | validation_steps= validation_generator.n // self.batch_size,
150 | epochs=self.epoch,
151 | callbacks=[checkpoint_callback, tensorboard_callback])
152 |
153 | print_("--------End of training--------\n", 'm')
154 |
155 | def load_model(self):
156 | """Ask user if start training from scratch or resume from a previous checkpoint
157 |
158 | If resume, load model in self.model and return True, else return False
159 | """
160 | ckpt_names = get_saved_model_list(self.checkpoints_dir)
161 | if not ckpt_names: # list is empty
162 | print_("No checkpoints found in {}\n".format(self.checkpoint_dir), 'm')
163 | return False
164 | else:
165 | print_("Found checkpoints:\n", 'm')
166 | for name in ckpt_names:
167 | print(" {}".format(name))
168 | # Ask user if they prefer to start training from scratch or resume training on a specific ckeckpoint
169 | while True:
170 | mode=str(raw_input('Start training from scratch (start) or resume training from a previous checkpoint (choose one of the above): '))
171 | if mode == 'start' or mode in ckpt_names:
172 | break
173 | else:
174 | print("Answer should be 'start' or one of the following checkpoints: {}".format(ckpt_names))
175 | continue
176 | if mode == 'start':
177 | return False
178 | elif mode in ckpt_names:
179 | # Try to load given intermediate checkpoint
180 | print_("Loading trained model...\n", 'm')
181 | self.model = tf.keras.models.load_model(os.path.join(self.checkpoints_dir, mode))
182 | print_("...Checkpoint {} loaded\n".format(mode), 'm')
183 | return True
184 | else:
185 | raise ValueError("User input is neither 'start' nor a valid checkpoint")
186 |
187 | def parse_args():
188 | parser = argparse.ArgumentParser(description='Model training arguments')
189 | parser.add_argument('--bch', type=int, default=16, dest='batch_size', help='training batch size')
190 | parser.add_argument('--ep', type=int, default=100, dest='epoch', help='training epoch number')
191 | parser.add_argument('--lr', type=float, default=1e-4, dest='learning_rate', help='initial learning rate')
192 | args = parser.parse_args()
193 | return args
194 |
195 | if __name__ == '__main__':
196 | args = parse_args()
197 | # set up model to train
198 | model = TrainModel(args)
199 | model.train()
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/Plugins/Server/server.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018 Foundry.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | ##############################################################################
15 |
16 | import argparse
17 | import os
18 | import importlib
19 | import socket # to get machine hostname
20 | import traceback
21 |
22 | try: # python3
23 | import socketserver
24 | except ImportError: # python2
25 | import SocketServer as socketserver
26 |
27 | import numpy as np
28 |
29 | from message_pb2 import *
30 |
31 | class MLTCPServer(socketserver.TCPServer):
32 | def __init__(self, server_address, handler_class, auto_bind=True):
33 | self.verbose = True
34 | # Each directory in models/ containing a model.py file is an available ML model
35 | self.available_models = [name for name in next(os.walk('models'))[1]
36 | if os.path.isfile(os.path.join('models', name, 'model.py'))]
37 | self.available_models.sort()
38 | self.models = {}
39 | for model in self.available_models:
40 | print('Importing models.{}.model'.format(model))
41 | self.models[model] = importlib.import_module('models.{}.model'.format(model)).Model()
42 | socketserver.TCPServer.__init__(self, server_address, handler_class, auto_bind)
43 | return
44 |
45 | class ImageProcessTCPHandler(socketserver.BaseRequestHandler):
46 | """This request handler is instantiated once per connection."""
47 |
48 | def handle(self):
49 | # Read the data headers
50 | data_hdr = self.request.recv(12)
51 | sz = int(data_hdr)
52 | self.vprint('Receiving message of size: {}'.format(sz))
53 |
54 | # Read data
55 | data = self.recvall(sz)
56 | self.vprint('{} bytes read'.format(len(data)))
57 |
58 | # Parse the message
59 | req_msg = RequestWrapper()
60 | req_msg.ParseFromString(data)
61 | self.vprint('Message parsed')
62 |
63 | # Process message
64 | resp_msg = self.process_message(req_msg)
65 | # Serialize response
66 | self.vprint('Serializing message')
67 | s = resp_msg.SerializeToString()
68 | msg_len = resp_msg.ByteSize()
69 | totallen = 12 + msg_len
70 | msg = bytes(str(totallen).zfill(12).encode('utf-8')) + s
71 | self.vprint('Sending response message of size: {}'.format(totallen))
72 | self.sendmsg(msg, totallen)
73 | self.vprint('-----------------------------------------------')
74 |
75 | def process_message(self, message):
76 | if message.HasField('r1'):
77 | self.vprint('Received info request')
78 | return self.process_info(message)
79 | elif message.HasField('r2'):
80 | self.vprint('Received inference request')
81 | return self.process_inference(message)
82 | else:
83 | # Pass error message to the client
84 | return self.errormsg("Server received unindentified request from client.")
85 |
86 | def process_info(self, message):
87 | resp_msg = RespondWrapper()
88 | resp_msg.info = True
89 | resp_info = RespondInfo()
90 | resp_info.num_models = len(self.server.available_models)
91 | # Add all model info into the message
92 | for model in self.server.available_models:
93 | m = resp_info.models.add()
94 | m.name = model
95 | m.label = self.server.models[model].get_name()
96 | # Add inputs
97 | for inp_name, inp_channels in self.server.models[model].get_inputs().items():
98 | inp = m.inputs.add()
99 | inp.name = inp_name
100 | inp.channels = inp_channels
101 | # Add outputs
102 | for out_name, out_channels in self.server.models[model].get_outputs().items():
103 | out = m.outputs.add()
104 | out.name = out_name
105 | out.channels = out_channels
106 | # Add options
107 | for opt_name, opt_value in self.server.models[model].get_options().items():
108 | if type(opt_value) == int:
109 | opt = m.int_options.add()
110 | elif type(opt_value) == float:
111 | opt = m.float_options.add()
112 | elif type(opt_value) == bool:
113 | opt = m.bool_options.add()
114 | elif type(opt_value) == str:
115 | opt = m.string_options.add()
116 | # TODO: Implement multiple choice
117 | else:
118 | # Send an error response message to the Nuke Client
119 | option_error = ("Model option of type {} is not implemented. "
120 | "Broadcasted options need to be one of bool, int, float, str."
121 | ).format(type(opt_value))
122 | return self.errormsg(option_error)
123 | opt.name = opt_name
124 | opt.values.extend([opt_value])
125 | # Add buttons
126 | for button_name, button_value in self.server.models[model].get_buttons().items():
127 | if type (button_value) == bool:
128 | button = m.button_options.add()
129 | else:
130 | return self.errormsg("Model button needs to be of type bool.")
131 | button.name = button_name
132 | button.values.extend([button_value])
133 |
134 | # Add RespondInfo message to RespondWrapper
135 | resp_msg.r1.CopyFrom(resp_info)
136 |
137 | return resp_msg
138 |
139 | def process_inference(self, message):
140 | req = message.r2
141 | m = req.model
142 | self.vprint('Requesting inference on model: {}'.format(m.name))
143 |
144 | # Parse model options
145 | opt = {}
146 | for options in [m.bool_options, m.int_options, m.float_options, m.string_options]:
147 | for option in options:
148 | opt[option.name] = option.values[0]
149 | # Set model options
150 | self.server.models[m.name].set_options(opt)
151 | # Parse model buttons
152 | btn = {}
153 | for button in m.button_options:
154 | btn[button.name] = button.values[0]
155 | self.server.models[m.name].set_buttons(btn)
156 |
157 | # Parse images
158 | img_list = []
159 | for byte_img in req.images:
160 | img = np.fromstring(byte_img.image, dtype=' ' + string)
242 |
243 | if __name__ == "__main__":
244 | parser = argparse.ArgumentParser(description='Machine Learning inference server.')
245 | parser.add_argument('port', type=int, help='Port number for the server to listen to.')
246 | args = parser.parse_args()
247 |
248 | # Get the current hostname of the server
249 | server_hostname = socket.gethostbyname(socket.gethostname())
250 | # Create the server
251 | server = MLTCPServer((server_hostname, args.port), ImageProcessTCPHandler, False)
252 |
253 | # Bind and activate the server
254 | server.allow_reuse_address = True
255 | server.server_bind()
256 | server.server_activate()
257 | print('Server -> Listening on port: {}'.format(args.port))
258 | server.serve_forever()
--------------------------------------------------------------------------------
/Plugins/Client/MLClientModelManager.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019 Foundry.
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 | //*************************************************************************
15 |
16 | #include "MLClientModelManager.h"
17 | #include "DDImage/Knob.h"
18 | #include "MLClient.h"
19 |
20 | MLClientModelKnob::MLClientModelKnob(DD::Image::Knob_Closure* kc, DD::Image::Op* op, const char* name)
21 | : DD::Image::Knob(kc, name)
22 | , _op(op)
23 | , _model("")
24 | { }
25 |
26 | const char* MLClientModelKnob::Class() const
27 | {
28 | return "MLClientModelKnob";
29 | }
30 |
31 | bool MLClientModelKnob::not_default () const
32 | {
33 | // Always flag as not default, so it's always serialised.
34 | return true;
35 | }
36 |
37 | std::string MLClientModelKnob::getModel() const
38 | {
39 | return _model;
40 | }
41 |
42 | const std::map& MLClientModelKnob::getParameters() const
43 | {
44 | return _parameters;
45 | }
46 |
47 | void MLClientModelKnob::to_script (std::ostream &out, const DD::Image::OutputContext *, bool quote) const
48 | {
49 | std::string saveString;
50 | std::stringstream ss;
51 | if (_op != nullptr) {
52 | DD::Image::Knob* k = _op->knob("models");
53 | if(k != nullptr) {
54 | const int modelIndex = k->get_value();
55 | DD::Image::Enumeration_KnobI* eKnob = k->enumerationKnob();
56 | if(eKnob != nullptr) {
57 | ss << "model:" << eKnob->getItemValueString(modelIndex) << ";";
58 | MLClient* mlClient = dynamic_cast(_op);
59 | if(mlClient != nullptr) {
60 | MLClientModelManager& mlManager = mlClient->getModelManager();
61 | toScriptT(mlManager, ss, &MLClientModelManager::getNumOfInts, &MLClientModelManager::getDynamicIntName);
62 | toScriptT(mlManager, ss, &MLClientModelManager::getNumOfFloats, &MLClientModelManager::getDynamicFloatName);
63 | toScriptT(mlManager, ss, &MLClientModelManager::getNumOfBools, &MLClientModelManager::getDynamicBoolName);
64 | toScriptStrings(mlManager, ss);
65 | }
66 | }
67 | }
68 | }
69 | saveString = ss.str();
70 | if(quote) {
71 | saveString.insert(saveString.begin(),'{');
72 | saveString+='}';
73 | }
74 | out << saveString;
75 | }
76 |
77 | bool MLClientModelKnob::from_script(const char * src)
78 | {
79 | std::string loadString(src);
80 |
81 | if ((_op != nullptr) && (loadString!="")) {
82 | bool success = false;
83 |
84 | // We parse the serialised string to extract the pairs of key:val;
85 | const std::string delimiter = ";";
86 | const std::string keyValDelimiter = ":";
87 | _parameters.clear();
88 | size_t pos = 0;
89 | std::string token;
90 | while ((pos = loadString.find(delimiter)) != std::string::npos) {
91 | token = loadString.substr(0, pos);
92 | std::cout << token << std::endl;
93 |
94 | // We further split the key:value pair
95 | std::string key = token.substr(0, token.find(keyValDelimiter));
96 | std::string val = token.substr(token.find(keyValDelimiter) + keyValDelimiter.length(), token.length() - key.length() - keyValDelimiter.length());
97 | if(key == "model") {
98 | _model = val;
99 | } else {
100 | _parameters.insert(std::make_pair(key, val));
101 | }
102 |
103 | loadString.erase(0, pos + delimiter.length());
104 | }
105 |
106 | return success;
107 | }
108 | return true;
109 | }
110 |
111 | void MLClientModelKnob::toScriptT(MLClientModelManager& mlManager, std::ostream &out,
112 | int (MLClientModelManager::*getNum)() const,
113 | std::string (MLClientModelManager::*getDynamicName)(int)) const
114 | {
115 | const int num = (mlManager.*getNum)();
116 | for(int i = 0; i < num; i++) {
117 | DD::Image::Knob* k = _op->knob((mlManager.*getDynamicName)(i).c_str());
118 | if(k != nullptr) {
119 | std::stringstream ss;
120 | k->to_script(ss, nullptr, false);
121 | out << (mlManager.*getDynamicName)(i) << ":" << ss.str() << ";";
122 | }
123 | }
124 | }
125 |
126 | void MLClientModelKnob::toScriptStrings(MLClientModelManager& mlManager, std::ostream &out) const
127 | {
128 | const int numFloats = mlManager.getNumOfStrings();
129 | for(int i = 0; i < numFloats; i++) {
130 | DD::Image::Knob* k = _op->knob(mlManager.getDynamicStringName(i).c_str());
131 | if(k != nullptr) {
132 | out << mlManager.getDynamicStringName(i) << ":" << k->get_text() << ";";
133 | }
134 | }
135 | }
136 |
137 | MLClientModelManager::MLClientModelManager(DD::Image::Op* parent)
138 | : _parent(parent)
139 | { }
140 |
141 | MLClientModelManager::~MLClientModelManager()
142 | { }
143 |
144 | //! Parse options from the server model /m to the MLClientModelManager
145 | void MLClientModelManager::parseOptions(const mlserver::Model& m)
146 | {
147 | clear();
148 |
149 | for (int i = 0, endI = m.bool_options_size(); i < endI; i++) {
150 | mlserver::BoolAttrib option;
151 | option = m.bool_options(i);
152 | if (option.values(0)) {
153 | _dynamicBoolValues.push_back(1);
154 | }
155 | else {
156 | _dynamicBoolValues.push_back(0);
157 | }
158 | _dynamicBoolNames.push_back(option.name());
159 | }
160 | for (int i = 0, endI = m.int_options_size(); i < endI; i++) {
161 | mlserver::IntAttrib option;
162 | option = m.int_options(i);
163 | _dynamicIntValues.push_back(option.values(0));
164 | _dynamicIntNames.push_back(option.name());
165 | }
166 | for (int i = 0, endI = m.float_options_size(); i < endI; i++) {
167 | mlserver::FloatAttrib option;
168 | option = m.float_options(i);
169 | _dynamicFloatValues.push_back(option.values(0));
170 | _dynamicFloatNames.push_back(option.name());
171 | }
172 | for (int i = 0, endI = m.string_options_size(); i < endI; i++) {
173 | mlserver::StringAttrib option;
174 | option = m.string_options(i);
175 | _dynamicStringValues.push_back(option.values(0));
176 | _dynamicStringNames.push_back(option.name());
177 | }
178 | for (int i = 0, endI = m.button_options_size(); i < endI; i++) {
179 | mlserver::BoolAttrib option;
180 | option = m.button_options(i);
181 | if (option.values(0)) {
182 | _dynamicButtonValues.push_back(1);
183 | }
184 | else {
185 | _dynamicButtonValues.push_back(0);
186 | }
187 | _dynamicButtonNames.push_back(option.name());
188 | }
189 | }
190 |
191 | //! Use current knob values to update options on the server model /m
192 | //! in order to later request inference on this model
193 | void MLClientModelManager::updateOptions(mlserver::Model& m)
194 | {
195 | m.clear_bool_options();
196 | for (int i = 0; i < _dynamicBoolValues.size(); i++) {
197 | mlserver::BoolAttrib* option = m.add_bool_options();
198 | option->set_name(_dynamicBoolNames[i]);
199 | DD::Image::Knob* k = _parent->knob(_dynamicBoolNames[i].c_str());
200 | bool val = false;
201 | if (k != nullptr) {
202 | val = k->get_value();
203 | }
204 | option->add_values(val);
205 | }
206 |
207 | m.clear_int_options();
208 | for (int i = 0; i < _dynamicIntValues.size(); i++) {
209 | mlserver::IntAttrib* option = m.add_int_options();
210 | option->set_name(_dynamicIntNames[i]);
211 | DD::Image::Knob* k = _parent->knob(_dynamicIntNames[i].c_str());
212 | int val = 0;
213 | if (k != nullptr) {
214 | val = k->get_value();
215 | }
216 | option->add_values(val);
217 | }
218 |
219 | m.clear_float_options();
220 | for (int i = 0; i < _dynamicFloatValues.size(); i++) {
221 | mlserver::FloatAttrib* option = m.add_float_options();
222 | option->set_name(_dynamicFloatNames[i]);
223 | DD::Image::Knob* k = _parent->knob(_dynamicFloatNames[i].c_str());
224 | float val = 0.0f;
225 | if (k != nullptr) {
226 | val = k->get_value();
227 | }
228 | option->add_values(val);
229 | }
230 |
231 | m.clear_string_options();
232 | for (int i = 0; i < _dynamicStringValues.size(); i++) {
233 | mlserver::StringAttrib* option = m.add_string_options();
234 | option->set_name(_dynamicStringNames[i]);
235 | DD::Image::Knob* k = _parent->knob(_dynamicStringNames[i].c_str());
236 | const char* val = "";
237 | if(k != nullptr) {
238 | val = k->get_text();
239 | if (val==nullptr) {
240 | val = "";
241 | }
242 | }
243 | option->add_values(val);
244 | }
245 |
246 | m.clear_button_options();
247 | for (int i = 0; i < _dynamicButtonValues.size(); i++) {
248 | mlserver::BoolAttrib* option = m.add_button_options();
249 | option->set_name(_dynamicButtonNames[i]);
250 | // Get member value instead of knob value to catch button push
251 | option->add_values(_dynamicButtonValues[i]);
252 | }
253 | }
254 |
255 | int MLClientModelManager::getNumOfFloats() const
256 | {
257 | return _dynamicFloatValues.size();
258 | }
259 |
260 | int MLClientModelManager::getNumOfInts() const
261 | {
262 | return _dynamicIntValues.size();
263 | }
264 |
265 | int MLClientModelManager::getNumOfBools() const
266 | {
267 | return _dynamicBoolValues.size();
268 | }
269 |
270 | int MLClientModelManager::getNumOfStrings() const
271 | {
272 | return _dynamicStringValues.size();
273 | }
274 |
275 | int MLClientModelManager::getNumOfButtons() const
276 | {
277 | return _dynamicButtonValues.size();
278 | }
279 |
280 | std::string MLClientModelManager::getDynamicBoolName(int idx)
281 | {
282 | return _dynamicBoolNames[idx];
283 | }
284 |
285 | std::string MLClientModelManager::getDynamicFloatName(int idx)
286 | {
287 | return _dynamicFloatNames[idx];
288 | }
289 |
290 | std::string MLClientModelManager::getDynamicIntName(int idx)
291 | {
292 | return _dynamicIntNames[idx];
293 | }
294 |
295 | std::string MLClientModelManager::getDynamicStringName(int idx)
296 | {
297 | return _dynamicStringNames[idx];
298 | }
299 |
300 | std::string MLClientModelManager::getDynamicButtonName(int idx)
301 | {
302 | return _dynamicButtonNames[idx];
303 | }
304 |
305 | float* MLClientModelManager::getDynamicFloatValue(int idx)
306 | {
307 | return &_dynamicFloatValues[idx];
308 | }
309 |
310 | int* MLClientModelManager::getDynamicIntValue(int idx)
311 | {
312 | return &_dynamicIntValues[idx];
313 | }
314 |
315 | bool* MLClientModelManager::getDynamicBoolValue(int idx)
316 | {
317 | return (bool*)&_dynamicBoolValues[idx];
318 | }
319 |
320 | std::string* MLClientModelManager::getDynamicStringValue(int idx)
321 | {
322 | return &_dynamicStringValues[idx];
323 | }
324 |
325 | bool* MLClientModelManager::getDynamicButtonValue(int idx)
326 | {
327 | return (bool*)&_dynamicButtonValues[idx];
328 | }
329 |
330 | void MLClientModelManager::setDynamicButtonValue(int idx, int value)
331 | {
332 | _dynamicButtonValues[idx] = value;
333 | }
334 |
335 | void MLClientModelManager::clear()
336 | {
337 | _dynamicBoolValues.clear();
338 | _dynamicIntValues.clear();
339 | _dynamicFloatValues.clear();
340 | _dynamicStringValues.clear();
341 | _dynamicButtonValues.clear();
342 |
343 | _dynamicBoolNames.clear();
344 | _dynamicIntNames.clear();
345 | _dynamicFloatNames.clear();
346 | _dynamicStringNames.clear();
347 | _dynamicButtonNames.clear();
348 | }
349 |
--------------------------------------------------------------------------------
/INSTALL.md:
--------------------------------------------------------------------------------
1 | # Installing Nuke Machine Learning Plugin
2 |
3 | The Nuke Machine Learning (ML) installation can be divided into compiling the MLClient Nuke node and installing the MLServer using Docker.
4 |
5 | The MLClient plugin can be compiled on both Linux/MacOS and Windows systems. It communicates with the MLServer which needs to be run on a Linux machine with NVIDIA GPU.
6 |
7 | **Requirements:**
8 | - Linux with Nuke installed
9 | - NVIDIA GPU (Important: GPU memory must be at least 6GB)
10 | - CMake (minimum 3.10)
11 | - Protobuf (tested with 2.5.0 and 3.5.1)
12 | - Docker
13 |
14 | ## Installing the Client on Linux/MacOS
15 |
16 | ### Install Protobuf
17 |
18 | Protocol Buffers (aka Protobuf) are an efficient way of serializing structured data - similar to XML, but faster and simpler. We use it to define, write, and read the data for our client<->server communication.
19 |
20 | Following the [installation instructions](https://github.com/protocolbuffers/protobuf/blob/master/src/README.md) from the Protobuf GitHub repository, we recommend compiling Protobuf from source:
21 |
22 | First get Protobuf source file for C++, for instance version 3.5.1:
23 | ```
24 | wget https://github.com/protocolbuffers/protobuf/releases/download/v3.5.1/protobuf-cpp-3.5.1.tar.gz
25 | # Extract file in current directory
26 | tar -xzf protobuf-cpp-3.5.1.tar.gz
27 | ```
28 | Then build and install the C++ Protocol Buffer runtime and the Protocol Buffer compiler (protoc):
29 | ```
30 | cd protobuf-3.5.1
31 | ./configure
32 | make
33 | make check
34 | sudo make install
35 | sudo ldconfig # refresh shared library cache.
36 | ```
37 |
38 | Note: Instead of compiling it from source, Protobuf may alternatively be installed with a package manager, for example:
39 | ```
40 | sudo yum install protobuf-devel
41 | ```
42 |
43 | ### Compile MLClient Nuke Node
44 |
45 | If not already cloned, fetch the `nuke-ML-server` repository:
46 | ```
47 | git clone https://github.com/TheFoundryVisionmongers/nuke-ML-server
48 | ```
49 | Execute the commands below to compile the client MLClient.so plugin, setting the NUKE_INSTALL_PATH to point to the folder of the desired Nuke version:
50 | ```
51 | cd nuke-ML-server/
52 | mkdir build && cd build
53 | cmake -DNUKE_INSTALL_PATH=/path/to/Nuke11.3v1/ ..
54 | make
55 | ```
56 | The MLClient.so plugin will now be in the `build/Plugins/Client` folder. Before it can be used, Nuke needs to know where it lives. One way to do this is to update the NUKE_PATH environment variable to point to the MLClient.so plugin (This can be skipped if it was moved to the root of your ~/.nuke folder, or the path was added in Nuke through Python):
57 | ```
58 | export NUKE_PATH=/path/to/lib/:$NUKE_PATH
59 | ```
60 | At that point, after opening Nuke and updating all plugins, the `MLClient` node should be available. To update all the plugins in Nuke, you can either use the Other > All Plugins > Update option (see [documentation](https://learn.foundry.com/nuke/developers/63/pythondevguide/installing_plugins.html)), or simply press `tab` in the Node Graph then write `Update [All plugins]`. If the `MLClient` node is still missing, verify that the current NUKE_PATH is correctly pointing to the folder containing MLClient.so.
61 |
62 | ## Installing the Client on Windows
63 |
64 | This was tested on Windows 10. You need to have [cmake](https://cmake.org/) and [git](https://git-scm.com/) installed on your computer.
65 |
66 | Start by installing the Visual Studio Compiler "Build Tools for Visual Studio 2017" found at [this link](https://www.visualstudio.com/thank-you-downloading-visual-studio/?sku=BuildTools&rel=15).
67 |
68 | ### Install Protobuf
69 |
70 | We recommend building Protobuf locally as a static library. For reference this section partly follows the [installation instructions](https://github.com/protocolbuffers/protobuf/blob/master/cmake/README.md) from the Protobuf GitHub repository.
71 |
72 | First open “**x64** Native Tools Command Prompt for VS 2017” executable. Please note it has to be **x64** and not x86.
73 |
74 | If `cmake` or `git` commands are not available from Command Prompt, add them to the system PATH variable:
75 | ```
76 | set PATH=%PATH%;C:\Program Files (x86)\CMake\bin
77 | set PATH=%PATH%;C:\Program Files\Git\cmd
78 | ```
79 | Clone your chosen Protobuf branch release, for instance here version 3.5.1:
80 | ```
81 | git clone -b v3.5.1 https://github.com/protocolbuffers/protobuf.git
82 | cd protobuf
83 | git submodule update --init --recursive
84 | cd cmake
85 | mkdir build & cd build
86 | mkdir release & cd release
87 | ```
88 | Compile protobuf with dynamic VCRTLib (Visual Studio Code C++ Runtime Library):
89 | ```
90 | cmake -G "NMake Makefiles" -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX= -Dprotobuf_MSVC_STATIC_RUNTIME=OFF -Dprotobuf_BUILD_TESTS=OFF ../..
91 | ```
92 | Install protobuf in the specified `` folder by running the following:
93 | ```
94 | nmake install
95 | ```
96 | Note: This last command will create the following folders under the `` location:
97 | - bin - that contains protobuf protoc.exe compiler;
98 | - include - that contains C++ headers and protobuf *.proto files;
99 | - lib - that contains linking libraries and CMake configuration files for protobuf package.
100 |
101 | ### Compile MLClient Nuke Node
102 |
103 | If not already done, clone the `nuke-ML-server` repository:
104 | ```
105 | git clone https://github.com/TheFoundryVisionmongers/nuke-ML-server
106 | cd nuke-ml-server
107 | mkdir build & cd build
108 | mkdir x64-Release & cd x64-Release
109 | ```
110 | Compile the MLClient and link your version of Nuke and Protobuf install path:
111 | ```
112 | cmake -G "NMake Makefiles" -DCMAKE_BUILD_TYPE=Release -DNUKE_INSTALL_PATH=”/path/to//Nuke12.0v3” -DProtobuf_LIBRARIES=”/lib” -DProtobuf_INCLUDE_DIR=”/include” -DProtobuf_PROTOC_EXECUTABLE="/bin/protoc.exe" ../..
113 | nmake
114 | ```
115 | The MLClient.dll plugin should now be in the `build/x64-Release/Plugins/Client` folder. Before it can be used, Nuke needs to know where it lives. You can either copy it to your ~/.nuke folder or update the NUKE_PATH environment:
116 | ```
117 | set NUKE_PATH=%NUKE_PATH%;path/to/lib
118 | ```
119 | At that point, after opening Nuke and updating all plugins, the `MLClient` node should be available. To update all the plugins in Nuke, you can either use the Other > All Plugins > Update option (see [documentation](https://learn.foundry.com/nuke/developers/63/pythondevguide/installing_plugins.html)), or simply press `tab` in the Node Graph then write `Update [All plugins]`. If the `MLClient` node is still missing, verify that the current NUKE_PATH is correctly pointing to the folder containing MLClient.dll.
120 |
121 | As your client is on a Windows machine, you now need to run the server on a Linux machine with NVidia GPU (see [next section](https://github.com/TheFoundryVisionmongers/nuke-ML-server/blob/master/INSTALL.md#installing-the-server)) and connect your Windows machine to it following the [Connect to an External Server](https://github.com/TheFoundryVisionmongers/nuke-ML-server/blob/master/INSTALL.md#connect-to-an-external-server) section.
122 |
123 | ## Installing the Server
124 |
125 | ### Install Docker
126 |
127 | Docker provides a way to package and run an application in a securely isolated environment called a container. This container includes all the application dependencies and libraries. It ensures that the application works seamlessly inside the container in any system environment. We use docker to create a container that easily runs the MLServer.
128 |
129 | Install Docker:
130 | ```
131 | # Install the official docker-ce package
132 | sudo curl -sSL https://get.docker.com/ | sh
133 | # Start Docker
134 | sudo systemctl start docker
135 | ```
136 | Nvidia Docker is a necessary plugin that enables Nvidia GPU-accelerated applications to run in Docker.
137 |
138 | Install nvidia-container-toolkit for your Linux platform by following the [installation instructions](https://github.com/NVIDIA/nvidia-docker) of the nvidia-docker repository. On CentOS/RHEL, you should follow section "CentOS 7 (**docker-ce**), RHEL 7.4/7.5 (**docker-ce**), Amazon Linux 1/2" of the repository.
139 |
140 | Build the docker image from the [Dockerfile](/Plugins/Server/Dockerfile):
141 | ```
142 | # Start by loading Ubuntu18.04 with cuda 10.0 and cudnn7 as the base image
143 | sudo docker pull nvidia/cuda:10.0-cudnn7-devel-ubuntu18.04
144 | # Build the docker image on top of the base image
145 | cd Plugins/Server/
146 | # Choose your own label for , it must be lowercase. e.g. mlserver.
147 | sudo docker build -t -f Dockerfile .
148 | ```
149 |
150 | ### Run Docker Container
151 |
152 | Create and run a docker container on top of the created docker image, referencing the `` from the previous step:
153 |
154 | ```
155 | sudo docker run --gpus all -v /absolute/path/to/nuke-ML-server/Models/:/workspace/ml-server/models -it
156 | ```
157 |
158 | Notes:
159 | - the `-v` (volume) option links your host machine Models/ folder with the models/ folder inside your container. You only need to modify `/absolute/path/to/nuke-ML-server/Models/`, leave the `/workspace/ml-server/models` unchanged as it already corresponds to the folder structure inside your Docker image. This option allows you to add models in Models/ that will be directly available and updated inside your container.
160 | - If your docker version doesn't recognise the `--gpus` flag, you can equally run the same docker container by replacing `sudo docker run --gpus all ` by `sudo nvidia-docker run` or `sudo docker run --runtime=nvidia`.
161 |
162 | ## Getting Started
163 |
164 | ### Download Configuration and Weights Files
165 |
166 | To be able to run inference on the Mask-RCNN model, you need to download its configuration and weight files.
167 |
168 | Depending on your GPU memory, you can use either a ResNet101 (GPU memory > 8GB) or a ResNet50 (GPU memory > 6GB) backbone. The results with ResNet101 are slightly better.
169 | - Mask-RCNN requires ~7GB GPU RAM with ResNet101 and ~4.6GB with ResNet50.
170 |
171 | Download your selected configuration and weight files:
172 | - Mask-RCNN ResNet50:
173 | - Configuration: [e2e_mask_rcnn_R-50-FPN_2x.yaml](https://raw.githubusercontent.com/facebookresearch/Detectron/master/configs/12_2017_baselines/e2e_mask_rcnn_R-50-FPN_2x.yaml)
174 | - Corresponding weights: [model_final.pkl](https://dl.fbaipublicfiles.com/detectron/35859007/12_2017_baselines/e2e_mask_rcnn_R-50-FPN_2x.yaml.01_49_07.By8nQcCH/output/train/coco_2014_train%3Acoco_2014_valminusminival/generalized_rcnn/model_final.pkl) (from the Detectron [Model Zoo](https://github.com/facebookresearch/Detectron/blob/master/MODEL_ZOO.md))
175 | - OR Mask_RCNN ResNet101
176 | - Configuration: [e2e_mask_rcnn_X-101-64x4d-FPN_2x.yaml](https://raw.githubusercontent.com/facebookresearch/Detectron/master/configs/12_2017_baselines/e2e_mask_rcnn_X-101-64x4d-FPN_2x.yaml)
177 | - Correponding weights: [model_final.pkl](https://dl.fbaipublicfiles.com/detectron/35859745/12_2017_baselines/e2e_mask_rcnn_X-101-64x4d-FPN_2x.yaml.02_00_30.ESWbND2w/output/train/coco_2014_train%3Acoco_2014_valminusminival/generalized_rcnn/model_final.pkl) (from the Detectron [Model Zoo](https://github.com/facebookresearch/Detectron/blob/master/MODEL_ZOO.md))
178 |
179 | And move them to `Models/mrcnn/` folder.
180 |
181 | ResNet50 is the default backbone. If you use ResNet101, you need to modify the config and weight file names in Models/mrcnn/model.py.
182 |
183 | ### Connect Client and Server
184 |
185 | This section explains how to connect the server and client when your docker container and Nuke instance are running on the same Linux machine:
186 |
187 | 0. (If you have stopped your container, follow the [Run Docker Container](https://github.com/TheFoundryVisionmongers/nuke-ML-server/blob/master/INSTALL.md#run-docker-container) section again)
188 | 1. In the running docker container, query the IP address:
189 | ```
190 | hostname -I
191 | ```
192 | 2. In Nuke, set the MLClient node `host` to the container IP address,
193 | 3. In the container, launch the server and start listening on port 55555:
194 | ```
195 | python server.py 55555
196 | ```
197 | 4. In Nuke, click on the MLClient connect button, you should have the three models available.
198 |
199 | ### Connect to an External Server
200 |
201 | This section explains how to connect server and client when your docker container (MLServer) and Nuke (MLClient) are running on two different machines, e.g. if you are using the MLClient on Windows. In that case, you have a Linux machine running the docker container and a Windows machine running Nuke.
202 |
203 | 1. On your **Linux machine** (not the docker container, not your Windows machine), query the IP adress:
204 | ```
205 | hostname -I
206 | ```
207 | 2. In Nuke, set the MLClient node `host` to the Linux machine IP address obtained.
208 | 3. On the Linux machine, run the docker container exporting a port of your choice (here port 7000 of the host is mapped to port 55555 of the container):
209 | ```
210 | sudo docker run --gpus all -v /absolute/path/to/nuke-ML-server/Models/:/workspace/ml-server/models -p 7000:55555 -it
211 | ```
212 | 4. In the container, launch the server and start listening on port 55555:
213 | ```
214 | python server.py 55555
215 | ```
216 | 5. In Nuke, set the MLClient node `port` to 7000 and click on the MLClient connect button.
217 |
218 | ### Add your own Model
219 |
220 | To implement your own model, you can create a new folder in the /Models directory with your model name. At the minimum, this folder needs to include an empty `__init__.py` file and a `model.py` file that contains a Model class inheriting from BaseModel.
221 |
222 | You can copy the simple [Models/blur/](Models/blur) model as a starting point, and implement your own model looking at the examples of blur and mrcnn.
223 |
--------------------------------------------------------------------------------
/Models/mrcnn/vis.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | ##############################################################################
15 |
16 | """Detection output visualization module."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 | from __future__ import unicode_literals
22 |
23 | import cv2
24 | import numpy as np
25 | import os
26 |
27 | import pycocotools.mask as mask_util
28 |
29 | from detectron.utils.colormap import colormap
30 | import detectron.utils.env as envu
31 | import detectron.utils.keypoints as keypoint_utils
32 | # Matplotlib requires certain adjustments in some environments
33 | # Must happen before importing matplotlib
34 | envu.set_up_matplotlib()
35 | import matplotlib.pyplot as plt
36 | from matplotlib.patches import Polygon
37 |
38 | plt.rcParams['pdf.fonttype'] = 42 # For editing in Adobe Illustrator
39 |
40 |
41 | _GRAY = (218, 227, 218)
42 | _GREEN = (18, 127, 15)
43 | _WHITE = (255, 255, 255)
44 |
45 |
46 | def kp_connections(keypoints):
47 | kp_lines = [
48 | [keypoints.index('left_eye'), keypoints.index('right_eye')],
49 | [keypoints.index('left_eye'), keypoints.index('nose')],
50 | [keypoints.index('right_eye'), keypoints.index('nose')],
51 | [keypoints.index('right_eye'), keypoints.index('right_ear')],
52 | [keypoints.index('left_eye'), keypoints.index('left_ear')],
53 | [keypoints.index('right_shoulder'), keypoints.index('right_elbow')],
54 | [keypoints.index('right_elbow'), keypoints.index('right_wrist')],
55 | [keypoints.index('left_shoulder'), keypoints.index('left_elbow')],
56 | [keypoints.index('left_elbow'), keypoints.index('left_wrist')],
57 | [keypoints.index('right_hip'), keypoints.index('right_knee')],
58 | [keypoints.index('right_knee'), keypoints.index('right_ankle')],
59 | [keypoints.index('left_hip'), keypoints.index('left_knee')],
60 | [keypoints.index('left_knee'), keypoints.index('left_ankle')],
61 | [keypoints.index('right_shoulder'), keypoints.index('left_shoulder')],
62 | [keypoints.index('right_hip'), keypoints.index('left_hip')],
63 | ]
64 | return kp_lines
65 |
66 |
67 | def convert_from_cls_format(cls_boxes, cls_segms, cls_keyps):
68 | """Convert from the class boxes/segms/keyps format generated by the testing
69 | code.
70 | """
71 | box_list = [b for b in cls_boxes if len(b) > 0]
72 | if len(box_list) > 0:
73 | boxes = np.concatenate(box_list)
74 | else:
75 | boxes = None
76 | if cls_segms is not None:
77 | segms = [s for slist in cls_segms for s in slist]
78 | else:
79 | segms = None
80 | if cls_keyps is not None:
81 | keyps = [k for klist in cls_keyps for k in klist]
82 | else:
83 | keyps = None
84 | classes = []
85 | for j in range(len(cls_boxes)):
86 | classes += [j] * len(cls_boxes[j])
87 | return boxes, segms, keyps, classes
88 |
89 |
90 | def get_class_string(class_index, score, dataset):
91 | class_text = dataset.classes[class_index] if dataset is not None else \
92 | 'id{:d}'.format(class_index)
93 | return class_text + ' {:0.2f}'.format(score).lstrip('0')
94 |
95 |
96 | def vis_mask(img, mask, col, alpha=0.4, show_border=True, border_thick=1):
97 | """Visualizes a single binary mask."""
98 |
99 | img = img.astype(np.float32)
100 | idx = np.nonzero(mask)
101 |
102 | img[idx[0], idx[1], :] *= 1.0 - alpha
103 | img[idx[0], idx[1], :] += alpha * col
104 |
105 | if show_border:
106 | # cv2.findContours gives (image, contours, hierarchy) back in opencv 3.x
107 | # but gives back (contours, hierachy) in opencv 2.x and 4.x
108 | contours, _ = cv2.findContours(
109 | mask.copy(), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)[-2:]
110 | cv2.drawContours(img, contours, -1, _WHITE, border_thick, cv2.LINE_AA)
111 |
112 | return img.astype(np.uint8)
113 |
114 |
115 | def vis_class(img, pos, class_str, font_scale=0.35):
116 | """Visualizes the class."""
117 | img = img.astype(np.uint8)
118 | x0, y0 = int(pos[0]), int(pos[1])
119 | # Compute text size.
120 | txt = class_str
121 | font = cv2.FONT_HERSHEY_SIMPLEX
122 | ((txt_w, txt_h), _) = cv2.getTextSize(txt, font, font_scale, 1)
123 | # Place text background.
124 | back_tl = x0, y0 - int(1.3 * txt_h)
125 | back_br = x0 + txt_w, y0
126 | cv2.rectangle(img, back_tl, back_br, _GREEN, -1)
127 | # Show text.
128 | txt_tl = x0, y0 - int(0.3 * txt_h)
129 | cv2.putText(img, txt, txt_tl, font, font_scale, _GRAY, lineType=cv2.LINE_AA)
130 | return img
131 |
132 |
133 | def vis_bbox(img, bbox, thick=1):
134 | """Visualizes a bounding box."""
135 | img = img.astype(np.uint8)
136 | (x0, y0, w, h) = bbox
137 | x1, y1 = int(x0 + w), int(y0 + h)
138 | x0, y0 = int(x0), int(y0)
139 | cv2.rectangle(img, (x0, y0), (x1, y1), _GREEN, thickness=thick)
140 | return img
141 |
142 |
143 | def vis_keypoints(img, kps, kp_thresh=2, alpha=0.7):
144 | """Visualizes keypoints (adapted from vis_one_image).
145 | kps has shape (4, #keypoints) where 4 rows are (x, y, logit, prob).
146 | """
147 | dataset_keypoints, _ = keypoint_utils.get_keypoints()
148 | kp_lines = kp_connections(dataset_keypoints)
149 |
150 | # Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv.
151 | cmap = plt.get_cmap('rainbow')
152 | colors = [cmap(i) for i in np.linspace(0, 1, len(kp_lines) + 2)]
153 | colors = [(c[2] * 255, c[1] * 255, c[0] * 255) for c in colors]
154 |
155 | # Perform the drawing on a copy of the image, to allow for blending.
156 | kp_mask = np.copy(img)
157 |
158 | # Draw mid shoulder / mid hip first for better visualization.
159 | mid_shoulder = (
160 | kps[:2, dataset_keypoints.index('right_shoulder')] +
161 | kps[:2, dataset_keypoints.index('left_shoulder')]) / 2.0
162 | sc_mid_shoulder = np.minimum(
163 | kps[2, dataset_keypoints.index('right_shoulder')],
164 | kps[2, dataset_keypoints.index('left_shoulder')])
165 | mid_hip = (
166 | kps[:2, dataset_keypoints.index('right_hip')] +
167 | kps[:2, dataset_keypoints.index('left_hip')]) / 2.0
168 | sc_mid_hip = np.minimum(
169 | kps[2, dataset_keypoints.index('right_hip')],
170 | kps[2, dataset_keypoints.index('left_hip')])
171 | nose_idx = dataset_keypoints.index('nose')
172 | if sc_mid_shoulder > kp_thresh and kps[2, nose_idx] > kp_thresh:
173 | cv2.line(
174 | kp_mask, tuple(mid_shoulder), tuple(kps[:2, nose_idx]),
175 | color=colors[len(kp_lines)], thickness=2, lineType=cv2.LINE_AA)
176 | if sc_mid_shoulder > kp_thresh and sc_mid_hip > kp_thresh:
177 | cv2.line(
178 | kp_mask, tuple(mid_shoulder), tuple(mid_hip),
179 | color=colors[len(kp_lines) + 1], thickness=2, lineType=cv2.LINE_AA)
180 |
181 | # Draw the keypoints.
182 | for l in range(len(kp_lines)):
183 | i1 = kp_lines[l][0]
184 | i2 = kp_lines[l][1]
185 | p1 = kps[0, i1], kps[1, i1]
186 | p2 = kps[0, i2], kps[1, i2]
187 | if kps[2, i1] > kp_thresh and kps[2, i2] > kp_thresh:
188 | cv2.line(
189 | kp_mask, p1, p2,
190 | color=colors[l], thickness=2, lineType=cv2.LINE_AA)
191 | if kps[2, i1] > kp_thresh:
192 | cv2.circle(
193 | kp_mask, p1,
194 | radius=3, color=colors[l], thickness=-1, lineType=cv2.LINE_AA)
195 | if kps[2, i2] > kp_thresh:
196 | cv2.circle(
197 | kp_mask, p2,
198 | radius=3, color=colors[l], thickness=-1, lineType=cv2.LINE_AA)
199 |
200 | # Blend the keypoints.
201 | return cv2.addWeighted(img, 1.0 - alpha, kp_mask, alpha, 0)
202 |
203 |
204 | def vis_one_image_opencv(
205 | im, boxes, segms=None, keypoints=None, thresh=0.9, kp_thresh=2,
206 | show_box=False, dataset=None, show_class=False,
207 | alpha=0.4, show_border=True, border_thick=1, bbox_thick=1, font_scale=0.35):
208 | """Constructs a numpy array with the detections visualized."""
209 |
210 | if isinstance(boxes, list):
211 | boxes, segms, keypoints, classes = convert_from_cls_format(
212 | boxes, segms, keypoints)
213 |
214 | if boxes is None or boxes.shape[0] == 0 or max(boxes[:, 4]) < thresh:
215 | return im
216 |
217 | if segms is not None and len(segms) > 0:
218 | masks = mask_util.decode(segms)
219 | color_list = colormap()
220 | mask_color_id = 0
221 |
222 | # Display in largest to smallest order to reduce occlusion
223 | areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
224 | sorted_inds = np.argsort(-areas)
225 |
226 | for i in sorted_inds:
227 | bbox = boxes[i, :4]
228 | score = boxes[i, -1]
229 | if score < thresh:
230 | continue
231 |
232 | # show box (off by default)
233 | if show_box:
234 | im = vis_bbox(
235 | im, (bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]), thick=bbox_thick)
236 |
237 | # show class (off by default)
238 | if show_class:
239 | class_str = get_class_string(classes[i], score, dataset)
240 | im = vis_class(im, (bbox[0], bbox[1] - 2), class_str, font_scale=font_scale)
241 |
242 | # show mask
243 | if segms is not None and len(segms) > i:
244 | color_mask = color_list[mask_color_id % len(color_list), 0:3]
245 | mask_color_id += 1
246 | im = vis_mask(im, masks[..., i], color_mask, alpha=alpha,
247 | show_border=show_border, border_thick=border_thick)
248 |
249 | # show keypoints
250 | if keypoints is not None and len(keypoints) > i:
251 | im = vis_keypoints(im, keypoints[i], kp_thresh)
252 |
253 | return im
254 |
255 |
256 | def vis_one_image_binary(im, boxes, segms, keypoints=None, thresh=0.9):
257 | im = np.zeros_like(im)
258 | if isinstance(boxes, list):
259 | boxes, segms, keypoints, classes = convert_from_cls_format(
260 | boxes, segms, keypoints)
261 |
262 | if boxes is None or boxes.shape[0] == 0 or max(boxes[:, 4]) < thresh:
263 | return im
264 |
265 | if segms is not None and len(segms) > 0:
266 | masks = mask_util.decode(segms)
267 |
268 | # Display in largest to smallest order to reduce occlusion
269 | areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
270 | sorted_inds = np.argsort(-areas)
271 |
272 | for i in sorted_inds:
273 | bbox = boxes[i, :4]
274 | score = boxes[i, -1]
275 | if score < thresh:
276 | continue
277 |
278 | color_mask = np.array([1., 1., 1.]) * 255
279 | im = vis_mask(im, masks[..., i], color_mask, alpha=1.,
280 | show_border=False)
281 |
282 | return im
283 |
284 |
285 | def vis_one_image(
286 | im, im_name, output_dir, boxes, segms=None, keypoints=None, thresh=0.9,
287 | kp_thresh=2, dpi=200, box_alpha=0.0, dataset=None, show_class=False,
288 | ext='pdf', out_when_no_box=False):
289 | """Visual debugging of detections."""
290 | if not os.path.exists(output_dir):
291 | os.makedirs(output_dir)
292 |
293 | if isinstance(boxes, list):
294 | boxes, segms, keypoints, classes = convert_from_cls_format(
295 | boxes, segms, keypoints)
296 |
297 | if (boxes is None or boxes.shape[0] == 0 or max(boxes[:, 4]) < thresh) and not out_when_no_box:
298 | return
299 |
300 | dataset_keypoints, _ = keypoint_utils.get_keypoints()
301 |
302 | if segms is not None and len(segms) > 0:
303 | masks = mask_util.decode(segms)
304 |
305 | color_list = colormap(rgb=True) / 255
306 |
307 | kp_lines = kp_connections(dataset_keypoints)
308 | cmap = plt.get_cmap('rainbow')
309 | colors = [cmap(i) for i in np.linspace(0, 1, len(kp_lines) + 2)]
310 |
311 | fig = plt.figure(frameon=False)
312 | fig.set_size_inches(im.shape[1] / dpi, im.shape[0] / dpi)
313 | ax = plt.Axes(fig, [0., 0., 1., 1.])
314 | ax.axis('off')
315 | fig.add_axes(ax)
316 | ax.imshow(im)
317 |
318 | if boxes is None:
319 | sorted_inds = [] # avoid crash when 'boxes' is None
320 | else:
321 | # Display in largest to smallest order to reduce occlusion
322 | areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
323 | sorted_inds = np.argsort(-areas)
324 |
325 | mask_color_id = 0
326 | for i in sorted_inds:
327 | bbox = boxes[i, :4]
328 | score = boxes[i, -1]
329 | if score < thresh:
330 | continue
331 |
332 | # show box (off by default)
333 | ax.add_patch(
334 | plt.Rectangle((bbox[0], bbox[1]),
335 | bbox[2] - bbox[0],
336 | bbox[3] - bbox[1],
337 | fill=False, edgecolor='g',
338 | linewidth=0.5, alpha=box_alpha))
339 |
340 | if show_class:
341 | ax.text(
342 | bbox[0], bbox[1] - 2,
343 | get_class_string(classes[i], score, dataset),
344 | fontsize=3,
345 | family='serif',
346 | bbox=dict(
347 | facecolor='g', alpha=0.4, pad=0, edgecolor='none'),
348 | color='white')
349 |
350 | # show mask
351 | if segms is not None and len(segms) > i:
352 | img = np.ones(im.shape)
353 | color_mask = color_list[mask_color_id % len(color_list), 0:3]
354 | mask_color_id += 1
355 |
356 | w_ratio = .4
357 | for c in range(3):
358 | color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio
359 | for c in range(3):
360 | img[:, :, c] = color_mask[c]
361 | e = masks[:, :, i]
362 |
363 | # cv2.findCountours gives (image, contours, hierarchy) back in opencv 3.x
364 | # but gives back (contours, hierachy) in opencv 2.x and 4.x
365 | contour, hier = cv2.findContours(
366 | e.copy(), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)[-2:]
367 |
368 | for c in contour:
369 | polygon = Polygon(
370 | c.reshape((-1, 2)),
371 | fill=True, facecolor=color_mask,
372 | edgecolor='w', linewidth=1.2,
373 | alpha=0.5)
374 | ax.add_patch(polygon)
375 |
376 | # show keypoints
377 | if keypoints is not None and len(keypoints) > i:
378 | kps = keypoints[i]
379 | plt.autoscale(False)
380 | for l in range(len(kp_lines)):
381 | i1 = kp_lines[l][0]
382 | i2 = kp_lines[l][1]
383 | if kps[2, i1] > kp_thresh and kps[2, i2] > kp_thresh:
384 | x = [kps[0, i1], kps[0, i2]]
385 | y = [kps[1, i1], kps[1, i2]]
386 | line = plt.plot(x, y)
387 | plt.setp(line, color=colors[l], linewidth=1.0, alpha=0.7)
388 | if kps[2, i1] > kp_thresh:
389 | plt.plot(
390 | kps[0, i1], kps[1, i1], '.', color=colors[l],
391 | markersize=3.0, alpha=0.7)
392 |
393 | if kps[2, i2] > kp_thresh:
394 | plt.plot(
395 | kps[0, i2], kps[1, i2], '.', color=colors[l],
396 | markersize=3.0, alpha=0.7)
397 |
398 | # add mid shoulder / mid hip for better visualization
399 | mid_shoulder = (
400 | kps[:2, dataset_keypoints.index('right_shoulder')] +
401 | kps[:2, dataset_keypoints.index('left_shoulder')]) / 2.0
402 | sc_mid_shoulder = np.minimum(
403 | kps[2, dataset_keypoints.index('right_shoulder')],
404 | kps[2, dataset_keypoints.index('left_shoulder')])
405 | mid_hip = (
406 | kps[:2, dataset_keypoints.index('right_hip')] +
407 | kps[:2, dataset_keypoints.index('left_hip')]) / 2.0
408 | sc_mid_hip = np.minimum(
409 | kps[2, dataset_keypoints.index('right_hip')],
410 | kps[2, dataset_keypoints.index('left_hip')])
411 | if (sc_mid_shoulder > kp_thresh and
412 | kps[2, dataset_keypoints.index('nose')] > kp_thresh):
413 | x = [mid_shoulder[0], kps[0, dataset_keypoints.index('nose')]]
414 | y = [mid_shoulder[1], kps[1, dataset_keypoints.index('nose')]]
415 | line = plt.plot(x, y)
416 | plt.setp(
417 | line, color=colors[len(kp_lines)], linewidth=1.0, alpha=0.7)
418 | if sc_mid_shoulder > kp_thresh and sc_mid_hip > kp_thresh:
419 | x = [mid_shoulder[0], mid_hip[0]]
420 | y = [mid_shoulder[1], mid_hip[1]]
421 | line = plt.plot(x, y)
422 | plt.setp(
423 | line, color=colors[len(kp_lines) + 1], linewidth=1.0,
424 | alpha=0.7)
425 |
426 | output_name = os.path.basename(im_name) + '.' + ext
427 | fig.savefig(os.path.join(output_dir, '{}'.format(output_name)), dpi=dpi)
428 | plt.close('all')
429 |
--------------------------------------------------------------------------------
/Models/regressionTemplateTF/train_regression.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Foundry.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | ##############################################################################
15 |
16 | from __future__ import division, print_function, absolute_import
17 | from builtins import input # python 2/3 forward-compatible (raw_input)
18 |
19 | import sys
20 | import os
21 | import time
22 | import random
23 | import argparse
24 | from datetime import datetime
25 |
26 | import numpy as np
27 |
28 | import tensorflow as tf
29 | print(tf.__version__)
30 |
31 | tf.compat.v1.enable_eager_execution()
32 |
33 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
34 | from common.model_builder import baseline_model
35 | from common.util import get_filepaths_from_dir, get_ckpt_list, print_
36 | from common.util import is_exr, read_resize_exr, linear_to_srgb
37 |
38 | def enable_deterministic_training(seed, no_gpu_patch=False):
39 | """Set all seeds for deterministic training
40 |
41 | Args:
42 | no_gpu_patch (bool): if False, apply a patch to TensorFlow to have
43 | deterministic GPU operations, if True the training is much faster
44 | but slightly less deterministic.
45 | This function needs to be called before any TensorFlow code.
46 | """
47 | import numpy as np
48 | import os
49 | import random
50 | import tfdeterminism
51 | if not no_gpu_patch:
52 | # Patch stock TensorFlow to have deterministic GPU operation
53 | tfdeterminism.patch() # then use tf as normal
54 | # If PYTHONHASHSEED environment variable is not set or set to random,
55 | # a random value is used to seed the hashes of str, bytes and datetime
56 | # objects. (Necessary for Python >= 3.2.3)
57 | os.environ['PYTHONHASHSEED']=str(seed)
58 | # Set python built-in pseudo-random generator at a fixed value
59 | random.seed(seed)
60 | # Set seed for random Numpy operation (e.g. np.random.randint)
61 | np.random.seed(seed)
62 | # Set seed for random TensorFlow operation (e.g. tf.image.random_crop)
63 | tf.compat.v1.random.set_random_seed(seed)
64 |
65 | ## DATA PROCESSING
66 |
67 | def histogram(tensor, value_range=[0.0, 1.0], nbins=100):
68 | """Return histogram of tensor"""
69 | h, w, c = tensor.shape
70 | hist = tf.histogram_fixed_width(tensor, value_range, nbins=nbins)
71 | hist = tf.divide(hist, h * w * c)
72 | return hist
73 |
74 | def gamma_correction(img, gamma):
75 | """Apply gamma correction to image img
76 |
77 | Returns:
78 | hists: stack of both original and graded image histograms
79 | """
80 | # Check number of parameter is one
81 | if gamma.shape[0] != 1:
82 | raise ValueError("Parameter for gamma correction must be of "
83 | "size (1,), not {}.\n\tCheck your self.output_param_number, ".format(gamma.shape)
84 | + "you may need to implement your own input_data preprocessing.")
85 | # Create groundtruth graded image
86 | img_grade = tf.math.pow(img, gamma)
87 | # Compute histograms
88 | img_hist = histogram(img)
89 | img_grade_hist = histogram(img_grade)
90 | hists = tf.stack([img_hist, img_grade_hist], axis=0)
91 | return hists
92 |
93 | ## CUSTOM TRAINING METRICS
94 |
95 | def bin_acc(y_true, y_pred, delta=0.02):
96 | """Bin accuracy metric equals 1.0 if diff between true
97 | and predicted value is inferior to delta.
98 | """
99 | diff = tf.keras.backend.abs(y_true - y_pred)
100 | # If diff is less that delta --> true (1.0), otherwise false (0.0)
101 | correct = tf.keras.backend.less(diff, delta)
102 | # Return percentage accuracy
103 | return tf.keras.backend.mean(correct)
104 |
105 | class TrainModel(object):
106 | """Train Regression model from the given data"""
107 |
108 | def __init__(self, args):
109 | # Training hyperparameters
110 | self.learning_rate = args.learning_rate
111 | self.batch_size = args.batch_size
112 | self.epoch = args.epoch
113 | self.patch_size = 50
114 | self.channels = 3 # input / output channels
115 | self.output_param_number = 1
116 | self.no_resume = args.no_resume
117 | # A random seed (!=None) allows you to reproduce your training results
118 | self.seed = args.seed
119 | if self.seed is not None:
120 | # Set all seeds necessary for deterministic training
121 | enable_deterministic_training(self.seed, args.no_gpu_patch)
122 | # Training and validation dataset paths
123 | train_data_path = './data/train/'
124 | val_data_path = './data/validation/'
125 |
126 | # Where to save and load model weights (=checkpoints)
127 | self.ckpt_dir = './checkpoints'
128 | if not os.path.exists(self.ckpt_dir):
129 | os.makedirs(self.ckpt_dir)
130 | self.ckpt_save_name = args.ckpt_save_name
131 |
132 | # Where to save tensorboard summaries
133 | self.summaries_dir = './summaries/'
134 | if not os.path.exists(self.summaries_dir):
135 | os.makedirs(self.summaries_dir)
136 |
137 | # Get training dataset as list of image paths
138 | self.train_data_list = get_filepaths_from_dir(train_data_path)
139 | if not self.train_data_list:
140 | raise ValueError("No training data found in folder {}".format(train_data_path))
141 | elif (len(self.train_data_list) < self.batch_size):
142 | raise ValueError("Batch size must be smaller than the dataset (batch size = {}, number of training data = {})"
143 | .format(self.batch_size, len(self.train_data_list)))
144 | self.is_exr = is_exr(self.train_data_list[0])
145 |
146 | # Compute and print training hyperparameters
147 | self.batch_per_epoch = (len(self.train_data_list)) // self.batch_size
148 | max_steps = int(self.epoch * (self.batch_per_epoch))
149 | print_("Number of training data: {}\nNumber of batches per epoch: {} (batch size = {})\nNumber of training steps for {} epochs: {}\n"
150 | .format(len(self.train_data_list), self.batch_per_epoch, self.batch_size, self.epoch, max_steps), 'm')
151 |
152 | # Get validation dataset if provided
153 | self.has_val_data = True
154 | self.val_data_list = get_filepaths_from_dir(val_data_path)
155 | if not self.val_data_list:
156 | print("No validation data found in {}".format(val_data_path))
157 | self.has_val_data = False
158 | elif (len(self.val_data_list) < self.batch_size):
159 | raise ValueError("Batch size must be smaller than the dataset (batch size = {}, number of validation data = {})"
160 | .format(self.batch_size, len(self.val_data_list)))
161 | else:
162 | val_is_exr = is_exr(self.val_data_list[0])
163 | if (val_is_exr and not self.is_exr) or (not val_is_exr and self.is_exr):
164 | raise TypeError("Train and validation data should have the same file format")
165 | self.val_batch_per_epoch = (len(self.val_data_list)) // self.batch_size
166 | print("Number of validation data: {}\nNumber of validation batches per epoch: {} (batch size = {})"
167 | .format(len(self.val_data_list), self.val_batch_per_epoch, self.batch_size))
168 |
169 | def get_data(self, data_list, batch_size=16, epoch=100, shuffle_buffer_size=1000):
170 |
171 | def read_and_preprocess_data(path_img, param):
172 | """Read image in path_img, resize it to patch_size,
173 | convert to grayscale and apply a random gamma grade to it
174 |
175 | Returns:
176 | input_data: stack of both original and graded image histograms
177 | param: groundtruth gamma value
178 | """
179 | if self.is_exr: # ['exr', 'EXR']
180 | img = tf.numpy_function(read_resize_exr,
181 | [path_img, self.patch_size], [tf.float32])
182 | img = tf.numpy_function(linear_to_srgb, [img], [tf.float32])
183 | img = tf.reshape(img, [self.patch_size, self.patch_size, self.channels])
184 | img = tf.image.rgb_to_grayscale(img)
185 | else: # ['jpg', 'jpeg', 'png', 'bmp', 'JPG', 'JPEG', 'PNG', 'BMP']
186 | img_raw = tf.io.read_file(path_img)
187 | img_tensor = tf.image.decode_png(img_raw, channels=3)
188 | img = tf.cast(img_tensor, tf.float32) / 255.0
189 | img = tf.image.rgb_to_grayscale(img)
190 | img = tf.image.resize(img, [self.patch_size, self.patch_size])
191 | # Depending on what parameter(s) you want to learn, modify the training
192 | # input data. Here to learn gamma correction, our input data trainX is
193 | # a stack of both original and gamma-graded histograms.
194 | input_data = gamma_correction(img, param)
195 | return input_data, param
196 |
197 | with tf.compat.v1.variable_scope('input'):
198 | # Ensure preprocessing is done on the CPU (to let the GPU focus on training)
199 | with tf.device('/cpu:0'):
200 | data_tensor = tf.convert_to_tensor(data_list, dtype=tf.string)
201 | path_dataset = tf.data.Dataset.from_tensor_slices((data_tensor))
202 | path_dataset = path_dataset.shuffle(shuffle_buffer_size).repeat(epoch)
203 | # Depending on what parameter(s) you want to learn, modify the random
204 | # uniform range. Here create random gamma values between 0.2 and 5
205 | param_tensor = tf.random.uniform(
206 | [len(data_list)*epoch, self.output_param_number], 0.2, 5.0)
207 | param_dataset = tf.data.Dataset.from_tensor_slices((param_tensor))
208 | dataset = tf.data.Dataset.zip((path_dataset, param_dataset))
209 | # Apply read_and_preprocess_data function to all input in the path_dataset
210 | dataset = dataset.map(read_and_preprocess_data, num_parallel_calls=4)
211 | dataset = dataset.batch(batch_size)
212 | # Always prefetch one batch and make sure there is always one ready
213 | dataset = dataset.prefetch(buffer_size=1)
214 | return dataset
215 |
216 | def tensorboard_callback(self, writer):
217 | """Return custom Tensorboard callback for logging main metrics"""
218 |
219 | def log_metrics(epoch, logs):
220 | """Log training/validation loss and accuracy to Tensorboard"""
221 | with writer.as_default(), tf.contrib.summary.always_record_summaries():
222 | tf.contrib.summary.scalar('train_loss', logs['loss'], step=epoch)
223 | tf.contrib.summary.scalar('train_bin_acc', logs['bin_acc'], step=epoch)
224 | if self.has_val_data:
225 | tf.contrib.summary.scalar('val_loss', logs['val_loss'], step=epoch)
226 | tf.contrib.summary.scalar('val_bin_acc', logs['val_bin_acc'], step=epoch)
227 | tf.contrib.summary.flush()
228 |
229 | return tf.keras.callbacks.LambdaCallback(on_epoch_end=log_metrics)
230 |
231 | def get_compiled_model(self, input_shape):
232 | model = baseline_model(
233 | input_shape,
234 | output_param_number=self.output_param_number)
235 | adam = tf.keras.optimizers.Adam(lr=self.learning_rate)
236 | model.compile(optimizer=adam,
237 | loss='mean_squared_error',
238 | metrics=[bin_acc])
239 | return model
240 |
241 | def train(self):
242 | # Create a session so that tf.keras don't allocate all GPU memory at once
243 | sess = tf.compat.v1.Session(
244 | config=tf.compat.v1.ConfigProto(
245 | gpu_options=tf.compat.v1.GPUOptions(allow_growth=True)))
246 | tf.compat.v1.keras.backend.set_session(sess)
247 |
248 | # Get training and validation dataset
249 | ds_train = self.get_data(
250 | self.train_data_list,
251 | self.batch_size,
252 | self.epoch)
253 | for x, y in ds_train.take(1): # take one batch from ds_train
254 | trainX, trainY = x, y
255 | print("Input shape {}, target shape: {}".format(trainX.shape, trainY.shape))
256 | if self.has_val_data:
257 | ds_val = self.get_data(
258 | self.val_data_list,
259 | self.batch_size,
260 | self.epoch)
261 | print("********Data Created********")
262 |
263 | # Build model
264 | model = self.get_compiled_model(trainX.shape[1:])
265 |
266 | # Check if there are intermediate trained model to load
267 | if self.no_resume or not self.load(model):
268 | print_("Starting training from scratch\n", 'm')
269 |
270 | # Callback for creating Tensorboard summary
271 | summary_name = ("data{}_bch{}_ep{}".format(
272 | len(self.train_data_list), self.batch_size, self.epoch))
273 | summary_name += ("_seed{}".format(self.seed) if self.seed is not None else "")
274 | summary_writer = tf.contrib.summary.create_file_writer(
275 | os.path.join(self.summaries_dir, summary_name))
276 | tb_callback = self.tensorboard_callback(summary_writer)
277 |
278 | # Callback for saving model's weights
279 | ckpt_path = os.path.join(self.ckpt_dir, self.ckpt_save_name + "-ep{epoch:02d}")
280 | ckpt_callback = tf.keras.callbacks.ModelCheckpoint(
281 | filepath=ckpt_path,
282 | # save best model based on monitor value
283 | monitor='val_loss' if self.has_val_data else 'loss',
284 | verbose=1,
285 | save_best_only=True,
286 | save_weights_only=True)
287 |
288 | # Evaluate the model before training
289 | if self.has_val_data:
290 | val_loss, val_bin_acc = model.evaluate(ds_val.take(20), verbose=1)
291 | print("Initial Loss on validation dataset: {:.4f}".format(val_loss))
292 |
293 | # TRAIN model
294 | print_("--------Start of training--------\n", 'm')
295 | print("NOTE:\tDuring training, the latest model is saved only if its\n"
296 | "\t(validation) loss is better than the last best model.")
297 | train_start = time.time()
298 | model.fit(
299 | ds_train,
300 | validation_data=ds_val if self.has_val_data else None,
301 | epochs=self.epoch,
302 | steps_per_epoch=self.batch_per_epoch,
303 | validation_steps=self.val_batch_per_epoch if self.has_val_data else None,
304 | callbacks=[ckpt_callback, tb_callback],
305 | verbose=1)
306 | print_("Training duration: {:0.4f}s\n".format(time.time() - train_start), 'm')
307 | print_("--------End of training--------\n", 'm')
308 |
309 | # Show predictions on the first batch of training data
310 | print("Parameter prediction (PR) compared to groundtruth (GT) for first batch of training data:")
311 | preds_train = model.predict(trainX.numpy())
312 | print("Train GT:", trainY.numpy().flatten())
313 | print("Train PR:", preds_train.flatten())
314 | # Make predictions on the first batch of validation data
315 | if self.has_val_data:
316 | print("For first batch of validation data:")
317 | for x, y in ds_val.take(1): # take one batch from ds_val
318 | valX, valY = x, y
319 | preds_val = model.predict(valX)
320 | print("Val GT:", valY.numpy().flatten())
321 | print("Val PR:", preds_val.flatten())
322 | # Free all resources associated with the session
323 | sess.close()
324 |
325 | def load(self, model):
326 | ckpt_names = get_ckpt_list(self.ckpt_dir)
327 | if not ckpt_names: # list is empty
328 | print_("No checkpoints found in {}\n".format(self.ckpt_dir), 'm')
329 | return False
330 | else:
331 | print_("Found checkpoints:\n", 'm')
332 | for name in ckpt_names:
333 | print(" {}".format(name))
334 | # Ask user if they prefer to start training from scratch or resume training on a specific ckeckpoint
335 | while True:
336 | mode=str(input('Start training from scratch (start) or resume training from a previous checkpoint (choose one of the above): '))
337 | if mode == 'start' or mode in ckpt_names:
338 | break
339 | else:
340 | print("Answer should be 'start' or one of the following checkpoints: {}".format(ckpt_names))
341 | continue
342 | if mode == 'start':
343 | return False
344 | elif mode in ckpt_names:
345 | # Try to load given intermediate checkpoint
346 | print_("Loading trained model...\n", 'm')
347 | model.load_weights(os.path.join(self.ckpt_dir, mode))
348 | print_("...Checkpoint {} loaded\n".format(mode), 'm')
349 | return True
350 | else:
351 | raise ValueError("User input is neither 'start' nor a valid checkpoint")
352 |
353 | def evaluate(self, test_data_path, weights):
354 | """Evaluate a trained model on the test dataset
355 |
356 | Args:
357 | test_data_path (str): path to directory containing images for testing
358 | weights (str): name of the tensorflow checkpoint (weights) to evaluate
359 | """
360 | test_data_list = get_filepaths_from_dir(test_data_path)
361 | if not test_data_list:
362 | raise ValueError("No test data found in folder {}".format(test_data_path))
363 | elif (len(self.train_data_list) < self.batch_size):
364 | raise ValueError("Batch size must be smaller than the dataset (batch size = {}, number of test data = {})"
365 | .format(self.batch_size, len(test_data_list)))
366 | self.is_exr = is_exr(test_data_list[0])
367 |
368 | # Get and create test dataset
369 | ds_test = self.get_data(
370 | test_data_list,
371 | self.batch_size,
372 | 1)
373 | for x, y in ds_test.take(1): # take one batch from ds_test
374 | testX, testY = x, y
375 | print_("Number of test data: {}\n".format(len(test_data_list)), 'm')
376 | print("Input shape {}, target shape: {}".format(testX.shape, testY.shape))
377 |
378 | # Build model
379 | model = self.get_compiled_model(testX.shape[1:])
380 |
381 | # Load model weights
382 | print_("Loading trained model for testing...\n", 'm')
383 | model.load_weights(os.path.join(self.ckpt_dir, weights)).expect_partial()
384 | print_("...Checkpoint {} loaded\n".format(weights), 'm')
385 |
386 | # Test final model on this unseen dataset
387 | results = model.evaluate(ds_test)
388 | print("test loss, test acc:", results)
389 | print_("--------End of testing--------\n", 'm')
390 |
391 | def parse_args():
392 | parser = argparse.ArgumentParser(description='Model training arguments')
393 | parser.add_argument('--bch', type=int, default=10, dest='batch_size', help='training batch size')
394 | parser.add_argument('--ep', type=int, default=15, dest='epoch', help='training epoch number')
395 | parser.add_argument('--lr', type=float, default=1e-3, dest='learning_rate', help='initial learning rate')
396 | parser.add_argument('--seed', type=int, default=None, dest='seed', help='set random seed for deterministic training')
397 | parser.add_argument('--no-gpu-patch', dest='no_gpu_patch', default=False, action='store_true', help='if seed is set, add this tag for much faster but slightly less deterministic training')
398 | parser.add_argument('--no-resume', dest='no_resume', default=False, action='store_true', help="start training from scratch")
399 | parser.add_argument('--name', type=str, default="regressionTemplateTF", dest='ckpt_save_name', help='name of saved checkpoints/model weights')
400 | args = parser.parse_args()
401 | return args
402 |
403 | if __name__ == '__main__':
404 | args = parse_args()
405 | # Set up model to train
406 | model = TrainModel(args)
407 | model.train()
408 | # To evaluate on the test dataset, uncomment next line and give the
409 | # test dataset directory and the model checkpoint name
410 | # model.evaluate('./data/test', 'regressionTemplateTF-ep35')
--------------------------------------------------------------------------------
/Models/trainingTemplateTF/train_model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019 Foundry.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | ##############################################################################
15 |
16 | from __future__ import print_function
17 | from builtins import input, range # python 2/3 forward-compatible (input_raw, xrange)
18 |
19 | import sys
20 | import os
21 | import time
22 | import random
23 | import argparse
24 | from datetime import datetime
25 |
26 | import numpy as np
27 | import tensorflow as tf
28 | print(tf.__version__)
29 | tf.compat.v1.disable_eager_execution() # For TF 2.x compatibility
30 |
31 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
32 | from common.model_builder import EncoderDecoder
33 | from common.util import im2uint8, get_filepaths_from_dir, get_ckpt_list, print_
34 | from common.util import is_exr, read_crop_exr_pair, linear_to_srgb
35 |
36 | def enable_deterministic_training(seed, no_gpu_patch=False):
37 | """Set all seeds for deterministic training
38 | This function needs to be called before any tensorflow code.
39 | """
40 | import numpy as np
41 | import os
42 | import random
43 | import tfdeterminism
44 | if not no_gpu_patch:
45 | # Patch stock TensorFlow to have deterministic GPU operation
46 | tfdeterminism.patch() # then use tf as normal
47 | # If PYTHONHASHSEED environment variable is not set or set to random,
48 | # a random value is used to seed the hashes of str, bytes and datetime
49 | # objects. (Necessary for Python >= 3.2.3)
50 | os.environ['PYTHONHASHSEED']=str(seed)
51 | # Set python built-in pseudo-random generator at a fixed value
52 | random.seed(seed)
53 | # Set seed for random Numpy operation (e.g. np.random.randint)
54 | np.random.seed(seed)
55 | # Set seed for random TensorFlow operation (e.g. tf.image.random_crop)
56 | tf.compat.v1.random.set_random_seed(seed)
57 |
58 | class TrainModel(object):
59 | """Train the EncoderDecoder from the given input and groundtruth data"""
60 |
61 | def __init__(self, args):
62 | # Training hyperparameters
63 | self.learning_rate = args.learning_rate
64 | self.batch_size = args.batch_size
65 | self.epoch = args.epoch
66 | self.no_resume = args.no_resume
67 | # A random seed (!=None) allows you to reproduce your training results
68 | self.seed = args.seed
69 | if self.seed is not None:
70 | # Set all seeds necessary for deterministic training
71 | enable_deterministic_training(self.seed, args.no_gpu_patch)
72 | self.crop_size = 256
73 | self.n_levels = 3
74 | self.scale = 0.5
75 | self.channels = 3 # input / output channels
76 | # Training and validation dataset paths
77 | train_in_data_path = './data/train/input'
78 | train_gt_data_path = './data/train/groundtruth'
79 | val_in_data_path = './data/validation/input'
80 | val_gt_data_path = './data/validation/groundtruth'
81 |
82 | # Where to save and load model weights (=checkpoints)
83 | self.checkpoints_dir = './checkpoints'
84 | if not os.path.exists(self.checkpoints_dir):
85 | os.makedirs(self.checkpoints_dir)
86 | self.ckpt_save_name = args.ckpt_save_name
87 | # Maximum number of recent checkpoint files to keep
88 | self.max_ckpts_to_keep = 50
89 | # In addition keep one checkpoint file for every N hours of training
90 | self.keep_ckpt_every_n_hours = 1
91 | # How often, in training steps. we save model checkpoints
92 | self.ckpts_save_freq = 1000
93 | # How often, in training steps. we print training losses to bash
94 | self.training_print_freq = 10
95 |
96 | # Where to save tensorboard summaries
97 | self.summaries_dir = './summaries'
98 | if not os.path.exists(self.summaries_dir):
99 | os.makedirs(self.summaries_dir)
100 | # How often, in training steps. we save tensorboard summaries
101 | self.summaries_save_freq = 10
102 | # How often, in secs, we flush the pending tensorboard summaries to disk
103 | self.summary_flush_secs = 30
104 |
105 | # Get training dataset as lists of image paths
106 | self.train_in_data_list = get_filepaths_from_dir(train_in_data_path)
107 | self.train_gt_data_list = get_filepaths_from_dir(train_gt_data_path)
108 | if not self.train_in_data_list or not self.train_gt_data_list:
109 | raise ValueError("No training data found in folders {} or {}".format(train_in_data_path, train_gt_data_path))
110 | elif len(self.train_in_data_list) != len(self.train_gt_data_list):
111 | raise ValueError("{} ({} data) and {} ({} data) should have the same number of input data"
112 | .format(train_in_data_path, len(self.train_in_data_list), train_gt_data_path, len(self.train_gt_data_list)))
113 | elif (len(self.train_in_data_list) < self.batch_size):
114 | raise ValueError("Batch size must be smaller than the dataset (batch size = {}, number of training data = {})"
115 | .format(self.batch_size, len(self.train_in_data_list)))
116 | self.is_exr = is_exr(self.train_in_data_list[0])
117 |
118 | # Get validation dataset if provided
119 | self.has_val_data = True
120 | self.val_in_data_list = get_filepaths_from_dir(val_in_data_path)
121 | self.val_gt_data_list = get_filepaths_from_dir(val_gt_data_path)
122 | if not self.val_in_data_list or not self.val_gt_data_list:
123 | print("No validation data found in {} or {}".format(val_in_data_path, val_gt_data_path))
124 | self.has_val_data = False
125 | elif len(self.val_in_data_list) != len(self.val_gt_data_list):
126 | raise ValueError("{} ({} data) and {} ({} data) should have the same number of input data"
127 | .format(val_in_data_path, len(self.val_in_data_list), val_gt_data_path, len(self.val_gt_data_list)))
128 | elif (len(self.val_in_data_list) < self.batch_size):
129 | raise ValueError("Batch size must be smaller than the dataset (batch size = {}, number of validation data = {})"
130 | .format(self.batch_size, len(self.val_in_data_list)))
131 | else:
132 | val_is_exr = is_exr(self.val_in_data_list[0])
133 | if (val_is_exr and not self.is_exr) or (not val_is_exr and self.is_exr):
134 | raise TypeError("Train and validation data should have the same file format")
135 | print("Number of validation data: {}".format(len(self.val_in_data_list)))
136 |
137 | # Compute and print training hyperparameters
138 | batch_per_epoch = (len(self.train_in_data_list)) // self.batch_size
139 | self.max_steps = int(self.epoch * (batch_per_epoch))
140 | print_("Number of training data: {}\nNumber of batches per epoch: {} (batch size = {})\nNumber of training steps for {} epochs: {}\n"
141 | .format(len(self.train_in_data_list), batch_per_epoch, self.batch_size, self.epoch, self.max_steps), 'm')
142 |
143 | def get_data(self, in_data_list, gt_data_list, batch_size=16, epoch=100):
144 |
145 | def read_and_preprocess(path_img_in, path_img_gt):
146 | if self.is_exr: # ['exr', 'EXR']
147 | # Read and crop data
148 | img_crop = tf.numpy_function(read_crop_exr_pair,
149 | [path_img_in, path_img_gt, self.crop_size], [tf.float32, tf.float32])
150 | img_crop = tf.numpy_function(linear_to_srgb, [img_crop], tf.float32)
151 | img_crop = tf.unstack(tf.reshape(img_crop, [2, self.crop_size, self.crop_size, self.channels]))
152 | else: # ['jpg', 'jpeg', 'png', 'bmp', 'JPG', 'JPEG', 'PNG', 'BMP']
153 | # Read data
154 | img_in_raw = tf.io.read_file(path_img_in)
155 | img_gt_raw = tf.io.read_file(path_img_gt)
156 | img_in_tensor = tf.image.decode_image(img_in_raw, channels=self.channels)
157 | img_gt_tensor = tf.image.decode_image(img_gt_raw, channels=self.channels)
158 | # Normalise then crop data
159 | imgs = [tf.cast(img, tf.float32) / 255.0 for img in [img_in_tensor, img_gt_tensor]]
160 | img_crop = tf.unstack(tf.image.random_crop(tf.stack(imgs, axis=0),
161 | [2, self.crop_size, self.crop_size, self.channels], seed=self.seed), axis=0)
162 | return img_crop
163 |
164 | def multi_thread_preprocess(path_img_in, path_img_gt):
165 | """Non-random data preprocessing to be run in a multi-thread map
166 | Read image in path_img, and normalize it
167 | """
168 | if self.is_exr:
169 | # Do nothing, all preprocessing done in single_thread_preprocess
170 | return path_img_in, path_img_gt
171 | else:
172 | img_in_raw = tf.io.read_file(path_img_in)
173 | img_gt_raw = tf.io.read_file(path_img_gt)
174 | img_in_tensor = tf.image.decode_image(img_in_raw, channels=self.channels)
175 | img_gt_tensor = tf.image.decode_image(img_gt_raw, channels=self.channels)
176 | # Normalise data
177 | imgs = [tf.cast(img, tf.float32) / 255.0 for img in [img_in_tensor, img_gt_tensor]]
178 | return imgs
179 |
180 | def single_thread_preprocess(img_in, img_gt):
181 | """Random data preprocessing to be run in a one thread map
182 | Crop image with deterministic TensorFlow (png) or Numpy (exr) seed
183 | """
184 | if self.is_exr:
185 | img_crop = tf.numpy_function(read_crop_exr_pair,
186 | [img_in, img_gt, self.crop_size], [tf.float32, tf.float32])
187 | img_crop = tf.numpy_function(linear_to_srgb, [img_crop], tf.float32)
188 | img_crop = tf.unstack(tf.reshape(img_crop, [2, self.crop_size, self.crop_size, self.channels]))
189 | else:
190 | img_crop = tf.unstack(tf.image.random_crop(tf.stack([img_in, img_gt], axis=0),
191 | [2, self.crop_size, self.crop_size, self.channels], seed=self.seed), axis=0)
192 | return img_crop
193 |
194 | with tf.compat.v1.variable_scope('input'):
195 | # Ensure preprocessing is done on the CPU (to let the GPU focus on training)
196 | with tf.device('/cpu:0'):
197 | in_list = tf.convert_to_tensor(in_data_list, dtype=tf.string)
198 | gt_list = tf.convert_to_tensor(gt_data_list, dtype=tf.string)
199 |
200 | path_dataset = tf.data.Dataset.from_tensor_slices((in_list, gt_list))
201 | path_dataset = path_dataset.shuffle(
202 | buffer_size=len(in_data_list), seed=self.seed).repeat(epoch)
203 | # Apply read_and_preprocess function to all input in the path_dataset
204 | if self.seed is None:
205 | # Run all preprocessing in one dataset.map()
206 | num_parallel_calls = 1 if self.is_exr else 4
207 | dataset = path_dataset.map(read_and_preprocess, num_parallel_calls)
208 | else:
209 | # Perform the non-random ops in a multi-threaded map()
210 | dataset = path_dataset.map(multi_thread_preprocess, num_parallel_calls=4)
211 | # Perform the random ops in a single-threaded map() for
212 | # deterministic training when seed is not None
213 | dataset = dataset.map(single_thread_preprocess, num_parallel_calls=1)
214 | dataset = dataset.batch(batch_size)
215 | # Always prefetch one batch and make sure there is always one ready
216 | dataset = dataset.prefetch(buffer_size=1)
217 | # Create operator to iterate over the created dataset
218 | next_element = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
219 | return next_element
220 |
221 | def loss(self, n_outputs, img_gt):
222 | """Compute multi-scale loss function"""
223 | loss_total = 0
224 | for i in range(self.n_levels):
225 | _, hi, wi, _ = n_outputs[i].shape
226 | gt_i = tf.image.resize(img_gt, [hi, wi], method='bilinear')
227 | loss = tf.reduce_mean(tf.square(gt_i - n_outputs[i]))
228 | loss_total += loss
229 | # Save out images and loss values to tensorboard
230 | tf.compat.v1.summary.image('out_' + str(i), im2uint8(n_outputs[i]))
231 | # Save total loss to tensorboard
232 | tf.compat.v1.summary.scalar('loss_total', loss_total)
233 | return loss_total
234 |
235 | def validate(self, model):
236 | total_val_loss = 0.0
237 | # Get next data from preprocessed validation dataset
238 | val_img_in, val_img_gt = self.get_data(self.val_in_data_list, self.val_gt_data_list, self.batch_size, -1)
239 | n_outputs = model(val_img_in, reuse=False)
240 | val_op = self.loss(n_outputs, val_img_gt)
241 | # Test results over one epoch
242 | batch_per_epoch = len(self.val_in_data_list) // self.batch_size
243 | for batch in range(batch_per_epoch):
244 | total_val_loss += val_op
245 | return total_val_loss / batch_per_epoch
246 |
247 | def train(self):
248 | # Build model
249 | model = EncoderDecoder(self.n_levels, self.scale, self.channels)
250 |
251 | # Learning rate decay
252 | global_step = tf.Variable(initial_value=0, dtype=tf.int32, trainable=False)
253 | self.lr = tf.compat.v1.train.polynomial_decay(
254 | self.learning_rate, global_step,
255 | decay_steps=self.max_steps,
256 | end_learning_rate=0.0,
257 | power=0.3)
258 | tf.compat.v1.summary.scalar('learning_rate', self.lr)
259 | # Training operator
260 | adam = tf.compat.v1.train.AdamOptimizer(self.lr)
261 |
262 | # Get next data from preprocessed training dataset
263 | img_in, img_gt = self.get_data(
264 | self.train_in_data_list,
265 | self.train_gt_data_list,
266 | self.batch_size,
267 | self.epoch)
268 | print('img_in, img_gt', img_in.shape, img_gt.shape)
269 | tf.compat.v1.summary.image('img_in', im2uint8(img_in))
270 | tf.compat.v1.summary.image('img_gt', im2uint8(img_gt))
271 |
272 | # Compute image loss
273 | n_outputs = model(img_in, reuse=False)
274 | loss_op = self.loss(n_outputs, img_gt)
275 | # By default, adam uses the current graph trainable_variables to optimise training,
276 | # thus train_op should be the last operation of the graph for training.
277 | train_op = adam.minimize(loss_op, global_step)
278 |
279 | # Create session
280 | sess = tf.compat.v1.Session(
281 | config=tf.compat.v1.ConfigProto(
282 | gpu_options=tf.compat.v1.GPUOptions(allow_growth=True)))
283 |
284 | # Initialise all the variables in current session
285 | init = tf.compat.v1.global_variables_initializer()
286 | sess.run(init)
287 | self.saver = tf.compat.v1.train.Saver(
288 | max_to_keep=self.max_ckpts_to_keep,
289 | keep_checkpoint_every_n_hours=self.keep_ckpt_every_n_hours)
290 |
291 | # Check if there are intermediate trained model to load
292 | if self.no_resume or not self.load(sess, self.checkpoints_dir):
293 | print_("Starting training from scratch\n", 'm')
294 |
295 | # Tensorboard summary
296 | summary_op = tf.compat.v1.summary.merge_all()
297 | summary_name = ("data{}_bch{}_ep{}".format(
298 | len(self.train_in_data_list), self.batch_size, self.epoch))
299 | summary_name += ("_seed{}".format(self.seed) if self.seed is not None else "")
300 | summary_writer = tf.compat.v1.summary.FileWriter(
301 | os.path.join(self.summaries_dir, summary_name),
302 | graph=sess.graph,
303 | flush_secs=self.summary_flush_secs)
304 |
305 | # Compute loss on validation dataset to check overfitting
306 | if self.has_val_data:
307 | val_loss_op = self.validate(model)
308 | # Save validation loss to tensorboard
309 | val_summary_op = tf.compat.v1.summary.scalar('val_loss', val_loss_op)
310 | # Compute initial loss
311 | val_loss, val_summary = sess.run([val_loss_op, val_summary_op])
312 | summary_writer.add_summary(val_summary, global_step=0)
313 | print("Initial Loss on validation dataset: {:.6f}".format(val_loss))
314 |
315 | ################ TRAINING ################
316 | train_start = time.time()
317 | for step in range(sess.run(global_step), self.max_steps):
318 | start_time = time.time()
319 | val_str = ''
320 | if step % self.summaries_save_freq == 0 or step == self.max_steps - 1:
321 | # Train model and record summaries
322 | _, loss_total, summary = sess.run([train_op, loss_op, summary_op])
323 | summary_writer.add_summary(summary, global_step=step)
324 | duration = time.time() - start_time
325 | if self.has_val_data and step != 0:
326 | # Compute validation loss
327 | val_loss, val_summary = sess.run([val_loss_op, val_summary_op])
328 | summary_writer.add_summary(val_summary, global_step=step)
329 | val_str = ', val loss: {:.6f}'.format(val_loss)
330 | else: # Train only
331 | _, loss_total = sess.run([train_op, loss_op])
332 | duration = time.time() - start_time
333 | assert not np.isnan(loss_total), 'Model diverged with loss = NaN'
334 |
335 | if step % self.training_print_freq == 0 or step == self.max_steps - 1:
336 | examples_per_sec = self.batch_size / duration
337 | sec_per_batch = float(duration)
338 | format_str = ('{}: step {}, loss: {:.6f} ({:.1f} data/s; {:.3f} s/bch)'
339 | .format(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), step, loss_total, examples_per_sec, sec_per_batch))
340 | print(format_str + val_str)
341 |
342 | if (step + 1) % self.ckpts_save_freq == 0 or step == self.max_steps - 1:
343 | # Save current model in a checkpoint
344 | self.save(sess, self.checkpoints_dir, step + 1)
345 | print_("Training duration: {:0.4f}s\n".format(time.time() - train_start), 'm')
346 | print_("--------End of training--------\n", 'm')
347 | # Free all resources associated with the session
348 | sess.close()
349 |
350 | def save(self, sess, checkpoint_dir, step):
351 | if not os.path.exists(checkpoint_dir):
352 | os.makedirs(checkpoint_dir)
353 | self.saver.save(sess, os.path.join(checkpoint_dir, self.ckpt_save_name), global_step=step)
354 |
355 | def load(self, sess, checkpoint_dir):
356 | ckpt_names = get_ckpt_list(checkpoint_dir)
357 | if not ckpt_names: # list is empty
358 | print_("No checkpoints found in {}\n".format(checkpoint_dir), 'm')
359 | return False
360 | else:
361 | print_("Found checkpoints:\n", 'm')
362 | for name in ckpt_names:
363 | print(" {}".format(name))
364 | # Ask user if they prefer to start training from scratch or resume training on a specific ckeckpoint
365 | while True:
366 | mode=str(input('Start training from scratch (start) or resume training from a previous checkpoint (choose one of the above): '))
367 | if mode == 'start' or mode in ckpt_names:
368 | break
369 | else:
370 | print("Answer should be 'start' or one of the following checkpoints: {}".format(ckpt_names))
371 | continue
372 | if mode == 'start':
373 | return False
374 | elif mode in ckpt_names:
375 | # Try to load given intermediate checkpoint
376 | print_("Loading trained model...\n", 'm')
377 | self.saver.restore(sess, os.path.join(checkpoint_dir, mode))
378 | print_("...Checkpoint {} loaded\n".format(mode), 'm')
379 | return True
380 | else:
381 | raise ValueError("User input is neither 'start' nor a valid checkpoint")
382 |
383 | def parse_args():
384 | parser = argparse.ArgumentParser(description='Model training arguments')
385 | parser.add_argument('--bch', type=int, default=16, dest='batch_size', help='training batch size')
386 | parser.add_argument('--ep', type=int, default=10000, dest='epoch', help='training epoch number')
387 | parser.add_argument('--lr', type=float, default=1e-4, dest='learning_rate', help='initial learning rate')
388 | parser.add_argument('--seed', type=int, default=None, dest='seed', help='set random seed for deterministic training')
389 | parser.add_argument('--no-gpu-patch', dest='no_gpu_patch', default=False, action='store_true', help='if seed is set, add this tag for much faster but slightly less deterministic training')
390 | parser.add_argument('--no-resume', dest='no_resume', default=False, action='store_true', help="start training from scratch")
391 | parser.add_argument('--name', type=str, default="trainingTemplateTF", dest='ckpt_save_name', help='name of saved checkpoints/model weights')
392 | args = parser.parse_args()
393 | return args
394 |
395 | if __name__ == '__main__':
396 | args = parse_args()
397 | # set up model to train
398 | model = TrainModel(args)
399 | model.train()
400 |
--------------------------------------------------------------------------------