├── .gitignore ├── README.md ├── animegan ├── __init__.py ├── configs │ ├── __init__.py │ └── defaults.py ├── data │ ├── __init__.py │ ├── build.py │ ├── collate_batch.py │ ├── datasets │ │ ├── __init__.py │ │ ├── animegan.py │ │ └── concat_dataset.py │ ├── samplers │ │ ├── __init__.py │ │ ├── distributed.py │ │ ├── grouped_batch_sampler.py │ │ └── iteration_based_batch_sampler.py │ └── transforms │ │ ├── __init__.py │ │ ├── build.py │ │ └── transforms.py ├── engine │ ├── __init__.py │ ├── inference.py │ └── trainer.py ├── lib │ ├── __init__.py │ └── trainer.py ├── modeling │ ├── __init__.py │ ├── backbone.py │ ├── build.py │ ├── discriminator.py │ ├── generator.py │ ├── layers.py │ ├── loss.py │ ├── registry.py │ └── utils.py ├── solver │ ├── __init__.py │ ├── build.py │ └── lr_scheduler.py └── utils │ ├── __init__.py │ ├── checkpoint.py │ ├── comm.py │ ├── datasetInfo.py │ ├── env.py │ ├── logger.py │ ├── model_serialization.py │ ├── model_zoo.py │ ├── registry.py │ └── tm.py ├── configs ├── e2e_hayao.yaml └── e2e_shinkai.yaml ├── scripts ├── data_mean.py ├── gramEmbedding.py ├── image2anime.py ├── modelTensorboard.py ├── tf2torch.py └── video2anime.py ├── src ├── hayao │ ├── anime_1.jpg │ ├── anime_2.jpg │ └── anime_3.jpg └── shinkai │ ├── anime_1.jpg │ ├── anime_2.jpg │ └── anime_3.jpg └── tools └── train_net.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.local 3 | .idea 4 | .cache 5 | graph 6 | outputs 7 | local_src -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AnimeGAN 2 | 「Open Source」. The pytorch implementation of AnimeGAN. 3 | 「动画风格生成,用AI拍出一部属于自己的动漫」 4 | 5 | ## 数据集 6 | [下载地址](https://github.com/TachibanaYoshino/AnimeGAN/releases/tag/dataset-1) 7 | 训练数据结构设计 8 | ``` 9 | |_datasets 10 | |__animegan 11 | |___Shinkai(画家风格) 12 | |_____train(训练数据) 13 | |________real(真实数据) 14 | |________style(风格数据) 15 | |________smooth(经过处理的风格数据) 16 | |_____test(测试数据) 17 | |________real(真实数据) 18 | ``` 19 | 20 | ## 训练 21 | ``` 22 | # 单卡 23 | python tools/train_net.py \ 24 | --config-file "path/to/config" \ 25 | SOLVER.IMS_PER_BATCH 8 26 | 27 | # 多卡 28 | python -m torch.distributed.launch --nproc_per_node=8 \ 29 | /tools/train_net.py \ 30 | --config-file "path/to/config" \ 31 | SOLVER.IMS_PER_BATCH 8 32 | ``` 33 | 34 | ## 模型下载 35 | * 预训练模型 36 | 链接: https://pan.baidu.com/s/12mCSoACTE4sXA4ycQN0YmA 37 | 密码: 9tau 38 | * shinkai风格模型 39 | 链接: https://pan.baidu.com/s/1yG_BxCGrBqsVITqE5Vx6dw 40 | 密码: 6tp4 41 | * hayao风格模型 42 | 链接: https://pan.baidu.com/s/19PMLNO-lQ0tH0DwQjz_XMg 43 | 密码: cs27 44 | 45 | ## 图片转换 46 | ``` 47 | # --config-path 配置文件路径 48 | # --image 图像路径 49 | # MODEL.WEIGHT 模型文件路径 50 | python scripts/image2anime.py \ 51 | --config-file "your_config_path" \ 52 | --image "your_image_path" \ 53 | MODEL.WEIGHT "your_model_path" 54 | ``` 55 | 56 | ## 视频转换 57 | ``` 58 | # --config-path 配置文件路径 59 | # --video 视频路径 60 | # MODEL.WEIGHT 模型文件路径 61 | python scripts/image2anime.py \ 62 | --config-file "your_config_path" \ 63 | --video "your_video_path" \ 64 | MODEL.WEIGHT "your_model_path" 65 | ``` 66 | 67 | ## demo 68 | :heart_eyes: Photo to Shinkai Style 69 | ![](src/shinkai/anime_1.jpg) 70 | ![](src/shinkai/anime_2.jpg) 71 | ![](src/shinkai/anime_3.jpg) 72 | :heart_eyes: Photo to Hayao Style 73 | ![](src/hayao/anime_1.jpg) 74 | ![](src/hayao/anime_2.jpg) 75 | ![](src/hayao/anime_3.jpg) 76 | 77 | ## Thanks 78 | Thanks for [TachibanaYoshino](https://tachibanayoshino.github.io/AnimeGANv2/) -------------------------------------------------------------------------------- /animegan/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | -------------------------------------------------------------------------------- /animegan/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | 4 | from .defaults import _C as cfg -------------------------------------------------------------------------------- /animegan/configs/defaults.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | 4 | from yacs.config import CfgNode as CN 5 | _C = CN() 6 | 7 | # ----------------------------------------------------------------------------- 8 | # Config definition 9 | # ----------------------------------------------------------------------------- 10 | _C.MODEL = CN() 11 | # 训练设备类型 12 | _C.MODEL.DEVICE = "cuda" 13 | # 是否迁移学习 14 | _C.MODEL.TRANSFER_LEARNING = False 15 | _C.MODEL.WEIGHT = "" 16 | 17 | # ----------------------------------------------------------------------------- 18 | # BACKBONE 19 | # ----------------------------------------------------------------------------- 20 | _C.MODEL.BACKBONE = CN() 21 | _C.MODEL.BACKBONE.BODY = "VGG19" 22 | _C.MODEL.BACKBONE.WEIGHT = "" 23 | 24 | # ----------------------------------------------------------------------------- 25 | # INPUT 26 | # ----------------------------------------------------------------------------- 27 | _C.INPUT = CN() 28 | _C.INPUT.IMG_SIZE = (256, 256) 29 | # RGB 30 | _C.INPUT.PIXEL_MEAN = [-4.4661, -8.6698, 13.1360] 31 | # _C.INPUT.PIXEL_STD = [127.5, 127.5, 127.5] 32 | 33 | # ----------------------------------------------------------------------------- 34 | # DISCRIMINATOR 35 | # ----------------------------------------------------------------------------- 36 | _C.MODEL.DISCRIMINATOR = CN() 37 | _C.MODEL.DISCRIMINATOR.BODY = "Base-256" 38 | _C.MODEL.DISCRIMINATOR.IN_CHANNELS = 3 39 | _C.MODEL.DISCRIMINATOR.CHANNELS = 64 40 | _C.MODEL.DISCRIMINATOR.N_DIS = 2 41 | 42 | # ----------------------------------------------------------------------------- 43 | # GENERATOR 44 | # ----------------------------------------------------------------------------- 45 | _C.MODEL.GENERATOR = CN() 46 | _C.MODEL.GENERATOR.BODY = "Base-256" 47 | _C.MODEL.GENERATOR.IN_CHANNELS = 3 48 | 49 | # ---------------------------------------------------------------------------- # 50 | # Loss options 51 | # ---------------------------------------------------------------------------- # 52 | _C.MODEL.LOSS = CN() 53 | _C.MODEL.LOSS.S_LOSS_COLOR2GRAY = True 54 | 55 | # ---------------------------------------------------------------------------- # 56 | # Common options 57 | # ---------------------------------------------------------------------------- # 58 | _C.MODEL.COMMON = CN() 59 | _C.MODEL.COMMON.GAN_TYPE = 'lsgan' 60 | _C.MODEL.COMMON.TRAINING_RATE = 1 61 | 62 | _C.MODEL.COMMON.LD = 10.0 63 | _C.MODEL.COMMON.WEIGHT_ADV_G = 300.0 64 | _C.MODEL.COMMON.WEIGHT_ADV_D = 300.0 65 | _C.MODEL.COMMON.WEIGHT_G_CON = 1.5 66 | _C.MODEL.COMMON.WEIGHT_G_STYLE = 2.5 67 | _C.MODEL.COMMON.WEIGHT_G_COLOR = 10.0 68 | _C.MODEL.COMMON.WEIGHT_G_TV = 1.0 69 | _C.MODEL.COMMON.WEIGHT_D_LOSS_REAL = 1.7 70 | _C.MODEL.COMMON.WEIGHT_D_LOSS_FAKE = 1.7 71 | _C.MODEL.COMMON.WEIGHT_D_LOSS_GRAY = 1.7 72 | _C.MODEL.COMMON.WEIGHT_D_LOSS_BLUR = 1.0 73 | 74 | # ----------------------------------------------------------------------------- 75 | # Dataset 76 | # ----------------------------------------------------------------------------- 77 | _C.DATASETS = CN() 78 | # List of the dataset Info for training 79 | _C.DATASETS.TRAIN = [] 80 | # List of the dataset Info for testing 81 | _C.DATASETS.TEST = [] 82 | 83 | # ----------------------------------------------------------------------------- 84 | # DataLoader 85 | # ----------------------------------------------------------------------------- 86 | _C.DATALOADER = CN() 87 | # Number of data loading threads 88 | _C.DATALOADER.NUM_WORKERS = 1 89 | 90 | # ---------------------------------------------------------------------------- # 91 | # Solver 92 | # ---------------------------------------------------------------------------- # 93 | _C.SOLVER = CN() 94 | 95 | _C.SOLVER.MAX_EPOCH = 100 96 | # 单位epoch 97 | _C.SOLVER.CHECKPOINT_PERIOD = 20 98 | # 单位iteration 99 | _C.SOLVER.PRINT_PERIOD = 20 100 | # 单位epoch 101 | _C.SOLVER.TEST_PERIOD = 10 102 | #每个batch处理图片,基于GPU数量定 103 | _C.SOLVER.IMS_PER_BATCH = 32 104 | 105 | 106 | _C.SOLVER.GENERATOR = CN() 107 | _C.SOLVER.GENERATOR.INIT_EPOCH = 10 108 | _C.SOLVER.GENERATOR.BASE_LR = 0.0002 109 | _C.SOLVER.GENERATOR.STEPS = (10,) 110 | _C.SOLVER.GENERATOR.GAMMA = 0.1 111 | _C.SOLVER.GENERATOR.WARMUP_FACTOR = 1.0 / 3 112 | _C.SOLVER.GENERATOR.WARMUP_ITERS = 0 113 | _C.SOLVER.GENERATOR.WARMUP_METHOD = "constant" 114 | 115 | _C.SOLVER.DISCRIMINATOR = CN() 116 | _C.SOLVER.DISCRIMINATOR.BASE_LR = 0.00004 117 | _C.SOLVER.DISCRIMINATOR.STEPS = (100,) 118 | _C.SOLVER.DISCRIMINATOR.GAMMA = 0.1 119 | _C.SOLVER.DISCRIMINATOR.WARMUP_FACTOR = 1.0 / 3 120 | _C.SOLVER.DISCRIMINATOR.WARMUP_ITERS = 0 121 | _C.SOLVER.DISCRIMINATOR.WARMUP_METHOD = "linear" 122 | 123 | # ---------------------------------------------------------------------------- # 124 | # TEST 125 | # ---------------------------------------------------------------------------- # 126 | _C.TEST = CN() 127 | # Number of images per batch 128 | _C.TEST.IMS_PER_BATCH = 32 129 | 130 | # ---------------------------------------------------------------------------- # 131 | # Misc options 132 | # ---------------------------------------------------------------------------- # 133 | _C.OUTPUT_DIR = "./outputs/" -------------------------------------------------------------------------------- /animegan/data/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | -------------------------------------------------------------------------------- /animegan/data/build.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | 4 | # coding: utf-8 5 | # Author: wanhui0729@gmail.com 6 | 7 | import bisect 8 | import copy 9 | import torch.utils.data 10 | from . import samplers 11 | from . import datasets as dataInterface 12 | from .collate_batch import ImageBatchCollator 13 | from .transforms.build import build_transforms 14 | from .datasets.concat_dataset import ConcatDataset 15 | from animegan.utils.comm import get_world_size 16 | from animegan.utils.datasetInfo import DatasetInfo 17 | 18 | def build_dataset(dataInterface, datasetsInfo, transforms, is_train=True): 19 | if not isinstance(datasetsInfo, (list, tuple)): 20 | raise RuntimeError( 21 | "datasetsInfo should be a list" 22 | ) 23 | datasets = [] 24 | for datasetInfo in datasetsInfo: 25 | data = datasetInfo.get() 26 | factory = getattr(dataInterface, data.get("factory")) 27 | args = data.get("args") 28 | args["transforms"] = transforms 29 | # make dataset from factory 30 | try: 31 | dataset = factory(**args) 32 | except: 33 | raise ValueError("Please check dataset factory, the parameters are: {}".format(args)) 34 | datasets.append(dataset) 35 | 36 | # for testing, return a list of datasets 37 | if not is_train: 38 | return datasets 39 | 40 | # for training, concatenate all datasets into a single one 41 | dataset = datasets[0] 42 | if len(datasets) > 1: 43 | dataset = ConcatDataset(datasets) 44 | 45 | return [dataset] 46 | 47 | 48 | def make_data_sampler(dataset, shuffle, distributed): 49 | if distributed: 50 | return samplers.DistributedSampler(dataset, shuffle=shuffle) 51 | if shuffle: 52 | sampler = torch.utils.data.sampler.RandomSampler(dataset) 53 | else: 54 | sampler = torch.utils.data.sampler.SequentialSampler(dataset) 55 | return sampler 56 | 57 | 58 | def _quantize(x, bins): 59 | bins = copy.copy(bins) 60 | bins = sorted(bins) 61 | quantized = list(map(lambda y: bisect.bisect_right(bins, y), x)) 62 | return quantized 63 | 64 | 65 | def make_batch_data_sampler( 66 | dataset, sampler, images_per_batch, num_iters=None, start_iter=0, is_train=False 67 | ): 68 | # 防止batchsize为1时导致模型batchnorm报错 69 | drop = True if is_train else False 70 | batch_sampler = torch.utils.data.sampler.BatchSampler( 71 | sampler, images_per_batch, drop_last=drop 72 | ) 73 | if num_iters is not None: 74 | batch_sampler = samplers.IterationBasedBatchSampler( 75 | batch_sampler, num_iters, start_iter 76 | ) 77 | return batch_sampler 78 | 79 | def make_datasetsInfo(datasetsDictInfo): 80 | datasetsInfo = list() 81 | for datasetDictInfo in datasetsDictInfo: 82 | datasetsInfo.append(DatasetInfo(**datasetDictInfo)) 83 | return datasetsInfo 84 | 85 | def make_datasets(cfg, is_train=True, dataEntrance=None): 86 | ''' 87 | Arguments: 88 | cfg: 配置文件 89 | is_train: bool,是否处于训练阶段 90 | return: 91 | datasets: list of dataset 92 | epoch_sizes: list of size of the dataset's ecpoch 93 | ''' 94 | datasetsInfo = make_datasetsInfo(cfg.DATASETS.TRAIN) if is_train \ 95 | else make_datasetsInfo(cfg.DATASETS.TEST) 96 | 97 | transforms = build_transforms(cfg, is_train) 98 | dataEntrance = dataEntrance or dataInterface 99 | datasets = build_dataset(dataEntrance, datasetsInfo, transforms, is_train=is_train) 100 | 101 | epoch_sizes = [] 102 | for dataset in datasets: 103 | epoch_size = len(dataset) // cfg.SOLVER.IMS_PER_BATCH 104 | epoch_sizes.append(epoch_size) 105 | if is_train: 106 | # during training, a single (possibly concatenated) datasets is returned 107 | assert len(datasets) == 1 108 | return datasets, epoch_sizes 109 | 110 | def make_data_loader(cfg, datasets, epoch_sizes, is_train=True, is_distributed=False, start_iter=0): 111 | num_gpus = get_world_size() 112 | if is_train: 113 | images_per_batch = cfg.SOLVER.IMS_PER_BATCH 114 | assert ( 115 | images_per_batch % num_gpus == 0 116 | ), "SOLVER.IMS_PER_BATCH ({}) must be divisible by the number " 117 | "of GPUs ({}) used.".format(images_per_batch, num_gpus) 118 | images_per_gpu = images_per_batch // num_gpus 119 | shuffle = True 120 | num_iters = cfg.SOLVER.MAX_EPOCH * epoch_sizes[0] 121 | else: 122 | images_per_batch = cfg.TEST.IMS_PER_BATCH 123 | assert ( 124 | images_per_batch % num_gpus == 0 125 | ), "TEST.IMS_PER_BATCH ({}) must be divisible by the number " 126 | "of GPUs ({}) used.".format(images_per_batch, num_gpus) 127 | images_per_gpu = images_per_batch // num_gpus 128 | shuffle = False if not is_distributed else True 129 | start_iter = 0 130 | num_iters = None 131 | 132 | data_loaders = [] 133 | for dataset in datasets: 134 | sampler = make_data_sampler(dataset, shuffle, is_distributed) 135 | batch_sampler = make_batch_data_sampler( 136 | dataset, sampler, images_per_gpu, num_iters, start_iter, is_train 137 | ) 138 | collator = ImageBatchCollator() 139 | num_workers = cfg.DATALOADER.NUM_WORKERS 140 | data_loader = torch.utils.data.DataLoader( 141 | dataset, 142 | num_workers=num_workers, 143 | batch_sampler=batch_sampler, 144 | collate_fn=collator, 145 | ) 146 | data_loaders.append(data_loader) 147 | if is_train: 148 | # during training, a single (possibly concatenated) data_loader is returned 149 | assert len(data_loaders) == 1 150 | return data_loaders[0] 151 | return data_loaders 152 | -------------------------------------------------------------------------------- /animegan/data/collate_batch.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | 4 | import torch 5 | class ImageBatchCollator(object): 6 | def _prepare(self, images): 7 | ''' 8 | 准备训练数据 9 | ''' 10 | if not all(images): 11 | return None 12 | color_images = [image[0] for image in images] 13 | gray_images = [image[1] for image in images] 14 | return [torch.stack(color_images), torch.stack(gray_images)] 15 | 16 | def __call__(self, batch): 17 | transposed_batch = list(zip(*batch)) 18 | real_images = transposed_batch[0] 19 | style_images = transposed_batch[1] 20 | smooth_images = transposed_batch[2] 21 | img_ids = transposed_batch[3] 22 | return self._prepare(real_images), self._prepare(style_images), self._prepare(smooth_images), img_ids 23 | -------------------------------------------------------------------------------- /animegan/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | 4 | from .animegan import AnimeGanDataset -------------------------------------------------------------------------------- /animegan/data/datasets/animegan.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | 4 | import os 5 | import cv2 6 | import random 7 | from tqdm import tqdm 8 | import numpy as np 9 | import torch.utils.data 10 | from PIL import Image 11 | from animegan.utils.comm import is_main_process, synchronize 12 | 13 | class AnimeGanDataset(torch.utils.data.Dataset): 14 | def __init__(self, dataDir, split, transforms=None): 15 | assert split in ('train', 'test'), "Please check split supported." 16 | self.transforms = transforms 17 | self.split = split 18 | dataFolder = os.path.join(dataDir, self.split) 19 | if self.split == 'train': 20 | real_path = os.path.join(dataFolder, 'real') 21 | style_path = os.path.join(dataFolder, 'style') 22 | smooth_path = os.path.join(dataFolder, 'smooth') 23 | # 初始化做smooth处理 24 | if not os.path.exists(smooth_path) and is_main_process(): 25 | self._gen_smooth(style_path, smooth_path) 26 | synchronize() 27 | self.real = [os.path.join(real_path, name) for name in os.listdir(real_path)] 28 | self.style = [os.path.join(style_path, name) for name in os.listdir(style_path)] 29 | self.smooth_path = smooth_path 30 | self._init_real_producer() 31 | self._init_style_producer() 32 | else: 33 | real_path = os.path.join(dataFolder, 'real') 34 | self.real = [os.path.join(real_path, name) for name in os.listdir(real_path)] 35 | 36 | def _gen_smooth(self, style_path, smooth_path): 37 | os.makedirs(smooth_path) 38 | for image in tqdm(os.listdir(style_path)): 39 | image_path = os.path.join(style_path, image) 40 | bgr_img = cv2.imread(image_path) 41 | gray_img = cv2.imread(image_path, 0) 42 | 43 | kernel_size = 5 44 | kernel = np.ones((kernel_size, kernel_size), np.uint8) 45 | gauss = cv2.getGaussianKernel(kernel_size, 0) 46 | gauss = gauss * gauss.transpose(1, 0) 47 | 48 | pad_img = np.pad(bgr_img, ((2, 2), (2, 2), (0, 0)), mode='reflect') 49 | edges = cv2.Canny(gray_img, 100, 200) 50 | dilation = cv2.dilate(edges, kernel) 51 | gauss_img = np.copy(bgr_img) 52 | idx = np.where(dilation != 0) 53 | for i in range(np.sum(dilation != 0)): 54 | gauss_img[idx[0][i], idx[1][i], 0] = np.sum( 55 | np.multiply(pad_img[idx[0][i]:idx[0][i] + kernel_size, idx[1][i]:idx[1][i] + kernel_size, 0], 56 | gauss)) 57 | gauss_img[idx[0][i], idx[1][i], 1] = np.sum( 58 | np.multiply(pad_img[idx[0][i]:idx[0][i] + kernel_size, idx[1][i]:idx[1][i] + kernel_size, 1], 59 | gauss)) 60 | gauss_img[idx[0][i], idx[1][i], 2] = np.sum( 61 | np.multiply(pad_img[idx[0][i]:idx[0][i] + kernel_size, idx[1][i]:idx[1][i] + kernel_size, 2], 62 | gauss)) 63 | 64 | cv2.imwrite(os.path.join(smooth_path, image), gauss_img) 65 | 66 | def _init_real_producer(self): 67 | self.real_producer = self.real.copy() 68 | random.shuffle(self.real_producer) 69 | 70 | def _init_style_producer(self): 71 | self.style_producer = self.style.copy() 72 | random.shuffle(self.style_producer) 73 | 74 | def _real_consumer(self): 75 | if len(self.real_producer) == 0: 76 | self._init_real_producer() 77 | real = self.real_producer.pop() 78 | return real 79 | 80 | def _style_consumer(self): 81 | if len(self.style_producer) == 0: 82 | self._init_style_producer() 83 | style = self.style_producer.pop() 84 | return style 85 | 86 | def __getitem__(self, index): 87 | if self.split == 'train': 88 | real, style = self._real_consumer(), self._style_consumer() 89 | # 同名 90 | smooth = os.path.join(self.smooth_path, os.path.basename(style)) 91 | real = Image.open(real).convert("RGB") 92 | style = Image.open(style).convert("RGB") 93 | smooth = Image.open(smooth).convert("RGB") 94 | if self.transforms: 95 | [real, style, smooth] = self.transforms([real, style, smooth]) 96 | return real, style, smooth, index 97 | else: 98 | real = self.real[index] 99 | real = Image.open(real).convert("RGB") 100 | if self.transforms: 101 | [real] = self.transforms([real]) 102 | return real, None, None, index 103 | 104 | def __len__(self): 105 | if self.split == 'train': 106 | return max(len(self.real), len(self.style)) 107 | else: 108 | return len(self.real) 109 | 110 | # def __iter__(self): 111 | # self.iternum = self.__len__() 112 | # return self 113 | # 114 | # def __next__(self): 115 | # self.iternum -= 1 116 | # if self.iternum < 0: 117 | # raise StopIteration 118 | # else: 119 | # return self.__getitem__(self.iternum) -------------------------------------------------------------------------------- /animegan/data/datasets/concat_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import bisect 3 | 4 | from torch.utils.data.dataset import ConcatDataset as _ConcatDataset 5 | 6 | 7 | class ConcatDataset(_ConcatDataset): 8 | """ 9 | Same as torch.utils.data.dataset.ConcatDataset, but exposes an extra 10 | method for querying the sizes of the image 11 | """ 12 | 13 | def get_idxs(self, idx): 14 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 15 | if dataset_idx == 0: 16 | sample_idx = idx 17 | else: 18 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 19 | return dataset_idx, sample_idx 20 | 21 | def get_img_info(self, idx): 22 | dataset_idx, sample_idx = self.get_idxs(idx) 23 | return self.datasets[dataset_idx].get_img_info(sample_idx) 24 | -------------------------------------------------------------------------------- /animegan/data/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from .distributed import DistributedSampler 3 | from .grouped_batch_sampler import GroupedBatchSampler 4 | from .iteration_based_batch_sampler import IterationBasedBatchSampler 5 | 6 | __all__ = ["DistributedSampler", "GroupedBatchSampler", "IterationBasedBatchSampler"] 7 | -------------------------------------------------------------------------------- /animegan/data/samplers/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # Code is copy-pasted exactly as in torch.utils.data.distributed. 3 | # FIXME remove this once c10d fixes the bug it has 4 | import math 5 | import torch 6 | import torch.distributed as dist 7 | from torch.utils.data.sampler import Sampler 8 | 9 | 10 | class DistributedSampler(Sampler): 11 | """Sampler that restricts data loading to a subset of the dataset. 12 | It is especially useful in conjunction with 13 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 14 | process can pass a DistributedSampler instance as a DataLoader sampler, 15 | and load a subset of the original dataset that is exclusive to it. 16 | .. note:: 17 | Dataset is assumed to be of constant size. 18 | Arguments: 19 | dataset: Dataset used for sampling. 20 | num_replicas (optional): Number of processes participating in 21 | distributed training. 22 | rank (optional): Rank of the current process within num_replicas. 23 | """ 24 | 25 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 26 | if num_replicas is None: 27 | if not dist.is_available(): 28 | raise RuntimeError("Requires distributed package to be available") 29 | num_replicas = dist.get_world_size() 30 | if rank is None: 31 | if not dist.is_available(): 32 | raise RuntimeError("Requires distributed package to be available") 33 | rank = dist.get_rank() 34 | self.dataset = dataset 35 | self.num_replicas = num_replicas 36 | self.rank = rank 37 | self.epoch = 0 38 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 39 | self.total_size = self.num_samples * self.num_replicas 40 | self.shuffle = shuffle 41 | 42 | def __iter__(self): 43 | if self.shuffle: 44 | # deterministically shuffle based on epoch 45 | g = torch.Generator() 46 | g.manual_seed(self.epoch) 47 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 48 | else: 49 | indices = torch.arange(len(self.dataset)).tolist() 50 | 51 | # add extra samples to make it evenly divisible 52 | indices += indices[: (self.total_size - len(indices))] 53 | assert len(indices) == self.total_size 54 | 55 | # subsample 56 | offset = self.num_samples * self.rank 57 | indices = indices[offset: offset + self.num_samples] 58 | assert len(indices) == self.num_samples 59 | 60 | return iter(indices) 61 | 62 | def __len__(self): 63 | return self.num_samples 64 | 65 | def set_epoch(self, epoch): 66 | self.epoch = epoch 67 | -------------------------------------------------------------------------------- /animegan/data/samplers/grouped_batch_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import itertools 3 | 4 | import torch 5 | from torch.utils.data.sampler import BatchSampler 6 | from torch.utils.data.sampler import Sampler 7 | 8 | 9 | class GroupedBatchSampler(BatchSampler): 10 | """ 11 | Wraps another sampler to yield a mini-batch of indices. 12 | It enforces that elements from the same group should appear in groups of batch_size. 13 | It also tries to provide mini-batches which follows an ordering which is 14 | as close as possible to the ordering from the original sampler. 15 | 16 | Arguments: 17 | sampler (Sampler): Base sampler. 18 | batch_size (int): Size of mini-batch. 19 | drop_uneven (bool): If ``True``, the sampler will drop the batches whose 20 | size is less than ``batch_size`` 21 | 22 | """ 23 | 24 | def __init__(self, sampler, group_ids, batch_size, drop_uneven=False): 25 | if not isinstance(sampler, Sampler): 26 | raise ValueError( 27 | "sampler should be an instance of " 28 | "torch.utils.data.Sampler, but got sampler={}".format(sampler) 29 | ) 30 | self.sampler = sampler 31 | self.group_ids = torch.as_tensor(group_ids) 32 | assert self.group_ids.dim() == 1 33 | self.batch_size = batch_size 34 | self.drop_uneven = drop_uneven 35 | 36 | self.groups = torch.unique(self.group_ids).sort(0)[0] 37 | 38 | self._can_reuse_batches = False 39 | 40 | def _prepare_batches(self): 41 | dataset_size = len(self.group_ids) 42 | # get the sampled indices from the sampler 43 | sampled_ids = torch.as_tensor(list(self.sampler)) 44 | # potentially not all elements of the dataset were sampled 45 | # by the sampler (e.g., DistributedSampler). 46 | # construct a tensor which contains -1 if the element was 47 | # not sampled, and a non-negative number indicating the 48 | # order where the element was sampled. 49 | # for example. if sampled_ids = [3, 1] and dataset_size = 5, 50 | # the order is [-1, 1, -1, 0, -1] 51 | order = torch.full((dataset_size,), -1, dtype=torch.int64) 52 | order[sampled_ids] = torch.arange(len(sampled_ids)) 53 | 54 | # get a mask with the elements that were sampled 55 | mask = order >= 0 56 | 57 | # find the elements that belong to each individual cluster 58 | clusters = [(self.group_ids == i) & mask for i in self.groups] 59 | # get relative order of the elements inside each cluster 60 | # that follows the order from the sampler 61 | relative_order = [order[cluster] for cluster in clusters] 62 | # with the relative order, find the absolute order in the 63 | # sampled space 64 | permutation_ids = [s[s.sort()[1]] for s in relative_order] 65 | # permute each cluster so that they follow the order from 66 | # the sampler 67 | permuted_clusters = [sampled_ids[idx] for idx in permutation_ids] 68 | 69 | # splits each cluster in batch_size, and merge as a list of tensors 70 | splits = [c.split(self.batch_size) for c in permuted_clusters] 71 | merged = tuple(itertools.chain.from_iterable(splits)) 72 | 73 | # now each batch internally has the right order, but 74 | # they are grouped by clusters. Find the permutation between 75 | # different batches that brings them as close as possible to 76 | # the order that we have in the sampler. For that, we will consider the 77 | # ordering as coming from the first element of each batch, and sort 78 | # correspondingly 79 | first_element_of_batch = [t[0].item() for t in merged] 80 | # get and inverse mapping from sampled indices and the position where 81 | # they occur (as returned by the sampler) 82 | inv_sampled_ids_map = {v: k for k, v in enumerate(sampled_ids.tolist())} 83 | # from the first element in each batch, get a relative ordering 84 | first_index_of_batch = torch.as_tensor( 85 | [inv_sampled_ids_map[s] for s in first_element_of_batch] 86 | ) 87 | 88 | # permute the batches so that they approximately follow the order 89 | # from the sampler 90 | permutation_order = first_index_of_batch.sort(0)[1].tolist() 91 | # finally, permute the batches 92 | batches = [merged[i].tolist() for i in permutation_order] 93 | 94 | if self.drop_uneven: 95 | kept = [] 96 | for batch in batches: 97 | if len(batch) == self.batch_size: 98 | kept.append(batch) 99 | batches = kept 100 | return batches 101 | 102 | def __iter__(self): 103 | if self._can_reuse_batches: 104 | batches = self._batches 105 | self._can_reuse_batches = False 106 | else: 107 | batches = self._prepare_batches() 108 | self._batches = batches 109 | return iter(batches) 110 | 111 | def __len__(self): 112 | if not hasattr(self, "_batches"): 113 | self._batches = self._prepare_batches() 114 | self._can_reuse_batches = True 115 | return len(self._batches) 116 | -------------------------------------------------------------------------------- /animegan/data/samplers/iteration_based_batch_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from torch.utils.data.sampler import BatchSampler 3 | 4 | 5 | class IterationBasedBatchSampler(BatchSampler): 6 | """ 7 | Wraps a BatchSampler, resampling from it until 8 | a specified number of iterations have been sampled 9 | """ 10 | 11 | def __init__(self, batch_sampler, num_iterations, start_iter=0): 12 | self.batch_sampler = batch_sampler 13 | self.num_iterations = num_iterations 14 | self.start_iter = start_iter 15 | 16 | def __iter__(self): 17 | iteration = self.start_iter 18 | while iteration <= self.num_iterations: 19 | # if the underlying sampler has a set_epoch method, like 20 | # DistributedSampler, used for making each process see 21 | # a different split of the dataset, then set it 22 | if hasattr(self.batch_sampler.sampler, "set_epoch"): 23 | self.batch_sampler.sampler.set_epoch(iteration) 24 | for batch in self.batch_sampler: 25 | iteration += 1 26 | if iteration > self.num_iterations: 27 | break 28 | yield batch 29 | 30 | def __len__(self): 31 | return self.num_iterations 32 | -------------------------------------------------------------------------------- /animegan/data/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com -------------------------------------------------------------------------------- /animegan/data/transforms/build.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | 4 | from . import transforms as T 5 | 6 | def build_transforms(cfg, is_train=True): 7 | tansform = T.Compose(cfg, is_train) 8 | return tansform -------------------------------------------------------------------------------- /animegan/data/transforms/transforms.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | 4 | import torch 5 | import random 6 | import numpy as np 7 | 8 | class Compose(): 9 | def __init__(self, cfg, is_train): 10 | self.cfg = cfg 11 | self.is_train = is_train 12 | 13 | def __call__(self, images): 14 | outputs = [] 15 | if self.is_train: 16 | assert len(images) == 3 17 | # real 18 | grayAndNormlize_real = GrayAndNormlize(mean=self.cfg.INPUT.PIXEL_MEAN) 19 | randomCrop_style = RandomCrop(self.cfg.INPUT.IMG_SIZE[0], self.cfg.INPUT.IMG_SIZE[1]) 20 | crop_images = randomCrop_style(images[0]) 21 | outputs += grayAndNormlize_real(crop_images, use_norm=False) 22 | # style and smooth 23 | randomCrop_style = RandomCrop(self.cfg.INPUT.IMG_SIZE[0], self.cfg.INPUT.IMG_SIZE[1]) 24 | grayAndNormlize_style = GrayAndNormlize(mean=self.cfg.INPUT.PIXEL_MEAN) 25 | crop_images = randomCrop_style(images[1:]) 26 | outputs += grayAndNormlize_style(crop_images) 27 | else: 28 | assert len(images) == 1 29 | resize_test = Resize(self.cfg.INPUT.IMG_SIZE) 30 | grayAndNormlize_test = GrayAndNormlize(mean=self.cfg.INPUT.PIXEL_MEAN) 31 | resize_images = resize_test(images) 32 | outputs += grayAndNormlize_test(resize_images) 33 | return outputs 34 | 35 | class Resize(): 36 | def __init__(self, size): 37 | self.size = size 38 | 39 | def get_size(self, image_size): 40 | w, h = image_size 41 | if h <= self.size[1]: 42 | h = self.size[1] 43 | else: 44 | x = h % 32 45 | h = h - x 46 | if w < self.size[0]: 47 | w = self.size[0] 48 | else: 49 | y = w % 32 50 | w = w - y 51 | return (w, h) 52 | 53 | def __call__(self, images): 54 | if not isinstance(images, list): 55 | images = [images] 56 | output = [] 57 | for image in images: 58 | size = self.get_size(image.size) 59 | image = image.resize(size) 60 | output.append(image) 61 | return output 62 | 63 | class RandomCrop(): 64 | def __init__(self, min_size, max_size): 65 | self.min_size = min_size 66 | self.max_size = max_size 67 | 68 | def get_size(self): 69 | size = random.randint(self.min_size, self.max_size) 70 | # h = w 71 | return (size, size) 72 | 73 | def __call__(self, images): 74 | if not isinstance(images, list): 75 | images = [images] 76 | w_crop, h_crop = self.get_size() 77 | w, h = images[0].size 78 | x = random.randint(0, w - w_crop) 79 | y = random.randint(0, h - h_crop) 80 | output = [] 81 | for image in images: 82 | image = image.crop((x, y, x + w_crop, y + h_crop)) 83 | output.append(image) 84 | return output 85 | 86 | class GrayAndNormlize(): 87 | def __init__(self, mean): 88 | self.mean = mean 89 | 90 | def __call__(self, images, use_norm=True): 91 | if not isinstance(images, list): 92 | images = [images] 93 | output = [] 94 | for image in images: 95 | image_color = np.array(image).astype(np.float32) 96 | image_gray = np.array(image.convert('L')).astype(np.float32) 97 | if use_norm: 98 | image_color[..., 0] += self.mean[0] 99 | image_color[..., 1] += self.mean[1] 100 | image_color[..., 2] += self.mean[2] 101 | image_color = torch.from_numpy(image_color.transpose((2, 0, 1))) 102 | image_gray = torch.from_numpy(np.asarray([image_gray, image_gray, image_gray])) 103 | output.append([image_color / 127.5 - 1.0, image_gray / 127.5 - 1.0]) 104 | return output -------------------------------------------------------------------------------- /animegan/engine/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | -------------------------------------------------------------------------------- /animegan/engine/inference.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | 4 | import os 5 | import cv2 6 | import torch 7 | import logging 8 | import datetime 9 | import numpy as np 10 | from tqdm import tqdm 11 | from animegan.utils.tm import Timer 12 | from animegan.modeling.utils import adjust_brightness_from_src_to_dst 13 | from animegan.utils.comm import is_main_process, get_world_size, synchronize 14 | 15 | def compute_on_dataset(model, data_loader, device): 16 | model.eval() 17 | results_dict = {} 18 | cpu_device = torch.device("cpu") 19 | cuda_device = torch.device("cuda") 20 | inference_timer = Timer() 21 | for i, batch in enumerate(tqdm(data_loader)): 22 | real_images, _, _, image_ids = batch 23 | # color images 24 | images = real_images[0] 25 | images = images.to(device) 26 | with torch.no_grad(): 27 | inference_timer.tic() 28 | output = model(images) 29 | if device == cuda_device: 30 | torch.cuda.synchronize() 31 | inference_timer.toc() 32 | output = [(img.to(cpu_device), o.to(cpu_device)) for img, o in zip(images, output)] 33 | results_dict.update( 34 | {img_id: result for img_id, result in zip(image_ids, output)} 35 | ) 36 | return results_dict, inference_timer.total_time 37 | 38 | def _save_prediction_images(predictions, output_folder, epoch): 39 | save_path = os.path.join(output_folder, str(epoch)) 40 | if not os.path.exists(save_path): 41 | os.makedirs(save_path, exist_ok=True) 42 | for img_id, (img, pred) in tqdm(predictions.items()): 43 | ori_img = (img.squeeze() + 1.) / 2 * 255 44 | ori_img = ori_img.permute(1, 2, 0).numpy().clip(0, 255).astype(np.uint8) 45 | fake_img = (pred.squeeze() + 1.) / 2 * 255 46 | fake_img = fake_img.permute(1, 2, 0).numpy().clip(0, 255).astype(np.uint8) 47 | fake_img = adjust_brightness_from_src_to_dst(fake_img, ori_img) 48 | cv2.imwrite(os.path.join(save_path, f'{img_id}_a.jpg'), cv2.cvtColor(ori_img, cv2.COLOR_RGB2BGR)) 49 | cv2.imwrite(os.path.join(save_path, f'{img_id}_b.jpg'), cv2.cvtColor(fake_img, cv2.COLOR_RGB2BGR)) 50 | 51 | class Evaluator(object): 52 | def __init__(self, data_loader, device="cuda", output_folder=None, logger_name=None): 53 | self.data_loader = data_loader 54 | self.device = torch.device(device) 55 | self.logger = logging.getLogger(logger_name + ".inference") 56 | self.output_folder = output_folder 57 | 58 | def do_inference(self, model, epoch): 59 | num_devices = get_world_size() 60 | dataset = self.data_loader.dataset 61 | self.logger.info("Start evaluation on {} dataset({} images).".format(dataset.__class__.__name__, len(dataset))) 62 | predictions, total_time = compute_on_dataset(model, self.data_loader, self.device) 63 | # wait for all processes to complete before measuring the time 64 | synchronize() 65 | total_time_str = str(datetime.timedelta(seconds=total_time)) 66 | self.logger.info( 67 | "Total inference time: {} ({} s / img per device, on {} devices)".format( 68 | total_time_str, total_time * num_devices / len(dataset), num_devices 69 | ) 70 | ) 71 | 72 | if self.output_folder: 73 | self.logger.info("Start save generated images on {} dataset({} images).".format(dataset.__class__.__name__, len(dataset))) 74 | _save_prediction_images(predictions, self.output_folder, epoch-1) 75 | -------------------------------------------------------------------------------- /animegan/engine/trainer.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | 4 | import os 5 | import math 6 | import logging 7 | import torch.distributed as dist 8 | from torch.utils.tensorboard import SummaryWriter 9 | from animegan.utils.tm import * 10 | from animegan.utils.comm import synchronize 11 | from animegan.utils.logger import MetricLogger 12 | from animegan.utils.comm import get_world_size, is_main_process 13 | from animegan.modeling.loss import * 14 | from animegan.modeling.utils import rgbScaled 15 | 16 | def reduce_loss_dict(loss_dict): 17 | """ 18 | Reduce the loss dictionary from all processes so that process with rank 19 | 0 has the averaged results. Returns a dict with the same fields as 20 | loss_dict, after reduction. 21 | """ 22 | world_size = get_world_size() 23 | if world_size < 2: 24 | return loss_dict 25 | with torch.no_grad(): 26 | loss_names = [] 27 | loss_values = [] 28 | # for k, v in loss_dict.items(): 29 | for k in sorted(loss_dict.keys()): 30 | loss_names.append(k) 31 | # all_losses.append(v) 32 | loss_values.append(loss_dict[k]) 33 | loss_values = torch.stack(loss_values, dim=0) 34 | dist.reduce(loss_values, dst=0) 35 | if is_main_process(): 36 | # only main process gets accumulated, so only divide by world_size in this case 37 | loss_values /= world_size 38 | reduced_losses = {k: v for k, v in zip(loss_names, loss_values)} 39 | return reduced_losses 40 | 41 | 42 | def do_train( 43 | models, 44 | cfg, 45 | data_loader, 46 | optimizers, 47 | schedulers, 48 | checkpointer, 49 | arguments, 50 | logger_name, 51 | epoch_size, 52 | evaluators=None, 53 | ): 54 | checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD 55 | print_period = cfg.SOLVER.PRINT_PERIOD 56 | test_period = cfg.SOLVER.TEST_PERIOD 57 | device = torch.device(cfg.MODEL.DEVICE) 58 | 59 | if is_main_process(): 60 | save_dir = checkpointer.save_dir 61 | out_dir = os.path.dirname(save_dir) 62 | tensorboard_dir = os.path.join(out_dir, 'tensorboard') 63 | writer = SummaryWriter(tensorboard_dir) 64 | 65 | start_iter = arguments["iteration"] 66 | model_backbone = models['backbone'] 67 | model_generator = models['generator'] 68 | model_discriminator = models['discriminator'] 69 | optimizer_generator = optimizers['generator'] 70 | optimizer_discriminator = optimizers['discriminator'] 71 | scheduler_generator = schedulers['generator'] 72 | scheduler_discriminator = schedulers['discriminator'] 73 | 74 | model_backbone.eval() 75 | model_generator.train() 76 | model_discriminator.train() 77 | 78 | logger = logging.getLogger(logger_name + ".trainer") 79 | logger.info("Start training") 80 | meters = MetricLogger(delimiter=" ") 81 | max_iter = len(data_loader) 82 | 83 | training_timer = Timer() 84 | training_timer.tic() 85 | batch_timer = Timer() 86 | batch_timer.tic() 87 | data_load_timer = Timer() 88 | data_load_timer.tic() 89 | _t = Timer() 90 | 91 | # epoch方式防止重复计算 92 | save_epoch = 0 93 | test_epoch = 0 94 | image_epoch = 0 95 | 96 | for iteration, (real_images, style_images, smooth_images, _) in enumerate(data_loader, start_iter): 97 | data_load_time = data_load_timer.toc() 98 | 99 | real_images_color = real_images[0] 100 | real_images_gray = real_images[1] 101 | style_images_color = style_images[0] 102 | style_images_gray = style_images[1] 103 | smooth_images_color = smooth_images[0] 104 | smooth_images_gray = smooth_images[1] 105 | arguments["iteration"] = iteration 106 | # 当前epoch度量 107 | epoch_current = math.ceil((iteration + 1) / epoch_size) 108 | 109 | # init阶段 110 | if epoch_current <= cfg.SOLVER.GENERATOR.INIT_EPOCH: 111 | # FP 112 | _t.tic() 113 | real_images_color = real_images_color.to(device) 114 | generated = model_generator(real_images_color) 115 | loss_init = init_loss(model_backbone, real_images_color, generated) 116 | INIT_FP_time = _t.toc() 117 | loss_dict = {"Init_loss": loss_init} 118 | # BP 119 | _t.tic() 120 | optimizer_generator.zero_grad() 121 | loss_init.backward() 122 | optimizer_generator.step() 123 | scheduler_generator.step() 124 | scheduler_discriminator.step() 125 | INIT_BP_time = _t.toc() 126 | meters.update(INIT_FP_time=INIT_FP_time, INIT_BP_time=INIT_BP_time) 127 | # 正常训练阶段 128 | else: 129 | real_images_color = real_images_color.to(device) 130 | style_images_color = style_images_color.to(device) 131 | style_images_gray = style_images_gray.to(device) 132 | smooth_images_gray = smooth_images_gray.to(device) 133 | generated = model_generator(real_images_color) 134 | loss_dict = {} 135 | if iteration % cfg.MODEL.COMMON.TRAINING_RATE == 0: 136 | # FP D 137 | _t.tic() 138 | generated_logit = model_discriminator(generated.detach()) 139 | anime_logit = model_discriminator(style_images_color) 140 | anime_gray_logit = model_discriminator(style_images_gray) 141 | smooth_logit = model_discriminator(smooth_images_gray) 142 | gp = gradient_panalty(model_discriminator, style_images_color, generated.detach()) 143 | loss_d = d_loss( 144 | generated_logit, 145 | anime_logit, 146 | anime_gray_logit, 147 | smooth_logit 148 | ) + gp 149 | D_FP_time = _t.toc() 150 | loss_dict.update({"D_loss": loss_d}) 151 | # BP D 152 | _t.tic() 153 | optimizer_discriminator.zero_grad() 154 | loss_d.backward() 155 | optimizer_discriminator.step() 156 | D_BP_time = _t.toc() 157 | meters.update(D_FP_time=D_FP_time, D_BP_time=D_BP_time) 158 | 159 | # FP G 160 | _t.tic() 161 | generated_logit = model_discriminator(generated) 162 | loss_g = g_loss( 163 | model_backbone, 164 | real_images_color, 165 | style_images_gray, 166 | generated, 167 | generated_logit 168 | ) 169 | G_FP_time = _t.toc() 170 | loss_dict.update({"G_loss": loss_g}) 171 | # BP G 172 | _t.tic() 173 | optimizer_generator.zero_grad() 174 | loss_g.backward() 175 | optimizer_generator.step() 176 | scheduler_generator.step() 177 | scheduler_discriminator.step() 178 | G_BP_time = _t.toc() 179 | meters.update(G_FP_time=G_FP_time, G_BP_time=G_BP_time) 180 | 181 | 182 | batch_time = batch_timer.toc() 183 | # loss记录 184 | if is_main_process(): 185 | # if iteration == 0: 186 | # writer.add_graph(model_backbone, real_images_color) 187 | # writer.add_graph(model_generator, real_images_color) 188 | # writer.add_graph(model_discriminator, real_images_color) 189 | if epoch_current != image_epoch: 190 | writer.add_image("images/real_color/{}".format(epoch_current), rgbScaled(real_images_color[0]).clamp(0, 1)) 191 | writer.add_image("images/real_gray/{}".format(epoch_current), rgbScaled(real_images_gray[0]).clamp(0, 1)) 192 | writer.add_image("images/style_color/{}".format(epoch_current), rgbScaled(style_images_color[0]).clamp(0, 1)) 193 | writer.add_image("images/style_gray/{}".format(epoch_current), rgbScaled(style_images_gray[0]).clamp(0, 1)) 194 | writer.add_image("images/smooth_color/{}".format(epoch_current), rgbScaled(smooth_images_color[0]).clamp(0, 1)) 195 | writer.add_image("images/smooth_gray/{}".format(epoch_current), rgbScaled(smooth_images_gray[0]).clamp(0, 1)) 196 | image_epoch = epoch_current 197 | writer.add_scalars('train/loss', loss_dict, iteration) 198 | # logger 199 | meters.update(batch_time=batch_time, data_time=data_load_time, **loss_dict) 200 | eta_seconds = meters.batch_time.global_avg * (max_iter - iteration) 201 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 202 | if ((iteration - 1) % print_period == 0) or iteration == max_iter - 1: 203 | epoch_string = "{epoch} | {epoch_iter:>4}/{epoch_size:<4}". \ 204 | format(epoch=epoch_current, epoch_iter=iteration % epoch_size, epoch_size=epoch_size) 205 | 206 | logger.info( 207 | meters.delimiter.join( 208 | [ 209 | "eta: {eta}", 210 | # "iter: {iter}", 211 | "epoch: {epoch}", 212 | "{meters}", 213 | "lr(G|D): {G_lr:.6f}|{D_lr:.6f}", 214 | "max mem: {memory:.0f}", 215 | ] 216 | ).format( 217 | eta=eta_string, 218 | # iter=iteration, 219 | epoch=epoch_string, 220 | meters=str(meters), 221 | G_lr=optimizer_generator.param_groups[0]["lr"], 222 | D_lr=optimizer_discriminator.param_groups[0]["lr"], 223 | memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 if torch.cuda.is_available() else 0, 224 | ) 225 | ) 226 | 227 | # checkpointer 228 | if (epoch_current - 1) % checkpoint_period == 0 and epoch_current != save_epoch and epoch_current > 1: 229 | checkpointer.save("model_{:05d}".format(epoch_current - 1), **arguments) 230 | save_epoch = epoch_current 231 | 232 | # test 233 | if iteration == (max_iter - 1) or \ 234 | ((epoch_current - 1) % test_period == 0 and epoch_current != test_epoch and epoch_current > 1): 235 | # 最后一个epoch训练完成后做正确显示 236 | epoch_current_show = epoch_current if iteration == (max_iter - 1) else (epoch_current - 1) 237 | if evaluators: 238 | model_generator.eval() 239 | for evaluator in evaluators: 240 | result = evaluator.do_inference(model_generator, epoch_current) 241 | # 只有主线程返回 242 | if result: 243 | # 用于解析日志标志 244 | logger.info("(*^_^*)") 245 | logger.info("Test model at {} dataset at {} epoch". 246 | format(evaluator.data_loader.dataset.__class__.__name__, epoch_current_show)) 247 | # synchronize after test 248 | synchronize() 249 | model_generator.train() 250 | test_epoch = epoch_current 251 | 252 | batch_timer.tic() 253 | data_load_timer.tic() 254 | checkpointer.save("model_final", **arguments) 255 | total_training_time = training_timer.toc() 256 | total_time_str = str(datetime.timedelta(seconds=total_training_time)) 257 | logger.info( 258 | "Total training time: {} ({:.4f} s / it)".format( 259 | total_time_str, total_training_time / (max_iter) 260 | ) 261 | ) 262 | if is_main_process(): 263 | writer.close() 264 | # synchronize after trainer 265 | synchronize() 266 | -------------------------------------------------------------------------------- /animegan/lib/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com -------------------------------------------------------------------------------- /animegan/lib/trainer.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | 4 | import os 5 | import torch 6 | import logging 7 | from animegan.engine.inference import Evaluator 8 | from animegan.engine.trainer import do_train 9 | from animegan.modeling.build import build_model 10 | from animegan.utils.comm import synchronize, get_rank 11 | from animegan.utils.checkpoint import ModelCheckpointer 12 | from animegan.data.build import make_datasets, make_data_loader 13 | from animegan.solver.build import make_optimizer_generator, make_optimizer_discriminator 14 | from animegan.solver.build import make_lr_scheduler_generator, make_lr_scheduler_discriminator 15 | 16 | 17 | def train(cfg, local_rank, distributed, logger_name, output_dir): 18 | model_backbone, model_generator, model_discriminator = build_model(cfg) 19 | logger = logging.getLogger(logger_name) 20 | logger.info(f"model backbone:\n{model_backbone}") 21 | logger.info(f"model generator:\n{model_generator}") 22 | logger.info(f"model discriminator:\n{model_discriminator}") 23 | 24 | device = torch.device(cfg.MODEL.DEVICE) 25 | model_backbone.to(device) 26 | model_generator.to(device) 27 | model_discriminator.to(device) 28 | 29 | arguments = {} 30 | arguments["iteration"] = 0 31 | 32 | datasets, epoch_sizes = make_datasets(cfg, is_train=True) 33 | # train阶段dataset合并成一个 34 | epoch_size = epoch_sizes[0] 35 | 36 | optimizer_generator = make_optimizer_generator(cfg, model_generator) 37 | optimizer_discriminator = make_optimizer_discriminator(cfg, model_discriminator) 38 | # TODO: epoch_size优化 39 | scheduler_generator = make_lr_scheduler_generator(cfg, optimizer_generator, epoch_size) 40 | scheduler_discriminator = make_lr_scheduler_discriminator(cfg, optimizer_discriminator, epoch_size) 41 | 42 | if distributed: 43 | model_backbone = torch.nn.parallel.DistributedDataParallel( 44 | model_backbone, device_ids=[local_rank], output_device=local_rank, 45 | broadcast_buffers=True, 46 | ) 47 | model_generator = torch.nn.parallel.DistributedDataParallel( 48 | model_generator, device_ids=[local_rank], output_device=local_rank, 49 | broadcast_buffers=True, 50 | ) 51 | model_discriminator = torch.nn.parallel.DistributedDataParallel( 52 | model_discriminator, device_ids=[local_rank], output_device=local_rank, 53 | broadcast_buffers=True, 54 | ) 55 | synchronize() 56 | models = { 57 | "generator": model_generator, 58 | "discriminator": model_discriminator 59 | } 60 | optimizers = { 61 | "generator": optimizer_generator, 62 | "discriminator": optimizer_discriminator 63 | } 64 | schedulers = { 65 | "generator": scheduler_generator, 66 | "discriminator": scheduler_discriminator 67 | } 68 | 69 | checkpointer = ModelCheckpointer( 70 | models=models, 71 | optimizers=optimizers, 72 | schedulers=schedulers, 73 | save_dir=output_dir, 74 | logger_name=logger_name 75 | ) 76 | 77 | extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT, cfg.MODEL.TRANSFER_LEARNING) 78 | arguments.update(extra_checkpoint_data) 79 | 80 | data_loader = make_data_loader( 81 | cfg=cfg, 82 | datasets=datasets, 83 | epoch_sizes=epoch_sizes, 84 | is_train=True, 85 | is_distributed=distributed, 86 | start_iter=arguments["iteration"] 87 | ) 88 | 89 | evaluators = get_evaluator(cfg, distributed, logger_name, output_dir=output_dir) 90 | models.update({"backbone": model_backbone}) 91 | 92 | do_train( 93 | models=models, 94 | cfg=cfg, 95 | data_loader=data_loader, 96 | optimizers=optimizers, 97 | schedulers=schedulers, 98 | checkpointer=checkpointer, 99 | arguments=arguments, 100 | logger_name=logger_name, 101 | epoch_size=epoch_size, 102 | evaluators=evaluators, 103 | ) 104 | 105 | def get_evaluator(cfg, distributed, logger_name, dataEntrance=None, output_dir=None): 106 | torch.cuda.empty_cache() 107 | 108 | output_folders = list() 109 | datasetsInfo = cfg.DATASETS.TEST 110 | 111 | if output_dir: 112 | for datasetInfo in datasetsInfo: 113 | _output_folder = os.path.join(output_dir, 114 | "inference", 115 | datasetInfo.get('factory')+'_'+datasetInfo.get('split')) 116 | if get_rank() == 0: 117 | os.makedirs(_output_folder, exist_ok=True) 118 | output_folders.append(_output_folder) 119 | datasets_test, epoch_sizes = make_datasets(cfg, is_train=False, dataEntrance=dataEntrance) 120 | data_loaders_test = make_data_loader(cfg, datasets=datasets_test, epoch_sizes=epoch_sizes, is_train=False, is_distributed=distributed) 121 | evaluators = list() 122 | for output_folder, data_loader_test in zip(output_folders, data_loaders_test): 123 | evaluators.append( 124 | Evaluator( 125 | data_loader=data_loader_test, 126 | logger_name=logger_name, 127 | device=cfg.MODEL.DEVICE, 128 | output_folder=output_folder, 129 | ) 130 | ) 131 | return evaluators -------------------------------------------------------------------------------- /animegan/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | -------------------------------------------------------------------------------- /animegan/modeling/backbone.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torchvision.models.utils import load_state_dict_from_url 7 | from animegan.modeling import registry 8 | 9 | model_urls = { 10 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 11 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 12 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 13 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 14 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 15 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 16 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 17 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 18 | } 19 | 20 | cfgs = { 21 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 22 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 23 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 24 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 25 | } 26 | 27 | class VGG(nn.Module): 28 | ''' 29 | 只保留features相关层 30 | ''' 31 | def __init__(self, features): 32 | super().__init__() 33 | self.features = features 34 | 35 | def forward(self, x): 36 | x = self.features(x) 37 | return x 38 | 39 | # 定制化,输出conv4_4_no_activation 40 | def make_layers(cfg, batch_norm=False): 41 | layers = [] 42 | in_channels = 3 43 | conv_stage = 1 44 | inner_stage = 0 45 | for v in cfg: 46 | if v == 'M': 47 | conv_stage += 1 48 | inner_stage = 0 49 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 50 | else: 51 | inner_stage += 1 52 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 53 | if conv_stage == 4 and inner_stage == 4: 54 | layers += [conv2d] 55 | # 特征返回点 56 | break 57 | else: 58 | if batch_norm: 59 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 60 | else: 61 | layers += [conv2d, nn.ReLU(inplace=True)] 62 | in_channels = v 63 | return nn.Sequential(*layers) 64 | 65 | @registry.BACKBONES.register("VGG19") 66 | def build_vgg_backbones(cfg): 67 | backbone_weiht = cfg.MODEL.BACKBONE.WEIGHT 68 | model = VGG(make_layers(cfgs['E'])) 69 | if not backbone_weiht: 70 | state_dict = load_state_dict_from_url(model_urls['vgg19'], progress=True) 71 | else: 72 | state_dict = torch.load(backbone_weiht) 73 | model.load_state_dict(state_dict, strict=False) 74 | return model 75 | 76 | def build_backbone(cfg): 77 | assert cfg.MODEL.BACKBONE.BODY in registry.BACKBONES, \ 78 | f"cfg.MODEL.BACKBONES.BODY: {cfg.MODEL.BACKBONE.CONV_BODY} are not registered in registry" 79 | return registry.BACKBONES[cfg.MODEL.BACKBONE.BODY](cfg) -------------------------------------------------------------------------------- /animegan/modeling/build.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | 4 | from .backbone import build_backbone 5 | from .generator import build_generator 6 | from .discriminator import build_discriminator 7 | 8 | 9 | def build_model(cfg): 10 | backbone = build_backbone(cfg) 11 | generator = build_generator(cfg) 12 | discriminator = build_discriminator(cfg) 13 | return backbone, generator, discriminator -------------------------------------------------------------------------------- /animegan/modeling/discriminator.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | 4 | import torch 5 | from torch import nn as nn 6 | from torch.nn.utils import spectral_norm 7 | from animegan.modeling import registry 8 | from animegan.modeling.layers import Layer_Norm 9 | 10 | def conv_sn(in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True): 11 | return spectral_norm( 12 | nn.Conv2d( 13 | in_channels=in_channels, 14 | out_channels=out_channels, 15 | kernel_size=kernel_size, 16 | stride=stride, 17 | padding=padding, 18 | bias=bias 19 | ) 20 | ) 21 | 22 | 23 | class D_Net(nn.Module): 24 | def __init__(self, in_channels, channels, n_dis): 25 | super().__init__() 26 | channels = channels // 2 27 | self.first = nn.Sequential( 28 | conv_sn(in_channels, channels, kernel_size=3, stride=1, padding=1, bias=False), 29 | nn.LeakyReLU(0.2) 30 | ) 31 | 32 | second_list = [] 33 | channels_in = channels 34 | for _ in range(n_dis): 35 | second_list += [ 36 | conv_sn(channels_in, channels * 2, kernel_size=3, stride=2, padding=1, bias=False), 37 | nn.LeakyReLU(0.2) 38 | ] 39 | second_list += [ 40 | conv_sn(channels * 2, channels * 4, kernel_size=3, stride=1, padding=1, bias=False), 41 | Layer_Norm(), 42 | nn.LeakyReLU(0.2) 43 | ] 44 | channels_in = channels * 4 45 | channels *= 2 46 | self.second = nn.Sequential(*second_list) 47 | 48 | self.third = nn.Sequential( 49 | conv_sn(channels_in, channels * 2, kernel_size=3, stride=1, padding=1), 50 | Layer_Norm(), 51 | nn.LeakyReLU(0.2), 52 | conv_sn(channels * 2, 1, kernel_size=3, stride=1, padding=1), 53 | ) 54 | self._initialize_weights() 55 | 56 | def _initialize_weights(self): 57 | for m in self.modules(): 58 | if isinstance(m, nn.Conv2d): 59 | nn.init.normal_(m.weight, mean=0., std=0.02) 60 | if m.bias is not None: 61 | nn.init.constant_(m.bias, 0.0) 62 | 63 | def forward(self, x): 64 | x = self.first(x) 65 | x = self.second(x) 66 | x = self.third(x) 67 | return x 68 | 69 | @registry.DISCRIMINATOR.register("Base-256") 70 | def build_base_discriminator(cfg): 71 | in_channels = cfg.MODEL.DISCRIMINATOR.IN_CHANNELS 72 | channels = cfg.MODEL.DISCRIMINATOR.CHANNELS 73 | n_dis = cfg.MODEL.DISCRIMINATOR.N_DIS 74 | return D_Net(in_channels, channels, n_dis) 75 | 76 | def build_discriminator(cfg): 77 | assert cfg.MODEL.DISCRIMINATOR.BODY in registry.DISCRIMINATOR, \ 78 | f"cfg.MODEL.DISCRIMINATOR.BODY: {cfg.MODEL.DISCRIMINATOR.CONV_BODY} are not registered in registry" 79 | return registry.DISCRIMINATOR[cfg.MODEL.DISCRIMINATOR.BODY](cfg) -------------------------------------------------------------------------------- /animegan/modeling/generator.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | 4 | import math 5 | from torch import nn as nn 6 | from animegan.modeling import registry 7 | from animegan.modeling.layers import Layer_Norm 8 | 9 | def truncated_normal_(tensor, mean=0., std=0.1): 10 | size = tensor.shape 11 | tmp = tensor.new_empty(size + (4,)).normal_() 12 | valid = (tmp < 2) & (tmp > -2) 13 | ind = valid.max(-1, keepdim=True)[1] 14 | tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) 15 | tensor.data.mul_(std).add_(mean) 16 | return tensor 17 | 18 | class Conv2DNormLReLU(nn.Module): 19 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False): 20 | super().__init__() 21 | self.Conv = nn.Conv2d(in_channels, 22 | out_channels, 23 | kernel_size=kernel_size, 24 | stride=stride, 25 | padding=padding, 26 | padding_mode='reflect', 27 | bias=bias) 28 | self.LayerNorm = Layer_Norm() 29 | self.LRelu = nn.LeakyReLU(0.2) 30 | 31 | def forward(self, x): 32 | x = self.Conv(x) 33 | x = self.LayerNorm(x) 34 | x = self.LRelu(x) 35 | return x 36 | 37 | class InvertedRes_Block(nn.Module): 38 | def __init__(self, in_channels, out_channels, expansion_ratio, stride): 39 | super().__init__() 40 | self.add_op = (in_channels == out_channels and stride == 1) 41 | bottleneck_dim = round(expansion_ratio * in_channels) 42 | # pw 43 | self.pw = Conv2DNormLReLU(in_channels, bottleneck_dim, kernel_size=1) 44 | # dw 45 | self.dw = nn.Sequential( 46 | nn.Conv2d( 47 | bottleneck_dim, 48 | bottleneck_dim, 49 | kernel_size=3, 50 | stride=stride, 51 | padding=1, 52 | groups=bottleneck_dim, 53 | padding_mode='reflect' 54 | ), 55 | Layer_Norm(), 56 | nn.LeakyReLU(0.2) 57 | ) 58 | # pw & linear 59 | self.pw_linear = nn.Sequential( 60 | nn.Conv2d(bottleneck_dim, out_channels, kernel_size=1, bias=False, padding_mode='reflect'), 61 | Layer_Norm() 62 | ) 63 | 64 | def forward(self, x): 65 | out = self.pw(x) 66 | out = self.dw(out) 67 | out = self.pw_linear(out) 68 | if self.add_op: 69 | out += x 70 | return out 71 | 72 | class G_Net(nn.Module): 73 | def __init__(self, in_channels): 74 | super().__init__() 75 | self.A = nn.Sequential( 76 | Conv2DNormLReLU(in_channels, 32, kernel_size=7, padding=3), 77 | Conv2DNormLReLU(32, 64, kernel_size=3, stride=2, padding=1), 78 | Conv2DNormLReLU(64, 64, kernel_size=3, padding=1) 79 | ) 80 | self.B = nn.Sequential( 81 | Conv2DNormLReLU(64, 128, kernel_size=3, stride=2, padding=1), 82 | Conv2DNormLReLU(128, 128, kernel_size=3, padding=1), 83 | ) 84 | self.C = nn.Sequential( 85 | Conv2DNormLReLU(128, 128, kernel_size=3, padding=1), 86 | InvertedRes_Block(128, 256, 2, 1), 87 | InvertedRes_Block(256, 256, 2, 1), 88 | InvertedRes_Block(256, 256, 2, 1), 89 | InvertedRes_Block(256, 256, 2, 1), 90 | Conv2DNormLReLU(256, 128, kernel_size=3, padding=1) 91 | ) 92 | self.D = nn.Sequential( 93 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), 94 | Conv2DNormLReLU(128, 128, kernel_size=3, padding=1), 95 | Conv2DNormLReLU(128, 128, kernel_size=3, padding=1) 96 | ) 97 | self.E = nn.Sequential( 98 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), 99 | Conv2DNormLReLU(128, 64, kernel_size=3, padding=1), 100 | Conv2DNormLReLU(64, 64, kernel_size=3, padding=1), 101 | Conv2DNormLReLU(64, 32, kernel_size=7, padding=3) 102 | ) 103 | self.F = nn.Sequential( 104 | nn.Conv2d(32, 3, kernel_size=1, stride=1, bias=False, padding_mode='reflect'), 105 | nn.Tanh() 106 | ) 107 | # self._initialize_weights() 108 | 109 | def _initialize_weights(self): 110 | for m in self.modules(): 111 | if isinstance(m, nn.Conv2d): 112 | # variance_scaling_initializer 113 | # https://docs.w3cub.com/tensorflow~python/tf/contrib/layers/variance_scaling_initializer 114 | truncated_normal_(m.weight, mean=0., std=math.sqrt(1.3 * 2.0 / m.in_channels)) 115 | if m.bias is not None: 116 | nn.init.constant_(m.bias, 0) 117 | 118 | def forward(self, x): 119 | x = self.A(x) 120 | x = self.B(x) 121 | x = self.C(x) 122 | x = self.D(x) 123 | x = self.E(x) 124 | x = self.F(x) 125 | return x 126 | 127 | @registry.GENERATOR.register("Base-256") 128 | def build_base_generator(cfg): 129 | in_channels = cfg.MODEL.GENERATOR.IN_CHANNELS 130 | return G_Net(in_channels) 131 | 132 | def build_generator(cfg): 133 | assert cfg.MODEL.GENERATOR.BODY in registry.GENERATOR, \ 134 | f"cfg.MODEL.GENERATOR.BODY: {cfg.MODEL.BACKBONE.CONV_BODY} are not registered in registry" 135 | return registry.GENERATOR[cfg.MODEL.GENERATOR.BODY](cfg) -------------------------------------------------------------------------------- /animegan/modeling/layers.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | 4 | from torch import nn as nn 5 | from torch.nn import functional as F 6 | 7 | class Layer_Norm(nn.Module): 8 | def __init__(self): 9 | super().__init__() 10 | 11 | def forward(self, x): 12 | return F.layer_norm(x, x.size()[1:]) -------------------------------------------------------------------------------- /animegan/modeling/loss.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | 4 | import torch 5 | from animegan.configs import cfg 6 | from torch.nn import functional as F 7 | from .utils import gram, rgb2yuv, prepare_feature_extract, color_2_gray 8 | 9 | def init_loss(model_backbone, real_images_color, generated): 10 | fake = generated 11 | real_feature_map = model_backbone(prepare_feature_extract(real_images_color)) 12 | fake_feature_map = model_backbone(prepare_feature_extract(fake)) 13 | loss = F.l1_loss(real_feature_map, fake_feature_map, reduction='mean') 14 | return loss * cfg.MODEL.COMMON.WEIGHT_G_CON 15 | 16 | def g_loss(model_backbone, real_images_color, style_images_gray, generated, generated_logit): 17 | fake = generated 18 | real_feature_map = model_backbone(prepare_feature_extract(real_images_color)) 19 | fake_feature_map = model_backbone(prepare_feature_extract(fake)) 20 | fake_feature_map_gray = model_backbone(prepare_feature_extract(color_2_gray(fake))) 21 | anime_feature_map = model_backbone(prepare_feature_extract(style_images_gray)) 22 | 23 | c_loss = F.l1_loss(real_feature_map, fake_feature_map, reduction='mean') 24 | 25 | if cfg.MODEL.LOSS.S_LOSS_COLOR2GRAY: 26 | s_loss = F.l1_loss(gram(anime_feature_map), gram(fake_feature_map_gray), reduction='mean') 27 | else: 28 | s_loss = F.l1_loss(gram(anime_feature_map), gram(fake_feature_map), reduction='mean') 29 | 30 | real_images_color_yuv = rgb2yuv(real_images_color) 31 | fake_yuv = rgb2yuv(fake) 32 | color_loss = F.l1_loss(real_images_color_yuv[..., 0], fake_yuv[..., 0], reduction='mean') + \ 33 | F.smooth_l1_loss(real_images_color_yuv[..., 1], fake_yuv[..., 1], reduction='mean') + \ 34 | F.smooth_l1_loss(real_images_color_yuv[..., 2], fake_yuv[..., 2], reduction='mean') 35 | 36 | 37 | dh_input, dh_target = fake[:, :, :-1, :], fake[:, :, 1:, :] 38 | dw_input, dw_target = fake[:, :, :, :-1], fake[:, :, :, 1:] 39 | tv_loss = F.mse_loss(dh_input, dh_target, reduction='mean') + \ 40 | F.mse_loss(dw_input, dw_target, reduction='mean') 41 | 42 | loss_func = cfg.MODEL.COMMON.GAN_TYPE 43 | if loss_func == 'wgan-gp' or loss_func == 'wgan-lp': 44 | fake_loss = - torch.mean(generated_logit) 45 | elif loss_func == 'lsgan': 46 | fake_loss = torch.mean(torch.square(generated_logit - 1.0)) 47 | elif loss_func == 'gan' or loss_func == 'dragan': 48 | fake_loss = F.binary_cross_entropy_with_logits(generated_logit, torch.ones_like(generated_logit), reduction='mean') 49 | elif loss_func == 'hinge': 50 | fake_loss = - torch.mean(generated_logit) 51 | else: 52 | raise NotImplementedError 53 | return cfg.MODEL.COMMON.WEIGHT_G_CON * c_loss + \ 54 | cfg.MODEL.COMMON.WEIGHT_G_STYLE * s_loss + \ 55 | cfg.MODEL.COMMON.WEIGHT_G_COLOR * color_loss + \ 56 | cfg.MODEL.COMMON.WEIGHT_G_TV * tv_loss + \ 57 | cfg.MODEL.COMMON.WEIGHT_ADV_G * fake_loss 58 | 59 | from torch import autograd 60 | from torch.autograd import Variable 61 | def gradient_panalty(model_discriminator, style_images_color, generted): 62 | loss_func = cfg.MODEL.COMMON.GAN_TYPE 63 | if loss_func not in ['dragan', 'wgan-gp', 'wgan-lp']: 64 | return 0 65 | if loss_func == 'dragan': 66 | eps = torch.empty_like(style_images_color).uniform_(0, 1) 67 | x_var = style_images_color.var() 68 | x_std = torch.sqrt(x_var) 69 | generted = style_images_color + 0.5 * x_std * eps 70 | b, c, h, w = style_images_color.shape 71 | device = style_images_color.device 72 | alpha = torch.Tensor(b, 1, 1, 1).uniform_(0, 1) 73 | alpha = alpha.expand(b, c, h, w).to(device) 74 | interpolated = style_images_color + alpha * (generted - style_images_color) 75 | 76 | # define it to calculate gradient 77 | interpolated = Variable(interpolated, requires_grad=True) 78 | # calculate probability of interpolated examples 79 | prob_interpolated = model_discriminator(interpolated) 80 | # calculate gradients of probabilities with respect to examples 81 | gradients = autograd.grad( 82 | outputs=prob_interpolated, 83 | inputs=interpolated, 84 | grad_outputs=torch.ones(prob_interpolated.size()).to(device), 85 | create_graph=True, 86 | retain_graph=True 87 | )[0] 88 | GP = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * cfg.MODEL.COMMON.LD 89 | return GP 90 | 91 | def d_loss(generated_logit, anime_logit, anime_gray_logit, smooth_logit): 92 | loss_func = cfg.MODEL.COMMON.GAN_TYPE 93 | if loss_func == 'wgan-gp' or loss_func == 'wgan-lp': 94 | real_loss = - torch.mean(anime_logit) 95 | gray_loss = torch.mean(anime_gray_logit) 96 | fake_loss = torch.mean(generated_logit) 97 | real_blur_loss = torch.mean(smooth_logit) 98 | elif loss_func == 'lsgan': 99 | real_loss = torch.mean(torch.square(anime_logit - 1.0)) 100 | gray_loss = torch.mean(torch.square(anime_gray_logit)) 101 | fake_loss = torch.mean(torch.square(generated_logit)) 102 | real_blur_loss = torch.mean(torch.square(smooth_logit)) 103 | elif loss_func == 'gan' or loss_func == 'dragan': 104 | real_loss = F.binary_cross_entropy_with_logits(anime_logit, torch.ones_like(anime_logit), reduction='mean') 105 | gray_loss = F.binary_cross_entropy_with_logits(anime_gray_logit, torch.zeros_like(anime_gray_logit), reduction='mean') 106 | fake_loss = F.binary_cross_entropy_with_logits(generated_logit, torch.zeros_like(generated_logit), reduction='mean') 107 | real_blur_loss = F.binary_cross_entropy_with_logits(smooth_logit, torch.zeros_like(smooth_logit), reduction='mean') 108 | elif loss_func == 'hinge': 109 | real_loss = torch.mean(torch.relu(1.0 - anime_logit)) 110 | gray_loss = torch.mean(torch.relu(1.0 + anime_gray_logit)) 111 | fake_loss = torch.mean(torch.relu(1.0 + generated_logit)) 112 | real_blur_loss = torch.mean(torch.relu(1.0 + smooth_logit)) 113 | else: 114 | raise NotImplementedError 115 | return cfg.MODEL.COMMON.WEIGHT_ADV_D * ( 116 | cfg.MODEL.COMMON.WEIGHT_D_LOSS_REAL * real_loss + 117 | cfg.MODEL.COMMON.WEIGHT_D_LOSS_FAKE * fake_loss + 118 | cfg.MODEL.COMMON.WEIGHT_D_LOSS_GRAY * gray_loss + 119 | cfg.MODEL.COMMON.WEIGHT_D_LOSS_BLUR * real_blur_loss 120 | ) -------------------------------------------------------------------------------- /animegan/modeling/registry.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | 4 | from animegan.utils.registry import Registry 5 | 6 | BACKBONES = Registry() 7 | GENERATOR = Registry() 8 | DISCRIMINATOR = Registry() -------------------------------------------------------------------------------- /animegan/modeling/utils.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | 4 | import torch 5 | import numpy as np 6 | 7 | yuv_from_rgb = np.array([[0.299, 0.587, 0.114], 8 | [-0.14714119, -0.28886916, 0.43601035], 9 | [0.61497538, -0.51496512, -0.10001026]]) 10 | 11 | # Pretrained 12 | # rgb 13 | # feature_extract_mean = [0.485, 0.456, 0.406] 14 | # feature_extract_std = [0.229, 0.224, 0.225] 15 | feature_extract_mean = [123.68, 116.779, 103.939] 16 | 17 | def rgbScaled(x): 18 | # [-1, 1] ~ [0, 1] 19 | return (x + 1.0) / 2.0 20 | 21 | def rgb2yuv(x): 22 | x = rgbScaled(x) 23 | x = x.permute([0, 2, 3, 1]) 24 | k_yuv_from_rgb = torch.from_numpy(yuv_from_rgb.T).to(x.dtype).to(x.device) 25 | yuv = torch.matmul(x, k_yuv_from_rgb) 26 | # yuv = yuv.permute([0, 3, 1, 2]) 27 | return yuv 28 | 29 | def color_2_gray(x): 30 | x = x.permute([0, 2, 3, 1]) 31 | k_color_2_gray = torch.Tensor([[0.299], [0.587], [0.114]]).to(x.dtype).to(x.device) 32 | gray = torch.matmul(x, k_color_2_gray) 33 | gray = torch.cat([gray, gray, gray], dim=-1) 34 | gray = gray.permute([0, 3, 1, 2]) 35 | return gray 36 | 37 | def gram(x): 38 | # [b, c, h, w] -> [b, h, w, c] 39 | x = x.permute([0, 2, 3, 1]) 40 | shape = x.shape 41 | b = shape[0] 42 | c = shape[3] 43 | x = torch.reshape(x, [b, -1, c]) 44 | return torch.matmul(x.permute(0, 2, 1), x) / (x.numel() // b) 45 | 46 | def prepare_feature_extract(rgb): 47 | # [-1, 1] ~ [0, 255] 48 | rgb_scaled = rgbScaled(rgb) * 255.0 49 | R, G, B = torch.chunk(rgb_scaled, 3, 1) 50 | feature_extract_input = torch.cat( 51 | [ 52 | (B - feature_extract_mean[2]), 53 | (G - feature_extract_mean[1]), 54 | (R - feature_extract_mean[0]), 55 | ], 56 | dim=1 57 | ) 58 | return feature_extract_input 59 | 60 | # Calculates the average brightness in the specified irregular image 61 | def calculate_average_brightness(img): 62 | # Average value of three color channels 63 | R = img[..., 0].mean() 64 | G = img[..., 1].mean() 65 | B = img[..., 2].mean() 66 | 67 | brightness = 0.299 * R + 0.587 * G + 0.114 * B 68 | return brightness, B, G, R 69 | 70 | # Adjusting the average brightness of the target image to the average brightness of the source image 71 | def adjust_brightness_from_src_to_dst(dst, src): 72 | brightness1, B1, G1, R1 = calculate_average_brightness(src) 73 | brightness2, B2, G2, R2 = calculate_average_brightness(dst) 74 | brightness_difference = brightness1 / brightness2 75 | 76 | # According to the average display brightness 77 | dstf = dst * brightness_difference 78 | 79 | # According to the average value of the three-color channel 80 | # dstf = dst.copy().astype(np.float32) 81 | # dstf[..., 0] = dst[..., 0] * (B1 / B2) 82 | # dstf[..., 1] = dst[..., 1] * (G1 / G2) 83 | # dstf[..., 2] = dst[..., 2] * (R1 / R2) 84 | 85 | # To limit the results and prevent crossing the border, 86 | # it must be converted to uint8, otherwise the default result is float32, and errors will occur. 87 | dstf = np.clip(dstf, 0, 255) 88 | dstf = np.uint8(dstf) 89 | 90 | return dstf -------------------------------------------------------------------------------- /animegan/solver/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | -------------------------------------------------------------------------------- /animegan/solver/build.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | 4 | import torch 5 | from .lr_scheduler import WarmupMultiStepLR 6 | 7 | def make_optimizer_generator(cfg, model): 8 | lr = cfg.SOLVER.GENERATOR.BASE_LR 9 | optimizer = torch.optim.Adam(model.parameters(), lr, betas=(0.5, 0.999)) 10 | return optimizer 11 | 12 | def make_optimizer_discriminator(cfg, model): 13 | lr = cfg.SOLVER.DISCRIMINATOR.BASE_LR 14 | optimizer = torch.optim.Adam(model.parameters(), lr, betas=(0.5, 0.999)) 15 | return optimizer 16 | 17 | 18 | def make_lr_scheduler_generator(cfg, optimizer, epoch_size): 19 | return WarmupMultiStepLR( 20 | optimizer, 21 | [step * epoch_size for step in cfg.SOLVER.GENERATOR.STEPS], 22 | cfg.SOLVER.GENERATOR.GAMMA, 23 | warmup_factor=cfg.SOLVER.GENERATOR.WARMUP_FACTOR, 24 | warmup_iters=cfg.SOLVER.GENERATOR.WARMUP_ITERS, 25 | warmup_method=cfg.SOLVER.GENERATOR.WARMUP_METHOD, 26 | ) 27 | 28 | def make_lr_scheduler_discriminator(cfg, optimizer, epoch_size): 29 | return WarmupMultiStepLR( 30 | optimizer, 31 | [step * epoch_size for step in cfg.SOLVER.DISCRIMINATOR.STEPS], 32 | cfg.SOLVER.DISCRIMINATOR.GAMMA, 33 | warmup_factor=cfg.SOLVER.DISCRIMINATOR.WARMUP_FACTOR, 34 | warmup_iters=cfg.SOLVER.DISCRIMINATOR.WARMUP_ITERS, 35 | warmup_method=cfg.SOLVER.DISCRIMINATOR.WARMUP_METHOD, 36 | ) -------------------------------------------------------------------------------- /animegan/solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 4 | 5 | import torch 6 | from bisect import bisect_right 7 | 8 | # FIXME ideally this would be achieved with a CombinedLRScheduler, 9 | # separating MultiStepLR with WarmupLR 10 | # but the current LRScheduler design doesn't allow it 11 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 12 | def __init__( 13 | self, 14 | optimizer, 15 | milestones, 16 | gamma=0.1, 17 | warmup_factor=1.0 / 3, 18 | warmup_iters=500, 19 | warmup_method="linear", 20 | last_epoch=-1, 21 | ): 22 | if not list(milestones) == sorted(milestones): 23 | raise ValueError( 24 | "Milestones should be a list of" " increasing integers. Got {}", 25 | milestones, 26 | ) 27 | 28 | if warmup_method not in ("constant", "linear"): 29 | raise ValueError( 30 | "Only 'constant' or 'linear' warmup_method accepted" 31 | "got {}".format(warmup_method) 32 | ) 33 | self.milestones = milestones 34 | self.gamma = gamma 35 | self.warmup_factor = warmup_factor 36 | self.warmup_iters = warmup_iters 37 | self.warmup_method = warmup_method 38 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 39 | 40 | def get_lr(self): 41 | warmup_factor = 1 42 | if self.last_epoch < self.warmup_iters: 43 | if self.warmup_method == "constant": 44 | warmup_factor = self.warmup_factor 45 | elif self.warmup_method == "linear": 46 | # alpha = self.last_epoch / self.warmup_iters 47 | alpha = float(self.last_epoch) / self.warmup_iters 48 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 49 | return [ 50 | base_lr 51 | * warmup_factor 52 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 53 | for base_lr in self.base_lrs 54 | ] 55 | -------------------------------------------------------------------------------- /animegan/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | -------------------------------------------------------------------------------- /animegan/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | 4 | import os 5 | import torch 6 | import logging 7 | from animegan.utils.comm import get_rank 8 | from animegan.utils.model_serialization import load_state_dict 9 | 10 | class ModelCheckpointer(): 11 | def __init__( 12 | self, 13 | models, 14 | optimizers=None, 15 | schedulers=None, 16 | save_dir=None, 17 | logger_name=None, 18 | ): 19 | if get_rank() != 0: 20 | save_dir = None 21 | self.models = models 22 | self.optimizers = optimizers 23 | self.schedulers = schedulers 24 | self.save_dir = save_dir 25 | if save_dir is not None: 26 | self.save_dir = os.path.join(save_dir, "model_record") 27 | if logger_name: 28 | logger = logging.getLogger(logger_name) 29 | else: 30 | logger = logging.getLogger(__name__) 31 | self.logger = logger 32 | 33 | def save(self, name, **kwargs): 34 | if not self.save_dir: 35 | return 36 | os.makedirs(self.save_dir, exist_ok=True) 37 | 38 | data = {"models": {}, "optimizers": {}, "schedulers": {}} 39 | data["models"]["generator"] = self.models['generator'].state_dict() 40 | data["models"]["discriminator"] = self.models['discriminator'].state_dict() 41 | if self.optimizers is not None: 42 | data["optimizers"]["generator"] = self.optimizers['generator'].state_dict() 43 | data["optimizers"]["discriminator"] = self.optimizers['discriminator'].state_dict() 44 | if self.schedulers is not None: 45 | data["schedulers"]["generator"] = self.schedulers['generator'].state_dict() 46 | data["schedulers"]["discriminator"] = self.schedulers['discriminator'].state_dict() 47 | 48 | data.update(kwargs) 49 | 50 | save_file = os.path.join(self.save_dir, "{}.pth".format(name)) 51 | self.logger.info("Saving checkpoint to {}".format(save_file)) 52 | torch.save(data, save_file) 53 | self.tag_last_checkpoint(save_file) 54 | 55 | def load(self, f=None, transfer_learning=False): 56 | ''' 57 | Arguments: 58 | f: checkpoint file 59 | transfer_learning: bool, continue train if False 60 | return: 61 | checkpoint 62 | ''' 63 | if not f: 64 | if self.has_checkpoint(): 65 | # override argument with existing checkpoint 66 | f = self.get_checkpoint_file() 67 | if not f: 68 | # no checkpoint could be found 69 | self.logger.info("No checkpoint found. Initializing model from scratch") 70 | return {} 71 | 72 | self.logger.info("Loading checkpoint from {}".format(f)) 73 | 74 | checkpoint = self._load_file(f) 75 | self._load_model(checkpoint) 76 | 77 | if transfer_learning: 78 | checkpoint["iteration"] = 0 79 | return checkpoint 80 | 81 | if "optimizers" in checkpoint and self.optimizers: 82 | self.logger.info("Loading optimizer from {}".format(f)) 83 | checkpoint_optimizers = checkpoint.pop("optimizers") 84 | self.optimizers["generator"].load_state_dict(checkpoint_optimizers.pop("generator")) 85 | self.optimizers["discriminator"].load_state_dict(checkpoint_optimizers.pop("discriminator")) 86 | if "schedulers" in checkpoint and self.schedulers: 87 | self.logger.info("Loading scheduler from {}".format(f)) 88 | checkpoint_schedulers = checkpoint.pop("schedulers") 89 | self.schedulers["generator"].load_state_dict(checkpoint_schedulers.pop("generator")) 90 | self.schedulers["discriminator"].load_state_dict(checkpoint_schedulers.pop("discriminator")) 91 | # return any further checkpoint data 92 | return checkpoint 93 | 94 | def has_checkpoint(self): 95 | if self.save_dir is None: 96 | return False 97 | save_file = os.path.join(self.save_dir, "last_checkpoint") 98 | return os.path.exists(save_file) 99 | 100 | def get_checkpoint_file(self): 101 | assert self.save_dir is not None 102 | save_file = os.path.join(self.save_dir, "last_checkpoint") 103 | try: 104 | with open(save_file, "r") as f: 105 | last_saved = f.read() 106 | last_saved = last_saved.strip() 107 | except IOError: 108 | # if file doesn't exist, maybe because it has just been 109 | # deleted by a separate process 110 | last_saved = "" 111 | return last_saved 112 | 113 | def tag_last_checkpoint(self, last_filename): 114 | assert self.save_dir is not None 115 | save_file = os.path.join(self.save_dir, "last_checkpoint") 116 | with open(save_file, "w") as f: 117 | f.write(last_filename) 118 | 119 | def _load_file(self, f): 120 | return torch.load(f, map_location=torch.device("cpu")) 121 | 122 | def _load_model(self, checkpoint): 123 | checkpoint_models = checkpoint.pop("models") 124 | load_state_dict(self.models["generator"], checkpoint_models.pop("generator")) 125 | load_state_dict(self.models["discriminator"], checkpoint_models.pop("discriminator")) 126 | -------------------------------------------------------------------------------- /animegan/utils/comm.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 4 | 5 | """ 6 | This file contains primitives for multi-gpu communication. 7 | This is useful when doing distributed training. 8 | """ 9 | 10 | import pickle 11 | import time 12 | 13 | import torch 14 | import torch.distributed as dist 15 | 16 | 17 | def get_world_size(): 18 | if not dist.is_available(): 19 | return 1 20 | if not dist.is_initialized(): 21 | return 1 22 | return dist.get_world_size() 23 | 24 | 25 | def get_rank(): 26 | if not dist.is_available(): 27 | return 0 28 | if not dist.is_initialized(): 29 | return 0 30 | return dist.get_rank() 31 | 32 | 33 | def is_main_process(): 34 | return get_rank() == 0 35 | 36 | 37 | def synchronize(): 38 | """ 39 | Helper function to synchronize (barrier) among all processes when 40 | using distributed training 41 | """ 42 | if not dist.is_available(): 43 | return 44 | if not dist.is_initialized(): 45 | return 46 | world_size = dist.get_world_size() 47 | if world_size == 1: 48 | return 49 | dist.barrier() 50 | 51 | 52 | def all_gather(data): 53 | """ 54 | Run all_gather on arbitrary picklable data (not necessarily tensors) 55 | Args: 56 | data: any picklable object 57 | Returns: 58 | list[data]: list of data gathered from each rank 59 | """ 60 | world_size = get_world_size() 61 | if world_size == 1: 62 | return [data] 63 | 64 | # serialized to a Tensor 65 | buffer = pickle.dumps(data) 66 | storage = torch.ByteStorage.from_buffer(buffer) 67 | tensor = torch.ByteTensor(storage).to("cuda") 68 | 69 | # obtain Tensor size of each rank 70 | local_size = torch.LongTensor([tensor.numel()]).to("cuda") 71 | size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)] 72 | dist.all_gather(size_list, local_size) 73 | size_list = [int(size.item()) for size in size_list] 74 | max_size = max(size_list) 75 | 76 | # receiving Tensor from all ranks 77 | # we pad the tensor because torch all_gather does not support 78 | # gathering tensors of different shapes 79 | tensor_list = [] 80 | for _ in size_list: 81 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda")) 82 | if local_size != max_size: 83 | padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda") 84 | tensor = torch.cat((tensor, padding), dim=0) 85 | dist.all_gather(tensor_list, tensor) 86 | 87 | data_list = [] 88 | for size, tensor in zip(size_list, tensor_list): 89 | buffer = tensor.cpu().numpy().tobytes()[:size] 90 | data_list.append(pickle.loads(buffer)) 91 | 92 | return data_list 93 | 94 | 95 | def reduce_dict(input_dict, average=True): 96 | """ 97 | Args: 98 | input_dict (dict): all the values will be reduced 99 | average (bool): whether to do average or sum 100 | Reduce the values in the dictionary from all processes so that process with rank 101 | 0 has the averaged results. Returns a dict with the same fields as 102 | input_dict, after reduction. 103 | """ 104 | world_size = get_world_size() 105 | if world_size < 2: 106 | return input_dict 107 | with torch.no_grad(): 108 | names = [] 109 | values = [] 110 | # sort the keys so that they are consistent across processes 111 | for k in sorted(input_dict.keys()): 112 | names.append(k) 113 | values.append(input_dict[k]) 114 | values = torch.stack(values, dim=0) 115 | dist.reduce(values, dst=0) 116 | if dist.get_rank() == 0 and average: 117 | # only main process gets accumulated, so only divide by 118 | # world_size in this case 119 | values /= world_size 120 | reduced_dict = {k: v for k, v in zip(names, values)} 121 | return reduced_dict 122 | -------------------------------------------------------------------------------- /animegan/utils/datasetInfo.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | 4 | ''' 5 | DatasetInfo class 6 | ''' 7 | 8 | class DatasetInfo(object): 9 | ''' 10 | manage dataset 11 | ''' 12 | def __init__(self, factory, **kwargs): 13 | ''' 14 | Arguments: 15 | dataDir: dataset dir 16 | factory: factory of data interface 17 | split: train or test 18 | kwargs: other args for data interface 19 | ''' 20 | self.factory = factory 21 | self.kwargs = kwargs 22 | 23 | def get(self): 24 | ''' 25 | return: 26 | dict contain the infomation of dataset 27 | ''' 28 | return dict( 29 | factory=self.factory, 30 | args=self.kwargs 31 | ) 32 | -------------------------------------------------------------------------------- /animegan/utils/env.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | 4 | import sys 5 | from torch.utils.collect_env import get_pretty_env_info 6 | import PIL 7 | 8 | SUPPORTED_DENY = ['win32'] 9 | 10 | def get_pil_version(): 11 | return "\n Pillow ({})".format(PIL.__version__) 12 | 13 | def collect_env_info(): 14 | if sys.platform.lower() in SUPPORTED_DENY: 15 | return "Warning: collect_env_info not supported on {}.".format(sys.platform.lower()) 16 | env_str = get_pretty_env_info() 17 | env_str += get_pil_version() 18 | return env_str 19 | -------------------------------------------------------------------------------- /animegan/utils/logger.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | 4 | import os 5 | import sys 6 | import torch 7 | import logging 8 | from collections import deque 9 | from collections import defaultdict 10 | 11 | 12 | class SmoothedValue(object): 13 | """Track a series of values and provide access to smoothed values over a 14 | window or the global series average. 15 | """ 16 | 17 | def __init__(self, window_size=20): 18 | self.deque = deque(maxlen=window_size) 19 | self.series = [] 20 | self.total = 0.0 21 | self.count = 0 22 | 23 | def update(self, value): 24 | self.deque.append(value) 25 | self.series.append(value) 26 | self.count += 1 27 | self.total += value 28 | 29 | @property 30 | def median(self): 31 | d = torch.tensor(list(self.deque)) 32 | return d.median().item() 33 | 34 | @property 35 | def avg(self): 36 | d = torch.tensor(list(self.deque)) 37 | return d.mean().item() 38 | 39 | @property 40 | def global_avg(self): 41 | return self.total / self.count 42 | 43 | 44 | class MetricLogger(object): 45 | def __init__(self, delimiter="\t"): 46 | self.meters = defaultdict(SmoothedValue) 47 | self.show = [] 48 | self.delimiter = delimiter 49 | 50 | def update(self, **kwargs): 51 | for k, v in kwargs.items(): 52 | if isinstance(v, torch.Tensor): 53 | v = v.item() 54 | assert isinstance(v, (float, int)) 55 | self.meters[k].update(v) 56 | self.show = kwargs.keys() 57 | 58 | def __getattr__(self, attr): 59 | if attr in self.meters: 60 | return self.meters[attr] 61 | # return object.__getattr__(self, attr) 62 | if attr in self.__dict__: 63 | return self.__dict__[attr] 64 | raise AttributeError("'{}' object has no attribute '{}'".format( 65 | type(self).__name__, attr)) 66 | 67 | def __str__(self): 68 | loss_str = [] 69 | for name, meter in self.meters.items(): 70 | # 本轮未更新的参数信息不展示 71 | if name not in self.show: 72 | continue 73 | loss_str.append( 74 | "{}: {:.4f} ({:.4f})".format(name, meter.avg, meter.median) 75 | ) 76 | return self.delimiter.join(loss_str) 77 | 78 | def setup_logger(name, distributed_rank, logFile="log.txt"): 79 | logger = logging.getLogger(name) 80 | logger.setLevel(logging.DEBUG) 81 | # don't log results for the non-master process 82 | if distributed_rank > 0: 83 | return logger 84 | ch = logging.StreamHandler(stream=sys.stdout) 85 | ch.setLevel(logging.DEBUG) 86 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 87 | ch.setFormatter(formatter) 88 | logger.addHandler(ch) 89 | 90 | if logFile: 91 | fh = logging.FileHandler(logFile) 92 | fh.setLevel(logging.DEBUG) 93 | fh.setFormatter(formatter) 94 | logger.addHandler(fh) 95 | 96 | return logger -------------------------------------------------------------------------------- /animegan/utils/model_serialization.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 4 | 5 | import logging 6 | import torch 7 | from collections import OrderedDict 8 | 9 | def align_and_update_state_dicts(model_state_dict, loaded_state_dict): 10 | """ 11 | Strategy: suppose that the models that we will create will have prefixes appended 12 | to each of its keys, for example due to an extra level of nesting that the original 13 | pre-trained weights from ImageNet won't contain. For example, model.state_dict() 14 | might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains 15 | res2.conv1.weight. We thus want to match both parameters together. 16 | For that, we look for each model weight, look among all loaded keys if there is one 17 | that is a suffix of the current weight name, and use it if that's the case. 18 | If multiple matches exist, take the one with longest size 19 | of the corresponding name. For example, for the same model as before, the pretrained 20 | weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case, 21 | we want to match backbone[0].body.conv1.weight to conv1.weight, and 22 | backbone[0].body.res2.conv1.weight to res2.conv1.weight. 23 | """ 24 | current_keys = sorted(list(model_state_dict.keys())) 25 | loaded_keys = sorted(list(loaded_state_dict.keys())) 26 | # get a matrix of string matches, where each (i, j) entry correspond to the size of the 27 | # loaded_key string, if it matches 28 | match_matrix = [ 29 | len(j) if i.endswith(j) else 0 for i in current_keys for j in loaded_keys 30 | ] 31 | match_matrix = torch.as_tensor(match_matrix).view( 32 | len(current_keys), len(loaded_keys) 33 | ) 34 | max_match_size, idxs = match_matrix.max(1) 35 | # remove indices that correspond to no-match 36 | idxs[max_match_size == 0] = -1 37 | 38 | # used for logging 39 | max_size = max([len(key) for key in current_keys]) if current_keys else 1 40 | max_size_loaded = max([len(key) for key in loaded_keys]) if loaded_keys else 1 41 | log_str_template = "{: <{}} loaded from {: <{}} of shape {}" 42 | logger = logging.getLogger(__name__) 43 | for idx_new, idx_old in enumerate(idxs.tolist()): 44 | if idx_old == -1: 45 | continue 46 | key = current_keys[idx_new] 47 | key_old = loaded_keys[idx_old] 48 | model_state_dict[key] = loaded_state_dict[key_old] 49 | logger.info( 50 | log_str_template.format( 51 | key, 52 | max_size, 53 | key_old, 54 | max_size_loaded, 55 | tuple(loaded_state_dict[key_old].shape), 56 | ) 57 | ) 58 | 59 | 60 | def strip_prefix_if_present(state_dict, prefix): 61 | keys = sorted(state_dict.keys()) 62 | if not all(key.startswith(prefix) for key in keys): 63 | return state_dict 64 | stripped_state_dict = OrderedDict() 65 | for key, value in state_dict.items(): 66 | stripped_state_dict[key.replace(prefix, "")] = value 67 | return stripped_state_dict 68 | 69 | 70 | def load_state_dict(model, loaded_state_dict): 71 | model_state_dict = model.state_dict() 72 | # if the state_dict comes from a model that was wrapped in a 73 | # DataParallel or DistributedDataParallel during serialization, 74 | # remove the "module" prefix before performing the matching 75 | 76 | loaded_state_dict = strip_prefix_if_present(loaded_state_dict, prefix="module.") 77 | 78 | align_and_update_state_dicts(model_state_dict, loaded_state_dict) 79 | 80 | # use strict loading 81 | 82 | model.load_state_dict(model_state_dict) 83 | -------------------------------------------------------------------------------- /animegan/utils/model_zoo.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 4 | 5 | import os 6 | import sys 7 | from torch.hub import urlparse, _download_url_to_file 8 | import torch.utils.model_zoo as model_zoo 9 | try: 10 | from torch.hub import HASH_REGEX 11 | except: 12 | import re 13 | HASH_REGEX = re.compile(r'-([a-f0-9]*)\.') 14 | from animegan.utils.comm import is_main_process, synchronize, get_world_size 15 | 16 | # very similar to https://github.com/pytorch/pytorch/blob/master/torch/utils/model_zoo.py 17 | # but with a few improvements and modifications 18 | def cache_url(url, model_dir=None, progress=True): 19 | r"""Loads the Torch serialized object at the given URL. 20 | If the object is already present in `model_dir`, it's deserialized and 21 | returned. The filename part of the URL should follow the naming convention 22 | ``filename-.ext`` where ```` is the first eight or more 23 | digits of the SHA256 hash of the contents of the file. The hash is used to 24 | ensure unique names and to verify the contents of the file. 25 | The default value of `model_dir` is ``$TORCH_HOME/models`` where 26 | ``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be 27 | overridden with the ``$TORCH_MODEL_ZOO`` environment variable. 28 | Args: 29 | url (string): URL of the object to download 30 | model_dir (string, optional): directory in which to save the object 31 | progress (bool, optional): whether or not to display a progress bar to stderr 32 | Example: 33 | >>> cached_file = rock.dl.utils.model_zoo.cache_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth') 34 | """ 35 | if model_dir is None: 36 | torch_home = os.path.expanduser(os.getenv('TORCH_HOME', '~/.torch')) 37 | model_dir = os.getenv('TORCH_MODEL_ZOO', os.path.join(torch_home, 'models')) 38 | if not os.path.exists(model_dir): 39 | os.makedirs(model_dir) 40 | parts = urlparse(url) 41 | filename = os.path.basename(parts.path) 42 | if filename == "model_final.pkl": 43 | # workaround as pre-trained Caffe2 models from Detectron have all the same filename 44 | # so make the full path the filename by replacing / with _ 45 | filename = parts.path.replace("/", "_") 46 | cached_file = os.path.join(model_dir, filename) 47 | if not os.path.exists(cached_file) and is_main_process(): 48 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 49 | hash_prefix = HASH_REGEX.search(filename) 50 | if hash_prefix is not None: 51 | hash_prefix = hash_prefix.group(1) 52 | # workaround: Caffe2 models don't have a hash, but follow the R-50 convention, 53 | # which matches the hash PyTorch uses. So we skip the hash matching 54 | # if the hash_prefix is less than 6 characters 55 | if len(hash_prefix) < 6: 56 | hash_prefix = None 57 | _download_url_to_file(url, cached_file, hash_prefix, progress=progress) 58 | synchronize() 59 | return cached_file 60 | 61 | # 防止多个进程同时下载模型预训练文件 62 | def distributed_load_url(model, url): 63 | if is_main_process(): 64 | model.load_state_dict(model_zoo.load_url(url), strict=False) 65 | if get_world_size() == 1: 66 | return model 67 | synchronize() 68 | if not is_main_process(): 69 | model.load_state_dict(model_zoo.load_url(url), strict=False) 70 | synchronize() 71 | return model -------------------------------------------------------------------------------- /animegan/utils/registry.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | 4 | def _register_generic(module_dict, module_name, module): 5 | assert module_name not in module_dict 6 | module_dict[module_name] = module 7 | 8 | 9 | class Registry(dict): 10 | ''' 11 | A helper class for managing registering modules, it extends a dictionary 12 | and provides a register functions. 13 | Eg. creeting a registry: 14 | some_registry = Registry({"default": default_module}) 15 | There're two ways of registering new modules: 16 | 1): normal way is just calling register function: 17 | def foo(): 18 | ... 19 | some_registry.register("foo_module", foo) 20 | 2): used as decorator when declaring the module: 21 | @some_registry.register("foo_module") 22 | @some_registry.register("foo_modeul_nickname") 23 | def foo(): 24 | ... 25 | Access of module is just like using a dictionary, eg: 26 | f = some_registry["foo_modeul"] 27 | ''' 28 | def __init__(self, *args, **kwargs): 29 | super(Registry, self).__init__(*args, **kwargs) 30 | 31 | def register(self, module_name, module=None): 32 | # used as function call 33 | if module is not None: 34 | _register_generic(self, module_name, module) 35 | return 36 | 37 | # used as decorator 38 | def register_fn(fn): 39 | _register_generic(self, module_name, fn) 40 | return fn 41 | 42 | return register_fn -------------------------------------------------------------------------------- /animegan/utils/tm.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | 4 | import time 5 | import datetime 6 | 7 | class Timer(object): 8 | """A simple timer.""" 9 | 10 | def __init__(self): 11 | self.total_time = 0. 12 | self.calls = 0 13 | self.start_time = 0. 14 | self.diff = 0. 15 | self.average_time = 0. 16 | 17 | def tic(self): 18 | '''time start''' 19 | # using time.time instead of time.clock because time time.clock 20 | # does not normalize for multithreading 21 | self.start_time = time.time() 22 | 23 | def toc(self, average=False): 24 | '''time stop''' 25 | self.diff = time.time() - self.start_time 26 | self.total_time += self.diff 27 | self.calls += 1 28 | if average: 29 | self.average_time = self.total_time / self.calls 30 | return self.average_time 31 | else: 32 | return self.diff 33 | 34 | def clear(self): 35 | self.total_time = 0. 36 | self.calls = 0 37 | self.start_time = 0. 38 | self.diff = 0. 39 | self.average_time = 0. 40 | 41 | def generate_time_str(the_time=None, tag=''): 42 | ''' 43 | generate time string 44 | :param the_time: time.time() 45 | :return: time string 46 | ''' 47 | the_time = the_time or time.time() 48 | time_str = str(int(the_time * 10000000)) 49 | if tag: 50 | time_str = time_str + '_' + tag 51 | return time_str 52 | 53 | def generate_datetime_str(the_time=None, formate='%Y-%m-%d %H:%M:%S', tag=''): 54 | ''' 55 | generate datetime string 56 | :param the_time: datetime.datetime() 57 | :param formate: datetime string format 58 | :return: datetime string 59 | ''' 60 | the_time = the_time or datetime.datetime.now() 61 | time_str = the_time.strftime(formate) 62 | if tag: 63 | time_str = time_str + '_' + tag 64 | return time_str -------------------------------------------------------------------------------- /configs/e2e_hayao.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | WEIGHT: "" 3 | TRANSFER_LEARNING: False 4 | COMMON: 5 | GAN_TYPE: 'lsgan' 6 | TRAINING_RATE: 1 7 | LD: 10.0 8 | WEIGHT_ADV_G: 300.0 9 | WEIGHT_ADV_D: 300.0 10 | WEIGHT_G_CON: 1.5 11 | WEIGHT_G_STYLE: 2.5 12 | WEIGHT_G_COLOR: 15.0 13 | WEIGHT_G_TV: 1.0 14 | WEIGHT_D_LOSS_REAL: 1.2 15 | WEIGHT_D_LOSS_FAKE: 1.2 16 | WEIGHT_D_LOSS_GRAY: 1.2 17 | WEIGHT_D_LOSS_BLUR: 0.8 18 | BACKBONE: 19 | BODY: "VGG19" 20 | WEIGHT: "/data/datasets/animegan/vgg_pretrained.pth" 21 | GENERATOR: 22 | BODY: "Base-256" 23 | IN_CHANNELS: 3 24 | DISCRIMINATOR: 25 | BODY: "Base-256" 26 | IN_CHANNELS: 3 27 | CHANNELS: 64 28 | N_DIS: 2 29 | DATASETS: 30 | TRAIN: [ 31 | { 32 | 'factory': 'AnimeGanDataset', 33 | 'dataDir': '/data/datasets/animegan/Hayao', 34 | 'split': 'train' 35 | } 36 | ] 37 | TEST: [ 38 | { 39 | 'factory': 'AnimeGanDataset', 40 | 'dataDir': '/data/datasets/animegan/Hayao', 41 | 'split': 'test' 42 | } 43 | ] 44 | INPUT: 45 | IMG_SIZE: (256, 256) 46 | PIXEL_MEAN: [-4.4661, -8.6698, 13.1360] 47 | DATALOADER: 48 | NUM_WORKERS: 1 49 | SOLVER: 50 | MAX_EPOCH: 100 51 | PRINT_PERIOD: 20 52 | CHECKPOINT_PERIOD: 1 53 | TEST_PERIOD: 1 54 | IMS_PER_BATCH: 8 55 | GENERATOR: 56 | BASE_LR: 0.0002 57 | INIT_EPOCH: 10 58 | STEPS: (10,) 59 | WARMUP_FACTOR: 0.0 60 | WARMUP_ITERS: 0 61 | WARMUP_METHOD: 'constant' 62 | DISCRIMINATOR: 63 | BASE_LR: 0.00004 64 | STEPS: (100,) 65 | WARMUP_FACTOR: 0.0 66 | WARMUP_ITERS: 0 67 | WARMUP_METHOD: 'constant' 68 | TEST: 69 | IMS_PER_BATCH: 1 70 | -------------------------------------------------------------------------------- /configs/e2e_shinkai.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | WEIGHT: "" 3 | TRANSFER_LEARNING: False 4 | COMMON: 5 | GAN_TYPE: 'lsgan' 6 | TRAINING_RATE: 1 7 | WEIGHT_ADV_G: 300.0 8 | WEIGHT_ADV_D: 300.0 9 | WEIGHT_G_CON: 1.2 10 | WEIGHT_G_STYLE: 2.0 11 | WEIGHT_G_COLOR: 10.0 12 | WEIGHT_G_TV: 1.0 13 | WEIGHT_D_LOSS_REAL: 1.7 14 | WEIGHT_D_LOSS_FAKE: 1.7 15 | WEIGHT_D_LOSS_GRAY: 1.7 16 | WEIGHT_D_LOSS_BLUR: 1.0 17 | BACKBONE: 18 | BODY: "VGG19" 19 | WEIGHT: "/data/datasets/animegan/vgg_pretrained.pth" 20 | GENERATOR: 21 | BODY: "Base-256" 22 | IN_CHANNELS: 3 23 | DISCRIMINATOR: 24 | BODY: "Base-256" 25 | IN_CHANNELS: 3 26 | CHANNELS: 64 27 | N_DIS: 2 28 | DATASETS: 29 | TRAIN: [ 30 | { 31 | 'factory': 'AnimeGanDataset', 32 | 'dataDir': '/data/datasets/animegan/Shinkai', 33 | 'split': 'train' 34 | } 35 | ] 36 | TEST: [ 37 | { 38 | 'factory': 'AnimeGanDataset', 39 | 'dataDir': '/data/datasets/animegan/Shinkai', 40 | 'split': 'test' 41 | } 42 | ] 43 | INPUT: 44 | IMG_SIZE: (256, 256) 45 | PIXEL_MEAN: [-4.4661, -8.6698, 13.1360] 46 | DATALOADER: 47 | NUM_WORKERS: 1 48 | SOLVER: 49 | MAX_EPOCH: 100 50 | PRINT_PERIOD: 20 51 | CHECKPOINT_PERIOD: 1 52 | TEST_PERIOD: 1 53 | IMS_PER_BATCH: 8 54 | GENERATOR: 55 | BASE_LR: 0.0002 56 | INIT_EPOCH: 10 57 | STEPS: (10,) 58 | WARMUP_FACTOR: 0.0 59 | WARMUP_ITERS: 0 60 | WARMUP_METHOD: 'constant' 61 | DISCRIMINATOR: 62 | BASE_LR: 0.00004 63 | STEPS: (100,) 64 | WARMUP_FACTOR: 0.0 65 | WARMUP_ITERS: 0 66 | WARMUP_METHOD: 'constant' 67 | TEST: 68 | IMS_PER_BATCH: 1 69 | -------------------------------------------------------------------------------- /scripts/data_mean.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | 4 | import os 5 | import cv2 6 | from tqdm import tqdm 7 | 8 | def read_img(image_path): 9 | img = cv2.imread(image_path) 10 | assert len(img.shape) == 3 11 | B = img[..., 0].mean() 12 | G = img[..., 1].mean() 13 | R = img[..., 2].mean() 14 | return B, G, R 15 | 16 | def get_mean(images_path): 17 | images = os.listdir(images_path) 18 | image_num = len(images) 19 | B_total = 0 20 | G_total = 0 21 | R_total = 0 22 | for image in tqdm(images) : 23 | image_path = os.path.join(images_path, image) 24 | bgr = read_img(image_path) 25 | B_total += bgr[0] 26 | G_total += bgr[1] 27 | R_total += bgr[2] 28 | 29 | B_mean, G_mean, R_mean = B_total / image_num, G_total / image_num, R_total / image_num 30 | mean = (B_mean + G_mean + R_mean)/3 31 | 32 | return mean-B_mean, mean-G_mean, mean-R_mean 33 | 34 | if __name__ == '__main__': 35 | images_path = "/data/datasets/animegan/your_name/train/style" 36 | B_mean, G_mean, R_mean = get_mean(images_path) 37 | print("B_mean: {}\nG_mean: {}\nR_mean: {}".format(B_mean, G_mean, R_mean)) -------------------------------------------------------------------------------- /scripts/gramEmbedding.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | 4 | ''' 5 | gram 特征可视化 6 | 相同场景的gram特征距离较近,灰度图聚类更好,颜色收到一定颜色影响 7 | 同一张图片的gram特征受颜色的影响 8 | ''' 9 | 10 | import os, sys 11 | project_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 12 | sys.path.append(project_path) 13 | import cv2 14 | import torch 15 | import argparse 16 | from tqdm import tqdm 17 | from torch.utils.tensorboard import SummaryWriter 18 | from animegan.configs import cfg 19 | from animegan.modeling.backbone import build_backbone 20 | from animegan.data.transforms.build import build_transforms 21 | from animegan.modeling.utils import gram, prepare_feature_extract 22 | 23 | 24 | def get_model(device): 25 | model = build_backbone(cfg) 26 | model.to(device) 27 | return model 28 | 29 | def main(): 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument( 32 | "--config-file", 33 | default="", 34 | metavar="FILE", 35 | help="path to config file", 36 | type=str, 37 | ) 38 | parser.add_argument( 39 | "--imagesPath", 40 | type=str, 41 | required=True 42 | ) 43 | parser.add_argument( 44 | "--num", 45 | type=int, 46 | default=100 47 | ) 48 | 49 | parser.add_argument( 50 | "opts", 51 | help="Modify config options using the command-line", 52 | default=None, 53 | nargs=argparse.REMAINDER, 54 | ) 55 | 56 | args = parser.parse_args() 57 | cfg.merge_from_file(args.config_file) 58 | cfg.merge_from_list(args.opts) 59 | cfg.freeze() 60 | 61 | imagesPath = args.imagesPath 62 | num = args.num 63 | device = torch.device(cfg.MODEL.DEVICE) 64 | model = get_model(device) 65 | model.eval() 66 | transform = build_transforms(cfg, False) 67 | 68 | writer = SummaryWriter(log_dir="gramEmbedding") 69 | mat_list = [] 70 | label_image_list = [] 71 | metadata_list = [] 72 | count = 0 73 | for image in tqdm(os.listdir(imagesPath)): 74 | count += 1 75 | if count > num: break 76 | image_input_ori = cv2.imread(os.path.join(imagesPath, image)) 77 | image_tarnsform = transform([image_input_ori])[0] 78 | image_input_gray = image_tarnsform[1].unsqueeze(0).to(device) 79 | image_input_color = image_tarnsform[0].unsqueeze(0).to(device) 80 | with torch.no_grad(): 81 | # color 82 | backbone_feature_map = model(prepare_feature_extract(image_input_color)) 83 | gram_feature_map = gram(backbone_feature_map) 84 | gram_feature = gram_feature_map.flatten() 85 | mat_list.append(gram_feature) 86 | image_input_ori_show = cv2.cvtColor(image_input_ori, cv2.COLOR_BGR2RGB).transpose((2, 0, 1)) 87 | label_image_list.append(torch.from_numpy(image_input_ori_show) / 255.0) 88 | metadata_list.append(torch.zeros([1])) 89 | 90 | #gray 91 | backbone_feature_map = model(prepare_feature_extract(image_input_gray)) 92 | gram_feature_map = gram(backbone_feature_map) 93 | gram_feature = gram_feature_map.flatten() 94 | mat_list.append(gram_feature) 95 | image_input_ori_show = cv2.cvtColor(image_input_ori, cv2.COLOR_BGR2RGB).transpose((2, 0, 1)) 96 | label_image_list.append(torch.from_numpy(image_input_ori_show) / 255.0) 97 | metadata_list.append(torch.ones([1])) 98 | writer.add_embedding( 99 | mat=torch.stack(mat_list), 100 | metadata=metadata_list, 101 | label_img=torch.stack(label_image_list) 102 | ) 103 | 104 | if __name__ == '__main__': 105 | main() -------------------------------------------------------------------------------- /scripts/image2anime.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | 4 | import os, sys 5 | project_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 6 | sys.path.append(project_path) 7 | import cv2 8 | import torch 9 | import argparse 10 | import numpy as np 11 | from PIL import Image 12 | from animegan.configs import cfg 13 | from animegan.modeling.generator import build_generator 14 | from animegan.data.transforms.build import build_transforms 15 | from animegan.utils.model_serialization import load_state_dict 16 | from animegan.modeling.utils import adjust_brightness_from_src_to_dst 17 | 18 | 19 | def get_model(model_weight, device): 20 | model = build_generator(cfg) 21 | checkpoint = torch.load(model_weight, map_location=torch.device("cpu")) 22 | load_state_dict(model, checkpoint.pop("models").pop("generator")) 23 | model.to(device) 24 | return model 25 | 26 | def main(): 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument( 29 | "--config-file", 30 | required=True, 31 | metavar="FILE", 32 | help="path to config file", 33 | type=str, 34 | ) 35 | parser.add_argument( 36 | "--image", 37 | type=str, 38 | required=True 39 | ) 40 | 41 | parser.add_argument( 42 | "opts", 43 | help="Modify config options using the command-line", 44 | default=None, 45 | nargs=argparse.REMAINDER, 46 | ) 47 | 48 | args = parser.parse_args() 49 | cfg.merge_from_file(args.config_file) 50 | cfg.merge_from_list(args.opts) 51 | cfg.freeze() 52 | 53 | image_path = args.image 54 | model_weight = cfg.MODEL.WEIGHT 55 | device = torch.device(cfg.MODEL.DEVICE) 56 | 57 | model = get_model(model_weight, device) 58 | model.eval() 59 | 60 | transform = build_transforms(cfg, False) 61 | 62 | image = cv2.imread(image_path) 63 | 64 | input = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 65 | input = Image.fromarray(input) 66 | input = transform([input])[0][0].unsqueeze(0) 67 | input = input.to(device) 68 | with torch.no_grad(): 69 | pred = model(input).cpu() 70 | pred_img = (pred.squeeze() + 1.) / 2 * 255 71 | pred_img = pred_img.permute(1, 2, 0).numpy().clip(0, 255).astype(np.uint8) 72 | pred_img = adjust_brightness_from_src_to_dst(pred_img, image) 73 | pred_img = cv2.cvtColor(pred_img, cv2.COLOR_RGB2BGR) 74 | cv2.imshow("", pred_img) 75 | cv2.waitKey() 76 | image = cv2.resize(image, pred_img.shape[:-1][::-1]) 77 | concat_img = np.concatenate((image, pred_img), 1) 78 | cv2.imwrite(f"anime_{os.path.basename(image_path)}", concat_img) 79 | 80 | if __name__ == '__main__': 81 | main() -------------------------------------------------------------------------------- /scripts/modelTensorboard.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | 4 | import os, sys 5 | project_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 6 | sys.path.append(project_path) 7 | import torch 8 | from animegan.configs import cfg 9 | from animegan.modeling.build import build_model 10 | from torch.utils.tensorboard import SummaryWriter 11 | 12 | if __name__ == '__main__': 13 | tensorboard_dir = os.path.join(project_path, 'graph') 14 | writer_b = SummaryWriter(os.path.join(tensorboard_dir, 'backbone')) 15 | writer_g = SummaryWriter(os.path.join(tensorboard_dir, 'generator')) 16 | writer_d = SummaryWriter(os.path.join(tensorboard_dir, 'discriminator')) 17 | model_backbone, model_generator, model_discriminator = build_model(cfg) 18 | input_dummy = torch.ones((8, 3, 256, 256)) 19 | writer_b.add_graph(model_backbone, input_dummy) 20 | writer_g.add_graph(model_generator, input_dummy) 21 | writer_d.add_graph(model_discriminator, input_dummy) -------------------------------------------------------------------------------- /scripts/tf2torch.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | 4 | # tensorflow 5 | # conv1_1 (3, 3, 3, 64) (64,) 6 | # conv1_2 (3, 3, 64, 64) (64,) 7 | # conv2_1 (3, 3, 64, 128) (128,) 8 | # conv2_2 (3, 3, 128, 128) (128,) 9 | # conv3_1 (3, 3, 128, 256) (256,) 10 | # conv3_2 (3, 3, 256, 256) (256,) 11 | # conv3_3 (3, 3, 256, 256) (256,) 12 | # conv3_4 (3, 3, 256, 256) (256,) 13 | # conv4_1 (3, 3, 256, 512) (512,) 14 | # conv4_2 (3, 3, 512, 512) (512,) 15 | # conv4_3 (3, 3, 512, 512) (512,) 16 | # conv4_4 (3, 3, 512, 512) (512,) 17 | # conv5_1 (3, 3, 512, 512) (512,) 18 | # conv5_2 (3, 3, 512, 512) (512,) 19 | # conv5_3 (3, 3, 512, 512) (512,) 20 | # conv5_4 (3, 3, 512, 512) (512,) 21 | # fc6 (25088, 4096) (4096,) 22 | # fc7 (4096, 4096) (4096,) 23 | # fc8 (4096, 1000) (1000,) 24 | 25 | # pytorch 26 | # features.0.weight torch.Size([64, 3, 3, 3]) 27 | # features.0.bias torch.Size([64]) 28 | # features.2.weight torch.Size([64, 64, 3, 3]) 29 | # features.2.bias torch.Size([64]) 30 | # features.5.weight torch.Size([128, 64, 3, 3]) 31 | # features.5.bias torch.Size([128]) 32 | # features.7.weight torch.Size([128, 128, 3, 3]) 33 | # features.7.bias torch.Size([128]) 34 | # features.10.weight torch.Size([256, 128, 3, 3]) 35 | # features.10.bias torch.Size([256]) 36 | # features.12.weight torch.Size([256, 256, 3, 3]) 37 | # features.12.bias torch.Size([256]) 38 | # features.14.weight torch.Size([256, 256, 3, 3]) 39 | # features.14.bias torch.Size([256]) 40 | # features.16.weight torch.Size([256, 256, 3, 3]) 41 | # features.16.bias torch.Size([256]) 42 | # features.19.weight torch.Size([512, 256, 3, 3]) 43 | # features.19.bias torch.Size([512]) 44 | # features.21.weight torch.Size([512, 512, 3, 3]) 45 | # features.21.bias torch.Size([512]) 46 | # features.23.weight torch.Size([512, 512, 3, 3]) 47 | # features.23.bias torch.Size([512]) 48 | # features.25.weight torch.Size([512, 512, 3, 3]) 49 | # features.25.bias torch.Size([512]) 50 | # features.28.weight torch.Size([512, 512, 3, 3]) 51 | # features.28.bias torch.Size([512]) 52 | # features.30.weight torch.Size([512, 512, 3, 3]) 53 | # features.30.bias torch.Size([512]) 54 | # features.32.weight torch.Size([512, 512, 3, 3]) 55 | # features.32.bias torch.Size([512]) 56 | # features.34.weight torch.Size([512, 512, 3, 3]) 57 | # features.34.bias torch.Size([512]) 58 | # classifier.0.weight torch.Size([4096, 25088]) 59 | # classifier.0.bias torch.Size([4096]) 60 | # classifier.3.weight torch.Size([4096, 4096]) 61 | # classifier.3.bias torch.Size([4096]) 62 | # classifier.6.weight torch.Size([1000, 4096]) 63 | # classifier.6.bias torch.Size([1000]) 64 | 65 | exchange_map = { 66 | "features.0.weight": ("conv1_1", 0), 67 | "features.0.bias": ("conv1_1", 1), 68 | "features.2.weight": ("conv1_2", 0), 69 | "features.2.bias": ("conv1_2", 1), 70 | "features.5.weight": ("conv2_1", 0), 71 | "features.5.bias": ("conv2_1", 1), 72 | "features.7.weight": ("conv2_2", 0), 73 | "features.7.bias": ("conv2_2", 1), 74 | "features.10.weight": ("conv3_1", 0), 75 | "features.10.bias": ("conv3_1", 1), 76 | "features.12.weight": ("conv3_2", 0), 77 | "features.12.bias": ("conv3_2", 1), 78 | "features.14.weight": ("conv3_3", 0), 79 | "features.14.bias": ("conv3_3", 1), 80 | "features.16.weight": ("conv3_4", 0), 81 | "features.16.bias": ("conv3_4", 1), 82 | "features.19.weight": ("conv4_1", 0), 83 | "features.19.bias": ("conv4_1", 1), 84 | "features.21.weight": ("conv4_2", 0), 85 | "features.21.bias": ("conv4_2", 1), 86 | "features.23.weight": ("conv4_3", 0), 87 | "features.23.bias": ("conv4_3", 1), 88 | "features.25.weight": ("conv4_4", 0), 89 | "features.25.bias": ("conv4_4", 1), 90 | "features.28.weight": ("conv5_1", 0), 91 | "features.28.bias": ("conv5_1", 1), 92 | "features.30.weight": ("conv5_2", 0), 93 | "features.30.bias": ("conv5_2", 1), 94 | "features.32.weight": ("conv5_3", 0), 95 | "features.32.bias": ("conv5_3", 1), 96 | "features.34.weight": ("conv5_4", 0), 97 | "features.34.bias": ("conv5_4", 1), 98 | "classifier.0.weight": ("fc6", 0), 99 | "classifier.0.bias": ("fc6", 1), 100 | "classifier.3.weight": ("fc7", 0), 101 | "classifier.3.bias": ("fc7", 1), 102 | "classifier.6.weight": ("fc8", 0), 103 | "classifier.6.bias": ("fc8", 1), 104 | } 105 | 106 | import torch 107 | import numpy as np 108 | from torchvision.models import vgg19 109 | 110 | def get_tf_dict(path): 111 | state_dict = np.load(path, encoding='latin1', allow_pickle=True).item() 112 | return state_dict 113 | 114 | def get_torch_dict(): 115 | state_dict = vgg19(pretrained=True).state_dict() 116 | return state_dict 117 | 118 | def exchange(): 119 | tf_state_dict = get_tf_dict('/data/datasets/animegan/vgg19.npy') 120 | torch_state_dict = get_torch_dict() 121 | for k, v in torch_state_dict.items(): 122 | tf_k, tf_ind = exchange_map[k] 123 | tf_map_v = tf_state_dict[tf_k][tf_ind] 124 | # check shape eq 125 | if 'weight' in k: 126 | if tf_map_v.ndim == 4: 127 | tf_map_v = np.ascontiguousarray(tf_map_v.transpose(3, 2, 0, 1)) 128 | elif v.ndim == 2: 129 | tf_map_v = np.ascontiguousarray(tf_map_v.transpose(1, 0)) 130 | assert tuple(v.shape) == tf_map_v.shape 131 | 132 | # exchange 133 | print("Exchane torch [{}] to tf [{}:{}]".format(k, tf_k, tf_ind)) 134 | torch_state_dict[k] = torch.from_numpy(tf_map_v) 135 | torch.save(torch_state_dict, "vgg_tf_2_torch.pth") 136 | 137 | if __name__ == '__main__': 138 | exchange() -------------------------------------------------------------------------------- /scripts/video2anime.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | 4 | import os, sys 5 | project_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 6 | sys.path.append(project_path) 7 | import cv2 8 | import torch 9 | import argparse 10 | import numpy as np 11 | from tqdm import tqdm 12 | from PIL import Image 13 | from animegan.configs import cfg 14 | from animegan.modeling.generator import build_generator 15 | from animegan.data.transforms.build import build_transforms 16 | from animegan.utils.model_serialization import load_state_dict 17 | from animegan.modeling.utils import adjust_brightness_from_src_to_dst 18 | 19 | 20 | def get_model(model_weight, device): 21 | model = build_generator(cfg) 22 | checkpoint = torch.load(model_weight, map_location=torch.device("cpu")) 23 | load_state_dict(model, checkpoint.pop("models").pop("generator")) 24 | model.to(device) 25 | return model 26 | 27 | def main(): 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument( 30 | "--config-file", 31 | default="", 32 | metavar="FILE", 33 | help="path to config file", 34 | type=str, 35 | ) 36 | parser.add_argument( 37 | "--video", 38 | type=str, 39 | required=True 40 | ) 41 | parser.add_argument( 42 | "opts", 43 | help="Modify config options using the command-line", 44 | default=None, 45 | nargs=argparse.REMAINDER, 46 | ) 47 | 48 | args = parser.parse_args() 49 | cfg.merge_from_file(args.config_file) 50 | cfg.merge_from_list(args.opts) 51 | cfg.freeze() 52 | video_path = args.video 53 | model_weight = cfg.MODEL.WEIGHT 54 | device = torch.device(cfg.MODEL.DEVICE) 55 | 56 | model = get_model(model_weight, device) 57 | model.eval() 58 | transform = build_transforms(cfg, False) 59 | 60 | videoCapture = cv2.VideoCapture(video_path) 61 | fps = int(videoCapture.get(cv2.CAP_PROP_FPS)) 62 | # w, h = int(videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(videoCapture.get(cv2.CAP_PROP_FRAME_HEIGHT)) 63 | # size = (int(w - w % 32), int(h - h % 32)) 64 | size = (1920, 1080) 65 | # fourcc = int(videoCapture.get(cv2.CAP_PROP_FOURCC)) 66 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 67 | frame_num = int(videoCapture.get(cv2.CAP_PROP_FRAME_COUNT)) 68 | videoWriter = cv2.VideoWriter(f"anime_{os.path.basename(video_path)}.mp4", fourcc, fps, size) 69 | for index in tqdm(range(frame_num)): 70 | success, frame = videoCapture.read() 71 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 72 | frame = cv2.resize(frame, size) 73 | input = Image.fromarray(frame) 74 | input = transform([input])[0][0].unsqueeze(0) 75 | input = input.to(device) 76 | with torch.no_grad(): 77 | pred = model(input).cpu() 78 | pred_img = (pred.squeeze() + 1.) / 2 * 255 79 | pred_img = pred_img.permute(1, 2, 0).numpy().clip(0, 255).astype(np.uint8) 80 | pred_img = adjust_brightness_from_src_to_dst(pred_img, frame) 81 | video_frame = cv2.cvtColor(pred_img, cv2.COLOR_RGB2BGR) 82 | video_frame = cv2.resize(video_frame, size) 83 | # cv2.imshow("", video_frame) 84 | # cv2.waitKey() 85 | videoWriter.write(video_frame) 86 | videoCapture.release() 87 | videoWriter.release() 88 | 89 | if __name__ == '__main__': 90 | main() -------------------------------------------------------------------------------- /src/hayao/anime_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wan-h/AnimeGAN_pytorch/90c868853339194c35bc29295670a5437ba0adfc/src/hayao/anime_1.jpg -------------------------------------------------------------------------------- /src/hayao/anime_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wan-h/AnimeGAN_pytorch/90c868853339194c35bc29295670a5437ba0adfc/src/hayao/anime_2.jpg -------------------------------------------------------------------------------- /src/hayao/anime_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wan-h/AnimeGAN_pytorch/90c868853339194c35bc29295670a5437ba0adfc/src/hayao/anime_3.jpg -------------------------------------------------------------------------------- /src/shinkai/anime_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wan-h/AnimeGAN_pytorch/90c868853339194c35bc29295670a5437ba0adfc/src/shinkai/anime_1.jpg -------------------------------------------------------------------------------- /src/shinkai/anime_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wan-h/AnimeGAN_pytorch/90c868853339194c35bc29295670a5437ba0adfc/src/shinkai/anime_2.jpg -------------------------------------------------------------------------------- /src/shinkai/anime_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wan-h/AnimeGAN_pytorch/90c868853339194c35bc29295670a5437ba0adfc/src/shinkai/anime_3.jpg -------------------------------------------------------------------------------- /tools/train_net.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author: wanhui0729@gmail.com 3 | 4 | # coding: utf-8 5 | # Author: wanhui0729@gmail.com 6 | 7 | import os 8 | import sys 9 | root_path = os.path.dirname(os.path.abspath(os.path.dirname(__file__))) 10 | sys.path.append(root_path) 11 | 12 | import argparse 13 | import torch 14 | # see https://github.com/pytorch/pytorch/issues/973 15 | # torch.multiprocessing.set_sharing_strategy('file_system') 16 | from animegan.configs import cfg 17 | from animegan.lib.trainer import train 18 | from animegan.utils.env import collect_env_info 19 | from animegan.utils.tm import generate_datetime_str 20 | from animegan.utils.comm import get_rank, synchronize 21 | from animegan.utils.logger import setup_logger 22 | 23 | 24 | def main(): 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument( 27 | "--config-file", 28 | default="", 29 | metavar="FILE", 30 | help="path to config file", 31 | type=str, 32 | ) 33 | parser.add_argument("--local_rank", type=int, default=0) 34 | # 所有剩余的命令行参数都被收集到一个列表中 opts 35 | parser.add_argument( 36 | "opts", 37 | help="Modify config options using the command-line", 38 | default=None, 39 | nargs=argparse.REMAINDER, 40 | ) 41 | 42 | args = parser.parse_args() 43 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 44 | args.distributed = num_gpus > 1 45 | # init distributed 46 | if args.distributed: 47 | torch.cuda.set_device(args.local_rank) 48 | torch.distributed.init_process_group( 49 | backend="nccl", init_method="env://" 50 | ) 51 | 52 | cfg.merge_from_file(args.config_file) 53 | cfg.merge_from_list(args.opts) 54 | cfg.freeze() 55 | # 继续训练使用相同的目录结构 56 | if cfg.MODEL.WEIGHT and not cfg.MODEL.TRANSFER_LEARNING: 57 | weight_dir = cfg.MODEL.WEIGHT 58 | model_record_dir = os.path.dirname(weight_dir) 59 | output_dir = os.path.dirname(model_record_dir) 60 | else: 61 | output_dir = os.path.join(os.path.abspath(cfg.OUTPUT_DIR), generate_datetime_str(formate='%Y%m%d-%H%M%S')) 62 | if get_rank() == 0: 63 | os.makedirs(output_dir, exist_ok=True) 64 | synchronize() 65 | logger_name = "AnimeGan" 66 | logFile = os.path.join(output_dir, 'log.txt') 67 | logger = setup_logger(name=logger_name, distributed_rank=get_rank(), logFile=logFile) 68 | logger.info("Using {} GPUs".format(num_gpus)) 69 | logger.info(args) 70 | logger.info("Collecting env info (might take some time)") 71 | logger.info("\n" + collect_env_info()) 72 | logger.info("Loaded configuration file {}".format(args.config_file)) 73 | 74 | with open(args.config_file, "r") as cf: 75 | config_str = "\n" + cf.read() 76 | logger.info(config_str) 77 | logger.info("Running with config:\n{}".format(cfg)) 78 | 79 | train(cfg, args.local_rank, args.distributed, logger_name, output_dir) 80 | 81 | if __name__ == '__main__': 82 | main() --------------------------------------------------------------------------------