├── src ├── __init__.py ├── image_utils.py ├── features.py ├── gans.py ├── pipeline.py └── search.py ├── data └── download_data.sh ├── pip_requirements.txt ├── .gitignore ├── README.md ├── conda_requirements.txt ├── networks ├── pix2pix.py ├── cyclegan.py └── stargan.py └── processing └── feature_vectors_download.ipynb /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/download_data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | echo "Downloading images..." 3 | curl -Lo images.zip https://s3.eu-central-1.amazonaws.com/fashion-gan/images.zip 4 | echo "Downloading features..." 5 | curl -Lo features.zip https://s3.eu-central-1.amazonaws.com/fashion-gan/features.zip 6 | echo "Downloading models..." 7 | curl -Lo models.zip https://s3.eu-central-1.amazonaws.com/fashion-gan/models.zip 8 | echo "Downloading clustering data..." 9 | curl -Lo clustering.zip https://s3.eu-central-1.amazonaws.com/fashion-gan/clustering.zip 10 | 11 | echo "Unzipping files..." 12 | unzip -q images.zip 13 | unzip -q features.zip 14 | unzip -q models.zip 15 | unzip -q clustering.zip 16 | 17 | rm *.zip 18 | -------------------------------------------------------------------------------- /pip_requirements.txt: -------------------------------------------------------------------------------- 1 | anaconda-client==1.7.2 2 | appdirs==1.4.3 3 | appnope==0.1.0 4 | asn1crypto==0.24.0 5 | attrs==18.2.0 6 | Automat==0.7.0 7 | backcall==0.1.0 8 | bleach==2.1.4 9 | certifi==2018.8.24 10 | cffi==1.11.5 11 | chardet==3.0.4 12 | clyent==1.2.2 13 | constantly==15.1.0 14 | cryptography==2.3.1 15 | cycler==0.10.0 16 | Cython==0.28.5 17 | decorator==4.3.0 18 | entrypoints==0.2.3 19 | html5lib==1.0.1 20 | hyperlink==18.0.0 21 | idna==2.7 22 | incremental==17.5.0 23 | ipykernel==4.9.0 24 | ipython==6.5.0 25 | ipython-genutils==0.2.0 26 | ipywidgets==7.4.1 27 | jedi==0.12.1 28 | Jinja2==2.10 29 | jsonschema==2.6.0 30 | jupyter-client==5.2.3 31 | jupyter-core==4.4.0 32 | kiwisolver==1.0.1 33 | MarkupSafe==1.0 34 | matplotlib==2.2.3 35 | mistune==0.8.3 36 | mkl-fft==1.0.4 37 | mkl-random==1.0.1 38 | nb-anacondacloud==1.4.0 39 | nb-conda==2.2.1 40 | nb-conda-kernels==2.1.0 41 | nbconvert==5.3.1 42 | nbformat==4.4.0 43 | nbpresent==3.0.2 44 | notebook==5.6.0 45 | numpy==1.15.1 46 | olefile==0.46 47 | pandas==0.23.4 48 | pandocfilters==1.4.2 49 | parso==0.3.1 50 | pexpect==4.6.0 51 | pickleshare==0.7.4 52 | Pillow==5.2.0 53 | prometheus-client==0.3.1 54 | prompt-toolkit==1.0.15 55 | ptyprocess==0.6.0 56 | pyasn1==0.4.4 57 | pyasn1-modules==0.2.2 58 | pycparser==2.18 59 | Pygments==2.2.0 60 | pyOpenSSL==18.0.0 61 | pyparsing==2.2.0 62 | PySocks==1.6.8 63 | python-dateutil==2.7.3 64 | pytz==2018.5 65 | PyYAML==3.13 66 | pyzmq==17.1.2 67 | requests==2.20.0 68 | scikit-learn==0.19.2 69 | scipy==1.1.0 70 | Send2Trash==1.5.0 71 | service-identity==17.0.0 72 | simplegeneric==0.8.1 73 | six==1.11.0 74 | terminado==0.8.1 75 | testpath==0.3.1 76 | torch==0.4.1 77 | torchvision==0.2.1 78 | tornado==5.1 79 | traitlets==4.3.2 80 | Twisted==17.5.0 81 | urllib3==1.23 82 | wcwidth==0.1.7 83 | webencodings==0.5.1 84 | widgetsnbextension==3.4.1 85 | zope.interface==4.5.0 86 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | !data/download_data.sh 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # celery beat schedule file 87 | celerybeat-schedule 88 | 89 | # SageMath parsed files 90 | *.sage.py 91 | 92 | # Environments 93 | .env 94 | .venv 95 | env/ 96 | venv/ 97 | ENV/ 98 | env.bak/ 99 | venv.bak/ 100 | 101 | # Spyder project settings 102 | .spyderproject 103 | .spyproject 104 | 105 | # Rope project settings 106 | .ropeproject 107 | 108 | # mkdocs documentation 109 | /site 110 | 111 | # mypy 112 | .mypy_cache/ 113 | .dmypy.json 114 | dmypy.json -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FashionGAN Search 2 | 3 | This project is part of the evaluation of generative adversarial networks for improving image retrieval systems. It uses a fashion dataset to synthesize new images of fashion products based on user input, and to trigger a search of similar existing products. 4 | The main application allows user to modify the shape and pattern of a dress, and then choose the best match from the retrieved products. The user can then further modify the chosen product. 5 | 6 | ![results](https://raw.githubusercontent.com/sonynka/fashion_gan/images/results.png) 7 | 8 | ## Usage 9 | 10 | #### App 11 | To use the project run the *FashionGAN_search.ipynb* notebook in the given conda environment (see requirements). The application started in the notebook prompts the user to control the image modifications and search by text input. 12 | 13 | #### Processing 14 | The notebooks in processing folder were used to download the feature vectors for image retrieval and clustering model images. All the data that they produce is already provided in the data folder. However, these notebooks can be run to further understand these processing steps. 15 | 16 | #### Networks 17 | The networks folder contains the three generators used in the final model 18 | - **StarGAN** originally from: https://github.com/yunjey/StarGAN 19 | - **CycleGAN** originally from: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix 20 | - **Pix2Pix** originally from: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix 21 | 22 | The networks were trained on the fashion dataset, and the best models are provided in the data folder. 23 | 24 | ## Setup 25 | 26 | #### Data 27 | All data neccessary for running this project can be downloaded by running the following script: 28 | 29 | ```bash 30 | cd data 31 | ./download_data.sh 32 | ``` 33 | 34 | The script will download the following folders: 35 | - **images**: images of dresses and models wearing those dresses that the networks were trained on (cca 15.000 product images and 60.000 model images). 36 | - **clustering**: models and data for clustering of available model images to create a paired dataset of 1 product image + 1 model image 37 | - **features**: feature vectors for both product images and clustered model images for retrieval 38 | - **models**: trained GAN models to modify attributes of images (The models were trained using several GANs repositories: **Pix2Pix** and **CycleGAN** on https://github.com/sonynka/pytorch-CycleGAN-and-pix2pix and **StarGAN** on https://github.com/sonynka/StarGAN. 39 | 40 | **Note**: The original dataset was scraped from various fashion online stores and contains cca 90.000 images. For the purpose of this project, I only used category dresses. Code for scraping and the whole dataset can be found here: https://github.com/sonynka/fashion_scraper. 41 | 42 | #### Requirements 43 | To download Anaconda package manager, go to: https://www.continuum.io/downloads. 44 | After installing the conda environment locally, proceed to setup this project environment. 45 | 46 | Install all dependencies from conda_requirements.txt file. 47 | ```bash 48 | conda create -n fashion_gan python=3.6 49 | source activate fashion_gan 50 | conda install --file conda_requirements.txt 51 | pip install -r pip_requirements.txt 52 | ``` 53 | 54 | To start a jupyter notebook in the environment: 55 | ```bash 56 | source activate fashion_gan 57 | jupyter notebook 58 | ``` 59 | 60 | 61 | To deactivate this specific virtual environment: 62 | ```bash 63 | source deactivate 64 | ``` 65 | 66 | If you need to completely remove this conda env, you can use the following command: 67 | 68 | ```bash 69 | conda env remove --name fashion_gan 70 | ``` 71 | -------------------------------------------------------------------------------- /src/image_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageOps, ImageFilter 2 | import numpy as np 3 | import cv2 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | def pad_image(img): 8 | width, height = img.size 9 | 10 | max_size = max(width, height) 11 | 12 | pad_height = max_size - height 13 | pad_width = max_size - width 14 | 15 | padding = (pad_width // 2, 16 | pad_height // 2, 17 | pad_width - (pad_width // 2), 18 | pad_height - (pad_height // 2)) 19 | 20 | padded_img = ImageOps.expand(img, padding, fill=(255, 255, 255)) 21 | return padded_img 22 | 23 | 24 | def remove_alpha(img): 25 | if img.mode == 'RGBA': 26 | img.load() # required for png.split() 27 | image_jpeg = Image.new("RGB", img.size, (255, 255, 255)) 28 | image_jpeg.paste(img, mask=img.split()[3]) # 3 is the alpha channel 29 | img = image_jpeg 30 | 31 | return img 32 | 33 | 34 | def resize_image(img, size): 35 | return img.resize(size, Image.ANTIALIAS) 36 | 37 | 38 | def process_image(img: Image, size=None): 39 | if not size: 40 | size = [256, 256] 41 | 42 | img = remove_alpha(img) 43 | img = pad_image(img) 44 | img = resize_image(img, size=size) 45 | 46 | return img 47 | 48 | 49 | def get_edges(img_path, img_size, crop=False): 50 | """ Get black-white canny edge detection for the image """ 51 | 52 | img = Image.open(img_path) 53 | if crop: 54 | w, h = img.size 55 | img = img.crop((70, 0, w - 70, h)) 56 | img = img.resize(img_size, Image.ANTIALIAS) 57 | img = img.filter(ImageFilter.FIND_EDGES) 58 | img = img.convert('L') 59 | img_arr = np.array(img) 60 | img.close() 61 | 62 | return img_arr 63 | 64 | 65 | def get_mask(img_path, img_size, crop=False): 66 | 67 | img = threshold_mask(img_path) 68 | 69 | if crop: 70 | w, h = img.size 71 | img = img.crop((70, 0, w - 70, h)) 72 | img = img.resize(img_size) 73 | img_arr = np.array(img) 74 | img.close() 75 | 76 | return img_arr 77 | 78 | 79 | def threshold_mask(img_path, thresh=200): 80 | """ 81 | Get image threshold mask. 82 | """ 83 | 84 | # Read image 85 | im_in = cv2.imread(img_path) 86 | 87 | # grayscale and blur 88 | im_in = cv2.cvtColor(im_in, cv2.COLOR_BGR2GRAY) 89 | im_in = cv2.GaussianBlur(im_in, (5, 5), 0) 90 | 91 | # get threshold mask 92 | _, thresh_mask = cv2.threshold(im_in, thresh, 255, cv2.THRESH_BINARY) 93 | im_masked = cv2.bitwise_not(thresh_mask) 94 | 95 | # fill contours of threshold mask 96 | _, contours, _ = cv2.findContours(im_masked, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 97 | for cnt in contours: 98 | cv2.drawContours(im_masked, [cnt], 0, 255, -1) 99 | 100 | # smooth and dilate 101 | smooth = cv2.GaussianBlur(im_masked, (9, 9), 0) 102 | _, thresh_mask2 = cv2.threshold(smooth, 0, 255, 103 | cv2.THRESH_BINARY + cv2.THRESH_OTSU) 104 | kernel = np.ones((2, 2), np.uint8) 105 | im_out = cv2.dilate(thresh_mask2, kernel, iterations=1) 106 | 107 | return Image.fromarray(im_out) 108 | 109 | 110 | def plot_img_row(images, img_labels=None): 111 | fig, axarr = plt.subplots(nrows=1, ncols=len(images), 112 | figsize=(len(images) * 3, 4)) 113 | 114 | for i, img in enumerate(images): 115 | ax = axarr[i] 116 | img = img.resize([256, 256]) 117 | img = img.crop((40, 0, 216, 256)) 118 | ax.imshow(img) 119 | ax.set_xticks([]) 120 | ax.set_yticks([]) 121 | 122 | for spine in ax.spines.keys(): 123 | ax.spines[spine].set_visible(False) 124 | 125 | if img_labels is not None: 126 | ax.set_title(img_labels[i]) 127 | 128 | plt.show() 129 | 130 | 131 | def plot_img(img): 132 | plt.imshow(img) 133 | plt.axis('off') 134 | plt.show() 135 | 136 | return img -------------------------------------------------------------------------------- /conda_requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: osx-64 4 | _nb_ext_conf=0.4.0=py36_1 5 | anaconda-client=1.7.2=py36_0 6 | appdirs=1.4.3=py36h28b3542_0 7 | appnope=0.1.0=py36hf537a9a_0 8 | asn1crypto=0.24.0=py36_0 9 | attrs=18.2.0=py36h28b3542_0 10 | automat=0.7.0=py36_0 11 | backcall=0.1.0=py36_0 12 | blas=1.0=mkl 13 | bleach=2.1.4=py36_0 14 | bzip2=1.0.6=h1de35cc_5 15 | ca-certificates=2018.03.07=0 16 | cairo=1.14.12=hc4e6be7_4 17 | certifi=2018.8.24=py36_1 18 | cffi=1.11.5=py36h6174b99_1 19 | chardet=3.0.4=py36_1 20 | clyent=1.2.2=py36_1 21 | constantly=15.1.0=py36h28b3542_0 22 | cryptography=2.3.1=py36hdbc3d79_0 23 | cycler=0.10.0=py36hfc81398_0 24 | decorator=4.3.0=py36_0 25 | entrypoints=0.2.3=py36_2 26 | ffmpeg=4.0=h01ea3c9_0 27 | fontconfig=2.13.0=h5d5b041_1 28 | freetype=2.9.1=hb4e5f40_0 29 | gettext=0.19.8.1=h15daf44_3 30 | glib=2.56.2=hd9629dc_0 31 | graphite2=1.3.12=h2098e52_2 32 | harfbuzz=1.8.8=hb8d4a28_0 33 | hdf5=1.10.2=hfa1e0ec_1 34 | html5lib=1.0.1=py36_0 35 | hyperlink=18.0.0=py36_0 36 | icu=58.2=h4b95b61_1 37 | idna=2.7=py36_0 38 | incremental=17.5.0=py36_0 39 | intel-openmp=2019.0=118 40 | ipykernel=4.9.0=py36_1 41 | ipython=6.5.0=py36_0 42 | ipython_genutils=0.2.0=py36h241746c_0 43 | ipywidgets=7.4.1=py36_0 44 | jasper=2.0.14=h636a363_1 45 | jedi=0.12.1=py36_0 46 | jinja2=2.10=py36_0 47 | jpeg=9b=he5867d9_2 48 | jsonschema=2.6.0=py36hb385e00_0 49 | jupyter_client=5.2.3=py36_0 50 | jupyter_core=4.4.0=py36_0 51 | kiwisolver=1.0.1=py36h0a44026_0 52 | libcxx=4.0.1=h579ed51_0 53 | libcxxabi=4.0.1=hebd6815_0 54 | libedit=3.1.20170329=hb402a30_2 55 | libffi=3.2.1=h475c297_4 56 | libgfortran=3.0.1=h93005f0_2 57 | libiconv=1.15=hdd342a3_7 58 | libopencv=3.4.2=h7c891bd_1 59 | libopus=1.2.1=h169cedb_0 60 | libpng=1.6.34=he12f830_0 61 | libsodium=1.0.16=h3efe00b_0 62 | libtiff=4.0.9=hcb84e12_2 63 | libvpx=1.7.0=h378b8a2_0 64 | libxml2=2.9.8=hab757c2_1 65 | markupsafe=1.0=py36h1de35cc_1 66 | matplotlib=2.2.3=py36h54f8f79_0 67 | mistune=0.8.3=py36h1de35cc_1 68 | mkl=2019.0=118 69 | mkl_fft=1.0.4=py36h5d10147_1 70 | mkl_random=1.0.1=py36h5d10147_1 71 | nb_anacondacloud=1.4.0=py36_0 72 | nb_conda=2.2.1=py36_0 73 | nb_conda_kernels=2.1.0=py36_0 74 | nbconvert=5.3.1=py36_0 75 | nbformat=4.4.0=py36h827af21_0 76 | nbpresent=3.0.2=py36_1 77 | ncurses=6.1=h0a44026_0 78 | ninja=1.8.2=py36h04f5b5a_1 79 | notebook=5.6.0=py36_0 80 | numpy=1.15.1=py36h6a91979_0 81 | numpy-base=1.15.1=py36h8a80b8c_0 82 | olefile=0.46=py36_0 83 | opencv=3.4.2=py36h6fd60c2_1 84 | openssl=1.0.2p=h1de35cc_0 85 | pandas=0.23.4=py36h6440ff4_0 86 | pandoc=2.2.3.2=0 87 | pandocfilters=1.4.2=py36_1 88 | parso=0.3.1=py36_0 89 | pcre=8.42=h378b8a2_0 90 | pexpect=4.6.0=py36_0 91 | pickleshare=0.7.4=py36hf512f8e_0 92 | pillow=5.2.0=py36hb68e598_0 93 | pip=10.0.1=py36_0 94 | pixman=0.34.0=hca0a616_3 95 | prometheus_client=0.3.1=py36h28b3542_0 96 | prompt_toolkit=1.0.15=py36haeda067_0 97 | ptyprocess=0.6.0=py36_0 98 | py-opencv=3.4.2=py36h7c891bd_1 99 | pyasn1=0.4.4=py36h28b3542_0 100 | pyasn1-modules=0.2.2=py36_0 101 | pycparser=2.18=py36_1 102 | pygments=2.2.0=py36h240cd3f_0 103 | pyopenssl=18.0.0=py36_0 104 | pyparsing=2.2.0=py36_1 105 | pysocks=1.6.8=py36_0 106 | python=3.6.6=hc167b69_0 107 | python-dateutil=2.7.3=py36_0 108 | pytz=2018.5=py36_0 109 | pyyaml=3.13=py36h1de35cc_0 110 | pyzmq=17.1.2=py36h1de35cc_0 111 | readline=7.0=h1de35cc_5 112 | requests=2.20.0=py36_0 113 | scikit-learn=0.19.2=py36h4f467ca_0 114 | scipy=1.1.0=py36h28f7352_1 115 | send2trash=1.5.0=py36_0 116 | service_identity=17.0.0=py36h28b3542_0 117 | setuptools=40.2.0=py36_0 118 | simplegeneric=0.8.1=py36_2 119 | six=1.11.0=py36_1 120 | sqlite=3.24.0=ha441bb4_0 121 | terminado=0.8.1=py36_1 122 | testpath=0.3.1=py36h625a49b_0 123 | tk=8.6.8=ha441bb4_0 124 | tornado=5.1=py36h1de35cc_0 125 | traitlets=4.3.2=py36h65bd3ce_0 126 | twisted=17.5.0=py36_0 127 | urllib3=1.23=py36_0 128 | wcwidth=0.1.7=py36h8c6ec74_0 129 | webencodings=0.5.1=py36_1 130 | wheel=0.31.1=py36_0 131 | widgetsnbextension=3.4.1=py36_0 132 | xz=5.2.4=h1de35cc_4 133 | yaml=0.1.7=hc338f04_2 134 | zeromq=4.2.5=h0a44026_1 135 | zlib=1.2.11=hf3cbc9b_2 136 | zope=1.0=py36_1 137 | zope.interface=4.5.0=py36h1de35cc_0 138 | -------------------------------------------------------------------------------- /networks/pix2pix.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | from torch.optim import lr_scheduler 6 | 7 | ############################################################################### 8 | # Helper Functions 9 | ############################################################################### 10 | 11 | 12 | # Defines the Unet generator. 13 | # |num_downs|: number of downsamplings in UNet. For example, 14 | # if |num_downs| == 7, image of size 128x128 will become of size 1x1 15 | # at the bottleneck 16 | class Generator(nn.Module): 17 | def __init__(self, input_nc=3, output_nc=3, num_downs=8, ngf=64, 18 | norm_layer=nn.BatchNorm2d, use_dropout=False): 19 | super(Generator, self).__init__() 20 | 21 | # construct unet structure 22 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) 23 | for i in range(num_downs - 5): 24 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) 25 | unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 26 | unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 27 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 28 | unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) 29 | 30 | self.model = unet_block 31 | 32 | def forward(self, input): 33 | return self.model(input) 34 | 35 | 36 | # Defines the submodule with skip connection. 37 | # X -------------------identity---------------------- X 38 | # |-- downsampling -- |submodule| -- upsampling --| 39 | class UnetSkipConnectionBlock(nn.Module): 40 | def __init__(self, outer_nc, inner_nc, input_nc=None, 41 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): 42 | super(UnetSkipConnectionBlock, self).__init__() 43 | self.outermost = outermost 44 | if type(norm_layer) == functools.partial: 45 | use_bias = norm_layer.func == nn.InstanceNorm2d 46 | else: 47 | use_bias = norm_layer == nn.InstanceNorm2d 48 | if input_nc is None: 49 | input_nc = outer_nc 50 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, 51 | stride=2, padding=1, bias=use_bias) 52 | downrelu = nn.LeakyReLU(0.2, True) 53 | downnorm = norm_layer(inner_nc) 54 | uprelu = nn.ReLU(True) 55 | upnorm = norm_layer(outer_nc) 56 | 57 | if outermost: 58 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 59 | kernel_size=4, stride=2, 60 | padding=1) 61 | down = [downconv] 62 | up = [uprelu, upconv, nn.Tanh()] 63 | model = down + [submodule] + up 64 | elif innermost: 65 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc, 66 | kernel_size=4, stride=2, 67 | padding=1, bias=use_bias) 68 | down = [downrelu, downconv] 69 | up = [uprelu, upconv, upnorm] 70 | model = down + up 71 | else: 72 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 73 | kernel_size=4, stride=2, 74 | padding=1, bias=use_bias) 75 | down = [downrelu, downconv, downnorm] 76 | up = [uprelu, upconv, upnorm] 77 | 78 | if use_dropout: 79 | model = down + [submodule] + up + [nn.Dropout(0.5)] 80 | else: 81 | model = down + [submodule] + up 82 | 83 | self.model = nn.Sequential(*model) 84 | 85 | def forward(self, x): 86 | if self.outermost: 87 | return self.model(x) 88 | else: 89 | return torch.cat([x, self.model(x)], 1) 90 | -------------------------------------------------------------------------------- /networks/cyclegan.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import functools 3 | 4 | # Defines the generator that consists of Resnet blocks between a few 5 | # downsampling/upsampling operations. 6 | # Code and idea originally from Justin Johnson's architecture. 7 | # https://github.com/jcjohnson/fast-neural-style/ 8 | class Generator(nn.Module): 9 | def __init__(self, input_nc=3, output_nc=3, ngf=64, use_dropout=False, n_blocks=9, padding_type='reflect'): 10 | 11 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, 12 | track_running_stats=True) 13 | assert(n_blocks >= 0) 14 | super(Generator, self).__init__() 15 | self.input_nc = input_nc 16 | self.output_nc = output_nc 17 | self.ngf = ngf 18 | if type(norm_layer) == functools.partial: 19 | use_bias = norm_layer.func == nn.InstanceNorm2d 20 | else: 21 | use_bias = norm_layer == nn.InstanceNorm2d 22 | 23 | model = [nn.ReflectionPad2d(3), 24 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, 25 | bias=use_bias), 26 | norm_layer(ngf), 27 | nn.ReLU(True)] 28 | 29 | n_downsampling = 2 30 | for i in range(n_downsampling): 31 | mult = 2**i 32 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, 33 | stride=2, padding=1, bias=use_bias), 34 | norm_layer(ngf * mult * 2), 35 | nn.ReLU(True)] 36 | 37 | mult = 2**n_downsampling 38 | for i in range(n_blocks): 39 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] 40 | 41 | for i in range(n_downsampling): 42 | mult = 2**(n_downsampling - i) 43 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), 44 | kernel_size=3, stride=2, 45 | padding=1, output_padding=1, 46 | bias=use_bias), 47 | norm_layer(int(ngf * mult / 2)), 48 | nn.ReLU(True)] 49 | model += [nn.ReflectionPad2d(3)] 50 | model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] 51 | model += [nn.Tanh()] 52 | 53 | self.model = nn.Sequential(*model) 54 | 55 | def forward(self, input): 56 | return self.model(input) 57 | 58 | 59 | # Define a resnet block 60 | class ResnetBlock(nn.Module): 61 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): 62 | super(ResnetBlock, self).__init__() 63 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) 64 | 65 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): 66 | conv_block = [] 67 | p = 0 68 | if padding_type == 'reflect': 69 | conv_block += [nn.ReflectionPad2d(1)] 70 | elif padding_type == 'replicate': 71 | conv_block += [nn.ReplicationPad2d(1)] 72 | elif padding_type == 'zero': 73 | p = 1 74 | else: 75 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 76 | 77 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), 78 | norm_layer(dim), 79 | nn.ReLU(True)] 80 | if use_dropout: 81 | conv_block += [nn.Dropout(0.5)] 82 | 83 | p = 0 84 | if padding_type == 'reflect': 85 | conv_block += [nn.ReflectionPad2d(1)] 86 | elif padding_type == 'replicate': 87 | conv_block += [nn.ReplicationPad2d(1)] 88 | elif padding_type == 'zero': 89 | p = 1 90 | else: 91 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 92 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), 93 | norm_layer(dim)] 94 | 95 | return nn.Sequential(*conv_block) 96 | 97 | def forward(self, x): 98 | out = x + self.conv_block(x) 99 | return out 100 | -------------------------------------------------------------------------------- /networks/stargan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from torchvision import transforms 6 | 7 | 8 | class ResidualBlock(nn.Module): 9 | """Residual Block.""" 10 | def __init__(self, dim_in, dim_out): 11 | super(ResidualBlock, self).__init__() 12 | self.main = nn.Sequential( 13 | nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False), 14 | nn.InstanceNorm2d(dim_out, affine=True), 15 | nn.ReLU(inplace=True), 16 | nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False), 17 | nn.InstanceNorm2d(dim_out, affine=True)) 18 | 19 | def forward(self, x): 20 | return x + self.main(x) 21 | 22 | 23 | class Generator(nn.Module): 24 | """Generator. Encoder-Decoder Architecture.""" 25 | def __init__(self, conv_dim=64, c_dim=5, repeat_num=6, image_size=128): 26 | super(Generator, self).__init__() 27 | 28 | def conv2d_output_size(image_size, kernel, stride, pad): 29 | return ((image_size + 2 * pad - kernel)//stride) + 1 30 | 31 | def conv2dtranspose_output_size(image_size, kernel, stride, pad): 32 | return (image_size - 1) * stride - 2 * pad + kernel 33 | 34 | layers = [] 35 | layers.append(nn.Conv2d(3+c_dim, conv_dim, kernel_size=7, stride=1, padding=3, bias=False)) 36 | layers.append(nn.InstanceNorm2d(conv_dim, affine=True)) 37 | layers.append(nn.ReLU(inplace=True)) 38 | 39 | image_size = conv2d_output_size(image_size=image_size, kernel=7, stride=1, pad=3) 40 | 41 | # Down-Sampling 42 | curr_dim = conv_dim 43 | for i in range(2): 44 | layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1, bias=False)) 45 | layers.append(nn.InstanceNorm2d(curr_dim*2, affine=True)) 46 | layers.append(nn.ReLU(inplace=True)) 47 | curr_dim = curr_dim * 2 48 | 49 | image_size = conv2d_output_size(image_size=image_size, kernel=4, stride=2, pad=1) 50 | 51 | # Bottleneck 52 | for i in range(repeat_num): 53 | layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim)) 54 | image_size = conv2d_output_size(image_size=image_size, kernel=3, stride=1, pad=1) 55 | 56 | # Up-Sampling 57 | for i in range(2): 58 | image_size = conv2dtranspose_output_size(image_size=image_size, kernel=4, stride=2, pad=1) 59 | layers.append(nn.Upsample(size=image_size, mode='bilinear')) 60 | layers.append(nn.ConvTranspose2d(curr_dim, curr_dim//2, kernel_size=7, stride=1, padding=3, bias=False)) 61 | layers.append(nn.InstanceNorm2d(curr_dim//2, affine=True)) 62 | layers.append(nn.ReLU(inplace=True)) 63 | curr_dim = curr_dim // 2 64 | 65 | layers.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False)) 66 | layers.append(nn.Tanh()) 67 | self.main = nn.Sequential(*layers) 68 | 69 | def forward(self, x, c): 70 | # replicate spatially and concatenate domain information 71 | c = c.unsqueeze(2).unsqueeze(3) 72 | c = c.expand(c.size(0), c.size(1), x.size(2), x.size(3)) 73 | x = torch.cat([x, c], dim=1) 74 | return self.main(x) 75 | 76 | 77 | class Discriminator(nn.Module): 78 | """Discriminator. PatchGAN.""" 79 | def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6): 80 | super(Discriminator, self).__init__() 81 | 82 | layers = [] 83 | layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1)) 84 | layers.append(nn.LeakyReLU(0.01, inplace=True)) 85 | 86 | curr_dim = conv_dim 87 | for i in range(1, repeat_num): 88 | layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1)) 89 | layers.append(nn.LeakyReLU(0.01, inplace=True)) 90 | curr_dim = curr_dim * 2 91 | 92 | k_size = int(image_size / np.power(2, repeat_num)) 93 | self.main = nn.Sequential(*layers) 94 | self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1, bias=False) 95 | self.conv2 = nn.Conv2d(curr_dim, c_dim, kernel_size=k_size, bias=False) 96 | 97 | def forward(self, x): 98 | h = self.main(x) 99 | out_real = self.conv1(h) 100 | out_aux = self.conv2(h) 101 | return out_real.squeeze(), out_aux.squeeze() -------------------------------------------------------------------------------- /src/features.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import io 4 | import requests 5 | from PIL import Image 6 | import torch 7 | import torchvision 8 | import torch.nn as nn 9 | from torchvision import transforms 10 | 11 | 12 | class ResnetFeatureGenerator(): 13 | """ 14 | Loads a ResNet152 model to generate feature vectors from the last hidden 15 | layer of the network. If no model path is provided, the pretrained ImageNet 16 | model from PyTorch is loaded. Otherwise, the model's weights are loaded 17 | from the model path. 18 | """ 19 | 20 | _DATA_TRANSFORMS = transforms.Compose([ 21 | transforms.Resize(224), 22 | transforms.ToTensor(), 23 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 24 | ]) 25 | 26 | def __init__(self, retrained_model_path=None): 27 | """ 28 | :param retrained_model_path: path to a retrained resnet152 model 29 | (if not provided, the original ImageNet pretrained model from PyTorch 30 | is loaded) 31 | """ 32 | 33 | self.model_path = retrained_model_path 34 | self.model = self._load_resnet152() 35 | 36 | def get_feature(self, img: Image): 37 | feature = self.model(self._DATA_TRANSFORMS(img).unsqueeze(0)) 38 | feature = feature.squeeze().data.numpy() 39 | 40 | return feature 41 | 42 | def _load_resnet152(self): 43 | """Loads original ResNet152 model with pretrained weights on ImageNet 44 | dataset. If a model_path is provided, it loads the weights from the 45 | model_path. Returns the model excluding the last layer. 46 | """ 47 | 48 | model = torchvision.models.resnet152(pretrained=True) 49 | 50 | if self.model_path: 51 | num_ftrs = model.fc.in_features 52 | model.fc = nn.Linear(num_ftrs, 7) 53 | model.load_state_dict( 54 | torch.load(self.model_path, map_location='cpu')) 55 | 56 | # only load layers until the last hidden layer to extract features 57 | modules = list(model.children())[:-1] 58 | model = nn.Sequential(*modules) 59 | 60 | return model 61 | 62 | 63 | class AkiwiFeatureGenerator(): 64 | """ 65 | Queries the Akiwi API to get feature vectors for an image. 66 | """ 67 | 68 | _SIZE_URLS = { 69 | 64: ['http://akiwi.eu/mxnet/feature/'], 70 | 50: ['http://akiwi.eu/feature/fv50/'], 71 | 114: ['http://akiwi.eu/mxnet/feature/', 72 | 'http://akiwi.eu/feature/fv50/'] 73 | } 74 | 75 | def __init__(self, feature_size): 76 | """ 77 | Based on the feature size, queries the corresponding API endpoint. 78 | Feature size 114 are sizes 64 and 50 appended. 79 | :param feature_size: the size of the feature vector to download 80 | """ 81 | self.urls = self._SIZE_URLS[feature_size] 82 | 83 | def get_feature(self, img: Image): 84 | 85 | features = [] 86 | for url in self.urls: 87 | f = self._get_url_feature(url, img) 88 | features.append(f) 89 | 90 | feature = np.concatenate(features) 91 | return feature 92 | 93 | @staticmethod 94 | def _get_url_feature(url, img: Image): 95 | """ 96 | Query the Akiwi API to get an image feature vector 97 | :param url: URL of the API to query 98 | :param img: image for which to get feature vector 99 | :return: feature numpy array 100 | """ 101 | 102 | img_bytes = io.BytesIO() 103 | img.save(img_bytes, format='JPEG') 104 | files = {'file': ('img.jpg', img_bytes.getvalue(), 'image/jpeg')} 105 | response = requests.post(url, files=files, timeout=10) 106 | 107 | retries = 0 108 | while (response.status_code != 200) & (retries < 3): 109 | retries += 1 110 | response = requests.post(url, files=files, timeout=10) 111 | 112 | if response.status_code == 200: 113 | response_feature = response.content 114 | feature = np.frombuffer(response_feature, dtype=np.uint8) 115 | return feature 116 | 117 | print("Couldn't get feature. Response: ", response) 118 | 119 | 120 | def download_feature_vectors(files, save_dir, feature_generator=None): 121 | """ 122 | Downloads feature vectors for a list of files. 123 | :param files: list of files 124 | :param save_dir: directory to save feature vectors 125 | :param feature_generator: feature generator to create features 126 | """ 127 | 128 | if not feature_generator: 129 | feature_generator = AkiwiFeatureGenerator(114) 130 | 131 | if not os.path.exists(save_dir): 132 | os.makedirs(save_dir) 133 | 134 | for idx, file in enumerate(files): 135 | if idx % 100 == 0: 136 | print('Downloaded {} / {}'.format(idx, len(files))) 137 | 138 | # name feature with the same basename as file 139 | save_path = os.path.join(save_dir, os.path.basename(file).split('.jpg')[0] + '.npy') 140 | if os.path.exists(save_path): 141 | continue 142 | 143 | feature = feature_generator.get_feature(Image.open(file)) 144 | np.save(save_path, feature) 145 | 146 | 147 | def main(): 148 | print('done') 149 | 150 | 151 | if __name__ == '__main__': 152 | main() -------------------------------------------------------------------------------- /src/gans.py: -------------------------------------------------------------------------------- 1 | from networks import stargan, pix2pix, cyclegan 2 | from torchvision import transforms 3 | import torch 4 | from PIL import Image 5 | import os 6 | import glob 7 | import random 8 | 9 | class Modifier(): 10 | """ 11 | Modifier class that contains all GAN generator to modify attributes of 12 | an input image. 13 | """ 14 | 15 | def __init__(self, models_root): 16 | """ 17 | :param models_root: path to folder containg all models 18 | """ 19 | self._shape_modifier = _StarGANModifier(os.path.join(models_root, 'stargan')) 20 | self._pattern_modifier = _CycleGANModifier(os.path.join(models_root, 'cyclegan')) 21 | self._model_generator = _Pix2PixModifier(os.path.join(models_root, 'pix2pix_models.pth')) 22 | 23 | def modify_shape(self, image: Image, attribute: str, value: str): 24 | return self._shape_modifier.modify_image(image, attribute, value) 25 | 26 | def modify_pattern(self, image: Image, attribute: str, value: str): 27 | return self._pattern_modifier.modify_image(image, attribute, value) 28 | 29 | def product_to_model(self, image: Image): 30 | return self._model_generator.generate_image(image) 31 | 32 | def get_shape_labels(self): 33 | return self._shape_modifier.LABELS 34 | 35 | def get_pattern_labels(self): 36 | return self._pattern_modifier.LABELS 37 | 38 | 39 | class _BaseModifier(): 40 | def __init__(self, img_size): 41 | 42 | self.TRANSFORMS = transforms.Compose([ 43 | transforms.Resize(img_size, interpolation=Image.ANTIALIAS), 44 | transforms.ToTensor(), 45 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 46 | 47 | @staticmethod 48 | def denorm_tensor(img_tensor): 49 | 50 | img_d = (img_tensor + 1) / 2 51 | img_d = img_d.clamp_(0, 1) 52 | img_d = img_d.data.mul(255).clamp(0, 255).byte() 53 | img_d = img_d.permute(1, 2, 0).cpu().numpy() 54 | return Image.fromarray(img_d) 55 | 56 | 57 | class _StarGANModifier(_BaseModifier): 58 | 59 | LABELS = { 60 | 'sleeve_length': ['3/4', 'long', 'short', 'sleeveless'], 61 | 'fit': ['loose', 'normal', 'tight'], 62 | 'neckline': ['round', 'v', 'wide'] 63 | } 64 | 65 | IMAGE_SIZE = 128 66 | 67 | def __init__(self, G_path_root): 68 | 69 | self.G_models = {} 70 | 71 | for attr, label_map in self.LABELS.items(): 72 | num_classes = len(self.LABELS[attr]) 73 | 74 | G_path = os.path.join(G_path_root, attr + '.pth') 75 | self.G_models[attr] = self.load_model(G_path, num_classes) 76 | 77 | super().__init__(self.IMAGE_SIZE) 78 | 79 | @staticmethod 80 | def load_model(G_path, num_classes): 81 | try: 82 | G = stargan.Generator(c_dim=num_classes) 83 | G.load_state_dict(torch.load(G_path, map_location='cpu')) 84 | return G 85 | 86 | except: 87 | print("Couldn't find model", G_path) 88 | 89 | def modify_image(self, image: Image, attribute: str, value: str): 90 | assert attribute in self.LABELS.keys() 91 | assert value in self.LABELS[attribute] 92 | 93 | img_tensor = self.TRANSFORMS(image).unsqueeze(0) 94 | img_label = self.get_label(attribute, value) 95 | label_tensor = torch.FloatTensor(img_label).unsqueeze(0) 96 | 97 | G = self.G_models[attribute] 98 | fake_tensor = G(img_tensor, label_tensor).squeeze(0) 99 | fake_img = self.denorm_tensor(fake_tensor) 100 | 101 | return fake_img 102 | 103 | def get_label(self, attr, value): 104 | attr_len = len(self.LABELS[attr]) 105 | label_idx = self.LABELS[attr].index(value) 106 | 107 | label = [0] * attr_len 108 | label[label_idx] = 1 109 | 110 | return label 111 | 112 | 113 | class _CycleGANModifier(_BaseModifier): 114 | 115 | LABELS = { 116 | 'floral': ['add', 'remove'], 117 | 'stripes': ['add', 'remove'] 118 | } 119 | 120 | IMAGE_SIZE = 256 121 | 122 | def __init__(self, G_path_root): 123 | 124 | self.G_models = {} 125 | 126 | for attr, values in self.LABELS.items(): 127 | self.G_models[attr] = {} 128 | attr_path = glob.glob(os.path.join(G_path_root, attr, '*.pth')) 129 | 130 | for value in values: 131 | G_paths = [p for p in attr_path if value in p] 132 | G_list = [self.load_model(p) for p in G_paths] 133 | self.G_models[attr][value] = G_list 134 | 135 | super().__init__(self.IMAGE_SIZE) 136 | 137 | @staticmethod 138 | def load_model(G_path): 139 | try: 140 | G = cyclegan.Generator() 141 | G.load_state_dict(torch.load(G_path, map_location='cpu')) 142 | return G 143 | 144 | except: 145 | print("Couldn't find model", G_path) 146 | 147 | def modify_image(self, image: Image, attribute: str, value: str): 148 | assert attribute in self.LABELS.keys() 149 | assert value in self.LABELS[attribute] 150 | 151 | img_tensor = self.TRANSFORMS(image).unsqueeze(0) 152 | 153 | G = random.choice(self.G_models[attribute][value]) 154 | fake_tensor = G(img_tensor).squeeze(0) 155 | 156 | fake_img = self.denorm_tensor(fake_tensor) 157 | 158 | return fake_img 159 | 160 | 161 | class _Pix2PixModifier(_BaseModifier): 162 | 163 | IMAGE_SIZE = 256 164 | 165 | def __init__(self, G_path): 166 | 167 | self.G = pix2pix.Generator() 168 | self.G.load_state_dict(torch.load(G_path, map_location='cpu')) 169 | 170 | super().__init__(self.IMAGE_SIZE) 171 | 172 | def generate_image(self, image: Image): 173 | img_tensor = self.TRANSFORMS(image).unsqueeze(0) 174 | fake_tensor = self.G(img_tensor).squeeze(0) 175 | fake_img = self.denorm_tensor(fake_tensor) 176 | 177 | return fake_img 178 | 179 | 180 | def main(): 181 | modifier = Modifier('../data/models/') 182 | 183 | test_img = Image.open('../data/images/test_images/dresses/NEW1407001000004.jpg') 184 | mod_img = modifier.modify_pattern(test_img, 'floral', 'add') 185 | print('done') 186 | 187 | if __name__ == '__main__': 188 | main() -------------------------------------------------------------------------------- /src/pipeline.py: -------------------------------------------------------------------------------- 1 | from src.image_utils import plot_img, plot_img_row 2 | from PIL import Image 3 | from src.gans import Modifier 4 | 5 | 6 | class FashionGANApp(): 7 | """ 8 | App that allows user to modify a dress image and search for similar 9 | products via console input. 10 | """ 11 | 12 | def __init__(self, modifier: Modifier, dress_search, model_search, 13 | num_imgs=10, metric='l1'): 14 | """ 15 | :param modifier: GAN modifier object to generate modified images 16 | :param dress_search: Search or CombinedSearch object for products 17 | :param model_search: Search or CombinedSearch object for model images 18 | :param num_imgs: Number of similar images to retrieve 19 | """ 20 | self._modifier = modifier 21 | self._dress_search = dress_search 22 | self._model_search = model_search 23 | self._num_similar_imgs = num_imgs 24 | self._search_metric = metric 25 | 26 | def start(self, input_img: Image): 27 | """ 28 | Start the application - modify input image with console input. 29 | :param input_img: image to modify 30 | """ 31 | 32 | product_img = input_img 33 | plot_img(product_img) 34 | 35 | # SHAPE MODIFICATION 36 | product_img = self._shape_modification(product_img) 37 | 38 | # PATTERN MODIFICATION 39 | product_img = self._pattern_modification(product_img) 40 | 41 | # MODEL IMAGE 42 | model_img = self._generate_model_image(product_img) 43 | mod_sim_imgs = self._search_models(model_img) 44 | 45 | # SEARCH 46 | prod_sim_imgs = self._search_products(product_img) 47 | 48 | # BEST IMAGE 49 | img_idx = self._select_best_image() 50 | 51 | # CONTINUE 52 | if img_idx != '': 53 | best_img = Image.open(prod_sim_imgs[int(img_idx)]) 54 | self.start(best_img) 55 | 56 | def _shape_modification(self, img): 57 | self._print_title('SHAPE MODIFICATION') 58 | shape_labels = self._modifier.get_shape_labels() 59 | attr = self._ask_user_input(list(shape_labels.keys()), skip_option=True) 60 | 61 | if attr != '': 62 | value = self._ask_user_input(shape_labels[attr], skip_option=False) 63 | img = self._modifier.modify_shape(img, attr, value) 64 | plot_img(img) 65 | 66 | return img 67 | 68 | def _pattern_modification(self, img): 69 | self._print_title('PATTERN MODIFICATION') 70 | pattern_labels = self._modifier.get_pattern_labels() 71 | attr = self._ask_user_input(list(pattern_labels.keys()), 72 | skip_option=True) 73 | 74 | if attr != '': 75 | value = self._ask_user_input(pattern_labels[attr], 76 | skip_option=False) 77 | img = self._modifier.modify_pattern(img, attr, value) 78 | plot_img(img) 79 | 80 | return img 81 | 82 | def _generate_model_image(self, img): 83 | self._print_title('MODEL IMAGE') 84 | model_img = self._modifier.product_to_model(img) 85 | plot_img(model_img) 86 | 87 | return model_img 88 | 89 | def _search_products(self, img): 90 | self._print_title('SIMILAR PRODUCTS FROM PRODUCT SEARCH') 91 | prod_sim_imgs = self._dress_search.get_similar_images( 92 | img, self._num_similar_imgs, metric=self._search_metric) 93 | plot_img_row([Image.open(i) for i in prod_sim_imgs], 94 | img_labels=range(self._num_similar_imgs)) 95 | 96 | return prod_sim_imgs 97 | 98 | def _search_models(self, img): 99 | self._print_title('MODEL SEARCH') 100 | mod_sim_imgs = self._model_search.get_similar_images( 101 | img, num_imgs=self._num_similar_imgs, metric=self._search_metric) 102 | 103 | plot_img_row([Image.open(i) for i in mod_sim_imgs]) 104 | 105 | return mod_sim_imgs 106 | 107 | def _select_best_image(self): 108 | self._print_title("SELECT BEST IMAGE") 109 | img_options = list(map(str, range(self._num_similar_imgs))) 110 | img_idx = self._ask_user_input(img_options, skip_option=True) 111 | 112 | return img_idx 113 | 114 | def _ask_user_input(self, options, skip_option=True): 115 | while True: 116 | print("Choose from the following: {}".format(options)) 117 | if skip_option: 118 | print("or press ENTER to skip") 119 | options.append('') 120 | attr = input() 121 | 122 | if attr not in options: 123 | print("Invalid input, try again.") 124 | else: 125 | print() 126 | return attr 127 | 128 | @staticmethod 129 | def _print_title(title): 130 | print(title) 131 | print('-' * 30) 132 | 133 | 134 | def main(): 135 | from PIL import Image 136 | import os 137 | 138 | from src.gans import Modifier 139 | from src.features import AkiwiFeatureGenerator, ResnetFeatureGenerator 140 | from src.search import Search, CombinedSearch 141 | 142 | import warnings 143 | warnings.filterwarnings('ignore') 144 | 145 | folder_gens = {'akiwi_50': AkiwiFeatureGenerator(50), 146 | 'resnet': ResnetFeatureGenerator()} 147 | 148 | dress_imgs = '../data/images/fashion/dresses/' 149 | model_imgs = '../data/images/fashion_models/dresses_clustered/' 150 | 151 | dress_feats = '../data/features/fashion/dresses/' 152 | model_feats = '../data/features/fashion_models/dresses/' 153 | 154 | dress_search = {} 155 | for dir_name, gen in folder_gens.items(): 156 | dress_search[dir_name] = Search(dress_imgs, 157 | os.path.join(dress_feats, dir_name), 158 | gen) 159 | 160 | model_search = {} 161 | for dir_name, gen in folder_gens.items(): 162 | model_search[dir_name] = Search(model_imgs, 163 | os.path.join(model_feats, dir_name), 164 | gen) 165 | 166 | # combined search 167 | dress_resnet50 = CombinedSearch( 168 | [dress_search['akiwi_50'], dress_search['resnet']], factors=[2, 1]) 169 | model_resnet50 = CombinedSearch( 170 | [model_search['akiwi_50'], model_search['resnet']], factors=[2, 1]) 171 | 172 | modifier = Modifier('../data/models/') 173 | app = FashionGANApp(modifier, dress_resnet50, model_resnet50) 174 | 175 | test_img = Image.open('../data/images/fashion/dresses/1DR21C07A-Q11.jpg') 176 | app.start(test_img) 177 | 178 | print('done') 179 | 180 | 181 | if __name__ == '__main__': 182 | main() -------------------------------------------------------------------------------- /src/search.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import glob 4 | from PIL import Image 5 | from sklearn.metrics import pairwise_distances 6 | from sklearn.preprocessing import StandardScaler 7 | from src import image_utils 8 | from src.features import ResnetFeatureGenerator, AkiwiFeatureGenerator 9 | 10 | 11 | class Search(): 12 | """ 13 | Loads all features from a folder and searches best images by comparing 14 | the distances between the features. 15 | """ 16 | 17 | def __init__(self, images_root, features_root, feature_generator=None): 18 | """ 19 | :param features_root: path to folder containing .npy features 20 | :param images_root: path to folder containg images with same names as 21 | in features_root 22 | :param feature_generator: generator to create features for new images 23 | """ 24 | 25 | if not feature_generator: 26 | feature_generator = ResnetFeatureGenerator() 27 | 28 | self.features_root = features_root 29 | self.feature_generator = feature_generator 30 | self.images_root = images_root 31 | 32 | self.feature_names, self.features = self._load_features() 33 | self.scaler = self._scale_features() 34 | 35 | def _load_features(self): 36 | """ 37 | Load features from the feature root into a numpy array 38 | """ 39 | 40 | print('Loading features from:', self.features_root) 41 | feat_files = glob.glob(os.path.join(self.features_root, '*.npy')) 42 | 43 | if len(feat_files) == 0: 44 | raise ValueError( 45 | 'Features root {} is empty.'.format(self.features_root)) 46 | 47 | feat_files = sorted(feat_files) 48 | feat_names = [os.path.basename(f).rsplit('.', 1)[0] 49 | for f in feat_files] 50 | features = np.array([np.load(f) for f in feat_files]) 51 | 52 | return feat_names, features 53 | 54 | def _scale_features(self): 55 | """ 56 | Scale all loaded features to mean=0 and std=1 57 | """ 58 | scaler = StandardScaler() 59 | self.features = scaler.fit_transform(self.features) 60 | 61 | return scaler 62 | 63 | def _get_img_feature(self, img): 64 | img = image_utils.process_image(img) 65 | 66 | img_feature = self.feature_generator.get_feature(img) 67 | img_feature = img_feature.reshape(1, -1) 68 | img_feature = self.scaler.transform(img_feature) 69 | 70 | return img_feature 71 | 72 | def _search_similar_images(self, img, num_imgs, metric): 73 | """ 74 | Calculate pairwise distances from the given image to the loaded 75 | features and return the image paths with the smallest distances 76 | """ 77 | 78 | img_feature = self._get_img_feature(img) 79 | 80 | dist = pairwise_distances(img_feature, self.features, metric=metric) 81 | best_img_idxs = np.argsort(dist)[0].tolist()[:num_imgs] 82 | best_img_dist = dist[0][best_img_idxs] 83 | 84 | best_img_paths = [ 85 | os.path.join(self.images_root, self.feature_names[i] + '.jpg') 86 | for i in best_img_idxs] 87 | 88 | return best_img_paths, best_img_dist 89 | 90 | def get_similar_images_with_distances(self, img: Image, num_imgs=8, metric='l1'): 91 | """ 92 | Retrieve similar images and their distances to the given image compared 93 | with the given metric 94 | :param img: image for which to find similar images 95 | :param num_imgs: number of similar images to retrieve 96 | :param metric: distance metric to use when comparing features ['l1','l2'] 97 | :return: list of best image paths and list of their distances 98 | to the original image 99 | """ 100 | return self._search_similar_images(img, num_imgs, metric) 101 | 102 | def get_similar_images(self, img: Image, num_imgs=8, metric='l1'): 103 | """ 104 | Retrieve similar images to the given image compared with the given metric 105 | :param img: image for which to find similar images 106 | :param num_imgs: number of similar images to retrieve 107 | :param metric: distance metric to use when comparing features ['l1','l2'] 108 | :return: list of best image paths 109 | """ 110 | 111 | best_img_paths, _ = self._search_similar_images(img, num_imgs, metric) 112 | return best_img_paths 113 | 114 | 115 | class CombinedSearch(): 116 | """ 117 | Combines a list of Search classes to enable search with combined features 118 | """ 119 | 120 | def __init__(self, search_list, factors=None): 121 | """ 122 | :param search_list: list of Search objects 123 | :param factors: list of factors with which to multiply the 124 | respective Search features 125 | """ 126 | 127 | # take all features as equal 128 | if not factors: 129 | factors = [1] * len(search_list) 130 | 131 | self.factors = factors 132 | self.search_list = search_list 133 | self.features = [] 134 | 135 | for search, factor in zip(search_list, factors): 136 | self.features.append(factor * search.features) 137 | self.features = np.concatenate(self.features, axis=1) 138 | 139 | def _get_img_feature(self, img): 140 | img = image_utils.process_image(img) 141 | 142 | img_feature = [] 143 | for search, factor in zip(self.search_list, self.factors): 144 | feature = search.feature_generator.get_feature(img) 145 | feature = feature.reshape(1, -1) 146 | feature = search.scaler.transform(feature) 147 | img_feature.append(factor * feature.squeeze()) 148 | 149 | img_feature = np.concatenate(img_feature) 150 | img_feature = img_feature.reshape(1, -1) 151 | 152 | return img_feature 153 | 154 | def _search_similar_images(self, img, num_imgs, metric): 155 | """ 156 | Calculate pairwise distances from the given image to the loaded 157 | features and return the image paths with the smallest distances 158 | """ 159 | 160 | img_feature = self._get_img_feature(img) 161 | 162 | dist = pairwise_distances(img_feature, self.features, metric=metric) 163 | best_img_idxs = np.argsort(dist)[0].tolist()[:num_imgs] 164 | best_img_dist = dist[0][best_img_idxs] 165 | 166 | best_img_paths = [ 167 | os.path.join(self.search_list[0].images_root, 168 | self.search_list[0].feature_names[i] + '.jpg') 169 | for i in best_img_idxs] 170 | 171 | return best_img_paths, best_img_dist 172 | 173 | def get_similar_images_with_distances(self, img: Image, num_imgs=8, metric='l1'): 174 | """ 175 | Retrieve similar images and their distances to the given image compared 176 | with the given metric 177 | :param img: image for which to find similar images 178 | :param num_imgs: number of similar images to retrieve 179 | :param metric: distance metric to use when comparing features ['l1','l2'] 180 | :return: list of best image paths and list of their distances 181 | to the original image 182 | """ 183 | return self._search_similar_images(img, num_imgs, metric) 184 | 185 | def get_similar_images(self, img: Image, num_imgs=8, metric='l1'): 186 | """ 187 | Retrieve similar images to the given image compared with the given metric 188 | :param img: image for which to find similar images 189 | :param num_imgs: number of similar images to retrieve 190 | :param metric: distance metric to use when comparing features ['l1','l2'] 191 | :return: list of best image paths 192 | """ 193 | 194 | best_img_paths, _ = self._search_similar_images(img, num_imgs, metric) 195 | return best_img_paths 196 | 197 | 198 | def main(): 199 | print('done') 200 | 201 | 202 | if __name__ == '__main__': 203 | main() -------------------------------------------------------------------------------- /processing/feature_vectors_download.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Features Download\n", 8 | "This notebook downloads various image features for all listed images. The available image features to be downloaded are:\n", 9 | "* Akiwi50\n", 10 | "* Akiwi64\n", 11 | "* Akiwi114\n", 12 | "* Resnet152 original\n", 13 | "* Resnet152 retrained for fashion dataset" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "import os\n", 23 | "import glob\n", 24 | "import numpy as np\n", 25 | "import sys\n", 26 | "sys.path.append('..')\n", 27 | "\n", 28 | "from src import features, search" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "## Load Data" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 3, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "# path with all image files\n", 45 | "data_path = '../data/images/fashion_models/dresses_clustered/*.jpg'\n", 46 | "# path to save features - subfolders for each feature name will be created\n", 47 | "feature_root = '../data/features/fashion_models/dresses/'" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 4, 53 | "metadata": {}, 54 | "outputs": [ 55 | { 56 | "name": "stdout", 57 | "output_type": "stream", 58 | "text": [ 59 | "num images: 15304\n" 60 | ] 61 | } 62 | ], 63 | "source": [ 64 | "# read data\n", 65 | "filelist = glob.glob(data_path)\n", 66 | "filelist = sorted(filelist)\n", 67 | "print('num images: ', len(filelist))" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": {}, 73 | "source": [ 74 | "# Akiwi Features" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "## 114 Features" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 5, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "feature_path = os.path.join(feature_root, 'akiwi_114')\n", 91 | "feature_gen = search.AkiwiFeatureGenerator(114)" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 6, 97 | "metadata": {}, 98 | "outputs": [ 99 | { 100 | "name": "stdout", 101 | "output_type": "stream", 102 | "text": [ 103 | "Downloaded 0 / 15304\n", 104 | "Downloaded 100 / 15304\n", 105 | "Downloaded 200 / 15304\n", 106 | "Downloaded 300 / 15304\n", 107 | "Downloaded 400 / 15304\n", 108 | "Downloaded 500 / 15304\n", 109 | "Downloaded 600 / 15304\n", 110 | "Downloaded 700 / 15304\n", 111 | "Downloaded 800 / 15304\n", 112 | "Downloaded 900 / 15304\n", 113 | "Downloaded 1000 / 15304\n", 114 | "Downloaded 1100 / 15304\n", 115 | "Downloaded 1200 / 15304\n", 116 | "Downloaded 1300 / 15304\n", 117 | "Downloaded 1400 / 15304\n", 118 | "Downloaded 1500 / 15304\n", 119 | "Downloaded 1600 / 15304\n", 120 | "Downloaded 1700 / 15304\n", 121 | "Downloaded 1800 / 15304\n", 122 | "Downloaded 1900 / 15304\n", 123 | "Downloaded 2000 / 15304\n", 124 | "Downloaded 2100 / 15304\n", 125 | "Downloaded 2200 / 15304\n", 126 | "Downloaded 2300 / 15304\n", 127 | "Downloaded 2400 / 15304\n", 128 | "Downloaded 2500 / 15304\n", 129 | "Downloaded 2600 / 15304\n", 130 | "Downloaded 2700 / 15304\n", 131 | "Downloaded 2800 / 15304\n", 132 | "Downloaded 2900 / 15304\n", 133 | "Downloaded 3000 / 15304\n", 134 | "Downloaded 3100 / 15304\n", 135 | "Downloaded 3200 / 15304\n", 136 | "Downloaded 3300 / 15304\n", 137 | "Downloaded 3400 / 15304\n", 138 | "Downloaded 3500 / 15304\n", 139 | "Downloaded 3600 / 15304\n", 140 | "Downloaded 3700 / 15304\n", 141 | "Downloaded 3800 / 15304\n", 142 | "Downloaded 3900 / 15304\n", 143 | "Downloaded 4000 / 15304\n", 144 | "Downloaded 4100 / 15304\n", 145 | "Downloaded 4200 / 15304\n", 146 | "Downloaded 4300 / 15304\n", 147 | "Downloaded 4400 / 15304\n", 148 | "Downloaded 4500 / 15304\n", 149 | "Downloaded 4600 / 15304\n", 150 | "Downloaded 4700 / 15304\n", 151 | "Downloaded 4800 / 15304\n", 152 | "Downloaded 4900 / 15304\n", 153 | "Downloaded 5000 / 15304\n", 154 | "Downloaded 5100 / 15304\n", 155 | "Downloaded 5200 / 15304\n", 156 | "Downloaded 5300 / 15304\n", 157 | "Downloaded 5400 / 15304\n", 158 | "Downloaded 5500 / 15304\n", 159 | "Downloaded 5600 / 15304\n", 160 | "Downloaded 5700 / 15304\n", 161 | "Downloaded 5800 / 15304\n", 162 | "Downloaded 5900 / 15304\n", 163 | "Downloaded 6000 / 15304\n", 164 | "Downloaded 6100 / 15304\n", 165 | "Downloaded 6200 / 15304\n", 166 | "Downloaded 6300 / 15304\n", 167 | "Downloaded 6400 / 15304\n", 168 | "Downloaded 6500 / 15304\n", 169 | "Downloaded 6600 / 15304\n", 170 | "Downloaded 6700 / 15304\n", 171 | "Downloaded 6800 / 15304\n", 172 | "Downloaded 6900 / 15304\n", 173 | "Downloaded 7000 / 15304\n", 174 | "Downloaded 7100 / 15304\n", 175 | "Downloaded 7200 / 15304\n", 176 | "Downloaded 7300 / 15304\n", 177 | "Downloaded 7400 / 15304\n", 178 | "Downloaded 7500 / 15304\n", 179 | "Downloaded 7600 / 15304\n", 180 | "Downloaded 7700 / 15304\n", 181 | "Downloaded 7800 / 15304\n", 182 | "Downloaded 7900 / 15304\n", 183 | "Downloaded 8000 / 15304\n", 184 | "Downloaded 8100 / 15304\n", 185 | "Downloaded 8200 / 15304\n", 186 | "Downloaded 8300 / 15304\n", 187 | "Downloaded 8400 / 15304\n", 188 | "Downloaded 8500 / 15304\n", 189 | "Downloaded 8600 / 15304\n", 190 | "Downloaded 8700 / 15304\n", 191 | "Downloaded 8800 / 15304\n", 192 | "Downloaded 8900 / 15304\n", 193 | "Downloaded 9000 / 15304\n", 194 | "Downloaded 9100 / 15304\n", 195 | "Downloaded 9200 / 15304\n", 196 | "Downloaded 9300 / 15304\n", 197 | "Downloaded 9400 / 15304\n", 198 | "Downloaded 9500 / 15304\n", 199 | "Downloaded 9600 / 15304\n", 200 | "Downloaded 9700 / 15304\n", 201 | "Downloaded 9800 / 15304\n", 202 | "Downloaded 9900 / 15304\n", 203 | "Downloaded 10000 / 15304\n", 204 | "Downloaded 10100 / 15304\n", 205 | "Downloaded 10200 / 15304\n", 206 | "Downloaded 10300 / 15304\n", 207 | "Downloaded 10400 / 15304\n", 208 | "Downloaded 10500 / 15304\n", 209 | "Downloaded 10600 / 15304\n", 210 | "Downloaded 10700 / 15304\n", 211 | "Downloaded 10800 / 15304\n", 212 | "Downloaded 10900 / 15304\n", 213 | "Downloaded 11000 / 15304\n", 214 | "Downloaded 11100 / 15304\n", 215 | "Downloaded 11200 / 15304\n", 216 | "Downloaded 11300 / 15304\n", 217 | "Downloaded 11400 / 15304\n", 218 | "Downloaded 11500 / 15304\n", 219 | "Downloaded 11600 / 15304\n", 220 | "Downloaded 11700 / 15304\n", 221 | "Downloaded 11800 / 15304\n", 222 | "Downloaded 11900 / 15304\n", 223 | "Downloaded 12000 / 15304\n", 224 | "Downloaded 12100 / 15304\n", 225 | "Downloaded 12200 / 15304\n", 226 | "Downloaded 12300 / 15304\n", 227 | "Downloaded 12400 / 15304\n", 228 | "Downloaded 12500 / 15304\n", 229 | "Downloaded 12600 / 15304\n", 230 | "Downloaded 12700 / 15304\n", 231 | "Downloaded 12800 / 15304\n", 232 | "Downloaded 12900 / 15304\n", 233 | "Downloaded 13000 / 15304\n", 234 | "Downloaded 13100 / 15304\n", 235 | "Downloaded 13200 / 15304\n", 236 | "Downloaded 13300 / 15304\n", 237 | "Downloaded 13400 / 15304\n", 238 | "Downloaded 13500 / 15304\n", 239 | "Downloaded 13600 / 15304\n", 240 | "Downloaded 13700 / 15304\n", 241 | "Downloaded 13800 / 15304\n", 242 | "Downloaded 13900 / 15304\n", 243 | "Downloaded 14000 / 15304\n", 244 | "Downloaded 14100 / 15304\n", 245 | "Downloaded 14200 / 15304\n", 246 | "Downloaded 14300 / 15304\n", 247 | "Downloaded 14400 / 15304\n", 248 | "Downloaded 14500 / 15304\n", 249 | "Downloaded 14600 / 15304\n", 250 | "Downloaded 14700 / 15304\n", 251 | "Downloaded 14800 / 15304\n", 252 | "Downloaded 14900 / 15304\n", 253 | "Downloaded 15000 / 15304\n", 254 | "Downloaded 15100 / 15304\n", 255 | "Downloaded 15200 / 15304\n", 256 | "Downloaded 15300 / 15304\n" 257 | ] 258 | } 259 | ], 260 | "source": [ 261 | "features.download_feature_vectors(filelist, feature_path, feature_gen)" 262 | ] 263 | }, 264 | { 265 | "cell_type": "markdown", 266 | "metadata": {}, 267 | "source": [ 268 | "## Split 114 features into 64 and 50" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": 7, 274 | "metadata": {}, 275 | "outputs": [], 276 | "source": [ 277 | "feat_path64 = os.path.join(feature_root, 'akiwi_64')\n", 278 | "if not os.path.exists(feat_path64):\n", 279 | " os.makedirs(feat_path64)\n", 280 | " \n", 281 | "feat_path50 = os.path.join(feature_root, 'akiwi_50')\n", 282 | "if not os.path.exists(feat_path50):\n", 283 | " os.makedirs(feat_path50)" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": 9, 289 | "metadata": {}, 290 | "outputs": [ 291 | { 292 | "name": "stdout", 293 | "output_type": "stream", 294 | "text": [ 295 | "0\n", 296 | "1000\n", 297 | "2000\n", 298 | "3000\n", 299 | "4000\n", 300 | "5000\n", 301 | "6000\n", 302 | "7000\n", 303 | "8000\n", 304 | "9000\n", 305 | "10000\n", 306 | "11000\n", 307 | "12000\n", 308 | "13000\n", 309 | "14000\n", 310 | "15000\n" 311 | ] 312 | } 313 | ], 314 | "source": [ 315 | "feats114 = glob.glob(os.path.join(feature_path,'*.npy'))\n", 316 | "for idx, file in enumerate(feats114):\n", 317 | " if idx % 1000 == 0:\n", 318 | " print(idx)\n", 319 | " \n", 320 | " fv = np.load(file)\n", 321 | " \n", 322 | " feat64 = fv[:64]\n", 323 | " np.save(os.path.join(feat_path64, os.path.basename(file)), feat64)\n", 324 | " \n", 325 | " feat50 = fv[64:]\n", 326 | " np.save(os.path.join(feat_path50, os.path.basename(file)), feat50)" 327 | ] 328 | }, 329 | { 330 | "cell_type": "markdown", 331 | "metadata": {}, 332 | "source": [ 333 | "# Original ResNet152 Features" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": 10, 339 | "metadata": {}, 340 | "outputs": [], 341 | "source": [ 342 | "feature_path = os.path.join(feature_root, 'resnet')\n", 343 | "feature_gen = search.ResnetFeatureGenerator()" 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": 11, 349 | "metadata": {}, 350 | "outputs": [ 351 | { 352 | "name": "stdout", 353 | "output_type": "stream", 354 | "text": [ 355 | "Downloaded 0 / 15304\n", 356 | "Downloaded 100 / 15304\n", 357 | "Downloaded 200 / 15304\n", 358 | "Downloaded 300 / 15304\n", 359 | "Downloaded 400 / 15304\n", 360 | "Downloaded 500 / 15304\n", 361 | "Downloaded 600 / 15304\n", 362 | "Downloaded 700 / 15304\n", 363 | "Downloaded 800 / 15304\n", 364 | "Downloaded 900 / 15304\n", 365 | "Downloaded 1000 / 15304\n", 366 | "Downloaded 1100 / 15304\n", 367 | "Downloaded 1200 / 15304\n", 368 | "Downloaded 1300 / 15304\n", 369 | "Downloaded 1400 / 15304\n", 370 | "Downloaded 1500 / 15304\n", 371 | "Downloaded 1600 / 15304\n", 372 | "Downloaded 1700 / 15304\n", 373 | "Downloaded 1800 / 15304\n", 374 | "Downloaded 1900 / 15304\n", 375 | "Downloaded 2000 / 15304\n", 376 | "Downloaded 2100 / 15304\n", 377 | "Downloaded 2200 / 15304\n", 378 | "Downloaded 2300 / 15304\n", 379 | "Downloaded 2400 / 15304\n", 380 | "Downloaded 2500 / 15304\n", 381 | "Downloaded 2600 / 15304\n", 382 | "Downloaded 2700 / 15304\n", 383 | "Downloaded 2800 / 15304\n", 384 | "Downloaded 2900 / 15304\n", 385 | "Downloaded 3000 / 15304\n", 386 | "Downloaded 3100 / 15304\n", 387 | "Downloaded 3200 / 15304\n", 388 | "Downloaded 3300 / 15304\n", 389 | "Downloaded 3400 / 15304\n", 390 | "Downloaded 3500 / 15304\n", 391 | "Downloaded 3600 / 15304\n", 392 | "Downloaded 3700 / 15304\n", 393 | "Downloaded 3800 / 15304\n", 394 | "Downloaded 3900 / 15304\n", 395 | "Downloaded 4000 / 15304\n", 396 | "Downloaded 4100 / 15304\n", 397 | "Downloaded 4200 / 15304\n", 398 | "Downloaded 4300 / 15304\n", 399 | "Downloaded 4400 / 15304\n", 400 | "Downloaded 4500 / 15304\n", 401 | "Downloaded 4600 / 15304\n", 402 | "Downloaded 4700 / 15304\n", 403 | "Downloaded 4800 / 15304\n", 404 | "Downloaded 4900 / 15304\n", 405 | "Downloaded 5000 / 15304\n", 406 | "Downloaded 5100 / 15304\n", 407 | "Downloaded 5200 / 15304\n", 408 | "Downloaded 5300 / 15304\n", 409 | "Downloaded 5400 / 15304\n", 410 | "Downloaded 5500 / 15304\n", 411 | "Downloaded 5600 / 15304\n", 412 | "Downloaded 5700 / 15304\n", 413 | "Downloaded 5800 / 15304\n", 414 | "Downloaded 5900 / 15304\n", 415 | "Downloaded 6000 / 15304\n", 416 | "Downloaded 6100 / 15304\n", 417 | "Downloaded 6200 / 15304\n", 418 | "Downloaded 6300 / 15304\n", 419 | "Downloaded 6400 / 15304\n", 420 | "Downloaded 6500 / 15304\n", 421 | "Downloaded 6600 / 15304\n", 422 | "Downloaded 6700 / 15304\n", 423 | "Downloaded 6800 / 15304\n", 424 | "Downloaded 6900 / 15304\n", 425 | "Downloaded 7000 / 15304\n", 426 | "Downloaded 7100 / 15304\n", 427 | "Downloaded 7200 / 15304\n", 428 | "Downloaded 7300 / 15304\n", 429 | "Downloaded 7400 / 15304\n", 430 | "Downloaded 7500 / 15304\n", 431 | "Downloaded 7600 / 15304\n", 432 | "Downloaded 7700 / 15304\n", 433 | "Downloaded 7800 / 15304\n", 434 | "Downloaded 7900 / 15304\n", 435 | "Downloaded 8000 / 15304\n", 436 | "Downloaded 8100 / 15304\n", 437 | "Downloaded 8200 / 15304\n", 438 | "Downloaded 8300 / 15304\n", 439 | "Downloaded 8400 / 15304\n", 440 | "Downloaded 8500 / 15304\n", 441 | "Downloaded 8600 / 15304\n", 442 | "Downloaded 8700 / 15304\n", 443 | "Downloaded 8800 / 15304\n", 444 | "Downloaded 8900 / 15304\n", 445 | "Downloaded 9000 / 15304\n", 446 | "Downloaded 9100 / 15304\n", 447 | "Downloaded 9200 / 15304\n", 448 | "Downloaded 9300 / 15304\n", 449 | "Downloaded 9400 / 15304\n", 450 | "Downloaded 9500 / 15304\n", 451 | "Downloaded 9600 / 15304\n", 452 | "Downloaded 9700 / 15304\n", 453 | "Downloaded 9800 / 15304\n", 454 | "Downloaded 9900 / 15304\n", 455 | "Downloaded 10000 / 15304\n", 456 | "Downloaded 10100 / 15304\n", 457 | "Downloaded 10200 / 15304\n", 458 | "Downloaded 10300 / 15304\n", 459 | "Downloaded 10400 / 15304\n", 460 | "Downloaded 10500 / 15304\n", 461 | "Downloaded 10600 / 15304\n", 462 | "Downloaded 10700 / 15304\n", 463 | "Downloaded 10800 / 15304\n", 464 | "Downloaded 10900 / 15304\n", 465 | "Downloaded 11000 / 15304\n", 466 | "Downloaded 11100 / 15304\n", 467 | "Downloaded 11200 / 15304\n", 468 | "Downloaded 11300 / 15304\n", 469 | "Downloaded 11400 / 15304\n", 470 | "Downloaded 11500 / 15304\n", 471 | "Downloaded 11600 / 15304\n", 472 | "Downloaded 11700 / 15304\n", 473 | "Downloaded 11800 / 15304\n", 474 | "Downloaded 11900 / 15304\n", 475 | "Downloaded 12000 / 15304\n", 476 | "Downloaded 12100 / 15304\n", 477 | "Downloaded 12200 / 15304\n", 478 | "Downloaded 12300 / 15304\n", 479 | "Downloaded 12400 / 15304\n", 480 | "Downloaded 12500 / 15304\n", 481 | "Downloaded 12600 / 15304\n", 482 | "Downloaded 12700 / 15304\n", 483 | "Downloaded 12800 / 15304\n", 484 | "Downloaded 12900 / 15304\n", 485 | "Downloaded 13000 / 15304\n", 486 | "Downloaded 13100 / 15304\n", 487 | "Downloaded 13200 / 15304\n", 488 | "Downloaded 13300 / 15304\n", 489 | "Downloaded 13400 / 15304\n", 490 | "Downloaded 13500 / 15304\n", 491 | "Downloaded 13600 / 15304\n", 492 | "Downloaded 13700 / 15304\n", 493 | "Downloaded 13800 / 15304\n", 494 | "Downloaded 13900 / 15304\n", 495 | "Downloaded 14000 / 15304\n", 496 | "Downloaded 14100 / 15304\n", 497 | "Downloaded 14200 / 15304\n", 498 | "Downloaded 14300 / 15304\n", 499 | "Downloaded 14400 / 15304\n", 500 | "Downloaded 14500 / 15304\n", 501 | "Downloaded 14600 / 15304\n", 502 | "Downloaded 14700 / 15304\n", 503 | "Downloaded 14800 / 15304\n", 504 | "Downloaded 14900 / 15304\n", 505 | "Downloaded 15000 / 15304\n", 506 | "Downloaded 15100 / 15304\n", 507 | "Downloaded 15200 / 15304\n", 508 | "Downloaded 15300 / 15304\n" 509 | ] 510 | } 511 | ], 512 | "source": [ 513 | "features.download_feature_vectors(filelist, feature_path, feature_gen)" 514 | ] 515 | }, 516 | { 517 | "cell_type": "markdown", 518 | "metadata": {}, 519 | "source": [ 520 | "# Retrained ResNet152 Features" 521 | ] 522 | }, 523 | { 524 | "cell_type": "code", 525 | "execution_count": 12, 526 | "metadata": {}, 527 | "outputs": [], 528 | "source": [ 529 | "feature_path = os.path.join(feature_root, 'resnet_retrained')\n", 530 | "feature_gen = search.ResnetFeatureGenerator('../data/models/resnet152_retrained.pth')" 531 | ] 532 | }, 533 | { 534 | "cell_type": "code", 535 | "execution_count": 13, 536 | "metadata": {}, 537 | "outputs": [ 538 | { 539 | "name": "stdout", 540 | "output_type": "stream", 541 | "text": [ 542 | "Downloaded 0 / 15304\n", 543 | "Downloaded 100 / 15304\n", 544 | "Downloaded 200 / 15304\n", 545 | "Downloaded 300 / 15304\n", 546 | "Downloaded 400 / 15304\n", 547 | "Downloaded 500 / 15304\n", 548 | "Downloaded 600 / 15304\n", 549 | "Downloaded 700 / 15304\n", 550 | "Downloaded 800 / 15304\n", 551 | "Downloaded 900 / 15304\n", 552 | "Downloaded 1000 / 15304\n", 553 | "Downloaded 1100 / 15304\n", 554 | "Downloaded 1200 / 15304\n", 555 | "Downloaded 1300 / 15304\n", 556 | "Downloaded 1400 / 15304\n", 557 | "Downloaded 1500 / 15304\n", 558 | "Downloaded 1600 / 15304\n", 559 | "Downloaded 1700 / 15304\n", 560 | "Downloaded 1800 / 15304\n", 561 | "Downloaded 1900 / 15304\n", 562 | "Downloaded 2000 / 15304\n", 563 | "Downloaded 2100 / 15304\n", 564 | "Downloaded 2200 / 15304\n", 565 | "Downloaded 2300 / 15304\n", 566 | "Downloaded 2400 / 15304\n", 567 | "Downloaded 2500 / 15304\n", 568 | "Downloaded 2600 / 15304\n", 569 | "Downloaded 2700 / 15304\n", 570 | "Downloaded 2800 / 15304\n", 571 | "Downloaded 2900 / 15304\n", 572 | "Downloaded 3000 / 15304\n", 573 | "Downloaded 3100 / 15304\n", 574 | "Downloaded 3200 / 15304\n", 575 | "Downloaded 3300 / 15304\n", 576 | "Downloaded 3400 / 15304\n", 577 | "Downloaded 3500 / 15304\n", 578 | "Downloaded 3600 / 15304\n", 579 | "Downloaded 3700 / 15304\n", 580 | "Downloaded 3800 / 15304\n", 581 | "Downloaded 3900 / 15304\n", 582 | "Downloaded 4000 / 15304\n", 583 | "Downloaded 4100 / 15304\n", 584 | "Downloaded 4200 / 15304\n", 585 | "Downloaded 4300 / 15304\n", 586 | "Downloaded 4400 / 15304\n", 587 | "Downloaded 4500 / 15304\n", 588 | "Downloaded 4600 / 15304\n", 589 | "Downloaded 4700 / 15304\n", 590 | "Downloaded 4800 / 15304\n", 591 | "Downloaded 4900 / 15304\n", 592 | "Downloaded 5000 / 15304\n", 593 | "Downloaded 5100 / 15304\n", 594 | "Downloaded 5200 / 15304\n", 595 | "Downloaded 5300 / 15304\n", 596 | "Downloaded 5400 / 15304\n", 597 | "Downloaded 5500 / 15304\n", 598 | "Downloaded 5600 / 15304\n", 599 | "Downloaded 5700 / 15304\n", 600 | "Downloaded 5800 / 15304\n", 601 | "Downloaded 5900 / 15304\n", 602 | "Downloaded 6000 / 15304\n", 603 | "Downloaded 6100 / 15304\n", 604 | "Downloaded 6200 / 15304\n", 605 | "Downloaded 6300 / 15304\n", 606 | "Downloaded 6400 / 15304\n", 607 | "Downloaded 6500 / 15304\n", 608 | "Downloaded 6600 / 15304\n", 609 | "Downloaded 6700 / 15304\n", 610 | "Downloaded 6800 / 15304\n", 611 | "Downloaded 6900 / 15304\n", 612 | "Downloaded 7000 / 15304\n", 613 | "Downloaded 7100 / 15304\n", 614 | "Downloaded 7200 / 15304\n", 615 | "Downloaded 7300 / 15304\n", 616 | "Downloaded 7400 / 15304\n", 617 | "Downloaded 7500 / 15304\n", 618 | "Downloaded 7600 / 15304\n", 619 | "Downloaded 7700 / 15304\n", 620 | "Downloaded 7800 / 15304\n", 621 | "Downloaded 7900 / 15304\n", 622 | "Downloaded 8000 / 15304\n", 623 | "Downloaded 8100 / 15304\n", 624 | "Downloaded 8200 / 15304\n", 625 | "Downloaded 8300 / 15304\n", 626 | "Downloaded 8400 / 15304\n", 627 | "Downloaded 8500 / 15304\n", 628 | "Downloaded 8600 / 15304\n", 629 | "Downloaded 8700 / 15304\n", 630 | "Downloaded 8800 / 15304\n", 631 | "Downloaded 8900 / 15304\n", 632 | "Downloaded 9000 / 15304\n", 633 | "Downloaded 9100 / 15304\n", 634 | "Downloaded 9200 / 15304\n", 635 | "Downloaded 9300 / 15304\n", 636 | "Downloaded 9400 / 15304\n", 637 | "Downloaded 9500 / 15304\n", 638 | "Downloaded 9600 / 15304\n", 639 | "Downloaded 9700 / 15304\n", 640 | "Downloaded 9800 / 15304\n", 641 | "Downloaded 9900 / 15304\n", 642 | "Downloaded 10000 / 15304\n", 643 | "Downloaded 10100 / 15304\n", 644 | "Downloaded 10200 / 15304\n", 645 | "Downloaded 10300 / 15304\n", 646 | "Downloaded 10400 / 15304\n", 647 | "Downloaded 10500 / 15304\n", 648 | "Downloaded 10600 / 15304\n", 649 | "Downloaded 10700 / 15304\n", 650 | "Downloaded 10800 / 15304\n", 651 | "Downloaded 10900 / 15304\n", 652 | "Downloaded 11000 / 15304\n", 653 | "Downloaded 11100 / 15304\n", 654 | "Downloaded 11200 / 15304\n", 655 | "Downloaded 11300 / 15304\n", 656 | "Downloaded 11400 / 15304\n", 657 | "Downloaded 11500 / 15304\n", 658 | "Downloaded 11600 / 15304\n", 659 | "Downloaded 11700 / 15304\n", 660 | "Downloaded 11800 / 15304\n", 661 | "Downloaded 11900 / 15304\n", 662 | "Downloaded 12000 / 15304\n", 663 | "Downloaded 12100 / 15304\n", 664 | "Downloaded 12200 / 15304\n", 665 | "Downloaded 12300 / 15304\n", 666 | "Downloaded 12400 / 15304\n", 667 | "Downloaded 12500 / 15304\n", 668 | "Downloaded 12600 / 15304\n", 669 | "Downloaded 12700 / 15304\n", 670 | "Downloaded 12800 / 15304\n", 671 | "Downloaded 12900 / 15304\n", 672 | "Downloaded 13000 / 15304\n", 673 | "Downloaded 13100 / 15304\n", 674 | "Downloaded 13200 / 15304\n", 675 | "Downloaded 13300 / 15304\n", 676 | "Downloaded 13400 / 15304\n", 677 | "Downloaded 13500 / 15304\n", 678 | "Downloaded 13600 / 15304\n", 679 | "Downloaded 13700 / 15304\n", 680 | "Downloaded 13800 / 15304\n", 681 | "Downloaded 13900 / 15304\n", 682 | "Downloaded 14000 / 15304\n", 683 | "Downloaded 14100 / 15304\n", 684 | "Downloaded 14200 / 15304\n", 685 | "Downloaded 14300 / 15304\n", 686 | "Downloaded 14400 / 15304\n", 687 | "Downloaded 14500 / 15304\n", 688 | "Downloaded 14600 / 15304\n", 689 | "Downloaded 14700 / 15304\n", 690 | "Downloaded 14800 / 15304\n", 691 | "Downloaded 14900 / 15304\n", 692 | "Downloaded 15000 / 15304\n", 693 | "Downloaded 15100 / 15304\n", 694 | "Downloaded 15200 / 15304\n", 695 | "Downloaded 15300 / 15304\n" 696 | ] 697 | } 698 | ], 699 | "source": [ 700 | "features.download_feature_vectors(filelist, feature_path, feature_gen)" 701 | ] 702 | }, 703 | { 704 | "cell_type": "code", 705 | "execution_count": null, 706 | "metadata": { 707 | "collapsed": true 708 | }, 709 | "outputs": [], 710 | "source": [] 711 | } 712 | ], 713 | "metadata": { 714 | "kernelspec": { 715 | "display_name": "Python [default]", 716 | "language": "python", 717 | "name": "python3" 718 | }, 719 | "language_info": { 720 | "codemirror_mode": { 721 | "name": "ipython", 722 | "version": 3 723 | }, 724 | "file_extension": ".py", 725 | "mimetype": "text/x-python", 726 | "name": "python", 727 | "nbconvert_exporter": "python", 728 | "pygments_lexer": "ipython3", 729 | "version": "3.6.6" 730 | } 731 | }, 732 | "nbformat": 4, 733 | "nbformat_minor": 2 734 | } 735 | --------------------------------------------------------------------------------