├── .gitignore ├── BatchResize.py ├── DataTools ├── DataSets.py ├── FileTools.py ├── Loaders.py ├── Prepro.py ├── README.md └── __init__.py ├── Functions ├── Metrics.py ├── README.md ├── SRMeasure.py ├── TestTools.py ├── __init__.py └── functional.py ├── LogTools ├── README.md ├── __init__.py └── logger.py ├── ModelTools ├── BaseModel.py ├── ModelZoo │ └── __init__.py └── __init__.py ├── README.md ├── TorchNet ├── Calligraphy.py ├── ClassicSRNet.py ├── Deblur.py ├── Denoising.py ├── Discriminators.py ├── FaceDetection.py ├── FaceHallucination.py ├── GANInverse.py ├── GANmodel.py ├── Img2Img.py ├── LSTMBurstImage.py ├── Losses.py ├── OpticalFlow.py ├── Optim.py ├── PGGAN.py ├── README.md ├── RLSR.py ├── SRMini.py ├── StyleGAN.py ├── VGG.py ├── Visualizing.py ├── __init__.py ├── activation.py ├── modules.py ├── tSNE.py └── tools.py ├── __init__.py ├── batch_PSNR_SSIM.py ├── caffemodel_to_t7.lua ├── convert_torch.py └── frameCutter.py /.gitignore: -------------------------------------------------------------------------------- 1 | led / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | .idea/ 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | .static_storage/ 58 | .media/ 59 | local_settings.py 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | 108 | # test script 109 | __test_script.py 110 | -------------------------------------------------------------------------------- /BatchResize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | from PIL import Image 5 | 6 | from DataTools.FileTools import _video_image_file, _image_file 7 | from Functions.functional import resize 8 | 9 | parser = argparse.ArgumentParser() 10 | 11 | parser.add_argument('-i', '--input', type=str, default='/home/sensetime/Documents/NTIRE2019/SRdata/Train_GT_down') 12 | parser.add_argument('-o', '--output', type=str, default='/home/sensetime/Documents/NTIRE2019/SRdata/Train_GT_bicubic') 13 | parser.add_argument('-s', '--scala', type=int, default=4) 14 | parser.add_argument('-v', type=bool, default=False) 15 | parser.add_argument('--up', type=bool, default=True) 16 | args = parser.parse_args() 17 | 18 | input_path = os.path.abspath(args.input) 19 | output_path = os.path.abspath(args.output) 20 | if args.up: 21 | scala = 1 / args.scala 22 | else: 23 | scala = args.scala 24 | V = args.v 25 | 26 | 27 | def make_dir(input, output): 28 | dirs = os.listdir(input) 29 | os.mkdir(output) 30 | for i in dirs: 31 | os.mkdir(os.path.join(output, i)) 32 | 33 | 34 | def resize_and_save(file_org, output, scala): 35 | im = Image.open(file_org) 36 | w, h = im.size 37 | im = resize(im, (int(h // scala), int(w // scala))) 38 | name = os.path.split(file_org) 39 | vdir = os.path.split(name[0]) 40 | if V: 41 | save_name = os.path.join(output, vdir[1]) 42 | save_name = os.path.join(save_name, name[1]) 43 | else: 44 | save_name = os.path.join(output, name[1]) 45 | im.save(save_name) 46 | print(save_name) 47 | 48 | if V: 49 | file_list = _video_image_file(input_path) 50 | make_dir(input_path, output_path) 51 | for j in file_list: 52 | for i in j: 53 | resize_and_save(i, output_path, scala) 54 | else: 55 | file_list = _image_file(input_path) 56 | os.mkdir(output_path) 57 | for i in file_list: 58 | resize_and_save(i, output_path, scala) 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /DataTools/FileTools.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import time 5 | from math import log10 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.autograd import Variable 10 | import torchvision.transforms as transforms 11 | from PIL import Image 12 | 13 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm'] 14 | 15 | 16 | def mkdirs(paths): 17 | if isinstance(paths, list) and not isinstance(paths, str): 18 | for path in paths: 19 | mkdir(path) 20 | else: 21 | mkdir(paths) 22 | 23 | 24 | def mkdir(path): 25 | if not os.path.exists(path): 26 | os.makedirs(path) 27 | 28 | 29 | def _is_image_file(filename): 30 | """ 31 | judge if the file is an image file 32 | :param filename: path 33 | :return: bool of judgement 34 | """ 35 | filename_lower = filename.lower() 36 | return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS) 37 | 38 | 39 | def _video_image_file(path): 40 | """ 41 | Data Store Format: 42 | 43 | Data Folder 44 | | 45 | |-Video Folder 46 | | |-Video Frames (images) 47 | | 48 | |-Video Folder 49 | |-Video Frames (images) 50 | 51 | ... 52 | 53 | |-Video Folder 54 | |-Video Frames (images) 55 | 56 | :param path: path to Data Folder, absolute path 57 | :return: 2D list of str, the path is absolute path 58 | [[Video Frames], [Video Frames], ... , [Video Frames]] 59 | """ 60 | abs_path = os.path.abspath(path) 61 | video_list = os.listdir(abs_path) 62 | video_list.sort() 63 | frame_list = [None] * len(video_list) 64 | for i in range(len(video_list)): 65 | video_list[i] = os.path.join(path, video_list[i]) 66 | frame_list[i] = os.listdir(video_list[i]) 67 | for j in range(len(os.listdir(video_list[i]))): 68 | frame_list[i][j] = os.path.join(video_list[i], frame_list[i][j]) 69 | frame_list[i].sort() 70 | return frame_list 71 | 72 | 73 | def video_frame_names(path): 74 | video, frame = os.path.split(path) 75 | _, video = os.path.split(video) 76 | return video, frame 77 | 78 | 79 | def sample_info_video(video_frames, time_window, time_stride): 80 | samples = [0] * len(video_frames) 81 | area_sum_samples = [0] * len(video_frames) 82 | for i, video in enumerate(video_frames): 83 | samples[i] = (len(video) - time_window) // time_stride 84 | if i != 0: 85 | area_sum_samples[i] = sum(samples[:i]) 86 | return samples, area_sum_samples 87 | 88 | 89 | def _sample_from_videos_frames(path, time_window, time_stride): 90 | """ 91 | Sample from video frames files 92 | :param path: path to Data Folder, absolute path 93 | :param time_window: number of frames in one sample 94 | :param time_stride: strides when sample frames 95 | :return: 2D list of str, absolute path to each frames 96 | [[Sample Frames], [Sample Frames], ... , [Sample Frames]] 97 | """ 98 | video_frame_list = _video_image_file(path) 99 | sample_list = list() 100 | for video in video_frame_list: 101 | assert isinstance(video, list), "Plz check video_frame_list = _video_image_file(path) should be 2D list" 102 | for i in range(0, len(video), time_stride): 103 | sample = video[i:i + time_window] 104 | if len(sample) != time_window: 105 | break 106 | sample.append(video[i + (time_window // 2)]) 107 | sample_list.append(sample) 108 | return sample_list 109 | 110 | 111 | # TODO: large sample number function 112 | def _sample_from_videos_frames_large(path, time_window, time_stride): 113 | """ 114 | write to a file, return one sample once. use pointer 115 | :param path: 116 | :param time_window: 117 | :param time_stride: 118 | :return: 119 | """ 120 | pass 121 | 122 | 123 | def _image_file(path): # TODO: wrong function 124 | """ 125 | return list of images in the path 126 | :param path: path to Data Folder, absolute path 127 | :return: 1D list of image files absolute path 128 | """ 129 | abs_path = os.path.abspath(path) 130 | image_files = os.listdir(abs_path) 131 | for i in range(len(image_files)): 132 | if (not os.path.isdir(image_files[i])) and (_is_image_file(image_files[i])): 133 | image_files[i] = os.path.join(abs_path, image_files[i]) 134 | return image_files 135 | 136 | 137 | def _all_images(path): 138 | """ 139 | return all images in the folder 140 | :param path: path to Data Folder, absolute path 141 | :return: 1D list of image files absolute path 142 | """ 143 | # TODO: Tail Call Elimination 144 | abs_path = os.path.abspath(path) 145 | image_files = list() 146 | for subpath in os.listdir(abs_path): 147 | if os.path.isdir(os.path.join(abs_path, subpath)): 148 | image_files = image_files + _all_images(os.path.join(abs_path, subpath)) 149 | else: 150 | if _is_image_file(subpath): 151 | image_files.append(os.path.join(abs_path, subpath)) 152 | image_files.sort() 153 | return image_files 154 | 155 | 156 | 157 | -------------------------------------------------------------------------------- /DataTools/Loaders.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | 3 | from .FileTools import _is_image_file 4 | from .Prepro import _id 5 | from ..Functions.functional import * 6 | 7 | 8 | def _add_batch_one(tensor): 9 | """ 10 | Return a tensor with size (1, ) + tensor.size 11 | :param tensor: 2D or 3D tensor 12 | :return: 3D or 4D tensor 13 | """ 14 | return tensor.view((1, ) + tensor.size()) 15 | 16 | 17 | def _remove_batch(tensor): 18 | """ 19 | Return a tensor with size tensor.size()[1:] 20 | :param tensor: 3D or 4D tensor 21 | :return: 2D or 3D tensor 22 | """ 23 | return tensor.view(tensor.size()[1:]) 24 | 25 | 26 | def _add_channel_one(tensor): 27 | return tensor.view(tensor.size()[:1] + (1, ) + tensor.size()[-2:]) 28 | 29 | 30 | def _remove_channel(tensor): 31 | return tensor.view((tensor.size()[0] * tensor.size()[1],) + tensor.size()[-2:]) 32 | 33 | 34 | def PIL2Tensor(img): 35 | """ 36 | Converts a PIL Image or numpy.ndarray (H, W, C) in the range [0, 255] to a torch.FloatTensor of shape (1, C, H, W) in the range [0.0, 1.0]. 37 | 38 | :param img: PIL.Image or numpy.ndarray (H, W, C) in the range [0, 255] 39 | :return: 4D tensor with size [1, C, H, W] in range [0, 1.] 40 | """ 41 | return _add_batch_one(to_tensor(img)) 42 | 43 | 44 | def Tensor2PIL(tensor, mode=None): 45 | """ 46 | :param tensor: 4D tensor with size [1, C, H, W] in range [0, 1.] 47 | :param mode: (`PIL.Image mode`_): color space and pixel depth of input data (optional). 48 | PIL.Image mode: http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#modes 49 | :return: PIL.Image 50 | """ 51 | if len(tensor.size()) == 3: 52 | return to_pil_image(tensor, mode=mode) 53 | elif len(tensor.size()) == 4: 54 | return to_pil_image(_remove_batch(tensor)) 55 | 56 | 57 | def PIL2VAR(img, norm_function=_id, volatile=False): 58 | """ 59 | Convert a PIL.Image to Variable directly 60 | :param img: PIL.Image 61 | :param norm_function: The normalization to the tensor 62 | :return: Variable 63 | """ 64 | return Variable(norm_function(PIL2Tensor(img)), volatile=volatile) 65 | 66 | 67 | def VAR2PIL(img, non_norm_function=_id): 68 | """ 69 | Convert a Variable to PIL.Image 70 | :param img: Variable 71 | :param non_norm_function: according to the normalization function, the `inverse` normalization 72 | :return: PIL.Image 73 | """ 74 | return Tensor2PIL(non_norm_function(img.data)) 75 | 76 | 77 | def pil_loader(path, mode='RGB'): 78 | """ 79 | open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 80 | :param path: image path 81 | :return: PIL.Image 82 | """ 83 | assert _is_image_file(path), "%s is not an image" % path 84 | with open(path, 'rb') as f: 85 | with Image.open(f) as img: 86 | return img.convert(mode) 87 | 88 | 89 | def load_to_tensor(path, mode='RGB'): 90 | """ 91 | Load image to tensor 92 | :param path: image path 93 | :param mode: 'Y' returns 1 channel tensor, 'RGB' returns 3 channels, 'RGBA' returns 4 channels, 'YCbCr' returns 3 channels 94 | :return: 3D tensor 95 | """ 96 | if mode != 'Y': 97 | return to_tensor(pil_loader(path, mode=mode)) 98 | else: 99 | return to_tensor(pil_loader(path, mode='YCbCr'))[:1] 100 | 101 | 102 | def YChannel(img_tensor_rgb, mode='yuv'): 103 | """ 104 | :param img_tensor_rgb: RGB 4D tensor, [0, 1] 105 | :param mode: yuv, ycbcr, edsr 106 | :return: Y 107 | """ 108 | if mode == 'yuv': 109 | y = img_tensor_rgb[:, :1, :, :] * 0.299 + img_tensor_rgb[:, 1:2, :, :] * 0.587 + img_tensor_rgb[:, 2:3, :, :] * 0.114 110 | return y 111 | elif mode == 'ycbcr': 112 | y = img_tensor_rgb[:, :1, :, :] * 0.257 + img_tensor_rgb[:, 1:2, :, :] * 0.504 + img_tensor_rgb[:, 2:3, :, :] * 0.098 + 0.0625 113 | return y 114 | elif mode == 'edsr': 115 | y = img_tensor_rgb[:, :1, :, :] * 0.257 + img_tensor_rgb[:, 1:2, :, :] * 0.504 + img_tensor_rgb[:, 2:3, :, :] * 0.098 116 | return y 117 | else: 118 | raise Exception('Check para mode') 119 | -------------------------------------------------------------------------------- /DataTools/Prepro.py: -------------------------------------------------------------------------------- 1 | from ..Functions.functional import * 2 | 3 | 4 | def _id(x): 5 | """ 6 | return x 7 | :param x: 8 | :return: 9 | """ 10 | return x 11 | 12 | 13 | def _sigmoid_to_tanh(x): 14 | """ 15 | range [0, 1] to range [-1, 1] 16 | :param x: tensor type 17 | :return: tensor 18 | """ 19 | return (x - 0.5) * 2. 20 | 21 | 22 | def _tanh_to_sigmoid(x): 23 | """ 24 | range [-1, 1] to range [0, 1] 25 | :param x: 26 | :return: 27 | """ 28 | return x * 0.5 + 0.5 29 | 30 | 31 | def _255_to_tanh(x): 32 | """ 33 | range [0, 255] to range [-1, 1] 34 | :param x: 35 | :return: 36 | """ 37 | return (x - 127.5) / 127.5 38 | 39 | 40 | def _tanh_to_255(x): 41 | """ 42 | range [-1. 1] to range [0, 255] 43 | :param x: 44 | :return: 45 | """ 46 | return x * 127.5 + 127.5 47 | 48 | 49 | # TODO: _sigmoid_to_255(x), _255_to_sigmoid(x) 50 | # def _sigmoid_to_255(x): 51 | # def _255_to_sigmoid(x): 52 | 53 | 54 | def random_pre_process(img): 55 | """ 56 | Random pre-processing the input Image 57 | :param img: PIL.Image 58 | :return: PIL.Image 59 | """ 60 | if bool(random.getrandbits(1)): 61 | img = hflip(img) 62 | if bool(random.getrandbits(1)): 63 | img = vflip(img) 64 | angle = random.randrange(-15, 15) 65 | return rotate(img, angle) 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /DataTools/README.md: -------------------------------------------------------------------------------- 1 | # DataTools 2 | ###### JasonGUTU 3 | This package provides some useful functions and DataSet classes for Image processing and Low-level Computer Vision. 4 | ### Structure 5 | - `DataSets` contains some DataSet class, all the child class of torch.utils.data.Dataset. 6 | - `FileTools` contains tools for file management 7 | - `Loaders` contains Image loaders 8 | - `Prepro` contains self-customized pre-processing functions or classes 9 | ### Docs 10 | #### DataSets.py 11 | All the classes inherited from torch.utils.data.Dataset are self-customized Dataset classes 12 | ```[Python] 13 | # `TestDataset` is a Dataset classes 14 | # Instantiation 15 | dataset = TestDataset(*args, **kwargs) 16 | # Use index to retrieve 17 | first_data = dataset[0] 18 | # Number of samples 19 | length = len(dataset) 20 | ``` 21 | In this file, Datasets contain: 22 | ``` 23 | class SRDataSet(torch.utils.data.Dataset) 24 | """ 25 | :param data_path: Path to data root 26 | :param lr_patch_size: the Low resolution size, by default, the patch is square 27 | :param scala: SR scala, default is 4 28 | :param interp: interpolation for resize, default is Image.BICUBIC, optional [Image.BILINEAR, Image.BICUBIC] 29 | :param mode: 'RGB' or 'Y' 30 | :param sub_dir: if True, then all the images in the `data_path` directory AND child directory will be use 31 | :parem prepro: function fo to ``PIL.Image``!, will run this function before crop and resize 32 | """ 33 | ``` 34 | This Dataset is for loading small images like image-91 and image-191. 35 | The images are small, direct loading has little effect on performance. 36 | In this dataset, every image will be returned once in one epoch. 37 | Every time one image is return will be pre-processing and then random crop a patch. 38 | If the patch size is bigger than image size, the image will be resize to a 'cropable' size and random crop a patch. 39 | 40 | ``` 41 | class SRDataLarge(torch.utils.data.Dataset) 42 | """ 43 | :param data_path: Path to data root 44 | :param lr_patch_size: the Low resolution size, by default, the patch is square 45 | :param scala: SR scala, default is 4 46 | :param interp: interpolation for resize, default is Image.BICUBIC, optional [Image.BILINEAR, Image.BICUBIC] 47 | :param mode: 'RGB' or 'Y' 48 | :param sub_dir: if True, then all the images in the `data_path` directory AND child directory will be use 49 | :parem prepro: function fo to ``PIL.Image``!, will run this function before crop and resize 50 | :param buffer: how many patches cut from one image 51 | """ 52 | ``` 53 | This Dataset is for loading large images like DIV2K. 54 | The images are large, direct loading has effect on performance. 55 | In this dataset, every image will be returned buffer number of patches in one epoch. 56 | Every time one image is return will be pre-processing and then random crop a patch. 57 | 58 | #### FileTools.py 59 | ```[Python] 60 | def _is_image_file(filename): 61 | """ 62 | judge if the file is an image file 63 | :param filename: path 64 | :return: bool of judgement 65 | """ 66 | ``` 67 | ```[Python] 68 | def _video_image_file(path): 69 | """ 70 | Data Store Format: 71 | 72 | Data Folder 73 | | 74 | |-Video Folder 75 | | |-Video Frames (images) 76 | | 77 | |-Video Folder 78 | |-Video Frames (images) 79 | 80 | ... 81 | 82 | |-Video Folder 83 | |-Video Frames (images) 84 | 85 | :param path: path to Data Folder, absolute path 86 | :return: 2D list of str, the path is absolute path 87 | [[Video Frames], [Video Frames], ... , [Video Frames]] 88 | """ 89 | ``` 90 | ```[Python] 91 | def _sample_from_videos_frames(path, time_window, time_stride): 92 | """ 93 | Sample from video frames files 94 | :param path: path to Data Folder, absolute path 95 | :param time_window: number of frames in one sample 96 | :param time_stride: strides when sample frames 97 | :return: 2D list of str, absolute path to each frames 98 | [[Sample Frames], [Sample Frames], ... , [Sample Frames]] 99 | """ 100 | ``` 101 | ```[Python] 102 | def _image_file(path): 103 | """ 104 | return list of images in the path 105 | :param path: path to Data Folder, absolute path 106 | :return: 1D list of image files absolute path 107 | """ 108 | ``` 109 | ```[Python] 110 | def _all_images(path): 111 | """ 112 | return all images in the folder, include child folder. 113 | :param path: path to Data Folder, absolute path 114 | :return: 1D list of image files absolute path 115 | """ 116 | ``` 117 | 118 | #### Loaders.py 119 | 120 | ```[Python] 121 | def _add_batch_one(tensor): 122 | """ 123 | Return a tensor with size (1, ) + tensor.size 124 | :param tensor: 2D or 3D tensor 125 | :return: 3D or 4D tensor 126 | """ 127 | ``` 128 | ```[Python] 129 | def _remove_batch(tensor): 130 | """ 131 | Return a tensor with size tensor.size()[1:] 132 | :param tensor: 3D or 4D tensor 133 | :return: 2D or 3D tensor 134 | """ 135 | ``` 136 | ```[Python] 137 | def PIL2Tensor(img): 138 | """ 139 | Converts a PIL Image or numpy.ndarray (H, W, C) in the range [0, 255] to a torch.FloatTensor of shape (1, C, H, W) in the range [0.0, 1.0]. 140 | 141 | :param img: PIL.Image or numpy.ndarray (H, W, C) in the range [0, 255] 142 | :return: 4D tensor with size [1, C, H, W] in range [0, 1.] 143 | """ 144 | ``` 145 | ```[Python] 146 | def Tensor2PIL(tensor, mode=None): 147 | """ 148 | Convert a 4D tensor with size [1, C, H, W] to PIL Image 149 | :param tensor: 4D tensor with size [1, C, H, W] in range [0, 1.] 150 | :param mode: (`PIL.Image mode`_): color space and pixel depth of input data (optional). 151 | PIL.Image mode: http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#modes 152 | :return: PIL.Image 153 | """ 154 | ``` 155 | ```[Python] 156 | def PIL2VAR(img, norm_function=_id): 157 | """ 158 | Convert a PIL.Image to Variable directly, add batch dimension (can be use to test directly) 159 | :param img: PIL.Image 160 | :param norm_function: The normalization to the tensor 161 | :return: Variable 162 | """ 163 | ``` 164 | ```[Python] 165 | def VAR2PIL(img, non_norm_function=_id): 166 | """ 167 | Convert a Variable to PIL.Image 168 | :param img: Variable 169 | :param non_norm_function: according to the normalization function, the `inverse` normalization 170 | :return: PIL.Image 171 | """ 172 | ``` 173 | ```[Python] 174 | def pil_loader(path, mode='RGB'): 175 | """ 176 | open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 177 | :param path: image path 178 | :return: PIL.Image 179 | """ 180 | ``` 181 | ```[Python] 182 | def load_to_tensor(path, mode='RGB'): 183 | """ 184 | Load image to tensor, 3D tensor 185 | :param path: image path 186 | :param mode: 'Y' returns 1 channel tensor, 'RGB' returns 3 channels, 'RGBA' returns 4 channels, 'YCbCr' returns 3 channels 187 | :return: 3D tensor 188 | """ 189 | ``` 190 | 191 | 192 | 193 | -------------------------------------------------------------------------------- /DataTools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonGUTU/TorchTools/b7025ab8469eec87830b109a344a7af551a5d464/DataTools/__init__.py -------------------------------------------------------------------------------- /Functions/Metrics.py: -------------------------------------------------------------------------------- 1 | try: 2 | from math import log10 3 | except ImportError: 4 | from math import log 5 | def log10(x): 6 | return log(x) / log(10.) 7 | 8 | import torch 9 | 10 | from .functional import to_tensor 11 | 12 | 13 | def mse(x, y): 14 | """ 15 | MSE Error 16 | :param x: tensor 17 | :param y: tensor 18 | :return: float 19 | """ 20 | diff = x - y 21 | diff = diff * diff 22 | return torch.mean(diff) 23 | 24 | 25 | def psnr(x, y, peak=1.): 26 | """ 27 | psnr from tensor 28 | :param x: tensor 29 | :param y: tensor 30 | :return: float (mse, psnr) 31 | """ 32 | _mse = mse(x, y) 33 | return _mse, 10 * log10((peak ** 2) / _mse) 34 | 35 | 36 | def PSNR(x, y): 37 | """ 38 | PSNR from PIL.Image 39 | :param x: PIL.Image 40 | :param y: PIL.Image 41 | :return: float (mse, psnr) 42 | """ 43 | return psnr(to_tensor(x), to_tensor(y), peak=1.) 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /Functions/README.md: -------------------------------------------------------------------------------- 1 | # Functions 2 | ###### JasonGUTU 3 | This package provides some useful functions and DataSet classes for Image processing and Low-level Computer Vision. 4 | ### Structure 5 | - `Metrics` contains some useful metrics 6 | - `functional` contains tools for PIL.Image transfer, similar to torchvision.transforms.functional but you can use it in GPU server. 7 | - `TestTools` contains Tools for easy testing -------------------------------------------------------------------------------- /Functions/SRMeasure.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import math 3 | from scipy.ndimage import gaussian_filter 4 | from numpy.lib.stride_tricks import as_strided as ast 5 | import numpy.linalg 6 | from scipy.special import gamma 7 | import scipy.misc 8 | import scipy.io 9 | import skimage.transform 10 | 11 | 12 | 13 | def psnr(img1, img2): 14 | mse = numpy.mean( (img1 - img2) ** 2 ) 15 | if mse == 0: 16 | return 100 17 | PIXEL_MAX = 255.0 18 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 19 | 20 | """ 21 | Hat tip: http://stackoverflow.com/a/5078155/1828289 22 | """ 23 | def block_view(A, block=(3, 3)): 24 | """Provide a 2D block view to 2D array. No error checking made. 25 | Therefore meaningful (as implemented) only for blocks strictly 26 | compatible with the shape of A.""" 27 | # simple shape and strides computations may seem at first strange 28 | # unless one is able to recognize the 'tuple additions' involved ;-) 29 | shape = (A.shape[0]/ block[0], A.shape[1]/ block[1])+ block 30 | strides = (block[0]* A.strides[0], block[1]* A.strides[1])+ A.strides 31 | return ast(A, shape= shape, strides= strides) 32 | 33 | 34 | def ssim(img1, img2, C1=0.01**2, C2=0.03**2): 35 | 36 | bimg1 = block_view(img1, (4,4)) 37 | bimg2 = block_view(img2, (4,4)) 38 | s1 = numpy.sum(bimg1, (-1, -2)) 39 | s2 = numpy.sum(bimg2, (-1, -2)) 40 | ss = numpy.sum(bimg1*bimg1, (-1, -2)) + numpy.sum(bimg2*bimg2, (-1, -2)) 41 | s12 = numpy.sum(bimg1*bimg2, (-1, -2)) 42 | 43 | vari = ss - s1*s1 - s2*s2 44 | covar = s12 - s1*s2 45 | 46 | ssim_map = (2*s1*s2 + C1) * (2*covar + C2) / ((s1*s1 + s2*s2 + C1) * (vari + C2)) 47 | return numpy.mean(ssim_map) 48 | 49 | # FIXME there seems to be a problem with this code 50 | def ssim_exact(img1, img2, sd=1.5, C1=0.01**2, C2=0.03**2): 51 | 52 | mu1 = gaussian_filter(img1, sd) 53 | mu2 = gaussian_filter(img2, sd) 54 | mu1_sq = mu1 * mu1 55 | mu2_sq = mu2 * mu2 56 | mu1_mu2 = mu1 * mu2 57 | sigma1_sq = gaussian_filter(img1 * img1, sd) - mu1_sq 58 | sigma2_sq = gaussian_filter(img2 * img2, sd) - mu2_sq 59 | sigma12 = gaussian_filter(img1 * img2, sd) - mu1_mu2 60 | 61 | ssim_num = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) 62 | 63 | ssim_den = ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 64 | 65 | ssim_map = ssim_num / ssim_den 66 | return numpy.mean(ssim_map) 67 | 68 | """ 69 | Generalized Gaussian distribution estimation. 70 | Cite: 71 | Dominguez-Molina, J. Armando, et al. "A practical procedure to estimate the shape parameter in the generalized Gaussian distribution.", 72 | available through http://www. cimat. mx/reportes/enlinea/I-01-18_eng. pdf 1 (2001). 73 | """ 74 | 75 | """ 76 | Generalized Gaussian ratio function 77 | Cite: Dominguez-Molina 2001, pg 7, eq (8) 78 | """ 79 | 80 | 81 | def generalized_gaussian_ratio(alpha): 82 | return (gamma(2.0 / alpha) ** 2) / (gamma(1.0 / alpha) * gamma(3.0 / alpha)) 83 | 84 | 85 | """ 86 | Generalized Gaussian ratio function inverse (numerical approximation) 87 | Cite: Dominguez-Molina 2001, pg 13 88 | """ 89 | 90 | 91 | def generalized_gaussian_ratio_inverse(k): 92 | a1 = -0.535707356 93 | a2 = 1.168939911 94 | a3 = -0.1516189217 95 | b1 = 0.9694429 96 | b2 = 0.8727534 97 | b3 = 0.07350824 98 | c1 = 0.3655157 99 | c2 = 0.6723532 100 | c3 = 0.033834 101 | 102 | if k < 0.131246: 103 | return 2 * math.log(27.0 / 16.0) / math.log(3.0 / (4 * k ** 2)) 104 | elif k < 0.448994: 105 | return (1 / (2 * a1)) * (-a2 + math.sqrt(a2 ** 2 - 4 * a1 * a3 + 4 * a1 * k)) 106 | elif k < 0.671256: 107 | return (1 / (2 * b3 * k)) * (b1 - b2 * k - math.sqrt((b1 - b2 * k) ** 2 - 4 * b3 * (k ** 2))) 108 | elif k < 0.75: 109 | # print "%f %f %f" % (k, ((3-4*k)/(4*c1)), c2**2 + 4*c3*log((3-4*k)/(4*c1)) ) 110 | return (1 / (2 * c3)) * (c2 - math.sqrt(c2 ** 2 + 4 * c3 * math.log((3 - 4 * k) / (4 * c1)))) 111 | else: 112 | print("warning: GGRF inverse of %f is not defined" % (k)) 113 | return numpy.nan 114 | 115 | 116 | """ 117 | Estimate the parameters of an asymmetric generalized Gaussian distribution 118 | """ 119 | 120 | 121 | def estimate_aggd_params(x): 122 | x_left = x[x < 0] 123 | x_right = x[x >= 0] 124 | stddev_left = math.sqrt((1.0 / (x_left.size - 1)) * numpy.sum(x_left ** 2)) 125 | stddev_right = math.sqrt((1.0 / (x_right.size - 1)) * numpy.sum(x_right ** 2)) 126 | if stddev_right == 0: 127 | return 1, 0, 0 # TODO check this 128 | r_hat = numpy.mean(numpy.abs(x)) ** 2 / numpy.mean(x ** 2) 129 | y_hat = stddev_left / stddev_right 130 | R_hat = r_hat * (y_hat ** 3 + 1) * (y_hat + 1) / ((y_hat ** 2 + 1) ** 2) 131 | alpha = generalized_gaussian_ratio_inverse(R_hat) 132 | beta_left = stddev_left * math.sqrt(gamma(3.0 / alpha) / gamma(1.0 / alpha)) 133 | beta_right = stddev_right * math.sqrt(gamma(3.0 / alpha) / gamma(1.0 / alpha)) 134 | return alpha, beta_left, beta_right 135 | 136 | 137 | def compute_features(img_norm): 138 | features = [] 139 | alpha, beta_left, beta_right = estimate_aggd_params(img_norm) 140 | 141 | features.extend([alpha, (beta_left + beta_right) / 2]) 142 | 143 | for x_shift, y_shift in ((0, 1), (1, 0), (1, 1), (1, -1)): 144 | img_pair_products = img_norm * numpy.roll(numpy.roll(img_norm, y_shift, axis=0), x_shift, axis=1) 145 | alpha, beta_left, beta_right = estimate_aggd_params(img_pair_products) 146 | eta = (beta_right - beta_left) * (gamma(2.0 / alpha) / gamma(1.0 / alpha)) 147 | features.extend([alpha, eta, beta_left, beta_right]) 148 | 149 | return features 150 | 151 | 152 | def normalize_image(img, sigma=7 / 6): 153 | mu = gaussian_filter(img, sigma, mode='nearest') 154 | mu_sq = mu * mu 155 | sigma = numpy.sqrt(numpy.abs(gaussian_filter(img * img, sigma, mode='nearest') - mu_sq)) 156 | img_norm = (img - mu) / (sigma + 1) 157 | return img_norm 158 | 159 | 160 | def niqe(img): 161 | model_mat = scipy.io.loadmat('modelparameters.mat') 162 | model_mu = model_mat['mu_prisparam'] 163 | model_cov = model_mat['cov_prisparam'] 164 | 165 | features = None 166 | img_scaled = img 167 | for scale in [1, 2]: 168 | 169 | if scale != 1: 170 | img_scaled = skimage.transform.rescale(img, 1 / scale) 171 | # img_scaled = scipy.misc.imresize(img_norm, 0.5) 172 | 173 | # print img_scaled 174 | img_norm = normalize_image(img_scaled) 175 | 176 | scale_features = [] 177 | block_size = 96 // scale 178 | for block_col in range(img_norm.shape[0] // block_size): 179 | for block_row in range(img_norm.shape[1] // block_size): 180 | block_features = compute_features(img_norm[block_col * block_size:(block_col + 1) * block_size, 181 | block_row * block_size:(block_row + 1) * block_size]) 182 | scale_features.append(block_features) 183 | # print "len(scale_features)=%f" %(len(scale_features)) 184 | if features == None: 185 | features = numpy.vstack(scale_features) 186 | # print features.shape 187 | else: 188 | features = numpy.hstack([features, numpy.vstack(scale_features)]) 189 | # print features.shape 190 | 191 | features_mu = numpy.mean(features, axis=0) 192 | features_cov = numpy.cov(features.T) 193 | 194 | pseudoinv_of_avg_cov = numpy.linalg.pinv((model_cov + features_cov) / 2) 195 | niqe_quality = math.sqrt((model_mu - features_mu).dot(pseudoinv_of_avg_cov.dot((model_mu - features_mu).T))) 196 | 197 | return niqe_quality -------------------------------------------------------------------------------- /Functions/TestTools.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import skimage.io as io 5 | from scipy import misc 6 | from PIL import Image, ImageDraw 7 | 8 | import torch 9 | from torch.autograd import Variable 10 | 11 | from ..DataTools.FileTools import _image_file 12 | from ..DataTools.Loaders import pil_loader, load_to_tensor 13 | 14 | 15 | RY = 15 16 | YG = 6 17 | GC = 4 18 | CB = 11 19 | BM = 13 20 | MR = 6 21 | ncols = sum([RY, YG, GC, CB, BM, MR]) 22 | 23 | 24 | LEFT_EYE = [36, 37, 38, 39, 40, 41] 25 | LEFT_EYEBROW = [17, 18, 19, 20, 21] 26 | RIGHT_EYE = [42, 43, 44, 45, 46, 47] 27 | RIGHT_EYEBROW = [22, 23, 24, 25, 26] 28 | MOUTH = [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67] 29 | LEFT_MOUTH = [48, 60] 30 | RIGHT_MOUTH = [54, 64] 31 | LEFT_MOST = [0, 1, 2] 32 | RIGHT_MOST = [14, 15, 16] 33 | TOP_MOST = [18, 19, 20, 23, 24, 25] 34 | DOWN_MOST = [7, 8, 9] 35 | NOSE_TIP = [31, 32, 33, 34, 35] 36 | 37 | 38 | def _id(x): 39 | """ 40 | return x 41 | :param x: 42 | :return: 43 | """ 44 | return x 45 | 46 | 47 | def _add_batch_one(tensor): 48 | """ 49 | Return a tensor with size (1, ) + tensor.size 50 | :param tensor: 2D or 3D tensor 51 | :return: 3D or 4D tensor 52 | """ 53 | return tensor.view((1, ) + tensor.size()) 54 | 55 | 56 | def _remove_batch(tensor): 57 | """ 58 | Return a tensor with size tensor.size()[1:] 59 | :param tensor: 3D or 4D tensor 60 | :return: 2D or 3D tensor 61 | """ 62 | return tensor.view(tensor.size()[1:]) 63 | 64 | 65 | def _sigmoid_to_tanh(x): 66 | """ 67 | range [0, 1] to range [-1, 1] 68 | :param x: tensor type 69 | :return: tensor 70 | """ 71 | return (x - 0.5) * 2. 72 | 73 | 74 | def _tanh_to_sigmoid(x): 75 | """ 76 | range [-1, 1] to range [0, 1] 77 | :param x: 78 | :return: 79 | """ 80 | return x * 0.5 + 0.5 81 | 82 | 83 | def test_pool(path, cuda=True, mode='Y', normalization_func=_sigmoid_to_tanh): 84 | img_path_list = _image_file(path) 85 | test_img_pool = list() 86 | for i in range(len(img_path_list)): 87 | pic = Variable(_add_batch_one(normalization_func(load_to_tensor(img_path_list[i], mode=mode)))) 88 | pic.volatile = True 89 | if cuda: 90 | test_img_pool.append(pic.cuda()) 91 | else: 92 | test_img_pool.append(pic) 93 | return test_img_pool 94 | 95 | 96 | def high_point_to_low_point(point_set, size_h, size_l): 97 | """ 98 | The three parameters will be (w, h) 99 | :param point: the point location 100 | :param size_h: high resolution image size 101 | :param size_l: low resolution image size 102 | :return: 103 | """ 104 | h_1, h_2 = size_h 105 | l_1, l_2 = size_l 106 | lr_points = list() 107 | for point in point_set: 108 | x, y = point 109 | lr_points.append([int(round(x * (l_1 / h_1))), int(round(y * (l_2 / h_2)))]) 110 | return np.array(lr_points) 111 | 112 | 113 | def _center_point(point_1, point_2): 114 | return (np.array(point_1, dtype=np.float32) + np.array(point_2, dtype=np.float32)) // 2 115 | 116 | 117 | def _euclid_distance(point_1, point_2): 118 | return np.sqrt(np.sum((point_1 - point_2) ** 2)) 119 | 120 | 121 | def _angle_point(center, point_1, point_2): 122 | a = _euclid_distance(point_1, point_2) 123 | b = _euclid_distance(center, point_2) 124 | c = _euclid_distance(center, point_1) 125 | return np.arccos((b ** 2 + c ** 2 - a ** 2) / (2 * b * c)) 126 | 127 | 128 | def _rotate_affine_matrix(center, theta, high): 129 | """ 130 | :param center: rotate center 131 | :param theta: clock wise, rad 132 | :return: 133 | """ 134 | x, y = center 135 | # y = high - y_p 136 | sin = np.sin(theta) 137 | cos = np.cos(theta) 138 | matrix = np.array( 139 | [[cos, -sin, x - x * cos + y * sin], 140 | [sin, cos, y - x * sin - y * cos]] 141 | ) 142 | return torch.from_numpy(matrix) 143 | 144 | 145 | def get_landmarks(img, detector, predictor): 146 | """ 147 | Return landmark martix 148 | :param img: img read by skimage.io.imread 149 | :param detector: dlib.get_frontal_face_detector() instance 150 | :param predictor: dlib.shape_predictor('..?./shape_predictor_68_face_landmarks.dat') 151 | :return: landmark matrix 152 | """ 153 | rects = detector(img, 1) 154 | return np.array([[p.x, p.y] for p in predictor(img, rects[0]).parts()]) 155 | 156 | 157 | def _centroid(landmarks, point_list): 158 | """ 159 | Return the centroid of given points 160 | :param point_list: point list 161 | :return: array(centroid) (x, y) 162 | """ 163 | x = np.zeros((len(point_list),)) 164 | y = np.zeros((len(point_list),)) 165 | for i, p in enumerate(point_list): 166 | x[i] = landmarks[p][0] 167 | y[i] = landmarks[p][1] 168 | x_mean = int(x.mean()) 169 | y_mean = int(y.mean()) 170 | return np.array([x_mean, y_mean]) 171 | 172 | 173 | def make_color_wheel(): 174 | """A color wheel or color circle is an abstract illustrative 175 | organization of color hues around a circle. 176 | This is for making output image easy to distinguish every 177 | part. 178 | """ 179 | # These are chosen based on perceptual similarity 180 | # e.g. one can distinguish more shades between red and yellow 181 | # than between yellow and green 182 | 183 | if ncols > 60: 184 | exit(1) 185 | 186 | color_wheel = np.zeros((ncols, 3)) 187 | i = 0 188 | # RY: (255, 255*i/RY, 0) 189 | color_wheel[i: i + RY, 0] = 255 190 | color_wheel[i: i + RY, 1] = np.arange(RY) * 255 / RY 191 | i += RY 192 | # YG: (255-255*i/YG, 255, 0) 193 | color_wheel[i: i + YG, 0] = 255 - np.arange(YG) * 255 / YG 194 | color_wheel[i: i + YG, 1] = 255 195 | i += YG 196 | # GC: (0, 255, 255*i/GC) 197 | color_wheel[i: i + GC, 1] = 255 198 | color_wheel[i: i + GC, 2] = np.arange(GC) * 255 / GC 199 | i += GC 200 | # CB: (0, 255-255*i/CB, 255) 201 | color_wheel[i: i + CB, 1] = 255 - np.arange(CB) * 255 / CB 202 | color_wheel[i: i + CB, 2] = 255 203 | i += CB 204 | # BM: (255*i/BM, 0, 255) 205 | color_wheel[i: i + BM, 0] = np.arange(BM) * 255 / BM 206 | color_wheel[i: i + BM, 2] = 255 207 | i += BM 208 | # MR: (255, 0, 255-255*i/MR) 209 | color_wheel[i: i + MR, 0] = 255 210 | color_wheel[i: i + MR, 2] = 255 - np.arange(MR) * 255 / MR 211 | 212 | return color_wheel 213 | 214 | 215 | def mapping_to_indices(coords): 216 | """numpy advanced indexing is like x[, , ...] 217 | this function convert coords of shape (h, w, 2) to advanced indices 218 | 219 | # Arguments 220 | coords: shape of (h, w) 221 | # Returns 222 | indices: [, , ...] 223 | """ 224 | h, w = coords.shape[:2] 225 | indices_axis_2 = list(np.tile(coords[:, :, 0].reshape(-1), 2)) 226 | indices_axis_3 = list(np.tile(coords[:, :, 1].reshape(-1), 1)) 227 | return [indices_axis_2, indices_axis_3] 228 | 229 | 230 | def flow_to_color(flow, normalized=True): 231 | """ 232 | # Arguments 233 | flow: (h, w, 2) flow[u, v] is (y_offset, x_offset) 234 | normalized: if is True, element in flow is between -1 and 1, which 235 | present to 236 | """ 237 | color_wheel = make_color_wheel() # (55, 3) 238 | h, w = flow.shape[:2] 239 | rad = np.sum(flow ** 2, axis=2) ** 0.5 # shape: (h, w) 240 | rad = np.concatenate([rad.reshape(h, w, 1)] * 3, axis=-1) 241 | a = np.arctan2(-flow[:, :, 1], -flow[:, :, 0]) / np.pi # shape: (h, w) range: (-1, 1) 242 | fk = (a + 1.0) / 2.0 * (ncols - 1) # -1~1 mapped to 1~ncols 243 | k0 = np.floor(fk).astype(np.int) 244 | k1 = (k0 + 1) % ncols 245 | f = (fk - k0).reshape((-1, 1)) 246 | f = np.concatenate([f] * 3, axis=1) 247 | color0 = color_wheel[list(k0.reshape(-1))] / 255.0 248 | color1 = color_wheel[list(k1.reshape(-1))] / 255.0 249 | res = (1 - f) * color0 + f * color1 250 | res = np.reshape(res, (h, w, 3)) # flatten to h*w 251 | mask = rad <= 1 252 | res[mask] = (1 - rad * (1 - res))[mask] # increase saturation with radius 253 | res[~mask] *= .75 # out of range 254 | 255 | return res 256 | 257 | 258 | def DrawBoxAndCrop(img, leftup_point, box_size, upsample=1, line_width=2, color='red', resample=Image.NEAREST): 259 | leftup_x, leftup_y = leftup_point 260 | if isinstance(box_size, int): 261 | crop_x = box_size 262 | crop_y = box_size 263 | else: 264 | crop_x, crop_y = box_size 265 | imgc = img.copy() 266 | img_crop = img.crop((leftup_x, leftup_y, leftup_x + crop_x, leftup_y + crop_y)) 267 | draw = ImageDraw.Draw(imgc) 268 | draw.line([(leftup_x, leftup_y),(leftup_x+crop_x, leftup_y),(leftup_x+crop_x, leftup_y+crop_y),(leftup_x, leftup_y+crop_y),(leftup_x, leftup_y)], fill=color, width=line_width) 269 | img_crop = img_crop.resize((crop_x*upsample, crop_y*upsample), resample=resample) 270 | return imgc, img_crop -------------------------------------------------------------------------------- /Functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonGUTU/TorchTools/b7025ab8469eec87830b109a344a7af551a5d464/Functions/__init__.py -------------------------------------------------------------------------------- /LogTools/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonGUTU/TorchTools/b7025ab8469eec87830b109a344a7af551a5d464/LogTools/README.md -------------------------------------------------------------------------------- /LogTools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonGUTU/TorchTools/b7025ab8469eec87830b109a344a7af551a5d464/LogTools/__init__.py -------------------------------------------------------------------------------- /LogTools/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import dominate 4 | from dominate.tags import * 5 | from collections import OrderedDict 6 | import torch 7 | 8 | 9 | class _FileLogger(object): 10 | """Logger for losses files""" 11 | def __init__(self, logger, log_name, title_list): 12 | """ 13 | Init a new log term 14 | :param log_name: The name of the log 15 | :param title_list: list, titles to store 16 | :return: 17 | """ 18 | assert isinstance(logger, Logger), "logger should be instance of Logger" 19 | self.titles = title_list 20 | self.len = len(title_list) 21 | self.log_file_name = os.path.join(logger.log_dir, log_name + '.csv') 22 | with open(self.log_file_name, 'w') as f: 23 | f.write(','.join(title_list) + '\n') 24 | 25 | def add_log(self, value_list): 26 | assert len(value_list) == self.len, "Log Value doesn't match" 27 | for i in range(self.len): 28 | if not isinstance(value_list[i], str): 29 | value_list[i] = str(value_list[i]) 30 | with open(self.log_file_name, 'a') as f: 31 | f.write(','.join(value_list) + '\n') 32 | 33 | 34 | class HTML: 35 | def __init__(self, logger, reflesh=0): 36 | self.logger = logger 37 | self.title = logger.opt.exp_name 38 | self.web_dir = logger.web 39 | self.img_dir = logger.img 40 | if not os.path.exists(self.web_dir): 41 | os.makedirs(self.web_dir) 42 | if not os.path.exists(self.img_dir): 43 | os.makedirs(self.img_dir) 44 | # print(self.img_dir) 45 | 46 | self.doc = dominate.document(title=self.title) 47 | if reflesh > 0: 48 | with self.doc.head: 49 | meta(http_equiv="reflesh", content=str(reflesh)) 50 | 51 | def get_image_dir(self): 52 | return self.img_dir 53 | 54 | def add_header(self, str): 55 | with self.doc: 56 | h3(str) 57 | 58 | def add_table(self, border=1): 59 | self.t = table(border=border, style="table-layout: fixed;") 60 | self.doc.add(self.t) 61 | 62 | def add_images(self, ims, txts, links, width=400): 63 | self.add_table() 64 | with self.t: 65 | with tr(): 66 | for im, txt, link in zip(ims, txts, links): 67 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 68 | with p(): 69 | with a(href=os.path.join('images', link)): #TODO:image 70 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 71 | br() 72 | p(txt) 73 | 74 | def save(self): 75 | html_file = '%s/index.html' % self.web_dir 76 | f = open(html_file, 'wt') 77 | f.write(self.doc.render()) 78 | f.close() 79 | 80 | 81 | class Logger(object): 82 | """Logger for easy log the training process.""" 83 | def __init__(self, name, exp_dir, opt, commend='', HTML_doc=False, log_dir='log', checkpoint_dir='checkpoint', sample='samples', web='web'): 84 | """ 85 | Init the exp dirs and generate readme file 86 | :param name: experiment name 87 | :param exp_dir: dir name to store exp 88 | :param opt: argparse namespace 89 | :param log_dir: 90 | :param checkpoint_dir: 91 | :param sample: 92 | """ 93 | self.name = name 94 | self.exp_dir = os.path.abspath(exp_dir) 95 | self.log_dir = os.path.join(self.exp_dir, log_dir) 96 | self.sample = os.path.join(self.exp_dir, sample) 97 | self.web = os.path.join(self.exp_dir, web) 98 | self.img = os.path.join(self.web, 'images') #TODO:image 99 | self.checkpoint_dir = os.path.join(self.exp_dir, checkpoint_dir) 100 | self.opt = opt 101 | try: 102 | os.mkdir(self.exp_dir) 103 | os.mkdir(self.log_dir) 104 | os.mkdir(self.checkpoint_dir) 105 | os.mkdir(self.sample) 106 | os.mkdir(self.web) 107 | os.mkdir(self.img) 108 | print('Creating: %s\n %s\n %s\n %s' % (self.exp_dir, self.log_dir, self.sample, self.checkpoint_dir)) 109 | except NotImplementedError: 110 | raise Exception('Check your dir.') 111 | except FileExistsError: 112 | pass 113 | with open(os.path.join(self.exp_dir, 'run_commend.txt'), 'w') as f: 114 | f.write(commend) 115 | self.html_tag = HTML_doc 116 | if HTML_doc: 117 | self.html = HTML(self) 118 | self.html.add_header(opt.exp_name) 119 | self.html.save() 120 | 121 | self._parse() 122 | 123 | def _parse(self): 124 | """ 125 | print parameters and generate readme file 126 | :return: 127 | """ 128 | attr_list = list() 129 | exp_readme = os.path.join(self.exp_dir, 'README.txt') 130 | 131 | for attr in dir(self.opt): 132 | if not attr.startswith('_'): 133 | attr_list.append(attr) 134 | print('Init parameters...') 135 | with open(exp_readme, 'w') as readme: 136 | readme.write(self.name + '\n') 137 | for attr in attr_list: 138 | line = '%s : %s' % (attr, self.opt.__getattribute__(attr)) 139 | print(line) 140 | readme.write(line) 141 | readme.write('\n') 142 | 143 | def init_scala_log(self, log_name, title_list): 144 | """ 145 | Init a new log term 146 | :param log_name: The name of the log 147 | :param title_list: list, titles to store 148 | :return: 149 | """ 150 | return _FileLogger(self, log_name, title_list) 151 | 152 | def _parse_save_name(self, tag, epoch, step='_', type='.pth'): 153 | return str(epoch) + step + tag + type 154 | 155 | def save_epoch(self, epoch, name, state_dict): 156 | """ 157 | Torch save 158 | :param name: 159 | :param state_dict: 160 | :return: 161 | """ 162 | torch.save(state_dict, os.path.join(self.checkpoint_dir, self._parse_save_name(name, epoch))) 163 | 164 | def save(self, name, state_dict): 165 | """ 166 | Torch save 167 | :param name: 168 | :param state_dict: 169 | :return: 170 | """ 171 | torch.save(state_dict, os.path.join(self.checkpoint_dir, name)) 172 | 173 | def load_epoch(self, name, epoch): 174 | return torch.load(os.path.join(self.checkpoint_dir, self._parse_save_name(name, epoch))) 175 | 176 | def print_log(self, string): 177 | print(string) 178 | with open(os.path.join(self.log_dir, 'output.log'), 'a') as f: 179 | f.write(string) 180 | 181 | def _parse_web_image_name(self, Nom, tag, step='_', type='.png'): 182 | return 'No.' + str(Nom) + step + tag + type 183 | 184 | def _save_web_images(self, pil, name): 185 | save_path = os.path.join(self.img, name) 186 | pil.save(save_path) 187 | 188 | def _add_image_table(self, img_list, tag_list): 189 | assert len(img_list) == len(tag_list), 'check input' 190 | self.html.add_images(img_list, tag_list, img_list) 191 | self.html.save() 192 | 193 | def save_image_record(self, epoch, image_dict): 194 | img_list = list() 195 | tag_list = list() 196 | for tag, image in image_dict.items(): 197 | image_name = self._parse_web_image_name(epoch, tag) 198 | img_list.append(image_name) 199 | tag_list.append(tag) 200 | image.save(os.path.join(self.img, image_name)) 201 | self.html.add_header('Epoch: %d' % epoch) 202 | self._add_image_table(img_list, tag_list) 203 | 204 | def save_logger(self): 205 | with open(os.path.join(self.exp_dir, 'logger.pkl'), 'w') as f: 206 | pickle.dump(self, f) 207 | -------------------------------------------------------------------------------- /ModelTools/BaseModel.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | 5 | class BaseModel(object): 6 | """ 7 | The base class of one experiment 8 | """ 9 | def __init__(self, logger): 10 | self.logger = logger 11 | 12 | def name(self): 13 | return 'BaseModel' 14 | 15 | def initialize(self): 16 | pass 17 | 18 | def set_input(self): 19 | pass 20 | 21 | def backward(self): 22 | pass 23 | 24 | def optimize_parameters(self): 25 | pass 26 | 27 | def train_step(self): 28 | pass 29 | 30 | def test(self): 31 | pass 32 | 33 | def save_model(self): 34 | pass 35 | 36 | def load_model(self): 37 | pass 38 | 39 | def save_filename(self, network_label, epoch_label): 40 | return '%s_net_%s.pth' % (epoch_label, network_label) 41 | 42 | # helper saving function that can be used by subclasses 43 | def save_network(self, network, network_label, epoch_label): 44 | save_path = os.path.join(self.logger.checkpoint_dir, self.save_filename(network_label, epoch_label)) 45 | torch.save(network.cpu().state_dict(), save_path) 46 | network.cuda() 47 | 48 | # helper loading function that can be used by subclasses 49 | def load_network(self, network, network_label, epoch_label): 50 | save_path = os.path.join(self.logger.checkpoint_dir, self.save_filename(network_label, epoch_label)) 51 | network.load_state_dict(torch.load(save_path)) 52 | 53 | def get_current_errors(self): 54 | return {} 55 | 56 | 57 | -------------------------------------------------------------------------------- /ModelTools/ModelZoo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonGUTU/TorchTools/b7025ab8469eec87830b109a344a7af551a5d464/ModelTools/ModelZoo/__init__.py -------------------------------------------------------------------------------- /ModelTools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonGUTU/TorchTools/b7025ab8469eec87830b109a344a7af551a5d464/ModelTools/__init__.py -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TorchTools 2 | ###### Jason GUTU 3 | ### About 4 | This is a Tool package for PyTorch. 5 | - `DataTools` contains tools for easy data management 6 | - `LogTools` contains tools for easy logging 7 | - `TorchNet` contains modules and some classical Models for Image to Image or Classification 8 | - `Functions` contains functions that useful for pre-processing and testing 9 | Docs are available in each folders, check it before first use. 10 | ### Package Usage 11 | Download the repository, and put it at the root of your project: 12 | ``` 13 | - Your Project Folder 14 | | 15 | |- TorchTools 16 | | 17 | |- Your Package 18 | | 19 | |- your python files 20 | | 21 | . 22 | . 23 | . 24 | | 25 | |- python files 26 | ``` 27 | Import TorchTools directly in your python files 28 | 29 | The functions or class start with `_` or `__` are Internal implementation, I don't recommend you to use them directly unless you know how it works. 30 | ### Future Needs 31 | Create issue if you want some functions or tools. 32 | Bug reports are **NEEDED**. 33 | ### Change Logs 34 | - 18.01.12 First commit 35 | -------------------------------------------------------------------------------- /TorchNet/Calligraphy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.functional import relu 5 | 6 | try: 7 | from math import log2 8 | except: 9 | from math import log 10 | def log2(x): 11 | return log(x) / log(2) 12 | 13 | from .activation import swish 14 | from .modules import Flatten, upsampleBlock 15 | 16 | 17 | class CalligraphyDiscriminator(nn.Module): 18 | def __init__(self, n_maps=64, activation=swish): 19 | super(CalligraphyDiscriminator, self).__init__() 20 | self.act = activation 21 | # B*1*32*32 22 | self.conv1 = nn.Conv2d(1, n_maps, kernel_size=3, stride=2, padding=1) 23 | # B*N*16*16 24 | self.conv2 = nn.Conv2d(n_maps*1, n_maps*2, kernel_size=3, stride=2, padding=1) 25 | # B*2N*8*8 26 | self.conv3 = nn.Conv2d(n_maps*2, n_maps*4, kernel_size=3, stride=2, padding=1) 27 | # B*4N*4*4 28 | self.conv4 = nn.Conv2d(n_maps*4, n_maps*8, kernel_size=3, stride=2, padding=1) 29 | # B*8N*2*2 30 | self.flat = Flatten() 31 | self.dense = nn.Linear(2048, 128) 32 | self.out = nn.Linear(128, 1) 33 | 34 | def forward(self, input): 35 | return F.sigmoid( 36 | self.out( 37 | self.act(self.dense( 38 | self.flat( 39 | self.act(self.conv4( 40 | self.act(self.conv3( 41 | self.act(self.conv2( 42 | self.act(self.conv1(input)) 43 | )) 44 | )) 45 | )) 46 | ) 47 | )) 48 | ) 49 | ) 50 | 51 | 52 | class CalligraphyUpsampleDiscriminator(nn.Module): 53 | def __init__(self, n_maps=32, activation=swish): 54 | super(CalligraphyUpsampleDiscriminator, self).__init__() 55 | self.act = activation 56 | # B*1*128*128 57 | self.conv1 = nn.Conv2d(1, n_maps, kernel_size=5, stride=2, padding=2) 58 | # B*N*64*64 59 | self.conv2 = nn.Conv2d(n_maps*1, n_maps*1, kernel_size=5, stride=2, padding=2) 60 | # B*N*32*32 61 | self.conv3 = nn.Conv2d(n_maps*1, n_maps*2, kernel_size=3, stride=2, padding=1) 62 | # B*2N*16*16 63 | self.conv4 = nn.Conv2d(n_maps*2, n_maps*4, kernel_size=3, stride=2, padding=1) 64 | # B*4N*8*8 65 | self.conv5 = nn.Conv2d(n_maps*4, n_maps*4, kernel_size=3, stride=2, padding=1) 66 | # B*4N*4*4 67 | self.conv6 = nn.Conv2d(n_maps*4, n_maps*8, kernel_size=3, stride=2, padding=1) 68 | # B*8N*2*2 69 | self.flat = Flatten() 70 | self.dense = nn.Linear(2048, 128) 71 | self.out = nn.Linear(128, 1) 72 | 73 | def forward(self, input): 74 | return F.sigmoid( 75 | self.out( 76 | self.act(self.dense( 77 | self.flat(self.act(self.conv6( 78 | self.act(self.conv5( 79 | self.act(self.conv4( 80 | self.act(self.conv3( 81 | self.act(self.conv2( 82 | self.act(self.conv1(input)) 83 | )) 84 | )) 85 | )) 86 | )) 87 | ))) 88 | )) 89 | ) 90 | ) 91 | 92 | 93 | class CalliUpsampleNet(nn.Module): 94 | def __init__(self, n_maps=64, activation=swish): 95 | super(CalliUpsampleNet, self).__init__() 96 | self.act = activation 97 | 98 | 99 | class EDCalliTransferNet(nn.Module): 100 | """ 101 | Encoder-Decoder Transfer Net 102 | Use Pixel-shuffle as Upsample method 103 | Input 32*32 104 | """ 105 | def __init__(self, n_maps=64, activation=swish): 106 | super(EDCalliTransferNet, self).__init__() 107 | self.act = activation 108 | self.flat = Flatten() 109 | # B*1*32*32 110 | self.E_Conv = nn.Conv2d(in_channels=1, out_channels=n_maps*1, kernel_size=5, stride=1, padding=2) 111 | # B*N*32*32 112 | self.E_conv1 = nn.Conv2d(in_channels=n_maps*1, out_channels=n_maps*1, kernel_size=3, stride=2, padding=1) 113 | # B*N*16*16 114 | self.E_conv2 = nn.Conv2d(in_channels=n_maps*1, out_channels=n_maps*2, kernel_size=3, stride=2, padding=1) 115 | # B*2N*8*8 116 | self.E_conv3 = nn.Conv2d(in_channels=n_maps*2, out_channels=n_maps*4, kernel_size=3, stride=2, padding=1) 117 | # B*4N*4*4 118 | self.E_conv4 = nn.Conv2d(in_channels=n_maps*4, out_channels=n_maps*8, kernel_size=3, stride=2, padding=1) 119 | # B*8N*2*2 120 | self.E_conv5 = nn.Conv2d(in_channels=n_maps*8, out_channels=n_maps*16, kernel_size=2) 121 | # B*16N*1*1 122 | # if N == 64, 16*N == 1024 123 | self.pixel = nn.PixelShuffle(2) 124 | # pixel 125 | # B*4N*2*2 126 | self.D_conv1 = nn.Conv2d(in_channels=n_maps*4, out_channels=n_maps*8, kernel_size=3, stride=1, padding=1) 127 | # B*8N*2*2 128 | # pixel 129 | # B*2N*4*4 130 | self.D_conv2 = nn.Conv2d(in_channels=n_maps*2, out_channels=n_maps*4, kernel_size=3, stride=1, padding=1) 131 | # B*4N*4*4 132 | # pixel 133 | # B*N*8*8 134 | self.D_conv3 = nn.Conv2d(in_channels=n_maps*1, out_channels=n_maps*4, kernel_size=5, stride=1, padding=2) 135 | # B*4N*8*8 136 | # pixel 137 | # B*N*16*16 138 | self.D_conv4 = nn.Conv2d(in_channels=n_maps*1, out_channels=n_maps*4, kernel_size=5, stride=1, padding=2) 139 | # B*4N*16*16 140 | # pixel 141 | # B*N*32*32 142 | self.D_conv5 = nn.Conv2d(in_channels=n_maps*1, out_channels=n_maps, kernel_size=3, stride=1, padding=1) 143 | # B*N*32*32 144 | self.D_Conv = nn.Conv2d(in_channels=n_maps, out_channels=1, kernel_size=3, stride=1, padding=1) 145 | # B*1*32*32 146 | 147 | def _decoder(self, code): 148 | return F.tanh( 149 | self.D_Conv( 150 | self.act(self.D_conv5( 151 | self.pixel(self.act(self.D_conv4( 152 | self.pixel(self.act(self.D_conv3( 153 | self.pixel(self.act(self.D_conv2( 154 | self.pixel(self.act(self.D_conv1( 155 | self.pixel(code) 156 | ))) 157 | ))) 158 | ))) 159 | ))) 160 | )) 161 | ) 162 | ) 163 | 164 | def _enocder(self, input): 165 | """ 166 | :param input: 32*32 167 | :return: B*16N*1*1 168 | """ 169 | return self.act(self.E_conv5( 170 | self.act(self.E_conv4( 171 | self.act(self.E_conv3( 172 | self.act(self.E_conv2( 173 | self.act(self.E_conv1( 174 | self.act(self.E_Conv(input)) 175 | )) 176 | )) 177 | )) 178 | )) 179 | )) 180 | 181 | def forward(self, input): 182 | code = self._enocder(input) 183 | return self._decoder(code) 184 | 185 | 186 | class CaliNet(nn.Module): 187 | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.InstanceNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'): 188 | super(CaliNet, self).__init__() 189 | self.conv = nn.Sequential( 190 | *[ 191 | nn.Conv2d(input_nc, ngf, kernel_size=5, stride=1, padding=2), 192 | nn.ReLU(), 193 | nn.Conv2d(ngf, ngf * 2, kernel_size=5, stride=2, padding=2), 194 | norm_layer(ngf * 2), 195 | nn.ReLU(), 196 | nn.Conv2d(ngf * 2, ngf * 2, kernel_size=5, stride=1, padding=2), 197 | norm_layer(ngf * 2), 198 | nn.ReLU(), 199 | nn.Conv2d(ngf * 2, ngf * 4, kernel_size=5, stride=2, padding=2), 200 | norm_layer(ngf * 4), 201 | nn.ReLU(), 202 | nn.Conv2d(ngf * 4, ngf * 4, kernel_size=5, stride=1, padding=2), 203 | norm_layer(ngf * 4), 204 | nn.ReLU(), 205 | nn.Conv2d(ngf * 4, ngf * 4, kernel_size=5, stride=2, padding=2), 206 | norm_layer(ngf * 4), 207 | nn.ReLU(), 208 | nn.Conv2d(ngf * 4, ngf * 4, kernel_size=5, stride=1, padding=2), 209 | norm_layer(ngf * 4), 210 | nn.ReLU(), 211 | nn.Conv2d(ngf * 4, ngf * 4, kernel_size=5, stride=1, padding=2), 212 | norm_layer(ngf * 4), 213 | nn.ReLU(), 214 | ] 215 | ) 216 | self.decv = nn.Sequential( 217 | *[ 218 | nn.ConvTranspose2d(ngf * 4, ngf * 4, kernel_size=4, stride=2, padding=1), 219 | norm_layer(ngf * 4), 220 | nn.ReLU(), 221 | nn.ConvTranspose2d(ngf * 4, ngf * 4, kernel_size=5, stride=1, padding=2), 222 | norm_layer(ngf * 4), 223 | nn.ReLU(), 224 | nn.ConvTranspose2d(ngf * 4, ngf * 2, kernel_size=4, stride=2, padding=1), 225 | norm_layer(ngf * 2), 226 | nn.ReLU(), 227 | nn.ConvTranspose2d(ngf * 2, ngf * 2, kernel_size=5, stride=1, padding=2), 228 | norm_layer(ngf * 2), 229 | nn.ReLU(), 230 | nn.ConvTranspose2d(ngf * 2, ngf * 1, kernel_size=4, stride=2, padding=1), 231 | nn.ReLU(), 232 | nn.ConvTranspose2d(ngf * 1, output_nc, kernel_size=5, stride=1, padding=2) 233 | ] 234 | ) 235 | 236 | def forward(self, input): 237 | x = self.conv(input) 238 | x = self.decv(x) 239 | return x 240 | 241 | 242 | -------------------------------------------------------------------------------- /TorchNet/ClassicSRNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | try: 6 | from math import log2 7 | except: 8 | from math import log 9 | def log2(x): 10 | return log(x) / log(2) 11 | 12 | import math 13 | 14 | from .modules import residualBlock, upsampleBlock, DownsamplingShuffle 15 | 16 | 17 | class SRResNet_Residual_Block(nn.Module): 18 | def __init__(self): 19 | super(SRResNet_Residual_Block, self).__init__() 20 | 21 | self.conv1 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) 22 | self.in1 = nn.InstanceNorm2d(64, affine=True) 23 | self.relu = nn.LeakyReLU(0.2, inplace=True) 24 | self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) 25 | self.in2 = nn.InstanceNorm2d(64, affine=True) 26 | 27 | def forward(self, x): 28 | identity_data = x 29 | output = self.relu(self.in1(self.conv1(x))) 30 | output = self.in2(self.conv2(output)) 31 | output = torch.add(output, identity_data) 32 | return output 33 | 34 | 35 | class SRResNetRGBX4(nn.Module): 36 | def __init__(self, min=0.0, max=1.0, tanh=False): 37 | super(SRResNetRGBX4, self).__init__() 38 | self.min = min 39 | self.max = max 40 | self.tanh = tanh 41 | 42 | self.conv_input = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=9, stride=1, padding=4, bias=False) 43 | self.relu = nn.LeakyReLU(0.2, inplace=True) 44 | 45 | self.residual = self.make_layer(SRResNet_Residual_Block, 16) 46 | 47 | self.conv_mid = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) 48 | self.bn_mid = nn.InstanceNorm2d(64, affine=True) 49 | 50 | self.upscale4x = nn.Sequential( 51 | nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False), 52 | nn.PixelShuffle(2), 53 | nn.LeakyReLU(0.2, inplace=True), 54 | nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False), 55 | nn.PixelShuffle(2), 56 | nn.LeakyReLU(0.2, inplace=True), 57 | ) 58 | 59 | self.conv_output = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=9, stride=1, padding=4, bias=False) 60 | 61 | for m in self.modules(): 62 | if isinstance(m, nn.Conv2d): 63 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 64 | m.weight.data.normal_(0, math.sqrt(2. / n)) 65 | if m.bias is not None: 66 | m.bias.data.zero_() 67 | 68 | def make_layer(self, block, num_of_layer): 69 | layers = [] 70 | for _ in range(num_of_layer): 71 | layers.append(block()) 72 | return nn.Sequential(*layers) 73 | 74 | def forward(self, x, clip=True): 75 | out = self.relu(self.conv_input(x)) 76 | residual = out 77 | out = self.residual(out) 78 | out = self.bn_mid(self.conv_mid(out)) 79 | out = torch.add(out, residual) 80 | out = self.upscale4x(out) 81 | out = self.conv_output(out) 82 | if self.tanh: 83 | return F.tanh(out) 84 | else: 85 | return torch.clamp(out, min=self.min, max=self.max) if clip else out 86 | 87 | 88 | class SRResNetYX4(nn.Module): 89 | def __init__(self, min=0.0, max=1.0, tanh=True): 90 | super(SRResNetYX4, self).__init__() 91 | self.min = min 92 | self.max = max 93 | self.tanh = tanh 94 | 95 | self.conv_input = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=9, stride=1, padding=4, bias=False) 96 | self.relu = nn.LeakyReLU(0.2, inplace=True) 97 | 98 | self.residual = self.make_layer(SRResNet_Residual_Block, 16) 99 | 100 | self.conv_mid = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) 101 | self.bn_mid = nn.InstanceNorm2d(64, affine=True) 102 | 103 | self.upscale4x = nn.Sequential( 104 | nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False), 105 | nn.PixelShuffle(2), 106 | nn.LeakyReLU(0.2, inplace=True), 107 | nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False), 108 | nn.PixelShuffle(2), 109 | nn.LeakyReLU(0.2, inplace=True), 110 | ) 111 | 112 | self.conv_output = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=9, stride=1, padding=4, bias=False) 113 | 114 | for m in self.modules(): 115 | if isinstance(m, nn.Conv2d): 116 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 117 | m.weight.data.normal_(0, math.sqrt(2. / n)) 118 | if m.bias is not None: 119 | m.bias.data.zero_() 120 | 121 | def make_layer(self, block, num_of_layer): 122 | layers = [] 123 | for _ in range(num_of_layer): 124 | layers.append(block()) 125 | return nn.Sequential(*layers) 126 | 127 | def forward(self, x, clip=True): 128 | out = self.relu(self.conv_input(x)) 129 | residual = out 130 | out = self.residual(out) 131 | out = self.bn_mid(self.conv_mid(out)) 132 | out = torch.add(out, residual) 133 | out = self.upscale4x(out) 134 | out = self.conv_output(out) 135 | if self.tanh: 136 | return F.tanh(out) 137 | else: 138 | return torch.clamp(out, min=self.min, max=self.max) if clip else out 139 | 140 | 141 | class SRResNetYX2(nn.Module): 142 | def __init__(self, min=0.0, max=1.0, tanh=True): 143 | super(SRResNetYX2, self).__init__() 144 | self.min = min 145 | self.max = max 146 | self.tanh = tanh 147 | 148 | self.conv_input = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=9, stride=1, padding=4, bias=False) 149 | self.relu = nn.LeakyReLU(0.2, inplace=True) 150 | 151 | self.residual = self.make_layer(SRResNet_Residual_Block, 16) 152 | 153 | self.conv_mid = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) 154 | self.bn_mid = nn.InstanceNorm2d(64, affine=True) 155 | 156 | self.upscale2x = nn.Sequential( 157 | nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False), 158 | nn.PixelShuffle(2), 159 | nn.LeakyReLU(0.2, inplace=True), 160 | ) 161 | 162 | self.conv_output = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=9, stride=1, padding=4, bias=False) 163 | 164 | for m in self.modules(): 165 | if isinstance(m, nn.Conv2d): 166 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 167 | m.weight.data.normal_(0, math.sqrt(2. / n)) 168 | if m.bias is not None: 169 | m.bias.data.zero_() 170 | 171 | def make_layer(self, block, num_of_layer): 172 | layers = [] 173 | for _ in range(num_of_layer): 174 | layers.append(block()) 175 | return nn.Sequential(*layers) 176 | 177 | def forward(self, x): 178 | out = self.relu(self.conv_input(x)) 179 | residual = out 180 | out = self.residual(out) 181 | out = self.bn_mid(self.conv_mid(out)) 182 | out = torch.add(out, residual) 183 | out = self.upscale2x(out) 184 | out = self.conv_output(out) 185 | if self.tanh: 186 | return F.tanh(out) 187 | else: 188 | return torch.clamp(out, min=self.min, max=self.max) 189 | 190 | 191 | class DownSampleResNetYX4(nn.Module): 192 | def __init__(self, min=0.0, max=1.0): 193 | super(DownSampleResNetYX4, self).__init__() 194 | self.min = min 195 | self.max = max 196 | self.down_shuffle = DownsamplingShuffle(4) 197 | 198 | self.conv_input = nn.Conv2d(in_channels=16, out_channels=64, kernel_size=5, stride=1, padding=2, bias=False) 199 | self.relu = nn.LeakyReLU(0.2, inplace=True) 200 | 201 | self.residual = self.make_layer(SRResNet_Residual_Block, 6) 202 | 203 | self.conv_mid = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) 204 | self.bn_mid = nn.InstanceNorm2d(64, affine=True) 205 | 206 | self.conv_output = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=5, stride=1, padding=2, bias=False) 207 | 208 | for m in self.modules(): 209 | if isinstance(m, nn.Conv2d): 210 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 211 | m.weight.data.normal_(0, math.sqrt(2. / n)) 212 | if m.bias is not None: 213 | m.bias.data.zero_() 214 | 215 | def make_layer(self, block, num_of_layer): 216 | layers = [] 217 | for _ in range(num_of_layer): 218 | layers.append(block()) 219 | return nn.Sequential(*layers) 220 | 221 | def forward(self, x): 222 | out = self.relu(self.conv_input(self.down_shuffle(x))) 223 | residual = out 224 | out = self.residual(out) 225 | out = self.bn_mid(self.conv_mid(out)) 226 | out = torch.add(out, residual) 227 | out = self.conv_output(out) 228 | return torch.clamp(out, min=self.min, max=self.max) 229 | 230 | 231 | class FSRCNNY(nn.Module): 232 | """ 233 | Sequential( 234 | (0): Conv2d (1, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False) 235 | (1): LeakyReLU(0.2, inplace) 236 | (2): Conv2d (1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 237 | (3): LeakyReLU(0.2, inplace) 238 | (4): Conv2d (1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 239 | (5): LeakyReLU(0.2, inplace) 240 | (6): Conv2d (1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 241 | (7): LeakyReLU(0.2, inplace) 242 | (8): Conv2d (1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 243 | (9): LeakyReLU(0.2, inplace) 244 | (10): Conv2d (1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 245 | (11): LeakyReLU(0.2, inplace) 246 | (12): upsampleBlock( 247 | (act): LeakyReLU(0.2, inplace) 248 | (pad): ReflectionPad2d((1, 1, 1, 1)) 249 | (conv): Conv2d (64, 256, kernel_size=(3, 3), stride=(1, 1)) 250 | (shuffler): PixelShuffle(upscale_factor=2) 251 | ) 252 | (13): upsampleBlock( 253 | (act): LeakyReLU(0.2, inplace) 254 | (pad): ReflectionPad2d((1, 1, 1, 1)) 255 | (conv): Conv2d (64, 256, kernel_size=(3, 3), stride=(1, 1)) 256 | (shuffler): PixelShuffle(upscale_factor=2) 257 | ) 258 | (14): Conv2d (64, 1, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False) 259 | ) 260 | """ 261 | def __init__(self, scala=4): 262 | super(FSRCNNY, self).__init__() 263 | self.scala = scala 264 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=7, stride=1, padding=3, bias=False) 265 | self.relu1 = nn.LeakyReLU(0.2, inplace=True) 266 | self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) 267 | self.relu2 = nn.LeakyReLU(0.2, inplace=True) 268 | self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) 269 | self.relu3 = nn.LeakyReLU(0.2, inplace=True) 270 | self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) 271 | self.relu4 = nn.LeakyReLU(0.2, inplace=True) 272 | self.conv5 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) 273 | self.relu5 = nn.LeakyReLU(0.2, inplace=True) 274 | self.conv6 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) 275 | self.relu6 = nn.LeakyReLU(0.2, inplace=True) 276 | 277 | for i in range(int(log2(self.scala))): 278 | self.add_module('upsample' + str(i + 1), upsampleBlock(64, 64 * 4, activation=nn.LeakyReLU(0.2, inplace=True))) 279 | 280 | self.convf = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=7, stride=1, padding=3, bias=False) 281 | 282 | def forward(self, input): 283 | out = self.relu1(self.conv1(input)) 284 | out = self.relu2(self.conv2(out)) 285 | out = self.relu3(self.conv3(out)) 286 | out = self.relu4(self.conv4(out)) 287 | out = self.relu5(self.conv5(out)) 288 | x = self.relu6(self.conv6(out)) 289 | 290 | for i in range(int(log2(self.scala))): 291 | x = self.__getattr__('upsample' + str(i + 1))(x) 292 | 293 | return F.tanh(self.convf(x)) 294 | 295 | -------------------------------------------------------------------------------- /TorchNet/Deblur.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Function, Variable 5 | 6 | try: 7 | from pytorch_fft.fft.autograd import Ifft2d, Fft2d 8 | import pytorch_fft.fft as fft 9 | except: 10 | pass 11 | 12 | import numpy as np 13 | import torch 14 | 15 | fft_cuda = Fft2d() 16 | ifft_cuda = Ifft2d() 17 | 18 | 19 | def complex_multi(real_1, real_2, imag_1, imag_2): 20 | real = real_1 * real_2 - imag_1 * imag_2 21 | imag = imag_2 * real_1 + imag_1 * real_2 22 | return real, imag 23 | 24 | 25 | def complex_div(real_1, real_2, imag_1, imag_2): 26 | down = real_2 ** 2 + imag_2 ** 2 27 | up_real, up_imag = complex_multi(real_1, real_2, imag_1, -imag_2) 28 | return up_real / down, up_imag / down 29 | 30 | 31 | def complex_abs(real, imag): 32 | return torch.sqrt(real ** 2 + imag ** 2) 33 | 34 | 35 | def fftshift_cpu(fft_map, axes=None): 36 | ndim = len(fft_map.size()) 37 | if axes is None: 38 | axes = list(range(ndim)) 39 | elif isinstance(axes, int): 40 | axes = (axes,) 41 | y = fft_map 42 | for k in axes: 43 | n = fft_map.size()[k] 44 | p2 = (n + 1) // 2 45 | mylist = np.concatenate((np.arange(p2, n), np.arange(p2))) 46 | y = torch.index_select(y, k, Variable(torch.LongTensor(mylist))) 47 | return y 48 | 49 | 50 | def fftshift_cuda(fft_map, axes=None): 51 | ndim = len(fft_map.size()) 52 | if axes is None: 53 | axes = list(range(ndim)) 54 | elif isinstance(axes, int): 55 | axes = (axes,) 56 | y = fft_map 57 | for k in axes: 58 | n = fft_map.size()[k] 59 | p2 = (n + 1) // 2 60 | mylist = np.concatenate((np.arange(p2, n), np.arange(p2))) 61 | y = torch.index_select(y, k, Variable(torch.LongTensor(mylist)).cuda()) 62 | return y 63 | 64 | 65 | def center_padding(canvas, H, W): 66 | """ 67 | (pad_l, pad_r, pad_t, pad_b ) 68 | """ 69 | canvas_p = canvas // 2 70 | H_p = H // 2 71 | W_p = W // 2 72 | pad_t = H_p - canvas_p 73 | pad_b = H - pad_t - canvas 74 | pad_l = W_p - canvas_p 75 | pad_r = W - pad_l - canvas 76 | return pad_l, pad_r, pad_t, pad_b 77 | 78 | 79 | def kernel_filp(kernel): 80 | canvas_1 = kernel.size()[1] 81 | canvas_2 = kernel.size()[2] 82 | kernel_h = torch.index_select(kernel, 1, Variable(torch.LongTensor(np.arange(canvas_1-1, -1, -1)))) 83 | kernel_w = torch.index_select(kernel_h, 2, Variable(torch.LongTensor(np.arange(canvas_2-1, -1, -1)))) 84 | return kernel_w 85 | 86 | 87 | def convolution_fft_cuda(image, psf, canvas=64, eps=1e-8): 88 | img_real = image 89 | _, H, W = image.size() 90 | img_imag = torch.zeros_like(img_real).cuda() 91 | psf_pad_real = F.pad(psf, center_padding(canvas, H, W)) 92 | psf_pad_imag = torch.zeros_like(psf_pad_real).cuda() 93 | input_fft_r, input_fft_i = fft_cuda(img_real, img_imag) 94 | psf_fft_r, psf_fft_i = fft_cuda(psf_pad_real, psf_pad_imag) 95 | conv_real, conv_imag = complex_multi(input_fft_r, psf_fft_r + eps, input_fft_i, psf_fft_i) 96 | conv = ifft_cuda(conv_real, conv_imag) 97 | conv_mage = torch.sqrt(conv[0] ** 2 + conv[1] ** 2) 98 | conv_img = fftshift_cuda(conv_mage) 99 | return conv_img 100 | 101 | 102 | def inverse_convolution_cuda(image, psf, canvas=64, eps=1e-8): 103 | img_real = image 104 | _, H, W = image.size() 105 | img_imag = torch.zeros_like(img_real).cuda() 106 | psf_pad_real = F.pad(psf, center_padding(canvas, H, W)) 107 | psf_pad_imag = torch.zeros_like(psf_pad_real).cuda() 108 | input_fft_r, input_fft_i = fft_cuda(img_real, img_imag) 109 | psf_fft_r, psf_fft_i = fft_cuda(psf_pad_real, psf_pad_imag) 110 | deconv_real, deconv_imag = complex_div(input_fft_r, psf_fft_r + eps, input_fft_i, psf_fft_i) 111 | deconv = ifft_cuda(deconv_real, deconv_imag) 112 | deconv_mage = torch.sqrt(deconv[0] ** 2 + deconv[1] ** 2) 113 | deconv_img = fftshift_cuda(deconv_mage) 114 | return deconv_img 115 | 116 | 117 | def wiener_filter_cuda(image, psf, canvas=64, eps=1e-8, K=1e-8): 118 | img_real = image 119 | _, H, W = image.size() 120 | img_imag = torch.zeros_like(img_real).cuda() 121 | psf_pad_real = F.pad(psf, center_padding(canvas, H, W)) 122 | psf_pad_imag = torch.zeros_like(psf_pad_real).cuda() 123 | input_fft_r, input_fft_i = fft_cuda(img_real, img_imag) 124 | psf_fft_r, psf_fft_i = fft_cuda(psf_pad_real, psf_pad_imag) 125 | psf_fft_r += eps 126 | psf_fft_prime_r, psf_fft_prime_i = complex_div(psf_fft_r, psf_fft_r**2 + psf_fft_i**2 + K, -psf_fft_i, 0.0) 127 | deconv_r, deconv_i = complex_multi(input_fft_r, psf_fft_prime_r, input_fft_i, psf_fft_prime_i) 128 | deconv = ifft_cuda(deconv_r, deconv_i) 129 | deconv_mage = torch.sqrt(deconv[0] ** 2 + deconv[1] ** 2) 130 | deconv_img = fftshift_cuda(deconv_mage) 131 | return deconv_img 132 | 133 | -------------------------------------------------------------------------------- /TorchNet/Discriminators.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .activation import swish 6 | 7 | 8 | class Discriminator(nn.Module): 9 | 10 | def __init__(self, activation=swish): 11 | super(Discriminator, self).__init__() 12 | self.act = activation 13 | 14 | self.conv1 = nn.Conv2d(3, 64, 3, stride=1, padding=1) 15 | 16 | self.conv2 = nn.Conv2d(64, 64, 3, stride=2, padding=1) 17 | self.bn2 = nn.BatchNorm2d(64) 18 | self.conv3 = nn.Conv2d(64, 128, 3, stride=1, padding=1) 19 | self.bn3 = nn.BatchNorm2d(128) 20 | self.conv4 = nn.Conv2d(128, 128, 3, stride=2, padding=1) 21 | self.bn4 = nn.BatchNorm2d(128) 22 | self.conv5 = nn.Conv2d(128, 256, 3, stride=1, padding=1) 23 | self.bn5 = nn.BatchNorm2d(256) 24 | self.conv6 = nn.Conv2d(256, 256, 3, stride=2, padding=1) 25 | self.bn6 = nn.BatchNorm2d(256) 26 | self.conv7 = nn.Conv2d(256, 512, 3, stride=1, padding=1) 27 | self.bn7 = nn.BatchNorm2d(512) 28 | self.conv8 = nn.Conv2d(512, 512, 3, stride=2, padding=1) 29 | self.bn8 = nn.BatchNorm2d(512) 30 | self.conv9 = nn.Conv2d(512, 1, 1, stride=1, padding=1) 31 | 32 | def forward(self, x): 33 | x = self.act(self.conv1(x)) 34 | 35 | x = self.act(self.bn2(self.conv2(x))) 36 | x = self.act(self.bn3(self.conv3(x))) 37 | x = self.act(self.bn4(self.conv4(x))) 38 | x = self.act(self.bn5(self.conv5(x))) 39 | x = self.act(self.bn6(self.conv6(x))) 40 | x = self.act(self.bn7(self.conv7(x))) 41 | x = self.act(self.bn8(self.conv8(x))) 42 | 43 | x = self.conv9(x) 44 | return F.sigmoid(F.avg_pool2d(x, x.size()[2:])).view(x.size()[0], -1) 45 | 46 | 47 | class DiscriminatorY(nn.Module): 48 | 49 | def __init__(self, activation=swish): 50 | super(DiscriminatorY, self).__init__() 51 | self.act = activation 52 | self.conv1 = nn.Conv2d(1, 64, 3, stride=1, padding=1) 53 | 54 | self.conv2 = nn.Conv2d(64, 64, 3, stride=2, padding=1) 55 | self.bn2 = nn.BatchNorm2d(64) 56 | self.conv3 = nn.Conv2d(64, 64, 3, stride=1, padding=1) 57 | self.bn3 = nn.BatchNorm2d(64) 58 | self.conv4 = nn.Conv2d(64, 128, 3, stride=2, padding=1) 59 | self.bn4 = nn.BatchNorm2d(128) 60 | self.conv5 = nn.Conv2d(128, 128, 3, stride=1, padding=1) 61 | self.bn5 = nn.BatchNorm2d(128) 62 | self.conv6 = nn.Conv2d(128, 128, 3, stride=2, padding=1) 63 | self.bn6 = nn.BatchNorm2d(128) 64 | self.conv7 = nn.Conv2d(128, 256, 3, stride=1, padding=1) 65 | self.bn7 = nn.BatchNorm2d(256) 66 | self.conv8 = nn.Conv2d(256, 256, 3, stride=2, padding=1) 67 | self.bn8 = nn.BatchNorm2d(256) 68 | self.conv9 = nn.Conv2d(256, 1, 1, stride=1, padding=1) 69 | 70 | def forward(self, x): 71 | x = self.act(self.conv1(x)) 72 | 73 | x = self.act(self.bn2(self.conv2(x))) 74 | x = self.act(self.bn3(self.conv3(x))) 75 | x = self.act(self.bn4(self.conv4(x))) 76 | x = self.act(self.bn5(self.conv5(x))) 77 | x = self.act(self.bn6(self.conv6(x))) 78 | x = self.act(self.bn7(self.conv7(x))) 79 | x = self.act(self.bn8(self.conv8(x))) 80 | 81 | x = self.conv9(x) 82 | return F.sigmoid(F.avg_pool2d(x, x.size()[2:])).view(x.size()[0], -1) 83 | 84 | -------------------------------------------------------------------------------- /TorchNet/FaceDetection.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | from .modules import upsampleBlock 9 | from .activation import swish 10 | 11 | 12 | class PNet(nn.Module): 13 | def __init__(self): 14 | super(PNet, self).__init__() 15 | self.conv1 = nn.Conv2d(3, 10, 3) 16 | self.pool = nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True) 17 | self.conv2 = nn.Conv2d(10, 16, 3) 18 | self.conv3 = nn.Conv2d(16, 32, 3) 19 | self.conv4_1 = nn.Conv2d(32, 2, 1) 20 | self.conv4_2 = nn.Conv2d(32, 4, 1) 21 | 22 | def forward(self, input): 23 | conv3 = F.relu(self.conv3( 24 | F.relu(self.conv2( 25 | self.pool(F.relu(self.conv1(input))) 26 | )) 27 | )) 28 | conv4_1 = self.conv4_1(conv3) 29 | conv4_2 = self.conv4_2(conv3) 30 | return torch.cat([conv4_1, conv4_2], dim=1) 31 | 32 | 33 | class RNet(nn.Module): 34 | def __init__(self): 35 | super(RNet, self).__init__() 36 | self.conv1 = nn.Conv2d(3, 28, 3) 37 | self.pool1 = nn.MaxPool2d((3, 3), (2, 2), (0, 0), ceil_mode=True) 38 | self.conv2 = nn.Conv2d(28, 48, 3) 39 | self.pool2 = nn.MaxPool2d((3, 3), (2, 2), (0, 0), ceil_mode=True) 40 | self.conv3 = nn.Conv2d(48, 64, 2) 41 | self.linear1 = nn.Linear(576, 128) 42 | self.linear2_1 = nn.Linear(128, 2) 43 | self.linear2_2 = nn.Linear(128, 4) 44 | self.linear2_3 = nn.Linear(128, 10) 45 | 46 | def forward(self, input): 47 | conv3 = F.relu(self.conv3( 48 | self.pool2(F.relu(self.conv2( 49 | self.pool1(F.relu(self.conv1(input))) 50 | ))) 51 | )) 52 | flat = conv3.view(conv3.size(0), -1) 53 | linear1 = F.relu(self.linear1(flat)) 54 | classification = self.linear2_1() 55 | return None 56 | 57 | 58 | def _detector_parser(image_size, ): 59 | pass 60 | 61 | 62 | def _2detector_label(): 63 | pass 64 | 65 | 66 | class LightFaceDetector(nn.Module): 67 | """ 68 | This is a lightweight face landmark detector using fully convolution layer. detect m points 69 | |-------------------------------| 70 | | Input Y Channel Image | 71 | |-------------------------------| 72 | | Convolution_1 n32k5s1, ReLu | 73 | |-------------------------------| 74 | | MaxPooling 2x | 75 | |-------------------------------| 76 | | Convolution_2 n32k3s1, ReLu | 77 | |-------------------------------| 78 | | Convolution_3 n32k3s1, ReLu | 79 | |-------------------------------| 80 | | Pixel Shuffle 2x | 81 | |-------------------------------| 82 | | Convolution_4 nmk3s1, Sigmoid | 83 | |-------------------------------| 84 | """ 85 | def __init__(self, landmarks=5, activation=swish, in_channel=1): 86 | super(LightFaceDetector, self).__init__() 87 | self.act = activation 88 | self.pad1 = nn.ReflectionPad2d(2) 89 | self.conv1 = nn.Conv2d(in_channel, 32, 5, stride=1, padding=0) 90 | self.pool = nn.MaxPool2d(2) 91 | self.pad2 = nn.ReflectionPad2d(1) 92 | self.conv2 = nn.Conv2d(32, 32, 3, stride=1, padding=0) 93 | self.pad3 = nn.ReflectionPad2d(1) 94 | self.conv3 = nn.Conv2d(32, 32, 3, stride=1, padding=0) 95 | self.upsample = upsampleBlock(32, 128) 96 | self.conv4 = nn.Conv2d(32, landmarks, 3, stride=1, padding=1) 97 | 98 | def forward(self, input): 99 | return F.sigmoid(self.conv4(self.upsample(self.act(self.conv3(self.act(self.conv2(self.pool(self.act(self.conv1(input)))))))))) 100 | 101 | 102 | 103 | 104 | 105 | 106 | -------------------------------------------------------------------------------- /TorchNet/FaceHallucination.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.functional import relu 5 | 6 | try: 7 | from math import log2 8 | except: 9 | from math import log 10 | def log2(x): 11 | return log(x) / log(2) 12 | 13 | from .modules import Features4Layer, Features3Layer, residualBlock, upsampleBlock, LateUpsamplingBlock, LateUpsamplingBlockNoBN 14 | from .activation import swish 15 | 16 | 17 | class MaxActivationFusion(nn.Module): 18 | """ 19 | model implementation of the Maximum-activation Detail Fusion 20 | This is not a complete SR model, just **Fusion Part** 21 | """ 22 | def __init__(self, features=64, feature_extractor=Features4Layer, activation=relu): 23 | """ 24 | :param features: the number of final feature maps 25 | """ 26 | super(MaxActivationFusion, self).__init__() 27 | self.features = feature_extractor(features, activation=activation) 28 | 29 | def forward(self, frame_1, frame_2, frame_3, frame_4, frame_5): 30 | """ 31 | :param frame_1: frame t-2 32 | :param frame_2: frame t-1 33 | :param frame_3: frame t 34 | :param frame_4: frame t+1 35 | :param frame_5: frame t+2 36 | :return: features 37 | """ 38 | frame_1_feature = self.features(frame_1) 39 | frame_2_feature = self.features(frame_2) 40 | frame_3_feature = self.features(frame_3) 41 | frame_4_feature = self.features(frame_4) 42 | frame_5_feature = self.features(frame_5) 43 | 44 | frame_1_feature = frame_1_feature.view((1, ) + frame_1_feature.size()) 45 | frame_2_feature = frame_2_feature.view((1, ) + frame_2_feature.size()) 46 | frame_3_feature = frame_3_feature.view((1, ) + frame_3_feature.size()) 47 | frame_4_feature = frame_4_feature.view((1, ) + frame_4_feature.size()) 48 | frame_5_feature = frame_5_feature.view((1, ) + frame_5_feature.size()) 49 | 50 | cat = torch.cat((frame_1_feature, frame_2_feature, frame_3_feature, frame_4_feature, frame_5_feature), dim=0) 51 | return torch.max(cat, 0)[0] 52 | 53 | 54 | class MeanActivationFusion(nn.Module): 55 | """ 56 | model implementation of the Mean-activation Detail Fusion 57 | This is not a complete SR model, just **Fusion Part** 58 | """ 59 | def __init__(self, features=64, feature_extractor=Features4Layer, activation=relu): 60 | """ 61 | :param features: the number of final feature maps 62 | """ 63 | super(MeanActivationFusion, self).__init__() 64 | self.features = feature_extractor(features, activation=activation) 65 | 66 | def forward(self, frame_1, frame_2, frame_3, frame_4, frame_5): 67 | """ 68 | :param frame_1: frame t-2 69 | :param frame_2: frame t-1 70 | :param frame_3: frame t 71 | :param frame_4: frame t+1 72 | :param frame_5: frame t+2 73 | :return: features 74 | """ 75 | frame_1_feature = self.features(frame_1) 76 | frame_2_feature = self.features(frame_2) 77 | frame_3_feature = self.features(frame_3) 78 | frame_4_feature = self.features(frame_4) 79 | frame_5_feature = self.features(frame_5) 80 | 81 | frame_1_feature = frame_1_feature.view((1, ) + frame_1_feature.size()) 82 | frame_2_feature = frame_2_feature.view((1, ) + frame_2_feature.size()) 83 | frame_3_feature = frame_3_feature.view((1, ) + frame_3_feature.size()) 84 | frame_4_feature = frame_4_feature.view((1, ) + frame_4_feature.size()) 85 | frame_5_feature = frame_5_feature.view((1, ) + frame_5_feature.size()) 86 | 87 | cat = torch.cat((frame_1_feature, frame_2_feature, frame_3_feature, frame_4_feature, frame_5_feature), dim=0) 88 | return torch.mean(cat, 0) 89 | 90 | 91 | class EarlyMean(nn.Module): 92 | """ 93 | model implementation of the Early Mean Fusion 94 | This is not a complete SR model, just **Fusion Part** 95 | """ 96 | def __init__(self, features=64, feature_extractor=Features4Layer, activation=relu): 97 | """ 98 | :param features: the number of final feature maps 99 | """ 100 | super(EarlyMean, self).__init__() 101 | self.features = feature_extractor(features, activation=activation) 102 | 103 | def forward(self, frame_1, frame_2, frame_3, frame_4, frame_5): 104 | """ 105 | :param frame_1: frame t-2 106 | :param frame_2: frame t-1 107 | :param frame_3: frame t 108 | :param frame_4: frame t+1 109 | :param frame_5: frame t+2 110 | :return: features 111 | """ 112 | frame_1 = frame_1.view((1, ) + frame_1.size()) 113 | frame_2 = frame_2.view((1, ) + frame_2.size()) 114 | frame_3 = frame_3.view((1, ) + frame_3.size()) 115 | frame_4 = frame_4.view((1, ) + frame_4.size()) 116 | frame_5 = frame_5.view((1, ) + frame_5.size()) 117 | frames_mean = torch.mean(torch.cat((frame_1, frame_2, frame_3, frame_4, frame_5), dim=0), 0) 118 | return self.features(frames_mean) 119 | 120 | 121 | class EaryFusion(nn.Module): 122 | """ 123 | model implementation of the Early Fusion 124 | This is not a complete SR model, just **Fusion Part** 125 | """ 126 | def __init__(self, features=64, feature_extractor=Features3Layer, activation=relu): 127 | """ 128 | :param features: the number of final feature maps 129 | """ 130 | super(EaryFusion, self).__init__() 131 | self.act = activation 132 | self.features = feature_extractor(features, activation=activation) 133 | self.conv = nn.Conv2d(features * 5, features, 3, stride=1, padding=1) 134 | 135 | def forward(self, frame_1, frame_2, frame_3, frame_4, frame_5): 136 | """ 137 | :param frame_1: frame t-2 138 | :param frame_2: frame t-1 139 | :param frame_3: frame t 140 | :param frame_4: frame t+1 141 | :param frame_5: frame t+2 142 | :return: features 143 | """ 144 | return self.act(self.conv( 145 | torch.cat( 146 | (self.features(frame_1), 147 | self.features(frame_2), 148 | self.features(frame_3), 149 | self.features(frame_4), 150 | self.features(frame_5)), 151 | dim=1) 152 | )) 153 | 154 | 155 | class EarlyEarly(nn.Module): 156 | """ 157 | model implementation of the Early Fusion 158 | This is not a complete SR model, just **Fusion Part** 159 | """ 160 | def __init__(self, features=64, activation=relu): 161 | """ 162 | :param features: the number of final feature maps 163 | """ 164 | super(EarlyEarly, self).__init__() 165 | self.act = activation 166 | self.conv = nn.Conv2d(5, features, 3, stride=1, padding=1) 167 | self.c1 = nn.Conv2d(features, features, 3, padding=1) 168 | self.c2 = nn.Conv2d(features, features, 3, padding=1) 169 | self.c3 = nn.Conv2d(features, features, 3, padding=1) 170 | 171 | def forward(self, frame_1, frame_2, frame_3, frame_4, frame_5): 172 | """ 173 | :param frame_1: frame t-2 174 | :param frame_2: frame t-1 175 | :param frame_3: frame t 176 | :param frame_4: frame t+1 177 | :param frame_5: frame t+2 178 | :return: features 179 | """ 180 | return self.act(self.c3(self.act(self.c2(self.act(self.c1(self.act(self.conv(torch.cat( 181 | (frame_1, 182 | frame_2, 183 | frame_3, 184 | frame_4, 185 | frame_5), 186 | dim=1))))))))) 187 | 188 | 189 | class _3DConv(nn.Module): 190 | def __init__(self,features=64, activation=relu): 191 | super(_3DConv, self).__init__() 192 | self.act = activation 193 | self.n_features = features 194 | self.conv3d = nn.Conv3d(1, features, (5, 3, 3), padding=(0, 1, 1)) 195 | self.c1 = nn.Conv2d(features, features, 3, padding=1) 196 | self.c2 = nn.Conv2d(features, features, 3, padding=1) 197 | self.c3 = nn.Conv2d(features, features, 3, padding=1) 198 | 199 | def forward(self, frame_1, frame_2, frame_3, frame_4, frame_5): 200 | batch, C, H, W = frame_1.size() 201 | frame_1 = frame_1.view((batch, C, 1, H, W)) 202 | frame_2 = frame_2.view((batch, C, 1, H, W)) 203 | frame_3 = frame_3.view((batch, C, 1, H, W)) 204 | frame_4 = frame_4.view((batch, C, 1, H, W)) 205 | frame_5 = frame_5.view((batch, C, 1, H, W)) 206 | frame_block = torch.cat( 207 | (frame_1, 208 | frame_2, 209 | frame_3, 210 | frame_4, 211 | frame_5), dim=2) 212 | _3dfeatures = self.act(self.conv3d(frame_block)) 213 | _3dfeatures = _3dfeatures.view((batch, self.n_features, H, W)) 214 | return self.act(self.c3( 215 | self.act(self.c2( 216 | self.act(self.c1(_3dfeatures)) 217 | )) 218 | )) 219 | 220 | 221 | class HallucinationOrigin(nn.Module): 222 | """ 223 | Original Video Face Hallucination Net 224 | |---------------------------------| 225 | | Input features | 226 | |---------------------------------| 227 | | n | Residual blocks | 228 | |---------------------------------| 229 | | Big short connect from features | 230 | |---------------------------------| 231 | | Convolution and BN | 232 | |---------------------------------| 233 | | Pixel Shuffle Up-sampling | 234 | |---------------------------------| 235 | | Final Convolution | 236 | |---------------------------------| 237 | | Tanh | 238 | |---------------------------------| 239 | """ 240 | def __init__(self, scala=8, features=64, n_residual_blocks=9, big_short_connect=False, output_channel=1): 241 | """ 242 | :param scala: scala factor 243 | :param n_residual_blocks: The number of residual blocks 244 | :param Big_short_connect: Weather the short connect between the input features and the Conv&BN 245 | """ 246 | super(HallucinationOrigin, self).__init__() 247 | self.n_residual_blocks = n_residual_blocks 248 | self.scala = scala 249 | self.connect = big_short_connect 250 | 251 | for i in range(self.n_residual_blocks): 252 | self.add_module('residual_block' + str(i + 1), residualBlock(features)) 253 | 254 | self.pad = nn.ReflectionPad2d(1) 255 | self.conv = nn.Conv2d(features, features, 3, stride=1, padding=0) 256 | self.bn = nn.BatchNorm2d(features) 257 | 258 | for i in range(int(log2(self.scala))): 259 | self.add_module('upsample' + str(i + 1), upsampleBlock(features, features * 4)) 260 | 261 | self.pad2 = nn.ReflectionPad2d(3) 262 | self.conv2 = nn.Conv2d(features, output_channel, 7, stride=1, padding=0) 263 | 264 | def forward(self, features): 265 | y = features.clone() 266 | for i in range(self.n_residual_blocks): 267 | y = self.__getattr__('residual_block' + str(i + 1))(y) 268 | 269 | if self.connect: 270 | x = self.bn(self.conv(self.pad(y))) + features 271 | else: 272 | x = self.bn(self.conv(self.pad(y))) 273 | 274 | for i in range(int(log2(self.scala))): 275 | x = self.__getattr__('upsample' + str(i + 1))(x) 276 | 277 | return F.tanh(self.conv2(self.pad2(x))) 278 | 279 | 280 | class StepHallucinationNet(nn.Module): 281 | """ 282 | |-----------------------------------| 283 | | features | 284 | |-----------------------------------| 285 | | log2(scala) | LateUpsamplingBlock | 286 | |-----------------------------------| 287 | | Convolution and Tanh | 288 | |-----------------------------------| 289 | """ 290 | def __init__(self, scala=8, features=64, little_res_blocks=3, output_channel=1): 291 | """ 292 | :param scala: scala factor 293 | :param features: 294 | :param little_res_blocks: The number of residual blocks in every late upsample blocks 295 | :param output_channel: default to be 1 for Y channel 296 | """ 297 | super(StepHallucinationNet, self).__init__() 298 | self.scala = scala 299 | self.features = features 300 | self.n_res_blocks = little_res_blocks 301 | 302 | for i in range(int(log2(self.scala))): 303 | self.add_module('lateUpsampling' + str(i + 1), LateUpsamplingBlock(features, n_res_block=little_res_blocks)) 304 | 305 | self.pad = nn.ReflectionPad2d(3) 306 | self.conv = nn.Conv2d(features, output_channel, 7, stride=1, padding=0) 307 | 308 | def forward(self, features): 309 | for i in range(int(log2(self.scala))): 310 | features = self.__getattr__('lateUpsampling' + str(i + 1))(features) 311 | return F.tanh(self.conv(self.pad(features))) 312 | 313 | 314 | class StepHallucinationNoBN(nn.Module): 315 | """ 316 | |-----------------------------------| 317 | | features | 318 | |-----------------------------------| 319 | | log2(scala) | LateUpsamplingBlock | 320 | |-----------------------------------| 321 | | Convolution and Tanh | 322 | |-----------------------------------| 323 | """ 324 | def __init__(self, scala=8, features=64, little_res_blocks=3, output_channel=1): 325 | """ 326 | :param scala: scala factor 327 | :param features: 328 | :param little_res_blocks: The number of residual blocks in every late upsample blocks 329 | :param output_channel: default to be 1 for Y channel 330 | """ 331 | super(StepHallucinationNoBN, self).__init__() 332 | self.scala = scala 333 | self.features = features 334 | self.n_res_blocks = little_res_blocks 335 | 336 | for i in range(int(log2(self.scala))): 337 | self.add_module('lateUpsampling' + str(i + 1), LateUpsamplingBlockNoBN(features, n_res_block=little_res_blocks)) 338 | 339 | self.pad = nn.ReflectionPad2d(3) 340 | self.conv = nn.Conv2d(features, output_channel, 7, stride=1, padding=0) 341 | 342 | def forward(self, features): 343 | for i in range(int(log2(self.scala))): 344 | features = self.__getattr__('lateUpsampling' + str(i + 1))(features) 345 | return F.tanh(self.conv(self.pad(features))) 346 | 347 | 348 | class FusionModel(nn.Module): 349 | """ 350 | The Multi-frame face hallucination model 351 | """ 352 | def __init__(self, scala=8, fusion='mdf', upsample='org'): 353 | """ 354 | :param fusion: 'mdf'=MaxActivationFusion, 355 | 'early'=EaryFusion, 356 | 'mef'=MeanActivationFusion, 357 | 'earlyMean'=EarlyMean, 358 | 'ee'=EarlyEarly, 359 | '3d'=_3DConv 360 | :param upsample: 'org'=HallucinationOrigin, 361 | 'step'=StepHallucinationNet, 362 | 'no'=StepHallucinationNoBN 363 | """ 364 | super(FusionModel, self).__init__() 365 | if fusion == 'mdf': 366 | self.features = MaxActivationFusion() 367 | elif fusion == 'early': 368 | self.features = EaryFusion() 369 | elif fusion == 'mef': 370 | self.features = MeanActivationFusion() 371 | elif fusion == 'earlyMean': 372 | self.features = EarlyMean() 373 | elif fusion == 'ee': 374 | self.features = EarlyEarly() 375 | elif fusion == '3d': 376 | self.features = _3DConv() 377 | else: 378 | raise Exception('Wrong Parameter: fusion') 379 | 380 | if upsample == 'org': 381 | self.upsample = HallucinationOrigin(scala=scala) 382 | elif upsample == 'step': 383 | self.upsample = StepHallucinationNet(scala=scala) 384 | elif upsample == 'no': 385 | self.upsample = StepHallucinationNoBN(scala=scala) 386 | else: 387 | raise Exception('Wrong Parameter: upsample') 388 | 389 | def forward(self, frame_1, frame_2, frame_3, frame_4, frame_5): 390 | return self.upsample(self.features(frame_1, frame_2, frame_3, frame_4, frame_5)) 391 | 392 | 393 | class SingleImageBaseline(nn.Module): 394 | """ 395 | Single Image Baseline 396 | """ 397 | def __init__(self, scala=8, feature='f3', upsample='org'): 398 | """ 399 | :param feature: 'f3'=Features3Layer 400 | 'f4'=Features4Layer 401 | :param upsample: 'org'=HallucinationOrigin, 402 | 'step'=StepHallucinationNet, 403 | 'no'=StepHallucinationNoBN 404 | """ 405 | super(SingleImageBaseline, self).__init__() 406 | if feature == 'f3': 407 | self.features = Features3Layer() 408 | elif feature == 'f4': 409 | self.features = Features4Layer() 410 | else: 411 | raise Exception('Wrong Parameter: feature') 412 | 413 | if upsample == 'org': 414 | self.upsample = HallucinationOrigin(scala=scala) 415 | elif upsample == 'step': 416 | self.upsample = StepHallucinationNet(scala=scala) 417 | elif upsample == 'no': 418 | self.upsample = StepHallucinationNoBN(scala=scala) 419 | else: 420 | raise Exception('Wrong Parameter: upsample') 421 | 422 | def forward(self, frame): 423 | return self.upsample(self.features(frame)) 424 | 425 | 426 | def multi_to_single_step_upsample(Multi_state_dict): 427 | for key in dict(Multi_state_dict).keys(): 428 | if key.startswith('features.features.'): 429 | Multi_state_dict[key[9:]] = Multi_state_dict.pop(key) 430 | return Multi_state_dict -------------------------------------------------------------------------------- /TorchNet/GANInverse.py: -------------------------------------------------------------------------------- 1 | # pytorch>=1.0 2 | import numpy as np 3 | import random 4 | import functools 5 | from math import ceil 6 | import itertools 7 | import os 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | from collections import OrderedDict 13 | from torch.autograd import Variable 14 | import torch.nn.functional as F 15 | from torch.utils.data import TensorDataset 16 | from PIL import Image 17 | 18 | try: 19 | from math import log2 20 | except: 21 | from math import log 22 | def log2(x): 23 | return log(x) / log(2) 24 | 25 | import math 26 | 27 | from ..DataTools.Loaders import PIL2VAR, VAR2PIL 28 | from ..DataTools.Prepro import _tanh_to_sigmoid, _sigmoid_to_tanh 29 | from ..DataTools.Loaders import _add_batch_one, _remove_batch 30 | from ..DataTools.FileTools import _image_file 31 | from ..Functions import functional as Func 32 | from .modules import residualBlock, upsampleBlock, DownsamplingShuffle, Attention, Flatten, BatchBlur, b_GaussianNoising, b_GPUVar_Bicubic, b_CPUVar_Bicubic 33 | from .activation import swish 34 | from .ClassicSRNet import SRResNet_Residual_Block 35 | 36 | 37 | PGGAN_LATENT = [(512, 1, 1), 38 | (512, 4, 4), (512, 4, 4), 39 | (512, 8, 8), (512, 8, 8), 40 | (512, 16, 16), (512, 16, 16), 41 | (512, 32, 32), (512, 32, 32), 42 | (256, 64, 64), (256, 64, 64), 43 | (128, 128, 128), (128, 128, 128), 44 | (64, 256, 256), (64, 256, 256), 45 | (32, 512, 512), (32, 512, 512), 46 | (16, 1024, 1024), (16, 1024, 1024), 47 | (3, 1024, 1024)] 48 | 49 | 50 | def PGGAN_parser(layer_number, gan_model): 51 | # layer_number count from 0 52 | rest_model = nn.Sequential(*list(gan_model.children())[layer_number:]).cuda() 53 | before_model = nn.Sequential(*list(gan_model.children())[:layer_number]).cuda() 54 | input_size = PGGAN_LATENT[layer_number] 55 | return rest_model, before_model, input_size 56 | 57 | 58 | class Naive_Inverser(object): 59 | def __init__(self, model, input_size, output_size): 60 | self.model = model 61 | self.model.eval() 62 | self.in_size = input_size 63 | self.out_size = output_size 64 | 65 | def __call__(self, gt, iterations=2000, learning_rate=0.01, criterion=nn.MSELoss(reduction='sum'), init=None, intermediate_step=-1, intermediate_stop=1000): 66 | assert list(gt.size()[1:]) == self.out_size, "check output size" 67 | batch_size = gt.size()[0] 68 | if init == None: 69 | z_estimate = torch.randn((batch_size,) + self.in_size).cuda() # our estimate, initialized randomly 70 | else: 71 | z_estimate = init 72 | z_estimate.requires_grad = True 73 | 74 | optimizer = optim.Adam([z_estimate], lr=learning_rate) 75 | 76 | # Opt 77 | z_middle = list() 78 | for i in range(iterations): 79 | y_estimate = self.model(z_estimate) 80 | optimizer.zero_grad() 81 | loss = criterion(y_estimate, gt.detach()) 82 | if intermediate_step >= 1 and i <= intermediate_stop: 83 | if i % intermediate_step == 0: 84 | z_middle.append(z_estimate.cpu()) 85 | print("iter {:04d}: y_error = {:03g}".format(i, loss.item())) 86 | loss.backward() 87 | optimizer.step() 88 | return z_estimate, z_middle if intermediate_step >= 1 else z_estimate 89 | 90 | 91 | class Convex_Sphere_Inverser(Naive_Inverser): 92 | def __init__(self, model, input_size, output_size): 93 | super(Convex_Sphere_Inverser, self).__init__(model, input_size, output_size) 94 | 95 | def __ceil__(self, gt, iterations=2000, learning_rate=0.01, criterion=nn.MSELoss(reduction='sum'), intermediate_step=-1, init=None, norm=1): 96 | assert list(gt.size()[1:]) == self.out_size, "check output size" 97 | if init == None: 98 | z_estimate = torch.randn(self.in_size).cuda() # our estimate, initialized randomly 99 | else: 100 | z_estimate = init 101 | z_estimate.requires_grad = True 102 | 103 | optimizer = optim.Adam([z_estimate], lr=learning_rate) 104 | 105 | # Opt 106 | for i in range(iterations): 107 | y_estimate = self.model(z_estimate) 108 | optimizer.zero_grad() 109 | loss = criterion(y_estimate, gt.detach()) 110 | if intermediate_step >= 1: 111 | if i % intermediate_step == 0: 112 | print("iter {:04d}: y_error = {:03g}".format(i, loss.item())) 113 | loss.backward() 114 | optimizer.step() 115 | z_estimate = z_estimate / torch.sqrt(torch.sum(torch.pow(z_estimate, 2))) 116 | return z_estimate 117 | 118 | 119 | class LBFGS_Inverser(Naive_Inverser): 120 | def __init__(self, model, input_size, output_size): 121 | super(LBFGS_Inverser, self).__init__(model, input_size, output_size) 122 | 123 | def __ceil__(self, gt, iterations=1000, learning_rate=0.1, history_size=100, criterion=nn.MSELoss(reduction='sum'), init=None, intermediate_step=-1, intermediate_stop=1000): 124 | assert list(gt.size()[1:]) == self.out_size, "check output size" 125 | if init == None: 126 | z_estimate = torch.randn(self.in_size).cuda() # our estimate, initialized randomly 127 | else: 128 | z_estimate = init 129 | z_estimate.requires_grad = True 130 | 131 | optimizer = optim.LBFGS([z_estimate], lr=learning_rate, history_size=history_size) 132 | 133 | def closure(): 134 | y_estimate = self.model(z_estimate) 135 | loss = criterion(y_estimate, gt.detach()) 136 | optimizer.zero_grad() 137 | loss.backward() 138 | return loss 139 | 140 | z_middle = list() 141 | for i in range(iterations): 142 | y_estimate = self.model(z_estimate) 143 | optimizer.zero_grad() 144 | loss = criterion(y_estimate, gt.detach()) 145 | if intermediate_step >= 1 and i <= intermediate_stop: 146 | if i % intermediate_step == 0: 147 | z_middle.append(z_estimate.cpu()) 148 | print("iter {:04d}: y_error = {:03g}".format(i, loss.item())) 149 | loss.backward() 150 | optimizer.step(closure) 151 | return z_estimate, z_middle if intermediate_step >= 1 else z_estimate 152 | 153 | 154 | class Encoder_Inverser(object): 155 | def __init__(self, model, input_size, output_size, encoder): 156 | self.model = model 157 | self.model.eval() 158 | self.encoder = encoder 159 | self.encoder.eval() 160 | self.in_size = input_size 161 | self.out_size = output_size 162 | 163 | 164 | def __call__(self, gt, iterations=500, learning_rate=0.01, criterion=nn.MSELoss(reduction='sum'), intermediate_step=-1, intermediate_stop=1000): 165 | assert list(gt.size()[1:]) == self.out_size, "check output size" 166 | batch_size = gt.size()[0] 167 | 168 | z_estimate = self.encoder(gt).detach() 169 | z_estimate.requires_grad = True 170 | 171 | optimizer = optim.Adam([z_estimate], lr=learning_rate) 172 | 173 | # Opt 174 | z_middle = list() 175 | for i in range(iterations): 176 | y_estimate = self.model(z_estimate) 177 | optimizer.zero_grad() 178 | loss = criterion(y_estimate, gt.detach()) 179 | if intermediate_step >= 1 and i <= intermediate_stop: 180 | if i % intermediate_step == 0: 181 | z_middle.append(z_estimate.cpu()) 182 | print("iter {:04d}: y_error = {:03g}".format(i, loss.item())) 183 | loss.backward() 184 | optimizer.step() 185 | return z_estimate, z_middle if intermediate_step >= 1 else z_estimate 186 | 187 | 188 | class StyleGAN_w_prime_Inverser(object): 189 | def __init__(self, model, input_size, output_size=1024, layer_number=18): 190 | self.model = model 191 | self.model.eval() 192 | self.in_size = input_size 193 | self.layer_number = layer_number 194 | self.out_size = output_size 195 | 196 | def __call__(self, gt, iterations=1000, learning_rate=0.01, criterion=nn.MSELoss(reduction='sum'), init=None, intermediate_step=-1, intermediate_stop=1000): 197 | assert list(gt.size()[1:]) == self.out_size, "check output size" 198 | batch_size = gt.size()[0] 199 | if init == None: 200 | z_estimate = [] 201 | for i in range(self.layer_number): 202 | z_estimate.append(torch.randn((batch_size,) + self.in_size).cuda()) # our estimate, initialized randomly 203 | else: 204 | z_estimate = init 205 | for z in z_estimate: 206 | z.requires_grad = True 207 | 208 | optimizer = optim.Adam(z_estimate, lr=learning_rate) 209 | 210 | # Opt 211 | z_middle = list() 212 | for i in range(iterations): 213 | y_estimate = self.model(z_estimate) 214 | optimizer.zero_grad() 215 | loss = criterion(y_estimate, gt.detach()) 216 | if intermediate_step >= 1 and i <= intermediate_stop: 217 | if i % intermediate_step == 0: 218 | z_middle.append(z_estimate.cpu()) 219 | print("iter {:04d}: y_error = {:03g}".format(i, loss.item())) 220 | loss.backward() 221 | optimizer.step() 222 | return z_estimate, z_middle if intermediate_step >= 1 else z_estimate 223 | -------------------------------------------------------------------------------- /TorchNet/Img2Img.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | try: 6 | from math import log2 7 | except: 8 | from math import log 9 | def log2(x): 10 | return log(x) / log(2) 11 | 12 | from .activation import swish 13 | from .modules import residualBlock, upsampleBlock 14 | 15 | 16 | class Generator(nn.Module): 17 | """ 18 | General Generator with pixel shuffle upsample 19 | """ 20 | def __init__(self, n_residual_blocks, scala, activation=swish): 21 | """ 22 | :param n_residual_blocks: Number of residual blocks 23 | :param scala: factor of upsample 24 | :param activation: function of activation 25 | """ 26 | super(Generator, self).__init__() 27 | self.n_residual_blocks = n_residual_blocks 28 | self.scala = scala 29 | self.act = activation 30 | 31 | self.conv1 = nn.Conv2d(3, 64, 9, stride=1, padding=4) 32 | 33 | for i in range(self.n_residual_blocks): 34 | self.add_module('residual_block' + str(i + 1), residualBlock()) 35 | 36 | self.conv2 = nn.Conv2d(64, 64, 3, stride=1, padding=1) 37 | self.bn2 = nn.BatchNorm2d(64) 38 | 39 | for i in range(int(log2(self.scala))): 40 | self.add_module('upsample' + str(i + 1), upsampleBlock(64, 256)) 41 | 42 | self.conv3 = nn.Conv2d(64, 3, 9, stride=1, padding=4) 43 | 44 | def forward(self, x): 45 | x = self.act(self.conv1(x)) 46 | 47 | y = x.clone() 48 | for i in range(self.n_residual_blocks): 49 | y = self.__getattr__('residual_block' + str(i+1))(y) 50 | 51 | x = self.bn2(self.conv2(y)) + x 52 | 53 | for i in range(int(log2(self.scala))): 54 | x = self.__getattr__('upsample' + str(i+1))(x) 55 | 56 | return F.tanh(self.conv3(x)) 57 | 58 | 59 | class GeneratorY(nn.Module): 60 | """ 61 | General Generator with pixel shuffle upsample 62 | """ 63 | def __init__(self, n_residual_blocks, scala, activation=swish): 64 | """ 65 | :param n_residual_blocks: Number of residual blocks 66 | :param scala: factor of upsample 67 | :param activation: function of activation 68 | """ 69 | super(GeneratorY, self).__init__() 70 | self.n_residual_blocks = n_residual_blocks 71 | self.scala = scala 72 | self.act = activation 73 | 74 | self.conv1 = nn.Conv2d(1, 64, 5, stride=1, padding=2) 75 | 76 | for i in range(self.n_residual_blocks): 77 | self.add_module('residual_block' + str(i + 1), residualBlock()) 78 | 79 | self.conv2 = nn.Conv2d(64, 64, 3, stride=1, padding=1) 80 | self.bn2 = nn.BatchNorm2d(64) 81 | 82 | for i in range(int(log2(self.scala))): 83 | self.add_module('upsample' + str(i + 1), upsampleBlock(64, 256)) 84 | 85 | self.conv3 = nn.Conv2d(64, 1, 7, stride=1, padding=3) 86 | 87 | def forward(self, x): 88 | x = self.act(self.conv1(x)) 89 | 90 | y = x.clone() 91 | for i in range(self.n_residual_blocks): 92 | y = self.__getattr__('residual_block' + str(i+1))(y) 93 | 94 | x = self.bn2(self.conv2(y)) + x 95 | 96 | for i in range(int(log2(self.scala))): 97 | x = self.__getattr__('upsample' + str(i+1))(x) 98 | 99 | return F.tanh(self.conv3(x)) 100 | 101 | 102 | class VideoSRGAN_5(nn.Module): 103 | 104 | def __init__(self, n_residual_blocks, scala, activation=swish): 105 | super(VideoSRGAN_5, self).__init__() 106 | self.n_residual_blocks = n_residual_blocks 107 | self.scala = scala 108 | self.time_window = 5 109 | self.act = activation 110 | 111 | self.conv_1_center = nn.Conv2d(1, 32, 5, stride=1, padding=2) 112 | self.conv_1_1 = nn.Conv2d(1, 32, 5, stride=1, padding=2) 113 | self.conv_1_2 = nn.Conv2d(1, 32, 5, stride=1, padding=2) 114 | 115 | self.shrink_conv = nn.Conv2d(32 * 5, 64, 1, stride=1, padding=0) 116 | 117 | for i in range(self.n_residual_blocks): 118 | self.add_module('residual_block' + str(i + 1), residualBlock()) 119 | 120 | self.conv2 = nn.Conv2d(64, 64, 3, stride=1, padding=1) 121 | self.bn2 = nn.BatchNorm2d(64) 122 | 123 | for i in range(int(log2(self.scala))): 124 | self.add_module('upsample' + str(i + 1), upsampleBlock(64, 256)) 125 | 126 | self.final_pad = nn.ReflectionPad2d(3) 127 | self.conv3 = nn.Conv2d(64, 1, 7, stride=1, padding=0) 128 | 129 | def forward(self, frame_1, frame_2, frame_center, frame_4, frame_5): 130 | center_conv = self.conv_1_center(frame_center) 131 | conv_frame_1 = self.conv_1_2(frame_1) 132 | conv_frame_2 = self.conv_1_1(frame_2) 133 | conv_frame_4 = self.conv_1_1(frame_4) 134 | conv_frame_5 = self.conv_1_2(frame_5) 135 | conv_1 = self.act(torch.cat((conv_frame_1, conv_frame_2, center_conv, conv_frame_4, conv_frame_5), dim=1)) 136 | x = self.act(self.shrink_conv(conv_1)) 137 | 138 | y = x.clone() 139 | for i in range(self.n_residual_blocks): 140 | y = self.__getattr__('residual_block' + str(i + 1))(y) 141 | 142 | x = self.bn2(self.conv2(y)) + x 143 | 144 | for i in range(int(log2(self.scala))): 145 | x = self.__getattr__('upsample' + str(i + 1))(x) 146 | 147 | return F.tanh(self.conv3(self.final_pad(x))) 148 | 149 | 150 | -------------------------------------------------------------------------------- /TorchNet/Losses.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.init as init 7 | import torch.nn.functional as F 8 | from torch.autograd import Variable 9 | 10 | from .VGG import vgg19 11 | 12 | 13 | class MCTHuberLoss(nn.Module): 14 | """ 15 | The Huber Loss used in MCT 16 | """ 17 | def __init__(self, hpyer_lambda, epsilon=0.01): 18 | super(MCTHuberLoss, self).__init__() 19 | self.epsilon = epsilon 20 | self.lamb = hpyer_lambda 21 | self.sobel = nn.Conv2d(2, 4, 3, stride=1, padding=0, groups=2) 22 | weight = np.array([[[[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]]], 23 | [[[1.0, 2.0, 1.0], [0.0, 0.0, 0.0], [-1.0, -2.0, -1.0]]], 24 | [[[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]]], 25 | [[[1.0, 2.0, 1.0], [0.0, 0.0, 0.0], [-1.0, -2.0, -1.0]]]], dtype=np.float32) 26 | bias = np.array([0.0, 0.0, 0.0, 0.0], dtype=np.float32) 27 | self.sobel.weight.data = torch.from_numpy(weight) 28 | self.sobel.bias.data = torch.from_numpy(bias) 29 | 30 | def forward(self, flows): 31 | Grad_Flow = self.sobel(flows) 32 | return torch.sqrt(torch.sum(Grad_Flow * Grad_Flow) + self.epsilon) * self.lamb 33 | 34 | def _sobel(self, flows): 35 | return self.sobel(flows) 36 | 37 | 38 | class TVLoss(nn.Module): 39 | def __init__(self): 40 | super(TVLoss,self).__init__() 41 | 42 | def forward(self, x): 43 | batch_size = x.size()[0] 44 | h_x = x.size()[2] 45 | w_x = x.size()[3] 46 | count_h = self._tensor_size(x[:, :, 1:, :]) 47 | count_w = self._tensor_size(x[:, :, :, 1:]) 48 | h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1,:]), 2).sum() 49 | w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum() 50 | return 2 * (h_tv / count_h + w_tv / count_w) / batch_size 51 | 52 | def _tensor_size(self,t): 53 | return t.size()[1] * t.size()[2] * t.size()[3] 54 | 55 | 56 | class CropMarginLoss(nn.Module): 57 | def __init__(self, loss=nn.MSELoss, crop=5): 58 | super(CropMarginLoss, self).__init__() 59 | self.loss = loss() 60 | self.crop = crop 61 | 62 | def forward(self, input, target): 63 | return self.loss(input[:, :, self.crop: -self.crop, self.crop: -self.crop], target[:, :, self.crop: -self.crop, self.crop: -self.crop]) 64 | 65 | 66 | class GANLoss(nn.Module): 67 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, 68 | tensor=torch.FloatTensor): 69 | super(GANLoss, self).__init__() 70 | self.real_label = target_real_label 71 | self.fake_label = target_fake_label 72 | self.real_label_var = None 73 | self.fake_label_var = None 74 | self.Tensor = tensor 75 | if use_lsgan: 76 | self.loss = nn.MSELoss() 77 | else: 78 | self.loss = nn.BCEWithLogitsLoss() 79 | 80 | def get_target_tensor(self, input, target_is_real): 81 | if target_is_real: 82 | create_label = ((self.real_label_var is None) or 83 | (self.real_label_var.numel() != input.numel())) 84 | if create_label: 85 | real_tensor = self.Tensor(input.size()).fill_(self.real_label) 86 | self.real_label_var = Variable(real_tensor, requires_grad=False) 87 | target_tensor = self.real_label_var 88 | else: 89 | create_label = ((self.fake_label_var is None) or 90 | (self.fake_label_var.numel() != input.numel())) 91 | if create_label: 92 | fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) 93 | self.fake_label_var = Variable(fake_tensor, requires_grad=False) 94 | target_tensor = self.fake_label_var 95 | return target_tensor 96 | 97 | def forward(self, input, target_is_real): 98 | target_tensor = self.get_target_tensor(input, target_is_real) 99 | return self.loss(input, target_tensor) 100 | 101 | 102 | class VGGLoss(nn.Module): 103 | """ 104 | VGG( 105 | (features): Sequential( 106 | (0): Conv2d (3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 107 | (1): ReLU(inplace) 108 | (2): Conv2d (64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 109 | (3): ReLU(inplace) # 5 x 5 110 | (4): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1)) 111 | 112 | (5): Conv2d (64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 113 | (6): ReLU(inplace) 114 | (7): Conv2d (128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 115 | (8): ReLU(inplace) # 14 x 14 116 | (9): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1)) 117 | 118 | (10): Conv2d (128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 119 | (11): ReLU(inplace) 120 | (12): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 121 | (13): ReLU(inplace) 122 | (14): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 123 | (15): ReLU(inplace) 124 | (16): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 125 | (17): ReLU(inplace) # 48 x 48 126 | (18): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1)) 127 | 128 | (19): Conv2d (256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 129 | (20): ReLU(inplace) 130 | (21): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 131 | (22): ReLU(inplace) 132 | (23): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 133 | (24): ReLU(inplace) 134 | (25): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 135 | (26): ReLU(inplace) # 116 x 116 136 | (27): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1)) 137 | 138 | (28): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 139 | (29): ReLU(inplace) 140 | (30): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 141 | (31): ReLU(inplace) 142 | (32): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 143 | (33): ReLU(inplace) 144 | (34): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 145 | (35): ReLU(inplace) 146 | (36): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1)) 147 | ) 148 | """ 149 | def __init__(self, vgg_path, layers='5', input='RGB', loss='l1', activate='before'): 150 | super(VGGLoss, self).__init__() 151 | self.input = input 152 | vgg = vgg19() 153 | vgg.load_state_dict(torch.load(vgg_path)) 154 | self.layers = [int(l) for l in layers] 155 | layers_dict = [0, 4, 9, 18, 27, 36] if activate == 'after' else [0, 3, 8, 17, 26, 35] 156 | self.vgg = [] 157 | if loss == 'l1': 158 | self.loss_model = nn.L1Loss() 159 | elif loss == 'l2': 160 | self.loss_model = nn.MSELoss() 161 | else: 162 | raise Exception('Do not support this loss.') 163 | 164 | i = 0 165 | for j in self.layers: 166 | self.vgg.append(nn.Sequential(*list(vgg.features.children())[layers_dict[i]:layers_dict[j]])) 167 | i = j 168 | 169 | def cuda(self, device=None): 170 | for Seq in self.vgg: 171 | Seq.cuda() 172 | self.loss_model.cuda() 173 | 174 | def forward(self, input, target): 175 | if self.input == 'RGB': 176 | input_R, input_G, input_B = torch.split(input, 1, dim=1) 177 | target_R, target_G, target_B = torch.split(target, 1, dim=1) 178 | input_BGR = torch.cat([input_B, input_G, input_R], dim=1) 179 | target_BGR = torch.cat([target_B, target_G, target_R], dim=1) 180 | else: 181 | input_BGR = input 182 | target_BGR = target 183 | 184 | input_list = [input_BGR] 185 | target_list = [target_BGR] 186 | 187 | for Sequential in self.vgg: 188 | input_list.append(Sequential(input_list[-1])) 189 | target_list.append(Sequential(target_list[-1])) 190 | 191 | loss = [] 192 | for i in range(len(self.layers)): 193 | loss.append(self.loss_model(input_list[i + 1], target_list[i + 1].detach())) 194 | if len(loss) != 1: 195 | return sum(loss) 196 | else: 197 | return loss[0] 198 | 199 | 200 | class ContextualLoss(nn.Module): 201 | def __init__(self, sigma=0.1, b=1.0, epsilon=1e-5, similarity='cos'): 202 | super(ContextualLoss, self).__init__() 203 | self.sigma = sigma 204 | self.similarity = similarity 205 | self.b = b 206 | self.e = epsilon 207 | 208 | def cos_similarity(self, image_features, target_features): 209 | # N, V, C 210 | if_vec = image_features.view((image_features.size()[0], image_features.size()[1], -1)).permute(0, 2, 1) 211 | tf_vec = target_features.view((target_features.size()[0], target_features.size()[1], -1)).permute(0, 2, 1) 212 | # Centre by T 213 | tf_mean = torch.mean(tf_vec, dim=1, keepdim=True) 214 | ifc_vec = if_vec - tf_mean 215 | tfc_vec = tf_vec - tf_mean 216 | # L2-norm normalization 217 | ifc_vec_l2 = torch.div(ifc_vec, torch.sqrt(torch.sum(ifc_vec * ifc_vec, dim=2, keepdim=True))) 218 | tfc_vec_l2 = torch.div(tfc_vec, torch.sqrt(torch.sum(tfc_vec * tfc_vec, dim=2, keepdim=True))) 219 | # cross dot 220 | feature_cos_similarity_matrix = 1 - torch.bmm(ifc_vec_l2, tfc_vec_l2.permute(0, 2, 1)) 221 | return feature_cos_similarity_matrix 222 | 223 | def L2_similarity(self, image_features, target_features): 224 | pass 225 | 226 | def relative_distances(self, feature_similarity_matrix): 227 | relative_dist = feature_similarity_matrix / (torch.min(feature_similarity_matrix, dim=2, keepdim=True)[0] + self.e) 228 | return relative_dist 229 | 230 | def weighted_average_distances(self, relative_distances_matrix): 231 | weights_before_normalization = torch.exp((self.b - relative_distances_matrix) / self.sigma) 232 | weights_sum = torch.sum(weights_before_normalization, dim=2, keepdim=True) 233 | weights_normalized = torch.div(weights_before_normalization, weights_sum) 234 | return weights_normalized 235 | 236 | def CX(self, feature_similarity_matrix): 237 | CX_i_j = self.weighted_average_distances(self.relative_distances(feature_similarity_matrix)) 238 | CX_j_i = CX_i_j.permute(0, 2, 1) 239 | max_i_on_j = torch.max(CX_j_i, dim=1)[0] 240 | CS = torch.mean(max_i_on_j, dim=1) 241 | CX = - torch.log(CS) 242 | CX_loss = torch.mean(CX) 243 | return CX_loss 244 | 245 | def forward(self, image_features, target_features): 246 | if self.similarity == 'cos': 247 | feature_similarity_matrix = self.cos_similarity(image_features, target_features) 248 | elif self.similarity == 'l2': 249 | feature_similarity_matrix = self.L2_similarity(image_features, target_features) 250 | else: 251 | feature_similarity_matrix = self.cos_similarity(image_features, target_features) 252 | return self.CX(feature_similarity_matrix) 253 | 254 | 255 | class VGGContextualLoss(nn.Module): 256 | def __init__(self, vgg_path, layer=3, input='RGB', sigma=0.1, b=1.0, epsilon=1e-5, similarity='cos', activate='before'): 257 | super(VGGContextualLoss, self).__init__() 258 | self.input = input 259 | vgg = vgg19() 260 | vgg.load_state_dict(torch.load(vgg_path)) 261 | self.layer = layer # layers in [1, 2, 3, 4, 5] 262 | self.CXLoss = ContextualLoss(sigma=sigma, b=b, epsilon=epsilon, similarity=similarity) 263 | layers_dict = [0, 4, 9, 18, 27, 36] if activate == 'after' else [0, 3, 8, 17, 26, 35] 264 | self.sequential = nn.Sequential(*list(vgg.features.children())[0:layers_dict[layer-1]]) 265 | 266 | def forward(self, image, target): 267 | if self.input == 'RGB': 268 | input_R, input_G, input_B = torch.split(image, 1, dim=1) 269 | target_R, target_G, target_B = torch.split(target, 1, dim=1) 270 | image_BGR = torch.cat([input_B, input_G, input_R], dim=1) 271 | target_BGR = torch.cat([target_B, target_G, target_R], dim=1) 272 | else: 273 | image_BGR = image 274 | target_BGR = target 275 | 276 | image_features = self.sequential(image_BGR) 277 | target_features = self.sequential(target_BGR) 278 | return self.CXLoss(image_features, target_features) 279 | 280 | def cuda(self, device=None): 281 | self.sequential.cuda() 282 | 283 | 284 | 285 | 286 | -------------------------------------------------------------------------------- /TorchNet/OpticalFlow.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import functools 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | 9 | from ..DataTools.Loaders import _add_batch_one 10 | from .modules import Attention 11 | 12 | 13 | class _CoarseFlow(nn.Module): 14 | """ 15 | Coarse Flow Network in MCT 16 | |----------------------| 17 | | Input two frame | 18 | |----------------------| 19 | | Conv k5-n24-s2, ReLu | 20 | |----------------------| 21 | | Conv k3-n24-s1, ReLu | 22 | |----------------------| 23 | | Conv k5-n24-s2, ReLu | 24 | |----------------------| 25 | | Conv k3-n24-s1, ReLu | 26 | |----------------------| 27 | | Conv k3-n32-s1, Tanh | 28 | |----------------------| 29 | | Pixel Shuffle x4 | 30 | |----------------------| 31 | """ 32 | def __init__(self, input_channel=1): 33 | super(_CoarseFlow, self).__init__() 34 | self.channel = input_channel 35 | self.conv1 = nn.Conv2d(input_channel * 2, 24, 5, stride=2, padding=2) 36 | self.conv2 = nn.Conv2d(24, 24, 3, stride=1, padding=1) 37 | self.conv3 = nn.Conv2d(24, 24, 5, stride=2, padding=2) 38 | self.conv4 = nn.Conv2d(24, 24, 3, stride=1, padding=1) 39 | self.conv5 = nn.Conv2d(24, 32, 3, stride=1, padding=1) 40 | self.pix = nn.PixelShuffle(4) 41 | 42 | def forward(self, frame_t, frame_tp1): 43 | input = torch.cat([frame_t, frame_tp1], dim=1) 44 | return self.pix( 45 | self.conv5( 46 | F.relu(self.conv4( 47 | F.relu(self.conv3( 48 | F.relu(self.conv2( 49 | F.relu(self.conv1(input)) 50 | )) 51 | )) 52 | )) 53 | ) 54 | ) 55 | 56 | 57 | class _FineFlow(nn.Module): 58 | """ 59 | Fine Flow Network in MCT 60 | |----------------------| 61 | | Input two frame | 62 | |----------------------| 63 | | Conv k5-n24-s2, ReLu | 64 | |----------------------| 65 | | Conv k3-n24-s1, ReLu | 66 | |----------------------| 67 | | Conv k3-n24-s1, ReLu | 68 | |----------------------| 69 | | Conv k3-n24-s1, ReLu | 70 | |----------------------| 71 | | Conv k3-n8-s1, Tanh | 72 | |----------------------| 73 | | Pixel Shuffle x2 | 74 | |----------------------| 75 | """ 76 | def __init__(self, input_channel=1): 77 | super(_FineFlow, self).__init__() 78 | self.channel = input_channel 79 | self.conv1 = nn.Conv2d(input_channel * 3 + 2, 24, 5, stride=2, padding=2) 80 | self.conv2 = nn.Conv2d(24, 24, 3, stride=1, padding=1) 81 | self.conv3 = nn.Conv2d(24, 24, 3, stride=1, padding=1) 82 | self.conv4 = nn.Conv2d(24, 24, 3, stride=1, padding=1) 83 | self.conv5 = nn.Conv2d(24, 8, 3, stride=1, padding=1) 84 | self.pix = nn.PixelShuffle(2) 85 | 86 | def forward(self, frame_t, frame_tp1, flow, coarse_frame_tp1): 87 | input = torch.cat([frame_t, frame_tp1, flow, coarse_frame_tp1], dim=1) 88 | return self.pix( 89 | self.conv5( 90 | F.relu(self.conv4( 91 | F.relu(self.conv3( 92 | F.relu(self.conv2( 93 | F.relu(self.conv1(input)) 94 | )) 95 | )) 96 | )) 97 | ) 98 | ) 99 | 100 | 101 | class Advector(nn.Module): 102 | """ 103 | According to the PyTorch documentation 104 | The X axis is positively toward left and the Y-axis is positively toward upward 105 | """ 106 | def __init__(self): 107 | super(Advector, self).__init__() 108 | self.std_theta = np.array([[[1, 0, 0], [0, 1, 0]]], dtype=np.float32) 109 | 110 | def forward(self, frame_t, vectors): 111 | N, C, H, W = frame_t.size() 112 | vectors[:, 0] = 2 * vectors[:, 0] / W 113 | vectors[:, 1] = 2 * vectors[:, 1] / H 114 | if isinstance(frame_t.data, torch.cuda.FloatTensor): 115 | std_theta = Variable(torch.from_numpy(np.repeat(self.std_theta, N, axis=0))).cuda() 116 | else: 117 | std_theta = Variable(torch.from_numpy(np.repeat(self.std_theta, N, axis=0))) 118 | for i in range(N): 119 | std_theta[i, :, 2] = vectors[i] 120 | if isinstance(frame_t.data, torch.cuda.FloatTensor): 121 | affine = F.affine_grid(std_theta, frame_t.size()).cuda() 122 | else: 123 | affine = F.affine_grid(std_theta, frame_t.size()) 124 | return F.grid_sample(frame_t, affine) 125 | 126 | 127 | class Warp(nn.Module): 128 | """ 129 | Warp Using Optical Flow 130 | """ 131 | def __init__(self): 132 | super(Warp, self).__init__() 133 | self.std_theta = np.eye(2, 3, dtype=np.float32).reshape((1, 2, 3)) 134 | 135 | def forward(self, frame_t, flow_field): 136 | """ 137 | :param frame_t: input batch of images (N x C x IH x IW) 138 | :param flow_field: flow_field with shape(N x 2 x OH x OW) 139 | :return: output Tensor 140 | """ 141 | N, C, H, W = frame_t.size() 142 | std_theta = torch.from_numpy(np.repeat(self.std_theta, N, axis=0)) 143 | if isinstance(frame_t.data, torch.cuda.FloatTensor): 144 | std = F.affine_grid(std_theta, frame_t.size()).cuda() 145 | else: 146 | std = F.affine_grid(std_theta, frame_t.size()) 147 | flow_field[:, 0, :, :] = flow_field[:, 0, :, :] / W 148 | flow_field[:, 1, :, :] = flow_field[:, 1, :, :] / H 149 | return F.grid_sample(frame_t, std + flow_field.permute(0, 2, 3, 1)) 150 | 151 | 152 | class FlowField(nn.Module): 153 | """ 154 | The final Fine Flow 155 | """ 156 | def __init__(self): 157 | super(FlowField, self).__init__() 158 | self.coarse_net = _CoarseFlow() 159 | self.fine_net = _FineFlow() 160 | self.warp = Warp() 161 | 162 | def forward(self, frame_t, frame_tp1): 163 | coarse_flow = self.coarse_net(frame_t, frame_tp1) 164 | coarse_frame_tp1 = self.warp(frame_t, coarse_flow) 165 | return self.fine_net(frame_t, frame_tp1, coarse_flow, coarse_frame_tp1) + coarse_flow 166 | 167 | 168 | class _CoarseFlowNoStride(nn.Module): 169 | """ 170 | Coarse Flow Network in MCT without Stride 171 | |----------------------| 172 | | Input two frame | 173 | |----------------------| 174 | | Conv k5-n32-s1, ReLu | 175 | |----------------------| 176 | | Conv k3-n32-s1, ReLu | 177 | |----------------------| 178 | | Conv k5-n32-s1, ReLu | 179 | |----------------------| 180 | | Conv k3-n32-s1, ReLu | 181 | |----------------------| 182 | | Conv k3-n32-s1, Tanh | 183 | |----------------------| 184 | | Pixel Shuffle x4 | 185 | |----------------------| 186 | """ 187 | def __init__(self, input_channel=1): 188 | super(_CoarseFlowNoStride, self).__init__() 189 | self.channel = input_channel 190 | self.conv1 = nn.Conv2d(input_channel * 2, 32, 7, stride=1, padding=3) 191 | self.conv2 = nn.Conv2d(32, 32, 5, stride=1, padding=2) 192 | self.conv3 = nn.Conv2d(32, 32, 5, stride=1, padding=2) 193 | self.conv4 = nn.Conv2d(32, 32, 5, stride=1, padding=2) 194 | self.conv5 = nn.Conv2d(32, 2, 5, stride=1, padding=2) 195 | 196 | def forward(self, frame_t, frame_tp1): 197 | input = torch.cat([frame_t, frame_tp1], dim=1) 198 | return self.conv5( 199 | F.relu(self.conv4( 200 | F.relu(self.conv3( 201 | F.relu(self.conv2( 202 | F.relu(self.conv1(input)) 203 | )) 204 | )) 205 | )) 206 | ) 207 | 208 | 209 | class _FineFlowNoStride(nn.Module): 210 | """ 211 | Fine Flow Network in MCT without Stride 212 | |----------------------| 213 | | Input two frame | 214 | |----------------------| 215 | | Conv k5-n24-s1, ReLu | 216 | |----------------------| 217 | | Conv k3-n24-s1, ReLu | 218 | |----------------------| 219 | | Conv k3-n24-s1, ReLu | 220 | |----------------------| 221 | | Conv k3-n24-s1, ReLu | 222 | |----------------------| 223 | | Conv k3-n8-s1, Tanh | 224 | |----------------------| 225 | | Pixel Shuffle x2 | 226 | |----------------------| 227 | """ 228 | def __init__(self, input_channel=1): 229 | super(_FineFlowNoStride, self).__init__() 230 | self.channel = input_channel 231 | self.conv1 = nn.Conv2d(input_channel * 3 + 2, 32, 5, stride=1, padding=2) 232 | self.conv2 = nn.Conv2d(32, 32, 5, stride=1, padding=2) 233 | self.conv3 = nn.Conv2d(32, 32, 3, stride=1, padding=1) 234 | self.conv4 = nn.Conv2d(32, 32, 3, stride=1, padding=1) 235 | self.conv5 = nn.Conv2d(32, 2, 3, stride=1, padding=1) 236 | 237 | def forward(self, frame_t, frame_tp1, flow, coarse_frame_tp1): 238 | input = torch.cat([frame_t, frame_tp1, flow, coarse_frame_tp1], dim=1) 239 | return self.conv5( 240 | F.relu(self.conv4( 241 | F.relu(self.conv3( 242 | F.relu(self.conv2( 243 | F.relu(self.conv1(input)) 244 | )) 245 | )) 246 | )) 247 | ) 248 | 249 | 250 | class FlowFielsNoStride(nn.Module): 251 | """ 252 | The final Fine Flow 253 | """ 254 | def __init__(self): 255 | super(FlowFielsNoStride, self).__init__() 256 | self.coarse_net = _CoarseFlowNoStride() 257 | self.fine_net = _FineFlowNoStride() 258 | self.warp = Warp() 259 | 260 | def forward(self, frame_t, frame_tp1): 261 | coarse_flow = self.coarse_net(frame_t, frame_tp1) 262 | coarse_frame_tp1 = self.warp(frame_t, coarse_flow) 263 | return self.fine_net(frame_t, frame_tp1, coarse_flow, coarse_frame_tp1) + coarse_flow 264 | 265 | 266 | class Affine_Translation(nn.Module): 267 | def __init__(self, input_channel=1, freedom_degree=2): 268 | super(Affine_Translation, self).__init__() 269 | self.channel = input_channel 270 | self.conv1 = nn.Conv2d(input_channel * 2, 16, 5, stride=2, padding=2) 271 | self.conv2 = nn.Conv2d(16, 16, 5, stride=1, padding=2) 272 | self.conv3 = nn.Conv2d(16, 16, 5, stride=2, padding=2) 273 | self.conv4 = nn.Conv2d(16, 16, 5, stride=1, padding=2) 274 | self.conv5 = nn.Conv2d(16, 32, 3, stride=2, padding=1) 275 | self.linear = nn.Linear(32, freedom_degree) 276 | 277 | def forward(self, frame_t, frame_tp1): 278 | input = torch.cat([frame_t, frame_tp1], dim=1) 279 | x = F.relu(self.conv1(input)) 280 | x = F.relu(self.conv2(x)) 281 | x = F.relu(self.conv3(x)) 282 | x = F.relu(self.conv4(x)) 283 | x = F.relu(self.conv5(x)) 284 | x = F.avg_pool2d(x, kernel_size=x.size()[2:]) 285 | x = self.linear(x.view(x.size()[:2])) 286 | return x 287 | 288 | 289 | class AdvectorFCN(nn.Module): 290 | """ 291 | Discriminator with attention module 292 | """ 293 | 294 | def __init__(self, input_nc=1, ndf=16, output_nc=2, norm_layer=nn.InstanceNorm2d, use_sigmoid=False, feature_channels=16, 295 | down_samples=2): 296 | super(AdvectorFCN, self).__init__() 297 | 298 | if type(norm_layer) == functools.partial: 299 | use_bias = norm_layer.func == nn.InstanceNorm2d 300 | else: 301 | use_bias = norm_layer == nn.InstanceNorm2d 302 | 303 | pixelS_net = [ 304 | nn.Conv2d(input_nc, ndf, kernel_size=3, stride=1, padding=1), 305 | nn.LeakyReLU(0.2, True), 306 | 307 | nn.Conv2d(ndf, ndf * 2, kernel_size=3, stride=1, padding=1, bias=use_bias), 308 | norm_layer(ndf * 2), 309 | nn.LeakyReLU(0.2, True), 310 | 311 | nn.Conv2d(ndf * 2, ndf * 2, kernel_size=3, stride=1, padding=1, bias=use_bias), 312 | norm_layer(ndf * 2), 313 | nn.LeakyReLU(0.2, True), 314 | 315 | nn.Conv2d(ndf * 2, ndf * 2, kernel_size=3, stride=1, padding=1, bias=use_bias), 316 | norm_layer(ndf * 2), 317 | nn.LeakyReLU(0.2, True), 318 | 319 | nn.Conv2d(ndf * 2, output_nc, kernel_size=1, stride=1, padding=0, bias=use_bias), 320 | norm_layer(ndf * 2), 321 | nn.LeakyReLU(0.2, True), 322 | ] 323 | 324 | self.pixelD = nn.Sequential(*pixelS_net) 325 | self.attention = Attention(input_channel=input_nc, feature_channels=feature_channels, down_samples=down_samples) 326 | 327 | def forward(self, input, train=False, is_attention=False): 328 | pixel_level_pred = self.pixelD(input) 329 | if train: 330 | attention_map = self.attention(input) 331 | else: 332 | attention_map = self.attention(input).detach() 333 | weighted_pred = torch.mul(pixel_level_pred, attention_map) 334 | weighted_sum = torch.sum(weighted_pred.view(weighted_pred.size(0), weighted_pred.size(1), -1), dim=2) 335 | attention_sum = torch.sum(attention_map.view(attention_map.size(0), attention_map.size(1), -1), dim=2) 336 | final_pred = torch.div(weighted_sum, attention_sum) 337 | return_tuple = final_pred 338 | if train: 339 | return final_pred, pixel_level_pred 340 | elif is_attention: 341 | return final_pred, pixel_level_pred, attention_map 342 | else: 343 | return return_tuple 344 | 345 | 346 | -------------------------------------------------------------------------------- /TorchNet/Optim.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch.optim.optimizer import Optimizer 3 | 4 | 5 | class AdamPre(Optimizer): 6 | """Implements Adam algorithm with prediction step. 7 | 8 | This class implements lookahead version of Adam Optimizer. 9 | The structure of class is similar to Adam class in Pytorch. 10 | 11 | Arguments: 12 | params (iterable): iterable of parameters to optimize or dicts defining 13 | parameter groups 14 | lr (float, optional): learning rate (default: 1e-3) 15 | betas (Tuple[float, float], optional): coefficients used for computing 16 | running averages of gradient and its square (default: (0.9, 0.999)) 17 | eps (float, optional): term added to the denominator to improve 18 | numerical stability (default: 1e-8) 19 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 20 | """ 21 | 22 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 23 | weight_decay=0, name='NotGiven'): 24 | self.first_time = True 25 | self.name = name 26 | defaults = dict(lr=lr, betas=betas, eps=eps, 27 | weight_decay=weight_decay) 28 | super(AdamPre, self).__init__(params, defaults) 29 | 30 | def step(self, closure=None): 31 | """Performs a single optimization step. 32 | 33 | Arguments: 34 | closure (callable, optional): A closure that reevaluates the model 35 | and returns the loss. 36 | """ 37 | self.first_time = False 38 | loss = None 39 | if closure is not None: 40 | loss = closure() 41 | 42 | for group in self.param_groups: 43 | for p in group['params']: 44 | if p.grad is None: 45 | continue 46 | grad = p.grad.data 47 | 48 | state = self.state[p] 49 | 50 | # State initialization 51 | if len(state) == 0: 52 | state['step'] = 0 53 | # Exponential moving average of gradient values 54 | state['exp_avg'] = grad.new().resize_as_(grad).zero_() 55 | # Exponential moving average of squared gradient values 56 | state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_() 57 | 58 | state['oldWeights'] = p.data.clone() 59 | 60 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 61 | beta1, beta2 = group['betas'] 62 | 63 | state['step'] += 1 64 | 65 | if group['weight_decay'] != 0: 66 | grad = grad.add(group['weight_decay'], p.data) 67 | 68 | # Decay the first and second moment running average coefficient 69 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 70 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 71 | 72 | denom = exp_avg_sq.sqrt().add_(group['eps']) 73 | 74 | bias_correction1 = 1 - beta1 ** min(state['step'],1022) 75 | bias_correction2 = 1 - beta2 ** min(state['step'],1022) 76 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 77 | 78 | p.data.addcdiv_(-step_size, exp_avg, denom) 79 | return loss 80 | 81 | def stepLookAhead(self, closure=None): 82 | """Performs a single optimization step. 83 | 84 | Arguments: 85 | closure (callable, optional): A closure that reevaluates the model 86 | and returns the loss. 87 | """ 88 | if self.first_time: 89 | return None 90 | loss = None 91 | if closure is not None: 92 | loss = closure() 93 | 94 | for group in self.param_groups: 95 | for p in group['params']: 96 | if p.grad is None: 97 | continue 98 | state = self.state[p] 99 | temp_grad = p.data.sub(state['oldWeights']) 100 | state['oldWeights'].copy_(p.data) 101 | p.data.add_(temp_grad) 102 | return loss 103 | 104 | 105 | def restoreStepLookAhead(self, closure=None): 106 | """Performs a single optimization step. 107 | 108 | Arguments: 109 | closure (callable, optional): A closure that reevaluates the model 110 | and returns the loss. 111 | """ 112 | if self.first_time: 113 | return None 114 | loss = None 115 | if closure is not None: 116 | loss = closure() 117 | 118 | for group in self.param_groups: 119 | for p in group['params']: 120 | if p.grad is None: 121 | continue 122 | state = self.state[p] 123 | p.data.copy_(state['oldWeights']) 124 | return loss 125 | -------------------------------------------------------------------------------- /TorchNet/PGGAN.py: -------------------------------------------------------------------------------- 1 | class ProgressiveGenerator(nn.Sequential): 2 | def __init__(self, resolution=None, sizes=None, modify_sequence=None, 3 | output_tanh=False): 4 | ''' 5 | A pytorch progessive GAN generator that can be converted directly 6 | from either a tensorflow model or a theano model. It consists of 7 | a sequence of convolutional layers, organized in pairs, with an 8 | upsampling and reduction of channels at every other layer; and 9 | then finally followed by an output layer that reduces it to an 10 | RGB [-1..1] image. 11 | 12 | The network can be given more layers to increase the output 13 | resolution. The sizes argument indicates the fieature depth at 14 | each upsampling, starting with the input z: [input-dim, 4x4-depth, 15 | 8x8-depth, 16x16-depth...]. The output dimension is 2 * 2**len(sizes) 16 | 17 | Some default architectures can be selected by supplying the 18 | resolution argument instead. 19 | 20 | The optional modify_sequence function can be used to transform the 21 | sequence of layers before the network is constructed. 22 | 23 | If output_tanh is set to True, the network applies a tanh to clamp 24 | the output to [-1,1] before output; otherwise the output is unclamped. 25 | ''' 26 | assert (resolution is None) != (sizes is None) 27 | if sizes is None: 28 | sizes = { 29 | 8: [512, 512, 512], 30 | 16: [512, 512, 512, 512], 31 | 32: [512, 512, 512, 512, 256], 32 | 64: [512, 512, 512, 512, 256, 128], 33 | 128: [512, 512, 512, 512, 256, 128, 64], 34 | 256: [512, 512, 512, 512, 256, 128, 64, 32], 35 | 1024: [512, 512, 512, 512, 512, 256, 128, 64, 32, 16] 36 | }[resolution] 37 | # Follow the schedule of upsampling given by sizes. 38 | # layers are called: layer1, layer2, etc; then output_128x128 39 | sequence = [] 40 | def add_d(layer, name=None): 41 | if name is None: 42 | name = 'layer%d' % (len(sequence) + 1) 43 | sequence.append((name, layer)) 44 | add_d(NormConvBlock(sizes[0], sizes[1], kernel_size=4, padding=3)) 45 | add_d(NormConvBlock(sizes[1], sizes[1], kernel_size=3, padding=1)) 46 | for i, (si, so) in enumerate(zip(sizes[1:-1], sizes[2:])): 47 | add_d(NormUpscaleConvBlock(si, so, kernel_size=3, padding=1)) 48 | add_d(NormConvBlock(so, so, kernel_size=3, padding=1)) 49 | # Create an output layer. During training, the progressive GAN 50 | # learns several such output layers for various resolutions; we 51 | # just include the last (highest resolution) one. 52 | dim = 4 * (2 ** (len(sequence) // 2 - 1)) 53 | add_d(OutputConvBlock(sizes[-1], tanh=output_tanh), 54 | name='output_%dx%d' % (dim, dim)) 55 | # Allow the sequence to be modified 56 | if modify_sequence is not None: 57 | sequence = modify_sequence(sequence) 58 | super().__init__(OrderedDict(sequence)) 59 | 60 | def forward(self, x): 61 | # Convert vector input to 1x1 featuremap. 62 | x = x.view(x.shape[0], x.shape[1], 1, 1) 63 | return super().forward(x) 64 | 65 | 66 | class ProgressiveGeneratorEncoder(nn.Sequential): 67 | def __init__(self, resolution=None, sizes=None, modify_sequence=None, 68 | output_tanh=False): 69 | ''' 70 | A pytorch progessive GAN generator that can be converted directly 71 | from either a tensorflow model or a theano model. It consists of 72 | a sequence of convolutional layers, organized in pairs, with an 73 | upsampling and reduction of channels at every other layer; and 74 | then finally followed by an output layer that reduces it to an 75 | RGB [-1..1] image. 76 | 77 | The network can be given more layers to increase the output 78 | resolution. The sizes argument indicates the fieature depth at 79 | each upsampling, starting with the input z: [input-dim, 4x4-depth, 80 | 8x8-depth, 16x16-depth...]. The output dimension is 2 * 2**len(sizes) 81 | 82 | Some default architectures can be selected by supplying the 83 | resolution argument instead. 84 | 85 | The optional modify_sequence function can be used to transform the 86 | sequence of layers before the network is constructed. 87 | 88 | If output_tanh is set to True, the network applies a tanh to clamp 89 | the output to [-1,1] before output; otherwise the output is unclamped. 90 | ''' 91 | assert (resolution is None) != (sizes is None) 92 | if sizes is None: 93 | sizes = { 94 | 8: [512, 512, 512], 95 | 16: [512, 512, 512, 512], 96 | 32: [512, 512, 512, 512, 256], 97 | 64: [512, 512, 512, 512, 256, 128], 98 | 128: [512, 512, 512, 512, 256, 128, 64], 99 | 256: [512, 512, 512, 512, 256, 128, 64, 32], 100 | 1024: [512, 512, 512, 512, 512, 256, 128, 64, 32, 16] 101 | }[resolution] 102 | sizes += [3] 103 | sizes.reverse() 104 | # Follow the schedule of upsampling given by sizes. 105 | # layers are called: layer1, layer2, etc; then output_128x128 106 | sequence = [] 107 | def add_d(layer, name=None): 108 | if name is None: 109 | name = 'layer%d' % (len(sequence) + 1) 110 | sequence.append((name, layer)) 111 | 112 | add_d(NormConvBlock(sizes[0], sizes[1], kernel_size=3, padding=1)) 113 | add_d(NormConvBlock(sizes[1], sizes[1], kernel_size=3, padding=1)) 114 | for i, (si, so) in enumerate(zip(sizes[1:-2], sizes[2:-1])): 115 | add_d(NormDownConvBlock(si, so, kernel_size=3, padding=1)) 116 | add_d(NormConvBlock(so, so, kernel_size=3, padding=1)) 117 | # Create an output layer. During training, the progressive GAN 118 | # learns several such output layers for various resolutions; we 119 | # just include the last (highest resolution) one. 120 | # dim = 4 * (2 ** (len(sequence) // 2 - 1)) 121 | # add_d(OutputConvBlock(sizes[-1], tanh=output_tanh), 122 | # name='output_%dx%d' % (dim, dim)) 123 | # Allow the sequence to be modified 124 | add_d(NormConvBlock(sizes[-2], sizes[-1], kernel_size=4, padding=0)) 125 | if modify_sequence is not None: 126 | sequence = modify_sequence(sequence) 127 | super().__init__(OrderedDict(sequence)) 128 | 129 | def forward(self, x): 130 | # Convert vector input to 1x1 featuremap. 131 | # x = x.view(x.shape[0], x.shape[1], 1, 1) 132 | return super().forward(x) 133 | 134 | 135 | class DoubleResolutionLayer(nn.Module): 136 | def forward(self, x): 137 | x = F.interpolate(x, scale_factor=2, mode='nearest') 138 | return x 139 | 140 | class HalfResolutionLayer(nn.Module): 141 | def forward(self, x): 142 | x = F.avg_pool2d(x, kernel_size=2) 143 | return x 144 | 145 | 146 | class WScaleLayer(nn.Module): 147 | def __init__(self, size, fan_in, gain=np.sqrt(2)): 148 | super(WScaleLayer, self).__init__() 149 | self.scale = gain / np.sqrt(fan_in) # No longer a parameter 150 | self.b = nn.Parameter(torch.randn(size)) 151 | self.size = size 152 | 153 | def forward(self, x): 154 | x_size = x.size() 155 | x = x * self.scale + self.b.view(1, -1, 1, 1).expand( 156 | x_size[0], self.size, x_size[2], x_size[3]) 157 | return x 158 | 159 | 160 | class NormConvBlock(nn.Module): 161 | def __init__(self, in_channels, out_channels, kernel_size, padding): 162 | super(NormConvBlock, self).__init__() 163 | self.norm = PixelNormLayer() 164 | self.conv = nn.Conv2d( 165 | in_channels, out_channels, kernel_size, 1, padding, bias=False) 166 | self.wscale = WScaleLayer(out_channels, in_channels, 167 | gain=np.sqrt(2) / kernel_size) 168 | self.relu = nn.LeakyReLU(inplace=True, negative_slope=0.2) 169 | 170 | def forward(self, x): 171 | x = self.norm(x) 172 | x = self.conv(x) 173 | x = self.relu(self.wscale(x)) 174 | return x 175 | 176 | 177 | class NormUpscaleConvBlock(nn.Module): 178 | def __init__(self, in_channels, out_channels, kernel_size, padding): 179 | super(NormUpscaleConvBlock, self).__init__() 180 | self.norm = PixelNormLayer() 181 | self.up = DoubleResolutionLayer() 182 | self.conv = nn.Conv2d( 183 | in_channels, out_channels, kernel_size, 1, padding, bias=False) 184 | self.wscale = WScaleLayer(out_channels, in_channels, 185 | gain=np.sqrt(2) / kernel_size) 186 | self.relu = nn.LeakyReLU(inplace=True, negative_slope=0.2) 187 | 188 | def forward(self, x): 189 | x = self.norm(x) 190 | x = self.up(x) 191 | x = self.conv(x) 192 | x = self.relu(self.wscale(x)) 193 | return x 194 | 195 | 196 | class NormDownConvBlock(nn.Module): 197 | def __init__(self, in_channels, out_channels, kernel_size, padding): 198 | super(NormDownConvBlock, self).__init__() 199 | self.norm = PixelNormLayer() 200 | self.down = HalfResolutionLayer() 201 | self.conv = nn.Conv2d( 202 | in_channels, out_channels, kernel_size, 1, padding, bias=False) 203 | self.wscale = WScaleLayer(out_channels, in_channels, 204 | gain=np.sqrt(2) / kernel_size) 205 | self.relu = nn.LeakyReLU(inplace=True, negative_slope=0.2) 206 | 207 | def forward(self, x): 208 | x = self.norm(x) 209 | x = self.down(x) 210 | x = self.conv(x) 211 | x = self.relu(self.wscale(x)) 212 | return x 213 | 214 | 215 | class OutputConvBlock(nn.Module): 216 | def __init__(self, in_channels, tanh=False): 217 | super().__init__() 218 | self.norm = PixelNormLayer() 219 | self.conv = nn.Conv2d( 220 | in_channels, 3, kernel_size=1, padding=0, bias=False) 221 | self.wscale = WScaleLayer(3, in_channels, gain=1) 222 | self.clamp = nn.Hardtanh() if tanh else (lambda x: x) 223 | 224 | def forward(self, x): 225 | x = self.norm(x) 226 | x = self.conv(x) 227 | x = self.wscale(x) 228 | x = self.clamp(x) 229 | return x 230 | 231 | ############################################################################### 232 | # Conversion 233 | ############################################################################### 234 | 235 | def from_tf_parameters(parameters): 236 | ''' 237 | Instantiate from tensorflow variables. 238 | ''' 239 | state_dict = state_dict_from_tf_parameters(parameters) 240 | sizes = sizes_from_state_dict(state_dict) 241 | result = ProgressiveGenerator(sizes=sizes) 242 | result.load_state_dict(state_dict) 243 | return result 244 | 245 | 246 | def from_old_pt_dict(parameters): 247 | ''' 248 | Instantiate from old pytorch state dict. 249 | ''' 250 | state_dict = state_dict_from_old_pt_dict(parameters) 251 | sizes = sizes_from_state_dict(state_dict) 252 | result = ProgressiveGenerator(sizes=sizes) 253 | result.load_state_dict(state_dict) 254 | return result 255 | 256 | 257 | def sizes_from_state_dict(params): 258 | ''' 259 | In a progressive GAN, the number of channels can change after each 260 | upsampling. This function reads the state dict to figure the 261 | number of upsamplings and the channel depth of each filter. 262 | ''' 263 | sizes = [] 264 | for i in itertools.count(): 265 | pt_layername = 'layer%d' % (i + 1) 266 | try: 267 | weight = params['%s.conv.weight' % pt_layername] 268 | except KeyError: 269 | break 270 | if i == 0: 271 | sizes.append(weight.shape[1]) 272 | if i % 2 == 0: 273 | sizes.append(weight.shape[0]) 274 | return sizes 275 | 276 | 277 | def state_dict_from_tf_parameters(parameters): 278 | ''' 279 | Conversion from tensorflow parameters 280 | ''' 281 | params = dict(parameters) 282 | result = {} 283 | sizes = [] 284 | for i in itertools.count(): 285 | resolution = 4 * (2 ** (i // 2)) 286 | # Translate parameter names. For example: 287 | # 4x4/Dense/weight -> layer1.conv.weight 288 | # 32x32/Conv0_up/weight -> layer7.conv.weight 289 | # 32x32/Conv1/weight -> layer8.conv.weight 290 | tf_layername = '%dx%d/%s' % (resolution, resolution, 291 | 'Dense' if i == 0 else 'Conv' if i == 1 else 292 | 'Conv0_up' if i % 2 == 0 else 'Conv1') 293 | pt_layername = 'layer%d' % (i + 1) 294 | # Stop looping when we run out of parameters. 295 | try: 296 | weight = torch.from_numpy(params['%s/weight' % tf_layername]) 297 | except KeyError: 298 | break 299 | # Transpose convolution weights into pytorch format. 300 | if i == 0: 301 | # Convert dense layer to 4x4 convolution 302 | weight = weight.view(weight.shape[0], weight.shape[1] // 16, 303 | 4, 4).permute(1, 0, 2, 3).flip(2, 3) 304 | sizes.append(weight.shape[0]) 305 | elif i % 2 == 0: 306 | # Convert inverse convolution to convolution 307 | weight = weight.permute(2, 3, 0, 1).flip(2, 3) 308 | else: 309 | # Ordinary Conv2d conversion. 310 | weight = weight.permute(3, 2, 0, 1) 311 | sizes.append(weight.shape[1]) 312 | result['%s.conv.weight' % (pt_layername)] = weight 313 | # Copy bias vector. 314 | bias = torch.from_numpy(params['%s/bias' % tf_layername]) 315 | result['%s.wscale.b' % (pt_layername)] = bias 316 | # Copy just finest-grained ToRGB output layers. For example: 317 | # ToRGB_lod0/weight -> output.conv.weight 318 | i -= 1 319 | resolution = 4 * (2 ** (i // 2)) 320 | tf_layername = 'ToRGB_lod0' 321 | pt_layername = 'output_%dx%d' % (resolution, resolution) 322 | result['%s.conv.weight' % pt_layername] = torch.from_numpy( 323 | params['%s/weight' % tf_layername]).permute(3, 2, 0, 1) 324 | result['%s.wscale.b' % pt_layername] = torch.from_numpy( 325 | params['%s/bias' % tf_layername]) 326 | # Return parameters 327 | return result 328 | 329 | 330 | def state_dict_from_old_pt_dict(params): 331 | ''' 332 | Conversion from the old pytorch model layer names. 333 | ''' 334 | result = {} 335 | sizes = [] 336 | for i in itertools.count(): 337 | old_layername = 'features.%d' % i 338 | pt_layername = 'layer%d' % (i + 1) 339 | try: 340 | weight = params['%s.conv.weight' % (old_layername)] 341 | except KeyError: 342 | break 343 | if i == 0: 344 | sizes.append(weight.shape[0]) 345 | if i % 2 == 0: 346 | sizes.append(weight.shape[1]) 347 | result['%s.conv.weight' % (pt_layername)] = weight 348 | result['%s.wscale.b' % (pt_layername)] = params[ 349 | '%s.wscale.b' % (old_layername)] 350 | # Copy the output layers. 351 | i -= 1 352 | resolution = 4 * (2 ** (i // 2)) 353 | pt_layername = 'output_%dx%d' % (resolution, resolution) 354 | result['%s.conv.weight' % pt_layername] = params['output.conv.weight'] 355 | result['%s.wscale.b' % pt_layername] = params['output.wscale.b'] 356 | # Return parameters and also network architecture sizes. 357 | return result -------------------------------------------------------------------------------- /TorchNet/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonGUTU/TorchTools/b7025ab8469eec87830b109a344a7af551a5d464/TorchNet/README.md -------------------------------------------------------------------------------- /TorchNet/SRMini.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | try: 6 | from math import log2 7 | except: 8 | from math import log 9 | def log2(x): 10 | return log(x) / log(2) 11 | 12 | import math 13 | 14 | from .modules import residualBlock, upsampleBlock 15 | 16 | 17 | class SRSubPixOneLayer(nn.Module): 18 | def __init__(self): 19 | super(SRSubPixOneLayer, self).__init__() 20 | self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1) 21 | self.relu = nn.LeakyReLU(0.2) 22 | self.conv2 = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1) 23 | self.pix = nn.PixelShuffle(4) 24 | 25 | def forward(self, input): 26 | return F.tanh(self.pix(self.conv2(self.relu(self.conv1(input))))) 27 | 28 | 29 | -------------------------------------------------------------------------------- /TorchNet/VGG.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | import math 4 | 5 | 6 | __all__ = [ 7 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 8 | 'vgg19_bn', 'vgg19', 9 | ] 10 | 11 | 12 | model_urls = { 13 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 14 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 15 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 16 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 17 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 18 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 19 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 20 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 21 | } 22 | 23 | 24 | class VGG(nn.Module): 25 | 26 | def __init__(self, features, num_classes=1000): 27 | super(VGG, self).__init__() 28 | self.features = features 29 | self.classifier = nn.Sequential( 30 | nn.Linear(512 * 7 * 7, 4096), 31 | nn.ReLU(True), 32 | nn.Dropout(), 33 | nn.Linear(4096, 4096), 34 | nn.ReLU(True), 35 | nn.Dropout(), 36 | nn.Linear(4096, num_classes), 37 | ) 38 | self._initialize_weights() 39 | 40 | def forward(self, x): 41 | x = self.features(x) 42 | x = x.view(x.size(0), -1) 43 | x = self.classifier(x) 44 | return x 45 | 46 | def _initialize_weights(self): 47 | for m in self.modules(): 48 | if isinstance(m, nn.Conv2d): 49 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 50 | m.weight.data.normal_(0, math.sqrt(2. / n)) 51 | if m.bias is not None: 52 | m.bias.data.zero_() 53 | elif isinstance(m, nn.BatchNorm2d): 54 | m.weight.data.fill_(1) 55 | m.bias.data.zero_() 56 | elif isinstance(m, nn.Linear): 57 | m.weight.data.normal_(0, 0.01) 58 | m.bias.data.zero_() 59 | 60 | 61 | def make_layers(cfg, batch_norm=False): 62 | layers = [] 63 | in_channels = 3 64 | for v in cfg: 65 | if v == 'M': 66 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 67 | else: 68 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 69 | if batch_norm: 70 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 71 | else: 72 | layers += [conv2d, nn.ReLU(inplace=True)] 73 | in_channels = v 74 | return nn.Sequential(*layers) 75 | 76 | 77 | cfg = { 78 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 79 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 80 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 81 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 82 | } 83 | 84 | 85 | def vgg11(pretrained=False, **kwargs): 86 | """VGG 11-layer model (configuration "A") 87 | 88 | Args: 89 | pretrained (bool): If True, returns a model pre-trained on ImageNet 90 | """ 91 | model = VGG(make_layers(cfg['A']), **kwargs) 92 | if pretrained: 93 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11'])) 94 | return model 95 | 96 | 97 | def vgg11_bn(pretrained=False, **kwargs): 98 | """VGG 11-layer model (configuration "A") with batch normalization 99 | 100 | Args: 101 | pretrained (bool): If True, returns a model pre-trained on ImageNet 102 | """ 103 | model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs) 104 | if pretrained: 105 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn'])) 106 | return model 107 | 108 | 109 | def vgg13(pretrained=False, **kwargs): 110 | """VGG 13-layer model (configuration "B") 111 | 112 | Args: 113 | pretrained (bool): If True, returns a model pre-trained on ImageNet 114 | """ 115 | model = VGG(make_layers(cfg['B']), **kwargs) 116 | if pretrained: 117 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13'])) 118 | return model 119 | 120 | 121 | def vgg13_bn(pretrained=False, **kwargs): 122 | """VGG 13-layer model (configuration "B") with batch normalization 123 | 124 | Args: 125 | pretrained (bool): If True, returns a model pre-trained on ImageNet 126 | """ 127 | model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs) 128 | if pretrained: 129 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn'])) 130 | return model 131 | 132 | 133 | def vgg16(pretrained=False, **kwargs): 134 | """VGG 16-layer model (configuration "D") 135 | 136 | Args: 137 | pretrained (bool): If True, returns a model pre-trained on ImageNet 138 | """ 139 | model = VGG(make_layers(cfg['D']), **kwargs) 140 | if pretrained: 141 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16'])) 142 | return model 143 | 144 | 145 | def vgg16_bn(pretrained=False, **kwargs): 146 | """VGG 16-layer model (configuration "D") with batch normalization 147 | 148 | Args: 149 | pretrained (bool): If True, returns a model pre-trained on ImageNet 150 | """ 151 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs) 152 | if pretrained: 153 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn'])) 154 | return model 155 | 156 | 157 | def vgg19(pretrained=False, **kwargs): 158 | """VGG 19-layer model (configuration "E") 159 | 160 | Args: 161 | pretrained (bool): If True, returns a model pre-trained on ImageNet 162 | """ 163 | model = VGG(make_layers(cfg['E']), **kwargs) 164 | if pretrained: 165 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19'])) 166 | return model 167 | 168 | 169 | def vgg19_bn(pretrained=False, **kwargs): 170 | """VGG 19-layer model (configuration 'E') with batch normalization 171 | 172 | Args: 173 | pretrained (bool): If True, returns a model pre-trained on ImageNet 174 | """ 175 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs) 176 | if pretrained: 177 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn'])) 178 | return model 179 | -------------------------------------------------------------------------------- /TorchNet/Visualizing.py: -------------------------------------------------------------------------------- 1 | from mpl_toolkits.mplot3d import Axes3D 2 | from matplotlib import pyplot as plt 3 | from matplotlib import cm 4 | import h5py 5 | import argparse 6 | import numpy as np 7 | from os.path import exists 8 | import seaborn as sns 9 | 10 | 11 | def plot_2d_contour(surf_file, surf_name='train_loss', vmin=0.1, vmax=10, vlevel=0.5, show=False): 12 | """Plot 2D contour map and 3D surface.""" 13 | 14 | f = h5py.File(surf_file, 'r') 15 | x = np.array(f['xcoordinates'][:]) 16 | y = np.array(f['ycoordinates'][:]) 17 | X, Y = np.meshgrid(x, y) 18 | 19 | if surf_name in f.keys(): 20 | Z = np.array(f[surf_name][:]) 21 | elif surf_name == 'train_err' or surf_name == 'test_err' : 22 | Z = 100 - np.array(f[surf_name][:]) 23 | else: 24 | print ('%s is not found in %s' % (surf_name, surf_file)) 25 | 26 | print('------------------------------------------------------------------') 27 | print('plot_2d_contour') 28 | print('------------------------------------------------------------------') 29 | print("loading surface file: " + surf_file) 30 | print('len(xcoordinates): %d len(ycoordinates): %d' % (len(x), len(y))) 31 | print('max(%s) = %f \t min(%s) = %f' % (surf_name, np.max(Z), surf_name, np.min(Z))) 32 | print(Z) 33 | 34 | if (len(x) <= 1 or len(y) <= 1): 35 | print('The length of coordinates is not enough for plotting contours') 36 | return 37 | 38 | # -------------------------------------------------------------------- 39 | # Plot 2D contours 40 | # -------------------------------------------------------------------- 41 | fig = plt.figure() 42 | CS = plt.contour(X, Y, Z, cmap='summer', levels=np.arange(vmin, vmax, vlevel)) 43 | plt.clabel(CS, inline=1, fontsize=8) 44 | fig.savefig(surf_file + '_' + surf_name + '_2dcontour' + '.pdf', dpi=300, 45 | bbox_inches='tight', format='pdf') 46 | 47 | fig = plt.figure() 48 | print(surf_file + '_' + surf_name + '_2dcontourf' + '.pdf') 49 | CS = plt.contourf(X, Y, Z, cmap='summer', levels=np.arange(vmin, vmax, vlevel)) 50 | fig.savefig(surf_file + '_' + surf_name + '_2dcontourf' + '.pdf', dpi=300, 51 | bbox_inches='tight', format='pdf') 52 | 53 | # -------------------------------------------------------------------- 54 | # Plot 2D heatmaps 55 | # -------------------------------------------------------------------- 56 | fig = plt.figure() 57 | sns_plot = sns.heatmap(Z, cmap='viridis', cbar=True, vmin=vmin, vmax=vmax, 58 | xticklabels=False, yticklabels=False) 59 | sns_plot.invert_yaxis() 60 | sns_plot.get_figure().savefig(surf_file + '_' + surf_name + '_2dheat.pdf', 61 | dpi=300, bbox_inches='tight', format='pdf') 62 | 63 | # -------------------------------------------------------------------- 64 | # Plot 3D surface 65 | # -------------------------------------------------------------------- 66 | fig = plt.figure() 67 | ax = Axes3D(fig) 68 | surf = ax.plot_surface(X, Y, Z, cmap=cm.coolwarm, linewidth=0, antialiased=False) 69 | fig.colorbar(surf, shrink=0.5, aspect=5) 70 | fig.savefig(surf_file + '_' + surf_name + '_3dsurface.pdf', dpi=300, 71 | bbox_inches='tight', format='pdf') 72 | 73 | f.close() 74 | if show: plt.show() -------------------------------------------------------------------------------- /TorchNet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonGUTU/TorchTools/b7025ab8469eec87830b109a344a7af551a5d464/TorchNet/__init__.py -------------------------------------------------------------------------------- /TorchNet/activation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def swish(x): 6 | return x * F.sigmoid(x) -------------------------------------------------------------------------------- /TorchNet/modules.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import PIL.Image as Image 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn.functional import relu 9 | from torch.autograd import Variable 10 | 11 | from .activation import swish 12 | from ..Functions import functional as Func 13 | 14 | def GaussianNoising(tensor, sigma, mean=0.0, noise_size=None, min=0.0, max=1.0): 15 | if noise_size is None: 16 | size = tensor.size() 17 | else: 18 | size = noise_size 19 | noise = torch.FloatTensor(np.random.normal(loc=mean, scale=sigma, size=size)) 20 | return torch.clamp(noise + tensor, min=min, max=max) 21 | 22 | 23 | def b_GaussianNoising(tensor, sigma, mean=0.0, noise_size=None, min=0.0, max=1.0): 24 | if noise_size is None: 25 | size = tensor.size() 26 | else: 27 | size = noise_size 28 | noise = torch.mul(torch.FloatTensor(np.random.normal(loc=mean, scale=1.0, size=size)), sigma.view(sigma.size() + (1, 1))) 29 | return torch.clamp(noise + tensor, min=min, max=max) 30 | 31 | 32 | def PoissonNoising(tensor, lamb, noise_size=None, min=0.0, max=1.0): 33 | if noise_size is None: 34 | size = tensor.size() 35 | else: 36 | size = noise_size 37 | noise = torch.FloatTensor(np.random.poisson(lam=lamb, size=size)) 38 | return torch.clamp(noise + tensor, min=min, max=max) 39 | 40 | 41 | class Blur(nn.Module): 42 | def __init__(self, l=15, kernel=None): 43 | super(Blur, self).__init__() 44 | self.l = l 45 | self.pad = nn.ReflectionPad2d(l // 2) 46 | self.kernel = Variable(torch.FloatTensor(kernel).view((1, 1, self.l, self.l))) 47 | 48 | def cuda(self, device=None): 49 | self.kernel = self.kernel.cuda() 50 | 51 | def forward(self, input): 52 | B, C, H, W = input.size() 53 | pad = self.pad(input) 54 | H_p, W_p = pad.size()[-2:] 55 | input_CBHW = pad.view((C * B, 1, H_p, W_p)) 56 | 57 | return F.conv2d(input_CBHW, self.kernel).view(B, C, H, W) 58 | 59 | 60 | class BatchBlur(nn.Module): 61 | def __init__(self, l=15): 62 | super(BatchBlur, self).__init__() 63 | self.l = l 64 | if l % 2 == 1: 65 | self.pad = nn.ReflectionPad2d(l // 2) 66 | else: 67 | self.pad = nn.ReflectionPad2d((l // 2, l // 2 - 1, l // 2, l // 2 - 1)) 68 | # self.pad = nn.ZeroPad2d(l // 2) 69 | 70 | def forward(self, input, kernel): 71 | B, C, H, W = input.size() 72 | pad = self.pad(input) 73 | H_p, W_p = pad.size()[-2:] 74 | 75 | if len(kernel.size()) == 2: 76 | input_CBHW = pad.view((C * B, 1, H_p, W_p)) 77 | kernel_var = kernel.contiguous().view((1, 1, self.l, self.l)) 78 | 79 | return F.conv2d(input_CBHW, kernel_var, padding=0).view((B, C, H, W)) 80 | else: 81 | input_CBHW = pad.view((1, C * B, H_p, W_p)) 82 | kernel_var = kernel.contiguous().view((B, 1, self.l, self.l)).repeat(1, C, 1, 1).view((B * C, 1, self.l, self.l)) 83 | return F.conv2d(input_CBHW, kernel_var, groups=B*C).view((B, C, H, W)) 84 | 85 | 86 | def b_GPUVar_Bicubic(Var, scale): 87 | tensor = Var.cpu().data 88 | B, C, H, W = tensor.size() 89 | H_new = int(H / scale) 90 | W_new = int(W / scale) 91 | tensor_v = tensor.view((B*C, 1, H, W)) 92 | re_tensor = torch.zeros((B*C, 1, H_new, W_new)) 93 | for i in range(B*C): 94 | img = Func.to_pil_image(tensor_v[i]) 95 | re_tensor[i] = Func.to_tensor(Func.resize(img, (H_new, W_new), interpolation=Image.BICUBIC)) 96 | re_tensor_v = re_tensor.view((B, C, H_new, W_new)) 97 | return re_tensor_v 98 | 99 | 100 | def b_CPUVar_Bicubic(Var, scale): 101 | tensor = Var.data 102 | B, C, H, W = tensor.size() 103 | H_new = int(H / scale) 104 | W_new = int(W / scale) 105 | tensor_v = tensor.view((B*C, 1, H, W)) 106 | re_tensor = torch.zeros((B*C, 1, H_new, W_new)) 107 | for i in range(B*C): 108 | img = Func.to_pil_image(tensor_v[i]) 109 | re_tensor[i] = Func.to_tensor(Func.resize(img, (H_new, W_new), interpolation=Image.BICUBIC)) 110 | re_tensor_v = re_tensor.view((B, C, H_new, W_new)) 111 | return re_tensor_v 112 | 113 | 114 | class GaussianBlur(nn.Module): 115 | def __init__(self, input_channel, l=3, sigma=0.6): 116 | super(GaussianBlur, self).__init__() 117 | self.l = l 118 | self.sig = sigma 119 | self.kernel = Variable(self._g_kernel().view((1, input_channel, l, l))) 120 | self.pad = nn.ReflectionPad2d(l // 2) 121 | 122 | def _g_kernel(self): 123 | ax = np.arange(-self.l // 2 + 1., self.l // 2 + 1.) 124 | xx, yy = np.meshgrid(ax, ax) 125 | kernel = np.exp(-(xx ** 2 + yy ** 2) / (2. * self.sig ** 2)) 126 | return torch.FloatTensor(kernel / np.sum(kernel)) 127 | 128 | def forward(self, input): 129 | return F.conv2d(self.pad(input), self.kernel) 130 | 131 | 132 | class RandomBlur(nn.Module): 133 | def __init__(self, input_channel=1, kernel_size=15, sigma_min=0.0, sigma_max=1.4): 134 | super(RandomBlur, self).__init__() 135 | self.min = sigma_min 136 | self.max = sigma_max 137 | self.l = kernel_size 138 | self.input_channel = input_channel 139 | self.pad = nn.ReflectionPad2d(kernel_size // 2) 140 | 141 | def _g_kernel(self, sigma): 142 | ax = np.arange(-self.l // 2 + 1., self.l // 2 + 1.) 143 | xx, yy = np.meshgrid(ax, ax) 144 | kernel = np.exp(-(xx ** 2 + yy ** 2) / (2. * sigma ** 2)) 145 | return torch.FloatTensor(kernel / np.sum(kernel)) 146 | 147 | def forward(self, input): 148 | sigma_random = random.uniform(self.min, self.max) 149 | kernel = Variable(self._g_kernel(sigma_random).view((1, self.input_channel, self.l, self.l))) 150 | return F.conv2d(self.pad(input), kernel) 151 | 152 | 153 | class RandomNoisedBlur(nn.Module): 154 | def __init__(self, input_channel=1, kernel_size=15, sigma_min=0.0, sigma_max=1.4, noise_sigma=0.02): 155 | super(RandomNoisedBlur, self).__init__() 156 | self.min = sigma_min 157 | self.max = sigma_max 158 | self.l = kernel_size 159 | self.noise = noise_sigma 160 | self.input_channel = input_channel 161 | self.pad = nn.ReflectionPad2d(kernel_size // 2) 162 | 163 | def _g_kernel(self, sigma): 164 | ax = np.arange(-self.l // 2 + 1., self.l // 2 + 1.) 165 | xx, yy = np.meshgrid(ax, ax) 166 | kernel = np.exp(-(xx ** 2 + yy ** 2) / (2. * sigma ** 2)) 167 | noise = np.random.normal(1.0, self.noise, kernel.shape) 168 | gaussian = kernel / np.sum(kernel) 169 | return torch.FloatTensor(gaussian * noise) 170 | 171 | def forward(self, input): 172 | sigma_random = random.uniform(self.min, self.max) 173 | kernel = Variable(self._g_kernel(sigma_random).view((1, self.input_channel, self.l, self.l))) 174 | return F.conv2d(self.pad(input), kernel) 175 | 176 | 177 | class Flatten(nn.Module): 178 | def forward(self, x): 179 | x = x.view(x.size()[0], -1) 180 | return x 181 | 182 | 183 | class FeatureExtractor(nn.Module): 184 | 185 | def __init__(self, cnn, feature_layer=11): 186 | super(FeatureExtractor, self).__init__() 187 | self.features = nn.Sequential(*list(cnn.features.children())[:(feature_layer+1)]) 188 | 189 | def forward(self, x): 190 | # TODO convert x: RGB to BGR 191 | return self.features(x) 192 | 193 | 194 | class residualBlock(nn.Module): 195 | 196 | def __init__(self, in_channels=64, kernel=3, mid_channels=64, out_channels=64, stride=1, activation=relu): 197 | super(residualBlock, self).__init__() 198 | self.act = activation 199 | self.pad1 = nn.ReflectionPad2d((kernel // 2)) 200 | self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=mid_channels, kernel_size=kernel, stride=stride, padding=0) 201 | self.bn1 = nn.BatchNorm2d(mid_channels) 202 | self.pad2 = nn.ReflectionPad2d((kernel // 2)) 203 | self.conv2 = nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=kernel, stride=stride, padding=0) 204 | self.bn2 = nn.BatchNorm2d(out_channels) 205 | 206 | def forward(self, x): 207 | y = self.act(self.bn1(self.conv1(self.pad1(x)))) 208 | return self.bn2(self.conv2(self.pad2(y))) + x 209 | 210 | 211 | class residualBlockNoBN(nn.Module): 212 | 213 | def __init__(self, in_channels=64, kernel=3, mid_channels=64, out_channels=64, stride=1, activation=relu): 214 | super(residualBlockNoBN, self).__init__() 215 | self.act = activation 216 | self.pad1 = nn.ReflectionPad2d((kernel // 2)) 217 | self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=mid_channels, kernel_size=kernel, stride=stride, padding=0) 218 | self.pad2 = nn.ReflectionPad2d((kernel // 2)) 219 | self.conv2 = nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=kernel, stride=stride, padding=0) 220 | 221 | def forward(self, x): 222 | y = self.act(self.conv1(self.pad1(x))) 223 | return self.conv2(self.pad2(y)) + x 224 | 225 | 226 | class residualBlockIN(nn.Module): 227 | def __init__(self, in_channels=64, kernel=3, mid_channels=64, out_channels=64, stride=1, activation=relu): 228 | super(residualBlockIN, self).__init__() 229 | 230 | self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=mid_channels, kernel_size=kernel, stride=stride, padding=kernel // 2, bias=False) 231 | self.in1 = nn.InstanceNorm2d(mid_channels, affine=True) 232 | self.act = activation 233 | self.conv2 = nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=kernel, stride=stride, padding=kernel // 2, bias=False) 234 | self.in2 = nn.InstanceNorm2d(64, affine=True) 235 | 236 | def forward(self, x): 237 | identity_data = x 238 | output = self.act(self.in1(self.conv1(x))) 239 | output = self.in2(self.conv2(output)) 240 | output = torch.add(output, identity_data) 241 | return output 242 | 243 | 244 | class upsampleBlock(nn.Module): 245 | 246 | def __init__(self, in_channels, out_channels, activation=relu): 247 | super(upsampleBlock, self).__init__() 248 | self.act = activation 249 | self.pad = nn.ReflectionPad2d(1) 250 | self.conv = nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=0) 251 | self.shuffler = nn.PixelShuffle(2) 252 | 253 | def forward(self, x): 254 | return self.act(self.shuffler(self.conv(self.pad(x)))) 255 | 256 | 257 | class deconvUpsampleBlock(nn.Module): 258 | 259 | def __init__(self, in_channels, mid_channels, out_channels, kernel_1=5, kernel_2=3, activation=relu): 260 | self.act = activation 261 | super(deconvUpsampleBlock, self).__init__() 262 | self.deconv_1 = nn.ConvTranspose2d(in_channels=in_channels, out_channels=mid_channels, kernel_size=kernel_1, stride=2, padding=kernel_1 // 2) 263 | # self.deconv_2 = nn.ConvTranspose2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=3, stride=kernel_2, padding=kernel_2 // 2) 264 | 265 | def forward(self, x): 266 | return self.act(self.deconv_1(x)) 267 | 268 | 269 | class Features4Layer(nn.Module): 270 | """ 271 | Basic feature extractor, 4 layer version 272 | """ 273 | def __init__(self, features=64, activation=relu): 274 | """ 275 | :param frame: The input frame image 276 | :param features: feature maps per layer 277 | """ 278 | super(Features4Layer, self).__init__() 279 | self.act = activation 280 | 281 | self.pad1 = nn.ReflectionPad2d(2) 282 | self.conv1 = nn.Conv2d(1, features, 5, stride=1, padding=0) 283 | 284 | self.pad2 = nn.ReflectionPad2d(1) 285 | self.conv2 = nn.Conv2d(features, features, 3, stride=1, padding=0) 286 | self.bn2 = nn.BatchNorm2d(features) 287 | 288 | self.pad3 = nn.ReflectionPad2d(1) 289 | self.conv3 = nn.Conv2d(features, features, 3, stride=1, padding=0) 290 | self.bn3 = nn.BatchNorm2d(features) 291 | 292 | self.pad4 = nn.ReflectionPad2d(1) 293 | self.conv4 = nn.Conv2d(features, features, 3, stride=1, padding=0) 294 | 295 | def forward(self, frame): 296 | return self.act(self.conv4(self.pad4( 297 | self.act(self.bn3(self.conv3(self.pad3( 298 | self.act(self.bn2(self.conv2(self.pad2( 299 | self.act(self.conv1(self.pad1(frame))) 300 | )))) 301 | )))) 302 | ))) 303 | 304 | 305 | class Features3Layer(nn.Module): 306 | """ 307 | Basic feature extractor, 4 layer version 308 | """ 309 | def __init__(self, features=64, activation=relu): 310 | """ 311 | :param frame: The input frame image 312 | :param features: feature maps per layer 313 | """ 314 | super(Features3Layer, self).__init__() 315 | self.act = activation 316 | 317 | self.pad1 = nn.ReflectionPad2d(2) 318 | self.conv1 = nn.Conv2d(1, features, 5, stride=1, padding=0) 319 | 320 | self.pad2 = nn.ReflectionPad2d(1) 321 | self.conv2 = nn.Conv2d(features, features, 3, stride=1, padding=0) 322 | self.bn2 = nn.BatchNorm2d(features) 323 | 324 | self.pad3 = nn.ReflectionPad2d(1) 325 | self.conv3 = nn.Conv2d(features, features, 3, stride=1, padding=0) 326 | 327 | def forward(self, frame): 328 | return self.act(self.conv3(self.pad3( 329 | self.act(self.bn2(self.conv2(self.pad2( 330 | self.act(self.conv1(self.pad1(frame))) 331 | )))) 332 | ))) 333 | 334 | 335 | class LateUpsamplingBlock(nn.Module): 336 | """ 337 | this is another up-sample block for step upsample 338 | |------------------------------| 339 | | features | 340 | |------------------------------| 341 | | n | residual blocks | 342 | |------------------------------| 343 | | Pixel shuffle up-sampling x2 | 344 | |------------------------------| 345 | """ 346 | def __init__(self, features=64, n_res_block=3): 347 | """ 348 | :param features: number of feature maps input 349 | :param n_res_block: number of residual blocks 350 | """ 351 | super(LateUpsamplingBlock, self).__init__() 352 | self.n_residual_blocks = n_res_block 353 | 354 | for i in range(self.n_residual_blocks): 355 | self.add_module('residual_block' + str(i + 1), residualBlock(features)) 356 | 357 | self.upsample = upsampleBlock(features, features * 4) 358 | 359 | def forward(self, features): 360 | for i in range(self.n_residual_blocks): 361 | features = self.__getattr__('residual_block' + str(i + 1))(features) 362 | return self.upsample(features) 363 | 364 | 365 | class LateUpsamplingBlockNoBN(nn.Module): 366 | """ 367 | this is another up-sample block for step upsample 368 | |------------------------------| 369 | | features | 370 | |------------------------------| 371 | | n | residual blocks | 372 | |------------------------------| 373 | | Pixel shuffle up-sampling x2 | 374 | |------------------------------| 375 | """ 376 | def __init__(self, features=64, n_res_block=3): 377 | """ 378 | :param features: number of feature maps input 379 | :param n_res_block: number of residual blocks 380 | """ 381 | super(LateUpsamplingBlockNoBN, self).__init__() 382 | self.n_residual_blocks = n_res_block 383 | 384 | for i in range(self.n_residual_blocks): 385 | self.add_module('residual_block' + str(i + 1), residualBlockNoBN(features)) 386 | 387 | self.upsample = upsampleBlock(features, features * 4) 388 | 389 | def forward(self, features): 390 | for i in range(self.n_residual_blocks): 391 | features = self.__getattr__('residual_block' + str(i + 1))(features) 392 | return self.upsample(features) 393 | 394 | 395 | class DownsamplingShuffle(nn.Module): 396 | 397 | def __init__(self, scala): 398 | super(DownsamplingShuffle, self).__init__() 399 | self.scala = scala 400 | 401 | def forward(self, input): 402 | """ 403 | input should be 4D tensor N, C, H, W 404 | :param input: 405 | :return: 406 | """ 407 | N, C, H, W = input.size() 408 | assert H % self.scala == 0, 'Plz Check input and scala' 409 | assert W % self.scala == 0, 'Plz Check input and scala' 410 | map_channels = self.scala ** 2 411 | channels = C * map_channels 412 | out_height = H // self.scala 413 | out_width = W // self.scala 414 | 415 | input_view = input.contiguous().view( 416 | N, C, out_height, self.scala, out_width, self.scala) 417 | 418 | shuffle_out = input_view.permute(0, 1, 3, 5, 2, 4).contiguous() 419 | 420 | return shuffle_out.view(N, channels, out_height, out_width) 421 | 422 | 423 | class _AttentionDownConv(nn.Module): 424 | def __init__(self, features=16): 425 | super(_AttentionDownConv, self).__init__() 426 | self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1) 427 | self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1) 428 | self.downsample = nn.Conv2d(features, features, kernel_size=3, stride=2, padding=1) 429 | 430 | def forward(self, input): 431 | return F.relu(self.downsample( 432 | F.relu(self.conv2( 433 | F.relu(self.conv1(input)) 434 | )) 435 | )) 436 | 437 | 438 | class _AttentionUpConv(nn.Module): 439 | def __init__(self, features=16): 440 | super(_AttentionUpConv, self).__init__() 441 | self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1) 442 | self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1) 443 | self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1) 444 | self.upsample = nn.UpsamplingBilinear2d(scale_factor=2) 445 | 446 | def forward(self, input): 447 | return F.relu(self.upsample( 448 | F.relu(self.conv2( 449 | F.relu(self.conv1(input)) 450 | )) 451 | )) 452 | 453 | 454 | class Attention(nn.Module): 455 | """ 456 | Attention Module, output with sigmoid 457 | """ 458 | def __init__(self, input_channel=1, feature_channels=16, down_samples=2): 459 | super(Attention, self).__init__() 460 | self.input = input_channel 461 | self.ngf = feature_channels 462 | self.down = down_samples 463 | self.down_square = 2 ** down_samples 464 | 465 | self.input_conv = nn.Conv2d(input_channel, feature_channels, kernel_size=5, stride=1, padding=2) 466 | 467 | self.final_conv = nn.Conv2d(feature_channels, 1, kernel_size=5, stride=1, padding=2) 468 | 469 | for i in range(down_samples): 470 | self.add_module('down_sample_' + str(i + 1), _AttentionDownConv(features=feature_channels)) 471 | 472 | for i in range(down_samples): 473 | self.add_module('up_sample_' + str(i + 1), _AttentionUpConv(features=feature_channels)) 474 | 475 | def forward(self, input): 476 | B, C, H, W = input.size() 477 | pad_H = self.down_square - (H % self.down_square) if H % self.down_square != 0 else 0 478 | pad_W = self.down_square - (W % self.down_square) if W % self.down_square != 0 else 0 479 | 480 | input_pad = F.pad(input, (0, pad_H, 0, pad_W), 'reflect') 481 | 482 | 483 | output = F.relu(self.input_conv(input_pad)) 484 | 485 | for i in range(self.down): 486 | output = self.__getattr__('down_sample_' + str(i + 1))(output) 487 | 488 | for i in range(self.down): 489 | output = self.__getattr__('up_sample_' + str(i + 1))(output) 490 | 491 | output = self.final_conv(output) 492 | output_pad = F.pad(output, (0, -pad_H, 0, -pad_W)) 493 | 494 | return F.sigmoid(output_pad) 495 | 496 | -------------------------------------------------------------------------------- /TorchNet/tSNE.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonGUTU/TorchTools/b7025ab8469eec87830b109a344a7af551a5d464/TorchNet/tSNE.py -------------------------------------------------------------------------------- /TorchNet/tools.py: -------------------------------------------------------------------------------- 1 | from math import sqrt, ceil 2 | import random 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | 9 | import PIL 10 | from PIL import Image 11 | 12 | from ..Functions.functional import to_pil_image 13 | from collections import OrderedDict 14 | 15 | 16 | class ImagePool(): 17 | def __init__(self, pool_size): 18 | self.pool_size = pool_size 19 | if self.pool_size > 0: 20 | self.num_imgs = 0 21 | self.images = [] 22 | 23 | def query(self, images): 24 | if self.pool_size == 0: 25 | return Variable(images) 26 | return_images = [] 27 | for image in images: 28 | image = torch.unsqueeze(image, 0) 29 | if self.num_imgs < self.pool_size: 30 | self.num_imgs = self.num_imgs + 1 31 | self.images.append(image) 32 | return_images.append(image) 33 | else: 34 | p = random.uniform(0, 1) 35 | if p > 0.5: 36 | random_id = random.randint(0, self.pool_size - 1) 37 | tmp = self.images[random_id].clone() 38 | self.images[random_id] = image 39 | return_images.append(tmp) 40 | else: 41 | return_images.append(image) 42 | return_images = Variable(torch.cat(return_images, 0)) 43 | return return_images 44 | 45 | 46 | def calculate_parameters(model): 47 | parameters = 0 48 | for weight in model.parameters(): 49 | p = 1 50 | for dim in weight.size(): 51 | p *= dim 52 | parameters += p 53 | return parameters 54 | 55 | 56 | def print_network(net): 57 | num_params = 0 58 | for param in net.parameters(): 59 | num_params += param.numel() 60 | print(net) 61 | print('Total number of parameters: %d' % num_params) 62 | 63 | 64 | def FeatureMapsVisualization(feature_maps, instan_norm=False): 65 | """ 66 | visualize feature maps 67 | :param feature_maps: must be 4D tensor with B equals to 1 or 3D tensor N * H * W 68 | :return: PIL.Image of feature maps 69 | """ 70 | if len(feature_maps.size()) == 4: 71 | feature_maps = feature_maps.view(feature_maps.size()[1:]) 72 | if not instan_norm: 73 | feature_maps = (feature_maps - feature_maps.min()) / (feature_maps.max() - feature_maps.min()) 74 | maps_number = feature_maps.size()[0] 75 | feature_H = feature_maps.size()[1] 76 | feature_W = feature_maps.size()[2] 77 | W_n = ceil(sqrt(maps_number)) 78 | H_n = ceil(maps_number / W_n) 79 | map_W = W_n * feature_W 80 | map_H = H_n * feature_H 81 | MAP = Image.new('L', (map_W, map_H)) 82 | for i in range(maps_number): 83 | map_t = feature_maps[i] 84 | if instan_norm: 85 | map_t = (map_t - map_t.min()) / (map_t.max() - map_t.min()) 86 | map_t = map_t.view((1, ) + map_t.size()) 87 | map_pil = to_pil_image(map_t) 88 | n_row = i % W_n 89 | n_col = i // W_n 90 | MAP.paste(map_pil, (n_row * feature_W, n_col * feature_H)) 91 | return MAP 92 | 93 | 94 | def ModelToSequential(model, seq_output=True): 95 | Sequential_list = list() 96 | for sub in model.children(): 97 | if isinstance(sub, torch.nn.modules.container.Sequential): 98 | Sequential_list.extend(ModelToSequential(sub, seq_output=False)) 99 | else: 100 | Sequential_list.append(sub) 101 | if seq_output: 102 | return nn.Sequential(*Sequential_list) 103 | else: 104 | return Sequential_list 105 | 106 | 107 | # def KernelsVisualization(kernels, instan_norm=False): 108 | # """ 109 | # visualize feature maps 110 | # :param feature_maps: must be 4D tensor 111 | # :return: PIL.Image of feature maps 112 | # """ 113 | # if not instan_norm: 114 | # feature_maps = (kernels - kernels.min()) / (kernels.max() - kernels.min()) 115 | # kernels_out = kernels.size()[0] 116 | # kernels_in = kernels.size()[1] 117 | # feature_H = kernels.size()[2] 118 | # feature_W = kernels.size()[3] 119 | # W_n = ceil(sqrt(kernels_in)) 120 | # H_n = ceil(kernels_in / W_n) 121 | # big_W_n = ceil(sqrt(kernels_out)) 122 | # big_H_n = ceil(kernels_out / W_n) 123 | # map_W = W_n * feature_W 124 | # map_H = H_n * feature_H 125 | # MAP = Image.new('L', (map_W, map_H)) 126 | # for i in range(maps_number): 127 | # map_t = feature_maps[i] 128 | # if instan_norm: 129 | # map_t = (map_t - map_t.min()) / (map_t.max() - map_t.min()) 130 | # map_t = map_t.view((1, ) + map_t.size()) 131 | # map_pil = to_pil_image(map_t) 132 | # n_row = i % W_n 133 | # n_col = i // W_n 134 | # MAP.paste(map_pil, (n_row * feature_W, n_col * feature_H)) 135 | # return MAP 136 | 137 | def summary(model, input_size, batch_size=-1, device="cuda"): 138 | 139 | def register_hook(module): 140 | 141 | def hook(module, input, output): 142 | class_name = str(module.__class__).split(".")[-1].split("'")[0] 143 | module_idx = len(summary) 144 | 145 | m_key = "%s-%i" % (class_name, module_idx + 1) 146 | summary[m_key] = OrderedDict() 147 | summary[m_key]["input_shape"] = list(input[0].size()) 148 | summary[m_key]["input_shape"][0] = batch_size 149 | if isinstance(output, (list, tuple)): 150 | summary[m_key]["output_shape"] = [ 151 | [-1] + list(o.size())[1:] for o in output 152 | ] 153 | else: 154 | summary[m_key]["output_shape"] = list(output.size()) 155 | summary[m_key]["output_shape"][0] = batch_size 156 | 157 | params = 0 158 | if hasattr(module, "weight") and hasattr(module.weight, "size"): 159 | params += torch.prod(torch.LongTensor(list(module.weight.size()))) 160 | summary[m_key]["trainable"] = module.weight.requires_grad 161 | if hasattr(module, "bias") and hasattr(module.bias, "size"): 162 | params += torch.prod(torch.LongTensor(list(module.bias.size()))) 163 | summary[m_key]["nb_params"] = params 164 | 165 | if ( 166 | not isinstance(module, nn.Sequential) 167 | and not isinstance(module, nn.ModuleList) 168 | and not (module == model) 169 | ): 170 | hooks.append(module.register_forward_hook(hook)) 171 | 172 | device = device.lower() 173 | assert device in [ 174 | "cuda", 175 | "cpu", 176 | ], "Input device is not valid, please specify 'cuda' or 'cpu'" 177 | 178 | if device == "cuda" and torch.cuda.is_available(): 179 | dtype = torch.cuda.FloatTensor 180 | else: 181 | dtype = torch.FloatTensor 182 | 183 | # multiple inputs to the network 184 | if isinstance(input_size, tuple): 185 | input_size = [input_size] 186 | 187 | # batch_size of 2 for batchnorm 188 | x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size] 189 | # print(type(x[0])) 190 | 191 | # create properties 192 | summary = OrderedDict() 193 | hooks = [] 194 | 195 | # register hook 196 | model.apply(register_hook) 197 | 198 | # make a forward pass 199 | # print(x.shape) 200 | model(*x) 201 | 202 | # remove these hooks 203 | for h in hooks: 204 | h.remove() 205 | 206 | print("----------------------------------------------------------------") 207 | line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #") 208 | print(line_new) 209 | print("================================================================") 210 | total_params = 0 211 | total_output = 0 212 | trainable_params = 0 213 | for layer in summary: 214 | # input_shape, output_shape, trainable, nb_params 215 | line_new = "{:>20} {:>25} {:>15}".format( 216 | layer, 217 | str(summary[layer]["output_shape"]), 218 | "{0:,}".format(summary[layer]["nb_params"]), 219 | ) 220 | total_params += summary[layer]["nb_params"] 221 | total_output += np.prod(summary[layer]["output_shape"]) 222 | if "trainable" in summary[layer]: 223 | if summary[layer]["trainable"] == True: 224 | trainable_params += summary[layer]["nb_params"] 225 | print(line_new) 226 | 227 | # assume 4 bytes/number (float on cuda). 228 | total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.)) 229 | total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients 230 | total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.)) 231 | total_size = total_params_size + total_output_size + total_input_size 232 | 233 | print("================================================================") 234 | print("Total params: {0:,}".format(total_params)) 235 | print("Trainable params: {0:,}".format(trainable_params)) 236 | print("Non-trainable params: {0:,}".format(total_params - trainable_params)) 237 | print("----------------------------------------------------------------") 238 | print("Input size (MB): %0.2f" % total_input_size) 239 | print("Forward/backward pass size (MB): %0.2f" % total_output_size) 240 | print("Params size (MB): %0.2f" % total_params_size) 241 | print("Estimated Total Size (MB): %0.2f" % total_size) 242 | print("----------------------------------------------------------------") 243 | # return summary 244 | 245 | 246 | 247 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonGUTU/TorchTools/b7025ab8469eec87830b109a344a7af551a5d464/__init__.py -------------------------------------------------------------------------------- /batch_PSNR_SSIM.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import argparse 3 | import scipy.misc 4 | import os.path 5 | 6 | from TorchTools.Functions.SRMeasure import psnr, ssim_exact, niqe 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--gt', type=str, default='', help='Ground Truth image path') 10 | parser.add_argument('--img', type=str, default='', help='Image dir path') 11 | opt = parser.parse_args() 12 | 13 | abs_img_path = os.path.abspath(opt.img) 14 | img_list = os.listdir(abs_img_path) 15 | abs_img_list = [] 16 | 17 | for i in range(len(img_list)): 18 | abs_img_list.append(os.path.join(abs_img_path, img_list[i])) 19 | 20 | gt = os.path.abspath(opt.gt) 21 | ref = scipy.misc.imread(gt, flatten=True).astype(numpy.float32) 22 | 23 | PSNR_list = [] 24 | SSIM_list = [] 25 | NIQE_list = [] 26 | 27 | return_str = 'GT: %s\n' % gt 28 | for i in range(len(img_list)): 29 | return_str += '%d. %s: ' % (i + 1, img_list[i]) 30 | img = scipy.misc.imread(abs_img_list[i], flatten=True).astype(numpy.float32) 31 | img_psnr = psnr(img, ref) 32 | img_ssim = ssim_exact(img, ref) 33 | img_niqe = niqe(img) 34 | PSNR_list.append(img_psnr) 35 | SSIM_list.append(img_ssim) 36 | NIQE_list.append(img_niqe) 37 | print('%d. %s: PSNR: %.4f, SSIM: %.4f, NIQE: %.4f' % (i + 1, img_list[i], img_psnr, img_ssim, img_niqe)) 38 | return_str += 'PSNR: %.4f, SSIM: %.4f, NIQE: %.4f\n' % (img_psnr, img_ssim, img_niqe) 39 | PSNR = sum(PSNR_list) / len(PSNR_list) 40 | SSIM = sum(SSIM_list) / len(SSIM_list) 41 | NIQE = sum(NIQE_list) / len(NIQE_list) 42 | return_str += 'AVG: PSNR: %.4f, SSIM: %.4f, NIQE: %.4f' % (PSNR, SSIM, NIQE) 43 | 44 | with open('test_%s.txt' % opt.gt, 'r') as f: 45 | f.write(return_str) 46 | 47 | -------------------------------------------------------------------------------- /caffemodel_to_t7.lua: -------------------------------------------------------------------------------- 1 | require 'loadcaffe' 2 | require 'xlua' 3 | require 'optim' 4 | 5 | —- modify the path 6 | 7 | prototxt = 'MTCNNv1/model/' 8 | binary = '/home/fanq15/pconvert_caffe_to_pytorch/vgg16.caffemodel' 9 | 10 | net = loadcaffe.load(prototxt, binary, 'cudnn') 11 | net = net:float() —- essential reference https://github.com/clcarwin/convert_torch_to_pytorch/issues/8 12 | print(net) 13 | 14 | torch.save('/home/fanq15/convert_caffe_to_pytorch/vgg16_torch.t7', net) 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /convert_torch.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torch.autograd import Variable 8 | from torch.utils.serialization import load_lua 9 | 10 | import numpy as np 11 | import os 12 | import math 13 | from functools import reduce 14 | 15 | class LambdaBase(nn.Sequential): 16 | def __init__(self, fn, *args): 17 | super(LambdaBase, self).__init__(*args) 18 | self.lambda_func = fn 19 | 20 | def forward_prepare(self, input): 21 | output = [] 22 | for module in self._modules.values(): 23 | output.append(module(input)) 24 | return output if output else input 25 | 26 | class Lambda(LambdaBase): 27 | def forward(self, input): 28 | return self.lambda_func(self.forward_prepare(input)) 29 | 30 | class LambdaMap(LambdaBase): 31 | def forward(self, input): 32 | # result is Variables list [Variable1, Variable2, ...] 33 | return list(map(self.lambda_func,self.forward_prepare(input))) 34 | 35 | class LambdaReduce(LambdaBase): 36 | def forward(self, input): 37 | # result is a Variable 38 | return reduce(self.lambda_func,self.forward_prepare(input)) 39 | 40 | 41 | def copy_param(m,n): 42 | if m.weight is not None: n.weight.data.copy_(m.weight) 43 | if m.bias is not None: n.bias.data.copy_(m.bias) 44 | if hasattr(n,'running_mean'): n.running_mean.copy_(m.running_mean) 45 | if hasattr(n,'running_var'): n.running_var.copy_(m.running_var) 46 | 47 | def add_submodule(seq, *args): 48 | for n in args: 49 | seq.add_module(str(len(seq._modules)),n) 50 | 51 | def lua_recursive_model(module,seq): 52 | for m in module.modules: 53 | name = type(m).__name__ 54 | real = m 55 | if name == 'TorchObject': 56 | name = m._typename.replace('cudnn.','') 57 | m = m._obj 58 | 59 | if name == 'SpatialConvolution': 60 | if not hasattr(m,'groups'): m.groups=1 61 | n = nn.Conv2d(m.nInputPlane,m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),1,m.groups,bias=(m.bias is not None)) 62 | copy_param(m,n) 63 | add_submodule(seq,n) 64 | elif name == 'SpatialBatchNormalization': 65 | n = nn.BatchNorm2d(m.running_mean.size(0), m.eps, m.momentum, m.affine) 66 | copy_param(m,n) 67 | add_submodule(seq,n) 68 | elif name == 'ReLU': 69 | n = nn.ReLU() 70 | add_submodule(seq,n) 71 | elif name == 'SpatialMaxPooling': 72 | n = nn.MaxPool2d((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),ceil_mode=m.ceil_mode) 73 | add_submodule(seq,n) 74 | elif name == 'SpatialAveragePooling': 75 | n = nn.AvgPool2d((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),ceil_mode=m.ceil_mode) 76 | add_submodule(seq,n) 77 | elif name == 'SpatialUpSamplingNearest': 78 | n = nn.UpsamplingNearest2d(scale_factor=m.scale_factor) 79 | add_submodule(seq,n) 80 | elif name == 'View': 81 | n = Lambda(lambda x: x.view(x.size(0),-1)) 82 | add_submodule(seq,n) 83 | elif name == 'Linear': 84 | # Linear in pytorch only accept 2D input 85 | n1 = Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ) 86 | n2 = nn.Linear(m.weight.size(1),m.weight.size(0),bias=(m.bias is not None)) 87 | copy_param(m,n2) 88 | n = nn.Sequential(n1,n2) 89 | add_submodule(seq,n) 90 | elif name == 'Dropout': 91 | m.inplace = False 92 | n = nn.Dropout(m.p) 93 | add_submodule(seq,n) 94 | elif name == 'SoftMax': 95 | n = nn.Softmax() 96 | add_submodule(seq,n) 97 | elif name == 'Identity': 98 | n = Lambda(lambda x: x) # do nothing 99 | add_submodule(seq,n) 100 | elif name == 'SpatialFullConvolution': 101 | n = nn.ConvTranspose2d(m.nInputPlane,m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH)) 102 | add_submodule(seq,n) 103 | elif name == 'SpatialReplicationPadding': 104 | n = nn.ReplicationPad2d((m.pad_l,m.pad_r,m.pad_t,m.pad_b)) 105 | add_submodule(seq,n) 106 | elif name == 'SpatialReflectionPadding': 107 | n = nn.ReflectionPad2d((m.pad_l,m.pad_r,m.pad_t,m.pad_b)) 108 | add_submodule(seq,n) 109 | elif name == 'Copy': 110 | n = Lambda(lambda x: x) # do nothing 111 | add_submodule(seq,n) 112 | elif name == 'Narrow': 113 | n = Lambda(lambda x,a=(m.dimension,m.index,m.length): x.narrow(*a)) 114 | add_submodule(seq,n) 115 | elif name == 'SpatialCrossMapLRN': 116 | lrn = torch.legacy.nn.SpatialCrossMapLRN(m.size,m.alpha,m.beta,m.k) 117 | n = Lambda(lambda x,lrn=lrn: Variable(lrn.forward(x.data))) 118 | add_submodule(seq,n) 119 | elif name == 'Sequential': 120 | n = nn.Sequential() 121 | lua_recursive_model(m,n) 122 | add_submodule(seq,n) 123 | elif name == 'ConcatTable': # output is list 124 | n = LambdaMap(lambda x: x) 125 | lua_recursive_model(m,n) 126 | add_submodule(seq,n) 127 | elif name == 'CAddTable': # input is list 128 | n = LambdaReduce(lambda x,y: x+y) 129 | add_submodule(seq,n) 130 | elif name == 'Concat': 131 | dim = m.dimension 132 | n = LambdaReduce(lambda x,y,dim=dim: torch.cat((x,y),dim)) 133 | lua_recursive_model(m,n) 134 | add_submodule(seq,n) 135 | elif name == 'TorchObject': 136 | print('Not Implement',name,real._typename) 137 | else: 138 | print('Not Implement',name) 139 | 140 | 141 | def lua_recursive_source(module): 142 | s = [] 143 | for m in module.modules: 144 | name = type(m).__name__ 145 | real = m 146 | if name == 'TorchObject': 147 | name = m._typename.replace('cudnn.','') 148 | m = m._obj 149 | 150 | if name == 'SpatialConvolution': 151 | if not hasattr(m,'groups'): m.groups=1 152 | s += ['nn.Conv2d({},{},{},{},{},{},{},bias={}),#Conv2d'.format(m.nInputPlane, 153 | m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),1,m.groups,m.bias is not None)] 154 | elif name == 'SpatialBatchNormalization': 155 | s += ['nn.BatchNorm2d({},{},{},{}),#BatchNorm2d'.format(m.running_mean.size(0), m.eps, m.momentum, m.affine)] 156 | elif name == 'ReLU': 157 | s += ['nn.ReLU()'] 158 | elif name == 'SpatialMaxPooling': 159 | s += ['nn.MaxPool2d({},{},{},ceil_mode={}),#MaxPool2d'.format((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),m.ceil_mode)] 160 | elif name == 'SpatialAveragePooling': 161 | s += ['nn.AvgPool2d({},{},{},ceil_mode={}),#AvgPool2d'.format((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),m.ceil_mode)] 162 | elif name == 'SpatialUpSamplingNearest': 163 | s += ['nn.UpsamplingNearest2d(scale_factor={})'.format(m.scale_factor)] 164 | elif name == 'View': 165 | s += ['Lambda(lambda x: x.view(x.size(0),-1)), # View'] 166 | elif name == 'Linear': 167 | s1 = 'Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x )' 168 | s2 = 'nn.Linear({},{},bias={})'.format(m.weight.size(1),m.weight.size(0),(m.bias is not None)) 169 | s += ['nn.Sequential({},{}),#Linear'.format(s1,s2)] 170 | elif name == 'Dropout': 171 | s += ['nn.Dropout({})'.format(m.p)] 172 | elif name == 'SoftMax': 173 | s += ['nn.Softmax()'] 174 | elif name == 'Identity': 175 | s += ['Lambda(lambda x: x), # Identity'] 176 | elif name == 'SpatialFullConvolution': 177 | s += ['nn.ConvTranspose2d({},{},{},{},{})'.format(m.nInputPlane, 178 | m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH))] 179 | elif name == 'SpatialReplicationPadding': 180 | s += ['nn.ReplicationPad2d({})'.format((m.pad_l,m.pad_r,m.pad_t,m.pad_b))] 181 | elif name == 'SpatialReflectionPadding': 182 | s += ['nn.ReflectionPad2d({})'.format((m.pad_l,m.pad_r,m.pad_t,m.pad_b))] 183 | elif name == 'Copy': 184 | s += ['Lambda(lambda x: x), # Copy'] 185 | elif name == 'Narrow': 186 | s += ['Lambda(lambda x,a={}: x.narrow(*a))'.format((m.dimension,m.index,m.length))] 187 | elif name == 'SpatialCrossMapLRN': 188 | lrn = 'torch.legacy.nn.SpatialCrossMapLRN(*{})'.format((m.size,m.alpha,m.beta,m.k)) 189 | s += ['Lambda(lambda x,lrn={}: Variable(lrn.forward(x.data)))'.format(lrn)] 190 | 191 | elif name == 'Sequential': 192 | s += ['nn.Sequential( # Sequential'] 193 | s += lua_recursive_source(m) 194 | s += [')'] 195 | elif name == 'ConcatTable': 196 | s += ['LambdaMap(lambda x: x, # ConcatTable'] 197 | s += lua_recursive_source(m) 198 | s += [')'] 199 | elif name == 'CAddTable': 200 | s += ['LambdaReduce(lambda x,y: x+y), # CAddTable'] 201 | elif name == 'Concat': 202 | dim = m.dimension 203 | s += ['LambdaReduce(lambda x,y,dim={}: torch.cat((x,y),dim), # Concat'.format(m.dimension)] 204 | s += lua_recursive_source(m) 205 | s += [')'] 206 | else: 207 | s += '# ' + name + ' Not Implement,\n' 208 | s = map(lambda x: '\t{}'.format(x),s) 209 | return s 210 | 211 | def simplify_source(s): 212 | s = map(lambda x: x.replace(',(1, 1),(0, 0),1,1,bias=True),#Conv2d',')'),s) 213 | s = map(lambda x: x.replace(',(0, 0),1,1,bias=True),#Conv2d',')'),s) 214 | s = map(lambda x: x.replace(',1,1,bias=True),#Conv2d',')'),s) 215 | s = map(lambda x: x.replace(',bias=True),#Conv2d',')'),s) 216 | s = map(lambda x: x.replace('),#Conv2d',')'),s) 217 | s = map(lambda x: x.replace(',1e-05,0.1,True),#BatchNorm2d',')'),s) 218 | s = map(lambda x: x.replace('),#BatchNorm2d',')'),s) 219 | s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#MaxPool2d',')'),s) 220 | s = map(lambda x: x.replace(',ceil_mode=False),#MaxPool2d',')'),s) 221 | s = map(lambda x: x.replace('),#MaxPool2d',')'),s) 222 | s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#AvgPool2d',')'),s) 223 | s = map(lambda x: x.replace(',ceil_mode=False),#AvgPool2d',')'),s) 224 | s = map(lambda x: x.replace(',bias=True)),#Linear',')), # Linear'),s) 225 | s = map(lambda x: x.replace(')),#Linear',')), # Linear'),s) 226 | 227 | s = map(lambda x: '{},\n'.format(x),s) 228 | s = map(lambda x: x[1:],s) 229 | s = reduce(lambda x,y: x+y, s) 230 | return s 231 | 232 | def torch_to_pytorch(t7_filename,outputname=None): 233 | model = load_lua(t7_filename,unknown_classes=True) 234 | if type(model).__name__=='hashable_uniq_dict': model=model.model 235 | model.gradInput = None 236 | slist = lua_recursive_source(torch.legacy.nn.Sequential().add(model)) 237 | s = simplify_source(slist) 238 | header = ''' 239 | import torch 240 | import torch.nn as nn 241 | from torch.autograd import Variable 242 | from functools import reduce 243 | 244 | class LambdaBase(nn.Sequential): 245 | def __init__(self, fn, *args): 246 | super(LambdaBase, self).__init__(*args) 247 | self.lambda_func = fn 248 | 249 | def forward_prepare(self, input): 250 | output = [] 251 | for module in self._modules.values(): 252 | output.append(module(input)) 253 | return output if output else input 254 | 255 | class Lambda(LambdaBase): 256 | def forward(self, input): 257 | return self.lambda_func(self.forward_prepare(input)) 258 | 259 | class LambdaMap(LambdaBase): 260 | def forward(self, input): 261 | return list(map(self.lambda_func,self.forward_prepare(input))) 262 | 263 | class LambdaReduce(LambdaBase): 264 | def forward(self, input): 265 | return reduce(self.lambda_func,self.forward_prepare(input)) 266 | ''' 267 | varname = t7_filename.replace('.t7','').replace('.','_').replace('-','_') 268 | s = '{}\n\n{} = {}'.format(header,varname,s[:-2]) 269 | 270 | if outputname is None: outputname=varname 271 | with open(outputname+'.py', "w") as pyfile: 272 | pyfile.write(s) 273 | 274 | n = nn.Sequential() 275 | lua_recursive_model(model,n) 276 | torch.save(n.state_dict(),outputname+'.pth') 277 | 278 | 279 | parser = argparse.ArgumentParser(description='Convert torch t7 model to pytorch') 280 | parser.add_argument('--model','-m', type=str, required=True, 281 | help='torch model file in t7 format') 282 | parser.add_argument('--output', '-o', type=str, default=None, 283 | help='output file name prefix, xxx.py xxx.pth') 284 | args = parser.parse_args() 285 | 286 | torch_to_pytorch(args.model,args.output) 287 | -------------------------------------------------------------------------------- /frameCutter.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import argparse 4 | 5 | parser = argparse.ArgumentParser() 6 | 7 | parser.add_argument('-i', '--input', type=str, default='../../000_src_data/vid215_201712271749') 8 | parser.add_argument('-o', '--output', type=str, default='01_cut/cut_vid215_201801031127') 9 | parser.add_argument('-f', '--fps', type=int, default=30) 10 | parser.add_argument('--ss', type=int, default=1) 11 | parser.add_argument('--t', type=int, default=8, help='negative value, means how long to the eof ') 12 | 13 | args = parser.parse_args() 14 | 15 | # in_path = '../../000_src_data/data_vid10_1218' 16 | 17 | # out_path = '../../011_aug_data/cutdata_vid10_1218' 18 | 19 | in_path = args.input 20 | out_path = args.output 21 | fps = args.fps 22 | 23 | if not os.path.exists(out_path): 24 | os.makedirs(out_path) 25 | 26 | files = os.listdir(in_path) 27 | files.sort() 28 | 29 | for file in files: 30 | os.mkdir('{}/{}'.format(out_path, file)) 31 | os.system( 32 | 'ffmpeg -ss {} -t {} -r {} -i {}/{} {}/{}/%4d.png'.format(args.ss, args.t, fps, in_path, file, out_path, file) 33 | ) 34 | 35 | --------------------------------------------------------------------------------