├── .gitignore ├── CMakeLists.txt ├── INSTALL.md ├── LICENSE ├── Models ├── __init__.py ├── baseModel.py ├── blur │ ├── __init__.py │ └── model.py ├── classTemplateTF │ ├── README.md │ ├── __init__.py │ ├── model.py │ └── train_classification.py ├── common │ ├── __init__.py │ ├── model_builder.py │ └── util.py ├── mrcnn │ ├── __init__.py │ ├── model.py │ ├── utils.py │ └── vis.py ├── regressionTemplateTF │ ├── README.md │ ├── __init__.py │ ├── model.py │ └── train_regression.py └── trainingTemplateTF │ ├── README.md │ ├── __init__.py │ ├── data │ └── train │ │ ├── groundtruth │ │ └── alive00001.png │ │ └── input │ │ └── alive_snow00001.png │ ├── model.py │ └── train_model.py ├── Plugins ├── Client │ ├── CMakeLists.txt │ ├── MLClient.cpp │ ├── MLClient.h │ ├── MLClientComms.cpp │ ├── MLClientComms.h │ ├── MLClientModelManager.cpp │ ├── MLClientModelManager.h │ └── message.proto └── Server │ ├── .dockerignore │ ├── Dockerfile │ ├── __init__.py │ ├── message_pb2.py │ ├── py2.Dockerfile │ └── server.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore 2 | .vscode/ 3 | build/ 4 | build-Debug/ 5 | 6 | # Ignore Python compiled files 7 | *.py[co] 8 | # Ignore configuration and weights files 9 | *.yaml 10 | *.pkl 11 | 12 | # Ignore shared objects 13 | *.so 14 | 15 | # Ignore generated files 16 | *.os 17 | *.o 18 | 19 | # Ignore all directories named: 20 | summaries/ 21 | input/ 22 | groundtruth/ 23 | checkpoints/ 24 | data/ 25 | serverlocal/ 26 | densepose/ -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.10) 2 | project(MachineLearningPlugins VERSION 1.0.0) 3 | 4 | #===------------------------------------------------------------------------=== 5 | # Global settings some based on the external configuration settings 6 | set( CMAKE_CXX_STANDARD 11 ) 7 | set( CMAKE_CXX_EXTENSIONS OFF ) 8 | set( CMAKE_CXX_VISIBILITY_PRESET hidden ) 9 | set( CMAKE_POSITION_INDEPENDENT_CODE True ) 10 | if (CMAKE_CXX_COMPILER_ID MATCHES "Clang") 11 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fcolor-diagnostics") 12 | endif() 13 | 14 | #===------------------------------------------------------------------------=== 15 | # Build information 16 | string( TIMESTAMP BUILDDATE_YEAR_INTERNAL "%Y" ) 17 | string( TIMESTAMP BUILDDATE_MMDD_INTERNAL "%m%d" ) 18 | string( TIMESTAMP BUILDDATE_FULL_INTERNAL "%Y-%m-%dT%H:%M:%S" ) 19 | string( TIMESTAMP BUILDDATE_STAMP "%Y.%m%d" ) 20 | string( REGEX REPLACE "^0" "" BUILDDATE_MMDD_INTERNAL ${BUILDDATE_MMDD_INTERNAL} ) 21 | set( BUILDDATE_YEAR "${BUILDDATE_YEAR_INTERNAL}" CACHE STRING "Year of the build: It will default to the current year." ) 22 | set( BUILDDATE_MMDD "${BUILDDATE_MMDD_INTERNAL}" CACHE STRING "Month and day of the build: It will default to the calendar month and day." ) 23 | set( BUILDDATE_FULL "${BUILDDATE_FULL_INTERNAL}" CACHE STRING "Exact time of the build." ) 24 | 25 | #===------------------------------------------------------------------------=== 26 | # Compile CMakeLists found in subdirectories 27 | add_subdirectory(Plugins/Client) -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | # Installing Nuke Machine Learning Plugin 2 | 3 | The Nuke Machine Learning (ML) installation can be divided into compiling the MLClient Nuke node and installing the MLServer using Docker. 4 | 5 | The MLClient plugin can be compiled on both Linux/MacOS and Windows systems. It communicates with the MLServer which needs to be run on a Linux machine with NVIDIA GPU. 6 | 7 | **Requirements:** 8 | - Linux with Nuke installed 9 | - NVIDIA GPU (Important: GPU memory must be at least 6GB) 10 | - CMake (minimum 3.10) 11 | - Protobuf (tested with 2.5.0 and 3.5.1) 12 | - Docker 13 | 14 | ## Installing the Client on Linux/MacOS 15 | 16 | ### Install Protobuf 17 | 18 | Protocol Buffers (aka Protobuf) are an efficient way of serializing structured data - similar to XML, but faster and simpler. We use it to define, write, and read the data for our client<->server communication. 19 | 20 | Following the [installation instructions](https://github.com/protocolbuffers/protobuf/blob/master/src/README.md) from the Protobuf GitHub repository, we recommend compiling Protobuf from source: 21 | 22 | First get Protobuf source file for C++, for instance version 3.5.1: 23 | ``` 24 | wget https://github.com/protocolbuffers/protobuf/releases/download/v3.5.1/protobuf-cpp-3.5.1.tar.gz 25 | # Extract file in current directory 26 | tar -xzf protobuf-cpp-3.5.1.tar.gz 27 | ``` 28 | Then build and install the C++ Protocol Buffer runtime and the Protocol Buffer compiler (protoc): 29 | ``` 30 | cd protobuf-3.5.1 31 | ./configure 32 | make 33 | make check 34 | sudo make install 35 | sudo ldconfig # refresh shared library cache. 36 | ``` 37 | 38 | Note: Instead of compiling it from source, Protobuf may alternatively be installed with a package manager, for example: 39 | ``` 40 | sudo yum install protobuf-devel 41 | ``` 42 | 43 | ### Compile MLClient Nuke Node 44 | 45 | If not already cloned, fetch the `nuke-ML-server` repository: 46 | ``` 47 | git clone https://github.com/TheFoundryVisionmongers/nuke-ML-server 48 | ``` 49 | Execute the commands below to compile the client MLClient.so plugin, setting the NUKE_INSTALL_PATH to point to the folder of the desired Nuke version: 50 | ``` 51 | cd nuke-ML-server/ 52 | mkdir build && cd build 53 | cmake -DNUKE_INSTALL_PATH=/path/to/Nuke11.3v1/ .. 54 | make 55 | ``` 56 | The MLClient.so plugin will now be in the `build/Plugins/Client` folder. Before it can be used, Nuke needs to know where it lives. One way to do this is to update the NUKE_PATH environment variable to point to the MLClient.so plugin (This can be skipped if it was moved to the root of your ~/.nuke folder, or the path was added in Nuke through Python): 57 | ``` 58 | export NUKE_PATH=/path/to/lib/:$NUKE_PATH 59 | ``` 60 | At that point, after opening Nuke and updating all plugins, the `MLClient` node should be available. To update all the plugins in Nuke, you can either use the Other > All Plugins > Update option (see [documentation](https://learn.foundry.com/nuke/developers/63/pythondevguide/installing_plugins.html)), or simply press `tab` in the Node Graph then write `Update [All plugins]`. If the `MLClient` node is still missing, verify that the current NUKE_PATH is correctly pointing to the folder containing MLClient.so. 61 | 62 | ## Installing the Client on Windows 63 | 64 | This was tested on Windows 10. You need to have [cmake](https://cmake.org/) and [git](https://git-scm.com/) installed on your computer. 65 | 66 | Start by installing the Visual Studio Compiler "Build Tools for Visual Studio 2017" found at [this link](https://www.visualstudio.com/thank-you-downloading-visual-studio/?sku=BuildTools&rel=15). 67 | 68 | ### Install Protobuf 69 | 70 | We recommend building Protobuf locally as a static library. For reference this section partly follows the [installation instructions](https://github.com/protocolbuffers/protobuf/blob/master/cmake/README.md) from the Protobuf GitHub repository. 71 | 72 | First open “**x64** Native Tools Command Prompt for VS 2017” executable. Please note it has to be **x64** and not x86. 73 | 74 | If `cmake` or `git` commands are not available from Command Prompt, add them to the system PATH variable: 75 | ``` 76 | set PATH=%PATH%;C:\Program Files (x86)\CMake\bin 77 | set PATH=%PATH%;C:\Program Files\Git\cmd 78 | ``` 79 | Clone your chosen Protobuf branch release, for instance here version 3.5.1: 80 | ``` 81 | git clone -b v3.5.1 https://github.com/protocolbuffers/protobuf.git 82 | cd protobuf 83 | git submodule update --init --recursive 84 | cd cmake 85 | mkdir build & cd build 86 | mkdir release & cd release 87 | ``` 88 | Compile protobuf with dynamic VCRTLib (Visual Studio Code C++ Runtime Library): 89 | ``` 90 | cmake -G "NMake Makefiles" -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX= -Dprotobuf_MSVC_STATIC_RUNTIME=OFF -Dprotobuf_BUILD_TESTS=OFF ../.. 91 | ``` 92 | Install protobuf in the specified `` folder by running the following: 93 | ``` 94 | nmake install 95 | ``` 96 | Note: This last command will create the following folders under the `` location: 97 | - bin - that contains protobuf protoc.exe compiler; 98 | - include - that contains C++ headers and protobuf *.proto files; 99 | - lib - that contains linking libraries and CMake configuration files for protobuf package. 100 | 101 | ### Compile MLClient Nuke Node 102 | 103 | If not already done, clone the `nuke-ML-server` repository: 104 | ``` 105 | git clone https://github.com/TheFoundryVisionmongers/nuke-ML-server 106 | cd nuke-ml-server 107 | mkdir build & cd build 108 | mkdir x64-Release & cd x64-Release 109 | ``` 110 | Compile the MLClient and link your version of Nuke and Protobuf install path: 111 | ``` 112 | cmake -G "NMake Makefiles" -DCMAKE_BUILD_TYPE=Release -DNUKE_INSTALL_PATH=”/path/to//Nuke12.0v3” -DProtobuf_LIBRARIES=”/lib” -DProtobuf_INCLUDE_DIR=”/include” -DProtobuf_PROTOC_EXECUTABLE="/bin/protoc.exe" ../.. 113 | nmake 114 | ``` 115 | The MLClient.dll plugin should now be in the `build/x64-Release/Plugins/Client` folder. Before it can be used, Nuke needs to know where it lives. You can either copy it to your ~/.nuke folder or update the NUKE_PATH environment: 116 | ``` 117 | set NUKE_PATH=%NUKE_PATH%;path/to/lib 118 | ``` 119 | At that point, after opening Nuke and updating all plugins, the `MLClient` node should be available. To update all the plugins in Nuke, you can either use the Other > All Plugins > Update option (see [documentation](https://learn.foundry.com/nuke/developers/63/pythondevguide/installing_plugins.html)), or simply press `tab` in the Node Graph then write `Update [All plugins]`. If the `MLClient` node is still missing, verify that the current NUKE_PATH is correctly pointing to the folder containing MLClient.dll. 120 | 121 | As your client is on a Windows machine, you now need to run the server on a Linux machine with NVidia GPU (see [next section](https://github.com/TheFoundryVisionmongers/nuke-ML-server/blob/master/INSTALL.md#installing-the-server)) and connect your Windows machine to it following the [Connect to an External Server](https://github.com/TheFoundryVisionmongers/nuke-ML-server/blob/master/INSTALL.md#connect-to-an-external-server) section. 122 | 123 | ## Installing the Server 124 | 125 | ### Install Docker 126 | 127 | Docker provides a way to package and run an application in a securely isolated environment called a container. This container includes all the application dependencies and libraries. It ensures that the application works seamlessly inside the container in any system environment. We use docker to create a container that easily runs the MLServer. 128 | 129 | Install Docker: 130 | ``` 131 | # Install the official docker-ce package 132 | sudo curl -sSL https://get.docker.com/ | sh 133 | # Start Docker 134 | sudo systemctl start docker 135 | ``` 136 | Nvidia Docker is a necessary plugin that enables Nvidia GPU-accelerated applications to run in Docker. 137 | 138 | Install nvidia-container-toolkit for your Linux platform by following the [installation instructions](https://github.com/NVIDIA/nvidia-docker) of the nvidia-docker repository. On CentOS/RHEL, you should follow section "CentOS 7 (**docker-ce**), RHEL 7.4/7.5 (**docker-ce**), Amazon Linux 1/2" of the repository. 139 | 140 | Build the docker image from the [Dockerfile](/Plugins/Server/Dockerfile): 141 | ``` 142 | # Start by loading Ubuntu18.04 with cuda 10.0 and cudnn7 as the base image 143 | sudo docker pull nvidia/cuda:10.0-cudnn7-devel-ubuntu18.04 144 | # Build the docker image on top of the base image 145 | cd Plugins/Server/ 146 | # Choose your own label for , it must be lowercase. e.g. mlserver. 147 | sudo docker build -t -f Dockerfile . 148 | ``` 149 | 150 | ### Run Docker Container 151 | 152 | Create and run a docker container on top of the created docker image, referencing the `` from the previous step: 153 | 154 | ``` 155 | sudo docker run --gpus all -v /absolute/path/to/nuke-ML-server/Models/:/workspace/ml-server/models -it 156 | ``` 157 | 158 | Notes: 159 | - the `-v` (volume) option links your host machine Models/ folder with the models/ folder inside your container. You only need to modify `/absolute/path/to/nuke-ML-server/Models/`, leave the `/workspace/ml-server/models` unchanged as it already corresponds to the folder structure inside your Docker image. This option allows you to add models in Models/ that will be directly available and updated inside your container. 160 | - If your docker version doesn't recognise the `--gpus` flag, you can equally run the same docker container by replacing `sudo docker run --gpus all ` by `sudo nvidia-docker run` or `sudo docker run --runtime=nvidia`. 161 | 162 | ## Getting Started 163 | 164 | ### Download Configuration and Weights Files 165 | 166 | To be able to run inference on the Mask-RCNN model, you need to download its configuration and weight files. 167 | 168 | Depending on your GPU memory, you can use either a ResNet101 (GPU memory > 8GB) or a ResNet50 (GPU memory > 6GB) backbone. The results with ResNet101 are slightly better. 169 | - Mask-RCNN requires ~7GB GPU RAM with ResNet101 and ~4.6GB with ResNet50. 170 | 171 | Download your selected configuration and weight files: 172 | - Mask-RCNN ResNet50: 173 | - Configuration: [e2e_mask_rcnn_R-50-FPN_2x.yaml](https://raw.githubusercontent.com/facebookresearch/Detectron/master/configs/12_2017_baselines/e2e_mask_rcnn_R-50-FPN_2x.yaml) 174 | - Corresponding weights: [model_final.pkl](https://dl.fbaipublicfiles.com/detectron/35859007/12_2017_baselines/e2e_mask_rcnn_R-50-FPN_2x.yaml.01_49_07.By8nQcCH/output/train/coco_2014_train%3Acoco_2014_valminusminival/generalized_rcnn/model_final.pkl) (from the Detectron [Model Zoo](https://github.com/facebookresearch/Detectron/blob/master/MODEL_ZOO.md)) 175 | - OR Mask_RCNN ResNet101 176 | - Configuration: [e2e_mask_rcnn_X-101-64x4d-FPN_2x.yaml](https://raw.githubusercontent.com/facebookresearch/Detectron/master/configs/12_2017_baselines/e2e_mask_rcnn_X-101-64x4d-FPN_2x.yaml) 177 | - Correponding weights: [model_final.pkl](https://dl.fbaipublicfiles.com/detectron/35859745/12_2017_baselines/e2e_mask_rcnn_X-101-64x4d-FPN_2x.yaml.02_00_30.ESWbND2w/output/train/coco_2014_train%3Acoco_2014_valminusminival/generalized_rcnn/model_final.pkl) (from the Detectron [Model Zoo](https://github.com/facebookresearch/Detectron/blob/master/MODEL_ZOO.md)) 178 | 179 | And move them to `Models/mrcnn/` folder. 180 | 181 | ResNet50 is the default backbone. If you use ResNet101, you need to modify the config and weight file names in Models/mrcnn/model.py. 182 | 183 | ### Connect Client and Server 184 | 185 | This section explains how to connect the server and client when your docker container and Nuke instance are running on the same Linux machine: 186 | 187 | 0. (If you have stopped your container, follow the [Run Docker Container](https://github.com/TheFoundryVisionmongers/nuke-ML-server/blob/master/INSTALL.md#run-docker-container) section again) 188 | 1. In the running docker container, query the IP address: 189 | ``` 190 | hostname -I 191 | ``` 192 | 2. In Nuke, set the MLClient node `host` to the container IP address, 193 | 3. In the container, launch the server and start listening on port 55555: 194 | ``` 195 | python server.py 55555 196 | ``` 197 | 4. In Nuke, click on the MLClient connect button, you should have the three models available. 198 | 199 | ### Connect to an External Server 200 | 201 | This section explains how to connect server and client when your docker container (MLServer) and Nuke (MLClient) are running on two different machines, e.g. if you are using the MLClient on Windows. In that case, you have a Linux machine running the docker container and a Windows machine running Nuke. 202 | 203 | 1. On your **Linux machine** (not the docker container, not your Windows machine), query the IP adress: 204 | ``` 205 | hostname -I 206 | ``` 207 | 2. In Nuke, set the MLClient node `host` to the Linux machine IP address obtained. 208 | 3. On the Linux machine, run the docker container exporting a port of your choice (here port 7000 of the host is mapped to port 55555 of the container): 209 | ``` 210 | sudo docker run --gpus all -v /absolute/path/to/nuke-ML-server/Models/:/workspace/ml-server/models -p 7000:55555 -it 211 | ``` 212 | 4. In the container, launch the server and start listening on port 55555: 213 | ``` 214 | python server.py 55555 215 | ``` 216 | 5. In Nuke, set the MLClient node `port` to 7000 and click on the MLClient connect button. 217 | 218 | ### Add your own Model 219 | 220 | To implement your own model, you can create a new folder in the /Models directory with your model name. At the minimum, this folder needs to include an empty `__init__.py` file and a `model.py` file that contains a Model class inheriting from BaseModel. 221 | 222 | You can copy the simple [Models/blur/](Models/blur) model as a starting point, and implement your own model looking at the examples of blur and mrcnn. 223 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /Models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheFoundryVisionmongers/nuke-ML-server/5dd4d04cd673c60de8093c600d6c54016fca92d6/Models/__init__.py -------------------------------------------------------------------------------- /Models/baseModel.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Foundry. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | ############################################################################## 15 | 16 | import sys 17 | if sys.version_info.major > 2: # python3 18 | unicode = str 19 | 20 | import numpy as np 21 | 22 | class BaseModel(object): 23 | def __init__(self): 24 | self.name = 'Base model' 25 | self.options = () # List of attribute names that should get exposed in Nuke 26 | self.buttons = () # List of button names that should get exposed in Nuke 27 | self.inputs = {'input': 3} # Define Inputs (name, #channels) 28 | self.outputs = {'output': 3} # And Outputs (name, #channels) 29 | pass 30 | 31 | def inference(self, *inputs): 32 | """Do an inference on the model with a set of inputs. 33 | 34 | # Arguments: 35 | inputs: A list of images 36 | 37 | # Return: 38 | The result of the inference as a list of images 39 | """ 40 | raise NotImplementedError 41 | 42 | def get_options(self): 43 | """Get a dictionary of exposed options from the model. 44 | 45 | To expose options, self.options has to be filled with attribute names. 46 | Return a dictionary of option names and values. 47 | """ 48 | opt = {} 49 | if hasattr(self, 'options'): 50 | for option in self.options: 51 | value = getattr(self, option) 52 | if isinstance(value, unicode): 53 | value = str(value) 54 | assert type(value) in [bool, int, float, str], \ 55 | 'Broadcasted options need to be one of bool, int, float, str.' 56 | opt[option] = value 57 | return opt 58 | 59 | def set_options(self, optionsDict): 60 | """Set the options of the model. 61 | 62 | # Arguments: 63 | optionsDict: A dictionary of attribute names and values 64 | """ 65 | for name, value in optionsDict.items(): 66 | setattr(self, name, value) 67 | 68 | def get_buttons(self): 69 | """Return the defined buttons of the model. 70 | 71 | To expose buttons in Nuke, self.buttons has to be filled with attribute names. 72 | """ 73 | btn = {} 74 | if hasattr(self, 'buttons'): 75 | for button in self.buttons: 76 | value = getattr(self, button) 77 | assert type(value) in [bool], 'Broadcasted buttons need to be bool.' 78 | btn[button] = value 79 | return btn 80 | 81 | def set_buttons(self, buttonsDict): 82 | """Set the buttons of the model. 83 | 84 | # Arguments: 85 | buttonsDict: A dictionary of attribute names and values 86 | """ 87 | for name, value in buttonsDict.items(): 88 | setattr(self, name, value) 89 | 90 | def get_inputs(self): 91 | """Return the defined inputs of the model.""" 92 | return self.inputs 93 | 94 | def get_outputs(self): 95 | """Return the defined outputs of the model.""" 96 | return self.outputs 97 | 98 | def get_name(self): 99 | """Return the name of the model.""" 100 | return self.name -------------------------------------------------------------------------------- /Models/blur/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheFoundryVisionmongers/nuke-ML-server/5dd4d04cd673c60de8093c600d6c54016fca92d6/Models/blur/__init__.py -------------------------------------------------------------------------------- /Models/blur/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Foundry. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | ############################################################################## 15 | 16 | from ..baseModel import BaseModel 17 | 18 | import cv2 19 | import numpy as np 20 | 21 | from ..common.util import linear_to_srgb, srgb_to_linear 22 | 23 | import message_pb2 24 | 25 | class Model(BaseModel): 26 | def __init__(self): 27 | super(Model, self).__init__() 28 | self.name = 'Gaussian Blur' 29 | 30 | self.kernel_size = 5 31 | self.make_blur = False 32 | 33 | # Define options 34 | self.options = ('kernel_size',) 35 | self.buttons = ('make_blur',) 36 | 37 | # Define inputs/outputs 38 | self.inputs = {'input': 3} 39 | self.outputs = {'output': 3} 40 | 41 | def inference(self, image_list): 42 | """Do an inference on the model with a set of inputs. 43 | 44 | # Arguments: 45 | image_list: The input image list 46 | 47 | Return the result of the inference. 48 | """ 49 | image = image_list[0] 50 | image = linear_to_srgb(image) 51 | image = (image * 255).astype(np.uint8) 52 | kernel = self.kernel_size * 2 + 1 53 | blur = cv2.GaussianBlur(image, (kernel, kernel), 0) 54 | blur = blur.astype(np.float32) / 255. 55 | blur = srgb_to_linear(blur) 56 | 57 | # If make_blur button is pressed in Nuke 58 | if self.make_blur: 59 | script_msg = message_pb2.FieldValuePairAttrib() 60 | script_msg.name = "PythonScript" 61 | # Create a Python script message to run in Nuke 62 | python_script = self.blur_script(blur) 63 | script_msg_val = script_msg.values.add() 64 | script_msg_str = script_msg_val.string_attributes.add() 65 | script_msg_str.values.extend([python_script]) 66 | return [blur, script_msg] 67 | 68 | return [blur] 69 | 70 | def blur_script(self, image): 71 | """Return the Python script function to create a pop up window in Nuke. 72 | 73 | The pop up window displays the brightest pixel position of the given image. 74 | """ 75 | # Compute brightest pixel of the image 76 | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 77 | [min_val, max_val, min_loc, max_loc] = cv2.minMaxLoc(gray) 78 | # Y axis are inversed in Nuke 79 | max_loc = (max_loc[0], image.shape[0] - max_loc[1]) 80 | popup_msg = ( 81 | "Brightest pixel of the blurred image\\n" 82 | "Location: {}, Value: {:.3f}." 83 | ).format(max_loc, max_val) 84 | script = "nuke.message('{}')\n".format(popup_msg) 85 | return script -------------------------------------------------------------------------------- /Models/classTemplateTF/README.md: -------------------------------------------------------------------------------- 1 | # Classification Training Template 2 | 3 | The classTemplateTF is a training template written in TensorFlow. It aims at quickly enabling classification training. For instance, detecting the presence of a specific actor in a shot. When trained, the model can be tested and used directly in Nuke through the nuke-ML-server. 4 | 5 | Apart from the dataset structure, all other instructions are similar to the other [Training Template](https://github.com/TheFoundryVisionmongers/nuke-ML-server/tree/master/Models/trainingTemplateTF). 6 | 7 | ## Dataset 8 | 9 | To train the ML algorithm, you need to set-up your dataset in `classTemplateTF/data/train/`. This directory should contain one subdirectory per class. Any PNG, JPG, BMP, PPM or TIF images inside each of the subdirectories will be included. 10 | 11 | For example, if you want to train a classifier to differentiate between cats, dogs and foxes. The `data/train/` directory should have 3 subdirectories named `cats`, `dogs` and `foxes` with each directory containing images of the corresponding animal. 12 | 13 | Optionally, you can add a separate set of images in `classTemplateTF/data/validation/`. If available, it is periodically used to check that there is no overfitting on the training data. Please note that the validation dataset and training dataset must not intersect. 14 | 15 | If no validation dataset is found, 20% of the training data will be used as a validation split. 16 | 17 | -------------------------------------------------------------------------------- /Models/classTemplateTF/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheFoundryVisionmongers/nuke-ML-server/5dd4d04cd673c60de8093c600d6c54016fca92d6/Models/classTemplateTF/__init__.py -------------------------------------------------------------------------------- /Models/classTemplateTF/model.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import sys 4 | import os 5 | import time 6 | 7 | import scipy.misc 8 | import numpy as np 9 | import cv2 10 | 11 | import tensorflow as tf 12 | tf.compat.v1.disable_eager_execution() # For TF 2.x compatibility 13 | 14 | from ..baseModel import BaseModel 15 | from ..common.util import print_, get_saved_model_list, linear_to_srgb 16 | 17 | import message_pb2 18 | 19 | class Model(BaseModel): 20 | """Load your trained model and do inference in Nuke""" 21 | 22 | def __init__(self): 23 | super(Model, self).__init__() 24 | self.name = 'Classification Template' 25 | dir_path = os.path.dirname(os.path.realpath(__file__)) 26 | self.checkpoints_dir = os.path.join(dir_path, 'checkpoints') 27 | self.batch_size = 1 28 | 29 | # Initialise checkpoint name to the most recent trained model 30 | ckpt_names = get_saved_model_list(self.checkpoints_dir) 31 | if not ckpt_names: # empty list 32 | self.checkpoint_name = '' 33 | else: 34 | self.checkpoint_name = ckpt_names[-1] 35 | self.prev_ckpt_name = self.checkpoint_name 36 | 37 | # Button to get classification label 38 | self.get_label = False 39 | 40 | # Define options 41 | self.options = ('checkpoint_name',) 42 | self.buttons = ('get_label',) 43 | # Define inputs/outputs 44 | self.inputs = {'input': 3} 45 | self.outputs = {'output': 3} 46 | 47 | def load_model(self): 48 | # Check if empty or invalid checkpoint name 49 | if self.checkpoint_name=='': 50 | ckpt_names = get_saved_model_list(self.checkpoints_dir) 51 | if not ckpt_names: 52 | raise ValueError("No checkpoints found in {}".format(self.checkpoints_dir)) 53 | else: 54 | raise ValueError("Empty checkpoint name, try an available checkpoint in {} (ex: {})" 55 | .format(self.checkpoints_dir, ckpt_names[-1])) 56 | print_("Loading trained model checkpoint...\n", 'm') 57 | # Load from given checkpoint file name 58 | model = tf.keras.models.load_model(os.path.join(self.checkpoints_dir, self.checkpoint_name)) 59 | model._make_predict_function() 60 | print_("...Checkpoint {} loaded\n".format(self.checkpoint_name), 'm') 61 | return model 62 | 63 | def inference(self, image_list): 64 | """Do an inference on the model with a set of inputs. 65 | 66 | # Arguments: 67 | image_list: The input image list 68 | 69 | Return the result of the inference. 70 | """ 71 | image = image_list[0] 72 | image = linear_to_srgb(image).copy() 73 | image = (image * 255).astype(np.uint8) 74 | 75 | if not hasattr(self, 'model'): 76 | # Initialise tensorflow graph 77 | tf.compat.v1.reset_default_graph() 78 | config = tf.compat.v1.ConfigProto() 79 | config.gpu_options.allow_growth=True 80 | self.sess = tf.compat.v1.Session(config=config) 81 | # Necessary to switch / load_weights on different h5 file 82 | tf.compat.v1.keras.backend.set_session(self.sess) 83 | # Load most recent trained model 84 | self.model = self.load_model() 85 | self.graph = tf.compat.v1.get_default_graph() 86 | self.prev_ckpt_name = self.checkpoint_name 87 | self.class_labels = (self.checkpoint_name.split('.')[0]).split('_') 88 | else: 89 | tf.compat.v1.keras.backend.set_session(self.sess) 90 | 91 | # If checkpoint name has changed, load new checkpoint 92 | if self.prev_ckpt_name != self.checkpoint_name or self.checkpoint_name == '': 93 | self.model = self.load_model() 94 | self.graph = tf.compat.v1.get_default_graph() 95 | self.class_labels = (self.checkpoint_name.split('.')[0]).split('_') 96 | # If checkpoint correctly loaded, update previous checkpoint name 97 | self.prev_ckpt_name = self.checkpoint_name 98 | 99 | image = cv2.resize(image, dsize=(224, 224), interpolation=cv2.INTER_NEAREST) 100 | # Predict on new data 101 | image_batch = np.expand_dims(image, 0) 102 | # Preprocess a numpy array encoding a batch of images (RGB values within [0, 255]) 103 | image_batch = tf.keras.applications.mobilenet.preprocess_input(image_batch) 104 | start = time.time() 105 | 106 | with self.graph.as_default(): 107 | y_prob = self.model.predict(image_batch) 108 | 109 | y_class = y_prob.argmax(axis=-1)[0] 110 | duration = time.time() - start 111 | # Print results on server side 112 | print('Inference duration: {:4.3f}s'.format(duration)) 113 | class_scores = str(["{0:0.4f}".format(i) for i in y_prob[0]]).replace("'", "") 114 | print("Class scores: {} --> Label: {}".format(class_scores, self.class_labels[y_class])) 115 | 116 | # If get_label button is pressed in Nuke 117 | if self.get_label: 118 | # Send back which class was detected 119 | script_msg = message_pb2.FieldValuePairAttrib() 120 | script_msg.name = "PythonScript" 121 | # Create a Python script message to run in Nuke 122 | nuke_msg = "Class scores: {}\\nLabel: {}".format(class_scores, self.class_labels[y_class]) 123 | python_script = "nuke.message('{}')\n".format(nuke_msg) 124 | script_msg_val = script_msg.values.add() 125 | script_msg_str = script_msg_val.string_attributes.add() 126 | script_msg_str.values.extend([python_script]) 127 | return [image_list[0], script_msg] 128 | return [image_list[0]] -------------------------------------------------------------------------------- /Models/classTemplateTF/train_classification.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import sys 4 | import os 5 | import time 6 | import random 7 | import argparse 8 | from datetime import datetime 9 | 10 | import scipy.misc 11 | import numpy as np 12 | 13 | import tensorflow as tf 14 | tf.compat.v1.disable_eager_execution() # For TF 2.x compatibility 15 | 16 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 17 | from common.model_builder import mobilenet_transfer 18 | from common.util import im2uint8, get_filepaths_from_dir, get_saved_model_list, get_labels_from_dir, print_ 19 | 20 | class TrainModel(object): 21 | """Train the chosen model from the given input and groundtruth data""" 22 | 23 | def __init__(self, args): 24 | # Training hyperparameters 25 | self.learning_rate = args.learning_rate 26 | self.batch_size = args.batch_size 27 | self.epoch = args.epoch 28 | self.save_model_period = 1 # save model weights every N epochs 29 | # Training and validation dataset paths 30 | self.train_data_path = './data/train' 31 | self.val_data_path = './data/validation' 32 | # Where to save and load model weights (=checkpoints) 33 | self.checkpoints_dir = './checkpoints' 34 | if not os.path.exists(self.checkpoints_dir): 35 | os.makedirs(self.checkpoints_dir) 36 | self.ckpt_save_name = 'classTemplate' 37 | # Where to save tensorboard summaries 38 | self.summaries_dir = './summaries' 39 | if not os.path.exists(self.summaries_dir): 40 | os.makedirs(self.summaries_dir) 41 | 42 | # Get training dataset as lists of image paths 43 | self.train_gt_data_list = get_filepaths_from_dir(self.train_data_path) 44 | if len(self.train_gt_data_list) is 0: 45 | raise ValueError("No training data found in folder {}".format(self.train_data_path)) 46 | elif (len(self.train_gt_data_list) < self.batch_size): 47 | raise ValueError("Batch size must be smaller than the dataset (batch size = {}, number of training data = {})" 48 | .format(self.batch_size, len(self.train_gt_data_list))) 49 | 50 | # Get validation dataset if provided 51 | self.has_val_data = True 52 | self.val_gt_data_list = get_filepaths_from_dir(self.val_data_path) 53 | if len(self.val_gt_data_list) is 0: 54 | print("No validation data found in {}, 20% of training data will be used as validation data".format(self.val_data_path)) 55 | self.has_val_data = False 56 | self.validation_split = 0.2 57 | elif (len(self.val_gt_data_list) < self.batch_size): 58 | raise ValueError("Batch size must be smaller than the dataset (batch size = {}, number of validation data = {})" 59 | .format(self.batch_size, len(self.val_gt_data_list))) 60 | else: 61 | print_("Number of validation data: {}\n".format(len(self.val_gt_data_list)), 'm') 62 | self.validation_split = 0.0 63 | 64 | self.train_labels = get_labels_from_dir(self.train_data_path) 65 | # Check class labels are the same 66 | if self.has_val_data: 67 | self.val_labels = get_labels_from_dir(self.val_data_path) 68 | if self.train_labels != self.val_labels: 69 | if len(self.train_labels) != len(self.val_labels): 70 | raise ValueError("{} and {} should have the same number of subdirectories ({}!={})" 71 | .format(self.train_data_path, self.val_data_path, len(self.train_labels), len(self.val_labels))) 72 | raise ValueError("{} and {} should have the same subdirectory label names ({}!={})" 73 | .format(self.train_data_path, self.val_data_path, self.train_labels, self.val_labels)) 74 | 75 | # Compute and print training hyperparameters 76 | self.batch_per_epoch = int(np.ceil(len(self.train_gt_data_list) / float(self.batch_size))) 77 | self.max_steps = int(self.epoch * (self.batch_per_epoch)) 78 | print_("Number of training data: {}\nNumber of batches per epoch: {} (batch size = {})\nNumber of training steps for {} epochs: {}\n" 79 | .format(len(self.train_gt_data_list), self.batch_per_epoch, self.batch_size, self.epoch, self.max_steps), 'm') 80 | print("Class labels: {}".format(self.train_labels)) 81 | 82 | def train(self): 83 | # Build model 84 | self.model = mobilenet_transfer(len(self.train_labels)) 85 | # Configure the model for training 86 | self.model.compile(optimizer=tf.keras.optimizers.Adam(), 87 | loss='categorical_crossentropy', 88 | metrics=['accuracy']) 89 | # Print current model layers 90 | # self.model.summary() 91 | 92 | # Set preprocessing function 93 | datagen = tf.keras.preprocessing.image.ImageDataGenerator( 94 | # scale pixels between -1 and 1, sample-wise 95 | preprocessing_function=tf.keras.applications.mobilenet.preprocess_input, 96 | validation_split=self.validation_split) 97 | # Get classification data 98 | train_generator=datagen.flow_from_directory( 99 | self.train_data_path, 100 | target_size=(224,224), 101 | color_mode='rgb', 102 | batch_size=self.batch_size, 103 | class_mode='categorical', 104 | shuffle=True, 105 | subset='training') 106 | if self.has_val_data: 107 | validation_generator=datagen.flow_from_directory( 108 | self.val_data_path, 109 | target_size=(224,224), 110 | color_mode='rgb', 111 | batch_size=self.batch_size, 112 | class_mode='categorical', 113 | shuffle=True) 114 | else: # Generate a split of the training data as validation data 115 | validation_generator=datagen.flow_from_directory( 116 | self.train_data_path, # subset from training data path 117 | target_size=(224,224), 118 | color_mode='rgb', 119 | batch_size=self.batch_size, 120 | class_mode='categorical', 121 | shuffle=True, 122 | subset='validation') 123 | 124 | # Callback for creating Tensorboard summary 125 | summary_name = "classif_data{}_bch{}_ep{}".format(len(self.train_gt_data_list), self.batch_size, self.epoch) 126 | tensorboard_callback = tf.keras.callbacks.TensorBoard( 127 | log_dir=os.path.join(self.summaries_dir, summary_name)) 128 | # Callback for saving models periodically 129 | class_labels_save = '_'.join(self.train_labels) + '.' 130 | # 'acc' is the training accuracy and 'val_acc' is the validation set accuracy 131 | self.ckpt_save_name = class_labels_save + self.ckpt_save_name + "-val_acc{val_acc:.2f}-acc{acc:.2f}-ep{epoch:04d}.h5" 132 | checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( 133 | filepath=os.path.join(self.checkpoints_dir, self.ckpt_save_name), 134 | save_weights_only=False, 135 | period=self.save_model_period, 136 | save_best_only=True, monitor='val_acc', mode='max' 137 | ) 138 | 139 | # Check if there are intermediate trained model to load 140 | # Uncomment following lines if you want to resume from a previous saved model 141 | # if not self.load_model(): 142 | # print_("Starting training from scratch\n", 'm') 143 | 144 | # Train the model 145 | fit_history = self.model.fit_generator( 146 | generator=train_generator, 147 | steps_per_epoch=train_generator.n // self.batch_size, 148 | validation_data=validation_generator, 149 | validation_steps= validation_generator.n // self.batch_size, 150 | epochs=self.epoch, 151 | callbacks=[checkpoint_callback, tensorboard_callback]) 152 | 153 | print_("--------End of training--------\n", 'm') 154 | 155 | def load_model(self): 156 | """Ask user if start training from scratch or resume from a previous checkpoint 157 | 158 | If resume, load model in self.model and return True, else return False 159 | """ 160 | ckpt_names = get_saved_model_list(self.checkpoints_dir) 161 | if not ckpt_names: # list is empty 162 | print_("No checkpoints found in {}\n".format(self.checkpoint_dir), 'm') 163 | return False 164 | else: 165 | print_("Found checkpoints:\n", 'm') 166 | for name in ckpt_names: 167 | print(" {}".format(name)) 168 | # Ask user if they prefer to start training from scratch or resume training on a specific ckeckpoint 169 | while True: 170 | mode=str(raw_input('Start training from scratch (start) or resume training from a previous checkpoint (choose one of the above): ')) 171 | if mode == 'start' or mode in ckpt_names: 172 | break 173 | else: 174 | print("Answer should be 'start' or one of the following checkpoints: {}".format(ckpt_names)) 175 | continue 176 | if mode == 'start': 177 | return False 178 | elif mode in ckpt_names: 179 | # Try to load given intermediate checkpoint 180 | print_("Loading trained model...\n", 'm') 181 | self.model = tf.keras.models.load_model(os.path.join(self.checkpoints_dir, mode)) 182 | print_("...Checkpoint {} loaded\n".format(mode), 'm') 183 | return True 184 | else: 185 | raise ValueError("User input is neither 'start' nor a valid checkpoint") 186 | 187 | def parse_args(): 188 | parser = argparse.ArgumentParser(description='Model training arguments') 189 | parser.add_argument('--bch', type=int, default=16, dest='batch_size', help='training batch size') 190 | parser.add_argument('--ep', type=int, default=100, dest='epoch', help='training epoch number') 191 | parser.add_argument('--lr', type=float, default=1e-4, dest='learning_rate', help='initial learning rate') 192 | args = parser.parse_args() 193 | return args 194 | 195 | if __name__ == '__main__': 196 | args = parse_args() 197 | # set up model to train 198 | model = TrainModel(args) 199 | model.train() -------------------------------------------------------------------------------- /Models/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheFoundryVisionmongers/nuke-ML-server/5dd4d04cd673c60de8093c600d6c54016fca92d6/Models/common/__init__.py -------------------------------------------------------------------------------- /Models/common/model_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Foundry. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | ############################################################################## 15 | 16 | from builtins import range # python 2/3 forward-compatible (xrange) 17 | import tensorflow as tf 18 | 19 | class ResNetBlock(tf.keras.layers.Layer): 20 | """Classic ResNet residual block""" 21 | 22 | def __init__(self, new_dim=32, ksize=5, name='resblock'): 23 | super(ResNetBlock, self).__init__(name=name) 24 | self.conv2D_1 = tf.keras.layers.Conv2D( 25 | filters=new_dim, kernel_size=ksize, padding='SAME', 26 | activation=tf.nn.relu, name='conv1') 27 | self.conv2D_2 = tf.keras.layers.Conv2D( 28 | filters=new_dim, kernel_size=ksize, padding='SAME', 29 | activation=None, name='conv2') 30 | 31 | def call(self, inputs): 32 | x = self.conv2D_1(inputs) 33 | x = self.conv2D_2(x) 34 | return x + inputs 35 | 36 | class EncoderDecoder(tf.keras.Model): 37 | """Create an encoder decoder model""" 38 | 39 | def __init__(self, n_levels, scale, channels, name='g_net'): 40 | super(EncoderDecoder, self).__init__(name=name) 41 | self.n_levels = n_levels 42 | self.scale = scale 43 | 44 | # Encoder layers 45 | self.conv1_1 = tf.keras.layers.Conv2D( 46 | filters=32, kernel_size=5, padding='SAME', 47 | activation=tf.nn.relu, name='enc1_1') 48 | self.block1_2 = ResNetBlock(32, 5, name='enc1_2') 49 | self.block1_3 = ResNetBlock(32, 5, name='enc1_3') 50 | self.block1_4 = ResNetBlock(32, 5, name='enc1_4') 51 | self.conv2_1 = tf.keras.layers.Conv2D( 52 | filters=64, kernel_size=5, strides=2, 53 | padding='SAME', activation=tf.nn.relu, name='enc2_1') 54 | self.block2_2 = ResNetBlock(64, 5, name='enc2_2') 55 | self.block2_3 = ResNetBlock(64, 5, name='enc2_3') 56 | self.block2_4 = ResNetBlock(64, 5, name='enc2_4') 57 | self.conv3_1 = tf.keras.layers.Conv2D( 58 | filters=128, kernel_size=5, strides=2, 59 | padding='SAME', activation=tf.nn.relu, name='enc3_1') 60 | self.block3_2 = ResNetBlock(128, 5, name='enc3_2') 61 | self.block3_3 = ResNetBlock(128, 5, name='enc3_3') 62 | self.block3_4 = ResNetBlock(128, 5, name='enc3_4') 63 | # Decoder layers 64 | self.deblock3_3 = ResNetBlock(128, 5, name='dec3_3') 65 | self.deblock3_2 = ResNetBlock(128, 5, name='dec3_2') 66 | self.deblock3_1 = ResNetBlock(128, 5, name='dec3_1') 67 | self.deconv2_4 = tf.keras.layers.Conv2DTranspose( 68 | filters=64, kernel_size=4, strides=2, 69 | padding='SAME', activation=tf.nn.relu, name='dec2_4') 70 | self.deblock2_3 = ResNetBlock(64, 5, name='dec2_3') 71 | self.deblock2_2 = ResNetBlock(64, 5, name='dec2_2') 72 | self.deblock2_1 = ResNetBlock(64, 5, name='dec2_1') 73 | self.deconv1_4 = tf.keras.layers.Conv2DTranspose( 74 | filters=32, kernel_size=4, strides=2, 75 | padding='SAME', activation=tf.nn.relu, name='dec1_4') 76 | self.deblock1_3 = ResNetBlock(32, 5, name='dec1_3') 77 | self.deblock1_2 = ResNetBlock(32, 5, name='dec1_2') 78 | self.deblock1_1 = ResNetBlock(32, 5, name='dec1_1') 79 | self.deconv0_4 = tf.keras.layers.Conv2DTranspose( 80 | filters=channels, kernel_size=5, padding='SAME', 81 | activation=None, name='dec1_0') 82 | 83 | def call(self, inputs, reuse=False): 84 | # Apply encoder decoder 85 | n, h, w, c = inputs.get_shape().as_list() 86 | n_outputs = [] 87 | input_pred = inputs 88 | with tf.compat.v1.variable_scope('', reuse=reuse): 89 | for i in range(self.n_levels): 90 | scale = self.scale ** (self.n_levels - i - 1) 91 | hi = int(round(h * scale)) 92 | wi = int(round(w * scale)) 93 | input_init = tf.image.resize(inputs, [hi, wi], method='bilinear') 94 | input_pred = tf.stop_gradient(tf.image.resize(input_pred, [hi, wi], method='bilinear')) 95 | input_all = tf.concat([input_init, input_pred], axis=3, name='inp') 96 | 97 | # Encoder 98 | conv1_1 = self.conv1_1(input_all) 99 | conv1_2 = self.block1_2(conv1_1) 100 | conv1_3 = self.block1_3(conv1_2) 101 | conv1_4 = self.block1_4(conv1_3) 102 | conv2_1 = self.conv2_1(conv1_4) 103 | conv2_2 = self.block2_2(conv2_1) 104 | conv2_3 = self.block2_3(conv2_2) 105 | conv2_4 = self.block2_4(conv2_3) 106 | conv3_1 = self.conv3_1(conv2_4) 107 | conv3_2 = self.block3_2(conv3_1) 108 | conv3_3 = self.block3_3(conv3_2) 109 | encoded = self.block3_4(conv3_3) 110 | 111 | # Decoder 112 | deconv3_3 = self.deblock3_3(encoded) 113 | deconv3_2 = self.deblock3_2(deconv3_3) 114 | deconv3_1 = self.deblock3_1(deconv3_2) 115 | deconv2_4 = self.deconv2_4(deconv3_1) 116 | cat2 = deconv2_4 + conv2_4 # Skip connection 117 | deconv2_3 = self.deblock2_3(cat2) 118 | deconv2_2 = self.deblock2_2(deconv2_3) 119 | deconv2_1 = self.deblock2_1(deconv2_2) 120 | deconv1_4 = self.deconv1_4(deconv2_1) 121 | cat1 = deconv1_4 + conv1_4 # Skip connection 122 | deconv1_3 = self.deblock1_3(cat1) 123 | deconv1_2 = self.deblock1_2(deconv1_3) 124 | deconv1_1 = self.deblock1_1(deconv1_2) 125 | input_pred = self.deconv0_4(deconv1_1) 126 | 127 | if i >= 0: 128 | n_outputs.append(input_pred) 129 | if i == 0: 130 | tf.compat.v1.get_variable_scope().reuse_variables() 131 | return n_outputs 132 | 133 | def mobilenet_transfer(class_number): 134 | """Return a classification model with a mobilenet backbone pretrained on ImageNet 135 | 136 | # Arguments: 137 | class_number: Number of classes / labels to detect 138 | """ 139 | # Import the mobilenet model and discards the last 1000 neuron layer. 140 | base_model = tf.keras.applications.MobileNet(input_shape=(224,224,3), weights='imagenet',include_top=False, pooling='avg') 141 | 142 | x = base_model.output 143 | x = tf.keras.layers.Dense(1024,activation='relu')(x) 144 | x = tf.keras.layers.Dense(1024,activation='relu')(x) 145 | x = tf.keras.layers.Dense(512,activation='relu')(x) 146 | # Final layer with softmax activation 147 | preds = tf.keras.layers.Dense(class_number,activation='softmax')(x) 148 | # Build the model 149 | model = tf.keras.models.Model(inputs=base_model.input,outputs=preds) 150 | 151 | # Freeze base_model 152 | # for layer in base_model.layers: # <=> to [:86] 153 | # layer.trainable = False 154 | # Freeze the first 60 layers and fine-tune the rest 155 | for layer in model.layers[:60]: 156 | layer.trainable=False 157 | for layer in model.layers[60:]: 158 | layer.trainable=True 159 | 160 | return model 161 | 162 | def baseline_model(input_shape, output_param_number=1, hidden_layer_size=16): 163 | """Return a fully connected model with 1 hidden layer""" 164 | if hidden_layer_size < output_param_number: 165 | raise ValueError("Neurons in the hidden layer (={}) \ 166 | should be > output param number (={})".format( 167 | hidden_layer_size, output_param_number)) 168 | model = tf.keras.Sequential() 169 | model.add(tf.keras.layers.Flatten(input_shape=input_shape)) 170 | # Regular densely connected NN layer 171 | model.add(tf.keras.layers.Dense(hidden_layer_size, activation=tf.nn.relu)) 172 | model.add(tf.keras.layers.Dense(output_param_number, activation=None)) # linear activation 173 | return model -------------------------------------------------------------------------------- /Models/common/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Foundry. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | ############################################################################## 15 | 16 | import sys 17 | import os 18 | import re 19 | 20 | import numpy as np 21 | import OpenEXR, Imath 22 | import cv2 23 | 24 | import tensorflow as tf 25 | 26 | def print_(str, colour='', bold=False): 27 | if colour == 'w': # yellow warning 28 | sys.stdout.write('\033[93m') 29 | elif colour == "e": # red error 30 | sys.stdout.write('\033[91m') 31 | elif colour == "m": # magenta info 32 | sys.stdout.write('\033[95m') 33 | if bold: 34 | sys.stdout.write('\033[1m') 35 | sys.stdout.write(str) 36 | sys.stdout.write('\033[0m') 37 | sys.stdout.flush() 38 | 39 | ## GET DATA ## 40 | 41 | def get_filepaths_from_dir(dir_path): 42 | """Recursively walk through the given directory and return a list of file paths""" 43 | data_list = [] 44 | for (root, directories, filenames) in os.walk(dir_path): 45 | directories.sort() 46 | filenames.sort() 47 | for filename in filenames: 48 | data_list += [os.path.join(root,filename)] 49 | return data_list 50 | 51 | def get_labels_from_dir(dir_path): 52 | """Return classification class labels (= first subdirectories names)""" 53 | labels_list = [] 54 | for (root, directories, filenames) in os.walk(dir_path): 55 | directories.sort() 56 | labels_list += directories 57 | # Break to only keep the top directory 58 | break 59 | # Remove '.' in folder names for label retrieval in model.py 60 | labels_list = [''.join(label.split('.')) for label in labels_list] 61 | return labels_list 62 | 63 | def atoi(text): 64 | return int(text) if text.isdigit() else text 65 | 66 | def natural_keys(text): 67 | """Use mylist.sort(key=natural_keys) to sort mylist in human order""" 68 | return [atoi(c) for c in re.split(r'(\d+)', text)] 69 | 70 | def get_ckpt_list(ckpt_dir): 71 | filenames_list = [] 72 | for (root, directories, filenames) in os.walk(ckpt_dir): 73 | filenames_list += filenames 74 | # Break to only keep the top directory 75 | break 76 | ckpt_list = [] 77 | for filename in filenames_list: 78 | split = filename.split('.') 79 | if len(split) > 1 and split[-1] == 'index': 80 | # remove .index to get the ckeckpoint name 81 | ckpt_list += [filename[:-6]] 82 | ckpt_list.sort(key=natural_keys) 83 | return ckpt_list 84 | 85 | def get_saved_model_list(ckpt_dir): 86 | """Return a list of HDF5 models found in ckpt_dir""" 87 | filenames_list = [] 88 | for (root, directories, filenames) in os.walk(ckpt_dir): 89 | filenames_list += filenames 90 | # Break to only keep the top directory 91 | break 92 | ckpt_list = [] 93 | for filename in filenames_list: 94 | if filename.endswith(('.h5', '.hdf5')): 95 | ckpt_list += [filename] 96 | ckpt_list.sort(key=natural_keys) 97 | return ckpt_list 98 | 99 | ## PROCESS DATA ## 100 | 101 | def im2uint8(x): 102 | if x.__class__ == tf.Tensor: 103 | return tf.cast(tf.clip_by_value(x, 0.0, 1.0) * 255.0, tf.uint8) 104 | else: 105 | t = np.clip(x, 0.0, 1.0) * 255.0 106 | return t.astype(np.uint8) 107 | 108 | def srgb_to_linear(x): 109 | """Transform the image from sRGB to linear""" 110 | a = 0.055 111 | x = np.clip(x, 0, 1) 112 | mask = x < 0.04045 113 | x[mask] /= 12.92 114 | x[mask!=True] = np.exp(2.4 * (np.log(x[mask!=True] + a) - np.log(1 + a))) 115 | return x 116 | 117 | def linear_to_srgb(x): 118 | """Transform the image from linear to sRGB""" 119 | a = 0.055 120 | x = np.clip(x, 0, 1) 121 | mask = x <= 0.0031308 122 | x[mask] *= 12.92 123 | x[mask!=True] = np.exp(np.log(1 + a) + (1/2.4) * np.log(x[mask!=True])) - a 124 | return x 125 | 126 | ## EXR DATA UTILS ## 127 | 128 | """ 129 | EXR utility functions have to be wrapped in a TensorFlow graph by using 130 | tf.numpy_function(). This function requires a specific fixed return type, 131 | which is why all EXR reading functions are of return type float32. 132 | """ 133 | # Imath.PixelType can have UINT unint32, HALF float16, FLOAT float32 134 | EXR_PIX_TYPE = Imath.PixelType(Imath.PixelType.FLOAT) 135 | EXR_NP_TYPE = np.float32 136 | 137 | def is_exr(filename): 138 | file_extension = os.path.splitext(filename)[1][1:] 139 | if file_extension in ['exr', 'EXR']: 140 | return True 141 | elif file_extension in ['jpg', 'jpeg', 'png', 'bmp', 'JPG', 'JPEG', 'PNG', 'BMP']: 142 | return False 143 | else: 144 | raise TypeError("{} unhandled type extensions. Should be one of " 145 | "['jpg', 'jpeg', 'png', 'bmp', 'exr']". format(file_extension)) 146 | 147 | def check_exr(exr_files, channel_names=['R', 'G', 'B']): 148 | """Check that exr_files (a list of EXR file(s)) have the requested channels 149 | and have the same data window size. Return image width and height. 150 | """ 151 | if not list(channel_names): 152 | raise ValueError("channel_names is empty") 153 | if isinstance(exr_files, OpenEXR.InputFile): # single exr file 154 | exr_files = [exr_files] 155 | elif not isinstance(exr_files, list): 156 | raise TypeError("type(exr_files): {}, should be str or list".format(type(exr_files))) 157 | # Check data window size 158 | data_windows = [str(exr.header()['dataWindow']) for exr in exr_files] 159 | if any(dw != data_windows[0] for dw in data_windows): 160 | raise ValueError("input and groundtruth .exr images have different size") 161 | # Check channel to read are present in given exr file(s) 162 | channels_headers = [exr.header()['channels'] for exr in exr_files] 163 | for channels in channels_headers: 164 | if any(c not in list(channels.keys()) for c in channel_names): 165 | raise ValueError("Try to read channels {} of an exr image with channels {}" 166 | .format(channel_names, list(channels.keys()))) 167 | # Compute the size 168 | dw = exr_files[0].header()['dataWindow'] 169 | width = dw.max.x - dw.min.x + 1 170 | height = dw.max.y - dw.min.y + 1 171 | return width, height 172 | 173 | def read_exr(exr_path, channel_names=['R', 'G', 'B']): 174 | """Read requested channels of an exr and return them in a numpy array 175 | """ 176 | # Open and check the input file 177 | exr_file = OpenEXR.InputFile(exr_path) 178 | width, height = check_exr(exr_file, channel_names) 179 | # Copy channels from an exr file into a numpy array 180 | exr_numpy = [np.frombuffer(exr_file.channel(c, EXR_PIX_TYPE), dtype=EXR_NP_TYPE) 181 | .reshape(height, width) for c in channel_names] 182 | exr_numpy = np.stack(exr_numpy, axis=-1) 183 | return exr_numpy 184 | 185 | def read_resize_exr(exr_path, patch_size, channel_names=['R', 'G', 'B']): 186 | """Read requested channels of an exr as numpy array 187 | and return them resized to (patch_size, patch_size) 188 | """ 189 | exr = read_exr(exr_path, channel_names) 190 | exr_resize = cv2.resize(exr, dsize=(patch_size, patch_size)) 191 | return exr_resize 192 | 193 | def read_crop_exr(exr_file, size, crop_w, crop_h, crop_size=256, channel_names=['R', 'G', 'B']): 194 | """Read requested channels of an exr file, crop it and return it as numpy array 195 | 196 | The cropping box has a size of crop_size and its bottom left point is (crop_h, crop_w) 197 | """ 198 | # Read only the crop scanlines, not the full EXR image 199 | cnames = ''.join(channel_names) 200 | channels = exr_file.channels(cnames=cnames, pixel_type=EXR_PIX_TYPE, 201 | scanLine1=crop_h, scanLine2=crop_h + crop_size - 1) 202 | exr_crop = np.zeros([crop_size, crop_size, len(channel_names)], dtype=EXR_NP_TYPE) 203 | for idx, c in enumerate(channel_names): 204 | exr_crop[:,:,idx] = (np.frombuffer(channels[idx], dtype=EXR_NP_TYPE) 205 | .reshape(crop_size, size[0])[:, crop_w:crop_w+crop_size]) 206 | return exr_crop 207 | 208 | def read_crop_exr_pair(exr_path_in, exr_path_gt, crop_size=256, channel_names=['R', 'G', 'B']): 209 | """Read requested channels of input and groundtruth .exr image paths 210 | and return the same random crop of both 211 | """ 212 | # Open the input file 213 | exr_file_in = OpenEXR.InputFile(exr_path_in) 214 | exr_file_gt = OpenEXR.InputFile(exr_path_gt) 215 | width, height = check_exr([exr_file_in, exr_file_gt], channel_names) 216 | # Check exr image width and height >= crop_size 217 | if height < crop_size or width < crop_size: 218 | raise ValueError("Input images size should be superior or equal to crop_size: {} < ({},{})" 219 | .format((width, height), crop_size, crop_size)) 220 | # Get random crop value 221 | randw = np.random.randint(0, width-crop_size) if width-crop_size > 0 else 0 222 | randh = np.random.randint(0, height-crop_size) if height-crop_size > 0 else 0 223 | # Get the crop of input and groundtruth .exr images 224 | exr_crop_in = read_crop_exr(exr_file_in, (width, height), randw, randh, crop_size, channel_names) 225 | exr_crop_gt = read_crop_exr(exr_file_gt, (width, height), randw, randh, crop_size, channel_names) 226 | return [exr_crop_in, exr_crop_gt] -------------------------------------------------------------------------------- /Models/mrcnn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheFoundryVisionmongers/nuke-ML-server/5dd4d04cd673c60de8093c600d6c54016fca92d6/Models/mrcnn/__init__.py -------------------------------------------------------------------------------- /Models/mrcnn/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Foundry. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | ############################################################################## 15 | 16 | # The inference method is based on: 17 | # -------------------------------------------------------- 18 | # Facebook infer_simple.py file: 19 | # https://github.com/facebookresearch/Detectron/blob/master/tools/infer_simple.py 20 | # Copyright (c) Facebook, Inc. and its affiliates. 21 | # Licensed under the Apache License, Version 2.0 22 | # -------------------------------------------------------- 23 | 24 | import copy 25 | import numpy as np 26 | 27 | from caffe2.python import workspace 28 | # import libcaffe2_detectron_ops_gpu.so 29 | import detectron.utils.c2 as c2_utils 30 | c2_utils.import_detectron_ops() 31 | 32 | from detectron.core.config import assert_and_infer_cfg 33 | from detectron.core.config import cfg 34 | from detectron.core.config import merge_cfg_from_file, merge_cfg_from_cfg 35 | from detectron.utils.collections import AttrDict 36 | from detectron.utils.io import cache_url 37 | from detectron.utils.logging import setup_logging 38 | from detectron.utils.timer import Timer 39 | import detectron.core.test_engine as infer_engine 40 | import detectron.datasets.dummy_datasets as dummy_datasets 41 | import detectron.utils.c2 as c2_utils 42 | 43 | from .vis import vis_one_image_binary, vis_one_image_opencv 44 | from .utils import dict_equal 45 | from ..common.util import linear_to_srgb, srgb_to_linear 46 | from ..baseModel import BaseModel 47 | 48 | class Model(BaseModel): 49 | def __init__(self): 50 | super(Model, self).__init__() 51 | self.name = 'Mask RCNN' 52 | 53 | # Configuration and weights options 54 | # By default, we use ResNet50 backbone architecture, you can switch to 55 | # ResNet101 to increase quality if your GPU memory is higher than 8GB. 56 | # To do so, you will need to download both .yaml and .pkl ResNet101 files 57 | # then replace the below 'cfg_file' with the following: 58 | # self.cfg_file = 'models/mrcnn/e2e_mask_rcnn_X-101-64x4d-FPN_2x.yaml' 59 | self.cfg_file = 'models/mrcnn/e2e_mask_rcnn_R-50-FPN_2x.yaml' 60 | self.weights = 'models/mrcnn/model_final.pkl' 61 | self.default_cfg = copy.deepcopy(AttrDict(cfg)) # cfg from detectron.core.config 62 | self.mrcnn_cfg = AttrDict() 63 | self.dummy_coco_dataset = dummy_datasets.get_coco_dataset() 64 | 65 | # Inference options 66 | self.show_box = True 67 | self.show_class = True 68 | self.thresh = 0.7 69 | self.alpha = 0.4 70 | self.show_border = True 71 | self.border_thick = 1 72 | self.bbox_thick = 1 73 | self.font_scale = 0.35 74 | self.binary_masks = False 75 | 76 | # Define exposed options 77 | self.options = ( 78 | 'show_box', 'show_class', 'thresh', 'alpha', 'show_border', 79 | 'border_thick', 'bbox_thick', 'font_scale', 'binary_masks', 80 | ) 81 | # Define inputs/outputs 82 | self.inputs = {'input': 3} 83 | self.outputs = {'output': 3} 84 | 85 | def inference(self, image_list): 86 | """Do an inference on the model with a set of inputs. 87 | 88 | # Arguments: 89 | image_list: The input image list 90 | 91 | Return the result of the inference. 92 | """ 93 | image = image_list[0] 94 | image = linear_to_srgb(image)*255. 95 | imcpy = image.copy() 96 | 97 | # Initialize the model out of the configuration and weights files 98 | if not hasattr(self, 'model'): 99 | workspace.ResetWorkspace() 100 | # Reset to default config 101 | merge_cfg_from_cfg(self.default_cfg) 102 | # Load mask rcnn configuration file 103 | merge_cfg_from_file(self.cfg_file) 104 | assert_and_infer_cfg(cache_urls=False, make_immutable=False) 105 | self.model = infer_engine.initialize_model_from_cfg(self.weights) 106 | # Save mask rcnn full configuration file 107 | self.mrcnn_cfg = copy.deepcopy(AttrDict(cfg)) # cfg from detectron.core.config 108 | else: 109 | # There is a global config file for all detectron models (Densepose, Mask RCNN..) 110 | # Check if current global config file is correct for mask rcnn 111 | if not dict_equal(self.mrcnn_cfg, cfg): 112 | # Free memory of previous workspace 113 | workspace.ResetWorkspace() 114 | # Load mask rcnn configuration file 115 | merge_cfg_from_cfg(self.mrcnn_cfg) 116 | assert_and_infer_cfg(cache_urls=False, make_immutable=False) 117 | self.model = infer_engine.initialize_model_from_cfg(self.weights) 118 | 119 | with c2_utils.NamedCudaScope(0): 120 | # If using densepose/detectron GitHub, im_detect_all also returns cls_bodys 121 | # Only takes the first 3 elements of the list for compatibility 122 | cls_boxes, cls_segms, cls_keyps = infer_engine.im_detect_all( 123 | self.model, image[:, :, ::-1], None 124 | )[:3] 125 | 126 | if self.binary_masks: 127 | res = vis_one_image_binary( 128 | imcpy, 129 | cls_boxes, 130 | cls_segms, 131 | thresh=self.thresh 132 | ) 133 | else: 134 | res = vis_one_image_opencv( 135 | imcpy, 136 | cls_boxes, 137 | cls_segms, 138 | cls_keyps, 139 | thresh=self.thresh, 140 | show_box=self.show_box, 141 | show_class=self.show_class, 142 | dataset=self.dummy_coco_dataset, 143 | alpha=self.alpha, 144 | show_border=self.show_border, 145 | border_thick=self.border_thick, 146 | bbox_thick=self.bbox_thick, 147 | font_scale=self.font_scale 148 | ) 149 | 150 | res = srgb_to_linear(res.astype(np.float32) / 255.) 151 | 152 | return [res] -------------------------------------------------------------------------------- /Models/mrcnn/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Foundry. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | ############################################################################## 15 | 16 | """Utility functions for mask rcnn model""" 17 | 18 | from detectron.utils.collections import AttrDict 19 | import numpy as np 20 | 21 | def dict_equal(d1, d2): 22 | """Recursively compute if two dictionaries are equals both in keys and values. 23 | 24 | # Arguments: 25 | d1, d2: The two dictionaries to compare 26 | 27 | # Return: 28 | False if any key or value are different, True otherwise 29 | """ 30 | for k in d1: 31 | if k not in d2: 32 | return False 33 | for k in d2: 34 | if type(d2[k]) not in (dict, list, AttrDict, np.ndarray): 35 | if d2[k] != d1[k]: 36 | return False 37 | elif type(d2[k]) == "np.ndarray": 38 | if any(d2[k] != d1[k]): 39 | return False 40 | else: # d2[k] dictionary or list 41 | if type(d1[k]) != type(d2[k]): 42 | return False 43 | else: 44 | if type(d2[k]) == AttrDict or type(d2[k]) == dict: 45 | if(not dict_equal(d1[k], d2[k])): 46 | return False 47 | return True -------------------------------------------------------------------------------- /Models/mrcnn/vis.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | ############################################################################## 15 | 16 | """Detection output visualization module.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | from __future__ import unicode_literals 22 | 23 | import cv2 24 | import numpy as np 25 | import os 26 | 27 | import pycocotools.mask as mask_util 28 | 29 | from detectron.utils.colormap import colormap 30 | import detectron.utils.env as envu 31 | import detectron.utils.keypoints as keypoint_utils 32 | # Matplotlib requires certain adjustments in some environments 33 | # Must happen before importing matplotlib 34 | envu.set_up_matplotlib() 35 | import matplotlib.pyplot as plt 36 | from matplotlib.patches import Polygon 37 | 38 | plt.rcParams['pdf.fonttype'] = 42 # For editing in Adobe Illustrator 39 | 40 | 41 | _GRAY = (218, 227, 218) 42 | _GREEN = (18, 127, 15) 43 | _WHITE = (255, 255, 255) 44 | 45 | 46 | def kp_connections(keypoints): 47 | kp_lines = [ 48 | [keypoints.index('left_eye'), keypoints.index('right_eye')], 49 | [keypoints.index('left_eye'), keypoints.index('nose')], 50 | [keypoints.index('right_eye'), keypoints.index('nose')], 51 | [keypoints.index('right_eye'), keypoints.index('right_ear')], 52 | [keypoints.index('left_eye'), keypoints.index('left_ear')], 53 | [keypoints.index('right_shoulder'), keypoints.index('right_elbow')], 54 | [keypoints.index('right_elbow'), keypoints.index('right_wrist')], 55 | [keypoints.index('left_shoulder'), keypoints.index('left_elbow')], 56 | [keypoints.index('left_elbow'), keypoints.index('left_wrist')], 57 | [keypoints.index('right_hip'), keypoints.index('right_knee')], 58 | [keypoints.index('right_knee'), keypoints.index('right_ankle')], 59 | [keypoints.index('left_hip'), keypoints.index('left_knee')], 60 | [keypoints.index('left_knee'), keypoints.index('left_ankle')], 61 | [keypoints.index('right_shoulder'), keypoints.index('left_shoulder')], 62 | [keypoints.index('right_hip'), keypoints.index('left_hip')], 63 | ] 64 | return kp_lines 65 | 66 | 67 | def convert_from_cls_format(cls_boxes, cls_segms, cls_keyps): 68 | """Convert from the class boxes/segms/keyps format generated by the testing 69 | code. 70 | """ 71 | box_list = [b for b in cls_boxes if len(b) > 0] 72 | if len(box_list) > 0: 73 | boxes = np.concatenate(box_list) 74 | else: 75 | boxes = None 76 | if cls_segms is not None: 77 | segms = [s for slist in cls_segms for s in slist] 78 | else: 79 | segms = None 80 | if cls_keyps is not None: 81 | keyps = [k for klist in cls_keyps for k in klist] 82 | else: 83 | keyps = None 84 | classes = [] 85 | for j in range(len(cls_boxes)): 86 | classes += [j] * len(cls_boxes[j]) 87 | return boxes, segms, keyps, classes 88 | 89 | 90 | def get_class_string(class_index, score, dataset): 91 | class_text = dataset.classes[class_index] if dataset is not None else \ 92 | 'id{:d}'.format(class_index) 93 | return class_text + ' {:0.2f}'.format(score).lstrip('0') 94 | 95 | 96 | def vis_mask(img, mask, col, alpha=0.4, show_border=True, border_thick=1): 97 | """Visualizes a single binary mask.""" 98 | 99 | img = img.astype(np.float32) 100 | idx = np.nonzero(mask) 101 | 102 | img[idx[0], idx[1], :] *= 1.0 - alpha 103 | img[idx[0], idx[1], :] += alpha * col 104 | 105 | if show_border: 106 | # cv2.findContours gives (image, contours, hierarchy) back in opencv 3.x 107 | # but gives back (contours, hierachy) in opencv 2.x and 4.x 108 | contours, _ = cv2.findContours( 109 | mask.copy(), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)[-2:] 110 | cv2.drawContours(img, contours, -1, _WHITE, border_thick, cv2.LINE_AA) 111 | 112 | return img.astype(np.uint8) 113 | 114 | 115 | def vis_class(img, pos, class_str, font_scale=0.35): 116 | """Visualizes the class.""" 117 | img = img.astype(np.uint8) 118 | x0, y0 = int(pos[0]), int(pos[1]) 119 | # Compute text size. 120 | txt = class_str 121 | font = cv2.FONT_HERSHEY_SIMPLEX 122 | ((txt_w, txt_h), _) = cv2.getTextSize(txt, font, font_scale, 1) 123 | # Place text background. 124 | back_tl = x0, y0 - int(1.3 * txt_h) 125 | back_br = x0 + txt_w, y0 126 | cv2.rectangle(img, back_tl, back_br, _GREEN, -1) 127 | # Show text. 128 | txt_tl = x0, y0 - int(0.3 * txt_h) 129 | cv2.putText(img, txt, txt_tl, font, font_scale, _GRAY, lineType=cv2.LINE_AA) 130 | return img 131 | 132 | 133 | def vis_bbox(img, bbox, thick=1): 134 | """Visualizes a bounding box.""" 135 | img = img.astype(np.uint8) 136 | (x0, y0, w, h) = bbox 137 | x1, y1 = int(x0 + w), int(y0 + h) 138 | x0, y0 = int(x0), int(y0) 139 | cv2.rectangle(img, (x0, y0), (x1, y1), _GREEN, thickness=thick) 140 | return img 141 | 142 | 143 | def vis_keypoints(img, kps, kp_thresh=2, alpha=0.7): 144 | """Visualizes keypoints (adapted from vis_one_image). 145 | kps has shape (4, #keypoints) where 4 rows are (x, y, logit, prob). 146 | """ 147 | dataset_keypoints, _ = keypoint_utils.get_keypoints() 148 | kp_lines = kp_connections(dataset_keypoints) 149 | 150 | # Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv. 151 | cmap = plt.get_cmap('rainbow') 152 | colors = [cmap(i) for i in np.linspace(0, 1, len(kp_lines) + 2)] 153 | colors = [(c[2] * 255, c[1] * 255, c[0] * 255) for c in colors] 154 | 155 | # Perform the drawing on a copy of the image, to allow for blending. 156 | kp_mask = np.copy(img) 157 | 158 | # Draw mid shoulder / mid hip first for better visualization. 159 | mid_shoulder = ( 160 | kps[:2, dataset_keypoints.index('right_shoulder')] + 161 | kps[:2, dataset_keypoints.index('left_shoulder')]) / 2.0 162 | sc_mid_shoulder = np.minimum( 163 | kps[2, dataset_keypoints.index('right_shoulder')], 164 | kps[2, dataset_keypoints.index('left_shoulder')]) 165 | mid_hip = ( 166 | kps[:2, dataset_keypoints.index('right_hip')] + 167 | kps[:2, dataset_keypoints.index('left_hip')]) / 2.0 168 | sc_mid_hip = np.minimum( 169 | kps[2, dataset_keypoints.index('right_hip')], 170 | kps[2, dataset_keypoints.index('left_hip')]) 171 | nose_idx = dataset_keypoints.index('nose') 172 | if sc_mid_shoulder > kp_thresh and kps[2, nose_idx] > kp_thresh: 173 | cv2.line( 174 | kp_mask, tuple(mid_shoulder), tuple(kps[:2, nose_idx]), 175 | color=colors[len(kp_lines)], thickness=2, lineType=cv2.LINE_AA) 176 | if sc_mid_shoulder > kp_thresh and sc_mid_hip > kp_thresh: 177 | cv2.line( 178 | kp_mask, tuple(mid_shoulder), tuple(mid_hip), 179 | color=colors[len(kp_lines) + 1], thickness=2, lineType=cv2.LINE_AA) 180 | 181 | # Draw the keypoints. 182 | for l in range(len(kp_lines)): 183 | i1 = kp_lines[l][0] 184 | i2 = kp_lines[l][1] 185 | p1 = kps[0, i1], kps[1, i1] 186 | p2 = kps[0, i2], kps[1, i2] 187 | if kps[2, i1] > kp_thresh and kps[2, i2] > kp_thresh: 188 | cv2.line( 189 | kp_mask, p1, p2, 190 | color=colors[l], thickness=2, lineType=cv2.LINE_AA) 191 | if kps[2, i1] > kp_thresh: 192 | cv2.circle( 193 | kp_mask, p1, 194 | radius=3, color=colors[l], thickness=-1, lineType=cv2.LINE_AA) 195 | if kps[2, i2] > kp_thresh: 196 | cv2.circle( 197 | kp_mask, p2, 198 | radius=3, color=colors[l], thickness=-1, lineType=cv2.LINE_AA) 199 | 200 | # Blend the keypoints. 201 | return cv2.addWeighted(img, 1.0 - alpha, kp_mask, alpha, 0) 202 | 203 | 204 | def vis_one_image_opencv( 205 | im, boxes, segms=None, keypoints=None, thresh=0.9, kp_thresh=2, 206 | show_box=False, dataset=None, show_class=False, 207 | alpha=0.4, show_border=True, border_thick=1, bbox_thick=1, font_scale=0.35): 208 | """Constructs a numpy array with the detections visualized.""" 209 | 210 | if isinstance(boxes, list): 211 | boxes, segms, keypoints, classes = convert_from_cls_format( 212 | boxes, segms, keypoints) 213 | 214 | if boxes is None or boxes.shape[0] == 0 or max(boxes[:, 4]) < thresh: 215 | return im 216 | 217 | if segms is not None and len(segms) > 0: 218 | masks = mask_util.decode(segms) 219 | color_list = colormap() 220 | mask_color_id = 0 221 | 222 | # Display in largest to smallest order to reduce occlusion 223 | areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) 224 | sorted_inds = np.argsort(-areas) 225 | 226 | for i in sorted_inds: 227 | bbox = boxes[i, :4] 228 | score = boxes[i, -1] 229 | if score < thresh: 230 | continue 231 | 232 | # show box (off by default) 233 | if show_box: 234 | im = vis_bbox( 235 | im, (bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]), thick=bbox_thick) 236 | 237 | # show class (off by default) 238 | if show_class: 239 | class_str = get_class_string(classes[i], score, dataset) 240 | im = vis_class(im, (bbox[0], bbox[1] - 2), class_str, font_scale=font_scale) 241 | 242 | # show mask 243 | if segms is not None and len(segms) > i: 244 | color_mask = color_list[mask_color_id % len(color_list), 0:3] 245 | mask_color_id += 1 246 | im = vis_mask(im, masks[..., i], color_mask, alpha=alpha, 247 | show_border=show_border, border_thick=border_thick) 248 | 249 | # show keypoints 250 | if keypoints is not None and len(keypoints) > i: 251 | im = vis_keypoints(im, keypoints[i], kp_thresh) 252 | 253 | return im 254 | 255 | 256 | def vis_one_image_binary(im, boxes, segms, keypoints=None, thresh=0.9): 257 | im = np.zeros_like(im) 258 | if isinstance(boxes, list): 259 | boxes, segms, keypoints, classes = convert_from_cls_format( 260 | boxes, segms, keypoints) 261 | 262 | if boxes is None or boxes.shape[0] == 0 or max(boxes[:, 4]) < thresh: 263 | return im 264 | 265 | if segms is not None and len(segms) > 0: 266 | masks = mask_util.decode(segms) 267 | 268 | # Display in largest to smallest order to reduce occlusion 269 | areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) 270 | sorted_inds = np.argsort(-areas) 271 | 272 | for i in sorted_inds: 273 | bbox = boxes[i, :4] 274 | score = boxes[i, -1] 275 | if score < thresh: 276 | continue 277 | 278 | color_mask = np.array([1., 1., 1.]) * 255 279 | im = vis_mask(im, masks[..., i], color_mask, alpha=1., 280 | show_border=False) 281 | 282 | return im 283 | 284 | 285 | def vis_one_image( 286 | im, im_name, output_dir, boxes, segms=None, keypoints=None, thresh=0.9, 287 | kp_thresh=2, dpi=200, box_alpha=0.0, dataset=None, show_class=False, 288 | ext='pdf', out_when_no_box=False): 289 | """Visual debugging of detections.""" 290 | if not os.path.exists(output_dir): 291 | os.makedirs(output_dir) 292 | 293 | if isinstance(boxes, list): 294 | boxes, segms, keypoints, classes = convert_from_cls_format( 295 | boxes, segms, keypoints) 296 | 297 | if (boxes is None or boxes.shape[0] == 0 or max(boxes[:, 4]) < thresh) and not out_when_no_box: 298 | return 299 | 300 | dataset_keypoints, _ = keypoint_utils.get_keypoints() 301 | 302 | if segms is not None and len(segms) > 0: 303 | masks = mask_util.decode(segms) 304 | 305 | color_list = colormap(rgb=True) / 255 306 | 307 | kp_lines = kp_connections(dataset_keypoints) 308 | cmap = plt.get_cmap('rainbow') 309 | colors = [cmap(i) for i in np.linspace(0, 1, len(kp_lines) + 2)] 310 | 311 | fig = plt.figure(frameon=False) 312 | fig.set_size_inches(im.shape[1] / dpi, im.shape[0] / dpi) 313 | ax = plt.Axes(fig, [0., 0., 1., 1.]) 314 | ax.axis('off') 315 | fig.add_axes(ax) 316 | ax.imshow(im) 317 | 318 | if boxes is None: 319 | sorted_inds = [] # avoid crash when 'boxes' is None 320 | else: 321 | # Display in largest to smallest order to reduce occlusion 322 | areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) 323 | sorted_inds = np.argsort(-areas) 324 | 325 | mask_color_id = 0 326 | for i in sorted_inds: 327 | bbox = boxes[i, :4] 328 | score = boxes[i, -1] 329 | if score < thresh: 330 | continue 331 | 332 | # show box (off by default) 333 | ax.add_patch( 334 | plt.Rectangle((bbox[0], bbox[1]), 335 | bbox[2] - bbox[0], 336 | bbox[3] - bbox[1], 337 | fill=False, edgecolor='g', 338 | linewidth=0.5, alpha=box_alpha)) 339 | 340 | if show_class: 341 | ax.text( 342 | bbox[0], bbox[1] - 2, 343 | get_class_string(classes[i], score, dataset), 344 | fontsize=3, 345 | family='serif', 346 | bbox=dict( 347 | facecolor='g', alpha=0.4, pad=0, edgecolor='none'), 348 | color='white') 349 | 350 | # show mask 351 | if segms is not None and len(segms) > i: 352 | img = np.ones(im.shape) 353 | color_mask = color_list[mask_color_id % len(color_list), 0:3] 354 | mask_color_id += 1 355 | 356 | w_ratio = .4 357 | for c in range(3): 358 | color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio 359 | for c in range(3): 360 | img[:, :, c] = color_mask[c] 361 | e = masks[:, :, i] 362 | 363 | # cv2.findCountours gives (image, contours, hierarchy) back in opencv 3.x 364 | # but gives back (contours, hierachy) in opencv 2.x and 4.x 365 | contour, hier = cv2.findContours( 366 | e.copy(), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)[-2:] 367 | 368 | for c in contour: 369 | polygon = Polygon( 370 | c.reshape((-1, 2)), 371 | fill=True, facecolor=color_mask, 372 | edgecolor='w', linewidth=1.2, 373 | alpha=0.5) 374 | ax.add_patch(polygon) 375 | 376 | # show keypoints 377 | if keypoints is not None and len(keypoints) > i: 378 | kps = keypoints[i] 379 | plt.autoscale(False) 380 | for l in range(len(kp_lines)): 381 | i1 = kp_lines[l][0] 382 | i2 = kp_lines[l][1] 383 | if kps[2, i1] > kp_thresh and kps[2, i2] > kp_thresh: 384 | x = [kps[0, i1], kps[0, i2]] 385 | y = [kps[1, i1], kps[1, i2]] 386 | line = plt.plot(x, y) 387 | plt.setp(line, color=colors[l], linewidth=1.0, alpha=0.7) 388 | if kps[2, i1] > kp_thresh: 389 | plt.plot( 390 | kps[0, i1], kps[1, i1], '.', color=colors[l], 391 | markersize=3.0, alpha=0.7) 392 | 393 | if kps[2, i2] > kp_thresh: 394 | plt.plot( 395 | kps[0, i2], kps[1, i2], '.', color=colors[l], 396 | markersize=3.0, alpha=0.7) 397 | 398 | # add mid shoulder / mid hip for better visualization 399 | mid_shoulder = ( 400 | kps[:2, dataset_keypoints.index('right_shoulder')] + 401 | kps[:2, dataset_keypoints.index('left_shoulder')]) / 2.0 402 | sc_mid_shoulder = np.minimum( 403 | kps[2, dataset_keypoints.index('right_shoulder')], 404 | kps[2, dataset_keypoints.index('left_shoulder')]) 405 | mid_hip = ( 406 | kps[:2, dataset_keypoints.index('right_hip')] + 407 | kps[:2, dataset_keypoints.index('left_hip')]) / 2.0 408 | sc_mid_hip = np.minimum( 409 | kps[2, dataset_keypoints.index('right_hip')], 410 | kps[2, dataset_keypoints.index('left_hip')]) 411 | if (sc_mid_shoulder > kp_thresh and 412 | kps[2, dataset_keypoints.index('nose')] > kp_thresh): 413 | x = [mid_shoulder[0], kps[0, dataset_keypoints.index('nose')]] 414 | y = [mid_shoulder[1], kps[1, dataset_keypoints.index('nose')]] 415 | line = plt.plot(x, y) 416 | plt.setp( 417 | line, color=colors[len(kp_lines)], linewidth=1.0, alpha=0.7) 418 | if sc_mid_shoulder > kp_thresh and sc_mid_hip > kp_thresh: 419 | x = [mid_shoulder[0], mid_hip[0]] 420 | y = [mid_shoulder[1], mid_hip[1]] 421 | line = plt.plot(x, y) 422 | plt.setp( 423 | line, color=colors[len(kp_lines) + 1], linewidth=1.0, 424 | alpha=0.7) 425 | 426 | output_name = os.path.basename(im_name) + '.' + ext 427 | fig.savefig(os.path.join(output_dir, '{}'.format(output_name)), dpi=dpi) 428 | plt.close('all') 429 | -------------------------------------------------------------------------------- /Models/regressionTemplateTF/README.md: -------------------------------------------------------------------------------- 1 | # Regression Training Template 2 | 3 | The regressionTemplateTF is a training template written in TensorFlow. It aims at quickly enabling image-to-parameters training. For instance, finding the lens distortion parameters or gamma correction of an image. When trained, the model can be tested and used directly in Nuke through the nuke-ML-server. 4 | 5 | Compared to the image-to-image [Training Template](https://github.com/TheFoundryVisionmongers/nuke-ML-server/tree/master/Models/trainingTemplateTF) and the image-to-labels [Classification Template](https://github.com/TheFoundryVisionmongers/nuke-ML-server/tree/master/Models/classTemplateTF), this template will not work out-of-the-box and will require some data preprocessing implementation, as detailed in the [following section](https://github.com/TheFoundryVisionmongers/nuke-ML-server/tree/master/Models/regressionTemplateTF#data-preprocessing-implementation). This guide will be based on the current template example: gamma-correction prediction. 6 | 7 | For instructions on how to set-up the training, on potential training issues or on TensorBoard visualisation, please refer to the [training template readme](https://github.com/TheFoundryVisionmongers/nuke-ML-server/blob/master/Models/trainingTemplateTF/README.md). 8 | 9 | ## Data Preprocessing Implementation 10 | 11 | To train the ML algorithm, you need to set-up your dataset in `regressionTemplateTF/data/train/`. In addition to the training data, it is highly recommended to have validation data in `regressionTemplateTF/data/validation/`. This allows you to check that there is no overfitting on the training data. Please note that the validation dataset and training dataset must not intersect. 12 | 13 | Your training/validation dataset will be different depending on your task, i.e. depending on which parameter(s) you want to learn. In the current implementation, we are doing a regression on one parameter (gamma) with a specifically designed data preprocessing pipeline. Namely our model training input is a stack of both original and gamma-graded image histograms. 14 | 15 | Our preprocessing pipeline read the original image (from `regressionTemplateTF/data/train/` or `regressionTemplateTF/data/validation/`), then apply gamma correction to that image using a random gamma value. Both the original and resulting gamma-graded images are grayscaled, resized and we compute their 100-bin histogram. The model input (shape [2, 100]) is a stack of those two histograms. 16 | 17 | The above data preprocessing is specific to the gamma-correction problem, which means that for other parameters prediction (e.g. colour grading, lens distortion..), you will have to modifiy the data preprocessing functions found in `train_regression.py` and in `model.py` to match your task. The inference file `model.py` has to be changed as well, as the same data preprocessing used in training has to be applied before doing an inference in Nuke. 18 | 19 | To summarise, for your specific regression task, you need to implement an appropriate data preprocessing and modify the code in both the training file `train_regression.py` and the inference file `model.py` accordingly. 20 | 21 | ## Training 22 | 23 | Inside your docker container, go to the regressionTemplateTF folder: 24 | ``` 25 | cd /workspace/ml-server/models/regressionTemplateTF 26 | ``` 27 | Then directly train your model: 28 | ``` 29 | python train_regression.py 30 | ``` 31 | You can also specify the batch size, learning rate and number of epochs: 32 | ``` 33 | python train_regression.py --bch=16 --lr=1e-3 --ep=1000 34 | ``` 35 | It is now possible to have deterministic training. You will be able to reproduce your training (get same model weights) by setting the seed to a random int number (here 77): 36 | ``` 37 | python train_regression.py --seed=77 38 | ``` 39 | We enable deterministic training in part by applying a GPU patch to the stock TensorFlow, this GPU patch slows down training significantly. By adding the `--no-gpu-patch` tag to the previous command, you achieve a slighlty less deterministic training but keep the same training time. 40 | 41 | Note: the current gamma-correction task is creating gamma-graded images on-the-fly using random gamma values, so for the training to succeed it is recommended to have >500 training images. 42 | 43 | 44 | 45 | 46 | -------------------------------------------------------------------------------- /Models/regressionTemplateTF/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheFoundryVisionmongers/nuke-ML-server/5dd4d04cd673c60de8093c600d6c54016fca92d6/Models/regressionTemplateTF/__init__.py -------------------------------------------------------------------------------- /Models/regressionTemplateTF/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Foundry. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | ############################################################################## 15 | 16 | from __future__ import print_function 17 | 18 | import sys 19 | import os 20 | import time 21 | 22 | import scipy.misc 23 | import numpy as np 24 | import cv2 25 | 26 | import tensorflow as tf 27 | tf.compat.v1.disable_eager_execution() # For TF 2.x compatibility 28 | 29 | from models.baseModel import BaseModel 30 | from models.common.model_builder import baseline_model 31 | from models.common.util import print_, get_ckpt_list, linear_to_srgb, srgb_to_linear 32 | 33 | import message_pb2 34 | 35 | class Model(BaseModel): 36 | """Load your trained model and do inference in Nuke""" 37 | 38 | def __init__(self): 39 | super(Model, self).__init__() 40 | self.name = 'Regression Template TF' 41 | self.n_levels = 3 42 | self.scale = 0.5 43 | dir_path = os.path.dirname(os.path.realpath(__file__)) 44 | self.checkpoints_dir = os.path.join(dir_path, 'checkpoints') 45 | self.patch_size = 50 46 | self.output_param_number = 1 47 | 48 | # Initialise checkpoint name to the latest checkpoint 49 | ckpt_names = get_ckpt_list(self.checkpoints_dir) 50 | if not ckpt_names: # empty list 51 | self.checkpoint_name = '' 52 | else: 53 | latest_ckpt = tf.compat.v1.train.latest_checkpoint(self.checkpoints_dir) 54 | if latest_ckpt is not None: 55 | self.checkpoint_name = latest_ckpt.split('/')[-1] 56 | else: 57 | self.checkpoint_name = ckpt_names[-1] 58 | self.prev_ckpt_name = self.checkpoint_name 59 | 60 | # Silence TF log when creating tf.Session() 61 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 62 | 63 | # Define options 64 | self.gamma_to_predict = 1.0 65 | self.predict = False 66 | self.options = ('checkpoint_name', 'gamma_to_predict',) 67 | self.buttons = ('predict',) 68 | # Define inputs/outputs 69 | self.inputs = {'input': 3} 70 | self.outputs = {'output': 3} 71 | 72 | def load(self, model): 73 | # Check if empty or invalid checkpoint name 74 | if self.checkpoint_name=='': 75 | ckpt_names = get_ckpt_list(self.checkpoints_dir) 76 | if not ckpt_names: 77 | raise ValueError("No checkpoints found in {}".format(self.checkpoints_dir)) 78 | else: 79 | raise ValueError("Empty checkpoint name, try an available checkpoint in {} (ex: {})" 80 | .format(self.checkpoints_dir, ckpt_names[-1])) 81 | print_("Loading trained model checkpoint...\n", 'm') 82 | # Load from given checkpoint file name 83 | self.saver.restore(self.sess, os.path.join(self.checkpoints_dir, self.checkpoint_name)) 84 | print_("...Checkpoint {} loaded\n".format(self.checkpoint_name), 'm') 85 | 86 | def inference(self, image_list): 87 | """Do an inference on the model with a set of inputs. 88 | 89 | # Arguments: 90 | image_list: The input image list 91 | 92 | Return the result of the inference. 93 | """ 94 | image = image_list[0] 95 | image = linear_to_srgb(image).copy() 96 | 97 | if not hasattr(self, 'sess'): 98 | # Initialise tensorflow graph 99 | tf.compat.v1.reset_default_graph() 100 | config = tf.compat.v1.ConfigProto() 101 | config.gpu_options.allow_growth=True 102 | self.sess=tf.compat.v1.Session(config=config) 103 | # Input is stacked histograms of original and gamma-graded images. 104 | input_shape = [1, 2, 100] 105 | # Initialise input placeholder size 106 | self.input = tf.compat.v1.placeholder(tf.float32, shape=input_shape) 107 | self.model = baseline_model( 108 | input_shape=input_shape[1:], 109 | output_param_number=self.output_param_number) 110 | self.infer_op = self.model(self.input) 111 | # Load latest model checkpoint 112 | self.saver = tf.compat.v1.train.Saver() 113 | self.load(self.model) 114 | self.prev_ckpt_name = self.checkpoint_name 115 | 116 | # If checkpoint name has changed, load new checkpoint 117 | if self.prev_ckpt_name != self.checkpoint_name or self.checkpoint_name == '': 118 | self.load(self.model) 119 | # If checkpoint correctly loaded, update previous checkpoint name 120 | self.prev_ckpt_name = self.checkpoint_name 121 | 122 | # Preprocess image same way we preprocessed it for training 123 | # Here for gamma correction compute histograms 124 | def histogram(x, value_range=[0.0, 1.0], nbins=100): 125 | """Return histogram of tensor x""" 126 | h, w, c = x.shape 127 | hist = tf.histogram_fixed_width(x, value_range, nbins=nbins) 128 | hist = tf.divide(hist, h * w * c) 129 | return hist 130 | with tf.compat.v1.Session() as sess: 131 | # Convert to grayscale 132 | img_gray = tf.image.rgb_to_grayscale(image) 133 | img_gray = tf.image.resize(img_gray, [self.patch_size, self.patch_size]) 134 | # Apply gamma correction 135 | img_gray_grade = tf.math.pow(img_gray, self.gamma_to_predict) 136 | img_grade = tf.math.pow(image, self.gamma_to_predict) 137 | # Compute histograms 138 | img_hist = histogram(img_gray) 139 | img_grade_hist = histogram(img_gray_grade) 140 | hists_op = tf.stack([img_hist, img_grade_hist], axis=0) 141 | hists, img_grade = sess.run([hists_op, img_grade]) 142 | res_img = srgb_to_linear(img_grade) 143 | 144 | hists_batch = np.expand_dims(hists, 0) 145 | start = time.time() 146 | # Run model inference 147 | inference = self.sess.run(self.infer_op, feed_dict={self.input: hists_batch}) 148 | duration = time.time() - start 149 | print('Inference duration: {:4.3f}s'.format(duration)) 150 | res = inference[-1] 151 | print("Predicted gamma: {}".format(res)) 152 | 153 | # If predict button is pressed in Nuke 154 | if self.predict: 155 | script_msg = message_pb2.FieldValuePairAttrib() 156 | script_msg.name = "PythonScript" 157 | # Create a Python script message to run in Nuke 158 | python_script = self.nuke_script(res) 159 | script_msg_val = script_msg.values.add() 160 | script_msg_str = script_msg_val.string_attributes.add() 161 | script_msg_str.values.extend([python_script]) 162 | return [res_img, script_msg] 163 | 164 | return [res_img] 165 | 166 | def nuke_script(self, res): 167 | """Return the Python script function to create a pop up window in Nuke.""" 168 | popup_msg = "Predicted gamma: {}".format(res) 169 | script = "nuke.message('{}')\n".format(popup_msg) 170 | return script -------------------------------------------------------------------------------- /Models/regressionTemplateTF/train_regression.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Foundry. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | ############################################################################## 15 | 16 | from __future__ import division, print_function, absolute_import 17 | from builtins import input # python 2/3 forward-compatible (raw_input) 18 | 19 | import sys 20 | import os 21 | import time 22 | import random 23 | import argparse 24 | from datetime import datetime 25 | 26 | import numpy as np 27 | 28 | import tensorflow as tf 29 | print(tf.__version__) 30 | 31 | tf.compat.v1.enable_eager_execution() 32 | 33 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 34 | from common.model_builder import baseline_model 35 | from common.util import get_filepaths_from_dir, get_ckpt_list, print_ 36 | from common.util import is_exr, read_resize_exr, linear_to_srgb 37 | 38 | def enable_deterministic_training(seed, no_gpu_patch=False): 39 | """Set all seeds for deterministic training 40 | 41 | Args: 42 | no_gpu_patch (bool): if False, apply a patch to TensorFlow to have 43 | deterministic GPU operations, if True the training is much faster 44 | but slightly less deterministic. 45 | This function needs to be called before any TensorFlow code. 46 | """ 47 | import numpy as np 48 | import os 49 | import random 50 | import tfdeterminism 51 | if not no_gpu_patch: 52 | # Patch stock TensorFlow to have deterministic GPU operation 53 | tfdeterminism.patch() # then use tf as normal 54 | # If PYTHONHASHSEED environment variable is not set or set to random, 55 | # a random value is used to seed the hashes of str, bytes and datetime 56 | # objects. (Necessary for Python >= 3.2.3) 57 | os.environ['PYTHONHASHSEED']=str(seed) 58 | # Set python built-in pseudo-random generator at a fixed value 59 | random.seed(seed) 60 | # Set seed for random Numpy operation (e.g. np.random.randint) 61 | np.random.seed(seed) 62 | # Set seed for random TensorFlow operation (e.g. tf.image.random_crop) 63 | tf.compat.v1.random.set_random_seed(seed) 64 | 65 | ## DATA PROCESSING 66 | 67 | def histogram(tensor, value_range=[0.0, 1.0], nbins=100): 68 | """Return histogram of tensor""" 69 | h, w, c = tensor.shape 70 | hist = tf.histogram_fixed_width(tensor, value_range, nbins=nbins) 71 | hist = tf.divide(hist, h * w * c) 72 | return hist 73 | 74 | def gamma_correction(img, gamma): 75 | """Apply gamma correction to image img 76 | 77 | Returns: 78 | hists: stack of both original and graded image histograms 79 | """ 80 | # Check number of parameter is one 81 | if gamma.shape[0] != 1: 82 | raise ValueError("Parameter for gamma correction must be of " 83 | "size (1,), not {}.\n\tCheck your self.output_param_number, ".format(gamma.shape) 84 | + "you may need to implement your own input_data preprocessing.") 85 | # Create groundtruth graded image 86 | img_grade = tf.math.pow(img, gamma) 87 | # Compute histograms 88 | img_hist = histogram(img) 89 | img_grade_hist = histogram(img_grade) 90 | hists = tf.stack([img_hist, img_grade_hist], axis=0) 91 | return hists 92 | 93 | ## CUSTOM TRAINING METRICS 94 | 95 | def bin_acc(y_true, y_pred, delta=0.02): 96 | """Bin accuracy metric equals 1.0 if diff between true 97 | and predicted value is inferior to delta. 98 | """ 99 | diff = tf.keras.backend.abs(y_true - y_pred) 100 | # If diff is less that delta --> true (1.0), otherwise false (0.0) 101 | correct = tf.keras.backend.less(diff, delta) 102 | # Return percentage accuracy 103 | return tf.keras.backend.mean(correct) 104 | 105 | class TrainModel(object): 106 | """Train Regression model from the given data""" 107 | 108 | def __init__(self, args): 109 | # Training hyperparameters 110 | self.learning_rate = args.learning_rate 111 | self.batch_size = args.batch_size 112 | self.epoch = args.epoch 113 | self.patch_size = 50 114 | self.channels = 3 # input / output channels 115 | self.output_param_number = 1 116 | self.no_resume = args.no_resume 117 | # A random seed (!=None) allows you to reproduce your training results 118 | self.seed = args.seed 119 | if self.seed is not None: 120 | # Set all seeds necessary for deterministic training 121 | enable_deterministic_training(self.seed, args.no_gpu_patch) 122 | # Training and validation dataset paths 123 | train_data_path = './data/train/' 124 | val_data_path = './data/validation/' 125 | 126 | # Where to save and load model weights (=checkpoints) 127 | self.ckpt_dir = './checkpoints' 128 | if not os.path.exists(self.ckpt_dir): 129 | os.makedirs(self.ckpt_dir) 130 | self.ckpt_save_name = args.ckpt_save_name 131 | 132 | # Where to save tensorboard summaries 133 | self.summaries_dir = './summaries/' 134 | if not os.path.exists(self.summaries_dir): 135 | os.makedirs(self.summaries_dir) 136 | 137 | # Get training dataset as list of image paths 138 | self.train_data_list = get_filepaths_from_dir(train_data_path) 139 | if not self.train_data_list: 140 | raise ValueError("No training data found in folder {}".format(train_data_path)) 141 | elif (len(self.train_data_list) < self.batch_size): 142 | raise ValueError("Batch size must be smaller than the dataset (batch size = {}, number of training data = {})" 143 | .format(self.batch_size, len(self.train_data_list))) 144 | self.is_exr = is_exr(self.train_data_list[0]) 145 | 146 | # Compute and print training hyperparameters 147 | self.batch_per_epoch = (len(self.train_data_list)) // self.batch_size 148 | max_steps = int(self.epoch * (self.batch_per_epoch)) 149 | print_("Number of training data: {}\nNumber of batches per epoch: {} (batch size = {})\nNumber of training steps for {} epochs: {}\n" 150 | .format(len(self.train_data_list), self.batch_per_epoch, self.batch_size, self.epoch, max_steps), 'm') 151 | 152 | # Get validation dataset if provided 153 | self.has_val_data = True 154 | self.val_data_list = get_filepaths_from_dir(val_data_path) 155 | if not self.val_data_list: 156 | print("No validation data found in {}".format(val_data_path)) 157 | self.has_val_data = False 158 | elif (len(self.val_data_list) < self.batch_size): 159 | raise ValueError("Batch size must be smaller than the dataset (batch size = {}, number of validation data = {})" 160 | .format(self.batch_size, len(self.val_data_list))) 161 | else: 162 | val_is_exr = is_exr(self.val_data_list[0]) 163 | if (val_is_exr and not self.is_exr) or (not val_is_exr and self.is_exr): 164 | raise TypeError("Train and validation data should have the same file format") 165 | self.val_batch_per_epoch = (len(self.val_data_list)) // self.batch_size 166 | print("Number of validation data: {}\nNumber of validation batches per epoch: {} (batch size = {})" 167 | .format(len(self.val_data_list), self.val_batch_per_epoch, self.batch_size)) 168 | 169 | def get_data(self, data_list, batch_size=16, epoch=100, shuffle_buffer_size=1000): 170 | 171 | def read_and_preprocess_data(path_img, param): 172 | """Read image in path_img, resize it to patch_size, 173 | convert to grayscale and apply a random gamma grade to it 174 | 175 | Returns: 176 | input_data: stack of both original and graded image histograms 177 | param: groundtruth gamma value 178 | """ 179 | if self.is_exr: # ['exr', 'EXR'] 180 | img = tf.numpy_function(read_resize_exr, 181 | [path_img, self.patch_size], [tf.float32]) 182 | img = tf.numpy_function(linear_to_srgb, [img], [tf.float32]) 183 | img = tf.reshape(img, [self.patch_size, self.patch_size, self.channels]) 184 | img = tf.image.rgb_to_grayscale(img) 185 | else: # ['jpg', 'jpeg', 'png', 'bmp', 'JPG', 'JPEG', 'PNG', 'BMP'] 186 | img_raw = tf.io.read_file(path_img) 187 | img_tensor = tf.image.decode_png(img_raw, channels=3) 188 | img = tf.cast(img_tensor, tf.float32) / 255.0 189 | img = tf.image.rgb_to_grayscale(img) 190 | img = tf.image.resize(img, [self.patch_size, self.patch_size]) 191 | # Depending on what parameter(s) you want to learn, modify the training 192 | # input data. Here to learn gamma correction, our input data trainX is 193 | # a stack of both original and gamma-graded histograms. 194 | input_data = gamma_correction(img, param) 195 | return input_data, param 196 | 197 | with tf.compat.v1.variable_scope('input'): 198 | # Ensure preprocessing is done on the CPU (to let the GPU focus on training) 199 | with tf.device('/cpu:0'): 200 | data_tensor = tf.convert_to_tensor(data_list, dtype=tf.string) 201 | path_dataset = tf.data.Dataset.from_tensor_slices((data_tensor)) 202 | path_dataset = path_dataset.shuffle(shuffle_buffer_size).repeat(epoch) 203 | # Depending on what parameter(s) you want to learn, modify the random 204 | # uniform range. Here create random gamma values between 0.2 and 5 205 | param_tensor = tf.random.uniform( 206 | [len(data_list)*epoch, self.output_param_number], 0.2, 5.0) 207 | param_dataset = tf.data.Dataset.from_tensor_slices((param_tensor)) 208 | dataset = tf.data.Dataset.zip((path_dataset, param_dataset)) 209 | # Apply read_and_preprocess_data function to all input in the path_dataset 210 | dataset = dataset.map(read_and_preprocess_data, num_parallel_calls=4) 211 | dataset = dataset.batch(batch_size) 212 | # Always prefetch one batch and make sure there is always one ready 213 | dataset = dataset.prefetch(buffer_size=1) 214 | return dataset 215 | 216 | def tensorboard_callback(self, writer): 217 | """Return custom Tensorboard callback for logging main metrics""" 218 | 219 | def log_metrics(epoch, logs): 220 | """Log training/validation loss and accuracy to Tensorboard""" 221 | with writer.as_default(), tf.contrib.summary.always_record_summaries(): 222 | tf.contrib.summary.scalar('train_loss', logs['loss'], step=epoch) 223 | tf.contrib.summary.scalar('train_bin_acc', logs['bin_acc'], step=epoch) 224 | if self.has_val_data: 225 | tf.contrib.summary.scalar('val_loss', logs['val_loss'], step=epoch) 226 | tf.contrib.summary.scalar('val_bin_acc', logs['val_bin_acc'], step=epoch) 227 | tf.contrib.summary.flush() 228 | 229 | return tf.keras.callbacks.LambdaCallback(on_epoch_end=log_metrics) 230 | 231 | def get_compiled_model(self, input_shape): 232 | model = baseline_model( 233 | input_shape, 234 | output_param_number=self.output_param_number) 235 | adam = tf.keras.optimizers.Adam(lr=self.learning_rate) 236 | model.compile(optimizer=adam, 237 | loss='mean_squared_error', 238 | metrics=[bin_acc]) 239 | return model 240 | 241 | def train(self): 242 | # Create a session so that tf.keras don't allocate all GPU memory at once 243 | sess = tf.compat.v1.Session( 244 | config=tf.compat.v1.ConfigProto( 245 | gpu_options=tf.compat.v1.GPUOptions(allow_growth=True))) 246 | tf.compat.v1.keras.backend.set_session(sess) 247 | 248 | # Get training and validation dataset 249 | ds_train = self.get_data( 250 | self.train_data_list, 251 | self.batch_size, 252 | self.epoch) 253 | for x, y in ds_train.take(1): # take one batch from ds_train 254 | trainX, trainY = x, y 255 | print("Input shape {}, target shape: {}".format(trainX.shape, trainY.shape)) 256 | if self.has_val_data: 257 | ds_val = self.get_data( 258 | self.val_data_list, 259 | self.batch_size, 260 | self.epoch) 261 | print("********Data Created********") 262 | 263 | # Build model 264 | model = self.get_compiled_model(trainX.shape[1:]) 265 | 266 | # Check if there are intermediate trained model to load 267 | if self.no_resume or not self.load(model): 268 | print_("Starting training from scratch\n", 'm') 269 | 270 | # Callback for creating Tensorboard summary 271 | summary_name = ("data{}_bch{}_ep{}".format( 272 | len(self.train_data_list), self.batch_size, self.epoch)) 273 | summary_name += ("_seed{}".format(self.seed) if self.seed is not None else "") 274 | summary_writer = tf.contrib.summary.create_file_writer( 275 | os.path.join(self.summaries_dir, summary_name)) 276 | tb_callback = self.tensorboard_callback(summary_writer) 277 | 278 | # Callback for saving model's weights 279 | ckpt_path = os.path.join(self.ckpt_dir, self.ckpt_save_name + "-ep{epoch:02d}") 280 | ckpt_callback = tf.keras.callbacks.ModelCheckpoint( 281 | filepath=ckpt_path, 282 | # save best model based on monitor value 283 | monitor='val_loss' if self.has_val_data else 'loss', 284 | verbose=1, 285 | save_best_only=True, 286 | save_weights_only=True) 287 | 288 | # Evaluate the model before training 289 | if self.has_val_data: 290 | val_loss, val_bin_acc = model.evaluate(ds_val.take(20), verbose=1) 291 | print("Initial Loss on validation dataset: {:.4f}".format(val_loss)) 292 | 293 | # TRAIN model 294 | print_("--------Start of training--------\n", 'm') 295 | print("NOTE:\tDuring training, the latest model is saved only if its\n" 296 | "\t(validation) loss is better than the last best model.") 297 | train_start = time.time() 298 | model.fit( 299 | ds_train, 300 | validation_data=ds_val if self.has_val_data else None, 301 | epochs=self.epoch, 302 | steps_per_epoch=self.batch_per_epoch, 303 | validation_steps=self.val_batch_per_epoch if self.has_val_data else None, 304 | callbacks=[ckpt_callback, tb_callback], 305 | verbose=1) 306 | print_("Training duration: {:0.4f}s\n".format(time.time() - train_start), 'm') 307 | print_("--------End of training--------\n", 'm') 308 | 309 | # Show predictions on the first batch of training data 310 | print("Parameter prediction (PR) compared to groundtruth (GT) for first batch of training data:") 311 | preds_train = model.predict(trainX.numpy()) 312 | print("Train GT:", trainY.numpy().flatten()) 313 | print("Train PR:", preds_train.flatten()) 314 | # Make predictions on the first batch of validation data 315 | if self.has_val_data: 316 | print("For first batch of validation data:") 317 | for x, y in ds_val.take(1): # take one batch from ds_val 318 | valX, valY = x, y 319 | preds_val = model.predict(valX) 320 | print("Val GT:", valY.numpy().flatten()) 321 | print("Val PR:", preds_val.flatten()) 322 | # Free all resources associated with the session 323 | sess.close() 324 | 325 | def load(self, model): 326 | ckpt_names = get_ckpt_list(self.ckpt_dir) 327 | if not ckpt_names: # list is empty 328 | print_("No checkpoints found in {}\n".format(self.ckpt_dir), 'm') 329 | return False 330 | else: 331 | print_("Found checkpoints:\n", 'm') 332 | for name in ckpt_names: 333 | print(" {}".format(name)) 334 | # Ask user if they prefer to start training from scratch or resume training on a specific ckeckpoint 335 | while True: 336 | mode=str(input('Start training from scratch (start) or resume training from a previous checkpoint (choose one of the above): ')) 337 | if mode == 'start' or mode in ckpt_names: 338 | break 339 | else: 340 | print("Answer should be 'start' or one of the following checkpoints: {}".format(ckpt_names)) 341 | continue 342 | if mode == 'start': 343 | return False 344 | elif mode in ckpt_names: 345 | # Try to load given intermediate checkpoint 346 | print_("Loading trained model...\n", 'm') 347 | model.load_weights(os.path.join(self.ckpt_dir, mode)) 348 | print_("...Checkpoint {} loaded\n".format(mode), 'm') 349 | return True 350 | else: 351 | raise ValueError("User input is neither 'start' nor a valid checkpoint") 352 | 353 | def evaluate(self, test_data_path, weights): 354 | """Evaluate a trained model on the test dataset 355 | 356 | Args: 357 | test_data_path (str): path to directory containing images for testing 358 | weights (str): name of the tensorflow checkpoint (weights) to evaluate 359 | """ 360 | test_data_list = get_filepaths_from_dir(test_data_path) 361 | if not test_data_list: 362 | raise ValueError("No test data found in folder {}".format(test_data_path)) 363 | elif (len(self.train_data_list) < self.batch_size): 364 | raise ValueError("Batch size must be smaller than the dataset (batch size = {}, number of test data = {})" 365 | .format(self.batch_size, len(test_data_list))) 366 | self.is_exr = is_exr(test_data_list[0]) 367 | 368 | # Get and create test dataset 369 | ds_test = self.get_data( 370 | test_data_list, 371 | self.batch_size, 372 | 1) 373 | for x, y in ds_test.take(1): # take one batch from ds_test 374 | testX, testY = x, y 375 | print_("Number of test data: {}\n".format(len(test_data_list)), 'm') 376 | print("Input shape {}, target shape: {}".format(testX.shape, testY.shape)) 377 | 378 | # Build model 379 | model = self.get_compiled_model(testX.shape[1:]) 380 | 381 | # Load model weights 382 | print_("Loading trained model for testing...\n", 'm') 383 | model.load_weights(os.path.join(self.ckpt_dir, weights)).expect_partial() 384 | print_("...Checkpoint {} loaded\n".format(weights), 'm') 385 | 386 | # Test final model on this unseen dataset 387 | results = model.evaluate(ds_test) 388 | print("test loss, test acc:", results) 389 | print_("--------End of testing--------\n", 'm') 390 | 391 | def parse_args(): 392 | parser = argparse.ArgumentParser(description='Model training arguments') 393 | parser.add_argument('--bch', type=int, default=10, dest='batch_size', help='training batch size') 394 | parser.add_argument('--ep', type=int, default=15, dest='epoch', help='training epoch number') 395 | parser.add_argument('--lr', type=float, default=1e-3, dest='learning_rate', help='initial learning rate') 396 | parser.add_argument('--seed', type=int, default=None, dest='seed', help='set random seed for deterministic training') 397 | parser.add_argument('--no-gpu-patch', dest='no_gpu_patch', default=False, action='store_true', help='if seed is set, add this tag for much faster but slightly less deterministic training') 398 | parser.add_argument('--no-resume', dest='no_resume', default=False, action='store_true', help="start training from scratch") 399 | parser.add_argument('--name', type=str, default="regressionTemplateTF", dest='ckpt_save_name', help='name of saved checkpoints/model weights') 400 | args = parser.parse_args() 401 | return args 402 | 403 | if __name__ == '__main__': 404 | args = parse_args() 405 | # Set up model to train 406 | model = TrainModel(args) 407 | model.train() 408 | # To evaluate on the test dataset, uncomment next line and give the 409 | # test dataset directory and the model checkpoint name 410 | # model.evaluate('./data/test', 'regressionTemplateTF-ep35') -------------------------------------------------------------------------------- /Models/trainingTemplateTF/README.md: -------------------------------------------------------------------------------- 1 | # Training Template: Train and Infer Models in the nuke-ML-server 2 | 3 | The TrainingTemplateTF model is a training template written in TensorFlow. It aims at quickly enabling image-to-image training using a multi-scale encoder-decoder model. When trained, the model can be tested and used directly in Nuke through the nuke-ML-server. 4 | 5 | For instance, if you have a set of noisy / clear image pairs and would like to train a model to be able to denoise an image, you simply need to fill in your data in the `TrainingTemplateTF/data` and start the training with one command line. You can monitor the training using TensorBoard and eventually test the trained model on live images in Nuke. 6 | 7 | This page contains instructions on how to use this training template. The training happens in the Docker container, while the inference is done through the MLClient plugin. 8 | 9 | ## Set-up 10 | 11 | Start by installing the nuke-ML-server (see [INSTALL.md](https://github.com/TheFoundryVisionmongers/nuke-ML-server/blob/master/INSTALL.md)). If you had already installed the previous version, you will still have to rebuild the docker image once: 12 | ``` 13 | cd Plugins/Server/ 14 | sudo docker build -t -f Dockerfile . 15 | ``` 16 | 17 | To launch the [TensorBoard Visualisation](https://github.com/TheFoundryVisionmongers/nuke-ML-server/tree/master/Models/trainingTemplateTF#tensorboard) from within the Docker, you have to run the docker container ([Run Docker Container](https://github.com/TheFoundryVisionmongers/nuke-ML-server/blob/master/INSTALL.md#run-docker-container) section) with an exported port 6006: 18 | ``` 19 | sudo docker run --gpus all -v /absolute/path/to/nuke-ML-server/Models/:/workspace/ml-server/models -p 6006:6006 -it 20 | ``` 21 | 22 | ## Train in Docker 23 | 24 | ### Dataset 25 | 26 | To train the ML algorithm, your need to provide a dataset of groundtruth & input image data pairs. For instance, the input data could be blurred images and the groundtruth corresponding sharp images. In that case, you would like the model to learn to infer a sharp image out of a blurred input image. 27 | 28 | Respectively place your input and groundtruth data in `trainingTemplateTF/data/train/input/` and `trainingTemplateTF/data/train/groundtruth/` folders. 29 | 30 | Optionally, you can add a separate set of image pairs in `trainingTemplateTF/data/val/input/`and `trainingTemplateTF/data/val/groundtruth/`. If this validation dataset is available, it is periodically tested on the current model weights to check that there is no overfitting on the training data. Please note that the validation dataset and training dataset must not intersect, no image pair should be found in both datasets. 31 | 32 | Notes: 33 | - The preprocessing cropping size is currently 256x256, therefore the dataset images are expected to be at least 256x256. 34 | - Supported image types are JPG, PNG, BMP and EXR. 35 | - Depending on the compression used, EXR images can be slower to read. In our experiments, the fastest EXR read is achieved with B44, B44A or no compression. 36 | 37 | ### Training 38 | 39 | Inside your docker container, go to the trainingTemplateTF folder: 40 | ``` 41 | cd /workspace/ml-server/models/trainingTemplateTF 42 | ``` 43 | Then directly train your model: 44 | ``` 45 | python train_model.py 46 | ``` 47 | You can also specify the batch size, learning rate and number of epochs: 48 | ``` 49 | python train_model.py --bch=16 --lr=1e-4 --ep=1000 50 | ``` 51 | It is now possible to have deterministic training. You will be able to reproduce your training (get same model weights) by setting the seed to a random int number (here 77): 52 | ``` 53 | python train_model.py --seed=77 54 | ``` 55 | We enable deterministic training in part by applying a GPU patch to the stock TensorFlow, this GPU patch slows down training significantly. By adding the `--no-gpu-patch` tag to the previous command, you achieve a slighlty less deterministic training but keep the same training time. 56 | 57 | ### Potential Training Issues 58 | 59 | The principal issue you may hit when training is a GPU out-of-memory (OOM) error. To apply training with default values, your GPU memory should be at least 8GB. 60 | 61 | If you reach an OOM error, you can consider reducing the GPU memory requirements -likely at the expense of the final model performance- by: 62 | - Building a simplified version of the encoder-decoder model found in [`model_builder.py`](https://github.com/TheFoundryVisionmongers/nuke-ML-server/blob/master/Models/trainingTemplateTF/util/model_builder.py) (e.g. by removing layers), 63 | - Reducing the batch size (`--bch` argument), 64 | - Or lowering the preprocessing cropping size (`crop_size` in [`train_model.py`](https://github.com/TheFoundryVisionmongers/nuke-ML-server/blob/master/Models/trainingTemplateTF/train_model.py)). 65 | 66 | During training, images are cropped as a preprocessing step before being fed to the network. Therefore if you want your model to learn a global image information (e.g. lens distortion), this cropping preprocessing should be changed in the code (e.g. use resize & padding instead), so as to keep the whole image information. 67 | 68 | ### TensorBoard 69 | 70 | [TensorBoard](https://www.tensorflow.org/guide/summaries_and_tensorboard) is a great way to visualise how your training is progressing. 71 | 72 | The TrainingTemplateTF automatically saves learning rate and loss evolution as well as input, groundtruth and temporary output images in the `trainingTemplateTF/summaries/` folder. 73 | 74 | To view these TensorBoard summaries, first find which container is currently running your training (STATUS: Up, PORTS: 0.0.0.0:6006->6006/tcp, NAMES=``) from all the created docker containers: 75 | ``` 76 | sudo docker ps -a 77 | ``` 78 | Launch a second terminal connected to the same docker container, where `` is the name of your training container found above: 79 | ``` 80 | docker exec -it bash 81 | ``` 82 | Launch TensorBoard in this new docker terminal to view the progression in real-time in your browser: 83 | ``` 84 | tensorboard --logdir models/trainingTemplateTF/summaries/ 85 | ``` 86 | From your host machine, you can now navigate to the following browser address to monitor your training: http://localhost:6006. 87 | 88 | ### Checkpoints 89 | 90 | During training, the model weights and graph are saved every N steps and put in the `trainingTemplateTF/checkpoints/` folder. A checkpoint name, for instance `trainingTemplateTF.model-375000` means that it contains the weights after 375,000 training steps using model trainingTemplateTF. 91 | 92 | When launching a training, you can decide to start from scratch or resume training from a list of previous checkpoints. 93 | 94 | ## Inference in Nuke 95 | 96 | After training your model inside the docker container, you can launch Nuke and select the `Training Template TF` model in the MLClient node. 97 | 98 | The plugin will automatically load the most advanced trained checkpoints found in `trainingTemplateTF/checkpoints/`, and run an inference using the loaded weights and graph. If you prefer to use older checkpoints, you can write the name of a previous checkpoint as an inference option in Nuke. 99 | 100 | This is a great way to verify on your own live-data that the model weights converged correctly without overfitting on the training data. 101 | 102 | Note: the inference is done on saved checkpoints and not on a frozen graph, which implies that the saved checkpoint graph must correspond to the current graph. If you change the graph (by changing the preprocessing step, number of layers, variable names etc.), you won't directly be able to load older checkpoints built on a different graph. -------------------------------------------------------------------------------- /Models/trainingTemplateTF/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheFoundryVisionmongers/nuke-ML-server/5dd4d04cd673c60de8093c600d6c54016fca92d6/Models/trainingTemplateTF/__init__.py -------------------------------------------------------------------------------- /Models/trainingTemplateTF/data/train/groundtruth/alive00001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheFoundryVisionmongers/nuke-ML-server/5dd4d04cd673c60de8093c600d6c54016fca92d6/Models/trainingTemplateTF/data/train/groundtruth/alive00001.png -------------------------------------------------------------------------------- /Models/trainingTemplateTF/data/train/input/alive_snow00001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheFoundryVisionmongers/nuke-ML-server/5dd4d04cd673c60de8093c600d6c54016fca92d6/Models/trainingTemplateTF/data/train/input/alive_snow00001.png -------------------------------------------------------------------------------- /Models/trainingTemplateTF/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Foundry. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | ############################################################################## 15 | 16 | from __future__ import print_function 17 | 18 | import sys 19 | import os 20 | import time 21 | 22 | import scipy.misc 23 | import numpy as np 24 | import cv2 25 | 26 | import tensorflow as tf 27 | tf.compat.v1.disable_eager_execution() # For TF 2.x compatibility 28 | 29 | from ..baseModel import BaseModel 30 | from ..common.model_builder import EncoderDecoder 31 | from ..common.util import print_, get_ckpt_list, linear_to_srgb, srgb_to_linear 32 | 33 | class Model(BaseModel): 34 | """Load your trained model and do inference in Nuke""" 35 | 36 | def __init__(self): 37 | super(Model, self).__init__() 38 | self.name = 'Training Template TF' 39 | self.n_levels = 3 40 | self.scale = 0.5 41 | dir_path = os.path.dirname(os.path.realpath(__file__)) 42 | self.checkpoints_dir = os.path.join(dir_path, 'checkpoints') 43 | self.batch_size = 1 44 | 45 | # Initialise checkpoint name to the most advanced checkpoint (highest step) 46 | ckpt_names = get_ckpt_list(self.checkpoints_dir) 47 | if not ckpt_names: # empty list 48 | self.checkpoint_name = '' 49 | else: 50 | ckpt_steps = [int(name.split('-')[-1]) for name in ckpt_names] 51 | self.checkpoint_name = ckpt_names[ckpt_steps.index(max(ckpt_steps))] 52 | self.prev_ckpt_name = self.checkpoint_name 53 | 54 | self.options = ('checkpoint_name',) 55 | # Define inputs/outputs 56 | self.inputs = {'input': 3} 57 | self.outputs = {'output': 3} 58 | 59 | def load(self, sess, checkpoint_dir): 60 | # Check if empty or invalid checkpoint name 61 | if self.checkpoint_name=='': 62 | ckpt_names = get_ckpt_list(self.checkpoints_dir) 63 | if not ckpt_names: 64 | raise ValueError("No checkpoints found in {}".format(self.checkpoints_dir)) 65 | else: 66 | raise ValueError("Empty checkpoint name, try an available checkpoint in {} (ex: {})" 67 | .format(self.checkpoints_dir, ckpt_names[-1])) 68 | print_("Loading trained model checkpoint...\n", 'm') 69 | # Load from given checkpoint file name 70 | self.saver.restore(sess, os.path.join(checkpoint_dir, self.checkpoint_name)) 71 | print_("...Checkpoint {} loaded\n".format(self.checkpoint_name), 'm') 72 | 73 | def inference(self, image_list): 74 | """Do an inference on the model with a set of inputs. 75 | 76 | # Arguments: 77 | image_list: The input image list 78 | 79 | Return the result of the inference. 80 | """ 81 | image = image_list[0] 82 | image = linear_to_srgb(image).copy() 83 | H, W, channels = image.shape 84 | 85 | # Add padding so that width and height of image are a multiple of 16 86 | new_H = int(H + 16 - H%16) if H%16!=0 else H 87 | new_W = int(W + 16 - W%16) if W%16!=0 else W 88 | img_pad = np.pad(image, ((0, new_H - H), (0, new_W - W), (0, 0)), 'reflect') 89 | 90 | if not hasattr(self, 'sess'): 91 | # Initialise input placeholder size 92 | self.curr_height = new_H; self.curr_width = new_W 93 | # Initialise tensorflow graph 94 | tf.compat.v1.reset_default_graph() 95 | config = tf.compat.v1.ConfigProto() 96 | config.gpu_options.allow_growth=True 97 | self.sess=tf.compat.v1.Session(config=config) 98 | self.input = tf.compat.v1.placeholder(tf.float32, shape=[self.batch_size, new_H, new_W, channels]) 99 | self.model = EncoderDecoder(self.n_levels, self.scale, channels) 100 | self.infer_op = self.model(self.input, reuse=False) 101 | # Load model checkpoint having the longest training (highest step) 102 | self.saver = tf.compat.v1.train.Saver() 103 | self.load(self.sess, self.checkpoints_dir) 104 | self.prev_ckpt_name = self.checkpoint_name 105 | 106 | elif self.curr_height != new_H or self.curr_width != new_W: 107 | # Modify input placeholder size 108 | self.input = tf.compat.v1.placeholder(tf.float32, shape=[self.batch_size, new_H, new_W, channels]) 109 | self.infer_op = self.model(self.input, reuse=False) 110 | # Update image height and width 111 | self.curr_height = new_H; self.curr_width = new_W 112 | 113 | # If checkpoint name has changed, load new checkpoint 114 | if self.prev_ckpt_name != self.checkpoint_name or self.checkpoint_name == '': 115 | self.load(self.sess, self.checkpoints_dir) 116 | # If checkpoint correctly loaded, update previous checkpoint name 117 | self.prev_ckpt_name = self.checkpoint_name 118 | 119 | # Apply current model to the padded input image 120 | image_batch = np.expand_dims(img_pad, 0) 121 | start = time.time() 122 | # The network is expecting image_batch to be of type tf.float32 123 | inference = self.sess.run(self.infer_op, feed_dict={self.input: image_batch}) 124 | duration = time.time() - start 125 | print('Inference duration: {:4.3f}s'.format(duration)) 126 | res = inference[-1] 127 | # Remove first dimension and padding 128 | res = res[0, :H, :W, :] 129 | 130 | output_image = srgb_to_linear(res) 131 | return [output_image] -------------------------------------------------------------------------------- /Models/trainingTemplateTF/train_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Foundry. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | ############################################################################## 15 | 16 | from __future__ import print_function 17 | from builtins import input, range # python 2/3 forward-compatible (input_raw, xrange) 18 | 19 | import sys 20 | import os 21 | import time 22 | import random 23 | import argparse 24 | from datetime import datetime 25 | 26 | import numpy as np 27 | import tensorflow as tf 28 | print(tf.__version__) 29 | tf.compat.v1.disable_eager_execution() # For TF 2.x compatibility 30 | 31 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 32 | from common.model_builder import EncoderDecoder 33 | from common.util import im2uint8, get_filepaths_from_dir, get_ckpt_list, print_ 34 | from common.util import is_exr, read_crop_exr_pair, linear_to_srgb 35 | 36 | def enable_deterministic_training(seed, no_gpu_patch=False): 37 | """Set all seeds for deterministic training 38 | This function needs to be called before any tensorflow code. 39 | """ 40 | import numpy as np 41 | import os 42 | import random 43 | import tfdeterminism 44 | if not no_gpu_patch: 45 | # Patch stock TensorFlow to have deterministic GPU operation 46 | tfdeterminism.patch() # then use tf as normal 47 | # If PYTHONHASHSEED environment variable is not set or set to random, 48 | # a random value is used to seed the hashes of str, bytes and datetime 49 | # objects. (Necessary for Python >= 3.2.3) 50 | os.environ['PYTHONHASHSEED']=str(seed) 51 | # Set python built-in pseudo-random generator at a fixed value 52 | random.seed(seed) 53 | # Set seed for random Numpy operation (e.g. np.random.randint) 54 | np.random.seed(seed) 55 | # Set seed for random TensorFlow operation (e.g. tf.image.random_crop) 56 | tf.compat.v1.random.set_random_seed(seed) 57 | 58 | class TrainModel(object): 59 | """Train the EncoderDecoder from the given input and groundtruth data""" 60 | 61 | def __init__(self, args): 62 | # Training hyperparameters 63 | self.learning_rate = args.learning_rate 64 | self.batch_size = args.batch_size 65 | self.epoch = args.epoch 66 | self.no_resume = args.no_resume 67 | # A random seed (!=None) allows you to reproduce your training results 68 | self.seed = args.seed 69 | if self.seed is not None: 70 | # Set all seeds necessary for deterministic training 71 | enable_deterministic_training(self.seed, args.no_gpu_patch) 72 | self.crop_size = 256 73 | self.n_levels = 3 74 | self.scale = 0.5 75 | self.channels = 3 # input / output channels 76 | # Training and validation dataset paths 77 | train_in_data_path = './data/train/input' 78 | train_gt_data_path = './data/train/groundtruth' 79 | val_in_data_path = './data/validation/input' 80 | val_gt_data_path = './data/validation/groundtruth' 81 | 82 | # Where to save and load model weights (=checkpoints) 83 | self.checkpoints_dir = './checkpoints' 84 | if not os.path.exists(self.checkpoints_dir): 85 | os.makedirs(self.checkpoints_dir) 86 | self.ckpt_save_name = args.ckpt_save_name 87 | # Maximum number of recent checkpoint files to keep 88 | self.max_ckpts_to_keep = 50 89 | # In addition keep one checkpoint file for every N hours of training 90 | self.keep_ckpt_every_n_hours = 1 91 | # How often, in training steps. we save model checkpoints 92 | self.ckpts_save_freq = 1000 93 | # How often, in training steps. we print training losses to bash 94 | self.training_print_freq = 10 95 | 96 | # Where to save tensorboard summaries 97 | self.summaries_dir = './summaries' 98 | if not os.path.exists(self.summaries_dir): 99 | os.makedirs(self.summaries_dir) 100 | # How often, in training steps. we save tensorboard summaries 101 | self.summaries_save_freq = 10 102 | # How often, in secs, we flush the pending tensorboard summaries to disk 103 | self.summary_flush_secs = 30 104 | 105 | # Get training dataset as lists of image paths 106 | self.train_in_data_list = get_filepaths_from_dir(train_in_data_path) 107 | self.train_gt_data_list = get_filepaths_from_dir(train_gt_data_path) 108 | if not self.train_in_data_list or not self.train_gt_data_list: 109 | raise ValueError("No training data found in folders {} or {}".format(train_in_data_path, train_gt_data_path)) 110 | elif len(self.train_in_data_list) != len(self.train_gt_data_list): 111 | raise ValueError("{} ({} data) and {} ({} data) should have the same number of input data" 112 | .format(train_in_data_path, len(self.train_in_data_list), train_gt_data_path, len(self.train_gt_data_list))) 113 | elif (len(self.train_in_data_list) < self.batch_size): 114 | raise ValueError("Batch size must be smaller than the dataset (batch size = {}, number of training data = {})" 115 | .format(self.batch_size, len(self.train_in_data_list))) 116 | self.is_exr = is_exr(self.train_in_data_list[0]) 117 | 118 | # Get validation dataset if provided 119 | self.has_val_data = True 120 | self.val_in_data_list = get_filepaths_from_dir(val_in_data_path) 121 | self.val_gt_data_list = get_filepaths_from_dir(val_gt_data_path) 122 | if not self.val_in_data_list or not self.val_gt_data_list: 123 | print("No validation data found in {} or {}".format(val_in_data_path, val_gt_data_path)) 124 | self.has_val_data = False 125 | elif len(self.val_in_data_list) != len(self.val_gt_data_list): 126 | raise ValueError("{} ({} data) and {} ({} data) should have the same number of input data" 127 | .format(val_in_data_path, len(self.val_in_data_list), val_gt_data_path, len(self.val_gt_data_list))) 128 | elif (len(self.val_in_data_list) < self.batch_size): 129 | raise ValueError("Batch size must be smaller than the dataset (batch size = {}, number of validation data = {})" 130 | .format(self.batch_size, len(self.val_in_data_list))) 131 | else: 132 | val_is_exr = is_exr(self.val_in_data_list[0]) 133 | if (val_is_exr and not self.is_exr) or (not val_is_exr and self.is_exr): 134 | raise TypeError("Train and validation data should have the same file format") 135 | print("Number of validation data: {}".format(len(self.val_in_data_list))) 136 | 137 | # Compute and print training hyperparameters 138 | batch_per_epoch = (len(self.train_in_data_list)) // self.batch_size 139 | self.max_steps = int(self.epoch * (batch_per_epoch)) 140 | print_("Number of training data: {}\nNumber of batches per epoch: {} (batch size = {})\nNumber of training steps for {} epochs: {}\n" 141 | .format(len(self.train_in_data_list), batch_per_epoch, self.batch_size, self.epoch, self.max_steps), 'm') 142 | 143 | def get_data(self, in_data_list, gt_data_list, batch_size=16, epoch=100): 144 | 145 | def read_and_preprocess(path_img_in, path_img_gt): 146 | if self.is_exr: # ['exr', 'EXR'] 147 | # Read and crop data 148 | img_crop = tf.numpy_function(read_crop_exr_pair, 149 | [path_img_in, path_img_gt, self.crop_size], [tf.float32, tf.float32]) 150 | img_crop = tf.numpy_function(linear_to_srgb, [img_crop], tf.float32) 151 | img_crop = tf.unstack(tf.reshape(img_crop, [2, self.crop_size, self.crop_size, self.channels])) 152 | else: # ['jpg', 'jpeg', 'png', 'bmp', 'JPG', 'JPEG', 'PNG', 'BMP'] 153 | # Read data 154 | img_in_raw = tf.io.read_file(path_img_in) 155 | img_gt_raw = tf.io.read_file(path_img_gt) 156 | img_in_tensor = tf.image.decode_image(img_in_raw, channels=self.channels) 157 | img_gt_tensor = tf.image.decode_image(img_gt_raw, channels=self.channels) 158 | # Normalise then crop data 159 | imgs = [tf.cast(img, tf.float32) / 255.0 for img in [img_in_tensor, img_gt_tensor]] 160 | img_crop = tf.unstack(tf.image.random_crop(tf.stack(imgs, axis=0), 161 | [2, self.crop_size, self.crop_size, self.channels], seed=self.seed), axis=0) 162 | return img_crop 163 | 164 | def multi_thread_preprocess(path_img_in, path_img_gt): 165 | """Non-random data preprocessing to be run in a multi-thread map 166 | Read image in path_img, and normalize it 167 | """ 168 | if self.is_exr: 169 | # Do nothing, all preprocessing done in single_thread_preprocess 170 | return path_img_in, path_img_gt 171 | else: 172 | img_in_raw = tf.io.read_file(path_img_in) 173 | img_gt_raw = tf.io.read_file(path_img_gt) 174 | img_in_tensor = tf.image.decode_image(img_in_raw, channels=self.channels) 175 | img_gt_tensor = tf.image.decode_image(img_gt_raw, channels=self.channels) 176 | # Normalise data 177 | imgs = [tf.cast(img, tf.float32) / 255.0 for img in [img_in_tensor, img_gt_tensor]] 178 | return imgs 179 | 180 | def single_thread_preprocess(img_in, img_gt): 181 | """Random data preprocessing to be run in a one thread map 182 | Crop image with deterministic TensorFlow (png) or Numpy (exr) seed 183 | """ 184 | if self.is_exr: 185 | img_crop = tf.numpy_function(read_crop_exr_pair, 186 | [img_in, img_gt, self.crop_size], [tf.float32, tf.float32]) 187 | img_crop = tf.numpy_function(linear_to_srgb, [img_crop], tf.float32) 188 | img_crop = tf.unstack(tf.reshape(img_crop, [2, self.crop_size, self.crop_size, self.channels])) 189 | else: 190 | img_crop = tf.unstack(tf.image.random_crop(tf.stack([img_in, img_gt], axis=0), 191 | [2, self.crop_size, self.crop_size, self.channels], seed=self.seed), axis=0) 192 | return img_crop 193 | 194 | with tf.compat.v1.variable_scope('input'): 195 | # Ensure preprocessing is done on the CPU (to let the GPU focus on training) 196 | with tf.device('/cpu:0'): 197 | in_list = tf.convert_to_tensor(in_data_list, dtype=tf.string) 198 | gt_list = tf.convert_to_tensor(gt_data_list, dtype=tf.string) 199 | 200 | path_dataset = tf.data.Dataset.from_tensor_slices((in_list, gt_list)) 201 | path_dataset = path_dataset.shuffle( 202 | buffer_size=len(in_data_list), seed=self.seed).repeat(epoch) 203 | # Apply read_and_preprocess function to all input in the path_dataset 204 | if self.seed is None: 205 | # Run all preprocessing in one dataset.map() 206 | num_parallel_calls = 1 if self.is_exr else 4 207 | dataset = path_dataset.map(read_and_preprocess, num_parallel_calls) 208 | else: 209 | # Perform the non-random ops in a multi-threaded map() 210 | dataset = path_dataset.map(multi_thread_preprocess, num_parallel_calls=4) 211 | # Perform the random ops in a single-threaded map() for 212 | # deterministic training when seed is not None 213 | dataset = dataset.map(single_thread_preprocess, num_parallel_calls=1) 214 | dataset = dataset.batch(batch_size) 215 | # Always prefetch one batch and make sure there is always one ready 216 | dataset = dataset.prefetch(buffer_size=1) 217 | # Create operator to iterate over the created dataset 218 | next_element = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next() 219 | return next_element 220 | 221 | def loss(self, n_outputs, img_gt): 222 | """Compute multi-scale loss function""" 223 | loss_total = 0 224 | for i in range(self.n_levels): 225 | _, hi, wi, _ = n_outputs[i].shape 226 | gt_i = tf.image.resize(img_gt, [hi, wi], method='bilinear') 227 | loss = tf.reduce_mean(tf.square(gt_i - n_outputs[i])) 228 | loss_total += loss 229 | # Save out images and loss values to tensorboard 230 | tf.compat.v1.summary.image('out_' + str(i), im2uint8(n_outputs[i])) 231 | # Save total loss to tensorboard 232 | tf.compat.v1.summary.scalar('loss_total', loss_total) 233 | return loss_total 234 | 235 | def validate(self, model): 236 | total_val_loss = 0.0 237 | # Get next data from preprocessed validation dataset 238 | val_img_in, val_img_gt = self.get_data(self.val_in_data_list, self.val_gt_data_list, self.batch_size, -1) 239 | n_outputs = model(val_img_in, reuse=False) 240 | val_op = self.loss(n_outputs, val_img_gt) 241 | # Test results over one epoch 242 | batch_per_epoch = len(self.val_in_data_list) // self.batch_size 243 | for batch in range(batch_per_epoch): 244 | total_val_loss += val_op 245 | return total_val_loss / batch_per_epoch 246 | 247 | def train(self): 248 | # Build model 249 | model = EncoderDecoder(self.n_levels, self.scale, self.channels) 250 | 251 | # Learning rate decay 252 | global_step = tf.Variable(initial_value=0, dtype=tf.int32, trainable=False) 253 | self.lr = tf.compat.v1.train.polynomial_decay( 254 | self.learning_rate, global_step, 255 | decay_steps=self.max_steps, 256 | end_learning_rate=0.0, 257 | power=0.3) 258 | tf.compat.v1.summary.scalar('learning_rate', self.lr) 259 | # Training operator 260 | adam = tf.compat.v1.train.AdamOptimizer(self.lr) 261 | 262 | # Get next data from preprocessed training dataset 263 | img_in, img_gt = self.get_data( 264 | self.train_in_data_list, 265 | self.train_gt_data_list, 266 | self.batch_size, 267 | self.epoch) 268 | print('img_in, img_gt', img_in.shape, img_gt.shape) 269 | tf.compat.v1.summary.image('img_in', im2uint8(img_in)) 270 | tf.compat.v1.summary.image('img_gt', im2uint8(img_gt)) 271 | 272 | # Compute image loss 273 | n_outputs = model(img_in, reuse=False) 274 | loss_op = self.loss(n_outputs, img_gt) 275 | # By default, adam uses the current graph trainable_variables to optimise training, 276 | # thus train_op should be the last operation of the graph for training. 277 | train_op = adam.minimize(loss_op, global_step) 278 | 279 | # Create session 280 | sess = tf.compat.v1.Session( 281 | config=tf.compat.v1.ConfigProto( 282 | gpu_options=tf.compat.v1.GPUOptions(allow_growth=True))) 283 | 284 | # Initialise all the variables in current session 285 | init = tf.compat.v1.global_variables_initializer() 286 | sess.run(init) 287 | self.saver = tf.compat.v1.train.Saver( 288 | max_to_keep=self.max_ckpts_to_keep, 289 | keep_checkpoint_every_n_hours=self.keep_ckpt_every_n_hours) 290 | 291 | # Check if there are intermediate trained model to load 292 | if self.no_resume or not self.load(sess, self.checkpoints_dir): 293 | print_("Starting training from scratch\n", 'm') 294 | 295 | # Tensorboard summary 296 | summary_op = tf.compat.v1.summary.merge_all() 297 | summary_name = ("data{}_bch{}_ep{}".format( 298 | len(self.train_in_data_list), self.batch_size, self.epoch)) 299 | summary_name += ("_seed{}".format(self.seed) if self.seed is not None else "") 300 | summary_writer = tf.compat.v1.summary.FileWriter( 301 | os.path.join(self.summaries_dir, summary_name), 302 | graph=sess.graph, 303 | flush_secs=self.summary_flush_secs) 304 | 305 | # Compute loss on validation dataset to check overfitting 306 | if self.has_val_data: 307 | val_loss_op = self.validate(model) 308 | # Save validation loss to tensorboard 309 | val_summary_op = tf.compat.v1.summary.scalar('val_loss', val_loss_op) 310 | # Compute initial loss 311 | val_loss, val_summary = sess.run([val_loss_op, val_summary_op]) 312 | summary_writer.add_summary(val_summary, global_step=0) 313 | print("Initial Loss on validation dataset: {:.6f}".format(val_loss)) 314 | 315 | ################ TRAINING ################ 316 | train_start = time.time() 317 | for step in range(sess.run(global_step), self.max_steps): 318 | start_time = time.time() 319 | val_str = '' 320 | if step % self.summaries_save_freq == 0 or step == self.max_steps - 1: 321 | # Train model and record summaries 322 | _, loss_total, summary = sess.run([train_op, loss_op, summary_op]) 323 | summary_writer.add_summary(summary, global_step=step) 324 | duration = time.time() - start_time 325 | if self.has_val_data and step != 0: 326 | # Compute validation loss 327 | val_loss, val_summary = sess.run([val_loss_op, val_summary_op]) 328 | summary_writer.add_summary(val_summary, global_step=step) 329 | val_str = ', val loss: {:.6f}'.format(val_loss) 330 | else: # Train only 331 | _, loss_total = sess.run([train_op, loss_op]) 332 | duration = time.time() - start_time 333 | assert not np.isnan(loss_total), 'Model diverged with loss = NaN' 334 | 335 | if step % self.training_print_freq == 0 or step == self.max_steps - 1: 336 | examples_per_sec = self.batch_size / duration 337 | sec_per_batch = float(duration) 338 | format_str = ('{}: step {}, loss: {:.6f} ({:.1f} data/s; {:.3f} s/bch)' 339 | .format(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), step, loss_total, examples_per_sec, sec_per_batch)) 340 | print(format_str + val_str) 341 | 342 | if (step + 1) % self.ckpts_save_freq == 0 or step == self.max_steps - 1: 343 | # Save current model in a checkpoint 344 | self.save(sess, self.checkpoints_dir, step + 1) 345 | print_("Training duration: {:0.4f}s\n".format(time.time() - train_start), 'm') 346 | print_("--------End of training--------\n", 'm') 347 | # Free all resources associated with the session 348 | sess.close() 349 | 350 | def save(self, sess, checkpoint_dir, step): 351 | if not os.path.exists(checkpoint_dir): 352 | os.makedirs(checkpoint_dir) 353 | self.saver.save(sess, os.path.join(checkpoint_dir, self.ckpt_save_name), global_step=step) 354 | 355 | def load(self, sess, checkpoint_dir): 356 | ckpt_names = get_ckpt_list(checkpoint_dir) 357 | if not ckpt_names: # list is empty 358 | print_("No checkpoints found in {}\n".format(checkpoint_dir), 'm') 359 | return False 360 | else: 361 | print_("Found checkpoints:\n", 'm') 362 | for name in ckpt_names: 363 | print(" {}".format(name)) 364 | # Ask user if they prefer to start training from scratch or resume training on a specific ckeckpoint 365 | while True: 366 | mode=str(input('Start training from scratch (start) or resume training from a previous checkpoint (choose one of the above): ')) 367 | if mode == 'start' or mode in ckpt_names: 368 | break 369 | else: 370 | print("Answer should be 'start' or one of the following checkpoints: {}".format(ckpt_names)) 371 | continue 372 | if mode == 'start': 373 | return False 374 | elif mode in ckpt_names: 375 | # Try to load given intermediate checkpoint 376 | print_("Loading trained model...\n", 'm') 377 | self.saver.restore(sess, os.path.join(checkpoint_dir, mode)) 378 | print_("...Checkpoint {} loaded\n".format(mode), 'm') 379 | return True 380 | else: 381 | raise ValueError("User input is neither 'start' nor a valid checkpoint") 382 | 383 | def parse_args(): 384 | parser = argparse.ArgumentParser(description='Model training arguments') 385 | parser.add_argument('--bch', type=int, default=16, dest='batch_size', help='training batch size') 386 | parser.add_argument('--ep', type=int, default=10000, dest='epoch', help='training epoch number') 387 | parser.add_argument('--lr', type=float, default=1e-4, dest='learning_rate', help='initial learning rate') 388 | parser.add_argument('--seed', type=int, default=None, dest='seed', help='set random seed for deterministic training') 389 | parser.add_argument('--no-gpu-patch', dest='no_gpu_patch', default=False, action='store_true', help='if seed is set, add this tag for much faster but slightly less deterministic training') 390 | parser.add_argument('--no-resume', dest='no_resume', default=False, action='store_true', help="start training from scratch") 391 | parser.add_argument('--name', type=str, default="trainingTemplateTF", dest='ckpt_save_name', help='name of saved checkpoints/model weights') 392 | args = parser.parse_args() 393 | return args 394 | 395 | if __name__ == '__main__': 396 | args = parse_args() 397 | # set up model to train 398 | model = TrainModel(args) 399 | model.train() 400 | -------------------------------------------------------------------------------- /Plugins/Client/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # CMakeLists.txt for Machine Learning Plug-in: MLClient 2 | 3 | # Setting up MLClient sources and dependencies 4 | set (ML_CLIENT_SOURCES 5 | MLClient.cpp 6 | MLClientComms.cpp 7 | MLClientModelManager.cpp 8 | ) 9 | 10 | find_package(Protobuf REQUIRED) 11 | if (WIN32) 12 | find_library(PROTOBUF_LIBRARY NAME libprotobuf PATHS ${Protobuf_LIBRARIES}) 13 | endif() 14 | 15 | # Compile protobuf .cpp and .h files out of message.proto 16 | protobuf_generate_cpp(PROTO_SRCS PROTO_HDRS message.proto) 17 | list(APPEND ML_CLIENT_SOURCES ${PROTO_SRCS}) # add message.pb.cc 18 | 19 | if(NOT NUKE_INSTALL_PATH) 20 | message(FATAL_ERROR "Nuke install path not set.") 21 | endif() 22 | find_library(DDIMAGE_LIBRARY NAME DDImage libDDImage PATHS ${NUKE_INSTALL_PATH}) 23 | if(NOT DDIMAGE_LIBRARY) 24 | message(FATAL_ERROR "DDImage library not found.") 25 | endif() 26 | 27 | # Create MLClient.so shared library 28 | add_library(MLClient SHARED 29 | ${ML_CLIENT_SOURCES} 30 | ) 31 | 32 | set_target_properties (MLClient PROPERTIES PREFIX "") 33 | target_include_directories(MLClient PRIVATE 34 | ${NUKE_INSTALL_PATH}/include 35 | ${CMAKE_CURRENT_BINARY_DIR} # include message.pb.h 36 | ${Protobuf_INCLUDE_DIR} 37 | ) 38 | 39 | target_link_libraries(MLClient 40 | ${PROTOBUF_LIBRARY} 41 | ${DDIMAGE_LIBRARY} 42 | ) 43 | 44 | if (WIN32) 45 | target_link_libraries(MLClient 46 | ws2_32.lib # include windows socket library 47 | ) 48 | endif (WIN32) -------------------------------------------------------------------------------- /Plugins/Client/MLClient.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019 Foundry. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | //************************************************************************* 15 | 16 | #ifndef MLCLIENT_H 17 | #define MLCLIENT_H 18 | 19 | // Standard plug-in include files. 20 | #include "DDImage/PlanarIop.h" 21 | #include "DDImage/NukeWrapper.h" 22 | #include "DDImage/Row.h" 23 | #include "DDImage/Tile.h" 24 | #include "DDImage/Knobs.h" 25 | #include "DDImage/Thread.h" 26 | #include 27 | 28 | // Local include files 29 | #include "MLClientComms.h" 30 | #include "MLClientModelManager.h" 31 | 32 | //! The Machine Learning (ML) Client plug-in connects Nuke to a Python server to apply ML models to images. 33 | /*! This plug-in can connect to a server (given a host and port), which responds 34 | with a list of available Machine Learning (ML) models and options. 35 | On every /a renderStripe() call, the image and model options are sent from Nuke to the server, 36 | there the server can process the image by doing Machine Learning inference, 37 | finally the resulting image is sent back to Nuke. 38 | */ 39 | class MLClient : public DD::Image::PlanarIop 40 | { 41 | 42 | public: 43 | // Static consts 44 | static const char* const kClassName; 45 | static const char* const kHelpString; 46 | 47 | static const char* const kDefaultHostName; 48 | static const int kDefaultPortNumber; 49 | 50 | private: 51 | static const DD::Image::ChannelSet kDefaultChannels; 52 | static const int kDefaultNumberOfChannels; 53 | 54 | public: 55 | //! Constructor. Initialize user controls to their default values. 56 | MLClient(Node* node); 57 | virtual ~MLClient(); 58 | 59 | public: 60 | // DDImage::Iop overrides 61 | 62 | //! The maximum number of input connections the operator can have. 63 | int maximum_inputs() const; 64 | //! The minimum number of input connections the operator can have. 65 | int minimum_inputs() const; 66 | /*! Return the text Nuke should draw on the arrow head for input \a input 67 | in the DAG window. This should be a very short string, one letter 68 | ideally. Return null or an empty string to not label the arrow. 69 | */ 70 | const char* input_label(int input, char* buffer) const; 71 | 72 | bool useStripes() const; 73 | bool renderFullPlanes() const; 74 | 75 | void _validate(bool); 76 | void getRequests(const DD::Image::Box& box, const DD::Image::ChannelSet& channels, int count, DD::Image::RequestOutput &reqData) const; 77 | 78 | /*! This function is called by Nuke for processing the current image. 79 | The image and model options are sent from Nuke to the server, 80 | there the server can process the image by doing Machine Learning inference, 81 | finally the resulting image is sent back to Nuke. 82 | The function tries to reconnect if no connection is set. 83 | */ 84 | void renderStripe(DD::Image::ImagePlane& imagePlane); 85 | 86 | //! Information to the plug-in manager of DDNewImage/Nuke. 87 | static const DD::Image::Iop::Description description; 88 | 89 | static void addDynamicKnobs(void*, DD::Image::Knob_Callback); 90 | void knobs(DD::Image::Knob_Callback f); 91 | int knob_changed(DD::Image::Knob*); 92 | 93 | //! Return the name of the class. 94 | const char* Class() const; 95 | const char* node_help() const; 96 | 97 | MLClientModelManager& getModelManager(); 98 | int getNumNewKnobs(); 99 | void setNumNewKnobs(int i); 100 | 101 | private: 102 | // Private functions for talking to the server 103 | //! Try connect to the server and set-up the relevant knobs. Return true on 104 | //! success, false otherwise and setting a descriptive error in errorMsg. 105 | bool refreshModelsAndKnobsFromServer(std::string& errorMsg); 106 | 107 | //! Return whether we successfully managed to pull model 108 | //! info from the server at some time in the past, and the selected model is 109 | //! valid. 110 | bool haveValidModelInfo() const; 111 | 112 | //! Connect to server, then send inference request and read inference response. 113 | //! Return true on success, false otherwise filling in the errorMsg. 114 | bool processImage(const std::string& hostStr, int port, mlserver::RespondWrapper& responseWrapper, std::string& errorMsg); 115 | 116 | //! Parse the response messge from the server, and if it contains 117 | //! an image, attempt to copy the image to the imagePlane. Return 118 | //! true on success, false otherwise and fill in the error string. 119 | bool renderOutputBuffer(mlserver::RespondWrapper& responseWrapper, DD::Image::ImagePlane& imagePlane, std::string& errorMsg); 120 | 121 | //! Return whether the dynamic knobs should be shown or not. 122 | bool getShowDynamic() const; 123 | 124 | //! Look for the knob with the given name. If found, restore its value 125 | //! from the given serialised value. 126 | void restoreKnobValue(const std::string& knobName, const std::string& value); 127 | 128 | private: 129 | std::string _host; 130 | bool _hostIsValid; 131 | int _port; 132 | bool _portIsValid; 133 | int _chosenModel; 134 | bool _modelSelected; 135 | 136 | DD::Image::Knob* _selectedModelknob; 137 | std::vector _serverModels; 138 | 139 | std::vector _numInputs; 140 | std::vector> _inputNames; 141 | 142 | MLClientModelManager _modelManager; 143 | 144 | bool _showDynamic; 145 | int _numNewKnobs; // Used to track the number of knobs created by the previous pass, so that the same number can be deleted next time. 146 | 147 | }; 148 | 149 | #endif // MLCLIENT_H 150 | -------------------------------------------------------------------------------- /Plugins/Client/MLClientComms.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019 Foundry. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | //************************************************************************* 15 | 16 | #ifndef MLCLIENTCOMMS_H 17 | #define MLCLIENTCOMMS_H 18 | 19 | // Protobuf headers 20 | #include "message.pb.h" 21 | 22 | using byte = unsigned char; 23 | 24 | 25 | //! The Machine Learning (ML) Client plug-in connects Nuke to a Python server to apply ML models to images. 26 | /*! This plug-in can connect to a server (given a host and port), which responds 27 | with a list of available Machine Learning (ML) models and options. 28 | On every /a renderStripe() call, the image and model options are sent from Nuke to the server, 29 | there the server can process the image by doing Machine Learning inference, 30 | finally the resulting image is sent back to Nuke. 31 | */ 32 | class MLClientComms 33 | { 34 | public: 35 | // Static consts 36 | static const int kNumberOfBytesHeaderSize; 37 | 38 | static const int kTimeout; 39 | static const int kMaxNumberOfTry; 40 | 41 | // Static non-conts 42 | static bool Verbose; 43 | 44 | public: 45 | //! Constructor. Initialize user controls to their default values, then try to 46 | //! connect to the specified host / port. Following the c-tor, you can test for 47 | //! a valid connection by calling isConnected(). 48 | MLClientComms(const std::string& hostStr, int port); 49 | 50 | //! Destructor. Tear down any existing connection. 51 | virtual ~MLClientComms(); 52 | 53 | public: 54 | // Public static methods for client-server communication 55 | 56 | //! Test if a given hostname is valid, returning true if it is, false otherwise 57 | static bool ValidateHostName(const std::string& hostStr); 58 | 59 | //! Print debug related information to std::cout, when ::Verbose is set to true. 60 | static void Vprint(std::string msg); 61 | 62 | public: 63 | // Public methods for client-server communication 64 | 65 | //! Return whether this object is connected to the specified server. 66 | bool isConnected() const; 67 | 68 | //! Function for discovering & negotiating the available models and their parameters. 69 | //! Return true on success, false otherwise with the errorMsg filled in. 70 | bool sendInfoRequestAndReadInfoResponse(mlserver::RespondWrapper& responseWrapper, std::string& errorMsg); 71 | 72 | //! Function for performing the inference on a selected model. 73 | //! Return true on success, false otherwise with the errorMsg filled in. 74 | bool sendInferenceRequestAndReadInferenceResponse(mlserver::RequestInference& requestInference, mlserver::RespondWrapper& responseWrapper, std::string& errorMsg); 75 | 76 | private: 77 | // Private client / server comms functions 78 | 79 | //! Try to connect to the server with the specified hostStr & port, by repeatedly 80 | //! calling setupConnection() below until a connection is made or times out. After it 81 | //! returns, you can test if it was successful by calling isConnected(). 82 | void connectLoop(); 83 | 84 | //! Create a socket to connect to the server specified by hostStr and port. 85 | //! Return true on success, false otherwise with a message filled in errorStr. 86 | bool setupConnection(std::string& errorStr); 87 | 88 | //! Request the server to return a future message about its models. This is used 89 | //! to instruct the server that it should set itself up. 90 | //! Return true on success, false otherwise. 91 | bool sendInfoRequest(); 92 | 93 | //! Retrieve the response from the server and store it in responseWrapper, to be parsed 94 | //! elsewhere. Return true on success, false otherwise. 95 | bool readInfoResponse(mlserver::RespondWrapper& responseWrapper); 96 | 97 | //! Send a messaged image to to the server. Return true on success, false otherwise. 98 | bool sendInferenceRequest(mlserver::RequestInference& requestInference); 99 | 100 | //! Marshall the returned image into a float buffer of the original image size. Note, this 101 | //! expects the size of result to have been set to the same size as the image that was 102 | //! previously sent to the server. Return true on success, false otherwise. 103 | bool readInferenceResponse(mlserver::RespondWrapper& responseWrapper); 104 | 105 | //! Pull the data after determining the size 'siz' from the header. 106 | //! Helper to the above 'readInfoResponse' function. 107 | bool readInfoResponse(google::protobuf::uint32 siz, mlserver::RespondWrapper& responseWrapper); 108 | 109 | //! Pull the data after determining the size 'siz' from the header. 110 | //! Helper to the above 'readInferenceResponse' function. 111 | bool readInferenceResponse(google::protobuf::uint32 siz, mlserver::RespondWrapper& responseWrapper); 112 | 113 | //! Close the current connection if one is open. 114 | void closeConnection(); 115 | 116 | private: 117 | // Private helper functions 118 | google::protobuf::uint32 readHdr(char* buf); 119 | void* getInAddr(struct sockaddr* sa); 120 | 121 | private: 122 | // Private member variables 123 | std::string _hostStr; 124 | int _port; 125 | 126 | bool _isConnected; 127 | int _socket; 128 | }; 129 | 130 | #endif // MLCLIENTCOMMS_H 131 | -------------------------------------------------------------------------------- /Plugins/Client/MLClientModelManager.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019 Foundry. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | //************************************************************************* 15 | 16 | #include "MLClientModelManager.h" 17 | #include "DDImage/Knob.h" 18 | #include "MLClient.h" 19 | 20 | MLClientModelKnob::MLClientModelKnob(DD::Image::Knob_Closure* kc, DD::Image::Op* op, const char* name) 21 | : DD::Image::Knob(kc, name) 22 | , _op(op) 23 | , _model("") 24 | { } 25 | 26 | const char* MLClientModelKnob::Class() const 27 | { 28 | return "MLClientModelKnob"; 29 | } 30 | 31 | bool MLClientModelKnob::not_default () const 32 | { 33 | // Always flag as not default, so it's always serialised. 34 | return true; 35 | } 36 | 37 | std::string MLClientModelKnob::getModel() const 38 | { 39 | return _model; 40 | } 41 | 42 | const std::map& MLClientModelKnob::getParameters() const 43 | { 44 | return _parameters; 45 | } 46 | 47 | void MLClientModelKnob::to_script (std::ostream &out, const DD::Image::OutputContext *, bool quote) const 48 | { 49 | std::string saveString; 50 | std::stringstream ss; 51 | if (_op != nullptr) { 52 | DD::Image::Knob* k = _op->knob("models"); 53 | if(k != nullptr) { 54 | const int modelIndex = k->get_value(); 55 | DD::Image::Enumeration_KnobI* eKnob = k->enumerationKnob(); 56 | if(eKnob != nullptr) { 57 | ss << "model:" << eKnob->getItemValueString(modelIndex) << ";"; 58 | MLClient* mlClient = dynamic_cast(_op); 59 | if(mlClient != nullptr) { 60 | MLClientModelManager& mlManager = mlClient->getModelManager(); 61 | toScriptT(mlManager, ss, &MLClientModelManager::getNumOfInts, &MLClientModelManager::getDynamicIntName); 62 | toScriptT(mlManager, ss, &MLClientModelManager::getNumOfFloats, &MLClientModelManager::getDynamicFloatName); 63 | toScriptT(mlManager, ss, &MLClientModelManager::getNumOfBools, &MLClientModelManager::getDynamicBoolName); 64 | toScriptStrings(mlManager, ss); 65 | } 66 | } 67 | } 68 | } 69 | saveString = ss.str(); 70 | if(quote) { 71 | saveString.insert(saveString.begin(),'{'); 72 | saveString+='}'; 73 | } 74 | out << saveString; 75 | } 76 | 77 | bool MLClientModelKnob::from_script(const char * src) 78 | { 79 | std::string loadString(src); 80 | 81 | if ((_op != nullptr) && (loadString!="")) { 82 | bool success = false; 83 | 84 | // We parse the serialised string to extract the pairs of key:val; 85 | const std::string delimiter = ";"; 86 | const std::string keyValDelimiter = ":"; 87 | _parameters.clear(); 88 | size_t pos = 0; 89 | std::string token; 90 | while ((pos = loadString.find(delimiter)) != std::string::npos) { 91 | token = loadString.substr(0, pos); 92 | std::cout << token << std::endl; 93 | 94 | // We further split the key:value pair 95 | std::string key = token.substr(0, token.find(keyValDelimiter)); 96 | std::string val = token.substr(token.find(keyValDelimiter) + keyValDelimiter.length(), token.length() - key.length() - keyValDelimiter.length()); 97 | if(key == "model") { 98 | _model = val; 99 | } else { 100 | _parameters.insert(std::make_pair(key, val)); 101 | } 102 | 103 | loadString.erase(0, pos + delimiter.length()); 104 | } 105 | 106 | return success; 107 | } 108 | return true; 109 | } 110 | 111 | void MLClientModelKnob::toScriptT(MLClientModelManager& mlManager, std::ostream &out, 112 | int (MLClientModelManager::*getNum)() const, 113 | std::string (MLClientModelManager::*getDynamicName)(int)) const 114 | { 115 | const int num = (mlManager.*getNum)(); 116 | for(int i = 0; i < num; i++) { 117 | DD::Image::Knob* k = _op->knob((mlManager.*getDynamicName)(i).c_str()); 118 | if(k != nullptr) { 119 | std::stringstream ss; 120 | k->to_script(ss, nullptr, false); 121 | out << (mlManager.*getDynamicName)(i) << ":" << ss.str() << ";"; 122 | } 123 | } 124 | } 125 | 126 | void MLClientModelKnob::toScriptStrings(MLClientModelManager& mlManager, std::ostream &out) const 127 | { 128 | const int numFloats = mlManager.getNumOfStrings(); 129 | for(int i = 0; i < numFloats; i++) { 130 | DD::Image::Knob* k = _op->knob(mlManager.getDynamicStringName(i).c_str()); 131 | if(k != nullptr) { 132 | out << mlManager.getDynamicStringName(i) << ":" << k->get_text() << ";"; 133 | } 134 | } 135 | } 136 | 137 | MLClientModelManager::MLClientModelManager(DD::Image::Op* parent) 138 | : _parent(parent) 139 | { } 140 | 141 | MLClientModelManager::~MLClientModelManager() 142 | { } 143 | 144 | //! Parse options from the server model /m to the MLClientModelManager 145 | void MLClientModelManager::parseOptions(const mlserver::Model& m) 146 | { 147 | clear(); 148 | 149 | for (int i = 0, endI = m.bool_options_size(); i < endI; i++) { 150 | mlserver::BoolAttrib option; 151 | option = m.bool_options(i); 152 | if (option.values(0)) { 153 | _dynamicBoolValues.push_back(1); 154 | } 155 | else { 156 | _dynamicBoolValues.push_back(0); 157 | } 158 | _dynamicBoolNames.push_back(option.name()); 159 | } 160 | for (int i = 0, endI = m.int_options_size(); i < endI; i++) { 161 | mlserver::IntAttrib option; 162 | option = m.int_options(i); 163 | _dynamicIntValues.push_back(option.values(0)); 164 | _dynamicIntNames.push_back(option.name()); 165 | } 166 | for (int i = 0, endI = m.float_options_size(); i < endI; i++) { 167 | mlserver::FloatAttrib option; 168 | option = m.float_options(i); 169 | _dynamicFloatValues.push_back(option.values(0)); 170 | _dynamicFloatNames.push_back(option.name()); 171 | } 172 | for (int i = 0, endI = m.string_options_size(); i < endI; i++) { 173 | mlserver::StringAttrib option; 174 | option = m.string_options(i); 175 | _dynamicStringValues.push_back(option.values(0)); 176 | _dynamicStringNames.push_back(option.name()); 177 | } 178 | for (int i = 0, endI = m.button_options_size(); i < endI; i++) { 179 | mlserver::BoolAttrib option; 180 | option = m.button_options(i); 181 | if (option.values(0)) { 182 | _dynamicButtonValues.push_back(1); 183 | } 184 | else { 185 | _dynamicButtonValues.push_back(0); 186 | } 187 | _dynamicButtonNames.push_back(option.name()); 188 | } 189 | } 190 | 191 | //! Use current knob values to update options on the server model /m 192 | //! in order to later request inference on this model 193 | void MLClientModelManager::updateOptions(mlserver::Model& m) 194 | { 195 | m.clear_bool_options(); 196 | for (int i = 0; i < _dynamicBoolValues.size(); i++) { 197 | mlserver::BoolAttrib* option = m.add_bool_options(); 198 | option->set_name(_dynamicBoolNames[i]); 199 | DD::Image::Knob* k = _parent->knob(_dynamicBoolNames[i].c_str()); 200 | bool val = false; 201 | if (k != nullptr) { 202 | val = k->get_value(); 203 | } 204 | option->add_values(val); 205 | } 206 | 207 | m.clear_int_options(); 208 | for (int i = 0; i < _dynamicIntValues.size(); i++) { 209 | mlserver::IntAttrib* option = m.add_int_options(); 210 | option->set_name(_dynamicIntNames[i]); 211 | DD::Image::Knob* k = _parent->knob(_dynamicIntNames[i].c_str()); 212 | int val = 0; 213 | if (k != nullptr) { 214 | val = k->get_value(); 215 | } 216 | option->add_values(val); 217 | } 218 | 219 | m.clear_float_options(); 220 | for (int i = 0; i < _dynamicFloatValues.size(); i++) { 221 | mlserver::FloatAttrib* option = m.add_float_options(); 222 | option->set_name(_dynamicFloatNames[i]); 223 | DD::Image::Knob* k = _parent->knob(_dynamicFloatNames[i].c_str()); 224 | float val = 0.0f; 225 | if (k != nullptr) { 226 | val = k->get_value(); 227 | } 228 | option->add_values(val); 229 | } 230 | 231 | m.clear_string_options(); 232 | for (int i = 0; i < _dynamicStringValues.size(); i++) { 233 | mlserver::StringAttrib* option = m.add_string_options(); 234 | option->set_name(_dynamicStringNames[i]); 235 | DD::Image::Knob* k = _parent->knob(_dynamicStringNames[i].c_str()); 236 | const char* val = ""; 237 | if(k != nullptr) { 238 | val = k->get_text(); 239 | if (val==nullptr) { 240 | val = ""; 241 | } 242 | } 243 | option->add_values(val); 244 | } 245 | 246 | m.clear_button_options(); 247 | for (int i = 0; i < _dynamicButtonValues.size(); i++) { 248 | mlserver::BoolAttrib* option = m.add_button_options(); 249 | option->set_name(_dynamicButtonNames[i]); 250 | // Get member value instead of knob value to catch button push 251 | option->add_values(_dynamicButtonValues[i]); 252 | } 253 | } 254 | 255 | int MLClientModelManager::getNumOfFloats() const 256 | { 257 | return _dynamicFloatValues.size(); 258 | } 259 | 260 | int MLClientModelManager::getNumOfInts() const 261 | { 262 | return _dynamicIntValues.size(); 263 | } 264 | 265 | int MLClientModelManager::getNumOfBools() const 266 | { 267 | return _dynamicBoolValues.size(); 268 | } 269 | 270 | int MLClientModelManager::getNumOfStrings() const 271 | { 272 | return _dynamicStringValues.size(); 273 | } 274 | 275 | int MLClientModelManager::getNumOfButtons() const 276 | { 277 | return _dynamicButtonValues.size(); 278 | } 279 | 280 | std::string MLClientModelManager::getDynamicBoolName(int idx) 281 | { 282 | return _dynamicBoolNames[idx]; 283 | } 284 | 285 | std::string MLClientModelManager::getDynamicFloatName(int idx) 286 | { 287 | return _dynamicFloatNames[idx]; 288 | } 289 | 290 | std::string MLClientModelManager::getDynamicIntName(int idx) 291 | { 292 | return _dynamicIntNames[idx]; 293 | } 294 | 295 | std::string MLClientModelManager::getDynamicStringName(int idx) 296 | { 297 | return _dynamicStringNames[idx]; 298 | } 299 | 300 | std::string MLClientModelManager::getDynamicButtonName(int idx) 301 | { 302 | return _dynamicButtonNames[idx]; 303 | } 304 | 305 | float* MLClientModelManager::getDynamicFloatValue(int idx) 306 | { 307 | return &_dynamicFloatValues[idx]; 308 | } 309 | 310 | int* MLClientModelManager::getDynamicIntValue(int idx) 311 | { 312 | return &_dynamicIntValues[idx]; 313 | } 314 | 315 | bool* MLClientModelManager::getDynamicBoolValue(int idx) 316 | { 317 | return (bool*)&_dynamicBoolValues[idx]; 318 | } 319 | 320 | std::string* MLClientModelManager::getDynamicStringValue(int idx) 321 | { 322 | return &_dynamicStringValues[idx]; 323 | } 324 | 325 | bool* MLClientModelManager::getDynamicButtonValue(int idx) 326 | { 327 | return (bool*)&_dynamicButtonValues[idx]; 328 | } 329 | 330 | void MLClientModelManager::setDynamicButtonValue(int idx, int value) 331 | { 332 | _dynamicButtonValues[idx] = value; 333 | } 334 | 335 | void MLClientModelManager::clear() 336 | { 337 | _dynamicBoolValues.clear(); 338 | _dynamicIntValues.clear(); 339 | _dynamicFloatValues.clear(); 340 | _dynamicStringValues.clear(); 341 | _dynamicButtonValues.clear(); 342 | 343 | _dynamicBoolNames.clear(); 344 | _dynamicIntNames.clear(); 345 | _dynamicFloatNames.clear(); 346 | _dynamicStringNames.clear(); 347 | _dynamicButtonNames.clear(); 348 | } 349 | -------------------------------------------------------------------------------- /Plugins/Client/MLClientModelManager.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019 Foundry. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | //************************************************************************* 15 | 16 | #ifndef MLClientModelManager_H 17 | #define MLClientModelManager_H 18 | 19 | #include 20 | #include 21 | #include "message.pb.h" 22 | 23 | #include "DDImage/Op.h" 24 | #include "DDImage/Knobs.h" 25 | 26 | class MLClientModelManager; 27 | 28 | //! The role of this custom knob is to serialise and to store the selected model and its parameters. 29 | //! As these exist as dynamic knobs, this is to workaround the fact that we would need to know about these knobs 30 | //! in advance to save them the usual way. 31 | class MLClientModelKnob : public DD::Image::Knob 32 | { 33 | public: 34 | MLClientModelKnob(DD::Image::Knob_Closure* kc, DD::Image::Op* op, const char* name); 35 | 36 | // Knob overrides. 37 | const char* Class() const override; 38 | bool not_default () const override; 39 | //! Serialises the currently selected model and its parameters as follows: 40 | //! {model:modelName;param1:value1;param2:value2} 41 | void to_script (std::ostream &out, const DD::Image::OutputContext *, bool quote) const override; 42 | //! Deserialises the saved model and its parameters. 43 | //! The model can then be retreived with getModel() 44 | //! and the dictionary of parameters with getParameters(). 45 | bool from_script(const char * src) override; 46 | 47 | std::string getModel() const; 48 | const std::map& getParameters() const; 49 | 50 | private: 51 | //! Serialises the dynamic knobs to the given output stream. 52 | //! This function is generic for the Ints, Floats and Bools knobs 53 | //! provided that the corresponding getNumOfT and getDynamicTName 54 | //! functions are given. 55 | void toScriptT(MLClientModelManager& mlManager, std::ostream &out, 56 | int (MLClientModelManager::*getNum)() const, 57 | std::string (MLClientModelManager::*getDynamicName)(int)) const; 58 | //! Serialises the dynamic knobs containing strings to the given output stream. 59 | void toScriptStrings(MLClientModelManager& mlManager, std::ostream &out) const; 60 | 61 | private: 62 | DD::Image::Op* _op; 63 | std::string _model; 64 | std::map _parameters; 65 | }; 66 | 67 | //! Class to parse and store knobs for a given model. 68 | class MLClientModelManager 69 | { 70 | public: 71 | explicit MLClientModelManager(DD::Image::Op* parent); 72 | ~MLClientModelManager(); 73 | 74 | // Getters of the class 75 | int getNumOfFloats() const; 76 | int getNumOfInts() const; 77 | int getNumOfBools() const; 78 | int getNumOfStrings() const; 79 | int getNumOfButtons() const; 80 | 81 | std::string getDynamicBoolName(int idx); 82 | std::string getDynamicFloatName(int idx); 83 | std::string getDynamicIntName(int idx); 84 | std::string getDynamicStringName(int idx); 85 | std::string getDynamicButtonName(int idx); 86 | 87 | float* getDynamicFloatValue(int idx); 88 | int* getDynamicIntValue(int idx); 89 | bool* getDynamicBoolValue(int idx); 90 | std::string* getDynamicStringValue(int idx); 91 | bool* getDynamicButtonValue(int idx); 92 | void setDynamicButtonValue(int idx, int value); 93 | 94 | void clear(); 95 | //! Parse the model options from the ML server. 96 | void parseOptions(const mlserver::Model& m); 97 | //! Update any current options from any changes to the ML server. 98 | void updateOptions(mlserver::Model& m); 99 | 100 | private: 101 | DD::Image::Op* _parent; 102 | std::vector _dynamicBoolValues; 103 | std::vector _dynamicIntValues; 104 | std::vector _dynamicFloatValues; 105 | std::vector _dynamicStringValues; 106 | std::vector _dynamicButtonValues; 107 | std::vector _dynamicBoolNames; 108 | std::vector _dynamicIntNames; 109 | std::vector _dynamicFloatNames; 110 | std::vector _dynamicStringNames; 111 | std::vector _dynamicButtonNames; 112 | }; 113 | 114 | #endif 115 | -------------------------------------------------------------------------------- /Plugins/Client/message.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package mlserver; 4 | 5 | message RequestWrapper { 6 | optional bool info = 1; 7 | optional RequestInfo r1 = 2; 8 | optional RequestInference r2 = 3; 9 | } 10 | 11 | message RespondWrapper { 12 | optional bool info = 1; 13 | optional RespondInfo r1 = 2; 14 | optional RespondInference r2 = 3; 15 | optional Error error = 4; 16 | } 17 | 18 | message RequestInfo { 19 | optional bool info = 1; 20 | } 21 | 22 | message RespondInfo { 23 | optional int32 num_models = 1; 24 | repeated Model models = 2; 25 | } 26 | 27 | message Model { 28 | optional string name = 1; 29 | optional string label = 2; 30 | repeated ImagePrototype inputs = 3; 31 | repeated ImagePrototype outputs = 4; 32 | repeated BoolAttrib bool_options = 5; 33 | repeated IntAttrib int_options = 6; 34 | repeated FloatAttrib float_options = 7; 35 | repeated StringAttrib string_options = 8; 36 | repeated BoolAttrib button_options = 9; 37 | repeated MultipleChoiceOption mc_options = 10; 38 | } 39 | 40 | message MultipleChoiceOption { 41 | optional string name = 1; 42 | optional string value = 2; 43 | repeated string choices = 3; 44 | } 45 | 46 | message ImagePrototype { 47 | optional string name = 1; 48 | optional int32 channels = 2; 49 | } 50 | 51 | message Error { 52 | optional string msg = 1; 53 | } 54 | 55 | message RequestInference { 56 | optional Model model = 1; 57 | repeated Image images = 2; 58 | } 59 | 60 | message RespondInference { 61 | optional int32 num_images = 1; 62 | repeated Image images = 2; 63 | optional int32 num_objects = 3; 64 | repeated FieldValuePairAttrib objects = 4; 65 | } 66 | 67 | message Image { 68 | optional int32 width = 1; 69 | optional int32 height = 2; 70 | optional int32 channels = 3; 71 | optional bytes image = 4; 72 | } 73 | 74 | message BoolAttrib { 75 | optional string name = 1; 76 | repeated bool values = 2 [packed=true]; 77 | } 78 | 79 | message IntAttrib { 80 | optional string name = 1; 81 | repeated int32 values = 2 [packed=true]; 82 | } 83 | 84 | message FloatAttrib { 85 | optional string name = 1; 86 | repeated float values = 2 [packed=true]; 87 | } 88 | 89 | message StringAttrib { 90 | optional string name = 1; 91 | repeated string values = 2; 92 | } 93 | 94 | message FieldValuePairAttrib { 95 | optional string name = 1; 96 | repeated FieldValuePair values = 2; 97 | } 98 | 99 | message FieldValuePair { 100 | repeated IntAttrib int_attributes = 1; 101 | repeated FloatAttrib float_attributes = 2; 102 | repeated StringAttrib string_attributes = 3; 103 | repeated FieldValuePairAttrib children = 4; 104 | } -------------------------------------------------------------------------------- /Plugins/Server/.dockerignore: -------------------------------------------------------------------------------- 1 | # Ignore when building the docker image 2 | .dockerignore 3 | Dockerfile -------------------------------------------------------------------------------- /Plugins/Server/Dockerfile: -------------------------------------------------------------------------------- 1 | # Ubuntu 18.04 with CUDA 10.0, CuDNN 7.6 2 | # Python3.6, TensorFlow 1.15.0, PyTorch 1.4 3 | FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu18.04 4 | 5 | ARG DEBIAN_FRONTEND=noninteractive 6 | RUN apt-get update && apt-get install -y --no-install-recommends \ 7 | build-essential \ 8 | ca-certificates \ 9 | cmake \ 10 | curl \ 11 | git \ 12 | libglib2.0-0 \ 13 | libjpeg-dev \ 14 | libopencv-dev \ 15 | libopenexr-dev \ 16 | libpng-dev \ 17 | libsm-dev \ 18 | vim && \ 19 | rm -rf /var/lib/apt/lists/* 20 | 21 | # Install Python 3.6 22 | RUN apt-get update && apt-get install -y --no-install-recommends \ 23 | python3-opencv \ 24 | python3-pip \ 25 | python3.6-dev && \ 26 | rm -rf /var/lib/apt/lists/* 27 | # Have aliases python3->python and pip3->pip 28 | RUN ln -s /usr/bin/python3 /usr/bin/python && \ 29 | ln -s /usr/bin/pip3 /usr/bin/pip 30 | RUN python -m pip install --upgrade pip 31 | 32 | RUN pip install --no-cache-dir setuptools wheel && \ 33 | pip install --no-cache-dir \ 34 | future \ 35 | gast==0.2.2 \ 36 | protobuf \ 37 | pyyaml==3.13 \ 38 | scikit-image \ 39 | typing \ 40 | imageio \ 41 | OpenEXR 42 | 43 | # Install TF 1.15.0 GPU for Python3.6 (no TensorRT) 44 | RUN pip install --no-cache-dir \ 45 | tensorflow-gpu==1.15.0 \ 46 | tensorflow-determinism 47 | 48 | # Install PyTorch (include Caffe2) for CUDA 10.0 49 | RUN pip install --no-cache-dir torch==1.4.0+cu100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html 50 | RUN pip install --no-cache-dir cupy-cuda100 51 | RUN pip install --no-cache-dir cython 52 | 53 | WORKDIR /workspace 54 | # Install the COCO API 55 | RUN pip install 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI' 56 | 57 | # Install detectron for mask RCNN 58 | RUN git clone https://github.com/facebookresearch/detectron 59 | RUN sed -i 's/cythonize(ext_modules)/cythonize(ext_modules, language_level="3")/g' detectron/setup.py 60 | RUN cd detectron && pip install -r requirements.txt && make 61 | 62 | WORKDIR /workspace/ml-server 63 | # Copy your current folder to the docker image /workspace/ml-server/ folder 64 | COPY . . -------------------------------------------------------------------------------- /Plugins/Server/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheFoundryVisionmongers/nuke-ML-server/5dd4d04cd673c60de8093c600d6c54016fca92d6/Plugins/Server/__init__.py -------------------------------------------------------------------------------- /Plugins/Server/py2.Dockerfile: -------------------------------------------------------------------------------- 1 | # Ubuntu 18.04 with CUDA 10.0, CuDNN 7.6 2 | # Python2.7, TensorFlow 1.15.0, PyTorch 1.4 3 | FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu18.04 4 | 5 | ARG DEBIAN_FRONTEND=noninteractive 6 | RUN apt-get update && apt-get install -y --no-install-recommends \ 7 | build-essential \ 8 | ca-certificates \ 9 | cmake \ 10 | curl \ 11 | git \ 12 | libglib2.0-0 \ 13 | libjpeg-dev \ 14 | libopencv-dev \ 15 | libopenexr-dev \ 16 | libpng-dev \ 17 | libsm-dev \ 18 | vim && \ 19 | rm -rf /var/lib/apt/lists/* 20 | 21 | # Install Python 2.7 22 | RUN apt-get update && apt-get install -y --no-install-recommends \ 23 | python-pip \ 24 | python2.7-dev && \ 25 | rm -rf /var/lib/apt/lists/* 26 | 27 | # pip version 21.0 will drop support for Python 2.7 28 | RUN python -m pip install --upgrade pip==20.1 29 | RUN pip install --no-cache-dir setuptools wheel && \ 30 | pip install --no-cache-dir \ 31 | future \ 32 | gast==0.2.2 \ 33 | protobuf \ 34 | pyyaml==3.13 \ 35 | scikit-image \ 36 | typing \ 37 | imageio==2.6.1 \ 38 | OpenEXR==1.3.2 39 | 40 | # Install TF 1.15.0 GPU for Python2.7 41 | RUN pip install --no-cache-dir \ 42 | tensorflow-gpu==1.15.0 \ 43 | tensorflow-determinism 44 | 45 | # Install PyTorch (include Caffe2) for CUDA 10.0 46 | RUN pip install --no-cache-dir torch==1.4.0+cu100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html 47 | RUN pip install --no-cache-dir cupy-cuda100 48 | RUN pip install --no-cache-dir cython 49 | 50 | WORKDIR /workspace 51 | # Install the COCO API 52 | RUN pip install 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI' 53 | 54 | # Install detectron for mask RCNN 55 | RUN git clone https://github.com/facebookresearch/detectron 56 | RUN sed -i 's/cythonize(ext_modules)/cythonize(ext_modules, language_level="2")/g' detectron/setup.py 57 | RUN cd detectron && pip install -r requirements.txt && make 58 | 59 | WORKDIR /workspace/ml-server 60 | # Copy your current folder to the docker image /workspace/ml-server/ folder 61 | COPY . . -------------------------------------------------------------------------------- /Plugins/Server/server.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 Foundry. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | ############################################################################## 15 | 16 | import argparse 17 | import os 18 | import importlib 19 | import socket # to get machine hostname 20 | import traceback 21 | 22 | try: # python3 23 | import socketserver 24 | except ImportError: # python2 25 | import SocketServer as socketserver 26 | 27 | import numpy as np 28 | 29 | from message_pb2 import * 30 | 31 | class MLTCPServer(socketserver.TCPServer): 32 | def __init__(self, server_address, handler_class, auto_bind=True): 33 | self.verbose = True 34 | # Each directory in models/ containing a model.py file is an available ML model 35 | self.available_models = [name for name in next(os.walk('models'))[1] 36 | if os.path.isfile(os.path.join('models', name, 'model.py'))] 37 | self.available_models.sort() 38 | self.models = {} 39 | for model in self.available_models: 40 | print('Importing models.{}.model'.format(model)) 41 | self.models[model] = importlib.import_module('models.{}.model'.format(model)).Model() 42 | socketserver.TCPServer.__init__(self, server_address, handler_class, auto_bind) 43 | return 44 | 45 | class ImageProcessTCPHandler(socketserver.BaseRequestHandler): 46 | """This request handler is instantiated once per connection.""" 47 | 48 | def handle(self): 49 | # Read the data headers 50 | data_hdr = self.request.recv(12) 51 | sz = int(data_hdr) 52 | self.vprint('Receiving message of size: {}'.format(sz)) 53 | 54 | # Read data 55 | data = self.recvall(sz) 56 | self.vprint('{} bytes read'.format(len(data))) 57 | 58 | # Parse the message 59 | req_msg = RequestWrapper() 60 | req_msg.ParseFromString(data) 61 | self.vprint('Message parsed') 62 | 63 | # Process message 64 | resp_msg = self.process_message(req_msg) 65 | # Serialize response 66 | self.vprint('Serializing message') 67 | s = resp_msg.SerializeToString() 68 | msg_len = resp_msg.ByteSize() 69 | totallen = 12 + msg_len 70 | msg = bytes(str(totallen).zfill(12).encode('utf-8')) + s 71 | self.vprint('Sending response message of size: {}'.format(totallen)) 72 | self.sendmsg(msg, totallen) 73 | self.vprint('-----------------------------------------------') 74 | 75 | def process_message(self, message): 76 | if message.HasField('r1'): 77 | self.vprint('Received info request') 78 | return self.process_info(message) 79 | elif message.HasField('r2'): 80 | self.vprint('Received inference request') 81 | return self.process_inference(message) 82 | else: 83 | # Pass error message to the client 84 | return self.errormsg("Server received unindentified request from client.") 85 | 86 | def process_info(self, message): 87 | resp_msg = RespondWrapper() 88 | resp_msg.info = True 89 | resp_info = RespondInfo() 90 | resp_info.num_models = len(self.server.available_models) 91 | # Add all model info into the message 92 | for model in self.server.available_models: 93 | m = resp_info.models.add() 94 | m.name = model 95 | m.label = self.server.models[model].get_name() 96 | # Add inputs 97 | for inp_name, inp_channels in self.server.models[model].get_inputs().items(): 98 | inp = m.inputs.add() 99 | inp.name = inp_name 100 | inp.channels = inp_channels 101 | # Add outputs 102 | for out_name, out_channels in self.server.models[model].get_outputs().items(): 103 | out = m.outputs.add() 104 | out.name = out_name 105 | out.channels = out_channels 106 | # Add options 107 | for opt_name, opt_value in self.server.models[model].get_options().items(): 108 | if type(opt_value) == int: 109 | opt = m.int_options.add() 110 | elif type(opt_value) == float: 111 | opt = m.float_options.add() 112 | elif type(opt_value) == bool: 113 | opt = m.bool_options.add() 114 | elif type(opt_value) == str: 115 | opt = m.string_options.add() 116 | # TODO: Implement multiple choice 117 | else: 118 | # Send an error response message to the Nuke Client 119 | option_error = ("Model option of type {} is not implemented. " 120 | "Broadcasted options need to be one of bool, int, float, str." 121 | ).format(type(opt_value)) 122 | return self.errormsg(option_error) 123 | opt.name = opt_name 124 | opt.values.extend([opt_value]) 125 | # Add buttons 126 | for button_name, button_value in self.server.models[model].get_buttons().items(): 127 | if type (button_value) == bool: 128 | button = m.button_options.add() 129 | else: 130 | return self.errormsg("Model button needs to be of type bool.") 131 | button.name = button_name 132 | button.values.extend([button_value]) 133 | 134 | # Add RespondInfo message to RespondWrapper 135 | resp_msg.r1.CopyFrom(resp_info) 136 | 137 | return resp_msg 138 | 139 | def process_inference(self, message): 140 | req = message.r2 141 | m = req.model 142 | self.vprint('Requesting inference on model: {}'.format(m.name)) 143 | 144 | # Parse model options 145 | opt = {} 146 | for options in [m.bool_options, m.int_options, m.float_options, m.string_options]: 147 | for option in options: 148 | opt[option.name] = option.values[0] 149 | # Set model options 150 | self.server.models[m.name].set_options(opt) 151 | # Parse model buttons 152 | btn = {} 153 | for button in m.button_options: 154 | btn[button.name] = button.values[0] 155 | self.server.models[m.name].set_buttons(btn) 156 | 157 | # Parse images 158 | img_list = [] 159 | for byte_img in req.images: 160 | img = np.fromstring(byte_img.image, dtype=' ' + string) 242 | 243 | if __name__ == "__main__": 244 | parser = argparse.ArgumentParser(description='Machine Learning inference server.') 245 | parser.add_argument('port', type=int, help='Port number for the server to listen to.') 246 | args = parser.parse_args() 247 | 248 | # Get the current hostname of the server 249 | server_hostname = socket.gethostbyname(socket.gethostname()) 250 | # Create the server 251 | server = MLTCPServer((server_hostname, args.port), ImageProcessTCPHandler, False) 252 | 253 | # Bind and activate the server 254 | server.allow_reuse_address = True 255 | server.server_bind() 256 | server.server_activate() 257 | print('Server -> Listening on port: {}'.format(args.port)) 258 | server.serve_forever() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Python-based Machine Learning Frame Server for Nuke 2 | 3 | This repository contains the client-server system enabling Machine Learning (ML) inference in Nuke. This work is split into two parts: a client Nuke plug-in [Plugins/Client/](Plugins/Client) and the Python frame server [Plugins/Server](Plugins/Server). 4 | 5 | The following models are provided as examples: 6 | - blur: a simple gaussian blur operation 7 | - [Mask-RCNN](https://github.com/facebookresearch/Detectron) 8 | - [trainingTemplateTF](https://github.com/TheFoundryVisionmongers/nuke-ML-server/tree/master/Models/trainingTemplateTF): a training template written in TensorFlow which enables simple image-to-image training. Instructions on how to use this template are found [here](https://github.com/TheFoundryVisionmongers/nuke-ML-server/tree/master/Models/trainingTemplateTF). 9 | 10 |
11 | 12 |

