├── outer ├── __init__.py ├── models │ ├── __init__.py │ ├── generator_utils.py │ ├── attn_decoder_rnn.py │ ├── cover_classes.py │ ├── base_generator.py │ ├── my_discriminator.py │ ├── discriminator.py │ ├── generator.py │ └── my_generator_fixed_multi_circle.py ├── svg_tools │ └── __init__.py ├── metadata_extractor.py ├── emotions.py ├── test_blur.py ├── colors_tools.py ├── dataset.py ├── represent.py ├── audio_extractor.py └── SVGContainer.py ├── utils ├── __init__.py ├── filenames.py ├── edges.py ├── glyphs_font_checker.py ├── noise.py ├── image_clustering.py ├── color_extractor.py ├── color_contrast.py ├── checkpoint.py ├── text_fitter.py ├── dataset_utils.py ├── plotting.py └── bboxes.py ├── captions ├── __init__.py ├── models │ ├── __init__.py │ └── captioner.py ├── get_caption_borders.py └── train.py ├── colorer ├── __init__.py ├── models │ ├── __init__.py │ ├── colorer.py │ ├── colorer_dropout.py │ └── gan_colorer.py ├── check_music_features.py ├── colors_transforms.py ├── test_all_cielabs.py ├── draw_palettes.py ├── get_main_colors.py ├── transfer_style.py ├── train.py ├── test_model.py └── music_palette_dataset.py ├── docker_build.sh ├── examples ├── gen1_capt1 │ ├── full_img.PNG │ ├── Boney M. - Rasputin.PNG │ ├── Burak Yeter - Body Talks.png │ ├── Charlie Puth - Girlfriend.PNG │ └── Imagine Dragons - Follow You.png ├── gen2_capt2 │ ├── BOYO - Dance Alone.png │ └── Nawms - Sugar Cane (feat. Emilia Ali).png ├── gen1_capt2 │ ├── Imagine Dragons - Believer.png │ └── Imagine Dragons - Believer.svg └── tracks_with_generated_covers_gen1_capt1 │ ├── demo_track_1.mp3 │ ├── demo_track_2.mp3 │ ├── demo_track_3.mp3 │ ├── demo_track_4.mp3 │ ├── demo_track_1.mp3-3.svg │ ├── demo_track_4.mp3-1.svg │ ├── demo_track_1.mp3-5.svg │ ├── demo_track_4.mp3-2.svg │ ├── demo_track_4.mp3-3.svg │ ├── demo_track_1.mp3-1.svg │ ├── demo_track_2.mp3-1.svg │ ├── demo_track_2.mp3-2.svg │ ├── demo_track_3.mp3-2.svg │ ├── demo_track_1.mp3-2.svg │ ├── demo_track_1.mp3-4.svg │ ├── demo_track_2.mp3-4.svg │ ├── demo_track_2.mp3-5.svg │ ├── demo_track_3.mp3-3.svg │ ├── demo_track_4.mp3-4.svg │ ├── demo_track_3.mp3-4.svg │ ├── demo_track_3.mp3-1.svg │ ├── demo_track_4.mp3-5.svg │ ├── demo_track_2.mp3-3.svg │ └── demo_track_3.mp3-5.svg ├── __init__.py ├── docker_run.sh ├── requirements.txt ├── docs ├── captioner_train_help.txt ├── eval_help.txt └── covergan_train_help.txt ├── covergan_training_command.sh ├── Dockerfile ├── captioner_train.py ├── .gitignore ├── eval.py ├── colorer_train.py ├── README.md └── covergan_train.py /outer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /captions/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /colorer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /captions/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /colorer/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /outer/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /outer/svg_tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docker_build.sh: -------------------------------------------------------------------------------- 1 | docker build -t covergan_training . -------------------------------------------------------------------------------- /examples/gen1_capt1/full_img.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IzhanVarsky/CoverGAN/HEAD/examples/gen1_capt1/full_img.PNG -------------------------------------------------------------------------------- /examples/gen1_capt1/Boney M. - Rasputin.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IzhanVarsky/CoverGAN/HEAD/examples/gen1_capt1/Boney M. - Rasputin.PNG -------------------------------------------------------------------------------- /examples/gen2_capt2/BOYO - Dance Alone.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IzhanVarsky/CoverGAN/HEAD/examples/gen2_capt2/BOYO - Dance Alone.png -------------------------------------------------------------------------------- /examples/gen1_capt1/Burak Yeter - Body Talks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IzhanVarsky/CoverGAN/HEAD/examples/gen1_capt1/Burak Yeter - Body Talks.png -------------------------------------------------------------------------------- /examples/gen1_capt1/Charlie Puth - Girlfriend.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IzhanVarsky/CoverGAN/HEAD/examples/gen1_capt1/Charlie Puth - Girlfriend.PNG -------------------------------------------------------------------------------- /examples/gen1_capt2/Imagine Dragons - Believer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IzhanVarsky/CoverGAN/HEAD/examples/gen1_capt2/Imagine Dragons - Believer.png -------------------------------------------------------------------------------- /examples/gen1_capt1/Imagine Dragons - Follow You.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IzhanVarsky/CoverGAN/HEAD/examples/gen1_capt1/Imagine Dragons - Follow You.png -------------------------------------------------------------------------------- /colorer/check_music_features.py: -------------------------------------------------------------------------------- 1 | from utils.dataset_utils import create_audio_tensors_for_folder 2 | 3 | create_audio_tensors_for_folder('check_volume', 'check_volume_ckpts') 4 | -------------------------------------------------------------------------------- /examples/gen2_capt2/Nawms - Sugar Cane (feat. Emilia Ali).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IzhanVarsky/CoverGAN/HEAD/examples/gen2_capt2/Nawms - Sugar Cane (feat. Emilia Ali).png -------------------------------------------------------------------------------- /examples/tracks_with_generated_covers_gen1_capt1/demo_track_1.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IzhanVarsky/CoverGAN/HEAD/examples/tracks_with_generated_covers_gen1_capt1/demo_track_1.mp3 -------------------------------------------------------------------------------- /examples/tracks_with_generated_covers_gen1_capt1/demo_track_2.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IzhanVarsky/CoverGAN/HEAD/examples/tracks_with_generated_covers_gen1_capt1/demo_track_2.mp3 -------------------------------------------------------------------------------- /examples/tracks_with_generated_covers_gen1_capt1/demo_track_3.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IzhanVarsky/CoverGAN/HEAD/examples/tracks_with_generated_covers_gen1_capt1/demo_track_3.mp3 -------------------------------------------------------------------------------- /examples/tracks_with_generated_covers_gen1_capt1/demo_track_4.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IzhanVarsky/CoverGAN/HEAD/examples/tracks_with_generated_covers_gen1_capt1/demo_track_4.mp3 -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["captions", "outer", "colorer", "utils"] 2 | 3 | from covergan import captions 4 | from covergan import outer 5 | from covergan import colorer 6 | from covergan import utils 7 | -------------------------------------------------------------------------------- /utils/filenames.py: -------------------------------------------------------------------------------- 1 | from unicodedata import normalize 2 | 3 | 4 | # This is a workaround for Colab's Google Drive filename processing 5 | def normalize_filename(filename: str): 6 | return normalize("NFC", filename) 7 | -------------------------------------------------------------------------------- /utils/edges.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from skimage.feature import canny 4 | from kornia.color import rgb_to_grayscale 5 | 6 | 7 | def detect_edges(img: torch.Tensor) -> torch.Tensor: 8 | # The input image is expected to be CxHxW 9 | return torch.tensor(canny(rgb_to_grayscale(img).squeeze(0).cpu().numpy())) 10 | -------------------------------------------------------------------------------- /docker_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | rest=$@ 6 | 7 | IMAGE=covergan_training:latest 8 | 9 | CONTAINER_ID=$(docker inspect --format="{{.Id}}" ${IMAGE} 2> /dev/null) 10 | if [[ "${CONTAINER_ID}" ]]; then 11 | docker run --runtime=nvidia --shm-size=2g --gpus all --rm -v `pwd`:/scratch --user $(id -u):$(id -g) \ 12 | --workdir=/scratch -e HOME=/scratch $IMAGE $@ 13 | else 14 | echo "Unknown container image: ${IMAGE}" 15 | exit 1 16 | fi 17 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # PyDiffVG Python dependencies 2 | cssutils 3 | scikit-learn 4 | scikit-image 5 | svgwrite 6 | svgpathtools 7 | 8 | # Audio extraction dependencies 9 | essentia 10 | eyed3 11 | mutagen 12 | 13 | # CoverGAN dependencies 14 | grpcio 15 | grpcio-tools 16 | Pillow>=8.2.0 17 | kornia 18 | matplotlib 19 | scipy 20 | cairosvg 21 | colorthief 22 | colormath 23 | imageio 24 | opencv-python-headless 25 | fonttools 26 | googletrans 27 | lyricsgenius 28 | svglib 29 | reportlab 30 | lxml 31 | wand -------------------------------------------------------------------------------- /utils/glyphs_font_checker.py: -------------------------------------------------------------------------------- 1 | from fontTools.ttLib import TTFont 2 | 3 | 4 | def font_supports_all_glyphs(phrase_words, font_path): 5 | font = TTFont(font_path) 6 | for word in phrase_words: 7 | for c in word: 8 | if not has_glyph(font, c): 9 | return False 10 | return True 11 | 12 | 13 | def has_glyph(font, glyph): 14 | for table in font['cmap'].tables: 15 | if ord(glyph) in table.cmap.keys(): 16 | return True 17 | return False 18 | -------------------------------------------------------------------------------- /utils/noise.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_noise(n_samples, input_dim, device): 5 | """ 6 | Function for creating noise vectors: Given the dimensions (n_samples, input_dim) 7 | creates a tensor of that shape filled with random numbers from the normal distribution. 8 | Parameters: 9 | n_samples: the number of samples to generate, a scalar 10 | input_dim: the dimension of the input vector, a scalar 11 | device: the device type 12 | """ 13 | return torch.randn(n_samples, input_dim, device=device) 14 | 15 | 16 | def combine_noise_and_one_hot(x, y): 17 | assert len(x) == len(y) 18 | return torch.cat((x.float(), y.float()), dim=1) 19 | -------------------------------------------------------------------------------- /colorer/colors_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from colormath.color_objects import sRGBColor, LabColor 3 | from colormath.color_conversions import convert_color 4 | 5 | 6 | def rgb_to_cielab(a): 7 | a = np.array(a) 8 | a1, a2, a3 = a / 255 9 | color1_rgb = sRGBColor(a1, a2, a3) 10 | color1_lab = convert_color(color1_rgb, LabColor) 11 | return np.array([color1_lab.lab_l, color1_lab.lab_a, color1_lab.lab_b]) 12 | 13 | 14 | def cielab_rgb_to(a): 15 | a = np.array(a) 16 | a1, a2, a3 = a 17 | lab = LabColor(a1, a2, a3) 18 | color = convert_color(lab, sRGBColor) 19 | return np.array([color.rgb_r, color.rgb_g, color.rgb_b]) 20 | 21 | 22 | def rgb_lab_rgb(a): 23 | return list((cielab_rgb_to(rgb_to_cielab(a)) * 255).astype(int)) 24 | -------------------------------------------------------------------------------- /colorer/test_all_cielabs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from colorer.colors_transforms import rgb_to_cielab 4 | 5 | arange = np.arange(0, 256, 1) 6 | 7 | lab_l = [] 8 | lab_a = [] 9 | lab_b = [] 10 | step = 1 11 | for r in arange: 12 | for g in arange: 13 | for b in arange: 14 | print(r, g, b) 15 | cielab = rgb_to_cielab([r, g, b]) 16 | lab_l.append(cielab[0]) 17 | lab_a.append(cielab[1]) 18 | lab_b.append(cielab[2]) 19 | sorted_lab_l = sorted(lab_l) 20 | sorted_lab_a = sorted(lab_a) 21 | sorted_lab_b = sorted(lab_b) 22 | print(sorted_lab_l[0], sorted_lab_l[-1]) 23 | print(sorted_lab_a[0], sorted_lab_a[-1]) 24 | print(sorted_lab_b[0], sorted_lab_b[-1]) 25 | # 0.0 99.99998453333127 26 | # -86.1829494051608 98.23532017664644 27 | # -107.86546414496824 94.47731817969378 -------------------------------------------------------------------------------- /utils/image_clustering.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def cluster(image, k=4, with_labels_centers=False, only_labels_centers=False): 6 | # Reshaping the image into a 2D array of pixels and 3 color values (RGB) 7 | pixel_vals = image.reshape((-1, 3)) 8 | # Convert to float type only for supporting cv2.kmean 9 | pixel_vals = np.float32(pixel_vals) 10 | criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.85) 11 | retval, labels, centers = cv2.kmeans(pixel_vals, k, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS) 12 | centers = np.uint8(centers) 13 | if only_labels_centers: 14 | return labels, centers 15 | segmented_data = centers[labels.flatten()] # Mapping labels to center points (RGB Value) 16 | res = segmented_data.reshape(image.shape) 17 | if with_labels_centers: 18 | return res, labels, centers 19 | return res 20 | -------------------------------------------------------------------------------- /docs/captioner_train_help.txt: -------------------------------------------------------------------------------- 1 | Train the Captioner Network. 2 | 3 | Usage: captioner_train.py [OPTIONS] 4 | 5 | Options: 6 | --original_covers DIR Directory with the original cover images [default: "./original_covers"] 7 | --clean_covers DIR Directory with the cover images with captions removed 8 | [default: "./clean_covers"] 9 | --checkpoint_root DIR Checkpoint location [default: "./checkpoint"] 10 | 11 | --lr FLOAT Learning rate [default: 0.001] 12 | --epochs INT Number of epochs to train for [default: 138] 13 | --batch_size INT Batch size [default: 64] 14 | --canvas_size INT Image canvas size for learning [default: 256] 15 | --display_steps INT How often to plot the samples [default: 10] 16 | 17 | --plot_grad Whether to plot the gradients [default: False] 18 | 19 | -h, --help Show help message and exit 20 | -------------------------------------------------------------------------------- /utils/color_extractor.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from scipy.cluster.vq import kmeans, vq 7 | 8 | 9 | def extract_primary_color(img: torch.Tensor, count: int) -> Optional[Tuple[int, int, int]]: 10 | # Requires HWC RGBA tensors 11 | shape = img.shape 12 | img = torch.reshape(img, (shape[0] * shape[1], shape[2])).numpy() 13 | 14 | codes, dist_ = kmeans(img, k_or_guess=count) 15 | img = img[:, :3] 16 | codes = codes[codes[:, 3] > 0.5][:, :3] # Filter out blank background (by alpha channel) 17 | 18 | if not len(codes): 19 | return None 20 | 21 | vecs, dist_ = vq(img, codes) 22 | counts, bin_edges_ = np.histogram(vecs, bins=len(codes)) # Count occurrences 23 | 24 | idx = np.argmax(counts) # Find most frequent 25 | 26 | result = np.round(codes[idx] * 255).astype(int) # [0, 1] -> [0, 255] 27 | 28 | return tuple(map(int, result)) 29 | -------------------------------------------------------------------------------- /utils/color_contrast.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | 4 | def luminance(rgb: Tuple[int, int, int]): 5 | # https://www.w3.org/TR/2008/REC-WCAG20-20081211/#relativeluminancedef 6 | def f(v): 7 | v /= 255 8 | return v / 12.92 if v <= 0.03928 else ((v + 0.055) / 1.055) ** 2.4 9 | 10 | r, g, b = rgb 11 | return f(r) * 0.2126 + f(g) * 0.7152 + f(b) * 0.0722 12 | 13 | 14 | def contrast(rgb1: Tuple[int, int, int], rgb2: Tuple[int, int, int]) -> float: 15 | # https://www.w3.org/TR/2008/REC-WCAG20-20081211/#contrast-ratiodef 16 | lum1 = luminance(rgb1) 17 | lum2 = luminance(rgb2) 18 | brightest = max(lum1, lum2) 19 | darkest = min(lum1, lum2) 20 | return (brightest + 0.05) / (darkest + 0.05) 21 | 22 | 23 | def sufficient_contrast(rgb1: Tuple[int, int, int], 24 | rgb2: Tuple[int, int, int]) -> bool: 25 | # https://www.w3.org/TR/2008/REC-WCAG20-20081211/#visual-audio-contrast-contrast 26 | return contrast(rgb1, rgb2) >= 3 # Consider 4.5 for smaller text 27 | -------------------------------------------------------------------------------- /covergan_training_command.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | pwd 3 | nvidia-smi 4 | echo "FINE!" 5 | #python3 ./covergan_train.py --emotions ./emotions.json --covers ./clean_covers/ --audio ./audio --epochs 50000 6 | #python3 ./covergan_train.py --emotions ./small_emotions.json --covers ./small_clean_covers/ --audio ./small_audio --checkpoint_root ./small_checkpoint --epochs 50000 7 | #python3 ./covergan_train.py --train_dir ./dataset_demo_4 --emotions ./demo_emotions.json --epochs 50000 --display_steps 300 8 | #python3 ./covergan_train.py --train_dir ./dataset_emoji_4 --emotions ./emotions.json --epochs 50000 --display_steps 100 9 | #python3 ./covergan_train.py --train_dir ./dataset_emoji_52 --emotions ./emotions.json --epochs 50000 --display_steps 300 10 | python3 ./covergan_train.py --train_dir ./dataset_full_covers --emotions ./emotions.json --epochs 50000 --display_steps 100 --backup_epochs 20 --gen_lr 0.0001 --disc_lr 0.0004 11 | #python3 ./colorer_train.py --train_dir ./dataset_full_covers --emotions ./emotions.json --epochs 50000 --display_steps 500 --backup_epochs 50 -------------------------------------------------------------------------------- /colorer/draw_palettes.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageDraw 2 | 3 | palettes = [ 4 | [(185.83932795535875, 245.76842360647782, 158.07709104142657), 5 | (65.17575993764615, 65.35268901013251, 69.9871395167576), 6 | (208.78896882494004, 247.26139088729016, 192.2170263788969), 7 | (60.50174825174825, 184.7937062937063, 186.37179487179486), 8 | (242.23333333333332, 211.1, 184.26666666666668), 9 | (213.14456391875746, 250.02867383512546, 195.69534050179212)], 10 | [(190, 244, 159), (61, 112, 115), (100, 99, 103), 11 | (114, 170, 157), (196, 244, 217), (94, 132, 84)] 12 | ] 13 | imsize = 1024 14 | for palette in palettes: 15 | palette = [tuple(map(int, c)) for c in palette] 16 | im = Image.new('RGB', (imsize, imsize)) 17 | draw = ImageDraw.Draw(im) 18 | cur_h = 0 19 | width = imsize / len(palette) 20 | for i, p in enumerate(palette): 21 | print(i, p) 22 | colorval = "#%02x%02x%02x" % p 23 | draw.rectangle((0, cur_h, imsize, cur_h + width), fill=colorval) 24 | cur_h += width 25 | im.show() 26 | -------------------------------------------------------------------------------- /outer/models/generator_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from outer.colors_tools import palette_to_triad_palette 4 | 5 | 6 | def colorize(paths, colors, use_triad=False, need_stroke=False): 7 | if use_triad: 8 | colors = colors.reshape(colors.shape[0], -1, 3) 9 | colors = palette_to_triad_palette(colors) 10 | background_color = colors[0] 11 | color_ind = 1 12 | for path in paths: 13 | if path["fill_color"] is not None: 14 | # was alpha channel 15 | path["fill_color"] = torch.cat((colors[color_ind], path["fill_color"]), dim=0) 16 | else: 17 | path["fill_color"] = colors[color_ind] 18 | color_ind += 1 19 | if need_stroke: 20 | if path["stroke_color"] is not None: 21 | # was alpha channel 22 | path["stroke_color"] = torch.cat((colors[color_ind], path["stroke_color"]), dim=0) 23 | else: 24 | path["stroke_color"] = colors[color_ind] 25 | color_ind += 1 26 | else: 27 | path["stroke_color"] = path["fill_color"] 28 | return background_color -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/cuda:11.4.2-cudnn8-devel-ubuntu20.04 2 | 3 | ENV PYTHONDONTWRITEBYTECODE 1 4 | ENV PYTHONUNBUFFERED 1 5 | 6 | RUN apt-get update -y 7 | RUN apt-get update 8 | RUN DEBIAN_FRONTEND=noninteractive apt-get install -y python3-pip cmake 9 | RUN DEBIAN_FRONTEND=noninteractive apt upgrade -y cmake 10 | 11 | RUN pip3 install torch==1.10.0+cu113 torchvision==0.11.1+cu113 torchaudio==0.10.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html 12 | 13 | COPY ./requirements.txt ./requirements.txt 14 | COPY ./diffvg ./diffvg 15 | WORKDIR . 16 | 17 | RUN pip3 install -r ./requirements.txt 18 | RUN cd diffvg && python3 ./setup.py install && cd .. 19 | # If problems with installing or using diffvg, check this issue: 20 | # https://github.com/BachiLi/diffvg/issues/29#issuecomment-994807865 21 | 22 | # Unset TORCH_CUDA_ARCH_LIST and exec. This makes pytorch run-time 23 | # extension builds significantly faster as we only compile for the 24 | # currently active GPU configuration. 25 | #RUN (printf '#!/bin/bash\nunset TORCH_CUDA_ARCH_LIST\nexec \"$@\"\n' >> /entry.sh) && chmod a+x /entry.sh 26 | #ENTRYPOINT ["/entry.sh"] 27 | 28 | ENTRYPOINT ["./covergan_training_command.sh"] -------------------------------------------------------------------------------- /docs/eval_help.txt: -------------------------------------------------------------------------------- 1 | Test the CoverGAN. 2 | 3 | Usage: eval.py [OPTIONS] 4 | 5 | Options: 6 | --gan_weights DIR Model weights for CoverGAN [default: "./weights/covergan.pt"] 7 | --captioner_weights DIR Captioner weights [default: "./weights/captioner.pt"] 8 | --protosvg_address STR ProtoSVG rendering server [default: "localhost:50051"] 9 | --font_dir DIR Directory with font files [default: "./fonts"] 10 | --output_dir DIR Directory where to save SVG covers [default: "./gen_samples"] 11 | --num_samples INT Number of samples to generate [default: 5] 12 | 13 | --audio_file STR Path to the audio file to process [required] 14 | --emotions STR Emotion of the audio file [required] 15 | --track_artist STR Track artist [required] 16 | --track_name STR Track name [required] 17 | 18 | --filter Overlay filter to apply to the final image [default: False] 19 | --rasterize Whether to rasterize the generated cover [default: False] 20 | --watermark Whether to add watermark [default: False] 21 | --deterministic Whether to disable random noise [default: False] 22 | --debug Whether to enable debug features [default: False] 23 | 24 | -h, --help Show help message and exit 25 | -------------------------------------------------------------------------------- /colorer/get_main_colors.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import binascii 4 | 5 | import numpy as np 6 | import scipy.cluster 7 | from PIL import Image 8 | 9 | NUM_CLUSTERS = 15 10 | 11 | print('reading image') 12 | im = Image.open('A S T R O - Change.jpg') 13 | im = im.resize((150, 150)) # optional, to reduce time 14 | ar = np.asarray(im) 15 | shape = ar.shape 16 | ar = ar.reshape(np.product(shape[:2]), shape[2]).astype(float) 17 | 18 | print('finding clusters') 19 | codes, dist = scipy.cluster.vq.kmeans(ar, NUM_CLUSTERS) 20 | print('cluster centres:\n', codes) 21 | 22 | vecs, dist = scipy.cluster.vq.vq(ar, codes) # assign codes 23 | counts, bins = np.histogram(vecs, len(codes)) # count occurrences 24 | 25 | index_max = np.argmax(counts) # find most frequent 26 | peak = codes[index_max] 27 | colour = binascii.hexlify(bytearray(int(c) for c in peak)).decode('ascii') 28 | print('most frequent is %s (#%s)' % (peak, colour)) 29 | 30 | # bonus: save image using only the N most common colours 31 | import imageio 32 | 33 | print("====") 34 | print([tuple(c) for c in codes]) 35 | print("====") 36 | print(len(vecs)) 37 | print(vecs) 38 | print("====") 39 | 40 | c = ar.copy() 41 | for i, code in enumerate(codes): 42 | c[np.r_[np.where(vecs == i)], :] = code 43 | imageio.imwrite('clusters.png', c.reshape(*shape).astype(np.uint8)) 44 | print('saved clustered image') 45 | -------------------------------------------------------------------------------- /examples/gen1_capt2/Imagine Dragons - Believer.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | 12 | 18 | 24 | 25 | 26 | Imagine Dragons – Believer 27 | 28 | -------------------------------------------------------------------------------- /examples/tracks_with_generated_covers_gen1_capt1/demo_track_1.mp3-3.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 10 | 16 | 22 | 23 | Cool Band 24 | 25 | 26 | New Song 27 | 28 | 29 | -------------------------------------------------------------------------------- /examples/tracks_with_generated_covers_gen1_capt1/demo_track_4.mp3-1.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 10 | 16 | 22 | 23 | Cool Band 24 | 25 | 26 | New Song 27 | 28 | 29 | -------------------------------------------------------------------------------- /examples/tracks_with_generated_covers_gen1_capt1/demo_track_1.mp3-5.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 10 | 16 | 22 | 23 | Cool Band 24 | 25 | 26 | New Song 27 | 28 | 29 | -------------------------------------------------------------------------------- /examples/tracks_with_generated_covers_gen1_capt1/demo_track_4.mp3-2.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 10 | 16 | 22 | 23 | Cool Band 24 | 25 | 26 | New Song 27 | 28 | 29 | -------------------------------------------------------------------------------- /examples/tracks_with_generated_covers_gen1_capt1/demo_track_4.mp3-3.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 10 | 16 | 22 | 23 | Cool Band 24 | 25 | 26 | New Song 27 | 28 | 29 | -------------------------------------------------------------------------------- /examples/tracks_with_generated_covers_gen1_capt1/demo_track_1.mp3-1.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 10 | 16 | 22 | 23 | Cool Band 24 | 25 | 26 | New Song 27 | 28 | 29 | -------------------------------------------------------------------------------- /examples/tracks_with_generated_covers_gen1_capt1/demo_track_2.mp3-1.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 10 | 16 | 22 | 23 | Cool Band 24 | 25 | 26 | New Song 27 | 28 | 29 | -------------------------------------------------------------------------------- /examples/tracks_with_generated_covers_gen1_capt1/demo_track_2.mp3-2.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 10 | 16 | 22 | 23 | Cool Band 24 | 25 | 26 | New Song 27 | 28 | 29 | -------------------------------------------------------------------------------- /examples/tracks_with_generated_covers_gen1_capt1/demo_track_3.mp3-2.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 10 | 16 | 22 | 23 | Cool Band 24 | 25 | 26 | New Song 27 | 28 | 29 | -------------------------------------------------------------------------------- /examples/tracks_with_generated_covers_gen1_capt1/demo_track_1.mp3-2.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 10 | 16 | 22 | 23 | Cool Band 24 | 25 | 26 | New Song 27 | 28 | 29 | -------------------------------------------------------------------------------- /examples/tracks_with_generated_covers_gen1_capt1/demo_track_1.mp3-4.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 10 | 16 | 22 | 23 | Cool Band 24 | 25 | 26 | New Song 27 | 28 | 29 | -------------------------------------------------------------------------------- /examples/tracks_with_generated_covers_gen1_capt1/demo_track_2.mp3-4.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 10 | 16 | 22 | 23 | Cool Band 24 | 25 | 26 | New Song 27 | 28 | 29 | -------------------------------------------------------------------------------- /examples/tracks_with_generated_covers_gen1_capt1/demo_track_2.mp3-5.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 10 | 16 | 22 | 23 | Cool Band 24 | 25 | 26 | New Song 27 | 28 | 29 | -------------------------------------------------------------------------------- /examples/tracks_with_generated_covers_gen1_capt1/demo_track_3.mp3-3.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 10 | 16 | 22 | 23 | Cool Band 24 | 25 | 26 | New Song 27 | 28 | 29 | -------------------------------------------------------------------------------- /examples/tracks_with_generated_covers_gen1_capt1/demo_track_4.mp3-4.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 10 | 16 | 22 | 23 | Cool Band 24 | 25 | 26 | New Song 27 | 28 | 29 | -------------------------------------------------------------------------------- /examples/tracks_with_generated_covers_gen1_capt1/demo_track_3.mp3-4.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 10 | 16 | 22 | 23 | Cool Band 24 | 25 | 26 | New Song 27 | 28 | 29 | -------------------------------------------------------------------------------- /examples/tracks_with_generated_covers_gen1_capt1/demo_track_3.mp3-1.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 10 | 16 | 22 | 23 | Cool Band 24 | 25 | 26 | New Song 27 | 28 | 29 | -------------------------------------------------------------------------------- /examples/tracks_with_generated_covers_gen1_capt1/demo_track_4.mp3-5.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 10 | 16 | 22 | 23 | Cool Band 24 | 25 | 26 | New Song 27 | 28 | 29 | -------------------------------------------------------------------------------- /examples/tracks_with_generated_covers_gen1_capt1/demo_track_2.mp3-3.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 10 | 16 | 22 | 23 | Cool Band 24 | 25 | 26 | New Song 27 | 28 | 29 | -------------------------------------------------------------------------------- /examples/tracks_with_generated_covers_gen1_capt1/demo_track_3.mp3-5.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 10 | 16 | 22 | 23 | Cool Band 24 | 25 | 26 | New Song 27 | 28 | 29 | -------------------------------------------------------------------------------- /outer/models/attn_decoder_rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pydiffvg import device 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class AttnDecoderRNN(nn.Module): 8 | def __init__(self, hidden_size, output_size, dropout_p=0.1): 9 | super(AttnDecoderRNN, self).__init__() 10 | max_length = 5 11 | self.hidden_size = hidden_size 12 | self.output_size = output_size 13 | self.dropout_p = dropout_p 14 | self.max_length = max_length 15 | 16 | self.embedding = nn.Linear(self.output_size, self.hidden_size) 17 | self.attn = nn.Linear(self.hidden_size * 2, self.max_length) 18 | self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size) 19 | self.dropout = nn.Dropout(self.dropout_p) 20 | self.gru = nn.GRU(self.hidden_size, self.hidden_size) 21 | self.out = nn.Linear(self.hidden_size, self.output_size) 22 | 23 | def forward(self, input, hidden, encoder_outputs): 24 | embedded = self.embedding(input).view(1, 1, -1) 25 | embedded = self.dropout(embedded) 26 | 27 | attn_weights = F.softmax( 28 | self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1) 29 | attn_applied = torch.bmm(attn_weights.unsqueeze(0), 30 | encoder_outputs.unsqueeze(0)) 31 | 32 | output = torch.cat((embedded[0], attn_applied[0]), 1) 33 | output = self.attn_combine(output).unsqueeze(0) 34 | 35 | output = F.relu(output) 36 | output, hidden = self.gru(output, hidden) 37 | 38 | output = F.log_softmax(self.out(output[0]), dim=1) 39 | return output, hidden, attn_weights 40 | 41 | def initHidden(self): 42 | return torch.zeros(1, 1, self.hidden_size, device=device) 43 | -------------------------------------------------------------------------------- /outer/metadata_extractor.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from io import BytesIO 3 | 4 | import eyed3 5 | import eyed3.utils.art 6 | # Because one tag parser is not enough 7 | import mutagen 8 | from PIL import Image 9 | from eyed3 import id3 10 | 11 | logger = logging.getLogger("metadata_extractor") 12 | logger.addHandler(logging.StreamHandler()) 13 | logger.setLevel(logging.INFO) 14 | 15 | 16 | def get_tag_map(filename: str): 17 | logger.info(f'Extracting metadata for {filename}') 18 | 19 | eyed3_tag = id3.Tag() 20 | if not eyed3_tag.parse(filename): 21 | return None 22 | 23 | try: 24 | mutagen_tag = mutagen.File(filename) 25 | except mutagen.MutagenError: 26 | return None 27 | 28 | result = { 29 | 'filename': eyed3_tag.file_info.name, 30 | 'artist': eyed3_tag.artist, 31 | 'title': eyed3_tag.title, 32 | 'release_year': eyed3_tag.getBestDate().year, 33 | 'genre': eyed3_tag.genre 34 | } 35 | 36 | # eyeD3 37 | covers = eyed3.utils.art.getArtFromTag(eyed3_tag, id3.frames.ImageFrame.FRONT_COVER) 38 | if covers: 39 | result['cover'] = Image.open(BytesIO(covers[0].image_data)) 40 | 41 | # mutagen 42 | result['country'] = mutagen_tag.get('releasecountry') 43 | result['country'] = None if result['country'] is None else result['country'][0] 44 | result['release_date'] = mutagen_tag.get('date') # less reliable than eyeD3's year?f 45 | 46 | # mutagen + MusicBrainz 47 | result['artist_id'] = mutagen_tag.get('musicbrainz_artistid') 48 | result['album_id'] = mutagen_tag.get('musicbrainz_albumid') 49 | result['track_id'] = mutagen_tag.get('musicbrainz_trackid') 50 | result['release_track_id'] = mutagen_tag.get('musicbrainz_releasetrackid') 51 | result['release_group_id'] = mutagen_tag.get('musicbrainz_releasegroupid') 52 | 53 | return result 54 | -------------------------------------------------------------------------------- /docs/covergan_train_help.txt: -------------------------------------------------------------------------------- 1 | Train the CoverGAN Network. 2 | 3 | Usage: covergan_train.py [OPTIONS] 4 | 5 | Example of emotions.json file: 6 | [ 7 | ["track1.mp3", ["emotion1", "emotion2"]], 8 | ["track2.mp3", ["emotion1", "emotion2"]], 9 | ["track3.mp3", ["emotion1", "emotion2", "emotion3"]] 10 | ] 11 | 12 | Options: 13 | --audio DIR Directory with the music files [default: "./audio"] 14 | --covers DIR Directory with the cover images [default: "./clean_covers"] 15 | --emotions PATH_TO_JSON File with emotion markup for train dataset. 16 | In this file for each music track its emotions are indicated. 17 | --checkpoint_root DIR Checkpoint location [default: "./checkpoint"] 18 | 19 | --test_set DIR Directory with test music files. 20 | If the metadata extractor can not manage finding covers 21 | on the Internet, add covers of such dtracks to the same folder. 22 | --test_emotions PATH_TO_JSON File with emotion markup for test dataset. 23 | In this file for each music track its emotions are indicated. 24 | 25 | --lr FLOAT Learning rate [default: 0.0005] 26 | --disc_repeats INT Discriminator runs per iteration [default: 5] 27 | --epochs INT Number of epochs to train for [default: 8000] 28 | --batch_size INT Batch size [default: 64] 29 | --canvas_size INT Image canvas size for learning [default: 128] 30 | --display_steps INT How often to plot the samples [default: 500] 31 | --backup_epochs INT How often to backup checkpoints [default: 600] 32 | 33 | --augment_dataset Whether to augment the dataset [default: False] 34 | --plot_grad Whether to plot the gradients [default: False] 35 | 36 | -h, --help Show help message and exit 37 | -------------------------------------------------------------------------------- /outer/emotions.py: -------------------------------------------------------------------------------- 1 | from enum import IntEnum 2 | import json 3 | from typing import Optional 4 | 5 | import torch 6 | 7 | 8 | # IntEnum allows to compare enum values to ints directly 9 | class Emotion(IntEnum): 10 | ANGER = 0 11 | COMFORTABLE = 1 12 | FEAR = 2 13 | FUNNY = 3 14 | HAPPY = 4 15 | INSPIRATIONAL = 5 16 | JOY = 6 17 | LONELY = 7 18 | NOSTALGIC = 8 19 | PASSIONATE = 9 20 | QUIET = 10 21 | RELAXED = 11 22 | ROMANTIC = 12 23 | SADNESS = 13 24 | SERIOUS = 14 25 | SOULFUL = 15 26 | SURPRISE = 16 27 | SWEET = 17 28 | WARY = 18 29 | 30 | def __str__(self) -> str: 31 | return self.name.lower() 32 | 33 | 34 | def emotion_from_str(emotion_str: str) -> Optional[Emotion]: 35 | try: 36 | return Emotion[emotion_str.upper()] 37 | except KeyError: 38 | print(f"Unknown emotion: {emotion_str}") 39 | return None 40 | 41 | 42 | def read_emotion_file(emotions_filename: str): 43 | with open(emotions_filename, 'r', encoding="utf-8") as f: 44 | emotion_list = json.load(f) 45 | for entry in emotion_list: 46 | if len(entry) != 2: 47 | print(f"Malformed entry in emotion file: {entry}") 48 | return None 49 | result = [] 50 | for filename, emotion_strs in emotion_list: 51 | emotions = [emotion_from_str(x) for x in emotion_strs] 52 | if None in emotions: 53 | print(f"Unknown emotion in emotions list for dataset file '{filename}': {emotions}") 54 | return None 55 | if not (2 <= len(emotions) <= 3): 56 | print(f"Invalid emotion count for dataset file '{filename}'") 57 | return None 58 | result.append((filename, emotions)) 59 | 60 | print(f"Successfully parsed emotion file with {len(result)} entries.") 61 | return result 62 | 63 | 64 | def emotions_one_hot(emotions_list: [Emotion]) -> torch.Tensor: 65 | emotions_int_list = [int(x) for x in emotions_list] 66 | result = torch.zeros(len(Emotion)) 67 | result[emotions_int_list] = 1 68 | return result 69 | -------------------------------------------------------------------------------- /utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from typing import Union 4 | 5 | import torch 6 | 7 | logger = logging.getLogger("checkpoint") 8 | logger.addHandler(logging.StreamHandler()) 9 | logger.setLevel(logging.INFO) 10 | 11 | 12 | def get_checkpoint_filename(checkpoint_root: str, checkpoint_name: str, epoch: int = None) -> str: 13 | suffix = f"-{epoch}" if epoch is not None else "" 14 | return f"{checkpoint_root}/{checkpoint_name}{suffix}.pt" 15 | 16 | 17 | def save_checkpoint(checkpoint_root: str, checkpoint_name: str, epochs_done: int, backup_epochs: int, 18 | models: [Union[torch.nn.Module, torch.optim.Optimizer]]): 19 | checkpoint_dict = {} 20 | for i, model in enumerate(models): 21 | checkpoint_dict[f"{i}_state_dict"] = model.state_dict() 22 | checkpoint_dict[f"epochs_done"] = epochs_done 23 | 24 | if not backup_epochs: 25 | # Unconditional save 26 | filename = get_checkpoint_filename(checkpoint_root, checkpoint_name) 27 | torch.save(checkpoint_dict, filename) 28 | logger.info(f"{filename} saved") 29 | if backup_epochs and epochs_done and epochs_done % backup_epochs == 0: 30 | # Regular backup 31 | filename = get_checkpoint_filename(checkpoint_root, checkpoint_name, epochs_done) 32 | torch.save(checkpoint_dict, filename) 33 | logger.info(f"Backup {filename} saved") 34 | 35 | 36 | def load_checkpoint(checkpoint_root: str, checkpoint_name: str, 37 | models: [Union[torch.nn.Module, torch.optim.Optimizer]]) -> int: 38 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 39 | filename = get_checkpoint_filename(checkpoint_root, checkpoint_name) 40 | 41 | if os.path.isfile(filename): 42 | logger.info(f"Found {filename}, loading") 43 | checkpoint = torch.load(filename, map_location=device) 44 | for i, model in enumerate(models): 45 | model.load_state_dict(checkpoint[f"{i}_state_dict"]) 46 | epochs_done = checkpoint[f"epochs_done"] 47 | logger.info(f"{filename} loaded") 48 | return epochs_done 49 | else: 50 | return 0 51 | -------------------------------------------------------------------------------- /utils/text_fitter.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageFont, features 2 | 3 | from .bboxes import BBox 4 | 5 | 6 | # assert features.check('raqm') 7 | 8 | 9 | def binary_search(predicate, lower_bound, upper_bound, default_value=None): 10 | ret = default_value 11 | while lower_bound + 1 < upper_bound: 12 | mid = lower_bound + (upper_bound - lower_bound) // 2 13 | if predicate(mid): 14 | ret = lower_bound = mid 15 | else: 16 | upper_bound = mid 17 | return ret 18 | 19 | 20 | def get_text_wh(font: ImageFont.FreeTypeFont, text: str, direction: str): 21 | left, top, right, bottom = font.getbbox(text, direction, anchor='lt') 22 | width = right - left 23 | height = bottom - top 24 | if direction == 'ttb': 25 | width, height = height, width 26 | return width, height 27 | 28 | 29 | def fit_text(text: str, boundary: BBox, font_filename: str) -> (int, str, int, int, float): 30 | boundary_width, boundary_height = boundary.width(), boundary.height() 31 | pos_wh_ratio = boundary.wh_ratio() 32 | direction = 'ttb' if pos_wh_ratio < 1 else 'ltr' 33 | 34 | def fits(fs: int): 35 | f = ImageFont.truetype(font_filename, int(fs)) 36 | text_w, text_h = get_text_wh(f, text, direction) 37 | return text_w <= boundary_width and text_h <= boundary_height 38 | 39 | cannot_fit = False 40 | font_size = binary_search(fits, 10, 80) 41 | if font_size is None: 42 | cannot_fit = True 43 | # The predicted area is too small, so the text won't fit anyway; make it readable. 44 | font_size = 20 45 | 46 | font = ImageFont.truetype(font_filename, font_size) 47 | text_width, text_height = get_text_wh(font, text, direction) 48 | 49 | x_shift = (boundary_width - text_width) // 2 50 | y_shift = (boundary_height - text_height) // 2 51 | x = boundary.x1 + x_shift 52 | y = boundary.y1 + y_shift 53 | if direction == 'ttb': 54 | x += text_width // 2 55 | else: 56 | y += text_height 57 | 58 | if cannot_fit: 59 | score = 0 60 | else: 61 | score = text_width * text_height / boundary.area() 62 | 63 | return font_size, direction, x, y, score 64 | -------------------------------------------------------------------------------- /colorer/models/colorer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | 3 | import numpy as np 4 | import torch 5 | from numpy import ndarray 6 | from torch import nn 7 | 8 | from outer.emotions import Emotion 9 | 10 | 11 | class Colorer(nn.Module): 12 | def __init__(self, z_dim: int, audio_embedding_dim: int, has_emotions: bool, num_layers: int, 13 | colors_count: int = 6): 14 | super(Colorer, self).__init__() 15 | colors_count = 12 16 | self.colors_count = colors_count 17 | in_features = z_dim + audio_embedding_dim 18 | if has_emotions: 19 | in_features += len(Emotion) 20 | out_features = colors_count 21 | 22 | feature_step = (in_features - out_features) // num_layers 23 | 24 | layers = [] 25 | for i in range(num_layers - 1): 26 | out_features = in_features - feature_step 27 | layers += [ 28 | torch.nn.Linear(in_features=in_features, out_features=out_features), 29 | # torch.nn.Dropout(0.2), 30 | torch.nn.BatchNorm1d(num_features=out_features), 31 | torch.nn.LeakyReLU(0.2) 32 | ] 33 | in_features = out_features 34 | layers += [ 35 | torch.nn.Linear(in_features=in_features, out_features=colors_count * 3), # 3 = RGB 36 | torch.nn.Sigmoid() 37 | ] 38 | 39 | self.model_ = torch.nn.Sequential(*layers) 40 | 41 | def forward(self, noise: torch.Tensor, audio_embedding: torch.Tensor, 42 | emotions: Optional[torch.Tensor]) -> List[torch.Tensor]: 43 | if emotions is not None: 44 | inp = torch.cat((noise, audio_embedding, emotions), dim=1) 45 | else: 46 | inp = torch.cat((noise, audio_embedding), dim=1) 47 | 48 | return self.model_(inp) 49 | 50 | def predict(self, noise: torch.Tensor, audio_embedding: torch.Tensor, 51 | emotions: Optional[torch.Tensor]) -> ndarray: 52 | palette = self.forward(noise, audio_embedding, emotions) 53 | palette = palette[0].detach().cpu().numpy() * 255 54 | palette = palette.reshape(-1, 3) 55 | palette = np.around(palette, 0).astype(int) 56 | return palette 57 | -------------------------------------------------------------------------------- /colorer/models/colorer_dropout.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | 3 | import numpy as np 4 | import torch 5 | from numpy import ndarray 6 | from torch import nn 7 | 8 | from outer.emotions import Emotion 9 | 10 | 11 | class Colorer2(nn.Module): 12 | def __init__(self, z_dim: int, audio_embedding_dim: int, has_emotions: bool, num_layers: int, 13 | colors_count: int = 6): 14 | super(Colorer2, self).__init__() 15 | colors_count = 12 16 | self.colors_count = colors_count 17 | in_features = z_dim + audio_embedding_dim 18 | if has_emotions: 19 | in_features += len(Emotion) 20 | out_features = colors_count 21 | 22 | feature_step = (in_features - out_features) // num_layers 23 | 24 | layers = [] 25 | for i in range(num_layers - 1): 26 | out_features = in_features - feature_step 27 | layers += [ 28 | torch.nn.Linear(in_features=in_features, out_features=out_features), 29 | torch.nn.Dropout(0.2), 30 | torch.nn.BatchNorm1d(num_features=out_features), 31 | torch.nn.LeakyReLU(0.2) 32 | ] 33 | in_features = out_features 34 | layers += [ 35 | torch.nn.Linear(in_features=in_features, out_features=colors_count * 3), # 3 = RGB 36 | torch.nn.Sigmoid() 37 | ] 38 | 39 | self.model_ = torch.nn.Sequential(*layers) 40 | 41 | def forward(self, noise: torch.Tensor, audio_embedding: torch.Tensor, 42 | emotions: Optional[torch.Tensor]) -> List[torch.Tensor]: 43 | if emotions is not None: 44 | inp = torch.cat((noise, audio_embedding, emotions), dim=1) 45 | else: 46 | inp = torch.cat((noise, audio_embedding), dim=1) 47 | 48 | return self.model_(inp) 49 | 50 | def predict(self, noise: torch.Tensor, audio_embedding: torch.Tensor, 51 | emotions: Optional[torch.Tensor]) -> ndarray: 52 | palette = self.forward(noise, audio_embedding, emotions) 53 | palette = palette[0].detach().cpu().numpy() * 255 54 | palette = palette.reshape(-1, 3) 55 | palette = np.around(palette, 0).astype(int) 56 | return palette 57 | -------------------------------------------------------------------------------- /outer/models/cover_classes.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | 5 | from outer.colors_tools import palette_to_triad_palette 6 | 7 | 8 | class CoverFigure: 9 | def __init__(self): 10 | self.points = None 11 | self.fill_color = None 12 | self.stroke_width = None 13 | self.stroke_color = None 14 | 15 | self.center_point = None 16 | self.radius = None 17 | self.deformation_points = None 18 | self.angle = None 19 | 20 | def to_dict(self): 21 | return {"points": self.points, "fill_color": self.fill_color, 22 | "stroke_width": self.stroke_width, "stroke_color": self.stroke_color} 23 | 24 | 25 | class Cover: 26 | def __init__(self): 27 | self.background_color = None 28 | self.canvas_size = None 29 | self.figures: List[CoverFigure] = [] 30 | 31 | def add_figure(self, fig: CoverFigure): 32 | self.figures.append(fig) 33 | 34 | def colorize_cover(self, colors, use_triad=False, need_stroke=False): 35 | if use_triad: 36 | device = colors.device 37 | colors = colors.detach().cpu().numpy()[None, :] 38 | colors = palette_to_triad_palette(colors)[0] 39 | colors = torch.from_numpy(colors).to(device) 40 | self.background_color = colors[0] 41 | color_ind = 1 42 | for path in self.figures: 43 | if path.fill_color is not None: 44 | # was alpha channel 45 | path.fill_color = torch.cat((colors[color_ind], path.fill_color[-1:]), dim=0) 46 | else: 47 | path.fill_color = colors[color_ind] 48 | color_ind += 1 49 | if need_stroke: 50 | if path.stroke_color is not None: 51 | # was alpha channel 52 | path.stroke_color = torch.cat((colors[color_ind], path.stroke_color[-1:]), dim=0) 53 | else: 54 | path.stroke_color = colors[color_ind] 55 | color_ind += 1 56 | else: 57 | path.stroke_color = path.fill_color 58 | 59 | def to_background_and_paths(self): 60 | return self.background_color, list(map(CoverFigure.to_dict, self.figures)) 61 | -------------------------------------------------------------------------------- /captions/get_caption_borders.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch.cuda 4 | from PIL import Image 5 | from torchvision.transforms.functional import to_tensor 6 | 7 | from captions.dataset import plot_img_with_bboxes, image_file_to_tensor 8 | from captions.models.captioner import Captioner 9 | from utils.bboxes import BBox 10 | 11 | device = "cuda" if torch.cuda.is_available() else "cpu" 12 | captioner_weights_path = "../weights/captioner.pt" 13 | 14 | cover_paths = ["0_s_1_g2.png", "0_s_6_g2.png"] 15 | artist_name = "" 16 | track_name = "" 17 | 18 | 19 | def png_data_to_pil_image(f_name, canvas_size: Optional[int] = None) -> Image: 20 | result = Image.open(f_name).convert('RGB') 21 | if canvas_size is not None: 22 | result = result.resize((canvas_size, canvas_size)) 23 | return result 24 | 25 | 26 | def main(): 27 | captioner_canvas_size_ = 256 28 | num_captioner_conv_layers = 3 29 | num_captioner_linear_layers = 2 30 | captioner = Captioner( 31 | canvas_size=captioner_canvas_size_, 32 | num_conv_layers=num_captioner_conv_layers, 33 | num_linear_layers=num_captioner_linear_layers 34 | ).to(device) 35 | captioner_weights = torch.load(captioner_weights_path, map_location=device) 36 | captioner.load_state_dict(captioner_weights["0_state_dict"]) 37 | captioner.eval() 38 | with torch.no_grad(): 39 | covs = [] 40 | for cover_path in cover_paths: 41 | covs.append(to_tensor(png_data_to_pil_image(cover_path, captioner_canvas_size_))) 42 | pos_pred, color_pred = captioner(torch.stack(covs)) 43 | pos_preds = torch.round(pos_pred * captioner_canvas_size_).to(int).numpy() 44 | color_preds = torch.round(color_pred * 255).to(int).numpy() 45 | for i, cover_path in enumerate(cover_paths): 46 | x_pos_pred = pos_preds[i] 47 | x_color_pred = color_preds[i] 48 | print(x_pos_pred) 49 | print(x_color_pred) 50 | bbox1 = BBox(*x_pos_pred[:4]) 51 | col1 = x_color_pred[:3] 52 | bbox2 = BBox(*x_pos_pred[4:]) 53 | col2 = x_color_pred[3:] 54 | original_cover = image_file_to_tensor(cover_path, captioner_canvas_size_) 55 | plot_img_with_bboxes(original_cover, [(bbox1, col1), (bbox2, col2)]) 56 | 57 | 58 | if __name__ == '__main__': 59 | main() 60 | -------------------------------------------------------------------------------- /outer/test_blur.py: -------------------------------------------------------------------------------- 1 | def blr(): 2 | from PIL import Image 3 | from torchvision.transforms import transforms 4 | from torchvision.transforms.functional import gaussian_blur 5 | 6 | im = Image.open("../dataset_emoji_4/clean_covers/demo_track_1.jpg") 7 | convert_tensor = transforms.ToTensor() 8 | tens = convert_tensor(im) 9 | blur = gaussian_blur(tens, kernel_size=29) 10 | to_pil = transforms.ToPILImage() 11 | pil = to_pil(blur) 12 | pil.save('out2.png') 13 | 14 | 15 | def wand_rend_test(): 16 | from outer.SVGContainer import wand_rendering 17 | f_name = "../generated_covers_tests_rand2/Sound Quelle - Volga-1.svg" 18 | f_name = "kek3.svg" 19 | with open(f_name, mode="r", encoding="utf-8") as f: 20 | img = wand_rendering(f.read()) 21 | with open("test.png", mode="wb") as f: 22 | f.write(img) 23 | 24 | 25 | def svglib_rend_test(): 26 | from outer.SVGContainer import svglib_rendering, svglib_rendering_from_file 27 | # f_name = "../generated_covers_svgcont2/&me - The Rapture &&%$Pt.II-1.svg" 28 | f_name = "kek2.svg" 29 | with open(f_name, mode="r", encoding="utf-8") as f: 30 | img = svglib_rendering_from_file(f.name) 31 | with open("test.png", mode="wb") as f: 32 | f.write(img) 33 | 34 | 35 | def load_test(): 36 | from outer.SVGContainer import SVGContainer 37 | data = "mega_kek.svg" 38 | svg_cont = SVGContainer.load_svg(open(data).read()) 39 | svg_cont.save_png("kek3.png", renderer_type="wand") 40 | svg_cont.save_svg("kek3_.svg") 41 | 42 | 43 | def add_text_test(): 44 | from outer.SVGContainer import SVGContainer 45 | from service_utils import paste_caption 46 | data = "kek2_notext.svg" 47 | svg_cont = SVGContainer.load_svg(open(data).read()) 48 | pil = svg_cont.to_PIL(renderer_type="wand").convert("RGB") 49 | paste_caption(svg_cont, pil, "Nyash", "Myash", "../fonts") 50 | pil.save("kek2_with_text.png") 51 | # svg_cont.save_png("kek2_with_text.png", renderer_type="wand") 52 | svg_cont.save_svg("kek2_with_text.svg") 53 | 54 | 55 | def check_to_obj(): 56 | from outer.SVGContainer import SVGContainer 57 | data = "mega_kek.svg" 58 | svg_cont = SVGContainer.load_svg(open(data).read()) 59 | obj = svg_cont.to_obj() 60 | import json 61 | res = json.dumps(obj, indent=4) 62 | print(res) 63 | 64 | 65 | # wand_rend_test() 66 | # load_test() 67 | # add_text_test() 68 | check_to_obj() 69 | -------------------------------------------------------------------------------- /outer/colors_tools.py: -------------------------------------------------------------------------------- 1 | import colorsys 2 | import numpy as np 3 | 4 | 5 | def to_int_color(color): 6 | return tuple(int(x) for x in color) 7 | 8 | 9 | def rgb_to_hsv(r, g, b): 10 | return list(colorsys.rgb_to_hsv(r, g, b)) 11 | 12 | 13 | def hsv_to_rgb(h, s, v): 14 | return list(colorsys.hsv_to_rgb(h, s, v)) 15 | 16 | 17 | def find_median_rgb(x1, x2): 18 | hsv_x1 = rgb_to_hsv(x1[0], x1[1], x1[2]) 19 | hsv_x2 = rgb_to_hsv(x2[0], x2[1], x2[2]) 20 | mean = (hsv_x1[0] + hsv_x2[1]) / 2 21 | res1 = hsv_x1.copy() 22 | res1[0] = mean 23 | res2 = hsv_x2.copy() 24 | res2[0] = mean 25 | res3 = hsv_x1.copy() 26 | res3[0] = 1 - mean 27 | res4 = hsv_x2.copy() 28 | res4[0] = 1 - mean 29 | return vals_to_rgb([res1, res2, res3, res4]) 30 | 31 | 32 | def vals_to_rgb(lst): 33 | return [hsv_to_rgb(x[0], x[1], x[2]) for x in lst] 34 | 35 | 36 | def palette_to_triad_palette(predicted_palette, base_colors_num=3): 37 | predicted_palette = predicted_palette[:, :base_colors_num] 38 | btch_new_palette = [] 39 | for btch_ind, pal in enumerate(predicted_palette): 40 | new_colors = [] 41 | new_colors.extend(pal.copy()) 42 | for i, x in enumerate(pal): 43 | for j in range(i + 1, len(pal)): 44 | c = find_median_rgb(x, pal[j]) 45 | new_colors.extend(c) 46 | btch_new_palette.append(new_colors) 47 | return np.array(btch_new_palette) 48 | 49 | 50 | def contrast_color_old(r, g, b): 51 | h, s, v = rgb_to_hsv(r, g, b) 52 | new_h = 0.5 + h if h < 0.5 else h - 0.5 53 | new_v = 50 + v if v < 50 else v - 50 54 | # new_h = 1 - h 55 | return hsv_to_rgb(new_h, 0, new_v) 56 | 57 | 58 | def contrast_color(r, g, b): 59 | from utils.color_contrast import sufficient_contrast, contrast 60 | 61 | background_color = (r, g, b) 62 | i_r, i_g, i_b = caption_color = invert_color(r, g, b) 63 | if sufficient_contrast(caption_color, background_color): 64 | return caption_color 65 | # logger.warning("CAPT: Insufficient contrast, fixing") 66 | black = (10, 10, 10) 67 | white = (230, 230, 230) 68 | black_contrast = contrast(background_color, black) 69 | white_contrast = contrast(background_color, white) 70 | return black if black_contrast > white_contrast else white 71 | 72 | 73 | def invert_color(r, g, b): 74 | return 255 - r, 255 - g, 255 - b 75 | 76 | 77 | if __name__ == '__main__': 78 | print(find_median_rgb((0, 255 / 255, 69 / 255), (0.1, 0.5, 0.3))) 79 | -------------------------------------------------------------------------------- /colorer/transfer_style.py: -------------------------------------------------------------------------------- 1 | import random 2 | from enum import Enum 3 | 4 | from lxml import etree as ET 5 | 6 | from colorer.music_palette_dataset import get_main_rgb_palette, get_main_rgb_palette2 7 | 8 | 9 | class TransferAlgoType(Enum): 10 | ColorThiefUnsorted = 1 11 | UnsortedClustering = 2 12 | SortedClustering = 3 13 | 14 | 15 | def transfer(root, png_path, algo_type, use_random): 16 | nodes = root.findall("*[@fill]") 17 | if algo_type == TransferAlgoType.ColorThiefUnsorted: 18 | palette = get_main_rgb_palette(png_path, color_count=len(nodes)) 19 | elif algo_type == TransferAlgoType.UnsortedClustering: 20 | palette = get_main_rgb_palette2(png_path, color_count=len(nodes), sort_colors=False) 21 | else: 22 | palette = get_main_rgb_palette2(png_path, color_count=len(nodes), sort_colors=True) 23 | if len(nodes) < len(palette): 24 | palette = palette[:nodes] 25 | if use_random: 26 | random.shuffle(palette) 27 | for ind, n in enumerate(nodes): 28 | r, g, b = palette[ind % len(palette)] 29 | n.set("fill", f"rgb({r}, {g}, {b})") 30 | 31 | 32 | def transfer_style_str(png_path, svg_str, 33 | algo_type: TransferAlgoType = TransferAlgoType.SortedClustering, 34 | use_random=False): 35 | print(f"Transferring style from `{png_path}` to svg file.") 36 | root = ET.fromstring(svg_str) 37 | transfer(root, png_path, algo_type, use_random) 38 | return ET.tostring(root, encoding="utf-8").decode() 39 | 40 | 41 | def transfer_style(svg_path, png_path, svg_out_path, algo_type: TransferAlgoType = TransferAlgoType.SortedClustering, 42 | use_random=True): 43 | print(f"Transferring style from `{png_path}` to svg file `{svg_path}`. Out path: `{svg_out_path}`") 44 | tree = ET.parse(svg_path) 45 | root = tree.getroot() 46 | transfer(root, png_path, algo_type, use_random) 47 | tree.write(svg_out_path, encoding="utf-8") 48 | 49 | 50 | if __name__ == '__main__': 51 | # transfer_style("test_svg.svg", "A S T R O - Change.jpg", "out.svg") 52 | transfer_style("ABBA - I Do, I Do, I Do, I Do, I Do-1.svg", "img.png", "out1.svg", 53 | algo_type=TransferAlgoType.ColorThiefUnsorted, use_random=True) 54 | # transfer_style("psvg_Marsicans - Wake Up Freya.mp3-1.svg", "A S T R O - Change.jpg", "out2.svg", 55 | # algo_type=TransferAlgoType.ColorThiefUnsorted, use_random=True) 56 | # transfer_style("test_svg2.svg", "img.png", "out2.svg") 57 | # transfer_style("img3.svg", "img.png", "out3.svg") 58 | -------------------------------------------------------------------------------- /captioner_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | import argparse 4 | import logging 5 | 6 | import torch 7 | from torch.utils.data.dataloader import DataLoader 8 | 9 | from captions.dataset import CaptionDataset 10 | 11 | from captions.train import make_models, train 12 | 13 | logger = logging.getLogger("captioner_main") 14 | logger.addHandler(logging.StreamHandler()) 15 | logger.setLevel(logging.INFO) 16 | 17 | 18 | def get_train_data(checkpoint_dir: str, original_cover_dir: str, clean_cover_dir: str, 19 | batch_size: int, canvas_size: int) -> DataLoader: 20 | dataset = CaptionDataset(checkpoint_dir, original_cover_dir, clean_cover_dir, canvas_size) 21 | dataloader = DataLoader( 22 | dataset, 23 | batch_size=batch_size, 24 | shuffle=True 25 | ) 26 | 27 | return dataloader 28 | 29 | 30 | def main(): 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument("--original_covers", help="Directory with the original cover images", 33 | type=str, default="./original_covers") 34 | parser.add_argument("--clean_covers", help="Directory with the cover images with captions removed", 35 | type=str, default="./clean_covers") 36 | parser.add_argument("--checkpoint_root", help="Checkpoint location", type=str, default="./checkpoint") 37 | parser.add_argument("--lr", help="Learning rate", type=float, default=0.001) 38 | parser.add_argument("--epochs", help="Number of epochs to train for", type=int, default=138) 39 | parser.add_argument("--batch_size", help="Batch size", type=int, default=64) 40 | parser.add_argument("--canvas_size", help="Image canvas size for learning", type=int, default=256) 41 | parser.add_argument("--display_steps", help="How often to plot the samples", type=int, default=10) 42 | parser.add_argument("--plot_grad", help="Whether to plot the gradients", default=False, action="store_true") 43 | args = parser.parse_args() 44 | print(args) 45 | 46 | # Network properties 47 | num_conv_layers = 3 48 | num_linear_layers = 2 49 | 50 | # Plot properties 51 | bin_steps = 20 # How many steps to aggregate with mean for each plot point 52 | 53 | logger.info("--- Starting captioner_main ---") 54 | 55 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 56 | dataloader = get_train_data( 57 | args.checkpoint_root, args.original_covers, args.clean_covers, 58 | args.batch_size, args.canvas_size 59 | ) 60 | 61 | logger.info("--- Captioner training ---") 62 | captioner = make_models( 63 | canvas_size=args.canvas_size, 64 | num_conv_layers=num_conv_layers, 65 | num_linear_layers=num_linear_layers, 66 | device=device 67 | ) 68 | train(dataloader, captioner, device, { 69 | # Common 70 | "display_steps": args.display_steps, 71 | "bin_steps": bin_steps, 72 | "checkpoint_root": args.checkpoint_root, 73 | "n_epochs": args.epochs, 74 | "lr": args.lr, 75 | "plot_grad": args.plot_grad, 76 | }) 77 | 78 | 79 | if __name__ == '__main__': 80 | main() 81 | -------------------------------------------------------------------------------- /outer/models/base_generator.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | 3 | import torch 4 | 5 | from .cover_classes import Cover 6 | from ..SVGContainer import SVGContainer 7 | from ..emotions import Emotion 8 | from ..represent import as_diffvg_render, as_SVGCont, as_SVGCont2 9 | 10 | 11 | class BaseGenerator(torch.nn.Module): 12 | def fwd_cover(self, noise: torch.Tensor, audio_embedding: torch.Tensor, 13 | emotions: Optional[torch.Tensor]) -> Cover: 14 | pass 15 | 16 | def fwd_svgcont(self, noise: torch.Tensor, audio_embedding: torch.Tensor, 17 | emotions: Optional[torch.Tensor]) -> SVGContainer: 18 | pass 19 | 20 | def fwd_images(self, noise: torch.Tensor, audio_embedding: torch.Tensor, 21 | emotions: Optional[torch.Tensor]) -> torch.Tensor: 22 | pass 23 | 24 | def forward(self, noise: torch.Tensor, audio_embedding: torch.Tensor, 25 | emotions: Optional[torch.Tensor], return_psvg=False): 26 | if emotions is not None: 27 | inp = torch.cat((noise, audio_embedding, emotions), dim=1) 28 | else: 29 | inp = torch.cat((noise, audio_embedding), dim=1) 30 | print(f"inp: {inp.shape}") 31 | all_shape_params = self.model_(inp) 32 | print(f"all_shape_params: {all_shape_params.shape}") 33 | assert not torch.any(torch.isnan(all_shape_params)) 34 | 35 | action = as_SVGCont2 if return_psvg else as_diffvg_render 36 | 37 | result = [] 38 | for shape_params in all_shape_params: 39 | index = 0 40 | 41 | inc = 3 # RGB, no transparency for the background 42 | background_color = shape_params[index: index + inc] 43 | index += inc 44 | 45 | paths = [] 46 | for _ in range(self.path_count_): 47 | path = {} 48 | 49 | inc = (self.path_segment_count_ * 3 + 1) * 2 50 | path["points"] = (shape_params[index: index + inc].view(-1, 2) * 2 - 0.5) * self.canvas_size_ 51 | index += inc 52 | 53 | path["stroke_width"] = shape_params[index] * self.max_stroke_width_ * self.canvas_size_ 54 | index += 1 55 | 56 | # Colors 57 | inc = 4 # RGBA 58 | path["stroke_color"] = shape_params[index: index + inc] 59 | index += inc 60 | path["fill_color"] = shape_params[index: index + inc] 61 | index += inc 62 | 63 | paths.append(path) 64 | 65 | assert len(paths) == self.path_count_ 66 | image = action( 67 | paths=paths, 68 | background_color=background_color, 69 | canvas_size=self.canvas_size_ 70 | ) 71 | result.append(image) 72 | 73 | if not return_psvg: 74 | result = torch.stack(result) 75 | batch_size = audio_embedding.shape[0] 76 | result_channels = 3 # RGB 77 | assert result.shape == (batch_size, result_channels, self.canvas_size_, self.canvas_size_) 78 | 79 | return result 80 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Project ignores 2 | captions/captioned_images 3 | colorer/check_volume 4 | colorer/check_volume_ckpts 5 | colorer/*.jpg 6 | colorer/*.svg 7 | colorer/*.png 8 | diffvg_tmp_old 9 | diffvg_tmp 10 | fonts 11 | fonts_small 12 | fonts_big 13 | gen_samples_demo 14 | generated_covers_* 15 | not_found_audio 16 | OLD 17 | old_generated 18 | plots 19 | test_audio 20 | vae_check 21 | vae_train.py 22 | utils/sequences.py 23 | utils/tensors.py 24 | utils/weight_init.py 25 | utils/weight_clipper.py 26 | scripts 27 | outer/models/former 28 | outer/models/raster_generator.py 29 | outer/models/rnn_generator.py 30 | outer/*.svg 31 | outer/*.png 32 | repair_all_bash_scripts_to_linux.sh 33 | old_README.md 34 | dataset_demo_4 35 | dataset_emoji_4 36 | dataset_emoji_52 37 | dataset_full_covers 38 | dataset_small_20 39 | diploma_test 40 | weights 41 | 42 | .idea 43 | # Byte-compiled / optimized / DLL files 44 | __pycache__/ 45 | *.py[cod] 46 | *$py.class 47 | 48 | # C extensions 49 | *.so 50 | 51 | # Distribution / packaging 52 | .Python 53 | build/ 54 | develop-eggs/ 55 | dist/ 56 | downloads/ 57 | eggs/ 58 | .eggs/ 59 | lib/ 60 | lib64/ 61 | parts/ 62 | sdist/ 63 | var/ 64 | wheels/ 65 | pip-wheel-metadata/ 66 | share/python-wheels/ 67 | *.egg-info/ 68 | .installed.cfg 69 | *.egg 70 | MANIFEST 71 | 72 | # PyInstaller 73 | # Usually these files are written by a python script from a template 74 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 75 | *.manifest 76 | *.spec 77 | 78 | # Installer logs 79 | pip-log.txt 80 | pip-delete-this-directory.txt 81 | 82 | # Unit test / coverage reports 83 | htmlcov/ 84 | .tox/ 85 | .nox/ 86 | .coverage 87 | .coverage.* 88 | .cache 89 | nosetests.xml 90 | coverage.xml 91 | *.cover 92 | *.py,cover 93 | .hypothesis/ 94 | .pytest_cache/ 95 | 96 | # Translations 97 | *.mo 98 | *.pot 99 | 100 | # Django stuff: 101 | *.log 102 | local_settings.py 103 | db.sqlite3 104 | db.sqlite3-journal 105 | 106 | # Flask stuff: 107 | instance/ 108 | .webassets-cache 109 | 110 | # Scrapy stuff: 111 | .scrapy 112 | 113 | # Sphinx documentation 114 | docs/_build/ 115 | 116 | # PyBuilder 117 | target/ 118 | 119 | # Jupyter Notebook 120 | .ipynb_checkpoints 121 | 122 | # IPython 123 | profile_default/ 124 | ipython_config.py 125 | 126 | # pyenv 127 | .python-version 128 | 129 | # pipenv 130 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 131 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 132 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 133 | # install all needed dependencies. 134 | #Pipfile.lock 135 | 136 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 137 | __pypackages__/ 138 | 139 | # Celery stuff 140 | celerybeat-schedule 141 | celerybeat.pid 142 | 143 | # SageMath parsed files 144 | *.sage.py 145 | 146 | # Environments 147 | .env 148 | .venv 149 | env/ 150 | venv/ 151 | ENV/ 152 | env.bak/ 153 | venv.bak/ 154 | 155 | # Spyder project settings 156 | .spyderproject 157 | .spyproject 158 | 159 | # Rope project settings 160 | .ropeproject 161 | 162 | # mkdocs documentation 163 | /site 164 | 165 | # mypy 166 | .mypy_cache/ 167 | .dmypy.json 168 | dmypy.json 169 | 170 | # Pyre type checker 171 | .pyre/ 172 | -------------------------------------------------------------------------------- /outer/models/my_discriminator.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torchvision.transforms.functional import gaussian_blur 5 | 6 | from ..emotions import Emotion 7 | 8 | 9 | def transform_img_for_disc(img_tensor: torch.Tensor) -> torch.Tensor: 10 | return gaussian_blur(img_tensor, kernel_size=29) 11 | 12 | 13 | class MyDiscriminator(torch.nn.Module): 14 | def __init__(self, canvas_size: int, audio_embedding_dim: int, has_emotions: bool, 15 | num_conv_layers: int, num_linear_layers: int): 16 | super(MyDiscriminator, self).__init__() 17 | 18 | layers = [] 19 | in_channels = 3 # RGB 20 | out_channels = in_channels * 8 21 | double_channels = True 22 | conv_kernel_size = 6 23 | conv_stride = 4 24 | conv_padding = 1 25 | ds_size = canvas_size 26 | 27 | for i in range(num_conv_layers): 28 | # Dimensions of the downsampled image 29 | ds_size = (ds_size + 2 * conv_padding - conv_kernel_size) // conv_stride + 1 30 | 31 | layers += [ 32 | torch.nn.Conv2d( 33 | in_channels=in_channels, out_channels=out_channels, 34 | kernel_size=conv_kernel_size, 35 | stride=conv_stride, 36 | padding=conv_padding, 37 | bias=False 38 | ), 39 | torch.nn.LayerNorm([ds_size, ds_size]), 40 | torch.nn.LeakyReLU(0.1) 41 | ] 42 | 43 | in_channels = out_channels 44 | if double_channels: 45 | out_channels *= 2 46 | double_channels = False 47 | else: 48 | double_channels = True 49 | conv_kernel_size = max(conv_kernel_size - 1, 3) 50 | conv_stride = max(conv_stride - 1, 2) 51 | 52 | self.model = torch.nn.Sequential(*layers) 53 | 54 | out_channels //= 2 # Output channels of the last Conv2d 55 | img_dim = out_channels * (ds_size ** 2) 56 | 57 | in_features = img_dim + audio_embedding_dim 58 | if has_emotions: 59 | in_features += len(Emotion) 60 | layers = [] 61 | for i in range(num_linear_layers - 1): 62 | out_features = in_features // 64 63 | layers += [ 64 | torch.nn.Linear(in_features=in_features, out_features=out_features), 65 | torch.nn.LeakyReLU(0.2), 66 | torch.nn.Dropout2d(0.2) 67 | ] 68 | in_features = out_features 69 | layers += [ 70 | torch.nn.Linear(in_features=in_features, out_features=1) 71 | ] 72 | self.adv_layer = torch.nn.Sequential(*layers) 73 | 74 | def forward(self, img: torch.Tensor, audio_embedding: torch.Tensor, 75 | emotions: Optional[torch.Tensor]) -> torch.Tensor: 76 | transformed_img = transform_img_for_disc(img) 77 | output = self.model(transformed_img) 78 | output = output.reshape(output.shape[0], -1) # Flatten elements in the batch 79 | if emotions is None: 80 | validity = self.adv_layer(torch.cat((output, audio_embedding), dim=1)) 81 | else: 82 | validity = self.adv_layer(torch.cat((output, audio_embedding, emotions), dim=1)) 83 | assert not torch.any(torch.isnan(validity)) 84 | return validity 85 | -------------------------------------------------------------------------------- /outer/models/discriminator.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torchvision.transforms.functional import gaussian_blur 5 | 6 | from ..emotions import Emotion 7 | 8 | 9 | def transform_img_for_disc(img_tensor: torch.Tensor) -> torch.Tensor: 10 | return gaussian_blur(img_tensor, kernel_size=29) 11 | 12 | 13 | class Discriminator(torch.nn.Module): 14 | 15 | def __init__(self, canvas_size: int, audio_embedding_dim: int, has_emotions: bool, 16 | num_conv_layers: int, num_linear_layers: int): 17 | super(Discriminator, self).__init__() 18 | 19 | layers = [] 20 | in_channels = 3 # RGB 21 | out_channels = in_channels * 8 22 | double_channels = True 23 | conv_kernel_size = 6 24 | conv_stride = 4 25 | conv_padding = 1 26 | ds_size = canvas_size 27 | 28 | for i in range(num_conv_layers): 29 | # Dimensions of the downsampled image 30 | ds_size = (ds_size + 2 * conv_padding - conv_kernel_size) // conv_stride + 1 31 | 32 | layers += [ 33 | torch.nn.Conv2d( 34 | in_channels=in_channels, out_channels=out_channels, 35 | kernel_size=conv_kernel_size, 36 | stride=conv_stride, 37 | padding=conv_padding, 38 | bias=False 39 | ), 40 | torch.nn.LayerNorm([ds_size, ds_size]), 41 | # torch.nn.LeakyReLU(0.1), 42 | torch.nn.ELU(), 43 | ] 44 | 45 | in_channels = out_channels 46 | if double_channels: 47 | out_channels *= 2 48 | double_channels = False 49 | else: 50 | double_channels = True 51 | conv_kernel_size = max(conv_kernel_size - 1, 3) 52 | conv_stride = max(conv_stride - 1, 2) 53 | 54 | self.model = torch.nn.Sequential(*layers) 55 | 56 | out_channels //= 2 # Output channels of the last Conv2d 57 | img_dim = out_channels * (ds_size ** 2) 58 | 59 | in_features = img_dim + audio_embedding_dim 60 | if has_emotions: 61 | in_features += len(Emotion) 62 | layers = [] 63 | for i in range(num_linear_layers - 1): 64 | out_features = in_features // 64 65 | layers += [ 66 | torch.nn.Linear(in_features=in_features, out_features=out_features), 67 | # torch.nn.LeakyReLU(0.2), 68 | torch.nn.ELU(), 69 | # torch.nn.Dropout2d(0.2) 70 | ] 71 | in_features = out_features 72 | layers += [ 73 | torch.nn.Linear(in_features=in_features, out_features=1) 74 | ] 75 | self.adv_layer = torch.nn.Sequential(*layers) 76 | 77 | def forward(self, img: torch.Tensor, audio_embedding: torch.Tensor, 78 | emotions: Optional[torch.Tensor]) -> torch.Tensor: 79 | transformed_img = transform_img_for_disc(img) 80 | output = self.model(transformed_img) 81 | output = output.reshape(output.shape[0], -1) # Flatten elements in the batch 82 | if emotions is None: 83 | inp = torch.cat((output, audio_embedding), dim=1) 84 | else: 85 | inp = torch.cat((output, audio_embedding, emotions), dim=1) 86 | validity = self.adv_layer(inp) 87 | assert not torch.any(torch.isnan(validity)) 88 | return validity 89 | -------------------------------------------------------------------------------- /utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from itertools import repeat 4 | import torch 5 | from multiprocessing import Pool 6 | from pathlib import Path 7 | from outer.metadata_extractor import get_tag_map 8 | 9 | logger = logging.getLogger("dataset_utils") 10 | logger.addHandler(logging.StreamHandler()) 11 | # logger.setLevel(logging.INFO) 12 | logger.setLevel(logging.WARNING) 13 | 14 | MUSIC_EXTENSIONS = ['flac', 'mp3'] 15 | IMAGE_EXTENSION = '.jpg' 16 | 17 | 18 | def filename_extension(file_path: str) -> str: 19 | ext_with_dot = os.path.splitext(file_path)[1] 20 | return ext_with_dot[1:] 21 | 22 | 23 | def replace_extension(file_path: str, new_extension: str) -> str: 24 | last_dot = file_path.rfind('.') 25 | return file_path[:last_dot] + new_extension 26 | 27 | 28 | def get_tensor_file(checkpoint_root: str, f: str): 29 | return f'{checkpoint_root}/{f}.pt' 30 | 31 | 32 | def get_cover_from_tags(f): 33 | return get_tag_map(f)['cover'] 34 | 35 | 36 | def process_music_file_to_tensor(f: str, checkpoint_root: str, audio_dir: str, cover_dir: str, num: int = None): 37 | music_tensor_f = get_tensor_file(checkpoint_root, f) 38 | if not os.path.isfile(music_tensor_f): 39 | create_and_save_music_tensor(audio_dir, f, music_tensor_f, num) 40 | 41 | cover_file = f'{cover_dir}/{replace_extension(f, IMAGE_EXTENSION)}' 42 | if not os.path.isfile(cover_file): 43 | logger.info(f'No cover file for {f}, attempting extraction...') 44 | music_file = f'{audio_dir}/{f}' 45 | image = get_cover_from_tags(music_file) 46 | assert image is not None, f'Failed to extract cover for {f}, aborting!' 47 | image.save(cover_file) 48 | logger.info(f'Cover image for {f} extracted and saved') 49 | 50 | 51 | def create_and_save_music_tensor(audio_dir, f, music_tensor_f, num=None): 52 | logger.info(f'No tensor for {f}, generating...') 53 | music_file = f'{audio_dir}/{f}' 54 | create_and_save_music_file_to_tensor_inner(music_file, music_tensor_f, num) 55 | 56 | 57 | def create_audio_tensors_for_folder(folder_path, out_folder): 58 | files = os.listdir(folder_path) 59 | os.makedirs(out_folder, exist_ok=True) 60 | for f in files: 61 | create_and_save_music_file_to_tensor_inner(f"{folder_path}/{f}", f"{out_folder}/{f}.pt") 62 | 63 | 64 | def create_and_save_music_file_to_tensor_inner(input_file, out_file, num=None): 65 | from outer.audio_extractor import audio_to_embedding 66 | embedding = torch.from_numpy(audio_to_embedding(input_file, num)) 67 | torch.save(embedding, out_file) 68 | 69 | 70 | def create_music_tensor_files(checkpoint_root_, audio_dir, cover_dir): 71 | completion_marker = f'{checkpoint_root_}/COMPLETE' 72 | if os.path.isfile(completion_marker): 73 | dataset_files = sorted([ 74 | f[:-len('.pt')] for f in os.listdir(checkpoint_root_) 75 | if f.endswith('.pt') 76 | ]) 77 | for f in dataset_files: 78 | cover_file = f'{cover_dir}/{replace_extension(f, IMAGE_EXTENSION)}' 79 | assert os.path.isfile(cover_file), f'No cover for {f}' 80 | logger.info(f'Dataset considered complete with {len(dataset_files)} tracks and covers.') 81 | else: 82 | logger.info('Building the dataset based on music') 83 | dataset_files = sorted([ 84 | f for f in os.listdir(audio_dir) 85 | if os.path.isfile(f'{audio_dir}/{f}') and filename_extension(f) in MUSIC_EXTENSIONS 86 | ]) 87 | 88 | Path(checkpoint_root_).mkdir(exist_ok=True) 89 | with Pool(maxtasksperchild=50) as pool: 90 | pool.starmap( 91 | process_music_file_to_tensor, 92 | zip(dataset_files, repeat(checkpoint_root_), repeat(audio_dir), repeat(cover_dir), 93 | [i for i in range(len(dataset_files))]), 94 | chunksize=100 95 | ) 96 | logger.info('Marking the dataset complete.') 97 | Path(completion_marker).touch(exist_ok=False) 98 | return dataset_files 99 | -------------------------------------------------------------------------------- /captions/models/captioner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from utils.edges import detect_edges 4 | 5 | 6 | def make_linear_layer(in_features: int, out_features: int, num_layers: int): 7 | layers = [] 8 | out_features_final = out_features 9 | for i in range(num_layers - 1): 10 | out_features = in_features // 4 11 | layers += [ 12 | torch.nn.Linear(in_features=in_features, out_features=out_features), 13 | torch.nn.LeakyReLU(), 14 | ] 15 | if i != num_layers - 2: 16 | layers.append(torch.nn.Dropout2d(0.2)) 17 | in_features = out_features 18 | layers += [ 19 | torch.nn.Linear(in_features=in_features, out_features=out_features_final), 20 | torch.nn.Sigmoid() 21 | ] 22 | return torch.nn.Sequential(*layers) 23 | 24 | 25 | def make_conv(in_channels: int, out_channels: int, num_layers: int, canvas_size: int) -> (torch.nn.Module, int, int): 26 | layers = [] 27 | ds_size = canvas_size 28 | 29 | conv_kernel_size = 3 30 | conv_stride = 2 31 | conv_padding = 1 32 | 33 | for i in range(num_layers): 34 | layers += [ 35 | torch.nn.Conv2d( 36 | in_channels=in_channels, out_channels=out_channels, 37 | kernel_size=conv_kernel_size, 38 | stride=conv_stride, 39 | padding=conv_padding 40 | ), 41 | torch.nn.LeakyReLU(), 42 | torch.nn.Conv2d( 43 | in_channels=out_channels, out_channels=out_channels, 44 | kernel_size=conv_kernel_size, 45 | stride=conv_stride, 46 | padding=conv_padding 47 | ), 48 | torch.nn.LeakyReLU(), 49 | torch.nn.BatchNorm2d(out_channels), 50 | torch.nn.Dropout2d(0.2) 51 | ] 52 | 53 | in_channels = out_channels 54 | out_channels *= 2 55 | # 2 convolutions 56 | for _ in range(2): 57 | ds_size = (ds_size + 2 * conv_padding - conv_kernel_size) // conv_stride + 1 58 | conv_kernel_size = min(conv_kernel_size + 2, 5) 59 | conv_stride = max(conv_stride - 1, 2) 60 | 61 | out_channels //= 2 62 | # The height and width of downsampled image 63 | assert ds_size > 0 64 | img_dim = out_channels * (ds_size ** 2) 65 | 66 | return torch.nn.Sequential(*layers), img_dim 67 | 68 | 69 | class Captioner(torch.nn.Module): 70 | 71 | def __init__(self, canvas_size: int, num_conv_layers: int, num_linear_layers: int): 72 | super(Captioner, self).__init__() 73 | 74 | in_channels = 4 # RGB + grayscale edges 75 | out_channels = in_channels * 16 76 | 77 | self.conv_, img_dim = make_conv( 78 | in_channels=in_channels, 79 | out_channels=out_channels, 80 | num_layers=num_conv_layers, 81 | canvas_size=canvas_size 82 | ) 83 | 84 | self.pos_predictor_ = make_linear_layer( 85 | in_features=img_dim, 86 | out_features=2 * 4, # x, y, width, height 87 | num_layers=num_linear_layers 88 | ) 89 | 90 | self.color_predictor_ = make_linear_layer( 91 | in_features=img_dim, 92 | out_features=2 * 3, # 2 RGB colors 93 | num_layers=num_linear_layers 94 | ) 95 | 96 | def forward(self, img: torch.Tensor, edges: torch.Tensor = None) -> (torch.Tensor, torch.Tensor): 97 | if edges is None: 98 | edges = torch.stack([detect_edges(x) for x in img]).to(img.device) 99 | edges = edges.unsqueeze(dim=1) 100 | inp = torch.cat((img, edges), dim=1) 101 | output = self.conv_(inp) 102 | output = output.reshape(output.shape[0], -1) # Flatten elements in the batch 103 | 104 | pos_pred = self.pos_predictor_(output) 105 | color_pred = self.color_predictor_(output) 106 | assert not torch.any(torch.isnan(pos_pred)) 107 | assert not torch.any(torch.isnan(color_pred)) 108 | 109 | return pos_pred, color_pred 110 | -------------------------------------------------------------------------------- /outer/models/generator.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | 3 | import torch 4 | 5 | from ..emotions import Emotion 6 | from ..represent import as_diffvg_render, as_SVGCont, as_SVGCont2 7 | 8 | 9 | def calc_param_dim(path_count: int, path_segment_count: int): 10 | # For 1 path: 11 | # (segment_count * 3 + 1) (x,y) points 12 | # 1 stroke width 13 | # 2 RGBA (stroke color, fill color) 14 | path_params = ((path_segment_count * 3 + 1) * 2 + 1 + 4 * 2) * path_count 15 | # For the canvas: 16 | # 1 RGB (background color) 17 | canvas_params = 3 18 | return path_params + canvas_params 19 | 20 | 21 | class Generator(torch.nn.Module): 22 | def __init__(self, z_dim: int, audio_embedding_dim: int, has_emotions: bool, num_layers: int, canvas_size: int, 23 | path_count: int, path_segment_count: int, max_stroke_width: float): 24 | super(Generator, self).__init__() 25 | 26 | param_dim = calc_param_dim(path_count=path_count, path_segment_count=path_segment_count) 27 | in_features = z_dim + audio_embedding_dim 28 | if has_emotions: 29 | in_features += len(Emotion) 30 | out_features = param_dim 31 | feature_step = (in_features - out_features) // num_layers 32 | 33 | layers = [] 34 | for i in range(num_layers - 1): 35 | out_features = in_features - feature_step 36 | layers += [ 37 | torch.nn.Linear(in_features=in_features, out_features=out_features), 38 | torch.nn.BatchNorm1d(num_features=out_features), 39 | torch.nn.LeakyReLU(0.2) 40 | ] 41 | in_features = out_features 42 | layers += [ 43 | torch.nn.Linear(in_features=in_features, out_features=param_dim), 44 | torch.nn.Sigmoid() 45 | ] 46 | 47 | self.model_ = torch.nn.Sequential(*layers) 48 | self.canvas_size_ = canvas_size 49 | self.path_count_ = path_count 50 | self.path_segment_count_ = path_segment_count 51 | self.max_stroke_width_ = max_stroke_width 52 | 53 | def forward(self, noise: torch.Tensor, audio_embedding: torch.Tensor, 54 | emotions: Optional[torch.Tensor], return_psvg=False): 55 | if emotions is not None: 56 | inp = torch.cat((noise, audio_embedding, emotions), dim=1) 57 | else: 58 | inp = torch.cat((noise, audio_embedding), dim=1) 59 | print(f"inp: {inp.shape}") 60 | all_shape_params = self.model_(inp) 61 | print(f"all_shape_params: {all_shape_params.shape}") 62 | assert not torch.any(torch.isnan(all_shape_params)) 63 | 64 | action = as_SVGCont2 if return_psvg else as_diffvg_render 65 | 66 | result = [] 67 | for shape_params in all_shape_params: 68 | index = 0 69 | 70 | inc = 3 # RGB, no transparency for the background 71 | background_color = shape_params[index: index + inc] 72 | index += inc 73 | 74 | paths = [] 75 | for _ in range(self.path_count_): 76 | path = {} 77 | 78 | inc = (self.path_segment_count_ * 3 + 1) * 2 79 | path["points"] = (shape_params[index: index + inc].view(-1, 2) * 2 - 0.5) * self.canvas_size_ 80 | index += inc 81 | 82 | path["stroke_width"] = shape_params[index] * self.max_stroke_width_ * self.canvas_size_ 83 | index += 1 84 | 85 | # Colors 86 | inc = 4 # RGBA 87 | path["stroke_color"] = shape_params[index: index + inc] 88 | index += inc 89 | path["fill_color"] = shape_params[index: index + inc] 90 | index += inc 91 | 92 | paths.append(path) 93 | 94 | assert len(paths) == self.path_count_ 95 | image = action( 96 | paths=paths, 97 | background_color=background_color, 98 | canvas_size=self.canvas_size_ 99 | ) 100 | result.append(image) 101 | 102 | if not return_psvg: 103 | result = torch.stack(result) 104 | batch_size = audio_embedding.shape[0] 105 | result_channels = 3 # RGB 106 | assert result.shape == (batch_size, result_channels, self.canvas_size_, self.canvas_size_) 107 | 108 | return result 109 | -------------------------------------------------------------------------------- /colorer/train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | from torch.utils.data.dataloader import DataLoader 5 | 6 | from utils.checkpoint import save_checkpoint, load_checkpoint 7 | from utils.noise import get_noise 8 | from .models.colorer import Colorer 9 | 10 | logger = logging.getLogger("trainer") 11 | logger.addHandler(logging.StreamHandler()) 12 | logger.setLevel(logging.INFO) 13 | 14 | 15 | def weighted_mse_loss(input, target, weight=None): 16 | if weight is None: 17 | max_weight = input.size()[1] 18 | weight = torch.tensor([(max_weight - i // 3) // 3 for i in range(max_weight)]).to(input.device) 19 | weight = weight.repeat((len(input), 1)) 20 | return (weight * (input - target) ** 2).mean() 21 | 22 | 23 | def train(train_dataloader: DataLoader, gen: Colorer, device: torch.device, training_params: dict, 24 | test_dataloader: DataLoader): 25 | n_epochs = training_params["n_epochs"] 26 | lr = training_params["lr"] 27 | z_dim = training_params["z_dim"] 28 | disc_slices = training_params["disc_slices"] 29 | checkpoint_root = training_params["checkpoint_root"] 30 | backup_epochs = training_params["backup_epochs"] 31 | 32 | gen_opt = torch.optim.Adam(gen.parameters(), lr=lr) 33 | criterion = weighted_mse_loss 34 | 35 | # model_name = f'colorer_{gen.color_type}_{gen.colors_count}_colors' 36 | model_name = f'colorer_{gen.colors_count}_colors_{train_dataloader.dataset.sorted_color}' 37 | print("Trying to load checkpoint.") 38 | epochs_done = load_checkpoint(checkpoint_root, model_name, [gen, gen_opt]) 39 | if epochs_done: 40 | logger.info(f"Loaded a checkpoint with {epochs_done} epochs done") 41 | 42 | log_interval = 1 43 | for epoch in range(epochs_done + 1, n_epochs + epochs_done + 1): 44 | gen.train() 45 | running_train_loss = 0.0 46 | for batch_idx, batch in enumerate(train_dataloader): 47 | torch.cuda.empty_cache() 48 | if len(batch) == 3: 49 | audio_embedding, real_palette, emotions = batch 50 | real_palette = real_palette.to(device) 51 | emotions = emotions.to(device) 52 | else: 53 | audio_embedding, real_palette = batch 54 | real_palette = real_palette.to(device) 55 | emotions = None 56 | cur_batch_size = len(audio_embedding) 57 | audio_embedding = audio_embedding.float().to(device) 58 | audio_embedding_disc = audio_embedding[:, :disc_slices].reshape(cur_batch_size, -1) 59 | 60 | z = get_noise(cur_batch_size, z_dim, device=device) 61 | 62 | gen_opt.zero_grad() 63 | net_out = gen(z, audio_embedding_disc, emotions) 64 | loss = criterion(net_out * 255, real_palette) 65 | loss.backward() 66 | gen_opt.step() 67 | running_train_loss += loss.item() 68 | # if (batch_idx + 1) % log_interval == 0: 69 | # print('Train Epoch: {} [({:.0f}%)] Loss: {:.6f}'.format( 70 | # epoch, 100. * batch_idx / len(dataloader), loss.data.item())) 71 | # save_checkpoint(checkpoint_root, model_name, epoch, backup_epochs, [gen, gen_opt]) 72 | if epoch == epochs_done + 1 or epoch % log_interval == 0: 73 | print('Train Epoch: {}. Loss: {:.6f}'.format(epoch, loss.item())) 74 | 75 | save_checkpoint(checkpoint_root, model_name, epoch, backup_epochs, [gen, gen_opt]) 76 | 77 | avg_train_loss = running_train_loss / (batch_idx + 1) 78 | 79 | if epoch == epochs_done + 1 or epoch % log_interval == 0: 80 | gen.eval() 81 | running_test_loss = 0.0 82 | for batch_idx, batch in enumerate(test_dataloader): 83 | torch.cuda.empty_cache() 84 | if len(batch) == 3: 85 | audio_embedding, real_palette, emotions = batch 86 | real_palette = real_palette.to(device) 87 | emotions = emotions.to(device) 88 | else: 89 | audio_embedding, real_palette = batch 90 | real_palette = real_palette.to(device) 91 | emotions = None 92 | cur_batch_size = len(audio_embedding) 93 | audio_embedding = audio_embedding.float().to(device) 94 | audio_embedding_disc = audio_embedding[:, :disc_slices].reshape(cur_batch_size, -1) 95 | 96 | z = get_noise(cur_batch_size, z_dim, device=device) 97 | net_out = gen(z, audio_embedding_disc, emotions) 98 | # loss = criterion(net_out * 255, real_palette * 255) 99 | loss = criterion(net_out * 255, real_palette) 100 | running_test_loss += loss 101 | 102 | avg_test_loss = running_test_loss / (batch_idx + 1) 103 | print('LOSS: train {}; valid {}'.format(avg_train_loss, avg_test_loss)) 104 | -------------------------------------------------------------------------------- /captions/train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | from torch.utils.data.dataloader import DataLoader 5 | import torchvision.ops.boxes as bops 6 | 7 | from tqdm.auto import tqdm 8 | 9 | from .models.captioner import Captioner 10 | 11 | from utils.checkpoint import save_checkpoint, load_checkpoint 12 | from utils.plotting import plot_losses, plot_grad_flow 13 | 14 | logger = logging.getLogger("trainer") 15 | logger.addHandler(logging.StreamHandler()) 16 | logger.setLevel(logging.INFO) 17 | 18 | 19 | def unstack(t: torch.Tensor, dim: int) -> torch.Tensor: 20 | # a x (dim * b) -> (a * b) x dim 21 | width = t.shape[1] 22 | if width == dim: 23 | return t 24 | 25 | assert width % dim == 0 26 | result = [] 27 | for i in range(width // dim): 28 | result.append(t[:, i * dim: (i + 1) * dim]) 29 | return torch.cat(result) 30 | 31 | 32 | def calc_iou(a: torch.Tensor, b: torch.Tensor): 33 | pos_dim = 4 34 | a = unstack(a, pos_dim) 35 | b = unstack(b, pos_dim) 36 | a = bops.box_convert(a, in_fmt='xywh', out_fmt='xyxy') 37 | b = bops.box_convert(b, in_fmt='xywh', out_fmt='xyxy') 38 | iou = bops.generalized_box_iou(a, b).diagonal() # We only need GIoU of corresponding boxes 39 | return iou 40 | 41 | 42 | def make_models(canvas_size: int, num_conv_layers: int, num_linear_layers: int, 43 | device: torch.device) -> Captioner: 44 | captioner = Captioner( 45 | canvas_size=canvas_size, 46 | num_conv_layers=num_conv_layers, 47 | num_linear_layers=num_linear_layers 48 | ).to(device) 49 | 50 | return captioner 51 | 52 | 53 | def train(dataloader: DataLoader, captioner: Captioner, device: torch.device, training_params: dict): 54 | logger.info(captioner) 55 | 56 | n_epochs = training_params["n_epochs"] 57 | lr = training_params["lr"] 58 | checkpoint_root = training_params["checkpoint_root"] 59 | display_steps = training_params["display_steps"] 60 | bin_steps = training_params["bin_steps"] 61 | plot_grad = training_params["plot_grad"] 62 | 63 | opt = torch.optim.Adam(captioner.parameters(), lr=lr, betas=(0.5, 0.999)) # momentum=0 64 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode="max", patience=5, verbose=True) 65 | criterion = torch.nn.MSELoss() 66 | 67 | captioner_name = 'captioner' 68 | epochs_done = load_checkpoint(checkpoint_root, captioner_name, [captioner, opt]) 69 | if epochs_done: 70 | logger.info(f"Loaded a checkpoint with {epochs_done} epochs done") 71 | 72 | cur_step = 0 73 | pos_losses, color_losses = [], [] 74 | pos_val_metrics, color_val_losses = [], [] 75 | 76 | for epoch in range(epochs_done + 1, epochs_done + n_epochs + 1): 77 | for cover, edges, pos_truth, color_truth in tqdm(dataloader): 78 | cover = cover.to(device) 79 | edges = edges.to(device) 80 | pos_truth = pos_truth.to(device) 81 | color_truth = color_truth.to(device) 82 | 83 | opt.zero_grad() 84 | 85 | pos_pred, color_pred = captioner(cover, edges) 86 | assert len(pos_pred) == len(pos_truth) 87 | assert len(color_pred) == len(color_truth) 88 | 89 | pos_loss = criterion(pos_pred, pos_truth) 90 | pos_loss.backward(retain_graph=True) 91 | color_loss = criterion(color_pred, color_truth) 92 | color_loss.backward() 93 | if plot_grad and cur_step % display_steps == 0: 94 | plot_grad_flow(captioner.named_parameters(), "Captioner") 95 | 96 | opt.step() 97 | 98 | pos_losses.append(pos_loss.item()) 99 | color_losses.append(color_loss.item()) 100 | # plot_losses(epoch, cur_step, display_steps, bin_steps, 101 | # [("Positioning", pos_losses), ("Coloring", color_losses)]) 102 | cur_step += 1 103 | 104 | captioner.eval() 105 | pos_val_metric, color_val_loss = 0.0, 0.0 106 | for cover, edges, pos_truth, color_truth in tqdm(dataloader): 107 | cover = cover.to(device) 108 | edges = edges.to(device) 109 | pos_truth = pos_truth.to(device) 110 | color_truth = color_truth.to(device) 111 | 112 | pos_pred, color_pred = captioner(cover, edges) 113 | batch_iou = calc_iou(pos_pred, pos_truth).mean() 114 | batch_criterion = criterion(color_pred, color_truth).mean() 115 | pos_val_metric += batch_iou.item() 116 | color_val_loss += batch_criterion.item() 117 | pos_val_metric /= len(dataloader) 118 | color_val_loss /= len(dataloader) 119 | pos_val_metrics.append(pos_val_metric) 120 | color_val_losses.append(color_val_loss) 121 | plot_losses(epoch, 1, 1, 1, [("Pos IOU Metric", pos_val_metrics)]) 122 | plot_losses(epoch, 1, 1, 1, [("Coloring", color_val_losses)]) 123 | captioner.train() 124 | 125 | scheduler.step(pos_val_metric - color_val_loss) # max-mode 126 | 127 | save_checkpoint(checkpoint_root, captioner_name, epoch, 0, [captioner, opt]) 128 | -------------------------------------------------------------------------------- /outer/dataset.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from itertools import repeat 3 | from multiprocessing import Pool 4 | from pathlib import Path 5 | from typing import Optional, Tuple 6 | 7 | import torchvision.transforms as transforms 8 | from PIL import Image 9 | from torch.utils.data.dataset import Dataset 10 | 11 | from utils.dataset_utils import * 12 | from utils.filenames import normalize_filename 13 | from .emotions import emotions_one_hot, read_emotion_file 14 | 15 | logger = logging.getLogger("dataset") 16 | logger.addHandler(logging.StreamHandler()) 17 | logger.setLevel(logging.WARNING) 18 | 19 | 20 | class AugmentTransform(Enum): 21 | NONE = 0 22 | FLIP_H = 1 23 | FLIP_V = 2 24 | 25 | 26 | def image_file_to_tensor(file_path: str, canvas_size: int, transform: AugmentTransform) -> torch.Tensor: 27 | logger.info(f'Reading image file {file_path} to a tensor') 28 | # Convert to RGB as input can sometimes be grayscale 29 | image = Image.open(file_path).resize((canvas_size, canvas_size)).convert('RGB') 30 | if transform == AugmentTransform.FLIP_H: 31 | image = image.transpose(Image.FLIP_LEFT_RIGHT) 32 | elif transform == AugmentTransform.FLIP_V: 33 | image = image.transpose(Image.FLIP_TOP_BOTTOM) 34 | return transforms.ToTensor()(image) # CxHxW 35 | 36 | 37 | def read_music_tensor_for_file(f: str, checkpoint_root: str): 38 | music_tensor_f = get_tensor_file(checkpoint_root, f) 39 | assert os.path.isfile(music_tensor_f), f'Music tensor file missing for {f}' 40 | music_tensor = torch.load(music_tensor_f) 41 | 42 | return music_tensor 43 | 44 | 45 | def read_cover_tensor_for_file(f: str, cover_dir: str, canvas_size: int, image_transform: AugmentTransform): 46 | cover_file = f'{cover_dir}/{replace_extension(f, IMAGE_EXTENSION)}' 47 | assert os.path.isfile(cover_file), f'No cover image on disk for {f}, aborting!' 48 | cover_tensor = image_file_to_tensor(cover_file, canvas_size, image_transform) 49 | 50 | return cover_tensor 51 | 52 | 53 | class MusicDataset(Dataset[Tuple[torch.Tensor, torch.Tensor]]): 54 | def __init__(self, name: str, checkpoint_dir: str, audio_dir: str, cover_dir: str, emotion_file: Optional[str], 55 | canvas_size: int, augment: bool = False, should_cache: bool = True): 56 | self.checkpoint_root_ = f'{checkpoint_dir}/{name}' 57 | self.augment_ = augment 58 | self.cache_ = {} if should_cache else None 59 | 60 | dataset_files = create_music_tensor_files(self.checkpoint_root_, audio_dir, cover_dir) 61 | 62 | self.emotions_dict_ = None 63 | if emotion_file is not None: 64 | if not os.path.isfile(emotion_file): 65 | print(f"WARNING: Emotion file '{emotion_file}' does not exist") 66 | else: 67 | emotions_list = read_emotion_file(emotion_file) 68 | emotions_dict = dict(emotions_list) 69 | self.emotions_dict_ = emotions_dict 70 | for filename in dataset_files: 71 | filename = normalize_filename(filename) 72 | if filename not in emotions_dict: 73 | print(f"Emotions were not provided for dataset file {filename}") 74 | self.emotions_dict_ = None 75 | if self.emotions_dict_ is None: 76 | print("WARNING: Ignoring emotion data, see reasons above.") 77 | else: 78 | for filename, emotions in self.emotions_dict_.items(): 79 | self.emotions_dict_[filename] = emotions_one_hot(emotions) 80 | 81 | self.dataset_files_ = dataset_files 82 | self.cover_dir_ = cover_dir 83 | self.canvas_size_ = canvas_size 84 | 85 | def has_emotions(self): 86 | return self.emotions_dict_ is not None 87 | 88 | def __getitem__(self, index) -> Tuple[torch.Tensor, torch.Tensor]: 89 | if self.cache_ is not None and index in self.cache_: 90 | return self.cache_[index] 91 | 92 | if self.augment_: 93 | track_index = index // len(AugmentTransform) 94 | image_transform = AugmentTransform(index % len(AugmentTransform)) 95 | else: 96 | track_index = index 97 | image_transform = AugmentTransform.NONE 98 | f = self.dataset_files_[track_index] 99 | 100 | music_tensor = read_music_tensor_for_file(f, self.checkpoint_root_) 101 | cover_tensor = read_cover_tensor_for_file(f, self.cover_dir_, self.canvas_size_, image_transform) 102 | emotions = self.emotions_dict_[normalize_filename(f)] if self.emotions_dict_ is not None else None 103 | 104 | target_count = 24 # 2m = 120s, 120/5 105 | if len(music_tensor) < target_count: 106 | music_tensor = music_tensor.repeat(target_count // len(music_tensor) + 1, 1) 107 | music_tensor = music_tensor[:target_count] 108 | 109 | if emotions is not None: 110 | result = music_tensor, cover_tensor, emotions 111 | # result = music_tensor, cover_tensor, emotions, f 112 | else: 113 | result = music_tensor, cover_tensor 114 | 115 | if self.cache_ is not None: 116 | self.cache_[index] = result 117 | 118 | return result 119 | 120 | def __len__(self) -> int: 121 | result = len(self.dataset_files_) 122 | if self.augment_: 123 | result *= len(AugmentTransform) 124 | return result 125 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | import argparse 4 | import logging 5 | import os 6 | from typing import Optional 7 | 8 | from outer.emotions import Emotion, emotion_from_str 9 | from service import CoverService, OverlayFilter 10 | 11 | logger = logging.getLogger("eval") 12 | logger.addHandler(logging.StreamHandler()) 13 | logger.setLevel(logging.INFO) 14 | 15 | 16 | def filter_from_str(filter_str: str) -> Optional[OverlayFilter]: 17 | try: 18 | return OverlayFilter[filter_str.upper()] 19 | except KeyError: 20 | print(f"Unknown filter: {filter_str}") 21 | return None 22 | 23 | 24 | def main(): 25 | parser = argparse.ArgumentParser() 26 | # Service config 27 | parser.add_argument("--gan1_weights", help="Model weights for CoverGAN", 28 | type=str, default="./weights/covergan_ilya.pt") 29 | parser.add_argument("--gan2_weights", help="Model weights for CoverGAN", 30 | type=str, default="./weights/checkpoint_6figs_5depth_512noise.pt") 31 | parser.add_argument("--captioner_weights", help="Captioner weights", 32 | type=str, default="./weights/captioner.pt") 33 | parser.add_argument("--protosvg_address", help="ProtoSVG rendering server", 34 | type=str, default="localhost:50051") 35 | parser.add_argument("--font_dir", help="Directory with font files", 36 | type=str, default="./fonts") 37 | parser.add_argument("--output_dir", help="Directory where to save SVG covers", 38 | type=str, default="./gen_samples") 39 | parser.add_argument("--num_samples", help="Number of samples to generate", type=int, default=5) 40 | # Input data 41 | parser.add_argument("--audio_file", help="Path to the audio file to process", 42 | type=str, default=None, required=True) 43 | parser.add_argument("--emotions", help="Emotion of the audio file", 44 | type=str, default=None, required=True) 45 | parser.add_argument("--track_artist", help="Track artist", type=str, default=None, required=True) 46 | parser.add_argument("--track_name", help="Track name", type=str, default=None, required=True) 47 | parser.add_argument("--gen_type", help="Type of generator to use (1 or 2)", type=str, default="2", required=True) 48 | parser.add_argument("--captioning_type", help="Type of captioning algo to use (1 or 2)", type=str, 49 | default="2", required=True) 50 | # Other options 51 | parser.add_argument("--rasterize", help="Whether to rasterize the generated cover", default=False, 52 | action="store_true") 53 | parser.add_argument("--filter", help="Overlay filter to apply to the final image", 54 | default=False, action="store_true") 55 | parser.add_argument("--watermark", help="Whether to add watermark", 56 | default=False, action="store_true") 57 | parser.add_argument("--debug", help="Whether to enable debug features", 58 | default=False, action="store_true") 59 | parser.add_argument("--deterministic", help="Whether to disable random noise", 60 | default=False, action="store_true") 61 | 62 | args = parser.parse_args() 63 | print(args) 64 | 65 | logger.info("--- Starting evaluator ---") 66 | 67 | # Validate the input 68 | audio_file_name = args.audio_file 69 | output_dir = args.output_dir 70 | os.makedirs(output_dir, exist_ok=True) 71 | 72 | emotions: [Emotion] = [emotion_from_str(x) for x in args.emotions.split(',')] 73 | 74 | if audio_file_name is None or None in emotions: 75 | print("ERROR: Missing audio/emotion, exiting.") 76 | return 77 | if args.track_artist is None or args.track_name is None: 78 | print("ERROR: Unspecified track authorship properties.") 79 | return 80 | if not os.path.isfile(audio_file_name): 81 | print("ERROR: The specified audio file does not exist.") 82 | return 83 | 84 | track_artist = args.track_artist 85 | track_name = args.track_name 86 | 87 | # Start the service 88 | service = CoverService( 89 | args.gan1_weights, 90 | args.captioner_weights, 91 | args.gan2_weights, 92 | args.font_dir, 93 | log_level=logging.INFO, 94 | debug=args.debug, deterministic=args.deterministic 95 | ) 96 | 97 | # Generate covers 98 | result = service.generate( 99 | audio_file_name, track_artist, track_name, emotions, 100 | num_samples=args.num_samples, generatorType=args.gen_type, use_captioner=args.captioning_type, 101 | apply_filters=args.filter, rasterize=args.rasterize, watermark=args.watermark 102 | ) 103 | 104 | basename = os.path.basename(audio_file_name) 105 | for i, res in enumerate(result): 106 | if args.rasterize: 107 | (svg_xml, png_data) = res 108 | svg_cover_filename = f"{output_dir}/{basename}-{i + 1}.svg" 109 | png_cover_filename = f"{output_dir}/{basename}-{i + 1}.png" 110 | with open(svg_cover_filename, 'w') as f: 111 | f.write(svg_xml) 112 | with open(png_cover_filename, 'wb') as f: 113 | f.write(png_data) 114 | else: 115 | svg_cover_filename = f"{output_dir}/{basename}-{i + 1}.svg" 116 | with open(svg_cover_filename, 'w') as f: 117 | f.write(res) 118 | 119 | 120 | if __name__ == '__main__': 121 | main() 122 | -------------------------------------------------------------------------------- /utils/plotting.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import torch 7 | from matplotlib.lines import Line2D 8 | from torchvision.transforms.functional import to_pil_image 9 | 10 | logger = logging.getLogger("plotting") 11 | logger.addHandler(logging.StreamHandler()) 12 | logger.setLevel(logging.INFO) 13 | os.makedirs("./plots", exist_ok=True) 14 | 15 | 16 | def plot_losses(epoch: int, cur_step: int, display_steps: int, bin_steps: int, losses: [(str, [float])]): 17 | if cur_step % display_steps == 0 and cur_step > 0: 18 | loss_stats = [] 19 | 20 | for loss_name, loss_values in losses: 21 | loss_mean = sum(loss_values[-display_steps:]) / display_steps 22 | if "loss" not in loss_name.lower() and "metric" not in loss_name.lower(): 23 | loss_name += " Loss" 24 | loss_stats.append(f"{loss_name}: {loss_mean}") 25 | 26 | num_examples = (len(loss_values) // bin_steps) * bin_steps 27 | plt.plot( 28 | range(num_examples // bin_steps), 29 | torch.Tensor(loss_values[:num_examples]).view(-1, bin_steps).mean(1), 30 | label=loss_name 31 | ) 32 | 33 | logger.info(f"Epoch {epoch} (step {cur_step}): " + ", ".join(loss_stats)) 34 | plt.legend() 35 | print("Saving losses to png file") 36 | import covergan_train 37 | plt.savefig(f"{covergan_train.logger.plots_dir}/losses-{epoch}-{cur_step}.png") 38 | plt.show() 39 | plt.close() 40 | elif cur_step == 0: 41 | logger.info("The training is working") 42 | 43 | 44 | def plot_grad_flow(named_parameters, model_name: str, epoch=None, cur_step=None): 45 | """Plots the gradients flowing through different layers in the network during training. 46 | Can be used for checking for possible gradient vanishing/exploding problems. 47 | 48 | Usage: Plug this function in Trainer class after loss.backwards() as 49 | `plot_grad_flow(self.model.named_parameters())` to visualize the gradient flow 50 | """ 51 | 52 | # Calculate the stats 53 | avg_grads = [] 54 | max_grads = [] 55 | layers = [] 56 | for n, p in named_parameters: 57 | if p.requires_grad and ("bias" not in n) and p.grad is not None: 58 | layers.append(n) 59 | avg_grads.append(p.grad.abs().mean()) 60 | max_grads.append(p.grad.abs().max()) 61 | 62 | # Initialize plot canvas 63 | fig = plt.figure() 64 | fig.set_size_inches(6, 6) 65 | 66 | # Plot 67 | plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.2, lw=1, color="c") # Max gradients 68 | plt.bar(np.arange(len(max_grads)), avg_grads, alpha=0.2, lw=1, color="b") # Mean gradients 69 | plt.hlines(0, 0, len(avg_grads) + 1, lw=2, color="k") # Zero gradient line 70 | plt.xticks(range(0, len(avg_grads), 1), layers, rotation="vertical") 71 | 72 | # Set display options 73 | plt.xlim(left=0, right=len(avg_grads)) 74 | plt.ylim(bottom=-0.001, top=0.02) # Zoom in on the lower gradient regions 75 | plt.xlabel('Layers') 76 | plt.ylabel('average gradient') 77 | plt.title(f'Gradient flow in {model_name}') 78 | plt.grid(True) 79 | plt.legend( 80 | [ 81 | Line2D([0], [0], color="c", lw=4), 82 | Line2D([0], [0], color="b", lw=4), 83 | Line2D([0], [0], color="k", lw=4) 84 | ], 85 | [ 86 | 'max-gradient', 87 | 'mean-gradient', 88 | 'zero-gradient' 89 | ] 90 | ) 91 | 92 | fig.tight_layout() 93 | plt.plot() 94 | if epoch is not None: 95 | print("Saving grad flow to png file") 96 | import covergan_train 97 | plt.savefig(f"{covergan_train.logger.plots_dir}/grad-flow-{epoch}-{cur_step}.png") 98 | plt.show() 99 | plt.close() 100 | 101 | 102 | def plot_real_fake_covers(real_cover_tensor: torch.Tensor, fake_cover_tensor: torch.Tensor, 103 | disc_real_pred: torch.Tensor = None, disc_fake_pred: torch.Tensor = None, 104 | epoch=None, cur_step=None, plot_saving_dir="./plots"): 105 | sample_count = 5 # max covers to draw 106 | 107 | real_cover_tensor = real_cover_tensor[:sample_count] 108 | fake_cover_tensor = fake_cover_tensor[:sample_count] 109 | 110 | rows = min(sample_count, len(real_cover_tensor)) 111 | cols = 2 112 | fig = plt.figure() 113 | fig.set_size_inches(6 * cols, 6 * rows) 114 | for i in range(rows): 115 | real = real_cover_tensor[i] 116 | fake = fake_cover_tensor[i] 117 | 118 | real_pil = to_pil_image(real) 119 | fake_pil = to_pil_image(fake) 120 | 121 | real_score = disc_real_pred[i].item() if disc_real_pred is not None else None 122 | fake_score = disc_fake_pred[i].item() if disc_fake_pred is not None else None 123 | 124 | for (j, (pil, score)) in enumerate([(real_pil, real_score), (fake_pil, fake_score)]): 125 | plt.subplot(rows, cols, i * cols + j + 1) 126 | plt.imshow(pil) 127 | if score is not None: 128 | plt.text(10, 10, f'{score:.3f}', backgroundcolor='w', fontsize=40.0) 129 | plt.xticks([]) 130 | plt.yticks([]) 131 | 132 | fig.tight_layout() 133 | plt.plot() 134 | if epoch is not None: 135 | print("Saving covers to png file") 136 | import covergan_train 137 | # print(covergan_train.logger.plots_dir) 138 | plt.savefig(f"{covergan_train.logger.plots_dir}/covers-{epoch}-{cur_step}.png") 139 | plt.show() 140 | plt.close() 141 | -------------------------------------------------------------------------------- /outer/represent.py: -------------------------------------------------------------------------------- 1 | from pydiffvg import * 2 | 3 | from outer.SVGContainer import SVGContainer, RectTag, PathTag 4 | from outer.models.cover_classes import Cover, CoverFigure 5 | 6 | 7 | def tensor_color_to_int(t: torch.Tensor, a: float = None): 8 | t = (t * 255).round().to(int) 9 | r = t[0].item() 10 | g = t[1].item() 11 | b = t[2].item() 12 | if a is None: 13 | a = t[3].item() 14 | else: 15 | a = round(255 * a) 16 | return [r, g, b, a] 17 | 18 | 19 | def color_to_rgba_attr(color): 20 | return f"rgba({color[0]}, {color[1]}, {color[2]}, {round(color[3] / 255.0, 2)})" 21 | 22 | 23 | def to_diffvg_svg_params(paths: [dict], background_color: torch.Tensor, canvas_size: int): 24 | shapes = [] 25 | shape_groups = [] 26 | 27 | # No transparency for the background 28 | background_color = torch.cat((background_color, torch.ones(1, device=background_color.device))) 29 | background_square = pydiffvg.Rect( 30 | p_min=torch.Tensor([0.0, 0.0]), 31 | p_max=torch.Tensor([1.0, 1.0]) * canvas_size, 32 | stroke_width=torch.Tensor([0.0]) 33 | ) 34 | shapes.append(background_square) 35 | background_group = pydiffvg.ShapeGroup( 36 | shape_ids=torch.tensor([0]), 37 | fill_color=background_color 38 | ) 39 | shape_groups.append(background_group) 40 | 41 | # Paths 42 | for p in paths: 43 | if "is_circle" in p: 44 | radius = p["radius"] # radius=torch.tensor(40.0), 45 | center = p["center"] # center=torch.tensor([128.0, 128.0]) 46 | stroke_color = p["stroke_color"] 47 | fill_color = p["fill_color"] 48 | path = pydiffvg.Circle(radius=radius, center=center) 49 | else: 50 | points = p["points"] # For `segment_count` segments 51 | stroke_width = p["stroke_width"] 52 | stroke_color = p["stroke_color"] 53 | fill_color = p["fill_color"] 54 | 55 | # Avoid overlapping points 56 | # Note: this is taken from diffvg GAN example 57 | eps = 1e-4 58 | points = points + eps * torch.randn_like(points) 59 | 60 | control_per_segment = 2 # 3 points = 1 base + 2 control 61 | segment_count = (len(points) - 1) // 3 62 | num_control_points = torch.full((segment_count,), control_per_segment, dtype=torch.int32) 63 | path = pydiffvg.Path(num_control_points=num_control_points, 64 | points=points, 65 | stroke_width=stroke_width, 66 | is_closed=True) 67 | shapes.append(path) 68 | 69 | path_group = pydiffvg.ShapeGroup(shape_ids=torch.tensor([len(shapes) - 1]), 70 | fill_color=fill_color, 71 | stroke_color=stroke_color) 72 | shape_groups.append(path_group) 73 | return canvas_size, canvas_size, shapes, shape_groups 74 | 75 | 76 | def as_diffvg_render(paths: [dict], background_color: torch.Tensor, canvas_size: int) -> torch.Tensor: 77 | params = to_diffvg_svg_params(paths, background_color, canvas_size) 78 | scene_args = pydiffvg.RenderFunction.serialize_scene(*params) 79 | render = pydiffvg.RenderFunction.apply 80 | img = render(canvas_size, # width 81 | canvas_size, # height 82 | 2, # num_samples_x 83 | 2, # num_samples_y 84 | 0, # seed 85 | None, # background_image 86 | *scene_args) 87 | img = img[:, :, :3] # RGBA -> RGB 88 | img = img.permute(2, 0, 1) # HWC -> CHW 89 | 90 | return img 91 | 92 | 93 | def as_SVGCont(cover: Cover, canvas_size: int): 94 | image = SVGContainer(width=canvas_size, height=canvas_size) 95 | backgroundColor = tensor_color_to_int(cover.background_color, 1.0) 96 | rect = RectTag(attrs_dict={"width": canvas_size, "height": canvas_size, 97 | "fill": color_to_rgba_attr(backgroundColor)}) 98 | image.add_inner_node(rect) 99 | for p in cover.figures: 100 | points = p.points.round().to(int) # For `self.path_segment_count_` segments 101 | stroke_width = p.stroke_width.round().to(int).item() 102 | stroke_color = tensor_color_to_int(p.stroke_color) 103 | fill_color = tensor_color_to_int(p.fill_color) 104 | 105 | path = PathTag() 106 | 107 | path.move_to(points[0][0].item(), points[0][1].item()) 108 | 109 | segment_count = (len(points) - 1) // 3 110 | for j in range(segment_count): 111 | segment_points = points[1 + j * 3: 1 + (j + 1) * 3] 112 | path.cubic_to(segment_points[0][0].item(), 113 | segment_points[0][1].item(), 114 | segment_points[1][0].item(), 115 | segment_points[1][1].item(), 116 | segment_points[2][0].item(), 117 | segment_points[2][1].item()) 118 | path.close_path() 119 | path.add_attrs({"fill": color_to_rgba_attr(fill_color), 120 | "stroke": color_to_rgba_attr(stroke_color), 121 | "stroke-width": stroke_width, 122 | }) 123 | image.add_inner_node(path) 124 | return image 125 | 126 | 127 | def as_SVGCont2(paths: [dict], background_color: torch.Tensor, canvas_size: int): 128 | cover = Cover() 129 | cover.background_color = background_color 130 | for p in paths: 131 | fig = CoverFigure() 132 | fig.points = p["points"] 133 | fig.stroke_width = p["stroke_width"] 134 | fig.stroke_color = p["stroke_color"] 135 | fig.fill_color = p["fill_color"] 136 | cover.add_figure(fig) 137 | return as_SVGCont(cover, canvas_size) 138 | -------------------------------------------------------------------------------- /colorer_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | import argparse 4 | import logging 5 | import os 6 | 7 | import torch 8 | from torch.utils.data.dataloader import DataLoader 9 | 10 | from colorer.models.colorer import Colorer 11 | from colorer.models.colorer_dropout import Colorer2 12 | from colorer.music_palette_dataset import MusicPaletteDataset 13 | 14 | logger = logging.getLogger("colorer_train") 15 | logger.addHandler(logging.StreamHandler()) 16 | logger.setLevel(logging.INFO) 17 | 18 | colorer_type = Colorer 19 | colorer_type = Colorer2 20 | 21 | 22 | def get_train_data(checkpoint_dir: str, audio_dir: str, cover_dir: str, emotion_file: str, 23 | batch_size: int, is_for_train: bool = True) -> \ 24 | (DataLoader, int, (int, int, int), bool): 25 | dataset = MusicPaletteDataset("cgan_out_dataset", checkpoint_dir, 26 | audio_dir, cover_dir, emotion_file, is_for_train=is_for_train) 27 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) 28 | music_tensor, palette_tensor = dataset[0][:2] 29 | audio_embedding_dim = music_tensor.shape[1] 30 | palette_shape = palette_tensor.shape 31 | has_emotions = dataset.has_emotions() 32 | 33 | return dataloader, audio_embedding_dim, palette_shape, has_emotions 34 | 35 | 36 | def file_in_folder(dir, file): 37 | if file is None: 38 | return None 39 | return f"{dir}/{file}" 40 | 41 | 42 | def main(): 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument("--train_dir", help="Directory with all folders for training", type=str, default=".") 45 | parser.add_argument("--colors_count", help="Count of colors to predict", type=int, default=6) 46 | parser.add_argument("--plots", help="Directory where save plots while training", type=str, default="plots") 47 | parser.add_argument("--audio", help="Directory with the music files", type=str, default="audio") 48 | parser.add_argument("--covers", help="Directory with the cover images", type=str, default="clean_covers") 49 | parser.add_argument("--emotions", help="File with emotion markup for train dataset", type=str, default=None) 50 | parser.add_argument("--test_set", help="Directory with test music files", type=str, default=None) 51 | parser.add_argument("--test_emotions", help="File with emotion markup for test dataset", type=str, default=None) 52 | parser.add_argument("--checkpoint_root", help="Checkpoint location", type=str, default="checkpoint") 53 | parser.add_argument("--lr", help="Learning rate", type=float, default=0.0005) 54 | parser.add_argument("--disc_repeats", help="Discriminator runs per iteration", type=int, default=5) 55 | parser.add_argument("--epochs", help="Number of epochs to train for", type=int, default=8000) 56 | parser.add_argument("--batch_size", help="Batch size", type=int, default=64) 57 | parser.add_argument("--canvas_size", help="Image canvas size for learning", type=int, default=128) 58 | parser.add_argument("--display_steps", help="How often to plot the samples", type=int, default=500) 59 | parser.add_argument("--backup_epochs", help="How often to backup checkpoints", type=int, default=600) 60 | parser.add_argument("--plot_grad", help="Whether to plot the gradients", default=False, action="store_true") 61 | args = parser.parse_args() 62 | print(args) 63 | 64 | # Network properties 65 | num_gen_layers = 5 66 | z_dim = 32 # Dimension of the noise vector 67 | 68 | disc_slices = 6 69 | 70 | # Plot properties 71 | bin_steps = 20 # How many steps to aggregate with mean for each plot point 72 | 73 | logger.info("--- Starting out_main ---") 74 | 75 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 76 | 77 | os.makedirs(file_in_folder(args.train_dir, args.checkpoint_root), exist_ok=True) 78 | 79 | train_dataloader, audio_embedding_dim, img_shape, has_emotions = get_train_data( 80 | file_in_folder(args.train_dir, args.checkpoint_root), 81 | file_in_folder(args.train_dir, args.audio), 82 | file_in_folder(args.train_dir, args.covers), 83 | file_in_folder(args.train_dir, args.emotions), 84 | args.batch_size, is_for_train=True 85 | ) 86 | 87 | test_dataloader, audio_embedding_dim, img_shape, has_emotions = get_train_data( 88 | file_in_folder(args.train_dir, args.checkpoint_root), 89 | file_in_folder(args.train_dir, args.audio), 90 | file_in_folder(args.train_dir, args.covers), 91 | file_in_folder(args.train_dir, args.emotions), 92 | args.batch_size, is_for_train=False 93 | ) 94 | 95 | logger.info("--- Colorer training ---") 96 | gen = colorer_type( 97 | z_dim=z_dim, 98 | audio_embedding_dim=audio_embedding_dim * disc_slices, 99 | has_emotions=has_emotions, 100 | num_layers=num_gen_layers, 101 | colors_count=args.colors_count, 102 | ).to(device) 103 | GAN_MODEL = True 104 | training_params = { 105 | # Common 106 | "display_steps": args.display_steps, 107 | "backup_epochs": args.backup_epochs, 108 | "bin_steps": bin_steps, 109 | "z_dim": z_dim, 110 | "disc_slices": disc_slices, 111 | "checkpoint_root": file_in_folder(args.train_dir, args.checkpoint_root), 112 | # (W)GAN-specific 113 | "n_epochs": args.epochs, 114 | "lr": args.lr, 115 | "disc_repeats": args.disc_repeats, 116 | "plot_grad": args.plot_grad, 117 | } 118 | if not GAN_MODEL: 119 | from colorer.train import train 120 | train(train_dataloader, gen, device, training_params, test_dataloader) 121 | else: 122 | from colorer.models.gan_colorer import ColorerDiscriminator, train 123 | num_disc_layers = 2 124 | disc = ColorerDiscriminator(audio_embedding_dim=audio_embedding_dim * disc_slices, 125 | has_emotions=has_emotions, 126 | num_layers=num_disc_layers).to(device) 127 | train(train_dataloader, test_dataloader, gen, disc, device, training_params) 128 | 129 | 130 | if __name__ == '__main__': 131 | main() 132 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CoverGAN 2 | 3 |
4 | img1 5 | img2 6 | img3 7 |
8 | 9 | **CoverGAN** is a set of tools and machine learning models designed to generate good-looking album covers based on users' 10 | audio tracks and emotions. Resulting covers are generated in vector graphics format (SVG). 11 | 12 | Available emotions: 13 | 14 | * Anger 15 | * Comfortable 16 | * Fear 17 | * Funny 18 | * Happy 19 | * Inspirational 20 | * Joy 21 | * Lonely 22 | * Nostalgic 23 | * Passionate 24 | * Quiet 25 | * Relaxed 26 | * Romantic 27 | * Sadness 28 | * Serious 29 | * Soulful 30 | * Surprise 31 | * Sweet 32 | * Wary 33 | 34 | The service is available on http://81.3.154.178:5001/covergan. 35 | 36 | ## Service functionality 37 | 38 | * Generation of music covers by analyzing music and emotions 39 | * Several GAN models 40 | * SVG format 41 | * Possibility of rasterization 42 | * Insertion of readable captions 43 | * A large number of different fonts 44 | * Insertion of different color filters 45 | * SVG editor 46 | * Convenient change of colors 47 | * Style transfer from provided image 48 | * Saving images in any resolution 49 | 50 | ## Weights 51 | 52 | * The pretrained weights can be downloaded 53 | from [here](https://drive.google.com/file/d/1ArU0TziLBOxhphG4KBshUxPBBECErxu1/view?usp=sharing) 54 | * These weights should be placed into `./weights` folder 55 | 56 | ## Training 57 | 58 | * See [this README](./README.md) for training details. 59 | 60 | ## Testing using Docker 61 | 62 | In this service two types of generator are available: 63 | 64 | * The first one creates the covers with abstract lines 65 | * The second one draws closed forms. 66 | 67 | It is also possible to use one of two algorithms for applying inscriptions to the cover: 68 | 69 | * The first algorithm uses the captioner model 70 | * The second is a deterministic algorithm which searches for a suitable location 71 | 72 | The service uses pretrained weights. See [this](README.md#Weights) section. 73 | 74 | ### Building 75 | 76 | * Specify PyTorch version to install in [`Dockerfile`](./Dockerfile). 77 | 78 | * Build the image running `docker_build_covergan_service.sh` file 79 | 80 | ### Running 81 | 82 | * Start the container running `docker_run_covergan_service.sh` file 83 | 84 | ### Testing 85 | 86 | Go to `http://localhost:5001` in the browser and enjoy! 87 | 88 | ## Local testing 89 | 90 | ### Install dependencies 91 | 92 | * Install suitable PyTorch version: `pip install torch torchvision torchaudio` 93 | * Install [DiffVG](https://github.com/BachiLi/diffvg) 94 | * Install dependencies from [this](./requirements.txt) file 95 | 96 | ### Running 97 | 98 | * Run 99 | 100 | ```sh 101 | python3 ./eval.py \ 102 | --audio_file="test.mp3" \ 103 | --emotions=joy,relaxed \ 104 | --track_artist="Cool Band" \ 105 | --track_name="New Song" 106 | ``` 107 | 108 | * The resulting `.svg` covers by default will be saved to [`./gen_samples`](./covergan/gen_samples) folder. 109 | 110 | ## Examples of generated covers 111 | 112 | See [this](./examples) examples folder. 113 | 114 | ## Contents 115 | 116 | * `captions/`: a network that predicts aesthetically matching colors and positions for the captions (artist and track 117 | names). 118 | * `colorer/`: a network that predicts palettes for music covers. 119 | * `docs/`: folder with instructions on how to start training or testing models. 120 | * `examples/`: folder with simple music tracks, their generated covers, and with examples of original and clean 121 | datasets. 122 | * `fonts/`: folder with downloaded from [Google Fonts](https://fonts.google.com/) fonts. 123 | * `outer/`: the primary GAN that generates vector graphics descriptions from audio files and user-specified emotions. 124 | * `utils/`: parts of code implementing various independent functionality and separated for convenient reuse. 125 | * `weights/`: folder where the best models were saved. 126 | * `captioner_train.py`: an entry point to trigger the Captioner network training. 127 | * `covergan_train.py`: an entry point to trigger the CoverGAN training. 128 | * `eval.py`: an entry point to trigger the primary flow as a command line tool. 129 | * `service.py`: the primary code flow for album cover generation. 130 | 131 | ## Default structure of dataset folder: 132 | 133 | * `audio/`: default folder with music tracks (`.flac` or `.mp3` format) for CoverGAN training. 134 | * `checkpoint/`: default folder where checkpoints and other intermediate files while training CoverGAN and Captioner 135 | Networks will be stored. 136 | * `clean_covers/`: default folder with covers on which captures were removed. 137 | * `original_covers/`: default folder with original covers. 138 | * `plots/`: the folder where the intermediate plots while training will be saved 139 | * `emotions.json`: file with emotion markup for train dataset. 140 | 141 | ## Dependencies 142 | 143 | * The machine learning models rely on the popular [PyTorch](https://pytorch.org) framework. 144 | * Differentiable vector graphics rendering is provided by [diffvg](https://github.com/BachiLi/diffvg), which needs to be 145 | built from source. 146 | * Audio feature extraction is based on [Essentia](https://github.com/MTG/essentia), prebuilt pip packages are available. 147 | * Other Python library dependencies include Pillow, Matplotlib, SciPy, and [Kornia](https://kornia.github.io). 148 | 149 | ## Dataset 150 | 151 | The full dataset contains of: 152 | 153 | * Audio tracks 154 | * Original covers 155 | * Cleaned covers 156 | * Fonts 157 | * Marked up emotions 158 | * Marked up rectangles for captioner model training 159 | 160 | The dataset can be downloaded from [here](https://drive.google.com/file/d/1_NKlS79y29_he9P3xTLd7SgYbOstCkmO/view?usp=sharing) 161 | 162 | ## Training using Docker with GPU 163 | 164 | * Build image running `docker_build.sh` 165 | * See [these](/docs) docs for more details about specified options while training networks. 166 | * Specify training command in `covergan_training_command.sh` 167 | * Start container running `docker_run.sh` 168 | 169 | ## License 170 | 171 | Shield: [![CC BY-NC-SA 4.0][cc-by-nc-sa-shield]][cc-by-nc-sa] 172 | 173 | This work is licensed under a 174 | [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License][cc-by-nc-sa]. 175 | 176 | [![CC BY-NC-SA 4.0][cc-by-nc-sa-image]][cc-by-nc-sa] 177 | 178 | [cc-by-nc-sa]: http://creativecommons.org/licenses/by-nc-sa/4.0/ 179 | 180 | [cc-by-nc-sa-image]: https://licensebuttons.net/l/by-nc-sa/4.0/88x31.png 181 | 182 | [cc-by-nc-sa-shield]: https://img.shields.io/badge/License-CC%20BY--NC--SA%204.0-lightgrey.svg -------------------------------------------------------------------------------- /outer/audio_extractor.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | 4 | # from essentia import Pool 5 | import essentia.standard as es 6 | 7 | logger = logging.getLogger("audio_extractor") 8 | logger.addHandler(logging.StreamHandler()) 9 | logger.setLevel(logging.INFO) 10 | 11 | 12 | # def print_pool(pool: Pool): 13 | # np.set_printoptions(suppress=True) 14 | # for key in sorted(pool.descriptorNames()): 15 | # val = pool[key] 16 | # r = str(pool[key]) 17 | # if isinstance(val, np.ndarray): 18 | # r = f'Array of shape {val.shape}: ' + r 19 | # print(f'* {key}:\n{r}') 20 | 21 | 22 | KEY_NUM_MAP = {'C': 0, 'C#': 1, 'D': 2, 23 | 'Eb': 3, 'E': 4, 'F': 5, 24 | 'F#': 6, 'G': 7, 'Ab': 8, 25 | 'A': 9, 'Bb': 10, 'B': 11} 26 | 27 | SAMPLE_RATE = 44100 # MonoLoader resamples to 44.1 KHz 28 | CHUNK_SECONDS = 10 29 | SLICE_SIZE = SAMPLE_RATE * CHUNK_SECONDS 30 | 31 | 32 | class FeatureExtractor: 33 | def __init__(self): 34 | # --- Spectral: 35 | self.windowing_hann_algo_ = es.Windowing(type='hann') 36 | self.windowing_bh_algo_ = es.Windowing(type='blackmanharris92') 37 | self.spectrum_algo_ = es.Spectrum() # FFT() returns complex FFT, here we want just the magnitude spectrum 38 | self.mfcc_algo_ = es.MFCC(inputSize=SLICE_SIZE // 2 + 1) 39 | self.log_norm_ = es.UnaryOperator(type='log') 40 | self.spectral_contrast_algo_ = es.SpectralContrast(frameSize=SLICE_SIZE) 41 | self.spectral_peaks_algo_ = es.SpectralPeaks() 42 | self.spectrum_whitening_algo_ = es.SpectralWhitening() 43 | # --- Dynamics: 44 | self.loudness_algo_ = es.Loudness() 45 | # --- Rhythm: 46 | self.danceability_algo_ = es.Danceability() 47 | self.rhythm_algo_ = es.RhythmExtractor2013(method="multifeature") 48 | self.onset_algo_ = es.OnsetRate() 49 | # --- Tonal: 50 | self.chromagram_algo_ = es.Chromagram() 51 | self.hpcp_algo_ = es.HPCP() 52 | self.key_algo_ = es.Key(profileType='edma') 53 | 54 | def mfcc(self, s): 55 | mfcc_bands, mfcc_coefficients = self.mfcc_algo_(self.spectrum_algo_(self.windowing_hann_algo_(s))) 56 | melband_log = self.log_norm_(mfcc_bands) 57 | return mfcc_bands, mfcc_coefficients, melband_log 58 | 59 | def spectral(self, s): 60 | spectrum = self.spectrum_algo_(self.windowing_bh_algo_(s)) 61 | 62 | spectral_contrast, spectral_valley = self.spectral_contrast_algo_(spectrum) 63 | frequencies, magnitudes = self.spectral_peaks_algo_(spectrum) 64 | magnitudes = self.spectrum_whitening_algo_(spectrum, frequencies, magnitudes) 65 | 66 | return spectral_contrast, spectral_valley, frequencies, magnitudes 67 | 68 | def dynamics(self, s): 69 | loudness = self.loudness_algo_(s) 70 | return loudness 71 | 72 | def rhythm(self, s): 73 | danceability = self.danceability_algo_(s)[0] 74 | bpm, ticks, confidence, estimates, bpm_intervals = self.rhythm_algo_(s) 75 | # New instance, parametrized by ticks 76 | beats_loudness_algo = es.BeatsLoudness(sampleRate=SAMPLE_RATE, beats=ticks) 77 | mean_beats_loudness = beats_loudness_algo(s)[0].mean() 78 | 79 | self.onset_algo_.reset() # Has to be reset between slices 80 | onsets, onset_rate = self.onset_algo_(s) 81 | 82 | return danceability, bpm, mean_beats_loudness, onset_rate 83 | 84 | def chromagram(self, s): 85 | chromagram = [] 86 | const_q_frame_size = 32768 87 | for frame in es.FrameGenerator(s, frameSize=const_q_frame_size, 88 | hopSize=const_q_frame_size, startFromZero=True): 89 | chromagram.append(self.chromagram_algo_(frame)) 90 | chromagram = np.concatenate(chromagram) 91 | 92 | return chromagram 93 | 94 | def tonal(self, frequencies, magnitudes): 95 | hpcp = self.hpcp_algo_(frequencies, magnitudes) 96 | key, key_scale, key_strength = self.key_algo_(hpcp)[:3] 97 | is_major = int(key_scale == "major") 98 | key_num = KEY_NUM_MAP[key] 99 | 100 | return key_num, is_major, key_strength 101 | 102 | def end_track(self): 103 | self.rhythm_algo_.reset() 104 | 105 | 106 | ext = FeatureExtractor() 107 | 108 | 109 | def audio_to_embedding(file_path: str, f_num=None) -> np.array: 110 | if f_num is not None: 111 | logger.info(f'Extracting #{f_num} audio embeddings for {file_path}') 112 | else: 113 | logger.info(f'Extracting audio embeddings for {file_path}') 114 | metadata_reader = es.MetadataReader(filename=file_path, failOnError=True) 115 | metadata = metadata_reader() 116 | pool_meta, meta_duration, meta_bitrate, meta_sample_rate, meta_channels = metadata[7:] 117 | 118 | loader = es.MonoLoader(filename=file_path) 119 | audio = loader() 120 | 121 | # Signal duration 122 | duration_algo = es.Duration() 123 | duration = duration_algo(audio) 124 | 125 | assert abs(meta_duration - duration) < 1.0, f'Incomplete file "{file_path}": meta {meta_duration}, read {duration}' 126 | 127 | # Slice the track in 10s chunks 128 | slices = es.FrameGenerator(audio, frameSize=SLICE_SIZE, hopSize=SLICE_SIZE // 2, startFromZero=True) 129 | 130 | # Compute the audio features on slices 131 | embeddings = [] 132 | for s in slices: 133 | # Spectral 134 | mfcc_bands, mfcc_coefficients, melband_log = ext.mfcc(s) 135 | spectral_contrast, spectral_valley, frequencies, magnitudes = ext.spectral(s) 136 | # Dynamics 137 | loudness = ext.dynamics(s) 138 | # Rhythm 139 | danceability, bpm, mean_beats_loudness, onset_rate = ext.rhythm(s) 140 | # Tonal 141 | chromagram = ext.chromagram(s) 142 | key_num, is_major, key_strength = ext.tonal(frequencies, magnitudes) 143 | 144 | # Normalization 145 | mfcc_bands /= 0.003 # [0, 0.003) 146 | mfcc_coefficients = (mfcc_coefficients + 1500) / 1800 # (-1500, 300) 147 | melband_log = (melband_log + 70) / 70 # (-70, 0] 148 | # chromagram is already normalized 149 | spectral_contrast += 1 # [-1, 0] 150 | spectral_valley = (spectral_valley + 70) / 70 # (-70, 0] 151 | bpm = (bpm - 50) / 150 # (50, 200) 152 | loudness /= 4000 # [0, 4000) 153 | mean_beats_loudness /= 3 # [0, 3) 154 | danceability /= 12 # (0, 12) 155 | onset_rate /= 10 # (0, 10) 156 | key_num /= (len(KEY_NUM_MAP) - 1) # [0, 11] 157 | # is_major is already normalized 158 | key_strength = (key_strength + 1) / 2 # [-1, 1) 159 | 160 | embed = np.concatenate([ 161 | mfcc_bands, mfcc_coefficients, melband_log, chromagram, 162 | spectral_contrast, spectral_valley, 163 | np.array([bpm, loudness, mean_beats_loudness, danceability, 164 | onset_rate, key_num, is_major, key_strength]) 165 | ]) 166 | embeddings.append(embed) 167 | 168 | ext.end_track() 169 | 170 | result = np.stack(embeddings) 171 | 172 | return result 173 | -------------------------------------------------------------------------------- /outer/models/my_generator_fixed_multi_circle.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from typing import * 4 | from colorer.test_model import get_palette_predictor 5 | from .former.modules import TransformerBlock 6 | from .my_generator_circle_paths import create_circle_control_points 7 | from ..emotions import Emotion 8 | from ..represent import * 9 | 10 | 11 | class MyGeneratorFixedMultiCircle(nn.Module): 12 | def __init__(self, z_dim: int, audio_embedding_dim: int, has_emotions: bool, num_layers: int, canvas_size: int, 13 | path_count: int, path_segment_count: int, max_stroke_width: float): 14 | super(MyGeneratorFixedMultiCircle, self).__init__() 15 | self.path_count_in_row = 15 16 | path_count = self.path_count_in_row ** 2 17 | path_segment_count = 2 18 | 19 | in_features = z_dim + audio_embedding_dim 20 | if has_emotions: 21 | in_features += len(Emotion) 22 | 23 | self.circle_center_count = 2 # (x_0, y_0) 24 | self.circle_radius_count = 1 25 | self.background_color_count = 3 # RGB, no transparency for the background 26 | self.all_paths_count = path_count 27 | 28 | self.WITH_DEFORMATION = False 29 | if self.WITH_DEFORMATION: 30 | self.one_path_points_count = (path_segment_count * 3) * 2 31 | self.all_points_count_for_path = self.circle_radius_count + self.one_path_points_count 32 | else: 33 | self.all_points_count_for_path = self.circle_radius_count 34 | 35 | self.WITH_OPACITY = False 36 | if self.WITH_OPACITY: 37 | # 4 = RGBA 38 | self.fill_color = 4 39 | else: 40 | self.fill_color = 3 41 | self.all_points_count_for_path += self.fill_color 42 | out_dim = self.background_color_count + self.all_paths_count * self.all_points_count_for_path 43 | 44 | out_features = out_dim 45 | feature_step = (in_features - out_features) // num_layers 46 | 47 | layers = [] 48 | for i in range(num_layers - 1): 49 | out_features = in_features - feature_step 50 | layers += [ 51 | torch.nn.Linear(in_features=in_features, out_features=out_features), 52 | torch.nn.BatchNorm1d(num_features=out_features), 53 | torch.nn.LeakyReLU(0.2) 54 | ] 55 | in_features = out_features 56 | layers += [ 57 | torch.nn.Linear(in_features=in_features, out_features=out_dim), 58 | torch.nn.Sigmoid() 59 | ] 60 | 61 | # self.transformer_block = TransformerBlock(1, 10, False) 62 | self.model_ = torch.nn.Sequential(*layers) 63 | 64 | self.canvas_size_ = canvas_size 65 | self.path_segment_count_ = path_segment_count 66 | self.max_stroke_width_ = max_stroke_width 67 | 68 | def forward(self, noise: torch.Tensor, audio_embedding: torch.Tensor, 69 | emotions: Optional[torch.Tensor], return_psvg=False) \ 70 | -> Union[torch.Tensor, List[psvg.ProtoSVG]]: 71 | forward_fun = self.my_mega_forward 72 | return forward_fun(noise, audio_embedding, emotions, return_psvg) 73 | 74 | def my_mega_forward(self, noise: torch.Tensor, audio_embedding: torch.Tensor, 75 | emotions: Optional[torch.Tensor], return_psvg=False, return_diffvg_svg_params=False) \ 76 | -> Union[torch.Tensor, List[psvg.ProtoSVG]]: 77 | if emotions is not None: 78 | inp = torch.cat((noise, audio_embedding, emotions), dim=1) 79 | else: 80 | inp = torch.cat((noise, audio_embedding), dim=1) 81 | 82 | # inp = self.transformer_block(inp.view(inp.shape[0], -1, 1)) 83 | # inp = inp.view(inp.shape[0], -1) 84 | all_shape_params = self.model_(inp) 85 | assert not torch.any(torch.isnan(all_shape_params)) 86 | 87 | action = as_protosvg if return_psvg else as_diffvg_render 88 | 89 | result = [] 90 | result_svg_params = [] 91 | for b_idx, shape_params in enumerate(all_shape_params): 92 | index = 0 93 | 94 | inc = self.background_color_count 95 | background_color = shape_params[index: index + inc] 96 | index += inc 97 | 98 | paths = [] 99 | for path_idx in range(self.all_paths_count): 100 | path = {} 101 | base_radius = self.canvas_size_ / (self.path_count_in_row * 2) 102 | 103 | inc = self.circle_radius_count 104 | radius = shape_params[index: index + inc] * base_radius * 2.5 # 1.5 105 | index += inc 106 | 107 | centerX = (path_idx % self.path_count_in_row) * (2 * base_radius) + base_radius 108 | centerY = (path_idx // self.path_count_in_row) * (2 * base_radius) + base_radius 109 | center_point = torch.tensor([centerX, centerY]).to(radius.device) 110 | circle_points = create_circle_control_points(center_point, radius, self.path_segment_count_) 111 | 112 | if self.WITH_DEFORMATION: 113 | inc = self.one_path_points_count 114 | # deformation_points = (shape_params[index: index + inc] - 0.5) * 2 * self.canvas_size_ * 0.1 115 | deformation_points = (shape_params[index: index + inc] - 0.5) * 2 * radius * 0.2 116 | index += inc 117 | deformated_path = circle_points + deformation_points 118 | else: 119 | deformated_path = circle_points 120 | deformated_closed_path = torch.cat((deformated_path, deformated_path[:2]), dim=-1) 121 | path["points"] = deformated_closed_path.view(-1, 2) 122 | 123 | inc = self.fill_color 124 | if self.WITH_OPACITY: 125 | path["fill_color"] = shape_params[index: index + inc] 126 | else: 127 | path["fill_color"] = torch.cat((shape_params[index: index + inc], 128 | torch.tensor([1.0]).to(radius.device)), dim=-1) 129 | index += inc 130 | 131 | path["stroke_width"] = torch.tensor(0.0).to(noise.device) 132 | path["stroke_color"] = path["fill_color"] 133 | paths.append(path) 134 | 135 | if return_diffvg_svg_params: 136 | svg_params = to_diffvg_svg_params(paths=paths, 137 | background_color=background_color, 138 | canvas_size=self.canvas_size_) 139 | result_svg_params.append(svg_params) 140 | else: 141 | image = action( 142 | paths=paths, 143 | background_color=background_color, 144 | segment_count=self.path_segment_count_, 145 | canvas_size=self.canvas_size_ 146 | ) 147 | result.append(image) 148 | 149 | if not return_psvg: 150 | result = torch.stack(result) 151 | batch_size = audio_embedding.shape[0] 152 | result_channels = 3 # RGB 153 | assert result.shape == (batch_size, result_channels, self.canvas_size_, self.canvas_size_) 154 | 155 | if return_diffvg_svg_params: 156 | return result_svg_params 157 | 158 | return result 159 | -------------------------------------------------------------------------------- /colorer/models/gan_colorer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Optional 3 | 4 | import torch 5 | from torch.utils.data import DataLoader 6 | 7 | from outer.emotions import Emotion 8 | from utils.checkpoint import load_checkpoint, save_checkpoint 9 | from utils.noise import get_noise 10 | 11 | logger = logging.getLogger("colorer gan trainer") 12 | logger.addHandler(logging.StreamHandler()) 13 | logger.setLevel(logging.INFO) 14 | 15 | 16 | class ColorerDiscriminator(torch.nn.Module): 17 | def __init__(self, audio_embedding_dim: int, has_emotions: bool, num_layers: int): 18 | super(ColorerDiscriminator, self).__init__() 19 | colors_count = 12 20 | self.colors_count = colors_count 21 | in_features = audio_embedding_dim + colors_count * 3 22 | if has_emotions: 23 | in_features += len(Emotion) 24 | out_features = 1 25 | 26 | feature_step = (in_features - out_features) // num_layers 27 | 28 | layers = [] 29 | for i in range(num_layers - 1): 30 | out_features = in_features - feature_step 31 | layers += [ 32 | torch.nn.Linear(in_features=in_features, out_features=out_features), 33 | torch.nn.LeakyReLU(0.2), 34 | torch.nn.Dropout2d(0.2) 35 | ] 36 | in_features = out_features 37 | layers += [ 38 | torch.nn.Linear(in_features=in_features, out_features=colors_count), 39 | torch.nn.Sigmoid() 40 | ] 41 | self.adv_layer = torch.nn.Sequential(*layers) 42 | 43 | def forward(self, audio_embedding: torch.Tensor, emotions: Optional[torch.Tensor], 44 | colors: torch.Tensor) -> torch.Tensor: 45 | if emotions is None: 46 | cat = torch.cat((audio_embedding, colors), dim=1) 47 | else: 48 | cat = torch.cat((audio_embedding, emotions, colors), dim=1) 49 | validity = self.adv_layer(cat) 50 | assert not torch.any(torch.isnan(validity)) 51 | return validity 52 | 53 | 54 | def train(train_dataloader: DataLoader, 55 | test_dataloader: DataLoader, 56 | gen, disc: ColorerDiscriminator, 57 | device: torch.device, 58 | training_params: dict): 59 | n_epochs = training_params["n_epochs"] 60 | lr = training_params["lr"] 61 | z_dim = training_params["z_dim"] 62 | disc_slices = training_params["disc_slices"] 63 | checkpoint_root = training_params["checkpoint_root"] 64 | backup_epochs = training_params["backup_epochs"] 65 | 66 | gen_opt = torch.optim.Adam(gen.parameters(), lr=lr) 67 | disc_opt = torch.optim.Adam(disc.parameters(), lr=lr) 68 | model_name = f'GAN_colorer_{gen.colors_count}_colors_{train_dataloader.dataset.sorted_color}' 69 | print("Trying to load checkpoint.") 70 | epochs_done = load_checkpoint(checkpoint_root, model_name, [gen, disc, gen_opt, disc_opt]) 71 | if epochs_done: 72 | logger.info(f"Loaded a checkpoint with {epochs_done} epochs done") 73 | 74 | criterion = torch.nn.MSELoss() 75 | log_interval = 1 76 | colors_count = gen.colors_count 77 | disc_getting_colors_count = 1 78 | for epoch in range(epochs_done + 1, n_epochs + epochs_done + 1): 79 | gen.train() 80 | disc.train() 81 | for batch_idx, batch in enumerate(train_dataloader): 82 | torch.cuda.empty_cache() 83 | if len(batch) == 3: 84 | audio_embedding, real_palette, emotions = batch 85 | real_palette = real_palette.to(device) 86 | emotions = emotions.to(device) 87 | else: 88 | audio_embedding, real_palette = batch 89 | real_palette = real_palette.to(device) 90 | emotions = None 91 | cur_batch_size = len(audio_embedding) 92 | audio_embedding = audio_embedding.float().to(device) 93 | audio_embedding_disc = audio_embedding[:, :disc_slices].reshape(cur_batch_size, -1) 94 | 95 | # Train discriminator 96 | real_outputs = disc(audio_embedding_disc, emotions, real_palette) 97 | real_label = torch.ones(real_palette.shape[0], colors_count).to(device) 98 | z = get_noise(cur_batch_size, z_dim, device=device) 99 | fake_inputs = gen(z, audio_embedding_disc, emotions) 100 | fake_outputs = disc(audio_embedding_disc, emotions, fake_inputs) 101 | fake_label = torch.zeros(fake_inputs.shape[0], colors_count).to(device) 102 | outputs = torch.cat((real_outputs, fake_outputs), dim=0) 103 | targets = torch.cat((real_label, fake_label), dim=0) 104 | D_loss = criterion(outputs, targets) 105 | disc_opt.zero_grad() 106 | D_loss.backward() 107 | disc_opt.step() 108 | # Train generator 109 | z = get_noise(cur_batch_size, z_dim, device=device) 110 | fake_inputs = gen(z, audio_embedding_disc, emotions) 111 | fake_outputs = disc(audio_embedding_disc, emotions, fake_inputs) 112 | fake_targets = torch.ones(fake_inputs.shape[0], colors_count).to(device) 113 | G_loss = criterion(fake_outputs, fake_targets) 114 | gen_opt.zero_grad() 115 | G_loss.backward() 116 | gen_opt.step() 117 | if batch_idx % 10 == 0 or batch_idx == len(train_dataloader): 118 | print('Epoch {} Iteration {}: discriminator_loss {:.3f} generator_loss {:.3f}' 119 | .format(epoch, batch_idx, D_loss.item(), G_loss.item())) 120 | save_checkpoint(checkpoint_root, model_name, epoch, backup_epochs, [gen, disc, gen_opt, disc_opt]) 121 | 122 | if epoch == epochs_done + 1 or epoch % log_interval == 0: 123 | gen.eval() 124 | disc.eval() 125 | running_G_test_loss = 0.0 126 | running_D_test_loss = 0.0 127 | for batch_idx, batch in enumerate(test_dataloader): 128 | torch.cuda.empty_cache() 129 | if len(batch) == 3: 130 | audio_embedding, real_palette, emotions = batch 131 | real_palette = real_palette.to(device) 132 | emotions = emotions.to(device) 133 | else: 134 | audio_embedding, real_palette = batch 135 | real_palette = real_palette.to(device) 136 | emotions = None 137 | cur_batch_size = len(audio_embedding) 138 | audio_embedding = audio_embedding.float().to(device) 139 | audio_embedding_disc = audio_embedding[:, :disc_slices].reshape(cur_batch_size, -1) 140 | 141 | z = get_noise(cur_batch_size, z_dim, device=device) 142 | net_out = gen(z, audio_embedding_disc, emotions) 143 | loss = criterion(net_out * 255, real_palette) 144 | running_G_test_loss += loss 145 | 146 | real_outputs = disc(audio_embedding_disc, emotions, real_palette) 147 | real_label = torch.ones(real_palette.shape[0], colors_count).to(device) 148 | fake_inputs = net_out 149 | fake_outputs = disc(audio_embedding_disc, emotions, fake_inputs) 150 | fake_label = torch.zeros(fake_inputs.shape[0], colors_count).to(device) 151 | outputs = torch.cat((real_outputs, fake_outputs), dim=0) 152 | targets = torch.cat((real_label, fake_label), dim=0) 153 | D_loss = criterion(outputs, targets) 154 | running_D_test_loss += D_loss 155 | 156 | avg_G_test_loss = running_G_test_loss / (batch_idx + 1) 157 | avg_D_test_loss = running_D_test_loss / (batch_idx + 1) 158 | print('Test LOSS: gen {}, disc {}'.format(avg_G_test_loss, avg_D_test_loss)) 159 | -------------------------------------------------------------------------------- /utils/bboxes.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | import operator 3 | from typing import Callable 4 | 5 | import scipy.ndimage as ndimage 6 | import scipy.spatial as spatial 7 | 8 | 9 | class BBox(object): 10 | def __init__(self, x1: int, y1: int, x2: int, y2: int): 11 | """ 12 | (x1, y1) is the upper left corner. 13 | (x2, y2) is the lower right corner. 14 | """ 15 | self.x1 = min(x1, x2) 16 | self.x2 = max(x1, x2) 17 | self.y1 = min(y1, y2) 18 | self.y2 = max(y1, y2) 19 | 20 | def ul(self) -> (int, int): 21 | return self.x1, self.y1 22 | 23 | def lr(self) -> (int, int): 24 | return self.x2, self.y2 25 | 26 | def width(self) -> int: 27 | return self.x2 - self.x1 28 | 29 | def height(self) -> int: 30 | return self.y2 - self.y1 31 | 32 | def wh_ratio(self) -> float: 33 | return self.width() / self.height() 34 | 35 | def area(self) -> int: 36 | return self.width() * self.height() 37 | 38 | def taxicab_diagonal(self) -> int: 39 | # Taxicab distance from (x1, y1) to (x2, y2) 40 | return self.width() + self.height() 41 | 42 | def min_coord_dist(self, other) -> float: 43 | dx = min(abs(self.x1 - other.x1), abs(self.x1 - other.x2), abs(self.x2 - other.x1), abs(self.x2 - other.x2)) 44 | dy = min(abs(self.y1 - other.y1), abs(self.y1 - other.y2), abs(self.y2 - other.y1), abs(self.y2 - other.y2)) 45 | return sqrt(dx ** 2 + dy ** 2) 46 | 47 | def overlaps(self, other) -> bool: 48 | return not ((self.x1 > other.x2) 49 | or (self.x2 < other.x1) 50 | or (self.y1 > other.y2) 51 | or (self.y2 < other.y1)) 52 | 53 | def to_pos(self, canvas_size: int) -> (float, float, float, float): 54 | return ( 55 | self.x1 / canvas_size, 56 | self.y1 / canvas_size, 57 | self.width() / canvas_size, 58 | self.height() / canvas_size 59 | ) 60 | 61 | def split_horizontal(self, ratio: float): 62 | w1 = round(self.width() * ratio) 63 | return BBox(self.x1, self.y1, self.x1 + w1, self.y2), BBox(self.x1 + w1 + 1, self.y1, self.x2, self.y2) 64 | 65 | def split_vertical(self, ratio: float): 66 | h1 = round(self.height() * ratio) 67 | return BBox(self.x1, self.y1, self.x2, self.y1 + h1), BBox(self.x1, self.y1 + h1 + 1, self.x2, self.y2) 68 | 69 | def recanvas(self, cur_canvas: int, new_canvas: int): 70 | """ 71 | Rescale coordinates to a different canvas size 72 | """ 73 | r = new_canvas / cur_canvas 74 | return BBox(int(self.x1 * r), int(self.y1 * r), int(self.x2 * r), int(self.y2 * r)) 75 | 76 | def __str__(self): 77 | return f"BBox ({self.x1}, {self.y1})-({self.x2}, {self.y2}): w={self.width()}, h={self.height()}" 78 | 79 | def __eq__(self, other) -> bool: 80 | return (self.x1 == other.x1 81 | and self.y1 == other.y1 82 | and self.x2 == other.x2 83 | and self.y2 == other.y2) 84 | 85 | def __hash__(self): 86 | return hash((self.x1, self.y1, self.x2, self.y2)) 87 | 88 | 89 | def find_bboxes(img): 90 | filled = ndimage.morphology.binary_fill_holes(img) 91 | coded_blobs, num_blobs = ndimage.label(filled) 92 | data_slices = ndimage.find_objects(coded_blobs) 93 | 94 | bounding_boxes = [] 95 | for s in data_slices: 96 | dy, dx = s[:2] 97 | bounding_boxes.append(BBox(dx.start, dy.start, dx.stop, dy.stop)) 98 | 99 | return bounding_boxes 100 | 101 | 102 | def remove_overlaps(bboxes: [BBox]) -> [BBox]: 103 | """ 104 | Replace overlapping bboxes with the minimal BBox that contains both. 105 | """ 106 | if not bboxes: 107 | return [] 108 | 109 | corners = [] 110 | ul_corners = [b.ul() for b in bboxes] 111 | bbox_map = {} # Corners -> bboxes 112 | 113 | for bbox in bboxes: 114 | for c in (bbox.ul(), bbox.lr()): 115 | corners.append(c) 116 | bbox_map[c] = bbox 117 | 118 | tree = spatial.KDTree(corners) # Quick nearest-neighbor lookup 119 | for c in ul_corners: 120 | bbox = bbox_map[c] 121 | # Find all points within a taxicab distance of the corner 122 | indices = tree.query_ball_point(c, bbox_map[c].taxicab_diagonal(), p=1) 123 | for near_corner in tree.data[indices]: 124 | near_bbox = bbox_map[tuple(near_corner)] 125 | if bbox != near_bbox and bbox.overlaps(near_bbox): 126 | # Expand both bboxes 127 | bbox.x1 = near_bbox.x1 = min(bbox.x1, near_bbox.x1) 128 | bbox.y1 = near_bbox.y1 = min(bbox.y1, near_bbox.y1) 129 | bbox.x2 = near_bbox.x2 = max(bbox.x2, near_bbox.x2) 130 | bbox.y2 = near_bbox.y2 = max(bbox.y2, near_bbox.y2) 131 | return list(set(bbox_map.values())) 132 | 133 | 134 | def filter_bboxes_by_size(bboxes, threshold): 135 | return list(filter(lambda x: x.width() > threshold and x.height() > threshold, bboxes)) 136 | 137 | 138 | def merge_bboxes(bboxes: [BBox]) -> BBox: 139 | assert bboxes 140 | if len(bboxes) == 1: 141 | return bboxes[0] 142 | 143 | x1 = min(b.x1 for b in bboxes) 144 | y1 = min(b.y1 for b in bboxes) 145 | x2 = max(b.x2 for b in bboxes) 146 | y2 = max(b.y2 for b in bboxes) 147 | 148 | return BBox(x1 - 1, y1 - 1, x2, y2) 149 | 150 | 151 | def merge_aligned_bboxes(bboxes: [BBox], canvas_size: int) -> [BBox]: 152 | if not bboxes: 153 | return [] 154 | 155 | shift_threshold = 1 / 4 156 | dim_threshold = 5 / 12 157 | canvas_threshold = 1 / 5 158 | 159 | params: [(str, Callable[[BBox], int])] = [ 160 | ("y1", "x1", BBox.height, BBox.width), # top 161 | ("y2", "x1", BBox.height, BBox.width), # bottom 162 | ("x1", "y1", BBox.width, BBox.height), # left 163 | ("x2", "y1", BBox.width, BBox.height), # right 164 | ] 165 | 166 | for align_param, sort_param, dim_param_f, dist_param_f in params: 167 | bboxes = sorted(bboxes, key=operator.attrgetter(align_param)) 168 | result = [] 169 | 170 | def split_group(cur): 171 | cur = sorted(cur, key=operator.attrgetter(sort_param)) 172 | subgroup = [cur[0]] 173 | for x in cur[1:]: 174 | if x.min_coord_dist(subgroup[-1]) < canvas_size * canvas_threshold: 175 | subgroup.append(x) 176 | else: 177 | result.append(merge_bboxes(subgroup)) 178 | subgroup = [x] 179 | if subgroup: 180 | result.append(merge_bboxes(subgroup)) 181 | 182 | cur_group = [] 183 | prev_pos = getattr(bboxes[0], align_param) 184 | prev_dim = dim_param_f(bboxes[0]) 185 | for b in bboxes: 186 | b_dim = dim_param_f(b) 187 | max_dim = max(b_dim, prev_dim) 188 | valid_shift = abs(getattr(b, align_param) - prev_pos) < max_dim * shift_threshold 189 | valid_dim_change = abs(b_dim - prev_dim) < max_dim * dim_threshold 190 | valid_canvas_ratio = (not cur_group) or b_dim < canvas_size * canvas_threshold 191 | if valid_shift and valid_dim_change and valid_canvas_ratio: 192 | cur_group.append(b) 193 | else: 194 | split_group(cur_group) 195 | cur_group = [b] 196 | prev_pos = getattr(b, align_param) 197 | prev_dim = b_dim 198 | if cur_group: 199 | split_group(cur_group) 200 | bboxes = result 201 | 202 | return bboxes 203 | 204 | 205 | def crop_bboxes_by_canvas(bboxes: [BBox], canvas_size: int): 206 | for b in bboxes: 207 | b.x1 = max(b.x1, 0) 208 | b.y1 = max(b.y1, 0) 209 | b.x2 = min(b.x2, canvas_size - 1) 210 | b.y2 = min(b.y2, canvas_size - 1) 211 | 212 | 213 | def sort_bboxes_by_area(bboxes: [BBox]) -> [BBox]: 214 | return sorted(bboxes, key=lambda x: x.area()) 215 | -------------------------------------------------------------------------------- /colorer/test_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | import os 4 | import random 5 | 6 | import torch 7 | from PIL import Image, ImageDraw 8 | 9 | from colorer.colors_transforms import rgb_lab_rgb, rgb_to_cielab 10 | from colorer.models.colorer import Colorer 11 | from colorer.models.colorer_dropout import Colorer2 12 | from colorer.music_palette_dataset import image_file_to_palette 13 | from outer.emotions import emotions_one_hot, Emotion 14 | from utils.noise import get_noise 15 | 16 | 17 | def compare_for_file(f_name, colors_count): 18 | print(f_name) 19 | device = "cuda" if torch.cuda.is_available() else "cpu" 20 | disc_slices = 6 21 | num_samples = 1 22 | emotions = [random.choice(list(Emotion))] 23 | deterministic = False 24 | z_dim = 32 25 | f_name = f_name.replace(".pt", "") 26 | cover_file_name = f"../dataset_full_covers/clean_covers/{f_name.replace('.mp3', '.jpg')}" 27 | audio_checkpoint_fname = f"../dataset_full_covers/checkpoint/cgan_out_dataset/{f_name}.pt" 28 | # music_tensor = torch.from_numpy(audio_to_embedding(audio_file_name)) 29 | music_tensor = torch.load(audio_checkpoint_fname) 30 | target_count = 24 # 2m = 120s, 120/5 31 | if len(music_tensor) < target_count: 32 | music_tensor = music_tensor.repeat(target_count // len(music_tensor) + 1, 1) 33 | music_tensor = music_tensor[:target_count].float().to(device) 34 | music_tensor = music_tensor[:disc_slices].flatten() 35 | music_tensor = music_tensor.unsqueeze(dim=0).repeat((num_samples, 1)) 36 | emotions_tensor = emotions_one_hot(emotions).to(device).unsqueeze(dim=0).repeat((num_samples, 1)) 37 | if deterministic: 38 | noise = torch.zeros((num_samples, z_dim), device=device) 39 | else: 40 | noise = get_noise(num_samples, z_dim, device=device) 41 | generator = get_palette_predictor() 42 | palette = generator.predict(noise, music_tensor, emotions_tensor) 43 | imsize = 512 44 | 45 | palette = [tuple(map(int, c)) for c in palette] 46 | im = Image.new('RGB', (imsize * 3, imsize)) 47 | im.paste(Image.open(cover_file_name).resize((imsize, imsize))) 48 | draw = ImageDraw.Draw(im) 49 | 50 | real_palette = image_file_to_palette(f_name, "../dataset_full_covers/clean_covers", colors_count) 51 | print("=" * 20) 52 | print("real palette: ", real_palette) 53 | rgb_lab_real_pal = [rgb_lab_rgb(p) for p in real_palette] 54 | print("rgb lab real palette: ", rgb_lab_real_pal) 55 | print("lab real palette: ", [list(rgb_to_cielab(p)) for p in real_palette]) 56 | print("predicted: ", [list(p) for p in palette]) 57 | rgb_lab_pred_pal = [rgb_lab_rgb(p) for p in palette] 58 | print("rgb lab predicted palette:", rgb_lab_pred_pal) 59 | print("lab predicted palette: ", [list(rgb_to_cielab(p)) for p in palette]) 60 | 61 | # palettes_to_paint = [real_palette, rgb_lab_real_pal, palette, rgb_lab_pred_pal] 62 | palettes_to_paint = [real_palette, palette] 63 | for pal_i, pal in enumerate(palettes_to_paint): 64 | cur_h = 0 65 | width = imsize / (colors_count - 1) 66 | for i, p in enumerate(pal): 67 | p = tuple(p) 68 | colorval = "#%02x%02x%02x" % p 69 | draw.rectangle((imsize * (pal_i + 1), cur_h, imsize * (pal_i + 2) - 10, cur_h + width), fill=colorval) 70 | cur_h += width 71 | im.show() 72 | 73 | 74 | def get_palette_predictor(device=None, color_predictor_weights=None, model_type="2"): 75 | import os 76 | # print("cur path:", os.path.abspath(os.getcwd())) 77 | # print(os.listdir(os.getcwd())) 78 | if device is None: 79 | device = "cuda" if torch.cuda.is_available() else "cpu" 80 | colors_count = 12 81 | disc_slices = 6 82 | z_dim = 32 83 | num_gen_layers = 5 84 | audio_embedding_dim = 281 85 | if color_predictor_weights is None or color_predictor_weights == "": 86 | # model_path = f"dataset_full_covers/checkpoint/colorer_{colors_count}_colors-28800.pt" 87 | # model_path = f"dataset_full_covers/checkpoint/colorer_{colors_count}_colors-1200.pt" 88 | model_path = f"dataset_full_covers/checkpoint/colorer_{colors_count}_colors-1800.pt" 89 | # model_path = f"dataset_full_covers/checkpoint/colorer_{colors_count}_colors_sorted-100.pt" 90 | # model_path = f"dataset_full_covers/checkpoint/GAN_colorer_{colors_count}_colors_sorted-200.pt" 91 | if os.getcwd().endswith("covergan") or os.getcwd().endswith("scratch"): 92 | gan_weights = f"./{model_path}" 93 | else: 94 | gan_weights = f"../{model_path}" 95 | else: 96 | gan_weights = color_predictor_weights 97 | 98 | if model_type == "1": 99 | gen_type = Colorer 100 | else: 101 | gen_type = Colorer2 102 | generator = gen_type( 103 | z_dim=z_dim, 104 | audio_embedding_dim=audio_embedding_dim * disc_slices, 105 | has_emotions=True, 106 | num_layers=num_gen_layers, 107 | colors_count=colors_count) 108 | generator.eval() 109 | gan_weights = torch.load(gan_weights, map_location=device) 110 | generator.load_state_dict(gan_weights["0_state_dict"]) 111 | generator.to(device) 112 | return generator 113 | 114 | 115 | def cmp(): 116 | f_names = [ 117 | "&me - The Rapture Pt.II.mp3", 118 | "Zomboy - Lone Wolf.mp3", 119 | "Zero 7 - Home.mp3", 120 | "ZAYSTIN - Without You.mp3", 121 | "Yu Jae Seok - Dancing King.mp3", 122 | "Vargas & Lagola - Selfish.mp3", 123 | "voisart - Like Glass.mp3", 124 | "Кино - Группа крови.mp3", 125 | # "Кино - Звезда по имени Солнце.mp3", 126 | # "Young the Giant - Mirrorball.mp3", 127 | # "Younger Hunger - Elmer.mp3", 128 | # "Younger Hunger - Straight Face.mp3", 129 | # "Zack Martino - Hold On To Me.mp3", 130 | # "Zara Larsson - WOW (feat. Sabrina Carpenter).mp3", 131 | # "ZAYSTIN - Rather Go Blind.mp3", 132 | # "ZAYSTIN - Without You.mp3", 133 | # "Zoé - SKR.mp3", 134 | # "Zoé - Azul.mp3.pt", 135 | # "АИГЕЛ - Офигенно.mp3.pt", 136 | # "АИГЕЛ - Ул.mp3.pt", 137 | # "РОУКС - Bread.mp3.pt", 138 | # "Саша Ролекс - Плакали.mp3.pt", 139 | # "Саша Ролекс - Рана.mp3.pt", 140 | ] 141 | for f_name in f_names: 142 | compare_for_file(f_name, 12) 143 | 144 | 145 | def xxx(): 146 | ckpts = 'check_volume_ckpts' 147 | for ckpt in os.listdir(ckpts): 148 | audio_checkpoint_fname = f"{ckpts}/{ckpt}" 149 | run_ckpt(audio_checkpoint_fname) 150 | 151 | 152 | def run_ckpt(audio_checkpoint_fname): 153 | device = "cuda" if torch.cuda.is_available() else "cpu" 154 | colors_count = 12 155 | disc_slices = 6 156 | num_samples = 1 157 | deterministic = False 158 | z_dim = 32 159 | print(audio_checkpoint_fname) 160 | emotions = [random.choice(list(Emotion))] 161 | music_tensor = torch.load(audio_checkpoint_fname) 162 | target_count = 24 # 2m = 120s, 120/5 163 | if len(music_tensor) < target_count: 164 | music_tensor = music_tensor.repeat(target_count // len(music_tensor) + 1, 1) 165 | music_tensor = music_tensor[:target_count].float().to(device) 166 | music_tensor = music_tensor[:disc_slices].flatten() 167 | music_tensor = music_tensor.unsqueeze(dim=0).repeat((num_samples, 1)) 168 | emotions_tensor = emotions_one_hot(emotions).to(device).unsqueeze(dim=0).repeat((num_samples, 1)) 169 | if deterministic: 170 | noise = torch.zeros((num_samples, z_dim), device=device) 171 | else: 172 | noise = get_noise(num_samples, z_dim, device=device) 173 | generator = get_palette_predictor() 174 | palette = generator.predict(noise, music_tensor, emotions_tensor) 175 | imsize = 512 176 | palette = [tuple(map(int, c)) for c in palette] 177 | im = Image.new('RGB', (imsize * 3, imsize)) 178 | draw = ImageDraw.Draw(im) 179 | palettes_to_paint = [palette] 180 | for pal_i, pal in enumerate(palettes_to_paint): 181 | cur_h = 0 182 | width = imsize / (colors_count - 1) 183 | for i, p in enumerate(pal): 184 | p = tuple(p) 185 | colorval = "#%02x%02x%02x" % p 186 | draw.rectangle((imsize * (pal_i + 1), cur_h, imsize * (pal_i + 2) - 10, cur_h + width), fill=colorval) 187 | cur_h += width 188 | im.show() 189 | 190 | 191 | if __name__ == '__main__': 192 | # cmp() 193 | # xxx() 194 | run_ckpt("../diploma_test/test_random_music/music_ckpts/George Michael - Carles Whisper.mp3.pt") 195 | -------------------------------------------------------------------------------- /colorer/music_palette_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import numpy as np 4 | from torch.utils.data.dataset import Dataset 5 | 6 | from outer.emotions import emotions_one_hot, read_emotion_file 7 | from utils.dataset_utils import * 8 | from utils.filenames import normalize_filename 9 | from utils.image_clustering import cluster 10 | 11 | logger = logging.getLogger("dataset") 12 | logger.addHandler(logging.StreamHandler()) 13 | # logger.setLevel(logging.INFO) 14 | logger.setLevel(logging.WARNING) 15 | 16 | 17 | def get_main_rgb_palette(f_path: str, color_count: int, quality=5): 18 | from colorthief import ColorThief 19 | color_thief = ColorThief(f_path) 20 | return color_thief.get_palette(color_count=color_count + 1, quality=quality) 21 | 22 | 23 | def get_main_rgb_palette2(f_path: str, color_count: int, sort_colors=True): 24 | import PIL.Image 25 | pil_image = PIL.Image.open(f_path).convert(mode='RGB') 26 | labels, centers = cluster(np.asarray(pil_image), k=color_count, only_labels_centers=True) 27 | if not sort_colors: 28 | return [list(x) for x in centers] 29 | 30 | from collections import Counter 31 | dict = Counter(labels.flatten()) 32 | sorted_labels = sorted(dict.items(), key=lambda x: x[1], reverse=True) 33 | palette = [list(centers[label[0]]) for label in sorted_labels] 34 | return palette 35 | 36 | 37 | def image_file_to_palette(f: str, cover_dir: str, color_count: int, color_type: str = 'rgb', 38 | sorted_colors: bool = True): 39 | cover_file = f'{cover_dir}/{replace_extension(f, IMAGE_EXTENSION)}' 40 | if sorted_colors: 41 | palette = get_main_rgb_palette2(cover_file, color_count, sort_colors=True) 42 | else: 43 | palette = get_main_rgb_palette(cover_file, color_count) 44 | palette = [list(t) for t in palette] 45 | palette = palette + palette[:color_count - len(palette)] 46 | assert len(palette) == color_count 47 | 48 | if color_type == 'rgb': 49 | return palette 50 | if color_type == 'lab': 51 | from colorer.colors_transforms import rgb_to_cielab, cielab_rgb_to 52 | print("before:", palette) 53 | palette_very_before = palette.copy() 54 | palette = [rgb_to_cielab(np.array(p)) for p in palette] 55 | print("after:", palette) 56 | palette_old_check = [cielab_rgb_to(np.array(p)) for p in palette] 57 | palette_old_check2 = 255 * np.array(palette_old_check) 58 | print("after:", palette_old_check) 59 | print("after:", palette_old_check2) 60 | return palette 61 | 62 | 63 | def process_cover_to_palette(f: str, checkpoint_root: str, cover_dir: str, palette_count: int, 64 | color_type: str = 'rgb', sorted_colors: bool = True, num: int = None): 65 | palette_tensor_f = get_tensor_file(checkpoint_root, replace_extension(f, "")) 66 | if not os.path.isfile(palette_tensor_f): 67 | logger.info(f'No palette tensor for file #{num}: {f}, generating...') 68 | palette = image_file_to_palette(f, cover_dir, palette_count, 69 | color_type=color_type, sorted_colors=sorted_colors) 70 | palette = np.concatenate(palette) # / 255 71 | palette_tensor = torch.from_numpy(palette).float() 72 | torch.save(palette_tensor, palette_tensor_f) 73 | 74 | 75 | def read_tensor_from_file(f: str, checkpoint_root: str): 76 | tensor_f = get_tensor_file(checkpoint_root, f) 77 | assert os.path.isfile(tensor_f), f'Tensor file missing for {tensor_f}' 78 | return torch.load(tensor_f) 79 | 80 | 81 | class MusicPaletteDataset(Dataset[Tuple[torch.Tensor, torch.Tensor]]): 82 | def __init__(self, name: str, checkpoint_dir: str, audio_dir: str, cover_dir: str, emotion_file: Optional[str], 83 | sort_colors: bool = True, 84 | should_cache: bool = True, 85 | is_for_train: bool = True, train_test_split_coef: float = 0.9): 86 | self.color_type = 'rgb' 87 | # self.color_type = 'lab' 88 | self.sorted_color = 'sorted' if sort_colors else 'unsorted' 89 | self.palette_count = 12 90 | self.palette_name = 'palette_dataset' 91 | self.checkpoint_root_ = f'{checkpoint_dir}/{name}' 92 | self.palette_checkpoint_root_ = f'{checkpoint_dir}/{self.palette_name}/' \ 93 | f'palette_{self.color_type}_count_{self.palette_count}_{self.sorted_color}' 94 | self.cache_ = {} if should_cache else None 95 | 96 | self.is_for_train = is_for_train 97 | self.train_test_split_coef = train_test_split_coef 98 | 99 | self.create_palette_tensor_files(cover_dir) 100 | create_music_tensor_files(self.checkpoint_root_, audio_dir, cover_dir) 101 | 102 | self.dataset_files_ = self.get_dataset_files() 103 | 104 | self.emotions_dict_ = None 105 | if emotion_file is not None: 106 | if not os.path.isfile(emotion_file): 107 | print(f"WARNING: Emotion file '{emotion_file}' does not exist") 108 | else: 109 | emotions_list = read_emotion_file(emotion_file) 110 | emotions_dict = dict(emotions_list) 111 | self.emotions_dict_ = emotions_dict 112 | for f in self.dataset_files_: 113 | f = normalize_filename(f) 114 | if f not in emotions_dict: 115 | print(f"Emotions were not provided for dataset file {f}") 116 | self.emotions_dict_ = None 117 | if self.emotions_dict_ is None: 118 | print("WARNING: Ignoring emotion data, see reasons above.") 119 | else: 120 | for f, emotions in self.emotions_dict_.items(): 121 | self.emotions_dict_[f] = emotions_one_hot(emotions) 122 | 123 | self.cover_dir_ = cover_dir 124 | 125 | def get_dataset_files(self): 126 | dataset_files = sorted([ 127 | f[:-len('.pt')] for f in os.listdir(self.checkpoint_root_) 128 | if f.endswith('.pt') 129 | ]) 130 | for_train_files_count = int(len(dataset_files) * self.train_test_split_coef) 131 | print(f"--- {for_train_files_count} files considered for training") 132 | print(f"--- {len(dataset_files) - for_train_files_count} files considered for testing") 133 | if self.is_for_train: 134 | dataset_files = dataset_files[:for_train_files_count] 135 | else: 136 | dataset_files = dataset_files[for_train_files_count:] 137 | return dataset_files 138 | 139 | def create_palette_tensor_files(self, cover_dir: str): 140 | completion_marker = f'{self.palette_checkpoint_root_}/COMPLETE' 141 | if not os.path.isfile(completion_marker): 142 | logger.info('Building the palette dataset based on covers') 143 | dataset_files = sorted([ 144 | f for f in os.listdir(cover_dir) 145 | if os.path.isfile(f'{cover_dir}/{f}') and filename_extension(f) in IMAGE_EXTENSION 146 | ]) 147 | os.makedirs(self.palette_checkpoint_root_, exist_ok=True) 148 | with Pool(maxtasksperchild=50) as pool: 149 | pool.starmap( 150 | process_cover_to_palette, 151 | zip(dataset_files, repeat(self.palette_checkpoint_root_), repeat(cover_dir), 152 | repeat(self.palette_count), 153 | repeat(self.color_type), 154 | repeat(self.sorted_color), 155 | [i for i in range(len(dataset_files))]), 156 | chunksize=100 157 | ) 158 | logger.info('Marking the palette dataset complete.') 159 | Path(completion_marker).touch(exist_ok=False) 160 | else: 161 | dataset_files = sorted([ 162 | f[:-len('.pt')] for f in os.listdir(self.palette_checkpoint_root_) 163 | if f.endswith('.pt') 164 | ]) 165 | for f in dataset_files: 166 | cover_file = f'{cover_dir}/{f}{IMAGE_EXTENSION}' 167 | assert os.path.isfile(cover_file), f'No cover for {f}' 168 | logger.info(f'Palette dataset considered complete with {len(dataset_files)} covers.') 169 | 170 | def has_emotions(self): 171 | return self.emotions_dict_ is not None 172 | 173 | def __getitem__(self, index) -> Tuple[torch.Tensor, torch.Tensor]: 174 | if self.cache_ is not None and index in self.cache_: 175 | return self.cache_[index] 176 | 177 | track_index = index 178 | f = self.dataset_files_[track_index] 179 | 180 | music_tensor = read_tensor_from_file(f, self.checkpoint_root_) 181 | palette_tensor = read_tensor_from_file(replace_extension(f, ""), self.palette_checkpoint_root_) 182 | emotions = self.emotions_dict_[normalize_filename(f)] if self.emotions_dict_ is not None else None 183 | 184 | target_count = 24 # 2m = 120s, 120/5 185 | if len(music_tensor) < target_count: 186 | music_tensor = music_tensor.repeat(target_count // len(music_tensor) + 1, 1) 187 | music_tensor = music_tensor[:target_count] 188 | 189 | if emotions is not None: 190 | result = music_tensor, palette_tensor, emotions 191 | # result = music_tensor, palette_tensor, emotions, f 192 | else: 193 | result = music_tensor, palette_tensor 194 | 195 | if self.cache_ is not None: 196 | self.cache_[index] = result 197 | 198 | return result 199 | 200 | def __len__(self) -> int: 201 | return len(self.dataset_files_) 202 | -------------------------------------------------------------------------------- /covergan_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | import argparse 4 | import logging 5 | import os 6 | 7 | import torch 8 | from torch.utils.data.dataloader import DataLoader 9 | 10 | from colorer.test_model import get_palette_predictor 11 | from outer.dataset import MusicDataset 12 | 13 | from outer.train import make_gan_models, train 14 | 15 | from utils.noise import get_noise 16 | from utils.plotting import plot_real_fake_covers 17 | 18 | logger = logging.getLogger("out_main") 19 | logger.addHandler(logging.StreamHandler()) 20 | logger.setLevel(logging.INFO) 21 | 22 | 23 | def get_train_data(checkpoint_dir: str, audio_dir: str, cover_dir: str, emotion_file: str, 24 | batch_size: int, canvas_size: int, 25 | augment_dataset: bool) -> (DataLoader, int, (int, int, int), bool): 26 | dataset = MusicDataset("cgan_out_dataset", checkpoint_dir, 27 | audio_dir, cover_dir, emotion_file, 28 | canvas_size, augment_dataset) 29 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) 30 | music_tensor, cover_tensor = dataset[0][:2] 31 | audio_embedding_dim = music_tensor.shape[1] 32 | img_shape = cover_tensor.shape 33 | has_emotions = dataset.has_emotions() 34 | 35 | return dataloader, audio_embedding_dim, img_shape, has_emotions 36 | 37 | 38 | def get_test_data(checkpoint_dir: str, test_set_dir: str, test_emotion_file: str, 39 | batch_size: int, canvas_size: int) -> (DataLoader, int, (int, int, int), bool): 40 | test_dataset = MusicDataset("cgan_out_test_dataset", checkpoint_dir, 41 | test_set_dir, test_set_dir, 42 | test_emotion_file, canvas_size) 43 | test_dataloader = DataLoader( 44 | test_dataset, 45 | batch_size=batch_size, 46 | shuffle=True 47 | ) 48 | return test_dataloader 49 | 50 | 51 | def demo_samples(gen, dataloader: DataLoader, z_dim: int, disc_slices: int, device: torch.device, 52 | palette_generator=None): 53 | def generate(z, audio_embedding_disc, emotions): 54 | if palette_generator is None: 55 | return gen(z, audio_embedding_disc, emotions) 56 | return gen(z, audio_embedding_disc, emotions, palette_generator=palette_generator) 57 | 58 | gen.eval() 59 | 60 | sample_count = 5 # max covers to draw 61 | 62 | with torch.no_grad(): 63 | for batch in dataloader: 64 | if len(batch) == 2: 65 | audio_embedding, real_cover_tensor = batch 66 | emotions = None 67 | else: 68 | audio_embedding, real_cover_tensor, emotions = batch 69 | emotions = emotions[:sample_count].to(device) 70 | sample_count = min(sample_count, len(audio_embedding)) 71 | audio_embedding = audio_embedding[:sample_count].float().to(device) 72 | audio_embedding = audio_embedding[:, :disc_slices].reshape(sample_count, -1) 73 | real_cover_tensor = real_cover_tensor[:sample_count].to(device) 74 | 75 | noise = get_noise(sample_count, z_dim, device=device) 76 | fake_cover_tensor = generate(noise, audio_embedding, emotions) 77 | 78 | plot_real_fake_covers(real_cover_tensor, fake_cover_tensor) 79 | break # we only want one batch 80 | 81 | gen.train() 82 | 83 | 84 | def display_dataset_objects(dataloader: DataLoader, disc_slices: int): 85 | from sklearn.decomposition import PCA 86 | import matplotlib.pyplot as plt 87 | 88 | all_audio_embeddings = [] 89 | for batch in dataloader: 90 | audio_embedding = batch[0] 91 | batch_size = len(audio_embedding) 92 | audio_embedding = audio_embedding[:, :disc_slices].reshape(batch_size, -1) 93 | all_audio_embeddings.append(audio_embedding.float()) 94 | all_audio_embeddings = torch.cat(all_audio_embeddings, dim=0).cpu().numpy() 95 | 96 | pca = PCA(n_components=2) 97 | xy = pca.fit_transform(all_audio_embeddings) 98 | x = list(xy[:, 0]) 99 | y = list(xy[:, 1]) 100 | 101 | plt.scatter(x, y, s=1) 102 | plt.title("Training tracks") 103 | plt.show() 104 | 105 | 106 | def file_in_folder(dir, file): 107 | if file is None: 108 | return None 109 | return f"{dir}/{file}" 110 | 111 | 112 | def main(): 113 | parser = argparse.ArgumentParser() 114 | parser.add_argument("--train_dir", help="Directory with all folders for training", type=str, default=".") 115 | parser.add_argument("--plots", help="Directory where save plots while training", type=str, default="plots") 116 | parser.add_argument("--audio", help="Directory with the music files", type=str, default="audio") 117 | parser.add_argument("--covers", help="Directory with the cover images", type=str, default="clean_covers") 118 | parser.add_argument("--emotions", help="File with emotion markup for train dataset", type=str, default=None) 119 | parser.add_argument("--test_set", help="Directory with test music files", type=str, default=None) 120 | parser.add_argument("--test_emotions", help="File with emotion markup for test dataset", type=str, default=None) 121 | parser.add_argument("--checkpoint_root", help="Checkpoint location", type=str, default="checkpoint") 122 | parser.add_argument("--augment_dataset", help="Whether to augment the dataset", default=False, action="store_true") 123 | parser.add_argument("--gen_lr", help="Generator learning rate", type=float, default=0.0005) 124 | parser.add_argument("--disc_lr", help="Discriminator learning rate", type=float, default=0.0005) 125 | parser.add_argument("--disc_repeats", help="Discriminator runs per iteration", type=int, default=5) 126 | parser.add_argument("--epochs", help="Number of epochs to train for", type=int, default=8000) 127 | parser.add_argument("--batch_size", help="Batch size", type=int, default=64) 128 | parser.add_argument("--canvas_size", help="Image canvas size for learning", type=int, default=128) 129 | parser.add_argument("--display_steps", help="How often to plot the samples", type=int, default=500) 130 | parser.add_argument("--backup_epochs", help="How often to backup checkpoints", type=int, default=600) 131 | parser.add_argument("--plot_grad", help="Whether to plot the gradients", default=False, action="store_true") 132 | args = parser.parse_args() 133 | print(args) 134 | 135 | # Network properties 136 | num_gen_layers = 5 137 | num_disc_conv_layers = 3 138 | num_disc_linear_layers = 2 139 | z_dim = 32 # Dimension of the noise vector 140 | # z_dim = 512 # Dimension of the noise vector 141 | 142 | # Painter properties 143 | path_count = 3 144 | path_segment_count = 4 145 | disc_slices = 6 146 | max_stroke_width = 0.01 # relative to the canvas size 147 | 148 | # Plot properties 149 | bin_steps = 20 # How many steps to aggregate with mean for each plot point 150 | 151 | logger.info("--- Starting out_main ---") 152 | 153 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 154 | 155 | os.makedirs(file_in_folder(args.train_dir, args.checkpoint_root), exist_ok=True) 156 | 157 | dataloader, audio_embedding_dim, img_shape, has_emotions = get_train_data( 158 | file_in_folder(args.train_dir, args.checkpoint_root), 159 | file_in_folder(args.train_dir, args.audio), 160 | file_in_folder(args.train_dir, args.covers), 161 | file_in_folder(args.train_dir, args.emotions), 162 | args.batch_size, args.canvas_size, args.augment_dataset 163 | ) 164 | logger.plots_dir = file_in_folder(args.train_dir, args.plots) 165 | os.makedirs(logger.plots_dir, exist_ok=True) 166 | # display_dataset_objects(dataloader, disc_slices) 167 | if args.test_set is None: 168 | test_dataloader = None 169 | else: 170 | test_dataloader = get_test_data( 171 | file_in_folder(args.train_dir, args.checkpoint_root), 172 | file_in_folder(args.train_dir, args.test_set), 173 | file_in_folder(args.train_dir, args.test_emotions), 174 | args.batch_size, args.canvas_size 175 | ) 176 | 177 | logger.info("--- CoverGAN training ---") 178 | gen, disc = make_gan_models( 179 | z_dim=z_dim, 180 | audio_embedding_dim=audio_embedding_dim, 181 | img_shape=img_shape, 182 | has_emotions=has_emotions, 183 | num_gen_layers=num_gen_layers, 184 | num_disc_conv_layers=num_disc_conv_layers, 185 | num_disc_linear_layers=num_disc_linear_layers, 186 | path_count=path_count, 187 | path_segment_count=path_segment_count, 188 | max_stroke_width=max_stroke_width, 189 | disc_slices=disc_slices, 190 | device=device 191 | ) 192 | palette_generator = get_palette_predictor(device) 193 | 194 | demo_samples(gen, dataloader, z_dim, disc_slices, device, palette_generator=palette_generator) 195 | 196 | train(dataloader, test_dataloader, gen, disc, device, { 197 | # Common 198 | "display_steps": args.display_steps, 199 | "backup_epochs": args.backup_epochs, 200 | "bin_steps": bin_steps, 201 | "z_dim": z_dim, 202 | "disc_slices": disc_slices, 203 | "checkpoint_root": file_in_folder(args.train_dir, args.checkpoint_root), 204 | # (W)GAN-specific 205 | "n_epochs": args.epochs, 206 | "gen_lr": args.gen_lr, 207 | "disc_lr": args.disc_lr, 208 | "disc_repeats": args.disc_repeats, 209 | "plot_grad": args.plot_grad, 210 | }, cgan_out_name="cgan_6figs_32noise_separated_palette_tanh_betas", palette_generator=palette_generator, 211 | USE_SHUFFLING=True) 212 | 213 | logger.info("--- CoverGAN sample demo ---") 214 | demo_samples(gen, dataloader, z_dim, disc_slices, device, palette_generator=palette_generator) 215 | 216 | 217 | if __name__ == '__main__': 218 | main() 219 | -------------------------------------------------------------------------------- /outer/SVGContainer.py: -------------------------------------------------------------------------------- 1 | import html 2 | from typing import List 3 | 4 | 5 | def wand_rendering(svg_str): 6 | # WARNING! Works only on Windows. 7 | import wand.image 8 | with wand.image.Image(blob=svg_str.encode(), format="svg") as image: 9 | png_image = image.make_blob("png") 10 | return png_image 11 | 12 | 13 | def cairo_rendering(svg_str): 14 | # WARNING! Works only on Linux. 15 | from cairosvg import svg2png 16 | return svg2png(bytestring=svg_str) 17 | 18 | 19 | def svglib_rendering(svg_str): 20 | # WARNING!