├── .gitignore ├── README.md ├── lightweight_real_esrgan_anime ├── README.md ├── app.py ├── requirements.txt └── screenshot.png ├── real_esrgan ├── README.md ├── realesrgan │ ├── __init__.py │ ├── archs │ │ ├── __init__.py │ │ └── discriminator_arch.py │ ├── data │ │ ├── __init__.py │ │ ├── realesrgan_dataset.py │ │ └── realesrgan_paired_dataset.py │ ├── models │ │ ├── __init__.py │ │ ├── realesrgan_model.py │ │ └── realesrnet_model.py │ ├── train.py │ ├── utils.py │ ├── version.py │ └── weights │ │ └── README.md ├── requirements.txt ├── upscale_image.py └── upscale_image_rgba.py ├── realtime_srgan_anime ├── LICENSE ├── README.md ├── app.py ├── app_camera.py ├── demo.png └── requirements.txt └── yamnet ├── README.md ├── human_voice_example.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ml-examples 2 | -------------------------------------------------------------------------------- /lightweight_real_esrgan_anime/README.md: -------------------------------------------------------------------------------- 1 | ![](./screenshot.png) 2 | 3 | ``` 4 | pip install -r requirements.txt 5 | 6 | python app.py 7 | 8 | # Access shown URL. 9 | ``` -------------------------------------------------------------------------------- /lightweight_real_esrgan_anime/app.py: -------------------------------------------------------------------------------- 1 | import gradio 2 | from huggingface_hub import hf_hub_download 3 | import onnxruntime 4 | from PIL import Image 5 | import numpy as np 6 | 7 | path = hf_hub_download("xiongjie/lightweight-real-ESRGAN-anime", filename="RealESRGAN_x4plus_anime_4B32F.onnx") 8 | session = onnxruntime.InferenceSession(path, providers=["CPUExecutionProvider"]) 9 | 10 | def upscale(np_image_rgb): 11 | # From RGB to BGR 12 | np_image_bgr = np_image_rgb[:, :, ::-1] 13 | np_image_bgr = np_image_bgr.astype(np.float32) 14 | np_image_bgr /= 255 15 | np_image_bgr = np.transpose(np_image_bgr, (2, 0, 1)) 16 | np_image_bgr = np.expand_dims(np_image_bgr, axis=0) 17 | output_img = session.run([], {"image.1": np_image_bgr})[0] 18 | output_img = np.squeeze(output_img, axis=0).astype(np.float32).clip(0, 1) 19 | output_img = np.transpose(output_img, (1, 2, 0)) 20 | output = (output_img * 255.0).astype(np.uint8) 21 | # From BGR to RGB 22 | output = output[:, :, ::-1] 23 | 24 | return output 25 | 26 | css = ".output_image {height: 100% !important; width: 100% !important;}" 27 | inputs = gradio.inputs.Image() 28 | outputs = gradio.outputs.Image() 29 | gradio.Interface(fn=upscale, inputs=inputs, outputs=outputs, css=css).launch() -------------------------------------------------------------------------------- /lightweight_real_esrgan_anime/requirements.txt: -------------------------------------------------------------------------------- 1 | gradio 2 | huggingface-hub 3 | onnxruntime 4 | Pillow 5 | numpy -------------------------------------------------------------------------------- /lightweight_real_esrgan_anime/screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiong-jie-y/ml-examples/947e1fb8690a75f8a34db38aee4635d0319044d7/lightweight_real_esrgan_anime/screenshot.png -------------------------------------------------------------------------------- /real_esrgan/README.md: -------------------------------------------------------------------------------- 1 | place test.py in this folder. 2 | 3 | ``` 4 | pip install -r requirements.txt 5 | 6 | # Download pth model for anime 7 | gdown https://drive.google.com/uc?id=1cExySdxIOh0mw7XK_P_LZiijDKMx9m-p 8 | 9 | # Download onnx model for anime 10 | gdown https://drive.google.com/file/d/1mm7xflWKdCKvCtWDGwh8TUrGItNs-V3I/view?usp=sharing 11 | 12 | # There's also onnx model for wide range of images. 13 | # https://github.com/PINTO0309/PINTO_model_zoo/tree/main/133_Real-ESRGAN 14 | 15 | # (1) No dependency to realsrgan package. 16 | python upscale_image_rgba.py 17 | 18 | # (2) shorter code 19 | python upscale_image.py 20 | ``` -------------------------------------------------------------------------------- /real_esrgan/realesrgan/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from .archs import * 3 | from .data import * 4 | from .models import * 5 | from .utils import * 6 | from .version import __gitsha__, __version__ 7 | -------------------------------------------------------------------------------- /real_esrgan/realesrgan/archs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from basicsr.utils import scandir 3 | from os import path as osp 4 | 5 | # automatically scan and import arch modules for registry 6 | # scan all the files that end with '_arch.py' under the archs folder 7 | arch_folder = osp.dirname(osp.abspath(__file__)) 8 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] 9 | # import all the arch modules 10 | _arch_modules = [importlib.import_module(f'realesrgan.archs.{file_name}') for file_name in arch_filenames] 11 | -------------------------------------------------------------------------------- /real_esrgan/realesrgan/archs/discriminator_arch.py: -------------------------------------------------------------------------------- 1 | from basicsr.utils.registry import ARCH_REGISTRY 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | from torch.nn.utils import spectral_norm 5 | 6 | 7 | @ARCH_REGISTRY.register() 8 | class UNetDiscriminatorSN(nn.Module): 9 | """Defines a U-Net discriminator with spectral normalization (SN)""" 10 | 11 | def __init__(self, num_in_ch, num_feat=64, skip_connection=True): 12 | super(UNetDiscriminatorSN, self).__init__() 13 | self.skip_connection = skip_connection 14 | norm = spectral_norm 15 | 16 | self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1) 17 | 18 | self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False)) 19 | self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False)) 20 | self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False)) 21 | # upsample 22 | self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False)) 23 | self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False)) 24 | self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False)) 25 | 26 | # extra 27 | self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) 28 | self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) 29 | 30 | self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1) 31 | 32 | def forward(self, x): 33 | x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True) 34 | x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True) 35 | x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True) 36 | x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True) 37 | 38 | # upsample 39 | x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False) 40 | x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True) 41 | 42 | if self.skip_connection: 43 | x4 = x4 + x2 44 | x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) 45 | x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True) 46 | 47 | if self.skip_connection: 48 | x5 = x5 + x1 49 | x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False) 50 | x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True) 51 | 52 | if self.skip_connection: 53 | x6 = x6 + x0 54 | 55 | # extra 56 | out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True) 57 | out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True) 58 | out = self.conv9(out) 59 | 60 | return out 61 | -------------------------------------------------------------------------------- /real_esrgan/realesrgan/data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from basicsr.utils import scandir 3 | from os import path as osp 4 | 5 | # automatically scan and import dataset modules for registry 6 | # scan all the files that end with '_dataset.py' under the data folder 7 | data_folder = osp.dirname(osp.abspath(__file__)) 8 | dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] 9 | # import all the dataset modules 10 | _dataset_modules = [importlib.import_module(f'realesrgan.data.{file_name}') for file_name in dataset_filenames] 11 | -------------------------------------------------------------------------------- /real_esrgan/realesrgan/data/realesrgan_dataset.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import numpy as np 4 | import os 5 | import os.path as osp 6 | import random 7 | import time 8 | import torch 9 | from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels 10 | from basicsr.data.transforms import augment 11 | from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor 12 | from basicsr.utils.registry import DATASET_REGISTRY 13 | from torch.utils import data as data 14 | 15 | 16 | @DATASET_REGISTRY.register() 17 | class RealESRGANDataset(data.Dataset): 18 | """ 19 | Dataset used for Real-ESRGAN model. 20 | """ 21 | 22 | def __init__(self, opt): 23 | super(RealESRGANDataset, self).__init__() 24 | self.opt = opt 25 | # file client (io backend) 26 | self.file_client = None 27 | self.io_backend_opt = opt['io_backend'] 28 | self.gt_folder = opt['dataroot_gt'] 29 | 30 | if self.io_backend_opt['type'] == 'lmdb': 31 | self.io_backend_opt['db_paths'] = [self.gt_folder] 32 | self.io_backend_opt['client_keys'] = ['gt'] 33 | if not self.gt_folder.endswith('.lmdb'): 34 | raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}") 35 | with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: 36 | self.paths = [line.split('.')[0] for line in fin] 37 | else: 38 | with open(self.opt['meta_info']) as fin: 39 | paths = [line.strip() for line in fin] 40 | self.paths = [os.path.join(self.gt_folder, v) for v in paths] 41 | 42 | # blur settings for the first degradation 43 | self.blur_kernel_size = opt['blur_kernel_size'] 44 | self.kernel_list = opt['kernel_list'] 45 | self.kernel_prob = opt['kernel_prob'] 46 | self.blur_sigma = opt['blur_sigma'] 47 | self.betag_range = opt['betag_range'] 48 | self.betap_range = opt['betap_range'] 49 | self.sinc_prob = opt['sinc_prob'] 50 | 51 | # blur settings for the second degradation 52 | self.blur_kernel_size2 = opt['blur_kernel_size2'] 53 | self.kernel_list2 = opt['kernel_list2'] 54 | self.kernel_prob2 = opt['kernel_prob2'] 55 | self.blur_sigma2 = opt['blur_sigma2'] 56 | self.betag_range2 = opt['betag_range2'] 57 | self.betap_range2 = opt['betap_range2'] 58 | self.sinc_prob2 = opt['sinc_prob2'] 59 | 60 | # a final sinc filter 61 | self.final_sinc_prob = opt['final_sinc_prob'] 62 | 63 | self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21 64 | self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect 65 | self.pulse_tensor[10, 10] = 1 66 | 67 | def __getitem__(self, index): 68 | if self.file_client is None: 69 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 70 | 71 | # -------------------------------- Load gt images -------------------------------- # 72 | # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32. 73 | gt_path = self.paths[index] 74 | # avoid errors caused by high latency in reading files 75 | retry = 3 76 | while retry > 0: 77 | try: 78 | img_bytes = self.file_client.get(gt_path, 'gt') 79 | except Exception as e: 80 | logger = get_root_logger() 81 | logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}') 82 | # change another file to read 83 | index = random.randint(0, self.__len__()) 84 | gt_path = self.paths[index] 85 | time.sleep(1) # sleep 1s for occasional server congestion 86 | else: 87 | break 88 | finally: 89 | retry -= 1 90 | img_gt = imfrombytes(img_bytes, float32=True) 91 | 92 | # -------------------- augmentation for training: flip, rotation -------------------- # 93 | img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot']) 94 | 95 | # crop or pad to 400: 400 is hard-coded. You may change it accordingly 96 | h, w = img_gt.shape[0:2] 97 | crop_pad_size = 400 98 | # pad 99 | if h < crop_pad_size or w < crop_pad_size: 100 | pad_h = max(0, crop_pad_size - h) 101 | pad_w = max(0, crop_pad_size - w) 102 | img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101) 103 | # crop 104 | if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size: 105 | h, w = img_gt.shape[0:2] 106 | # randomly choose top and left coordinates 107 | top = random.randint(0, h - crop_pad_size) 108 | left = random.randint(0, w - crop_pad_size) 109 | img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...] 110 | 111 | # ------------------------ Generate kernels (used in the first degradation) ------------------------ # 112 | kernel_size = random.choice(self.kernel_range) 113 | if np.random.uniform() < self.opt['sinc_prob']: 114 | # this sinc filter setting is for kernels ranging from [7, 21] 115 | if kernel_size < 13: 116 | omega_c = np.random.uniform(np.pi / 3, np.pi) 117 | else: 118 | omega_c = np.random.uniform(np.pi / 5, np.pi) 119 | kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) 120 | else: 121 | kernel = random_mixed_kernels( 122 | self.kernel_list, 123 | self.kernel_prob, 124 | kernel_size, 125 | self.blur_sigma, 126 | self.blur_sigma, [-math.pi, math.pi], 127 | self.betag_range, 128 | self.betap_range, 129 | noise_range=None) 130 | # pad kernel 131 | pad_size = (21 - kernel_size) // 2 132 | kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size))) 133 | 134 | # ------------------------ Generate kernels (used in the second degradation) ------------------------ # 135 | kernel_size = random.choice(self.kernel_range) 136 | if np.random.uniform() < self.opt['sinc_prob2']: 137 | if kernel_size < 13: 138 | omega_c = np.random.uniform(np.pi / 3, np.pi) 139 | else: 140 | omega_c = np.random.uniform(np.pi / 5, np.pi) 141 | kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) 142 | else: 143 | kernel2 = random_mixed_kernels( 144 | self.kernel_list2, 145 | self.kernel_prob2, 146 | kernel_size, 147 | self.blur_sigma2, 148 | self.blur_sigma2, [-math.pi, math.pi], 149 | self.betag_range2, 150 | self.betap_range2, 151 | noise_range=None) 152 | 153 | # pad kernel 154 | pad_size = (21 - kernel_size) // 2 155 | kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size))) 156 | 157 | # ------------------------------------- sinc kernel ------------------------------------- # 158 | if np.random.uniform() < self.opt['final_sinc_prob']: 159 | kernel_size = random.choice(self.kernel_range) 160 | omega_c = np.random.uniform(np.pi / 3, np.pi) 161 | sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21) 162 | sinc_kernel = torch.FloatTensor(sinc_kernel) 163 | else: 164 | sinc_kernel = self.pulse_tensor 165 | 166 | # BGR to RGB, HWC to CHW, numpy to tensor 167 | img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0] 168 | kernel = torch.FloatTensor(kernel) 169 | kernel2 = torch.FloatTensor(kernel2) 170 | 171 | return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path} 172 | return return_d 173 | 174 | def __len__(self): 175 | return len(self.paths) 176 | -------------------------------------------------------------------------------- /real_esrgan/realesrgan/data/realesrgan_paired_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb 3 | from basicsr.data.transforms import augment, paired_random_crop 4 | from basicsr.utils import FileClient, imfrombytes, img2tensor 5 | from basicsr.utils.registry import DATASET_REGISTRY 6 | from torch.utils import data as data 7 | from torchvision.transforms.functional import normalize 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class RealESRGANPairedDataset(data.Dataset): 12 | """Paired image dataset for image restoration. 13 | 14 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and 15 | GT image pairs. 16 | 17 | There are three modes: 18 | 1. 'lmdb': Use lmdb files. 19 | If opt['io_backend'] == lmdb. 20 | 2. 'meta_info': Use meta information file to generate paths. 21 | If opt['io_backend'] != lmdb and opt['meta_info'] is not None. 22 | 3. 'folder': Scan folders to generate paths. 23 | The rest. 24 | 25 | Args: 26 | opt (dict): Config for train datasets. It contains the following keys: 27 | dataroot_gt (str): Data root path for gt. 28 | dataroot_lq (str): Data root path for lq. 29 | meta_info (str): Path for meta information file. 30 | io_backend (dict): IO backend type and other kwarg. 31 | filename_tmpl (str): Template for each filename. Note that the 32 | template excludes the file extension. Default: '{}'. 33 | gt_size (int): Cropped patched size for gt patches. 34 | use_hflip (bool): Use horizontal flips. 35 | use_rot (bool): Use rotation (use vertical flip and transposing h 36 | and w for implementation). 37 | 38 | scale (bool): Scale, which will be added automatically. 39 | phase (str): 'train' or 'val'. 40 | """ 41 | 42 | def __init__(self, opt): 43 | super(RealESRGANPairedDataset, self).__init__() 44 | self.opt = opt 45 | # file client (io backend) 46 | self.file_client = None 47 | self.io_backend_opt = opt['io_backend'] 48 | self.mean = opt['mean'] if 'mean' in opt else None 49 | self.std = opt['std'] if 'std' in opt else None 50 | 51 | self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] 52 | if 'filename_tmpl' in opt: 53 | self.filename_tmpl = opt['filename_tmpl'] 54 | else: 55 | self.filename_tmpl = '{}' 56 | 57 | if self.io_backend_opt['type'] == 'lmdb': 58 | self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] 59 | self.io_backend_opt['client_keys'] = ['lq', 'gt'] 60 | self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt']) 61 | elif 'meta_info' in self.opt and self.opt['meta_info'] is not None: 62 | with open(self.opt['meta_info']) as fin: 63 | paths = [line.strip() for line in fin] 64 | self.paths = [] 65 | for path in paths: 66 | gt_path, lq_path = path.split(', ') 67 | gt_path = os.path.join(self.gt_folder, gt_path) 68 | lq_path = os.path.join(self.lq_folder, lq_path) 69 | self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)])) 70 | else: 71 | self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) 72 | 73 | def __getitem__(self, index): 74 | if self.file_client is None: 75 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 76 | 77 | scale = self.opt['scale'] 78 | 79 | # Load gt and lq images. Dimension order: HWC; channel order: BGR; 80 | # image range: [0, 1], float32. 81 | gt_path = self.paths[index]['gt_path'] 82 | img_bytes = self.file_client.get(gt_path, 'gt') 83 | img_gt = imfrombytes(img_bytes, float32=True) 84 | lq_path = self.paths[index]['lq_path'] 85 | img_bytes = self.file_client.get(lq_path, 'lq') 86 | img_lq = imfrombytes(img_bytes, float32=True) 87 | 88 | # augmentation for training 89 | if self.opt['phase'] == 'train': 90 | gt_size = self.opt['gt_size'] 91 | # random crop 92 | img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) 93 | # flip, rotation 94 | img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot']) 95 | 96 | # BGR to RGB, HWC to CHW, numpy to tensor 97 | img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) 98 | # normalize 99 | if self.mean is not None or self.std is not None: 100 | normalize(img_lq, self.mean, self.std, inplace=True) 101 | normalize(img_gt, self.mean, self.std, inplace=True) 102 | 103 | return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path} 104 | 105 | def __len__(self): 106 | return len(self.paths) 107 | -------------------------------------------------------------------------------- /real_esrgan/realesrgan/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from basicsr.utils import scandir 3 | from os import path as osp 4 | 5 | # automatically scan and import model modules for registry 6 | # scan all the files that end with '_model.py' under the model folder 7 | model_folder = osp.dirname(osp.abspath(__file__)) 8 | model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] 9 | # import all the model modules 10 | _model_modules = [importlib.import_module(f'realesrgan.models.{file_name}') for file_name in model_filenames] 11 | -------------------------------------------------------------------------------- /real_esrgan/realesrgan/models/realesrgan_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import torch 4 | from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt 5 | from basicsr.data.transforms import paired_random_crop 6 | from basicsr.models.srgan_model import SRGANModel 7 | from basicsr.utils import DiffJPEG, USMSharp 8 | from basicsr.utils.img_process_util import filter2D 9 | from basicsr.utils.registry import MODEL_REGISTRY 10 | from collections import OrderedDict 11 | from torch.nn import functional as F 12 | 13 | 14 | @MODEL_REGISTRY.register() 15 | class RealESRGANModel(SRGANModel): 16 | """RealESRGAN Model""" 17 | 18 | def __init__(self, opt): 19 | super(RealESRGANModel, self).__init__(opt) 20 | self.jpeger = DiffJPEG(differentiable=False).cuda() 21 | self.usm_sharpener = USMSharp().cuda() 22 | self.queue_size = opt.get('queue_size', 180) 23 | 24 | @torch.no_grad() 25 | def _dequeue_and_enqueue(self): 26 | # training pair pool 27 | # initialize 28 | b, c, h, w = self.lq.size() 29 | if not hasattr(self, 'queue_lr'): 30 | assert self.queue_size % b == 0, 'queue size should be divisible by batch size' 31 | self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda() 32 | _, c, h, w = self.gt.size() 33 | self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda() 34 | self.queue_ptr = 0 35 | if self.queue_ptr == self.queue_size: # full 36 | # do dequeue and enqueue 37 | # shuffle 38 | idx = torch.randperm(self.queue_size) 39 | self.queue_lr = self.queue_lr[idx] 40 | self.queue_gt = self.queue_gt[idx] 41 | # get 42 | lq_dequeue = self.queue_lr[0:b, :, :, :].clone() 43 | gt_dequeue = self.queue_gt[0:b, :, :, :].clone() 44 | # update 45 | self.queue_lr[0:b, :, :, :] = self.lq.clone() 46 | self.queue_gt[0:b, :, :, :] = self.gt.clone() 47 | 48 | self.lq = lq_dequeue 49 | self.gt = gt_dequeue 50 | else: 51 | # only do enqueue 52 | self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone() 53 | self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone() 54 | self.queue_ptr = self.queue_ptr + b 55 | 56 | @torch.no_grad() 57 | def feed_data(self, data): 58 | if self.is_train and self.opt.get('high_order_degradation', True): 59 | # training data synthesis 60 | self.gt = data['gt'].to(self.device) 61 | self.gt_usm = self.usm_sharpener(self.gt) 62 | 63 | self.kernel1 = data['kernel1'].to(self.device) 64 | self.kernel2 = data['kernel2'].to(self.device) 65 | self.sinc_kernel = data['sinc_kernel'].to(self.device) 66 | 67 | ori_h, ori_w = self.gt.size()[2:4] 68 | 69 | # ----------------------- The first degradation process ----------------------- # 70 | # blur 71 | out = filter2D(self.gt_usm, self.kernel1) 72 | # random resize 73 | updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0] 74 | if updown_type == 'up': 75 | scale = np.random.uniform(1, self.opt['resize_range'][1]) 76 | elif updown_type == 'down': 77 | scale = np.random.uniform(self.opt['resize_range'][0], 1) 78 | else: 79 | scale = 1 80 | mode = random.choice(['area', 'bilinear', 'bicubic']) 81 | out = F.interpolate(out, scale_factor=scale, mode=mode) 82 | # noise 83 | gray_noise_prob = self.opt['gray_noise_prob'] 84 | if np.random.uniform() < self.opt['gaussian_noise_prob']: 85 | out = random_add_gaussian_noise_pt( 86 | out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob) 87 | else: 88 | out = random_add_poisson_noise_pt( 89 | out, 90 | scale_range=self.opt['poisson_scale_range'], 91 | gray_prob=gray_noise_prob, 92 | clip=True, 93 | rounds=False) 94 | # JPEG compression 95 | jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range']) 96 | out = torch.clamp(out, 0, 1) 97 | out = self.jpeger(out, quality=jpeg_p) 98 | 99 | # ----------------------- The second degradation process ----------------------- # 100 | # blur 101 | if np.random.uniform() < self.opt['second_blur_prob']: 102 | out = filter2D(out, self.kernel2) 103 | # random resize 104 | updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0] 105 | if updown_type == 'up': 106 | scale = np.random.uniform(1, self.opt['resize_range2'][1]) 107 | elif updown_type == 'down': 108 | scale = np.random.uniform(self.opt['resize_range2'][0], 1) 109 | else: 110 | scale = 1 111 | mode = random.choice(['area', 'bilinear', 'bicubic']) 112 | out = F.interpolate( 113 | out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode) 114 | # noise 115 | gray_noise_prob = self.opt['gray_noise_prob2'] 116 | if np.random.uniform() < self.opt['gaussian_noise_prob2']: 117 | out = random_add_gaussian_noise_pt( 118 | out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob) 119 | else: 120 | out = random_add_poisson_noise_pt( 121 | out, 122 | scale_range=self.opt['poisson_scale_range2'], 123 | gray_prob=gray_noise_prob, 124 | clip=True, 125 | rounds=False) 126 | 127 | # JPEG compression + the final sinc filter 128 | # We also need to resize images to desired sizes. We group [resize back + sinc filter] together 129 | # as one operation. 130 | # We consider two orders: 131 | # 1. [resize back + sinc filter] + JPEG compression 132 | # 2. JPEG compression + [resize back + sinc filter] 133 | # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines. 134 | if np.random.uniform() < 0.5: 135 | # resize back + the final sinc filter 136 | mode = random.choice(['area', 'bilinear', 'bicubic']) 137 | out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) 138 | out = filter2D(out, self.sinc_kernel) 139 | # JPEG compression 140 | jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) 141 | out = torch.clamp(out, 0, 1) 142 | out = self.jpeger(out, quality=jpeg_p) 143 | else: 144 | # JPEG compression 145 | jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) 146 | out = torch.clamp(out, 0, 1) 147 | out = self.jpeger(out, quality=jpeg_p) 148 | # resize back + the final sinc filter 149 | mode = random.choice(['area', 'bilinear', 'bicubic']) 150 | out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) 151 | out = filter2D(out, self.sinc_kernel) 152 | 153 | # clamp and round 154 | self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255. 155 | 156 | # random crop 157 | gt_size = self.opt['gt_size'] 158 | (self.gt, self.gt_usm), self.lq = paired_random_crop([self.gt, self.gt_usm], self.lq, gt_size, 159 | self.opt['scale']) 160 | 161 | # training pair pool 162 | self._dequeue_and_enqueue() 163 | # sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue 164 | self.gt_usm = self.usm_sharpener(self.gt) 165 | else: 166 | self.lq = data['lq'].to(self.device) 167 | if 'gt' in data: 168 | self.gt = data['gt'].to(self.device) 169 | self.gt_usm = self.usm_sharpener(self.gt) 170 | 171 | def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): 172 | # do not use the synthetic process during validation 173 | self.is_train = False 174 | super(RealESRGANModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img) 175 | self.is_train = True 176 | 177 | def optimize_parameters(self, current_iter): 178 | l1_gt = self.gt_usm 179 | percep_gt = self.gt_usm 180 | gan_gt = self.gt_usm 181 | if self.opt['l1_gt_usm'] is False: 182 | l1_gt = self.gt 183 | if self.opt['percep_gt_usm'] is False: 184 | percep_gt = self.gt 185 | if self.opt['gan_gt_usm'] is False: 186 | gan_gt = self.gt 187 | 188 | # optimize net_g 189 | for p in self.net_d.parameters(): 190 | p.requires_grad = False 191 | 192 | self.optimizer_g.zero_grad() 193 | self.output = self.net_g(self.lq) 194 | 195 | l_g_total = 0 196 | loss_dict = OrderedDict() 197 | if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters): 198 | # pixel loss 199 | if self.cri_pix: 200 | l_g_pix = self.cri_pix(self.output, l1_gt) 201 | l_g_total += l_g_pix 202 | loss_dict['l_g_pix'] = l_g_pix 203 | # perceptual loss 204 | if self.cri_perceptual: 205 | l_g_percep, l_g_style = self.cri_perceptual(self.output, percep_gt) 206 | if l_g_percep is not None: 207 | l_g_total += l_g_percep 208 | loss_dict['l_g_percep'] = l_g_percep 209 | if l_g_style is not None: 210 | l_g_total += l_g_style 211 | loss_dict['l_g_style'] = l_g_style 212 | # gan loss 213 | fake_g_pred = self.net_d(self.output) 214 | l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) 215 | l_g_total += l_g_gan 216 | loss_dict['l_g_gan'] = l_g_gan 217 | 218 | l_g_total.backward() 219 | self.optimizer_g.step() 220 | 221 | # optimize net_d 222 | for p in self.net_d.parameters(): 223 | p.requires_grad = True 224 | 225 | self.optimizer_d.zero_grad() 226 | # real 227 | real_d_pred = self.net_d(gan_gt) 228 | l_d_real = self.cri_gan(real_d_pred, True, is_disc=True) 229 | loss_dict['l_d_real'] = l_d_real 230 | loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) 231 | l_d_real.backward() 232 | # fake 233 | fake_d_pred = self.net_d(self.output.detach().clone()) # clone for pt1.9 234 | l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True) 235 | loss_dict['l_d_fake'] = l_d_fake 236 | loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) 237 | l_d_fake.backward() 238 | self.optimizer_d.step() 239 | 240 | if self.ema_decay > 0: 241 | self.model_ema(decay=self.ema_decay) 242 | 243 | self.log_dict = self.reduce_loss_dict(loss_dict) 244 | -------------------------------------------------------------------------------- /real_esrgan/realesrgan/models/realesrnet_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import torch 4 | from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt 5 | from basicsr.data.transforms import paired_random_crop 6 | from basicsr.models.sr_model import SRModel 7 | from basicsr.utils import DiffJPEG, USMSharp 8 | from basicsr.utils.img_process_util import filter2D 9 | from basicsr.utils.registry import MODEL_REGISTRY 10 | from torch.nn import functional as F 11 | 12 | 13 | @MODEL_REGISTRY.register() 14 | class RealESRNetModel(SRModel): 15 | """RealESRNet Model""" 16 | 17 | def __init__(self, opt): 18 | super(RealESRNetModel, self).__init__(opt) 19 | self.jpeger = DiffJPEG(differentiable=False).cuda() 20 | self.usm_sharpener = USMSharp().cuda() 21 | self.queue_size = opt.get('queue_size', 180) 22 | 23 | @torch.no_grad() 24 | def _dequeue_and_enqueue(self): 25 | # training pair pool 26 | # initialize 27 | b, c, h, w = self.lq.size() 28 | if not hasattr(self, 'queue_lr'): 29 | assert self.queue_size % b == 0, 'queue size should be divisible by batch size' 30 | self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda() 31 | _, c, h, w = self.gt.size() 32 | self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda() 33 | self.queue_ptr = 0 34 | if self.queue_ptr == self.queue_size: # full 35 | # do dequeue and enqueue 36 | # shuffle 37 | idx = torch.randperm(self.queue_size) 38 | self.queue_lr = self.queue_lr[idx] 39 | self.queue_gt = self.queue_gt[idx] 40 | # get 41 | lq_dequeue = self.queue_lr[0:b, :, :, :].clone() 42 | gt_dequeue = self.queue_gt[0:b, :, :, :].clone() 43 | # update 44 | self.queue_lr[0:b, :, :, :] = self.lq.clone() 45 | self.queue_gt[0:b, :, :, :] = self.gt.clone() 46 | 47 | self.lq = lq_dequeue 48 | self.gt = gt_dequeue 49 | else: 50 | # only do enqueue 51 | self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone() 52 | self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone() 53 | self.queue_ptr = self.queue_ptr + b 54 | 55 | @torch.no_grad() 56 | def feed_data(self, data): 57 | if self.is_train and self.opt.get('high_order_degradation', True): 58 | # training data synthesis 59 | self.gt = data['gt'].to(self.device) 60 | # USM the GT images 61 | if self.opt['gt_usm'] is True: 62 | self.gt = self.usm_sharpener(self.gt) 63 | 64 | self.kernel1 = data['kernel1'].to(self.device) 65 | self.kernel2 = data['kernel2'].to(self.device) 66 | self.sinc_kernel = data['sinc_kernel'].to(self.device) 67 | 68 | ori_h, ori_w = self.gt.size()[2:4] 69 | 70 | # ----------------------- The first degradation process ----------------------- # 71 | # blur 72 | out = filter2D(self.gt, self.kernel1) 73 | # random resize 74 | updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0] 75 | if updown_type == 'up': 76 | scale = np.random.uniform(1, self.opt['resize_range'][1]) 77 | elif updown_type == 'down': 78 | scale = np.random.uniform(self.opt['resize_range'][0], 1) 79 | else: 80 | scale = 1 81 | mode = random.choice(['area', 'bilinear', 'bicubic']) 82 | out = F.interpolate(out, scale_factor=scale, mode=mode) 83 | # noise 84 | gray_noise_prob = self.opt['gray_noise_prob'] 85 | if np.random.uniform() < self.opt['gaussian_noise_prob']: 86 | out = random_add_gaussian_noise_pt( 87 | out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob) 88 | else: 89 | out = random_add_poisson_noise_pt( 90 | out, 91 | scale_range=self.opt['poisson_scale_range'], 92 | gray_prob=gray_noise_prob, 93 | clip=True, 94 | rounds=False) 95 | # JPEG compression 96 | jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range']) 97 | out = torch.clamp(out, 0, 1) 98 | out = self.jpeger(out, quality=jpeg_p) 99 | 100 | # ----------------------- The second degradation process ----------------------- # 101 | # blur 102 | if np.random.uniform() < self.opt['second_blur_prob']: 103 | out = filter2D(out, self.kernel2) 104 | # random resize 105 | updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0] 106 | if updown_type == 'up': 107 | scale = np.random.uniform(1, self.opt['resize_range2'][1]) 108 | elif updown_type == 'down': 109 | scale = np.random.uniform(self.opt['resize_range2'][0], 1) 110 | else: 111 | scale = 1 112 | mode = random.choice(['area', 'bilinear', 'bicubic']) 113 | out = F.interpolate( 114 | out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode) 115 | # noise 116 | gray_noise_prob = self.opt['gray_noise_prob2'] 117 | if np.random.uniform() < self.opt['gaussian_noise_prob2']: 118 | out = random_add_gaussian_noise_pt( 119 | out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob) 120 | else: 121 | out = random_add_poisson_noise_pt( 122 | out, 123 | scale_range=self.opt['poisson_scale_range2'], 124 | gray_prob=gray_noise_prob, 125 | clip=True, 126 | rounds=False) 127 | 128 | # JPEG compression + the final sinc filter 129 | # We also need to resize images to desired sizes. We group [resize back + sinc filter] together 130 | # as one operation. 131 | # We consider two orders: 132 | # 1. [resize back + sinc filter] + JPEG compression 133 | # 2. JPEG compression + [resize back + sinc filter] 134 | # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines. 135 | if np.random.uniform() < 0.5: 136 | # resize back + the final sinc filter 137 | mode = random.choice(['area', 'bilinear', 'bicubic']) 138 | out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) 139 | out = filter2D(out, self.sinc_kernel) 140 | # JPEG compression 141 | jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) 142 | out = torch.clamp(out, 0, 1) 143 | out = self.jpeger(out, quality=jpeg_p) 144 | else: 145 | # JPEG compression 146 | jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) 147 | out = torch.clamp(out, 0, 1) 148 | out = self.jpeger(out, quality=jpeg_p) 149 | # resize back + the final sinc filter 150 | mode = random.choice(['area', 'bilinear', 'bicubic']) 151 | out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) 152 | out = filter2D(out, self.sinc_kernel) 153 | 154 | # clamp and round 155 | self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255. 156 | 157 | # random crop 158 | gt_size = self.opt['gt_size'] 159 | self.gt, self.lq = paired_random_crop(self.gt, self.lq, gt_size, self.opt['scale']) 160 | 161 | # training pair pool 162 | self._dequeue_and_enqueue() 163 | else: 164 | self.lq = data['lq'].to(self.device) 165 | if 'gt' in data: 166 | self.gt = data['gt'].to(self.device) 167 | self.gt_usm = self.usm_sharpener(self.gt) 168 | 169 | def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): 170 | # do not use the synthetic process during validation 171 | self.is_train = False 172 | super(RealESRNetModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img) 173 | self.is_train = True 174 | -------------------------------------------------------------------------------- /real_esrgan/realesrgan/train.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | import os.path as osp 3 | from basicsr.train import train_pipeline 4 | 5 | import realesrgan.archs 6 | import realesrgan.data 7 | import realesrgan.models 8 | 9 | if __name__ == '__main__': 10 | root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) 11 | train_pipeline(root_path) 12 | -------------------------------------------------------------------------------- /real_esrgan/realesrgan/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import numpy as np 4 | import os 5 | import torch 6 | from basicsr.archs.rrdbnet_arch import RRDBNet 7 | from torch.hub import download_url_to_file, get_dir 8 | from torch.nn import functional as F 9 | from urllib.parse import urlparse 10 | 11 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 12 | 13 | 14 | class RealESRGANer(): 15 | 16 | def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=10, half=False): 17 | self.scale = scale 18 | self.tile_size = tile 19 | self.tile_pad = tile_pad 20 | self.pre_pad = pre_pad 21 | self.mod_scale = None 22 | self.half = half 23 | 24 | # initialize model 25 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 26 | if model is None: 27 | model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale) 28 | 29 | if model_path.startswith('https://'): 30 | model_path = load_file_from_url( 31 | url=model_path, model_dir='realesrgan/weights', progress=True, file_name=None) 32 | loadnet = torch.load(model_path) 33 | if 'params_ema' in loadnet: 34 | keyname = 'params_ema' 35 | else: 36 | keyname = 'params' 37 | model.load_state_dict(loadnet[keyname], strict=True) 38 | model.eval() 39 | self.model = model.to(self.device) 40 | if self.half: 41 | self.model = self.model.half() 42 | 43 | def pre_process(self, img): 44 | img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float() 45 | self.img = img.unsqueeze(0).to(self.device) 46 | if self.half: 47 | self.img = self.img.half() 48 | 49 | # pre_pad 50 | if self.pre_pad != 0: 51 | self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect') 52 | # mod pad 53 | if self.scale == 2: 54 | self.mod_scale = 2 55 | elif self.scale == 1: 56 | self.mod_scale = 4 57 | if self.mod_scale is not None: 58 | self.mod_pad_h, self.mod_pad_w = 0, 0 59 | _, _, h, w = self.img.size() 60 | if (h % self.mod_scale != 0): 61 | self.mod_pad_h = (self.mod_scale - h % self.mod_scale) 62 | if (w % self.mod_scale != 0): 63 | self.mod_pad_w = (self.mod_scale - w % self.mod_scale) 64 | self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect') 65 | 66 | def process(self): 67 | import IPython; IPython.embed() 68 | self.output = self.model(self.img) 69 | 70 | def tile_process(self): 71 | """Modified from: https://github.com/ata4/esrgan-launcher 72 | """ 73 | batch, channel, height, width = self.img.shape 74 | output_height = height * self.scale 75 | output_width = width * self.scale 76 | output_shape = (batch, channel, output_height, output_width) 77 | 78 | # start with black image 79 | self.output = self.img.new_zeros(output_shape) 80 | tiles_x = math.ceil(width / self.tile_size) 81 | tiles_y = math.ceil(height / self.tile_size) 82 | 83 | # loop over all tiles 84 | for y in range(tiles_y): 85 | for x in range(tiles_x): 86 | # extract tile from input image 87 | ofs_x = x * self.tile_size 88 | ofs_y = y * self.tile_size 89 | # input tile area on total image 90 | input_start_x = ofs_x 91 | input_end_x = min(ofs_x + self.tile_size, width) 92 | input_start_y = ofs_y 93 | input_end_y = min(ofs_y + self.tile_size, height) 94 | 95 | # input tile area on total image with padding 96 | input_start_x_pad = max(input_start_x - self.tile_pad, 0) 97 | input_end_x_pad = min(input_end_x + self.tile_pad, width) 98 | input_start_y_pad = max(input_start_y - self.tile_pad, 0) 99 | input_end_y_pad = min(input_end_y + self.tile_pad, height) 100 | 101 | # input tile dimensions 102 | input_tile_width = input_end_x - input_start_x 103 | input_tile_height = input_end_y - input_start_y 104 | tile_idx = y * tiles_x + x + 1 105 | input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad] 106 | 107 | # upscale tile 108 | try: 109 | with torch.no_grad(): 110 | output_tile = self.model(input_tile) 111 | except Exception as error: 112 | print('Error', error) 113 | print(f'\tTile {tile_idx}/{tiles_x * tiles_y}') 114 | 115 | # output tile area on total image 116 | output_start_x = input_start_x * self.scale 117 | output_end_x = input_end_x * self.scale 118 | output_start_y = input_start_y * self.scale 119 | output_end_y = input_end_y * self.scale 120 | 121 | # output tile area without padding 122 | output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale 123 | output_end_x_tile = output_start_x_tile + input_tile_width * self.scale 124 | output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale 125 | output_end_y_tile = output_start_y_tile + input_tile_height * self.scale 126 | 127 | # put tile into output image 128 | self.output[:, :, output_start_y:output_end_y, 129 | output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile, 130 | output_start_x_tile:output_end_x_tile] 131 | 132 | def post_process(self): 133 | # remove extra pad 134 | if self.mod_scale is not None: 135 | _, _, h, w = self.output.size() 136 | self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale] 137 | # remove prepad 138 | if self.pre_pad != 0: 139 | _, _, h, w = self.output.size() 140 | self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale] 141 | return self.output 142 | 143 | @torch.no_grad() 144 | def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'): 145 | h_input, w_input = img.shape[0:2] 146 | # img: numpy 147 | img = img.astype(np.float32) 148 | if np.max(img) > 256: # 16-bit image 149 | max_range = 65535 150 | print('\tInput is a 16-bit image') 151 | else: 152 | max_range = 255 153 | img = img / max_range 154 | if len(img.shape) == 2: # gray image 155 | img_mode = 'L' 156 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) 157 | elif img.shape[2] == 4: # RGBA image with alpha channel 158 | img_mode = 'RGBA' 159 | alpha = img[:, :, 3] 160 | img = img[:, :, 0:3] 161 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 162 | if alpha_upsampler == 'realesrgan': 163 | alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB) 164 | else: 165 | img_mode = 'RGB' 166 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 167 | 168 | # ------------------- process image (without the alpha channel) ------------------- # 169 | self.pre_process(img) 170 | if self.tile_size > 0: 171 | self.tile_process() 172 | else: 173 | self.process() 174 | output_img = self.post_process() 175 | output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy() 176 | output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0)) 177 | if img_mode == 'L': 178 | output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY) 179 | 180 | # ------------------- process the alpha channel if necessary ------------------- # 181 | if img_mode == 'RGBA': 182 | if alpha_upsampler == 'realesrgan': 183 | self.pre_process(alpha) 184 | if self.tile_size > 0: 185 | self.tile_process() 186 | else: 187 | self.process() 188 | output_alpha = self.post_process() 189 | output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy() 190 | output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0)) 191 | output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY) 192 | else: 193 | h, w = alpha.shape[0:2] 194 | output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR) 195 | 196 | # merge the alpha channel 197 | output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA) 198 | output_img[:, :, 3] = output_alpha 199 | 200 | # ------------------------------ return ------------------------------ # 201 | if max_range == 65535: # 16-bit image 202 | output = (output_img * 65535.0).round().astype(np.uint16) 203 | else: 204 | output = (output_img * 255.0).round().astype(np.uint8) 205 | 206 | if outscale is not None and outscale != float(self.scale): 207 | output = cv2.resize( 208 | output, ( 209 | int(w_input * outscale), 210 | int(h_input * outscale), 211 | ), interpolation=cv2.INTER_LANCZOS4) 212 | 213 | return output, img_mode 214 | 215 | 216 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None): 217 | """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py 218 | """ 219 | if model_dir is None: 220 | hub_dir = get_dir() 221 | model_dir = os.path.join(hub_dir, 'checkpoints') 222 | 223 | os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True) 224 | 225 | parts = urlparse(url) 226 | filename = os.path.basename(parts.path) 227 | if file_name is not None: 228 | filename = file_name 229 | cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename)) 230 | if not os.path.exists(cached_file): 231 | print(f'Downloading: "{url}" to {cached_file}\n') 232 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) 233 | return cached_file 234 | -------------------------------------------------------------------------------- /real_esrgan/realesrgan/version.py: -------------------------------------------------------------------------------- 1 | # GENERATED VERSION FILE 2 | # TIME: Sun Oct 31 15:00:17 2021 3 | __version__ = '0.2.2.5' 4 | __gitsha__ = '3338b31' 5 | version_info = (0, 2, 2, 5) 6 | -------------------------------------------------------------------------------- /real_esrgan/realesrgan/weights/README.md: -------------------------------------------------------------------------------- 1 | # Weights 2 | 3 | Put the downloaded weights to this folder. 4 | -------------------------------------------------------------------------------- /real_esrgan/requirements.txt: -------------------------------------------------------------------------------- 1 | onnxruntime-gpu 2 | Pillow 3 | numpy 4 | opencv-python==4.5.2.54 5 | basicsr 6 | click 7 | gdown -------------------------------------------------------------------------------- /real_esrgan/upscale_image.py: -------------------------------------------------------------------------------- 1 | from realesrgan import RealESRGANer 2 | from basicsr.archs.rrdbnet_arch import RRDBNet 3 | import PIL 4 | import numpy as np 5 | 6 | class RealESRGANUpscaler: 7 | def __init__(self): 8 | model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) 9 | self.upsampler = RealESRGANer( 10 | scale=4, 11 | model_path="RealESRGAN_x4plus_anime_6B.pth", 12 | model=model, 13 | tile=0, 14 | tile_pad=10, 15 | pre_pad=0, 16 | half=False 17 | ) 18 | 19 | def upscale(self, pil_image): 20 | output, _ = self.upsampler.enhance(np.array(pil_image), outscale=4) 21 | pil_image = PIL.Image.fromarray(output, mode='RGBA') 22 | 23 | return pil_image 24 | 25 | upscaler = RealESRGANUpscaler() 26 | image = PIL.Image.open("test.png") 27 | pil_image = upscaler.upscale(image) 28 | pil_image.show() -------------------------------------------------------------------------------- /real_esrgan/upscale_image_rgba.py: -------------------------------------------------------------------------------- 1 | import onnxruntime 2 | import torch 3 | from PIL import Image 4 | import numpy as np 5 | import cv2 6 | import time 7 | from basicsr.archs.rrdbnet_arch import RRDBNet 8 | import click 9 | 10 | class SimpleRealESRGAN: 11 | def __init__(self, onnx=False): 12 | self.onnx = onnx 13 | if onnx: 14 | self.session = onnxruntime.InferenceSession( 15 | "RealESRGAN_x4plus_anime_6B.onnx", providers=["CUDAExecutionProvider"]) 16 | self.device = torch.device('cpu') 17 | else: 18 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 19 | model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) 20 | loadnet = torch.load("RealESRGAN_x4plus_anime_6B.pth") 21 | if 'params_ema' in loadnet: 22 | keyname = 'params_ema' 23 | else: 24 | keyname = 'params' 25 | model.load_state_dict(loadnet[keyname], strict=True) 26 | model.eval() 27 | self.model = model.to(device) 28 | self.device = device 29 | 30 | def upscale_image(self, np_image_rgb): 31 | np_image_rgb = cv2.cvtColor(np_image_rgb, cv2.COLOR_RGB2BGR) 32 | image_rgb_tensor = torch.tensor(np_image_rgb.astype(np.float32)).to(self.device) 33 | image_rgb_tensor /= 255 34 | image_rgb_tensor = image_rgb_tensor.permute(2, 0, 1) 35 | image_rgb_tensor = image_rgb_tensor.unsqueeze(0) 36 | 37 | if self.onnx: 38 | output_img = torch.tensor(self.session.run([], {"image.1": image_rgb_tensor.cpu().numpy()})[0]) 39 | else: 40 | output_img = self.model(image_rgb_tensor) 41 | 42 | output_img = output_img.data.squeeze().float().clamp_(0, 1) 43 | output_img = output_img.permute((1, 2, 0)) 44 | output = (output_img * 255.0).round().cpu().numpy().astype(np.uint8) 45 | output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB) 46 | return output 47 | 48 | @torch.no_grad() 49 | def upscale_rgba_image(self, np_image_rgba): 50 | s = time.time() 51 | upscaled_rgb = self.upscale_image(np_image_rgba[:, :, 0:3]) 52 | upscaled_alpha = np.expand_dims( 53 | cv2.cvtColor( 54 | self.upscale_image( 55 | cv2.cvtColor(np_image_rgba[:, :, 3], cv2.COLOR_GRAY2RGB)), cv2.COLOR_RGB2GRAY), axis=2) 56 | output = np.concatenate((upscaled_rgb, upscaled_alpha), axis=2) 57 | pil_image = Image.fromarray(output, mode="RGBA") 58 | print(time.time() - s) 59 | return pil_image 60 | 61 | @click.command() 62 | @click.option("--onnx", is_flag=True) 63 | def main(onnx): 64 | upscaler = SimpleRealESRGAN(onnx) 65 | np_image = np.array(Image.open("test.png")) 66 | pil_image = upscaler.upscale_rgba_image(np_image) 67 | pil_image.show() 68 | 69 | if __name__ == "__main__": 70 | main() -------------------------------------------------------------------------------- /realtime_srgan_anime/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2021 xiong-jie-y 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /realtime_srgan_anime/README.md: -------------------------------------------------------------------------------- 1 | ## Demo 2 | ![](./demo.png) 3 | 4 | ## How to run it? 5 | ``` 6 | pip install -r requirements.txt 7 | 8 | python app.py 9 | 10 | # Access shown URL. 11 | ``` 12 | 13 | The running speed is slow when running on gradio. 14 | Please try running it in your own program. 15 | 16 | ## License 17 | The code and the model used in this folder is licensed under [MIT license](LICENSE). 18 | You can get model directly from [Hugging Face](https://huggingface.co/xiongjie/realtime-SRGAN-for-anime). -------------------------------------------------------------------------------- /realtime_srgan_anime/app.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import gradio 4 | from huggingface_hub import hf_hub_download 5 | from basicsr.archs.srresnet_arch import MSRResNet 6 | 7 | class SimpleRealUpscaler: 8 | def __init__(self): 9 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 10 | model = MSRResNet(num_in_ch=3, num_out_ch=3, num_feat=32, num_block=6, upscale=4) 11 | path = hf_hub_download("xiongjie/realtime-SRGAN-for-anime", filename="SRGAN_x4plus_anime.pth") 12 | loadnet = torch.load(path) 13 | if 'params_ema' in loadnet: 14 | keyname = 'params_ema' 15 | else: 16 | keyname = 'params' 17 | model.load_state_dict(loadnet[keyname], strict=True) 18 | model.eval() 19 | self.model = model.to(self.device) 20 | 21 | def upscale(self, np_image_rgb): 22 | image_rgb_tensor = torch.tensor(np_image_rgb[:,:,::-1].astype(np.float32)).to(self.device) 23 | image_rgb_tensor /= 255 24 | image_rgb_tensor = image_rgb_tensor.permute(2, 0, 1) 25 | image_rgb_tensor = image_rgb_tensor.unsqueeze(0) 26 | output_img = self.model(image_rgb_tensor) 27 | output_img = output_img.data.squeeze().float().clamp_(0, 1) 28 | output_img = output_img.permute((1, 2, 0)) 29 | output = (output_img * 255.0).round().cpu().numpy().astype(np.uint8) 30 | return output[:, :, ::-1] 31 | 32 | 33 | upscaler = SimpleRealUpscaler() 34 | def upscale(np_image_rgb): 35 | return upscaler.upscale(np_image_rgb) 36 | 37 | css = ".output_image {height: 100% !important; width: 100% !important;}" 38 | inputs = gradio.inputs.Image() 39 | outputs = gradio.outputs.Image() 40 | gradio.Interface(fn=upscale, inputs=inputs, outputs=outputs, css=css).launch() -------------------------------------------------------------------------------- /realtime_srgan_anime/app_camera.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import gradio 4 | from huggingface_hub import hf_hub_download 5 | from basicsr.archs.srresnet_arch import MSRResNet 6 | 7 | class SimpleRealUpscaler: 8 | def __init__(self): 9 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 10 | model = MSRResNet(num_in_ch=3, num_out_ch=3, num_feat=32, num_block=6, upscale=4) 11 | path = hf_hub_download("xiongjie/realtime-SRGAN-for-anime", filename="SRGAN_x4plus_anime.pth") 12 | loadnet = torch.load(path) 13 | if 'params_ema' in loadnet: 14 | keyname = 'params_ema' 15 | else: 16 | keyname = 'params' 17 | model.load_state_dict(loadnet[keyname], strict=True) 18 | model.eval() 19 | self.model = model.to(self.device) 20 | 21 | def upscale(self, np_image_rgb): 22 | image_rgb_tensor = torch.tensor(np_image_rgb[:,:,::-1].astype(np.float32)).to(self.device) 23 | image_rgb_tensor /= 255 24 | image_rgb_tensor = image_rgb_tensor.permute(2, 0, 1) 25 | image_rgb_tensor = image_rgb_tensor.unsqueeze(0) 26 | output_img = self.model(image_rgb_tensor) 27 | output_img = output_img.data.squeeze().float().clamp_(0, 1) 28 | output_img = output_img.permute((1, 2, 0)) 29 | output = (output_img * 255.0).round().cpu().numpy().astype(np.uint8) 30 | return output[:, :, ::-1] 31 | 32 | 33 | import numpy as np 34 | import cv2 as cv 35 | 36 | 37 | cap = cv.VideoCapture(0) 38 | upscaler = SimpleRealUpscaler() 39 | 40 | if not cap.isOpened(): 41 | print("Cannot open camera") 42 | exit() 43 | while True: 44 | ret, frame = cap.read() 45 | print(frame.shape) 46 | frame = cv.resize(frame, (frame.shape[1] // 2, frame.shape[0] // 2)) 47 | frame = upscaler.upscale(frame) 48 | 49 | if not ret: 50 | print("Can't receive frame (stream end?). Exiting ...") 51 | break 52 | 53 | cv.imshow('frame', frame) 54 | if cv.waitKey(1) == ord('q'): 55 | break 56 | -------------------------------------------------------------------------------- /realtime_srgan_anime/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiong-jie-y/ml-examples/947e1fb8690a75f8a34db38aee4635d0319044d7/realtime_srgan_anime/demo.png -------------------------------------------------------------------------------- /realtime_srgan_anime/requirements.txt: -------------------------------------------------------------------------------- 1 | gradio 2 | huggingface-hub 3 | basicsr 4 | Pillow 5 | numpy 6 | -------------------------------------------------------------------------------- /yamnet/README.md: -------------------------------------------------------------------------------- 1 | ## How to run it? 2 | ``` 3 | pip install -r requirements.txt 4 | 5 | python human_voice_example.py Speach 6 | 7 | # Speak something 8 | ``` 9 | 10 | ## License 11 | The code in this folder is licensed under [MIT license](LICENSE). 12 | The model used in this script follows [this license](https://tfhub.dev/google/yamnet/1). -------------------------------------------------------------------------------- /yamnet/human_voice_example.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import datetime 3 | import enum 4 | import os 5 | import time 6 | from collections import deque 7 | from dataclasses import dataclass 8 | 9 | import click 10 | import gdown 11 | import librosa 12 | import numpy as np 13 | import onnxruntime 14 | import pyaudio 15 | import scipy 16 | from scipy.io.wavfile import write 17 | 18 | MODEL_PATH_ROOT_ = "models" 19 | 20 | def get_model_file_from_gdrive(name, url): 21 | filepath = os.path.join(MODEL_PATH_ROOT_, name) 22 | if not os.path.exists(filepath): 23 | os.makedirs(MODEL_PATH_ROOT_, exist_ok=True) 24 | gdown.download(url, filepath, quiet=False) 25 | 26 | return filepath 27 | 28 | def ensure_sample_rate(original_sample_rate, waveform, 29 | desired_sample_rate=16000): 30 | """Resample waveform if required.""" 31 | if original_sample_rate != desired_sample_rate: 32 | desired_length = int(round(float(len(waveform)) / 33 | original_sample_rate * desired_sample_rate)) 34 | waveform = scipy.signal.resample(waveform, desired_length) 35 | return desired_sample_rate, waveform 36 | 37 | class HumanVoiceDetector: 38 | def __init__(self): 39 | self.class_names = ['Speech', 'Child speech, kid speaking', 'Conversation', 'Narration, monologue', 'Babbling', 'Speech synthesizer', 'Shout', 'Bellow', 'Whoop', 'Yell', 'Children shouting', 'Screaming', 'Whispering', 'Laughter', 'Baby laughter', 'Giggle', 'Snicker', 'Belly laugh', 'Chuckle, chortle', 'Crying, sobbing', 'Baby cry, infant cry', 'Whimper', 'Wail, moan', 'Sigh', 'Singing', 'Choir', 'Yodeling', 'Chant', 'Mantra', 'Child singing', 'Synthetic singing', 'Rapping', 'Humming', 'Groan', 'Grunt', 'Whistling', 'Breathing', 'Wheeze', 'Snoring', 'Gasp', 'Pant', 'Snort', 'Cough', 'Throat clearing', 'Sneeze', 'Sniff', 'Run', 'Shuffle', 'Walk, footsteps', 'Chewing, mastication', 'Biting', 'Gargling', 'Stomach rumble', 'Burping, eructation', 'Hiccup', 'Fart', 'Hands', 'Finger snapping', 'Clapping', 'Heart sounds, heartbeat', 'Heart murmur', 'Cheering', 'Applause', 'Chatter', 'Crowd', 'Hubbub, speech noise, speech babble', 'Children playing', 'Animal', 'Domestic animals, pets', 'Dog', 'Bark', 'Yip', 'Howl', 'Bow-wow', 'Growling', 'Whimper (dog)', 'Cat', 'Purr', 'Meow', 'Hiss', 'Caterwaul', 'Livestock, farm animals, working animals', 'Horse', 'Clip-clop', 'Neigh, whinny', 'Cattle, bovinae', 'Moo', 'Cowbell', 'Pig', 'Oink', 'Goat', 'Bleat', 'Sheep', 'Fowl', 'Chicken, rooster', 'Cluck', 'Crowing, cock-a-doodle-doo', 'Turkey', 'Gobble', 'Duck', 'Quack', 'Goose', 'Honk', 'Wild animals', 'Roaring cats (lions, tigers)', 'Roar', 'Bird', 'Bird vocalization, bird call, bird song', 'Chirp, tweet', 'Squawk', 'Pigeon, dove', 'Coo', 'Crow', 'Caw', 'Owl', 'Hoot', 'Bird flight, flapping wings', 'Canidae, dogs, wolves', 'Rodents, rats, mice', 'Mouse', 'Patter', 'Insect', 'Cricket', 'Mosquito', 'Fly, housefly', 'Buzz', 'Bee, wasp, etc.', 'Frog', 'Croak', 'Snake', 'Rattle', 'Whale vocalization', 'Music', 'Musical instrument', 'Plucked string instrument', 'Guitar', 'Electric guitar', 'Bass guitar', 'Acoustic guitar', 'Steel guitar, slide guitar', 'Tapping (guitar technique)', 'Strum', 'Banjo', 'Sitar', 'Mandolin', 'Zither', 'Ukulele', 'Keyboard (musical)', 'Piano', 'Electric piano', 'Organ', 'Electronic organ', 'Hammond organ', 'Synthesizer', 'Sampler', 'Harpsichord', 'Percussion', 'Drum kit', 'Drum machine', 'Drum', 'Snare drum', 'Rimshot', 'Drum roll', 'Bass drum', 'Timpani', 'Tabla', 'Cymbal', 'Hi-hat', 'Wood block', 'Tambourine', 'Rattle (instrument)', 'Maraca', 'Gong', 'Tubular bells', 'Mallet percussion', 'Marimba, xylophone', 'Glockenspiel', 'Vibraphone', 'Steelpan', 'Orchestra', 'Brass instrument', 'French horn', 'Trumpet', 'Trombone', 'Bowed string instrument', 'String section', 'Violin, fiddle', 'Pizzicato', 'Cello', 'Double bass', 'Wind instrument, woodwind instrument', 'Flute', 'Saxophone', 'Clarinet', 'Harp', 'Bell', 'Church bell', 'Jingle bell', 'Bicycle bell', 'Tuning fork', 'Chime', 'Wind chime', 'Change ringing (campanology)', 'Harmonica', 'Accordion', 'Bagpipes', 'Didgeridoo', 'Shofar', 'Theremin', 'Singing bowl', 'Scratching (performance technique)', 'Pop music', 'Hip hop music', 'Beatboxing', 'Rock music', 'Heavy metal', 'Punk rock', 'Grunge', 'Progressive rock', 'Rock and roll', 'Psychedelic rock', 'Rhythm and blues', 'Soul music', 'Reggae', 'Country', 'Swing music', 'Bluegrass', 'Funk', 'Folk music', 'Middle Eastern music', 'Jazz', 'Disco', 'Classical music', 'Opera', 'Electronic music', 'House music', 'Techno', 'Dubstep', 'Drum and bass', 'Electronica', 'Electronic dance music', 'Ambient music', 'Trance music', 'Music of Latin America', 'Salsa music', 'Flamenco', 'Blues', 'Music for children', 'New-age music', 'Vocal music', 'A capella', 'Music of Africa', 'Afrobeat', 'Christian music', 'Gospel music', 'Music of Asia', 'Carnatic music', 'Music of Bollywood', 'Ska', 'Traditional music', 'Independent music', 'Song', 'Background music', 'Theme music', 'Jingle (music)', 'Soundtrack music', 'Lullaby', 'Video game music', 'Christmas music', 'Dance music', 'Wedding music', 'Happy music', 'Sad music', 'Tender music', 'Exciting music', 'Angry music', 'Scary music', 'Wind', 'Rustling leaves', 'Wind noise (microphone)', 'Thunderstorm', 'Thunder', 'Water', 'Rain', 'Raindrop', 'Rain on surface', 'Stream', 'Waterfall', 'Ocean', 'Waves, surf', 'Steam', 'Gurgling', 'Fire', 'Crackle', 'Vehicle', 'Boat, Water vehicle', 'Sailboat, sailing ship', 'Rowboat, canoe, kayak', 'Motorboat, speedboat', 'Ship', 'Motor vehicle (road)', 'Car', 'Vehicle horn, car horn, honking', 'Toot', 'Car alarm', 'Power windows, electric windows', 'Skidding', 'Tire squeal', 'Car passing by', 'Race car, auto racing', 'Truck', 'Air brake', 'Air horn, truck horn', 'Reversing beeps', 'Ice cream truck, ice cream van', 'Bus', 'Emergency vehicle', 'Police car (siren)', 'Ambulance (siren)', 'Fire engine, fire truck (siren)', 'Motorcycle', 'Traffic noise, roadway noise', 'Rail transport', 'Train', 'Train whistle', 'Train horn', 'Railroad car, train wagon', 'Train wheels squealing', 'Subway, metro, underground', 'Aircraft', 'Aircraft engine', 'Jet engine', 'Propeller, airscrew', 'Helicopter', 'Fixed-wing aircraft, airplane', 'Bicycle', 'Skateboard', 'Engine', 'Light engine (high frequency)', "Dental drill, dentist's drill", 'Lawn mower', 'Chainsaw', 'Medium engine (mid frequency)', 'Heavy engine (low frequency)', 'Engine knocking', 'Engine starting', 'Idling', 'Accelerating, revving, vroom', 'Door', 'Doorbell', 'Ding-dong', 'Sliding door', 'Slam', 'Knock', 'Tap', 'Squeak', 'Cupboard open or close', 'Drawer open or close', 'Dishes, pots, and pans', 'Cutlery, silverware', 'Chopping (food)', 'Frying (food)', 'Microwave oven', 'Blender', 'Water tap, faucet', 'Sink (filling or washing)', 'Bathtub (filling or washing)', 'Hair dryer', 'Toilet flush', 'Toothbrush', 'Electric toothbrush', 'Vacuum cleaner', 'Zipper (clothing)', 'Keys jangling', 'Coin (dropping)', 'Scissors', 'Electric shaver, electric razor', 'Shuffling cards', 'Typing', 'Typewriter', 'Computer keyboard', 'Writing', 'Alarm', 'Telephone', 'Telephone bell ringing', 'Ringtone', 'Telephone dialing, DTMF', 'Dial tone', 'Busy signal', 'Alarm clock', 'Siren', 'Civil defense siren', 'Buzzer', 'Smoke detector, smoke alarm', 'Fire alarm', 'Foghorn', 'Whistle', 'Steam whistle', 'Mechanisms', 'Ratchet, pawl', 'Clock', 'Tick', 'Tick-tock', 'Gears', 'Pulleys', 'Sewing machine', 'Mechanical fan', 'Air conditioning', 'Cash register', 'Printer', 'Camera', 'Single-lens reflex camera', 'Tools', 'Hammer', 'Jackhammer', 'Sawing', 'Filing (rasp)', 'Sanding', 'Power tool', 'Drill', 'Explosion', 'Gunshot, gunfire', 'Machine gun', 'Fusillade', 'Artillery fire', 'Cap gun', 'Fireworks', 'Firecracker', 'Burst, pop', 'Eruption', 'Boom', 'Wood', 'Chop', 'Splinter', 'Crack', 'Glass', 'Chink, clink', 'Shatter', 'Liquid', 'Splash, splatter', 'Slosh', 'Squish', 'Drip', 'Pour', 'Trickle, dribble', 'Gush', 'Fill (with liquid)', 'Spray', 'Pump (liquid)', 'Stir', 'Boiling', 'Sonar', 'Arrow', 'Whoosh, swoosh, swish', 'Thump, thud', 'Thunk', 'Electronic tuner', 'Effects unit', 'Chorus effect', 'Basketball bounce', 'Bang', 'Slap, smack', 'Whack, thwack', 'Smash, crash', 'Breaking', 'Bouncing', 'Whip', 'Flap', 'Scratch', 'Scrape', 'Rub', 'Roll', 'Crushing', 'Crumpling, crinkling', 'Tearing', 'Beep, bleep', 'Ping', 'Ding', 'Clang', 'Squeal', 'Creak', 'Rustle', 'Whir', 'Clatter', 'Sizzle', 'Clicking', 'Clickety-clack', 'Rumble', 'Plop', 'Jingle, tinkle', 'Hum', 'Zing', 'Boing', 'Crunch', 'Silence', 'Sine wave', 'Harmonic', 'Chirp tone', 'Sound effect', 'Pulse', 'Inside, small room', 'Inside, large room or hall', 'Inside, public space', 'Outside, urban or manmade', 'Outside, rural or natural', 'Reverberation', 'Echo', 'Noise', 'Environmental noise', 'Static', 'Mains hum', 'Distortion', 'Sidetone', 'Cacophony', 'White noise', 'Pink noise', 'Throbbing', 'Vibration', 'Television', 'Radio', 'Field recording'] 40 | 41 | def wait_for_human_voice(self, detect_class): 42 | frame_len = int(16000 * 0.96 * 0.1) 43 | 44 | p = pyaudio.PyAudio() 45 | stream = p.open(format=pyaudio.paInt16, 46 | channels=1, 47 | rate=16000, 48 | input=True, 49 | frames_per_buffer=frame_len) 50 | 51 | buffers = deque() 52 | left_list = [] 53 | 54 | providers = ['CPUExecutionProvider'] 55 | 56 | session = onnxruntime.InferenceSession( 57 | get_model_file_from_gdrive("yamnet.onnx", "https://drive.google.com/uc?id=1u7V15wRp3_gcUdXPzm9WtJy51ENpCqEC"), 58 | providers=providers) 59 | 60 | found_speech = False 61 | 62 | while True: 63 | data = stream.read(frame_len, exception_on_overflow=False) 64 | frame_data = librosa.util.buf_to_float(data, n_bytes=2, dtype=np.int16) 65 | 66 | buffers.append(frame_data) 67 | if len(buffers) > 10: 68 | if found_speech: 69 | left_list.append(buffers.popleft()) 70 | else: 71 | buffers.popleft() 72 | 73 | this_frame_data = np.concatenate(buffers) 74 | 75 | input_name = session.get_inputs()[0].name 76 | input_wave = this_frame_data.astype(np.float32) 77 | 78 | outputs = session.run([], {input_name: input_wave})[0][0] 79 | class_name = self.class_names[np.argmax(outputs)] 80 | 81 | print(class_name) 82 | 83 | if class_name == detect_class: 84 | found_speech = True 85 | 86 | elif found_speech and class_name != detect_class: 87 | sound = np.concatenate(left_list + list(buffers)) 88 | write('keyword.wav', 16000, sound) 89 | break 90 | 91 | stream.stop_stream() 92 | stream.close() 93 | p.terminate() 94 | 95 | @click.command() 96 | @click.argument("detect_class") 97 | def main(detect_class): 98 | voice_detector = HumanVoiceDetector() 99 | voice_detector.wait_for_human_voice(detect_class) 100 | 101 | 102 | if __name__ == "__main__": 103 | main() -------------------------------------------------------------------------------- /yamnet/requirements.txt: -------------------------------------------------------------------------------- 1 | PyAudio 2 | onnxruntime 3 | scipy 4 | gdown --------------------------------------------------------------------------------