├── README.md ├── __init__.py ├── environment.yml ├── learning_rate_fn.png ├── requirements.txt ├── run_scripts.txt └── savi ├── __init__.py ├── configs ├── __init__.py └── movi │ └── savi_conditional_small.py ├── datasets ├── __init__.py └── tfds │ ├── __init__.py │ ├── tfds_dataset_wrapper.py │ ├── tfds_input_pipeline.py │ └── tfds_preprocessing.py ├── lib ├── __init__.py ├── losses.py ├── metrics.py ├── metrics_jax.py └── utils.py ├── main.py ├── modules ├── __init__.py ├── attention.py ├── convolution.py ├── decoders.py ├── evaluator.py ├── factory.py ├── initializers.py ├── misc.py └── video.py └── trainers ├── __init__.py ├── tfds_trainer.py ├── tfds_trainer_dataparallel.py └── utils ├── __init__.py ├── lr_decay.py ├── lr_sched.py └── misc.py /README.md: -------------------------------------------------------------------------------- 1 | # SAVi-pytorch 2 | 3 | ``` 4 | +-----------------------------------------------------------+-----------------+---------+-----------+--------+ 5 | | Name | Shape | Size | Mean | Std | 6 | +-----------------------------------------------------------+-----------------+---------+-----------+--------+ 7 | | initializer.embedding_transform.model.dense_mlp_0.weight | (256, 4) | 1,024 | -0.00731 | 0.506 | 8 | | initializer.embedding_transform.model.dense_mlp_0.bias | (256,) | 256 | 0.0 | 0.0 | 9 | | initializer.embedding_transform.model.dense_mlp_1.weight | (128, 256) | 32,768 | 0.000184 | 0.062 | 10 | | initializer.embedding_transform.model.dense_mlp_1.bias | (128,) | 128 | 0.0 | 0.0 | 11 | | encoder.backbone.cnn_layers.conv_0.weight | (32, 3, 5, 5) | 2,400 | -0.00121 | 0.115 | 12 | | encoder.backbone.cnn_layers.conv_0.bias | (32,) | 32 | 0.0 | 0.0 | 13 | | encoder.backbone.cnn_layers.conv_1.weight | (32, 32, 5, 5) | 25,600 | 1.36e-05 | 0.0354 | 14 | | encoder.backbone.cnn_layers.conv_1.bias | (32,) | 32 | 0.0 | 0.0 | 15 | | encoder.backbone.cnn_layers.conv_2.weight | (32, 32, 5, 5) | 25,600 | 3.66e-05 | 0.0354 | 16 | | encoder.backbone.cnn_layers.conv_2.bias | (32,) | 32 | 0.0 | 0.0 | 17 | | encoder.backbone.cnn_layers.conv_3.weight | (32, 32, 5, 5) | 25,600 | 9.38e-05 | 0.0356 | 18 | | encoder.backbone.cnn_layers.conv_3.bias | (32,) | 32 | 0.0 | 0.0 | 19 | | encoder.pos_emb.pos_embedding | (1, 64, 64, 2) | 8,192 | -1.49e-08 | 0.586 | 20 | | encoder.pos_emb.output_transform.layernorm_module.weight | (32,) | 32 | 1.0 | 0.0 | 21 | | encoder.pos_emb.output_transform.layernorm_module.bias | (32,) | 32 | 0.0 | 0.0 | 22 | | encoder.pos_emb.output_transform.model.dense_mlp_0.weight | (64, 32) | 2,048 | 0.00425 | 0.181 | 23 | | encoder.pos_emb.output_transform.model.dense_mlp_0.bias | (64,) | 64 | 0.0 | 0.0 | 24 | | encoder.pos_emb.output_transform.model.dense_mlp_1.weight | (32, 64) | 2,048 | 0.00412 | 0.125 | 25 | | encoder.pos_emb.output_transform.model.dense_mlp_1.bias | (32,) | 32 | 0.0 | 0.0 | 26 | | encoder.pos_emb.project_add_dense.weight | (32, 2) | 64 | -0.057 | 0.663 | 27 | | encoder.pos_emb.project_add_dense.bias | (32,) | 32 | 0.0 | 0.0 | 28 | | corrector.gru.dense_ir.weight | (128, 128) | 16,384 | -0.00144 | 0.0886 | 29 | | corrector.gru.dense_ir.bias | (128,) | 128 | 0.0 | 0.0 | 30 | | corrector.gru.dense_iz.weight | (128, 128) | 16,384 | 3.85e-06 | 0.0891 | 31 | | corrector.gru.dense_iz.bias | (128,) | 128 | 0.0 | 0.0 | 32 | | corrector.gru.dense_in.weight | (128, 128) | 16,384 | -0.00101 | 0.0876 | 33 | | corrector.gru.dense_in.bias | (128,) | 128 | 0.0 | 0.0 | 34 | | corrector.gru.dense_hr.weight | (128, 128) | 16,384 | 0.000347 | 0.0884 | 35 | | corrector.gru.dense_hz.weight | (128, 128) | 16,384 | 0.000198 | 0.0884 | 36 | | corrector.gru.dense_hn.weight | (128, 128) | 16,384 | -0.000997 | 0.0884 | 37 | | corrector.gru.dense_hn.bias | (128,) | 128 | 0.0 | 0.0 | 38 | | corrector.dense_q.weight | (128, 128) | 16,384 | -0.000674 | 0.0887 | 39 | | corrector.dense_k.weight | (128, 32) | 4,096 | -0.00172 | 0.179 | 40 | | corrector.dense_v.weight | (128, 32) | 4,096 | 0.0034 | 0.179 | 41 | | corrector.layernorm_q.weight | (128,) | 128 | 1.0 | 0.0 | 42 | | corrector.layernorm_q.bias | (128,) | 128 | 0.0 | 0.0 | 43 | | corrector.layernorm_input.weight | (32,) | 32 | 1.0 | 0.0 | 44 | | corrector.layernorm_input.bias | (32,) | 32 | 0.0 | 0.0 | 45 | | decoder.backbone.cnn_layers.conv_0.weight | (128, 64, 5, 5) | 204,800 | 2.94e-05 | 0.0251 | 46 | | decoder.backbone.cnn_layers.conv_0.bias | (64,) | 64 | 0.0 | 0.0 | 47 | | decoder.backbone.cnn_layers.conv_1.weight | (64, 64, 5, 5) | 102,400 | 4.32e-05 | 0.025 | 48 | | decoder.backbone.cnn_layers.conv_1.bias | (64,) | 64 | 0.0 | 0.0 | 49 | | decoder.backbone.cnn_layers.conv_2.weight | (64, 64, 5, 5) | 102,400 | -2.19e-05 | 0.025 | 50 | | decoder.backbone.cnn_layers.conv_2.bias | (64,) | 64 | 0.0 | 0.0 | 51 | | decoder.backbone.cnn_layers.conv_3.weight | (64, 64, 5, 5) | 102,400 | 1.49e-05 | 0.025 | 52 | | decoder.backbone.cnn_layers.conv_3.bias | (64,) | 64 | 0.0 | 0.0 | 53 | | decoder.pos_emb.pos_embedding | (1, 8, 8, 2) | 128 | 0.0 | 0.657 | 54 | | decoder.pos_emb.project_add_dense.weight | (128, 2) | 256 | 0.104 | 0.707 | 55 | | decoder.pos_emb.project_add_dense.bias | (128,) | 128 | 0.0 | 0.0 | 56 | | decoder.target_readout.readout_modules.0.weight | (3, 64) | 192 | 0.0142 | 0.128 | 57 | | decoder.target_readout.readout_modules.0.bias | (3,) | 3 | 0.0 | 0.0 | 58 | | decoder.mask_pred.weight | (1, 64) | 64 | -0.0126 | 0.114 | 59 | | decoder.mask_pred.bias | (1,) | 1 | 0.0 | nan | 60 | | predictor.mlp.model.dense_mlp_0.weight | (256, 128) | 32,768 | -0.000881 | 0.0889 | 61 | | predictor.mlp.model.dense_mlp_0.bias | (256,) | 256 | 0.0 | 0.0 | 62 | | predictor.mlp.model.dense_mlp_1.weight | (128, 256) | 32,768 | -0.000456 | 0.0624 | 63 | | predictor.mlp.model.dense_mlp_1.bias | (128,) | 128 | 0.0 | 0.0 | 64 | | predictor.layernorm_query.weight | (128,) | 128 | 1.0 | 0.0 | 65 | | predictor.layernorm_query.bias | (128,) | 128 | 0.0 | 0.0 | 66 | | predictor.layernorm_mlp.weight | (128,) | 128 | 1.0 | 0.0 | 67 | | predictor.layernorm_mlp.bias | (128,) | 128 | 0.0 | 0.0 | 68 | | predictor.dense_q.weight | (128, 128) | 16,384 | -0.00041 | 0.0875 | 69 | | predictor.dense_q.bias | (128,) | 128 | 0.0 | 0.0 | 70 | | predictor.dense_k.weight | (128, 128) | 16,384 | 0.00023 | 0.089 | 71 | | predictor.dense_k.bias | (128,) | 128 | 0.0 | 0.0 | 72 | | predictor.dense_v.weight | (128, 128) | 16,384 | 0.00036 | 0.088 | 73 | | predictor.dense_v.bias | (128,) | 128 | 0.0 | 0.0 | 74 | | predictor.dense_o.weight | (128, 128) | 16,384 | -0.000598 | 0.0878 | 75 | | predictor.dense_o.bias | (128,) | 128 | 0.0 | 0.0 | 76 | +-----------------------------------------------------------+-----------------+---------+-----------+--------+ 77 | Total: 895,268 78 | SAVi( 79 | (initializer): CoordinateEncoderStateInit( 80 | (embedding_transform): MLP( 81 | (model): ModuleList( 82 | (dense_mlp_0): Linear(in_features=4, out_features=256, bias=True) 83 | (dense_mlp_0_act): ReLU() 84 | (dense_mlp_1): Linear(in_features=256, out_features=128, bias=True) 85 | ) 86 | ) 87 | ) 88 | (encoder): FrameEncoder( 89 | (backbone): CNN2( 90 | (cnn_layers): ModuleList( 91 | (conv_0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) 92 | (act_0): ReLU() 93 | (conv_1): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) 94 | (act_1): ReLU() 95 | (conv_2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) 96 | (act_2): ReLU() 97 | (conv_3): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) 98 | (act_3): ReLU() 99 | ) 100 | ) 101 | (pos_emb): PositionEmbedding( 102 | (pos_transform): Identity() 103 | (output_transform): MLP( 104 | (layernorm_module): LayerNorm((32,), eps=1e-06, elementwise_affine=True) 105 | (model): ModuleList( 106 | (dense_mlp_0): Linear(in_features=32, out_features=64, bias=True) 107 | (dense_mlp_0_act): ReLU() 108 | (dense_mlp_1): Linear(in_features=64, out_features=32, bias=True) 109 | ) 110 | ) 111 | (project_add_dense): Linear(in_features=2, out_features=32, bias=True) 112 | ) 113 | (output_transform): Identity() 114 | ) 115 | (corrector): SlotAttention( 116 | (gru): myGRUCell( 117 | (dense_ir): Linear(in_features=128, out_features=128, bias=True) 118 | (dense_iz): Linear(in_features=128, out_features=128, bias=True) 119 | (dense_in): Linear(in_features=128, out_features=128, bias=True) 120 | (dense_hr): Linear(in_features=128, out_features=128, bias=False) 121 | (dense_hz): Linear(in_features=128, out_features=128, bias=False) 122 | (dense_hn): Linear(in_features=128, out_features=128, bias=True) 123 | ) 124 | (dense_q): Linear(in_features=128, out_features=128, bias=False) 125 | (dense_k): Linear(in_features=32, out_features=128, bias=False) 126 | (dense_v): Linear(in_features=32, out_features=128, bias=False) 127 | (layernorm_q): LayerNorm((128,), eps=1e-06, elementwise_affine=True) 128 | (layernorm_input): LayerNorm((32,), eps=1e-06, elementwise_affine=True) 129 | (inverted_attention): InvertedDotProductAttention( 130 | (attn_fn): GeneralizedDotProductAttention() 131 | ) 132 | ) 133 | (decoder): SpatialBroadcastDecoder( 134 | (backbone): CNN2( 135 | (cnn_layers): ModuleList( 136 | (conv_0): ConvTranspose2d(128, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1)) 137 | (act_0): ReLU() 138 | (conv_1): ConvTranspose2d(64, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1)) 139 | (act_1): ReLU() 140 | (conv_2): ConvTranspose2d(64, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1)) 141 | (act_2): ReLU() 142 | (conv_3): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) 143 | (act_3): ReLU() 144 | ) 145 | ) 146 | (pos_emb): PositionEmbedding( 147 | (pos_transform): Identity() 148 | (output_transform): Identity() 149 | (project_add_dense): Linear(in_features=2, out_features=128, bias=True) 150 | ) 151 | (target_readout): Readout( 152 | (readout_modules): ModuleList( 153 | (0): Linear(in_features=64, out_features=3, bias=True) 154 | ) 155 | ) 156 | (mask_pred): Linear(in_features=64, out_features=1, bias=True) 157 | ) 158 | (predictor): TransformerBlock( 159 | (attn): GeneralizedDotProductAttention() 160 | (mlp): MLP( 161 | (model): ModuleList( 162 | (dense_mlp_0): Linear(in_features=128, out_features=256, bias=True) 163 | (dense_mlp_0_act): ReLU() 164 | (dense_mlp_1): Linear(in_features=256, out_features=128, bias=True) 165 | ) 166 | ) 167 | (layernorm_query): LayerNorm((128,), eps=1e-06, elementwise_affine=True) 168 | (layernorm_mlp): LayerNorm((128,), eps=1e-06, elementwise_affine=True) 169 | (dense_q): Linear(in_features=128, out_features=128, bias=True) 170 | (dense_k): Linear(in_features=128, out_features=128, bias=True) 171 | (dense_v): Linear(in_features=128, out_features=128, bias=True) 172 | (dense_o): Linear(in_features=128, out_features=128, bias=True) 173 | ) 174 | (processor): Processor( 175 | (corrector): SlotAttention( 176 | (gru): myGRUCell( 177 | (dense_ir): Linear(in_features=128, out_features=128, bias=True) 178 | (dense_iz): Linear(in_features=128, out_features=128, bias=True) 179 | (dense_in): Linear(in_features=128, out_features=128, bias=True) 180 | (dense_hr): Linear(in_features=128, out_features=128, bias=False) 181 | (dense_hz): Linear(in_features=128, out_features=128, bias=False) 182 | (dense_hn): Linear(in_features=128, out_features=128, bias=True) 183 | ) 184 | (dense_q): Linear(in_features=128, out_features=128, bias=False) 185 | (dense_k): Linear(in_features=32, out_features=128, bias=False) 186 | (dense_v): Linear(in_features=32, out_features=128, bias=False) 187 | (layernorm_q): LayerNorm((128,), eps=1e-06, elementwise_affine=True) 188 | (layernorm_input): LayerNorm((32,), eps=1e-06, elementwise_affine=True) 189 | (inverted_attention): InvertedDotProductAttention( 190 | (attn_fn): GeneralizedDotProductAttention() 191 | ) 192 | ) 193 | (predictor): TransformerBlock( 194 | (attn): GeneralizedDotProductAttention() 195 | (mlp): MLP( 196 | (model): ModuleList( 197 | (dense_mlp_0): Linear(in_features=128, out_features=256, bias=True) 198 | (dense_mlp_0_act): ReLU() 199 | (dense_mlp_1): Linear(in_features=256, out_features=128, bias=True) 200 | ) 201 | ) 202 | (layernorm_query): LayerNorm((128,), eps=1e-06, elementwise_affine=True) 203 | (layernorm_mlp): LayerNorm((128,), eps=1e-06, elementwise_affine=True) 204 | (dense_q): Linear(in_features=128, out_features=128, bias=True) 205 | (dense_k): Linear(in_features=128, out_features=128, bias=True) 206 | (dense_v): Linear(in_features=128, out_features=128, bias=True) 207 | (dense_o): Linear(in_features=128, out_features=128, bias=True) 208 | ) 209 | ) 210 | ) 211 | ``` -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # FIXME -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: savi 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=5.1=1_gnu 8 | - blas=1.0=mkl 9 | - brotlipy=0.7.0=py37h27cfd23_1003 10 | - bzip2=1.0.8=h7b6447c_0 11 | - ca-certificates=2022.4.26=h06a4308_0 12 | - certifi=2022.5.18.1=py37h06a4308_0 13 | - cffi=1.15.0=py37hd667e15_1 14 | - cryptography=37.0.1=py37h9ce1e76_0 15 | - cudatoolkit=11.3.1=h2bc3f7f_2 16 | - ffmpeg=4.3=hf484d3e_0 17 | - freetype=2.11.0=h70c0345_0 18 | - giflib=5.2.1=h7b6447c_0 19 | - gmp=6.2.1=h295c915_3 20 | - gnutls=3.6.15=he1e5248_0 21 | - idna=3.3=pyhd3eb1b0_0 22 | - intel-openmp=2021.4.0=h06a4308_3561 23 | - jpeg=9e=h7f8727e_0 24 | - lame=3.100=h7b6447c_0 25 | - lcms2=2.12=h3be6417_0 26 | - ld_impl_linux-64=2.38=h1181459_1 27 | - libffi=3.3=he6710b0_2 28 | - libgcc-ng=11.2.0=h1234567_1 29 | - libgomp=11.2.0=h1234567_1 30 | - libiconv=1.16=h7f8727e_2 31 | - libidn2=2.3.2=h7f8727e_0 32 | - libpng=1.6.37=hbc83047_0 33 | - libstdcxx-ng=11.2.0=h1234567_1 34 | - libtasn1=4.16.0=h27cfd23_0 35 | - libtiff=4.2.0=h2818925_1 36 | - libunistring=0.9.10=h27cfd23_0 37 | - libuv=1.40.0=h7b6447c_0 38 | - libwebp=1.2.2=h55f646e_0 39 | - libwebp-base=1.2.2=h7f8727e_0 40 | - lz4-c=1.9.3=h295c915_1 41 | - mkl=2021.4.0=h06a4308_640 42 | - mkl-service=2.4.0=py37h7f8727e_0 43 | - mkl_fft=1.3.1=py37hd3c417c_0 44 | - mkl_random=1.2.2=py37h51133e4_0 45 | - ncurses=6.3=h7f8727e_2 46 | - nettle=3.7.3=hbbd107a_1 47 | - numpy-base=1.21.5=py37ha15fc14_3 48 | - openh264=2.1.1=h4ff587b_0 49 | - openssl=1.1.1o=h7f8727e_0 50 | - pip=21.2.2=py37h06a4308_0 51 | - pycparser=2.21=pyhd3eb1b0_0 52 | - pyopenssl=22.0.0=pyhd3eb1b0_0 53 | - pysocks=1.7.1=py37_1 54 | - python=3.7.13=h12debd9_0 55 | - pytorch=1.11.0=py3.7_cuda11.3_cudnn8.2.0_0 56 | - pytorch-mutex=1.0=cuda 57 | - readline=8.1.2=h7f8727e_1 58 | - setuptools=61.2.0=py37h06a4308_0 59 | - six=1.16.0=pyhd3eb1b0_1 60 | - sqlite=3.38.3=hc218d9a_0 61 | - tk=8.6.12=h1ccaba5_0 62 | - torchaudio=0.11.0=py37_cu113 63 | - typing_extensions=4.1.1=pyh06a4308_0 64 | - urllib3=1.26.9=py37h06a4308_0 65 | - wheel=0.37.1=pyhd3eb1b0_0 66 | - xz=5.2.5=h7f8727e_1 67 | - zlib=1.2.12=h7f8727e_2 68 | - zstd=1.5.2=ha4553b6_0 69 | - pip: 70 | - absl-py==1.1.0 71 | - argon2-cffi==21.3.0 72 | - argon2-cffi-bindings==21.2.0 73 | - astunparse==1.6.3 74 | - attrs==21.4.0 75 | - backcall==0.2.0 76 | - beautifulsoup4==4.11.1 77 | - bleach==5.0.0 78 | - cached-property==1.5.2 79 | - cachetools==5.2.0 80 | - charset-normalizer==2.0.12 81 | - chex==0.1.3 82 | - click==8.1.3 83 | - clu==0.0.7 84 | - colorama==0.4.4 85 | - commonmark==0.9.1 86 | - contextlib2==21.6.0 87 | - cycler==0.11.0 88 | - debugpy==1.6.0 89 | - decorator==5.1.1 90 | - defusedxml==0.7.1 91 | - dill==0.3.5.1 92 | - dm-tree==0.1.7 93 | - docker-pycreds==0.4.0 94 | - entrypoints==0.4 95 | - etils==0.6.0 96 | - fastjsonschema==2.15.3 97 | - flatbuffers==1.12 98 | - flax==0.5.1 99 | - fonttools==4.33.3 100 | - gast==0.4.0 101 | - gitdb==4.0.9 102 | - gitpython==3.1.27 103 | - google-auth==2.8.0 104 | - google-auth-oauthlib==0.4.6 105 | - google-pasta==0.2.0 106 | - googleapis-common-protos==1.56.2 107 | - grpcio==1.46.3 108 | - h5py==3.7.0 109 | - imageio==2.19.3 110 | - importlib-metadata==4.11.4 111 | - importlib-resources==5.7.1 112 | - ipdb==0.13.9 113 | - ipykernel==6.15.0 114 | - ipython==7.34.0 115 | - ipython-genutils==0.2.0 116 | - ipywidgets==7.7.0 117 | - jax==0.3.13 118 | - jaxlib==0.3.10 119 | - jedi==0.18.1 120 | - jinja2==3.1.2 121 | - jsonschema==4.6.0 122 | - jupyter==1.0.0 123 | - jupyter-client==7.3.4 124 | - jupyter-console==6.4.3 125 | - jupyter-core==4.10.0 126 | - jupyterlab-pygments==0.2.2 127 | - jupyterlab-widgets==1.1.0 128 | - keras==2.9.0 129 | - keras-preprocessing==1.1.2 130 | - kiwisolver==1.4.3 131 | - libclang==14.0.1 132 | - markdown==3.3.7 133 | - markupsafe==2.1.1 134 | - matplotlib==3.5.2 135 | - matplotlib-inline==0.1.3 136 | - mistune==0.8.4 137 | - ml-collections==0.1.1 138 | - msgpack==1.0.4 139 | - nbclient==0.6.4 140 | - nbconvert==6.5.0 141 | - nbformat==5.4.0 142 | - nest-asyncio==1.5.5 143 | - networkx==2.6.3 144 | - notebook==6.4.12 145 | - numpy==1.21.6 146 | - oauthlib==3.2.0 147 | - opt-einsum==3.3.0 148 | - optax==0.1.2 149 | - packaging==21.3 150 | - pandocfilters==1.5.0 151 | - parso==0.8.3 152 | - pathtools==0.1.2 153 | - pexpect==4.8.0 154 | - pickleshare==0.7.5 155 | - pillow==9.1.1 156 | - prometheus-client==0.14.1 157 | - promise==2.3 158 | - prompt-toolkit==3.0.29 159 | - protobuf==3.19.4 160 | - psutil==5.9.1 161 | - ptyprocess==0.7.0 162 | - pyasn1==0.4.8 163 | - pyasn1-modules==0.2.8 164 | - pygments==2.12.0 165 | - pyparsing==3.0.9 166 | - pyrsistent==0.18.1 167 | - python-dateutil==2.8.2 168 | - pywavelets==1.3.0 169 | - pyyaml==6.0 170 | - pyzmq==23.1.0 171 | - qtconsole==5.3.1 172 | - qtpy==2.1.0 173 | - requests==2.28.0 174 | - requests-oauthlib==1.3.1 175 | - rich==11.1.0 176 | - rsa==4.8 177 | - scikit-image==0.19.3 178 | - scipy==1.7.3 179 | - send2trash==1.8.0 180 | - sentry-sdk==1.5.12 181 | - setproctitle==1.2.3 182 | - shortuuid==1.0.9 183 | - smmap==5.0.0 184 | - soupsieve==2.3.2.post1 185 | - tensorboard==2.9.1 186 | - tensorboard-data-server==0.6.1 187 | - tensorboard-plugin-wit==1.8.1 188 | - tensorflow==2.9.1 189 | - tensorflow-cpu==2.9.1 190 | - tensorflow-datasets==4.6.0 191 | - tensorflow-estimator==2.9.0 192 | - tensorflow-io-gcs-filesystem==0.26.0 193 | - tensorflow-metadata==1.8.0 194 | - termcolor==1.1.0 195 | - terminado==0.15.0 196 | - tifffile==2021.11.2 197 | - tinycss2==1.1.1 198 | - toml==0.10.2 199 | - toolz==0.11.2 200 | - torch==1.11.0 201 | - torchvision==0.12.0 202 | - tornado==6.1 203 | - tqdm==4.64.0 204 | - traitlets==5.3.0 205 | - typing-extensions==4.2.0 206 | - wandb==0.12.18 207 | - wcwidth==0.2.5 208 | - webencodings==0.5.1 209 | - werkzeug==2.1.2 210 | - widgetsnbextension==3.6.0 211 | - wrapt==1.14.1 212 | - zipp==3.8.0 213 | prefix: /home/junkeun-yi/miniconda3/envs/savi 214 | -------------------------------------------------------------------------------- /learning_rate_fn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junkeun-yi/SAVi-pytorch/6f9a1c8987995aa7729bb4cbcf3a027a88dadf94/learning_rate_fn.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | numpy 4 | scikit-image 5 | matplotlib 6 | ml_collections 7 | clu 8 | tensorflow-cpu 9 | tensorflow-datasets 10 | wandb 11 | -------------------------------------------------------------------------------- /run_scripts.txt: -------------------------------------------------------------------------------- 1 | # most basic 2 | python -m savi.main --seed {} --wandb --group {} --exp {} --gpu {} 3 | 4 | # using movi_c 5 | python -m savi.main --tfds_name movi_c/128x128:1.0.0 --seed {} --wandb --group {} --exp {} --gpu {} 6 | 7 | # model flow 8 | python -m savi.main --seed {} --wandb --group {} --exp {} --gpu {} --model_type flow --slice_decode_inputs 9 | 10 | # for evaluation, just add --eval. 11 | # to resume, use --resume_from {checkpoint_path} 12 | 13 | 14 | # examples 15 | python -m savi.main --seed 51 --wandb --group savi_movi_a_6.19 --exp 51 --gpu 0,1,2,3 16 | 17 | python -m savi.main --tfds_name movi_c/128x128:1.0.0 --group savi_movi_c_myattn_loss --exp 21 --seed 21 --gpu 0,1 --wandb 18 | 19 | python -m savi.main --model_type flow --slice_decode_inputs --group flow_movi_a_6.19 --exp 21 --seed 21 --gpu 0,1 --wandb 20 | 21 | python -m savi.main --group savi_movi_a_myattn_loss --gpu 2,3,4,5,6,7,8,9 --seed 21 --exp 21 --wandb 22 | 23 | python -m savi.main --tfds_name movi_c/128x128:1.0.0 --group savi_movi_c_normal_default --gpu --exp 21 --seed 21 --wandb 24 | 25 | python -m savi.main --group savi_movi_a_test --wandb --seed 11 --exp 11_xavuni,zeros_ --gpu 26 | 27 | python -m savi.main --tfds_name movi_c/128x128:1.0.0 --group savi_med_movi_c_test --gpu 8,9 --seed 11 --wandb --exp 11_lecnor,zeros_d1 --init_weight lecun_normal --init_bias zeros 28 | 29 | python -m savi.main --tfds_name movi_c/128x128:1.0.0 --group savi_med_movi_c_test --gpu 8,9 --seed 51 --wandb --exp 51_lecnor,zeros_d1 --init_weight lecun_normal --init_bias zeros --model_size medium --batch_size 32 --accum_iter 2 30 | 31 | --data_dir /shared/junkeun-yi/kubric 32 | 33 | python -m savi.main --group savi_med_movi_a_test --gpu 2,3,4,5,6,7,8,9 --seed 61 --wandb --exp 61_xavnor,zeros_sg --init_weight xavier_normal --init_bias zeros --model_size medium 34 | 35 | python -m savi.main --tfds_name movi_c/128x128:1.0.0 --group savi_med_movi_c_gradNone_test --model_size medium --gpu 0,2,3,4 --seed 101 --exp 101_lecnor,zeros_ft --init_weight lecun_normal --init_bias zeros 36 | 37 | python -m savi.main --tfds_name movi_a/128x128:1.0.0 --group savi_movi_a_lecnor --seed 200 --exp 200_gg 38 | 39 | python -m savi.main --data_dir /shared/junkeun-yi/kubric --tfds_name movi_a/128x128:1.0.0 --group savi_movi_a_lecnor --seed 200 --exp 200_gg -------------------------------------------------------------------------------- /savi/__init__.py: -------------------------------------------------------------------------------- 1 | # FIXME -------------------------------------------------------------------------------- /savi/configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junkeun-yi/SAVi-pytorch/6f9a1c8987995aa7729bb4cbcf3a027a88dadf94/savi/configs/__init__.py -------------------------------------------------------------------------------- /savi/configs/movi/savi_conditional_small.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junkeun-yi/SAVi-pytorch/6f9a1c8987995aa7729bb4cbcf3a027a88dadf94/savi/configs/movi/savi_conditional_small.py -------------------------------------------------------------------------------- /savi/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # FIXME -------------------------------------------------------------------------------- /savi/datasets/tfds/__init__.py: -------------------------------------------------------------------------------- 1 | # FIXME -------------------------------------------------------------------------------- /savi/datasets/tfds/tfds_dataset_wrapper.py: -------------------------------------------------------------------------------- 1 | """Try to wrap TFDS dataset.""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | import jax 8 | 9 | from torch.utils.data import Dataset 10 | 11 | import os 12 | 13 | from savi.datasets.tfds import tfds_input_pipeline 14 | 15 | # MoVi dataset 16 | class MOViData(Dataset): 17 | def __init__(self, tfds_dataset): 18 | self.dataset = tfds_dataset 19 | self.itr = iter(self.dataset) 20 | # TODO: check if running iter(self.dataset) always returns the same data 21 | 22 | def __len__(self): 23 | return len(self.dataset) 24 | 25 | def __getitem__(self, idx): 26 | batch = jax.tree_map(np.asarray, next(self.itr)) 27 | 28 | video = torch.from_numpy(batch['video']) # (B T H W 3) 29 | boxes = torch.from_numpy(batch['boxes']) # (B T maxN 4) 30 | flow = torch.from_numpy(batch['flow']) # (B T H W 3) 31 | padding_mask = torch.from_numpy(batch['padding_mask']) 32 | mask = torch.from_numpy(batch['mask']) if 'mask' in batch.keys() else torch.empty(0, dtype=torch.bool) 33 | segmentations = torch.from_numpy(batch['segmentations']) 34 | 35 | return video, boxes, segmentations, flow, padding_mask, mask 36 | 37 | def reset_itr(self): 38 | self.itr = iter(self.dataset) 39 | 40 | class MOViDataByRank(Dataset): 41 | def __init__(self, tfds_dataset, rank, world_size): 42 | self.dataset = tfds_dataset 43 | self.rank = rank 44 | self.world_size = world_size 45 | 46 | self.reset_itr() 47 | 48 | def __len__(self): 49 | return len(self.dataset) 50 | 51 | def __getitem__(self, idx): 52 | print('hello', self.rank, self.world_size) 53 | for _ in range(self.world_size): 54 | # move by stride 55 | next(self.itr) 56 | 57 | print('retrieving') 58 | print(next(self.itr)) 59 | batch = jax.tree_map(np.asarray, next(self.itr)) 60 | 61 | video = torch.from_numpy(batch['video']) # (B T H W 3) 62 | boxes = torch.from_numpy(batch['boxes']) # (B T maxN 4) 63 | flow = torch.from_numpy(batch['flow']) # (B T H W 3) 64 | padding_mask = torch.from_numpy(batch['padding_mask']) 65 | segmentations = torch.from_numpy(batch['segmentations']) 66 | 67 | print('video', video.shape) 68 | 69 | return video, boxes, flow, padding_mask, segmentations 70 | 71 | def reset_itr(self): 72 | # move itr by rank steps to return strided data 73 | self.itr = iter(self.dataset) 74 | for _ in range(self.rank): 75 | next(self.itr) -------------------------------------------------------------------------------- /savi/datasets/tfds/tfds_input_pipeline.py: -------------------------------------------------------------------------------- 1 | """Input pipeline for TFDS datasets.""" 2 | 3 | # FIXME 4 | 5 | import functools 6 | from typing import Dict, List, Tuple 7 | 8 | from clu import deterministic_data 9 | from clu import preprocess_spec 10 | 11 | import ml_collections 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import numpy as np 17 | 18 | from savi.datasets.tfds import tfds_preprocessing as preprocessing 19 | 20 | import tensorflow as tf 21 | import tensorflow_datasets as tfds 22 | 23 | Array = torch.Tensor 24 | PRNGKey = Array 25 | 26 | def preprocess_example(features: Dict[str, tf.Tensor], 27 | preprocess_strs: List[str]) -> Dict[str, tf.Tensor]: 28 | """Process a single data example. 29 | 30 | Args: 31 | features: A dictionary containing the tensors of a single data example. 32 | preprocess_strs: List of strings, describing one preprocessing operation 33 | each, in clu.preprocess_spec format. 34 | 35 | Returns: 36 | Dictionary containing the preprocessed tensors of a single data example. 37 | """ 38 | all_ops = preprocessing.all_ops() 39 | preprocess_fn = preprocess_spec.parse("|".join(preprocess_strs), all_ops) 40 | return preprocess_fn(features) 41 | 42 | 43 | # def get_batch_dims(gloabl_batch_size: int) -> List[int]: 44 | # """Gets the first two axis sizes for data batches. 45 | 46 | 47 | # """ 48 | 49 | def create_datasets( 50 | args, 51 | data_rng: PRNGKey) -> Tuple[tf.data.Dataset, tf.data.Dataset]: 52 | """Create datasets for training and evaluation 53 | 54 | For the same data_rng and config this will return the same datasets. 55 | The datasets only contain stateless operations. 56 | 57 | Args: 58 | args: Configuration to use. 59 | data_rng: JAX PRNGKey for dataset pipeline. 60 | 61 | Returns: 62 | A tuple with the training dataset and the evaluation dataset. 63 | """ 64 | dataset_builder = tfds.builder( 65 | args.tfds_name, data_dir=args.data_dir) 66 | 67 | batch_dims = (args.batch_size,) 68 | 69 | train_preprocess_fn = functools.partial( 70 | preprocess_example, preprocess_strs=args.preproc_train) 71 | eval_preprocess_fn = functools.partial( 72 | preprocess_example, preprocess_strs=args.preproc_eval) 73 | 74 | train_split_name = "train" # args.get("train_split", "train") 75 | eval_split_name = "validation" # args.get("validation_split", "validation") 76 | 77 | # TODO: may need to do something to only run on one host 78 | train_split = deterministic_data.get_read_instruction_for_host( 79 | train_split_name, dataset_info=dataset_builder.info) 80 | train_ds = deterministic_data.create_dataset( 81 | dataset_builder, 82 | split=train_split, 83 | rng=data_rng, 84 | preprocess_fn=train_preprocess_fn, 85 | cache=False, 86 | shuffle_buffer_size=args.shuffle_buffer_size, 87 | batch_dims=batch_dims, 88 | num_epochs=None, 89 | shuffle=True) 90 | 91 | eval_split = deterministic_data.get_read_instruction_for_host( 92 | eval_split_name, dataset_info=dataset_builder.info, drop_remainder=False) 93 | eval_ds = deterministic_data.create_dataset( 94 | dataset_builder, 95 | split=eval_split, 96 | rng=None, 97 | preprocess_fn=eval_preprocess_fn, 98 | cache=False, 99 | batch_dims=batch_dims, 100 | num_epochs=1, 101 | shuffle=False, 102 | pad_up_to_batches="auto") 103 | 104 | return train_ds, eval_ds -------------------------------------------------------------------------------- /savi/lib/__init__.py: -------------------------------------------------------------------------------- 1 | # FIXME -------------------------------------------------------------------------------- /savi/lib/losses.py: -------------------------------------------------------------------------------- 1 | """Loss functions.""" 2 | 3 | # FIXME 4 | 5 | import functools 6 | import inspect 7 | from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import numpy as np 13 | 14 | _LOSS_FUNCTIONS = {} 15 | 16 | Array = torch.Tensor 17 | ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] 18 | ArrayDict = Dict[str, Array] 19 | DictTree = Dict[str, Union[Array, "DictTree"]] 20 | LossFn = Callable[[Dict[str, ArrayTree], Dict[str, ArrayTree]], 21 | Tuple[Array, ArrayTree]] 22 | ConfigAttr = Any 23 | MetricSpec = Dict[str, str] 24 | 25 | 26 | def standardize_loss_config( 27 | loss_config: Union[Sequence[str], Dict] 28 | ) -> Dict: 29 | """Standardize loss configs into a common Dict format. 30 | 31 | Args: 32 | loss_config: List of strings or Dict specifying loss configuration. 33 | Valid input formats are: 34 | Option 1 (list of strings): 35 | ex) `["box", "presence"]` 36 | Option 2 (losses with weights only): 37 | ex) `{"box": 5, "presence": 2}` 38 | Option 3 (losses with weights and other parameters): 39 | ex) `{"box": {"weight" 5, "metric": "l1"}, "presence": {"weight": 2}} 40 | 41 | Returns: 42 | Standardized Dict containing the loss configuration 43 | 44 | Raises: 45 | ValueError: If loss_config is a list that contains non-string entries. 46 | """ 47 | 48 | if isinstance(loss_config, Sequence): # Option 1 49 | if not all(isinstance(loss_type, str) for loss_type in loss_config): 50 | raise ValueError(f"Loss types all need to be str but got {loss_config}") 51 | return {k: {} for k in loss_config} 52 | 53 | # Convert all option-2-style weights to option-3-style dictionaries. 54 | if not isinstance(loss_config, Dict): 55 | raise ValueError(f"Loss config type not Sequence or Dict; got {loss_config}") 56 | else: 57 | loss_config = { 58 | k: { 59 | "weight": v 60 | } if isinstance(v, (float, int)) else v for k, v in loss_config.items() 61 | } 62 | return loss_config 63 | 64 | 65 | def update_loss_aux(loss_aux: Dict[str, Array], update: Dict[str, Array]): 66 | existing_keys = set(update.keys()).intersection(loss_aux.keys()) 67 | if existing_keys: 68 | raise KeyError( 69 | f"Can't overwrite existing keys in loss_aux: {existing_keys}") 70 | loss_aux.update(update) 71 | 72 | 73 | def compute_full_loss( 74 | preds: Dict[str, ArrayTree], targets: Dict[str, ArrayTree], 75 | loss_config: Union[Sequence[str], Dict] 76 | ) -> Tuple[Array, ArrayTree]: 77 | """Loss function that parses and combines weighted loss terms. 78 | 79 | Args: 80 | preds: Dictionary of tensors containing model predictions. 81 | targets: Dictionary of tensors containing prediction targets. 82 | loss_config: List of strings or Dict specifying loss configuration. 83 | See @register_loss decorated functions below for valid loss names. 84 | Valid losses formats are: 85 | - Option 1 (list of strings): 86 | ex) `["box", "presence"]` 87 | - Option 2 (losses with weights only): 88 | ex) `{"box": 5, "presence": 2}` 89 | - Option 3 (losses with weights and other parameters) 90 | ex) `{"box": {"weight": 5, "metric": "l1}, "presence": {"weight": 2}}` 91 | - Option 4 (like 3 but decoupling name and loss_type) 92 | ex) `{"recon_flow": {"loss_type": "recon", "key": "flow"}, 93 | "recon_video": {"loss_type": "recon", "key": "video"}}` 94 | 95 | Returns: 96 | A 2-tuple of the sum of all individual loss terms and a dictionary of 97 | auxiliary losses and metrics. 98 | """ 99 | 100 | loss = torch.zeros_like(torch.Tensor(), dtype=torch.float32) 101 | loss_aux = {} 102 | loss_config = standardize_loss_config(loss_config) 103 | for loss_name, cfg in loss_config.items(): 104 | context_kwargs = {"preds": preds, "targets": targets} 105 | weight, loss_term, loss_aux_update = comput_loss_term( 106 | loss_name=loss_name, context_kwargs=context_kwargs, config_kwargs=cfg) 107 | 108 | unweighted_loss = torch.mean(loss_term) 109 | loss += weight * unweighted_loss 110 | loss_aux_update[loss_name + "_value"] = unweighted_loss 111 | loss_aux_update[loss_name + "_weight"] = torch.ones_like(unweighted_loss) 112 | update_loss_aux(loss_aux, loss_aux_update) 113 | return loss, loss_aux 114 | 115 | 116 | def register_loss(func=None, 117 | *, 118 | name: Optional[str] = None, 119 | check_unused_kwargs: bool = True): 120 | """Decorator for registering a loss function. 121 | 122 | Can be used without arguments: 123 | ``` 124 | @register_loss 125 | def my_loss(**_): 126 | return 0 127 | ``` 128 | or with keyword arguments: 129 | ``` 130 | @register_loss(name="my_renamed_loss") 131 | def my_loss(**_): 132 | return 0 133 | ``` 134 | 135 | Loss functions may accept 136 | - context kwargs: `preds` and `targets` 137 | - config kwargs: any argument specified in the config 138 | - the special `config_kwargs` parameter that contains the entire loss config. 139 | Loss functions also _need_ to accept a **kwarg argument to support extending 140 | the interface. 141 | They should return either: 142 | - just the computed loss (pre-reduction) 143 | - or a tuple of the computed loss and a loss_aux_updates dict 144 | 145 | Args: 146 | func: the decorated function 147 | name (str): optional name to be used for this loss in the config. 148 | Defaults to the name of the function. 149 | check_unused_kwargs (bool): By default compute_loss_term raises an error if 150 | there are any usused config kwargs. If this flag is set to False that step 151 | is skipped. This is useful if the config_kwargs should be passed onward to 152 | another function. 153 | 154 | Returns: 155 | The decorated function (or a partial of the decorator) 156 | """ 157 | # If this decorator has been called with parameters but no function, then we 158 | # return the decorator again (but with partially filled parameters). 159 | # This allows using both @register_loss and @register_loss(name="foo") 160 | if func is None: 161 | return functools.partial( 162 | register_loss, name=name, check_unused_kwargs=check_unused_kwargs) 163 | 164 | # No (further) arguments: this is the actual decorator 165 | # ensure that the loss function includes a **kwargs argument 166 | loss_name = name if name is not None else func.__name__ 167 | if not any(v.kind == inspect.Parameter.VAR_KEYWORD 168 | for k, v in inspect.signature(func).parameters.items()): 169 | raise TypeError( 170 | f"Loss function '{loss_name}' needs to include a **kwargs argument") 171 | func.name = loss_name 172 | func.check_unused_kwargs = check_unused_kwargs 173 | _LOSS_FUNCTIONS[loss_name] = func 174 | return func 175 | 176 | 177 | def compute_loss_term( 178 | loss_name: str, context_kwargs: Dict[str, Any], 179 | config_kwargs: Dict[str, Any]) -> Tuple[float, Array, Dict[str, Array]]: 180 | """Compute a loss function given its config and context parameters. 181 | 182 | Takes care of: 183 | - finding the correct loss function based on "loss_type" or name 184 | - the optional "weight" parameter 185 | - checking for typos and collisions in config parameters 186 | - adding the optional loss_aux_updates if omitted by the loss_fn 187 | 188 | Args: 189 | loss_name: Name of the loss, i.e. its key in the config.losses dict. 190 | context_kwargs: Dictionary of context variables (`preds` and `targets`) 191 | config_kwargs: The config dict for this loss 192 | 193 | Returns: 194 | 1. the loss weight (float) 195 | 2. loss term (Array) 196 | 3. loss aux updates (Dict[str, Array]) 197 | 198 | Raises: 199 | KeyError: 200 | Unknown loss_type 201 | KeyError: 202 | Unused config entries, i.e. not used by the loss function. 203 | Not raised if using @reegister_loss(check_unused_kwargs=False) 204 | KeyError: Config entry with a name that conflicts with a context_kwarg 205 | ValueError: Non-numerical weight in config_kwargs 206 | """ 207 | 208 | # Make a dict copy of config_kwargs 209 | kwargs = {k: v for k, v in config_kwargs.items()} 210 | 211 | # Get the loss function 212 | loss_type = kwargs.pop("loss_type", loss_name) 213 | if loss_type not in _LOSS_FUNCTIONS: 214 | raise KeyError(f"Unknown loss_type '{loss_type}'.") 215 | loss_fn = _LOSS_FUNCTIONS[loss_type] 216 | 217 | # Take care of "weight" term 218 | weight = kwargs.pop("weight", 1.0) 219 | if not isinstance(weight, (int, float)): 220 | raise ValueError(f"Weight for loss {loss_name} should be a number, " 221 | f"but was {weight}.") 222 | 223 | # Check for unused config entries (to prevent typos etc.) 224 | config_keys = set(kwargs) 225 | if loss_fn.check_unused_kwargs: 226 | param_names = set(inspect.signature(loss_fn).parameters) 227 | unused_config_keys = config_keys - param_names 228 | if unused_config_keys: 229 | raise KeyError(f"Unrecognized config entries {unused_config_keys} " 230 | f"for loss {loss_name}.") 231 | 232 | # Check for key collisions between context and config 233 | conflicting_config_keys = config_keys.intersection(context_kwargs) 234 | if conflicting_config_keys: 235 | raise KeyError(f"The config keys {conflicting_config_keys} conflict " 236 | f"with the context parameters ({context_kwargs.keys()}) " 237 | f"for loss {loss_name}.") 238 | 239 | # Construct the arguments for the loss function 240 | kwargs.update(context_kwargs) 241 | kwargs["config_kwargs"] = config_kwargs 242 | 243 | # Call loss 244 | results = loss_fn(**kwargs) 245 | 246 | # Add empty loss_aux_updates if neceessary 247 | if isinstance(results, Tuple): 248 | loss, loss_aux_update = results 249 | else: 250 | loss, loss_aux_update = results, {} 251 | 252 | return weight, loss, loss_aux_update 253 | 254 | 255 | # -------- Loss functions -------- 256 | @register_loss 257 | def recon(preds: ArrayTree, 258 | targets: ArrayTree, 259 | key: str = "video", 260 | reduction_type: str = "sum", 261 | **_) -> float: 262 | """Reconstruction loss (MSE).""" 263 | inputs = preds["outputs"][key] 264 | targets = targets[key] 265 | loss = F.mse_loss(inputs, targets, reduction=reduction_type) 266 | if reduction_type == "mean": 267 | # This rescaling reflects taking the sum over feature axis & 268 | # mean over space/time axis 269 | loss *= targets.shape[-1] 270 | return torch.mean(loss) 271 | 272 | def recon_loss(preds: ArrayTree, 273 | targets: ArrayTree, 274 | reduction_type: str = "sum") -> float: 275 | """Reconstruction loss (MSE).""" 276 | inputs = preds 277 | targets = targets 278 | loss = F.mse_loss(inputs, targets, reduction=reduction_type) 279 | if reduction_type == "mean": 280 | # This rescaling reflects taking the sum over feature axis & 281 | # mean over space/time axis 282 | loss *= targets.shape[-1] 283 | return loss.mean() 284 | 285 | class Recon_Loss(nn.Module): 286 | 287 | def __init__(self): 288 | super().__init__() 289 | self.l2 = nn.MSELoss(reduction="sum") 290 | 291 | def forward(self, inputs, targets): 292 | # print('in, tar', inputs.shape, targets.shape) 293 | 294 | loss = self.l2(inputs, targets) 295 | return torch.mean(loss) 296 | 297 | # def squared_l2_norm(preds: Array, targets: Array, 298 | # reduction_type: str = "sum") -> Array: 299 | # """Squared L2 norm. 300 | # reduction: in `["sum", "mean"]` 301 | # """ 302 | # if reduction_type =="sum" -------------------------------------------------------------------------------- /savi/lib/metrics.py: -------------------------------------------------------------------------------- 1 | """Clustering metrics.""" 2 | 3 | # TODO: 4 | 5 | from typing import Optional, Sequence, Union, Dict 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | def check_shape(x, expected_shape: Sequence[Optional[int]], name: str): 13 | """Check whether shape x is as expected. 14 | 15 | Args: 16 | x: Any data type with `shape` attribute. if `shape` sttribute is not present 17 | it is assumed to be a scalar with shape (). 18 | expected shape: The shape that is expected of x. For example, 19 | [None, None, 3] can be the `expected_shape` for a color image, 20 | [4, None, None, 3] if we know that batch size is 4. 21 | name: Name of `x` to provide informative error messages. 22 | Raises: ValueError if x's shape does not match expected_shape. Also raises 23 | ValueError if expected_shape is not a list or tuple. 24 | """ 25 | if not isinstance(expected_shape, (list, tuple)): 26 | raise ValueError( 27 | "expected_shape should be a list or tuple of ints but got " 28 | f"{expected_shape}.") 29 | 30 | # Scalars have shape () by definition. 31 | shape = getattr(x, "shape", ()) 32 | 33 | if (len(shape) != len(expected_shape) or 34 | any(j is not None and i != j for i, j in zip(shape, expected_shape))): 35 | raise ValueError( 36 | f"Input {name} had shape {shape} but {expected_shape} was expected" 37 | ) 38 | 39 | 40 | def _validate_inputs(predicted_segmentations: np.ndarray, 41 | ground_truth_segmentations: np.ndarray, 42 | padding_mask: np.ndarray, 43 | mask: Optional[np.ndarray] = None) -> None: 44 | """Checks that all inputs have the expected shapes. 45 | 46 | Args: 47 | predicted_segmentations: An array of integers of shape [bs, seq_len, H, W] 48 | containing model segmentation predictions. 49 | ground_truth_segmentations: An array of integers of shape [bs, seq_len, H, W] 50 | containing ground truth segmentations. 51 | padding_mask: An array of integers of shape [bs, seq_len, H, W] defining 52 | regions where the ground truth is meaningless, for example because this 53 | corresponds to regions which were padded during data augmentation. 54 | Value 0 corresponds to padded regions, 1 corresponds to valid regions to 55 | be used for metric calculation. 56 | mask: An optional array of boolean mask values of shape [bs]. `True` 57 | corresponds to actual batch examples whereas `False` corresponds to padding. 58 | TODO: what exactly is this ? 59 | 60 | Raises: 61 | ValueError if the inputs are not valid. 62 | """ 63 | 64 | check_shape( 65 | predicted_segmentations, [None, None, None, None], 66 | "predicted_segmentations[bs, seq_len, h, w]") 67 | check_shape( 68 | ground_truth_segmentations, [None, None, None, None], 69 | "ground_truth_segmentations [bs, seq_len, h, w]") 70 | check_shape( 71 | predicted_segmentations, ground_truth_segmentations.shape, 72 | "predicted_segmentations [should match ground_truth_segmentations]") 73 | check_shape( 74 | padding_mask, ground_truth_segmentations.shape, 75 | "padding_mask [should match ground_truth_segmentations]") 76 | 77 | if not np.issubdtype(predicted_segmentations.dtype, np.integer): 78 | raise ValueError("predicted_segmentations has to be integer-valued. " 79 | "Got {}".format(predicted_segmentations.dtype)) 80 | 81 | if not np.issubdtype(ground_truth_segmentations.dtype, np.integer): 82 | raise ValueError("ground_truth_segmentations has to be integer-valued. " 83 | "Got {}".format(ground_truth_segmentations.dtype)) 84 | 85 | if not np.issubdtype(padding_mask.dtype, np.integer): 86 | raise ValueError("padding_mask has to be integer_valued. " 87 | "Got {}".format(padding_mask.dtype)) 88 | 89 | if mask is not None: 90 | check_shape(mask, [None], "mask [bs]") 91 | if not np.issubdtype(mask.dtype, np.bool_): 92 | raise ValueError("mask has to be boolean. Got {}".format(mask.dtype)) 93 | 94 | 95 | def adjusted_rand_index(true_ids: np.ndarray, pred_ids: np.ndarray, 96 | num_instances_true: int, num_instances_pred: int, 97 | padding_mask: Optional[np.ndarray] = None, 98 | ignore_background: bool = False) -> np.ndarray: 99 | """Computes the adjusted Rand Index (ARI), a clustering similarity score. 100 | 101 | Args: 102 | true_ids: An integer-valued array of shape 103 | [bs, seq_len, H, W]. The true cluster assignment encoded as integer ids. 104 | pred_ids: An integer-valued array of shape 105 | [bs, seq_len, H, W]. The predicted cluster assignment encoder as integer ids. 106 | num_instances_true: An integer, the number of instances in true_ids 107 | (i.e. max(true_ids) + 1). 108 | num_instances_pred: An integer, the number of instances in true_ids 109 | (i.e. max(pred_ids) + 1). 110 | padding_mask: An array of integers of shape [bs, seq_len, H, W] defining regions 111 | where the ground truth is meaningless, for example because this corresponds to 112 | regions which were padded during data augmentation. Value 0 corresponds to 113 | padded regions, 1 corresponds to valid regions to be used for metric calculation. 114 | ignore_background: Boolean, if True, then ignore all pixels where true_ids == 0 (default: False). 115 | 116 | Returns: 117 | ARI scores as a float32 array of shape [bs]. 118 | """ 119 | 120 | true_oh = F.one_hot(torch.from_numpy(true_ids).long(), num_instances_true) 121 | pred_oh = F.one_hot(torch.from_numpy(pred_ids).long(), num_instances_pred) 122 | if padding_mask is not None: 123 | true_oh = true_oh * padding_mask[..., None] 124 | 125 | if ignore_background: 126 | true_oh = true_oh[..., 1:] # remove the background row 127 | 128 | N = torch.einsum("bthwc,bthwk->bck", true_oh, pred_oh) 129 | A = torch.sum(N, dim=-1) # row-sum (bs, c) 130 | B = torch.sum(N, dim=-2) # col-sum (bs, k) 131 | num_points = torch.sum(A, dim=1) 132 | 133 | rindex = torch.sum(N * (N - 1), dim=1).sum(dim=1) 134 | aindex = torch.sum(A * (A - 1), dim=1) 135 | bindex = torch.sum(B * (B - 1), dim=1) 136 | expected_rindex = aindex * bindex / torch.clip(num_points * (num_points-1), 1) 137 | max_rindex = (aindex + bindex) / 2 138 | denominator = max_rindex - expected_rindex 139 | ari = (rindex - expected_rindex) / denominator 140 | 141 | # There are two cases for which the denominator can be zero: 142 | # 1. If both label_pred and label_true assign all pixels to a single cluster. 143 | # (max_rindex == expected_rindex == rindex == num_points * (num_points-1)) 144 | # 2. If both label_pred and label_true assign max 1 point to each cluster. 145 | # (max_rindex == expected_rindex == rindex == 0) 146 | # In both cases, we want the ARI score to be 1.0: 147 | # return torch.where(denominator, ari, 1.0) 148 | return torch.where(denominator > 0, ari.double(), 1.0) 149 | 150 | class Ari(): 151 | """Adjusted Rand Index (ARI) computed from predictions and labels. 152 | 153 | ARI is a similarity score to compare two clusterings. ARI returns values in 154 | the range [-1, 1], where 1 corresponds to two identical clusterings (up to 155 | permutation), i.e. a perfect match between the predicted clustering and the 156 | ground-truth clustering. A value of (close to) 0 corresponds to chance. 157 | Negative values corresponds to cases where the agreement between the 158 | clusterings is less than expected from a random assignment. 159 | In this implementations, we use ARI to compare predicted instance segmentation 160 | masks (including background prediction) with ground-trueht segmentation 161 | annotations. 162 | """ 163 | 164 | @staticmethod 165 | def from_model_output(predicted_segmentations: np.ndarray, 166 | ground_truth_segmentations: np.ndarray, 167 | padding_mask: np.ndarray, 168 | ground_truth_max_num_instances: int, 169 | predicted_max_num_instances: int, 170 | ignore_background: bool= False, 171 | mask: Optional[np.ndarray] = None, 172 | **_): 173 | """Computation of the ARI clustering metric. 174 | 175 | NOTE: This implementation does not currently support padding masks. 176 | Args: 177 | predicted_segmentations: An array of integers of shape 178 | [bs, seq_len, H, W] containing model segmentation predictions. 179 | ground_truth_segmentations: An array of integers of shape 180 | [bs, seq_len, H, W] containing ground truth segmentations. 181 | padding_mask: An array of integers of shape [bs, seq_len, H, W] 182 | defining regions where the ground truth is meaningless, for example 183 | because this corresponds to regions which were padded during data 184 | augmentation. Value 0 corresponds to padded regions, 1 corresponds to 185 | valid regions to be used for metric calculation. 186 | ground_truth_max_num_instances: Maximum number of instances (incl. 187 | background, which counts as the 0-th instance) possible in the dataset. 188 | predicted_max_num_instances: Maximum number of predicted instances (incl. 189 | background). 190 | ignore_background: If True, then ignore all pixels where 191 | ground_truth_segmentations == 0 (default: False). 192 | mask: An optional array of boolean mask values of shape [bs]. `True` 193 | corresponds to actual batch examples whereas `False` corresponds to 194 | padding. 195 | 196 | Returns: 197 | Object of Ari with computed intermediate values. 198 | """ 199 | _validate_inputs( 200 | predicted_segmentations=predicted_segmentations, 201 | ground_truth_segmentations=ground_truth_segmentations, 202 | padding_mask=padding_mask, 203 | mask=mask) 204 | 205 | batch_size = predicted_segmentations.shape[0] 206 | 207 | if mask is None: 208 | mask = np.ones(batch_size, dtype=padding_mask.dtype) 209 | else: 210 | mask = np.asarray(mask, dtype=padding_mask.dtype) 211 | 212 | ari_batch = adjusted_rand_index( 213 | pred_ids=predicted_segmentations, 214 | true_ids=ground_truth_segmentations, 215 | num_instances_true=ground_truth_max_num_instances, 216 | num_instances_pred=predicted_max_num_instances, 217 | padding_mask=padding_mask, 218 | ignore_background=ignore_background) 219 | 220 | # return cls(total=torch.sum(ari_batch * mask), count=torch.sum(mask)) 221 | return {'total': torch.sum(ari_batch * mask), 'count': np.sum(mask)} 222 | 223 | class AriNoBg(Ari): 224 | """Adjusted Rand Index (ARI), ignoring the ground-truth background label.""" 225 | 226 | @classmethod 227 | def from_model_output(cls, **kwargs): 228 | """See `Ari` dostring for allowed keyword arguments.""" 229 | return super().from_model_output(**kwargs, ignore_background=True) -------------------------------------------------------------------------------- /savi/lib/metrics_jax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Clustering metrics.""" 16 | 17 | from typing import Optional, Sequence, Union 18 | 19 | from clu import metrics 20 | import flax 21 | import jax 22 | import jax.numpy as jnp 23 | import numpy as np 24 | 25 | Ndarray = Union[np.ndarray, jnp.ndarray] 26 | 27 | 28 | def check_shape(x, expected_shape: Sequence[Optional[int]], name: str): 29 | """Check whether shape x is as expected. 30 | 31 | Args: 32 | x: Any data type with `shape` attribute. If `shape` attribute is not present 33 | it is assumed to be a scalar with shape (). 34 | expected_shape: The shape that is expected of x. For example, 35 | [None, None, 3] can be the `expected_shape` for a color image, 36 | [4, None, None, 3] if we know that batch size is 4. 37 | name: Name of `x` to provide informative error messages. 38 | 39 | Raises: ValueError if x's shape does not match expected_shape. Also raises 40 | ValueError if expected_shape is not a list or tuple. 41 | """ 42 | if not isinstance(expected_shape, (list, tuple)): 43 | raise ValueError( 44 | "expected_shape should be a list or tuple of ints but got " 45 | f"{expected_shape}.") 46 | 47 | # Scalars have shape () by definition. 48 | shape = getattr(x, "shape", ()) 49 | 50 | if (len(shape) != len(expected_shape) or 51 | any(j is not None and i != j for i, j in zip(shape, expected_shape))): 52 | raise ValueError( 53 | f"Input {name} had shape {shape} but {expected_shape} was expected.") 54 | 55 | 56 | def _validate_inputs(predicted_segmentations: Ndarray, 57 | ground_truth_segmentations: Ndarray, 58 | padding_mask: Ndarray, 59 | mask: Optional[Ndarray] = None) -> None: 60 | """Checks that all inputs have the expected shapes. 61 | 62 | Args: 63 | predicted_segmentations: An array of integers of shape [bs, seq_len, H, W] 64 | containing model segmentation predictions. 65 | ground_truth_segmentations: An array of integers of shape [bs, seq_len, H, 66 | W] containing ground truth segmentations. 67 | padding_mask: An array of integers of shape [bs, seq_len, H, W] defining 68 | regions where the ground truth is meaningless, for example because this 69 | corresponds to regions which were padded during data augmentation. Value 0 70 | corresponds to padded regions, 1 corresponds to valid regions to be used 71 | for metric calculation. 72 | mask: An optional array of boolean mask values of shape [bs]. `True` 73 | corresponds to actual batch examples whereas `False` corresponds to 74 | padding. 75 | 76 | Raises: 77 | ValueError if the inputs are not valid. 78 | """ 79 | 80 | check_shape( 81 | predicted_segmentations, [None, None, None, None], 82 | "predicted_segmentations [bs, seq_len, h, w]") 83 | check_shape( 84 | ground_truth_segmentations, [None, None, None, None], 85 | "ground_truth_segmentations [bs, seq_len, h, w]") 86 | check_shape( 87 | predicted_segmentations, ground_truth_segmentations.shape, 88 | "predicted_segmentations [should match ground_truth_segmentations]") 89 | check_shape( 90 | padding_mask, ground_truth_segmentations.shape, 91 | "padding_mask [should match ground_truth_segmentations]") 92 | 93 | if not jnp.issubdtype(predicted_segmentations.dtype, jnp.integer): 94 | raise ValueError("predicted_segmentations has to be integer-valued. " 95 | "Got {}".format(predicted_segmentations.dtype)) 96 | 97 | if not jnp.issubdtype(ground_truth_segmentations.dtype, jnp.integer): 98 | raise ValueError("ground_truth_segmentations has to be integer-valued. " 99 | "Got {}".format(ground_truth_segmentations.dtype)) 100 | 101 | if not jnp.issubdtype(padding_mask.dtype, jnp.integer): 102 | raise ValueError("padding_mask has to be integer-valued. " 103 | "Got {}".format(padding_mask.dtype)) 104 | 105 | if mask is not None: 106 | check_shape(mask, [None], "mask [bs]") 107 | if not jnp.issubdtype(mask.dtype, jnp.bool_): 108 | raise ValueError("mask has to be boolean. Got {}".format(mask.dtype)) 109 | 110 | 111 | def adjusted_rand_index(true_ids: Ndarray, pred_ids: Ndarray, 112 | num_instances_true: int, num_instances_pred: int, 113 | padding_mask: Optional[Ndarray] = None, 114 | ignore_background: bool = False) -> Ndarray: 115 | """Computes the adjusted Rand index (ARI), a clustering similarity score. 116 | 117 | Args: 118 | true_ids: An integer-valued array of shape 119 | [batch_size, seq_len, H, W]. The true cluster assignment encoded 120 | as integer ids. 121 | pred_ids: An integer-valued array of shape 122 | [batch_size, seq_len, H, W]. The predicted cluster assignment 123 | encoded as integer ids. 124 | num_instances_true: An integer, the number of instances in true_ids 125 | (i.e. max(true_ids) + 1). 126 | num_instances_pred: An integer, the number of instances in true_ids 127 | (i.e. max(pred_ids) + 1). 128 | padding_mask: An array of integers of shape [batch_size, seq_len, H, W] 129 | defining regions where the ground truth is meaningless, for example 130 | because this corresponds to regions which were padded during data 131 | augmentation. Value 0 corresponds to padded regions, 1 corresponds to 132 | valid regions to be used for metric calculation. 133 | ignore_background: Boolean, if True, then ignore all pixels where 134 | true_ids == 0 (default: False). 135 | 136 | Returns: 137 | ARI scores as a float32 array of shape [batch_size]. 138 | 139 | References: 140 | Lawrence Hubert, Phipps Arabie. 1985. "Comparing partitions" 141 | https://link.springer.com/article/10.1007/BF01908075 142 | Wikipedia 143 | https://en.wikipedia.org/wiki/Rand_index 144 | Scikit Learn 145 | http://scikit-learn.org/stable/modules/generated/sklearn.metrics.adjusted_rand_score.html 146 | """ 147 | 148 | true_oh = jax.nn.one_hot(true_ids, num_instances_true) 149 | pred_oh = jax.nn.one_hot(pred_ids, num_instances_pred) 150 | if padding_mask is not None: 151 | true_oh = true_oh * padding_mask[..., None] 152 | # pred_oh = pred_oh * padding_mask[..., None] # <-- not needed 153 | 154 | if ignore_background: 155 | true_oh = true_oh[..., 1:] # Remove the background row. 156 | 157 | N = jnp.einsum("bthwc,bthwk->bck", true_oh, pred_oh) 158 | A = jnp.sum(N, axis=-1) # row-sum (batch_size, c) 159 | B = jnp.sum(N, axis=-2) # col-sum (batch_size, k) 160 | num_points = jnp.sum(A, axis=1) 161 | 162 | rindex = jnp.sum(N * (N - 1), axis=[1, 2]) 163 | aindex = jnp.sum(A * (A - 1), axis=1) 164 | bindex = jnp.sum(B * (B - 1), axis=1) 165 | expected_rindex = aindex * bindex / jnp.clip(num_points * (num_points-1), 1) 166 | max_rindex = (aindex + bindex) / 2 167 | denominator = max_rindex - expected_rindex 168 | ari = (rindex - expected_rindex) / denominator 169 | 170 | # There are two cases for which the denominator can be zero: 171 | # 1. If both label_pred and label_true assign all pixels to a single cluster. 172 | # (max_rindex == expected_rindex == rindex == num_points * (num_points-1)) 173 | # 2. If both label_pred and label_true assign max 1 point to each cluster. 174 | # (max_rindex == expected_rindex == rindex == 0) 175 | # In both cases, we want the ARI score to be 1.0: 176 | return jnp.where(denominator, ari, 1.0) 177 | 178 | 179 | @flax.struct.dataclass 180 | class Ari(metrics.Average): 181 | """Adjusted Rand Index (ARI) computed from predictions and labels. 182 | 183 | ARI is a similarity score to compare two clusterings. ARI returns values in 184 | the range [-1, 1], where 1 corresponds to two identical clusterings (up to 185 | permutation), i.e. a perfect match between the predicted clustering and the 186 | ground-truth clustering. A value of (close to) 0 corresponds to chance. 187 | Negative values corresponds to cases where the agreement between the 188 | clusterings is less than expected from a random assignment. 189 | 190 | In this implementation, we use ARI to compare predicted instance segmentation 191 | masks (including background prediction) with ground-truth segmentation 192 | annotations. 193 | """ 194 | 195 | @classmethod 196 | def from_model_output(cls, 197 | predicted_segmentations: Ndarray, 198 | ground_truth_segmentations: Ndarray, 199 | padding_mask: Ndarray, 200 | ground_truth_max_num_instances: int, 201 | predicted_max_num_instances: int, 202 | ignore_background: bool = False, 203 | mask: Optional[Ndarray] = None, 204 | **_) -> metrics.Metric: 205 | """Computation of the ARI clustering metric. 206 | 207 | NOTE: This implementation does not currently support padding masks. 208 | 209 | Args: 210 | predicted_segmentations: An array of integers of shape 211 | [bs, seq_len, H, W] containing model segmentation predictions. 212 | ground_truth_segmentations: An array of integers of shape 213 | [bs, seq_len, H, W] containing ground truth segmentations. 214 | padding_mask: An array of integers of shape [bs, seq_len, H, W] 215 | defining regions where the ground truth is meaningless, for example 216 | because this corresponds to regions which were padded during data 217 | augmentation. Value 0 corresponds to padded regions, 1 corresponds to 218 | valid regions to be used for metric calculation. 219 | ground_truth_max_num_instances: Maximum number of instances (incl. 220 | background, which counts as the 0-th instance) possible in the dataset. 221 | predicted_max_num_instances: Maximum number of predicted instances (incl. 222 | background). 223 | ignore_background: If True, then ignore all pixels where 224 | ground_truth_segmentations == 0 (default: False). 225 | mask: An optional array of boolean mask values of shape [bs]. `True` 226 | corresponds to actual batch examples whereas `False` corresponds to 227 | padding. 228 | 229 | Returns: 230 | Object of Ari with computed intermediate values. 231 | """ 232 | _validate_inputs( 233 | predicted_segmentations=predicted_segmentations, 234 | ground_truth_segmentations=ground_truth_segmentations, 235 | padding_mask=padding_mask, 236 | mask=mask) 237 | 238 | batch_size = predicted_segmentations.shape[0] 239 | 240 | if mask is None or len(mask) == 0: 241 | mask = jnp.ones(batch_size, dtype=padding_mask.dtype) 242 | else: 243 | mask = jnp.asarray(mask, dtype=padding_mask.dtype) 244 | 245 | ari_batch = adjusted_rand_index( 246 | pred_ids=predicted_segmentations, 247 | true_ids=ground_truth_segmentations, 248 | num_instances_true=ground_truth_max_num_instances, 249 | num_instances_pred=predicted_max_num_instances, 250 | padding_mask=padding_mask, 251 | ignore_background=ignore_background) 252 | # return cls(total=jnp.sum(ari_batch * mask), count=jnp.sum(mask)) # pytype: disable=wrong-keyword-args 253 | return {'total': jnp.sum(ari_batch * mask), 'count': jnp.sum(mask)} 254 | 255 | @flax.struct.dataclass 256 | class AriNoBg(Ari): 257 | """Adjusted Rand Index (ARI), ignoring the ground-truth background label.""" 258 | 259 | @classmethod 260 | def from_model_output(cls, **kwargs) -> metrics.Metric: 261 | """See `Ari` docstring for allowed keyword arguments.""" 262 | return super().from_model_output(**kwargs, ignore_background=True) 263 | -------------------------------------------------------------------------------- /savi/lib/utils.py: -------------------------------------------------------------------------------- 1 | """Common utils.""" 2 | 3 | # TODO: 4 | 5 | import functools 6 | import importlib 7 | from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Type, Union 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import numpy as np 13 | import math 14 | 15 | import matplotlib 16 | import matplotlib.pyplot as plt 17 | import skimage.transform 18 | 19 | from savi.lib import metrics 20 | 21 | Array = Union[np.ndarray, torch.Tensor] # FIXME: 22 | ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet 23 | DictTree = Dict[str, Union[Array, "DictTree"]] # pytype: disable=not-supported-yet 24 | PRNGKey = Array 25 | ConfigAttr = Any 26 | MetricSpec = Dict[str, str] 27 | 28 | class TrainState: 29 | """Data structure for checkpointing the model.""" 30 | step: int 31 | optimizer: torch.optim.Optimizer 32 | variables: torch.nn.parameter.Parameter 33 | rng: int # FIXME: seed ? 34 | 35 | 36 | # TODO: not sure what to do with this 37 | METRIC_TYPE_TO_CLS = { 38 | "loss": Any, 39 | "ari": metrics.Ari, 40 | "ari_nobg": metrics.AriNoBg 41 | } 42 | 43 | # FIXME: make metrics collection just a dictionary 44 | def make_metrics_collection(metrics_spec: Optional[MetricSpec]) -> Dict[str, Any]: 45 | metrics_dict = {} 46 | if metrics_spec: 47 | for m_name, m_type in metrics_spec.items(): 48 | metrics_dict[m_name] = METRIC_TYPE_TO_CLS[m_type] 49 | 50 | return metrics_dict 51 | 52 | 53 | def _flatten_dict(xs, is_leaf=None, sep=None): 54 | assert isinstance(xs, dict), 'expected (frozen)dict' 55 | 56 | def _key(path): 57 | if sep is None: 58 | return path 59 | return sep.join(path) 60 | 61 | def _flatten(xs, prefix): 62 | if not isinstance(xs, dict) or ( 63 | is_leaf and is_leaf(prefix, xs)): 64 | return {_key(prefix): xs} 65 | result = {} 66 | is_empty = True 67 | for key, value in xs.items(): 68 | is_empty = False 69 | path = prefix + (key,) 70 | result.update(_flatten(value, path)) 71 | return result 72 | return _flatten(xs, ()) 73 | 74 | def flatten_named_dicttree(metrics_res: DictTree, sep: str = "/"): 75 | """Flatten dictionary.""" 76 | metrics_res_flat = {} 77 | for k, v in _flatten_dict(metrics_res).items(): 78 | metrics_res_flat[(sep.join(k)).strip(sep)] = v 79 | return metrics_res_flat 80 | 81 | 82 | # def clip_grads(grad_tree: ArrayTree, max_norm: float, epsilon: float = 1e-6): 83 | # """Gradient clipping with epsilon. 84 | 85 | # """ 86 | 87 | def lecun_uniform_(tensor, gain=1.): 88 | fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(tensor) 89 | var = gain / float(fan_in) 90 | a = math.sqrt(3 * var) 91 | return nn.init._no_grad_uniform_(tensor, -a, a) 92 | 93 | 94 | def lecun_normal_(tensor, gain=1., mode="fan_in"): 95 | fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(tensor) 96 | if mode == "fan_in": 97 | scale_mode = fan_in 98 | elif mode == "fan_out": 99 | scale_mode = fan_out 100 | else: 101 | raise NotImplementedError 102 | var = gain / float(scale_mode) 103 | # constant is stddev of standard normal truncated to (-2, 2) 104 | std = math.sqrt(var) / .87962566103423978 105 | # return nn.init._no_grad_normal_(tensor, 0., std) 106 | kernel = torch.nn.init._no_grad_trunc_normal_(tensor, 0, 1, -2, 2) * std 107 | with torch.no_grad(): 108 | tensor[:] = kernel[:] 109 | return tensor 110 | 111 | def lecun_normal_fan_out_(tensor, gain=1.): 112 | return lecun_normal_(tensor, gain=gain, mode="fan_out") 113 | 114 | def lecun_normal_convtranspose_(tensor, gain=1.): 115 | # for some reason, the convtranspose weights are [in_channels, out_channels, kernel, kernel] 116 | # but the _calculate_fan_in_and_fan_out treats dim 1 as fan_in and dim 0 as fan_out. 117 | # so, for convolution weights, have to use fan_out instead of fan_in 118 | # which is actually using fan_in instead of fan_out 119 | return lecun_normal_fan_out_(tensor, gain=gain) 120 | 121 | init_fn = { 122 | 'xavier_uniform': nn.init.xavier_uniform_, 123 | 'xavier_normal': nn.init.xavier_normal_, 124 | 'kaiming_uniform': nn.init.kaiming_uniform_, 125 | 'kaiming_normal': nn.init.kaiming_normal_, 126 | 'lecun_uniform': lecun_uniform_, 127 | 'lecun_normal': lecun_normal_, 128 | 'lecun_normal_fan_out': lecun_normal_fan_out_, 129 | 'ones': nn.init.ones_, 130 | 'zeros': nn.init.zeros_, 131 | 'default': lambda x: x} 132 | def init_param(name, gain=1.): 133 | assert name in init_fn.keys(), "not a valid init method" 134 | # return init_fn[name](tensor, gain) 135 | return functools.partial(init_fn[name], gain=gain) 136 | 137 | def spatial_broadcast(x: torch.Tensor, resolution: Sequence[int]) -> Array: 138 | """Broadcast flat inputs to a 2D grid of a given resolution.""" 139 | x = x[:, None, None, :] 140 | # return np.tile(x, [1, resolution[0], resolution[1], 1]) 141 | return torch.tile(x, [1, resolution[0], resolution[1], 1]) 142 | 143 | def broadcast_across_batch(inputs: Array, batch_size: int) -> Array: 144 | """Broadcasts inputs across a batch of examples (creates new axis).""" 145 | return torch.broadcast_to( 146 | torch.unsqueeze(0), 147 | size=(batch_size,) + inputs.shape) 148 | 149 | # def time_distributed(cls, in_axes=1, axis=1): 150 | 151 | def create_gradient_grid( 152 | samples_per_dim: Sequence[int], value_range: Sequence[float] = (-1.0, 1.0) 153 | ) -> Array: 154 | """Creates a tensor with equidistant entries from -1 to +1 in each dim 155 | 156 | Args: 157 | samples_per_dim: Number of points to have along each dimension. 158 | value_range: In each dimension, points will go from range[0] to range[1] 159 | 160 | Returns: 161 | A tensor of shape [samples_per_dim] + [len(samples_per_dim)]. 162 | """ 163 | 164 | s = [np.linspace(value_range[0], value_range[1], n) for n in samples_per_dim] 165 | pe = np.stack(np.meshgrid(*s, sparse=False, indexing="ij"), axis=-1) 166 | return np.array(pe) 167 | 168 | def convert_to_fourier_features(inputs: Array, basis_degree: int) -> Array: 169 | """Convert inputs to Fourier features, e.g. for positional encoding.""" 170 | 171 | # inputs.shape = (..., n_dims). 172 | # inputs should be in range [-pi, pi] or [0, 2pi]. 173 | n_dims = inputs.shape[-1] 174 | 175 | # Generate frequency basis 176 | freq_basis = np.concatenate( # shape = (n_dims, n_dims * basis_degree) 177 | [2**i * np.eye(n_dims) for i in range(basis_degree)], 1) 178 | 179 | # x.shape = (..., n_dims * basis_degree) 180 | x = inputs @ freq_basis # Project inputs onto frequency basis. 181 | 182 | # Obtain Fourier feaures as [sin(x), cos(x)] = [sin(x), sin(x + 0.5 * pi)]. 183 | return np.sin(np.concatenate([x, x + 0.5 * np.pi], axis=-1)) -------------------------------------------------------------------------------- /savi/main.py: -------------------------------------------------------------------------------- 1 | # from savi.datasets.tfds.load_tfds_data import test 2 | from savi.trainers.tfds_trainer import test 3 | from savi.trainers.tfds_trainer_dataparallel import test 4 | 5 | if __name__ == "__main__": 6 | print('hi') 7 | 8 | test() -------------------------------------------------------------------------------- /savi/modules/__init__.py: -------------------------------------------------------------------------------- 1 | """Module library.""" 2 | 3 | # FIXME 4 | 5 | # Re-export commonly used modules and functions 6 | 7 | from .attention import (GeneralizedDotProductAttention, 8 | InvertedDotProductAttention, SlotAttention, 9 | TransformerBlock, TransformerBlockOld, Transformer) 10 | from .convolution import (CNN, CNN2) 11 | from .decoders import SpatialBroadcastDecoder 12 | from .initializers import (GaussianStateInit, ParamStateInit, 13 | SegmentationEncoderStateInit, 14 | CoordinateEncoderStateInit) 15 | from .misc import (MLP, PositionEmbedding, Readout) 16 | from .video import (FrameEncoder, Processor, SAVi) 17 | from .factory import build_modules as savi_build_modules -------------------------------------------------------------------------------- /savi/modules/attention.py: -------------------------------------------------------------------------------- 1 | """Attention module library.""" 2 | 3 | import functools 4 | from typing import Any, Dict, Iterable, Mapping, Optional, Tuple, Union 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import numpy as np 10 | 11 | from savi.modules import misc 12 | from savi.lib.utils import init_param, init_fn 13 | 14 | Shape = Tuple[int] 15 | 16 | DType = Any 17 | Array = torch.Tensor # np.ndarray 18 | ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet # TODO: what is this ? 19 | ProcessorState = ArrayTree 20 | PRNGKey = Array 21 | NestedDict = Dict[str, Any] 22 | 23 | 24 | class SlotAttention(nn.Module): 25 | """Slot Attention module. 26 | 27 | Note: This module uses pre-normalization by default. 28 | """ 29 | def __init__(self, 30 | input_size: int, # size of encoded inputs. # FIXME: added for submodules. 31 | slot_size: int, # fixed size. or same as qkv_size. 32 | qkv_size: int = None, # fixed size, or slot size. # Optional[int] = None, 33 | num_iterations: int = 1, 34 | mlp_size: Optional[int] = None, 35 | epsilon: float = 1e-8, 36 | num_heads: int = 1, 37 | weight_init: str = 'xavier_uniform' 38 | ): 39 | super().__init__() 40 | 41 | self.input_size = input_size 42 | self.slot_size = slot_size 43 | self.qkv_size = qkv_size if qkv_size is not None else slot_size 44 | self.num_iterations = num_iterations 45 | self.mlp_size = mlp_size 46 | self.epsilon = epsilon 47 | self.num_heads = num_heads 48 | self.weight_init = weight_init 49 | # other definitions 50 | self.head_dim = qkv_size // self.num_heads 51 | 52 | # shared modules 53 | ## gru 54 | self.gru = misc.myGRUCell(slot_size, slot_size, weight_init=weight_init) 55 | 56 | ## weights 57 | self.dense_q = nn.Linear(slot_size, qkv_size, bias=False) 58 | self.dense_k = nn.Linear(input_size, qkv_size, bias=False) 59 | self.dense_v = nn.Linear(input_size, qkv_size, bias=False) 60 | init_fn[weight_init['linear_w']](self.dense_q.weight) 61 | init_fn[weight_init['linear_w']](self.dense_k.weight) 62 | init_fn[weight_init['linear_w']](self.dense_v.weight) 63 | 64 | ## layernorms 65 | self.layernorm_q = nn.LayerNorm(qkv_size, eps=1e-6) 66 | self.layernorm_input = nn.LayerNorm(input_size, eps=1e-6) 67 | 68 | ## attention 69 | self.inverted_attention = InvertedDotProductAttention( 70 | input_size=qkv_size, output_size=slot_size, 71 | num_heads=self.num_heads, norm_type="mean", 72 | epsilon=epsilon, weight_init=weight_init) 73 | 74 | ## output transform 75 | if self.mlp_size is not None: 76 | self.mlp = misc.MLP( 77 | input_size=slot_size, hidden_size=self.mlp_size, 78 | output_size=slot_size, layernorm="pre", residual=True, 79 | weight_init=weight_init) 80 | 81 | def forward(self, slots: Array, inputs: Array, 82 | padding_mask: Optional[Array] = None) -> Array: 83 | """Slot Attention module forward pass.""" 84 | del padding_mask # Unused. 85 | 86 | B, O, D = slots.shape 87 | _, L, M = inputs.shape 88 | 89 | # inputs.shape = (b, n_inputs, input_size). 90 | inputs = self.layernorm_input(inputs) 91 | # k.shape = (b, n_inputs, num_heads, head_dim). 92 | k = self.dense_k(inputs).view(B, L, self.num_heads, self.head_dim) 93 | # v.shape = (b, n_inputs, num_heads, head_dim). 94 | v = self.dense_v(inputs).view(B, L, self.num_heads, self.head_dim) 95 | 96 | # Multiple rounds of attention. 97 | for _ in range(self.num_iterations): 98 | 99 | # Inverted dot-product attention. 100 | slots_n = self.layernorm_q(slots) 101 | ## q.shape = (b, num_objects, num_heads, qkv_size). 102 | q = self.dense_q(slots_n).view(B, O, self.num_heads, self.head_dim) 103 | updates, attn = self.inverted_attention(query=q, key=k, value=v) 104 | 105 | # Recurrent update. 106 | slots = self.gru( 107 | updates.reshape(-1, D), 108 | slots.reshape(-1, D)) 109 | slots = slots.reshape(B, -1, D) 110 | 111 | # Feedforward block with pre-normalization. 112 | if self.mlp_size is not None: 113 | slots = self.mlp(slots) 114 | 115 | return slots, attn 116 | 117 | def compute_attention(self, slots, inputs): 118 | """Slot Attention without GRU and iteration.""" 119 | # inputs.shape = (b, n_inputs, input_size). 120 | B, O, D = slots.shape 121 | _, L, M = inputs.shape 122 | inputs = self.layernorm_input(inputs) 123 | slots = self.layernorm_q(slots) 124 | q = self.dense_q(slots).view(B, O, self.num_heads, self.head_dim) 125 | k = self.dense_k(inputs).view(B, L, self.num_heads, self.head_dim) 126 | v = self.dense_v(inputs).view(B, L, self.num_heads, self.head_dim) 127 | updated_slots, attn = self.inverted_attention(query=q, key=k, value=v) 128 | 129 | # updated_slots [B Q S], attn TODO: shape 130 | return updated_slots, attn 131 | 132 | class InvertedDotProductAttention(nn.Module): 133 | """Inverted version of dot-product attention (softmax over query axis).""" 134 | 135 | def __init__(self, 136 | input_size: int, # qkv_size # FIXME: added for submodules 137 | output_size: int, # FIXME: added for submodules 138 | num_heads: Optional[int] = 1, # FIXME: added for submodules 139 | norm_type: Optional[str] = "mean", # mean, layernorm, or None 140 | # multi_head: bool = False, # FIXME: can infer from num_heads. 141 | epsilon: float = 1e-8, 142 | dtype: DType = torch.float32, 143 | weight_init = None 144 | # precision # not used 145 | ): 146 | super().__init__() 147 | 148 | assert num_heads >= 1 and isinstance(num_heads, int) 149 | 150 | self.input_size = input_size 151 | self.output_size = output_size 152 | self.norm_type = norm_type 153 | self.num_heads = num_heads 154 | self.multi_head = True if num_heads > 1 else False 155 | self.epsilon = epsilon 156 | self.dtype = dtype 157 | self.weight_init = weight_init 158 | # other definitions 159 | self.head_dim = input_size // self.num_heads 160 | 161 | # submodules 162 | self.attn_fn = GeneralizedDotProductAttention( 163 | inverted_attn=True, 164 | renormalize_keys=True if self.norm_type == "mean" else False, 165 | epsilon=self.epsilon, 166 | dtype=self.dtype) 167 | if self.multi_head: 168 | self.dense_o = nn.Linear(input_size, output_size, bias=False) 169 | init_fn[weight_init['linear_w']](self.dense_o.weight) 170 | if self.norm_type == "layernorm": 171 | self.layernorm = nn.LayerNorm(output_size, eps=1e-6) 172 | 173 | def forward(self, query: Array, key: Array, value: Array) -> Array: 174 | """Computes inverted dot-product attention. 175 | 176 | Args: 177 | qk_features = [num_heads, head_dim] = qkv_dim 178 | query: Queries with shape of `[batch, q_num, qk_features]`. 179 | key: Keys with shape of `[batch, kv_num, qk_features]`. 180 | value: Values with shape of `[batch, kv_num, v_features]`. 181 | train: Indicating whether we're training or evaluating. 182 | 183 | Returns: 184 | Output of shape `[batch, n_queries, v_features]` 185 | """ 186 | B, Q = query.shape[:2] 187 | 188 | # Apply attention mechanism 189 | output, attn = self.attn_fn(query=query, key=key, value=value) 190 | 191 | if self.multi_head: 192 | # Multi-head aggregation. Equivalent to concat + dense layer. 193 | output = self.dense_o(output.view(B, Q, self.input_size)).view(B, Q, self.output_size) 194 | else: 195 | # Remove head dimension. 196 | output = output.squeeze(-2) 197 | 198 | if self.norm_type == "layernorm": 199 | output = self.layernorm(output) 200 | 201 | return output, attn 202 | 203 | 204 | class GeneralizedDotProductAttention(nn.Module): 205 | """Multi-head dot-product attention with customizable normalization axis. 206 | 207 | This module supports logging of attention weights in a variable collection. 208 | """ 209 | 210 | def __init__(self, 211 | dtype: DType = torch.float32, 212 | # precision: Optional[] # not used 213 | epsilon: float = 1e-8, 214 | inverted_attn: bool = False, 215 | renormalize_keys: bool = False, 216 | attn_weights_only: bool = False 217 | ): 218 | super().__init__() 219 | 220 | self.dtype = dtype 221 | self.epsilon = epsilon 222 | self.inverted_attn = inverted_attn 223 | self.renormalize_keys = renormalize_keys 224 | self.attn_weights_only = attn_weights_only 225 | 226 | def forward(self, query: Array, key: Array, value: Array, 227 | train: bool = False, **kwargs) -> Array: 228 | """Computes multi-head dot-product attention given query, key, and value. 229 | 230 | Args: 231 | query: Queries with shape of `[batch..., q_num, num_heads, qk_features]`. 232 | key: Keys with shape of `[batch..., kv_num, num_heads, qk_features]`. 233 | value: Values with shape of `[batch..., kv_num, num_heads, v_features]`. 234 | train: Indicating whether we're training or evaluating. 235 | **kwargs: Additional keyword arguments are required when used as attention 236 | function in nn.MultiHeadDotPRoductAttention, but they will be ignored here. 237 | 238 | Returns: 239 | Output of shape `[batch..., q_num, num_heads, v_features]`. 240 | """ 241 | del train # Unused. 242 | 243 | assert query.ndim == key.ndim == value.ndim, ( 244 | "Queries, keys, and values must have the same rank.") 245 | assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], ( 246 | "Query, key, and value batch dimensions must match.") 247 | assert query.shape[-2] == key.shape[-2] == value.shape[-2], ( 248 | "Query, key, and value num_heads dimensions must match.") 249 | assert key.shape[-3] == value.shape[-3], ( 250 | "Key and value cardinality dimensions must match.") 251 | assert query.shape[-1] == key.shape[-1], ( 252 | "Query and key feature dimensions must match.") 253 | 254 | if kwargs.get("bias") is not None: 255 | raise NotImplementedError( 256 | "Support for masked attention is not yet implemented.") 257 | 258 | if "dropout_rate" in kwargs: 259 | if kwargs["dropout_rate"] > 0.: 260 | raise NotImplementedError("Support for dropout is not yet implemented.") 261 | 262 | # Temperature normalization. 263 | qk_features = query.shape[-1] 264 | query = query / (qk_features ** 0.5) # torch.sqrt(qk_features) 265 | 266 | # attn.shape = (batch..., num_heads, q_num, kv_num) 267 | attn = torch.matmul(query.permute(0, 2, 1, 3), key.permute(0, 2, 3, 1)) # bhqd @ bhdk -> bhqk 268 | 269 | if self.inverted_attn: 270 | attention_dim = -2 # Query dim 271 | else: 272 | attention_dim = -1 # Key dim 273 | 274 | # Softmax normalization (by default over key dim) 275 | attn = torch.softmax(attn, dim=attention_dim, dtype=self.dtype) 276 | 277 | if self.renormalize_keys: 278 | # Corresponds to value aggregation via weighted mean (as opposed to sum). 279 | normalizer = torch.sum(attn, axis=-1, keepdim=True) + self.epsilon 280 | attn_n = attn / normalizer 281 | else: 282 | attn_n = attn 283 | 284 | if self.attn_weights_only: 285 | return attn_n 286 | 287 | # Aggregate values using a weighted sum with weights provided by `attn` 288 | updates = torch.einsum("bhqk,bkhd->bqhd", attn_n, value) 289 | 290 | return updates, attn # FIXME: return attention too, as no option for intermediate storing in module in torch. 291 | 292 | 293 | class Transformer(nn.Module): 294 | """Transformer with multiple blocks.""" 295 | 296 | def __init__(self, 297 | embed_dim: int, # FIXME: added for submodules 298 | num_heads: int, 299 | qkv_size: int, 300 | mlp_size: int, 301 | num_layers: int, 302 | pre_norm: bool = False 303 | ): 304 | super().__init__() 305 | 306 | self.num_heads = num_heads 307 | self.qkv_size = qkv_size 308 | self.mlp_size = mlp_size 309 | self.num_layes = num_layers 310 | self.pre_norm = pre_norm 311 | 312 | # submodules 313 | self.model = nn.ModuleList() 314 | for lyr in range(self.num_layers): 315 | self.model.add_module( 316 | name=f"TransformerBlock_{lyr}", 317 | module=TransformerBlock( 318 | embed_dim=embed_dim, num_heads=num_heads, 319 | qkv_size=qkv_size, mlp_size=mlp_size, 320 | pre_norm=pre_norm) 321 | ) 322 | 323 | def forward(self, queries: Array, inputs: Optional[Array] = None, 324 | padding_mask: Optional[Array] = None, 325 | train: bool = False) -> Array: 326 | x = queries 327 | for layer in self.model: 328 | x = layer(x, inputs, padding_mask, train) 329 | return x 330 | 331 | 332 | class TransformerBlockOld(nn.Module): 333 | """Tranformer decoder block.""" 334 | 335 | def __init__(self, 336 | embed_dim: int, # FIXME: added for submodules 337 | num_heads: int, 338 | qkv_size: int, 339 | mlp_size: int, 340 | pre_norm: bool = False, 341 | cross_attn: bool = False 342 | ): 343 | super().__init__() 344 | 345 | self.num_heads = num_heads 346 | self.qkv_size = qkv_size 347 | self.mlp_size = mlp_size 348 | self.pre_norm = pre_norm 349 | 350 | # submodules 351 | ## MHA 352 | self.attn_self = nn.MultiheadAttention( 353 | embed_dim=embed_dim, num_heads=num_heads, batch_first=True) 354 | self.attn_cross = nn.MultiheadAttention( 355 | embed_dim=embed_dim, num_heads=num_heads, batch_first=True) if cross_attn else None 356 | ## mlps 357 | self.mlp = misc.MLP( 358 | input_size=embed_dim, hidden_size=mlp_size, 359 | output_size=embed_dim) 360 | ## layernorms 361 | self.layernorm_query = nn.LayerNorm(embed_dim, eps=1e-6) 362 | self.layernorm_inputs = nn.LayerNorm(embed_dim, eps=1e-6) if cross_attn else None 363 | self.layernorm_mlp = nn.LayerNorm(embed_dim, eps=1e-6) 364 | 365 | def forward(self, queries: Array, inputs: Optional[Array] = None, 366 | padding_mask: Optional[Array] = None, 367 | train: bool = False) -> Array: 368 | del padding_mask, train # Unused. 369 | assert queries.ndim == 3 370 | 371 | if self.pre_norm: 372 | # Self-attention on queries. 373 | x = self.layernorm_query(queries) 374 | x, _ = self.attn_self(query=x, key=x, value=x) 375 | x = x + queries 376 | 377 | # Cross-attention on inputs. 378 | if inputs is not None: 379 | assert inputs.ndim == 3 380 | y = self.layernorm_inputs(x) 381 | y, _ = self.attn_cross(q=y, k=inputs, v=inputs) 382 | y = y + x 383 | else: 384 | y = x 385 | 386 | # MLP 387 | z = self.layernorm_mlp(y) 388 | z = self.mlp(z) 389 | z = z + y 390 | else: 391 | # Self-attention on queries. 392 | x = queries 393 | x, _ = self.attn_self(query=x, key=x, value=x) 394 | x = x + queries 395 | x = self.layernorm_query(x) 396 | 397 | # Cross-attention on inputs. 398 | if inputs is not None: 399 | assert inputs.ndim == 3 400 | y, _ = self.attn_cross(query=x, key=inputs, value=inputs) 401 | y = y + x 402 | y = self.layernorm_inputs(y) 403 | else: 404 | y = x 405 | 406 | # MLP 407 | z = self.mlp(y) 408 | z = z + y 409 | z = self.layernorm_mlp(z) 410 | return z 411 | 412 | 413 | class TransformerBlock(nn.Module): 414 | """Tranformer decoder block.""" 415 | 416 | def __init__(self, 417 | embed_dim: int, # FIXME: added for submodules 418 | num_heads: int, 419 | qkv_size: int, 420 | mlp_size: int, 421 | pre_norm: bool = False, 422 | weight_init = None 423 | ): 424 | super().__init__() 425 | 426 | self.embed_dim = embed_dim 427 | self.qkv_size = qkv_size 428 | self.mlp_size = mlp_size 429 | self.num_heads = num_heads 430 | self.pre_norm = pre_norm 431 | self.weight_init = weight_init 432 | 433 | assert num_heads >= 1 434 | assert qkv_size % num_heads == 0, "embed dim must be divisible by num_heads" 435 | self.head_dim = qkv_size // num_heads 436 | 437 | # submodules 438 | ## MHA # 439 | self.attn = GeneralizedDotProductAttention() 440 | ## mlps 441 | self.mlp = misc.MLP( 442 | input_size=embed_dim, hidden_size=mlp_size, 443 | output_size=embed_dim, weight_init=weight_init) 444 | ## layernorms 445 | self.layernorm_query = nn.LayerNorm(embed_dim, eps=1e-6) 446 | self.layernorm_mlp = nn.LayerNorm(embed_dim, eps=1e-6) 447 | ## weights 448 | self.dense_q = nn.Linear(embed_dim, qkv_size) 449 | self.dense_k = nn.Linear(embed_dim, qkv_size) 450 | self.dense_v = nn.Linear(embed_dim, qkv_size) 451 | init_fn[weight_init['linear_w']](self.dense_q.weight) 452 | init_fn[weight_init['linear_b']](self.dense_q.bias) 453 | init_fn[weight_init['linear_w']](self.dense_k.weight) 454 | init_fn[weight_init['linear_b']](self.dense_k.bias) 455 | init_fn[weight_init['linear_w']](self.dense_v.weight) 456 | init_fn[weight_init['linear_b']](self.dense_v.bias) 457 | if self.num_heads > 1: 458 | self.dense_o = nn.Linear(qkv_size, embed_dim) 459 | # nn.init.xavier_uniform_(self.w_o.weight) 460 | init_fn[weight_init['linear_w']](self.dense_o.weight) 461 | init_fn[weight_init['linear_b']](self.dense_o.bias) 462 | self.multi_head = True 463 | else: 464 | self.multi_head = False 465 | 466 | def forward(self, inputs: Array) -> Array: # TODO: add general attention for q, k, v, not just for x = qkv 467 | assert inputs.ndim == 3 468 | 469 | B, L, _ = inputs.shape 470 | head_dim = self.embed_dim // self.num_heads 471 | 472 | if self.pre_norm: 473 | # Self-attention. 474 | x = self.layernorm_query(inputs) 475 | q = self.dense_q(x).view(B, L, self.num_heads, head_dim) 476 | k = self.dense_k(x).view(B, L, self.num_heads, head_dim) 477 | v = self.dense_v(x).view(B, L, self.num_heads, head_dim) 478 | x, _ = self.attn(query=q, key=k, value=v) 479 | if self.multi_head: 480 | x = self.dense_o(x.reshape(B, L, self.qkv_size)).view(B, L, self.embed_dim) 481 | else: 482 | x = x.squeeze(-2) 483 | x = x + inputs 484 | 485 | y = x 486 | 487 | # MLP 488 | z = self.layernorm_mlp(y) 489 | z = self.mlp(z) 490 | z = z + y 491 | else: 492 | # Self-attention on queries. 493 | x = inputs 494 | q = self.dense_q(x).view(B, L, self.num_heads, head_dim) 495 | k = self.dense_k(x).view(B, L, self.num_heads, head_dim) 496 | v = self.dense_v(x).view(B, L, self.num_heads, head_dim) 497 | x, _ = self.attn(query=q, key=k, value=v) 498 | if self.multi_head: 499 | x = self.dense_o(x.reshape(B, L, self.qkv_size)).view(B, L, self.embed_dim) 500 | else: 501 | x = x.squeeze(-2) 502 | x = x + inputs 503 | x = self.layernorm_query(x) 504 | 505 | y = x 506 | 507 | # MLP 508 | z = self.mlp(y) 509 | z = z + y 510 | z = self.layernorm_mlp(z) 511 | return z -------------------------------------------------------------------------------- /savi/modules/convolution.py: -------------------------------------------------------------------------------- 1 | """Convolutional module library.""" 2 | 3 | # FIXME 4 | 5 | import functools 6 | from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import numpy as np 12 | import torchvision.transforms as transforms 13 | import math 14 | 15 | from savi.lib.utils import init_fn 16 | 17 | Shape = Tuple[int] 18 | 19 | DType = Any 20 | Array = torch.Tensor 21 | ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet 22 | ProcessorState = ArrayTree 23 | PRNGKey = Array 24 | NestedDict = Dict[str, Any] 25 | 26 | class CNN(nn.Module): 27 | """Flexible CNN model with conv. and normalization layers.""" 28 | 29 | # TODO: add padding ? 30 | def __init__(self, 31 | features: Sequence[int], # FIXME: [in_channels, *out_channels] 32 | kernel_size: Sequence[Tuple[int, int]], 33 | strides: Sequence[Tuple[int, int]], 34 | layer_transpose: Sequence[bool], 35 | transpose_double: bool = True, 36 | padding: Union[Sequence[Tuple[int, int]], str] = None, 37 | activation_fn: Callable[[Array], Array] = nn.ReLU, 38 | norm_type: Optional[str] = None, 39 | axis_name: Optional[str] = None, # Over which axis to aggregate batch stats. 40 | output_size: Optional[int] = None, 41 | weight_init = None 42 | ): 43 | super().__init__() 44 | 45 | self.features = features 46 | self.kernel_size = kernel_size 47 | self.strides = strides 48 | self.layer_transpose = layer_transpose 49 | self.transpose_double = transpose_double 50 | self.padding = padding 51 | self.activation_fn = activation_fn 52 | self.norm_type = norm_type 53 | self.axis_name = axis_name 54 | self.output_size = output_size 55 | self.weight_init = weight_init 56 | 57 | # submodules 58 | num_layers = len(features) - 1 # account for input features (channels) 59 | 60 | if padding is None: 61 | padding = 0 62 | if isinstance(padding, int) or isinstance(padding, str): 63 | padding = [padding for _ in range(num_layers)] 64 | self.padding = padding 65 | 66 | assert num_layers >= 1, "Need to have at least one layer." 67 | assert len(kernel_size) == num_layers, ( 68 | f"len(kernel_size): {len(kernel_size)} and len(features): {len(features)} must match.") 69 | assert len(strides) == num_layers, ( 70 | f"len(strides): {len(strides)} and len(features): {len(features)} must match.") 71 | assert len(layer_transpose) == num_layers, ( 72 | f"len(layer_transpose): {len(layer_transpose)} and len(features): {len(features)} must match.") 73 | 74 | if self.norm_type: 75 | assert self.norm_type in {"batch", "group", "instance", "layer"}, ( 76 | f"({self.norm_type}) is not a valid normalization type") 77 | 78 | # Whether transpose conv or regular conv 79 | conv_module = {False: nn.Conv2d, True: nn.ConvTranspose2d} 80 | 81 | if self.norm_type == "batch": 82 | norm_module = functools.partial(nn.BatchNorm2d, momentum=0.9) 83 | elif self.norm_type == "group": 84 | norm_module = lambda x: nn.GroupNorm(num_groups=32, num_channels=x) 85 | elif self.norm_type == "layer": 86 | norm_module = functools.partial(nn.LayerNorm, eps=1e-6) 87 | elif self.norm_type == "instance": 88 | norm_module = functools.partial(nn.InstanceNorm2d) 89 | 90 | # model 91 | ## Convnet Architecture. 92 | self.cnn_layers = nn.ModuleList() 93 | for i in range(num_layers): 94 | 95 | ### Convolution Layer. 96 | convname = "convtranspose" if layer_transpose[i] else "conv" 97 | pad = padding[i] 98 | if "convtranspose" == convname and isinstance(pad, str): 99 | pad = 0 100 | name = f"{convname}_{i}" 101 | module = conv_module[self.layer_transpose[i]]( 102 | in_channels=features[i], out_channels=features[i+1], 103 | kernel_size=kernel_size[i], stride=strides[i], padding=pad, 104 | bias=False if norm_type else True) 105 | self.cnn_layers.add_module(name, module) 106 | 107 | # init conv layer weights. 108 | # nn.init.xavier_uniform_(module.weight) 109 | init_fn[weight_init['conv_w']](module.weight) 110 | if not norm_type: 111 | init_fn[weight_init['conv_b']](module.bias) 112 | 113 | ### Normalization Layer. 114 | if self.norm_type: 115 | self.cnn_layers.add_module( 116 | f"{self.norm_type}_norm_{i}", 117 | norm_module(features[i+1])) 118 | 119 | ### Activation Layer 120 | self.cnn_layers.add_module( 121 | f"activ_{i}", 122 | activation_fn()) 123 | 124 | ## Final Dense Layer 125 | if self.output_size: 126 | self.project_to_output = nn.Linear(features[-1], self.output_size, bias=True) 127 | # nn.init.xavier_uniform_(self.project_to_output.weight) 128 | init_fn[weight_init['linear_w']](self.project_to_output.weight) 129 | init_fn[weight_init['linear_b']](self.project_to_output.bias) 130 | 131 | def forward(self, inputs: Array, channels_last=False) -> Tuple[Dict[str, Array]]: 132 | if channels_last: 133 | # inputs.shape = (batch_size, height, width, n_channels) 134 | inputs = inputs.permute((0, 3, 1, 2)) 135 | # inputs.shape = (batch_size, n_channels, height, width) 136 | 137 | x = inputs 138 | for name, layer in self.cnn_layers.named_children(): 139 | layer_fn = lambda x_in: layer(x_in) 140 | if "convtranspose" in name and self.transpose_double: 141 | output_shape = (x.shape[-2]*2, x.shape[-1]*2) 142 | layer_fn = lambda x_in: layer(x_in, output_size=output_shape) 143 | x = layer_fn(x) 144 | # if inputs.get_device() == 0: 145 | # print(name, inputs.max().item(), inputs.min().item(), 146 | # x.max().item(), x.min().item()) 147 | 148 | if channels_last: 149 | # x.shape = (batch_size, n_features, h*, w*) 150 | x = x.permute((0, 2, 3, 1)) 151 | # x.shape = (batch_size, h*, w*, n_features) 152 | 153 | if self.output_size: 154 | x = self.project_to_output(x) 155 | 156 | return x 157 | 158 | class CNN2(nn.Module): 159 | """New CNN module because above wasn't too flexible in torch.""" 160 | 161 | def __init__(self, 162 | conv_modules: nn.ModuleList, 163 | transpose_modules = None, 164 | activation_fn: nn.Module = nn.ReLU, 165 | norm_type: Optional[str] = None, 166 | output_size: Optional[str] = None, 167 | weight_init = None 168 | ): 169 | super().__init__() 170 | 171 | self.transpose_modules = transpose_modules 172 | self.activation = activation_fn 173 | self.norm_type = norm_type 174 | self.output_size = output_size 175 | self.weight_init = weight_init 176 | self.features = [c.out_channels for c in conv_modules.children()] 177 | 178 | # submodules 179 | num_layers = len(conv_modules) 180 | 181 | # check if there are transposed convolutions (needed for weight init) 182 | if transpose_modules is not None: 183 | assert len(transpose_modules) == num_layers, ( 184 | "need to specify which modules are transposed convolutions.") 185 | else: 186 | transpose_modules = [False for _ in range(num_layers)] 187 | self.transpose_modules = transpose_modules 188 | init_map = {True: 'convtranspose_w', False: 'conv_w'} 189 | 190 | if self.norm_type: 191 | assert self.norm_type in {"batch", "group", "instance", "layer"}, ( 192 | f"({self.norm_type}) is not a valid normalization type") 193 | 194 | if self.norm_type == "batch": 195 | norm_module = functools.partial(nn.BatchNorm2d, momentum=0.9) 196 | elif self.norm_type == "group": 197 | norm_module = lambda x: nn.GroupNorm(num_groups=32, num_channels=x) 198 | elif self.norm_type == "layer": 199 | norm_module = functools.partial(nn.LayerNorm, eps=1e-6) 200 | elif self.norm_type == "instance": 201 | norm_module = functools.partial(nn.InstanceNorm2d) 202 | 203 | # model 204 | ## Convnet Architecture. 205 | self.cnn_layers = nn.ModuleList() 206 | for i in range(num_layers): 207 | ### Conv 208 | name = f"conv_{i}" 209 | conv = conv_modules[i] 210 | init_fn[weight_init[init_map[transpose_modules[i]]]](conv.weight) 211 | if conv.bias is not None: 212 | init_fn[weight_init['conv_b']](conv.bias) 213 | self.cnn_layers.add_module(name, conv) 214 | 215 | ### Normalization (if exists) 216 | if self.norm_type: 217 | self.cnn_layers.add_module( 218 | f"{self.norm_type}_norm_{i}", 219 | norm_module(self.features[i])) 220 | 221 | ### Activation 222 | self.cnn_layers.add_module( 223 | f"act_{i}", 224 | activation_fn()) 225 | 226 | ## Final Dense Layer (if exists) 227 | if self.output_size: 228 | self.project_to_output = nn.Linear(self.features[-1], self.outptu_size, bias=True) 229 | init_fn[weight_init['linear_w']](self.project_to_output.weight) 230 | init_fn[weight_init['linear_b']](self.project_to_output.bias) 231 | 232 | def forward(self, inputs: Array, channels_last=False) -> Tuple[Dict[str, Array]]: 233 | if channels_last: 234 | # inputs.shape = (batch_size, height, width, n_channels) 235 | inputs = inputs.permute((0, 3, 1, 2)) 236 | # inputs.shape = (batch_size, n_channels, height, width) 237 | 238 | x = inputs 239 | for name, layer in self.cnn_layers.named_children(): 240 | x = layer(x) 241 | 242 | if channels_last: 243 | # x.shape = (batch_size, n_features, h*, w*) 244 | x = x.permute((0, 2, 3, 1)) 245 | # x.shape = (batch_size, h*, w*, n_features) 246 | 247 | if self.output_size: 248 | x = self.project_to_output(x) 249 | 250 | return x -------------------------------------------------------------------------------- /savi/modules/decoders.py: -------------------------------------------------------------------------------- 1 | """Decoder module library.""" 2 | 3 | # FIXME 4 | 5 | import functools 6 | from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union 7 | from pyparsing import alphas 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import numpy as np 13 | 14 | from savi.lib import utils 15 | from savi.lib.utils import init_fn 16 | 17 | Shape = Tuple[int] 18 | 19 | DType = Any 20 | Array = torch.Tensor 21 | ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet 22 | ProcessorState = ArrayTree 23 | PRNGKey = Array 24 | NestedDict = Dict[str, Any] 25 | 26 | 27 | 28 | class SpatialBroadcastDecoder(nn.Module): 29 | """Spatial broadcast decoder for as set of slots (per frame).""" 30 | 31 | def __init__(self, 32 | resolution: Sequence[int], 33 | backbone: nn.Module, 34 | pos_emb: nn.Module, 35 | target_readout: nn.Module = None, 36 | weight_init = None 37 | ): 38 | super().__init__() 39 | 40 | self.resolution = resolution 41 | self.backbone = backbone 42 | self.pos_emb = pos_emb 43 | self.target_readout = target_readout 44 | self.weight_init = weight_init 45 | 46 | # submodules 47 | self.mask_pred = nn.Linear(self.backbone.features[-1], 1) 48 | # nn.init.xavier_uniform_(self.mask_pred.weight) 49 | init_fn[weight_init['linear_w']](self.mask_pred.weight) 50 | init_fn[weight_init['linear_b']](self.mask_pred.bias) 51 | 52 | def forward(self, slots: Array) -> Array: 53 | 54 | batch_size, n_slots, n_features = slots.shape 55 | 56 | # Fold slot dim into batch dim. 57 | x = slots.reshape(shape=(batch_size * n_slots, n_features)) 58 | 59 | # Spatial broadcast with position embedding. 60 | x = utils.spatial_broadcast(x, self.resolution) 61 | x = self.pos_emb(x) 62 | 63 | # bb_features.shape = (batch_size * n_slots, h, w, c) 64 | bb_features = self.backbone(x, channels_last=True) 65 | spatial_dims = bb_features.shape[-3:-1] 66 | 67 | alpha_logits = self.mask_pred( # take each feature separately 68 | bb_features.reshape(shape=(-1, bb_features.shape[-1]))) 69 | alpha_logits = alpha_logits.reshape( 70 | shape=(batch_size, n_slots, *spatial_dims, -1)) # (B O H W 1) 71 | 72 | alpha_mask = alpha_logits.softmax(dim=1) 73 | 74 | # TODO: figure out what to do with readout. 75 | targets_dict = self.target_readout(bb_features.reshape(shape=(-1, bb_features.shape[-1]))) 76 | 77 | preds_dict = dict() 78 | for target_key, channels in targets_dict.items(): 79 | 80 | # channels.shape = (batch_size, n_slots, h, w, c) 81 | channels = channels.reshape(shape=(batch_size, n_slots, *spatial_dims, -1)) 82 | 83 | # masked_channels.shape = (batch_size, n_slots, h, w, c) 84 | masked_channels = channels * alpha_mask 85 | 86 | # decoded_target.shape = (batch_size, h, w, c) 87 | decoded_target = torch.sum(masked_channels, dim=1) # Combine target 88 | preds_dict[target_key] = decoded_target 89 | 90 | if not self.training: # intermediates for logging. 91 | preds_dict[f"eval/{target_key}_slots"] = channels 92 | preds_dict[f"eval/{target_key}_masked"] = masked_channels 93 | preds_dict[f"eval/{target_key}_combined"] = decoded_target 94 | 95 | # if not self.training: # intermediates for logging. 96 | # preds_dict["eval/alpha_mask"] = alpha_mask 97 | preds_dict["alpha_mask"] = alpha_mask 98 | 99 | if not self.training: # only return for evaluation 100 | preds_dict["segmentations"] = alpha_logits.argmax(dim=1) 101 | 102 | return preds_dict -------------------------------------------------------------------------------- /savi/modules/evaluator.py: -------------------------------------------------------------------------------- 1 | """Model evaluation.""" 2 | 3 | # TODO: rename file 4 | 5 | from typing import Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Type, Union 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import numpy as np 11 | 12 | Array = torch.Tensor 13 | ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] 14 | 15 | 16 | def get_eval_metrics( 17 | preds: Dict[str, ArrayTree], 18 | batch: Dict[str, Array], 19 | loss_fn, 20 | eval_metrics_processor, 21 | predicted_max_num_instances: int, 22 | ground_truth_max_num_instances: int, 23 | ) -> Union[None, Dict]: 24 | """Compute the metrics for the model predictions in inference mode. 25 | 26 | The metrics are averaged across *all* devices (of all hosts). 27 | 28 | Args: 29 | preds: Model predictions. 30 | batch: Inputs that should be evaluated. 31 | loss_fn: Loss function that takes model predictions and a batch of data. 32 | eval_metrics_cls: Dictionary of evaluation metrics. 33 | predicted_max_num_instances: Maximum number of instances (objects) in prediction. 34 | ground_truth_max_num_instances: Maximum number of instances in ground truth, 35 | including background (which counts as a separate instance). 36 | 37 | Returns: 38 | The evaluation metrics. 39 | """ 40 | loss = loss_fn(preds, batch) 41 | metrics_update = eval_metrics_processor.from_model_output( 42 | 43 | ) 44 | # TODO 45 | return metrics_update 46 | 47 | def eval_first_step( 48 | model: nn.Module, 49 | batch: Tuple[Array], 50 | # conditioning_key: Optional[str] = None 51 | ) -> Dict[str, ArrayTree]: 52 | """Get the model predictions with a freshly initialized recurrent state. 53 | 54 | Args: 55 | model: Model used in eval step. 56 | state_variables: State variables for the model. 57 | params: Params for the model. 58 | batch: Inputs that should be evaluated. 59 | conditioning_key: Optional key for conditioning slots. 60 | Returns: 61 | The model's predictions. 62 | """ 63 | 64 | video, boxes, segmentations, gt_flow, padding_mask, mask = batch 65 | # TODO: delete hardcode 66 | conditioning = boxes 67 | 68 | preds = model( 69 | video=video, conditioning=conditioning, 70 | padding_mask=padding_mask 71 | ) 72 | 73 | return preds 74 | 75 | 76 | def eval_continued_step( 77 | model: nn.Module, 78 | batch: Tuple[Array], 79 | recurrent_states: Array 80 | ) -> Dict[str, ArrayTree]: 81 | """Get the model predictions, continuing from a provided recurrent state. 82 | 83 | Args: 84 | model: Model used in eval step. 85 | batch: Inputs that should be evaluated. 86 | recurrent_states: Recurrent internal model state from which to continue. 87 | i.e. slots 88 | Returns: 89 | The model's predictions. 90 | """ 91 | 92 | video, boxes, segmentations, gt_flow, padding_mask, mask = batch 93 | 94 | preds = model( 95 | video=video, conditioning=recurrent_states, 96 | continue_from_previous_state=True, padding_mask=padding_mask 97 | ) 98 | 99 | return preds 100 | 101 | def batch_slicer( 102 | batch: Tuple[Array], 103 | start_idx: int, 104 | end_idx: int, 105 | pad_value: int = 0) -> Tuple[Array]: 106 | """Slicing the batch along axis 1. (hardcoded) 107 | 108 | Pads when sequence ends before `end`. 109 | hardcoded parameters included, don't use as a general slicing fn 110 | """ 111 | assert start_idx <= end_idx 112 | video, boxes, segmentations, gt_flow, padding_mask, mask = batch 113 | 114 | seq_len = video.shape[1] 115 | # Infer end index if not provided. 116 | if end_idx == -1: 117 | end_idx = seq_len 118 | # Set padding size if end index > sequence length 119 | pad_size = 0 120 | if end_idx > seq_len: 121 | pad_size = end_idx - start_idx 122 | end_idx = seq_len 123 | 124 | sliced_batch = [] 125 | for array in (video, boxes, segmentations, gt_flow, padding_mask): 126 | if pad_size > 0: 127 | # array shape: (B, T, ...) 128 | pad_shape = list(array.shape[:1]) + [pad_size] + list(array.shape[2:]) # (B, pad, ...) 129 | padding = torch.full(pad_shape, pad_value) 130 | item = torch.cat([array[:, start_idx:end_idx], padding], dim=1) 131 | sliced_batch.append(item) 132 | else: 133 | sliced_batch.append(array[:, start_idx:end_idx]) 134 | sliced_batch.append(mask) # hardcoded. only array with shape (B,) 135 | 136 | return sliced_batch 137 | 138 | def preds_reader(model_outputs): 139 | """Hardcoded helper function for eval_step readability""" 140 | recurrent_states = model_outputs["states_pred"] # [B, T, N, S] 141 | pred_seg = model_outputs["outputs"]["segmentations"] # [B, T, H, W, 1] 142 | pred_flow = model_outputs["outputs"]["flow"] # [B, T, H, W, 3] 143 | att_t = model_outputs["attention"] # [B, T, ?] # TODO: figure this out 144 | 145 | return recurrent_states, pred_seg, pred_flow, att_t 146 | 147 | 148 | def eval_step( 149 | model: nn.Module, 150 | batch: Tuple[Array], 151 | slice_size: Optional[int] = None 152 | ) -> Tuple[Array]: 153 | """Compute the metrics for the given model in inference mode. 154 | 155 | The metrics are averaged across all devices. 156 | 157 | Args: 158 | model: Model used in eval step 159 | batch: inputs 160 | eval_first_step_fn: eval first step fn 161 | eval_continued_step_fn: eval continued step fn 162 | slice_size: Optional int, if provided, evaluate model on temporal 163 | slices of this size instead of full sequence length at once. 164 | Returns: 165 | Model predictions (hardcoded) 166 | pred_seg, pred_flow, att_t 167 | """ 168 | 169 | video, boxes, segmentations, flow, padding_mask, mask = batch 170 | temporal_axis = axis = 1 171 | 172 | seq_len = video.shape[axis] 173 | # Sliced evaluation (i.e. onsmaller temporal slices of the video). 174 | if slice_size is not None and slice_size < seq_len: 175 | num_slices = int(np.ceil(seq_len / slice_size)) 176 | 177 | # Get predictions for first slice (with fresh recurrrent state (i.e. slots)). 178 | batch_slice = batch_slicer(batch, 0, slice_size) 179 | preds_slice = eval_first_step( 180 | model=model, batch=batch_slice) 181 | recurrent_states, pred_seg, pred_flow, att_t = preds_reader(preds_slice) 182 | # make predictions array 183 | preds = [[item] for item in (pred_seg, pred_flow, att_t)] 184 | 185 | # Iterate over remaining slices (re-using the previous recurrent state). 186 | for slice_idx in range(1, num_slices): 187 | batch_slice = batch_slicer(batch, 188 | start_idx=slice_idx * slice_size, 189 | end_idx=(slice_idx+1) * slice_size) 190 | preds_slice = eval_continued_step( 191 | model, batch_slice, recurrent_states) 192 | recurrent_states, pred_seg, pred_flow, att_t = preds_reader(preds_slice) 193 | for i in range(len(preds)): 194 | preds[i].append((pred_seg, pred_flow, att_t)[i]) 195 | 196 | # join the predictions 197 | for i in range(len(preds)): 198 | preds[i] = torch.cat(preds[i], dim=axis) 199 | 200 | else: 201 | preds = eval_first_step(model, batch) 202 | preds = preds_reader(preds)[1:] 203 | 204 | return preds 205 | -------------------------------------------------------------------------------- /savi/modules/factory.py: -------------------------------------------------------------------------------- 1 | """Return model, loss, and eval metrics in 1 go 2 | for the SAVi model.""" 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | from savi.lib.utils import init_fn 9 | 10 | import savi.modules as modules 11 | import savi.modules.misc as misc 12 | 13 | 14 | def build_model(args): 15 | if args.model_size == "small": 16 | slot_size = 128 17 | num_slots = args.num_slots 18 | weight_init = args.weight_init 19 | # Encoder 20 | # encoder_backbone = modules.CNN( 21 | # features=[3, 32, 32, 32, 32], 22 | # kernel_size=[(5, 5), (5, 5), (5, 5), (5, 5)], 23 | # strides=[(1, 1), (1, 1), (1, 1), (1, 1)], 24 | # # padding="same", 25 | # padding=[(2, 2), (2, 2), (2, 2), (2, 2)], 26 | # layer_transpose=[False, False, False, False], 27 | # weight_init=weight_init) 28 | encoder_backbone = modules.CNN2( 29 | conv_modules=nn.ModuleList([ 30 | nn.Conv2d(3, 32, (5, 5), (1, 1), (2, 2)), 31 | nn.Conv2d(32, 32, (5, 5), (1, 1), (2, 2)), 32 | nn.Conv2d(32, 32, (5, 5), (1, 1), (2, 2)), 33 | nn.Conv2d(32, 32, (5, 5), (1, 1), (2, 2))]), 34 | weight_init=weight_init) 35 | encoder = modules.FrameEncoder( 36 | backbone=encoder_backbone, 37 | pos_emb=modules.PositionEmbedding( 38 | input_shape=(-1, 64, 64, 32), 39 | embedding_type="linear", 40 | update_type="project_add", 41 | output_transform=modules.MLP( 42 | input_size=32, 43 | hidden_size=64, 44 | output_size=32, 45 | layernorm="pre", 46 | weight_init=weight_init), 47 | weight_init=weight_init)) 48 | # Corrector 49 | corrector = modules.SlotAttention( 50 | input_size=32, # TODO: validate, should be backbone output size 51 | qkv_size=128, 52 | slot_size=slot_size, 53 | num_iterations=1, 54 | weight_init=weight_init) 55 | # Predictor 56 | predictor = modules.TransformerBlock( 57 | embed_dim=slot_size, 58 | num_heads=4, 59 | qkv_size=128, 60 | mlp_size=256, 61 | weight_init=weight_init) 62 | # Initializer 63 | initializer = modules.CoordinateEncoderStateInit( 64 | embedding_transform=modules.MLP( 65 | input_size=4, # bounding boxes have feature size 4 66 | hidden_size=256, 67 | output_size=slot_size, 68 | layernorm=None, 69 | weight_init=weight_init), 70 | prepend_background=True, 71 | center_of_mass=False) 72 | # Decoder 73 | readout_modules = nn.ModuleList([ 74 | nn.Linear(64, out_features) for out_features in args.targets.values()]) 75 | for module in readout_modules.children(): 76 | init_fn[weight_init['linear_w']](module.weight) 77 | init_fn[weight_init['linear_b']](module.bias) 78 | # decoder_backbone = modules.CNN( 79 | # features=[slot_size, 64, 64, 64, 64], 80 | # kernel_size=[(5, 5), (5, 5), (5, 5), (5, 5)], 81 | # strides=[(2, 2), (2, 2), (2, 2), (1, 1)], 82 | # padding=[2, 2, 2, "same"], 83 | # transpose_double=True, 84 | # layer_transpose=[True, True, True, False], 85 | # weight_init=weight_init) 86 | decoder_backbone = modules.CNN2( 87 | nn.ModuleList([ 88 | nn.ConvTranspose2d(slot_size, 64, (5, 5), (2, 2), padding=(2, 2), output_padding=(1, 1)), 89 | nn.ConvTranspose2d(64, 64, (5, 5), (2, 2), padding=(2, 2), output_padding=(1, 1)), 90 | nn.ConvTranspose2d(64, 64, (5, 5), (2, 2), padding=(2, 2), output_padding=(1, 1)), 91 | nn.Conv2d(64, 64, (5, 5), (1, 1), (2,2))]), 92 | transpose_modules=[True, True, True, False], 93 | weight_init=weight_init) 94 | decoder = modules.SpatialBroadcastDecoder( 95 | resolution=(8,8), # Update if data resolution or strides change. 96 | backbone=decoder_backbone, 97 | pos_emb=modules.PositionEmbedding( 98 | input_shape=(-1, 8, 8, slot_size), 99 | embedding_type="linear", 100 | update_type="project_add", 101 | weight_init=weight_init), 102 | target_readout=modules.Readout( 103 | keys=list(args.targets), 104 | readout_modules=readout_modules), 105 | weight_init=weight_init) 106 | # SAVi Model 107 | model = modules.SAVi( 108 | encoder=encoder, 109 | decoder=decoder, 110 | corrector=corrector, 111 | predictor=predictor, 112 | initializer=initializer, 113 | decode_corrected=True, 114 | decode_predicted=False) 115 | elif args.model_size == "medium": 116 | slot_size = 128 117 | num_slots = args.num_slots 118 | weight_init = args.weight_init 119 | # Encoder 120 | # encoder_backbone = modules.CNN( 121 | # features=[3, 64, 64, 64, 64], 122 | # kernel_size=[(5, 5), (5, 5), (5, 5), (5, 5)], 123 | # strides=[(2, 2), (1, 1), (1, 1), (1, 1)], 124 | # padding=[(2, 2), "same", "same", "same"], 125 | # layer_transpose=[False, False, False, False], 126 | # weight_init=weight_init) 127 | encoder_backbone = modules.CNN2( 128 | conv_modules=nn.ModuleList([ 129 | nn.Conv2d(3, 64, (5, 5), (2, 2), (2, 2)), 130 | nn.Conv2d(64, 64, (5, 5), (1, 1), (2, 2)), 131 | nn.Conv2d(64, 64, (5, 5), (1, 1), (2, 2)), 132 | nn.Conv2d(64, 64, (5, 5), (1, 1), (2, 2))]), 133 | weight_init=weight_init) 134 | encoder = modules.FrameEncoder( 135 | backbone=encoder_backbone, 136 | pos_emb=modules.PositionEmbedding( 137 | input_shape=(-1, 64, 64, 64), 138 | embedding_type="linear", 139 | update_type="project_add", 140 | output_transform=modules.MLP( 141 | input_size=64, 142 | hidden_size=64, 143 | output_size=64, 144 | layernorm="pre", 145 | weight_init=weight_init), 146 | weight_init=weight_init)) 147 | # Corrector 148 | corrector = modules.SlotAttention( 149 | input_size=64, # TODO: validate, should be backbone output size 150 | qkv_size=128, 151 | slot_size=slot_size, 152 | num_iterations=1, 153 | weight_init=weight_init) 154 | # Predictor 155 | predictor = modules.TransformerBlock( 156 | embed_dim=slot_size, 157 | num_heads=4, 158 | qkv_size=128, 159 | mlp_size=256, 160 | weight_init=weight_init) 161 | # Initializer 162 | initializer = modules.CoordinateEncoderStateInit( 163 | embedding_transform=modules.MLP( 164 | input_size=4, # bounding boxes have feature size 4 165 | hidden_size=256, 166 | output_size=slot_size, 167 | layernorm=None, 168 | weight_init=weight_init), 169 | prepend_background=True, 170 | center_of_mass=False) 171 | # Decoder 172 | readout_modules = nn.ModuleList([ 173 | nn.Linear(64, out_features) for out_features in args.targets.values()]) 174 | for module in readout_modules.children(): 175 | init_fn[weight_init['linear_w']](module.weight) 176 | init_fn[weight_init['linear_b']](module.bias) 177 | # decoder_backbone = modules.CNN( 178 | # features=[slot_size, 64, 64, 64, 64], 179 | # kernel_size=[(5, 5), (5, 5), (5, 5), (5, 5)], 180 | # strides=[(2, 2), (2, 2), (2, 2), (2, 2)], 181 | # padding=[2, 2, 2, 2], 182 | # transpose_double=True, 183 | # layer_transpose=[True, True, True, True], 184 | # weight_init=weight_init) 185 | decoder_backbone = modules.CNN2( 186 | nn.ModuleList([ 187 | nn.ConvTranspose2d(slot_size, 64, (5, 5), (2, 2), padding=(2, 2), output_padding=(1, 1)), 188 | nn.ConvTranspose2d(64, 64, (5, 5), (2, 2), padding=(2, 2), output_padding=(1, 1)), 189 | nn.ConvTranspose2d(64, 64, (5, 5), (2, 2), padding=(2, 2), output_padding=(1, 1)), 190 | nn.ConvTranspose2d(64, 64, (5, 5), (2, 2), padding=(2, 2), output_padding=(1, 1))]), 191 | transpose_modules=[True, True, True, True], 192 | weight_init=weight_init) 193 | decoder = modules.SpatialBroadcastDecoder( 194 | resolution=(8,8), # Update if data resolution or strides change. 195 | backbone=decoder_backbone, 196 | pos_emb=modules.PositionEmbedding( 197 | input_shape=(-1, 8, 8, slot_size), 198 | embedding_type="linear", 199 | update_type="project_add", 200 | weight_init=weight_init), 201 | target_readout=modules.Readout( 202 | keys=list(args.targets), 203 | readout_modules=readout_modules), 204 | weight_init=weight_init) 205 | # SAVi Model 206 | model = modules.SAVi( 207 | encoder=encoder, 208 | decoder=decoder, 209 | corrector=corrector, 210 | predictor=predictor, 211 | initializer=initializer, 212 | decode_corrected=True, 213 | decode_predicted=False) 214 | else: 215 | raise NotImplementedError 216 | return model 217 | 218 | 219 | def build_modules(args): 220 | """Return the model and loss/eval processors.""" 221 | model = build_model(args) 222 | loss = misc.ReconLoss() 223 | metrics = misc.ARI() 224 | 225 | return model, loss, metrics 226 | -------------------------------------------------------------------------------- /savi/modules/initializers.py: -------------------------------------------------------------------------------- 1 | """Initializers module library.""" 2 | 3 | # FIXME 4 | 5 | import functools 6 | from turtle import forward 7 | from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import numpy as np 13 | 14 | from savi.lib import utils 15 | from savi.modules import misc 16 | from savi.modules import video 17 | 18 | Shape = Tuple[int] 19 | 20 | DType = Any 21 | Array = torch.Tensor 22 | ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet 23 | ProcessorState = ArrayTree 24 | PRNGKey = Array 25 | NestedDict = Dict[str, Any] 26 | 27 | 28 | class ParamStateInit(nn.Module): 29 | """Fixed, learnable state initialization. 30 | 31 | Note: This module ignores any conditional input (by design). 32 | """ 33 | 34 | def __init__(self, 35 | shape: Sequence[int], 36 | init_fn: str = "normal" 37 | ): 38 | super().__init__() 39 | 40 | if init_fn == "normal": 41 | self.init_fn = functools.partial(nn.init.normal_, std=1.) 42 | elif init_fn == "zeros": 43 | self.init_fn = nn.init.zeros_() 44 | else: 45 | raise ValueError(f"Unknown init_fn: {init_fn}") 46 | 47 | self.param = nn.Parameter(torch.empty(size=(shape))) 48 | 49 | def forward(self, inputs: Optional[Array], batch_size: int) -> Array: 50 | del inputs # Unused. 51 | self.param = self.init_fn(self.param) 52 | return utils.broadcast_across_batch(self.param, batch_size=batch_size) 53 | 54 | 55 | class GaussianStateInit(nn.Module): 56 | """Random state initialization with zero-mean, unit-variance Gaussian 57 | 58 | Note: This module does not contain any trainable parameters. 59 | This module also ignores any conditional input (by design). 60 | """ 61 | 62 | def __init__(self, 63 | shape: Sequence[int], 64 | ): 65 | super().__init__() 66 | 67 | self.shape = shape 68 | 69 | def forward(self, inputs: Optional[Array], batch_size: int) -> Array: 70 | del inputs # Unused. 71 | # TODO: Use torch generator ? 72 | return torch.normal(mean=torch.zeros([batch_size] + list(self.shape))) 73 | 74 | 75 | class SegmentationEncoderStateInit(nn.Module): 76 | """State init that encodes segmentation masks as conditional input.""" 77 | 78 | def __init__(self, 79 | max_num_slots: int, 80 | backbone: nn.Module, 81 | pos_emb: nn.Module = nn.Identity(), 82 | reduction: Optional[str] = "all_flatten", # Reduce spatial dim by default. 83 | output_transform: nn.Module = nn.Identity(), 84 | zero_background: bool = False 85 | ): 86 | super().__init__() 87 | 88 | self.max_num_slots = max_num_slots 89 | self.backbone = backbone 90 | self.pos_emb = pos_emb 91 | self.reduction = reduction 92 | self.output_transform = output_transform 93 | self.zero_background = zero_background 94 | 95 | # submodules 96 | self.encoder = video.FrameEncoder( 97 | backbone=backbone, pos_emb=pos_emb, 98 | reduction=reduction, output_transform=output_transform) 99 | 100 | def forward(self, inputs: Array, batch_size: Optional[int]) -> Array: 101 | del batch_size # Unused. 102 | 103 | # inputs.shape = (batch_size, seq_len, height, width) 104 | inputs = inputs[:, 0] # Only condition on first time step. 105 | 106 | # Convert mask index to one-hot. 107 | inputs_oh = F.one_hot(inputs, self.max_num_slots) 108 | # inputs_oh.shape = (batch_size, height, width, n_slots) 109 | # NOTE: 0th entry inputs_oh[... 0] will typically correspond to background. 110 | 111 | # Set background slot to all-zeros. 112 | if self.zero_background: 113 | inputs_oh[:, :, :, 0] = 0 114 | 115 | # Switch one-hot axis into 1st position (i.e. sequence axis). 116 | inputs_oh = inputs_oh.permute((0, 3, 1, 2)) 117 | # inputs_oh.shape = (batch_size, max_num_slots, height, width) 118 | 119 | # Append dummy feature axis. 120 | inputs_oh = torch.unsqueeze(-1) 121 | 122 | # encode slots 123 | # slots.shape = (batch_size, n_slots, n_features) 124 | slots = self.encoder(inputs_oh, None) 125 | 126 | return slots 127 | 128 | 129 | class CoordinateEncoderStateInit(nn.Module): 130 | """State init that encodes bounding box corrdinates as conditional input. 131 | 132 | Attributes: 133 | embedding_transform: A nn.Module that is applied on inputs (bounding boxes). 134 | prepend_background: Boolean flag' whether to prepend a special, zero-valued 135 | background bounding box to the input. Default: False. 136 | center_of_mass: Boolean flag; whether to convert bounding boxes to center 137 | of mass coordinates. Default: False. 138 | background_value: Default value to fill in the background. 139 | """ 140 | 141 | def __init__(self, 142 | embedding_transform: nn.Module, 143 | prepend_background: bool = False, 144 | center_of_mass: bool = False, 145 | background_value: float = 0. 146 | ): 147 | super().__init__() 148 | 149 | self.embedding_transform = embedding_transform 150 | self.prepend_background = prepend_background 151 | self.center_of_mass = center_of_mass 152 | self.background_value = background_value 153 | 154 | def forward(self, inputs: Array, batch_size: Optional[int]) -> Array: 155 | del batch_size # Unused. 156 | 157 | # inputs.shape = (batch_size, seq_len, bboxes, 4) 158 | inputs = inputs[:, 0] # Only condition on first time step. 159 | # inputs.shape = (batch_size, bboxes, 4) 160 | 161 | if self.prepend_background: 162 | # Adds a fake background box [0, 0, 0, 0] at the beginning. 163 | batch_size = inputs.shape[0] 164 | 165 | # Encode the background as specified by the background_value. 166 | background = torch.full( 167 | (batch_size, 1, 4), self.background_value, dtype=inputs.dtype, 168 | device = inputs.get_device()) 169 | 170 | inputs = torch.cat([background, inputs], dim=1) 171 | 172 | if self.center_of_mass: 173 | y_pos = (inputs[:, :, 0] + inputs[:, :, 2]) / 2 174 | x_pos = (inputs[:, :, 1] + inputs[:, :, 3]) / 2 175 | inputs = torch.stack([y_pos, x_pos], dim=-1) 176 | 177 | slots = self.embedding_transform(inputs) 178 | 179 | return slots -------------------------------------------------------------------------------- /savi/modules/misc.py: -------------------------------------------------------------------------------- 1 | """Miscellaneous modules.""" 2 | 3 | # FIXME 4 | 5 | from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import numpy as np 11 | import math 12 | 13 | import savi.lib.metrics as metrics 14 | import savi.lib.metrics_jax as metrics_jax 15 | import savi.modules.evaluator as evaluator 16 | from savi.lib import utils 17 | from savi.lib.utils import init_fn 18 | 19 | DType = Any 20 | Array = torch.Tensor 21 | ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet 22 | ProcessorState = ArrayTree 23 | PRNGKey = Array 24 | NestedDict = Dict[str, Any] 25 | 26 | # class Identity(nn.Module): 27 | # """Module that applies the identity function, ignoring any additional args.""" 28 | 29 | # def __init__(self): 30 | # super().__init__() 31 | 32 | # def forward(self, inputs: Array, **args) -> Array: 33 | # return inputs 34 | 35 | 36 | class Readout(nn.Module): 37 | """Module for reading out multiple targets from an embedding.""" 38 | 39 | def __init__(self, 40 | keys: Sequence[str], 41 | readout_modules: nn.ModuleList, 42 | stop_gradient: Optional[Sequence[bool]] = None 43 | ): 44 | super().__init__() 45 | 46 | self.keys = keys 47 | self.readout_modules = readout_modules 48 | self.stop_gradient = stop_gradient 49 | 50 | def forward(self, inputs: Array) -> ArrayTree: 51 | num_targets = len(self.keys) 52 | assert num_targets >= 1, "Need to have at least one target." 53 | assert len(self.readout_modules) == num_targets, ( 54 | f"len(modules):({len(self.readout_modules)}) and len(keys):({len(self.keys)}) must match.") 55 | if self.stop_gradient is not None: 56 | assert len(self.stop_gradient) == num_targets, ( 57 | f"len(stop_gradient):({len(self.stop_gradient)}) and len(keys):({len(self.keys)}) must match.") 58 | outputs = {} 59 | modules_iter = iter(self.readout_modules) 60 | for i in range(num_targets): 61 | if self.stop_gradient is not None and self.stop_gradient[i]: 62 | x = x.detach() # FIXME 63 | else: 64 | x = inputs 65 | outputs[self.keys[i]] = next(modules_iter)(x) 66 | return outputs 67 | 68 | class DummyReadout(nn.Module): 69 | 70 | def forward(self, inputs: Array) -> ArrayTree: 71 | return {} 72 | 73 | class MLP(nn.Module): 74 | """Simple MLP with one hidden layer and optional pre-/post-layernorm.""" 75 | 76 | def __init__(self, 77 | input_size: int, # FIXME: added because or else can't instantiate submodules 78 | hidden_size: int, 79 | output_size: int, # if not given, should be inputs.shape[-1] at forward 80 | num_hidden_layers: int = 1, 81 | activation_fn: nn.Module = nn.ReLU, 82 | layernorm: Optional[str] = None, 83 | activate_output: bool = False, 84 | residual: bool = False, 85 | weight_init = None 86 | ): 87 | super().__init__() 88 | 89 | self.input_size = input_size 90 | self.hidden_size = hidden_size 91 | self.output_size = output_size 92 | self.num_hidden_layers = num_hidden_layers 93 | self.activation_fn = activation_fn 94 | self.layernorm = layernorm 95 | self.activate_output = activate_output 96 | self.residual = residual 97 | self.weight_init = weight_init 98 | 99 | # submodules 100 | ## layernorm 101 | if self.layernorm == "pre": 102 | self.layernorm_module = nn.LayerNorm(input_size, eps=1e-6) 103 | elif self.layernorm == "post": 104 | self.layernorm_module = nn.LayerNorm(output_size, eps=1e-6) 105 | ## mlp 106 | self.model = nn.ModuleList() 107 | self.model.add_module("dense_mlp_0", nn.Linear(self.input_size, self.hidden_size)) 108 | self.model.add_module("dense_mlp_0_act", self.activation_fn()) 109 | for i in range(1, self.num_hidden_layers): 110 | self.model.add_module(f"den_mlp_{i}", nn.Linear(self.hidden_size, self.hidden_size)) 111 | self.model.add_module(f"dense_mlp_{i}_act", self.activation_fn()) 112 | self.model.add_module(f"dense_mlp_{self.num_hidden_layers}", nn.Linear(self.hidden_size, self.output_size)) 113 | if self.activate_output: 114 | self.model.add_module(f"dense_mlp_{self.num_hidden_layers}_act", self.activation_fn()) 115 | for name, module in self.model.named_children(): 116 | if 'act' not in name: 117 | # nn.init.xavier_uniform_(module.weight) 118 | init_fn[weight_init['linear_w']](module.weight) 119 | init_fn[weight_init['linear_b']](module.bias) 120 | 121 | def forward(self, inputs: Array, train: bool = False) -> Array: 122 | del train # Unused 123 | 124 | x = inputs 125 | if self.layernorm == "pre": 126 | x = self.layernorm_module(x) 127 | for layer in self.model: 128 | x = layer(x) 129 | if self.residual: 130 | x = x + inputs 131 | if self.layernorm == "post": 132 | x = self.layernorm_module(x) 133 | return x 134 | 135 | class myGRUCell(nn.Module): 136 | """GRU cell as nn.Module 137 | 138 | Added because nn.GRUCell doesn't match up with jax's GRUCell... 139 | This one is designed to match ! (almost; output returns only once) 140 | 141 | The mathematical definition of the cell is as follows 142 | 143 | .. math:: 144 | 145 | \begin{array}{ll} 146 | r = \sigma(W_{ir} x + W_{hr} h + b_{hr}) \\ 147 | z = \sigma(W_{iz} x + W_{hz} h + b_{hz}) \\ 148 | n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ 149 | h' = (1 - z) * n + z * h \\ 150 | \end{array} 151 | """ 152 | 153 | def __init__(self, 154 | input_size: int, 155 | hidden_size: int, 156 | gate_fn = torch.sigmoid, 157 | activation_fn = torch.tanh, 158 | weight_init = None 159 | ): 160 | super().__init__() 161 | 162 | self.input_size = input_size 163 | self.hidden_size = hidden_size 164 | self.gate_fn = gate_fn 165 | self.activation_fn = activation_fn 166 | self.weight_init = weight_init 167 | 168 | # submodules 169 | self.dense_ir = nn.Linear(input_size, hidden_size) 170 | self.dense_iz = nn.Linear(input_size, hidden_size) 171 | self.dense_in = nn.Linear(input_size, hidden_size) 172 | self.dense_hr = nn.Linear(hidden_size, hidden_size, bias=False) 173 | self.dense_hz = nn.Linear(hidden_size, hidden_size, bias=False) 174 | self.dense_hn = nn.Linear(hidden_size, hidden_size) 175 | self.reset_parameters() 176 | 177 | def reset_parameters(self) -> None: 178 | recurrent_weight_init = nn.init.orthogonal_ 179 | if self.weight_init is not None: 180 | weight_init = init_fn[self.weight_init['linear_w']] 181 | bias_init = init_fn[self.weight_init['linear_b']] 182 | else: 183 | # weight init not given 184 | stdv = 1.0 / math.sqrt(self.hidden_size) if self.hidden_size > 0 else 0 185 | weight_init = bias_init = lambda weight: nn.init.uniform_(weight, -stdv, stdv) 186 | # input weights 187 | weight_init(self.dense_ir.weight) 188 | bias_init(self.dense_ir.bias) 189 | weight_init(self.dense_iz.weight) 190 | bias_init(self.dense_iz.bias) 191 | weight_init(self.dense_in.weight) 192 | bias_init(self.dense_in.bias) 193 | # hidden weights 194 | recurrent_weight_init(self.dense_hr.weight) 195 | recurrent_weight_init(self.dense_hz.weight) 196 | recurrent_weight_init(self.dense_hn.weight) 197 | bias_init(self.dense_hn.bias) 198 | 199 | def forward(self, inputs, carry): 200 | h = carry 201 | # input and recurrent layeres are summed so only one needs a bias 202 | r = self.gate_fn(self.dense_ir(inputs) + self.dense_hr(h)) 203 | z = self.gate_fn(self.dense_iz(inputs) + self.dense_hz(h)) 204 | # add bias because the linear transformations aren't directly summed 205 | n = self.activation_fn(self.dense_in(inputs) + 206 | r * self.dense_hn(h)) 207 | new_h = (1. - z) * n + z * h 208 | return new_h 209 | 210 | # class GRU(nn.Module): 211 | # """GRU cell as nn.Module.""" 212 | 213 | # def __init__(self, 214 | # input_size: int, # FIXME: added for submodules 215 | # hidden_size: int, # FIXME: added for submodules 216 | # ): 217 | # super().__init__() 218 | 219 | # # submodules 220 | # self.gru = nn.GRUCell(input_size, hidden_size) 221 | 222 | # def forward(self, carry: Array, inputs: Array, 223 | # train: bool = False) -> Array: 224 | # del train # unused 225 | 226 | # carry = self.gru(inputs, carry) 227 | # return carry 228 | 229 | 230 | # class Dense(nn.Module): 231 | # """Dense layer as nn.Module accepting "train" flag. """ 232 | 233 | # def __init__(self, 234 | # input_shape: int, # FIXME: added for submodules 235 | # features: int, 236 | # use_bias: bool = True 237 | # ): 238 | # super().__init__() 239 | 240 | # # submodules 241 | # self.dense = nn.Linear(input_shape, features, use_bias) 242 | 243 | # def forward(self, inputs: Array, train: bool = False) -> Array: 244 | # del train # Unused. 245 | # return self.dense(inputs) 246 | 247 | 248 | class PositionEmbedding(nn.Module): 249 | """A module for applying N-dimensional position embedding. 250 | 251 | Attr: 252 | embedding_type: A string defining the type of position embedding to use. 253 | One of ["linear", "discrete_1d", "fourier", "gaussian_fourier"]. 254 | update_type: A string defining how the input is updated with the position embedding. 255 | One of ["proj_add", "concat"]. 256 | num_fourier_bases: The number of Fourier bases to use. For embedding_type == "fourier", 257 | the embedding dimensionality is 2 x number of position dimensions x num_fourier_bases. 258 | For embedding_type == "gaussian_fourier", the embedding dimensionality is 259 | 2 x num_fourier_bases. For embedding_type == "linear", this parameter is ignored. 260 | gaussian_sigma: Standard deviation of sampled Gaussians. 261 | pos_transform: Optional transform for the embedding. 262 | output_transform: Optional transform for the combined input and embedding. 263 | trainable_pos_embedding: Boolean flag for allowing gradients to flow into the position 264 | embedding, so that the optimizer can update it. 265 | """ 266 | 267 | def __init__(self, 268 | input_shape: Tuple[int], # FIXME: added for submodules. 269 | embedding_type: str, 270 | update_type: str, 271 | num_fourier_bases: int = 0, 272 | gaussian_sigma: float = 1.0, 273 | pos_transform: nn.Module = nn.Identity(), 274 | output_transform: nn.Module = nn.Identity(), 275 | trainable_pos_embedding: bool = False, 276 | weight_init = None 277 | ): 278 | super().__init__() 279 | 280 | self.input_shape = input_shape 281 | self.embedding_type = embedding_type 282 | self.update_type = update_type 283 | self.num_fourier_bases = num_fourier_bases 284 | self.gaussian_sigma = gaussian_sigma 285 | self.pos_transform = pos_transform 286 | self.output_transform = output_transform 287 | self.trainable_pos_embedding = trainable_pos_embedding 288 | self.weight_init = weight_init 289 | 290 | # submodules defined in module. 291 | self.pos_embedding = nn.Parameter(self._make_pos_embedding_tensor(input_shape), 292 | requires_grad=self.trainable_pos_embedding) 293 | if self.update_type == "project_add": 294 | self.project_add_dense = nn.Linear(self.pos_embedding.shape[-1], input_shape[-1]) 295 | # nn.init.xavier_uniform_(self.project_add_dense.weight) 296 | init_fn[weight_init['linear_w']](self.project_add_dense.weight) 297 | init_fn[weight_init['linear_b']](self.project_add_dense.bias) 298 | 299 | # TODO: validate 300 | def _make_pos_embedding_tensor(self, input_shape): 301 | if self.embedding_type == "discrete_1d": 302 | # An integer tensor in [0, input_shape[-2]-1] reflecting 303 | # 1D discrete position encoding (encode the second-to-last axis). 304 | pos_embedding = np.broadcast_to( 305 | np.arange(input_shape[-2]), input_shape[1:-1]) 306 | else: 307 | # A tensor grid in [-1, +1] for each input dimension. 308 | pos_embedding = utils.create_gradient_grid(input_shape[1:-1], [-1.0, 1.0]) 309 | 310 | if self.embedding_type == "linear": 311 | pos_embedding = torch.from_numpy(pos_embedding) 312 | elif self.embedding_type == "discrete_1d": 313 | pos_embedding = F.one_hot(torch.from_numpy(pos_embedding), input_shape[-2]) 314 | elif self.embedding_type == "fourier": 315 | # NeRF-style Fourier/sinusoidal position encoding. 316 | pos_embedding = utils.convert_to_fourier_features( 317 | pos_embedding * np.pi, basis_degree=self.num_fourier_bases) 318 | pos_embedding = torch.from_numpy(pos_embedding) 319 | elif self.embedding_type == "gaussian_fourier": 320 | # Gaussian Fourier features. Reference: https://arxiv.org/abs/2006.10739 321 | num_dims = pos_embedding.shape[-1] 322 | projection = np.random.normal( 323 | size=[num_dims, self.num_fourier_bases]) * self.gaussian_sigma 324 | pos_embedding = np.pi * pos_embedding.dot(projection) 325 | # A slightly faster implementation of sin and cos. 326 | pos_embedding = np.sin( 327 | np.concatenate([pos_embedding, pos_embedding + 0.5 * np.pi], axis=-1)) 328 | pos_embedding = torch.from_numpy(pos_embedding) 329 | else: 330 | raise ValueError("Invalid embedding type provided.") 331 | 332 | # Add batch dimension. 333 | pos_embedding = pos_embedding.unsqueeze(0) 334 | pos_embedding = pos_embedding.float() 335 | 336 | return pos_embedding 337 | 338 | def forward(self, inputs: Array) -> Array: 339 | 340 | # Apply optional transformation on the position embedding. 341 | pos_embedding = self.pos_transform(self.pos_embedding).to(inputs.get_device()) 342 | 343 | # Apply position encoding to inputs. 344 | if self.update_type == "project_add": 345 | # Here, we project the position encodings to the same dimensionality as 346 | # the inputs and add them to the inputs (broadcast along batch dimension). 347 | # This is roughly equivalent to concatenation of position encodings to the 348 | # inputs (if followed by a Dense layer), but is slightly more efficient. 349 | x = inputs + self.project_add_dense(pos_embedding) 350 | elif self.update_type == "concat": 351 | # Repeat the position embedding along the first (batch) dimension. 352 | pos_embedding = torch.broadcast_to( 353 | pos_embedding, inputs.shape[:-1] + pos_embedding.shape[-1:]) 354 | # concatenate along the channel dimension. 355 | x = torch.concat((inputs, pos_embedding), dim=-1) 356 | else: 357 | raise ValueError("Invalid update type provided.") 358 | 359 | # Apply optional output transformation. 360 | x = self.output_transform(x) 361 | return x 362 | 363 | 364 | ##################################################### 365 | # Losses 366 | 367 | class ReconLoss(nn.Module): 368 | """L2 loss.""" 369 | 370 | def __init__(self, l2_weight=1, reduction="none"): 371 | super().__init__() 372 | 373 | self.l2 = nn.MSELoss(reduction=reduction) 374 | self.l2_weight = l2_weight 375 | 376 | def forward(self, model_outputs, batch): 377 | if isinstance(model_outputs, dict): 378 | pred_flow = model_outputs["outputs"]["flow"] 379 | else: 380 | # TODO: need to clean all of this up 381 | pred_flow = model_outputs[1] 382 | video, boxes, segmentations, gt_flow, padding_mask, mask = batch 383 | 384 | # l2 loss between images and predicted images 385 | loss = self.l2_weight * self.l2(pred_flow, gt_flow) 386 | 387 | # sum over elements, leaving [B, -1] 388 | return loss.reshape(loss.shape[0], -1).sum(-1) 389 | 390 | 391 | ####################################################### 392 | # Eval Metrics 393 | 394 | class ARI(nn.Module): 395 | """ARI.""" 396 | 397 | def forward(self, model_outputs, batch, args): 398 | video, boxes, segmentations, flow, padding_mask, mask = batch 399 | 400 | pr_seg = model_outputs[0].squeeze(-1).int().cpu().numpy() 401 | # pr_seg = model_outputs["outputs"]["segmentations"][:, 1:].squeeze(-1).int().cpu().numpy() 402 | gt_seg = segmentations.int().cpu().numpy() 403 | input_pad = padding_mask.cpu().numpy() 404 | mask = mask.cpu().numpy() 405 | 406 | # ari_bg = metrics.Ari.from_model_output( 407 | ari_bg = metrics_jax.Ari.from_model_output( 408 | predicted_segmentations=pr_seg, ground_truth_segmentations=gt_seg, 409 | padding_mask=input_pad, 410 | ground_truth_max_num_instances=args.max_instances + 1, 411 | predicted_max_num_instances=args.num_slots, 412 | ignore_background=False, mask=mask) 413 | # ari_nobg = metrics.Ari.from_model_output( 414 | ari_nobg = metrics_jax.Ari.from_model_output( 415 | predicted_segmentations=pr_seg, ground_truth_segmentations=gt_seg, 416 | padding_mask=input_pad, 417 | ground_truth_max_num_instances=args.max_instances + 1, 418 | predicted_max_num_instances=args.num_slots, 419 | ignore_background=True, mask=mask) 420 | 421 | return ari_bg, ari_nobg 422 | -------------------------------------------------------------------------------- /savi/modules/video.py: -------------------------------------------------------------------------------- 1 | """Video module library.""" 2 | 3 | # FIXME 4 | 5 | import functools 6 | from typing import Any, Callable, Dict, Iterable, Mapping, NamedTuple, Optional, Tuple, Union 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import numpy as np 12 | 13 | from savi.lib import utils 14 | from savi.modules import misc 15 | 16 | Shape = Tuple[int] 17 | 18 | DType = Any 19 | Array = torch.Tensor 20 | ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet 21 | ProcessorState = ArrayTree 22 | PRNGKey = Array 23 | NestedDict = Dict[str, Any] 24 | 25 | 26 | class Processor(nn.Module): 27 | """Recurrent processor module. 28 | 29 | This module is scanned (applied recurrently) over the sequence dimension of 30 | the input and applies a corrector and a predictor module. The corrector is 31 | only applied if new inputs (such as new image/frame) are received and uses 32 | the new input to correct its internal state. 33 | 34 | The predictor is equivalent to a latent transition model and produces a 35 | prediction for the state at the next time step, given teh current (corrected) 36 | state. 37 | """ 38 | 39 | def __init__(self, 40 | corrector: nn.Module, 41 | predictor: nn.Module 42 | ): 43 | super().__init__() 44 | 45 | self.corrector = corrector 46 | self.predictor = predictor 47 | 48 | def forward(self, slots: ProcessorState, inputs: Optional[Array], 49 | padding_mask: Optional[Array]) -> Tuple[Array, Array]: 50 | 51 | # Only apply corrector if we receive new inputs. 52 | if inputs is not None: 53 | # flatten spatial dims 54 | inputs = inputs.flatten(1, 2) 55 | corrected_slots, attn = self.corrector(slots, inputs, padding_mask) 56 | # Otherwise simply use previous state as input for predictor 57 | else: 58 | corrected_slots = slots 59 | 60 | # Always apply predictor (i.e. transition model). 61 | predicted_slots = self.predictor(corrected_slots) 62 | 63 | # Prepare outputs 64 | return corrected_slots, predicted_slots, attn 65 | 66 | 67 | class SAVi(nn.Module): 68 | """Video model consisting of encoder, recurrent processor, and decoder.""" 69 | 70 | def __init__(self, 71 | encoder: nn.Module, 72 | decoder: nn.Module, 73 | corrector: nn.Module, 74 | predictor: nn.Module, 75 | initializer: nn.Module, 76 | decode_corrected: bool = True, 77 | decode_predicted: bool = True 78 | ): 79 | super().__init__() 80 | 81 | self.initializer = initializer 82 | self.encoder = encoder 83 | self.corrector = corrector 84 | self.decoder = decoder 85 | self.predictor = predictor 86 | self.decode_corrected = decode_corrected 87 | self.decode_predicted = decode_predicted 88 | 89 | # submodules 90 | self.processor = Processor(corrector, predictor) 91 | 92 | def forward(self, video: Array, conditioning: Optional[Array] = None, 93 | continue_from_previous_state: bool = False, 94 | padding_mask: Optional[Array] = None, **kwargs) -> ArrayTree: 95 | """Performs a forward pass on a video. 96 | 97 | Args: 98 | video: Video of shape `[batch_size, n_frames, height, width, n_channels]`. 99 | conditioning: Optional tensor used for conditioning the initial state 100 | of the recurrent processor. 101 | continue_from_previous_state: Boolean, whether to continue from a previous 102 | state or not. If True, the conditioning variable is used directly as 103 | initial state. 104 | padding_mask: Binary mask for padding video inputs (e.g. for videos of 105 | different sizes/lengths). Zero corresponds to padding. 106 | 107 | Returns: 108 | A dictionary of model predictions. 109 | """ 110 | del kwargs # Unused. 111 | 112 | if padding_mask is None: 113 | padding_mask = torch.ones(video.shape[:-1], dtype=torch.int32) 114 | 115 | # video.shape = (batch_size, n_frames, height, width, n_channels) 116 | B, T, H, W, C = video.shape 117 | # encoded_inputs = self.encoder(video, padding_mask) 118 | # flatten over B * Time and unflatten after to get [B, T, h*, w*, F] 119 | encoded_inputs = self.encoder(video.flatten(0, 1)) 120 | encoded_inputs = encoded_inputs.reshape(shape=(B, T, *encoded_inputs.shape[-3:])) 121 | 122 | if continue_from_previous_state: 123 | assert conditioning is not None, ( 124 | "When continuing from a previous state, the state has to be passed " 125 | "via the `conditioning` variable, which cannot be `None`." 126 | ) 127 | init_slots = conditioning[:, -1] # currently, only use last state. 128 | # init_slots = conditioning # given [B, N, D], the slots of the last state 129 | else: 130 | # same as above but without encoded inputs. 131 | init_slots = self.initializer( 132 | conditioning, batch_size=video.shape[0]) 133 | 134 | # Scan recurrent processor over encoded inputs along sequence dimension. 135 | # # TODO: make this over t time steps. for loop ? 136 | # # corrected_st, predicted_st = self.processor( 137 | # # init_state, encoded_inputs, padding_mask) 138 | # # implementation try 1: 139 | # predicted_slots = init_slots 140 | # for t in range(T): 141 | # slots = predicted_slots 142 | # encoded_frame = encoded_inputs[:, t] 143 | # corrected_slots, predicted_slots = self.processor(slots, encoded_frame, padding_mask) 144 | 145 | # # corrected_st.shape = (batch_size, n_frames, ..., n_features) 146 | # # predicted_st.shape = (batch_size, n_frames, ..., n_features) 147 | 148 | # # Decode latent states. 149 | # outputs = self.decoder(corrected_slots) if self.decode_corrected else None 150 | # outputs_pred = self.decoder(predicted_slots) if self.decode_predicted else None 151 | 152 | # TODO: implementation try 2: 153 | # need to get the intermediate slots, not just the last. above doesn't return 154 | # all slots over all time. 155 | 156 | # TODO: do the decoding all at once instead of per-timestep like done here. 157 | outputs, outputs_pred, attn = None, None, None 158 | slots_corrected_list, slots_predicted_list, attn_list = [], [], [] 159 | predicted_slots = init_slots 160 | for t in range(T): 161 | slots = predicted_slots 162 | encoded_frame = encoded_inputs[:, t] 163 | corrected_slots, predicted_slots, attn_t = self.processor(slots, encoded_frame, padding_mask) 164 | 165 | slots_corrected_list.append(corrected_slots.unsqueeze(1)) 166 | slots_predicted_list.append(predicted_slots.unsqueeze(1)) 167 | attn_list.append(attn_t.unsqueeze(1)) 168 | 169 | corrected_slots = torch.cat(slots_corrected_list, dim=1) 170 | predicted_slots = torch.cat(slots_predicted_list, dim=1) 171 | attn = torch.cat(attn_list, dim=1) 172 | 173 | # Decode latent states 174 | outputs = self.decoder(corrected_slots.flatten(0,1)) if self.decode_corrected else None 175 | outputs_pred = self.decoder(predicted_slots.flatten(0,1)) if self.decode_predicted else None 176 | 177 | if outputs is not None: 178 | for key, value in outputs.items(): 179 | outputs[key] = value.reshape(B, T, *value.shape[1:]) 180 | if outputs_pred is not None: 181 | for key, value in outputs_pred.items(): 182 | outputs_pred[key] = value.reshape(B, T, *value.shape[1:]) 183 | 184 | return { 185 | # "states": corrected_slots, 186 | "states_pred": predicted_slots, 187 | "outputs": outputs, 188 | # "outputs_pred": outputs_pred, 189 | "attention": attn 190 | } 191 | 192 | 193 | class FrameEncoder(nn.Module): 194 | """Encoder for single video frame.""" 195 | 196 | def __init__(self, 197 | backbone: nn.Module, 198 | pos_emb: nn.Module = nn.Identity(), 199 | reduction: Optional[str] = None, # [spatial_flatten, spatial_average, all_flatten] 200 | output_transform: nn.Module = nn.Identity() 201 | ): 202 | super().__init__() 203 | 204 | self.backbone = backbone 205 | self.pos_emb = pos_emb 206 | self.reduction = reduction 207 | self.output_transform = output_transform 208 | 209 | def forward(self, inputs: Array, padding_mask: Optional[Array] = None) -> Tuple[Array, Dict[str, Array]]: 210 | del padding_mask # Unused. 211 | 212 | # inputs.shape = (batch_size, height, width, n_channels) 213 | x = self.backbone(inputs, channels_last=True) 214 | 215 | x = self.pos_emb(x) 216 | 217 | if self.reduction == "spatial_flatten": 218 | B, H, W, F = x.shape 219 | x = x.reshape(shape=(B, H*W, F)) 220 | elif self.reduction == "spatial_average": 221 | x = torch.mean(x, dim=(1,2)) 222 | elif self.reduction == "all_flatten": 223 | x = torch.flatten(x) 224 | elif self.reduction is not None: 225 | raise ValueError(f"Unknown reduction of type: {self.reduction}") 226 | 227 | x = self.output_transform(x) 228 | return x -------------------------------------------------------------------------------- /savi/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | # FIXME -------------------------------------------------------------------------------- /savi/trainers/tfds_trainer.py: -------------------------------------------------------------------------------- 1 | # Reference: MAE github https://github.com/facebookresearch/mae 2 | 3 | # TODO 4 | 5 | import jax 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.backends.cudnn as cudnn 11 | from torch.utils.tensorboard import SummaryWriter 12 | 13 | from typing import Iterable, Optional 14 | 15 | import random 16 | import math 17 | import os 18 | import sys 19 | import argparse 20 | import datetime 21 | import time 22 | import json 23 | from pathlib import Path 24 | 25 | from savi.datasets.tfds import tfds_input_pipeline 26 | from savi.datasets.tfds.tfds_dataset_wrapper import MOViData, MOViDataByRank 27 | import savi.modules as modules 28 | 29 | import savi.lib.losses as losses 30 | import savi.lib.metrics as metrics 31 | 32 | import savi.trainers.utils.misc as misc 33 | import savi.trainers.utils.lr_sched as lr_sched 34 | import savi.trainers.utils.lr_decay as lr_decay 35 | from savi.trainers.utils.misc import NativeScalerWithGradNormCount as NativeScaler 36 | 37 | def get_args(): 38 | parser = argparse.ArgumentParser('TFDS dataset training for SAVi.') 39 | def adrg(name, default, type=str, help=None): 40 | """ADd aRGuments to parser.""" 41 | if help: 42 | parser.add_argument(name, default=default, type=type, help=help) 43 | else: 44 | parser.add_argument(name, default=default, type=type) 45 | 46 | # Training config 47 | adrg('--seed', 42, int) 48 | adrg('--batch_size', 8, int, help='Batch size per GPU \ 49 | (effective batch size = batch_size * accum_iter * # gpus') 50 | # Try to use 8 gpus to get batch size 64, as it is the batch size used in the SAVi code. 51 | adrg('--epochs', 50, int) 52 | adrg('--accum_iter', 1, int, help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 53 | adrg('--num_train_steps', 100000, int) 54 | adrg('--device', 'cuda', help='device to use for training / testing') 55 | adrg('--num_workers', 10, int) 56 | 57 | # Resuming 58 | parser.add_argument('--resume', default='', 59 | help='resume from checkpoint') 60 | 61 | # distributed training parameters 62 | parser.add_argument('--world_size', default=1, type=int, 63 | help='number of distributed processes') 64 | parser.add_argument('--local_rank', default=-1, type=int) 65 | parser.add_argument('--dist_on_itp', action='store_true') 66 | parser.add_argument('--dist_url', default='env://', 67 | help='url used to set up distributed training') 68 | 69 | # Adam optimizer config 70 | adrg('--lr', 2e-4, float) 71 | adrg('--warmup_steps', 2500, int) 72 | adrg('--max_grad_norm', 0.05, float) 73 | 74 | # Logging and Saving config 75 | adrg('--log_loss_every_step', 50, int) 76 | adrg('--eval_every_steps', 1000, int) 77 | adrg('--checkpoint_every_steps', 5000) 78 | adrg('--output_dir', './output_dir', help="path where to save, empty for no saving.") 79 | adrg('--log_dir', './output_dir', help="path where to log tensorboard log") 80 | 81 | # Misc 82 | parser.add_argument('--pin_mem', action='store_true', 83 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 84 | 85 | 86 | # Metrics Spec 87 | adrg('--metrics', 'loss,ari,ari_nobg') 88 | 89 | # Dataset 90 | adrg('--tfds_name', "movi_a/128x128:1.0.0", help="Dataset for training/eval") 91 | adrg('--data_dir', "/home/junkeun-yi/current/datasets/kubric/") 92 | adrg('--shuffle_buffer_size', 8*8, help="should be batch_size * 8") 93 | 94 | # Model 95 | adrg('--max_instances', 10, int, help="Number of slots") # For Movi-A,B,C, only up to 10. for MOVi-D,E, up to 23. 96 | adrg('--model_size', 'small', help="How to prepare data and model architecture.") 97 | 98 | # Evaluation 99 | adrg('--eval_slice_size', 6, int) 100 | adrg('--eval_slice_keys', 'video,segmentations,flow,boxes') 101 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 102 | parser.add_argument('--dist_eval', action='store_true', default=False, 103 | help='Enabling distributed evaluation (recommended during training for faster monitor') 104 | 105 | 106 | args = parser.parse_args() 107 | # Metrics 108 | args.train_metrics_spec = { 109 | v: v for v in args.metrics.split(',')} 110 | args.eval_metrics_spec = { 111 | f"eval_{v}": v for v in args.metrics.split(',')} 112 | # Misc 113 | args.num_slots = args.max_instances + 1 # only used for metrics 114 | args.logging_min_n_colors = args.max_instances 115 | args.eval_slice_keys = [v for v in args.eval_slice_keys.split(',')] 116 | 117 | # HARDCODED 118 | args.targets = {"flow": 3} 119 | args.losses = {f"recon_{target}": {"loss_type": "recon", "key": target} 120 | for target in args.targets} 121 | 122 | # Preprocessing 123 | if args.model_size =="small": 124 | args.preproc_train = [ 125 | "video_from_tfds", 126 | f"sparse_to_dense_annotation(max_instances={args.max_instances})", 127 | "temporal_random_strided_window(length=6)", 128 | "resize_small(64)", 129 | "flow_to_rgb()" # NOTE: This only uses the first two flow dimensions. 130 | ] 131 | args.preproc_eval = [ 132 | "video_from_tfds", 133 | f"sparse_to_dense_annotation(max_instances={args.max_instances})", 134 | "temporal_crop_or_pad(length=24)", 135 | "resize_small(64)", 136 | "flow_to_rgb()" # NOTE: This only uses the first two flow dimensions. 137 | ] 138 | 139 | return args 140 | 141 | def build_model(args): 142 | if args.model_size == "small": 143 | slot_size = 128 144 | num_slots = args.num_slots 145 | # Encoder 146 | encoder = modules.FrameEncoder( 147 | backbone=modules.CNN( 148 | features=[3, 32, 32, 32, 32], 149 | kernel_size=[(5, 5), (5, 5), (5, 5), (5, 5)], 150 | strides=[(1, 1), (1, 1), (1, 1), (1, 1)], 151 | layer_transpose=[False, False, False, False]), 152 | pos_emb=modules.PositionEmbedding( 153 | input_shape=(args.batch_size, 4, 4, 32), # TODO: validate, should be backbone output size 154 | embedding_type="linear", 155 | update_type="project_add", 156 | output_transform=modules.MLP( 157 | input_size=32, 158 | hidden_size=64, 159 | output_size=32, 160 | layernorm="pre"))) 161 | # Corrector 162 | corrector = modules.SlotAttention( 163 | input_size=32, # TODO: validate, should be backbone output size 164 | qkv_size=128, 165 | slot_size=slot_size, 166 | num_iterations=1) 167 | # Predictor 168 | predictor = modules.TransformerBlock( 169 | embed_dim=slot_size, 170 | num_heads=4, 171 | qkv_size=128, 172 | mlp_size=256) 173 | # Initializer 174 | initializer = modules.CoordinateEncoderStateInit( 175 | embedding_transform=modules.MLP( 176 | input_size=4, # bounding boxes have feature size 4 177 | hidden_size=256, 178 | output_size=slot_size, 179 | layernorm=None), 180 | prepend_background=True, 181 | center_of_mass=False) 182 | # Decoder 183 | decoder = modules.SpatialBroadcastDecoder( 184 | resolution=(8,8), # Update if data resolution or strides change. 185 | backbone=modules.CNN( 186 | features=[slot_size, 64, 64, 64, 64], 187 | kernel_size=[(5, 5), (5, 5), (5, 5), (5, 5)], 188 | strides=[(2, 2), (2, 2), (2, 2), (1, 1)], 189 | layer_transpose=[True, True, True, False]), 190 | pos_emb=modules.PositionEmbedding( 191 | input_shape=(args.batch_size, 8, 8, 128), 192 | embedding_type="linear", 193 | update_type="project_add"), 194 | target_readout=modules.Readout( 195 | keys=list(args.targets), 196 | readout_modules=nn.ModuleList([ 197 | nn.Linear(64, out_features) for out_features in args.targets.values()]))) 198 | # SAVi Model 199 | model = modules.SAVi( 200 | encoder=encoder, 201 | decoder=decoder, 202 | corrector=corrector, 203 | predictor=predictor, 204 | initializer=initializer, 205 | decode_corrected=True, 206 | decode_predicted=False) 207 | else: 208 | raise NotImplementedError 209 | return model 210 | 211 | def build_datasets(args): 212 | rng = jax.random.PRNGKey(args.seed) 213 | train_ds, eval_ds = tfds_input_pipeline.create_datasets(args, rng) 214 | 215 | num_tasks = misc.get_world_size() 216 | global_rank = misc.get_rank() 217 | 218 | traindata = MOViDataByRank(train_ds, global_rank, num_tasks) 219 | evaldata = MOViDataByRank(eval_ds, global_rank, num_tasks) 220 | 221 | return traindata, evaldata 222 | 223 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 224 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 225 | device: torch.device, epoch: int, loss_scaler, global_step, max_norm: float = 0, 226 | log_writer=None, args=None): 227 | model.train(True) 228 | metric_logger = misc.MetricLogger(delimiter=" ") 229 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f')) 230 | header = 'Epoch: [{}]'.format(epoch) 231 | print_freq = args.log_loss_every_step 232 | 233 | accum_iter = args.accum_iter 234 | 235 | optimizer.zero_grad() 236 | 237 | if log_writer is not None: 238 | print('log_dir: {}'.format(log_writer.log_dir)) 239 | 240 | # TODO: only first epoch has scheduler, and does step-wise scheduling 241 | if epoch == 0: 242 | scheduler = lr_sched.get_cosine_schedule_with_warmup(optimizer, args.warmup_steps, args.num_train_steps, num_cycles=1, last_epoch=-1) 243 | else: 244 | scheduler = None 245 | 246 | for data_iter_step, (video, boxes, flow, padding_mask, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 247 | 248 | if global_step % args.eval_every_steps: 249 | # TODO: evaluate 250 | pass 251 | 252 | if global_step % args.checkpoint_every_steps: 253 | # TODO: checkpoint 254 | pass 255 | 256 | # SAVi doesn't train on epochs, just on steps. 257 | if global_step > args.num_train_steps: 258 | break 259 | 260 | # need to squeeze because of weird dataset wrapping ... 261 | video = video.squeeze(0).to(device, non_blocking=True) 262 | boxes = boxes.squeeze(0).to(device, non_blocking=True) 263 | flow = flow.squeeze(0).to(device, non_blocking=True) 264 | padding_mask = padding_mask.squeeze(0).to(device, non_blocking=True) 265 | # segmentations = segmentations.squeeze(0).to(device, non_blocking=True) 266 | 267 | print('video', video.shape, end='\r') 268 | 269 | conditioning = boxes # TODO: make this not hardcoded 270 | 271 | with torch.cuda.amp.autocast(): 272 | outputs = model(video=video, conditioning=conditioning, 273 | padding_mask=padding_mask) 274 | loss = criterion(outputs["outputs"]["flow"], flow) 275 | 276 | loss_value = loss.item() 277 | 278 | if not math.isfininte(loss_value): 279 | print("Loss is {}, stopping training".format(loss_value)) 280 | sys.exit(1) 281 | 282 | loss /= accum_iter 283 | loss_scaler(loss, optimizer, clip_grad=max_norm, 284 | parameters=model.parameters(), create_graph=False, 285 | update_grad=(data_iter_step + 1) % accum_iter == 0) 286 | if (data_iter_step + 1) % accum_iter == 0: 287 | optimizer.zero_grad() 288 | if scheduler is not None: 289 | scheduler.step() 290 | 291 | torch.cuda.synchronize() 292 | 293 | metric_logger.update(loss=loss_value) 294 | 295 | lr = optimizer.param_groups[0]["lr"] 296 | metric_logger.update(lr=lr) 297 | 298 | loss_value_reduce = misc.all_reduce_mean(loss_value) 299 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 300 | """ We use epoch_1000x as the x-axis in tensorboard. 301 | This calibrates different curves when batch size changes. 302 | """ 303 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 304 | log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x) 305 | log_writer.add_scalar('lr', lr, epoch_1000x) 306 | 307 | global_step += 1 308 | 309 | # gather the stats from all processes 310 | metric_logger.synchronize_between_processes() 311 | print("Averaged stats:", metric_logger) 312 | return global_step, {k: meter.global_avg for k, meter in metric_logger.meters.items()} 313 | 314 | 315 | @torch.no_grad() 316 | def evaluate(data_loader, model, device, args): 317 | loss = losses.recon_loss 318 | ari = metrics.adjusted_rand_index 319 | 320 | 321 | metric_logger = misc.MetricLogger(delimiter=" ") 322 | header = 'Test:' 323 | 324 | # switch to evaluation mode 325 | model.eval() 326 | 327 | for (video, boxes, flow, padding_mask, segmentations) in metric_logger.log_every(data_loader, 10, header): 328 | video = video.squeeze(0).to(device, non_blocking=True) 329 | boxes = boxes.squeeze(0).to(device, non_blocking=True) 330 | flow = flow.squeeze(0).to(device, non_blocking=True) 331 | padding_mask = padding_mask.squeeze(0).to(device, non_blocking=True) 332 | segmentations = segmentations.squeeze(0).to(device, non_blocking=True) 333 | 334 | conditioning = boxes # TODO: don't hardcode 335 | 336 | # compute output 337 | with torch.cuda.amp.autocast(): 338 | outputs = model(video=video, conditioning=conditioning, 339 | padding_mask=padding_mask) 340 | loss = loss(outputs["outputs"]["flow"], flow) 341 | ari_bg = ari(pred_ids=outputs["outputs"]["segmentations"], 342 | true_ids=segmentations, num_instances_pred=args.num_slots, 343 | num_instances_true=args.max_instances + 1, # add bg, 344 | padding_mask=padding_mask, ignore_background=False) 345 | ari_nobg = ari(pred_ids=outputs["outputs"]["segmentations"], 346 | true_ids=segmentations, num_instances_pred=args.num_slots, 347 | num_instances_true=args.max_instances + 1, # add bg, 348 | padding_mask=padding_mask, ignore_background=True) 349 | 350 | # TODO: change tensors to numpy before doing calculations. 351 | # TODO: update meters with number of items according to given by metrics fn 352 | 353 | batch_size = video.shape[0] 354 | metric_logger.update(loss=loss.item()) 355 | metric_logger.meters['ari'].update(ari_bg, n=batch_size) 356 | metric_logger.meters['ari_nobg'].update(ari_nobg, n=batch_size) 357 | 358 | # gather the stats from all processes 359 | metric_logger.synchronize_between_processes() 360 | print('* ARI {ari_bg.global_avg:.3f} ARI_NoBg {ari_nobg.global_avg:.3f} loss {losses.global_avg:.3f}' 361 | .format(ari_bg=metric_logger.ari, ari_nobg=metric_logger.ari_nobg, losses=metric_logger.loss)) 362 | 363 | # switch back to training 364 | model.train() 365 | 366 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 367 | 368 | def run(args): 369 | misc.init_distributed_mode(args) 370 | 371 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 372 | print("{}".format(args).replace(', ', ',\n')) 373 | 374 | device = torch.device(args.device) 375 | 376 | # fix the seed for reproducibility 377 | seed = args.seed + misc.get_rank() 378 | torch.manual_seed(seed) 379 | np.random.seed(seed) 380 | random.seed(seed) 381 | 382 | cudnn.benchmark = True 383 | 384 | dataset_train, dataset_val = build_datasets(args) 385 | 386 | if True: # args.distributed: 387 | num_tasks = misc.get_world_size() 388 | global_rank = misc.get_rank() 389 | # sampler_train = torch.utils.data.DistributedSampler( 390 | # dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=False) 391 | sampler_train = torch.utils.data.SequentialSampler(dataset_train) 392 | # print("Sampler_train") 393 | if args.dist_eval: 394 | if len(dataset_val) % num_tasks != 0: 395 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 396 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 397 | 'equal num of samples per-process.') 398 | # sampler_val = torch.utils.data.DistributedSampler( 399 | # dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) 400 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 401 | else: 402 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 403 | else: 404 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 405 | sampler_val = troch.utils.data.SequentialSampler(dataset_val) 406 | 407 | if global_rank == 0 and args.log_dir is not None: 408 | os.makedirs(args.log_dir, exist_ok=True) 409 | log_writer = SummaryWriter(log_dir=args.log_dir) 410 | else: 411 | log_writer = None 412 | 413 | data_loader_train = torch.utils.data.DataLoader( 414 | dataset_train, sampler=sampler_train, 415 | batch_size=1, # HARDCODED because doing something weird with this. 416 | num_workers=args.num_workers, 417 | pin_memory=args.pin_mem, 418 | drop_last=True 419 | ) 420 | 421 | data_loader_val = torch.utils.data.DataLoader( 422 | dataset_val, sampler=sampler_val, 423 | batch_size=1, # HARDCODED because doing something weird with this. 424 | num_workers=args.num_workers, 425 | pin_memory=args.pin_mem, 426 | drop_last=True 427 | ) 428 | 429 | # Model setup 430 | model = build_model(args) 431 | 432 | # TODO: make checkpoint loading 433 | 434 | model.to(device) 435 | 436 | model_without_ddp = model 437 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 438 | 439 | # print("Model = %s" % str(model_without_ddp)) 440 | print('number of params (M): %.2f' % (n_parameters / 1.e6)) 441 | 442 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 443 | 444 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 445 | print("actual lr: %.2e" % args.lr) 446 | 447 | print("accumulate grad iterations: %d" % args.accum_iter) 448 | print("effective batch size: %d" % eff_batch_size) 449 | 450 | if args.distributed: 451 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 452 | model_without_ddp = model.module 453 | 454 | # build optimizer 455 | optimizer = torch.optim.Adam(model_without_ddp.parameters(), lr=args.lr) 456 | loss_scaler = NativeScaler() 457 | 458 | # Loss 459 | criterion = losses.recon_loss 460 | print("criterion = %s" % str(criterion)) 461 | 462 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 463 | 464 | if args.eval: 465 | test_stats = evaluate(data_loader_val, model, device) 466 | print(test_stats) 467 | exit(0) 468 | 469 | print(f"Start training for {args.num_train_steps} steps.") 470 | start_time = time.time() 471 | max_accuracy = 0.0 472 | global_step = 0 473 | for epoch in range(0, args.epochs): 474 | # if args.distributed: 475 | # data_loader_train.sampler.set_epoch(epoch) 476 | step_add, train_stats = train_one_epoch( 477 | model, criterion, data_loader_train, 478 | optimizer, device, epoch, loss_scaler, 479 | global_step, args.max_grad_norm, 480 | log_writer, args 481 | ) 482 | global_step += step_add 483 | if args.output_dir: 484 | misc.save_model( 485 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 486 | loss_scaler=loss_scaler, epoch=epoch) 487 | 488 | test_stats = evaluate(data_loader_val, model, device, args) 489 | print(test_stats) 490 | 491 | # log writer stuff. 492 | 493 | def main(): 494 | args = get_args() 495 | 496 | if args.output_dir: 497 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 498 | 499 | run(args) 500 | 501 | 502 | def test(): 503 | # args = get_args() 504 | # model = build_model(args) 505 | # print(model) 506 | 507 | main() 508 | 509 | 510 | if __name__ == "__main__": 511 | test() 512 | 513 | 514 | """ 515 | 516 | PYTHONPATH=$PYTHONPATH:./ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch --nproc_per_node=8 savi/main.py 517 | 518 | PYTHONPATH=$PYTHONPATH:./ CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch --nproc_per_node=4 savi/main.py 519 | 520 | PYTHONPATH=$PYTHONPATH:./ CUDA_VISIBLE_DEVICES=0 python3 -m torch.distributed.launch --nproc_per_node=1 savi/main.py 521 | 522 | """ -------------------------------------------------------------------------------- /savi/trainers/tfds_trainer_dataparallel.py: -------------------------------------------------------------------------------- 1 | # Reference: MAE github https://github.com/facebookresearch/mae 2 | 3 | # TODO 4 | 5 | import jax 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torchvision 11 | 12 | from typing import Iterable, Optional 13 | 14 | import random 15 | import math 16 | import os 17 | import sys 18 | import argparse 19 | from datetime import datetime 20 | from pathlib import Path 21 | import wandb 22 | 23 | from savi.datasets.tfds import tfds_input_pipeline 24 | from savi.datasets.tfds.tfds_dataset_wrapper import MOViData 25 | import savi.modules as modules 26 | import savi.modules_flow as modules_flow 27 | import savi.modules.evaluator 28 | 29 | import savi.trainers.utils.misc as misc 30 | import savi.trainers.utils.lr_sched as lr_sched 31 | import savi.trainers.utils.lr_decay as lr_decay 32 | 33 | processors_dict = { 34 | 'savi': modules.savi_build_modules, 35 | 'flow': modules_flow.flow_build_modules 36 | } 37 | 38 | def get_args(): 39 | parser = argparse.ArgumentParser('TFDS dataset training for SAVi.') 40 | def adrg(name, default, type=str, help=None): 41 | """ADd aRGuments to parser.""" 42 | if help: 43 | parser.add_argument(name, default=default, type=type, help=help) 44 | else: 45 | parser.add_argument(name, default=default, type=type) 46 | 47 | # Training config 48 | adrg('--seed', 42, int) 49 | adrg('--epochs', 50, int) 50 | adrg('--num_train_steps', 100000, int) 51 | adrg('--batch_size', 64, int, help='Batch size') 52 | parser.add_argument('--accum_iter', default=1, type=int, help="Accumulate gradient iterations (for increasing the effective batch size under memory constraints)") 53 | parser.add_argument('--gpu', default='1', type=str, help='GPU id to use.') 54 | parser.add_argument('--slice_decode_inputs', action='store_true', help="decode in slices.") 55 | 56 | # Adam optimizer config 57 | adrg('--lr', 2e-4, float) 58 | adrg('--warmup_steps', 2500, int) 59 | adrg('--max_grad_norm', 0.05, float) 60 | 61 | # Logging and Saving config 62 | adrg('--log_loss_every_step', 50, int) 63 | adrg('--eval_every_steps', 1000, int) 64 | adrg('--checkpoint_every_steps', 5000, int) 65 | # adrg('--output_dir', './output_dir', help="path where to save, empty for no saving.") 66 | # adrg('--log_dir', './output_dir', help="path where to log tensorboard log") 67 | adrg('--exp', 'test', help="experiment name") 68 | parser.add_argument('--no_snap', action='store_true', help="don't snapshot model") 69 | parser.add_argument('--wandb', action='store_true', help="wandb logging") 70 | adrg('--group', 'test', help="wandb logging group") 71 | 72 | # Loading model 73 | adrg('--resume_from', None, str, help="absolute path of experiment snapshot") 74 | 75 | # Metrics Spec 76 | adrg('--metrics', 'loss,ari,ari_nobg') 77 | 78 | # Dataset 79 | adrg('--tfds_name', "movi_a/128x128:1.0.0", help="Dataset for training/eval") 80 | adrg('--data_dir', "/home/junkeun-yi/current/datasets/kubric/") 81 | # adrg('--shuffle_buffer_size', 64, help="should be batch_size") 82 | 83 | # Model 84 | adrg('--max_instances', 23, int, help="Number of slots") # For Movi-A,B,C, only up to 10. for MOVi-D,E, up to 23. 85 | adrg('--model_size', 'small', help="How to prepare data and model architecture.") 86 | adrg('--model_type', 'savi', help="model type") 87 | parser.add_argument('--init_weight', default='default', help='weight init') 88 | parser.add_argument('--init_bias', default='default', help='bias init') 89 | 90 | # Evaluation 91 | adrg('--eval_slice_size', 6, int) 92 | # adrg('--eval_slice_keys', 'video,segmentations,flow,boxes') 93 | parser.add_argument('--eval', action='store_true', help="Perform evaluation only") 94 | 95 | 96 | args = parser.parse_args() 97 | # Weights 98 | args.weight_init = { 99 | 'param': args.init_weight, 100 | 'linear_w': args.init_weight, 101 | 'linear_b': args.init_bias, 102 | 'conv_w': args.init_weight, 103 | # convtranspose kernel shape requires special handling. 104 | 'convtranspose_w': "lecun_normal_fan_out" if args.init_weight == 'lecun_normal' else args.init_weight, 105 | 'conv_b': args.init_bias} 106 | # Training 107 | args.gpu = [int(i) for i in args.gpu.split(',')] 108 | # Metrics 109 | args.train_metrics_spec = { 110 | v: v for v in args.metrics.split(',')} 111 | args.eval_metrics_spec = { 112 | f"eval_{v}": v for v in args.metrics.split(',')} 113 | # Misc 114 | args.num_slots = args.max_instances + 1 # only used for metrics 115 | args.logging_min_n_colors = args.max_instances 116 | # args.eval_slice_keys = [v for v in args.eval_slice_keys.split(',')] 117 | args.shuffle_buffer_size = args.batch_size * 8 118 | # if not args.group: 119 | # args.group = f"{args.model_type}_{args.tfds_name.split('/')[0]}" 120 | kwargs = {} 121 | kwargs['slice_decode_inputs'] = True if args.slice_decode_inputs else False 122 | args.kwargs = kwargs 123 | 124 | # HARDCODED 125 | args.targets = {"flow": 3} 126 | args.losses = {f"recon_{target}": {"loss_type": "recon", "key": target} 127 | for target in args.targets} 128 | 129 | # Preprocessing 130 | if args.model_size == "small": 131 | args.preproc_train = [ 132 | "video_from_tfds", 133 | f"sparse_to_dense_annotation(max_instances={args.max_instances})", 134 | "temporal_random_strided_window(length=6)", 135 | "resize_small(64)", 136 | "flow_to_rgb()" # NOTE: This only uses the first two flow dimensions. 137 | ] 138 | args.preproc_eval = [ 139 | "video_from_tfds", 140 | f"sparse_to_dense_annotation(max_instances={args.max_instances})", 141 | "temporal_crop_or_pad(length=24)", 142 | "resize_small(64)", 143 | "flow_to_rgb()" # NOTE: This only uses the first two flow dimensions. 144 | ] 145 | elif args.model_size == "medium": 146 | args.preproc_train = [ 147 | "video_from_tfds", 148 | f"sparse_to_dense_annotation(max_instances={args.max_instances})", 149 | "temporal_random_strided_window(length=6)", 150 | "resize_small(128)", 151 | "flow_to_rgb()" # NOTE: This only uses the first two flow dimensions. 152 | ] 153 | args.preproc_eval = [ 154 | "video_from_tfds", 155 | f"sparse_to_dense_annotation(max_instances={args.max_instances})", 156 | "temporal_crop_or_pad(length=24)", 157 | "resize_small(128)", 158 | "flow_to_rgb()" # NOTE: This only uses the first two flow dimensions. 159 | ] 160 | 161 | return args 162 | 163 | 164 | def build_datasets(args): 165 | rng = jax.random.PRNGKey(args.seed) 166 | train_ds, eval_ds = tfds_input_pipeline.create_datasets(args, rng) 167 | 168 | traindata = MOViData(train_ds) 169 | evaldata = MOViData(eval_ds) 170 | 171 | return traindata, evaldata 172 | 173 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 174 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 175 | device: torch.device, epoch: int, global_step, start_time, 176 | max_norm: Optional[float] = None, args=None, 177 | val_loader=None, evaluator=None): 178 | model.train(True) 179 | 180 | # TODO: this is needed ... cuz using hack tfds wrapper. 181 | dataset = data_loader.dataset 182 | dataset.reset_itr() 183 | len_data = len(dataset) 184 | data_loader = torch.utils.data.DataLoader(dataset, 1, shuffle=False) 185 | 186 | # TODO: only first epoch has scheduler, and does step-wise scheduling 187 | if epoch == 0: 188 | # scheduler = None 189 | scheduler = lr_sched.get_cosine_schedule_with_warmup(optimizer, args.warmup_steps, args.num_train_steps, num_cycles=0.5, last_epoch=-1) 190 | else: 191 | scheduler = None 192 | 193 | loss = None 194 | grad_accum = 0 195 | for data_iter_step, (video, boxes, segmentations, flow, padding_mask, mask) in enumerate(data_loader): 196 | # need to squeeze because of weird dataset wrapping ... 197 | video = video.squeeze(0).to(device, non_blocking=True) # [64, 6, 64, 64, 3] 198 | boxes = boxes.squeeze(0).to(device, non_blocking=True) 199 | flow = flow.squeeze(0).to(device, non_blocking=True) 200 | padding_mask = padding_mask.squeeze(0).to(device, non_blocking=True) 201 | mask = mask.squeeze(0).to(device, non_blocking=True) if len(mask) > 0 else None 202 | segmentations = segmentations.squeeze(0).to(device, non_blocking=True) 203 | batch = (video, boxes, segmentations, flow, padding_mask, mask) 204 | 205 | conditioning = boxes # TODO: make this not hardcoded 206 | 207 | outputs = model(video=video, conditioning=conditioning, 208 | padding_mask=padding_mask) 209 | itr_loss = criterion(outputs, batch) 210 | if loss == None: 211 | loss = itr_loss 212 | del outputs 213 | del batch 214 | 215 | grad_accum += 1 216 | if grad_accum != args.accum_iter: 217 | # accumulating gradients to reach effective batch size. 218 | # effective batch size is batch_size * accum_iter 219 | 220 | # since loss will be [loss(item) for item in batch], we can 221 | # update the loss by extending the losses (i think). need to check 222 | loss = torch.cat([loss, itr_loss], dim=0) 223 | else: 224 | loss = loss.mean() # sum over elements, mean over batch. 225 | 226 | loss_value = loss.item() 227 | 228 | print(f"step: {global_step+1} / {args.num_train_steps}, loss: {loss_value}, clock: {datetime.now()-start_time}", end='\r') 229 | 230 | if not math.isfinite(loss_value): 231 | print("Loss is {}, stopping training".format(loss_value)) 232 | sys.exit(1) 233 | 234 | optimizer.zero_grad(set_to_none=True) 235 | 236 | loss.backward() 237 | # clip grad norm 238 | # TODO: fix grad norm clipping, as it's making the loss NaN 239 | if max_norm is not None: 240 | torch.nn.utils.clip_grad_norm(model.parameters(), max_norm) 241 | optimizer.step() 242 | if scheduler is not None: 243 | scheduler.step() 244 | 245 | if args.wandb: 246 | wandb.log({'train/loss': loss_value}) 247 | wandb.log({'train/lr': optimizer.param_groups[0]['lr']}) 248 | 249 | # global stepper. 250 | global_step += 1 251 | # if global_step % args.log_loss_every_step == 0: 252 | # # TODO: log the loss (with tensorboard / csv) 253 | # if args.wandb: 254 | # wandb.log({'train/loss': loss_value}) 255 | # print() 256 | # print() 257 | if global_step % args.eval_every_steps == 0: 258 | print() 259 | evaluate(val_loader, model, criterion, evaluator, device, args, global_step) 260 | if not args.no_snap and global_step % args.checkpoint_every_steps == 0: 261 | misc.save_snapshot(args, model.module, optimizer, global_step, f'./experiments/{args.group}_{args.exp}/snapshots/{global_step}.pt') 262 | # SAVi doesn't train on epochs, just on steps. 263 | if global_step >= args.num_train_steps: 264 | # save before exit 265 | print('done training') 266 | misc.save_snapshot(args, model.module, optimizer, global_step, f'./experiments/{args.group}_{args.exp}/snapshots/{global_step}.pt') 267 | print('exiting') 268 | # if args.wandb: 269 | # wandb.alert( 270 | # title="End of Run", 271 | # text=f"Run {args.group}_{args.exp} ended after {datetime.now()-start_time} time") 272 | sys.exit(0) 273 | 274 | grad_accum = 0 275 | loss = None 276 | 277 | return global_step, loss 278 | 279 | 280 | @torch.no_grad() 281 | def evaluate(data_loader, model, criterion, evaluator, device, args, name="test"): 282 | 283 | # switch to evaluation mode 284 | model.eval() 285 | 286 | # TODO: this is needed ... cuz using hack tfds wrapper. 287 | dataset = data_loader.dataset 288 | dataset.reset_itr() 289 | len_data = len(dataset) 290 | data_loader = torch.utils.data.DataLoader(dataset, 1, shuffle=False) 291 | 292 | loss_value = 1e12 293 | ari_running = {'total': 0, 'count': 0} 294 | ari_nobg_running = {'total': 0, 'count': 0} 295 | for i_batch, (video, boxes, segmentations, flow, padding_mask, mask) in enumerate(data_loader): 296 | # need to squeeze because of weird dataset wrapping ... 297 | video = video.squeeze(0).to(device, non_blocking=True) # [64, 6, 64, 64, 3] 298 | boxes = boxes.squeeze(0).to(device, non_blocking=True) 299 | flow = flow.squeeze(0).to(device, non_blocking=True) 300 | padding_mask = padding_mask.squeeze(0).to(device, non_blocking=True) 301 | mask = mask.squeeze(0).to(device, non_blocking=True) if len(mask) > 0 else None 302 | segmentations = segmentations.squeeze(0).to(device, non_blocking=True) 303 | batch = (video, boxes, segmentations, flow, padding_mask, mask) 304 | 305 | conditioning = boxes # TODO: don't hardcode 306 | 307 | # compute output 308 | if args.model_type == "savi": 309 | outputs = savi.modules.evaluator.eval_step(model, batch, slice_size=args.eval_slice_size) 310 | else: 311 | outputs = model(video=video, conditioning=conditioning, 312 | padding_mask=padding_mask, **args.kwargs) 313 | loss = criterion(outputs, batch) 314 | loss = loss.mean() # mean over devices 315 | loss_value = loss.item() 316 | 317 | ari_bg, ari_nobg = evaluator(outputs, batch, args) 318 | 319 | for k, v in ari_bg.items(): 320 | ari_running[k] += v.item() 321 | for k, v in ari_nobg.items(): 322 | ari_nobg_running[k] += v.item() 323 | 324 | # print(f"{i_batch+1} / {len_data}, loss: {loss_value}, running_ari_fg: {ari_nobg_running['total'] / ari_nobg_running['count']}", end='\r') 325 | print(f"{i_batch+1} / {len_data}, loss: {loss_value}, running_ari: {ari_running['total'] / ari_running['count']}, running_ari_fg: {ari_nobg_running['total'] / ari_nobg_running['count']}", end='\r') 326 | 327 | # visualize first 3 iterations 328 | if i_batch == 0: 329 | for i_sample in range(3): 330 | if args.model_type == "savi": 331 | B, T, H, W, _ = video.shape 332 | # attn = outputs['attention'][0].squeeze(1) 333 | attn = outputs[2][i_sample].squeeze(1) 334 | attn = attn.reshape(shape=(attn.shape[0], args.num_slots, int(attn.shape[-1] ** (1/2)), int(attn.shape[-1] ** (1/2)))) 335 | if attn.shape[-2:] != video.shape[-3:-1]: 336 | attn = F.interpolate(attn, size=video.shape[-3:-1], mode='bilinear', align_corners=True).view(attn.shape[0], args.num_slots, *video.shape[-3:-1]) 337 | # pr_flow = outputs['outputs']['flow'][0] 338 | pr_flow = outputs[1][i_sample] 339 | # pr_seg = outputs['outputs']['segmentations'][0].squeeze(-1) 340 | pr_seg = outputs[0][i_sample].squeeze(-1) 341 | else: 342 | pr_flow = outputs[2][i_sample] 343 | B, T, H, W, _ = video.shape 344 | pr_flow = torch.cat([torch.zeros(1,H,W,2).to(pr_flow.get_device()), pr_flow], dim=0) 345 | pr_flow = torchvision.utils.flow_to_image( 346 | pr_flow.permute(0,3,1,2)).permute(0,2,3,1).reshape(shape=(T, H, W, 3)) 347 | attn = outputs[4][i_sample] 348 | attn = attn.reshape(shape=(T, H, W, args.num_slots)).permute(0, 3, 1, 2) 349 | pr_seg = outputs[1][i_sample].squeeze(-1) 350 | pr_vid = outputs[0][i_sample] 351 | pr_vid = torch.clamp(pr_vid, 0.0, 1.0) 352 | pr_flow = torch.clamp(pr_flow, 0.0, 1.0) 353 | # visualize attention 354 | misc.viz_slots_flow(video[i_sample].cpu().numpy(), 355 | flow[i_sample].cpu().numpy(), pr_flow.cpu().numpy(), attn.cpu().numpy(), 356 | f"./experiments/{args.group}_{args.exp}/viz_slots_flow/{name}_{i_batch}.png", 357 | trunk=8, send_to_wandb=True if args.wandb else False) 358 | # visualize attention again 359 | if args.model_type == "flow": 360 | misc.viz_slots_frame_pred(video[i_sample].cpu().numpy(), 361 | pr_vid.cpu().numpy(), pr_flow.cpu().numpy(), attn.cpu().numpy(), 362 | f"./experiments/{args.group}_{args.exp}/viz_slots_frame_pred/{name}_{i_batch}.png", 363 | trunk=6, send_to_wandb=True if args.wandb else False) 364 | # visualize segmentation 365 | misc.viz_seg(video[i_sample].cpu().numpy(), 366 | segmentations[i_sample].int().cpu().numpy(), 367 | pr_seg.int().cpu().numpy(), 368 | f"./experiments/{args.group}_{args.exp}/viz_seg/{name}_{i_batch}.png", 369 | trunk=10, send_to_wandb=True if args.wandb else False) 370 | final_loss = loss_value 371 | final_ari = ari_running['total'] / ari_running['count'] 372 | final_ari_nobg = ari_nobg_running['total'] / ari_nobg_running['count'] 373 | 374 | print(f"{name}: loss: {final_loss}, ari_bg: {final_ari}, ari_fg: {final_ari_nobg}") 375 | # print(f"{name}: loss: {final_loss}, ari_fg: {final_ari_nobg}") 376 | 377 | # switch back to training 378 | model.train() 379 | 380 | # TODO: log (tensorboard or csv) 381 | if args.wandb: 382 | wandb.log({'eval/loss': final_loss, 'eval/ari': final_ari, 'eval/ari_fg': final_ari_nobg}) 383 | # only log foreground ari ... 384 | # wandb.log({'eval/loss': final_loss, 'eval/ari_fg': final_ari_nobg}) 385 | 386 | return final_loss, final_ari, final_ari_nobg 387 | 388 | 389 | def run(args): 390 | 391 | if args.wandb: 392 | wandb.init(project="savi_new", name=args.exp, group=args.group) 393 | # TODO: tensorboard or csv 394 | 395 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 396 | print("{}".format(args).replace(', ', ',\n')) 397 | 398 | device = torch.device(args.gpu[0]) 399 | 400 | # fix the seed for reproducibility 401 | seed = args.seed 402 | torch.manual_seed(seed) 403 | np.random.seed(seed) 404 | random.seed(seed) 405 | 406 | dataset_train, dataset_val = build_datasets(args) 407 | 408 | # Not using DistributedDataParallel ... only DataParallel 409 | # Need to set batch size to 1 because only passing through the torch dataset interface 410 | train_loader = torch.utils.data.DataLoader(dataset_train, 1, shuffle=False) 411 | val_loader = torch.utils.data.DataLoader(dataset_val, 1, shuffle=False) 412 | 413 | # Model setup 414 | model, criterion, evaluator = processors_dict[args.model_type](args) 415 | model = model.to(device) 416 | criterion = criterion.to(device) 417 | evaluator = evaluator.to(device) 418 | 419 | # print parameter overview # TODO: log this 420 | print(misc.parameter_overview(model)) 421 | print(model) 422 | 423 | # build optimizer 424 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 425 | 426 | if args.resume_from is not None: 427 | _, resume_step = misc.load_snapshot(model, optimizer, device, args.resume_from) 428 | 429 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 430 | # print("Model = %s" % str(model_without_ddp)) 431 | print('number of params (M): %.2f' % (n_parameters / 1.e6)) 432 | print("lr: %.2e" % args.lr) 433 | print(f"effective batch size: {args.batch_size * args.accum_iter}") 434 | 435 | # Loss 436 | print("criterion = %s" % str(criterion)) 437 | 438 | # make dataparallel 439 | model = nn.DataParallel(model, device_ids=args.gpu) 440 | criterion = nn.DataParallel(criterion, device_ids=args.gpu) 441 | 442 | print(f"Start training for {args.num_train_steps} steps.") 443 | start_time = datetime.now() 444 | global_step = resume_step if args.resume_from is not None else 0 445 | 446 | # eval only 447 | if args.eval: 448 | # assert isinstance(args.resume_from, str), "no snapshot given." 449 | evaluate(val_loader, model, criterion, evaluator, device, args, f"eval") 450 | # evaluate(train_loader, model, criterion, evaluator, device, args, f"eval") 451 | sys.exit(1) 452 | 453 | for epoch in range(args.epochs): 454 | step_add, loss = train_one_epoch( 455 | model, criterion, train_loader, 456 | optimizer, device, epoch, 457 | global_step, start_time, 458 | args.max_grad_norm, args, 459 | val_loader, evaluator 460 | ) 461 | global_step += step_add 462 | print(f"epoch: {epoch+1}, loss: {loss}, clock: {datetime.now()-start_time}") 463 | 464 | evaluate(val_loader, model, 465 | criterion, device, args, 466 | f"epoch_{epoch+1}") 467 | 468 | if not args.no_snap: 469 | misc.save_snapshot(args, model.module, optimizer, global_step, f'./experiments/{args.exp}/snapshots/{epoch+1}.pt') 470 | 471 | # global stepper 472 | if global_step >= args.num_train_steps: 473 | break 474 | 475 | def main(): 476 | args = get_args() 477 | 478 | # if args.output_dir: 479 | # Path(args.output_dir).mkdir(parents=True, exist_ok=True) 480 | 481 | run(args) 482 | 483 | def test(): 484 | main() 485 | # import ipdb 486 | 487 | # args = get_args() 488 | 489 | # dataset_train, dataset_val = build_datasets(args) 490 | # dataloader = DataLoader(dataset_train, 1, shuffle=False) 491 | 492 | # for i, out in enumerate(dataloader): 493 | # print(i, [a.shape for a in out], end='\r') 494 | 495 | # ipdb.set_trace() 496 | 497 | """ 498 | 499 | python -m savi.main --gpu 1,2,3,4 500 | 501 | """ -------------------------------------------------------------------------------- /savi/trainers/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # FIXME -------------------------------------------------------------------------------- /savi/trainers/utils/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import json 13 | 14 | 15 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 16 | """ 17 | Parameter groups for layer-wise lr decay 18 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 19 | """ 20 | param_group_names = {} 21 | param_groups = {} 22 | 23 | num_layers = len(model.blocks) + 1 24 | 25 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 26 | 27 | for n, p in model.named_parameters(): 28 | if not p.requires_grad: 29 | continue 30 | 31 | # no decay: all 1D parameters and model specific ones 32 | if p.ndim == 1 or n in no_weight_decay_list: 33 | g_decay = "no_decay" 34 | this_decay = 0. 35 | else: 36 | g_decay = "decay" 37 | this_decay = weight_decay 38 | 39 | layer_id = get_layer_id_for_vit(n, num_layers) 40 | group_name = "layer_%d_%s" % (layer_id, g_decay) 41 | 42 | if group_name not in param_group_names: 43 | this_scale = layer_scales[layer_id] 44 | 45 | param_group_names[group_name] = { 46 | "lr_scale": this_scale, 47 | "weight_decay": this_decay, 48 | "params": [], 49 | } 50 | param_groups[group_name] = { 51 | "lr_scale": this_scale, 52 | "weight_decay": this_decay, 53 | "params": [], 54 | } 55 | 56 | param_group_names[group_name]["params"].append(n) 57 | param_groups[group_name]["params"].append(p) 58 | 59 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 60 | 61 | return list(param_groups.values()) 62 | 63 | 64 | def get_layer_id_for_vit(name, num_layers): 65 | """ 66 | Assign a parameter with its layer id 67 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 68 | """ 69 | if name in ['cls_token', 'pos_embed']: 70 | return 0 71 | elif name.startswith('patch_embed'): 72 | return 0 73 | elif name.startswith('blocks'): 74 | return int(name.split('.')[1]) + 1 75 | else: 76 | return num_layers 77 | 78 | 79 | def linear_warmup(optimizer, step, warmup_steps, peak_lr): 80 | if step > warmup_steps: 81 | return None 82 | lr = step * (peak_lr / (warmup_steps)) 83 | 84 | for param_group in optimizer.param_groups: 85 | param_group["lr"] = lr 86 | return lr -------------------------------------------------------------------------------- /savi/trainers/utils/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from torch.optim.lr_scheduler import LambdaLR 7 | 8 | import math 9 | 10 | def adjust_learning_rate(optimizer, lr, step, warmup_steps): 11 | """Decay the learning rate with half-cycle cosine after warmup""" 12 | if step < warmup_steps: 13 | lr = args.lr * epoch / args.warmup_epochs 14 | else: 15 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 16 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 17 | for param_group in optimizer.param_groups: 18 | if "lr_scale" in param_group: 19 | param_group["lr"] = lr * param_group["lr_scale"] 20 | else: 21 | param_group["lr"] = lr 22 | return lr 23 | 24 | # TODO: 25 | def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1): 26 | """ Create a schedule with a learning rate that decreases following the 27 | values of the cosine function between 0 and `pi * cycles` after a warmup 28 | period during which it increases linearly between 0 and 1. 29 | """ 30 | def lr_lambda(current_step): 31 | if current_step < num_warmup_steps: 32 | return float(current_step) / float(max(1, num_warmup_steps)) 33 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 34 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) 35 | 36 | return LambdaLR(optimizer, lr_lambda, last_epoch) -------------------------------------------------------------------------------- /savi/trainers/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import builtins 13 | import datetime 14 | import os 15 | import time 16 | from collections import defaultdict, deque 17 | from pathlib import Path 18 | from pyparsing import line_end 19 | from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union 20 | import dataclasses 21 | 22 | import torch 23 | import torch.nn as nn 24 | import torch.distributed as dist 25 | from torch._six import inf 26 | 27 | import matplotlib.pyplot as plt 28 | import numpy as np 29 | from skimage.color import label2rgb 30 | 31 | import wandb 32 | 33 | class SmoothedValue(object): 34 | """Track a series of values and provide access to smoothed values over a 35 | window or the global series average. 36 | """ 37 | 38 | def __init__(self, window_size=20, fmt=None): 39 | if fmt is None: 40 | fmt = "{median:.4f} ({global_avg:.4f})" 41 | self.deque = deque(maxlen=window_size) 42 | self.total = 0.0 43 | self.count = 0 44 | self.fmt = fmt 45 | 46 | def update(self, value, n=1): 47 | self.deque.append(value) 48 | self.count += n 49 | self.total += value * n 50 | 51 | def synchronize_between_processes(self): 52 | """ 53 | Warning: does not synchronize the deque! 54 | """ 55 | if not is_dist_avail_and_initialized(): 56 | return 57 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 58 | dist.barrier() 59 | dist.all_reduce(t) 60 | t = t.tolist() 61 | self.count = int(t[0]) 62 | self.total = t[1] 63 | 64 | @property 65 | def median(self): 66 | d = torch.tensor(list(self.deque)) 67 | return d.median().item() 68 | 69 | @property 70 | def avg(self): 71 | d = torch.tensor(list(self.deque), dtype=torch.float32) 72 | return d.mean().item() 73 | 74 | @property 75 | def global_avg(self): 76 | return self.total / self.count 77 | 78 | @property 79 | def max(self): 80 | return max(self.deque) 81 | 82 | @property 83 | def value(self): 84 | return self.deque[-1] 85 | 86 | def __str__(self): 87 | return self.fmt.format( 88 | median=self.median, 89 | avg=self.avg, 90 | global_avg=self.global_avg, 91 | max=self.max, 92 | value=self.value) 93 | 94 | 95 | class MetricLogger(object): 96 | def __init__(self, delimiter="\t"): 97 | self.meters = defaultdict(SmoothedValue) 98 | self.delimiter = delimiter 99 | 100 | def update(self, **kwargs): 101 | for k, v in kwargs.items(): 102 | if v is None: 103 | continue 104 | if isinstance(v, torch.Tensor): 105 | v = v.item() 106 | assert isinstance(v, (float, int)) 107 | self.meters[k].update(v) 108 | 109 | def __getattr__(self, attr): 110 | if attr in self.meters: 111 | return self.meters[attr] 112 | if attr in self.__dict__: 113 | return self.__dict__[attr] 114 | raise AttributeError("'{}' object has no attribute '{}'".format( 115 | type(self).__name__, attr)) 116 | 117 | def __str__(self): 118 | loss_str = [] 119 | for name, meter in self.meters.items(): 120 | loss_str.append( 121 | "{}: {}".format(name, str(meter)) 122 | ) 123 | return self.delimiter.join(loss_str) 124 | 125 | def synchronize_between_processes(self): 126 | for meter in self.meters.values(): 127 | meter.synchronize_between_processes() 128 | 129 | def add_meter(self, name, meter): 130 | self.meters[name] = meter 131 | 132 | def log_every(self, iterable, print_freq, header=None): 133 | i = 0 134 | if not header: 135 | header = '' 136 | start_time = time.time() 137 | end = time.time() 138 | iter_time = SmoothedValue(fmt='{avg:.4f}') 139 | data_time = SmoothedValue(fmt='{avg:.4f}') 140 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 141 | log_msg = [ 142 | header, 143 | '[{0' + space_fmt + '}/{1}]', 144 | 'eta: {eta}', 145 | '{meters}', 146 | 'time: {time}', 147 | 'data: {data}' 148 | ] 149 | if torch.cuda.is_available(): 150 | log_msg.append('max mem: {memory:.0f}') 151 | log_msg = self.delimiter.join(log_msg) 152 | MB = 1024.0 * 1024.0 153 | for obj in iterable: 154 | data_time.update(time.time() - end) 155 | yield obj 156 | iter_time.update(time.time() - end) 157 | if i % print_freq == 0 or i == len(iterable) - 1: 158 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 159 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 160 | if torch.cuda.is_available(): 161 | print(log_msg.format( 162 | i, len(iterable), eta=eta_string, 163 | meters=str(self), 164 | time=str(iter_time), data=str(data_time), 165 | memory=torch.cuda.max_memory_allocated() / MB)) 166 | else: 167 | print(log_msg.format( 168 | i, len(iterable), eta=eta_string, 169 | meters=str(self), 170 | time=str(iter_time), data=str(data_time))) 171 | i += 1 172 | end = time.time() 173 | total_time = time.time() - start_time 174 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 175 | print('{} Total time: {} ({:.4f} s / it)'.format( 176 | header, total_time_str, total_time / len(iterable))) 177 | 178 | 179 | def setup_for_distributed(is_master): 180 | """ 181 | This function disables printing when not in master process 182 | """ 183 | builtin_print = builtins.print 184 | 185 | def print(*args, **kwargs): 186 | force = kwargs.pop('force', False) 187 | force = force or (get_world_size() > 8) 188 | if is_master or force: 189 | now = datetime.datetime.now().time() 190 | builtin_print('[{}] '.format(now), end='') # print with time stamp 191 | builtin_print(*args, **kwargs) 192 | 193 | builtins.print = print 194 | 195 | 196 | def is_dist_avail_and_initialized(): 197 | if not dist.is_available(): 198 | return False 199 | if not dist.is_initialized(): 200 | return False 201 | return True 202 | 203 | 204 | def get_world_size(): 205 | if not is_dist_avail_and_initialized(): 206 | return 1 207 | return dist.get_world_size() 208 | 209 | 210 | def get_rank(): 211 | if not is_dist_avail_and_initialized(): 212 | return 0 213 | return dist.get_rank() 214 | 215 | 216 | def is_main_process(): 217 | return get_rank() == 0 218 | 219 | 220 | def save_on_master(*args, **kwargs): 221 | if is_main_process(): 222 | torch.save(*args, **kwargs) 223 | 224 | 225 | def init_distributed_mode(args): 226 | if args.dist_on_itp: 227 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 228 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 229 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 230 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 231 | os.environ['LOCAL_RANK'] = str(args.gpu) 232 | os.environ['RANK'] = str(args.rank) 233 | os.environ['WORLD_SIZE'] = str(args.world_size) 234 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 235 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 236 | args.rank = int(os.environ["RANK"]) 237 | args.world_size = int(os.environ['WORLD_SIZE']) 238 | args.gpu = int(os.environ['LOCAL_RANK']) 239 | elif 'SLURM_PROCID' in os.environ: 240 | args.rank = int(os.environ['SLURM_PROCID']) 241 | args.gpu = args.rank % torch.cuda.device_count() 242 | else: 243 | print('Not using distributed mode') 244 | setup_for_distributed(is_master=True) # hack 245 | args.distributed = False 246 | return 247 | 248 | args.distributed = True 249 | 250 | torch.cuda.set_device(args.gpu) 251 | args.dist_backend = 'nccl' 252 | print('| distributed init (rank {}): {}, gpu {}'.format( 253 | args.rank, args.dist_url, args.gpu), flush=True) 254 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 255 | world_size=args.world_size, rank=args.rank) 256 | torch.distributed.barrier() 257 | setup_for_distributed(args.rank == 0) 258 | 259 | 260 | class NativeScalerWithGradNormCount: 261 | state_dict_key = "amp_scaler" 262 | 263 | def __init__(self): 264 | self._scaler = torch.cuda.amp.GradScaler() 265 | 266 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 267 | self._scaler.scale(loss).backward(create_graph=create_graph) 268 | if update_grad: 269 | if clip_grad is not None: 270 | assert parameters is not None 271 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 272 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 273 | else: 274 | self._scaler.unscale_(optimizer) 275 | norm = get_grad_norm_(parameters) 276 | self._scaler.step(optimizer) 277 | self._scaler.update() 278 | else: 279 | norm = None 280 | return norm 281 | 282 | def state_dict(self): 283 | return self._scaler.state_dict() 284 | 285 | def load_state_dict(self, state_dict): 286 | self._scaler.load_state_dict(state_dict) 287 | 288 | 289 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 290 | if isinstance(parameters, torch.Tensor): 291 | parameters = [parameters] 292 | parameters = [p for p in parameters if p.grad is not None] 293 | norm_type = float(norm_type) 294 | if len(parameters) == 0: 295 | return torch.tensor(0.) 296 | device = parameters[0].grad.device 297 | if norm_type == inf: 298 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 299 | else: 300 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 301 | return total_norm 302 | 303 | 304 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): 305 | output_dir = Path(args.output_dir) 306 | epoch_name = str(epoch) 307 | if loss_scaler is not None: 308 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 309 | for checkpoint_path in checkpoint_paths: 310 | to_save = { 311 | 'model': model_without_ddp.state_dict(), 312 | 'optimizer': optimizer.state_dict(), 313 | 'epoch': epoch, 314 | 'scaler': loss_scaler.state_dict(), 315 | 'args': args, 316 | } 317 | 318 | save_on_master(to_save, checkpoint_path) 319 | else: 320 | client_state = {'epoch': epoch} 321 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) 322 | 323 | 324 | def load_model(args, model_without_ddp, optimizer, loss_scaler): 325 | if args.resume: 326 | if args.resume.startswith('https'): 327 | checkpoint = torch.hub.load_state_dict_from_url( 328 | args.resume, map_location='cpu', check_hash=True) 329 | else: 330 | checkpoint = torch.load(args.resume, map_location='cpu') 331 | model_without_ddp.load_state_dict(checkpoint['model']) 332 | print("Resume checkpoint %s" % args.resume) 333 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 334 | optimizer.load_state_dict(checkpoint['optimizer']) 335 | args.start_epoch = checkpoint['epoch'] + 1 336 | if 'scaler' in checkpoint: 337 | loss_scaler.load_state_dict(checkpoint['scaler']) 338 | print("With optim & sched!") 339 | 340 | 341 | def all_reduce_mean(x): 342 | world_size = get_world_size() 343 | if world_size > 1: 344 | x_reduce = torch.tensor(x).cuda() 345 | dist.all_reduce(x_reduce) 346 | x_reduce /= world_size 347 | return x_reduce.item() 348 | else: 349 | return x 350 | 351 | 352 | #################################### 353 | # I added these 354 | 355 | @dataclasses.dataclass 356 | class ParamRow: 357 | name: str 358 | shape: Tuple[int] 359 | size: int 360 | 361 | 362 | @dataclasses.dataclass 363 | class ParamRowWithStats(ParamRow): 364 | mean: float 365 | std: float 366 | 367 | def _default_table_value_formatter(value): 368 | """Formats ints with "," between thousands and floats to 3 digits.""" 369 | if isinstance(value, bool): 370 | return str(value) 371 | elif isinstance(value, int): 372 | return "{:,}".format(value) 373 | elif isinstance(value, float): 374 | return "{:.3}".format(value) 375 | else: 376 | return str(value) 377 | 378 | def make_table( 379 | rows: List[Any], 380 | *, 381 | column_names: Optional[Sequence[str]] = None, 382 | value_formatter: Callable[[Any], str] = _default_table_value_formatter, 383 | max_lines: Optional[int] = None, 384 | ) -> str: 385 | """Renders a list of rows to a table. 386 | 387 | Args: 388 | rows: List of dataclass instances of a single type (e.g. `ParamRow`). 389 | column_names: List of columns that that should be included in the output. If 390 | not provided, then the columns are taken from keys of the first row. 391 | value_formatter: Callable used to format cell values. 392 | max_lines: Don't render a table longer than this. 393 | 394 | Returns: 395 | A string representation of the table in the form: 396 | 397 | +---------+---------+ 398 | | Col1 | Col2 | 399 | +---------+---------+ 400 | | value11 | value12 | 401 | | value21 | value22 | 402 | +---------+---------+ 403 | """ 404 | 405 | if any(not dataclasses.is_dataclass(row) for row in rows): 406 | raise ValueError("Expected `rows` to be list of dataclasses") 407 | if len(set(map(type, rows))) > 1: 408 | raise ValueError("Expected elements of `rows` be of same type.") 409 | 410 | class Column: 411 | 412 | def __init__(self, name, values): 413 | self.name = name.capitalize() 414 | self.values = values 415 | self.width = max(len(v) for v in values + [name]) 416 | 417 | if column_names is None: 418 | if not rows: 419 | return "(empty table)" 420 | column_names = [field.name for field in dataclasses.fields(rows[0])] 421 | 422 | columns = [ 423 | Column(name, [value_formatter(getattr(row, name)) 424 | for row in rows]) 425 | for name in column_names 426 | ] 427 | 428 | var_line_format = "|" + "".join(f" {{: <{c.width}s}} |" for c in columns) 429 | sep_line_format = var_line_format.replace(" ", "-").replace("|", "+") 430 | header = var_line_format.replace(">", "<").format(*[c.name for c in columns]) 431 | separator = sep_line_format.format(*["" for c in columns]) 432 | 433 | lines = [separator, header, separator] 434 | for i in range(len(rows)): 435 | if max_lines and len(lines) >= max_lines - 3: 436 | lines.append("[...]") 437 | break 438 | lines.append(var_line_format.format(*[c.values[i] for c in columns])) 439 | lines.append(separator) 440 | 441 | return "\n".join(lines) 442 | 443 | def parameter_overview(model: nn.Module): 444 | rows = [] 445 | for name, value in model.named_parameters(): 446 | rows.append(ParamRowWithStats( 447 | name=name, shape=tuple(value.shape), 448 | size=int(np.prod(value.shape)), 449 | mean=float(value.mean()), 450 | std=float(value.std()))) 451 | total_weights = sum([np.prod(v.shape) for v in model.parameters()]) 452 | column_names = [field.name for field in dataclasses.fields(ParamRowWithStats)] 453 | table = make_table(rows, column_names=column_names) 454 | return table + f"\nTotal: {total_weights:,}" 455 | 456 | # TODO: make output path absolute and not assuming an experiments dir 457 | def save_snapshot(args, model, optimizer, global_step, output_fn): 458 | print('saving model.') 459 | os.makedirs(os.path.dirname(output_fn), exist_ok=True) 460 | payload = { 461 | 'model': model.state_dict(), 462 | 'optimizer': optimizer.state_dict(), 463 | 'global_step': global_step, 464 | 'args': args 465 | } 466 | torch.save(payload, output_fn) 467 | print('saved model.') 468 | 469 | 470 | def load_snapshot(model, optimizer, device, name): 471 | print('loading model.') 472 | snapshot_path = name 473 | payload = torch.load(snapshot_path, map_location=device) 474 | model.load_state_dict(payload['model']) 475 | optimizer.load_state_dict(payload['optimizer']) 476 | print('loaded model.') 477 | return payload['args'], payload['global_step'] 478 | 479 | ####################### 480 | 481 | def plot_image(ax, img, label=None): 482 | ax.imshow(img) 483 | ax.axis('off') 484 | ax.set_xticks([]) 485 | ax.set_yticks([]) 486 | if label: 487 | # ax.set_title(label, fontsize=3, y=-21) 488 | ax.set_xlabel(label, fontsize=3) 489 | ax.axis('on') 490 | 491 | 492 | def viz_seg(vid, gt_mask, pr_mask, output_fn, trunk=None, send_to_wandb=False): 493 | """ 494 | Plot the video, gt seg and pred seg masks 495 | 496 | Args: 497 | vid: (L H W C) 498 | gt_mask: (L H W C) 499 | pred_mask: (L H W C) 500 | output_fn: save path 501 | trunk: truncate temporal dim for viz clarity 502 | """ 503 | if trunk is None: 504 | trunk = len(vid) 505 | T = min(len(vid), trunk) 506 | os.makedirs(os.path.dirname(output_fn), exist_ok=True) 507 | 508 | plt.close() 509 | fig, ax = plt.subplots(T, 3, dpi=400) 510 | 511 | for t in range(T): 512 | gt_seg = label2rgb(gt_mask[t], vid[t]) 513 | pred_seg = label2rgb(pr_mask[t], vid[t]) 514 | 515 | plot_image(ax[t, 0], vid[t], 'original') 516 | plot_image(ax[t, 1], gt_seg, 'gt_seg') 517 | plot_image(ax[t, 2], pred_seg, 'pred_seg') 518 | 519 | plt.savefig(output_fn) 520 | plt.show() 521 | 522 | if send_to_wandb: 523 | wandb.log({ 524 | "eval/seg": 525 | wandb.Image(plt.gcf()) 526 | }) 527 | 528 | 529 | def viz_slots_flow(vid, gt_flow, pr_flow, mask, output_fn, trunk=None, send_to_wandb=False): 530 | """ 531 | Plot the video and slots 532 | 533 | Args: 534 | vid, flow: (L H W C) 535 | mask: (L num_objects H W) 536 | trunk: truncate temporal dim for viz clarity 537 | """ 538 | if trunk is None: 539 | trunk = len(vid) 540 | T = min(len(vid), trunk) 541 | n_objs = mask.shape[1] 542 | os.makedirs(os.path.dirname(output_fn), exist_ok=True) 543 | 544 | slots = vid[:, np.newaxis, :, :, :] * mask[:, :, :, :, np.newaxis] 545 | 546 | plt.close() 547 | fig, ax = plt.subplots(T, n_objs+3, dpi=400) 548 | 549 | for t in range(T): 550 | 551 | plot_image(ax[t, 0], vid[t], 'frame') 552 | plot_image(ax[t, 1], gt_flow[t], 'gt_flow') 553 | plot_image(ax[t, 2], pr_flow[t], 'pred_flow') 554 | 555 | for obj in range(3, n_objs+3): 556 | plot_image(ax[t, obj], slots[t, obj-3], f'slot {obj-2}') 557 | 558 | plt.savefig(output_fn) 559 | plt.show() 560 | 561 | if send_to_wandb: 562 | wandb.log({ 563 | "eval/slots_flow": 564 | wandb.Image(plt.gcf()) 565 | }) 566 | 567 | 568 | def viz_slots_frame_pred(vid, pred_frame, pred_flow, mask, output_fn, trunk=None, send_to_wandb=False): 569 | """ 570 | Plot the video and slots 571 | 572 | Args: 573 | vid, pred_frame: (L H W C) 574 | mask: (L num_objects H W) 575 | trunk: truncate temporal dim for viz clarity 576 | """ 577 | if trunk is None: 578 | trunk = len(vid) 579 | T = min(len(vid), trunk) 580 | n_objs = mask.shape[1] 581 | os.makedirs(os.path.dirname(output_fn), exist_ok=True) 582 | 583 | slots = vid[:, np.newaxis, :, :, :] * mask[:, :, :, :, np.newaxis] 584 | 585 | plt.close() 586 | fig, ax = plt.subplots(T, n_objs+3, dpi=400) 587 | 588 | for t in range(T): 589 | 590 | plot_image(ax[t, 0], vid[t], 'frame') 591 | plot_image(ax[t, 1], pred_frame[t], 'pred_frame') 592 | plot_image(ax[t, 2], pred_flow[t], 'pred_flow') 593 | 594 | for obj in range(3, n_objs+3): 595 | plot_image(ax[t, obj], slots[t, obj-3], f'slot {obj-2}') 596 | 597 | plt.savefig(output_fn) 598 | plt.show() 599 | 600 | if send_to_wandb: 601 | wandb.log({ 602 | "eval/slots_pred_frame": 603 | wandb.Image(plt.gcf()) 604 | }) 605 | 606 | ####################### --------------------------------------------------------------------------------