├── README.md ├── _config.yml ├── main.c ├── model.py └── p.patch /README.md: -------------------------------------------------------------------------------- 1 | # Deploying Tensorflow as C/C++ executable 2 | 3 | Here's a scenario that I believe some non-data engineers or data scientists are confronted with. 4 | 5 | **How do I deliver a Tensorflow model that I trained in Python but deploy in pure C/C++ code on the client side without setting up a Python environment on their side, and on top of that, all files have to be binaries?** 6 | 7 | The answer to that is to use the Tensorflow C or C++ API. In this article, we only look at how to use the C API (not the C++/TensorflowLite) that runs only on the CPU. 8 | 9 | You would think that the *famous* Tensorflow would have documentation about how to compile a simple C solution with Tensorflow, but as of now (TF2.1), there is little to no information about that. I'm here to share my findings. 10 | 11 | This article will explain how to run a common C programme using Tensorflow's C API 2.1. The environment that I will use throughout the article is as follows: 12 | 13 | - OS : Linux ( Tested and worked on un fresh Ubuntu 19.10/OpenSuse Tumbleweed) 14 | - Latest GCC 15 | - Tensorflow from [Github](https://github.com/tensorflow/tensorflow) (master branch 2.1) 16 | - No GPU 17 | 18 | Also, I would like to credit Vlad Dovgalecs and his [article](https://medium.com/@vladislavsd/undocumented-tensorflow-c-api-b527c0b4ef6) at Medium, as this tutorial is largely based on and improved upon his findings. 19 | 20 | # Tutorial structure 21 | This article will be a bit lengthy. But here is what we will do, step by step: 22 | 23 | 1. Clone Tensorflow source code and compile to get the C API headers and binaries. 24 | 2. Build the simpliest model using Python and Tensorflow and export it as a TF model that can be read by the C API. 25 | 3. Build a simple C program, compile it with "gcc," and run it like a normal execution file. 26 | 27 | So here we go: 28 | 29 | # 1. Getting the Tensorflow C API 30 | As far as I know, there are two ways to get those C API headers. 31 | - Download the precompiled Tensorflow C API from the website (binaries may not be up to date).**OR** 32 | - Clone and compile from source code (a time-consuming process, but if things don't work, we can debug and examine the API). 33 | 34 | So I'm going to show how to compile their code and use their binaries. 35 | 36 | ## Step A: clone their projects 37 | create a folder and clone the project 38 | 39 | ``` 40 | git clone https://github.com/tensorflow/tensorflow.git 41 | ``` 42 | 43 | ## Step B: Install the tools that are required for the compilation (Bazel, Numpy) 44 | 45 | You would need [Bazel](https://bazel.build/) to compile. Install it on your environment 46 | 47 | Ubuntu : 48 | ``` 49 | sudo apt update && sudo apt install bazel-1.2.1 50 | ``` 51 | 52 | OpenSuse : 53 | ``` 54 | sudo zypper install bazel 55 | ``` 56 | 57 | Whichever platform you use, make sure the Bazel version is 1.2.1, as this is what Tensorflow 2.1 is currently using. This could change in the future. 58 | 59 | Next, we would need to instal `Numpy` Python's package (why would we need a Python package to build a C API?). You can instal it however you want, as long as it can be referenced back during compilation. But I prefer to instal it through [Miniconda](https://docs.conda.io/en/latest/miniconda.html) and have a separate virtual environment for the build. Here's how: 60 | 61 | Install Miniconda : 62 | ```bash 63 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh 64 | sudo chmod 777 Miniconda3-latest-Linux-x86_64.sh 65 | ./Miniconda3-latest-Linux-x86_64.sh 66 | # follow the default installation direction 67 | ``` 68 | 69 | Create a new environtment + Numpy named tf-build: 70 | ```bash 71 | conda create -n tf-build python=3.7 numpy 72 | ``` 73 | we use this environtment later in step D. 74 | 75 | ## Step C: Apply patch to the source code (IMPORTANT!) 76 | 77 | Tensorflow 2.1 source code has a bug that will make you fail to build it. Refer to this [issue](https://github.com/clearlinux/distribution/issues/1151). The fix is to apply a patch [here](https://github.com/clearlinux-pkgs/tensorflow/blob/master/Add-grpc-fix-for-gettid.patch). I included a file in this repository that can be used as the patch. 78 | ```bash 79 | # copy/download the "p.patch" file from my repo and past at the root of Tensorflow source code. 80 | git apply p.patch 81 | ``` 82 | In future this might be fixed and not relevant. 83 | 84 | ## Step D: Compile the code 85 | 86 | By referring to the Tensorflow [documentation](https://www.tensorflow.org/install/lang_c) and github [Readme](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/lib_package/README.md). Here's how we compile it. We need to activate out conda env first for it refer to Numpy 87 | 88 | 89 | ```bash 90 | conda activate tf-build # skip this if you already have numpy installed globally 91 | 92 | # make sure you're at the root of the Tensorflow source code. 93 | bazel test -c opt tensorflow/tools/lib_package:libtensorflow_test # note that this will take very long to compile 94 | bazel build -c opt tensorflow/tools/lib_package:libtensorflow_test 95 | ``` 96 | Let me **WARN** you again. It takes 2 hours to compile on a VM with Ubuntu in a 6-core configuration. My friend with a 2-core laptop basically froze trying to compile this. Here is an advice. Run on a server with a powerful CPU and RAM.  97 | 98 | copy the file at `bazel-bin/tensorflow/tools/lib_package/libtensorflow.tar.gz` and paste it to your desired folder. Untar it as follows:  99 | ``` 100 | tar -C /usr/local -xzf libtensorflow.tar.gz 101 | ``` 102 | I untar it at my home folder instead of at `/usr/local` as I was just trying it out. 103 | 104 | CONGRATULATION!! YOU MADE IT. at least for compiling tensorflow. 105 | 106 | # 2. Simple model with Python 107 | 108 | In this step, we will build a model with the `tf.keras.layers` class and save it to be loaded later with the C API. Refer to the full code at `model.py` in the [repo](https://github.com/AmirulOm/tensorflow_capi_sample/blob/master/model.py). 109 | 110 | ## Step A: Write the model 111 | here is simple model where is has a custom `tf.keras.layers.Model`, with single `dense` layer. Which is initialized with `ones`. Hence the output of this model (from the `def call()`) will produce an output that is similar to the input. 112 | 113 | ```python 114 | import numpy as np 115 | import tensorflow as tf 116 | from tensorflow import keras 117 | from tensorflow.keras import layers 118 | 119 | class testModel(tf.keras.Model): 120 | def __init__(self): 121 | super(testModel, self).__init__() 122 | self.dense1 = tf.keras.layers.Dense(1, kernel_initializer='Ones', activation=tf.nn.relu) 123 | 124 | def call(self, inputs): 125 | return self.dense1(inputs) 126 | 127 | input_data = np.asarray([[10]]) 128 | module = testModel() 129 | module._set_inputs(input_data) 130 | print(module(input_data)) 131 | 132 | # Export the model to a SavedModel 133 | module.save('model', save_format='tf') 134 | ``` 135 | 136 | Eversince Tensorflow 2.0, Eager execution allow us to run a model without drafting the graph and run through `session`. But in order to save the model ( refer to this line `module.save('model', save_format='tf')`), the graph needs to be built before it can be saved. Hence, we will need to call the model at least once for it to create the graph. Calling `print(module(input_data))` will force it to create the graph. 137 | 138 | Next run the code: 139 | ``` 140 | python model.py 141 | ``` 142 | You should get an output as below: 143 | ``` 144 | 2020-01-30 11:46:25.400334: I tensorflow/core/platform/cpu_feature_guard.cc:145] This TensorFlow binary is optimized with Intel(R) MKL-DNN to use the following CPU instructions in performance critical operations: SSE4.1 SSE4.2 AVX AVX2 FMA 145 | To enable them in non-MKL-DNN operations, rebuild TensorFlow with the appropriate compiler flags. 146 | 2020-01-30 11:46:25.421717: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 3699495000 Hz 147 | 2020-01-30 11:46:25.422615: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x561bef5ac2a0 executing computations on platform Host. Devices: 148 | 2020-01-30 11:46:25.422655: I tensorflow/compiler/xla/service/service.cc:175] StreamExecutor device (0): Host, Default Version 149 | 2020-01-30 11:46:25.422744: I tensorflow/core/common_runtime/process_util.cc:115] Creating new thread pool with default inter op setting: 2. Tune using inter_op_parallelism_threads for best performance. 150 | tf.Tensor([[10.]], shape=(1, 1), dtype=float32) 151 | ``` 152 | 153 | A folder called `model` should also be created.  154 | 155 | ## Step B: Verified the saved model 156 | 157 | When we save a model, it will create a folder with a bunch of files inside it. It basically stores the weights and the graphs of the model. Tensorflow includes a tool for diving into these files and matching the input and output tensors. It is called `saved_model_cli`. It is a command line tool that comes together when you install Tensorflow. 158 | 159 | BUT WAIT!, we haven't install tensorflow !!. so basicly there is two way to get `saved_model_cli` 160 | - Install tensorflow 161 | - Build from source code and looks for `saved_model_cli` 162 | 163 | for this I will just install tensorflow in seperate conda environment and call it there, we only need to use it once anyway. so here we go 164 | 165 | Install tensorflow in seperate conda environment : 166 | 167 | ```bash 168 | conda create -n tf python=3.7 tensorflow 169 | ``` 170 | 171 | Activate the environment: 172 | ``` 173 | conda activate tf 174 | ``` 175 | 176 | by now you should be able to call `saved_model_cli` through command line. 177 | 178 | We would need to extract the graph names for the input and output tensors and use that information later when calling the C API. Here's how: 179 | 180 | ```bash 181 | saved_model_cli show --dir 182 | ``` 183 | 184 | running this and replaced the appropriate path, you should get an output like below: 185 | ``` 186 | The given SavedModel contains the following tag-sets: 187 | serve 188 | ``` 189 | use this tag-set to further drill into the tensor graph, here's how: 190 | ``` 191 | saved_model_cli show --dir --tag_set serve 192 | ``` 193 | and you should get an output like below: 194 | ``` 195 | The given SavedModel MetaGraphDef contains SignatureDefs with the following keys: 196 | SignatureDef key: "__saved_model_init_op" 197 | SignatureDef key: "serving_default" 198 | ``` 199 | 200 | using `serving_default` signature key into command to print out the tensor node: 201 | ``` 202 | saved_model_cli show --dir --tag_set serve --signature_def serving_default 203 | ``` 204 | 205 | and you should get an output like below: 206 | ``` 207 | The given SavedModel SignatureDef contains the following input(s): 208 | inputs['input_1'] tensor_info: 209 | dtype: DT_INT64 210 | shape: (-1, 1) 211 | name: serving_default_input_1:0 212 | The given SavedModel SignatureDef contains the following output(s): 213 | outputs['output_1'] tensor_info: 214 | dtype: DT_FLOAT 215 | shape: (-1, 1) 216 | name: StatefulPartitionedCall:0 217 | Method name is: tensorflow/serving/predict 218 | ``` 219 | here we would need the name `serving_default_input_1` and `StatefulPartitionedCall` later to be use in the C API. 220 | 221 | # 3. Building C/C++ code 222 | 223 | Third part is to write the C code that use the Tensorflow C API and import the Python saved model. The full code can be refer at [here](https://github.com/AmirulOm/tensorflow_capi_sample/blob/master/main.c). 224 | 225 | There is no C API proper documentation, so if something went wrong, it's best to look back at ther C header in the [source code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/c_api.h) (You can also debug using GDB and step by step learn how the C header works) 226 | 227 | 228 | ## Step A: Write C code 229 | On empty file, import the tensorflow C API as follow: 230 | 231 | ```cpp 232 | #include 233 | #include 234 | #include "tensorflow/c/c_api.h" 235 | 236 | void NoOpDeallocator(void* data, size_t a, void* b) {} 237 | 238 | int main() 239 | { 240 | } 241 | ``` 242 | Note that you have `NoOpDeallocator` void function declared, we will use it later 243 | 244 | Next need to load the savedmodel and the session using `TF_LoadSessionFromSavedModel` API. 245 | 246 | ```cpp 247 | 248 | //********* Read model 249 | TF_Graph* Graph = TF_NewGraph(); 250 | TF_Status* Status = TF_NewStatus(); 251 | 252 | TF_SessionOptions* SessionOpts = TF_NewSessionOptions(); 253 | TF_Buffer* RunOpts = NULL; 254 | 255 | const char* saved_model_dir = "model/"; // Path of the model 256 | const char* tags = "serve"; // default model serving tag; can change in future 257 | int ntags = 1; 258 | 259 | TF_Session* Session = TF_LoadSessionFromSavedModel(SessionOpts, RunOpts, saved_model_dir, &tags, ntags, Graph, NULL, Status); 260 | if(TF_GetCode(Status) == TF_OK) 261 | { 262 | printf("TF_LoadSessionFromSavedModel OK\n"); 263 | } 264 | else 265 | { 266 | printf("%s",TF_Message(Status)); 267 | } 268 | ``` 269 | 270 | Next we grab the tensor node from the graph by their name. Remember earlier we search for tensor name using `saved_model_cli`?. here where we use it back when we call `TF_GraphOperationByName()`. In this example, `serving_default_input_1` is our input tensor and `StatefulPartitionedCall` is out output tensor. 271 | 272 | ```cpp 273 | //****** Get input tensor 274 | int NumInputs = 1; 275 | TF_Output* Input = malloc(sizeof(TF_Output) * NumInputs); 276 | 277 | TF_Output t0 = {TF_GraphOperationByName(Graph, "serving_default_input_1"), 0}; 278 | if(t0.oper == NULL) 279 | printf("ERROR: Failed TF_GraphOperationByName serving_default_input_1\n"); 280 | else 281 | printf("TF_GraphOperationByName serving_default_input_1 is OK\n"); 282 | 283 | Input[0] = t0; 284 | 285 | //********* Get Output tensor 286 | int NumOutputs = 1; 287 | TF_Output* Output = malloc(sizeof(TF_Output) * NumOutputs); 288 | 289 | TF_Output t2 = {TF_GraphOperationByName(Graph, "StatefulPartitionedCall"), 0}; 290 | if(t2.oper == NULL) 291 | printf("ERROR: Failed TF_GraphOperationByName StatefulPartitionedCall\n"); 292 | else 293 | printf("TF_GraphOperationByName StatefulPartitionedCall is OK\n"); 294 | 295 | Output[0] = t2; 296 | ``` 297 | 298 | Next we will need to allocate the new tensor locally using `TF_NewTensor`, set the input value and later we will pass to session run. *NOTE that `ndata` is total byte size of your data, not lenght of the array* 299 | 300 | Here we set the input tensor with value of 20. and we should see the output value as 20 as well. 301 | 302 | ```cpp 303 | //********* Allocate data for inputs & outputs 304 | TF_Tensor** InputValues = (TF_Tensor**)malloc(sizeof(TF_Tensor*)*NumInputs); 305 | TF_Tensor** OutputValues = malloc(sizeof(TF_Tensor*)*NumOutputs); 306 | 307 | int ndims = 2; 308 | int64_t dims[] = {1,1}; 309 | int64_t data[] = {20}; 310 | int ndata = sizeof(int64_t); // This is tricky, it number of bytes not number of element 311 | 312 | TF_Tensor* int_tensor = TF_NewTensor(TF_INT64, dims, ndims, data, ndata, &NoOpDeallocator, 0); 313 | if (int_tensor != NULL) 314 | { 315 | printf("TF_NewTensor is OK\n"); 316 | } 317 | else 318 | printf("ERROR: Failed TF_NewTensor\n"); 319 | 320 | InputValues[0] = int_tensor; 321 | ``` 322 | 323 | Next we can run the model by invoking `TF_SessionRun` API. Here's how: 324 | 325 | ```cpp 326 | // //Run the Session 327 | TF_SessionRun(Session, NULL, Input, InputValues, NumInputs, Output, OutputValues, NumOutputs, NULL, 0,NULL , Status); 328 | 329 | if(TF_GetCode(Status) == TF_OK) 330 | { 331 | printf("Session is OK\n"); 332 | } 333 | else 334 | { 335 | printf("%s",TF_Message(Status)); 336 | } 337 | 338 | // //Free memory 339 | TF_DeleteGraph(Graph); 340 | TF_DeleteSession(Session, Status); 341 | TF_DeleteSessionOptions(SessionOpts); 342 | TF_DeleteStatus(Status); 343 | ``` 344 | Lastly, we want get back the output value from the output tensor using `TF_TensorData` that extract data from the tensor object. Since we know the size of the output which is 1, i can directly print it. Else use `TF_GraphGetTensorNumDims` or other API that is available in `c_api.h` or `tf_tensor.h` 345 | 346 | ```cpp 347 | 348 | void* buff = TF_TensorData(OutputValues[0]); 349 | float* offsets = buff; 350 | printf("Result Tensor :\n"); 351 | printf("%f\n",offsets[0]); 352 | return 0; 353 | ``` 354 | 355 | ## Step B: Compile the code 356 | 357 | Compile it as below: 358 | 359 | ```bash 360 | gcc -I/include/ -L/lib main.c -ltensorflow -o main.out 361 | ``` 362 | 363 | ## Step C: Run it 364 | 365 | Before you run it. You'll need to make sure the C library is exported in your environment 366 | 367 | ```bash 368 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/lib 369 | ``` 370 | 371 | RUN IT 372 | 373 | ``` 374 | ./main.out 375 | ``` 376 | 377 | You should get an output like below. Notice that the output value is 20 like out input. you can change the model and initiliaze the kernel with weight of value 2 and see if it reflected to other value. 378 | 379 | ``` 380 | 2020-01-31 09:47:48.842680: I tensorflow/cc/saved_model/reader.cc:31] Reading SavedModel from: model/ 381 | 2020-01-31 09:47:48.844252: I tensorflow/cc/saved_model/reader.cc:54] Reading meta graph with tags { serve } 382 | 2020-01-31 09:47:48.844295: I tensorflow/cc/saved_model/loader.cc:264] Reading SavedModel debug info (if present) from: model/ 383 | 2020-01-31 09:47:48.844385: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE3 SSE4.1 SSE4.2 AVX AVX2 FMA 384 | 2020-01-31 09:47:48.859883: I tensorflow/cc/saved_model/loader.cc:203] Restoring SavedModel bundle. 385 | 2020-01-31 09:47:48.908997: I tensorflow/cc/saved_model/loader.cc:152] Running initialization op on SavedModel bundle at path: model/ 386 | 2020-01-31 09:47:48.923127: I tensorflow/cc/saved_model/loader.cc:333] SavedModel load for tags { serve }; Status: success: OK. Took 80457 microseconds. 387 | TF_LoadSessionFromSavedModel OK 388 | TF_GraphOperationByName serving_default_input_1 is OK 389 | TF_GraphOperationByName StatefulPartitionedCall is OK 390 | TF_NewTensor is OK 391 | Session is OK 392 | Result Tensor : 393 | 20.000000 394 | ``` 395 | 396 | END 397 | 398 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-minimal -------------------------------------------------------------------------------- /main.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "tensorflow/c/c_api.h" 4 | 5 | void NoOpDeallocator(void* data, size_t a, void* b) {} 6 | 7 | int main() 8 | { 9 | //********* Read model 10 | TF_Graph* Graph = TF_NewGraph(); 11 | TF_Status* Status = TF_NewStatus(); 12 | 13 | TF_SessionOptions* SessionOpts = TF_NewSessionOptions(); 14 | TF_Buffer* RunOpts = NULL; 15 | 16 | const char* saved_model_dir = "lstm2/"; 17 | const char* tags = "serve"; // default model serving tag; can change in future 18 | int ntags = 1; 19 | 20 | TF_Session* Session = TF_LoadSessionFromSavedModel(SessionOpts, RunOpts, saved_model_dir, &tags, ntags, Graph, NULL, Status); 21 | if(TF_GetCode(Status) == TF_OK) 22 | { 23 | printf("TF_LoadSessionFromSavedModel OK\n"); 24 | } 25 | else 26 | { 27 | printf("%s",TF_Message(Status)); 28 | } 29 | 30 | //****** Get input tensor 31 | //TODO : need to use saved_model_cli to read saved_model arch 32 | int NumInputs = 1; 33 | TF_Output* Input = (TF_Output*)malloc(sizeof(TF_Output) * NumInputs); 34 | 35 | TF_Output t0 = {TF_GraphOperationByName(Graph, "serving_default_input_1"), 0}; 36 | if(t0.oper == NULL) 37 | printf("ERROR: Failed TF_GraphOperationByName serving_default_input_1\n"); 38 | else 39 | printf("TF_GraphOperationByName serving_default_input_1 is OK\n"); 40 | 41 | Input[0] = t0; 42 | 43 | //********* Get Output tensor 44 | int NumOutputs = 1; 45 | TF_Output* Output = (TF_Output*)malloc(sizeof(TF_Output) * NumOutputs); 46 | 47 | TF_Output t2 = {TF_GraphOperationByName(Graph, "StatefulPartitionedCall"), 0}; 48 | if(t2.oper == NULL) 49 | printf("ERROR: Failed TF_GraphOperationByName StatefulPartitionedCall\n"); 50 | else 51 | printf("TF_GraphOperationByName StatefulPartitionedCall is OK\n"); 52 | 53 | Output[0] = t2; 54 | 55 | //********* Allocate data for inputs & outputs 56 | TF_Tensor** InputValues = (TF_Tensor**)malloc(sizeof(TF_Tensor*)*NumInputs); 57 | TF_Tensor** OutputValues = (TF_Tensor**)malloc(sizeof(TF_Tensor*)*NumOutputs); 58 | 59 | int ndims = 2; 60 | int64_t dims[] = {1,30}; 61 | float data[1*30] ;//= {1,1,1,1,1,1,1,1,1,1}; 62 | for(int i=0; i< (1*30); i++) 63 | { 64 | data[i] = 1.00; 65 | } 66 | int ndata = sizeof(float)*1*30 ;// This is tricky, it number of bytes not number of element 67 | 68 | TF_Tensor* int_tensor = TF_NewTensor(TF_FLOAT, dims, ndims, data, ndata, &NoOpDeallocator, 0); 69 | if (int_tensor != NULL) 70 | { 71 | printf("TF_NewTensor is OK\n"); 72 | } 73 | else 74 | printf("ERROR: Failed TF_NewTensor\n"); 75 | 76 | InputValues[0] = int_tensor; 77 | 78 | // //Run the Session 79 | TF_SessionRun(Session, NULL, Input, InputValues, NumInputs, Output, OutputValues, NumOutputs, NULL, 0,NULL , Status); 80 | 81 | if(TF_GetCode(Status) == TF_OK) 82 | { 83 | printf("Session is OK\n"); 84 | } 85 | else 86 | { 87 | printf("%s",TF_Message(Status)); 88 | } 89 | 90 | // //Free memory 91 | TF_DeleteGraph(Graph); 92 | TF_DeleteSession(Session, Status); 93 | TF_DeleteSessionOptions(SessionOpts); 94 | TF_DeleteStatus(Status); 95 | 96 | 97 | void* buff = TF_TensorData(OutputValues[0]); 98 | float* offsets = (float*)buff; 99 | printf("Result Tensor :\n"); 100 | for(int i=0;i<10;i++) 101 | { 102 | printf("%f\n",offsets[i]); 103 | } 104 | 105 | 106 | } 107 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import tensorflow as tf 4 | from tensorflow import keras 5 | from tensorflow.keras import layers, Sequential 6 | from tensorflow.keras.layers import Dense, LSTM, InputLayer, Bidirectional, TimeDistributed, Embedding, Activation 7 | from tensorflow.keras.optimizers import Adam 8 | 9 | model = Sequential() 10 | model.add(InputLayer(input_shape=(30))) 11 | model.add(Embedding(1000 + 1, 40)) # tunable output length 12 | model.add(Bidirectional(LSTM(128, return_sequences=True))) 13 | model.add(TimeDistributed(Dense(60 + 1))) 14 | model.add(Activation('softmax')) 15 | model.compile(loss='categorical_crossentropy', optimizer=Adam(0.001), metrics=['accuracy']) 16 | model.summary() 17 | input_data = np.ones((30,1)) 18 | print(model(input_data)) 19 | model.save("lstm2") -------------------------------------------------------------------------------- /p.patch: -------------------------------------------------------------------------------- 1 | diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl 2 | index 55d7eb93..33e86087 100755 3 | --- a/tensorflow/workspace.bzl 4 | +++ b/tensorflow/workspace.bzl 5 | @@ -486,6 +486,7 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""): 6 | # WARNING: make sure ncteisen@ and vpai@ are cc-ed on any CL to change the below rule 7 | tf_http_archive( 8 | name = "grpc", 9 | + patch_file = clean_dep("//third_party:Rename-gettid-functions.patch"), 10 | sha256 = "67a6c26db56f345f7cee846e681db2c23f919eba46dd639b09462d1b6203d28c", 11 | strip_prefix = "grpc-4566c2a29ebec0835643b972eb99f4306c4234a3", 12 | system_build_file = clean_dep("//third_party/systemlibs:grpc.BUILD"), 13 | diff --git a/third_party/Rename-gettid-functions.patch b/third_party/Rename-gettid-functions.patch 14 | new file mode 100644 15 | index 00000000..90bd9115 16 | --- /dev/null 17 | +++ b/third_party/Rename-gettid-functions.patch 18 | @@ -0,0 +1,78 @@ 19 | +From d1d017390b799c59d6fdf7b8afa6136d218bdd61 Mon Sep 17 00:00:00 2001 20 | +From: Benjamin Peterson 21 | +Date: Fri, 3 May 2019 08:11:00 -0700 22 | +Subject: [PATCH] Rename gettid() functions. 23 | + 24 | +glibc 2.30 will declare its own gettid; see https://sourceware.org/git/?p=glibc.git;a=commit;h=1d0fc213824eaa2a8f8c4385daaa698ee8fb7c92. Rename the grpc versions to avoid naming conflicts. 25 | +--- 26 | + src/core/lib/gpr/log_linux.cc | 4 ++-- 27 | + src/core/lib/gpr/log_posix.cc | 4 ++-- 28 | + src/core/lib/iomgr/ev_epollex_linux.cc | 4 ++-- 29 | + 3 files changed, 6 insertions(+), 6 deletions(-) 30 | + 31 | +diff --git a/src/core/lib/gpr/log_linux.cc b/src/core/lib/gpr/log_linux.cc 32 | +index 561276f0c20..8b597b4cf2f 100644 33 | +--- a/src/core/lib/gpr/log_linux.cc 34 | ++++ b/src/core/lib/gpr/log_linux.cc 35 | +@@ -40,7 +40,7 @@ 36 | + #include 37 | + #include 38 | + 39 | +-static long gettid(void) { return syscall(__NR_gettid); } 40 | ++static long sys_gettid(void) { return syscall(__NR_gettid); } 41 | + 42 | + void gpr_log(const char* file, int line, gpr_log_severity severity, 43 | + const char* format, ...) { 44 | +@@ -70,7 +70,7 @@ void gpr_default_log(gpr_log_func_args* args) { 45 | + gpr_timespec now = gpr_now(GPR_CLOCK_REALTIME); 46 | + struct tm tm; 47 | + static __thread long tid = 0; 48 | +- if (tid == 0) tid = gettid(); 49 | ++ if (tid == 0) tid = sys_gettid(); 50 | + 51 | + timer = static_cast(now.tv_sec); 52 | + final_slash = strrchr(args->file, '/'); 53 | +diff --git a/src/core/lib/gpr/log_posix.cc b/src/core/lib/gpr/log_posix.cc 54 | +index b6edc14ab6b..2f7c6ce3760 100644 55 | +--- a/src/core/lib/gpr/log_posix.cc 56 | ++++ b/src/core/lib/gpr/log_posix.cc 57 | +@@ -31,7 +31,7 @@ 58 | + #include 59 | + #include 60 | + 61 | +-static intptr_t gettid(void) { return (intptr_t)pthread_self(); } 62 | ++static intptr_t sys_gettid(void) { return (intptr_t)pthread_self(); } 63 | + 64 | + void gpr_log(const char* file, int line, gpr_log_severity severity, 65 | + const char* format, ...) { 66 | +@@ -86,7 +86,7 @@ void gpr_default_log(gpr_log_func_args* args) { 67 | + char* prefix; 68 | + gpr_asprintf(&prefix, "%s%s.%09d %7" PRIdPTR " %s:%d]", 69 | + gpr_log_severity_string(args->severity), time_buffer, 70 | +- (int)(now.tv_nsec), gettid(), display_file, args->line); 71 | ++ (int)(now.tv_nsec), sys_gettid(), display_file, args->line); 72 | + 73 | + fprintf(stderr, "%-70s %s\n", prefix, args->message); 74 | + gpr_free(prefix); 75 | +diff --git a/src/core/lib/iomgr/ev_epollex_linux.cc b/src/core/lib/iomgr/ev_epollex_linux.cc 76 | +index 08116b3ab53..76f59844312 100644 77 | +--- a/src/core/lib/iomgr/ev_epollex_linux.cc 78 | ++++ b/src/core/lib/iomgr/ev_epollex_linux.cc 79 | +@@ -1102,7 +1102,7 @@ static void end_worker(grpc_pollset* pollset, grpc_pollset_worker* worker, 80 | + } 81 | + 82 | + #ifndef NDEBUG 83 | +-static long gettid(void) { return syscall(__NR_gettid); } 84 | ++static long sys_gettid(void) { return syscall(__NR_gettid); } 85 | + #endif 86 | + 87 | + /* pollset->mu lock must be held by the caller before calling this. 88 | +@@ -1122,7 +1122,7 @@ static grpc_error* pollset_work(grpc_pollset* pollset, 89 | + #define WORKER_PTR (&worker) 90 | + #endif 91 | + #ifndef NDEBUG 92 | +- WORKER_PTR->originator = gettid(); 93 | ++ WORKER_PTR->originator = sys_gettid(); 94 | + #endif 95 | + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { 96 | + gpr_log(GPR_INFO, 97 | -- 98 | --------------------------------------------------------------------------------