├── tests ├── __init__.py ├── test_show.py ├── test_super_resolution.py ├── test_emojich_unet.py ├── test_tokenizer.py ├── test_image_prompts.py ├── test_dalle.py ├── conftest.py └── test_vae.py ├── pics ├── cat-ru.png ├── man-0.png ├── man-1.png ├── man-2.png ├── man-3.png ├── man-4.png ├── man-5.png ├── man-ru.png ├── cat-gan.png ├── man-gan.png ├── woman-gan.png ├── woman-ru.png ├── avocado-gan.png ├── avocado-ru.png ├── cat-diffusion.png ├── cathedral-gan.png ├── cathedral-ru.png ├── emojich │ ├── examples.png │ ├── emoji-Donald.png │ ├── emojich_rgba.png │ ├── emojich-stickers.png │ └── emojich_rgba_100.png ├── woman-diffusion.png ├── avocado-diffusion.png ├── cathedral-diffusion.png ├── malevich │ ├── rainbow-full.png │ ├── rainbow-cherry-pick.png │ ├── rainbow-super-resolution.png │ ├── anime-girl-super-resolution.png │ └── russian-temple-image-prompt.png ├── habr_eng.svg └── habr.svg ├── requirements-test.txt ├── .coveragerc ├── requirements.txt ├── setup.cfg ├── .gitlab-ci.yml ├── .pre-commit-config.yaml ├── rudalle ├── __init__.py ├── vae │ ├── vqgan.gumbelf8-sber.config.yml │ ├── __init__.py │ ├── decoder_dwt.py │ ├── model.py │ └── pytorch_wavelets_utils.py ├── ruclip │ ├── __init__.py │ └── processor.py ├── utils.py ├── realesrgan │ ├── __init__.py │ ├── model.py │ ├── utils.py │ ├── rrdbnet_arch.py │ └── arch_util.py ├── emojich_unet │ └── __init__.py ├── dalle │ ├── utils.py │ ├── image_attention.py │ ├── fp16.py │ ├── __init__.py │ ├── model.py │ └── transformer.py ├── tokenizer.py ├── image_prompts.py └── pipelines.py ├── setup.py ├── .gitignore ├── README.md ├── Emojich.md └── LICENSE.txt /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pics/cat-ru.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/cat-ru.png -------------------------------------------------------------------------------- /pics/man-0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/man-0.png -------------------------------------------------------------------------------- /pics/man-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/man-1.png -------------------------------------------------------------------------------- /pics/man-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/man-2.png -------------------------------------------------------------------------------- /pics/man-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/man-3.png -------------------------------------------------------------------------------- /pics/man-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/man-4.png -------------------------------------------------------------------------------- /pics/man-5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/man-5.png -------------------------------------------------------------------------------- /pics/man-ru.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/man-ru.png -------------------------------------------------------------------------------- /pics/cat-gan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/cat-gan.png -------------------------------------------------------------------------------- /pics/man-gan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/man-gan.png -------------------------------------------------------------------------------- /pics/woman-gan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/woman-gan.png -------------------------------------------------------------------------------- /pics/woman-ru.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/woman-ru.png -------------------------------------------------------------------------------- /requirements-test.txt: -------------------------------------------------------------------------------- 1 | -r requirements.txt 2 | pytest 3 | pytest-cov 4 | pre-commit 5 | -------------------------------------------------------------------------------- /pics/avocado-gan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/avocado-gan.png -------------------------------------------------------------------------------- /pics/avocado-ru.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/avocado-ru.png -------------------------------------------------------------------------------- /pics/cat-diffusion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/cat-diffusion.png -------------------------------------------------------------------------------- /pics/cathedral-gan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/cathedral-gan.png -------------------------------------------------------------------------------- /pics/cathedral-ru.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/cathedral-ru.png -------------------------------------------------------------------------------- /pics/emojich/examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/emojich/examples.png -------------------------------------------------------------------------------- /pics/woman-diffusion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/woman-diffusion.png -------------------------------------------------------------------------------- /pics/avocado-diffusion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/avocado-diffusion.png -------------------------------------------------------------------------------- /pics/cathedral-diffusion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/cathedral-diffusion.png -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit = 3 | # omit this single file 4 | rudalle/vae/pytorch_wavelets_utils.py 5 | -------------------------------------------------------------------------------- /pics/emojich/emoji-Donald.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/emojich/emoji-Donald.png -------------------------------------------------------------------------------- /pics/emojich/emojich_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/emojich/emojich_rgba.png -------------------------------------------------------------------------------- /pics/malevich/rainbow-full.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/malevich/rainbow-full.png -------------------------------------------------------------------------------- /pics/emojich/emojich-stickers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/emojich/emojich-stickers.png -------------------------------------------------------------------------------- /pics/emojich/emojich_rgba_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/emojich/emojich_rgba_100.png -------------------------------------------------------------------------------- /pics/malevich/rainbow-cherry-pick.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/malevich/rainbow-cherry-pick.png -------------------------------------------------------------------------------- /pics/malevich/rainbow-super-resolution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/malevich/rainbow-super-resolution.png -------------------------------------------------------------------------------- /pics/malevich/anime-girl-super-resolution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/malevich/anime-girl-super-resolution.png -------------------------------------------------------------------------------- /pics/malevich/russian-temple-image-prompt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/malevich/russian-temple-image-prompt.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | taming-transformers==0.0.1 2 | more_itertools~=8.10.0 3 | transformers~=4.10.2 4 | youtokentome~=1.0.6 5 | omegaconf>=2.0.0 6 | einops~=0.3.2 7 | PyWavelets==1.1.1 8 | segmentation-models-pytorch==0.1.3 9 | opencv-python==4.5.4.60 10 | torch 11 | torchvision 12 | matplotlib 13 | -------------------------------------------------------------------------------- /tests/test_show.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from rudalle.pipelines import show 3 | 4 | 5 | def test_show(sample_image): 6 | img = sample_image.copy() 7 | img = img.resize((256, 256)) 8 | pil_images = [img]*5 9 | show(pil_images, nrow=2, save_dir='/tmp/pics', show=False) 10 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [pep8] 2 | max-line-length = 120 3 | exclude = .tox,*migrations*,.json 4 | 5 | [flake8] 6 | max-line-length = 120 7 | exclude = .tox,*migrations*,.json 8 | 9 | [autopep8-wrapper] 10 | exclude = .tox,*migrations*,.json 11 | 12 | [check-docstring-first] 13 | exclude = .tox,*migrations*,.json 14 | -------------------------------------------------------------------------------- /tests/test_super_resolution.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from rudalle.pipelines import super_resolution 3 | 4 | 5 | def test_super_resolution(sample_image, realesrgan): 6 | img = sample_image.copy() 7 | img = img.resize((32, 32)) 8 | sr_img = super_resolution([img], realesrgan)[0] 9 | assert sr_img.size[0] == 32*2 10 | assert sr_img.size[1] == 32*2 11 | -------------------------------------------------------------------------------- /.gitlab-ci.yml: -------------------------------------------------------------------------------- 1 | stages: 2 | - test 3 | 4 | all_branch_test: 5 | stage: test 6 | tags: 7 | - docker 8 | image: python:3.9 9 | script: 10 | - apt-get update ##[edited] 11 | - apt-get install ffmpeg libsm6 libxext6 -y 12 | - pip install cython 13 | - pip install -r requirements-test.txt --no-cache-dir 14 | - pip install timm==0.4.12 15 | - pip install codecov 16 | - pytest --cov=rudalle tests/ 17 | - bash <(curl -s https://codecov.io/bash) -t $CODECOV_TOKEN 18 | except: 19 | - tags 20 | -------------------------------------------------------------------------------- /tests/test_emojich_unet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | 4 | from rudalle.pipelines import convert_emoji_to_rgba 5 | 6 | 7 | def test_convert_emoji_to_rgba(sample_image, emojich_unet): 8 | img = sample_image.copy() 9 | img = img.resize((512, 512)) 10 | rgba_images, runs = convert_emoji_to_rgba([img], emojich_unet, score_thr=0.99) 11 | assert len(runs) == len(rgba_images) 12 | rgba_img = rgba_images[0] 13 | assert rgba_img.size[0] == 512 14 | assert rgba_img.size[1] == 512 15 | assert np.array(rgba_img).shape[-1] == 4 16 | assert runs[0] in ['unet', 'classic'] 17 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.0.1 4 | hooks: 5 | - id: check-docstring-first 6 | - id: check-merge-conflict 7 | stages: 8 | - push 9 | - id: double-quote-string-fixer 10 | - id: end-of-file-fixer 11 | - id: fix-encoding-pragma 12 | - id: mixed-line-ending 13 | - id: trailing-whitespace 14 | - repo: https://github.com/pycqa/flake8 15 | rev: "4.0.1" 16 | hooks: 17 | - id: flake8 18 | args: ['--config=setup.cfg'] 19 | - repo: https://github.com/pre-commit/mirrors-autopep8 20 | rev: v1.5.7 21 | hooks: 22 | - id: autopep8 23 | -------------------------------------------------------------------------------- /rudalle/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from .vae import get_vae 3 | from .dalle import get_rudalle_model 4 | from .tokenizer import get_tokenizer 5 | from .realesrgan import get_realesrgan 6 | from .ruclip import get_ruclip 7 | from .emojich_unet import get_emojich_unet 8 | from . import vae, dalle, tokenizer, realesrgan, pipelines, ruclip, image_prompts 9 | 10 | 11 | __all__ = [ 12 | 'get_vae', 13 | 'get_rudalle_model', 14 | 'get_tokenizer', 15 | 'get_realesrgan', 16 | 'get_ruclip', 17 | 'get_emojich_unet', 18 | 'vae', 19 | 'dalle', 20 | 'ruclip', 21 | 'tokenizer', 22 | 'realesrgan', 23 | 'pipelines', 24 | 'image_prompts', 25 | ] 26 | 27 | __version__ = '0.4.0' 28 | -------------------------------------------------------------------------------- /tests/test_tokenizer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import pytest 3 | 4 | 5 | @pytest.mark.parametrize('text, text_seq_length, bpe_dropout', [ 6 | ('hello, how are you?', 128, 0.1), 7 | ('hello, how are you?', 128, 0.5), 8 | ('hello, how are you?', 128, 1.0), 9 | ('hello ... how are you ?', 256, 1.0), 10 | ('a person standing at a table with bottles of win', 64, 0.5), 11 | ('привет как дела???', 76, 0.0), 12 | ('клип на русском языке :)', 76, 0.1), 13 | ]) 14 | def test_encode_decode_text_yttm(yttm_tokenizer, text, text_seq_length, bpe_dropout): 15 | tokens = yttm_tokenizer.encode_text(text, text_seq_length=text_seq_length, bpe_dropout=bpe_dropout) 16 | decoded_text = yttm_tokenizer.decode_text(tokens) 17 | assert text == decoded_text 18 | -------------------------------------------------------------------------------- /tests/test_image_prompts.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import pytest 3 | 4 | from rudalle.image_prompts import ImagePrompts 5 | 6 | 7 | @pytest.mark.parametrize('borders, crop_first', [ 8 | ({'up': 4, 'right': 0, 'left': 0, 'down': 0}, False), 9 | ({'up': 4, 'right': 0, 'left': 0, 'down': 0}, True), 10 | ({'up': 4, 'right': 3, 'left': 3, 'down': 3}, False) 11 | ]) 12 | def test_image_prompts(sample_image, vae, borders, crop_first): 13 | img = sample_image.copy() 14 | img = img.resize((256, 256)) 15 | image_prompt = ImagePrompts(img, borders, vae, crop_first=crop_first) 16 | assert image_prompt.image_prompts.shape[1] == 32 * 32 17 | assert len(image_prompt.image_prompts_idx) == (borders['up'] + borders['down']) * 32 \ 18 | + (borders['left'] + borders['right']) * (32 - borders['up'] - borders['down']) 19 | -------------------------------------------------------------------------------- /rudalle/vae/vqgan.gumbelf8-sber.config.yml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: taming.models.vqgan.GumbelVQ 4 | params: 5 | kl_weight: 1.0e-08 6 | embed_dim: 256 7 | n_embed: 8192 8 | monitor: val/rec_loss 9 | temperature_scheduler_config: 10 | target: taming.lr_scheduler.LambdaWarmUpCosineScheduler 11 | params: 12 | warm_up_steps: 0 13 | max_decay_steps: 1000001 14 | lr_start: 0.9 15 | lr_max: 0.9 16 | lr_min: 1.0e-06 17 | ddconfig: 18 | double_z: false 19 | z_channels: 256 20 | resolution: 256 21 | in_channels: 3 22 | out_ch: 3 23 | ch: 128 24 | ch_mult: 25 | - 1 26 | - 1 27 | - 2 28 | - 4 29 | num_res_blocks: 2 30 | attn_resolutions: 31 | - 32 32 | dropout: 0.0 33 | lossconfig: 34 | target: taming.modules.losses.vqperceptual.DummyLoss 35 | -------------------------------------------------------------------------------- /rudalle/ruclip/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | 4 | from transformers import CLIPModel 5 | from huggingface_hub import hf_hub_url, cached_download 6 | 7 | from .processor import RuCLIPProcessor 8 | 9 | MODELS = { 10 | 'ruclip-vit-base-patch32-v5': dict( 11 | repo_id='sberbank-ai/ru-clip', 12 | filenames=[ 13 | 'bpe.model', 'config.json', 'pytorch_model.bin' 14 | ] 15 | ), 16 | } 17 | 18 | 19 | def get_ruclip(name, cache_dir='/tmp/rudalle'): 20 | assert name in MODELS 21 | config = MODELS[name] 22 | repo_id = config['repo_id'] 23 | cache_dir = os.path.join(cache_dir, name) 24 | for filename in config['filenames']: 25 | config_file_url = hf_hub_url(repo_id=repo_id, filename=f'{name}/{filename}') 26 | cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename) 27 | ruclip = CLIPModel.from_pretrained(cache_dir) 28 | ruclip_processor = RuCLIPProcessor.from_pretrained(cache_dir) 29 | print('ruclip --> ready') 30 | return ruclip, ruclip_processor 31 | -------------------------------------------------------------------------------- /rudalle/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import random 4 | 5 | import torch 6 | import torchvision 7 | import numpy as np 8 | 9 | 10 | def seed_everything(seed): 11 | random.seed(seed) 12 | os.environ['PYTHONHASHSEED'] = str(seed) 13 | np.random.seed(seed) 14 | torch.manual_seed(seed) 15 | torch.cuda.manual_seed(seed) 16 | torch.backends.cudnn.deterministic = True 17 | torch.backends.cudnn.benchmark = True 18 | 19 | 20 | def torch_tensors_to_pil_list(input_images): 21 | out_images = [] 22 | for in_image in input_images: 23 | in_image = in_image.cpu().detach() 24 | out_image = torchvision.transforms.functional.to_pil_image(in_image).convert('RGB') 25 | out_images.append(out_image) 26 | return out_images 27 | 28 | 29 | def pil_list_to_torch_tensors(pil_images): 30 | result = [] 31 | for pil_image in pil_images: 32 | image = np.array(pil_image, dtype=np.uint8) 33 | image = torch.from_numpy(image) 34 | image = image.permute(2, 0, 1).unsqueeze(0) 35 | result.append(image) 36 | return torch.cat(result, dim=0) 37 | -------------------------------------------------------------------------------- /rudalle/realesrgan/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | 4 | from huggingface_hub import hf_hub_url, cached_download 5 | 6 | from .model import RealESRGAN 7 | 8 | 9 | MODELS = { 10 | 'x2': dict( 11 | scale=2, 12 | repo_id='shonenkov/rudalle-utils', 13 | filename='RealESRGAN_x2.pth', 14 | ), 15 | 'x4': dict( 16 | scale=4, 17 | repo_id='shonenkov/rudalle-utils', 18 | filename='RealESRGAN_x4.pth', 19 | ), 20 | 'x8': dict( 21 | scale=8, 22 | repo_id='shonenkov/rudalle-utils', 23 | filename='RealESRGAN_x8.pth', 24 | ), 25 | } 26 | 27 | 28 | def get_realesrgan(name, device='cpu', fp16=False, cache_dir='/tmp/rudalle'): 29 | assert name in MODELS 30 | config = MODELS[name] 31 | model = RealESRGAN(device, config['scale'], fp16=fp16) 32 | cache_dir = os.path.join(cache_dir, name) 33 | config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename']) 34 | cached_download(config_file_url, cache_dir=cache_dir, force_filename=config['filename']) 35 | model.load_weights(os.path.join(cache_dir, config['filename'])) 36 | print(f'{name} --> ready') 37 | return model 38 | -------------------------------------------------------------------------------- /rudalle/vae/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from os.path import dirname, abspath, join 3 | 4 | import torch 5 | from huggingface_hub import hf_hub_url, cached_download 6 | from omegaconf import OmegaConf 7 | 8 | from .model import VQGanGumbelVAE 9 | 10 | 11 | def get_vae(pretrained=True, dwt=False, cache_dir='/tmp/rudalle'): 12 | # TODO 13 | config = OmegaConf.load(join(dirname(abspath(__file__)), 'vqgan.gumbelf8-sber.config.yml')) 14 | vae = VQGanGumbelVAE(config, dwt=dwt) 15 | if pretrained: 16 | repo_id = 'shonenkov/rudalle-utils' 17 | if dwt: 18 | filename = 'vqgan.gumbelf8-sber-dwt.model.ckpt' 19 | else: 20 | filename = 'vqgan.gumbelf8-sber.model.ckpt' 21 | cache_dir = join(cache_dir, 'vae') 22 | config_file_url = hf_hub_url(repo_id=repo_id, filename=filename) 23 | cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename) 24 | checkpoint = torch.load(join(cache_dir, filename), map_location='cpu') 25 | if dwt: 26 | vae.load_state_dict(checkpoint['state_dict']) 27 | else: 28 | vae.model.load_state_dict(checkpoint['state_dict'], strict=False) 29 | print('vae --> ready') 30 | return vae 31 | -------------------------------------------------------------------------------- /tests/test_dalle.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import pytest 4 | 5 | from .test_vae import preprocess 6 | 7 | 8 | @pytest.mark.parametrize('text', [ 9 | 'мальчик играет с оленем', 10 | ]) 11 | def test_forward_step_and_criterion(text, sample_image, yttm_tokenizer, vae, small_dalle): 12 | bs = 4 13 | text_seq_length = small_dalle.get_param('text_seq_length') 14 | total_seq_length = small_dalle.get_param('total_seq_length') 15 | device = small_dalle.get_param('device') 16 | 17 | img = sample_image.copy() 18 | img = preprocess(img, target_image_size=256) 19 | images = img.repeat(bs, 1, 1, 1).to(device) 20 | 21 | text = text.lower().strip() 22 | text_input_ids = yttm_tokenizer.encode_text(text, text_seq_length=text_seq_length) 23 | text_input_ids = text_input_ids.unsqueeze(0).repeat(bs, 1).to(device) 24 | 25 | attention_mask = torch.tril(torch.ones((bs, 1, total_seq_length, total_seq_length), device=device)) 26 | with torch.no_grad(): 27 | image_input_ids = vae.get_codebook_indices(images) 28 | input_ids = torch.cat((text_input_ids, image_input_ids), dim=1) 29 | loss, loss_values = small_dalle.forward(input_ids, attention_mask, return_loss=True) 30 | assert type(loss.data.detach().item()) == float 31 | assert type(loss_values) == dict 32 | -------------------------------------------------------------------------------- /pics/habr_eng.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /rudalle/emojich_unet/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | 4 | import torch 5 | from huggingface_hub import hf_hub_url, cached_download 6 | 7 | 8 | MODELS = { 9 | 'unet_effnetb5': dict( 10 | encoder_name='efficientnet-b5', 11 | repo_id='sberbank-ai/rudalle-Emojich', 12 | filename='pytorch_model_v2.bin', 13 | classes=2, 14 | ), 15 | } 16 | 17 | 18 | def get_emojich_unet(name, cache_dir='/tmp/rudalle'): 19 | assert name in MODELS 20 | config = MODELS[name] 21 | try: 22 | import segmentation_models_pytorch as smp 23 | except ImportError: 24 | import logging 25 | logging.warning('If you would like to use emojich_unet, you should reinstall timm package:' 26 | '"pip install timm==0.4.12"') 27 | return 28 | model = smp.Unet( 29 | encoder_name=config['encoder_name'], 30 | encoder_weights=None, 31 | in_channels=3, 32 | classes=config['classes'], 33 | ) 34 | cache_dir = os.path.join(cache_dir, name) 35 | filename = config['filename'] 36 | config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=f'{name}/{filename}') 37 | cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename) 38 | checkpoint = torch.load(os.path.join(cache_dir, config['filename']), map_location='cpu') 39 | model.load_state_dict(checkpoint) 40 | print(f'{name} --> ready') 41 | return model 42 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import io 3 | from os.path import abspath, dirname 4 | 5 | import PIL 6 | import pytest 7 | import requests 8 | 9 | from rudalle import get_tokenizer, get_rudalle_model, get_vae, get_realesrgan, get_emojich_unet 10 | 11 | 12 | TEST_ROOT = dirname(abspath(__file__)) 13 | 14 | 15 | @pytest.fixture(scope='module') 16 | def realesrgan(): 17 | realesrgan = get_realesrgan('x2', device='cpu') 18 | yield realesrgan 19 | 20 | 21 | @pytest.fixture(scope='module') 22 | def vae(): 23 | vae = get_vae(pretrained=False) 24 | yield vae 25 | 26 | 27 | @pytest.fixture(scope='module') 28 | def dwt_vae(): 29 | vae = get_vae(pretrained=False, dwt=True) 30 | yield vae 31 | 32 | 33 | @pytest.fixture(scope='module') 34 | def yttm_tokenizer(): 35 | tokenizer = get_tokenizer() 36 | yield tokenizer 37 | 38 | 39 | @pytest.fixture(scope='module') 40 | def sample_image(): 41 | url = 'https://cdn.kqed.org/wp-content/uploads/sites/12/2013/12/rudolph.png' 42 | resp = requests.get(url) 43 | resp.raise_for_status() 44 | image = PIL.Image.open(io.BytesIO(resp.content)) 45 | yield image 46 | 47 | 48 | @pytest.fixture(scope='module') 49 | def small_dalle(): 50 | model = get_rudalle_model('small', pretrained=False, fp16=False, device='cpu') 51 | yield model 52 | 53 | 54 | @pytest.fixture(scope='module') 55 | def emojich_unet(): 56 | model = get_emojich_unet('unet_effnetb5') 57 | yield model 58 | -------------------------------------------------------------------------------- /pics/habr.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /rudalle/dalle/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | 4 | 5 | def exists(val): 6 | return val is not None 7 | 8 | 9 | def is_empty(t): 10 | return t.nelement() == 0 11 | 12 | 13 | def ensure_divisibility(numerator, denominator): 14 | """Ensure that numerator is divisible by the denominator.""" 15 | assert numerator % denominator == 0, '{} is not divisible by {}'.format( 16 | numerator, denominator) 17 | 18 | 19 | def divide(numerator, denominator): 20 | """Ensure that numerator is divisible by the denominator and return 21 | the division value.""" 22 | ensure_divisibility(numerator, denominator) 23 | return numerator // denominator 24 | 25 | 26 | def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False): 27 | """ 28 | Split a tensor along its last dimension. 29 | Arguments: 30 | tensor: input tensor. 31 | num_partitions: number of partitions to split the tensor 32 | contiguous_split_chunks: If True, make each chunk contiguous 33 | in memory. 34 | """ 35 | # Get the size and dimension. 36 | last_dim = tensor.dim() - 1 37 | last_dim_size = divide(tensor.size()[last_dim], num_partitions) 38 | # Split. 39 | tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) 40 | # Note: torch.split does not create contiguous tensors by default. 41 | if contiguous_split_chunks: 42 | return tuple(chunk.contiguous() for chunk in tensor_list) 43 | return tensor_list 44 | 45 | 46 | def init_method_normal(std=0.02): 47 | """Init method based on normal distribution. 48 | 49 | This is only used for embeddings. The transformer has its 50 | own initializer. 51 | """ 52 | def init_(tensor): 53 | return torch.nn.init.normal_(tensor, mean=0.0, std=std) 54 | return init_ 55 | -------------------------------------------------------------------------------- /rudalle/dalle/image_attention.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | 5 | 6 | def _init_mask(text_tokens, image_tokens_per_dim, is_bool_mask=False): 7 | attn_size = text_tokens + image_tokens_per_dim**2 8 | mask = torch.tril(torch.ones(attn_size, attn_size, dtype=torch.bool if is_bool_mask else torch.float32)) 9 | return mask 10 | 11 | 12 | def get_row_mask(text_tokens=256, image_tokens_per_dim=32, is_bool_mask=False): 13 | mask = _init_mask(text_tokens, image_tokens_per_dim, is_bool_mask=is_bool_mask) 14 | step = image_tokens_per_dim + 1 15 | for col in range(text_tokens, mask.size(1)): 16 | mask[col + step:, col] = False if is_bool_mask else 0.0 17 | return mask 18 | 19 | 20 | def get_col_mask(text_tokens=256, image_tokens_per_dim=32, is_bool_mask=False): 21 | mask = _init_mask(text_tokens, image_tokens_per_dim, is_bool_mask=is_bool_mask) 22 | step = image_tokens_per_dim - 1 23 | for col in range(text_tokens, mask.size(1)): 24 | for i in range(1, mask.size(0), step+1): 25 | mask[col + i: col + i + step, col] = False if is_bool_mask else 0.0 26 | return mask 27 | 28 | 29 | def get_conv_mask(text_tokens=256, image_tokens_per_dim=32, kernel=11, is_bool_mask=False): 30 | mask = _init_mask(text_tokens, image_tokens_per_dim, is_bool_mask=is_bool_mask) 31 | shift = kernel // 2 32 | for pos in range(text_tokens, mask.size(1)): 33 | mask[pos+1:, pos] = False if is_bool_mask else 0.0 34 | img = torch.zeros(image_tokens_per_dim, image_tokens_per_dim) 35 | pixel_id = pos - text_tokens 36 | row = pixel_id // image_tokens_per_dim 37 | col = pixel_id % image_tokens_per_dim 38 | for r in range(-shift, shift+1): 39 | for c in range(-shift, shift+1): 40 | c_abs = (c + col) % image_tokens_per_dim 41 | r_abs = (r + row) % image_tokens_per_dim 42 | img[r_abs, c_abs] = 0.2 43 | cell_id = r_abs * image_tokens_per_dim + c_abs 44 | if text_tokens + cell_id > pos: 45 | mask[text_tokens + cell_id, pos] = True if is_bool_mask else 1.0 46 | 47 | img[row, col] = 1.0 48 | return mask 49 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import re 4 | from setuptools import setup 5 | 6 | 7 | def read(filename): 8 | with open(os.path.join(os.path.dirname(__file__), filename)) as f: 9 | file_content = f.read() 10 | return file_content 11 | 12 | 13 | def get_requirements(): 14 | requirements = [] 15 | for requirement in read('requirements.txt').splitlines(): 16 | if requirement.startswith('git+') or requirement.startswith('svn+') or requirement.startswith('hg+'): 17 | parsed_requires = re.findall(r'#egg=([\w\d\.]+)-([\d\.]+)$', requirement) 18 | if parsed_requires: 19 | package, version = parsed_requires[0] 20 | requirements.append(f'{package}=={version}') 21 | else: 22 | print('WARNING! For correct matching dependency links need to specify package name and version' 23 | 'such as #egg=-') 24 | else: 25 | requirements.append(requirement) 26 | return requirements 27 | 28 | 29 | def get_links(): 30 | return [ 31 | requirement for requirement in read('requirements.txt').splitlines() 32 | if requirement.startswith('git+') or requirement.startswith('svn+') or requirement.startswith('hg+') 33 | ] 34 | 35 | 36 | def get_version(): 37 | """ Get version from the package without actually importing it. """ 38 | init = read('rudalle/__init__.py') 39 | for line in init.split('\n'): 40 | if line.startswith('__version__'): 41 | return eval(line.split('=')[1]) 42 | 43 | 44 | setup( 45 | name='rudalle', 46 | version=get_version(), 47 | author='SberAI, SberDevices', 48 | author_email='shonenkov@phystech.edu', 49 | description='ruDALL-E generate images from texts in Russian language', 50 | packages=['rudalle', 'rudalle/dalle', 'rudalle/realesrgan', 'rudalle/ruclip', 'rudalle/vae', 51 | 'rudalle/emojich_unet'], 52 | package_data={'rudalle/vae': ['*.yml']}, 53 | install_requires=get_requirements(), 54 | dependency_links=get_links(), 55 | long_description=read('README.md'), 56 | long_description_content_type='text/markdown', 57 | ) 58 | -------------------------------------------------------------------------------- /tests/test_vae.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import PIL 3 | import pytest 4 | import torch 5 | import torchvision.transforms as T 6 | import torchvision.transforms.functional as TF 7 | 8 | 9 | @pytest.mark.parametrize('target_image_size', [128, 192, 256]) 10 | def test_decode_vae(vae, sample_image, target_image_size): 11 | img = sample_image.copy() 12 | img = preprocess(img, target_image_size=target_image_size) 13 | with torch.no_grad(): 14 | img_seq = vae.get_codebook_indices(img) 15 | out_img = vae.decode(img_seq) 16 | assert out_img.shape == (1, 3, target_image_size, target_image_size) 17 | 18 | 19 | @pytest.mark.parametrize('target_image_size', [128, 192, 256]) 20 | def test_reconstruct_vae(vae, sample_image, target_image_size): 21 | img = sample_image.copy() 22 | with torch.no_grad(): 23 | x_vqgan = preprocess(img, target_image_size=target_image_size) 24 | output = reconstruct_with_vqgan(preprocess_vqgan(x_vqgan), vae.model) 25 | assert output.shape == (1, 3, target_image_size, target_image_size) 26 | 27 | 28 | @pytest.mark.parametrize('target_image_size', [256]) 29 | def test_reconstruct_dwt_vae(dwt_vae, sample_image, target_image_size): 30 | img = sample_image.copy() 31 | with torch.no_grad(): 32 | x_vqgan = preprocess(img, target_image_size=target_image_size) 33 | output = reconstruct_with_vqgan(preprocess_vqgan(x_vqgan), dwt_vae.model) 34 | assert output.shape == (1, 3, target_image_size*2, target_image_size*2) 35 | 36 | 37 | def preprocess(img, target_image_size=256): 38 | s = min(img.size) 39 | if s < target_image_size: 40 | raise ValueError(f'min dim for image {s} < {target_image_size}') 41 | r = target_image_size / s 42 | s = (round(r * img.size[1]), round(r * img.size[0])) 43 | img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS) 44 | img = TF.center_crop(img, output_size=2 * [target_image_size]) 45 | img = torch.unsqueeze(T.ToTensor()(img), 0) 46 | return img 47 | 48 | 49 | def preprocess_vqgan(x): 50 | x = 2.*x - 1. 51 | return x 52 | 53 | 54 | def reconstruct_with_vqgan(x, model): 55 | z, _, [_, _, _] = model.encode(x) 56 | print(f'VQGAN --- {model.__class__.__name__}: latent shape: {z.shape[2:]}') 57 | xrec = model.decode(z) 58 | return xrec 59 | -------------------------------------------------------------------------------- /rudalle/tokenizer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from os.path import join 3 | 4 | import torch 5 | import numpy as np 6 | import youtokentome as yttm 7 | from huggingface_hub import hf_hub_url, cached_download 8 | 9 | 10 | def get_tokenizer(path=None, cache_dir='/tmp/rudalle'): 11 | # TODO docstring 12 | if path is None: 13 | repo_id = 'shonenkov/rudalle-utils' 14 | filename = 'bpe.model' 15 | cache_dir = join(cache_dir, 'tokenizer') 16 | config_file_url = hf_hub_url(repo_id=repo_id, filename=filename) 17 | cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename) 18 | path = join(cache_dir, filename) 19 | tokenizer = YTTMTokenizerWrapper(yttm.BPE(model=path)) 20 | print('tokenizer --> ready') 21 | return tokenizer 22 | 23 | 24 | class YTTMTokenizerWrapper: 25 | eos_id = 3 26 | bos_id = 2 27 | unk_id = 1 28 | pad_id = 0 29 | 30 | def __init__(self, tokenizer): 31 | self.tokenizer = tokenizer 32 | 33 | def __len__(self): 34 | return self.vocab_size() 35 | 36 | def get_pad_token_id(self): 37 | # TODO docstring 38 | return self.tokenizer.subword_to_id('') 39 | 40 | def vocab_size(self): 41 | # TODO docstring 42 | return self.tokenizer.vocab_size() 43 | 44 | def encode_text(self, text, text_seq_length, bpe_dropout=0.0): 45 | # TODO docstring 46 | tokens = self.tokenizer.encode([text], output_type=yttm.OutputType.ID, dropout_prob=bpe_dropout)[0] 47 | tokens = [self.bos_id] + tokens + [self.eos_id] 48 | return self.prepare_tokens(tokens, text_seq_length) 49 | 50 | def decode_text(self, encoded): 51 | # TODO docstring 52 | return self.tokenizer.decode(encoded.cpu().numpy().tolist(), ignore_ids=[ 53 | self.eos_id, self.bos_id, self.unk_id, self.pad_id 54 | ])[0] 55 | 56 | @staticmethod 57 | def prepare_tokens(tokens, text_seq_length): 58 | # TODO docstring 59 | empty_positions = text_seq_length - len(tokens) 60 | if empty_positions > 0: 61 | tokens = np.hstack((tokens, np.zeros(empty_positions))) # position tokens after text 62 | if len(tokens) > text_seq_length: 63 | tokens = tokens[:text_seq_length] 64 | return torch.tensor(tokens).long() 65 | -------------------------------------------------------------------------------- /rudalle/dalle/fp16.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | from torch import nn 4 | from torch.autograd import Variable 5 | from torch.nn.parameter import Parameter 6 | 7 | FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) 8 | HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) 9 | 10 | 11 | def conversion_helper(val, conversion): 12 | """Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure.""" 13 | if not isinstance(val, (tuple, list)): 14 | return conversion(val) 15 | rtn = [conversion_helper(v, conversion) for v in val] 16 | if isinstance(val, tuple): 17 | rtn = tuple(rtn) 18 | return rtn 19 | 20 | 21 | def fp32_to_fp16(val): 22 | """Convert fp32 `val` to fp16""" 23 | def half_conversion(val): 24 | val_typecheck = val 25 | if isinstance(val_typecheck, (Parameter, Variable)): 26 | val_typecheck = val.data 27 | if isinstance(val_typecheck, FLOAT_TYPES): 28 | val = val.half() 29 | return val 30 | return conversion_helper(val, half_conversion) 31 | 32 | 33 | def fp16_to_fp32(val): 34 | """Convert fp16 `val` to fp32""" 35 | def float_conversion(val): 36 | val_typecheck = val 37 | if isinstance(val_typecheck, (Parameter, Variable)): 38 | val_typecheck = val.data 39 | if isinstance(val_typecheck, HALF_TYPES): 40 | val = val.float() 41 | return val 42 | return conversion_helper(val, float_conversion) 43 | 44 | 45 | class FP16Module(nn.Module): 46 | def __init__(self, module): 47 | super(FP16Module, self).__init__() 48 | self.add_module('module', module.half()) 49 | 50 | def forward(self, *inputs, **kwargs): 51 | return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs)) 52 | 53 | def state_dict(self, destination=None, prefix='', keep_vars=False): 54 | return self.module.state_dict(destination, prefix, keep_vars) 55 | 56 | def load_state_dict(self, state_dict, strict=True): 57 | self.module.load_state_dict(state_dict, strict=strict) 58 | 59 | def get_param(self, item): 60 | return self.module.get_param(item) 61 | 62 | def to(self, device, *args, **kwargs): 63 | self.module.to(device) 64 | return super().to(device, *args, **kwargs) 65 | -------------------------------------------------------------------------------- /rudalle/realesrgan/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Source: https://github.com/boomb0om/Real-ESRGAN-colab 3 | 4 | import torch 5 | import numpy as np 6 | from PIL import Image 7 | 8 | from .rrdbnet_arch import RRDBNet 9 | from .utils import pad_reflect, split_image_into_overlapping_patches, stich_together, unpad_image 10 | from rudalle.dalle.fp16 import FP16Module 11 | 12 | 13 | class RealESRGAN: 14 | def __init__(self, device, scale=4, fp16=False): 15 | self.device = device 16 | self.scale = scale 17 | self.model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale) 18 | self.fp16 = fp16 19 | 20 | def load_weights(self, model_path): 21 | loadnet = torch.load(model_path) 22 | if 'params' in loadnet: 23 | self.model.load_state_dict(loadnet['params'], strict=True) 24 | elif 'params_ema' in loadnet: 25 | self.model.load_state_dict(loadnet['params_ema'], strict=True) 26 | else: 27 | self.model.load_state_dict(loadnet, strict=True) 28 | self.model.eval() 29 | if self.fp16: 30 | self.model = FP16Module(self.model) 31 | self.model.to(self.device) 32 | 33 | def predict(self, lr_image, batch_size=4, patches_size=192, 34 | padding=24, pad_size=15): 35 | scale = self.scale 36 | device = self.device 37 | lr_image = np.array(lr_image) 38 | lr_image = pad_reflect(lr_image, pad_size) 39 | 40 | patches, p_shape = split_image_into_overlapping_patches(lr_image, patch_size=patches_size, 41 | padding_size=padding) 42 | if self.fp16: 43 | img = torch.HalfTensor(patches / 255).permute((0, 3, 1, 2)).to(device).detach() 44 | else: 45 | img = torch.FloatTensor(patches / 255).permute((0, 3, 1, 2)).to(device).detach() 46 | 47 | with torch.no_grad(): 48 | res = self.model(img[0:batch_size]) 49 | for i in range(batch_size, img.shape[0], batch_size): 50 | res = torch.cat((res, self.model(img[i:i + batch_size])), 0) 51 | 52 | sr_image = res.permute((0, 2, 3, 1)).cpu().clamp_(0, 1) 53 | np_sr_image = sr_image.numpy() 54 | 55 | padded_size_scaled = tuple(np.multiply(p_shape[0:2], scale)) + (3,) 56 | scaled_image_shape = tuple(np.multiply(lr_image.shape[0:2], scale)) + (3,) 57 | np_sr_image = stich_together(np_sr_image, padded_image_shape=padded_size_scaled, 58 | target_shape=scaled_image_shape, padding_size=padding * scale) 59 | sr_img = (np_sr_image * 255).astype(np.uint8) 60 | sr_img = unpad_image(sr_img, pad_size * scale) 61 | sr_img = Image.fromarray(sr_img) 62 | 63 | return sr_img 64 | -------------------------------------------------------------------------------- /rudalle/ruclip/processor.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import json 4 | import torch 5 | import youtokentome as yttm 6 | import torchvision.transforms as T 7 | from torch.nn.utils.rnn import pad_sequence 8 | 9 | 10 | class RuCLIPProcessor: 11 | eos_id = 3 12 | bos_id = 2 13 | unk_id = 1 14 | pad_id = 0 15 | 16 | def __init__(self, tokenizer_path, image_size=224, text_seq_length=76, mean=None, std=None): 17 | 18 | self.tokenizer = yttm.BPE(tokenizer_path) 19 | self.mean = mean or [0.485, 0.456, 0.406] 20 | self.std = std or [0.229, 0.224, 0.225] 21 | self.image_transform = T.Compose([ 22 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), 23 | T.RandomResizedCrop(image_size, scale=(1., 1.), ratio=(1., 1.)), 24 | T.ToTensor(), 25 | T.Normalize(mean=self.mean, std=self.std) 26 | ]) 27 | self.text_seq_length = text_seq_length 28 | self.image_size = image_size 29 | 30 | def encode_text(self, text): 31 | text = text.lower() 32 | tokens = self.tokenizer.encode([text], output_type=yttm.OutputType.ID, dropout_prob=0.0)[0] 33 | tokens = [self.bos_id] + tokens + [self.eos_id] 34 | tokens = tokens[:self.text_seq_length] 35 | mask = [1] * len(tokens) 36 | return torch.tensor(tokens).long(), torch.tensor(mask).long() 37 | 38 | def decode_text(self, encoded): 39 | return self.tokenizer.decode(encoded.cpu().numpy().tolist(), ignore_ids=[ 40 | self.eos_id, self.bos_id, self.unk_id, self.pad_id 41 | ])[0] 42 | 43 | def __call__(self, text=None, images=None, **kwargs): 44 | inputs = {} 45 | if text is not None: 46 | input_ids, masks = [], [] 47 | texts = [text] if isinstance(text, str) else text 48 | for text in texts: 49 | tokens, mask = self.encode_text(text) 50 | input_ids.append(tokens) 51 | masks.append(mask) 52 | inputs['input_ids'] = pad_sequence(input_ids, batch_first=True) 53 | inputs['attention_mask'] = pad_sequence(masks, batch_first=True) 54 | if images is not None: 55 | pixel_values = [] 56 | for i, image in enumerate(images): 57 | pixel_values.append(self.image_transform(image)) 58 | inputs['pixel_values'] = pad_sequence(pixel_values, batch_first=True) 59 | return inputs 60 | 61 | @classmethod 62 | def from_pretrained(cls, folder): 63 | tokenizer_path = os.path.join(folder, 'bpe.model') 64 | config = json.load(open(os.path.join(folder, 'config.json'))) 65 | image_size = config['vision_config']['image_size'] 66 | text_seq_length = config['text_config']['max_position_embeddings'] - 1 67 | mean, std = config.get('mean'), config.get('std') 68 | return cls(tokenizer_path, image_size=image_size, text_seq_length=text_seq_length, mean=mean, std=std) 69 | -------------------------------------------------------------------------------- /rudalle/image_prompts.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import numpy as np 4 | 5 | 6 | class ImagePrompts: 7 | 8 | def __init__(self, pil_image, borders, vae, device='cpu', crop_first=False): 9 | """ 10 | Args: 11 | pil_image (PIL.Image): image in PIL format 12 | borders (dict[str] | int): borders that we croped from pil_image 13 | example: {'up': 4, 'right': 0, 'left': 0, 'down': 0} (1 int eq 8 pixels) 14 | vae (VQGanGumbelVAE): VQGAN model for image encoding 15 | device (str): cpu or cuda 16 | crop_first (bool): if True, croped image before VQGAN encoding 17 | """ 18 | self.device = device 19 | img = self._preprocess_img(pil_image) 20 | self.image_prompts_idx, self.image_prompts = self._get_image_prompts(img, borders, vae, crop_first) 21 | 22 | def _preprocess_img(self, pil_img): 23 | img = torch.tensor(np.array(pil_img.convert('RGB')).transpose(2, 0, 1)) / 255. 24 | img = img.unsqueeze(0).to(self.device, dtype=torch.float32) 25 | img = (2 * img) - 1 26 | return img 27 | 28 | def _get_image_prompts(self, img, borders, vae, crop_first): 29 | if crop_first: 30 | bs, _, img_w, img_h = img.shape 31 | vqg_img_w, vqg_img_h = img_w // 8, img_h // 8 32 | vqg_img = torch.zeros((bs, vqg_img_w, vqg_img_h), dtype=torch.int32, device=img.device) 33 | if borders['down'] != 0: 34 | down_border = borders['down'] * 8 35 | _, _, [_, _, down_vqg_img] = vae.model.encode(img[:, :, -down_border:, :]) 36 | vqg_img[:, -borders['down']:, :] = down_vqg_img 37 | if borders['right'] != 0: 38 | right_border = borders['right'] * 8 39 | _, _, [_, _, right_vqg_img] = vae.model.encode(img[:, :, :, -right_border:]) 40 | vqg_img[:, :, -borders['right']:] = right_vqg_img 41 | if borders['left'] != 0: 42 | left_border = borders['left'] * 8 43 | _, _, [_, _, left_vqg_img] = vae.model.encode(img[:, :, :, :left_border]) 44 | vqg_img[:, :, :borders['left']] = left_vqg_img 45 | if borders['up'] != 0: 46 | up_border = borders['up'] * 8 47 | _, _, [_, _, up_vqg_img] = vae.model.encode(img[:, :, :up_border, :]) 48 | vqg_img[:, :borders['up'], :] = up_vqg_img 49 | else: 50 | _, _, [_, _, vqg_img] = vae.model.encode(img) 51 | 52 | bs, vqg_img_w, vqg_img_h = vqg_img.shape 53 | mask = torch.zeros(vqg_img_w, vqg_img_h) 54 | if borders['up'] != 0: 55 | mask[:borders['up'], :] = 1. 56 | if borders['down'] != 0: 57 | mask[-borders['down']:, :] = 1. 58 | if borders['right'] != 0: 59 | mask[:, -borders['right']:] = 1. 60 | if borders['left'] != 0: 61 | mask[:, :borders['left']] = 1. 62 | mask = mask.reshape(-1).bool() 63 | 64 | image_prompts = vqg_img.reshape((bs, -1)) 65 | image_prompts_idx = np.arange(vqg_img_w * vqg_img_h) 66 | image_prompts_idx = set(image_prompts_idx[mask]) 67 | 68 | return image_prompts_idx, image_prompts 69 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### JetBrains template 3 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 4 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 5 | 6 | settings/local.py 7 | logs/*.log 8 | 9 | # User-specific stuff: 10 | .idea/ 11 | 12 | # Sensitive or high-churn files: 13 | .idea/**/dataSources/ 14 | .idea/**/dataSources.ids 15 | .idea/**/dataSources.xml 16 | .idea/**/dataSources.local.xml 17 | .idea/**/sqlDataSources.xml 18 | .idea/**/dynamic.xml 19 | .idea/**/uiDesigner.xml 20 | 21 | # Gradle: 22 | .idea/**/gradle.xml 23 | .idea/**/libraries 24 | 25 | # CMake 26 | cmake-build-debug/ 27 | 28 | # Mongo Explorer plugin: 29 | .idea/**/mongoSettings.xml 30 | 31 | ## File-based project format: 32 | *.iws 33 | 34 | ## Plugin-specific files: 35 | 36 | # IntelliJ 37 | out/ 38 | 39 | # mpeltonen/sbt-idea plugin 40 | .idea_modules/ 41 | 42 | # JIRA plugin 43 | atlassian-ide-plugin.xml 44 | 45 | # Cursive Clojure plugin 46 | .idea/replstate.xml 47 | 48 | # Crashlytics plugin (for Android Studio and IntelliJ) 49 | com_crashlytics_export_strings.xml 50 | crashlytics.properties 51 | crashlytics-build.properties 52 | fabric.properties 53 | ### Python template 54 | # Byte-compiled / optimized / DLL files 55 | __pycache__/ 56 | *.py[cod] 57 | *$py.class 58 | 59 | # C extensions 60 | *.so 61 | 62 | # Distribution / packaging 63 | .Python 64 | build/ 65 | develop-eggs/ 66 | dist/ 67 | downloads/ 68 | eggs/ 69 | .eggs/ 70 | lib/ 71 | lib64/ 72 | parts/ 73 | sdist/ 74 | var/ 75 | wheels/ 76 | *.egg-info/ 77 | .installed.cfg 78 | *.egg 79 | 80 | # PyInstaller 81 | # Usually these files are written by a python script from a template 82 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 83 | *.manifest 84 | *.spec 85 | 86 | # Installer logs 87 | pip-log.txt 88 | pip-delete-this-directory.txt 89 | 90 | # Unit test / coverage reports 91 | htmlcov/ 92 | .tox/ 93 | .coverage 94 | .coverage.* 95 | .cache 96 | nosetests.xml 97 | coverage.xml 98 | *.cover 99 | .hypothesis/ 100 | 101 | # Translations 102 | *.mo 103 | *.pot 104 | 105 | # Django stuff: 106 | *.log 107 | local_settings.py 108 | 109 | # Flask stuff: 110 | instance/ 111 | .webassets-cache 112 | 113 | # Scrapy stuff: 114 | .scrapy 115 | 116 | # Sphinx documentation 117 | docs/_build/ 118 | 119 | # PyBuilder 120 | target/ 121 | 122 | # Jupyter Notebook 123 | .ipynb_checkpoints 124 | 125 | # pyenv 126 | .python-version 127 | 128 | # celery beat schedule file 129 | celerybeat-schedule 130 | 131 | # SageMath parsed files 132 | *.sage.py 133 | 134 | # Environments 135 | .env 136 | .venv 137 | env/ 138 | venv/ 139 | ENV/ 140 | 141 | # Spyder project settings 142 | .spyderproject 143 | .spyproject 144 | 145 | # Rope project settings 146 | .ropeproject 147 | 148 | # mkdocs documentation 149 | /site 150 | 151 | # mypy 152 | .mypy_cache/ 153 | /tests/load_tests/logs/* 154 | /tests/.pytest_cache/ 155 | ws_test.py 156 | /.vscode/ 157 | 158 | .s3_cache/ 159 | mlruns 160 | *.pyc 161 | *.swp 162 | *.pt 163 | *.bin 164 | .vscode/ 165 | runs/ 166 | jupyters/custom_* 167 | 168 | *logs/ 169 | .DS_store 170 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ruDALL-E diffusion 2 | 3 | try it out on colab 4 | 5 | ru-dalle-diffusion colab link 6 | 7 | 8 | ruDALL-E diffusion is ruDALL-E with a diffusion decoder, similar to [dall-3](https://github.com/Jack000/DALLE-pytorch/) 9 | 10 | Decoding VQ embeddings with a DDPM model can produce much more realistic fine-grain details than VQVAE and VQGAN. 11 | 12 | the only code change to ruDALL-E is to return the image tokens in generate_images() - the actual diffusion model is here: https://github.com/Jack000/guided-diffusion 13 | 14 | # Samples 15 | | ruDALL-E + real-ESRGAN | diffusion | 16 | | --- | --- | 17 | | | | 18 | | | | 19 | | | | 20 | | | | 21 | 22 | note that the results depend a lot on the seed value 23 | 24 | base image: 25 | 26 | 27 | diffusion-generated samples (different seeds): 28 | |   |   |   | 29 | | --- | --- | --- | 30 | | | | | 31 | | | | | 32 | 33 | ### generation by ruDALLE: 34 | ```python 35 | from rudalle.pipelines import generate_images, show, super_resolution, cherry_pick_by_clip 36 | from rudalle import get_rudalle_model, get_tokenizer, get_vae, get_realesrgan, get_ruclip 37 | from rudalle.utils import seed_everything 38 | import numpy as np 39 | 40 | # prepare models: 41 | device = 'cuda' 42 | dalle = get_rudalle_model('Malevich', pretrained=True, fp16=True, device=device) 43 | tokenizer = get_tokenizer() 44 | vae = get_vae(dwt=False).to(device) # Make sure to set dwt to False! 45 | 46 | text = 'изображение радуги на фоне ночного города' 47 | 48 | seed_everything(42) 49 | pil_images = [] 50 | scores = [] 51 | codes = [] 52 | 53 | for top_k, top_p, images_num in [ 54 | (2048, 0.995, 3), 55 | (1536, 0.99, 3), 56 | (1024, 0.99, 3), 57 | (1024, 0.98, 3), 58 | (512, 0.97, 3), 59 | (384, 0.96, 3), 60 | (256, 0.95, 3), 61 | (128, 0.95, 3), 62 | ]: 63 | _pil_images, _scores, _codes = generate_images(text, tokenizer, dalle, vae, top_k=top_k, images_num=images_num, top_p=top_p, return_codes=True) 64 | pil_images += _pil_images 65 | scores += _scores 66 | codes += _codes 67 | 68 | sr_images = super_resolution(pil_images, realesrgan) 69 | 70 | for i, im in enumerate(pil_images): 71 | im.save(str(i)+'.png') 72 | sr_images[i].save(str(i)+'sr.png') 73 | with open(str(i)+'.npy', 'wb') as f: 74 | np.save(f, codes[i]) 75 | 76 | # afterward, pass .npy file to diffusion model https://github.com/Jack000/guided-diffusion 77 | ``` 78 | -------------------------------------------------------------------------------- /rudalle/dalle/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | 4 | import torch 5 | from huggingface_hub import hf_hub_url, cached_download 6 | 7 | from .model import DalleModel 8 | from .fp16 import FP16Module 9 | 10 | 11 | MODELS = { 12 | 'Malevich': dict( 13 | description='◼️ Malevich is 1.3 billion params model from the family GPT3-like, ' 14 | 'that uses Russian language and text+image multi-modality.', 15 | model_params=dict( 16 | num_layers=24, 17 | hidden_size=2048, 18 | num_attention_heads=16, 19 | embedding_dropout_prob=0.1, 20 | output_dropout_prob=0.1, 21 | attention_dropout_prob=0.1, 22 | image_tokens_per_dim=32, 23 | text_seq_length=128, 24 | cogview_sandwich_layernorm=True, 25 | cogview_pb_relax=True, 26 | vocab_size=16384+128, 27 | image_vocab_size=8192, 28 | ), 29 | repo_id='sberbank-ai/rudalle-Malevich', 30 | filename='pytorch_model_v2.bin', 31 | authors='SberAI, SberDevices', 32 | full_description='', # TODO 33 | ), 34 | 'Emojich': dict( 35 | description='😋 Emojich is a 1.3 billion params model from the family GPT3-like, ' 36 | 'it generates emoji-style images with the brain of ◾ Malevich.', 37 | model_params=dict( 38 | num_layers=24, 39 | hidden_size=2048, 40 | num_attention_heads=16, 41 | embedding_dropout_prob=0.1, 42 | output_dropout_prob=0.1, 43 | attention_dropout_prob=0.1, 44 | image_tokens_per_dim=32, 45 | text_seq_length=128, 46 | cogview_sandwich_layernorm=True, 47 | cogview_pb_relax=True, 48 | vocab_size=16384 + 128, 49 | image_vocab_size=8192, 50 | ), 51 | repo_id='sberbank-ai/rudalle-Emojich', 52 | filename='pytorch_model.bin', 53 | authors='SberAI', 54 | full_description='', # TODO 55 | ), 56 | 'small': dict( 57 | description='', 58 | model_params=dict( 59 | num_layers=12, 60 | hidden_size=768, 61 | num_attention_heads=12, 62 | embedding_dropout_prob=0.1, 63 | output_dropout_prob=0.1, 64 | attention_dropout_prob=0.1, 65 | image_tokens_per_dim=32, 66 | text_seq_length=128, 67 | cogview_sandwich_layernorm=True, 68 | cogview_pb_relax=True, 69 | vocab_size=16384+128, 70 | image_vocab_size=8192, 71 | ), 72 | repo_id='', 73 | filename='', 74 | full_description='', # TODO 75 | ), 76 | } 77 | 78 | 79 | def get_rudalle_model(name, pretrained=True, fp16=False, device='cpu', cache_dir='/tmp/rudalle', **model_kwargs): 80 | # TODO docstring 81 | assert name in MODELS 82 | 83 | if fp16 and device == 'cpu': 84 | print('Warning! Using both fp16 and cpu doesnt support. You can use cuda device or turn off fp16.') 85 | 86 | config = MODELS[name].copy() 87 | config['model_params'].update(model_kwargs) 88 | model = DalleModel(device=device, **config['model_params']) 89 | if pretrained: 90 | cache_dir = os.path.join(cache_dir, name) 91 | config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename']) 92 | cached_download(config_file_url, cache_dir=cache_dir, force_filename=config['filename']) 93 | checkpoint = torch.load(os.path.join(cache_dir, config['filename']), map_location='cpu') 94 | model.load_state_dict(checkpoint) 95 | if fp16: 96 | model = FP16Module(model) 97 | model.eval() 98 | model = model.to(device) 99 | if config['description'] and pretrained: 100 | print(config['description']) 101 | return model 102 | -------------------------------------------------------------------------------- /rudalle/vae/decoder_dwt.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import pywt 3 | import torch 4 | import torch.nn as nn 5 | from taming.modules.diffusionmodules.model import Decoder 6 | 7 | from .pytorch_wavelets_utils import SFB2D, _SFB2D, prep_filt_sfb2d, mode_to_int 8 | 9 | 10 | class DecoderDWT(nn.Module): 11 | def __init__(self, ddconfig, embed_dim): 12 | super().__init__() 13 | if ddconfig.out_ch != 12: 14 | ddconfig.out_ch = 12 15 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig['z_channels'], 1) 16 | self.decoder = Decoder(**ddconfig) 17 | self.idwt = DWTInverse(mode='zero', wave='db1') 18 | 19 | def forward(self, x): 20 | # x = self.post_quant_conv(x) 21 | freq = self.decoder(x) 22 | img = self.dwt_to_img(freq) 23 | return img 24 | 25 | def dwt_to_img(self, img): 26 | b, c, h, w = img.size() 27 | low = img[:, :3, :, :] 28 | high = img[:, 3:, :, :].view(b, 3, 3, h, w) 29 | return self.idwt((low, [high])) 30 | 31 | 32 | class DWTInverse(nn.Module): 33 | """ Performs a 2d DWT Inverse reconstruction of an image 34 | 35 | Args: 36 | wave (str or pywt.Wavelet): Which wavelet to use 37 | C: deprecated, will be removed in future 38 | """ 39 | 40 | def __init__(self, wave='db1', mode='zero', trace_model=False): 41 | super().__init__() 42 | if isinstance(wave, str): 43 | wave = pywt.Wavelet(wave) 44 | if isinstance(wave, pywt.Wavelet): 45 | g0_col, g1_col = wave.rec_lo, wave.rec_hi 46 | g0_row, g1_row = g0_col, g1_col 47 | else: 48 | if len(wave) == 2: 49 | g0_col, g1_col = wave[0], wave[1] 50 | g0_row, g1_row = g0_col, g1_col 51 | elif len(wave) == 4: 52 | g0_col, g1_col = wave[0], wave[1] 53 | g0_row, g1_row = wave[2], wave[3] 54 | # Prepare the filters 55 | filts = prep_filt_sfb2d(g0_col, g1_col, g0_row, g1_row) 56 | self.register_buffer('g0_col', filts[0]) 57 | self.register_buffer('g1_col', filts[1]) 58 | self.register_buffer('g0_row', filts[2]) 59 | self.register_buffer('g1_row', filts[3]) 60 | self.mode = mode 61 | self.trace_model = trace_model 62 | 63 | def forward(self, coeffs): 64 | """ 65 | Args: 66 | coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where: 67 | yl is a lowpass tensor of shape :math:`(N, C_{in}, H_{in}', 68 | W_{in}')` and yh is a list of bandpass tensors of shape 69 | :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. I.e. should match 70 | the format returned by DWTForward 71 | 72 | Returns: 73 | Reconstructed input of shape :math:`(N, C_{in}, H_{in}, W_{in})` 74 | 75 | Note: 76 | :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly 77 | downsampled shapes of the DWT pyramid. 78 | 79 | Note: 80 | Can have None for any of the highpass scales and will treat the 81 | values as zeros (not in an efficient way though). 82 | """ 83 | yl, yh = coeffs 84 | ll = yl 85 | mode = mode_to_int(self.mode) 86 | 87 | # Do a multilevel inverse transform 88 | for h in yh[::-1]: 89 | if h is None: 90 | h = torch.zeros(ll.shape[0], ll.shape[1], 3, ll.shape[-2], 91 | ll.shape[-1], device=ll.device) 92 | 93 | # 'Unpad' added dimensions 94 | if ll.shape[-2] > h.shape[-2]: 95 | ll = ll[..., :-1, :] 96 | if ll.shape[-1] > h.shape[-1]: 97 | ll = ll[..., :-1] 98 | if not self.trace_model: 99 | ll = SFB2D.apply(ll, h, self.g0_col, self.g1_col, self.g0_row, self.g1_row, mode) 100 | else: 101 | ll = _SFB2D(ll, h, self.g0_col, self.g1_col, self.g0_row, self.g1_row, mode) 102 | return ll 103 | -------------------------------------------------------------------------------- /rudalle/vae/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from math import sqrt, log 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch import einsum 8 | from einops import rearrange 9 | from taming.modules.diffusionmodules.model import Encoder, Decoder 10 | 11 | from .decoder_dwt import DecoderDWT 12 | 13 | 14 | class VQGanGumbelVAE(torch.nn.Module): 15 | 16 | def __init__(self, config, dwt=False): 17 | super().__init__() 18 | model = GumbelVQ( 19 | ddconfig=config.model.params.ddconfig, 20 | n_embed=config.model.params.n_embed, 21 | embed_dim=config.model.params.embed_dim, 22 | kl_weight=config.model.params.kl_weight, 23 | dwt=dwt, 24 | ) 25 | self.model = model 26 | self.num_layers = int(log(config.model.params.ddconfig.attn_resolutions[0]) / log(2)) 27 | self.image_size = 256 28 | self.num_tokens = config.model.params.n_embed 29 | 30 | @torch.no_grad() 31 | def get_codebook_indices(self, img): 32 | img = (2 * img) - 1 33 | _, _, [_, _, indices] = self.model.encode(img) 34 | return rearrange(indices, 'b h w -> b (h w)') 35 | 36 | def decode(self, img_seq): 37 | b, n = img_seq.shape 38 | one_hot_indices = torch.nn.functional.one_hot(img_seq, num_classes=self.num_tokens).float() 39 | z = (one_hot_indices @ self.model.quantize.embed.weight) 40 | z = rearrange(z, 'b (h w) c -> b c h w', h=int(sqrt(n))) 41 | img = self.model.decode(z) 42 | img = (img.clamp(-1., 1.) + 1) * 0.5 43 | return img 44 | 45 | 46 | class GumbelQuantize(nn.Module): 47 | """ 48 | credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!) 49 | Gumbel Softmax trick quantizer 50 | Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016 51 | https://arxiv.org/abs/1611.01144 52 | """ 53 | 54 | def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True, 55 | kl_weight=5e-4, temp_init=1.0, use_vqinterface=True): 56 | super().__init__() 57 | self.embedding_dim = embedding_dim 58 | self.n_embed = n_embed 59 | self.straight_through = straight_through 60 | self.temperature = temp_init 61 | self.kl_weight = kl_weight 62 | self.proj = nn.Conv2d(num_hiddens, n_embed, 1) 63 | self.embed = nn.Embedding(self.n_embed, self.embedding_dim) 64 | self.use_vqinterface = use_vqinterface 65 | 66 | def forward(self, z, temp=None, return_logits=False): 67 | hard = self.straight_through if self.training else True 68 | temp = self.temperature if temp is None else temp 69 | logits = self.proj(z) 70 | soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard) 71 | z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight) 72 | # + kl divergence to the prior loss 73 | qy = F.softmax(logits, dim=1) 74 | diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean() 75 | ind = soft_one_hot.argmax(dim=1) 76 | if self.use_vqinterface: 77 | if return_logits: 78 | return z_q, diff, (None, None, ind), logits 79 | return z_q, diff, (None, None, ind) 80 | return z_q, diff, ind 81 | 82 | 83 | class GumbelVQ(nn.Module): 84 | 85 | def __init__(self, ddconfig, n_embed, embed_dim, dwt=False, kl_weight=1e-8): 86 | super().__init__() 87 | z_channels = ddconfig['z_channels'] 88 | self.dwt = dwt 89 | self.encoder = Encoder(**ddconfig) 90 | self.decoder = DecoderDWT(ddconfig, embed_dim) if dwt else Decoder(**ddconfig) 91 | self.quantize = GumbelQuantize(z_channels, embed_dim, n_embed=n_embed, kl_weight=kl_weight, temp_init=1.0) 92 | self.quant_conv = torch.nn.Conv2d(ddconfig['z_channels'], embed_dim, 1) 93 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig['z_channels'], 1) 94 | 95 | def encode(self, x): 96 | h = self.encoder(x) 97 | h = self.quant_conv(h) 98 | quant, emb_loss, info = self.quantize(h) 99 | return quant, emb_loss, info 100 | 101 | def decode(self, quant): 102 | if self.dwt: 103 | quant = self.decoder.post_quant_conv(quant) 104 | else: 105 | quant = self.post_quant_conv(quant) 106 | dec = self.decoder(quant) 107 | return dec 108 | -------------------------------------------------------------------------------- /rudalle/realesrgan/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | 4 | 5 | def pad_reflect(image, pad_size): 6 | imsize = image.shape 7 | height, width = imsize[:2] 8 | new_img = np.zeros([height + pad_size * 2, width + pad_size * 2, imsize[2]]).astype(np.uint8) 9 | new_img[pad_size:-pad_size, pad_size:-pad_size, :] = image 10 | new_img[0:pad_size, pad_size:-pad_size, :] = np.flip(image[0:pad_size, :, :], axis=0) # top 11 | new_img[-pad_size:, pad_size:-pad_size, :] = np.flip(image[-pad_size:, :, :], axis=0) # bottom 12 | new_img[:, 0:pad_size, :] = np.flip(new_img[:, pad_size:pad_size * 2, :], axis=1) # left 13 | new_img[:, -pad_size:, :] = np.flip(new_img[:, -pad_size * 2:-pad_size, :], axis=1) # right 14 | return new_img 15 | 16 | 17 | def unpad_image(image, pad_size): 18 | return image[pad_size:-pad_size, pad_size:-pad_size, :] 19 | 20 | 21 | def pad_patch(image_patch, padding_size, channel_last=True): 22 | """ Pads image_patch with with padding_size edge values. """ 23 | if channel_last: 24 | return np.pad( 25 | image_patch, 26 | ((padding_size, padding_size), (padding_size, padding_size), (0, 0)), 27 | 'edge', 28 | ) 29 | else: 30 | return np.pad( 31 | image_patch, 32 | ((0, 0), (padding_size, padding_size), (padding_size, padding_size)), 33 | 'edge', 34 | ) 35 | 36 | 37 | def unpad_patches(image_patches, padding_size): 38 | return image_patches[:, padding_size:-padding_size, padding_size:-padding_size, :] 39 | 40 | 41 | def split_image_into_overlapping_patches(image_array, patch_size, padding_size=2): 42 | """ Splits the image into partially overlapping patches. 43 | The patches overlap by padding_size pixels. 44 | Pads the image twice: 45 | - first to have a size multiple of the patch size, 46 | - then to have equal padding at the borders. 47 | Args: 48 | image_array: numpy array of the input image. 49 | patch_size: size of the patches from the original image (without padding). 50 | padding_size: size of the overlapping area. 51 | """ 52 | xmax, ymax, _ = image_array.shape 53 | x_remainder = xmax % patch_size 54 | y_remainder = ymax % patch_size 55 | 56 | # modulo here is to avoid extending of patch_size instead of 0 57 | x_extend = (patch_size - x_remainder) % patch_size 58 | y_extend = (patch_size - y_remainder) % patch_size 59 | 60 | # make sure the image is divisible into regular patches 61 | extended_image = np.pad(image_array, ((0, x_extend), (0, y_extend), (0, 0)), 'edge') 62 | 63 | # add padding around the image to simplify computations 64 | padded_image = pad_patch(extended_image, padding_size, channel_last=True) 65 | 66 | xmax, ymax, _ = padded_image.shape 67 | patches = [] 68 | 69 | x_lefts = range(padding_size, xmax - padding_size, patch_size) 70 | y_tops = range(padding_size, ymax - padding_size, patch_size) 71 | 72 | for x in x_lefts: 73 | for y in y_tops: 74 | x_left = x - padding_size 75 | y_top = y - padding_size 76 | x_right = x + patch_size + padding_size 77 | y_bottom = y + patch_size + padding_size 78 | patch = padded_image[x_left:x_right, y_top:y_bottom, :] 79 | patches.append(patch) 80 | 81 | return np.array(patches), padded_image.shape 82 | 83 | 84 | def stich_together(patches, padded_image_shape, target_shape, padding_size=4): 85 | """ Reconstruct the image from overlapping patches. 86 | After scaling, shapes and padding should be scaled too. 87 | Args: 88 | patches: patches obtained with split_image_into_overlapping_patches 89 | padded_image_shape: shape of the padded image contructed in split_image_into_overlapping_patches 90 | target_shape: shape of the final image 91 | padding_size: size of the overlapping area. 92 | """ 93 | 94 | xmax, ymax, _ = padded_image_shape 95 | patches = unpad_patches(patches, padding_size) 96 | patch_size = patches.shape[1] 97 | n_patches_per_row = ymax // patch_size 98 | 99 | complete_image = np.zeros((xmax, ymax, 3)) 100 | 101 | row = -1 102 | col = 0 103 | for i in range(len(patches)): 104 | if i % n_patches_per_row == 0: 105 | row += 1 106 | col = 0 107 | complete_image[ 108 | row * patch_size: (row + 1) * patch_size, col * patch_size: (col + 1) * patch_size, : 109 | ] = patches[i] 110 | col += 1 111 | return complete_image[0: target_shape[0], 0: target_shape[1], :] 112 | -------------------------------------------------------------------------------- /Emojich.md: -------------------------------------------------------------------------------- 1 | [[Paper]](https://arxiv.org/abs/2112.02448) [[Хабр]](https://habr.com/ru/company/sberbank/blog/593893/) [[Model Card]](https://huggingface.co/sberbank-ai/rudalle-Emojich) [[Kaggle]](https://www.kaggle.com/shonenkov/emojich-rudall-e) [[Dataset]](https://www.kaggle.com/shonenkov/russian-emoji) 2 | # Emojich 3 | ![](./pics/emojich/emojich_rgba_100.png) 4 | ### generate emojis from text 5 | 6 | Model was trained by [Sber AI](https://github.com/sberbank-ai) 7 | * Task: `text2image generation` 8 | * Num Parameters: `1.3 B` 9 | * Training Data Volume: `120 million text-image pairs` & [`2749 text-emoji pairs`](https://www.kaggle.com/shonenkov/russian-emoji) 10 | 11 | [![Telegram](https://img.shields.io/badge/Telegram-Stickers-blue?style=for-the-badge&logo=)](https://telegram.me/addstickers/SberAI_ruDALLE) 12 | 13 | ### Model Description 14 | 😋 Emojich is a 1.3 billion params model from the family GPT3-like, it generates emoji-style images with the brain of ◾ Malevich. 15 | 16 | 17 | ### Fine-tuning stage: 18 | 19 | The main goal of fine-tuning is trying to keep the generalization of [ruDALL-E Malevich (XL)](https://huggingface.co/sberbank-ai/rudalle-Malevich) 20 | model on text to emoji tasks. ruDALL-E Malevich is a multi-modality big pretrained transformer, that uses images and texts. 21 | The idea with freezing feedforward and self-attention layers in pretrained transformer is demonstrated high performance in changing different modalities. 22 | Also, the model has a good chance for over-fitting text modality and lost generalization. 23 | To deal with this problem is increased coefficient 10^3 in weighted cross-entropy loss for image codebooks part. 24 | 25 | 26 | Full version of training code is available on Kaggle: [![Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://www.kaggle.com/shonenkov/emojich-rudall-e) 27 | 28 | ### Usage: 29 | 30 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1YbEduCe8jH0DXMXKxnb8ulmT8jscJ54i?usp=sharing) 31 | 32 | ```python 33 | from rudalle.pipelines import generate_images, show 34 | from rudalle import get_rudalle_model, get_tokenizer, get_vae 35 | from rudalle.utils import seed_everything 36 | 37 | device = 'cuda' 38 | dalle = get_rudalle_model('Emojich', pretrained=True, fp16=True, device=device) 39 | tokenizer = get_tokenizer() 40 | vae = get_vae(dwt=True).to(device) 41 | 42 | text = 'Дональд Трамп из лего' # Donald Trump made of LEGO 43 | 44 | seed_everything(42) 45 | pil_images = [] 46 | for top_k, top_p, images_num in [ 47 | (2048, 0.995, 16), 48 | ]: 49 | pil_images += generate_images(text, tokenizer, dalle, vae, top_k=top_k, images_num=images_num, top_p=top_p, bs=8)[0] 50 | 51 | show(pil_images, 4) 52 | ``` 53 | ![](./pics/emojich/emoji-Donald.png) 54 | 55 | ### Super Resolution: 56 | ```python 57 | from rudalle.pipelines import super_resolution 58 | from rudalle import get_realesrgan 59 | 60 | device = 'cuda' 61 | realesrgan = get_realesrgan('x4', device=device) 62 | sr_images = super_resolution(pil_images, realesrgan) 63 | ``` 64 | 65 | ### Converting to Telegram Stickers format (512x512 RGBA) 66 | ```python 67 | from rudalle.pipelines import convert_emoji_to_rgba, show_rgba 68 | from rudalle import get_emojich_unet 69 | 70 | device = 'cuda' 71 | emojich_unet = get_emojich_unet('unet_effnetb7').to(device) 72 | rgba_images, _ = convert_emoji_to_rgba(sr_images, emojich_unet, device=device) 73 | for rgba_image in rgba_images: 74 | show_rgba(rgba_image); 75 | ``` 76 | ![](./pics/emojich/emojich-stickers.png) 77 | 78 | ### Examples of generated emojis 79 | 80 | All examples are generated automatically (without manual cherry-picking) with hyper-parameters: 81 | seed 42, batch size 16, top-k 2048, top-p 0.995, temperature 1.0, GPU A100. 82 | For making better generative emojis should use more attempts (~512) and select the best one manually. 83 | 84 | *Remember, the great art makers became "great" after creating just only one masterpiece.* 85 | 86 | ![](./pics/emojich/examples.png) 87 | 88 | 89 | ### Citation 90 | Feel free to cite our work in your research if it is helpful for you 91 | ``` 92 | @misc{shonenkov2021emojich, 93 | title={Emojich -- zero-shot emoji generation using Russian language: a technical report}, 94 | author={Alex Shonenkov and Daria Bakshandaeva and Denis Dimitrov and Aleksandr Nikolich}, 95 | year={2021}, 96 | eprint={2112.02448}, 97 | archivePrefix={arXiv}, 98 | primaryClass={cs.CL} 99 | } 100 | ``` 101 | -------------------------------------------------------------------------------- /rudalle/realesrgan/rrdbnet_arch.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | from torch import nn as nn 4 | from torch.nn import functional as F 5 | 6 | from .arch_util import default_init_weights, make_layer, pixel_unshuffle 7 | 8 | 9 | class ResidualDenseBlock(nn.Module): 10 | """Residual Dense Block. 11 | Used in RRDB block in ESRGAN. 12 | Args: 13 | num_feat (int): Channel number of intermediate features. 14 | num_grow_ch (int): Channels for each growth. 15 | """ 16 | 17 | def __init__(self, num_feat=64, num_grow_ch=32): 18 | super(ResidualDenseBlock, self).__init__() 19 | self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) 20 | self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) 21 | self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) 22 | self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) 23 | self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) 24 | 25 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 26 | 27 | # initialization 28 | default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) 29 | 30 | def forward(self, x): 31 | x1 = self.lrelu(self.conv1(x)) 32 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) 33 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 34 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) 35 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 36 | # Emperically, we use 0.2 to scale the residual for better performance 37 | return x5 * 0.2 + x 38 | 39 | 40 | class RRDB(nn.Module): 41 | """Residual in Residual Dense Block. 42 | Used in RRDB-Net in ESRGAN. 43 | Args: 44 | num_feat (int): Channel number of intermediate features. 45 | num_grow_ch (int): Channels for each growth. 46 | """ 47 | 48 | def __init__(self, num_feat, num_grow_ch=32): 49 | super(RRDB, self).__init__() 50 | self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) 51 | self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) 52 | self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) 53 | 54 | def forward(self, x): 55 | out = self.rdb1(x) 56 | out = self.rdb2(out) 57 | out = self.rdb3(out) 58 | # Emperically, we use 0.2 to scale the residual for better performance 59 | return out * 0.2 + x 60 | 61 | 62 | class RRDBNet(nn.Module): 63 | """Networks consisting of Residual in Residual Dense Block, which is used 64 | in ESRGAN. 65 | ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. 66 | We extend ESRGAN for scale x2 and scale x1. 67 | Note: This is one option for scale 1, scale 2 in RRDBNet. 68 | We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size 69 | and enlarge the channel size before feeding inputs into the main ESRGAN architecture. 70 | Args: 71 | num_in_ch (int): Channel number of inputs. 72 | num_out_ch (int): Channel number of outputs. 73 | num_feat (int): Channel number of intermediate features. 74 | Default: 64 75 | num_block (int): Block number in the trunk network. Defaults: 23 76 | num_grow_ch (int): Channels for each growth. Default: 32. 77 | """ 78 | 79 | def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32): 80 | super(RRDBNet, self).__init__() 81 | self.scale = scale 82 | if scale == 2: 83 | num_in_ch = num_in_ch * 4 84 | elif scale == 1: 85 | num_in_ch = num_in_ch * 16 86 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 87 | self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch) 88 | self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 89 | # upsample 90 | self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 91 | self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 92 | if scale == 8: 93 | self.conv_up3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 94 | self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 95 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 96 | 97 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 98 | 99 | def forward(self, x): 100 | if self.scale == 2: 101 | feat = pixel_unshuffle(x, scale=2) 102 | elif self.scale == 1: 103 | feat = pixel_unshuffle(x, scale=4) 104 | else: 105 | feat = x 106 | feat = self.conv_first(feat) 107 | body_feat = self.conv_body(self.body(feat)) 108 | feat = feat + body_feat 109 | # upsample 110 | feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) 111 | feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) 112 | if self.scale == 8: 113 | feat = self.lrelu(self.conv_up3(F.interpolate(feat, scale_factor=2, mode='nearest'))) 114 | out = self.conv_last(self.lrelu(self.conv_hr(feat))) 115 | return out 116 | -------------------------------------------------------------------------------- /rudalle/dalle/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn.functional as F 4 | from einops import rearrange 5 | 6 | from .utils import exists, is_empty, init_method_normal 7 | 8 | from .transformer import DalleTransformer 9 | 10 | 11 | class DalleModel(torch.nn.Module): 12 | def __init__(self, 13 | device, 14 | num_layers, 15 | vocab_size, 16 | hidden_size, 17 | num_attention_heads, 18 | embedding_dropout_prob, 19 | attention_dropout_prob, 20 | output_dropout_prob, 21 | text_seq_length=128, 22 | image_tokens_per_dim=32, 23 | image_vocab_size=16384, 24 | loss_img_weight=7, 25 | cogview_sandwich_layernorm=False, 26 | cogview_pb_relax=False, 27 | is_bool_mask=True, 28 | mlp_activation='gelu_jit'): 29 | super(DalleModel, self).__init__() 30 | self.device = device 31 | self.image_tokens_per_dim = image_tokens_per_dim 32 | self.image_seq_length = image_tokens_per_dim ** 2 33 | self.text_seq_length = text_seq_length 34 | self.total_seq_length = self.text_seq_length + self.image_seq_length 35 | self.total_vocab_size = vocab_size + image_vocab_size 36 | self.vocab_size = vocab_size 37 | self.loss_img_weight = loss_img_weight 38 | 39 | init_method = init_method_normal(std=0.02) 40 | 41 | self.text_embeddings = torch.nn.Embedding(vocab_size, hidden_size) 42 | self.image_embeddings = torch.nn.Embedding(image_vocab_size, hidden_size) 43 | 44 | # Position embedding (serial). 45 | self.text_pos_embeddings = torch.nn.Embedding(text_seq_length + 1, hidden_size) 46 | self.image_row_embeddings = torch.nn.Embedding(image_tokens_per_dim, hidden_size) 47 | self.image_col_embeddings = torch.nn.Embedding(image_tokens_per_dim, hidden_size) 48 | init_method(self.text_pos_embeddings.weight) 49 | init_method(self.image_row_embeddings.weight) 50 | init_method(self.image_col_embeddings.weight) 51 | 52 | self.to_logits = torch.nn.Sequential( 53 | torch.nn.LayerNorm(hidden_size), 54 | torch.nn.Linear(hidden_size, self.total_vocab_size), 55 | ) 56 | 57 | # Embeddings dropout 58 | self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) 59 | 60 | # Transformer 61 | self.transformer = DalleTransformer( 62 | num_layers, 63 | hidden_size, 64 | num_attention_heads, 65 | attention_dropout_prob, 66 | output_dropout_prob, 67 | text_seq_length=text_seq_length, 68 | image_tokens_per_dim=image_tokens_per_dim, 69 | cogview_sandwich_layernorm=cogview_sandwich_layernorm, 70 | cogview_pb_relax=cogview_pb_relax, 71 | mlp_activation=mlp_activation, 72 | is_bool_mask=is_bool_mask, 73 | ) 74 | 75 | def get_param(self, item): 76 | return getattr(self, item) 77 | 78 | def get_image_pos_embeddings(self, image_input_ids, past_length=0): 79 | input_shape = image_input_ids.size() 80 | row_ids = torch.arange(past_length, input_shape[-1] + past_length, 81 | dtype=torch.long, device=self.device) // self.image_tokens_per_dim 82 | row_ids = row_ids.unsqueeze(0).view(-1, input_shape[-1]) 83 | col_ids = torch.arange(past_length, input_shape[-1] + past_length, 84 | dtype=torch.long, device=self.device) % self.image_tokens_per_dim 85 | col_ids = col_ids.unsqueeze(0).view(-1, input_shape[-1]) 86 | return self.image_row_embeddings(row_ids) + self.image_col_embeddings(col_ids) 87 | 88 | def forward( 89 | self, 90 | input_ids, 91 | attention_mask, 92 | return_loss=False, 93 | has_cache=False, 94 | use_cache=False, 95 | ): 96 | text = input_ids[:, :self.text_seq_length] 97 | text_range = torch.arange(self.text_seq_length) 98 | text_range += (self.vocab_size - self.text_seq_length) 99 | text_range = text_range.to(self.device) 100 | text = torch.where(text == 0, text_range, text) 101 | # some hardcode :) 102 | text = F.pad(text, (1, 0), value=2) 103 | text_embeddings = self.text_embeddings(text) + \ 104 | self.text_pos_embeddings(torch.arange(text.shape[1], device=self.device)) 105 | 106 | image_input_ids = input_ids[:, self.text_seq_length:] 107 | 108 | if exists(image_input_ids) and not is_empty(image_input_ids): 109 | image_embeddings = self.image_embeddings(image_input_ids) + \ 110 | self.get_image_pos_embeddings(image_input_ids, past_length=0) 111 | embeddings = torch.cat((text_embeddings, image_embeddings), dim=1) 112 | else: 113 | embeddings = text_embeddings 114 | # some hardcode :) 115 | if embeddings.shape[1] > self.total_seq_length: 116 | embeddings = embeddings[:, :-1] 117 | 118 | alpha = 0.1 119 | embeddings = embeddings * alpha + embeddings.detach() * (1-alpha) 120 | 121 | attention_mask = attention_mask[:, :, :embeddings.shape[1], :embeddings.shape[1]] 122 | transformer_output, present_has_cache = self.transformer( 123 | embeddings, attention_mask, has_cache=has_cache, use_cache=use_cache) 124 | 125 | logits = self.to_logits(transformer_output) 126 | if return_loss is False: 127 | return logits, present_has_cache 128 | 129 | labels = torch.cat((text[:, 1:], image_input_ids), dim=1).contiguous().long() 130 | logits = rearrange(logits, 'b n c -> b c n') 131 | 132 | text_logits = logits[:, :self.vocab_size, :self.text_seq_length].contiguous().float() 133 | image_logits = logits[:, self.vocab_size:, self.text_seq_length:].contiguous().float() 134 | 135 | loss_text = F.cross_entropy( 136 | text_logits, 137 | labels[:, :self.text_seq_length]) 138 | loss_img = F.cross_entropy( 139 | image_logits, 140 | labels[:, self.text_seq_length:]) 141 | 142 | loss = (loss_text + self.loss_img_weight * loss_img) / (self.loss_img_weight + 1) 143 | return loss, {'text': loss_text.data.detach().float(), 'image': loss_img.data.detach().float()} 144 | 145 | def to(self, device, *args, **kwargs): 146 | self.device = device 147 | return super().to(device, *args, **kwargs) 148 | -------------------------------------------------------------------------------- /rudalle/realesrgan/arch_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import math 3 | import torch 4 | from torch import nn as nn 5 | from torch.nn import functional as F 6 | from torch.nn import init as init 7 | from torch.nn.modules.batchnorm import _BatchNorm 8 | 9 | 10 | @torch.no_grad() 11 | def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): 12 | """Initialize network weights. 13 | Args: 14 | module_list (list[nn.Module] | nn.Module): Modules to be initialized. 15 | scale (float): Scale initialized weights, especially for residual 16 | blocks. Default: 1. 17 | bias_fill (float): The value to fill bias. Default: 0 18 | kwargs (dict): Other arguments for initialization function. 19 | """ 20 | if not isinstance(module_list, list): 21 | module_list = [module_list] 22 | for module in module_list: 23 | for m in module.modules(): 24 | if isinstance(m, nn.Conv2d): 25 | init.kaiming_normal_(m.weight, **kwargs) 26 | m.weight.data *= scale 27 | if m.bias is not None: 28 | m.bias.data.fill_(bias_fill) 29 | elif isinstance(m, nn.Linear): 30 | init.kaiming_normal_(m.weight, **kwargs) 31 | m.weight.data *= scale 32 | if m.bias is not None: 33 | m.bias.data.fill_(bias_fill) 34 | elif isinstance(m, _BatchNorm): 35 | init.constant_(m.weight, 1) 36 | if m.bias is not None: 37 | m.bias.data.fill_(bias_fill) 38 | 39 | 40 | def make_layer(basic_block, num_basic_block, **kwarg): 41 | """Make layers by stacking the same blocks. 42 | Args: 43 | basic_block (nn.module): nn.module class for basic block. 44 | num_basic_block (int): number of blocks. 45 | Returns: 46 | nn.Sequential: Stacked blocks in nn.Sequential. 47 | """ 48 | layers = [] 49 | for _ in range(num_basic_block): 50 | layers.append(basic_block(**kwarg)) 51 | return nn.Sequential(*layers) 52 | 53 | 54 | class ResidualBlockNoBN(nn.Module): 55 | """Residual block without BN. 56 | It has a style of: 57 | ---Conv-ReLU-Conv-+- 58 | |________________| 59 | Args: 60 | num_feat (int): Channel number of intermediate features. 61 | Default: 64. 62 | res_scale (float): Residual scale. Default: 1. 63 | pytorch_init (bool): If set to True, use pytorch default init, 64 | otherwise, use default_init_weights. Default: False. 65 | """ 66 | 67 | def __init__(self, num_feat=64, res_scale=1, pytorch_init=False): 68 | super(ResidualBlockNoBN, self).__init__() 69 | self.res_scale = res_scale 70 | self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) 71 | self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) 72 | self.relu = nn.ReLU(inplace=True) 73 | 74 | if not pytorch_init: 75 | default_init_weights([self.conv1, self.conv2], 0.1) 76 | 77 | def forward(self, x): 78 | identity = x 79 | out = self.conv2(self.relu(self.conv1(x))) 80 | return identity + out * self.res_scale 81 | 82 | 83 | class Upsample(nn.Sequential): 84 | """Upsample module. 85 | Args: 86 | scale (int): Scale factor. Supported scales: 2^n and 3. 87 | num_feat (int): Channel number of intermediate features. 88 | """ 89 | 90 | def __init__(self, scale, num_feat): 91 | m = [] 92 | if (scale & (scale - 1)) == 0: # scale = 2^n 93 | for _ in range(int(math.log(scale, 2))): 94 | m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) 95 | m.append(nn.PixelShuffle(2)) 96 | elif scale == 3: 97 | m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) 98 | m.append(nn.PixelShuffle(3)) 99 | else: 100 | raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') 101 | super(Upsample, self).__init__(*m) 102 | 103 | 104 | def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True): 105 | """Warp an image or feature map with optical flow. 106 | Args: 107 | x (Tensor): Tensor with size (n, c, h, w). 108 | flow (Tensor): Tensor with size (n, h, w, 2), normal value. 109 | interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'. 110 | padding_mode (str): 'zeros' or 'border' or 'reflection'. 111 | Default: 'zeros'. 112 | align_corners (bool): Before pytorch 1.3, the default value is 113 | align_corners=True. After pytorch 1.3, the default value is 114 | align_corners=False. Here, we use the True as default. 115 | Returns: 116 | Tensor: Warped image or feature map. 117 | """ 118 | assert x.size()[-2:] == flow.size()[1:3] 119 | _, _, h, w = x.size() 120 | # create mesh grid 121 | grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x)) 122 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 123 | grid.requires_grad = False 124 | 125 | vgrid = grid + flow 126 | # scale grid to [-1,1] 127 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 128 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 129 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) 130 | output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners) 131 | 132 | # TODO, what if align_corners=False 133 | return output 134 | 135 | 136 | def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False): 137 | """Resize a flow according to ratio or shape. 138 | Args: 139 | flow (Tensor): Precomputed flow. shape [N, 2, H, W]. 140 | size_type (str): 'ratio' or 'shape'. 141 | sizes (list[int | float]): the ratio for resizing or the final output 142 | shape. 143 | 1) The order of ratio should be [ratio_h, ratio_w]. For 144 | downsampling, the ratio should be smaller than 1.0 (i.e., ratio 145 | < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e., 146 | ratio > 1.0). 147 | 2) The order of output_size should be [out_h, out_w]. 148 | interp_mode (str): The mode of interpolation for resizing. 149 | Default: 'bilinear'. 150 | align_corners (bool): Whether align corners. Default: False. 151 | Returns: 152 | Tensor: Resized flow. 153 | """ 154 | _, _, flow_h, flow_w = flow.size() 155 | if size_type == 'ratio': 156 | output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1]) 157 | elif size_type == 'shape': 158 | output_h, output_w = sizes[0], sizes[1] 159 | else: 160 | raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.') 161 | 162 | input_flow = flow.clone() 163 | ratio_h = output_h / flow_h 164 | ratio_w = output_w / flow_w 165 | input_flow[:, 0, :, :] *= ratio_w 166 | input_flow[:, 1, :, :] *= ratio_h 167 | resized_flow = F.interpolate( 168 | input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners) 169 | return resized_flow 170 | 171 | 172 | # TODO: may write a cpp file 173 | def pixel_unshuffle(x, scale): 174 | """ Pixel unshuffle. 175 | Args: 176 | x (Tensor): Input feature with shape (b, c, hh, hw). 177 | scale (int): Downsample ratio. 178 | Returns: 179 | Tensor: the pixel unshuffled feature. 180 | """ 181 | b, c, hh, hw = x.size() 182 | out_channel = c * (scale**2) 183 | assert hh % scale == 0 and hw % scale == 0 184 | h = hh // scale 185 | w = hw // scale 186 | x_view = x.view(b, c, h, scale, w, scale) 187 | return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) 188 | -------------------------------------------------------------------------------- /rudalle/pipelines.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | from glob import glob 4 | from os.path import join 5 | 6 | import cv2 7 | import torch 8 | import torchvision 9 | import transformers 10 | import more_itertools 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | from tqdm.auto import tqdm 14 | from PIL import Image 15 | 16 | from . import utils 17 | 18 | 19 | def generate_images(text, tokenizer, dalle, vae, top_k, top_p, images_num, image_prompts=None, temperature=1.0, bs=4, 20 | seed=None, use_cache=True, return_codes=False): 21 | # TODO docstring 22 | if seed is not None: 23 | utils.seed_everything(seed) 24 | 25 | vocab_size = dalle.get_param('vocab_size') 26 | text_seq_length = dalle.get_param('text_seq_length') 27 | image_seq_length = dalle.get_param('image_seq_length') 28 | total_seq_length = dalle.get_param('total_seq_length') 29 | device = dalle.get_param('device') 30 | 31 | text = text.lower().strip() 32 | input_ids = tokenizer.encode_text(text, text_seq_length=text_seq_length) 33 | pil_images, scores, codes = [], [], [] 34 | for chunk in more_itertools.chunked(range(images_num), bs): 35 | chunk_bs = len(chunk) 36 | with torch.no_grad(): 37 | attention_mask = torch.tril(torch.ones((chunk_bs, 1, total_seq_length, total_seq_length), device=device)) 38 | out = input_ids.unsqueeze(0).repeat(chunk_bs, 1).to(device) 39 | has_cache = False 40 | sample_scores = [] 41 | if image_prompts is not None: 42 | prompts_idx, prompts = image_prompts.image_prompts_idx, image_prompts.image_prompts 43 | prompts = prompts.repeat(chunk_bs, 1) 44 | for idx in tqdm(range(out.shape[1], total_seq_length)): 45 | idx -= text_seq_length 46 | if image_prompts is not None and idx in prompts_idx: 47 | out = torch.cat((out, prompts[:, idx].unsqueeze(1)), dim=-1) 48 | else: 49 | logits, has_cache = dalle(out, attention_mask, 50 | has_cache=has_cache, use_cache=use_cache, return_loss=False) 51 | logits = logits[:, -1, vocab_size:] 52 | logits /= temperature 53 | filtered_logits = transformers.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) 54 | probs = torch.nn.functional.softmax(filtered_logits, dim=-1) 55 | sample = torch.multinomial(probs, 1) 56 | sample_scores.append(probs[torch.arange(probs.size(0)), sample.transpose(0, 1)]) 57 | out = torch.cat((out, sample), dim=-1) 58 | codebooks = out[:, -image_seq_length:] 59 | images = vae.decode(codebooks) 60 | pil_images += utils.torch_tensors_to_pil_list(images) 61 | scores += torch.cat(sample_scores).sum(0).detach().cpu().numpy().tolist() 62 | for j in range(codebooks.shape[0]): 63 | codes.append(codebooks[j].detach().cpu().numpy()) 64 | 65 | if return_codes: 66 | return pil_images, scores, codes 67 | return pil_images, scores 68 | 69 | def super_resolution(pil_images, realesrgan, batch_size=4): 70 | result = [] 71 | for pil_image in pil_images: 72 | with torch.no_grad(): 73 | sr_image = realesrgan.predict(np.array(pil_image), batch_size=batch_size) 74 | result.append(sr_image) 75 | return result 76 | 77 | 78 | def cherry_pick_by_clip(pil_images, text, ruclip, ruclip_processor, device='cpu', count=4): 79 | with torch.no_grad(): 80 | inputs = ruclip_processor(text=text, images=pil_images) 81 | for key in inputs.keys(): 82 | inputs[key] = inputs[key].to(device) 83 | outputs = ruclip(**inputs) 84 | sims = outputs.logits_per_image.view(-1).softmax(dim=0) 85 | items = [] 86 | for index, sim in enumerate(sims.cpu().numpy()): 87 | items.append({'img_index': index, 'cosine': sim}) 88 | items = sorted(items, key=lambda x: x['cosine'], reverse=True)[:count] 89 | top_pil_images = [pil_images[x['img_index']] for x in items] 90 | top_scores = [x['cosine'] for x in items] 91 | return top_pil_images, top_scores 92 | 93 | 94 | def show(pil_images, nrow=4, size=14, save_dir=None, show=True): 95 | """ 96 | :param pil_images: list of images in PIL 97 | :param nrow: number of rows 98 | :param size: size of the images 99 | :param save_dir: dir for separately saving of images, example: save_dir='./pics' 100 | """ 101 | if save_dir is not None: 102 | os.makedirs(save_dir, exist_ok=True) 103 | count = len(glob(join(save_dir, 'img_*.png'))) 104 | for i, pil_image in enumerate(pil_images): 105 | pil_image.save(join(save_dir, f'img_{count+i}.png')) 106 | 107 | pil_images = [pil_image.convert('RGB') for pil_image in pil_images] 108 | imgs = torchvision.utils.make_grid(utils.pil_list_to_torch_tensors(pil_images), nrow=nrow) 109 | if not isinstance(imgs, list): 110 | imgs = [imgs.cpu()] 111 | fix, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(size, size)) 112 | for i, img in enumerate(imgs): 113 | img = img.detach() 114 | img = torchvision.transforms.functional.to_pil_image(img) 115 | if save_dir is not None: 116 | count = len(glob(join(save_dir, 'group_*.png'))) 117 | img.save(join(save_dir, f'group_{count+i}.png')) 118 | if show: 119 | axs[0, i].imshow(np.asarray(img)) 120 | axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) 121 | if show: 122 | fix.show() 123 | plt.show() 124 | 125 | 126 | def classic_convert_emoji_to_rgba(np_image, lower_thr=240, upper_thr=255, width=2): 127 | img = np_image[:, :, :3].copy() 128 | lower = np.array([lower_thr, lower_thr, lower_thr], dtype='uint8') 129 | upper = np.array([upper_thr, upper_thr, upper_thr], dtype='uint8') 130 | mask = cv2.inRange(img, lower, upper) 131 | ret, thresh = cv2.threshold(mask, 0, 255, 0) 132 | contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) 133 | a_channel = np.ones((512, 512), dtype=np.uint8)*255 134 | if len(contours) != 0: 135 | contours = sorted(contours, key=lambda x: x.shape[0])[-7:] 136 | cv2.fillPoly(a_channel, contours, (0, 0, 0)) 137 | cv2.drawContours(a_channel, contours, -1, (0, 0, 0), width) 138 | img = cv2.cvtColor(img, cv2.COLOR_RGB2RGBA) 139 | img[:, :, 3] = a_channel 140 | return img 141 | 142 | 143 | def convert_emoji_to_rgba(pil_images, emojich_unet, device='cpu', bs=1, score_thr=0.99): 144 | final_images, runs = [], [] 145 | with torch.no_grad(): 146 | for chunk in more_itertools.chunked(pil_images, bs): 147 | images = [] 148 | for pil_image in chunk: 149 | image = np.array(pil_image.resize((512, 512)))[:, :, :3] 150 | image = image.astype(np.float32) / 255.0 151 | image = torch.from_numpy(image).permute(2, 0, 1) 152 | images.append(image) 153 | images = torch.nn.utils.rnn.pad_sequence(images, batch_first=True) 154 | pred_masks = emojich_unet(images.to(device)) 155 | pred_masks = torch.softmax(pred_masks, 1) 156 | scores, pred_masks = torch.max(pred_masks, 1) 157 | pred_masks = pred_masks.int().cpu().numpy() 158 | pred_masks = (pred_masks * 255).astype(np.uint8) 159 | for pil_image, pred_mask, score in zip(chunk, pred_masks, scores): 160 | score = score.mean().item() 161 | final_image = np.zeros((512, 512, 4), np.uint8) 162 | final_image[:, :, :3] = np.array(pil_image.resize((512, 512)))[:, :, :3] 163 | if score > score_thr: 164 | run = 'unet' 165 | final_image[:, :, -1] = pred_mask 166 | else: 167 | run = 'classic' 168 | final_image = classic_convert_emoji_to_rgba(final_image) 169 | final_image = Image.fromarray(final_image) 170 | final_images.append(final_image) 171 | runs.append(run) 172 | return final_images, runs 173 | 174 | 175 | def show_rgba(rgba_pil_image): 176 | img = np.array(rgba_pil_image) 177 | fig, ax = plt.subplots(1, 3, figsize=(10, 10), dpi=100) 178 | ax[0].imshow(img[:, :, :3]) 179 | ax[1].imshow(img[:, :, -1]) 180 | mask = np.repeat(np.expand_dims(img[:, :, -1] < 128, -1), 3, axis=-1) 181 | img = img[:, :, :3] 182 | img[mask[:, :, 0], 0] = 64 183 | img[mask[:, :, 0], 1] = 255 184 | img[mask[:, :, 0], 2] = 64 185 | ax[2].imshow(img) 186 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [2020] [sberbank-ai] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /rudalle/vae/pytorch_wavelets_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Useful utilities for testing the 2-D DTCWT with synthetic images 4 | License: https://github.com/fbcotter/pytorch_wavelets/blob/master/LICENSE 5 | Source: https://github.com/fbcotter/pytorch_wavelets/blob/31d6ac1b51b08f811a6a70eb7b3440f106009da0/pytorch_wavelets/dwt/lowlevel.py # noqa 6 | """ 7 | 8 | import pywt 9 | import torch 10 | import numpy as np 11 | import torch.nn.functional as F 12 | from torch.autograd import Function 13 | 14 | 15 | def sfb1d(lo, hi, g0, g1, mode='zero', dim=-1): 16 | """ 1D synthesis filter bank of an image tensor 17 | """ 18 | C = lo.shape[1] 19 | d = dim % 4 20 | # If g0, g1 are not tensors, make them. If they are, then assume that they 21 | # are in the right order 22 | if not isinstance(g0, torch.Tensor): 23 | g0 = torch.tensor(np.copy(np.array(g0).ravel()), 24 | dtype=torch.float, device=lo.device) 25 | if not isinstance(g1, torch.Tensor): 26 | g1 = torch.tensor(np.copy(np.array(g1).ravel()), 27 | dtype=torch.float, device=lo.device) 28 | L = g0.numel() 29 | shape = [1, 1, 1, 1] 30 | shape[d] = L 31 | N = 2*lo.shape[d] 32 | # If g aren't in the right shape, make them so 33 | if g0.shape != tuple(shape): 34 | g0 = g0.reshape(*shape) 35 | if g1.shape != tuple(shape): 36 | g1 = g1.reshape(*shape) 37 | 38 | s = (2, 1) if d == 2 else (1, 2) 39 | g0 = torch.cat([g0]*C, dim=0) 40 | g1 = torch.cat([g1]*C, dim=0) 41 | if mode == 'per' or mode == 'periodization': 42 | y = F.conv_transpose2d(lo, g0, stride=s, groups=C) + \ 43 | F.conv_transpose2d(hi, g1, stride=s, groups=C) 44 | if d == 2: 45 | y[:, :, :L-2] = y[:, :, :L-2] + y[:, :, N:N+L-2] 46 | y = y[:, :, :N] 47 | else: 48 | y[:, :, :, :L-2] = y[:, :, :, :L-2] + y[:, :, :, N:N+L-2] 49 | y = y[:, :, :, :N] 50 | y = roll(y, 1-L//2, dim=dim) 51 | else: 52 | if mode == 'zero' or mode == 'symmetric' or mode == 'reflect' or \ 53 | mode == 'periodic': 54 | pad = (L-2, 0) if d == 2 else (0, L-2) 55 | y = F.conv_transpose2d(lo, g0, stride=s, padding=pad, groups=C) + \ 56 | F.conv_transpose2d(hi, g1, stride=s, padding=pad, groups=C) 57 | else: 58 | raise ValueError('Unkown pad type: {}'.format(mode)) 59 | 60 | return y 61 | 62 | 63 | def _SFB2D(low, highs, g0_row, g1_row, g0_col, g1_col, mode): 64 | mode = int_to_mode(mode) 65 | 66 | lh, hl, hh = torch.unbind(highs, dim=2) 67 | lo = sfb1d(low, lh, g0_col, g1_col, mode=mode, dim=2) 68 | hi = sfb1d(hl, hh, g0_col, g1_col, mode=mode, dim=2) 69 | y = sfb1d(lo, hi, g0_row, g1_row, mode=mode, dim=3) 70 | 71 | return y 72 | 73 | 74 | def roll(x, n, dim, make_even=False): 75 | if n < 0: 76 | n = x.shape[dim] + n 77 | 78 | if make_even and x.shape[dim] % 2 == 1: 79 | end = 1 80 | else: 81 | end = 0 82 | 83 | if dim == 0: 84 | return torch.cat((x[-n:], x[:-n+end]), dim=0) 85 | elif dim == 1: 86 | return torch.cat((x[:, -n:], x[:, :-n+end]), dim=1) 87 | elif dim == 2 or dim == -2: 88 | return torch.cat((x[:, :, -n:], x[:, :, :-n+end]), dim=2) 89 | elif dim == 3 or dim == -1: 90 | return torch.cat((x[:, :, :, -n:], x[:, :, :, :-n+end]), dim=3) 91 | 92 | 93 | def int_to_mode(mode): 94 | if mode == 0: 95 | return 'zero' 96 | elif mode == 1: 97 | return 'symmetric' 98 | elif mode == 2: 99 | return 'periodization' 100 | elif mode == 3: 101 | return 'constant' 102 | elif mode == 4: 103 | return 'reflect' 104 | elif mode == 5: 105 | return 'replicate' 106 | elif mode == 6: 107 | return 'periodic' 108 | else: 109 | raise ValueError('Unkown pad type: {}'.format(mode)) 110 | 111 | 112 | def prep_filt_sfb2d(g0_col, g1_col, g0_row=None, g1_row=None, device=None): 113 | """ 114 | Prepares the filters to be of the right form for the sfb2d function. In 115 | particular, makes the tensors the right shape. It does not mirror image them 116 | as as sfb2d uses conv2d_transpose which acts like normal convolution. 117 | Inputs: 118 | g0_col (array-like): low pass column filter bank 119 | g1_col (array-like): high pass column filter bank 120 | g0_row (array-like): low pass row filter bank. If none, will assume the 121 | same as column filter 122 | g1_row (array-like): high pass row filter bank. If none, will assume the 123 | same as column filter 124 | device: which device to put the tensors on to 125 | Returns: 126 | (g0_col, g1_col, g0_row, g1_row) 127 | """ 128 | g0_col, g1_col = prep_filt_sfb1d(g0_col, g1_col, device) 129 | if g0_row is None: 130 | g0_row, g1_row = g0_col, g1_col 131 | else: 132 | g0_row, g1_row = prep_filt_sfb1d(g0_row, g1_row, device) 133 | 134 | g0_col = g0_col.reshape((1, 1, -1, 1)) 135 | g1_col = g1_col.reshape((1, 1, -1, 1)) 136 | g0_row = g0_row.reshape((1, 1, 1, -1)) 137 | g1_row = g1_row.reshape((1, 1, 1, -1)) 138 | 139 | return g0_col, g1_col, g0_row, g1_row 140 | 141 | 142 | def prep_filt_sfb1d(g0, g1, device=None): 143 | """ 144 | Prepares the filters to be of the right form for the sfb1d function. In 145 | particular, makes the tensors the right shape. It does not mirror image them 146 | as as sfb2d uses conv2d_transpose which acts like normal convolution. 147 | Inputs: 148 | g0 (array-like): low pass filter bank 149 | g1 (array-like): high pass filter bank 150 | device: which device to put the tensors on to 151 | Returns: 152 | (g0, g1) 153 | """ 154 | g0 = np.array(g0).ravel() 155 | g1 = np.array(g1).ravel() 156 | t = torch.get_default_dtype() 157 | g0 = torch.tensor(g0, device=device, dtype=t).reshape((1, 1, -1)) 158 | g1 = torch.tensor(g1, device=device, dtype=t).reshape((1, 1, -1)) 159 | 160 | return g0, g1 161 | 162 | 163 | def mode_to_int(mode): 164 | if mode == 'zero': 165 | return 0 166 | elif mode == 'symmetric': 167 | return 1 168 | elif mode == 'per' or mode == 'periodization': 169 | return 2 170 | elif mode == 'constant': 171 | return 3 172 | elif mode == 'reflect': 173 | return 4 174 | elif mode == 'replicate': 175 | return 5 176 | elif mode == 'periodic': 177 | return 6 178 | else: 179 | raise ValueError('Unkown pad type: {}'.format(mode)) 180 | 181 | 182 | def afb1d(x, h0, h1, mode='zero', dim=-1): 183 | """ 1D analysis filter bank (along one dimension only) of an image 184 | Inputs: 185 | x (tensor): 4D input with the last two dimensions the spatial input 186 | h0 (tensor): 4D input for the lowpass filter. Should have shape (1, 1, 187 | h, 1) or (1, 1, 1, w) 188 | h1 (tensor): 4D input for the highpass filter. Should have shape (1, 1, 189 | h, 1) or (1, 1, 1, w) 190 | mode (str): padding method 191 | dim (int) - dimension of filtering. d=2 is for a vertical filter (called 192 | column filtering but filters across the rows). d=3 is for a 193 | horizontal filter, (called row filtering but filters across the 194 | columns). 195 | Returns: 196 | lohi: lowpass and highpass subbands concatenated along the channel 197 | dimension 198 | """ 199 | C = x.shape[1] 200 | # Convert the dim to positive 201 | d = dim % 4 202 | s = (2, 1) if d == 2 else (1, 2) 203 | N = x.shape[d] 204 | # If h0, h1 are not tensors, make them. If they are, then assume that they 205 | # are in the right order 206 | if not isinstance(h0, torch.Tensor): 207 | h0 = torch.tensor(np.copy(np.array(h0).ravel()[::-1]), 208 | dtype=torch.float, device=x.device) 209 | if not isinstance(h1, torch.Tensor): 210 | h1 = torch.tensor(np.copy(np.array(h1).ravel()[::-1]), 211 | dtype=torch.float, device=x.device) 212 | L = h0.numel() 213 | L2 = L // 2 214 | shape = [1, 1, 1, 1] 215 | shape[d] = L 216 | # If h aren't in the right shape, make them so 217 | if h0.shape != tuple(shape): 218 | h0 = h0.reshape(*shape) 219 | if h1.shape != tuple(shape): 220 | h1 = h1.reshape(*shape) 221 | h = torch.cat([h0, h1] * C, dim=0) 222 | 223 | if mode == 'per' or mode == 'periodization': 224 | if x.shape[dim] % 2 == 1: 225 | if d == 2: 226 | x = torch.cat((x, x[:, :, -1:]), dim=2) 227 | else: 228 | x = torch.cat((x, x[:, :, :, -1:]), dim=3) 229 | N += 1 230 | x = roll(x, -L2, dim=d) 231 | pad = (L-1, 0) if d == 2 else (0, L-1) 232 | lohi = F.conv2d(x, h, padding=pad, stride=s, groups=C) 233 | N2 = N//2 234 | if d == 2: 235 | lohi[:, :, :L2] = lohi[:, :, :L2] + lohi[:, :, N2:N2+L2] 236 | lohi = lohi[:, :, :N2] 237 | else: 238 | lohi[:, :, :, :L2] = lohi[:, :, :, :L2] + lohi[:, :, :, N2:N2+L2] 239 | lohi = lohi[:, :, :, :N2] 240 | else: 241 | # Calculate the pad size 242 | outsize = pywt.dwt_coeff_len(N, L, mode=mode) 243 | p = 2 * (outsize - 1) - N + L 244 | if mode == 'zero': 245 | # Sadly, pytorch only allows for same padding before and after, if 246 | # we need to do more padding after for odd length signals, have to 247 | # prepad 248 | if p % 2 == 1: 249 | pad = (0, 0, 0, 1) if d == 2 else (0, 1, 0, 0) 250 | x = F.pad(x, pad) 251 | pad = (p//2, 0) if d == 2 else (0, p//2) 252 | # Calculate the high and lowpass 253 | lohi = F.conv2d(x, h, padding=pad, stride=s, groups=C) 254 | elif mode == 'symmetric' or mode == 'reflect' or mode == 'periodic': 255 | pad = (0, 0, p//2, (p+1)//2) if d == 2 else (p//2, (p+1)//2, 0, 0) 256 | x = mypad(x, pad=pad, mode=mode) 257 | lohi = F.conv2d(x, h, stride=s, groups=C) 258 | else: 259 | raise ValueError('Unkown pad type: {}'.format(mode)) 260 | 261 | return lohi 262 | 263 | 264 | def mypad(x, pad, mode='constant', value=0): 265 | """ Function to do numpy like padding on tensors. Only works for 2-D 266 | padding. 267 | Inputs: 268 | x (tensor): tensor to pad 269 | pad (tuple): tuple of (left, right, top, bottom) pad sizes 270 | mode (str): 'symmetric', 'wrap', 'constant, 'reflect', 'replicate', or 271 | 'zero'. The padding technique. 272 | """ 273 | if mode == 'symmetric': 274 | # Vertical only 275 | if pad[0] == 0 and pad[1] == 0: 276 | m1, m2 = pad[2], pad[3] 277 | l = x.shape[-2] # noqa 278 | xe = reflect(np.arange(-m1, l+m2, dtype='int32'), -0.5, l-0.5) 279 | return x[:, :, xe] 280 | # horizontal only 281 | elif pad[2] == 0 and pad[3] == 0: 282 | m1, m2 = pad[0], pad[1] 283 | l = x.shape[-1] # noqa 284 | xe = reflect(np.arange(-m1, l+m2, dtype='int32'), -0.5, l-0.5) 285 | return x[:, :, :, xe] 286 | # Both 287 | else: 288 | m1, m2 = pad[0], pad[1] 289 | l1 = x.shape[-1] 290 | xe_row = reflect(np.arange(-m1, l1+m2, dtype='int32'), -0.5, l1-0.5) 291 | m1, m2 = pad[2], pad[3] 292 | l2 = x.shape[-2] 293 | xe_col = reflect(np.arange(-m1, l2+m2, dtype='int32'), -0.5, l2-0.5) 294 | i = np.outer(xe_col, np.ones(xe_row.shape[0])) 295 | j = np.outer(np.ones(xe_col.shape[0]), xe_row) 296 | return x[:, :, i, j] 297 | elif mode == 'periodic': 298 | # Vertical only 299 | if pad[0] == 0 and pad[1] == 0: 300 | xe = np.arange(x.shape[-2]) 301 | xe = np.pad(xe, (pad[2], pad[3]), mode='wrap') 302 | return x[:, :, xe] 303 | # Horizontal only 304 | elif pad[2] == 0 and pad[3] == 0: 305 | xe = np.arange(x.shape[-1]) 306 | xe = np.pad(xe, (pad[0], pad[1]), mode='wrap') 307 | return x[:, :, :, xe] 308 | # Both 309 | else: 310 | xe_col = np.arange(x.shape[-2]) 311 | xe_col = np.pad(xe_col, (pad[2], pad[3]), mode='wrap') 312 | xe_row = np.arange(x.shape[-1]) 313 | xe_row = np.pad(xe_row, (pad[0], pad[1]), mode='wrap') 314 | i = np.outer(xe_col, np.ones(xe_row.shape[0])) 315 | j = np.outer(np.ones(xe_col.shape[0]), xe_row) 316 | return x[:, :, i, j] 317 | 318 | elif mode == 'constant' or mode == 'reflect' or mode == 'replicate': 319 | return F.pad(x, pad, mode, value) 320 | elif mode == 'zero': 321 | return F.pad(x, pad) 322 | else: 323 | raise ValueError('Unkown pad type: {}'.format(mode)) 324 | 325 | 326 | def reflect(x, minx, maxx): 327 | """Reflect the values in matrix *x* about the scalar values *minx* and 328 | *maxx*. Hence a vector *x* containing a long linearly increasing series is 329 | converted into a waveform which ramps linearly up and down between *minx* 330 | and *maxx*. If *x* contains integers and *minx* and *maxx* are (integers + 331 | 0.5), the ramps will have repeated max and min samples. 332 | .. codeauthor:: Rich Wareham , Aug 2013 333 | .. codeauthor:: Nick Kingsbury, Cambridge University, January 1999. 334 | """ 335 | x = np.asanyarray(x) 336 | rng = maxx - minx 337 | rng_by_2 = 2 * rng 338 | mod = np.fmod(x - minx, rng_by_2) 339 | normed_mod = np.where(mod < 0, mod + rng_by_2, mod) 340 | out = np.where(normed_mod >= rng, rng_by_2 - normed_mod, normed_mod) + minx 341 | return np.array(out, dtype=x.dtype) 342 | 343 | 344 | class SFB2D(Function): 345 | """ Does a single level 2d wavelet decomposition of an input. Does separate 346 | row and column filtering by two calls to 347 | :py:func:`pytorch_wavelets.dwt.lowlevel.afb1d` 348 | Needs to have the tensors in the right form. Because this function defines 349 | its own backward pass, saves on memory by not having to save the input 350 | tensors. 351 | Inputs: 352 | x (torch.Tensor): Input to decompose 353 | h0_row: row lowpass 354 | h1_row: row highpass 355 | h0_col: col lowpass 356 | h1_col: col highpass 357 | mode (int): use mode_to_int to get the int code here 358 | We encode the mode as an integer rather than a string as gradcheck causes an 359 | error when a string is provided. 360 | Returns: 361 | y: Tensor of shape (N, C*4, H, W) 362 | """ 363 | @staticmethod 364 | def forward(ctx, low, highs, g0_row, g1_row, g0_col, g1_col, mode): 365 | mode = int_to_mode(mode) 366 | ctx.mode = mode 367 | ctx.save_for_backward(g0_row, g1_row, g0_col, g1_col) 368 | 369 | lh, hl, hh = torch.unbind(highs, dim=2) 370 | lo = sfb1d(low, lh, g0_col, g1_col, mode=mode, dim=2) 371 | hi = sfb1d(hl, hh, g0_col, g1_col, mode=mode, dim=2) 372 | y = sfb1d(lo, hi, g0_row, g1_row, mode=mode, dim=3) 373 | return y 374 | 375 | @staticmethod 376 | def backward(ctx, dy): 377 | dlow, dhigh = None, None 378 | if ctx.needs_input_grad[0]: 379 | mode = ctx.mode 380 | g0_row, g1_row, g0_col, g1_col = ctx.saved_tensors 381 | dx = afb1d(dy, g0_row, g1_row, mode=mode, dim=3) 382 | dx = afb1d(dx, g0_col, g1_col, mode=mode, dim=2) 383 | s = dx.shape 384 | dx = dx.reshape(s[0], -1, 4, s[-2], s[-1]) 385 | dlow = dx[:, :, 0].contiguous() 386 | dhigh = dx[:, :, 1:].contiguous() 387 | return dlow, dhigh, None, None, None, None, None 388 | -------------------------------------------------------------------------------- /rudalle/dalle/transformer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import math 3 | 4 | import torch 5 | from torch.nn import LayerNorm 6 | 7 | from .utils import divide, split_tensor_along_last_dim 8 | from .image_attention import get_conv_mask, get_row_mask, get_col_mask 9 | 10 | 11 | def gelu(x): 12 | return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))) 13 | 14 | 15 | @torch.jit.script 16 | def gelu_jit(x): 17 | """OpenAI's gelu implementation.""" 18 | return gelu(x) 19 | 20 | 21 | class DalleTransformer(torch.nn.Module): 22 | """ 23 | This module takes input from embedding layer and it's output can 24 | be used directly by a logit layer. It consists of L (num-layers) 25 | blocks of: 26 | layer norm 27 | self attention 28 | residual connection 29 | layer norm 30 | mlp 31 | residual connection 32 | followed by a final layer norm. 33 | 34 | Arguments: 35 | num_layers: Number of transformer layers. 36 | hidden_size: The hidden size of the self attention. 37 | num_attention_heads: number of attention head in the self 38 | attention. 39 | attention_dropout_prob: dropout probability of the attention 40 | score in self attention. 41 | output_dropout_prob: dropout probability for the outputs 42 | after self attention and final output. 43 | layernorm_epsilon: epsilon used in layernorm to avoid 44 | division by zero. 45 | """ 46 | _mask_map = [] 47 | 48 | def __init__(self, num_layers, hidden_size, num_attention_heads, attention_dropout_prob, output_dropout_prob, 49 | text_seq_length, image_tokens_per_dim, layernorm_epsilon=1.0e-5, 50 | cogview_sandwich_layernorm=False, cogview_pb_relax=False, mlp_activation='gelu_jit', 51 | is_bool_mask=False): 52 | super(DalleTransformer, self).__init__() 53 | 54 | self.num_layers = num_layers 55 | # CogView stabilization of training features, see chapter 2.4 https://arxiv.org/pdf/2105.13290.pdf 56 | self.cogview_pb_relax = cogview_pb_relax 57 | 58 | # Transformer layers. 59 | self.layers = torch.nn.ModuleList([ 60 | DalleTransformerLayer( 61 | hidden_size, 62 | num_attention_heads, 63 | attention_dropout_prob, 64 | output_dropout_prob, 65 | layernorm_epsilon, 66 | cogview_sandwich_layernorm=cogview_sandwich_layernorm, 67 | cogview_pb_relax=cogview_pb_relax, 68 | mlp_activation=mlp_activation, 69 | ) for _ in range(num_layers) 70 | ]) 71 | 72 | row_mask = get_row_mask(text_seq_length, image_tokens_per_dim, is_bool_mask=is_bool_mask) 73 | col_mask = get_col_mask(text_seq_length, image_tokens_per_dim, is_bool_mask=is_bool_mask) 74 | conv_mask = get_conv_mask(text_seq_length, image_tokens_per_dim, is_bool_mask=is_bool_mask) 75 | self.register_buffer('row_mask', row_mask) 76 | self.register_buffer('col_mask', col_mask) 77 | self.register_buffer('conv_mask', conv_mask) 78 | 79 | # Final layer norm before output. 80 | self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) 81 | 82 | def _get_layer_mask(self, layer_id): 83 | if ((layer_id - 1) % 4 == 0): 84 | layer_mask = self.col_mask 85 | elif layer_id != self.num_layers - 1: 86 | layer_mask = self.row_mask 87 | else: 88 | layer_mask = self.conv_mask 89 | return layer_mask 90 | 91 | def forward(self, hidden_states, attention_mask, has_cache, use_cache): 92 | for i, layer in enumerate(self.layers): 93 | mask = attention_mask 94 | layer_mask = self._get_layer_mask(i)[:mask.size(2), :mask.size(3)] 95 | mask = torch.mul(attention_mask, layer_mask) 96 | hidden_states, present_has_cache = layer(hidden_states, mask, has_cache=has_cache, use_cache=use_cache) 97 | output = self.final_layernorm(hidden_states) 98 | return output, present_has_cache 99 | 100 | 101 | class DalleTransformerLayer(torch.nn.Module): 102 | """ 103 | A single layer transformer. 104 | 105 | We use the following notation: 106 | h: hidden size 107 | n: number of attention heads 108 | b: batch size 109 | s: sequence length 110 | Transformer layer takes input with size [b, s, h] and returns an 111 | output of the same size. 112 | 113 | Arguments: 114 | hidden_size: The hidden size of the self attention. 115 | num_attention_heads: number of attention head in the self 116 | attention. 117 | attention_dropout_prob: dropout probability of the attention 118 | score in self attention. 119 | output_dropout_prob: dropout probability for the outputs 120 | after self attention and final output. 121 | layernorm_epsilon: epsilon used in layernorm to avoid 122 | division by zero. 123 | """ 124 | 125 | def __init__(self, 126 | hidden_size, 127 | num_attention_heads, 128 | attention_dropout_prob, 129 | output_dropout_prob, 130 | layernorm_epsilon, 131 | cogview_sandwich_layernorm=False, 132 | cogview_pb_relax=False, 133 | mlp_activation='gelu_jit'): 134 | super(DalleTransformerLayer, self).__init__() 135 | 136 | # CogView stabilization of training features, see chapter 2.4 https://arxiv.org/pdf/2105.13290.pdf 137 | self.cogview_sandwich_layernorm = cogview_sandwich_layernorm 138 | self.cogview_pb_relax = cogview_pb_relax 139 | 140 | # Layernorm on the input data. 141 | self.input_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) 142 | 143 | if self.cogview_sandwich_layernorm: 144 | self.before_first_addition_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) 145 | self.before_second_addition_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) 146 | 147 | # Self attention. 148 | self.attention = DalleSelfAttention( 149 | hidden_size, 150 | num_attention_heads, 151 | attention_dropout_prob, 152 | output_dropout_prob, 153 | cogview_pb_relax=cogview_pb_relax 154 | ) 155 | 156 | # Layernorm on the input data. 157 | self.post_attention_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) 158 | 159 | # MLP 160 | self.mlp = DalleMLP(hidden_size, output_dropout_prob, activation=mlp_activation) 161 | 162 | def forward(self, hidden_states, ltor_mask, has_cache, use_cache): 163 | # hidden_states: [b, s, h] 164 | # ltor_mask: [1, 1, s, s] 165 | 166 | # Layer norm at the begining of the transformer layer. 167 | layernorm_output = self.input_layernorm(hidden_states) 168 | 169 | # Self attention. 170 | attention_output, att_has_cache = self.attention( 171 | layernorm_output, ltor_mask, has_cache=has_cache, use_cache=use_cache) 172 | 173 | if self.cogview_sandwich_layernorm: 174 | attention_output = self.before_first_addition_layernorm(attention_output) 175 | 176 | # Residual connection. 177 | layernorm_input = hidden_states + attention_output 178 | 179 | # Layer norm post the self attention. 180 | layernorm_output = self.post_attention_layernorm(layernorm_input) 181 | 182 | # MLP. 183 | mlp_output, mlp_has_cache = self.mlp( 184 | layernorm_output, has_cache=has_cache, use_cache=use_cache) 185 | 186 | if self.cogview_sandwich_layernorm: 187 | mlp_output = self.before_second_addition_layernorm(mlp_output) 188 | 189 | # Second residual connection. 190 | output = layernorm_input + mlp_output 191 | 192 | return output, att_has_cache and mlp_has_cache 193 | 194 | 195 | class DalleSelfAttention(torch.nn.Module): 196 | """ 197 | Self-attention layer takes input with size [b, s, h] where b is 198 | the batch size, s is the sequence length, and h is the hidden size 199 | and creates output of the same size. 200 | Arguments: 201 | hidden_size: total hidden size of the layer (h). 202 | num_attention_heads: number of attention heads (n). Note that we 203 | require n to be divisible by number of GPUs 204 | used to parallelize the model. Also, we 205 | require hidden size to be divisible by n. 206 | attention_dropout_prob: dropout probability for the attention scores. 207 | output_dropout_prob: dropout probability for the output. 208 | We use the following notation: 209 | h: hidden_size 210 | n: num_attention_heads 211 | p: number of partitions 212 | np: n/p 213 | hp: h/p 214 | hn: h/n 215 | b: batch size 216 | s: sequence length 217 | """ 218 | 219 | def __init__(self, hidden_size, num_attention_heads, 220 | attention_dropout_prob, output_dropout_prob, cogview_pb_relax=False): 221 | super(DalleSelfAttention, self).__init__() 222 | 223 | # CogView stabilization of training features, see chapter 2.4 https://arxiv.org/pdf/2105.13290.pdf 224 | self.cogview_pb_relax = cogview_pb_relax 225 | 226 | self.hidden_size = hidden_size 227 | self.num_attention_heads = num_attention_heads 228 | self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads) 229 | 230 | self.query_key_value = torch.nn.Linear(hidden_size, 3 * hidden_size) 231 | self.attention_dropout = torch.nn.Dropout(attention_dropout_prob) 232 | 233 | # Output. 234 | self.dense = torch.nn.Linear(hidden_size, hidden_size) 235 | self.output_dropout = torch.nn.Dropout(output_dropout_prob) 236 | 237 | # Cache 238 | self.past_key = None 239 | self.past_value = None 240 | self.past_output = None 241 | 242 | def _transpose_for_scores(self, tensor): 243 | """ Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with size [b, np, s, hn]. """ 244 | new_tensor_shape = tensor.size()[:-1] + (self.num_attention_heads, self.hidden_size_per_attention_head) 245 | tensor = tensor.view(*new_tensor_shape) 246 | return tensor.permute(0, 2, 1, 3) 247 | 248 | def _calculate_attention_scores(self, query_layer, key_layer, ltor_mask): 249 | key_t = key_layer.transpose(-1, -2) 250 | if self.cogview_pb_relax: 251 | attention_scores = torch.matmul( 252 | query_layer / math.sqrt(self.hidden_size_per_attention_head), 253 | key_t 254 | ) 255 | else: 256 | attention_scores = torch.matmul(query_layer, key_t) / math.sqrt(self.hidden_size_per_attention_head) 257 | ltor_mask = ltor_mask[:, :, -attention_scores.shape[-2]:] 258 | attention_scores = torch.mul(attention_scores, ltor_mask) - 10000.0 * (1.0 - ltor_mask) 259 | if self.cogview_pb_relax: 260 | # normalize attention scores. Should not affect resulting softmax value 261 | alpha = 32 262 | attention_scores_scaled = attention_scores / alpha 263 | attention_scores_scaled_maxes, _ = attention_scores_scaled.detach().view( 264 | [attention_scores.size(0), attention_scores.size(1), -1] 265 | ).max(dim=-1) # max per head per sample 266 | attention_scores_scaled_maxes = attention_scores_scaled_maxes.unsqueeze(-1).unsqueeze(-1).expand( 267 | [-1, -1, attention_scores.size(2), attention_scores.size(3)] 268 | ) # expand to [b, np, s, s] 269 | attention_scores = (attention_scores_scaled - attention_scores_scaled_maxes) * alpha 270 | return attention_scores 271 | 272 | def forward(self, hidden_states, ltor_mask, has_cache=False, use_cache=False, ): 273 | # hidden_states: [b, s, h] 274 | # ltor_mask: [1, 1, s, s] 275 | # Attention heads. [b, s, hp] 276 | if has_cache and use_cache: 277 | mixed_x_layer = self.query_key_value(hidden_states[:, self.past_key.shape[-2]:, :]) 278 | else: 279 | mixed_x_layer = self.query_key_value(hidden_states) 280 | 281 | (mixed_query_layer, 282 | mixed_key_layer, 283 | mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) 284 | 285 | query_layer = self._transpose_for_scores(mixed_query_layer) 286 | key_layer = self._transpose_for_scores(mixed_key_layer) 287 | value_layer = self._transpose_for_scores(mixed_value_layer) 288 | 289 | # Can be simplified, but I didn't for readability's sake 290 | if use_cache and has_cache: 291 | key_layer = torch.cat((self.past_key, key_layer), dim=-2) 292 | value_layer = torch.cat((self.past_value, value_layer), dim=-2) 293 | attention_scores = self._calculate_attention_scores( 294 | query_layer=query_layer, key_layer=key_layer, ltor_mask=ltor_mask 295 | ) 296 | else: 297 | attention_scores = self._calculate_attention_scores( 298 | query_layer=query_layer, key_layer=key_layer, ltor_mask=ltor_mask 299 | ) 300 | 301 | if use_cache and has_cache: 302 | extra_cache_size = hidden_states.shape[-2] - self.past_key.shape[-2] 303 | attention_scores = attention_scores[..., -extra_cache_size:, :] 304 | 305 | if use_cache: 306 | self.past_key = key_layer 307 | self.past_value = value_layer 308 | else: 309 | self.past_key = None 310 | self.past_value = None 311 | self.past_output = None 312 | has_cache = False 313 | 314 | # Attention probabilities. [b, np, s, s] 315 | attention_probs = torch.nn.Softmax(dim=-1)(attention_scores) 316 | 317 | # This is actually dropping out entire tokens to attend to, which might 318 | # seem a bit unusual, but is taken from the original Transformer paper. 319 | attention_probs = self.attention_dropout(attention_probs) 320 | 321 | # Context layer. 322 | # [b, np, s, hn] 323 | context_layer = torch.matmul(attention_probs, value_layer) 324 | 325 | # [b, s, np, hn] 326 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 327 | 328 | new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,) 329 | # [b, s, hp] 330 | context_layer = context_layer.view(*new_context_layer_shape) 331 | 332 | # Output. [b, s, h] 333 | output = self.dense(context_layer) 334 | 335 | if use_cache: 336 | # Can be simplified, but I didn't for readability's sake 337 | if has_cache: 338 | output = torch.cat((self.past_output, output), dim=-2) 339 | self.past_output = output 340 | else: 341 | self.past_output = output 342 | has_cache = True 343 | 344 | output = self.output_dropout(output) 345 | return output, has_cache 346 | 347 | 348 | class DalleMLP(torch.nn.Module): 349 | """ 350 | MLP will take the input with h hidden state, project it to 4*h 351 | hidden dimension, perform gelu transformation, and project the 352 | state back into h hidden dimension. At the end, dropout is also 353 | applied. 354 | Arguments: 355 | hidden_size: The hidden size of the self attention. 356 | output_dropout_prob: dropout probability for the outputs 357 | after self attention and final output. 358 | """ 359 | 360 | def __init__(self, hidden_size, output_dropout_prob, activation='gelu_jit'): 361 | super(DalleMLP, self).__init__() 362 | self.activation = activation 363 | # Project to 4h. 364 | self.dense_h_to_4h = torch.nn.Linear(hidden_size, 4 * hidden_size) 365 | # Project back to h. 366 | self.dense_4h_to_h = torch.nn.Linear(4 * hidden_size, hidden_size) 367 | self.dropout = torch.nn.Dropout(output_dropout_prob) 368 | # MLP cache 369 | self.past_x = None 370 | 371 | def forward(self, hidden_states, has_cache=False, use_cache=False): 372 | if has_cache and use_cache: 373 | hidden_states = hidden_states[:, self.past_x.shape[-2]:] 374 | 375 | # [b, s, 4hp] 376 | x = self.dense_h_to_4h(hidden_states) 377 | if self.activation == 'gelu_jit': 378 | x = gelu_jit(x) 379 | elif self.activation == 'gelu': 380 | x = gelu(x) 381 | else: 382 | raise NotImplementedError('Used MLP activation is not implemented.') 383 | # [b, s, h] 384 | x = self.dense_4h_to_h(x) 385 | if use_cache: 386 | # Can be simplified, but I didn't for readability's sake 387 | if has_cache: 388 | x = torch.cat((self.past_x, x), dim=-2) 389 | self.past_x = x 390 | else: 391 | self.past_x = x 392 | 393 | has_cache = True 394 | else: 395 | self.past_x = None 396 | has_cache = False 397 | output = self.dropout(x) 398 | 399 | return output, has_cache 400 | --------------------------------------------------------------------------------