Example of Nuke doing DensePose inference.

13 |
14 | 15 | ## Introduction 16 | 17 | The Machine Learning (ML) plug-in connects Nuke to a Python server to apply ML models to images. 18 | The plug-in works as follows: 19 | - The Nuke node can connect to a server given an ip address and port, 20 | - The Python server responds with the list of available Machine Learning (ML) models and options, 21 | - The Nuke node displays the models in an enumeration knob, from which the user can choose, 22 | - On every renderStripe call, the current image and model options are sent from the Nuke node to the server, 23 | - The server does an inference on the image using the chosen model/options. This inference can be an actual inference operation of a machine learning model, or just some other image processing code, 24 | - The resulting image is sent back to the Nuke node. 25 | 26 | ## Installation 27 | 28 | Please find installation instructions in [INSTALL.md](INSTALL.md). 29 | 30 | ## Known Issues 31 | 32 | 1. The GPU can run out of memory when doing model inference. To run Mask-RCNN, it is necessary to have a GPU memory of at least 6GB. 33 | 2. If you get the following error: "The TensorFlow library was compiled to use AVX instructions, but these aren't available on your machine." Please refer to [issue#10](https://github.com/TheFoundryVisionmongers/nuke-ML-server/issues/10) [Thanks to [samhodge](https://github.com/samhodge)] 34 | 35 | ## License 36 | 37 | The source code is licensed under the Apache License, Version 2.0, found in [LICENSE](LICENSE). 38 | 39 | ## Contacts 40 | 41 | - Johanna Barbier (Johanna.Barbier@foundry.com) 42 | - Dan Ring (Dan.Ring@foundry.com) 43 | 44 | This plug-in was initially created by Sebastian Lutz (https://v-sense.scss.tcd.ie/?profile=sebastian-lutz). 45 | 46 | ## References 47 | 48 | - [Mask R-CNN](https://arxiv.org/abs/1703.06870). 49 | Kaiming He, Georgia Gkioxari, Piotr Dollár, and Ross Girshick. 50 | IEEE International Conference on Computer Vision (ICCV), 2017. 51 | - [DensePose: Dense Human Pose Estimation In The Wild](https://arxiv.org/abs/1802.00434). 52 | Riza Alp Güler, Natalia Neverova, Iasonas Kokkinos. 53 | IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2018. --------------------------------------------------------------------------------