├── .gitignore ├── .pre-commit-config.yaml ├── README.md ├── pics ├── h_example.jpg └── v_example.jpg ├── requirements.txt ├── rudalle_aspect_ratio ├── __init__.py ├── aspect_ratio.py ├── image_prompts.py └── models.py └── setup.cfg /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### JetBrains template 3 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 4 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 5 | 6 | settings/local.py 7 | logs/*.log 8 | 9 | # User-specific stuff: 10 | .idea/ 11 | 12 | # Sensitive or high-churn files: 13 | .idea/**/dataSources/ 14 | .idea/**/dataSources.ids 15 | .idea/**/dataSources.xml 16 | .idea/**/dataSources.local.xml 17 | .idea/**/sqlDataSources.xml 18 | .idea/**/dynamic.xml 19 | .idea/**/uiDesigner.xml 20 | 21 | # Gradle: 22 | .idea/**/gradle.xml 23 | .idea/**/libraries 24 | 25 | # CMake 26 | cmake-build-debug/ 27 | 28 | # Mongo Explorer plugin: 29 | .idea/**/mongoSettings.xml 30 | 31 | ## File-based project format: 32 | *.iws 33 | 34 | ## Plugin-specific files: 35 | 36 | # IntelliJ 37 | out/ 38 | 39 | # mpeltonen/sbt-idea plugin 40 | .idea_modules/ 41 | 42 | # JIRA plugin 43 | atlassian-ide-plugin.xml 44 | 45 | # Cursive Clojure plugin 46 | .idea/replstate.xml 47 | 48 | # Crashlytics plugin (for Android Studio and IntelliJ) 49 | com_crashlytics_export_strings.xml 50 | crashlytics.properties 51 | crashlytics-build.properties 52 | fabric.properties 53 | ### Python template 54 | # Byte-compiled / optimized / DLL files 55 | __pycache__/ 56 | *.py[cod] 57 | *$py.class 58 | 59 | # C extensions 60 | *.so 61 | 62 | # Distribution / packaging 63 | .Python 64 | build/ 65 | develop-eggs/ 66 | dist/ 67 | downloads/ 68 | eggs/ 69 | .eggs/ 70 | lib/ 71 | lib64/ 72 | parts/ 73 | sdist/ 74 | var/ 75 | wheels/ 76 | *.egg-info/ 77 | .installed.cfg 78 | *.egg 79 | 80 | # PyInstaller 81 | # Usually these files are written by a python script from a template 82 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 83 | *.manifest 84 | *.spec 85 | 86 | # Installer logs 87 | pip-log.txt 88 | pip-delete-this-directory.txt 89 | 90 | # Unit test / coverage reports 91 | htmlcov/ 92 | .tox/ 93 | .coverage 94 | .coverage.* 95 | .cache 96 | nosetests.xml 97 | coverage.xml 98 | *.cover 99 | .hypothesis/ 100 | 101 | # Translations 102 | *.mo 103 | *.pot 104 | 105 | # Django stuff: 106 | *.log 107 | local_settings.py 108 | 109 | # Flask stuff: 110 | instance/ 111 | .webassets-cache 112 | 113 | # Scrapy stuff: 114 | .scrapy 115 | 116 | # Sphinx documentation 117 | docs/_build/ 118 | 119 | # PyBuilder 120 | target/ 121 | 122 | # Jupyter Notebook 123 | .ipynb_checkpoints 124 | 125 | # pyenv 126 | .python-version 127 | 128 | # celery beat schedule file 129 | celerybeat-schedule 130 | 131 | # SageMath parsed files 132 | *.sage.py 133 | 134 | # Environments 135 | .env 136 | .venv 137 | env/ 138 | venv/ 139 | ENV/ 140 | 141 | # Spyder project settings 142 | .spyderproject 143 | .spyproject 144 | 145 | # Rope project settings 146 | .ropeproject 147 | 148 | # mkdocs documentation 149 | /site 150 | 151 | # mypy 152 | .mypy_cache/ 153 | /tests/load_tests/logs/* 154 | /tests/.pytest_cache/ 155 | ws_test.py 156 | /.vscode/ 157 | 158 | .s3_cache/ 159 | mlruns 160 | *.pyc 161 | *.swp 162 | *.pt 163 | *.bin 164 | .vscode/ 165 | runs/ 166 | jupyters/custom_* 167 | tb_logs 168 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.2.3 4 | hooks: 5 | - id: check-docstring-first 6 | stages: 7 | - commit 8 | - push 9 | - id: check-merge-conflict 10 | stages: 11 | - push 12 | - id: double-quote-string-fixer 13 | stages: 14 | - commit 15 | - push 16 | - id: fix-encoding-pragma 17 | stages: 18 | - commit 19 | - push 20 | - id: flake8 21 | args: ['--config=setup.cfg'] 22 | stages: 23 | - commit 24 | - push 25 | - repo: https://github.com/pre-commit/mirrors-autopep8 26 | rev: v1.4.4 27 | hooks: 28 | - id: autopep8 29 | stages: 30 | - commit 31 | - push 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [[Paper]](https://drive.google.com/file/d/1bN1pa6h9QO_po8VKSScNxAeWzV_nl_6W/view) 2 | [[Colab]](https://colab.research.google.com/drive/124zC1w2qHR1ijfEPQVvLccLRBLD_3duG?usp=sharing) 3 | [[Kaggle]](https://www.kaggle.com/code/shonenkov/usage-rudalle-aspect-ratio) 4 | [[Model Card]](https://huggingface.co/shonenkov-AI/rudalle-xl-surrealist) 5 | 6 | ruDALLE aspect ratio images 7 | --- 8 | Generate any arbitrary aspect ratio images using the ruDALLE models 9 | 10 | ### Installing 11 | 12 | ``` 13 | pip install rudalle==1.1.1 14 | git clone https://github.com/shonenkov-AI/rudalle-aspect-ratio 15 | ``` 16 | 17 | ### Quick Start 18 | 19 | Horizontal images: 20 | ```python3 21 | import sys 22 | sys.path.insert(0, './rudalle-aspect-ratio') 23 | from rudalle_aspect_ratio import RuDalleAspectRatio, get_rudalle_model 24 | from rudalle import get_vae, get_tokenizer 25 | from rudalle.pipelines import show 26 | 27 | device = 'cuda' 28 | dalle = get_rudalle_model('Surrealist_XL', fp16=True, device=device) 29 | vae, tokenizer = get_vae().to(device), get_tokenizer() 30 | rudalle_ar = RuDalleAspectRatio( 31 | dalle=dalle, vae=vae, tokenizer=tokenizer, 32 | aspect_ratio=32/9, bs=4, device=device 33 | ) 34 | _, result_pil_images = rudalle_ar.generate_images('готический квартал', 768, 0.99, 4) 35 | show(result_pil_images, 1) 36 | ``` 37 | ![](./pics/h_example.jpg) 38 | 39 | Vertical images: 40 | ```python3 41 | rudalle_ar = RuDalleAspectRatio( 42 | dalle=dalle, vae=vae, tokenizer=tokenizer, 43 | aspect_ratio=9/32, bs=4, device=device 44 | ) 45 | _, result_pil_images = rudalle_ar.generate_images('голубой цветок', 768, 0.99, 4) 46 | show(result_pil_images, 4) 47 | ``` 48 | 49 | ![](./pics/v_example.jpg) 50 | 51 | ### Citation: 52 | ``` 53 | @MISC{rudalle_ar_github, 54 | author = {Alex Shonenkov}, 55 | title = {Github ruDALLE aspect ratio images by shonenkovAI}, 56 | url = {https://github.com/shonenkov-AI/rudalle-aspect-ratio}, 57 | year = 2022, 58 | note = {Accessed: 13-04-2022} 59 | } 60 | ``` 61 | -------------------------------------------------------------------------------- /pics/h_example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shonenkov-AI/rudalle-aspect-ratio/10fb15dc8b9bc21373676f9d4e175c1894f26917/pics/h_example.jpg -------------------------------------------------------------------------------- /pics/v_example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shonenkov-AI/rudalle-aspect-ratio/10fb15dc8b9bc21373676f9d4e175c1894f26917/pics/v_example.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | rudalle==1.1.1 -------------------------------------------------------------------------------- /rudalle_aspect_ratio/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from .image_prompts import BatchImagePrompts 3 | from .aspect_ratio import RuDalleAspectRatio 4 | from .models import get_rudalle_model 5 | 6 | 7 | __all__ = ['BatchImagePrompts', 'RuDalleAspectRatio', 'get_rudalle_model'] 8 | -------------------------------------------------------------------------------- /rudalle_aspect_ratio/aspect_ratio.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import numpy as np 4 | import more_itertools 5 | import transformers 6 | from tqdm import tqdm 7 | from PIL import Image 8 | from rudalle import utils 9 | from einops import rearrange 10 | 11 | 12 | from .image_prompts import BatchImagePrompts 13 | 14 | 15 | class RuDalleAspectRatio: 16 | 17 | def __init__(self, dalle, vae, tokenizer, aspect_ratio=1.0, window=128, image_size=256, bs=4, 18 | device='cuda', quite=False): 19 | """ 20 | :param float aspect_ratio: w / h 21 | :param int window: size of context window for h_generations 22 | :param int image_size: image size, that is used by the rudalle model 23 | :param int bs: batch size 24 | :param bool quite: on/off tqdm 25 | """ 26 | self.device = device 27 | self.dalle = dalle 28 | self.vae = vae 29 | self.tokenizer = tokenizer 30 | # 31 | self.vocab_size = self.dalle.get_param('vocab_size') 32 | self.text_seq_length = self.dalle.get_param('text_seq_length') 33 | self.image_seq_length = self.dalle.get_param('image_seq_length') 34 | self.total_seq_length = self.dalle.get_param('total_seq_length') 35 | self.image_tokens_per_dim = self.dalle.get_param('image_tokens_per_dim') 36 | # 37 | self.window = window 38 | self.image_size = image_size 39 | self.patch_size = image_size // self.image_tokens_per_dim 40 | self.bs = bs 41 | self.quite = quite 42 | if aspect_ratio <= 1: 43 | self.is_vertical = True 44 | self.w = image_size 45 | self.h = int(round(image_size / aspect_ratio)) 46 | else: 47 | self.is_vertical = False 48 | self.h = image_size 49 | self.w = int(round(image_size * aspect_ratio)) 50 | self.aspect_ratio = aspect_ratio 51 | 52 | def generate_images(self, text, top_k=1024, top_p=0.975, images_num=4, seed=None): 53 | if seed is not None: 54 | utils.seed_everything(seed) 55 | 56 | if self.is_vertical: 57 | codebooks = self.generate_h_codebooks(text, top_k=top_k, top_p=top_p, images_num=images_num) 58 | pil_images = self.decode_h_codebooks(codebooks) 59 | else: 60 | codebooks, pil_images = [], [] 61 | image_prompts = None 62 | while (len(pil_images)+1)*self.window <= self.w: 63 | if pil_images: 64 | image_prompts = self.prepare_w_image_prompt(pil_images[-1]) 65 | _pil_images, _codebooks = self.generate_w_codebooks( 66 | text, top_k=top_k, top_p=top_p, images_num=images_num, 67 | image_prompts=image_prompts, use_cache=True, 68 | ) 69 | codebooks.append(_codebooks) 70 | pil_images.append(_pil_images) 71 | 72 | pil_images = self.decode_w_codebooks(codebooks) 73 | codebooks = torch.cat([_codebooks for _codebooks in codebooks]) 74 | 75 | result_images = [pil_img.crop((0, 0, self.w, self.h)) for pil_img in pil_images] 76 | return codebooks, result_images 77 | 78 | def generate_w_codebooks(self, text, top_k, top_p, images_num, image_prompts=None, temperature=1.0, use_cache=True): 79 | text = text.lower().strip() 80 | input_ids = self.tokenizer.encode_text(text, text_seq_length=self.text_seq_length) 81 | codebooks, pil_images = [], [] 82 | for chunk in more_itertools.chunked(range(images_num), self.bs): 83 | chunk_bs = len(chunk) 84 | with torch.no_grad(): 85 | attention_mask = torch.tril( 86 | torch.ones((chunk_bs, 1, self.total_seq_length, self.total_seq_length), device=self.device) 87 | ) 88 | out = input_ids.unsqueeze(0).repeat(chunk_bs, 1).to(self.device) 89 | cache = {} 90 | if image_prompts is not None: 91 | prompts_idx, prompts = image_prompts.image_prompts_idx, image_prompts.image_prompts 92 | range_out = range(out.shape[1], self.total_seq_length) 93 | if not self.quite: 94 | range_out = tqdm(range_out) 95 | for idx in range_out: 96 | idx -= self.text_seq_length 97 | if image_prompts is not None and idx in prompts_idx: 98 | out = torch.cat((out, prompts[:, idx].unsqueeze(1)), dim=-1) 99 | else: 100 | logits, cache = self.dalle(out, attention_mask, 101 | cache=cache, use_cache=use_cache, return_loss=False) 102 | logits = logits[:, -1, self.vocab_size:] 103 | logits /= temperature 104 | filtered_logits = transformers.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) 105 | probs = torch.nn.functional.softmax(filtered_logits, dim=-1) 106 | sample = torch.multinomial(probs, 1) 107 | out = torch.cat((out, sample), dim=-1) 108 | 109 | _codebooks = out[:, -self.image_seq_length:] 110 | images = self.vae.decode(_codebooks) 111 | pil_images += utils.torch_tensors_to_pil_list(images) 112 | codebooks.append(_codebooks) 113 | 114 | codebooks = torch.cat(codebooks) 115 | return pil_images, codebooks 116 | 117 | def prepare_w_image_prompt(self, pil_images): 118 | changed_pil_images = [] 119 | for pil_img in pil_images: 120 | np_img = np.array(pil_img) 121 | np_img[:, :self.window, :] = np_img[:, self.window:2*self.window, :] 122 | pil_img = Image.fromarray(np_img) 123 | changed_pil_images.append(pil_img) 124 | borders = {'up': 0, 'left': self.window // self.patch_size, 'right': 0, 'down': 0} 125 | return BatchImagePrompts(changed_pil_images, borders, self.vae, self.device, crop_first=True) 126 | 127 | def generate_h_codebooks(self, text, top_k, top_p, images_num, temperature=1.0, use_cache=True): 128 | h_out = int(round(self.image_tokens_per_dim / self.aspect_ratio)) 129 | text = text.lower().strip() 130 | input_ids = self.tokenizer.encode_text(text, text_seq_length=self.text_seq_length) 131 | codebooks = [] 132 | for chunk in more_itertools.chunked(range(images_num), self.bs): 133 | chunk_bs = len(chunk) 134 | with torch.no_grad(): 135 | attention_mask = torch.tril( 136 | torch.ones((chunk_bs, 1, self.total_seq_length, self.total_seq_length), device=self.device) 137 | ) 138 | full_context = input_ids.unsqueeze(0).repeat(chunk_bs, 1).to(self.device) 139 | range_h_out = range(h_out) 140 | if not self.quite: 141 | range_h_out = tqdm(range_h_out) 142 | for i in range_h_out: 143 | j = (self.image_tokens_per_dim * i) // h_out 144 | out = torch.cat(( 145 | input_ids.unsqueeze(0).repeat(chunk_bs, 1).to(self.device), 146 | full_context[:, self.text_seq_length:][:, -j * self.image_tokens_per_dim:] 147 | ), dim=-1) 148 | 149 | cache = {} 150 | for _ in range(self.image_tokens_per_dim): 151 | logits, cache = self.dalle(out, attention_mask, 152 | cache=cache, use_cache=use_cache, return_loss=False) 153 | logits = logits[:, -1, self.vocab_size:] 154 | logits /= temperature 155 | filtered_logits = transformers.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) 156 | probs = torch.nn.functional.softmax(filtered_logits, dim=-1) 157 | sample = torch.multinomial(probs, 1) 158 | out = torch.cat((out, sample), dim=-1) 159 | full_context = torch.cat((full_context, sample), dim=-1) 160 | codebooks.append(full_context[:, self.text_seq_length:]) 161 | 162 | return torch.cat(codebooks) 163 | 164 | def decode_h_codebooks(self, codebooks): 165 | with torch.no_grad(): 166 | one_hot_indices = torch.nn.functional.one_hot(codebooks, num_classes=self.vae.num_tokens).float() 167 | z = (one_hot_indices @ self.vae.model.quantize.embed.weight) 168 | z = rearrange(z, 'b (h w) c -> b c h w', w=self.image_tokens_per_dim) 169 | img = self.vae.model.decode(z) 170 | img = (img.clamp(-1., 1.) + 1) * 0.5 171 | return utils.torch_tensors_to_pil_list(img) 172 | 173 | def decode_w_codebooks(self, codebooks): 174 | with torch.no_grad(): 175 | final_z = [] 176 | for i, img_seq in enumerate(codebooks): 177 | one_hot_indices = torch.nn.functional.one_hot(img_seq, num_classes=self.vae.num_tokens).float() 178 | z = (one_hot_indices @ self.vae.model.quantize.embed.weight) 179 | z = rearrange(z, 'b (h w) c -> b c h w', h=self.image_tokens_per_dim) 180 | if i < len(codebooks)-1: 181 | final_z.append(z[:, :, :, :self.window//self.patch_size]) 182 | else: 183 | final_z.append(z) 184 | z = torch.cat(final_z, -1) 185 | img = self.vae.model.decode(z) 186 | img = (img.clamp(-1., 1.) + 1) * 0.5 187 | return utils.torch_tensors_to_pil_list(img) 188 | -------------------------------------------------------------------------------- /rudalle_aspect_ratio/image_prompts.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import numpy as np 4 | 5 | 6 | class BatchImagePrompts: 7 | 8 | def __init__(self, pil_images, borders, vae, device='cpu', crop_first=False): 9 | self.device = device 10 | img = torch.cat([self._preprocess_img(pil_image) for pil_image in pil_images], dim=0) 11 | img = img.to(self.device, dtype=torch.float32) 12 | self.image_prompts_idx, self.image_prompts = self._get_image_prompts(img, borders, vae, crop_first) 13 | 14 | def _preprocess_img(self, pil_img): 15 | img = torch.tensor(np.array(pil_img.convert('RGB')).transpose(2, 0, 1)) / 255. 16 | img = img.unsqueeze(0) 17 | img = (2 * img) - 1 18 | return img 19 | 20 | def _get_image_prompts(self, img, borders, vae, crop_first): 21 | if crop_first: 22 | bs, _, img_h, img_w = img.shape 23 | vqg_img_w, vqg_img_h = img_w // 8, img_h // 8 24 | vqg_img = torch.zeros((bs, vqg_img_h, vqg_img_w), dtype=torch.int32, device=img.device) 25 | if borders['down'] != 0: 26 | down_border = borders['down'] * 8 27 | _, _, [_, _, down_vqg_img] = vae.model.encode(img[:, :, -down_border:, :], disable_gumbel_softmax=True) 28 | vqg_img[:, -borders['down']:, :] = down_vqg_img 29 | if borders['right'] != 0: 30 | right_border = borders['right'] * 8 31 | _, _, [_, _, right_vqg_img] = vae.model.encode( 32 | img[:, :, :, -right_border:], disable_gumbel_softmax=True) 33 | vqg_img[:, :, -borders['right']:] = right_vqg_img 34 | if borders['left'] != 0: 35 | left_border = borders['left'] * 8 36 | _, _, [_, _, left_vqg_img] = vae.model.encode(img[:, :, :, :left_border], disable_gumbel_softmax=True) 37 | vqg_img[:, :, :borders['left']] = left_vqg_img 38 | if borders['up'] != 0: 39 | up_border = borders['up'] * 8 40 | _, _, [_, _, up_vqg_img] = vae.model.encode(img[:, :, :up_border, :], disable_gumbel_softmax=True) 41 | vqg_img[:, :borders['up'], :] = up_vqg_img 42 | else: 43 | _, _, [_, _, vqg_img] = vae.model.encode(img, disable_gumbel_softmax=True) 44 | 45 | bs, vqg_img_h, vqg_img_w = vqg_img.shape 46 | mask = torch.zeros(vqg_img_h, vqg_img_w) 47 | if borders['up'] != 0: 48 | mask[:borders['up'], :] = 1. 49 | if borders['down'] != 0: 50 | mask[-borders['down']:, :] = 1. 51 | if borders['right'] != 0: 52 | mask[:, -borders['right']:] = 1. 53 | if borders['left'] != 0: 54 | mask[:, :borders['left']] = 1. 55 | mask = mask.reshape(-1).bool() 56 | 57 | image_prompts = vqg_img.reshape((bs, -1)) 58 | image_prompts_idx = np.arange(vqg_img_w * vqg_img_h) 59 | image_prompts_idx = set(image_prompts_idx[mask]) 60 | 61 | return image_prompts_idx, image_prompts 62 | -------------------------------------------------------------------------------- /rudalle_aspect_ratio/models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | 4 | import torch 5 | from huggingface_hub import hf_hub_url, cached_download 6 | from rudalle.dalle import MODELS 7 | from rudalle.dalle.model import DalleModel 8 | from rudalle.dalle.fp16 import FP16Module 9 | 10 | 11 | MODELS.update({ 12 | 'Surrealist_XL': dict( 13 | hf_version='v3', 14 | description='Surrealist is 1.3 billion params model from the family GPT3-like, ' 15 | 'that was trained on surrealism and Russian.', 16 | model_params=dict( 17 | num_layers=24, 18 | hidden_size=2048, 19 | num_attention_heads=16, 20 | embedding_dropout_prob=0.1, 21 | output_dropout_prob=0.1, 22 | attention_dropout_prob=0.1, 23 | image_tokens_per_dim=32, 24 | text_seq_length=128, 25 | cogview_sandwich_layernorm=True, 26 | cogview_pb_relax=True, 27 | vocab_size=16384 + 128, 28 | image_vocab_size=8192, 29 | ), 30 | repo_id='shonenkov-AI/rudalle-xl-surrealist', 31 | filename='pytorch_model.bin', 32 | authors='shonenkovAI', 33 | full_description='', 34 | ) 35 | }) 36 | 37 | 38 | def get_rudalle_model(name, pretrained=True, fp16=False, device='cpu', cache_dir='/tmp/rudalle', **model_kwargs): 39 | assert name in MODELS 40 | 41 | if fp16 and device == 'cpu': 42 | print('Warning! Using both fp16 and cpu doesnt support. You can use cuda device or turn off fp16.') 43 | 44 | config = MODELS[name].copy() 45 | config['model_params'].update(model_kwargs) 46 | model = DalleModel(device=device, **config['model_params']) 47 | if pretrained: 48 | cache_dir = os.path.join(cache_dir, name) 49 | config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename']) 50 | cached_download(config_file_url, cache_dir=cache_dir, force_filename=config['filename']) 51 | checkpoint = torch.load(os.path.join(cache_dir, config['filename']), map_location='cpu') 52 | model.load_state_dict(checkpoint) 53 | if fp16: 54 | model = FP16Module(model) 55 | model.eval() 56 | model = model.to(device) 57 | if config['description'] and pretrained: 58 | print(config['description']) 59 | return model 60 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [pep8] 2 | max-line-length = 120 3 | exclude = .tox,*migrations*,.json 4 | 5 | [flake8] 6 | max-line-length = 120 7 | exclude = .tox,*migrations*,.json 8 | 9 | [autopep8-wrapper] 10 | exclude = .tox,*migrations*,.json 11 | 12 | [check-docstring-first] 13 | exclude = .tox,*migrations*,.json 14 | --------------------------------------------------------------------------------