parseInt(i, 10)),
42 | y: values,
43 | type: 'scatter',
44 | mode: 'lines+markers',
45 | marker: { color: '#1a9afc' },
46 | },
47 | ]}
48 | layout={{ width: 420, height: 340, title: metric }}
49 | />
50 | );
51 | });
52 |
53 | return (
54 |
55 |
59 | {'<'}Train{' />'}
60 |
61 | {metricElems}
62 | {this.state.modelElement}
63 |
64 | );
65 | }
66 |
67 | shouldComponentUpdate(nextProps, nextState) {
68 | return (
69 | this.props.train != nextProps.train ||
70 | this.props.display != nextProps.display ||
71 | this.state.modelElement != nextState.modelElement ||
72 | this.state.metrics != nextState.metrics
73 | );
74 | }
75 |
76 | componentDidUpdate(prevProps) {
77 | // Resume training if train state was changed to true
78 | if (this.props.train && !prevProps.train && this.trainer != null) {
79 | this.trainer.next();
80 | }
81 | }
82 |
83 | async * _train(model) {
84 | const {
85 | trainData,
86 | samples,
87 | validationData,
88 | epochs,
89 | batchSize,
90 | display,
91 | } = this.props;
92 |
93 |
94 | // TODO: Switch to PropTypes
95 | const onBatchEnd = typeof this.props.onBatchEnd === 'function' ? this.props.onBatchEnd : () => { };
96 |
97 | for (let epoch = 0; epoch < epochs; epoch++) {
98 | const trainGenerator = trainData();
99 | for (let batch = 0; batch * batchSize < samples; batch++) {
100 | // Pause training when train prop is false
101 | const train = this.props.train;
102 | if (!train) {
103 | yield;
104 | }
105 |
106 | const trainBatch = this._getBatch(trainGenerator, batchSize);
107 | const history = await model.fit(trainBatch.xs, trainBatch.ys, { batchSize: trainBatch.xs.shape[0], epochs: 1 });
108 | onBatchEnd(history.history, model);
109 | tf.dispose(trainBatch);
110 |
111 | if (display) {
112 | const fitMetrics = history.history;
113 | this._pushMetrics(fitMetrics);
114 | await tf.nextFrame();
115 | }
116 | }
117 |
118 | if (validationData) {
119 | const valGenerator = validationData();
120 | // Just get all the validation data at once
121 | const valBatch = this._getBatch(valGenerator, Infinity);
122 | const valMetrics = model.evaluate(valBatch.xs, valBatch.ys, { batchSize });
123 | const history = {};
124 |
125 | for (let i = 0; i < valMetrics.length; i++) {
126 | const metric = model.metricsNames[i];
127 | history[`validation-${metric}`] = await valMetrics[i].data();
128 | }
129 |
130 | this._pushMetrics(history);
131 |
132 | tf.dispose(valMetrics);
133 | tf.dispose(valBatch);
134 | }
135 | }
136 |
137 | this.props.onTrainEnd(model);
138 | }
139 |
140 | _getBatch(generator, batchSize = 32) {
141 | const xs = [];
142 | const ys = [];
143 |
144 | for (let i = 0; i < batchSize; i++) {
145 | const sample = generator.next().value;
146 |
147 | if (sample == null) {
148 | break;
149 | }
150 |
151 | xs.push(sample.x);
152 | ys.push(sample.y);
153 | }
154 |
155 | if (xs.length === 0) {
156 | throw new Error('No data returned from data generator for batch, check sample length');
157 | }
158 |
159 | // Either stack if it's a generator of tensors, or convert to tensor if
160 | // it's a generator of JS arrs
161 | const stack = arr => arr[0] instanceof tf.Tensor ?
162 | tf.stack(arr) : tf.tensor(arr);
163 |
164 | return {
165 | xs: stack(xs),
166 | ys: stack(ys),
167 | };
168 | }
169 |
170 | _pushMetrics(metrics) {
171 | const updatedMetrics = { ...this.state.metrics };
172 | Object.keys(metrics).forEach(metric => {
173 | if (updatedMetrics[metric] == null) {
174 | updatedMetrics[metric] = [];
175 | }
176 | updatedMetrics[metric].push(metrics[metric][0]);
177 | });
178 | this.setState({
179 | metrics: updatedMetrics,
180 | });
181 | }
182 | }
183 |
--------------------------------------------------------------------------------
/src/index.js:
--------------------------------------------------------------------------------
1 | import React from 'react';
2 | import * as tf from '@tensorflow/tfjs';
3 |
4 | export { Train } from './Train';
5 |
6 | function parseLayerElement(element) {
7 | switch (element.type) {
8 | case Conv2D:
9 | return tf.layers.conv2d(element.props);
10 | case Dense:
11 | return tf.layers.dense(element.props);
12 | case Flatten:
13 | return tf.layers.flatten(element.props);
14 | case MaxPooling2D:
15 | return tf.layers.maxPooling2d(element.props);
16 | default:
17 | throw new Error('Invalid Layer', element);
18 | }
19 | }
20 |
21 | export class Model extends React.Component {
22 | render() {
23 | return null;
24 | }
25 |
26 | _compile() {
27 | const {
28 | children,
29 | optimizer,
30 | loss,
31 | metrics,
32 | onCompile,
33 | } = this.props;
34 |
35 | const layerElements = React.Children.toArray(children);
36 |
37 | const model = tf.sequential();
38 |
39 | layerElements.forEach(layerElement => {
40 | model.add(parseLayerElement(layerElement));
41 | });
42 |
43 | model.compile({
44 | optimizer,
45 | loss,
46 | metrics,
47 | });
48 |
49 | if (typeof onCompile === 'function') {
50 | onCompile(model);
51 | }
52 | }
53 |
54 | componentDidUpdate(prevProps) {
55 | if (this.props != prevProps)
56 | this._compile();
57 | }
58 |
59 | componentDidMount() {
60 | this._compile();
61 | }
62 | }
63 |
64 | // Layer Types
65 | export function Conv2D() { return null; }
66 | export function Dense() { return null; }
67 | export function Flatten() { return null; }
68 | export function MaxPooling2D() { return null; }
69 |
--------------------------------------------------------------------------------
/webpack.config.js:
--------------------------------------------------------------------------------
1 | const path = require('path');
2 | const webpack = require('webpack');
3 |
4 | module.exports = {
5 | mode: 'development',
6 | context: path.resolve(__dirname, 'src'),
7 | entry: {
8 | index: ['./index.js'],
9 | },
10 | output: {
11 | path: path.resolve(__dirname, 'dist'),
12 | filename: '[name].bundle.js',
13 | library: 'remoteRequire',
14 | libraryTarget: 'umd',
15 | },
16 | module: {
17 | rules: [
18 | {
19 | test: /\.js$/i,
20 | exclude: [/node_modules/],
21 | use: [{
22 | loader: 'babel-loader',
23 | }],
24 | },
25 | ],
26 | },
27 | externals: {
28 | '@tensorflow/tfjs': '@tensorflow/tfjs',
29 | 'react': 'react',
30 | },
31 | };
32 |
--------------------------------------------------------------------------------