41 | );
42 | }
43 | }
44 |
--------------------------------------------------------------------------------
/web/client/src/pages/App.js:
--------------------------------------------------------------------------------
1 | import React, { Component } from "react";
2 | import AuthPage from "./auth/AuthPage";
3 | import SelectPage from "./select/SelectPage";
4 | import EncodePage from "./encode/EncodePage";
5 | import DecodePage from "./decode/DecodePage";
6 | import "./global.css";
7 | import Router from "react-router-dom/BrowserRouter";
8 | import { AnimatedSwitch, spring } from "react-router-transition";
9 | import Route from "react-router-dom/Route";
10 | import ImageTrace from '../components/ImageTrace/ImageTrace'
11 |
12 |
13 | function mapStyles(styles) {
14 | return {
15 | opacity: styles.opacity,
16 | transform: `scale(${styles.scale})`
17 | };
18 | }
19 |
20 | // wrap the `spring` helper to use a bouncy config
21 | function bounce(val) {
22 | return spring(val, {
23 | stiffness: 330,
24 | damping: 22
25 | });
26 | }
27 |
28 | class App extends Component {
29 | state = {
30 | page: "select"
31 | };
32 |
33 | render() {
34 | return (
35 |
36 |
37 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 | );
62 | }
63 | }
64 |
65 | export default App;
66 |
--------------------------------------------------------------------------------
/optimizers.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import random, sys, os, json, glob, math
4 |
5 | import IPython
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torch.optim as optim
11 | from torch.autograd import Variable
12 |
13 |
14 | class Optimizer(nn.Module):
15 | def __init__(self, parameters):
16 | self.parameters = parameters
17 | pass
18 |
19 | def step(self, loss):
20 | loss.backward(create_graph=True, retain_graph=True)
21 | for param in self.parameters:
22 | update = self.forward(param.grad, param)
23 | yield (param + update)
24 |
25 | def forward(self, grad, param=None):
26 | raise NotImplementedError()
27 |
28 |
29 | class Adam(Optimizer):
30 | def __init__(self, parameters, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, differentiable=False):
31 |
32 | super().__init__(parameters)
33 |
34 | self.lr = lr
35 | self.betas = betas
36 | self.eps = eps
37 | self.state = {param: {} for param in parameters}
38 | self.differentiable = differentiable
39 |
40 | def forward(self, grad, param=None):
41 |
42 | state = self.state[param]
43 | step = state["step"] = state.get("step", 0) + 1
44 | exp_avg = state["exp_avg"] = state.get("exp_avg", torch.zeros_like(grad.data))
45 | exp_avg_sq = state["exp_avg_sq"] = state.get("exp_avg_sq", torch.zeros_like(grad.data))
46 | beta1, beta2 = self.betas
47 |
48 | exp_avg = exp_avg * beta1 + (1 - beta1) * (grad)
49 | exp_avg_sq = exp_avg_sq * (beta2) + (1 - beta2) * (grad) * (grad)
50 | denom = exp_avg_sq.sqrt() + self.eps
51 |
52 | bias_correction1 = 1 - beta1 ** state["step"]
53 | bias_correction2 = 1 - beta2 ** state["step"]
54 |
55 | step_size = self.lr * math.sqrt(bias_correction2) / bias_correction1
56 | update = -step_size * exp_avg / denom
57 |
58 | state["exp_avg"] = exp_avg.detach().data
59 | state["exp_avg_sq"] = exp_avg_sq.detach().data
60 |
61 | if not self.differentiable:
62 | update = update.data
63 |
64 | return update
65 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 |
2 | import random, sys, os, glob, yaml, time
3 | import argparse, subprocess, shutil, shlex
4 | from fire import Fire
5 | from utils import elapsed
6 |
7 | import IPython
8 |
9 |
10 | def execute(cmd, mode="experiment", config="default", shutdown=False, debug=False):
11 |
12 | elapsed()
13 | try:
14 | run_log = yaml.load(open("jobs/runlog.yml"))
15 | except:
16 | run_log = {}
17 |
18 | run_data = run_log[mode] = run_log.get(mode, {})
19 | run_data["runs"] = run_data.get("runs", 0) + 1
20 | run_name = mode + str(run_data["runs"])
21 | run_data[run_name] = run_data.get(run_name, {"config": config, "cmd": cmd, "status": "Running"})
22 | run_data = run_data[run_name]
23 |
24 | print(f"Running job: {run_name}")
25 |
26 | shutil.rmtree("output/", ignore_errors=True)
27 | os.makedirs("output/")
28 | os.makedirs(f"jobs/{run_name}", exist_ok=True)
29 |
30 | with open("jobs/jobinfo.txt", "w") as config_file:
31 | print(config, file=config_file)
32 |
33 | cmd = shlex.split(cmd)
34 | if cmd[0] == "python" and debug:
35 | cmd[0] = "ipython"
36 | cmd.insert(1, "-i")
37 | elif cmd[0] == "python":
38 | cmd.insert(1, "-u")
39 |
40 | print(" ".join(cmd))
41 | process = subprocess.Popen(cmd, shell=False, stdout=subprocess.PIPE, universal_newlines=True)
42 |
43 | try:
44 | with open(f"jobs/{run_name}/stdout.txt", "w") as outfile:
45 | for stdout_line in iter(process.stdout.readline, ""):
46 | print(stdout_line, end="")
47 | outfile.write(stdout_line)
48 |
49 | return_code = process.wait()
50 | run_data["status"] = "Error" if return_code else "Complete"
51 | except KeyboardInterrupt:
52 | print("\nKilled by user.")
53 | process.kill()
54 | run_data["status"] = "Killed"
55 | except OSError:
56 | print("\nSystem error.")
57 | process.kill()
58 | run_data["status"] = "Error"
59 |
60 | process.kill()
61 |
62 | if debug and run_data["status"] != "Complete":
63 | return
64 |
65 | shutil.copytree("output", f"jobs/{run_name}/output")
66 | for file in glob.glob("*.py"):
67 | shutil.copy(file, f"jobs/{run_name}")
68 |
69 | yaml.safe_dump(run_log, open("jobs/runlog.yml", "w"), allow_unicode=True, default_flow_style=False)
70 | yaml.safe_dump(run_data, open(f"jobs/{run_name}/comments.yml", "w"), allow_unicode=True, default_flow_style=False)
71 |
72 | interval = elapsed()
73 | print(f"Program ended after {interval:0.4f} seconds.")
74 | if shutdown and run_data["status"] != "Killed" and interval > 60:
75 | print(f"Shutting down in 1 minute.")
76 | time.sleep(60)
77 | subprocess.call("sudo shutdown -h now", shell=True)
78 |
79 |
80 | def run(cmd, mode="experiment", config="default", shutdown=False, debug=False):
81 | cmd = f""" screen -S {config} bash -c "python run.py execute \\"{cmd}\\" --mode {mode} --config {config} --shutdown {shutdown} --debug {debug}" """
82 | subprocess.call(shlex.split(cmd))
83 |
84 |
85 | if __name__ == "__main__":
86 | Fire({"run": run, "execute": execute})
87 |
--------------------------------------------------------------------------------
/old/train_dsc.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import random, sys, os, json, glob
4 | import tqdm, itertools, shutil
5 |
6 | import matplotlib as mpl
7 |
8 | mpl.use("Agg")
9 | import matplotlib.pyplot as plt
10 |
11 | import torch
12 |
13 | torch.backends.cudnn.benchmark = True
14 | import torch.nn as nn
15 | import torch.nn.functional as F
16 | import torch.optim as optim
17 | from torch.autograd import Variable
18 |
19 | from utils import *
20 | import transforms
21 | from modules import UNet
22 | from logger import Logger
23 |
24 | from sklearn.metrics import roc_auc_score
25 | from scipy.stats import pearsonr
26 | import IPython
27 |
28 | DATA_PATH = "data/encode_120"
29 | logger = Logger("train_dsc", ("loss", "corr"), print_every=5, plot_every=20)
30 |
31 |
32 | def loss_func(model, x, y):
33 | cleaned = model.forward(x)
34 | corr, p = pearsonr(
35 | cleaned.data.cpu().numpy().flatten(), y.data.cpu().numpy().flatten()
36 | )
37 | return (cleaned - y).pow(2).sum(), corr
38 |
39 |
40 | def data_gen(files, batch_size=64):
41 | while True:
42 | enc_files = random.sample(files, batch_size)
43 | orig_files = [f.replace("encoded", "original") for f in enc_files]
44 | print(enc_files)
45 | encoded_ims = [im.load(image) for image in enc_files]
46 | original_ims = [im.load(image) for image in orig_files]
47 | encoded, original = im.stack(encoded_ims), im.stack(original_ims)
48 |
49 | yield encoded, (encoded - original)
50 |
51 |
52 | def viz_preds(model, x, y):
53 | preds = model(x)
54 | for i, (pred, truth, enc) in enumerate(zip(preds, y, x)):
55 | im.save(im.numpy(enc), f"{OUTPUT_DIR}{i}_encoded.jpg")
56 | im.save(3 * np.abs(im.numpy(pred)), f"{OUTPUT_DIR}{i}_pred.jpg")
57 | im.save(3 * np.abs(im.numpy(truth)), f"{OUTPUT_DIR}{i}_truth.jpg")
58 |
59 |
60 | if __name__ == "__main__":
61 |
62 | model = nn.DataParallel(UNet())
63 | model.train()
64 |
65 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
66 |
67 | # optimizer.load_state_dict('output/unet_opt.pth')
68 | model.module.load("jobs/experiment_unet/output/train_unet.pth")
69 |
70 | logger.add_hook(
71 | lambda: [
72 | print(f"Saving model/opt to {OUTPUT_DIR}train_unet.pth"),
73 | model.module.save(OUTPUT_DIR + "train_unet.pth"),
74 | torch.save(optimizer.state_dict(), OUTPUT_DIR + "unet_opt.pth"),
75 | ],
76 | freq=100,
77 | )
78 |
79 | files = glob.glob(f"{DATA_PATH}/*encoded*.jpg")
80 | train_files, val_files = files[:-128], files[-128:]
81 | x_val, y_val = next(data_gen(val_files, 128))
82 |
83 | for i, (x, y) in enumerate(data_gen(train_files, 128)):
84 | loss, corr = loss_func(model, x, y)
85 |
86 | logger.step("loss", min(5000, loss))
87 | logger.step("corr", corr)
88 |
89 | optimizer.zero_grad()
90 | loss.backward()
91 | optimizer.step()
92 |
93 | if i % 20 == 0:
94 | model.eval()
95 | val_loss = loss_func(model, x_val, y_val)
96 | model.train()
97 | print(f"val_loss = {val_loss}")
98 |
99 | if i == 2000:
100 | break
101 |
--------------------------------------------------------------------------------
/old/train_unet.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import random, sys, os, json, glob
4 | import tqdm, itertools, shutil
5 |
6 | import matplotlib as mpl
7 |
8 | mpl.use("Agg")
9 | import matplotlib.pyplot as plt
10 |
11 | import torch
12 |
13 | torch.backends.cudnn.benchmark = True
14 | import torch.nn as nn
15 | import torch.nn.functional as F
16 | import torch.optim as optim
17 | from torch.autograd import Variable
18 |
19 | from utils import *
20 | import transforms
21 | from models import UNet
22 | from logger import Logger
23 |
24 | from sklearn.metrics import roc_auc_score
25 | from scipy.stats import pearsonr
26 | import IPython
27 |
28 | DATA_PATH = "data/encode_120"
29 | logger = Logger("train_dsc", ("loss", "corr"), print_every=5, plot_every=20)
30 |
31 |
32 | def loss_func(model, x, y):
33 | cleaned = model.forward(x)
34 | corr, p = pearsonr(
35 | cleaned.data.cpu().numpy().flatten(), y.data.cpu().numpy().flatten()
36 | )
37 | return (cleaned - y).pow(2).sum(), corr
38 |
39 |
40 | def data_gen(files, batch_size=64):
41 | while True:
42 | enc_files = random.sample(files, batch_size)
43 | orig_files = [f.replace("encoded", "original") for f in enc_files]
44 | encoded_ims = [im.load(image) for image in enc_files]
45 | original_ims = [im.load(image) for image in orig_files]
46 | encoded, original = im.stack(encoded_ims), im.stack(original_ims)
47 |
48 | yield original, (encoded - original)
49 |
50 |
51 | def viz_preds(model, x, y):
52 | preds = model(x)
53 | for i, (pred, truth, enc) in enumerate(zip(preds, y, x)):
54 | im.save(im.numpy(enc + truth), f"{OUTPUT_DIR}{i}_encoded.jpg")
55 | im.save(3 * np.abs(im.numpy(pred)), f"{OUTPUT_DIR}{i}_pred.jpg")
56 | im.save(3 * np.abs(im.numpy(truth)), f"{OUTPUT_DIR}{i}_truth.jpg")
57 |
58 |
59 | if __name__ == "__main__":
60 |
61 | model = nn.DataParallel(UNet())
62 | model.train()
63 |
64 | optimizer = torch.optim.Adam(model.parameters(), lr=3e-3)
65 |
66 | # optimizer.load_state_dict('output/unet_opt.pth')
67 | model.module.load("output/train_unet.pth")
68 |
69 | logger.add_hook(
70 | lambda: [
71 | print(f"Saving model/opt to {OUTPUT_DIR}train_unet.pth"),
72 | model.module.save(OUTPUT_DIR + "train_unet.pth"),
73 | torch.save(optimizer.state_dict(), OUTPUT_DIR + "unet_opt.pth"),
74 | ],
75 | freq=100,
76 | )
77 |
78 | files = glob.glob(f"{DATA_PATH}/*encoded*.jpg")
79 | train_files, val_files = files[:-142], files[-142:]
80 | x_val, y_val = next(data_gen(val_files, 142))
81 |
82 | for i, (x, y) in enumerate(data_gen(train_files, 142)):
83 | loss, corr = loss_func(model, x, y)
84 |
85 | logger.step("loss", min(5000, loss))
86 | logger.step("corr", corr)
87 |
88 | optimizer.zero_grad()
89 | loss.backward()
90 | optimizer.step()
91 |
92 | if i % 50 == 0:
93 | model.eval()
94 | val_loss = loss_func(model, x_val, y_val)
95 | viz_preds(model, x_val[:8], y_val[:8])
96 | model.train()
97 | print(f"val_loss = {val_loss}")
98 |
99 | if i == 2000:
100 | break
101 |
--------------------------------------------------------------------------------
/old/losses.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import numpy as np
4 | import random, sys, os, json
5 |
6 | import matplotlib as mpl
7 |
8 | mpl.use("Agg")
9 | import matplotlib.pyplot as plt
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | import torch.optim as optim
15 | from torch.autograd import Variable
16 |
17 | from torchvision import models
18 | from utils import *
19 |
20 | from skimage import filters
21 | from skimage.morphology import binary_dilation
22 |
23 | import IPython
24 |
25 | import transforms
26 |
27 |
28 | def getFeatures(model, x):
29 | x = torch.cat(
30 | [
31 | ((x[0] - 0.485) / (0.229)).unsqueeze(0),
32 | ((x[1] - 0.456) / (0.224)).unsqueeze(0),
33 | ((x[2] - 0.406) / (0.225)).unsqueeze(0),
34 | ],
35 | dim=0,
36 | )
37 | x = transforms.identity(x).unsqueeze(0)
38 |
39 | features = []
40 | prev_feat = x
41 |
42 | for i, module in enumerate(model.features.features._modules.values()):
43 | next_feat = module(prev_feat)
44 | features.append(next_feat)
45 | prev_feat = next_feat
46 |
47 | return features
48 |
49 |
50 | def gram_matrix(features, normalize=True):
51 | N, C, H, W = features.shape
52 | featuresReshaped = features.reshape((N, C, H * W))
53 | featuresTranspose = featuresReshaped.permute(0, 2, 1)
54 | ans = torch.matmul(featuresReshaped, featuresTranspose)
55 | if normalize:
56 | ans /= H * W * C
57 | return ans
58 |
59 |
60 | # features_list is a list of layers of model to calculate content at
61 | # weights_list is a list of corresponding weights for features_list
62 | def content_loss(model, features_list, weights_list, changed_image, original_image):
63 |
64 | original_features = model.getFeatures(original_image)
65 | changed_features = model.getFeatures(changed_image)
66 |
67 | loss = 0
68 |
69 | for i, layer_number in enumerate(features_list):
70 | activation_changed = changed_features[layer_number]
71 | activation_original = original_features[layer_number]
72 |
73 | N, C_1, H_1, W_1 = activation_changed.shape
74 | F_ij = activation_changed.reshape((C_1, H_1 * W_1))
75 | P_ij = activation_original.reshape((C_1, H_1 * W_1))
76 | loss += weights_list[i] * (((F_ij - P_ij).norm(2)) ** 2)
77 |
78 | return loss
79 |
80 |
81 | def gram_matrix(features, normalize=True):
82 | N, C, H, W = features.shape
83 | featuresReshaped = features.reshape((N, C, H * W))
84 | featuresTranspose = featuresReshaped.permute(0, 2, 1)
85 | ans = torch.matmul(featuresReshaped, featuresTranspose)
86 | if normalize:
87 | ans /= H * W * C
88 | return ans
89 |
90 |
91 | # features_list is a list of layers of model to calculate style at
92 | # weights_list is a list of corresponding weights for features_list
93 | def style_loss(model, features_list, weights_list, changed_image, original_image):
94 |
95 | original_features = model.getFeatures(original_image)
96 | changed_features = model.getFeatures(changed_image)
97 |
98 | loss = 0
99 |
100 | for i, layer_number in enumerate(features_list):
101 | activation_changed = changed_features[layer_number]
102 | activation_original = original_features[layer_number]
103 | changed_gram = gram_matrix(activation_changed)
104 | original_gram = gram_matrix(activation_original)
105 | loss += weights_list[i] * (((changed_gram - original_gram).norm(2)) ** 2)
106 |
107 | return loss
108 |
--------------------------------------------------------------------------------
/train_amnesia.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import random, sys, os, json, glob
4 | import tqdm, itertools, shutil
5 |
6 | import matplotlib as mpl
7 |
8 | mpl.use("Agg")
9 | import matplotlib.pyplot as plt
10 |
11 | import torch
12 |
13 | torch.backends.cudnn.benchmark = True
14 | import torch.nn as nn
15 | import torch.nn.functional as F
16 | import torch.optim as optim
17 | from torch.autograd import Variable
18 |
19 | from utils import *
20 | import transforms
21 | from encoding import encode_binary
22 | from models import DecodingModel, DataParallelModel
23 | from logger import Logger, VisdomLogger
24 |
25 | from skimage.morphology import binary_dilation
26 | import IPython
27 |
28 | from testing import test_transforms
29 |
30 |
31 | def loss_func(model, x, targets):
32 | scores = model.forward(x)
33 | predictions = scores.mean(dim=1)
34 | score_targets = binary.target(targets).unsqueeze(1).expand_as(scores)
35 |
36 | return (F.binary_cross_entropy(scores, score_targets), predictions.cpu().data.numpy().round(2))
37 |
38 |
39 | def init_data(output_path, n=None):
40 |
41 | shutil.rmtree(output_path)
42 | os.makedirs(output_path)
43 |
44 | image_files = TRAIN_FILES
45 | if n is not None:
46 | image_files = image_files[0:n]
47 |
48 | for k, files in tqdm.tqdm(list(enumerate(batch(image_files, batch_size=BATCH_SIZE))), ncols=50):
49 |
50 | images = im.stack([im.load(img_file) for img_file in files]).detach()
51 | perturbation = nn.Parameter(0.03 * torch.randn(images.size()).to(DEVICE) + 0.0)
52 | targets = [binary.random(n=TARGET_SIZE) for i in range(len(images))]
53 | torch.save((perturbation.data, images.data, targets), f"{output_path}/{k}.pth")
54 |
55 |
56 | if __name__ == "__main__":
57 |
58 | model = DataParallelModel(DecodingModel(n=DIST_SIZE, distribution=transforms.training))
59 | params = itertools.chain(model.module.classifier.parameters(), model.module.features[-1].parameters())
60 | optimizer = torch.optim.Adam(params, lr=2.5e-3)
61 | init_data("data/amnesia")
62 |
63 | logger = VisdomLogger("train", server="35.230.67.129", port=8000, env=JOB)
64 | logger.add_hook(lambda x: logger.step(), feature="epoch", freq=20)
65 | logger.add_hook(lambda data: logger.plot(data, "train_loss"), feature="loss", freq=50)
66 | logger.add_hook(lambda data: logger.plot(data, "train_bits"), feature="bits", freq=50)
67 | logger.add_hook(lambda x: model.save("output/train_test.pth", verbose=True), feature="epoch", freq=100)
68 | model.save("output/train_test.pth", verbose=True)
69 |
70 | files = glob.glob(f"data/amnesia/*.pth")
71 | for i, save_file in enumerate(random.choice(files) for i in range(0, 2701)):
72 |
73 | perturbation, images, targets = torch.load(save_file)
74 | perturbation = perturbation.requires_grad_()
75 |
76 | perturbation.requires_grad = True
77 | encoded_ims, perturbation = encode_binary(
78 | images, targets, model, max_iter=1, perturbation=perturbation, use_weighting=True
79 | )
80 |
81 | loss, predictions = loss_func(model, encoded_ims, targets)
82 | error = np.mean([binary.distance(x, y) for x, y in zip(predictions, targets)])
83 |
84 | logger.update("epoch", i)
85 | logger.update("loss", loss)
86 | logger.update("bits", error)
87 |
88 | loss.backward()
89 | optimizer.step()
90 | optimizer.zero_grad()
91 |
92 | torch.save((perturbation.data, images.data, targets), save_file)
93 |
94 | if i != 0 and i % 300 == 0:
95 |
96 | model.save("output/train_test.pth")
97 | model2 = DataParallelModel(
98 | DecodingModel.load(distribution=transforms.training, n=DIST_SIZE, weights_file="output/train_test.pth")
99 | )
100 | # test_transforms(model, random.sample(TRAIN_FILES, 16), name=f'iter{i}_train')
101 | test_transforms(model2, VAL_FILES, name=f"iter{i}_test", max_iter=300)
102 |
--------------------------------------------------------------------------------
/old/modules.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import print_function
3 |
4 | import numpy as np
5 | import random, sys, os, json
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torch.optim as optim
11 | from torch.autograd import Variable
12 |
13 | from torchvision import models
14 | from utils import *
15 | import transforms
16 |
17 | import IPython
18 |
19 |
20 | class UNet_down_block(nn.Module):
21 | def __init__(self, input_channel, output_channel, down_size=True):
22 | super(UNet_down_block, self).__init__()
23 | self.conv1 = nn.Conv2d(input_channel, output_channel, 3, padding=1)
24 | self.bn1 = nn.BatchNorm2d(output_channel)
25 | self.conv2 = nn.Conv2d(output_channel, output_channel, 3, padding=1)
26 | self.bn2 = nn.BatchNorm2d(output_channel)
27 | self.max_pool = nn.MaxPool2d(2, 2)
28 | self.relu = nn.ReLU()
29 | self.down_size = down_size
30 |
31 | def forward(self, x):
32 |
33 | x = self.relu(self.bn1(self.conv1(x)))
34 | x = self.relu(self.bn2(self.conv2(x)))
35 | if self.down_size:
36 | x = self.max_pool(x)
37 | return x
38 |
39 |
40 | class UNet_up_block(nn.Module):
41 | def __init__(self, prev_channel, input_channel, output_channel, up_sample=True):
42 | super(UNet_up_block, self).__init__()
43 | self.up_sampling = nn.Upsample(scale_factor=2, mode="bilinear")
44 | self.conv1 = nn.Conv2d(
45 | prev_channel + input_channel, output_channel, 3, padding=1
46 | )
47 | self.bn1 = nn.BatchNorm2d(output_channel)
48 | self.conv2 = nn.Conv2d(output_channel, output_channel, 3, padding=1)
49 | self.bn2 = nn.BatchNorm2d(output_channel)
50 | self.relu = torch.nn.ReLU()
51 | self.up_sample = up_sample
52 |
53 | def forward(self, prev_feature_map, x):
54 | if self.up_sample:
55 | x = self.up_sampling(x)
56 | x = torch.cat((x, prev_feature_map), dim=1)
57 | x = self.relu(self.bn1(self.conv1(x)))
58 | x = self.relu(self.bn2(self.conv2(x)))
59 | return x
60 |
61 |
62 | class UNet(nn.Module):
63 | def __init__(self):
64 | super(UNet, self).__init__()
65 |
66 | self.down_block1 = UNet_down_block(3, 16, False)
67 | self.down_block2 = UNet_down_block(16, 32, True)
68 | self.down_block3 = UNet_down_block(32, 64, True)
69 | self.down_block4 = UNet_down_block(64, 128, True)
70 | self.down_block5 = UNet_down_block(128, 256, True)
71 | # self.down_block6 = UNet_down_block(256, 512, True)
72 |
73 | self.mid_conv1 = nn.Conv2d(256, 256, 3, padding=1)
74 | self.bn1 = nn.BatchNorm2d(256)
75 | # self.mid_conv2 = nn.Conv2d(512, 512, 3, padding=1)
76 | # self.bn2 = nn.BatchNorm2d(512)
77 | # self.mid_conv3 = torch.nn.Conv2d(512, 512, 3, padding=1)
78 | # self.bn3 = torch.nn.BatchNorm2d(512)
79 |
80 | # self.up_block1 = UNet_up_block(256, 512, 256)
81 | self.up_block2 = UNet_up_block(128, 256, 128)
82 | self.up_block3 = UNet_up_block(64, 128, 64)
83 | self.up_block4 = UNet_up_block(32, 64, 32)
84 | self.up_block5 = UNet_up_block(16, 32, 16)
85 |
86 | self.last_conv1 = nn.Conv2d(16, 16, 3, padding=1)
87 | self.last_bn = nn.BatchNorm2d(16)
88 | self.last_conv2 = nn.Conv2d(16, 3, 1, padding=0)
89 | self.relu = nn.ReLU()
90 | self.to(DEVICE)
91 |
92 | def forward(self, x):
93 | self.x1 = self.down_block1(x)
94 | self.x2 = self.down_block2(self.x1)
95 | self.x3 = self.down_block3(self.x2)
96 | self.x4 = self.down_block4(self.x3)
97 | self.x5 = self.down_block5(self.x4)
98 | # self.x6 = self.down_block6(self.x5)
99 |
100 | self.x5 = self.relu(self.bn1(self.mid_conv1(self.x5)))
101 | # self.x6 = self.relu(self.bn2(self.mid_conv2(self.x6)))
102 | # self.x6 = self.relu(self.bn3(self.mid_conv3(self.x6)))
103 |
104 | # x = self.up_block1(self.x5, self.x6)
105 | x = self.up_block2(self.x4, self.x5)
106 | x = self.up_block3(self.x3, x)
107 | x = self.up_block4(self.x2, x)
108 | x = self.up_block5(self.x1, x)
109 | x = self.relu(self.last_bn(self.last_conv1(x)))
110 | x = self.last_conv2(x)
111 | return x
112 |
113 | def load(self, file_path):
114 | self.load_state_dict(torch.load(file_path))
115 |
116 | def save(self, file_path):
117 | torch.save(self.state_dict(), file_path)
118 |
--------------------------------------------------------------------------------
/encoding.py:
--------------------------------------------------------------------------------
1 |
2 | import random, sys, os, json, glob, argparse
3 |
4 | import numpy as np
5 | import matplotlib as mpl
6 |
7 | mpl.use("Agg")
8 | import matplotlib.pyplot as plt
9 | from fire import Fire
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | import torch.optim as optim
15 | from torch.autograd import Variable
16 |
17 | from models import DecodingModel, DataParallelModel
18 | from torchvision import models
19 | from logger import Logger, VisdomLogger
20 | from utils import *
21 |
22 | import IPython
23 |
24 | import transforms
25 |
26 |
27 | # LOGGING
28 | logger = VisdomLogger("encoding", server="35.230.67.129", port=8000, env=JOB)
29 | logger.add_hook(lambda x: logger.step(), feature="epoch", freq=20)
30 | logger.add_hook(lambda x: logger.plot(x, "Encoding Loss", opts=dict(ymin=0)), feature="loss", freq=50)
31 |
32 |
33 | """
34 | Computes the changed images, given a a specified perturbation, standard deviation weighting,
35 | and epsilon.
36 | """
37 |
38 |
39 | def compute_changed_images(images, perturbation, std_weights, epsilon=EPSILON):
40 |
41 | perturbation_w2 = perturbation * std_weights
42 | perturbation_zc = (
43 | perturbation_w2
44 | / perturbation_w2.view(perturbation_w2.shape[0], -1)
45 | .norm(2, dim=1, keepdim=True)
46 | .unsqueeze(2)
47 | .unsqueeze(2)
48 | .expand_as(perturbation_w2)
49 | * epsilon
50 | * (perturbation_w2[0].nelement() ** 0.5)
51 | )
52 |
53 | changed_images = (images + perturbation_zc).clamp(min=0.0, max=1.0)
54 | return changed_images
55 |
56 |
57 | """
58 | Computes the cross entropy loss of a set of encoded images, given the model and targets.
59 | """
60 |
61 |
62 | def loss_func(model, x, targets):
63 | scores = model.forward(x)
64 | predictions = scores.mean(dim=1)
65 | score_targets = binary.target(targets).unsqueeze(1).expand_as(scores)
66 |
67 | return (F.binary_cross_entropy(scores, score_targets), predictions.cpu().data.numpy().round(2))
68 |
69 |
70 | """
71 | Encodes a set of images with the specified binary targets, for a given number of iterations.
72 | """
73 |
74 |
75 | def encode_binary(
76 | images, targets, model=DecodingModel(), n=None, max_iter=500, verbose=False, perturbation=None, use_weighting=False
77 | ):
78 |
79 | if n is not None:
80 | if verbose:
81 | print(f"Changing distribution size: {model.n} -> {n}")
82 | n, model.n = (model.n, n)
83 |
84 | returnPerturbation = True
85 | if perturbation is None:
86 | perturbation = nn.Parameter(0.03 * torch.randn(images.size()).to(DEVICE) + 0.0)
87 | returnPerturbation = False
88 |
89 | changed_images = images.detach()
90 | optimizer = torch.optim.Adam([perturbation], lr=ENCODING_LR)
91 | std_weights = get_std_weight(images, alpha=PERT_ALPHA) if use_weighting else 1
92 |
93 | for i in range(0, max_iter):
94 |
95 | changed_images = compute_changed_images(images, perturbation, std_weights)
96 | loss, predictions = loss_func(model, changed_images, targets)
97 |
98 | loss.backward()
99 | optimizer.step()
100 | optimizer.zero_grad()
101 |
102 | error = np.mean([binary.distance(x, y) for x, y in zip(predictions, targets)])
103 |
104 | if verbose:
105 | logger.update("epoch", i)
106 | logger.update("loss", loss)
107 | logger.update("bits", error)
108 |
109 | changed_images = compute_changed_images(images, perturbation, std_weights)
110 |
111 | if n is not None:
112 | if verbose:
113 | print(f"Fixing distribution size: {model.n} -> {n}")
114 | n, model.n = (model.n, n)
115 |
116 | if returnPerturbation:
117 | return changed_images.detach(), perturbation.detach()
118 |
119 | return changed_images.detach()
120 |
121 |
122 | """
123 | Command-line interface for encoding a single image.
124 | """
125 |
126 |
127 | def encode(
128 | image,
129 | out,
130 | target=binary.str(binary.random(TARGET_SIZE)),
131 | n=96,
132 | model=None,
133 | max_iter=500,
134 | use_weighting=True,
135 | perturbation_out=None,
136 | ):
137 |
138 | if not isinstance(model, DecodingModel):
139 | model = DataParallelModel(DecodingModel.load(distribution=transforms.encoding, n=n, weights_file=model))
140 | image = im.torch(im.load(image)).unsqueeze(0)
141 | print("Target: ", target)
142 | target = binary.parse(str(target))
143 | encoded = encode_binary(image, [target], model, n=n, verbose=True, max_iter=max_iter, use_weighting=use_weighting)
144 | im.save(im.numpy(encoded.squeeze()), file=out)
145 | if perturbation_out != None:
146 | im.save(im.numpy((image - encoded).squeeze()), file=perturbation_out)
147 |
148 |
149 | if __name__ == "__main__":
150 | Fire(encode)
151 |
--------------------------------------------------------------------------------
/web/client/src/registerServiceWorker.js:
--------------------------------------------------------------------------------
1 | // In production, we register a service worker to serve assets from local cache.
2 |
3 | // This lets the app load faster on subsequent visits in production, and gives
4 | // it offline capabilities. However, it also means that developers (and users)
5 | // will only see deployed updates on the "N+1" visit to a page, since previously
6 | // cached resources are updated in the background.
7 |
8 | // To learn more about the benefits of this model, read https://goo.gl/KwvDNy.
9 | // This link also includes instructions on opting out of this behavior.
10 |
11 | const isLocalhost = Boolean(
12 | window.location.hostname === 'localhost' ||
13 | // [::1] is the IPv6 localhost address.
14 | window.location.hostname === '[::1]' ||
15 | // 127.0.0.1/8 is considered localhost for IPv4.
16 | window.location.hostname.match(
17 | /^127(?:\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)){3}$/
18 | )
19 | );
20 |
21 | export default function register() {
22 | if (process.env.NODE_ENV === 'production' && 'serviceWorker' in navigator) {
23 | // The URL constructor is available in all browsers that support SW.
24 | const publicUrl = new URL(process.env.PUBLIC_URL, window.location);
25 | if (publicUrl.origin !== window.location.origin) {
26 | // Our service worker won't work if PUBLIC_URL is on a different origin
27 | // from what our page is served on. This might happen if a CDN is used to
28 | // serve assets; see https://github.com/facebookincubator/create-react-app/issues/2374
29 | return;
30 | }
31 |
32 | window.addEventListener('load', () => {
33 | const swUrl = `${process.env.PUBLIC_URL}/service-worker.js`;
34 |
35 | if (isLocalhost) {
36 | // This is running on localhost. Lets check if a service worker still exists or not.
37 | checkValidServiceWorker(swUrl);
38 |
39 | // Add some additional logging to localhost, pointing developers to the
40 | // service worker/PWA documentation.
41 | navigator.serviceWorker.ready.then(() => {
42 | console.log(
43 | 'This web app is being served cache-first by a service ' +
44 | 'worker. To learn more, visit https://goo.gl/SC7cgQ'
45 | );
46 | });
47 | } else {
48 | // Is not local host. Just register service worker
49 | registerValidSW(swUrl);
50 | }
51 | });
52 | }
53 | }
54 |
55 | function registerValidSW(swUrl) {
56 | navigator.serviceWorker
57 | .register(swUrl)
58 | .then(registration => {
59 | registration.onupdatefound = () => {
60 | const installingWorker = registration.installing;
61 | installingWorker.onstatechange = () => {
62 | if (installingWorker.state === 'installed') {
63 | if (navigator.serviceWorker.controller) {
64 | // At this point, the old content will have been purged and
65 | // the fresh content will have been added to the cache.
66 | // It's the perfect time to display a "New content is
67 | // available; please refresh." message in your web app.
68 | console.log('New content is available; please refresh.');
69 | } else {
70 | // At this point, everything has been precached.
71 | // It's the perfect time to display a
72 | // "Content is cached for offline use." message.
73 | console.log('Content is cached for offline use.');
74 | }
75 | }
76 | };
77 | };
78 | })
79 | .catch(error => {
80 | console.error('Error during service worker registration:', error);
81 | });
82 | }
83 |
84 | function checkValidServiceWorker(swUrl) {
85 | // Check if the service worker can be found. If it can't reload the page.
86 | fetch(swUrl)
87 | .then(response => {
88 | // Ensure service worker exists, and that we really are getting a JS file.
89 | if (
90 | response.status === 404 ||
91 | response.headers.get('content-type').indexOf('javascript') === -1
92 | ) {
93 | // No service worker found. Probably a different app. Reload the page.
94 | navigator.serviceWorker.ready.then(registration => {
95 | registration.unregister().then(() => {
96 | window.location.reload();
97 | });
98 | });
99 | } else {
100 | // Service worker found. Proceed as normal.
101 | registerValidSW(swUrl);
102 | }
103 | })
104 | .catch(() => {
105 | console.log(
106 | 'No internet connection found. App is running in offline mode.'
107 | );
108 | });
109 | }
110 |
111 | export function unregister() {
112 | if ('serviceWorker' in navigator) {
113 | navigator.serviceWorker.ready.then(registration => {
114 | registration.unregister();
115 | });
116 | }
117 | }
118 |
--------------------------------------------------------------------------------
/logger.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import matplotlib as mpl
4 |
5 | mpl.use("Agg")
6 | import matplotlib.pyplot as plt
7 | import random, sys, os, json, math
8 |
9 | import torch
10 | from torchvision import datasets, transforms
11 | import visdom
12 |
13 | from utils import *
14 | import IPython
15 |
16 |
17 | class BaseLogger(object):
18 | def __init__(self, name, verbose=True):
19 |
20 | self.name = name
21 | self.data = {}
22 | self.running_data = {}
23 | self.reset_running = {}
24 | self.verbose = verbose
25 | self.hooks = []
26 |
27 | def add_hook(self, hook, feature="epoch", freq=40):
28 | self.hooks.append((hook, feature, freq))
29 |
30 | def update(self, feature, x):
31 | if isinstance(x, torch.Tensor):
32 | x = x.data.cpu().numpy().mean()
33 |
34 | self.data[feature] = self.data.get(feature, [])
35 | self.data[feature].append(x)
36 | if feature not in self.running_data or self.reset_running.pop(feature, False):
37 | self.running_data[feature] = []
38 | self.running_data[feature].append(x)
39 |
40 | for hook, hook_feature, freq in self.hooks:
41 | if feature == hook_feature and len(self.data[feature]) % freq == 0:
42 | hook(self.data[feature])
43 |
44 | def step(self):
45 | self.text(f"({self.name}) ", end="")
46 | for feature in self.running_data.keys():
47 | if len(self.running_data[feature]) == 0:
48 | continue
49 | val = np.mean(self.running_data[feature])
50 | if float(val).is_integer():
51 | self.text(f"{feature}: {int(val)}", end=", ")
52 | else:
53 | self.text(f"{feature}: {val:0.4f}", end=", ")
54 | self.reset_running[feature] = True
55 | self.text(f" ... {elapsed():0.2f} sec")
56 |
57 | def text(self, text, end="\n"):
58 | raise NotImplementedError()
59 |
60 | def plot(self, data, plot_name, opts={}):
61 | raise NotImplementedError()
62 |
63 | def images(self, data, image_name):
64 | raise NotImplementedError()
65 |
66 |
67 | class Logger(BaseLogger):
68 | def __init__(self, *args, **kwargs):
69 | self.results = kwargs.pop("results", "output")
70 | super().__init__(*args, **kwargs)
71 |
72 | def text(self, text, end="\n"):
73 | print(text, end=end, flush=True)
74 |
75 | def plot(self, data, plot_name, opts={}):
76 | np.savez_compressed(f"{self.results}/{plot_name}.npz", data)
77 | plt.plot(data)
78 | plt.savefig(f"{self.results}/{plot_name}.jpg")
79 | plt.clf()
80 |
81 |
82 | class VisdomLogger(BaseLogger):
83 | def __init__(self, *args, **kwargs):
84 | self.port = kwargs.pop("port", 7000)
85 | self.server = kwargs.pop("server", "35.230.67.129")
86 | self.env = kwargs.pop("env", "main")
87 | print(f"Logging to environment {self.env}")
88 | self.visdom = visdom.Visdom(
89 | server="http://" + self.server, port=self.port, env=self.env, use_incoming_socket=False
90 | )
91 | self.visdom.delete_env(self.env)
92 | self.windows = {}
93 | super().__init__(*args, **kwargs)
94 |
95 | def text(self, text, end="\n"):
96 | print(text, end=end)
97 | window, old_text = self.windows.get("text", (None, ""))
98 | if end == "\n":
99 | end = " "
100 | display = old_text + text + end
101 |
102 | if window is not None:
103 | window = self.visdom.text(display, win=window, append=False)
104 | else:
105 | window = self.visdom.text(display)
106 |
107 | self.windows["text"] = window, display
108 |
109 | def viz(self, viz_name, method, *args, **kwargs):
110 | window = self.windows.get(viz_name, None)
111 | if window is not None:
112 | window = getattr(self.visdom, method)(*args, **kwargs, win=window)
113 | else:
114 | window = getattr(self.visdom, method)(*args, **kwargs, win=window)
115 | self.windows[viz_name] = window
116 |
117 | def plot(self, data, plot_name, opts={}):
118 |
119 | window = self.windows.get(plot_name, None)
120 | opts.update({"title": plot_name})
121 |
122 | self.viz(plot_name, "line", np.array(data), opts=opts)
123 |
124 | def images(self, data, image_name, opts={}, resize=64):
125 |
126 | transform = transforms.Compose([transforms.ToPILImage(), transforms.Resize(resize), transforms.ToTensor()])
127 | data = torch.stack([transform(x) for x in data.cpu()])
128 | data = data.data.cpu().numpy()
129 |
130 | window = self.windows.get(image_name, None)
131 |
132 | opts.update({"title": image_name})
133 | self.viz(image_name, "images", np.array(data), opts=opts)
134 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # NeuralHash: An Adversarial Steganographic Method For Robust, Imperceptible Watermarking
2 | Building the next-gen watermark with deep learning: imperceptibly encoding images with un-erasable patterns to verify content ownership.
3 |
4 | ## What it does:
5 | Given an image (like Scream), Neuralhash makes small perturbations to visually encode a unique signature of the author:
6 |
7 |
8 |
9 | Which is able to be decoded even after extreme transformations (like a cellphone photo of the encoded image):
10 |
11 |
12 |
13 |
14 |
15 | Our secure watermarking scheme represents significant advances in protecting content ownership and piracy prevention on the Internet.
16 | ## Harnessing Adversarial Examples
17 |
18 | Our key insight is that we can use adversarial example techniques on a Decoder Network (that maps input images to 32-bit signatures) to generate perturbations that decode to the desired signature. We perform projected gradient descent under the Expectation over Transformation framework to do this as follows:
19 |
20 |
21 |
22 | We simulate an attack distrubtion using a set of differentiable transformations over which we train over. Here are some sample transforms:
23 |
24 |
25 |
26 |
27 | ## Training the Network
28 | We also propose a method to train our decoder network under the Expectation-Maximization (EM) framework to learn feature transformations that are more resilient to the threat space of attacks. As shown below, we alternate between encoding images using the network and then updating the network's weights to be more robust to attacks.
29 |
30 |
31 |
32 |
33 | The below plots show robustness of our encoded images during the training process. As you can see, over many iterations, the line becomes flatter, indicating robustness over rotation and scaling. Shown later, our approach generalizes to more extreme transformations.
34 |
35 |
36 |
37 |
38 | ## Sample Encodings
39 | Here are some sample original images (top row) and the corresponding watermarked image (bottom row):
40 |
41 |
42 |
43 | ## Example Attacks
44 | Some examples where our approach succeessfully decodes the correct signature and examples where it fails:
45 |
46 |
47 |
48 | ## Final Thoughts:
49 |
50 | The development of a secure watermarking scheme is an important problem that has applications in content ownership and piracy prevention. Current state-of-the-art techniques are unable to document robustness across a variety of affine transformations. We propose a method that harnesses the expressiveness of deep neural networks to covertly embed imperceptible, transformation-resilient binary signatures into images. Given a decoder network, our key insight is that adversarial example generation techniques can be used to encode images by performing projected gradient descent on the image to embed a chosen signature.
51 |
52 | By performing projective gradient descent on the decoder model with respect to a given image, we can use it to “sign” images robustly (think of a more advanced watermark). We start with the original image, then repeatedly tweak the pixel values such that the image (and all transformations, including scaling, rotation, adding noise, blurring, random cropping, and more) decodes to a specified 32-bit code. The resultant image will be almost imperceptible from the original image, yet contain an easily-decodable signature that cannot be removed even by the most dedicated of adversaries.
53 |
54 | We also propose a method to train our decoder network under the Expectation-Maximization (EM) framework to learn feature transformations that are more resilient to the threat space of attacks. Experimental results indicate that our model achieves robustness across different transformations such as scaling and rotating, with improved results over the length of EM training. Furthermore, we show an inherent trade-off between robustness and imperceptibility, which allows the user of the model flexibility in adjusting parameters to fit a particular task.
55 |
56 | Paper and more details coming soon.
57 |
--------------------------------------------------------------------------------
/web/client/src/pages/background.js:
--------------------------------------------------------------------------------
1 |
2 |
3 | var colorPicker = (function() {
4 | var colors = ["#FF6138", "#FFBE53", "#2980B9", "#282741"];
5 | var index = 0;
6 | function next() {
7 | index = index++ < colors.length - 1 ? index : 0;
8 | return colors[index];
9 | }
10 | function current() {
11 | return colors[index];
12 | }
13 | return {
14 | next: next,
15 | current: current
16 | };
17 | })();
18 |
19 | function removeAnimation(animation) {
20 | var index = animations.indexOf(animation);
21 | if (index > -1) animations.splice(index, 1);
22 | }
23 |
24 | function calcPageFillRadius(x, y) {
25 | var l = Math.max(x - 0, cW - x);
26 | var h = Math.max(y - 0, cH - y);
27 | return Math.sqrt(Math.pow(l, 2) + Math.pow(h, 2));
28 | }
29 |
30 | function addClickListeners() {
31 | document.addEventListener("touchstart", handleEvent);
32 | document.addEventListener("mousedown", handleEvent);
33 | }
34 |
35 | function handleEvent(e) {
36 | if (e.touches) {
37 | e.preventDefault();
38 | e = e.touches[0];
39 | }
40 | var currentColor = colorPicker.current();
41 | var nextColor = colorPicker.next();
42 | var targetR = calcPageFillRadius(e.pageX, e.pageY);
43 | var rippleSize = Math.min(200, cW * 0.4);
44 | var minCoverDuration = 750;
45 |
46 | var pageFill = new Circle({
47 | x: e.pageX,
48 | y: e.pageY,
49 | r: 0,
50 | fill: nextColor
51 | });
52 | var fillAnimation = anime({
53 | targets: pageFill,
54 | r: targetR,
55 | duration: Math.max(targetR / 2, minCoverDuration),
56 | easing: "easeOutQuart",
57 | complete: function() {
58 | bgColor = pageFill.fill;
59 | removeAnimation(fillAnimation);
60 | }
61 | });
62 |
63 | var ripple = new Circle({
64 | x: e.pageX,
65 | y: e.pageY,
66 | r: 0,
67 | fill: currentColor,
68 | stroke: {
69 | width: 3,
70 | color: currentColor
71 | },
72 | opacity: 1
73 | });
74 | var rippleAnimation = anime({
75 | targets: ripple,
76 | r: rippleSize,
77 | opacity: 0,
78 | easing: "easeOutExpo",
79 | duration: 900,
80 | complete: removeAnimation
81 | });
82 |
83 | var particles = [];
84 | for (var i = 0; i < 32; i++) {
85 | var particle = new Circle({
86 | x: e.pageX,
87 | y: e.pageY,
88 | fill: currentColor,
89 | r: anime.random(24, 48)
90 | });
91 | particles.push(particle);
92 | }
93 | var particlesAnimation = anime({
94 | targets: particles,
95 | x: function(particle) {
96 | return particle.x + anime.random(rippleSize, -rippleSize);
97 | },
98 | y: function(particle) {
99 | return particle.y + anime.random(rippleSize * 1.15, -rippleSize * 1.15);
100 | },
101 | r: 0,
102 | easing: "easeOutExpo",
103 | duration: anime.random(1000, 1300),
104 | complete: removeAnimation
105 | });
106 | animations.push(fillAnimation, rippleAnimation, particlesAnimation);
107 | }
108 |
109 | function extend(a, b) {
110 | for (var key in b) {
111 | if (b.hasOwnProperty(key)) {
112 | a[key] = b[key];
113 | }
114 | }
115 | return a;
116 | }
117 |
118 | var Circle = function(opts) {
119 | extend(this, opts);
120 | };
121 |
122 | Circle.prototype.draw = function() {
123 | ctx.globalAlpha = this.opacity || 1;
124 | ctx.beginPath();
125 | ctx.arc(this.x, this.y, this.r, 0, 2 * Math.PI, false);
126 | if (this.stroke) {
127 | ctx.strokeStyle = this.stroke.color;
128 | ctx.lineWidth = this.stroke.width;
129 | ctx.stroke();
130 | }
131 | if (this.fill) {
132 | ctx.fillStyle = this.fill;
133 | ctx.fill();
134 | }
135 | ctx.closePath();
136 | ctx.globalAlpha = 1;
137 | };
138 |
139 | var animate = anime({
140 | duration: Infinity,
141 | update: function() {
142 | ctx.fillStyle = bgColor;
143 | ctx.fillRect(0, 0, cW, cH);
144 | animations.forEach(function(anim) {
145 | anim.animatables.forEach(function(animatable) {
146 | animatable.target.draw();
147 | });
148 | });
149 | }
150 | });
151 |
152 | var resizeCanvas = function() {
153 | cW = window.innerWidth;
154 | cH = window.innerHeight;
155 | c.width = cW * devicePixelRatio;
156 | c.height = cH * devicePixelRatio;
157 | ctx.scale(devicePixelRatio, devicePixelRatio);
158 | };
159 |
160 | (function init() {
161 | resizeCanvas();
162 | if (window.CP) {
163 | // CodePen's loop detection was causin' problems
164 | // and I have no idea why, so...
165 | window.CP.PenTimer.MAX_TIME_IN_LOOP_WO_EXIT = 6000;
166 | }
167 | window.addEventListener("resize", resizeCanvas);
168 | addClickListeners();
169 | if (!!window.location.pathname.match(/fullcpgrid/)) {
170 | startFauxClicking();
171 | }
172 | handleInactiveUser();
173 | })();
174 |
175 | function handleInactiveUser() {
176 | var inactive = setTimeout(function() {
177 | fauxClick(cW / 2, cH / 2);
178 | }, 2000);
179 |
180 | function clearInactiveTimeout() {
181 | clearTimeout(inactive);
182 | document.removeEventListener("mousedown", clearInactiveTimeout);
183 | document.removeEventListener("touchstart", clearInactiveTimeout);
184 | }
185 |
186 | document.addEventListener("mousedown", clearInactiveTimeout);
187 | document.addEventListener("touchstart", clearInactiveTimeout);
188 | }
189 |
190 | function startFauxClicking() {
191 | setTimeout(function() {
192 | fauxClick(
193 | anime.random(cW * 0.2, cW * 0.8),
194 | anime.random(cH * 0.2, cH * 0.8)
195 | );
196 | startFauxClicking();
197 | }, anime.random(200, 900));
198 | }
199 |
200 | function fauxClick(x, y) {
201 | var fauxClick = new Event("mousedown");
202 | fauxClick.pageX = x;
203 | fauxClick.pageY = y;
204 | document.dispatchEvent(fauxClick);
205 | }
206 |
--------------------------------------------------------------------------------
/web/client/src/components/ImageTrace/ImageTrace.js:
--------------------------------------------------------------------------------
1 | import React from "react";
2 | import "./ImageTrace.css";
3 |
4 | import Image from "image-js";
5 | import { toPath, toPoints } from "svg-points";
6 | import * as simplify from "simplify-js";
7 | import * as ImageTracer from "imagetracerjs";
8 | import HtmlToReact, { Parser } from "html-to-react";
9 | import { TweenLite, morphSVG, TimelineLite, SlowMo, CustomEase } from "gsap";
10 |
11 | export default class ImageTrace extends React.Component {
12 | state = {
13 | path:
14 | };
15 |
16 | componentDidMount() {
17 | Image.load("protected.jpeg").then(async img => {
18 | let data = ImageTracer.imagedataToSVG(img, {
19 | ltres: 30,
20 | qtres: 30,
21 | numberofcolors: 2,
22 | pal: [{ r: 255, b: 255, g: 255, a: 0 }, { r: 0, b: 0, g: 0, a: 1 }],
23 | colorsampling: 0,
24 | linefilter: true
25 | });
26 |
27 | let svgData = this.convertSvgGroupToPath(data);
28 |
29 | this.setState({
30 | path:
31 | });
32 |
33 | let tl = new TimelineLite();
34 |
35 | let orig = document.querySelector("#data");
36 | let obj = {
37 | length: 0,
38 | pathLength: orig.getTotalLength()
39 | };
40 |
41 | tl.to(obj, 500, {
42 | length: obj.pathLength,
43 | onUpdate: drawLine,
44 | ease: SlowMo.ease.config(0.1, 0.7, false)
45 | });
46 |
47 | function drawLine() {
48 | orig.style.strokeDasharray = [obj.length, obj.pathLength].join(" ");
49 | }
50 |
51 | tl.to(
52 | "#data",
53 | 1,
54 | {
55 | morphSVG: { shape: "#lock", shapeIndex: 20 }
56 | },
57 | "-=498"
58 | );
59 | tl.play();
60 | });
61 | }
62 |
63 | convertSvgGroupToPath(data) {
64 | let finalD = "";
65 | var processNodeDefinitions = new HtmlToReact.ProcessNodeDefinitions(React);
66 | let component = new Parser().parseWithInstructions(
67 | data,
68 | node => {
69 | return true;
70 | },
71 | [
72 | {
73 | shouldProcessNode: function(node) {
74 | return node.name === "svg";
75 | },
76 | processNode: function(node, children) {
77 | return children;
78 | }
79 | },
80 | {
81 | shouldProcessNode: function(node) {
82 | return node.name === "path";
83 | },
84 | processNode: function(node, children, index) {
85 | finalD += " " + node.attribs.d;
86 | return React.createElement("path", {
87 | key: index,
88 | d: node.attribs.d
89 | });
90 | }
91 | },
92 | {
93 | // Anything else
94 | shouldProcessNode: function(node) {
95 | return true;
96 | },
97 | processNode: processNodeDefinitions.processDefaultNode
98 | }
99 | ]
100 | );
101 | return finalD;
102 | }
103 | /*
104 | let svgData = MSQR(edge.getCanvas(), {
105 | width: edge.width,
106 | tolerance: 50,
107 | align: false,
108 | alpha: 1,
109 | bleed: 5, // width of bleed mask (used with multiple shapes only)
110 | maxShapes: 5,
111 | height: edge.height,
112 | path2D: false
113 | });
114 |
115 | let a = document.createElement("a");
116 | document.body.appendChild(a);
117 | a.style = "display: none";
118 | let url = window.URL.createObjectURL(await edge.toBlob());
119 | a.href = url;
120 | a.download = "path.png";
121 | a.click();
122 | window.URL.revokeObjectURL(url);
123 |
124 | console.log(svgData);
125 | this.setState({ path: toPath(svgData) });
126 | });
127 | }
128 |
129 | trace = image => {
130 | var point, nextpoint;
131 | let data = [];
132 |
133 | for (var i = 0; i <= image.data.length; i++) {
134 | if (image.data[i] === 255) {
135 | // start pathfinding
136 | point = { x: i % image.width, y: (i / image.width) | 0 };
137 |
138 | image.data[i] = 0;
139 |
140 | // start a line
141 | var line = [];
142 | line.push(point);
143 | while ((nextpoint = this.lineGobble(image, point))) {
144 | line.push(nextpoint);
145 | point = nextpoint;
146 | }
147 | data.push(line);
148 | }
149 | }
150 | return data;
151 | };
152 |
153 | lineGobble = (image, point) => {
154 | var neighbor = [
155 | [0, -1], // n
156 | [1, 0], // s
157 | [0, 1], // e
158 | [-1, 0], // w
159 | [-1, -1], // nw
160 | [1, -1], // ne
161 | [1, 1], // se
162 | [-1, 1] // sw
163 | ];
164 | var checkpoint = {};
165 |
166 | for (var i = 0; i < neighbor.length; i++) {
167 | checkpoint.x = point.x + neighbor[i][0];
168 | checkpoint.y = point.y + neighbor[i][1];
169 |
170 | var result = this.checkpixel(image, checkpoint);
171 | if (result) {
172 | return checkpoint;
173 | }
174 | }
175 | return false;
176 | };
177 |
178 | checkpixel = (image, point) => {
179 | if (0 <= point.x < image.width) {
180 | if (0 <= point.y < image.height) {
181 | // point is "in bounds"
182 | var index = point.y * image.width + point.x;
183 | if (image.data[index] === 255) {
184 | image.data[index] = 0;
185 | return true;
186 | }
187 | }
188 | }
189 | return false;
190 | };*/
191 |
192 | render() {
193 | return (
194 |