├── .gitignore ├── README.md ├── images ├── eraser.png ├── input.png ├── input1.png ├── input2.png ├── input3.png ├── input4.png ├── input5.png ├── input6.png ├── input7.png ├── output.png ├── output1.png ├── output2.png ├── output3.png ├── output4.png ├── output5.png ├── output6.png ├── output7.png ├── pencil.png └── pokeball.png ├── index.html ├── index.js ├── models └── edges2pikachu_AtoB.pict ├── pix2pix.js ├── style.css └── utils.js /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pix2pix in tensorflow.js 2 | 3 | # This repo is moved to [https://github.com/yining1023/pix2pix_tensorflowjs_lite](https://github.com/yining1023/pix2pix_tensorflowjs_lite) 4 | 5 | See a live demo here: [https://yining1023.github.io/pix2pix_tensorflowjs/](https://yining1023.github.io/pix2pix_tensorflowjs/) 6 | 7 | Screen_Shot_2018_06_17_at_11_06_09_PM 8 | 9 | Try it yourself: Download/clone the repository and run it locally: 10 | ```bash 11 | git clone https://github.com/yining1023/pix2pix_tensorflowjs.git 12 | cd pix2pix_tensorflowjs 13 | python3 -m http.server 14 | ``` 15 | 16 | 17 | 18 | Credits: This project is based on [affinelayer](https://github.com/affinelayer)'s [pix2pix-tensorflow](https://github.com/affinelayer/pix2pix-tensorflow). I want to thank [christopherhesse](https://github.com/christopherhesse), [nsthorat](https://github.com/nsthorat), and [dsmilkov](dsmilkov) for their help and suggestions from this Github [issue](https://github.com/tensorflow/tfjs/issues/79). 19 | 20 | 21 | 22 | ## How to train a pix2pix(edges2xxx) model from scratch 23 | - 1. Prepare the data 24 | - 2. Train the model 25 | - 3. Test the model 26 | - 4. Export the model 27 | - 5. Port the model to tensorflow.js 28 | - 6. Create an interactive interface in the browser 29 | 30 | 31 | 32 | ### 1. Prepare the data 33 | 34 | - 1.1 Scrape images from google search 35 | - 1.2 Remove the background of the images 36 | - 1.3 Resize all images into 256x256 px 37 | - 1.4 Detect edges of all images 38 | - 1.5 Combine input images and target images 39 | - 1.6 Split all combined images into two folders: `train` and `val` 40 | 41 | Before we start, check out [affinelayer](https://github.com/affinelayer)'s [Create your own dataset](https://github.com/affinelayer/pix2pix-tensorflow#creating-your-own-dataset). I followed his instrustion for steps 1.3, 1.5 and 1.6. 42 | 43 | 44 | #### 1.1 Scrape images from google search 45 | We can create our own target images. But for this edge2pikachu project, I downloaded a lot of images from google. I'm using this [google_image_downloader](https://github.com/atif93/google_image_downloader) to download images from google. 46 | After downloading the repo above, run - 47 | ``` 48 | $ python image_download.py 49 | ``` 50 | It will download images and save it to the current directory. 51 | 52 | 53 | #### 1.2 Remove the background of the images 54 | Some images have some background. I'm using [grabcut](https://docs.opencv.org/trunk/d8/d83/tutorial_py_grabcut.html) with OpenCV to remove background 55 | Check out the script here: [https://github.com/yining1023/pix2pix-tensorflow/blob/master/tools/grabcut.py](https://github.com/yining1023/pix2pix-tensorflow/blob/master/tools/grabcut.py) 56 | To run the script- 57 | ``` 58 | $ python grabcut.py 59 | ``` 60 | It will open an interactive interface, here are some instructions: [https://github.com/symao/InteractiveImageSegmentation](https://github.com/symao/InteractiveImageSegmentation) 61 | Here's an example of removing background using grabcut: 62 | 63 | Screen Shot 2018 03 13 at 7 03 28 AM 64 | 65 | 66 | #### 1.3 Resize all images into 256x256 px 67 | Download [pix2pix-tensorflow](https://github.com/affinelayer/pix2pix-tensorflow) repo. 68 | Put all images we got into `photos/original` folder 69 | Run - 70 | ``` 71 | $ python tools/process.py --input_dir photos/original --operation resize --output_dir photos/resized 72 | ``` 73 | We should be able to see a new folder called `resized` with all resized images in it. 74 | 75 | 76 | #### 1.4 Detect edges of all images 77 | The script that I use to detect edges of images from one folder at once is here: [https://github.com/yining1023/pix2pix-tensorflow/blob/master/tools/edge-detection.py](https://github.com/yining1023/pix2pix-tensorflow/blob/master/tools/edge-detection.py), we need to change the path of the input images directory on [line 31](https://github.com/yining1023/pix2pix-tensorflow/blob/3e0d6c8613b3aa69adffe5484989bbe2c82b2c57/tools/edge-detection.py#L31), and create a new empty folder called `edges` in the same directory. 78 | Run - 79 | ``` 80 | $ python edge-detection.py 81 | ``` 82 | We should be able to see edged-detected images in the `edges` folder. 83 | Here's an example of edge detection: left(original) right(edge detected) 84 | 85 | 0_batch2 86 | 0_batch2_2 87 | 88 | 89 | #### 1.5 Combine input images and target images 90 | ``` 91 | python tools/process.py --input_dir photos/resized --b_dir photos/blank --operation combine --output_dir photos/combined 92 | ``` 93 | 94 | Here is an example of the combined image: 95 | Notice that the size of the combined image is 512x256px. The size is important for training the model successfully. 96 | 97 | 0_batch2 98 | 99 | Read more here: [affinelayer](https://github.com/affinelayer)'s [Create your own dataset](https://github.com/affinelayer/pix2pix-tensorflow#creating-your-own-dataset) 100 | 101 | 102 | #### 1.6 Split all combined images into two folders: `train` and `val` 103 | ``` 104 | python tools/split.py --dir photos/combined 105 | ``` 106 | Read more here: [affinelayer](https://github.com/affinelayer)'s [Create your own dataset](https://github.com/affinelayer/pix2pix-tensorflow#creating-your-own-dataset) 107 | 108 | I collected 305 images for training and 78 images for testing. 109 | 110 | 111 | ### 2. Train the model 112 | ``` 113 | # train the model 114 | python pix2pix.py --mode train --output_dir pikachu_train --max_epochs 200 --input_dir pikachu/train --which_direction BtoA 115 | ``` 116 | Read more here: [https://github.com/affinelayer/pix2pix-tensorflow#getting-started](https://github.com/affinelayer/pix2pix-tensorflow#getting-started) 117 | 118 | I used the High Power Computer(HPC) at NYU to train the model. You can see more instruction here: [https://github.com/cvalenzuela/hpc](https://github.com/cvalenzuela/hpc). You can request GPU and submit a job to HPC, and use tunnels to tranfer large files between the HPC and your computer. 119 | 120 | The training takes me 4 hours and 16 mins. After train, there should be a `pikachu_train` folder with `checkpoint` in it. 121 | If you add `--ngf 32 --ndf 32` when training the model: python pix2pix.py --mode train --output_dir pikachu_train --max_epochs 200 --input_dir pikachu/train --which_direction BtoA --ngf 32 --ndf 32, the model will be smaller 13.6 MB, and it will take less time to train. 122 | 123 | 124 | ### 3. Test the model 125 | ``` 126 | # test the model 127 | python pix2pix.py --mode test --output_dir pikachu_test --input_dir pikachu/val --checkpoint pikachu_train 128 | ``` 129 | After testing, there should be a new folder called `pikachu_test`. In the folder, if you open the `index.html`, you should be able to see something like this in your browser: 130 | 131 | Screen_Shot_2018_03_15_at_8_42_48_AM
132 | 133 | Read more here: [https://github.com/affinelayer/pix2pix-tensorflow#getting-started](https://github.com/affinelayer/pix2pix-tensorflow#getting-started) 134 | 135 | 136 | ### 4. Export the model 137 | ``` 138 | python pix2pix.py --mode export --output_dir /export/ --checkpoint /pikachu_train/ --which_direction BtoA 139 | ``` 140 | It will create a new `export` folder 141 | 142 | ### 5. Port the model to tensorflow.js 143 | I followed [affinelayer](https://github.com/affinelayer)'s instruction here: [https://github.com/affinelayer/pix2pix-tensorflow/tree/master/server#exporting](https://github.com/affinelayer/pix2pix-tensorflow/tree/master/server#exporting) 144 | 145 | ``` 146 | cd server 147 | python tools/export-checkpoint.py --checkpoint ../export --output_file static/models/pikachu_BtoA.pict 148 | ``` 149 | We should be able to get a file named `pikachu_BtoA.pict`, which is 54.4 MB. 150 | If you add `--ngf 32 --ndf 32` when training the model: python pix2pix.py --mode train --output_dir pikachu_train --max_epochs 200 --input_dir pikachu/train --which_direction BtoA --ngf 32 --ndf 32, the model will be smaller 13.6 MB, and it will take less time to train. 151 | 152 | ### 6. Create an interactive interface in the browser 153 | Copy the model we get from step 5 to the `models` folder. 154 | 155 | -------------------------------------------------------------------------------- /images/eraser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yining1023/pix2pix_tensorflowjs/2eb2e726e33d3aeda441ac1bb6d01a3289275029/images/eraser.png -------------------------------------------------------------------------------- /images/input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yining1023/pix2pix_tensorflowjs/2eb2e726e33d3aeda441ac1bb6d01a3289275029/images/input.png -------------------------------------------------------------------------------- /images/input1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yining1023/pix2pix_tensorflowjs/2eb2e726e33d3aeda441ac1bb6d01a3289275029/images/input1.png -------------------------------------------------------------------------------- /images/input2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yining1023/pix2pix_tensorflowjs/2eb2e726e33d3aeda441ac1bb6d01a3289275029/images/input2.png -------------------------------------------------------------------------------- /images/input3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yining1023/pix2pix_tensorflowjs/2eb2e726e33d3aeda441ac1bb6d01a3289275029/images/input3.png -------------------------------------------------------------------------------- /images/input4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yining1023/pix2pix_tensorflowjs/2eb2e726e33d3aeda441ac1bb6d01a3289275029/images/input4.png -------------------------------------------------------------------------------- /images/input5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yining1023/pix2pix_tensorflowjs/2eb2e726e33d3aeda441ac1bb6d01a3289275029/images/input5.png -------------------------------------------------------------------------------- /images/input6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yining1023/pix2pix_tensorflowjs/2eb2e726e33d3aeda441ac1bb6d01a3289275029/images/input6.png -------------------------------------------------------------------------------- /images/input7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yining1023/pix2pix_tensorflowjs/2eb2e726e33d3aeda441ac1bb6d01a3289275029/images/input7.png -------------------------------------------------------------------------------- /images/output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yining1023/pix2pix_tensorflowjs/2eb2e726e33d3aeda441ac1bb6d01a3289275029/images/output.png -------------------------------------------------------------------------------- /images/output1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yining1023/pix2pix_tensorflowjs/2eb2e726e33d3aeda441ac1bb6d01a3289275029/images/output1.png -------------------------------------------------------------------------------- /images/output2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yining1023/pix2pix_tensorflowjs/2eb2e726e33d3aeda441ac1bb6d01a3289275029/images/output2.png -------------------------------------------------------------------------------- /images/output3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yining1023/pix2pix_tensorflowjs/2eb2e726e33d3aeda441ac1bb6d01a3289275029/images/output3.png -------------------------------------------------------------------------------- /images/output4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yining1023/pix2pix_tensorflowjs/2eb2e726e33d3aeda441ac1bb6d01a3289275029/images/output4.png -------------------------------------------------------------------------------- /images/output5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yining1023/pix2pix_tensorflowjs/2eb2e726e33d3aeda441ac1bb6d01a3289275029/images/output5.png -------------------------------------------------------------------------------- /images/output6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yining1023/pix2pix_tensorflowjs/2eb2e726e33d3aeda441ac1bb6d01a3289275029/images/output6.png -------------------------------------------------------------------------------- /images/output7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yining1023/pix2pix_tensorflowjs/2eb2e726e33d3aeda441ac1bb6d01a3289275029/images/output7.png -------------------------------------------------------------------------------- /images/pencil.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yining1023/pix2pix_tensorflowjs/2eb2e726e33d3aeda441ac1bb6d01a3289275029/images/pencil.png -------------------------------------------------------------------------------- /images/pokeball.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yining1023/pix2pix_tensorflowjs/2eb2e726e33d3aeda441ac1bb6d01a3289275029/images/pokeball.png -------------------------------------------------------------------------------- /index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 |

Pix2Pix Edges2Pikachu

13 |

1. Select 'Pencil' tool to draw a Pikachu on the canvas below.

14 |

2. Click the 'Transfer' button.

15 |

3. A colored Pikachu image will appear on the right side in ~5s.

16 |

4. Click the 'Clear' button to clear the canvas and draw again.

17 |
18 |
19 | 20 | 21 |
22 |
23 |
24 |
25 |
26 | 27 | 28 |
29 |
30 |
31 |
32 | 33 |

Loading Model...

34 |
35 |
36 |
37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /index.js: -------------------------------------------------------------------------------- 1 | const SIZE = 256, sampleNum = 7; 2 | let inputCanvas, outputContainer, statusMsg, transferBtn, sampleIndex = 0, modelReady = false, isTransfering = false; 3 | const inputImgs = [], outputImgs = []; 4 | 5 | const edges2pikachu = pix2pix('./models/edges2pikachu_AtoB.pict', modelLoaded); 6 | 7 | function setup() { 8 | // Create canvas 9 | inputCanvas = createCanvas(SIZE, SIZE); 10 | inputCanvas.class('border-box pencil').parent('canvasContainer'); 11 | 12 | // Selcect output div container 13 | outputContainer = select('#output'); 14 | statusMsg = select('#status'); 15 | transferBtn = select('#transferBtn').hide(); 16 | 17 | // Display initial input image 18 | loadImage('./images/input.png', inputImg => image(inputImg, 0, 0)); 19 | 20 | // Display initial output image 21 | let out = createImg('./images/output.png'); 22 | outputContainer.html(''); 23 | out.class('border-box').parent('output'); 24 | 25 | // Load other sample input/output images 26 | for (let i = 1; i <= sampleNum; i += 1) { 27 | loadImage(`./images/input${i}.png`, inImg => { 28 | inputImgs.push(inImg); 29 | let outImg = createImg(`./images/output${i}.png`); 30 | outImg.hide().class('border-box'); 31 | outputImgs.push(outImg); 32 | }); 33 | } 34 | 35 | // Set stroke to black 36 | stroke(0); 37 | pixelDensity(1); 38 | } 39 | 40 | // Draw on the canvas when mouse is pressed 41 | function draw() { 42 | if (mouseIsPressed) { 43 | line(mouseX, mouseY, pmouseX, pmouseY); 44 | } 45 | } 46 | 47 | function mouseReleased() { 48 | if (modelReady && !isTransfering) { 49 | transfer() 50 | } 51 | } 52 | 53 | function transfer() { 54 | isTransfering = true; 55 | // Update status message 56 | statusMsg.html('Applying Style Transfer...!'); 57 | 58 | // Select canvas DOM element 59 | let canvasElement = document.getElementById('defaultCanvas0'); 60 | // Apply pix2pix transformation 61 | edges2pikachu.transfer(canvasElement, result => { 62 | // Clear output container 63 | outputContainer.html(''); 64 | // Create an image based result 65 | createImg(result.src).class('border-box').parent('output'); 66 | statusMsg.html('Done!'); 67 | isTransfering = false; 68 | }); 69 | } 70 | 71 | // A function to be called when the models have loaded 72 | function modelLoaded() { 73 | if (!statusMsg) statusMsg = select('#status'); 74 | statusMsg.html('Model Loaded!'); 75 | transferBtn.show(); 76 | modelReady = true; 77 | } 78 | 79 | // Clear the canvas 80 | function clearCanvas() { 81 | background(255); 82 | } 83 | 84 | function getRandomOutput() { 85 | image(inputImgs[sampleIndex], 0, 0); 86 | outputContainer.html(''); 87 | outputImgs[sampleIndex].show().parent('output'); 88 | sampleIndex += 1; 89 | if (sampleIndex > 6) sampleIndex = 0; 90 | } 91 | 92 | function usePencil() { 93 | stroke(0); 94 | strokeWeight(1); 95 | inputCanvas.removeClass('eraser'); 96 | inputCanvas.addClass('pencil'); 97 | } 98 | 99 | function useEraser() { 100 | stroke(255); 101 | strokeWeight(15); 102 | inputCanvas.removeClass('pencil'); 103 | inputCanvas.addClass('eraser'); 104 | } 105 | -------------------------------------------------------------------------------- /models/edges2pikachu_AtoB.pict: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yining1023/pix2pix_tensorflowjs/2eb2e726e33d3aeda441ac1bb6d01a3289275029/models/edges2pikachu_AtoB.pict -------------------------------------------------------------------------------- /pix2pix.js: -------------------------------------------------------------------------------- 1 | class Pix2pix { 2 | constructor(model, callback) { 3 | this.ready = false; 4 | 5 | this.loadCheckpoints(model).then(() => { 6 | this.ready = true; 7 | if (callback) { 8 | callback(); 9 | } 10 | }); 11 | } 12 | 13 | async loadCheckpoints(path) { 14 | this.weights = await fetchWeights(path); 15 | } 16 | 17 | async transfer(inputElement, callback = () => { }) { 18 | const input = tf.browser.fromPixels(inputElement); 19 | const inputData = input.dataSync(); 20 | const floatInput = tf.tensor3d(inputData, input.shape, 'float32'); 21 | const normalizedInput = tf.div(floatInput, tf.scalar(255)); 22 | 23 | function preprocess(inputPreproc) { 24 | return tf.sub(tf.mul(inputPreproc, tf.scalar(2)), tf.scalar(1)); 25 | } 26 | 27 | function deprocess(inputDeproc) { 28 | return tf.div(tf.add(inputDeproc, tf.scalar(1)), tf.scalar(2)); 29 | } 30 | 31 | function batchnorm(inputBat, scale, offset) { 32 | const moments = tf.moments(inputBat, [0, 1]); 33 | const varianceEpsilon = 1e-5; 34 | return tf.batchNorm(inputBat, moments.mean, moments.variance, offset, scale, varianceEpsilon); 35 | } 36 | 37 | function conv2d(inputCon, filterCon) { 38 | return tf.conv2d(inputCon, filterCon, [2, 2], 'same'); 39 | } 40 | 41 | function deconv2d(inputDeconv, filterDeconv, biasDecon) { 42 | const convolved = tf.conv2dTranspose(inputDeconv, filterDeconv, [inputDeconv.shape[0] * 2, inputDeconv.shape[1] * 2, filterDeconv.shape[2]], [2, 2], 'same'); 43 | const biased = tf.add(convolved, biasDecon); 44 | return biased; 45 | } 46 | 47 | const result = tf.tidy(() => { 48 | const preprocessedInput = preprocess(normalizedInput); 49 | const layers = []; 50 | let filter = this.weights['generator/encoder_1/conv2d/kernel']; 51 | let bias = this.weights['generator/encoder_1/conv2d/bias']; 52 | let convolved = conv2d(preprocessedInput, filter, bias); 53 | layers.push(convolved); 54 | 55 | for (let i = 2; i <= 8; i += 1) { 56 | const scope = `generator/encoder_${i.toString()}`; 57 | filter = this.weights[`${scope}/conv2d/kernel`]; 58 | const bias2 = this.weights[`${scope}/conv2d/bias`]; 59 | const layerInput = layers[layers.length - 1]; 60 | const rectified = tf.leakyRelu(layerInput, 0.2); 61 | convolved = conv2d(rectified, filter, bias2); 62 | const scale = this.weights[`${scope}/batch_normalization/gamma`]; 63 | const offset = this.weights[`${scope}/batch_normalization/beta`]; 64 | const normalized = batchnorm(convolved, scale, offset); 65 | layers.push(normalized); 66 | } 67 | 68 | for (let i = 8; i >= 2; i -= 1) { 69 | let layerInput; 70 | if (i === 8) { 71 | layerInput = layers[layers.length - 1]; 72 | } else { 73 | const skipLayer = i - 1; 74 | layerInput = tf.concat([layers[layers.length - 1], layers[skipLayer]], 2); 75 | } 76 | const rectified = tf.relu(layerInput); 77 | const scope = `generator/decoder_${i.toString()}`; 78 | filter = this.weights[`${scope}/conv2d_transpose/kernel`]; 79 | bias = this.weights[`${scope}/conv2d_transpose/bias`]; 80 | convolved = deconv2d(rectified, filter, bias); 81 | const scale = this.weights[`${scope}/batch_normalization/gamma`]; 82 | const offset = this.weights[`${scope}/batch_normalization/beta`]; 83 | const normalized = batchnorm(convolved, scale, offset); 84 | layers.push(normalized); 85 | } 86 | 87 | const layerInput = tf.concat([layers[layers.length - 1], layers[0]], 2); 88 | let rectified2 = tf.relu(layerInput); 89 | filter = this.weights['generator/decoder_1/conv2d_transpose/kernel']; 90 | const bias3 = this.weights['generator/decoder_1/conv2d_transpose/bias']; 91 | convolved = deconv2d(rectified2, filter, bias3); 92 | rectified2 = tf.tanh(convolved); 93 | layers.push(rectified2); 94 | 95 | const output = layers[layers.length - 1]; 96 | const deprocessedOutput = deprocess(output); 97 | return deprocessedOutput; 98 | }); 99 | 100 | await tf.nextFrame(); 101 | callback(array3DToImage(result)); 102 | } 103 | } 104 | 105 | const pix2pix = (model, callback = () => { }) => new Pix2pix(model, callback); 106 | -------------------------------------------------------------------------------- /style.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: monospace; 3 | background: #eee; 4 | } 5 | 6 | .border-box { 7 | background: white; 8 | border-radius: 7px; 9 | margin: 20px; 10 | } 11 | 12 | .flex { 13 | display: flex; 14 | } 15 | 16 | .line-height { 17 | line-height: 2.5; 18 | } 19 | 20 | .flex-space-between { 21 | justify-content: space-between; 22 | } 23 | 24 | #toolContainer { 25 | width: 258px; 26 | } 27 | 28 | .pencil { 29 | cursor: url("images/pencil.png") 0 30, auto; 30 | } 31 | 32 | .eraser { 33 | cursor: url("images/eraser.png") 15 15, auto; 34 | } 35 | 36 | button { 37 | cursor: pointer; 38 | height: 40px; 39 | background: #FFD524; 40 | border-radius: 20px; 41 | border: none; 42 | color: white; 43 | font-family: monospace; 44 | font-size: 13px; 45 | } 46 | 47 | button:hover { 48 | background-color: #FC9A24; 49 | } 50 | 51 | #textContainer { 52 | margin: 20px; 53 | } 54 | 55 | #transferContainer { 56 | margin-top: 80px; 57 | } 58 | 59 | #btnContainer { 60 | margin: 0 20px; 61 | } 62 | 63 | #toolContainer { 64 | width: 256px; 65 | margin: 0 20px; 66 | } 67 | 68 | .pokeball { 69 | animation: pokeball-spin infinite 5s linear; 70 | height: 60px; 71 | margin-left: 20px; 72 | margin-bottom: 10px; 73 | } 74 | 75 | @keyframes pokeball-spin { 76 | 0% { 77 | transform:rotate(0deg); 78 | } 79 | 25%{ 80 | transform:rotate(-75deg); 81 | } 82 | 83 | 75%{ 84 | transform:rotate(75deg); 85 | } 86 | 87 | 100% { 88 | transform:rotate(0deg); 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /utils.js: -------------------------------------------------------------------------------- 1 | // Fetch weights from path 2 | const fetchWeights = (urlPath) => { 3 | return new Promise((resolve, reject) => { 4 | const weightsCache = {}; 5 | if (urlPath in weightsCache) { 6 | resolve(weightsCache[urlPath]); 7 | return; 8 | } 9 | 10 | const xhr = new XMLHttpRequest(); 11 | xhr.open('GET', urlPath, true); 12 | xhr.responseType = 'arraybuffer'; 13 | xhr.onload = () => { 14 | if (xhr.status !== 200) { 15 | reject(new Error('missing model')); 16 | return; 17 | } 18 | const buf = xhr.response; 19 | if (!buf) { 20 | reject(new Error('invalid arraybuffer')); 21 | return; 22 | } 23 | 24 | const parts = []; 25 | let offset = 0; 26 | while (offset < buf.byteLength) { 27 | const b = new Uint8Array(buf.slice(offset, offset + 4)); 28 | offset += 4; 29 | const len = (b[0] << 24) + (b[1] << 16) + (b[2] << 8) + b[3]; // eslint-disable-line no-bitwise 30 | parts.push(buf.slice(offset, offset + len)); 31 | offset += len; 32 | } 33 | 34 | const shapes = JSON.parse((new TextDecoder('utf8')).decode(parts[0])); 35 | const index = new Float32Array(parts[1]); 36 | const encoded = new Uint8Array(parts[2]); 37 | 38 | // decode using index 39 | const arr = new Float32Array(encoded.length); 40 | for (let i = 0; i < arr.length; i += 1) { 41 | arr[i] = index[encoded[i]]; 42 | } 43 | 44 | const weights = {}; 45 | offset = 0; 46 | for (let i = 0; i < shapes.length; i += 1) { 47 | const { shape } = shapes[i]; 48 | const size = shape.reduce((total, num) => total * num); 49 | const values = arr.slice(offset, offset + size); 50 | const tfarr = tf.tensor1d(values, 'float32'); 51 | weights[shapes[i].name] = tfarr.reshape(shape); 52 | offset += size; 53 | } 54 | weightsCache[urlPath] = weights; 55 | resolve(weights); 56 | }; 57 | xhr.send(null); 58 | }); 59 | } 60 | 61 | // Converts a tf to DOM img element 62 | const array3DToImage = (tensor) => { 63 | const [imgWidth, imgHeight] = tensor.shape; 64 | const data = tensor.dataSync(); 65 | const canvas = document.createElement('canvas'); 66 | canvas.width = imgWidth; 67 | canvas.height = imgHeight; 68 | const ctx = canvas.getContext('2d'); 69 | const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height); 70 | 71 | for (let i = 0; i < imgWidth * imgHeight; i += 1) { 72 | const j = i * 4; 73 | const k = i * 3; 74 | imageData.data[j + 0] = Math.floor(256 * data[k + 0]); 75 | imageData.data[j + 1] = Math.floor(256 * data[k + 1]); 76 | imageData.data[j + 2] = Math.floor(256 * data[k + 2]); 77 | imageData.data[j + 3] = 255; 78 | } 79 | ctx.putImageData(imageData, 0, 0); 80 | 81 | // Create img HTML element from canvas 82 | const dataUrl = canvas.toDataURL(); 83 | const outputImg = document.createElement('img'); 84 | outputImg.src = dataUrl; 85 | outputImg.style.width = imgWidth; 86 | outputImg.style.height = imgHeight; 87 | return outputImg; 88 | }; 89 | --------------------------------------------------------------------------------