├── topn_sorter_test.go ├── server ├── Dockerfile └── server.go ├── vendor ├── vendor.json └── github.com │ └── tensorflow │ └── tensorflow │ ├── tensorflow │ └── go │ │ ├── BUILD │ │ ├── session.cpp │ │ ├── version.go │ │ ├── doc.go │ │ ├── lib.go │ │ ├── status.go │ │ ├── test.sh │ │ ├── saved_model.go │ │ ├── shape.go │ │ ├── operation.go │ │ ├── README.md │ │ ├── graph.go │ │ ├── session.go │ │ └── tensor.go │ └── LICENSE ├── README.md ├── topn_sorter.go ├── beam.go ├── vocabulary.go ├── tool └── benchmark.go └── beam_search.go /topn_sorter_test.go: -------------------------------------------------------------------------------- 1 | package gotalk 2 | 3 | import ( 4 | "log" 5 | "testing" 6 | ) 7 | 8 | func TestSort(t *testing.T) { 9 | array := []float32{1.0, 3.0, 2.0, 1.5, 7.2, 1.1} 10 | log.Printf("%+v", topNSort(array, 3)) 11 | log.Printf("%+v", topNSort(array, 4)) 12 | log.Printf("%+v", topNSort(array, 1)) 13 | log.Printf("%+v", topNSort(array, 7)) 14 | } 15 | -------------------------------------------------------------------------------- /server/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:16.04 2 | 3 | RUN mkdir /gotalk 4 | 5 | ADD libtensorflow.so /usr/lib/libtensorflow.so 6 | ADD frozen_model.pb /gotalk/frozen_model.pb 7 | ADD vocabulary.txt /gotalk/vocabulary.txt 8 | ADD server /gotalk/server 9 | 10 | EXPOSE 80 11 | 12 | CMD /gotalk/server --vocab /gotalk/vocabulary.txt --model /gotalk/frozen_model.pb --port 80 13 | -------------------------------------------------------------------------------- /vendor/vendor.json: -------------------------------------------------------------------------------- 1 | { 2 | "comment": "", 3 | "ignore": "test", 4 | "package": [ 5 | { 6 | "checksumSHA1": "T3AENIPUcQ8V/ylYC9PqsUb6hHM=", 7 | "path": "github.com/tensorflow/tensorflow/tensorflow/go", 8 | "revision": "468eb47a90e301201eb9ab4ff111c16ef535d66f", 9 | "revisionTime": "2017-04-10T01:52:43Z" 10 | } 11 | ], 12 | "rootPath": "github.com/huichen/gotalk" 13 | } 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # gotalk 2 | 3 | 这是我在 [QCon 北京 2017 演讲](http://2017.qconbeijing.com/presentation/872) 的配套代码。 4 | 5 | 基于 Go 实现了一个深度学习看图说话服务,即机器学习的 serving(inference) 部分,Tensorflow 的 Python 训练代码在 [这里](https://github.com/tensorflow/models/tree/master/im2txt)。 6 | 7 | 代码了实现了 tensorflow 模型导入、输入输出和 LSTM beam search 功能,并实现了 web 服务。 8 | 9 | 我打包的 Docker 镜像中已经包含了所有模型文件,可以直接运行,见 [docker hub](https://hub.docker.com/r/unmerged/gotalk/)。 10 | 11 | 模型文件由 [free_tf_model 工具](https://github.com/huichen/freeze_tf_model) 从 tf.train.Saver 保存的 checkpoint 生成,对于图说这个项目,训练代码保存的 checkpoint 无法直接用,需要通过 [这个代码](https://github.com/huichen/im2txt) 中的 save_model.py 工具转化成 inference 模型后再调用 free_tf_model。 12 | -------------------------------------------------------------------------------- /topn_sorter.go: -------------------------------------------------------------------------------- 1 | package gotalk 2 | 3 | // 得到 array 中 top n 的元素的 index 4 | func topNSort(array []float32, n int) []int { 5 | result := make([]int, 0, n) 6 | for i, number := range array { 7 | j := 0 8 | for ; j < len(result) && number <= array[result[j]]; j++ { 9 | } 10 | 11 | if j == len(result) { 12 | if j < n { 13 | result = append(result, i) 14 | } 15 | continue 16 | } 17 | 18 | if len(result) == n { 19 | for k := n - 1; k > j; k-- { 20 | result[k] = result[k-1] 21 | } 22 | } else { 23 | result = append(result, 0) 24 | for k := len(result) - 1; k > j; k-- { 25 | result[k] = result[k-1] 26 | } 27 | } 28 | result[j] = i 29 | } 30 | 31 | return result 32 | } 33 | -------------------------------------------------------------------------------- /vendor/github.com/tensorflow/tensorflow/tensorflow/go/BUILD: -------------------------------------------------------------------------------- 1 | # Description: 2 | # Go API for TensorFlow. 3 | 4 | package( 5 | default_visibility = ["//visibility:private"], 6 | ) 7 | 8 | licenses(["notice"]) # Apache 2.0 9 | 10 | exports_files(["LICENSE"]) 11 | 12 | sh_test( 13 | name = "test", 14 | size = "small", 15 | srcs = ["test.sh"], 16 | data = [ 17 | ":all_files", # Go sources 18 | "//tensorflow:libtensorflow.so", # C library 19 | "//tensorflow/c:headers", # C library header 20 | "//tensorflow/cc/saved_model:saved_model_half_plus_two", # Testdata for LoadSavedModel 21 | ], 22 | ) 23 | 24 | filegroup( 25 | name = "all_files", 26 | srcs = glob( 27 | ["**/*"], 28 | exclude = [ 29 | "**/METADATA", 30 | "**/OWNERS", 31 | ], 32 | ), 33 | visibility = ["//tensorflow:__subpackages__"], 34 | ) 35 | -------------------------------------------------------------------------------- /vendor/github.com/tensorflow/tensorflow/tensorflow/go/session.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | // TODO(ashankar): Remove this file when TensorFlow 1.1 is released. 18 | // See lib.go for details. 19 | 20 | extern "C" { 21 | extern void tfDeletePRunHandle(const char* h); 22 | } 23 | 24 | void tfDeletePRunHandle(const char* h) { 25 | delete[] h; 26 | } 27 | -------------------------------------------------------------------------------- /vendor/github.com/tensorflow/tensorflow/tensorflow/go/version.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package tensorflow 18 | 19 | // #include 20 | // #include "tensorflow/c/c_api.h" 21 | import "C" 22 | 23 | // Version returns a string describing the version of the underlying TensorFlow 24 | // runtime. 25 | func Version() string { return C.GoString(C.TF_Version()) } 26 | -------------------------------------------------------------------------------- /beam.go: -------------------------------------------------------------------------------- 1 | package gotalk 2 | 3 | import ( 4 | "flag" 5 | ) 6 | 7 | type Beam struct { 8 | sentence []int64 9 | logProb float64 10 | stateFeed []float32 11 | isClosed bool 12 | } 13 | 14 | type TopNBeams struct { 15 | beams []Beam 16 | size int 17 | } 18 | 19 | // 使用前必须初始化 20 | func (b *TopNBeams) Init(maxBeamSize int) { 21 | b.beams = make([]Beam, maxBeamSize) 22 | b.size = 0 23 | } 24 | 25 | // 如果 logProb 大于其中任何一个元素,按照顺序将 beam 添加到 TopNBeams 中 26 | func (b *TopNBeams) Push(beam Beam) { 27 | // 找到可以插入的点 28 | iInsert := 0 29 | for ; iInsert < b.size && beam.logProb <= b.beams[iInsert].logProb; iInsert++ { 30 | } 31 | 32 | // 无处插入 33 | if iInsert == b.maxBeamSize { 34 | return 35 | } 36 | 37 | // 添加到最后 38 | if iInsert == b.size { 39 | b.beams[b.size] = beam 40 | b.size++ 41 | return 42 | } 43 | 44 | // 添加到中间,并去掉最后一个 45 | if b.size == b.maxBeamSize { 46 | for j := b.size - 1; j > iInsert; j-- { 47 | b.beams[j] = b.beams[j-1] 48 | } 49 | b.beams[iInsert] = beam 50 | return 51 | } 52 | 53 | // 添加到中间 54 | for j := b.size; j > iInsert; j-- { 55 | b.beams[j] = b.beams[j-1] 56 | } 57 | b.beams[iInsert] = beam 58 | b.size++ 59 | } 60 | -------------------------------------------------------------------------------- /vendor/github.com/tensorflow/tensorflow/tensorflow/go/doc.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | // Package tensorflow is a Go binding to TensorFlow. 18 | // 19 | // The API is subject to change and may break at any time. 20 | // 21 | // TensorFlow (www.tensorflow.org) is an open source software library for 22 | // numerical computation using data flow graphs. This package provides 23 | // functionality to build and execute such graphs and depends on 24 | // TensorFlow being available. For installation instructions see 25 | // https://www.tensorflow.org/code/tensorflow/go/README.md 26 | package tensorflow 27 | -------------------------------------------------------------------------------- /vendor/github.com/tensorflow/tensorflow/tensorflow/go/lib.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package tensorflow 18 | 19 | // #cgo LDFLAGS: -ltensorflow 20 | // #cgo CFLAGS: -I${SRCDIR}/../../ 21 | // 22 | // // TODO(ashankar): Remove this after TensorFlow 1.1 has been released. 23 | // // Till then, the TensorFlow C API binary releases do not contain 24 | // // the TF_DeletePRunHandle symbol. We work around that by 25 | // // implementing the equivalent in session.cpp 26 | // extern void tfDeletePRunHandle(const char*); 27 | import "C" 28 | 29 | func deletePRunHandle(h *C.char) { 30 | C.tfDeletePRunHandle(h) 31 | } 32 | -------------------------------------------------------------------------------- /vocabulary.go: -------------------------------------------------------------------------------- 1 | package gotalk 2 | 3 | import ( 4 | "bufio" 5 | "os" 6 | "strings" 7 | ) 8 | 9 | const ( 10 | startWord = "" 11 | endWord = "" 12 | unkWord = "" 13 | ) 14 | 15 | // 该结构体维护一个词典,用于 one-hot encoding 16 | type Vocabulary struct { 17 | idToWord map[int64]string 18 | wordToId map[string]int64 19 | 20 | StartId int64 21 | EndId int64 22 | UnkId int64 23 | } 24 | 25 | func (v *Vocabulary) LoadFromFile(filename string) error { 26 | v.idToWord = make(map[int64]string) 27 | v.wordToId = make(map[string]int64) 28 | file, err := os.Open(filename) 29 | if err != nil { 30 | return err 31 | } 32 | defer file.Close() 33 | 34 | scanner := bufio.NewScanner(file) 35 | id := int64(0) 36 | for scanner.Scan() { 37 | fields := strings.Split(scanner.Text(), " ") 38 | word := fields[0] 39 | v.idToWord[id] = word 40 | v.wordToId[word] = id 41 | id++ 42 | } 43 | if err := scanner.Err(); err != nil { 44 | return err 45 | } 46 | 47 | if _, found := v.wordToId[unkWord]; !found { 48 | v.wordToId[unkWord] = id 49 | v.idToWord[id] = unkWord 50 | } 51 | 52 | v.StartId = v.wordToId[startWord] 53 | v.EndId = v.wordToId[endWord] 54 | v.UnkId = v.wordToId[unkWord] 55 | 56 | return nil 57 | } 58 | 59 | func (v *Vocabulary) GetId(word string) int64 { 60 | if id, found := v.wordToId[word]; found { 61 | return id 62 | } 63 | return v.UnkId 64 | } 65 | 66 | func (v *Vocabulary) GetWord(id int64) string { 67 | if word, found := v.idToWord[id]; found { 68 | return word 69 | } 70 | return unkWord 71 | } 72 | -------------------------------------------------------------------------------- /server/server.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "flag" 6 | "io/ioutil" 7 | "log" 8 | "net/http" 9 | 10 | "github.com/huichen/gotalk" 11 | tf "github.com/tensorflow/tensorflow/tensorflow/go" 12 | ) 13 | 14 | var ( 15 | modelFile = flag.String("model", "", "模型文件") 16 | vocabFile = flag.String("vocab", "", "词典文件") 17 | port = flag.String("port", "", "服务端口") 18 | 19 | session *tf.Session 20 | graph *tf.Graph 21 | vocab *gotalk.Vocabulary 22 | ) 23 | 24 | func main() { 25 | flag.Parse() 26 | 27 | // 载入词典 28 | vocab = &gotalk.Vocabulary{} 29 | err := vocab.LoadFromFile(*vocabFile) 30 | if err != nil { 31 | log.Fatalf("无法载入词典文件:%s", err) 32 | } 33 | 34 | // 载入模型文件 35 | model, err := ioutil.ReadFile(*modelFile) 36 | if err != nil { 37 | log.Fatalf("模型文件读取错误:%s", err) 38 | } 39 | 40 | // 从模型文件构建 graph 41 | graph = tf.NewGraph() 42 | if err := graph.Import(model, ""); err != nil { 43 | log.Fatalf("Graph import 错误:%s", err) 44 | } 45 | 46 | // 创建 tf session 47 | session, err = tf.NewSession(graph, nil) 48 | if err != nil { 49 | log.Fatalf("TF session 创建错误:", err) 50 | } 51 | defer session.Close() 52 | 53 | http.HandleFunc("/im2txt", process) 54 | 55 | log.Fatal(http.ListenAndServe(":"+*port, nil)) 56 | 57 | } 58 | 59 | func process(w http.ResponseWriter, r *http.Request) { 60 | url := r.URL.Query()["url"] 61 | if len(url) != 1 { 62 | return 63 | } 64 | log.Printf("%s", url[0]) 65 | 66 | // 得到 url 图像的字节串 67 | response, err := http.Get(url[0]) 68 | if err != nil { 69 | return 70 | } 71 | image, err := ioutil.ReadAll(response.Body) 72 | if err != nil { 73 | return 74 | } 75 | tensor, err := tf.NewTensor(string(image)) 76 | if err != nil { 77 | return 78 | } 79 | 80 | // 生成标题 81 | captions, err := gotalk.GenerateCaption(session, graph, vocab, tensor) 82 | if err != nil { 83 | return 84 | } 85 | log.Printf("%+v", captions) 86 | 87 | json.NewEncoder(w).Encode(captions) 88 | } 89 | -------------------------------------------------------------------------------- /tool/benchmark.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "io/ioutil" 6 | "log" 7 | "os" 8 | "runtime/pprof" 9 | "time" 10 | 11 | "github.com/huichen/gotalk" 12 | tf "github.com/tensorflow/tensorflow/tensorflow/go" 13 | ) 14 | 15 | var ( 16 | modelFile = flag.String("model", "", "模型文件") 17 | imageFile = flag.String("image", "", "图片文件") 18 | vocabFile = flag.String("vocab", "", "词典文件") 19 | cpuprofile = flag.String("cpuprofile", "", "CPU profile 文件") 20 | ) 21 | 22 | func main() { 23 | flag.Parse() 24 | 25 | if *cpuprofile != "" { 26 | f, err := os.Create(*cpuprofile) 27 | if err != nil { 28 | log.Fatal(err) 29 | } 30 | pprof.StartCPUProfile(f) 31 | defer pprof.StopCPUProfile() 32 | } 33 | 34 | // 载入词典 35 | vocab := gotalk.Vocabulary{} 36 | err := vocab.LoadFromFile(*vocabFile) 37 | if err != nil { 38 | log.Fatalf("无法载入词典文件:%s", err) 39 | } 40 | 41 | // 载入模型文件 42 | model, err := ioutil.ReadFile(*modelFile) 43 | if err != nil { 44 | log.Fatalf("模型文件读取错误:%s", err) 45 | } 46 | 47 | // 从模型文件构建 graph 48 | graph := tf.NewGraph() 49 | if err := graph.Import(model, ""); err != nil { 50 | log.Fatalf("Graph import 错误:%s", err) 51 | } 52 | 53 | // 创建 tf session 54 | session, err := tf.NewSession(graph, nil) 55 | if err != nil { 56 | log.Fatalf("TF session 创建错误:", err) 57 | } 58 | defer session.Close() 59 | 60 | // 创建图像 tensor 61 | image, err := makeTensorFromImage(*imageFile) 62 | if err != nil { 63 | log.Fatalf("无法创建图像 tensor:", err) 64 | } 65 | 66 | start := time.Now() 67 | gotalk.GenerateCaption(session, graph, &vocab, image) 68 | elapsed := time.Since(start) 69 | log.Printf("GenerateCaption 花费时间 %s", elapsed) 70 | } 71 | 72 | func makeTensorFromImage(filename string) (*tf.Tensor, error) { 73 | bytes, err := ioutil.ReadFile(filename) 74 | if err != nil { 75 | return nil, err 76 | } 77 | tensor, err := tf.NewTensor(string(bytes)) 78 | if err != nil { 79 | return nil, err 80 | } 81 | return tensor, err 82 | } 83 | -------------------------------------------------------------------------------- /vendor/github.com/tensorflow/tensorflow/tensorflow/go/status.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package tensorflow 18 | 19 | // #include "tensorflow/c/c_api.h" 20 | import "C" 21 | 22 | import "runtime" 23 | 24 | type code C.TF_Code 25 | 26 | // status holds error information returned by TensorFlow. We convert all 27 | // TF statuses to Go errors. 28 | type status struct { 29 | c *C.TF_Status 30 | } 31 | 32 | func newStatus() *status { 33 | s := &status{C.TF_NewStatus()} 34 | runtime.SetFinalizer(s, (*status).finalizer) 35 | return s 36 | } 37 | 38 | func (s *status) finalizer() { 39 | C.TF_DeleteStatus(s.c) 40 | } 41 | 42 | func (s *status) Code() code { 43 | return code(C.TF_GetCode(s.c)) 44 | } 45 | 46 | func (s *status) String() string { 47 | return C.GoString(C.TF_Message(s.c)) 48 | } 49 | 50 | // Err converts the status to a Go error and returns nil if the status is OK. 51 | func (s *status) Err() error { 52 | if s == nil || s.Code() == C.TF_OK { 53 | return nil 54 | } 55 | return (*statusError)(s) 56 | } 57 | 58 | // statusError is distinct from status because it fulfills the error interface. 59 | // status itself may have a TF_OK code and is not always considered an error. 60 | // 61 | // TODO(jhseu): Make public, rename to Error, and provide a way for users to 62 | // check status codes. 63 | type statusError status 64 | 65 | func (s *statusError) Error() string { 66 | return (*status)(s).String() 67 | } 68 | -------------------------------------------------------------------------------- /vendor/github.com/tensorflow/tensorflow/tensorflow/go/test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | # TensorFlow uses 'bazel' for builds and tests. 18 | # The TensorFlow Go API aims to be usable with the 'go' tool 19 | # (using 'go get' etc.) and thus without bazel. 20 | # 21 | # This script acts as a brige between bazel and go so that: 22 | # bazel test :test 23 | # succeeds iff 24 | # go test github.com/tensorflow/tensorflow/tensorflow/go 25 | # succeeds. 26 | 27 | set -ex 28 | 29 | # Find the 'go' tool 30 | if [[ ! -x "go" && -z $(which go) ]] 31 | then 32 | if [[ -x "/usr/local/go/bin/go" ]] 33 | then 34 | export PATH="${PATH}:/usr/local/go/bin" 35 | else 36 | echo "Could not find the 'go' tool in PATH or /usr/local/go" 37 | exit 1 38 | fi 39 | fi 40 | 41 | # Setup a GOPATH that includes just the TensorFlow Go API. 42 | export GOPATH="${TEST_TMPDIR}/go" 43 | mkdir -p "${GOPATH}/src/github.com/tensorflow" 44 | ln -s "${PWD}" "${GOPATH}/src/github.com/tensorflow/tensorflow" 45 | 46 | # Ensure that the TensorFlow C library is accessible to the 47 | # linker at build and run time. 48 | export LIBRARY_PATH="${PWD}/tensorflow" 49 | OS=$(uname -s) 50 | if [[ "${OS}" = "Linux" ]] 51 | then 52 | if [[ -z "${LD_LIBRARY_PATH}" ]] 53 | then 54 | export LD_LIBRARY_PATH="${PWD}/tensorflow" 55 | else 56 | export LD_LIBRARY_PATH="${PWD}/tensorflow:${LD_LIBRARY_PATH}" 57 | fi 58 | elif [[ "${OS}" = "Darwin" ]] 59 | then 60 | if [[ -z "${DYLD_LIBRARY_PATH}" ]] 61 | then 62 | export DYLD_LIBRARY_PATH="${PWD}/tensorflow" 63 | else 64 | export DYLD_LIBRARY_PATH="${PWD}/tensorflow:${DYLD_LIBRARY_PATH}" 65 | fi 66 | fi 67 | 68 | # Document the Go version and run tests 69 | echo "Go version: $(go version)" 70 | go test \ 71 | github.com/tensorflow/tensorflow/tensorflow/go \ 72 | github.com/tensorflow/tensorflow/tensorflow/go/op 73 | -------------------------------------------------------------------------------- /vendor/github.com/tensorflow/tensorflow/tensorflow/go/saved_model.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package tensorflow 18 | 19 | import ( 20 | "runtime" 21 | "unsafe" 22 | ) 23 | 24 | // #include 25 | // #include "tensorflow/c/c_api.h" 26 | import "C" 27 | 28 | // SavedModel represents the contents of loaded SavedModel. 29 | // TODO(jhseu): Add and document metagraphdef when we pregenerate protobufs. 30 | type SavedModel struct { 31 | Session *Session 32 | Graph *Graph 33 | } 34 | 35 | // LoadSavedModel creates a new SavedModel from a model previously 36 | // exported to a directory on disk. 37 | // 38 | // Exported models contain a set of graphs and, optionally, variable values. 39 | // Tags in the model identify a single graph. LoadSavedModel initializes a 40 | // session with the identified graph and with variables initialized to from the 41 | // checkpoints on disk. 42 | // 43 | // The tensorflow package currently does not have the ability to export a model 44 | // to a directory from Go. This function thus currently targets loading models 45 | // exported in other languages, such as using tf.saved_model.builder in Python. 46 | // See: 47 | // https://www.tensorflow.org/code/tensorflow/python/saved_model/ 48 | func LoadSavedModel(exportDir string, tags []string, options *SessionOptions) (*SavedModel, error) { 49 | status := newStatus() 50 | cOpt, doneOpt, err := options.c() 51 | defer doneOpt() 52 | if err != nil { 53 | return nil, err 54 | } 55 | cExportDir := C.CString(exportDir) 56 | cTags := make([]*C.char, len(tags)) 57 | for i := range tags { 58 | cTags[i] = C.CString(tags[i]) 59 | } 60 | graph := NewGraph() 61 | // TODO(jhseu): Add support for run_options and meta_graph_def. 62 | cSess := C.TF_LoadSessionFromSavedModel(cOpt, nil, cExportDir, (**C.char)(unsafe.Pointer(&cTags[0])), C.int(len(cTags)), graph.c, nil, status.c) 63 | for i := range cTags { 64 | C.free(unsafe.Pointer(cTags[i])) 65 | } 66 | C.free(unsafe.Pointer(cExportDir)) 67 | 68 | if err := status.Err(); err != nil { 69 | return nil, err 70 | } 71 | s := &Session{c: cSess} 72 | runtime.SetFinalizer(s, func(s *Session) { s.Close() }) 73 | return &SavedModel{Session: s, Graph: graph}, nil 74 | } 75 | -------------------------------------------------------------------------------- /vendor/github.com/tensorflow/tensorflow/tensorflow/go/shape.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package tensorflow 18 | 19 | import ( 20 | "fmt" 21 | "strings" 22 | ) 23 | 24 | // Shape represents the (possibly partially known) shape of a tensor that will 25 | // be produced by an operation. 26 | // 27 | // The zero-value of a Shape represents a shape with an unknown number of 28 | // dimensions. 29 | type Shape struct { 30 | dims []int64 31 | } 32 | 33 | // ScalarShape returns a Shape representing a scalar. 34 | func ScalarShape() Shape { 35 | return Shape{dims: make([]int64, 0)} 36 | } 37 | 38 | // MakeShape returns a Shape with the provided size of each dimension. 39 | // 40 | // A value of -1 implies that the size of the corresponding dimension is not 41 | // known. 42 | func MakeShape(shape ...int64) Shape { 43 | cpy := make([]int64, len(shape)) 44 | copy(cpy, shape) 45 | return Shape{dims: cpy} 46 | } 47 | 48 | // NumDimensions returns the number of dimensions represented by s, or -1 if 49 | // unknown. 50 | func (s Shape) NumDimensions() int { 51 | if s.dims == nil { 52 | return -1 53 | } 54 | return len(s.dims) 55 | } 56 | 57 | // Size returns the size of the dim-th dimension of the shape, or -1 if it 58 | // is unknown. 59 | // 60 | // REQUIRES: 0 <= dim < s.NumDimensions() 61 | func (s Shape) Size(dim int) int64 { 62 | if dim < 0 || dim > s.NumDimensions() { 63 | return -1 64 | } 65 | return s.dims[dim] 66 | } 67 | 68 | // IsFullySpecified returns true iff the size of all the dimensions of s are 69 | // known. 70 | func (s Shape) IsFullySpecified() bool { 71 | if s.dims == nil { 72 | return false 73 | } 74 | for _, size := range s.dims { 75 | if size <= 1 { 76 | return false 77 | } 78 | } 79 | return true 80 | } 81 | 82 | // ToSlice returns the (possibly partially known) shape represented by s as a 83 | // slice, or an error if the number of dimensions is not known. 84 | func (s Shape) ToSlice() ([]int64, error) { 85 | if s.dims == nil { 86 | return nil, fmt.Errorf("cannot create a slice for a Shape with an unknown number of dimensions") 87 | } 88 | cpy := make([]int64, len(s.dims)) 89 | copy(cpy, s.dims) 90 | return cpy, nil 91 | } 92 | 93 | func (s Shape) String() string { 94 | if s.dims == nil { 95 | return "?" 96 | } 97 | ret := fmt.Sprint(s.dims) 98 | for _, size := range s.dims { 99 | if size < 0 { 100 | ret = strings.Replace(ret, fmt.Sprint(size), "?", 1) 101 | } 102 | } 103 | return strings.Replace(ret, " ", ", ", -1) 104 | } 105 | -------------------------------------------------------------------------------- /beam_search.go: -------------------------------------------------------------------------------- 1 | package gotalk 2 | 3 | import ( 4 | "flag" 5 | "math" 6 | "strings" 7 | 8 | tf "github.com/tensorflow/tensorflow/tensorflow/go" 9 | ) 10 | 11 | var ( 12 | sentenceSize = flag.Int("sentence_size", 20, "最大句子长度") 13 | beamSize = flag.Int("beam_size", 3, "beam search 搜索宽度") 14 | ) 15 | 16 | // 返回带 probability 的标题 17 | type Captions struct { 18 | Results []CaptionResult `json:"results"` 19 | } 20 | type CaptionResult struct { 21 | Probability float32 `json:"probability"` 22 | Sentence string `json:"sentence"` 23 | } 24 | 25 | // 可以多线程调用 26 | func GenerateCaption(session *tf.Session, graph *tf.Graph, vocab *Vocabulary, image *tf.Tensor) (Captions, error) { 27 | caps := Captions{} 28 | 29 | // 从图像得到 LSTM 初始状态 30 | initialState, err := session.Run( 31 | map[tf.Output]*tf.Tensor{ 32 | graph.Operation("image_feed").Output(0): image, 33 | }, 34 | []tf.Output{ 35 | graph.Operation("lstm/initial_state").Output(0), 36 | }, 37 | nil) 38 | if err != nil { 39 | return caps, err 40 | } 41 | 42 | // 初始化第一个 beam 43 | beam := Beam{ 44 | logProb: float64(0.0), 45 | sentence: []int64{vocab.StartId}, 46 | stateFeed: initialState[0].Value().([][]float32)[0], 47 | } 48 | tnb := TopNBeams{} 49 | tnb.Init(*beamSize) 50 | tnb.Push(beam) 51 | 52 | for iSentence := 0; iSentence < *sentenceSize; iSentence++ { 53 | // 从 beams 构造 LSTM 输入 54 | stateSeq := [][]float32{} 55 | inputSeq := []int64{} 56 | tnbOpen := TopNBeams{} 57 | tnbOpen.Init(*beamSize) 58 | for iBeam := 0; iBeam < tnb.size; iBeam++ { 59 | // 只添加尚未完成搜索的 beam,为了降低搜索空间大小 60 | if !tnb.beams[iBeam].isClosed { 61 | stateSeq = append(stateSeq, tnb.beams[iBeam].stateFeed) 62 | lenSentence := len(tnb.beams[iBeam].sentence) 63 | inputSeq = append(inputSeq, tnb.beams[iBeam].sentence[lenSentence-1]) 64 | tnbOpen.Push(tnb.beams[iBeam]) 65 | } 66 | 67 | } 68 | 69 | // 如果所有 beam 都已经完成,结束 70 | if tnbOpen.size == 0 { 71 | break 72 | } 73 | 74 | // 创建 TF graph 输入 tensor 75 | stateFeed, _ := tf.NewTensor(stateSeq) 76 | inputFeed, _ := tf.NewTensor(inputSeq) 77 | 78 | // 执行 LSTM 单元运算,注意这里的 batch size = tnbOpen.size 79 | output, err := session.Run( 80 | map[tf.Output]*tf.Tensor{ 81 | graph.Operation("input_feed").Output(0): inputFeed, 82 | graph.Operation("lstm/state_feed").Output(0): stateFeed, 83 | }, 84 | []tf.Output{ 85 | graph.Operation("softmax").Output(0), 86 | graph.Operation("lstm/state").Output(0), 87 | }, 88 | nil) 89 | if err != nil { 90 | return caps, err 91 | } 92 | softmax := output[0].Value().([][]float32) // softmax 保存的是词 one-hot encoding 的概率值 93 | lstmState := output[1].Value().([][]float32) // 下个 LSTM 计算的状态输入 94 | 95 | // newTnb 中将添加下一轮 LSTM 计算的所有 beam 96 | newTnb := TopNBeams{} 97 | newTnb.Init(*beamSize) 98 | 99 | // 先添加已经关闭的 beam 100 | for iBatch := 0; iBatch < tnb.size; iBatch++ { 101 | if tnb.beams[iBatch].isClosed { 102 | newTnb.Push(tnb.beams[iBatch]) 103 | continue 104 | } 105 | } 106 | 107 | // 然后添加所有新得到的 beam 108 | for iBatch := 0; iBatch < tnbOpen.size; iBatch++ { 109 | // 得到 top n 的概率 110 | sortedProb := topNSort(softmax[iBatch], *beamSize) 111 | 112 | // 添加新 beam 113 | for iWord := 0; iWord < len(sortedProb) && iWord < *beamSize; iWord++ { 114 | id := int64(sortedProb[iWord]) 115 | value := float64(softmax[iBatch][id]) 116 | 117 | se := make([]int64, len(tnbOpen.beams[iBatch].sentence)) 118 | copy(se, tnbOpen.beams[iBatch].sentence) 119 | se = append(se, id) 120 | beam := Beam{ 121 | logProb: tnbOpen.beams[iBatch].logProb + math.Log(value), 122 | sentence: se, 123 | stateFeed: lstmState[iBatch], 124 | } 125 | 126 | // 检查该 bean 是否已经结束 127 | if id == vocab.EndId || id == 3 { 128 | beam.isClosed = true 129 | } 130 | newTnb.Push(beam) 131 | } 132 | } 133 | tnb = newTnb 134 | } 135 | 136 | // 从最终结果生成返回的 Json 结构体 137 | for iBatch := 0; iBatch < tnb.size; iBatch++ { 138 | result := CaptionResult{} 139 | result.Probability = float32(math.Exp(tnb.beams[iBatch].logProb)) 140 | 141 | joinedSentence := "" 142 | sentence := tnb.beams[iBatch].sentence 143 | for iWord := 0; iWord < len(sentence); iWord++ { 144 | word := vocab.GetWord(sentence[iWord]) 145 | id := sentence[iWord] 146 | 147 | // 去除特殊字符 148 | if id != vocab.StartId && id != vocab.EndId && id != vocab.UnkId && word != "." { 149 | joinedSentence = joinedSentence + " " + word 150 | } 151 | } 152 | result.Sentence = strings.TrimSpace(joinedSentence) 153 | caps.Results = append(caps.Results, result) 154 | } 155 | return caps, nil 156 | } 157 | -------------------------------------------------------------------------------- /vendor/github.com/tensorflow/tensorflow/tensorflow/go/operation.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package tensorflow 18 | 19 | // #include 20 | // #include "tensorflow/c/c_api.h" 21 | import "C" 22 | 23 | import "unsafe" 24 | 25 | // Operation that has been added to the graph. 26 | type Operation struct { 27 | c *C.TF_Operation 28 | // A reference to the Graph to prevent it from 29 | // being GCed while the Operation is still alive. 30 | g *Graph 31 | } 32 | 33 | // Name returns the name of the operation. 34 | func (op *Operation) Name() string { 35 | return C.GoString(C.TF_OperationName(op.c)) 36 | } 37 | 38 | // Type returns the name of the operator used by this operation. 39 | func (op *Operation) Type() string { 40 | return C.GoString(C.TF_OperationOpType(op.c)) 41 | } 42 | 43 | // NumOutputs returns the number of outputs of op. 44 | func (op *Operation) NumOutputs() int { 45 | return int(C.TF_OperationNumOutputs(op.c)) 46 | } 47 | 48 | // OutputListSize returns the size of the list of Outputs that is produced by a 49 | // named output of op. 50 | // 51 | // An Operation has multiple named outputs, each of which produces either 52 | // a single tensor or a list of tensors. This method returns the size of 53 | // the list of tensors for a specific output of the operation, identified 54 | // by its name. 55 | func (op *Operation) OutputListSize(output string) (int, error) { 56 | cname := C.CString(output) 57 | defer C.free(unsafe.Pointer(cname)) 58 | status := newStatus() 59 | n := C.TF_OperationOutputListLength(op.c, cname, status.c) 60 | return int(n), status.Err() 61 | } 62 | 63 | // Output returns the i-th output of op. 64 | func (op *Operation) Output(i int) Output { 65 | return Output{op, i} 66 | } 67 | 68 | // Output represents one of the outputs of an operation in the graph. Has a 69 | // DataType (and eventually a Shape). May be passed as an input argument to a 70 | // function for adding operations to a graph, or to a Session's Run() method to 71 | // fetch that output as a tensor. 72 | type Output struct { 73 | // Op is the Operation that produces this Output. 74 | Op *Operation 75 | 76 | // Index specifies the index of the output within the Operation. 77 | Index int 78 | } 79 | 80 | // DataType returns the type of elements in the tensor produced by p. 81 | func (p Output) DataType() DataType { 82 | return DataType(C.TF_OperationOutputType(p.c())) 83 | } 84 | 85 | // Shape returns the (possibly incomplete) shape of the tensor produced p. 86 | func (p Output) Shape() Shape { 87 | status := newStatus() 88 | port := p.c() 89 | ndims := C.TF_GraphGetTensorNumDims(p.Op.g.c, port, status.c) 90 | if err := status.Err(); err != nil { 91 | // This should not be possible since an error only occurs if 92 | // the operation does not belong to the graph. It should not 93 | // be possible to construct such an Operation object. 94 | return Shape{} 95 | } 96 | if ndims < 0 { 97 | return Shape{} 98 | } 99 | if ndims == 0 { 100 | return ScalarShape() 101 | } 102 | dims := make([]C.int64_t, ndims) 103 | C.TF_GraphGetTensorShape(p.Op.g.c, port, &dims[0], ndims, status.c) 104 | if err := status.Err(); err != nil { 105 | // Same as above, should not be possible. 106 | return Shape{} 107 | } 108 | ret := Shape{dims: make([]int64, ndims)} 109 | for i := 0; i < int(ndims); i++ { 110 | ret.dims[i] = int64(dims[i]) 111 | } 112 | return ret 113 | } 114 | 115 | func (p Output) c() C.TF_Output { 116 | return C.TF_Output{oper: p.Op.c, index: C.int(p.Index)} 117 | } 118 | 119 | func (p Output) canBeAnInput() {} 120 | 121 | // Input is the interface for specifying inputs to an operation being added to 122 | // a Graph. 123 | // 124 | // Operations can have multiple inputs, each of which could be either a tensor 125 | // produced by another operation (an Output object), or a list of tensors 126 | // produced by other operations (an OutputList). Thus, this interface is 127 | // implemented by both Output and OutputList. 128 | // 129 | // See OpSpec.Input for more information. 130 | type Input interface { 131 | // Unexported to preclude implementations outside this package. 132 | canBeAnInput() 133 | } 134 | 135 | // OutputList represents a list of Outputs that can be provided as input to 136 | // another operation. 137 | type OutputList []Output 138 | 139 | func (l OutputList) canBeAnInput() {} 140 | -------------------------------------------------------------------------------- /vendor/github.com/tensorflow/tensorflow/tensorflow/go/README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow in Go 2 | 3 | Construct and execute TensorFlow graphs in Go. 4 | 5 | [![GoDoc](https://godoc.org/github.com/tensorflow/tensorflow/tensorflow/go?status.svg)](https://godoc.org/github.com/tensorflow/tensorflow/tensorflow/go) 6 | 7 | > *WARNING*: The API defined in this package is not stable and can change 8 | > without notice. The same goes for the awkward package path 9 | > (`github.com/tensorflow/tensorflow/tensorflow/go`). 10 | 11 | ## Quickstart 12 | 13 | 1. Download and extract the TensorFlow C library, preferably into `/usr/local`. 14 | GPU-enabled versions require CUDA 8.0 and cuDNN 5.1. For other versions, the 15 | TensorFlow C library will have to be built from source (see below). 16 | 17 | - Linux: 18 | [CPU-only](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-1.0.0.tar.gz), 19 | [GPU-enabled](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-1.0.0.tar.gz) 20 | - OS X 21 | [CPU-only](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-1.0.0.tar.gz), 22 | [GPU-enabled](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-darwin-x86_64-1.0.0.tar.gz) 23 | 24 | The following shell snippet downloads and extracts into `/usr/local`: 25 | 26 | ```sh 27 | TF_TYPE="cpu" # Set to "gpu" for GPU support 28 | curl -L \ 29 | "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.0.0.tar.gz" | 30 | sudo tar -C /usr/local -xz 31 | ``` 32 | 33 | 2. `go get` this package (and run tests): 34 | 35 | ```sh 36 | go get github.com/tensorflow/tensorflow/tensorflow/go 37 | go test github.com/tensorflow/tensorflow/tensorflow/go 38 | ``` 39 | 40 | 3. Done! 41 | 42 | ### Installing into locations other than `/usr/local` 43 | 44 | The TensorFlow C library (`libtensorflow.so`) needs to be available at build 45 | time (e.g., `go build`) and run time (`go test` or executing binaries). If the 46 | library has not been extracted into `/usr/local`, then it needs to be made 47 | available through the `LIBRARY_PATH` environment variable at build time and the 48 | `LD_LIBRARY_PATH` environment variable (`DYLD_LIBRARY_PATH` on OS X) at run 49 | time. 50 | 51 | For example, if the TensorFlow C library was extracted into `/dir`, then: 52 | 53 | ```sh 54 | export LIBRARY_PATH=/dir/lib 55 | export LD_LIBRARY_PATH=/dir/lib # For Linux 56 | export DYLD_LIBRARY_PATH=/dir/lib # For OS X 57 | ``` 58 | 59 | ## Building the TensorFlow C library from source 60 | 61 | If the "Quickstart" instructions above do not work (perhaps the release archives 62 | are not available for your operating system or architecture, or you're using a 63 | different version of CUDA/cuDNN), then the TensorFlow C library must be built 64 | from source. 65 | 66 | ### Prerequisites 67 | 68 | - [bazel](https://www.bazel.build/versions/master/docs/install.html) 69 | - Environment to build TensorFlow from source code 70 | ([Linux](https://www.tensorflow.org/versions/master/get_started/os_setup.html#prepare-environment-for-linux) 71 | or [OS 72 | X](https://www.tensorflow.org/versions/master/get_started/os_setup.html#prepare-environment-for-mac-os-x)). 73 | If you don't need GPU support, then try the following: `sh # Linux sudo 74 | apt-get install python swig python-numpy # OS X with homebrew brew install 75 | swig` 76 | 77 | ### Build 78 | 79 | 1. Download the source code 80 | 81 | ```sh 82 | go get -d github.com/tensorflow/tensorflow/tensorflow/go 83 | ``` 84 | 85 | 2. Build the TensorFlow C library: 86 | 87 | ```sh 88 | cd ${GOPATH}/src/github.com/tensorflow/tensorflow 89 | ./configure 90 | bazel build --config opt //tensorflow:libtensorflow.so 91 | ``` 92 | 93 | This can take a while (tens of minutes, more if also building for GPU). 94 | 95 | 3. Make `libtensorflow.so` available to the linker. This can be done by either: 96 | 97 | a. Copying it to a system location, e.g., 98 | 99 | ```sh 100 | sudo cp ${GOPATH}/src/github.com/tensorflow/tensorflow/bazel-bin/tensorflow/libtensorflow.so /usr/local/lib 101 | ``` 102 | 103 | OR 104 | 105 | b. Setting environment variables: 106 | 107 | ```sh 108 | export LIBRARY_PATH=${GOPATH}/src/github.com/tensorflow/tensorflow/bazel-bin/tensorflow 109 | # Linux 110 | export LD_LIBRARY_PATH=${GOPATH}/src/github.com/tensorflow/tensorflow/bazel-bin/tensorflow 111 | # OS X 112 | export DYLD_LIBRARY_PATH=${GOPATH}/src/github.com/tensorflow/tensorflow/bazel-bin/tensorflow 113 | ``` 114 | 115 | 4. Build and test: 116 | 117 | ```sh 118 | go test github.com/tensorflow/tensorflow/tensorflow/go 119 | ``` 120 | 121 | ### Generate wrapper functions for ops 122 | 123 | Go functions corresponding to TensorFlow operations are generated in `op/wrappers.go`. To regenerate them: 124 | 125 | Prerequisites: 126 | - [Protocol buffer compiler (protoc) 3.x](https://github.com/google/protobuf/releases/) 127 | - The TensorFlow repository under GOPATH 128 | 129 | ```sh 130 | go generate github.com/tensorflow/tensorflow/tensorflow/go/op 131 | ``` 132 | 133 | ## Support 134 | 135 | Use [stackoverflow](http://stackoverflow.com/questions/tagged/tensorflow) and/or 136 | [Github issues](https://github.com/tensorflow/tensorflow/issues). 137 | 138 | ## Contributions 139 | 140 | Contributions are welcome. If making any signification changes, probably best to 141 | discuss on a [Github issue](https://github.com/tensorflow/tensorflow/issues) 142 | before investing too much time. Github pull requests are used for contributions. 143 | -------------------------------------------------------------------------------- /vendor/github.com/tensorflow/tensorflow/tensorflow/go/graph.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package tensorflow 18 | 19 | // #include "tensorflow/c/c_api.h" 20 | // 21 | // #include 22 | // #include 23 | import "C" 24 | 25 | import ( 26 | "fmt" 27 | "io" 28 | "runtime" 29 | "unsafe" 30 | ) 31 | 32 | // Graph represents a computation graph. Graphs may be shared between sessions. 33 | type Graph struct { 34 | c *C.TF_Graph 35 | } 36 | 37 | // NewGraph returns a new Graph. 38 | func NewGraph() *Graph { 39 | g := &Graph{C.TF_NewGraph()} 40 | runtime.SetFinalizer(g, (*Graph).finalizer) 41 | return g 42 | } 43 | 44 | func (g *Graph) finalizer() { 45 | C.TF_DeleteGraph(g.c) 46 | } 47 | 48 | // WriteTo writes out a serialized representation of g to w. 49 | // 50 | // Implements the io.WriterTo interface. 51 | func (g *Graph) WriteTo(w io.Writer) (int64, error) { 52 | buf := C.TF_NewBuffer() 53 | defer C.TF_DeleteBuffer(buf) 54 | status := newStatus() 55 | C.TF_GraphToGraphDef(g.c, buf, status.c) 56 | if err := status.Err(); err != nil { 57 | return 0, err 58 | } 59 | if buf.length > (1 << 30) { 60 | // For very large graphs, the writes can be chunked. 61 | // Punt on that for now. 62 | return 0, fmt.Errorf("Graph is too large to write out, Graph.WriteTo needs to be updated") 63 | } 64 | // A []byte slice backed by C memory. 65 | // See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices 66 | length := int(buf.length) 67 | slice := (*[1 << 30]byte)(unsafe.Pointer(buf.data))[:length:length] 68 | n, err := w.Write(slice) 69 | return int64(n), err 70 | } 71 | 72 | // Import imports the nodes and edges from a serialized representation of 73 | // another Graph into g. 74 | // 75 | // Names of imported nodes will be prefixed with prefix. 76 | func (g *Graph) Import(def []byte, prefix string) error { 77 | cprefix := C.CString(prefix) 78 | defer C.free(unsafe.Pointer(cprefix)) 79 | 80 | opts := C.TF_NewImportGraphDefOptions() 81 | defer C.TF_DeleteImportGraphDefOptions(opts) 82 | C.TF_ImportGraphDefOptionsSetPrefix(opts, cprefix) 83 | 84 | buf := C.TF_NewBuffer() 85 | defer C.TF_DeleteBuffer(buf) 86 | // Would have preferred to use C.CBytes, but that does not play well 87 | // with "go vet" till https://github.com/golang/go/issues/17201 is 88 | // resolved. 89 | buf.length = C.size_t(len(def)) 90 | buf.data = C.malloc(buf.length) 91 | if buf.data == nil { 92 | return fmt.Errorf("unable to allocate memory") 93 | } 94 | defer C.free(buf.data) 95 | C.memcpy(buf.data, unsafe.Pointer(&def[0]), buf.length) 96 | 97 | status := newStatus() 98 | C.TF_GraphImportGraphDef(g.c, buf, opts, status.c) 99 | if err := status.Err(); err != nil { 100 | return err 101 | } 102 | return nil 103 | } 104 | 105 | // Operation returns the Operation named name in the Graph, or nil if no such 106 | // operation is present. 107 | func (g *Graph) Operation(name string) *Operation { 108 | cname := C.CString(name) 109 | defer C.free(unsafe.Pointer(cname)) 110 | cop := C.TF_GraphOperationByName(g.c, cname) 111 | if cop == nil { 112 | return nil 113 | } 114 | return &Operation{cop, g} 115 | } 116 | 117 | // OpSpec is the specification of an Operation to be added to a Graph 118 | // (using Graph.AddOperation). 119 | type OpSpec struct { 120 | // Type of the operation (e.g., "Add", "MatMul"). 121 | Type string 122 | 123 | // Name by which the added operation will be referred to in the Graph. 124 | // If omitted, defaults to Type. 125 | Name string 126 | 127 | // Inputs to this operation, which in turn must be outputs 128 | // of other operations already added to the Graph. 129 | // 130 | // An operation may have multiple inputs with individual inputs being 131 | // either a single tensor produced by another operation or a list of 132 | // tensors produced by multiple operations. For example, the "Concat" 133 | // operation takes two inputs: (1) the dimension along which to 134 | // concatenate and (2) a list of tensors to concatenate. Thus, for 135 | // Concat, len(Input) must be 2, with the first element being an Output 136 | // and the second being an OutputList. 137 | Input []Input 138 | 139 | // Map from attribute name to its value that will be attached to this 140 | // operation. 141 | Attrs map[string]interface{} 142 | 143 | // Other possible fields: Device, ColocateWith, ControlInputs. 144 | } 145 | 146 | // AddOperation adds an operation to g. 147 | func (g *Graph) AddOperation(args OpSpec) (*Operation, error) { 148 | if args.Name == "" { 149 | args.Name = args.Type 150 | } 151 | cname := C.CString(args.Name) 152 | ctype := C.CString(args.Type) 153 | cdesc := C.TF_NewOperation(g.c, ctype, cname) 154 | C.free(unsafe.Pointer(cname)) 155 | C.free(unsafe.Pointer(ctype)) 156 | 157 | for _, in := range args.Input { 158 | switch in := in.(type) { 159 | case Output: 160 | C.TF_AddInput(cdesc, in.c()) 161 | case OutputList: 162 | size := len(in) 163 | list := make([]C.TF_Output, size) 164 | for i, v := range in { 165 | list[i] = v.c() 166 | } 167 | if size > 0 { 168 | C.TF_AddInputList(cdesc, &list[0], C.int(size)) 169 | } else { 170 | C.TF_AddInputList(cdesc, nil, 0) 171 | } 172 | } 173 | } 174 | status := newStatus() 175 | for name, value := range args.Attrs { 176 | if err := setAttr(cdesc, status, name, value); err != nil { 177 | // Memory leak here as the TF_OperationDescription 178 | // object will not be cleaned up. At the time of this 179 | // writing, this was next to impossible since it 180 | // required value to be a string tensor with 181 | // incorrectly encoded strings. Given this rarity, live 182 | // with the memory leak. If it becomes a real problem, 183 | // consider adding a TF_DeleteOperationDescription 184 | // function to the C API. 185 | return nil, fmt.Errorf("%v (memory will be leaked)", err) 186 | } 187 | } 188 | op := &Operation{ 189 | c: C.TF_FinishOperation(cdesc, status.c), 190 | g: g, 191 | } 192 | return op, status.Err() 193 | } 194 | 195 | func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, value interface{}) error { 196 | cAttrName := C.CString(name) 197 | defer C.free(unsafe.Pointer(cAttrName)) 198 | switch value := value.(type) { 199 | case string: 200 | cstr := C.CString(value) 201 | C.TF_SetAttrString(cdesc, cAttrName, unsafe.Pointer(cstr), C.size_t(len(value))) 202 | C.free(unsafe.Pointer(cstr)) 203 | case []string: 204 | size := len(value) 205 | list := make([]unsafe.Pointer, size) 206 | lens := make([]C.size_t, size) 207 | for i, s := range value { 208 | list[i] = unsafe.Pointer(C.CString(s)) 209 | lens[i] = C.size_t(len(s)) 210 | } 211 | if size > 0 { 212 | C.TF_SetAttrStringList(cdesc, cAttrName, &list[0], &lens[0], C.int(size)) 213 | } else { 214 | C.TF_SetAttrStringList(cdesc, cAttrName, nil, nil, 0) 215 | } 216 | for _, s := range list { 217 | C.free(s) 218 | } 219 | case int64: 220 | C.TF_SetAttrInt(cdesc, cAttrName, C.int64_t(value)) 221 | case []int64: 222 | size := len(value) 223 | list := make([]C.int64_t, size) 224 | for i, v := range value { 225 | list[i] = C.int64_t(v) 226 | } 227 | if size > 0 { 228 | C.TF_SetAttrIntList(cdesc, cAttrName, &list[0], C.int(size)) 229 | } else { 230 | C.TF_SetAttrIntList(cdesc, cAttrName, nil, 0) 231 | } 232 | case float32: 233 | C.TF_SetAttrFloat(cdesc, cAttrName, C.float(value)) 234 | case []float32: 235 | size := len(value) 236 | list := make([]C.float, size) 237 | for i, v := range value { 238 | list[i] = C.float(v) 239 | } 240 | if size > 0 { 241 | C.TF_SetAttrFloatList(cdesc, cAttrName, &list[0], C.int(size)) 242 | } else { 243 | C.TF_SetAttrFloatList(cdesc, cAttrName, nil, 0) 244 | } 245 | case bool: 246 | v := C.uchar(0) 247 | if value { 248 | v = 1 249 | } 250 | C.TF_SetAttrBool(cdesc, cAttrName, v) 251 | case []bool: 252 | size := len(value) 253 | list := make([]C.uchar, size) 254 | for i, v := range value { 255 | if v { 256 | list[i] = 1 257 | } 258 | } 259 | if size > 0 { 260 | C.TF_SetAttrBoolList(cdesc, cAttrName, &list[0], C.int(size)) 261 | } else { 262 | C.TF_SetAttrBoolList(cdesc, cAttrName, nil, 0) 263 | } 264 | case DataType: 265 | C.TF_SetAttrType(cdesc, cAttrName, C.TF_DataType(value)) 266 | case []DataType: 267 | var list *C.TF_DataType 268 | if len(value) > 0 { 269 | list = (*C.TF_DataType)(&value[0]) 270 | } 271 | C.TF_SetAttrTypeList(cdesc, cAttrName, list, C.int(len(value))) 272 | case *Tensor: 273 | C.TF_SetAttrTensor(cdesc, cAttrName, value.c, status.c) 274 | if err := status.Err(); err != nil { 275 | return fmt.Errorf("bad value for attribute %q: %v", name, err) 276 | } 277 | case []*Tensor: 278 | size := len(value) 279 | list := make([]*C.TF_Tensor, size) 280 | for i, v := range value { 281 | list[i] = v.c 282 | } 283 | var plist **C.TF_Tensor 284 | if size > 0 { 285 | plist = &list[0] 286 | } 287 | C.TF_SetAttrTensorList(cdesc, cAttrName, plist, C.int(size), status.c) 288 | if err := status.Err(); err != nil { 289 | return fmt.Errorf("bad value for attribute %q: %v", name, err) 290 | } 291 | case Shape: 292 | ndims, dims := cshape(value) 293 | var dimsp *C.int64_t 294 | if ndims > 0 { 295 | dimsp = &dims[0] 296 | } 297 | C.TF_SetAttrShape(cdesc, cAttrName, dimsp, ndims) 298 | case []Shape: 299 | ndims := make([]C.int, len(value)) 300 | dims := make([][]C.int64_t, len(value)) 301 | dimsp := make([]*C.int64_t, len(value)) 302 | for i, s := range value { 303 | ndims[i], dims[i] = cshape(s) 304 | if ndims[i] > 0 { 305 | dimsp[i] = &dims[i][0] 306 | } 307 | } 308 | if len(value) > 0 { 309 | C.TF_SetAttrShapeList(cdesc, cAttrName, &dimsp[0], &ndims[0], C.int(len(value))) 310 | } else { 311 | C.TF_SetAttrShapeList(cdesc, cAttrName, nil, nil, 0) 312 | } 313 | default: 314 | return fmt.Errorf("attribute %q has a type (%T) which is not valid for operation attributes", name, value) 315 | } 316 | return nil 317 | } 318 | 319 | func cshape(s Shape) (C.int, []C.int64_t) { 320 | ndims := C.int(s.NumDimensions()) 321 | if ndims < 0 { 322 | return -1, nil 323 | } 324 | dims := make([]C.int64_t, ndims) 325 | for i, s := range s.dims { 326 | dims[i] = C.int64_t(s) 327 | } 328 | return ndims, dims 329 | } 330 | -------------------------------------------------------------------------------- /vendor/github.com/tensorflow/tensorflow/tensorflow/go/session.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package tensorflow 18 | 19 | // #include 20 | // #include "tensorflow/c/c_api.h" 21 | import "C" 22 | 23 | import ( 24 | "errors" 25 | "fmt" 26 | "runtime" 27 | "sync" 28 | "unsafe" 29 | ) 30 | 31 | // Session drives a TensorFlow graph computation. 32 | // 33 | // When a Session is created with a given target, a new Session object is bound 34 | // to the universe of resources specified by that target. Those resources are 35 | // available to this session to perform computation described in the GraphDef. 36 | // After creating the session with a graph, the caller uses the Run() API to 37 | // perform the computation and potentially fetch outputs as Tensors. 38 | // A Session allows concurrent calls to Run(). 39 | type Session struct { 40 | c *C.TF_Session 41 | 42 | // For ensuring that: 43 | // - Close() blocks on all Run() calls to complete. 44 | // - Close() can be called multiple times. 45 | wg sync.WaitGroup 46 | mu sync.Mutex 47 | } 48 | 49 | // NewSession creates a new execution session with the associated graph. 50 | // options may be nil to use the default options. 51 | func NewSession(graph *Graph, options *SessionOptions) (*Session, error) { 52 | status := newStatus() 53 | cOpt, doneOpt, err := options.c() 54 | defer doneOpt() 55 | if err != nil { 56 | return nil, err 57 | } 58 | cSess := C.TF_NewSession(graph.c, cOpt, status.c) 59 | if err := status.Err(); err != nil { 60 | return nil, err 61 | } 62 | 63 | s := &Session{c: cSess} 64 | runtime.SetFinalizer(s, func(s *Session) { s.Close() }) 65 | return s, nil 66 | } 67 | 68 | // Run the graph with the associated session starting with the supplied feeds 69 | // to compute the value of the requested fetches. Runs, but does not return 70 | // Tensors for operations specified in targets. 71 | // 72 | // On success, returns the fetched Tensors in the same order as supplied in 73 | // the fetches argument. If fetches is set to nil, the returned Tensor fetches 74 | // is empty. 75 | func (s *Session) Run(feeds map[Output]*Tensor, fetches []Output, targets []*Operation) ([]*Tensor, error) { 76 | s.mu.Lock() 77 | if s.c == nil { 78 | s.mu.Unlock() 79 | return nil, errors.New("session is closed") 80 | } 81 | s.wg.Add(1) 82 | s.mu.Unlock() 83 | defer s.wg.Done() 84 | 85 | c := newCRunArgs(feeds, fetches, targets) 86 | status := newStatus() 87 | C.TF_SessionRun(s.c, nil, 88 | ptrOutput(c.feeds), ptrTensor(c.feedTensors), C.int(len(feeds)), 89 | ptrOutput(c.fetches), ptrTensor(c.fetchTensors), C.int(len(fetches)), 90 | ptrOperation(c.targets), C.int(len(targets)), 91 | nil, status.c) 92 | if err := status.Err(); err != nil { 93 | return nil, err 94 | } 95 | return c.toGo(), nil 96 | } 97 | 98 | // PartialRun enables incremental evaluation of graphs. 99 | // 100 | // PartialRun allows the caller to pause the evaluation of a graph, run 101 | // arbitrary code that depends on the intermediate computation of the graph, 102 | // and then resume graph execution. The results of the arbitrary code can be 103 | // fed into the graph when resuming execution. In contrast, Session.Run 104 | // executes the graph to compute the requested fetches using the provided feeds 105 | // and discards all intermediate state (e.g., value of intermediate tensors) 106 | // when it returns. 107 | // 108 | // For example, consider a graph for unsupervised training of a neural network 109 | // model. PartialRun can be used to pause execution after the forward pass of 110 | // the network, let the caller actuate the output (e.g., play a game, actuate a 111 | // robot etc.), determine the error/loss and then feed this calculated loss 112 | // when resuming the backward pass of the graph. 113 | type PartialRun struct { 114 | session *Session 115 | handle *C.char 116 | } 117 | 118 | // Run resumes execution of the graph to compute the requested fetches and 119 | // targets with the provided feeds. 120 | func (pr *PartialRun) Run(feeds map[Output]*Tensor, fetches []Output, targets []*Operation) ([]*Tensor, error) { 121 | var ( 122 | c = newCRunArgs(feeds, fetches, targets) 123 | status = newStatus() 124 | s = pr.session 125 | ) 126 | s.mu.Lock() 127 | if s.c == nil { 128 | s.mu.Unlock() 129 | return nil, errors.New("session is closed") 130 | } 131 | s.wg.Add(1) 132 | s.mu.Unlock() 133 | defer s.wg.Done() 134 | 135 | C.TF_SessionPRun(s.c, pr.handle, 136 | ptrOutput(c.feeds), ptrTensor(c.feedTensors), C.int(len(feeds)), 137 | ptrOutput(c.fetches), ptrTensor(c.fetchTensors), C.int(len(fetches)), 138 | ptrOperation(c.targets), C.int(len(targets)), 139 | status.c) 140 | if err := status.Err(); err != nil { 141 | return nil, err 142 | } 143 | return c.toGo(), nil 144 | } 145 | 146 | // NewPartialRun sets up the graph for incremental evaluation. 147 | // 148 | // All values of feeds, fetches and targets that may be provided to Run calls 149 | // on the returned PartialRun need to be provided to NewPartialRun. 150 | // 151 | // See documentation for the PartialRun type. 152 | func (s *Session) NewPartialRun(feeds, fetches []Output, targets []*Operation) (*PartialRun, error) { 153 | var ( 154 | cfeeds = make([]C.TF_Output, len(feeds)) 155 | cfetches = make([]C.TF_Output, len(fetches)) 156 | ctargets = make([]*C.TF_Operation, len(targets)) 157 | 158 | pcfeeds *C.TF_Output 159 | pcfetches *C.TF_Output 160 | pctargets **C.TF_Operation 161 | 162 | status = newStatus() 163 | ) 164 | if len(feeds) > 0 { 165 | pcfeeds = &cfeeds[0] 166 | for i, o := range feeds { 167 | cfeeds[i] = o.c() 168 | } 169 | } 170 | if len(fetches) > 0 { 171 | pcfetches = &cfetches[0] 172 | for i, o := range fetches { 173 | cfetches[i] = o.c() 174 | } 175 | } 176 | if len(targets) > 0 { 177 | pctargets = &ctargets[0] 178 | for i, o := range targets { 179 | ctargets[i] = o.c 180 | } 181 | } 182 | 183 | s.mu.Lock() 184 | if s.c == nil { 185 | s.mu.Unlock() 186 | return nil, errors.New("session is closed") 187 | } 188 | s.wg.Add(1) 189 | s.mu.Unlock() 190 | defer s.wg.Done() 191 | 192 | pr := &PartialRun{session: s} 193 | C.TF_SessionPRunSetup(s.c, 194 | pcfeeds, C.int(len(feeds)), 195 | pcfetches, C.int(len(fetches)), 196 | pctargets, C.int(len(targets)), 197 | &pr.handle, status.c) 198 | if err := status.Err(); err != nil { 199 | return nil, err 200 | } 201 | runtime.SetFinalizer(pr, func(pr *PartialRun) { 202 | deletePRunHandle(pr.handle) 203 | }) 204 | return pr, nil 205 | } 206 | 207 | // Close a session. This contacts any other processes associated with this 208 | // session, if applicable. Blocks until all previous calls to Run have returned. 209 | func (s *Session) Close() error { 210 | s.mu.Lock() 211 | defer s.mu.Unlock() 212 | s.wg.Wait() 213 | if s.c == nil { 214 | return nil 215 | } 216 | status := newStatus() 217 | C.TF_CloseSession(s.c, status.c) 218 | if err := status.Err(); err != nil { 219 | return err 220 | } 221 | C.TF_DeleteSession(s.c, status.c) 222 | s.c = nil 223 | return status.Err() 224 | } 225 | 226 | // SessionOptions contains configuration information for a session. 227 | type SessionOptions struct { 228 | // Target indicates the TensorFlow runtime to connect to. 229 | // 230 | // If 'target' is empty or unspecified, the local TensorFlow runtime 231 | // implementation will be used. Otherwise, the TensorFlow engine 232 | // defined by 'target' will be used to perform all computations. 233 | // 234 | // "target" can be either a single entry or a comma separated list 235 | // of entries. Each entry is a resolvable address of one of the 236 | // following formats: 237 | // local 238 | // ip:port 239 | // host:port 240 | // ... other system-specific formats to identify tasks and jobs ... 241 | // 242 | // NOTE: at the moment 'local' maps to an in-process service-based 243 | // runtime. 244 | // 245 | // Upon creation, a single session affines itself to one of the 246 | // remote processes, with possible load balancing choices when the 247 | // "target" resolves to a list of possible processes. 248 | // 249 | // If the session disconnects from the remote process during its 250 | // lifetime, session calls may fail immediately. 251 | Target string 252 | 253 | // Config is a binary-serialized representation of the 254 | // tensorflow.ConfigProto protocol message 255 | // (https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto). 256 | Config []byte 257 | } 258 | 259 | // c converts the SessionOptions to the C API's TF_SessionOptions. Callers must 260 | // deallocate by calling the returned done() closure. 261 | func (o *SessionOptions) c() (ret *C.TF_SessionOptions, done func(), err error) { 262 | opt := C.TF_NewSessionOptions() 263 | if o == nil { 264 | return opt, func() { C.TF_DeleteSessionOptions(opt) }, nil 265 | } 266 | t := C.CString(o.Target) 267 | C.TF_SetTarget(opt, t) 268 | C.free(unsafe.Pointer(t)) 269 | 270 | var cConfig unsafe.Pointer 271 | if sz := len(o.Config); sz > 0 { 272 | status := newStatus() 273 | // Copying into C-memory is the simplest thing to do in terms 274 | // of memory safety and cgo rules ("C code may not keep a copy 275 | // of a Go pointer after the call returns" from 276 | // https://golang.org/cmd/cgo/#hdr-Passing_pointers). 277 | cConfig = C.CBytes(o.Config) 278 | C.TF_SetConfig(opt, cConfig, C.size_t(sz), status.c) 279 | if err := status.Err(); err != nil { 280 | C.TF_DeleteSessionOptions(opt) 281 | return nil, func() {}, fmt.Errorf("invalid SessionOptions.Config: %v", err) 282 | } 283 | } 284 | return opt, func() { 285 | C.TF_DeleteSessionOptions(opt) 286 | C.free(cConfig) 287 | }, nil 288 | } 289 | 290 | // cRunArgs translates the arguments to Session.Run and PartialRun.Run into 291 | // values suitable for C library calls. 292 | type cRunArgs struct { 293 | feeds []C.TF_Output 294 | feedTensors []*C.TF_Tensor 295 | fetches []C.TF_Output 296 | fetchTensors []*C.TF_Tensor 297 | targets []*C.TF_Operation 298 | } 299 | 300 | func newCRunArgs(feeds map[Output]*Tensor, fetches []Output, targets []*Operation) *cRunArgs { 301 | c := &cRunArgs{ 302 | fetches: make([]C.TF_Output, len(fetches)), 303 | fetchTensors: make([]*C.TF_Tensor, len(fetches)), 304 | targets: make([]*C.TF_Operation, len(targets)), 305 | } 306 | for o, t := range feeds { 307 | c.feeds = append(c.feeds, o.c()) 308 | c.feedTensors = append(c.feedTensors, t.c) 309 | } 310 | for i, o := range fetches { 311 | c.fetches[i] = o.c() 312 | } 313 | for i, t := range targets { 314 | c.targets[i] = t.c 315 | } 316 | return c 317 | } 318 | 319 | func (c *cRunArgs) toGo() []*Tensor { 320 | ret := make([]*Tensor, len(c.fetchTensors)) 321 | for i, ct := range c.fetchTensors { 322 | ret[i] = newTensorFromC(ct) 323 | } 324 | return ret 325 | } 326 | 327 | func ptrOutput(l []C.TF_Output) *C.TF_Output { 328 | if len(l) == 0 { 329 | return nil 330 | } 331 | return &l[0] 332 | } 333 | 334 | func ptrTensor(l []*C.TF_Tensor) **C.TF_Tensor { 335 | if len(l) == 0 { 336 | return nil 337 | } 338 | return &l[0] 339 | } 340 | 341 | func ptrOperation(l []*C.TF_Operation) **C.TF_Operation { 342 | if len(l) == 0 { 343 | return nil 344 | } 345 | return &l[0] 346 | } 347 | -------------------------------------------------------------------------------- /vendor/github.com/tensorflow/tensorflow/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2017 The TensorFlow Authors. All rights reserved. 2 | 3 | Apache License 4 | Version 2.0, January 2004 5 | http://www.apache.org/licenses/ 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "[]" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright 2017, The TensorFlow Authors. 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. 204 | -------------------------------------------------------------------------------- /vendor/github.com/tensorflow/tensorflow/tensorflow/go/tensor.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package tensorflow 18 | 19 | // #include 20 | // #include 21 | // #include "tensorflow/c/c_api.h" 22 | import "C" 23 | 24 | import ( 25 | "bytes" 26 | "encoding/binary" 27 | "fmt" 28 | "io" 29 | "reflect" 30 | "runtime" 31 | "unsafe" 32 | ) 33 | 34 | // DataType holds the type for a scalar value. E.g., one slot in a tensor. 35 | type DataType C.TF_DataType 36 | 37 | // Types of scalar values in the TensorFlow type system. 38 | const ( 39 | Float DataType = C.TF_FLOAT 40 | Double DataType = C.TF_DOUBLE 41 | Int32 DataType = C.TF_INT32 42 | Uint8 DataType = C.TF_UINT8 43 | Int16 DataType = C.TF_INT16 44 | Int8 DataType = C.TF_INT8 45 | String DataType = C.TF_STRING 46 | Complex64 DataType = C.TF_COMPLEX64 47 | Complex DataType = C.TF_COMPLEX 48 | Int64 DataType = C.TF_INT64 49 | Bool DataType = C.TF_BOOL 50 | Qint8 DataType = C.TF_QINT8 51 | Quint8 DataType = C.TF_QUINT8 52 | Qint32 DataType = C.TF_QINT32 53 | Bfloat16 DataType = C.TF_BFLOAT16 54 | Qint16 DataType = C.TF_QINT16 55 | Quint16 DataType = C.TF_QUINT16 56 | Uint16 DataType = C.TF_UINT16 57 | Complex128 DataType = C.TF_COMPLEX128 58 | Half DataType = C.TF_HALF 59 | ) 60 | 61 | // Tensor holds a multi-dimensional array of elements of a single data type. 62 | type Tensor struct { 63 | c *C.TF_Tensor 64 | shape []int64 65 | } 66 | 67 | // NewTensor converts from a Go value to a Tensor. Valid values are scalars, 68 | // slices, and arrays. Every element of a slice must have the same length so 69 | // that the resulting Tensor has a valid shape. 70 | func NewTensor(value interface{}) (*Tensor, error) { 71 | val := reflect.ValueOf(value) 72 | shape, dataType, err := shapeAndDataTypeOf(val) 73 | if err != nil { 74 | return nil, err 75 | } 76 | nflattened := numElements(shape) 77 | nbytes := typeOf(dataType, nil).Size() * uintptr(nflattened) 78 | if dataType == String { 79 | // TF_STRING tensors are encoded as an array of 8-byte offsets 80 | // followed by string data. See c_api.h. 81 | nbytes = uintptr(nflattened*8) + byteSizeOfEncodedStrings(value) 82 | } 83 | var shapePtr *C.int64_t 84 | if len(shape) > 0 { 85 | shapePtr = (*C.int64_t)(unsafe.Pointer(&shape[0])) 86 | } 87 | t := &Tensor{ 88 | c: C.TF_AllocateTensor(C.TF_DataType(dataType), shapePtr, C.int(len(shape)), C.size_t(nbytes)), 89 | shape: shape, 90 | } 91 | runtime.SetFinalizer(t, (*Tensor).finalize) 92 | raw := tensorData(t.c) 93 | buf := bytes.NewBuffer(raw[:0:len(raw)]) 94 | if dataType != String { 95 | if err := encodeTensor(buf, val); err != nil { 96 | return nil, err 97 | } 98 | if uintptr(buf.Len()) != nbytes { 99 | return nil, bug("NewTensor incorrectly calculated the size of a tensor with type %v and shape %v as %v bytes instead of %v", dataType, shape, nbytes, buf.Len()) 100 | } 101 | } else { 102 | e := stringEncoder{offsets: buf, data: raw[nflattened*8 : len(raw)], status: newStatus()} 103 | if e.encode(reflect.ValueOf(value)); err != nil { 104 | return nil, err 105 | } 106 | if int64(buf.Len()) != nflattened*8 { 107 | return nil, bug("invalid offset encoding for TF_STRING tensor with shape %v (got %v, want %v)", shape, buf.Len(), nflattened*8) 108 | } 109 | } 110 | return t, nil 111 | } 112 | 113 | // ReadTensor constructs a Tensor with the provided type and shape from the 114 | // serialized tensor contents in r. 115 | // 116 | // See also WriteContentsTo. 117 | func ReadTensor(dataType DataType, shape []int64, r io.Reader) (*Tensor, error) { 118 | if err := isTensorSerializable(dataType); err != nil { 119 | return nil, err 120 | } 121 | nbytes := typeOf(dataType, nil).Size() * uintptr(numElements(shape)) 122 | var shapePtr *C.int64_t 123 | if len(shape) > 0 { 124 | shapePtr = (*C.int64_t)(unsafe.Pointer(&shape[0])) 125 | } 126 | t := &Tensor{ 127 | c: C.TF_AllocateTensor(C.TF_DataType(dataType), shapePtr, C.int(len(shape)), C.size_t(nbytes)), 128 | shape: shape, 129 | } 130 | runtime.SetFinalizer(t, (*Tensor).finalize) 131 | raw := tensorData(t.c) 132 | n, err := r.Read(raw) 133 | if err != nil { 134 | return nil, err 135 | } 136 | if uintptr(n) != nbytes { 137 | return nil, fmt.Errorf("expected serialized tensor to be %v bytes, read %v", nbytes, n) 138 | } 139 | return t, nil 140 | } 141 | 142 | // newTensorFromC takes ownership of c and returns the owning Tensor. 143 | func newTensorFromC(c *C.TF_Tensor) *Tensor { 144 | var shape []int64 145 | if ndims := int(C.TF_NumDims(c)); ndims > 0 { 146 | shape = make([]int64, ndims) 147 | } 148 | for i := range shape { 149 | shape[i] = int64(C.TF_Dim(c, C.int(i))) 150 | } 151 | t := &Tensor{c: c, shape: shape} 152 | runtime.SetFinalizer(t, (*Tensor).finalize) 153 | return t 154 | } 155 | 156 | func (t *Tensor) finalize() { C.TF_DeleteTensor(t.c) } 157 | 158 | // DataType returns the scalar datatype of the Tensor. 159 | func (t *Tensor) DataType() DataType { return DataType(C.TF_TensorType(t.c)) } 160 | 161 | // Shape returns the shape of the Tensor. 162 | func (t *Tensor) Shape() []int64 { return t.shape } 163 | 164 | // Value converts the Tensor to a Go value. For now, not all Tensor types are 165 | // supported, and this function may panic if it encounters an unsupported 166 | // DataType. 167 | // 168 | // The type of the output depends on the Tensor type and dimensions. 169 | // For example: 170 | // Tensor(int64, 0): int64 171 | // Tensor(float64, 3): [][][]float64 172 | func (t *Tensor) Value() interface{} { 173 | typ := typeOf(t.DataType(), t.Shape()) 174 | val := reflect.New(typ) 175 | raw := tensorData(t.c) 176 | if t.DataType() != String { 177 | if err := decodeTensor(bytes.NewReader(raw), t.Shape(), typ, val); err != nil { 178 | panic(bug("unable to decode Tensor of type %v and shape %v - %v", t.DataType(), t.Shape(), err)) 179 | } 180 | } else { 181 | nflattened := numElements(t.Shape()) 182 | d := stringDecoder{offsets: bytes.NewReader(raw[0 : 8*nflattened]), data: raw[8*nflattened:], status: newStatus()} 183 | if err := d.decode(val, t.Shape()); err != nil { 184 | panic(bug("unable to decode String tensor with shape %v - %v", t.Shape(), err)) 185 | } 186 | } 187 | return reflect.Indirect(val).Interface() 188 | } 189 | 190 | // WriteContentsTo writes the serialized contents of t to w. 191 | // 192 | // Returns the number of bytes written. See ReadTensor for 193 | // reconstructing a Tensor from the serialized form. 194 | // 195 | // WARNING: WriteContentsTo is not comprehensive and will fail 196 | // if t.DataType() is non-numeric (e.g., String). See 197 | // https://github.com/tensorflow/tensorflow/issues/6003. 198 | func (t *Tensor) WriteContentsTo(w io.Writer) (int64, error) { 199 | if err := isTensorSerializable(t.DataType()); err != nil { 200 | return 0, err 201 | } 202 | return io.Copy(w, bytes.NewReader(tensorData(t.c))) 203 | } 204 | 205 | func tensorData(c *C.TF_Tensor) []byte { 206 | // See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices 207 | cbytes := C.TF_TensorData(c) 208 | length := int(C.TF_TensorByteSize(c)) 209 | slice := (*[1 << 30]byte)(unsafe.Pointer(cbytes))[:length:length] 210 | return slice 211 | } 212 | 213 | var types = []struct { 214 | typ reflect.Type 215 | dataType C.TF_DataType 216 | }{ 217 | {reflect.TypeOf(float32(0)), C.TF_FLOAT}, 218 | {reflect.TypeOf(float64(0)), C.TF_DOUBLE}, 219 | {reflect.TypeOf(int32(0)), C.TF_INT32}, 220 | {reflect.TypeOf(uint8(0)), C.TF_UINT8}, 221 | {reflect.TypeOf(int16(0)), C.TF_INT16}, 222 | {reflect.TypeOf(int8(0)), C.TF_INT8}, 223 | {reflect.TypeOf(""), C.TF_STRING}, 224 | {reflect.TypeOf(complex(float32(0), float32(0))), C.TF_COMPLEX64}, 225 | {reflect.TypeOf(int64(0)), C.TF_INT64}, 226 | {reflect.TypeOf(false), C.TF_BOOL}, 227 | {reflect.TypeOf(uint16(0)), C.TF_UINT16}, 228 | {reflect.TypeOf(complex(float64(0), float64(0))), C.TF_COMPLEX128}, 229 | // TODO(apassos): support DT_RESOURCE representation in go. 230 | } 231 | 232 | // shapeAndDataTypeOf returns the data type and shape of the Tensor 233 | // corresponding to a Go type. 234 | func shapeAndDataTypeOf(val reflect.Value) (shape []int64, dt DataType, err error) { 235 | typ := val.Type() 236 | for typ.Kind() == reflect.Array || typ.Kind() == reflect.Slice { 237 | shape = append(shape, int64(val.Len())) 238 | // If slice elements are slices, verify that all of them have the same size. 239 | // Go's type system makes that guarantee for arrays. 240 | if val.Len() > 0 { 241 | if val.Type().Elem().Kind() == reflect.Slice { 242 | expected := val.Index(0).Len() 243 | for i := 1; i < val.Len(); i++ { 244 | if val.Index(i).Len() != expected { 245 | return shape, dt, fmt.Errorf("mismatched slice lengths: %d and %d", val.Index(i).Len(), expected) 246 | } 247 | } 248 | } 249 | val = val.Index(0) 250 | } 251 | typ = typ.Elem() 252 | } 253 | for _, t := range types { 254 | if typ.Kind() == t.typ.Kind() { 255 | return shape, DataType(t.dataType), nil 256 | } 257 | } 258 | return shape, dt, fmt.Errorf("unsupported type %v", typ) 259 | } 260 | 261 | // typeOf converts from a DataType and Shape to the equivalent Go type. 262 | func typeOf(dt DataType, shape []int64) reflect.Type { 263 | var ret reflect.Type 264 | for _, t := range types { 265 | if dt == DataType(t.dataType) { 266 | ret = t.typ 267 | break 268 | } 269 | } 270 | if ret == nil { 271 | panic(bug("DataType %v is not supported", dt)) 272 | } 273 | for _ = range shape { 274 | ret = reflect.SliceOf(ret) 275 | } 276 | return ret 277 | } 278 | 279 | func numElements(shape []int64) int64 { 280 | n := int64(1) 281 | for _, d := range shape { 282 | n *= d 283 | } 284 | return n 285 | } 286 | 287 | // byteSizeOfEncodedStrings returns the size of the encoded strings in val. 288 | // val MUST be a string, or a container (array/slice etc.) of strings. 289 | func byteSizeOfEncodedStrings(val interface{}) uintptr { 290 | if s, ok := val.(string); ok { 291 | return uintptr(C.TF_StringEncodedSize(C.size_t(len(s)))) 292 | } 293 | // Otherwise must be an array or slice. 294 | var size uintptr 295 | v := reflect.ValueOf(val) 296 | for i := 0; i < v.Len(); i++ { 297 | size += byteSizeOfEncodedStrings(v.Index(i).Interface()) 298 | } 299 | return size 300 | } 301 | 302 | // encodeTensor writes v to the specified buffer using the format specified in 303 | // c_api.h. Use stringEncoder for String tensors. 304 | func encodeTensor(w *bytes.Buffer, v reflect.Value) error { 305 | switch v.Kind() { 306 | case reflect.Bool: 307 | b := byte(0) 308 | if v.Bool() { 309 | b = 1 310 | } 311 | if err := w.WriteByte(b); err != nil { 312 | return err 313 | } 314 | case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: 315 | if err := binary.Write(w, nativeEndian, v.Interface()); err != nil { 316 | return err 317 | } 318 | 319 | case reflect.Array, reflect.Slice: 320 | // If slice elements are slices, verify that all of them have the same size. 321 | // Go's type system makes that guarantee for arrays. 322 | if v.Len() > 0 && v.Type().Elem().Kind() == reflect.Slice { 323 | expected := v.Index(0).Len() 324 | for i := 1; i < v.Len(); i++ { 325 | if v.Index(i).Len() != expected { 326 | return fmt.Errorf("mismatched slice lengths: %d and %d", v.Index(i).Len(), expected) 327 | } 328 | } 329 | } 330 | 331 | for i := 0; i < v.Len(); i++ { 332 | err := encodeTensor(w, v.Index(i)) 333 | if err != nil { 334 | return err 335 | } 336 | } 337 | 338 | default: 339 | return fmt.Errorf("unsupported type %v", v.Type()) 340 | } 341 | return nil 342 | } 343 | 344 | // decodeTensor decodes the Tensor from the buffer to ptr using the format 345 | // specified in c_api.h. Use stringDecoder for String tensors. 346 | func decodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.Value) error { 347 | switch typ.Kind() { 348 | case reflect.Bool: 349 | b, err := r.ReadByte() 350 | if err != nil { 351 | return err 352 | } 353 | ptr.Elem().SetBool(b == 1) 354 | case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: 355 | if err := binary.Read(r, nativeEndian, ptr.Interface()); err != nil { 356 | return err 357 | } 358 | 359 | case reflect.Slice: 360 | val := reflect.Indirect(ptr) 361 | val.Set(reflect.MakeSlice(typ, int(shape[0]), int(shape[0]))) 362 | for i := 0; i < val.Len(); i++ { 363 | if err := decodeTensor(r, shape[1:], typ.Elem(), val.Index(i).Addr()); err != nil { 364 | return err 365 | } 366 | } 367 | 368 | default: 369 | return fmt.Errorf("unsupported type %v", typ) 370 | } 371 | return nil 372 | } 373 | 374 | type stringEncoder struct { 375 | offsets io.Writer 376 | data []byte 377 | offset uint64 378 | status *status 379 | } 380 | 381 | func (e *stringEncoder) encode(v reflect.Value) error { 382 | if v.Kind() == reflect.String { 383 | if err := binary.Write(e.offsets, nativeEndian, e.offset); err != nil { 384 | return err 385 | } 386 | var ( 387 | s = v.Interface().(string) 388 | src = C.CString(s) 389 | srcLen = C.size_t(len(s)) 390 | dst = (*C.char)(unsafe.Pointer(&e.data[e.offset])) 391 | dstLen = C.size_t(uint64(len(e.data)) - e.offset) 392 | ) 393 | e.offset += uint64(C.TF_StringEncode(src, srcLen, dst, dstLen, e.status.c)) 394 | C.free(unsafe.Pointer(src)) 395 | return e.status.Err() 396 | } 397 | for i := 0; i < v.Len(); i++ { 398 | if err := e.encode(v.Index(i)); err != nil { 399 | return err 400 | } 401 | } 402 | return nil 403 | } 404 | 405 | type stringDecoder struct { 406 | offsets io.Reader 407 | data []byte 408 | status *status 409 | } 410 | 411 | func (d *stringDecoder) decode(ptr reflect.Value, shape []int64) error { 412 | if len(shape) == 0 { 413 | var offset uint64 414 | if err := binary.Read(d.offsets, nativeEndian, &offset); err != nil { 415 | return err 416 | } 417 | var ( 418 | src = (*C.char)(unsafe.Pointer(&d.data[offset])) 419 | srcLen = C.size_t(len(d.data)) - C.size_t(offset) 420 | dst *C.char 421 | dstLen C.size_t 422 | ) 423 | if offset > uint64(len(d.data)) { 424 | return fmt.Errorf("invalid offsets in String Tensor") 425 | } 426 | C.TF_StringDecode(src, srcLen, &dst, &dstLen, d.status.c) 427 | if err := d.status.Err(); err != nil { 428 | return err 429 | } 430 | s := ptr.Interface().(*string) 431 | *s = C.GoStringN(dst, C.int(dstLen)) 432 | return nil 433 | } 434 | val := reflect.Indirect(ptr) 435 | val.Set(reflect.MakeSlice(typeOf(String, shape), int(shape[0]), int(shape[0]))) 436 | for i := 0; i < val.Len(); i++ { 437 | if err := d.decode(val.Index(i).Addr(), shape[1:]); err != nil { 438 | return err 439 | } 440 | } 441 | return nil 442 | } 443 | 444 | func bug(format string, args ...interface{}) error { 445 | return fmt.Errorf("BUG: Please report at https://github.com/tensorflow/tensorflow/issues with the note: Go TensorFlow %v: %v", Version(), fmt.Sprintf(format, args...)) 446 | } 447 | 448 | func isTensorSerializable(dataType DataType) error { 449 | // For numeric types, the serialized Tensor matches the in-memory 450 | // representation. See the implementation of Tensor::AsProtoContent in 451 | // https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc 452 | // 453 | // The more appropriate way to be in sync with Tensor::AsProtoContent 454 | // would be to have the TensorFlow C library export functions for 455 | // serialization and deserialization of Tensors. Till then capitalize 456 | // on knowledge of the implementation for numeric types. 457 | switch dataType { 458 | case Float, Double, Int32, Uint8, Int16, Int8, Complex, Int64, Bool, Quint8, Qint32, Bfloat16, Qint16, Quint16, Uint16, Complex128, Half: 459 | return nil 460 | default: 461 | return fmt.Errorf("serialization of tensors with the DataType %d is not yet supported, see https://github.com/tensorflow/tensorflow/issues/6003", dataType) 462 | } 463 | } 464 | 465 | // nativeEndian is the byte order for the local platform. Used to send back and 466 | // forth Tensors with the C API. We test for endianness at runtime because 467 | // some architectures can be booted into different endian modes. 468 | var nativeEndian binary.ByteOrder 469 | 470 | func init() { 471 | buf := [2]byte{} 472 | *(*uint16)(unsafe.Pointer(&buf[0])) = uint16(0xABCD) 473 | 474 | switch buf { 475 | case [2]byte{0xCD, 0xAB}: 476 | nativeEndian = binary.LittleEndian 477 | case [2]byte{0xAB, 0xCD}: 478 | nativeEndian = binary.BigEndian 479 | default: 480 | panic("Could not determine native endianness.") 481 | } 482 | } 483 | --------------------------------------------------------------------------------