├── editing ├── __init__.py ├── config.py └── util.py ├── third_party └── __init__.py ├── tools ├── inversion │ └── __init__.py ├── manipulation │ ├── celeba_attributes.py │ ├── analysis_channel.py │ └── enhance.py ├── transfer │ ├── convert_stylegan2-ada_weights.py │ └── convert_weight.py ├── evaluate │ ├── blur.py │ └── fid.py └── helper.py ├── documents ├── images │ └── tmp │ │ ├── b0.jpg │ │ ├── alter_paired_mdfc0_batch1.jpg │ │ ├── before_Blode_mdfc0_batch1.jpg │ │ ├── before_alter_Black_Hair_12_mdfc0_batch0.jpg │ │ └── before_alter_Black_Hair_12_mdfc0_batch0_diff.jpg └── environment.md ├── pretrained └── modifications │ ├── before_alter_Black_Hair_12.mdfc │ ├── before_alter_Blond_Hair_12.mdfc │ ├── before_single_channel_11_286.mdfc │ └── before_Bushy_Eyebrows_s_123,315,325.mdfc ├── models ├── __init__.py ├── ada_ops │ ├── __init__.py │ ├── bias_act.h │ ├── upfirdn2d.h │ ├── fma.py │ ├── grid_sample_gradfix.py │ ├── bias_act.cpp │ ├── upfirdn2d.cpp │ ├── bias_act.cu │ ├── conv2d_resample.py │ ├── conv2d_gradfix.py │ └── bias_act.py ├── custom_ops.py └── StyleGAN2_wrapper.py ├── raydl ├── __init__.py ├── handlers │ ├── __init__.py │ ├── util.py │ ├── item_clock.py │ ├── ema.py │ └── running_statistics.py ├── metrics │ ├── __init__.py │ ├── common.py │ └── generation.py ├── distributed.py ├── misc.py ├── engine.py ├── collection.py ├── fp16.py ├── tensor.py ├── information.py ├── registry.py └── io.py ├── .gitignore ├── environment.yml └── README.md /editing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /third_party/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/inversion/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /documents/images/tmp/b0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/budui/Control-Units-in-StyleGAN2/HEAD/documents/images/tmp/b0.jpg -------------------------------------------------------------------------------- /documents/images/tmp/alter_paired_mdfc0_batch1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/budui/Control-Units-in-StyleGAN2/HEAD/documents/images/tmp/alter_paired_mdfc0_batch1.jpg -------------------------------------------------------------------------------- /documents/images/tmp/before_Blode_mdfc0_batch1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/budui/Control-Units-in-StyleGAN2/HEAD/documents/images/tmp/before_Blode_mdfc0_batch1.jpg -------------------------------------------------------------------------------- /pretrained/modifications/before_alter_Black_Hair_12.mdfc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/budui/Control-Units-in-StyleGAN2/HEAD/pretrained/modifications/before_alter_Black_Hair_12.mdfc -------------------------------------------------------------------------------- /pretrained/modifications/before_alter_Blond_Hair_12.mdfc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/budui/Control-Units-in-StyleGAN2/HEAD/pretrained/modifications/before_alter_Blond_Hair_12.mdfc -------------------------------------------------------------------------------- /pretrained/modifications/before_single_channel_11_286.mdfc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/budui/Control-Units-in-StyleGAN2/HEAD/pretrained/modifications/before_single_channel_11_286.mdfc -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from raydl.registry import Registry 2 | 3 | MODEL = Registry("model") 4 | 5 | import models.StyleGAN2 6 | import models.StyleGAN2_wrapper 7 | import models.StyleGAN2_mine -------------------------------------------------------------------------------- /documents/images/tmp/before_alter_Black_Hair_12_mdfc0_batch0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/budui/Control-Units-in-StyleGAN2/HEAD/documents/images/tmp/before_alter_Black_Hair_12_mdfc0_batch0.jpg -------------------------------------------------------------------------------- /pretrained/modifications/before_Bushy_Eyebrows_s_123,315,325.mdfc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/budui/Control-Units-in-StyleGAN2/HEAD/pretrained/modifications/before_Bushy_Eyebrows_s_123,315,325.mdfc -------------------------------------------------------------------------------- /documents/images/tmp/before_alter_Black_Hair_12_mdfc0_batch0_diff.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/budui/Control-Units-in-StyleGAN2/HEAD/documents/images/tmp/before_alter_Black_Hair_12_mdfc0_batch0_diff.jpg -------------------------------------------------------------------------------- /raydl/__init__.py: -------------------------------------------------------------------------------- 1 | from .collection import * 2 | from .distributed import ddp_sync 3 | from .fp16 import auto_fp16 4 | from .handlers.util import disable_if_failed 5 | from .information import * 6 | from .io import * 7 | from .misc import * 8 | from .tensor import * 9 | -------------------------------------------------------------------------------- /raydl/handlers/__init__.py: -------------------------------------------------------------------------------- 1 | from .ema import ModelExponentialMovingAverage 2 | from .item_clock import TickEvent, ItemClock 3 | from .running_statistics import RunningStatistician 4 | from .time_profile import BasicTimeProfiler, HandlersTimeProfiler 5 | from .util import disable_if_failed 6 | -------------------------------------------------------------------------------- /models/ada_ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /raydl/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from ignite.engine.events import Events 2 | from ignite.metrics import MetricUsage 3 | 4 | from .common import Collector 5 | from .generation import FID 6 | 7 | 8 | class LifeWise(MetricUsage): 9 | usage_name: str = "life_wise" 10 | 11 | def __init__(self) -> None: 12 | super(LifeWise, self).__init__( 13 | started=Events.STARTED, 14 | completed=Events.COMPLETED, 15 | iteration_completed=Events.ITERATION_COMPLETED, 16 | ) 17 | -------------------------------------------------------------------------------- /raydl/distributed.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | 3 | import torch 4 | 5 | 6 | @contextlib.contextmanager 7 | def ddp_sync(module: torch.nn.Module, sync: bool): 8 | """ 9 | allow DDP sync or not. 10 | :param module: if module is not DDP-wrapped module, do nothing. 11 | :param sync: enable sync between process or not if module is DDP-wrapped module 12 | :return: 13 | """ 14 | assert isinstance(module, torch.nn.Module) 15 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): 16 | yield 17 | else: 18 | with module.no_sync(): 19 | yield 20 | -------------------------------------------------------------------------------- /raydl/handlers/util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | from loguru import logger 4 | 5 | 6 | def disable_if_failed(func): 7 | """ 8 | useful Decorator for ignite handlers. if some handler run failed, do not run this handler in next trigger event. 9 | :param func: handler 10 | :return: handler 11 | """ 12 | _enable = True 13 | 14 | @functools.wraps(func) 15 | def wrapper(*args, **kwargs): 16 | nonlocal _enable 17 | if _enable: 18 | try: 19 | func(*args, **kwargs) 20 | except Exception as e: # pylint: disable=broad-except 21 | # Catch all Exception, log error and disable handler. 22 | _enable = False 23 | logger.warning(f"disable {func.__name__} due to error below") 24 | logger.exception(e) 25 | else: 26 | logger.debug(f"skip {func.__name__} because it failed in the last execution") 27 | 28 | return wrapper 29 | -------------------------------------------------------------------------------- /models/ada_ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ 39 | -------------------------------------------------------------------------------- /documents/environment.md: -------------------------------------------------------------------------------- 1 | ## How to set up environment? 2 | 3 | ### From sketch 4 | 5 | First, install python library: 6 | 7 | ```bash 8 | # CUDA 11 9 | conda install pytorch torchvision ignite cudatoolkit=11.1 -c pytorch -c nvidia 10 | # CUDA 10.1 11 | conda install -y pytorch ignite torchvision cudatoolkit=10.1 -c pytorch 12 | 13 | # with pip 14 | pip install opencv-python ipython loguru fire tqdm lmdb omegaconf numpy scipy matplotlib pandas 15 | # with conda 16 | conda install -c conda-forge ipython loguru fire tqdm python-lmdb omegaconf 17 | 18 | # opencv must be installed with pip 19 | pip install opencv-python 20 | ``` 21 | 22 | Then, install useful tools: 23 | 24 | ```bash 25 | conda install -c conda-forge ipdb tensorboard 26 | ``` 27 | 28 | Finally, update gcc version if you use `centos/tlinux`: 29 | 30 | ```bash 31 | # install scl and gcc tools 32 | yum install tlinux-release-scl 33 | yum install devtoolset-7-gcc-c++.x86_64 34 | # enable high version gcc with `scl` 35 | scl enable devtoolset-7 bash 36 | # replace c++ with g++ to avoid PyTorch warnings. 37 | CCPATH=$(which c++); mv $CCPATH "$CCPATH".backup; ln -s $(which g++) $CCPATH 38 | ``` 39 | 40 | ## Useful tools: 41 | 42 | Check&Kill zombie processes if multi-process tools exit unexpectedly. 43 | 44 | ```bash 45 | # check zombie processes 46 | ps aux | grep train_generator.py | grep -v grep | awk '{print $2}' 47 | # kill them 48 | ps aux | grep train_generator.py | grep -v grep | awk '{print $2}' | xargs kill -9 49 | ``` 50 | 51 | Show process tree: `ps auxf` 52 | 53 | See where a process hang: `strace -p ` 54 | 55 | ```text 56 | ipdb> from IPython import embed 57 | ipdb> embed() # drop into an IPython session. 58 | # Any variables you define or modify here 59 | # will not affect program execution 60 | ``` -------------------------------------------------------------------------------- /models/ada_ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct upfirdn2d_kernel_params 15 | { 16 | const void* x; 17 | const float* f; 18 | void* y; 19 | 20 | int2 up; 21 | int2 down; 22 | int2 pad0; 23 | int flip; 24 | float gain; 25 | 26 | int4 inSize; // [width, height, channel, batch] 27 | int4 inStride; 28 | int2 filterSize; // [width, height] 29 | int2 filterStride; 30 | int4 outSize; // [width, height, channel, batch] 31 | int4 outStride; 32 | int sizeMinor; 33 | int sizeMajor; 34 | 35 | int loopMinor; 36 | int loopMajor; 37 | int loopX; 38 | int launchMinor; 39 | int launchMajor; 40 | }; 41 | 42 | //------------------------------------------------------------------------ 43 | // CUDA kernel specialization. 44 | 45 | struct upfirdn2d_kernel_spec 46 | { 47 | void* kernel; 48 | int tileOutW; 49 | int tileOutH; 50 | int loopMinor; 51 | int loopX; 52 | }; 53 | 54 | //------------------------------------------------------------------------ 55 | // CUDA kernel selection. 56 | 57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 58 | 59 | //------------------------------------------------------------------------ 60 | -------------------------------------------------------------------------------- /raydl/handlers/item_clock.py: -------------------------------------------------------------------------------- 1 | import ignite.distributed as idist 2 | from ignite.engine import Engine, Events, EventEnum 3 | 4 | 5 | class TickEvent(EventEnum): 6 | TICK_STARTED = "tick_started" 7 | TICK_COMPLETED = "tick_completed" 8 | 9 | 10 | class ItemClock: 11 | def __init__( 12 | self, 13 | num_items_per_tick=1024, 14 | batch_num_items_transform=lambda x: len(x), 15 | ): 16 | assert isinstance(num_items_per_tick, int) and num_items_per_tick > 0 17 | self.tick_size = num_items_per_tick 18 | self.batch_num_items_transform = batch_num_items_transform 19 | self.cur_tick = 1 20 | 21 | def _init(self, engine: Engine): 22 | if not hasattr(engine.state, "num_items"): 23 | engine.state.num_items = 0 24 | engine.state.tick = (engine.state.num_items + self.tick_size - 1) // self.tick_size 25 | 26 | def attach(self, engine: Engine): 27 | engine.state_dict_user_keys.append("num_items") 28 | # engine.state_dict_user_keys.append("tick") 29 | 30 | self._init(engine) 31 | engine.add_event_handler(Events.STARTED, self._init) 32 | 33 | @engine.on(Events.ITERATION_COMPLETED) 34 | def update_num_items(e: Engine): 35 | cur_items = self.batch_num_items_transform(e.state.batch) 36 | e.state.num_items += cur_items * idist.get_world_size() 37 | 38 | engine.register_events( 39 | *TickEvent, 40 | event_to_attr={ 41 | TickEvent.TICK_STARTED: "tick", 42 | TickEvent.TICK_COMPLETED: "tick", 43 | } 44 | ) 45 | 46 | @engine.on(Events.ITERATION_STARTED) 47 | def tick_start(e: Engine): 48 | s = e.state 49 | if (s.num_items + self.tick_size) // self.tick_size == s.tick + 1: 50 | s.tick += 1 51 | e.fire_event(TickEvent.TICK_STARTED) 52 | 53 | @engine.on(Events.ITERATION_COMPLETED) 54 | def tick_completed(e: Engine): 55 | s = e.state 56 | if s.num_items // self.tick_size == s.tick: 57 | e.fire_event(TickEvent.TICK_COMPLETED) 58 | -------------------------------------------------------------------------------- /tools/manipulation/celeba_attributes.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from pathlib import Path 3 | 4 | import fire 5 | import numpy as np 6 | import torch 7 | 8 | from editing.config import CELEBA_ATTRS 9 | 10 | 11 | class Worker: 12 | def __init__(self, record_path="./data/interfacegan_dataset.npz", device="cuda"): 13 | records = np.load(record_path) 14 | # 500000x40 15 | self.logit_bank = torch.from_numpy(records["logit"]).to(device, torch.float) 16 | # 500000x512 17 | self.w_bank = torch.from_numpy(records["w"]).to(device, torch.float) 18 | # 40 19 | self.std_bank = self.logit_bank.std(dim=0) 20 | 21 | def random_samples(self, attr_id, num_samples=1000, sample_range=(0.25, 0.75)): 22 | torch.manual_seed(0) 23 | num_total = len(self.logit_bank) 24 | index_pool = self.logit_bank[:, attr_id].topk(num_total, largest=False)[1][ 25 | int(num_total * sample_range[0]): int(num_total * sample_range[1])] 26 | torch.manual_seed(0) 27 | return index_pool[torch.randperm(len(index_pool))[:num_samples]] 28 | 29 | 30 | def generate_examples(): 31 | worker = Worker() 32 | examples = defaultdict(list) 33 | for attr_id in range(40): 34 | for level in range(10): 35 | examples[attr_id].append(worker.w_bank[worker.random_samples(attr_id, 100, (level / 10, level / 10 + 0.1))]) 36 | torch.save(examples, "./data/attribute_examples.pt") 37 | 38 | 39 | def sample(attr_id, num_samples, sample_range=(0.25, 0.75), save_path=None): 40 | assert num_samples < (sample_range[1] - sample_range[0]) * 500000 41 | worker = Worker() 42 | ws = worker.w_bank[worker.random_samples(attr_id, num_samples, sample_range)] 43 | 44 | default_save_name = f"{CELEBA_ATTRS[attr_id]}_{sample_range[0]}-{sample_range[1]}_{num_samples}.w" 45 | if save_path is None: 46 | save_path = default_save_name 47 | else: 48 | save_path = Path(save_path) 49 | if save_path.is_dir(): 50 | save_path = save_path / default_save_name 51 | 52 | torch.save(ws, save_path) 53 | 54 | 55 | if __name__ == '__main__': 56 | fire.Fire() 57 | -------------------------------------------------------------------------------- /models/ada_ops/fma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" 10 | 11 | import torch 12 | 13 | 14 | # ---------------------------------------------------------------------------- 15 | 16 | def fma(a, b, c): # => a * b + c 17 | return _FusedMultiplyAdd.apply(a, b, c) 18 | 19 | 20 | # ---------------------------------------------------------------------------- 21 | 22 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 23 | @staticmethod 24 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 25 | out = torch.addcmul(c, a, b) 26 | ctx.save_for_backward(a, b) 27 | ctx.c_shape = c.shape 28 | return out 29 | 30 | @staticmethod 31 | def backward(ctx, dout): # pylint: disable=arguments-differ 32 | a, b = ctx.saved_tensors 33 | c_shape = ctx.c_shape 34 | da = None 35 | db = None 36 | dc = None 37 | 38 | if ctx.needs_input_grad[0]: 39 | da = _unbroadcast(dout * b, a.shape) 40 | 41 | if ctx.needs_input_grad[1]: 42 | db = _unbroadcast(dout * a, b.shape) 43 | 44 | if ctx.needs_input_grad[2]: 45 | dc = _unbroadcast(dout, c_shape) 46 | 47 | return da, db, dc 48 | 49 | 50 | # ---------------------------------------------------------------------------- 51 | 52 | def _unbroadcast(x, shape): 53 | extra_dims = x.ndim - len(shape) 54 | assert extra_dims >= 0 55 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 56 | if len(dim): 57 | x = x.sum(dim=dim, keepdim=True) 58 | if extra_dims: 59 | x = x.reshape(-1, *x.shape[extra_dims + 1:]) 60 | assert x.shape == shape 61 | return x 62 | 63 | # ---------------------------------------------------------------------------- 64 | -------------------------------------------------------------------------------- /raydl/handlers/ema.py: -------------------------------------------------------------------------------- 1 | import ignite.distributed as idist 2 | import torch 3 | from ignite.handlers import Engine, Events 4 | 5 | 6 | class ModelExponentialMovingAverage: 7 | def __init__(self, model, model_ema, num_items_of_half_life, batch_size_per_gpu, rampup=None, 8 | num_items_transform=None): 9 | assert isinstance(model, torch.nn.Module) 10 | assert isinstance(model_ema, torch.nn.Module) 11 | assert isinstance(num_items_of_half_life, int) and num_items_of_half_life > 0 12 | assert num_items_transform is None or callable(num_items_transform) 13 | 14 | self.model = model 15 | self.model_ema = model_ema 16 | self.num_items_of_half_life = num_items_of_half_life 17 | self.rampup = rampup 18 | self.batch_size_per_gpu = batch_size_per_gpu 19 | self.num_items_transform = num_items_transform 20 | 21 | @staticmethod 22 | @torch.no_grad() 23 | def accumulate(model_ema, model, decay=0.999): 24 | for p_ema, p in zip(model_ema.parameters(), model.parameters()): 25 | p_ema.copy_(p.lerp(p_ema, decay)) 26 | for b_ema, b in zip(model_ema.buffers(), model.buffers()): 27 | b_ema.copy_(b) 28 | 29 | def update(self, engine): 30 | if self.num_items_transform is None: 31 | num_items = engine.state.num_items 32 | else: 33 | num_items = self.num_items_transform(engine) 34 | world_size = engine.state.world_size if hasattr(engine.state, "world_size") else idist.get_world_size() 35 | ema_num_items = self.num_items_of_half_life 36 | if self.rampup is not None: 37 | ema_num_items = min(ema_num_items, num_items * self.rampup) 38 | # The half-life is the time lag at which the exponential weights decay by one half. 39 | # (ema_beta)^(num_items)=0.5 40 | ema_beta = 0.5 ** (self.batch_size_per_gpu * world_size / max(ema_num_items, 1e-8)) 41 | self.accumulate(self.model_ema, self.model, decay=ema_beta) 42 | 43 | def init(self): 44 | # copy params from g to g_ema. 45 | # need to call this after load_state_dict 46 | # make sure g's params would not be updated after this and before the first iteration. 47 | self.accumulate(self.model_ema, self.model, 0) 48 | 49 | def attach(self, engine: Engine, event=Events.ITERATION_COMPLETED): 50 | engine.add_event_handler(Events.STARTED, self.init) 51 | engine.add_event_handler(event, self.update) 52 | -------------------------------------------------------------------------------- /editing/config.py: -------------------------------------------------------------------------------- 1 | CHECKPOINT_PATH = "./pretrained/stylegan2-ffhq-config-f.pt" 2 | FACE_PARSER_CKP = "pretrained/BiSetNet.pth" 3 | CORRECTION_PATH = "pretrained/correction.pt" 4 | STATISTICS_PATH = "" 5 | CLASSIFIER_CKP = "pretrained/Attribute_CelebAMask-HQ_40_classifier.pth" 6 | RECORD_PATH = "" 7 | 8 | MAIN_LAYER = [0, 2, 3, 5, 6, 8, 9, 11, 12, 14, 15, 17, 18, 20, 21, 23, 24] 9 | TO_RGB_LAYER = [1, 4, 7, 10, 13, 16, 19, 22, 25] 10 | NUM_MAIN_LAYER_CHANNEL = [512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 256, 256, 128, 128, 64, 64, 32] 11 | NUM_TO_RGB_CHANNEL = [512, 512, 512, 512, 512, 256, 128, 64, 32] 12 | CHANNEL_MAP = dict(zip(MAIN_LAYER + TO_RGB_LAYER, NUM_MAIN_LAYER_CHANNEL + NUM_TO_RGB_CHANNEL)) 13 | 14 | SEMANTIC_REGION = dict( 15 | [ 16 | ("background", (0,)), # 0 17 | ("brow", (1, 2)), # 1 18 | ("eye", (3, 4)), # 2 19 | ("glass", (5,)), # 3 20 | ("ear", (6, 7, 8)), # 4 21 | ("nose", (9,)), # 5 22 | ("mouth", (10,)), # 6 23 | ("lips", (11, 12)), # 7 24 | ("neck", (13, 14)), # 8 25 | ("cloth", (15,)), # 9 26 | ("hair", (16,)), # 10 27 | ("hat", (17,)), # 11 28 | ("face_up", (18,)), # 12 29 | ("face_middle", (19,)), # 13 30 | ("face_down", (20,)), # 14 31 | ] 32 | ) 33 | 34 | CELEBA_ATTRS = [ 35 | "5_o_Clock_Shadow", # 0 短胡子 36 | "Arched_Eyebrows", # 1 弯眉毛 37 | "Attractive", # 2 有吸引力 38 | "Bags_Under_Eyes", # 3 眼袋 39 | "Bald", # 4 秃顶 40 | "Bangs", # 5 刘海 41 | "Big_Lips", # 6 厚嘴唇 42 | "Big_Nose", # 7 大鼻子 43 | "Black_Hair", # 8 黑色头发 44 | "Blond_Hair", # 9 金色头发 45 | "Blurry", # 10 模糊 46 | "Brown_Hair", # 11 棕色头发 47 | "Bushy_Eyebrows", # 12 浓眉毛 48 | "Chubby", # 13 胖的 49 | "Double_Chin", # 14 双下巴 50 | "Eyeglasses", # 15 眼镜 51 | "Goatee", # 16 山羊胡子 52 | "Gray_Hair", # 17 灰白头发 53 | "Heavy_Makeup", # 18 浓妆 54 | "High_Cheekbones", # 19 高颧骨 55 | "Male", # 20 男性 56 | "Mouth_Slightly_Open", # 21 嘴巴微张 57 | "Mustache", # 22 胡子,髭 58 | "Narrow_Eyes", # 23 小眼睛 59 | "No_Beard", # 24 没有胡子 60 | "Oval_Face", # 25 鸭蛋脸 61 | "Pale_Skin", # 26 皮肤苍白 62 | "Pointy_Nose", # 27 尖鼻子 63 | "Receding_Hairline", # 28 发际线后移 64 | "Rosy_Cheeks", # 29 红润双颊 65 | "Sideburns", # 30 连鬓胡子 66 | "Smiling", # 31 微笑 67 | "Straight_Hair", # 32 直发 68 | "Wavy_Hair", # 33 卷发 69 | "Wearing_Earrings", # 34 戴耳环 70 | "Wearing_Hat", # 35 戴帽子 71 | "Wearing_Lipstick", # 36 涂唇膏 72 | "Wearing_Necklace", # 37 戴项链 73 | "Wearing_Necktie", # 38 戴领带 74 | "Young", # 39 年轻 75 | ] 76 | -------------------------------------------------------------------------------- /raydl/misc.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from typing import Sequence 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from ignite.handlers import Timer 7 | 8 | __all__ = [ 9 | "running_time", 10 | "empty_cuda_cache", 11 | "total_chunk", 12 | "factors_sequence", 13 | "residual_grid_sample", 14 | "assert_shape", 15 | "sequence_chunk" 16 | ] 17 | 18 | 19 | @contextmanager 20 | def running_time(name="this-block", verbose=False): 21 | t = Timer() 22 | yield t 23 | if verbose: 24 | print(f"{name} consume {t.value()}s") 25 | 26 | 27 | def empty_cuda_cache() -> None: 28 | torch.cuda.empty_cache() 29 | import gc 30 | 31 | gc.collect() 32 | 33 | 34 | def assert_shape(x: torch.Tensor, shape): 35 | assert x.ndim == len(shape), f"shape has dim {len(shape)}, but x has dim {x.ndim}" 36 | for s, es in zip(x.shape, shape): 37 | if es is None: 38 | continue 39 | assert s == es, f"x have shape {x.size()} but the expected shape is {shape}." 40 | 41 | 42 | def total_chunk(total, chunk_size, drop_last=False): 43 | seen_amount = 0 44 | while True: 45 | if seen_amount >= total: 46 | break 47 | cur_chunk_size = min(total - seen_amount, chunk_size) 48 | if drop_last and cur_chunk_size < chunk_size: 49 | break 50 | yield cur_chunk_size 51 | seen_amount += cur_chunk_size 52 | 53 | 54 | def sequence_chunk(sequence: Sequence, chunk_size, drop_last=False): 55 | total = len(sequence) 56 | i = 0 57 | for batch in total_chunk(total, chunk_size, drop_last): 58 | yield sequence[i:batch + i] 59 | i += batch 60 | 61 | 62 | def factors_sequence(end_factor, num_factors, start_factor=None, paired_factor=True): 63 | assert num_factors > 0 64 | if num_factors == 1: 65 | return [end_factor, ] 66 | start_factor = start_factor if start_factor is not None else (-end_factor if paired_factor else 0.0) 67 | delta = (end_factor - start_factor) / (num_factors - 1) 68 | return [start_factor + i * delta for i in range(num_factors)] 69 | 70 | 71 | def residual_grid_sample(image, residual_grid, mode='nearest', padding_mode='reflection', align_corners=False): 72 | n, c, h, w = image.size() 73 | assert_shape(residual_grid, (n, 2, h, w)) 74 | 75 | device = image.device 76 | dtype = image.dtype 77 | grid_h = torch.arange(start=-1, end=1 + 1e-10, step=2 / (h - 1), device=device, dtype=dtype) 78 | grid_w = torch.arange(start=-1, end=1 + 1e-10, step=2 / (w - 1), device=device, dtype=dtype) 79 | zero_grid = torch.stack(torch.meshgrid(grid_h, grid_w)[::-1]) 80 | grid = (residual_grid + zero_grid.unsqueeze(dim=0)).permute(0, 2, 3, 1) 81 | return F.grid_sample(image, grid, mode, padding_mode, align_corners) 82 | -------------------------------------------------------------------------------- /tools/manipulation/analysis_channel.py: -------------------------------------------------------------------------------- 1 | import fire 2 | import ignite.distributed as idist 3 | import torch 4 | from loguru import logger 5 | from tqdm import tqdm 6 | 7 | from editing.config import FACE_PARSER_CKP 8 | from models.StyleGAN2_wrapper import ImprovedStyleGAN2Generator 9 | from third_party.BiSetNet import FaceParser 10 | 11 | 12 | def main(checkpoint, save_path, num_samples=10000, batch_size=8, truncation=0.7): 13 | device = idist.device() 14 | G = ImprovedStyleGAN2Generator.load(checkpoint, device=device, default_truncation=truncation) 15 | G.manipulation_mode() 16 | 17 | logger.info("load generator over") 18 | 19 | face_parser = FaceParser(model_path=FACE_PARSER_CKP, device=device) 20 | 21 | logger.info("load face_parser over") 22 | num_batch = (num_samples + batch_size - 1) // batch_size 23 | batch_id = num_batch 24 | 25 | style_grads = None 26 | style_grad_num = None 27 | 28 | pbar = tqdm(total=num_batch, ncols=0) 29 | 30 | while batch_id > 0: 31 | z = torch.randn(batch_size, G.style_dim, device=device) 32 | w = G.z_to_w(z) 33 | styles = G.w_to_styles(w) 34 | 35 | styles = [s.detach().requires_grad_(True) for s in styles] 36 | images = G.styles_to_image(styles) 37 | 38 | with torch.no_grad(): 39 | parsing = face_parser.batch_run(images, pre_normalize=True, image_repr=False, compact_mask=True) 40 | if parsing is None: 41 | continue 42 | 43 | if style_grads is None: 44 | style_grads = [[torch.zeros(s.size(-1), device=device) for _ in range(parsing.size(1))] for s in styles] 45 | style_grad_num = [[0 for _ in range(parsing.size(1))] for _ in styles] 46 | 47 | for mask_id in range(parsing.size(1)): 48 | G.zero_grad() 49 | for s in styles: 50 | s.grad = None 51 | grad_map = parsing[:, [mask_id, ]].repeat(1, 3, 1, 1).float() 52 | grad_map /= grad_map.abs().sum(dim=[1, 2, 3], keepdim=True).clip_(1e-5) 53 | 54 | # some mask result may not contains any content, e.g. full of 0. 55 | num_valid = (grad_map.sum(dim=[-1, -2, -3]) > 0).sum() 56 | images.backward(grad_map, retain_graph=True) 57 | 58 | for i, s in enumerate(styles): 59 | style_grads[i][mask_id] += s.grad.abs().sum(dim=[0]) 60 | style_grad_num[i][mask_id] += num_valid 61 | 62 | batch_id -= 1 63 | pbar.update(1) 64 | pbar.close() 65 | 66 | channel_correction = [] 67 | print(','.join(map(str, [float(c) / (num_batch * batch_size) for c in style_grad_num[0]]))) 68 | for layer in range(len(style_grads)): 69 | channel_correction.append(torch.stack([c.div(n) for c, n in zip(style_grads[layer], style_grad_num[layer])])) 70 | torch.save(channel_correction, save_path) 71 | 72 | 73 | if __name__ == '__main__': 74 | fire.Fire(main) 75 | -------------------------------------------------------------------------------- /raydl/engine.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Any, Callable 3 | 4 | from ignite.engine import Engine, Events 5 | from loguru import logger 6 | 7 | 8 | class RayEngine(Engine): 9 | """ 10 | add debug method for ignite.engine.Engine. if set debug as True, will also print event log. 11 | """ 12 | 13 | def __init__(self, process_function: Callable, debug=False): 14 | super().__init__(process_function) 15 | self._firing_event_depth = 0 16 | self.debug = debug 17 | if self.debug: 18 | self.logger = logger.bind(event=True) 19 | 20 | self.add_event_handler(Events.STARTED, self._reset_firing_event_depth) 21 | 22 | self.state.attrs_to_log = [] 23 | 24 | def _reset_firing_event_depth(self): 25 | self._firing_event_depth = 0 26 | 27 | @property 28 | def _current_status(self) -> str: 29 | if hasattr(self.state, "num_items"): 30 | return f"" 31 | 32 | last_event = self.last_event_name 33 | if last_event in self.state.event_to_attr: 34 | s = self.state 35 | name = self.state.event_to_attr[last_event] 36 | return f"<{name}={getattr(s, name)}>" 37 | else: 38 | return f"" 39 | 40 | def _fire_event(self, event_name: Any, *event_args: Any, **event_kwargs: Any) -> None: 41 | """ 42 | reset logger info 43 | :param event_name: 44 | :param event_args: 45 | :param event_kwargs: 46 | :return: 47 | """ 48 | if not self.debug: 49 | return super(RayEngine, self)._fire_event(event_name, *event_args, **event_kwargs) 50 | 51 | padding = " " * 6 52 | self.last_event_name = event_name 53 | self._firing_event_depth += 1 54 | start_time = None 55 | if len(self._event_handlers[event_name]): 56 | prefix = max(self._firing_event_depth - 2, 0) * f"│{padding}" + ( 57 | 1 if self._firing_event_depth >= 2 else 0) * f"{padding}├─ " 58 | self.logger.debug(f"{prefix}firing {event_name.name} at {self._current_status}") 59 | start_time = time.time() 60 | prefix = max(self._firing_event_depth - 1, 0) * f"{padding}│" + padding 61 | for func, args, kwargs in self._event_handlers[event_name]: 62 | kwargs.update(event_kwargs) 63 | first, others = ((args[0],), args[1:]) if (args and args[0] == self) else ((), args) 64 | self.logger.debug(f"{prefix}├─ {repr(func)}") 65 | func(*first, *(event_args + others), **kwargs) 66 | self._firing_event_depth -= 1 67 | if start_time is not None: 68 | overall_time = time.time() - start_time 69 | timer_str = f"{overall_time * 1000:.4f}ms" if overall_time * 1000 < 1 else f"{overall_time:.4f}s" 70 | self.logger.debug(f"{prefix}└── finish in {timer_str}") 71 | -------------------------------------------------------------------------------- /raydl/handlers/running_statistics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class RunningStatistician: 5 | _support_statistics = ("min/index", "max/index", "mean", "std") 6 | 7 | def __init__(self, statistic_names=_support_statistics): 8 | assert all(x in self._support_statistics for x in statistic_names), \ 9 | f"only support statistic in {self._support_statistics}, but got {statistic_names}" 10 | self.names = statistic_names 11 | self._dtype = torch.float64 12 | 13 | self._s_counter = 0 14 | self._s_min, self._s_min_index = None, None 15 | self._s_max, self._s_max_index = None, None 16 | self._s_sum = torch.as_tensor(0.0).to(self._dtype) 17 | self._s_square = torch.as_tensor(0.0).to(self._dtype) 18 | 19 | def sum(self): 20 | return self._s_sum 21 | 22 | def __len__(self): 23 | return self._s_counter 24 | 25 | def internal_status(self): 26 | return {k: v for k, v in self.__dict__.items() if k.startswith("_s_")} 27 | 28 | def reset(self): 29 | self._s_counter = 0 30 | self._s_min, self._s_min_index = None, None 31 | self._s_max, self._s_max_index = None, None 32 | self._s_sum = torch.as_tensor(0.0).to(self._dtype) 33 | self._s_square = torch.as_tensor(0.0).to(self._dtype) 34 | 35 | def update(self, x): 36 | element = torch.as_tensor(x) 37 | if torch.numel(element) == 0: 38 | return x 39 | 40 | element = element.detach().flatten().to(self._dtype) 41 | self._s_counter += int(torch.ones_like(element).sum()) 42 | 43 | if "min/index" in self.names: 44 | min_element = element.min() 45 | self._s_min, self._s_min_index = (min_element, self._s_counter) if \ 46 | self._s_min is None or self._s_min > min_element else (self._s_min, self._s_min_index) 47 | 48 | if "max/index" in self.names: 49 | max_element = element.max() 50 | self._s_max, self._s_max_index = (max_element, self._s_counter) if \ 51 | self._s_max is None or self._s_max < max_element else (self._s_max, self._s_max_index) 52 | 53 | if "mean" in self.names or "std" in self.names: 54 | self._s_sum += element.sum() 55 | 56 | if "std" in self.names: 57 | self._s_square += element.square().sum() 58 | 59 | def compute(self): 60 | amount = int(self._s_counter) 61 | output = dict() 62 | for name in self.names: 63 | if name == "min/index": 64 | output["min/index"] = (float(self._s_min.cpu()), self._s_min_index) 65 | if name == "max/index": 66 | output["max/index"] = (float(self._s_max.cpu()), self._s_max_index) 67 | if name == "mean": 68 | if amount == 0: 69 | output["mean"] = float("nan") 70 | else: 71 | output["mean"] = float(self._s_sum.cpu()) / amount 72 | if name == "std": 73 | if amount == 0: 74 | output["std"] = float("nan") 75 | elif amount == 1: 76 | output["std"] = 0.0 77 | else: 78 | mean = output.get("mean", float(self._s_sum.cpu()) / amount) 79 | output["std"] = float(((self._s_square - amount * mean ** 2) / (amount - 1)).sqrt().cpu()) 80 | return output 81 | -------------------------------------------------------------------------------- /models/ada_ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.grid_sample` that 10 | supports arbitrarily high order gradients between the input and output. 11 | Only works on 2D images and assumes 12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" 13 | 14 | import warnings 15 | 16 | import torch 17 | import torch.nn.functional 18 | 19 | # pylint: disable=redefined-builtin 20 | # pylint: disable=arguments-differ 21 | # pylint: disable=protected-access 22 | 23 | # ---------------------------------------------------------------------------- 24 | 25 | enabled = False # Enable the custom op by setting this to true. 26 | 27 | 28 | # ---------------------------------------------------------------------------- 29 | 30 | def grid_sample(input, grid): 31 | if _should_use_custom_op(): 32 | return _GridSample2dForward.apply(input, grid) 33 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', 34 | align_corners=False) 35 | 36 | 37 | # ---------------------------------------------------------------------------- 38 | 39 | def _should_use_custom_op(): 40 | if not enabled: 41 | return False 42 | if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): 43 | return True 44 | warnings.warn( 45 | f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().') 46 | return False 47 | 48 | 49 | # ---------------------------------------------------------------------------- 50 | 51 | class _GridSample2dForward(torch.autograd.Function): 52 | @staticmethod 53 | def forward(ctx, input, grid): 54 | assert input.ndim == 4 55 | assert grid.ndim == 4 56 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', 57 | align_corners=False) 58 | ctx.save_for_backward(input, grid) 59 | return output 60 | 61 | @staticmethod 62 | def backward(ctx, grad_output): 63 | input, grid = ctx.saved_tensors 64 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 65 | return grad_input, grad_grid 66 | 67 | 68 | # ---------------------------------------------------------------------------- 69 | 70 | class _GridSample2dBackward(torch.autograd.Function): 71 | @staticmethod 72 | def forward(ctx, grad_output, input, grid): 73 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 74 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 75 | ctx.save_for_backward(grid) 76 | return grad_input, grad_grid 77 | 78 | @staticmethod 79 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 80 | _ = grad2_grad_grid # unused 81 | grid, = ctx.saved_tensors 82 | grad2_grad_output = None 83 | grad2_input = None 84 | grad2_grid = None 85 | 86 | if ctx.needs_input_grad[0]: 87 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 88 | 89 | assert not ctx.needs_input_grad[2] 90 | return grad2_grad_output, grad2_input, grad2_grid 91 | 92 | # ---------------------------------------------------------------------------- 93 | -------------------------------------------------------------------------------- /tools/transfer/convert_stylegan2-ada_weights.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | 4 | import fire 5 | import numpy as np 6 | import torch 7 | 8 | 9 | def convert_to_rgb(state_ros, state_nv, ros_name, nv_name): 10 | state_ros[f"{ros_name}.conv.weight"] = state_nv[f"{nv_name}.torgb.weight"].unsqueeze(0) 11 | state_ros[f"{ros_name}.bias"] = state_nv[f"{nv_name}.torgb.bias"].unsqueeze(0).unsqueeze(-1).unsqueeze(-1) 12 | state_ros[f"{ros_name}.conv.modulation.weight"] = state_nv[f"{nv_name}.torgb.affine.weight"] 13 | state_ros[f"{ros_name}.conv.modulation.bias"] = state_nv[f"{nv_name}.torgb.affine.bias"] 14 | 15 | 16 | def convert_conv(state_ros, state_nv, ros_name, nv_name): 17 | state_ros[f"{ros_name}.conv.weight"] = state_nv[f"{nv_name}.weight"].unsqueeze(0) 18 | state_ros[f"{ros_name}.activate.bias"] = state_nv[f"{nv_name}.bias"] 19 | state_ros[f"{ros_name}.conv.modulation.weight"] = state_nv[f"{nv_name}.affine.weight"] 20 | state_ros[f"{ros_name}.conv.modulation.bias"] = state_nv[f"{nv_name}.affine.bias"] 21 | state_ros[f"{ros_name}.noise.weight"] = state_nv[f"{nv_name}.noise_strength"].unsqueeze(0) 22 | 23 | 24 | def convert_blur_kernel(state_ros, state_nv, level): 25 | """Not quite sure why there is a factor of 4 here""" 26 | # They are all the same 27 | state_ros[f"convs.{2 * level}.conv.blur.kernel"] = 4 * state_nv["synthesis.b4.resample_filter"] 28 | state_ros[f"to_rgbs.{level}.upsample.kernel"] = 4 * state_nv["synthesis.b4.resample_filter"] 29 | 30 | 31 | def determine_config(state_nv): 32 | mapping_names = [name for name in state_nv.keys() if "mapping.fc" in name] 33 | sythesis_names = [name for name in state_nv.keys() if "synthesis.b" in name] 34 | 35 | n_mapping = max([int(re.findall("(\d+)", n)[0]) for n in mapping_names]) + 1 36 | resolution = max([int(re.findall("(\d+)", n)[0]) for n in sythesis_names]) 37 | n_layers = np.log(resolution / 2) / np.log(2) 38 | 39 | return n_mapping, n_layers 40 | 41 | 42 | def convert(ada_path, network_pkl, output_file): 43 | """ 44 | convert weights from stylegan2-ada-pytorch 45 | :param ada_path: https://github.com/dvschultz/stylegan2-ada-pytorch local path 46 | :param network_pkl: path to original network weights 47 | :param output_file: path to converted network weights 48 | :return: 49 | """ 50 | sys.path.append(ada_path) 51 | 52 | import legacy # noqa 53 | import dnnlib # noqa 54 | 55 | with dnnlib.util.open_url(network_pkl) as f: 56 | ckp = legacy.load_network_pkl(f) 57 | G_nvidia = ckp["G_ema"] 58 | 59 | state_nv = G_nvidia.state_dict() 60 | n_mapping, n_layers = determine_config(state_nv) 61 | 62 | state_ros = {} 63 | 64 | for i in range(n_mapping): 65 | state_ros[f"style.{i + 1}.weight"] = state_nv[f"mapping.fc{i}.weight"] 66 | state_ros[f"style.{i + 1}.bias"] = state_nv[f"mapping.fc{i}.bias"] 67 | 68 | for i in range(int(n_layers)): 69 | if i > 0: 70 | for conv_level in range(2): 71 | convert_conv(state_ros, state_nv, f"convs.{2 * i - 2 + conv_level}", 72 | f"synthesis.b{4 * (2 ** i)}.conv{conv_level}") 73 | state_ros[f"noises.noise_{2 * i - 1 + conv_level}"] = state_nv[ 74 | f"synthesis.b{4 * (2 ** i)}.conv{conv_level}.noise_const"].unsqueeze(0).unsqueeze(0) 75 | 76 | convert_to_rgb(state_ros, state_nv, f"to_rgbs.{i - 1}", f"synthesis.b{4 * (2 ** i)}") 77 | convert_blur_kernel(state_ros, state_nv, i - 1) 78 | 79 | else: 80 | state_ros[f"input.input"] = state_nv[f"synthesis.b{4 * (2 ** i)}.const"].unsqueeze(0) 81 | convert_conv(state_ros, state_nv, "conv1", f"synthesis.b{4 * (2 ** i)}.conv1") 82 | state_ros[f"noises.noise_{2 * i}"] = state_nv[f"synthesis.b{4 * (2 ** i)}.conv1.noise_const"].unsqueeze( 83 | 0).unsqueeze(0) 84 | convert_to_rgb(state_ros, state_nv, "to_rgb1", f"synthesis.b{4 * (2 ** i)}") 85 | 86 | # https://github.com/yuval-alaluf/restyle-encoder/issues/1#issuecomment-828354736 87 | latent_avg = state_nv['mapping.w_avg'] 88 | state_dict = {"g_ema": state_ros, "latent_avg": latent_avg, "d": ckp["D"].state_dict()} 89 | torch.save(state_dict, output_file) 90 | 91 | 92 | if __name__ == "__main__": 93 | fire.Fire(convert) 94 | -------------------------------------------------------------------------------- /models/ada_ops/bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "bias_act.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y) 17 | { 18 | if (x.dim() != y.dim()) 19 | return false; 20 | for (int64_t i = 0; i < x.dim(); i++) 21 | { 22 | if (x.size(i) != y.size(i)) 23 | return false; 24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) 25 | return false; 26 | } 27 | return true; 28 | } 29 | 30 | //------------------------------------------------------------------------ 31 | 32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) 33 | { 34 | // Validate arguments. 35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); 37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); 38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); 39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); 40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1"); 42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); 43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); 44 | TORCH_CHECK(grad >= 0, "grad must be non-negative"); 45 | 46 | // Validate layout. 47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); 48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); 49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); 50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); 51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); 52 | 53 | // Create output tensor. 54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 55 | torch::Tensor y = torch::empty_like(x); 56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); 57 | 58 | // Initialize CUDA kernel parameters. 59 | bias_act_kernel_params p; 60 | p.x = x.data_ptr(); 61 | p.b = (b.numel()) ? b.data_ptr() : NULL; 62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL; 63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL; 64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL; 65 | p.y = y.data_ptr(); 66 | p.grad = grad; 67 | p.act = act; 68 | p.alpha = alpha; 69 | p.gain = gain; 70 | p.clamp = clamp; 71 | p.sizeX = (int)x.numel(); 72 | p.sizeB = (int)b.numel(); 73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; 74 | 75 | // Choose CUDA kernel. 76 | void* kernel; 77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 78 | { 79 | kernel = choose_bias_act_kernel(p); 80 | }); 81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); 82 | 83 | // Launch CUDA kernel. 84 | p.loopX = 4; 85 | int blockSize = 4 * 32; 86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 87 | void* args[] = {&p}; 88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 89 | return y; 90 | } 91 | 92 | //------------------------------------------------------------------------ 93 | 94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 95 | { 96 | m.def("bias_act", &bias_act); 97 | } 98 | 99 | //------------------------------------------------------------------------ 100 | -------------------------------------------------------------------------------- /tools/evaluate/blur.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import defaultdict 3 | from pathlib import Path 4 | 5 | import cv2 6 | import fire 7 | import numpy as np 8 | from skimage.metrics import structural_similarity 9 | from torchvision.datasets.folder import is_image_file 10 | from tqdm import tqdm 11 | 12 | 13 | def Vollath(img): 14 | ''' 15 | :param img:narray 二维灰度图像 16 | :return: float 图像约清晰越大 17 | ''' 18 | shape = np.shape(img) 19 | u = np.mean(img) 20 | out = -shape[0] * shape[1] * (u ** 2) 21 | for x in range(0, shape[0] - 1): 22 | for y in range(0, shape[1]): 23 | out += int(img[x, y]) * int(img[x + 1, y]) 24 | return out 25 | 26 | 27 | def entropy(img): 28 | ''' 29 | :param img:narray 二维灰度图像 30 | :return: float 图像约清晰越大 31 | ''' 32 | out = 0 33 | count = np.shape(img)[0] * np.shape(img)[1] 34 | p = np.bincount(np.array(img).flatten()) 35 | for i in range(0, len(p)): 36 | if p[i] != 0: 37 | out -= p[i] * math.log(p[i] / count) / count 38 | return out 39 | 40 | 41 | def sobel(img): 42 | x = cv2.Sobel(img, cv2.CV_16S, 1, 0) 43 | y = cv2.Sobel(img, cv2.CV_16S, 0, 1) 44 | absX = cv2.convertScaleAbs(x) # 转回uint8 45 | absY = cv2.convertScaleAbs(y) 46 | dst = cv2.addWeighted(absX, 0.5, absY, 0.5, 0) 47 | return dst 48 | 49 | 50 | def nrss(image): 51 | image_blur = cv2.GaussianBlur(image, (7, 7), 0) 52 | 53 | G, Gr = sobel(image), sobel(image_blur) 54 | 55 | (h, w) = G.shape 56 | G_blk_list = [] 57 | Gr_blk_list = [] 58 | sp = 6 59 | for i in range(sp): 60 | for j in range(sp): 61 | G_blk = G[int((i / sp) * h):int(((i + 1) / sp) * h), 62 | int((j / sp) * w):int(((j + 1) / sp) * w)] 63 | Gr_blk = Gr[int((i / sp) * h):int(((i + 1) / sp) * h), 64 | int((j / sp) * w):int(((j + 1) / sp) * w)] 65 | G_blk_list.append(G_blk) 66 | Gr_blk_list.append(Gr_blk) 67 | sum = 0 68 | for i in range(sp * sp): 69 | mssim = structural_similarity(G_blk_list[i], Gr_blk_list[i]) 70 | sum = mssim + sum 71 | return 1 - sum / (sp * sp * 1.0) 72 | 73 | 74 | def reblur(img): 75 | image_blur = cv2.GaussianBlur(img, (7, 7), 0) 76 | 77 | img = img.astype(np.float64) 78 | image_blur = image_blur.astype(np.float64) 79 | 80 | s_Vver = np.clip(np.abs(img[1:, :-1] - img[:-1, :-1]) - np.abs( 81 | image_blur[1:, :-1] - image_blur[:-1, :-1]), 0, None).sum() 82 | s_Fver = np.abs(img[1:, :-1] - img[:-1, :-1]).sum() 83 | 84 | s_Fhor = np.abs(img[:-1, 1:] - img[:-1, :-1]).sum() 85 | s_Vhor = np.clip(np.abs(img[:-1, 1:] - img[:-1, :-1]) - np.abs( 86 | image_blur[:-1, 1:] - image_blur[:-1, :-1]), 0, None).sum() 87 | 88 | return 1 - max((s_Fver - s_Vver) / s_Fver, (s_Fhor - s_Vhor) / s_Fhor) 89 | 90 | 91 | def energy(img): 92 | img = img.astype(np.float64) 93 | return np.power(img[1:, :-1] - img[:-1, :-1], 2).sum() / img.size + np.power(img[:-1, 1:] - img[:-1, :-1], 94 | 2).sum() / img.size 95 | 96 | 97 | def SMD2(gray): 98 | gray = gray.astype(np.float64) 99 | return (np.abs(gray[1:, :-1] - gray[:-1, :-1]) * np.abs(gray[:-1, 1:] - gray[:-1, :-1])).sum() / gray.size 100 | 101 | 102 | def brenner(img): 103 | ''' 104 | :param img:narray 二维灰度图像 105 | :return: float 图像约清晰越大 106 | ''' 107 | # shape = np.shape(img) 108 | # out = 0 109 | # for x in range(0, shape[0]-2): 110 | # for y in range(0, shape[1]): 111 | # out += (int(img[x+2, y])-int(img[x, y]))**2 112 | img = img.astype(np.float64) 113 | return np.power((img[2:] - img[:-2]), 2).sum() / img.size 114 | 115 | 116 | def loop(image_path: Path): 117 | gray = cv2.imread(image_path.as_posix(), cv2.IMREAD_GRAYSCALE) 118 | return dict( 119 | SMD2=SMD2(gray), 120 | brenner=brenner(gray), 121 | nrss=nrss(gray), 122 | energy=energy(gray), 123 | reblur=reblur(gray), 124 | ) 125 | 126 | 127 | def main(root): 128 | """ 129 | calculate image 'blur-ness' with various metrics 130 | :param root: image_folder 131 | :return: 132 | """ 133 | metrics = defaultdict(list) 134 | pbar = tqdm(filter(lambda path: is_image_file(path.name), Path(root).glob("*"))) 135 | for img in pbar: 136 | r = loop(img) 137 | desc = ", ".join(f"{k}:{v:.2f}" for k, v in r.items()) 138 | pbar.set_description(f"{desc} - {img.name}") 139 | for k in r: 140 | metrics[k].append(r[k]) 141 | 142 | print("\n______________________________\n") 143 | for k in metrics: 144 | print(k, f"{sum(metrics[k]) / len(metrics[k]):.4f}") 145 | 146 | 147 | if __name__ == "__main__": 148 | fire.Fire(main) 149 | -------------------------------------------------------------------------------- /raydl/collection.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import re 3 | from typing import Iterable, Any, Union, Tuple, Sequence, Optional 4 | 5 | import torch 6 | 7 | __all__ = [ 8 | "tuple_of_type", 9 | "tuple_of_indices", 10 | "parse_indices_str", 11 | "AttributeDict", 12 | "describe_dict", 13 | "paired_indexes" 14 | ] 15 | 16 | 17 | def tuple_of_type(data: Union[Iterable, Any], target_types: Union[type, Sequence[type]] = (int, float), 18 | skip_unwanted_item: bool = True) -> Tuple: 19 | """ 20 | return a tuple of target type from data 21 | :param data: 22 | :param target_types: 23 | :param skip_unwanted_item: 24 | :return: 25 | """ 26 | if isinstance(data, target_types): 27 | return (data,) 28 | if isinstance(data, Iterable): 29 | values = [] 30 | for i, d in enumerate(data): 31 | if not isinstance(d, target_types): 32 | if skip_unwanted_item: 33 | continue 34 | else: 35 | raise ValueError(f"expect types: {target_types}, but the type of the {i}-th items is {type(d)}") 36 | values.append(d) 37 | return tuple(values) 38 | raise ValueError(f"expect data have type {target_types} or Iterable of {target_types} but got {type(data)}") 39 | 40 | 41 | def parse_indices_str(indices_str: str) -> Tuple: 42 | """ 43 | parse string that contains indices 44 | :param indices_str: a string that only contains indices, i.e. only contains digit[0-9], `-`, and `,` 45 | :return: a tuple of indices, keep in order of indices_str 46 | """ 47 | assert isinstance(indices_str, str), f"indices_str must be string, but got {type(indices_str)}" 48 | assert re.match(r"^(((\d+-\d+)|\d+),?)+$", indices_str) is not None, "invalid list string" 49 | 50 | if "," not in indices_str and "-" not in indices_str: 51 | return (int(indices_str),) 52 | result = [] 53 | for sub in indices_str.split(","): 54 | if "-" in sub: 55 | s_left, s_right = tuple(map(int, sub.split("-"))) 56 | delta = -1 if s_right < s_left else 1 57 | result += list(range(s_left, s_right + delta, delta)) 58 | else: 59 | result.append(int(sub)) 60 | return tuple(result) 61 | 62 | 63 | def distinct_sequence(seq: Sequence): 64 | # if with python > 3.7, below must be faster. 65 | # although faster make a little a little sense. 66 | # https://stackoverflow.com/questions/480214/how-do-you-remove-duplicates-from-a-list-whilst-preserving-order 67 | # return list(dict.fromkeys(items)) 68 | seen = set() 69 | seen_add = seen.add 70 | return (x for x in seq if not (x in seen or seen_add(x))) 71 | 72 | 73 | def tuple_of_indices(data: Union[int, Sequence[int]], need_distinct: bool = False): 74 | indices = parse_indices_str(data) if isinstance(data, str) else tuple_of_type(data, int) 75 | if not need_distinct: 76 | return indices 77 | return distinct_sequence(indices) 78 | 79 | 80 | def paired_indexes(arg1, arg2): 81 | index1 = tuple_of_indices(arg1) 82 | index2 = tuple_of_indices(arg2) 83 | 84 | if len(index1) > 1: 85 | assert len(index2) == len(index1) 86 | 87 | if len(index1) == 1: 88 | index1 = index1 * len(index2) 89 | return index1, index2 90 | 91 | 92 | class AttributeDict(dict): 93 | """Convenience class that behaves like a dict but allows access with the attribute syntax.""" 94 | 95 | def __getattr__(self, name: str) -> Any: 96 | try: 97 | return self[name] 98 | except KeyError: 99 | raise AttributeError(name) 100 | 101 | def __setattr__(self, name: str, value: Any) -> None: 102 | self[name] = value 103 | 104 | def __delattr__(self, name: str) -> None: 105 | del self[name] 106 | 107 | 108 | def describe_dict(state_dict: dict, separator: str = "\n\t", head: Optional[str] = "dict content:") -> str: 109 | """ 110 | return dict description str. 111 | """ 112 | if not isinstance(state_dict, dict): 113 | raise ValueError(f"describe_dict only accept dict type, but got {type(state_dict)}") 114 | strings = [head] if head is not None else [] 115 | for k, v in state_dict.items(): 116 | if isinstance(v, (int, float, str, pathlib.Path)) or v is None: 117 | value_str = str(v) 118 | elif torch.is_tensor(v): 119 | if len(v.size()) == 0 or torch.numel(v) <= 16 and v.size(0) == 1: 120 | value_str = str(v) 121 | else: 122 | value_str = f"{type(v)}(dtype={v.dtype}, device={v.device}, shape={v.shape})" 123 | elif hasattr(v, "__len__"): 124 | value_str = f"{type(v)}(length={len(v)})" 125 | else: 126 | value_str = f"{type(v)}" 127 | strings.append(f"{k}: {value_str}") 128 | return separator.join(strings) 129 | -------------------------------------------------------------------------------- /models/ada_ops/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "upfirdn2d.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) 17 | { 18 | // Validate arguments. 19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); 21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); 22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); 24 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 25 | TORCH_CHECK(f.dim() == 2, "f must be rank 2"); 26 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); 27 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); 28 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); 29 | 30 | // Create output tensor. 31 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 32 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; 33 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; 34 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); 35 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); 36 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); 37 | 38 | // Initialize CUDA kernel parameters. 39 | upfirdn2d_kernel_params p; 40 | p.x = x.data_ptr(); 41 | p.f = f.data_ptr(); 42 | p.y = y.data_ptr(); 43 | p.up = make_int2(upx, upy); 44 | p.down = make_int2(downx, downy); 45 | p.pad0 = make_int2(padx0, pady0); 46 | p.flip = (flip) ? 1 : 0; 47 | p.gain = gain; 48 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 49 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); 50 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); 51 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); 52 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); 53 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); 54 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; 55 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; 56 | 57 | // Choose CUDA kernel. 58 | upfirdn2d_kernel_spec spec; 59 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 60 | { 61 | spec = choose_upfirdn2d_kernel(p); 62 | }); 63 | 64 | // Set looping options. 65 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; 66 | p.loopMinor = spec.loopMinor; 67 | p.loopX = spec.loopX; 68 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; 69 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; 70 | 71 | // Compute grid size. 72 | dim3 blockSize, gridSize; 73 | if (spec.tileOutW < 0) // large 74 | { 75 | blockSize = dim3(4, 32, 1); 76 | gridSize = dim3( 77 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, 78 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, 79 | p.launchMajor); 80 | } 81 | else // small 82 | { 83 | blockSize = dim3(256, 1, 1); 84 | gridSize = dim3( 85 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, 86 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, 87 | p.launchMajor); 88 | } 89 | 90 | // Launch CUDA kernel. 91 | void* args[] = {&p}; 92 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 93 | return y; 94 | } 95 | 96 | //------------------------------------------------------------------------ 97 | 98 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 99 | { 100 | m.def("upfirdn2d", &upfirdn2d); 101 | } 102 | 103 | //------------------------------------------------------------------------ 104 | -------------------------------------------------------------------------------- /raydl/metrics/common.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | from collections import OrderedDict 3 | from typing import Callable, Union 4 | 5 | import torch 6 | from ignite.engine import Engine, Events 7 | from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce 8 | 9 | 10 | class Collector(Metric): 11 | required_output_keys = None 12 | 13 | def __init__( 14 | self, 15 | output_transform: Callable = lambda x: x, 16 | reset_after_computing=True, 17 | device: Union[str, torch.device] = torch.device("cpu"), 18 | ): 19 | self._reset_after_computing = reset_after_computing 20 | 21 | super().__init__(output_transform=output_transform, device=device) 22 | 23 | @staticmethod 24 | def _clean_v(v): 25 | if not isinstance(v, (numbers.Number, torch.Tensor)): 26 | raise TypeError(f"Output should be a number or torch.Tensor, but given {type(v)}") 27 | if torch.is_tensor(v): 28 | v = v.detach() 29 | if v.ndim == 0 or (v.ndim == 1 and len(v) == 1): 30 | return v 31 | if v.ndim == 2 and v.size(-1) == 1: 32 | return v.sum() 33 | if v.ndim == 1 and len(v) > 1: 34 | return v.sum() 35 | raise TypeError(f"if output is tensor, must have size Nx1 or 1 or N or zero dim," 36 | f" but given size: {v.size()} dim: {v.dim()}") 37 | return v 38 | 39 | def _len_v(self, v): 40 | if not torch.is_tensor(v) or v.ndim == 0: 41 | return torch.ones([], device=self._device) 42 | return torch.as_tensor([len(v)], device=self._device) 43 | 44 | def _break_combined_tensor(self): 45 | if not (self.accumulator.dim() == self.num_examples.dim() == 1) \ 46 | or not (len(self.accumulator) == len(self.num_examples) == len(self.keys)): 47 | raise RuntimeError( 48 | f"internal variables do not match in size: {self.accumulator.size()} {self.num_examples.size()}") 49 | 50 | return OrderedDict(zip(self.keys, self.accumulator)), OrderedDict(zip(self.keys, self.num_examples)) 51 | 52 | @reinit__is_reduced 53 | def reset(self) -> None: 54 | self.accumulator = torch.tensor([], dtype=torch.float64, device=self._device) 55 | self.num_examples = torch.tensor([], dtype=torch.int32, device=self._device) 56 | self.keys = [] 57 | 58 | @reinit__is_reduced 59 | def update(self, output: OrderedDict) -> None: 60 | cur_keys = OrderedDict([(k, None) for k in output.keys()]) 61 | _cur_num_examples = [] 62 | _cur_examples = [] 63 | for k in self.keys: 64 | if k in cur_keys: 65 | cur_keys.pop(k) 66 | v = output.get(k, None) 67 | _cur_num_examples.append(torch.zeros([], device=self._device) if v is None else self._len_v(v)) 68 | _cur_examples.append(torch.zeros([], device=self._device) if v is None else self._clean_v(v)) 69 | if len(cur_keys) > 0: 70 | # new key-value pair 71 | for k in cur_keys: 72 | v = output[k] 73 | try: 74 | _cur_examples.append(self._clean_v(v)) 75 | _cur_num_examples.append(self._len_v(v)) 76 | except TypeError as e: 77 | raise ValueError(f"Invalid value: {e}. key-value pair: {k}: {v}") 78 | self.keys.append(k) 79 | self.accumulator = torch.cat( 80 | [self.accumulator, torch.zeros([len(cur_keys)], dtype=torch.float64, device=self._device)]) 81 | self.num_examples = torch.cat( 82 | [self.num_examples, torch.zeros([len(cur_keys)], dtype=torch.int32, device=self._device)]) 83 | self.accumulator.add_(torch.as_tensor(_cur_examples, dtype=torch.float64, device=self._device)) 84 | self.num_examples.add_(torch.as_tensor(_cur_num_examples, dtype=torch.int32, device=self._device)) 85 | 86 | @torch.no_grad() 87 | def iteration_completed(self, engine: Engine) -> None: 88 | output = self._output_transform(engine.state.output) 89 | self.update(output) 90 | 91 | @sync_all_reduce("num_examples", "accumulator") 92 | def compute(self): 93 | _accumulator, _num_examples = self._break_combined_tensor() 94 | if self._reset_after_computing: 95 | self.reset() 96 | return {k: float(_accumulator[k] / _num_examples[k]) for k in _accumulator} 97 | 98 | def attach( 99 | self, engine: Engine, name: str, event_name: Events = Events.ITERATION_COMPLETED 100 | ) -> None: 101 | engine.add_event_handler(Events.STARTED, self.started) 102 | engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed) 103 | engine.add_event_handler(event_name, self.completed, name) 104 | 105 | @property 106 | def description(self): 107 | s = "Collector State:\n" 108 | s += f"\tMetric(num_examples):accumulator\n" 109 | for i, k in enumerate(self.keys): 110 | s += f"\t{k}({int(self.num_examples[i])}): {float(self.accumulator[i]):.4f}\n" 111 | return s 112 | -------------------------------------------------------------------------------- /tools/manipulation/enhance.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import fire 4 | import torch 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from loguru import logger 8 | from tqdm import tqdm 9 | 10 | import raydl 11 | from editing import util 12 | from editing.config import CHECKPOINT_PATH, FACE_PARSER_CKP 13 | from editing.modification import Manipulator 14 | from models.StyleGAN2_wrapper import ImprovedStyleGAN2Generator 15 | from third_party.BiSetNet import FaceParser 16 | 17 | 18 | def crop_vector(vector, rate=0.8, exponent=2): 19 | assert 1 >= rate >= 0 20 | vector = vector.flatten().abs() 21 | vector_norm_sum = vector.pow(exponent).sum() 22 | min_norm = vector_norm_sum * rate 23 | values, indexes = vector.pow(exponent).topk(len(vector)) 24 | norm = 0 25 | i = 0 26 | while norm <= min_norm: 27 | norm += values[i] 28 | i += 1 29 | main_indexes = indexes[:i] 30 | mask = torch.zeros_like(vector) 31 | mask[main_indexes] = 1 32 | return vector * mask 33 | 34 | 35 | def prev( 36 | modification_path, 37 | mask_ids, 38 | checkpoint=CHECKPOINT_PATH, 39 | num_batch=500, 40 | batch_size=16, 41 | save_path=None, 42 | device="cuda", 43 | seed=None, 44 | lr=0.1, 45 | delta_s_min_rate=0.6, 46 | min_resolution=32, 47 | lambda_neg=1, 48 | replace_factor=1, 49 | compact_mask=True, 50 | truncation=0.7, 51 | ): 52 | print(locals()) 53 | G = ImprovedStyleGAN2Generator.load(checkpoint, device=device, default_truncation=truncation) 54 | G.manipulation_mode() 55 | G.eval() 56 | G.requires_grad_(False) 57 | 58 | if seed is not None: 59 | torch.manual_seed(seed) 60 | 61 | mask_ids = raydl.tuple_of_indices(mask_ids) 62 | 63 | face_parser = FaceParser(model_path=FACE_PARSER_CKP) 64 | 65 | if save_path is None: 66 | save_path = Path(modification_path).parent / f"before_{Path(modification_path).stem}.mdfc" 67 | else: 68 | save_path = Path(save_path) 69 | 70 | modification = Manipulator.load(modification_path, device) 71 | 72 | with torch.no_grad(): 73 | w = G.z_to_w(torch.randn(4096, G.style_dim, device=device), truncation=truncation) 74 | std_styles = [s.std(dim=0, keepdim=True) for s in G.w_to_styles(w)] 75 | mean_styles = G.w_to_styles(w.mean(dim=0, keepdim=True)) 76 | 77 | min_layer = min(modification.data.keys()) 78 | prev_layer = util.prev_layer(min_layer) 79 | 80 | delta_s_indices = torch.flatten(torch.nonzero(crop_vector( 81 | vector=modification.data[min_layer]["move"]["direction"] / std_styles[min_layer].to(device), 82 | rate=delta_s_min_rate, 83 | ))).tolist() 84 | logger.info(f"optimize for {len(delta_s_indices)} dims: {delta_s_indices}") 85 | 86 | assert "move" in modification.data[min_layer] 87 | 88 | modification.add_replacement(prev_layer, mean_styles[prev_layer], replace_indexes=delta_s_indices) 89 | modification.to(device) 90 | modification.data[prev_layer]["replace"]["style"].requires_grad_(True) 91 | 92 | optimizer = optim.Adam([modification.data[prev_layer]["replace"]["style"]], lr=lr) 93 | 94 | for i in tqdm(range(num_batch), ncols=0): 95 | z = torch.randn(batch_size, 512, device=device) 96 | w = G.z_to_w(z) 97 | zero_styles = G.w_to_styles(w) 98 | zero_image = G.styles_to_image(zero_styles) 99 | 100 | _, feature = G.styles_to_image_and_features( 101 | modification.apply(styles=zero_styles, move_factor=0, replace_factor=replace_factor), 102 | min_layer, 103 | ) 104 | feature = feature[0] 105 | 106 | with torch.no_grad(): 107 | parsing = face_parser.batch_run(zero_image, pre_normalize=True, image_repr=False, compact_mask=compact_mask) 108 | if parsing is None: 109 | continue 110 | mask = torch.zeros(parsing.size(0), 1, *zero_image.size()[-2:]).to(device) 111 | for mi in mask_ids: 112 | mask = torch.logical_or(mask, parsing[:, mi: mi + 1]) 113 | mask = mask.float() 114 | 115 | if feature.size(-1) < min_resolution: 116 | feature = F.interpolate(feature, min_resolution) 117 | mask = F.interpolate(mask, min_resolution) 118 | else: 119 | mask = F.interpolate(mask, feature.size()[-2:]) 120 | 121 | mask_factor = mask.sum() / torch.numel(mask) 122 | 123 | target_feature = feature[:, delta_s_indices].abs() 124 | feature_pos_loss = -(mask * target_feature).mean() / mask_factor 125 | feature_neg_loss = ((1 - mask) * target_feature).mean() / (1 - mask_factor) 126 | 127 | loss = feature_pos_loss + lambda_neg * feature_neg_loss 128 | 129 | optimizer.zero_grad() 130 | loss.backward() 131 | optimizer.step() 132 | 133 | if (i + 1) % 10 == 0: 134 | print(f"total: {loss.item():.2f} disentangle: {feature_neg_loss.item():.2f} " 135 | f"effect: {feature_pos_loss.item():.2f} ") 136 | 137 | modification.save(save_path) 138 | 139 | 140 | if __name__ == '__main__': 141 | fire.Fire() 142 | -------------------------------------------------------------------------------- /raydl/fp16.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from inspect import getfullargspec 3 | from typing import Iterable, Mapping 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | from torch.cuda.amp import autocast 9 | 10 | 11 | def cast_tensor_type(inputs, src_type, dst_type): 12 | """Recursively convert Tensor in inputs from src_type to dst_type. 13 | 14 | Args: 15 | inputs: Inputs that to be casted. 16 | src_type (torch.dtype): Source type.. 17 | dst_type (torch.dtype): Destination type. 18 | 19 | Returns: 20 | The same type with inputs, but all contained Tensors have been cast. 21 | """ 22 | if isinstance(inputs, torch.Tensor): 23 | return inputs.to(dst_type) 24 | if isinstance(inputs, nn.Module): 25 | return inputs 26 | elif isinstance(inputs, str): 27 | return inputs 28 | elif isinstance(inputs, np.ndarray): 29 | return inputs 30 | elif isinstance(inputs, Mapping): 31 | return type(inputs)({ 32 | k: cast_tensor_type(v, src_type, dst_type) 33 | for k, v in inputs.items() 34 | }) 35 | elif isinstance(inputs, Iterable): 36 | return type(inputs)(cast_tensor_type(item, src_type, dst_type) for item in inputs) 37 | else: 38 | return inputs 39 | 40 | 41 | def auto_fp16(apply_to=None, out_fp32=False): 42 | """Decorator to enable fp16 training automatically. 43 | 44 | This decorator is useful when you write custom modules and want to support 45 | mixed precision training. If inputs arguments are fp32 tensors, they will 46 | be converted to fp16 automatically. Arguments other than fp32 tensors are 47 | ignored. If you are using PyTorch >= 1.6, torch.cuda.amp is used as the 48 | backend, otherwise, original mmcv implementation will be adopted. 49 | 50 | Args: 51 | apply_to (Iterable, optional): The argument names to be converted. 52 | `None` indicates all arguments. 53 | out_fp32 (bool): Whether to convert the output back to fp32. 54 | 55 | Example: 56 | 57 | >>> import torch.nn as nn 58 | >>> class MyModule1(nn.Module): 59 | >>> 60 | >>> # Convert x and y to fp16 61 | >>> @auto_fp16() 62 | >>> def forward(self, x, y): 63 | >>> pass 64 | 65 | >>> import torch.nn as nn 66 | >>> class MyModule2(nn.Module): 67 | >>> 68 | >>> # convert pred to fp16 69 | >>> @auto_fp16(apply_to=('pred', )) 70 | >>> def do_something(self, pred, others): 71 | >>> pass 72 | """ 73 | 74 | def auto_fp16_wrapper(old_func): 75 | 76 | @functools.wraps(old_func) 77 | def new_func(*args, **kwargs): 78 | # check if the module has set the attribute `fp16_enabled`, if not, 79 | # just fallback to the original method. 80 | if not isinstance(args[0], torch.nn.Module): 81 | raise TypeError('@auto_fp16 can only be used to decorate the ' 82 | 'method of nn.Module') 83 | if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled): 84 | return old_func(*args, **kwargs) 85 | 86 | # define output type by class itself 87 | if hasattr(args[0], 'out_fp32') and args[0].out_fp32: 88 | _out_fp32 = True 89 | else: 90 | _out_fp32 = False 91 | 92 | # get the arg spec of the decorated method 93 | args_info = getfullargspec(old_func) 94 | # get the argument names to be casted 95 | # Here, we change the default behaviour with Yu Xiong's 96 | # implementation 97 | args_to_cast = [] if apply_to is None else apply_to 98 | # convert the args that need to be processed 99 | new_args = [] 100 | # NOTE: default args are not taken into consideration 101 | if args: 102 | arg_names = args_info.args[:len(args)] 103 | for i, arg_name in enumerate(arg_names): 104 | if arg_name in args_to_cast: 105 | new_args.append( 106 | cast_tensor_type(args[i], torch.float, torch.half)) 107 | else: 108 | new_args.append(args[i]) 109 | # convert the kwargs that need to be processed 110 | new_kwargs = {} 111 | if kwargs: 112 | for arg_name, arg_value in kwargs.items(): 113 | if arg_name in args_to_cast: 114 | new_kwargs[arg_name] = cast_tensor_type( 115 | arg_value, torch.float, torch.half) 116 | else: 117 | new_kwargs[arg_name] = arg_value 118 | # apply converted arguments to the decorated method 119 | output = autocast(enabled=True)(old_func)(*new_args, **new_kwargs) 120 | # cast the results back to fp32 if necessary 121 | if out_fp32 or _out_fp32: 122 | output = cast_tensor_type(output, torch.half, torch.float) 123 | return output 124 | 125 | return new_func 126 | 127 | return auto_fp16_wrapper 128 | -------------------------------------------------------------------------------- /raydl/tensor.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | from typing import Iterable, Union, Optional, Tuple 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | 8 | __all__ = ["grid_transpose", "create_heatmap", "constant", "is_scalar", "nan_to_num"] 9 | 10 | _constant_cache = dict() 11 | 12 | 13 | def is_scalar(v): 14 | if isinstance(v, numbers.Number): 15 | return True 16 | if torch.is_tensor(v) and torch.numel(v) == 1: 17 | return True 18 | return False 19 | 20 | 21 | try: 22 | nan_to_num = torch.nan_to_num # 1.8.0a0 23 | except AttributeError: 24 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin 25 | # Replace NaN/Inf with specified numerical values. 26 | assert isinstance(input, torch.Tensor) 27 | if posinf is None: 28 | posinf = torch.finfo(input.dtype).max 29 | if neginf is None: 30 | neginf = torch.finfo(input.dtype).min 31 | assert nan == 0 32 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) 33 | 34 | 35 | def constant(value, shape=None, dtype=None, device=None, memory_format=None): 36 | """ 37 | Cached construction of constant tensors. Avoids CPU=>GPU copy when the 38 | same constant is used multiple times. 39 | :param value: 40 | :param shape: 41 | :param dtype: 42 | :param device: 43 | :param memory_format: 44 | :return: 45 | """ 46 | value = np.asarray(value) 47 | if shape is not None: 48 | shape = tuple(shape) 49 | if dtype is None: 50 | dtype = torch.get_default_dtype() 51 | if device is None: 52 | device = torch.device('cpu') 53 | if memory_format is None: 54 | memory_format = torch.contiguous_format 55 | 56 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) 57 | tensor = _constant_cache.get(key, None) 58 | if tensor is None: 59 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) 60 | if shape is not None: 61 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) 62 | tensor = tensor.contiguous(memory_format=memory_format) 63 | _constant_cache[key] = tensor 64 | return tensor 65 | 66 | 67 | def create_heatmap( 68 | images: torch.Tensor, 69 | range_min: Optional[float] = None, 70 | range_max: Optional[float] = None, 71 | scale_each: bool = False, 72 | color_map: str = "jet", 73 | return_tensor: bool = True 74 | ) -> Union[torch.Tensor, Tuple[np.array, ...]]: 75 | """ 76 | create heatmap from BxHxW tensor. 77 | :param images: Tensor[BxHxW] 78 | :param range_min: max value used to normalize the image. By default, min and max are computed from the tensor. 79 | :param range_max: min value used to normalize the image. By default, min and max are computed from the tensor. 80 | :param scale_each: If True, scale each image in the batch of images separately rather 81 | than the (min, max) over all images. Default: False. 82 | :param color_map: The colormap to apply, colormap from 83 | https://docs.opencv.org/3.4/d3/d50/group__imgproc__colormap.html#ga9a805d8262bcbe273f16be9ea2055a65 84 | :param return_tensor: if True, return Tensor[Bx3xHxW], otherwise return tuple of numpy.array(0-255) 85 | :return: 86 | """ 87 | device = images.device 88 | assert images.dim() == 3 89 | try: 90 | color_map = getattr(cv2, f"COLORMAP_{color_map.upper()}") 91 | except AttributeError: 92 | raise ValueError(f"invalid color_map {color_map}") 93 | 94 | with torch.no_grad(): 95 | images = images.detach().clone().to(dtype=torch.float32, device=torch.device("cpu")) 96 | 97 | if range_min is None: 98 | range_min = images.amin(dim=[-1, -2], keepdim=True) if scale_each else images.amin() 99 | if range_max is None: 100 | range_max = images.amax(dim=[-1, -2], keepdim=True) if scale_each else images.amax() 101 | 102 | heatmaps = [] 103 | for m in images.add_(-range_min).div_(range_max - range_min + 1e-5).clip_(0.0, 1.0): 104 | heatmaps.append(cv2.applyColorMap(np.uint8(m.numpy() * 255), color_map)) 105 | 106 | if return_tensor: 107 | heatmaps = torch.from_numpy(np.stack(heatmaps)).permute(0, 3, 1, 2) 108 | # BGR -> RGB & [0, 255] -> [-1, 1] 109 | heatmaps = (heatmaps[:, [2, 1, 0], :, :].contiguous().float().to(device) / 255 - 0.5) * 2 110 | return heatmaps 111 | return tuple(heatmaps) 112 | 113 | 114 | def grid_transpose(tensors: Union[torch.Tensor, Iterable], original_nrow: Optional[int] = None) -> torch.Tensor: 115 | """ 116 | batch tensors transpose. 117 | :param tensors: Tensor[(ROW*COL)*D1*...], or Iterable of same size tensors. 118 | :param original_nrow: original ROW 119 | :return: Tensor[(COL*ROW)*D1*...] 120 | """ 121 | assert torch.is_tensor(tensors) or isinstance(tensors, Iterable) 122 | if not torch.is_tensor(tensors) and isinstance(tensors, Iterable): 123 | seen_size = None 124 | grid = [] 125 | for tensor in tensors: 126 | if seen_size is None: 127 | seen_size = tensor.size() 128 | original_nrow = original_nrow or len(tensor) 129 | elif tensor.size() != seen_size: 130 | raise ValueError("expect all tensor in images have the same size.") 131 | grid.append(tensor) 132 | tensors = torch.cat(grid) 133 | 134 | assert original_nrow is not None 135 | 136 | cell_size = tensors.size()[1:] 137 | 138 | tensors = tensors.reshape(-1, original_nrow, *cell_size) 139 | tensors = torch.transpose(tensors, 0, 1) 140 | return torch.reshape(tensors, (-1, *cell_size)) 141 | -------------------------------------------------------------------------------- /models/custom_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import glob 10 | import hashlib 11 | import importlib 12 | import os 13 | import shutil 14 | from pathlib import Path 15 | 16 | import torch 17 | import torch.utils.cpp_extension 18 | from loguru import logger 19 | from torch.utils.file_baton import FileBaton 20 | 21 | verbosity = 'info' # Verbosity level: 'none', 'debug', 'info' 22 | 23 | 24 | def _find_compiler_bindir(): 25 | patterns = [ 26 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', 27 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', 28 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', 29 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', 30 | ] 31 | for pattern in patterns: 32 | matches = sorted(glob.glob(pattern)) 33 | if len(matches): 34 | return matches[-1] 35 | return None 36 | 37 | 38 | # Main entry point for compiling and loading C++/CUDA plugins. 39 | _cached_plugins = dict() 40 | 41 | 42 | def get_plugin(module_name, sources, **build_kwargs): 43 | assert verbosity in ['none', 'debug', 'info'] 44 | 45 | # Already cached? 46 | if module_name in _cached_plugins: 47 | return _cached_plugins[module_name] 48 | 49 | # Print status. 50 | if verbosity == "debug": 51 | logger.debug(f'Setting up PyTorch plugin "{module_name}"...') 52 | if verbosity == "info": 53 | logger.info(f'Setting up PyTorch plugin "{module_name}"...') 54 | 55 | try: # pylint: disable=too-many-nested-blocks 56 | # Make sure we can find the necessary compiler binaries. 57 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: 58 | compiler_bindir = _find_compiler_bindir() 59 | if compiler_bindir is None: 60 | raise RuntimeError( 61 | f'Could not find MSVC/GCC/CLANG installation on this computer. ' 62 | f'Check _find_compiler_bindir() in "{__file__}".') 63 | os.environ['PATH'] += ';' + compiler_bindir 64 | 65 | # Compile and load. 66 | verbose_build = (verbosity == 'full') 67 | 68 | # Incremental build md5sum trickery. Copies all the input source files 69 | # into a cached build directory under a combined md5 digest of the input 70 | # source files. Copying is done only if the combined digest has changed. 71 | # This keeps input file timestamps and filenames the same as in previous 72 | # extension builds, allowing for fast incremental rebuilds. 73 | # 74 | # This optimization is done only in case all the source files reside in 75 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR 76 | # environment variable is set (we take this as a signal that the user 77 | # actually cares about this.) 78 | source_dirs_set = set(os.path.dirname(source) for source in sources) 79 | if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ): 80 | all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file())) 81 | 82 | # Compute a combined hash digest for all source files in the same 83 | # custom op directory (usually .cu, .cpp, .py and .h files). 84 | hash_md5 = hashlib.md5() 85 | for src in all_source_files: 86 | with open(src, 'rb') as f: 87 | hash_md5.update(f.read()) 88 | build_dir = torch.utils.cpp_extension._get_build_directory(module_name, 89 | verbose=verbose_build) # pylint: disable=protected-access 90 | digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest()) 91 | 92 | if not os.path.isdir(digest_build_dir): 93 | os.makedirs(digest_build_dir, exist_ok=True) 94 | baton = FileBaton(os.path.join(digest_build_dir, 'lock')) 95 | if baton.try_acquire(): 96 | try: 97 | for src in all_source_files: 98 | shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src))) 99 | finally: 100 | baton.release() 101 | else: 102 | # Someone else is copying source files under the digest dir, 103 | # wait until done and continue. 104 | baton.wait() 105 | digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources] 106 | torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir, 107 | verbose=verbose_build, sources=digest_sources, **build_kwargs) 108 | else: 109 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) 110 | module = importlib.import_module(module_name) 111 | except: 112 | if verbosity == 'brief': 113 | logger.debug('Failed!') 114 | raise 115 | 116 | # Print status and add to cache. 117 | if verbosity == "debug": 118 | logger.debug(f'Done setting up PyTorch plugin "{module_name}".') 119 | if verbosity == "info": 120 | logger.info(f'Done setting up PyTorch plugin "{module_name}".') 121 | 122 | _cached_plugins[module_name] = module 123 | return module 124 | -------------------------------------------------------------------------------- /tools/helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | from itertools import chain 3 | from pathlib import Path 4 | 5 | import cv2 6 | import fire 7 | import torch 8 | from loguru import logger 9 | from torchvision.datasets.folder import is_image_file 10 | from tqdm import tqdm 11 | 12 | import raydl 13 | 14 | 15 | def split(*image_path, save_path, n_height, n_width, resize=None): 16 | image = raydl.load_images(image_path) 17 | base_h, base_w = image.size()[-2] // n_height, image.size()[-1] // n_width 18 | print(f"original image size: {image.size()} base crop size: {(base_h, base_w)}") 19 | split_images = [] 20 | for h in range(n_height): 21 | for w in range(n_width): 22 | split_images.append(image[:, :, h * base_h:(h + 1) * base_h, w * base_w:(w + 1) * base_w]) 23 | split_images = torch.cat(split_images) 24 | raydl.save_images(split_images, save_path, resize=resize, separately=True) 25 | 26 | 27 | def video(image_folder, save_path, fps=4, crop_width=None, crop_height=None, crop_pixel_base=1): 28 | files = [file for file in Path(image_folder).glob("*") if is_image_file(file.name)] 29 | files = sorted(files, key=os.path.getmtime) 30 | print([f.name for f in files]) 31 | 32 | out_stream = None 33 | 34 | crop_width = crop_width * crop_pixel_base if crop_width is not None else None 35 | crop_height = crop_height * crop_pixel_base if crop_height is not None else None 36 | 37 | for f in tqdm(files): 38 | img = cv2.imread(f.as_posix()) 39 | height, width, _ = img.shape 40 | height = crop_height or height 41 | width = crop_width or width 42 | size = (width, height) 43 | 44 | if out_stream is None: 45 | out_stream = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc('m', 'p', '4', 'v'), fps, size) 46 | out_stream.write(img[:height, :width, :]) 47 | out_stream.release() 48 | 49 | 50 | def diff(path1, path2, save_path=None, ada=False): 51 | path1 = Path(path1) 52 | path2 = Path(path2) 53 | 54 | if path1.is_file() and path2.is_file(): 55 | save_path = Path(save_path) if save_path is not None else path1.parent 56 | images_iterator = [(path1, path2, save_path / f"{path1.stem}_diff_{path2.stem}{path1.suffix}")] 57 | elif path1.is_dir() and path2.is_dir(): 58 | save_path = Path(save_path) if save_path is not None else path1.parent / f"{path1.name}_diff_{path2.name}" 59 | if not save_path.exists(): 60 | print(f"mkdir {save_path}") 61 | save_path.mkdir() 62 | files1 = [f for f in Path(path1).glob("*") if is_image_file(f.name) and (path2 / f.name).exists()] 63 | files2 = [path2 / f.name for f in files1] 64 | images_iterator = zip(files1, files2, [save_path / f.name for f in files1]) 65 | else: 66 | raise ValueError(f"{path1} and {path2} are not both files or both folders") 67 | 68 | for image_path1, image_path2, path in images_iterator: 69 | image1 = raydl.load_images(image_path1) 70 | image2 = raydl.load_images(image_path2) 71 | assert image1.size() == image2.size() 72 | range_max = (2 ** 2 * 3) ** 0.5 if not ada else None 73 | 74 | image_diff = torch.norm(image1 - image2, p=2, dim=1) 75 | print(f"{image_path1} and {image_path2} " 76 | f"max difference is {image_diff.max():.4f} ({100 * image_diff.max() / ((2 ** 2 * 3) ** 0.5):.2f}%)") 77 | diff_heatmap = raydl.create_heatmap(image_diff, scale_each=False, range_min=0, range_max=range_max) 78 | 79 | if path.exists(): 80 | _path = path.parent / f"{path.stem}_diff{path.suffix}" 81 | print(f"{path} existed! rename save path to {_path}") 82 | path = _path 83 | raydl.save_images(diff_heatmap, path) 84 | 85 | 86 | def concat(*image_folders, save_path, captions=None, resize=None, nrow=None, transpose=True, batch_size=None): 87 | image_folders = [Path(f) for f in image_folders] 88 | images_to_compare = list(set.intersection(*[ 89 | set([p.name for p in chain( 90 | folder.glob("*.jpg"), 91 | folder.glob("*.png") 92 | )]) for folder in image_folders 93 | ])) 94 | if len(images_to_compare) == 0: 95 | print("can not found the same name jpg or png images in these image_folders") 96 | return 97 | images_to_compare = sorted(images_to_compare) 98 | logger.info(f"have total {len(images_to_compare)} images") 99 | logger.info(f"image_folders:\n" + '\n\t'.join([str(f) for f in image_folders])) 100 | 101 | batch_size = batch_size or len(images_to_compare) 102 | i = 0 103 | save_path = Path(save_path) 104 | 105 | for batch_id, batch in enumerate(raydl.total_chunk(len(images_to_compare), batch_size, drop_last=False)): 106 | images = torch.cat([ 107 | raydl.load_images([folder / name for name in images_to_compare[i:i + batch]], resize=resize) 108 | for folder in image_folders 109 | ]) 110 | i += batch 111 | 112 | if captions is not None: 113 | captions = list(map(str, captions)) 114 | assert len(captions) == len(image_folders), f"{captions} v.s len(image_folders)={len(image_folders)}" 115 | if not transpose: 116 | captions = sum([[c, *[None] * (batch - 1)] for c in captions], []) 117 | else: 118 | captions.extend([None] * (len(captions) * (batch - 1))) 119 | default_nrow = batch 120 | if transpose: 121 | images = raydl.grid_transpose(images, batch) 122 | default_nrow = images.size(0) // default_nrow 123 | 124 | if batch_size == 1: 125 | sp = save_path.parent / images_to_compare[batch_id] 126 | elif batch_size != len(images_to_compare): 127 | sp = save_path.parent / f"{save_path.stem}_{batch_id}{save_path.suffix}" 128 | else: 129 | sp = save_path 130 | logger.info(sp) 131 | raydl.save_images(images, save_path=sp, captions=captions, resize=resize, nrow=nrow or default_nrow) 132 | 133 | 134 | if __name__ == '__main__': 135 | fire.Fire() 136 | -------------------------------------------------------------------------------- /tools/evaluate/fid.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | import fire 5 | import ignite.distributed as idist 6 | import torch 7 | from ignite.contrib.handlers import ProgressBar 8 | from ignite.engine import Engine, Events 9 | from ignite.handlers import Timer 10 | from loguru import logger 11 | from torch.utils.data import Dataset 12 | from torchvision import transforms 13 | 14 | import raydl 15 | import training.distributed 16 | from data.dataset import ImageDataset 17 | from raydl.metrics import FID, LifeWise 18 | 19 | 20 | class FlipDataset(Dataset): 21 | def __init__(self, dataset): 22 | self.original = dataset 23 | 24 | def __repr__(self): 25 | return "Flip" + repr(self.original) 26 | 27 | def __len__(self): 28 | return len(self.original) * 2 29 | 30 | def __getitem__(self, index): 31 | if index < len(self.original): 32 | return self.original[index] 33 | image = self.original[index - len(self.original)] 34 | return torch.flip(image, [-1]) 35 | 36 | 37 | def get_dataset(path=None, resolution=None): 38 | if path is None: 39 | return None 40 | 41 | transform = transforms.Compose( 42 | [ 43 | transforms.ToTensor(), 44 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), 45 | ] 46 | ) 47 | 48 | dataset = ImageDataset(path, transform, resolution, archive_type=None) 49 | return dataset 50 | 51 | 52 | def get_dataloader(dataset=None, xflip=True, batch_size=None, num_workers=None): 53 | if dataset is None: 54 | def infinite_iterator(): 55 | while True: 56 | yield 57 | 58 | return infinite_iterator() 59 | 60 | if xflip: 61 | logger.info("add xflip argument") 62 | dataset = FlipDataset(dataset) 63 | 64 | logger.info(f"use dataset: \n{dataset},\n have {len(dataset)} images") 65 | 66 | loader = training.distributed.auto_dataloader( 67 | dataset, batch_size=batch_size, num_workers=num_workers, 68 | shuffle=False, drop_last=False, seed=0, 69 | ) 70 | return loader 71 | 72 | 73 | def create_fid_evaluator(pkl_path, inception_path, num_images=None, step_fn=None): 74 | if not Path(pkl_path).exists(): 75 | logger.info(f"{pkl_path} do not exist. Computed result pkl will be saved at there.") 76 | computed_pkl_save_path = pkl_path 77 | precomputed_pkl = None 78 | else: 79 | precomputed_pkl = pkl_path 80 | computed_pkl_save_path = None 81 | 82 | evaluator = Engine(step_fn or (lambda engine, batch: batch)) 83 | 84 | fider = FID(precomputed_pkl, computed_pkl_save_path=computed_pkl_save_path, inception_path=inception_path, 85 | max_num_examples=num_images) 86 | fider.attach(evaluator, "fid", LifeWise()) 87 | 88 | if idist.get_rank() == 0: 89 | pbar = ProgressBar(ncols=80, ) 90 | pbar.attach(evaluator) 91 | 92 | @evaluator.on(Events.ITERATION_COMPLETED) 93 | def update(engine: Engine): 94 | if fider.is_full: 95 | engine.terminate() 96 | 97 | @evaluator.on(Events.COMPLETED) 98 | def logg(engine: Engine): 99 | logger.info("over") 100 | if precomputed_pkl is not None: 101 | logger.info(f"metric fid: {engine.state.metrics['fid']:.4f}") 102 | else: 103 | logger.info(f"save computed pkl at {computed_pkl_save_path}") 104 | 105 | return evaluator 106 | 107 | 108 | def generator_step_fn(checkpoint, batch_size=32, truncation=1): 109 | from models.StyleGAN2_wrapper import ImprovedStyleGAN2Generator 110 | 111 | g = ImprovedStyleGAN2Generator.load(checkpoint, device=idist.device(), default_truncation=truncation) 112 | g.manipulation_mode() 113 | g.eval() 114 | g = training.distributed.auto_model(g) 115 | 116 | @torch.no_grad() 117 | def step(engine, _): 118 | z = torch.randn(batch_size, g.style_dim, device=idist.device()) 119 | images = g(z=z) 120 | return images 121 | 122 | return step 123 | 124 | 125 | def running( 126 | local_rank, 127 | pkl_path, 128 | num_images=None, 129 | path=None, 130 | resolution=None, 131 | batch_size=32, 132 | num_workers=4, 133 | xflip=True, 134 | truncation=1, 135 | inception_path="./pretrained_models/stylegan2-ada-fid-inception.pt", 136 | ): 137 | logger.remove() 138 | if local_rank == 0: 139 | logger.add(sys.stderr, level="DEBUG") 140 | 141 | if Path(path).suffix == ".pt": 142 | logger.info(f"will load generator from checkpoint {path}") 143 | dataset = get_dataset() 144 | dataloader = get_dataloader(dataset) 145 | step_fn = generator_step_fn(path, batch_size, truncation=truncation) 146 | num_images = num_images or 50_000 147 | else: 148 | logger.info(f"will load images from {path}") 149 | dataset = get_dataset(path, resolution) 150 | dataloader = get_dataloader(dataset, xflip=xflip, batch_size=batch_size, num_workers=num_workers) 151 | step_fn = None 152 | 153 | timer = Timer() 154 | logger.debug(raydl.describe_dict(dict( 155 | pkl_path=pkl_path, 156 | inception_path=inception_path, 157 | num_images=num_images, 158 | )), head="evaluator dict:") 159 | evaluator = create_fid_evaluator(pkl_path, inception_path, num_images, step_fn) 160 | evaluator.run(dataloader) 161 | logger.info(f"total running time: {timer.value():.2f}s") 162 | 163 | 164 | def run( 165 | pkl_path, 166 | path=None, 167 | resolution=None, 168 | num_images=None, 169 | batch_size=32, 170 | num_workers=4, 171 | xflip=True, 172 | truncation=1, 173 | inception_path="./pretrained_models/stylegan2-ada-fid-inception.pt", 174 | backend="nccl", 175 | ): 176 | kwargs = locals() 177 | logger.info(raydl.describe_dict(kwargs, head="configs:")) 178 | with idist.Parallel(backend=kwargs.pop("backend")) as parallel: 179 | parallel.run(running, **kwargs) 180 | 181 | 182 | if __name__ == '__main__': 183 | torch.set_grad_enabled(False) 184 | fire.Fire(run) 185 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.toptal.com/developers/gitignore/api/pycharm,python,visualstudiocode 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=pycharm,python,visualstudiocode 4 | 5 | ### PyCharm ### 6 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 7 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 8 | 9 | # User-specific stuff 10 | .idea/**/workspace.xml 11 | .idea/**/tasks.xml 12 | .idea/**/usage.statistics.xml 13 | .idea/**/dictionaries 14 | .idea/**/shelf 15 | 16 | # Generated files 17 | .idea/**/contentModel.xml 18 | 19 | # Sensitive or high-churn files 20 | .idea/**/dataSources/ 21 | .idea/**/dataSources.ids 22 | .idea/**/dataSources.local.xml 23 | .idea/**/sqlDataSources.xml 24 | .idea/**/dynamic.xml 25 | .idea/**/uiDesigner.xml 26 | .idea/**/dbnavigator.xml 27 | 28 | # Gradle 29 | .idea/**/gradle.xml 30 | .idea/**/libraries 31 | 32 | # Gradle and Maven with auto-import 33 | # When using Gradle or Maven with auto-import, you should exclude module files, 34 | # since they will be recreated, and may cause churn. Uncomment if using 35 | # auto-import. 36 | # .idea/artifacts 37 | # .idea/compiler.xml 38 | # .idea/jarRepositories.xml 39 | # .idea/modules.xml 40 | # .idea/*.iml 41 | # .idea/modules 42 | # *.iml 43 | # *.ipr 44 | 45 | # CMake 46 | cmake-build-*/ 47 | 48 | # Mongo Explorer plugin 49 | .idea/**/mongoSettings.xml 50 | 51 | # File-based project format 52 | *.iws 53 | 54 | # IntelliJ 55 | out/ 56 | 57 | # mpeltonen/sbt-idea plugin 58 | .idea_modules/ 59 | 60 | # JIRA plugin 61 | atlassian-ide-plugin.xml 62 | 63 | # Cursive Clojure plugin 64 | .idea/replstate.xml 65 | 66 | # Crashlytics plugin (for Android Studio and IntelliJ) 67 | com_crashlytics_export_strings.xml 68 | crashlytics.properties 69 | crashlytics-build.properties 70 | fabric.properties 71 | 72 | # Editor-based Rest Client 73 | .idea/httpRequests 74 | 75 | # Android studio 3.1+ serialized cache file 76 | .idea/caches/build_file_checksums.ser 77 | 78 | ### PyCharm Patch ### 79 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 80 | 81 | # *.iml 82 | # modules.xml 83 | # .idea/misc.xml 84 | # *.ipr 85 | 86 | # Sonarlint plugin 87 | # https://plugins.jetbrains.com/plugin/7973-sonarlint 88 | .idea/**/sonarlint/ 89 | 90 | # SonarQube Plugin 91 | # https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin 92 | .idea/**/sonarIssues.xml 93 | 94 | # Markdown Navigator plugin 95 | # https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced 96 | .idea/**/markdown-navigator.xml 97 | .idea/**/markdown-navigator-enh.xml 98 | .idea/**/markdown-navigator/ 99 | 100 | # Cache file creation bug 101 | # See https://youtrack.jetbrains.com/issue/JBR-2257 102 | .idea/$CACHE_FILE$ 103 | 104 | # CodeStream plugin 105 | # https://plugins.jetbrains.com/plugin/12206-codestream 106 | .idea/codestream.xml 107 | 108 | ### Python ### 109 | # Byte-compiled / optimized / DLL files 110 | __pycache__/ 111 | *.py[cod] 112 | *$py.class 113 | 114 | # C extensions 115 | *.so 116 | 117 | # Distribution / packaging 118 | .Python 119 | build/ 120 | develop-eggs/ 121 | dist/ 122 | downloads/ 123 | eggs/ 124 | .eggs/ 125 | parts/ 126 | sdist/ 127 | var/ 128 | wheels/ 129 | pip-wheel-metadata/ 130 | share/python-wheels/ 131 | *.egg-info/ 132 | .installed.cfg 133 | *.egg 134 | MANIFEST 135 | 136 | # PyInstaller 137 | # Usually these files are written by a python script from a template 138 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 139 | *.manifest 140 | *.spec 141 | 142 | # Installer logs 143 | pip-log.txt 144 | pip-delete-this-directory.txt 145 | 146 | # Unit test / coverage reports 147 | htmlcov/ 148 | .tox/ 149 | .nox/ 150 | .coverage 151 | .coverage.* 152 | .cache 153 | nosetests.xml 154 | coverage.xml 155 | *.cover 156 | *.py,cover 157 | .hypothesis/ 158 | .pytest_cache/ 159 | pytestdebug.log 160 | 161 | # Translations 162 | *.mo 163 | *.pot 164 | 165 | # Django stuff: 166 | *.log 167 | local_settings.py 168 | db.sqlite3 169 | db.sqlite3-journal 170 | 171 | # Flask stuff: 172 | instance/ 173 | .webassets-cache 174 | 175 | # Scrapy stuff: 176 | .scrapy 177 | 178 | # Sphinx documentation 179 | docs/_build/ 180 | doc/_build/ 181 | 182 | # PyBuilder 183 | target/ 184 | 185 | # Jupyter Notebook 186 | .ipynb_checkpoints 187 | 188 | # IPython 189 | profile_default/ 190 | ipython_config.py 191 | 192 | # pyenv 193 | .python-version 194 | 195 | # pipenv 196 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 197 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 198 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 199 | # install all needed dependencies. 200 | #Pipfile.lock 201 | 202 | # poetry 203 | #poetry.lock 204 | 205 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 206 | __pypackages__/ 207 | 208 | # Celery stuff 209 | celerybeat-schedule 210 | celerybeat.pid 211 | 212 | # SageMath parsed files 213 | *.sage.py 214 | 215 | # Environments 216 | # .env 217 | .env/ 218 | .venv/ 219 | env/ 220 | venv/ 221 | ENV/ 222 | env.bak/ 223 | venv.bak/ 224 | pythonenv* 225 | 226 | # Spyder project settings 227 | .spyderproject 228 | .spyproject 229 | 230 | # Rope project settings 231 | .ropeproject 232 | 233 | # mkdocs documentation 234 | /site 235 | 236 | # mypy 237 | .mypy_cache/ 238 | .dmypy.json 239 | dmypy.json 240 | 241 | # Pyre type checker 242 | .pyre/ 243 | 244 | # pytype static type analyzer 245 | .pytype/ 246 | 247 | # operating system-related files 248 | # file properties cache/storage on macOS 249 | *.DS_Store 250 | # thumbnail cache on Windows 251 | Thumbs.db 252 | 253 | # profiling data 254 | .prof 255 | 256 | 257 | ### VisualStudioCode ### 258 | .vscode/* 259 | !.vscode/settings.json 260 | !.vscode/tasks.json 261 | !.vscode/launch.json 262 | !.vscode/extensions.json 263 | *.code-workspace 264 | 265 | ### VisualStudioCode Patch ### 266 | # Ignore all local history of files 267 | .history 268 | .ionide 269 | 270 | # End of https://www.toptal.com/developers/gitignore/api/pycharm,python,visualstudiocode 271 | 272 | 273 | /runs/ 274 | training/engine/Tester.py 275 | configs/test.yml 276 | pretrained -------------------------------------------------------------------------------- /raydl/information.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Sequence 3 | 4 | import ignite 5 | import ignite.distributed as idist 6 | import torch 7 | import torch.nn as nn 8 | 9 | import raydl 10 | from raydl import AttributeDict 11 | from raydl.collection import describe_dict 12 | 13 | __all__ = [ 14 | "distributed_configure", 15 | "memory_usage", 16 | "module_summary", 17 | "library_version", 18 | "suppress_tracer_warnings", 19 | "class_repr", 20 | ] 21 | 22 | 23 | def class_repr(module, attrs: Sequence[str], additional_items: dict = None): 24 | r_dict = {k: getattr(module, k) for k in attrs} 25 | if additional_items is not None: 26 | r_dict.update(additional_items) 27 | separator = "\n\t" 28 | r = raydl.describe_dict(r_dict, head=f"{module.__class__.__name__}(", separator=separator) 29 | return f"{r}\n)" 30 | 31 | 32 | def module_summary(module: nn.Module, inputs: Sequence, max_nesting: int = 3, skip_redundant: bool = True) -> str: 33 | assert isinstance(module, torch.nn.Module) 34 | assert not isinstance(module, torch.jit.ScriptModule) 35 | assert isinstance(inputs, Sequence) 36 | 37 | # Register hooks. 38 | entries = [] 39 | nesting = [0] # use list to keep all module use the same nesting 40 | 41 | def pre_hook(_mod, _inputs): 42 | nesting[0] += 1 43 | 44 | def post_hook(mod, _inputs, outputs): 45 | nesting[0] -= 1 46 | if nesting[0] <= max_nesting: 47 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] 48 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)] 49 | entries.append(AttributeDict(mod=mod, outputs=outputs)) 50 | 51 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] 52 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] 53 | 54 | # Run module. 55 | outputs = module(*inputs) 56 | for hook in hooks: 57 | hook.remove() 58 | 59 | # Identify unique outputs, parameters, and buffers. 60 | tensors_seen = set() 61 | for e in entries: 62 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] 63 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] 64 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] 65 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} 66 | 67 | # Filter out redundant entries. 68 | if skip_redundant: 69 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] 70 | 71 | # Construct table. 72 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] 73 | rows += [['---'] * len(rows[0])] 74 | param_total = 0 75 | buffer_total = 0 76 | submodule_names = {mod: name for name, mod in module.named_modules()} 77 | for e in entries: 78 | name = '' if e.mod is module else submodule_names[e.mod] 79 | param_size = sum(t.numel() for t in e.unique_params) 80 | buffer_size = sum(t.numel() for t in e.unique_buffers) 81 | output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs] 82 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] 83 | rows += [[ 84 | name + (':0' if len(e.outputs) >= 2 else ''), 85 | str(param_size) if param_size else '-', 86 | str(buffer_size) if buffer_size else '-', 87 | (output_shapes + ['-'])[0], 88 | (output_dtypes + ['-'])[0], 89 | ]] 90 | for idx in range(1, len(e.outputs)): 91 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] 92 | param_total += param_size 93 | buffer_total += buffer_size 94 | 95 | rows += [['---'] * len(rows[0])] 96 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']] 97 | 98 | # Print table. 99 | widths = [max(len(cell) for cell in column) for column in zip(*rows)] 100 | summary_text = "" 101 | for row in rows: 102 | summary_text += ' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)) 103 | summary_text += '\n' 104 | return summary_text 105 | 106 | 107 | def distributed_configure(): 108 | return describe_dict({ 109 | "distributed configuration": idist.model_name(), 110 | "backend": idist.backend(), 111 | "hostname": idist.hostname(), 112 | "world size": idist.get_world_size(), 113 | "rank": idist.get_rank(), 114 | "local rank": idist.get_local_rank(), 115 | "num nodes": idist.get_nnodes(), 116 | "num processes per_node": idist.get_nproc_per_node(), 117 | "node rank": idist.get_node_rank(), 118 | }, head="distributed_configure:") 119 | 120 | 121 | def library_version(): 122 | import torchvision 123 | versions = dict( 124 | PyTorch=torch.__version__, 125 | torchvision=torchvision.__version__, 126 | ignite=ignite.__version__, 127 | ) 128 | if torch.cuda.is_available(): 129 | # explicitly import cudnn as 130 | # torch.backends.cudnn can not be pickled with hvd spawning procs 131 | from torch.backends import cudnn 132 | 133 | versions["GPU"] = torch.cuda.get_device_name(idist.get_local_rank()) 134 | versions["CUDA"] = torch.version.cuda 135 | versions["CUDNN"] = cudnn.version() 136 | return describe_dict(versions, head="version:") 137 | 138 | 139 | def memory_usage(verbose=True): 140 | usage = "memory_usage:\n" 141 | memory = dict( 142 | max_reserverd_memory=torch.cuda.max_memory_reserved(), 143 | max_allocated_memory=torch.cuda.max_memory_allocated(), 144 | ) 145 | torch.cuda.reset_peak_memory_stats() 146 | for k, v in memory.items(): 147 | if v > 1024 ** 2: 148 | usage += f"{k}: {v / 1024 ** 2} MB\n" 149 | elif v > 1024: 150 | usage += f"{k}: {v / 1024} KB\n" 151 | else: 152 | usage += f"{k}: {v} B\n" 153 | if verbose: 154 | usage += torch.cuda.memory_summary() 155 | return usage 156 | 157 | 158 | class suppress_tracer_warnings(warnings.catch_warnings): 159 | def __enter__(self): 160 | super().__enter__() 161 | warnings.simplefilter('ignore', category=torch.jit.TracerWarning) 162 | return self 163 | -------------------------------------------------------------------------------- /raydl/registry.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import warnings 3 | from copy import deepcopy 4 | from typing import MutableMapping 5 | 6 | from loguru import logger 7 | 8 | 9 | class _Registry: 10 | def __init__(self, name): 11 | self._name = name 12 | 13 | def get(self, key): 14 | raise NotImplemented 15 | 16 | # 17 | # def keys(self): 18 | # raise NotImplemented 19 | # 20 | # def __len__(self): 21 | # len(self.keys()) 22 | 23 | def __contains__(self, key): 24 | return self.get(key) is not None 25 | 26 | @property 27 | def name(self): 28 | return self._name 29 | 30 | def __repr__(self): 31 | return f"{self.__class__.__name__}(name={self._name})" 32 | 33 | def build_with(self, cfg, default_args=None): 34 | """Build a module from config dict. 35 | Args: 36 | cfg (MutableMapping, or type str): Config dict. It should at least contain the key "_type". 37 | default_args (dict, optional): Default initialization arguments. 38 | Returns: 39 | object: The constructed object. 40 | """ 41 | if isinstance(cfg, MutableMapping): 42 | if '_type' in cfg: 43 | # {"_type": "CLASS_NAME", ...arguments} 44 | args = deepcopy(cfg) 45 | obj_type = args.pop('_type') 46 | args = dict(args.items()) 47 | elif len(cfg) == 1: 48 | # {"CLASS_NAME": {...arguments}} 49 | obj_type, args = list(cfg.items())[0] 50 | args = dict(args.items()) 51 | else: 52 | raise ValueError(f"Invalid cfg. the cfg dict must contain the type info, but got {cfg}") 53 | elif isinstance(cfg, str): 54 | obj_type = cfg 55 | args = dict() 56 | else: 57 | raise TypeError(f'cfg must be `MutableMapping` or a str, but got {type(cfg)}') 58 | 59 | for invalid_key in [k for k in args.keys() if k.startswith("_")]: 60 | warnings.warn(f"got param start with `_`: {invalid_key}, will remove it") 61 | args.pop(invalid_key) 62 | 63 | if not (isinstance(default_args, dict) or default_args is None): 64 | raise TypeError('default_args must be a dict or None, ' 65 | f'but got {type(default_args)}') 66 | 67 | if isinstance(obj_type, str): 68 | obj_cls = self.get(obj_type) 69 | if obj_cls is None: 70 | raise KeyError(f'{obj_type} is not in the {self.name} registry') 71 | elif inspect.isclass(obj_type): 72 | obj_cls = obj_type 73 | else: 74 | raise TypeError(f'type must be a str or valid type, but got {type(obj_type)}') 75 | 76 | if default_args is not None: 77 | for name, value in default_args.items(): 78 | args.setdefault(name, value) 79 | try: 80 | obj = obj_cls(**args) 81 | except TypeError as e: 82 | logger.error(e) 83 | raise TypeError(f"invalid argument in {args} when try to build {obj_cls}\n build with {cfg}\n") from e 84 | return obj 85 | 86 | 87 | class Registry(_Registry): 88 | """A registry to map strings to classes. 89 | Args: 90 | name (str): Registry name. 91 | """ 92 | 93 | def __init__(self, name): 94 | super().__init__(name) 95 | self._module_dict = dict() 96 | 97 | def keys(self): 98 | return tuple(self._module_dict.keys()) 99 | 100 | def __len__(self): 101 | len(self.keys()) 102 | 103 | def get(self, key): 104 | """ 105 | Get the registry record. 106 | :param key: The class name in string format. 107 | :return: The corresponding class. 108 | """ 109 | return self._module_dict.get(key, None) 110 | 111 | def _register_module(self, module_class, module_name=None, force=False): 112 | if module_name is None: 113 | module_name = module_class.__name__ 114 | if not force and module_name in self._module_dict: 115 | if self._module_dict[module_name] == module_class: 116 | warnings.warn(f'{module_name} is already registered in {self.name}, but is the same class') 117 | return 118 | if module_class.__module__ == "__main__": 119 | warnings.warn(f"{module_name} is already registered in {self.name}, but registered again in __main__") 120 | return 121 | raise KeyError(f'{module_name}:{self._module_dict[module_name]} is already registered in {self.name}' 122 | f'so {module_class} can not be registered') 123 | self._module_dict[module_name] = module_class 124 | 125 | def register_module(self, name=None, force=False, module=None): 126 | """Register a module. 127 | A record will be added to `self._module_dict`, whose key is the class 128 | name or the specified name, and value is the class itself. 129 | It can be used as a decorator or a normal function. 130 | Args: 131 | name (str | None): The module name to be registered. If not 132 | specified, the class name will be used. 133 | force (bool, optional): Whether to override an existing class with 134 | the same name. Default: False. 135 | module (type): Module class to be registered. 136 | """ 137 | if not isinstance(force, bool): 138 | raise TypeError(f'force must be a boolean, but got {type(force)}') 139 | 140 | # use it as a normal method: x.register_module(module=SomeClass) 141 | if module is not None: 142 | self._register_module( 143 | module_class=module, module_name=name, force=force) 144 | return module 145 | 146 | # raise the error ahead of time 147 | if not (name is None or isinstance(name, str)): 148 | raise TypeError(f'name must be a str, but got {type(name)}') 149 | 150 | # use it as a decorator: @x.register_module() 151 | def _register(cls): 152 | self._register_module(module_class=cls, module_name=name, force=force) 153 | return cls 154 | 155 | return _register 156 | 157 | 158 | class ModuleRegistry(_Registry): 159 | def __init__(self, module): 160 | super(ModuleRegistry, self).__init__(module.__name__) 161 | self.module = module 162 | 163 | def get(self, key): 164 | return getattr(self.module, key, None) 165 | -------------------------------------------------------------------------------- /editing/util.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from pathlib import Path 3 | 4 | import torch 5 | from loguru import logger 6 | 7 | import raydl 8 | from editing.config import CORRECTION_PATH 9 | 10 | 11 | def convert_images_to_video(image_pattern, save_path, fps=24, overwrite=True, crf=18): 12 | # install ffmpeg last version first. https://johnvansickle.com/ffmpeg/ 13 | # will construct shell command like: 14 | # ffmpeg -r 15 -f image2 -i tmp/sequence_%d.jpg -vcodec libx264 -crf 18 -pix_fmt yuv420p tmp/t4.mp4 15 | args = ["ffmpeg"] 16 | if overwrite: 17 | args.append("-y") 18 | args += [ 19 | "-r", 20 | f"{fps}", 21 | "-f", 22 | "image2", 23 | "-i", 24 | f"{image_pattern}", 25 | "-vcodec", 26 | "libx264", 27 | "-crf", 28 | f"{crf}", 29 | "-pix_fmt", 30 | "yuv420p", 31 | save_path, 32 | "-hide_banner", 33 | "-loglevel", 34 | "warning", 35 | ] 36 | logger.info(" ".join(args)) 37 | subprocess.run(args) 38 | 39 | 40 | # StyleGAN2 specific utilize functions 41 | 42 | def prev_layer(layer: int): 43 | """ 44 | previous main layer of input layer in StyleGAN Generator 45 | :param layer: layer index 46 | :return: previous layer index 47 | """ 48 | assert layer >= 0 49 | 50 | # to_rgb layer 51 | if layer % 3 == 1: 52 | return layer - 1 53 | if layer % 3 == 0: 54 | return layer - 1 55 | if layer % 3 == 2: 56 | return layer - 2 57 | 58 | 59 | def next_layer(layer): 60 | """ 61 | next main layer of input layer in StyleGAN Generator 62 | :param layer: layer index 63 | :return: next layer index 64 | """ 65 | if layer == 0: 66 | return 1 67 | # to_rgb layer 68 | if layer % 3 == 1: 69 | return layer + 1 70 | if layer % 3 == 0: 71 | return layer + 2 72 | if layer % 3 == 2: 73 | return layer + 1 74 | 75 | 76 | def channel_selector(layers, rules="all", is_indexes_rule=False, correction_path=CORRECTION_PATH) -> dict: 77 | """ 78 | select channels according to rules. 79 | :param layers: selected layers. 80 | :param rules: python command that will be run with eval(). something like "ric[10]>0.1" 81 | or "(ric[10]>0.1)&(ric[8]>0.1)" 82 | :param is_indexes_rule: if set as True, directly convert rules as channel indices, like "12,1,511" 83 | :param correction_path: channel-region corrections. 84 | :return: dict(layer1=mask1, layer2=mask2) 85 | """ 86 | correction = torch.load(correction_path) 87 | 88 | layers = raydl.tuple_of_indices(layers) 89 | if isinstance(rules, int) and is_indexes_rule: 90 | rules = f"{rules}" 91 | rules = raydl.tuple_of_type(rules, str) 92 | if len(rules) == 1: 93 | rules = rules * len(layers) 94 | assert len(layers) == len(rules) 95 | 96 | result = {} 97 | for layer, rule in zip(layers, rules): 98 | corr = correction[layer].abs() 99 | # relative correction by mask 100 | rcm = corr / corr.amax(dim=1, keepdim=True) # noqa 101 | # relative correction by channel 102 | rcc = corr / corr.amax(dim=0, keepdim=True) # noqa 103 | # relative importance by channel 104 | ric = corr / corr.sum(dim=0, keepdim=True) # noqa 105 | if is_indexes_rule: 106 | mask = torch.zeros_like(corr[0]) 107 | for c in raydl.parse_indices_str(rule): 108 | mask[c] = 1 109 | else: 110 | if rule != "all": 111 | # rule from user, very very dangerous!! 112 | mask = eval(rule) 113 | else: 114 | mask = torch.ones_like(corr[0]) 115 | logger.info(f"layer {layer} {int(mask.sum())} dims: {torch.nonzero(mask.float()).flatten().tolist()[:50]}") 116 | assert mask.size() == torch.Size([corr.size(1)]) 117 | result[layer] = mask 118 | 119 | return result 120 | 121 | 122 | def load_latent( 123 | latent_path, ids=None, load_real=False, real_path=None, 124 | latent_type=None, real_images_resize=1024, 125 | device=torch.device("cuda") 126 | ) -> dict: 127 | """ 128 | load latent code of StyleGAN, may load real images too. 129 | :param latent_path: 130 | :param ids: latent indices. something like "1,2,9" 131 | :param load_real: whether load real images while real images path can be found. 132 | :param real_path: real images path 133 | :param latent_type: if None, will infer from suffix, or treat as "w" 134 | :param real_images_resize: will resize loaded real images. 135 | :param device: 136 | :return: latent dict: dict(latent_type:latent) 137 | """ 138 | latent_path = Path(latent_path) 139 | ids = raydl.tuple_of_indices(ids) if ids is not None else None 140 | real_images = None 141 | if latent_type is None: 142 | latent_type = latent_path.suffix[1:] 143 | 144 | if load_real: 145 | real_path = Path(real_path or latent_path.with_suffix(".realpath")) 146 | if not real_path.exists(): 147 | logger.warning(f"can not find associated real images path") 148 | else: 149 | reals = torch.load(real_path) 150 | if ids is not None: 151 | reals = [reals[i] for i in ids if -len(reals) <= i < len(reals)] 152 | try: 153 | if (latent_path.parent / "real").exists(): 154 | reals = [(latent_path.parent / "real") / Path(r).name for r in reals] 155 | real_images = raydl.load_images(reals, resize=real_images_resize) 156 | logger.debug(f"load {len(real_images)} real images") 157 | except FileNotFoundError as e: 158 | logger.error(e) 159 | real_images = None 160 | 161 | latent = torch.load(latent_path) 162 | if isinstance(latent, (list, tuple)): 163 | latent_type = "styles" 164 | if ids is not None: 165 | ids = (i for i in ids if -len(latent[0]) <= i < len(latent[0])) 166 | latent = [ 167 | s[ids, :] if s.dim() == 2 else s[ids, :, :] 168 | for s in latent 169 | ] 170 | latent = [s.to(device) for s in latent] 171 | else: 172 | if ids is not None: 173 | ids = [i for i in ids if -len(latent) <= i < len(latent)] 174 | latent = latent[ids, :, :] if latent.dim() == 3 else latent[ids, :] 175 | latent = latent.to(device) 176 | if latent_type not in ["w", "z", "styles"]: 177 | latent_type = "w" if latent.dim() == 3 else "z" 178 | logger.warning(f"do not set correctly latent type, set as infer type: {latent_type}") 179 | return {latent_type: latent} if real_images is None else {latent_type: latent, "image": real_images} 180 | -------------------------------------------------------------------------------- /models/ada_ops/bias_act.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include "bias_act.h" 11 | 12 | //------------------------------------------------------------------------ 13 | // Helpers. 14 | 15 | template struct InternalType; 16 | template <> struct InternalType { typedef double scalar_t; }; 17 | template <> struct InternalType { typedef float scalar_t; }; 18 | template <> struct InternalType { typedef float scalar_t; }; 19 | 20 | //------------------------------------------------------------------------ 21 | // CUDA kernel. 22 | 23 | template 24 | __global__ void bias_act_kernel(bias_act_kernel_params p) 25 | { 26 | typedef typename InternalType::scalar_t scalar_t; 27 | int G = p.grad; 28 | scalar_t alpha = (scalar_t)p.alpha; 29 | scalar_t gain = (scalar_t)p.gain; 30 | scalar_t clamp = (scalar_t)p.clamp; 31 | scalar_t one = (scalar_t)1; 32 | scalar_t two = (scalar_t)2; 33 | scalar_t expRange = (scalar_t)80; 34 | scalar_t halfExpRange = (scalar_t)40; 35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; 36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; 37 | 38 | // Loop over elements. 39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 41 | { 42 | // Load. 43 | scalar_t x = (scalar_t)((const T*)p.x)[xi]; 44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; 45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; 46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; 47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; 48 | scalar_t yy = (gain != 0) ? yref / gain : 0; 49 | scalar_t y = 0; 50 | 51 | // Apply bias. 52 | ((G == 0) ? x : xref) += b; 53 | 54 | // linear 55 | if (A == 1) 56 | { 57 | if (G == 0) y = x; 58 | if (G == 1) y = x; 59 | } 60 | 61 | // relu 62 | if (A == 2) 63 | { 64 | if (G == 0) y = (x > 0) ? x : 0; 65 | if (G == 1) y = (yy > 0) ? x : 0; 66 | } 67 | 68 | // lrelu 69 | if (A == 3) 70 | { 71 | if (G == 0) y = (x > 0) ? x : x * alpha; 72 | if (G == 1) y = (yy > 0) ? x : x * alpha; 73 | } 74 | 75 | // tanh 76 | if (A == 4) 77 | { 78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } 79 | if (G == 1) y = x * (one - yy * yy); 80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy); 81 | } 82 | 83 | // sigmoid 84 | if (A == 5) 85 | { 86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); 87 | if (G == 1) y = x * yy * (one - yy); 88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy); 89 | } 90 | 91 | // elu 92 | if (A == 6) 93 | { 94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one; 95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one); 96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); 97 | } 98 | 99 | // selu 100 | if (A == 7) 101 | { 102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); 103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); 104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); 105 | } 106 | 107 | // softplus 108 | if (A == 8) 109 | { 110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); 111 | if (G == 1) y = x * (one - exp(-yy)); 112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } 113 | } 114 | 115 | // swish 116 | if (A == 9) 117 | { 118 | if (G == 0) 119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one); 120 | else 121 | { 122 | scalar_t c = exp(xref); 123 | scalar_t d = c + one; 124 | if (G == 1) 125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); 126 | else 127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); 128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; 129 | } 130 | } 131 | 132 | // Apply gain. 133 | y *= gain * dy; 134 | 135 | // Clamp. 136 | if (clamp >= 0) 137 | { 138 | if (G == 0) 139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; 140 | else 141 | y = (yref > -clamp & yref < clamp) ? y : 0; 142 | } 143 | 144 | // Store. 145 | ((T*)p.y)[xi] = (T)y; 146 | } 147 | } 148 | 149 | //------------------------------------------------------------------------ 150 | // CUDA kernel selection. 151 | 152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p) 153 | { 154 | if (p.act == 1) return (void*)bias_act_kernel; 155 | if (p.act == 2) return (void*)bias_act_kernel; 156 | if (p.act == 3) return (void*)bias_act_kernel; 157 | if (p.act == 4) return (void*)bias_act_kernel; 158 | if (p.act == 5) return (void*)bias_act_kernel; 159 | if (p.act == 6) return (void*)bias_act_kernel; 160 | if (p.act == 7) return (void*)bias_act_kernel; 161 | if (p.act == 8) return (void*)bias_act_kernel; 162 | if (p.act == 9) return (void*)bias_act_kernel; 163 | return NULL; 164 | } 165 | 166 | //------------------------------------------------------------------------ 167 | // Template specializations. 168 | 169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 172 | 173 | //------------------------------------------------------------------------ 174 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: base 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - absl-py=0.14.1=pyhd8ed1ab_0 9 | - argon2-cffi=20.1.0=py38h1e0a361_1 10 | - attrs=20.1.0=pyh9f0ad1d_0 11 | - backcall=0.2.0=pyh9f0ad1d_0 12 | - backports=1.0=py_2 13 | - backports.functools_lru_cache=1.6.1=py_0 14 | - blas=1.0=mkl 15 | - bleach=3.1.5=pyh9f0ad1d_0 16 | - bzip2=1.0.8=h516909a_3 17 | - c-ares=1.17.1=h27cfd23_0 18 | - ca-certificates=2021.5.30=ha878542_0 19 | - cairo=1.16.0=h3fc0475_1005 20 | - certifi=2021.5.30=py38h578d9bd_0 21 | - cffi=1.14.1=py38h5bae8af_0 22 | - chardet=3.0.4=py38_1003 23 | - conda=4.10.3=py38h578d9bd_2 24 | - conda-package-handling=1.6.1=py38h7b6447c_0 25 | - cryptography=2.9.2=py38h1ba5d50_0 26 | - cudatoolkit=10.1.243=h6bb024c_0 27 | - dataclasses=0.8=pyhc8e2a94_3 28 | - dbus=1.13.6=he372182_0 29 | - decorator=4.4.2=py_0 30 | - defusedxml=0.6.0=py_0 31 | - entrypoints=0.3=py38h32f6830_1001 32 | - expat=2.2.9=he1b5a44_2 33 | - ffmpeg=4.3.1=h167e202_0 34 | - fontconfig=2.13.1=h1056068_1002 35 | - freetype=2.10.2=h5ab3b9f_0 36 | - gettext=0.19.8.1=hc5be6a0_1002 37 | - glib=2.65.0=h6f030ca_0 38 | - gmp=6.2.0=he1b5a44_2 39 | - gnutls=3.6.13=h79a8f9a_0 40 | - graphite2=1.3.13=he1b5a44_1001 41 | - grpcio=1.33.2=py38heead2fc_2 42 | - gst-plugins-base=1.14.5=h0935bb2_2 43 | - gstreamer=1.14.5=h36ae1b5_2 44 | - harfbuzz=2.7.2=hee91db6_0 45 | - hdf5=1.10.6=nompi_h3c11f04_101 46 | - icu=67.1=he1b5a44_0 47 | - idna=2.9=py_1 48 | - ignite=0.4.6=py_0 49 | - importlib-metadata=1.7.0=py38h32f6830_0 50 | - importlib_metadata=1.7.0=0 51 | - intel-openmp=2020.2=254 52 | - ipdb=0.13.9=pyhd8ed1ab_0 53 | - ipykernel=5.3.4=py38h23f93f0_0 54 | - ipython=7.18.1=py38h1cdfbd6_0 55 | - ipython_genutils=0.2.0=py_1 56 | - jasper=1.900.1=h07fcdf6_1006 57 | - jedi=0.15.2=py38_0 58 | - jinja2=2.11.2=pyh9f0ad1d_0 59 | - jpeg=9d=h516909a_0 60 | - json5=0.9.4=pyh9f0ad1d_0 61 | - jsonschema=3.2.0=py38h32f6830_1 62 | - jupyter_client=6.1.7=py_0 63 | - jupyter_core=4.6.3=py38h32f6830_1 64 | - jupyterlab=2.2.6=py_0 65 | - jupyterlab_server=1.2.0=py_0 66 | - krb5=1.17.1=h2fd8d38_0 67 | - lame=3.100=h14c3975_1001 68 | - lcms2=2.11=h396b838_0 69 | - ld_impl_linux-64=2.34=hc38a660_9 70 | - libblas=3.8.0=16_mkl 71 | - libcblas=3.8.0=16_mkl 72 | - libclang=10.0.1=default_hde54327_1 73 | - libedit=3.1.20181209=hc058e9b_0 74 | - libevent=2.1.10=hcdb4288_2 75 | - libffi=3.2.1=he1b5a44_1007 76 | - libgcc-ng=9.1.0=hdf63c60_0 77 | - libgfortran-ng=7.5.0=hdf63c60_16 78 | - libiconv=1.16=h516909a_0 79 | - liblapack=3.8.0=16_mkl 80 | - liblapacke=3.8.0=16_mkl 81 | - libllvm10=10.0.1=he513fc3_3 82 | - libopencv=4.4.0=py38_2 83 | - libpng=1.6.37=hbc83047_0 84 | - libpq=12.3=h5513abc_0 85 | - libprotobuf=3.17.2=h4ff587b_1 86 | - libsodium=1.0.18=h516909a_0 87 | - libstdcxx-ng=9.1.0=hdf63c60_0 88 | - libtiff=4.1.0=h2733197_1 89 | - libuuid=2.32.1=h14c3975_1000 90 | - libuv=1.40.0=h7b6447c_0 91 | - libwebp-base=1.1.0=h516909a_3 92 | - libxcb=1.13=h14c3975_1002 93 | - libxkbcommon=0.10.0=he1b5a44_0 94 | - libxml2=2.9.10=h68273f3_2 95 | - lz4-c=1.9.2=he6710b0_1 96 | - markdown=3.3.4=pyhd8ed1ab_0 97 | - markupsafe=1.1.1=py38h1e0a361_1 98 | - mistune=0.8.4=py38h1e0a361_1001 99 | - mkl=2020.2=256 100 | - mkl-service=2.3.0=py38he904b0f_0 101 | - mkl_fft=1.1.0=py38h23d657b_0 102 | - mkl_random=1.1.1=py38h0573a6f_0 103 | - mysql-common=8.0.21=0 104 | - mysql-libs=8.0.21=hf3661c5_0 105 | - nbconvert=5.6.1=py38h32f6830_1 106 | - nbformat=5.0.7=py_0 107 | - ncurses=6.2=he6710b0_1 108 | - nettle=3.4.1=h1bed415_1002 109 | - ninja=1.10.0=py38hfd86e86_0 110 | - notebook=6.1.3=py38h32f6830_0 111 | - nspr=4.28=he1b5a44_0 112 | - nss=3.56=he751ad9_0 113 | - numpy=1.19.1=py38hbc911f0_0 114 | - numpy-base=1.19.1=py38hfa32c7d_0 115 | - olefile=0.46=py_0 116 | - opencv=4.4.0=py38_2 117 | - openh264=2.1.1=h8b12597_0 118 | - openssl=1.1.1l=h7f8727e_0 119 | - packaging=20.4=pyh9f0ad1d_0 120 | - pandoc=2.10.1=h516909a_0 121 | - pandocfilters=1.4.2=py_1 122 | - parso=0.5.2=py_0 123 | - pcre=8.44=he1b5a44_0 124 | - pexpect=4.8.0=py38h32f6830_1 125 | - pickleshare=0.7.5=py38h32f6830_1001 126 | - pillow=7.2.0=py38hb39fc2d_0 127 | - pip=20.0.2=py38_3 128 | - pixman=0.38.0=h516909a_1003 129 | - prometheus_client=0.8.0=pyh9f0ad1d_0 130 | - prompt-toolkit=3.0.6=py_0 131 | - protobuf=3.17.2=py38h295c915_0 132 | - pthread-stubs=0.4=h14c3975_1001 133 | - ptyprocess=0.6.0=py_1001 134 | - py-opencv=4.4.0=py38h23f93f0_2 135 | - pycosat=0.6.3=py38h7b6447c_1 136 | - pycparser=2.20=py_0 137 | - pygments=2.6.1=py_0 138 | - pyopenssl=19.1.0=py38_0 139 | - pyparsing=2.4.7=pyh9f0ad1d_0 140 | - pyrsistent=0.16.0=py38h1e0a361_0 141 | - pysocks=1.7.1=py38_0 142 | - python=3.8.5=h1103e12_7_cpython 143 | - python-dateutil=2.8.1=py_0 144 | - python_abi=3.8=1_cp38 145 | - pytorch=1.7.0=py3.8_cuda10.1.243_cudnn7.6.3_0 146 | - pyzmq=19.0.2=py38ha71036d_0 147 | - qt=5.12.6=h0c8506f_0 148 | - readline=8.0=h7b6447c_0 149 | - requests=2.23.0=py38_0 150 | - ruamel_yaml=0.15.87=py38h7b6447c_0 151 | - send2trash=1.5.0=py_0 152 | - setuptools=46.4.0=py38_0 153 | - six=1.14.0=py38_0 154 | - sqlite=3.33.0=h4cf870e_0 155 | - tensorboard=1.15.0=py38_0 156 | - terminado=0.8.3=py38h32f6830_1 157 | - testpath=0.4.4=py_0 158 | - tk=8.6.10=hbc83047_0 159 | - tornado=6.0.4=py38h1e0a361_1 160 | - tqdm=4.46.0=py_0 161 | - traitlets=4.3.3=py38h32f6830_1 162 | - typing_extensions=3.10.0.2=pyh06a4308_0 163 | - urllib3=1.25.8=py38_0 164 | - wcwidth=0.2.5=pyh9f0ad1d_1 165 | - webencodings=0.5.1=py_1 166 | - werkzeug=2.0.1=pyhd8ed1ab_0 167 | - wheel=0.34.2=py38_0 168 | - x264=1!152.20180806=h14c3975_0 169 | - xorg-kbproto=1.0.7=h14c3975_1002 170 | - xorg-libice=1.0.10=h516909a_0 171 | - xorg-libsm=1.2.3=h84519dc_1000 172 | - xorg-libx11=1.6.12=h516909a_0 173 | - xorg-libxau=1.0.9=h14c3975_0 174 | - xorg-libxdmcp=1.1.3=h516909a_0 175 | - xorg-libxext=1.3.4=h516909a_0 176 | - xorg-libxrender=0.9.10=h516909a_1002 177 | - xorg-renderproto=0.11.1=h14c3975_1002 178 | - xorg-xextproto=7.3.0=h14c3975_1002 179 | - xorg-xproto=7.0.31=h14c3975_1007 180 | - xz=5.2.5=h7b6447c_0 181 | - yaml=0.1.7=had09818_2 182 | - zeromq=4.3.2=he1b5a44_3 183 | - zipp=3.1.0=py_0 184 | - zlib=1.2.11=h7b6447c_3 185 | - zstd=1.4.5=h9ceee32_0 186 | - pip: 187 | - adabelief-pytorch==0.2.1 188 | - antlr4-python3-runtime==4.8 189 | - colorama==0.4.4 190 | - cycler==0.10.0 191 | - einops==0.3.2 192 | - fire==0.4.0 193 | - kiwisolver==1.3.2 194 | - kornia==0.5.11 195 | - lightweight-gan==0.20.4 196 | - lmdb==1.2.1 197 | - loguru==0.5.3 198 | - matplotlib==3.4.3 199 | - omegaconf==2.1.1 200 | - opencv-python==4.5.3.56 201 | - pandas==1.3.3 202 | - py==1.10.0 203 | - pytz==2021.3 204 | - pyyaml==5.4.1 205 | - retry==0.9.2 206 | - scipy==1.7.1 207 | - tabulate==0.8.9 208 | - termcolor==1.1.0 209 | - torch==1.9.1+cu111 210 | - torchaudio==0.9.1 211 | - torchvision==0.10.1+cu111 212 | prefix: /opt/conda 213 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Control-Units-in-StyleGAN2 2 | 3 | [Project](https://wrong.wang/x/Control-Units-in-StyleGAN2/) | [Paper](https://dl.acm.org/doi/10.1145/3474085.3475274) 4 | 5 | The official PyTorch implementation for MM'21 paper 'Attribute-specific Control Units in StyleGAN for Fine-grained Image Manipulation' 6 | 7 | 8 | ## Pretrained Models 9 | 10 | We provide the pretrained StyleGAN2 generator, face parser, attribute classifier and e4e encoder in the following link. 11 | 12 | [Google Drive](https://drive.google.com/drive/folders/1g-ukOZ_KZXHSroLXq87jTx7iTlIinhzf?usp=sharing) 13 | 14 | Please download the pre-trained StyleGAN2 generator at least. Put the models in `./pretrained`. 15 | 16 | Now, the structure folder `pretrained` maybe looks like: 17 | 18 | 19 | ```text 20 | $ tree ./pretrained 21 | ./pretrained/ 22 | |-- Attribute_CelebAMask-HQ_40_classifier.pth 23 | |-- BiSetNet.pth 24 | |-- correction.pt 25 | |-- e4e_ffhq_encode.pt 26 | |-- modifications 27 | | |-- before_Bushy_Eyebrows_s_123,315,325.mdfc 28 | | |-- before_alter_Black_Hair_12.mdfc 29 | | |-- before_alter_Blond_Hair_12.mdfc 30 | | `-- before_single_channel_11_286.mdfc 31 | `-- stylegan2-ffhq-config-f.pt 32 | ``` 33 | 34 | 35 | ## Set up the environment 36 | 37 | Detailed setup information can be found in [environment](documents/environment.md). 38 | 39 | We also provide the [environment.yml](environment.yml) as a reference. In general, make sure the gcc version is new enough, and all other packages can be installed via conda or pip. 40 | 41 | ## Test 42 | 43 | Let's get started quickly: 44 | 45 | ```bash 46 | python3 manipulation.py test ./pretrained/modifications/before_alter_Black_Hair_12.mdfc --max_factor 20 47 | ``` 48 | 49 | This command will save the editing results in `./tmp`: 50 | 51 | ```text 52 | $ tree ./tmp 53 | tmp 54 | |-- before_alter_Black_Hair_12_mdfc0_batch0.jpg 55 | `-- before_alter_Black_Hair_12_mdfc0_batch0_diff.jpg 56 | ``` 57 | 58 | The image named as `{mdfc_file_name}_{mdfc_id}_{batch_id}.jpg` is the generated manipulated grid image, and the `{mdfc_file_name}_{mdfc_id}_{batch_id}_diff.jpg` is the manipulated error image. 59 | 60 | `{mdfc_file_name}_{mdfc_id}_{batch_id}.jpg`: 61 | 62 | ![before_alter_Black_Hair_12_mdfc0_batch0.jpg](documents/images/tmp/before_alter_Black_Hair_12_mdfc0_batch0.jpg) 63 | 64 | `{mdfc_file_name}_{mdfc_id}_{batch_id}_diff.jpg`: 65 | 66 | ![before_alter_Black_Hair_12_mdfc0_batch0_diff.jpg](documents/images/tmp/before_alter_Black_Hair_12_mdfc0_batch0_diff.jpg "{mdfc_file_name}_{mdfc_id}_{batch_id}_diff.jpg") 67 | 68 | Then, you can test other `*.mdfc` in `./pretrained/modifications`. 69 | 70 | ## Create the modification for attributes 71 | 72 | For simplicity, let us create a new modification using a small number of positive and negative samples. 73 | 74 | Take the attribute `Blond_Hair` as an example. We will try to find the control units for this attribute, and the modification that can change the hair color. 75 | 76 | ### open a python shell, run the following command to create a latent bank. 77 | 78 | ```python 79 | import torch 80 | torch.manual_seed(0) 81 | torch.save(torch.randn(8, 512), "seed.z") 82 | ``` 83 | 84 | ### generate the images corresponding to the latent. 85 | 86 | ``` 87 | python3 manipulation.py generate --num_samples 8 --batch_size 8 --separately 0 --latent_path ./seed.z --captions 88 | ``` 89 | 90 | This command will save the image `b0.jpg` in the `./tmp`. 91 | 92 | ### determine which ones to use as positive or negative samples. 93 | 94 | The generated images are saved as `./tmp/b0.jpg`, check it and determine which image can be used as positive sample. 95 | 96 | ![b0.jpg](documents/images/tmp/b0.jpg) 97 | 98 | we select `0,1,3` as the positive samples, and `4,2,6,7` for the negative. 99 | 100 | ### generate the manipulated result only with the style vector movement. 101 | 102 | we now calculate the mean of the positive latents and negative latents (in the S space). Then move the modulation style along the direction: 103 | 104 | ```python 105 | python3 manipulation.py test --batch_size 8 --seed 1 --num_samples 16 --resize 256 --save_type grid \ 106 | --start_factor -3 --max_factor 6 --num_factors 5 \ 107 | --build_mdfc_way paired_delta \ 108 | --latent1 ./seed.z --ids1 0,1,3 \ 109 | --latent2 ./seed.z --ids2 4,2,6,7 \ 110 | --layers 12 --rules "(ric[10]>0.2)" 111 | ``` 112 | 113 | `latent1` & `ids1` specify the positive latent, and `latent2` & `ids2` for negative latent. 114 | 115 | `layers` means the layer index, and `--rules "(ric[10]>0.2)"` means we select the channels that the correction to the region `hair` is greater than 0.2 . 116 | 117 | ```python 118 | # region id mapping 119 | # ("", ) # region_id 120 | SEMANTIC_REGION = dict( 121 | [ 122 | ("background", (0,)), # 0 123 | ("brow", (1, 2)), # 1 124 | ("eye", (3, 4)), # 2 125 | ("glass", (5,)), # 3 126 | ("ear", (6, 7, 8)), # 4 127 | ("nose", (9,)), # 5 128 | ("mouth", (10,)), # 6 129 | ("lips", (11, 12)), # 7 130 | ("neck", (13, 14)), # 8 131 | ("cloth", (15,)), # 9 132 | ("hair", (16,)), # 10 133 | ("hat", (17,)), # 11 134 | ("face_up", (18,)), # 12 135 | ("face_middle", (19,)), # 13 136 | ("face_down", (20,)), # 14 137 | ] 138 | ) 139 | ``` 140 | 141 | Here is the result: 142 | 143 | ![](documents/images/tmp/alter_paired_mdfc0_batch1.jpg) 144 | 145 | ### save the modificaton as the mdfc file 146 | 147 | ```python 148 | # just copy the cli arguments for the about command 149 | PYTHONPATH=. python3 editing/modification.py paired_delta \ 150 | --latent1 ./seed.z --ids1 0,1,3 \ 151 | --latent2 ./seed.z --ids2 4,2,6,7 \ 152 | --layers 12 --rules "(ric[10]>0.2)" 153 | ``` 154 | 155 | Then change the ugly default name: 156 | 157 | ```bash 158 | mv "tmp/alter_paired_delta_seed_(0, 1, 3)_seed_(4, 2, 6, 7).mdfc" tmp/Blond_Hair.mdfc 159 | ``` 160 | 161 | ### Optimizate to find the style vector for manipulating feature maps 162 | 163 | ```bash 164 | PYTHONPATH=. python3 ./tools/manipulation/enhance.py prev tmp/Blond_Hair.mdfc 10 \ 165 | --checkpoint pretrained/stylegan2-ffhq-config-f.pt \ 166 | --batch_size 4 --num_batch 1000 167 | 168 | # test new mdfc 169 | python3 manipulation.py test ./tmp/before_Blond_Hair.mdfc --batch_size 8 --seed 1 --num_samples 16 --resize 256 --save_type grid --start_factor -3 --max_factor 3 --num_factors 5 --another_factor 0.8 170 | ``` 171 | 172 | The result of full modification to control units: 173 | 174 | ![](documents/images/tmp/before_Blode_mdfc0_batch1.jpg) 175 | 176 | 177 | # Citation 178 | 179 | If you use this code for your research, please cite our paper [Attribute-specific Control Units in StyleGAN for Fine-grained Image Manipulation 180 | ](https://dl.acm.org/doi/10.1145/3474085.3475274) 181 | 182 | ```text 183 | @inproceedings{10.1145/3474085.3475274, 184 | author = {Wang, Rui and Chen, Jian and Yu, Gang and Sun, Li and Yu, Changqian and Gao, Changxin and Sang, Nong}, 185 | title = {Attribute-Specific Control Units in StyleGAN for Fine-Grained Image Manipulation}, 186 | year = {2021}, 187 | isbn = {9781450386517}, 188 | publisher = {Association for Computing Machinery}, 189 | address = {New York, NY, USA}, 190 | url = {https://doi.org/10.1145/3474085.3475274}, 191 | doi = {10.1145/3474085.3475274}, 192 | booktitle = {Proceedings of the 29th ACM International Conference on Multimedia}, 193 | pages = {926–934}, 194 | numpages = {9}, 195 | keywords = {generative adversarial networks(GANs), control unit, image manipulation}, 196 | location = {Virtual Event, China}, 197 | series = {MM '21} 198 | } 199 | 200 | ``` 201 | -------------------------------------------------------------------------------- /models/ada_ops/conv2d_resample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """2D convolution with optional up/downsampling.""" 10 | 11 | from numbers import Integral 12 | from typing import Sequence, Union 13 | 14 | import torch 15 | 16 | import raydl 17 | from . import conv2d_gradfix 18 | from . import upfirdn2d 19 | from .upfirdn2d import _get_filter_size 20 | from .upfirdn2d import _parse_padding 21 | 22 | 23 | # ---------------------------------------------------------------------------- 24 | 25 | def _get_weight_shape(w): 26 | with raydl.suppress_tracer_warnings(): # this value will be treated as a constant 27 | shape = [int(sz) for sz in w.shape] 28 | return shape 29 | 30 | 31 | # ---------------------------------------------------------------------------- 32 | 33 | def _conv2d_wrapper(x, w, stride=1, padding: Union[Integral, Sequence] = 0, groups=1, transpose=False, 34 | flip_weight=True): 35 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. 36 | """ 37 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 38 | 39 | # Flip weight if requested. 40 | if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). 41 | w = w.flip([2, 3]) 42 | 43 | # Workaround performance pitfall in cuDNN 8.0.5, triggered when using 44 | # 1x1 kernel + memory_format=channels_last + less than 64 channels. 45 | if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose: 46 | if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64: 47 | if out_channels <= 4 and groups == 1: 48 | in_shape = x.shape 49 | x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1]) 50 | x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]]) 51 | else: 52 | x = x.to(memory_format=torch.contiguous_format) 53 | w = w.to(memory_format=torch.contiguous_format) 54 | x = conv2d_gradfix.conv2d(x, w, groups=groups) 55 | return x.to(memory_format=torch.channels_last) 56 | 57 | # Otherwise => execute using conv2d_gradfix. 58 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d 59 | return op(x, w, stride=stride, padding=padding, groups=groups) 60 | 61 | 62 | # ---------------------------------------------------------------------------- 63 | 64 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): 65 | r"""2D convolution with optional up/downsampling. 66 | 67 | Padding is performed only once at the beginning, not between the operations. 68 | 69 | Args: 70 | x: Input tensor of shape 71 | `[batch_size, in_channels, in_height, in_width]`. 72 | w: Weight tensor of shape 73 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`. 74 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by 75 | calling upfirdn2d.setup_filter(). None = identity (default). 76 | up: Integer upsampling factor (default: 1). 77 | down: Integer downsampling factor (default: 1). 78 | padding: Padding with respect to the upsampled image. Can be a single number 79 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 80 | (default: 0). 81 | groups: Split input channels into N groups (default: 1). 82 | flip_weight: False = convolution, True = correlation (default: True). 83 | flip_filter: False = convolution, True = correlation (default: False). 84 | 85 | Returns: 86 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 87 | """ 88 | # Validate arguments. 89 | assert isinstance(x, torch.Tensor) and (x.ndim == 4) 90 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 91 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) 92 | assert isinstance(up, int) and (up >= 1) 93 | assert isinstance(down, int) and (down >= 1) 94 | assert isinstance(groups, int) and (groups >= 1) 95 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 96 | fw, fh = _get_filter_size(f) 97 | px0, px1, py0, py1 = _parse_padding(padding) 98 | 99 | # Adjust padding to account for up/downsampling. 100 | if up > 1: 101 | px0 += (fw + up - 1) // 2 102 | px1 += (fw - up) // 2 103 | py0 += (fh + up - 1) // 2 104 | py1 += (fh - up) // 2 105 | if down > 1: 106 | px0 += (fw - down + 1) // 2 107 | px1 += (fw - down) // 2 108 | py0 += (fh - down + 1) // 2 109 | py1 += (fh - down) // 2 110 | 111 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. 112 | if kw == 1 and kh == 1 and (down > 1 and up == 1): 113 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0, px1, py0, py1], flip_filter=flip_filter) 114 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 115 | return x 116 | 117 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. 118 | if kw == 1 and kh == 1 and (up > 1 and down == 1): 119 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 120 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0, px1, py0, py1], gain=up ** 2, flip_filter=flip_filter) 121 | return x 122 | 123 | # Fast path: downsampling only => use strided convolution. 124 | if down > 1 and up == 1: 125 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0, px1, py0, py1], flip_filter=flip_filter) 126 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) 127 | return x 128 | 129 | # Fast path: upsampling with optional downsampling => use transpose strided convolution. 130 | if up > 1: 131 | if groups == 1: 132 | w = w.transpose(0, 1) 133 | else: 134 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) 135 | w = w.transpose(1, 2) 136 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) 137 | px0 -= kw - 1 138 | px1 -= kw - up 139 | py0 -= kh - 1 140 | py1 -= kh - up 141 | pxt = max(min(-px0, -px1), 0) 142 | pyt = max(min(-py0, -py1), 0) 143 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt, pxt], groups=groups, transpose=True, 144 | flip_weight=(not flip_weight)) 145 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt], gain=up ** 2, 146 | flip_filter=flip_filter) 147 | if down > 1: 148 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 149 | return x 150 | 151 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. 152 | if up == 1 and down == 1: 153 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: 154 | return _conv2d_wrapper(x=x, w=w, padding=[py0, px0], groups=groups, flip_weight=flip_weight) 155 | 156 | # Fallback: Generic reference implementation. 157 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0, px1, py0, py1], gain=up ** 2, 158 | flip_filter=flip_filter) 159 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 160 | if down > 1: 161 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 162 | return x 163 | 164 | # ---------------------------------------------------------------------------- 165 | -------------------------------------------------------------------------------- /models/ada_ops/conv2d_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.conv2d` that supports 10 | arbitrarily high order gradients with zero performance penalty.""" 11 | 12 | import contextlib 13 | import warnings 14 | 15 | import torch 16 | import torch.nn.functional 17 | import torch.backends.cudnn 18 | # pylint: disable=redefined-builtin 19 | # pylint: disable=arguments-differ 20 | # pylint: disable=protected-access 21 | 22 | # ---------------------------------------------------------------------------- 23 | 24 | enabled = False # Enable the custom op by setting this to true. 25 | weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. 26 | 27 | 28 | @contextlib.contextmanager 29 | def no_weight_gradients(): 30 | global weight_gradients_disabled 31 | old = weight_gradients_disabled 32 | weight_gradients_disabled = True 33 | yield 34 | weight_gradients_disabled = old 35 | 36 | 37 | # ---------------------------------------------------------------------------- 38 | 39 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 40 | if _should_use_custom_op(input): 41 | return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, 42 | output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) 43 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, 44 | dilation=dilation, groups=groups) 45 | 46 | 47 | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): 48 | if _should_use_custom_op(input): 49 | return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, 50 | output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, 51 | bias) 52 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, 53 | output_padding=output_padding, groups=groups, dilation=dilation) 54 | 55 | 56 | # ---------------------------------------------------------------------------- 57 | 58 | def _should_use_custom_op(input): 59 | assert isinstance(input, torch.Tensor) 60 | if (not enabled) or (not torch.backends.cudnn.enabled): 61 | return False 62 | if input.device.type != 'cuda': 63 | return False 64 | if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): 65 | return True 66 | warnings.warn( 67 | f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().') 68 | return False 69 | 70 | 71 | def _tuple_of_ints(xs, ndim): 72 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim 73 | assert len(xs) == ndim 74 | assert all(isinstance(x, int) for x in xs) 75 | return xs 76 | 77 | 78 | # ---------------------------------------------------------------------------- 79 | 80 | _conv2d_gradfix_cache = dict() 81 | 82 | 83 | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): 84 | # Parse arguments. 85 | ndim = 2 86 | weight_shape = tuple(weight_shape) 87 | stride = _tuple_of_ints(stride, ndim) 88 | padding = _tuple_of_ints(padding, ndim) 89 | output_padding = _tuple_of_ints(output_padding, ndim) 90 | dilation = _tuple_of_ints(dilation, ndim) 91 | 92 | # Lookup from cache. 93 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) 94 | if key in _conv2d_gradfix_cache: 95 | return _conv2d_gradfix_cache[key] 96 | 97 | # Validate arguments. 98 | assert groups >= 1 99 | assert len(weight_shape) == ndim + 2 100 | assert all(stride[i] >= 1 for i in range(ndim)) 101 | assert all(padding[i] >= 0 for i in range(ndim)) 102 | assert all(dilation[i] >= 0 for i in range(ndim)) 103 | if not transpose: 104 | assert all(output_padding[i] == 0 for i in range(ndim)) 105 | else: # transpose 106 | assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) 107 | 108 | # Helpers. 109 | common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) 110 | 111 | def calc_output_padding(input_shape, output_shape): 112 | if transpose: 113 | return [0, 0] 114 | return [ 115 | input_shape[i + 2] 116 | - (output_shape[i + 2] - 1) * stride[i] 117 | - (1 - 2 * padding[i]) 118 | - dilation[i] * (weight_shape[i + 2] - 1) 119 | for i in range(ndim) 120 | ] 121 | 122 | # Forward & backward. 123 | class Conv2d(torch.autograd.Function): 124 | @staticmethod 125 | def forward(ctx, input, weight, bias): 126 | assert weight.shape == weight_shape 127 | if not transpose: 128 | output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) 129 | else: # transpose 130 | output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, 131 | output_padding=output_padding, **common_kwargs) 132 | ctx.save_for_backward(input, weight) 133 | return output 134 | 135 | @staticmethod 136 | def backward(ctx, grad_output): 137 | input, weight = ctx.saved_tensors 138 | grad_input = None 139 | grad_weight = None 140 | grad_bias = None 141 | 142 | if ctx.needs_input_grad[0]: 143 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) 144 | grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, 145 | **common_kwargs).apply(grad_output, weight, None) 146 | assert grad_input.shape == input.shape 147 | 148 | if ctx.needs_input_grad[1] and not weight_gradients_disabled: 149 | grad_weight = Conv2dGradWeight.apply(grad_output, input) 150 | assert grad_weight.shape == weight_shape 151 | 152 | if ctx.needs_input_grad[2]: 153 | grad_bias = grad_output.sum([0, 2, 3]) 154 | 155 | return grad_input, grad_weight, grad_bias 156 | 157 | # Gradient with respect to the weights. 158 | class Conv2dGradWeight(torch.autograd.Function): 159 | @staticmethod 160 | def forward(ctx, grad_output, input): 161 | op = torch._C._jit_get_operation( 162 | 'aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight') 163 | flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, 164 | torch.backends.cudnn.allow_tf32] 165 | grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) 166 | assert grad_weight.shape == weight_shape 167 | ctx.save_for_backward(grad_output, input) 168 | return grad_weight 169 | 170 | @staticmethod 171 | def backward(ctx, grad2_grad_weight): 172 | grad_output, input = ctx.saved_tensors 173 | grad2_grad_output = None 174 | grad2_input = None 175 | 176 | if ctx.needs_input_grad[0]: 177 | grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) 178 | assert grad2_grad_output.shape == grad_output.shape 179 | 180 | if ctx.needs_input_grad[1]: 181 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) 182 | grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, 183 | **common_kwargs).apply(grad_output, grad2_grad_weight, None) 184 | assert grad2_input.shape == input.shape 185 | 186 | return grad2_grad_output, grad2_input 187 | 188 | _conv2d_gradfix_cache[key] = Conv2d 189 | return Conv2d 190 | 191 | # ---------------------------------------------------------------------------- 192 | -------------------------------------------------------------------------------- /models/StyleGAN2_wrapper.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | from loguru import logger 5 | from torch.nn import Identity 6 | 7 | from models.StyleGAN2 import Generator as StyleGAN2Generator 8 | 9 | 10 | class ImprovedStyleGAN2Generator(StyleGAN2Generator): 11 | """ 12 | wrap original StyleGAN Generator for manipulation. 13 | """ 14 | 15 | def __init__( 16 | self, 17 | size, 18 | style_dim, 19 | n_mlp, 20 | channel_multiplier=2, 21 | blur_kernel=(1, 3, 3, 1), 22 | lr_mlp=0.01, 23 | default_truncation=0.7, 24 | num_fp16_res=0, 25 | conv_clamp=None, 26 | ): 27 | super().__init__(size, style_dim=style_dim, n_mlp=n_mlp, channel_multiplier=channel_multiplier, 28 | blur_kernel=blur_kernel, lr_mlp=lr_mlp, num_fp16_res=num_fp16_res, conv_clamp=conv_clamp) 29 | 30 | self.default_truncation = default_truncation 31 | self.truncation_latent = None 32 | 33 | layers = [self.conv1, self.to_rgb1] 34 | for conv1, conv2, to_rgb in zip( 35 | self.convs[::2], self.convs[1::2], self.to_rgbs 36 | ): 37 | layers += [conv1, conv2, to_rgb] 38 | 39 | self.layers = layers 40 | 41 | def manipulation_mode(self, flag=True): 42 | with torch.no_grad(): 43 | if self.truncation_latent is None: 44 | self.truncation_latent = self.mean_latent(4096) 45 | for layer in self.layers: 46 | if flag: 47 | layer.conv.modulation, layer.conv.modulation_ = Identity(), layer.conv.modulation 48 | else: 49 | layer.conv.modulation = layer.conv.modulation_ 50 | del layer.conv.modulation_ 51 | logger.debug(f"enable manipulation mode: {flag}") 52 | 53 | def z_to_w(self, z, truncation=None, truncation_latent=None): 54 | w = self.style(z) 55 | truncation = truncation if truncation is not None else self.default_truncation 56 | truncation_latent = self.truncation_latent if truncation_latent is None else truncation_latent 57 | if truncation_latent is None: 58 | truncation_latent = self.mean_latent(4096) 59 | self.truncation_latent = truncation_latent 60 | if truncation < 1: 61 | w = truncation_latent + truncation * (w - truncation_latent) 62 | return w 63 | 64 | def w_to_styles(self, w): 65 | if w.dim() == 2: 66 | w = w.unsqueeze(dim=1).expand(-1, self.n_latent, self.style_dim) 67 | 68 | styles = [self.conv1.conv.modulation_(w[:, 0]), self.to_rgb1.conv.modulation_(w[:, 1])] 69 | 70 | i = 1 71 | for conv1, conv2, to_rgb in zip(self.convs[::2], self.convs[1::2], self.to_rgbs): 72 | styles.append(conv1.conv.modulation_(w[:, i])) 73 | styles.append(conv2.conv.modulation_(w[:, i + 1])) 74 | styles.append(to_rgb.conv.modulation_(w[:, i + 2])) 75 | i += 2 76 | 77 | return styles 78 | 79 | def styles_to_image(self, styles, noise=None, randomize_noise=False): 80 | if noise is None: 81 | if randomize_noise: 82 | noise = [None] * self.num_layers 83 | else: 84 | noise = [getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)] 85 | 86 | out = self.input(styles[0]) 87 | out = self.conv1(out, styles[0], noise=noise[0]) 88 | 89 | skip = self.to_rgb1(out, styles[1]) 90 | 91 | i = 2 92 | for conv1, conv2, noise1, noise2, to_rgb in zip( 93 | self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs 94 | ): 95 | out = conv1(out, styles[i], noise=noise1) 96 | out = conv2(out, styles[i + 1], noise=noise2) 97 | skip = to_rgb(out, styles[i + 2], skip) 98 | i += 3 99 | 100 | image = skip 101 | 102 | return image 103 | 104 | def styles_to_image_and_features(self, styles, layer_indexes, noise=None, randomize_noise=False): 105 | if noise is None: 106 | if randomize_noise: 107 | noise = [None] * self.num_layers 108 | else: 109 | noise = [getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)] 110 | 111 | layer_indexes = [layer_indexes, ] if isinstance(layer_indexes, int) else layer_indexes 112 | features = [] 113 | out = self.input(styles[0]) 114 | if 0 in layer_indexes: 115 | features.append(out) 116 | out = self.conv1(out, styles[0], noise=noise[0]) 117 | 118 | if 1 in layer_indexes: 119 | features.append(out) 120 | skip = self.to_rgb1(out, styles[1]) 121 | 122 | i = 2 123 | for conv1, conv2, noise1, noise2, to_rgb in zip( 124 | self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs 125 | ): 126 | if i in layer_indexes: 127 | features.append(out) 128 | out = conv1(out, styles[i], noise=noise1) 129 | if i + 1 in layer_indexes: 130 | features.append(out) 131 | out = conv2(out, styles[i + 1], noise=noise2) 132 | if i + 2 in layer_indexes: 133 | features.append(out) 134 | skip = to_rgb(out, styles[i + 2], skip) 135 | i += 3 136 | 137 | image = skip 138 | return image, features 139 | 140 | def w_to_image(self, w, noise=None, randomize_noise=False): 141 | styles = self.w_to_styles(w) 142 | return self.styles_to_image(styles, noise, randomize_noise) 143 | 144 | def z_to_image(self, z, truncation=None, truncation_latent=None, noise=None, randomize_noise=True): 145 | w = self.z_to_w(z, truncation, truncation_latent) 146 | return self.w_to_image(w, noise, randomize_noise) 147 | 148 | def forward(self, w=None, styles=None, z=None, modifications=None): # noqa 149 | assert w is not None or styles is not None or z is not None 150 | styles = [s.clone() for s in styles] if styles is not None \ 151 | else self.w_to_styles(w if w is not None else self.z_to_w(z)) 152 | if modifications is not None: 153 | assert isinstance(modifications, (list, tuple)) 154 | for mdfc, apply_param in modifications: 155 | styles = mdfc.apply(styles, **apply_param) 156 | image = self.styles_to_image(styles) 157 | return image 158 | 159 | @staticmethod 160 | def conv_layer_name(layer): 161 | if layer == 0: 162 | return "conv1" 163 | elif layer == 1: 164 | return "to_rgb1" 165 | elif layer % 3 == 1: 166 | return f"to_rgbs.{layer // 3 - 1}" 167 | else: 168 | return f"convs.{2 * ((layer - 2) // 3) + (layer + 1) % 3}" 169 | 170 | @staticmethod 171 | def parameter_tag(name: str): 172 | if name.startswith("style."): 173 | return "mapping" 174 | if "noise" in name: 175 | return "noise" 176 | if name.startswith("input."): 177 | return "input" 178 | if "modulation" in name: 179 | return "modulation" 180 | if "blur" in name: 181 | return "conv.blur" 182 | if "bias" in name: 183 | return "conv.bias" 184 | if "upsample" in name: 185 | return "to_rgb.upsample" 186 | if "weight" in name: 187 | return "conv.weight" 188 | raise ValueError("invalid param name: " + name) 189 | 190 | @staticmethod 191 | def infer_arguments(state_dict: dict): 192 | num_layers = len(list(filter(lambda x: x.startswith("noise"), state_dict.keys()))) 193 | arguments = dict() 194 | arguments["size"] = 2 ** ((num_layers - 1) / 2 + 2) 195 | arguments["style_dim"] = state_dict["style.1.bias"].size()[0] 196 | arguments["n_mlp"] = len(list(filter(lambda x: x.startswith("style"), state_dict.keys()))) // 2 197 | arguments["channel_multiplier"] = state_dict["convs.8.conv.modulation.bias"].size()[0] // 256 198 | return arguments 199 | 200 | @staticmethod 201 | def load(model_path, lr_mlp=0.01, default_truncation=0.7, conv_clamp=256.0, num_fp16_res=4, device="cuda"): 202 | """ 203 | load checkpoint. Will return a Generator instance 204 | """ 205 | model_path = Path(model_path) 206 | checkpoint = torch.load(model_path, map_location=torch.device("cpu")) 207 | 208 | if (model_path.parent.parent / "config.yaml").exists(): 209 | # find saved config yaml for this checkpoint 210 | # create generator use arguments in config.yaml 211 | from omegaconf import OmegaConf 212 | conf = OmegaConf.load(model_path.parent.parent / "config.yaml") 213 | arguments = conf.model.generator 214 | for key in ["_type", ]: 215 | arguments.pop(key) 216 | logger.info("use arguments from config.yaml") 217 | logger.info(arguments) 218 | else: 219 | arguments = ImprovedStyleGAN2Generator.infer_arguments(checkpoint["g_ema"]) 220 | arguments["lr_mlp"] = lr_mlp 221 | arguments["num_fp16_res"] = num_fp16_res 222 | arguments["conv_clamp"] = conv_clamp 223 | 224 | arguments["default_truncation"] = default_truncation 225 | g_ema = ImprovedStyleGAN2Generator(**arguments).to(device) 226 | for k in list(checkpoint["g_ema"].keys()): 227 | if "noises" in k: 228 | checkpoint["g_ema"].pop(k) 229 | logger.warning(g_ema.load_state_dict(checkpoint["g_ema"], strict=False)) 230 | 231 | if "latent_avg" in checkpoint: 232 | g_ema.truncation_latent = checkpoint["latent_avg"].unsqueeze(dim=0).to(device) 233 | return g_ema 234 | -------------------------------------------------------------------------------- /tools/transfer/convert_weight.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | import pickle 5 | import sys 6 | 7 | import numpy as np 8 | import torch 9 | from torchvision import utils 10 | 11 | from models.StyleGAN2 import Generator, Discriminator 12 | 13 | 14 | def convert_modconv(vars, source_name, target_name, flip=False): 15 | weight = vars[source_name + "/weight"].value().eval() 16 | mod_weight = vars[source_name + "/mod_weight"].value().eval() 17 | mod_bias = vars[source_name + "/mod_bias"].value().eval() 18 | noise = vars[source_name + "/noise_strength"].value().eval() 19 | bias = vars[source_name + "/bias"].value().eval() 20 | 21 | dic = { 22 | "conv.weight": np.expand_dims(weight.transpose((3, 2, 0, 1)), 0), 23 | "conv.modulation.weight": mod_weight.transpose((1, 0)), 24 | "conv.modulation.bias": mod_bias + 1, 25 | "noise.weight": np.array([noise]), 26 | "activate.bias": bias, 27 | } 28 | 29 | dic_torch = {} 30 | 31 | for k, v in dic.items(): 32 | dic_torch[target_name + "." + k] = torch.from_numpy(v) 33 | 34 | if flip: 35 | dic_torch[target_name + ".conv.weight"] = torch.flip( 36 | dic_torch[target_name + ".conv.weight"], [3, 4] 37 | ) 38 | 39 | return dic_torch 40 | 41 | 42 | def convert_conv(vars, source_name, target_name, bias=True, start=0): 43 | weight = vars[source_name + "/weight"].value().eval() 44 | 45 | dic = {"weight": weight.transpose((3, 2, 0, 1))} 46 | 47 | if bias: 48 | dic["bias"] = vars[source_name + "/bias"].value().eval() 49 | 50 | dic_torch = {} 51 | 52 | dic_torch[target_name + f".{start}.weight"] = torch.from_numpy(dic["weight"]) 53 | 54 | if bias: 55 | dic_torch[target_name + f".{start + 1}.bias"] = torch.from_numpy(dic["bias"]) 56 | 57 | return dic_torch 58 | 59 | 60 | def convert_torgb(vars, source_name, target_name): 61 | weight = vars[source_name + "/weight"].value().eval() 62 | mod_weight = vars[source_name + "/mod_weight"].value().eval() 63 | mod_bias = vars[source_name + "/mod_bias"].value().eval() 64 | bias = vars[source_name + "/bias"].value().eval() 65 | 66 | dic = { 67 | "conv.weight": np.expand_dims(weight.transpose((3, 2, 0, 1)), 0), 68 | "conv.modulation.weight": mod_weight.transpose((1, 0)), 69 | "conv.modulation.bias": mod_bias + 1, 70 | "bias": bias.reshape((1, 3, 1, 1)), 71 | } 72 | 73 | dic_torch = {} 74 | 75 | for k, v in dic.items(): 76 | dic_torch[target_name + "." + k] = torch.from_numpy(v) 77 | 78 | return dic_torch 79 | 80 | 81 | def convert_dense(vars, source_name, target_name): 82 | weight = vars[source_name + "/weight"].value().eval() 83 | bias = vars[source_name + "/bias"].value().eval() 84 | 85 | dic = {"weight": weight.transpose((1, 0)), "bias": bias} 86 | 87 | dic_torch = {} 88 | 89 | for k, v in dic.items(): 90 | dic_torch[target_name + "." + k] = torch.from_numpy(v) 91 | 92 | return dic_torch 93 | 94 | 95 | def update(state_dict, new): 96 | for k, v in new.items(): 97 | if k not in state_dict: 98 | raise KeyError(k + " is not found") 99 | 100 | if v.shape != state_dict[k].shape: 101 | raise ValueError(f"Shape mismatch: {v.shape} vs {state_dict[k].shape}") 102 | 103 | state_dict[k] = v 104 | 105 | 106 | def discriminator_fill_statedict(statedict, vars, size): 107 | log_size = int(math.log(size, 2)) 108 | 109 | update(statedict, convert_conv(vars, f"{size}x{size}/FromRGB", "convs.0")) 110 | 111 | conv_i = 1 112 | 113 | for i in range(log_size - 2, 0, -1): 114 | reso = 4 * 2 ** i 115 | update( 116 | statedict, 117 | convert_conv(vars, f"{reso}x{reso}/Conv0", f"convs.{conv_i}.conv1"), 118 | ) 119 | update( 120 | statedict, 121 | convert_conv( 122 | vars, f"{reso}x{reso}/Conv1_down", f"convs.{conv_i}.conv2", start=1 123 | ), 124 | ) 125 | update( 126 | statedict, 127 | convert_conv( 128 | vars, f"{reso}x{reso}/Skip", f"convs.{conv_i}.skip", start=1, bias=False 129 | ), 130 | ) 131 | conv_i += 1 132 | 133 | update(statedict, convert_conv(vars, f"4x4/Conv", "final_conv")) 134 | update(statedict, convert_dense(vars, f"4x4/Dense0", "final_linear.0")) 135 | update(statedict, convert_dense(vars, f"Output", "final_linear.1")) 136 | 137 | return statedict 138 | 139 | 140 | def fill_statedict(state_dict, vars, size, n_mlp): 141 | log_size = int(math.log(size, 2)) 142 | 143 | for i in range(n_mlp): 144 | update(state_dict, convert_dense(vars, f"G_mapping/Dense{i}", f"style.{i + 1}")) 145 | 146 | update( 147 | state_dict, 148 | { 149 | "input.input": torch.from_numpy( 150 | vars["G_synthesis/4x4/Const/const"].value().eval() 151 | ) 152 | }, 153 | ) 154 | 155 | update(state_dict, convert_torgb(vars, "G_synthesis/4x4/ToRGB", "to_rgb1")) 156 | 157 | for i in range(log_size - 2): 158 | reso = 4 * 2 ** (i + 1) 159 | update( 160 | state_dict, 161 | convert_torgb(vars, f"G_synthesis/{reso}x{reso}/ToRGB", f"to_rgbs.{i}"), 162 | ) 163 | 164 | update(state_dict, convert_modconv(vars, "G_synthesis/4x4/Conv", "conv1")) 165 | 166 | conv_i = 0 167 | 168 | for i in range(log_size - 2): 169 | reso = 4 * 2 ** (i + 1) 170 | update( 171 | state_dict, 172 | convert_modconv( 173 | vars, 174 | f"G_synthesis/{reso}x{reso}/Conv0_up", 175 | f"convs.{conv_i}", 176 | flip=True, 177 | ), 178 | ) 179 | update( 180 | state_dict, 181 | convert_modconv( 182 | vars, f"G_synthesis/{reso}x{reso}/Conv1", f"convs.{conv_i + 1}" 183 | ), 184 | ) 185 | conv_i += 2 186 | 187 | for i in range(0, (log_size - 2) * 2 + 1): 188 | update( 189 | state_dict, 190 | { 191 | f"noises.noise_{i}": torch.from_numpy( 192 | vars[f"G_synthesis/noise{i}"].value().eval() 193 | ) 194 | }, 195 | ) 196 | 197 | return state_dict 198 | 199 | 200 | if __name__ == "__main__": 201 | device = "cuda" 202 | 203 | parser = argparse.ArgumentParser( 204 | description="Tensorflow to pytorch model checkpoint converter" 205 | ) 206 | parser.add_argument( 207 | "--repo", 208 | type=str, 209 | required=True, 210 | help="path to the offical StyleGAN2 repository with dnnlib/ folder", 211 | ) 212 | parser.add_argument( 213 | "--gen", action="store_true", help="convert the generator weights" 214 | ) 215 | parser.add_argument( 216 | "--disc", action="store_true", help="convert the discriminator weights" 217 | ) 218 | parser.add_argument( 219 | "--generate_sample", action="store_false", help="generate samples and compare it with tf version" 220 | ) 221 | parser.add_argument( 222 | "--channel_multiplier", 223 | type=int, 224 | default=2, 225 | help="channel multiplier factor. config-f = 2, else = 1", 226 | ) 227 | parser.add_argument("path", metavar="PATH", help="path to the tensorflow weights") 228 | 229 | args = parser.parse_args() 230 | 231 | sys.path.append(args.repo) 232 | 233 | import dnnlib 234 | from dnnlib import tflib 235 | 236 | tflib.init_tf() 237 | 238 | with open(args.path, "rb") as f: 239 | generator, discriminator, g_ema = pickle.load(f) 240 | 241 | size = g_ema.output_shape[2] 242 | 243 | n_mlp = 0 244 | mapping_layers_names = g_ema.__getstate__()['components']['mapping'].list_layers() 245 | for layer in mapping_layers_names: 246 | if layer[0].startswith('Dense'): 247 | n_mlp += 1 248 | 249 | g = Generator(size, 512, n_mlp, channel_multiplier=args.channel_multiplier) 250 | state_dict = g.state_dict() 251 | state_dict = fill_statedict(state_dict, g_ema.vars, size, n_mlp) 252 | 253 | g.load_state_dict(state_dict) 254 | 255 | latent_avg = torch.from_numpy(g_ema.vars["dlatent_avg"].value().eval()) 256 | 257 | ckpt = {"g_ema": state_dict, "latent_avg": latent_avg} 258 | 259 | if args.gen: 260 | g_train = Generator(size, 512, n_mlp, channel_multiplier=args.channel_multiplier) 261 | g_train_state = g_train.state_dict() 262 | g_train_state = fill_statedict(g_train_state, generator.vars, size, n_mlp) 263 | ckpt["g"] = g_train_state 264 | 265 | if args.disc: 266 | disc = Discriminator(size, channel_multiplier=args.channel_multiplier) 267 | d_state = disc.state_dict() 268 | d_state = discriminator_fill_statedict(d_state, discriminator.vars, size) 269 | ckpt["d"] = d_state 270 | 271 | name = os.path.splitext(os.path.basename(args.path))[0] 272 | torch.save(ckpt, name + ".pt") 273 | 274 | if args.generate_sample: 275 | batch_size = {256: 16, 512: 9, 1024: 4} 276 | n_sample = batch_size.get(size, 25) 277 | 278 | g = g.to(device) 279 | 280 | z = np.random.RandomState(0).randn(n_sample, 512).astype("float32") 281 | 282 | with torch.no_grad(): 283 | img_pt, _ = g( 284 | [torch.from_numpy(z).to(device)], 285 | truncation=0.5, 286 | truncation_latent=latent_avg.to(device), 287 | randomize_noise=False, 288 | ) 289 | 290 | Gs_kwargs = dnnlib.EasyDict() 291 | Gs_kwargs.randomize_noise = False 292 | img_tf = g_ema.run(z, None, **Gs_kwargs) 293 | img_tf = torch.from_numpy(img_tf).to(device) 294 | 295 | img_diff = ((img_pt + 1) / 2).clamp(0.0, 1.0) - ((img_tf.to(device) + 1) / 2).clamp( 296 | 0.0, 1.0 297 | ) 298 | 299 | img_concat = torch.cat((img_tf, img_pt, img_diff), dim=0) 300 | 301 | print(img_diff.abs().max()) 302 | 303 | utils.save_image( 304 | img_concat, name + ".png", nrow=n_sample, normalize=True, range=(-1, 1) 305 | ) 306 | -------------------------------------------------------------------------------- /raydl/metrics/generation.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import warnings 3 | from pathlib import Path 4 | from typing import Callable, Optional, Union 5 | 6 | import ignite.distributed as idist 7 | import numpy as np 8 | import torch 9 | from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce 10 | from torch.hub import download_url_to_file 11 | 12 | STYLEGAN2_ADA_FID_WEIGHTS_URL = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/' \ 13 | 'pretrained/metrics/inception-2015-12-05.pt' 14 | 15 | 16 | def fid_score_(sample_mean, sample_cov, real_mean, real_conv, eps=1e-6): 17 | try: 18 | import scipy 19 | import scipy.linalg 20 | except ImportError: 21 | raise RuntimeError("fid_score requires scipy to be installed.") 22 | m = np.square(sample_mean - real_mean).sum() 23 | s, _ = scipy.linalg.sqrtm(np.dot(sample_cov, real_conv), disp=False) # pylint: disable=no-member 24 | fid = np.real(m + np.trace(sample_cov + real_conv - s * 2)) 25 | return float(fid) 26 | 27 | 28 | def fid_score(sample_mean, sample_cov, real_mean, real_cov, eps=1e-6): 29 | """Refer to the implementation from: 30 | 31 | https://github.com/rosinality/stylegan2-pytorch/blob/master/fid.py#L34 32 | """ 33 | try: 34 | import scipy 35 | import scipy.linalg 36 | except ImportError: 37 | raise RuntimeError("fid_score requires scipy to be installed.") 38 | 39 | cov_sqrt, _ = scipy.linalg.sqrtm(sample_cov @ real_cov, disp=False) 40 | 41 | if not np.isfinite(cov_sqrt).all(): 42 | print('product of cov matrices is singular') 43 | offset = np.eye(sample_cov.shape[0]) * eps 44 | cov_sqrt = scipy.linalg.sqrtm( 45 | (sample_cov + offset) @ (real_cov + offset)) 46 | 47 | if np.iscomplexobj(cov_sqrt): 48 | if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): 49 | m = np.max(np.abs(cov_sqrt.imag)) 50 | 51 | raise ValueError(f'Imaginary component {m}') 52 | 53 | cov_sqrt = cov_sqrt.real 54 | 55 | mean_diff = sample_mean - real_mean 56 | mean_norm = mean_diff @ mean_diff 57 | 58 | trace = np.trace(sample_cov) + np.trace(real_cov) - 2 * np.trace(cov_sqrt) 59 | 60 | fid = mean_norm + trace 61 | 62 | return float(fid) 63 | 64 | 65 | class StyleGAN2InceptionExtractor: 66 | def __init__(self, inception_path="./pretrained_models/stylegan2-ada-fid-inception.pt") -> None: 67 | self.device = idist.device() 68 | if not Path(inception_path).exists(): 69 | if idist.get_local_rank() > 0: 70 | # Ensure that only local rank 0 download the checkpoint 71 | # Thus each node will download a copy of the checkpoint 72 | idist.barrier() 73 | if not Path(inception_path).parent.exists(): 74 | Path(inception_path).parent.mkdir() 75 | download_url_to_file(STYLEGAN2_ADA_FID_WEIGHTS_URL, inception_path) 76 | if idist.get_local_rank() == 0: 77 | # Ensure that only local rank 0 download the dataset 78 | idist.barrier() 79 | self.inception = torch.jit.load(inception_path, map_location=self.device) 80 | self.inception.to(self.device) 81 | self.inception.eval() 82 | 83 | @torch.no_grad() 84 | def __call__(self, data: torch.Tensor) -> torch.Tensor: 85 | if data.dim() != 4: 86 | raise ValueError(f"Inputs should be a tensor of dim 4, got {data.dim()}") 87 | if data.shape[1] != 3: 88 | raise ValueError(f"Inputs should be a tensor with 3 channels, got {data.shape}") 89 | data = data.to(self.device) 90 | data = (data * 127.5 + 128).clamp(0, 255).to(torch.uint8) 91 | return self.inception(data, return_features=True) 92 | 93 | 94 | class InceptionExtractor: 95 | def __init__(self) -> None: 96 | try: 97 | from torchvision import models 98 | except ImportError: 99 | raise RuntimeError("This module requires torchvision to be installed.") 100 | self.model = models.inception_v3(pretrained=True) 101 | self.model.fc = torch.nn.Identity() 102 | self.model.eval() 103 | 104 | @torch.no_grad() 105 | def __call__(self, data: torch.Tensor) -> torch.Tensor: 106 | if data.dim() != 4: 107 | raise ValueError(f"Inputs should be a tensor of dim 4, got {data.dim()}") 108 | if data.shape[1] != 3: 109 | raise ValueError(f"Inputs should be a tensor with 3 channels, got {data.shape}") 110 | return self.model(data) 111 | 112 | 113 | class FID(Metric): 114 | def __init__( 115 | self, 116 | precomputed_pkl: Optional[Union[str, Path]] = None, 117 | computed_pkl_save_path: Optional[Union[str, Path]] = None, 118 | output_transform: Callable = lambda x: x, 119 | max_num_examples: Optional[int] = None, 120 | num_features: Optional[int] = None, 121 | inception_path="./pretrained_models/stylegan2-ada-fid-inception.pt", 122 | feature_extractor: Optional[Callable] = None, 123 | device: Union[str, torch.device] = torch.device("cpu"), 124 | ) -> None: 125 | 126 | try: 127 | import scipy # noqa: F401 128 | except ImportError: 129 | raise RuntimeError("This module requires scipy to be installed.") 130 | 131 | # default is inception 132 | if num_features is None and feature_extractor is None: 133 | num_features = 2048 134 | feature_extractor = StyleGAN2InceptionExtractor(inception_path) 135 | elif num_features is None: 136 | raise ValueError("Argument num_features should be defined") 137 | elif feature_extractor is None: 138 | self._feature_extractor = lambda x: x 139 | feature_extractor = self._feature_extractor 140 | 141 | if num_features <= 0: 142 | raise ValueError(f"Argument num_features must be greater to zero, got: {num_features}") 143 | self._num_features = num_features 144 | self._feature_extractor = feature_extractor 145 | self._eps = 1e-6 146 | 147 | if precomputed_pkl is None and computed_pkl_save_path is None: 148 | raise ValueError(f"must set one of precomputed_pkl, computed_pkl_save_path") 149 | 150 | self.computed_pkl_save_path = computed_pkl_save_path 151 | if precomputed_pkl is None: 152 | self._mean, self._cov = None, None 153 | else: 154 | self._mean, self._cov = self._load_precomputed_pkl(precomputed_pkl) 155 | 156 | self.max_num_examples = max_num_examples 157 | super(FID, self).__init__(output_transform=output_transform, device=device) 158 | 159 | @staticmethod 160 | def _load_precomputed_pkl(precomputed_pkl): 161 | assert Path(precomputed_pkl).exists(), f"{precomputed_pkl} do not exist" 162 | with open(precomputed_pkl, "rb") as f: 163 | reference = pickle.load(f) 164 | mean = reference['mean'] 165 | cov = reference['cov'] 166 | return mean, cov 167 | 168 | def _online_update(self, features: torch.Tensor) -> None: 169 | features = features.to(torch.float64) 170 | 171 | if self.raw_mean is None or self.raw_cov is None: 172 | self.raw_mean = features.sum(dim=0) 173 | self.raw_cov = features.T @ features 174 | return 175 | 176 | self.raw_mean += features.sum(dim=0) 177 | self.raw_cov += features.T @ features 178 | return 179 | 180 | @staticmethod 181 | def _check_feature_input(feature: torch.Tensor) -> None: 182 | if feature.dim() != 2: 183 | raise ValueError(f"Features must be a tensor of dim 2, got: {feature.dim()}") 184 | if feature.shape[0] == 0: 185 | raise ValueError(f"Batch size should be greater than one, got: {feature.shape[0]}") 186 | if feature.shape[1] == 0: 187 | raise ValueError(f"Feature size should be greater than one, got: {feature.shape[1]}") 188 | 189 | @reinit__is_reduced 190 | def reset(self) -> None: 191 | self.raw_mean, self.raw_cov = None, None 192 | self._num_examples = 0 193 | self._last_features = None 194 | super(FID, self).reset() 195 | 196 | def _update(self, features): 197 | # Updates the mean and covariance for features 198 | self._online_update(features) 199 | self._num_examples += features.shape[0] 200 | 201 | @property 202 | def is_full(self): 203 | return self._last_features is not None 204 | 205 | @reinit__is_reduced 206 | def update(self, output: torch.Tensor) -> None: 207 | if self._last_features is not None: 208 | return 209 | 210 | # Extract the features from the outputs 211 | features = self._feature_extractor(output.detach()).to(self._device) 212 | cur_num_examples = features.shape[0] 213 | # Check the feature shapes 214 | self._check_feature_input(features) 215 | if self.max_num_examples is None or self.max_num_examples > ( 216 | self._num_examples + cur_num_examples) * idist.get_world_size(): 217 | self._update(features) 218 | else: 219 | self._last_features = idist.all_gather(features) 220 | 221 | @sync_all_reduce("_num_examples", "raw_mean", "raw_cov") 222 | def compute(self): 223 | if self._last_features is not None: 224 | _num = self.max_num_examples - self._num_examples 225 | if _num < 0: 226 | raise RuntimeError(f"max items: {self.max_num_examples} but now we have: {self._num_examples}") 227 | self._update(self._last_features[:_num]) 228 | 229 | if self.max_num_examples is not None: 230 | assert self._num_examples == self.max_num_examples, \ 231 | f"num_examples: {self._num_examples} != {self.max_num_examples}" 232 | 233 | cur_mean = (self.raw_mean / self._num_examples).cpu().numpy() 234 | cur_cov = (self.raw_cov / self._num_examples).cpu().numpy() 235 | cur_cov = cur_cov - np.outer(cur_mean, cur_mean) 236 | 237 | if self.computed_pkl_save_path is not None and idist.get_rank() == 0: 238 | with open(self.computed_pkl_save_path, "wb") as f: 239 | embeds = dict(mean=cur_mean, cov=cur_cov) 240 | pickle.dump(embeds, f) 241 | 242 | if self._mean is not None and self._cov is not None: 243 | fid = fid_score_(cur_mean, cur_cov, self._mean, self._cov, self._eps) 244 | if fid == float("inf"): 245 | warnings.warn("The product of covariance of train and test features is out of bounds.") 246 | return fid 247 | else: 248 | return float("inf") 249 | -------------------------------------------------------------------------------- /models/ada_ops/bias_act.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom PyTorch ops for efficient bias and activation.""" 10 | 11 | import os 12 | import traceback 13 | import warnings 14 | 15 | import numpy as np 16 | import torch 17 | import torch.nn.functional 18 | 19 | import raydl 20 | from .. import custom_ops 21 | 22 | # ---------------------------------------------------------------------------- 23 | 24 | activation_funcs = { 25 | 'linear': raydl.AttributeDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', 26 | has_2nd_grad=False), 27 | 'relu': raydl.AttributeDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), 28 | cuda_idx=2, ref='y', has_2nd_grad=False), 29 | 'lrelu': raydl.AttributeDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, 30 | def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), 31 | 'tanh': raydl.AttributeDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', 32 | has_2nd_grad=True), 33 | 'sigmoid': raydl.AttributeDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', 34 | has_2nd_grad=True), 35 | 'elu': raydl.AttributeDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, 36 | ref='y', 37 | has_2nd_grad=True), 38 | 'selu': raydl.AttributeDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, 39 | ref='y', has_2nd_grad=True), 40 | 'softplus': raydl.AttributeDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, 41 | cuda_idx=8, ref='y', has_2nd_grad=True), 42 | 'swish': raydl.AttributeDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, 43 | ref='x', has_2nd_grad=True), 44 | } 45 | 46 | # ---------------------------------------------------------------------------- 47 | 48 | _inited = False 49 | _plugin = None 50 | _null_tensor = torch.empty([0]) 51 | 52 | 53 | def _init(): 54 | global _inited, _plugin 55 | if not _inited: 56 | _inited = True 57 | sources = ['bias_act.cpp', 'bias_act.cu'] 58 | sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] 59 | try: 60 | _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) 61 | except: 62 | warnings.warn( 63 | 'Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) 64 | return _plugin is not None 65 | 66 | 67 | # ---------------------------------------------------------------------------- 68 | 69 | def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): 70 | r"""Fused bias and activation function. 71 | 72 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`, 73 | and scales the result by `gain`. Each of the steps is optional. In most cases, 74 | the fused op is considerably more efficient than performing the same calculation 75 | using standard PyTorch ops. It supports first and second order gradients, 76 | but not third order gradients. 77 | 78 | Args: 79 | x: Input activation tensor. Can be of any shape. 80 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type 81 | as `x`. The shape must be known, and it must match the dimension of `x` 82 | corresponding to `dim`. 83 | dim: The dimension in `x` corresponding to the elements of `b`. 84 | The value of `dim` is ignored if `b` is not specified. 85 | act: Name of the activation function to evaluate, or `"linear"` to disable. 86 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. 87 | See `activation_funcs` for a full list. `None` is not allowed. 88 | alpha: Shape parameter for the activation function, or `None` to use the default. 89 | gain: Scaling factor for the output tensor, or `None` to use default. 90 | See `activation_funcs` for the default scaling of each activation function. 91 | If unsure, consider specifying 1. 92 | clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable 93 | the clamping (default). 94 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). 95 | 96 | Returns: 97 | Tensor of the same shape and datatype as `x`. 98 | """ 99 | assert isinstance(x, torch.Tensor) 100 | assert impl in ['ref', 'cuda'] 101 | if impl == 'cuda' and x.device.type == 'cuda' and _init(): 102 | return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) 103 | return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) 104 | 105 | 106 | # ---------------------------------------------------------------------------- 107 | 108 | def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): 109 | """Slow reference implementation of `bias_act()` using standard TensorFlow ops. 110 | """ 111 | assert isinstance(x, torch.Tensor) 112 | assert clamp is None or clamp >= 0 113 | spec = activation_funcs[act] 114 | alpha = float(alpha if alpha is not None else spec.def_alpha) 115 | gain = float(gain if gain is not None else spec.def_gain) 116 | clamp = float(clamp if clamp is not None else -1) 117 | 118 | # Add bias. 119 | if b is not None: 120 | assert isinstance(b, torch.Tensor) and b.ndim == 1 121 | assert 0 <= dim < x.ndim 122 | assert b.shape[0] == x.shape[dim] 123 | x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) 124 | 125 | # Evaluate activation function. 126 | alpha = float(alpha) 127 | x = spec.func(x, alpha=alpha) 128 | 129 | # Scale by gain. 130 | gain = float(gain) 131 | if gain != 1: 132 | x = x * gain 133 | 134 | # Clamp. 135 | if clamp >= 0: 136 | x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type 137 | return x 138 | 139 | 140 | # ---------------------------------------------------------------------------- 141 | 142 | _bias_act_cuda_cache = dict() 143 | 144 | 145 | def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): 146 | """Fast CUDA implementation of `bias_act()` using custom ops. 147 | """ 148 | # Parse arguments. 149 | assert clamp is None or clamp >= 0 150 | spec = activation_funcs[act] 151 | alpha = float(alpha if alpha is not None else spec.def_alpha) 152 | gain = float(gain if gain is not None else spec.def_gain) 153 | clamp = float(clamp if clamp is not None else -1) 154 | 155 | # Lookup from cache. 156 | key = (dim, act, alpha, gain, clamp) 157 | if key in _bias_act_cuda_cache: 158 | return _bias_act_cuda_cache[key] 159 | 160 | # Forward op. 161 | class BiasActCuda(torch.autograd.Function): 162 | @staticmethod 163 | def forward(ctx, x, b): # pylint: disable=arguments-differ 164 | ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format 165 | x = x.contiguous(memory_format=ctx.memory_format) 166 | b = b.contiguous() if b is not None else _null_tensor 167 | y = x 168 | if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: 169 | y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, 170 | clamp) 171 | ctx.save_for_backward( 172 | x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 173 | b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 174 | y if 'y' in spec.ref else _null_tensor) 175 | return y 176 | 177 | @staticmethod 178 | def backward(ctx, dy): # pylint: disable=arguments-differ 179 | dy = dy.contiguous(memory_format=ctx.memory_format) 180 | x, b, y = ctx.saved_tensors 181 | dx = None 182 | db = None 183 | 184 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: 185 | dx = dy 186 | if act != 'linear' or gain != 1 or clamp >= 0: 187 | dx = BiasActCudaGrad.apply(dy, x, b, y) 188 | 189 | if ctx.needs_input_grad[1]: 190 | db = dx.sum([i for i in range(dx.ndim) if i != dim]) 191 | 192 | return dx, db 193 | 194 | # Backward op. 195 | class BiasActCudaGrad(torch.autograd.Function): 196 | @staticmethod 197 | def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ 198 | ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format 199 | dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) 200 | ctx.save_for_backward( 201 | dy if spec.has_2nd_grad else _null_tensor, 202 | x, b, y) 203 | return dx 204 | 205 | @staticmethod 206 | def backward(ctx, d_dx): # pylint: disable=arguments-differ 207 | d_dx = d_dx.contiguous(memory_format=ctx.memory_format) 208 | dy, x, b, y = ctx.saved_tensors 209 | d_dy = None 210 | d_x = None 211 | d_b = None 212 | d_y = None 213 | 214 | if ctx.needs_input_grad[0]: 215 | d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) 216 | 217 | if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): 218 | d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) 219 | 220 | if spec.has_2nd_grad and ctx.needs_input_grad[2]: 221 | d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) 222 | 223 | return d_dy, d_x, d_b, d_y 224 | 225 | # Add to cache. 226 | _bias_act_cuda_cache[key] = BiasActCuda 227 | return BiasActCuda 228 | 229 | # ---------------------------------------------------------------------------- 230 | -------------------------------------------------------------------------------- /raydl/io.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import pickle 4 | import warnings 5 | from io import BytesIO 6 | from pathlib import Path 7 | from typing import Union, Optional, List, Tuple, Text, Iterable, Sequence 8 | 9 | import lmdb 10 | import torch 11 | import torch.nn.functional as F 12 | import torchvision.transforms.functional 13 | from PIL import Image, ImageDraw, ImageFont 14 | from torchvision.datasets.folder import is_image_file, default_loader 15 | from torchvision.utils import make_grid 16 | 17 | __all__ = [ 18 | "load_images", 19 | "save_images", 20 | "draw_captions_over_image", 21 | "resize_images", 22 | "images_files", 23 | "pil_loader", 24 | "LMDBCacheLoader" 25 | ] 26 | 27 | 28 | def resize_images(images: torch.Tensor, resize=None, resize_mode: str = "bilinear") -> torch.Tensor: 29 | """ 30 | resize images, when resize is not None. 31 | :param images: torch.Tensor[NxCxHxW] 32 | :param resize: None means do nothing, or target_size[int]. target_size will be convert to (target_size, target_size) 33 | :param resize_mode: interpolate mode. 34 | :return: resized images. 35 | """ 36 | if resize is None: 37 | return images 38 | if isinstance(resize, (int, float)): 39 | resize = (int(resize), int(resize * images.shape[-1] / images.shape[-2])) 40 | resize = (resize, resize) if isinstance(resize, int) else resize 41 | if resize[0] != images[0].size(-2) or resize[1] != images[0].size(-1): 42 | align_corners = False if resize_mode in ["linear", "bilinear", "bicubic", "trilinear"] else None 43 | images = F.interpolate(images, size=resize, mode=resize_mode, align_corners=align_corners) 44 | return images 45 | 46 | 47 | def pil_loader(path: str, mode="RGB") -> Image.Image: 48 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 49 | with open(path, 'rb') as f: 50 | img = Image.open(f) 51 | return img.convert(mode) 52 | 53 | 54 | def load_images( 55 | images_path: Union[str, Path, Iterable], 56 | resize: Optional[Union[int, Tuple]] = None, 57 | value_range: Tuple[int, int] = (-1, 1), 58 | device: Union[torch.device, str] = torch.device("cuda"), 59 | image_mode: str = "RGB", 60 | resize_mode: str = "bilinear", 61 | ) -> torch.Tensor: 62 | """ 63 | read images into tensor 64 | :param images_path: 65 | :param resize: 66 | :param value_range: 67 | :param device: 68 | :param image_mode: accept a string to specify the mode of image, must in 69 | https://pillow.readthedocs.io/en/latest/handbook/concepts.html#modes 70 | :param resize_mode: accept a string to specify the mode of resize(interpolate), must in 71 | https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.interpolate 72 | :return: images (Tensor[num_images, image_channels, image_height, image_width]) 73 | """ 74 | if isinstance(images_path, (str, Path)): 75 | images_path = [images_path, ] 76 | images = [] 77 | for image_path in images_path: 78 | pil_image = pil_loader(image_path, image_mode) 79 | # 1xCxHxW, value_range: [0, 1] 80 | image = torchvision.transforms.functional.to_tensor(pil_image).unsqueeze_(0) 81 | images.append(resize_images(image, resize=resize, resize_mode=resize_mode)) 82 | images = torch.cat(images).to(device) 83 | images = images * (value_range[1] - value_range[0]) + value_range[0] 84 | return images 85 | 86 | 87 | def draw_captions_over_image( 88 | pil_image: Image, 89 | captions: Sequence, 90 | grid_cell_size: Tuple[int, int], 91 | grid_cell_padding: int, 92 | caption_color: str = "#ff0000", 93 | caption_font="DejaVuSans.ttf" 94 | ) -> Image: 95 | """ 96 | draw captions over grid image. use grid_cell_size to specify minimal cell size in grid. 97 | :param pil_image: grid image 98 | :param captions: a sequence of value in (None, str, tuple). 99 | value can be None, which means skip this grid cell; 100 | can be str, which is captions; 101 | can be tuple of (caption, color), for specify color for this caption. 102 | :param grid_cell_size: tuple (height, width) 103 | :param grid_cell_padding: padding when make grid 104 | :param caption_color: the color of the captions, default is red "#ff0000" 105 | :param caption_font: the font of the captions, default is DejaVuSans.ttf, 106 | will find font as https://pillow.readthedocs.io/en/latest/reference/ImageFont.html#PIL.ImageFont.truetype 107 | :return: 108 | """ 109 | h, w = grid_cell_size 110 | padding = grid_cell_padding 111 | nrow = pil_image.width // w 112 | im_draw = ImageDraw.Draw(pil_image) 113 | try: 114 | im_font = ImageFont.truetype(caption_font, size=max(h // 10, 12)) 115 | except OSError: 116 | warnings.warn(f"can not find {caption_font}, so use the default font, better than nothing") 117 | im_font = ImageFont.load_default() 118 | 119 | for i, cap in enumerate(captions): 120 | if cap is None: 121 | continue 122 | cap, fill_color = cap, caption_color if not isinstance(cap, (tuple, list)) else cap 123 | im_draw.text( 124 | ((i % nrow) * (w + padding) - padding, (i // nrow) * (h + padding) - padding), 125 | cap, 126 | fill=fill_color, 127 | font=im_font 128 | ) 129 | return pil_image 130 | 131 | 132 | def infer_pleasant_nrow(length: int): 133 | sqrt_nrow_candidate = int(math.sqrt(length)) 134 | if sqrt_nrow_candidate ** 2 == length: 135 | return sqrt_nrow_candidate 136 | return 2 ** int(math.log2(math.sqrt(length)) + 1) 137 | 138 | 139 | def save_images( 140 | images: Union[torch.Tensor, List[torch.Tensor]], 141 | save_path: Union[Text, Path, Sequence], 142 | captions: Optional[Union[bool, Sequence]] = None, 143 | resize: Optional[Union[int, Tuple[int, int]]] = None, 144 | separately: bool = False, 145 | nrow: Optional[int] = None, 146 | normalize: bool = True, 147 | value_range: Optional[Tuple[int, int]] = (-1, 1), 148 | scale_each: bool = False, 149 | padding: int = 0, 150 | pad_value: int = 0, 151 | caption_color: str = "#ff0000", 152 | caption_font="DejaVuSans.ttf" 153 | ): 154 | """ 155 | save images 156 | :param images: (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W) 157 | or a list of images all of the same size. 158 | :param save_path: path to save images 159 | :param captions: 160 | :param resize: if not None, resize images. 161 | :param separately: if True, save images separately rather make grid 162 | :param nrow: Number of images displayed in each row of the grid. 163 | The final grid size is ``(B / nrow, nrow)``. 164 | :param normalize: If True, shift the image to the range (0, 1), 165 | by the min and max values specified by :attr:`range`. Default: ``False``. 166 | :param value_range: tuple (min, max) where min and max are numbers, 167 | then these numbers are used to normalize the image. 168 | :param scale_each: If ``True``, scale each image in the batch of 169 | images separately rather than the (min, max) over all images. Default: ``False``. 170 | :param padding: amount of padding. Default: ``0``. 171 | :param pad_value: Value for the padded pixels. Default: ``0``. 172 | :param caption_color: the color of the captions, default is red "#ff0000" 173 | :param caption_font: the font of the captions, default is DejaVuSans.ttf, 174 | will find font as https://pillow.readthedocs.io/en/latest/reference/ImageFont.html#PIL.ImageFont.truetype 175 | :return: None 176 | """ 177 | assert not (isinstance(save_path, (list, tuple)) and not separately), \ 178 | f"{save_path} separately: {separately}" 179 | if isinstance(captions, bool) and captions: 180 | captions = list(map(str, range(len(images)))) 181 | if not torch.is_tensor(images): 182 | images = torch.cat(images) 183 | images = resize_images(images, resize=resize) 184 | 185 | if not separately: 186 | if nrow is None: 187 | nrow = infer_pleasant_nrow(len(images)) 188 | grid_image = make_grid(images, nrow, padding, normalize, value_range, scale_each, pad_value) 189 | pil_image = Image.fromarray( 190 | grid_image.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()) 191 | if captions is not None: 192 | pil_image = draw_captions_over_image( 193 | pil_image, captions, 194 | grid_cell_size=(images[0].size(-2), images[0].size(-1)), 195 | grid_cell_padding=padding, 196 | caption_color=caption_color, 197 | caption_font=caption_font 198 | ) 199 | pil_image.save(save_path) 200 | return 201 | 202 | if isinstance(save_path, (str, Path)): 203 | save_path = Path(save_path) 204 | save_path = [save_path.with_name(f"{save_path.stem}_{i}{save_path.suffix}") for i in range(len(images))] 205 | assert len(save_path) >= len(images) 206 | for i in range(len(images)): 207 | image = images[i:i + 1] 208 | caption = None if captions is None else captions[i:i + 1] 209 | image = make_grid(image, nrow, padding, normalize, value_range, scale_each, pad_value) 210 | pil_image = Image.fromarray( 211 | image.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()) 212 | if caption is not None: 213 | pil_image = draw_captions_over_image( 214 | pil_image, captions, 215 | grid_cell_size=(image.size(-2), image.size(-1)), 216 | grid_cell_padding=padding, 217 | caption_color=caption_color, 218 | caption_font=caption_font 219 | ) 220 | pil_image.save(save_path[i]) 221 | 222 | 223 | def images_files(image_folder, recursive=False): 224 | pattern = "**/*" if recursive else "*" 225 | root = Path(image_folder).resolve() 226 | if not root.exists(): 227 | return [] 228 | files = [file for file in root.glob(pattern) if is_image_file(file.name)] 229 | files = sorted(files, key=os.path.getmtime) 230 | return files 231 | 232 | 233 | class LMDBCacheLoader: 234 | def __init__(self, lmdb_cache_path, loader=default_loader): 235 | self.loader = loader 236 | self.env = lmdb.open( 237 | lmdb_cache_path, 238 | map_size=1024 ** 4, 239 | readahead=False, 240 | ) 241 | if not self.env: 242 | raise IOError('Cannot open lmdb dataset', lmdb_cache_path) 243 | self.txn = self.env.begin(write=True) 244 | 245 | def __call__(self, file_path): 246 | assert isinstance(file_path, (str, Path)) 247 | file_path = file_path if isinstance(file_path, str) else str(file_path) 248 | 249 | key = file_path.encode('utf-8') 250 | result_bytes = self.txn.get(key) 251 | if result_bytes is None: 252 | result = self.loader(file_path) 253 | # save loaded result to lmdb dataset 254 | buffer = BytesIO() 255 | pickle.dump(result, buffer) 256 | self.txn.put(key, buffer.getvalue()) 257 | return result 258 | return pickle.load(BytesIO(result_bytes)) 259 | --------------------------------------------------------------------------------