├── main.py ├── webpage ├── __init__.py ├── pix2pix │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── dataloader.py │ │ ├── image_manipulation.py │ │ └── image_manipulation.ipynb │ ├── model.py │ ├── generator_new_data.csv │ ├── discriminator_new_data.csv │ ├── pix2pix.py │ ├── train.py │ └── test.py ├── pix2pix_transpose │ ├── __init__.py │ └── model.py ├── tests.py ├── admin.py ├── apps.py ├── urls.py ├── serializers.py ├── models.py ├── views.py ├── test.ipynb ├── convert_image.py └── microsoft_model.py ├── Old_image_reconstruction ├── __init__.py ├── asgi.py ├── wsgi.py ├── urls.py └── settings.py ├── .gitignore ├── manage.py ├── README.md ├── setup.sh ├── requirements.txt └── setup.bat /main.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /webpage/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /webpage/pix2pix/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /webpage/pix2pix/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Old_image_reconstruction/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /webpage/pix2pix_transpose/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /webpage/tests.py: -------------------------------------------------------------------------------- 1 | from django.test import TestCase 2 | 3 | # Create your tests here. 4 | -------------------------------------------------------------------------------- /webpage/admin.py: -------------------------------------------------------------------------------- 1 | from django.contrib import admin 2 | 3 | # Register your models here. 4 | -------------------------------------------------------------------------------- /webpage/apps.py: -------------------------------------------------------------------------------- 1 | from django.apps import AppConfig 2 | 3 | 4 | class WebpageConfig(AppConfig): 5 | default_auto_field = 'django.db.models.BigAutoField' 6 | name = 'webpage' 7 | -------------------------------------------------------------------------------- /webpage/urls.py: -------------------------------------------------------------------------------- 1 | from django.urls import path 2 | from . import views 3 | 4 | urlpatterns = [ 5 | path('', views.home, name='home'), 6 | 7 | path('api/old_image/',views.ImageList.as_view()), 8 | 9 | ] -------------------------------------------------------------------------------- /webpage/serializers.py: -------------------------------------------------------------------------------- 1 | from rest_framework import serializers 2 | from .models import Image 3 | 4 | class ImageSerializer(serializers.ModelSerializer): 5 | class Meta: 6 | 7 | model=Image 8 | fields=['id','image','n_image','method'] 9 | 10 | -------------------------------------------------------------------------------- /webpage/models.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | from django.utils.translation import gettext_lazy as _ 3 | 4 | 5 | def upload_to(instance,filename): 6 | return '{filename}'.format(filename=filename) 7 | 8 | class Image(models.Model): 9 | image = models.ImageField(_("Image"), upload_to=upload_to, height_field=None, width_field=None, max_length=None) 10 | n_image = models.ImageField("New_Image",default='Default.jpg') 11 | method = models.IntegerField("Method",default=0) -------------------------------------------------------------------------------- /Old_image_reconstruction/asgi.py: -------------------------------------------------------------------------------- 1 | """ 2 | ASGI config for Old_image_reconstruction project. 3 | 4 | It exposes the ASGI callable as a module-level variable named ``application``. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/4.0/howto/deployment/asgi/ 8 | """ 9 | 10 | import os 11 | 12 | from django.core.asgi import get_asgi_application 13 | 14 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'Old_image_reconstruction.settings') 15 | 16 | application = get_asgi_application() 17 | -------------------------------------------------------------------------------- /Old_image_reconstruction/wsgi.py: -------------------------------------------------------------------------------- 1 | """ 2 | WSGI config for Old_image_reconstruction project. 3 | 4 | It exposes the WSGI callable as a module-level variable named ``application``. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/4.0/howto/deployment/wsgi/ 8 | """ 9 | 10 | import os 11 | 12 | from django.core.wsgi import get_wsgi_application 13 | 14 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'Old_image_reconstruction.settings') 15 | 16 | application = get_wsgi_application() 17 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *env*/ 2 | *.org 3 | notebooks/ 4 | high_resoution_image/ 5 | .ipynb_checkpoints/ 6 | blend_images/ 7 | images/ 8 | old_images/ 9 | __pycache__/ 10 | .ipynb_checkpoints/ 11 | saved_models/ 12 | test_images/ 13 | img_align_celeba/ 14 | data.zip/ 15 | ISR_saved/ 16 | bringing_old_photos_back_to_life_2/ 17 | bringing_old_photos_back_to_life/ 18 | input_folder/ 19 | output_folder/ 20 | test.ipynb/ 21 | val/ 22 | train/ 23 | train_2/ 24 | val_2/ 25 | *.pth 26 | archive/ 27 | .vscode/ 28 | migrations/ 29 | db.sqlite3 -------------------------------------------------------------------------------- /manage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Django's command-line utility for administrative tasks.""" 3 | import os 4 | import sys 5 | 6 | 7 | def main(): 8 | """Run administrative tasks.""" 9 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'Old_image_reconstruction.settings') 10 | try: 11 | from django.core.management import execute_from_command_line 12 | except ImportError as exc: 13 | raise ImportError( 14 | "Couldn't import Django. Are you sure it's installed and " 15 | "available on your PYTHONPATH environment variable? Did you " 16 | "forget to activate a virtual environment?" 17 | ) from exc 18 | execute_from_command_line(sys.argv) 19 | 20 | 21 | if __name__ == '__main__': 22 | main() 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Old-image-restoration-minor 2 | Restoration of old images using [Pix2pix](https://arxiv.org/pdf/1611.07004.pdf) architecture. 3 | 4 | ### Prerequisite 5 | * Python version: 3.9.5 6 | 7 | ### Installing Dependencies 8 | ``` 9 | python -m pip install -r requirements.txt 10 | ``` 11 | 12 | ### How to Run 13 | * Clone the repo 14 | ``` 15 | git clone git@github.com:Imsanskar/Old-image-restoration-minor.git 16 | ``` 17 | * Run the script: `./setup.sh` 18 | * Run Django Server: 19 | ``` 20 | python manage.py makemigrations 21 | python manage.py migrate 22 | python manage.py runserver 23 | ``` 24 | 25 | ### Contributors 26 | * [Sanskar Amgain](https://github.com/Imsanskar/) 27 | * [Sagar Timalsina](https://github.com/Sgr45/) 28 | * [Sandip Puri](https://github.com/Sandippuri/) 29 | * [Tilak Chad](https://github.com/TilakChad/) 30 | -------------------------------------------------------------------------------- /Old_image_reconstruction/urls.py: -------------------------------------------------------------------------------- 1 | """Old_image_reconstruction URL Configuration 2 | 3 | The `urlpatterns` list routes URLs to views. For more information please see: 4 | https://docs.djangoproject.com/en/4.0/topics/http/urls/ 5 | Examples: 6 | Function views 7 | 1. Add an import: from my_app import views 8 | 2. Add a URL to urlpatterns: path('', views.home, name='home') 9 | Class-based views 10 | 1. Add an import: from other_app.views import Home 11 | 2. Add a URL to urlpatterns: path('', Home.as_view(), name='home') 12 | Including another URLconf 13 | 1. Import the include() function: from django.urls import include, path 14 | 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) 15 | """ 16 | from django.contrib import admin 17 | from django.urls import path, include 18 | 19 | from django.conf import settings 20 | from django.conf.urls.static import static 21 | 22 | urlpatterns = [ 23 | path('admin/', admin.site.urls), 24 | path('', include('webpage.urls')), 25 | ] 26 | 27 | urlpatterns += static(settings.MEDIA_URL,document_root=settings.MEDIA_ROOT) 28 | -------------------------------------------------------------------------------- /webpage/pix2pix/data/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.data import Dataset, DataLoader 3 | from torchvision import transforms, utils 4 | from PIL import Image 5 | import glob 6 | 7 | 8 | class ImageDataset(Dataset): 9 | def __init__(self, old_image_path:str, reconstructed_image_path:str, transform = None): 10 | # set the transforms 11 | self.transform = transform 12 | self.old_image_path = old_image_path 13 | self.reconstructed_image_path = reconstructed_image_path 14 | 15 | # get all the file names 16 | self.old_image_files = sorted(glob.glob(old_image_path + "/*.*")) 17 | self.reconstructed_image_files = sorted(glob.glob(reconstructed_image_path + "/*.*")) 18 | 19 | def __getitem__(self, index): 20 | reconstruted_image = Image.open(self.reconstructed_image_files[index % len(self.reconstructed_image_files)]) 21 | old_image = Image.open(self.old_image_files[index % len(self.old_image_files)]) 22 | 23 | if self.transform: 24 | img_A = self.transform(old_image) 25 | img_B = self.transform(reconstruted_image) 26 | 27 | return { 28 | "A": img_A, 29 | "B": img_B 30 | } 31 | 32 | def __len__(self): 33 | return len(self.old_image_files) 34 | 35 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | cd webpage/ 2 | if [ ! -d "./bringing_old_photos_back_to_life" ] 3 | then 4 | # clone the repo 5 | git clone https://github.com/microsoft/Bringing-Old-Photos-Back-to-Life.git 6 | 7 | 8 | # rename folder 9 | mv Bringing-Old-Photos-Back-to-Life/ bringing_old_photos_back_to_life/ 10 | 11 | #remove cloned folder 12 | rm -rf Bringing-Old-Photos-Back-to-Life/ 13 | 14 | # copy model.py 15 | cp microsoft_model.py bringing_old_photos_back_to_life/model.py 16 | 17 | cd bringing_old_photos_back_to_life/ 18 | 19 | # Clone the Synchronized-BatchNorm-PyTorch repository 20 | cd Face_Enhancement/models/networks/ 21 | git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 22 | cp -rf Synchronized-BatchNorm-PyTorch/sync_batchnorm . 23 | cd ../../../ 24 | 25 | cd Global/detection_models 26 | git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 27 | cp -rf Synchronized-BatchNorm-PyTorch/sync_batchnorm . 28 | cd ../../ 29 | 30 | echo "Downloading the landmark detection pretrained model" 31 | cd Face_Detection/ 32 | wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 33 | bzip2 -d shape_predictor_68_face_landmarks.dat.bz2 34 | cd ../ 35 | 36 | # Downloading the pretrained model from Azure Blob Storage 37 | cd Face_Enhancement/ 38 | wget https://github.com/microsoft/Bringing-Old-Photos-Back-to-Life/releases/download/v1.0/face_checkpoints.zip 39 | unzip face_checkpoints.zip 40 | 41 | # clean the zip file 42 | rm face_checkpoints.zip 43 | cd ../ 44 | cd Global/ 45 | wget https://github.com/microsoft/Bringing-Old-Photos-Back-to-Life/releases/download/v1.0/global_checkpoints.zip 46 | unzip global_checkpoints.zip 47 | 48 | # clean zip file 49 | rm global_checkpoints.zip 50 | cd ../ 51 | else 52 | echo "Old repo found" 53 | fi -------------------------------------------------------------------------------- /webpage/views.py: -------------------------------------------------------------------------------- 1 | from genericpath import exists 2 | from time import time 3 | from django.http import HttpResponse 4 | from django.shortcuts import render 5 | from .bringing_old_photos_back_to_life import model 6 | import os 7 | import torch 8 | from rest_framework.decorators import api_view 9 | from django.shortcuts import render 10 | from .models import Image 11 | from rest_framework import generics 12 | from .serializers import ImageSerializer 13 | from rest_framework.parsers import MultiPartParser, FormParser 14 | from .convert_image import convert 15 | from rest_framework.response import Response 16 | from rest_framework.views import APIView 17 | 18 | 19 | # Create your views here. 20 | # TODO: Remove this 21 | def home(request): 22 | torch.cuda.empty_cache() 23 | current_working_directory = os.getcwd().split('/')[-1] 24 | path = "webpage/pix2pix_super_res/outputs/input.jpg" 25 | curr_time = time() 26 | if not os.path.exists('webpage/input_folder'): 27 | os.makedirs("webpage/input_folder/") 28 | # shutil.copy(path, "webpage/input_folder/input.jpg") 29 | # im = Image.open("./pix2pix/data/train/old_images/001000.jpg") 30 | model.modify("webpage/input_folder", True) 31 | print(f"Time taken: {time() - curr_time}") 32 | return HttpResponse("Hello there") 33 | 34 | class ImageList(APIView): 35 | parser_classes = (MultiPartParser, FormParser) 36 | 37 | def get(self, request, *args, **kwargs): 38 | images = Image.objects.all() 39 | serializer = ImageSerializer(images, many=True) 40 | return Response(serializer.data) 41 | 42 | def post(self, request,*args, **kwargs): 43 | # delete all the images from the database 44 | for image in Image.objects.all(): 45 | image.delete() 46 | images_serializer = ImageSerializer(data=request.data) 47 | if images_serializer.is_valid(): 48 | images_serializer.save() 49 | convert(images_serializer['method'].value) 50 | return Response(images_serializer.data,) 51 | else: 52 | print(f"error, {images_serializer.errors}") 53 | return Response(images_serializer.errors) 54 | 55 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aio-pika==6.7.1 2 | aiofiles==0.8.0 3 | aiohttp==3.8.1 4 | aiormq==3.3.1 5 | aiosignal==1.2.0 6 | apply-defaults==0.1.6 7 | asn1crypto==1.4.0 8 | async-timeout==4.0.2 9 | asynctest==0.13.0 10 | attrs==21.4.0 11 | backcall==0.2.0 12 | certifi==2021.10.8 13 | cffi==1.15.0 14 | chardet==3.0.4 15 | charset-normalizer==2.0.10 16 | click==7.1.2 17 | coincurve==13.0.0 18 | cycler==0.10.0 19 | cytoolz==0.11.2 20 | debugpy==1.5.1 21 | decorator==5.1.1 22 | earlgrey==0.2.2 23 | entrypoints==0.3 24 | eth-hash==0.3.2 25 | eth-keyfile==0.5.1 26 | eth-keys==0.4.0 27 | eth-typing==3.0.0 28 | eth-utils==1.10.0 29 | frozenlist==1.3.0 30 | gunicorn==20.0.4 31 | h11==0.8.1 32 | h2==3.2.0 33 | hpack==3.0.0 34 | hstspreload==2021.12.1 35 | httptools==0.3.0 36 | httpx==0.9.3 37 | hyperframe==5.2.0 38 | iconcommons==1.1.3 39 | iconrpcserver==1.6.2 40 | iconsdk==1.3.6 41 | iconservice==1.8.10 42 | idna==3.3 43 | importlib-metadata==4.10.1 44 | ipykernel==6.7.0 45 | ipython==7.31.0 46 | iso3166==1.0.1 47 | jedi==0.18.1 48 | jsonrpcclient==3.3.6 49 | jsonrpcserver==4.1.3 50 | jsonschema==3.2.0 51 | jupyter-client==7.1.0 52 | jupyter-core==4.9.1 53 | keras==2.8.0 54 | kiwisolver==1.3.1 55 | matplotlib==3.4.2 56 | matplotlib-inline==0.1.3 57 | msgpack==1.0.3 58 | multidict==5.0.0 59 | multipledispatch==0.6.0 60 | nest-asyncio==1.5.4 61 | numpy==1.21.1 62 | pamqp==2.3.0 63 | parso==0.8.3 64 | pexpect==4.8.0 65 | pickle5==0.0.12 66 | pickleshare==0.7.5 67 | pika==1.1.0 68 | Pillow==8.3.1 69 | plyvel==1.2.0 70 | prompt-toolkit==3.0.24 71 | ptyprocess==0.7.0 72 | pycparser==2.21 73 | pycryptodome==3.13.0 74 | pygame==2.0.1 75 | Pygments==2.11.2 76 | pyparsing==2.4.7 77 | pyrsistent==0.18.1 78 | python-dateutil==2.8.2 79 | pyzmq==22.3.0 80 | requests==2.27.1 81 | requests-mock==1.8.0 82 | rfc3986==1.5.0 83 | sanic==19.12.5 84 | Sanic-Cors==0.10.0.post3 85 | Sanic-Plugins-Framework==0.9.5 86 | setproctitle==1.1.10 87 | six==1.16.0 88 | sniffio==1.2.0 89 | tbears==1.8.0 90 | toolz==0.11.2 91 | tornado==6.1 92 | tqdm==4.62.3 93 | traitlets==5.1.1 94 | typing-extensions==3.7.4.3 95 | ujson==5.1.0 96 | urllib3==1.26.8 97 | uvloop==0.14.0 98 | wcwidth==0.2.5 99 | websockets==8.1 100 | yarl==1.7.2 101 | zipp==3.7.0 102 | -------------------------------------------------------------------------------- /setup.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | cd webpage\ 3 | 4 | if not exist ".\bringing_old_photos_back_to_life" goto SETUP 5 | if exist ".\Bringing-Old-Photos-Back-to-Life" goto RENAME_FOLDER 6 | 7 | echo Root directory already exist, skipping cloning 8 | goto SkipRootClone 9 | 10 | :SETUP 11 | echo Starting Setup 12 | git clone https://github.com/microsoft/Bringing-Old-Photos-Back-to-Life.git 13 | 14 | :RENAME_FOLDER 15 | ren Bringing-Old-Photos-Back-to-Life bringing_old_photos_back_to_life 16 | 17 | 18 | 19 | 20 | :SkipRootClone 21 | cd bringing_old_photos_back_to_life\ 22 | if exist ".\Face_Enhancement\models\networks\Synchronized-BatchNorm-PyTorch\" goto SkipFaceEnhancementNetwork 23 | echo Downloading Synchronized BatchNorm pytorch in Face_Enhancement 24 | cd Face_Enhancement\models\networks\ 25 | git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 26 | xcopy /s/e Synchronized-BatchNorm-PyTorch\sync_batchnorm . 27 | cd ..\..\..\ 28 | 29 | :SkipFaceEnhancementNetwork 30 | if exist ".\Global\detecion_models\Synchronized-BatchNorm-PyTorch\" goto SkipDetectionNetwork 31 | echo Downloading Synchronized BatchNorm pytorch in Face_Detection 32 | cd Global\detection_models 33 | git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 34 | xcopy /s/e Synchronized-BatchNorm-PyTorch\sync_batchnorm . 35 | cd ..\..\ 36 | 37 | :SkipDetectionNetwork 38 | if exist ".\Face_Detection\shape_predictor_68_face_landmarks.dat.bz2" goto SkipShapePredictor 39 | echo Downloading the landmark detection pretrained model 40 | cd Face_Detection/ 41 | curl http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 --output shape_predictor_68_face_landmarks.dat.bz2 42 | where 7z >nul 2>nul 43 | IF %ERRORLEVEL% NEQ 0 goto NotFound7z 44 | 7z x shape_predictor_68_face_landmarks.dat.bz2 45 | cd ../ 46 | 47 | :SkipShapePredictor 48 | if exist ".\Face_Enhancement\face_checkpoints" goto SkipFaceCheckPoint 49 | echo Downloading the pretrained model from Azure Blob Storage 50 | cd Face_Enhancement/ 51 | curl https://github.com/microsoft/Bringing-Old-Photos-Back-to-Life/releases/download/v1.0/face_checkpoints.zip --output face_checkpoints.zip 52 | echo Extract face_checkpoints.zip from Face_Enhancement/face_checkpoints directory 53 | @REM tar -xvf face_checkpoints.zip -C face_checkpoints 54 | @REM del face_checkpoints.zip 55 | cd ../ 56 | 57 | :SkipFaceCheckPoint 58 | if exist ".\Global\global_checkpoints" goto END 59 | cd Global/ 60 | curl https://github.com/microsoft/Bringing-Old-Photos-Back-to-Life/releases/download/v1.0/global_checkpoints.zip --output global_checkpoints.zip 61 | @REM tar -xvf global_checkpoints.zip -C global_checkpoints 62 | 63 | @REM del global_checkpoints.zip 64 | echo Extract global_checkpoints.zip from Global/global_checkpoints directory 65 | goto END 66 | 67 | :NotFound7z 68 | echo 7z required to for extracting files, install 7z from here https://www.7-zip.org/ 69 | goto END 70 | 71 | :END -------------------------------------------------------------------------------- /Old_image_reconstruction/settings.py: -------------------------------------------------------------------------------- 1 | """ 2 | Django settings for Old_image_reconstruction project. 3 | 4 | Generated by 'django-admin startproject' using Django 4.0. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/4.0/topics/settings/ 8 | 9 | For the full list of settings and their values, see 10 | https://docs.djangoproject.com/en/4.0/ref/settings/ 11 | """ 12 | import os 13 | from pathlib import Path 14 | 15 | # Build paths inside the project like this: BASE_DIR / 'subdir'. 16 | BASE_DIR = Path(__file__).resolve().parent.parent 17 | 18 | 19 | # Quick-start development settings - unsuitable for production 20 | # See https://docs.djangoproject.com/en/4.0/howto/deployment/checklist/ 21 | 22 | # SECURITY WARNING: keep the secret key used in production secret! 23 | SECRET_KEY = 'django-insecure-uaai_s3*n0mfip15al-3x41xaz@=zlg=6rj@q@i6==u9=ey!7c' 24 | 25 | # SECURITY WARNING: don't run with debug turned on in production! 26 | DEBUG = True 27 | 28 | ALLOWED_HOSTS = [] 29 | 30 | 31 | # Application definition 32 | 33 | INSTALLED_APPS = [ 34 | 'django.contrib.admin', 35 | 'django.contrib.auth', 36 | 'django.contrib.contenttypes', 37 | 'django.contrib.sessions', 38 | 'django.contrib.messages', 39 | 'django.contrib.staticfiles', 40 | 'webpage', 41 | 'rest_framework', 42 | "corsheaders", 43 | ] 44 | 45 | MIDDLEWARE = [ 46 | 'django.middleware.security.SecurityMiddleware', 47 | 'django.contrib.sessions.middleware.SessionMiddleware', 48 | "corsheaders.middleware.CorsMiddleware", 49 | 'django.middleware.common.CommonMiddleware', 50 | 'django.middleware.csrf.CsrfViewMiddleware', 51 | 'django.contrib.auth.middleware.AuthenticationMiddleware', 52 | 'django.contrib.messages.middleware.MessageMiddleware', 53 | 'django.middleware.clickjacking.XFrameOptionsMiddleware', 54 | ] 55 | 56 | ROOT_URLCONF = 'Old_image_reconstruction.urls' 57 | 58 | TEMPLATES = [ 59 | { 60 | 'BACKEND': 'django.template.backends.django.DjangoTemplates', 61 | 'DIRS': [], 62 | 'APP_DIRS': True, 63 | 'OPTIONS': { 64 | 'context_processors': [ 65 | 'django.template.context_processors.debug', 66 | 'django.template.context_processors.request', 67 | 'django.contrib.auth.context_processors.auth', 68 | 'django.contrib.messages.context_processors.messages', 69 | ], 70 | }, 71 | }, 72 | ] 73 | 74 | WSGI_APPLICATION = 'Old_image_reconstruction.wsgi.application' 75 | 76 | 77 | # Database 78 | # https://docs.djangoproject.com/en/4.0/ref/settings/#databases 79 | 80 | DATABASES = { 81 | 'default': { 82 | 'ENGINE': 'django.db.backends.sqlite3', 83 | 'NAME': BASE_DIR / 'db.sqlite3', 84 | } 85 | } 86 | 87 | 88 | # Password validation 89 | # https://docs.djangoproject.com/en/4.0/ref/settings/#auth-password-validators 90 | 91 | AUTH_PASSWORD_VALIDATORS = [ 92 | { 93 | 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', 94 | }, 95 | { 96 | 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', 97 | }, 98 | { 99 | 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', 100 | }, 101 | { 102 | 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', 103 | }, 104 | ] 105 | 106 | 107 | # Internationalization 108 | # https://docs.djangoproject.com/en/4.0/topics/i18n/ 109 | 110 | LANGUAGE_CODE = 'en-us' 111 | 112 | TIME_ZONE = 'UTC' 113 | 114 | USE_I18N = True 115 | 116 | USE_TZ = True 117 | 118 | 119 | # Static files (CSS, JavaScript, Images) 120 | # https://docs.djangoproject.com/en/4.0/howto/static-files/ 121 | 122 | STATIC_URL = 'static/' 123 | CORS_ALLOWED_ORIGINS = [ 124 | "http://localhost:3000", 125 | "http://127.0.0.1:3000", 126 | ] 127 | 128 | # Default primary key field type 129 | # https://docs.djangoproject.com/en/4.0/ref/settings/#default-auto-field 130 | 131 | DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField' 132 | MEDIA_ROOT = os.path.join(BASE_DIR,'webpage','input_folder') 133 | MEDIA_URL='/input_folder/' 134 | -------------------------------------------------------------------------------- /webpage/test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "bb750d3f", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stderr", 11 | "output_type": "stream", 12 | "text": [ 13 | "2022-01-24 16:27:31.740846: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /media/imsanskar/My files/Projects/Minor/env/lib/python3.9/site-packages/cv2/../../lib64:\n", 14 | "2022-01-24 16:27:31.740881: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n" 15 | ] 16 | } 17 | ], 18 | "source": [ 19 | "import pix2pix.pix2pix as p2p\n", 20 | "from PIL import Image\n", 21 | "import torchvision.transforms as transforms\n", 22 | "import torch\n", 23 | "from pix2pix.data.image_manipulation import *\n", 24 | "from pix2pix.data.dataloader import *\n", 25 | "import SRGAN.SRGAN_pretrained as SR_GAN\n", 26 | "from ISR.models import RDN" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 3, 32 | "id": "b46e6db0", 33 | "metadata": {}, 34 | "outputs": [ 35 | { 36 | "data": { 37 | "text/plain": [ 38 | "" 39 | ] 40 | }, 41 | "execution_count": 3, 42 | "metadata": {}, 43 | "output_type": "execute_result" 44 | } 45 | ], 46 | "source": [ 47 | "generator = p2p.GeneratorUNet()\n", 48 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 49 | "generator.load_state_dict(torch.load(\"pix2pix/saved_models/generator.pth\", map_location = device))" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "id": "c18ede2d", 56 | "metadata": {}, 57 | "outputs": [ 58 | { 59 | "name": "stdout", 60 | "output_type": "stream", 61 | "text": [ 62 | "Downloading data from https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/ISR/rdn-C6-D20-G64-G064-x2/PSNR-driven/rdn-C6-D20-G64-G064-x2_PSNR_epoch086.hdf5\n", 63 | "30588928/66071288 [============>.................] - ETA: 22:01" 64 | ] 65 | } 66 | ], 67 | "source": [ 68 | "generator_srgan = SR_GAN.GeneratorSRGAN().to(device)\n", 69 | "generator_srgan.load_state_dict(torch.load(\"./SRGAN/saved_models/srresnet.pth\", map_location=device))\n", 70 | "generator_srgan.eval()\n", 71 | "rdn = RDN(weights='psnr-large')" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 14, 77 | "id": "277e957e", 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "image = Image.open(\"./pix2pix/data/train/old_images/002000.jpg\")\n", 82 | "image_width = 256\n", 83 | "image_height = 256\n", 84 | "transform = transforms.Compose([\n", 85 | " transforms.ToTensor(), # transform to tensor\n", 86 | " transforms.Resize((image_width, image_height)) # Resize the image to constant size\n", 87 | "])\n", 88 | "\n", 89 | "im = transform(image)\n", 90 | "output = np_to_pil(generator(im.unsqueeze(0))[0].detach().cpu().numpy())\n", 91 | "transform = transforms.Compose([\n", 92 | " transforms.ToTensor(),\n", 93 | "# transforms.Resize((256, 256), Image.BICUBIC),\n", 94 | "])\n", 95 | "input_image = transform(output).to(device)\n", 96 | "output_image = np_to_pil(generator_srgan(input_image.unsqueeze(0))[0].detach().cpu().numpy())\n", 97 | "output_image = rdn.predict(np.array(image))" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "id": "0976f7e8", 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "Image.fromarray(output_image)" 108 | ] 109 | } 110 | ], 111 | "metadata": { 112 | "kernelspec": { 113 | "display_name": "minor", 114 | "language": "python", 115 | "name": "minor" 116 | }, 117 | "language_info": { 118 | "codemirror_mode": { 119 | "name": "ipython", 120 | "version": 3 121 | }, 122 | "file_extension": ".py", 123 | "mimetype": "text/x-python", 124 | "name": "python", 125 | "nbconvert_exporter": "python", 126 | "pygments_lexer": "ipython3", 127 | "version": "3.9.5" 128 | } 129 | }, 130 | "nbformat": 4, 131 | "nbformat_minor": 5 132 | } 133 | -------------------------------------------------------------------------------- /webpage/convert_image.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from cv2 import transform 3 | import torch 4 | from .models import Image 5 | from .serializers import ImageSerializer 6 | import copy,io 7 | from PIL import Image as Photo 8 | from django.core.files.images import ImageFile 9 | import os 10 | from .bringing_old_photos_back_to_life import model 11 | import glob 12 | from .pix2pix_transpose import model as model_transpose 13 | from .pix2pix import model as pix2pix_model 14 | from torchvision import transforms 15 | from .pix2pix.data.image_manipulation import np_to_pil 16 | 17 | def get_file_name(file_path_list:str): 18 | file_path_list = list(reversed(list(file_path_list))) 19 | 20 | file_name = list(reversed(file_path_list)) 21 | for i, char in enumerate(file_path_list): 22 | if char == '/' or char == '\\': 23 | file_name = list(reversed(file_path_list[:i])) 24 | 25 | return "".join(file_name).split(".")[0] 26 | 27 | return "".join(file_name).split(".")[0] 28 | 29 | 30 | def restore_image(image_file:str, method): 31 | file_name = get_file_name(image_file) 32 | if os.path.exists("webpage/input_folder/output.png"): 33 | os.remove("webpage/input_folder/output.png") 34 | files = glob.glob('webpage/input_folder/*', recursive=True) 35 | for f in files: 36 | if get_file_name(f) != file_name: 37 | print("LOG: Old files removed") 38 | os.remove(f) 39 | if method == 3: 40 | torch.cuda.empty_cache() 41 | # cleanup intermediate directory 42 | # os.remove(image_file) 43 | 44 | files = glob.glob('webpage/output_folder/final_output/*', recursive=True) 45 | for f in files: 46 | os.remove(f) 47 | 48 | shutil.rmtree("webpage/output_folder/final_output/", ignore_errors=True) 49 | shutil.rmtree("webpage/output_folder/stage_1_restore_output/input_image/", ignore_errors=True) 50 | shutil.rmtree("webpage/output_folder/stage_1_restore_output/origin/", ignore_errors=True) 51 | shutil.rmtree("webpage/output_folder/stage_1_restore_output/restored_image/", ignore_errors=True) 52 | shutil.rmtree("webpage/output_folder/stage_1_restore_output/masks/input/", ignore_errors=True) 53 | shutil.rmtree("webpage/output_folder/stage_1_restore_output/masks/mask/", ignore_errors=True) 54 | 55 | model.modify("webpage/input_folder", with_scratch=True, image_filename = image_file) 56 | y = Photo.open(f"webpage/output_folder/stage_1_restore_output/restored_image/{file_name}.png") 57 | return y 58 | elif method == 1: 59 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 60 | generator_transpose = model_transpose.get_generator("webpage/pix2pix_transpose/saved_models/generator_mine.pth").to(device) 61 | 62 | image = Photo.open(image_file) 63 | tensor_transform = transforms.Compose([ 64 | transforms.ToTensor(), 65 | transforms.Resize((512, 512)) 66 | ]) 67 | 68 | output_image = generator_transpose(tensor_transform(image).unsqueeze(0).to(device)) 69 | output_image = np_to_pil( 70 | output_image.detach().cpu().numpy()[0] 71 | ) 72 | 73 | 74 | return output_image 75 | 76 | 77 | else: 78 | #TODO: Implement this branch 79 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 80 | generator = pix2pix_model.get_generator("webpage/pix2pix/saved_models_new_data/generator_263.pth").to(device) 81 | from torchsummary import summary 82 | 83 | print(print(generator)) 84 | image = Photo.open(image_file) 85 | tensor_transform = transforms.Compose([ 86 | transforms.ToTensor(), 87 | transforms.Resize((512, 512)) 88 | ]) 89 | 90 | output_image = generator(tensor_transform(image).unsqueeze(0).to(device)) 91 | 92 | 93 | output_image = np_to_pil( 94 | output_image.detach().cpu().numpy()[0] 95 | ) 96 | 97 | 98 | return output_image 99 | 100 | 101 | 102 | # restores the image from input_folder and saves it in the database 103 | # method describes which method to use for restoration 104 | def convert(method): 105 | image1 = Image.objects.filter().order_by('-pk')[0] #takes the latest data 106 | input = image1.image 107 | method = image1.method 108 | im = Photo.open(input) 109 | filename = "output.png" 110 | o_im = restore_image(image1.image.path, method=method) 111 | f = io.BytesIO() 112 | o_im.save(f,'PNG') 113 | outputimage = ImageFile(f,name = filename) 114 | image1.n_image = outputimage 115 | image1.save() 116 | 117 | 118 | files = os.listdir('./webpage/input_folder/') 119 | for f in files: 120 | if f != "output.png" and f != 'Default.jpg': 121 | print(f"Fuse file removed {f}") 122 | os.remove(f"webpage/input_folder/{f}") 123 | -------------------------------------------------------------------------------- /webpage/pix2pix/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.transforms as transforms 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import os 7 | 8 | 9 | def weights_init_normal(m): 10 | classname = m.__class__.__name__ 11 | 12 | if classname.find("Conv") != -1 and classname.find("DoubleConv") == 1: 13 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 14 | elif classname.find("BatchNorm2d") != -1: 15 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 16 | torch.nn.init.constant_(m.bias.data, 0.0) 17 | 18 | 19 | class UNetDown(nn.Module): 20 | def __init__(self, in_size, out_size, normalize = True, dropout = 0.0): 21 | super(UNetDown, self).__init__() 22 | layers = [ 23 | nn.Conv2d(in_size, out_size, 4, 2, 1, bias = False) 24 | ] 25 | if normalize: 26 | layers.append(nn.InstanceNorm2d(out_size)) 27 | 28 | layers.append(nn.LeakyReLU(0.2)) 29 | 30 | if dropout: 31 | layers.append(nn.Dropout(dropout)) 32 | 33 | self.model = nn.Sequential(*layers) 34 | 35 | def forward(self, x): 36 | return self.model(x) 37 | 38 | 39 | class DoubleConv(nn.Module): 40 | """(convolution => [BN] => ReLU) * 2""" 41 | 42 | def __init__(self, in_channels, out_channels, mid_channels=None): 43 | super().__init__() 44 | if not mid_channels: 45 | mid_channels = out_channels 46 | self.double_conv = nn.Sequential( 47 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), 48 | nn.BatchNorm2d(mid_channels), 49 | nn.ReLU(inplace=True), 50 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), 51 | nn.BatchNorm2d(out_channels), 52 | nn.ReLU(inplace=True) 53 | ) 54 | 55 | def forward(self, x): 56 | return self.double_conv(x) 57 | 58 | class UNetUp(nn.Module): 59 | def __init__(self, in_size, out_size, dropout = 0.0): 60 | super(UNetUp, self).__init__() 61 | 62 | layers = [ 63 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 64 | DoubleConv(in_size, out_size, in_size // 2), 65 | nn.InstanceNorm2d(out_size), 66 | nn.ReLU(inplace=True), 67 | ] 68 | if dropout: 69 | layers.append(nn.Dropout(dropout)) 70 | 71 | self.model = nn.Sequential(*layers) 72 | 73 | 74 | def forward(self, x, skip_input): 75 | x = self.model(x) 76 | x = torch.cat((x, skip_input), 1) 77 | 78 | return x 79 | 80 | 81 | 82 | class GeneratorUNet(nn.Module): 83 | def __init__(self, in_channels=3, out_channels=3): 84 | super(GeneratorUNet, self).__init__() 85 | 86 | self.down1 = UNetDown(in_channels, 64, normalize=False) 87 | self.down2 = UNetDown(64, 128) 88 | self.down3 = UNetDown(128, 256) 89 | self.down4 = UNetDown(256, 512, dropout=0.5) 90 | self.down5 = UNetDown(512, 512, dropout=0.5) 91 | self.down6 = UNetDown(512, 512, dropout=0.5) 92 | self.down7 = UNetDown(512, 512, dropout=0.5) 93 | self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5) 94 | 95 | self.up1 = UNetUp(512, 512, dropout=0.5) 96 | self.up2 = UNetUp(1024, 512, dropout=0.5) 97 | self.up3 = UNetUp(1024, 512, dropout=0.5) 98 | self.up4 = UNetUp(1024, 512, dropout=0.5) 99 | self.up5 = UNetUp(1024, 256) 100 | self.up6 = UNetUp(512, 128) 101 | self.up7 = UNetUp(256, 64) 102 | 103 | self.final = nn.Sequential( 104 | nn.Upsample(scale_factor=2), 105 | nn.ZeroPad2d((1, 0, 1, 0)), 106 | nn.Conv2d(128, out_channels, 4, padding=1), 107 | nn.Tanh(), 108 | ) 109 | 110 | def forward(self, x): 111 | # U-Net generator with skip connections from encoder to decoder 112 | d1 = self.down1(x) 113 | d2 = self.down2(d1) 114 | d3 = self.down3(d2) 115 | d4 = self.down4(d3) 116 | d5 = self.down5(d4) 117 | d6 = self.down6(d5) 118 | d7 = self.down7(d6) 119 | d8 = self.down8(d7) 120 | 121 | # unet connections 122 | u1 = self.up1(d8, d7) 123 | u2 = self.up2(u1, d6) 124 | u3 = self.up3(u2, d5) 125 | u4 = self.up4(u3, d4) 126 | u5 = self.up5(u4, d3) 127 | u6 = self.up6(u5, d2) 128 | u7 = self.up7(u6, d1) 129 | 130 | return self.final(u7) 131 | 132 | 133 | class Discriminator(nn.Module): 134 | def __init__(self, in_channels=3): 135 | super(Discriminator, self).__init__() 136 | 137 | def discriminator_block(in_filters, out_filters, normalization=True): 138 | """Returns downsampling layers of each discriminator block""" 139 | layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)] 140 | if normalization: 141 | layers.append(nn.InstanceNorm2d(out_filters)) 142 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 143 | return layers 144 | 145 | self.model = nn.Sequential( 146 | *discriminator_block(in_channels * 2, 64, normalization=False), 147 | *discriminator_block(64, 128), 148 | *discriminator_block(128, 256), 149 | *discriminator_block(256, 512), 150 | nn.ZeroPad2d((1, 0, 1, 0)), 151 | nn.Conv2d(512, 1, 4, padding=1, bias=False), 152 | nn.Sigmoid() 153 | ) 154 | 155 | def forward(self, img_A, img_B): 156 | # Concatenate image and condition image by channels to produce input 157 | img_input = torch.cat((img_A, img_B), 1) 158 | return self.model(img_input) 159 | 160 | def get_generator(path:str) -> GeneratorUNet : 161 | generator = GeneratorUNet() 162 | generator_state = None 163 | 164 | try: 165 | generator_state = torch.load(path) 166 | generator.load_state_dict(generator_state) 167 | except FileNotFoundError: 168 | print("Model path not found") 169 | 170 | 171 | return generator 172 | -------------------------------------------------------------------------------- /webpage/pix2pix_transpose/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.transforms as transforms 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import os 7 | import torchvision.datasets as dset 8 | from torch.autograd import Variable 9 | from tqdm import tqdm 10 | from PIL import Image 11 | from torch.utils.data import DataLoader 12 | 13 | def weights_init_normal(m): 14 | classname = m.__class__.__name__ 15 | 16 | if classname.find("Conv") != -1 and classname.find("DoubleConv") == 1: 17 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 18 | elif classname.find("BatchNorm2d") != -1: 19 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 20 | torch.nn.init.constant_(m.bias.data, 0.0) 21 | 22 | 23 | class UNetDown(nn.Module): 24 | def __init__(self, in_size, out_size, normalize = True, dropout = 0.0): 25 | super(UNetDown, self).__init__() 26 | layers = [ 27 | nn.Conv2d(in_size, out_size, 4, 2, 1, bias = False) 28 | ] 29 | if normalize: 30 | layers.append(nn.InstanceNorm2d(out_size)) 31 | 32 | layers.append(nn.LeakyReLU(0.2)) 33 | 34 | if dropout: 35 | layers.append(nn.Dropout(dropout)) 36 | 37 | self.model = nn.Sequential(*layers) 38 | 39 | def forward(self, x): 40 | return self.model(x) 41 | 42 | 43 | class DoubleConv(nn.Module): 44 | """(convolution => [BN] => ReLU) * 2""" 45 | 46 | def __init__(self, in_channels, out_channels, mid_channels=None): 47 | super().__init__() 48 | if not mid_channels: 49 | mid_channels = out_channels 50 | self.double_conv = nn.Sequential( 51 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), 52 | nn.BatchNorm2d(mid_channels), 53 | nn.ReLU(inplace=True), 54 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), 55 | nn.BatchNorm2d(out_channels), 56 | nn.ReLU(inplace=True) 57 | ) 58 | 59 | def forward(self, x): 60 | return self.double_conv(x) 61 | 62 | class UNetUp(nn.Module): 63 | def __init__(self, in_size, out_size, dropout = 0.0): 64 | super(UNetUp, self).__init__() 65 | 66 | layers = [ 67 | nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False), 68 | nn.InstanceNorm2d(out_size), 69 | nn.ReLU(inplace=True), 70 | ] 71 | if dropout: 72 | layers.append(nn.Dropout(dropout)) 73 | 74 | self.model = nn.Sequential(*layers) 75 | 76 | 77 | def forward(self, x, skip_input): 78 | x = self.model(x) 79 | x = torch.cat((x, skip_input), 1) 80 | 81 | return x 82 | 83 | 84 | 85 | class GeneratorUNet(nn.Module): 86 | def __init__(self, in_channels=3, out_channels=3): 87 | super(GeneratorUNet, self).__init__() 88 | 89 | self.down1 = UNetDown(in_channels, 64, normalize=False) 90 | self.down2 = UNetDown(64, 128) 91 | self.down3 = UNetDown(128, 256) 92 | self.down4 = UNetDown(256, 512, dropout=0.5) 93 | self.down5 = UNetDown(512, 512, dropout=0.5) 94 | self.down6 = UNetDown(512, 512, dropout=0.5) 95 | self.down7 = UNetDown(512, 512, dropout=0.5) 96 | self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5) 97 | 98 | self.up1 = UNetUp(512, 512, dropout=0.5) 99 | self.up2 = UNetUp(1024, 512, dropout=0.5) 100 | self.up3 = UNetUp(1024, 512, dropout=0.5) 101 | self.up4 = UNetUp(1024, 512, dropout=0.5) 102 | self.up5 = UNetUp(1024, 256) 103 | self.up6 = UNetUp(512, 128) 104 | self.up7 = UNetUp(256, 64) 105 | 106 | self.final = nn.Sequential( 107 | nn.Upsample(scale_factor=2), 108 | nn.ZeroPad2d((1, 0, 1, 0)), 109 | nn.Conv2d(128, out_channels, 4, padding=1), 110 | nn.Tanh(), 111 | ) 112 | 113 | def forward(self, x): 114 | # U-Net generator with skip connections from encoder to decoder 115 | d1 = self.down1(x) 116 | d2 = self.down2(d1) 117 | d3 = self.down3(d2) 118 | d4 = self.down4(d3) 119 | d5 = self.down5(d4) 120 | d6 = self.down6(d5) 121 | d7 = self.down7(d6) 122 | d8 = self.down8(d7) 123 | 124 | # unet connections 125 | u1 = self.up1(d8, d7) 126 | u2 = self.up2(u1, d6) 127 | u3 = self.up3(u2, d5) 128 | u4 = self.up4(u3, d4) 129 | u5 = self.up5(u4, d3) 130 | u6 = self.up6(u5, d2) 131 | u7 = self.up7(u6, d1) 132 | 133 | return self.final(u7) 134 | 135 | 136 | class Discriminator(nn.Module): 137 | def __init__(self, in_channels=3): 138 | super(Discriminator, self).__init__() 139 | 140 | def discriminator_block(in_filters, out_filters, normalization=True): 141 | """Returns downsampling layers of each discriminator block""" 142 | layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)] 143 | if normalization: 144 | layers.append(nn.InstanceNorm2d(out_filters)) 145 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 146 | return layers 147 | 148 | self.model = nn.Sequential( 149 | *discriminator_block(in_channels * 2, 64, normalization=False), 150 | *discriminator_block(64, 128), 151 | *discriminator_block(128, 256), 152 | *discriminator_block(256, 512), 153 | nn.ZeroPad2d((1, 0, 1, 0)), 154 | nn.Conv2d(512, 1, 4, padding=1, bias=False), 155 | nn.Sigmoid() 156 | ) 157 | 158 | def forward(self, img_A, img_B): 159 | # Concatenate image and condition image by channels to produce input 160 | img_input = torch.cat((img_A, img_B), 1) 161 | return self.model(img_input) 162 | 163 | 164 | def get_generator(path:str) -> GeneratorUNet : 165 | generator = GeneratorUNet() 166 | generator_state = None 167 | 168 | try: 169 | generator_state = torch.load(path) 170 | generator.load_state_dict(generator_state) 171 | except FileNotFoundError: 172 | print("Model path not found") 173 | 174 | 175 | return generator 176 | -------------------------------------------------------------------------------- /webpage/microsoft_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from collections import OrderedDict 4 | from torch.autograd import Variable 5 | from PIL import Image 6 | import torch 7 | import torchvision.utils as vutils 8 | import torchvision.transforms as transforms 9 | import numpy as np 10 | import cv2 11 | from subprocess import call 12 | import sys 13 | import shutil 14 | 15 | 16 | def modify(input_folder, with_scratch:bool, image_filename=None, cv2_frame=None): 17 | 18 | def run_cmd(command): 19 | try: 20 | call(command, shell=True) 21 | except KeyboardInterrupt: 22 | print("Process interrupted") 23 | sys.exit(1) 24 | 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("--input_folder", type=str, 27 | default= image_filename, help="Test images") 28 | parser.add_argument( 29 | "--output_folder", 30 | type=str, 31 | default="./output", 32 | help="Restored images, please use the absolute path", 33 | ) 34 | parser.add_argument("--GPU", type=str, default="-1", help="0,1,2") 35 | parser.add_argument( 36 | "--checkpoint_name", type=str, default="Setting_9_epoch_100", help="choose which checkpoint" 37 | ) 38 | parser.add_argument("--with_scratch",default="--with_scratch" ,action="store_true") 39 | # opts = parser.parse_args() 40 | 41 | gpu1 = str(-1) 42 | output_folder = "webpage/output_folder" 43 | checkpoint_name = "Setting_9_epoch_100" 44 | 45 | # resolve relative paths before changing directory 46 | input_folder = os.path.abspath(input_folder) 47 | output_folder = os.path.abspath(output_folder) 48 | if not os.path.exists(output_folder): 49 | os.makedirs(output_folder) 50 | 51 | main_environment = os.getcwd() 52 | 53 | # Stage 1: Overall Quality Improve 54 | print("Running Stage 1: Overall restoration") 55 | os.chdir("webpage/bringing_old_photos_back_to_life/Global") 56 | stage_1_input_dir = input_folder 57 | stage_1_output_dir = os.path.join( 58 | output_folder, "stage_1_restore_output") 59 | if not os.path.exists(stage_1_output_dir): 60 | os.makedirs(stage_1_output_dir) 61 | 62 | if not with_scratch: 63 | stage_1_command = ( 64 | "python test.py --test_mode Full --Quality_restore --test_input " 65 | + stage_1_input_dir 66 | + " --outputs_dir " 67 | + stage_1_output_dir 68 | + " --gpu_ids " 69 | + gpu1 70 | ) 71 | run_cmd(stage_1_command) 72 | else: 73 | 74 | mask_dir = os.path.join(stage_1_output_dir, "masks") 75 | new_input = os.path.join(mask_dir, "input") 76 | new_mask = os.path.join(mask_dir, "mask") 77 | stage_1_command_1 = ( 78 | "python detection.py --test_path " 79 | + stage_1_input_dir 80 | + " --output_dir " 81 | + mask_dir 82 | + " --input_size full_size" 83 | + " --GPU " 84 | + str(gpu1) 85 | ) 86 | stage_1_command_2 = ( 87 | "python test.py --Scratch_and_Quality_restore --test_input " 88 | + new_input 89 | + " --test_mask " 90 | + new_mask 91 | + " --outputs_dir " 92 | + stage_1_output_dir 93 | + " --gpu_ids " 94 | + gpu1 95 | ) 96 | run_cmd(stage_1_command_1) 97 | run_cmd(stage_1_command_2) 98 | 99 | # Solve the case when there is no face in the old photo 100 | stage_1_results = os.path.join(stage_1_output_dir, "restored_image") 101 | stage_4_output_dir = os.path.join(output_folder, "final_output") 102 | if not os.path.exists(stage_4_output_dir): 103 | os.makedirs(stage_4_output_dir) 104 | for x in os.listdir(stage_1_results): 105 | img_dir = os.path.join(stage_1_results, x) 106 | shutil.copy(img_dir, stage_4_output_dir) 107 | 108 | print("Finish Stage 1 ...") 109 | print("\n") 110 | 111 | # Stage 2: Face Detection 112 | 113 | print("Running Stage 2: Face Detection") 114 | # set the current working directory to the root 115 | os.chdir("../../../") 116 | os.chdir("webpage/bringing_old_photos_back_to_life/Face_Detection") 117 | stage_2_input_dir = os.path.join(stage_1_output_dir, "restored_image") 118 | stage_2_output_dir = os.path.join( 119 | output_folder, "stage_2_detection_output") 120 | if not os.path.exists(stage_2_output_dir): 121 | os.makedirs(stage_2_output_dir) 122 | stage_2_command = ( 123 | "python detect_all_dlib.py --url " + stage_2_input_dir + 124 | " --save_url " + stage_2_output_dir 125 | ) 126 | run_cmd(stage_2_command) 127 | print("Finish Stage 2 ...") 128 | print("\n") 129 | 130 | # Stage 3: Face Restore 131 | print("Running Stage 3: Face Enhancement") 132 | os.chdir(".././Face_Enhancement") 133 | stage_3_input_mask = "./" 134 | stage_3_input_face = stage_2_output_dir 135 | stage_3_output_dir = os.path.join( 136 | output_folder, "stage_3_face_output") 137 | if not os.path.exists(stage_3_output_dir): 138 | os.makedirs(stage_3_output_dir) 139 | stage_3_command = ( 140 | "python test_face.py --old_face_folder " 141 | + stage_3_input_face 142 | + " --old_face_label_folder " 143 | + stage_3_input_mask 144 | + " --tensorboard_log --name " 145 | + checkpoint_name 146 | + " --gpu_ids " 147 | + gpu1 148 | + " --load_size 256 --label_nc 18 --no_instance --preprocess_mode resize --batchSize 4 --results_dir " 149 | + stage_3_output_dir 150 | + " --no_parsing_map" 151 | ) 152 | run_cmd(stage_3_command) 153 | print("Finish Stage 3 ...") 154 | print("\n") 155 | 156 | # Stage 4: Warp back 157 | print("Running Stage 4: Blending") 158 | os.chdir(".././Face_Detection") 159 | stage_4_input_image_dir = os.path.join( 160 | stage_1_output_dir, "restored_image") 161 | stage_4_input_face_dir = os.path.join(stage_3_output_dir, "each_img") 162 | stage_4_output_dir = os.path.join(output_folder, "final_output") 163 | if not os.path.exists(stage_4_output_dir): 164 | os.makedirs(stage_4_output_dir) 165 | stage_4_command = ( 166 | "python align_warp_back_multiple_dlib.py --origin_url " 167 | + stage_4_input_image_dir 168 | + " --replace_url " 169 | + stage_4_input_face_dir 170 | + " --save_url " 171 | + stage_4_output_dir 172 | ) 173 | run_cmd(stage_4_command) 174 | print("Finish Stage 4 ...") 175 | print("\n") 176 | 177 | print("All the processing is done. Please check the results.") 178 | os.chdir("../../..") 179 | -------------------------------------------------------------------------------- /webpage/pix2pix/data/image_manipulation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[1]: 5 | 6 | 7 | import numpy as np 8 | import random 9 | import cv2 10 | from PIL import Image 11 | import os 12 | from io import BytesIO 13 | from tqdm.notebook import tqdm, trange 14 | import time 15 | 16 | def blend_image(original_image: Image.Image, blend_image_source: Image.Image, intensity = 0.4): 17 | """ 18 | Blends the original image with blend_image with intensity 19 | """ 20 | 21 | # converts the blend_image to the format of original image 22 | # because both the image needs to be in the same format 23 | # same goes for size 24 | blend_image_source = blend_image_source.convert(original_image.mode) 25 | blend_image_source = blend_image_source.resize(original_image.size) 26 | 27 | new_image = Image.new(original_image.mode, original_image.size) 28 | new_image = Image.blend(original_image, blend_image_source, intensity) 29 | 30 | return new_image 31 | 32 | 33 | # In[3]: 34 | 35 | 36 | def pil_to_np(img_pil): 37 | """ 38 | Converts image from pil Image to numpy array 39 | """ 40 | ar: np.ndarray = np.array(img_pil) 41 | if len(ar.shape) == 3: 42 | """ 43 | Tensor transpose, since in this case tensor is 3D the order of transpose can be different 44 | In 2D matrix the transpose is only i,j-> j,i but in more than 2D matrix different permutation can be 45 | applied 46 | """ 47 | ar = ar.transpose(2, 0, 1) 48 | else: 49 | ar = ar[None, ...] 50 | 51 | return ar.astype(np.float32) / 255. 52 | 53 | 54 | # In[4]: 55 | 56 | 57 | def np_to_pil(img_np): 58 | """ 59 | Converts np.ndarray to Image.Image object 60 | """ 61 | ar = np.clip(img_np * 255, 0, 255).astype(np.uint8) 62 | 63 | if img_np.shape[0] == 1: 64 | ar = ar[0] 65 | else: 66 | ar = ar.transpose(1, 2, 0) 67 | 68 | return Image.fromarray(ar) 69 | 70 | 71 | # In[5]: 72 | 73 | 74 | def synthesize_salt_pepper(image: Image.Image, amount, salt_vs_pepper): 75 | """ 76 | Salt and pepper noise is also known as an impulse noise, this noise can be caused by sharp and sudden 77 | disturbances in the image signal. gives the appearance of scattered white or black(or both) pixel over 78 | the image 79 | """ 80 | img_pil=pil_to_np(image) 81 | 82 | out = img_pil.copy() 83 | p = amount 84 | q = salt_vs_pepper 85 | flipped = np.random.choice([True, False], size=img_pil.shape, 86 | p=[p, 1 - p]) 87 | salted = np.random.choice([True, False], size=img_pil.shape, 88 | p=[q, 1 - q]) 89 | peppered = ~salted 90 | out[flipped & salted] = 1 91 | out[flipped & peppered] = 0. 92 | noisy = np.clip(out, 0, 1).astype(np.float32) 93 | return np_to_pil(noisy) 94 | 95 | 96 | # In[6]: 97 | 98 | 99 | def synthesize_speckle(image,std_l,std_r): 100 | 101 | ## Give PIL, return the noisy PIL 102 | 103 | img_pil=pil_to_np(image) 104 | 105 | mean=0 106 | std=random.uniform(std_l/255.,std_r/255.) 107 | gauss=np.random.normal(loc=mean,scale=std,size=img_pil.shape) 108 | noisy=img_pil+gauss*img_pil 109 | noisy=np.clip(noisy,0,1).astype(np.float32) 110 | 111 | return np_to_pil(noisy) 112 | 113 | 114 | # In[7]: 115 | 116 | 117 | def blur_image_v2(img): 118 | x=np.array(img) 119 | kernel_size_candidate=[(3,3),(5,5),(7,7)] 120 | kernel_size=random.sample(kernel_size_candidate,1)[0] 121 | std=random.uniform(1.,5.) 122 | 123 | #print("The gaussian kernel size: (%d,%d) std: %.2f"%(kernel_size[0],kernel_size[1],std)) 124 | blur=cv2.GaussianBlur(x,kernel_size,std) 125 | 126 | return Image.fromarray(blur.astype(np.uint8)) 127 | 128 | 129 | # In[8]: 130 | 131 | 132 | def synthesize_low_resolution(image: Image.Image): 133 | """ 134 | Creates a low resolution image from high resolution image 135 | """ 136 | width, height = image.size 137 | 138 | new_width = np.random.randint(int(width / 2), width - int(width / 5)) 139 | new_height = np.random.randint(int(height / 2), height - int(height / 5)) 140 | 141 | image = image.resize((new_width, new_height), Image.BICUBIC) 142 | 143 | if random.uniform(0, 1) < 0.5: 144 | image = image.resize((width, height), Image.NEAREST) 145 | else: 146 | image = image.resize((width, height), Image.BILINEAR) 147 | 148 | return image 149 | 150 | 151 | # In[9]: 152 | 153 | 154 | def online_add_degradation_v2(img): 155 | task_id = np.random.permutation(4) 156 | 157 | for x in task_id: 158 | if x == 0 and random.uniform(0,1)<0.7: 159 | img = blur_image_v2(img) 160 | if x == 1 and random.uniform(0,1)<0.7: 161 | flag = random.choice([1, 2, 3]) 162 | if flag == 1: 163 | pass 164 | # img = synthesize_gaussian(img, 5, 50) 165 | if flag == 2: 166 | img = synthesize_speckle(img, 5, 50) 167 | if flag == 3: 168 | img = synthesize_salt_pepper(img, random.uniform(0, 0.01), random.uniform(0.3, 0.8)) 169 | if x == 2 and random.uniform(0,1)<0.7: 170 | img=synthesize_low_resolution(img) 171 | 172 | if x==3 and random.uniform(0,1)<0.7: 173 | img=convertToJpeg(img,random.randint(40,100)) 174 | 175 | return img 176 | 177 | 178 | # In[10]: 179 | 180 | 181 | def zero_mask(row, col): 182 | x = np.zeros((row, col, 3)) 183 | mask=Image.fromarray(x).convert("RGB") 184 | return mask 185 | 186 | def irregular_hole_synthesize(img, mask): 187 | """ 188 | Create holes using scrach paper textures 189 | Args: 190 | img: Original Image 191 | mask: scratch paper texture 192 | """ 193 | img_np = np.array(img).astype('uint8') 194 | mask = mask.resize(img.size) 195 | mask = mask.convert(img.mode) 196 | mask_np = np.array(mask).astype('uint8') 197 | mask_np = mask_np / 255 198 | img_new=img_np * (1 - mask_np) + mask_np * 255 199 | 200 | 201 | hole_img=Image.fromarray(img_new.astype('uint8')).convert("RGB") 202 | 203 | 204 | return hole_img,mask.convert("L") 205 | 206 | 207 | # In[11]: 208 | 209 | 210 | def convertToJpeg(im,quality): 211 | with BytesIO() as f: 212 | im.save(f, format='JPEG',quality=quality) 213 | f.seek(0) 214 | return Image.open(f).convert('RGB') 215 | 216 | 217 | # In[12]: 218 | 219 | -------------------------------------------------------------------------------- /webpage/pix2pix/generator_new_data.csv: -------------------------------------------------------------------------------- 1 | Generator Error, 6.799430091118178, 0 2 | Generator Error, 6.568183241688254, 1 3 | Generator Error, 6.424796258542021, 2 4 | Generator Error, 6.1330551684129375, 3 5 | Generator Error, 5.986139562193432, 4 6 | Generator Error, 5.920468651296521, 5 7 | Generator Error, 5.80306859977345, 6 8 | Generator Error, 5.750150122116727, 7 9 | Generator Error, 5.67178101412697, 8 10 | Generator Error, 5.587790112984951, 9 11 | Generator Error, 5.553457332654598, 10 12 | Generator Error, 5.458462151284453, 11 13 | Generator Error, 5.373598050708553, 12 14 | Generator Error, 5.376967481787214, 13 15 | Generator Error, 5.303945089927645, 14 16 | Generator Error, 5.295398744793446, 15 17 | Generator Error, 5.239869391510242, 16 18 | Generator Error, 5.178962830808226, 17 19 | Generator Error, 5.216716329407783, 18 20 | Generator Error, 5.122097533012071, 19 21 | Generator Error, 5.062556581352147, 20 22 | Generator Error, 5.044870180775457, 21 23 | Generator Error, 5.004301670386311, 22 24 | Generator Error, 4.971274247187625, 23 25 | Generator Error, 4.923108525149269, 24 26 | Generator Error, 4.839897621720462, 25 27 | Generator Error, 4.869583673803525, 26 28 | Generator Error, 4.894023786479529, 27 29 | Generator Error, 4.77692180532013, 28 30 | Generator Error, 4.787149373116149, 29 31 | Generator Error, 4.659107276242042, 30 32 | Generator Error, 4.626132471026576, 31 33 | Generator Error, 4.601359848740437, 32 34 | Generator Error, 4.58438978630327, 33 35 | Generator Error, 4.510311632555247, 34 36 | Generator Error, 4.442683311469654, 35 37 | Generator Error, 4.458588139639154, 36 38 | Generator Error, 4.5036555161494265, 37 39 | Generator Error, 4.447549068429171, 38 40 | Generator Error, 4.415379429044832, 39 41 | Generator Error, 4.539547005533719, 40 42 | Generator Error, 4.422808963536309, 41 43 | Generator Error, 4.405371332349886, 42 44 | Generator Error, 4.421134810030687, 43 45 | Generator Error, 4.327863224558957, 44 46 | Generator Error, 4.365889167604338, 45 47 | Generator Error, 4.4326573805210705, 46 48 | Generator Error, 4.247371026318336, 47 49 | Generator Error, 4.287095420714114, 48 50 | Generator Error, 4.235578075561233, 49 51 | Generator Error, 4.256990065592777, 50 52 | Generator Error, 4.200384576964288, 51 53 | Generator Error, 4.204652032924695, 52 54 | Generator Error, 4.202172133405853, 53 55 | Generator Error, 4.194275317536561, 54 56 | Generator Error, 4.118708263331946, 55 57 | Generator Error, 4.157135047840075, 56 58 | Generator Error, 4.156949504699997, 57 59 | Generator Error, 4.120126344405199, 58 60 | Generator Error, 4.0969506969016765, 59 61 | Error Generator, 3.889726855743553, 60 62 | Error Generator, 3.833918946614198, 61 63 | Error Generator, 3.815539152136346, 62 64 | Error Generator, 3.8060858181867556, 63 65 | Error Generator, 3.788152533119889, 64 66 | Error Generator, 3.768929883767078, 65 67 | Error Generator, 3.7782745429124875, 66 68 | Error Generator, 3.755999118795892, 67 69 | Error Generator, 3.744878339541467, 68 70 | Error Generator, 3.7482743918613237, 69 71 | Error Generator, 3.7121499115821877, 70 72 | Error Generator, 3.731086032650482, 71 73 | Error Generator, 3.702971134140593, 72 74 | Error Generator, 3.7050233520037756, 73 75 | Error Generator, 3.7084619931135134, 74 76 | Error Generator, 3.694056890587106, 75 77 | Error Generator, 3.684533758751024, 76 78 | Error Generator, 3.6811046261358036, 77 79 | Error Generator, 3.6707167975710466, 78 80 | Error Generator, 3.6680691253517477, 79 81 | Generator Error, 3.9067787190353918, 80 82 | Generator Error, 3.987543861222358, 81 83 | Generator Error, 3.922849831925599, 82 84 | Generator Error, 4.02415139140285, 83 85 | Generator Error, 3.9275327880119644, 84 86 | Generator Error, 3.912382847455971, 85 87 | Generator Error, 4.154399633407593, 86 88 | Generator Error, 3.9098625681699457, 87 89 | Generator Error, 3.920629666332056, 88 90 | Generator Error, 3.9652383236830677, 89 91 | Generator Error, 3.874138241031777, 90 92 | Generator Error, 3.8871624406299663, 91 93 | Generator Error, 3.862695861678613, 92 94 | Generator Error, 3.8828822315419127, 93 95 | Generator Error, 3.9688979636580317, 94 96 | Generator Error, 3.808995756359608, 95 97 | Generator Error, 3.777002418902437, 96 98 | Generator Error, 3.8370921430478986, 97 99 | Generator Error, 3.8569612530247794, 98 100 | Generator Error, 3.8054505146954902, 99 101 | Generator Error, 3.852648493908204, 100 102 | Generator Error, 3.789691459090084, 101 103 | Generator Error, 3.807946267690042, 102 104 | Generator Error, 3.7914500018942947, 103 105 | Generator Error, 3.8624822558558938, 104 106 | Generator Error, 3.768465640427042, 105 107 | Error Generator, 3.591369274103246, 106 108 | Error Generator, 3.547187601785524, 107 109 | Error Generator, 3.541561314279999, 108 110 | Error Generator, 3.529248766424532, 109 111 | Error Generator, 3.5320867158790334, 110 112 | Error Generator, 3.529770949440545, 111 113 | Error Generator, 3.5005355898237904, 112 114 | Error Generator, 3.4898030385021914, 113 115 | Error Generator, 3.495631077843255, 114 116 | Error Generator, 3.489176648487977, 115 117 | Error Generator, 3.487669780921032, 116 118 | Error Generator, 3.477573949578814, 117 119 | Error Generator, 3.4818776094518, 118 120 | Error Generator, 3.45974994383717, 119 121 | Error Generator, 3.4527430240576864, 120 122 | Error Generator, 3.469422352822471, 121 123 | Error Generator, 3.454811902972759, 122 124 | Error Generator, 3.451493571719852, 123 125 | Error Generator, 3.451590960624659, 124 126 | Error Generator, 3.4500773755295016, 125 127 | Generator Error, 3.6459970809660938, 126 128 | Generator Error, 3.7012045546629584, 127 129 | Generator Error, 3.690802244179149, 128 130 | Generator Error, 3.7069921067459046, 129 131 | Generator Error, 3.8283904994848563, 130 132 | Generator Error, 3.728886988679719, 131 133 | Generator Error, 3.692484015294354, 132 134 | Generator Error, 3.6783829179553478, 133 135 | Error Generator, 3.491557195853283, 134 136 | Error Generator, 3.4499836103611083, 135 137 | Error Generator, 3.4274557857151846, 136 138 | Error Generator, 3.4314365601652606, 137 139 | Error Generator, 3.4227817691332922, 138 140 | Error Generator, 3.4323016557648285, 139 141 | Error Generator, 3.4082265668570715, 140 142 | Error Generator, 3.4213536881722546, 141 143 | Error Generator, 3.4236193584604853, 142 144 | Error Generator, 3.4035427570343018, 143 145 | Error Generator, 3.4024931435336434, 144 146 | Error Generator, 3.401032151769123, 145 147 | Error Generator, 3.396521564908502, 146 148 | Error Generator, 3.3895647864771115, 147 149 | Error Generator, 3.381002408068327, 148 150 | Error Generator, 3.3897696535734205, 149 151 | Error Generator, 3.379003467153034, 150 152 | Error Generator, 3.356737688254406, 151 153 | Error Generator, 3.3722420428036513, 152 154 | Error Generator, 3.361182751813771, 153 155 | Generator Error, 3.5345533329271093, 154 156 | Generator Error, 3.597825402995933, 155 157 | Generator Error, 3.587405228342847, 156 158 | Generator Error, 3.59052830775881, 157 159 | Generator Error, 3.596191847732312, 158 160 | Generator Error, 3.6021588582956294, 159 161 | Generator Error, 3.7916152341284226, 160 162 | Generator Error, 3.581287878094517, 161 163 | Generator Error, 3.6144641587942727, 162 164 | Error Generator, 3.425344358688282, 163 165 | Error Generator, 3.3955704239307423, 164 166 | Error Generator, 3.3808310088388165, 165 167 | Error Generator, 3.368860540796795, 166 168 | Error Generator, 3.3425106199996732, 167 169 | Error Generator, 3.359740217715078, 168 170 | Error Generator, 3.339936003300816, 169 171 | Error Generator, 3.3453277897495792, 170 172 | Error Generator, 3.3509799846540695, 171 173 | Error Generator, 3.3337088715973624, 172 174 | Error Generator, 3.3440170389781065, 173 175 | Error Generator, 3.3248392021486546, 174 176 | Error Generator, 3.306080936820586, 175 177 | Error Generator, 3.3082236486588608, 176 178 | Error Generator, 3.296269559182262, 177 179 | Error Generator, 3.3140091997752257, 178 180 | Error Generator, 3.3111735655798165, 179 181 | Error Generator, 3.293981317095282, 180 182 | Error Generator, 3.3052177463097596, 181 183 | Error Generator, 3.3249623154011947, 182 184 | Error Generator, 3.318954863254493, 183 185 | Error Generator, 3.3063021429342117, 184 186 | Error Generator, 3.3010321217125624, 185 187 | Error Generator, 3.286198899644246, 186 188 | Error Generator, 3.2808820206972094, 187 189 | Error Generator, 3.3030062919544383, 188 190 | Error Generator, 3.294364359706499, 189 191 | Error Generator, 3.2654629543127602, 190 192 | Error Generator, 3.2599405073961676, 191 193 | Error Generator, 3.2622655701163588, 192 194 | Error Generator, 3.2767279337573525, 193 195 | Error Generator, 3.2568603493519968, 194 196 | Error Generator, 3.242287435278987, 195 197 | Error Generator, 3.239465666133047, 196 198 | Error Generator, 3.228658444044606, 197 199 | Error Generator, 3.246905445263086, 198 200 | Error Generator, 3.240637485554676, 199 201 | Error Generator, 3.239376709161215, 200 202 | Error Generator, 3.2307998786698904, 201 203 | Error Generator, 3.242731672249093, 202 204 | Error Generator, 3.246264882435072, 203 205 | Error Generator, 3.2274485708072485, 204 206 | Error Generator, 3.2266689594218274, 200 207 | Error Generator, 3.230928016814175, 201 208 | Error Generator, 3.2473561921656526, 202 209 | Error Generator, 3.264895219676542, 203 210 | Error Generator, 3.2445632085105442, 204 211 | Error Generator, 3.2293283134106767, 200 212 | Error Generator, 3.2263952533140876, 201 213 | Error Generator, 3.217344825630946, 202 214 | Error Generator, 3.2261077868227925, 203 215 | Error Generator, 3.229645482751707, 204 216 | Error Generator, 3.221687719521933, 200 217 | Error Generator, 3.220171371043123, 201 218 | Error Generator, 3.2122639867643645, 202 219 | Generator Error, 3.451730440777971, 203 220 | Generator Error, 3.516746714087947, 204 221 | -------------------------------------------------------------------------------- /webpage/pix2pix/discriminator_new_data.csv: -------------------------------------------------------------------------------- 1 | Discriminator Error, 0.1636957909669302, 0 2 | Discriminator Error, 0.0247244372395772, 1 3 | Discriminator Error, 0.029819785001034917, 2 4 | Discriminator Error, 0.027094710068915808, 3 5 | Discriminator Error, 0.020468627905248525, 4 6 | Discriminator Error, 0.008369045325999917, 5 7 | Discriminator Error, 0.020362799186315946, 6 8 | Discriminator Error, 0.023936884426982687, 7 9 | Discriminator Error, 0.01531627328209503, 8 10 | Discriminator Error, 0.022489982332009197, 9 11 | Discriminator Error, 0.0161271960701525, 10 12 | Discriminator Error, 0.028012227567415192, 11 13 | Discriminator Error, 0.019439485368916428, 12 14 | Discriminator Error, 0.01258276984511134, 13 15 | Discriminator Error, 0.018101139030820286, 14 16 | Discriminator Error, 0.011892171175869047, 15 17 | Discriminator Error, 0.019467066261297888, 16 18 | Discriminator Error, 0.008437896227890014, 17 19 | Discriminator Error, 0.0058646649511756065, 18 20 | Discriminator Error, 0.02707087063300474, 19 21 | Discriminator Error, 0.014680346719341202, 20 22 | Discriminator Error, 0.0018944458123365912, 21 23 | Discriminator Error, 0.022078127758387182, 22 24 | Discriminator Error, 0.013436762078608437, 23 25 | Discriminator Error, 0.02946412957956881, 24 26 | Discriminator Error, 0.004054017344773224, 25 27 | Discriminator Error, 0.012269061793997973, 26 28 | Discriminator Error, 0.015664149043786595, 27 29 | Discriminator Error, 0.006292782985565962, 28 30 | Discriminator Error, 1.1889080283714227e-05, 29 31 | Discriminator Error, 0.020176293542782196, 30 32 | Discriminator Error, 0.01975508892711059, 31 33 | Discriminator Error, 0.012961544006710636, 32 34 | Discriminator Error, 0.009520235178021976, 33 35 | Discriminator Error, 0.015974175151477726, 34 36 | Discriminator Error, 0.020026790864568607, 35 37 | Discriminator Error, 0.015145051998486864, 36 38 | Discriminator Error, 0.00010882138574179398, 37 39 | Discriminator Error, 0.011842663106487692, 38 40 | Discriminator Error, 0.0037273257731606506, 39 41 | Discriminator Error, 0.02600713590701608, 40 42 | Discriminator Error, 0.001240329756835838, 41 43 | Discriminator Error, 0.0034682251411274874, 42 44 | Discriminator Error, 0.012294893634667977, 43 45 | Discriminator Error, 0.007926536737710242, 44 46 | Discriminator Error, 0.001911827350461456, 45 47 | Discriminator Error, 0.01214886361057139, 46 48 | Discriminator Error, 0.015329277038518408, 47 49 | Discriminator Error, 0.009536089788255367, 48 50 | Discriminator Error, 0.01613865657485189, 49 51 | Discriminator Error, 0.00303499045768486, 50 52 | Discriminator Error, 0.018029607512101455, 51 53 | Discriminator Error, 0.010264516207204335, 52 54 | Discriminator Error, 0.011863380912619494, 53 55 | Discriminator Error, 0.005669155097448805, 54 56 | Discriminator Error, 0.011601125870404145, 55 57 | Discriminator Error, 0.005289249244501695, 56 58 | Discriminator Error, 0.024797408231042737, 57 59 | Discriminator Error, 0.01322961824828175, 58 60 | Discriminator Error, 0.0007661512566268793, 59 61 | Error Generator, 0.0017096178234547504, 60 62 | Error Generator, 5.034875490990258e-06, 61 63 | Error Generator, 7.956948129752815e-06, 62 64 | Error Generator, 1.706688851138271e-06, 63 65 | Error Generator, 4.699985263165477e-06, 64 66 | Error Generator, 5.835982324003978e-06, 65 67 | Error Generator, 2.855739878774231e-06, 66 68 | Error Generator, 1.3522647606049764e-06, 67 69 | Error Generator, 2.866080824037036e-06, 68 70 | Error Generator, 6.931706502322326e-07, 69 71 | Error Generator, 1.6215068995006172e-06, 70 72 | Error Generator, 6.514777195077239e-07, 71 73 | Error Generator, 8.043631625704314e-07, 72 74 | Error Generator, 6.006279733788215e-07, 73 75 | Error Generator, 0.014338978910797036, 74 76 | Error Generator, 0.002150866955361727, 75 77 | Error Generator, 0.0068945366962601925, 76 78 | Error Generator, 0.003025524857496896, 77 79 | Error Generator, 0.0010866643630071653, 78 80 | Error Generator, 0.0013457160235327854, 79 81 | Discriminator Error, 0.006450916799357465, 80 82 | Discriminator Error, 0.0002803738102327531, 81 83 | Discriminator Error, 0.020137342034874984, 82 84 | Discriminator Error, 0.004984511587083992, 83 85 | Discriminator Error, 0.01627713974778171, 84 86 | Discriminator Error, 0.0038846248145421288, 85 87 | Discriminator Error, 0.0003632981248226814, 86 88 | Discriminator Error, 7.711619253921082e-06, 87 89 | Discriminator Error, 0.010715467355056775, 88 90 | Discriminator Error, 0.010907993478716037, 89 91 | Discriminator Error, 5.874320827999073e-05, 90 92 | Discriminator Error, 0.01031196084083905, 91 93 | Discriminator Error, 0.010514290943641626, 92 94 | Discriminator Error, 1.0749312463946364e-05, 93 95 | Discriminator Error, 0.007550165921446628, 94 96 | Discriminator Error, 0.0005106016056323569, 95 97 | Discriminator Error, 2.5078064139559833e-05, 96 98 | Discriminator Error, 0.004141725663646721, 97 99 | Discriminator Error, 0.02880163643598681, 98 100 | Discriminator Error, 0.0007017043445108779, 99 101 | Discriminator Error, 0.03291082113507696, 100 102 | Discriminator Error, 0.0008145567083021017, 101 103 | Discriminator Error, 0.0078473423126278, 102 104 | Discriminator Error, 0.004082521993429954, 103 105 | Discriminator Error, 0.025626596581317152, 104 106 | Discriminator Error, 0.01246223538280446, 105 107 | Error Generator, 0.004227198561254808, 106 108 | Error Generator, 0.00027570282426964685, 107 109 | Error Generator, 0.000862823335444827, 108 110 | Error Generator, 0.0015386307774054564, 109 111 | Error Generator, 0.00019774551598314362, 110 112 | Error Generator, 0.0014890850143695383, 111 113 | Error Generator, 0.004258690872853429, 112 114 | Error Generator, 0.0018219588617738914, 113 115 | Error Generator, 0.0017767260255236983, 114 116 | Error Generator, 0.0016090054955843688, 115 117 | Error Generator, 2.1710020074766448e-05, 116 118 | Error Generator, 0.0014084077139756247, 117 119 | Error Generator, 0.0013971816553831832, 118 120 | Error Generator, 0.00017998490549403737, 119 121 | Error Generator, 3.391050646509112e-05, 120 122 | Error Generator, 0.001714194599615786, 121 123 | Error Generator, 0.0014821173897843504, 122 124 | Error Generator, 0.004661631874854993, 123 125 | Error Generator, 0.0031180511948742215, 124 126 | Error Generator, 0.0026355096633976388, 125 127 | Discriminator Error, 0.010556636609937229, 126 128 | Discriminator Error, 0.009543748753229813, 127 129 | Discriminator Error, 0.0007758908175711853, 128 130 | Discriminator Error, 0.016520816193560002, 129 131 | Discriminator Error, 0.006124662224695524, 130 132 | Discriminator Error, 0.002914256874504223, 131 133 | Discriminator Error, 0.007566204301684339, 132 134 | Discriminator Error, 0.01633244731429131, 133 135 | Error Generator, 0.0018276535184288773, 134 136 | Error Generator, 0.0031788918803276134, 135 137 | Error Generator, 0.003044590045821311, 136 138 | Error Generator, 0.0017465172788422802, 137 139 | Error Generator, 0.0016023907663223324, 138 140 | Error Generator, 0.0009126838333557272, 139 141 | Error Generator, 0.002294249237611066, 140 142 | Error Generator, 0.0010067418944474575, 141 143 | Error Generator, 0.005225907479412175, 142 144 | Error Generator, 0.004118384853296162, 143 145 | Error Generator, 0.0008594426858121161, 144 146 | Error Generator, 0.00015557708534560152, 145 147 | Error Generator, 0.00023550425447583888, 146 148 | Error Generator, 0.0011918149490881757, 147 149 | Error Generator, 0.0017106313296783716, 148 150 | Error Generator, 0.0005648238553632302, 149 151 | Error Generator, 1.5107930709556014e-05, 150 152 | Error Generator, 0.0028077621390373395, 151 153 | Error Generator, 0.005645933331465965, 152 154 | Error Generator, 0.0016567445484079967, 153 155 | Discriminator Error, 0.009816257661421107, 154 156 | Discriminator Error, 0.008165603218976357, 155 157 | Discriminator Error, 0.008426300556401508, 156 158 | Discriminator Error, 0.007587060573275067, 157 159 | Discriminator Error, 0.009455372259729205, 158 160 | Discriminator Error, 0.00757393253938887, 159 161 | Discriminator Error, 0.004592133120737779, 160 162 | Discriminator Error, 0.010307996371456633, 161 163 | Discriminator Error, 0.007193993400402873, 162 164 | Error Generator, 0.0025945229827103954, 163 165 | Error Generator, 0.0018146596081556816, 164 166 | Error Generator, 0.0014071291925412497, 165 167 | Error Generator, 0.001028582442159514, 166 168 | Error Generator, 0.0037837975100448195, 167 169 | Error Generator, 0.0010577817741822652, 168 170 | Error Generator, 0.0019473018520209024, 169 171 | Error Generator, 0.0007754038508321713, 170 172 | Error Generator, 0.00015861460462541035, 171 173 | Error Generator, 0.0012566628939655396, 172 174 | Error Generator, 0.002208607117874268, 173 175 | Error Generator, 0.0015970652068386375, 174 176 | Error Generator, 0.002521302971318352, 175 177 | Error Generator, 0.0017706180694739063, 176 178 | Error Generator, 0.0017350977570044493, 177 179 | Error Generator, 0.0015713284573835545, 178 180 | Error Generator, 0.0029022772106233642, 179 181 | Error Generator, 0.005000024933748933, 180 182 | Error Generator, 0.002802764863976385, 181 183 | Error Generator, 0.0005673081513520666, 182 184 | Error Generator, 0.002863177639401655, 183 185 | Error Generator, 0.005074260216913496, 184 186 | Error Generator, 0.0020840722442299685, 185 187 | Error Generator, 0.0018976665245299207, 186 188 | Error Generator, 0.001949973811781059, 187 189 | Error Generator, 0.0005108897981967193, 188 190 | Error Generator, 0.00015602336188333965, 189 191 | Error Generator, 0.0006085177197125892, 190 192 | Error Generator, 0.0014615266580733375, 191 193 | Error Generator, 0.00022743264836207173, 192 194 | Error Generator, 0.0026033203350031464, 193 195 | Error Generator, 0.0012906662755888952, 194 196 | Error Generator, 0.002480837879836327, 195 197 | Error Generator, 0.002569136458368491, 196 198 | Error Generator, 0.002489244662305403, 197 199 | Error Generator, 0.0015900169543911296, 198 200 | Error Generator, 0.001399487088624506, 199 201 | Error Generator, 0.0015340080899654753, 200 202 | Error Generator, 0.00027228564288874315, 201 203 | Error Generator, 0.00022029949304237215, 202 204 | Error Generator, 0.0031045419446822185, 203 205 | Error Generator, 0.0003763901171686337, 204 206 | Error Generator, 0.0018582279744081115, 200 207 | Error Generator, 0.00029183454060117346, 201 208 | Error Generator, 0.002411676044135595, 202 209 | Error Generator, 0.00267801285085113, 203 210 | Error Generator, 0.0019338618063910033, 204 211 | Error Generator, 0.002655433293293313, 200 212 | Error Generator, 0.0019320536541115915, 201 213 | Error Generator, 0.0003277269656157814, 202 214 | Error Generator, 0.00021465219379591957, 203 215 | Error Generator, 0.0001217429470291149, 204 216 | Error Generator, 0.002682537802892918, 200 217 | Error Generator, 0.0005553295091380613, 201 218 | Error Generator, 0.00010615298986281258, 202 219 | Discriminator Error, 0.009145303484117111, 203 220 | Discriminator Error, 0.010413979109413423, 204 221 | -------------------------------------------------------------------------------- /webpage/pix2pix/pix2pix.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.transforms as transforms 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import os 7 | import torchvision.datasets as dset 8 | from .data import image_manipulation 9 | from .data import dataloader as img_dataloader 10 | from torch.autograd import Variable 11 | from tqdm import tqdm 12 | from PIL import Image 13 | 14 | 15 | def weights_init_normal(m): 16 | classname = m.__class__.__name__ 17 | if classname.find("Conv") != -1: 18 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 19 | elif classname.find("BatchNorm2d") != -1: 20 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 21 | torch.nn.init.constant_(m.bias.data, 0.0) 22 | 23 | class UNetDown(nn.Module): 24 | def __init__(self, in_size, out_size, normalize = True, dropout = 0.0): 25 | super(UNetDown, self).__init__() 26 | layers = [ 27 | nn.Conv2d(in_size, out_size, 4, 2, 1, bias = False) 28 | ] 29 | if normalize: 30 | layers.append(nn.InstanceNorm2d(out_size)) 31 | 32 | layers.append(nn.LeakyReLU(0.2)) 33 | 34 | if dropout: 35 | layers.append(nn.Dropout(dropout)) 36 | 37 | self.model = nn.Sequential(*layers) 38 | 39 | def forward(self, x): 40 | return self.model(x) 41 | 42 | class UNetUp(nn.Module): 43 | def __init__(self, in_size, out_size, dropout = 0.0): 44 | super(UNetUp, self).__init__() 45 | 46 | layers = [ 47 | nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False), 48 | nn.InstanceNorm2d(out_size), 49 | nn.ReLU(inplace=True), 50 | ] 51 | if dropout: 52 | layers.append(nn.Dropout(dropout)) 53 | 54 | self.model = nn.Sequential(*layers) 55 | 56 | 57 | def forward(self, x, skip_input): 58 | x = self.model(x) 59 | x = torch.cat((x, skip_input), 1) 60 | 61 | return x 62 | 63 | class GeneratorUNet(nn.Module): 64 | def __init__(self, in_channels=3, out_channels=3): 65 | super(GeneratorUNet, self).__init__() 66 | 67 | self.down1 = UNetDown(in_channels, 64, normalize=False) 68 | self.down2 = UNetDown(64, 128) 69 | self.down3 = UNetDown(128, 256) 70 | self.down4 = UNetDown(256, 512, dropout=0.5) 71 | self.down5 = UNetDown(512, 512, dropout=0.5) 72 | self.down6 = UNetDown(512, 512, dropout=0.5) 73 | self.down7 = UNetDown(512, 512, dropout=0.5) 74 | self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5) 75 | 76 | self.up1 = UNetUp(512, 512, dropout=0.5) 77 | self.up2 = UNetUp(1024, 512, dropout=0.5) 78 | self.up3 = UNetUp(1024, 512, dropout=0.5) 79 | self.up4 = UNetUp(1024, 512, dropout=0.5) 80 | self.up5 = UNetUp(1024, 256) 81 | self.up6 = UNetUp(512, 128) 82 | self.up7 = UNetUp(256, 64) 83 | 84 | self.final = nn.Sequential( 85 | nn.Upsample(scale_factor=2), 86 | nn.ZeroPad2d((1, 0, 1, 0)), 87 | nn.Conv2d(128, out_channels, 4, padding=1), 88 | nn.Tanh(), 89 | ) 90 | 91 | def forward(self, x): 92 | # U-Net generator with skip connections from encoder to decoder 93 | d1 = self.down1(x) 94 | d2 = self.down2(d1) 95 | d3 = self.down3(d2) 96 | d4 = self.down4(d3) 97 | d5 = self.down5(d4) 98 | d6 = self.down6(d5) 99 | d7 = self.down7(d6) 100 | d8 = self.down8(d7) 101 | 102 | # unet connections 103 | u1 = self.up1(d8, d7) 104 | u2 = self.up2(u1, d6) 105 | u3 = self.up3(u2, d5) 106 | u4 = self.up4(u3, d4) 107 | u5 = self.up5(u4, d3) 108 | u6 = self.up6(u5, d2) 109 | u7 = self.up7(u6, d1) 110 | 111 | return self.final(u7) 112 | 113 | class Discriminator(nn.Module): 114 | def __init__(self, in_channels=3): 115 | super(Discriminator, self).__init__() 116 | 117 | def discriminator_block(in_filters, out_filters, normalization=True): 118 | """Returns downsampling layers of each discriminator block""" 119 | layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)] 120 | if normalization: 121 | layers.append(nn.InstanceNorm2d(out_filters)) 122 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 123 | return layers 124 | 125 | self.model = nn.Sequential( 126 | *discriminator_block(in_channels * 2, 64, normalization=False), 127 | *discriminator_block(64, 128), 128 | *discriminator_block(128, 256), 129 | *discriminator_block(256, 512), 130 | nn.ZeroPad2d((1, 0, 1, 0)), 131 | nn.Conv2d(512, 1, 4, padding=1, bias=False), 132 | nn.Sigmoid() 133 | ) 134 | 135 | def forward(self, img_A, img_B): 136 | # Concatenate image and condition image by channels to produce input 137 | img_input = torch.cat((img_A, img_B), 1) 138 | return self.model(img_input) 139 | 140 | # ## Model Train 141 | 142 | if __name__ == "__main__": 143 | torch.cuda.is_available() 144 | 145 | # random seed for reproducibility 146 | random_seed = 69 147 | 148 | np.random.seed(random_seed) 149 | 150 | # no of workers for dataloader 151 | no_of_workers = 4 152 | 153 | # root of the data 154 | data_root = "data/train/" 155 | 156 | # batch size 157 | batch_size = 1 158 | 159 | #no of epochs 160 | n_epochs = 10 161 | 162 | # learning rate 163 | lr = 0.0002 164 | 165 | # betas for adam 166 | beta_1 = 0.5 167 | beta_2 = 0.999 168 | 169 | # image size 170 | image_height = 256 171 | image_width = 256 172 | 173 | # We can use an image folder dataset the way we have it setup. 174 | # Create the dataset 175 | dataset = dset.ImageFolder(root=data_root, 176 | transform=transforms.Compose([ 177 | transforms.ToTensor(), 178 | ])) 179 | # Create the dataloader 180 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, 181 | num_workers = no_of_workers) 182 | #initialize model classes 183 | generator = GeneratorUNet() 184 | discriminator = Discriminator() 185 | 186 | 187 | # check if cuda is avialbale 188 | cuda = True if torch.cuda.is_available() else False 189 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 190 | print(cuda) 191 | 192 | # initialize weights if the model is not found in the paths 193 | if os.path.exists("saved_models/generator_49.pth"): 194 | print("Generator Found") 195 | generator.load_state_dict(torch.load("saved_models/generator_49.pth", map_location = device)) 196 | else: 197 | generator.apply(weights_init_normal) 198 | 199 | if os.path.exists("saved_models/discriminator_49.pth"): 200 | print("Discriminator Found") 201 | discriminator.load_state_dict(torch.load("saved_models/discriminator_49.pth", map_location = device)) 202 | else: 203 | discriminator.apply(weights_init_normal) 204 | 205 | # model loss functions 206 | loss_fn_generator = torch.nn.MSELoss() # mean squared loss 207 | loss_fn_disc = torch.nn.L1Loss() #pixel wise loss 208 | 209 | # to cuda if cuda is avaiable 210 | generator.to(device) 211 | discriminator.to(device) 212 | loss_fn_disc.to(device) 213 | loss_fn_generator.to(device) 214 | 215 | # optimizers 216 | optimier_G = torch.optim.Adam(generator.parameters(), betas=(beta_1, beta_2), lr=lr) 217 | optimier_D = torch.optim.Adam(discriminator.parameters(), betas=(beta_1, beta_2), lr=lr) 218 | 219 | # Loss weight of L1 pixel-wise loss between translated image and real image 220 | lambda_pixel = 100 221 | 222 | # Calculate output of image discriminator (PatchGAN) 223 | patch = (1, image_height // 2 ** 4, image_width // 2 ** 4) 224 | 225 | # Tensor type 226 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 227 | 228 | transform = transforms.Compose([ 229 | transforms.ToTensor(), # transform to tensor 230 | transforms.Resize((image_width, image_height)) # Resize the image to constant size 231 | ]) 232 | 233 | # create a dataloader 234 | pair_image_dataloader = img_dataloader.ImageDataset("./data/train_2/old_images", "./data/train_2/reconstructed_images", transform) 235 | 236 | for epoch in range(1): 237 | for i, batch in tqdm(enumerate(pair_image_dataloader)): 238 | real_A = batch['A'].unsqueeze(0) # old image 239 | real_B = batch['B'].unsqueeze(0) # new image 240 | 241 | # train generator 242 | optimier_G.zero_grad() 243 | 244 | # Adversarial ground truths 245 | valid = Variable(Tensor(np.ones((real_A.size(0), *patch))), requires_grad=False) # ground truth for valid 246 | fake = Variable(Tensor(np.zeros((real_A.size(0), *patch))), requires_grad=False) # ground truth for invalid 247 | 248 | 249 | # GAN loss 250 | fake_B = generator(real_A.to(device)) # fake sample generated by generator 251 | pred_fake = discriminator(fake_B.to(device), real_B.to(device)) # prediction using discriminator 252 | loss_generator = loss_fn_generator(pred_fake.to(device), valid.to(device)) # check if the sample is valid or not 253 | 254 | loss_pixel = loss_fn_disc(fake_B.to(device), real_B.to(device)) # calculate the pixel wise loss 255 | 256 | # total loss 257 | loss_G = loss_generator + lambda_pixel * loss_pixel # total loss of the generator 258 | 259 | loss_G.backward() 260 | optimier_G.step() 261 | 262 | ## Train discriminator 263 | optimier_D.zero_grad() 264 | 265 | # Real loss 266 | pred_real = discriminator(real_B.to(device), real_A.to(device)) # loss to check real or not 267 | loss_real = loss_fn_generator(pred_real, valid) 268 | 269 | # Fake loss 270 | pred_fake = discriminator(fake_B.detach().to(device), real_A.to(device)) # loss to check fake or not 271 | loss_fake = loss_fn_generator(pred_fake.to(device), fake.to(device)) 272 | 273 | # Total loss 274 | loss_D = 0.5 * (loss_real + loss_fake) # total loss of the discriminator 275 | 276 | loss_D.backward() 277 | optimier_D.step() 278 | 279 | # for logging 280 | if i % 100 == 0 and i: 281 | print(f"Generator Error: {torch.linalg.norm(loss_G).item()}, epoch: {epoch}, itr: {i}") 282 | print(f"Discriminator Error: {torch.linalg.norm(loss_D).item()}, epoch: {epoch}, itr: {i}") 283 | 284 | # train with only 5000 images 285 | if i % 500 == 0 and i > 0: 286 | break 287 | 288 | 289 | torch.save(generator.state_dict(), "saved_models/generator.pth") 290 | torch.save(discriminator.state_dict(), "saved_models/discriminator.pth") 291 | -------------------------------------------------------------------------------- /webpage/pix2pix/data/image_manipulation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "0aab5a32", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import numpy as np\n", 11 | "import random\n", 12 | "import cv2\n", 13 | "from PIL import Image\n", 14 | "import os\n", 15 | "from io import BytesIO\n", 16 | "from tqdm.notebook import tqdm, trange\n", 17 | "import time" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 2, 23 | "id": "23a5007e", 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "def blend_image(original_image: Image.Image, blend_image_source: Image.Image, intensity = 0.6):\n", 28 | " \"\"\"\n", 29 | " Blends the original image with blend_image with intensity\n", 30 | " \"\"\"\n", 31 | " \n", 32 | " # converts the blend_image to the format of original image\n", 33 | " # because both the image needs to be in the same format\n", 34 | " # same goes for size\n", 35 | " blend_image_source = blend_image_source.convert(original_image.mode) \n", 36 | " blend_image_source = blend_image_source.resize(original_image.size)\n", 37 | " \n", 38 | " new_image = Image.new(original_image.mode, original_image.size)\n", 39 | " new_image = Image.blend(original_image, blend_image_source, intensity)\n", 40 | " \n", 41 | " return new_image" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 3, 47 | "id": "b9feef2a", 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "def pil_to_np(img_pil):\n", 52 | " \"\"\"\n", 53 | " Converts image from pil Image to numpy array\n", 54 | " \"\"\"\n", 55 | " ar: np.ndarray = np.array(img_pil)\n", 56 | " print(ar.shape)\n", 57 | " if len(ar.shape) == 3:\n", 58 | " \"\"\"\n", 59 | " Tensor transpose, since in this case tensor is 3D the order of transpose can be different\n", 60 | " In 2D matrix the transpose is only i,j-> j,i but in more than 2D matrix different permutation can be \n", 61 | " applied\n", 62 | " \"\"\"\n", 63 | " ar = ar.transpose(2, 0, 1)\n", 64 | " else:\n", 65 | " ar = ar[None, ...]\n", 66 | "\n", 67 | " return ar.astype(np.float32) / 255." 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 4, 73 | "id": "11162123", 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "def np_to_pil(img_np):\n", 78 | " \"\"\"\n", 79 | " Converts np.ndarray to Image.Image object\n", 80 | " \"\"\"\n", 81 | " ar = np.clip(img_np * 255, 0, 255).astype(np.uint8)\n", 82 | "\n", 83 | " if img_np.shape[0] == 1:\n", 84 | " ar = ar[0]\n", 85 | " else:\n", 86 | " ar = ar.transpose(1, 2, 0)\n", 87 | "\n", 88 | " return Image.fromarray(ar)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 5, 94 | "id": "8703eebd", 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "def synthesize_salt_pepper(image: Image.Image, amount, salt_vs_pepper):\n", 99 | " \"\"\"\n", 100 | " Salt and pepper noise is also known as an impulse noise, this noise can be caused by sharp and sudden \n", 101 | " disturbances in the image signal. gives the appearance of scattered white or black(or both) pixel over\n", 102 | " the image\n", 103 | " \"\"\"\n", 104 | " img_pil=pil_to_np(image)\n", 105 | "\n", 106 | " out = img_pil.copy()\n", 107 | " p = amount\n", 108 | " q = salt_vs_pepper\n", 109 | " flipped = np.random.choice([True, False], size=img_pil.shape,\n", 110 | " p=[p, 1 - p])\n", 111 | " salted = np.random.choice([True, False], size=img_pil.shape,\n", 112 | " p=[q, 1 - q])\n", 113 | " peppered = ~salted\n", 114 | " out[flipped & salted] = 1\n", 115 | " out[flipped & peppered] = 0.\n", 116 | " noisy = np.clip(out, 0, 1).astype(np.float32)\n", 117 | " return np_to_pil(noisy)" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 6, 123 | "id": "97268ca6", 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "def synthesize_speckle(image,std_l,std_r):\n", 128 | "\n", 129 | " ## Give PIL, return the noisy PIL\n", 130 | "\n", 131 | " img_pil=pil_to_np(image)\n", 132 | "\n", 133 | " mean=0\n", 134 | " std=random.uniform(std_l/255.,std_r/255.)\n", 135 | " gauss=np.random.normal(loc=mean,scale=std,size=img_pil.shape)\n", 136 | " noisy=img_pil+gauss*img_pil\n", 137 | " noisy=np.clip(noisy,0,1).astype(np.float32)\n", 138 | "\n", 139 | " return np_to_pil(noisy)" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 7, 145 | "id": "eb4a58e9", 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [ 149 | "def blur_image_v2(img):\n", 150 | " x=np.array(img)\n", 151 | " kernel_size_candidate=[(3,3),(5,5),(7,7)]\n", 152 | " kernel_size=random.sample(kernel_size_candidate,1)[0]\n", 153 | " std=random.uniform(1.,5.)\n", 154 | "\n", 155 | " #print(\"The gaussian kernel size: (%d,%d) std: %.2f\"%(kernel_size[0],kernel_size[1],std))\n", 156 | " blur=cv2.GaussianBlur(x,kernel_size,std)\n", 157 | "\n", 158 | " return Image.fromarray(blur.astype(np.uint8))" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 8, 164 | "id": "50ac819c", 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [ 168 | "def synthesize_low_resolution(image: Image.Image):\n", 169 | " \"\"\"\n", 170 | " Creates a low resolution image from high resolution image\n", 171 | " \"\"\"\n", 172 | " width, height = image.size\n", 173 | " \n", 174 | " new_width = np.random.randint(int(width / 2), width - int(width / 5))\n", 175 | " new_height = np.random.randint(int(height / 2), height - int(height / 5))\n", 176 | " \n", 177 | " image = image.resize((new_width, new_height), Image.BICUBIC)\n", 178 | " \n", 179 | " if random.uniform(0, 1) < 0.5:\n", 180 | " image = image.resize((width, height), Image.NEAREST)\n", 181 | " else:\n", 182 | " image = image.resize((width, height), Image.BILINEAR)\n", 183 | " \n", 184 | " return image" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 9, 190 | "id": "018a8a72", 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "def online_add_degradation_v2(img):\n", 195 | " task_id = np.random.permutation(4)\n", 196 | "\n", 197 | " for x in task_id:\n", 198 | " if x == 0 and random.uniform(0,1)<0.7:\n", 199 | " img = blur_image_v2(img)\n", 200 | " if x == 1 and random.uniform(0,1)<0.7:\n", 201 | " flag = random.choice([1, 2, 3])\n", 202 | " if flag == 1:\n", 203 | " img = synthesize_gaussian(img, 5, 50)\n", 204 | " if flag == 2:\n", 205 | " img = synthesize_speckle(img, 5, 50)\n", 206 | " if flag == 3:\n", 207 | " img = synthesize_salt_pepper(img, random.uniform(0, 0.01), random.uniform(0.3, 0.8))\n", 208 | " if x == 2 and random.uniform(0,1)<0.7:\n", 209 | " img=synthesize_low_resolution(img)\n", 210 | "\n", 211 | " if x==3 and random.uniform(0,1)<0.7:\n", 212 | " img=convertToJpeg(img,random.randint(40,100))\n", 213 | "\n", 214 | " return img" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 10, 220 | "id": "15e2d28e", 221 | "metadata": {}, 222 | "outputs": [], 223 | "source": [ 224 | "def zero_mask(row, col):\n", 225 | " x = np.zeros((row, col, 3))\n", 226 | " mask=Image.fromarray(x).convert(\"RGB\")\n", 227 | " return mask\n", 228 | "\n", 229 | "def irregular_hole_synthesize(img, mask):\n", 230 | " \"\"\"\n", 231 | " Create holes using scrach paper textures\n", 232 | " Args:\n", 233 | " img: Original Image\n", 234 | " mask: scratch paper texture\n", 235 | " \"\"\"\n", 236 | " img_np = np.array(img).astype('uint8')\n", 237 | " mask = mask.resize(img.size)\n", 238 | " mask = mask.convert(img.mode) \n", 239 | " mask_np = np.array(mask).astype('uint8')\n", 240 | " mask_np = mask_np / 255\n", 241 | " img_new=img_np * (1 - mask_np) + mask_np * 255\n", 242 | "\n", 243 | "\n", 244 | " hole_img=Image.fromarray(img_new.astype('uint8')).convert(\"RGB\")\n", 245 | "\n", 246 | "\n", 247 | " return hole_img,mask.convert(\"L\")" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": 11, 253 | "id": "9faf16bd", 254 | "metadata": {}, 255 | "outputs": [], 256 | "source": [ 257 | "def convertToJpeg(im,quality):\n", 258 | " with BytesIO() as f:\n", 259 | " im.save(f, format='JPEG',quality=quality)\n", 260 | " f.seek(0)\n", 261 | " return Image.open(f).convert('RGB')" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": 12, 267 | "id": "acd2721f", 268 | "metadata": {}, 269 | "outputs": [ 270 | { 271 | "data": { 272 | "application/vnd.jupyter.widget-view+json": { 273 | "model_id": "8da6a073f3164d7fb99c943bd8836259", 274 | "version_major": 2, 275 | "version_minor": 0 276 | }, 277 | "text/plain": [ 278 | " 0%| | 0/1051 [00:00 [BN] => ReLU) * 2""" 80 | 81 | def __init__(self, in_channels, out_channels, mid_channels=None): 82 | super().__init__() 83 | if not mid_channels: 84 | mid_channels = out_channels 85 | self.double_conv = nn.Sequential( 86 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), 87 | nn.BatchNorm2d(mid_channels), 88 | nn.ReLU(inplace=True), 89 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), 90 | nn.BatchNorm2d(out_channels), 91 | nn.ReLU(inplace=True) 92 | ) 93 | 94 | def forward(self, x): 95 | return self.double_conv(x) 96 | 97 | class UNetUp(nn.Module): 98 | def __init__(self, in_size, out_size, dropout = 0.0): 99 | super(UNetUp, self).__init__() 100 | 101 | layers = [ 102 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 103 | DoubleConv(in_size, out_size, in_size // 2), 104 | nn.InstanceNorm2d(out_size), 105 | nn.ReLU(inplace=True), 106 | ] 107 | if dropout: 108 | layers.append(nn.Dropout(dropout)) 109 | 110 | self.model = nn.Sequential(*layers) 111 | 112 | 113 | def forward(self, x, skip_input): 114 | x = self.model(x) 115 | x = torch.cat((x, skip_input), 1) 116 | 117 | return x 118 | 119 | 120 | 121 | class GeneratorUNet(nn.Module): 122 | def __init__(self, in_channels=3, out_channels=3): 123 | super(GeneratorUNet, self).__init__() 124 | 125 | self.down1 = UNetDown(in_channels, 64, normalize=False) 126 | self.down2 = UNetDown(64, 128) 127 | self.down3 = UNetDown(128, 256) 128 | self.down4 = UNetDown(256, 512, dropout=0.5) 129 | self.down5 = UNetDown(512, 512, dropout=0.5) 130 | self.down6 = UNetDown(512, 512, dropout=0.5) 131 | self.down7 = UNetDown(512, 512, dropout=0.5) 132 | self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5) 133 | 134 | self.up1 = UNetUp(512, 512, dropout=0.5) 135 | self.up2 = UNetUp(1024, 512, dropout=0.5) 136 | self.up3 = UNetUp(1024, 512, dropout=0.5) 137 | self.up4 = UNetUp(1024, 512, dropout=0.5) 138 | self.up5 = UNetUp(1024, 256) 139 | self.up6 = UNetUp(512, 128) 140 | self.up7 = UNetUp(256, 64) 141 | 142 | self.final = nn.Sequential( 143 | nn.Upsample(scale_factor=2), 144 | nn.ZeroPad2d((1, 0, 1, 0)), 145 | nn.Conv2d(128, out_channels, 4, padding=1), 146 | nn.Tanh(), 147 | ) 148 | 149 | def forward(self, x): 150 | # U-Net generator with skip connections from encoder to decoder 151 | d1 = self.down1(x) 152 | d2 = self.down2(d1) 153 | d3 = self.down3(d2) 154 | d4 = self.down4(d3) 155 | d5 = self.down5(d4) 156 | d6 = self.down6(d5) 157 | d7 = self.down7(d6) 158 | d8 = self.down8(d7) 159 | 160 | # unet connections 161 | u1 = self.up1(d8, d7) 162 | u2 = self.up2(u1, d6) 163 | u3 = self.up3(u2, d5) 164 | u4 = self.up4(u3, d4) 165 | u5 = self.up5(u4, d3) 166 | u6 = self.up6(u5, d2) 167 | u7 = self.up7(u6, d1) 168 | 169 | return self.final(u7) 170 | 171 | 172 | class Discriminator(nn.Module): 173 | def __init__(self, in_channels=3): 174 | super(Discriminator, self).__init__() 175 | 176 | def discriminator_block(in_filters, out_filters, normalization=True): 177 | """Returns downsampling layers of each discriminator block""" 178 | layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)] 179 | if normalization: 180 | layers.append(nn.InstanceNorm2d(out_filters)) 181 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 182 | return layers 183 | 184 | self.model = nn.Sequential( 185 | *discriminator_block(in_channels * 2, 64, normalization=False), 186 | *discriminator_block(64, 128), 187 | *discriminator_block(128, 256), 188 | *discriminator_block(256, 512), 189 | nn.ZeroPad2d((1, 0, 1, 0)), 190 | nn.Conv2d(512, 1, 4, padding=1, bias=False), 191 | nn.Sigmoid() 192 | ) 193 | 194 | def forward(self, img_A, img_B): 195 | # Concatenate image and condition image by channels to produce input 196 | img_input = torch.cat((img_A, img_B), 1) 197 | return self.model(img_input) 198 | 199 | 200 | import requests 201 | from io import BytesIO 202 | transform = transforms.Compose([ 203 | transforms.ToTensor(), # transform to tensor 204 | transforms.Resize((image_width, image_height)) # Resize the image to constant size 205 | ]) 206 | 207 | 208 | 209 | 210 | # count the number of trainable parameters 211 | def count_parameters(model): 212 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 213 | 214 | 215 | 216 | 217 | if __name__ == "__main__": 218 | # no of workers for dataloader 219 | no_of_workers = 4 220 | 221 | # root of the data 222 | data_root = "data/train/" 223 | 224 | # batch size 225 | batch_size = 10 226 | 227 | #no of epochs 228 | n_epochs = 10 229 | 230 | # learning rate 231 | lr = 0.0005 232 | 233 | # betas for adam 234 | beta_1 = 0.5 235 | beta_2 = 0.999 236 | 237 | # image size 238 | image_height = 512 239 | image_width = 512 240 | 241 | 242 | #initialize model classes 243 | generator = GeneratorUNet() 244 | discriminator = Discriminator() 245 | 246 | 247 | # check if cuda is avialbale 248 | cuda = True if torch.cuda.is_available() else False 249 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 250 | print(cuda) 251 | 252 | 253 | 254 | # Loss weight of L1 pixel-wise loss between translated image and real image 255 | lambda_pixel = 60 256 | 257 | # Calculate output of image discriminator (PatchGAN) 258 | patch = (1, image_height // 2 ** 4, image_width // 2 ** 4) 259 | 260 | # Tensor type 261 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 262 | 263 | 264 | 265 | # load dataset 266 | transform = transforms.Compose([ 267 | transforms.ToTensor(), # transform to tensor 268 | transforms.Resize((image_width, image_height)) # Resize the image to constant size 269 | ]) 270 | 271 | # create a dataloader 272 | pair_image_dataloader = img_dataloader.ImageDataset("./data/train_2/old_images", "./data/train_2/reconstructed_images", transform) 273 | 274 | dataloader = DataLoader( 275 | pair_image_dataloader, 276 | batch_size = 4, 277 | shuffle = True, 278 | ) 279 | 280 | # val_image_dataloader = img_dataloader.ImageDataset("./data/val/old_image", "./data/val/reconstructed_image", transform) 281 | # val_dataloader = DataLoader( 282 | # val_image_dataloader, 283 | # batch_size = 5, 284 | # shuffle = True 285 | # ) 286 | 287 | # load csv for data logging 288 | generator_error_file = "generator_new_data.csv" 289 | disc_error_file = "discriminator_new_data.csv" 290 | 291 | total_generator_epochs = 0 292 | try: 293 | with open(generator_error_file, "r") as f: 294 | last_line = None 295 | for last_line in f: 296 | pass 297 | if last_line != None: 298 | print("CSV file found") 299 | total_generator_epochs = int(last_line.split(',')[-1]) + 1 300 | else: 301 | total_generator_epochs = 0 302 | 303 | f.close() 304 | except FileNotFoundError: 305 | with open(generator_error_file, "w") as f: 306 | total_generator_epochs = 0 307 | 308 | f.close() 309 | 310 | 311 | total_discriminator_epochs = 0 312 | try: 313 | with open(disc_error_file, "r") as f: 314 | last_line = None 315 | for last_line in f: 316 | pass 317 | if last_line != None: 318 | total_discriminator_epochs = int(last_line.split(',')[-1]) + 1 319 | print(f"CSV file found: {total_discriminator_epochs}") 320 | else: 321 | total_discriminator_epochs = 0 322 | 323 | f.close() 324 | except FileNotFoundError: 325 | with open(disc_error_file, "w") as f: 326 | total_generator_epochs = 0 327 | 328 | f.close() 329 | 330 | generator_file = f"saved_models_new_data/generator_{total_generator_epochs - 1}.pth" 331 | discriminator_file = f"saved_models_new_data/discriminator_{total_discriminator_epochs - 1}.pth" 332 | # initialize weights if the model is not found in the paths 333 | if os.path.exists(generator_file): 334 | print("Generator Found") 335 | generator.load_state_dict(torch.load(generator_file, map_location = device)) 336 | else: 337 | generator.apply(weights_init_normal) 338 | 339 | 340 | if os.path.exists(discriminator_file): 341 | print("Discriminator Found") 342 | discriminator.load_state_dict(torch.load(discriminator_file, map_location = device)) 343 | else: 344 | discriminator.apply(weights_init_normal) 345 | 346 | # model loss functions 347 | loss_fn_generator = torch.nn.MSELoss() # mean squared loss 348 | loss_fn_disc = torch.nn.L1Loss() #pixel wise loss 349 | 350 | # to cuda if cuda is avaiable 351 | generator.to(device) 352 | discriminator.to(device) 353 | loss_fn_disc.to(device) 354 | loss_fn_generator.to(device) 355 | 356 | # optimizers 357 | optimier_G = torch.optim.Adam(generator.parameters(), betas=(beta_1, beta_2), lr=lr) 358 | optimier_D = torch.optim.Adam(discriminator.parameters(), betas=(beta_1, beta_2), lr=lr) 359 | 360 | for epoch in range(total_discriminator_epochs, total_discriminator_epochs + 2): 361 | # break 362 | loss_G_list = np.array([]) 363 | loss_D_list = np.array([]) 364 | for i, batch in tqdm(enumerate(dataloader)): 365 | real_A = batch['A'] # old image 366 | real_B = batch['B'] # new image 367 | 368 | # train generator 369 | optimier_G.zero_grad() 370 | 371 | # Adversarial ground truths 372 | valid = Variable(Tensor(np.ones((real_A.size(0), *patch))), requires_grad=False) # ground truth for valid 373 | fake = Variable(Tensor(np.zeros((real_A.size(0), *patch))), requires_grad=False) # ground truth for invalid 374 | 375 | 376 | # GAN loss 377 | fake_B = generator(real_A.to(device)) # fake sample generated by generator 378 | pred_fake = discriminator(fake_B.to(device), real_B.to(device)) # prediction using discriminator 379 | loss_generator = loss_fn_generator(pred_fake.to(device), valid.to(device)) # check if the sample is valid or not 380 | 381 | loss_pixel = loss_fn_disc(fake_B.to(device), real_B.to(device)) # calculate the pixel wise loss 382 | 383 | # total loss 384 | loss_G = loss_generator + lambda_pixel * loss_pixel # total loss of the generator 385 | 386 | loss_G.backward() 387 | optimier_G.step() 388 | 389 | ## Train discriminator 390 | optimier_D.zero_grad() 391 | 392 | # Real loss 393 | pred_real = discriminator(real_B.to(device), real_A.to(device)) # loss to check real or not 394 | loss_real = loss_fn_generator(pred_real, valid) 395 | 396 | # Fake loss 397 | pred_fake = discriminator(fake_B.detach().to(device), real_A.to(device)) # loss to check fake or not 398 | loss_fake = loss_fn_generator(pred_fake.to(device), fake.to(device)) 399 | 400 | # Total loss 401 | loss_D = 0.5 * (loss_real + loss_fake) # total loss of the discriminator 402 | 403 | loss_D.backward() 404 | optimier_D.step() 405 | 406 | # for logging 407 | if i % 10 == 0 and i != 0: 408 | print(f"Generator Error: {torch.linalg.norm(loss_G).item()}, epoch: {epoch}, itr: {i}") 409 | print(f"Discriminator Error: {torch.linalg.norm(loss_D).item()}, epoch: {epoch}, itr: {i}") 410 | 411 | loss_G_list = np.append(loss_G_list, torch.linalg.norm(loss_G).item()) 412 | loss_D_list = np.append(loss_D_list, torch.linalg.norm(loss_D).item()) 413 | 414 | # log into a file 415 | with open(generator_error_file, "a") as f: 416 | f.write(f"Generator Error, {np.mean(loss_G_list)}, {epoch}\n") 417 | with open(disc_error_file, "a") as f: 418 | f.write(f"Discriminator Error, {np.mean(loss_D_list)}, {epoch}\n") 419 | 420 | torch.save(generator.state_dict(), f"saved_models_new_data/generator_{epoch}.pth") 421 | torch.save(discriminator.state_dict(), f"saved_models_new_data/discriminator_{epoch}.pth") 422 | 423 | # torch.cuda.empty_cache() 424 | 425 | # # image load 426 | # original_image = Image.open("./outputs/input.jpg") 427 | # # original_image = test_list[0]["A"] 428 | # original_image.save("outputs/input.jpg") 429 | # output_image = image_manipulation.np_to_pil( 430 | # generator(transform(original_image).unsqueeze(0).to(device)).detach().cpu().numpy()[0] 431 | # ) 432 | 433 | # width, height = original_image.size 434 | # output_image.resize((width, height)) 435 | 436 | # output_image.save(f"outputs/output_{total_generator_epochs}.jpg") 437 | -------------------------------------------------------------------------------- /webpage/pix2pix/test.py: -------------------------------------------------------------------------------- 1 | from pyrsistent import b 2 | import torch 3 | import torch.nn as nn 4 | import torchvision.transforms as transforms 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import os 8 | import torchvision.datasets as dset 9 | from data import image_manipulation 10 | from data import dataloader as img_dataloader 11 | from torch.autograd import Variable 12 | from tqdm import tqdm 13 | from PIL import Image 14 | from torch.utils.data import DataLoader 15 | from torchsummary import summary 16 | from torch.utils.tensorboard import SummaryWriter 17 | 18 | 19 | def weights_init_normal(m): 20 | classname = m.__class__.__name__ 21 | 22 | if classname.find("Conv") != -1 and classname.find("DoubleConv") == 1: 23 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 24 | elif classname.find("BatchNorm2d") != -1: 25 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 26 | torch.nn.init.constant_(m.bias.data, 0.0) 27 | 28 | 29 | 30 | 31 | class UNetDown(nn.Module): 32 | def __init__(self, in_size, out_size, normalize = True, dropout = 0.0): 33 | super(UNetDown, self).__init__() 34 | layers = [ 35 | nn.Conv2d(in_size, out_size, 4, 2, 1, bias = False) 36 | ] 37 | if normalize: 38 | layers.append(nn.InstanceNorm2d(out_size)) 39 | 40 | layers.append(nn.LeakyReLU(0.2)) 41 | 42 | if dropout: 43 | layers.append(nn.Dropout(dropout)) 44 | 45 | self.model = nn.Sequential(*layers) 46 | 47 | def forward(self, x): 48 | return self.model(x) 49 | 50 | 51 | 52 | 53 | class DoubleConv(nn.Module): 54 | """(convolution => [BN] => ReLU) * 2""" 55 | 56 | def __init__(self, in_channels, out_channels, mid_channels=None): 57 | super().__init__() 58 | if not mid_channels: 59 | mid_channels = out_channels 60 | self.double_conv = nn.Sequential( 61 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), 62 | nn.BatchNorm2d(mid_channels), 63 | nn.ReLU(inplace=True), 64 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), 65 | nn.BatchNorm2d(out_channels), 66 | nn.ReLU(inplace=True) 67 | ) 68 | 69 | def forward(self, x): 70 | return self.double_conv(x) 71 | 72 | 73 | class UNetUp(nn.Module): 74 | def __init__(self, in_size, out_size, dropout = 0.0): 75 | super(UNetUp, self).__init__() 76 | 77 | layers = [ 78 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 79 | DoubleConv(in_size, out_size, in_size // 2), 80 | nn.InstanceNorm2d(out_size), 81 | nn.ReLU(inplace=True), 82 | ] 83 | if dropout: 84 | layers.append(nn.Dropout(dropout)) 85 | 86 | self.model = nn.Sequential(*layers) 87 | 88 | 89 | def forward(self, x, skip_input): 90 | x = self.model(x) 91 | x = torch.cat((x, skip_input), 1) 92 | 93 | return x 94 | 95 | 96 | class GeneratorUNet(nn.Module): 97 | def __init__(self, in_channels=3, out_channels=3): 98 | super(GeneratorUNet, self).__init__() 99 | 100 | self.down1 = UNetDown(in_channels, 64, normalize=False) 101 | self.down2 = UNetDown(64, 128) 102 | self.down3 = UNetDown(128, 256) 103 | self.down4 = UNetDown(256, 512, dropout=0.5) 104 | self.down5 = UNetDown(512, 512, dropout=0.5) 105 | self.down6 = UNetDown(512, 512, dropout=0.5) 106 | self.down7 = UNetDown(512, 512, dropout=0.5) 107 | self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5) 108 | 109 | self.up1 = UNetUp(512, 512, dropout=0.5) 110 | self.up2 = UNetUp(1024, 512, dropout=0.5) 111 | self.up3 = UNetUp(1024, 512, dropout=0.5) 112 | self.up4 = UNetUp(1024, 512, dropout=0.5) 113 | self.up5 = UNetUp(1024, 256) 114 | self.up6 = UNetUp(512, 128) 115 | self.up7 = UNetUp(256, 64) 116 | 117 | self.final = nn.Sequential( 118 | nn.Upsample(scale_factor=2), 119 | nn.ZeroPad2d((1, 0, 1, 0)), 120 | nn.Conv2d(128, out_channels, 4, padding=1), 121 | nn.Tanh(), 122 | ) 123 | 124 | def forward(self, x): 125 | # U-Net generator with skip connections from encoder to decoder 126 | d1 = self.down1(x) 127 | d2 = self.down2(d1) 128 | d3 = self.down3(d2) 129 | d4 = self.down4(d3) 130 | d5 = self.down5(d4) 131 | d6 = self.down6(d5) 132 | d7 = self.down7(d6) 133 | d8 = self.down8(d7) 134 | 135 | # unet connections 136 | u1 = self.up1(d8, d7) 137 | u2 = self.up2(u1, d6) 138 | u3 = self.up3(u2, d5) 139 | u4 = self.up4(u3, d4) 140 | u5 = self.up5(u4, d3) 141 | u6 = self.up6(u5, d2) 142 | u7 = self.up7(u6, d1) 143 | 144 | return self.final(u7) 145 | 146 | 147 | import torch.nn.functional as F 148 | import math 149 | def gaussian(window_size, sigma): 150 | """ 151 | Generates a list of Tensor values drawn from a gaussian distribution with standard 152 | diviation = sigma and sum of all elements = 1. 153 | 154 | Length of list = window_size 155 | """ 156 | gauss = torch.Tensor([math.exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 157 | return gauss/gauss.sum() 158 | 159 | def create_window(window_size, channel=1): 160 | 161 | # Generate an 1D tensor containing values sampled from a gaussian distribution 162 | _1d_window = gaussian(window_size=window_size, sigma=1.5).unsqueeze(1) 163 | 164 | # Converting to 2D 165 | _2d_window = _1d_window.mm(_1d_window.t()).float().unsqueeze(0).unsqueeze(0) 166 | 167 | window = torch.Tensor(_2d_window.expand(channel, 1, window_size, window_size).contiguous()) 168 | 169 | return window 170 | 171 | 172 | def ssim(img1, img2, val_range, window_size=11, window=None, size_average=True, full=False): 173 | L = val_range # L is the dynamic range of the pixel values (255 for 8-bit grayscale images), 174 | 175 | pad = window_size // 2 176 | 177 | try: 178 | _, channels, height, width = img1.size() 179 | except: 180 | channels, height, width = img1.size() 181 | 182 | # if window is not provided, init one 183 | if window is None: 184 | real_size = min(window_size, height, width) # window should be atleast 11x11 185 | window = create_window(real_size, channel=channels).to(img1.device) 186 | 187 | # calculating the mu parameter (locally) for both images using a gaussian filter 188 | # calculates the luminosity params 189 | mu1 = F.conv2d(img1, window, padding=pad, groups=channels) 190 | mu2 = F.conv2d(img2, window, padding=pad, groups=channels) 191 | 192 | mu1_sq = mu1 ** 2 193 | mu2_sq = mu2 ** 2 194 | mu12 = mu1 * mu2 195 | 196 | # now we calculate the sigma square parameter 197 | # Sigma deals with the contrast component 198 | sigma1_sq = F.conv2d(img1 * img1, window, padding=pad, groups=channels) - mu1_sq 199 | sigma2_sq = F.conv2d(img2 * img2, window, padding=pad, groups=channels) - mu2_sq 200 | sigma12 = F.conv2d(img1 * img2, window, padding=pad, groups=channels) - mu12 201 | 202 | # Some constants for stability 203 | C1 = (0.01 ) ** 2 # NOTE: Removed L from here (ref PT implementation) 204 | C2 = (0.03 ) ** 2 205 | 206 | contrast_metric = (2.0 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2) 207 | contrast_metric = torch.mean(contrast_metric) 208 | 209 | numerator1 = 2 * mu12 + C1 210 | numerator2 = 2 * sigma12 + C2 211 | denominator1 = mu1_sq + mu2_sq + C1 212 | denominator2 = sigma1_sq + sigma2_sq + C2 213 | 214 | ssim_score = (numerator1 * numerator2) / (denominator1 * denominator2) 215 | 216 | if size_average: 217 | ret = ssim_score.mean() 218 | else: 219 | ret = ssim_score.mean(1).mean(1).mean(1) 220 | 221 | if full: 222 | return ret, contrast_metric 223 | 224 | return ret 225 | 226 | 227 | 228 | class Discriminator(nn.Module): 229 | def __init__(self, in_channels=3): 230 | super(Discriminator, self).__init__() 231 | 232 | def discriminator_block(in_filters, out_filters, normalization=True): 233 | """Returns downsampling layers of each discriminator block""" 234 | layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)] 235 | if normalization: 236 | layers.append(nn.InstanceNorm2d(out_filters)) 237 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 238 | return layers 239 | 240 | self.model = nn.Sequential( 241 | *discriminator_block(in_channels * 2, 64, normalization=False), 242 | *discriminator_block(64, 128), 243 | *discriminator_block(128, 256), 244 | *discriminator_block(256, 512), 245 | nn.ZeroPad2d((1, 0, 1, 0)), 246 | nn.Conv2d(512, 1, 4, padding=1, bias=False), 247 | nn.Sigmoid() 248 | ) 249 | 250 | def forward(self, img_A, img_B): 251 | # Concatenate image and condition image by channels to produce input 252 | img_input = torch.cat((img_A, img_B), 1) 253 | return self.model(img_input) 254 | 255 | 256 | from scipy.linalg import sqrtm 257 | def fid(img1_vec, img2_vec) -> float: 258 | #calculate mean 259 | mu1, C1 = img1_vec.mean(axis = 0), np.cov(img1_vec, rowvar = False) 260 | mu2, C2 = img2_vec.mean(axis = 0), np.cov(img2_vec, rowvar = False) 261 | 262 | # sum of squared difference 263 | msdiff = np.sum((mu1 - mu2) ** 2) 264 | 265 | # sqrt of products 266 | product_covariance = sqrtm(C1.dot(C2)) 267 | if np.iscomplexobj(product_covariance): 268 | product_covariance = product_covariance.real 269 | 270 | sqrt_product_covariance = np.trace(C1 + C2 - 2 * product_covariance) 271 | #return the result 272 | return msdiff + sqrt_product_covariance 273 | 274 | def calculate_fid(model, images_1, images_2): 275 | preprocess = transforms.Compose([ 276 | transforms.Resize(299), 277 | # transforms.CenterCrop(299), 278 | # transforms.ToTensor(), 279 | ]) 280 | images_1 = preprocess(images_1) 281 | images_2 = preprocess(images_2) 282 | img1_vec = model(preprocess(images_1)).detach().cpu().numpy() 283 | img2_vec = model(preprocess(images_2)).detach().cpu().numpy() 284 | return fid(img1_vec, img2_vec) 285 | 286 | 287 | if __name__ == "__main__": 288 | # batch size 289 | batch_size = 10 290 | 291 | # image size 292 | image_height = 512 293 | image_width = 512 294 | 295 | 296 | #initialize model classes 297 | generator = GeneratorUNet() 298 | discriminator = Discriminator() 299 | 300 | 301 | # check if cuda is avialbale 302 | cuda = True if torch.cuda.is_available() else False 303 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 304 | 305 | generator_file = "saved_models_new_data/generator_131.pth" 306 | # initialize weights if the model is not found in the paths 307 | if os.path.exists(generator_file): 308 | print("Generator Found") 309 | generator.load_state_dict(torch.load(generator_file, map_location = device)) 310 | else: 311 | generator.apply(weights_init_normal) 312 | 313 | 314 | 315 | # to cuda if cuda is avaiable 316 | generator.to(device) 317 | 318 | # Tensor type 319 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 320 | 321 | 322 | transform = transforms.Compose([ 323 | transforms.ToTensor(), # transform to tensor 324 | transforms.Resize((image_width, image_height)) # Resize the image to constant size 325 | ]) 326 | 327 | # create a dataloader 328 | pair_image_dataloader = img_dataloader.ImageDataset("./data/train/old_images", "./data/train/reconstructed_images", transform) 329 | 330 | dataloader = DataLoader( 331 | pair_image_dataloader, 332 | batch_size = 2, 333 | shuffle = True, 334 | ) 335 | 336 | val_image_dataloader = img_dataloader.ImageDataset("./data/val/old_image", "./data/val/reconstructed_image", transform) 337 | val_dataloader = DataLoader( 338 | val_image_dataloader, 339 | batch_size = 2, 340 | shuffle = True 341 | ) 342 | 343 | # FID calculation 344 | # preprocess = transforms.Compose([ 345 | # transforms.Resize(299), 346 | # ]) 347 | 348 | # model = torch.hub.load('pytorch/vision:v0.10.0', 'inception_v3', pretrained=True, progress=False) 349 | # model.eval() 350 | # model.to(device) 351 | 352 | # generator.eval() 353 | # # score for training set 354 | # dataloader_list = next(iter(dataloader)) 355 | # images_1 = preprocess(generator(dataloader_list['A'].to(device))) 356 | # images_2 = preprocess(dataloader_list['B']).to(device) 357 | 358 | # # score of the dataser 359 | # img1_vec = model(preprocess(dataloader_list['A'].to(device))).detach().cpu().numpy() 360 | # img2_vec = model(preprocess(images_2)).detach().cpu().numpy() 361 | # print(f"Dataset score: {fid(np.transpose(img1_vec), np.transpose(img2_vec))}") 362 | 363 | 364 | # img1_vec = model(preprocess(images_1)).detach().cpu().numpy() 365 | # img2_vec = model(preprocess(images_2)).detach().cpu().numpy() 366 | # print(f"Training score: {fid(np.transpose(img1_vec), np.transpose(img2_vec))}") 367 | # torch.cuda.empty_cache() 368 | 369 | # dataloader_list = next(iter(val_dataloader)) 370 | # images_1 = preprocess(generator(dataloader_list['A'].to(device))) 371 | # images_2 = preprocess(dataloader_list['B']).to(device) 372 | # img1_vec = model(preprocess(images_1)).detach().cpu().numpy() 373 | # img2_vec = model(preprocess(images_2)).detach().cpu().numpy() 374 | # print(f"Validation score: {fid(np.transpose(img1_vec), np.transpose(img2_vec))}") 375 | 376 | # with torch.no_grad(): 377 | # for i, batch in enumerate(dataloader): 378 | # generate_image = list(iter(generator(batch['A'].to(device)))) 379 | # original_image_list = list(iter(batch['A'])) 380 | # pil_transform = transforms.Compose([ 381 | # transforms.ToPILImage() 382 | # ]) 383 | 384 | # pil_transform(generate_image[0].detach().cpu()).save(f"outputs/{2 * i}.jpg") 385 | # pil_transform(generate_image[1].detach().cpu()).save(f"outputs/{2 * i + 1}.jpg") 386 | 387 | # pil_transform(original_image_list[0].detach().cpu()).save(f"inputs/{2 * i}.jpg") 388 | # pil_transform(original_image_list[1].detach().cpu()).save(f"inputs/{2 * i + 1}.jpg") 389 | 390 | # if i > 10: 391 | # break 392 | 393 | ssim_value_list = [] 394 | ssim_value_list_dataset = [] 395 | with torch.no_grad(): 396 | for i , batch in enumerate(dataloader): 397 | generated_image = generator(batch['A'].to(device)) 398 | original_image = batch['B'].to(device) 399 | 400 | ssim_value_list.append(ssim(original_image, generated_image, 255).detach().cpu().item()) 401 | ssim_value_list_dataset.append(ssim(batch['A'].to(device), original_image, 255).detach().cpu().item()) 402 | 403 | 404 | if i > 100: 405 | break 406 | 407 | ssim_value_np = np.array(ssim_value_list) 408 | ssim_value_dataset_np = np.array(ssim_value_list_dataset) 409 | 410 | print(f"SSIM Value: {np.mean((ssim_value_list))}") 411 | 412 | 413 | 414 | --------------------------------------------------------------------------------