├── typings.d.ts ├── src ├── images │ ├── a.jpg │ ├── 111.jpg │ ├── 222.jpeg │ ├── 333.jpg │ ├── dog.jpg │ ├── test.png │ ├── horses.jpg │ ├── people.jpg │ ├── street.jpg │ ├── street1.jpg │ └── timg.jpeg ├── style.css ├── index.html ├── index.ts └── yolo │ ├── config.js │ └── yolo-eval.ts ├── docs ├── img │ └── demo1.jpg └── index.html ├── .gitignore ├── tsconfig.json ├── .eslintrc ├── package.json ├── README.md ├── README_EN.md └── webpack.config.js /typings.d.ts: -------------------------------------------------------------------------------- 1 | declare module '@antv/g2' 2 | declare module '@antv/data-set' -------------------------------------------------------------------------------- /src/images/a.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqingr/tfjs-yolov3/HEAD/src/images/a.jpg -------------------------------------------------------------------------------- /docs/img/demo1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqingr/tfjs-yolov3/HEAD/docs/img/demo1.jpg -------------------------------------------------------------------------------- /src/images/111.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqingr/tfjs-yolov3/HEAD/src/images/111.jpg -------------------------------------------------------------------------------- /src/images/222.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqingr/tfjs-yolov3/HEAD/src/images/222.jpeg -------------------------------------------------------------------------------- /src/images/333.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqingr/tfjs-yolov3/HEAD/src/images/333.jpg -------------------------------------------------------------------------------- /src/images/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqingr/tfjs-yolov3/HEAD/src/images/dog.jpg -------------------------------------------------------------------------------- /src/images/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqingr/tfjs-yolov3/HEAD/src/images/test.png -------------------------------------------------------------------------------- /src/images/horses.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqingr/tfjs-yolov3/HEAD/src/images/horses.jpg -------------------------------------------------------------------------------- /src/images/people.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqingr/tfjs-yolov3/HEAD/src/images/people.jpg -------------------------------------------------------------------------------- /src/images/street.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqingr/tfjs-yolov3/HEAD/src/images/street.jpg -------------------------------------------------------------------------------- /src/images/street1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqingr/tfjs-yolov3/HEAD/src/images/street1.jpg -------------------------------------------------------------------------------- /src/images/timg.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqingr/tfjs-yolov3/HEAD/src/images/timg.jpeg -------------------------------------------------------------------------------- /src/style.css: -------------------------------------------------------------------------------- 1 | #img-box { 2 | position: relative; 3 | } 4 | #img-box .rect { 5 | position: absolute; 6 | border: 2px solid #f00; 7 | } 8 | #img-box .rect .className { 9 | position: absolute; 10 | top: 0; 11 | background: #f00; 12 | color: #fff; 13 | 14 | } 15 | #test-canvas { 16 | } -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | Document 8 | 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # See http://help.github.com/ignore-files/ for more about ignoring files. 2 | 3 | # compiled output 4 | /tmp 5 | /dist 6 | 7 | # dependencies 8 | /node_modules 9 | /bower_components 10 | /src/model 11 | 12 | # IDEs and editors 13 | /.idea 14 | .vscode/* 15 | .project 16 | .classpath 17 | *.launch 18 | .settings/ 19 | 20 | # misc 21 | /.sass-cache 22 | /connect.lock 23 | /coverage/* 24 | /libpeerconnection.log 25 | npm-debug.log 26 | testem.log 27 | /typings 28 | 29 | # e2e 30 | /e2e/*.js 31 | /e2e/*.map 32 | 33 | #System Files 34 | .DS_Store 35 | Thumbs.db 36 | node.test.js 37 | 38 | yarn.lock 39 | 40 | webpack.config.prod.js -------------------------------------------------------------------------------- /src/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | Document 8 | 9 | 10 |
11 | 12 |
13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "outDir": "./dist/", 4 | "sourceMap": true, 5 | "noImplicitAny": true, 6 | "module": "es6", 7 | "moduleResolution": "node", 8 | "target": "es5", 9 | "allowJs": true, 10 | "allowSyntheticDefaultImports": true, 11 | "experimentalDecorators": true, 12 | "isolatedModules": false, 13 | "noImplicitThis": true, 14 | "strictNullChecks": true, 15 | "removeComments": true, 16 | "suppressImplicitAnyIndexErrors": true, 17 | "lib" : ["es2015", "es2017", "dom"], 18 | "baseUrl": "types", 19 | "typeRoots": [ 20 | "node_modules/@types" 21 | ], 22 | "paths": { 23 | "jquery-ui": [ 24 | "node_modules/@types/jqueryui/index" 25 | ] 26 | } 27 | }, 28 | "exclude": [ 29 | "node_modules" 30 | ] 31 | } -------------------------------------------------------------------------------- /.eslintrc: -------------------------------------------------------------------------------- 1 | { 2 | "extends": [ 3 | "standard" 4 | ], 5 | "plugins": [ 6 | "html", 7 | "import", 8 | "babel" 9 | ], 10 | "env": { 11 | "browser": true, 12 | "node": true, 13 | "es6": true, 14 | "jquery": true, 15 | "commonjs": true, 16 | "phantomjs": true, 17 | "mocha": true 18 | }, 19 | "rules": { 20 | "no-confusing-arrow": "off", 21 | "babel/new-cap": 1, 22 | "no-undef": "off", 23 | "import/no-unresolved": "off", 24 | "import/extensions": "off", 25 | "import/no-webpack-loader-syntax": "off", 26 | "keyword-spacing": "off", 27 | "no-unused-vars": [ 0, { "varsIgnorePattern": "Component$", "args": "none" } ], 28 | "no-new": 0, 29 | "indent": [ 30 | "error", 31 | 2, 32 | { "ignoredNodes": [ "Program" ] } 33 | ] 34 | }, 35 | "parser": "typescript-eslint-parser", 36 | "parserOptions": { 37 | "sourceType": "module", 38 | "ecmaFeatures": { 39 | "jsx": true 40 | } 41 | } 42 | } 43 | 44 | -------------------------------------------------------------------------------- /src/index.ts: -------------------------------------------------------------------------------- 1 | import { yolov3Tiny } from './yolo/yolo-eval' 2 | // import { yolov3Tiny } from '../dist/index.bundle' 3 | 4 | import './style.css' 5 | // import * as tf from '@tensorflow/tfjs' 6 | 7 | // async function load () { 8 | // const model = await tf.loadModel('/model/yolov3-tiny/model.json') 9 | // // const model = await tf.loadModel('https://zqingr.github.io/tfjs-yolov3-demo/model/yolov3-tiny/model.json') 10 | 11 | // model.summary() 12 | // } 13 | // load() 14 | 15 | const $img = document.getElementById('img') as HTMLImageElement 16 | 17 | async function start () { 18 | // // const model = await tf.loadModel('https://zqingr.github.io/tfjs-yolov3-demo/model/yolov3-tiny/model.json') 19 | const yolo = await yolov3Tiny() 20 | const boxes = await yolo($img) 21 | 22 | const $imgbox = document.getElementById('img-box') as HTMLElement 23 | 24 | boxes.forEach(box => { 25 | const $div = document.createElement('div') 26 | $div.className = 'rect' 27 | $div.style.top = box.top + 'px' 28 | $div.style.left = box.left + 'px' 29 | $div.style.width = box.width + 'px' 30 | $div.style.height = box.height + 'px' 31 | $div.innerHTML = `${box.classes} ${box.scores}` 32 | 33 | $imgbox.appendChild($div) 34 | }) 35 | 36 | console.log(boxes) 37 | } 38 | start() 39 | 40 | // import yolov3 from './yolo/yolo-eval' 41 | 42 | // export default yolov3 43 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "random-data", 3 | "version": "1.0.0", 4 | "description": "", 5 | "main": "webpack.config.js", 6 | "scripts": { 7 | "test": "echo \"Error: no test specified\" && exit 1", 8 | "start": "cross-env NODE_ENV=dev PLATFORM=web webpack-dev-server --progress --host 0.0.0.0 --disableHostCheck=true", 9 | "build": "cross-env NODE_ENV=pro PLATFORM=web webpack --progress --config ./webpack.config.prod.js --optimize-minimize" 10 | }, 11 | "author": "", 12 | "license": "ISC", 13 | "devDependencies": { 14 | "@types/d3": "^5.7.1", 15 | "cross-env": "^5.2.0", 16 | "eslint": "^5.15.1", 17 | "eslint-config-standard": "^12.0.0", 18 | "eslint-loader": "^2.1.2", 19 | "eslint-plugin-babel": "^5.3.0", 20 | "eslint-plugin-html": "^5.0.3", 21 | "eslint-plugin-import": "^2.16.0", 22 | "eslint-plugin-node": "^8.0.1", 23 | "eslint-plugin-promise": "^4.0.1", 24 | "eslint-plugin-standard": "^4.0.0", 25 | "file-loader": "^3.0.1", 26 | "ts-loader": "^5.3.3", 27 | "ts-node": "^8.0.3", 28 | "typescript": "^3.3.3333", 29 | "typescript-eslint-parser": "^22.0.0", 30 | "url-loader": "^1.1.2", 31 | "webpack": "^4.29.6", 32 | "webpack-cli": "^3.2.3", 33 | "webpack-dev-server": "^3.2.1" 34 | }, 35 | "dependencies": { 36 | "@antv/data-set": "^0.10.2", 37 | "@antv/g2": "^3.4.10", 38 | "@tensorflow/tfjs": "1.0.0", 39 | "@types/echarts": "^4.1.5", 40 | "css-loader": "^2.1.1", 41 | "echarts": "^4.1.0", 42 | "node-fetch": "^2.3.0", 43 | "style-loader": "^0.23.1" 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /src/yolo/config.js: -------------------------------------------------------------------------------- 1 | export const ANCHORS = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326] 2 | export const ANCHORS_TINY = [10,14, 23,27, 37,58, 81,82, 135,169, 344,319] 3 | export const COCO_CLASSESS = [ 4 | 'person', 5 | 'bicycle', 6 | 'car', 7 | 'motorbike', 8 | 'aeroplane', 9 | 'bus', 10 | 'train', 11 | 'truck', 12 | 'boat', 13 | 'traffic light', 14 | 'fire hydrant', 15 | 'stop sign', 16 | 'parking meter', 17 | 'bench', 18 | 'bird', 19 | 'cat', 20 | 'dog', 21 | 'horse', 22 | 'sheep', 23 | 'cow', 24 | 'elephant', 25 | 'bear', 26 | 'zebra', 27 | 'giraffe', 28 | 'backpack', 29 | 'umbrella', 30 | 'handbag', 31 | 'tie', 32 | 'suitcase', 33 | 'frisbee', 34 | 'skis', 35 | 'snowboard', 36 | 'sports ball', 37 | 'kite', 38 | 'baseball bat', 39 | 'baseball glove', 40 | 'skateboard', 41 | 'surfboard', 42 | 'tennis racket', 43 | 'bottle', 44 | 'wine glass', 45 | 'cup', 46 | 'fork', 47 | 'knife', 48 | 'spoon', 49 | 'bowl', 50 | 'banana', 51 | 'apple', 52 | 'sandwich', 53 | 'orange', 54 | 'broccoli', 55 | 'carrot', 56 | 'hot dog', 57 | 'pizza', 58 | 'donut', 59 | 'cake', 60 | 'chair', 61 | 'sofa', 62 | 'pottedplant', 63 | 'bed', 64 | 'diningtable', 65 | 'toilet', 66 | 'tvmonitor', 67 | 'laptop', 68 | 'mouse', 69 | 'remote', 70 | 'keyboard', 71 | 'cell phone', 72 | 'microwave', 73 | 'oven', 74 | 'toaster', 75 | 'sink', 76 | 'refrigerator', 77 | 'book', 78 | 'clock', 79 | 'vase', 80 | 'scissors', 81 | 'teddy bear', 82 | 'hair drier', 83 | 'toothbrush', 84 | ] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 中文 | [English](./README_EN.md) 2 | 3 | # tfjs-yolov3 4 | 5 | ## 介绍 6 | 7 | 完全用js来实现图片中的目标检测 8 | 基于yolov3算法和Tensorflow.js库 9 | 用tensorflow.js实现yolov3和yolov3-tiny 10 | 11 | 需要注意的是: 必须是Tensorflow.js@v0.12.4版本以上 12 | 13 | ## 特点 14 | - 可以识别**任意尺寸**的图片 15 | - 同时支持yolov3和yolov3-tiny 16 | 17 | ## 快速开始 18 | 19 | ### 安装 20 | 21 | ``` 22 | npm install tfjs-yolov3 23 | ``` 24 | 25 | ### 用法示例 26 | 27 | ```javascript 28 | import { yolov3, yolov3Tiny } from 'tfjs-yolov3' 29 | 30 | async function start () { 31 | const yolo = await yolov3Tiny() // pre-load model (35M) 32 | // or 33 | // const yolo = await yolov3() // pre-load model (245M) 34 | 35 | const $img = document.getElementById('img') 36 | const boxes = await yolo($img) 37 | draw(boxes) // Some draw function 38 | } 39 | start() 40 | ``` 41 | 42 | ## API 文档 43 | 44 | yolov3和yolov3Tiny函数接受一个options对象,并返回一个函数 45 | 46 | ```typescript 47 | export declare function yolov3 ( 48 | { modelUrl, anchors }? : 49 | { modelUrl?: string, anchors?: number[] } 50 | ): Promise 51 | 52 | export declare function yolov3Tiny ( 53 | { modelUrl, anchors }? : 54 | { modelUrl?: string, anchors?: number[] } 55 | ): Promise 56 | ``` 57 | 58 | | 参数 | 说明 | 59 | | ------------ | ------------ | 60 | | modelUrl | 可选,预训练model的url,可把model下载到本地,加快预训练model的加载速度,[点我下载](https://github.com/zqingr/tfjs-yolov3/releases/tag/v1.0) | 61 | | anchors | 可选,可自定义anchores,格式参考[config](https://github.com/zqingr/tfjs-yolov3/blob/master/src/yolo/config.js) | 62 | 63 | 这两个函数调用后会加载预训练model,并返回一个函数,可用这个函数去识别图片,并返回识别后的box列表,参数如下: 64 | 65 | ```typescript 66 | type yolo = ($img: HTMLImageElement) => Promise 67 | 68 | interface Box { 69 | top: number 70 | left: number 71 | bottom: number 72 | right: number 73 | width: number 74 | height: number 75 | scores: number 76 | classes: string 77 | } 78 | ``` 79 | 80 | 81 | 82 | ## DEMO 83 | 84 | [点击查看在线DEMO](https://zqingr.github.io/tfjs-yolov3-demo/) 85 | 86 | ![demo](./docs/img/demo1.jpg) 87 | 88 | 89 | -------------------------------------------------------------------------------- /README_EN.md: -------------------------------------------------------------------------------- 1 | # tfjs-yolov3 2 | 3 | ### Introduction 4 | 5 | A Tensorflow.js implementation of YOLOv3 and YOLOv3-tiny 6 | 7 | Note: Must Tensorflow.js@v0.12.4+ 8 | 9 | # Features 10 | - can recognize images of **any size** 11 | - Support both **yolov3** and **yolov3-tiny** 12 | 13 | ## Quick Start 14 | 15 | ### Install 16 | 17 | ``` 18 | npm install tfjs-yolov3 19 | ``` 20 | 21 | ### Usage Example 22 | 23 | ```javascript 24 | import { yolov3, yolov3Tiny } from 'tfjs-yolov3' 25 | 26 | async function start () { 27 | const yolo = await yolov3Tiny() // pre-load model (35M) 28 | // or 29 | // const yolo = await yolov3() // pre-load model (245M) 30 | 31 | const $img = document.getElementById('img') 32 | const boxes = await yolo($img) 33 | draw(boxes) // Some draw function 34 | } 35 | start() 36 | ``` 37 | 38 | 39 | ## API DOC 40 | 41 | The yolov3 and yolov3Tiny functions accept an options object and return a function 42 | 43 | ```typescript 44 | export declare function yolov3 ( 45 | { modelUrl, anchors }? : 46 | { modelUrl?: string, anchors?: number[] } 47 | ): Promise 48 | 49 | export declare function yolov3Tiny ( 50 | { modelUrl, anchors }? : 51 | { modelUrl?: string, anchors?: number[] } 52 | ): Promise 53 | ``` 54 | 55 | | Parameters | Description | 56 | | ------------ | ------------ | 57 | | modelUrl | Optional, pre-train the model's url, you can [download](https://github.com/zqingr/tfjs-yolov3/releases/tag/v1.0) the model to the local, speed up the loading of the pre-training model | 58 | | anchors | Optional, custom anchors, format reference[config](https://github.com/zqingr/tfjs-yolov3/blob/master/src/yolo/config.js) | 59 | 60 | After the above two functions are called, the pre-training model will be loaded, and a function will be returned. This function can be used to identify the image and return the identified box list. The parameters are as follows: 61 | 62 | ```typescript 63 | type yolo = ($img: HTMLImageElement) => Promise 64 | 65 | interface Box { 66 | top: number 67 | left: number 68 | bottom: number 69 | right: number 70 | width: number 71 | height: number 72 | scores: number 73 | classes: string 74 | } 75 | ``` 76 | 77 | 78 | 79 | 80 | ## DEMO 81 | 82 | [Check out the Live Demo](https://zqingr.github.io/tfjs-yolov3-demo/) 83 | 84 | ![demo](./docs/img/demo1.jpg) 85 | 86 | 87 | -------------------------------------------------------------------------------- /webpack.config.js: -------------------------------------------------------------------------------- 1 | const path = require('path'); 2 | 3 | module.exports = function (env) { 4 | return { 5 | mode: 'development', 6 | context: path.join(process.cwd(), 'src'), 7 | entry: { 8 | 'index': `./index.ts` 9 | }, 10 | output: { 11 | path: path.resolve(__dirname, 'dist'), 12 | filename: '[name].bundle.js', 13 | libraryTarget: "umd" 14 | }, 15 | resolve: { 16 | // Add `.ts` and `.tsx` as a resolvable extension. 17 | extensions: [".ts", ".tsx", ".js"] 18 | }, 19 | module: { 20 | rules: [ 21 | { 22 | test: /\.ts$/, 23 | enforce: "pre", 24 | exclude: /node_modules/, 25 | loader: "eslint-loader", 26 | options: { 27 | fix: true 28 | } 29 | }, 30 | { 31 | type: 'javascript/auto', 32 | test: /\.(json)$/, 33 | exclude: /node_modules/, 34 | loader: [ 35 | `file-loader?publicPath=/&name=[name].[ext]` 36 | ] 37 | }, 38 | { 39 | test: /\.ts$/, 40 | exclude: /node_modules/, 41 | loader: 'ts-loader', 42 | options: { 43 | appendTsSuffixTo: [/\.vue$/, /\.ts$/], 44 | transpileOnly: true 45 | } 46 | }, 47 | { 48 | test: /\.html$/, 49 | loader: 'raw-loader', 50 | exclude: path.join(__dirname, './src/index.html') 51 | }, 52 | { 53 | test: /\.(jpg|jpeg|gif|png)$/, 54 | loader: [ 55 | `url-loader?limit=4112&publicPath=/&name=[name].[ext]` 56 | ] 57 | }, 58 | { 59 | test: /\.css$/, 60 | exclude: /node_modules/, 61 | use: [ 62 | 'style-loader', 63 | 'css-loader' 64 | ] 65 | }] 66 | }, 67 | externals: { 68 | "tf": "@tensorflow/tfjs" 69 | }, 70 | devServer: { 71 | port: 8000, 72 | host: 'localhost', 73 | historyApiFallback: true, 74 | watchOptions: { 75 | aggregateTimeout: 300, 76 | poll: 1000 77 | }, 78 | contentBase: ['src'], 79 | open: false, 80 | stats: { 81 | assets: true, 82 | children: false, 83 | chunks: false, 84 | hash: false, 85 | modules: false, 86 | publicPath: false, 87 | timings: true, 88 | version: false, 89 | warnings: true, 90 | colors: { 91 | green: '\u001b[32m', 92 | } 93 | } 94 | } 95 | } 96 | } -------------------------------------------------------------------------------- /src/yolo/yolo-eval.ts: -------------------------------------------------------------------------------- 1 | import * as tf from '@tensorflow/tfjs' 2 | import { COCO_CLASSESS, ANCHORS, ANCHORS_TINY } from './config' 3 | 4 | const grenerateArr = (num: number) => 5 | Array(num) 6 | .fill(0) 7 | .map((v, i) => i) 8 | 9 | function yoloHead ( 10 | feats: tf.Tensor, 11 | anchors: tf.Tensor2D, 12 | numClasses: number, 13 | inputShape: number[] 14 | ) { 15 | return tf.tidy(() => { 16 | const numAnchors = anchors.shape[0] 17 | 18 | // Reshape to height, width, num_anchors, box_params. 19 | const anchorsTensor = anchors.reshape([1, numAnchors, 2]) 20 | 21 | const gridShape = feats.shape.slice(0, 2) // height, width 22 | const gridY = tf.tile( 23 | tf.reshape(grenerateArr(gridShape[0]), [-1, 1, 1, 1]), 24 | [1, gridShape[1], 1, 1] 25 | ) 26 | const gridX = tf.tile( 27 | tf.reshape(grenerateArr(gridShape[1]), [1, -1, 1, 1]), 28 | [gridShape[0], 1, 1, 1] 29 | ) 30 | 31 | let grid = gridX.concat(gridY, 3) 32 | 33 | grid = tf.cast(grid, feats.dtype) 34 | const newfeats = tf.reshape(feats, [ 35 | gridShape[0], 36 | gridShape[1], 37 | numAnchors, 38 | numClasses + 5 39 | ]) 40 | 41 | // Adjust preditions to each spatial grid point and anchor size. 42 | const [xy, wh, con, probs] = tf.split(newfeats, [2, 2, 1, 80], 3) 43 | 44 | const boxXY = tf.div(tf.add(tf.sigmoid(xy), grid), gridShape.reverse()) 45 | const boxWH = tf.div( 46 | tf.mul(tf.exp(wh), anchorsTensor), 47 | inputShape.reverse() 48 | ) 49 | 50 | const boxConfidence = tf.sigmoid(con) 51 | const boxClassProbs = tf.sigmoid(probs) 52 | 53 | return [boxXY, boxWH, boxConfidence, boxClassProbs] 54 | }) 55 | } 56 | 57 | // Get corrected boxes 58 | 59 | function yoloCorrectBoxes ( 60 | boxXY: tf.Tensor, 61 | boxWH: tf.Tensor, 62 | inputShape: number[], 63 | imageShape: number[] 64 | ) { 65 | // boxXY.print() 66 | return tf.tidy(() => { 67 | let boxYX = tf.concat(tf.split(boxXY, [1, 1], 3).reverse(), 3) 68 | let boxHW = tf.concat(tf.split(boxWH, [1, 1], 3).reverse(), 3) 69 | 70 | const scale = tf.div(inputShape, imageShape) 71 | boxYX = tf.div(tf.mul(boxYX, inputShape), scale) 72 | boxHW = tf.div(tf.mul(boxHW, inputShape), scale) 73 | 74 | const boxMins = tf.sub(boxYX, tf.div(boxHW, 2)) 75 | const boxMaxes = tf.add(boxYX, tf.div(boxHW, 2)) 76 | 77 | const boxes = tf.concat( 78 | [...tf.split(boxMins, [1, 1], 3), ...tf.split(boxMaxes, [1, 1], 3)], 79 | 3 80 | ) 81 | 82 | return boxes 83 | }) 84 | } 85 | 86 | /** 87 | * Process Conv layer output 88 | */ 89 | function yoloBoxesAndScores ( 90 | feats: tf.Tensor, 91 | anchors: tf.Tensor2D, 92 | numClasses: number, 93 | inputShape: number[], 94 | imageShape: number[] 95 | ) { 96 | const [boxXY, boxWH, boxConfidence, boxClassProbs] = yoloHead( 97 | feats, 98 | anchors, 99 | numClasses, 100 | inputShape 101 | ) 102 | let boxes = yoloCorrectBoxes(boxXY, boxWH, inputShape, imageShape) 103 | boxes = boxes.reshape([-1, 4]) 104 | 105 | let boxScores = tf.mul(boxConfidence, boxClassProbs) 106 | boxScores = boxScores.reshape([-1, numClasses]) 107 | 108 | return [boxes, boxScores] 109 | } 110 | 111 | /** 112 | * Evaluate YOLO model on given input and return filtered boxes. 113 | */ 114 | async function yoloEval ( 115 | output: tf.Tensor[], 116 | anchors: tf.Tensor2D, 117 | numberClasses: number, 118 | imageShape: number[], 119 | maxBoxs: number = 20, 120 | scoreThreshold: number = 0.3, 121 | iouThreshold: number = 0.45 122 | ) { 123 | const numLayers = output.length 124 | const anchorMask = 125 | numLayers === 3 ? [[6, 7, 8], [3, 4, 5], [0, 1, 2]] : [[3, 4, 5], [1, 2, 3]] // default setting 126 | 127 | const inputShape = output[0].shape.slice(0, 2).map(num => num * 32) 128 | const boxesArr = [] 129 | const boxScoresArr = [] 130 | 131 | for (let index = 0; index < numLayers; index++) { 132 | const [_boxes, _boxScores] = yoloBoxesAndScores( 133 | output[index], 134 | anchors.gather(tf.cast(tf.tensor1d(anchorMask[index]), 'int32')), 135 | numberClasses, 136 | inputShape, 137 | imageShape 138 | ) 139 | boxesArr.push(_boxes) 140 | boxScoresArr.push(_boxScores) 141 | } 142 | 143 | const boxes = tf.concat(boxesArr, 0) 144 | const boxScores = tf.concat(boxScoresArr, 0) 145 | 146 | let boxes_: Float32Array[][] = [] 147 | let scores_: Float32Array[] = [] 148 | let classes_: number[] = [] 149 | 150 | const splitBoxScores = tf.split(boxScores, Array(numberClasses).fill(1), 1) 151 | 152 | for (let index = 0; index < numberClasses; index++) { 153 | const nmsIndex = await tf.image.nonMaxSuppressionAsync( 154 | boxes as tf.Tensor, 155 | splitBoxScores[index].reshape([-1]), 156 | maxBoxs, 157 | iouThreshold, 158 | scoreThreshold 159 | ) 160 | if (!nmsIndex.size) continue 161 | 162 | const classBoxes = tf.gather(boxes, nmsIndex) 163 | const classBoxScores = tf.gather(splitBoxScores[index], nmsIndex) 164 | 165 | boxes_ = boxes_.concat( 166 | tf 167 | .split(classBoxes, Array(nmsIndex.size).fill(1)) 168 | .map(d => d.dataSync() as Float32Array) 169 | ) 170 | scores_ = scores_.concat( 171 | tf 172 | .split(classBoxScores, Array(nmsIndex.size).fill(1)) 173 | .map(d => d.dataSync() as Float32Array) 174 | ) 175 | classes_ = classes_.concat(Array(nmsIndex.size).fill(index)) 176 | 177 | classBoxScores.dispose() 178 | classBoxes.dispose() 179 | } 180 | 181 | boxes.dispose() 182 | boxScores.dispose() 183 | 184 | return boxes_.map((box, i) => { 185 | return { 186 | top: box[0], 187 | left: box[1], 188 | bottom: box[2], 189 | right: box[3], 190 | width: +box[3] - +box[1], 191 | height: +box[2] - +box[0], 192 | scores: scores_[i][0], 193 | classes: COCO_CLASSESS[classes_[i]] 194 | } 195 | }) 196 | } 197 | 198 | const $canvas = document.createElement('canvas') 199 | $canvas.width = 416 200 | $canvas.height = 416 201 | const ctx = $canvas.getContext('2d') as CanvasRenderingContext2D 202 | 203 | export async function yolov3Tiny ({ 204 | modelUrl = 'https://zqingr.github.io/tfjs-yolov3-demo/model/yolov3-tiny/model.json', 205 | anchors = ANCHORS_TINY 206 | }: { modelUrl?: string; anchors?: number[] } = {}) { 207 | const yoloTinyData = await yolo({ modelUrl, anchors }) 208 | return yoloTinyData 209 | } 210 | export async function yolov3 ({ 211 | modelUrl = 'https://zqingr.github.io/tfjs-yolov3-demo/model/yolov3/model.json', 212 | anchors = ANCHORS 213 | }: { modelUrl?: string; anchors?: number[] } = {}) { 214 | const yoloData = await yolo({ modelUrl, anchors }) 215 | return yoloData 216 | } 217 | 218 | async function yolo ({ 219 | modelUrl, 220 | anchors 221 | }: { 222 | modelUrl: string 223 | anchors: number[] 224 | }) { 225 | const localModels = await tf.io.listModels() 226 | if (localModels['indexeddb://yolov3-1']) { 227 | const model = await tf.loadLayersModel('indexeddb://yolov3-1') 228 | } else { 229 | const model = await tf.loadLayersModel(modelUrl) 230 | await model.save('indexeddb://yolov3-1') 231 | } 232 | 233 | return async ($img: HTMLImageElement) => { 234 | ctx.drawImage($img, 0, 0, 416, 416) 235 | 236 | const sample = tf.stack([ 237 | // tf.div(tf.cast(tf.fromPixels(document.getElementById('test-canvas') as HTMLCanvasElement), 'float32'), 255) 238 | tf.div(tf.cast(tf.browser.fromPixels($canvas), 'float32'), 255) 239 | ]) 240 | let output = (await model.predict(sample)) as tf.Tensor[] 241 | output = output.map(feats => feats.reshape(feats.shape.slice(1))) 242 | 243 | const boxes = await yoloEval( 244 | output, 245 | tf.tensor1d(anchors).reshape([-1, 2]), 246 | COCO_CLASSESS.length, 247 | [$img.clientHeight, $img.clientWidth] 248 | // [416, 416] 249 | ) 250 | return boxes 251 | } 252 | } 253 | --------------------------------------------------------------------------------