├── README.md ├── Stable_Diffusion.ipynb ├── Wav2Lip.ipynb ├── LDM_TXT2IM.ipynb ├── articulated_animation.ipynb ├── projected_g.ipynb ├── dlfs.ipynb ├── pi_GAN.ipynb ├── video_matting.ipynb ├── Stable_Diffusion2.ipynb ├── FacialCartoonization.ipynb ├── autoencoder.ipynb ├── Optimized_LDM_TXT2IM.ipynb ├── DemoSegmenter.ipynb ├── mttr_interactive_demo.ipynb ├── ArcaneGAN.ipynb ├── ArcaneGAN_latest.ipynb ├── RIS_demo.ipynb ├── ReStyle_animations.ipynb ├── AnimeGANV2_for_face.ipynb └── stylegan_nada.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # others2 -------------------------------------------------------------------------------- /Stable_Diffusion.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Stable_Diffusion", 7 | "provenance": [], 8 | "authorship_tag": "ABX9TyOe9iaMhUOS+00Llkwxyw8t", 9 | "include_colab_link": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "language_info": { 16 | "name": "python" 17 | }, 18 | "accelerator": "GPU", 19 | "gpuClass": "standard" 20 | }, 21 | "cells": [ 22 | { 23 | "cell_type": "markdown", 24 | "metadata": { 25 | "id": "view-in-github", 26 | "colab_type": "text" 27 | }, 28 | "source": [ 29 | "\"Open" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "source": [ 35 | "#@title **セットアップ**\n", 36 | "\n", 37 | "# ライブラリのインストール\n", 38 | "!pip install diffusers==0.8.0 transformers scipy ftfy\n", 39 | "\n", 40 | "# アクセス・トークン設定\n", 41 | "Access_Token=\"\"#@param {type:\"string\"}\n", 42 | "\n", 43 | "# パイプライン構築\n", 44 | "from diffusers import StableDiffusionPipeline\n", 45 | "pipe = StableDiffusionPipeline.from_pretrained(\"CompVis/stable-diffusion-v1-4\", use_auth_token=Access_Token)\n", 46 | "pipe.to(\"cuda\")" 47 | ], 48 | "metadata": { 49 | "id": "RrqY6TqCNWNo" 50 | }, 51 | "execution_count": null, 52 | "outputs": [] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "source": [ 57 | "#@title **画像生成**\n", 58 | "\n", 59 | "# 生成\n", 60 | "prompt = \"An astronaut riding a horse in a photorealistic style\" #@param {type:\"string\"}\n", 61 | "image = pipe(prompt)[\"images\"][0]\n", 62 | "\n", 63 | "# 保存\n", 64 | "sentence = prompt.replace(' ','_')\n", 65 | "out_path = sentence+'.png'\n", 66 | "image.save(out_path)\n", 67 | "\n", 68 | "# 表示\n", 69 | "from IPython.display import Image,display\n", 70 | "display(Image(out_path))" 71 | ], 72 | "metadata": { 73 | "id": "ppaduE-0O9gT" 74 | }, 75 | "execution_count": null, 76 | "outputs": [] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "source": [ 81 | "#@title **画像のダウンロード**\n", 82 | "from google.colab import files\n", 83 | "files.download(out_path)" 84 | ], 85 | "metadata": { 86 | "id": "QrXp6dv8Twyx" 87 | }, 88 | "execution_count": null, 89 | "outputs": [] 90 | } 91 | ] 92 | } 93 | -------------------------------------------------------------------------------- /Wav2Lip.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github", 7 | "colab_type": "text" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": { 17 | "id": "L1SFU1qm8KKQ" 18 | }, 19 | "outputs": [], 20 | "source": [ 21 | "#@title セットアップ\n", 22 | "# ml4aインストール\n", 23 | "!pip install tensorflow==1.15.0\n", 24 | "!pip install imageio==2.4.1\n", 25 | "!pip3 install --quiet ml4a\n", 26 | "\n", 27 | "# ライブラリ・インポート\n", 28 | "from ml4a import audio\n", 29 | "from ml4a import image\n", 30 | "from ml4a.models import wav2lip\n", 31 | "\n", 32 | "# サンプルデータ・ダウンロード\n", 33 | "! pip install --upgrade gdown\n", 34 | "import gdown\n", 35 | "gdown.download('https://drive.google.com/uc?id=1JxpWvO7ssUbO3O_4Wvqf-tTa23GiP-1Y', './data.zip', quiet=False)\n", 36 | "! unzip data.zip\n", 37 | "\n", 38 | "# サンプリング・レート取得関数\n", 39 | "def get_rate(file_path):\n", 40 | " import wave\n", 41 | " wf = wave.open(file_path, \"r\")\n", 42 | " fs = wf.getframerate()\n", 43 | " return fs" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "source": [ 49 | "#@title wav2lip変換\n", 50 | "image = './image/03.jpg' #@param {type:\"string\"}\n", 51 | "audio ='./audio/coffe.wav' #@param {type:\"string\"}\n", 52 | "rate = get_rate(audio)\n", 53 | "\n", 54 | "wav2lip.run(image, \n", 55 | " audio,\n", 56 | " sampling_rate = rate,\n", 57 | " output_video = 'output.mp4', \n", 58 | " pads = [0, 10, 0, 0],\n", 59 | " resize_factor = 2, \n", 60 | " crop = None, \n", 61 | " box = None)" 62 | ], 63 | "metadata": { 64 | "id": "El9tQ1V_Ux96" 65 | }, 66 | "execution_count": null, 67 | "outputs": [] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "source": [ 72 | "#@title 動画の再生\n", 73 | "from IPython.display import HTML\n", 74 | "from base64 import b64encode\n", 75 | "\n", 76 | "mp4 = open('output.mp4', 'rb').read()\n", 77 | "data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode()\n", 78 | "HTML(f\"\"\"\n", 79 | "\"\"\")" 82 | ], 83 | "metadata": { 84 | "id": "iTt7jMPI-HYg" 85 | }, 86 | "execution_count": null, 87 | "outputs": [] 88 | } 89 | ], 90 | "metadata": { 91 | "kernelspec": { 92 | "display_name": "Python 3", 93 | "language": "python", 94 | "name": "python3" 95 | }, 96 | "language_info": { 97 | "codemirror_mode": { 98 | "name": "ipython", 99 | "version": 3 100 | }, 101 | "file_extension": ".py", 102 | "mimetype": "text/x-python", 103 | "name": "python", 104 | "nbconvert_exporter": "python", 105 | "pygments_lexer": "ipython3", 106 | "version": "3.6.9" 107 | }, 108 | "colab": { 109 | "name": "Wav2Lip", 110 | "provenance": [], 111 | "include_colab_link": true 112 | }, 113 | "accelerator": "GPU" 114 | }, 115 | "nbformat": 4, 116 | "nbformat_minor": 0 117 | } 118 | -------------------------------------------------------------------------------- /LDM_TXT2IM.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "LDM-TXT2IM.ipynb", 7 | "private_outputs": true, 8 | "provenance": [], 9 | "machine_shape": "hm", 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | }, 16 | "language_info": { 17 | "name": "python" 18 | }, 19 | "accelerator": "GPU" 20 | }, 21 | "cells": [ 22 | { 23 | "cell_type": "markdown", 24 | "metadata": { 25 | "id": "view-in-github", 26 | "colab_type": "text" 27 | }, 28 | "source": [ 29 | "\"Open" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "source": [ 35 | "# Latent Diffusion Models Text2Image\n", 36 | "\n", 37 | "### https://arxiv.org/abs/2112.10752\n", 38 | "\n", 39 | "### Original repo: https://github.com/CompVis/latent-diffusion\n", 40 | "\n", 41 | "### Enhanced repo by [@RiversHaveWings](https://twitter.com/RiversHaveWings): https://github.com/crowsonkb/latent-diffusion\n", 42 | "\n", 43 | "### Colab optimizations taken from [@multimodalart](https://twitter.com/multimodalart): https://github.com/multimodalart/latent-diffusion-notebook\n", 44 | "\n", 45 | "Shortcut to this notebook: [bit.ly/txt2im](https://bit.ly/txt2im)\n", 46 | "\n", 47 | "Notebook by: [Eyal Gruss](https://eyalgruss.com) \\([@eyaler](https://twitter.com/eyaler)\\)\n", 48 | "\n", 49 | "A curated list of online generative tools: [j.mp/generativetools](https://j.mp/generativetools)" 50 | ], 51 | "metadata": { 52 | "id": "Bmvx0uTbF6Iw" 53 | } 54 | }, 55 | { 56 | "cell_type": "code", 57 | "source": [ 58 | "#@title Setup\n", 59 | "%cd /content\n", 60 | "!nvidia-smi\n", 61 | "!git clone https://github.com/eyaler/latent-diffusion\n", 62 | "!git clone https://github.com/CompVis/taming-transformers\n", 63 | "!pip -q install -e ./taming-transformers\n", 64 | "!pip -q install omegaconf pytorch-lightning torch-fidelity einops transformers\n", 65 | "%cd latent-diffusion\n", 66 | "!mkdir -p models/ldm/text2img-large/\n", 67 | "!wget -nc -O models/ldm/text2img-large/model.ckpt https://ommer-lab.com/files/latent-diffusion/nitro/txt2img-f8-large/model.ckpt\n" 68 | ], 69 | "metadata": { 70 | "cellView": "form", 71 | "id": "2iLdwkKD5l8a" 72 | }, 73 | "execution_count": null, 74 | "outputs": [] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": { 80 | "cellView": "form", 81 | "id": "g0_Gb52UwMHQ" 82 | }, 83 | "outputs": [], 84 | "source": [ 85 | "#@title Generate\n", 86 | "#@markdown Note: An error probably indicates that you either:\n", 87 | "#@markdown 1. Skipped running the above setup stage or waited too long and the Colab disconnected - run it again.\n", 88 | "#@markdown 2. Ran out of RAM - which may be solved by Runtime->Mangage sessions->TERMINATE and starting over in hope of getting a stronger (non K80) machine, or upgrading to Colab Pro...\n", 89 | "prompt = 'Putin riding a zebra shirtless and waving the Ukrainian flag' #@param {type: 'string'}\n", 90 | "plms = False #@param {type: 'boolean'}\n", 91 | "ddim_eta = 0 #@param {type: 'number'}\n", 92 | "n_samples = 4 #@param {type: 'integer'}\n", 93 | "n_iter = 4 #@param {type: 'integer'}\n", 94 | "scale = 5#@param {type: 'number'}\n", 95 | "ddim_steps = 50#@param {type: 'integer'}\n", 96 | "W = 256 #@param {type: 'integer'}\n", 97 | "H = 256 #@param {type: 'integer'}\n", 98 | "outdir = 'outputs' #@param {type: 'string'}\n", 99 | "from google.colab.patches import cv2_imshow\n", 100 | "import cv2\n", 101 | "from time import time\n", 102 | "start = time()\n", 103 | "plms_arg = ''\n", 104 | "if plms:\n", 105 | " plms_arg = '--plms'\n", 106 | "!python scripts/txt2img.py --prompt \"$prompt\" --ddim_eta $ddim_eta --n_samples $n_samples --n_iter $n_iter --scale $scale --ddim_steps $ddim_steps --H $H --W $W --outdir $outdir $plms_arg\n", 107 | "print(f'Took {time()-start:.0f} secs.')\n", 108 | "filename = f'{outdir}/{prompt.replace(\" \", \"-\")}.png'\n", 109 | "print(filename)\n", 110 | "im = cv2.imread(filename)\n", 111 | "cv2_imshow(im)" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "source": [ 117 | "#@title Download images\n", 118 | "!zip -jrqFS ldm.zip \"$outdir\"\n", 119 | "from google.colab import files\n", 120 | "files.download('ldm.zip')" 121 | ], 122 | "metadata": { 123 | "cellView": "form", 124 | "id": "S3PKmI74DENO" 125 | }, 126 | "execution_count": null, 127 | "outputs": [] 128 | } 129 | ] 130 | } -------------------------------------------------------------------------------- /articulated_animation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "articulated_animation", 7 | "provenance": [], 8 | "machine_shape": "hm", 9 | "include_colab_link": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "language_info": { 16 | "name": "python" 17 | }, 18 | "accelerator": "GPU" 19 | }, 20 | "cells": [ 21 | { 22 | "cell_type": "markdown", 23 | "metadata": { 24 | "id": "view-in-github", 25 | "colab_type": "text" 26 | }, 27 | "source": [ 28 | "\"Open" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": { 34 | "id": "-FPGt7_2z7fn" 35 | }, 36 | "source": [ 37 | "\n", 38 | "# セットアップ" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "metadata": { 44 | "id": "2LoHDHBHyzzZ" 45 | }, 46 | "source": [ 47 | "# 1.git-lfsインストール\n", 48 | "!curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash\n", 49 | "!sudo apt-get install git-lfs\n", 50 | "!git lfs install\n", 51 | "\n", 52 | "# 2.githubからコードをコピー\n", 53 | "!git clone https://github.com/snap-research/articulated-animation.git\n", 54 | "%cd articulated-animation\n", 55 | "\n", 56 | "# --- パッチ(2021.9.1追加) ---\n", 57 | "import gdown\n", 58 | "gdown.download('https://drive.google.com/uc?id=1YHhLDf7QGhVUyAIMRsJBmCm0VDDcQdVV', './pt.zip', quiet=False)\n", 59 | "! unzip pt.zip -d checkpoints\n", 60 | "\n", 61 | "# 3.学習済みパラメータをダウンロード\n", 62 | "from demo import load_checkpoints\n", 63 | "generator, region_predictor, avd_network = load_checkpoints(config_path='config/ted384.yaml',\n", 64 | " checkpoint_path='checkpoints/ted384.pth')\n", 65 | "\n", 66 | "# 4.サンプル画像をダウンロード\n", 67 | "! pip install --upgrade gdown\n", 68 | "import gdown\n", 69 | "gdown.download('https://drive.google.com/uc?id=1ZF6fuBTjfVYKeSpX-0R_nMgYTCjSkxHB', './sample.zip', quiet=False)\n", 70 | "! unzip sample.zip -d sup-mat" 71 | ], 72 | "execution_count": null, 73 | "outputs": [] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "metadata": { 78 | "id": "1MBVZKC9Lcsc" 79 | }, 80 | "source": [ 81 | "# 静止画と動画の読み込み" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "metadata": { 87 | "id": "CRVOnlASzrhI" 88 | }, 89 | "source": [ 90 | "# --- 静止画と動画の読み込み ---\n", 91 | "import imageio\n", 92 | "import numpy as np\n", 93 | "import matplotlib.pyplot as plt\n", 94 | "import matplotlib.animation as animation\n", 95 | "from skimage.transform import resize\n", 96 | "from IPython.display import HTML\n", 97 | "import warnings\n", 98 | "warnings.filterwarnings(\"ignore\")\n", 99 | "\n", 100 | "source_image = imageio.imread('sup-mat/001.png')\n", 101 | "driving_video = imageio.mimread('sup-mat/driving.mp4')\n", 102 | "\n", 103 | "\n", 104 | "source_image = resize(source_image, (384, 384))[..., :3]\n", 105 | "driving_video = [resize(frame, (384, 384))[..., :3] for frame in driving_video]\n", 106 | "\n", 107 | "def display(source, driving, generated=None):\n", 108 | " fig = plt.figure(figsize=(8 + 4 * (generated is not None), 6))\n", 109 | "\n", 110 | " ims = []\n", 111 | " for i in range(len(driving)):\n", 112 | " cols = [source]\n", 113 | " cols.append(driving[i])\n", 114 | " if generated is not None:\n", 115 | " cols.append(generated[i])\n", 116 | " im = plt.imshow(np.concatenate(cols, axis=1), animated=True)\n", 117 | " plt.axis('off')\n", 118 | " ims.append([im])\n", 119 | "\n", 120 | " ani = animation.ArtistAnimation(fig, ims, interval=50, repeat_delay=1000)\n", 121 | " plt.close()\n", 122 | " return ani\n", 123 | " \n", 124 | "HTML(display(source_image, driving_video).to_html5_video())" 125 | ], 126 | "execution_count": null, 127 | "outputs": [] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "metadata": { 132 | "id": "bfyQMZ57ICJD" 133 | }, 134 | "source": [ 135 | "# 推論・アニメーションの作成" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "metadata": { 141 | "id": "Lcv6u7Bcz27V" 142 | }, 143 | "source": [ 144 | "# --- 推論・アニメーションの作成 ---\n", 145 | "from demo import make_animation\n", 146 | "from skimage import img_as_ubyte\n", 147 | "\n", 148 | "predictions = make_animation(source_image, driving_video, generator, \n", 149 | " region_predictor, avd_network, animation_mode='avd')\n", 150 | "\n", 151 | "#save resulting video\n", 152 | "imageio.mimsave('generated.mp4', [img_as_ubyte(frame) for frame in predictions])\n", 153 | "\n", 154 | "HTML(display(source_image, driving_video, predictions).to_html5_video())" 155 | ], 156 | "execution_count": null, 157 | "outputs": [] 158 | } 159 | ] 160 | } 161 | -------------------------------------------------------------------------------- /projected_g.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github", 7 | "colab_type": "text" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": { 17 | "id": "VZ3pwUJSoOdO" 18 | }, 19 | "outputs": [], 20 | "source": [ 21 | "# --- セットアップ ---\n", 22 | "\n", 23 | "# Pytorch バージョン変更\n", 24 | "! pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 torchaudio==0.10.1 -f https://download.pytorch.org/whl/torch_stable.html\n", 25 | "\n", 26 | "# githubからコードを取得\n", 27 | "! git clone https://github.com/autonomousvision/projected_gan\n", 28 | "! pip install timm dill\n", 29 | "%cd projected_gan\n", 30 | "\n", 31 | "# 学習済みパラメータのダウンロード\n", 32 | "! wget https://s3.eu-central-1.amazonaws.com/avg-projects/projected_gan/models/art_painting.pkl\n", 33 | "! wget https://s3.eu-central-1.amazonaws.com/avg-projects/projected_gan/models/church.pkl\n", 34 | "#! wget https://s3.eu-central-1.amazonaws.com/avg-projects/projected_gan/models/cityscapes.pkl\n", 35 | "#! wget https://s3.eu-central-1.amazonaws.com/avg-projects/projected_gan/models/clevr.pkl\n", 36 | "#! wget https://s3.eu-central-1.amazonaws.com/avg-projects/projected_gan/models/ffhq.pkl\n", 37 | "#! wget https://s3.eu-central-1.amazonaws.com/avg-projects/projected_gan/models/flowers.pkl\n", 38 | "#! wget https://s3.eu-central-1.amazonaws.com/avg-projects/projected_gan/models/landscape.pkl\n", 39 | "! wget https://s3.eu-central-1.amazonaws.com/avg-projects/projected_gan/models/pokemon.pkl\n", 40 | "\n", 41 | "\n", 42 | "# 画像表示\n", 43 | "import matplotlib.pyplot as plt\n", 44 | "from PIL import Image\n", 45 | "import os\n", 46 | "import numpy as np\n", 47 | "\n", 48 | "def display_pic(folder):\n", 49 | " fig = plt.figure(figsize=(30, 60))\n", 50 | " files = os.listdir(folder)\n", 51 | " files.sort()\n", 52 | " for i, file in enumerate(files):\n", 53 | " img = Image.open(folder+'/'+file) \n", 54 | " images = np.asarray(img)\n", 55 | " ax = fig.add_subplot(10, 5, i+1, xticks=[], yticks=[])\n", 56 | " image_plt = np.array(images)\n", 57 | " ax.imshow(image_plt)\n", 58 | " ax.set_xlabel(file, fontsize=25) \n", 59 | " plt.show()\n", 60 | " plt.close()\n", 61 | "\n", 62 | "# リセットフォルダ\n", 63 | "import shutil\n", 64 | "\n", 65 | "def reset_folder(path):\n", 66 | " if os.path.isdir(path):\n", 67 | " shutil.rmtree(path)\n", 68 | " os.makedirs(path,exist_ok=True)\n", 69 | "\n", 70 | "# 動画再生\n", 71 | "from IPython.display import display, HTML\n", 72 | "from IPython.display import HTML\n", 73 | "\n", 74 | "def display_mp4(path):\n", 75 | " print('prepere to play movie...')\n", 76 | " from base64 import b64encode\n", 77 | " mp4 = open(path,'rb').read()\n", 78 | " data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n", 79 | " display(HTML(\"\"\"\n", 80 | " \n", 83 | " \"\"\" % data_url))" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": { 90 | "id": "chHmHK9wN-Im" 91 | }, 92 | "outputs": [], 93 | "source": [ 94 | "# サンプル画像生成\n", 95 | "reset_folder('out')\n", 96 | "! python gen_images.py --outdir=out\\\n", 97 | " --trunc=1.0\\\n", 98 | " --seeds=20-29 \\\n", 99 | " --network=pokemon.pkl" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "metadata": { 106 | "id": "8ydp2CBJPOW6" 107 | }, 108 | "outputs": [], 109 | "source": [ 110 | "# 画像の表示\n", 111 | "display_pic('out')" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": { 118 | "id": "WSLg-NKqOBeh" 119 | }, 120 | "outputs": [], 121 | "source": [ 122 | "# 補完動画の作成\n", 123 | "! python gen_video.py --output=lerp.mp4\\\n", 124 | " --trunc=1.0\\\n", 125 | " --seeds=20-49\\\n", 126 | " --grid=3x2 \\\n", 127 | " --network=pokemon.pkl" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "source": [ 133 | "# 動画の再生\n", 134 | "display_mp4('lerp.mp4')" 135 | ], 136 | "metadata": { 137 | "id": "OtOuWkoPjdrZ" 138 | }, 139 | "execution_count": null, 140 | "outputs": [] 141 | } 142 | ], 143 | "metadata": { 144 | "accelerator": "GPU", 145 | "colab": { 146 | "collapsed_sections": [], 147 | "name": "projected_g", 148 | "provenance": [], 149 | "include_colab_link": true 150 | }, 151 | "kernelspec": { 152 | "display_name": "Python 3", 153 | "name": "python3" 154 | }, 155 | "language_info": { 156 | "name": "python" 157 | } 158 | }, 159 | "nbformat": 4, 160 | "nbformat_minor": 0 161 | } -------------------------------------------------------------------------------- /dlfs.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "dlfs", 7 | "private_outputs": true, 8 | "provenance": [], 9 | "collapsed_sections": [], 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | }, 16 | "language_info": { 17 | "name": "python" 18 | }, 19 | "accelerator": "GPU" 20 | }, 21 | "cells": [ 22 | { 23 | "cell_type": "markdown", 24 | "metadata": { 25 | "id": "view-in-github", 26 | "colab_type": "text" 27 | }, 28 | "source": [ 29 | "\"Open" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": { 35 | "id": "aqo9-hddTZrh" 36 | }, 37 | "source": [ 38 | "# セットアップ\n" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "metadata": { 44 | "id": "IOSPzfVRTspe" 45 | }, 46 | "source": [ 47 | "# githubのコードをコピー\n", 48 | "!git clone https://github.com/SenHe/DLFS.git\n", 49 | "%cd DLFS/\n", 50 | "\n", 51 | "# ライブラリーのインストール\n", 52 | "!pip3 install -r requirements.txt\n", 53 | "\n", 54 | "# 補助モデルのダウンロード\n", 55 | "!python download_models.py\n", 56 | "\n", 57 | "# 学習済みモデルのダウンロード\n", 58 | "! pip install --upgrade gdown\n", 59 | "import gdown\n", 60 | "!mkdir checkpoints\n", 61 | "%cd checkpoints\n", 62 | "\n", 63 | "gdown.download('https://drive.google.com/u/0/uc?id=1pB4mufFtzbJSxxv_2iFrBPD3vp_Ef-n3&export=download', 'males_model.zip', quiet=False)\n", 64 | "!unzip males_model.zip\n", 65 | "gdown.download('https://drive.google.com/u/0/uc?id=1z0s_j3Khs7-352bMvz8RSnrR53vvdbiI&export=download', 'females_model.zip', quiet=False)\n", 66 | "!unzip females_model.zip\n", 67 | "%cd ..\n", 68 | "\n", 69 | "# サンプル画像ダウンロード\n", 70 | "gdown.download('https://drive.google.com/uc?id=1ruwDizjnzd3scR1QvpXGWLXywY8_W0yj', './images.zip', quiet=False)\n", 71 | "!unzip images.zip" 72 | ], 73 | "execution_count": null, 74 | "outputs": [] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "metadata": { 79 | "id": "dHIRttMXIjWm" 80 | }, 81 | "source": [ 82 | "# インポート&初期設定" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "metadata": { 88 | "id": "SqxbhoRrUU-m" 89 | }, 90 | "source": [ 91 | "# インポート&初期設定\n", 92 | "import os\n", 93 | "from collections import OrderedDict\n", 94 | "from options.test_options import TestOptions\n", 95 | "from data.data_loader import CreateDataLoader\n", 96 | "from models_distan.models import create_model\n", 97 | "import util.util as util\n", 98 | "from util.visualizer import Visualizer\n", 99 | "\n", 100 | "opt = TestOptions().parse(save=False)\n", 101 | "opt.display_id = 0 # do not launch visdom\n", 102 | "opt.nThreads = 1 # test code only supports nThreads = 1\n", 103 | "opt.batchSize = 1 # test code only supports batchSize = 1\n", 104 | "opt.serial_batches = True # no shuffle\n", 105 | "opt.no_flip = True # no flip\n", 106 | "opt.in_the_wild = True # This triggers preprocessing of in the wild images in the dataloader\n", 107 | "opt.traverse = True # This tells the model to traverse the latent space between anchor classes\n", 108 | "opt.interp_step = 0.05 # this controls the number of images to interpolate between anchor classes\n", 109 | "\n", 110 | "data_loader = CreateDataLoader(opt)\n", 111 | "dataset = data_loader.load_data()\n", 112 | "visualizer = Visualizer(opt)\n" 113 | ], 114 | "execution_count": null, 115 | "outputs": [] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "metadata": { 120 | "id": "TJmA52UUX5oM" 121 | }, 122 | "source": [ 123 | "# 年齢による顔アニメーション" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "metadata": { 129 | "id": "-Eapax9tUliy" 130 | }, 131 | "source": [ 132 | "# 年齢による顔アニメーション作成\n", 133 | "opt.name = 'females_model' # females_model'あるいは'males_model'を選択する\n", 134 | "model = create_model(opt)\n", 135 | "model.eval()\n", 136 | "\n", 137 | "img_dir ='images' # フォルダー指定\n", 138 | "img_file ='04.jpg' # ファイル名\n", 139 | "img_path = img_dir+'/'+img_file\n", 140 | "data = dataset.dataset.get_item_from_path(img_path)\n", 141 | "visuals = model.inference(data)\n", 142 | "\n", 143 | "os.makedirs('results', exist_ok=True)\n", 144 | "out_path ='results/'+img_file[:-4]+'.mp4'\n", 145 | "visualizer.make_video(visuals, out_path)\n", 146 | "\n", 147 | "# コーデック変換\n", 148 | "import os\n", 149 | "import shutil\n", 150 | "#shutil.copy('./results/'+img_path[:-4]+'.mp4', './results/out.mp4')\n", 151 | "shutil.copy(out_path, './results/out.mp4')\n", 152 | "if os.path.exists('./output.mp4'):\n", 153 | " os.remove('./output.mp4')\n", 154 | "! ffmpeg -i ./results/out.mp4 -vcodec h264 -pix_fmt yuv420p output.mp4" 155 | ], 156 | "execution_count": null, 157 | "outputs": [] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "metadata": { 162 | "id": "GPiFfIJbXgJa" 163 | }, 164 | "source": [ 165 | "# mp4動画の再生\n", 166 | "from IPython.display import HTML\n", 167 | "from base64 import b64encode\n", 168 | "\n", 169 | "mp4 = open('./output.mp4', 'rb').read()\n", 170 | "data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode()\n", 171 | "HTML(f\"\"\"\n", 172 | "\"\"\")" 175 | ], 176 | "execution_count": null, 177 | "outputs": [] 178 | } 179 | ] 180 | } 181 | -------------------------------------------------------------------------------- /pi_GAN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "pi-GAN", 7 | "provenance": [], 8 | "authorship_tag": "ABX9TyPpACWPJCJDnhEB9y9/MtpJ", 9 | "include_colab_link": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "language_info": { 16 | "name": "python" 17 | }, 18 | "accelerator": "GPU" 19 | }, 20 | "cells": [ 21 | { 22 | "cell_type": "markdown", 23 | "metadata": { 24 | "id": "view-in-github", 25 | "colab_type": "text" 26 | }, 27 | "source": [ 28 | "\"Open" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": { 34 | "id": "NcHH1BydC1pM" 35 | }, 36 | "source": [ 37 | "### セットアップ" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "metadata": { 43 | "id": "74w2kUu7vnfS" 44 | }, 45 | "source": [ 46 | "# gpu チェックチェック\n", 47 | "! nvidia-smi -L\n", 48 | "\n", 49 | "# githubからコードを取得\n", 50 | "! git clone https://github.com/marcoamonteiro/pi-GAN.git\n", 51 | "%cd pi-GAN\n", 52 | "\n", 53 | "# ライブラリーのインストール\n", 54 | "! pip install -r requirements.txt\n", 55 | "\n", 56 | "# 学習済みモデルのダウンロード\n", 57 | "! pip install --upgrade gdown\n", 58 | "import gdown\n", 59 | "gdown.download('https://drive.google.com/uc?id=1bRB4-KxQplJryJvqyEa8Ixkf_BVm4Nn6', './CelebA.zip', quiet=False)\n", 60 | "! unzip CelebA.zip\n", 61 | "gdown.download('https://drive.google.com/uc?id=1WBA-WI8DA7FqXn7__0TdBO0eO08C_EhG', './Cats.zip', quiet=False)\n", 62 | "! unzip Cats.zip\n", 63 | "gdown.download('https://drive.google.com/uc?id=1n4eXijbSD48oJVAbAV4hgdcTbT3Yv4xO', './CARLA.zip', quiet=False)\n", 64 | "! unzip CARLA" 65 | ], 66 | "execution_count": null, 67 | "outputs": [] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": { 72 | "id": "OmRMbVFkDlsh" 73 | }, 74 | "source": [ 75 | "### コード本体" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "metadata": { 81 | "id": "UHWSZbgh2vbG" 82 | }, 83 | "source": [ 84 | "# マルチビュー画像を作成(CelebA)\n", 85 | "! python render_multiview_images.py CelebA/generator.pth --curriculum CelebA --seeds 0 1 2 3\n", 86 | "\n", 87 | "# 画像を表示\n", 88 | "from IPython.display import Image,display_png\n", 89 | "import glob\n", 90 | "files = glob.glob('imgs/*.png')\n", 91 | "files.sort()\n", 92 | "for file in files:\n", 93 | " display_png(Image(file))" 94 | ], 95 | "execution_count": null, 96 | "outputs": [] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "metadata": { 101 | "id": "RDw9yiWozBxe" 102 | }, 103 | "source": [ 104 | "# ビデオを作成(CelebA)\n", 105 | "! python render_video.py CelebA/generator.pth --curriculum CelebA --seeds 0 1 2 3" 106 | ], 107 | "execution_count": null, 108 | "outputs": [] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "metadata": { 113 | "id": "fSxngVvtMxME" 114 | }, 115 | "source": [ 116 | "# mp4動画の再生\n", 117 | "video = '2.mp4' #@param {type:\"string\"}\n", 118 | "video_file = 'vids/'+video\n", 119 | "\n", 120 | "from IPython.display import HTML\n", 121 | "from base64 import b64encode\n", 122 | "mp4 = open(video_file, 'rb').read()\n", 123 | "data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode()\n", 124 | "HTML(f\"\"\"\n", 125 | "\"\"\")" 128 | ], 129 | "execution_count": null, 130 | "outputs": [] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "metadata": { 135 | "id": "SUbhxsdkOaHE" 136 | }, 137 | "source": [ 138 | "# 補間ビデオを作成(CelebA)\n", 139 | "! python render_video_interpolation.py CelebA/generator.pth --curriculum CelebA --seeds 5 0 4 16 2 19 5\n", 140 | "\n", 141 | "# mp4動画の再生\n", 142 | "from IPython.display import HTML\n", 143 | "from base64 import b64encode\n", 144 | "mp4 = open('vids/interp.mp4', 'rb').read()\n", 145 | "data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode()\n", 146 | "HTML(f\"\"\"\n", 147 | "\"\"\")" 150 | ], 151 | "execution_count": null, 152 | "outputs": [] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "metadata": { 157 | "id": "uztOZnBNLdCb" 158 | }, 159 | "source": [ 160 | "# 補間ビデオを作成(Cats)\n", 161 | "! python render_video_interpolation.py Cats/generator.pth --curriculum CelebA --seeds 0 4 8 5 9 1 0\n", 162 | "\n", 163 | "# mp4動画の再生\n", 164 | "from IPython.display import HTML\n", 165 | "from base64 import b64encode\n", 166 | "mp4 = open('vids/interp.mp4', 'rb').read()\n", 167 | "data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode()\n", 168 | "HTML(f\"\"\"\n", 169 | "\"\"\")" 172 | ], 173 | "execution_count": null, 174 | "outputs": [] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "metadata": { 179 | "id": "pSF1l25fZEqN" 180 | }, 181 | "source": [ 182 | "# 補間ビデオを作成(CARLA)\n", 183 | "! python render_video_interpolation.py CARLA/generator.pth --curriculum CARLA --seeds 1 2 3 5 6 8 1\n", 184 | "\n", 185 | "# mp4動画の再生\n", 186 | "from IPython.display import HTML\n", 187 | "from base64 import b64encode\n", 188 | "mp4 = open('vids/interp.mp4', 'rb').read()\n", 189 | "data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode()\n", 190 | "HTML(f\"\"\"\n", 191 | "\"\"\")" 194 | ], 195 | "execution_count": null, 196 | "outputs": [] 197 | } 198 | ] 199 | } 200 | -------------------------------------------------------------------------------- /video_matting.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "video_matting", 7 | "private_outputs": true, 8 | "provenance": [], 9 | "collapsed_sections": [], 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | }, 16 | "language_info": { 17 | "name": "python" 18 | }, 19 | "accelerator": "GPU" 20 | }, 21 | "cells": [ 22 | { 23 | "cell_type": "markdown", 24 | "metadata": { 25 | "id": "view-in-github", 26 | "colab_type": "text" 27 | }, 28 | "source": [ 29 | "\"Open" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": { 35 | "id": "hJnxd6hdDymV" 36 | }, 37 | "source": [ 38 | "# Robust High-Resolution Video Matting with Temporal Guidance.\n", 39 | "\n", 40 | "![Teaser](https://raw.githubusercontent.com/PeterL1n/RobustVideoMatting/master/documentation/image/teaser.gif)\n", 41 | "\n", 42 | "[Project Site](https://peterl1n.github.io/RobustVideoMatting) | [GitHub](https://github.com/PeterL1n/RobustVideoMatting) | [Paper](https://arxiv.org/abs/2108.11515)\n", 43 | "\n", 44 | "\n", 45 | "## セットアップ(3分程度かかります)" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "metadata": { 51 | "id": "Q9Lpmpm4IuII" 52 | }, 53 | "source": [ 54 | "# pytorchバージョン変更\n", 55 | "! pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html\n", 56 | "\n", 57 | "# ライブラリー・インストール\n", 58 | "! pip install --quiet av pims\n", 59 | "! pip install --upgrade gdown\n", 60 | "\n", 61 | "# モデル構築\n", 62 | "import torch\n", 63 | "model = torch.hub.load(\"PeterL1n/RobustVideoMatting\", \"mobilenetv3\").cuda() # or \"resnet50\"\n", 64 | "convert_video = torch.hub.load(\"PeterL1n/RobustVideoMatting\", \"converter\")" 65 | ], 66 | "execution_count": null, 67 | "outputs": [] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": { 72 | "id": "9qD01TEQEg1p" 73 | }, 74 | "source": [ 75 | "### オプション 1: 自分の動画のアップロード\n", 76 | "\n", 77 | "セルを実行して自分のPCにあるファイルを選択 (mp4ファイルのみ)\\\n", 78 | "*アップロードしたファイルは自動的にinput.mp4に書き換えられます" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "metadata": { 84 | "id": "4cGycwzuEgF_" 85 | }, 86 | "source": [ 87 | "import os\n", 88 | "from google.colab import files\n", 89 | "\n", 90 | "uploaded = files.upload() # Use colab upload dialog.\n", 91 | "uploaded = list(uploaded.keys()) # Get uploaded filenames.\n", 92 | "assert len(uploaded) == 1 # Make sure only uploaded one file.\n", 93 | "os.rename(uploaded[0], 'input.mp4') # Rename file to \"input.mp4\"." 94 | ], 95 | "execution_count": null, 96 | "outputs": [] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": { 101 | "id": "0SZzWWvVIpts" 102 | }, 103 | "source": [ 104 | "### オプション 2: サンプルビデオを使用(自分のビデオを使う場合はパス)\n", 105 | "\n", 106 | "セルを実行してサンプルビデオをダウンロードする\\\n", 107 | "*ダウンロードしたファイルは自動的にinput.mp4に書き換えられます" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "metadata": { 113 | "id": "VLonjeynFONz" 114 | }, 115 | "source": [ 116 | "import gdown\n", 117 | "gdown.download('https://drive.google.com/uc?id=1LASEXzU015e30y4YQqv6mq4KId07mVof', './input.mp4', quiet=False)" 118 | ], 119 | "execution_count": null, 120 | "outputs": [] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "metadata": { 125 | "id": "Jy5AetvnHYyO" 126 | }, 127 | "source": [ 128 | "### 変換の実行" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "metadata": { 134 | "id": "XQS1RNu3IEl2" 135 | }, 136 | "source": [ 137 | "convert_video(\n", 138 | " model, # The loaded model, can be on any device (cpu or cuda).\n", 139 | " input_source='input.mp4', # A video file or an image sequence directory.\n", 140 | " downsample_ratio=None, # [Optional] If None, make downsampled max size be 512px.\n", 141 | " output_type='video', # Choose \"video\" or \"png_sequence\"\n", 142 | " output_composition='com.mp4', # File path if video; directory path if png sequence.\n", 143 | " output_alpha=\"pha.mp4\", # [Optional] Output the raw alpha prediction.\n", 144 | " output_foreground=\"fgr.mp4\", # [Optional] Output the raw foreground prediction.\n", 145 | " output_video_mbps=4, # Output video mbps. Not needed for png sequence.\n", 146 | " seq_chunk=12, # Process n frames at once for better parallelism.\n", 147 | " num_workers=1, # Only for image sequence input. Reader threads.\n", 148 | " progress=True # Print conversion progress.\n", 149 | ")" 150 | ], 151 | "execution_count": null, 152 | "outputs": [] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": { 157 | "id": "s3p8jD1qntso" 158 | }, 159 | "source": [ 160 | "### 動画の再生" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "metadata": { 166 | "id": "0hSiwEoNnxn3" 167 | }, 168 | "source": [ 169 | "# mp4動画の再生\n", 170 | "from IPython.display import HTML\n", 171 | "from base64 import b64encode\n", 172 | "\n", 173 | "mp4 = open('./com.mp4', 'rb').read()\n", 174 | "data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode()\n", 175 | "HTML(f\"\"\"\n", 176 | "\"\"\")" 179 | ], 180 | "execution_count": null, 181 | "outputs": [] 182 | }, 183 | { 184 | "cell_type": "markdown", 185 | "metadata": { 186 | "id": "-iJwFwqUI9Az" 187 | }, 188 | "source": [ 189 | "### 実行結果\n", 190 | "\n", 191 | "変換を実行すると下記のファイルが作成されます。必要に応じてダウンロードして下さい。\n", 192 | "\n", 193 | "* `com.mp4`: 【背景をグリーンスクリーン化した動画】\n", 194 | "* `pha.mp4`: αチャンネルを操作した動画.\n", 195 | "* `fgr.mp4`: マスク動画" 196 | ] 197 | } 198 | ] 199 | } 200 | -------------------------------------------------------------------------------- /Stable_Diffusion2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github", 7 | "colab_type": "text" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": { 16 | "id": "w-0mHhP4tpTT" 17 | }, 18 | "source": [ 19 | "#**New: support added for image2image**" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "metadata": { 26 | "id": "9jUMAVmhVedQ" 27 | }, 28 | "outputs": [], 29 | "source": [ 30 | "#@title **セットアップ**\n", 31 | "\n", 32 | "# ライブラリ。インストール\n", 33 | "! pip install transformers gradio scipy ftfy \"ipywidgets>=7,<8\" datasets\n", 34 | "\n", 35 | "# githubからコードをコピーしインストール\n", 36 | "! git clone https://github.com/huggingface/diffusers.git\n", 37 | "! pip install git+https://github.com/huggingface/diffusers.git\n", 38 | "%cd diffusers\n", 39 | "\n", 40 | "# 関数定義(追加)\n", 41 | "import PIL\n", 42 | "from PIL import Image\n", 43 | "import numpy as np\n", 44 | "\n", 45 | "def preprocess(image):\n", 46 | " w, h = image.size\n", 47 | " w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32\n", 48 | " image = image.resize((w, h), resample=PIL.Image.LANCZOS)\n", 49 | " image = np.array(image).astype(np.float32) / 255.0\n", 50 | " image = image[None].transpose(0, 3, 1, 2)\n", 51 | " image = torch.from_numpy(image)\n", 52 | " return 2.*image - 1." 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": { 59 | "id": "q4IQmzJOzuFz" 60 | }, 61 | "outputs": [], 62 | "source": [ 63 | "#@title **Hugging Faceへログイン**\n", 64 | "#@markdown ・事前にHagging Faceでアクセス・トークンを取得しておいて下さい\n", 65 | "\n", 66 | "from huggingface_hub import notebook_login\n", 67 | "\n", 68 | "# ログイン\n", 69 | "notebook_login()" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": { 76 | "id": "nySTNxLrWXSe" 77 | }, 78 | "outputs": [], 79 | "source": [ 80 | "#@title **本体プログラム**\n", 81 | "import gradio as gr\n", 82 | "import torch\n", 83 | "from torch import autocast\n", 84 | "from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler\n", 85 | "import requests\n", 86 | "from PIL import Image\n", 87 | "from io import BytesIO\n", 88 | "from IPython.display import clear_output ###\n", 89 | "import warnings ###\n", 90 | "warnings.filterwarnings('ignore') ###\n", 91 | "\n", 92 | "#from examples.inference.image_to_image import StableDiffusionImg2ImgPipeline, preprocess\n", 93 | "from diffusers import StableDiffusionImg2ImgPipeline\n", 94 | "\n", 95 | "lms = LMSDiscreteScheduler(\n", 96 | " beta_start=0.00085, \n", 97 | " beta_end=0.012, \n", 98 | " beta_schedule=\"scaled_linear\"\n", 99 | ")\n", 100 | "\n", 101 | "pipe = StableDiffusionPipeline.from_pretrained(\n", 102 | " \"CompVis/stable-diffusion-v1-4\", \n", 103 | " scheduler=lms,\n", 104 | " revision=\"fp16\", \n", 105 | " use_auth_token=True\n", 106 | ").to(\"cuda\")\n", 107 | "\n", 108 | "pipeimg = StableDiffusionImg2ImgPipeline.from_pretrained(\n", 109 | " \"CompVis/stable-diffusion-v1-4\",\n", 110 | " revision=\"fp16\", \n", 111 | " torch_dtype=torch.float16,\n", 112 | " use_auth_token=True\n", 113 | ").to(\"cuda\")\n", 114 | "\n", 115 | "\n", 116 | "\n", 117 | "\n", 118 | "block = gr.Blocks(css=\".container { max-width: 800px; margin: auto; }\")\n", 119 | "\n", 120 | "num_samples = 2\n", 121 | "\n", 122 | "def infer(prompt, init_image, strength):\n", 123 | " if init_image != None:\n", 124 | " init_image = init_image.resize((512, 512))\n", 125 | " init_image = preprocess(init_image)\n", 126 | " with autocast(\"cuda\"):\n", 127 | " images = pipeimg(prompt=[prompt] * num_samples, image=init_image, strength=strength, guidance_scale=7.5)[0]\n", 128 | " else: \n", 129 | " with autocast(\"cuda\"):\n", 130 | " images = pipe(prompt=[prompt] * num_samples, guidance_scale=7.5)[0]\n", 131 | "\n", 132 | " return images\n", 133 | "\n", 134 | "\n", 135 | "with block as demo:\n", 136 | " gr.Markdown(\"

