├── README.md ├── pyproject.toml ├── src └── nude2 │ ├── static │ ├── favicon.ico │ └── index.html │ ├── utils.py │ ├── generate.py │ ├── benchmark.py │ ├── train_constants.py │ ├── splashes │ └── splash.txt │ ├── browse.py │ ├── progress.py │ ├── model_v1.py │ ├── cli.py │ ├── sample.py │ ├── model.py │ ├── onnxify.py │ ├── train.py │ └── data.py ├── pytest.ini ├── .gitignore ├── test ├── data_test.py └── train_test.py ├── setup.cfg └── samples └── index.html /README.md: -------------------------------------------------------------------------------- 1 | :cherry_blossom: 2 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | -------------------------------------------------------------------------------- /src/nude2/static/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whatever/nude-2.0/main/src/nude2/static/favicon.ico -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | minversion = 6.0 3 | addopts = -ra -q 4 | testpaths = 5 | test 6 | integration 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.avi 2 | *.egg-info 3 | *.jpg 4 | *.json 5 | *.mp4 6 | *.png 7 | *.pt 8 | *.swp 9 | *.tar.gz 10 | .DS_Store 11 | .ipynb_checkpoints/ 12 | __pycache__ 13 | -------------------------------------------------------------------------------- /src/nude2/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Feels lazy to name something "utils.py", but here we are. 3 | """ 4 | 5 | 6 | import os.path 7 | 8 | 9 | def splash(splash_name): 10 | """Print a splash message.""" 11 | fname = os.path.join(os.path.dirname(__file__), "splashes", f"{splash_name}.txt") 12 | with open(fname, "r") as f: 13 | txt = f.read() 14 | print(f"\033[95m{txt}\033[00m") 15 | -------------------------------------------------------------------------------- /src/nude2/generate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | 4 | 5 | from nude2.train import Generator, weights_init 6 | from nude2.data import MetCenterCroppedDataset 7 | 8 | def main(checkpoint): 9 | 10 | states = torch.load(checkpoint) 11 | 12 | g = Generator() 13 | g.apply(weights_init) 14 | g.load_state_dict(states["g"]) 15 | 16 | fixed_noise = torch.randn(1, 100, 1, 1) 17 | vec = g(fixed_noise) 18 | img = MetCenterCroppedDataset.pilify(vec[0]) 19 | img.save("generated.png") 20 | -------------------------------------------------------------------------------- /test/data_test.py: -------------------------------------------------------------------------------- 1 | import PIL 2 | import torch 3 | import unittest 4 | 5 | 6 | from nude2.data import CachedDataset 7 | from nude2.train import Generator, Discriminator 8 | 9 | 10 | from train_test import WIDTH, HEIGHT 11 | 12 | 13 | class DataTest(unittest.TestCase): 14 | def test_resize(self): 15 | 16 | dataset = CachedDataset("./whatever", "./no-cache") 17 | img = PIL.Image.new("RGB", (8*WIDTH, 8*HEIGHT)) 18 | 19 | for t in dataset.transforms: 20 | img = t(img) 21 | w, h = img.size 22 | self.assertEqual(w, WIDTH) 23 | self.assertEqual(h, HEIGHT) 24 | -------------------------------------------------------------------------------- /src/nude2/benchmark.py: -------------------------------------------------------------------------------- 1 | import nude2.data 2 | from datetime import datetime, timedelta 3 | import time 4 | 5 | 6 | def main(): 7 | 8 | db = nude2.data.MetData(nude2.data.DB) 9 | 10 | tags = ["Sphinx", "Crucifixion", "Male Nudes", "Female Nudes", "asdasd", ""] 11 | 12 | mediums = ["oil", ""] 13 | 14 | cases = sorted( 15 | (tag, medium) 16 | for tag in tags 17 | for medium in mediums 18 | ) 19 | 20 | for tag, medium in cases: 21 | start = datetime.now() 22 | results = db.fetch_tag(tag, medium) 23 | duration = datetime.now() - start 24 | pad = 30 - len(tag) - len(medium) 25 | print(f"tag={tag}, medium={medium} {'.'*pad} {duration}, {len(results)}") 26 | -------------------------------------------------------------------------------- /src/nude2/train_constants.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # Batch size during training 4 | batch_size = 1024 5 | 6 | # Spatial size of training images. All images will be resized to this 7 | # size using a transformer. 8 | image_size = 64 9 | 10 | # Number of channels in the training images. For color images this is 3 11 | nc = 3 12 | 13 | # Size of z latent vector (i.e. size of generator input) 14 | nz = 100 15 | 16 | # Size of feature maps in generator 17 | ngf = 64 18 | 19 | # Size of feature maps in discriminator 20 | ndf = 64 21 | 22 | # Learning rate for optimizers 23 | lr = 0.0002 24 | 25 | # Beta1 hyperparameter for Adam optimizers 26 | beta1 = 0.5 27 | 28 | # Number of GPUs available. Use 0 for CPU mode. 29 | ngpu = 1 30 | 31 | # Set device 32 | device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu") 33 | -------------------------------------------------------------------------------- /test/train_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | 4 | 5 | from nude2.train import Generator, Discriminator 6 | 7 | 8 | WIDTH = HEIGHT = 64 9 | 10 | 11 | class ModuleShapeTest(unittest.TestCase): 12 | """Test that the generator and discrimator compute the correct shapes""" 13 | 14 | def test_generator_shape(self): 15 | """Ensure that g(Nx100x1x1) -> Nx3x256, 256)""" 16 | 17 | g = Generator() 18 | 19 | r = g(torch.rand(10, 100, 1, 1)) 20 | 21 | self.assertEqual( 22 | r.shape, 23 | (10, 3, WIDTH, HEIGHT), 24 | ) 25 | 26 | def test_discriminator_shape(self): 27 | """Ensure that d(Nx3x256, 256) -> Nx1x1x1""" 28 | 29 | d = Discriminator() 30 | 31 | r = d(torch.rand(8, 3, WIDTH, HEIGHT)) 32 | 33 | self.assertEqual( 34 | r.shape, 35 | (8, 1, 1, 1), 36 | ) 37 | -------------------------------------------------------------------------------- /src/nude2/splashes/splash.txt: -------------------------------------------------------------------------------- 1 | :::!~!!!!!:. 2 | .xUHWH!! !!?M88WHX:. 3 | .X*#M@$!! !X!M$$$$$$WWx:. 4 | :!!!!!!?H! :!$!$$$$$$$$$$8X: 5 | !!~ ~:~!! :~!$!#$$$$$$$$$$8X: 6 | :!~::!H!< ~.U$X!?R$$$$$$$$MM! 7 | ~!~!!!!~~ .:XW$$$U!!?$$$$$$RMM! 8 | !:~~~ .:!M"T#$$$$WX??#MRRMMM! 9 | ~?WuxiW*` `"#$$$$8!!!!??!!! 10 | :X- M$$$$ `"T#$T~!8$WUXU~ 11 | :%` ~#$$$m: ~!~ ?$$$$$$ 12 | :!`.- ~T$$$$8xx. .xWW- ~""##*" 13 | ..... -~~:<` ! ~?T#$$@@W@*?$$ /` 14 | W$@@M!!! .!~~ !! .:XUW$W!~ `"~: : 15 | #"~~`.:x%`!! !H: !WM$$$$Ti.: .!WUn+!` 16 | :::~:!!`:X~ .: ?H.!u "$$$B$$$!W:U!T$$M~ 17 | .~~ :X@!.-~ ?@WTWo("*$$$W$TH$! ` 18 | Wi.~!X$?!-~ : ?$$$B$Wu("**$RM! 19 | $R@i.~~ ! : ~$$$$$B$$en:`` 20 | ?MXT@Wx.~ : ~"##*$$$$M~ 21 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = nude2 3 | version = 0.0.1 4 | author = Matt <3 5 | author_email = matt@worldshadowgovernment.com 6 | url = https://github.com/whatever/nude-2.0 7 | description = Attempted recreation of the iconic Robbie Barrett's "The Nude" 8 | long_description = file: README.md 9 | long_description_content_type = text/markdown 10 | keywords = nude 11 | license = UNLICENSE 12 | classifiers = 13 | License :: OSI Approved :: BSD License 14 | Programming Language :: Python :: 3 15 | 16 | [options] 17 | package_dir = 18 | = src 19 | packages = find: 20 | install_requires = 21 | onnxruntime==1.16.3 22 | Pillow==10.0.1 23 | pytest==7.4.3 24 | torch==2.1.0 25 | torchvision==0.16.0 26 | # albumentations==1.3.1 27 | # matplotlib==3.8.0 28 | # opencv-python==4.8.1.78 29 | 30 | [options.packages.find] 31 | where = src 32 | exclude = 33 | examples* 34 | # tools* 35 | # docs* 36 | 37 | [options.entry_points] 38 | console_scripts = 39 | nude = nude2.cli:main 40 | 41 | 42 | [options.package_data] 43 | example = data/schema.json, *.txt 44 | * = README.md 45 | -------------------------------------------------------------------------------- /samples/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | samples 5 | 6 | 7 | 18 | 44 | 45 | 46 |
47 | 48 | 49 | -------------------------------------------------------------------------------- /src/nude2/browse.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import http.server 3 | import json 4 | import nude2.data 5 | import os 6 | import threading 7 | import time 8 | import webbrowser 9 | 10 | 11 | from http.server import HTTPServer, SimpleHTTPRequestHandler 12 | 13 | 14 | HOST = "127.0.0.1" 15 | """Open server locally""" 16 | 17 | 18 | PORT = 8181 19 | """Default to port 8181""" 20 | 21 | 22 | DIR = os.path.realpath(os.path.join(os.path.dirname(__file__), "static")) 23 | """Use /static/ directory""" 24 | 25 | 26 | class StaticDirServer(SimpleHTTPRequestHandler): 27 | """Serve a directory of static files""" 28 | 29 | def __init__(self, *args, **kwargs): 30 | """Construct""" 31 | super().__init__(*args, directory=DIR, **kwargs) 32 | 33 | def do_GET(self): 34 | """Respond to get requests with static files or api response""" 35 | if not self.path.startswith("/api/v0/"): 36 | return super().do_GET() 37 | 38 | pieces = [ 39 | v.replace("%20", " ") 40 | for v in self.path.split("/", 4) 41 | ] 42 | 43 | medium, tag = pieces[-2:] 44 | 45 | db = nude2.data.MetData(nude2.data.DB) 46 | print("Searching for:", medium, tag) 47 | res = db.fetch_tag(tag, medium) 48 | 49 | self.send_response(200) 50 | self.send_header("Content-type", "application/json") 51 | self.end_headers() 52 | self.wfile.write(bytes(json.dumps({"results": res}).encode("ascii"))) 53 | 54 | 55 | def serve(): 56 | """Start the static server""" 57 | server_address = (HOST, PORT) 58 | httpd = HTTPServer(server_address, StaticDirServer) 59 | httpd.serve_forever() 60 | 61 | 62 | def view(): 63 | threading.Thread(target=start_server).start() 64 | time.sleep(10) 65 | 66 | 67 | def start_server(): 68 | print("Starting server...") 69 | threading.Thread(target=serve).start() 70 | 71 | print("Sleeping for 1 second...") 72 | time.sleep(1) 73 | 74 | print("Opening browser...") 75 | webbrowser.open(f"http://{HOST}:{PORT}") 76 | -------------------------------------------------------------------------------- /src/nude2/progress.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | import sys 4 | 5 | from datetime import datetime 6 | 7 | class ProgressBar(object): 8 | """...""" 9 | 10 | def __init__(self, x, prefix="", size=100, clear=True): 11 | 12 | if isinstance(x, int): 13 | self.iter = range(x) 14 | else: 15 | self.iter = iter(x) 16 | 17 | self.start_time = datetime.now() 18 | self.clear = clear 19 | self.prefix = prefix 20 | self.size = size 21 | self.current = 0 22 | self.line = self.prefix + "[" + " "*self.size + "]" 23 | self.show() 24 | 25 | def x_x(self): 26 | eye = ["x", "-", "^", "o", "O"] 27 | return f"{random.choice(eye)}_{random.choice(eye)}" 28 | 29 | def show(self): 30 | 31 | # Percentage complete 32 | pct = self.current / len(self.iter) 33 | pct = round(pct * self.size) 34 | 35 | # Time elapsed 36 | dur = datetime.now() - self.start_time 37 | 38 | t = int(dur.total_seconds()) 39 | r = dur.total_seconds() - t 40 | 41 | h = t // 3600 42 | m = (t - h*3600) // 60 43 | s = t - h*3600 - m*60 44 | f = t - h*3600 - m*60 - s 45 | 46 | sys.stdout.write("\b"*len(self.line)) 47 | self.line = self.prefix + "[" + "*"*pct + " "*(self.size-pct) + "] " + self.x_x() 48 | self.line += f" {100*self.current/len(self.iter):0.2f}% [{h}:{m:02}:{s:02}] @ {datetime.now().strftime('%H:%M:%S')}\r" 49 | sys.stdout.write(self.line) 50 | sys.stdout.flush() 51 | 52 | def inc(self): 53 | self.current += 1 54 | self.show() 55 | 56 | def set(self, value): 57 | self.current = value 58 | self.show() 59 | 60 | def __enter__(self): 61 | return self 62 | 63 | def __exit__(self, type, value, traceback): 64 | print(self.line) 65 | if self.clear: 66 | sys.stdout.write("\b"*len(self.line)) 67 | 68 | class I(object): 69 | def __iter__(self): 70 | return self 71 | def __next__(self): 72 | return 1 73 | 74 | def main(): 75 | for _ in range(5): 76 | with ProgressBar(66) as progress: 77 | for _ in range(66): 78 | time.sleep(0.02) 79 | progress.inc() 80 | -------------------------------------------------------------------------------- /src/nude2/model_v1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from nude2.train_constants import * 5 | 6 | 7 | class Discriminator(nn.Module): 8 | """Module to discriminate between real and fake images""" 9 | 10 | def __init__(self): 11 | super(Discriminator, self).__init__() 12 | self.main = nn.Sequential( 13 | # input is ``(nc) x 64 x 64`` 14 | nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), 15 | nn.LeakyReLU(0.2, inplace=True), 16 | # state size. ``(ndf) x 32 x 32`` 17 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), 18 | nn.BatchNorm2d(ndf * 2), 19 | nn.LeakyReLU(0.2, inplace=True), 20 | # state size. ``(ndf*2) x 16 x 16`` 21 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), 22 | nn.BatchNorm2d(ndf * 4), 23 | nn.LeakyReLU(0.2, inplace=True), 24 | # state size. ``(ndf*4) x 8 x 8`` 25 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), 26 | nn.BatchNorm2d(ndf * 8), 27 | nn.LeakyReLU(0.2, inplace=True), 28 | # state size. ``(ndf*8) x 4 x 4`` 29 | nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), 30 | nn.Sigmoid(), 31 | ) 32 | 33 | def forward(self, input): 34 | return self.main(input) 35 | 36 | 37 | class Generator(nn.Module): 38 | """Module to generate an image from a feature vector""" 39 | 40 | def __init__(self): 41 | super(Generator, self).__init__() 42 | self.main = nn.Sequential( 43 | # input is Z, going into a convolution 44 | nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False), 45 | nn.BatchNorm2d(ngf * 8), 46 | nn.ReLU(True), 47 | # state size. ``(ngf*8) x 4 x 4`` 48 | nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), 49 | nn.BatchNorm2d(ngf * 4), 50 | nn.ReLU(True), 51 | # state size. ``(ngf*4) x 8 x 8`` 52 | nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False), 53 | nn.BatchNorm2d(ngf * 2), 54 | nn.ReLU(True), 55 | # state size. ``(ngf*2) x 16 x 16`` 56 | nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False), 57 | nn.BatchNorm2d(ngf), 58 | nn.ReLU(True), 59 | # state size. ``(ngf) x 32 x 32`` 60 | nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False), 61 | nn.Tanh() 62 | # state size. ``(nc) x 64 x 64``, 63 | ) 64 | 65 | def forward(self, input): 66 | return self.main(input) 67 | -------------------------------------------------------------------------------- /src/nude2/cli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import nude2.benchmark 4 | import nude2.browse 5 | import nude2.data 6 | import nude2.generate 7 | import nude2.onnxify 8 | import nude2.progress 9 | import nude2.sample 10 | import nude2.train 11 | 12 | import nude2 13 | 14 | def main(): 15 | parser = argparse.ArgumentParser() 16 | 17 | subparsers = parser.add_subparsers(dest="command") 18 | 19 | browse_parser = subparsers.add_parser("browse") 20 | 21 | data_parser = subparsers.add_parser("data") 22 | data_parser.add_argument("--concurrency", type=int, default=8) 23 | data_parser.add_argument("--limit", type=int, default=1000) 24 | 25 | benchmark_parser = subparsers.add_parser("benchmark") 26 | 27 | progress_parser = subparsers.add_parser("progress") 28 | 29 | train_parser = subparsers.add_parser("train") 30 | train_parser.add_argument("--data", type=str, required=True) 31 | train_parser.add_argument("--epochs", type=int, default=10) 32 | train_parser.add_argument("--batch-size", type=int, default=8) 33 | train_parser.add_argument("--checkpoint", type=str, required=True) 34 | train_parser.add_argument("--samples-path", type=str) 35 | train_parser.add_argument("--seed", type=int, default=42069) 36 | 37 | generate_parser = subparsers.add_parser("generate") 38 | generate_parser.add_argument("--checkpoint", type=str) 39 | generate_parser.add_argument("--samples", type=str) 40 | 41 | video_parser = subparsers.add_parser("video") 42 | video_parser.add_argument("--checkpoint", type=str) 43 | video_parser.add_argument("--samples", type=str) 44 | video_parser.add_argument("-o", "--out", type=str) 45 | 46 | sample_parser = subparsers.add_parser("sample") 47 | sample_parser.add_argument("--checkpoint", type=str) 48 | sample_parser.add_argument("--samples", type=str) 49 | 50 | onnxify_parser = subparsers.add_parser("onnxify") 51 | onnxify_parser.add_argument("--checkpoint", type=str) 52 | onnxify_parser.add_argument("-o", "--output", type=str) 53 | 54 | args = parser.parse_args() 55 | 56 | if args.command == "browse": 57 | nude2.browse.view() 58 | elif args.command == "data": 59 | nude2.data.main(args.concurrency, args.limit) 60 | elif args.command == "benchmark": 61 | nude2.benchmark.main() 62 | elif args.command == "progress": 63 | nude2.progress.main() 64 | elif args.command == "train": 65 | nude2.train.main( 66 | args.data, 67 | args.epochs, 68 | args.batch_size, 69 | args.checkpoint, 70 | args.samples_path, 71 | seed=args.seed, 72 | ) 73 | elif args.command == "onnxify": 74 | nude2.onnxify.main( 75 | args.checkpoint, 76 | args.output, 77 | ) 78 | elif args.command == "generate": 79 | nude2.generate.main(args.checkpoint) 80 | elif args.command == "video": 81 | nude2.sample.main(args.samples, args.out) 82 | elif args.command == "sample": 83 | nude2.sample.main(args.checkpoint, args.samples) 84 | -------------------------------------------------------------------------------- /src/nude2/sample.py: -------------------------------------------------------------------------------- 1 | import PIL 2 | import os.path 3 | import torch 4 | 5 | import torch.nn as nn 6 | 7 | from glob import glob 8 | 9 | from collections import defaultdict 10 | import itertools 11 | import subprocess 12 | import random 13 | 14 | from nude2.model import Generator198x198 15 | import nude2.data 16 | 17 | 18 | def gridify(images, nrows, ncols): 19 | WIDTH = HEIGHT = 64 20 | PADDING = 4 21 | 22 | w = (ncols+1)*PADDING + ncols*WIDTH 23 | h = (nrows+1)*PADDING + nrows*HEIGHT 24 | 25 | elpapa = PIL.Image.new("RGB", (w, h)) 26 | elpapa.paste((255, 255, 255), (0, 0, w, h)) 27 | 28 | for n, img in enumerate(images): 29 | i = n // nrows 30 | j = n % ncols 31 | x = PADDING + i*(WIDTH+PADDING) 32 | y = PADDING + j*(HEIGHT+PADDING) 33 | elpapa.paste(img, (x, y)) 34 | 35 | return elpapa 36 | 37 | 38 | def main2(checkpoint_path): 39 | 40 | NCOLS = NROWS = 8 41 | 42 | images = [ 43 | PIL.Image.new("RGB", (64, 64)) 44 | for i in range(NROWS*NCOLS) 45 | ] 46 | 47 | for i, img in enumerate(images): 48 | x = i % NCOLS 49 | y = i // NCOLS 50 | 51 | r = int(x/(NCOLS-1)*255) 52 | g = 0 53 | b = int(y/(NROWS-1)*255) 54 | 55 | img.paste( 56 | (r, g, b), 57 | (0, 0, 64, 64), 58 | ) 59 | 60 | elpapa = gridify(images, NROWS, NCOLS) 61 | 62 | elpapa.save("colsamps/yikes.png") 63 | 64 | 65 | def main2(samples_path, out): 66 | 67 | fnames = sorted(glob(os.path.join(samples_path, "*.jpg"))) 68 | 69 | image_set = defaultdict(list) 70 | 71 | images = defaultdict(list) 72 | 73 | for fname in sorted(fnames): 74 | fname = os.path.basename(fname) 75 | g = fname.split("-")[1] 76 | image_set[g].append(fname) 77 | 78 | for epoch in sorted(image_set.keys()): 79 | fnames = sorted(image_set[epoch])[0:64] 80 | for fname in fnames: 81 | with PIL.Image.open(os.path.join(samples_path, fname)) as img: 82 | images[epoch].append(img.resize((64, 64))) 83 | 84 | for epoch in images.keys(): 85 | elpapa = gridify(images[epoch], 8, 8) 86 | print(f"saving epoch={epoch}") 87 | elpapa.save(f"colsamps/{epoch}.png") 88 | 89 | 90 | subprocess.run([ 91 | "ffmpeg", 92 | "-framerate", "30", 93 | "-pattern_type", "glob", 94 | "-i", "colsamps/*.png", 95 | "-c:v", "libx264", 96 | "-pix_fmt", "yuv420p", 97 | out, 98 | ]) 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | # Useful with: 107 | """ 108 | ffmpeg -framerate 30 \ 109 | -pattern_type glob \ 110 | -i '*.jpg' \ 111 | -c:v libx264 \ 112 | -pix_fmt yuv420p \ 113 | out.mp4 114 | """ 115 | 116 | 117 | def main(checkpoint_path, samples_path): 118 | states = torch.load( 119 | checkpoint_path, 120 | map_location=torch.device('cpu'), 121 | ) 122 | 123 | g = Generator198x198() 124 | g.load_state_dict(states["g"]) 125 | g.eval() 126 | 127 | batch_size = 64 128 | 129 | # noise = torch.rand(batch_size, 100, 1, 1) 130 | noise = torch.ones(batch_size, 100, 1, 1) 131 | noise = torch.rand(batch_size, 100, 1, 1) 132 | 133 | for i in range(batch_size): 134 | out = g(noise) 135 | img = nude2.data.MetCenterCroppedDataset.pilify(out[i]) 136 | img.save(os.path.join(samples_path, f"sample-{i}.jpg")) 137 | -------------------------------------------------------------------------------- /src/nude2/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | ndf = ngf = 64 5 | nz = 100 6 | nc = 3 7 | 8 | 9 | class Discriminator(nn.Module): 10 | """Module to discriminate between real and fake images""" 11 | 12 | def __init__(self): 13 | super(Discriminator, self).__init__() 14 | 15 | self.main = nn.Sequential( 16 | # input is ``(nc) x 64 x 64`` 17 | nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), 18 | nn.LeakyReLU(0.2, inplace=True), 19 | # state size. ``(ndf) x 32 x 32`` 20 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), 21 | nn.BatchNorm2d(ndf * 2), 22 | nn.LeakyReLU(0.2, inplace=True), 23 | # state size. ``(ndf*2) x 16 x 16`` 24 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), 25 | nn.BatchNorm2d(ndf * 4), 26 | nn.LeakyReLU(0.2, inplace=True), 27 | # state size. ``(ndf*4) x 8 x 8`` 28 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), 29 | nn.BatchNorm2d(ndf * 8), 30 | nn.LeakyReLU(0.2, inplace=True), 31 | # state size. ``(ndf*8) x 4 x 4`` 32 | nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), 33 | nn.Sigmoid(), 34 | ) 35 | 36 | def forward(self, input): 37 | return self.main(input) 38 | 39 | 40 | class Generator(nn.Module): 41 | """Module to generate an image from a feature vector""" 42 | 43 | def __init__(self): 44 | super(Generator, self).__init__() 45 | self.main = nn.Sequential( 46 | # input is Z, going into a convolution 47 | nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False), 48 | nn.BatchNorm2d(ngf * 8), 49 | nn.ReLU(True), 50 | # state size. ``(ngf*8) x 4 x 4`` 51 | nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), 52 | nn.BatchNorm2d(ngf * 4), 53 | nn.ReLU(True), 54 | # state size. ``(ngf*4) x 8 x 8`` 55 | nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False), 56 | nn.BatchNorm2d(ngf * 2), 57 | nn.ReLU(True), 58 | # state size. ``(ngf*2) x 16 x 16`` 59 | nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False), 60 | nn.BatchNorm2d(ngf), 61 | nn.ReLU(True), 62 | # state size. ``(ngf) x 32 x 32`` 63 | nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False), 64 | nn.Tanh() 65 | # state size. ``(nc) x 64 x 64``, 66 | ) 67 | 68 | 69 | def forward(self, input): 70 | return self.main(input) 71 | 72 | 73 | class Generator198x198(nn.Module): 74 | """Module to generate an image from a feature vector""" 75 | 76 | def __init__(self): 77 | super(Generator198x198, self).__init__() 78 | 79 | self.main = nn.Sequential( 80 | nn.ConvTranspose2d( nz, ngf * 8, 5, 2, 0, bias=False), 81 | nn.BatchNorm2d(ngf * 8), 82 | nn.ReLU(True), 83 | 84 | nn.ConvTranspose2d( ngf * 8, ngf * 4, 6, 3, 1, bias=False), 85 | nn.BatchNorm2d(ngf * 4), 86 | nn.ReLU(True), 87 | 88 | nn.ConvTranspose2d( ngf * 4, ngf * 2, 6, 3, 1, bias=False), 89 | nn.BatchNorm2d(ngf * 2), 90 | nn.ReLU(True), 91 | 92 | nn.ConvTranspose2d( ngf * 2, ngf * 1, 5, 2, 1, bias=False), 93 | nn.BatchNorm2d(ngf * 1), 94 | nn.ReLU(True), 95 | 96 | nn.ConvTranspose2d( ngf * 1, nc, 4, 2, 1, bias=False), 97 | nn.Tanh(), 98 | ) 99 | 100 | 101 | def forward(self, input): 102 | return self.main(input) 103 | -------------------------------------------------------------------------------- /src/nude2/onnxify.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | import PIL 3 | import onnxruntime as ort 4 | import torch 5 | import torch.nn as nn 6 | import torchvision.transforms as transforms 7 | import numpy as np 8 | 9 | import onnx 10 | from onnx import compose 11 | 12 | from nude2.model import Generator198x198 13 | import nude2.data 14 | 15 | class ImageGenerator(Generator198x198): 16 | 17 | 18 | def __init__(self): 19 | super().__init__() 20 | s = 1/255. 21 | self.pilify = transforms.Compose([ 22 | transforms.Normalize(mean=[0., 0., 0.], std=[1/0.5, 1/0.5, 1/0.5]), 23 | transforms.Normalize(mean=[-0.5, -0.5, -0.5], std=[s, s, s]), 24 | ]) 25 | 26 | def forward(self, x): 27 | return self.pilify(super().forward(x)) 28 | 29 | 30 | class Transform(nn.Module): 31 | def __init__(self): 32 | super().__init__() 33 | s = 1.0 34 | self.normalize1 = transforms.Normalize(mean=[0., 0., 0.], std=[1/0.5, 1/0.5, 1/0.5]) 35 | self.normalize2 = transforms.Normalize(mean=[-0.5, -0.5, -0.5], std=[s, s, s]) 36 | 37 | def forward(self, x): 38 | x = self.normalize1(x) 39 | x = self.normalize2(x) 40 | # x = self.to_pil(x) 41 | return x 42 | 43 | 44 | 45 | def main(checkpoint, output): 46 | 47 | fname = "checkpoints/nude2-dcgan-met-random-crop-198x198.pt" 48 | fname = checkpoint 49 | 50 | states = torch.load(fname, map_location="cpu") 51 | 52 | with tempfile.NamedTemporaryFile("w") as f: 53 | t = Transform() 54 | x = torch.randn(1, 3, 1, 1) 55 | torch.onnx.export(t, x, f.name, opset_version=10, input_names=["raw"], output_names=["output"]) 56 | t_onnx = onnx.load(f.name) 57 | 58 | with tempfile.NamedTemporaryFile("w") as f: 59 | g = Generator198x198() 60 | g.load_state_dict(states["g"]) 61 | g.cpu() 62 | g.eval() 63 | g.train() 64 | x = torch.randn(1, 100, 1, 1) 65 | x = torch.ones(1, 100, 1, 1) 66 | torch.onnx.export( 67 | g, 68 | x, 69 | f.name, 70 | export_params=True, 71 | opset_version=10, 72 | do_constant_folding=True, 73 | input_names=["vector"], 74 | output_names=["unnormalized_output"], 75 | ) 76 | 77 | g_onnx = onnx.load(f.name) 78 | onnx.checker.check_model(g_onnx) 79 | 80 | # Session 81 | ort_sess = ort.InferenceSession(f.name) 82 | ort_outs = ort_sess.run(None, {"vector": x.numpy()})[0] 83 | g_outs = g(x).detach().numpy() 84 | 85 | # print(g_outs[0, 0, :10, :10]) 86 | 87 | if not np.allclose(ort_outs, g_outs, rtol=1e-6, atol=1e-6): 88 | print("WARNING: predicted matrices aren't close!") 89 | 90 | model = compose.merge_models(g_onnx, t_onnx, io_map=[("unnormalized_output", "raw")]) 91 | onnx.save_model(model, output) 92 | 93 | # XXX: GENERATE SAMPLE 94 | # session = ort.InferenceSession("checkpoints/nice.onnx") 95 | # res = session.run(None, {"vector": x.numpy()}) 96 | # res = res[0][0] 97 | # res = res.astype(np.uint8) 98 | # res = np.moveaxis(res, 0, -1) 99 | # print(res.shape) 100 | # img = PIL.Image.fromarray(res, mode="RGB") 101 | # img.save("cmon.png") 102 | # print(img) 103 | 104 | 105 | # XXX: GENERATE PYTORCH SAMPLE 106 | # batch_size = 64 107 | # x = torch.randn(batch_size, 100, 1, 1) 108 | # y = g(x) 109 | # y = t(y) 110 | # img = nude2.data.MetCenterCroppedDataset.pilify(y[0]) 111 | # img.save("yea.png") 112 | 113 | # print((ort_outs - g_outs).mean()) 114 | # print((ort_outs - g_outs).max()) 115 | -------------------------------------------------------------------------------- /src/nude2/static/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | lol 5 | 6 | 7 | 30 | 31 | 32 | 33 | 34 |
35 |
36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 |
59 |
60 | 61 | 62 |
63 | 64 | 136 | 137 | 138 | 139 | -------------------------------------------------------------------------------- /src/nude2/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | DCGAN stolen from 3 | """ 4 | 5 | 6 | import json 7 | import logging 8 | import nude2.data 9 | import os 10 | import os.path 11 | import random 12 | import torch 13 | import torch.nn as nn 14 | import torch.optim 15 | import torch.utils.data as data 16 | import torchvision 17 | 18 | from datetime import datetime 19 | from nude2.progress import ProgressBar 20 | from nude2.utils import splash 21 | 22 | from nude2.model import Generator, Discriminator 23 | from nude2.train_constants import * 24 | 25 | import nude2.model_v1 as model_v1 26 | 27 | 28 | LOG_FORMAT = "\033[95m%(asctime)s\033[00m [%(levelname)s] %(message)s" 29 | logging.basicConfig(format=LOG_FORMAT) 30 | logger = logging.getLogger(__name__) 31 | logger.setLevel(logging.INFO) 32 | 33 | 34 | def weights_init(m): 35 | """Normalize weights""" 36 | classname = m.__class__.__name__ 37 | if classname.find('Conv') != -1: 38 | nn.init.normal_(m.weight.data, 0.0, 0.02) 39 | elif classname.find('BatchNorm') != -1: 40 | nn.init.normal_(m.weight.data, 1.0, 0.02) 41 | nn.init.constant_(m.bias.data, 0) 42 | 43 | 44 | 45 | 46 | def main(data_folder, num_epochs, batch_size, checkpoint_path, samples_path, seed=None): 47 | 48 | if seed is not None: 49 | manualSeed = 999 50 | random.seed(manualSeed) 51 | torch.manual_seed(manualSeed) 52 | torch.use_deterministic_algorithms(True) 53 | 54 | data_dir = os.path.expanduser(data_folder) 55 | 56 | splash("splash") 57 | print("\n\n") 58 | print("NUDE 2.0") 59 | print("========") 60 | print(f"data .......... \033[95m{data_dir}\033[00m") 61 | print(f"epochs ........ \033[96m{num_epochs}\033[00m") 62 | print(f"batch size .... \033[95m{batch_size}\033[00m") 63 | print(f"device ........ \033[95m{device}\033[00m") 64 | print(f"checkpoint .... \033[95m{checkpoint_path}\033[00m") 65 | print(f"samples path .. \033[95m{samples_path}\033[00m") 66 | print() 67 | 68 | if samples_path is not None: 69 | os.makedirs(samples_path, exist_ok=True) 70 | with open(os.path.join(samples_path, "meta.json"), "w") as f: 71 | json.dump({ 72 | "data_folder": data_folder, 73 | "num_epochs": num_epochs, 74 | "batch_size": batch_size, 75 | "checkpoint_path": checkpoint_path, 76 | "samples_path": samples_path, 77 | "lr": lr, 78 | "beta1": beta1, 79 | }, f) 80 | 81 | # dataset = nude2.data.MetCenterCroppedDataset(data_dir) 82 | 83 | dataset = nude2.data.CachedDataset( 84 | data_dir, 85 | "~/.cache/nude2/images-random-crop-256x256", 86 | ) 87 | 88 | dataloader = data.DataLoader( 89 | dataset, 90 | batch_size=batch_size, 91 | shuffle=True, 92 | num_workers=8, 93 | ) 94 | 95 | g = Generator().to(device) 96 | g.apply(weights_init) 97 | 98 | d = Discriminator().to(device) 99 | d.apply(weights_init) 100 | 101 | total_params = sum( 102 | param.numel() 103 | for param in g.parameters() 104 | ) 105 | 106 | print("Total params =", total_params) 107 | 108 | criterion = nn.BCELoss() 109 | 110 | fixed_noise = torch.randn(64, nz, 1, 1, device=device) 111 | 112 | real_label = 1.0 113 | fake_label = 0.0 114 | 115 | optimizerD = torch.optim.Adam(d.parameters(), lr=lr, betas=(beta1, 0.999)) 116 | optimizerG = torch.optim.Adam(g.parameters(), lr=lr, betas=(beta1, 0.999)) 117 | 118 | try: 119 | states = torch.load(checkpoint_path) 120 | g.load_state_dict(states["g"]) 121 | d.load_state_dict(states["d"]) 122 | epoch = states["epoch"] + 1 123 | logger.info(f"Loaded at {epoch-1}") 124 | except FileNotFoundError: 125 | logger.warn("Could not find specified checkpoint... starting GAN at epoch=0") 126 | epoch = 0 127 | 128 | last_epoch = epoch + num_epochs 129 | 130 | for epoch in range(epoch, epoch + num_epochs): 131 | 132 | start = datetime.utcnow() 133 | 134 | with ProgressBar(len(dataloader), prefix=f"[epoch={epoch:04d}/{last_epoch}] ", size=40) as progress: 135 | for i, imgs in enumerate(dataloader): 136 | 137 | # Discriminate against real data 138 | d.zero_grad() 139 | real_cpu = imgs.to(device) 140 | b_size = real_cpu.size(0) 141 | label = torch.full( 142 | (b_size,), 143 | real_label, 144 | dtype=torch.float, 145 | device=device, 146 | ) 147 | output = d(real_cpu).view(-1) 148 | errD_real = criterion(output, label) 149 | errD_real.backward() 150 | D_x = output.mean().item() 151 | 152 | # Discriminate against fake data 153 | noise = torch.randn(b_size, nz, 1, 1, device=device) 154 | fake = g(noise) 155 | label.fill_(fake_label) 156 | 157 | output = d(fake.detach()).view(-1) 158 | errD_fake = criterion(output, label) 159 | 160 | errD_fake.backward() 161 | D_G_z1 = output.mean().item() 162 | errD = errD_real + errD_fake 163 | optimizerD.step() 164 | 165 | # Increment generator 166 | g.zero_grad() 167 | label.fill_(real_label) # fake labels are real for generator cost 168 | output = d(fake).view(-1) 169 | errG = criterion(output, label) 170 | errG.backward() 171 | D_G_z2 = output.mean().item() 172 | optimizerG.step() 173 | 174 | # inc 175 | progress.inc() 176 | 177 | sample = g(fixed_noise).detach().cpu() 178 | 179 | if samples_path is not None: 180 | for i in range(sample.size(0)): 181 | img = nude2.data.MetCenterCroppedDataset.pilify(sample[i]) 182 | fname = f"sample-{epoch:04d}-{i:02d}.jpg" 183 | img.save(os.path.join(samples_path, fname)) 184 | 185 | dur = datetime.utcnow() - start 186 | 187 | res = torch.save({ 188 | "g": g.state_dict(), 189 | "d": d.state_dict(), 190 | "epoch": epoch, 191 | }, checkpoint_path) 192 | -------------------------------------------------------------------------------- /src/nude2/data.py: -------------------------------------------------------------------------------- 1 | import PIL.Image 2 | import csv 3 | import ctypes 4 | import hashlib 5 | import json 6 | import logging 7 | import multiprocessing as mp 8 | import multiprocessing.dummy 9 | import numpy as np 10 | import os.path 11 | import requests 12 | import shutil 13 | import sqlite3 14 | import tempfile 15 | import torch 16 | import torchvision 17 | import urllib.request 18 | 19 | from collections import OrderedDict 20 | from datetime import datetime 21 | from glob import glob 22 | from torch.utils.data import Dataset, DataLoader 23 | 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | CACHE_DIR = os.path.expanduser(os.path.join("~", ".cache", "nude2", "data")) 29 | 30 | 31 | IMAGES_CSV = os.path.join(CACHE_DIR, "met-images.csv") 32 | 33 | 34 | MET_CSV = os.path.join(CACHE_DIR, "MetObjects.csv") 35 | 36 | 37 | DB = os.path.join(CACHE_DIR, "met.db") 38 | 39 | 40 | MET_COLS = OrderedDict([ 41 | ("Object Number", "VARCHAR(256)"), 42 | ("Is Highlight", "BOOL"), 43 | ("Is Timeline Work", "BOOL"), 44 | ("Is Public Domain", "BOOL"), 45 | ("Object ID", "NUMBER"), 46 | ("Gallery Number", "NUMBER"), 47 | ("Department", "VARCHAR(255)"), 48 | ("AccessionYear", "VARCHAR(255)"), 49 | ("Object Name", "VARCHAR(255)"), 50 | ("Title", "VARCHAR(255)"), 51 | ("Culture", "VARCHAR(255)"), 52 | ("Period", "VARCHAR(255)"), 53 | ("Dynasty", "VARCHAR(255)"), 54 | ("Reign", "VARCHAR(255)"), 55 | ("Portfolio", "VARCHAR(255)"), 56 | ("Constituent ID", "VARCHAR(255)"), 57 | ("Artist Role", "VARCHAR(255)"), 58 | ("Artist Prefix", "VARCHAR(255)"), 59 | ("Artist Display Name", "VARCHAR(255)"), 60 | ("Artist Display Bio", "VARCHAR(255)"), 61 | ("Artist Suffix", "VARCHAR(255)"), 62 | ("Artist Alpha Sort", "VARCHAR(255)"), 63 | ("Artist Nationality", "VARCHAR(255)"), 64 | ("Artist Begin Date", "VARCHAR(255)"), 65 | ("Artist End Date", "VARCHAR(255)"), 66 | ("Artist Gender", "VARCHAR(255)"), 67 | ("Artist ULAN URL", "VARCHAR(255)"), 68 | ("Artist Wikidata URL", "VARCHAR(255)"), 69 | ("Object Date", "VARCHAR(255)"), 70 | ("Object Begin Date", "VARCHAR(255)"), 71 | ("Object End Date", "VARCHAR(255)"), 72 | ("Medium", "VARCHAR(255)"), 73 | ("Dimensions", "VARCHAR(255)"), 74 | ("Credit Line", "VARCHAR(255)"), 75 | ("Geography Type", "VARCHAR(255)"), 76 | ("City", "VARCHAR(255)"), 77 | ("State", "VARCHAR(255)"), 78 | ("County", "VARCHAR(255)"), 79 | ("Country", "VARCHAR(255)"), 80 | ("Region", "VARCHAR(255)"), 81 | ("Subregion", "VARCHAR(255)"), 82 | ("Locale", "VARCHAR(255)"), 83 | ("Locus", "VARCHAR(255)"), 84 | ("Excavation", "VARCHAR(255)"), 85 | ("River", "VARCHAR(255)"), 86 | ("Classification", "VARCHAR(255)"), 87 | ("Rights and Reproduction", "VARCHAR(255)"), 88 | ("Link Resource", "VARCHAR(255)"), 89 | ("Object Wikidata URL", "VARCHAR(255)"), 90 | ("Metadata Date", "VARCHAR(255)"), 91 | ("Repository", "VARCHAR(255)"), 92 | ("Tags", "VARCHAR(255)"), 93 | ("Tags AAT URL", "VARCHAR(255)"), 94 | ("Tags Wikidata URL", "VARCHAR(255)"), 95 | ]) 96 | 97 | 98 | CREATE_MET_TABLE = f""" 99 | CREATE TABLE IF NOT EXISTS met ({", ".join( 100 | " `{k}` {v}".format(k=k, v=v) 101 | for k, v in MET_COLS.items() 102 | )}, 103 | CONSTRAINT pk_met PRIMARY KEY (`Object ID`) 104 | )""" 105 | 106 | 107 | 108 | CREATE_MET_IMAGES_TABLE = """ 109 | CREATE TABLE IF NOT EXISTS met_images ( 110 | `Object ID` NUMBER, 111 | `Image URL` VARCHAR(255), 112 | CONSTRAINT pk_met_images PRIMARY KEY(`Object ID`, `Image URL`) 113 | ) 114 | """ 115 | 116 | 117 | CREATE_MET_TAG_TABLE = """ 118 | CREATE TABLE IF NOT EXISTS met_tags ( 119 | `Object ID` NUMBER, 120 | `Tag` VARCHAR(255), 121 | CONSTRAINT pk_met_tag PRIMARY KEY (`Object ID`, `Tag`) 122 | ) 123 | """ 124 | 125 | SELECT_MATCHING_TAGS_COLS = [ 126 | "Object ID", 127 | "Title", 128 | "Object Number", 129 | "Object Name", 130 | "Is Highlight", 131 | "Is Timeline Work", 132 | "Is Public Domain", 133 | "Image URL", 134 | "Medium", 135 | "Tags", 136 | "Tag", 137 | "Image URL", 138 | ] 139 | 140 | 141 | SELECT_MATCHING_TAGS = f""" 142 | CREATE VIEW IF NOT EXISTS tagged_images AS 143 | SELECT 144 | {", ".join(f"`{c}`" for c in SELECT_MATCHING_TAGS_COLS)} 145 | FROM met 146 | INNER JOIN met_tags USING (`Object ID`) 147 | INNER JOIN met_images USING (`Object ID`) 148 | """ 149 | 150 | 151 | SELECT_TAGGED_MET_IMAGES_SQL = """ 152 | WITH matched_object_ids AS ( 153 | SELECT DISTINCT 154 | `Object ID` 155 | FROM met_tags 156 | WHERE `Tag` IN ? 157 | ) 158 | SELECT 159 | * 160 | FROM met_images 161 | INNER JOIN matched_object_ids USING (`Object ID`) 162 | """ 163 | 164 | def create_index_sql(table_name, cols): 165 | sanitized_cols = [c.lower().replace(" ", "") for c in cols] 166 | cols = [f"`{c}`" for c in cols] 167 | return f""" 168 | CREATE INDEX IF NOT EXISTS `idx__{table_name}__{'_'.join(sanitized_cols)}` 169 | ON {table_name} ({", ".join(cols)}) 170 | """ 171 | 172 | 173 | class MetData(object): 174 | """Metroplitan Museum of Art Data""" 175 | 176 | def __init__(self, loc=DB): 177 | self.loc = loc 178 | self.conn = sqlite3.connect(self.loc) 179 | 180 | if not self.is_bootstrapped(): 181 | logger.info("Bootstrapping MET data") 182 | self.bootstrap() 183 | 184 | def is_bootstrapped(self): 185 | """Return whether database has been bootstrapped""" 186 | with self.conn as conn: 187 | curs = conn.cursor() 188 | curs.execute("SELECT name FROM sqlite_master WHERE type='table'") 189 | tables = set(row[0] for row in curs.fetchall()) 190 | return tables == {"met", "met_images", "met_tags"} 191 | 192 | def bootstrap(self): 193 | with self.conn as conn: 194 | curs = conn.cursor() 195 | 196 | curs.execute(CREATE_MET_TABLE) 197 | curs.execute(create_index_sql("met", ["Object ID"])) 198 | 199 | with open(MET_CSV, "r") as fi: 200 | reader = csv.DictReader(fi) 201 | curs.executemany( 202 | f"INSERT INTO met VALUES ({', '.join('?' * len(MET_COLS))})", 203 | (tuple(row.values()) for row in reader), 204 | ) 205 | 206 | curs.execute(CREATE_MET_IMAGES_TABLE) 207 | curs.execute(create_index_sql("met_images", ["Object ID"])) 208 | 209 | # XXX: executemany receives a `set` because there are dupes in the csv. 210 | # Just fix the csv here first 211 | with open(IMAGES_CSV, "r") as fi: 212 | reader = csv.DictReader(fi) 213 | curs.executemany( 214 | "INSERT INTO met_images VALUES (?, ?)", 215 | {tuple(row.values()) for row in reader}, 216 | ) 217 | 218 | curs.execute(CREATE_MET_TAG_TABLE) 219 | curs.execute(create_index_sql("met_tags", ["Object ID"])) 220 | curs.execute(create_index_sql("met_tags", ["Object ID", "Tag"])) 221 | 222 | curs.execute("SELECT `Object ID`, `Tags` FROM met WHERE `Tags` IS NOT NULL AND `Tags` != ''") 223 | 224 | rows = curs.fetchall() 225 | 226 | curs.executemany( 227 | "INSERT INTO met_tags VALUES (?, ?)", 228 | ((row[0], tag) for row in rows for tag in row[1].split("|")), 229 | ) 230 | 231 | curs.execute(SELECT_MATCHING_TAGS) 232 | 233 | with self.conn as conn: 234 | curs = conn.cursor() 235 | curs.execute("VACUUM") 236 | 237 | logger.info("Finished bootstrapping"); 238 | 239 | def __len__(self): 240 | with self.conn: 241 | curs = self.conn.cursor() 242 | curs.execute("SELECT COUNT(*) FROM met_images") 243 | return curs.fetchone()[0] 244 | 245 | def __getitem__(self, idx): 246 | with self.conn: 247 | curs = self.conn.cursor() 248 | curs.execute( 249 | "SELECT * FROM met_images LIMIT 1 OFFSET ?", 250 | (idx, ), 251 | ) 252 | return curs.fetchone() 253 | 254 | def select(self, sql, params=tuple()): 255 | assert sql.upper().startswith("SELECT") 256 | with self.conn: 257 | curs = self.conn.cursor() 258 | curs.execute(sql, params) 259 | return curs.fetchall() 260 | 261 | def fetch_tag(self, tag, medium): 262 | """Return artworks with a given tag""" 263 | 264 | tag = tag or "" 265 | 266 | medium = medium or "" 267 | 268 | Q = """ 269 | WITH filtered_images AS ( 270 | SELECT * 271 | FROM tagged_images 272 | WHERE 273 | COALESCE(?, '') = '' 274 | OR 275 | `Tag` = ? 276 | ) 277 | SELECT * 278 | FROM filtered_images 279 | WHERE 280 | `Medium` LIKE '%' || ? || '%' 281 | """ 282 | with self.conn: 283 | curs = self.conn.cursor() 284 | curs.execute(Q, (tag, tag, medium, )) 285 | return [ 286 | dict(zip(SELECT_MATCHING_TAGS_COLS, row)) 287 | for row in curs.fetchall() 288 | ] 289 | 290 | 291 | def retrieve_csv_data(): 292 | """Store CSV's locally""" 293 | 294 | if not os.path.exists(CACHE_DIR): 295 | logger.info("Creating cache directory: %s", CACHE_DIR) 296 | os.makedirs(CACHE_DIR) 297 | 298 | if not os.path.exists(MET_CSV): 299 | logger.info("Downloading MET CSV: %s", MET_CSV) 300 | MET_CSV_URL = "https://github.com/metmuseum/openaccess/raw/master/MetObjects.csv" 301 | urllib.request.urlretrieve(MET_CSV_URL, MET_CSV) 302 | 303 | if not os.path.exists(IMAGES_CSV): 304 | logger.info("Using MET Image CSV: %s", IMAGES_CSV) 305 | MET_IMAGES_CSV_URL = "https://github.com/gregsadetsky/open-access-is-great-but-where-are-the-images/raw/main/1.data/met-images.csv" 306 | urllib.request.urlretrieve(MET_IMAGES_CSV_URL, IMAGES_CSV) 307 | 308 | 309 | def get_image_file_name(image_url): 310 | """...""" 311 | _, suf = os.path.splitext(image_url) 312 | sha = hashlib.sha256() 313 | sha.update(image_url.encode()) 314 | return f"met-image-{sha.hexdigest().lower()}{suf.lower()}" 315 | 316 | 317 | class MetImageGetter(object): 318 | """Loading Cache for Met Images""" 319 | 320 | def __init__(self, cache_dir="~/.cache/nude2/images/", download=True): 321 | """Construct...""" 322 | self.cache_dir = os.path.expanduser(cache_dir) 323 | self.temp_dir = tempfile.mkdtemp(prefix="met-images-") 324 | self.download = download 325 | print(self.temp_dir) 326 | 327 | def __del__(self): 328 | shutil.rmtree(self.temp_dir) 329 | 330 | def fetch(self, image_url): 331 | """Return a PIL image """ 332 | 333 | image_url = image_url.replace(" ", "%20") 334 | 335 | if not os.path.exists(self.cache_dir): 336 | logger.info("Creating image cache directory: %s", self.cache_dir) 337 | os.makedirs(self.cache_dir) 338 | 339 | _, suf = os.path.splitext(image_url) 340 | 341 | sha = hashlib.sha256() 342 | sha.update(image_url.encode()) 343 | 344 | fname = f"met-image-{sha.hexdigest().lower()}{suf.lower()}" 345 | ftemp = os.path.join(self.temp_dir, fname) 346 | fpath = os.path.join(self.cache_dir, fname) 347 | 348 | if not os.path.exists(fpath): 349 | if not self.download: 350 | return None 351 | logger.debug("Downloading image: %s", image_url) 352 | try: 353 | print("Retrieving!") 354 | urllib.request.urlretrieve(image_url, ftemp) 355 | shutil.move(ftemp, fpath) 356 | except: 357 | logger.error("Failed to download image: %s", image_url) 358 | return None 359 | 360 | return fpath 361 | 362 | 363 | class MetDataset(Dataset): 364 | """Configurable dataset""" 365 | 366 | crop = torchvision.transforms.Compose([ 367 | torchvision.transforms.Resize((256, 256)), 368 | torchvision.transforms.CenterCrop(224), 369 | ]) 370 | 371 | tensorify = torchvision.transforms.Compose([ 372 | torchvision.transforms.ToTensor(), 373 | ]) 374 | 375 | SELECT_QUERY = """ 376 | SELECT 377 | `Object ID`, 378 | `Object Name`, 379 | `Title`, 380 | `Tags`, 381 | `Image URL` 382 | FROM met_images 383 | INNER JOIN met 384 | USING (`Object ID`) 385 | WHERE `Object Name` = 'Painting' 386 | """.strip() 387 | 388 | COUNT_QUERY = """ 389 | SELECT 390 | COUNT(*) 391 | FROM met_images 392 | INNER JOIN met 393 | USING (`Object ID`) 394 | WHERE `Object Name` = 'Painting' 395 | """.strip() 396 | 397 | def __init__(self, tags, cache_dir="~/.cache/nude2/", tag_db="", force_refresh=False): 398 | self.cache_dir = os.path.expanduser(cache_dir) 399 | self.base_images_dir = os.path.join(self.cache_dir, "images") 400 | self.augmented_images_dir = os.path.join(self.cache_dir, "images-augmented") 401 | 402 | self.db = MetData(DB) 403 | self.fetcher = MetImageGetter(self.base_images_dir, download=False) 404 | 405 | base_rows = self.db.select(self.SELECT_QUERY) 406 | self.rows = list(self.augmentations(base_rows)) 407 | 408 | five_crop = torchvision.transforms.Compose([ 409 | torchvision.transforms.Resize(336), 410 | torchvision.transforms.FiveCrop(224), 411 | ]) 412 | 413 | tensorify = torchvision.transforms.Compose([ 414 | torchvision.transforms.ToTensor(), 415 | torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 416 | ]) 417 | 418 | def fetch_crops(self, row): 419 | image_path = self.fetcher.fetch(row[-1]) 420 | prefix, suffix = os.path.splitext(os.path.basename(image_path)) 421 | crop_fnames = glob(os.path.join(self.augmented_images_dir, f"{prefix}-[0-4]-0.jpg")) 422 | 423 | crop_fnames = [ 424 | os.path.join(self.augmented_images_dir, f"{prefix}-{i}-0.jpg") 425 | for i in range(5) 426 | ] 427 | 428 | crops = [] 429 | 430 | if all(os.path.exists(fname) for fname in crop_fnames): 431 | for fname in crop_fnames: 432 | with PIL.Image.open(fname) as i: 433 | crops.append(i) 434 | 435 | else: 436 | logger.warning("Generating new images for {prefix}") 437 | try: 438 | with PIL.Image.open(image_path) as i: 439 | img = i.convert("RGB").copy() 440 | crops = self.five_crop(img) 441 | for i, c in enumerate(crops): 442 | fpath = os.path.join(self.augmented_images_dir, f"{prefix}-{i}-0.jpg") 443 | c.save(fpath) 444 | except PIL.Image.DecompressionBombError: 445 | logger.error("Decompression bomb error on image: %s", image_path) 446 | 447 | return crops 448 | 449 | def augmentations(self, rows): 450 | 451 | if not os.path.exists(self.augmented_images_dir): 452 | os.makedirs(self.augmented_images_dir) 453 | 454 | for row in rows: 455 | 456 | image_path = self.fetcher.fetch(row[-1]) 457 | 458 | if image_path is None: 459 | continue 460 | 461 | crops = self.fetch_crops(row) 462 | 463 | object_id = row[0] 464 | original_path = image_path 465 | 466 | fname, suffix = os.path.splitext(os.path.basename(image_path)) 467 | 468 | for i, c in enumerate(crops): 469 | aug_fname = f"{fname}-{i}-0{suffix}" 470 | aug_path = os.path.join(self.augmented_images_dir, aug_fname) 471 | 472 | if not os.path.exists(aug_path): 473 | print("SAVE!") 474 | c.save(aug_path) 475 | 476 | res = { 477 | "fname": aug_fname, 478 | "path": aug_path, 479 | "nude": "Nude" in row[-2], 480 | "base_image_path": row[-1], 481 | } 482 | 483 | yield res 484 | 485 | def save_augmentation_labels(self): 486 | augmented_labels_dir = os.path.join(self.cache_dir, "labels-augmented") 487 | 488 | if not os.path.exists(augmented_labels_dir): 489 | os.makedirs(augmented_labels_dir) 490 | 491 | for row in self.rows: 492 | image_prefix, _ = os.path.splitext(row["fname"]) 493 | label_path = os.path.join(augmented_labels_dir, f"{image_prefix}.json") 494 | row["label_path"] = label_path 495 | 496 | with open(label_path, "w") as fo: 497 | json.dump(row, fo) 498 | 499 | 500 | def __len__(self): 501 | return len(self.rows) 502 | 503 | def __getitem__(self, idx): 504 | with PIL.Image.open(self.rows[idx]["path"]) as i: 505 | row = self.rows[idx] 506 | row["image"] = self.tensorify(np.array(i)) 507 | return row 508 | 509 | 510 | class MetCenterCroppedDataset(Dataset): 511 | 512 | image_size = 64 513 | 514 | tensorify = torchvision.transforms.Compose([ 515 | torchvision.transforms.Resize(image_size), 516 | torchvision.transforms.CenterCrop(image_size), 517 | torchvision.transforms.ToTensor(), 518 | torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 519 | ]) 520 | 521 | pilify = torchvision.transforms.Compose([ 522 | torchvision.transforms.Normalize(mean=[0., 0., 0.], std=[1/0.5, 1/0.5, 1/0.5]), 523 | torchvision.transforms.Normalize(mean=[-0.5, -0.5, -0.5], std=[1., 1., 1.]), 524 | torchvision.transforms.ToPILImage(), 525 | ]) 526 | 527 | def __init__(self, cache_dir="~/.cache/nude2/"): 528 | self.cache_dir = os.path.expanduser(cache_dir) 529 | self.image_dir = self.cache_dir 530 | self.fnames = glob(os.path.join(self.image_dir, "*.jpg")) 531 | self.cache = {} 532 | 533 | def __len__(self): 534 | return len(self.fnames) 535 | 536 | def __getitem__(self, idx): 537 | if idx not in self.cache: 538 | try: 539 | with PIL.Image.open(self.fnames[idx]) as i: 540 | self.cache[idx] = self.tensorify(i.convert("RGB")) 541 | except PIL.Image.DecompressionBombError: 542 | print("fname =", self.fnames[idx]) 543 | return None 544 | except PIL.Image.DecompressionBombWarning: 545 | print("!!") 546 | return self.cache[idx] 547 | 548 | 549 | class MetFiveCornerDataset(Dataset): 550 | """Five corners + flip""" 551 | 552 | uncropped_size = 256 553 | 554 | cropped_size = 128 555 | 556 | tencrop = torchvision.transforms.Compose([ 557 | torchvision.transforms.Resize(uncropped_size), 558 | torchvision.transforms.TenCrop(cropped_size), 559 | ]) 560 | 561 | tensorify = torchvision.transforms.Compose([ 562 | torchvision.transforms.ToTensor(), 563 | torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 564 | ]) 565 | 566 | pilify = torchvision.transforms.Compose([ 567 | torchvision.transforms.Normalize(mean=[0., 0., 0.], std=[1/0.5, 1/0.5, 1/0.5]), 568 | torchvision.transforms.Normalize(mean=[-0.5, -0.5, -0.5], std=[1., 1., 1.]), 569 | torchvision.transforms.ToPILImage(), 570 | ]) 571 | 572 | def __init__(self, cache_dir="~/.cache/nude2/"): 573 | """...""" 574 | 575 | with sqlite3.connect(DB) as conn: 576 | curs = conn.cursor() 577 | curs.execute(""" 578 | SELECT 579 | `Image URL` 580 | FROM met_images 581 | INNER JOIN met_tags USING (`Object ID`) 582 | WHERE `Tag` LIKE '%Nude%' 583 | """.strip()) 584 | self.nude_file_names = { 585 | get_image_file_name(row[0]) 586 | for row in curs.fetchall() 587 | } 588 | 589 | 590 | self.cache_dir = os.path.expanduser(cache_dir) 591 | self.image_dir = self.cache_dir 592 | self.fnames = glob(os.path.join(self.image_dir, "*.jpg")) 593 | self.fnames = self.fnames 594 | # self.fnames = self.fnames[0:100] 595 | self.hits = mp.Array(ctypes.c_bool, len(self.fnames)) 596 | 597 | 598 | n = 10*len(self.fnames) 599 | c = 3 600 | w = h = self.cropped_size 601 | 602 | shared_array_base = mp.Array(ctypes.c_float, n * c * w * h) 603 | shared_array = np.ctypeslib.as_array(shared_array_base.get_obj()) 604 | shared_array = shared_array.reshape(n, c, h, w) 605 | self.cache = torch.from_numpy(shared_array) 606 | 607 | def __len__(self): 608 | """Return the length of the dataset""" 609 | return 10 * len(self.fnames) 610 | 611 | def __getitem__(self, idx): 612 | """Return the image at the given point""" 613 | i = idx // 10 614 | 615 | hit = self.hits[i] 616 | 617 | fname = self.fnames[i] 618 | fname = os.path.basename(fname) 619 | 620 | is_nude = fname in self.nude_file_names 621 | 622 | if self.hits[i] == False: 623 | PIL.Image.MAX_IMAGE_PIXELS = 343934400 + 1 624 | PIL.Image.MAX_IMAGE_PIXELS = 757164160 + 1 625 | with PIL.Image.open(self.fnames[i]) as raw_img: 626 | for j, cropped_image in enumerate(self.tencrop(raw_img.convert("RGB"))): 627 | assert 0 <= j < 10 628 | 629 | f, s = os.path.splitext(os.path.basename(self.fnames[i])) 630 | 631 | foutname = os.path.join( 632 | os.path.expanduser("~/.cache/nude2/images-tencrop"), 633 | f"{f}-{i}-{j}{s}", 634 | ) 635 | 636 | print(foutname) 637 | 638 | cropped_image.save(foutname) 639 | self.cache[10*i+j] = self.tensorify(cropped_image) 640 | self.hits[i] = True 641 | 642 | return is_nude, self.cache[idx] 643 | 644 | 645 | class CachedDataset(Dataset): 646 | 647 | cropped_size = 64 648 | 649 | uncropped_size = cropped_size // 2 * 3 650 | 651 | tensorify = torchvision.transforms.Compose([ 652 | torchvision.transforms.ToTensor(), 653 | torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 654 | ]) 655 | 656 | pilify = torchvision.transforms.Compose([ 657 | torchvision.transforms.Normalize(mean=[0., 0., 0.], std=[1/0.5, 1/0.5, 1/0.5]), 658 | torchvision.transforms.Normalize(mean=[-0.5, -0.5, -0.5], std=[1., 1., 1.]), 659 | torchvision.transforms.ToPILImage(), 660 | ]) 661 | 662 | def __init__(self, image_dir, cache_dir): 663 | 664 | self.image_dir = os.path.expanduser(image_dir) 665 | 666 | self.cache_dir = os.path.expanduser(cache_dir) 667 | 668 | self.image_fnames = [ 669 | fname 670 | for suffix in ["jpg"] 671 | for fname in glob(os.path.join(self.image_dir, f"*.{suffix}")) 672 | ] 673 | 674 | self.transforms = [ 675 | 676 | # Standard center-crop 677 | torchvision.transforms.Compose([ 678 | torchvision.transforms.Resize(self.cropped_size), 679 | torchvision.transforms.CenterCrop(self.cropped_size), 680 | ]), 681 | 682 | # Horizontal flip 683 | torchvision.transforms.Compose([ 684 | torchvision.transforms.Resize(self.cropped_size), 685 | torchvision.transforms.CenterCrop(self.cropped_size), 686 | torchvision.transforms.RandomHorizontalFlip(1.0), 687 | ]), 688 | ] 689 | 690 | self.transforms += [ 691 | # Random crop and jitter 692 | torchvision.transforms.Compose([ 693 | torchvision.transforms.Resize(self.uncropped_size), 694 | torchvision.transforms.RandomCrop(self.cropped_size), 695 | torchvision.transforms.RandomHorizontalFlip(0.5), 696 | ]), 697 | ] * 3 698 | 699 | self.transforms += [ 700 | # Random crop and jitter 701 | torchvision.transforms.Compose([ 702 | torchvision.transforms.Resize(self.uncropped_size), 703 | torchvision.transforms.RandomCrop(self.cropped_size), 704 | torchvision.transforms.ColorJitter(brightness=(0.5,1.5), contrast=(1), saturation=(0.5,1.5), hue=(-0.1,0.1)), 705 | torchvision.transforms.RandomAdjustSharpness(1.5), 706 | torchvision.transforms.RandomHorizontalFlip(0.5), 707 | ]), 708 | ] * 2 709 | 710 | 711 | os.makedirs(self.cache_dir, exist_ok=True) 712 | 713 | 714 | def get_filename(self, idx): 715 | """Return the filename""" 716 | i = idx // len(self.transforms) 717 | return self.image_fnames[i] 718 | 719 | 720 | def get_source_image(self, idx): 721 | """Return the source image for an index""" 722 | with PIL.Image.open(self.get_filename(idx)) as i: 723 | return i 724 | 725 | def __len__(self): 726 | """Return the total number of augmented images""" 727 | return len(self.transforms) * len(self.image_fnames) 728 | 729 | def __getitem__(self, idx): 730 | """Return an image as a tensor from the dataset""" 731 | 732 | i = idx // len(self.transforms) 733 | r = idx % len(self.transforms) 734 | 735 | cache_fname = f"image-{i}-{r}.jpg" 736 | cache_fpath = os.path.join(self.cache_dir, cache_fname) 737 | 738 | if not os.path.exists(cache_fpath): 739 | fname = self.get_filename(idx) 740 | 741 | PIL.Image.MAX_IMAGE_PIXELS = 757164160 + 1 742 | 743 | with PIL.Image.open(fname) as img: 744 | for j, t in enumerate(self.transforms): 745 | img = t(img.convert("RGB")) 746 | 747 | with tempfile.NamedTemporaryFile(prefix="image-", suffix=".jpg", delete=False) as fo: 748 | img.save(fo.name) 749 | 750 | dest_fname = f"image-{i}-{j}.jpg" 751 | dest_fpath = os.path.join(self.cache_dir, dest_fname) 752 | 753 | shutil.move( 754 | fo.name, 755 | dest_fpath, 756 | ) 757 | 758 | with PIL.Image.open(cache_fpath) as img: 759 | return self.tensorify(img) 760 | 761 | 762 | 763 | def main(concurrency, limit): 764 | 765 | dataset = CachedDataset("~/images/", "~/images-random-crops") 766 | dataloader = DataLoader(dataset, batch_size=100, num_workers=8, shuffle=True) 767 | 768 | for row in dataloader: 769 | pass 770 | 771 | return 772 | 773 | dataset = MetFiveCornerDataset(cache_dir="~/.cache/nude2/images/") 774 | 775 | dataloader = DataLoader(dataset, batch_size=100, num_workers=0, shuffle=False) 776 | 777 | for epoch in range(10): 778 | print(f"epoch = {epoch}") 779 | 780 | start = datetime.now() 781 | 782 | for hit, x in dataloader: 783 | pass 784 | 785 | dur = (datetime.now() - start).total_seconds() 786 | 787 | print("Duration:", dur) 788 | --------------------------------------------------------------------------------