├── .gitignore ├── README.md ├── conda_env.txt ├── docker_install.md ├── img ├── stylegan1_pt_compression.jpg ├── stylegan1_tf_compression.jpg ├── stylegan2_pt_compression.jpg ├── stylegan2_tf_compression.jpg └── waifu │ ├── anime_face_v1_1.jpg │ ├── anime_face_v1_2.jpg │ ├── anime_portrait_v1.jpg │ └── anime_portrait_v2.jpg ├── network ├── common.py ├── run_pt_stylegan.py ├── stylegan1.py └── stylegan2.py ├── packaged ├── run_pt_stylegan1.py ├── run_pt_stylegan2.py ├── run_tf_stylegan1.py └── run_tf_stylegan2.py └── waifu ├── common.py ├── converter.py ├── run_pt_stylegan.py ├── stylegan1.py └── stylegan2.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.pyc 3 | .backup* 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # stylegans-pytorch 2 | ![License CC BY-NC](https://img.shields.io/badge/license-CC_BY--NC-green.svg?style=plastic) 3 | 4 | - pure pytorch python implementation 5 | - doesn't need operation implemented by cuda 6 | 7 | **やったこと** 8 | 著者実装の学習済みStyleGAN ([v1](https://github.com/NVlabs/stylegan),[v2](https://github.com/NVlabs/stylegan2))の 9 | 重みを変換してPyTorch再現実装のモデルで同じ出力を得るまで. 10 | 11 | - 著者配布学習済み重み( tensorflow )を pytorch 実装モデルで動かせるように変換する 12 | - 変換済み著者配布重みを使用して pytorch 実装モデルで画像生成 13 | 14 | > 学習済みモデルからの重みの抽出を著者コードに依存しない形で実装しようと考えたが, 15 | 配布されている学習済みモデルpickleの内部で色々と参照されているので, 16 | 再現ができたかの確認や重みの変換には著者コード(TensorFlow実装)が必須である. 17 | 18 | ## 再現結果 19 | 20 | - 著者オリジナル実装 StyleGAN1 21 | ![tf_stylegan1結果圧縮画像](./img/stylegan1_tf_compression.jpg) 22 | - 再現実装 StyleGAN1 23 | ![pt_stylegan1結果圧縮画像](./img/stylegan1_pt_compression.jpg) 24 | - 著者オリジナル実装 StyleGAN2 25 | ![tf_stylegan2結果圧縮画像](./img/stylegan2_tf_compression.jpg) 26 | - 再現実装 StyleGAN2 27 | ![pt_stylegan2結果圧縮画像](./img/stylegan2_pt_compression.jpg) 28 | 29 | ## ディレクトリ構成 30 | 31 | ### はじめに 32 | 著者配布の学習済みモデルから重みを抽出するには 33 | 著者配布コードを利用するしかない. 34 | オリジナルのリポジトリを clone して,ここにあるコードをコピーして使ってください. 35 | 依存環境が把握しやすいように, 36 | StyleGANv1 と StyleGANv2 でコードを共通化せず 37 | 全てのコードを1ファイルにまとめてあります. 38 | 同様の形式でPyTorch再現実装版も1ファイルにまとめたものを用意しておきました. 39 | PyTorch再現実装版はコードの共通化をしたものも用意してあり, 40 | StyleGANv1,StyleGANv2の違う部分だけを確認しやすいようにしておきました. 41 | 42 | **注意** : tensorflowから重みを変換するにはGPUが必要です.(4GB以上) 43 | GPUをお持ちでない方は [StyleGAN2による画像生成をCPU環境/TensorFlow.jsで動かす](https://memo.sugyan.com/entry/2020/02/06/005441) を参考に 44 | 色々と著者コードを書き換えて対応してください. 45 | 46 | ### ディレクトリの構成 47 | ``` 48 | - workdir/ 49 | - stylegans-pytorch/ 本リポジトリ 50 | - img/ : 再現結果 51 | - network/ : StyleGANの構造 (PyTorch) 52 | - waifu/ : アニメ顔生成の学習済みモデルを動かすのに利用 53 | - packaged/ : StyleGANを動作させるコード (tf/pt) 1ファイルにまとめられている 54 | - conda_env.txt : 動作確認済み環境 55 | - docker_install.md : Dockerインストール方法について 56 | - README.md : 説明 (このファイル) 57 | - stylegan/ 著者オリジナル https://github.com/NVlabs/stylegan 58 | - stylegan2/ 著者オリジナル https://github.com/NVlabs/stylegan2 59 | - /wherever/you/want/ 60 | - karras2019stylegan-ffhq-1024x1024.pkl : 著者配布学習済みモデル for StyleGANv1 61 | - stylegan2-ffhq-config-f.pkl : 著者配布学習済みモデル for StyleGANv2 62 | ``` 63 | 64 | ### 概要 65 | 66 | 2つのステップ (``run_tf_stylegan2.py`` & ``run_pt_stylegan2.py`` ) を踏む必要がある 67 | 1. 重みの変換 ``tf`` : 著者配布 tensorflowコード & 重みを使って pytorch で使える重みを得る 68 | 2. 画像の生成 ``pt`` : 本リポジトリのコードと pytorch用重みを使って画像を生成 69 | 70 | - ``run_tf_stylegan2.py`` 71 | - 著者配布モデル ``stylegan2-ffhq-config-f.pkl`` を numpy形式 ``stylegan2_ndarray.pkl`` に変換 72 | - 潜在ベクトルを決めて ``latent2.pkl`` として保存 73 | - 潜在ベクトルから画像を生成して ``stylegan2_tf.png`` として保存 74 | - ``run_pt_stylegan2.py`` 75 | - numpy形式 ``stylegan2_ndarray.pkl`` を pytorch形式 ``stylegan2_state_dict.pth`` に変換 76 | - 潜在ベクトルを読み込んで画像生成し ``stylegan2_pt.png`` として保存 77 | 78 | ### 入出力用ディレクトリについて 79 | 重みの変換/再現の確認の際に以下のものが入力/出力される (以下の表は StyleGANv2 のもの) 80 | 81 | ``run_tf_stylegan2.py`` , ``run_pt_stylegan2.py`` の各プログラムには 82 | オプションとして ``-w`` ( weight directory ) と ``-o`` ( output directory ) を指定する必要がある. 83 | 84 | - dir : ``w`` = weight directory, ``o`` = output directory 85 | それぞれ別のディレクトリを指定することができる. 86 | 後で説明する使い方では w と o は同じディレクトリにしている. 87 | - ``tf`` と ``pt`` はそれぞれ ``run_tf_stylegan2.py`` と ``run_pt_stylegan2.py`` の入出力ファイルを表している. 88 | 89 | | dir | tf | pt | summary | file name | detail | 90 | | --- | --- | --- | ---- | ---- | ---- | 91 | | w | IN | - | 学習済みモデル | ``stylegan2-ffhq-config-f.pkl`` | 配布されているものをダウンロード | 92 | | w | OUT | IN | 学習済みモデル | ``stylegan2_ndarray.pkl`` | ``run_tf_stylegan2.py``でnumpy形式に変換 | 93 | | o | OUT | IN | 入力潜在変数 | ``latent2.pkl`` | ``run_tf_stylegan2.py``で使用した潜在ベクトルを記録 | 94 | | o | OUT | IN | 出力結果写真 | ``stylegan2_tf.png`` | ``run_tf_stylegan2.py``で著者実装モデルから出力 | 95 | | o | - | OUT | 出力結果写真 | ``stylegan2_pt.png`` | ``run_pt_stylegan2.py``で本実装から出力 | 96 | | w | - | OUT | 学習済みモデル | ``stylegan2_state_dict.pth`` | ``run_pt_stylegan2.py``でnumpy形式から変換 | 97 | 98 | 99 | --- 100 | 101 | ## 実行方法 102 | 103 | ### 1. 入出力用ディレクトリの準備 104 | 105 | 以下のように用意 106 | ``` 107 | export STYLEGANSDIR=/wherever/you/want 108 | mkdir -p $STYLEGANSDIR 109 | ``` 110 | 111 | ### 2. 重みのダウンロード 112 | 再現実装の動作確認にはオリジナルの学習済みモデルと, 113 | 生成器の出力を保存するためのディレクトリが必要. 114 | ``` 115 | ( cd $STYLEGANSDIR && curl gdrive.sh | bash -s https://drive.google.com/open?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ ) 116 | ( cd $STYLEGANSDIR && curl gdrive.sh | bash -s https://drive.google.com/open?id=1Mgh-jglZjgksupF0XLl0KzuOqd1LXcoE ) 117 | ``` 118 | 119 | **注意** 120 | このスクリプトは通常は正常に動くはずだが, 121 | Googleドライブ側に一日あたりのダウンロード回数制限がある(?)ためか動かないことがあります. 122 | GUIのブラウザから直接アクセスするほうが良いかもしれません. 123 | 124 | 125 | ### 3. コードの用意 126 | 著者オリジナル実装と本レポジトリをダウンロード. 127 | ``` 128 | mkdir workdir 129 | cd workdir 130 | git clone https://github.com/NVlabs/stylegan.git 131 | git clone https://github.com/NVlabs/stylegan2.git 132 | git clone https://github.com/yuuho/stylegans-pytorch.git 133 | ``` 134 | 135 | ### 4. 環境構築 136 | StyleGANv1のほうは著者オリジナル実装(tensorflow)も再現実装(pytorch)もcondaのみで環境構築可能. 137 | ``` 138 | conda create -y -n stylegans numpy scipy opencv tensorflow-gpu=1.14 tensorboard lmdb requests pytorch -c pytorch 139 | conda install -y -n stylegans moviepy -c conda-forge 140 | conda activate stylegans 141 | ``` 142 | 143 | ~~実行した環境を残しておいた 144 | (``conda list --export > conda_env.txt``)ので 145 | 以下で構築しても良い. 146 | ``conda env create -n stylegans -f stylegans-pytorch/conda_env.txt``~~ 147 | 148 | StyleGANv2のほうは,著者オリジナル実装(tensorflow)は 149 | CUDAコードが含まれておりnvccによるコンパイル環境が必要なのでDockerのほうが良い. 150 | rootless環境でも動くはず.rootlessやりたい場合は [インストール方法](docker_install.md) を見る. 151 | 152 | Dockerイメージのビルドは ``workdir``で 153 | ``` 154 | docker build -t tkarras/stylegan2:latest ./stylegan2 155 | ``` 156 | 再現実装(pytorch)に関してはStyleGANv1と同じ環境で動く. 157 | 158 | 159 | ### 5. 動かす 160 | #### 5.1. StyleGAN (v1) tensorflow 161 | ``workdir``で 162 | ``` 163 | cp stylegans-pytorch/packaged/run_tf_stylegan1.py stylegan/ 164 | cd stylegan 165 | python run_tf_stylegan1.py -w $STYLEGANSDIR -o $STYLEGANSDIR 166 | cd - 167 | ``` 168 | 169 | #### 5.2. StyleGAN (v1) pytorch 170 | ``workdir``で 171 | ``` 172 | python stylegans-pytorch/packaged/run_pt_stylegan1.py -w $STYLEGANSDIR -o $STYLEGANSDIR 173 | ``` 174 | 175 | #### 5.3. StyleGAN (v2) tensorflow 176 | 著者Docker環境で実行する. 177 | ``workdir``で 178 | ``` 179 | cp stylegans-pytorch/packaged/run_tf_stylegan2.py stylegan2/ 180 | docker run --gpus all -v $PWD/stylegan2:/workdir \ 181 | -w /workdir -v $STYLEGANSDIR:/stylegans-pytorch \ 182 | -it --rm tkarras/stylegan2:latest 183 | CUDA_VISIBLE_DEVICES=0 python run_tf_stylegan2.py -w /stylegans-pytorch -o /stylegans-pytorch 184 | exit 185 | ``` 186 | 187 | Dockerがrootlessでないなら出力したファイルを読めるようにする必要がある 188 | ``` 189 | sudo chown -R $(whoami) $STYLEGANSDIR 190 | ``` 191 | 192 | #### 5.4. StyleGAN (v2) pytorch 193 | 194 | ``workdir``で 195 | ``` 196 | python stylegans-pytorch/packaged/run_pt_stylegan2.py -w $STYLEGANSDIR -o $STYLEGANSDIR 197 | ``` 198 | 199 | --- 200 | 201 | ## 特殊な部分/細かな違い 202 | 203 | ### 1. 画像拡大操作/ブラー 204 | StyleGANの解像度をあげるためのConvolutionについて, 205 | 基本的にはTransposedConvolutionを利用するが, 206 | 後続のBlurレイヤーとの兼ね合いもあっていくつかの実装方法が存在する. 207 | 208 | 1. ConvTransposeを3x3フィルタで pad0,stride2で行い,blurを4x4フィルタで行う方法 209 | 2. ConvTransposeを4x4フィルタで pad1,stride2で行い,blurを3x3フィルタで行う方法 210 | 3. nearest neighbour upsamplingで拡大後,Convolutionを3x3フィルタで pad1,stride1で行い,blurを3x3フィルタで行う方法 211 | 212 | StyleGANの論文で引用している論文では信号処理的な観点からblurの必要性について説明している. 213 | おそらくもっとも素直な実装は3である. 214 | 1と2は3に比べて高速である. 215 | upsampleを行ってconvolutionをすると計算量的に重くなるので, 216 | ほぼ同値な方法として3x3の畳み込みフィルタを学習させたい. 217 | 1と2はほぼ同値である. 218 | 219 | ### 2. ノイズの入れ方 220 | StyleGANでは全ピクセルに対してノイズを入れる. 221 | 222 | StyleGAN1では固定ノイズは (H,W) で保持しておいて, 223 | ノイズの重みを (C,) で保持. 224 | 225 | StyleGAN2では固定ノイズは (H,W) で保持しておいて, 226 | ノイズの重みを (1,) = スカラー で保持. 227 | 228 | ### 3. 増幅 229 | StyleGAN1とStyleGAN2で増幅処理している場所が違う. 230 | 元の実装では gain という変数に √2 などが設定されていて, 231 | convやfcの後に強制的に特徴マップを増幅していた. 232 | 233 | - StyleGAN1 mapping: linear -> gain -> bias 234 | - StyleGAN2 mapping: linear -> bias -> gain 235 | - StyleGAN1 conv : conv -> gain -> noise -> bias 236 | - StyleGAN2 conv : conv -> noise -> bias -> gain 237 | - StyleGAN1 toRGB : conv -> bias (増幅なし) 238 | - StyleGAN2 toRGB : conv -> bias (増幅なし) 239 | 240 | 241 | --- 242 | 243 | ## Waifu 244 | 245 | 著者の学習させたFFHQモデルが動作するのは確認したので第三者によって学習させたモデルが動くか試してみます. 246 | [Making Anime Faces With StyleGAN - Gwern](https://www.gwern.net/Faces) に 247 | 二次元キャラクター生成用の学習済みモデルがあるのでこれを使います. 248 | 249 | ### モデルのダウンロード・配置 250 | 4つの学習済みモデルが配布されているのを見つけたのでこれで試します. 251 | 252 | 1. Face StyleGANv1 (This Waifu Does Not Exist v1) 512px Danbooru2017 (train 218,794) 253 | [2019-02-26-stylegan-faces-network-02048-016041.pkl](https://mega.nz/#!aPRFDKaC!FDpQi_FEPK443JoRBEOEDOmlLmJSblKFlqZ1A1XPt2Y) 254 | 2. Face StyleGANv1 512px 255 | [2019-03-08-stylegan-animefaces-network-02051-021980.pkl.xz](https://mega.nz/#!vawjXISI!F7s13yRicxDA3QYqYDL2kjnc2K7Zk3DwCIYETREmBP4) 256 | 3. Portrait StyleGANv1 512px 257 | [2019-04-30-stylegan-danbooru2018-portraits-02095-066083.pkl](https://mega.nz/#!CRtiDI7S!xo4zm3n7pkq1Lsfmuio1O8QPpUwHrtFTHjNJ8_XxSJs) 258 | 4. Portrait StyleGANv2 (This Waifu Does Not Exist v3) 512px 259 | [2020-01-11-skylion-stylegan2-animeportraits-networksnapshot-024664.pkl.xz](https://mega.nz/#!PeIi2ayb!xoRtjTXyXuvgDxSsSMn-cOh-Zux9493zqdxwVMaAzp4) 260 | 261 | StyleGANv1では (1) Faceを学習させたもの,(2) そこから更に学習させたもの,(3) Portraitを学習させたもの,の3つがあるようです. 262 | StyleGANv2では (4) Portraitを学習させたもの,が1つあります. 263 | Faceはクローズ・アップぐらいの構図,Portraitはアップ・ショットぐらいの構図のようです. 264 | 265 | 圧縮されているものは解凍して以下のように配置します. 266 | ``` 267 | - /wherever/you/want/ 268 | - 2019-02-26-stylegan-faces-network-02048-016041.pkl 269 | - 2019-03-08-stylegan-animefaces-network-02051-021980.pkl 270 | - 2019-04-30-stylegan-danbooru2018-portraits-02095-066083.pkl 271 | - 2020-01-11-skylion-stylegan2-animeportraits-networksnapshot-024664.pkl 272 | ``` 273 | 274 | ### 重みの変換・画像の生成 275 | 276 | ``waifu`` ディレクトリに雑に書き換えたものを置いておきました. 277 | 278 | #### tensorflowで重みを変換 279 | 280 | ``workdir`` で 281 | ``` 282 | cp stylegans-pytorch/waifu/converter.py stylegan/ 283 | cd stylegan 284 | python converter.py 1 face_v1_1 -w $STYLEGANSDIR -o $STYLEGANSDIR 285 | python converter.py 1 face_v1_2 -w $STYLEGANSDIR -o $STYLEGANSDIR 286 | python converter.py 1 portrait_v1 -w $STYLEGANSDIR -o $STYLEGANSDIR 287 | cd - 288 | 289 | cp stylegans-pytorch/waifu/converter.py stylegan2/ 290 | docker run --gpus all -v $PWD/stylegan2:/workdir \ 291 | -w /workdir -v $STYLEGANSDIR:/stylegans-pytorch \ 292 | -it --rm tkarras/stylegan2:latest 293 | CUDA_VISIBLE_DEVICES=0 python converter.py 2 portrait_v2 -w /stylegans-pytorch -o /stylegans-pytorch 294 | exit 295 | sudo chown -R $(whoami) $STYLEGANSDIR 296 | ``` 297 | 298 | #### pytorchで読み込んで画像生成 299 | 300 | ``workdir`` で 301 | ``` 302 | python stylegans-pytorch/waifu/run_pt_stylegan.py 1 face_v1_1 -w $STYLEGANSDIR -o $STYLEGANSDIR 303 | python stylegans-pytorch/waifu/run_pt_stylegan.py 1 face_v1_2 -w $STYLEGANSDIR -o $STYLEGANSDIR 304 | python stylegans-pytorch/waifu/run_pt_stylegan.py 1 portrait_v1 -w $STYLEGANSDIR -o $STYLEGANSDIR 305 | python stylegans-pytorch/waifu/run_pt_stylegan.py 2 portrait_v2 -w $STYLEGANSDIR -o $STYLEGANSDIR 306 | ``` 307 | 308 | ### 結果 309 | うまくいったようです. 310 | 311 | 左: オリジナル出力, 右: 再現出力 312 | 313 | - (1) anime_face_v1_1 314 | ![anime_face_v1_1](./img/waifu/anime_face_v1_1.jpg) 315 | 316 | - (2) anime_face_v1_2 317 | ![anime_face_v1_2](./img/waifu/anime_face_v1_2.jpg) 318 | 319 | - (3) anime_portrait_v1 320 | ![anime_portrait_v1](./img/waifu/anime_portrait_v1.jpg) 321 | 322 | - (4) anime_portrait_v2 323 | ![anime_portrait_v2](./img/waifu/anime_portrait_v2.jpg) 324 | 325 | ## TODO 326 | - style mixingもやる 327 | - StyleGANv2 の色味が違う原因を特定 328 | - projection 329 | - train 330 | -------------------------------------------------------------------------------- /conda_env.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=main 5 | _tflow_select=2.1.0=gpu 6 | absl-py=0.8.1=py37_0 7 | asn1crypto=1.2.0=py37_0 8 | astor=0.8.0=py37_0 9 | blas=1.0=mkl 10 | bzip2=1.0.8=h7b6447c_0 11 | c-ares=1.15.0=h7b6447c_1001 12 | ca-certificates=2019.11.28=hecc5488_0 13 | cairo=1.14.12=h8948797_3 14 | certifi=2019.11.28=py37_0 15 | cffi=1.13.2=py37h2e261b9_0 16 | chardet=3.0.4=py37_1003 17 | cryptography=2.8=py37h1ba5d50_0 18 | cudatoolkit=10.1.243=h6bb024c_0 19 | cudnn=7.6.4=cuda10.1_0 20 | cupti=10.1.168=0 21 | decorator=4.4.1=py_0 22 | ffmpeg=4.0=hcdf2ecd_0 23 | fontconfig=2.13.0=h9420a91_0 24 | freeglut=3.0.0=hf484d3e_5 25 | freetype=2.9.1=h8a8886c_1 26 | gast=0.3.2=py_0 27 | glib=2.63.1=h5a9c865_0 28 | google-pasta=0.1.8=py_0 29 | graphite2=1.3.13=h23475e2_0 30 | grpcio=1.16.1=py37hf8bcb03_1 31 | h5py=2.8.0=py37h989c5e5_3 32 | harfbuzz=1.8.8=hffaf4a1_0 33 | hdf5=1.10.2=hba1933b_1 34 | icu=58.2=h9c2bf20_1 35 | idna=2.8=py37_0 36 | imageio=2.6.1=py37_0 37 | intel-openmp=2019.4=243 38 | jasper=2.0.14=h07fcdf6_1 39 | jpeg=9b=h024ee3a_2 40 | keras-applications=1.0.8=py_0 41 | keras-preprocessing=1.1.0=py_1 42 | libedit=3.1.20181209=hc058e9b_0 43 | libffi=3.2.1=hd88cf55_4 44 | libgcc-ng=9.1.0=hdf63c60_0 45 | libgfortran-ng=7.3.0=hdf63c60_0 46 | libglu=9.0.0=hf484d3e_1 47 | libopencv=3.4.2=hb342d67_1 48 | libopus=1.3=h7b6447c_0 49 | libpng=1.6.37=hbc83047_0 50 | libprotobuf=3.10.1=hd408876_0 51 | libstdcxx-ng=9.1.0=hdf63c60_0 52 | libtiff=4.1.0=h2733197_0 53 | libuuid=1.0.3=h1bed415_2 54 | libvpx=1.7.0=h439df22_0 55 | libxcb=1.13=h1bed415_1 56 | libxml2=2.9.9=hea5a465_1 57 | lmdb=0.9.23=he6710b0_0 58 | markdown=3.1.1=py37_0 59 | mkl=2019.4=243 60 | mkl-service=2.3.0=py37he904b0f_0 61 | mkl_fft=1.0.15=py37ha843d7b_0 62 | mkl_random=1.1.0=py37hd6b4f25_0 63 | moviepy=0.2.3.5=py_0 64 | ncurses=6.1=he6710b0_1 65 | ninja=1.9.0=py37hfd86e86_0 66 | numpy=1.17.4=py37hc1035e2_0 67 | numpy-base=1.17.4=py37hde5b4d6_0 68 | olefile=0.46=py_0 69 | opencv=3.4.2=py37h6fd60c2_1 70 | openssl=1.1.1d=h516909a_0 71 | pcre=8.43=he6710b0_0 72 | pillow=6.2.1=py37h34e0f95_0 73 | pip=19.3.1=py37_0 74 | pixman=0.38.0=h7b6447c_0 75 | protobuf=3.10.1=py37he6710b0_0 76 | py-opencv=3.4.2=py37hb342d67_1 77 | pycparser=2.19=py37_0 78 | pyopenssl=19.1.0=py37_0 79 | pysocks=1.7.1=py37_0 80 | python=3.7.5=h0371630_0 81 | pytorch=1.3.1=py3.7_cuda10.1.243_cudnn7.6.3_0 82 | readline=7.0=h7b6447c_5 83 | requests=2.22.0=py37_1 84 | scipy=1.3.2=py37h7c811a0_0 85 | setuptools=42.0.2=py37_0 86 | six=1.13.0=py37_0 87 | sqlite=3.30.1=h7b6447c_0 88 | tensorboard=1.14.0=py37hf484d3e_0 89 | tensorflow=1.14.0=gpu_py37h74c33d7_0 90 | tensorflow-base=1.14.0=gpu_py37he45bfe2_0 91 | tensorflow-estimator=1.14.0=py_0 92 | tensorflow-gpu=1.14.0=h0d30ee6_0 93 | termcolor=1.1.0=py37_1 94 | tk=8.6.8=hbc83047_0 95 | tqdm=4.40.1=py_0 96 | urllib3=1.25.7=py37_0 97 | werkzeug=0.16.0=py_0 98 | wheel=0.33.6=py37_0 99 | wrapt=1.11.2=py37h7b6447c_0 100 | xz=5.2.4=h14c3975_4 101 | zlib=1.2.11=h7b6447c_3 102 | zstd=1.3.7=h0b5b093_0 103 | -------------------------------------------------------------------------------- /docker_install.md: -------------------------------------------------------------------------------- 1 | # rootless docker 2 | 3 | 最新のDockerではsudo権限がなくても各ユーザーが利用できるような仕組みが導入された. 4 | 5 | インターネットで拾ってきたイメージを各ユーザーがroot権限で実行できるようになっていては危険なので 6 | 共有サーバーではrootless Dockerを利用する必要がある. 7 | 8 | # 概要 9 | rootless dockerは各ユーザーの ``~/bin/`` にインストールされ, 10 | 各ユーザーが ``systemctl --user start docker`` として自分の権限でdockerデーモンを起動して使う. 11 | 12 | gpuを使うためにはシステム権限でnvidia-container-runtimeを入れて, 13 | 各ユーザーがアクセスできるように設定しておく必要がある. 14 | 15 | # インストール方法 16 | ## sudo権限が必要なこと 17 | root権限で実行しなければならないのはnvidia-container-runtimeのインストールと設定のみである. 18 | 19 | ### nvidia-container-runtime用のGPG鍵とPPAの追加 20 | ``` 21 | curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add - 22 | distribution=$(. /etc/os-release;echo $ID$VERSION_ID) 23 | curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list 24 | sudo apt-get update 25 | ``` 26 | nvidiaのGPG鍵は定期的にアップデートしないとapt updateが通らなくなるので注意.上記コマンドを再度行えばOK 27 | 28 | ### nvidia-container-runtimeのインストール 29 | ``` 30 | sudo apt install nvidia-container-runtime 31 | ``` 32 | 33 | ### ユーザーがGPUを使えるように設定 34 | ``sudo vim /etc/nvidia-container-runtime/config.toml`` 35 | 一行だけ書き換える 36 | ``` 37 | no-cgroups = true 38 | ``` 39 | 40 | 再起動 41 | ``` 42 | sudo reboot 43 | ``` 44 | 45 | ## 各ユーザーがやること 46 | * rootless dockerのインストール 47 | * dockerデーモンの起動 48 | 49 | ### rootless dockerのインストール 50 | これだけ 51 | ``` 52 | curl -sSL https://get.docker.com/rootless | sh 53 | ``` 54 | 55 | ``~/bin/``にdockerデーモンがインストールされるのでパスを通す必要がある. 56 | dockerはどこに接続すればいいか知りたいので ``DOCKER_HOST`` が設定されている必要がある. 57 | 58 | ### dockerデーモンの起動 59 | ちゃんとパスが通せていたら動くはず 60 | ``systemctl --user start docker`` 61 | 動かなかったらエラーメッセージをよく読んでパスを通す. 62 | LDAPなどユーザーの管理が各サーバー上に無い場合はできません... 63 | 64 | 65 | ### 実行できるかテスト 66 | ``` 67 | docker run --gpus all -it --rm nvidia/cuda:8.0-cudnn6-devel-ubuntu16.04 68 | nvidia-smi 69 | nvcc -V 70 | ``` 71 | 72 | -------------------------------------------------------------------------------- /img/stylegan1_pt_compression.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuuho/stylegans-pytorch/5f88cfa8b48956e7189b361d33c82f1f2fc3c6ea/img/stylegan1_pt_compression.jpg -------------------------------------------------------------------------------- /img/stylegan1_tf_compression.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuuho/stylegans-pytorch/5f88cfa8b48956e7189b361d33c82f1f2fc3c6ea/img/stylegan1_tf_compression.jpg -------------------------------------------------------------------------------- /img/stylegan2_pt_compression.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuuho/stylegans-pytorch/5f88cfa8b48956e7189b361d33c82f1f2fc3c6ea/img/stylegan2_pt_compression.jpg -------------------------------------------------------------------------------- /img/stylegan2_tf_compression.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuuho/stylegans-pytorch/5f88cfa8b48956e7189b361d33c82f1f2fc3c6ea/img/stylegan2_tf_compression.jpg -------------------------------------------------------------------------------- /img/waifu/anime_face_v1_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuuho/stylegans-pytorch/5f88cfa8b48956e7189b361d33c82f1f2fc3c6ea/img/waifu/anime_face_v1_1.jpg -------------------------------------------------------------------------------- /img/waifu/anime_face_v1_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuuho/stylegans-pytorch/5f88cfa8b48956e7189b361d33c82f1f2fc3c6ea/img/waifu/anime_face_v1_2.jpg -------------------------------------------------------------------------------- /img/waifu/anime_portrait_v1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuuho/stylegans-pytorch/5f88cfa8b48956e7189b361d33c82f1f2fc3c6ea/img/waifu/anime_portrait_v1.jpg -------------------------------------------------------------------------------- /img/waifu/anime_portrait_v2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuuho/stylegans-pytorch/5f88cfa8b48956e7189b361d33c82f1f2fc3c6ea/img/waifu/anime_portrait_v2.jpg -------------------------------------------------------------------------------- /network/common.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | # mapping前に潜在変数を超球面上に正規化 8 | class PixelwiseNormalization(nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | def forward(self, x): 12 | return x / torch.sqrt((x**2).mean(1,keepdim=True) + 1e-8) 13 | 14 | 15 | # 移動平均を用いて潜在変数を正規化する. 16 | class TruncationTrick(nn.Module): 17 | def __init__(self, num_target, threshold, output_num, style_dim): 18 | super().__init__() 19 | self.num_target = num_target 20 | self.threshold = threshold 21 | self.output_num = output_num 22 | self.register_buffer('avg_style', torch.zeros((style_dim,))) 23 | 24 | def forward(self, x): 25 | # in:(N,D) -> out:(N,O,D) 26 | N,D = x.shape 27 | O = self.output_num 28 | x = x.view(N,1,D).expand(N,O,D) 29 | rate = torch.cat([ torch.ones((N, self.num_target, D)) *self.threshold, 30 | torch.ones((N, O-self.num_target, D)) *1.0 ],1).to(x.device) 31 | avg = self.avg_style.view(1,1,D).expand(N,O,D) 32 | return avg + (x-avg)*rate 33 | 34 | 35 | # 特徴マップ信号を増幅する 36 | class Amplify(nn.Module): 37 | def __init__(self, rate): 38 | super().__init__() 39 | self.rate = rate 40 | def forward(self,x): 41 | return x * self.rate 42 | 43 | 44 | # チャンネルごとにバイアス項を足す 45 | class AddChannelwiseBias(nn.Module): 46 | def __init__(self, out_channels, lr): 47 | super().__init__() 48 | # lr = 1.0 (conv,mod,AdaIN), 0.01 (mapping) 49 | 50 | self.bias = nn.Parameter(torch.zeros(out_channels)) 51 | torch.nn.init.zeros_(self.bias.data) 52 | self.bias_scaler = lr 53 | 54 | def forward(self, x): 55 | oC,*_ = self.bias.shape 56 | shape = (1,oC) if x.ndim==2 else (1,oC,1,1) 57 | y = x + self.bias.view(*shape)*self.bias_scaler 58 | return y 59 | 60 | 61 | # 学習率を調整したFC層 62 | class EqualizedFullyConnect(nn.Module): 63 | def __init__(self, in_dim, out_dim, lr): 64 | super().__init__() 65 | # lr = 0.01 (mapping), 1.0 (mod,AdaIN) 66 | 67 | self.weight = nn.Parameter(torch.randn((out_dim,in_dim))) 68 | torch.nn.init.normal_(self.weight.data, mean=0.0, std=1.0/lr) 69 | self.weight_scaler = 1/(in_dim**0.5)*lr 70 | 71 | def forward(self, x): 72 | # x (N,D) 73 | return F.linear(x, self.weight*self.weight_scaler, None) 74 | -------------------------------------------------------------------------------- /network/run_pt_stylegan.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | from pathlib import Path 4 | import pickle 5 | 6 | import numpy as np 7 | import cv2 8 | import torch 9 | 10 | 11 | # コマンドライン引数の取得 12 | def parse_args(): 13 | parser = argparse.ArgumentParser(description='著者実装を動かしたり重みを抜き出したり') 14 | parser.add_argument('version',type=int, 15 | help='version name') 16 | parser.add_argument('-w','--weight_dir',type=str,default='/tmp/stylegans-pytorch', 17 | help='学習済みのモデルを保存する場所') 18 | parser.add_argument('-o','--output_dir',type=str,default='/tmp/stylegans-pytorch', 19 | help='生成された画像を保存する場所') 20 | parser.add_argument('--batch_size',type=int,default=1, 21 | help='バッチサイズ') 22 | parser.add_argument('--device',type=str,default='gpu',choices=['gpu','cpu'], 23 | help='デバイス') 24 | args = parser.parse_args() 25 | args.resolution = 1024 26 | return args 27 | 28 | 29 | # 変換関数 30 | ops_dict = { 31 | # 変調転置畳み込みの重み (iC,oC,kH,kW) 32 | 'mTc' : lambda weight: torch.flip(torch.from_numpy(weight.transpose((2,3,0,1))), [2, 3]), 33 | # 転置畳み込みの重み (iC,oC,kH,kW) 34 | 'Tco' : lambda weight: torch.from_numpy(weight.transpose((2,3,0,1))), 35 | # 畳み込みの重み (oC,iC,kH,kW) 36 | 'con' : lambda weight: torch.from_numpy(weight.transpose((3,2,0,1))), 37 | # 全結合層の重み (oD, iD) 38 | 'fc_' : lambda weight: torch.from_numpy(weight.transpose((1, 0))), 39 | # 全結合層のバイアス項, 固定入力, 固定ノイズ, v1ノイズの重み (無変換) 40 | 'any' : lambda weight: torch.from_numpy(weight), 41 | # Style-Mixingの値, v2ノイズの重み (scalar) 42 | 'uns' : lambda weight: torch.from_numpy(np.array(weight).reshape(1)), 43 | } 44 | 45 | 46 | setting = { 47 | 1: { 48 | 'src_weight': 'stylegan1_ndarray.pkl', 49 | 'src_latent': 'latents1.pkl', 50 | 'dst_image' : 'stylegan1_pt.png', 51 | 'dst_weight': 'stylegan1_state_dict.pth' }, 52 | 2: { 53 | 'src_weight': 'stylegan2_ndarray.pkl', 54 | 'src_latent': 'latents2.pkl', 55 | 'dst_image' : 'stylegan2_pt.png', 56 | 'dst_weight': 'stylegan2_state_dict.pth' }, 57 | } 58 | 59 | 60 | if __name__ == '__main__': 61 | # コマンドライン引数の取得 62 | args = parse_args() 63 | 64 | # バージョンによって切り替え 65 | cfg = setting[args.version] 66 | if args.version==1: 67 | from stylegan1 import Generator, name_trans_dict 68 | elif args.version==2: 69 | from stylegan2 import Generator, name_trans_dict 70 | 71 | 72 | print('model construction...') 73 | generator = Generator() 74 | base_dict = generator.state_dict() 75 | 76 | print('model weights load...') 77 | with (Path(args.weight_dir)/cfg['src_weight']).open('rb') as f: 78 | src_dict = pickle.load(f) 79 | 80 | print('set state_dict...') 81 | new_dict = { k : ops_dict[v[0]](src_dict[v[1]]) for k,v in name_trans_dict.items()} 82 | generator.load_state_dict(new_dict) 83 | 84 | print('load latents...') 85 | with (Path(args.output_dir)/cfg['src_latent']).open('rb') as f: 86 | latents = pickle.load(f) 87 | latents = torch.from_numpy(latents.astype(np.float32)) 88 | 89 | print('network forward...') 90 | device = torch.device('cuda') if torch.cuda.is_available() and args.device=='gpu' else torch.device('cpu') 91 | with torch.no_grad(): 92 | N,_ = latents.shape 93 | generator.to(device) 94 | images = np.empty((N,args.resolution,args.resolution,3),dtype=np.uint8) 95 | 96 | for i in range(0,N,args.batch_size): 97 | j = min(i+args.batch_size,N) 98 | z = latents[i:j].to(device) 99 | img = generator(z) 100 | normalized = (img.clamp(-1,1)+1)/2*255 101 | images[i:j] = normalized.permute(0,2,3,1).cpu().numpy().astype(np.uint8) 102 | del z, img, normalized 103 | 104 | # 出力を並べる関数 105 | def make_table(imgs): 106 | # 出力する個数,解像度 107 | num_H, num_W = 4,4 108 | H = W = args.resolution 109 | num_images = num_H*num_W 110 | 111 | canvas = np.zeros((H*num_H,W*num_W,3),dtype=np.uint8) 112 | for i,p in enumerate(imgs[:num_images]): 113 | h,w = i//num_W, i%num_W 114 | canvas[H*h:H*-~h,W*w:W*-~w,:] = p[:,:,::-1] 115 | return canvas 116 | 117 | print('image output...') 118 | cv2.imwrite(str(Path(args.output_dir)/cfg['dst_image']), make_table(images)) 119 | 120 | print('weight save...') 121 | torch.save(generator.state_dict(),str(Path(args.weight_dir)/cfg['dst_weight'])) 122 | 123 | print('all done') 124 | -------------------------------------------------------------------------------- /network/stylegan1.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from common import PixelwiseNormalization, Amplify, AddChannelwiseBias, EqualizedFullyConnect, TruncationTrick 8 | 9 | 10 | # 固定ノイズ 11 | class ElementwiseNoise(nn.Module): 12 | def __init__(self, ch, size_hw): 13 | super().__init__() 14 | self.register_buffer("const_noise", torch.randn((1, 1, size_hw, size_hw))) 15 | self.noise_scaler = nn.Parameter(torch.zeros((ch,))) 16 | 17 | def forward(self, x): 18 | N,C,H,W = x.shape 19 | noise = self.const_noise.expand(N,C,H,W) 20 | scaler = self.noise_scaler.view(1,C,1,1) 21 | return x + noise * scaler 22 | 23 | 24 | # ブラー : 解像度を上げる畳み込みの後に使う 25 | class Blur3x3(nn.Module): 26 | def __init__(self): 27 | super().__init__() 28 | f = np.array( [ [1/16, 2/16, 1/16], 29 | [2/16, 4/16, 2/16], 30 | [1/16, 2/16, 1/16]], dtype=np.float32).reshape([1, 1, 3, 3]) 31 | self.filter = torch.from_numpy(f) 32 | 33 | def forward(self, x): 34 | _N,C,_H,_W = x.shape 35 | return F.conv2d(x, self.filter.to(x.device).expand(C,1,3,3), padding=1, groups=C) 36 | 37 | 38 | # 学習率を調整した転置畳み込み (ブラーのための拡張あり) 39 | class EqualizedFusedConvTransposed2d(nn.Module): 40 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, lr): 41 | super().__init__() 42 | # lr = 1.0 43 | 44 | self.stride, self.padding = stride, padding 45 | 46 | self.weight = nn.Parameter(torch.empty(in_channels, out_channels, kernel_size, kernel_size)) 47 | torch.nn.init.normal_(self.weight.data, mean=0.0, std=1.0/lr) 48 | self.weight_scaler = 1 / ((in_channels * (kernel_size ** 2) )**0.5) * lr 49 | 50 | def forward(self, x): 51 | # 3x3 conv を 4x4 transposed conv として使う 52 | i_ch, o_ch, _kh, _kw = self.weight.shape 53 | # Padding (L,R,T,B) で4x4の四隅に3x3フィルタを寄せて和で合成 54 | weight_4x4 = torch.cat([F.pad(self.weight, pad).view(1,i_ch,o_ch,4,4) 55 | for pad in [(0,1,0,1),(1,0,0,1),(0,1,1,0),(1,0,1,0)]]).sum(dim=0) 56 | return F.conv_transpose2d(x, weight_4x4*self.weight_scaler, stride=2, padding=1) 57 | # 3x3でconvしてからpadで4隅に寄せて計算しても同じでは? 58 | # padding0にしてStyleGAN2のBlurを使っても同じでは? 59 | 60 | 61 | # 学習率を調整した畳込み 62 | class EqualizedConv2d(nn.Module): 63 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, lr): 64 | super().__init__() 65 | # lr = 1.0 66 | 67 | self.stride, self.padding = stride, padding 68 | 69 | self.weight = nn.Parameter(torch.empty(out_channels, in_channels, kernel_size, kernel_size)) 70 | torch.nn.init.normal_(self.weight.data, mean=0.0, std=1.0/lr) 71 | self.weight_scaler = 1 / ((in_channels * (kernel_size ** 2) )**0.5) * lr 72 | 73 | def forward(self, x): 74 | N,C,H,W = x.shape 75 | return F.conv2d(x, self.weight*self.weight_scaler, None, 76 | self.stride, self.padding) 77 | 78 | 79 | # 学習率を調整したAdaIN 80 | class EqualizedAdaIN(nn.Module): 81 | def __init__(self, fmap_ch, style_ch, lr): 82 | super().__init__() 83 | # lr = 1.0 84 | self.fc = EqualizedFullyConnect(style_ch, fmap_ch*2, lr) 85 | self.bias = AddChannelwiseBias(fmap_ch*2,lr) 86 | 87 | def forward(self, pack): 88 | x, style = pack 89 | #N,D = w.shape 90 | N,C,H,W = x.shape 91 | 92 | _vec = self.bias( self.fc(style) ).view(N,2*C,1,1) # (N,2C,1,1) 93 | scale, shift = _vec[:,:C,:,:], _vec[:,C:,:,:] # (N,C,1,1), (N,C,1,1) 94 | return (scale+1) * F.instance_norm(x, eps=1e-8) + shift 95 | 96 | 97 | class Generator(nn.Module): 98 | 99 | structure = { 100 | 'mapping': [['pixel_norm'], ['fc',512,512],['amp'],['b',512],['Lrelu'],['fc',512,512],['amp'],['b',512],['Lrelu'], 101 | ['fc',512,512],['amp'],['b',512],['Lrelu'],['fc',512,512],['amp'],['b',512],['Lrelu'], 102 | ['fc',512,512],['amp'],['b',512],['Lrelu'],['fc',512,512],['amp'],['b',512],['Lrelu'], 103 | ['fc',512,512],['amp'],['b',512],['Lrelu'],['fc',512,512],['amp'],['b',512],['Lrelu'],['truncation']], 104 | 'START' : [ ['noiseE',512, 4], ['bias',512], ['Lrelu'] ], 'adain4a' : [['adain',512]], 105 | 'Fconv4' : [ ['EqConv3x3',512, 512], ['amp'], ['noiseE',512, 4], ['bias',512], ['Lrelu'] ], 'adain4b' : [['adain',512]], 'toRGB_4' : [['EqConv1x1',512, 3], ['bias',3]], 106 | 'Uconv8' : [['up'], ['EqConv3x3',512, 512], ['blur3x3'], ['amp'], ['noiseE',512, 8], ['bias',512], ['Lrelu'] ], 'adain8a' : [['adain',512]], 107 | 'Fconv8' : [ ['EqConv3x3',512, 512], ['amp'], ['noiseE',512, 8], ['bias',512], ['Lrelu'] ], 'adain8b' : [['adain',512]], 'toRGB_8' : [['EqConv1x1',512, 3], ['bias',3]], 108 | 'Uconv16' : [['up'], ['EqConv3x3',512, 512], ['blur3x3'], ['amp'], ['noiseE',512, 16], ['bias',512], ['Lrelu'] ], 'adain16a' : [['adain',512]], 109 | 'Fconv16' : [ ['EqConv3x3',512, 512], ['amp'], ['noiseE',512, 16], ['bias',512], ['Lrelu'] ], 'adain16b' : [['adain',512]], 'toRGB_16' : [['EqConv1x1',512, 3], ['bias',3]], 110 | 'Uconv32' : [['up'], ['EqConv3x3',512, 512], ['blur3x3'], ['amp'], ['noiseE',512, 32], ['bias',512], ['Lrelu'] ], 'adain32a' : [['adain',512]], 111 | 'Fconv32' : [ ['EqConv3x3',512, 512], ['amp'], ['noiseE',512, 32], ['bias',512], ['Lrelu'] ], 'adain32b' : [['adain',512]], 'toRGB_32' : [['EqConv1x1',512, 3], ['bias',3]], 112 | 'Uconv64' : [['up'], ['EqConv3x3',512, 256], ['blur3x3'], ['amp'], ['noiseE',256, 64], ['bias',256], ['Lrelu'] ], 'adain64a' : [['adain',256]], 113 | 'Fconv64' : [ ['EqConv3x3',256, 256], ['amp'], ['noiseE',256, 64], ['bias',256], ['Lrelu'] ], 'adain64b' : [['adain',256]], 'toRGB_64' : [['EqConv1x1',256, 3], ['bias',3]], 114 | 'Uconv128' : [ ['EqConvT3x3EX',256, 128], ['blur3x3'], ['amp'], ['noiseE',128, 128], ['bias',128], ['Lrelu'] ], 'adain128a' : [['adain',128]], 115 | 'Fconv128' : [ ['EqConv3x3',128, 128], ['amp'], ['noiseE',128, 128], ['bias',128], ['Lrelu'] ], 'adain128b' : [['adain',128]], 'toRGB_128' : [['EqConv1x1',128, 3], ['bias',3]], 116 | 'Uconv256' : [ ['EqConvT3x3EX',128, 64], ['blur3x3'], ['amp'], ['noiseE', 64, 256], ['bias', 64], ['Lrelu'] ], 'adain256a' : [['adain', 64]], 117 | 'Fconv256' : [ ['EqConv3x3', 64, 64], ['amp'], ['noiseE', 64, 256], ['bias', 64], ['Lrelu'] ], 'adain256b' : [['adain', 64]], 'toRGB_256' : [['EqConv1x1', 64, 3], ['bias',3]], 118 | 'Uconv512' : [ ['EqConvT3x3EX', 64, 32], ['blur3x3'], ['amp'], ['noiseE', 32, 512], ['bias', 32], ['Lrelu'] ], 'adain512a' : [['adain', 32]], 119 | 'Fconv512' : [ ['EqConv3x3', 32, 32], ['amp'], ['noiseE', 32, 512], ['bias', 32], ['Lrelu'] ], 'adain512b' : [['adain', 32]], 'toRGB_512' : [['EqConv1x1', 32, 3], ['bias',3]], 120 | 'Uconv1024': [ ['EqConvT3x3EX', 32, 16], ['blur3x3'], ['amp'], ['noiseE', 16, 1024], ['bias', 16], ['Lrelu'] ], 'adain1024a': [['adain', 16]], 121 | 'Fconv1024': [ ['EqConv3x3', 16, 16], ['amp'], ['noiseE', 16, 1024], ['bias', 16], ['Lrelu'] ], 'adain1024b': [['adain', 16]], 'toRGB_1024': [['EqConv1x1', 16, 3], ['bias',3]], 122 | } 123 | 124 | def _make_sequential(self,key): 125 | definition = { 126 | 'pixel_norm' : lambda *config: PixelwiseNormalization(), 127 | 'truncation' : lambda *config: TruncationTrick( 128 | num_target=8, threshold=0.7, output_num=18, style_dim=512 ), 129 | 'fc' : lambda *config: EqualizedFullyConnect( 130 | in_dim=config[0],out_dim=config[1], lr=0.01), 131 | 'b' : lambda *config: AddChannelwiseBias(out_channels=config[0], lr=0.01), 132 | 'bias' : lambda *config: AddChannelwiseBias(out_channels=config[0], lr=1.0), 133 | 'amp' : lambda *config: Amplify(2**0.5), 134 | 'Lrelu' : lambda *config: nn.LeakyReLU(negative_slope=0.2), 135 | 'EqConvT3x3EX' : lambda *config: EqualizedFusedConvTransposed2d( 136 | in_channels=config[0], out_channels=config[1], 137 | kernel_size=3, stride=1, padding=1, lr=1.0), 138 | 'EqConv3x3' : lambda *config: EqualizedConv2d( 139 | in_channels=config[0], out_channels=config[1], 140 | kernel_size=3, stride=1, padding=1, lr=1.0), 141 | 'EqConv1x1' : lambda *config: EqualizedConv2d( 142 | in_channels=config[0], out_channels=config[1], 143 | kernel_size=1, stride=1, padding=0, lr=1.0), 144 | 'noiseE' : lambda *config: ElementwiseNoise(ch=config[0], size_hw=config[1]), 145 | 'blur3x3' : lambda *config: Blur3x3(), 146 | 'up' : lambda *config: nn.Upsample( 147 | scale_factor=2,mode='nearest'), 148 | 'adain' : lambda *config: EqualizedAdaIN( 149 | fmap_ch=config[0], style_ch=512, lr=1.0), 150 | } 151 | return nn.Sequential(*[ definition[k](*cfg) for k,*cfg in self.structure[key]]) 152 | 153 | def __init__(self): 154 | super().__init__() 155 | 156 | # 固定入力値 157 | self.register_buffer('const',torch.ones((1, 512, 4, 4),dtype=torch.float32)) 158 | 159 | # 今回は使わない 160 | self.register_buffer('image_mixing_rate',torch.zeros((1,))) # 複数のtoRGBの合成比率 161 | self.register_buffer('style_mixing_rate',torch.zeros((1,))) # スタイルの合成比率 162 | 163 | # 潜在変数のマッピングネットワーク 164 | self.mapping = self._make_sequential('mapping') 165 | self.blocks = nn.ModuleList([self._make_sequential(k) for k in [ 166 | 'START', 'Fconv4', 'Uconv8', 'Fconv8', 'Uconv16', 'Fconv16', 167 | 'Uconv32', 'Fconv32', 'Uconv64', 'Fconv64', 'Uconv128', 'Fconv128', 168 | 'Uconv256', 'Fconv256', 'Uconv512', 'Fconv512', 'Uconv1024','Fconv1024' 169 | ] ]) 170 | self.adains = nn.ModuleList([self._make_sequential(k) for k in [ 171 | 'adain4a', 'adain4b', 'adain8a', 'adain8b', 172 | 'adain16a', 'adain16b', 'adain32a', 'adain32b', 173 | 'adain64a', 'adain64b', 'adain128a', 'adain128b', 174 | 'adain256a', 'adain256b', 'adain512a', 'adain512b', 175 | 'adain1024a', 'adain1024b' 176 | ] ]) 177 | self.toRGBs = nn.ModuleList([self._make_sequential(k) for k in [ 178 | 'toRGB_4', 'toRGB_8', 'toRGB_16', 'toRGB_32', 179 | 'toRGB_64', 'toRGB_128', 'toRGB_256', 'toRGB_512', 180 | 'toRGB_1024' 181 | ] ]) 182 | 183 | def forward(self, z): 184 | ''' 185 | input: z : (N,D) D=512 186 | output: img : (N,3,1024,1024) 187 | ''' 188 | N,D = z.shape 189 | 190 | styles = self.mapping(z) # (N,18,D) 191 | tmp = self.const.expand(N,512,4,4) 192 | for i, (adain, conv) in enumerate(zip(self.adains, self.blocks)): 193 | tmp = conv(tmp) 194 | tmp = adain( (tmp, styles[:,i,:]) ) 195 | img = self.toRGBs[-1](tmp) 196 | 197 | return img 198 | 199 | 200 | ########## 以下,重み変換 ######## 201 | 202 | name_trans_dict = { 203 | 'const' : ['any', 'G_synthesis/4x4/Const/const' ], 204 | 'image_mixing_rate' : ['uns', 'G_synthesis/lod' ], 205 | 'style_mixing_rate' : ['uns', 'lod' ], 206 | 'mapping.1.weight' : ['fc_', 'G_mapping/Dense0/weight' ], 207 | 'mapping.3.bias' : ['any', 'G_mapping/Dense0/bias' ], 208 | 'mapping.5.weight' : ['fc_', 'G_mapping/Dense1/weight' ], 209 | 'mapping.7.bias' : ['any', 'G_mapping/Dense1/bias' ], 210 | 'mapping.9.weight' : ['fc_', 'G_mapping/Dense2/weight' ], 211 | 'mapping.11.bias' : ['any', 'G_mapping/Dense2/bias' ], 212 | 'mapping.13.weight' : ['fc_', 'G_mapping/Dense3/weight' ], 213 | 'mapping.15.bias' : ['any', 'G_mapping/Dense3/bias' ], 214 | 'mapping.17.weight' : ['fc_', 'G_mapping/Dense4/weight' ], 215 | 'mapping.19.bias' : ['any', 'G_mapping/Dense4/bias' ], 216 | 'mapping.21.weight' : ['fc_', 'G_mapping/Dense5/weight' ], 217 | 'mapping.23.bias' : ['any', 'G_mapping/Dense5/bias' ], 218 | 'mapping.25.weight' : ['fc_', 'G_mapping/Dense6/weight' ], 219 | 'mapping.27.bias' : ['any', 'G_mapping/Dense6/bias' ], 220 | 'mapping.29.weight' : ['fc_', 'G_mapping/Dense7/weight' ], 221 | 'mapping.31.bias' : ['any', 'G_mapping/Dense7/bias' ], 222 | 'mapping.33.avg_style' : ['any', 'dlatent_avg' ], 223 | 'blocks.0.0.noise_scaler' : ['any', 'G_synthesis/4x4/Const/Noise/weight' ], 224 | 'blocks.0.0.const_noise' : ['any', 'G_synthesis/noise0' ], 225 | 'blocks.0.1.bias' : ['any', 'G_synthesis/4x4/Const/bias' ], 226 | 'blocks.1.0.weight' : ['con', 'G_synthesis/4x4/Conv/weight' ], 227 | 'blocks.1.2.noise_scaler' : ['any', 'G_synthesis/4x4/Conv/Noise/weight' ], 228 | 'blocks.1.2.const_noise' : ['any', 'G_synthesis/noise1' ], 229 | 'blocks.1.3.bias' : ['any', 'G_synthesis/4x4/Conv/bias' ], 230 | 'blocks.2.1.weight' : ['con', 'G_synthesis/8x8/Conv0_up/weight' ], 231 | 'blocks.2.4.noise_scaler' : ['any', 'G_synthesis/8x8/Conv0_up/Noise/weight' ], 232 | 'blocks.2.4.const_noise' : ['any', 'G_synthesis/noise2' ], 233 | 'blocks.2.5.bias' : ['any', 'G_synthesis/8x8/Conv0_up/bias' ], 234 | 'blocks.3.0.weight' : ['con', 'G_synthesis/8x8/Conv1/weight' ], 235 | 'blocks.3.2.noise_scaler' : ['any', 'G_synthesis/8x8/Conv1/Noise/weight' ], 236 | 'blocks.3.2.const_noise' : ['any', 'G_synthesis/noise3' ], 237 | 'blocks.3.3.bias' : ['any', 'G_synthesis/8x8/Conv1/bias' ], 238 | 'blocks.4.1.weight' : ['con', 'G_synthesis/16x16/Conv0_up/weight' ], 239 | 'blocks.4.4.noise_scaler' : ['any', 'G_synthesis/16x16/Conv0_up/Noise/weight' ], 240 | 'blocks.4.4.const_noise' : ['any', 'G_synthesis/noise4' ], 241 | 'blocks.4.5.bias' : ['any', 'G_synthesis/16x16/Conv0_up/bias' ], 242 | 'blocks.5.0.weight' : ['con', 'G_synthesis/16x16/Conv1/weight' ], 243 | 'blocks.5.2.noise_scaler' : ['any', 'G_synthesis/16x16/Conv1/Noise/weight' ], 244 | 'blocks.5.2.const_noise' : ['any', 'G_synthesis/noise5' ], 245 | 'blocks.5.3.bias' : ['any', 'G_synthesis/16x16/Conv1/bias' ], 246 | 'blocks.6.1.weight' : ['con', 'G_synthesis/32x32/Conv0_up/weight' ], 247 | 'blocks.6.4.noise_scaler' : ['any', 'G_synthesis/32x32/Conv0_up/Noise/weight' ], 248 | 'blocks.6.4.const_noise' : ['any', 'G_synthesis/noise6' ], 249 | 'blocks.6.5.bias' : ['any', 'G_synthesis/32x32/Conv0_up/bias' ], 250 | 'blocks.7.0.weight' : ['con', 'G_synthesis/32x32/Conv1/weight' ], 251 | 'blocks.7.2.noise_scaler' : ['any', 'G_synthesis/32x32/Conv1/Noise/weight' ], 252 | 'blocks.7.2.const_noise' : ['any', 'G_synthesis/noise7' ], 253 | 'blocks.7.3.bias' : ['any', 'G_synthesis/32x32/Conv1/bias' ], 254 | 'blocks.8.1.weight' : ['con', 'G_synthesis/64x64/Conv0_up/weight' ], 255 | 'blocks.8.4.noise_scaler' : ['any', 'G_synthesis/64x64/Conv0_up/Noise/weight' ], 256 | 'blocks.8.4.const_noise' : ['any', 'G_synthesis/noise8' ], 257 | 'blocks.8.5.bias' : ['any', 'G_synthesis/64x64/Conv0_up/bias' ], 258 | 'blocks.9.0.weight' : ['con', 'G_synthesis/64x64/Conv1/weight' ], 259 | 'blocks.9.2.noise_scaler' : ['any', 'G_synthesis/64x64/Conv1/Noise/weight' ], 260 | 'blocks.9.2.const_noise' : ['any', 'G_synthesis/noise9' ], 261 | 'blocks.9.3.bias' : ['any', 'G_synthesis/64x64/Conv1/bias' ], 262 | 'blocks.10.0.weight' : ['Tco', 'G_synthesis/128x128/Conv0_up/weight' ], 263 | 'blocks.10.3.noise_scaler' : ['any', 'G_synthesis/128x128/Conv0_up/Noise/weight' ], 264 | 'blocks.10.3.const_noise' : ['any', 'G_synthesis/noise10' ], 265 | 'blocks.10.4.bias' : ['any', 'G_synthesis/128x128/Conv0_up/bias' ], 266 | 'blocks.11.0.weight' : ['con', 'G_synthesis/128x128/Conv1/weight' ], 267 | 'blocks.11.2.noise_scaler' : ['any', 'G_synthesis/128x128/Conv1/Noise/weight' ], 268 | 'blocks.11.2.const_noise' : ['any', 'G_synthesis/noise11' ], 269 | 'blocks.11.3.bias' : ['any', 'G_synthesis/128x128/Conv1/bias' ], 270 | 'blocks.12.0.weight' : ['Tco', 'G_synthesis/256x256/Conv0_up/weight' ], 271 | 'blocks.12.3.noise_scaler' : ['any', 'G_synthesis/256x256/Conv0_up/Noise/weight' ], 272 | 'blocks.12.3.const_noise' : ['any', 'G_synthesis/noise12' ], 273 | 'blocks.12.4.bias' : ['any', 'G_synthesis/256x256/Conv0_up/bias' ], 274 | 'blocks.13.0.weight' : ['con', 'G_synthesis/256x256/Conv1/weight' ], 275 | 'blocks.13.2.noise_scaler' : ['any', 'G_synthesis/256x256/Conv1/Noise/weight' ], 276 | 'blocks.13.2.const_noise' : ['any', 'G_synthesis/noise13' ], 277 | 'blocks.13.3.bias' : ['any', 'G_synthesis/256x256/Conv1/bias' ], 278 | 'blocks.14.0.weight' : ['Tco', 'G_synthesis/512x512/Conv0_up/weight' ], 279 | 'blocks.14.3.noise_scaler' : ['any', 'G_synthesis/512x512/Conv0_up/Noise/weight' ], 280 | 'blocks.14.3.const_noise' : ['any', 'G_synthesis/noise14' ], 281 | 'blocks.14.4.bias' : ['any', 'G_synthesis/512x512/Conv0_up/bias' ], 282 | 'blocks.15.0.weight' : ['con', 'G_synthesis/512x512/Conv1/weight' ], 283 | 'blocks.15.2.noise_scaler' : ['any', 'G_synthesis/512x512/Conv1/Noise/weight' ], 284 | 'blocks.15.2.const_noise' : ['any', 'G_synthesis/noise15' ], 285 | 'blocks.15.3.bias' : ['any', 'G_synthesis/512x512/Conv1/bias' ], 286 | 'blocks.16.0.weight' : ['Tco', 'G_synthesis/1024x1024/Conv0_up/weight' ], 287 | 'blocks.16.3.noise_scaler' : ['any', 'G_synthesis/1024x1024/Conv0_up/Noise/weight' ], 288 | 'blocks.16.3.const_noise' : ['any', 'G_synthesis/noise16' ], 289 | 'blocks.16.4.bias' : ['any', 'G_synthesis/1024x1024/Conv0_up/bias' ], 290 | 'blocks.17.0.weight' : ['con', 'G_synthesis/1024x1024/Conv1/weight' ], 291 | 'blocks.17.2.noise_scaler' : ['any', 'G_synthesis/1024x1024/Conv1/Noise/weight' ], 292 | 'blocks.17.2.const_noise' : ['any', 'G_synthesis/noise17' ], 293 | 'blocks.17.3.bias' : ['any', 'G_synthesis/1024x1024/Conv1/bias' ], 294 | 'adains.0.0.fc.weight' : ['fc_', 'G_synthesis/4x4/Const/StyleMod/weight' ], 295 | 'adains.0.0.bias.bias' : ['any', 'G_synthesis/4x4/Const/StyleMod/bias' ], 296 | 'adains.1.0.fc.weight' : ['fc_', 'G_synthesis/4x4/Conv/StyleMod/weight' ], 297 | 'adains.1.0.bias.bias' : ['any', 'G_synthesis/4x4/Conv/StyleMod/bias' ], 298 | 'adains.2.0.fc.weight' : ['fc_', 'G_synthesis/8x8/Conv0_up/StyleMod/weight' ], 299 | 'adains.2.0.bias.bias' : ['any', 'G_synthesis/8x8/Conv0_up/StyleMod/bias' ], 300 | 'adains.3.0.fc.weight' : ['fc_', 'G_synthesis/8x8/Conv1/StyleMod/weight' ], 301 | 'adains.3.0.bias.bias' : ['any', 'G_synthesis/8x8/Conv1/StyleMod/bias' ], 302 | 'adains.4.0.fc.weight' : ['fc_', 'G_synthesis/16x16/Conv0_up/StyleMod/weight' ], 303 | 'adains.4.0.bias.bias' : ['any', 'G_synthesis/16x16/Conv0_up/StyleMod/bias' ], 304 | 'adains.5.0.fc.weight' : ['fc_', 'G_synthesis/16x16/Conv1/StyleMod/weight' ], 305 | 'adains.5.0.bias.bias' : ['any', 'G_synthesis/16x16/Conv1/StyleMod/bias' ], 306 | 'adains.6.0.fc.weight' : ['fc_', 'G_synthesis/32x32/Conv0_up/StyleMod/weight' ], 307 | 'adains.6.0.bias.bias' : ['any', 'G_synthesis/32x32/Conv0_up/StyleMod/bias' ], 308 | 'adains.7.0.fc.weight' : ['fc_', 'G_synthesis/32x32/Conv1/StyleMod/weight' ], 309 | 'adains.7.0.bias.bias' : ['any', 'G_synthesis/32x32/Conv1/StyleMod/bias' ], 310 | 'adains.8.0.fc.weight' : ['fc_', 'G_synthesis/64x64/Conv0_up/StyleMod/weight' ], 311 | 'adains.8.0.bias.bias' : ['any', 'G_synthesis/64x64/Conv0_up/StyleMod/bias' ], 312 | 'adains.9.0.fc.weight' : ['fc_', 'G_synthesis/64x64/Conv1/StyleMod/weight' ], 313 | 'adains.9.0.bias.bias' : ['any', 'G_synthesis/64x64/Conv1/StyleMod/bias' ], 314 | 'adains.10.0.fc.weight' : ['fc_', 'G_synthesis/128x128/Conv0_up/StyleMod/weight' ], 315 | 'adains.10.0.bias.bias' : ['any', 'G_synthesis/128x128/Conv0_up/StyleMod/bias' ], 316 | 'adains.11.0.fc.weight' : ['fc_', 'G_synthesis/128x128/Conv1/StyleMod/weight' ], 317 | 'adains.11.0.bias.bias' : ['any', 'G_synthesis/128x128/Conv1/StyleMod/bias' ], 318 | 'adains.12.0.fc.weight' : ['fc_', 'G_synthesis/256x256/Conv0_up/StyleMod/weight' ], 319 | 'adains.12.0.bias.bias' : ['any', 'G_synthesis/256x256/Conv0_up/StyleMod/bias' ], 320 | 'adains.13.0.fc.weight' : ['fc_', 'G_synthesis/256x256/Conv1/StyleMod/weight' ], 321 | 'adains.13.0.bias.bias' : ['any', 'G_synthesis/256x256/Conv1/StyleMod/bias' ], 322 | 'adains.14.0.fc.weight' : ['fc_', 'G_synthesis/512x512/Conv0_up/StyleMod/weight' ], 323 | 'adains.14.0.bias.bias' : ['any', 'G_synthesis/512x512/Conv0_up/StyleMod/bias' ], 324 | 'adains.15.0.fc.weight' : ['fc_', 'G_synthesis/512x512/Conv1/StyleMod/weight' ], 325 | 'adains.15.0.bias.bias' : ['any', 'G_synthesis/512x512/Conv1/StyleMod/bias' ], 326 | 'adains.16.0.fc.weight' : ['fc_', 'G_synthesis/1024x1024/Conv0_up/StyleMod/weight' ], 327 | 'adains.16.0.bias.bias' : ['any', 'G_synthesis/1024x1024/Conv0_up/StyleMod/bias' ], 328 | 'adains.17.0.fc.weight' : ['fc_', 'G_synthesis/1024x1024/Conv1/StyleMod/weight' ], 329 | 'adains.17.0.bias.bias' : ['any', 'G_synthesis/1024x1024/Conv1/StyleMod/bias' ], 330 | 'toRGBs.0.0.weight' : ['con', 'G_synthesis/ToRGB_lod8/weight' ], 331 | 'toRGBs.0.1.bias' : ['any', 'G_synthesis/ToRGB_lod8/bias' ], 332 | 'toRGBs.1.0.weight' : ['con', 'G_synthesis/ToRGB_lod7/weight' ], 333 | 'toRGBs.1.1.bias' : ['any', 'G_synthesis/ToRGB_lod7/bias' ], 334 | 'toRGBs.2.0.weight' : ['con', 'G_synthesis/ToRGB_lod6/weight' ], 335 | 'toRGBs.2.1.bias' : ['any', 'G_synthesis/ToRGB_lod6/bias' ], 336 | 'toRGBs.3.0.weight' : ['con', 'G_synthesis/ToRGB_lod5/weight' ], 337 | 'toRGBs.3.1.bias' : ['any', 'G_synthesis/ToRGB_lod5/bias' ], 338 | 'toRGBs.4.0.weight' : ['con', 'G_synthesis/ToRGB_lod4/weight' ], 339 | 'toRGBs.4.1.bias' : ['any', 'G_synthesis/ToRGB_lod4/bias' ], 340 | 'toRGBs.5.0.weight' : ['con', 'G_synthesis/ToRGB_lod3/weight' ], 341 | 'toRGBs.5.1.bias' : ['any', 'G_synthesis/ToRGB_lod3/bias' ], 342 | 'toRGBs.6.0.weight' : ['con', 'G_synthesis/ToRGB_lod2/weight' ], 343 | 'toRGBs.6.1.bias' : ['any', 'G_synthesis/ToRGB_lod2/bias' ], 344 | 'toRGBs.7.0.weight' : ['con', 'G_synthesis/ToRGB_lod1/weight' ], 345 | 'toRGBs.7.1.bias' : ['any', 'G_synthesis/ToRGB_lod1/bias' ], 346 | 'toRGBs.8.0.weight' : ['con', 'G_synthesis/ToRGB_lod0/weight' ], 347 | 'toRGBs.8.1.bias' : ['any', 'G_synthesis/ToRGB_lod0/bias' ], 348 | } 349 | -------------------------------------------------------------------------------- /network/stylegan2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from common import PixelwiseNormalization, Amplify, AddChannelwiseBias, EqualizedFullyConnect, TruncationTrick 8 | 9 | 10 | # 固定ノイズ 11 | class PixelwiseNoise(nn.Module): 12 | def __init__(self, resolution): 13 | super().__init__() 14 | self.register_buffer("const_noise", torch.randn((1, 1, resolution, resolution))) 15 | self.noise_scaler = nn.Parameter(torch.zeros(1)) 16 | 17 | def forward(self, x): 18 | N,C,H,W = x.shape 19 | noise = self.const_noise.expand(N,C,H,W) 20 | return x + noise * self.noise_scaler 21 | 22 | 23 | # 解像度上げるときのModConvのみ 24 | class FusedBlur3x3(nn.Module): 25 | def __init__(self): 26 | super().__init__() 27 | kernel = np.array([ [1/16, 2/16, 1/16], 28 | [2/16, 4/16, 2/16], 29 | [1/16, 2/16, 1/16]],dtype=np.float32) 30 | pads = [[(0,1),(0,1)],[(0,1),(1,0)],[(1,0),(0,1)],[(1,0),(1,0)]] 31 | kernel = np.stack( [np.pad(kernel,pad,'constant') for pad in pads] ).sum(0) 32 | #kernel [ [1/16, 3/16, 3/16, 1/16,], 33 | # [3/16, 9/16, 9/16, 3/16,], 34 | # [3/16, 9/16, 9/16, 3/16,], 35 | # [1/16, 3/16, 3/16, 1/16,] ] 36 | self.kernel = torch.from_numpy(kernel) 37 | 38 | def forward(self, feature): 39 | # featureは(N,C,H+1,W+1) 40 | kernel = self.kernel.clone().to(feature.device) 41 | _N,C,_Hp1,_Wp1 = feature.shape 42 | return F.conv2d(feature, kernel.expand(C,1,4,4), padding=1, groups=C) 43 | 44 | 45 | # 学習率を調整した変調転置畳み込み 46 | class EqualizedModulatedConvTranspose2d(nn.Module): 47 | def __init__(self, in_channels, out_channels, kernel_size, style_dim, padding, stride, demodulate=True, lr=1): 48 | super().__init__() 49 | 50 | self.padding, self.stride = padding, stride 51 | self.demodulate = demodulate 52 | 53 | self.weight = nn.Parameter( torch.randn(in_channels, out_channels, kernel_size, kernel_size)) 54 | torch.nn.init.normal_(self.weight.data, mean=0.0, std=1.0/lr) 55 | self.weight_scaler = 1 / (in_channels * kernel_size*kernel_size)**0.5 * lr 56 | 57 | self.fc = EqualizedFullyConnect(style_dim, in_channels, lr) 58 | self.bias = AddChannelwiseBias(in_channels, lr) 59 | 60 | def forward(self, pack): 61 | x, style = pack 62 | N, iC, H, W = x.shape 63 | iC, oC, kH, kW = self.weight.shape 64 | 65 | mod_rates = self.bias(self.fc(style))+1 # (N, iC) 66 | modulated_weight = self.weight_scaler*self.weight.view(1,iC,oC,kH,kW) \ 67 | * mod_rates.view(N,iC,1,1,1) # (N,iC,oC,kH,kW) 68 | 69 | if self.demodulate: 70 | demod_norm = 1 / ((modulated_weight**2).sum([1,3,4]) + 1e-8)**0.5 # (N, oC) 71 | weight = modulated_weight * demod_norm.view(N, 1, oC, 1, 1) # (N,iC,oC,kH,kW) 72 | else: 73 | weight = modulated_weight 74 | 75 | x = x.view(1, N*iC, H, W) 76 | weight = weight.view(N*iC,oC,kH,kW) 77 | out = F.conv_transpose2d(x, weight, padding=self.padding, stride=self.stride, groups=N) 78 | 79 | _, _, Hp1, Wp1 = out.shape 80 | out = out.view(N, oC, Hp1, Wp1) 81 | 82 | return out 83 | 84 | 85 | # 学習率を調整した変調畳み込み 86 | class EqualizedModulatedConv2d(nn.Module): 87 | def __init__(self, in_channels, out_channels, kernel_size, style_dim, padding, stride, demodulate=True, lr=1): 88 | super().__init__() 89 | 90 | self.padding, self.stride = padding, stride 91 | self.demodulate = demodulate 92 | 93 | self.weight = nn.Parameter( torch.randn(out_channels, in_channels, kernel_size, kernel_size)) 94 | torch.nn.init.normal_(self.weight.data, mean=0.0, std=1.0/lr) 95 | self.weight_scaler = 1 / (in_channels*kernel_size*kernel_size)**0.5 * lr 96 | 97 | self.fc = EqualizedFullyConnect(style_dim, in_channels, lr) 98 | self.bias = AddChannelwiseBias(in_channels, lr) 99 | 100 | def forward(self, pack): 101 | x, style = pack 102 | N, iC, H, W = x.shape 103 | oC, iC, kH, kW = self.weight.shape 104 | 105 | mod_rates = self.bias(self.fc(style))+1 # (N, iC) 106 | modulated_weight = self.weight_scaler*self.weight.view(1,oC,iC,kH,kW) \ 107 | * mod_rates.view(N,1,iC,1,1) # (N,oC,iC,kH,kW) 108 | 109 | if self.demodulate: 110 | demod_norm = 1 / ((modulated_weight**2).sum([2,3,4]) + 1e-8)**0.5 # (N, oC) 111 | weight = modulated_weight * demod_norm.view(N, oC, 1, 1, 1) # (N,oC,iC,kH,kW) 112 | else: # ToRGB 113 | weight = modulated_weight 114 | 115 | out = F.conv2d(x.view(1,N*iC,H,W), weight.view(N*oC,iC,kH,kW), 116 | padding=self.padding, stride=self.stride, groups=N).view(N,oC,H,W) 117 | return out 118 | 119 | 120 | class Generator(nn.Module): 121 | 122 | structure = { 123 | 'mapping': [['pixel_norm'], ['fc',512,512],['b',512],['amp'],['Lrelu'],['fc',512,512],['b',512],['amp'],['Lrelu'], 124 | ['fc',512,512],['b',512],['amp'],['Lrelu'],['fc',512,512],['b',512],['amp'],['Lrelu'], 125 | ['fc',512,512],['b',512],['amp'],['Lrelu'],['fc',512,512],['b',512],['amp'],['Lrelu'], 126 | ['fc',512,512],['b',512],['amp'],['Lrelu'],['fc',512,512],['b',512],['amp'],['Lrelu'],['truncation']], 127 | 'Fconv4' : [['EqModConv3x3', 512, 512], ['noiseP', 4], ['bias',512], ['amp'], ['Lrelu'] ], 'toRGB_4' : [['EqModConv1x1',512, 3], ['bias',3]], 128 | 'Uconv8' : [['EqModConvT3x3', 512, 512], ['blurEX'], ['noiseP', 8], ['bias',512], ['amp'], ['Lrelu'] ], 129 | 'Fconv8' : [['EqModConv3x3', 512, 512], ['noiseP', 8], ['bias',512], ['amp'], ['Lrelu'] ], 'toRGB_8' : [['EqModConv1x1',512, 3], ['bias',3]], 130 | 'Uconv16' : [['EqModConvT3x3', 512, 512], ['blurEX'], ['noiseP', 16], ['bias',512], ['amp'], ['Lrelu'] ], 131 | 'Fconv16' : [['EqModConv3x3', 512, 512], ['noiseP', 16], ['bias',512], ['amp'], ['Lrelu'] ], 'toRGB_16' : [['EqModConv1x1',512, 3], ['bias',3]], 132 | 'Uconv32' : [['EqModConvT3x3', 512, 512], ['blurEX'], ['noiseP', 32], ['bias',512], ['amp'], ['Lrelu'] ], 133 | 'Fconv32' : [['EqModConv3x3', 512, 512], ['noiseP', 32], ['bias',512], ['amp'], ['Lrelu'] ], 'toRGB_32' : [['EqModConv1x1',512, 3], ['bias',3]], 134 | 'Uconv64' : [['EqModConvT3x3', 512, 512], ['blurEX'], ['noiseP', 64], ['bias',512], ['amp'], ['Lrelu'] ], 135 | 'Fconv64' : [['EqModConv3x3', 512, 512], ['noiseP', 64], ['bias',512], ['amp'], ['Lrelu'] ], 'toRGB_64' : [['EqModConv1x1',512, 3], ['bias',3]], 136 | 'Uconv128' : [['EqModConvT3x3', 512, 256], ['blurEX'], ['noiseP', 128], ['bias',256], ['amp'], ['Lrelu'] ], 137 | 'Fconv128' : [['EqModConv3x3', 256, 256], ['noiseP', 128], ['bias',256], ['amp'], ['Lrelu'] ], 'toRGB_128' : [['EqModConv1x1',256, 3], ['bias',3]], 138 | 'Uconv256' : [['EqModConvT3x3', 256, 128], ['blurEX'], ['noiseP', 256], ['bias',128], ['amp'], ['Lrelu'] ], 139 | 'Fconv256' : [['EqModConv3x3', 128, 128], ['noiseP', 256], ['bias',128], ['amp'], ['Lrelu'] ], 'toRGB_256' : [['EqModConv1x1',128, 3], ['bias',3]], 140 | 'Uconv512' : [['EqModConvT3x3', 128, 64], ['blurEX'], ['noiseP', 512], ['bias', 64], ['amp'], ['Lrelu'] ], 141 | 'Fconv512' : [['EqModConv3x3', 64, 64], ['noiseP', 512], ['bias', 64], ['amp'], ['Lrelu'] ], 'toRGB_512' : [['EqModConv1x1', 64, 3], ['bias',3]], 142 | 'Uconv1024': [['EqModConvT3x3', 64, 32], ['blurEX'], ['noiseP',1024], ['bias', 32], ['amp'], ['Lrelu'] ], 143 | 'Fconv1024': [['EqModConv3x3', 32, 32], ['noiseP',1024], ['bias', 32], ['amp'], ['Lrelu'] ], 'toRGB_1024': [['EqModConv1x1', 32, 3], ['bias',3]], 144 | } 145 | 146 | def _make_sequential(self,key): 147 | definition = { 148 | 'pixel_norm' : lambda *config: PixelwiseNormalization(), 149 | 'truncation' : lambda *config: TruncationTrick( 150 | num_target=10, threshold=0.7, output_num=18, style_dim=512), 151 | 'fc' : lambda *config: EqualizedFullyConnect( 152 | in_dim=config[0], out_dim=config[1], lr=0.01), 153 | 'b' : lambda *config: AddChannelwiseBias(out_channels=config[0], lr=0.01), 154 | 'bias' : lambda *config: AddChannelwiseBias(out_channels=config[0], lr=1.0), 155 | 'amp' : lambda *config: Amplify(2**0.5), 156 | 'Lrelu' : lambda *config: nn.LeakyReLU(negative_slope=0.2), 157 | 'EqModConvT3x3': lambda *config: EqualizedModulatedConvTranspose2d( 158 | in_channels=config[0], out_channels=config[1], 159 | kernel_size=3, stride=2, padding=0, 160 | demodulate=True, lr=1.0, style_dim=512), 161 | 'EqModConv3x3' : lambda *config: EqualizedModulatedConv2d( 162 | in_channels=config[0], out_channels=config[1], 163 | kernel_size=3, stride=1, padding=1, 164 | demodulate=True, lr=1.0, style_dim=512), 165 | 'EqModConv1x1' : lambda *config: EqualizedModulatedConv2d( 166 | in_channels=config[0], out_channels=config[1], 167 | kernel_size=1, stride=1, padding=0, 168 | demodulate=False, lr=1.0, style_dim=512), 169 | 'noiseP' : lambda *config: PixelwiseNoise(resolution=config[0]), 170 | 'blurEX' : lambda *config: FusedBlur3x3(), 171 | } 172 | return nn.Sequential(*[ definition[k](*cfg) for k,*cfg in self.structure[key]]) 173 | 174 | 175 | def __init__(self): 176 | super().__init__() 177 | 178 | self.const_input = nn.Parameter(torch.randn(1, 512, 4, 4)) 179 | self.register_buffer('style_mixing_rate',torch.zeros((1,))) # スタイルの合成比率,今回は使わない 180 | 181 | self.mapping = self._make_sequential('mapping') 182 | self.blocks = nn.ModuleList([self._make_sequential(k) for k in [ 183 | 'Fconv4', 'Uconv8', 'Fconv8', 'Uconv16', 'Fconv16', 184 | 'Uconv32', 'Fconv32', 'Uconv64', 'Fconv64', 'Uconv128', 'Fconv128', 185 | 'Uconv256', 'Fconv256', 'Uconv512', 'Fconv512', 'Uconv1024','Fconv1024' 186 | ] ]) 187 | self.toRGBs = nn.ModuleList([self._make_sequential(k) for k in [ 188 | 'toRGB_4', 'toRGB_8', 'toRGB_16', 'toRGB_32', 189 | 'toRGB_64', 'toRGB_128', 'toRGB_256', 'toRGB_512', 190 | 'toRGB_1024' 191 | ] ]) 192 | 193 | 194 | def forward(self, z): 195 | N,D = z.shape 196 | 197 | # 潜在変数からスタイルへ変換 198 | styles = self.mapping(z) # (N,18,D) 199 | styles = [styles[:,i] for i in range(18)] # list[(N,D),]x18 200 | 201 | tmp = self.const_input.repeat(N, 1, 1, 1) 202 | tmp = self.blocks[0]( (tmp,styles[0]) ) 203 | skip = self.toRGBs[0]( (tmp,styles[1]) ) 204 | 205 | for convU, convF, toRGB, styU,styF,styT in zip( \ 206 | self.blocks[1::2], self.blocks[2::2], self.toRGBs[1:], 207 | styles[1::2], styles[2::2], styles[3::2]): 208 | tmp = convU( (tmp,styU) ) 209 | tmp = convF( (tmp,styF) ) 210 | skip = toRGB( (tmp,styT) ) + F.interpolate(skip,scale_factor=2,mode='bilinear',align_corners=False) 211 | 212 | return skip 213 | 214 | 215 | # { pytorchでの名前 : [変換関数, tensorflowでの名前] } 216 | name_trans_dict = { 217 | 'const_input' : ['any', 'G_synthesis/4x4/Const/const' ], 218 | 'style_mixing_rate' : ['uns', 'lod' ], 219 | 'mapping.1.weight' : ['fc_', 'G_mapping/Dense0/weight' ], 220 | 'mapping.2.bias' : ['any', 'G_mapping/Dense0/bias' ], 221 | 'mapping.5.weight' : ['fc_', 'G_mapping/Dense1/weight' ], 222 | 'mapping.6.bias' : ['any', 'G_mapping/Dense1/bias' ], 223 | 'mapping.9.weight' : ['fc_', 'G_mapping/Dense2/weight' ], 224 | 'mapping.10.bias' : ['any', 'G_mapping/Dense2/bias' ], 225 | 'mapping.13.weight' : ['fc_', 'G_mapping/Dense3/weight' ], 226 | 'mapping.14.bias' : ['any', 'G_mapping/Dense3/bias' ], 227 | 'mapping.17.weight' : ['fc_', 'G_mapping/Dense4/weight' ], 228 | 'mapping.18.bias' : ['any', 'G_mapping/Dense4/bias' ], 229 | 'mapping.21.weight' : ['fc_', 'G_mapping/Dense5/weight' ], 230 | 'mapping.22.bias' : ['any', 'G_mapping/Dense5/bias' ], 231 | 'mapping.25.weight' : ['fc_', 'G_mapping/Dense6/weight' ], 232 | 'mapping.26.bias' : ['any', 'G_mapping/Dense6/bias' ], 233 | 'mapping.29.weight' : ['fc_', 'G_mapping/Dense7/weight' ], 234 | 'mapping.30.bias' : ['any', 'G_mapping/Dense7/bias' ], 235 | 'mapping.33.avg_style' : ['any', 'dlatent_avg' ], 236 | 'blocks.0.0.weight' : ['con', 'G_synthesis/4x4/Conv/weight' ], 237 | 'blocks.0.0.fc.weight' : ['fc_', 'G_synthesis/4x4/Conv/mod_weight' ], 238 | 'blocks.0.0.bias.bias' : ['any', 'G_synthesis/4x4/Conv/mod_bias' ], 239 | 'blocks.0.1.noise_scaler' : ['uns', 'G_synthesis/4x4/Conv/noise_strength' ], 240 | 'blocks.0.1.const_noise' : ['any', 'G_synthesis/noise0' ], 241 | 'blocks.0.2.bias' : ['any', 'G_synthesis/4x4/Conv/bias' ], 242 | 'blocks.1.0.weight' : ['mTc', 'G_synthesis/8x8/Conv0_up/weight' ], 243 | 'blocks.1.0.fc.weight' : ['fc_', 'G_synthesis/8x8/Conv0_up/mod_weight' ], 244 | 'blocks.1.0.bias.bias' : ['any', 'G_synthesis/8x8/Conv0_up/mod_bias' ], 245 | 'blocks.1.2.noise_scaler' : ['uns', 'G_synthesis/8x8/Conv0_up/noise_strength' ], 246 | 'blocks.1.2.const_noise' : ['any', 'G_synthesis/noise1' ], 247 | 'blocks.1.3.bias' : ['any', 'G_synthesis/8x8/Conv0_up/bias' ], 248 | 'blocks.2.0.weight' : ['con', 'G_synthesis/8x8/Conv1/weight' ], 249 | 'blocks.2.0.fc.weight' : ['fc_', 'G_synthesis/8x8/Conv1/mod_weight' ], 250 | 'blocks.2.0.bias.bias' : ['any', 'G_synthesis/8x8/Conv1/mod_bias' ], 251 | 'blocks.2.1.noise_scaler' : ['uns', 'G_synthesis/8x8/Conv1/noise_strength' ], 252 | 'blocks.2.1.const_noise' : ['any', 'G_synthesis/noise2' ], 253 | 'blocks.2.2.bias' : ['any', 'G_synthesis/8x8/Conv1/bias' ], 254 | 'blocks.3.0.weight' : ['mTc', 'G_synthesis/16x16/Conv0_up/weight' ], 255 | 'blocks.3.0.fc.weight' : ['fc_', 'G_synthesis/16x16/Conv0_up/mod_weight' ], 256 | 'blocks.3.0.bias.bias' : ['any', 'G_synthesis/16x16/Conv0_up/mod_bias' ], 257 | 'blocks.3.2.noise_scaler' : ['uns', 'G_synthesis/16x16/Conv0_up/noise_strength' ], 258 | 'blocks.3.2.const_noise' : ['any', 'G_synthesis/noise3' ], 259 | 'blocks.3.3.bias' : ['any', 'G_synthesis/16x16/Conv0_up/bias' ], 260 | 'blocks.4.0.weight' : ['con', 'G_synthesis/16x16/Conv1/weight' ], 261 | 'blocks.4.0.fc.weight' : ['fc_', 'G_synthesis/16x16/Conv1/mod_weight' ], 262 | 'blocks.4.0.bias.bias' : ['any', 'G_synthesis/16x16/Conv1/mod_bias' ], 263 | 'blocks.4.1.noise_scaler' : ['uns', 'G_synthesis/16x16/Conv1/noise_strength' ], 264 | 'blocks.4.1.const_noise' : ['any', 'G_synthesis/noise4' ], 265 | 'blocks.4.2.bias' : ['any', 'G_synthesis/16x16/Conv1/bias' ], 266 | 'blocks.5.0.weight' : ['mTc', 'G_synthesis/32x32/Conv0_up/weight' ], 267 | 'blocks.5.0.fc.weight' : ['fc_', 'G_synthesis/32x32/Conv0_up/mod_weight' ], 268 | 'blocks.5.0.bias.bias' : ['any', 'G_synthesis/32x32/Conv0_up/mod_bias' ], 269 | 'blocks.5.2.noise_scaler' : ['uns', 'G_synthesis/32x32/Conv0_up/noise_strength' ], 270 | 'blocks.5.2.const_noise' : ['any', 'G_synthesis/noise5' ], 271 | 'blocks.5.3.bias' : ['any', 'G_synthesis/32x32/Conv0_up/bias' ], 272 | 'blocks.6.0.weight' : ['con', 'G_synthesis/32x32/Conv1/weight' ], 273 | 'blocks.6.0.fc.weight' : ['fc_', 'G_synthesis/32x32/Conv1/mod_weight' ], 274 | 'blocks.6.0.bias.bias' : ['any', 'G_synthesis/32x32/Conv1/mod_bias' ], 275 | 'blocks.6.1.noise_scaler' : ['uns', 'G_synthesis/32x32/Conv1/noise_strength' ], 276 | 'blocks.6.1.const_noise' : ['any', 'G_synthesis/noise6' ], 277 | 'blocks.6.2.bias' : ['any', 'G_synthesis/32x32/Conv1/bias' ], 278 | 'blocks.7.0.weight' : ['mTc', 'G_synthesis/64x64/Conv0_up/weight' ], 279 | 'blocks.7.0.fc.weight' : ['fc_', 'G_synthesis/64x64/Conv0_up/mod_weight' ], 280 | 'blocks.7.0.bias.bias' : ['any', 'G_synthesis/64x64/Conv0_up/mod_bias' ], 281 | 'blocks.7.2.noise_scaler' : ['uns', 'G_synthesis/64x64/Conv0_up/noise_strength' ], 282 | 'blocks.7.2.const_noise' : ['any', 'G_synthesis/noise7' ], 283 | 'blocks.7.3.bias' : ['any', 'G_synthesis/64x64/Conv0_up/bias' ], 284 | 'blocks.8.0.weight' : ['con', 'G_synthesis/64x64/Conv1/weight' ], 285 | 'blocks.8.0.fc.weight' : ['fc_', 'G_synthesis/64x64/Conv1/mod_weight' ], 286 | 'blocks.8.0.bias.bias' : ['any', 'G_synthesis/64x64/Conv1/mod_bias' ], 287 | 'blocks.8.1.noise_scaler' : ['uns', 'G_synthesis/64x64/Conv1/noise_strength' ], 288 | 'blocks.8.1.const_noise' : ['any', 'G_synthesis/noise8' ], 289 | 'blocks.8.2.bias' : ['any', 'G_synthesis/64x64/Conv1/bias' ], 290 | 'blocks.9.0.weight' : ['mTc', 'G_synthesis/128x128/Conv0_up/weight' ], 291 | 'blocks.9.0.fc.weight' : ['fc_', 'G_synthesis/128x128/Conv0_up/mod_weight' ], 292 | 'blocks.9.0.bias.bias' : ['any', 'G_synthesis/128x128/Conv0_up/mod_bias' ], 293 | 'blocks.9.2.noise_scaler' : ['uns', 'G_synthesis/128x128/Conv0_up/noise_strength' ], 294 | 'blocks.9.2.const_noise' : ['any', 'G_synthesis/noise9' ], 295 | 'blocks.9.3.bias' : ['any', 'G_synthesis/128x128/Conv0_up/bias' ], 296 | 'blocks.10.0.weight' : ['con', 'G_synthesis/128x128/Conv1/weight' ], 297 | 'blocks.10.0.fc.weight' : ['fc_', 'G_synthesis/128x128/Conv1/mod_weight' ], 298 | 'blocks.10.0.bias.bias' : ['any', 'G_synthesis/128x128/Conv1/mod_bias' ], 299 | 'blocks.10.1.noise_scaler' : ['uns', 'G_synthesis/128x128/Conv1/noise_strength' ], 300 | 'blocks.10.1.const_noise' : ['any', 'G_synthesis/noise10' ], 301 | 'blocks.10.2.bias' : ['any', 'G_synthesis/128x128/Conv1/bias' ], 302 | 'blocks.11.0.weight' : ['mTc', 'G_synthesis/256x256/Conv0_up/weight' ], 303 | 'blocks.11.0.fc.weight' : ['fc_', 'G_synthesis/256x256/Conv0_up/mod_weight' ], 304 | 'blocks.11.0.bias.bias' : ['any', 'G_synthesis/256x256/Conv0_up/mod_bias' ], 305 | 'blocks.11.2.noise_scaler' : ['uns', 'G_synthesis/256x256/Conv0_up/noise_strength' ], 306 | 'blocks.11.2.const_noise' : ['any', 'G_synthesis/noise11' ], 307 | 'blocks.11.3.bias' : ['any', 'G_synthesis/256x256/Conv0_up/bias' ], 308 | 'blocks.12.0.weight' : ['con', 'G_synthesis/256x256/Conv1/weight' ], 309 | 'blocks.12.0.fc.weight' : ['fc_', 'G_synthesis/256x256/Conv1/mod_weight' ], 310 | 'blocks.12.0.bias.bias' : ['any', 'G_synthesis/256x256/Conv1/mod_bias' ], 311 | 'blocks.12.1.noise_scaler' : ['uns', 'G_synthesis/256x256/Conv1/noise_strength' ], 312 | 'blocks.12.1.const_noise' : ['any', 'G_synthesis/noise12' ], 313 | 'blocks.12.2.bias' : ['any', 'G_synthesis/256x256/Conv1/bias' ], 314 | 'blocks.13.0.weight' : ['mTc', 'G_synthesis/512x512/Conv0_up/weight' ], 315 | 'blocks.13.0.fc.weight' : ['fc_', 'G_synthesis/512x512/Conv0_up/mod_weight' ], 316 | 'blocks.13.0.bias.bias' : ['any', 'G_synthesis/512x512/Conv0_up/mod_bias' ], 317 | 'blocks.13.2.noise_scaler' : ['uns', 'G_synthesis/512x512/Conv0_up/noise_strength' ], 318 | 'blocks.13.2.const_noise' : ['any', 'G_synthesis/noise13' ], 319 | 'blocks.13.3.bias' : ['any', 'G_synthesis/512x512/Conv0_up/bias' ], 320 | 'blocks.14.0.weight' : ['con', 'G_synthesis/512x512/Conv1/weight' ], 321 | 'blocks.14.0.fc.weight' : ['fc_', 'G_synthesis/512x512/Conv1/mod_weight' ], 322 | 'blocks.14.0.bias.bias' : ['any', 'G_synthesis/512x512/Conv1/mod_bias' ], 323 | 'blocks.14.1.noise_scaler' : ['uns', 'G_synthesis/512x512/Conv1/noise_strength' ], 324 | 'blocks.14.1.const_noise' : ['any', 'G_synthesis/noise14' ], 325 | 'blocks.14.2.bias' : ['any', 'G_synthesis/512x512/Conv1/bias' ], 326 | 'blocks.15.0.weight' : ['mTc', 'G_synthesis/1024x1024/Conv0_up/weight' ], 327 | 'blocks.15.0.fc.weight' : ['fc_', 'G_synthesis/1024x1024/Conv0_up/mod_weight' ], 328 | 'blocks.15.0.bias.bias' : ['any', 'G_synthesis/1024x1024/Conv0_up/mod_bias' ], 329 | 'blocks.15.2.noise_scaler' : ['uns', 'G_synthesis/1024x1024/Conv0_up/noise_strength'], 330 | 'blocks.15.2.const_noise' : ['any', 'G_synthesis/noise15' ], 331 | 'blocks.15.3.bias' : ['any', 'G_synthesis/1024x1024/Conv0_up/bias' ], 332 | 'blocks.16.0.weight' : ['con', 'G_synthesis/1024x1024/Conv1/weight' ], 333 | 'blocks.16.0.fc.weight' : ['fc_', 'G_synthesis/1024x1024/Conv1/mod_weight' ], 334 | 'blocks.16.0.bias.bias' : ['any', 'G_synthesis/1024x1024/Conv1/mod_bias' ], 335 | 'blocks.16.1.noise_scaler' : ['uns', 'G_synthesis/1024x1024/Conv1/noise_strength' ], 336 | 'blocks.16.1.const_noise' : ['any', 'G_synthesis/noise16' ], 337 | 'blocks.16.2.bias' : ['any', 'G_synthesis/1024x1024/Conv1/bias' ], 338 | 'toRGBs.0.0.weight' : ['con', 'G_synthesis/4x4/ToRGB/weight' ], 339 | 'toRGBs.0.0.fc.weight' : ['fc_', 'G_synthesis/4x4/ToRGB/mod_weight' ], 340 | 'toRGBs.0.0.bias.bias' : ['any', 'G_synthesis/4x4/ToRGB/mod_bias' ], 341 | 'toRGBs.0.1.bias' : ['any', 'G_synthesis/4x4/ToRGB/bias' ], 342 | 'toRGBs.1.0.weight' : ['con', 'G_synthesis/8x8/ToRGB/weight' ], 343 | 'toRGBs.1.0.fc.weight' : ['fc_', 'G_synthesis/8x8/ToRGB/mod_weight' ], 344 | 'toRGBs.1.0.bias.bias' : ['any', 'G_synthesis/8x8/ToRGB/mod_bias' ], 345 | 'toRGBs.1.1.bias' : ['any', 'G_synthesis/8x8/ToRGB/bias' ], 346 | 'toRGBs.2.0.weight' : ['con', 'G_synthesis/16x16/ToRGB/weight' ], 347 | 'toRGBs.2.0.fc.weight' : ['fc_', 'G_synthesis/16x16/ToRGB/mod_weight' ], 348 | 'toRGBs.2.0.bias.bias' : ['any', 'G_synthesis/16x16/ToRGB/mod_bias' ], 349 | 'toRGBs.2.1.bias' : ['any', 'G_synthesis/16x16/ToRGB/bias' ], 350 | 'toRGBs.3.0.weight' : ['con', 'G_synthesis/32x32/ToRGB/weight' ], 351 | 'toRGBs.3.0.fc.weight' : ['fc_', 'G_synthesis/32x32/ToRGB/mod_weight' ], 352 | 'toRGBs.3.0.bias.bias' : ['any', 'G_synthesis/32x32/ToRGB/mod_bias' ], 353 | 'toRGBs.3.1.bias' : ['any', 'G_synthesis/32x32/ToRGB/bias' ], 354 | 'toRGBs.4.0.weight' : ['con', 'G_synthesis/64x64/ToRGB/weight' ], 355 | 'toRGBs.4.0.fc.weight' : ['fc_', 'G_synthesis/64x64/ToRGB/mod_weight' ], 356 | 'toRGBs.4.0.bias.bias' : ['any', 'G_synthesis/64x64/ToRGB/mod_bias' ], 357 | 'toRGBs.4.1.bias' : ['any', 'G_synthesis/64x64/ToRGB/bias' ], 358 | 'toRGBs.5.0.weight' : ['con', 'G_synthesis/128x128/ToRGB/weight' ], 359 | 'toRGBs.5.0.fc.weight' : ['fc_', 'G_synthesis/128x128/ToRGB/mod_weight' ], 360 | 'toRGBs.5.0.bias.bias' : ['any', 'G_synthesis/128x128/ToRGB/mod_bias' ], 361 | 'toRGBs.5.1.bias' : ['any', 'G_synthesis/128x128/ToRGB/bias' ], 362 | 'toRGBs.6.0.weight' : ['con', 'G_synthesis/256x256/ToRGB/weight' ], 363 | 'toRGBs.6.0.fc.weight' : ['fc_', 'G_synthesis/256x256/ToRGB/mod_weight' ], 364 | 'toRGBs.6.0.bias.bias' : ['any', 'G_synthesis/256x256/ToRGB/mod_bias' ], 365 | 'toRGBs.6.1.bias' : ['any', 'G_synthesis/256x256/ToRGB/bias' ], 366 | 'toRGBs.7.0.weight' : ['con', 'G_synthesis/512x512/ToRGB/weight' ], 367 | 'toRGBs.7.0.fc.weight' : ['fc_', 'G_synthesis/512x512/ToRGB/mod_weight' ], 368 | 'toRGBs.7.0.bias.bias' : ['any', 'G_synthesis/512x512/ToRGB/mod_bias' ], 369 | 'toRGBs.7.1.bias' : ['any', 'G_synthesis/512x512/ToRGB/bias' ], 370 | 'toRGBs.8.0.weight' : ['con', 'G_synthesis/1024x1024/ToRGB/weight' ], 371 | 'toRGBs.8.0.fc.weight' : ['fc_', 'G_synthesis/1024x1024/ToRGB/mod_weight' ], 372 | 'toRGBs.8.0.bias.bias' : ['any', 'G_synthesis/1024x1024/ToRGB/mod_bias' ], 373 | 'toRGBs.8.1.bias' : ['any', 'G_synthesis/1024x1024/ToRGB/bias' ], 374 | } 375 | -------------------------------------------------------------------------------- /packaged/run_pt_stylegan1.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import pickle 4 | 5 | import numpy as np 6 | import cv2 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | # コマンドライン引数の取得 13 | def parse_args(): 14 | parser = argparse.ArgumentParser(description='著者実装を動かしたり重みを抜き出したり') 15 | parser.add_argument('-w','--weight_dir',type=str,default='/tmp/stylegans-pytorch', 16 | help='学習済みのモデルを保存する場所') 17 | parser.add_argument('-o','--output_dir',type=str,default='/tmp/stylegans-pytorch', 18 | help='生成された画像を保存する場所') 19 | parser.add_argument('--batch_size',type=int,default=1, 20 | help='バッチサイズ') 21 | parser.add_argument('--device',type=str,default='gpu',choices=['gpu','cpu'], 22 | help='デバイス') 23 | args = parser.parse_args() 24 | args.resolution = 1024 25 | return args 26 | 27 | 28 | # mapping前に潜在変数を超球面上に正規化 29 | class PixelwiseNormalization(nn.Module): 30 | def __init__(self): 31 | super().__init__() 32 | def forward(self, x): 33 | return x / torch.sqrt((x**2).mean(1,keepdim=True) + 1e-8) 34 | 35 | 36 | # 移動平均を用いて潜在変数を正規化する. 37 | class TruncationTrick(nn.Module): 38 | def __init__(self, num_target, threshold, output_num, style_dim): 39 | super().__init__() 40 | self.num_target = num_target 41 | self.threshold = threshold 42 | self.output_num = output_num 43 | self.register_buffer('avg_style', torch.zeros((style_dim,))) 44 | 45 | def forward(self, x): 46 | # in:(N,D) -> out:(N,O,D) 47 | N,D = x.shape 48 | O = self.output_num 49 | x = x.view(N,1,D).expand(N,O,D) 50 | rate = torch.cat([ torch.ones((N, self.num_target, D)) *self.threshold, 51 | torch.ones((N, O-self.num_target, D)) *1.0 ],1).to(x.device) 52 | avg = self.avg_style.view(1,1,D).expand(N,O,D) 53 | return avg + (x-avg)*rate 54 | 55 | 56 | # 特徴マップ信号を増幅する 57 | class Amplify(nn.Module): 58 | def __init__(self, rate): 59 | super().__init__() 60 | self.rate = rate 61 | def forward(self,x): 62 | return x * self.rate 63 | 64 | 65 | # チャンネルごとにバイアス項を足す 66 | class AddChannelwiseBias(nn.Module): 67 | def __init__(self, out_channels, lr): 68 | super().__init__() 69 | # lr = 1.0 (conv,mod,AdaIN), 0.01 (mapping) 70 | 71 | self.bias = nn.Parameter(torch.zeros(out_channels)) 72 | torch.nn.init.zeros_(self.bias.data) 73 | self.bias_scaler = lr 74 | 75 | def forward(self, x): 76 | oC,*_ = self.bias.shape 77 | shape = (1,oC) if x.ndim==2 else (1,oC,1,1) 78 | y = x + self.bias.view(*shape)*self.bias_scaler 79 | return y 80 | 81 | 82 | # 学習率を調整したFC層 83 | class EqualizedFullyConnect(nn.Module): 84 | def __init__(self, in_dim, out_dim, lr): 85 | super().__init__() 86 | # lr = 0.01 (mapping), 1.0 (mod,AdaIN) 87 | 88 | self.weight = nn.Parameter(torch.randn((out_dim,in_dim))) 89 | torch.nn.init.normal_(self.weight.data, mean=0.0, std=1.0/lr) 90 | self.weight_scaler = 1/(in_dim**0.5)*lr 91 | 92 | def forward(self, x): 93 | # x (N,D) 94 | return F.linear(x, self.weight*self.weight_scaler, None) 95 | 96 | 97 | # 固定ノイズ 98 | class ElementwiseNoise(nn.Module): 99 | def __init__(self, ch, size_hw): 100 | super().__init__() 101 | self.register_buffer("const_noise", torch.randn((1, 1, size_hw, size_hw))) 102 | self.noise_scaler = nn.Parameter(torch.zeros((ch,))) 103 | 104 | def forward(self, x): 105 | N,C,H,W = x.shape 106 | noise = self.const_noise.expand(N,C,H,W) 107 | scaler = self.noise_scaler.view(1,C,1,1) 108 | return x + noise * scaler 109 | 110 | 111 | # ブラー : 解像度を上げる畳み込みの後に使う 112 | class Blur3x3(nn.Module): 113 | def __init__(self): 114 | super().__init__() 115 | f = np.array( [ [1/16, 2/16, 1/16], 116 | [2/16, 4/16, 2/16], 117 | [1/16, 2/16, 1/16]], dtype=np.float32).reshape([1, 1, 3, 3]) 118 | self.filter = torch.from_numpy(f) 119 | 120 | def forward(self, x): 121 | _N,C,_H,_W = x.shape 122 | return F.conv2d(x, self.filter.to(x.device).expand(C,1,3,3), padding=1, groups=C) 123 | 124 | 125 | # 学習率を調整した転置畳み込み (ブラーのための拡張あり) 126 | class EqualizedFusedConvTransposed2d(nn.Module): 127 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, lr): 128 | super().__init__() 129 | # lr = 1.0 130 | 131 | self.stride, self.padding = stride, padding 132 | 133 | self.weight = nn.Parameter(torch.empty(in_channels, out_channels, kernel_size, kernel_size)) 134 | torch.nn.init.normal_(self.weight.data, mean=0.0, std=1.0/lr) 135 | self.weight_scaler = 1 / ((in_channels * (kernel_size ** 2) )**0.5) * lr 136 | 137 | def forward(self, x): 138 | # 3x3 conv を 4x4 transposed conv として使う 139 | i_ch, o_ch, _kh, _kw = self.weight.shape 140 | # Padding (L,R,T,B) で4x4の四隅に3x3フィルタを寄せて和で合成 141 | weight_4x4 = torch.cat([F.pad(self.weight, pad).view(1,i_ch,o_ch,4,4) 142 | for pad in [(0,1,0,1),(1,0,0,1),(0,1,1,0),(1,0,1,0)]]).sum(dim=0) 143 | return F.conv_transpose2d(x, weight_4x4*self.weight_scaler, stride=2, padding=1) 144 | # 3x3でconvしてからpadで4隅に寄せて計算しても同じでは? 145 | # padding0にしてStyleGAN2のBlurを使っても同じでは? 146 | 147 | 148 | # 学習率を調整した畳込み 149 | class EqualizedConv2d(nn.Module): 150 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, lr): 151 | super().__init__() 152 | # lr = 1.0 153 | 154 | self.stride, self.padding = stride, padding 155 | 156 | self.weight = nn.Parameter(torch.empty(out_channels, in_channels, kernel_size, kernel_size)) 157 | torch.nn.init.normal_(self.weight.data, mean=0.0, std=1.0/lr) 158 | self.weight_scaler = 1 / ((in_channels * (kernel_size ** 2) )**0.5) * lr 159 | 160 | def forward(self, x): 161 | N,C,H,W = x.shape 162 | return F.conv2d(x, self.weight*self.weight_scaler, None, 163 | self.stride, self.padding) 164 | 165 | 166 | # 学習率を調整したAdaIN 167 | class EqualizedAdaIN(nn.Module): 168 | def __init__(self, fmap_ch, style_ch, lr): 169 | super().__init__() 170 | # lr = 1.0 171 | self.fc = EqualizedFullyConnect(style_ch, fmap_ch*2, lr) 172 | self.bias = AddChannelwiseBias(fmap_ch*2,lr) 173 | 174 | def forward(self, pack): 175 | x, style = pack 176 | #N,D = w.shape 177 | N,C,H,W = x.shape 178 | 179 | _vec = self.bias( self.fc(style) ).view(N,2*C,1,1) # (N,2C,1,1) 180 | scale, shift = _vec[:,:C,:,:], _vec[:,C:,:,:] # (N,C,1,1), (N,C,1,1) 181 | return (scale+1) * F.instance_norm(x, eps=1e-8) + shift 182 | 183 | 184 | class Generator(nn.Module): 185 | 186 | structure = { 187 | 'mapping': [['pixel_norm'], ['fc',512,512],['amp'],['b',512],['Lrelu'],['fc',512,512],['amp'],['b',512],['Lrelu'], 188 | ['fc',512,512],['amp'],['b',512],['Lrelu'],['fc',512,512],['amp'],['b',512],['Lrelu'], 189 | ['fc',512,512],['amp'],['b',512],['Lrelu'],['fc',512,512],['amp'],['b',512],['Lrelu'], 190 | ['fc',512,512],['amp'],['b',512],['Lrelu'],['fc',512,512],['amp'],['b',512],['Lrelu'],['truncation']], 191 | 'START' : [ ['noiseE',512, 4], ['bias',512], ['Lrelu'] ], 'adain4a' : [['adain',512]], 192 | 'Fconv4' : [ ['EqConv3x3',512, 512], ['amp'], ['noiseE',512, 4], ['bias',512], ['Lrelu'] ], 'adain4b' : [['adain',512]], 'toRGB_4' : [['EqConv1x1',512, 3], ['bias',3]], 193 | 'Uconv8' : [['up'], ['EqConv3x3',512, 512], ['blur3x3'], ['amp'], ['noiseE',512, 8], ['bias',512], ['Lrelu'] ], 'adain8a' : [['adain',512]], 194 | 'Fconv8' : [ ['EqConv3x3',512, 512], ['amp'], ['noiseE',512, 8], ['bias',512], ['Lrelu'] ], 'adain8b' : [['adain',512]], 'toRGB_8' : [['EqConv1x1',512, 3], ['bias',3]], 195 | 'Uconv16' : [['up'], ['EqConv3x3',512, 512], ['blur3x3'], ['amp'], ['noiseE',512, 16], ['bias',512], ['Lrelu'] ], 'adain16a' : [['adain',512]], 196 | 'Fconv16' : [ ['EqConv3x3',512, 512], ['amp'], ['noiseE',512, 16], ['bias',512], ['Lrelu'] ], 'adain16b' : [['adain',512]], 'toRGB_16' : [['EqConv1x1',512, 3], ['bias',3]], 197 | 'Uconv32' : [['up'], ['EqConv3x3',512, 512], ['blur3x3'], ['amp'], ['noiseE',512, 32], ['bias',512], ['Lrelu'] ], 'adain32a' : [['adain',512]], 198 | 'Fconv32' : [ ['EqConv3x3',512, 512], ['amp'], ['noiseE',512, 32], ['bias',512], ['Lrelu'] ], 'adain32b' : [['adain',512]], 'toRGB_32' : [['EqConv1x1',512, 3], ['bias',3]], 199 | 'Uconv64' : [['up'], ['EqConv3x3',512, 256], ['blur3x3'], ['amp'], ['noiseE',256, 64], ['bias',256], ['Lrelu'] ], 'adain64a' : [['adain',256]], 200 | 'Fconv64' : [ ['EqConv3x3',256, 256], ['amp'], ['noiseE',256, 64], ['bias',256], ['Lrelu'] ], 'adain64b' : [['adain',256]], 'toRGB_64' : [['EqConv1x1',256, 3], ['bias',3]], 201 | 'Uconv128' : [ ['EqConvT3x3EX',256, 128], ['blur3x3'], ['amp'], ['noiseE',128, 128], ['bias',128], ['Lrelu'] ], 'adain128a' : [['adain',128]], 202 | 'Fconv128' : [ ['EqConv3x3',128, 128], ['amp'], ['noiseE',128, 128], ['bias',128], ['Lrelu'] ], 'adain128b' : [['adain',128]], 'toRGB_128' : [['EqConv1x1',128, 3], ['bias',3]], 203 | 'Uconv256' : [ ['EqConvT3x3EX',128, 64], ['blur3x3'], ['amp'], ['noiseE', 64, 256], ['bias', 64], ['Lrelu'] ], 'adain256a' : [['adain', 64]], 204 | 'Fconv256' : [ ['EqConv3x3', 64, 64], ['amp'], ['noiseE', 64, 256], ['bias', 64], ['Lrelu'] ], 'adain256b' : [['adain', 64]], 'toRGB_256' : [['EqConv1x1', 64, 3], ['bias',3]], 205 | 'Uconv512' : [ ['EqConvT3x3EX', 64, 32], ['blur3x3'], ['amp'], ['noiseE', 32, 512], ['bias', 32], ['Lrelu'] ], 'adain512a' : [['adain', 32]], 206 | 'Fconv512' : [ ['EqConv3x3', 32, 32], ['amp'], ['noiseE', 32, 512], ['bias', 32], ['Lrelu'] ], 'adain512b' : [['adain', 32]], 'toRGB_512' : [['EqConv1x1', 32, 3], ['bias',3]], 207 | 'Uconv1024': [ ['EqConvT3x3EX', 32, 16], ['blur3x3'], ['amp'], ['noiseE', 16, 1024], ['bias', 16], ['Lrelu'] ], 'adain1024a': [['adain', 16]], 208 | 'Fconv1024': [ ['EqConv3x3', 16, 16], ['amp'], ['noiseE', 16, 1024], ['bias', 16], ['Lrelu'] ], 'adain1024b': [['adain', 16]], 'toRGB_1024': [['EqConv1x1', 16, 3], ['bias',3]], 209 | } 210 | 211 | def _make_sequential(self,key): 212 | definition = { 213 | 'pixel_norm' : lambda *config: PixelwiseNormalization(), 214 | 'truncation' : lambda *config: TruncationTrick( 215 | num_target=8, threshold=0.7, output_num=18, style_dim=512 ), 216 | 'fc' : lambda *config: EqualizedFullyConnect( 217 | in_dim=config[0],out_dim=config[1], lr=0.01), 218 | 'b' : lambda *config: AddChannelwiseBias(out_channels=config[0], lr=0.01), 219 | 'bias' : lambda *config: AddChannelwiseBias(out_channels=config[0], lr=1.0), 220 | 'amp' : lambda *config: Amplify(2**0.5), 221 | 'Lrelu' : lambda *config: nn.LeakyReLU(negative_slope=0.2), 222 | 'EqConvT3x3EX' : lambda *config: EqualizedFusedConvTransposed2d( 223 | in_channels=config[0], out_channels=config[1], 224 | kernel_size=3, stride=1, padding=1, lr=1.0), 225 | 'EqConv3x3' : lambda *config: EqualizedConv2d( 226 | in_channels=config[0], out_channels=config[1], 227 | kernel_size=3, stride=1, padding=1, lr=1.0), 228 | 'EqConv1x1' : lambda *config: EqualizedConv2d( 229 | in_channels=config[0], out_channels=config[1], 230 | kernel_size=1, stride=1, padding=0, lr=1.0), 231 | 'noiseE' : lambda *config: ElementwiseNoise(ch=config[0], size_hw=config[1]), 232 | 'blur3x3' : lambda *config: Blur3x3(), 233 | 'up' : lambda *config: nn.Upsample( 234 | scale_factor=2,mode='nearest'), 235 | 'adain' : lambda *config: EqualizedAdaIN( 236 | fmap_ch=config[0], style_ch=512, lr=1.0), 237 | } 238 | return nn.Sequential(*[ definition[k](*cfg) for k,*cfg in self.structure[key]]) 239 | 240 | def __init__(self): 241 | super().__init__() 242 | 243 | # 固定入力値 244 | self.register_buffer('const',torch.ones((1, 512, 4, 4),dtype=torch.float32)) 245 | 246 | # 今回は使わない 247 | self.register_buffer('image_mixing_rate',torch.zeros((1,))) # 複数のtoRGBの合成比率 248 | self.register_buffer('style_mixing_rate',torch.zeros((1,))) # スタイルの合成比率 249 | 250 | # 潜在変数のマッピングネットワーク 251 | self.mapping = self._make_sequential('mapping') 252 | self.blocks = nn.ModuleList([self._make_sequential(k) for k in [ 253 | 'START', 'Fconv4', 'Uconv8', 'Fconv8', 'Uconv16', 'Fconv16', 254 | 'Uconv32', 'Fconv32', 'Uconv64', 'Fconv64', 'Uconv128', 'Fconv128', 255 | 'Uconv256', 'Fconv256', 'Uconv512', 'Fconv512', 'Uconv1024','Fconv1024' 256 | ] ]) 257 | self.adains = nn.ModuleList([self._make_sequential(k) for k in [ 258 | 'adain4a', 'adain4b', 'adain8a', 'adain8b', 259 | 'adain16a', 'adain16b', 'adain32a', 'adain32b', 260 | 'adain64a', 'adain64b', 'adain128a', 'adain128b', 261 | 'adain256a', 'adain256b', 'adain512a', 'adain512b', 262 | 'adain1024a', 'adain1024b' 263 | ] ]) 264 | self.toRGBs = nn.ModuleList([self._make_sequential(k) for k in [ 265 | 'toRGB_4', 'toRGB_8', 'toRGB_16', 'toRGB_32', 266 | 'toRGB_64', 'toRGB_128', 'toRGB_256', 'toRGB_512', 267 | 'toRGB_1024' 268 | ] ]) 269 | 270 | def forward(self, z): 271 | ''' 272 | input: z : (N,D) D=512 273 | output: img : (N,3,1024,1024) 274 | ''' 275 | N,D = z.shape 276 | 277 | styles = self.mapping(z) # (N,18,D) 278 | tmp = self.const.expand(N,512,4,4) 279 | for i, (adain, conv) in enumerate(zip(self.adains, self.blocks)): 280 | tmp = conv(tmp) 281 | tmp = adain( (tmp, styles[:,i,:]) ) 282 | img = self.toRGBs[-1](tmp) 283 | 284 | return img 285 | 286 | 287 | ########## 以下,重み変換 ######## 288 | 289 | name_trans_dict = { 290 | 'const' : ['any', 'G_synthesis/4x4/Const/const' ], 291 | 'image_mixing_rate' : ['uns', 'G_synthesis/lod' ], 292 | 'style_mixing_rate' : ['uns', 'lod' ], 293 | 'mapping.1.weight' : ['fc_', 'G_mapping/Dense0/weight' ], 294 | 'mapping.3.bias' : ['any', 'G_mapping/Dense0/bias' ], 295 | 'mapping.5.weight' : ['fc_', 'G_mapping/Dense1/weight' ], 296 | 'mapping.7.bias' : ['any', 'G_mapping/Dense1/bias' ], 297 | 'mapping.9.weight' : ['fc_', 'G_mapping/Dense2/weight' ], 298 | 'mapping.11.bias' : ['any', 'G_mapping/Dense2/bias' ], 299 | 'mapping.13.weight' : ['fc_', 'G_mapping/Dense3/weight' ], 300 | 'mapping.15.bias' : ['any', 'G_mapping/Dense3/bias' ], 301 | 'mapping.17.weight' : ['fc_', 'G_mapping/Dense4/weight' ], 302 | 'mapping.19.bias' : ['any', 'G_mapping/Dense4/bias' ], 303 | 'mapping.21.weight' : ['fc_', 'G_mapping/Dense5/weight' ], 304 | 'mapping.23.bias' : ['any', 'G_mapping/Dense5/bias' ], 305 | 'mapping.25.weight' : ['fc_', 'G_mapping/Dense6/weight' ], 306 | 'mapping.27.bias' : ['any', 'G_mapping/Dense6/bias' ], 307 | 'mapping.29.weight' : ['fc_', 'G_mapping/Dense7/weight' ], 308 | 'mapping.31.bias' : ['any', 'G_mapping/Dense7/bias' ], 309 | 'mapping.33.avg_style' : ['any', 'dlatent_avg' ], 310 | 'blocks.0.0.noise_scaler' : ['any', 'G_synthesis/4x4/Const/Noise/weight' ], 311 | 'blocks.0.0.const_noise' : ['any', 'G_synthesis/noise0' ], 312 | 'blocks.0.1.bias' : ['any', 'G_synthesis/4x4/Const/bias' ], 313 | 'blocks.1.0.weight' : ['con', 'G_synthesis/4x4/Conv/weight' ], 314 | 'blocks.1.2.noise_scaler' : ['any', 'G_synthesis/4x4/Conv/Noise/weight' ], 315 | 'blocks.1.2.const_noise' : ['any', 'G_synthesis/noise1' ], 316 | 'blocks.1.3.bias' : ['any', 'G_synthesis/4x4/Conv/bias' ], 317 | 'blocks.2.1.weight' : ['con', 'G_synthesis/8x8/Conv0_up/weight' ], 318 | 'blocks.2.4.noise_scaler' : ['any', 'G_synthesis/8x8/Conv0_up/Noise/weight' ], 319 | 'blocks.2.4.const_noise' : ['any', 'G_synthesis/noise2' ], 320 | 'blocks.2.5.bias' : ['any', 'G_synthesis/8x8/Conv0_up/bias' ], 321 | 'blocks.3.0.weight' : ['con', 'G_synthesis/8x8/Conv1/weight' ], 322 | 'blocks.3.2.noise_scaler' : ['any', 'G_synthesis/8x8/Conv1/Noise/weight' ], 323 | 'blocks.3.2.const_noise' : ['any', 'G_synthesis/noise3' ], 324 | 'blocks.3.3.bias' : ['any', 'G_synthesis/8x8/Conv1/bias' ], 325 | 'blocks.4.1.weight' : ['con', 'G_synthesis/16x16/Conv0_up/weight' ], 326 | 'blocks.4.4.noise_scaler' : ['any', 'G_synthesis/16x16/Conv0_up/Noise/weight' ], 327 | 'blocks.4.4.const_noise' : ['any', 'G_synthesis/noise4' ], 328 | 'blocks.4.5.bias' : ['any', 'G_synthesis/16x16/Conv0_up/bias' ], 329 | 'blocks.5.0.weight' : ['con', 'G_synthesis/16x16/Conv1/weight' ], 330 | 'blocks.5.2.noise_scaler' : ['any', 'G_synthesis/16x16/Conv1/Noise/weight' ], 331 | 'blocks.5.2.const_noise' : ['any', 'G_synthesis/noise5' ], 332 | 'blocks.5.3.bias' : ['any', 'G_synthesis/16x16/Conv1/bias' ], 333 | 'blocks.6.1.weight' : ['con', 'G_synthesis/32x32/Conv0_up/weight' ], 334 | 'blocks.6.4.noise_scaler' : ['any', 'G_synthesis/32x32/Conv0_up/Noise/weight' ], 335 | 'blocks.6.4.const_noise' : ['any', 'G_synthesis/noise6' ], 336 | 'blocks.6.5.bias' : ['any', 'G_synthesis/32x32/Conv0_up/bias' ], 337 | 'blocks.7.0.weight' : ['con', 'G_synthesis/32x32/Conv1/weight' ], 338 | 'blocks.7.2.noise_scaler' : ['any', 'G_synthesis/32x32/Conv1/Noise/weight' ], 339 | 'blocks.7.2.const_noise' : ['any', 'G_synthesis/noise7' ], 340 | 'blocks.7.3.bias' : ['any', 'G_synthesis/32x32/Conv1/bias' ], 341 | 'blocks.8.1.weight' : ['con', 'G_synthesis/64x64/Conv0_up/weight' ], 342 | 'blocks.8.4.noise_scaler' : ['any', 'G_synthesis/64x64/Conv0_up/Noise/weight' ], 343 | 'blocks.8.4.const_noise' : ['any', 'G_synthesis/noise8' ], 344 | 'blocks.8.5.bias' : ['any', 'G_synthesis/64x64/Conv0_up/bias' ], 345 | 'blocks.9.0.weight' : ['con', 'G_synthesis/64x64/Conv1/weight' ], 346 | 'blocks.9.2.noise_scaler' : ['any', 'G_synthesis/64x64/Conv1/Noise/weight' ], 347 | 'blocks.9.2.const_noise' : ['any', 'G_synthesis/noise9' ], 348 | 'blocks.9.3.bias' : ['any', 'G_synthesis/64x64/Conv1/bias' ], 349 | 'blocks.10.0.weight' : ['Tco', 'G_synthesis/128x128/Conv0_up/weight' ], 350 | 'blocks.10.3.noise_scaler' : ['any', 'G_synthesis/128x128/Conv0_up/Noise/weight' ], 351 | 'blocks.10.3.const_noise' : ['any', 'G_synthesis/noise10' ], 352 | 'blocks.10.4.bias' : ['any', 'G_synthesis/128x128/Conv0_up/bias' ], 353 | 'blocks.11.0.weight' : ['con', 'G_synthesis/128x128/Conv1/weight' ], 354 | 'blocks.11.2.noise_scaler' : ['any', 'G_synthesis/128x128/Conv1/Noise/weight' ], 355 | 'blocks.11.2.const_noise' : ['any', 'G_synthesis/noise11' ], 356 | 'blocks.11.3.bias' : ['any', 'G_synthesis/128x128/Conv1/bias' ], 357 | 'blocks.12.0.weight' : ['Tco', 'G_synthesis/256x256/Conv0_up/weight' ], 358 | 'blocks.12.3.noise_scaler' : ['any', 'G_synthesis/256x256/Conv0_up/Noise/weight' ], 359 | 'blocks.12.3.const_noise' : ['any', 'G_synthesis/noise12' ], 360 | 'blocks.12.4.bias' : ['any', 'G_synthesis/256x256/Conv0_up/bias' ], 361 | 'blocks.13.0.weight' : ['con', 'G_synthesis/256x256/Conv1/weight' ], 362 | 'blocks.13.2.noise_scaler' : ['any', 'G_synthesis/256x256/Conv1/Noise/weight' ], 363 | 'blocks.13.2.const_noise' : ['any', 'G_synthesis/noise13' ], 364 | 'blocks.13.3.bias' : ['any', 'G_synthesis/256x256/Conv1/bias' ], 365 | 'blocks.14.0.weight' : ['Tco', 'G_synthesis/512x512/Conv0_up/weight' ], 366 | 'blocks.14.3.noise_scaler' : ['any', 'G_synthesis/512x512/Conv0_up/Noise/weight' ], 367 | 'blocks.14.3.const_noise' : ['any', 'G_synthesis/noise14' ], 368 | 'blocks.14.4.bias' : ['any', 'G_synthesis/512x512/Conv0_up/bias' ], 369 | 'blocks.15.0.weight' : ['con', 'G_synthesis/512x512/Conv1/weight' ], 370 | 'blocks.15.2.noise_scaler' : ['any', 'G_synthesis/512x512/Conv1/Noise/weight' ], 371 | 'blocks.15.2.const_noise' : ['any', 'G_synthesis/noise15' ], 372 | 'blocks.15.3.bias' : ['any', 'G_synthesis/512x512/Conv1/bias' ], 373 | 'blocks.16.0.weight' : ['Tco', 'G_synthesis/1024x1024/Conv0_up/weight' ], 374 | 'blocks.16.3.noise_scaler' : ['any', 'G_synthesis/1024x1024/Conv0_up/Noise/weight' ], 375 | 'blocks.16.3.const_noise' : ['any', 'G_synthesis/noise16' ], 376 | 'blocks.16.4.bias' : ['any', 'G_synthesis/1024x1024/Conv0_up/bias' ], 377 | 'blocks.17.0.weight' : ['con', 'G_synthesis/1024x1024/Conv1/weight' ], 378 | 'blocks.17.2.noise_scaler' : ['any', 'G_synthesis/1024x1024/Conv1/Noise/weight' ], 379 | 'blocks.17.2.const_noise' : ['any', 'G_synthesis/noise17' ], 380 | 'blocks.17.3.bias' : ['any', 'G_synthesis/1024x1024/Conv1/bias' ], 381 | 'adains.0.0.fc.weight' : ['fc_', 'G_synthesis/4x4/Const/StyleMod/weight' ], 382 | 'adains.0.0.bias.bias' : ['any', 'G_synthesis/4x4/Const/StyleMod/bias' ], 383 | 'adains.1.0.fc.weight' : ['fc_', 'G_synthesis/4x4/Conv/StyleMod/weight' ], 384 | 'adains.1.0.bias.bias' : ['any', 'G_synthesis/4x4/Conv/StyleMod/bias' ], 385 | 'adains.2.0.fc.weight' : ['fc_', 'G_synthesis/8x8/Conv0_up/StyleMod/weight' ], 386 | 'adains.2.0.bias.bias' : ['any', 'G_synthesis/8x8/Conv0_up/StyleMod/bias' ], 387 | 'adains.3.0.fc.weight' : ['fc_', 'G_synthesis/8x8/Conv1/StyleMod/weight' ], 388 | 'adains.3.0.bias.bias' : ['any', 'G_synthesis/8x8/Conv1/StyleMod/bias' ], 389 | 'adains.4.0.fc.weight' : ['fc_', 'G_synthesis/16x16/Conv0_up/StyleMod/weight' ], 390 | 'adains.4.0.bias.bias' : ['any', 'G_synthesis/16x16/Conv0_up/StyleMod/bias' ], 391 | 'adains.5.0.fc.weight' : ['fc_', 'G_synthesis/16x16/Conv1/StyleMod/weight' ], 392 | 'adains.5.0.bias.bias' : ['any', 'G_synthesis/16x16/Conv1/StyleMod/bias' ], 393 | 'adains.6.0.fc.weight' : ['fc_', 'G_synthesis/32x32/Conv0_up/StyleMod/weight' ], 394 | 'adains.6.0.bias.bias' : ['any', 'G_synthesis/32x32/Conv0_up/StyleMod/bias' ], 395 | 'adains.7.0.fc.weight' : ['fc_', 'G_synthesis/32x32/Conv1/StyleMod/weight' ], 396 | 'adains.7.0.bias.bias' : ['any', 'G_synthesis/32x32/Conv1/StyleMod/bias' ], 397 | 'adains.8.0.fc.weight' : ['fc_', 'G_synthesis/64x64/Conv0_up/StyleMod/weight' ], 398 | 'adains.8.0.bias.bias' : ['any', 'G_synthesis/64x64/Conv0_up/StyleMod/bias' ], 399 | 'adains.9.0.fc.weight' : ['fc_', 'G_synthesis/64x64/Conv1/StyleMod/weight' ], 400 | 'adains.9.0.bias.bias' : ['any', 'G_synthesis/64x64/Conv1/StyleMod/bias' ], 401 | 'adains.10.0.fc.weight' : ['fc_', 'G_synthesis/128x128/Conv0_up/StyleMod/weight' ], 402 | 'adains.10.0.bias.bias' : ['any', 'G_synthesis/128x128/Conv0_up/StyleMod/bias' ], 403 | 'adains.11.0.fc.weight' : ['fc_', 'G_synthesis/128x128/Conv1/StyleMod/weight' ], 404 | 'adains.11.0.bias.bias' : ['any', 'G_synthesis/128x128/Conv1/StyleMod/bias' ], 405 | 'adains.12.0.fc.weight' : ['fc_', 'G_synthesis/256x256/Conv0_up/StyleMod/weight' ], 406 | 'adains.12.0.bias.bias' : ['any', 'G_synthesis/256x256/Conv0_up/StyleMod/bias' ], 407 | 'adains.13.0.fc.weight' : ['fc_', 'G_synthesis/256x256/Conv1/StyleMod/weight' ], 408 | 'adains.13.0.bias.bias' : ['any', 'G_synthesis/256x256/Conv1/StyleMod/bias' ], 409 | 'adains.14.0.fc.weight' : ['fc_', 'G_synthesis/512x512/Conv0_up/StyleMod/weight' ], 410 | 'adains.14.0.bias.bias' : ['any', 'G_synthesis/512x512/Conv0_up/StyleMod/bias' ], 411 | 'adains.15.0.fc.weight' : ['fc_', 'G_synthesis/512x512/Conv1/StyleMod/weight' ], 412 | 'adains.15.0.bias.bias' : ['any', 'G_synthesis/512x512/Conv1/StyleMod/bias' ], 413 | 'adains.16.0.fc.weight' : ['fc_', 'G_synthesis/1024x1024/Conv0_up/StyleMod/weight' ], 414 | 'adains.16.0.bias.bias' : ['any', 'G_synthesis/1024x1024/Conv0_up/StyleMod/bias' ], 415 | 'adains.17.0.fc.weight' : ['fc_', 'G_synthesis/1024x1024/Conv1/StyleMod/weight' ], 416 | 'adains.17.0.bias.bias' : ['any', 'G_synthesis/1024x1024/Conv1/StyleMod/bias' ], 417 | 'toRGBs.0.0.weight' : ['con', 'G_synthesis/ToRGB_lod8/weight' ], 418 | 'toRGBs.0.1.bias' : ['any', 'G_synthesis/ToRGB_lod8/bias' ], 419 | 'toRGBs.1.0.weight' : ['con', 'G_synthesis/ToRGB_lod7/weight' ], 420 | 'toRGBs.1.1.bias' : ['any', 'G_synthesis/ToRGB_lod7/bias' ], 421 | 'toRGBs.2.0.weight' : ['con', 'G_synthesis/ToRGB_lod6/weight' ], 422 | 'toRGBs.2.1.bias' : ['any', 'G_synthesis/ToRGB_lod6/bias' ], 423 | 'toRGBs.3.0.weight' : ['con', 'G_synthesis/ToRGB_lod5/weight' ], 424 | 'toRGBs.3.1.bias' : ['any', 'G_synthesis/ToRGB_lod5/bias' ], 425 | 'toRGBs.4.0.weight' : ['con', 'G_synthesis/ToRGB_lod4/weight' ], 426 | 'toRGBs.4.1.bias' : ['any', 'G_synthesis/ToRGB_lod4/bias' ], 427 | 'toRGBs.5.0.weight' : ['con', 'G_synthesis/ToRGB_lod3/weight' ], 428 | 'toRGBs.5.1.bias' : ['any', 'G_synthesis/ToRGB_lod3/bias' ], 429 | 'toRGBs.6.0.weight' : ['con', 'G_synthesis/ToRGB_lod2/weight' ], 430 | 'toRGBs.6.1.bias' : ['any', 'G_synthesis/ToRGB_lod2/bias' ], 431 | 'toRGBs.7.0.weight' : ['con', 'G_synthesis/ToRGB_lod1/weight' ], 432 | 'toRGBs.7.1.bias' : ['any', 'G_synthesis/ToRGB_lod1/bias' ], 433 | 'toRGBs.8.0.weight' : ['con', 'G_synthesis/ToRGB_lod0/weight' ], 434 | 'toRGBs.8.1.bias' : ['any', 'G_synthesis/ToRGB_lod0/bias' ], 435 | } 436 | 437 | 438 | 439 | 440 | 441 | 442 | 443 | 444 | 445 | 446 | 447 | 448 | 449 | 450 | 451 | 452 | 453 | 454 | 455 | 456 | 457 | 458 | 459 | 460 | 461 | 462 | 463 | # 変換関数 464 | ops_dict = { 465 | # 変調転置畳み込みの重み (iC,oC,kH,kW) 466 | 'mTc' : lambda weight: torch.flip(torch.from_numpy(weight.transpose((2,3,0,1))), [2, 3]), 467 | # 転置畳み込みの重み (iC,oC,kH,kW) 468 | 'Tco' : lambda weight: torch.from_numpy(weight.transpose((2,3,0,1))), 469 | # 畳み込みの重み (oC,iC,kH,kW) 470 | 'con' : lambda weight: torch.from_numpy(weight.transpose((3,2,0,1))), 471 | # 全結合層の重み (oD, iD) 472 | 'fc_' : lambda weight: torch.from_numpy(weight.transpose((1, 0))), 473 | # 全結合層のバイアス項, 固定入力, 固定ノイズ, v1ノイズの重み (無変換) 474 | 'any' : lambda weight: torch.from_numpy(weight), 475 | # Style-Mixingの値, v2ノイズの重み (scalar) 476 | 'uns' : lambda weight: torch.from_numpy(np.array(weight).reshape(1)), 477 | } 478 | 479 | 480 | if __name__ == '__main__': 481 | # コマンドライン引数の取得 482 | args = parse_args() 483 | 484 | cfg = { 485 | 'src_weight': 'stylegan1_ndarray.pkl', 486 | 'src_latent': 'latents1.pkl', 487 | 'dst_image' : 'stylegan1_pt.png', 488 | 'dst_weight': 'stylegan1_state_dict.pth' 489 | } 490 | 491 | print('model construction...') 492 | generator = Generator() 493 | base_dict = generator.state_dict() 494 | 495 | print('model weights load...') 496 | with (Path(args.weight_dir)/cfg['src_weight']).open('rb') as f: 497 | src_dict = pickle.load(f) 498 | 499 | print('set state_dict...') 500 | new_dict = { k : ops_dict[v[0]](src_dict[v[1]]) for k,v in name_trans_dict.items()} 501 | generator.load_state_dict(new_dict) 502 | 503 | print('load latents...') 504 | with (Path(args.output_dir)/cfg['src_latent']).open('rb') as f: 505 | latents = pickle.load(f) 506 | latents = torch.from_numpy(latents.astype(np.float32)) 507 | 508 | print('network forward...') 509 | device = torch.device('cuda') if torch.cuda.is_available() and args.device=='gpu' else torch.device('cpu') 510 | with torch.no_grad(): 511 | N,_ = latents.shape 512 | generator.to(device) 513 | images = np.empty((N,args.resolution,args.resolution,3),dtype=np.uint8) 514 | 515 | for i in range(0,N,args.batch_size): 516 | j = min(i+args.batch_size,N) 517 | z = latents[i:j].to(device) 518 | img = generator(z) 519 | normalized = (img.clamp(-1,1)+1)/2*255 520 | images[i:j] = normalized.permute(0,2,3,1).cpu().numpy().astype(np.uint8) 521 | del z, img, normalized 522 | 523 | # 出力を並べる関数 524 | def make_table(imgs): 525 | # 出力する個数,解像度 526 | num_H, num_W = 4,4 527 | H = W = args.resolution 528 | num_images = num_H*num_W 529 | 530 | canvas = np.zeros((H*num_H,W*num_W,3),dtype=np.uint8) 531 | for i,p in enumerate(imgs[:num_images]): 532 | h,w = i//num_W, i%num_W 533 | canvas[H*h:H*-~h,W*w:W*-~w,:] = p[:,:,::-1] 534 | return canvas 535 | 536 | print('image output...') 537 | cv2.imwrite(str(Path(args.output_dir)/cfg['dst_image']), make_table(images)) 538 | 539 | print('weight save...') 540 | torch.save(generator.state_dict(),str(Path(args.weight_dir)/cfg['dst_weight'])) 541 | 542 | print('all done') 543 | -------------------------------------------------------------------------------- /packaged/run_pt_stylegan2.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import pickle 4 | 5 | import numpy as np 6 | import cv2 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | # コマンドライン引数の取得 13 | def parse_args(): 14 | parser = argparse.ArgumentParser(description='著者実装を動かしたり重みを抜き出したり') 15 | parser.add_argument('-w','--weight_dir',type=str,default='/tmp/stylegans-pytorch', 16 | help='学習済みのモデルを保存する場所') 17 | parser.add_argument('-o','--output_dir',type=str,default='/tmp/stylegans-pytorch', 18 | help='生成された画像を保存する場所') 19 | parser.add_argument('--batch_size',type=int,default=1, 20 | help='バッチサイズ') 21 | parser.add_argument('--device',type=str,default='gpu',choices=['gpu','cpu'], 22 | help='デバイス') 23 | args = parser.parse_args() 24 | args.resolution = 1024 25 | return args 26 | 27 | 28 | # mapping前に潜在変数を超球面上に正規化 29 | class PixelwiseNormalization(nn.Module): 30 | def __init__(self): 31 | super().__init__() 32 | def forward(self, x): 33 | return x / torch.sqrt((x**2).mean(1,keepdim=True) + 1e-8) 34 | 35 | 36 | # 移動平均を用いて潜在変数を正規化する. 37 | class TruncationTrick(nn.Module): 38 | def __init__(self, num_target, threshold, output_num, style_dim): 39 | super().__init__() 40 | self.num_target = num_target 41 | self.threshold = threshold 42 | self.output_num = output_num 43 | self.register_buffer('avg_style', torch.zeros((style_dim,))) 44 | 45 | def forward(self, x): 46 | # in:(N,D) -> out:(N,O,D) 47 | N,D = x.shape 48 | O = self.output_num 49 | x = x.view(N,1,D).expand(N,O,D) 50 | rate = torch.cat([ torch.ones((N, self.num_target, D)) *self.threshold, 51 | torch.ones((N, O-self.num_target, D)) *1.0 ],1).to(x.device) 52 | avg = self.avg_style.view(1,1,D).expand(N,O,D) 53 | return avg + (x-avg)*rate 54 | 55 | 56 | # 特徴マップ信号を増幅する 57 | class Amplify(nn.Module): 58 | def __init__(self, rate): 59 | super().__init__() 60 | self.rate = rate 61 | def forward(self,x): 62 | return x * self.rate 63 | 64 | 65 | # チャンネルごとにバイアス項を足す 66 | class AddChannelwiseBias(nn.Module): 67 | def __init__(self, out_channels, lr): 68 | super().__init__() 69 | # lr = 1.0 (conv,mod,AdaIN), 0.01 (mapping) 70 | 71 | self.bias = nn.Parameter(torch.zeros(out_channels)) 72 | torch.nn.init.zeros_(self.bias.data) 73 | self.bias_scaler = lr 74 | 75 | def forward(self, x): 76 | oC,*_ = self.bias.shape 77 | shape = (1,oC) if x.ndim==2 else (1,oC,1,1) 78 | y = x + self.bias.view(*shape)*self.bias_scaler 79 | return y 80 | 81 | 82 | # 学習率を調整したFC層 83 | class EqualizedFullyConnect(nn.Module): 84 | def __init__(self, in_dim, out_dim, lr): 85 | super().__init__() 86 | # lr = 0.01 (mapping), 1.0 (mod,AdaIN) 87 | 88 | self.weight = nn.Parameter(torch.randn((out_dim,in_dim))) 89 | torch.nn.init.normal_(self.weight.data, mean=0.0, std=1.0/lr) 90 | self.weight_scaler = 1/(in_dim**0.5)*lr 91 | 92 | def forward(self, x): 93 | # x (N,D) 94 | return F.linear(x, self.weight*self.weight_scaler, None) 95 | 96 | 97 | # 固定ノイズ 98 | class PixelwiseNoise(nn.Module): 99 | def __init__(self, resolution): 100 | super().__init__() 101 | self.register_buffer("const_noise", torch.randn((1, 1, resolution, resolution))) 102 | self.noise_scaler = nn.Parameter(torch.zeros(1)) 103 | 104 | def forward(self, x): 105 | N,C,H,W = x.shape 106 | noise = self.const_noise.expand(N,C,H,W) 107 | return x + noise * self.noise_scaler 108 | 109 | 110 | # 解像度上げるときのModConvのみ 111 | class FusedBlur3x3(nn.Module): 112 | def __init__(self): 113 | super().__init__() 114 | kernel = np.array([ [1/16, 2/16, 1/16], 115 | [2/16, 4/16, 2/16], 116 | [1/16, 2/16, 1/16]],dtype=np.float32) 117 | pads = [[(0,1),(0,1)],[(0,1),(1,0)],[(1,0),(0,1)],[(1,0),(1,0)]] 118 | kernel = np.stack( [np.pad(kernel,pad,'constant') for pad in pads] ).sum(0) 119 | #kernel [ [1/16, 3/16, 3/16, 1/16,], 120 | # [3/16, 9/16, 9/16, 3/16,], 121 | # [3/16, 9/16, 9/16, 3/16,], 122 | # [1/16, 3/16, 3/16, 1/16,] ] 123 | self.kernel = torch.from_numpy(kernel) 124 | 125 | def forward(self, feature): 126 | # featureは(N,C,H+1,W+1) 127 | kernel = self.kernel.clone().to(feature.device) 128 | _N,C,_Hp1,_Wp1 = feature.shape 129 | return F.conv2d(feature, kernel.expand(C,1,4,4), padding=1, groups=C) 130 | 131 | 132 | # 学習率を調整した変調転置畳み込み 133 | class EqualizedModulatedConvTranspose2d(nn.Module): 134 | def __init__(self, in_channels, out_channels, kernel_size, style_dim, padding, stride, demodulate=True, lr=1): 135 | super().__init__() 136 | 137 | self.padding, self.stride = padding, stride 138 | self.demodulate = demodulate 139 | 140 | self.weight = nn.Parameter( torch.randn(in_channels, out_channels, kernel_size, kernel_size)) 141 | torch.nn.init.normal_(self.weight.data, mean=0.0, std=1.0/lr) 142 | self.weight_scaler = 1 / (in_channels * kernel_size*kernel_size)**0.5 * lr 143 | 144 | self.fc = EqualizedFullyConnect(style_dim, in_channels, lr) 145 | self.bias = AddChannelwiseBias(in_channels, lr) 146 | 147 | def forward(self, pack): 148 | x, style = pack 149 | N, iC, H, W = x.shape 150 | iC, oC, kH, kW = self.weight.shape 151 | 152 | mod_rates = self.bias(self.fc(style))+1 # (N, iC) 153 | modulated_weight = self.weight_scaler*self.weight.view(1,iC,oC,kH,kW) \ 154 | * mod_rates.view(N,iC,1,1,1) # (N,iC,oC,kH,kW) 155 | 156 | if self.demodulate: 157 | demod_norm = 1 / ((modulated_weight**2).sum([1,3,4]) + 1e-8)**0.5 # (N, oC) 158 | weight = modulated_weight * demod_norm.view(N, 1, oC, 1, 1) # (N,iC,oC,kH,kW) 159 | else: 160 | weight = modulated_weight 161 | 162 | x = x.view(1, N*iC, H, W) 163 | weight = weight.view(N*iC,oC,kH,kW) 164 | out = F.conv_transpose2d(x, weight, padding=self.padding, stride=self.stride, groups=N) 165 | 166 | _, _, Hp1, Wp1 = out.shape 167 | out = out.view(N, oC, Hp1, Wp1) 168 | 169 | return out 170 | 171 | 172 | # 学習率を調整した変調畳み込み 173 | class EqualizedModulatedConv2d(nn.Module): 174 | def __init__(self, in_channels, out_channels, kernel_size, style_dim, padding, stride, demodulate=True, lr=1): 175 | super().__init__() 176 | 177 | self.padding, self.stride = padding, stride 178 | self.demodulate = demodulate 179 | 180 | self.weight = nn.Parameter( torch.randn(out_channels, in_channels, kernel_size, kernel_size)) 181 | torch.nn.init.normal_(self.weight.data, mean=0.0, std=1.0/lr) 182 | self.weight_scaler = 1 / (in_channels*kernel_size*kernel_size)**0.5 * lr 183 | 184 | self.fc = EqualizedFullyConnect(style_dim, in_channels, lr) 185 | self.bias = AddChannelwiseBias(in_channels, lr) 186 | 187 | def forward(self, pack): 188 | x, style = pack 189 | N, iC, H, W = x.shape 190 | oC, iC, kH, kW = self.weight.shape 191 | 192 | mod_rates = self.bias(self.fc(style))+1 # (N, iC) 193 | modulated_weight = self.weight_scaler*self.weight.view(1,oC,iC,kH,kW) \ 194 | * mod_rates.view(N,1,iC,1,1) # (N,oC,iC,kH,kW) 195 | 196 | if self.demodulate: 197 | demod_norm = 1 / ((modulated_weight**2).sum([2,3,4]) + 1e-8)**0.5 # (N, oC) 198 | weight = modulated_weight * demod_norm.view(N, oC, 1, 1, 1) # (N,oC,iC,kH,kW) 199 | else: # ToRGB 200 | weight = modulated_weight 201 | 202 | out = F.conv2d(x.view(1,N*iC,H,W), weight.view(N*oC,iC,kH,kW), 203 | padding=self.padding, stride=self.stride, groups=N).view(N,oC,H,W) 204 | return out 205 | 206 | 207 | class Generator(nn.Module): 208 | 209 | structure = { 210 | 'mapping': [['pixel_norm'], ['fc',512,512],['b',512],['amp'],['Lrelu'],['fc',512,512],['b',512],['amp'],['Lrelu'], 211 | ['fc',512,512],['b',512],['amp'],['Lrelu'],['fc',512,512],['b',512],['amp'],['Lrelu'], 212 | ['fc',512,512],['b',512],['amp'],['Lrelu'],['fc',512,512],['b',512],['amp'],['Lrelu'], 213 | ['fc',512,512],['b',512],['amp'],['Lrelu'],['fc',512,512],['b',512],['amp'],['Lrelu'],['truncation']], 214 | 'Fconv4' : [['EqModConv3x3', 512, 512], ['noiseP', 4], ['bias',512], ['amp'], ['Lrelu'] ], 'toRGB_4' : [['EqModConv1x1',512, 3], ['bias',3]], 215 | 'Uconv8' : [['EqModConvT3x3', 512, 512], ['blurEX'], ['noiseP', 8], ['bias',512], ['amp'], ['Lrelu'] ], 216 | 'Fconv8' : [['EqModConv3x3', 512, 512], ['noiseP', 8], ['bias',512], ['amp'], ['Lrelu'] ], 'toRGB_8' : [['EqModConv1x1',512, 3], ['bias',3]], 217 | 'Uconv16' : [['EqModConvT3x3', 512, 512], ['blurEX'], ['noiseP', 16], ['bias',512], ['amp'], ['Lrelu'] ], 218 | 'Fconv16' : [['EqModConv3x3', 512, 512], ['noiseP', 16], ['bias',512], ['amp'], ['Lrelu'] ], 'toRGB_16' : [['EqModConv1x1',512, 3], ['bias',3]], 219 | 'Uconv32' : [['EqModConvT3x3', 512, 512], ['blurEX'], ['noiseP', 32], ['bias',512], ['amp'], ['Lrelu'] ], 220 | 'Fconv32' : [['EqModConv3x3', 512, 512], ['noiseP', 32], ['bias',512], ['amp'], ['Lrelu'] ], 'toRGB_32' : [['EqModConv1x1',512, 3], ['bias',3]], 221 | 'Uconv64' : [['EqModConvT3x3', 512, 512], ['blurEX'], ['noiseP', 64], ['bias',512], ['amp'], ['Lrelu'] ], 222 | 'Fconv64' : [['EqModConv3x3', 512, 512], ['noiseP', 64], ['bias',512], ['amp'], ['Lrelu'] ], 'toRGB_64' : [['EqModConv1x1',512, 3], ['bias',3]], 223 | 'Uconv128' : [['EqModConvT3x3', 512, 256], ['blurEX'], ['noiseP', 128], ['bias',256], ['amp'], ['Lrelu'] ], 224 | 'Fconv128' : [['EqModConv3x3', 256, 256], ['noiseP', 128], ['bias',256], ['amp'], ['Lrelu'] ], 'toRGB_128' : [['EqModConv1x1',256, 3], ['bias',3]], 225 | 'Uconv256' : [['EqModConvT3x3', 256, 128], ['blurEX'], ['noiseP', 256], ['bias',128], ['amp'], ['Lrelu'] ], 226 | 'Fconv256' : [['EqModConv3x3', 128, 128], ['noiseP', 256], ['bias',128], ['amp'], ['Lrelu'] ], 'toRGB_256' : [['EqModConv1x1',128, 3], ['bias',3]], 227 | 'Uconv512' : [['EqModConvT3x3', 128, 64], ['blurEX'], ['noiseP', 512], ['bias', 64], ['amp'], ['Lrelu'] ], 228 | 'Fconv512' : [['EqModConv3x3', 64, 64], ['noiseP', 512], ['bias', 64], ['amp'], ['Lrelu'] ], 'toRGB_512' : [['EqModConv1x1', 64, 3], ['bias',3]], 229 | 'Uconv1024': [['EqModConvT3x3', 64, 32], ['blurEX'], ['noiseP',1024], ['bias', 32], ['amp'], ['Lrelu'] ], 230 | 'Fconv1024': [['EqModConv3x3', 32, 32], ['noiseP',1024], ['bias', 32], ['amp'], ['Lrelu'] ], 'toRGB_1024': [['EqModConv1x1', 32, 3], ['bias',3]], 231 | } 232 | 233 | def _make_sequential(self,key): 234 | definition = { 235 | 'pixel_norm' : lambda *config: PixelwiseNormalization(), 236 | 'truncation' : lambda *config: TruncationTrick( 237 | num_target=10, threshold=0.7, output_num=18, style_dim=512), 238 | 'fc' : lambda *config: EqualizedFullyConnect( 239 | in_dim=config[0], out_dim=config[1], lr=0.01), 240 | 'b' : lambda *config: AddChannelwiseBias(out_channels=config[0], lr=0.01), 241 | 'bias' : lambda *config: AddChannelwiseBias(out_channels=config[0], lr=1.0), 242 | 'amp' : lambda *config: Amplify(2**0.5), 243 | 'Lrelu' : lambda *config: nn.LeakyReLU(negative_slope=0.2), 244 | 'EqModConvT3x3': lambda *config: EqualizedModulatedConvTranspose2d( 245 | in_channels=config[0], out_channels=config[1], 246 | kernel_size=3, stride=2, padding=0, 247 | demodulate=True, lr=1.0, style_dim=512), 248 | 'EqModConv3x3' : lambda *config: EqualizedModulatedConv2d( 249 | in_channels=config[0], out_channels=config[1], 250 | kernel_size=3, stride=1, padding=1, 251 | demodulate=True, lr=1.0, style_dim=512), 252 | 'EqModConv1x1' : lambda *config: EqualizedModulatedConv2d( 253 | in_channels=config[0], out_channels=config[1], 254 | kernel_size=1, stride=1, padding=0, 255 | demodulate=False, lr=1.0, style_dim=512), 256 | 'noiseP' : lambda *config: PixelwiseNoise(resolution=config[0]), 257 | 'blurEX' : lambda *config: FusedBlur3x3(), 258 | } 259 | return nn.Sequential(*[ definition[k](*cfg) for k,*cfg in self.structure[key]]) 260 | 261 | 262 | def __init__(self): 263 | super().__init__() 264 | 265 | self.const_input = nn.Parameter(torch.randn(1, 512, 4, 4)) 266 | self.register_buffer('style_mixing_rate',torch.zeros((1,))) # スタイルの合成比率,今回は使わない 267 | 268 | self.mapping = self._make_sequential('mapping') 269 | self.blocks = nn.ModuleList([self._make_sequential(k) for k in [ 270 | 'Fconv4', 'Uconv8', 'Fconv8', 'Uconv16', 'Fconv16', 271 | 'Uconv32', 'Fconv32', 'Uconv64', 'Fconv64', 'Uconv128', 'Fconv128', 272 | 'Uconv256', 'Fconv256', 'Uconv512', 'Fconv512', 'Uconv1024','Fconv1024' 273 | ] ]) 274 | self.toRGBs = nn.ModuleList([self._make_sequential(k) for k in [ 275 | 'toRGB_4', 'toRGB_8', 'toRGB_16', 'toRGB_32', 276 | 'toRGB_64', 'toRGB_128', 'toRGB_256', 'toRGB_512', 277 | 'toRGB_1024' 278 | ] ]) 279 | 280 | 281 | def forward(self, z): 282 | N,D = z.shape 283 | 284 | # 潜在変数からスタイルへ変換 285 | styles = self.mapping(z) # (N,18,D) 286 | styles = [styles[:,i] for i in range(18)] # list[(N,D),]x18 287 | 288 | tmp = self.const_input.repeat(N, 1, 1, 1) 289 | tmp = self.blocks[0]( (tmp,styles[0]) ) 290 | skip = self.toRGBs[0]( (tmp,styles[1]) ) 291 | 292 | for convU, convF, toRGB, styU,styF,styT in zip( \ 293 | self.blocks[1::2], self.blocks[2::2], self.toRGBs[1:], 294 | styles[1::2], styles[2::2], styles[3::2]): 295 | tmp = convU( (tmp,styU) ) 296 | tmp = convF( (tmp,styF) ) 297 | skip = toRGB( (tmp,styT) ) + F.interpolate(skip,scale_factor=2,mode='bilinear',align_corners=False) 298 | 299 | return skip 300 | 301 | 302 | # { pytorchでの名前 : [変換関数, tensorflowでの名前] } 303 | name_trans_dict = { 304 | 'const_input' : ['any', 'G_synthesis/4x4/Const/const' ], 305 | 'style_mixing_rate' : ['uns', 'lod' ], 306 | 'mapping.1.weight' : ['fc_', 'G_mapping/Dense0/weight' ], 307 | 'mapping.2.bias' : ['any', 'G_mapping/Dense0/bias' ], 308 | 'mapping.5.weight' : ['fc_', 'G_mapping/Dense1/weight' ], 309 | 'mapping.6.bias' : ['any', 'G_mapping/Dense1/bias' ], 310 | 'mapping.9.weight' : ['fc_', 'G_mapping/Dense2/weight' ], 311 | 'mapping.10.bias' : ['any', 'G_mapping/Dense2/bias' ], 312 | 'mapping.13.weight' : ['fc_', 'G_mapping/Dense3/weight' ], 313 | 'mapping.14.bias' : ['any', 'G_mapping/Dense3/bias' ], 314 | 'mapping.17.weight' : ['fc_', 'G_mapping/Dense4/weight' ], 315 | 'mapping.18.bias' : ['any', 'G_mapping/Dense4/bias' ], 316 | 'mapping.21.weight' : ['fc_', 'G_mapping/Dense5/weight' ], 317 | 'mapping.22.bias' : ['any', 'G_mapping/Dense5/bias' ], 318 | 'mapping.25.weight' : ['fc_', 'G_mapping/Dense6/weight' ], 319 | 'mapping.26.bias' : ['any', 'G_mapping/Dense6/bias' ], 320 | 'mapping.29.weight' : ['fc_', 'G_mapping/Dense7/weight' ], 321 | 'mapping.30.bias' : ['any', 'G_mapping/Dense7/bias' ], 322 | 'mapping.33.avg_style' : ['any', 'dlatent_avg' ], 323 | 'blocks.0.0.weight' : ['con', 'G_synthesis/4x4/Conv/weight' ], 324 | 'blocks.0.0.fc.weight' : ['fc_', 'G_synthesis/4x4/Conv/mod_weight' ], 325 | 'blocks.0.0.bias.bias' : ['any', 'G_synthesis/4x4/Conv/mod_bias' ], 326 | 'blocks.0.1.noise_scaler' : ['uns', 'G_synthesis/4x4/Conv/noise_strength' ], 327 | 'blocks.0.1.const_noise' : ['any', 'G_synthesis/noise0' ], 328 | 'blocks.0.2.bias' : ['any', 'G_synthesis/4x4/Conv/bias' ], 329 | 'blocks.1.0.weight' : ['mTc', 'G_synthesis/8x8/Conv0_up/weight' ], 330 | 'blocks.1.0.fc.weight' : ['fc_', 'G_synthesis/8x8/Conv0_up/mod_weight' ], 331 | 'blocks.1.0.bias.bias' : ['any', 'G_synthesis/8x8/Conv0_up/mod_bias' ], 332 | 'blocks.1.2.noise_scaler' : ['uns', 'G_synthesis/8x8/Conv0_up/noise_strength' ], 333 | 'blocks.1.2.const_noise' : ['any', 'G_synthesis/noise1' ], 334 | 'blocks.1.3.bias' : ['any', 'G_synthesis/8x8/Conv0_up/bias' ], 335 | 'blocks.2.0.weight' : ['con', 'G_synthesis/8x8/Conv1/weight' ], 336 | 'blocks.2.0.fc.weight' : ['fc_', 'G_synthesis/8x8/Conv1/mod_weight' ], 337 | 'blocks.2.0.bias.bias' : ['any', 'G_synthesis/8x8/Conv1/mod_bias' ], 338 | 'blocks.2.1.noise_scaler' : ['uns', 'G_synthesis/8x8/Conv1/noise_strength' ], 339 | 'blocks.2.1.const_noise' : ['any', 'G_synthesis/noise2' ], 340 | 'blocks.2.2.bias' : ['any', 'G_synthesis/8x8/Conv1/bias' ], 341 | 'blocks.3.0.weight' : ['mTc', 'G_synthesis/16x16/Conv0_up/weight' ], 342 | 'blocks.3.0.fc.weight' : ['fc_', 'G_synthesis/16x16/Conv0_up/mod_weight' ], 343 | 'blocks.3.0.bias.bias' : ['any', 'G_synthesis/16x16/Conv0_up/mod_bias' ], 344 | 'blocks.3.2.noise_scaler' : ['uns', 'G_synthesis/16x16/Conv0_up/noise_strength' ], 345 | 'blocks.3.2.const_noise' : ['any', 'G_synthesis/noise3' ], 346 | 'blocks.3.3.bias' : ['any', 'G_synthesis/16x16/Conv0_up/bias' ], 347 | 'blocks.4.0.weight' : ['con', 'G_synthesis/16x16/Conv1/weight' ], 348 | 'blocks.4.0.fc.weight' : ['fc_', 'G_synthesis/16x16/Conv1/mod_weight' ], 349 | 'blocks.4.0.bias.bias' : ['any', 'G_synthesis/16x16/Conv1/mod_bias' ], 350 | 'blocks.4.1.noise_scaler' : ['uns', 'G_synthesis/16x16/Conv1/noise_strength' ], 351 | 'blocks.4.1.const_noise' : ['any', 'G_synthesis/noise4' ], 352 | 'blocks.4.2.bias' : ['any', 'G_synthesis/16x16/Conv1/bias' ], 353 | 'blocks.5.0.weight' : ['mTc', 'G_synthesis/32x32/Conv0_up/weight' ], 354 | 'blocks.5.0.fc.weight' : ['fc_', 'G_synthesis/32x32/Conv0_up/mod_weight' ], 355 | 'blocks.5.0.bias.bias' : ['any', 'G_synthesis/32x32/Conv0_up/mod_bias' ], 356 | 'blocks.5.2.noise_scaler' : ['uns', 'G_synthesis/32x32/Conv0_up/noise_strength' ], 357 | 'blocks.5.2.const_noise' : ['any', 'G_synthesis/noise5' ], 358 | 'blocks.5.3.bias' : ['any', 'G_synthesis/32x32/Conv0_up/bias' ], 359 | 'blocks.6.0.weight' : ['con', 'G_synthesis/32x32/Conv1/weight' ], 360 | 'blocks.6.0.fc.weight' : ['fc_', 'G_synthesis/32x32/Conv1/mod_weight' ], 361 | 'blocks.6.0.bias.bias' : ['any', 'G_synthesis/32x32/Conv1/mod_bias' ], 362 | 'blocks.6.1.noise_scaler' : ['uns', 'G_synthesis/32x32/Conv1/noise_strength' ], 363 | 'blocks.6.1.const_noise' : ['any', 'G_synthesis/noise6' ], 364 | 'blocks.6.2.bias' : ['any', 'G_synthesis/32x32/Conv1/bias' ], 365 | 'blocks.7.0.weight' : ['mTc', 'G_synthesis/64x64/Conv0_up/weight' ], 366 | 'blocks.7.0.fc.weight' : ['fc_', 'G_synthesis/64x64/Conv0_up/mod_weight' ], 367 | 'blocks.7.0.bias.bias' : ['any', 'G_synthesis/64x64/Conv0_up/mod_bias' ], 368 | 'blocks.7.2.noise_scaler' : ['uns', 'G_synthesis/64x64/Conv0_up/noise_strength' ], 369 | 'blocks.7.2.const_noise' : ['any', 'G_synthesis/noise7' ], 370 | 'blocks.7.3.bias' : ['any', 'G_synthesis/64x64/Conv0_up/bias' ], 371 | 'blocks.8.0.weight' : ['con', 'G_synthesis/64x64/Conv1/weight' ], 372 | 'blocks.8.0.fc.weight' : ['fc_', 'G_synthesis/64x64/Conv1/mod_weight' ], 373 | 'blocks.8.0.bias.bias' : ['any', 'G_synthesis/64x64/Conv1/mod_bias' ], 374 | 'blocks.8.1.noise_scaler' : ['uns', 'G_synthesis/64x64/Conv1/noise_strength' ], 375 | 'blocks.8.1.const_noise' : ['any', 'G_synthesis/noise8' ], 376 | 'blocks.8.2.bias' : ['any', 'G_synthesis/64x64/Conv1/bias' ], 377 | 'blocks.9.0.weight' : ['mTc', 'G_synthesis/128x128/Conv0_up/weight' ], 378 | 'blocks.9.0.fc.weight' : ['fc_', 'G_synthesis/128x128/Conv0_up/mod_weight' ], 379 | 'blocks.9.0.bias.bias' : ['any', 'G_synthesis/128x128/Conv0_up/mod_bias' ], 380 | 'blocks.9.2.noise_scaler' : ['uns', 'G_synthesis/128x128/Conv0_up/noise_strength' ], 381 | 'blocks.9.2.const_noise' : ['any', 'G_synthesis/noise9' ], 382 | 'blocks.9.3.bias' : ['any', 'G_synthesis/128x128/Conv0_up/bias' ], 383 | 'blocks.10.0.weight' : ['con', 'G_synthesis/128x128/Conv1/weight' ], 384 | 'blocks.10.0.fc.weight' : ['fc_', 'G_synthesis/128x128/Conv1/mod_weight' ], 385 | 'blocks.10.0.bias.bias' : ['any', 'G_synthesis/128x128/Conv1/mod_bias' ], 386 | 'blocks.10.1.noise_scaler' : ['uns', 'G_synthesis/128x128/Conv1/noise_strength' ], 387 | 'blocks.10.1.const_noise' : ['any', 'G_synthesis/noise10' ], 388 | 'blocks.10.2.bias' : ['any', 'G_synthesis/128x128/Conv1/bias' ], 389 | 'blocks.11.0.weight' : ['mTc', 'G_synthesis/256x256/Conv0_up/weight' ], 390 | 'blocks.11.0.fc.weight' : ['fc_', 'G_synthesis/256x256/Conv0_up/mod_weight' ], 391 | 'blocks.11.0.bias.bias' : ['any', 'G_synthesis/256x256/Conv0_up/mod_bias' ], 392 | 'blocks.11.2.noise_scaler' : ['uns', 'G_synthesis/256x256/Conv0_up/noise_strength' ], 393 | 'blocks.11.2.const_noise' : ['any', 'G_synthesis/noise11' ], 394 | 'blocks.11.3.bias' : ['any', 'G_synthesis/256x256/Conv0_up/bias' ], 395 | 'blocks.12.0.weight' : ['con', 'G_synthesis/256x256/Conv1/weight' ], 396 | 'blocks.12.0.fc.weight' : ['fc_', 'G_synthesis/256x256/Conv1/mod_weight' ], 397 | 'blocks.12.0.bias.bias' : ['any', 'G_synthesis/256x256/Conv1/mod_bias' ], 398 | 'blocks.12.1.noise_scaler' : ['uns', 'G_synthesis/256x256/Conv1/noise_strength' ], 399 | 'blocks.12.1.const_noise' : ['any', 'G_synthesis/noise12' ], 400 | 'blocks.12.2.bias' : ['any', 'G_synthesis/256x256/Conv1/bias' ], 401 | 'blocks.13.0.weight' : ['mTc', 'G_synthesis/512x512/Conv0_up/weight' ], 402 | 'blocks.13.0.fc.weight' : ['fc_', 'G_synthesis/512x512/Conv0_up/mod_weight' ], 403 | 'blocks.13.0.bias.bias' : ['any', 'G_synthesis/512x512/Conv0_up/mod_bias' ], 404 | 'blocks.13.2.noise_scaler' : ['uns', 'G_synthesis/512x512/Conv0_up/noise_strength' ], 405 | 'blocks.13.2.const_noise' : ['any', 'G_synthesis/noise13' ], 406 | 'blocks.13.3.bias' : ['any', 'G_synthesis/512x512/Conv0_up/bias' ], 407 | 'blocks.14.0.weight' : ['con', 'G_synthesis/512x512/Conv1/weight' ], 408 | 'blocks.14.0.fc.weight' : ['fc_', 'G_synthesis/512x512/Conv1/mod_weight' ], 409 | 'blocks.14.0.bias.bias' : ['any', 'G_synthesis/512x512/Conv1/mod_bias' ], 410 | 'blocks.14.1.noise_scaler' : ['uns', 'G_synthesis/512x512/Conv1/noise_strength' ], 411 | 'blocks.14.1.const_noise' : ['any', 'G_synthesis/noise14' ], 412 | 'blocks.14.2.bias' : ['any', 'G_synthesis/512x512/Conv1/bias' ], 413 | 'blocks.15.0.weight' : ['mTc', 'G_synthesis/1024x1024/Conv0_up/weight' ], 414 | 'blocks.15.0.fc.weight' : ['fc_', 'G_synthesis/1024x1024/Conv0_up/mod_weight' ], 415 | 'blocks.15.0.bias.bias' : ['any', 'G_synthesis/1024x1024/Conv0_up/mod_bias' ], 416 | 'blocks.15.2.noise_scaler' : ['uns', 'G_synthesis/1024x1024/Conv0_up/noise_strength'], 417 | 'blocks.15.2.const_noise' : ['any', 'G_synthesis/noise15' ], 418 | 'blocks.15.3.bias' : ['any', 'G_synthesis/1024x1024/Conv0_up/bias' ], 419 | 'blocks.16.0.weight' : ['con', 'G_synthesis/1024x1024/Conv1/weight' ], 420 | 'blocks.16.0.fc.weight' : ['fc_', 'G_synthesis/1024x1024/Conv1/mod_weight' ], 421 | 'blocks.16.0.bias.bias' : ['any', 'G_synthesis/1024x1024/Conv1/mod_bias' ], 422 | 'blocks.16.1.noise_scaler' : ['uns', 'G_synthesis/1024x1024/Conv1/noise_strength' ], 423 | 'blocks.16.1.const_noise' : ['any', 'G_synthesis/noise16' ], 424 | 'blocks.16.2.bias' : ['any', 'G_synthesis/1024x1024/Conv1/bias' ], 425 | 'toRGBs.0.0.weight' : ['con', 'G_synthesis/4x4/ToRGB/weight' ], 426 | 'toRGBs.0.0.fc.weight' : ['fc_', 'G_synthesis/4x4/ToRGB/mod_weight' ], 427 | 'toRGBs.0.0.bias.bias' : ['any', 'G_synthesis/4x4/ToRGB/mod_bias' ], 428 | 'toRGBs.0.1.bias' : ['any', 'G_synthesis/4x4/ToRGB/bias' ], 429 | 'toRGBs.1.0.weight' : ['con', 'G_synthesis/8x8/ToRGB/weight' ], 430 | 'toRGBs.1.0.fc.weight' : ['fc_', 'G_synthesis/8x8/ToRGB/mod_weight' ], 431 | 'toRGBs.1.0.bias.bias' : ['any', 'G_synthesis/8x8/ToRGB/mod_bias' ], 432 | 'toRGBs.1.1.bias' : ['any', 'G_synthesis/8x8/ToRGB/bias' ], 433 | 'toRGBs.2.0.weight' : ['con', 'G_synthesis/16x16/ToRGB/weight' ], 434 | 'toRGBs.2.0.fc.weight' : ['fc_', 'G_synthesis/16x16/ToRGB/mod_weight' ], 435 | 'toRGBs.2.0.bias.bias' : ['any', 'G_synthesis/16x16/ToRGB/mod_bias' ], 436 | 'toRGBs.2.1.bias' : ['any', 'G_synthesis/16x16/ToRGB/bias' ], 437 | 'toRGBs.3.0.weight' : ['con', 'G_synthesis/32x32/ToRGB/weight' ], 438 | 'toRGBs.3.0.fc.weight' : ['fc_', 'G_synthesis/32x32/ToRGB/mod_weight' ], 439 | 'toRGBs.3.0.bias.bias' : ['any', 'G_synthesis/32x32/ToRGB/mod_bias' ], 440 | 'toRGBs.3.1.bias' : ['any', 'G_synthesis/32x32/ToRGB/bias' ], 441 | 'toRGBs.4.0.weight' : ['con', 'G_synthesis/64x64/ToRGB/weight' ], 442 | 'toRGBs.4.0.fc.weight' : ['fc_', 'G_synthesis/64x64/ToRGB/mod_weight' ], 443 | 'toRGBs.4.0.bias.bias' : ['any', 'G_synthesis/64x64/ToRGB/mod_bias' ], 444 | 'toRGBs.4.1.bias' : ['any', 'G_synthesis/64x64/ToRGB/bias' ], 445 | 'toRGBs.5.0.weight' : ['con', 'G_synthesis/128x128/ToRGB/weight' ], 446 | 'toRGBs.5.0.fc.weight' : ['fc_', 'G_synthesis/128x128/ToRGB/mod_weight' ], 447 | 'toRGBs.5.0.bias.bias' : ['any', 'G_synthesis/128x128/ToRGB/mod_bias' ], 448 | 'toRGBs.5.1.bias' : ['any', 'G_synthesis/128x128/ToRGB/bias' ], 449 | 'toRGBs.6.0.weight' : ['con', 'G_synthesis/256x256/ToRGB/weight' ], 450 | 'toRGBs.6.0.fc.weight' : ['fc_', 'G_synthesis/256x256/ToRGB/mod_weight' ], 451 | 'toRGBs.6.0.bias.bias' : ['any', 'G_synthesis/256x256/ToRGB/mod_bias' ], 452 | 'toRGBs.6.1.bias' : ['any', 'G_synthesis/256x256/ToRGB/bias' ], 453 | 'toRGBs.7.0.weight' : ['con', 'G_synthesis/512x512/ToRGB/weight' ], 454 | 'toRGBs.7.0.fc.weight' : ['fc_', 'G_synthesis/512x512/ToRGB/mod_weight' ], 455 | 'toRGBs.7.0.bias.bias' : ['any', 'G_synthesis/512x512/ToRGB/mod_bias' ], 456 | 'toRGBs.7.1.bias' : ['any', 'G_synthesis/512x512/ToRGB/bias' ], 457 | 'toRGBs.8.0.weight' : ['con', 'G_synthesis/1024x1024/ToRGB/weight' ], 458 | 'toRGBs.8.0.fc.weight' : ['fc_', 'G_synthesis/1024x1024/ToRGB/mod_weight' ], 459 | 'toRGBs.8.0.bias.bias' : ['any', 'G_synthesis/1024x1024/ToRGB/mod_bias' ], 460 | 'toRGBs.8.1.bias' : ['any', 'G_synthesis/1024x1024/ToRGB/bias' ], 461 | } 462 | 463 | 464 | # 変換関数 465 | ops_dict = { 466 | # 変調転置畳み込みの重み (iC,oC,kH,kW) 467 | 'mTc' : lambda weight: torch.flip(torch.from_numpy(weight.transpose((2,3,0,1))), [2, 3]), 468 | # 転置畳み込みの重み (iC,oC,kH,kW) 469 | 'Tco' : lambda weight: torch.from_numpy(weight.transpose((2,3,0,1))), 470 | # 畳み込みの重み (oC,iC,kH,kW) 471 | 'con' : lambda weight: torch.from_numpy(weight.transpose((3,2,0,1))), 472 | # 全結合層の重み (oD, iD) 473 | 'fc_' : lambda weight: torch.from_numpy(weight.transpose((1, 0))), 474 | # 全結合層のバイアス項, 固定入力, 固定ノイズ, v1ノイズの重み (無変換) 475 | 'any' : lambda weight: torch.from_numpy(weight), 476 | # Style-Mixingの値, v2ノイズの重み (scalar) 477 | 'uns' : lambda weight: torch.from_numpy(np.array(weight).reshape(1)), 478 | } 479 | 480 | 481 | if __name__ == '__main__': 482 | # コマンドライン引数の取得 483 | args = parse_args() 484 | 485 | cfg = { 486 | 'src_weight': 'stylegan2_ndarray.pkl', 487 | 'src_latent': 'latents2.pkl', 488 | 'dst_image' : 'stylegan2_pt.png', 489 | 'dst_weight': 'stylegan2_state_dict.pth' 490 | } 491 | 492 | print('model construction...') 493 | generator = Generator() 494 | base_dict = generator.state_dict() 495 | 496 | print('model weights load...') 497 | with (Path(args.weight_dir)/cfg['src_weight']).open('rb') as f: 498 | src_dict = pickle.load(f) 499 | 500 | print('set state_dict...') 501 | new_dict = { k : ops_dict[v[0]](src_dict[v[1]]) for k,v in name_trans_dict.items()} 502 | generator.load_state_dict(new_dict) 503 | 504 | print('load latents...') 505 | with (Path(args.output_dir)/cfg['src_latent']).open('rb') as f: 506 | latents = pickle.load(f) 507 | latents = torch.from_numpy(latents.astype(np.float32)) 508 | 509 | print('network forward...') 510 | device = torch.device('cuda') if torch.cuda.is_available() and args.device=='gpu' else torch.device('cpu') 511 | with torch.no_grad(): 512 | N,_ = latents.shape 513 | generator.to(device) 514 | images = np.empty((N,args.resolution,args.resolution,3),dtype=np.uint8) 515 | 516 | for i in range(0,N,args.batch_size): 517 | j = min(i+args.batch_size,N) 518 | z = latents[i:j].to(device) 519 | img = generator(z) 520 | normalized = (img.clamp(-1,1)+1)/2*255 521 | images[i:j] = normalized.permute(0,2,3,1).cpu().numpy().astype(np.uint8) 522 | del z, img, normalized 523 | 524 | # 出力を並べる関数 525 | def make_table(imgs): 526 | # 出力する個数,解像度 527 | num_H, num_W = 4,4 528 | H = W = args.resolution 529 | num_images = num_H*num_W 530 | 531 | canvas = np.zeros((H*num_H,W*num_W,3),dtype=np.uint8) 532 | for i,p in enumerate(imgs[:num_images]): 533 | h,w = i//num_W, i%num_W 534 | canvas[H*h:H*-~h,W*w:W*-~w,:] = p[:,:,::-1] 535 | return canvas 536 | 537 | print('image output...') 538 | cv2.imwrite(str(Path(args.output_dir)/cfg['dst_image']), make_table(images)) 539 | 540 | print('weight save...') 541 | torch.save(generator.state_dict(),str(Path(args.weight_dir)/cfg['dst_weight'])) 542 | 543 | print('all done') 544 | -------------------------------------------------------------------------------- /packaged/run_tf_stylegan1.py: -------------------------------------------------------------------------------- 1 | 2 | from pathlib import Path 3 | import argparse 4 | import pickle 5 | # from tempfile import TemporaryDirectory 6 | 7 | import numpy as np 8 | import PIL.Image 9 | import tensorflow as tf 10 | 11 | # import dnnlib 12 | 13 | 14 | # コマンドライン引数の取得 15 | def parse_args(): 16 | parser = argparse.ArgumentParser(description='著者実装を動かしたり重みを抜き出したり') 17 | parser.add_argument('-w','--weight_dir',type=str,default='/tmp/stylegans-pytorch', 18 | help='学習済みのモデルを保存する場所') 19 | parser.add_argument('-o','--output_dir',type=str,default='/tmp/stylegans-pytorch', 20 | help='生成された画像を保存する場所') 21 | parser.add_argument('--batch_size',type=int,default=1, 22 | help='バッチサイズ') 23 | args = parser.parse_args() 24 | args.resolution = 1024 25 | return args 26 | 27 | 28 | # tensorflowの初期化 29 | def init_tf(): 30 | tf_random_seed = np.random.randint(1 << 31) 31 | tf.set_random_seed(tf_random_seed) 32 | 33 | config_proto = tf.ConfigProto() 34 | config_proto.graph_options.place_pruned_graph = True 35 | config_proto.gpu_options.allow_growth = True 36 | 37 | session = tf.Session(config=config_proto) 38 | session._default_session = session.as_default() 39 | session._default_session.enforce_nesting = False 40 | session._default_session.__enter__() 41 | 42 | return session 43 | 44 | 45 | # tensorflowの出力を正規化 46 | def convert_images_to_uint8(images): 47 | images = tf.cast(images, tf.float32) 48 | images = tf.transpose(images, [0, 2, 3, 1]) 49 | images = (images+1.0) / 2.0 * 255.0 50 | images = tf.saturate_cast(images, tf.uint8) 51 | return images 52 | 53 | 54 | # メイン関数 55 | def generate_images(args): 56 | file_names = { 57 | 'input_weight' : 'karras2019stylegan-ffhq-1024x1024.pkl', 58 | 'output_weight' : 'stylegan1_ndarray.pkl', 59 | 'used_latents' : 'latents1.pkl', 60 | 'output_image' : 'stylegan1_tf.png', 61 | } 62 | 63 | init_tf() 64 | 65 | # 配布されている重みの読み込み 66 | with (Path(args.weight_dir)/file_names['input_weight']).open('rb') as f: 67 | *_, Gs = pickle.load(f) 68 | 69 | # 重みをnumpy形式に変換 70 | ndarrays = {k:v.eval() for k,v in Gs.vars.items()} 71 | [print(k,v.shape) for k,v in ndarrays.items()] 72 | 73 | # 重みをnumpy形式で保存 74 | print('weight save...') 75 | with (Path(args.weight_dir)/file_names['output_weight']).open('wb') as f: 76 | pickle.dump(ndarrays,f) 77 | 78 | 79 | # 画像を出力してみる 80 | print('run network') 81 | 82 | # 出力する個数,解像度 83 | num_H, num_W = 4,4 84 | N = num_images = num_H*num_W 85 | H = W = args.resolution 86 | 87 | # 出力を並べる関数 88 | def make_table(imgs): 89 | canvas = np.zeros((H*num_H,W*num_W,3),dtype=np.uint8) 90 | for i,p in enumerate(imgs): 91 | h,w = i//num_W, i%num_W 92 | canvas[H*h:H*-~h,W*w:W*-~w,:] = p 93 | return canvas 94 | 95 | # 乱数シードを固定,潜在変数の取得・保存 96 | latents = np.random.RandomState(5).randn(N, 512) 97 | with (Path(args.output_dir)/file_names['used_latents']).open('wb') as f: 98 | pickle.dump(latents, f) 99 | 100 | images = np.empty((N,args.resolution,args.resolution,3),dtype=np.uint8) 101 | for i in range(0, N, args.batch_size): 102 | j = min(i+args.batch_size, N) 103 | z = latents[i:j] 104 | images[i:j] = Gs.run(z, None, truncation_psi=0.7, randomize_noise=False, 105 | output_transform= {'func': convert_images_to_uint8}) 106 | 107 | # 画像の保存 108 | PIL.Image.fromarray(make_table(images)).save(Path(args.output_dir)/file_names['output_image']) 109 | 110 | 111 | if __name__ == '__main__': 112 | args = parse_args() 113 | 114 | generate_images(args) 115 | -------------------------------------------------------------------------------- /packaged/run_tf_stylegan2.py: -------------------------------------------------------------------------------- 1 | 2 | from pathlib import Path 3 | import argparse 4 | import pickle 5 | from tempfile import TemporaryDirectory 6 | 7 | import numpy as np 8 | import PIL.Image 9 | import tensorflow as tf 10 | 11 | import dnnlib 12 | 13 | 14 | # コマンドライン引数の取得 15 | def parse_args(): 16 | parser = argparse.ArgumentParser(description='著者実装を動かしたり重みを抜き出したり') 17 | parser.add_argument('-w','--weight_dir',type=str,default='/tmp/stylegans-pytorch', 18 | help='学習済みのモデルを保存する場所') 19 | parser.add_argument('-o','--output_dir',type=str,default='/tmp/stylegans-pytorch', 20 | help='生成された画像を保存する場所') 21 | parser.add_argument('--batch_size',type=int,default=1, 22 | help='バッチサイズ') 23 | args = parser.parse_args() 24 | args.resolution = 1024 25 | return args 26 | 27 | 28 | # tensorflowの初期化 29 | def init_tf(): 30 | tf_random_seed = np.random.randint(1 << 31) 31 | tf.set_random_seed(tf_random_seed) 32 | 33 | config_proto = tf.ConfigProto() 34 | config_proto.graph_options.place_pruned_graph = True 35 | config_proto.gpu_options.allow_growth = True 36 | 37 | session = tf.Session(config=config_proto) 38 | session._default_session = session.as_default() 39 | session._default_session.enforce_nesting = False 40 | session._default_session.__enter__() 41 | 42 | return session 43 | 44 | 45 | # tensorflowの出力を正規化 46 | def convert_images_to_uint8(images): 47 | images = tf.cast(images, tf.float32) 48 | images = tf.transpose(images, [0, 2, 3, 1]) 49 | images = (images+1.0) / 2.0 * 255.0 50 | images = tf.saturate_cast(images, tf.uint8) 51 | return images 52 | 53 | 54 | # 画像生成関数(dnnlib.submit_runに渡すコールバック関数) 55 | def generate_images(args): 56 | file_names = { 57 | 'input_weight' : 'stylegan2-ffhq-config-f.pkl', 58 | 'output_weight' : 'stylegan2_ndarray.pkl', 59 | 'used_latents' : 'latents2.pkl', 60 | 'output_image' : 'stylegan2_tf.png', 61 | } 62 | 63 | init_tf() 64 | 65 | # 配布されている重みの読み込み 66 | with (Path(args.weight_dir)/file_names['input_weight']).open('rb') as f: 67 | *_, Gs = pickle.load(f) 68 | 69 | # 重みをnumpy形式に変換 70 | ndarrays = {k:v.eval() for k,v in Gs.vars.items()} 71 | [print(k,v.shape) for k,v in ndarrays.items()] 72 | 73 | # 重みをnumpy形式で保存 74 | print('weight save...') 75 | with (Path(args.weight_dir)/file_names['output_weight']).open('wb') as f: 76 | pickle.dump(ndarrays,f) 77 | 78 | 79 | # 画像を出力してみる 80 | print('run network') 81 | 82 | # 出力する個数,解像度 83 | num_H, num_W = 4,4 84 | N = num_images = num_H*num_W 85 | H = W = args.resolution 86 | 87 | # 出力を並べる関数 88 | def make_table(imgs): 89 | canvas = np.zeros((H*num_H,W*num_W,3),dtype=np.uint8) 90 | for i,p in enumerate(imgs): 91 | h,w = i//num_W, i%num_W 92 | canvas[H*h:H*-~h,W*w:W*-~w,:] = p 93 | return canvas 94 | 95 | # 乱数シードを固定,潜在変数の取得・保存 96 | latents = np.random.RandomState(5).randn(N, 512) 97 | with (Path(args.output_dir)/file_names['used_latents']).open('wb') as f: 98 | pickle.dump(latents, f) 99 | 100 | images = np.empty((N,args.resolution,args.resolution,3),dtype=np.uint8) 101 | for i in range(0, N, args.batch_size): 102 | j = min(i+args.batch_size, N) 103 | z = latents[i:j] 104 | images[i:j] = Gs.run(z, None, truncation_psi=0.7, randomize_noise=False, 105 | output_transform= {'func': convert_images_to_uint8}) 106 | 107 | # 画像の保存 108 | PIL.Image.fromarray(make_table(images)).save(Path(args.output_dir)/file_names['output_image']) 109 | 110 | print('all done') 111 | 112 | 113 | if __name__ == '__main__': 114 | args = parse_args() 115 | 116 | with TemporaryDirectory() as dir_name: 117 | # 設定に必要な情報を書き込んで実行 118 | config = dnnlib.SubmitConfig() 119 | config.local.do_not_copy_source_files = True 120 | config.run_dir_root = dir_name # 変なファイルが自動生成される場所 121 | 122 | dnnlib.submit_run(config, 123 | 'run_tf_stylegan2.generate_images', # コールバック関数 124 | args=args) # コールバック関数の引数 125 | -------------------------------------------------------------------------------- /waifu/common.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | # mapping前に潜在変数を超球面上に正規化 8 | class PixelwiseNormalization(nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | def forward(self, x): 12 | return x / torch.sqrt((x**2).mean(1,keepdim=True) + 1e-8) 13 | 14 | 15 | # 移動平均を用いて潜在変数を正規化する. 16 | class TruncationTrick(nn.Module): 17 | def __init__(self, num_target, threshold, output_num, style_dim): 18 | super().__init__() 19 | self.num_target = num_target 20 | self.threshold = threshold 21 | self.output_num = output_num 22 | self.register_buffer('avg_style', torch.zeros((style_dim,))) 23 | 24 | def forward(self, x): 25 | # in:(N,D) -> out:(N,O,D) 26 | N,D = x.shape 27 | O = self.output_num 28 | x = x.view(N,1,D).expand(N,O,D) 29 | rate = torch.cat([ torch.ones((N, self.num_target, D)) *self.threshold, 30 | torch.ones((N, O-self.num_target, D)) *1.0 ],1).to(x.device) 31 | avg = self.avg_style.view(1,1,D).expand(N,O,D) 32 | return avg + (x-avg)*rate 33 | 34 | 35 | # 特徴マップ信号を増幅する 36 | class Amplify(nn.Module): 37 | def __init__(self, rate): 38 | super().__init__() 39 | self.rate = rate 40 | def forward(self,x): 41 | return x * self.rate 42 | 43 | 44 | # チャンネルごとにバイアス項を足す 45 | class AddChannelwiseBias(nn.Module): 46 | def __init__(self, out_channels, lr): 47 | super().__init__() 48 | # lr = 1.0 (conv,mod,AdaIN), 0.01 (mapping) 49 | 50 | self.bias = nn.Parameter(torch.zeros(out_channels)) 51 | torch.nn.init.zeros_(self.bias.data) 52 | self.bias_scaler = lr 53 | 54 | def forward(self, x): 55 | oC,*_ = self.bias.shape 56 | shape = (1,oC) if x.ndim==2 else (1,oC,1,1) 57 | y = x + self.bias.view(*shape)*self.bias_scaler 58 | return y 59 | 60 | 61 | # 学習率を調整したFC層 62 | class EqualizedFullyConnect(nn.Module): 63 | def __init__(self, in_dim, out_dim, lr): 64 | super().__init__() 65 | # lr = 0.01 (mapping), 1.0 (mod,AdaIN) 66 | 67 | self.weight = nn.Parameter(torch.randn((out_dim,in_dim))) 68 | torch.nn.init.normal_(self.weight.data, mean=0.0, std=1.0/lr) 69 | self.weight_scaler = 1/(in_dim**0.5)*lr 70 | 71 | def forward(self, x): 72 | # x (N,D) 73 | return F.linear(x, self.weight*self.weight_scaler, None) 74 | -------------------------------------------------------------------------------- /waifu/converter.py: -------------------------------------------------------------------------------- 1 | 2 | from pathlib import Path 3 | import argparse 4 | import pickle 5 | from tempfile import TemporaryDirectory 6 | 7 | import numpy as np 8 | import PIL.Image 9 | import tensorflow as tf 10 | 11 | import dnnlib 12 | 13 | 14 | # コマンドライン引数の取得 15 | def parse_args(): 16 | parser = argparse.ArgumentParser(description='著者実装を動かしたり重みを抜き出したり') 17 | parser.add_argument('version',type=int, 18 | help='version name') 19 | parser.add_argument('model', type=str, choices=['face_v1_1','face_v1_2','portrait_v1','portrait_v2'], 20 | help='モデル') 21 | parser.add_argument('-w','--weight_dir',type=str,default='/tmp/stylegans-pytorch', 22 | help='学習済みのモデルを保存する場所') 23 | parser.add_argument('-o','--output_dir',type=str,default='/tmp/stylegans-pytorch', 24 | help='生成された画像を保存する場所') 25 | parser.add_argument('--batch_size',type=int,default=1, 26 | help='バッチサイズ') 27 | args = parser.parse_args() 28 | args.resolution = 512 29 | return args 30 | 31 | 32 | # tensorflowの初期化 33 | def init_tf(): 34 | tf_random_seed = np.random.randint(1 << 31) 35 | tf.set_random_seed(tf_random_seed) 36 | 37 | config_proto = tf.ConfigProto() 38 | config_proto.graph_options.place_pruned_graph = True 39 | config_proto.gpu_options.allow_growth = True 40 | 41 | session = tf.Session(config=config_proto) 42 | session._default_session = session.as_default() 43 | session._default_session.enforce_nesting = False 44 | session._default_session.__enter__() 45 | 46 | return session 47 | 48 | 49 | # tensorflowの出力を正規化 50 | def convert_images_to_uint8(images): 51 | images = tf.cast(images, tf.float32) 52 | images = tf.transpose(images, [0, 2, 3, 1]) 53 | images = (images+1.0) / 2.0 * 255.0 54 | images = tf.saturate_cast(images, tf.uint8) 55 | return images 56 | 57 | 58 | # メイン関数 59 | def generate_images(args, file_names): 60 | 61 | init_tf() 62 | 63 | # 配布されている重みの読み込み 64 | with (Path(args.weight_dir)/file_names['input_weight']).open('rb') as f: 65 | *_, Gs = pickle.load(f) 66 | 67 | # 重みをnumpy形式に変換 68 | ndarrays = {k:v.eval() for k,v in Gs.vars.items()} 69 | [print(k,v.shape) for k,v in ndarrays.items()] 70 | 71 | # 重みをnumpy形式で保存 72 | print('weight save...') 73 | with (Path(args.weight_dir)/file_names['output_weight']).open('wb') as f: 74 | pickle.dump(ndarrays,f) 75 | 76 | 77 | # 画像を出力してみる 78 | print('run network') 79 | 80 | # 出力する個数,解像度 81 | num_H, num_W = 4,4 82 | N = num_images = num_H*num_W 83 | H = W = args.resolution 84 | 85 | # 出力を並べる関数 86 | def make_table(imgs): 87 | canvas = np.zeros((H*num_H,W*num_W,3),dtype=np.uint8) 88 | for i,p in enumerate(imgs): 89 | h,w = i//num_W, i%num_W 90 | canvas[H*h:H*-~h,W*w:W*-~w,:] = p 91 | return canvas 92 | 93 | # 乱数シードを固定,潜在変数の取得・保存 94 | latents = np.random.RandomState(5).randn(N, 512) 95 | with (Path(args.output_dir)/file_names['used_latents']).open('wb') as f: 96 | pickle.dump(latents, f) 97 | 98 | images = np.empty((N,args.resolution,args.resolution,3),dtype=np.uint8) 99 | for i in range(0, N, args.batch_size): 100 | j = min(i+args.batch_size, N) 101 | z = latents[i:j] 102 | images[i:j] = Gs.run(z, None, truncation_psi=0.7, randomize_noise=False, 103 | output_transform= {'func': convert_images_to_uint8}) 104 | 105 | # 画像の保存 106 | PIL.Image.fromarray(make_table(images)).save(Path(args.output_dir)/file_names['output_image']) 107 | 108 | print('all done') 109 | 110 | 111 | if __name__ == '__main__': 112 | args = parse_args() 113 | 114 | file_names = { 115 | 'face_v1_1' : { 116 | 'input_weight' : '2019-02-26-stylegan-faces-network-02048-016041.pkl', 117 | 'output_weight' : 'anime_face_v1_1_ndarray.pkl', 118 | 'used_latents' : 'anime_face_v1_1_latents.pkl', 119 | 'output_image' : 'anime_face_v1_1_tf.png', 120 | }, 121 | 'face_v1_2' : { 122 | 'input_weight' : '2019-03-08-stylegan-animefaces-network-02051-021980.pkl', 123 | 'output_weight' : 'anime_face_v1_2_ndarray.pkl', 124 | 'used_latents' : 'anime_face_v1_2_latents.pkl', 125 | 'output_image' : 'anime_face_v1_2_tf.png', 126 | }, 127 | 'portrait_v1' : { 128 | 'input_weight' : '2019-04-30-stylegan-danbooru2018-portraits-02095-066083.pkl', 129 | 'output_weight' : 'anime_portrait_v1_ndarray.pkl', 130 | 'used_latents' : 'anime_portrait_v1_latents.pkl', 131 | 'output_image' : 'anime_portrait_v1_tf.png', 132 | }, 133 | 'portrait_v2' : { 134 | 'input_weight' : '2020-01-11-skylion-stylegan2-animeportraits-networksnapshot-024664.pkl', 135 | 'output_weight' : 'anime_portrait_v2_ndarray.pkl', 136 | 'used_latents' : 'anime_portrait_v2_latents.pkl', 137 | 'output_image' : 'anime_portrait_v2_tf.png', 138 | }, 139 | }[args.model] 140 | 141 | if args.version==1: 142 | generate_images(args, file_names) 143 | 144 | elif args.version==2: 145 | with TemporaryDirectory() as dir_name: 146 | # 設定に必要な情報を書き込んで実行 147 | config = dnnlib.SubmitConfig() 148 | config.local.do_not_copy_source_files = True 149 | config.run_dir_root = dir_name # 変なファイルが自動生成される場所 150 | 151 | dnnlib.submit_run(config, 152 | 'converter.generate_images', # コールバック関数 153 | args=args, file_names=file_names) # コールバック関数の引数 154 | 155 | 156 | 157 | -------------------------------------------------------------------------------- /waifu/run_pt_stylegan.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | from pathlib import Path 4 | import pickle 5 | 6 | import numpy as np 7 | import cv2 8 | import torch 9 | 10 | 11 | # コマンドライン引数の取得 12 | def parse_args(): 13 | parser = argparse.ArgumentParser(description='著者実装を動かしたり重みを抜き出したり') 14 | parser.add_argument('version',type=int, 15 | help='version name') 16 | parser.add_argument('model', type=str, choices=['face_v1_1','face_v1_2','portrait_v1','portrait_v2'], 17 | help='モデル') 18 | parser.add_argument('-w','--weight_dir',type=str,default='/tmp/stylegans-pytorch', 19 | help='学習済みのモデルを保存する場所') 20 | parser.add_argument('-o','--output_dir',type=str,default='/tmp/stylegans-pytorch', 21 | help='生成された画像を保存する場所') 22 | parser.add_argument('--batch_size',type=int,default=1, 23 | help='バッチサイズ') 24 | parser.add_argument('--device',type=str,default='gpu',choices=['gpu','cpu'], 25 | help='デバイス') 26 | args = parser.parse_args() 27 | args.resolution = 512 28 | return args 29 | 30 | 31 | # 変換関数 32 | ops_dict = { 33 | # 変調転置畳み込みの重み (iC,oC,kH,kW) 34 | 'mTc' : lambda weight: torch.flip(torch.from_numpy(weight.transpose((2,3,0,1))), [2, 3]), 35 | # 転置畳み込みの重み (iC,oC,kH,kW) 36 | 'Tco' : lambda weight: torch.from_numpy(weight.transpose((2,3,0,1))), 37 | # 畳み込みの重み (oC,iC,kH,kW) 38 | 'con' : lambda weight: torch.from_numpy(weight.transpose((3,2,0,1))), 39 | # 全結合層の重み (oD, iD) 40 | 'fc_' : lambda weight: torch.from_numpy(weight.transpose((1, 0))), 41 | # 全結合層のバイアス項, 固定入力, 固定ノイズ, v1ノイズの重み (無変換) 42 | 'any' : lambda weight: torch.from_numpy(weight), 43 | # Style-Mixingの値, v2ノイズの重み (scalar) 44 | 'uns' : lambda weight: torch.from_numpy(np.array(weight).reshape(1)), 45 | } 46 | 47 | 48 | setting = { 49 | 'face_v1_1' : { 50 | 'src_weight' : 'anime_face_v1_1_ndarray.pkl', 51 | 'src_latent' : 'anime_face_v1_1_latents.pkl', 52 | 'dst_image' : 'anime_face_v1_1_pt.png', 53 | 'dst_weight' : 'anime_face_v1_1_state_dict.pth' 54 | }, 55 | 'face_v1_2' : { 56 | 'src_weight' : 'anime_face_v1_2_ndarray.pkl', 57 | 'src_latent' : 'anime_face_v1_2_latents.pkl', 58 | 'dst_image' : 'anime_face_v1_2_pt.png', 59 | 'dst_weight' : 'anime_face_v1_2_state_dict.pth' 60 | }, 61 | 'portrait_v1' : { 62 | 'src_weight' : 'anime_portrait_v1_ndarray.pkl', 63 | 'src_latent' : 'anime_portrait_v1_latents.pkl', 64 | 'dst_image' : 'anime_portrait_v1_pt.png', 65 | 'dst_weight' : 'anime_portrait_v1_state_dict.pth' 66 | }, 67 | 'portrait_v2' : { 68 | 'src_weight' : 'anime_portrait_v2_ndarray.pkl', 69 | 'src_latent' : 'anime_portrait_v2_latents.pkl', 70 | 'dst_image' : 'anime_portrait_v2_pt.png', 71 | 'dst_weight' : 'anime_portrait_v2_state_dict.pth' 72 | }, 73 | } 74 | 75 | 76 | if __name__ == '__main__': 77 | # コマンドライン引数の取得 78 | args = parse_args() 79 | 80 | # バージョンによって切り替え 81 | cfg = setting[args.model] 82 | 83 | if args.version==1: 84 | from stylegan1 import Generator, name_trans_dict 85 | elif args.version==2: 86 | from stylegan2 import Generator, name_trans_dict 87 | 88 | 89 | print('model construction...') 90 | generator = Generator() 91 | base_dict = generator.state_dict() 92 | 93 | print('model weights load...') 94 | with (Path(args.weight_dir)/cfg['src_weight']).open('rb') as f: 95 | src_dict = pickle.load(f) 96 | 97 | print('set state_dict...') 98 | new_dict = { k : ops_dict[v[0]](src_dict[v[1]]) for k,v in name_trans_dict.items() if v[1] in src_dict} 99 | generator.load_state_dict(new_dict) 100 | 101 | print('load latents...') 102 | with (Path(args.output_dir)/cfg['src_latent']).open('rb') as f: 103 | latents = pickle.load(f) 104 | latents = torch.from_numpy(latents.astype(np.float32)) 105 | 106 | print('network forward...') 107 | device = torch.device('cuda') if torch.cuda.is_available() and args.device=='gpu' else torch.device('cpu') 108 | with torch.no_grad(): 109 | N,_ = latents.shape 110 | generator.to(device) 111 | images = np.empty((N,args.resolution,args.resolution,3),dtype=np.uint8) 112 | 113 | for i in range(0,N,args.batch_size): 114 | j = min(i+args.batch_size,N) 115 | z = latents[i:j].to(device) 116 | img = generator(z) 117 | normalized = (img.clamp(-1,1)+1)/2*255 118 | images[i:j] = normalized.permute(0,2,3,1).cpu().numpy().astype(np.uint8) 119 | del z, img, normalized 120 | 121 | # 出力を並べる関数 122 | def make_table(imgs): 123 | # 出力する個数,解像度 124 | num_H, num_W = 4,4 125 | H = W = args.resolution 126 | num_images = num_H*num_W 127 | 128 | canvas = np.zeros((H*num_H,W*num_W,3),dtype=np.uint8) 129 | for i,p in enumerate(imgs[:num_images]): 130 | h,w = i//num_W, i%num_W 131 | canvas[H*h:H*-~h,W*w:W*-~w,:] = p[:,:,::-1] 132 | return canvas 133 | 134 | print('image output...') 135 | cv2.imwrite(str(Path(args.output_dir)/cfg['dst_image']), make_table(images)) 136 | 137 | print('weight save...') 138 | torch.save(generator.state_dict(),str(Path(args.weight_dir)/cfg['dst_weight'])) 139 | 140 | print('all done') 141 | -------------------------------------------------------------------------------- /waifu/stylegan1.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from common import PixelwiseNormalization, Amplify, AddChannelwiseBias, EqualizedFullyConnect, TruncationTrick 8 | 9 | 10 | # 固定ノイズ 11 | class ElementwiseNoise(nn.Module): 12 | def __init__(self, ch, size_hw): 13 | super().__init__() 14 | self.register_buffer("const_noise", torch.randn((1, 1, size_hw, size_hw))) 15 | self.noise_scaler = nn.Parameter(torch.zeros((ch,))) 16 | 17 | def forward(self, x): 18 | N,C,H,W = x.shape 19 | noise = self.const_noise.expand(N,C,H,W) 20 | scaler = self.noise_scaler.view(1,C,1,1) 21 | return x + noise * scaler 22 | 23 | 24 | # ブラー : 解像度を上げる畳み込みの後に使う 25 | class Blur3x3(nn.Module): 26 | def __init__(self): 27 | super().__init__() 28 | f = np.array( [ [1/16, 2/16, 1/16], 29 | [2/16, 4/16, 2/16], 30 | [1/16, 2/16, 1/16]], dtype=np.float32).reshape([1, 1, 3, 3]) 31 | self.filter = torch.from_numpy(f) 32 | 33 | def forward(self, x): 34 | _N,C,_H,_W = x.shape 35 | return F.conv2d(x, self.filter.to(x.device).expand(C,1,3,3), padding=1, groups=C) 36 | 37 | 38 | # 学習率を調整した転置畳み込み (ブラーのための拡張あり) 39 | class EqualizedFusedConvTransposed2d(nn.Module): 40 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, lr): 41 | super().__init__() 42 | # lr = 1.0 43 | 44 | self.stride, self.padding = stride, padding 45 | 46 | self.weight = nn.Parameter(torch.empty(in_channels, out_channels, kernel_size, kernel_size)) 47 | torch.nn.init.normal_(self.weight.data, mean=0.0, std=1.0/lr) 48 | self.weight_scaler = 1 / ((in_channels * (kernel_size ** 2) )**0.5) * lr 49 | 50 | def forward(self, x): 51 | # 3x3 conv を 4x4 transposed conv として使う 52 | i_ch, o_ch, _kh, _kw = self.weight.shape 53 | # Padding (L,R,T,B) で4x4の四隅に3x3フィルタを寄せて和で合成 54 | weight_4x4 = torch.cat([F.pad(self.weight, pad).view(1,i_ch,o_ch,4,4) 55 | for pad in [(0,1,0,1),(1,0,0,1),(0,1,1,0),(1,0,1,0)]]).sum(dim=0) 56 | return F.conv_transpose2d(x, weight_4x4*self.weight_scaler, stride=2, padding=1) 57 | # 3x3でconvしてからpadで4隅に寄せて計算しても同じでは? 58 | # padding0にしてStyleGAN2のBlurを使っても同じでは? 59 | 60 | 61 | # 学習率を調整した畳込み 62 | class EqualizedConv2d(nn.Module): 63 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, lr): 64 | super().__init__() 65 | # lr = 1.0 66 | 67 | self.stride, self.padding = stride, padding 68 | 69 | self.weight = nn.Parameter(torch.empty(out_channels, in_channels, kernel_size, kernel_size)) 70 | torch.nn.init.normal_(self.weight.data, mean=0.0, std=1.0/lr) 71 | self.weight_scaler = 1 / ((in_channels * (kernel_size ** 2) )**0.5) * lr 72 | 73 | def forward(self, x): 74 | N,C,H,W = x.shape 75 | return F.conv2d(x, self.weight*self.weight_scaler, None, 76 | self.stride, self.padding) 77 | 78 | 79 | # 学習率を調整したAdaIN 80 | class EqualizedAdaIN(nn.Module): 81 | def __init__(self, fmap_ch, style_ch, lr): 82 | super().__init__() 83 | # lr = 1.0 84 | self.fc = EqualizedFullyConnect(style_ch, fmap_ch*2, lr) 85 | self.bias = AddChannelwiseBias(fmap_ch*2,lr) 86 | 87 | def forward(self, pack): 88 | x, style = pack 89 | #N,D = w.shape 90 | N,C,H,W = x.shape 91 | 92 | _vec = self.bias( self.fc(style) ).view(N,2*C,1,1) # (N,2C,1,1) 93 | scale, shift = _vec[:,:C,:,:], _vec[:,C:,:,:] # (N,C,1,1), (N,C,1,1) 94 | return (scale+1) * F.instance_norm(x, eps=1e-8) + shift 95 | 96 | 97 | class Generator(nn.Module): 98 | 99 | structure = { 100 | 'mapping': [['pixel_norm'], ['fc',512,512],['amp'],['b',512],['Lrelu'],['fc',512,512],['amp'],['b',512],['Lrelu'], 101 | ['fc',512,512],['amp'],['b',512],['Lrelu'],['fc',512,512],['amp'],['b',512],['Lrelu'], 102 | ['fc',512,512],['amp'],['b',512],['Lrelu'],['fc',512,512],['amp'],['b',512],['Lrelu'], 103 | ['fc',512,512],['amp'],['b',512],['Lrelu'],['fc',512,512],['amp'],['b',512],['Lrelu'],['truncation']], 104 | 'START' : [ ['noiseE',512, 4], ['bias',512], ['Lrelu'] ], 'adain4a' : [['adain',512]], 105 | 'Fconv4' : [ ['EqConv3x3',512, 512], ['amp'], ['noiseE',512, 4], ['bias',512], ['Lrelu'] ], 'adain4b' : [['adain',512]], 'toRGB_4' : [['EqConv1x1',512, 3], ['bias',3]], 106 | 'Uconv8' : [['up'], ['EqConv3x3',512, 512], ['blur3x3'], ['amp'], ['noiseE',512, 8], ['bias',512], ['Lrelu'] ], 'adain8a' : [['adain',512]], 107 | 'Fconv8' : [ ['EqConv3x3',512, 512], ['amp'], ['noiseE',512, 8], ['bias',512], ['Lrelu'] ], 'adain8b' : [['adain',512]], 'toRGB_8' : [['EqConv1x1',512, 3], ['bias',3]], 108 | 'Uconv16' : [['up'], ['EqConv3x3',512, 512], ['blur3x3'], ['amp'], ['noiseE',512, 16], ['bias',512], ['Lrelu'] ], 'adain16a' : [['adain',512]], 109 | 'Fconv16' : [ ['EqConv3x3',512, 512], ['amp'], ['noiseE',512, 16], ['bias',512], ['Lrelu'] ], 'adain16b' : [['adain',512]], 'toRGB_16' : [['EqConv1x1',512, 3], ['bias',3]], 110 | 'Uconv32' : [['up'], ['EqConv3x3',512, 512], ['blur3x3'], ['amp'], ['noiseE',512, 32], ['bias',512], ['Lrelu'] ], 'adain32a' : [['adain',512]], 111 | 'Fconv32' : [ ['EqConv3x3',512, 512], ['amp'], ['noiseE',512, 32], ['bias',512], ['Lrelu'] ], 'adain32b' : [['adain',512]], 'toRGB_32' : [['EqConv1x1',512, 3], ['bias',3]], 112 | 'Uconv64' : [['up'], ['EqConv3x3',512, 256], ['blur3x3'], ['amp'], ['noiseE',256, 64], ['bias',256], ['Lrelu'] ], 'adain64a' : [['adain',256]], 113 | 'Fconv64' : [ ['EqConv3x3',256, 256], ['amp'], ['noiseE',256, 64], ['bias',256], ['Lrelu'] ], 'adain64b' : [['adain',256]], 'toRGB_64' : [['EqConv1x1',256, 3], ['bias',3]], 114 | 'Uconv128' : [ ['EqConvT3x3EX',256, 128], ['blur3x3'], ['amp'], ['noiseE',128, 128], ['bias',128], ['Lrelu'] ], 'adain128a' : [['adain',128]], 115 | 'Fconv128' : [ ['EqConv3x3',128, 128], ['amp'], ['noiseE',128, 128], ['bias',128], ['Lrelu'] ], 'adain128b' : [['adain',128]], 'toRGB_128' : [['EqConv1x1',128, 3], ['bias',3]], 116 | 'Uconv256' : [ ['EqConvT3x3EX',128, 64], ['blur3x3'], ['amp'], ['noiseE', 64, 256], ['bias', 64], ['Lrelu'] ], 'adain256a' : [['adain', 64]], 117 | 'Fconv256' : [ ['EqConv3x3', 64, 64], ['amp'], ['noiseE', 64, 256], ['bias', 64], ['Lrelu'] ], 'adain256b' : [['adain', 64]], 'toRGB_256' : [['EqConv1x1', 64, 3], ['bias',3]], 118 | 'Uconv512' : [ ['EqConvT3x3EX', 64, 32], ['blur3x3'], ['amp'], ['noiseE', 32, 512], ['bias', 32], ['Lrelu'] ], 'adain512a' : [['adain', 32]], 119 | 'Fconv512' : [ ['EqConv3x3', 32, 32], ['amp'], ['noiseE', 32, 512], ['bias', 32], ['Lrelu'] ], 'adain512b' : [['adain', 32]], 'toRGB_512' : [['EqConv1x1', 32, 3], ['bias',3]], 120 | #'Uconv1024': [ ['EqConvT3x3EX', 32, 16], ['blur3x3'], ['amp'], ['noiseE', 16, 1024], ['bias', 16], ['Lrelu'] ], 'adain1024a': [['adain', 16]], 121 | #'Fconv1024': [ ['EqConv3x3', 16, 16], ['amp'], ['noiseE', 16, 1024], ['bias', 16], ['Lrelu'] ], 'adain1024b': [['adain', 16]], 'toRGB_1024': [['EqConv1x1', 16, 3], ['bias',3]], 122 | } 123 | 124 | def _make_sequential(self,key): 125 | definition = { 126 | 'pixel_norm' : lambda *config: PixelwiseNormalization(), 127 | 'truncation' : lambda *config: TruncationTrick( 128 | num_target=8, threshold=0.7, output_num=18, style_dim=512 ), 129 | 'fc' : lambda *config: EqualizedFullyConnect( 130 | in_dim=config[0],out_dim=config[1], lr=0.01), 131 | 'b' : lambda *config: AddChannelwiseBias(out_channels=config[0], lr=0.01), 132 | 'bias' : lambda *config: AddChannelwiseBias(out_channels=config[0], lr=1.0), 133 | 'amp' : lambda *config: Amplify(2**0.5), 134 | 'Lrelu' : lambda *config: nn.LeakyReLU(negative_slope=0.2), 135 | 'EqConvT3x3EX' : lambda *config: EqualizedFusedConvTransposed2d( 136 | in_channels=config[0], out_channels=config[1], 137 | kernel_size=3, stride=1, padding=1, lr=1.0), 138 | 'EqConv3x3' : lambda *config: EqualizedConv2d( 139 | in_channels=config[0], out_channels=config[1], 140 | kernel_size=3, stride=1, padding=1, lr=1.0), 141 | 'EqConv1x1' : lambda *config: EqualizedConv2d( 142 | in_channels=config[0], out_channels=config[1], 143 | kernel_size=1, stride=1, padding=0, lr=1.0), 144 | 'noiseE' : lambda *config: ElementwiseNoise(ch=config[0], size_hw=config[1]), 145 | 'blur3x3' : lambda *config: Blur3x3(), 146 | 'up' : lambda *config: nn.Upsample( 147 | scale_factor=2,mode='nearest'), 148 | 'adain' : lambda *config: EqualizedAdaIN( 149 | fmap_ch=config[0], style_ch=512, lr=1.0), 150 | } 151 | return nn.Sequential(*[ definition[k](*cfg) for k,*cfg in self.structure[key]]) 152 | 153 | def __init__(self): 154 | super().__init__() 155 | 156 | # 固定入力値 157 | self.register_buffer('const',torch.ones((1, 512, 4, 4),dtype=torch.float32)) 158 | 159 | # 今回は使わない 160 | self.register_buffer('image_mixing_rate',torch.zeros((1,))) # 複数のtoRGBの合成比率 161 | self.register_buffer('style_mixing_rate',torch.zeros((1,))) # スタイルの合成比率 162 | 163 | # 潜在変数のマッピングネットワーク 164 | self.mapping = self._make_sequential('mapping') 165 | self.blocks = nn.ModuleList([self._make_sequential(k) for k in [ 166 | 'START', 'Fconv4', 'Uconv8', 'Fconv8', 'Uconv16', 'Fconv16', 167 | 'Uconv32', 'Fconv32', 'Uconv64', 'Fconv64', 'Uconv128', 'Fconv128', 168 | 'Uconv256', 'Fconv256', 'Uconv512', 'Fconv512'#, 'Uconv1024','Fconv1024' 169 | ] ]) 170 | self.adains = nn.ModuleList([self._make_sequential(k) for k in [ 171 | 'adain4a', 'adain4b', 'adain8a', 'adain8b', 172 | 'adain16a', 'adain16b', 'adain32a', 'adain32b', 173 | 'adain64a', 'adain64b', 'adain128a', 'adain128b', 174 | 'adain256a', 'adain256b', 'adain512a', 'adain512b', 175 | #'adain1024a', 'adain1024b' 176 | ] ]) 177 | self.toRGBs = nn.ModuleList([self._make_sequential(k) for k in [ 178 | 'toRGB_4', 'toRGB_8', 'toRGB_16', 'toRGB_32', 179 | 'toRGB_64', 'toRGB_128', 'toRGB_256', 'toRGB_512', 180 | #'toRGB_1024' 181 | ] ]) 182 | 183 | def forward(self, z): 184 | ''' 185 | input: z : (N,D) D=512 186 | output: img : (N,3,1024,1024) 187 | ''' 188 | N,D = z.shape 189 | 190 | styles = self.mapping(z) # (N,18,D) 191 | tmp = self.const.expand(N,512,4,4) 192 | for i, (adain, conv) in enumerate(zip(self.adains, self.blocks)): 193 | tmp = conv(tmp) 194 | tmp = adain( (tmp, styles[:,i,:]) ) 195 | img = self.toRGBs[-1](tmp) 196 | 197 | return img 198 | 199 | 200 | ########## 以下,重み変換 ######## 201 | 202 | name_trans_dict = { 203 | 'const' : ['any', 'G_synthesis/4x4/Const/const' ], 204 | 'image_mixing_rate' : ['uns', 'G_synthesis/lod' ], 205 | 'style_mixing_rate' : ['uns', 'lod' ], 206 | 'mapping.1.weight' : ['fc_', 'G_mapping/Dense0/weight' ], 207 | 'mapping.3.bias' : ['any', 'G_mapping/Dense0/bias' ], 208 | 'mapping.5.weight' : ['fc_', 'G_mapping/Dense1/weight' ], 209 | 'mapping.7.bias' : ['any', 'G_mapping/Dense1/bias' ], 210 | 'mapping.9.weight' : ['fc_', 'G_mapping/Dense2/weight' ], 211 | 'mapping.11.bias' : ['any', 'G_mapping/Dense2/bias' ], 212 | 'mapping.13.weight' : ['fc_', 'G_mapping/Dense3/weight' ], 213 | 'mapping.15.bias' : ['any', 'G_mapping/Dense3/bias' ], 214 | 'mapping.17.weight' : ['fc_', 'G_mapping/Dense4/weight' ], 215 | 'mapping.19.bias' : ['any', 'G_mapping/Dense4/bias' ], 216 | 'mapping.21.weight' : ['fc_', 'G_mapping/Dense5/weight' ], 217 | 'mapping.23.bias' : ['any', 'G_mapping/Dense5/bias' ], 218 | 'mapping.25.weight' : ['fc_', 'G_mapping/Dense6/weight' ], 219 | 'mapping.27.bias' : ['any', 'G_mapping/Dense6/bias' ], 220 | 'mapping.29.weight' : ['fc_', 'G_mapping/Dense7/weight' ], 221 | 'mapping.31.bias' : ['any', 'G_mapping/Dense7/bias' ], 222 | 'mapping.33.avg_style' : ['any', 'dlatent_avg' ], 223 | 'blocks.0.0.noise_scaler' : ['any', 'G_synthesis/4x4/Const/Noise/weight' ], 224 | 'blocks.0.0.const_noise' : ['any', 'G_synthesis/noise0' ], 225 | 'blocks.0.1.bias' : ['any', 'G_synthesis/4x4/Const/bias' ], 226 | 'blocks.1.0.weight' : ['con', 'G_synthesis/4x4/Conv/weight' ], 227 | 'blocks.1.2.noise_scaler' : ['any', 'G_synthesis/4x4/Conv/Noise/weight' ], 228 | 'blocks.1.2.const_noise' : ['any', 'G_synthesis/noise1' ], 229 | 'blocks.1.3.bias' : ['any', 'G_synthesis/4x4/Conv/bias' ], 230 | 'blocks.2.1.weight' : ['con', 'G_synthesis/8x8/Conv0_up/weight' ], 231 | 'blocks.2.4.noise_scaler' : ['any', 'G_synthesis/8x8/Conv0_up/Noise/weight' ], 232 | 'blocks.2.4.const_noise' : ['any', 'G_synthesis/noise2' ], 233 | 'blocks.2.5.bias' : ['any', 'G_synthesis/8x8/Conv0_up/bias' ], 234 | 'blocks.3.0.weight' : ['con', 'G_synthesis/8x8/Conv1/weight' ], 235 | 'blocks.3.2.noise_scaler' : ['any', 'G_synthesis/8x8/Conv1/Noise/weight' ], 236 | 'blocks.3.2.const_noise' : ['any', 'G_synthesis/noise3' ], 237 | 'blocks.3.3.bias' : ['any', 'G_synthesis/8x8/Conv1/bias' ], 238 | 'blocks.4.1.weight' : ['con', 'G_synthesis/16x16/Conv0_up/weight' ], 239 | 'blocks.4.4.noise_scaler' : ['any', 'G_synthesis/16x16/Conv0_up/Noise/weight' ], 240 | 'blocks.4.4.const_noise' : ['any', 'G_synthesis/noise4' ], 241 | 'blocks.4.5.bias' : ['any', 'G_synthesis/16x16/Conv0_up/bias' ], 242 | 'blocks.5.0.weight' : ['con', 'G_synthesis/16x16/Conv1/weight' ], 243 | 'blocks.5.2.noise_scaler' : ['any', 'G_synthesis/16x16/Conv1/Noise/weight' ], 244 | 'blocks.5.2.const_noise' : ['any', 'G_synthesis/noise5' ], 245 | 'blocks.5.3.bias' : ['any', 'G_synthesis/16x16/Conv1/bias' ], 246 | 'blocks.6.1.weight' : ['con', 'G_synthesis/32x32/Conv0_up/weight' ], 247 | 'blocks.6.4.noise_scaler' : ['any', 'G_synthesis/32x32/Conv0_up/Noise/weight' ], 248 | 'blocks.6.4.const_noise' : ['any', 'G_synthesis/noise6' ], 249 | 'blocks.6.5.bias' : ['any', 'G_synthesis/32x32/Conv0_up/bias' ], 250 | 'blocks.7.0.weight' : ['con', 'G_synthesis/32x32/Conv1/weight' ], 251 | 'blocks.7.2.noise_scaler' : ['any', 'G_synthesis/32x32/Conv1/Noise/weight' ], 252 | 'blocks.7.2.const_noise' : ['any', 'G_synthesis/noise7' ], 253 | 'blocks.7.3.bias' : ['any', 'G_synthesis/32x32/Conv1/bias' ], 254 | 'blocks.8.1.weight' : ['con', 'G_synthesis/64x64/Conv0_up/weight' ], 255 | 'blocks.8.4.noise_scaler' : ['any', 'G_synthesis/64x64/Conv0_up/Noise/weight' ], 256 | 'blocks.8.4.const_noise' : ['any', 'G_synthesis/noise8' ], 257 | 'blocks.8.5.bias' : ['any', 'G_synthesis/64x64/Conv0_up/bias' ], 258 | 'blocks.9.0.weight' : ['con', 'G_synthesis/64x64/Conv1/weight' ], 259 | 'blocks.9.2.noise_scaler' : ['any', 'G_synthesis/64x64/Conv1/Noise/weight' ], 260 | 'blocks.9.2.const_noise' : ['any', 'G_synthesis/noise9' ], 261 | 'blocks.9.3.bias' : ['any', 'G_synthesis/64x64/Conv1/bias' ], 262 | 'blocks.10.0.weight' : ['Tco', 'G_synthesis/128x128/Conv0_up/weight' ], 263 | 'blocks.10.3.noise_scaler' : ['any', 'G_synthesis/128x128/Conv0_up/Noise/weight' ], 264 | 'blocks.10.3.const_noise' : ['any', 'G_synthesis/noise10' ], 265 | 'blocks.10.4.bias' : ['any', 'G_synthesis/128x128/Conv0_up/bias' ], 266 | 'blocks.11.0.weight' : ['con', 'G_synthesis/128x128/Conv1/weight' ], 267 | 'blocks.11.2.noise_scaler' : ['any', 'G_synthesis/128x128/Conv1/Noise/weight' ], 268 | 'blocks.11.2.const_noise' : ['any', 'G_synthesis/noise11' ], 269 | 'blocks.11.3.bias' : ['any', 'G_synthesis/128x128/Conv1/bias' ], 270 | 'blocks.12.0.weight' : ['Tco', 'G_synthesis/256x256/Conv0_up/weight' ], 271 | 'blocks.12.3.noise_scaler' : ['any', 'G_synthesis/256x256/Conv0_up/Noise/weight' ], 272 | 'blocks.12.3.const_noise' : ['any', 'G_synthesis/noise12' ], 273 | 'blocks.12.4.bias' : ['any', 'G_synthesis/256x256/Conv0_up/bias' ], 274 | 'blocks.13.0.weight' : ['con', 'G_synthesis/256x256/Conv1/weight' ], 275 | 'blocks.13.2.noise_scaler' : ['any', 'G_synthesis/256x256/Conv1/Noise/weight' ], 276 | 'blocks.13.2.const_noise' : ['any', 'G_synthesis/noise13' ], 277 | 'blocks.13.3.bias' : ['any', 'G_synthesis/256x256/Conv1/bias' ], 278 | 'blocks.14.0.weight' : ['Tco', 'G_synthesis/512x512/Conv0_up/weight' ], 279 | 'blocks.14.3.noise_scaler' : ['any', 'G_synthesis/512x512/Conv0_up/Noise/weight' ], 280 | 'blocks.14.3.const_noise' : ['any', 'G_synthesis/noise14' ], 281 | 'blocks.14.4.bias' : ['any', 'G_synthesis/512x512/Conv0_up/bias' ], 282 | 'blocks.15.0.weight' : ['con', 'G_synthesis/512x512/Conv1/weight' ], 283 | 'blocks.15.2.noise_scaler' : ['any', 'G_synthesis/512x512/Conv1/Noise/weight' ], 284 | 'blocks.15.2.const_noise' : ['any', 'G_synthesis/noise15' ], 285 | 'blocks.15.3.bias' : ['any', 'G_synthesis/512x512/Conv1/bias' ], 286 | #'blocks.16.0.weight' : ['Tco', 'G_synthesis/1024x1024/Conv0_up/weight' ], 287 | #'blocks.16.3.noise_scaler' : ['any', 'G_synthesis/1024x1024/Conv0_up/Noise/weight' ], 288 | #'blocks.16.3.const_noise' : ['any', 'G_synthesis/noise16' ], 289 | #'blocks.16.4.bias' : ['any', 'G_synthesis/1024x1024/Conv0_up/bias' ], 290 | #'blocks.17.0.weight' : ['con', 'G_synthesis/1024x1024/Conv1/weight' ], 291 | #'blocks.17.2.noise_scaler' : ['any', 'G_synthesis/1024x1024/Conv1/Noise/weight' ], 292 | #'blocks.17.2.const_noise' : ['any', 'G_synthesis/noise17' ], 293 | #'blocks.17.3.bias' : ['any', 'G_synthesis/1024x1024/Conv1/bias' ], 294 | 'adains.0.0.fc.weight' : ['fc_', 'G_synthesis/4x4/Const/StyleMod/weight' ], 295 | 'adains.0.0.bias.bias' : ['any', 'G_synthesis/4x4/Const/StyleMod/bias' ], 296 | 'adains.1.0.fc.weight' : ['fc_', 'G_synthesis/4x4/Conv/StyleMod/weight' ], 297 | 'adains.1.0.bias.bias' : ['any', 'G_synthesis/4x4/Conv/StyleMod/bias' ], 298 | 'adains.2.0.fc.weight' : ['fc_', 'G_synthesis/8x8/Conv0_up/StyleMod/weight' ], 299 | 'adains.2.0.bias.bias' : ['any', 'G_synthesis/8x8/Conv0_up/StyleMod/bias' ], 300 | 'adains.3.0.fc.weight' : ['fc_', 'G_synthesis/8x8/Conv1/StyleMod/weight' ], 301 | 'adains.3.0.bias.bias' : ['any', 'G_synthesis/8x8/Conv1/StyleMod/bias' ], 302 | 'adains.4.0.fc.weight' : ['fc_', 'G_synthesis/16x16/Conv0_up/StyleMod/weight' ], 303 | 'adains.4.0.bias.bias' : ['any', 'G_synthesis/16x16/Conv0_up/StyleMod/bias' ], 304 | 'adains.5.0.fc.weight' : ['fc_', 'G_synthesis/16x16/Conv1/StyleMod/weight' ], 305 | 'adains.5.0.bias.bias' : ['any', 'G_synthesis/16x16/Conv1/StyleMod/bias' ], 306 | 'adains.6.0.fc.weight' : ['fc_', 'G_synthesis/32x32/Conv0_up/StyleMod/weight' ], 307 | 'adains.6.0.bias.bias' : ['any', 'G_synthesis/32x32/Conv0_up/StyleMod/bias' ], 308 | 'adains.7.0.fc.weight' : ['fc_', 'G_synthesis/32x32/Conv1/StyleMod/weight' ], 309 | 'adains.7.0.bias.bias' : ['any', 'G_synthesis/32x32/Conv1/StyleMod/bias' ], 310 | 'adains.8.0.fc.weight' : ['fc_', 'G_synthesis/64x64/Conv0_up/StyleMod/weight' ], 311 | 'adains.8.0.bias.bias' : ['any', 'G_synthesis/64x64/Conv0_up/StyleMod/bias' ], 312 | 'adains.9.0.fc.weight' : ['fc_', 'G_synthesis/64x64/Conv1/StyleMod/weight' ], 313 | 'adains.9.0.bias.bias' : ['any', 'G_synthesis/64x64/Conv1/StyleMod/bias' ], 314 | 'adains.10.0.fc.weight' : ['fc_', 'G_synthesis/128x128/Conv0_up/StyleMod/weight' ], 315 | 'adains.10.0.bias.bias' : ['any', 'G_synthesis/128x128/Conv0_up/StyleMod/bias' ], 316 | 'adains.11.0.fc.weight' : ['fc_', 'G_synthesis/128x128/Conv1/StyleMod/weight' ], 317 | 'adains.11.0.bias.bias' : ['any', 'G_synthesis/128x128/Conv1/StyleMod/bias' ], 318 | 'adains.12.0.fc.weight' : ['fc_', 'G_synthesis/256x256/Conv0_up/StyleMod/weight' ], 319 | 'adains.12.0.bias.bias' : ['any', 'G_synthesis/256x256/Conv0_up/StyleMod/bias' ], 320 | 'adains.13.0.fc.weight' : ['fc_', 'G_synthesis/256x256/Conv1/StyleMod/weight' ], 321 | 'adains.13.0.bias.bias' : ['any', 'G_synthesis/256x256/Conv1/StyleMod/bias' ], 322 | 'adains.14.0.fc.weight' : ['fc_', 'G_synthesis/512x512/Conv0_up/StyleMod/weight' ], 323 | 'adains.14.0.bias.bias' : ['any', 'G_synthesis/512x512/Conv0_up/StyleMod/bias' ], 324 | 'adains.15.0.fc.weight' : ['fc_', 'G_synthesis/512x512/Conv1/StyleMod/weight' ], 325 | 'adains.15.0.bias.bias' : ['any', 'G_synthesis/512x512/Conv1/StyleMod/bias' ], 326 | #'adains.16.0.fc.weight' : ['fc_', 'G_synthesis/1024x1024/Conv0_up/StyleMod/weight' ], 327 | #'adains.16.0.bias.bias' : ['any', 'G_synthesis/1024x1024/Conv0_up/StyleMod/bias' ], 328 | #'adains.17.0.fc.weight' : ['fc_', 'G_synthesis/1024x1024/Conv1/StyleMod/weight' ], 329 | #'adains.17.0.bias.bias' : ['any', 'G_synthesis/1024x1024/Conv1/StyleMod/bias' ], 330 | 'toRGBs.0.0.weight' : ['con', 'G_synthesis/ToRGB_lod7/weight' ],#: ['con', 'G_synthesis/ToRGB_lod8/weight' ], 331 | 'toRGBs.0.1.bias' : ['any', 'G_synthesis/ToRGB_lod7/bias' ],#: ['any', 'G_synthesis/ToRGB_lod8/bias' ], 332 | 'toRGBs.1.0.weight' : ['con', 'G_synthesis/ToRGB_lod6/weight' ],#: ['con', 'G_synthesis/ToRGB_lod7/weight' ], 333 | 'toRGBs.1.1.bias' : ['any', 'G_synthesis/ToRGB_lod6/bias' ],#: ['any', 'G_synthesis/ToRGB_lod7/bias' ], 334 | 'toRGBs.2.0.weight' : ['con', 'G_synthesis/ToRGB_lod5/weight' ],#: ['con', 'G_synthesis/ToRGB_lod6/weight' ], 335 | 'toRGBs.2.1.bias' : ['any', 'G_synthesis/ToRGB_lod5/bias' ],#: ['any', 'G_synthesis/ToRGB_lod6/bias' ], 336 | 'toRGBs.3.0.weight' : ['con', 'G_synthesis/ToRGB_lod4/weight' ],#: ['con', 'G_synthesis/ToRGB_lod5/weight' ], 337 | 'toRGBs.3.1.bias' : ['any', 'G_synthesis/ToRGB_lod4/bias' ],#: ['any', 'G_synthesis/ToRGB_lod5/bias' ], 338 | 'toRGBs.4.0.weight' : ['con', 'G_synthesis/ToRGB_lod3/weight' ],#: ['con', 'G_synthesis/ToRGB_lod4/weight' ], 339 | 'toRGBs.4.1.bias' : ['any', 'G_synthesis/ToRGB_lod3/bias' ],#: ['any', 'G_synthesis/ToRGB_lod4/bias' ], 340 | 'toRGBs.5.0.weight' : ['con', 'G_synthesis/ToRGB_lod2/weight' ],#: ['con', 'G_synthesis/ToRGB_lod3/weight' ], 341 | 'toRGBs.5.1.bias' : ['any', 'G_synthesis/ToRGB_lod2/bias' ],#: ['any', 'G_synthesis/ToRGB_lod3/bias' ], 342 | 'toRGBs.6.0.weight' : ['con', 'G_synthesis/ToRGB_lod1/weight' ],#: ['con', 'G_synthesis/ToRGB_lod2/weight' ], 343 | 'toRGBs.6.1.bias' : ['any', 'G_synthesis/ToRGB_lod1/bias' ],#: ['any', 'G_synthesis/ToRGB_lod2/bias' ], 344 | 'toRGBs.7.0.weight' : ['con', 'G_synthesis/ToRGB_lod0/weight' ],#: ['con', 'G_synthesis/ToRGB_lod1/weight' ], 345 | 'toRGBs.7.1.bias' : ['any', 'G_synthesis/ToRGB_lod0/bias' ],#: ['any', 'G_synthesis/ToRGB_lod1/bias' ], 346 | #'toRGBs.8.0.weight' : ['con', 'G_synthesis/ToRGB_lod/weight' ],#: ['con', 'G_synthesis/ToRGB_lod0/weight' ], 347 | #'toRGBs.8.1.bias' : ['any', 'G_synthesis/ToRGB_lod/bias' ],#: ['any', 'G_synthesis/ToRGB_lod0/bias' ], 348 | } 349 | -------------------------------------------------------------------------------- /waifu/stylegan2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from common import PixelwiseNormalization, Amplify, AddChannelwiseBias, EqualizedFullyConnect, TruncationTrick 8 | 9 | 10 | # 固定ノイズ 11 | class PixelwiseNoise(nn.Module): 12 | def __init__(self, resolution): 13 | super().__init__() 14 | self.register_buffer("const_noise", torch.randn((1, 1, resolution, resolution))) 15 | self.noise_scaler = nn.Parameter(torch.zeros(1)) 16 | 17 | def forward(self, x): 18 | N,C,H,W = x.shape 19 | noise = self.const_noise.expand(N,C,H,W) 20 | return x + noise * self.noise_scaler 21 | 22 | 23 | # 解像度上げるときのModConvのみ 24 | class FusedBlur3x3(nn.Module): 25 | def __init__(self): 26 | super().__init__() 27 | kernel = np.array([ [1/16, 2/16, 1/16], 28 | [2/16, 4/16, 2/16], 29 | [1/16, 2/16, 1/16]],dtype=np.float32) 30 | pads = [[(0,1),(0,1)],[(0,1),(1,0)],[(1,0),(0,1)],[(1,0),(1,0)]] 31 | kernel = np.stack( [np.pad(kernel,pad,'constant') for pad in pads] ).sum(0) 32 | #kernel [ [1/16, 3/16, 3/16, 1/16,], 33 | # [3/16, 9/16, 9/16, 3/16,], 34 | # [3/16, 9/16, 9/16, 3/16,], 35 | # [1/16, 3/16, 3/16, 1/16,] ] 36 | self.kernel = torch.from_numpy(kernel) 37 | 38 | def forward(self, feature): 39 | # featureは(N,C,H+1,W+1) 40 | kernel = self.kernel.clone().to(feature.device) 41 | _N,C,_Hp1,_Wp1 = feature.shape 42 | return F.conv2d(feature, kernel.expand(C,1,4,4), padding=1, groups=C) 43 | 44 | 45 | # 学習率を調整した変調転置畳み込み 46 | class EqualizedModulatedConvTranspose2d(nn.Module): 47 | def __init__(self, in_channels, out_channels, kernel_size, style_dim, padding, stride, demodulate=True, lr=1): 48 | super().__init__() 49 | 50 | self.padding, self.stride = padding, stride 51 | self.demodulate = demodulate 52 | 53 | self.weight = nn.Parameter( torch.randn(in_channels, out_channels, kernel_size, kernel_size)) 54 | torch.nn.init.normal_(self.weight.data, mean=0.0, std=1.0/lr) 55 | self.weight_scaler = 1 / (in_channels * kernel_size*kernel_size)**0.5 * lr 56 | 57 | self.fc = EqualizedFullyConnect(style_dim, in_channels, lr) 58 | self.bias = AddChannelwiseBias(in_channels, lr) 59 | 60 | def forward(self, pack): 61 | x, style = pack 62 | N, iC, H, W = x.shape 63 | iC, oC, kH, kW = self.weight.shape 64 | 65 | mod_rates = self.bias(self.fc(style))+1 # (N, iC) 66 | modulated_weight = self.weight_scaler*self.weight.view(1,iC,oC,kH,kW) \ 67 | * mod_rates.view(N,iC,1,1,1) # (N,iC,oC,kH,kW) 68 | 69 | if self.demodulate: 70 | demod_norm = 1 / ((modulated_weight**2).sum([1,3,4]) + 1e-8)**0.5 # (N, oC) 71 | weight = modulated_weight * demod_norm.view(N, 1, oC, 1, 1) # (N,iC,oC,kH,kW) 72 | else: 73 | weight = modulated_weight 74 | 75 | x = x.view(1, N*iC, H, W) 76 | weight = weight.view(N*iC,oC,kH,kW) 77 | out = F.conv_transpose2d(x, weight, padding=self.padding, stride=self.stride, groups=N) 78 | 79 | _, _, Hp1, Wp1 = out.shape 80 | out = out.view(N, oC, Hp1, Wp1) 81 | 82 | return out 83 | 84 | 85 | # 学習率を調整した変調畳み込み 86 | class EqualizedModulatedConv2d(nn.Module): 87 | def __init__(self, in_channels, out_channels, kernel_size, style_dim, padding, stride, demodulate=True, lr=1): 88 | super().__init__() 89 | 90 | self.padding, self.stride = padding, stride 91 | self.demodulate = demodulate 92 | 93 | self.weight = nn.Parameter( torch.randn(out_channels, in_channels, kernel_size, kernel_size)) 94 | torch.nn.init.normal_(self.weight.data, mean=0.0, std=1.0/lr) 95 | self.weight_scaler = 1 / (in_channels*kernel_size*kernel_size)**0.5 * lr 96 | 97 | self.fc = EqualizedFullyConnect(style_dim, in_channels, lr) 98 | self.bias = AddChannelwiseBias(in_channels, lr) 99 | 100 | def forward(self, pack): 101 | x, style = pack 102 | N, iC, H, W = x.shape 103 | oC, iC, kH, kW = self.weight.shape 104 | 105 | mod_rates = self.bias(self.fc(style))+1 # (N, iC) 106 | modulated_weight = self.weight_scaler*self.weight.view(1,oC,iC,kH,kW) \ 107 | * mod_rates.view(N,1,iC,1,1) # (N,oC,iC,kH,kW) 108 | 109 | if self.demodulate: 110 | demod_norm = 1 / ((modulated_weight**2).sum([2,3,4]) + 1e-8)**0.5 # (N, oC) 111 | weight = modulated_weight * demod_norm.view(N, oC, 1, 1, 1) # (N,oC,iC,kH,kW) 112 | else: # ToRGB 113 | weight = modulated_weight 114 | 115 | out = F.conv2d(x.view(1,N*iC,H,W), weight.view(N*oC,iC,kH,kW), 116 | padding=self.padding, stride=self.stride, groups=N).view(N,oC,H,W) 117 | return out 118 | 119 | 120 | class Generator(nn.Module): 121 | 122 | structure = { 123 | 'mapping': [['pixel_norm'], ['fc',512,512],['b',512],['amp'],['Lrelu'],['fc',512,512],['b',512],['amp'],['Lrelu'], 124 | ['fc',512,512],['b',512],['amp'],['Lrelu'],['fc',512,512],['b',512],['amp'],['Lrelu'], 125 | ['fc',512,512],['b',512],['amp'],['Lrelu'],['fc',512,512],['b',512],['amp'],['Lrelu'], 126 | ['fc',512,512],['b',512],['amp'],['Lrelu'],['fc',512,512],['b',512],['amp'],['Lrelu'],['truncation']], 127 | 'Fconv4' : [['EqModConv3x3', 512, 512], ['noiseP', 4], ['bias',512], ['amp'], ['Lrelu'] ], 'toRGB_4' : [['EqModConv1x1',512, 3], ['bias',3]], 128 | 'Uconv8' : [['EqModConvT3x3', 512, 512], ['blurEX'], ['noiseP', 8], ['bias',512], ['amp'], ['Lrelu'] ], 129 | 'Fconv8' : [['EqModConv3x3', 512, 512], ['noiseP', 8], ['bias',512], ['amp'], ['Lrelu'] ], 'toRGB_8' : [['EqModConv1x1',512, 3], ['bias',3]], 130 | 'Uconv16' : [['EqModConvT3x3', 512, 512], ['blurEX'], ['noiseP', 16], ['bias',512], ['amp'], ['Lrelu'] ], 131 | 'Fconv16' : [['EqModConv3x3', 512, 512], ['noiseP', 16], ['bias',512], ['amp'], ['Lrelu'] ], 'toRGB_16' : [['EqModConv1x1',512, 3], ['bias',3]], 132 | 'Uconv32' : [['EqModConvT3x3', 512, 512], ['blurEX'], ['noiseP', 32], ['bias',512], ['amp'], ['Lrelu'] ], 133 | 'Fconv32' : [['EqModConv3x3', 512, 512], ['noiseP', 32], ['bias',512], ['amp'], ['Lrelu'] ], 'toRGB_32' : [['EqModConv1x1',512, 3], ['bias',3]], 134 | 'Uconv64' : [['EqModConvT3x3', 512, 512], ['blurEX'], ['noiseP', 64], ['bias',512], ['amp'], ['Lrelu'] ], 135 | 'Fconv64' : [['EqModConv3x3', 512, 512], ['noiseP', 64], ['bias',512], ['amp'], ['Lrelu'] ], 'toRGB_64' : [['EqModConv1x1',512, 3], ['bias',3]], 136 | 'Uconv128' : [['EqModConvT3x3', 512, 256], ['blurEX'], ['noiseP', 128], ['bias',256], ['amp'], ['Lrelu'] ], 137 | 'Fconv128' : [['EqModConv3x3', 256, 256], ['noiseP', 128], ['bias',256], ['amp'], ['Lrelu'] ], 'toRGB_128' : [['EqModConv1x1',256, 3], ['bias',3]], 138 | 'Uconv256' : [['EqModConvT3x3', 256, 128], ['blurEX'], ['noiseP', 256], ['bias',128], ['amp'], ['Lrelu'] ], 139 | 'Fconv256' : [['EqModConv3x3', 128, 128], ['noiseP', 256], ['bias',128], ['amp'], ['Lrelu'] ], 'toRGB_256' : [['EqModConv1x1',128, 3], ['bias',3]], 140 | 'Uconv512' : [['EqModConvT3x3', 128, 64], ['blurEX'], ['noiseP', 512], ['bias', 64], ['amp'], ['Lrelu'] ], 141 | 'Fconv512' : [['EqModConv3x3', 64, 64], ['noiseP', 512], ['bias', 64], ['amp'], ['Lrelu'] ], 'toRGB_512' : [['EqModConv1x1', 64, 3], ['bias',3]], 142 | #'Uconv1024': [['EqModConvT3x3', 64, 32], ['blurEX'], ['noiseP',1024], ['bias', 32], ['amp'], ['Lrelu'] ], 143 | #'Fconv1024': [['EqModConv3x3', 32, 32], ['noiseP',1024], ['bias', 32], ['amp'], ['Lrelu'] ], 'toRGB_1024': [['EqModConv1x1', 32, 3], ['bias',3]], 144 | } 145 | 146 | def _make_sequential(self,key): 147 | definition = { 148 | 'pixel_norm' : lambda *config: PixelwiseNormalization(), 149 | 'truncation' : lambda *config: TruncationTrick( 150 | num_target=10, threshold=0.7, output_num=18, style_dim=512), 151 | 'fc' : lambda *config: EqualizedFullyConnect( 152 | in_dim=config[0], out_dim=config[1], lr=0.01), 153 | 'b' : lambda *config: AddChannelwiseBias(out_channels=config[0], lr=0.01), 154 | 'bias' : lambda *config: AddChannelwiseBias(out_channels=config[0], lr=1.0), 155 | 'amp' : lambda *config: Amplify(2**0.5), 156 | 'Lrelu' : lambda *config: nn.LeakyReLU(negative_slope=0.2), 157 | 'EqModConvT3x3': lambda *config: EqualizedModulatedConvTranspose2d( 158 | in_channels=config[0], out_channels=config[1], 159 | kernel_size=3, stride=2, padding=0, 160 | demodulate=True, lr=1.0, style_dim=512), 161 | 'EqModConv3x3' : lambda *config: EqualizedModulatedConv2d( 162 | in_channels=config[0], out_channels=config[1], 163 | kernel_size=3, stride=1, padding=1, 164 | demodulate=True, lr=1.0, style_dim=512), 165 | 'EqModConv1x1' : lambda *config: EqualizedModulatedConv2d( 166 | in_channels=config[0], out_channels=config[1], 167 | kernel_size=1, stride=1, padding=0, 168 | demodulate=False, lr=1.0, style_dim=512), 169 | 'noiseP' : lambda *config: PixelwiseNoise(resolution=config[0]), 170 | 'blurEX' : lambda *config: FusedBlur3x3(), 171 | } 172 | return nn.Sequential(*[ definition[k](*cfg) for k,*cfg in self.structure[key]]) 173 | 174 | 175 | def __init__(self): 176 | super().__init__() 177 | 178 | self.const_input = nn.Parameter(torch.randn(1, 512, 4, 4)) 179 | self.register_buffer('style_mixing_rate',torch.zeros((1,))) # スタイルの合成比率,今回は使わない 180 | 181 | self.mapping = self._make_sequential('mapping') 182 | self.blocks = nn.ModuleList([self._make_sequential(k) for k in [ 183 | 'Fconv4', 'Uconv8', 'Fconv8', 'Uconv16', 'Fconv16', 184 | 'Uconv32', 'Fconv32', 'Uconv64', 'Fconv64', 'Uconv128', 'Fconv128', 185 | 'Uconv256', 'Fconv256', 'Uconv512', 'Fconv512'#, 'Uconv1024','Fconv1024' 186 | ] ]) 187 | self.toRGBs = nn.ModuleList([self._make_sequential(k) for k in [ 188 | 'toRGB_4', 'toRGB_8', 'toRGB_16', 'toRGB_32', 189 | 'toRGB_64', 'toRGB_128', 'toRGB_256', 'toRGB_512', 190 | #'toRGB_1024' 191 | ] ]) 192 | 193 | 194 | def forward(self, z): 195 | N,D = z.shape 196 | 197 | # 潜在変数からスタイルへ変換 198 | styles = self.mapping(z) # (N,18,D) 199 | styles = [styles[:,i] for i in range(18)] # list[(N,D),]x18 200 | 201 | tmp = self.const_input.repeat(N, 1, 1, 1) 202 | tmp = self.blocks[0]( (tmp,styles[0]) ) 203 | skip = self.toRGBs[0]( (tmp,styles[1]) ) 204 | 205 | for convU, convF, toRGB, styU,styF,styT in zip( \ 206 | self.blocks[1::2], self.blocks[2::2], self.toRGBs[1:], 207 | styles[1::2], styles[2::2], styles[3::2]): 208 | tmp = convU( (tmp,styU) ) 209 | tmp = convF( (tmp,styF) ) 210 | skip = toRGB( (tmp,styT) ) + F.interpolate(skip,scale_factor=2,mode='bilinear',align_corners=False) 211 | 212 | return skip 213 | 214 | 215 | # { pytorchでの名前 : [変換関数, tensorflowでの名前] } 216 | name_trans_dict = { 217 | 'const_input' : ['any', 'G_synthesis/4x4/Const/const' ], 218 | 'style_mixing_rate' : ['uns', 'lod' ], 219 | 'mapping.1.weight' : ['fc_', 'G_mapping/Dense0/weight' ], 220 | 'mapping.2.bias' : ['any', 'G_mapping/Dense0/bias' ], 221 | 'mapping.5.weight' : ['fc_', 'G_mapping/Dense1/weight' ], 222 | 'mapping.6.bias' : ['any', 'G_mapping/Dense1/bias' ], 223 | 'mapping.9.weight' : ['fc_', 'G_mapping/Dense2/weight' ], 224 | 'mapping.10.bias' : ['any', 'G_mapping/Dense2/bias' ], 225 | 'mapping.13.weight' : ['fc_', 'G_mapping/Dense3/weight' ], 226 | 'mapping.14.bias' : ['any', 'G_mapping/Dense3/bias' ], 227 | 'mapping.17.weight' : ['fc_', 'G_mapping/Dense4/weight' ], 228 | 'mapping.18.bias' : ['any', 'G_mapping/Dense4/bias' ], 229 | 'mapping.21.weight' : ['fc_', 'G_mapping/Dense5/weight' ], 230 | 'mapping.22.bias' : ['any', 'G_mapping/Dense5/bias' ], 231 | 'mapping.25.weight' : ['fc_', 'G_mapping/Dense6/weight' ], 232 | 'mapping.26.bias' : ['any', 'G_mapping/Dense6/bias' ], 233 | 'mapping.29.weight' : ['fc_', 'G_mapping/Dense7/weight' ], 234 | 'mapping.30.bias' : ['any', 'G_mapping/Dense7/bias' ], 235 | 'mapping.33.avg_style' : ['any', 'dlatent_avg' ], 236 | 'blocks.0.0.weight' : ['con', 'G_synthesis/4x4/Conv/weight' ], 237 | 'blocks.0.0.fc.weight' : ['fc_', 'G_synthesis/4x4/Conv/mod_weight' ], 238 | 'blocks.0.0.bias.bias' : ['any', 'G_synthesis/4x4/Conv/mod_bias' ], 239 | 'blocks.0.1.noise_scaler' : ['uns', 'G_synthesis/4x4/Conv/noise_strength' ], 240 | 'blocks.0.1.const_noise' : ['any', 'G_synthesis/noise0' ], 241 | 'blocks.0.2.bias' : ['any', 'G_synthesis/4x4/Conv/bias' ], 242 | 'blocks.1.0.weight' : ['mTc', 'G_synthesis/8x8/Conv0_up/weight' ], 243 | 'blocks.1.0.fc.weight' : ['fc_', 'G_synthesis/8x8/Conv0_up/mod_weight' ], 244 | 'blocks.1.0.bias.bias' : ['any', 'G_synthesis/8x8/Conv0_up/mod_bias' ], 245 | 'blocks.1.2.noise_scaler' : ['uns', 'G_synthesis/8x8/Conv0_up/noise_strength' ], 246 | 'blocks.1.2.const_noise' : ['any', 'G_synthesis/noise1' ], 247 | 'blocks.1.3.bias' : ['any', 'G_synthesis/8x8/Conv0_up/bias' ], 248 | 'blocks.2.0.weight' : ['con', 'G_synthesis/8x8/Conv1/weight' ], 249 | 'blocks.2.0.fc.weight' : ['fc_', 'G_synthesis/8x8/Conv1/mod_weight' ], 250 | 'blocks.2.0.bias.bias' : ['any', 'G_synthesis/8x8/Conv1/mod_bias' ], 251 | 'blocks.2.1.noise_scaler' : ['uns', 'G_synthesis/8x8/Conv1/noise_strength' ], 252 | 'blocks.2.1.const_noise' : ['any', 'G_synthesis/noise2' ], 253 | 'blocks.2.2.bias' : ['any', 'G_synthesis/8x8/Conv1/bias' ], 254 | 'blocks.3.0.weight' : ['mTc', 'G_synthesis/16x16/Conv0_up/weight' ], 255 | 'blocks.3.0.fc.weight' : ['fc_', 'G_synthesis/16x16/Conv0_up/mod_weight' ], 256 | 'blocks.3.0.bias.bias' : ['any', 'G_synthesis/16x16/Conv0_up/mod_bias' ], 257 | 'blocks.3.2.noise_scaler' : ['uns', 'G_synthesis/16x16/Conv0_up/noise_strength' ], 258 | 'blocks.3.2.const_noise' : ['any', 'G_synthesis/noise3' ], 259 | 'blocks.3.3.bias' : ['any', 'G_synthesis/16x16/Conv0_up/bias' ], 260 | 'blocks.4.0.weight' : ['con', 'G_synthesis/16x16/Conv1/weight' ], 261 | 'blocks.4.0.fc.weight' : ['fc_', 'G_synthesis/16x16/Conv1/mod_weight' ], 262 | 'blocks.4.0.bias.bias' : ['any', 'G_synthesis/16x16/Conv1/mod_bias' ], 263 | 'blocks.4.1.noise_scaler' : ['uns', 'G_synthesis/16x16/Conv1/noise_strength' ], 264 | 'blocks.4.1.const_noise' : ['any', 'G_synthesis/noise4' ], 265 | 'blocks.4.2.bias' : ['any', 'G_synthesis/16x16/Conv1/bias' ], 266 | 'blocks.5.0.weight' : ['mTc', 'G_synthesis/32x32/Conv0_up/weight' ], 267 | 'blocks.5.0.fc.weight' : ['fc_', 'G_synthesis/32x32/Conv0_up/mod_weight' ], 268 | 'blocks.5.0.bias.bias' : ['any', 'G_synthesis/32x32/Conv0_up/mod_bias' ], 269 | 'blocks.5.2.noise_scaler' : ['uns', 'G_synthesis/32x32/Conv0_up/noise_strength' ], 270 | 'blocks.5.2.const_noise' : ['any', 'G_synthesis/noise5' ], 271 | 'blocks.5.3.bias' : ['any', 'G_synthesis/32x32/Conv0_up/bias' ], 272 | 'blocks.6.0.weight' : ['con', 'G_synthesis/32x32/Conv1/weight' ], 273 | 'blocks.6.0.fc.weight' : ['fc_', 'G_synthesis/32x32/Conv1/mod_weight' ], 274 | 'blocks.6.0.bias.bias' : ['any', 'G_synthesis/32x32/Conv1/mod_bias' ], 275 | 'blocks.6.1.noise_scaler' : ['uns', 'G_synthesis/32x32/Conv1/noise_strength' ], 276 | 'blocks.6.1.const_noise' : ['any', 'G_synthesis/noise6' ], 277 | 'blocks.6.2.bias' : ['any', 'G_synthesis/32x32/Conv1/bias' ], 278 | 'blocks.7.0.weight' : ['mTc', 'G_synthesis/64x64/Conv0_up/weight' ], 279 | 'blocks.7.0.fc.weight' : ['fc_', 'G_synthesis/64x64/Conv0_up/mod_weight' ], 280 | 'blocks.7.0.bias.bias' : ['any', 'G_synthesis/64x64/Conv0_up/mod_bias' ], 281 | 'blocks.7.2.noise_scaler' : ['uns', 'G_synthesis/64x64/Conv0_up/noise_strength' ], 282 | 'blocks.7.2.const_noise' : ['any', 'G_synthesis/noise7' ], 283 | 'blocks.7.3.bias' : ['any', 'G_synthesis/64x64/Conv0_up/bias' ], 284 | 'blocks.8.0.weight' : ['con', 'G_synthesis/64x64/Conv1/weight' ], 285 | 'blocks.8.0.fc.weight' : ['fc_', 'G_synthesis/64x64/Conv1/mod_weight' ], 286 | 'blocks.8.0.bias.bias' : ['any', 'G_synthesis/64x64/Conv1/mod_bias' ], 287 | 'blocks.8.1.noise_scaler' : ['uns', 'G_synthesis/64x64/Conv1/noise_strength' ], 288 | 'blocks.8.1.const_noise' : ['any', 'G_synthesis/noise8' ], 289 | 'blocks.8.2.bias' : ['any', 'G_synthesis/64x64/Conv1/bias' ], 290 | 'blocks.9.0.weight' : ['mTc', 'G_synthesis/128x128/Conv0_up/weight' ], 291 | 'blocks.9.0.fc.weight' : ['fc_', 'G_synthesis/128x128/Conv0_up/mod_weight' ], 292 | 'blocks.9.0.bias.bias' : ['any', 'G_synthesis/128x128/Conv0_up/mod_bias' ], 293 | 'blocks.9.2.noise_scaler' : ['uns', 'G_synthesis/128x128/Conv0_up/noise_strength' ], 294 | 'blocks.9.2.const_noise' : ['any', 'G_synthesis/noise9' ], 295 | 'blocks.9.3.bias' : ['any', 'G_synthesis/128x128/Conv0_up/bias' ], 296 | 'blocks.10.0.weight' : ['con', 'G_synthesis/128x128/Conv1/weight' ], 297 | 'blocks.10.0.fc.weight' : ['fc_', 'G_synthesis/128x128/Conv1/mod_weight' ], 298 | 'blocks.10.0.bias.bias' : ['any', 'G_synthesis/128x128/Conv1/mod_bias' ], 299 | 'blocks.10.1.noise_scaler' : ['uns', 'G_synthesis/128x128/Conv1/noise_strength' ], 300 | 'blocks.10.1.const_noise' : ['any', 'G_synthesis/noise10' ], 301 | 'blocks.10.2.bias' : ['any', 'G_synthesis/128x128/Conv1/bias' ], 302 | 'blocks.11.0.weight' : ['mTc', 'G_synthesis/256x256/Conv0_up/weight' ], 303 | 'blocks.11.0.fc.weight' : ['fc_', 'G_synthesis/256x256/Conv0_up/mod_weight' ], 304 | 'blocks.11.0.bias.bias' : ['any', 'G_synthesis/256x256/Conv0_up/mod_bias' ], 305 | 'blocks.11.2.noise_scaler' : ['uns', 'G_synthesis/256x256/Conv0_up/noise_strength' ], 306 | 'blocks.11.2.const_noise' : ['any', 'G_synthesis/noise11' ], 307 | 'blocks.11.3.bias' : ['any', 'G_synthesis/256x256/Conv0_up/bias' ], 308 | 'blocks.12.0.weight' : ['con', 'G_synthesis/256x256/Conv1/weight' ], 309 | 'blocks.12.0.fc.weight' : ['fc_', 'G_synthesis/256x256/Conv1/mod_weight' ], 310 | 'blocks.12.0.bias.bias' : ['any', 'G_synthesis/256x256/Conv1/mod_bias' ], 311 | 'blocks.12.1.noise_scaler' : ['uns', 'G_synthesis/256x256/Conv1/noise_strength' ], 312 | 'blocks.12.1.const_noise' : ['any', 'G_synthesis/noise12' ], 313 | 'blocks.12.2.bias' : ['any', 'G_synthesis/256x256/Conv1/bias' ], 314 | 'blocks.13.0.weight' : ['mTc', 'G_synthesis/512x512/Conv0_up/weight' ], 315 | 'blocks.13.0.fc.weight' : ['fc_', 'G_synthesis/512x512/Conv0_up/mod_weight' ], 316 | 'blocks.13.0.bias.bias' : ['any', 'G_synthesis/512x512/Conv0_up/mod_bias' ], 317 | 'blocks.13.2.noise_scaler' : ['uns', 'G_synthesis/512x512/Conv0_up/noise_strength' ], 318 | 'blocks.13.2.const_noise' : ['any', 'G_synthesis/noise13' ], 319 | 'blocks.13.3.bias' : ['any', 'G_synthesis/512x512/Conv0_up/bias' ], 320 | 'blocks.14.0.weight' : ['con', 'G_synthesis/512x512/Conv1/weight' ], 321 | 'blocks.14.0.fc.weight' : ['fc_', 'G_synthesis/512x512/Conv1/mod_weight' ], 322 | 'blocks.14.0.bias.bias' : ['any', 'G_synthesis/512x512/Conv1/mod_bias' ], 323 | 'blocks.14.1.noise_scaler' : ['uns', 'G_synthesis/512x512/Conv1/noise_strength' ], 324 | 'blocks.14.1.const_noise' : ['any', 'G_synthesis/noise14' ], 325 | 'blocks.14.2.bias' : ['any', 'G_synthesis/512x512/Conv1/bias' ], 326 | #'blocks.15.0.weight' : ['mTc', 'G_synthesis/1024x1024/Conv0_up/weight' ], 327 | #'blocks.15.0.fc.weight' : ['fc_', 'G_synthesis/1024x1024/Conv0_up/mod_weight' ], 328 | #'blocks.15.0.bias.bias' : ['any', 'G_synthesis/1024x1024/Conv0_up/mod_bias' ], 329 | #'blocks.15.2.noise_scaler' : ['uns', 'G_synthesis/1024x1024/Conv0_up/noise_strength'], 330 | #'blocks.15.2.const_noise' : ['any', 'G_synthesis/noise15' ], 331 | #'blocks.15.3.bias' : ['any', 'G_synthesis/1024x1024/Conv0_up/bias' ], 332 | #'blocks.16.0.weight' : ['con', 'G_synthesis/1024x1024/Conv1/weight' ], 333 | #'blocks.16.0.fc.weight' : ['fc_', 'G_synthesis/1024x1024/Conv1/mod_weight' ], 334 | #'blocks.16.0.bias.bias' : ['any', 'G_synthesis/1024x1024/Conv1/mod_bias' ], 335 | #'blocks.16.1.noise_scaler' : ['uns', 'G_synthesis/1024x1024/Conv1/noise_strength' ], 336 | #'blocks.16.1.const_noise' : ['any', 'G_synthesis/noise16' ], 337 | #'blocks.16.2.bias' : ['any', 'G_synthesis/1024x1024/Conv1/bias' ], 338 | 'toRGBs.0.0.weight' : ['con', 'G_synthesis/4x4/ToRGB/weight' ], 339 | 'toRGBs.0.0.fc.weight' : ['fc_', 'G_synthesis/4x4/ToRGB/mod_weight' ], 340 | 'toRGBs.0.0.bias.bias' : ['any', 'G_synthesis/4x4/ToRGB/mod_bias' ], 341 | 'toRGBs.0.1.bias' : ['any', 'G_synthesis/4x4/ToRGB/bias' ], 342 | 'toRGBs.1.0.weight' : ['con', 'G_synthesis/8x8/ToRGB/weight' ], 343 | 'toRGBs.1.0.fc.weight' : ['fc_', 'G_synthesis/8x8/ToRGB/mod_weight' ], 344 | 'toRGBs.1.0.bias.bias' : ['any', 'G_synthesis/8x8/ToRGB/mod_bias' ], 345 | 'toRGBs.1.1.bias' : ['any', 'G_synthesis/8x8/ToRGB/bias' ], 346 | 'toRGBs.2.0.weight' : ['con', 'G_synthesis/16x16/ToRGB/weight' ], 347 | 'toRGBs.2.0.fc.weight' : ['fc_', 'G_synthesis/16x16/ToRGB/mod_weight' ], 348 | 'toRGBs.2.0.bias.bias' : ['any', 'G_synthesis/16x16/ToRGB/mod_bias' ], 349 | 'toRGBs.2.1.bias' : ['any', 'G_synthesis/16x16/ToRGB/bias' ], 350 | 'toRGBs.3.0.weight' : ['con', 'G_synthesis/32x32/ToRGB/weight' ], 351 | 'toRGBs.3.0.fc.weight' : ['fc_', 'G_synthesis/32x32/ToRGB/mod_weight' ], 352 | 'toRGBs.3.0.bias.bias' : ['any', 'G_synthesis/32x32/ToRGB/mod_bias' ], 353 | 'toRGBs.3.1.bias' : ['any', 'G_synthesis/32x32/ToRGB/bias' ], 354 | 'toRGBs.4.0.weight' : ['con', 'G_synthesis/64x64/ToRGB/weight' ], 355 | 'toRGBs.4.0.fc.weight' : ['fc_', 'G_synthesis/64x64/ToRGB/mod_weight' ], 356 | 'toRGBs.4.0.bias.bias' : ['any', 'G_synthesis/64x64/ToRGB/mod_bias' ], 357 | 'toRGBs.4.1.bias' : ['any', 'G_synthesis/64x64/ToRGB/bias' ], 358 | 'toRGBs.5.0.weight' : ['con', 'G_synthesis/128x128/ToRGB/weight' ], 359 | 'toRGBs.5.0.fc.weight' : ['fc_', 'G_synthesis/128x128/ToRGB/mod_weight' ], 360 | 'toRGBs.5.0.bias.bias' : ['any', 'G_synthesis/128x128/ToRGB/mod_bias' ], 361 | 'toRGBs.5.1.bias' : ['any', 'G_synthesis/128x128/ToRGB/bias' ], 362 | 'toRGBs.6.0.weight' : ['con', 'G_synthesis/256x256/ToRGB/weight' ], 363 | 'toRGBs.6.0.fc.weight' : ['fc_', 'G_synthesis/256x256/ToRGB/mod_weight' ], 364 | 'toRGBs.6.0.bias.bias' : ['any', 'G_synthesis/256x256/ToRGB/mod_bias' ], 365 | 'toRGBs.6.1.bias' : ['any', 'G_synthesis/256x256/ToRGB/bias' ], 366 | 'toRGBs.7.0.weight' : ['con', 'G_synthesis/512x512/ToRGB/weight' ], 367 | 'toRGBs.7.0.fc.weight' : ['fc_', 'G_synthesis/512x512/ToRGB/mod_weight' ], 368 | 'toRGBs.7.0.bias.bias' : ['any', 'G_synthesis/512x512/ToRGB/mod_bias' ], 369 | 'toRGBs.7.1.bias' : ['any', 'G_synthesis/512x512/ToRGB/bias' ], 370 | #'toRGBs.8.0.weight' : ['con', 'G_synthesis/1024x1024/ToRGB/weight' ], 371 | #'toRGBs.8.0.fc.weight' : ['fc_', 'G_synthesis/1024x1024/ToRGB/mod_weight' ], 372 | #'toRGBs.8.0.bias.bias' : ['any', 'G_synthesis/1024x1024/ToRGB/mod_bias' ], 373 | #'toRGBs.8.1.bias' : ['any', 'G_synthesis/1024x1024/ToRGB/bias' ], 374 | } 375 | --------------------------------------------------------------------------------