Stable Diffusion

\")\n", 137 | " gr.Markdown(\n", 138 | " \"Stable Diffusion is an AI model that generates images from any prompt you give!\"\n", 139 | " )\n", 140 | " with gr.Group():\n", 141 | " with gr.Box():\n", 142 | " with gr.Row().style(mobile_collapse=False, equal_height=True):\n", 143 | "\n", 144 | " text = gr.Textbox(\n", 145 | " label=\"Enter your prompt\", show_label=False, max_lines=1\n", 146 | " ).style(\n", 147 | " border=(True, False, True, True),\n", 148 | " rounded=(True, False, False, True),\n", 149 | " container=False,\n", 150 | " )\n", 151 | " btn = gr.Button(\"Run\").style(\n", 152 | " margin=False,\n", 153 | " rounded=(False, True, True, False),\n", 154 | " )\n", 155 | " strength_slider = gr.Slider(\n", 156 | " label=\"Strength\",\n", 157 | " maximum = 1,\n", 158 | " value = 0.75 \n", 159 | " )\n", 160 | " image = gr.Image(\n", 161 | " label=\"Intial Image\",\n", 162 | " type=\"pil\"\n", 163 | " )\n", 164 | " \n", 165 | " gallery = gr.Gallery(label=\"Generated images\", show_label=False).style(\n", 166 | " grid=[2], height=\"auto\"\n", 167 | " )\n", 168 | " text.submit(infer, inputs=[text,image,strength_slider], outputs=gallery)\n", 169 | " btn.click(infer, inputs=[text,image,strength_slider], outputs=gallery)\n", 170 | "\n", 171 | " gr.Markdown(\n", 172 | " \"\"\"___\n", 173 | "

\n", 174 | " Created by CompVis and Stability AI\n", 175 | "
\n", 176 | "

