├── src ├── App.css ├── react-app-env.d.ts ├── sample-images │ ├── cat_text.png │ ├── car_inpaint.png │ └── lenna_noisy.png ├── App.test.tsx ├── index.css ├── index.tsx ├── LabeledCheckbox.tsx ├── LabeledSlider.tsx ├── models │ ├── ConvNet.ts │ └── UNet.ts ├── logo.svg ├── DrawableCanvas.tsx ├── AppState.ts ├── serviceWorker.ts ├── Painter.tsx └── App.tsx ├── public ├── favicon.ico ├── manifest.json └── index.html ├── .gitignore ├── tsconfig.json ├── LICENSE ├── package.json └── README.md /src/App.css: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /public/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RobinKa/web-deep-image-prior/HEAD/public/favicon.ico -------------------------------------------------------------------------------- /src/react-app-env.d.ts: -------------------------------------------------------------------------------- 1 | /// 2 | 3 | declare module "react-compare-image" -------------------------------------------------------------------------------- /src/sample-images/cat_text.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RobinKa/web-deep-image-prior/HEAD/src/sample-images/cat_text.png -------------------------------------------------------------------------------- /src/sample-images/car_inpaint.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RobinKa/web-deep-image-prior/HEAD/src/sample-images/car_inpaint.png -------------------------------------------------------------------------------- /src/sample-images/lenna_noisy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RobinKa/web-deep-image-prior/HEAD/src/sample-images/lenna_noisy.png -------------------------------------------------------------------------------- /src/App.test.tsx: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import ReactDOM from 'react-dom'; 3 | import App from './App'; 4 | 5 | it('renders without crashing', () => { 6 | const div = document.createElement('div'); 7 | ReactDOM.render(, div); 8 | ReactDOM.unmountComponentAtNode(div); 9 | }); 10 | -------------------------------------------------------------------------------- /public/manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | "short_name": "Deep Image Prior", 3 | "name": "Client-side Deep Image Prior", 4 | "icons": [ 5 | { 6 | "src": "favicon.ico", 7 | "sizes": "64x64 32x32 24x24 16x16", 8 | "type": "image/x-icon" 9 | } 10 | ], 11 | "start_url": ".", 12 | "display": "standalone", 13 | "theme_color": "#000000", 14 | "background_color": "#ffffff" 15 | } 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. 2 | 3 | # dependencies 4 | /node_modules 5 | /.pnp 6 | .pnp.js 7 | 8 | # testing 9 | /coverage 10 | 11 | # production 12 | /build 13 | 14 | # misc 15 | .DS_Store 16 | .env.local 17 | .env.development.local 18 | .env.test.local 19 | .env.production.local 20 | 21 | npm-debug.log* 22 | yarn-debug.log* 23 | yarn-error.log* 24 | -------------------------------------------------------------------------------- /src/index.css: -------------------------------------------------------------------------------- 1 | body { 2 | margin: 0; 3 | font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Oxygen', 4 | 'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue', 5 | sans-serif; 6 | -webkit-font-smoothing: antialiased; 7 | -moz-osx-font-smoothing: grayscale; 8 | } 9 | 10 | code { 11 | font-family: source-code-pro, Menlo, Monaco, Consolas, 'Courier New', 12 | monospace; 13 | } 14 | -------------------------------------------------------------------------------- /src/index.tsx: -------------------------------------------------------------------------------- 1 | import 'react-app-polyfill/ie9' 2 | import 'react-app-polyfill/stable' 3 | import "fast-text-encoding/text.min.js" 4 | 5 | import React from 'react' 6 | import ReactDOM from 'react-dom' 7 | import './index.css' 8 | import App from './App' 9 | import * as serviceWorker from './serviceWorker' 10 | 11 | ReactDOM.render(, document.getElementById('root')) 12 | 13 | // If you want your app to work offline and load faster, you can change 14 | // unregister() to register() below. Note this comes with some pitfalls. 15 | // Learn more about service workers: https://bit.ly/CRA-PWA 16 | serviceWorker.register() 17 | -------------------------------------------------------------------------------- /src/LabeledCheckbox.tsx: -------------------------------------------------------------------------------- 1 | import React, { CSSProperties } from "react" 2 | 3 | type LabeledCheckboxProps = { 4 | disabled: boolean, 5 | setValue: (value: boolean) => void 6 | value: boolean 7 | label: string 8 | style: CSSProperties 9 | } 10 | 11 | export default function LabeledCheckbox(props: LabeledCheckboxProps) { 12 | return ( 13 |
14 | props.setValue(evt.target.checked)} /> 15 | 16 |
17 | ) 18 | } -------------------------------------------------------------------------------- /tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "target": "es5", 4 | "lib": [ 5 | "dom", 6 | "dom.iterable", 7 | "esnext" 8 | ], 9 | "allowJs": true, 10 | "skipLibCheck": true, 11 | "esModuleInterop": true, 12 | "allowSyntheticDefaultImports": true, 13 | "strict": true, 14 | "forceConsistentCasingInFileNames": true, 15 | "module": "esnext", 16 | "moduleResolution": "node", 17 | "resolveJsonModule": true, 18 | "isolatedModules": true, 19 | "noEmit": true, 20 | "jsx": "preserve" 21 | }, 22 | "include": [ 23 | "src" 24 | ] 25 | } 26 | -------------------------------------------------------------------------------- /src/LabeledSlider.tsx: -------------------------------------------------------------------------------- 1 | import React, { CSSProperties } from "react" 2 | import Slider from 'rc-slider' 3 | 4 | type LabeledSliderProps = { 5 | disabled: boolean, 6 | setValue: (value: number) => void 7 | value: number 8 | label: string 9 | min: number 10 | max: number 11 | step: number 12 | style: CSSProperties 13 | } 14 | 15 | export default function LabeledSlider(props: LabeledSliderProps) { 16 | return ( 17 |
18 | props.setValue(value)} /> 19 | 20 |
21 | ) 22 | } -------------------------------------------------------------------------------- /src/models/ConvNet.ts: -------------------------------------------------------------------------------- 1 | import * as tf from "@tensorflow/tfjs" 2 | 3 | export function createConvNet(inputShape: [number, number, number], outputFilters: number, layers: number, filters: number) { 4 | const model = tf.sequential() 5 | 6 | const hiddenShape = [inputShape[0], inputShape[1], filters] 7 | 8 | for (let layerIndex = 0; layerIndex < layers - 1; layerIndex++) { 9 | model.add(tf.layers.conv2d({ 10 | inputShape: layerIndex === 0 ? inputShape : hiddenShape, 11 | kernelSize: [3, 3], 12 | padding: "same", 13 | filters: filters, 14 | activation: "relu" 15 | })) 16 | 17 | model.add(tf.layers.batchNormalization()) 18 | } 19 | 20 | model.add(tf.layers.conv2d({ 21 | inputShape: layers === 1 ? inputShape : hiddenShape, 22 | kernelSize: [3, 3], 23 | padding: "same", 24 | filters: 3, 25 | activation: "tanh" 26 | })) 27 | 28 | return model 29 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Robin Kahlow 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "deep-image-prior", 3 | "version": "0.1.1", 4 | "private": true, 5 | "homepage": "https://warlock.ai/deepimageprior", 6 | "dependencies": { 7 | "@tensorflow/tfjs": "^1.7.3", 8 | "@types/file-saver": "^2.0.1", 9 | "@types/jest": "24.0.16", 10 | "@types/node": "12.6.9", 11 | "@types/rc-slider": "^8.6.5", 12 | "@types/react": "16.8.24", 13 | "@types/react-dom": "16.8.5", 14 | "bootstrap": "^4.4.1", 15 | "fast-text-encoding": "^1.0.2", 16 | "rc-slider": "^8.7.1", 17 | "react": "^16.13.1", 18 | "react-app-polyfill": "^1.0.6", 19 | "react-bootstrap": "^1.0.1", 20 | "react-compare-image": "^1.4.1", 21 | "react-dom": "^16.13.1", 22 | "react-dropzone": "^10.2.2", 23 | "react-scripts": "^3.4.1", 24 | "typescript": "3.5.3" 25 | }, 26 | "scripts": { 27 | "start": "react-scripts start", 28 | "build": "react-scripts build", 29 | "test": "react-scripts test", 30 | "eject": "react-scripts eject" 31 | }, 32 | "eslintConfig": { 33 | "extends": "react-app" 34 | }, 35 | "browserslist": { 36 | "production": [ 37 | ">0.2%", 38 | "not dead", 39 | "not op_mini all" 40 | ], 41 | "development": [ 42 | "last 1 chrome version", 43 | "last 1 firefox version", 44 | "last 1 safari version" 45 | ] 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /public/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 12 | 13 | 22 | Deep Image Prior 23 | 24 | 25 | 26 |
27 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /src/models/UNet.ts: -------------------------------------------------------------------------------- 1 | import * as tf from "@tensorflow/tfjs" 2 | 3 | export function createUNet(inputShape: [number, number, number], outputFilters: number, layers: number, filters: number, skip: boolean) { 4 | const input = tf.input({ shape: inputShape }) 5 | 6 | const downs = [input] 7 | for (let i = 0; i < layers; i++) { 8 | downs.push(tf.layers.conv2d({ 9 | filters: Math.min(256, Math.pow(2, i) * filters), 10 | kernelSize: [4, 4], 11 | padding: "same", 12 | strides: 2, 13 | activation: "elu", 14 | }).apply(downs[downs.length - 1]) as tf.SymbolicTensor) 15 | } 16 | 17 | const ups = [downs[downs.length - 1]] 18 | for (let i = 0; i < layers; i++) { 19 | const last = i === layers - 1 20 | 21 | const upsampled = tf.layers.conv2dTranspose({ 22 | filters: Math.min(256, Math.pow(2, layers - i - 1) * filters), 23 | kernelSize: [4, 4], 24 | padding: "same", 25 | strides: 2, 26 | activation: "elu", 27 | }).apply(ups[ups.length - 1]) as tf.SymbolicTensor 28 | 29 | const concatenated = skip ? tf.layers.concatenate({axis: -1}).apply([upsampled, downs[layers - i - 1]]) : upsampled 30 | 31 | const processed = tf.layers.conv2d({ 32 | filters: last ? outputFilters : Math.min(256, Math.pow(2, layers - i - 1) * filters), 33 | kernelSize: [4, 4], 34 | padding: "same", 35 | strides: 1, 36 | activation: last ? "tanh" : "elu", 37 | }).apply(concatenated) as tf.SymbolicTensor 38 | 39 | ups.push(processed) 40 | } 41 | 42 | return tf.model({inputs: input, outputs: ups[ups.length - 1]}) 43 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This project was bootstrapped with [Create React App](https://github.com/facebook/create-react-app). 2 | 3 | ## Available Scripts 4 | 5 | In the project directory, you can run: 6 | 7 | ### `npm start` 8 | 9 | Runs the app in the development mode.
10 | Open [http://localhost:3000](http://localhost:3000) to view it in the browser. 11 | 12 | The page will reload if you make edits.
13 | You will also see any lint errors in the console. 14 | 15 | ### `npm test` 16 | 17 | Launches the test runner in the interactive watch mode.
18 | See the section about [running tests](https://facebook.github.io/create-react-app/docs/running-tests) for more information. 19 | 20 | ### `npm run build` 21 | 22 | Builds the app for production to the `build` folder.
23 | It correctly bundles React in production mode and optimizes the build for the best performance. 24 | 25 | The build is minified and the filenames include the hashes.
26 | Your app is ready to be deployed! 27 | 28 | See the section about [deployment](https://facebook.github.io/create-react-app/docs/deployment) for more information. 29 | 30 | ### `npm run eject` 31 | 32 | **Note: this is a one-way operation. Once you `eject`, you can’t go back!** 33 | 34 | If you aren’t satisfied with the build tool and configuration choices, you can `eject` at any time. This command will remove the single build dependency from your project. 35 | 36 | Instead, it will copy all the configuration files and the transitive dependencies (Webpack, Babel, ESLint, etc) right into your project so you have full control over them. All of the commands except `eject` will still work, but they will point to the copied scripts so you can tweak them. At this point you’re on your own. 37 | 38 | You don’t have to ever use `eject`. The curated feature set is suitable for small and middle deployments, and you shouldn’t feel obligated to use this feature. However we understand that this tool wouldn’t be useful if you couldn’t customize it when you are ready for it. 39 | 40 | ## Learn More 41 | 42 | You can learn more in the [Create React App documentation](https://facebook.github.io/create-react-app/docs/getting-started). 43 | 44 | To learn React, check out the [React documentation](https://reactjs.org/). 45 | -------------------------------------------------------------------------------- /src/logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /src/DrawableCanvas.tsx: -------------------------------------------------------------------------------- 1 | import React, { useState, useRef, Dispatch, useEffect } from "react" 2 | import { AppState, AppUpdateAction } from "./AppState" 3 | 4 | type DrawableCanvasProps = { 5 | state: AppState 6 | dispatchState: Dispatch 7 | backgroundImage: string 8 | } 9 | 10 | export default function DrawableCanvas(props: DrawableCanvasProps) { 11 | const [width, height] = [props.state.algorithmSettings.width, props.state.algorithmSettings.height] 12 | const maskCanvas = props.state.maskCanvas 13 | const dispatchState = props.dispatchState 14 | 15 | const [drawing, setDrawing] = useState(false) 16 | 17 | const canvas = useRef(null) 18 | 19 | function tryStartDraw() { 20 | if (props.state.step === "idle" && !props.state.shouldRun) { 21 | setDrawing(true) 22 | } 23 | } 24 | 25 | function endDraw() { 26 | setDrawing(false) 27 | } 28 | 29 | function reset() { 30 | const cnv = canvas.current! 31 | const ctx = cnv.getContext("2d")! 32 | 33 | ctx.fillStyle = "white" 34 | ctx.fillRect(0, 0, width, height) 35 | } 36 | 37 | function onMouseUp(event: React.MouseEvent) { 38 | endDraw() 39 | if (event.button === 2) { 40 | reset() 41 | } 42 | } 43 | 44 | function onMove(clientPos: [number, number]) { 45 | if (drawing) { 46 | const cnv = canvas.current! 47 | const ctx = cnv.getContext("2d")! 48 | const bounds = cnv.getBoundingClientRect() 49 | 50 | const mousePos = [clientPos[0] - bounds.left, clientPos[1] - bounds.top] 51 | const radius = 5 52 | 53 | ctx.beginPath() 54 | ctx.arc(mousePos[0], mousePos[1], radius, 0, 2 * Math.PI) 55 | ctx.fillStyle = "black" 56 | ctx.fill() 57 | } 58 | } 59 | 60 | function onTouchMove(e: React.TouchEvent) { 61 | onMove([ 62 | e.targetTouches[0] ? e.targetTouches[0].pageX : e.changedTouches[e.changedTouches.length - 1].pageX, 63 | e.targetTouches[0] ? e.targetTouches[0].pageY : e.changedTouches[e.changedTouches.length - 1].pageY 64 | ]) 65 | } 66 | 67 | useEffect(() => { 68 | const cnv = canvas.current! 69 | const ctx = cnv.getContext("2d")! 70 | ctx.fillStyle = "white" 71 | ctx.fillRect(0, 0, width, height) 72 | }, [canvas, width, height]) 73 | 74 | useEffect(() => { 75 | if (canvas.current !== maskCanvas) { 76 | dispatchState({ 77 | type: "setMaskCanvas", 78 | maskCanvas: canvas.current 79 | }) 80 | } 81 | }, [canvas, maskCanvas, dispatchState]) 82 | 83 | return ( 84 | onMove([e.clientX, e.clientY])} /> 88 | ) 89 | } -------------------------------------------------------------------------------- /src/AppState.ts: -------------------------------------------------------------------------------- 1 | import { useReducer } from "react" 2 | 3 | export type ImageData = { 4 | uri: string 5 | iteration: number 6 | } 7 | 8 | export type AlgorithmSettings = { 9 | filters: number 10 | layers: number 11 | width: number 12 | height: number 13 | inpaint: boolean 14 | } 15 | 16 | export type AppState = { 17 | step: "idle" | "runIter" | "finishedIter" 18 | shouldRun: boolean 19 | 20 | images: ImageData[] 21 | maskCanvas: HTMLCanvasElement | null 22 | algorithmSettings: AlgorithmSettings 23 | sourceImage: number[] | null 24 | iteration: number 25 | } 26 | 27 | export type AppUpdateReset = { type: "reset" } 28 | export type AppUpdateStart = { type: "start" } 29 | export type AppUpdatePause = { type: "pause" } 30 | 31 | export type AppUpdateAlgorithmSettings = { 32 | type: "algorithmSettings" 33 | newSettings: AlgorithmSettings 34 | } 35 | export type AppUpdateSetSourceImage = { 36 | type: "setSourceImage", 37 | image: number[] 38 | } 39 | export type AppUpdateFinishIter = { 40 | type: "finishIter", 41 | imageData: ImageData | undefined 42 | } 43 | export type AppUpdateSetMaskCanvas = { 44 | type: "setMaskCanvas" 45 | maskCanvas: HTMLCanvasElement | null 46 | } 47 | export type AppUpdateStartIter = { 48 | type: "startIter" 49 | } 50 | export type AppUpdateStopped = { 51 | type: "stopped" 52 | } 53 | 54 | export type AppUpdateAction = AppUpdateReset | AppUpdateStart | AppUpdatePause | 55 | AppUpdateAlgorithmSettings | AppUpdateSetSourceImage | AppUpdateFinishIter | AppUpdateSetMaskCanvas | AppUpdateStartIter | AppUpdateStopped 56 | 57 | function updateAppState(state: AppState, action: AppUpdateAction) { 58 | const newState = { ...state } 59 | 60 | switch (action.type) { 61 | case "reset": 62 | newState.images = [] 63 | newState.shouldRun = false 64 | newState.iteration = 0 65 | newState.algorithmSettings = { 66 | filters: 8, 67 | layers: 5, 68 | width: 256, 69 | height: 256, 70 | inpaint: false, 71 | } 72 | break 73 | case "start": 74 | newState.shouldRun = true 75 | break 76 | case "pause": 77 | newState.shouldRun = false 78 | break 79 | case "algorithmSettings": 80 | newState.algorithmSettings = action.newSettings 81 | newState.images = [] 82 | newState.iteration = 0 83 | break 84 | case "setSourceImage": 85 | newState.sourceImage = action.image 86 | newState.images = [] 87 | newState.iteration = 0 88 | break 89 | case "finishIter": 90 | if (newState.shouldRun) { 91 | newState.iteration += 1 92 | if (action.imageData) { 93 | newState.images.push(action.imageData) 94 | } 95 | } 96 | newState.step = "finishedIter" 97 | break 98 | case "stopped": 99 | newState.step = "idle" 100 | newState.iteration = 0 101 | break 102 | case "startIter": 103 | newState.step = "runIter" 104 | break 105 | case "setMaskCanvas": 106 | newState.maskCanvas = action.maskCanvas 107 | break 108 | default: 109 | throw new Error("Unhandled action in state update: " + JSON.stringify(action)) 110 | } 111 | 112 | return newState 113 | } 114 | 115 | export function useAppState() { 116 | return useReducer(updateAppState, { 117 | step: "idle", 118 | shouldRun: false, 119 | iteration: 0, 120 | images: [], 121 | algorithmSettings: { 122 | filters: 8, 123 | layers: 5, 124 | width: 256, 125 | height: 256, 126 | inpaint: false, 127 | }, 128 | sourceImage: null, 129 | maskCanvas: null, 130 | }) 131 | } 132 | -------------------------------------------------------------------------------- /src/serviceWorker.ts: -------------------------------------------------------------------------------- 1 | // This optional code is used to register a service worker. 2 | // register() is not called by default. 3 | 4 | // This lets the app load faster on subsequent visits in production, and gives 5 | // it offline capabilities. However, it also means that developers (and users) 6 | // will only see deployed updates on subsequent visits to a page, after all the 7 | // existing tabs open on the page have been closed, since previously cached 8 | // resources are updated in the background. 9 | 10 | // To learn more about the benefits of this model and instructions on how to 11 | // opt-in, read https://bit.ly/CRA-PWA 12 | 13 | const isLocalhost = Boolean( 14 | window.location.hostname === 'localhost' || 15 | // [::1] is the IPv6 localhost address. 16 | window.location.hostname === '[::1]' || 17 | // 127.0.0.1/8 is considered localhost for IPv4. 18 | window.location.hostname.match( 19 | /^127(?:\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)){3}$/ 20 | ) 21 | ); 22 | 23 | type Config = { 24 | onSuccess?: (registration: ServiceWorkerRegistration) => void; 25 | onUpdate?: (registration: ServiceWorkerRegistration) => void; 26 | }; 27 | 28 | export function register(config?: Config) { 29 | if (process.env.NODE_ENV === 'production' && 'serviceWorker' in navigator) { 30 | // The URL constructor is available in all browsers that support SW. 31 | const publicUrl = new URL( 32 | (process as { env: { [key: string]: string } }).env.PUBLIC_URL, 33 | window.location.href 34 | ); 35 | if (publicUrl.origin !== window.location.origin) { 36 | // Our service worker won't work if PUBLIC_URL is on a different origin 37 | // from what our page is served on. This might happen if a CDN is used to 38 | // serve assets; see https://github.com/facebook/create-react-app/issues/2374 39 | return; 40 | } 41 | 42 | window.addEventListener('load', () => { 43 | const swUrl = `${process.env.PUBLIC_URL}/service-worker.js`; 44 | 45 | if (isLocalhost) { 46 | // This is running on localhost. Let's check if a service worker still exists or not. 47 | checkValidServiceWorker(swUrl, config); 48 | 49 | // Add some additional logging to localhost, pointing developers to the 50 | // service worker/PWA documentation. 51 | navigator.serviceWorker.ready.then(() => { 52 | console.log( 53 | 'This web app is being served cache-first by a service ' + 54 | 'worker. To learn more, visit https://bit.ly/CRA-PWA' 55 | ); 56 | }); 57 | } else { 58 | // Is not localhost. Just register service worker 59 | registerValidSW(swUrl, config); 60 | } 61 | }); 62 | } 63 | } 64 | 65 | function registerValidSW(swUrl: string, config?: Config) { 66 | navigator.serviceWorker 67 | .register(swUrl) 68 | .then(registration => { 69 | registration.onupdatefound = () => { 70 | const installingWorker = registration.installing; 71 | if (installingWorker == null) { 72 | return; 73 | } 74 | installingWorker.onstatechange = () => { 75 | if (installingWorker.state === 'installed') { 76 | if (navigator.serviceWorker.controller) { 77 | // At this point, the updated precached content has been fetched, 78 | // but the previous service worker will still serve the older 79 | // content until all client tabs are closed. 80 | console.log( 81 | 'New content is available and will be used when all ' + 82 | 'tabs for this page are closed. See https://bit.ly/CRA-PWA.' 83 | ); 84 | 85 | // Execute callback 86 | if (config && config.onUpdate) { 87 | config.onUpdate(registration); 88 | } 89 | } else { 90 | // At this point, everything has been precached. 91 | // It's the perfect time to display a 92 | // "Content is cached for offline use." message. 93 | console.log('Content is cached for offline use.'); 94 | 95 | // Execute callback 96 | if (config && config.onSuccess) { 97 | config.onSuccess(registration); 98 | } 99 | } 100 | } 101 | }; 102 | }; 103 | }) 104 | .catch(error => { 105 | console.error('Error during service worker registration:', error); 106 | }); 107 | } 108 | 109 | function checkValidServiceWorker(swUrl: string, config?: Config) { 110 | // Check if the service worker can be found. If it can't reload the page. 111 | fetch(swUrl) 112 | .then(response => { 113 | // Ensure service worker exists, and that we really are getting a JS file. 114 | const contentType = response.headers.get('content-type'); 115 | if ( 116 | response.status === 404 || 117 | (contentType != null && contentType.indexOf('javascript') === -1) 118 | ) { 119 | // No service worker found. Probably a different app. Reload the page. 120 | navigator.serviceWorker.ready.then(registration => { 121 | registration.unregister().then(() => { 122 | window.location.reload(); 123 | }); 124 | }); 125 | } else { 126 | // Service worker found. Proceed as normal. 127 | registerValidSW(swUrl, config); 128 | } 129 | }) 130 | .catch(() => { 131 | console.log( 132 | 'No internet connection found. App is running in offline mode.' 133 | ); 134 | }); 135 | } 136 | 137 | export function unregister() { 138 | if ('serviceWorker' in navigator) { 139 | navigator.serviceWorker.ready.then(registration => { 140 | registration.unregister(); 141 | }); 142 | } 143 | } 144 | -------------------------------------------------------------------------------- /src/Painter.tsx: -------------------------------------------------------------------------------- 1 | import * as tf from "@tensorflow/tfjs" 2 | import React, { Dispatch, useEffect, useState, useMemo } from "react" 3 | import { createUNet } from "./models/UNet" 4 | import { AppState, AppUpdateAction } from "./AppState" 5 | import { LossOrMetricFn } from "@tensorflow/tfjs-layers/dist/types" 6 | 7 | tf.enableProdMode() 8 | 9 | type PainterProps = { 10 | state: AppState, 11 | dispatchState: Dispatch 12 | } 13 | 14 | function imageTensorFromFlatArray(flat: number[], width: number, height: number) { 15 | return tf.transpose(tf.tensor1d(flat).reshape([1, height, width, 4]).slice([0, 0, 0, 0], [1, height, width, 3]), [0, 2, 1, 3]) 16 | } 17 | 18 | function scaleImageTensor(x: tf.Tensor) { 19 | return tf.sub(tf.div(x, 127.5), 1) 20 | } 21 | 22 | function createMemoryCanvas(width: number, height: number) { 23 | const canvas = document.createElement("canvas") 24 | canvas.width = width 25 | canvas.height = height 26 | return canvas 27 | } 28 | 29 | function drawImageTensor(ctx: CanvasRenderingContext2D, imageTensor: number[][][][]) { 30 | const [width, height] = [ctx.canvas.width, ctx.canvas.height] 31 | 32 | const imageData = ctx.createImageData(width, height) 33 | 34 | for (let x = 0; x < width; x++) { 35 | for (let y = 0; y < height; y++) { 36 | const i = x + y * width 37 | imageData.data[i * 4 + 0] = Math.min(255, Math.max(0, 127.5 * (1 + imageTensor[0][x][y][0]))) 38 | imageData.data[i * 4 + 1] = Math.min(255, Math.max(0, 127.5 * (1 + imageTensor[0][x][y][1]))) 39 | imageData.data[i * 4 + 2] = Math.min(255, Math.max(0, 127.5 * (1 + imageTensor[0][x][y][2]))) 40 | imageData.data[i * 4 + 3] = 255 41 | } 42 | } 43 | 44 | ctx.putImageData(imageData, 0, 0) 45 | } 46 | 47 | export function Painter(props: PainterProps) { 48 | const { state, dispatchState } = props 49 | 50 | const [model, setModel] = useState(null) 51 | const [noise, setNoise] = useState | null>(null) 52 | const [imageTensor, setImageTensor] = useState | null>(null) 53 | 54 | const canvas = useMemo(() => { 55 | return createMemoryCanvas(state.algorithmSettings.width, state.algorithmSettings.height) 56 | }, [state.algorithmSettings.width, state.algorithmSettings.height]) 57 | 58 | useEffect(() => { 59 | if (state.step === "runIter") { 60 | let m = model 61 | let n = noise 62 | let it = imageTensor 63 | 64 | if (m === null || n === null || it === null) { 65 | const noiseShape: [number, number, number] = [state.algorithmSettings.width, state.algorithmSettings.height, 1] 66 | const outputFilters = 3 67 | 68 | m = createUNet(noiseShape, outputFilters, state.algorithmSettings.layers, state.algorithmSettings.filters, !state.algorithmSettings.inpaint) 69 | 70 | n = tf.randomNormal([1].concat(noiseShape)) 71 | 72 | it = scaleImageTensor(imageTensorFromFlatArray(state.sourceImage!, state.algorithmSettings.width, state.algorithmSettings.height)) 73 | 74 | setModel(m) 75 | setNoise(n) 76 | setImageTensor(it) 77 | 78 | let loss: string | LossOrMetricFn = "meanAbsoluteError" 79 | 80 | if (state.algorithmSettings.inpaint) { 81 | const mask = Array.from(state.maskCanvas!.getContext("2d")!.getImageData(0, 0, state.algorithmSettings.width, state.algorithmSettings.height).data) 82 | const mt = tf.div(imageTensorFromFlatArray(mask, state.algorithmSettings.width, state.algorithmSettings.height), 255) 83 | loss = (x: tf.Tensor, y: tf.Tensor) => { 84 | return tf.losses.absoluteDifference(x, y, mt!) 85 | } 86 | } 87 | 88 | m.compile({ 89 | optimizer: "adam", 90 | loss: loss, 91 | }) 92 | } 93 | 94 | (async () => { 95 | try { 96 | await m.fit(n, it, { 97 | batchSize: 1, 98 | epochs: 20, 99 | }) 100 | 101 | const output = await (m.predict(n) as tf.Tensor).array() as number[][][][] 102 | 103 | drawImageTensor(canvas.getContext("2d")!, output) 104 | 105 | dispatchState({ 106 | type: "finishIter", 107 | imageData: { 108 | iteration: state.iteration, 109 | uri: canvas.toDataURL("image/png") 110 | } 111 | }) 112 | } 113 | catch (e) { 114 | console.log(`Exception when running model: ${e}`) 115 | 116 | dispatchState({ 117 | type: "finishIter", 118 | imageData: undefined 119 | }) 120 | } 121 | })() 122 | } 123 | // eslint-disable-next-line react-hooks/exhaustive-deps 124 | }, [state.step]) 125 | 126 | useEffect(() => { 127 | if (state.step === "finishedIter" && !state.shouldRun) { 128 | if (model !== null) { 129 | model.dispose() 130 | } 131 | 132 | if (imageTensor !== null) { 133 | imageTensor.dispose() 134 | } 135 | 136 | if (noise !== null) { 137 | noise.dispose() 138 | } 139 | 140 | setModel(null) 141 | setImageTensor(null) 142 | setNoise(null) 143 | 144 | dispatchState({ 145 | type: "stopped" 146 | }) 147 | } 148 | // eslint-disable-next-line react-hooks/exhaustive-deps 149 | }, [state.shouldRun, state.step]) 150 | 151 | return
152 | } -------------------------------------------------------------------------------- /src/App.tsx: -------------------------------------------------------------------------------- 1 | import React, { useEffect, useState, useMemo, CSSProperties } from 'react' 2 | import './App.css' 3 | import { Painter } from './Painter' 4 | import { Row, Col, Container, Button, Navbar, Nav, Image as BSImage } from "react-bootstrap" 5 | import 'bootstrap/dist/css/bootstrap.css' 6 | import 'rc-slider/assets/index.css' 7 | import { useAppState, ImageData } from './AppState' 8 | import { useDropzone } from 'react-dropzone' 9 | import ReactCompareImage from "react-compare-image" 10 | import DrawableCanvas from './DrawableCanvas' 11 | import LabeledSlider from './LabeledSlider' 12 | import LabeledCheckbox from './LabeledCheckbox' 13 | 14 | import sampleImage1 from "./sample-images/car_inpaint.png" 15 | import sampleImage2 from "./sample-images/lenna_noisy.png" 16 | import sampleImage3 from "./sample-images/cat_text.png" 17 | 18 | const App: React.FC = () => { 19 | const [state, dispatchState] = useAppState() 20 | 21 | function setWidth(value: number) { 22 | dispatchState({ 23 | type: "algorithmSettings", 24 | newSettings: { 25 | ...state.algorithmSettings, 26 | width: value 27 | } 28 | }) 29 | } 30 | 31 | function setHeight(value: number) { 32 | dispatchState({ 33 | type: "algorithmSettings", 34 | newSettings: { 35 | ...state.algorithmSettings, 36 | height: value 37 | } 38 | }) 39 | } 40 | 41 | function setLayers(value: number) { 42 | dispatchState({ 43 | type: "algorithmSettings", 44 | newSettings: { 45 | ...state.algorithmSettings, 46 | layers: value 47 | } 48 | }) 49 | } 50 | 51 | function setFilters(value: number) { 52 | dispatchState({ 53 | type: "algorithmSettings", 54 | newSettings: { 55 | ...state.algorithmSettings, 56 | filters: value 57 | } 58 | }) 59 | } 60 | 61 | function setInpaint(value: boolean) { 62 | dispatchState({ 63 | type: "algorithmSettings", 64 | newSettings: { 65 | ...state.algorithmSettings, 66 | inpaint: value 67 | } 68 | }) 69 | } 70 | 71 | useEffect(() => { 72 | if (state.shouldRun && state.step !== "runIter") { 73 | dispatchState({ type: "startIter" }) 74 | } 75 | }, [state.step, state.shouldRun, dispatchState]) 76 | 77 | const canvas = useMemo(() => { 78 | const cnv = document.createElement("canvas") 79 | cnv.width = state.algorithmSettings.width 80 | cnv.height = state.algorithmSettings.height 81 | return cnv 82 | }, [state.algorithmSettings.width, state.algorithmSettings.height]) 83 | 84 | const [selectedImage, setSelectedImage] = useState(null) 85 | 86 | useEffect(() => { 87 | if (selectedImage !== null) { 88 | const context = canvas.getContext("2d")! 89 | context.drawImage(selectedImage, 0, 0, state.algorithmSettings.width, state.algorithmSettings.height) 90 | const imageData = context.getImageData(0, 0, state.algorithmSettings.width, state.algorithmSettings.height).data 91 | dispatchState({ 92 | type: "setSourceImage", 93 | image: Array.from(imageData) 94 | }) 95 | } 96 | }, [selectedImage, state.algorithmSettings.width, state.algorithmSettings.height, dispatchState, canvas]) 97 | 98 | function onImageSelected(files: File[]) { 99 | setSelectedImage(null) 100 | 101 | const image = new Image() 102 | 103 | image.onload = function (evt: any) { 104 | setSelectedImage(image) 105 | } 106 | 107 | const file = files[0] 108 | const reader = new FileReader() 109 | 110 | reader.onload = function (evt: any) { 111 | if (evt.target.readyState === FileReader.DONE) { 112 | image.src = evt.target.result 113 | } 114 | } 115 | 116 | reader.readAsDataURL(file) 117 | } 118 | 119 | const statusText = useMemo(() => { 120 | if (state.step !== "idle" && state.shouldRun) { 121 | return `Running, iteration: ${state.iteration}.` 122 | } else if (state.step === "idle" && state.shouldRun) { 123 | return "Starting, your browser might freeze for a while..." 124 | } else if (state.step !== "idle" && !state.shouldRun) { 125 | return "Stopping..." 126 | } else if (state.step === "idle" && !state.shouldRun && !state.sourceImage) { 127 | return "Choose an image." 128 | } else if (state.step === "idle" && !state.shouldRun) { 129 | return "Click start" 130 | } 131 | 132 | }, [state.step, state.shouldRun, state.sourceImage, state.iteration]) 133 | 134 | const displayedImage = selectedImage ? selectedImage.src : "" 135 | 136 | type SettingsProps = { 137 | style: CSSProperties, 138 | disabled: boolean 139 | } 140 | 141 | const settingsProps: SettingsProps = { 142 | style: { 143 | textAlign: "center" 144 | }, 145 | disabled: state.shouldRun || state.step !== "idle" 146 | } 147 | 148 | const [comparisonImageUri, setComparisonImageUri] = useState("") 149 | 150 | const selectImage = (evt: React.MouseEvent) => { 151 | if (!settingsProps.disabled) { 152 | setSelectedImage(evt.target as HTMLImageElement) 153 | } 154 | } 155 | 156 | const { getRootProps, getInputProps } = useDropzone({ 157 | accept: "image/*", 158 | onDrop: onImageSelected, 159 | disabled: settingsProps.disabled 160 | }) 161 | 162 | const columnSizes = { 163 | xl: 3, 164 | l: 3, 165 | md: 4, 166 | sm: 6, 167 | xs: 8 168 | } 169 | 170 | const runButtonEnabled = !state.shouldRun && state.step === "idle" && state.maskCanvas !== null && state.sourceImage !== null 171 | const stopButtonEnabled = state.shouldRun && state.step !== "idle" 172 | 173 | return ( 174 |
175 | 176 | 177 | 178 | 179 | 180 |
181 | Deep Image Prior 182 |
183 |
184 | Implementation by Tora 185 |
186 |
187 | Source code 188 | Original project page 189 | Paper 190 |
191 |
192 | 193 | 194 | 195 | 196 | 197 |

{statusText}

198 |
199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 |
217 | 218 |

Click to select an image

219 |
220 | 221 |
222 | 223 | 224 |

Settings

225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 |
234 | 235 | 236 | 237 |
238 | 239 |
240 | 241 |
242 |
243 | 244 | 245 |
246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | {state.images.map((image: ImageData) => 255 | {image.uri} setComparisonImageUri(image.uri)} /> 256 | )} 257 | 258 |
259 |
260 | ); 261 | } 262 | 263 | export default App 264 | --------------------------------------------------------------------------------