├── serving ├── src │ ├── test │ │ ├── resources │ │ │ └── application.yml │ │ └── java │ │ │ └── com │ │ │ └── mimacom │ │ │ ├── tensorflowdemo │ │ │ ├── TensorflowdemoApplicationTests.java │ │ │ └── WebModuleTestConfig.java │ │ │ └── irisml │ │ │ └── web │ │ │ └── controller │ │ │ ├── BaseControllerTest.java │ │ │ └── IrisControllerTest.java │ └── main │ │ ├── resources │ │ └── application.yml │ │ └── java │ │ └── com │ │ └── mimacom │ │ ├── irisml │ │ ├── domain │ │ │ ├── IrisType.java │ │ │ └── Iris.java │ │ ├── service │ │ │ ├── IrisClassifierService.java │ │ │ └── impl │ │ │ │ └── IrisTensorflowClassifierService.java │ │ └── web │ │ │ └── controller │ │ │ └── IrisController.java │ │ └── tensorflowdemo │ │ └── TensorflowDemoApplication.java ├── .gitignore └── pom.xml ├── training ├── stored_model │ └── 1530093489 │ │ └── variables │ │ ├── variables.index │ │ └── variables.data-00000-of-00001 ├── visualization │ └── tensorboard │ │ ├── 1530088196 │ │ └── events.out.tfevents.1530088948.DEMUJKHMPC2 │ │ ├── 1530089139 │ │ └── events.out.tfevents.1530089145.DEMUJKHMPC2 │ │ ├── 1530091803 │ │ └── events.out.tfevents.1530091808.DEMUJKHMPC2 │ │ ├── 1530091804 │ │ └── events.out.tfevents.1530091939.DEMUJKHMPC2 │ │ └── 1530092040 │ │ └── events.out.tfevents.1530092069.DEMUJKHMPC2 ├── .ipynb_checkpoints │ ├── TF_iris_data - working backup-checkpoint.ipynb │ └── TF_iris_data-checkpoint.ipynb └── TF_iris_data.ipynb └── README.md /serving/src/test/resources/application.yml: -------------------------------------------------------------------------------- 1 | irisml: 2 | savedModel: 3 | path: ../training/stored_model/1530093489 4 | tags: serve 5 | -------------------------------------------------------------------------------- /serving/src/main/resources/application.yml: -------------------------------------------------------------------------------- 1 | server: 2 | port: 7373 3 | 4 | 5 | irisml: 6 | savedModel: 7 | path: ../training/stored_model/1530093489 8 | tags: serve 9 | -------------------------------------------------------------------------------- /serving/src/main/java/com/mimacom/irisml/domain/IrisType.java: -------------------------------------------------------------------------------- 1 | package com.mimacom.irisml.domain; 2 | 3 | public enum IrisType { 4 | SETOSA, 5 | VERSICOLOUR, 6 | VIRGINICA 7 | } 8 | -------------------------------------------------------------------------------- /training/stored_model/1530093489/variables/variables.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ellerenad/Getting-started-Tensorflow-Java-Spring/HEAD/training/stored_model/1530093489/variables/variables.index -------------------------------------------------------------------------------- /training/stored_model/1530093489/variables/variables.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ellerenad/Getting-started-Tensorflow-Java-Spring/HEAD/training/stored_model/1530093489/variables/variables.data-00000-of-00001 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Repo storing the code for the blog post **Getting started with Tensorflow and Java-Spring**, published at [https://blog.mimacom.com/getting-started-tensorflow-spring/](https://blog.mimacom.com/getting-started-tensorflow-spring/) -------------------------------------------------------------------------------- /training/visualization/tensorboard/1530088196/events.out.tfevents.1530088948.DEMUJKHMPC2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ellerenad/Getting-started-Tensorflow-Java-Spring/HEAD/training/visualization/tensorboard/1530088196/events.out.tfevents.1530088948.DEMUJKHMPC2 -------------------------------------------------------------------------------- /training/visualization/tensorboard/1530089139/events.out.tfevents.1530089145.DEMUJKHMPC2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ellerenad/Getting-started-Tensorflow-Java-Spring/HEAD/training/visualization/tensorboard/1530089139/events.out.tfevents.1530089145.DEMUJKHMPC2 -------------------------------------------------------------------------------- /training/visualization/tensorboard/1530091803/events.out.tfevents.1530091808.DEMUJKHMPC2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ellerenad/Getting-started-Tensorflow-Java-Spring/HEAD/training/visualization/tensorboard/1530091803/events.out.tfevents.1530091808.DEMUJKHMPC2 -------------------------------------------------------------------------------- /training/visualization/tensorboard/1530091804/events.out.tfevents.1530091939.DEMUJKHMPC2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ellerenad/Getting-started-Tensorflow-Java-Spring/HEAD/training/visualization/tensorboard/1530091804/events.out.tfevents.1530091939.DEMUJKHMPC2 -------------------------------------------------------------------------------- /training/visualization/tensorboard/1530092040/events.out.tfevents.1530092069.DEMUJKHMPC2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ellerenad/Getting-started-Tensorflow-Java-Spring/HEAD/training/visualization/tensorboard/1530092040/events.out.tfevents.1530092069.DEMUJKHMPC2 -------------------------------------------------------------------------------- /serving/.gitignore: -------------------------------------------------------------------------------- 1 | /target/ 2 | .mvn 3 | mvnw 4 | mvnw.cmd 5 | 6 | ### STS ### 7 | .apt_generated 8 | .classpath 9 | .factorypath 10 | .project 11 | .settings 12 | .springBeans 13 | .sts4-cache 14 | 15 | ### IntelliJ IDEA ### 16 | .idea 17 | *.iws 18 | *.iml 19 | *.ipr 20 | 21 | ### NetBeans ### 22 | /nbproject/private/ 23 | /build/ 24 | /nbbuild/ 25 | /dist/ 26 | /nbdist/ 27 | /.nb-gradle/ -------------------------------------------------------------------------------- /serving/src/test/java/com/mimacom/tensorflowdemo/TensorflowdemoApplicationTests.java: -------------------------------------------------------------------------------- 1 | package com.mimacom.tensorflowdemo; 2 | 3 | import org.junit.Test; 4 | import org.junit.runner.RunWith; 5 | import org.springframework.boot.test.context.SpringBootTest; 6 | import org.springframework.test.context.junit4.SpringRunner; 7 | 8 | @RunWith(SpringRunner.class) 9 | @SpringBootTest 10 | public class TensorflowdemoApplicationTests { 11 | 12 | @Test 13 | public void contextLoads() { 14 | } 15 | 16 | } 17 | -------------------------------------------------------------------------------- /serving/src/main/java/com/mimacom/tensorflowdemo/TensorflowDemoApplication.java: -------------------------------------------------------------------------------- 1 | package com.mimacom.tensorflowdemo; 2 | 3 | import org.springframework.boot.SpringApplication; 4 | import org.springframework.boot.autoconfigure.SpringBootApplication; 5 | import org.springframework.context.annotation.ComponentScan; 6 | 7 | @ComponentScan(basePackages = {"com.mimacom.irisml"}) 8 | @SpringBootApplication 9 | public class TensorflowDemoApplication { 10 | 11 | public static void main(String[] args) throws Exception { 12 | SpringApplication.run(TensorflowDemoApplication.class, args); 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /serving/src/main/java/com/mimacom/irisml/service/IrisClassifierService.java: -------------------------------------------------------------------------------- 1 | package com.mimacom.irisml.service; 2 | 3 | import com.mimacom.irisml.domain.Iris; 4 | import com.mimacom.irisml.domain.IrisType; 5 | 6 | import java.util.Map; 7 | 8 | 9 | public interface IrisClassifierService { 10 | 11 | /** 12 | * Method to fetch a classification from the model 13 | * @param iris the data to classify 14 | * @return the predicted type 15 | */ 16 | IrisType classify(Iris iris); 17 | 18 | /** 19 | * Method to fetch from the model the probabilities of all the types 20 | * @param iris the data to classify 21 | * @return A map relating the type with its predicted probabilities 22 | */ 23 | Map classificationProbabilities(Iris iris); 24 | } 25 | -------------------------------------------------------------------------------- /serving/src/test/java/com/mimacom/tensorflowdemo/WebModuleTestConfig.java: -------------------------------------------------------------------------------- 1 | package com.mimacom.tensorflowdemo; 2 | 3 | import com.fasterxml.jackson.databind.ObjectMapper; 4 | import org.springframework.boot.autoconfigure.EnableAutoConfiguration; 5 | import org.springframework.context.annotation.Bean; 6 | import org.springframework.context.annotation.ComponentScan; 7 | import org.springframework.context.annotation.Configuration; 8 | import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter; 9 | import org.springframework.web.servlet.config.annotation.EnableWebMvc; 10 | 11 | @Configuration 12 | @EnableAutoConfiguration 13 | @EnableWebMvc 14 | @ComponentScan({"com.mimacom.irisml"}) 15 | /** 16 | * Configuration class for the tests of the controllers 17 | */ 18 | public class WebModuleTestConfig { 19 | 20 | @Bean 21 | public MappingJackson2HttpMessageConverter mappingJackson2HttpMessageConverter() { 22 | ObjectMapper objectMapper = new ObjectMapper(); 23 | return new MappingJackson2HttpMessageConverter(objectMapper); 24 | } 25 | 26 | } -------------------------------------------------------------------------------- /serving/src/main/java/com/mimacom/irisml/web/controller/IrisController.java: -------------------------------------------------------------------------------- 1 | package com.mimacom.irisml.web.controller; 2 | 3 | import com.mimacom.irisml.domain.Iris; 4 | import com.mimacom.irisml.domain.IrisType; 5 | import com.mimacom.irisml.service.IrisClassifierService; 6 | import org.springframework.beans.factory.annotation.Autowired; 7 | import org.springframework.web.bind.annotation.GetMapping; 8 | import org.springframework.web.bind.annotation.RequestParam; 9 | import org.springframework.web.bind.annotation.RestController; 10 | 11 | import java.util.Map; 12 | 13 | @RestController 14 | public class IrisController { 15 | 16 | @Autowired 17 | IrisClassifierService irisClassifierService; 18 | 19 | @GetMapping(value = "/iris/classify/class") 20 | public IrisType classify(Iris iris) { 21 | return irisClassifierService.classify(iris); 22 | } 23 | 24 | @GetMapping(value = "/iris/classify/probabilities") 25 | public Map classificationProbabilities(Iris iris) { 26 | return irisClassifierService.classificationProbabilities(iris); 27 | } 28 | 29 | } 30 | -------------------------------------------------------------------------------- /serving/src/main/java/com/mimacom/irisml/domain/Iris.java: -------------------------------------------------------------------------------- 1 | package com.mimacom.irisml.domain; 2 | 3 | public class Iris { 4 | 5 | private float petalLength; 6 | private float petalWidth; 7 | private float sepalLength; 8 | private float sepalWidth; 9 | 10 | public Iris() { 11 | } 12 | 13 | public Iris(float petalLength, float petalWidth, float sepalLength, float sepalWidth) { 14 | this.petalLength = petalLength; 15 | this.petalWidth = petalWidth; 16 | this.sepalLength = sepalLength; 17 | this.sepalWidth = sepalWidth; 18 | } 19 | 20 | public float getPetalLength() { 21 | return petalLength; 22 | } 23 | 24 | public void setPetalLength(float petalLength) { 25 | this.petalLength = petalLength; 26 | } 27 | 28 | public float getPetalWidth() { 29 | return petalWidth; 30 | } 31 | 32 | public void setPetalWidth(float petalWidth) { 33 | this.petalWidth = petalWidth; 34 | } 35 | 36 | public float getSepalLength() { 37 | return sepalLength; 38 | } 39 | 40 | public void setSepalLength(float sepalLength) { 41 | this.sepalLength = sepalLength; 42 | } 43 | 44 | public float getSepalWidth() { 45 | return sepalWidth; 46 | } 47 | 48 | public void setSepalWidth(float sepalWidth) { 49 | this.sepalWidth = sepalWidth; 50 | } 51 | 52 | 53 | } 54 | -------------------------------------------------------------------------------- /serving/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 4 | 4.0.0 5 | 6 | com.mimacom 7 | tensorflowdemo 8 | 0.0.1-SNAPSHOT 9 | jar 10 | 11 | tensorflowdemo 12 | Demo project for Spring Boot and Tensorflow 13 | 14 | 15 | org.springframework.cloud 16 | spring-cloud-starter-parent 17 | Camden.SR7 18 | 19 | 20 | 21 | UTF-8 22 | UTF-8 23 | 1.8 24 | 25 | 26 | 27 | 28 | org.springframework.boot 29 | spring-boot-starter-web 30 | 31 | 32 | 33 | org.tensorflow 34 | tensorflow 35 | 1.8.0 36 | 37 | 38 | 39 | org.springframework.boot 40 | spring-boot-starter-test 41 | test 42 | 43 | 44 | 45 | 46 | 47 | com.google.code.gson 48 | gson 49 | test 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | org.springframework.boot 58 | spring-boot-maven-plugin 59 | 60 | 61 | 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /serving/src/test/java/com/mimacom/irisml/web/controller/BaseControllerTest.java: -------------------------------------------------------------------------------- 1 | package com.mimacom.irisml.web.controller; 2 | 3 | import com.mimacom.tensorflowdemo.WebModuleTestConfig; 4 | import org.junit.Before; 5 | import org.junit.Test; 6 | import org.junit.runner.RunWith; 7 | import org.springframework.beans.factory.annotation.Autowired; 8 | import org.springframework.boot.test.context.SpringBootTest; 9 | import org.springframework.http.MediaType; 10 | import org.springframework.http.converter.HttpMessageConverter; 11 | import org.springframework.mock.http.MockHttpOutputMessage; 12 | import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; 13 | import org.springframework.test.context.web.WebAppConfiguration; 14 | import org.springframework.test.web.servlet.MockMvc; 15 | import org.springframework.test.web.servlet.setup.MockMvcBuilders; 16 | import org.springframework.web.context.WebApplicationContext; 17 | 18 | import java.io.IOException; 19 | 20 | import static org.junit.Assert.assertTrue; 21 | 22 | /** 23 | * Base class for the test of the controllers 24 | */ 25 | @SpringBootTest(classes = {WebModuleTestConfig.class}) 26 | @WebAppConfiguration 27 | @RunWith(SpringJUnit4ClassRunner.class) 28 | public class BaseControllerTest { 29 | 30 | @SuppressWarnings("SpringJavaAutowiringInspection") 31 | @Autowired 32 | protected HttpMessageConverter mappingJackson2HttpMessageConverter; 33 | 34 | 35 | protected MockMvc mockMvc; 36 | 37 | @Autowired 38 | private WebApplicationContext wac; 39 | 40 | protected String json(Object o) throws IOException { 41 | MockHttpOutputMessage mockHttpOutputMessage = new MockHttpOutputMessage(); 42 | this.mappingJackson2HttpMessageConverter.write(o, MediaType.APPLICATION_JSON, mockHttpOutputMessage); 43 | return mockHttpOutputMessage.getBodyAsString(); 44 | } 45 | 46 | @Before 47 | public void setUp() throws Exception { 48 | 49 | this.mockMvc = MockMvcBuilders 50 | .webAppContextSetup(wac) 51 | .build(); 52 | } 53 | 54 | @Test 55 | public void testInit() throws Exception { 56 | assertTrue(true); 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /serving/src/main/java/com/mimacom/irisml/service/impl/IrisTensorflowClassifierService.java: -------------------------------------------------------------------------------- 1 | package com.mimacom.irisml.service.impl; 2 | 3 | import com.mimacom.irisml.domain.Iris; 4 | import com.mimacom.irisml.domain.IrisType; 5 | import com.mimacom.irisml.service.IrisClassifierService; 6 | import org.springframework.beans.factory.annotation.Autowired; 7 | import org.springframework.beans.factory.annotation.Value; 8 | import org.springframework.stereotype.Service; 9 | import org.tensorflow.SavedModelBundle; 10 | import org.tensorflow.Session; 11 | import org.tensorflow.Tensor; 12 | 13 | import java.util.HashMap; 14 | import java.util.Map; 15 | 16 | @Service 17 | public class IrisTensorflowClassifierService implements IrisClassifierService { 18 | 19 | private final Session modelBundleSession; 20 | private final IrisType[] irisTypes; 21 | 22 | private final static String FEED_OPERATION = "dnn/input_from_feature_columns/input_layer/concat"; 23 | private final static String FETCH_OPERATION_PROBABILITIES = "dnn/head/predictions/probabilities"; 24 | private final static String FETCH_OPERATION_CLASS_ID = "dnn/head/predictions/class_ids"; 25 | 26 | @Autowired 27 | public IrisTensorflowClassifierService(@Value("${irisml.savedModel.path}") String savedModelPath, 28 | @Value("${irisml.savedModel.tags}") String savedModelTags) { 29 | this.modelBundleSession = SavedModelBundle.load(savedModelPath, savedModelTags).session(); 30 | this.irisTypes = IrisType.values(); 31 | } 32 | 33 | @Override 34 | public IrisType classify(Iris iris) { 35 | int category = this.fetchClassFromModel(iris); 36 | return this.irisTypes[category]; 37 | } 38 | 39 | @Override 40 | public Map classificationProbabilities(Iris iris){ 41 | Map results = new HashMap<>(irisTypes.length); 42 | float[][] vector = this.fetchProbabilitiesFromModel(iris); 43 | int resultsCount = vector[0].length; 44 | for (int i=0; i < resultsCount; i++){ 45 | results.put(irisTypes[i],vector[0][i]); 46 | } 47 | return results; 48 | } 49 | 50 | private float[][] fetchProbabilitiesFromModel(Iris iris) { 51 | Tensor inputTensor = IrisTensorflowClassifierService.createInputTensor(iris); 52 | 53 | Tensor result = this.modelBundleSession.runner() 54 | .feed(IrisTensorflowClassifierService.FEED_OPERATION, inputTensor) 55 | .fetch(IrisTensorflowClassifierService.FETCH_OPERATION_PROBABILITIES) 56 | .run().get(0); 57 | 58 | float[][] buffer = new float[1][3]; 59 | result.copyTo(buffer); 60 | return buffer; 61 | 62 | } 63 | 64 | private int fetchClassFromModel(Iris iris){ 65 | Tensor inputTensor = IrisTensorflowClassifierService.createInputTensor(iris); 66 | 67 | Tensor result = this.modelBundleSession.runner() 68 | .feed(IrisTensorflowClassifierService.FEED_OPERATION, inputTensor) 69 | .fetch(IrisTensorflowClassifierService.FETCH_OPERATION_CLASS_ID) 70 | .run().get(0); 71 | 72 | long[] buffer = new long[1]; 73 | result.copyTo(buffer); 74 | 75 | return (int)buffer[0]; 76 | } 77 | 78 | private static Tensor createInputTensor(Iris iris){ 79 | // order of the data on the input: PetalLength, PetalWidth, SepalLength, SepalWidth 80 | // (taken from the saved_model, node dnn/input_from_feature_columns/input_layer/concat) 81 | float[] input = {iris.getPetalLength(), iris.getPetalWidth(), iris.getSepalLength(), iris.getSepalWidth()}; 82 | float[][] data = new float[1][4]; 83 | data[0] = input; 84 | return Tensor.create(data); 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /serving/src/test/java/com/mimacom/irisml/web/controller/IrisControllerTest.java: -------------------------------------------------------------------------------- 1 | package com.mimacom.irisml.web.controller; 2 | 3 | import com.google.gson.Gson; 4 | import com.google.gson.internal.LinkedTreeMap; 5 | import com.mimacom.irisml.domain.IrisType; 6 | import org.junit.Test; 7 | import org.springframework.mock.web.MockHttpServletRequest; 8 | import org.springframework.mock.web.MockHttpServletResponse; 9 | import org.springframework.test.web.servlet.MvcResult; 10 | 11 | import java.io.UnsupportedEncodingException; 12 | import java.util.Locale; 13 | import java.util.Map; 14 | 15 | import static org.junit.Assert.assertEquals; 16 | import static org.junit.Assert.assertNull; 17 | import static org.junit.Assert.assertTrue; 18 | import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; 19 | import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; 20 | 21 | /** 22 | * order of the data on the input: PetalLength, PetalWidth, SepalLength, SepalWidth 23 | * (taken from the saved_model, node dnn/input_from_feature_columns/input_layer/concat) 24 | * new Iris(1.3f, 0.3f,5.0f, 3.5f); // expected 0 Setosa 25 | * new Iris(4.4f, 1.4f, 6.7f, 3.1f); // expected 1 Versicolour 26 | * new Iris(6.1f, 1.9f,7.4f, 2.8f); // expected 2 Virginica 27 | */ 28 | public class IrisControllerTest extends BaseControllerTest { 29 | 30 | 31 | @Test 32 | public void classify() throws Exception { 33 | 34 | String urlTemplate = "/iris/classify/class?petalLength=%.1f&petalWidth=%.1f&sepalLength=%.1f&sepalWidth=%.1f"; 35 | 36 | // Locale.US to make sure the numbers are with period instead of comma. 37 | String urlRequest = String.format(Locale.US, urlTemplate, 1.3f, 0.3f, 5.0f, 3.5f); 38 | MvcResult mvcResult = this.mockMvc.perform(get(urlRequest)) 39 | .andExpect(status().isOk()).andReturn(); 40 | 41 | assertEquals(IrisType.SETOSA.toString(), mvcResult.getResponse().getContentAsString().replace("\"", "")); 42 | 43 | 44 | urlRequest = String.format(Locale.US, urlTemplate, 4.4f, 1.4f, 6.7f, 3.1f); 45 | mvcResult = this.mockMvc.perform(get(urlRequest)) 46 | .andExpect(status().isOk()).andReturn(); 47 | 48 | assertEquals(IrisType.VERSICOLOUR.toString(), mvcResult.getResponse().getContentAsString().replace("\"", "")); 49 | 50 | 51 | urlRequest = String.format(Locale.US, urlTemplate, 6.1f, 1.9f, 7.4f, 2.8f); 52 | mvcResult = this.mockMvc.perform(get(urlRequest)) 53 | .andExpect(status().isOk()).andReturn(); 54 | 55 | assertEquals(IrisType.VIRGINICA.toString(), mvcResult.getResponse().getContentAsString().replace("\"", "")); 56 | 57 | } 58 | 59 | @Test 60 | public void classificationProbabilities() throws Exception { 61 | 62 | String urlTemplate = "/iris/classify/probabilities?petalLength=%.1f&petalWidth=%.1f&sepalLength=%.1f&sepalWidth=%.1f"; 63 | 64 | // Locale.US to make sure the numbers are with period instead of comma. 65 | String urlRequest = String.format(Locale.US, urlTemplate, 1.3f, 0.3f, 5.0f, 3.5f); 66 | MvcResult mvcResult = this.mockMvc.perform(get(urlRequest)) 67 | .andExpect(status().isOk()).andReturn(); 68 | 69 | assertProbabilitiesResponse(mvcResult.getResponse(), IrisType.SETOSA); 70 | 71 | urlRequest = String.format(Locale.US, urlTemplate, 4.4f, 1.4f, 6.7f, 3.1f); 72 | mvcResult = this.mockMvc.perform(get(urlRequest)) 73 | .andExpect(status().isOk()).andReturn(); 74 | 75 | assertProbabilitiesResponse(mvcResult.getResponse(), IrisType.VERSICOLOUR); 76 | 77 | urlRequest = String.format(Locale.US, urlTemplate, 6.1f, 1.9f, 7.4f, 2.8f); 78 | mvcResult = this.mockMvc.perform(get(urlRequest)) 79 | .andExpect(status().isOk()).andReturn(); 80 | 81 | assertProbabilitiesResponse(mvcResult.getResponse(), IrisType.VIRGINICA); 82 | } 83 | 84 | 85 | private void assertProbabilitiesResponse(MockHttpServletResponse mockHttpServletResponse, IrisType expectedType) throws UnsupportedEncodingException { 86 | // Extract the probabilities response 87 | Gson gson = new Gson(); 88 | LinkedTreeMap probabilities; 89 | probabilities = (LinkedTreeMap) gson.fromJson(mockHttpServletResponse.getContentAsString(), Map.class); 90 | // Assert 91 | assertEquals(expectedType.toString(), getPredictedType(probabilities)); 92 | assertProbabilities(probabilities); 93 | } 94 | 95 | private String getPredictedType(LinkedTreeMap probabilities) { 96 | // The predicted type is the one with the highest probabilities 97 | String predictedType = probabilities.entrySet().stream().max(Map.Entry.comparingByValue()).get().getKey(); 98 | return predictedType; 99 | } 100 | 101 | private void assertProbabilities(LinkedTreeMap probabilities) { 102 | // The same amount of entries in the map as the possible values 103 | assertEquals(probabilities.size(), IrisType.values().length); 104 | 105 | // All the types have a probability value 106 | for(IrisType irisType: IrisType.values()){ 107 | assertTrue(probabilities.containsKey(irisType.toString())); 108 | } 109 | 110 | // All the entries have a value 111 | probabilities.entrySet().stream().forEach(probabilityEntry -> { 112 | assertTrue(probabilityEntry.getKey() != null); 113 | assertTrue(probabilityEntry.getValue() != null); 114 | }); 115 | } 116 | 117 | 118 | } -------------------------------------------------------------------------------- /training/.ipynb_checkpoints/TF_iris_data - working backup-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import tensorflow as tf\n", 10 | "import pandas as pd\n", 11 | "import numpy as numpy\n", 12 | "from sklearn import datasets\n" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 2, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "iris = datasets.load_iris()\n", 22 | "iris_data_w_target = [];\n", 23 | "# prepare data\n", 24 | "# the data comes originally with all the examples ordered, so to get a meaningful test sample, we need to shuffle it\n", 25 | "# add the target to the data to be able to shuffle and partion later\n", 26 | "for i in range(len(iris.data)):\n", 27 | " value = numpy.append(iris.data[i], iris.target[i])\n", 28 | " iris_data_w_target.append(value)\n", 29 | "\n", 30 | "#print(iris_data_w_target)\n", 31 | "feature_names = ['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth', 'label']\n", 32 | "\n" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 3, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "\n", 42 | "df = pd.DataFrame(data = iris_data_w_target, columns = feature_names )\n", 43 | "\n", 44 | "# shuffle our data\n", 45 | "df = df.sample(frac=1).reset_index(drop=True)\n", 46 | "\n", 47 | "# partition of our data\n", 48 | "# we reserve the 20% of the total records for testing\n", 49 | "test_len = (len(df) * 20)//100;\n", 50 | "training_df = df[test_len:]\n", 51 | "test_df = df[:test_len]\n", 52 | "\n", 53 | "\n", 54 | "#print(test_df)\n" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 4, 60 | "metadata": {}, 61 | "outputs": [ 62 | { 63 | "name": "stdout", 64 | "output_type": "stream", 65 | "text": [ 66 | "INFO:tensorflow:Using default config.\n", 67 | "WARNING:tensorflow:Using temporary folder as model directory: C:\\Users\\_domine3\\AppData\\Local\\Temp\\tmp388xdcy3\n", 68 | "INFO:tensorflow:Using config: {'_log_step_count_steps': 100, '_train_distribute': None, '_keep_checkpoint_max': 5, '_save_checkpoints_steps': None, '_model_dir': 'C:\\\\Users\\\\_domine3\\\\AppData\\\\Local\\\\Temp\\\\tmp388xdcy3', '_tf_random_seed': None, '_global_id_in_cluster': 0, '_cluster_spec': , '_task_id': 0, '_save_summary_steps': 100, '_keep_checkpoint_every_n_hours': 10000, '_service': None, '_master': '', '_session_config': None, '_num_worker_replicas': 1, '_is_chief': True, '_num_ps_replicas': 0, '_task_type': 'worker', '_evaluation_master': '', '_save_checkpoints_secs': 600}\n", 69 | "INFO:tensorflow:Calling model_fn.\n", 70 | "INFO:tensorflow:Done calling model_fn.\n", 71 | "INFO:tensorflow:Create CheckpointSaverHook.\n", 72 | "INFO:tensorflow:Graph was finalized.\n", 73 | "INFO:tensorflow:Running local_init_op.\n", 74 | "INFO:tensorflow:Done running local_init_op.\n", 75 | "INFO:tensorflow:Saving checkpoints for 1 into C:\\Users\\_domine3\\AppData\\Local\\Temp\\tmp388xdcy3\\model.ckpt.\n", 76 | "INFO:tensorflow:loss = 139.2976, step = 1\n", 77 | "INFO:tensorflow:global_step/sec: 386.1\n", 78 | "INFO:tensorflow:loss = 15.095531, step = 101 (0.261 sec)\n", 79 | "INFO:tensorflow:global_step/sec: 512.821\n", 80 | "INFO:tensorflow:loss = 8.675394, step = 201 (0.195 sec)\n", 81 | "INFO:tensorflow:global_step/sec: 317.46\n", 82 | "INFO:tensorflow:loss = 9.85504, step = 301 (0.317 sec)\n", 83 | "INFO:tensorflow:global_step/sec: 343.643\n", 84 | "INFO:tensorflow:loss = 6.7131195, step = 401 (0.292 sec)\n", 85 | "INFO:tensorflow:global_step/sec: 305.81\n", 86 | "INFO:tensorflow:loss = 4.9423385, step = 501 (0.325 sec)\n", 87 | "INFO:tensorflow:global_step/sec: 362.319\n", 88 | "INFO:tensorflow:loss = 5.3761263, step = 601 (0.278 sec)\n", 89 | "INFO:tensorflow:global_step/sec: 326.797\n", 90 | "INFO:tensorflow:loss = 3.7728348, step = 701 (0.307 sec)\n", 91 | "INFO:tensorflow:global_step/sec: 591.715\n", 92 | "INFO:tensorflow:loss = 3.9594324, step = 801 (0.167 sec)\n", 93 | "INFO:tensorflow:global_step/sec: 358.423\n", 94 | "INFO:tensorflow:loss = 3.5536861, step = 901 (0.278 sec)\n", 95 | "INFO:tensorflow:global_step/sec: 348.432\n", 96 | "INFO:tensorflow:loss = 4.591949, step = 1001 (0.285 sec)\n", 97 | "INFO:tensorflow:global_step/sec: 534.759\n", 98 | "INFO:tensorflow:loss = 4.2310295, step = 1101 (0.191 sec)\n", 99 | "INFO:tensorflow:global_step/sec: 423.729\n", 100 | "INFO:tensorflow:loss = 2.5031085, step = 1201 (0.234 sec)\n", 101 | "INFO:tensorflow:global_step/sec: 416.667\n", 102 | "INFO:tensorflow:loss = 1.7853639, step = 1301 (0.238 sec)\n", 103 | "INFO:tensorflow:global_step/sec: 343.642\n", 104 | "INFO:tensorflow:loss = 1.6998769, step = 1401 (0.294 sec)\n", 105 | "INFO:tensorflow:global_step/sec: 317.46\n", 106 | "INFO:tensorflow:loss = 3.4077053, step = 1501 (0.314 sec)\n", 107 | "INFO:tensorflow:global_step/sec: 390.625\n", 108 | "INFO:tensorflow:loss = 6.1948853, step = 1601 (0.256 sec)\n", 109 | "INFO:tensorflow:global_step/sec: 552.486\n", 110 | "INFO:tensorflow:loss = 2.1856198, step = 1701 (0.179 sec)\n", 111 | "INFO:tensorflow:global_step/sec: 537.634\n", 112 | "INFO:tensorflow:loss = 2.9113278, step = 1801 (0.186 sec)\n", 113 | "INFO:tensorflow:global_step/sec: 473.934\n", 114 | "INFO:tensorflow:loss = 2.9161034, step = 1901 (0.214 sec)\n", 115 | "INFO:tensorflow:Saving checkpoints for 2000 into C:\\Users\\_domine3\\AppData\\Local\\Temp\\tmp388xdcy3\\model.ckpt.\n", 116 | "INFO:tensorflow:Loss for final step: 2.8931274.\n" 117 | ] 118 | }, 119 | { 120 | "data": { 121 | "text/plain": [ 122 | "" 123 | ] 124 | }, 125 | "execution_count": 4, 126 | "metadata": {}, 127 | "output_type": "execute_result" 128 | } 129 | ], 130 | "source": [ 131 | "# train the classifier\n", 132 | "\n", 133 | "my_feature_columns = [\n", 134 | " tf.contrib.layers.real_valued_column('SepalLength', dimension=1, dtype=tf.float32),\n", 135 | " tf.contrib.layers.real_valued_column('SepalWidth', dimension=1, dtype=tf.float32),\n", 136 | " tf.contrib.layers.real_valued_column('PetalLength', dimension=1, dtype=tf.float32),\n", 137 | " tf.contrib.layers.real_valued_column('PetalWidth', dimension=1, dtype=tf.float32)\n", 138 | "]\n", 139 | "\n", 140 | "tf.contrib.layers.real_valued_column\n", 141 | "\n", 142 | "# format data as required by the tensorflow framework\n", 143 | "x = {\n", 144 | " 'SepalLength': numpy.array(training_df['SepalLength']),\n", 145 | " 'SepalWidth': numpy.array(training_df['SepalWidth']),\n", 146 | " 'PetalLength': numpy.array(training_df['PetalLength']),\n", 147 | " 'PetalWidth': numpy.array(training_df['PetalWidth'])\n", 148 | "}\n", 149 | "\n", 150 | "#neural network\n", 151 | "classifier = tf.estimator.DNNClassifier(\n", 152 | " feature_columns=my_feature_columns,\n", 153 | " hidden_units=[10, 10],\n", 154 | " n_classes=3)\n", 155 | "\n", 156 | "\n", 157 | "# Define the training inputs\n", 158 | "train_input_fn = tf.estimator.inputs.numpy_input_fn(\n", 159 | " x=x,\n", 160 | " y=numpy.array(training_df['label']).astype(int),\n", 161 | " num_epochs=None,\n", 162 | " shuffle=True)\n", 163 | "\n", 164 | "# Train model.\n", 165 | "classifier.train(input_fn=train_input_fn, steps=2000)" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 5, 171 | "metadata": {}, 172 | "outputs": [ 173 | { 174 | "name": "stdout", 175 | "output_type": "stream", 176 | "text": [ 177 | "INFO:tensorflow:Calling model_fn.\n", 178 | "INFO:tensorflow:Done calling model_fn.\n", 179 | "INFO:tensorflow:Starting evaluation at 2018-06-04-17:11:52\n", 180 | "INFO:tensorflow:Graph was finalized.\n", 181 | "INFO:tensorflow:Restoring parameters from C:\\Users\\_domine3\\AppData\\Local\\Temp\\tmp388xdcy3\\model.ckpt-2000\n", 182 | "INFO:tensorflow:Running local_init_op.\n", 183 | "INFO:tensorflow:Done running local_init_op.\n", 184 | "INFO:tensorflow:Finished evaluation at 2018-06-04-17:11:53\n", 185 | "INFO:tensorflow:Saving dict for global step 2000: accuracy = 0.93333334, average_loss = 0.21500692, global_step = 2000, loss = 6.4502077\n", 186 | "Test Accuracy: 0.93333334\n" 187 | ] 188 | } 189 | ], 190 | "source": [ 191 | "\n", 192 | "# test accuracy of the model\n", 193 | "# format data as required by the tensorflow framework\n", 194 | "x = {\n", 195 | " 'SepalLength': numpy.array(test_df['SepalLength']),\n", 196 | " 'SepalWidth': numpy.array(test_df['SepalWidth']),\n", 197 | " 'PetalLength': numpy.array(test_df['PetalLength']),\n", 198 | " 'PetalWidth': numpy.array(test_df['PetalWidth'])\n", 199 | "}\n", 200 | "\n", 201 | "# Define the training inputs\n", 202 | "test_input_fn = tf.estimator.inputs.numpy_input_fn(\n", 203 | " x=x,\n", 204 | " y=numpy.array(test_df['label']).astype(int),\n", 205 | " num_epochs=1,\n", 206 | " shuffle=False)\n", 207 | "\n", 208 | "# Evaluate accuracy.\n", 209 | "accuracy_score = classifier.evaluate(input_fn=test_input_fn)[\"accuracy\"]\n", 210 | "\n", 211 | "print(\"Test Accuracy: \", accuracy_score)\n" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": 1, 217 | "metadata": {}, 218 | "outputs": [ 219 | { 220 | "ename": "NameError", 221 | "evalue": "name 'numpy' is not defined", 222 | "output_type": "error", 223 | "traceback": [ 224 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", 225 | "\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)", 226 | "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[1;31m# format data as required by the tensorflow framework\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m x = {\n\u001b[1;32m----> 4\u001b[1;33m \u001b[1;34m'SepalLength'\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mnumpy\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0marray\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m5.0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m6.7\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m7.4\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 5\u001b[0m \u001b[1;34m'SepalWidth'\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mnumpy\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0marray\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m3.5\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m3.1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m2.8\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 6\u001b[0m \u001b[1;34m'PetalLength'\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mnumpy\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0marray\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m1.3\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m4.4\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m6.1\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 227 | "\u001b[1;31mNameError\u001b[0m: name 'numpy' is not defined" 228 | ] 229 | } 230 | ], 231 | "source": [ 232 | "# We do some more manual testing\n", 233 | "# format data as required by the tensorflow framework\n", 234 | "x = {\n", 235 | " 'SepalLength': numpy.array([5.0, 6.7, 7.4]),\n", 236 | " 'SepalWidth': numpy.array([3.5, 3.1, 2.8]),\n", 237 | " 'PetalLength': numpy.array([1.3, 4.4, 6.1]),\n", 238 | " 'PetalWidth': numpy.array([0.3, 1.4, 1.9])\n", 239 | "}\n", 240 | "\n", 241 | "expected = numpy.array([0, 1, 2])\n", 242 | "\n", 243 | "# Define the training inputs\n", 244 | "predict_input_fn = tf.estimator.inputs.numpy_input_fn(\n", 245 | " x=x,\n", 246 | " num_epochs=1,\n", 247 | " shuffle=False)\n", 248 | "\n", 249 | "predictions = classifier.predict(input_fn=predict_input_fn)\n", 250 | "\n", 251 | "for pred_dict, expec in zip(predictions, expected):\n", 252 | " class_id = pred_dict['class_ids'][0]\n", 253 | " probability = pred_dict['probabilities'][class_id]\n", 254 | "\n", 255 | " print('\\nPrediction is \"{}\" ({:.1f}%), expected \"{}\"'.format(class_id, 100 * probability, expec))\n", 256 | " \n" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 7, 262 | "metadata": {}, 263 | "outputs": [ 264 | { 265 | "name": "stdout", 266 | "output_type": "stream", 267 | "text": [ 268 | "INFO:tensorflow:Calling model_fn.\n", 269 | "INFO:tensorflow:Done calling model_fn.\n", 270 | "INFO:tensorflow:Signatures INCLUDED in export for Regress: None\n", 271 | "INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict']\n", 272 | "INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification']\n", 273 | "INFO:tensorflow:Restoring parameters from C:\\Users\\_domine3\\AppData\\Local\\Temp\\tmp388xdcy3\\model.ckpt-2000\n", 274 | "INFO:tensorflow:Assets added to graph.\n", 275 | "INFO:tensorflow:No assets to write.\n", 276 | "INFO:tensorflow:SavedModel written to: b\"stored_model\\\\temp-b'1528132320'\\\\saved_model.pbtxt\"\n" 277 | ] 278 | }, 279 | { 280 | "data": { 281 | "text/plain": [ 282 | "b'stored_model\\\\1528132320'" 283 | ] 284 | }, 285 | "execution_count": 7, 286 | "metadata": {}, 287 | "output_type": "execute_result" 288 | } 289 | ], 290 | "source": [ 291 | "# tfrecord_serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(tf.contrib.layers.create_feature_spec_for_parsing(my_feature_columns)) \n", 292 | "\n", 293 | "def serving_input_receiver_fn():\n", 294 | " serialized_tf_example = tf.placeholder(dtype=tf.string, shape=[None], name='input_tensors')\n", 295 | " receiver_tensors = {\"predictor_inputs\": serialized_tf_example}\n", 296 | " feature_spec = {'SepalLength': tf.FixedLenFeature([25],tf.float32),\n", 297 | " 'SepalWidth': tf.FixedLenFeature([25],tf.float32),\n", 298 | " 'PetalLength': tf.FixedLenFeature([25],tf.float32),\n", 299 | " 'PetalWidth': tf.FixedLenFeature([25],tf.float32)}\n", 300 | " features = tf.parse_example(serialized_tf_example, feature_spec)\n", 301 | " return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)\n", 302 | "\n", 303 | "\n", 304 | "classifier.export_savedmodel(export_dir_base=\"stored_model\", serving_input_receiver_fn = serving_input_receiver_fn,as_text=True)\n", 305 | "\n" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": null, 311 | "metadata": {}, 312 | "outputs": [], 313 | "source": [] 314 | } 315 | ], 316 | "metadata": { 317 | "kernelspec": { 318 | "display_name": "Python 3", 319 | "language": "python", 320 | "name": "python3" 321 | }, 322 | "language_info": { 323 | "codemirror_mode": { 324 | "name": "ipython", 325 | "version": 3 326 | }, 327 | "file_extension": ".py", 328 | "mimetype": "text/x-python", 329 | "name": "python", 330 | "nbconvert_exporter": "python", 331 | "pygments_lexer": "ipython3", 332 | "version": "3.5.5" 333 | } 334 | }, 335 | "nbformat": 4, 336 | "nbformat_minor": 2 337 | } 338 | -------------------------------------------------------------------------------- /training/TF_iris_data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 17, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import tensorflow as tf\n", 10 | "import pandas as pd\n", 11 | "import numpy as np\n", 12 | "from sklearn import datasets\n", 13 | "\n", 14 | "FEATURE_SEPAL_LENGTH = 'SepalLength'\n", 15 | "FEATURE_SEPAL_WIDTH = 'SepalWidth'\n", 16 | "FEATURE_PETAL_LENGTH = 'PetalLength'\n", 17 | "FEATURE_PETAL_WIDTH = 'PetalWidth'\n", 18 | "LABEL = 'label'" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 18, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "# load the data set\n", 28 | "iris = datasets.load_iris()\n", 29 | "iris_data_w_target = [];\n", 30 | "\n", 31 | "# prepare data\n", 32 | "\n", 33 | "# the data comes originally with all the examples ordered, so to get a meaningful test sample, we need to shuffle it\n", 34 | "\n", 35 | "# add the target to the data to be able to shuffle and partition it later\n", 36 | "for i in range(len(iris.data)):\n", 37 | " value = np.append(iris.data[i], iris.target[i])\n", 38 | " iris_data_w_target.append(value)\n", 39 | "\n", 40 | "#print(iris_data_w_target)\n", 41 | "columns_names = [FEATURE_SEPAL_LENGTH, FEATURE_SEPAL_WIDTH, FEATURE_PETAL_LENGTH, FEATURE_PETAL_WIDTH, LABEL]\n", 42 | "\n", 43 | "df = pd.DataFrame(data = iris_data_w_target, columns = columns_names )\n", 44 | "\n", 45 | "# shuffle our data\n", 46 | "df = df.sample(frac=1).reset_index(drop=True)\n" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 19, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "# partition of our data\n", 56 | "# we reserve the 20% of the total records for testing\n", 57 | "test_len = (len(df) * 20)//100;\n", 58 | "training_df = df[test_len:]\n", 59 | "test_df = df[:test_len]\n", 60 | "\n", 61 | "#print(test_df)" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 31, 67 | "metadata": {}, 68 | "outputs": [ 69 | { 70 | "name": "stdout", 71 | "output_type": "stream", 72 | "text": [ 73 | "INFO:tensorflow:Using default config.\n", 74 | "WARNING:tensorflow:Using temporary folder as model directory: C:\\Users\\_domine3\\AppData\\Local\\Temp\\tmpi2at4f5r\n", 75 | "INFO:tensorflow:Using config: {'_cluster_spec': , '_num_ps_replicas': 0, '_is_chief': True, '_save_checkpoints_secs': 600, '_num_worker_replicas': 1, '_task_id': 0, '_save_summary_steps': 100, '_evaluation_master': '', '_train_distribute': None, '_tf_random_seed': None, '_session_config': None, '_model_dir': 'C:\\\\Users\\\\_domine3\\\\AppData\\\\Local\\\\Temp\\\\tmpi2at4f5r', '_service': None, '_master': '', '_global_id_in_cluster': 0, '_task_type': 'worker', '_log_step_count_steps': 100, '_keep_checkpoint_every_n_hours': 10000, '_keep_checkpoint_max': 5, '_save_checkpoints_steps': None}\n", 76 | "INFO:tensorflow:Calling model_fn.\n", 77 | "INFO:tensorflow:Done calling model_fn.\n", 78 | "INFO:tensorflow:Create CheckpointSaverHook.\n", 79 | "INFO:tensorflow:Graph was finalized.\n", 80 | "INFO:tensorflow:Running local_init_op.\n", 81 | "INFO:tensorflow:Done running local_init_op.\n", 82 | "INFO:tensorflow:Saving checkpoints for 1 into C:\\Users\\_domine3\\AppData\\Local\\Temp\\tmpi2at4f5r\\model.ckpt.\n", 83 | "INFO:tensorflow:loss = 146.32745, step = 1\n", 84 | "INFO:tensorflow:global_step/sec: 543.478\n", 85 | "INFO:tensorflow:loss = 30.012009, step = 101 (0.187 sec)\n", 86 | "INFO:tensorflow:global_step/sec: 729.927\n", 87 | "INFO:tensorflow:loss = 15.369042, step = 201 (0.136 sec)\n", 88 | "INFO:tensorflow:global_step/sec: 775.194\n", 89 | "INFO:tensorflow:loss = 10.820699, step = 301 (0.130 sec)\n", 90 | "INFO:tensorflow:global_step/sec: 719.425\n", 91 | "INFO:tensorflow:loss = 13.417875, step = 401 (0.139 sec)\n", 92 | "INFO:tensorflow:global_step/sec: 719.425\n", 93 | "INFO:tensorflow:loss = 10.081392, step = 501 (0.139 sec)\n", 94 | "INFO:tensorflow:global_step/sec: 729.927\n", 95 | "INFO:tensorflow:loss = 7.979366, step = 601 (0.136 sec)\n", 96 | "INFO:tensorflow:global_step/sec: 763.359\n", 97 | "INFO:tensorflow:loss = 10.715974, step = 701 (0.133 sec)\n", 98 | "INFO:tensorflow:global_step/sec: 709.22\n", 99 | "INFO:tensorflow:loss = 9.723078, step = 801 (0.139 sec)\n", 100 | "INFO:tensorflow:global_step/sec: 704.226\n", 101 | "INFO:tensorflow:loss = 11.645809, step = 901 (0.143 sec)\n", 102 | "INFO:tensorflow:global_step/sec: 740.741\n", 103 | "INFO:tensorflow:loss = 15.186651, step = 1001 (0.135 sec)\n", 104 | "INFO:tensorflow:global_step/sec: 751.88\n", 105 | "INFO:tensorflow:loss = 8.266764, step = 1101 (0.134 sec)\n", 106 | "INFO:tensorflow:global_step/sec: 769.23\n", 107 | "INFO:tensorflow:loss = 9.726784, step = 1201 (0.129 sec)\n", 108 | "INFO:tensorflow:global_step/sec: 657.895\n", 109 | "INFO:tensorflow:loss = 12.806259, step = 1301 (0.151 sec)\n", 110 | "INFO:tensorflow:global_step/sec: 584.795\n", 111 | "INFO:tensorflow:loss = 7.054368, step = 1401 (0.175 sec)\n", 112 | "INFO:tensorflow:global_step/sec: 561.798\n", 113 | "INFO:tensorflow:loss = 6.6224527, step = 1501 (0.176 sec)\n", 114 | "INFO:tensorflow:global_step/sec: 581.396\n", 115 | "INFO:tensorflow:loss = 3.9045315, step = 1601 (0.171 sec)\n", 116 | "INFO:tensorflow:global_step/sec: 751.879\n", 117 | "INFO:tensorflow:loss = 5.649955, step = 1701 (0.133 sec)\n", 118 | "INFO:tensorflow:global_step/sec: 724.638\n", 119 | "INFO:tensorflow:loss = 4.859914, step = 1801 (0.139 sec)\n", 120 | "INFO:tensorflow:global_step/sec: 641.026\n", 121 | "INFO:tensorflow:loss = 7.9570336, step = 1901 (0.153 sec)\n", 122 | "INFO:tensorflow:global_step/sec: 704.226\n", 123 | "INFO:tensorflow:loss = 4.841754, step = 2001 (0.144 sec)\n", 124 | "INFO:tensorflow:global_step/sec: 751.879\n", 125 | "INFO:tensorflow:loss = 7.0097547, step = 2101 (0.133 sec)\n", 126 | "INFO:tensorflow:global_step/sec: 684.932\n", 127 | "INFO:tensorflow:loss = 8.0823965, step = 2201 (0.146 sec)\n", 128 | "INFO:tensorflow:global_step/sec: 543.478\n", 129 | "INFO:tensorflow:loss = 5.6107244, step = 2301 (0.186 sec)\n", 130 | "INFO:tensorflow:global_step/sec: 537.634\n", 131 | "INFO:tensorflow:loss = 13.281709, step = 2401 (0.188 sec)\n", 132 | "INFO:tensorflow:global_step/sec: 373.134\n", 133 | "INFO:tensorflow:loss = 14.834627, step = 2501 (0.266 sec)\n", 134 | "INFO:tensorflow:global_step/sec: 450.45\n", 135 | "INFO:tensorflow:loss = 5.130001, step = 2601 (0.223 sec)\n", 136 | "INFO:tensorflow:global_step/sec: 444.445\n", 137 | "INFO:tensorflow:loss = 9.435574, step = 2701 (0.222 sec)\n", 138 | "INFO:tensorflow:global_step/sec: 436.681\n", 139 | "INFO:tensorflow:loss = 1.4045376, step = 2801 (0.231 sec)\n", 140 | "INFO:tensorflow:global_step/sec: 403.226\n", 141 | "INFO:tensorflow:loss = 7.711853, step = 2901 (0.245 sec)\n", 142 | "INFO:tensorflow:global_step/sec: 421.941\n", 143 | "INFO:tensorflow:loss = 6.3117566, step = 3001 (0.239 sec)\n", 144 | "INFO:tensorflow:global_step/sec: 537.634\n", 145 | "INFO:tensorflow:loss = 2.9975328, step = 3101 (0.185 sec)\n", 146 | "INFO:tensorflow:global_step/sec: 581.396\n", 147 | "INFO:tensorflow:loss = 6.0504866, step = 3201 (0.174 sec)\n", 148 | "INFO:tensorflow:global_step/sec: 625\n", 149 | "INFO:tensorflow:loss = 3.089339, step = 3301 (0.161 sec)\n", 150 | "INFO:tensorflow:global_step/sec: 574.713\n", 151 | "INFO:tensorflow:loss = 3.684587, step = 3401 (0.171 sec)\n", 152 | "INFO:tensorflow:global_step/sec: 546.448\n", 153 | "INFO:tensorflow:loss = 6.9830103, step = 3501 (0.190 sec)\n", 154 | "INFO:tensorflow:global_step/sec: 680.272\n", 155 | "INFO:tensorflow:loss = 4.2849407, step = 3601 (0.141 sec)\n", 156 | "INFO:tensorflow:global_step/sec: 704.226\n", 157 | "INFO:tensorflow:loss = 3.649962, step = 3701 (0.140 sec)\n", 158 | "INFO:tensorflow:global_step/sec: 621.118\n", 159 | "INFO:tensorflow:loss = 6.5003743, step = 3801 (0.165 sec)\n", 160 | "INFO:tensorflow:global_step/sec: 613.497\n", 161 | "INFO:tensorflow:loss = 2.2862856, step = 3901 (0.160 sec)\n", 162 | "INFO:tensorflow:Saving checkpoints for 4000 into C:\\Users\\_domine3\\AppData\\Local\\Temp\\tmpi2at4f5r\\model.ckpt.\n", 163 | "INFO:tensorflow:Loss for final step: 6.831412.\n" 164 | ] 165 | }, 166 | { 167 | "data": { 168 | "text/plain": [ 169 | "" 170 | ] 171 | }, 172 | "execution_count": 31, 173 | "metadata": {}, 174 | "output_type": "execute_result" 175 | } 176 | ], 177 | "source": [ 178 | "# train the classifier\n", 179 | "\n", 180 | "iris_feature_columns = [\n", 181 | " tf.contrib.layers.real_valued_column(FEATURE_SEPAL_LENGTH, dimension=1, dtype=tf.float32),\n", 182 | " tf.contrib.layers.real_valued_column(FEATURE_SEPAL_WIDTH, dimension=1, dtype=tf.float32),\n", 183 | " tf.contrib.layers.real_valued_column(FEATURE_PETAL_LENGTH, dimension=1, dtype=tf.float32),\n", 184 | " tf.contrib.layers.real_valued_column(FEATURE_PETAL_WIDTH, dimension=1, dtype=tf.float32)\n", 185 | "]\n", 186 | "\n", 187 | "\n", 188 | "# format data as required by the tensorflow framework\n", 189 | "x = {\n", 190 | " FEATURE_SEPAL_LENGTH : np.array(training_df[FEATURE_SEPAL_LENGTH]),\n", 191 | " FEATURE_SEPAL_WIDTH : np.array(training_df[FEATURE_SEPAL_WIDTH]),\n", 192 | " FEATURE_PETAL_LENGTH : np.array(training_df[FEATURE_PETAL_LENGTH]),\n", 193 | " FEATURE_PETAL_WIDTH : np.array(training_df[FEATURE_PETAL_WIDTH])\n", 194 | "}\n", 195 | "\n", 196 | "#neural network\n", 197 | "classifier = tf.estimator.DNNClassifier(\n", 198 | " feature_columns = iris_feature_columns,\n", 199 | " hidden_units = [5, 5],\n", 200 | " n_classes = 3)\n", 201 | "\n", 202 | "\n", 203 | "# Define the training inputs\n", 204 | "train_input_fn = tf.estimator.inputs.numpy_input_fn(\n", 205 | " x = x,\n", 206 | " y = np.array(training_df[LABEL]).astype(int),\n", 207 | " num_epochs = None,\n", 208 | " shuffle = True)\n", 209 | "\n", 210 | "# Train model.\n", 211 | "classifier.train(input_fn = train_input_fn, steps = 4000)" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": 32, 217 | "metadata": {}, 218 | "outputs": [ 219 | { 220 | "name": "stdout", 221 | "output_type": "stream", 222 | "text": [ 223 | "INFO:tensorflow:Calling model_fn.\n", 224 | "INFO:tensorflow:Done calling model_fn.\n", 225 | "INFO:tensorflow:Starting evaluation at 2018-06-27-09:33:54\n", 226 | "INFO:tensorflow:Graph was finalized.\n", 227 | "INFO:tensorflow:Restoring parameters from C:\\Users\\_domine3\\AppData\\Local\\Temp\\tmpi2at4f5r\\model.ckpt-4000\n", 228 | "INFO:tensorflow:Running local_init_op.\n", 229 | "INFO:tensorflow:Done running local_init_op.\n", 230 | "INFO:tensorflow:Finished evaluation at 2018-06-27-09:33:54\n", 231 | "INFO:tensorflow:Saving dict for global step 4000: accuracy = 0.96666664, average_loss = 0.08968501, global_step = 4000, loss = 2.6905503\n", 232 | "Test Accuracy: 0.96666664\n" 233 | ] 234 | } 235 | ], 236 | "source": [ 237 | "\n", 238 | "# test accuracy of the model\n", 239 | "# format data as required by the tensorflow framework\n", 240 | "x = {\n", 241 | " FEATURE_SEPAL_LENGTH : np.array(test_df[FEATURE_SEPAL_LENGTH]),\n", 242 | " FEATURE_SEPAL_WIDTH : np.array(test_df[FEATURE_SEPAL_WIDTH]),\n", 243 | " FEATURE_PETAL_LENGTH : np.array(test_df[FEATURE_PETAL_LENGTH]),\n", 244 | " FEATURE_PETAL_WIDTH : np.array(test_df[FEATURE_PETAL_WIDTH])\n", 245 | "}\n", 246 | "\n", 247 | "# Define the training inputs\n", 248 | "test_input_fn = tf.estimator.inputs.numpy_input_fn(\n", 249 | " x = x,\n", 250 | " y = np.array(test_df[LABEL]).astype(int),\n", 251 | " num_epochs = 1,\n", 252 | " shuffle = False)\n", 253 | "\n", 254 | "# Evaluate accuracy.\n", 255 | "accuracy_score = classifier.evaluate(input_fn=test_input_fn)[\"accuracy\"]\n", 256 | "\n", 257 | "print(\"Test Accuracy: \", accuracy_score)\n" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": 33, 263 | "metadata": {}, 264 | "outputs": [ 265 | { 266 | "name": "stdout", 267 | "output_type": "stream", 268 | "text": [ 269 | "INFO:tensorflow:Calling model_fn.\n", 270 | "INFO:tensorflow:Done calling model_fn.\n", 271 | "INFO:tensorflow:Graph was finalized.\n", 272 | "INFO:tensorflow:Restoring parameters from C:\\Users\\_domine3\\AppData\\Local\\Temp\\tmpi2at4f5r\\model.ckpt-4000\n", 273 | "INFO:tensorflow:Running local_init_op.\n", 274 | "INFO:tensorflow:Done running local_init_op.\n", 275 | "\n", 276 | "Prediction is \"0\" (certainity 100.0%), expected \"0\"\n", 277 | "\n", 278 | "Prediction is \"1\" (certainity 100.0%), expected \"1\"\n", 279 | "\n", 280 | "Prediction is \"2\" (certainity 99.4%), expected \"2\"\n" 281 | ] 282 | } 283 | ], 284 | "source": [ 285 | "# We do some more manual testing\n", 286 | "# format data as required by the tensorflow framework\n", 287 | "x = {\n", 288 | " FEATURE_SEPAL_LENGTH : np.array([5.0, 6.7, 7.4]),\n", 289 | " FEATURE_SEPAL_WIDTH : np.array([3.5, 3.1, 2.8]),\n", 290 | " FEATURE_PETAL_LENGTH : np.array([1.3, 4.4, 6.1]),\n", 291 | " FEATURE_PETAL_WIDTH : np.array([0.3, 1.4, 1.9])\n", 292 | "}\n", 293 | "\n", 294 | "expected = np.array([0, 1, 2])\n", 295 | "\n", 296 | "# Define the training inputs\n", 297 | "predict_input_fn = tf.estimator.inputs.numpy_input_fn(\n", 298 | " x = x,\n", 299 | " num_epochs = 1,\n", 300 | " shuffle = False)\n", 301 | "\n", 302 | "predictions = classifier.predict(input_fn = predict_input_fn)\n", 303 | "\n", 304 | "for pred_dict, expec in zip(predictions, expected):\n", 305 | " class_id = pred_dict['class_ids'][0]\n", 306 | " probability = pred_dict['probabilities'][class_id]\n", 307 | "\n", 308 | " print('\\nPrediction is \"{}\" (certainity {:.1f}%), expected \"{}\"'.format(class_id, 100 * probability, expec))\n", 309 | " \n" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": 36, 315 | "metadata": {}, 316 | "outputs": [ 317 | { 318 | "name": "stdout", 319 | "output_type": "stream", 320 | "text": [ 321 | "INFO:tensorflow:Calling model_fn.\n", 322 | "INFO:tensorflow:Done calling model_fn.\n", 323 | "INFO:tensorflow:Signatures INCLUDED in export for Regress: None\n", 324 | "INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict']\n", 325 | "INFO:tensorflow:Signatures INCLUDED in export for Classify: ['classification', 'serving_default']\n", 326 | "INFO:tensorflow:Restoring parameters from C:\\Users\\_domine3\\AppData\\Local\\Temp\\tmpi2at4f5r\\model.ckpt-4000\n", 327 | "INFO:tensorflow:Assets added to graph.\n", 328 | "INFO:tensorflow:No assets to write.\n", 329 | "INFO:tensorflow:SavedModel written to: b\"stored_model\\\\temp-b'1530093489'\\\\saved_model.pbtxt\"\n", 330 | "Model exported to stored_model\\1530093489\n" 331 | ] 332 | } 333 | ], 334 | "source": [ 335 | "# export the model\n", 336 | "\n", 337 | "def serving_input_receiver_fn():\n", 338 | " serialized_tf_example = tf.placeholder(dtype = tf.string, shape = [None], name = 'input_tensors')\n", 339 | " receiver_tensors = {'predictor_inputs' : serialized_tf_example}\n", 340 | " feature_spec = {FEATURE_SEPAL_LENGTH : tf.FixedLenFeature([25], tf.float32),\n", 341 | " FEATURE_SEPAL_WIDTH : tf.FixedLenFeature([25], tf.float32),\n", 342 | " FEATURE_PETAL_LENGTH : tf.FixedLenFeature([25], tf.float32),\n", 343 | " FEATURE_PETAL_WIDTH : tf.FixedLenFeature([25], tf.float32)}\n", 344 | " features = tf.parse_example(serialized_tf_example, feature_spec)\n", 345 | " return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)\n", 346 | "\n", 347 | "\n", 348 | "model_dir = classifier.export_savedmodel(export_dir_base = \"stored_model\", \n", 349 | " serving_input_receiver_fn = serving_input_receiver_fn,\n", 350 | " as_text = True)\n", 351 | "print('Model exported to '+ model_dir.decode())" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": 35, 357 | "metadata": {}, 358 | "outputs": [ 359 | { 360 | "name": "stdout", 361 | "output_type": "stream", 362 | "text": [ 363 | "INFO:tensorflow:Restoring parameters from b'./stored_model/1530092040\\\\variables\\\\variables'\n", 364 | "Model ready for visualization. Execute: tensorboard --logdir=./visualization/tensorboard/1530092040\n" 365 | ] 366 | } 367 | ], 368 | "source": [ 369 | "# Export the saved model so it can be visualized on the tensorboard \n", 370 | "# Code inspired from https://github.com/tensorflow/tensorflow/issues/8854\n", 371 | "\n", 372 | "# Work in progress: The model is shown in the tensorboard, but it is not showing the right data.\n", 373 | "\n", 374 | "from tensorflow.python.summary import summary\n", 375 | "\n", 376 | "log_dir = './visualization/tensorboard/1530093489' # + model_dir.decode().replace('stored_model\\\\','')\n", 377 | "model_dir= './stored_model/1530093489'\n", 378 | "with tf.Session(graph = tf.Graph()) as sess:\n", 379 | " tf.saved_model.loader.load(\n", 380 | " sess,\n", 381 | " [tf.saved_model.tag_constants.SERVING],\n", 382 | " model_dir)\n", 383 | "\n", 384 | " pb_visual_writer = summary.FileWriter(log_dir)\n", 385 | " pb_visual_writer.add_graph(sess.graph)\n", 386 | " print(\"Model ready for visualization. Execute: \"\n", 387 | " \"tensorboard --logdir={}\".format(log_dir))" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": null, 393 | "metadata": {}, 394 | "outputs": [], 395 | "source": [] 396 | } 397 | ], 398 | "metadata": { 399 | "kernelspec": { 400 | "display_name": "Python 3", 401 | "language": "python", 402 | "name": "python3" 403 | }, 404 | "language_info": { 405 | "codemirror_mode": { 406 | "name": "ipython", 407 | "version": 3 408 | }, 409 | "file_extension": ".py", 410 | "mimetype": "text/x-python", 411 | "name": "python", 412 | "nbconvert_exporter": "python", 413 | "pygments_lexer": "ipython3", 414 | "version": "3.5.5" 415 | } 416 | }, 417 | "nbformat": 4, 418 | "nbformat_minor": 2 419 | } 420 | -------------------------------------------------------------------------------- /training/.ipynb_checkpoints/TF_iris_data-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 17, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import tensorflow as tf\n", 10 | "import pandas as pd\n", 11 | "import numpy as np\n", 12 | "from sklearn import datasets\n", 13 | "\n", 14 | "FEATURE_SEPAL_LENGTH = 'SepalLength'\n", 15 | "FEATURE_SEPAL_WIDTH = 'SepalWidth'\n", 16 | "FEATURE_PETAL_LENGTH = 'PetalLength'\n", 17 | "FEATURE_PETAL_WIDTH = 'PetalWidth'\n", 18 | "LABEL = 'label'" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 18, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "# load the data set\n", 28 | "iris = datasets.load_iris()\n", 29 | "iris_data_w_target = [];\n", 30 | "\n", 31 | "# prepare data\n", 32 | "\n", 33 | "# the data comes originally with all the examples ordered, so to get a meaningful test sample, we need to shuffle it\n", 34 | "\n", 35 | "# add the target to the data to be able to shuffle and partition it later\n", 36 | "for i in range(len(iris.data)):\n", 37 | " value = np.append(iris.data[i], iris.target[i])\n", 38 | " iris_data_w_target.append(value)\n", 39 | "\n", 40 | "#print(iris_data_w_target)\n", 41 | "columns_names = [FEATURE_SEPAL_LENGTH, FEATURE_SEPAL_WIDTH, FEATURE_PETAL_LENGTH, FEATURE_PETAL_WIDTH, LABEL]\n", 42 | "\n", 43 | "df = pd.DataFrame(data = iris_data_w_target, columns = columns_names )\n", 44 | "\n", 45 | "# shuffle our data\n", 46 | "df = df.sample(frac=1).reset_index(drop=True)\n" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 19, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "# partition of our data\n", 56 | "# we reserve the 20% of the total records for testing\n", 57 | "test_len = (len(df) * 20)//100;\n", 58 | "training_df = df[test_len:]\n", 59 | "test_df = df[:test_len]\n", 60 | "\n", 61 | "#print(test_df)" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 31, 67 | "metadata": {}, 68 | "outputs": [ 69 | { 70 | "name": "stdout", 71 | "output_type": "stream", 72 | "text": [ 73 | "INFO:tensorflow:Using default config.\n", 74 | "WARNING:tensorflow:Using temporary folder as model directory: C:\\Users\\_domine3\\AppData\\Local\\Temp\\tmpi2at4f5r\n", 75 | "INFO:tensorflow:Using config: {'_cluster_spec': , '_num_ps_replicas': 0, '_is_chief': True, '_save_checkpoints_secs': 600, '_num_worker_replicas': 1, '_task_id': 0, '_save_summary_steps': 100, '_evaluation_master': '', '_train_distribute': None, '_tf_random_seed': None, '_session_config': None, '_model_dir': 'C:\\\\Users\\\\_domine3\\\\AppData\\\\Local\\\\Temp\\\\tmpi2at4f5r', '_service': None, '_master': '', '_global_id_in_cluster': 0, '_task_type': 'worker', '_log_step_count_steps': 100, '_keep_checkpoint_every_n_hours': 10000, '_keep_checkpoint_max': 5, '_save_checkpoints_steps': None}\n", 76 | "INFO:tensorflow:Calling model_fn.\n", 77 | "INFO:tensorflow:Done calling model_fn.\n", 78 | "INFO:tensorflow:Create CheckpointSaverHook.\n", 79 | "INFO:tensorflow:Graph was finalized.\n", 80 | "INFO:tensorflow:Running local_init_op.\n", 81 | "INFO:tensorflow:Done running local_init_op.\n", 82 | "INFO:tensorflow:Saving checkpoints for 1 into C:\\Users\\_domine3\\AppData\\Local\\Temp\\tmpi2at4f5r\\model.ckpt.\n", 83 | "INFO:tensorflow:loss = 146.32745, step = 1\n", 84 | "INFO:tensorflow:global_step/sec: 543.478\n", 85 | "INFO:tensorflow:loss = 30.012009, step = 101 (0.187 sec)\n", 86 | "INFO:tensorflow:global_step/sec: 729.927\n", 87 | "INFO:tensorflow:loss = 15.369042, step = 201 (0.136 sec)\n", 88 | "INFO:tensorflow:global_step/sec: 775.194\n", 89 | "INFO:tensorflow:loss = 10.820699, step = 301 (0.130 sec)\n", 90 | "INFO:tensorflow:global_step/sec: 719.425\n", 91 | "INFO:tensorflow:loss = 13.417875, step = 401 (0.139 sec)\n", 92 | "INFO:tensorflow:global_step/sec: 719.425\n", 93 | "INFO:tensorflow:loss = 10.081392, step = 501 (0.139 sec)\n", 94 | "INFO:tensorflow:global_step/sec: 729.927\n", 95 | "INFO:tensorflow:loss = 7.979366, step = 601 (0.136 sec)\n", 96 | "INFO:tensorflow:global_step/sec: 763.359\n", 97 | "INFO:tensorflow:loss = 10.715974, step = 701 (0.133 sec)\n", 98 | "INFO:tensorflow:global_step/sec: 709.22\n", 99 | "INFO:tensorflow:loss = 9.723078, step = 801 (0.139 sec)\n", 100 | "INFO:tensorflow:global_step/sec: 704.226\n", 101 | "INFO:tensorflow:loss = 11.645809, step = 901 (0.143 sec)\n", 102 | "INFO:tensorflow:global_step/sec: 740.741\n", 103 | "INFO:tensorflow:loss = 15.186651, step = 1001 (0.135 sec)\n", 104 | "INFO:tensorflow:global_step/sec: 751.88\n", 105 | "INFO:tensorflow:loss = 8.266764, step = 1101 (0.134 sec)\n", 106 | "INFO:tensorflow:global_step/sec: 769.23\n", 107 | "INFO:tensorflow:loss = 9.726784, step = 1201 (0.129 sec)\n", 108 | "INFO:tensorflow:global_step/sec: 657.895\n", 109 | "INFO:tensorflow:loss = 12.806259, step = 1301 (0.151 sec)\n", 110 | "INFO:tensorflow:global_step/sec: 584.795\n", 111 | "INFO:tensorflow:loss = 7.054368, step = 1401 (0.175 sec)\n", 112 | "INFO:tensorflow:global_step/sec: 561.798\n", 113 | "INFO:tensorflow:loss = 6.6224527, step = 1501 (0.176 sec)\n", 114 | "INFO:tensorflow:global_step/sec: 581.396\n", 115 | "INFO:tensorflow:loss = 3.9045315, step = 1601 (0.171 sec)\n", 116 | "INFO:tensorflow:global_step/sec: 751.879\n", 117 | "INFO:tensorflow:loss = 5.649955, step = 1701 (0.133 sec)\n", 118 | "INFO:tensorflow:global_step/sec: 724.638\n", 119 | "INFO:tensorflow:loss = 4.859914, step = 1801 (0.139 sec)\n", 120 | "INFO:tensorflow:global_step/sec: 641.026\n", 121 | "INFO:tensorflow:loss = 7.9570336, step = 1901 (0.153 sec)\n", 122 | "INFO:tensorflow:global_step/sec: 704.226\n", 123 | "INFO:tensorflow:loss = 4.841754, step = 2001 (0.144 sec)\n", 124 | "INFO:tensorflow:global_step/sec: 751.879\n", 125 | "INFO:tensorflow:loss = 7.0097547, step = 2101 (0.133 sec)\n", 126 | "INFO:tensorflow:global_step/sec: 684.932\n", 127 | "INFO:tensorflow:loss = 8.0823965, step = 2201 (0.146 sec)\n", 128 | "INFO:tensorflow:global_step/sec: 543.478\n", 129 | "INFO:tensorflow:loss = 5.6107244, step = 2301 (0.186 sec)\n", 130 | "INFO:tensorflow:global_step/sec: 537.634\n", 131 | "INFO:tensorflow:loss = 13.281709, step = 2401 (0.188 sec)\n", 132 | "INFO:tensorflow:global_step/sec: 373.134\n", 133 | "INFO:tensorflow:loss = 14.834627, step = 2501 (0.266 sec)\n", 134 | "INFO:tensorflow:global_step/sec: 450.45\n", 135 | "INFO:tensorflow:loss = 5.130001, step = 2601 (0.223 sec)\n", 136 | "INFO:tensorflow:global_step/sec: 444.445\n", 137 | "INFO:tensorflow:loss = 9.435574, step = 2701 (0.222 sec)\n", 138 | "INFO:tensorflow:global_step/sec: 436.681\n", 139 | "INFO:tensorflow:loss = 1.4045376, step = 2801 (0.231 sec)\n", 140 | "INFO:tensorflow:global_step/sec: 403.226\n", 141 | "INFO:tensorflow:loss = 7.711853, step = 2901 (0.245 sec)\n", 142 | "INFO:tensorflow:global_step/sec: 421.941\n", 143 | "INFO:tensorflow:loss = 6.3117566, step = 3001 (0.239 sec)\n", 144 | "INFO:tensorflow:global_step/sec: 537.634\n", 145 | "INFO:tensorflow:loss = 2.9975328, step = 3101 (0.185 sec)\n", 146 | "INFO:tensorflow:global_step/sec: 581.396\n", 147 | "INFO:tensorflow:loss = 6.0504866, step = 3201 (0.174 sec)\n", 148 | "INFO:tensorflow:global_step/sec: 625\n", 149 | "INFO:tensorflow:loss = 3.089339, step = 3301 (0.161 sec)\n", 150 | "INFO:tensorflow:global_step/sec: 574.713\n", 151 | "INFO:tensorflow:loss = 3.684587, step = 3401 (0.171 sec)\n", 152 | "INFO:tensorflow:global_step/sec: 546.448\n", 153 | "INFO:tensorflow:loss = 6.9830103, step = 3501 (0.190 sec)\n", 154 | "INFO:tensorflow:global_step/sec: 680.272\n", 155 | "INFO:tensorflow:loss = 4.2849407, step = 3601 (0.141 sec)\n", 156 | "INFO:tensorflow:global_step/sec: 704.226\n", 157 | "INFO:tensorflow:loss = 3.649962, step = 3701 (0.140 sec)\n", 158 | "INFO:tensorflow:global_step/sec: 621.118\n", 159 | "INFO:tensorflow:loss = 6.5003743, step = 3801 (0.165 sec)\n", 160 | "INFO:tensorflow:global_step/sec: 613.497\n", 161 | "INFO:tensorflow:loss = 2.2862856, step = 3901 (0.160 sec)\n", 162 | "INFO:tensorflow:Saving checkpoints for 4000 into C:\\Users\\_domine3\\AppData\\Local\\Temp\\tmpi2at4f5r\\model.ckpt.\n", 163 | "INFO:tensorflow:Loss for final step: 6.831412.\n" 164 | ] 165 | }, 166 | { 167 | "data": { 168 | "text/plain": [ 169 | "" 170 | ] 171 | }, 172 | "execution_count": 31, 173 | "metadata": {}, 174 | "output_type": "execute_result" 175 | } 176 | ], 177 | "source": [ 178 | "# train the classifier\n", 179 | "\n", 180 | "iris_feature_columns = [\n", 181 | " tf.contrib.layers.real_valued_column(FEATURE_SEPAL_LENGTH, dimension=1, dtype=tf.float32),\n", 182 | " tf.contrib.layers.real_valued_column(FEATURE_SEPAL_WIDTH, dimension=1, dtype=tf.float32),\n", 183 | " tf.contrib.layers.real_valued_column(FEATURE_PETAL_LENGTH, dimension=1, dtype=tf.float32),\n", 184 | " tf.contrib.layers.real_valued_column(FEATURE_PETAL_WIDTH, dimension=1, dtype=tf.float32)\n", 185 | "]\n", 186 | "\n", 187 | "\n", 188 | "# format data as required by the tensorflow framework\n", 189 | "x = {\n", 190 | " FEATURE_SEPAL_LENGTH : np.array(training_df[FEATURE_SEPAL_LENGTH]),\n", 191 | " FEATURE_SEPAL_WIDTH : np.array(training_df[FEATURE_SEPAL_WIDTH]),\n", 192 | " FEATURE_PETAL_LENGTH : np.array(training_df[FEATURE_PETAL_LENGTH]),\n", 193 | " FEATURE_PETAL_WIDTH : np.array(training_df[FEATURE_PETAL_WIDTH])\n", 194 | "}\n", 195 | "\n", 196 | "#neural network\n", 197 | "classifier = tf.estimator.DNNClassifier(\n", 198 | " feature_columns = iris_feature_columns,\n", 199 | " hidden_units = [5, 5],\n", 200 | " n_classes = 3)\n", 201 | "\n", 202 | "\n", 203 | "# Define the training inputs\n", 204 | "train_input_fn = tf.estimator.inputs.numpy_input_fn(\n", 205 | " x = x,\n", 206 | " y = np.array(training_df[LABEL]).astype(int),\n", 207 | " num_epochs = None,\n", 208 | " shuffle = True)\n", 209 | "\n", 210 | "# Train model.\n", 211 | "classifier.train(input_fn = train_input_fn, steps = 4000)" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": 32, 217 | "metadata": {}, 218 | "outputs": [ 219 | { 220 | "name": "stdout", 221 | "output_type": "stream", 222 | "text": [ 223 | "INFO:tensorflow:Calling model_fn.\n", 224 | "INFO:tensorflow:Done calling model_fn.\n", 225 | "INFO:tensorflow:Starting evaluation at 2018-06-27-09:33:54\n", 226 | "INFO:tensorflow:Graph was finalized.\n", 227 | "INFO:tensorflow:Restoring parameters from C:\\Users\\_domine3\\AppData\\Local\\Temp\\tmpi2at4f5r\\model.ckpt-4000\n", 228 | "INFO:tensorflow:Running local_init_op.\n", 229 | "INFO:tensorflow:Done running local_init_op.\n", 230 | "INFO:tensorflow:Finished evaluation at 2018-06-27-09:33:54\n", 231 | "INFO:tensorflow:Saving dict for global step 4000: accuracy = 0.96666664, average_loss = 0.08968501, global_step = 4000, loss = 2.6905503\n", 232 | "Test Accuracy: 0.96666664\n" 233 | ] 234 | } 235 | ], 236 | "source": [ 237 | "\n", 238 | "# test accuracy of the model\n", 239 | "# format data as required by the tensorflow framework\n", 240 | "x = {\n", 241 | " FEATURE_SEPAL_LENGTH : np.array(test_df[FEATURE_SEPAL_LENGTH]),\n", 242 | " FEATURE_SEPAL_WIDTH : np.array(test_df[FEATURE_SEPAL_WIDTH]),\n", 243 | " FEATURE_PETAL_LENGTH : np.array(test_df[FEATURE_PETAL_LENGTH]),\n", 244 | " FEATURE_PETAL_WIDTH : np.array(test_df[FEATURE_PETAL_WIDTH])\n", 245 | "}\n", 246 | "\n", 247 | "# Define the training inputs\n", 248 | "test_input_fn = tf.estimator.inputs.numpy_input_fn(\n", 249 | " x = x,\n", 250 | " y = np.array(test_df[LABEL]).astype(int),\n", 251 | " num_epochs = 1,\n", 252 | " shuffle = False)\n", 253 | "\n", 254 | "# Evaluate accuracy.\n", 255 | "accuracy_score = classifier.evaluate(input_fn=test_input_fn)[\"accuracy\"]\n", 256 | "\n", 257 | "print(\"Test Accuracy: \", accuracy_score)\n" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": 33, 263 | "metadata": {}, 264 | "outputs": [ 265 | { 266 | "name": "stdout", 267 | "output_type": "stream", 268 | "text": [ 269 | "INFO:tensorflow:Calling model_fn.\n", 270 | "INFO:tensorflow:Done calling model_fn.\n", 271 | "INFO:tensorflow:Graph was finalized.\n", 272 | "INFO:tensorflow:Restoring parameters from C:\\Users\\_domine3\\AppData\\Local\\Temp\\tmpi2at4f5r\\model.ckpt-4000\n", 273 | "INFO:tensorflow:Running local_init_op.\n", 274 | "INFO:tensorflow:Done running local_init_op.\n", 275 | "\n", 276 | "Prediction is \"0\" (certainity 100.0%), expected \"0\"\n", 277 | "\n", 278 | "Prediction is \"1\" (certainity 100.0%), expected \"1\"\n", 279 | "\n", 280 | "Prediction is \"2\" (certainity 99.4%), expected \"2\"\n" 281 | ] 282 | } 283 | ], 284 | "source": [ 285 | "# We do some more manual testing\n", 286 | "# format data as required by the tensorflow framework\n", 287 | "x = {\n", 288 | " FEATURE_SEPAL_LENGTH : np.array([5.0, 6.7, 7.4]),\n", 289 | " FEATURE_SEPAL_WIDTH : np.array([3.5, 3.1, 2.8]),\n", 290 | " FEATURE_PETAL_LENGTH : np.array([1.3, 4.4, 6.1]),\n", 291 | " FEATURE_PETAL_WIDTH : np.array([0.3, 1.4, 1.9])\n", 292 | "}\n", 293 | "\n", 294 | "expected = np.array([0, 1, 2])\n", 295 | "\n", 296 | "# Define the training inputs\n", 297 | "predict_input_fn = tf.estimator.inputs.numpy_input_fn(\n", 298 | " x = x,\n", 299 | " num_epochs = 1,\n", 300 | " shuffle = False)\n", 301 | "\n", 302 | "predictions = classifier.predict(input_fn = predict_input_fn)\n", 303 | "\n", 304 | "for pred_dict, expec in zip(predictions, expected):\n", 305 | " class_id = pred_dict['class_ids'][0]\n", 306 | " probability = pred_dict['probabilities'][class_id]\n", 307 | "\n", 308 | " print('\\nPrediction is \"{}\" (certainity {:.1f}%), expected \"{}\"'.format(class_id, 100 * probability, expec))\n", 309 | " \n" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": 36, 315 | "metadata": {}, 316 | "outputs": [ 317 | { 318 | "name": "stdout", 319 | "output_type": "stream", 320 | "text": [ 321 | "INFO:tensorflow:Calling model_fn.\n", 322 | "INFO:tensorflow:Done calling model_fn.\n", 323 | "INFO:tensorflow:Signatures INCLUDED in export for Regress: None\n", 324 | "INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict']\n", 325 | "INFO:tensorflow:Signatures INCLUDED in export for Classify: ['classification', 'serving_default']\n", 326 | "INFO:tensorflow:Restoring parameters from C:\\Users\\_domine3\\AppData\\Local\\Temp\\tmpi2at4f5r\\model.ckpt-4000\n", 327 | "INFO:tensorflow:Assets added to graph.\n", 328 | "INFO:tensorflow:No assets to write.\n", 329 | "INFO:tensorflow:SavedModel written to: b\"stored_model\\\\temp-b'1530093489'\\\\saved_model.pbtxt\"\n", 330 | "Model exported to stored_model\\1530093489\n" 331 | ] 332 | } 333 | ], 334 | "source": [ 335 | "# export the model\n", 336 | "\n", 337 | "def serving_input_receiver_fn():\n", 338 | " serialized_tf_example = tf.placeholder(dtype = tf.string, shape = [None], name = 'input_tensors')\n", 339 | " receiver_tensors = {'predictor_inputs' : serialized_tf_example}\n", 340 | " feature_spec = {FEATURE_SEPAL_LENGTH : tf.FixedLenFeature([25], tf.float32),\n", 341 | " FEATURE_SEPAL_WIDTH : tf.FixedLenFeature([25], tf.float32),\n", 342 | " FEATURE_PETAL_LENGTH : tf.FixedLenFeature([25], tf.float32),\n", 343 | " FEATURE_PETAL_WIDTH : tf.FixedLenFeature([25], tf.float32)}\n", 344 | " features = tf.parse_example(serialized_tf_example, feature_spec)\n", 345 | " return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)\n", 346 | "\n", 347 | "\n", 348 | "model_dir = classifier.export_savedmodel(export_dir_base = \"stored_model\", \n", 349 | " serving_input_receiver_fn = serving_input_receiver_fn,\n", 350 | " as_text = True)\n", 351 | "print('Model exported to '+ model_dir.decode())" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": 35, 357 | "metadata": {}, 358 | "outputs": [ 359 | { 360 | "name": "stdout", 361 | "output_type": "stream", 362 | "text": [ 363 | "INFO:tensorflow:Restoring parameters from b'./stored_model/1530092040\\\\variables\\\\variables'\n", 364 | "Model ready for visualization. Execute: tensorboard --logdir=./visualization/tensorboard/1530092040\n" 365 | ] 366 | } 367 | ], 368 | "source": [ 369 | "# Export the saved model so it can be visualized on the tensorboard \n", 370 | "# Code inspired from https://github.com/tensorflow/tensorflow/issues/8854\n", 371 | "\n", 372 | "# Work in progress: The model is shown in the tensorboard, but it is not showing the right data.\n", 373 | "\n", 374 | "from tensorflow.python.summary import summary\n", 375 | "\n", 376 | "log_dir = './visualization/tensorboard/1530093489' # + model_dir.decode().replace('stored_model\\\\','')\n", 377 | "model_dir= './stored_model/1530093489'\n", 378 | "with tf.Session(graph = tf.Graph()) as sess:\n", 379 | " tf.saved_model.loader.load(\n", 380 | " sess,\n", 381 | " [tf.saved_model.tag_constants.SERVING],\n", 382 | " model_dir)\n", 383 | "\n", 384 | " pb_visual_writer = summary.FileWriter(log_dir)\n", 385 | " pb_visual_writer.add_graph(sess.graph)\n", 386 | " print(\"Model ready for visualization. Execute: \"\n", 387 | " \"tensorboard --logdir={}\".format(log_dir))" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": null, 393 | "metadata": {}, 394 | "outputs": [], 395 | "source": [] 396 | } 397 | ], 398 | "metadata": { 399 | "kernelspec": { 400 | "display_name": "Python 3", 401 | "language": "python", 402 | "name": "python3" 403 | }, 404 | "language_info": { 405 | "codemirror_mode": { 406 | "name": "ipython", 407 | "version": 3 408 | }, 409 | "file_extension": ".py", 410 | "mimetype": "text/x-python", 411 | "name": "python", 412 | "nbconvert_exporter": "python", 413 | "pygments_lexer": "ipython3", 414 | "version": "3.5.5" 415 | } 416 | }, 417 | "nbformat": 4, 418 | "nbformat_minor": 2 419 | } 420 | --------------------------------------------------------------------------------