├── .gitignore ├── requirements.txt ├── utils.py ├── video_gen.py ├── text_gen.py ├── gen_cli.py ├── README.md └── image_gen.py /.gitignore: -------------------------------------------------------------------------------- 1 | output/ -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | deep-daze==0.10.2 2 | moviepy==1.0.3 3 | opencv-python==4.5.1.48 4 | torch==1.8.1 5 | transformers==4.5.1 6 | tqdm==4.59.0 -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from string import ascii_letters, digits 2 | 3 | FOLDER_NAME_ALLOW_LIST = set(ascii_letters + digits + '_-') 4 | FILE_NAME_ALLOW_LIST = set(ascii_letters + digits + '_-.') 5 | 6 | def sanitize_folder_name(name): 7 | clean_name = '' 8 | for c in name: 9 | if c in FOLDER_NAME_ALLOW_LIST: 10 | clean_name += c 11 | else: 12 | clean_name += '_' 13 | return clean_name 14 | 15 | def sanitize_file_name(name): 16 | clean_name = '' 17 | for c in name: 18 | if c in FILE_NAME_ALLOW_LIST: 19 | clean_name += c 20 | else: 21 | clean_name += '_' 22 | return clean_name -------------------------------------------------------------------------------- /video_gen.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | import cv2 5 | import moviepy.editor as mp 6 | 7 | VIDEO_FPS = 20 8 | 9 | def convert_to_gif(vid_path, gif_path): 10 | clip = mp.VideoFileClip(vid_path) 11 | clip.write_gif(gif_path) 12 | 13 | def gen_video(directory, del_imgs=False): 14 | imgs = [] 15 | 16 | file_paths = glob.glob(os.path.join(directory, '*.jpg')) 17 | root_file_path = sorted(file_paths, key=len)[0] 18 | file_paths.remove(root_file_path) 19 | 20 | for file_path in file_paths: 21 | img = cv2.imread(file_path) 22 | height, width, layers = img.shape 23 | size = (width, height) 24 | imgs.append(img) 25 | imgs.extend([img] * VIDEO_FPS * 2) 26 | 27 | root_file_name = os.path.basename(root_file_path) 28 | instance_name = root_file_name[:root_file_name.find('.jpg')] 29 | vid_path = instance_name + '.mp4' 30 | gif_path = instance_name + '.gif' 31 | 32 | out = cv2.VideoWriter(vid_path, cv2.VideoWriter_fourcc(*'mp4v'), VIDEO_FPS, size) 33 | 34 | for i in range(len(imgs)): 35 | out.write(imgs[i]) 36 | out.release() 37 | 38 | convert_to_gif(vid_path, gif_path) 39 | 40 | if del_imgs: 41 | for file_path in file_paths: 42 | os.remove(file_path) 43 | 44 | if __name__ == '__main__': 45 | gen_video('test_0', False) 46 | -------------------------------------------------------------------------------- /text_gen.py: -------------------------------------------------------------------------------- 1 | from transformers import pipeline 2 | 3 | generator = pipeline('text-generation', model='gpt2') 4 | 5 | # Source: https://stackoverflow.com/questions/1883980/find-the-nth-occurrence-of-substring-in-a-string 6 | def find_nth(haystack, needle, n): 7 | parts = haystack.split(needle, n + 1) 8 | if len(parts) <= n + 1: 9 | return -1 10 | return len(haystack) - len(parts[-1]) - len(needle) 11 | 12 | def clean_title(title): 13 | new_title = title.replace('\\', '') 14 | new_title = ' '.join(new_title.split()) 15 | new_title = new_title.strip() 16 | return new_title 17 | 18 | def gen_titles(n=1, genre=None): 19 | if not genre: 20 | genre = '' 21 | else: 22 | genre = genre.replace('"', '').strip() + ' ' 23 | prompt = f'This {genre} artwork is called "' 24 | 25 | output = generator( 26 | prompt, 27 | max_length = 25 + len(genre.split()), 28 | num_return_sequences = n, 29 | num_beams = 10, 30 | temperature = 10.0, 31 | top_k = 100 32 | ) 33 | 34 | output_texts = [] 35 | for item in output: 36 | full_text = item['generated_text'] 37 | start_idx = find_nth(full_text, '"', 0) + 1 38 | end_idx = find_nth(full_text, '"', 1) 39 | 40 | if end_idx == -1: 41 | target_text = full_text[start_idx:] 42 | else: 43 | target_text = full_text[start_idx:end_idx] 44 | 45 | cleaned_title = clean_title(target_text) 46 | # Replace any empty titles 47 | if cleaned_title == '': 48 | cleaned_title = gen_titles(n, genre)[0] 49 | 50 | output_texts.append(cleaned_title) 51 | 52 | return output_texts 53 | 54 | if __name__ == '__main__': 55 | print(gen_titles(3)) -------------------------------------------------------------------------------- /gen_cli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from image_gen import * 4 | from text_gen import gen_titles 5 | 6 | parser = argparse.ArgumentParser(description='Generate art.') 7 | parser.add_argument('--text', '-t', type=str, 8 | help='text to generate an image') 9 | parser.add_argument('--text_file', '-f', type=str, 10 | help='path to images to generate, one per line') 11 | parser.add_argument('--preset', '-p', default='med', 12 | help='preset for image quality/gen speed') 13 | parser.add_argument('--genre', '-g', type=str, 14 | help='genre of the title to be generated') 15 | parser.add_argument('--count', '-n', type=int, default=1, 16 | help='number of images to generate') 17 | 18 | PRESET_MAP = { 19 | '0': TEST_PRESET, 20 | '1': VERY_LOW_QUALITY_PRESET, 21 | '2': LOW_QUALITY_PRESET, 22 | '3': MED_QUALITY_PRESET, 23 | '4': HIGH_QUALITY_PRESET, 24 | 'test': TEST_PRESET, 25 | 'vlow': VERY_LOW_QUALITY_PRESET, 26 | 'low': LOW_QUALITY_PRESET, 27 | 'med': MED_QUALITY_PRESET, 28 | 'high': HIGH_QUALITY_PRESET 29 | } 30 | 31 | if __name__ == '__main__': 32 | args = parser.parse_args() 33 | preset = PRESET_MAP[args.preset.lower()] 34 | 35 | if args.text is not None: 36 | texts = [args.text] * args.count 37 | elif args.text_file is not None: 38 | with open(args.text_file, 'r') as f: 39 | lines = f.readlines() 40 | texts = [] 41 | for line in lines: 42 | for _ in range(args.count): 43 | texts.append(line) 44 | else: 45 | print('No titles given, so generating a random title...') 46 | texts = gen_titles(n=args.count, genre=args.genre) 47 | print('Generated titles:') 48 | for text in texts: 49 | print(text) 50 | 51 | for text in texts: 52 | generate_image(text.strip(), preset, True) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AI Art Generator 2 | *by Edan Meyer* 3 | 4 | This project provides a wraper over [DeepDaze](https://github.com/lucidrains/deep-daze) for generating artwork. It focuses on simplifying the process even further and adding more automation. It adds 3 things in addition to the base ability to generate images from text: 5 | - Predefined preset for varying levels of quality 6 | - Generation of videos and gifs from the images 7 | - Automatic generation of text prompts for end-to-end image generation 8 | 9 | ## Install 10 | $ pip install -r requirements.txt 11 | 12 | ## Examples 13 | **Generate a random piece of art** 14 | 15 | $ python gen_cli.py 16 | 17 | **Generate many, low-quality random pieces of art** 18 | 19 | $ python gen_cli.py -p low -n 10 20 | 21 | **Generate a specific piece of art** 22 | 23 | $ python gen_cli.py -t 'A doge in a coin going to the moon' 24 | 25 | ## Usage 26 | DESCRIPTION 27 | Generates art varying based on provided parameters. 28 | 29 | FLAGS 30 | --text / -t 31 | Default: None 32 | Prompt used to generate the image. 33 | --text_file / -f 34 | Default: None 35 | Path to file with a list of prompts delimited by new lines, 36 | and image for each prompt will be generated sequentially. 37 | --preset / -p 38 | Default: 'med' 39 | Preset for the level of quality of the image generated, valid 40 | presets are: ['test', 'vlow', 'low', 'med', 'high']. 41 | --genre / -g 42 | Default: None 43 | A string giving the genre of art you want to be generated. 44 | Leaving it as None will result in no specific genre. 45 | --count / -n 46 | Default: 1 47 | The number of pieces to produce. If --text is passed in, n 48 | variations of the same piece are produced, if --text_file is 49 | passed in, n variations of each prompt are produced. 50 | If neither is specified, n random pieces are produced. 51 | -------------------------------------------------------------------------------- /image_gen.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | 4 | import deep_daze 5 | from deep_daze import Imagine 6 | import torch 7 | from tqdm import trange 8 | 9 | from utils import sanitize_file_name, sanitize_folder_name 10 | from video_gen import gen_video 11 | 12 | HIGH_QUALITY_PRESET = { 13 | 'num_layers': 24, 14 | 'image_width': 512, 15 | 'epochs': 4, 16 | 'iterations': 1050, 17 | 'lr': 1e-5, 18 | 'batch_size': 4, 19 | 'gradient_accumulate_every': 6, 20 | 'save_every': 10 21 | } 22 | 23 | MED_QUALITY_PRESET = { 24 | 'num_layers': 16, 25 | 'image_width': 512, 26 | 'epochs': 3, 27 | 'iterations': 900, 28 | 'lr': 1e-5, 29 | 'batch_size': 4, 30 | 'gradient_accumulate_every': 4, 31 | 'save_every': 10 32 | } 33 | 34 | LOW_QUALITY_PRESET = { 35 | 'num_layers': 16, 36 | 'image_width': 256, 37 | 'epochs': 3, 38 | 'iterations': 512, 39 | 'lr': 2e-5, 40 | 'batch_size': 4, 41 | 'gradient_accumulate_every': 4, 42 | 'save_every': 10 43 | } 44 | 45 | VERY_LOW_QUALITY_PRESET = { 46 | 'num_layers': 16, 47 | 'image_width': 128, 48 | 'epochs': 1, 49 | 'iterations': 128, 50 | 'lr': 2e-5, 51 | 'batch_size': 4, 52 | 'gradient_accumulate_every': 2, 53 | 'save_every': 10 54 | } 55 | 56 | TEST_PRESET = { 57 | 'num_layers': 16, 58 | 'image_width': 64, 59 | 'epochs': 2, 60 | 'iterations': 10, 61 | 'lr': 2e-5, 62 | 'batch_size': 4, 63 | 'gradient_accumulate_every': 1, 64 | 'save_every': 2 65 | } 66 | 67 | def generate_image(text, preset, save_video=False): 68 | i = 0 69 | output_dir = sanitize_folder_name( 70 | '{}_{}'.format(text.replace(' ', '_'), i)) 71 | 72 | while os.path.exists(output_dir) and os.path.isdir(output_dir): 73 | i += 1 74 | output_dir = sanitize_folder_name( 75 | '{}_{}'.format(text.replace(' ', '_'), i)) 76 | 77 | os.mkdir(output_dir) 78 | os.chdir(output_dir) 79 | 80 | model = Imagine( 81 | text = text, 82 | num_layers = preset['num_layers'], 83 | save_every = preset['save_every'], 84 | image_width = preset['image_width'], 85 | lr = preset['lr'], 86 | iterations = preset['iterations'], 87 | save_progress = True, 88 | save_video = save_video, 89 | save_gif = False 90 | ) 91 | 92 | # Write text used to generate the piece 93 | with open('name.txt', 'w+') as f: 94 | f.write(text) 95 | # Overwrite the output path of images 96 | model.textpath = sanitize_file_name(text) 97 | 98 | for epoch in trange(preset['epochs'], desc='epoch'): 99 | for i in trange(preset['iterations'], desc='iteration'): 100 | model.train_step(epoch, i) 101 | 102 | del model 103 | gc.collect() 104 | torch.cuda.empty_cache() 105 | 106 | if save_video: 107 | gen_video('./', del_imgs=True) 108 | 109 | os.chdir('../') 110 | 111 | if __name__ == '__main__': 112 | generate_image('red_roses', TEST_PRESET, True) --------------------------------------------------------------------------------