├── README.md
├── pyproject.toml
├── src
└── nude2
│ ├── static
│ ├── favicon.ico
│ └── index.html
│ ├── utils.py
│ ├── generate.py
│ ├── benchmark.py
│ ├── train_constants.py
│ ├── splashes
│ └── splash.txt
│ ├── browse.py
│ ├── progress.py
│ ├── model_v1.py
│ ├── cli.py
│ ├── sample.py
│ ├── model.py
│ ├── onnxify.py
│ ├── train.py
│ └── data.py
├── pytest.ini
├── .gitignore
├── test
├── data_test.py
└── train_test.py
├── setup.cfg
└── samples
└── index.html
/README.md:
--------------------------------------------------------------------------------
1 | :cherry_blossom:
2 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools"]
3 | build-backend = "setuptools.build_meta"
4 |
--------------------------------------------------------------------------------
/src/nude2/static/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/whatever/nude-2.0/main/src/nude2/static/favicon.ico
--------------------------------------------------------------------------------
/pytest.ini:
--------------------------------------------------------------------------------
1 | [pytest]
2 | minversion = 6.0
3 | addopts = -ra -q
4 | testpaths =
5 | test
6 | integration
7 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.avi
2 | *.egg-info
3 | *.jpg
4 | *.json
5 | *.mp4
6 | *.png
7 | *.pt
8 | *.swp
9 | *.tar.gz
10 | .DS_Store
11 | .ipynb_checkpoints/
12 | __pycache__
13 |
--------------------------------------------------------------------------------
/src/nude2/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Feels lazy to name something "utils.py", but here we are.
3 | """
4 |
5 |
6 | import os.path
7 |
8 |
9 | def splash(splash_name):
10 | """Print a splash message."""
11 | fname = os.path.join(os.path.dirname(__file__), "splashes", f"{splash_name}.txt")
12 | with open(fname, "r") as f:
13 | txt = f.read()
14 | print(f"\033[95m{txt}\033[00m")
15 |
--------------------------------------------------------------------------------
/src/nude2/generate.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 |
4 |
5 | from nude2.train import Generator, weights_init
6 | from nude2.data import MetCenterCroppedDataset
7 |
8 | def main(checkpoint):
9 |
10 | states = torch.load(checkpoint)
11 |
12 | g = Generator()
13 | g.apply(weights_init)
14 | g.load_state_dict(states["g"])
15 |
16 | fixed_noise = torch.randn(1, 100, 1, 1)
17 | vec = g(fixed_noise)
18 | img = MetCenterCroppedDataset.pilify(vec[0])
19 | img.save("generated.png")
20 |
--------------------------------------------------------------------------------
/test/data_test.py:
--------------------------------------------------------------------------------
1 | import PIL
2 | import torch
3 | import unittest
4 |
5 |
6 | from nude2.data import CachedDataset
7 | from nude2.train import Generator, Discriminator
8 |
9 |
10 | from train_test import WIDTH, HEIGHT
11 |
12 |
13 | class DataTest(unittest.TestCase):
14 | def test_resize(self):
15 |
16 | dataset = CachedDataset("./whatever", "./no-cache")
17 | img = PIL.Image.new("RGB", (8*WIDTH, 8*HEIGHT))
18 |
19 | for t in dataset.transforms:
20 | img = t(img)
21 | w, h = img.size
22 | self.assertEqual(w, WIDTH)
23 | self.assertEqual(h, HEIGHT)
24 |
--------------------------------------------------------------------------------
/src/nude2/benchmark.py:
--------------------------------------------------------------------------------
1 | import nude2.data
2 | from datetime import datetime, timedelta
3 | import time
4 |
5 |
6 | def main():
7 |
8 | db = nude2.data.MetData(nude2.data.DB)
9 |
10 | tags = ["Sphinx", "Crucifixion", "Male Nudes", "Female Nudes", "asdasd", ""]
11 |
12 | mediums = ["oil", ""]
13 |
14 | cases = sorted(
15 | (tag, medium)
16 | for tag in tags
17 | for medium in mediums
18 | )
19 |
20 | for tag, medium in cases:
21 | start = datetime.now()
22 | results = db.fetch_tag(tag, medium)
23 | duration = datetime.now() - start
24 | pad = 30 - len(tag) - len(medium)
25 | print(f"tag={tag}, medium={medium} {'.'*pad} {duration}, {len(results)}")
26 |
--------------------------------------------------------------------------------
/src/nude2/train_constants.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | # Batch size during training
4 | batch_size = 1024
5 |
6 | # Spatial size of training images. All images will be resized to this
7 | # size using a transformer.
8 | image_size = 64
9 |
10 | # Number of channels in the training images. For color images this is 3
11 | nc = 3
12 |
13 | # Size of z latent vector (i.e. size of generator input)
14 | nz = 100
15 |
16 | # Size of feature maps in generator
17 | ngf = 64
18 |
19 | # Size of feature maps in discriminator
20 | ndf = 64
21 |
22 | # Learning rate for optimizers
23 | lr = 0.0002
24 |
25 | # Beta1 hyperparameter for Adam optimizers
26 | beta1 = 0.5
27 |
28 | # Number of GPUs available. Use 0 for CPU mode.
29 | ngpu = 1
30 |
31 | # Set device
32 | device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
33 |
--------------------------------------------------------------------------------
/test/train_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import unittest
3 |
4 |
5 | from nude2.train import Generator, Discriminator
6 |
7 |
8 | WIDTH = HEIGHT = 64
9 |
10 |
11 | class ModuleShapeTest(unittest.TestCase):
12 | """Test that the generator and discrimator compute the correct shapes"""
13 |
14 | def test_generator_shape(self):
15 | """Ensure that g(Nx100x1x1) -> Nx3x256, 256)"""
16 |
17 | g = Generator()
18 |
19 | r = g(torch.rand(10, 100, 1, 1))
20 |
21 | self.assertEqual(
22 | r.shape,
23 | (10, 3, WIDTH, HEIGHT),
24 | )
25 |
26 | def test_discriminator_shape(self):
27 | """Ensure that d(Nx3x256, 256) -> Nx1x1x1"""
28 |
29 | d = Discriminator()
30 |
31 | r = d(torch.rand(8, 3, WIDTH, HEIGHT))
32 |
33 | self.assertEqual(
34 | r.shape,
35 | (8, 1, 1, 1),
36 | )
37 |
--------------------------------------------------------------------------------
/src/nude2/splashes/splash.txt:
--------------------------------------------------------------------------------
1 | :::!~!!!!!:.
2 | .xUHWH!! !!?M88WHX:.
3 | .X*#M@$!! !X!M$$$$$$WWx:.
4 | :!!!!!!?H! :!$!$$$$$$$$$$8X:
5 | !!~ ~:~!! :~!$!#$$$$$$$$$$8X:
6 | :!~::!H!< ~.U$X!?R$$$$$$$$MM!
7 | ~!~!!!!~~ .:XW$$$U!!?$$$$$$RMM!
8 | !:~~~ .:!M"T#$$$$WX??#MRRMMM!
9 | ~?WuxiW*` `"#$$$$8!!!!??!!!
10 | :X- M$$$$ `"T#$T~!8$WUXU~
11 | :%` ~#$$$m: ~!~ ?$$$$$$
12 | :!`.- ~T$$$$8xx. .xWW- ~""##*"
13 | ..... -~~:<` ! ~?T#$$@@W@*?$$ /`
14 | W$@@M!!! .!~~ !! .:XUW$W!~ `"~: :
15 | #"~~`.:x%`!! !H: !WM$$$$Ti.: .!WUn+!`
16 | :::~:!!`:X~ .: ?H.!u "$$$B$$$!W:U!T$$M~
17 | .~~ :X@!.-~ ?@WTWo("*$$$W$TH$! `
18 | Wi.~!X$?!-~ : ?$$$B$Wu("**$RM!
19 | $R@i.~~ ! : ~$$$$$B$$en:``
20 | ?MXT@Wx.~ : ~"##*$$$$M~
21 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | name = nude2
3 | version = 0.0.1
4 | author = Matt <3
5 | author_email = matt@worldshadowgovernment.com
6 | url = https://github.com/whatever/nude-2.0
7 | description = Attempted recreation of the iconic Robbie Barrett's "The Nude"
8 | long_description = file: README.md
9 | long_description_content_type = text/markdown
10 | keywords = nude
11 | license = UNLICENSE
12 | classifiers =
13 | License :: OSI Approved :: BSD License
14 | Programming Language :: Python :: 3
15 |
16 | [options]
17 | package_dir =
18 | = src
19 | packages = find:
20 | install_requires =
21 | onnxruntime==1.16.3
22 | Pillow==10.0.1
23 | pytest==7.4.3
24 | torch==2.1.0
25 | torchvision==0.16.0
26 | # albumentations==1.3.1
27 | # matplotlib==3.8.0
28 | # opencv-python==4.8.1.78
29 |
30 | [options.packages.find]
31 | where = src
32 | exclude =
33 | examples*
34 | # tools*
35 | # docs*
36 |
37 | [options.entry_points]
38 | console_scripts =
39 | nude = nude2.cli:main
40 |
41 |
42 | [options.package_data]
43 | example = data/schema.json, *.txt
44 | * = README.md
45 |
--------------------------------------------------------------------------------
/samples/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | samples
5 |
6 |
7 |
18 |
44 |
45 |
46 |
47 |
48 |
49 |
--------------------------------------------------------------------------------
/src/nude2/browse.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import http.server
3 | import json
4 | import nude2.data
5 | import os
6 | import threading
7 | import time
8 | import webbrowser
9 |
10 |
11 | from http.server import HTTPServer, SimpleHTTPRequestHandler
12 |
13 |
14 | HOST = "127.0.0.1"
15 | """Open server locally"""
16 |
17 |
18 | PORT = 8181
19 | """Default to port 8181"""
20 |
21 |
22 | DIR = os.path.realpath(os.path.join(os.path.dirname(__file__), "static"))
23 | """Use /static/ directory"""
24 |
25 |
26 | class StaticDirServer(SimpleHTTPRequestHandler):
27 | """Serve a directory of static files"""
28 |
29 | def __init__(self, *args, **kwargs):
30 | """Construct"""
31 | super().__init__(*args, directory=DIR, **kwargs)
32 |
33 | def do_GET(self):
34 | """Respond to get requests with static files or api response"""
35 | if not self.path.startswith("/api/v0/"):
36 | return super().do_GET()
37 |
38 | pieces = [
39 | v.replace("%20", " ")
40 | for v in self.path.split("/", 4)
41 | ]
42 |
43 | medium, tag = pieces[-2:]
44 |
45 | db = nude2.data.MetData(nude2.data.DB)
46 | print("Searching for:", medium, tag)
47 | res = db.fetch_tag(tag, medium)
48 |
49 | self.send_response(200)
50 | self.send_header("Content-type", "application/json")
51 | self.end_headers()
52 | self.wfile.write(bytes(json.dumps({"results": res}).encode("ascii")))
53 |
54 |
55 | def serve():
56 | """Start the static server"""
57 | server_address = (HOST, PORT)
58 | httpd = HTTPServer(server_address, StaticDirServer)
59 | httpd.serve_forever()
60 |
61 |
62 | def view():
63 | threading.Thread(target=start_server).start()
64 | time.sleep(10)
65 |
66 |
67 | def start_server():
68 | print("Starting server...")
69 | threading.Thread(target=serve).start()
70 |
71 | print("Sleeping for 1 second...")
72 | time.sleep(1)
73 |
74 | print("Opening browser...")
75 | webbrowser.open(f"http://{HOST}:{PORT}")
76 |
--------------------------------------------------------------------------------
/src/nude2/progress.py:
--------------------------------------------------------------------------------
1 | import random
2 | import time
3 | import sys
4 |
5 | from datetime import datetime
6 |
7 | class ProgressBar(object):
8 | """..."""
9 |
10 | def __init__(self, x, prefix="", size=100, clear=True):
11 |
12 | if isinstance(x, int):
13 | self.iter = range(x)
14 | else:
15 | self.iter = iter(x)
16 |
17 | self.start_time = datetime.now()
18 | self.clear = clear
19 | self.prefix = prefix
20 | self.size = size
21 | self.current = 0
22 | self.line = self.prefix + "[" + " "*self.size + "]"
23 | self.show()
24 |
25 | def x_x(self):
26 | eye = ["x", "-", "^", "o", "O"]
27 | return f"{random.choice(eye)}_{random.choice(eye)}"
28 |
29 | def show(self):
30 |
31 | # Percentage complete
32 | pct = self.current / len(self.iter)
33 | pct = round(pct * self.size)
34 |
35 | # Time elapsed
36 | dur = datetime.now() - self.start_time
37 |
38 | t = int(dur.total_seconds())
39 | r = dur.total_seconds() - t
40 |
41 | h = t // 3600
42 | m = (t - h*3600) // 60
43 | s = t - h*3600 - m*60
44 | f = t - h*3600 - m*60 - s
45 |
46 | sys.stdout.write("\b"*len(self.line))
47 | self.line = self.prefix + "[" + "*"*pct + " "*(self.size-pct) + "] " + self.x_x()
48 | self.line += f" {100*self.current/len(self.iter):0.2f}% [{h}:{m:02}:{s:02}] @ {datetime.now().strftime('%H:%M:%S')}\r"
49 | sys.stdout.write(self.line)
50 | sys.stdout.flush()
51 |
52 | def inc(self):
53 | self.current += 1
54 | self.show()
55 |
56 | def set(self, value):
57 | self.current = value
58 | self.show()
59 |
60 | def __enter__(self):
61 | return self
62 |
63 | def __exit__(self, type, value, traceback):
64 | print(self.line)
65 | if self.clear:
66 | sys.stdout.write("\b"*len(self.line))
67 |
68 | class I(object):
69 | def __iter__(self):
70 | return self
71 | def __next__(self):
72 | return 1
73 |
74 | def main():
75 | for _ in range(5):
76 | with ProgressBar(66) as progress:
77 | for _ in range(66):
78 | time.sleep(0.02)
79 | progress.inc()
80 |
--------------------------------------------------------------------------------
/src/nude2/model_v1.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from nude2.train_constants import *
5 |
6 |
7 | class Discriminator(nn.Module):
8 | """Module to discriminate between real and fake images"""
9 |
10 | def __init__(self):
11 | super(Discriminator, self).__init__()
12 | self.main = nn.Sequential(
13 | # input is ``(nc) x 64 x 64``
14 | nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
15 | nn.LeakyReLU(0.2, inplace=True),
16 | # state size. ``(ndf) x 32 x 32``
17 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
18 | nn.BatchNorm2d(ndf * 2),
19 | nn.LeakyReLU(0.2, inplace=True),
20 | # state size. ``(ndf*2) x 16 x 16``
21 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
22 | nn.BatchNorm2d(ndf * 4),
23 | nn.LeakyReLU(0.2, inplace=True),
24 | # state size. ``(ndf*4) x 8 x 8``
25 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
26 | nn.BatchNorm2d(ndf * 8),
27 | nn.LeakyReLU(0.2, inplace=True),
28 | # state size. ``(ndf*8) x 4 x 4``
29 | nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
30 | nn.Sigmoid(),
31 | )
32 |
33 | def forward(self, input):
34 | return self.main(input)
35 |
36 |
37 | class Generator(nn.Module):
38 | """Module to generate an image from a feature vector"""
39 |
40 | def __init__(self):
41 | super(Generator, self).__init__()
42 | self.main = nn.Sequential(
43 | # input is Z, going into a convolution
44 | nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
45 | nn.BatchNorm2d(ngf * 8),
46 | nn.ReLU(True),
47 | # state size. ``(ngf*8) x 4 x 4``
48 | nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
49 | nn.BatchNorm2d(ngf * 4),
50 | nn.ReLU(True),
51 | # state size. ``(ngf*4) x 8 x 8``
52 | nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
53 | nn.BatchNorm2d(ngf * 2),
54 | nn.ReLU(True),
55 | # state size. ``(ngf*2) x 16 x 16``
56 | nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
57 | nn.BatchNorm2d(ngf),
58 | nn.ReLU(True),
59 | # state size. ``(ngf) x 32 x 32``
60 | nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
61 | nn.Tanh()
62 | # state size. ``(nc) x 64 x 64``,
63 | )
64 |
65 | def forward(self, input):
66 | return self.main(input)
67 |
--------------------------------------------------------------------------------
/src/nude2/cli.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import nude2.benchmark
4 | import nude2.browse
5 | import nude2.data
6 | import nude2.generate
7 | import nude2.onnxify
8 | import nude2.progress
9 | import nude2.sample
10 | import nude2.train
11 |
12 | import nude2
13 |
14 | def main():
15 | parser = argparse.ArgumentParser()
16 |
17 | subparsers = parser.add_subparsers(dest="command")
18 |
19 | browse_parser = subparsers.add_parser("browse")
20 |
21 | data_parser = subparsers.add_parser("data")
22 | data_parser.add_argument("--concurrency", type=int, default=8)
23 | data_parser.add_argument("--limit", type=int, default=1000)
24 |
25 | benchmark_parser = subparsers.add_parser("benchmark")
26 |
27 | progress_parser = subparsers.add_parser("progress")
28 |
29 | train_parser = subparsers.add_parser("train")
30 | train_parser.add_argument("--data", type=str, required=True)
31 | train_parser.add_argument("--epochs", type=int, default=10)
32 | train_parser.add_argument("--batch-size", type=int, default=8)
33 | train_parser.add_argument("--checkpoint", type=str, required=True)
34 | train_parser.add_argument("--samples-path", type=str)
35 | train_parser.add_argument("--seed", type=int, default=42069)
36 |
37 | generate_parser = subparsers.add_parser("generate")
38 | generate_parser.add_argument("--checkpoint", type=str)
39 | generate_parser.add_argument("--samples", type=str)
40 |
41 | video_parser = subparsers.add_parser("video")
42 | video_parser.add_argument("--checkpoint", type=str)
43 | video_parser.add_argument("--samples", type=str)
44 | video_parser.add_argument("-o", "--out", type=str)
45 |
46 | sample_parser = subparsers.add_parser("sample")
47 | sample_parser.add_argument("--checkpoint", type=str)
48 | sample_parser.add_argument("--samples", type=str)
49 |
50 | onnxify_parser = subparsers.add_parser("onnxify")
51 | onnxify_parser.add_argument("--checkpoint", type=str)
52 | onnxify_parser.add_argument("-o", "--output", type=str)
53 |
54 | args = parser.parse_args()
55 |
56 | if args.command == "browse":
57 | nude2.browse.view()
58 | elif args.command == "data":
59 | nude2.data.main(args.concurrency, args.limit)
60 | elif args.command == "benchmark":
61 | nude2.benchmark.main()
62 | elif args.command == "progress":
63 | nude2.progress.main()
64 | elif args.command == "train":
65 | nude2.train.main(
66 | args.data,
67 | args.epochs,
68 | args.batch_size,
69 | args.checkpoint,
70 | args.samples_path,
71 | seed=args.seed,
72 | )
73 | elif args.command == "onnxify":
74 | nude2.onnxify.main(
75 | args.checkpoint,
76 | args.output,
77 | )
78 | elif args.command == "generate":
79 | nude2.generate.main(args.checkpoint)
80 | elif args.command == "video":
81 | nude2.sample.main(args.samples, args.out)
82 | elif args.command == "sample":
83 | nude2.sample.main(args.checkpoint, args.samples)
84 |
--------------------------------------------------------------------------------
/src/nude2/sample.py:
--------------------------------------------------------------------------------
1 | import PIL
2 | import os.path
3 | import torch
4 |
5 | import torch.nn as nn
6 |
7 | from glob import glob
8 |
9 | from collections import defaultdict
10 | import itertools
11 | import subprocess
12 | import random
13 |
14 | from nude2.model import Generator198x198
15 | import nude2.data
16 |
17 |
18 | def gridify(images, nrows, ncols):
19 | WIDTH = HEIGHT = 64
20 | PADDING = 4
21 |
22 | w = (ncols+1)*PADDING + ncols*WIDTH
23 | h = (nrows+1)*PADDING + nrows*HEIGHT
24 |
25 | elpapa = PIL.Image.new("RGB", (w, h))
26 | elpapa.paste((255, 255, 255), (0, 0, w, h))
27 |
28 | for n, img in enumerate(images):
29 | i = n // nrows
30 | j = n % ncols
31 | x = PADDING + i*(WIDTH+PADDING)
32 | y = PADDING + j*(HEIGHT+PADDING)
33 | elpapa.paste(img, (x, y))
34 |
35 | return elpapa
36 |
37 |
38 | def main2(checkpoint_path):
39 |
40 | NCOLS = NROWS = 8
41 |
42 | images = [
43 | PIL.Image.new("RGB", (64, 64))
44 | for i in range(NROWS*NCOLS)
45 | ]
46 |
47 | for i, img in enumerate(images):
48 | x = i % NCOLS
49 | y = i // NCOLS
50 |
51 | r = int(x/(NCOLS-1)*255)
52 | g = 0
53 | b = int(y/(NROWS-1)*255)
54 |
55 | img.paste(
56 | (r, g, b),
57 | (0, 0, 64, 64),
58 | )
59 |
60 | elpapa = gridify(images, NROWS, NCOLS)
61 |
62 | elpapa.save("colsamps/yikes.png")
63 |
64 |
65 | def main2(samples_path, out):
66 |
67 | fnames = sorted(glob(os.path.join(samples_path, "*.jpg")))
68 |
69 | image_set = defaultdict(list)
70 |
71 | images = defaultdict(list)
72 |
73 | for fname in sorted(fnames):
74 | fname = os.path.basename(fname)
75 | g = fname.split("-")[1]
76 | image_set[g].append(fname)
77 |
78 | for epoch in sorted(image_set.keys()):
79 | fnames = sorted(image_set[epoch])[0:64]
80 | for fname in fnames:
81 | with PIL.Image.open(os.path.join(samples_path, fname)) as img:
82 | images[epoch].append(img.resize((64, 64)))
83 |
84 | for epoch in images.keys():
85 | elpapa = gridify(images[epoch], 8, 8)
86 | print(f"saving epoch={epoch}")
87 | elpapa.save(f"colsamps/{epoch}.png")
88 |
89 |
90 | subprocess.run([
91 | "ffmpeg",
92 | "-framerate", "30",
93 | "-pattern_type", "glob",
94 | "-i", "colsamps/*.png",
95 | "-c:v", "libx264",
96 | "-pix_fmt", "yuv420p",
97 | out,
98 | ])
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 | # Useful with:
107 | """
108 | ffmpeg -framerate 30 \
109 | -pattern_type glob \
110 | -i '*.jpg' \
111 | -c:v libx264 \
112 | -pix_fmt yuv420p \
113 | out.mp4
114 | """
115 |
116 |
117 | def main(checkpoint_path, samples_path):
118 | states = torch.load(
119 | checkpoint_path,
120 | map_location=torch.device('cpu'),
121 | )
122 |
123 | g = Generator198x198()
124 | g.load_state_dict(states["g"])
125 | g.eval()
126 |
127 | batch_size = 64
128 |
129 | # noise = torch.rand(batch_size, 100, 1, 1)
130 | noise = torch.ones(batch_size, 100, 1, 1)
131 | noise = torch.rand(batch_size, 100, 1, 1)
132 |
133 | for i in range(batch_size):
134 | out = g(noise)
135 | img = nude2.data.MetCenterCroppedDataset.pilify(out[i])
136 | img.save(os.path.join(samples_path, f"sample-{i}.jpg"))
137 |
--------------------------------------------------------------------------------
/src/nude2/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | ndf = ngf = 64
5 | nz = 100
6 | nc = 3
7 |
8 |
9 | class Discriminator(nn.Module):
10 | """Module to discriminate between real and fake images"""
11 |
12 | def __init__(self):
13 | super(Discriminator, self).__init__()
14 |
15 | self.main = nn.Sequential(
16 | # input is ``(nc) x 64 x 64``
17 | nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
18 | nn.LeakyReLU(0.2, inplace=True),
19 | # state size. ``(ndf) x 32 x 32``
20 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
21 | nn.BatchNorm2d(ndf * 2),
22 | nn.LeakyReLU(0.2, inplace=True),
23 | # state size. ``(ndf*2) x 16 x 16``
24 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
25 | nn.BatchNorm2d(ndf * 4),
26 | nn.LeakyReLU(0.2, inplace=True),
27 | # state size. ``(ndf*4) x 8 x 8``
28 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
29 | nn.BatchNorm2d(ndf * 8),
30 | nn.LeakyReLU(0.2, inplace=True),
31 | # state size. ``(ndf*8) x 4 x 4``
32 | nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
33 | nn.Sigmoid(),
34 | )
35 |
36 | def forward(self, input):
37 | return self.main(input)
38 |
39 |
40 | class Generator(nn.Module):
41 | """Module to generate an image from a feature vector"""
42 |
43 | def __init__(self):
44 | super(Generator, self).__init__()
45 | self.main = nn.Sequential(
46 | # input is Z, going into a convolution
47 | nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
48 | nn.BatchNorm2d(ngf * 8),
49 | nn.ReLU(True),
50 | # state size. ``(ngf*8) x 4 x 4``
51 | nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
52 | nn.BatchNorm2d(ngf * 4),
53 | nn.ReLU(True),
54 | # state size. ``(ngf*4) x 8 x 8``
55 | nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
56 | nn.BatchNorm2d(ngf * 2),
57 | nn.ReLU(True),
58 | # state size. ``(ngf*2) x 16 x 16``
59 | nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
60 | nn.BatchNorm2d(ngf),
61 | nn.ReLU(True),
62 | # state size. ``(ngf) x 32 x 32``
63 | nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
64 | nn.Tanh()
65 | # state size. ``(nc) x 64 x 64``,
66 | )
67 |
68 |
69 | def forward(self, input):
70 | return self.main(input)
71 |
72 |
73 | class Generator198x198(nn.Module):
74 | """Module to generate an image from a feature vector"""
75 |
76 | def __init__(self):
77 | super(Generator198x198, self).__init__()
78 |
79 | self.main = nn.Sequential(
80 | nn.ConvTranspose2d( nz, ngf * 8, 5, 2, 0, bias=False),
81 | nn.BatchNorm2d(ngf * 8),
82 | nn.ReLU(True),
83 |
84 | nn.ConvTranspose2d( ngf * 8, ngf * 4, 6, 3, 1, bias=False),
85 | nn.BatchNorm2d(ngf * 4),
86 | nn.ReLU(True),
87 |
88 | nn.ConvTranspose2d( ngf * 4, ngf * 2, 6, 3, 1, bias=False),
89 | nn.BatchNorm2d(ngf * 2),
90 | nn.ReLU(True),
91 |
92 | nn.ConvTranspose2d( ngf * 2, ngf * 1, 5, 2, 1, bias=False),
93 | nn.BatchNorm2d(ngf * 1),
94 | nn.ReLU(True),
95 |
96 | nn.ConvTranspose2d( ngf * 1, nc, 4, 2, 1, bias=False),
97 | nn.Tanh(),
98 | )
99 |
100 |
101 | def forward(self, input):
102 | return self.main(input)
103 |
--------------------------------------------------------------------------------
/src/nude2/onnxify.py:
--------------------------------------------------------------------------------
1 | import tempfile
2 | import PIL
3 | import onnxruntime as ort
4 | import torch
5 | import torch.nn as nn
6 | import torchvision.transforms as transforms
7 | import numpy as np
8 |
9 | import onnx
10 | from onnx import compose
11 |
12 | from nude2.model import Generator198x198
13 | import nude2.data
14 |
15 | class ImageGenerator(Generator198x198):
16 |
17 |
18 | def __init__(self):
19 | super().__init__()
20 | s = 1/255.
21 | self.pilify = transforms.Compose([
22 | transforms.Normalize(mean=[0., 0., 0.], std=[1/0.5, 1/0.5, 1/0.5]),
23 | transforms.Normalize(mean=[-0.5, -0.5, -0.5], std=[s, s, s]),
24 | ])
25 |
26 | def forward(self, x):
27 | return self.pilify(super().forward(x))
28 |
29 |
30 | class Transform(nn.Module):
31 | def __init__(self):
32 | super().__init__()
33 | s = 1.0
34 | self.normalize1 = transforms.Normalize(mean=[0., 0., 0.], std=[1/0.5, 1/0.5, 1/0.5])
35 | self.normalize2 = transforms.Normalize(mean=[-0.5, -0.5, -0.5], std=[s, s, s])
36 |
37 | def forward(self, x):
38 | x = self.normalize1(x)
39 | x = self.normalize2(x)
40 | # x = self.to_pil(x)
41 | return x
42 |
43 |
44 |
45 | def main(checkpoint, output):
46 |
47 | fname = "checkpoints/nude2-dcgan-met-random-crop-198x198.pt"
48 | fname = checkpoint
49 |
50 | states = torch.load(fname, map_location="cpu")
51 |
52 | with tempfile.NamedTemporaryFile("w") as f:
53 | t = Transform()
54 | x = torch.randn(1, 3, 1, 1)
55 | torch.onnx.export(t, x, f.name, opset_version=10, input_names=["raw"], output_names=["output"])
56 | t_onnx = onnx.load(f.name)
57 |
58 | with tempfile.NamedTemporaryFile("w") as f:
59 | g = Generator198x198()
60 | g.load_state_dict(states["g"])
61 | g.cpu()
62 | g.eval()
63 | g.train()
64 | x = torch.randn(1, 100, 1, 1)
65 | x = torch.ones(1, 100, 1, 1)
66 | torch.onnx.export(
67 | g,
68 | x,
69 | f.name,
70 | export_params=True,
71 | opset_version=10,
72 | do_constant_folding=True,
73 | input_names=["vector"],
74 | output_names=["unnormalized_output"],
75 | )
76 |
77 | g_onnx = onnx.load(f.name)
78 | onnx.checker.check_model(g_onnx)
79 |
80 | # Session
81 | ort_sess = ort.InferenceSession(f.name)
82 | ort_outs = ort_sess.run(None, {"vector": x.numpy()})[0]
83 | g_outs = g(x).detach().numpy()
84 |
85 | # print(g_outs[0, 0, :10, :10])
86 |
87 | if not np.allclose(ort_outs, g_outs, rtol=1e-6, atol=1e-6):
88 | print("WARNING: predicted matrices aren't close!")
89 |
90 | model = compose.merge_models(g_onnx, t_onnx, io_map=[("unnormalized_output", "raw")])
91 | onnx.save_model(model, output)
92 |
93 | # XXX: GENERATE SAMPLE
94 | # session = ort.InferenceSession("checkpoints/nice.onnx")
95 | # res = session.run(None, {"vector": x.numpy()})
96 | # res = res[0][0]
97 | # res = res.astype(np.uint8)
98 | # res = np.moveaxis(res, 0, -1)
99 | # print(res.shape)
100 | # img = PIL.Image.fromarray(res, mode="RGB")
101 | # img.save("cmon.png")
102 | # print(img)
103 |
104 |
105 | # XXX: GENERATE PYTORCH SAMPLE
106 | # batch_size = 64
107 | # x = torch.randn(batch_size, 100, 1, 1)
108 | # y = g(x)
109 | # y = t(y)
110 | # img = nude2.data.MetCenterCroppedDataset.pilify(y[0])
111 | # img.save("yea.png")
112 |
113 | # print((ort_outs - g_outs).mean())
114 | # print((ort_outs - g_outs).max())
115 |
--------------------------------------------------------------------------------
/src/nude2/static/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | lol
5 |
6 |
7 |
30 |
31 |
32 |
33 |
34 |
60 |
61 |
62 |
63 |
64 |
136 |
137 |
138 |
139 |
--------------------------------------------------------------------------------
/src/nude2/train.py:
--------------------------------------------------------------------------------
1 | """
2 | DCGAN stolen from
3 | """
4 |
5 |
6 | import json
7 | import logging
8 | import nude2.data
9 | import os
10 | import os.path
11 | import random
12 | import torch
13 | import torch.nn as nn
14 | import torch.optim
15 | import torch.utils.data as data
16 | import torchvision
17 |
18 | from datetime import datetime
19 | from nude2.progress import ProgressBar
20 | from nude2.utils import splash
21 |
22 | from nude2.model import Generator, Discriminator
23 | from nude2.train_constants import *
24 |
25 | import nude2.model_v1 as model_v1
26 |
27 |
28 | LOG_FORMAT = "\033[95m%(asctime)s\033[00m [%(levelname)s] %(message)s"
29 | logging.basicConfig(format=LOG_FORMAT)
30 | logger = logging.getLogger(__name__)
31 | logger.setLevel(logging.INFO)
32 |
33 |
34 | def weights_init(m):
35 | """Normalize weights"""
36 | classname = m.__class__.__name__
37 | if classname.find('Conv') != -1:
38 | nn.init.normal_(m.weight.data, 0.0, 0.02)
39 | elif classname.find('BatchNorm') != -1:
40 | nn.init.normal_(m.weight.data, 1.0, 0.02)
41 | nn.init.constant_(m.bias.data, 0)
42 |
43 |
44 |
45 |
46 | def main(data_folder, num_epochs, batch_size, checkpoint_path, samples_path, seed=None):
47 |
48 | if seed is not None:
49 | manualSeed = 999
50 | random.seed(manualSeed)
51 | torch.manual_seed(manualSeed)
52 | torch.use_deterministic_algorithms(True)
53 |
54 | data_dir = os.path.expanduser(data_folder)
55 |
56 | splash("splash")
57 | print("\n\n")
58 | print("NUDE 2.0")
59 | print("========")
60 | print(f"data .......... \033[95m{data_dir}\033[00m")
61 | print(f"epochs ........ \033[96m{num_epochs}\033[00m")
62 | print(f"batch size .... \033[95m{batch_size}\033[00m")
63 | print(f"device ........ \033[95m{device}\033[00m")
64 | print(f"checkpoint .... \033[95m{checkpoint_path}\033[00m")
65 | print(f"samples path .. \033[95m{samples_path}\033[00m")
66 | print()
67 |
68 | if samples_path is not None:
69 | os.makedirs(samples_path, exist_ok=True)
70 | with open(os.path.join(samples_path, "meta.json"), "w") as f:
71 | json.dump({
72 | "data_folder": data_folder,
73 | "num_epochs": num_epochs,
74 | "batch_size": batch_size,
75 | "checkpoint_path": checkpoint_path,
76 | "samples_path": samples_path,
77 | "lr": lr,
78 | "beta1": beta1,
79 | }, f)
80 |
81 | # dataset = nude2.data.MetCenterCroppedDataset(data_dir)
82 |
83 | dataset = nude2.data.CachedDataset(
84 | data_dir,
85 | "~/.cache/nude2/images-random-crop-256x256",
86 | )
87 |
88 | dataloader = data.DataLoader(
89 | dataset,
90 | batch_size=batch_size,
91 | shuffle=True,
92 | num_workers=8,
93 | )
94 |
95 | g = Generator().to(device)
96 | g.apply(weights_init)
97 |
98 | d = Discriminator().to(device)
99 | d.apply(weights_init)
100 |
101 | total_params = sum(
102 | param.numel()
103 | for param in g.parameters()
104 | )
105 |
106 | print("Total params =", total_params)
107 |
108 | criterion = nn.BCELoss()
109 |
110 | fixed_noise = torch.randn(64, nz, 1, 1, device=device)
111 |
112 | real_label = 1.0
113 | fake_label = 0.0
114 |
115 | optimizerD = torch.optim.Adam(d.parameters(), lr=lr, betas=(beta1, 0.999))
116 | optimizerG = torch.optim.Adam(g.parameters(), lr=lr, betas=(beta1, 0.999))
117 |
118 | try:
119 | states = torch.load(checkpoint_path)
120 | g.load_state_dict(states["g"])
121 | d.load_state_dict(states["d"])
122 | epoch = states["epoch"] + 1
123 | logger.info(f"Loaded at {epoch-1}")
124 | except FileNotFoundError:
125 | logger.warn("Could not find specified checkpoint... starting GAN at epoch=0")
126 | epoch = 0
127 |
128 | last_epoch = epoch + num_epochs
129 |
130 | for epoch in range(epoch, epoch + num_epochs):
131 |
132 | start = datetime.utcnow()
133 |
134 | with ProgressBar(len(dataloader), prefix=f"[epoch={epoch:04d}/{last_epoch}] ", size=40) as progress:
135 | for i, imgs in enumerate(dataloader):
136 |
137 | # Discriminate against real data
138 | d.zero_grad()
139 | real_cpu = imgs.to(device)
140 | b_size = real_cpu.size(0)
141 | label = torch.full(
142 | (b_size,),
143 | real_label,
144 | dtype=torch.float,
145 | device=device,
146 | )
147 | output = d(real_cpu).view(-1)
148 | errD_real = criterion(output, label)
149 | errD_real.backward()
150 | D_x = output.mean().item()
151 |
152 | # Discriminate against fake data
153 | noise = torch.randn(b_size, nz, 1, 1, device=device)
154 | fake = g(noise)
155 | label.fill_(fake_label)
156 |
157 | output = d(fake.detach()).view(-1)
158 | errD_fake = criterion(output, label)
159 |
160 | errD_fake.backward()
161 | D_G_z1 = output.mean().item()
162 | errD = errD_real + errD_fake
163 | optimizerD.step()
164 |
165 | # Increment generator
166 | g.zero_grad()
167 | label.fill_(real_label) # fake labels are real for generator cost
168 | output = d(fake).view(-1)
169 | errG = criterion(output, label)
170 | errG.backward()
171 | D_G_z2 = output.mean().item()
172 | optimizerG.step()
173 |
174 | # inc
175 | progress.inc()
176 |
177 | sample = g(fixed_noise).detach().cpu()
178 |
179 | if samples_path is not None:
180 | for i in range(sample.size(0)):
181 | img = nude2.data.MetCenterCroppedDataset.pilify(sample[i])
182 | fname = f"sample-{epoch:04d}-{i:02d}.jpg"
183 | img.save(os.path.join(samples_path, fname))
184 |
185 | dur = datetime.utcnow() - start
186 |
187 | res = torch.save({
188 | "g": g.state_dict(),
189 | "d": d.state_dict(),
190 | "epoch": epoch,
191 | }, checkpoint_path)
192 |
--------------------------------------------------------------------------------
/src/nude2/data.py:
--------------------------------------------------------------------------------
1 | import PIL.Image
2 | import csv
3 | import ctypes
4 | import hashlib
5 | import json
6 | import logging
7 | import multiprocessing as mp
8 | import multiprocessing.dummy
9 | import numpy as np
10 | import os.path
11 | import requests
12 | import shutil
13 | import sqlite3
14 | import tempfile
15 | import torch
16 | import torchvision
17 | import urllib.request
18 |
19 | from collections import OrderedDict
20 | from datetime import datetime
21 | from glob import glob
22 | from torch.utils.data import Dataset, DataLoader
23 |
24 |
25 | logger = logging.getLogger(__name__)
26 |
27 |
28 | CACHE_DIR = os.path.expanduser(os.path.join("~", ".cache", "nude2", "data"))
29 |
30 |
31 | IMAGES_CSV = os.path.join(CACHE_DIR, "met-images.csv")
32 |
33 |
34 | MET_CSV = os.path.join(CACHE_DIR, "MetObjects.csv")
35 |
36 |
37 | DB = os.path.join(CACHE_DIR, "met.db")
38 |
39 |
40 | MET_COLS = OrderedDict([
41 | ("Object Number", "VARCHAR(256)"),
42 | ("Is Highlight", "BOOL"),
43 | ("Is Timeline Work", "BOOL"),
44 | ("Is Public Domain", "BOOL"),
45 | ("Object ID", "NUMBER"),
46 | ("Gallery Number", "NUMBER"),
47 | ("Department", "VARCHAR(255)"),
48 | ("AccessionYear", "VARCHAR(255)"),
49 | ("Object Name", "VARCHAR(255)"),
50 | ("Title", "VARCHAR(255)"),
51 | ("Culture", "VARCHAR(255)"),
52 | ("Period", "VARCHAR(255)"),
53 | ("Dynasty", "VARCHAR(255)"),
54 | ("Reign", "VARCHAR(255)"),
55 | ("Portfolio", "VARCHAR(255)"),
56 | ("Constituent ID", "VARCHAR(255)"),
57 | ("Artist Role", "VARCHAR(255)"),
58 | ("Artist Prefix", "VARCHAR(255)"),
59 | ("Artist Display Name", "VARCHAR(255)"),
60 | ("Artist Display Bio", "VARCHAR(255)"),
61 | ("Artist Suffix", "VARCHAR(255)"),
62 | ("Artist Alpha Sort", "VARCHAR(255)"),
63 | ("Artist Nationality", "VARCHAR(255)"),
64 | ("Artist Begin Date", "VARCHAR(255)"),
65 | ("Artist End Date", "VARCHAR(255)"),
66 | ("Artist Gender", "VARCHAR(255)"),
67 | ("Artist ULAN URL", "VARCHAR(255)"),
68 | ("Artist Wikidata URL", "VARCHAR(255)"),
69 | ("Object Date", "VARCHAR(255)"),
70 | ("Object Begin Date", "VARCHAR(255)"),
71 | ("Object End Date", "VARCHAR(255)"),
72 | ("Medium", "VARCHAR(255)"),
73 | ("Dimensions", "VARCHAR(255)"),
74 | ("Credit Line", "VARCHAR(255)"),
75 | ("Geography Type", "VARCHAR(255)"),
76 | ("City", "VARCHAR(255)"),
77 | ("State", "VARCHAR(255)"),
78 | ("County", "VARCHAR(255)"),
79 | ("Country", "VARCHAR(255)"),
80 | ("Region", "VARCHAR(255)"),
81 | ("Subregion", "VARCHAR(255)"),
82 | ("Locale", "VARCHAR(255)"),
83 | ("Locus", "VARCHAR(255)"),
84 | ("Excavation", "VARCHAR(255)"),
85 | ("River", "VARCHAR(255)"),
86 | ("Classification", "VARCHAR(255)"),
87 | ("Rights and Reproduction", "VARCHAR(255)"),
88 | ("Link Resource", "VARCHAR(255)"),
89 | ("Object Wikidata URL", "VARCHAR(255)"),
90 | ("Metadata Date", "VARCHAR(255)"),
91 | ("Repository", "VARCHAR(255)"),
92 | ("Tags", "VARCHAR(255)"),
93 | ("Tags AAT URL", "VARCHAR(255)"),
94 | ("Tags Wikidata URL", "VARCHAR(255)"),
95 | ])
96 |
97 |
98 | CREATE_MET_TABLE = f"""
99 | CREATE TABLE IF NOT EXISTS met ({", ".join(
100 | " `{k}` {v}".format(k=k, v=v)
101 | for k, v in MET_COLS.items()
102 | )},
103 | CONSTRAINT pk_met PRIMARY KEY (`Object ID`)
104 | )"""
105 |
106 |
107 |
108 | CREATE_MET_IMAGES_TABLE = """
109 | CREATE TABLE IF NOT EXISTS met_images (
110 | `Object ID` NUMBER,
111 | `Image URL` VARCHAR(255),
112 | CONSTRAINT pk_met_images PRIMARY KEY(`Object ID`, `Image URL`)
113 | )
114 | """
115 |
116 |
117 | CREATE_MET_TAG_TABLE = """
118 | CREATE TABLE IF NOT EXISTS met_tags (
119 | `Object ID` NUMBER,
120 | `Tag` VARCHAR(255),
121 | CONSTRAINT pk_met_tag PRIMARY KEY (`Object ID`, `Tag`)
122 | )
123 | """
124 |
125 | SELECT_MATCHING_TAGS_COLS = [
126 | "Object ID",
127 | "Title",
128 | "Object Number",
129 | "Object Name",
130 | "Is Highlight",
131 | "Is Timeline Work",
132 | "Is Public Domain",
133 | "Image URL",
134 | "Medium",
135 | "Tags",
136 | "Tag",
137 | "Image URL",
138 | ]
139 |
140 |
141 | SELECT_MATCHING_TAGS = f"""
142 | CREATE VIEW IF NOT EXISTS tagged_images AS
143 | SELECT
144 | {", ".join(f"`{c}`" for c in SELECT_MATCHING_TAGS_COLS)}
145 | FROM met
146 | INNER JOIN met_tags USING (`Object ID`)
147 | INNER JOIN met_images USING (`Object ID`)
148 | """
149 |
150 |
151 | SELECT_TAGGED_MET_IMAGES_SQL = """
152 | WITH matched_object_ids AS (
153 | SELECT DISTINCT
154 | `Object ID`
155 | FROM met_tags
156 | WHERE `Tag` IN ?
157 | )
158 | SELECT
159 | *
160 | FROM met_images
161 | INNER JOIN matched_object_ids USING (`Object ID`)
162 | """
163 |
164 | def create_index_sql(table_name, cols):
165 | sanitized_cols = [c.lower().replace(" ", "") for c in cols]
166 | cols = [f"`{c}`" for c in cols]
167 | return f"""
168 | CREATE INDEX IF NOT EXISTS `idx__{table_name}__{'_'.join(sanitized_cols)}`
169 | ON {table_name} ({", ".join(cols)})
170 | """
171 |
172 |
173 | class MetData(object):
174 | """Metroplitan Museum of Art Data"""
175 |
176 | def __init__(self, loc=DB):
177 | self.loc = loc
178 | self.conn = sqlite3.connect(self.loc)
179 |
180 | if not self.is_bootstrapped():
181 | logger.info("Bootstrapping MET data")
182 | self.bootstrap()
183 |
184 | def is_bootstrapped(self):
185 | """Return whether database has been bootstrapped"""
186 | with self.conn as conn:
187 | curs = conn.cursor()
188 | curs.execute("SELECT name FROM sqlite_master WHERE type='table'")
189 | tables = set(row[0] for row in curs.fetchall())
190 | return tables == {"met", "met_images", "met_tags"}
191 |
192 | def bootstrap(self):
193 | with self.conn as conn:
194 | curs = conn.cursor()
195 |
196 | curs.execute(CREATE_MET_TABLE)
197 | curs.execute(create_index_sql("met", ["Object ID"]))
198 |
199 | with open(MET_CSV, "r") as fi:
200 | reader = csv.DictReader(fi)
201 | curs.executemany(
202 | f"INSERT INTO met VALUES ({', '.join('?' * len(MET_COLS))})",
203 | (tuple(row.values()) for row in reader),
204 | )
205 |
206 | curs.execute(CREATE_MET_IMAGES_TABLE)
207 | curs.execute(create_index_sql("met_images", ["Object ID"]))
208 |
209 | # XXX: executemany receives a `set` because there are dupes in the csv.
210 | # Just fix the csv here first
211 | with open(IMAGES_CSV, "r") as fi:
212 | reader = csv.DictReader(fi)
213 | curs.executemany(
214 | "INSERT INTO met_images VALUES (?, ?)",
215 | {tuple(row.values()) for row in reader},
216 | )
217 |
218 | curs.execute(CREATE_MET_TAG_TABLE)
219 | curs.execute(create_index_sql("met_tags", ["Object ID"]))
220 | curs.execute(create_index_sql("met_tags", ["Object ID", "Tag"]))
221 |
222 | curs.execute("SELECT `Object ID`, `Tags` FROM met WHERE `Tags` IS NOT NULL AND `Tags` != ''")
223 |
224 | rows = curs.fetchall()
225 |
226 | curs.executemany(
227 | "INSERT INTO met_tags VALUES (?, ?)",
228 | ((row[0], tag) for row in rows for tag in row[1].split("|")),
229 | )
230 |
231 | curs.execute(SELECT_MATCHING_TAGS)
232 |
233 | with self.conn as conn:
234 | curs = conn.cursor()
235 | curs.execute("VACUUM")
236 |
237 | logger.info("Finished bootstrapping");
238 |
239 | def __len__(self):
240 | with self.conn:
241 | curs = self.conn.cursor()
242 | curs.execute("SELECT COUNT(*) FROM met_images")
243 | return curs.fetchone()[0]
244 |
245 | def __getitem__(self, idx):
246 | with self.conn:
247 | curs = self.conn.cursor()
248 | curs.execute(
249 | "SELECT * FROM met_images LIMIT 1 OFFSET ?",
250 | (idx, ),
251 | )
252 | return curs.fetchone()
253 |
254 | def select(self, sql, params=tuple()):
255 | assert sql.upper().startswith("SELECT")
256 | with self.conn:
257 | curs = self.conn.cursor()
258 | curs.execute(sql, params)
259 | return curs.fetchall()
260 |
261 | def fetch_tag(self, tag, medium):
262 | """Return artworks with a given tag"""
263 |
264 | tag = tag or ""
265 |
266 | medium = medium or ""
267 |
268 | Q = """
269 | WITH filtered_images AS (
270 | SELECT *
271 | FROM tagged_images
272 | WHERE
273 | COALESCE(?, '') = ''
274 | OR
275 | `Tag` = ?
276 | )
277 | SELECT *
278 | FROM filtered_images
279 | WHERE
280 | `Medium` LIKE '%' || ? || '%'
281 | """
282 | with self.conn:
283 | curs = self.conn.cursor()
284 | curs.execute(Q, (tag, tag, medium, ))
285 | return [
286 | dict(zip(SELECT_MATCHING_TAGS_COLS, row))
287 | for row in curs.fetchall()
288 | ]
289 |
290 |
291 | def retrieve_csv_data():
292 | """Store CSV's locally"""
293 |
294 | if not os.path.exists(CACHE_DIR):
295 | logger.info("Creating cache directory: %s", CACHE_DIR)
296 | os.makedirs(CACHE_DIR)
297 |
298 | if not os.path.exists(MET_CSV):
299 | logger.info("Downloading MET CSV: %s", MET_CSV)
300 | MET_CSV_URL = "https://github.com/metmuseum/openaccess/raw/master/MetObjects.csv"
301 | urllib.request.urlretrieve(MET_CSV_URL, MET_CSV)
302 |
303 | if not os.path.exists(IMAGES_CSV):
304 | logger.info("Using MET Image CSV: %s", IMAGES_CSV)
305 | MET_IMAGES_CSV_URL = "https://github.com/gregsadetsky/open-access-is-great-but-where-are-the-images/raw/main/1.data/met-images.csv"
306 | urllib.request.urlretrieve(MET_IMAGES_CSV_URL, IMAGES_CSV)
307 |
308 |
309 | def get_image_file_name(image_url):
310 | """..."""
311 | _, suf = os.path.splitext(image_url)
312 | sha = hashlib.sha256()
313 | sha.update(image_url.encode())
314 | return f"met-image-{sha.hexdigest().lower()}{suf.lower()}"
315 |
316 |
317 | class MetImageGetter(object):
318 | """Loading Cache for Met Images"""
319 |
320 | def __init__(self, cache_dir="~/.cache/nude2/images/", download=True):
321 | """Construct..."""
322 | self.cache_dir = os.path.expanduser(cache_dir)
323 | self.temp_dir = tempfile.mkdtemp(prefix="met-images-")
324 | self.download = download
325 | print(self.temp_dir)
326 |
327 | def __del__(self):
328 | shutil.rmtree(self.temp_dir)
329 |
330 | def fetch(self, image_url):
331 | """Return a PIL image """
332 |
333 | image_url = image_url.replace(" ", "%20")
334 |
335 | if not os.path.exists(self.cache_dir):
336 | logger.info("Creating image cache directory: %s", self.cache_dir)
337 | os.makedirs(self.cache_dir)
338 |
339 | _, suf = os.path.splitext(image_url)
340 |
341 | sha = hashlib.sha256()
342 | sha.update(image_url.encode())
343 |
344 | fname = f"met-image-{sha.hexdigest().lower()}{suf.lower()}"
345 | ftemp = os.path.join(self.temp_dir, fname)
346 | fpath = os.path.join(self.cache_dir, fname)
347 |
348 | if not os.path.exists(fpath):
349 | if not self.download:
350 | return None
351 | logger.debug("Downloading image: %s", image_url)
352 | try:
353 | print("Retrieving!")
354 | urllib.request.urlretrieve(image_url, ftemp)
355 | shutil.move(ftemp, fpath)
356 | except:
357 | logger.error("Failed to download image: %s", image_url)
358 | return None
359 |
360 | return fpath
361 |
362 |
363 | class MetDataset(Dataset):
364 | """Configurable dataset"""
365 |
366 | crop = torchvision.transforms.Compose([
367 | torchvision.transforms.Resize((256, 256)),
368 | torchvision.transforms.CenterCrop(224),
369 | ])
370 |
371 | tensorify = torchvision.transforms.Compose([
372 | torchvision.transforms.ToTensor(),
373 | ])
374 |
375 | SELECT_QUERY = """
376 | SELECT
377 | `Object ID`,
378 | `Object Name`,
379 | `Title`,
380 | `Tags`,
381 | `Image URL`
382 | FROM met_images
383 | INNER JOIN met
384 | USING (`Object ID`)
385 | WHERE `Object Name` = 'Painting'
386 | """.strip()
387 |
388 | COUNT_QUERY = """
389 | SELECT
390 | COUNT(*)
391 | FROM met_images
392 | INNER JOIN met
393 | USING (`Object ID`)
394 | WHERE `Object Name` = 'Painting'
395 | """.strip()
396 |
397 | def __init__(self, tags, cache_dir="~/.cache/nude2/", tag_db="", force_refresh=False):
398 | self.cache_dir = os.path.expanduser(cache_dir)
399 | self.base_images_dir = os.path.join(self.cache_dir, "images")
400 | self.augmented_images_dir = os.path.join(self.cache_dir, "images-augmented")
401 |
402 | self.db = MetData(DB)
403 | self.fetcher = MetImageGetter(self.base_images_dir, download=False)
404 |
405 | base_rows = self.db.select(self.SELECT_QUERY)
406 | self.rows = list(self.augmentations(base_rows))
407 |
408 | five_crop = torchvision.transforms.Compose([
409 | torchvision.transforms.Resize(336),
410 | torchvision.transforms.FiveCrop(224),
411 | ])
412 |
413 | tensorify = torchvision.transforms.Compose([
414 | torchvision.transforms.ToTensor(),
415 | torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
416 | ])
417 |
418 | def fetch_crops(self, row):
419 | image_path = self.fetcher.fetch(row[-1])
420 | prefix, suffix = os.path.splitext(os.path.basename(image_path))
421 | crop_fnames = glob(os.path.join(self.augmented_images_dir, f"{prefix}-[0-4]-0.jpg"))
422 |
423 | crop_fnames = [
424 | os.path.join(self.augmented_images_dir, f"{prefix}-{i}-0.jpg")
425 | for i in range(5)
426 | ]
427 |
428 | crops = []
429 |
430 | if all(os.path.exists(fname) for fname in crop_fnames):
431 | for fname in crop_fnames:
432 | with PIL.Image.open(fname) as i:
433 | crops.append(i)
434 |
435 | else:
436 | logger.warning("Generating new images for {prefix}")
437 | try:
438 | with PIL.Image.open(image_path) as i:
439 | img = i.convert("RGB").copy()
440 | crops = self.five_crop(img)
441 | for i, c in enumerate(crops):
442 | fpath = os.path.join(self.augmented_images_dir, f"{prefix}-{i}-0.jpg")
443 | c.save(fpath)
444 | except PIL.Image.DecompressionBombError:
445 | logger.error("Decompression bomb error on image: %s", image_path)
446 |
447 | return crops
448 |
449 | def augmentations(self, rows):
450 |
451 | if not os.path.exists(self.augmented_images_dir):
452 | os.makedirs(self.augmented_images_dir)
453 |
454 | for row in rows:
455 |
456 | image_path = self.fetcher.fetch(row[-1])
457 |
458 | if image_path is None:
459 | continue
460 |
461 | crops = self.fetch_crops(row)
462 |
463 | object_id = row[0]
464 | original_path = image_path
465 |
466 | fname, suffix = os.path.splitext(os.path.basename(image_path))
467 |
468 | for i, c in enumerate(crops):
469 | aug_fname = f"{fname}-{i}-0{suffix}"
470 | aug_path = os.path.join(self.augmented_images_dir, aug_fname)
471 |
472 | if not os.path.exists(aug_path):
473 | print("SAVE!")
474 | c.save(aug_path)
475 |
476 | res = {
477 | "fname": aug_fname,
478 | "path": aug_path,
479 | "nude": "Nude" in row[-2],
480 | "base_image_path": row[-1],
481 | }
482 |
483 | yield res
484 |
485 | def save_augmentation_labels(self):
486 | augmented_labels_dir = os.path.join(self.cache_dir, "labels-augmented")
487 |
488 | if not os.path.exists(augmented_labels_dir):
489 | os.makedirs(augmented_labels_dir)
490 |
491 | for row in self.rows:
492 | image_prefix, _ = os.path.splitext(row["fname"])
493 | label_path = os.path.join(augmented_labels_dir, f"{image_prefix}.json")
494 | row["label_path"] = label_path
495 |
496 | with open(label_path, "w") as fo:
497 | json.dump(row, fo)
498 |
499 |
500 | def __len__(self):
501 | return len(self.rows)
502 |
503 | def __getitem__(self, idx):
504 | with PIL.Image.open(self.rows[idx]["path"]) as i:
505 | row = self.rows[idx]
506 | row["image"] = self.tensorify(np.array(i))
507 | return row
508 |
509 |
510 | class MetCenterCroppedDataset(Dataset):
511 |
512 | image_size = 64
513 |
514 | tensorify = torchvision.transforms.Compose([
515 | torchvision.transforms.Resize(image_size),
516 | torchvision.transforms.CenterCrop(image_size),
517 | torchvision.transforms.ToTensor(),
518 | torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
519 | ])
520 |
521 | pilify = torchvision.transforms.Compose([
522 | torchvision.transforms.Normalize(mean=[0., 0., 0.], std=[1/0.5, 1/0.5, 1/0.5]),
523 | torchvision.transforms.Normalize(mean=[-0.5, -0.5, -0.5], std=[1., 1., 1.]),
524 | torchvision.transforms.ToPILImage(),
525 | ])
526 |
527 | def __init__(self, cache_dir="~/.cache/nude2/"):
528 | self.cache_dir = os.path.expanduser(cache_dir)
529 | self.image_dir = self.cache_dir
530 | self.fnames = glob(os.path.join(self.image_dir, "*.jpg"))
531 | self.cache = {}
532 |
533 | def __len__(self):
534 | return len(self.fnames)
535 |
536 | def __getitem__(self, idx):
537 | if idx not in self.cache:
538 | try:
539 | with PIL.Image.open(self.fnames[idx]) as i:
540 | self.cache[idx] = self.tensorify(i.convert("RGB"))
541 | except PIL.Image.DecompressionBombError:
542 | print("fname =", self.fnames[idx])
543 | return None
544 | except PIL.Image.DecompressionBombWarning:
545 | print("!!")
546 | return self.cache[idx]
547 |
548 |
549 | class MetFiveCornerDataset(Dataset):
550 | """Five corners + flip"""
551 |
552 | uncropped_size = 256
553 |
554 | cropped_size = 128
555 |
556 | tencrop = torchvision.transforms.Compose([
557 | torchvision.transforms.Resize(uncropped_size),
558 | torchvision.transforms.TenCrop(cropped_size),
559 | ])
560 |
561 | tensorify = torchvision.transforms.Compose([
562 | torchvision.transforms.ToTensor(),
563 | torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
564 | ])
565 |
566 | pilify = torchvision.transforms.Compose([
567 | torchvision.transforms.Normalize(mean=[0., 0., 0.], std=[1/0.5, 1/0.5, 1/0.5]),
568 | torchvision.transforms.Normalize(mean=[-0.5, -0.5, -0.5], std=[1., 1., 1.]),
569 | torchvision.transforms.ToPILImage(),
570 | ])
571 |
572 | def __init__(self, cache_dir="~/.cache/nude2/"):
573 | """..."""
574 |
575 | with sqlite3.connect(DB) as conn:
576 | curs = conn.cursor()
577 | curs.execute("""
578 | SELECT
579 | `Image URL`
580 | FROM met_images
581 | INNER JOIN met_tags USING (`Object ID`)
582 | WHERE `Tag` LIKE '%Nude%'
583 | """.strip())
584 | self.nude_file_names = {
585 | get_image_file_name(row[0])
586 | for row in curs.fetchall()
587 | }
588 |
589 |
590 | self.cache_dir = os.path.expanduser(cache_dir)
591 | self.image_dir = self.cache_dir
592 | self.fnames = glob(os.path.join(self.image_dir, "*.jpg"))
593 | self.fnames = self.fnames
594 | # self.fnames = self.fnames[0:100]
595 | self.hits = mp.Array(ctypes.c_bool, len(self.fnames))
596 |
597 |
598 | n = 10*len(self.fnames)
599 | c = 3
600 | w = h = self.cropped_size
601 |
602 | shared_array_base = mp.Array(ctypes.c_float, n * c * w * h)
603 | shared_array = np.ctypeslib.as_array(shared_array_base.get_obj())
604 | shared_array = shared_array.reshape(n, c, h, w)
605 | self.cache = torch.from_numpy(shared_array)
606 |
607 | def __len__(self):
608 | """Return the length of the dataset"""
609 | return 10 * len(self.fnames)
610 |
611 | def __getitem__(self, idx):
612 | """Return the image at the given point"""
613 | i = idx // 10
614 |
615 | hit = self.hits[i]
616 |
617 | fname = self.fnames[i]
618 | fname = os.path.basename(fname)
619 |
620 | is_nude = fname in self.nude_file_names
621 |
622 | if self.hits[i] == False:
623 | PIL.Image.MAX_IMAGE_PIXELS = 343934400 + 1
624 | PIL.Image.MAX_IMAGE_PIXELS = 757164160 + 1
625 | with PIL.Image.open(self.fnames[i]) as raw_img:
626 | for j, cropped_image in enumerate(self.tencrop(raw_img.convert("RGB"))):
627 | assert 0 <= j < 10
628 |
629 | f, s = os.path.splitext(os.path.basename(self.fnames[i]))
630 |
631 | foutname = os.path.join(
632 | os.path.expanduser("~/.cache/nude2/images-tencrop"),
633 | f"{f}-{i}-{j}{s}",
634 | )
635 |
636 | print(foutname)
637 |
638 | cropped_image.save(foutname)
639 | self.cache[10*i+j] = self.tensorify(cropped_image)
640 | self.hits[i] = True
641 |
642 | return is_nude, self.cache[idx]
643 |
644 |
645 | class CachedDataset(Dataset):
646 |
647 | cropped_size = 64
648 |
649 | uncropped_size = cropped_size // 2 * 3
650 |
651 | tensorify = torchvision.transforms.Compose([
652 | torchvision.transforms.ToTensor(),
653 | torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
654 | ])
655 |
656 | pilify = torchvision.transforms.Compose([
657 | torchvision.transforms.Normalize(mean=[0., 0., 0.], std=[1/0.5, 1/0.5, 1/0.5]),
658 | torchvision.transforms.Normalize(mean=[-0.5, -0.5, -0.5], std=[1., 1., 1.]),
659 | torchvision.transforms.ToPILImage(),
660 | ])
661 |
662 | def __init__(self, image_dir, cache_dir):
663 |
664 | self.image_dir = os.path.expanduser(image_dir)
665 |
666 | self.cache_dir = os.path.expanduser(cache_dir)
667 |
668 | self.image_fnames = [
669 | fname
670 | for suffix in ["jpg"]
671 | for fname in glob(os.path.join(self.image_dir, f"*.{suffix}"))
672 | ]
673 |
674 | self.transforms = [
675 |
676 | # Standard center-crop
677 | torchvision.transforms.Compose([
678 | torchvision.transforms.Resize(self.cropped_size),
679 | torchvision.transforms.CenterCrop(self.cropped_size),
680 | ]),
681 |
682 | # Horizontal flip
683 | torchvision.transforms.Compose([
684 | torchvision.transforms.Resize(self.cropped_size),
685 | torchvision.transforms.CenterCrop(self.cropped_size),
686 | torchvision.transforms.RandomHorizontalFlip(1.0),
687 | ]),
688 | ]
689 |
690 | self.transforms += [
691 | # Random crop and jitter
692 | torchvision.transforms.Compose([
693 | torchvision.transforms.Resize(self.uncropped_size),
694 | torchvision.transforms.RandomCrop(self.cropped_size),
695 | torchvision.transforms.RandomHorizontalFlip(0.5),
696 | ]),
697 | ] * 3
698 |
699 | self.transforms += [
700 | # Random crop and jitter
701 | torchvision.transforms.Compose([
702 | torchvision.transforms.Resize(self.uncropped_size),
703 | torchvision.transforms.RandomCrop(self.cropped_size),
704 | torchvision.transforms.ColorJitter(brightness=(0.5,1.5), contrast=(1), saturation=(0.5,1.5), hue=(-0.1,0.1)),
705 | torchvision.transforms.RandomAdjustSharpness(1.5),
706 | torchvision.transforms.RandomHorizontalFlip(0.5),
707 | ]),
708 | ] * 2
709 |
710 |
711 | os.makedirs(self.cache_dir, exist_ok=True)
712 |
713 |
714 | def get_filename(self, idx):
715 | """Return the filename"""
716 | i = idx // len(self.transforms)
717 | return self.image_fnames[i]
718 |
719 |
720 | def get_source_image(self, idx):
721 | """Return the source image for an index"""
722 | with PIL.Image.open(self.get_filename(idx)) as i:
723 | return i
724 |
725 | def __len__(self):
726 | """Return the total number of augmented images"""
727 | return len(self.transforms) * len(self.image_fnames)
728 |
729 | def __getitem__(self, idx):
730 | """Return an image as a tensor from the dataset"""
731 |
732 | i = idx // len(self.transforms)
733 | r = idx % len(self.transforms)
734 |
735 | cache_fname = f"image-{i}-{r}.jpg"
736 | cache_fpath = os.path.join(self.cache_dir, cache_fname)
737 |
738 | if not os.path.exists(cache_fpath):
739 | fname = self.get_filename(idx)
740 |
741 | PIL.Image.MAX_IMAGE_PIXELS = 757164160 + 1
742 |
743 | with PIL.Image.open(fname) as img:
744 | for j, t in enumerate(self.transforms):
745 | img = t(img.convert("RGB"))
746 |
747 | with tempfile.NamedTemporaryFile(prefix="image-", suffix=".jpg", delete=False) as fo:
748 | img.save(fo.name)
749 |
750 | dest_fname = f"image-{i}-{j}.jpg"
751 | dest_fpath = os.path.join(self.cache_dir, dest_fname)
752 |
753 | shutil.move(
754 | fo.name,
755 | dest_fpath,
756 | )
757 |
758 | with PIL.Image.open(cache_fpath) as img:
759 | return self.tensorify(img)
760 |
761 |
762 |
763 | def main(concurrency, limit):
764 |
765 | dataset = CachedDataset("~/images/", "~/images-random-crops")
766 | dataloader = DataLoader(dataset, batch_size=100, num_workers=8, shuffle=True)
767 |
768 | for row in dataloader:
769 | pass
770 |
771 | return
772 |
773 | dataset = MetFiveCornerDataset(cache_dir="~/.cache/nude2/images/")
774 |
775 | dataloader = DataLoader(dataset, batch_size=100, num_workers=0, shuffle=False)
776 |
777 | for epoch in range(10):
778 | print(f"epoch = {epoch}")
779 |
780 | start = datetime.now()
781 |
782 | for hit, x in dataloader:
783 | pass
784 |
785 | dur = (datetime.now() - start).total_seconds()
786 |
787 | print("Duration:", dur)
788 |
--------------------------------------------------------------------------------