\"\"\"\n", 177 | " )\n", 178 | "\n", 179 | "clear_output() ###\n", 180 | "demo.launch(debug=True)" 181 | ] 182 | } 183 | ], 184 | "metadata": { 185 | "accelerator": "GPU", 186 | "colab": { 187 | "collapsed_sections": [], 188 | "machine_shape": "hm", 189 | "name": "Stable_Diffusion2", 190 | "provenance": [], 191 | "include_colab_link": true 192 | }, 193 | "gpuClass": "standard", 194 | "kernelspec": { 195 | "display_name": "Python 3", 196 | "name": "python3" 197 | }, 198 | "language_info": { 199 | "name": "python" 200 | } 201 | }, 202 | "nbformat": 4, 203 | "nbformat_minor": 0 204 | } 205 | -------------------------------------------------------------------------------- /FacialCartoonization.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "FacialCartoonization", 7 | "provenance": [], 8 | "authorship_tag": "ABX9TyOCRn3VQJSWKsu32au3IE6R", 9 | "include_colab_link": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "accelerator": "GPU" 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "markdown", 20 | "metadata": { 21 | "id": "view-in-github", 22 | "colab_type": "text" 23 | }, 24 | "source": [ 25 | "\"Open" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": { 31 | "id": "IEhLcJxD7AyK" 32 | }, 33 | "source": [ 34 | "# セットアップ" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "metadata": { 40 | "id": "YDt_YpcLpENX" 41 | }, 42 | "source": [ 43 | "# githubのコードを取得\n", 44 | "!git clone https://github.com/SystemErrorWang/FacialCartoonization.git\n", 45 | "%cd FacialCartoonization/\n", 46 | "\n", 47 | "# サンプル動画をダウンロード\n", 48 | "import gdown\n", 49 | "gdown.download('https://drive.google.com/uc?id=1CiyNbvntSLTL04WlVh6OiwnpdBd8lvOb', 'sample.mp4', quiet=False)" 50 | ], 51 | "execution_count": null, 52 | "outputs": [] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "metadata": { 57 | "id": "cIhcKKbiE30w" 58 | }, 59 | "source": [ 60 | "# フォルダー内表示関数定義\n", 61 | "import matplotlib.pyplot as plt\n", 62 | "from PIL import Image\n", 63 | "import os\n", 64 | "import shutil\n", 65 | "import numpy as np\n", 66 | "%matplotlib inline\n", 67 | "\n", 68 | "def display_pic(folder):\n", 69 | " fig = plt.figure(figsize=(30, 40))\n", 70 | " files = os.listdir(folder)\n", 71 | " files.sort()\n", 72 | " for i, file in enumerate(files):\n", 73 | " if file=='.ipynb_checkpoints':\n", 74 | " continue\n", 75 | " if file=='.DS_Store':\n", 76 | " continue\n", 77 | " img = Image.open(folder+'/'+file) \n", 78 | " images = np.asarray(img)\n", 79 | " ax = fig.add_subplot(10, 10, i+1, xticks=[], yticks=[])\n", 80 | " image_plt = np.array(images)\n", 81 | " ax.imshow(image_plt)\n", 82 | " ax.set_xlabel(file, fontsize=15) \n", 83 | " plt.show()\n", 84 | " plt.close()" 85 | ], 86 | "execution_count": null, 87 | "outputs": [] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "metadata": { 92 | "id": "W0da3qE_B6Fv" 93 | }, 94 | "source": [ 95 | "# 静止画をアニメに変換\n", 96 | "・自分で用意した画像を使う場合は、imagesフォルダーにアップロードして下さい。\\\n", 97 | "・静止画はjpgで正方形である必要があります\n" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "metadata": { 103 | "id": "MuWx803FDv0d" 104 | }, 105 | "source": [ 106 | "# サンプル画像表示\n", 107 | "display_pic('images')" 108 | ], 109 | "execution_count": null, 110 | "outputs": [] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "metadata": { 115 | "id": "d1IBNr9ACIdP" 116 | }, 117 | "source": [ 118 | "# result フォルダーリセット\n", 119 | "if os.path.isdir('results'):\n", 120 | " shutil.rmtree('results')\n", 121 | "\n", 122 | "# アニメ化の実行\n", 123 | "!python inference.py\n", 124 | "\n", 125 | "# 結果表示\n", 126 | "display_pic('results')" 127 | ], 128 | "execution_count": null, 129 | "outputs": [] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "metadata": { 134 | "id": "pSzdrowsMJEz" 135 | }, 136 | "source": [ 137 | "・results フォルダー内のファイル名を00000.jpgからの連番にし、動画に変換します" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "metadata": { 143 | "id": "XHo99LXihRkC" 144 | }, 145 | "source": [ 146 | "# ファイル名を00000.jpgからの連番に変更\n", 147 | "import glob\n", 148 | "files = glob.glob('./results/*.jpg')\n", 149 | "files.sort()\n", 150 | "\n", 151 | "for i, file in enumerate(files):\n", 152 | " img = Image.open(file)\n", 153 | " img.save('./results/'+str(i).zfill(5)+'.jpg')\n", 154 | " os.remove(file)\n", 155 | "\n", 156 | "# 実写+アニメ顔をmp4に変換\n", 157 | "!ffmpeg -r 1 -i results/%05d.jpg -vcodec libx264 -pix_fmt yuv420p out.mp4" 158 | ], 159 | "execution_count": null, 160 | "outputs": [] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "metadata": { 165 | "id": "QGVxWs2AMH0P" 166 | }, 167 | "source": [ 168 | "# mp4動画の再生\n", 169 | "from IPython.display import HTML\n", 170 | "from base64 import b64encode\n", 171 | "\n", 172 | "mp4 = open('./out.mp4', 'rb').read()\n", 173 | "data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode()\n", 174 | "HTML(f\"\"\"\n", 175 | "\"\"\")" 178 | ], 179 | "execution_count": null, 180 | "outputs": [] 181 | }, 182 | { 183 | "cell_type": "markdown", 184 | "metadata": { 185 | "id": "Iz0qtpKqvH4b" 186 | }, 187 | "source": [ 188 | "# mp4動画をアニメに変換\n", 189 | "・自分で用意したmp4動画を使う場合は、カレントディレクトリにアップロードして下さい。\\\n", 190 | "・動画は正方形の形状にして下さい。\\\n", 191 | "・11行目のビデオ指定を保存したファイル名に変更して下さい。" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "metadata": { 197 | "id": "Mme3Ug9wvF2x" 198 | }, 199 | "source": [ 200 | "# ビデオを静止画に変換\n", 201 | "import os\n", 202 | "import shutil\n", 203 | "import cv2\n", 204 | "\n", 205 | "# imagesフォルダーリセット\n", 206 | "if os.path.isdir('images'):\n", 207 | " shutil.rmtree('images')\n", 208 | "os.makedirs('images', exist_ok=True)\n", 209 | " \n", 210 | "def video_2_images(video_file= './sample.mp4', # ビデオ指定\n", 211 | " image_dir='./images/', \n", 212 | " image_file='%s.jpg'): \n", 213 | " \n", 214 | " # Initial setting\n", 215 | " i = 0\n", 216 | " interval = 3\n", 217 | " length = 600 # 最大フレーム数\n", 218 | " \n", 219 | " cap = cv2.VideoCapture(video_file)\n", 220 | " while(cap.isOpened()):\n", 221 | " flag, frame = cap.read() \n", 222 | " if flag == False: \n", 223 | " break\n", 224 | " if i == length*interval:\n", 225 | " break\n", 226 | " if i % interval == 0:\n", 227 | " frame = cv2.resize(frame, (256, 256)) # 256×256にリサイズ \n", 228 | " cv2.imwrite(image_dir+image_file % str(int(i/interval)).zfill(6), frame)\n", 229 | " i += 1 \n", 230 | " cap.release() \n", 231 | " \n", 232 | "def main():\n", 233 | " video_2_images()\n", 234 | " \n", 235 | "if __name__ == '__main__':\n", 236 | " main()" 237 | ], 238 | "execution_count": null, 239 | "outputs": [] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "metadata": { 244 | "id": "FfcvgccXyOwh" 245 | }, 246 | "source": [ 247 | "# resultsフォルダーリセット\n", 248 | "if os.path.isdir('results'):\n", 249 | " shutil.rmtree('results')\n", 250 | "\n", 251 | "# アニメ化の実行 \n", 252 | "!python inference.py\n", 253 | "\n", 254 | "# output.mp4をリセット\n", 255 | "if os.path.exists('./output.mp4'):\n", 256 | " os.remove('./output.mp4')\n", 257 | "\n", 258 | "# アニメ顔をmp4に変換\n", 259 | "!ffmpeg -r 10 -i results/%06d.jpg -vcodec libx264 -pix_fmt yuv420p output.mp4" 260 | ], 261 | "execution_count": null, 262 | "outputs": [] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "metadata": { 267 | "id": "MFEUXsjrI9l_" 268 | }, 269 | "source": [ 270 | "# mp4動画の再生\n", 271 | "from IPython.display import HTML\n", 272 | "from base64 import b64encode\n", 273 | "\n", 274 | "mp4 = open('./output.mp4', 'rb').read()\n", 275 | "data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode()\n", 276 | "HTML(f\"\"\"\n", 277 | "\"\"\")" 280 | ], 281 | "execution_count": null, 282 | "outputs": [] 283 | } 284 | ] 285 | } -------------------------------------------------------------------------------- /autoencoder.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "autoencoder", 7 | "provenance": [], 8 | "authorship_tag": "ABX9TyMOIYDZesQBY/JrLFT+B26E", 9 | "include_colab_link": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "language_info": { 16 | "name": "python" 17 | }, 18 | "accelerator": "GPU" 19 | }, 20 | "cells": [ 21 | { 22 | "cell_type": "markdown", 23 | "metadata": { 24 | "id": "view-in-github", 25 | "colab_type": "text" 26 | }, 27 | "source": [ 28 | "\"Open" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": { 34 | "id": "u9tFQpan0D0B" 35 | }, 36 | "source": [ 37 | "# Keras AutoEncoder で異常検知をやってみる\n" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "metadata": { 43 | "id": "rB--3kVqzFuF" 44 | }, 45 | "source": [ 46 | "# tensolflowバージョン1.x を選択\n", 47 | "%tensorflow_version 1.x" 48 | ], 49 | "execution_count": null, 50 | "outputs": [] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "metadata": { 55 | "id": "fP-s1A9tvU2x" 56 | }, 57 | "source": [ 58 | "from keras.layers import Input, Dense\n", 59 | "from keras.models import Model\n", 60 | "from keras.datasets import mnist\n", 61 | "from sklearn.model_selection import train_test_split\n", 62 | "import numpy as np\n", 63 | "import matplotlib.pyplot as plt\n", 64 | " \n", 65 | "# AutoEncoder ネットワーク構築\n", 66 | "encoding_dim = 32\n", 67 | "input_img = Input(shape=(784,))\n", 68 | "encoded = Dense(encoding_dim, activation='relu')(input_img)\n", 69 | "decoded = Dense(784, activation='sigmoid')(encoded)\n", 70 | "autoencoder = Model(input=input_img, output=decoded)\n", 71 | "autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')\n", 72 | " \n", 73 | "# MNIST データ読み込み\n", 74 | "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n", 75 | " \n", 76 | "# データの前準備\n", 77 | "x_train, x_valid = train_test_split(x_train, test_size=0.175)\n", 78 | "x_train = x_train.astype('float32')/255.\n", 79 | "x_valid = x_valid.astype('float32')/255.\n", 80 | "x_test = x_test.astype('float32')/255.\n", 81 | "x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))\n", 82 | "x_valid = x_valid.reshape((len(x_valid), np.prod(x_valid.shape[1:])))\n", 83 | "x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))\n", 84 | " \n", 85 | "# 学習\n", 86 | "autoencoder.fit(x_train, x_train,\n", 87 | " nb_epoch=50,\n", 88 | " batch_size=256,\n", 89 | " shuffle=True,\n", 90 | " validation_data=(x_valid, x_valid))\n", 91 | " \n", 92 | "# 出力画像の取得\n", 93 | "decoded_imgs = autoencoder.predict(x_test)\n", 94 | " \n", 95 | "# サンプル画像表示\n", 96 | "n = 6\n", 97 | "plt.figure(figsize=(12, 4))\n", 98 | "for i in range(n):\n", 99 | " # テスト画像を表示\n", 100 | " ax = plt.subplot(2, n, i+1)\n", 101 | " plt.imshow(x_test[i].reshape(28, 28))\n", 102 | " plt.gray()\n", 103 | " ax.get_xaxis().set_visible(False)\n", 104 | " ax.get_yaxis().set_visible(False)\n", 105 | " \n", 106 | " # 出力画像を表示\n", 107 | " ax = plt.subplot(2, n, i+1+n)\n", 108 | " plt.imshow(decoded_imgs[i].reshape(28, 28))\n", 109 | " plt.gray()\n", 110 | " ax.get_xaxis().set_visible(False)\n", 111 | " ax.get_yaxis().set_visible(False)\n", 112 | "plt.savefig(\"result.png\")\n", 113 | "plt.show()" 114 | ], 115 | "execution_count": null, 116 | "outputs": [] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "metadata": { 121 | "id": "3ffw3WSYxI1f" 122 | }, 123 | "source": [ 124 | "# 学習データを「1」のみにする\n", 125 | "x1 =[]\n", 126 | "for i in range(len(x_train)):\n", 127 | " if y_train[i] == 1 :\n", 128 | " x1.append(x_train[i])\n", 129 | "x_train = np.array(x1)\n", 130 | " \n", 131 | "# テストデータを「1」と「9」にする\n", 132 | "x2, y = [],[]\n", 133 | "for i in range(len(x_test)):\n", 134 | " if y_test[i] == 1 or y_test[i] == 9 :\n", 135 | " x2.append(x_test[i])\n", 136 | " y.append(y_test[i])\n", 137 | "x_test = np.array(x2)\n", 138 | "y = np.array(y)" 139 | ], 140 | "execution_count": null, 141 | "outputs": [] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "metadata": { 146 | "id": "MUywk8x_x72H" 147 | }, 148 | "source": [ 149 | "# 学習\n", 150 | "autoencoder.fit(x_train, x_train,\n", 151 | " nb_epoch=300,\n", 152 | " batch_size=256,\n", 153 | " shuffle=True,\n", 154 | " validation_data=(x_valid, x_valid))\n", 155 | " \n", 156 | "# 出力画像の取得\n", 157 | "decoded_imgs = autoencoder.predict(x_test)\n", 158 | " \n", 159 | "# サンプル画像表示\n", 160 | "n = 6\n", 161 | "plt.figure(figsize=(12, 4))\n", 162 | "for i in range(n):\n", 163 | " # テスト画像を表示\n", 164 | " ax = plt.subplot(2, n, i+1)\n", 165 | " plt.imshow(x_test[i].reshape(28, 28))\n", 166 | " plt.gray()\n", 167 | " ax.get_xaxis().set_visible(False)\n", 168 | " ax.get_yaxis().set_visible(False)\n", 169 | " \n", 170 | " # 出力画像を表示\n", 171 | " ax = plt.subplot(2, n, i+1+n)\n", 172 | " plt.imshow(decoded_imgs[i].reshape(28, 28))\n", 173 | " plt.gray()\n", 174 | " ax.get_xaxis().set_visible(False)\n", 175 | " ax.get_yaxis().set_visible(False)\n", 176 | "plt.savefig(\"result.png\")\n", 177 | "plt.show()" 178 | ], 179 | "execution_count": null, 180 | "outputs": [] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "metadata": { 185 | "id": "DNTiwMJWyqT4" 186 | }, 187 | "source": [ 188 | "# サンプル画像表示\n", 189 | "n = 6\n", 190 | "plt.figure(figsize=(12, 6))\n", 191 | "for i in range(n):\n", 192 | " # テスト画像を表示\n", 193 | " ax = plt.subplot(3, n, i+1)\n", 194 | " plt.imshow(x_test[i].reshape(28, 28))\n", 195 | " plt.gray()\n", 196 | " ax.get_xaxis().set_visible(False)\n", 197 | " ax.get_yaxis().set_visible(False)\n", 198 | " \n", 199 | " # 出力画像を表示\n", 200 | " ax = plt.subplot(3, n, i+1+n)\n", 201 | " plt.imshow(decoded_imgs[i].reshape(28, 28))\n", 202 | " plt.gray()\n", 203 | " ax.get_xaxis().set_visible(False)\n", 204 | " ax.get_yaxis().set_visible(False)\n", 205 | " \n", 206 | " # 入出力の差分画像を計算\n", 207 | " diff_img = x_test[i] - decoded_imgs[i]\n", 208 | " \n", 209 | " # 入出力の差分数値を計算\n", 210 | " diff = np.sum(np.abs(x_test[i]-decoded_imgs[i]))\n", 211 | " \n", 212 | " # 差分画像と差分数値の表示\n", 213 | " ax = plt.subplot(3, n, i+1+n*2)\n", 214 | " plt.imshow(diff_img.reshape(28, 28))\n", 215 | " plt.gray()\n", 216 | " ax.get_xaxis().set_visible(True)\n", 217 | " ax.get_yaxis().set_visible(True) \n", 218 | " ax.set_xlabel('score = '+str(diff)) \n", 219 | " \n", 220 | "plt.savefig(\"result.png\")\n", 221 | "plt.show()\n", 222 | "plt.close()" 223 | ], 224 | "execution_count": null, 225 | "outputs": [] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "metadata": { 230 | "id": "i3JiXcuqy3KA" 231 | }, 232 | "source": [ 233 | "# score を記録したファイルがあれば一端クリア\n", 234 | "import os\n", 235 | "if os.path.exists('scores_1.txt'):\n", 236 | " os.remove('scores_1.txt')\n", 237 | "if os.path.exists('scores_9.txt'):\n", 238 | " os.remove('scores_9.txt')\n", 239 | " \n", 240 | "# score の計算、結果のファイル保存\n", 241 | "for i in range(100):\n", 242 | " score = np.sum(np.abs(x_test[i]-decoded_imgs[i]))\n", 243 | " \n", 244 | " if y[i] == 1:\n", 245 | " with open('scores_1.txt','a') as f:\n", 246 | " f.write(str(score)+'\\n') \n", 247 | " else:\n", 248 | " with open('scores_9.txt','a') as f:\n", 249 | " f.write(str(score)+'\\n')\n", 250 | " \n", 251 | "# ファイルを元にヒストグラムの表示\n", 252 | "import matplotlib.pyplot as plt\n", 253 | "import csv\n", 254 | " \n", 255 | "x =[]\n", 256 | "with open('scores_1.txt', 'r') as f:\n", 257 | " reader = csv.reader(f)\n", 258 | " for row in reader:\n", 259 | " row = int(float(row[0]))\n", 260 | " x.append(row)\n", 261 | "y =[]\n", 262 | "with open('scores_9.txt', 'r') as f:\n", 263 | " reader = csv.reader(f)\n", 264 | " for row in reader:\n", 265 | " row = int(float(row[0]))\n", 266 | " y.append(row)\n", 267 | " \n", 268 | "plt.title(\"Score Histgram\")\n", 269 | "plt.xlabel(\"Score\")\n", 270 | "plt.ylabel(\"freq\")\n", 271 | "plt.hist(x, bins=10, alpha=0.3, histtype='stepfilled', color='r', label=\"1\")\n", 272 | "plt.hist(y, bins=40, alpha=0.3, histtype='stepfilled', color='b', label='9')\n", 273 | "plt.legend(loc=1)\n", 274 | "plt.savefig(\"histgram.png\")\n", 275 | "plt.show()\n", 276 | "plt.close()" 277 | ], 278 | "execution_count": null, 279 | "outputs": [] 280 | } 281 | ] 282 | } -------------------------------------------------------------------------------- /Optimized_LDM_TXT2IM.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Optimized LDM-TXT2IM", 7 | "private_outputs": true, 8 | "provenance": [], 9 | "machine_shape": "hm", 10 | "collapsed_sections": [], 11 | "include_colab_link": true 12 | }, 13 | "kernelspec": { 14 | "name": "python3", 15 | "display_name": "Python 3" 16 | }, 17 | "language_info": { 18 | "name": "python" 19 | }, 20 | "accelerator": "GPU" 21 | }, 22 | "cells": [ 23 | { 24 | "cell_type": "markdown", 25 | "metadata": { 26 | "id": "view-in-github", 27 | "colab_type": "text" 28 | }, 29 | "source": [ 30 | "\"Open" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "source": [ 36 | "# Latent Diffusion Models Text2Image\n", 37 | "\n", 38 | "### https://arxiv.org/abs/2112.10752\n", 39 | "\n", 40 | "### https://github.com/CompVis/latent-diffusion\n", 41 | "\n", 42 | "Original Notebook by: [Eyal Gruss](https://eyalgruss.com) \\([@eyaler](https://twitter.com/eyaler)\\)\n", 43 | "\n", 44 | "Optimizations by: [Aaron Gokaslan](https://twitter.com/SkyLi0n) and faster sampling by [RiverHasWings](https://twitter.com/rivershavewings)" 45 | ], 46 | "metadata": { 47 | "id": "Bmvx0uTbF6Iw" 48 | } 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "source": [ 53 | "# Note you need a GPU with 16GB of VRAM. If you get a K80, try again." 54 | ], 55 | "metadata": { 56 | "id": "g1QAMXzKCUO8" 57 | } 58 | }, 59 | { 60 | "cell_type": "code", 61 | "source": [ 62 | "!nvidia-smi" 63 | ], 64 | "metadata": { 65 | "id": "X8L4kLA9Ad2q" 66 | }, 67 | "execution_count": null, 68 | "outputs": [] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "source": [ 73 | "#@title Setup\n", 74 | "%cd /content\n", 75 | "!git clone https://github.com/crowsonkb/latent-diffusion --depth 1\n", 76 | "!git clone https://github.com/CompVis/taming-transformers --depth 1\n", 77 | "!pip -q install -e ./taming-transformers\n", 78 | "!pip -q install omegaconf pytorch-lightning torch-fidelity einops transformers\n", 79 | "%cd latent-diffusion\n", 80 | "!cp scripts/txt2img.py .\n", 81 | "!mkdir -p models/ldm/text2img-large/\n", 82 | "!wget -nc -O models/ldm/text2img-large/model.ckpt https://ommer-lab.com/files/latent-diffusion/nitro/txt2img-f8-large/model.ckpt\n" 83 | ], 84 | "metadata": { 85 | "id": "2iLdwkKD5l8a" 86 | }, 87 | "execution_count": null, 88 | "outputs": [] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "source": [ 93 | "%%writefile txt2img.py\n", 94 | "import argparse, os, sys, glob\n", 95 | "import torch\n", 96 | "import numpy as np\n", 97 | "from omegaconf import OmegaConf\n", 98 | "from PIL import Image\n", 99 | "from tqdm.auto import tqdm, trange\n", 100 | "from einops import rearrange\n", 101 | "from torchvision.utils import make_grid\n", 102 | "\n", 103 | "from ldm.util import instantiate_from_config\n", 104 | "from ldm.models.diffusion.ddim import DDIMSampler\n", 105 | "from ldm.models.diffusion.plms import PLMSSampler\n", 106 | "\n", 107 | "\n", 108 | "def load_model_from_config(config, ckpt, verbose=False):\n", 109 | " print(f\"Loading model from {ckpt}\")\n", 110 | " pl_sd = torch.load(ckpt, map_location=\"cuda\")\n", 111 | " sd = pl_sd[\"state_dict\"]\n", 112 | " model = instantiate_from_config(config.model)\n", 113 | " m, u = model.load_state_dict(sd, strict=False)\n", 114 | " if len(m) > 0 and verbose:\n", 115 | " print(\"missing keys:\")\n", 116 | " print(m)\n", 117 | " if len(u) > 0 and verbose:\n", 118 | " print(\"unexpected keys:\")\n", 119 | " print(u)\n", 120 | "\n", 121 | " model.cuda()\n", 122 | " model.eval()\n", 123 | " return model\n", 124 | "\n", 125 | "\n", 126 | "if __name__ == \"__main__\":\n", 127 | " parser = argparse.ArgumentParser()\n", 128 | "\n", 129 | " parser.add_argument(\n", 130 | " \"--prompt\",\n", 131 | " type=str,\n", 132 | " nargs=\"?\",\n", 133 | " default=\"a painting of a virus monster playing guitar\",\n", 134 | " help=\"the prompt to render\"\n", 135 | " )\n", 136 | "\n", 137 | " parser.add_argument(\n", 138 | " \"--outdir\",\n", 139 | " type=str,\n", 140 | " nargs=\"?\",\n", 141 | " help=\"dir to write results to\",\n", 142 | " default=\"outputs/txt2img-samples\"\n", 143 | " )\n", 144 | " parser.add_argument(\n", 145 | " \"--ddim_steps\",\n", 146 | " type=int,\n", 147 | " default=200,\n", 148 | " help=\"number of ddim sampling steps\",\n", 149 | " )\n", 150 | "\n", 151 | " parser.add_argument(\n", 152 | " \"--plms\",\n", 153 | " action='store_true',\n", 154 | " help=\"use plms sampling\",\n", 155 | " )\n", 156 | "\n", 157 | " parser.add_argument(\n", 158 | " \"--ddim_eta\",\n", 159 | " type=float,\n", 160 | " default=0.0,\n", 161 | " help=\"ddim eta (eta=0.0 corresponds to deterministic sampling\",\n", 162 | " )\n", 163 | " parser.add_argument(\n", 164 | " \"--n_iter\",\n", 165 | " type=int,\n", 166 | " default=1,\n", 167 | " help=\"sample this often\",\n", 168 | " )\n", 169 | "\n", 170 | " parser.add_argument(\n", 171 | " \"--H\",\n", 172 | " type=int,\n", 173 | " default=256,\n", 174 | " help=\"image height, in pixel space\",\n", 175 | " )\n", 176 | "\n", 177 | " parser.add_argument(\n", 178 | " \"--W\",\n", 179 | " type=int,\n", 180 | " default=256,\n", 181 | " help=\"image width, in pixel space\",\n", 182 | " )\n", 183 | "\n", 184 | " parser.add_argument(\n", 185 | " \"--n_samples\",\n", 186 | " type=int,\n", 187 | " default=4,\n", 188 | " help=\"how many samples to produce for the given prompt\",\n", 189 | " )\n", 190 | "\n", 191 | " parser.add_argument(\n", 192 | " \"--scale\",\n", 193 | " type=float,\n", 194 | " default=5.0,\n", 195 | " help=\"unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))\",\n", 196 | " )\n", 197 | " opt = parser.parse_args()\n", 198 | "\n", 199 | "\n", 200 | " config = OmegaConf.load(\"configs/latent-diffusion/txt2img-1p4B-eval.yaml\") # TODO: Optionally download from same location as ckpt and chnage this logic\n", 201 | " model = load_model_from_config(config, \"models/ldm/text2img-large/model.ckpt\") # TODO: check path\n", 202 | "\n", 203 | " device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", 204 | " model = model.to(device)\n", 205 | "\n", 206 | " if opt.plms:\n", 207 | " sampler = PLMSSampler(model)\n", 208 | " else:\n", 209 | " sampler = DDIMSampler(model)\n", 210 | "\n", 211 | " os.makedirs(opt.outdir, exist_ok=True)\n", 212 | " outpath = opt.outdir\n", 213 | "\n", 214 | " prompt = opt.prompt\n", 215 | "\n", 216 | "\n", 217 | " sample_path = os.path.join(outpath, \"samples\")\n", 218 | " os.makedirs(sample_path, exist_ok=True)\n", 219 | " base_count = len(os.listdir(sample_path))\n", 220 | "\n", 221 | " all_samples=list()\n", 222 | " with torch.no_grad():\n", 223 | " with model.ema_scope():\n", 224 | " uc = None\n", 225 | " if opt.scale != 1.0:\n", 226 | " uc = model.get_learned_conditioning(opt.n_samples * [\"\"])\n", 227 | " for n in trange(opt.n_iter, desc=\"Sampling\"):\n", 228 | " c = model.get_learned_conditioning(opt.n_samples * [prompt])\n", 229 | " shape = [4, opt.H//8, opt.W//8]\n", 230 | " samples_ddim, _ = sampler.sample(S=opt.ddim_steps,\n", 231 | " conditioning=c,\n", 232 | " batch_size=opt.n_samples,\n", 233 | " shape=shape,\n", 234 | " verbose=False,\n", 235 | " unconditional_guidance_scale=opt.scale,\n", 236 | " unconditional_conditioning=uc,\n", 237 | " eta=opt.ddim_eta)\n", 238 | "\n", 239 | " x_samples_ddim = model.decode_first_stage(samples_ddim)\n", 240 | " x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0)\n", 241 | "\n", 242 | " for x_sample in x_samples_ddim:\n", 243 | " x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')\n", 244 | " Image.fromarray(x_sample.astype(np.uint8)).save(os.path.join(sample_path, f\"{base_count:04}.png\"))\n", 245 | " base_count += 1\n", 246 | " all_samples.append(x_samples_ddim)\n", 247 | "\n", 248 | "\n", 249 | " # additionally, save as grid\n", 250 | " grid = torch.stack(all_samples, 0)\n", 251 | " grid = rearrange(grid, 'n b c h w -> (n b) c h w')\n", 252 | " grid = make_grid(grid, nrow=opt.n_samples)\n", 253 | "\n", 254 | " # to image\n", 255 | " grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()\n", 256 | " Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'{prompt.replace(\" \", \"-\")}.png'))\n", 257 | "\n", 258 | " print(f\"Your samples are ready and waiting four you here: \\n{outpath} \\nEnjoy.\")\n" 259 | ], 260 | "metadata": { 261 | "id": "TgSjU4lJZhgD" 262 | }, 263 | "execution_count": null, 264 | "outputs": [] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": null, 269 | "metadata": { 270 | "id": "g0_Gb52UwMHQ" 271 | }, 272 | "outputs": [], 273 | "source": [ 274 | "%cd /content/latent-diffusion\n", 275 | "\n", 276 | "#@title Generate\n", 277 | "prompt = 'A sticker of Albert Einstein riding a hors' #@param {type: 'string'}\n", 278 | "ddim_eta = 0 #@param {type: 'number'}\n", 279 | "n_samples = 2 #@param {type: 'integer'}\n", 280 | "n_iter = 4 #@param {type: 'integer'}\n", 281 | "scale = 5 #@param {type: 'number'}\n", 282 | "ddim_steps = 50#@param {type: 'integer'}\n", 283 | "W = 256 #@param {type: 'integer'}\n", 284 | "H = 256 #@param {type: 'integer'}\n", 285 | "outdir = 'outputs' #@param {type: 'string'}\n", 286 | "!mkdir -p $outdir\n", 287 | "from google.colab.patches import cv2_imshow\n", 288 | "import cv2\n", 289 | "!python txt2img.py --prompt \"$prompt\" --ddim_eta $ddim_eta --n_samples $n_samples --n_iter $n_iter --scale $scale --ddim_steps $ddim_steps --H $H --W $W --outdir $outdir --plms\n", 290 | "filename = f'{outdir}/{prompt.replace(\" \", \"-\")}.png'\n", 291 | "im = cv2.imread(filename)\n", 292 | "cv2_imshow(im)" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "source": [ 298 | "#@title Download images\n", 299 | "!zip -jrqFS ldm.zip \"$outdir\"\n", 300 | "from google.colab import files\n", 301 | "files.download('ldm.zip')" 302 | ], 303 | "metadata": { 304 | "id": "S3PKmI74DENO" 305 | }, 306 | "execution_count": null, 307 | "outputs": [] 308 | } 309 | ] 310 | } -------------------------------------------------------------------------------- /DemoSegmenter.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "accelerator": "GPU", 6 | "colab": { 7 | "name": "DemoSegmenter", 8 | "provenance": [], 9 | "collapsed_sections": [], 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "display_name": "Python 3", 14 | "language": "python", 15 | "name": "python3" 16 | }, 17 | "language_info": { 18 | "codemirror_mode": { 19 | "name": "ipython", 20 | "version": 3 21 | }, 22 | "file_extension": ".py", 23 | "mimetype": "text/x-python", 24 | "name": "python", 25 | "nbconvert_exporter": "python", 26 | "pygments_lexer": "ipython3", 27 | "version": "3.6.7" 28 | } 29 | }, 30 | "cells": [ 31 | { 32 | "cell_type": "markdown", 33 | "metadata": { 34 | "id": "view-in-github", 35 | "colab_type": "text" 36 | }, 37 | "source": [ 38 | "\"Open" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": { 44 | "id": "Z5mW_XiR-x8a" 45 | }, 46 | "source": [ 47 | "# Setup" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "metadata": { 53 | "id": "Uv4sZ0v9_W6H" 54 | }, 55 | "source": [ 56 | "**Environment Setup**" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "metadata": { 62 | "id": "wJwDpfnR0Tyl" 63 | }, 64 | "source": [ 65 | "%%bash\n", 66 | "# Colab-specific setup\n", 67 | "!(stat -t /usr/local/lib/*/dist-packages/google/colab > /dev/null 2>&1) && exit \n", 68 | "pip install yacs 2>&1 >> install.log\n", 69 | "git init 2>&1 >> install.log\n", 70 | "git remote add origin https://github.com/CSAILVision/semantic-segmentation-pytorch.git 2>> install.log\n", 71 | "git pull origin master 2>&1 >> install.log\n", 72 | "DOWNLOAD_ONLY=1 ./demo_test.sh 2>> install.log" 73 | ], 74 | "execution_count": null, 75 | "outputs": [] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": { 80 | "id": "j0n-UmiA_G-R" 81 | }, 82 | "source": [ 83 | "**Imports and utility functions**" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "metadata": { 89 | "id": "v-Lj9g_E0Tym" 90 | }, 91 | "source": [ 92 | "# System libs\n", 93 | "import os, csv, torch, numpy, scipy.io, PIL.Image, torchvision.transforms\n", 94 | "# Our libs\n", 95 | "from mit_semseg.models import ModelBuilder, SegmentationModule\n", 96 | "from mit_semseg.utils import colorEncode\n", 97 | "\n", 98 | "colors = scipy.io.loadmat('data/color150.mat')['colors']\n", 99 | "names = {}\n", 100 | "with open('data/object150_info.csv') as f:\n", 101 | " reader = csv.reader(f)\n", 102 | " next(reader)\n", 103 | " for row in reader:\n", 104 | " names[int(row[0])] = row[5].split(\";\")[0]\n", 105 | "\n", 106 | "def visualize_result(img, pred, index=None):\n", 107 | " # filter prediction class if requested\n", 108 | " if index is not None:\n", 109 | " pred = pred.copy()\n", 110 | " pred[pred != index] = -1\n", 111 | " print(f'{names[index+1]}:')\n", 112 | " \n", 113 | " # colorize prediction\n", 114 | " pred_color = colorEncode(pred, colors).astype(numpy.uint8)\n", 115 | "\n", 116 | " # aggregate images and save\n", 117 | " im_vis = numpy.concatenate((img, pred_color), axis=1)\n", 118 | " display(PIL.Image.fromarray(im_vis))" 119 | ], 120 | "execution_count": null, 121 | "outputs": [] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "metadata": { 126 | "id": "eql3xHZt0Tyn" 127 | }, 128 | "source": [ 129 | "**Loading the segmentation model**\n" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "metadata": { 135 | "id": "zrJ91vrG0Tyn" 136 | }, 137 | "source": [ 138 | "# Network Builders\n", 139 | "net_encoder = ModelBuilder.build_encoder(\n", 140 | " arch='resnet50dilated',\n", 141 | " fc_dim=2048,\n", 142 | " weights='ckpt/ade20k-resnet50dilated-ppm_deepsup/encoder_epoch_20.pth')\n", 143 | "net_decoder = ModelBuilder.build_decoder(\n", 144 | " arch='ppm_deepsup',\n", 145 | " fc_dim=2048,\n", 146 | " num_class=150,\n", 147 | " weights='ckpt/ade20k-resnet50dilated-ppm_deepsup/decoder_epoch_20.pth',\n", 148 | " use_softmax=True)\n", 149 | "\n", 150 | "crit = torch.nn.NLLLoss(ignore_index=-1)\n", 151 | "segmentation_module = SegmentationModule(net_encoder, net_decoder, crit)\n", 152 | "segmentation_module.eval()\n", 153 | "segmentation_module.cuda()" 154 | ], 155 | "execution_count": null, 156 | "outputs": [] 157 | }, 158 | { 159 | "cell_type": "markdown", 160 | "metadata": { 161 | "id": "XWnCw0If_gGj" 162 | }, 163 | "source": [ 164 | "# For image" 165 | ] 166 | }, 167 | { 168 | "cell_type": "markdown", 169 | "metadata": { 170 | "id": "sqZz74PF0Tyo" 171 | }, 172 | "source": [ 173 | "**Load test data**\n" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "metadata": { 179 | "id": "wkBlYcpC0Typ" 180 | }, 181 | "source": [ 182 | "# Load and normalize one image as a singleton tensor batch\n", 183 | "pil_to_tensor = torchvision.transforms.Compose([\n", 184 | " torchvision.transforms.ToTensor(),\n", 185 | " torchvision.transforms.Normalize(\n", 186 | " mean=[0.485, 0.456, 0.406], # These are RGB mean+std values\n", 187 | " std=[0.229, 0.224, 0.225]) # across a large photo dataset.\n", 188 | "])\n", 189 | "pil_image = PIL.Image.open('ADE_val_00001519.jpg').convert('RGB')\n", 190 | "img_original = numpy.array(pil_image)\n", 191 | "img_data = pil_to_tensor(pil_image)\n", 192 | "singleton_batch = {'img_data': img_data[None].cuda()}\n", 193 | "output_size = img_data.shape[1:]" 194 | ], 195 | "execution_count": null, 196 | "outputs": [] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "metadata": { 201 | "id": "kdAmRgQt0Typ" 202 | }, 203 | "source": [ 204 | "**Run the Model**" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "metadata": { 210 | "id": "w0SZaJfQ0Typ", 211 | "scrolled": false 212 | }, 213 | "source": [ 214 | "# Run the segmentation at the highest resolution.\n", 215 | "with torch.no_grad():\n", 216 | " scores = segmentation_module(singleton_batch, segSize=output_size)\n", 217 | " \n", 218 | "# Get the predicted scores for each pixel\n", 219 | "_, pred = torch.max(scores, dim=1)\n", 220 | "pred = pred.cpu()[0].numpy()\n", 221 | "visualize_result(img_original, pred)" 222 | ], 223 | "execution_count": null, 224 | "outputs": [] 225 | }, 226 | { 227 | "cell_type": "markdown", 228 | "metadata": { 229 | "id": "aEyXy3o-0Tyq" 230 | }, 231 | "source": [ 232 | "**Showing classes individually**" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "metadata": { 238 | "id": "RnYY5f680Tyq" 239 | }, 240 | "source": [ 241 | "# Top classes in answer\n", 242 | "predicted_classes = numpy.bincount(pred.flatten()).argsort()[::-1]\n", 243 | "for c in predicted_classes[:15]:\n", 244 | " visualize_result(img_original, pred, c)" 245 | ], 246 | "execution_count": null, 247 | "outputs": [] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "metadata": { 252 | "id": "BPSQ-6yK-ISM" 253 | }, 254 | "source": [ 255 | "# For movie" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "metadata": { 261 | "id": "CfHynVnK14ok" 262 | }, 263 | "source": [ 264 | "# サンプルビデオをダウンロード\n", 265 | "import gdown\n", 266 | "gdown.download('https://drive.google.com/uc?id=1cfa4R-0Zwd2Te5-qBWe9oNRKQ_pEUr0z', 'road.mp4', quiet=False)" 267 | ], 268 | "execution_count": null, 269 | "outputs": [] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "metadata": { 274 | "id": "kR8a_wmLv7__" 275 | }, 276 | "source": [ 277 | "# サンプルビデオを静止画に変換\n", 278 | "import os\n", 279 | "import shutil\n", 280 | "import cv2\n", 281 | " \n", 282 | "def video_2_images(video_file= './road.mp4', # ビデオ指定\n", 283 | " image_dir='./images/', \n", 284 | " image_file='%s.jpg'): \n", 285 | " \n", 286 | " # Initial setting\n", 287 | " i = 0\n", 288 | " interval = 3\n", 289 | " length = 600 # 最大フレーム数\n", 290 | " \n", 291 | " cap = cv2.VideoCapture(video_file)\n", 292 | " while(cap.isOpened()):\n", 293 | " flag, frame = cap.read() \n", 294 | " if flag == False: \n", 295 | " break\n", 296 | " if i == length*interval:\n", 297 | " break\n", 298 | " if i % interval == 0: \n", 299 | " cv2.imwrite(image_dir+image_file % str(int(i/interval)).zfill(6), frame)\n", 300 | " i += 1 \n", 301 | " cap.release() \n", 302 | "\n", 303 | "# imagesフォルダーリセット\n", 304 | "if os.path.isdir('images'):\n", 305 | " shutil.rmtree('images')\n", 306 | "os.makedirs('images', exist_ok=True)\n", 307 | "\n", 308 | "# ビデオを静止画に変換\n", 309 | "video_2_images()" 310 | ], 311 | "execution_count": null, 312 | "outputs": [] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "metadata": { 317 | "id": "rRRsELU4udx-" 318 | }, 319 | "source": [ 320 | "# 静止画をセグメンテーションに変換\n", 321 | "\n", 322 | "# 正規化データをロード\n", 323 | "pil_to_tensor = torchvision.transforms.Compose([\n", 324 | " torchvision.transforms.ToTensor(),\n", 325 | " torchvision.transforms.Normalize(\n", 326 | " mean=[0.485, 0.456, 0.406], # These are RGB mean+std values\n", 327 | " std=[0.229, 0.224, 0.225]) # across a large photo dataset.\n", 328 | "])\n", 329 | "\n", 330 | "# imagesフォルダーの静止画を1枚づつ処理\n", 331 | "from tqdm import tqdm\n", 332 | "import glob\n", 333 | "files = glob.glob('./images/*.jpg')\n", 334 | "files.sort()\n", 335 | "\n", 336 | "for file in tqdm(files):\n", 337 | " pil_image = PIL.Image.open(file).convert('RGB')\n", 338 | " img_original = numpy.array(pil_image)\n", 339 | " img_data = pil_to_tensor(pil_image)\n", 340 | " singleton_batch = {'img_data': img_data[None].cuda()}\n", 341 | " output_size = img_data.shape[1:]\n", 342 | "\n", 343 | " # セグメンテーションの実行\n", 344 | " with torch.no_grad():\n", 345 | " scores = segmentation_module(singleton_batch, segSize=output_size)\n", 346 | " \n", 347 | " # 予測結果の処理\n", 348 | " _, pred = torch.max(scores, dim=1)\n", 349 | " pred = pred.cpu()[0].numpy()\n", 350 | " pred_color = colorEncode(pred, colors).astype(numpy.uint8)\n", 351 | " im_vis = numpy.concatenate((img_original, pred_color), axis=1) # オリジナルと横連結\n", 352 | " #im_vis = numpy.concatenate((pred_color, img_original), axis=0) # オリジナルと縦連結\n", 353 | " PIL.Image.fromarray(im_vis).save(file) " 354 | ], 355 | "execution_count": null, 356 | "outputs": [] 357 | }, 358 | { 359 | "cell_type": "code", 360 | "metadata": { 361 | "id": "MpJ-104U4Fzq" 362 | }, 363 | "source": [ 364 | "# output.mp4をリセット\n", 365 | "if os.path.exists('./output.mp4'):\n", 366 | " os.remove('./output.mp4')\n", 367 | "\n", 368 | "# 実写+セグメンテーションをmp4動画に変換\n", 369 | "!ffmpeg -r 10 -i images/%06d.jpg -vcodec libx264 -pix_fmt yuv420p output.mp4" 370 | ], 371 | "execution_count": null, 372 | "outputs": [] 373 | }, 374 | { 375 | "cell_type": "code", 376 | "metadata": { 377 | "id": "uk_KCWp041gA" 378 | }, 379 | "source": [ 380 | "# mp4動画の再生\n", 381 | "from IPython.display import HTML\n", 382 | "from base64 import b64encode\n", 383 | "\n", 384 | "mp4 = open('./output.mp4', 'rb').read()\n", 385 | "data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode()\n", 386 | "HTML(f\"\"\"\n", 387 | "\"\"\")" 390 | ], 391 | "execution_count": null, 392 | "outputs": [] 393 | } 394 | ] 395 | } -------------------------------------------------------------------------------- /mttr_interactive_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "mttr_interactive_demo", 7 | "private_outputs": true, 8 | "provenance": [], 9 | "collapsed_sections": [], 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | }, 16 | "language_info": { 17 | "name": "python" 18 | }, 19 | "accelerator": "GPU" 20 | }, 21 | "cells": [ 22 | { 23 | "cell_type": "markdown", 24 | "metadata": { 25 | "id": "view-in-github", 26 | "colab_type": "text" 27 | }, 28 | "source": [ 29 | "\"Open" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": { 35 | "id": "c7ZB0W8mIOzx" 36 | }, 37 | "source": [ 38 | "**セットアップ**" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "metadata": { 44 | "id": "hbdn_2C-RO5Y" 45 | }, 46 | "source": [ 47 | "# moviepyインストール\n", 48 | "# %%capture\n", 49 | "!pip install av moviepy yt-dlp ruamel.yaml einops timm transformers\n", 50 | "\n", 51 | "# ライブラリのインポート\n", 52 | "import torch\n", 53 | "import torchvision\n", 54 | "import torchvision.transforms.functional as F\n", 55 | "from einops import rearrange\n", 56 | "import numpy as np\n", 57 | "from PIL import Image, ImageDraw, ImageOps, ImageFont\n", 58 | "from yt_dlp import YoutubeDL\n", 59 | "from moviepy.editor import VideoFileClip, AudioFileClip, ImageSequenceClip\n", 60 | "from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip\n", 61 | "from IPython.display import HTML\n", 62 | "from base64 import b64encode\n", 63 | "from tqdm.notebook import trange, tqdm\n", 64 | "from transformers import logging\n", 65 | "# logging.set_verbosity_error()\n", 66 | "\n", 67 | "# MTTRモデル初期化\n", 68 | "model, postprocessor = torch.hub.load('mttr2021/MTTR:main','mttr_refer_youtube_vos', force_reload=True)\n", 69 | "model = model.cuda()\n", 70 | "\n", 71 | "# 関数定義\n", 72 | "class NestedTensor(object):\n", 73 | " def __init__(self, tensors, mask):\n", 74 | " self.tensors = tensors\n", 75 | " self.mask = mask\n", 76 | "\n", 77 | "def nested_tensor_from_videos_list(videos_list):\n", 78 | " def _max_by_axis(the_list):\n", 79 | " maxes = the_list[0]\n", 80 | " for sublist in the_list[1:]:\n", 81 | " for index, item in enumerate(sublist):\n", 82 | " maxes[index] = max(maxes[index], item)\n", 83 | " return maxes\n", 84 | "\n", 85 | " max_size = _max_by_axis([list(img.shape) for img in videos_list])\n", 86 | " padded_batch_shape = [len(videos_list)] + max_size\n", 87 | " b, t, c, h, w = padded_batch_shape\n", 88 | " dtype = videos_list[0].dtype\n", 89 | " device = videos_list[0].device\n", 90 | " padded_videos = torch.zeros(padded_batch_shape, dtype=dtype, device=device)\n", 91 | " videos_pad_masks = torch.ones((b, t, h, w), dtype=torch.bool, device=device)\n", 92 | " for vid_frames, pad_vid_frames, vid_pad_m in zip(videos_list, padded_videos, videos_pad_masks):\n", 93 | " pad_vid_frames[:vid_frames.shape[0], :, :vid_frames.shape[2], :vid_frames.shape[3]].copy_(vid_frames)\n", 94 | " vid_pad_m[:vid_frames.shape[0], :vid_frames.shape[2], :vid_frames.shape[3]] = False\n", 95 | " return NestedTensor(padded_videos.transpose(0, 1), videos_pad_masks.transpose(0, 1))\n", 96 | "\n", 97 | "def apply_mask(image, mask, color, transparency=0.7):\n", 98 | " mask = mask[..., np.newaxis].repeat(repeats=3, axis=2)\n", 99 | " mask = mask * transparency\n", 100 | " color_matrix = np.ones(image.shape, dtype=np.float) * color\n", 101 | " out_image = color_matrix * mask + image * (1.0 - mask)\n", 102 | " return out_image\n" 103 | ], 104 | "execution_count": null, 105 | "outputs": [] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "source": [ 110 | "**動画選択とクエリ指定**" 111 | ], 112 | "metadata": { 113 | "id": "I6g9MPbj8P2l" 114 | } 115 | }, 116 | { 117 | "cell_type": "code", 118 | "metadata": { 119 | "id": "XbAe8Djicw20" 120 | }, 121 | "source": [ 122 | "# Choose (by un-commenting) one of the following:\n", 123 | "# video_url, (start_pt, end_pt), text_queries = 'https://www.youtube.com/watch?v=YThX7_8I3m0', (233, 243), ['guy in black performing tricks on a bike', 'a black bike used to perform tricks']\n", 124 | "# video_url, (start_pt, end_pt), text_queries = 'https://www.youtube.com/watch?v=hwLo7aU1Aas', (1144, 1152), ['a man riding a surfboard', 'a black and white surfboard']\n", 125 | "# video_url, (start_pt, end_pt), text_queries = 'https://www.youtube.com/watch?v=yvJDHbrumak', (48, 55), ['a red ball thrown in the air', 'a black horse playing with a person']\n", 126 | "# video_url, (start_pt, end_pt), text_queries = 'https://www.youtube.com/watch?v=L-Wd4A8ESyk', (289, 297), ['a guy performing tricks on a skateboard', 'a black skateboard']\n", 127 | "# video_url, (start_pt, end_pt), text_queries = 'https://www.youtube.com/watch?v=4iTiRvk4FHY', (24, 34), ['man in red shirt playing tennis', 'white tennis racket held by a man in a red shirt']\n", 128 | "# video_url, (start_pt, end_pt), text_queries = 'https://www.youtube.com/watch?v=ZHwlmvuW4NY', (115, 125), ['white dog playing', 'brown and black dog playing']\n", 129 | "video_url, (start_pt, end_pt), text_queries = 'https://www.youtube.com/watch?v=YThX7_8I3m0', (67, 77), ['guy in white shirt performing tricks on a bike', 'a black bike used to perform tricks']\n", 130 | "# video_url, (start_pt, end_pt), text_queries = 'https://www.youtube.com/watch?v=C7TCH927--g', (3, 13), ['a dog to the right', 'a cat to the left']\n", 131 | "# video_url, (start_pt, end_pt), text_queries = 'https://www.youtube.com/watch?v=0Z_WAF1GKfk', (143.5, 147.5), ['a dog to the left playing with a toy', 'a dog to the right playing with a toy']\n", 132 | "# video_url, (start_pt, end_pt), text_queries = 'https://www.youtube.com/watch?v=aEJJmebTLEs', (70, 80), ['a person hugging a dog', 'a white dog sitting']\n", 133 | "# video_url, (start_pt, end_pt), text_queries = 'https://www.youtube.com/watch?v=8sDF8lflCTs' ,(15, 23), ['person in blue riding a bike']\n", 134 | "# video_url, (start_pt, end_pt), text_queries = 'https://www.youtube.com/watch?v=dQw4w9WgXcQ', (2.5, 7.5), ['a person dancing']\n", 135 | "\n", 136 | "\n", 137 | "#OR - try your using own input in the following format: (but keep in mind that performance may be limited!)\n", 138 | "# video_url, (start_pt, end_pt), text_queries = f'https://www.youtube.com/watch?v=???????' ,(start_pnt, end_pnt), ['text query 1', 'text query 2']\n", 139 | "\n", 140 | "assert 0 < end_pt - start_pt <= 10, 'error - the subclip length must be 0-10 seconds long'\n", 141 | "assert 1 <= len(text_queries) <= 2, 'error - 1-2 input text queries are expected'" 142 | ], 143 | "execution_count": null, 144 | "outputs": [] 145 | }, 146 | { 147 | "cell_type": "markdown", 148 | "metadata": { 149 | "id": "HqVBgST0N7NQ" 150 | }, 151 | "source": [ 152 | "**動画のダウンロード**" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "metadata": { 158 | "id": "BKkfRZZHdzJg" 159 | }, 160 | "source": [ 161 | "download_resolution = 360\n", 162 | "full_video_path = 'full_video.mp4'\n", 163 | "input_clip_path = 'input_clip.mp4'\n", 164 | "\n", 165 | "# download parameters:\n", 166 | "ydl_opts = {'format': f'best[height<={download_resolution}]', 'overwrites': True, 'outtmpl': full_video_path}\n", 167 | "# download the whole video:\n", 168 | "with YoutubeDL(ydl_opts) as ydl:\n", 169 | " ydl.download([video_url])\n", 170 | "\n", 171 | "# extract the relevant subclip:\n", 172 | "with VideoFileClip(full_video_path) as video:\n", 173 | " subclip = video.subclip(start_pt, end_pt)\n", 174 | " subclip.write_videofile(input_clip_path)\n", 175 | " \n", 176 | "# visualize the input clip:\n", 177 | "input_clip = open(input_clip_path,'rb').read()\n", 178 | "data_url = \"data:video/mp4;base64,\" + b64encode(input_clip).decode()\n", 179 | "HTML(\"\"\"\"\"\" % data_url)" 180 | ], 181 | "execution_count": null, 182 | "outputs": [] 183 | }, 184 | { 185 | "cell_type": "markdown", 186 | "metadata": { 187 | "id": "vtF7RoBRojcR" 188 | }, 189 | "source": [ 190 | "**インスタンスマスクの生成**" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "metadata": { 196 | "id": "WS4FhVRnwtur" 197 | }, 198 | "source": [ 199 | "window_length = 24 # length of window during inference\n", 200 | "window_overlap = 6 # overlap (in frames) between consecutive windows\n", 201 | "\n", 202 | "with torch.inference_mode():\n", 203 | " # read and preprocess the video clip:\n", 204 | " video, audio, meta = torchvision.io.read_video(filename=input_clip_path)\n", 205 | " video = rearrange(video, 't h w c -> t c h w')\n", 206 | " input_video = F.resize(video, size=360, max_size=640).cuda()\n", 207 | " input_video = input_video.to(torch.float).div_(255)\n", 208 | " input_video = F.normalize(input_video, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n", 209 | " video_metadata = {'resized_frame_size': input_video.shape[-2:], 'original_frame_size': video.shape[-2:]}\n", 210 | " \n", 211 | " # partition the clip into overlapping windows of frames:\n", 212 | " windows = [input_video[i:i+window_length] for i in range(0, len(input_video), window_length - window_overlap)]\n", 213 | " # clean up the text queries:\n", 214 | " text_queries = [\" \".join(q.lower().split()) for q in text_queries]\n", 215 | "\n", 216 | " pred_masks_per_query = []\n", 217 | " t, _, h, w = video.shape\n", 218 | " for text_query in tqdm(text_queries, desc='text queries'):\n", 219 | " pred_masks = torch.zeros(size=(t, 1, h, w))\n", 220 | " for i, window in enumerate(tqdm(windows, desc='windows')):\n", 221 | " window = nested_tensor_from_videos_list([window])\n", 222 | " valid_indices = torch.arange(len(window.tensors)).cuda()\n", 223 | " outputs = model(window, valid_indices, [text_query])\n", 224 | " window_masks = postprocessor(outputs, [video_metadata], window.tensors.shape[-2:])[0]['pred_masks']\n", 225 | " win_start_idx = i*(window_length-window_overlap)\n", 226 | " pred_masks[win_start_idx:win_start_idx + window_length] = window_masks\n", 227 | " pred_masks_per_query.append(pred_masks)" 228 | ], 229 | "execution_count": null, 230 | "outputs": [] 231 | }, 232 | { 233 | "cell_type": "markdown", 234 | "metadata": { 235 | "id": "Th4tQP0Wo9W6" 236 | }, 237 | "source": [ 238 | "**動画にインスタンス・マスクとクエリを適用**" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "metadata": { 244 | "id": "deCqcgXiXUKL" 245 | }, 246 | "source": [ 247 | "# RGB colors for instance masks:\n", 248 | "light_blue = (41, 171, 226)\n", 249 | "purple = (237, 30, 121)\n", 250 | "dark_green = (35, 161, 90)\n", 251 | "orange = (255, 148, 59)\n", 252 | "colors = np.array([light_blue, purple, dark_green, orange])\n", 253 | "\n", 254 | "# width (in pixels) of the black strip above the video on which the text queries will be displayed:\n", 255 | "text_border_height_per_query = 35\n", 256 | "\n", 257 | "video_np = rearrange(video, 't c h w -> t h w c').numpy() / 255.0\n", 258 | "# del video\n", 259 | "pred_masks_per_frame = rearrange(torch.stack(pred_masks_per_query), 'q t 1 h w -> t q h w').numpy()\n", 260 | "masked_video = []\n", 261 | "for vid_frame, frame_masks in tqdm(zip(video_np, pred_masks_per_frame), total=len(video_np), desc='applying masks...'):\n", 262 | " # apply the masks:\n", 263 | " for inst_mask, color in zip(frame_masks, colors):\n", 264 | " vid_frame = apply_mask(vid_frame, inst_mask, color / 255.0)\n", 265 | " vid_frame = Image.fromarray((vid_frame * 255).astype(np.uint8))\n", 266 | " # visualize the text queries:\n", 267 | " vid_frame = ImageOps.expand(vid_frame, border=(0, len(text_queries)*text_border_height_per_query, 0, 0))\n", 268 | " W, H = vid_frame.size\n", 269 | " draw = ImageDraw.Draw(vid_frame)\n", 270 | " font = ImageFont.truetype(font='LiberationSans-Regular.ttf', size=30)\n", 271 | " for i, (text_query, color) in enumerate(zip(text_queries, colors), start=1):\n", 272 | " w, h = draw.textsize(text_query, font=font)\n", 273 | " draw.text(((W - w) / 2, (text_border_height_per_query * i) - h - 3),\n", 274 | " text_query, fill=tuple(color) + (255,), font=font)\n", 275 | " masked_video.append(np.array(vid_frame))\n", 276 | "\n", 277 | "# generate and save the output clip:\n", 278 | "output_clip_path = 'output_clip.mp4'\n", 279 | "clip = ImageSequenceClip(sequence=masked_video, fps=meta['video_fps'])\n", 280 | "clip = clip.set_audio(AudioFileClip(input_clip_path))\n", 281 | "clip.write_videofile(output_clip_path, fps=meta['video_fps'], audio=True)\n", 282 | "del masked_video\n", 283 | "\n", 284 | "# visualize the output clip:\n", 285 | "output_clip = open(output_clip_path,'rb').read()\n", 286 | "data_url = \"data:video/mp4;base64,\" + b64encode(output_clip).decode()\n", 287 | "HTML(\"\"\"\"\"\" % data_url)" 288 | ], 289 | "execution_count": null, 290 | "outputs": [] 291 | } 292 | ] 293 | } -------------------------------------------------------------------------------- /ArcaneGAN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "ArcaneGAN", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "include_colab_link": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "language_info": { 16 | "name": "python" 17 | }, 18 | "accelerator": "GPU" 19 | }, 20 | "cells": [ 21 | { 22 | "cell_type": "markdown", 23 | "metadata": { 24 | "id": "view-in-github", 25 | "colab_type": "text" 26 | }, 27 | "source": [ 28 | "\"Open" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "metadata": { 34 | "id": "GXqfcKRpS5Bi", 35 | "cellView": "form" 36 | }, 37 | "source": [ 38 | "#@title インストール\n", 39 | "#release v0.2\n", 40 | "!wget https://github.com/Sxela/ArcaneGAN/releases/download/v0.1/ArcaneGANv0.1.jit\n", 41 | "!wget https://github.com/Sxela/ArcaneGAN/releases/download/v0.2/ArcaneGANv0.2.jit\n", 42 | "!wget https://github.com/Sxela/ArcaneGAN/releases/download/v0.3/ArcaneGANv0.3.jit\n", 43 | "!pip -qq install facenet_pytorch\n", 44 | "\n", 45 | "# サンプル動画ダウンロード\n", 46 | "! pip install --upgrade gdown\n", 47 | "import gdown\n", 48 | "gdown.download('https://drive.google.com/uc?id=16ei31SsXRqjDM1h6FNeQJALKnbb_huyS', './movies.zip', quiet=False)\n", 49 | "! unzip movies.zip" 50 | ], 51 | "execution_count": null, 52 | "outputs": [] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "metadata": { 57 | "id": "Mm7x7XgxUUwv", 58 | "cellView": "form" 59 | }, 60 | "source": [ 61 | "#@title 初期設定\n", 62 | "#@markdown Select model version\n", 63 | "version = '0.3' #@param ['0.1','0.2','0.3']\n", 64 | "out_x_size = '1280' #@param {type:\"string\"}\n", 65 | "out_y_size = '720' #@param {type:\"string\"}\n", 66 | "x_size = int(out_x_size)\n", 67 | "y_size = int(out_y_size)\n", 68 | "\n", 69 | "from facenet_pytorch import MTCNN\n", 70 | "from torchvision import transforms\n", 71 | "import torch, PIL\n", 72 | "\n", 73 | "from tqdm.notebook import tqdm\n", 74 | "\n", 75 | "mtcnn = MTCNN(image_size=256, margin=80)\n", 76 | "\n", 77 | "# simplest ye olde trustworthy MTCNN for face detection with landmarks\n", 78 | "def detect(img):\n", 79 | " \n", 80 | " # Detect faces\n", 81 | " batch_boxes, batch_probs, batch_points = mtcnn.detect(img, landmarks=True)\n", 82 | " # Select faces\n", 83 | " if not mtcnn.keep_all:\n", 84 | " batch_boxes, batch_probs, batch_points = mtcnn.select_boxes(\n", 85 | " batch_boxes, batch_probs, batch_points, img, method=mtcnn.selection_method\n", 86 | " )\n", 87 | " \n", 88 | " return batch_boxes, batch_points\n", 89 | "\n", 90 | "# my version of isOdd, should make a separate repo for it :D\n", 91 | "def makeEven(_x):\n", 92 | " return _x if (_x % 2 == 0) else _x+1\n", 93 | "\n", 94 | "# the actual scaler function\n", 95 | "def scale(boxes, _img, max_res=1_500_000, target_face=256, fixed_ratio=0, max_upscale=2, VERBOSE=False):\n", 96 | " \n", 97 | " x, y = _img.size\n", 98 | " \n", 99 | " ratio = 2 #initial ratio\n", 100 | " \n", 101 | " #scale to desired face size\n", 102 | " if (boxes is not None):\n", 103 | " if len(boxes)>0:\n", 104 | " ratio = target_face/max(boxes[0][2:]-boxes[0][:2]); \n", 105 | " ratio = min(ratio, max_upscale)\n", 106 | " if VERBOSE: print('up by', ratio)\n", 107 | "\n", 108 | " if fixed_ratio>0:\n", 109 | " if VERBOSE: print('fixed ratio')\n", 110 | " ratio = fixed_ratio\n", 111 | " \n", 112 | " x*=ratio\n", 113 | " y*=ratio\n", 114 | " \n", 115 | " #downscale to fit into max res \n", 116 | " res = x*y\n", 117 | " if res > max_res:\n", 118 | " ratio = pow(res/max_res,1/2); \n", 119 | " if VERBOSE: print(ratio)\n", 120 | " x=int(x/ratio)\n", 121 | " y=int(y/ratio)\n", 122 | " \n", 123 | " #make dimensions even, because usually NNs fail on uneven dimensions due skip connection size mismatch\n", 124 | " x = makeEven(int(x))\n", 125 | " y = makeEven(int(y))\n", 126 | " \n", 127 | " size = (x, y)\n", 128 | "\n", 129 | " return _img.resize(size)\n", 130 | "\n", 131 | "\"\"\" \n", 132 | " A useful scaler algorithm, based on face detection.\n", 133 | " Takes PIL.Image, returns a uniformly scaled PIL.Image\n", 134 | " boxes: a list of detected bboxes\n", 135 | " _img: PIL.Image\n", 136 | " max_res: maximum pixel area to fit into. Use to stay below the VRAM limits of your GPU.\n", 137 | " target_face: desired face size. Upscale or downscale the whole image to fit the detected face into that dimension.\n", 138 | " fixed_ratio: fixed scale. Ignores the face size, but doesn't ignore the max_res limit.\n", 139 | " max_upscale: maximum upscale ratio. Prevents from scaling images with tiny faces to a blurry mess.\n", 140 | "\"\"\"\n", 141 | "\n", 142 | "def scale_by_face_size(_img, max_res=1_500_000, target_face=256, fix_ratio=0, max_upscale=2, VERBOSE=False):\n", 143 | " boxes = None\n", 144 | " boxes, _ = detect(_img)\n", 145 | " if VERBOSE: print('boxes',boxes)\n", 146 | " img_resized = scale(boxes, _img, max_res, target_face, fix_ratio, max_upscale, VERBOSE)\n", 147 | " return img_resized.resize((x_size, y_size))\n", 148 | "\n", 149 | "\n", 150 | "size = 256\n", 151 | "\n", 152 | "means = [0.485, 0.456, 0.406]\n", 153 | "stds = [0.229, 0.224, 0.225]\n", 154 | "\n", 155 | "t_stds = torch.tensor(stds).cuda().half()[:,None,None]\n", 156 | "t_means = torch.tensor(means).cuda().half()[:,None,None]\n", 157 | "\n", 158 | "def makeEven(_x):\n", 159 | " return int(_x) if (_x % 2 == 0) else int(_x+1)\n", 160 | "\n", 161 | "img_transforms = transforms.Compose([ \n", 162 | " transforms.ToTensor(),\n", 163 | " transforms.Normalize(means,stds)])\n", 164 | " \n", 165 | "def tensor2im(var):\n", 166 | " return var.mul(t_stds).add(t_means).mul(255.).clamp(0,255).permute(1,2,0)\n", 167 | "\n", 168 | "def proc_pil_img(input_image, model):\n", 169 | " transformed_image = img_transforms(input_image)[None,...].cuda().half()\n", 170 | " \n", 171 | " with torch.no_grad():\n", 172 | " result_image = model(transformed_image)[0]; print(result_image.shape)\n", 173 | " output_image = tensor2im(result_image)\n", 174 | " output_image = output_image.detach().cpu().numpy().astype('uint8')\n", 175 | " output_image = PIL.Image.fromarray(output_image)\n", 176 | " return output_image\n", 177 | "\n", 178 | "#load model\n", 179 | "model_path = f'/content/ArcaneGANv{version}.jit' \n", 180 | "in_dir = '/content/in'\n", 181 | "out_dir = f\"/content/{model_path.split('/')[-1][:-4]}_out\"\n", 182 | "\n", 183 | "model = torch.jit.load(model_path).eval().cuda().half()\n", 184 | "\n", 185 | "#setup colab interface\n", 186 | "\n", 187 | "from google.colab import files\n", 188 | "import ipywidgets as widgets\n", 189 | "from IPython.display import clear_output \n", 190 | "from IPython.display import display\n", 191 | "import os\n", 192 | "from glob import glob\n", 193 | "\n", 194 | "def reset(p):\n", 195 | " with output_reset:\n", 196 | " clear_output()\n", 197 | " clear_output()\n", 198 | " process()\n", 199 | " \n", 200 | "button_reset = widgets.Button(description=\"Upload\")\n", 201 | "output_reset = widgets.Output()\n", 202 | "button_reset.on_click(reset)\n", 203 | "\n", 204 | "def fit(img,maxsize=512):\n", 205 | " maxdim = max(*img.size)\n", 206 | " if maxdim>maxsize:\n", 207 | " ratio = maxsize/maxdim\n", 208 | " x,y = img.size\n", 209 | " size = (int(x*ratio),int(y*ratio)) \n", 210 | " img = img.resize(size)\n", 211 | " return img\n", 212 | " \n", 213 | "def show_img(f, size=1024):\n", 214 | " display(fit(PIL.Image.open(f),size))\n", 215 | "\n", 216 | "def process(upload=False):\n", 217 | " os.makedirs(in_dir, exist_ok=True)\n", 218 | " %cd {in_dir}/\n", 219 | " !rm -rf {out_dir}/*\n", 220 | " os.makedirs(out_dir, exist_ok=True)\n", 221 | " in_files = sorted(glob(f'{in_dir}/*'))\n", 222 | " if (len(in_files)==0) | (upload):\n", 223 | " !rm -rf {in_dir}/*\n", 224 | " uploaded = files.upload()\n", 225 | " if len(uploaded.keys())<=0: \n", 226 | " print('\\nNo files were uploaded. Try again..\\n')\n", 227 | " return\n", 228 | " \n", 229 | " in_files = sorted(glob(f'{in_dir}/*'))\n", 230 | " for img in tqdm(in_files):\n", 231 | " out = f\"{out_dir}/{img.split('/')[-1].split('.')[0]}.jpg\"\n", 232 | " im = PIL.Image.open(img)\n", 233 | " im = scale_by_face_size(im, target_face=300, max_res=1_500_000, max_upscale=2)\n", 234 | " res = proc_pil_img(im, model)\n", 235 | " #res = res.resize((1280, 720)) ###resize\n", 236 | " res.save(out)\n", 237 | "\n", 238 | " #out_zip = f\"{out_dir}.zip\"\n", 239 | " #!zip {out_zip} {out_dir}/*\n", 240 | " \n", 241 | " processed = sorted(glob(f'{out_dir}/*'))[:3]\n", 242 | " for f in processed: \n", 243 | " show_img(f, 256)\n" 244 | ], 245 | "execution_count": null, 246 | "outputs": [] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "source": [ 251 | "#@title 動画を静止画にバラす\n", 252 | "movie = '01.mp4' #@param {type:\"string\"}\n", 253 | "video_file = '/content/'+movie\n", 254 | "\n", 255 | "import os\n", 256 | "import shutil\n", 257 | "import cv2\n", 258 | "\n", 259 | "# flamesフォルダーリセット\n", 260 | "if os.path.isdir('/content/in'):\n", 261 | " shutil.rmtree('/content/in')\n", 262 | "os.makedirs('/content/in', exist_ok=True)\n", 263 | " \n", 264 | "def video_2_images(video_file= video_file, # ビデオの指定\n", 265 | " image_dir='/content/in/', \n", 266 | " image_file='%s.jpg'): \n", 267 | "\n", 268 | " # Initial setting\n", 269 | " i = 0\n", 270 | " interval = 1\n", 271 | " length = 3000 # 最大フレーム数\n", 272 | " \n", 273 | " cap = cv2.VideoCapture(video_file)\n", 274 | " fps = cap.get(cv2.CAP_PROP_FPS) # fps取得\n", 275 | "\n", 276 | " while(cap.isOpened()):\n", 277 | " flag, frame = cap.read() \n", 278 | " if flag == False: \n", 279 | " break\n", 280 | " if i == length*interval:\n", 281 | " break\n", 282 | " if i % interval == 0: \n", 283 | " cv2.imwrite(image_dir+image_file % str(int(i/interval)).zfill(6), frame)\n", 284 | " i += 1 \n", 285 | " cap.release()\n", 286 | " return fps, i, interval\n", 287 | " \n", 288 | "fps, i, interval = video_2_images()\n", 289 | "print('fps = ', fps)\n", 290 | "print('flames = ', i)\n", 291 | "print('interval = ', interval)" 292 | ], 293 | "metadata": { 294 | "id": "ox5OqfitQiby", 295 | "cellView": "form" 296 | }, 297 | "execution_count": null, 298 | "outputs": [] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "metadata": { 303 | "id": "tdePnlXFX7x8", 304 | "cellView": "form" 305 | }, 306 | "source": [ 307 | "#@title 静止画をアニメに変換\n", 308 | "process()\n", 309 | "%cd ..\n", 310 | "\n", 311 | "# コード内でカレントディレクトリを/content/inに移しているので、最後に/contentに戻す\n", 312 | "# そうしないと、動画から静止画をバラすときに/content/inを一旦削除するためカレントディレクトリを見失うため" 313 | ], 314 | "execution_count": null, 315 | "outputs": [] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "source": [ 320 | "#@title アニメから動画を作成\n", 321 | "\n", 322 | "# リセットファイル\n", 323 | "if os.path.exists('/content/output.mp4'):\n", 324 | " os.remove('/content/output.mp4')\n", 325 | "\n", 326 | "if version == '0.1':\n", 327 | " ! ffmpeg -r $fps -i /content/ArcaneGANv0.1_out/%06d.jpg -vcodec libx264 -pix_fmt yuv420p /content/output.mp4\n", 328 | "if version == '0.2':\n", 329 | " ! ffmpeg -r $fps -i /content/ArcaneGANv0.2_out/%06d.jpg -vcodec libx264 -pix_fmt yuv420p /content/output.mp4\n", 330 | "if version == '0.3':\n", 331 | " ! ffmpeg -r $fps -i /content/ArcaneGANv0.3_out/%06d.jpg -vcodec libx264 -pix_fmt yuv420p /content/output.mp4" 332 | ], 333 | "metadata": { 334 | "cellView": "form", 335 | "id": "j61Ga1xRFzbd" 336 | }, 337 | "execution_count": null, 338 | "outputs": [] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "source": [ 343 | "#@title 動画の再生\n", 344 | "from IPython.display import HTML\n", 345 | "from base64 import b64encode\n", 346 | "\n", 347 | "mp4 = open('/content/output.mp4', 'rb').read()\n", 348 | "data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode()\n", 349 | "HTML(f\"\"\"\n", 350 | "\"\"\")" 353 | ], 354 | "metadata": { 355 | "cellView": "form", 356 | "id": "jF3YY8DOAghg" 357 | }, 358 | "execution_count": null, 359 | "outputs": [] 360 | } 361 | ] 362 | } 363 | -------------------------------------------------------------------------------- /ArcaneGAN_latest.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "ArcaneGAN_latest", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "include_colab_link": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "language_info": { 16 | "name": "python" 17 | }, 18 | "accelerator": "GPU" 19 | }, 20 | "cells": [ 21 | { 22 | "cell_type": "markdown", 23 | "metadata": { 24 | "id": "view-in-github", 25 | "colab_type": "text" 26 | }, 27 | "source": [ 28 | "\"Open" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "metadata": { 34 | "id": "GXqfcKRpS5Bi", 35 | "cellView": "form" 36 | }, 37 | "source": [ 38 | "#@title インストール\n", 39 | "#release v0.2\n", 40 | "!wget https://github.com/Sxela/ArcaneGAN/releases/download/v0.1/ArcaneGANv0.1.jit\n", 41 | "!wget https://github.com/Sxela/ArcaneGAN/releases/download/v0.2/ArcaneGANv0.2.jit\n", 42 | "!wget https://github.com/Sxela/ArcaneGAN/releases/download/v0.3/ArcaneGANv0.3.jit\n", 43 | "!wget https://github.com/Sxela/ArcaneGAN/releases/download/v0.4/ArcaneGANv0.4.jit\n", 44 | "!pip -qq install facenet_pytorch\n", 45 | "\n", 46 | "# サンプル動画ダウンロード\n", 47 | "import gdown\n", 48 | "gdown.download('https://drive.google.com/uc?id=16ei31SsXRqjDM1h6FNeQJALKnbb_huyS', './movies.zip', quiet=False)\n", 49 | "! unzip movies.zip" 50 | ], 51 | "execution_count": null, 52 | "outputs": [] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "metadata": { 57 | "id": "Mm7x7XgxUUwv", 58 | "cellView": "form" 59 | }, 60 | "source": [ 61 | "#@title 初期設定\n", 62 | "#@markdown Select model version\n", 63 | "version = '0.4' #@param ['0.1','0.2','0.3','0.4']\n", 64 | "out_x_size = '1280' #@param {type:\"string\"}\n", 65 | "out_y_size = '720' #@param {type:\"string\"}\n", 66 | "x_size = int(out_x_size)\n", 67 | "y_size = int(out_y_size)\n", 68 | "\n", 69 | "from facenet_pytorch import MTCNN\n", 70 | "from torchvision import transforms\n", 71 | "import torch, PIL\n", 72 | "\n", 73 | "from tqdm.notebook import tqdm\n", 74 | "\n", 75 | "mtcnn = MTCNN(image_size=256, margin=80)\n", 76 | "\n", 77 | "# simplest ye olde trustworthy MTCNN for face detection with landmarks\n", 78 | "def detect(img):\n", 79 | " \n", 80 | " # Detect faces\n", 81 | " batch_boxes, batch_probs, batch_points = mtcnn.detect(img, landmarks=True)\n", 82 | " # Select faces\n", 83 | " if not mtcnn.keep_all:\n", 84 | " batch_boxes, batch_probs, batch_points = mtcnn.select_boxes(\n", 85 | " batch_boxes, batch_probs, batch_points, img, method=mtcnn.selection_method\n", 86 | " )\n", 87 | " \n", 88 | " return batch_boxes, batch_points\n", 89 | "\n", 90 | "# my version of isOdd, should make a separate repo for it :D\n", 91 | "def makeEven(_x):\n", 92 | " return _x if (_x % 2 == 0) else _x+1\n", 93 | "\n", 94 | "# the actual scaler function\n", 95 | "def scale(boxes, _img, max_res=1_500_000, target_face=256, fixed_ratio=0, max_upscale=2, VERBOSE=False):\n", 96 | " \n", 97 | " x, y = _img.size\n", 98 | " \n", 99 | " ratio = 2 #initial ratio\n", 100 | " \n", 101 | " #scale to desired face size\n", 102 | " if (boxes is not None):\n", 103 | " if len(boxes)>0:\n", 104 | " ratio = target_face/max(boxes[0][2:]-boxes[0][:2]); \n", 105 | " ratio = min(ratio, max_upscale)\n", 106 | " if VERBOSE: print('up by', ratio)\n", 107 | "\n", 108 | " if fixed_ratio>0:\n", 109 | " if VERBOSE: print('fixed ratio')\n", 110 | " ratio = fixed_ratio\n", 111 | " \n", 112 | " x*=ratio\n", 113 | " y*=ratio\n", 114 | " \n", 115 | " #downscale to fit into max res \n", 116 | " res = x*y\n", 117 | " if res > max_res:\n", 118 | " ratio = pow(res/max_res,1/2); \n", 119 | " if VERBOSE: print(ratio)\n", 120 | " x=int(x/ratio)\n", 121 | " y=int(y/ratio)\n", 122 | " \n", 123 | " #make dimensions even, because usually NNs fail on uneven dimensions due skip connection size mismatch\n", 124 | " x = makeEven(int(x))\n", 125 | " y = makeEven(int(y))\n", 126 | " \n", 127 | " size = (x, y)\n", 128 | "\n", 129 | " return _img.resize(size)\n", 130 | "\n", 131 | "\"\"\" \n", 132 | " A useful scaler algorithm, based on face detection.\n", 133 | " Takes PIL.Image, returns a uniformly scaled PIL.Image\n", 134 | " boxes: a list of detected bboxes\n", 135 | " _img: PIL.Image\n", 136 | " max_res: maximum pixel area to fit into. Use to stay below the VRAM limits of your GPU.\n", 137 | " target_face: desired face size. Upscale or downscale the whole image to fit the detected face into that dimension.\n", 138 | " fixed_ratio: fixed scale. Ignores the face size, but doesn't ignore the max_res limit.\n", 139 | " max_upscale: maximum upscale ratio. Prevents from scaling images with tiny faces to a blurry mess.\n", 140 | "\"\"\"\n", 141 | "\n", 142 | "def scale_by_face_size(_img, max_res=1_500_000, target_face=256, fix_ratio=0, max_upscale=2, VERBOSE=False):\n", 143 | " boxes = None\n", 144 | " boxes, _ = detect(_img)\n", 145 | " if VERBOSE: print('boxes',boxes)\n", 146 | " img_resized = scale(boxes, _img, max_res, target_face, fix_ratio, max_upscale, VERBOSE)\n", 147 | " return img_resized.resize((x_size, y_size))\n", 148 | "\n", 149 | "\n", 150 | "size = 256\n", 151 | "\n", 152 | "means = [0.485, 0.456, 0.406]\n", 153 | "stds = [0.229, 0.224, 0.225]\n", 154 | "\n", 155 | "t_stds = torch.tensor(stds).cuda().half()[:,None,None]\n", 156 | "t_means = torch.tensor(means).cuda().half()[:,None,None]\n", 157 | "\n", 158 | "def makeEven(_x):\n", 159 | " return int(_x) if (_x % 2 == 0) else int(_x+1)\n", 160 | "\n", 161 | "img_transforms = transforms.Compose([ \n", 162 | " transforms.ToTensor(),\n", 163 | " transforms.Normalize(means,stds)])\n", 164 | " \n", 165 | "def tensor2im(var):\n", 166 | " return var.mul(t_stds).add(t_means).mul(255.).clamp(0,255).permute(1,2,0)\n", 167 | "\n", 168 | "def proc_pil_img(input_image, model):\n", 169 | " transformed_image = img_transforms(input_image)[None,...].cuda().half()\n", 170 | " \n", 171 | " with torch.no_grad():\n", 172 | " result_image = model(transformed_image)[0]; print(result_image.shape)\n", 173 | " output_image = tensor2im(result_image)\n", 174 | " output_image = output_image.detach().cpu().numpy().astype('uint8')\n", 175 | " output_image = PIL.Image.fromarray(output_image)\n", 176 | " return output_image\n", 177 | "\n", 178 | "#load model\n", 179 | "model_path = f'/content/ArcaneGANv{version}.jit' \n", 180 | "in_dir = '/content/in'\n", 181 | "out_dir = f\"/content/{model_path.split('/')[-1][:-4]}_out\"\n", 182 | "\n", 183 | "model = torch.jit.load(model_path).eval().cuda().half()\n", 184 | "\n", 185 | "#setup colab interface\n", 186 | "\n", 187 | "from google.colab import files\n", 188 | "import ipywidgets as widgets\n", 189 | "from IPython.display import clear_output \n", 190 | "from IPython.display import display\n", 191 | "import os\n", 192 | "from glob import glob\n", 193 | "\n", 194 | "def reset(p):\n", 195 | " with output_reset:\n", 196 | " clear_output()\n", 197 | " clear_output()\n", 198 | " process()\n", 199 | " \n", 200 | "button_reset = widgets.Button(description=\"Upload\")\n", 201 | "output_reset = widgets.Output()\n", 202 | "button_reset.on_click(reset)\n", 203 | "\n", 204 | "def fit(img,maxsize=512):\n", 205 | " maxdim = max(*img.size)\n", 206 | " if maxdim>maxsize:\n", 207 | " ratio = maxsize/maxdim\n", 208 | " x,y = img.size\n", 209 | " size = (int(x*ratio),int(y*ratio)) \n", 210 | " img = img.resize(size)\n", 211 | " return img\n", 212 | " \n", 213 | "def show_img(f, size=1024):\n", 214 | " display(fit(PIL.Image.open(f),size))\n", 215 | "\n", 216 | "def process(upload=False):\n", 217 | " os.makedirs(in_dir, exist_ok=True)\n", 218 | " %cd {in_dir}/\n", 219 | " !rm -rf {out_dir}/*\n", 220 | " os.makedirs(out_dir, exist_ok=True)\n", 221 | " in_files = sorted(glob(f'{in_dir}/*'))\n", 222 | " if (len(in_files)==0) | (upload):\n", 223 | " !rm -rf {in_dir}/*\n", 224 | " uploaded = files.upload()\n", 225 | " if len(uploaded.keys())<=0: \n", 226 | " print('\\nNo files were uploaded. Try again..\\n')\n", 227 | " return\n", 228 | " \n", 229 | " in_files = sorted(glob(f'{in_dir}/*'))\n", 230 | " for img in tqdm(in_files):\n", 231 | " out = f\"{out_dir}/{img.split('/')[-1].split('.')[0]}.jpg\"\n", 232 | " im = PIL.Image.open(img)\n", 233 | " im = scale_by_face_size(im, target_face=300, max_res=1_500_000, max_upscale=2)\n", 234 | " res = proc_pil_img(im, model)\n", 235 | " #res = res.resize((1280, 720)) ###resize\n", 236 | " res.save(out)\n", 237 | "\n", 238 | " #out_zip = f\"{out_dir}.zip\"\n", 239 | " #!zip {out_zip} {out_dir}/*\n", 240 | " \n", 241 | " processed = sorted(glob(f'{out_dir}/*'))[:3]\n", 242 | " for f in processed: \n", 243 | " show_img(f, 256)\n" 244 | ], 245 | "execution_count": null, 246 | "outputs": [] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "source": [ 251 | "#@title 動画を静止画にバラす\n", 252 | "movie = '01.mp4' #@param {type:\"string\"}\n", 253 | "video_file = '/content/'+movie\n", 254 | "\n", 255 | "import os\n", 256 | "import shutil\n", 257 | "import cv2\n", 258 | "\n", 259 | "# flamesフォルダーリセット\n", 260 | "if os.path.isdir('/content/in'):\n", 261 | " shutil.rmtree('/content/in')\n", 262 | "os.makedirs('/content/in', exist_ok=True)\n", 263 | " \n", 264 | "def video_2_images(video_file= video_file, # ビデオの指定\n", 265 | " image_dir='/content/in/', \n", 266 | " image_file='%s.jpg'): \n", 267 | "\n", 268 | " # Initial setting\n", 269 | " i = 0\n", 270 | " interval = 1\n", 271 | " length = 3000 # 最大フレーム数\n", 272 | " \n", 273 | " cap = cv2.VideoCapture(video_file)\n", 274 | " fps = cap.get(cv2.CAP_PROP_FPS) # fps取得\n", 275 | "\n", 276 | " while(cap.isOpened()):\n", 277 | " flag, frame = cap.read() \n", 278 | " if flag == False: \n", 279 | " break\n", 280 | " if i == length*interval:\n", 281 | " break\n", 282 | " if i % interval == 0: \n", 283 | " cv2.imwrite(image_dir+image_file % str(int(i/interval)).zfill(6), frame)\n", 284 | " i += 1 \n", 285 | " cap.release()\n", 286 | " return fps, i, interval\n", 287 | " \n", 288 | "fps, i, interval = video_2_images()\n", 289 | "print('fps = ', fps)\n", 290 | "print('flames = ', i)\n", 291 | "print('interval = ', interval)" 292 | ], 293 | "metadata": { 294 | "id": "ox5OqfitQiby", 295 | "cellView": "form" 296 | }, 297 | "execution_count": null, 298 | "outputs": [] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "metadata": { 303 | "id": "tdePnlXFX7x8", 304 | "cellView": "form" 305 | }, 306 | "source": [ 307 | "#@title 静止画をアニメに変換\n", 308 | "process()\n", 309 | "%cd ..\n", 310 | "\n", 311 | "# コード内でカレントディレクトリを/content/inに移しているので、最後に/contentに戻す\n", 312 | "# そうしないと、動画から静止画をバラすときに/content/inを一旦削除するためカレントディレクトリを見失うため" 313 | ], 314 | "execution_count": null, 315 | "outputs": [] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "source": [ 320 | "#@title アニメから動画を作成\n", 321 | "\n", 322 | "# リセットファイル\n", 323 | "if os.path.exists('/content/output.mp4'):\n", 324 | " os.remove('/content/output.mp4')\n", 325 | "\n", 326 | "if version == '0.1':\n", 327 | " ! ffmpeg -r $fps -i /content/ArcaneGANv0.1_out/%06d.jpg -vcodec libx264 -pix_fmt yuv420p /content/output.mp4\n", 328 | "if version == '0.2':\n", 329 | " ! ffmpeg -r $fps -i /content/ArcaneGANv0.2_out/%06d.jpg -vcodec libx264 -pix_fmt yuv420p /content/output.mp4\n", 330 | "if version == '0.3':\n", 331 | " ! ffmpeg -r $fps -i /content/ArcaneGANv0.3_out/%06d.jpg -vcodec libx264 -pix_fmt yuv420p /content/output.mp4\n", 332 | "if version == '0.4':\n", 333 | " ! ffmpeg -r $fps -i /content/ArcaneGANv0.4_out/%06d.jpg -vcodec libx264 -pix_fmt yuv420p /content/output.mp4\n" 334 | ], 335 | "metadata": { 336 | "cellView": "form", 337 | "id": "j61Ga1xRFzbd" 338 | }, 339 | "execution_count": null, 340 | "outputs": [] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "source": [ 345 | "#@title 動画の再生\n", 346 | "from IPython.display import HTML\n", 347 | "from base64 import b64encode\n", 348 | "\n", 349 | "mp4 = open('/content/output.mp4', 'rb').read()\n", 350 | "data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode()\n", 351 | "HTML(f\"\"\"\n", 352 | "\"\"\")" 355 | ], 356 | "metadata": { 357 | "cellView": "form", 358 | "id": "jF3YY8DOAghg" 359 | }, 360 | "execution_count": null, 361 | "outputs": [] 362 | } 363 | ] 364 | } -------------------------------------------------------------------------------- /RIS_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "accelerator": "GPU", 6 | "colab": { 7 | "name": "RIS_demo", 8 | "provenance": [], 9 | "collapsed_sections": [], 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "display_name": "Python 3", 14 | "language": "python", 15 | "name": "python3" 16 | }, 17 | "language_info": { 18 | "codemirror_mode": { 19 | "name": "ipython", 20 | "version": 3 21 | }, 22 | "file_extension": ".py", 23 | "mimetype": "text/x-python", 24 | "name": "python", 25 | "nbconvert_exporter": "python", 26 | "pygments_lexer": "ipython3", 27 | "version": "3.7.10" 28 | } 29 | }, 30 | "cells": [ 31 | { 32 | "cell_type": "markdown", 33 | "metadata": { 34 | "id": "view-in-github", 35 | "colab_type": "text" 36 | }, 37 | "source": [ 38 | "\"Open" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": { 44 | "id": "GRFbyn-ay_Ei" 45 | }, 46 | "source": [ 47 | "# セットアップ" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "metadata": { 53 | "id": "dGqP06Rtm9TE" 54 | }, 55 | "source": [ 56 | "# githubからコードをコピー\n", 57 | "!git clone https://github.com/mchong6/RetrieveInStyle.git\n", 58 | "%cd RetrieveInStyle\n", 59 | "\n", 60 | "# ライブラリーのインストール\n", 61 | "!pip install tqdm gdown scikit-learn scipy lpips dlib opencv-python\n", 62 | "\n", 63 | "# ライブラリーのインポート\n", 64 | "import torch\n", 65 | "from torch import nn\n", 66 | "import numpy as np\n", 67 | "import torch.backends.cudnn as cudnn\n", 68 | "cudnn.benchmark = True\n", 69 | "import matplotlib.pyplot as plt\n", 70 | "import torch.nn.functional as F\n", 71 | "from model import *\n", 72 | "from spherical_kmeans import MiniBatchSphericalKMeans as sKmeans\n", 73 | "from tqdm import tqdm as tqdm\n", 74 | "import pickle\n", 75 | "import warnings\n", 76 | "warnings.filterwarnings(\"ignore\", category=UserWarning) # get rid of interpolation warning\n", 77 | "from util import *\n", 78 | "from google.colab import files\n", 79 | "from util import align_face\n", 80 | "import os\n", 81 | "from e4e_projection import projection\n", 82 | "%matplotlib inline\n", 83 | "\n", 84 | "# 学習済みモデルのロード\n", 85 | "device = 'cuda' # if GPU memory is low, use cpu instead\n", 86 | "generator = Generator(1024, 512, 8, channel_multiplier=2).to(device).eval()\n", 87 | "ensure_checkpoint_exists('stylegan2-ffhq-config-f.pt')\n", 88 | "ckpt = torch.load('stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)\n", 89 | "generator.load_state_dict(ckpt[\"g_ema\"], strict=False)\n", 90 | "with torch.no_grad():\n", 91 | " mean_latent = generator.mean_latent(50000)\n", 92 | "\n", 93 | "# カタログのロード\n", 94 | "truncation = 0.5\n", 95 | "stop_idx = 11 # choose 32x32 layer to do kmeans clustering\n", 96 | "n_clusters = 18 # Number of Kmeans cluster\n", 97 | "clusterer = pickle.load(open(\"catalog.pkl\", \"rb\"))\n" 98 | ], 99 | "execution_count": null, 100 | "outputs": [] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "metadata": { 105 | "id": "uCxun63Nm9TJ" 106 | }, 107 | "source": [ 108 | "# クラスタリングの視覚化" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "metadata": { 114 | "id": "87Pm31p6m9TK" 115 | }, 116 | "source": [ 117 | "plt.rcParams['figure.dpi'] = 150\n", 118 | "\n", 119 | "with torch.no_grad():\n", 120 | " sample_z = torch.randn([1, 512]).to(device)\n", 121 | " sample_w = generator.get_latent(sample_z, truncation=truncation, mean_latent=mean_latent)\n", 122 | " sample, outputs = generator(sample_w, is_cluster=1) # [b, c, h, w]\n", 123 | "\n", 124 | "# obtain 32x32 activations and classify using kmeans\n", 125 | "act = flatten_act(outputs[stop_idx][0])\n", 126 | "b,c,h,w = outputs[stop_idx][0].size()\n", 127 | "\n", 128 | "alpha = 0.5\n", 129 | "seg_mask = clusterer.predict(act)\n", 130 | "seg_mask = torch.from_numpy(seg_mask).view(1,h,w)\n", 131 | "seg_out = decode_segmap(seg_mask)\n", 132 | "\n", 133 | "sample_d = F.interpolate(sample, size=(256,256), mode='bilinear').cpu()\n", 134 | "seg_out_d = F.interpolate(seg_out, size=(256,256), mode='nearest')\n", 135 | "out = alpha*seg_out_d + (1-alpha)*sample_d\n", 136 | "\n", 137 | "display_image(out)" 138 | ], 139 | "execution_count": null, 140 | "outputs": [] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "metadata": { 145 | "id": "ISOQPp-bm9TK" 146 | }, 147 | "source": [ 148 | "# 顔の特徴のラベル付けと関数定義\n", 149 | "\n", 150 | "\n" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "metadata": { 156 | "id": "a2IPbbs4m9TM" 157 | }, 158 | "source": [ 159 | "#Gives an index to each feature we care about\n", 160 | "labels2idx = {\n", 161 | " 'nose': 0,\n", 162 | " 'eyes': 1,\n", 163 | " 'mouth':2,\n", 164 | " 'hair': 3,\n", 165 | " 'background': 4,\n", 166 | " 'cheeks': 5,\n", 167 | " 'neck': 6,\n", 168 | " 'clothes': 7,\n", 169 | "}\n", 170 | "\n", 171 | "# Assign to each feature the cluster index from segmentation\n", 172 | "labels_map = {\n", 173 | " 0: torch.tensor([7]),\n", 174 | " 1: torch.tensor([1,6]),\n", 175 | " 2: torch.tensor([4]),\n", 176 | " 3: torch.tensor([0,3,5,8,10,15,16]),\n", 177 | " 4: torch.tensor([11,13,14]),\n", 178 | " 5: torch.tensor([9]),\n", 179 | " 6: torch.tensor([17]),\n", 180 | " 7: torch.tensor([2,12]),\n", 181 | "}\n", 182 | "\n", 183 | "idx2labels = dict((v,k) for k,v in labels2idx.items())\n", 184 | "n_class = len(labels2idx)\n", 185 | "\n", 186 | "\n", 187 | "# compute M given a style code.\n", 188 | "@torch.no_grad()\n", 189 | "def compute_M(w, device='cuda'):\n", 190 | " M = []\n", 191 | " \n", 192 | " # get segmentation\n", 193 | " _, outputs = generator(w, is_cluster=1)\n", 194 | " cluster_layer = outputs[stop_idx][0]\n", 195 | " activation = flatten_act(cluster_layer)\n", 196 | " seg_mask = clusterer.predict(activation)\n", 197 | " b,c,h,w = cluster_layer.size()\n", 198 | "\n", 199 | " # create masks for each feature\n", 200 | " all_seg_mask = []\n", 201 | " seg_mask = torch.from_numpy(seg_mask).view(b,1,h,w,1).to(device)\n", 202 | " \n", 203 | " for key in range(n_class):\n", 204 | " # combine masks for all indices for a particular segmentation class\n", 205 | " indices = labels_map[key].view(1,1,1,1,-1) \n", 206 | " key_mask = (seg_mask == indices.to(device)).any(-1) #[b,1,h,w]\n", 207 | " all_seg_mask.append(key_mask)\n", 208 | " \n", 209 | " all_seg_mask = torch.stack(all_seg_mask, 1)\n", 210 | "\n", 211 | " # go through each activation layer and compute M\n", 212 | " for layer_idx in range(len(outputs)):\n", 213 | " layer = outputs[layer_idx][1].to(device)\n", 214 | " b,c,h,w = layer.size()\n", 215 | " layer = F.instance_norm(layer)\n", 216 | " layer = layer.pow(2)\n", 217 | " \n", 218 | " # resize the segmentation masks to current activations' resolution\n", 219 | " layer_seg_mask = F.interpolate(all_seg_mask.flatten(0,1).float(), align_corners=False, \n", 220 | " size=(h,w), mode='bilinear').view(b,-1,1,h,w)\n", 221 | " \n", 222 | " masked_layer = layer.unsqueeze(1) * layer_seg_mask # [b,k,c,h,w]\n", 223 | " masked_layer = (masked_layer.sum([3,4])/ (h*w))#[b,k,c]\n", 224 | "\n", 225 | " M.append(masked_layer.to(device))\n", 226 | "\n", 227 | " M = torch.cat(M, -1) #[b, k, c]\n", 228 | " \n", 229 | " # softmax to assign each channel to a particular segmentation class\n", 230 | " M = F.softmax(M/.1, 1)\n", 231 | " # simple thresholding\n", 232 | " M = (M>.8).float()\n", 233 | " \n", 234 | " # zero out torgb transfers, from https://arxiv.org/abs/2011.12799\n", 235 | " for i in range(n_class):\n", 236 | " part_M = style2list(M[:, i])\n", 237 | " for j in range(len(part_M)):\n", 238 | " if j in rgb_layer_idx:\n", 239 | " part_M[j].zero_()\n", 240 | " part_M = list2style(part_M)\n", 241 | " M[:, i] = part_M\n", 242 | "\n", 243 | " return M" 244 | ], 245 | "execution_count": null, 246 | "outputs": [] 247 | }, 248 | { 249 | "cell_type": "markdown", 250 | "metadata": { 251 | "id": "ojscea08whhf" 252 | }, 253 | "source": [ 254 | "# 【オプション】画像から潜在変数を求める\n", 255 | "*自分で用意した画像を使わない場合はこのブロックの実行をパスして下さい。\\\n", 256 | "*自分で用意した画像(jpg)から顔部分を切り取って潜在変数化したい場合は、PCからその画像をドラッグ&ドロップで RetrieveInStyle/images へアップロードして(複数OK)から、下記を実行して下さい。" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "metadata": { 262 | "id": "9uPnUk4-vnC2" 263 | }, 264 | "source": [ 265 | "import glob\n", 266 | "files = glob.glob('images/*.jpg')\n", 267 | "for file in files:\n", 268 | " filename = file[7:-4]\n", 269 | " cropped_face = align_face(file) # 顔部分の切り取り\n", 270 | " projection(cropped_face, filename, generator, device) # 潜在変数の取得" 271 | ], 272 | "execution_count": null, 273 | "outputs": [] 274 | }, 275 | { 276 | "cell_type": "markdown", 277 | "metadata": { 278 | "id": "a7Mxkkgwm9TN" 279 | }, 280 | "source": [ 281 | "# 顔の特徴の転送" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "metadata": { 287 | "id": "azXsVoBAm9TN" 288 | }, 289 | "source": [ 290 | "# ソース画像と参照画像の設定\n", 291 | "plt.rcParams['figure.dpi'] = 75\n", 292 | "\n", 293 | "# load codes from inverted real images using our projection code\n", 294 | "with torch.no_grad():\n", 295 | " '''\n", 296 | " if you gan inverted in the previous cell, you can call it here with variable filename\n", 297 | " otherwise, you can randomly generate or call a pre-inverted image\n", 298 | " '''\n", 299 | " # source = load_source([filename], generator, device)\n", 300 | " source = load_source(['brad_pitt'], generator, device)\n", 301 | " source_im, _ = generator(source)\n", 302 | " display_image(source_im, size=256)\n", 303 | " \n", 304 | " ref = load_source(['emma_watson', 'emma_stone', 'jennie'], generator, device)\n", 305 | " ref_im, _ = generator(ref)\n", 306 | " ref_im = downsample(ref_im)\n", 307 | " \n", 308 | " show(normalize_im(ref_im).permute(0,2,3,1).cpu(), title='References')" 309 | ], 310 | "execution_count": null, 311 | "outputs": [] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "metadata": { 316 | "id": "SvF4GLTum9TN", 317 | "scrolled": false 318 | }, 319 | "source": [ 320 | "# 顔の特徴の転送\n", 321 | "# Compute M for both source and reference images use cpu here to save memory\n", 322 | "source_M = compute_M(source, device='cpu')\n", 323 | "ref_M = compute_M(ref, device='cpu')\n", 324 | "\n", 325 | "# Find relevant channels for source and reference by taking max over their individual M\n", 326 | "max_M = torch.max(source_M.expand_as(ref_M), ref_M)\n", 327 | "max_M = add_pose(max_M, labels2idx)\n", 328 | "\n", 329 | "all_im = {}\n", 330 | "\n", 331 | "with torch.no_grad(): \n", 332 | " # features we are interest in transferring\n", 333 | " parts = ('eyes', 'nose', 'mouth', 'hair','pose')\n", 334 | " for label in parts:\n", 335 | " if label == 'pose':\n", 336 | " idx = -1\n", 337 | " else:\n", 338 | " idx = labels2idx[label]\n", 339 | " \n", 340 | " part_M = max_M[:,idx].to(device)\n", 341 | " blend = style2list(add_direction(source, ref, part_M, 1.3))\n", 342 | " \n", 343 | " blend_im, _ = generator(blend)\n", 344 | " blend_im = downsample(blend_im).cpu()\n", 345 | " all_im[label] = normalize_im(blend_im)\n", 346 | " \n", 347 | "part_grid(normalize_im(source_im.detach()), normalize_im(ref_im.detach()), all_im);" 348 | ], 349 | "execution_count": null, 350 | "outputs": [] 351 | }, 352 | { 353 | "cell_type": "markdown", 354 | "metadata": { 355 | "id": "z-qa0pfQm9TO" 356 | }, 357 | "source": [ 358 | "# 顔の特徴の転送(度合いのコントロール)" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "metadata": { 364 | "id": "lpbNadBpm9TO", 365 | "scrolled": false 366 | }, 367 | "source": [ 368 | "# ソース画像と参照画像の設定\n", 369 | "plt.rcParams['figure.dpi'] = 75\n", 370 | "torch.manual_seed(3913)\n", 371 | " \n", 372 | "with torch.no_grad():\n", 373 | " source = load_source(['emma_stone'], generator, device)\n", 374 | " source_im, _ = generator(source)\n", 375 | " display_image(source_im, size=256)\n", 376 | " \n", 377 | " ref = load_source(['brad_pitt'], generator, device)\n", 378 | " ref_im, _ = generator(ref)\n", 379 | " ref_im = downsample(ref_im)\n", 380 | " display_image(ref_im, title='reference')" 381 | ], 382 | "execution_count": null, 383 | "outputs": [] 384 | }, 385 | { 386 | "cell_type": "code", 387 | "metadata": { 388 | "id": "Ixo9fpeRm9TO" 389 | }, 390 | "source": [ 391 | "# 顔の特徴の転送(度合いのコントロール)\n", 392 | "plt.rcParams['figure.dpi'] = 200 \n", 393 | "\n", 394 | "source_M = compute_M(source, device='cpu')\n", 395 | "ref_M = compute_M(ref, device='cpu')\n", 396 | "\n", 397 | "max_M = torch.max(source_M.expand_as(ref_M), ref_M)\n", 398 | "max_M = add_pose(max_M, labels2idx)\n", 399 | "\n", 400 | "labels = ('eyes', 'hair') # choose what feature to interpolate {eyes/nose/mouth/hair/pose}\n", 401 | "max_alpha = 1.5 # max range to interpolate\n", 402 | "\n", 403 | "all_im = []\n", 404 | "with torch.no_grad(): \n", 405 | " for label in labels:\n", 406 | " row = []\n", 407 | " \n", 408 | " if label == 'pose':\n", 409 | " idx = -1\n", 410 | " else:\n", 411 | " idx = labels2idx[label]\n", 412 | "\n", 413 | " for alpha in np.linspace(-max_alpha, max_alpha, 5):\n", 414 | " part_M = max_M[:,idx].to(device) \n", 415 | " blend = style2list(add_direction(source, ref, part_M, alpha))\n", 416 | " blend_im, _ = generator(blend)\n", 417 | " blend_im = downsample(blend_im).cpu()\n", 418 | " row.append(blend_im)\n", 419 | "\n", 420 | " row.append(ref_im.cpu())\n", 421 | " row = torch.cat(row, -1)\n", 422 | " all_im.append(row)\n", 423 | " \n", 424 | " all_im = torch.cat(all_im, 2)\n", 425 | " display_image(all_im, size=None)" 426 | ], 427 | "execution_count": null, 428 | "outputs": [] 429 | }, 430 | { 431 | "cell_type": "markdown", 432 | "metadata": { 433 | "id": "8UDXOQ55m9TP" 434 | }, 435 | "source": [ 436 | "# 顔の特徴検索\n", 437 | "GANで5000個の顔のデータベースを作成し、顔の特徴が似た人を抽出する" 438 | ] 439 | }, 440 | { 441 | "cell_type": "code", 442 | "metadata": { 443 | "id": "TN-O34QCm9TP" 444 | }, 445 | "source": [ 446 | "# 顔データベースの作成\n", 447 | "torch.manual_seed(12390)\n", 448 | "num_data = 5000\n", 449 | "dataset = torch.randn([num_data, 512]).to(device)\n", 450 | "with torch.no_grad():\n", 451 | " dataset_w = generator.get_latent(dataset, truncation=truncation, mean_latent=mean_latent)\n", 452 | " dataset_M = []\n", 453 | " for i in tqdm(range(num_data)):\n", 454 | " # have to use cuda for this or it will be very slow\n", 455 | " dataset_M.append(compute_M(index_layers(dataset_w, i), device='cuda'))\n", 456 | "\n", 457 | " dataset_M = remove_2048(torch.cat(dataset_M, 0), labels2idx).to(device) #[N, K, C]" 458 | ], 459 | "execution_count": null, 460 | "outputs": [] 461 | }, 462 | { 463 | "cell_type": "code", 464 | "metadata": { 465 | "id": "RUXmlLzam9TQ" 466 | }, 467 | "source": [ 468 | "# 検索対象の表示\n", 469 | "plt.rcParams['figure.dpi'] = 75 \n", 470 | "\n", 471 | "with torch.no_grad():\n", 472 | " query_w = load_source(['tom_hiddleston'], generator, device)\n", 473 | " \n", 474 | " query_im, _ = generator(query_w)\n", 475 | " display_image(query_im)" 476 | ], 477 | "execution_count": null, 478 | "outputs": [] 479 | }, 480 | { 481 | "cell_type": "code", 482 | "metadata": { 483 | "id": "7tmBV9eYm9TQ", 484 | "scrolled": false 485 | }, 486 | "source": [ 487 | "# 検索実行と結果表示\n", 488 | "plt.rcParams['figure.dpi'] = 300\n", 489 | " \n", 490 | "num_nn = 6\n", 491 | "all_im = []\n", 492 | "query_M = remove_2048(compute_M(query_w, device=device), labels2idx).to(device)\n", 493 | "\n", 494 | "r_query_w = list2style(query_w)\n", 495 | "r_dataset_w = list2style(dataset_w)\n", 496 | "\n", 497 | "# normalize each style dimension\n", 498 | "largest = r_dataset_w.abs().max(0, keepdim=True)[0] + 1e-8\n", 499 | "norm_query_w = r_query_w/largest\n", 500 | "norm_target_w = r_dataset_w/largest\n", 501 | "\n", 502 | "# choose what features to perform retrieval on\n", 503 | "# parts = ('eyes', 'nose', 'mouth', 'hair')\n", 504 | "parts = ('eyes', 'mouth', 'hair',)\n", 505 | "\n", 506 | "# perform cosine similarity w.r.t a given feature\n", 507 | "with torch.no_grad():\n", 508 | " for part in parts:\n", 509 | " idx = labels2idx[part]\n", 510 | " \n", 511 | " source_part = norm_query_w * query_M[:,idx].to(device)\n", 512 | " target_part = norm_target_w * dataset_M[:,idx].to(device)\n", 513 | " \n", 514 | " distance = cos_dist(target_part, source_part) \n", 515 | " nearest_neighbors = torch.sort(distance)[1][:num_nn]\n", 516 | "\n", 517 | " row = [query_im.cpu()]\n", 518 | " for idx in nearest_neighbors:\n", 519 | " nn_w = index_layers(dataset_w, int(idx))\n", 520 | " nn_image, _ = generator(nn_w)\n", 521 | " row.append(nn_image.cpu())\n", 522 | " row = [downsample(a) for a in row]\n", 523 | " row = torch.cat(row, -1)\n", 524 | " all_im.append(row)\n", 525 | " \n", 526 | " all_im = torch.cat(all_im,-2)\n", 527 | " display_image(all_im, size=None)" 528 | ], 529 | "execution_count": null, 530 | "outputs": [] 531 | } 532 | ] 533 | } -------------------------------------------------------------------------------- /ReStyle_animations.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "accelerator": "GPU", 6 | "colab": { 7 | "name": "ReStyle_animations", 8 | "provenance": [], 9 | "collapsed_sections": [], 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "display_name": "Python 3", 14 | "name": "python3" 15 | }, 16 | "language_info": { 17 | "name": "python" 18 | } 19 | }, 20 | "cells": [ 21 | { 22 | "cell_type": "markdown", 23 | "metadata": { 24 | "id": "view-in-github", 25 | "colab_type": "text" 26 | }, 27 | "source": [ 28 | "\"Open" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": { 34 | "id": "pRIb9Xqjnmxn" 35 | }, 36 | "source": [ 37 | "# セットアップ" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "metadata": { 43 | "id": "VcNK15ganhUH" 44 | }, 45 | "source": [ 46 | "# githubからコードをコピー\n", 47 | "import os\n", 48 | "os.chdir('/content')\n", 49 | "CODE_DIR = 'restyle-encoder'\n", 50 | "!git clone https://github.com/yuval-alaluf/restyle-encoder.git $CODE_DIR\n", 51 | "\n", 52 | "# ninjaシステムインストール\n", 53 | "!wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip\n", 54 | "!sudo unzip ninja-linux.zip -d /usr/local/bin/\n", 55 | "!sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force\n", 56 | "os.chdir(f'./{CODE_DIR}')\n", 57 | "\n", 58 | "# ライブラリーのインポート\n", 59 | "from argparse import Namespace\n", 60 | "import time\n", 61 | "import os\n", 62 | "import sys\n", 63 | "import pprint\n", 64 | "from tqdm import tqdm\n", 65 | "import numpy as np\n", 66 | "from PIL import Image\n", 67 | "import torch\n", 68 | "import torchvision.transforms as transforms\n", 69 | "import imageio\n", 70 | "import matplotlib\n", 71 | "from IPython.display import HTML\n", 72 | "from base64 import b64encode\n", 73 | "\n", 74 | "sys.path.append(\".\")\n", 75 | "sys.path.append(\"..\")\n", 76 | "from utils.common import tensor2im\n", 77 | "from utils.inference_utils import run_on_batch\n", 78 | "from models.psp import pSp\n", 79 | "from models.e4e import e4e\n", 80 | "\n", 81 | "%load_ext autoreload\n", 82 | "%autoreload 2\n", 83 | "\n", 84 | "# サンプル画像のダウンロード\n", 85 | "! pip install --upgrade gdown\n", 86 | "import gdown\n", 87 | "gdown.download('https://drive.google.com/uc?id=1EvinsyeqFSU982133ehKCC50IYR1109t', './notebooks/pic.zip', quiet=False)\n", 88 | "! unzip -d notebooks notebooks/pic.zip" 89 | ], 90 | "execution_count": null, 91 | "outputs": [] 92 | }, 93 | { 94 | "cell_type": "markdown", 95 | "metadata": { 96 | "id": "Ba4ovOESo1Su" 97 | }, 98 | "source": [ 99 | "# モデルのダウンロード\n" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "metadata": { 105 | "id": "_wO0FrBNo07X", 106 | "cellView": "form" 107 | }, 108 | "source": [ 109 | "#@title モデルの指定\n", 110 | "experiment_type = 'ffhq_encode' #@param ['ffhq_encode', 'cars_encode', 'church_encode', 'horse_encode', 'afhq_wild_encode', 'toonify']" 111 | ], 112 | "execution_count": null, 113 | "outputs": [] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "metadata": { 118 | "id": "KSnjlBZOkTJ0" 119 | }, 120 | "source": [ 121 | "# ダウンロード命令の作成\n", 122 | "def get_download_model_command(file_id, file_name):\n", 123 | " \"\"\" Get wget download command for downloading the desired model and save to directory ../pretrained_models. \"\"\"\n", 124 | " current_directory = os.getcwd()\n", 125 | " save_path = os.path.join(os.path.dirname(current_directory), CODE_DIR, \"pretrained_models\")\n", 126 | " if not os.path.exists(save_path):\n", 127 | " os.makedirs(save_path)\n", 128 | " url = r\"\"\"wget --load-cookies /tmp/cookies.txt \"https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id={FILE_ID}' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\\1\\n/p')&id={FILE_ID}\" -O {SAVE_PATH}/{FILE_NAME} && rm -rf /tmp/cookies.txt\"\"\".format(FILE_ID=file_id, FILE_NAME=file_name, SAVE_PATH=save_path)\n", 129 | " return url \n", 130 | "\n", 131 | "MODEL_PATHS = {\n", 132 | " \"ffhq_encode\": {\"id\": \"1sw6I2lRIB0MpuJkpc8F5BJiSZrc0hjfE\", \"name\": \"restyle_psp_ffhq_encode.pt\"},\n", 133 | " \"cars_encode\": {\"id\": \"1zJHqHRQ8NOnVohVVCGbeYMMr6PDhRpPR\", \"name\": \"restyle_psp_cars_encode.pt\"},\n", 134 | " \"church_encode\": {\"id\": \"1bcxx7mw-1z7dzbJI_z7oGpWG1oQAvMaD\", \"name\": \"restyle_psp_church_encode.pt\"},\n", 135 | " \"horse_encode\": {\"id\": \"19_sUpTYtJmhSAolKLm3VgI-ptYqd-hgY\", \"name\": \"restyle_e4e_horse_encode.pt\"},\n", 136 | " \"afhq_wild_encode\": {\"id\": \"1GyFXVTNDUw3IIGHmGS71ChhJ1Rmslhk7\", \"name\": \"restyle_psp_afhq_wild_encode.pt\"},\n", 137 | " \"toonify\": {\"id\": \"1GtudVDig59d4HJ_8bGEniz5huaTSGO_0\", \"name\": \"restyle_psp_toonify.pt\"}\n", 138 | "}\n", 139 | "\n", 140 | "path = MODEL_PATHS[experiment_type]\n", 141 | "download_command = get_download_model_command(file_id=path[\"id\"], file_name=path[\"name\"]) \n", 142 | "\n", 143 | "\n", 144 | "# パラメータの設定\n", 145 | "EXPERIMENT_DATA_ARGS = {\n", 146 | " \"ffhq_encode\": {\n", 147 | " \"model_path\": \"pretrained_models/restyle_psp_ffhq_encode.pt\",\n", 148 | " \"image_path\": \"notebooks/images/face_img.jpg\",\n", 149 | " \"transform\": transforms.Compose([\n", 150 | " transforms.Resize((256, 256)),\n", 151 | " transforms.ToTensor(),\n", 152 | " transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])\n", 153 | " },\n", 154 | " \"cars_encode\": {\n", 155 | " \"model_path\": \"pretrained_models/restyle_psp_cars_encode.pt\",\n", 156 | " \"image_path\": \"notebooks/images/car_img.jpg\",\n", 157 | " \"transform\": transforms.Compose([\n", 158 | " transforms.Resize((192, 256)),\n", 159 | " transforms.ToTensor(),\n", 160 | " transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])\n", 161 | " },\n", 162 | " \"church_encode\": {\n", 163 | " \"model_path\": \"pretrained_models/restyle_psp_church_encode.pt\",\n", 164 | " \"image_path\": \"notebooks/images/church_img.jpg\",\n", 165 | " \"transform\": transforms.Compose([\n", 166 | " transforms.Resize((256, 256)),\n", 167 | " transforms.ToTensor(),\n", 168 | " transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])\n", 169 | " },\n", 170 | " \"horse_encode\": {\n", 171 | " \"model_path\": \"pretrained_models/restyle_e4e_horse_encode.pt\",\n", 172 | " \"image_path\": \"notebooks/images/horse_img.jpg\",\n", 173 | " \"transform\": transforms.Compose([\n", 174 | " transforms.Resize((256, 256)),\n", 175 | " transforms.ToTensor(),\n", 176 | " transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])\n", 177 | " },\n", 178 | " \"afhq_wild_encode\": {\n", 179 | " \"model_path\": \"pretrained_models/restyle_psp_afhq_wild_encode.pt\",\n", 180 | " \"image_path\": \"notebooks/images/afhq_wild_img.jpg\",\n", 181 | " \"transform\": transforms.Compose([\n", 182 | " transforms.Resize((256, 256)),\n", 183 | " transforms.ToTensor(),\n", 184 | " transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])\n", 185 | " },\n", 186 | " \"toonify\": {\n", 187 | " \"model_path\": \"pretrained_models/restyle_psp_toonify.pt\",\n", 188 | " \"image_path\": \"notebooks/images/toonify_img.jpg\",\n", 189 | " \"transform\": transforms.Compose([\n", 190 | " transforms.Resize((256, 256)),\n", 191 | " transforms.ToTensor(),\n", 192 | " transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])\n", 193 | " },\n", 194 | "}\n", 195 | "\n", 196 | "# モデルのダウンロード\n", 197 | "EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS[experiment_type]\n", 198 | "\n", 199 | "if not os.path.exists(EXPERIMENT_ARGS['model_path']) or os.path.getsize(EXPERIMENT_ARGS['model_path']) < 1000000:\n", 200 | " print(f'Downloading ReStyle model for {experiment_type}...')\n", 201 | " os.system(f\"wget {download_command}\")\n", 202 | " # if google drive receives too many requests, we'll reach the quota limit and be unable to download the model\n", 203 | " if os.path.getsize(EXPERIMENT_ARGS['model_path']) < 1000000:\n", 204 | " raise ValueError(\"Pretrained model was unable to be downloaded correctly!\")\n", 205 | " else:\n", 206 | " print('Done.')\n", 207 | "else:\n", 208 | " print(f'ReStyle model for {experiment_type} already exists!')\n", 209 | "\n", 210 | "\n", 211 | "# モデルのロード\n", 212 | "model_path = EXPERIMENT_ARGS['model_path']\n", 213 | "ckpt = torch.load(model_path, map_location='cpu')\n", 214 | "opts = ckpt['opts']\n", 215 | "opts['checkpoint_path'] = model_path\n", 216 | "opts = Namespace(**opts)\n", 217 | "\n", 218 | "if experiment_type == 'horse_encode': \n", 219 | " net = e4e(opts)\n", 220 | "else:\n", 221 | " net = pSp(opts)\n", 222 | " \n", 223 | "net.eval()\n", 224 | "net.cuda()\n", 225 | "print('Model successfully loaded!')" 226 | ], 227 | "execution_count": null, 228 | "outputs": [] 229 | }, 230 | { 231 | "cell_type": "markdown", 232 | "metadata": { 233 | "id": "mdQ0fXM7ppSr" 234 | }, 235 | "source": [ 236 | "## 関数定義と設定" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "metadata": { 242 | "id": "CJbNCfLaplKu" 243 | }, 244 | "source": [ 245 | "# 関数定義\n", 246 | "def generate_mp4(out_name, images, kwargs):\n", 247 | " writer = imageio.get_writer(out_name + '.mp4', **kwargs)\n", 248 | " for image in images:\n", 249 | " writer.append_data(image)\n", 250 | " writer.close()\n", 251 | "\n", 252 | "\n", 253 | "def run_on_batch_to_vecs(inputs, net, opts):\n", 254 | " opts.resize_outputs = False\n", 255 | " opts.n_iters_per_batch = 5\n", 256 | " with torch.no_grad():\n", 257 | " _, result_batch = run_on_batch(inputs.to(\"cuda\").float(), net, opts, avg_image)\n", 258 | " return result_batch[0][-1]\n", 259 | "\n", 260 | "\n", 261 | "def get_result_from_vecs(vectors_a, vectors_b, alpha):\n", 262 | " results = []\n", 263 | " for i in range(len(vectors_a)):\n", 264 | " with torch.no_grad():\n", 265 | " cur_vec = vectors_b[i] * alpha + vectors_a[i] * (1 - alpha)\n", 266 | " res = net(torch.from_numpy(cur_vec).cuda().unsqueeze(0), randomize_noise=False,\n", 267 | " input_code=True, input_is_full=True, resize=False)\n", 268 | " results.append(res[0])\n", 269 | " return results\n", 270 | "\n", 271 | "def show_mp4(filename, width):\n", 272 | " mp4 = open(filename + '.mp4', 'rb').read()\n", 273 | " data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n", 274 | " display(HTML(\"\"\"\n", 275 | " \n", 278 | " \"\"\" % (width, data_url)))\n", 279 | "\n", 280 | "\n", 281 | "# 潜在変数データの平均値を取得\n", 282 | "avg_image = net(net.latent_avg.unsqueeze(0),\n", 283 | " input_code=True,\n", 284 | " randomize_noise=False,\n", 285 | " return_latents=False,\n", 286 | " average_code=True)[0]\n", 287 | "avg_image = avg_image.to('cuda').float().detach()\n", 288 | "if opts.dataset_type == \"cars_encode\":\n", 289 | " avg_image = avg_image[:, 32:224, :]\n", 290 | "\n", 291 | "\n", 292 | "# 設定\n", 293 | "SEED = 42\n", 294 | "np.random.seed(SEED)\n", 295 | "img_transforms = EXPERIMENT_ARGS['transform']\n", 296 | "root_dir = \"notebooks/images/\"\n", 297 | "image_names = ['', '', '', '', '']\n", 298 | "image_paths = [os.path.join(root_dir, image) + '.jpg' for image in image_names]\n", 299 | "\n", 300 | "\n", 301 | "# imagesフォルダーをリセット\n", 302 | "import os\n", 303 | "import shutil\n", 304 | "if os.path.isdir('notebooks/images'):\n", 305 | " shutil.rmtree('notebooks/images')\n", 306 | "os.makedirs('notebooks/images', exist_ok=True)" 307 | ], 308 | "execution_count": null, 309 | "outputs": [] 310 | }, 311 | { 312 | "cell_type": "markdown", 313 | "metadata": { 314 | "id": "pvsWbbOytp4a" 315 | }, 316 | "source": [ 317 | "## Align\n", 318 | "picフォルダーにあるサンプル画像をAlignし、imagesフォルダーに保存します。\\\n", 319 | " *ffhq_encoder, toonify モデルを以外を指定した場合や、align済みの画像がある場合は、このブロックをスキップして、imagesフォルダーに画像(jpg)をアップロードして下さい。" 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "metadata": { 325 | "id": "57mIkFsbuMaa" 326 | }, 327 | "source": [ 328 | "def run_alignment(image_path):\n", 329 | " import dlib\n", 330 | " from scripts.align_faces_parallel import align_face\n", 331 | " if not os.path.exists(\"shape_predictor_68_face_landmarks.dat\"):\n", 332 | " print('Downloading files for aligning face image...')\n", 333 | " os.system('wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2')\n", 334 | " os.system('bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2')\n", 335 | " print('Done.')\n", 336 | " predictor = dlib.shape_predictor(\"shape_predictor_68_face_landmarks.dat\")\n", 337 | " aligned_image = align_face(filepath=image_path, \n", 338 | " predictor=predictor,\n", 339 | " output_size=256, \n", 340 | " transform_size=256) \n", 341 | " print(\"Aligned image has shape: {}\".format(aligned_image.size))\n", 342 | " return aligned_image \n", 343 | "\n", 344 | "\n", 345 | "ALIGN_IMAGES = True\n", 346 | "import glob\n", 347 | "import os\n", 348 | "image_paths = glob.glob('./notebooks/pic/*.jpg')\n", 349 | "image_names = os.listdir('./notebooks/pic')\n", 350 | "image_paths.sort()\n", 351 | "image_names.sort()\n", 352 | "\n", 353 | "# ffhq_encoderかtoonifyのときのみalignを実行\n", 354 | "if ALIGN_IMAGES and experiment_type in [\"ffhq_encode\", \"toonify\"]: \n", 355 | " aligned_image_paths = []\n", 356 | " for image_name, image_path in zip(image_names, image_paths): \n", 357 | " print(f'Aligning {image_name}...')\n", 358 | " aligned_image = run_alignment(image_path)\n", 359 | " aligned_path = os.path.join(root_dir, f'{image_name}_aligned.jpg')\n", 360 | " # save the aligned image\n", 361 | " aligned_image.save(aligned_path)\n", 362 | " aligned_image_paths.append(aligned_path)\n", 363 | " # use the save aligned images as our input image paths\n", 364 | " image_paths = aligned_image_paths" 365 | ], 366 | "execution_count": null, 367 | "outputs": [] 368 | }, 369 | { 370 | "cell_type": "markdown", 371 | "metadata": { 372 | "id": "nA8pC10Yry6U" 373 | }, 374 | "source": [ 375 | "## 画像から潜在変数を求める" 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "metadata": { 381 | "id": "CFQyaD1Qqe3d" 382 | }, 383 | "source": [ 384 | "import glob\n", 385 | "image_paths = glob.glob('notebooks/images/*.jpg')\n", 386 | "image_paths.sort()\n", 387 | "\n", 388 | "\n", 389 | "in_images = []\n", 390 | "all_vecs = []\n", 391 | "\n", 392 | "if experiment_type == \"cars_encode\":\n", 393 | " resize_amount = (512, 384)\n", 394 | "else:\n", 395 | " resize_amount = (opts.output_size, opts.output_size)\n", 396 | "\n", 397 | "for image_path in image_paths:\n", 398 | " print(f'Working on {os.path.basename(image_path)}...')\n", 399 | " original_image = Image.open(image_path)\n", 400 | " original_image = original_image.convert(\"RGB\")\n", 401 | " input_image = img_transforms(original_image)\n", 402 | " with torch.no_grad():\n", 403 | " result_vec = run_on_batch_to_vecs(input_image.unsqueeze(0), net, opts)\n", 404 | " all_vecs.append([result_vec])\n", 405 | " in_images.append(original_image.resize(resize_amount))" 406 | ], 407 | "execution_count": null, 408 | "outputs": [] 409 | }, 410 | { 411 | "cell_type": "markdown", 412 | "metadata": { 413 | "id": "wV7yaSXar88q" 414 | }, 415 | "source": [ 416 | "## 補完画像の生成" 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "metadata": { 422 | "id": "TsA3azDuqtWd" 423 | }, 424 | "source": [ 425 | "n_transition = 25\n", 426 | "if experiment_type == \"cars_encode\":\n", 427 | " SIZE = 384\n", 428 | "else:\n", 429 | " SIZE = opts.output_size\n", 430 | "\n", 431 | "images = []\n", 432 | "image_paths.append(image_paths[0])\n", 433 | "all_vecs.append(all_vecs[0])\n", 434 | "in_images.append(in_images[0])\n", 435 | "\n", 436 | "for i in range(1, len(image_paths)):\n", 437 | " if i == 0:\n", 438 | " alpha_vals = [0] * 10 + np.linspace(0, 1, n_transition).tolist() + [1] * 5\n", 439 | " else:\n", 440 | " alpha_vals = [0] * 5 + np.linspace(0, 1, n_transition).tolist() + [1] * 5\n", 441 | "\n", 442 | " for alpha in tqdm(alpha_vals):\n", 443 | " image_a = np.array(in_images[i - 1])\n", 444 | " image_b = np.array(in_images[i])\n", 445 | " image_joint = np.zeros_like(image_a)\n", 446 | " up_to_row = int((SIZE - 1) * alpha)\n", 447 | " if up_to_row > 0:\n", 448 | " image_joint[:(up_to_row + 1), :, :] = image_b[((SIZE - 1) - up_to_row):, :, :]\n", 449 | " if up_to_row < (SIZE - 1):\n", 450 | " image_joint[up_to_row:, :, :] = image_a[:(SIZE - up_to_row), :, :]\n", 451 | "\n", 452 | " result_image = get_result_from_vecs(all_vecs[i - 1], all_vecs[i], alpha)[0]\n", 453 | " if experiment_type == \"cars_encode\":\n", 454 | " result_image = result_image[:, 64:448, :]\n", 455 | "\n", 456 | " output_im = tensor2im(result_image)\n", 457 | " res = np.concatenate([image_joint, np.array(output_im)], axis=1)\n", 458 | " images.append(res)" 459 | ], 460 | "execution_count": null, 461 | "outputs": [] 462 | }, 463 | { 464 | "cell_type": "markdown", 465 | "metadata": { 466 | "id": "0a4dTg5Ly4P0" 467 | }, 468 | "source": [ 469 | "## mp4作成" 470 | ] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "metadata": { 475 | "colab": { 476 | "background_save": true 477 | }, 478 | "id": "YqvG0oJtsUWt" 479 | }, 480 | "source": [ 481 | "kwargs = {'fps': 15}\n", 482 | "save_path = \"notebooks/animations\"\n", 483 | "os.makedirs(save_path, exist_ok=True)\n", 484 | "\n", 485 | "gif_path = os.path.join(save_path, f\"{experiment_type}_gif\")\n", 486 | "generate_mp4(gif_path, images, kwargs)\n", 487 | "show_mp4(gif_path, width=opts.output_size)" 488 | ], 489 | "execution_count": null, 490 | "outputs": [] 491 | } 492 | ] 493 | } 494 | -------------------------------------------------------------------------------- /AnimeGANV2_for_face.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 5, 4 | "metadata": { 5 | "kernelspec": { 6 | "display_name": "Python 3 (ipykernel)", 7 | "language": "python", 8 | "name": "python3" 9 | }, 10 | "language_info": { 11 | "codemirror_mode": { 12 | "name": "ipython", 13 | "version": 3 14 | }, 15 | "file_extension": ".py", 16 | "mimetype": "text/x-python", 17 | "name": "python", 18 | "nbconvert_exporter": "python", 19 | "pygments_lexer": "ipython3", 20 | "version": "3.8.10" 21 | }, 22 | "colab": { 23 | "name": "AnimeGANV2_for_face", 24 | "provenance": [], 25 | "include_colab_link": true 26 | }, 27 | "accelerator": "GPU" 28 | }, 29 | "cells": [ 30 | { 31 | "cell_type": "markdown", 32 | "metadata": { 33 | "id": "view-in-github", 34 | "colab_type": "text" 35 | }, 36 | "source": [ 37 | "\"Open" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": { 43 | "id": "BWEN4yiB5gjz" 44 | }, 45 | "source": [ 46 | "# AnimeGANV2 for face" 47 | ], 48 | "id": "BWEN4yiB5gjz" 49 | }, 50 | { 51 | "cell_type": "code", 52 | "metadata": { 53 | "id": "3f40d528", 54 | "cellView": "form" 55 | }, 56 | "source": [ 57 | "#@title セットアップ\n", 58 | "\n", 59 | "# load Face2Paint model\n", 60 | "import torch \n", 61 | "from PIL import Image\n", 62 | "\n", 63 | "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", 64 | "model = torch.hub.load(\"bryandlee/animegan2-pytorch:main\", \"generator\", device=device).eval()\n", 65 | "face2paint = torch.hub.load(\"bryandlee/animegan2-pytorch:main\", \"face2paint\", device=device, side_by_side=True)\n", 66 | "\n", 67 | "\n", 68 | "# Face Detector & FFHQ-style Alignment\n", 69 | "# https://github.com/woctezuma/stylegan2-projecting-images\n", 70 | "import os\n", 71 | "import dlib\n", 72 | "import collections\n", 73 | "from typing import Union, List\n", 74 | "import numpy as np\n", 75 | "from PIL import Image\n", 76 | "import matplotlib.pyplot as plt\n", 77 | "\n", 78 | "\n", 79 | "def get_dlib_face_detector(predictor_path: str = \"shape_predictor_68_face_landmarks.dat\"):\n", 80 | "\n", 81 | " if not os.path.isfile(predictor_path):\n", 82 | " model_file = \"shape_predictor_68_face_landmarks.dat.bz2\"\n", 83 | " os.system(f\"wget http://dlib.net/files/{model_file}\")\n", 84 | " os.system(f\"bzip2 -dk {model_file}\")\n", 85 | "\n", 86 | " detector = dlib.get_frontal_face_detector()\n", 87 | " shape_predictor = dlib.shape_predictor(predictor_path)\n", 88 | "\n", 89 | " def detect_face_landmarks(img: Union[Image.Image, np.ndarray]):\n", 90 | " if isinstance(img, Image.Image):\n", 91 | " img = np.array(img)\n", 92 | " faces = []\n", 93 | " dets = detector(img)\n", 94 | " for d in dets:\n", 95 | " shape = shape_predictor(img, d)\n", 96 | " faces.append(np.array([[v.x, v.y] for v in shape.parts()]))\n", 97 | " return faces\n", 98 | " \n", 99 | " return detect_face_landmarks\n", 100 | "\n", 101 | "\n", 102 | "def display_facial_landmarks(\n", 103 | " img: Image, \n", 104 | " landmarks: List[np.ndarray],\n", 105 | " fig_size=[15, 15]\n", 106 | "):\n", 107 | " plot_style = dict(\n", 108 | " marker='o',\n", 109 | " markersize=4,\n", 110 | " linestyle='-',\n", 111 | " lw=2\n", 112 | " )\n", 113 | " pred_type = collections.namedtuple('prediction_type', ['slice', 'color'])\n", 114 | " pred_types = {\n", 115 | " 'face': pred_type(slice(0, 17), (0.682, 0.780, 0.909, 0.5)),\n", 116 | " 'eyebrow1': pred_type(slice(17, 22), (1.0, 0.498, 0.055, 0.4)),\n", 117 | " 'eyebrow2': pred_type(slice(22, 27), (1.0, 0.498, 0.055, 0.4)),\n", 118 | " 'nose': pred_type(slice(27, 31), (0.345, 0.239, 0.443, 0.4)),\n", 119 | " 'nostril': pred_type(slice(31, 36), (0.345, 0.239, 0.443, 0.4)),\n", 120 | " 'eye1': pred_type(slice(36, 42), (0.596, 0.875, 0.541, 0.3)),\n", 121 | " 'eye2': pred_type(slice(42, 48), (0.596, 0.875, 0.541, 0.3)),\n", 122 | " 'lips': pred_type(slice(48, 60), (0.596, 0.875, 0.541, 0.3)),\n", 123 | " 'teeth': pred_type(slice(60, 68), (0.596, 0.875, 0.541, 0.4))\n", 124 | " }\n", 125 | "\n", 126 | " fig = plt.figure(figsize=fig_size)\n", 127 | " ax = fig.add_subplot(1, 1, 1)\n", 128 | " ax.imshow(img)\n", 129 | " ax.axis('off')\n", 130 | "\n", 131 | " for face in landmarks:\n", 132 | " for pred_type in pred_types.values():\n", 133 | " ax.plot(\n", 134 | " face[pred_type.slice, 0],\n", 135 | " face[pred_type.slice, 1],\n", 136 | " color=pred_type.color, **plot_style\n", 137 | " )\n", 138 | " plt.show()\n", 139 | "\n", 140 | "\n", 141 | "\n", 142 | "# https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py\n", 143 | "\n", 144 | "import PIL.Image\n", 145 | "import PIL.ImageFile\n", 146 | "import numpy as np\n", 147 | "import scipy.ndimage\n", 148 | "\n", 149 | "\n", 150 | "def align_and_crop_face(\n", 151 | " img: Image.Image,\n", 152 | " landmarks: np.ndarray,\n", 153 | " expand: float = 1.0,\n", 154 | " output_size: int = 1024, \n", 155 | " transform_size: int = 4096,\n", 156 | " enable_padding: bool = True,\n", 157 | "):\n", 158 | " # Parse landmarks.\n", 159 | " # pylint: disable=unused-variable\n", 160 | " lm = landmarks\n", 161 | " lm_chin = lm[0 : 17] # left-right\n", 162 | " lm_eyebrow_left = lm[17 : 22] # left-right\n", 163 | " lm_eyebrow_right = lm[22 : 27] # left-right\n", 164 | " lm_nose = lm[27 : 31] # top-down\n", 165 | " lm_nostrils = lm[31 : 36] # top-down\n", 166 | " lm_eye_left = lm[36 : 42] # left-clockwise\n", 167 | " lm_eye_right = lm[42 : 48] # left-clockwise\n", 168 | " lm_mouth_outer = lm[48 : 60] # left-clockwise\n", 169 | " lm_mouth_inner = lm[60 : 68] # left-clockwise\n", 170 | "\n", 171 | " # Calculate auxiliary vectors.\n", 172 | " eye_left = np.mean(lm_eye_left, axis=0)\n", 173 | " eye_right = np.mean(lm_eye_right, axis=0)\n", 174 | " eye_avg = (eye_left + eye_right) * 0.5\n", 175 | " eye_to_eye = eye_right - eye_left\n", 176 | " mouth_left = lm_mouth_outer[0]\n", 177 | " mouth_right = lm_mouth_outer[6]\n", 178 | " mouth_avg = (mouth_left + mouth_right) * 0.5\n", 179 | " eye_to_mouth = mouth_avg - eye_avg\n", 180 | "\n", 181 | " # Choose oriented crop rectangle.\n", 182 | " x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]\n", 183 | " x /= np.hypot(*x)\n", 184 | " x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)\n", 185 | " x *= expand\n", 186 | " y = np.flipud(x) * [-1, 1]\n", 187 | " c = eye_avg + eye_to_mouth * 0.1\n", 188 | " quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])\n", 189 | " qsize = np.hypot(*x) * 2\n", 190 | "\n", 191 | " # Shrink.\n", 192 | " shrink = int(np.floor(qsize / output_size * 0.5))\n", 193 | " if shrink > 1:\n", 194 | " rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))\n", 195 | " img = img.resize(rsize, PIL.Image.ANTIALIAS)\n", 196 | " quad /= shrink\n", 197 | " qsize /= shrink\n", 198 | "\n", 199 | " # Crop.\n", 200 | " border = max(int(np.rint(qsize * 0.1)), 3)\n", 201 | " crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))\n", 202 | " crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1]))\n", 203 | " if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:\n", 204 | " img = img.crop(crop)\n", 205 | " quad -= crop[0:2]\n", 206 | "\n", 207 | " # Pad.\n", 208 | " pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))\n", 209 | " pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0))\n", 210 | " if enable_padding and max(pad) > border - 4:\n", 211 | " pad = np.maximum(pad, int(np.rint(qsize * 0.3)))\n", 212 | " img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')\n", 213 | " h, w, _ = img.shape\n", 214 | " y, x, _ = np.ogrid[:h, :w, :1]\n", 215 | " mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w-1-x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h-1-y) / pad[3]))\n", 216 | " blur = qsize * 0.02\n", 217 | " img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)\n", 218 | " img += (np.median(img, axis=(0,1)) - img) * np.clip(mask, 0.0, 1.0)\n", 219 | " img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')\n", 220 | " quad += pad[:2]\n", 221 | "\n", 222 | " # Transform.\n", 223 | " img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)\n", 224 | " if output_size < transform_size:\n", 225 | " img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)\n", 226 | "\n", 227 | " return img\n", 228 | "\n", 229 | "\n", 230 | "# define display function\n", 231 | "import matplotlib.pyplot as plt\n", 232 | "from PIL import Image\n", 233 | "import os\n", 234 | "import numpy as np\n", 235 | "\n", 236 | "def display_pic(folder):\n", 237 | " fig = plt.figure(figsize=(30, 40))\n", 238 | " files = os.listdir(folder)\n", 239 | " files.sort()\n", 240 | " for i, file in enumerate(files):\n", 241 | " if file == '.ipynb_checkpoints':\n", 242 | " continue \n", 243 | " img = Image.open(folder+'/'+file) \n", 244 | " images = np.asarray(img)\n", 245 | " ax = fig.add_subplot(10, 10, i+1, xticks=[], yticks=[])\n", 246 | " image_plt = np.array(images)\n", 247 | " ax.imshow(image_plt)\n", 248 | " ax.set_xlabel(file, fontsize=15) \n", 249 | " plt.show()\n", 250 | " plt.close()\n", 251 | "\n", 252 | "# サンプルデータをダウンロード\n", 253 | "! pip install --upgrade gdown\n", 254 | "import gdown\n", 255 | "gdown.download('https://drive.google.com/uc?id=1CDSfi5jZ_uqOYFZqe8n_CGkZc-XhyoJ6', 'sample.zip', quiet=False)\n", 256 | "! unzip sample.zip" 257 | ], 258 | "id": "3f40d528", 259 | "execution_count": null, 260 | "outputs": [] 261 | }, 262 | { 263 | "cell_type": "markdown", 264 | "metadata": { 265 | "id": "3gjnXo2V6e1-" 266 | }, 267 | "source": [ 268 | "**自分の画像を使う場合はpicフォルダーにアップロード**" 269 | ], 270 | "id": "3gjnXo2V6e1-" 271 | }, 272 | { 273 | "cell_type": "code", 274 | "metadata": { 275 | "cellView": "form", 276 | "id": "u0AP5LJ_X4Iz" 277 | }, 278 | "source": [ 279 | "#@title サンプル画像表示\n", 280 | "display_pic('pic')" 281 | ], 282 | "id": "u0AP5LJ_X4Iz", 283 | "execution_count": null, 284 | "outputs": [] 285 | }, 286 | { 287 | "cell_type": "markdown", 288 | "metadata": { 289 | "id": "whGJAsWY84FY" 290 | }, 291 | "source": [ 292 | "**side_by_sideのチェックを外すとアニメ単体になる**" 293 | ], 294 | "id": "whGJAsWY84FY" 295 | }, 296 | { 297 | "cell_type": "code", 298 | "metadata": { 299 | "id": "039c54ca", 300 | "cellView": "form" 301 | }, 302 | "source": [ 303 | "#@title 画像をアニメへ変換\n", 304 | "import cv2\n", 305 | "input = '004.jpg' #@param {type:\"string\"}\n", 306 | "side_by_side = True #@param {type:\"boolean\"}\n", 307 | "img = Image.open('pic/'+input).convert(\"RGB\")\n", 308 | "\n", 309 | "face_detector = get_dlib_face_detector()\n", 310 | "landmarks = face_detector(img)\n", 311 | "\n", 312 | "face = align_and_crop_face(img, landmarks[0], expand=1.3)\n", 313 | "output = face2paint(model=model, img=face, size=512, side_by_side=side_by_side)\n", 314 | "output.save('output.jpg')\n", 315 | "display(output)\n" 316 | ], 317 | "id": "039c54ca", 318 | "execution_count": null, 319 | "outputs": [] 320 | }, 321 | { 322 | "cell_type": "code", 323 | "metadata": { 324 | "cellView": "form", 325 | "id": "KJmn6B7aflqc" 326 | }, 327 | "source": [ 328 | "#@title アニメのダウンロード\n", 329 | "from google.colab import files\n", 330 | "files.download('output.jpg')" 331 | ], 332 | "id": "KJmn6B7aflqc", 333 | "execution_count": null, 334 | "outputs": [] 335 | }, 336 | { 337 | "cell_type": "markdown", 338 | "metadata": { 339 | "id": "P9nQAyna8xKz" 340 | }, 341 | "source": [ 342 | "#-------------------------------------------------#" 343 | ], 344 | "id": "P9nQAyna8xKz" 345 | }, 346 | { 347 | "cell_type": "markdown", 348 | "metadata": { 349 | "id": "F94-QDtJ6Snn" 350 | }, 351 | "source": [ 352 | "**自分の動画を使う場合はvideoフォルダーにアップロード**\\\n", 353 | "**動画はRVM処理するのがおすすめ(下記リンク参照)**\\\n", 354 | "http://cedro3.com/ai/rvm/" 355 | ], 356 | "id": "F94-QDtJ6Snn" 357 | }, 358 | { 359 | "cell_type": "code", 360 | "metadata": { 361 | "cellView": "form", 362 | "id": "Ej0JC8Hp32n-" 363 | }, 364 | "source": [ 365 | "#@title サンプル動画の再生\n", 366 | "from IPython.display import HTML\n", 367 | "from base64 import b64encode\n", 368 | "\n", 369 | "mp4 = open('./video/mark.mp4', 'rb').read()\n", 370 | "data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode()\n", 371 | "HTML(f\"\"\"\n", 372 | "\"\"\")" 375 | ], 376 | "id": "Ej0JC8Hp32n-", 377 | "execution_count": null, 378 | "outputs": [] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "metadata": { 383 | "id": "6l9YE6EKJSsS", 384 | "cellView": "form" 385 | }, 386 | "source": [ 387 | "#@title 動画をフレームにバラす\n", 388 | "video_name = 'mark.mp4' #@param {type:\"string\"}\n", 389 | "video_file = 'video/'+video_name\n", 390 | "\n", 391 | "import os\n", 392 | "import shutil\n", 393 | "import cv2\n", 394 | "\n", 395 | "# input.mp4にコピー\n", 396 | "shutil.copy(video_file, 'input.mp4')\n", 397 | "\n", 398 | "# flamesフォルダーリセット\n", 399 | "if os.path.isdir('flames'):\n", 400 | " shutil.rmtree('flames')\n", 401 | "os.makedirs('flames', exist_ok=True)\n", 402 | " \n", 403 | "def video_2_images(video_file= video_file, # ビデオの指定\n", 404 | " image_dir='./flames/', \n", 405 | " image_file='%s.jpg'): \n", 406 | " \n", 407 | " shutil.copy(video_file, 'input.mp4') ####\n", 408 | "\n", 409 | " # Initial setting\n", 410 | " i = 0\n", 411 | " interval = 3\n", 412 | " length = 100 # 最大フレーム数\n", 413 | " \n", 414 | " cap = cv2.VideoCapture(video_file)\n", 415 | " fps = cap.get(cv2.CAP_PROP_FPS) # fps取得\n", 416 | "\n", 417 | " while(cap.isOpened()):\n", 418 | " flag, frame = cap.read() \n", 419 | " if flag == False: \n", 420 | " break\n", 421 | " if i == length*interval:\n", 422 | " break\n", 423 | " if i % interval == 0: \n", 424 | " cv2.imwrite(image_dir+image_file % str(int(i/interval)).zfill(6), frame)\n", 425 | " i += 1 \n", 426 | " cap.release()\n", 427 | " return fps, i, interval\n", 428 | " \n", 429 | "fps, i, interval = video_2_images()\n", 430 | "print('fps = ', fps)\n", 431 | "print('flames = ', i)\n", 432 | "print('interval = ', interval)\n", 433 | " " 434 | ], 435 | "id": "6l9YE6EKJSsS", 436 | "execution_count": null, 437 | "outputs": [] 438 | }, 439 | { 440 | "cell_type": "code", 441 | "metadata": { 442 | "id": "6q8O2V-WPFuI", 443 | "cellView": "form" 444 | }, 445 | "source": [ 446 | "#@title フレームをアニメに変換\n", 447 | "import os\n", 448 | "import shutil\n", 449 | "import cv2\n", 450 | "\n", 451 | "# images folder reset\n", 452 | "if os.path.isdir('out'):\n", 453 | " shutil.rmtree('out')\n", 454 | "os.makedirs('out', exist_ok=True)\n", 455 | "\n", 456 | "import glob\n", 457 | "files = glob.glob('flames/*.jpg')\n", 458 | "files.sort()\n", 459 | "\n", 460 | "face_detector = get_dlib_face_detector()\n", 461 | "\n", 462 | "from tqdm import tqdm\n", 463 | "for i, file in enumerate(tqdm(files)):\n", 464 | " img = Image.open(file).convert(\"RGB\")\n", 465 | " #face_detector = get_dlib_face_detector()\n", 466 | " landmarks = face_detector(img)\n", 467 | "\n", 468 | " face = align_and_crop_face(img, landmarks[0], expand=1.3)\n", 469 | " output = face2paint(model=model, img=face, size=512)\n", 470 | " output.save('out/'+str(i).zfill(6)+'.jpg')" 471 | ], 472 | "id": "6q8O2V-WPFuI", 473 | "execution_count": null, 474 | "outputs": [] 475 | }, 476 | { 477 | "cell_type": "code", 478 | "metadata": { 479 | "id": "rczOn-f5SIUv", 480 | "cellView": "form" 481 | }, 482 | "source": [ 483 | "#@title アニメから動画を作成\n", 484 | "# リセットファイル\n", 485 | "if os.path.exists('./output.mp4'):\n", 486 | " os.remove('./output.mp4')\n", 487 | "\n", 488 | "speed = fps/interval\n", 489 | "\n", 490 | "# アニメ画をmp4動画(output.mp4)に変換する\n", 491 | "! ffmpeg -r $speed -i out/%06d.jpg -vcodec libx264 -pix_fmt yuv420p -loglevel error output.mp4" 492 | ], 493 | "id": "rczOn-f5SIUv", 494 | "execution_count": null, 495 | "outputs": [] 496 | }, 497 | { 498 | "cell_type": "code", 499 | "metadata": { 500 | "cellView": "form", 501 | "id": "4egpzmhh4fiB" 502 | }, 503 | "source": [ 504 | "#@title 動画の再生\n", 505 | "from IPython.display import HTML\n", 506 | "from base64 import b64encode\n", 507 | "\n", 508 | "mp4 = open('./output.mp4', 'rb').read()\n", 509 | "data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode()\n", 510 | "HTML(f\"\"\"\n", 511 | "\"\"\")" 514 | ], 515 | "id": "4egpzmhh4fiB", 516 | "execution_count": null, 517 | "outputs": [] 518 | }, 519 | { 520 | "cell_type": "code", 521 | "metadata": { 522 | "cellView": "form", 523 | "id": "szYJa76p48az" 524 | }, 525 | "source": [ 526 | "#@title 動画のダウンロード\n", 527 | "from google.colab import files\n", 528 | "files.download('output.mp4')" 529 | ], 530 | "id": "szYJa76p48az", 531 | "execution_count": null, 532 | "outputs": [] 533 | } 534 | ] 535 | } 536 | -------------------------------------------------------------------------------- /stylegan_nada.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "accelerator": "GPU", 6 | "colab": { 7 | "name": "stylegan_nada", 8 | "provenance": [], 9 | "collapsed_sections": [], 10 | "include_colab_link": true 11 | }, 12 | "interpreter": { 13 | "hash": "fd69f43f58546b570e94fd7eba7b65e6bcc7a5bbc4eab0408017d18902915d69" 14 | }, 15 | "kernelspec": { 16 | "display_name": "Python 3.7.5 64-bit", 17 | "name": "python3" 18 | }, 19 | "language_info": { 20 | "name": "python", 21 | "version": "" 22 | } 23 | }, 24 | "cells": [ 25 | { 26 | "cell_type": "markdown", 27 | "metadata": { 28 | "id": "view-in-github", 29 | "colab_type": "text" 30 | }, 31 | "source": [ 32 | "\"Open" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": { 38 | "id": "bYsd0_RFXb04" 39 | }, 40 | "source": [ 41 | "# Welcome to StyleGAN-NADA: CLIP-Guided Domain Adaptation of Image Generators!" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": { 47 | "id": "QTHeOO8qFw_e" 48 | }, 49 | "source": [ 50 | "# Step 1: Setup required libraries and models. \n", 51 | "This may take a few minutes.\n", 52 | "\n", 53 | "You may optionally enable downloads with pydrive in order to authenticate and avoid drive download limits when fetching pre-trained ReStyle and StyleGAN2 models." 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "metadata": { 59 | "id": "ph3R7lbl_arQ", 60 | "cellView": "form" 61 | }, 62 | "source": [ 63 | "#@title Setup\n", 64 | "%tensorflow_version 1.x\n", 65 | "! pip install --upgrade gdown\n", 66 | "\n", 67 | "import os\n", 68 | "\n", 69 | "from pydrive.auth import GoogleAuth\n", 70 | "from pydrive.drive import GoogleDrive\n", 71 | "from google.colab import auth\n", 72 | "from oauth2client.client import GoogleCredentials\n", 73 | "\n", 74 | "pretrained_model_dir = os.path.join(\"/content\", \"models\")\n", 75 | "os.makedirs(pretrained_model_dir, exist_ok=True)\n", 76 | "\n", 77 | "restyle_dir = os.path.join(\"/content\", \"restyle\")\n", 78 | "stylegan_ada_dir = os.path.join(\"/content\", \"stylegan_ada\")\n", 79 | "stylegan_nada_dir = os.path.join(\"/content\", \"stylegan_nada\")\n", 80 | "\n", 81 | "output_dir = os.path.join(\"/content\", \"output\")\n", 82 | "\n", 83 | "output_model_dir = os.path.join(output_dir, \"models\")\n", 84 | "output_image_dir = os.path.join(output_dir, \"images\")\n", 85 | "\n", 86 | "download_with_pydrive = False #@param {type:\"boolean\"} \n", 87 | " \n", 88 | "class Downloader(object):\n", 89 | " def __init__(self, use_pydrive):\n", 90 | " self.use_pydrive = use_pydrive\n", 91 | "\n", 92 | " if self.use_pydrive:\n", 93 | " self.authenticate()\n", 94 | " \n", 95 | " def authenticate(self):\n", 96 | " auth.authenticate_user()\n", 97 | " gauth = GoogleAuth()\n", 98 | " gauth.credentials = GoogleCredentials.get_application_default()\n", 99 | " self.drive = GoogleDrive(gauth)\n", 100 | " \n", 101 | " def download_file(self, file_id, file_dst):\n", 102 | " if self.use_pydrive:\n", 103 | " downloaded = self.drive.CreateFile({'id':file_id})\n", 104 | " downloaded.FetchMetadata(fetch_all=True)\n", 105 | " downloaded.GetContentFile(file_dst)\n", 106 | " else:\n", 107 | " !gdown --id $file_id -O $file_dst\n", 108 | "\n", 109 | "downloader = Downloader(download_with_pydrive)\n", 110 | "\n", 111 | "# install requirements\n", 112 | "!git clone https://github.com/yuval-alaluf/restyle-encoder.git $restyle_dir\n", 113 | "\n", 114 | "!wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip\n", 115 | "!sudo unzip ninja-linux.zip -d /usr/local/bin/\n", 116 | "!sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force\n", 117 | "\n", 118 | "!pip install ftfy regex tqdm \n", 119 | "!pip install git+https://github.com/openai/CLIP.git\n", 120 | "\n", 121 | "!git clone https://github.com/NVlabs/stylegan2-ada/ $stylegan_ada_dir\n", 122 | "!git clone https://github.com/rinongal/stylegan-nada.git $stylegan_nada_dir\n", 123 | "\n", 124 | "from argparse import Namespace\n", 125 | "\n", 126 | "import sys\n", 127 | "import numpy as np\n", 128 | "\n", 129 | "from PIL import Image\n", 130 | "\n", 131 | "import torch\n", 132 | "import torchvision.transforms as transforms\n", 133 | "\n", 134 | "sys.path.append(restyle_dir)\n", 135 | "sys.path.append(stylegan_nada_dir)\n", 136 | "sys.path.append(os.path.join(stylegan_nada_dir, \"ZSSGAN\"))\n", 137 | "\n", 138 | "device = 'cuda'\n", 139 | "\n", 140 | "%load_ext autoreload\n", 141 | "%autoreload 2\n", 142 | "\n", 143 | "import gdown\n", 144 | "gdown.download('https://drive.google.com/uc?id=1DiC2AZRt7GDSnLsE--FOqOTnMWwV6mTr', './sample.zip', quiet=False)\n", 145 | "! unzip sample.zip" 146 | ], 147 | "execution_count": null, 148 | "outputs": [] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "metadata": { 153 | "id": "kSL166pfGRWF" 154 | }, 155 | "source": [ 156 | "# Step 2: Choose a model type.\n", 157 | "Model will be downloaded and converted to a pytorch compatible version.\n", 158 | "\n", 159 | "Re-runs of the cell with the same model will re-use the previously downloaded version. Feel free to experiment and come back to previous models :)" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "metadata": { 165 | "id": "J4ATNsC1k28g", 166 | "cellView": "form" 167 | }, 168 | "source": [ 169 | "source_model_type = 'ffhq' #@param['ffhq', 'cat', 'dog', 'church', 'horse', 'car']\n", 170 | "\n", 171 | "source_model_download_path = {\"ffhq\": \"https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/ffhq.pkl\",\n", 172 | " \"cat\": \"https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/afhqcat.pkl\",\n", 173 | " \"dog\": \"https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/afhqdog.pkl\",\n", 174 | " \"church\": \"1iDo5cUgbwsJEt2uwfgDy_iPlaT-lLZmi\",\n", 175 | " \"car\": \"1i-39ztut-VdUVUiFuUrwdsItR--HF81w\",\n", 176 | " \"horse\": \"1irwWI291DolZhnQeW-ZyNWqZBjlWyJUn\"}\n", 177 | "\n", 178 | "model_names = {\"ffhq\": \"ffhq.pkl\",\n", 179 | " \"cat\": \"afhqcat.pkl\",\n", 180 | " \"dog\": \"afhqdog.pkl\",\n", 181 | " \"church\": \"stylegan2-church-config-f.pkl\",\n", 182 | " \"car\": \"stylegan2-car-config-f.pkl\",\n", 183 | " \"horse\": \"stylegan2-horse-config-f.pkl\"}\n", 184 | "\n", 185 | "download_string = source_model_download_path[source_model_type]\n", 186 | "file_name = model_names[source_model_type]\n", 187 | "pt_file_name = file_name.split(\".\")[0] + \".pt\"\n", 188 | "\n", 189 | "dataset_sizes = {\n", 190 | " \"ffhq\": 1024,\n", 191 | " \"cat\": 512,\n", 192 | " \"dog\": 512,\n", 193 | " \"church\": 256,\n", 194 | " \"horse\": 256,\n", 195 | " \"car\": 512,\n", 196 | "}\n", 197 | "\n", 198 | "if not os.path.isfile(os.path.join(pretrained_model_dir, file_name)):\n", 199 | " print(\"Downloading chosen model...\")\n", 200 | "\n", 201 | " if download_string.endswith(\".pkl\"):\n", 202 | " !wget $download_string -O $pretrained_model_dir/$file_name\n", 203 | " else:\n", 204 | " downloader.download_file(download_string, os.path.join(pretrained_model_dir, file_name))\n", 205 | " \n", 206 | "if not os.path.isfile(os.path.join(pretrained_model_dir, pt_file_name)):\n", 207 | " print(\"Converting sg2 model. This may take a few minutes...\")\n", 208 | " \n", 209 | " tf_path = next(filter(lambda x: \"tensorflow\" in x, sys.path), None)\n", 210 | " py_path = tf_path + f\":{stylegan_nada_dir}/ZSSGAN\"\n", 211 | " convert_script = os.path.join(stylegan_nada_dir, \"convert_weight.py\")\n", 212 | " !PYTHONPATH=$py_path python $convert_script --repo $stylegan_ada_dir --gen $pretrained_model_dir/$file_name" 213 | ], 214 | "execution_count": null, 215 | "outputs": [] 216 | }, 217 | { 218 | "cell_type": "markdown", 219 | "metadata": { 220 | "id": "DAri8ULOG2VE" 221 | }, 222 | "source": [ 223 | "# Step 3: Train the model.\n", 224 | "Describe your source and target class. These describe the direction of change you're trying to apply (e.g. \"photo\" to \"sketch\", \"dog\" to \"the joker\" or \"dog\" to \"avocado dog\").\n", 225 | "\n", 226 | "Alternatively, upload a directory with a small (~3) set of target style images (there is no need to preprocess them in any way) and set `style_image_dir` to point at them. This will use the images as a target rather than the source/class texts.\n", 227 | "\n", 228 | "We reccomend leaving the 'improve shape' button unticked at first, as it will lead to an increase in running times and is often not needed.\n", 229 | "For more drastic changes, turn it on and increase the number of iterations.\n", 230 | "\n", 231 | "As a rule of thumb:\n", 232 | "- Style and minor domain changes ('photo' -> 'sketch') require ~200-400 iterations.\n", 233 | "- Identity changes ('person' -> 'taylor swift') require ~150-200 iterations.\n", 234 | "- Simple in-domain changes ('face' -> 'smiling face') may require as few as 50.\n", 235 | "- The `style_image_dir` option often requires ~400-600 iterations.\n", 236 | "\n", 237 | "> Updates:
\n", 238 | "> 03/10 - Added support for style image targets.
\n", 239 | "> 03/08 - Added support for saving model checkpoints. If you want to save, set save_interval > 0.\n" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "metadata": { 245 | "id": "8YrtPb7KF8m-", 246 | "cellView": "form" 247 | }, 248 | "source": [ 249 | "from ZSSGAN.model.ZSSGAN import ZSSGAN\n", 250 | "\n", 251 | "import numpy as np\n", 252 | "\n", 253 | "import torch\n", 254 | "\n", 255 | "from tqdm import notebook\n", 256 | "\n", 257 | "from ZSSGAN.utils.file_utils import save_images, get_dir_img_list\n", 258 | "from ZSSGAN.utils.training_utils import mixing_noise\n", 259 | "\n", 260 | "from IPython.display import display\n", 261 | "\n", 262 | "source_class = \"Photo\" #@param {\"type\": \"string\"}\n", 263 | "target_class = \"Sketch\" #@param {\"type\": \"string\"}\n", 264 | "\n", 265 | "style_image_dir = \"\" #@param {'type': 'string'}\n", 266 | "\n", 267 | "target_img_list = get_dir_img_list(style_image_dir) if style_image_dir else None\n", 268 | "\n", 269 | "improve_shape = False #@param{type:\"boolean\"}\n", 270 | "\n", 271 | "model_choice = [\"ViT-B/32\", \"ViT-B/16\"]\n", 272 | "model_weights = [1.0, 0.0]\n", 273 | "\n", 274 | "if improve_shape or style_image_dir:\n", 275 | " model_weights[1] = 1.0\n", 276 | " \n", 277 | "mixing = 0.9 if improve_shape else 0.0\n", 278 | "\n", 279 | "auto_layers_k = int(2 * (2 * np.log2(dataset_sizes[source_model_type]) - 2) / 3) if improve_shape else 0\n", 280 | "auto_layer_iters = 1 if improve_shape else 0\n", 281 | "\n", 282 | "training_iterations = 151 #@param {type: \"integer\"}\n", 283 | "output_interval = 50 #@param {type: \"integer\"}\n", 284 | "save_interval = 0 #@param {type: \"integer\"}\n", 285 | "\n", 286 | "training_args = {\n", 287 | " \"size\": dataset_sizes[source_model_type],\n", 288 | " \"batch\": 2,\n", 289 | " \"n_sample\": 4,\n", 290 | " \"output_dir\": output_dir,\n", 291 | " \"lr\": 0.002,\n", 292 | " \"frozen_gen_ckpt\": os.path.join(pretrained_model_dir, pt_file_name),\n", 293 | " \"train_gen_ckpt\": os.path.join(pretrained_model_dir, pt_file_name),\n", 294 | " \"iter\": training_iterations,\n", 295 | " \"source_class\": source_class,\n", 296 | " \"target_class\": target_class,\n", 297 | " \"lambda_direction\": 1.0,\n", 298 | " \"lambda_patch\": 0.0,\n", 299 | " \"lambda_global\": 0.0,\n", 300 | " \"lambda_texture\": 0.0,\n", 301 | " \"lambda_manifold\": 0.0,\n", 302 | " \"auto_layer_k\": auto_layers_k,\n", 303 | " \"auto_layer_iters\": auto_layer_iters,\n", 304 | " \"auto_layer_batch\": 8,\n", 305 | " \"output_interval\": 50,\n", 306 | " \"clip_models\": model_choice,\n", 307 | " \"clip_model_weights\": model_weights,\n", 308 | " \"mixing\": mixing,\n", 309 | " \"phase\": None,\n", 310 | " \"sample_truncation\": 0.7,\n", 311 | " \"save_interval\": save_interval,\n", 312 | " \"target_img_list\": target_img_list,\n", 313 | " \"img2img_batch\": 16,\n", 314 | "}\n", 315 | "\n", 316 | "args = Namespace(**training_args)\n", 317 | "\n", 318 | "print(\"Loading base models...\")\n", 319 | "net = ZSSGAN(args)\n", 320 | "print(\"Models loaded! Starting training...\")\n", 321 | "\n", 322 | "g_reg_ratio = 4 / 5\n", 323 | "\n", 324 | "g_optim = torch.optim.Adam(\n", 325 | " net.generator_trainable.parameters(),\n", 326 | " lr=args.lr * g_reg_ratio,\n", 327 | " betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),\n", 328 | ")\n", 329 | "\n", 330 | "# Set up output directories.\n", 331 | "sample_dir = os.path.join(args.output_dir, \"sample\")\n", 332 | "ckpt_dir = os.path.join(args.output_dir, \"checkpoint\")\n", 333 | "\n", 334 | "os.makedirs(sample_dir, exist_ok=True)\n", 335 | "os.makedirs(ckpt_dir, exist_ok=True)\n", 336 | "\n", 337 | "seed = 3 #@param {\"type\": \"integer\"}\n", 338 | "\n", 339 | "torch.manual_seed(seed)\n", 340 | "np.random.seed(seed)\n", 341 | "\n", 342 | "# Training loop\n", 343 | "fixed_z = torch.randn(args.n_sample, 512, device=device)\n", 344 | "\n", 345 | "for i in notebook.tqdm(range(args.iter)):\n", 346 | " net.train()\n", 347 | " \n", 348 | " sample_z = mixing_noise(args.batch, 512, args.mixing, device)\n", 349 | "\n", 350 | " [sampled_src, sampled_dst], clip_loss = net(sample_z)\n", 351 | "\n", 352 | " net.zero_grad()\n", 353 | " clip_loss.backward()\n", 354 | "\n", 355 | " g_optim.step()\n", 356 | "\n", 357 | " if i % output_interval == 0:\n", 358 | " net.eval()\n", 359 | "\n", 360 | " with torch.no_grad():\n", 361 | " [sampled_src, sampled_dst], loss = net([fixed_z], truncation=args.sample_truncation)\n", 362 | "\n", 363 | " if source_model_type == 'car':\n", 364 | " sampled_dst = sampled_dst[:, :, 64:448, :]\n", 365 | "\n", 366 | " grid_rows = 4\n", 367 | "\n", 368 | " save_images(sampled_dst, sample_dir, \"dst\", grid_rows, i)\n", 369 | "\n", 370 | " img = Image.open(os.path.join(sample_dir, f\"dst_{str(i).zfill(6)}.jpg\")).resize((1024, 256))\n", 371 | " display(img)\n", 372 | " \n", 373 | " if (args.save_interval > 0) and (i > 0) and (i % args.save_interval == 0):\n", 374 | " torch.save(\n", 375 | " {\n", 376 | " \"g_ema\": net.generator_trainable.generator.state_dict(),\n", 377 | " \"g_optim\": g_optim.state_dict(),\n", 378 | " },\n", 379 | " f\"{ckpt_dir}/{str(i).zfill(6)}.pt\",\n", 380 | " )" 381 | ], 382 | "execution_count": null, 383 | "outputs": [] 384 | }, 385 | { 386 | "cell_type": "markdown", 387 | "metadata": { 388 | "id": "9ZZk6yZQvxGY" 389 | }, 390 | "source": [ 391 | "# Step 4: Generate samples with the new model" 392 | ] 393 | }, 394 | { 395 | "cell_type": "code", 396 | "metadata": { 397 | "id": "dLinyTgev5Qk", 398 | "cellView": "form" 399 | }, 400 | "source": [ 401 | "truncation = 0.7 #@param {type:\"slider\", min:0, max:1, step:0.05}\n", 402 | "\n", 403 | "samples = 9\n", 404 | "\n", 405 | "with torch.no_grad():\n", 406 | " net.eval()\n", 407 | " sample_z = torch.randn(samples, 512, device=device)\n", 408 | "\n", 409 | " [sampled_src, sampled_dst], loss = net([sample_z], truncation=truncation)\n", 410 | "\n", 411 | " if source_model_type == 'car':\n", 412 | " sampled_dst = sampled_dst[:, :, 64:448, :]\n", 413 | "\n", 414 | " grid_rows = int(samples ** 0.5)\n", 415 | "\n", 416 | " save_images(sampled_dst, sample_dir, \"sampled\", grid_rows, 0)\n", 417 | "\n", 418 | " display(Image.open(os.path.join(sample_dir, f\"sampled_{str(0).zfill(6)}.jpg\")).resize((768, 768)))" 419 | ], 420 | "execution_count": null, 421 | "outputs": [] 422 | }, 423 | { 424 | "cell_type": "markdown", 425 | "metadata": { 426 | "id": "e4hVHBrlGxzo" 427 | }, 428 | "source": [ 429 | "## Editing a real image with Re-Style inversion (currently only FFHQ inversion is supported):" 430 | ] 431 | }, 432 | { 433 | "cell_type": "markdown", 434 | "metadata": { 435 | "id": "-He6svz1qami" 436 | }, 437 | "source": [ 438 | "Step 1: Set up Re-Style.\n", 439 | "\n", 440 | "This may take a few minutes" 441 | ] 442 | }, 443 | { 444 | "cell_type": "code", 445 | "metadata": { 446 | "cellView": "form", 447 | "id": "R9IMgvRwqcGb" 448 | }, 449 | "source": [ 450 | "#@title Set up Re-Style\n", 451 | "from restyle.utils.common import tensor2im\n", 452 | "from restyle.models.psp import pSp\n", 453 | "from restyle.models.e4e import e4e\n", 454 | "\n", 455 | "downloader.download_file(\"1sw6I2lRIB0MpuJkpc8F5BJiSZrc0hjfE\", os.path.join(pretrained_model_dir, \"restyle_psp_ffhq_encode.pt\"))\n", 456 | "downloader.download_file(\"1e2oXVeBPXMQoUoC_4TNwAWpOPpSEhE_e\", os.path.join(pretrained_model_dir, \"restyle_e4e_ffhq_encode.pt\"))" 457 | ], 458 | "execution_count": null, 459 | "outputs": [] 460 | }, 461 | { 462 | "cell_type": "markdown", 463 | "metadata": { 464 | "id": "azLoQ61JqkyH" 465 | }, 466 | "source": [ 467 | "Step 2: Choose a re-style model\n", 468 | "\n", 469 | "We reccomend choosing the e4e model as it performs better under domain translations. Choose pSp for better reconstructions on minor domain changes (typically those that require less than 150 training steps)." 470 | ] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "metadata": { 475 | "cellView": "form", 476 | "id": "YEUaiEL2qn9g" 477 | }, 478 | "source": [ 479 | "encoder_type = 'psp' #@param['psp', 'e4e']\n", 480 | "\n", 481 | "restyle_experiment_args = {\n", 482 | " \"model_path\": os.path.join(pretrained_model_dir, f\"restyle_{encoder_type}_ffhq_encode.pt\"),\n", 483 | " \"transform\": transforms.Compose([\n", 484 | " transforms.Resize((256, 256)),\n", 485 | " transforms.ToTensor(),\n", 486 | " transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])\n", 487 | "}\n", 488 | "\n", 489 | "model_path = restyle_experiment_args['model_path']\n", 490 | "ckpt = torch.load(model_path, map_location='cpu')\n", 491 | "\n", 492 | "opts = ckpt['opts']\n", 493 | "\n", 494 | "opts['checkpoint_path'] = model_path\n", 495 | "opts = Namespace(**opts)\n", 496 | "\n", 497 | "restyle_net = (pSp if encoder_type == 'psp' else e4e)(opts)\n", 498 | "\n", 499 | "restyle_net.eval()\n", 500 | "restyle_net.cuda()\n", 501 | "print('Model successfully loaded!')" 502 | ], 503 | "execution_count": null, 504 | "outputs": [] 505 | }, 506 | { 507 | "cell_type": "markdown", 508 | "metadata": { 509 | "id": "HfB-jTnZgn0D" 510 | }, 511 | "source": [ 512 | "Step 3: Align and invert an image" 513 | ] 514 | }, 515 | { 516 | "cell_type": "code", 517 | "metadata": { 518 | "id": "2tMd5WBvE0Ol", 519 | "cellView": "form" 520 | }, 521 | "source": [ 522 | "def run_alignment(image_path):\n", 523 | " import dlib\n", 524 | " from scripts.align_faces_parallel import align_face\n", 525 | " if not os.path.exists(\"shape_predictor_68_face_landmarks.dat\"):\n", 526 | " print('Downloading files for aligning face image...')\n", 527 | " os.system('wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2')\n", 528 | " os.system('bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2')\n", 529 | " print('Done.')\n", 530 | " predictor = dlib.shape_predictor(\"shape_predictor_68_face_landmarks.dat\")\n", 531 | " aligned_image = align_face(filepath=image_path, predictor=predictor) \n", 532 | " print(\"Aligned image has shape: {}\".format(aligned_image.size))\n", 533 | " return aligned_image \n", 534 | "\n", 535 | "image_path = \"/content/sample/001.jpg\" #@param {'type': 'string'}\n", 536 | "original_image = Image.open(image_path).convert(\"RGB\")\n", 537 | "\n", 538 | "input_image = run_alignment(image_path)\n", 539 | "\n", 540 | "display(input_image)\n", 541 | "\n", 542 | "img_transforms = restyle_experiment_args['transform']\n", 543 | "transformed_image = img_transforms(input_image)\n", 544 | "\n", 545 | "def get_avg_image(net):\n", 546 | " avg_image = net(net.latent_avg.unsqueeze(0),\n", 547 | " input_code=True,\n", 548 | " randomize_noise=False,\n", 549 | " return_latents=False,\n", 550 | " average_code=True)[0]\n", 551 | " avg_image = avg_image.to('cuda').float().detach()\n", 552 | " return avg_image\n", 553 | "\n", 554 | "opts.n_iters_per_batch = 5\n", 555 | "opts.resize_outputs = False # generate outputs at full resolution\n", 556 | "\n", 557 | "from restyle.utils.inference_utils import run_on_batch\n", 558 | "\n", 559 | "with torch.no_grad():\n", 560 | " avg_image = get_avg_image(restyle_net)\n", 561 | " result_batch, result_latents = run_on_batch(transformed_image.unsqueeze(0).cuda(), restyle_net, opts, avg_image)" 562 | ], 563 | "execution_count": null, 564 | "outputs": [] 565 | }, 566 | { 567 | "cell_type": "markdown", 568 | "metadata": { 569 | "id": "XOiIZcJUgsQS" 570 | }, 571 | "source": [ 572 | "Step 4: Convert the image to the new domain" 573 | ] 574 | }, 575 | { 576 | "cell_type": "code", 577 | "metadata": { 578 | "id": "u5JqEOMnEA_m", 579 | "cellView": "form" 580 | }, 581 | "source": [ 582 | "#@title Convert inverted image.\n", 583 | "inverted_latent = torch.Tensor(result_latents[0][4]).cuda().unsqueeze(0).unsqueeze(1)\n", 584 | "\n", 585 | "with torch.no_grad():\n", 586 | " net.eval()\n", 587 | " \n", 588 | " [sampled_src, sampled_dst] = net(inverted_latent, input_is_latent=True)[0]\n", 589 | " \n", 590 | " joined_img = torch.cat([sampled_src, sampled_dst], dim=0)\n", 591 | " save_images(joined_img, sample_dir, \"joined\", 2, 0)\n", 592 | " display(Image.open(os.path.join(sample_dir, f\"joined_{str(0).zfill(6)}.jpg\")).resize((512, 256)))" 593 | ], 594 | "execution_count": null, 595 | "outputs": [] 596 | } 597 | ] 598 | } 599 | --------------------------------------------------------------------------------