├── src
├── __init__.py
├── image_utils.py
├── features.py
├── gans.py
├── pipeline.py
└── search.py
├── data
└── download_data.sh
├── pip_requirements.txt
├── .gitignore
├── README.md
├── conda_requirements.txt
├── networks
├── pix2pix.py
├── cyclegan.py
└── stargan.py
└── processing
└── feature_vectors_download.ipynb
/src/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/download_data.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | echo "Downloading images..."
3 | curl -Lo images.zip https://s3.eu-central-1.amazonaws.com/fashion-gan/images.zip
4 | echo "Downloading features..."
5 | curl -Lo features.zip https://s3.eu-central-1.amazonaws.com/fashion-gan/features.zip
6 | echo "Downloading models..."
7 | curl -Lo models.zip https://s3.eu-central-1.amazonaws.com/fashion-gan/models.zip
8 | echo "Downloading clustering data..."
9 | curl -Lo clustering.zip https://s3.eu-central-1.amazonaws.com/fashion-gan/clustering.zip
10 |
11 | echo "Unzipping files..."
12 | unzip -q images.zip
13 | unzip -q features.zip
14 | unzip -q models.zip
15 | unzip -q clustering.zip
16 |
17 | rm *.zip
18 |
--------------------------------------------------------------------------------
/pip_requirements.txt:
--------------------------------------------------------------------------------
1 | anaconda-client==1.7.2
2 | appdirs==1.4.3
3 | appnope==0.1.0
4 | asn1crypto==0.24.0
5 | attrs==18.2.0
6 | Automat==0.7.0
7 | backcall==0.1.0
8 | bleach==2.1.4
9 | certifi==2018.8.24
10 | cffi==1.11.5
11 | chardet==3.0.4
12 | clyent==1.2.2
13 | constantly==15.1.0
14 | cryptography==2.3.1
15 | cycler==0.10.0
16 | Cython==0.28.5
17 | decorator==4.3.0
18 | entrypoints==0.2.3
19 | html5lib==1.0.1
20 | hyperlink==18.0.0
21 | idna==2.7
22 | incremental==17.5.0
23 | ipykernel==4.9.0
24 | ipython==6.5.0
25 | ipython-genutils==0.2.0
26 | ipywidgets==7.4.1
27 | jedi==0.12.1
28 | Jinja2==2.10
29 | jsonschema==2.6.0
30 | jupyter-client==5.2.3
31 | jupyter-core==4.4.0
32 | kiwisolver==1.0.1
33 | MarkupSafe==1.0
34 | matplotlib==2.2.3
35 | mistune==0.8.3
36 | mkl-fft==1.0.4
37 | mkl-random==1.0.1
38 | nb-anacondacloud==1.4.0
39 | nb-conda==2.2.1
40 | nb-conda-kernels==2.1.0
41 | nbconvert==5.3.1
42 | nbformat==4.4.0
43 | nbpresent==3.0.2
44 | notebook==5.6.0
45 | numpy==1.15.1
46 | olefile==0.46
47 | pandas==0.23.4
48 | pandocfilters==1.4.2
49 | parso==0.3.1
50 | pexpect==4.6.0
51 | pickleshare==0.7.4
52 | Pillow==5.2.0
53 | prometheus-client==0.3.1
54 | prompt-toolkit==1.0.15
55 | ptyprocess==0.6.0
56 | pyasn1==0.4.4
57 | pyasn1-modules==0.2.2
58 | pycparser==2.18
59 | Pygments==2.2.0
60 | pyOpenSSL==18.0.0
61 | pyparsing==2.2.0
62 | PySocks==1.6.8
63 | python-dateutil==2.7.3
64 | pytz==2018.5
65 | PyYAML==3.13
66 | pyzmq==17.1.2
67 | requests==2.20.0
68 | scikit-learn==0.19.2
69 | scipy==1.1.0
70 | Send2Trash==1.5.0
71 | service-identity==17.0.0
72 | simplegeneric==0.8.1
73 | six==1.11.0
74 | terminado==0.8.1
75 | testpath==0.3.1
76 | torch==0.4.1
77 | torchvision==0.2.1
78 | tornado==5.1
79 | traitlets==4.3.2
80 | Twisted==17.5.0
81 | urllib3==1.23
82 | wcwidth==0.1.7
83 | webencodings==0.5.1
84 | widgetsnbextension==3.4.1
85 | zope.interface==4.5.0
86 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | data
2 | !data/download_data.sh
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .nox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 |
63 | # Flask stuff:
64 | instance/
65 | .webassets-cache
66 |
67 | # Scrapy stuff:
68 | .scrapy
69 |
70 | # Sphinx documentation
71 | docs/_build/
72 |
73 | # PyBuilder
74 | target/
75 |
76 | # Jupyter Notebook
77 | .ipynb_checkpoints
78 |
79 | # IPython
80 | profile_default/
81 | ipython_config.py
82 |
83 | # pyenv
84 | .python-version
85 |
86 | # celery beat schedule file
87 | celerybeat-schedule
88 |
89 | # SageMath parsed files
90 | *.sage.py
91 |
92 | # Environments
93 | .env
94 | .venv
95 | env/
96 | venv/
97 | ENV/
98 | env.bak/
99 | venv.bak/
100 |
101 | # Spyder project settings
102 | .spyderproject
103 | .spyproject
104 |
105 | # Rope project settings
106 | .ropeproject
107 |
108 | # mkdocs documentation
109 | /site
110 |
111 | # mypy
112 | .mypy_cache/
113 | .dmypy.json
114 | dmypy.json
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FashionGAN Search
2 |
3 | This project is part of the evaluation of generative adversarial networks for improving image retrieval systems. It uses a fashion dataset to synthesize new images of fashion products based on user input, and to trigger a search of similar existing products.
4 | The main application allows user to modify the shape and pattern of a dress, and then choose the best match from the retrieved products. The user can then further modify the chosen product.
5 |
6 | 
7 |
8 | ## Usage
9 |
10 | #### App
11 | To use the project run the *FashionGAN_search.ipynb* notebook in the given conda environment (see requirements). The application started in the notebook prompts the user to control the image modifications and search by text input.
12 |
13 | #### Processing
14 | The notebooks in processing folder were used to download the feature vectors for image retrieval and clustering model images. All the data that they produce is already provided in the data folder. However, these notebooks can be run to further understand these processing steps.
15 |
16 | #### Networks
17 | The networks folder contains the three generators used in the final model
18 | - **StarGAN** originally from: https://github.com/yunjey/StarGAN
19 | - **CycleGAN** originally from: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
20 | - **Pix2Pix** originally from: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
21 |
22 | The networks were trained on the fashion dataset, and the best models are provided in the data folder.
23 |
24 | ## Setup
25 |
26 | #### Data
27 | All data neccessary for running this project can be downloaded by running the following script:
28 |
29 | ```bash
30 | cd data
31 | ./download_data.sh
32 | ```
33 |
34 | The script will download the following folders:
35 | - **images**: images of dresses and models wearing those dresses that the networks were trained on (cca 15.000 product images and 60.000 model images).
36 | - **clustering**: models and data for clustering of available model images to create a paired dataset of 1 product image + 1 model image
37 | - **features**: feature vectors for both product images and clustered model images for retrieval
38 | - **models**: trained GAN models to modify attributes of images (The models were trained using several GANs repositories: **Pix2Pix** and **CycleGAN** on https://github.com/sonynka/pytorch-CycleGAN-and-pix2pix and **StarGAN** on https://github.com/sonynka/StarGAN.
39 |
40 | **Note**: The original dataset was scraped from various fashion online stores and contains cca 90.000 images. For the purpose of this project, I only used category dresses. Code for scraping and the whole dataset can be found here: https://github.com/sonynka/fashion_scraper.
41 |
42 | #### Requirements
43 | To download Anaconda package manager, go to: https://www.continuum.io/downloads.
44 | After installing the conda environment locally, proceed to setup this project environment.
45 |
46 | Install all dependencies from conda_requirements.txt file.
47 | ```bash
48 | conda create -n fashion_gan python=3.6
49 | source activate fashion_gan
50 | conda install --file conda_requirements.txt
51 | pip install -r pip_requirements.txt
52 | ```
53 |
54 | To start a jupyter notebook in the environment:
55 | ```bash
56 | source activate fashion_gan
57 | jupyter notebook
58 | ```
59 |
60 |
61 | To deactivate this specific virtual environment:
62 | ```bash
63 | source deactivate
64 | ```
65 |
66 | If you need to completely remove this conda env, you can use the following command:
67 |
68 | ```bash
69 | conda env remove --name fashion_gan
70 | ```
71 |
--------------------------------------------------------------------------------
/src/image_utils.py:
--------------------------------------------------------------------------------
1 | from PIL import Image, ImageOps, ImageFilter
2 | import numpy as np
3 | import cv2
4 | import matplotlib.pyplot as plt
5 |
6 |
7 | def pad_image(img):
8 | width, height = img.size
9 |
10 | max_size = max(width, height)
11 |
12 | pad_height = max_size - height
13 | pad_width = max_size - width
14 |
15 | padding = (pad_width // 2,
16 | pad_height // 2,
17 | pad_width - (pad_width // 2),
18 | pad_height - (pad_height // 2))
19 |
20 | padded_img = ImageOps.expand(img, padding, fill=(255, 255, 255))
21 | return padded_img
22 |
23 |
24 | def remove_alpha(img):
25 | if img.mode == 'RGBA':
26 | img.load() # required for png.split()
27 | image_jpeg = Image.new("RGB", img.size, (255, 255, 255))
28 | image_jpeg.paste(img, mask=img.split()[3]) # 3 is the alpha channel
29 | img = image_jpeg
30 |
31 | return img
32 |
33 |
34 | def resize_image(img, size):
35 | return img.resize(size, Image.ANTIALIAS)
36 |
37 |
38 | def process_image(img: Image, size=None):
39 | if not size:
40 | size = [256, 256]
41 |
42 | img = remove_alpha(img)
43 | img = pad_image(img)
44 | img = resize_image(img, size=size)
45 |
46 | return img
47 |
48 |
49 | def get_edges(img_path, img_size, crop=False):
50 | """ Get black-white canny edge detection for the image """
51 |
52 | img = Image.open(img_path)
53 | if crop:
54 | w, h = img.size
55 | img = img.crop((70, 0, w - 70, h))
56 | img = img.resize(img_size, Image.ANTIALIAS)
57 | img = img.filter(ImageFilter.FIND_EDGES)
58 | img = img.convert('L')
59 | img_arr = np.array(img)
60 | img.close()
61 |
62 | return img_arr
63 |
64 |
65 | def get_mask(img_path, img_size, crop=False):
66 |
67 | img = threshold_mask(img_path)
68 |
69 | if crop:
70 | w, h = img.size
71 | img = img.crop((70, 0, w - 70, h))
72 | img = img.resize(img_size)
73 | img_arr = np.array(img)
74 | img.close()
75 |
76 | return img_arr
77 |
78 |
79 | def threshold_mask(img_path, thresh=200):
80 | """
81 | Get image threshold mask.
82 | """
83 |
84 | # Read image
85 | im_in = cv2.imread(img_path)
86 |
87 | # grayscale and blur
88 | im_in = cv2.cvtColor(im_in, cv2.COLOR_BGR2GRAY)
89 | im_in = cv2.GaussianBlur(im_in, (5, 5), 0)
90 |
91 | # get threshold mask
92 | _, thresh_mask = cv2.threshold(im_in, thresh, 255, cv2.THRESH_BINARY)
93 | im_masked = cv2.bitwise_not(thresh_mask)
94 |
95 | # fill contours of threshold mask
96 | _, contours, _ = cv2.findContours(im_masked, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
97 | for cnt in contours:
98 | cv2.drawContours(im_masked, [cnt], 0, 255, -1)
99 |
100 | # smooth and dilate
101 | smooth = cv2.GaussianBlur(im_masked, (9, 9), 0)
102 | _, thresh_mask2 = cv2.threshold(smooth, 0, 255,
103 | cv2.THRESH_BINARY + cv2.THRESH_OTSU)
104 | kernel = np.ones((2, 2), np.uint8)
105 | im_out = cv2.dilate(thresh_mask2, kernel, iterations=1)
106 |
107 | return Image.fromarray(im_out)
108 |
109 |
110 | def plot_img_row(images, img_labels=None):
111 | fig, axarr = plt.subplots(nrows=1, ncols=len(images),
112 | figsize=(len(images) * 3, 4))
113 |
114 | for i, img in enumerate(images):
115 | ax = axarr[i]
116 | img = img.resize([256, 256])
117 | img = img.crop((40, 0, 216, 256))
118 | ax.imshow(img)
119 | ax.set_xticks([])
120 | ax.set_yticks([])
121 |
122 | for spine in ax.spines.keys():
123 | ax.spines[spine].set_visible(False)
124 |
125 | if img_labels is not None:
126 | ax.set_title(img_labels[i])
127 |
128 | plt.show()
129 |
130 |
131 | def plot_img(img):
132 | plt.imshow(img)
133 | plt.axis('off')
134 | plt.show()
135 |
136 | return img
--------------------------------------------------------------------------------
/conda_requirements.txt:
--------------------------------------------------------------------------------
1 | # This file may be used to create an environment using:
2 | # $ conda create --name --file
3 | # platform: osx-64
4 | _nb_ext_conf=0.4.0=py36_1
5 | anaconda-client=1.7.2=py36_0
6 | appdirs=1.4.3=py36h28b3542_0
7 | appnope=0.1.0=py36hf537a9a_0
8 | asn1crypto=0.24.0=py36_0
9 | attrs=18.2.0=py36h28b3542_0
10 | automat=0.7.0=py36_0
11 | backcall=0.1.0=py36_0
12 | blas=1.0=mkl
13 | bleach=2.1.4=py36_0
14 | bzip2=1.0.6=h1de35cc_5
15 | ca-certificates=2018.03.07=0
16 | cairo=1.14.12=hc4e6be7_4
17 | certifi=2018.8.24=py36_1
18 | cffi=1.11.5=py36h6174b99_1
19 | chardet=3.0.4=py36_1
20 | clyent=1.2.2=py36_1
21 | constantly=15.1.0=py36h28b3542_0
22 | cryptography=2.3.1=py36hdbc3d79_0
23 | cycler=0.10.0=py36hfc81398_0
24 | decorator=4.3.0=py36_0
25 | entrypoints=0.2.3=py36_2
26 | ffmpeg=4.0=h01ea3c9_0
27 | fontconfig=2.13.0=h5d5b041_1
28 | freetype=2.9.1=hb4e5f40_0
29 | gettext=0.19.8.1=h15daf44_3
30 | glib=2.56.2=hd9629dc_0
31 | graphite2=1.3.12=h2098e52_2
32 | harfbuzz=1.8.8=hb8d4a28_0
33 | hdf5=1.10.2=hfa1e0ec_1
34 | html5lib=1.0.1=py36_0
35 | hyperlink=18.0.0=py36_0
36 | icu=58.2=h4b95b61_1
37 | idna=2.7=py36_0
38 | incremental=17.5.0=py36_0
39 | intel-openmp=2019.0=118
40 | ipykernel=4.9.0=py36_1
41 | ipython=6.5.0=py36_0
42 | ipython_genutils=0.2.0=py36h241746c_0
43 | ipywidgets=7.4.1=py36_0
44 | jasper=2.0.14=h636a363_1
45 | jedi=0.12.1=py36_0
46 | jinja2=2.10=py36_0
47 | jpeg=9b=he5867d9_2
48 | jsonschema=2.6.0=py36hb385e00_0
49 | jupyter_client=5.2.3=py36_0
50 | jupyter_core=4.4.0=py36_0
51 | kiwisolver=1.0.1=py36h0a44026_0
52 | libcxx=4.0.1=h579ed51_0
53 | libcxxabi=4.0.1=hebd6815_0
54 | libedit=3.1.20170329=hb402a30_2
55 | libffi=3.2.1=h475c297_4
56 | libgfortran=3.0.1=h93005f0_2
57 | libiconv=1.15=hdd342a3_7
58 | libopencv=3.4.2=h7c891bd_1
59 | libopus=1.2.1=h169cedb_0
60 | libpng=1.6.34=he12f830_0
61 | libsodium=1.0.16=h3efe00b_0
62 | libtiff=4.0.9=hcb84e12_2
63 | libvpx=1.7.0=h378b8a2_0
64 | libxml2=2.9.8=hab757c2_1
65 | markupsafe=1.0=py36h1de35cc_1
66 | matplotlib=2.2.3=py36h54f8f79_0
67 | mistune=0.8.3=py36h1de35cc_1
68 | mkl=2019.0=118
69 | mkl_fft=1.0.4=py36h5d10147_1
70 | mkl_random=1.0.1=py36h5d10147_1
71 | nb_anacondacloud=1.4.0=py36_0
72 | nb_conda=2.2.1=py36_0
73 | nb_conda_kernels=2.1.0=py36_0
74 | nbconvert=5.3.1=py36_0
75 | nbformat=4.4.0=py36h827af21_0
76 | nbpresent=3.0.2=py36_1
77 | ncurses=6.1=h0a44026_0
78 | ninja=1.8.2=py36h04f5b5a_1
79 | notebook=5.6.0=py36_0
80 | numpy=1.15.1=py36h6a91979_0
81 | numpy-base=1.15.1=py36h8a80b8c_0
82 | olefile=0.46=py36_0
83 | opencv=3.4.2=py36h6fd60c2_1
84 | openssl=1.0.2p=h1de35cc_0
85 | pandas=0.23.4=py36h6440ff4_0
86 | pandoc=2.2.3.2=0
87 | pandocfilters=1.4.2=py36_1
88 | parso=0.3.1=py36_0
89 | pcre=8.42=h378b8a2_0
90 | pexpect=4.6.0=py36_0
91 | pickleshare=0.7.4=py36hf512f8e_0
92 | pillow=5.2.0=py36hb68e598_0
93 | pip=10.0.1=py36_0
94 | pixman=0.34.0=hca0a616_3
95 | prometheus_client=0.3.1=py36h28b3542_0
96 | prompt_toolkit=1.0.15=py36haeda067_0
97 | ptyprocess=0.6.0=py36_0
98 | py-opencv=3.4.2=py36h7c891bd_1
99 | pyasn1=0.4.4=py36h28b3542_0
100 | pyasn1-modules=0.2.2=py36_0
101 | pycparser=2.18=py36_1
102 | pygments=2.2.0=py36h240cd3f_0
103 | pyopenssl=18.0.0=py36_0
104 | pyparsing=2.2.0=py36_1
105 | pysocks=1.6.8=py36_0
106 | python=3.6.6=hc167b69_0
107 | python-dateutil=2.7.3=py36_0
108 | pytz=2018.5=py36_0
109 | pyyaml=3.13=py36h1de35cc_0
110 | pyzmq=17.1.2=py36h1de35cc_0
111 | readline=7.0=h1de35cc_5
112 | requests=2.20.0=py36_0
113 | scikit-learn=0.19.2=py36h4f467ca_0
114 | scipy=1.1.0=py36h28f7352_1
115 | send2trash=1.5.0=py36_0
116 | service_identity=17.0.0=py36h28b3542_0
117 | setuptools=40.2.0=py36_0
118 | simplegeneric=0.8.1=py36_2
119 | six=1.11.0=py36_1
120 | sqlite=3.24.0=ha441bb4_0
121 | terminado=0.8.1=py36_1
122 | testpath=0.3.1=py36h625a49b_0
123 | tk=8.6.8=ha441bb4_0
124 | tornado=5.1=py36h1de35cc_0
125 | traitlets=4.3.2=py36h65bd3ce_0
126 | twisted=17.5.0=py36_0
127 | urllib3=1.23=py36_0
128 | wcwidth=0.1.7=py36h8c6ec74_0
129 | webencodings=0.5.1=py36_1
130 | wheel=0.31.1=py36_0
131 | widgetsnbextension=3.4.1=py36_0
132 | xz=5.2.4=h1de35cc_4
133 | yaml=0.1.7=hc338f04_2
134 | zeromq=4.2.5=h0a44026_1
135 | zlib=1.2.11=hf3cbc9b_2
136 | zope=1.0=py36_1
137 | zope.interface=4.5.0=py36h1de35cc_0
138 |
--------------------------------------------------------------------------------
/networks/pix2pix.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 | import functools
5 | from torch.optim import lr_scheduler
6 |
7 | ###############################################################################
8 | # Helper Functions
9 | ###############################################################################
10 |
11 |
12 | # Defines the Unet generator.
13 | # |num_downs|: number of downsamplings in UNet. For example,
14 | # if |num_downs| == 7, image of size 128x128 will become of size 1x1
15 | # at the bottleneck
16 | class Generator(nn.Module):
17 | def __init__(self, input_nc=3, output_nc=3, num_downs=8, ngf=64,
18 | norm_layer=nn.BatchNorm2d, use_dropout=False):
19 | super(Generator, self).__init__()
20 |
21 | # construct unet structure
22 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
23 | for i in range(num_downs - 5):
24 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
25 | unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
26 | unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
27 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
28 | unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
29 |
30 | self.model = unet_block
31 |
32 | def forward(self, input):
33 | return self.model(input)
34 |
35 |
36 | # Defines the submodule with skip connection.
37 | # X -------------------identity---------------------- X
38 | # |-- downsampling -- |submodule| -- upsampling --|
39 | class UnetSkipConnectionBlock(nn.Module):
40 | def __init__(self, outer_nc, inner_nc, input_nc=None,
41 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
42 | super(UnetSkipConnectionBlock, self).__init__()
43 | self.outermost = outermost
44 | if type(norm_layer) == functools.partial:
45 | use_bias = norm_layer.func == nn.InstanceNorm2d
46 | else:
47 | use_bias = norm_layer == nn.InstanceNorm2d
48 | if input_nc is None:
49 | input_nc = outer_nc
50 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
51 | stride=2, padding=1, bias=use_bias)
52 | downrelu = nn.LeakyReLU(0.2, True)
53 | downnorm = norm_layer(inner_nc)
54 | uprelu = nn.ReLU(True)
55 | upnorm = norm_layer(outer_nc)
56 |
57 | if outermost:
58 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
59 | kernel_size=4, stride=2,
60 | padding=1)
61 | down = [downconv]
62 | up = [uprelu, upconv, nn.Tanh()]
63 | model = down + [submodule] + up
64 | elif innermost:
65 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
66 | kernel_size=4, stride=2,
67 | padding=1, bias=use_bias)
68 | down = [downrelu, downconv]
69 | up = [uprelu, upconv, upnorm]
70 | model = down + up
71 | else:
72 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
73 | kernel_size=4, stride=2,
74 | padding=1, bias=use_bias)
75 | down = [downrelu, downconv, downnorm]
76 | up = [uprelu, upconv, upnorm]
77 |
78 | if use_dropout:
79 | model = down + [submodule] + up + [nn.Dropout(0.5)]
80 | else:
81 | model = down + [submodule] + up
82 |
83 | self.model = nn.Sequential(*model)
84 |
85 | def forward(self, x):
86 | if self.outermost:
87 | return self.model(x)
88 | else:
89 | return torch.cat([x, self.model(x)], 1)
90 |
--------------------------------------------------------------------------------
/networks/cyclegan.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import functools
3 |
4 | # Defines the generator that consists of Resnet blocks between a few
5 | # downsampling/upsampling operations.
6 | # Code and idea originally from Justin Johnson's architecture.
7 | # https://github.com/jcjohnson/fast-neural-style/
8 | class Generator(nn.Module):
9 | def __init__(self, input_nc=3, output_nc=3, ngf=64, use_dropout=False, n_blocks=9, padding_type='reflect'):
10 |
11 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False,
12 | track_running_stats=True)
13 | assert(n_blocks >= 0)
14 | super(Generator, self).__init__()
15 | self.input_nc = input_nc
16 | self.output_nc = output_nc
17 | self.ngf = ngf
18 | if type(norm_layer) == functools.partial:
19 | use_bias = norm_layer.func == nn.InstanceNorm2d
20 | else:
21 | use_bias = norm_layer == nn.InstanceNorm2d
22 |
23 | model = [nn.ReflectionPad2d(3),
24 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
25 | bias=use_bias),
26 | norm_layer(ngf),
27 | nn.ReLU(True)]
28 |
29 | n_downsampling = 2
30 | for i in range(n_downsampling):
31 | mult = 2**i
32 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
33 | stride=2, padding=1, bias=use_bias),
34 | norm_layer(ngf * mult * 2),
35 | nn.ReLU(True)]
36 |
37 | mult = 2**n_downsampling
38 | for i in range(n_blocks):
39 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
40 |
41 | for i in range(n_downsampling):
42 | mult = 2**(n_downsampling - i)
43 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
44 | kernel_size=3, stride=2,
45 | padding=1, output_padding=1,
46 | bias=use_bias),
47 | norm_layer(int(ngf * mult / 2)),
48 | nn.ReLU(True)]
49 | model += [nn.ReflectionPad2d(3)]
50 | model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
51 | model += [nn.Tanh()]
52 |
53 | self.model = nn.Sequential(*model)
54 |
55 | def forward(self, input):
56 | return self.model(input)
57 |
58 |
59 | # Define a resnet block
60 | class ResnetBlock(nn.Module):
61 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
62 | super(ResnetBlock, self).__init__()
63 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
64 |
65 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
66 | conv_block = []
67 | p = 0
68 | if padding_type == 'reflect':
69 | conv_block += [nn.ReflectionPad2d(1)]
70 | elif padding_type == 'replicate':
71 | conv_block += [nn.ReplicationPad2d(1)]
72 | elif padding_type == 'zero':
73 | p = 1
74 | else:
75 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
76 |
77 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
78 | norm_layer(dim),
79 | nn.ReLU(True)]
80 | if use_dropout:
81 | conv_block += [nn.Dropout(0.5)]
82 |
83 | p = 0
84 | if padding_type == 'reflect':
85 | conv_block += [nn.ReflectionPad2d(1)]
86 | elif padding_type == 'replicate':
87 | conv_block += [nn.ReplicationPad2d(1)]
88 | elif padding_type == 'zero':
89 | p = 1
90 | else:
91 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
92 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
93 | norm_layer(dim)]
94 |
95 | return nn.Sequential(*conv_block)
96 |
97 | def forward(self, x):
98 | out = x + self.conv_block(x)
99 | return out
100 |
--------------------------------------------------------------------------------
/networks/stargan.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from torchvision import transforms
6 |
7 |
8 | class ResidualBlock(nn.Module):
9 | """Residual Block."""
10 | def __init__(self, dim_in, dim_out):
11 | super(ResidualBlock, self).__init__()
12 | self.main = nn.Sequential(
13 | nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
14 | nn.InstanceNorm2d(dim_out, affine=True),
15 | nn.ReLU(inplace=True),
16 | nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
17 | nn.InstanceNorm2d(dim_out, affine=True))
18 |
19 | def forward(self, x):
20 | return x + self.main(x)
21 |
22 |
23 | class Generator(nn.Module):
24 | """Generator. Encoder-Decoder Architecture."""
25 | def __init__(self, conv_dim=64, c_dim=5, repeat_num=6, image_size=128):
26 | super(Generator, self).__init__()
27 |
28 | def conv2d_output_size(image_size, kernel, stride, pad):
29 | return ((image_size + 2 * pad - kernel)//stride) + 1
30 |
31 | def conv2dtranspose_output_size(image_size, kernel, stride, pad):
32 | return (image_size - 1) * stride - 2 * pad + kernel
33 |
34 | layers = []
35 | layers.append(nn.Conv2d(3+c_dim, conv_dim, kernel_size=7, stride=1, padding=3, bias=False))
36 | layers.append(nn.InstanceNorm2d(conv_dim, affine=True))
37 | layers.append(nn.ReLU(inplace=True))
38 |
39 | image_size = conv2d_output_size(image_size=image_size, kernel=7, stride=1, pad=3)
40 |
41 | # Down-Sampling
42 | curr_dim = conv_dim
43 | for i in range(2):
44 | layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1, bias=False))
45 | layers.append(nn.InstanceNorm2d(curr_dim*2, affine=True))
46 | layers.append(nn.ReLU(inplace=True))
47 | curr_dim = curr_dim * 2
48 |
49 | image_size = conv2d_output_size(image_size=image_size, kernel=4, stride=2, pad=1)
50 |
51 | # Bottleneck
52 | for i in range(repeat_num):
53 | layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim))
54 | image_size = conv2d_output_size(image_size=image_size, kernel=3, stride=1, pad=1)
55 |
56 | # Up-Sampling
57 | for i in range(2):
58 | image_size = conv2dtranspose_output_size(image_size=image_size, kernel=4, stride=2, pad=1)
59 | layers.append(nn.Upsample(size=image_size, mode='bilinear'))
60 | layers.append(nn.ConvTranspose2d(curr_dim, curr_dim//2, kernel_size=7, stride=1, padding=3, bias=False))
61 | layers.append(nn.InstanceNorm2d(curr_dim//2, affine=True))
62 | layers.append(nn.ReLU(inplace=True))
63 | curr_dim = curr_dim // 2
64 |
65 | layers.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False))
66 | layers.append(nn.Tanh())
67 | self.main = nn.Sequential(*layers)
68 |
69 | def forward(self, x, c):
70 | # replicate spatially and concatenate domain information
71 | c = c.unsqueeze(2).unsqueeze(3)
72 | c = c.expand(c.size(0), c.size(1), x.size(2), x.size(3))
73 | x = torch.cat([x, c], dim=1)
74 | return self.main(x)
75 |
76 |
77 | class Discriminator(nn.Module):
78 | """Discriminator. PatchGAN."""
79 | def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6):
80 | super(Discriminator, self).__init__()
81 |
82 | layers = []
83 | layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
84 | layers.append(nn.LeakyReLU(0.01, inplace=True))
85 |
86 | curr_dim = conv_dim
87 | for i in range(1, repeat_num):
88 | layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1))
89 | layers.append(nn.LeakyReLU(0.01, inplace=True))
90 | curr_dim = curr_dim * 2
91 |
92 | k_size = int(image_size / np.power(2, repeat_num))
93 | self.main = nn.Sequential(*layers)
94 | self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1, bias=False)
95 | self.conv2 = nn.Conv2d(curr_dim, c_dim, kernel_size=k_size, bias=False)
96 |
97 | def forward(self, x):
98 | h = self.main(x)
99 | out_real = self.conv1(h)
100 | out_aux = self.conv2(h)
101 | return out_real.squeeze(), out_aux.squeeze()
--------------------------------------------------------------------------------
/src/features.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import io
4 | import requests
5 | from PIL import Image
6 | import torch
7 | import torchvision
8 | import torch.nn as nn
9 | from torchvision import transforms
10 |
11 |
12 | class ResnetFeatureGenerator():
13 | """
14 | Loads a ResNet152 model to generate feature vectors from the last hidden
15 | layer of the network. If no model path is provided, the pretrained ImageNet
16 | model from PyTorch is loaded. Otherwise, the model's weights are loaded
17 | from the model path.
18 | """
19 |
20 | _DATA_TRANSFORMS = transforms.Compose([
21 | transforms.Resize(224),
22 | transforms.ToTensor(),
23 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
24 | ])
25 |
26 | def __init__(self, retrained_model_path=None):
27 | """
28 | :param retrained_model_path: path to a retrained resnet152 model
29 | (if not provided, the original ImageNet pretrained model from PyTorch
30 | is loaded)
31 | """
32 |
33 | self.model_path = retrained_model_path
34 | self.model = self._load_resnet152()
35 |
36 | def get_feature(self, img: Image):
37 | feature = self.model(self._DATA_TRANSFORMS(img).unsqueeze(0))
38 | feature = feature.squeeze().data.numpy()
39 |
40 | return feature
41 |
42 | def _load_resnet152(self):
43 | """Loads original ResNet152 model with pretrained weights on ImageNet
44 | dataset. If a model_path is provided, it loads the weights from the
45 | model_path. Returns the model excluding the last layer.
46 | """
47 |
48 | model = torchvision.models.resnet152(pretrained=True)
49 |
50 | if self.model_path:
51 | num_ftrs = model.fc.in_features
52 | model.fc = nn.Linear(num_ftrs, 7)
53 | model.load_state_dict(
54 | torch.load(self.model_path, map_location='cpu'))
55 |
56 | # only load layers until the last hidden layer to extract features
57 | modules = list(model.children())[:-1]
58 | model = nn.Sequential(*modules)
59 |
60 | return model
61 |
62 |
63 | class AkiwiFeatureGenerator():
64 | """
65 | Queries the Akiwi API to get feature vectors for an image.
66 | """
67 |
68 | _SIZE_URLS = {
69 | 64: ['http://akiwi.eu/mxnet/feature/'],
70 | 50: ['http://akiwi.eu/feature/fv50/'],
71 | 114: ['http://akiwi.eu/mxnet/feature/',
72 | 'http://akiwi.eu/feature/fv50/']
73 | }
74 |
75 | def __init__(self, feature_size):
76 | """
77 | Based on the feature size, queries the corresponding API endpoint.
78 | Feature size 114 are sizes 64 and 50 appended.
79 | :param feature_size: the size of the feature vector to download
80 | """
81 | self.urls = self._SIZE_URLS[feature_size]
82 |
83 | def get_feature(self, img: Image):
84 |
85 | features = []
86 | for url in self.urls:
87 | f = self._get_url_feature(url, img)
88 | features.append(f)
89 |
90 | feature = np.concatenate(features)
91 | return feature
92 |
93 | @staticmethod
94 | def _get_url_feature(url, img: Image):
95 | """
96 | Query the Akiwi API to get an image feature vector
97 | :param url: URL of the API to query
98 | :param img: image for which to get feature vector
99 | :return: feature numpy array
100 | """
101 |
102 | img_bytes = io.BytesIO()
103 | img.save(img_bytes, format='JPEG')
104 | files = {'file': ('img.jpg', img_bytes.getvalue(), 'image/jpeg')}
105 | response = requests.post(url, files=files, timeout=10)
106 |
107 | retries = 0
108 | while (response.status_code != 200) & (retries < 3):
109 | retries += 1
110 | response = requests.post(url, files=files, timeout=10)
111 |
112 | if response.status_code == 200:
113 | response_feature = response.content
114 | feature = np.frombuffer(response_feature, dtype=np.uint8)
115 | return feature
116 |
117 | print("Couldn't get feature. Response: ", response)
118 |
119 |
120 | def download_feature_vectors(files, save_dir, feature_generator=None):
121 | """
122 | Downloads feature vectors for a list of files.
123 | :param files: list of files
124 | :param save_dir: directory to save feature vectors
125 | :param feature_generator: feature generator to create features
126 | """
127 |
128 | if not feature_generator:
129 | feature_generator = AkiwiFeatureGenerator(114)
130 |
131 | if not os.path.exists(save_dir):
132 | os.makedirs(save_dir)
133 |
134 | for idx, file in enumerate(files):
135 | if idx % 100 == 0:
136 | print('Downloaded {} / {}'.format(idx, len(files)))
137 |
138 | # name feature with the same basename as file
139 | save_path = os.path.join(save_dir, os.path.basename(file).split('.jpg')[0] + '.npy')
140 | if os.path.exists(save_path):
141 | continue
142 |
143 | feature = feature_generator.get_feature(Image.open(file))
144 | np.save(save_path, feature)
145 |
146 |
147 | def main():
148 | print('done')
149 |
150 |
151 | if __name__ == '__main__':
152 | main()
--------------------------------------------------------------------------------
/src/gans.py:
--------------------------------------------------------------------------------
1 | from networks import stargan, pix2pix, cyclegan
2 | from torchvision import transforms
3 | import torch
4 | from PIL import Image
5 | import os
6 | import glob
7 | import random
8 |
9 | class Modifier():
10 | """
11 | Modifier class that contains all GAN generator to modify attributes of
12 | an input image.
13 | """
14 |
15 | def __init__(self, models_root):
16 | """
17 | :param models_root: path to folder containg all models
18 | """
19 | self._shape_modifier = _StarGANModifier(os.path.join(models_root, 'stargan'))
20 | self._pattern_modifier = _CycleGANModifier(os.path.join(models_root, 'cyclegan'))
21 | self._model_generator = _Pix2PixModifier(os.path.join(models_root, 'pix2pix_models.pth'))
22 |
23 | def modify_shape(self, image: Image, attribute: str, value: str):
24 | return self._shape_modifier.modify_image(image, attribute, value)
25 |
26 | def modify_pattern(self, image: Image, attribute: str, value: str):
27 | return self._pattern_modifier.modify_image(image, attribute, value)
28 |
29 | def product_to_model(self, image: Image):
30 | return self._model_generator.generate_image(image)
31 |
32 | def get_shape_labels(self):
33 | return self._shape_modifier.LABELS
34 |
35 | def get_pattern_labels(self):
36 | return self._pattern_modifier.LABELS
37 |
38 |
39 | class _BaseModifier():
40 | def __init__(self, img_size):
41 |
42 | self.TRANSFORMS = transforms.Compose([
43 | transforms.Resize(img_size, interpolation=Image.ANTIALIAS),
44 | transforms.ToTensor(),
45 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
46 |
47 | @staticmethod
48 | def denorm_tensor(img_tensor):
49 |
50 | img_d = (img_tensor + 1) / 2
51 | img_d = img_d.clamp_(0, 1)
52 | img_d = img_d.data.mul(255).clamp(0, 255).byte()
53 | img_d = img_d.permute(1, 2, 0).cpu().numpy()
54 | return Image.fromarray(img_d)
55 |
56 |
57 | class _StarGANModifier(_BaseModifier):
58 |
59 | LABELS = {
60 | 'sleeve_length': ['3/4', 'long', 'short', 'sleeveless'],
61 | 'fit': ['loose', 'normal', 'tight'],
62 | 'neckline': ['round', 'v', 'wide']
63 | }
64 |
65 | IMAGE_SIZE = 128
66 |
67 | def __init__(self, G_path_root):
68 |
69 | self.G_models = {}
70 |
71 | for attr, label_map in self.LABELS.items():
72 | num_classes = len(self.LABELS[attr])
73 |
74 | G_path = os.path.join(G_path_root, attr + '.pth')
75 | self.G_models[attr] = self.load_model(G_path, num_classes)
76 |
77 | super().__init__(self.IMAGE_SIZE)
78 |
79 | @staticmethod
80 | def load_model(G_path, num_classes):
81 | try:
82 | G = stargan.Generator(c_dim=num_classes)
83 | G.load_state_dict(torch.load(G_path, map_location='cpu'))
84 | return G
85 |
86 | except:
87 | print("Couldn't find model", G_path)
88 |
89 | def modify_image(self, image: Image, attribute: str, value: str):
90 | assert attribute in self.LABELS.keys()
91 | assert value in self.LABELS[attribute]
92 |
93 | img_tensor = self.TRANSFORMS(image).unsqueeze(0)
94 | img_label = self.get_label(attribute, value)
95 | label_tensor = torch.FloatTensor(img_label).unsqueeze(0)
96 |
97 | G = self.G_models[attribute]
98 | fake_tensor = G(img_tensor, label_tensor).squeeze(0)
99 | fake_img = self.denorm_tensor(fake_tensor)
100 |
101 | return fake_img
102 |
103 | def get_label(self, attr, value):
104 | attr_len = len(self.LABELS[attr])
105 | label_idx = self.LABELS[attr].index(value)
106 |
107 | label = [0] * attr_len
108 | label[label_idx] = 1
109 |
110 | return label
111 |
112 |
113 | class _CycleGANModifier(_BaseModifier):
114 |
115 | LABELS = {
116 | 'floral': ['add', 'remove'],
117 | 'stripes': ['add', 'remove']
118 | }
119 |
120 | IMAGE_SIZE = 256
121 |
122 | def __init__(self, G_path_root):
123 |
124 | self.G_models = {}
125 |
126 | for attr, values in self.LABELS.items():
127 | self.G_models[attr] = {}
128 | attr_path = glob.glob(os.path.join(G_path_root, attr, '*.pth'))
129 |
130 | for value in values:
131 | G_paths = [p for p in attr_path if value in p]
132 | G_list = [self.load_model(p) for p in G_paths]
133 | self.G_models[attr][value] = G_list
134 |
135 | super().__init__(self.IMAGE_SIZE)
136 |
137 | @staticmethod
138 | def load_model(G_path):
139 | try:
140 | G = cyclegan.Generator()
141 | G.load_state_dict(torch.load(G_path, map_location='cpu'))
142 | return G
143 |
144 | except:
145 | print("Couldn't find model", G_path)
146 |
147 | def modify_image(self, image: Image, attribute: str, value: str):
148 | assert attribute in self.LABELS.keys()
149 | assert value in self.LABELS[attribute]
150 |
151 | img_tensor = self.TRANSFORMS(image).unsqueeze(0)
152 |
153 | G = random.choice(self.G_models[attribute][value])
154 | fake_tensor = G(img_tensor).squeeze(0)
155 |
156 | fake_img = self.denorm_tensor(fake_tensor)
157 |
158 | return fake_img
159 |
160 |
161 | class _Pix2PixModifier(_BaseModifier):
162 |
163 | IMAGE_SIZE = 256
164 |
165 | def __init__(self, G_path):
166 |
167 | self.G = pix2pix.Generator()
168 | self.G.load_state_dict(torch.load(G_path, map_location='cpu'))
169 |
170 | super().__init__(self.IMAGE_SIZE)
171 |
172 | def generate_image(self, image: Image):
173 | img_tensor = self.TRANSFORMS(image).unsqueeze(0)
174 | fake_tensor = self.G(img_tensor).squeeze(0)
175 | fake_img = self.denorm_tensor(fake_tensor)
176 |
177 | return fake_img
178 |
179 |
180 | def main():
181 | modifier = Modifier('../data/models/')
182 |
183 | test_img = Image.open('../data/images/test_images/dresses/NEW1407001000004.jpg')
184 | mod_img = modifier.modify_pattern(test_img, 'floral', 'add')
185 | print('done')
186 |
187 | if __name__ == '__main__':
188 | main()
--------------------------------------------------------------------------------
/src/pipeline.py:
--------------------------------------------------------------------------------
1 | from src.image_utils import plot_img, plot_img_row
2 | from PIL import Image
3 | from src.gans import Modifier
4 |
5 |
6 | class FashionGANApp():
7 | """
8 | App that allows user to modify a dress image and search for similar
9 | products via console input.
10 | """
11 |
12 | def __init__(self, modifier: Modifier, dress_search, model_search,
13 | num_imgs=10, metric='l1'):
14 | """
15 | :param modifier: GAN modifier object to generate modified images
16 | :param dress_search: Search or CombinedSearch object for products
17 | :param model_search: Search or CombinedSearch object for model images
18 | :param num_imgs: Number of similar images to retrieve
19 | """
20 | self._modifier = modifier
21 | self._dress_search = dress_search
22 | self._model_search = model_search
23 | self._num_similar_imgs = num_imgs
24 | self._search_metric = metric
25 |
26 | def start(self, input_img: Image):
27 | """
28 | Start the application - modify input image with console input.
29 | :param input_img: image to modify
30 | """
31 |
32 | product_img = input_img
33 | plot_img(product_img)
34 |
35 | # SHAPE MODIFICATION
36 | product_img = self._shape_modification(product_img)
37 |
38 | # PATTERN MODIFICATION
39 | product_img = self._pattern_modification(product_img)
40 |
41 | # MODEL IMAGE
42 | model_img = self._generate_model_image(product_img)
43 | mod_sim_imgs = self._search_models(model_img)
44 |
45 | # SEARCH
46 | prod_sim_imgs = self._search_products(product_img)
47 |
48 | # BEST IMAGE
49 | img_idx = self._select_best_image()
50 |
51 | # CONTINUE
52 | if img_idx != '':
53 | best_img = Image.open(prod_sim_imgs[int(img_idx)])
54 | self.start(best_img)
55 |
56 | def _shape_modification(self, img):
57 | self._print_title('SHAPE MODIFICATION')
58 | shape_labels = self._modifier.get_shape_labels()
59 | attr = self._ask_user_input(list(shape_labels.keys()), skip_option=True)
60 |
61 | if attr != '':
62 | value = self._ask_user_input(shape_labels[attr], skip_option=False)
63 | img = self._modifier.modify_shape(img, attr, value)
64 | plot_img(img)
65 |
66 | return img
67 |
68 | def _pattern_modification(self, img):
69 | self._print_title('PATTERN MODIFICATION')
70 | pattern_labels = self._modifier.get_pattern_labels()
71 | attr = self._ask_user_input(list(pattern_labels.keys()),
72 | skip_option=True)
73 |
74 | if attr != '':
75 | value = self._ask_user_input(pattern_labels[attr],
76 | skip_option=False)
77 | img = self._modifier.modify_pattern(img, attr, value)
78 | plot_img(img)
79 |
80 | return img
81 |
82 | def _generate_model_image(self, img):
83 | self._print_title('MODEL IMAGE')
84 | model_img = self._modifier.product_to_model(img)
85 | plot_img(model_img)
86 |
87 | return model_img
88 |
89 | def _search_products(self, img):
90 | self._print_title('SIMILAR PRODUCTS FROM PRODUCT SEARCH')
91 | prod_sim_imgs = self._dress_search.get_similar_images(
92 | img, self._num_similar_imgs, metric=self._search_metric)
93 | plot_img_row([Image.open(i) for i in prod_sim_imgs],
94 | img_labels=range(self._num_similar_imgs))
95 |
96 | return prod_sim_imgs
97 |
98 | def _search_models(self, img):
99 | self._print_title('MODEL SEARCH')
100 | mod_sim_imgs = self._model_search.get_similar_images(
101 | img, num_imgs=self._num_similar_imgs, metric=self._search_metric)
102 |
103 | plot_img_row([Image.open(i) for i in mod_sim_imgs])
104 |
105 | return mod_sim_imgs
106 |
107 | def _select_best_image(self):
108 | self._print_title("SELECT BEST IMAGE")
109 | img_options = list(map(str, range(self._num_similar_imgs)))
110 | img_idx = self._ask_user_input(img_options, skip_option=True)
111 |
112 | return img_idx
113 |
114 | def _ask_user_input(self, options, skip_option=True):
115 | while True:
116 | print("Choose from the following: {}".format(options))
117 | if skip_option:
118 | print("or press ENTER to skip")
119 | options.append('')
120 | attr = input()
121 |
122 | if attr not in options:
123 | print("Invalid input, try again.")
124 | else:
125 | print()
126 | return attr
127 |
128 | @staticmethod
129 | def _print_title(title):
130 | print(title)
131 | print('-' * 30)
132 |
133 |
134 | def main():
135 | from PIL import Image
136 | import os
137 |
138 | from src.gans import Modifier
139 | from src.features import AkiwiFeatureGenerator, ResnetFeatureGenerator
140 | from src.search import Search, CombinedSearch
141 |
142 | import warnings
143 | warnings.filterwarnings('ignore')
144 |
145 | folder_gens = {'akiwi_50': AkiwiFeatureGenerator(50),
146 | 'resnet': ResnetFeatureGenerator()}
147 |
148 | dress_imgs = '../data/images/fashion/dresses/'
149 | model_imgs = '../data/images/fashion_models/dresses_clustered/'
150 |
151 | dress_feats = '../data/features/fashion/dresses/'
152 | model_feats = '../data/features/fashion_models/dresses/'
153 |
154 | dress_search = {}
155 | for dir_name, gen in folder_gens.items():
156 | dress_search[dir_name] = Search(dress_imgs,
157 | os.path.join(dress_feats, dir_name),
158 | gen)
159 |
160 | model_search = {}
161 | for dir_name, gen in folder_gens.items():
162 | model_search[dir_name] = Search(model_imgs,
163 | os.path.join(model_feats, dir_name),
164 | gen)
165 |
166 | # combined search
167 | dress_resnet50 = CombinedSearch(
168 | [dress_search['akiwi_50'], dress_search['resnet']], factors=[2, 1])
169 | model_resnet50 = CombinedSearch(
170 | [model_search['akiwi_50'], model_search['resnet']], factors=[2, 1])
171 |
172 | modifier = Modifier('../data/models/')
173 | app = FashionGANApp(modifier, dress_resnet50, model_resnet50)
174 |
175 | test_img = Image.open('../data/images/fashion/dresses/1DR21C07A-Q11.jpg')
176 | app.start(test_img)
177 |
178 | print('done')
179 |
180 |
181 | if __name__ == '__main__':
182 | main()
--------------------------------------------------------------------------------
/src/search.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import glob
4 | from PIL import Image
5 | from sklearn.metrics import pairwise_distances
6 | from sklearn.preprocessing import StandardScaler
7 | from src import image_utils
8 | from src.features import ResnetFeatureGenerator, AkiwiFeatureGenerator
9 |
10 |
11 | class Search():
12 | """
13 | Loads all features from a folder and searches best images by comparing
14 | the distances between the features.
15 | """
16 |
17 | def __init__(self, images_root, features_root, feature_generator=None):
18 | """
19 | :param features_root: path to folder containing .npy features
20 | :param images_root: path to folder containg images with same names as
21 | in features_root
22 | :param feature_generator: generator to create features for new images
23 | """
24 |
25 | if not feature_generator:
26 | feature_generator = ResnetFeatureGenerator()
27 |
28 | self.features_root = features_root
29 | self.feature_generator = feature_generator
30 | self.images_root = images_root
31 |
32 | self.feature_names, self.features = self._load_features()
33 | self.scaler = self._scale_features()
34 |
35 | def _load_features(self):
36 | """
37 | Load features from the feature root into a numpy array
38 | """
39 |
40 | print('Loading features from:', self.features_root)
41 | feat_files = glob.glob(os.path.join(self.features_root, '*.npy'))
42 |
43 | if len(feat_files) == 0:
44 | raise ValueError(
45 | 'Features root {} is empty.'.format(self.features_root))
46 |
47 | feat_files = sorted(feat_files)
48 | feat_names = [os.path.basename(f).rsplit('.', 1)[0]
49 | for f in feat_files]
50 | features = np.array([np.load(f) for f in feat_files])
51 |
52 | return feat_names, features
53 |
54 | def _scale_features(self):
55 | """
56 | Scale all loaded features to mean=0 and std=1
57 | """
58 | scaler = StandardScaler()
59 | self.features = scaler.fit_transform(self.features)
60 |
61 | return scaler
62 |
63 | def _get_img_feature(self, img):
64 | img = image_utils.process_image(img)
65 |
66 | img_feature = self.feature_generator.get_feature(img)
67 | img_feature = img_feature.reshape(1, -1)
68 | img_feature = self.scaler.transform(img_feature)
69 |
70 | return img_feature
71 |
72 | def _search_similar_images(self, img, num_imgs, metric):
73 | """
74 | Calculate pairwise distances from the given image to the loaded
75 | features and return the image paths with the smallest distances
76 | """
77 |
78 | img_feature = self._get_img_feature(img)
79 |
80 | dist = pairwise_distances(img_feature, self.features, metric=metric)
81 | best_img_idxs = np.argsort(dist)[0].tolist()[:num_imgs]
82 | best_img_dist = dist[0][best_img_idxs]
83 |
84 | best_img_paths = [
85 | os.path.join(self.images_root, self.feature_names[i] + '.jpg')
86 | for i in best_img_idxs]
87 |
88 | return best_img_paths, best_img_dist
89 |
90 | def get_similar_images_with_distances(self, img: Image, num_imgs=8, metric='l1'):
91 | """
92 | Retrieve similar images and their distances to the given image compared
93 | with the given metric
94 | :param img: image for which to find similar images
95 | :param num_imgs: number of similar images to retrieve
96 | :param metric: distance metric to use when comparing features ['l1','l2']
97 | :return: list of best image paths and list of their distances
98 | to the original image
99 | """
100 | return self._search_similar_images(img, num_imgs, metric)
101 |
102 | def get_similar_images(self, img: Image, num_imgs=8, metric='l1'):
103 | """
104 | Retrieve similar images to the given image compared with the given metric
105 | :param img: image for which to find similar images
106 | :param num_imgs: number of similar images to retrieve
107 | :param metric: distance metric to use when comparing features ['l1','l2']
108 | :return: list of best image paths
109 | """
110 |
111 | best_img_paths, _ = self._search_similar_images(img, num_imgs, metric)
112 | return best_img_paths
113 |
114 |
115 | class CombinedSearch():
116 | """
117 | Combines a list of Search classes to enable search with combined features
118 | """
119 |
120 | def __init__(self, search_list, factors=None):
121 | """
122 | :param search_list: list of Search objects
123 | :param factors: list of factors with which to multiply the
124 | respective Search features
125 | """
126 |
127 | # take all features as equal
128 | if not factors:
129 | factors = [1] * len(search_list)
130 |
131 | self.factors = factors
132 | self.search_list = search_list
133 | self.features = []
134 |
135 | for search, factor in zip(search_list, factors):
136 | self.features.append(factor * search.features)
137 | self.features = np.concatenate(self.features, axis=1)
138 |
139 | def _get_img_feature(self, img):
140 | img = image_utils.process_image(img)
141 |
142 | img_feature = []
143 | for search, factor in zip(self.search_list, self.factors):
144 | feature = search.feature_generator.get_feature(img)
145 | feature = feature.reshape(1, -1)
146 | feature = search.scaler.transform(feature)
147 | img_feature.append(factor * feature.squeeze())
148 |
149 | img_feature = np.concatenate(img_feature)
150 | img_feature = img_feature.reshape(1, -1)
151 |
152 | return img_feature
153 |
154 | def _search_similar_images(self, img, num_imgs, metric):
155 | """
156 | Calculate pairwise distances from the given image to the loaded
157 | features and return the image paths with the smallest distances
158 | """
159 |
160 | img_feature = self._get_img_feature(img)
161 |
162 | dist = pairwise_distances(img_feature, self.features, metric=metric)
163 | best_img_idxs = np.argsort(dist)[0].tolist()[:num_imgs]
164 | best_img_dist = dist[0][best_img_idxs]
165 |
166 | best_img_paths = [
167 | os.path.join(self.search_list[0].images_root,
168 | self.search_list[0].feature_names[i] + '.jpg')
169 | for i in best_img_idxs]
170 |
171 | return best_img_paths, best_img_dist
172 |
173 | def get_similar_images_with_distances(self, img: Image, num_imgs=8, metric='l1'):
174 | """
175 | Retrieve similar images and their distances to the given image compared
176 | with the given metric
177 | :param img: image for which to find similar images
178 | :param num_imgs: number of similar images to retrieve
179 | :param metric: distance metric to use when comparing features ['l1','l2']
180 | :return: list of best image paths and list of their distances
181 | to the original image
182 | """
183 | return self._search_similar_images(img, num_imgs, metric)
184 |
185 | def get_similar_images(self, img: Image, num_imgs=8, metric='l1'):
186 | """
187 | Retrieve similar images to the given image compared with the given metric
188 | :param img: image for which to find similar images
189 | :param num_imgs: number of similar images to retrieve
190 | :param metric: distance metric to use when comparing features ['l1','l2']
191 | :return: list of best image paths
192 | """
193 |
194 | best_img_paths, _ = self._search_similar_images(img, num_imgs, metric)
195 | return best_img_paths
196 |
197 |
198 | def main():
199 | print('done')
200 |
201 |
202 | if __name__ == '__main__':
203 | main()
--------------------------------------------------------------------------------
/processing/feature_vectors_download.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Features Download\n",
8 | "This notebook downloads various image features for all listed images. The available image features to be downloaded are:\n",
9 | "* Akiwi50\n",
10 | "* Akiwi64\n",
11 | "* Akiwi114\n",
12 | "* Resnet152 original\n",
13 | "* Resnet152 retrained for fashion dataset"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": 2,
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "import os\n",
23 | "import glob\n",
24 | "import numpy as np\n",
25 | "import sys\n",
26 | "sys.path.append('..')\n",
27 | "\n",
28 | "from src import features, search"
29 | ]
30 | },
31 | {
32 | "cell_type": "markdown",
33 | "metadata": {},
34 | "source": [
35 | "## Load Data"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": 3,
41 | "metadata": {},
42 | "outputs": [],
43 | "source": [
44 | "# path with all image files\n",
45 | "data_path = '../data/images/fashion_models/dresses_clustered/*.jpg'\n",
46 | "# path to save features - subfolders for each feature name will be created\n",
47 | "feature_root = '../data/features/fashion_models/dresses/'"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": 4,
53 | "metadata": {},
54 | "outputs": [
55 | {
56 | "name": "stdout",
57 | "output_type": "stream",
58 | "text": [
59 | "num images: 15304\n"
60 | ]
61 | }
62 | ],
63 | "source": [
64 | "# read data\n",
65 | "filelist = glob.glob(data_path)\n",
66 | "filelist = sorted(filelist)\n",
67 | "print('num images: ', len(filelist))"
68 | ]
69 | },
70 | {
71 | "cell_type": "markdown",
72 | "metadata": {},
73 | "source": [
74 | "# Akiwi Features"
75 | ]
76 | },
77 | {
78 | "cell_type": "markdown",
79 | "metadata": {},
80 | "source": [
81 | "## 114 Features"
82 | ]
83 | },
84 | {
85 | "cell_type": "code",
86 | "execution_count": 5,
87 | "metadata": {},
88 | "outputs": [],
89 | "source": [
90 | "feature_path = os.path.join(feature_root, 'akiwi_114')\n",
91 | "feature_gen = search.AkiwiFeatureGenerator(114)"
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": 6,
97 | "metadata": {},
98 | "outputs": [
99 | {
100 | "name": "stdout",
101 | "output_type": "stream",
102 | "text": [
103 | "Downloaded 0 / 15304\n",
104 | "Downloaded 100 / 15304\n",
105 | "Downloaded 200 / 15304\n",
106 | "Downloaded 300 / 15304\n",
107 | "Downloaded 400 / 15304\n",
108 | "Downloaded 500 / 15304\n",
109 | "Downloaded 600 / 15304\n",
110 | "Downloaded 700 / 15304\n",
111 | "Downloaded 800 / 15304\n",
112 | "Downloaded 900 / 15304\n",
113 | "Downloaded 1000 / 15304\n",
114 | "Downloaded 1100 / 15304\n",
115 | "Downloaded 1200 / 15304\n",
116 | "Downloaded 1300 / 15304\n",
117 | "Downloaded 1400 / 15304\n",
118 | "Downloaded 1500 / 15304\n",
119 | "Downloaded 1600 / 15304\n",
120 | "Downloaded 1700 / 15304\n",
121 | "Downloaded 1800 / 15304\n",
122 | "Downloaded 1900 / 15304\n",
123 | "Downloaded 2000 / 15304\n",
124 | "Downloaded 2100 / 15304\n",
125 | "Downloaded 2200 / 15304\n",
126 | "Downloaded 2300 / 15304\n",
127 | "Downloaded 2400 / 15304\n",
128 | "Downloaded 2500 / 15304\n",
129 | "Downloaded 2600 / 15304\n",
130 | "Downloaded 2700 / 15304\n",
131 | "Downloaded 2800 / 15304\n",
132 | "Downloaded 2900 / 15304\n",
133 | "Downloaded 3000 / 15304\n",
134 | "Downloaded 3100 / 15304\n",
135 | "Downloaded 3200 / 15304\n",
136 | "Downloaded 3300 / 15304\n",
137 | "Downloaded 3400 / 15304\n",
138 | "Downloaded 3500 / 15304\n",
139 | "Downloaded 3600 / 15304\n",
140 | "Downloaded 3700 / 15304\n",
141 | "Downloaded 3800 / 15304\n",
142 | "Downloaded 3900 / 15304\n",
143 | "Downloaded 4000 / 15304\n",
144 | "Downloaded 4100 / 15304\n",
145 | "Downloaded 4200 / 15304\n",
146 | "Downloaded 4300 / 15304\n",
147 | "Downloaded 4400 / 15304\n",
148 | "Downloaded 4500 / 15304\n",
149 | "Downloaded 4600 / 15304\n",
150 | "Downloaded 4700 / 15304\n",
151 | "Downloaded 4800 / 15304\n",
152 | "Downloaded 4900 / 15304\n",
153 | "Downloaded 5000 / 15304\n",
154 | "Downloaded 5100 / 15304\n",
155 | "Downloaded 5200 / 15304\n",
156 | "Downloaded 5300 / 15304\n",
157 | "Downloaded 5400 / 15304\n",
158 | "Downloaded 5500 / 15304\n",
159 | "Downloaded 5600 / 15304\n",
160 | "Downloaded 5700 / 15304\n",
161 | "Downloaded 5800 / 15304\n",
162 | "Downloaded 5900 / 15304\n",
163 | "Downloaded 6000 / 15304\n",
164 | "Downloaded 6100 / 15304\n",
165 | "Downloaded 6200 / 15304\n",
166 | "Downloaded 6300 / 15304\n",
167 | "Downloaded 6400 / 15304\n",
168 | "Downloaded 6500 / 15304\n",
169 | "Downloaded 6600 / 15304\n",
170 | "Downloaded 6700 / 15304\n",
171 | "Downloaded 6800 / 15304\n",
172 | "Downloaded 6900 / 15304\n",
173 | "Downloaded 7000 / 15304\n",
174 | "Downloaded 7100 / 15304\n",
175 | "Downloaded 7200 / 15304\n",
176 | "Downloaded 7300 / 15304\n",
177 | "Downloaded 7400 / 15304\n",
178 | "Downloaded 7500 / 15304\n",
179 | "Downloaded 7600 / 15304\n",
180 | "Downloaded 7700 / 15304\n",
181 | "Downloaded 7800 / 15304\n",
182 | "Downloaded 7900 / 15304\n",
183 | "Downloaded 8000 / 15304\n",
184 | "Downloaded 8100 / 15304\n",
185 | "Downloaded 8200 / 15304\n",
186 | "Downloaded 8300 / 15304\n",
187 | "Downloaded 8400 / 15304\n",
188 | "Downloaded 8500 / 15304\n",
189 | "Downloaded 8600 / 15304\n",
190 | "Downloaded 8700 / 15304\n",
191 | "Downloaded 8800 / 15304\n",
192 | "Downloaded 8900 / 15304\n",
193 | "Downloaded 9000 / 15304\n",
194 | "Downloaded 9100 / 15304\n",
195 | "Downloaded 9200 / 15304\n",
196 | "Downloaded 9300 / 15304\n",
197 | "Downloaded 9400 / 15304\n",
198 | "Downloaded 9500 / 15304\n",
199 | "Downloaded 9600 / 15304\n",
200 | "Downloaded 9700 / 15304\n",
201 | "Downloaded 9800 / 15304\n",
202 | "Downloaded 9900 / 15304\n",
203 | "Downloaded 10000 / 15304\n",
204 | "Downloaded 10100 / 15304\n",
205 | "Downloaded 10200 / 15304\n",
206 | "Downloaded 10300 / 15304\n",
207 | "Downloaded 10400 / 15304\n",
208 | "Downloaded 10500 / 15304\n",
209 | "Downloaded 10600 / 15304\n",
210 | "Downloaded 10700 / 15304\n",
211 | "Downloaded 10800 / 15304\n",
212 | "Downloaded 10900 / 15304\n",
213 | "Downloaded 11000 / 15304\n",
214 | "Downloaded 11100 / 15304\n",
215 | "Downloaded 11200 / 15304\n",
216 | "Downloaded 11300 / 15304\n",
217 | "Downloaded 11400 / 15304\n",
218 | "Downloaded 11500 / 15304\n",
219 | "Downloaded 11600 / 15304\n",
220 | "Downloaded 11700 / 15304\n",
221 | "Downloaded 11800 / 15304\n",
222 | "Downloaded 11900 / 15304\n",
223 | "Downloaded 12000 / 15304\n",
224 | "Downloaded 12100 / 15304\n",
225 | "Downloaded 12200 / 15304\n",
226 | "Downloaded 12300 / 15304\n",
227 | "Downloaded 12400 / 15304\n",
228 | "Downloaded 12500 / 15304\n",
229 | "Downloaded 12600 / 15304\n",
230 | "Downloaded 12700 / 15304\n",
231 | "Downloaded 12800 / 15304\n",
232 | "Downloaded 12900 / 15304\n",
233 | "Downloaded 13000 / 15304\n",
234 | "Downloaded 13100 / 15304\n",
235 | "Downloaded 13200 / 15304\n",
236 | "Downloaded 13300 / 15304\n",
237 | "Downloaded 13400 / 15304\n",
238 | "Downloaded 13500 / 15304\n",
239 | "Downloaded 13600 / 15304\n",
240 | "Downloaded 13700 / 15304\n",
241 | "Downloaded 13800 / 15304\n",
242 | "Downloaded 13900 / 15304\n",
243 | "Downloaded 14000 / 15304\n",
244 | "Downloaded 14100 / 15304\n",
245 | "Downloaded 14200 / 15304\n",
246 | "Downloaded 14300 / 15304\n",
247 | "Downloaded 14400 / 15304\n",
248 | "Downloaded 14500 / 15304\n",
249 | "Downloaded 14600 / 15304\n",
250 | "Downloaded 14700 / 15304\n",
251 | "Downloaded 14800 / 15304\n",
252 | "Downloaded 14900 / 15304\n",
253 | "Downloaded 15000 / 15304\n",
254 | "Downloaded 15100 / 15304\n",
255 | "Downloaded 15200 / 15304\n",
256 | "Downloaded 15300 / 15304\n"
257 | ]
258 | }
259 | ],
260 | "source": [
261 | "features.download_feature_vectors(filelist, feature_path, feature_gen)"
262 | ]
263 | },
264 | {
265 | "cell_type": "markdown",
266 | "metadata": {},
267 | "source": [
268 | "## Split 114 features into 64 and 50"
269 | ]
270 | },
271 | {
272 | "cell_type": "code",
273 | "execution_count": 7,
274 | "metadata": {},
275 | "outputs": [],
276 | "source": [
277 | "feat_path64 = os.path.join(feature_root, 'akiwi_64')\n",
278 | "if not os.path.exists(feat_path64):\n",
279 | " os.makedirs(feat_path64)\n",
280 | " \n",
281 | "feat_path50 = os.path.join(feature_root, 'akiwi_50')\n",
282 | "if not os.path.exists(feat_path50):\n",
283 | " os.makedirs(feat_path50)"
284 | ]
285 | },
286 | {
287 | "cell_type": "code",
288 | "execution_count": 9,
289 | "metadata": {},
290 | "outputs": [
291 | {
292 | "name": "stdout",
293 | "output_type": "stream",
294 | "text": [
295 | "0\n",
296 | "1000\n",
297 | "2000\n",
298 | "3000\n",
299 | "4000\n",
300 | "5000\n",
301 | "6000\n",
302 | "7000\n",
303 | "8000\n",
304 | "9000\n",
305 | "10000\n",
306 | "11000\n",
307 | "12000\n",
308 | "13000\n",
309 | "14000\n",
310 | "15000\n"
311 | ]
312 | }
313 | ],
314 | "source": [
315 | "feats114 = glob.glob(os.path.join(feature_path,'*.npy'))\n",
316 | "for idx, file in enumerate(feats114):\n",
317 | " if idx % 1000 == 0:\n",
318 | " print(idx)\n",
319 | " \n",
320 | " fv = np.load(file)\n",
321 | " \n",
322 | " feat64 = fv[:64]\n",
323 | " np.save(os.path.join(feat_path64, os.path.basename(file)), feat64)\n",
324 | " \n",
325 | " feat50 = fv[64:]\n",
326 | " np.save(os.path.join(feat_path50, os.path.basename(file)), feat50)"
327 | ]
328 | },
329 | {
330 | "cell_type": "markdown",
331 | "metadata": {},
332 | "source": [
333 | "# Original ResNet152 Features"
334 | ]
335 | },
336 | {
337 | "cell_type": "code",
338 | "execution_count": 10,
339 | "metadata": {},
340 | "outputs": [],
341 | "source": [
342 | "feature_path = os.path.join(feature_root, 'resnet')\n",
343 | "feature_gen = search.ResnetFeatureGenerator()"
344 | ]
345 | },
346 | {
347 | "cell_type": "code",
348 | "execution_count": 11,
349 | "metadata": {},
350 | "outputs": [
351 | {
352 | "name": "stdout",
353 | "output_type": "stream",
354 | "text": [
355 | "Downloaded 0 / 15304\n",
356 | "Downloaded 100 / 15304\n",
357 | "Downloaded 200 / 15304\n",
358 | "Downloaded 300 / 15304\n",
359 | "Downloaded 400 / 15304\n",
360 | "Downloaded 500 / 15304\n",
361 | "Downloaded 600 / 15304\n",
362 | "Downloaded 700 / 15304\n",
363 | "Downloaded 800 / 15304\n",
364 | "Downloaded 900 / 15304\n",
365 | "Downloaded 1000 / 15304\n",
366 | "Downloaded 1100 / 15304\n",
367 | "Downloaded 1200 / 15304\n",
368 | "Downloaded 1300 / 15304\n",
369 | "Downloaded 1400 / 15304\n",
370 | "Downloaded 1500 / 15304\n",
371 | "Downloaded 1600 / 15304\n",
372 | "Downloaded 1700 / 15304\n",
373 | "Downloaded 1800 / 15304\n",
374 | "Downloaded 1900 / 15304\n",
375 | "Downloaded 2000 / 15304\n",
376 | "Downloaded 2100 / 15304\n",
377 | "Downloaded 2200 / 15304\n",
378 | "Downloaded 2300 / 15304\n",
379 | "Downloaded 2400 / 15304\n",
380 | "Downloaded 2500 / 15304\n",
381 | "Downloaded 2600 / 15304\n",
382 | "Downloaded 2700 / 15304\n",
383 | "Downloaded 2800 / 15304\n",
384 | "Downloaded 2900 / 15304\n",
385 | "Downloaded 3000 / 15304\n",
386 | "Downloaded 3100 / 15304\n",
387 | "Downloaded 3200 / 15304\n",
388 | "Downloaded 3300 / 15304\n",
389 | "Downloaded 3400 / 15304\n",
390 | "Downloaded 3500 / 15304\n",
391 | "Downloaded 3600 / 15304\n",
392 | "Downloaded 3700 / 15304\n",
393 | "Downloaded 3800 / 15304\n",
394 | "Downloaded 3900 / 15304\n",
395 | "Downloaded 4000 / 15304\n",
396 | "Downloaded 4100 / 15304\n",
397 | "Downloaded 4200 / 15304\n",
398 | "Downloaded 4300 / 15304\n",
399 | "Downloaded 4400 / 15304\n",
400 | "Downloaded 4500 / 15304\n",
401 | "Downloaded 4600 / 15304\n",
402 | "Downloaded 4700 / 15304\n",
403 | "Downloaded 4800 / 15304\n",
404 | "Downloaded 4900 / 15304\n",
405 | "Downloaded 5000 / 15304\n",
406 | "Downloaded 5100 / 15304\n",
407 | "Downloaded 5200 / 15304\n",
408 | "Downloaded 5300 / 15304\n",
409 | "Downloaded 5400 / 15304\n",
410 | "Downloaded 5500 / 15304\n",
411 | "Downloaded 5600 / 15304\n",
412 | "Downloaded 5700 / 15304\n",
413 | "Downloaded 5800 / 15304\n",
414 | "Downloaded 5900 / 15304\n",
415 | "Downloaded 6000 / 15304\n",
416 | "Downloaded 6100 / 15304\n",
417 | "Downloaded 6200 / 15304\n",
418 | "Downloaded 6300 / 15304\n",
419 | "Downloaded 6400 / 15304\n",
420 | "Downloaded 6500 / 15304\n",
421 | "Downloaded 6600 / 15304\n",
422 | "Downloaded 6700 / 15304\n",
423 | "Downloaded 6800 / 15304\n",
424 | "Downloaded 6900 / 15304\n",
425 | "Downloaded 7000 / 15304\n",
426 | "Downloaded 7100 / 15304\n",
427 | "Downloaded 7200 / 15304\n",
428 | "Downloaded 7300 / 15304\n",
429 | "Downloaded 7400 / 15304\n",
430 | "Downloaded 7500 / 15304\n",
431 | "Downloaded 7600 / 15304\n",
432 | "Downloaded 7700 / 15304\n",
433 | "Downloaded 7800 / 15304\n",
434 | "Downloaded 7900 / 15304\n",
435 | "Downloaded 8000 / 15304\n",
436 | "Downloaded 8100 / 15304\n",
437 | "Downloaded 8200 / 15304\n",
438 | "Downloaded 8300 / 15304\n",
439 | "Downloaded 8400 / 15304\n",
440 | "Downloaded 8500 / 15304\n",
441 | "Downloaded 8600 / 15304\n",
442 | "Downloaded 8700 / 15304\n",
443 | "Downloaded 8800 / 15304\n",
444 | "Downloaded 8900 / 15304\n",
445 | "Downloaded 9000 / 15304\n",
446 | "Downloaded 9100 / 15304\n",
447 | "Downloaded 9200 / 15304\n",
448 | "Downloaded 9300 / 15304\n",
449 | "Downloaded 9400 / 15304\n",
450 | "Downloaded 9500 / 15304\n",
451 | "Downloaded 9600 / 15304\n",
452 | "Downloaded 9700 / 15304\n",
453 | "Downloaded 9800 / 15304\n",
454 | "Downloaded 9900 / 15304\n",
455 | "Downloaded 10000 / 15304\n",
456 | "Downloaded 10100 / 15304\n",
457 | "Downloaded 10200 / 15304\n",
458 | "Downloaded 10300 / 15304\n",
459 | "Downloaded 10400 / 15304\n",
460 | "Downloaded 10500 / 15304\n",
461 | "Downloaded 10600 / 15304\n",
462 | "Downloaded 10700 / 15304\n",
463 | "Downloaded 10800 / 15304\n",
464 | "Downloaded 10900 / 15304\n",
465 | "Downloaded 11000 / 15304\n",
466 | "Downloaded 11100 / 15304\n",
467 | "Downloaded 11200 / 15304\n",
468 | "Downloaded 11300 / 15304\n",
469 | "Downloaded 11400 / 15304\n",
470 | "Downloaded 11500 / 15304\n",
471 | "Downloaded 11600 / 15304\n",
472 | "Downloaded 11700 / 15304\n",
473 | "Downloaded 11800 / 15304\n",
474 | "Downloaded 11900 / 15304\n",
475 | "Downloaded 12000 / 15304\n",
476 | "Downloaded 12100 / 15304\n",
477 | "Downloaded 12200 / 15304\n",
478 | "Downloaded 12300 / 15304\n",
479 | "Downloaded 12400 / 15304\n",
480 | "Downloaded 12500 / 15304\n",
481 | "Downloaded 12600 / 15304\n",
482 | "Downloaded 12700 / 15304\n",
483 | "Downloaded 12800 / 15304\n",
484 | "Downloaded 12900 / 15304\n",
485 | "Downloaded 13000 / 15304\n",
486 | "Downloaded 13100 / 15304\n",
487 | "Downloaded 13200 / 15304\n",
488 | "Downloaded 13300 / 15304\n",
489 | "Downloaded 13400 / 15304\n",
490 | "Downloaded 13500 / 15304\n",
491 | "Downloaded 13600 / 15304\n",
492 | "Downloaded 13700 / 15304\n",
493 | "Downloaded 13800 / 15304\n",
494 | "Downloaded 13900 / 15304\n",
495 | "Downloaded 14000 / 15304\n",
496 | "Downloaded 14100 / 15304\n",
497 | "Downloaded 14200 / 15304\n",
498 | "Downloaded 14300 / 15304\n",
499 | "Downloaded 14400 / 15304\n",
500 | "Downloaded 14500 / 15304\n",
501 | "Downloaded 14600 / 15304\n",
502 | "Downloaded 14700 / 15304\n",
503 | "Downloaded 14800 / 15304\n",
504 | "Downloaded 14900 / 15304\n",
505 | "Downloaded 15000 / 15304\n",
506 | "Downloaded 15100 / 15304\n",
507 | "Downloaded 15200 / 15304\n",
508 | "Downloaded 15300 / 15304\n"
509 | ]
510 | }
511 | ],
512 | "source": [
513 | "features.download_feature_vectors(filelist, feature_path, feature_gen)"
514 | ]
515 | },
516 | {
517 | "cell_type": "markdown",
518 | "metadata": {},
519 | "source": [
520 | "# Retrained ResNet152 Features"
521 | ]
522 | },
523 | {
524 | "cell_type": "code",
525 | "execution_count": 12,
526 | "metadata": {},
527 | "outputs": [],
528 | "source": [
529 | "feature_path = os.path.join(feature_root, 'resnet_retrained')\n",
530 | "feature_gen = search.ResnetFeatureGenerator('../data/models/resnet152_retrained.pth')"
531 | ]
532 | },
533 | {
534 | "cell_type": "code",
535 | "execution_count": 13,
536 | "metadata": {},
537 | "outputs": [
538 | {
539 | "name": "stdout",
540 | "output_type": "stream",
541 | "text": [
542 | "Downloaded 0 / 15304\n",
543 | "Downloaded 100 / 15304\n",
544 | "Downloaded 200 / 15304\n",
545 | "Downloaded 300 / 15304\n",
546 | "Downloaded 400 / 15304\n",
547 | "Downloaded 500 / 15304\n",
548 | "Downloaded 600 / 15304\n",
549 | "Downloaded 700 / 15304\n",
550 | "Downloaded 800 / 15304\n",
551 | "Downloaded 900 / 15304\n",
552 | "Downloaded 1000 / 15304\n",
553 | "Downloaded 1100 / 15304\n",
554 | "Downloaded 1200 / 15304\n",
555 | "Downloaded 1300 / 15304\n",
556 | "Downloaded 1400 / 15304\n",
557 | "Downloaded 1500 / 15304\n",
558 | "Downloaded 1600 / 15304\n",
559 | "Downloaded 1700 / 15304\n",
560 | "Downloaded 1800 / 15304\n",
561 | "Downloaded 1900 / 15304\n",
562 | "Downloaded 2000 / 15304\n",
563 | "Downloaded 2100 / 15304\n",
564 | "Downloaded 2200 / 15304\n",
565 | "Downloaded 2300 / 15304\n",
566 | "Downloaded 2400 / 15304\n",
567 | "Downloaded 2500 / 15304\n",
568 | "Downloaded 2600 / 15304\n",
569 | "Downloaded 2700 / 15304\n",
570 | "Downloaded 2800 / 15304\n",
571 | "Downloaded 2900 / 15304\n",
572 | "Downloaded 3000 / 15304\n",
573 | "Downloaded 3100 / 15304\n",
574 | "Downloaded 3200 / 15304\n",
575 | "Downloaded 3300 / 15304\n",
576 | "Downloaded 3400 / 15304\n",
577 | "Downloaded 3500 / 15304\n",
578 | "Downloaded 3600 / 15304\n",
579 | "Downloaded 3700 / 15304\n",
580 | "Downloaded 3800 / 15304\n",
581 | "Downloaded 3900 / 15304\n",
582 | "Downloaded 4000 / 15304\n",
583 | "Downloaded 4100 / 15304\n",
584 | "Downloaded 4200 / 15304\n",
585 | "Downloaded 4300 / 15304\n",
586 | "Downloaded 4400 / 15304\n",
587 | "Downloaded 4500 / 15304\n",
588 | "Downloaded 4600 / 15304\n",
589 | "Downloaded 4700 / 15304\n",
590 | "Downloaded 4800 / 15304\n",
591 | "Downloaded 4900 / 15304\n",
592 | "Downloaded 5000 / 15304\n",
593 | "Downloaded 5100 / 15304\n",
594 | "Downloaded 5200 / 15304\n",
595 | "Downloaded 5300 / 15304\n",
596 | "Downloaded 5400 / 15304\n",
597 | "Downloaded 5500 / 15304\n",
598 | "Downloaded 5600 / 15304\n",
599 | "Downloaded 5700 / 15304\n",
600 | "Downloaded 5800 / 15304\n",
601 | "Downloaded 5900 / 15304\n",
602 | "Downloaded 6000 / 15304\n",
603 | "Downloaded 6100 / 15304\n",
604 | "Downloaded 6200 / 15304\n",
605 | "Downloaded 6300 / 15304\n",
606 | "Downloaded 6400 / 15304\n",
607 | "Downloaded 6500 / 15304\n",
608 | "Downloaded 6600 / 15304\n",
609 | "Downloaded 6700 / 15304\n",
610 | "Downloaded 6800 / 15304\n",
611 | "Downloaded 6900 / 15304\n",
612 | "Downloaded 7000 / 15304\n",
613 | "Downloaded 7100 / 15304\n",
614 | "Downloaded 7200 / 15304\n",
615 | "Downloaded 7300 / 15304\n",
616 | "Downloaded 7400 / 15304\n",
617 | "Downloaded 7500 / 15304\n",
618 | "Downloaded 7600 / 15304\n",
619 | "Downloaded 7700 / 15304\n",
620 | "Downloaded 7800 / 15304\n",
621 | "Downloaded 7900 / 15304\n",
622 | "Downloaded 8000 / 15304\n",
623 | "Downloaded 8100 / 15304\n",
624 | "Downloaded 8200 / 15304\n",
625 | "Downloaded 8300 / 15304\n",
626 | "Downloaded 8400 / 15304\n",
627 | "Downloaded 8500 / 15304\n",
628 | "Downloaded 8600 / 15304\n",
629 | "Downloaded 8700 / 15304\n",
630 | "Downloaded 8800 / 15304\n",
631 | "Downloaded 8900 / 15304\n",
632 | "Downloaded 9000 / 15304\n",
633 | "Downloaded 9100 / 15304\n",
634 | "Downloaded 9200 / 15304\n",
635 | "Downloaded 9300 / 15304\n",
636 | "Downloaded 9400 / 15304\n",
637 | "Downloaded 9500 / 15304\n",
638 | "Downloaded 9600 / 15304\n",
639 | "Downloaded 9700 / 15304\n",
640 | "Downloaded 9800 / 15304\n",
641 | "Downloaded 9900 / 15304\n",
642 | "Downloaded 10000 / 15304\n",
643 | "Downloaded 10100 / 15304\n",
644 | "Downloaded 10200 / 15304\n",
645 | "Downloaded 10300 / 15304\n",
646 | "Downloaded 10400 / 15304\n",
647 | "Downloaded 10500 / 15304\n",
648 | "Downloaded 10600 / 15304\n",
649 | "Downloaded 10700 / 15304\n",
650 | "Downloaded 10800 / 15304\n",
651 | "Downloaded 10900 / 15304\n",
652 | "Downloaded 11000 / 15304\n",
653 | "Downloaded 11100 / 15304\n",
654 | "Downloaded 11200 / 15304\n",
655 | "Downloaded 11300 / 15304\n",
656 | "Downloaded 11400 / 15304\n",
657 | "Downloaded 11500 / 15304\n",
658 | "Downloaded 11600 / 15304\n",
659 | "Downloaded 11700 / 15304\n",
660 | "Downloaded 11800 / 15304\n",
661 | "Downloaded 11900 / 15304\n",
662 | "Downloaded 12000 / 15304\n",
663 | "Downloaded 12100 / 15304\n",
664 | "Downloaded 12200 / 15304\n",
665 | "Downloaded 12300 / 15304\n",
666 | "Downloaded 12400 / 15304\n",
667 | "Downloaded 12500 / 15304\n",
668 | "Downloaded 12600 / 15304\n",
669 | "Downloaded 12700 / 15304\n",
670 | "Downloaded 12800 / 15304\n",
671 | "Downloaded 12900 / 15304\n",
672 | "Downloaded 13000 / 15304\n",
673 | "Downloaded 13100 / 15304\n",
674 | "Downloaded 13200 / 15304\n",
675 | "Downloaded 13300 / 15304\n",
676 | "Downloaded 13400 / 15304\n",
677 | "Downloaded 13500 / 15304\n",
678 | "Downloaded 13600 / 15304\n",
679 | "Downloaded 13700 / 15304\n",
680 | "Downloaded 13800 / 15304\n",
681 | "Downloaded 13900 / 15304\n",
682 | "Downloaded 14000 / 15304\n",
683 | "Downloaded 14100 / 15304\n",
684 | "Downloaded 14200 / 15304\n",
685 | "Downloaded 14300 / 15304\n",
686 | "Downloaded 14400 / 15304\n",
687 | "Downloaded 14500 / 15304\n",
688 | "Downloaded 14600 / 15304\n",
689 | "Downloaded 14700 / 15304\n",
690 | "Downloaded 14800 / 15304\n",
691 | "Downloaded 14900 / 15304\n",
692 | "Downloaded 15000 / 15304\n",
693 | "Downloaded 15100 / 15304\n",
694 | "Downloaded 15200 / 15304\n",
695 | "Downloaded 15300 / 15304\n"
696 | ]
697 | }
698 | ],
699 | "source": [
700 | "features.download_feature_vectors(filelist, feature_path, feature_gen)"
701 | ]
702 | },
703 | {
704 | "cell_type": "code",
705 | "execution_count": null,
706 | "metadata": {
707 | "collapsed": true
708 | },
709 | "outputs": [],
710 | "source": []
711 | }
712 | ],
713 | "metadata": {
714 | "kernelspec": {
715 | "display_name": "Python [default]",
716 | "language": "python",
717 | "name": "python3"
718 | },
719 | "language_info": {
720 | "codemirror_mode": {
721 | "name": "ipython",
722 | "version": 3
723 | },
724 | "file_extension": ".py",
725 | "mimetype": "text/x-python",
726 | "name": "python",
727 | "nbconvert_exporter": "python",
728 | "pygments_lexer": "ipython3",
729 | "version": "3.6.6"
730 | }
731 | },
732 | "nbformat": 4,
733 | "nbformat_minor": 2
734 | }
735 |
--------------------------------------------------------------------------------