├── tests
├── __init__.py
├── test_show.py
├── test_super_resolution.py
├── test_emojich_unet.py
├── test_tokenizer.py
├── test_image_prompts.py
├── test_dalle.py
├── conftest.py
└── test_vae.py
├── pics
├── cat-ru.png
├── man-0.png
├── man-1.png
├── man-2.png
├── man-3.png
├── man-4.png
├── man-5.png
├── man-ru.png
├── cat-gan.png
├── man-gan.png
├── woman-gan.png
├── woman-ru.png
├── avocado-gan.png
├── avocado-ru.png
├── cat-diffusion.png
├── cathedral-gan.png
├── cathedral-ru.png
├── emojich
│ ├── examples.png
│ ├── emoji-Donald.png
│ ├── emojich_rgba.png
│ ├── emojich-stickers.png
│ └── emojich_rgba_100.png
├── woman-diffusion.png
├── avocado-diffusion.png
├── cathedral-diffusion.png
├── malevich
│ ├── rainbow-full.png
│ ├── rainbow-cherry-pick.png
│ ├── rainbow-super-resolution.png
│ ├── anime-girl-super-resolution.png
│ └── russian-temple-image-prompt.png
├── habr_eng.svg
└── habr.svg
├── requirements-test.txt
├── .coveragerc
├── requirements.txt
├── setup.cfg
├── .gitlab-ci.yml
├── .pre-commit-config.yaml
├── rudalle
├── __init__.py
├── vae
│ ├── vqgan.gumbelf8-sber.config.yml
│ ├── __init__.py
│ ├── decoder_dwt.py
│ ├── model.py
│ └── pytorch_wavelets_utils.py
├── ruclip
│ ├── __init__.py
│ └── processor.py
├── utils.py
├── realesrgan
│ ├── __init__.py
│ ├── model.py
│ ├── utils.py
│ ├── rrdbnet_arch.py
│ └── arch_util.py
├── emojich_unet
│ └── __init__.py
├── dalle
│ ├── utils.py
│ ├── image_attention.py
│ ├── fp16.py
│ ├── __init__.py
│ ├── model.py
│ └── transformer.py
├── tokenizer.py
├── image_prompts.py
└── pipelines.py
├── setup.py
├── .gitignore
├── README.md
├── Emojich.md
└── LICENSE.txt
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/pics/cat-ru.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/cat-ru.png
--------------------------------------------------------------------------------
/pics/man-0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/man-0.png
--------------------------------------------------------------------------------
/pics/man-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/man-1.png
--------------------------------------------------------------------------------
/pics/man-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/man-2.png
--------------------------------------------------------------------------------
/pics/man-3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/man-3.png
--------------------------------------------------------------------------------
/pics/man-4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/man-4.png
--------------------------------------------------------------------------------
/pics/man-5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/man-5.png
--------------------------------------------------------------------------------
/pics/man-ru.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/man-ru.png
--------------------------------------------------------------------------------
/pics/cat-gan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/cat-gan.png
--------------------------------------------------------------------------------
/pics/man-gan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/man-gan.png
--------------------------------------------------------------------------------
/pics/woman-gan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/woman-gan.png
--------------------------------------------------------------------------------
/pics/woman-ru.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/woman-ru.png
--------------------------------------------------------------------------------
/requirements-test.txt:
--------------------------------------------------------------------------------
1 | -r requirements.txt
2 | pytest
3 | pytest-cov
4 | pre-commit
5 |
--------------------------------------------------------------------------------
/pics/avocado-gan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/avocado-gan.png
--------------------------------------------------------------------------------
/pics/avocado-ru.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/avocado-ru.png
--------------------------------------------------------------------------------
/pics/cat-diffusion.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/cat-diffusion.png
--------------------------------------------------------------------------------
/pics/cathedral-gan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/cathedral-gan.png
--------------------------------------------------------------------------------
/pics/cathedral-ru.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/cathedral-ru.png
--------------------------------------------------------------------------------
/pics/emojich/examples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/emojich/examples.png
--------------------------------------------------------------------------------
/pics/woman-diffusion.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/woman-diffusion.png
--------------------------------------------------------------------------------
/pics/avocado-diffusion.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/avocado-diffusion.png
--------------------------------------------------------------------------------
/pics/cathedral-diffusion.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/cathedral-diffusion.png
--------------------------------------------------------------------------------
/.coveragerc:
--------------------------------------------------------------------------------
1 | [run]
2 | omit =
3 | # omit this single file
4 | rudalle/vae/pytorch_wavelets_utils.py
5 |
--------------------------------------------------------------------------------
/pics/emojich/emoji-Donald.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/emojich/emoji-Donald.png
--------------------------------------------------------------------------------
/pics/emojich/emojich_rgba.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/emojich/emojich_rgba.png
--------------------------------------------------------------------------------
/pics/malevich/rainbow-full.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/malevich/rainbow-full.png
--------------------------------------------------------------------------------
/pics/emojich/emojich-stickers.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/emojich/emojich-stickers.png
--------------------------------------------------------------------------------
/pics/emojich/emojich_rgba_100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/emojich/emojich_rgba_100.png
--------------------------------------------------------------------------------
/pics/malevich/rainbow-cherry-pick.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/malevich/rainbow-cherry-pick.png
--------------------------------------------------------------------------------
/pics/malevich/rainbow-super-resolution.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/malevich/rainbow-super-resolution.png
--------------------------------------------------------------------------------
/pics/malevich/anime-girl-super-resolution.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/malevich/anime-girl-super-resolution.png
--------------------------------------------------------------------------------
/pics/malevich/russian-temple-image-prompt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/ru-dalle/HEAD/pics/malevich/russian-temple-image-prompt.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | taming-transformers==0.0.1
2 | more_itertools~=8.10.0
3 | transformers~=4.10.2
4 | youtokentome~=1.0.6
5 | omegaconf>=2.0.0
6 | einops~=0.3.2
7 | PyWavelets==1.1.1
8 | segmentation-models-pytorch==0.1.3
9 | opencv-python==4.5.4.60
10 | torch
11 | torchvision
12 | matplotlib
13 |
--------------------------------------------------------------------------------
/tests/test_show.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from rudalle.pipelines import show
3 |
4 |
5 | def test_show(sample_image):
6 | img = sample_image.copy()
7 | img = img.resize((256, 256))
8 | pil_images = [img]*5
9 | show(pil_images, nrow=2, save_dir='/tmp/pics', show=False)
10 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [pep8]
2 | max-line-length = 120
3 | exclude = .tox,*migrations*,.json
4 |
5 | [flake8]
6 | max-line-length = 120
7 | exclude = .tox,*migrations*,.json
8 |
9 | [autopep8-wrapper]
10 | exclude = .tox,*migrations*,.json
11 |
12 | [check-docstring-first]
13 | exclude = .tox,*migrations*,.json
14 |
--------------------------------------------------------------------------------
/tests/test_super_resolution.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from rudalle.pipelines import super_resolution
3 |
4 |
5 | def test_super_resolution(sample_image, realesrgan):
6 | img = sample_image.copy()
7 | img = img.resize((32, 32))
8 | sr_img = super_resolution([img], realesrgan)[0]
9 | assert sr_img.size[0] == 32*2
10 | assert sr_img.size[1] == 32*2
11 |
--------------------------------------------------------------------------------
/.gitlab-ci.yml:
--------------------------------------------------------------------------------
1 | stages:
2 | - test
3 |
4 | all_branch_test:
5 | stage: test
6 | tags:
7 | - docker
8 | image: python:3.9
9 | script:
10 | - apt-get update ##[edited]
11 | - apt-get install ffmpeg libsm6 libxext6 -y
12 | - pip install cython
13 | - pip install -r requirements-test.txt --no-cache-dir
14 | - pip install timm==0.4.12
15 | - pip install codecov
16 | - pytest --cov=rudalle tests/
17 | - bash <(curl -s https://codecov.io/bash) -t $CODECOV_TOKEN
18 | except:
19 | - tags
20 |
--------------------------------------------------------------------------------
/tests/test_emojich_unet.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import numpy as np
3 |
4 | from rudalle.pipelines import convert_emoji_to_rgba
5 |
6 |
7 | def test_convert_emoji_to_rgba(sample_image, emojich_unet):
8 | img = sample_image.copy()
9 | img = img.resize((512, 512))
10 | rgba_images, runs = convert_emoji_to_rgba([img], emojich_unet, score_thr=0.99)
11 | assert len(runs) == len(rgba_images)
12 | rgba_img = rgba_images[0]
13 | assert rgba_img.size[0] == 512
14 | assert rgba_img.size[1] == 512
15 | assert np.array(rgba_img).shape[-1] == 4
16 | assert runs[0] in ['unet', 'classic']
17 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v4.0.1
4 | hooks:
5 | - id: check-docstring-first
6 | - id: check-merge-conflict
7 | stages:
8 | - push
9 | - id: double-quote-string-fixer
10 | - id: end-of-file-fixer
11 | - id: fix-encoding-pragma
12 | - id: mixed-line-ending
13 | - id: trailing-whitespace
14 | - repo: https://github.com/pycqa/flake8
15 | rev: "4.0.1"
16 | hooks:
17 | - id: flake8
18 | args: ['--config=setup.cfg']
19 | - repo: https://github.com/pre-commit/mirrors-autopep8
20 | rev: v1.5.7
21 | hooks:
22 | - id: autopep8
23 |
--------------------------------------------------------------------------------
/rudalle/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from .vae import get_vae
3 | from .dalle import get_rudalle_model
4 | from .tokenizer import get_tokenizer
5 | from .realesrgan import get_realesrgan
6 | from .ruclip import get_ruclip
7 | from .emojich_unet import get_emojich_unet
8 | from . import vae, dalle, tokenizer, realesrgan, pipelines, ruclip, image_prompts
9 |
10 |
11 | __all__ = [
12 | 'get_vae',
13 | 'get_rudalle_model',
14 | 'get_tokenizer',
15 | 'get_realesrgan',
16 | 'get_ruclip',
17 | 'get_emojich_unet',
18 | 'vae',
19 | 'dalle',
20 | 'ruclip',
21 | 'tokenizer',
22 | 'realesrgan',
23 | 'pipelines',
24 | 'image_prompts',
25 | ]
26 |
27 | __version__ = '0.4.0'
28 |
--------------------------------------------------------------------------------
/tests/test_tokenizer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import pytest
3 |
4 |
5 | @pytest.mark.parametrize('text, text_seq_length, bpe_dropout', [
6 | ('hello, how are you?', 128, 0.1),
7 | ('hello, how are you?', 128, 0.5),
8 | ('hello, how are you?', 128, 1.0),
9 | ('hello ... how are you ?', 256, 1.0),
10 | ('a person standing at a table with bottles of win', 64, 0.5),
11 | ('привет как дела???', 76, 0.0),
12 | ('клип на русском языке :)', 76, 0.1),
13 | ])
14 | def test_encode_decode_text_yttm(yttm_tokenizer, text, text_seq_length, bpe_dropout):
15 | tokens = yttm_tokenizer.encode_text(text, text_seq_length=text_seq_length, bpe_dropout=bpe_dropout)
16 | decoded_text = yttm_tokenizer.decode_text(tokens)
17 | assert text == decoded_text
18 |
--------------------------------------------------------------------------------
/tests/test_image_prompts.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import pytest
3 |
4 | from rudalle.image_prompts import ImagePrompts
5 |
6 |
7 | @pytest.mark.parametrize('borders, crop_first', [
8 | ({'up': 4, 'right': 0, 'left': 0, 'down': 0}, False),
9 | ({'up': 4, 'right': 0, 'left': 0, 'down': 0}, True),
10 | ({'up': 4, 'right': 3, 'left': 3, 'down': 3}, False)
11 | ])
12 | def test_image_prompts(sample_image, vae, borders, crop_first):
13 | img = sample_image.copy()
14 | img = img.resize((256, 256))
15 | image_prompt = ImagePrompts(img, borders, vae, crop_first=crop_first)
16 | assert image_prompt.image_prompts.shape[1] == 32 * 32
17 | assert len(image_prompt.image_prompts_idx) == (borders['up'] + borders['down']) * 32 \
18 | + (borders['left'] + borders['right']) * (32 - borders['up'] - borders['down'])
19 |
--------------------------------------------------------------------------------
/rudalle/vae/vqgan.gumbelf8-sber.config.yml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: taming.models.vqgan.GumbelVQ
4 | params:
5 | kl_weight: 1.0e-08
6 | embed_dim: 256
7 | n_embed: 8192
8 | monitor: val/rec_loss
9 | temperature_scheduler_config:
10 | target: taming.lr_scheduler.LambdaWarmUpCosineScheduler
11 | params:
12 | warm_up_steps: 0
13 | max_decay_steps: 1000001
14 | lr_start: 0.9
15 | lr_max: 0.9
16 | lr_min: 1.0e-06
17 | ddconfig:
18 | double_z: false
19 | z_channels: 256
20 | resolution: 256
21 | in_channels: 3
22 | out_ch: 3
23 | ch: 128
24 | ch_mult:
25 | - 1
26 | - 1
27 | - 2
28 | - 4
29 | num_res_blocks: 2
30 | attn_resolutions:
31 | - 32
32 | dropout: 0.0
33 | lossconfig:
34 | target: taming.modules.losses.vqperceptual.DummyLoss
35 |
--------------------------------------------------------------------------------
/rudalle/ruclip/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 |
4 | from transformers import CLIPModel
5 | from huggingface_hub import hf_hub_url, cached_download
6 |
7 | from .processor import RuCLIPProcessor
8 |
9 | MODELS = {
10 | 'ruclip-vit-base-patch32-v5': dict(
11 | repo_id='sberbank-ai/ru-clip',
12 | filenames=[
13 | 'bpe.model', 'config.json', 'pytorch_model.bin'
14 | ]
15 | ),
16 | }
17 |
18 |
19 | def get_ruclip(name, cache_dir='/tmp/rudalle'):
20 | assert name in MODELS
21 | config = MODELS[name]
22 | repo_id = config['repo_id']
23 | cache_dir = os.path.join(cache_dir, name)
24 | for filename in config['filenames']:
25 | config_file_url = hf_hub_url(repo_id=repo_id, filename=f'{name}/{filename}')
26 | cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename)
27 | ruclip = CLIPModel.from_pretrained(cache_dir)
28 | ruclip_processor = RuCLIPProcessor.from_pretrained(cache_dir)
29 | print('ruclip --> ready')
30 | return ruclip, ruclip_processor
31 |
--------------------------------------------------------------------------------
/rudalle/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import random
4 |
5 | import torch
6 | import torchvision
7 | import numpy as np
8 |
9 |
10 | def seed_everything(seed):
11 | random.seed(seed)
12 | os.environ['PYTHONHASHSEED'] = str(seed)
13 | np.random.seed(seed)
14 | torch.manual_seed(seed)
15 | torch.cuda.manual_seed(seed)
16 | torch.backends.cudnn.deterministic = True
17 | torch.backends.cudnn.benchmark = True
18 |
19 |
20 | def torch_tensors_to_pil_list(input_images):
21 | out_images = []
22 | for in_image in input_images:
23 | in_image = in_image.cpu().detach()
24 | out_image = torchvision.transforms.functional.to_pil_image(in_image).convert('RGB')
25 | out_images.append(out_image)
26 | return out_images
27 |
28 |
29 | def pil_list_to_torch_tensors(pil_images):
30 | result = []
31 | for pil_image in pil_images:
32 | image = np.array(pil_image, dtype=np.uint8)
33 | image = torch.from_numpy(image)
34 | image = image.permute(2, 0, 1).unsqueeze(0)
35 | result.append(image)
36 | return torch.cat(result, dim=0)
37 |
--------------------------------------------------------------------------------
/rudalle/realesrgan/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 |
4 | from huggingface_hub import hf_hub_url, cached_download
5 |
6 | from .model import RealESRGAN
7 |
8 |
9 | MODELS = {
10 | 'x2': dict(
11 | scale=2,
12 | repo_id='shonenkov/rudalle-utils',
13 | filename='RealESRGAN_x2.pth',
14 | ),
15 | 'x4': dict(
16 | scale=4,
17 | repo_id='shonenkov/rudalle-utils',
18 | filename='RealESRGAN_x4.pth',
19 | ),
20 | 'x8': dict(
21 | scale=8,
22 | repo_id='shonenkov/rudalle-utils',
23 | filename='RealESRGAN_x8.pth',
24 | ),
25 | }
26 |
27 |
28 | def get_realesrgan(name, device='cpu', fp16=False, cache_dir='/tmp/rudalle'):
29 | assert name in MODELS
30 | config = MODELS[name]
31 | model = RealESRGAN(device, config['scale'], fp16=fp16)
32 | cache_dir = os.path.join(cache_dir, name)
33 | config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename'])
34 | cached_download(config_file_url, cache_dir=cache_dir, force_filename=config['filename'])
35 | model.load_weights(os.path.join(cache_dir, config['filename']))
36 | print(f'{name} --> ready')
37 | return model
38 |
--------------------------------------------------------------------------------
/rudalle/vae/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from os.path import dirname, abspath, join
3 |
4 | import torch
5 | from huggingface_hub import hf_hub_url, cached_download
6 | from omegaconf import OmegaConf
7 |
8 | from .model import VQGanGumbelVAE
9 |
10 |
11 | def get_vae(pretrained=True, dwt=False, cache_dir='/tmp/rudalle'):
12 | # TODO
13 | config = OmegaConf.load(join(dirname(abspath(__file__)), 'vqgan.gumbelf8-sber.config.yml'))
14 | vae = VQGanGumbelVAE(config, dwt=dwt)
15 | if pretrained:
16 | repo_id = 'shonenkov/rudalle-utils'
17 | if dwt:
18 | filename = 'vqgan.gumbelf8-sber-dwt.model.ckpt'
19 | else:
20 | filename = 'vqgan.gumbelf8-sber.model.ckpt'
21 | cache_dir = join(cache_dir, 'vae')
22 | config_file_url = hf_hub_url(repo_id=repo_id, filename=filename)
23 | cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename)
24 | checkpoint = torch.load(join(cache_dir, filename), map_location='cpu')
25 | if dwt:
26 | vae.load_state_dict(checkpoint['state_dict'])
27 | else:
28 | vae.model.load_state_dict(checkpoint['state_dict'], strict=False)
29 | print('vae --> ready')
30 | return vae
31 |
--------------------------------------------------------------------------------
/tests/test_dalle.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | import pytest
4 |
5 | from .test_vae import preprocess
6 |
7 |
8 | @pytest.mark.parametrize('text', [
9 | 'мальчик играет с оленем',
10 | ])
11 | def test_forward_step_and_criterion(text, sample_image, yttm_tokenizer, vae, small_dalle):
12 | bs = 4
13 | text_seq_length = small_dalle.get_param('text_seq_length')
14 | total_seq_length = small_dalle.get_param('total_seq_length')
15 | device = small_dalle.get_param('device')
16 |
17 | img = sample_image.copy()
18 | img = preprocess(img, target_image_size=256)
19 | images = img.repeat(bs, 1, 1, 1).to(device)
20 |
21 | text = text.lower().strip()
22 | text_input_ids = yttm_tokenizer.encode_text(text, text_seq_length=text_seq_length)
23 | text_input_ids = text_input_ids.unsqueeze(0).repeat(bs, 1).to(device)
24 |
25 | attention_mask = torch.tril(torch.ones((bs, 1, total_seq_length, total_seq_length), device=device))
26 | with torch.no_grad():
27 | image_input_ids = vae.get_codebook_indices(images)
28 | input_ids = torch.cat((text_input_ids, image_input_ids), dim=1)
29 | loss, loss_values = small_dalle.forward(input_ids, attention_mask, return_loss=True)
30 | assert type(loss.data.detach().item()) == float
31 | assert type(loss_values) == dict
32 |
--------------------------------------------------------------------------------
/pics/habr_eng.svg:
--------------------------------------------------------------------------------
1 |
7 |
--------------------------------------------------------------------------------
/rudalle/emojich_unet/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 |
4 | import torch
5 | from huggingface_hub import hf_hub_url, cached_download
6 |
7 |
8 | MODELS = {
9 | 'unet_effnetb5': dict(
10 | encoder_name='efficientnet-b5',
11 | repo_id='sberbank-ai/rudalle-Emojich',
12 | filename='pytorch_model_v2.bin',
13 | classes=2,
14 | ),
15 | }
16 |
17 |
18 | def get_emojich_unet(name, cache_dir='/tmp/rudalle'):
19 | assert name in MODELS
20 | config = MODELS[name]
21 | try:
22 | import segmentation_models_pytorch as smp
23 | except ImportError:
24 | import logging
25 | logging.warning('If you would like to use emojich_unet, you should reinstall timm package:'
26 | '"pip install timm==0.4.12"')
27 | return
28 | model = smp.Unet(
29 | encoder_name=config['encoder_name'],
30 | encoder_weights=None,
31 | in_channels=3,
32 | classes=config['classes'],
33 | )
34 | cache_dir = os.path.join(cache_dir, name)
35 | filename = config['filename']
36 | config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=f'{name}/{filename}')
37 | cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename)
38 | checkpoint = torch.load(os.path.join(cache_dir, config['filename']), map_location='cpu')
39 | model.load_state_dict(checkpoint)
40 | print(f'{name} --> ready')
41 | return model
42 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import io
3 | from os.path import abspath, dirname
4 |
5 | import PIL
6 | import pytest
7 | import requests
8 |
9 | from rudalle import get_tokenizer, get_rudalle_model, get_vae, get_realesrgan, get_emojich_unet
10 |
11 |
12 | TEST_ROOT = dirname(abspath(__file__))
13 |
14 |
15 | @pytest.fixture(scope='module')
16 | def realesrgan():
17 | realesrgan = get_realesrgan('x2', device='cpu')
18 | yield realesrgan
19 |
20 |
21 | @pytest.fixture(scope='module')
22 | def vae():
23 | vae = get_vae(pretrained=False)
24 | yield vae
25 |
26 |
27 | @pytest.fixture(scope='module')
28 | def dwt_vae():
29 | vae = get_vae(pretrained=False, dwt=True)
30 | yield vae
31 |
32 |
33 | @pytest.fixture(scope='module')
34 | def yttm_tokenizer():
35 | tokenizer = get_tokenizer()
36 | yield tokenizer
37 |
38 |
39 | @pytest.fixture(scope='module')
40 | def sample_image():
41 | url = 'https://cdn.kqed.org/wp-content/uploads/sites/12/2013/12/rudolph.png'
42 | resp = requests.get(url)
43 | resp.raise_for_status()
44 | image = PIL.Image.open(io.BytesIO(resp.content))
45 | yield image
46 |
47 |
48 | @pytest.fixture(scope='module')
49 | def small_dalle():
50 | model = get_rudalle_model('small', pretrained=False, fp16=False, device='cpu')
51 | yield model
52 |
53 |
54 | @pytest.fixture(scope='module')
55 | def emojich_unet():
56 | model = get_emojich_unet('unet_effnetb5')
57 | yield model
58 |
--------------------------------------------------------------------------------
/pics/habr.svg:
--------------------------------------------------------------------------------
1 |
7 |
--------------------------------------------------------------------------------
/rudalle/dalle/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 |
4 |
5 | def exists(val):
6 | return val is not None
7 |
8 |
9 | def is_empty(t):
10 | return t.nelement() == 0
11 |
12 |
13 | def ensure_divisibility(numerator, denominator):
14 | """Ensure that numerator is divisible by the denominator."""
15 | assert numerator % denominator == 0, '{} is not divisible by {}'.format(
16 | numerator, denominator)
17 |
18 |
19 | def divide(numerator, denominator):
20 | """Ensure that numerator is divisible by the denominator and return
21 | the division value."""
22 | ensure_divisibility(numerator, denominator)
23 | return numerator // denominator
24 |
25 |
26 | def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
27 | """
28 | Split a tensor along its last dimension.
29 | Arguments:
30 | tensor: input tensor.
31 | num_partitions: number of partitions to split the tensor
32 | contiguous_split_chunks: If True, make each chunk contiguous
33 | in memory.
34 | """
35 | # Get the size and dimension.
36 | last_dim = tensor.dim() - 1
37 | last_dim_size = divide(tensor.size()[last_dim], num_partitions)
38 | # Split.
39 | tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
40 | # Note: torch.split does not create contiguous tensors by default.
41 | if contiguous_split_chunks:
42 | return tuple(chunk.contiguous() for chunk in tensor_list)
43 | return tensor_list
44 |
45 |
46 | def init_method_normal(std=0.02):
47 | """Init method based on normal distribution.
48 |
49 | This is only used for embeddings. The transformer has its
50 | own initializer.
51 | """
52 | def init_(tensor):
53 | return torch.nn.init.normal_(tensor, mean=0.0, std=std)
54 | return init_
55 |
--------------------------------------------------------------------------------
/rudalle/dalle/image_attention.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 |
5 |
6 | def _init_mask(text_tokens, image_tokens_per_dim, is_bool_mask=False):
7 | attn_size = text_tokens + image_tokens_per_dim**2
8 | mask = torch.tril(torch.ones(attn_size, attn_size, dtype=torch.bool if is_bool_mask else torch.float32))
9 | return mask
10 |
11 |
12 | def get_row_mask(text_tokens=256, image_tokens_per_dim=32, is_bool_mask=False):
13 | mask = _init_mask(text_tokens, image_tokens_per_dim, is_bool_mask=is_bool_mask)
14 | step = image_tokens_per_dim + 1
15 | for col in range(text_tokens, mask.size(1)):
16 | mask[col + step:, col] = False if is_bool_mask else 0.0
17 | return mask
18 |
19 |
20 | def get_col_mask(text_tokens=256, image_tokens_per_dim=32, is_bool_mask=False):
21 | mask = _init_mask(text_tokens, image_tokens_per_dim, is_bool_mask=is_bool_mask)
22 | step = image_tokens_per_dim - 1
23 | for col in range(text_tokens, mask.size(1)):
24 | for i in range(1, mask.size(0), step+1):
25 | mask[col + i: col + i + step, col] = False if is_bool_mask else 0.0
26 | return mask
27 |
28 |
29 | def get_conv_mask(text_tokens=256, image_tokens_per_dim=32, kernel=11, is_bool_mask=False):
30 | mask = _init_mask(text_tokens, image_tokens_per_dim, is_bool_mask=is_bool_mask)
31 | shift = kernel // 2
32 | for pos in range(text_tokens, mask.size(1)):
33 | mask[pos+1:, pos] = False if is_bool_mask else 0.0
34 | img = torch.zeros(image_tokens_per_dim, image_tokens_per_dim)
35 | pixel_id = pos - text_tokens
36 | row = pixel_id // image_tokens_per_dim
37 | col = pixel_id % image_tokens_per_dim
38 | for r in range(-shift, shift+1):
39 | for c in range(-shift, shift+1):
40 | c_abs = (c + col) % image_tokens_per_dim
41 | r_abs = (r + row) % image_tokens_per_dim
42 | img[r_abs, c_abs] = 0.2
43 | cell_id = r_abs * image_tokens_per_dim + c_abs
44 | if text_tokens + cell_id > pos:
45 | mask[text_tokens + cell_id, pos] = True if is_bool_mask else 1.0
46 |
47 | img[row, col] = 1.0
48 | return mask
49 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import re
4 | from setuptools import setup
5 |
6 |
7 | def read(filename):
8 | with open(os.path.join(os.path.dirname(__file__), filename)) as f:
9 | file_content = f.read()
10 | return file_content
11 |
12 |
13 | def get_requirements():
14 | requirements = []
15 | for requirement in read('requirements.txt').splitlines():
16 | if requirement.startswith('git+') or requirement.startswith('svn+') or requirement.startswith('hg+'):
17 | parsed_requires = re.findall(r'#egg=([\w\d\.]+)-([\d\.]+)$', requirement)
18 | if parsed_requires:
19 | package, version = parsed_requires[0]
20 | requirements.append(f'{package}=={version}')
21 | else:
22 | print('WARNING! For correct matching dependency links need to specify package name and version'
23 | 'such as #egg=-')
24 | else:
25 | requirements.append(requirement)
26 | return requirements
27 |
28 |
29 | def get_links():
30 | return [
31 | requirement for requirement in read('requirements.txt').splitlines()
32 | if requirement.startswith('git+') or requirement.startswith('svn+') or requirement.startswith('hg+')
33 | ]
34 |
35 |
36 | def get_version():
37 | """ Get version from the package without actually importing it. """
38 | init = read('rudalle/__init__.py')
39 | for line in init.split('\n'):
40 | if line.startswith('__version__'):
41 | return eval(line.split('=')[1])
42 |
43 |
44 | setup(
45 | name='rudalle',
46 | version=get_version(),
47 | author='SberAI, SberDevices',
48 | author_email='shonenkov@phystech.edu',
49 | description='ruDALL-E generate images from texts in Russian language',
50 | packages=['rudalle', 'rudalle/dalle', 'rudalle/realesrgan', 'rudalle/ruclip', 'rudalle/vae',
51 | 'rudalle/emojich_unet'],
52 | package_data={'rudalle/vae': ['*.yml']},
53 | install_requires=get_requirements(),
54 | dependency_links=get_links(),
55 | long_description=read('README.md'),
56 | long_description_content_type='text/markdown',
57 | )
58 |
--------------------------------------------------------------------------------
/tests/test_vae.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import PIL
3 | import pytest
4 | import torch
5 | import torchvision.transforms as T
6 | import torchvision.transforms.functional as TF
7 |
8 |
9 | @pytest.mark.parametrize('target_image_size', [128, 192, 256])
10 | def test_decode_vae(vae, sample_image, target_image_size):
11 | img = sample_image.copy()
12 | img = preprocess(img, target_image_size=target_image_size)
13 | with torch.no_grad():
14 | img_seq = vae.get_codebook_indices(img)
15 | out_img = vae.decode(img_seq)
16 | assert out_img.shape == (1, 3, target_image_size, target_image_size)
17 |
18 |
19 | @pytest.mark.parametrize('target_image_size', [128, 192, 256])
20 | def test_reconstruct_vae(vae, sample_image, target_image_size):
21 | img = sample_image.copy()
22 | with torch.no_grad():
23 | x_vqgan = preprocess(img, target_image_size=target_image_size)
24 | output = reconstruct_with_vqgan(preprocess_vqgan(x_vqgan), vae.model)
25 | assert output.shape == (1, 3, target_image_size, target_image_size)
26 |
27 |
28 | @pytest.mark.parametrize('target_image_size', [256])
29 | def test_reconstruct_dwt_vae(dwt_vae, sample_image, target_image_size):
30 | img = sample_image.copy()
31 | with torch.no_grad():
32 | x_vqgan = preprocess(img, target_image_size=target_image_size)
33 | output = reconstruct_with_vqgan(preprocess_vqgan(x_vqgan), dwt_vae.model)
34 | assert output.shape == (1, 3, target_image_size*2, target_image_size*2)
35 |
36 |
37 | def preprocess(img, target_image_size=256):
38 | s = min(img.size)
39 | if s < target_image_size:
40 | raise ValueError(f'min dim for image {s} < {target_image_size}')
41 | r = target_image_size / s
42 | s = (round(r * img.size[1]), round(r * img.size[0]))
43 | img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS)
44 | img = TF.center_crop(img, output_size=2 * [target_image_size])
45 | img = torch.unsqueeze(T.ToTensor()(img), 0)
46 | return img
47 |
48 |
49 | def preprocess_vqgan(x):
50 | x = 2.*x - 1.
51 | return x
52 |
53 |
54 | def reconstruct_with_vqgan(x, model):
55 | z, _, [_, _, _] = model.encode(x)
56 | print(f'VQGAN --- {model.__class__.__name__}: latent shape: {z.shape[2:]}')
57 | xrec = model.decode(z)
58 | return xrec
59 |
--------------------------------------------------------------------------------
/rudalle/tokenizer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from os.path import join
3 |
4 | import torch
5 | import numpy as np
6 | import youtokentome as yttm
7 | from huggingface_hub import hf_hub_url, cached_download
8 |
9 |
10 | def get_tokenizer(path=None, cache_dir='/tmp/rudalle'):
11 | # TODO docstring
12 | if path is None:
13 | repo_id = 'shonenkov/rudalle-utils'
14 | filename = 'bpe.model'
15 | cache_dir = join(cache_dir, 'tokenizer')
16 | config_file_url = hf_hub_url(repo_id=repo_id, filename=filename)
17 | cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename)
18 | path = join(cache_dir, filename)
19 | tokenizer = YTTMTokenizerWrapper(yttm.BPE(model=path))
20 | print('tokenizer --> ready')
21 | return tokenizer
22 |
23 |
24 | class YTTMTokenizerWrapper:
25 | eos_id = 3
26 | bos_id = 2
27 | unk_id = 1
28 | pad_id = 0
29 |
30 | def __init__(self, tokenizer):
31 | self.tokenizer = tokenizer
32 |
33 | def __len__(self):
34 | return self.vocab_size()
35 |
36 | def get_pad_token_id(self):
37 | # TODO docstring
38 | return self.tokenizer.subword_to_id('')
39 |
40 | def vocab_size(self):
41 | # TODO docstring
42 | return self.tokenizer.vocab_size()
43 |
44 | def encode_text(self, text, text_seq_length, bpe_dropout=0.0):
45 | # TODO docstring
46 | tokens = self.tokenizer.encode([text], output_type=yttm.OutputType.ID, dropout_prob=bpe_dropout)[0]
47 | tokens = [self.bos_id] + tokens + [self.eos_id]
48 | return self.prepare_tokens(tokens, text_seq_length)
49 |
50 | def decode_text(self, encoded):
51 | # TODO docstring
52 | return self.tokenizer.decode(encoded.cpu().numpy().tolist(), ignore_ids=[
53 | self.eos_id, self.bos_id, self.unk_id, self.pad_id
54 | ])[0]
55 |
56 | @staticmethod
57 | def prepare_tokens(tokens, text_seq_length):
58 | # TODO docstring
59 | empty_positions = text_seq_length - len(tokens)
60 | if empty_positions > 0:
61 | tokens = np.hstack((tokens, np.zeros(empty_positions))) # position tokens after text
62 | if len(tokens) > text_seq_length:
63 | tokens = tokens[:text_seq_length]
64 | return torch.tensor(tokens).long()
65 |
--------------------------------------------------------------------------------
/rudalle/dalle/fp16.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | from torch import nn
4 | from torch.autograd import Variable
5 | from torch.nn.parameter import Parameter
6 |
7 | FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
8 | HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
9 |
10 |
11 | def conversion_helper(val, conversion):
12 | """Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure."""
13 | if not isinstance(val, (tuple, list)):
14 | return conversion(val)
15 | rtn = [conversion_helper(v, conversion) for v in val]
16 | if isinstance(val, tuple):
17 | rtn = tuple(rtn)
18 | return rtn
19 |
20 |
21 | def fp32_to_fp16(val):
22 | """Convert fp32 `val` to fp16"""
23 | def half_conversion(val):
24 | val_typecheck = val
25 | if isinstance(val_typecheck, (Parameter, Variable)):
26 | val_typecheck = val.data
27 | if isinstance(val_typecheck, FLOAT_TYPES):
28 | val = val.half()
29 | return val
30 | return conversion_helper(val, half_conversion)
31 |
32 |
33 | def fp16_to_fp32(val):
34 | """Convert fp16 `val` to fp32"""
35 | def float_conversion(val):
36 | val_typecheck = val
37 | if isinstance(val_typecheck, (Parameter, Variable)):
38 | val_typecheck = val.data
39 | if isinstance(val_typecheck, HALF_TYPES):
40 | val = val.float()
41 | return val
42 | return conversion_helper(val, float_conversion)
43 |
44 |
45 | class FP16Module(nn.Module):
46 | def __init__(self, module):
47 | super(FP16Module, self).__init__()
48 | self.add_module('module', module.half())
49 |
50 | def forward(self, *inputs, **kwargs):
51 | return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs))
52 |
53 | def state_dict(self, destination=None, prefix='', keep_vars=False):
54 | return self.module.state_dict(destination, prefix, keep_vars)
55 |
56 | def load_state_dict(self, state_dict, strict=True):
57 | self.module.load_state_dict(state_dict, strict=strict)
58 |
59 | def get_param(self, item):
60 | return self.module.get_param(item)
61 |
62 | def to(self, device, *args, **kwargs):
63 | self.module.to(device)
64 | return super().to(device, *args, **kwargs)
65 |
--------------------------------------------------------------------------------
/rudalle/realesrgan/model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Source: https://github.com/boomb0om/Real-ESRGAN-colab
3 |
4 | import torch
5 | import numpy as np
6 | from PIL import Image
7 |
8 | from .rrdbnet_arch import RRDBNet
9 | from .utils import pad_reflect, split_image_into_overlapping_patches, stich_together, unpad_image
10 | from rudalle.dalle.fp16 import FP16Module
11 |
12 |
13 | class RealESRGAN:
14 | def __init__(self, device, scale=4, fp16=False):
15 | self.device = device
16 | self.scale = scale
17 | self.model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale)
18 | self.fp16 = fp16
19 |
20 | def load_weights(self, model_path):
21 | loadnet = torch.load(model_path)
22 | if 'params' in loadnet:
23 | self.model.load_state_dict(loadnet['params'], strict=True)
24 | elif 'params_ema' in loadnet:
25 | self.model.load_state_dict(loadnet['params_ema'], strict=True)
26 | else:
27 | self.model.load_state_dict(loadnet, strict=True)
28 | self.model.eval()
29 | if self.fp16:
30 | self.model = FP16Module(self.model)
31 | self.model.to(self.device)
32 |
33 | def predict(self, lr_image, batch_size=4, patches_size=192,
34 | padding=24, pad_size=15):
35 | scale = self.scale
36 | device = self.device
37 | lr_image = np.array(lr_image)
38 | lr_image = pad_reflect(lr_image, pad_size)
39 |
40 | patches, p_shape = split_image_into_overlapping_patches(lr_image, patch_size=patches_size,
41 | padding_size=padding)
42 | if self.fp16:
43 | img = torch.HalfTensor(patches / 255).permute((0, 3, 1, 2)).to(device).detach()
44 | else:
45 | img = torch.FloatTensor(patches / 255).permute((0, 3, 1, 2)).to(device).detach()
46 |
47 | with torch.no_grad():
48 | res = self.model(img[0:batch_size])
49 | for i in range(batch_size, img.shape[0], batch_size):
50 | res = torch.cat((res, self.model(img[i:i + batch_size])), 0)
51 |
52 | sr_image = res.permute((0, 2, 3, 1)).cpu().clamp_(0, 1)
53 | np_sr_image = sr_image.numpy()
54 |
55 | padded_size_scaled = tuple(np.multiply(p_shape[0:2], scale)) + (3,)
56 | scaled_image_shape = tuple(np.multiply(lr_image.shape[0:2], scale)) + (3,)
57 | np_sr_image = stich_together(np_sr_image, padded_image_shape=padded_size_scaled,
58 | target_shape=scaled_image_shape, padding_size=padding * scale)
59 | sr_img = (np_sr_image * 255).astype(np.uint8)
60 | sr_img = unpad_image(sr_img, pad_size * scale)
61 | sr_img = Image.fromarray(sr_img)
62 |
63 | return sr_img
64 |
--------------------------------------------------------------------------------
/rudalle/ruclip/processor.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import json
4 | import torch
5 | import youtokentome as yttm
6 | import torchvision.transforms as T
7 | from torch.nn.utils.rnn import pad_sequence
8 |
9 |
10 | class RuCLIPProcessor:
11 | eos_id = 3
12 | bos_id = 2
13 | unk_id = 1
14 | pad_id = 0
15 |
16 | def __init__(self, tokenizer_path, image_size=224, text_seq_length=76, mean=None, std=None):
17 |
18 | self.tokenizer = yttm.BPE(tokenizer_path)
19 | self.mean = mean or [0.485, 0.456, 0.406]
20 | self.std = std or [0.229, 0.224, 0.225]
21 | self.image_transform = T.Compose([
22 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
23 | T.RandomResizedCrop(image_size, scale=(1., 1.), ratio=(1., 1.)),
24 | T.ToTensor(),
25 | T.Normalize(mean=self.mean, std=self.std)
26 | ])
27 | self.text_seq_length = text_seq_length
28 | self.image_size = image_size
29 |
30 | def encode_text(self, text):
31 | text = text.lower()
32 | tokens = self.tokenizer.encode([text], output_type=yttm.OutputType.ID, dropout_prob=0.0)[0]
33 | tokens = [self.bos_id] + tokens + [self.eos_id]
34 | tokens = tokens[:self.text_seq_length]
35 | mask = [1] * len(tokens)
36 | return torch.tensor(tokens).long(), torch.tensor(mask).long()
37 |
38 | def decode_text(self, encoded):
39 | return self.tokenizer.decode(encoded.cpu().numpy().tolist(), ignore_ids=[
40 | self.eos_id, self.bos_id, self.unk_id, self.pad_id
41 | ])[0]
42 |
43 | def __call__(self, text=None, images=None, **kwargs):
44 | inputs = {}
45 | if text is not None:
46 | input_ids, masks = [], []
47 | texts = [text] if isinstance(text, str) else text
48 | for text in texts:
49 | tokens, mask = self.encode_text(text)
50 | input_ids.append(tokens)
51 | masks.append(mask)
52 | inputs['input_ids'] = pad_sequence(input_ids, batch_first=True)
53 | inputs['attention_mask'] = pad_sequence(masks, batch_first=True)
54 | if images is not None:
55 | pixel_values = []
56 | for i, image in enumerate(images):
57 | pixel_values.append(self.image_transform(image))
58 | inputs['pixel_values'] = pad_sequence(pixel_values, batch_first=True)
59 | return inputs
60 |
61 | @classmethod
62 | def from_pretrained(cls, folder):
63 | tokenizer_path = os.path.join(folder, 'bpe.model')
64 | config = json.load(open(os.path.join(folder, 'config.json')))
65 | image_size = config['vision_config']['image_size']
66 | text_seq_length = config['text_config']['max_position_embeddings'] - 1
67 | mean, std = config.get('mean'), config.get('std')
68 | return cls(tokenizer_path, image_size=image_size, text_seq_length=text_seq_length, mean=mean, std=std)
69 |
--------------------------------------------------------------------------------
/rudalle/image_prompts.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | import numpy as np
4 |
5 |
6 | class ImagePrompts:
7 |
8 | def __init__(self, pil_image, borders, vae, device='cpu', crop_first=False):
9 | """
10 | Args:
11 | pil_image (PIL.Image): image in PIL format
12 | borders (dict[str] | int): borders that we croped from pil_image
13 | example: {'up': 4, 'right': 0, 'left': 0, 'down': 0} (1 int eq 8 pixels)
14 | vae (VQGanGumbelVAE): VQGAN model for image encoding
15 | device (str): cpu or cuda
16 | crop_first (bool): if True, croped image before VQGAN encoding
17 | """
18 | self.device = device
19 | img = self._preprocess_img(pil_image)
20 | self.image_prompts_idx, self.image_prompts = self._get_image_prompts(img, borders, vae, crop_first)
21 |
22 | def _preprocess_img(self, pil_img):
23 | img = torch.tensor(np.array(pil_img.convert('RGB')).transpose(2, 0, 1)) / 255.
24 | img = img.unsqueeze(0).to(self.device, dtype=torch.float32)
25 | img = (2 * img) - 1
26 | return img
27 |
28 | def _get_image_prompts(self, img, borders, vae, crop_first):
29 | if crop_first:
30 | bs, _, img_w, img_h = img.shape
31 | vqg_img_w, vqg_img_h = img_w // 8, img_h // 8
32 | vqg_img = torch.zeros((bs, vqg_img_w, vqg_img_h), dtype=torch.int32, device=img.device)
33 | if borders['down'] != 0:
34 | down_border = borders['down'] * 8
35 | _, _, [_, _, down_vqg_img] = vae.model.encode(img[:, :, -down_border:, :])
36 | vqg_img[:, -borders['down']:, :] = down_vqg_img
37 | if borders['right'] != 0:
38 | right_border = borders['right'] * 8
39 | _, _, [_, _, right_vqg_img] = vae.model.encode(img[:, :, :, -right_border:])
40 | vqg_img[:, :, -borders['right']:] = right_vqg_img
41 | if borders['left'] != 0:
42 | left_border = borders['left'] * 8
43 | _, _, [_, _, left_vqg_img] = vae.model.encode(img[:, :, :, :left_border])
44 | vqg_img[:, :, :borders['left']] = left_vqg_img
45 | if borders['up'] != 0:
46 | up_border = borders['up'] * 8
47 | _, _, [_, _, up_vqg_img] = vae.model.encode(img[:, :, :up_border, :])
48 | vqg_img[:, :borders['up'], :] = up_vqg_img
49 | else:
50 | _, _, [_, _, vqg_img] = vae.model.encode(img)
51 |
52 | bs, vqg_img_w, vqg_img_h = vqg_img.shape
53 | mask = torch.zeros(vqg_img_w, vqg_img_h)
54 | if borders['up'] != 0:
55 | mask[:borders['up'], :] = 1.
56 | if borders['down'] != 0:
57 | mask[-borders['down']:, :] = 1.
58 | if borders['right'] != 0:
59 | mask[:, -borders['right']:] = 1.
60 | if borders['left'] != 0:
61 | mask[:, :borders['left']] = 1.
62 | mask = mask.reshape(-1).bool()
63 |
64 | image_prompts = vqg_img.reshape((bs, -1))
65 | image_prompts_idx = np.arange(vqg_img_w * vqg_img_h)
66 | image_prompts_idx = set(image_prompts_idx[mask])
67 |
68 | return image_prompts_idx, image_prompts
69 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by .ignore support plugin (hsz.mobi)
2 | ### JetBrains template
3 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm
4 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
5 |
6 | settings/local.py
7 | logs/*.log
8 |
9 | # User-specific stuff:
10 | .idea/
11 |
12 | # Sensitive or high-churn files:
13 | .idea/**/dataSources/
14 | .idea/**/dataSources.ids
15 | .idea/**/dataSources.xml
16 | .idea/**/dataSources.local.xml
17 | .idea/**/sqlDataSources.xml
18 | .idea/**/dynamic.xml
19 | .idea/**/uiDesigner.xml
20 |
21 | # Gradle:
22 | .idea/**/gradle.xml
23 | .idea/**/libraries
24 |
25 | # CMake
26 | cmake-build-debug/
27 |
28 | # Mongo Explorer plugin:
29 | .idea/**/mongoSettings.xml
30 |
31 | ## File-based project format:
32 | *.iws
33 |
34 | ## Plugin-specific files:
35 |
36 | # IntelliJ
37 | out/
38 |
39 | # mpeltonen/sbt-idea plugin
40 | .idea_modules/
41 |
42 | # JIRA plugin
43 | atlassian-ide-plugin.xml
44 |
45 | # Cursive Clojure plugin
46 | .idea/replstate.xml
47 |
48 | # Crashlytics plugin (for Android Studio and IntelliJ)
49 | com_crashlytics_export_strings.xml
50 | crashlytics.properties
51 | crashlytics-build.properties
52 | fabric.properties
53 | ### Python template
54 | # Byte-compiled / optimized / DLL files
55 | __pycache__/
56 | *.py[cod]
57 | *$py.class
58 |
59 | # C extensions
60 | *.so
61 |
62 | # Distribution / packaging
63 | .Python
64 | build/
65 | develop-eggs/
66 | dist/
67 | downloads/
68 | eggs/
69 | .eggs/
70 | lib/
71 | lib64/
72 | parts/
73 | sdist/
74 | var/
75 | wheels/
76 | *.egg-info/
77 | .installed.cfg
78 | *.egg
79 |
80 | # PyInstaller
81 | # Usually these files are written by a python script from a template
82 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
83 | *.manifest
84 | *.spec
85 |
86 | # Installer logs
87 | pip-log.txt
88 | pip-delete-this-directory.txt
89 |
90 | # Unit test / coverage reports
91 | htmlcov/
92 | .tox/
93 | .coverage
94 | .coverage.*
95 | .cache
96 | nosetests.xml
97 | coverage.xml
98 | *.cover
99 | .hypothesis/
100 |
101 | # Translations
102 | *.mo
103 | *.pot
104 |
105 | # Django stuff:
106 | *.log
107 | local_settings.py
108 |
109 | # Flask stuff:
110 | instance/
111 | .webassets-cache
112 |
113 | # Scrapy stuff:
114 | .scrapy
115 |
116 | # Sphinx documentation
117 | docs/_build/
118 |
119 | # PyBuilder
120 | target/
121 |
122 | # Jupyter Notebook
123 | .ipynb_checkpoints
124 |
125 | # pyenv
126 | .python-version
127 |
128 | # celery beat schedule file
129 | celerybeat-schedule
130 |
131 | # SageMath parsed files
132 | *.sage.py
133 |
134 | # Environments
135 | .env
136 | .venv
137 | env/
138 | venv/
139 | ENV/
140 |
141 | # Spyder project settings
142 | .spyderproject
143 | .spyproject
144 |
145 | # Rope project settings
146 | .ropeproject
147 |
148 | # mkdocs documentation
149 | /site
150 |
151 | # mypy
152 | .mypy_cache/
153 | /tests/load_tests/logs/*
154 | /tests/.pytest_cache/
155 | ws_test.py
156 | /.vscode/
157 |
158 | .s3_cache/
159 | mlruns
160 | *.pyc
161 | *.swp
162 | *.pt
163 | *.bin
164 | .vscode/
165 | runs/
166 | jupyters/custom_*
167 |
168 | *logs/
169 | .DS_store
170 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ruDALL-E diffusion
2 |
3 | try it out on colab
4 |
5 |
6 |
7 |
8 | ruDALL-E diffusion is ruDALL-E with a diffusion decoder, similar to [dall-3](https://github.com/Jack000/DALLE-pytorch/)
9 |
10 | Decoding VQ embeddings with a DDPM model can produce much more realistic fine-grain details than VQVAE and VQGAN.
11 |
12 | the only code change to ruDALL-E is to return the image tokens in generate_images() - the actual diffusion model is here: https://github.com/Jack000/guided-diffusion
13 |
14 | # Samples
15 | | ruDALL-E + real-ESRGAN | diffusion |
16 | | --- | --- |
17 | |
|
|
18 | |
|
|
19 | |
|
|
20 | |
|
|
21 |
22 | note that the results depend a lot on the seed value
23 |
24 | base image:
25 |
26 |
27 | diffusion-generated samples (different seeds):
28 | | | | |
29 | | --- | --- | --- |
30 | |
|
|
|
31 | |
|
|
|
32 |
33 | ### generation by ruDALLE:
34 | ```python
35 | from rudalle.pipelines import generate_images, show, super_resolution, cherry_pick_by_clip
36 | from rudalle import get_rudalle_model, get_tokenizer, get_vae, get_realesrgan, get_ruclip
37 | from rudalle.utils import seed_everything
38 | import numpy as np
39 |
40 | # prepare models:
41 | device = 'cuda'
42 | dalle = get_rudalle_model('Malevich', pretrained=True, fp16=True, device=device)
43 | tokenizer = get_tokenizer()
44 | vae = get_vae(dwt=False).to(device) # Make sure to set dwt to False!
45 |
46 | text = 'изображение радуги на фоне ночного города'
47 |
48 | seed_everything(42)
49 | pil_images = []
50 | scores = []
51 | codes = []
52 |
53 | for top_k, top_p, images_num in [
54 | (2048, 0.995, 3),
55 | (1536, 0.99, 3),
56 | (1024, 0.99, 3),
57 | (1024, 0.98, 3),
58 | (512, 0.97, 3),
59 | (384, 0.96, 3),
60 | (256, 0.95, 3),
61 | (128, 0.95, 3),
62 | ]:
63 | _pil_images, _scores, _codes = generate_images(text, tokenizer, dalle, vae, top_k=top_k, images_num=images_num, top_p=top_p, return_codes=True)
64 | pil_images += _pil_images
65 | scores += _scores
66 | codes += _codes
67 |
68 | sr_images = super_resolution(pil_images, realesrgan)
69 |
70 | for i, im in enumerate(pil_images):
71 | im.save(str(i)+'.png')
72 | sr_images[i].save(str(i)+'sr.png')
73 | with open(str(i)+'.npy', 'wb') as f:
74 | np.save(f, codes[i])
75 |
76 | # afterward, pass .npy file to diffusion model https://github.com/Jack000/guided-diffusion
77 | ```
78 |
--------------------------------------------------------------------------------
/rudalle/dalle/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 |
4 | import torch
5 | from huggingface_hub import hf_hub_url, cached_download
6 |
7 | from .model import DalleModel
8 | from .fp16 import FP16Module
9 |
10 |
11 | MODELS = {
12 | 'Malevich': dict(
13 | description='◼️ Malevich is 1.3 billion params model from the family GPT3-like, '
14 | 'that uses Russian language and text+image multi-modality.',
15 | model_params=dict(
16 | num_layers=24,
17 | hidden_size=2048,
18 | num_attention_heads=16,
19 | embedding_dropout_prob=0.1,
20 | output_dropout_prob=0.1,
21 | attention_dropout_prob=0.1,
22 | image_tokens_per_dim=32,
23 | text_seq_length=128,
24 | cogview_sandwich_layernorm=True,
25 | cogview_pb_relax=True,
26 | vocab_size=16384+128,
27 | image_vocab_size=8192,
28 | ),
29 | repo_id='sberbank-ai/rudalle-Malevich',
30 | filename='pytorch_model_v2.bin',
31 | authors='SberAI, SberDevices',
32 | full_description='', # TODO
33 | ),
34 | 'Emojich': dict(
35 | description='😋 Emojich is a 1.3 billion params model from the family GPT3-like, '
36 | 'it generates emoji-style images with the brain of ◾ Malevich.',
37 | model_params=dict(
38 | num_layers=24,
39 | hidden_size=2048,
40 | num_attention_heads=16,
41 | embedding_dropout_prob=0.1,
42 | output_dropout_prob=0.1,
43 | attention_dropout_prob=0.1,
44 | image_tokens_per_dim=32,
45 | text_seq_length=128,
46 | cogview_sandwich_layernorm=True,
47 | cogview_pb_relax=True,
48 | vocab_size=16384 + 128,
49 | image_vocab_size=8192,
50 | ),
51 | repo_id='sberbank-ai/rudalle-Emojich',
52 | filename='pytorch_model.bin',
53 | authors='SberAI',
54 | full_description='', # TODO
55 | ),
56 | 'small': dict(
57 | description='',
58 | model_params=dict(
59 | num_layers=12,
60 | hidden_size=768,
61 | num_attention_heads=12,
62 | embedding_dropout_prob=0.1,
63 | output_dropout_prob=0.1,
64 | attention_dropout_prob=0.1,
65 | image_tokens_per_dim=32,
66 | text_seq_length=128,
67 | cogview_sandwich_layernorm=True,
68 | cogview_pb_relax=True,
69 | vocab_size=16384+128,
70 | image_vocab_size=8192,
71 | ),
72 | repo_id='',
73 | filename='',
74 | full_description='', # TODO
75 | ),
76 | }
77 |
78 |
79 | def get_rudalle_model(name, pretrained=True, fp16=False, device='cpu', cache_dir='/tmp/rudalle', **model_kwargs):
80 | # TODO docstring
81 | assert name in MODELS
82 |
83 | if fp16 and device == 'cpu':
84 | print('Warning! Using both fp16 and cpu doesnt support. You can use cuda device or turn off fp16.')
85 |
86 | config = MODELS[name].copy()
87 | config['model_params'].update(model_kwargs)
88 | model = DalleModel(device=device, **config['model_params'])
89 | if pretrained:
90 | cache_dir = os.path.join(cache_dir, name)
91 | config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename'])
92 | cached_download(config_file_url, cache_dir=cache_dir, force_filename=config['filename'])
93 | checkpoint = torch.load(os.path.join(cache_dir, config['filename']), map_location='cpu')
94 | model.load_state_dict(checkpoint)
95 | if fp16:
96 | model = FP16Module(model)
97 | model.eval()
98 | model = model.to(device)
99 | if config['description'] and pretrained:
100 | print(config['description'])
101 | return model
102 |
--------------------------------------------------------------------------------
/rudalle/vae/decoder_dwt.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import pywt
3 | import torch
4 | import torch.nn as nn
5 | from taming.modules.diffusionmodules.model import Decoder
6 |
7 | from .pytorch_wavelets_utils import SFB2D, _SFB2D, prep_filt_sfb2d, mode_to_int
8 |
9 |
10 | class DecoderDWT(nn.Module):
11 | def __init__(self, ddconfig, embed_dim):
12 | super().__init__()
13 | if ddconfig.out_ch != 12:
14 | ddconfig.out_ch = 12
15 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig['z_channels'], 1)
16 | self.decoder = Decoder(**ddconfig)
17 | self.idwt = DWTInverse(mode='zero', wave='db1')
18 |
19 | def forward(self, x):
20 | # x = self.post_quant_conv(x)
21 | freq = self.decoder(x)
22 | img = self.dwt_to_img(freq)
23 | return img
24 |
25 | def dwt_to_img(self, img):
26 | b, c, h, w = img.size()
27 | low = img[:, :3, :, :]
28 | high = img[:, 3:, :, :].view(b, 3, 3, h, w)
29 | return self.idwt((low, [high]))
30 |
31 |
32 | class DWTInverse(nn.Module):
33 | """ Performs a 2d DWT Inverse reconstruction of an image
34 |
35 | Args:
36 | wave (str or pywt.Wavelet): Which wavelet to use
37 | C: deprecated, will be removed in future
38 | """
39 |
40 | def __init__(self, wave='db1', mode='zero', trace_model=False):
41 | super().__init__()
42 | if isinstance(wave, str):
43 | wave = pywt.Wavelet(wave)
44 | if isinstance(wave, pywt.Wavelet):
45 | g0_col, g1_col = wave.rec_lo, wave.rec_hi
46 | g0_row, g1_row = g0_col, g1_col
47 | else:
48 | if len(wave) == 2:
49 | g0_col, g1_col = wave[0], wave[1]
50 | g0_row, g1_row = g0_col, g1_col
51 | elif len(wave) == 4:
52 | g0_col, g1_col = wave[0], wave[1]
53 | g0_row, g1_row = wave[2], wave[3]
54 | # Prepare the filters
55 | filts = prep_filt_sfb2d(g0_col, g1_col, g0_row, g1_row)
56 | self.register_buffer('g0_col', filts[0])
57 | self.register_buffer('g1_col', filts[1])
58 | self.register_buffer('g0_row', filts[2])
59 | self.register_buffer('g1_row', filts[3])
60 | self.mode = mode
61 | self.trace_model = trace_model
62 |
63 | def forward(self, coeffs):
64 | """
65 | Args:
66 | coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where:
67 | yl is a lowpass tensor of shape :math:`(N, C_{in}, H_{in}',
68 | W_{in}')` and yh is a list of bandpass tensors of shape
69 | :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. I.e. should match
70 | the format returned by DWTForward
71 |
72 | Returns:
73 | Reconstructed input of shape :math:`(N, C_{in}, H_{in}, W_{in})`
74 |
75 | Note:
76 | :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly
77 | downsampled shapes of the DWT pyramid.
78 |
79 | Note:
80 | Can have None for any of the highpass scales and will treat the
81 | values as zeros (not in an efficient way though).
82 | """
83 | yl, yh = coeffs
84 | ll = yl
85 | mode = mode_to_int(self.mode)
86 |
87 | # Do a multilevel inverse transform
88 | for h in yh[::-1]:
89 | if h is None:
90 | h = torch.zeros(ll.shape[0], ll.shape[1], 3, ll.shape[-2],
91 | ll.shape[-1], device=ll.device)
92 |
93 | # 'Unpad' added dimensions
94 | if ll.shape[-2] > h.shape[-2]:
95 | ll = ll[..., :-1, :]
96 | if ll.shape[-1] > h.shape[-1]:
97 | ll = ll[..., :-1]
98 | if not self.trace_model:
99 | ll = SFB2D.apply(ll, h, self.g0_col, self.g1_col, self.g0_row, self.g1_row, mode)
100 | else:
101 | ll = _SFB2D(ll, h, self.g0_col, self.g1_col, self.g0_row, self.g1_row, mode)
102 | return ll
103 |
--------------------------------------------------------------------------------
/rudalle/vae/model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from math import sqrt, log
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch import einsum
8 | from einops import rearrange
9 | from taming.modules.diffusionmodules.model import Encoder, Decoder
10 |
11 | from .decoder_dwt import DecoderDWT
12 |
13 |
14 | class VQGanGumbelVAE(torch.nn.Module):
15 |
16 | def __init__(self, config, dwt=False):
17 | super().__init__()
18 | model = GumbelVQ(
19 | ddconfig=config.model.params.ddconfig,
20 | n_embed=config.model.params.n_embed,
21 | embed_dim=config.model.params.embed_dim,
22 | kl_weight=config.model.params.kl_weight,
23 | dwt=dwt,
24 | )
25 | self.model = model
26 | self.num_layers = int(log(config.model.params.ddconfig.attn_resolutions[0]) / log(2))
27 | self.image_size = 256
28 | self.num_tokens = config.model.params.n_embed
29 |
30 | @torch.no_grad()
31 | def get_codebook_indices(self, img):
32 | img = (2 * img) - 1
33 | _, _, [_, _, indices] = self.model.encode(img)
34 | return rearrange(indices, 'b h w -> b (h w)')
35 |
36 | def decode(self, img_seq):
37 | b, n = img_seq.shape
38 | one_hot_indices = torch.nn.functional.one_hot(img_seq, num_classes=self.num_tokens).float()
39 | z = (one_hot_indices @ self.model.quantize.embed.weight)
40 | z = rearrange(z, 'b (h w) c -> b c h w', h=int(sqrt(n)))
41 | img = self.model.decode(z)
42 | img = (img.clamp(-1., 1.) + 1) * 0.5
43 | return img
44 |
45 |
46 | class GumbelQuantize(nn.Module):
47 | """
48 | credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
49 | Gumbel Softmax trick quantizer
50 | Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
51 | https://arxiv.org/abs/1611.01144
52 | """
53 |
54 | def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True,
55 | kl_weight=5e-4, temp_init=1.0, use_vqinterface=True):
56 | super().__init__()
57 | self.embedding_dim = embedding_dim
58 | self.n_embed = n_embed
59 | self.straight_through = straight_through
60 | self.temperature = temp_init
61 | self.kl_weight = kl_weight
62 | self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
63 | self.embed = nn.Embedding(self.n_embed, self.embedding_dim)
64 | self.use_vqinterface = use_vqinterface
65 |
66 | def forward(self, z, temp=None, return_logits=False):
67 | hard = self.straight_through if self.training else True
68 | temp = self.temperature if temp is None else temp
69 | logits = self.proj(z)
70 | soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
71 | z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight)
72 | # + kl divergence to the prior loss
73 | qy = F.softmax(logits, dim=1)
74 | diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
75 | ind = soft_one_hot.argmax(dim=1)
76 | if self.use_vqinterface:
77 | if return_logits:
78 | return z_q, diff, (None, None, ind), logits
79 | return z_q, diff, (None, None, ind)
80 | return z_q, diff, ind
81 |
82 |
83 | class GumbelVQ(nn.Module):
84 |
85 | def __init__(self, ddconfig, n_embed, embed_dim, dwt=False, kl_weight=1e-8):
86 | super().__init__()
87 | z_channels = ddconfig['z_channels']
88 | self.dwt = dwt
89 | self.encoder = Encoder(**ddconfig)
90 | self.decoder = DecoderDWT(ddconfig, embed_dim) if dwt else Decoder(**ddconfig)
91 | self.quantize = GumbelQuantize(z_channels, embed_dim, n_embed=n_embed, kl_weight=kl_weight, temp_init=1.0)
92 | self.quant_conv = torch.nn.Conv2d(ddconfig['z_channels'], embed_dim, 1)
93 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig['z_channels'], 1)
94 |
95 | def encode(self, x):
96 | h = self.encoder(x)
97 | h = self.quant_conv(h)
98 | quant, emb_loss, info = self.quantize(h)
99 | return quant, emb_loss, info
100 |
101 | def decode(self, quant):
102 | if self.dwt:
103 | quant = self.decoder.post_quant_conv(quant)
104 | else:
105 | quant = self.post_quant_conv(quant)
106 | dec = self.decoder(quant)
107 | return dec
108 |
--------------------------------------------------------------------------------
/rudalle/realesrgan/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import numpy as np
3 |
4 |
5 | def pad_reflect(image, pad_size):
6 | imsize = image.shape
7 | height, width = imsize[:2]
8 | new_img = np.zeros([height + pad_size * 2, width + pad_size * 2, imsize[2]]).astype(np.uint8)
9 | new_img[pad_size:-pad_size, pad_size:-pad_size, :] = image
10 | new_img[0:pad_size, pad_size:-pad_size, :] = np.flip(image[0:pad_size, :, :], axis=0) # top
11 | new_img[-pad_size:, pad_size:-pad_size, :] = np.flip(image[-pad_size:, :, :], axis=0) # bottom
12 | new_img[:, 0:pad_size, :] = np.flip(new_img[:, pad_size:pad_size * 2, :], axis=1) # left
13 | new_img[:, -pad_size:, :] = np.flip(new_img[:, -pad_size * 2:-pad_size, :], axis=1) # right
14 | return new_img
15 |
16 |
17 | def unpad_image(image, pad_size):
18 | return image[pad_size:-pad_size, pad_size:-pad_size, :]
19 |
20 |
21 | def pad_patch(image_patch, padding_size, channel_last=True):
22 | """ Pads image_patch with with padding_size edge values. """
23 | if channel_last:
24 | return np.pad(
25 | image_patch,
26 | ((padding_size, padding_size), (padding_size, padding_size), (0, 0)),
27 | 'edge',
28 | )
29 | else:
30 | return np.pad(
31 | image_patch,
32 | ((0, 0), (padding_size, padding_size), (padding_size, padding_size)),
33 | 'edge',
34 | )
35 |
36 |
37 | def unpad_patches(image_patches, padding_size):
38 | return image_patches[:, padding_size:-padding_size, padding_size:-padding_size, :]
39 |
40 |
41 | def split_image_into_overlapping_patches(image_array, patch_size, padding_size=2):
42 | """ Splits the image into partially overlapping patches.
43 | The patches overlap by padding_size pixels.
44 | Pads the image twice:
45 | - first to have a size multiple of the patch size,
46 | - then to have equal padding at the borders.
47 | Args:
48 | image_array: numpy array of the input image.
49 | patch_size: size of the patches from the original image (without padding).
50 | padding_size: size of the overlapping area.
51 | """
52 | xmax, ymax, _ = image_array.shape
53 | x_remainder = xmax % patch_size
54 | y_remainder = ymax % patch_size
55 |
56 | # modulo here is to avoid extending of patch_size instead of 0
57 | x_extend = (patch_size - x_remainder) % patch_size
58 | y_extend = (patch_size - y_remainder) % patch_size
59 |
60 | # make sure the image is divisible into regular patches
61 | extended_image = np.pad(image_array, ((0, x_extend), (0, y_extend), (0, 0)), 'edge')
62 |
63 | # add padding around the image to simplify computations
64 | padded_image = pad_patch(extended_image, padding_size, channel_last=True)
65 |
66 | xmax, ymax, _ = padded_image.shape
67 | patches = []
68 |
69 | x_lefts = range(padding_size, xmax - padding_size, patch_size)
70 | y_tops = range(padding_size, ymax - padding_size, patch_size)
71 |
72 | for x in x_lefts:
73 | for y in y_tops:
74 | x_left = x - padding_size
75 | y_top = y - padding_size
76 | x_right = x + patch_size + padding_size
77 | y_bottom = y + patch_size + padding_size
78 | patch = padded_image[x_left:x_right, y_top:y_bottom, :]
79 | patches.append(patch)
80 |
81 | return np.array(patches), padded_image.shape
82 |
83 |
84 | def stich_together(patches, padded_image_shape, target_shape, padding_size=4):
85 | """ Reconstruct the image from overlapping patches.
86 | After scaling, shapes and padding should be scaled too.
87 | Args:
88 | patches: patches obtained with split_image_into_overlapping_patches
89 | padded_image_shape: shape of the padded image contructed in split_image_into_overlapping_patches
90 | target_shape: shape of the final image
91 | padding_size: size of the overlapping area.
92 | """
93 |
94 | xmax, ymax, _ = padded_image_shape
95 | patches = unpad_patches(patches, padding_size)
96 | patch_size = patches.shape[1]
97 | n_patches_per_row = ymax // patch_size
98 |
99 | complete_image = np.zeros((xmax, ymax, 3))
100 |
101 | row = -1
102 | col = 0
103 | for i in range(len(patches)):
104 | if i % n_patches_per_row == 0:
105 | row += 1
106 | col = 0
107 | complete_image[
108 | row * patch_size: (row + 1) * patch_size, col * patch_size: (col + 1) * patch_size, :
109 | ] = patches[i]
110 | col += 1
111 | return complete_image[0: target_shape[0], 0: target_shape[1], :]
112 |
--------------------------------------------------------------------------------
/Emojich.md:
--------------------------------------------------------------------------------
1 | [[Paper]](https://arxiv.org/abs/2112.02448) [[Хабр]](https://habr.com/ru/company/sberbank/blog/593893/) [[Model Card]](https://huggingface.co/sberbank-ai/rudalle-Emojich) [[Kaggle]](https://www.kaggle.com/shonenkov/emojich-rudall-e) [[Dataset]](https://www.kaggle.com/shonenkov/russian-emoji)
2 | # Emojich
3 | 
4 | ### generate emojis from text
5 |
6 | Model was trained by [Sber AI](https://github.com/sberbank-ai)
7 | * Task: `text2image generation`
8 | * Num Parameters: `1.3 B`
9 | * Training Data Volume: `120 million text-image pairs` & [`2749 text-emoji pairs`](https://www.kaggle.com/shonenkov/russian-emoji)
10 |
11 | [](https://telegram.me/addstickers/SberAI_ruDALLE)
12 |
13 | ### Model Description
14 | 😋 Emojich is a 1.3 billion params model from the family GPT3-like, it generates emoji-style images with the brain of ◾ Malevich.
15 |
16 |
17 | ### Fine-tuning stage:
18 |
19 | The main goal of fine-tuning is trying to keep the generalization of [ruDALL-E Malevich (XL)](https://huggingface.co/sberbank-ai/rudalle-Malevich)
20 | model on text to emoji tasks. ruDALL-E Malevich is a multi-modality big pretrained transformer, that uses images and texts.
21 | The idea with freezing feedforward and self-attention layers in pretrained transformer is demonstrated high performance in changing different modalities.
22 | Also, the model has a good chance for over-fitting text modality and lost generalization.
23 | To deal with this problem is increased coefficient 10^3 in weighted cross-entropy loss for image codebooks part.
24 |
25 |
26 | Full version of training code is available on Kaggle: [](https://www.kaggle.com/shonenkov/emojich-rudall-e)
27 |
28 | ### Usage:
29 |
30 | [](https://colab.research.google.com/drive/1YbEduCe8jH0DXMXKxnb8ulmT8jscJ54i?usp=sharing)
31 |
32 | ```python
33 | from rudalle.pipelines import generate_images, show
34 | from rudalle import get_rudalle_model, get_tokenizer, get_vae
35 | from rudalle.utils import seed_everything
36 |
37 | device = 'cuda'
38 | dalle = get_rudalle_model('Emojich', pretrained=True, fp16=True, device=device)
39 | tokenizer = get_tokenizer()
40 | vae = get_vae(dwt=True).to(device)
41 |
42 | text = 'Дональд Трамп из лего' # Donald Trump made of LEGO
43 |
44 | seed_everything(42)
45 | pil_images = []
46 | for top_k, top_p, images_num in [
47 | (2048, 0.995, 16),
48 | ]:
49 | pil_images += generate_images(text, tokenizer, dalle, vae, top_k=top_k, images_num=images_num, top_p=top_p, bs=8)[0]
50 |
51 | show(pil_images, 4)
52 | ```
53 | 
54 |
55 | ### Super Resolution:
56 | ```python
57 | from rudalle.pipelines import super_resolution
58 | from rudalle import get_realesrgan
59 |
60 | device = 'cuda'
61 | realesrgan = get_realesrgan('x4', device=device)
62 | sr_images = super_resolution(pil_images, realesrgan)
63 | ```
64 |
65 | ### Converting to Telegram Stickers format (512x512 RGBA)
66 | ```python
67 | from rudalle.pipelines import convert_emoji_to_rgba, show_rgba
68 | from rudalle import get_emojich_unet
69 |
70 | device = 'cuda'
71 | emojich_unet = get_emojich_unet('unet_effnetb7').to(device)
72 | rgba_images, _ = convert_emoji_to_rgba(sr_images, emojich_unet, device=device)
73 | for rgba_image in rgba_images:
74 | show_rgba(rgba_image);
75 | ```
76 | 
77 |
78 | ### Examples of generated emojis
79 |
80 | All examples are generated automatically (without manual cherry-picking) with hyper-parameters:
81 | seed 42, batch size 16, top-k 2048, top-p 0.995, temperature 1.0, GPU A100.
82 | For making better generative emojis should use more attempts (~512) and select the best one manually.
83 |
84 | *Remember, the great art makers became "great" after creating just only one masterpiece.*
85 |
86 | 
87 |
88 |
89 | ### Citation
90 | Feel free to cite our work in your research if it is helpful for you
91 | ```
92 | @misc{shonenkov2021emojich,
93 | title={Emojich -- zero-shot emoji generation using Russian language: a technical report},
94 | author={Alex Shonenkov and Daria Bakshandaeva and Denis Dimitrov and Aleksandr Nikolich},
95 | year={2021},
96 | eprint={2112.02448},
97 | archivePrefix={arXiv},
98 | primaryClass={cs.CL}
99 | }
100 | ```
101 |
--------------------------------------------------------------------------------
/rudalle/realesrgan/rrdbnet_arch.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | from torch import nn as nn
4 | from torch.nn import functional as F
5 |
6 | from .arch_util import default_init_weights, make_layer, pixel_unshuffle
7 |
8 |
9 | class ResidualDenseBlock(nn.Module):
10 | """Residual Dense Block.
11 | Used in RRDB block in ESRGAN.
12 | Args:
13 | num_feat (int): Channel number of intermediate features.
14 | num_grow_ch (int): Channels for each growth.
15 | """
16 |
17 | def __init__(self, num_feat=64, num_grow_ch=32):
18 | super(ResidualDenseBlock, self).__init__()
19 | self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
20 | self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
21 | self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
22 | self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
23 | self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
24 |
25 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
26 |
27 | # initialization
28 | default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
29 |
30 | def forward(self, x):
31 | x1 = self.lrelu(self.conv1(x))
32 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
33 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
34 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
35 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
36 | # Emperically, we use 0.2 to scale the residual for better performance
37 | return x5 * 0.2 + x
38 |
39 |
40 | class RRDB(nn.Module):
41 | """Residual in Residual Dense Block.
42 | Used in RRDB-Net in ESRGAN.
43 | Args:
44 | num_feat (int): Channel number of intermediate features.
45 | num_grow_ch (int): Channels for each growth.
46 | """
47 |
48 | def __init__(self, num_feat, num_grow_ch=32):
49 | super(RRDB, self).__init__()
50 | self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
51 | self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
52 | self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
53 |
54 | def forward(self, x):
55 | out = self.rdb1(x)
56 | out = self.rdb2(out)
57 | out = self.rdb3(out)
58 | # Emperically, we use 0.2 to scale the residual for better performance
59 | return out * 0.2 + x
60 |
61 |
62 | class RRDBNet(nn.Module):
63 | """Networks consisting of Residual in Residual Dense Block, which is used
64 | in ESRGAN.
65 | ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
66 | We extend ESRGAN for scale x2 and scale x1.
67 | Note: This is one option for scale 1, scale 2 in RRDBNet.
68 | We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
69 | and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
70 | Args:
71 | num_in_ch (int): Channel number of inputs.
72 | num_out_ch (int): Channel number of outputs.
73 | num_feat (int): Channel number of intermediate features.
74 | Default: 64
75 | num_block (int): Block number in the trunk network. Defaults: 23
76 | num_grow_ch (int): Channels for each growth. Default: 32.
77 | """
78 |
79 | def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
80 | super(RRDBNet, self).__init__()
81 | self.scale = scale
82 | if scale == 2:
83 | num_in_ch = num_in_ch * 4
84 | elif scale == 1:
85 | num_in_ch = num_in_ch * 16
86 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
87 | self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
88 | self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
89 | # upsample
90 | self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
91 | self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
92 | if scale == 8:
93 | self.conv_up3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
94 | self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
95 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
96 |
97 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
98 |
99 | def forward(self, x):
100 | if self.scale == 2:
101 | feat = pixel_unshuffle(x, scale=2)
102 | elif self.scale == 1:
103 | feat = pixel_unshuffle(x, scale=4)
104 | else:
105 | feat = x
106 | feat = self.conv_first(feat)
107 | body_feat = self.conv_body(self.body(feat))
108 | feat = feat + body_feat
109 | # upsample
110 | feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
111 | feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
112 | if self.scale == 8:
113 | feat = self.lrelu(self.conv_up3(F.interpolate(feat, scale_factor=2, mode='nearest')))
114 | out = self.conv_last(self.lrelu(self.conv_hr(feat)))
115 | return out
116 |
--------------------------------------------------------------------------------
/rudalle/dalle/model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | import torch.nn.functional as F
4 | from einops import rearrange
5 |
6 | from .utils import exists, is_empty, init_method_normal
7 |
8 | from .transformer import DalleTransformer
9 |
10 |
11 | class DalleModel(torch.nn.Module):
12 | def __init__(self,
13 | device,
14 | num_layers,
15 | vocab_size,
16 | hidden_size,
17 | num_attention_heads,
18 | embedding_dropout_prob,
19 | attention_dropout_prob,
20 | output_dropout_prob,
21 | text_seq_length=128,
22 | image_tokens_per_dim=32,
23 | image_vocab_size=16384,
24 | loss_img_weight=7,
25 | cogview_sandwich_layernorm=False,
26 | cogview_pb_relax=False,
27 | is_bool_mask=True,
28 | mlp_activation='gelu_jit'):
29 | super(DalleModel, self).__init__()
30 | self.device = device
31 | self.image_tokens_per_dim = image_tokens_per_dim
32 | self.image_seq_length = image_tokens_per_dim ** 2
33 | self.text_seq_length = text_seq_length
34 | self.total_seq_length = self.text_seq_length + self.image_seq_length
35 | self.total_vocab_size = vocab_size + image_vocab_size
36 | self.vocab_size = vocab_size
37 | self.loss_img_weight = loss_img_weight
38 |
39 | init_method = init_method_normal(std=0.02)
40 |
41 | self.text_embeddings = torch.nn.Embedding(vocab_size, hidden_size)
42 | self.image_embeddings = torch.nn.Embedding(image_vocab_size, hidden_size)
43 |
44 | # Position embedding (serial).
45 | self.text_pos_embeddings = torch.nn.Embedding(text_seq_length + 1, hidden_size)
46 | self.image_row_embeddings = torch.nn.Embedding(image_tokens_per_dim, hidden_size)
47 | self.image_col_embeddings = torch.nn.Embedding(image_tokens_per_dim, hidden_size)
48 | init_method(self.text_pos_embeddings.weight)
49 | init_method(self.image_row_embeddings.weight)
50 | init_method(self.image_col_embeddings.weight)
51 |
52 | self.to_logits = torch.nn.Sequential(
53 | torch.nn.LayerNorm(hidden_size),
54 | torch.nn.Linear(hidden_size, self.total_vocab_size),
55 | )
56 |
57 | # Embeddings dropout
58 | self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
59 |
60 | # Transformer
61 | self.transformer = DalleTransformer(
62 | num_layers,
63 | hidden_size,
64 | num_attention_heads,
65 | attention_dropout_prob,
66 | output_dropout_prob,
67 | text_seq_length=text_seq_length,
68 | image_tokens_per_dim=image_tokens_per_dim,
69 | cogview_sandwich_layernorm=cogview_sandwich_layernorm,
70 | cogview_pb_relax=cogview_pb_relax,
71 | mlp_activation=mlp_activation,
72 | is_bool_mask=is_bool_mask,
73 | )
74 |
75 | def get_param(self, item):
76 | return getattr(self, item)
77 |
78 | def get_image_pos_embeddings(self, image_input_ids, past_length=0):
79 | input_shape = image_input_ids.size()
80 | row_ids = torch.arange(past_length, input_shape[-1] + past_length,
81 | dtype=torch.long, device=self.device) // self.image_tokens_per_dim
82 | row_ids = row_ids.unsqueeze(0).view(-1, input_shape[-1])
83 | col_ids = torch.arange(past_length, input_shape[-1] + past_length,
84 | dtype=torch.long, device=self.device) % self.image_tokens_per_dim
85 | col_ids = col_ids.unsqueeze(0).view(-1, input_shape[-1])
86 | return self.image_row_embeddings(row_ids) + self.image_col_embeddings(col_ids)
87 |
88 | def forward(
89 | self,
90 | input_ids,
91 | attention_mask,
92 | return_loss=False,
93 | has_cache=False,
94 | use_cache=False,
95 | ):
96 | text = input_ids[:, :self.text_seq_length]
97 | text_range = torch.arange(self.text_seq_length)
98 | text_range += (self.vocab_size - self.text_seq_length)
99 | text_range = text_range.to(self.device)
100 | text = torch.where(text == 0, text_range, text)
101 | # some hardcode :)
102 | text = F.pad(text, (1, 0), value=2)
103 | text_embeddings = self.text_embeddings(text) + \
104 | self.text_pos_embeddings(torch.arange(text.shape[1], device=self.device))
105 |
106 | image_input_ids = input_ids[:, self.text_seq_length:]
107 |
108 | if exists(image_input_ids) and not is_empty(image_input_ids):
109 | image_embeddings = self.image_embeddings(image_input_ids) + \
110 | self.get_image_pos_embeddings(image_input_ids, past_length=0)
111 | embeddings = torch.cat((text_embeddings, image_embeddings), dim=1)
112 | else:
113 | embeddings = text_embeddings
114 | # some hardcode :)
115 | if embeddings.shape[1] > self.total_seq_length:
116 | embeddings = embeddings[:, :-1]
117 |
118 | alpha = 0.1
119 | embeddings = embeddings * alpha + embeddings.detach() * (1-alpha)
120 |
121 | attention_mask = attention_mask[:, :, :embeddings.shape[1], :embeddings.shape[1]]
122 | transformer_output, present_has_cache = self.transformer(
123 | embeddings, attention_mask, has_cache=has_cache, use_cache=use_cache)
124 |
125 | logits = self.to_logits(transformer_output)
126 | if return_loss is False:
127 | return logits, present_has_cache
128 |
129 | labels = torch.cat((text[:, 1:], image_input_ids), dim=1).contiguous().long()
130 | logits = rearrange(logits, 'b n c -> b c n')
131 |
132 | text_logits = logits[:, :self.vocab_size, :self.text_seq_length].contiguous().float()
133 | image_logits = logits[:, self.vocab_size:, self.text_seq_length:].contiguous().float()
134 |
135 | loss_text = F.cross_entropy(
136 | text_logits,
137 | labels[:, :self.text_seq_length])
138 | loss_img = F.cross_entropy(
139 | image_logits,
140 | labels[:, self.text_seq_length:])
141 |
142 | loss = (loss_text + self.loss_img_weight * loss_img) / (self.loss_img_weight + 1)
143 | return loss, {'text': loss_text.data.detach().float(), 'image': loss_img.data.detach().float()}
144 |
145 | def to(self, device, *args, **kwargs):
146 | self.device = device
147 | return super().to(device, *args, **kwargs)
148 |
--------------------------------------------------------------------------------
/rudalle/realesrgan/arch_util.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import math
3 | import torch
4 | from torch import nn as nn
5 | from torch.nn import functional as F
6 | from torch.nn import init as init
7 | from torch.nn.modules.batchnorm import _BatchNorm
8 |
9 |
10 | @torch.no_grad()
11 | def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
12 | """Initialize network weights.
13 | Args:
14 | module_list (list[nn.Module] | nn.Module): Modules to be initialized.
15 | scale (float): Scale initialized weights, especially for residual
16 | blocks. Default: 1.
17 | bias_fill (float): The value to fill bias. Default: 0
18 | kwargs (dict): Other arguments for initialization function.
19 | """
20 | if not isinstance(module_list, list):
21 | module_list = [module_list]
22 | for module in module_list:
23 | for m in module.modules():
24 | if isinstance(m, nn.Conv2d):
25 | init.kaiming_normal_(m.weight, **kwargs)
26 | m.weight.data *= scale
27 | if m.bias is not None:
28 | m.bias.data.fill_(bias_fill)
29 | elif isinstance(m, nn.Linear):
30 | init.kaiming_normal_(m.weight, **kwargs)
31 | m.weight.data *= scale
32 | if m.bias is not None:
33 | m.bias.data.fill_(bias_fill)
34 | elif isinstance(m, _BatchNorm):
35 | init.constant_(m.weight, 1)
36 | if m.bias is not None:
37 | m.bias.data.fill_(bias_fill)
38 |
39 |
40 | def make_layer(basic_block, num_basic_block, **kwarg):
41 | """Make layers by stacking the same blocks.
42 | Args:
43 | basic_block (nn.module): nn.module class for basic block.
44 | num_basic_block (int): number of blocks.
45 | Returns:
46 | nn.Sequential: Stacked blocks in nn.Sequential.
47 | """
48 | layers = []
49 | for _ in range(num_basic_block):
50 | layers.append(basic_block(**kwarg))
51 | return nn.Sequential(*layers)
52 |
53 |
54 | class ResidualBlockNoBN(nn.Module):
55 | """Residual block without BN.
56 | It has a style of:
57 | ---Conv-ReLU-Conv-+-
58 | |________________|
59 | Args:
60 | num_feat (int): Channel number of intermediate features.
61 | Default: 64.
62 | res_scale (float): Residual scale. Default: 1.
63 | pytorch_init (bool): If set to True, use pytorch default init,
64 | otherwise, use default_init_weights. Default: False.
65 | """
66 |
67 | def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
68 | super(ResidualBlockNoBN, self).__init__()
69 | self.res_scale = res_scale
70 | self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
71 | self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
72 | self.relu = nn.ReLU(inplace=True)
73 |
74 | if not pytorch_init:
75 | default_init_weights([self.conv1, self.conv2], 0.1)
76 |
77 | def forward(self, x):
78 | identity = x
79 | out = self.conv2(self.relu(self.conv1(x)))
80 | return identity + out * self.res_scale
81 |
82 |
83 | class Upsample(nn.Sequential):
84 | """Upsample module.
85 | Args:
86 | scale (int): Scale factor. Supported scales: 2^n and 3.
87 | num_feat (int): Channel number of intermediate features.
88 | """
89 |
90 | def __init__(self, scale, num_feat):
91 | m = []
92 | if (scale & (scale - 1)) == 0: # scale = 2^n
93 | for _ in range(int(math.log(scale, 2))):
94 | m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
95 | m.append(nn.PixelShuffle(2))
96 | elif scale == 3:
97 | m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
98 | m.append(nn.PixelShuffle(3))
99 | else:
100 | raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
101 | super(Upsample, self).__init__(*m)
102 |
103 |
104 | def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
105 | """Warp an image or feature map with optical flow.
106 | Args:
107 | x (Tensor): Tensor with size (n, c, h, w).
108 | flow (Tensor): Tensor with size (n, h, w, 2), normal value.
109 | interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
110 | padding_mode (str): 'zeros' or 'border' or 'reflection'.
111 | Default: 'zeros'.
112 | align_corners (bool): Before pytorch 1.3, the default value is
113 | align_corners=True. After pytorch 1.3, the default value is
114 | align_corners=False. Here, we use the True as default.
115 | Returns:
116 | Tensor: Warped image or feature map.
117 | """
118 | assert x.size()[-2:] == flow.size()[1:3]
119 | _, _, h, w = x.size()
120 | # create mesh grid
121 | grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
122 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
123 | grid.requires_grad = False
124 |
125 | vgrid = grid + flow
126 | # scale grid to [-1,1]
127 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
128 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
129 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
130 | output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
131 |
132 | # TODO, what if align_corners=False
133 | return output
134 |
135 |
136 | def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
137 | """Resize a flow according to ratio or shape.
138 | Args:
139 | flow (Tensor): Precomputed flow. shape [N, 2, H, W].
140 | size_type (str): 'ratio' or 'shape'.
141 | sizes (list[int | float]): the ratio for resizing or the final output
142 | shape.
143 | 1) The order of ratio should be [ratio_h, ratio_w]. For
144 | downsampling, the ratio should be smaller than 1.0 (i.e., ratio
145 | < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
146 | ratio > 1.0).
147 | 2) The order of output_size should be [out_h, out_w].
148 | interp_mode (str): The mode of interpolation for resizing.
149 | Default: 'bilinear'.
150 | align_corners (bool): Whether align corners. Default: False.
151 | Returns:
152 | Tensor: Resized flow.
153 | """
154 | _, _, flow_h, flow_w = flow.size()
155 | if size_type == 'ratio':
156 | output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
157 | elif size_type == 'shape':
158 | output_h, output_w = sizes[0], sizes[1]
159 | else:
160 | raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
161 |
162 | input_flow = flow.clone()
163 | ratio_h = output_h / flow_h
164 | ratio_w = output_w / flow_w
165 | input_flow[:, 0, :, :] *= ratio_w
166 | input_flow[:, 1, :, :] *= ratio_h
167 | resized_flow = F.interpolate(
168 | input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
169 | return resized_flow
170 |
171 |
172 | # TODO: may write a cpp file
173 | def pixel_unshuffle(x, scale):
174 | """ Pixel unshuffle.
175 | Args:
176 | x (Tensor): Input feature with shape (b, c, hh, hw).
177 | scale (int): Downsample ratio.
178 | Returns:
179 | Tensor: the pixel unshuffled feature.
180 | """
181 | b, c, hh, hw = x.size()
182 | out_channel = c * (scale**2)
183 | assert hh % scale == 0 and hw % scale == 0
184 | h = hh // scale
185 | w = hw // scale
186 | x_view = x.view(b, c, h, scale, w, scale)
187 | return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
188 |
--------------------------------------------------------------------------------
/rudalle/pipelines.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | from glob import glob
4 | from os.path import join
5 |
6 | import cv2
7 | import torch
8 | import torchvision
9 | import transformers
10 | import more_itertools
11 | import numpy as np
12 | import matplotlib.pyplot as plt
13 | from tqdm.auto import tqdm
14 | from PIL import Image
15 |
16 | from . import utils
17 |
18 |
19 | def generate_images(text, tokenizer, dalle, vae, top_k, top_p, images_num, image_prompts=None, temperature=1.0, bs=4,
20 | seed=None, use_cache=True, return_codes=False):
21 | # TODO docstring
22 | if seed is not None:
23 | utils.seed_everything(seed)
24 |
25 | vocab_size = dalle.get_param('vocab_size')
26 | text_seq_length = dalle.get_param('text_seq_length')
27 | image_seq_length = dalle.get_param('image_seq_length')
28 | total_seq_length = dalle.get_param('total_seq_length')
29 | device = dalle.get_param('device')
30 |
31 | text = text.lower().strip()
32 | input_ids = tokenizer.encode_text(text, text_seq_length=text_seq_length)
33 | pil_images, scores, codes = [], [], []
34 | for chunk in more_itertools.chunked(range(images_num), bs):
35 | chunk_bs = len(chunk)
36 | with torch.no_grad():
37 | attention_mask = torch.tril(torch.ones((chunk_bs, 1, total_seq_length, total_seq_length), device=device))
38 | out = input_ids.unsqueeze(0).repeat(chunk_bs, 1).to(device)
39 | has_cache = False
40 | sample_scores = []
41 | if image_prompts is not None:
42 | prompts_idx, prompts = image_prompts.image_prompts_idx, image_prompts.image_prompts
43 | prompts = prompts.repeat(chunk_bs, 1)
44 | for idx in tqdm(range(out.shape[1], total_seq_length)):
45 | idx -= text_seq_length
46 | if image_prompts is not None and idx in prompts_idx:
47 | out = torch.cat((out, prompts[:, idx].unsqueeze(1)), dim=-1)
48 | else:
49 | logits, has_cache = dalle(out, attention_mask,
50 | has_cache=has_cache, use_cache=use_cache, return_loss=False)
51 | logits = logits[:, -1, vocab_size:]
52 | logits /= temperature
53 | filtered_logits = transformers.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
54 | probs = torch.nn.functional.softmax(filtered_logits, dim=-1)
55 | sample = torch.multinomial(probs, 1)
56 | sample_scores.append(probs[torch.arange(probs.size(0)), sample.transpose(0, 1)])
57 | out = torch.cat((out, sample), dim=-1)
58 | codebooks = out[:, -image_seq_length:]
59 | images = vae.decode(codebooks)
60 | pil_images += utils.torch_tensors_to_pil_list(images)
61 | scores += torch.cat(sample_scores).sum(0).detach().cpu().numpy().tolist()
62 | for j in range(codebooks.shape[0]):
63 | codes.append(codebooks[j].detach().cpu().numpy())
64 |
65 | if return_codes:
66 | return pil_images, scores, codes
67 | return pil_images, scores
68 |
69 | def super_resolution(pil_images, realesrgan, batch_size=4):
70 | result = []
71 | for pil_image in pil_images:
72 | with torch.no_grad():
73 | sr_image = realesrgan.predict(np.array(pil_image), batch_size=batch_size)
74 | result.append(sr_image)
75 | return result
76 |
77 |
78 | def cherry_pick_by_clip(pil_images, text, ruclip, ruclip_processor, device='cpu', count=4):
79 | with torch.no_grad():
80 | inputs = ruclip_processor(text=text, images=pil_images)
81 | for key in inputs.keys():
82 | inputs[key] = inputs[key].to(device)
83 | outputs = ruclip(**inputs)
84 | sims = outputs.logits_per_image.view(-1).softmax(dim=0)
85 | items = []
86 | for index, sim in enumerate(sims.cpu().numpy()):
87 | items.append({'img_index': index, 'cosine': sim})
88 | items = sorted(items, key=lambda x: x['cosine'], reverse=True)[:count]
89 | top_pil_images = [pil_images[x['img_index']] for x in items]
90 | top_scores = [x['cosine'] for x in items]
91 | return top_pil_images, top_scores
92 |
93 |
94 | def show(pil_images, nrow=4, size=14, save_dir=None, show=True):
95 | """
96 | :param pil_images: list of images in PIL
97 | :param nrow: number of rows
98 | :param size: size of the images
99 | :param save_dir: dir for separately saving of images, example: save_dir='./pics'
100 | """
101 | if save_dir is not None:
102 | os.makedirs(save_dir, exist_ok=True)
103 | count = len(glob(join(save_dir, 'img_*.png')))
104 | for i, pil_image in enumerate(pil_images):
105 | pil_image.save(join(save_dir, f'img_{count+i}.png'))
106 |
107 | pil_images = [pil_image.convert('RGB') for pil_image in pil_images]
108 | imgs = torchvision.utils.make_grid(utils.pil_list_to_torch_tensors(pil_images), nrow=nrow)
109 | if not isinstance(imgs, list):
110 | imgs = [imgs.cpu()]
111 | fix, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(size, size))
112 | for i, img in enumerate(imgs):
113 | img = img.detach()
114 | img = torchvision.transforms.functional.to_pil_image(img)
115 | if save_dir is not None:
116 | count = len(glob(join(save_dir, 'group_*.png')))
117 | img.save(join(save_dir, f'group_{count+i}.png'))
118 | if show:
119 | axs[0, i].imshow(np.asarray(img))
120 | axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
121 | if show:
122 | fix.show()
123 | plt.show()
124 |
125 |
126 | def classic_convert_emoji_to_rgba(np_image, lower_thr=240, upper_thr=255, width=2):
127 | img = np_image[:, :, :3].copy()
128 | lower = np.array([lower_thr, lower_thr, lower_thr], dtype='uint8')
129 | upper = np.array([upper_thr, upper_thr, upper_thr], dtype='uint8')
130 | mask = cv2.inRange(img, lower, upper)
131 | ret, thresh = cv2.threshold(mask, 0, 255, 0)
132 | contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
133 | a_channel = np.ones((512, 512), dtype=np.uint8)*255
134 | if len(contours) != 0:
135 | contours = sorted(contours, key=lambda x: x.shape[0])[-7:]
136 | cv2.fillPoly(a_channel, contours, (0, 0, 0))
137 | cv2.drawContours(a_channel, contours, -1, (0, 0, 0), width)
138 | img = cv2.cvtColor(img, cv2.COLOR_RGB2RGBA)
139 | img[:, :, 3] = a_channel
140 | return img
141 |
142 |
143 | def convert_emoji_to_rgba(pil_images, emojich_unet, device='cpu', bs=1, score_thr=0.99):
144 | final_images, runs = [], []
145 | with torch.no_grad():
146 | for chunk in more_itertools.chunked(pil_images, bs):
147 | images = []
148 | for pil_image in chunk:
149 | image = np.array(pil_image.resize((512, 512)))[:, :, :3]
150 | image = image.astype(np.float32) / 255.0
151 | image = torch.from_numpy(image).permute(2, 0, 1)
152 | images.append(image)
153 | images = torch.nn.utils.rnn.pad_sequence(images, batch_first=True)
154 | pred_masks = emojich_unet(images.to(device))
155 | pred_masks = torch.softmax(pred_masks, 1)
156 | scores, pred_masks = torch.max(pred_masks, 1)
157 | pred_masks = pred_masks.int().cpu().numpy()
158 | pred_masks = (pred_masks * 255).astype(np.uint8)
159 | for pil_image, pred_mask, score in zip(chunk, pred_masks, scores):
160 | score = score.mean().item()
161 | final_image = np.zeros((512, 512, 4), np.uint8)
162 | final_image[:, :, :3] = np.array(pil_image.resize((512, 512)))[:, :, :3]
163 | if score > score_thr:
164 | run = 'unet'
165 | final_image[:, :, -1] = pred_mask
166 | else:
167 | run = 'classic'
168 | final_image = classic_convert_emoji_to_rgba(final_image)
169 | final_image = Image.fromarray(final_image)
170 | final_images.append(final_image)
171 | runs.append(run)
172 | return final_images, runs
173 |
174 |
175 | def show_rgba(rgba_pil_image):
176 | img = np.array(rgba_pil_image)
177 | fig, ax = plt.subplots(1, 3, figsize=(10, 10), dpi=100)
178 | ax[0].imshow(img[:, :, :3])
179 | ax[1].imshow(img[:, :, -1])
180 | mask = np.repeat(np.expand_dims(img[:, :, -1] < 128, -1), 3, axis=-1)
181 | img = img[:, :, :3]
182 | img[mask[:, :, 0], 0] = 64
183 | img[mask[:, :, 0], 1] = 255
184 | img[mask[:, :, 0], 2] = 64
185 | ax[2].imshow(img)
186 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [2020] [sberbank-ai]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/rudalle/vae/pytorch_wavelets_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Useful utilities for testing the 2-D DTCWT with synthetic images
4 | License: https://github.com/fbcotter/pytorch_wavelets/blob/master/LICENSE
5 | Source: https://github.com/fbcotter/pytorch_wavelets/blob/31d6ac1b51b08f811a6a70eb7b3440f106009da0/pytorch_wavelets/dwt/lowlevel.py # noqa
6 | """
7 |
8 | import pywt
9 | import torch
10 | import numpy as np
11 | import torch.nn.functional as F
12 | from torch.autograd import Function
13 |
14 |
15 | def sfb1d(lo, hi, g0, g1, mode='zero', dim=-1):
16 | """ 1D synthesis filter bank of an image tensor
17 | """
18 | C = lo.shape[1]
19 | d = dim % 4
20 | # If g0, g1 are not tensors, make them. If they are, then assume that they
21 | # are in the right order
22 | if not isinstance(g0, torch.Tensor):
23 | g0 = torch.tensor(np.copy(np.array(g0).ravel()),
24 | dtype=torch.float, device=lo.device)
25 | if not isinstance(g1, torch.Tensor):
26 | g1 = torch.tensor(np.copy(np.array(g1).ravel()),
27 | dtype=torch.float, device=lo.device)
28 | L = g0.numel()
29 | shape = [1, 1, 1, 1]
30 | shape[d] = L
31 | N = 2*lo.shape[d]
32 | # If g aren't in the right shape, make them so
33 | if g0.shape != tuple(shape):
34 | g0 = g0.reshape(*shape)
35 | if g1.shape != tuple(shape):
36 | g1 = g1.reshape(*shape)
37 |
38 | s = (2, 1) if d == 2 else (1, 2)
39 | g0 = torch.cat([g0]*C, dim=0)
40 | g1 = torch.cat([g1]*C, dim=0)
41 | if mode == 'per' or mode == 'periodization':
42 | y = F.conv_transpose2d(lo, g0, stride=s, groups=C) + \
43 | F.conv_transpose2d(hi, g1, stride=s, groups=C)
44 | if d == 2:
45 | y[:, :, :L-2] = y[:, :, :L-2] + y[:, :, N:N+L-2]
46 | y = y[:, :, :N]
47 | else:
48 | y[:, :, :, :L-2] = y[:, :, :, :L-2] + y[:, :, :, N:N+L-2]
49 | y = y[:, :, :, :N]
50 | y = roll(y, 1-L//2, dim=dim)
51 | else:
52 | if mode == 'zero' or mode == 'symmetric' or mode == 'reflect' or \
53 | mode == 'periodic':
54 | pad = (L-2, 0) if d == 2 else (0, L-2)
55 | y = F.conv_transpose2d(lo, g0, stride=s, padding=pad, groups=C) + \
56 | F.conv_transpose2d(hi, g1, stride=s, padding=pad, groups=C)
57 | else:
58 | raise ValueError('Unkown pad type: {}'.format(mode))
59 |
60 | return y
61 |
62 |
63 | def _SFB2D(low, highs, g0_row, g1_row, g0_col, g1_col, mode):
64 | mode = int_to_mode(mode)
65 |
66 | lh, hl, hh = torch.unbind(highs, dim=2)
67 | lo = sfb1d(low, lh, g0_col, g1_col, mode=mode, dim=2)
68 | hi = sfb1d(hl, hh, g0_col, g1_col, mode=mode, dim=2)
69 | y = sfb1d(lo, hi, g0_row, g1_row, mode=mode, dim=3)
70 |
71 | return y
72 |
73 |
74 | def roll(x, n, dim, make_even=False):
75 | if n < 0:
76 | n = x.shape[dim] + n
77 |
78 | if make_even and x.shape[dim] % 2 == 1:
79 | end = 1
80 | else:
81 | end = 0
82 |
83 | if dim == 0:
84 | return torch.cat((x[-n:], x[:-n+end]), dim=0)
85 | elif dim == 1:
86 | return torch.cat((x[:, -n:], x[:, :-n+end]), dim=1)
87 | elif dim == 2 or dim == -2:
88 | return torch.cat((x[:, :, -n:], x[:, :, :-n+end]), dim=2)
89 | elif dim == 3 or dim == -1:
90 | return torch.cat((x[:, :, :, -n:], x[:, :, :, :-n+end]), dim=3)
91 |
92 |
93 | def int_to_mode(mode):
94 | if mode == 0:
95 | return 'zero'
96 | elif mode == 1:
97 | return 'symmetric'
98 | elif mode == 2:
99 | return 'periodization'
100 | elif mode == 3:
101 | return 'constant'
102 | elif mode == 4:
103 | return 'reflect'
104 | elif mode == 5:
105 | return 'replicate'
106 | elif mode == 6:
107 | return 'periodic'
108 | else:
109 | raise ValueError('Unkown pad type: {}'.format(mode))
110 |
111 |
112 | def prep_filt_sfb2d(g0_col, g1_col, g0_row=None, g1_row=None, device=None):
113 | """
114 | Prepares the filters to be of the right form for the sfb2d function. In
115 | particular, makes the tensors the right shape. It does not mirror image them
116 | as as sfb2d uses conv2d_transpose which acts like normal convolution.
117 | Inputs:
118 | g0_col (array-like): low pass column filter bank
119 | g1_col (array-like): high pass column filter bank
120 | g0_row (array-like): low pass row filter bank. If none, will assume the
121 | same as column filter
122 | g1_row (array-like): high pass row filter bank. If none, will assume the
123 | same as column filter
124 | device: which device to put the tensors on to
125 | Returns:
126 | (g0_col, g1_col, g0_row, g1_row)
127 | """
128 | g0_col, g1_col = prep_filt_sfb1d(g0_col, g1_col, device)
129 | if g0_row is None:
130 | g0_row, g1_row = g0_col, g1_col
131 | else:
132 | g0_row, g1_row = prep_filt_sfb1d(g0_row, g1_row, device)
133 |
134 | g0_col = g0_col.reshape((1, 1, -1, 1))
135 | g1_col = g1_col.reshape((1, 1, -1, 1))
136 | g0_row = g0_row.reshape((1, 1, 1, -1))
137 | g1_row = g1_row.reshape((1, 1, 1, -1))
138 |
139 | return g0_col, g1_col, g0_row, g1_row
140 |
141 |
142 | def prep_filt_sfb1d(g0, g1, device=None):
143 | """
144 | Prepares the filters to be of the right form for the sfb1d function. In
145 | particular, makes the tensors the right shape. It does not mirror image them
146 | as as sfb2d uses conv2d_transpose which acts like normal convolution.
147 | Inputs:
148 | g0 (array-like): low pass filter bank
149 | g1 (array-like): high pass filter bank
150 | device: which device to put the tensors on to
151 | Returns:
152 | (g0, g1)
153 | """
154 | g0 = np.array(g0).ravel()
155 | g1 = np.array(g1).ravel()
156 | t = torch.get_default_dtype()
157 | g0 = torch.tensor(g0, device=device, dtype=t).reshape((1, 1, -1))
158 | g1 = torch.tensor(g1, device=device, dtype=t).reshape((1, 1, -1))
159 |
160 | return g0, g1
161 |
162 |
163 | def mode_to_int(mode):
164 | if mode == 'zero':
165 | return 0
166 | elif mode == 'symmetric':
167 | return 1
168 | elif mode == 'per' or mode == 'periodization':
169 | return 2
170 | elif mode == 'constant':
171 | return 3
172 | elif mode == 'reflect':
173 | return 4
174 | elif mode == 'replicate':
175 | return 5
176 | elif mode == 'periodic':
177 | return 6
178 | else:
179 | raise ValueError('Unkown pad type: {}'.format(mode))
180 |
181 |
182 | def afb1d(x, h0, h1, mode='zero', dim=-1):
183 | """ 1D analysis filter bank (along one dimension only) of an image
184 | Inputs:
185 | x (tensor): 4D input with the last two dimensions the spatial input
186 | h0 (tensor): 4D input for the lowpass filter. Should have shape (1, 1,
187 | h, 1) or (1, 1, 1, w)
188 | h1 (tensor): 4D input for the highpass filter. Should have shape (1, 1,
189 | h, 1) or (1, 1, 1, w)
190 | mode (str): padding method
191 | dim (int) - dimension of filtering. d=2 is for a vertical filter (called
192 | column filtering but filters across the rows). d=3 is for a
193 | horizontal filter, (called row filtering but filters across the
194 | columns).
195 | Returns:
196 | lohi: lowpass and highpass subbands concatenated along the channel
197 | dimension
198 | """
199 | C = x.shape[1]
200 | # Convert the dim to positive
201 | d = dim % 4
202 | s = (2, 1) if d == 2 else (1, 2)
203 | N = x.shape[d]
204 | # If h0, h1 are not tensors, make them. If they are, then assume that they
205 | # are in the right order
206 | if not isinstance(h0, torch.Tensor):
207 | h0 = torch.tensor(np.copy(np.array(h0).ravel()[::-1]),
208 | dtype=torch.float, device=x.device)
209 | if not isinstance(h1, torch.Tensor):
210 | h1 = torch.tensor(np.copy(np.array(h1).ravel()[::-1]),
211 | dtype=torch.float, device=x.device)
212 | L = h0.numel()
213 | L2 = L // 2
214 | shape = [1, 1, 1, 1]
215 | shape[d] = L
216 | # If h aren't in the right shape, make them so
217 | if h0.shape != tuple(shape):
218 | h0 = h0.reshape(*shape)
219 | if h1.shape != tuple(shape):
220 | h1 = h1.reshape(*shape)
221 | h = torch.cat([h0, h1] * C, dim=0)
222 |
223 | if mode == 'per' or mode == 'periodization':
224 | if x.shape[dim] % 2 == 1:
225 | if d == 2:
226 | x = torch.cat((x, x[:, :, -1:]), dim=2)
227 | else:
228 | x = torch.cat((x, x[:, :, :, -1:]), dim=3)
229 | N += 1
230 | x = roll(x, -L2, dim=d)
231 | pad = (L-1, 0) if d == 2 else (0, L-1)
232 | lohi = F.conv2d(x, h, padding=pad, stride=s, groups=C)
233 | N2 = N//2
234 | if d == 2:
235 | lohi[:, :, :L2] = lohi[:, :, :L2] + lohi[:, :, N2:N2+L2]
236 | lohi = lohi[:, :, :N2]
237 | else:
238 | lohi[:, :, :, :L2] = lohi[:, :, :, :L2] + lohi[:, :, :, N2:N2+L2]
239 | lohi = lohi[:, :, :, :N2]
240 | else:
241 | # Calculate the pad size
242 | outsize = pywt.dwt_coeff_len(N, L, mode=mode)
243 | p = 2 * (outsize - 1) - N + L
244 | if mode == 'zero':
245 | # Sadly, pytorch only allows for same padding before and after, if
246 | # we need to do more padding after for odd length signals, have to
247 | # prepad
248 | if p % 2 == 1:
249 | pad = (0, 0, 0, 1) if d == 2 else (0, 1, 0, 0)
250 | x = F.pad(x, pad)
251 | pad = (p//2, 0) if d == 2 else (0, p//2)
252 | # Calculate the high and lowpass
253 | lohi = F.conv2d(x, h, padding=pad, stride=s, groups=C)
254 | elif mode == 'symmetric' or mode == 'reflect' or mode == 'periodic':
255 | pad = (0, 0, p//2, (p+1)//2) if d == 2 else (p//2, (p+1)//2, 0, 0)
256 | x = mypad(x, pad=pad, mode=mode)
257 | lohi = F.conv2d(x, h, stride=s, groups=C)
258 | else:
259 | raise ValueError('Unkown pad type: {}'.format(mode))
260 |
261 | return lohi
262 |
263 |
264 | def mypad(x, pad, mode='constant', value=0):
265 | """ Function to do numpy like padding on tensors. Only works for 2-D
266 | padding.
267 | Inputs:
268 | x (tensor): tensor to pad
269 | pad (tuple): tuple of (left, right, top, bottom) pad sizes
270 | mode (str): 'symmetric', 'wrap', 'constant, 'reflect', 'replicate', or
271 | 'zero'. The padding technique.
272 | """
273 | if mode == 'symmetric':
274 | # Vertical only
275 | if pad[0] == 0 and pad[1] == 0:
276 | m1, m2 = pad[2], pad[3]
277 | l = x.shape[-2] # noqa
278 | xe = reflect(np.arange(-m1, l+m2, dtype='int32'), -0.5, l-0.5)
279 | return x[:, :, xe]
280 | # horizontal only
281 | elif pad[2] == 0 and pad[3] == 0:
282 | m1, m2 = pad[0], pad[1]
283 | l = x.shape[-1] # noqa
284 | xe = reflect(np.arange(-m1, l+m2, dtype='int32'), -0.5, l-0.5)
285 | return x[:, :, :, xe]
286 | # Both
287 | else:
288 | m1, m2 = pad[0], pad[1]
289 | l1 = x.shape[-1]
290 | xe_row = reflect(np.arange(-m1, l1+m2, dtype='int32'), -0.5, l1-0.5)
291 | m1, m2 = pad[2], pad[3]
292 | l2 = x.shape[-2]
293 | xe_col = reflect(np.arange(-m1, l2+m2, dtype='int32'), -0.5, l2-0.5)
294 | i = np.outer(xe_col, np.ones(xe_row.shape[0]))
295 | j = np.outer(np.ones(xe_col.shape[0]), xe_row)
296 | return x[:, :, i, j]
297 | elif mode == 'periodic':
298 | # Vertical only
299 | if pad[0] == 0 and pad[1] == 0:
300 | xe = np.arange(x.shape[-2])
301 | xe = np.pad(xe, (pad[2], pad[3]), mode='wrap')
302 | return x[:, :, xe]
303 | # Horizontal only
304 | elif pad[2] == 0 and pad[3] == 0:
305 | xe = np.arange(x.shape[-1])
306 | xe = np.pad(xe, (pad[0], pad[1]), mode='wrap')
307 | return x[:, :, :, xe]
308 | # Both
309 | else:
310 | xe_col = np.arange(x.shape[-2])
311 | xe_col = np.pad(xe_col, (pad[2], pad[3]), mode='wrap')
312 | xe_row = np.arange(x.shape[-1])
313 | xe_row = np.pad(xe_row, (pad[0], pad[1]), mode='wrap')
314 | i = np.outer(xe_col, np.ones(xe_row.shape[0]))
315 | j = np.outer(np.ones(xe_col.shape[0]), xe_row)
316 | return x[:, :, i, j]
317 |
318 | elif mode == 'constant' or mode == 'reflect' or mode == 'replicate':
319 | return F.pad(x, pad, mode, value)
320 | elif mode == 'zero':
321 | return F.pad(x, pad)
322 | else:
323 | raise ValueError('Unkown pad type: {}'.format(mode))
324 |
325 |
326 | def reflect(x, minx, maxx):
327 | """Reflect the values in matrix *x* about the scalar values *minx* and
328 | *maxx*. Hence a vector *x* containing a long linearly increasing series is
329 | converted into a waveform which ramps linearly up and down between *minx*
330 | and *maxx*. If *x* contains integers and *minx* and *maxx* are (integers +
331 | 0.5), the ramps will have repeated max and min samples.
332 | .. codeauthor:: Rich Wareham , Aug 2013
333 | .. codeauthor:: Nick Kingsbury, Cambridge University, January 1999.
334 | """
335 | x = np.asanyarray(x)
336 | rng = maxx - minx
337 | rng_by_2 = 2 * rng
338 | mod = np.fmod(x - minx, rng_by_2)
339 | normed_mod = np.where(mod < 0, mod + rng_by_2, mod)
340 | out = np.where(normed_mod >= rng, rng_by_2 - normed_mod, normed_mod) + minx
341 | return np.array(out, dtype=x.dtype)
342 |
343 |
344 | class SFB2D(Function):
345 | """ Does a single level 2d wavelet decomposition of an input. Does separate
346 | row and column filtering by two calls to
347 | :py:func:`pytorch_wavelets.dwt.lowlevel.afb1d`
348 | Needs to have the tensors in the right form. Because this function defines
349 | its own backward pass, saves on memory by not having to save the input
350 | tensors.
351 | Inputs:
352 | x (torch.Tensor): Input to decompose
353 | h0_row: row lowpass
354 | h1_row: row highpass
355 | h0_col: col lowpass
356 | h1_col: col highpass
357 | mode (int): use mode_to_int to get the int code here
358 | We encode the mode as an integer rather than a string as gradcheck causes an
359 | error when a string is provided.
360 | Returns:
361 | y: Tensor of shape (N, C*4, H, W)
362 | """
363 | @staticmethod
364 | def forward(ctx, low, highs, g0_row, g1_row, g0_col, g1_col, mode):
365 | mode = int_to_mode(mode)
366 | ctx.mode = mode
367 | ctx.save_for_backward(g0_row, g1_row, g0_col, g1_col)
368 |
369 | lh, hl, hh = torch.unbind(highs, dim=2)
370 | lo = sfb1d(low, lh, g0_col, g1_col, mode=mode, dim=2)
371 | hi = sfb1d(hl, hh, g0_col, g1_col, mode=mode, dim=2)
372 | y = sfb1d(lo, hi, g0_row, g1_row, mode=mode, dim=3)
373 | return y
374 |
375 | @staticmethod
376 | def backward(ctx, dy):
377 | dlow, dhigh = None, None
378 | if ctx.needs_input_grad[0]:
379 | mode = ctx.mode
380 | g0_row, g1_row, g0_col, g1_col = ctx.saved_tensors
381 | dx = afb1d(dy, g0_row, g1_row, mode=mode, dim=3)
382 | dx = afb1d(dx, g0_col, g1_col, mode=mode, dim=2)
383 | s = dx.shape
384 | dx = dx.reshape(s[0], -1, 4, s[-2], s[-1])
385 | dlow = dx[:, :, 0].contiguous()
386 | dhigh = dx[:, :, 1:].contiguous()
387 | return dlow, dhigh, None, None, None, None, None
388 |
--------------------------------------------------------------------------------
/rudalle/dalle/transformer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import math
3 |
4 | import torch
5 | from torch.nn import LayerNorm
6 |
7 | from .utils import divide, split_tensor_along_last_dim
8 | from .image_attention import get_conv_mask, get_row_mask, get_col_mask
9 |
10 |
11 | def gelu(x):
12 | return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
13 |
14 |
15 | @torch.jit.script
16 | def gelu_jit(x):
17 | """OpenAI's gelu implementation."""
18 | return gelu(x)
19 |
20 |
21 | class DalleTransformer(torch.nn.Module):
22 | """
23 | This module takes input from embedding layer and it's output can
24 | be used directly by a logit layer. It consists of L (num-layers)
25 | blocks of:
26 | layer norm
27 | self attention
28 | residual connection
29 | layer norm
30 | mlp
31 | residual connection
32 | followed by a final layer norm.
33 |
34 | Arguments:
35 | num_layers: Number of transformer layers.
36 | hidden_size: The hidden size of the self attention.
37 | num_attention_heads: number of attention head in the self
38 | attention.
39 | attention_dropout_prob: dropout probability of the attention
40 | score in self attention.
41 | output_dropout_prob: dropout probability for the outputs
42 | after self attention and final output.
43 | layernorm_epsilon: epsilon used in layernorm to avoid
44 | division by zero.
45 | """
46 | _mask_map = []
47 |
48 | def __init__(self, num_layers, hidden_size, num_attention_heads, attention_dropout_prob, output_dropout_prob,
49 | text_seq_length, image_tokens_per_dim, layernorm_epsilon=1.0e-5,
50 | cogview_sandwich_layernorm=False, cogview_pb_relax=False, mlp_activation='gelu_jit',
51 | is_bool_mask=False):
52 | super(DalleTransformer, self).__init__()
53 |
54 | self.num_layers = num_layers
55 | # CogView stabilization of training features, see chapter 2.4 https://arxiv.org/pdf/2105.13290.pdf
56 | self.cogview_pb_relax = cogview_pb_relax
57 |
58 | # Transformer layers.
59 | self.layers = torch.nn.ModuleList([
60 | DalleTransformerLayer(
61 | hidden_size,
62 | num_attention_heads,
63 | attention_dropout_prob,
64 | output_dropout_prob,
65 | layernorm_epsilon,
66 | cogview_sandwich_layernorm=cogview_sandwich_layernorm,
67 | cogview_pb_relax=cogview_pb_relax,
68 | mlp_activation=mlp_activation,
69 | ) for _ in range(num_layers)
70 | ])
71 |
72 | row_mask = get_row_mask(text_seq_length, image_tokens_per_dim, is_bool_mask=is_bool_mask)
73 | col_mask = get_col_mask(text_seq_length, image_tokens_per_dim, is_bool_mask=is_bool_mask)
74 | conv_mask = get_conv_mask(text_seq_length, image_tokens_per_dim, is_bool_mask=is_bool_mask)
75 | self.register_buffer('row_mask', row_mask)
76 | self.register_buffer('col_mask', col_mask)
77 | self.register_buffer('conv_mask', conv_mask)
78 |
79 | # Final layer norm before output.
80 | self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
81 |
82 | def _get_layer_mask(self, layer_id):
83 | if ((layer_id - 1) % 4 == 0):
84 | layer_mask = self.col_mask
85 | elif layer_id != self.num_layers - 1:
86 | layer_mask = self.row_mask
87 | else:
88 | layer_mask = self.conv_mask
89 | return layer_mask
90 |
91 | def forward(self, hidden_states, attention_mask, has_cache, use_cache):
92 | for i, layer in enumerate(self.layers):
93 | mask = attention_mask
94 | layer_mask = self._get_layer_mask(i)[:mask.size(2), :mask.size(3)]
95 | mask = torch.mul(attention_mask, layer_mask)
96 | hidden_states, present_has_cache = layer(hidden_states, mask, has_cache=has_cache, use_cache=use_cache)
97 | output = self.final_layernorm(hidden_states)
98 | return output, present_has_cache
99 |
100 |
101 | class DalleTransformerLayer(torch.nn.Module):
102 | """
103 | A single layer transformer.
104 |
105 | We use the following notation:
106 | h: hidden size
107 | n: number of attention heads
108 | b: batch size
109 | s: sequence length
110 | Transformer layer takes input with size [b, s, h] and returns an
111 | output of the same size.
112 |
113 | Arguments:
114 | hidden_size: The hidden size of the self attention.
115 | num_attention_heads: number of attention head in the self
116 | attention.
117 | attention_dropout_prob: dropout probability of the attention
118 | score in self attention.
119 | output_dropout_prob: dropout probability for the outputs
120 | after self attention and final output.
121 | layernorm_epsilon: epsilon used in layernorm to avoid
122 | division by zero.
123 | """
124 |
125 | def __init__(self,
126 | hidden_size,
127 | num_attention_heads,
128 | attention_dropout_prob,
129 | output_dropout_prob,
130 | layernorm_epsilon,
131 | cogview_sandwich_layernorm=False,
132 | cogview_pb_relax=False,
133 | mlp_activation='gelu_jit'):
134 | super(DalleTransformerLayer, self).__init__()
135 |
136 | # CogView stabilization of training features, see chapter 2.4 https://arxiv.org/pdf/2105.13290.pdf
137 | self.cogview_sandwich_layernorm = cogview_sandwich_layernorm
138 | self.cogview_pb_relax = cogview_pb_relax
139 |
140 | # Layernorm on the input data.
141 | self.input_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
142 |
143 | if self.cogview_sandwich_layernorm:
144 | self.before_first_addition_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
145 | self.before_second_addition_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
146 |
147 | # Self attention.
148 | self.attention = DalleSelfAttention(
149 | hidden_size,
150 | num_attention_heads,
151 | attention_dropout_prob,
152 | output_dropout_prob,
153 | cogview_pb_relax=cogview_pb_relax
154 | )
155 |
156 | # Layernorm on the input data.
157 | self.post_attention_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
158 |
159 | # MLP
160 | self.mlp = DalleMLP(hidden_size, output_dropout_prob, activation=mlp_activation)
161 |
162 | def forward(self, hidden_states, ltor_mask, has_cache, use_cache):
163 | # hidden_states: [b, s, h]
164 | # ltor_mask: [1, 1, s, s]
165 |
166 | # Layer norm at the begining of the transformer layer.
167 | layernorm_output = self.input_layernorm(hidden_states)
168 |
169 | # Self attention.
170 | attention_output, att_has_cache = self.attention(
171 | layernorm_output, ltor_mask, has_cache=has_cache, use_cache=use_cache)
172 |
173 | if self.cogview_sandwich_layernorm:
174 | attention_output = self.before_first_addition_layernorm(attention_output)
175 |
176 | # Residual connection.
177 | layernorm_input = hidden_states + attention_output
178 |
179 | # Layer norm post the self attention.
180 | layernorm_output = self.post_attention_layernorm(layernorm_input)
181 |
182 | # MLP.
183 | mlp_output, mlp_has_cache = self.mlp(
184 | layernorm_output, has_cache=has_cache, use_cache=use_cache)
185 |
186 | if self.cogview_sandwich_layernorm:
187 | mlp_output = self.before_second_addition_layernorm(mlp_output)
188 |
189 | # Second residual connection.
190 | output = layernorm_input + mlp_output
191 |
192 | return output, att_has_cache and mlp_has_cache
193 |
194 |
195 | class DalleSelfAttention(torch.nn.Module):
196 | """
197 | Self-attention layer takes input with size [b, s, h] where b is
198 | the batch size, s is the sequence length, and h is the hidden size
199 | and creates output of the same size.
200 | Arguments:
201 | hidden_size: total hidden size of the layer (h).
202 | num_attention_heads: number of attention heads (n). Note that we
203 | require n to be divisible by number of GPUs
204 | used to parallelize the model. Also, we
205 | require hidden size to be divisible by n.
206 | attention_dropout_prob: dropout probability for the attention scores.
207 | output_dropout_prob: dropout probability for the output.
208 | We use the following notation:
209 | h: hidden_size
210 | n: num_attention_heads
211 | p: number of partitions
212 | np: n/p
213 | hp: h/p
214 | hn: h/n
215 | b: batch size
216 | s: sequence length
217 | """
218 |
219 | def __init__(self, hidden_size, num_attention_heads,
220 | attention_dropout_prob, output_dropout_prob, cogview_pb_relax=False):
221 | super(DalleSelfAttention, self).__init__()
222 |
223 | # CogView stabilization of training features, see chapter 2.4 https://arxiv.org/pdf/2105.13290.pdf
224 | self.cogview_pb_relax = cogview_pb_relax
225 |
226 | self.hidden_size = hidden_size
227 | self.num_attention_heads = num_attention_heads
228 | self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads)
229 |
230 | self.query_key_value = torch.nn.Linear(hidden_size, 3 * hidden_size)
231 | self.attention_dropout = torch.nn.Dropout(attention_dropout_prob)
232 |
233 | # Output.
234 | self.dense = torch.nn.Linear(hidden_size, hidden_size)
235 | self.output_dropout = torch.nn.Dropout(output_dropout_prob)
236 |
237 | # Cache
238 | self.past_key = None
239 | self.past_value = None
240 | self.past_output = None
241 |
242 | def _transpose_for_scores(self, tensor):
243 | """ Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with size [b, np, s, hn]. """
244 | new_tensor_shape = tensor.size()[:-1] + (self.num_attention_heads, self.hidden_size_per_attention_head)
245 | tensor = tensor.view(*new_tensor_shape)
246 | return tensor.permute(0, 2, 1, 3)
247 |
248 | def _calculate_attention_scores(self, query_layer, key_layer, ltor_mask):
249 | key_t = key_layer.transpose(-1, -2)
250 | if self.cogview_pb_relax:
251 | attention_scores = torch.matmul(
252 | query_layer / math.sqrt(self.hidden_size_per_attention_head),
253 | key_t
254 | )
255 | else:
256 | attention_scores = torch.matmul(query_layer, key_t) / math.sqrt(self.hidden_size_per_attention_head)
257 | ltor_mask = ltor_mask[:, :, -attention_scores.shape[-2]:]
258 | attention_scores = torch.mul(attention_scores, ltor_mask) - 10000.0 * (1.0 - ltor_mask)
259 | if self.cogview_pb_relax:
260 | # normalize attention scores. Should not affect resulting softmax value
261 | alpha = 32
262 | attention_scores_scaled = attention_scores / alpha
263 | attention_scores_scaled_maxes, _ = attention_scores_scaled.detach().view(
264 | [attention_scores.size(0), attention_scores.size(1), -1]
265 | ).max(dim=-1) # max per head per sample
266 | attention_scores_scaled_maxes = attention_scores_scaled_maxes.unsqueeze(-1).unsqueeze(-1).expand(
267 | [-1, -1, attention_scores.size(2), attention_scores.size(3)]
268 | ) # expand to [b, np, s, s]
269 | attention_scores = (attention_scores_scaled - attention_scores_scaled_maxes) * alpha
270 | return attention_scores
271 |
272 | def forward(self, hidden_states, ltor_mask, has_cache=False, use_cache=False, ):
273 | # hidden_states: [b, s, h]
274 | # ltor_mask: [1, 1, s, s]
275 | # Attention heads. [b, s, hp]
276 | if has_cache and use_cache:
277 | mixed_x_layer = self.query_key_value(hidden_states[:, self.past_key.shape[-2]:, :])
278 | else:
279 | mixed_x_layer = self.query_key_value(hidden_states)
280 |
281 | (mixed_query_layer,
282 | mixed_key_layer,
283 | mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
284 |
285 | query_layer = self._transpose_for_scores(mixed_query_layer)
286 | key_layer = self._transpose_for_scores(mixed_key_layer)
287 | value_layer = self._transpose_for_scores(mixed_value_layer)
288 |
289 | # Can be simplified, but I didn't for readability's sake
290 | if use_cache and has_cache:
291 | key_layer = torch.cat((self.past_key, key_layer), dim=-2)
292 | value_layer = torch.cat((self.past_value, value_layer), dim=-2)
293 | attention_scores = self._calculate_attention_scores(
294 | query_layer=query_layer, key_layer=key_layer, ltor_mask=ltor_mask
295 | )
296 | else:
297 | attention_scores = self._calculate_attention_scores(
298 | query_layer=query_layer, key_layer=key_layer, ltor_mask=ltor_mask
299 | )
300 |
301 | if use_cache and has_cache:
302 | extra_cache_size = hidden_states.shape[-2] - self.past_key.shape[-2]
303 | attention_scores = attention_scores[..., -extra_cache_size:, :]
304 |
305 | if use_cache:
306 | self.past_key = key_layer
307 | self.past_value = value_layer
308 | else:
309 | self.past_key = None
310 | self.past_value = None
311 | self.past_output = None
312 | has_cache = False
313 |
314 | # Attention probabilities. [b, np, s, s]
315 | attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)
316 |
317 | # This is actually dropping out entire tokens to attend to, which might
318 | # seem a bit unusual, but is taken from the original Transformer paper.
319 | attention_probs = self.attention_dropout(attention_probs)
320 |
321 | # Context layer.
322 | # [b, np, s, hn]
323 | context_layer = torch.matmul(attention_probs, value_layer)
324 |
325 | # [b, s, np, hn]
326 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
327 |
328 | new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,)
329 | # [b, s, hp]
330 | context_layer = context_layer.view(*new_context_layer_shape)
331 |
332 | # Output. [b, s, h]
333 | output = self.dense(context_layer)
334 |
335 | if use_cache:
336 | # Can be simplified, but I didn't for readability's sake
337 | if has_cache:
338 | output = torch.cat((self.past_output, output), dim=-2)
339 | self.past_output = output
340 | else:
341 | self.past_output = output
342 | has_cache = True
343 |
344 | output = self.output_dropout(output)
345 | return output, has_cache
346 |
347 |
348 | class DalleMLP(torch.nn.Module):
349 | """
350 | MLP will take the input with h hidden state, project it to 4*h
351 | hidden dimension, perform gelu transformation, and project the
352 | state back into h hidden dimension. At the end, dropout is also
353 | applied.
354 | Arguments:
355 | hidden_size: The hidden size of the self attention.
356 | output_dropout_prob: dropout probability for the outputs
357 | after self attention and final output.
358 | """
359 |
360 | def __init__(self, hidden_size, output_dropout_prob, activation='gelu_jit'):
361 | super(DalleMLP, self).__init__()
362 | self.activation = activation
363 | # Project to 4h.
364 | self.dense_h_to_4h = torch.nn.Linear(hidden_size, 4 * hidden_size)
365 | # Project back to h.
366 | self.dense_4h_to_h = torch.nn.Linear(4 * hidden_size, hidden_size)
367 | self.dropout = torch.nn.Dropout(output_dropout_prob)
368 | # MLP cache
369 | self.past_x = None
370 |
371 | def forward(self, hidden_states, has_cache=False, use_cache=False):
372 | if has_cache and use_cache:
373 | hidden_states = hidden_states[:, self.past_x.shape[-2]:]
374 |
375 | # [b, s, 4hp]
376 | x = self.dense_h_to_4h(hidden_states)
377 | if self.activation == 'gelu_jit':
378 | x = gelu_jit(x)
379 | elif self.activation == 'gelu':
380 | x = gelu(x)
381 | else:
382 | raise NotImplementedError('Used MLP activation is not implemented.')
383 | # [b, s, h]
384 | x = self.dense_4h_to_h(x)
385 | if use_cache:
386 | # Can be simplified, but I didn't for readability's sake
387 | if has_cache:
388 | x = torch.cat((self.past_x, x), dim=-2)
389 | self.past_x = x
390 | else:
391 | self.past_x = x
392 |
393 | has_cache = True
394 | else:
395 | self.past_x = None
396 | has_cache = False
397 | output = self.dropout(x)
398 |
399 | return output, has_cache
400 |
--------------------------------------------------------------------------------