├── typings.d.ts
├── src
├── images
│ ├── a.jpg
│ ├── 111.jpg
│ ├── 222.jpeg
│ ├── 333.jpg
│ ├── dog.jpg
│ ├── test.png
│ ├── horses.jpg
│ ├── people.jpg
│ ├── street.jpg
│ ├── street1.jpg
│ └── timg.jpeg
├── style.css
├── index.html
├── index.ts
└── yolo
│ ├── config.js
│ └── yolo-eval.ts
├── docs
├── img
│ └── demo1.jpg
└── index.html
├── .gitignore
├── tsconfig.json
├── .eslintrc
├── package.json
├── README.md
├── README_EN.md
└── webpack.config.js
/typings.d.ts:
--------------------------------------------------------------------------------
1 | declare module '@antv/g2'
2 | declare module '@antv/data-set'
--------------------------------------------------------------------------------
/src/images/a.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zqingr/tfjs-yolov3/HEAD/src/images/a.jpg
--------------------------------------------------------------------------------
/docs/img/demo1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zqingr/tfjs-yolov3/HEAD/docs/img/demo1.jpg
--------------------------------------------------------------------------------
/src/images/111.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zqingr/tfjs-yolov3/HEAD/src/images/111.jpg
--------------------------------------------------------------------------------
/src/images/222.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zqingr/tfjs-yolov3/HEAD/src/images/222.jpeg
--------------------------------------------------------------------------------
/src/images/333.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zqingr/tfjs-yolov3/HEAD/src/images/333.jpg
--------------------------------------------------------------------------------
/src/images/dog.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zqingr/tfjs-yolov3/HEAD/src/images/dog.jpg
--------------------------------------------------------------------------------
/src/images/test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zqingr/tfjs-yolov3/HEAD/src/images/test.png
--------------------------------------------------------------------------------
/src/images/horses.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zqingr/tfjs-yolov3/HEAD/src/images/horses.jpg
--------------------------------------------------------------------------------
/src/images/people.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zqingr/tfjs-yolov3/HEAD/src/images/people.jpg
--------------------------------------------------------------------------------
/src/images/street.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zqingr/tfjs-yolov3/HEAD/src/images/street.jpg
--------------------------------------------------------------------------------
/src/images/street1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zqingr/tfjs-yolov3/HEAD/src/images/street1.jpg
--------------------------------------------------------------------------------
/src/images/timg.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zqingr/tfjs-yolov3/HEAD/src/images/timg.jpeg
--------------------------------------------------------------------------------
/src/style.css:
--------------------------------------------------------------------------------
1 | #img-box {
2 | position: relative;
3 | }
4 | #img-box .rect {
5 | position: absolute;
6 | border: 2px solid #f00;
7 | }
8 | #img-box .rect .className {
9 | position: absolute;
10 | top: 0;
11 | background: #f00;
12 | color: #fff;
13 |
14 | }
15 | #test-canvas {
16 | }
--------------------------------------------------------------------------------
/docs/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | Document
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # See http://help.github.com/ignore-files/ for more about ignoring files.
2 |
3 | # compiled output
4 | /tmp
5 | /dist
6 |
7 | # dependencies
8 | /node_modules
9 | /bower_components
10 | /src/model
11 |
12 | # IDEs and editors
13 | /.idea
14 | .vscode/*
15 | .project
16 | .classpath
17 | *.launch
18 | .settings/
19 |
20 | # misc
21 | /.sass-cache
22 | /connect.lock
23 | /coverage/*
24 | /libpeerconnection.log
25 | npm-debug.log
26 | testem.log
27 | /typings
28 |
29 | # e2e
30 | /e2e/*.js
31 | /e2e/*.map
32 |
33 | #System Files
34 | .DS_Store
35 | Thumbs.db
36 | node.test.js
37 |
38 | yarn.lock
39 |
40 | webpack.config.prod.js
--------------------------------------------------------------------------------
/src/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | Document
8 |
9 |
10 |
11 |

12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "outDir": "./dist/",
4 | "sourceMap": true,
5 | "noImplicitAny": true,
6 | "module": "es6",
7 | "moduleResolution": "node",
8 | "target": "es5",
9 | "allowJs": true,
10 | "allowSyntheticDefaultImports": true,
11 | "experimentalDecorators": true,
12 | "isolatedModules": false,
13 | "noImplicitThis": true,
14 | "strictNullChecks": true,
15 | "removeComments": true,
16 | "suppressImplicitAnyIndexErrors": true,
17 | "lib" : ["es2015", "es2017", "dom"],
18 | "baseUrl": "types",
19 | "typeRoots": [
20 | "node_modules/@types"
21 | ],
22 | "paths": {
23 | "jquery-ui": [
24 | "node_modules/@types/jqueryui/index"
25 | ]
26 | }
27 | },
28 | "exclude": [
29 | "node_modules"
30 | ]
31 | }
--------------------------------------------------------------------------------
/.eslintrc:
--------------------------------------------------------------------------------
1 | {
2 | "extends": [
3 | "standard"
4 | ],
5 | "plugins": [
6 | "html",
7 | "import",
8 | "babel"
9 | ],
10 | "env": {
11 | "browser": true,
12 | "node": true,
13 | "es6": true,
14 | "jquery": true,
15 | "commonjs": true,
16 | "phantomjs": true,
17 | "mocha": true
18 | },
19 | "rules": {
20 | "no-confusing-arrow": "off",
21 | "babel/new-cap": 1,
22 | "no-undef": "off",
23 | "import/no-unresolved": "off",
24 | "import/extensions": "off",
25 | "import/no-webpack-loader-syntax": "off",
26 | "keyword-spacing": "off",
27 | "no-unused-vars": [ 0, { "varsIgnorePattern": "Component$", "args": "none" } ],
28 | "no-new": 0,
29 | "indent": [
30 | "error",
31 | 2,
32 | { "ignoredNodes": [ "Program" ] }
33 | ]
34 | },
35 | "parser": "typescript-eslint-parser",
36 | "parserOptions": {
37 | "sourceType": "module",
38 | "ecmaFeatures": {
39 | "jsx": true
40 | }
41 | }
42 | }
43 |
44 |
--------------------------------------------------------------------------------
/src/index.ts:
--------------------------------------------------------------------------------
1 | import { yolov3Tiny } from './yolo/yolo-eval'
2 | // import { yolov3Tiny } from '../dist/index.bundle'
3 |
4 | import './style.css'
5 | // import * as tf from '@tensorflow/tfjs'
6 |
7 | // async function load () {
8 | // const model = await tf.loadModel('/model/yolov3-tiny/model.json')
9 | // // const model = await tf.loadModel('https://zqingr.github.io/tfjs-yolov3-demo/model/yolov3-tiny/model.json')
10 |
11 | // model.summary()
12 | // }
13 | // load()
14 |
15 | const $img = document.getElementById('img') as HTMLImageElement
16 |
17 | async function start () {
18 | // // const model = await tf.loadModel('https://zqingr.github.io/tfjs-yolov3-demo/model/yolov3-tiny/model.json')
19 | const yolo = await yolov3Tiny()
20 | const boxes = await yolo($img)
21 |
22 | const $imgbox = document.getElementById('img-box') as HTMLElement
23 |
24 | boxes.forEach(box => {
25 | const $div = document.createElement('div')
26 | $div.className = 'rect'
27 | $div.style.top = box.top + 'px'
28 | $div.style.left = box.left + 'px'
29 | $div.style.width = box.width + 'px'
30 | $div.style.height = box.height + 'px'
31 | $div.innerHTML = `${box.classes} ${box.scores}`
32 |
33 | $imgbox.appendChild($div)
34 | })
35 |
36 | console.log(boxes)
37 | }
38 | start()
39 |
40 | // import yolov3 from './yolo/yolo-eval'
41 |
42 | // export default yolov3
43 |
--------------------------------------------------------------------------------
/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "random-data",
3 | "version": "1.0.0",
4 | "description": "",
5 | "main": "webpack.config.js",
6 | "scripts": {
7 | "test": "echo \"Error: no test specified\" && exit 1",
8 | "start": "cross-env NODE_ENV=dev PLATFORM=web webpack-dev-server --progress --host 0.0.0.0 --disableHostCheck=true",
9 | "build": "cross-env NODE_ENV=pro PLATFORM=web webpack --progress --config ./webpack.config.prod.js --optimize-minimize"
10 | },
11 | "author": "",
12 | "license": "ISC",
13 | "devDependencies": {
14 | "@types/d3": "^5.7.1",
15 | "cross-env": "^5.2.0",
16 | "eslint": "^5.15.1",
17 | "eslint-config-standard": "^12.0.0",
18 | "eslint-loader": "^2.1.2",
19 | "eslint-plugin-babel": "^5.3.0",
20 | "eslint-plugin-html": "^5.0.3",
21 | "eslint-plugin-import": "^2.16.0",
22 | "eslint-plugin-node": "^8.0.1",
23 | "eslint-plugin-promise": "^4.0.1",
24 | "eslint-plugin-standard": "^4.0.0",
25 | "file-loader": "^3.0.1",
26 | "ts-loader": "^5.3.3",
27 | "ts-node": "^8.0.3",
28 | "typescript": "^3.3.3333",
29 | "typescript-eslint-parser": "^22.0.0",
30 | "url-loader": "^1.1.2",
31 | "webpack": "^4.29.6",
32 | "webpack-cli": "^3.2.3",
33 | "webpack-dev-server": "^3.2.1"
34 | },
35 | "dependencies": {
36 | "@antv/data-set": "^0.10.2",
37 | "@antv/g2": "^3.4.10",
38 | "@tensorflow/tfjs": "1.0.0",
39 | "@types/echarts": "^4.1.5",
40 | "css-loader": "^2.1.1",
41 | "echarts": "^4.1.0",
42 | "node-fetch": "^2.3.0",
43 | "style-loader": "^0.23.1"
44 | }
45 | }
46 |
--------------------------------------------------------------------------------
/src/yolo/config.js:
--------------------------------------------------------------------------------
1 | export const ANCHORS = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326]
2 | export const ANCHORS_TINY = [10,14, 23,27, 37,58, 81,82, 135,169, 344,319]
3 | export const COCO_CLASSESS = [
4 | 'person',
5 | 'bicycle',
6 | 'car',
7 | 'motorbike',
8 | 'aeroplane',
9 | 'bus',
10 | 'train',
11 | 'truck',
12 | 'boat',
13 | 'traffic light',
14 | 'fire hydrant',
15 | 'stop sign',
16 | 'parking meter',
17 | 'bench',
18 | 'bird',
19 | 'cat',
20 | 'dog',
21 | 'horse',
22 | 'sheep',
23 | 'cow',
24 | 'elephant',
25 | 'bear',
26 | 'zebra',
27 | 'giraffe',
28 | 'backpack',
29 | 'umbrella',
30 | 'handbag',
31 | 'tie',
32 | 'suitcase',
33 | 'frisbee',
34 | 'skis',
35 | 'snowboard',
36 | 'sports ball',
37 | 'kite',
38 | 'baseball bat',
39 | 'baseball glove',
40 | 'skateboard',
41 | 'surfboard',
42 | 'tennis racket',
43 | 'bottle',
44 | 'wine glass',
45 | 'cup',
46 | 'fork',
47 | 'knife',
48 | 'spoon',
49 | 'bowl',
50 | 'banana',
51 | 'apple',
52 | 'sandwich',
53 | 'orange',
54 | 'broccoli',
55 | 'carrot',
56 | 'hot dog',
57 | 'pizza',
58 | 'donut',
59 | 'cake',
60 | 'chair',
61 | 'sofa',
62 | 'pottedplant',
63 | 'bed',
64 | 'diningtable',
65 | 'toilet',
66 | 'tvmonitor',
67 | 'laptop',
68 | 'mouse',
69 | 'remote',
70 | 'keyboard',
71 | 'cell phone',
72 | 'microwave',
73 | 'oven',
74 | 'toaster',
75 | 'sink',
76 | 'refrigerator',
77 | 'book',
78 | 'clock',
79 | 'vase',
80 | 'scissors',
81 | 'teddy bear',
82 | 'hair drier',
83 | 'toothbrush',
84 | ]
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 中文 | [English](./README_EN.md)
2 |
3 | # tfjs-yolov3
4 |
5 | ## 介绍
6 |
7 | 完全用js来实现图片中的目标检测
8 | 基于yolov3算法和Tensorflow.js库
9 | 用tensorflow.js实现yolov3和yolov3-tiny
10 |
11 | 需要注意的是: 必须是Tensorflow.js@v0.12.4版本以上
12 |
13 | ## 特点
14 | - 可以识别**任意尺寸**的图片
15 | - 同时支持yolov3和yolov3-tiny
16 |
17 | ## 快速开始
18 |
19 | ### 安装
20 |
21 | ```
22 | npm install tfjs-yolov3
23 | ```
24 |
25 | ### 用法示例
26 |
27 | ```javascript
28 | import { yolov3, yolov3Tiny } from 'tfjs-yolov3'
29 |
30 | async function start () {
31 | const yolo = await yolov3Tiny() // pre-load model (35M)
32 | // or
33 | // const yolo = await yolov3() // pre-load model (245M)
34 |
35 | const $img = document.getElementById('img')
36 | const boxes = await yolo($img)
37 | draw(boxes) // Some draw function
38 | }
39 | start()
40 | ```
41 |
42 | ## API 文档
43 |
44 | yolov3和yolov3Tiny函数接受一个options对象,并返回一个函数
45 |
46 | ```typescript
47 | export declare function yolov3 (
48 | { modelUrl, anchors }? :
49 | { modelUrl?: string, anchors?: number[] }
50 | ): Promise
51 |
52 | export declare function yolov3Tiny (
53 | { modelUrl, anchors }? :
54 | { modelUrl?: string, anchors?: number[] }
55 | ): Promise
56 | ```
57 |
58 | | 参数 | 说明 |
59 | | ------------ | ------------ |
60 | | modelUrl | 可选,预训练model的url,可把model下载到本地,加快预训练model的加载速度,[点我下载](https://github.com/zqingr/tfjs-yolov3/releases/tag/v1.0) |
61 | | anchors | 可选,可自定义anchores,格式参考[config](https://github.com/zqingr/tfjs-yolov3/blob/master/src/yolo/config.js) |
62 |
63 | 这两个函数调用后会加载预训练model,并返回一个函数,可用这个函数去识别图片,并返回识别后的box列表,参数如下:
64 |
65 | ```typescript
66 | type yolo = ($img: HTMLImageElement) => Promise
67 |
68 | interface Box {
69 | top: number
70 | left: number
71 | bottom: number
72 | right: number
73 | width: number
74 | height: number
75 | scores: number
76 | classes: string
77 | }
78 | ```
79 |
80 |
81 |
82 | ## DEMO
83 |
84 | [点击查看在线DEMO](https://zqingr.github.io/tfjs-yolov3-demo/)
85 |
86 | 
87 |
88 |
89 |
--------------------------------------------------------------------------------
/README_EN.md:
--------------------------------------------------------------------------------
1 | # tfjs-yolov3
2 |
3 | ### Introduction
4 |
5 | A Tensorflow.js implementation of YOLOv3 and YOLOv3-tiny
6 |
7 | Note: Must Tensorflow.js@v0.12.4+
8 |
9 | # Features
10 | - can recognize images of **any size**
11 | - Support both **yolov3** and **yolov3-tiny**
12 |
13 | ## Quick Start
14 |
15 | ### Install
16 |
17 | ```
18 | npm install tfjs-yolov3
19 | ```
20 |
21 | ### Usage Example
22 |
23 | ```javascript
24 | import { yolov3, yolov3Tiny } from 'tfjs-yolov3'
25 |
26 | async function start () {
27 | const yolo = await yolov3Tiny() // pre-load model (35M)
28 | // or
29 | // const yolo = await yolov3() // pre-load model (245M)
30 |
31 | const $img = document.getElementById('img')
32 | const boxes = await yolo($img)
33 | draw(boxes) // Some draw function
34 | }
35 | start()
36 | ```
37 |
38 |
39 | ## API DOC
40 |
41 | The yolov3 and yolov3Tiny functions accept an options object and return a function
42 |
43 | ```typescript
44 | export declare function yolov3 (
45 | { modelUrl, anchors }? :
46 | { modelUrl?: string, anchors?: number[] }
47 | ): Promise
48 |
49 | export declare function yolov3Tiny (
50 | { modelUrl, anchors }? :
51 | { modelUrl?: string, anchors?: number[] }
52 | ): Promise
53 | ```
54 |
55 | | Parameters | Description |
56 | | ------------ | ------------ |
57 | | modelUrl | Optional, pre-train the model's url, you can [download](https://github.com/zqingr/tfjs-yolov3/releases/tag/v1.0) the model to the local, speed up the loading of the pre-training model |
58 | | anchors | Optional, custom anchors, format reference[config](https://github.com/zqingr/tfjs-yolov3/blob/master/src/yolo/config.js) |
59 |
60 | After the above two functions are called, the pre-training model will be loaded, and a function will be returned. This function can be used to identify the image and return the identified box list. The parameters are as follows:
61 |
62 | ```typescript
63 | type yolo = ($img: HTMLImageElement) => Promise
64 |
65 | interface Box {
66 | top: number
67 | left: number
68 | bottom: number
69 | right: number
70 | width: number
71 | height: number
72 | scores: number
73 | classes: string
74 | }
75 | ```
76 |
77 |
78 |
79 |
80 | ## DEMO
81 |
82 | [Check out the Live Demo](https://zqingr.github.io/tfjs-yolov3-demo/)
83 |
84 | 
85 |
86 |
87 |
--------------------------------------------------------------------------------
/webpack.config.js:
--------------------------------------------------------------------------------
1 | const path = require('path');
2 |
3 | module.exports = function (env) {
4 | return {
5 | mode: 'development',
6 | context: path.join(process.cwd(), 'src'),
7 | entry: {
8 | 'index': `./index.ts`
9 | },
10 | output: {
11 | path: path.resolve(__dirname, 'dist'),
12 | filename: '[name].bundle.js',
13 | libraryTarget: "umd"
14 | },
15 | resolve: {
16 | // Add `.ts` and `.tsx` as a resolvable extension.
17 | extensions: [".ts", ".tsx", ".js"]
18 | },
19 | module: {
20 | rules: [
21 | {
22 | test: /\.ts$/,
23 | enforce: "pre",
24 | exclude: /node_modules/,
25 | loader: "eslint-loader",
26 | options: {
27 | fix: true
28 | }
29 | },
30 | {
31 | type: 'javascript/auto',
32 | test: /\.(json)$/,
33 | exclude: /node_modules/,
34 | loader: [
35 | `file-loader?publicPath=/&name=[name].[ext]`
36 | ]
37 | },
38 | {
39 | test: /\.ts$/,
40 | exclude: /node_modules/,
41 | loader: 'ts-loader',
42 | options: {
43 | appendTsSuffixTo: [/\.vue$/, /\.ts$/],
44 | transpileOnly: true
45 | }
46 | },
47 | {
48 | test: /\.html$/,
49 | loader: 'raw-loader',
50 | exclude: path.join(__dirname, './src/index.html')
51 | },
52 | {
53 | test: /\.(jpg|jpeg|gif|png)$/,
54 | loader: [
55 | `url-loader?limit=4112&publicPath=/&name=[name].[ext]`
56 | ]
57 | },
58 | {
59 | test: /\.css$/,
60 | exclude: /node_modules/,
61 | use: [
62 | 'style-loader',
63 | 'css-loader'
64 | ]
65 | }]
66 | },
67 | externals: {
68 | "tf": "@tensorflow/tfjs"
69 | },
70 | devServer: {
71 | port: 8000,
72 | host: 'localhost',
73 | historyApiFallback: true,
74 | watchOptions: {
75 | aggregateTimeout: 300,
76 | poll: 1000
77 | },
78 | contentBase: ['src'],
79 | open: false,
80 | stats: {
81 | assets: true,
82 | children: false,
83 | chunks: false,
84 | hash: false,
85 | modules: false,
86 | publicPath: false,
87 | timings: true,
88 | version: false,
89 | warnings: true,
90 | colors: {
91 | green: '\u001b[32m',
92 | }
93 | }
94 | }
95 | }
96 | }
--------------------------------------------------------------------------------
/src/yolo/yolo-eval.ts:
--------------------------------------------------------------------------------
1 | import * as tf from '@tensorflow/tfjs'
2 | import { COCO_CLASSESS, ANCHORS, ANCHORS_TINY } from './config'
3 |
4 | const grenerateArr = (num: number) =>
5 | Array(num)
6 | .fill(0)
7 | .map((v, i) => i)
8 |
9 | function yoloHead (
10 | feats: tf.Tensor,
11 | anchors: tf.Tensor2D,
12 | numClasses: number,
13 | inputShape: number[]
14 | ) {
15 | return tf.tidy(() => {
16 | const numAnchors = anchors.shape[0]
17 |
18 | // Reshape to height, width, num_anchors, box_params.
19 | const anchorsTensor = anchors.reshape([1, numAnchors, 2])
20 |
21 | const gridShape = feats.shape.slice(0, 2) // height, width
22 | const gridY = tf.tile(
23 | tf.reshape(grenerateArr(gridShape[0]), [-1, 1, 1, 1]),
24 | [1, gridShape[1], 1, 1]
25 | )
26 | const gridX = tf.tile(
27 | tf.reshape(grenerateArr(gridShape[1]), [1, -1, 1, 1]),
28 | [gridShape[0], 1, 1, 1]
29 | )
30 |
31 | let grid = gridX.concat(gridY, 3)
32 |
33 | grid = tf.cast(grid, feats.dtype)
34 | const newfeats = tf.reshape(feats, [
35 | gridShape[0],
36 | gridShape[1],
37 | numAnchors,
38 | numClasses + 5
39 | ])
40 |
41 | // Adjust preditions to each spatial grid point and anchor size.
42 | const [xy, wh, con, probs] = tf.split(newfeats, [2, 2, 1, 80], 3)
43 |
44 | const boxXY = tf.div(tf.add(tf.sigmoid(xy), grid), gridShape.reverse())
45 | const boxWH = tf.div(
46 | tf.mul(tf.exp(wh), anchorsTensor),
47 | inputShape.reverse()
48 | )
49 |
50 | const boxConfidence = tf.sigmoid(con)
51 | const boxClassProbs = tf.sigmoid(probs)
52 |
53 | return [boxXY, boxWH, boxConfidence, boxClassProbs]
54 | })
55 | }
56 |
57 | // Get corrected boxes
58 |
59 | function yoloCorrectBoxes (
60 | boxXY: tf.Tensor,
61 | boxWH: tf.Tensor,
62 | inputShape: number[],
63 | imageShape: number[]
64 | ) {
65 | // boxXY.print()
66 | return tf.tidy(() => {
67 | let boxYX = tf.concat(tf.split(boxXY, [1, 1], 3).reverse(), 3)
68 | let boxHW = tf.concat(tf.split(boxWH, [1, 1], 3).reverse(), 3)
69 |
70 | const scale = tf.div(inputShape, imageShape)
71 | boxYX = tf.div(tf.mul(boxYX, inputShape), scale)
72 | boxHW = tf.div(tf.mul(boxHW, inputShape), scale)
73 |
74 | const boxMins = tf.sub(boxYX, tf.div(boxHW, 2))
75 | const boxMaxes = tf.add(boxYX, tf.div(boxHW, 2))
76 |
77 | const boxes = tf.concat(
78 | [...tf.split(boxMins, [1, 1], 3), ...tf.split(boxMaxes, [1, 1], 3)],
79 | 3
80 | )
81 |
82 | return boxes
83 | })
84 | }
85 |
86 | /**
87 | * Process Conv layer output
88 | */
89 | function yoloBoxesAndScores (
90 | feats: tf.Tensor,
91 | anchors: tf.Tensor2D,
92 | numClasses: number,
93 | inputShape: number[],
94 | imageShape: number[]
95 | ) {
96 | const [boxXY, boxWH, boxConfidence, boxClassProbs] = yoloHead(
97 | feats,
98 | anchors,
99 | numClasses,
100 | inputShape
101 | )
102 | let boxes = yoloCorrectBoxes(boxXY, boxWH, inputShape, imageShape)
103 | boxes = boxes.reshape([-1, 4])
104 |
105 | let boxScores = tf.mul(boxConfidence, boxClassProbs)
106 | boxScores = boxScores.reshape([-1, numClasses])
107 |
108 | return [boxes, boxScores]
109 | }
110 |
111 | /**
112 | * Evaluate YOLO model on given input and return filtered boxes.
113 | */
114 | async function yoloEval (
115 | output: tf.Tensor[],
116 | anchors: tf.Tensor2D,
117 | numberClasses: number,
118 | imageShape: number[],
119 | maxBoxs: number = 20,
120 | scoreThreshold: number = 0.3,
121 | iouThreshold: number = 0.45
122 | ) {
123 | const numLayers = output.length
124 | const anchorMask =
125 | numLayers === 3 ? [[6, 7, 8], [3, 4, 5], [0, 1, 2]] : [[3, 4, 5], [1, 2, 3]] // default setting
126 |
127 | const inputShape = output[0].shape.slice(0, 2).map(num => num * 32)
128 | const boxesArr = []
129 | const boxScoresArr = []
130 |
131 | for (let index = 0; index < numLayers; index++) {
132 | const [_boxes, _boxScores] = yoloBoxesAndScores(
133 | output[index],
134 | anchors.gather(tf.cast(tf.tensor1d(anchorMask[index]), 'int32')),
135 | numberClasses,
136 | inputShape,
137 | imageShape
138 | )
139 | boxesArr.push(_boxes)
140 | boxScoresArr.push(_boxScores)
141 | }
142 |
143 | const boxes = tf.concat(boxesArr, 0)
144 | const boxScores = tf.concat(boxScoresArr, 0)
145 |
146 | let boxes_: Float32Array[][] = []
147 | let scores_: Float32Array[] = []
148 | let classes_: number[] = []
149 |
150 | const splitBoxScores = tf.split(boxScores, Array(numberClasses).fill(1), 1)
151 |
152 | for (let index = 0; index < numberClasses; index++) {
153 | const nmsIndex = await tf.image.nonMaxSuppressionAsync(
154 | boxes as tf.Tensor,
155 | splitBoxScores[index].reshape([-1]),
156 | maxBoxs,
157 | iouThreshold,
158 | scoreThreshold
159 | )
160 | if (!nmsIndex.size) continue
161 |
162 | const classBoxes = tf.gather(boxes, nmsIndex)
163 | const classBoxScores = tf.gather(splitBoxScores[index], nmsIndex)
164 |
165 | boxes_ = boxes_.concat(
166 | tf
167 | .split(classBoxes, Array(nmsIndex.size).fill(1))
168 | .map(d => d.dataSync() as Float32Array)
169 | )
170 | scores_ = scores_.concat(
171 | tf
172 | .split(classBoxScores, Array(nmsIndex.size).fill(1))
173 | .map(d => d.dataSync() as Float32Array)
174 | )
175 | classes_ = classes_.concat(Array(nmsIndex.size).fill(index))
176 |
177 | classBoxScores.dispose()
178 | classBoxes.dispose()
179 | }
180 |
181 | boxes.dispose()
182 | boxScores.dispose()
183 |
184 | return boxes_.map((box, i) => {
185 | return {
186 | top: box[0],
187 | left: box[1],
188 | bottom: box[2],
189 | right: box[3],
190 | width: +box[3] - +box[1],
191 | height: +box[2] - +box[0],
192 | scores: scores_[i][0],
193 | classes: COCO_CLASSESS[classes_[i]]
194 | }
195 | })
196 | }
197 |
198 | const $canvas = document.createElement('canvas')
199 | $canvas.width = 416
200 | $canvas.height = 416
201 | const ctx = $canvas.getContext('2d') as CanvasRenderingContext2D
202 |
203 | export async function yolov3Tiny ({
204 | modelUrl = 'https://zqingr.github.io/tfjs-yolov3-demo/model/yolov3-tiny/model.json',
205 | anchors = ANCHORS_TINY
206 | }: { modelUrl?: string; anchors?: number[] } = {}) {
207 | const yoloTinyData = await yolo({ modelUrl, anchors })
208 | return yoloTinyData
209 | }
210 | export async function yolov3 ({
211 | modelUrl = 'https://zqingr.github.io/tfjs-yolov3-demo/model/yolov3/model.json',
212 | anchors = ANCHORS
213 | }: { modelUrl?: string; anchors?: number[] } = {}) {
214 | const yoloData = await yolo({ modelUrl, anchors })
215 | return yoloData
216 | }
217 |
218 | async function yolo ({
219 | modelUrl,
220 | anchors
221 | }: {
222 | modelUrl: string
223 | anchors: number[]
224 | }) {
225 | const localModels = await tf.io.listModels()
226 | if (localModels['indexeddb://yolov3-1']) {
227 | const model = await tf.loadLayersModel('indexeddb://yolov3-1')
228 | } else {
229 | const model = await tf.loadLayersModel(modelUrl)
230 | await model.save('indexeddb://yolov3-1')
231 | }
232 |
233 | return async ($img: HTMLImageElement) => {
234 | ctx.drawImage($img, 0, 0, 416, 416)
235 |
236 | const sample = tf.stack([
237 | // tf.div(tf.cast(tf.fromPixels(document.getElementById('test-canvas') as HTMLCanvasElement), 'float32'), 255)
238 | tf.div(tf.cast(tf.browser.fromPixels($canvas), 'float32'), 255)
239 | ])
240 | let output = (await model.predict(sample)) as tf.Tensor[]
241 | output = output.map(feats => feats.reshape(feats.shape.slice(1)))
242 |
243 | const boxes = await yoloEval(
244 | output,
245 | tf.tensor1d(anchors).reshape([-1, 2]),
246 | COCO_CLASSESS.length,
247 | [$img.clientHeight, $img.clientWidth]
248 | // [416, 416]
249 | )
250 | return boxes
251 | }
252 | }
253 |
--------------------------------------------------------------------------------