├── Dockerfile ├── LICENSE ├── README.md ├── environment.yml ├── imgs └── slash_icon.png ├── requirements.txt ├── setup.sh └── src ├── README.md ├── SLASH ├── __init__.py ├── mvpp.py └── slash.py ├── __init__.py ├── datasets.py ├── einsum_wrapper.py ├── experiments ├── __init__.py ├── baseline_slot_attention │ ├── dataGen.py │ ├── set_utils.py │ ├── train.py │ └── train.sh ├── mnist_top_k │ ├── __init__.py │ ├── dataGen.py │ ├── data_generation.ipynb │ ├── network_nn.py │ ├── train.py │ ├── train_baseline.sh │ ├── train_same.sh │ └── train_top_k.sh ├── slash_attention │ ├── ap_utils.py │ ├── clevr │ │ ├── __init__.py │ │ ├── auxiliary.py │ │ ├── dataGen.py │ │ ├── slash_attention_clevr.py │ │ └── train.py │ ├── clevr_cogent │ │ ├── __init__.py │ │ ├── auxiliary.py │ │ ├── dataGen.py │ │ ├── slash_attention_clevr.py │ │ └── train.py │ ├── cogent │ │ ├── __init__.py │ │ ├── dataGen.py │ │ ├── slash_attention_cogent.py │ │ └── train.py │ └── shapeworld4 │ │ ├── __init__.py │ │ ├── dataGen.py │ │ ├── slash_attention_shapeworld4.py │ │ └── train.py └── vqa │ ├── __init__.py │ ├── cmd_args2.py │ ├── dataGen.py │ ├── knowledge_graph.py │ ├── models.py │ ├── network_nn.py │ ├── preprocess.py │ ├── query_lib.py │ ├── sg_model.py │ ├── test.py │ ├── test.sh │ ├── train.py │ ├── train.sh │ ├── trainer.py │ ├── vqa_utils.py │ └── word_idx_translator.py ├── slot_attention_module.py └── utils.py /Dockerfile: -------------------------------------------------------------------------------- 1 | # Select the base image 2 | #old 21.06-py3 3 | #FROM nvcr.io/nvidia/pytorch:21.12-py3 4 | FROM nvcr.io/nvidia/pytorch:21.06-py3 5 | 6 | 7 | # Select the working directory 8 | WORKDIR /SLASH 9 | 10 | # Install PyTorch 11 | #RUN pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html 12 | 13 | # Install Python requirements 14 | COPY ./requirements.txt ./requirements.txt 15 | RUN pip install --ignore-installed -r requirements.txt 16 | # RUN conda install -c potassco clingo=5.5.0 17 | # RUN pip install git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git ; pip install rtpt torchsummary ; python -m pip install -U scikit-image 18 | # RUN conda install -c conda-forge tqdm 19 | # RUN conda install -c anaconda seaborn 20 | # RUN conda install -c conda-forge scikit-learn 21 | # RUN conda install -c conda-forge tensorboard 22 | 23 | # Remove fontlist file so MPL can find the serif fonts 24 | RUN rm -rf /.matplotlib/fontlist-v330.json 25 | 26 | # Setup mpl config dir since SciencePlots install .mplstyle files into this dir 27 | RUN mkdir -p /.matplotlib/stylelib 28 | RUN chmod a+rwx -R /.matplotlib 29 | ENV MPLCONFIGDIR=/.matplotlib 30 | 31 | # Add fonts for serif rendering in MPL plots 32 | RUN apt-get update 33 | RUN echo ttf-mscorefonts-installer msttcorefonts/accepted-mscorefonts-eula select true | debconf-set-selections 34 | RUN apt-get install --yes ttf-mscorefonts-installer 35 | RUN ln -snf /usr/share/zoneinfo/Etc/UTC /etc/localtime 36 | RUN apt-get install dvipng cm-super fonts-cmu --yes 37 | RUN apt-get install fonts-dejavu-core --yes 38 | RUN apt install -y tmux 39 | # RUN fc-cache -f -v 40 | 41 | RUN python3 -m pip install setuptools==59.5.0 42 | RUN pip install pandas==1.3.0 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 ml-research@TUDarmstadt 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Scalable Neural-Probabilistic Answer Set Programming 2 | 3 | Arseny Skryagin, Wolfgang Stammer, Daniel Ochs, Devendra Singh Dhami , Kristian Kersting 4 | 5 |

6 | 7 |

8 | 9 | [![MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 10 | 11 | # Abstract 12 | The goal of combining the robustness of neural networks and the expressiveness of symbolic 13 | methods has rekindled the interest in Neuro-Symbolic AI. Deep Probabilistic Programming 14 | Languages (DPPLs) have been developed for probabilistic logic programming to be carried 15 | out via the probability estimations of deep neural networks. However, recent SOTA DPPL 16 | approaches allow only for limited conditional probabilistic queries and do not offer the power 17 | of true joint probability estimation. In our work, we propose an easy integration of tractable 18 | probabilistic inference within a DPPL. To this end, we introduce SLASH, a novel DPPL 19 | that consists of Neural-Probabilistic Predicates (NPPs) and a logic program, united via 20 | answer set programming (ASP). NPPs are a novel design principle allowing for combining 21 | all deep model types and combinations thereof to be represented as a single probabilistic 22 | predicate. In this context, we introduce a novel +/− notation for answering various types 23 | of probabilistic queries by adjusting the atom notations of a predicate. To scale well, we 24 | show how to prune the stochastically insignificant parts of the (ground) program, speeding 25 | up reasoning without sacrificing the predictive performance. We evaluate SLASH on a 26 | variety of different tasks, including the benchmark task of MNIST addition and Visual 27 | Question Answering (VQA). 28 | 29 | 30 | 31 | 32 | ## Introduction 33 | This is the repository for SLASH, the deep declarative probabilistic programming language introduced within **Neural-Probabilistic Answer Set Programming** [![KR](https://img.shields.io/badge/Conference-KR2022-blue)](https://kr2022.cs.tu-dortmund.de/index.php) and **Scalable Neural-Probabilistic Answer Set Programming** [![JAIR](https://img.shields.io/badge/Conference-JAIR-blue.svg)](https://jair.org/index.php/jair/article/view/15027). [Link](https://jair.org/index.php/jair/article/view/15027) to the paper. 34 | 35 | 36 | 37 | ## 1. Prerequisites 38 | ### 1.1 Cloning SLASH and submodules 39 | To clone SLASH including all submodules use: 40 | ``` 41 | git clone --recurse-submodules -j8 https://github.com/ml-research/SLASH 42 | ``` 43 | 44 | ### 1.2 Anaconda 45 | The `environment.yaml` provides all packages needed to run a SLASH program using the Anaconda package manager. To create a new conda environment install Anaconda ([installation page](https://docs.anaconda.com/anaconda/install/)) and then create a new environment using the following command: 46 | ``` 47 | conda env create -n slash -f environment.yml 48 | conda activate env_dev 49 | pip install git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git 50 | ``` 51 | 52 | The following packages are installed: 53 | - pytorch 54 | - clingo # version 5.5.1 55 | - scikit-image 56 | - scikit-learn 57 | - seaborn 58 | - tqdm 59 | - rtpt 60 | - tensorboard # needed for standalone slot attention 61 | - torchsummary 62 | - GradualWarmupSheduler 63 | 64 | ### 1.3 Virtual environment 65 | To use virtualenv run the following commands. This creates a new virtual environment and installs all packages needed. Tested Python Versions 3.6. For using cuda version `10.x` with pytorch remove the `+cu113` version appendix in the `requirements.txt` file otherwise it will use CUDA version `11.3`. 66 | ``` 67 | python3 -m venv slash_env 68 | source slash_env/bin/activate 69 | pip install --upgrade pip 70 | python3 -m pip install -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html 71 | ``` 72 | 73 | ### 1.4 Docker 74 | Alternatively you can also run slash using docker. For this you first need to create an image using the provided `Dockerfile`. 75 | 76 | To start create an image named `slash` and then run the container using that image. In the slash base folder execute: 77 | ``` 78 | docker build . -t slash:0.01 79 | docker run --runtime=nvidia -e NVIDIA_VISIBLE_DEVICES= --ipc=host -it --rm -v /$(pwd):/slash slash:1.0 80 | ``` 81 | 82 | 83 | ## 2. Project Structure 84 | ``` 85 | --data: contains datasets after downloading 86 | --src: source files which has its own readme. Go there for more details. 87 | ``` 88 | 89 | 90 | ## 3. Getting started 91 | Go visit the ``` src/experiments/mnist_top_k``` folder to get familar with SLASH and run one of the training scripts. 92 | 93 | 94 | ## Citation 95 | 96 | If you use this code for your research, please cite our journal paper at [JAIR](https://jair.org/index.php/jair/article/view/15027) or the conference paper at [KR22](https://proceedings.kr.org/2022/48/) 97 | ``` 98 | @article{skryagin23JAIR, 99 | year = { 2023 }, 100 | crossref = { https://github.com/ml-research/SLASH }, 101 | title = { Scalable Neural-Probabilistic Answer Set Programming }, 102 | pages = { 579--617 }, 103 | volume = { 78 }, 104 | journal = { Journal of Artificial Intelligence Research (JAIR) }, 105 | author = {Skryagin, Arseny and Ochs, Daniel and Dhami, Devendra Singh and Kersting, Kristian}, 106 | } 107 | ``` 108 | ``` 109 | @inproceedings{skryagin2022KR, 110 | title={Neural-Probabilistic Answer Set Programming}, 111 | author={Arseny Skryagin and Wolfgang Stammer and Daniel Ochs and Devendra Singh Dhami and Kristian Kersting}, 112 | booktitle={Proceedings of the 19th International Conference on Principles of Knowledge Representation and Reasoning (KR)}, 113 | year={2022} 114 | } 115 | ``` 116 | 117 | ## Acknowlegments 118 | This work was partly supported by the Federal Minister of Education and Research (BMBF) and the Hessian Ministry of Science and the Arts (HMWK) within the National Research Center for Applied Cybersecurity ATHENE, the ICT-48 Network of AI Research Excellence Center "TAILOR" (EU Horizon 2020, GA No 952215, and the Collaboration Lab with Nexplore "AI in Construction" (AICO). It also benefited from the BMBF AI lighthouse project, the Hessian research priority programme LOEWE within the project WhiteBox, the HMWK cluster projects "The Third Wave of AI" and "The Adaptive Mind", the German Center for Artificial Intelligence (DFKI) project "SAINT". 119 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: env_slash 2 | channels: 3 | - anaconda 4 | - pytorch 5 | - nvidia 6 | - potassco 7 | - conda-forge 8 | - defaults 9 | dependencies: 10 | - _libgcc_mutex=0.1=main 11 | - absl-py=0.12.0=py36h06a4308_0 12 | - aiohttp=3.7.4=py36h27cfd23_1 13 | - anyio=3.1.0=py36h5fab9bb_0 14 | - argon2-cffi=20.1.0=py36h1d69622_2 15 | - async-timeout=3.0.1=py36h06a4308_0 16 | - async_generator=1.10=py_0 17 | - attrs=21.2.0=pyhd3eb1b0_0 18 | - babel=2.9.1=pyh44b312d_0 19 | - backports=1.0=py_2 20 | - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 21 | - blas=1.0=mkl 22 | - bleach=3.3.0=pyh44b312d_0 23 | - blinker=1.4=py36h06a4308_0 24 | - blosc=1.21.0=h8c45485_0 25 | - brotli=1.0.9=he6710b0_2 26 | - brotlipy=0.7.0=py36h27cfd23_1003 27 | - bzip2=1.0.8=h7b6447c_0 28 | - c-ares=1.17.1=h27cfd23_0 29 | - ca-certificates=2021.5.30=ha878542_0 30 | - cachetools=4.2.2=pyhd3eb1b0_0 31 | - certifi=2021.5.30=py36h5fab9bb_0 32 | - cffi=1.14.5=py36h261ae71_0 33 | - chardet=3.0.4=py36h06a4308_1003 34 | - charls=2.1.0=he6710b0_2 35 | - click=8.0.1=pyhd3eb1b0_0 36 | - clingo=5.4.0=py36h3fd9d12_1 37 | - cloudpickle=1.6.0=py_0 38 | - contextvars=2.4=py_0 39 | - coverage=5.5=py36h27cfd23_2 40 | - cryptography=3.4.7=py36hd23ed53_0 41 | - cudatoolkit=11.1.74=h6bb024c_0 42 | - cycler=0.10.0=py36_0 43 | - cython=0.29.23=py36h2531618_0 44 | - cytoolz=0.11.0=py36h7b6447c_0 45 | - dask-core=2021.3.0=pyhd3eb1b0_0 46 | - dataclasses=0.8=pyh4f3eec9_6 47 | - dbus=1.13.18=hb2f20db_0 48 | - decorator=5.0.9=pyhd3eb1b0_0 49 | - defusedxml=0.7.1=pyhd8ed1ab_0 50 | - entrypoints=0.3=pyhd8ed1ab_1003 51 | - expat=2.4.1=h2531618_2 52 | - ffmpeg=4.2.2=h20bf706_0 53 | - fontconfig=2.13.1=h6c09931_0 54 | - freetype=2.10.4=h5ab3b9f_0 55 | - giflib=5.1.4=h14c3975_1 56 | - glib=2.68.2=h36276a3_0 57 | - gmp=6.2.1=h2531618_2 58 | - gnutls=3.6.15=he1e5248_0 59 | - google-auth=1.30.1=pyhd3eb1b0_0 60 | - google-auth-oauthlib=0.4.1=py_2 61 | - grpcio=1.36.1=py36h2157cd5_1 62 | - gst-plugins-base=1.14.0=h8213a91_2 63 | - gstreamer=1.14.0=h28cd5cc_2 64 | - icu=58.2=he6710b0_3 65 | - idna=2.10=pyhd3eb1b0_0 66 | - idna_ssl=1.1.0=py36h06a4308_0 67 | - imagecodecs=2020.5.30=py36hfa7d478_2 68 | - imageio=2.9.0=pyhd3eb1b0_0 69 | - immutables=0.15=py36h27cfd23_0 70 | - importlib-metadata=3.10.0=py36h06a4308_0 71 | - intel-openmp=2021.2.0=h06a4308_610 72 | - ipykernel=5.5.5=py36hcb3619a_0 73 | - ipython=5.8.0=py36_1 74 | - ipython_genutils=0.2.0=py_1 75 | - jinja2=2.11.3=pyh44b312d_0 76 | - joblib=1.0.1=pyhd3eb1b0_0 77 | - jpeg=9b=h024ee3a_2 78 | - json5=0.9.5=pyh9f0ad1d_0 79 | - jsonschema=3.2.0=pyhd8ed1ab_3 80 | - jupyter_client=6.1.12=pyhd8ed1ab_0 81 | - jupyter_core=4.7.1=py36h5fab9bb_0 82 | - jupyter_server=1.8.0=pyhd8ed1ab_0 83 | - jupyterlab=3.0.16=pyhd8ed1ab_0 84 | - jupyterlab_pygments=0.1.2=pyh9f0ad1d_0 85 | - jupyterlab_server=2.6.0=pyhd8ed1ab_0 86 | - jxrlib=1.1=h7b6447c_2 87 | - kiwisolver=1.3.1=py36h2531618_0 88 | - lame=3.100=h7b6447c_0 89 | - lcms2=2.12=h3be6417_0 90 | - ld_impl_linux-64=2.33.1=h53a641e_7 91 | - libaec=1.0.4=he6710b0_1 92 | - libffi=3.3=he6710b0_2 93 | - libgcc-ng=9.1.0=hdf63c60_0 94 | - libgfortran-ng=7.3.0=hdf63c60_0 95 | - libidn2=2.3.1=h27cfd23_0 96 | - libopus=1.3.1=h7b6447c_0 97 | - libpng=1.6.37=hbc83047_0 98 | - libprotobuf=3.14.0=h8c45485_0 99 | - libsodium=1.0.18=h36c2ea0_1 100 | - libstdcxx-ng=9.1.0=hdf63c60_0 101 | - libtasn1=4.16.0=h27cfd23_0 102 | - libtiff=4.1.0=h2733197_1 103 | - libunistring=0.9.10=h27cfd23_0 104 | - libuuid=1.0.3=h1bed415_2 105 | - libuv=1.40.0=h7b6447c_0 106 | - libvpx=1.7.0=h439df22_0 107 | - libwebp=1.0.1=h8e7db2f_0 108 | - libxcb=1.14=h7b6447c_0 109 | - libxml2=2.9.10=hb55368b_3 110 | - libzopfli=1.0.3=he6710b0_0 111 | - lz4-c=1.9.3=h2531618_0 112 | - markdown=3.3.4=py36h06a4308_0 113 | - markupsafe=1.1.1=py36he6145b8_2 114 | - matplotlib=3.3.4=py36h06a4308_0 115 | - matplotlib-base=3.3.4=py36h62a2d02_0 116 | - mistune=0.8.4=py36h1d69622_1002 117 | - mkl=2020.2=256 118 | - mkl-service=2.3.0=py36he8ac12f_0 119 | - mkl_fft=1.3.0=py36h54f3939_0 120 | - mkl_random=1.1.1=py36h0573a6f_0 121 | - multidict=5.1.0=py36h27cfd23_2 122 | - nbclassic=0.3.1=pyhd8ed1ab_1 123 | - nbclient=0.5.3=pyhd8ed1ab_0 124 | - nbconvert=6.0.7=py36h5fab9bb_3 125 | - nbformat=5.1.3=pyhd8ed1ab_0 126 | - ncurses=6.2=he6710b0_1 127 | - nest-asyncio=1.5.1=pyhd8ed1ab_0 128 | - nettle=3.7.2=hbbd107a_1 129 | - networkx=2.5=py_0 130 | - ninja=1.10.2=hff7bd54_1 131 | - notebook=6.3.0=py36h5fab9bb_0 132 | - numpy=1.19.2=py36h54aff64_0 133 | - numpy-base=1.19.2=py36hfa32c7d_0 134 | - oauthlib=3.1.0=py_0 135 | - olefile=0.46=py36_0 136 | - openh264=2.1.0=hd408876_0 137 | - openjpeg=2.3.0=h05c96fa_1 138 | - openssl=1.1.1k=h27cfd23_0 139 | - packaging=20.9=pyh44b312d_0 140 | - pandas=1.1.5=py36ha9443f7_0 141 | - pandoc=2.14.0.1=h7f98852_0 142 | - pandocfilters=1.4.2=py_1 143 | - pcre=8.44=he6710b0_0 144 | - pexpect=4.8.0=pyh9f0ad1d_2 145 | - pickleshare=0.7.5=py_1003 146 | - pillow=8.2.0=py36he98fc37_0 147 | - prometheus_client=0.11.0=pyhd8ed1ab_0 148 | - prompt_toolkit=1.0.15=py_1 149 | - protobuf=3.14.0=py36h2531618_1 150 | - ptyprocess=0.7.0=pyhd3deb0d_0 151 | - pyasn1=0.4.8=py_0 152 | - pyasn1-modules=0.2.8=py_0 153 | - pycparser=2.20=py_2 154 | - pygments=2.9.0=pyhd8ed1ab_0 155 | - pyjwt=1.7.1=py36_0 156 | - pyopenssl=20.0.1=pyhd3eb1b0_1 157 | - pyparsing=2.4.7=pyhd3eb1b0_0 158 | - pyqt=5.9.2=py36h05f1152_2 159 | - pyrsistent=0.17.3=py36h1d69622_1 160 | - pysocks=1.7.1=py36h06a4308_0 161 | - python=3.6.13=hdb3f193_0 162 | - python-dateutil=2.8.1=pyhd3eb1b0_0 163 | - python_abi=3.6=1_cp36m 164 | - pytorch=1.8.1=py3.6_cuda11.1_cudnn8.0.5_0 165 | - pytz=2021.1=pyhd3eb1b0_0 166 | - pywavelets=1.1.1=py36h7b6447c_2 167 | - pyyaml=5.4.1=py36h27cfd23_1 168 | - pyzmq=19.0.2=py36h9947dbf_2 169 | - qt=5.9.7=h5867ecd_1 170 | - readline=8.1=h27cfd23_0 171 | - requests=2.25.1=pyhd3eb1b0_0 172 | - requests-oauthlib=1.3.0=py_0 173 | - rsa=4.7.2=pyhd3eb1b0_1 174 | - scikit-image=0.17.2=py36hdf5156a_0 175 | - scikit-learn=0.24.2=py36ha9443f7_0 176 | - scipy=1.5.2=py36h0b6359f_0 177 | - seaborn=0.11.0=py_0 178 | - send2trash=1.5.0=py_0 179 | - setuptools=52.0.0=py36h06a4308_0 180 | - simplegeneric=0.8.1=py_1 181 | - sip=4.19.8=py36hf484d3e_0 182 | - six=1.15.0=py36h06a4308_0 183 | - snappy=1.1.8=he6710b0_0 184 | - sniffio=1.2.0=py36h5fab9bb_1 185 | - sqlite=3.35.4=hdfb4753_0 186 | - tensorboard=2.4.1=pyhd8ed1ab_0 187 | - tensorboard-plugin-wit=1.6.0=py_0 188 | - terminado=0.10.0=py36h5fab9bb_0 189 | - testpath=0.5.0=pyhd8ed1ab_0 190 | - threadpoolctl=2.1.0=pyh5ca1d4c_0 191 | - tifffile=2021.3.17=pyhd3eb1b0_1 192 | - tk=8.6.10=hbc83047_0 193 | - toolz=0.11.1=pyhd3eb1b0_0 194 | - torchaudio=0.8.1=py36 195 | - torchvision=0.9.1=py36_cu111 196 | - tornado=6.1=py36h27cfd23_0 197 | - tqdm=4.61.0=pyhd8ed1ab_0 198 | - traitlets=4.3.3=py36h9f0ad1d_1 199 | - typing-extensions=3.7.4.3=hd3eb1b0_0 200 | - typing_extensions=3.7.4.3=pyh06a4308_0 201 | - urllib3=1.26.4=pyhd3eb1b0_0 202 | - wcwidth=0.2.5=pyh9f0ad1d_2 203 | - webencodings=0.5.1=py_1 204 | - websocket-client=0.57.0=py36h5fab9bb_4 205 | - werkzeug=1.0.1=pyhd3eb1b0_0 206 | - wheel=0.36.2=pyhd3eb1b0_0 207 | - x264=1!157.20191217=h7b6447c_0 208 | - xz=5.2.5=h7b6447c_0 209 | - yaml=0.2.5=h7b6447c_0 210 | - yarl=1.6.3=py36h27cfd23_0 211 | - zeromq=4.3.4=h2531618_0 212 | - zipp=3.4.1=pyhd3eb1b0_0 213 | - zlib=1.2.11=h7b6447c_3 214 | - zstd=1.4.9=haebb681_0 215 | - pip: 216 | - argparse==1.4.0 217 | - blessings==1.7 218 | - gpustat==0.6.0 219 | - nvidia-ml-py3==7.352.0 220 | - pip==21.1.2 221 | - psutil==5.8.0 222 | - rtpt==0.0.4 223 | - setproctitle==1.2.2 224 | - torchsummary==1.5.1 225 | - pathos==0.2.8 226 | prefix: /home/dochs/anaconda3/envs/env_slash 227 | -------------------------------------------------------------------------------- /imgs/slash_icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/SLASH/79d158dbd670d56d8bf9043640635c24a18a9c57/imgs/slash_icon.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.12.0 2 | blessings==1.7 3 | cachetools==4.2.2 4 | certifi==2021.5.30 5 | cffi==1.14.5 6 | chardet==4.0.0 7 | clingo==5.5.0.post3 8 | cycler==0.10.0 9 | dataclasses==0.6 10 | decorator==4.4.2 11 | distro==1.5.0 12 | google-auth==1.30.2 13 | google-auth-oauthlib==0.4.4 14 | gpustat==0.6.0 15 | grpcio==1.38.0 16 | idna==2.10 17 | imageio==2.9.0 18 | importlib-metadata==4.5.0 19 | install==1.3.4 20 | joblib==1.0.1 21 | kiwisolver==1.3.1 22 | Markdown==3.3.4 23 | matplotlib==3.3.4 24 | networkx==2.5.1 25 | numpy==1.19.5 26 | nvidia-ml-py3==7.352.0 27 | oauthlib==3.1.1 28 | packaging==20.9 29 | pandas==1.1.5 30 | Pillow==8.2.0 31 | protobuf==3.17.3 32 | psutil==5.8.0 33 | pyasn1==0.4.8 34 | pyasn1-modules==0.2.8 35 | pycparser==2.20 36 | pyparsing==2.4.7 37 | python-dateutil==2.8.1 38 | pytz==2021.1 39 | PyWavelets==1.1.1 40 | requests==2.25.1 41 | requests-oauthlib==1.3.0 42 | rsa==4.7.2 43 | rtpt==0.0.4 44 | scikit-image==0.17.2 45 | scikit-learn==0.24.2 46 | scipy==1.5.4 47 | seaborn==0.11.1 48 | setproctitle==1.2.2 49 | six==1.16.0 50 | tensorboard==2.5.0 51 | tensorboard-data-server==0.6.1 52 | tensorboard-plugin-wit==1.8.0 53 | threadpoolctl==2.1.0 54 | tifffile==2020.9.3 55 | #torch==1.8.1+cu111 56 | #torchaudio==0.8.1 57 | torchsummary==1.5.1 58 | #torchvision==0.9.1+cu111 59 | tqdm==4.61.0 60 | typing-extensions==3.10.0.0 61 | urllib3==1.26.5 62 | warmup-scheduler @ git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git@6b5e8953a80aef5b324104dc0c2e9b8c34d622bd 63 | Werkzeug==2.0.1 64 | zipp==3.4.1 65 | pathos==0.2.8 66 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source ~/.bashrc 4 | pyenv uninstall venv_slash 5 | pyenv virtualenv venv_slash 6 | pyenv activate venv_slash 7 | pip install --upgrade pip 8 | pip install -r requirements.txt 9 | -------------------------------------------------------------------------------- /src/README.md: -------------------------------------------------------------------------------- 1 | # Source folder quick overview: 2 | 3 | 1. `EinsumNetwork` contains code which was cloned from the [git repository EinNets](https://github.com/cambridge-mlg/EinsumNetworks). 4 | 2. `SLASH` contains all modifications made to enable probabilisitc circuits in the program, VQA and SLASH-attention, optimized gradient computation and the SAME technique. 5 | * `mvpp.py` contains all functions to compute stable models given a logic program and to compute the gradients using the output probabilites 6 | * `slash.py` contains the SLASH class which brings together the symbolic program with neural nets or probabilistic circuits 7 | 3. `experiments` contains the VQA, mnist addtion and slash attention experiments -------------------------------------------------------------------------------- /src/SLASH/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/SLASH/79d158dbd670d56d8bf9043640635c24a18a9c57/src/SLASH/__init__.py -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/SLASH/79d158dbd670d56d8bf9043640635c24a18a9c57/src/__init__.py -------------------------------------------------------------------------------- /src/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import urllib.request 4 | import shutil 5 | 6 | from zipfile import ZipFile 7 | import gzip 8 | import utils 9 | 10 | def maybe_download(directory, url_base, filename, suffix='.zip'): 11 | ''' 12 | Downloads the specified dataset and extracts it 13 | 14 | @param directory: 15 | @param url_base: URL where to find the file 16 | @param filename: name of the file to be downloaded 17 | @param suffix: suffix of the file 18 | 19 | :returns: true if nothing went wrong downloading 20 | ''' 21 | 22 | filepath = os.path.join(directory, filename) 23 | if os.path.isfile(filepath): 24 | return False 25 | 26 | if not os.path.isdir(directory): 27 | utils.mkdir_p(directory) 28 | 29 | url = url_base +filename 30 | 31 | _, zipped_filepath = tempfile.mkstemp(suffix=suffix) 32 | 33 | print('Downloading {} to {}'.format(url, zipped_filepath)) 34 | 35 | urllib.request.urlretrieve(url, zipped_filepath) 36 | print('{} Bytes'.format(os.path.getsize(zipped_filepath))) 37 | 38 | print('Move to {}'.format(filepath)) 39 | shutil.move(zipped_filepath, filepath) 40 | return True 41 | 42 | 43 | def extract_dataset(directory, filepath, filepath_extracted): 44 | if not os.path.isdir(filepath_extracted): 45 | print('unzip ',filepath, " to", filepath_extracted) 46 | with ZipFile(filepath, 'r') as zipObj: 47 | # Extract all the contents of zip file in current directory 48 | zipObj.extractall(directory) 49 | 50 | 51 | def maybe_download_shapeworld4(): 52 | ''' 53 | Downloads the shapeworld4 dataset if it is not downloaded yet 54 | ''' 55 | 56 | directory = "../../data/" 57 | file_name= "shapeworld4.zip" 58 | maybe_download(directory, "https://hessenbox.tu-darmstadt.de/dl/fiEE3hftM4n1gBGn4HJLKUkU/", file_name) 59 | 60 | filepath = os.path.join(directory, file_name) 61 | filepath_extracted = os.path.join(directory,"shapeworld4") 62 | 63 | extract_dataset(directory, filepath, filepath_extracted) 64 | 65 | 66 | def maybe_download_shapeworld_cogent(): 67 | ''' 68 | Downloads the shapeworld4 cogent dataset if it is not downloaded yet 69 | ''' 70 | 71 | directory = "../../data/" 72 | file_name= "shapeworld_cogent.zip" 73 | maybe_download(directory, "https://hessenbox.tu-darmstadt.de/dl/fi3CDjPRsYgAvotHcC8GPaWj/", file_name) 74 | 75 | filepath = os.path.join(directory, file_name) 76 | filepath_extracted = os.path.join(directory,"shapeworld_cogent") 77 | 78 | extract_dataset(directory, filepath, filepath_extracted) 79 | -------------------------------------------------------------------------------- /src/einsum_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from EinsumNetwork import EinsumNetwork, Graph 5 | 6 | device = torch.device('cuda:0') 7 | 8 | #wrapper class to create an Einsum Network given a specific structure and parameters 9 | class EiNet(EinsumNetwork.EinsumNetwork): 10 | def __init__(self , 11 | use_em, 12 | structure = 'poon-domingos', 13 | pd_num_pieces = [4], 14 | depth = 8, 15 | num_repetitions = 20, 16 | num_var = 784, 17 | class_count = 3, 18 | K = 10, 19 | num_sums = 10, 20 | pd_height = 28, 21 | pd_width = 28, 22 | learn_prior = True 23 | ): 24 | 25 | 26 | # Structure 27 | self.structure = structure 28 | self.class_count = class_count 29 | classes = np.arange(class_count) # [0,1,2,..,n-1] 30 | 31 | # Define the prior, i.e. P(C) and make it learnable. 32 | self.learnable_prior = learn_prior 33 | # P(C) is needed to apply the Bayes' theorem and to retrive 34 | # P(C|X) = P(X|C)*(P(C) / P(X) 35 | if self.class_count == 4: 36 | self.prior = torch.tensor([(1/3)*(2/3), (1/3)*(2/3), (1/3)*(2/3), (1/3)], dtype=torch.float, requires_grad=True, device=device).log() 37 | else: 38 | self.prior = torch.ones(class_count, device=device, dtype=torch.float) 39 | self.prior.fill_(1 / class_count) 40 | self.prior.log_() 41 | if self.learnable_prior: 42 | print("P(C) is learnable.") 43 | self.prior.requires_grad_() 44 | 45 | self.K = K 46 | self.num_sums = num_sums 47 | 48 | # 'poon-domingos' 49 | self.pd_num_pieces = pd_num_pieces # [10, 28],[4],[7] 50 | self.pd_height = pd_height 51 | self.pd_width = pd_width 52 | 53 | 54 | # 'binary-trees' 55 | self.depth = depth 56 | self.num_repetitions = num_repetitions 57 | self.num_var = num_var 58 | 59 | # drop-out rate 60 | # self.drop_out = drop_out 61 | # print("The drop-out rate:", self.drop_out) 62 | 63 | # EM-settings 64 | self.use_em = use_em 65 | online_em_frequency = 1 66 | online_em_stepsize = 0.05 # 0.05 67 | print("train SPN with EM:",self.use_em) 68 | 69 | 70 | # exponential_family = EinsumNetwork.BinomialArray 71 | # exponential_family = EinsumNetwork.CategoricalArray 72 | exponential_family = EinsumNetwork.NormalArray 73 | 74 | exponential_family_args = None 75 | if exponential_family == EinsumNetwork.BinomialArray: 76 | exponential_family_args = {'N': 255} 77 | if exponential_family == EinsumNetwork.CategoricalArray: 78 | exponential_family_args = {'K': 1366120} 79 | if exponential_family == EinsumNetwork.NormalArray: 80 | exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1} 81 | 82 | # Make EinsumNetwork 83 | if self.structure == 'poon-domingos': 84 | pd_delta = [[self.pd_height / d, self.pd_width / d] for d in self.pd_num_pieces] 85 | graph = Graph.poon_domingos_structure(shape=(self.pd_height, self.pd_width), delta=pd_delta) 86 | elif self.structure == 'binary-trees': 87 | graph = Graph.random_binary_trees(num_var=self.num_var, depth=self.depth, num_repetitions=self.num_repetitions) 88 | else: 89 | raise AssertionError("Unknown Structure") 90 | 91 | 92 | args = EinsumNetwork.Args( 93 | num_var=self.num_var, 94 | num_dims=1, 95 | num_classes=self.class_count, 96 | num_sums=self.num_sums, 97 | num_input_distributions=self.K, 98 | exponential_family=exponential_family, 99 | exponential_family_args=exponential_family_args, 100 | use_em=self.use_em, 101 | online_em_frequency=online_em_frequency, 102 | online_em_stepsize=online_em_stepsize) 103 | 104 | super().__init__(graph, args) 105 | super().initialize() 106 | 107 | def get_log_likelihoods(self, x): 108 | log_likelihood = super().forward(x) 109 | return log_likelihood 110 | 111 | def forward(self, x, marg_idx=None, type=1): 112 | 113 | # PRIOR 114 | if type == 4: 115 | expanded_prior = self.prior.expand(x.shape[0], self.prior.shape[0]) 116 | return expanded_prior 117 | 118 | else: 119 | # Obtain P(X|C) in log domain 120 | if marg_idx: # If marginalisation mask is passed 121 | self.set_marginalization_idx(marg_idx) 122 | log_likelihood = super().forward(x) 123 | self.set_marginalization_idx(None) 124 | likelihood = torch.nn.functional.softmax(log_likelihood, dim=1) 125 | else: 126 | log_likelihood = super().forward(x) 127 | 128 | #LIKELIHOOD 129 | if type == 2: 130 | likelihood = torch.nn.functional.softmax(log_likelihood, dim=1) 131 | # Sanity check for NaN-values 132 | if torch.isnan(log_likelihood).sum() > 0: 133 | print("likelihood nan") 134 | 135 | return likelihood 136 | else: 137 | # Apply Bayes' Theorem to obtain P(C|X) instead of P(X|C) 138 | # as it is provided by the EiNet 139 | # 1. Computation of the prior, i.e. P(C), is already being 140 | # dealt with at the initialisation of the EiNet. 141 | # 2. Compute the normalization constant P(X) 142 | z = torch.logsumexp(log_likelihood + self.prior, dim=1) 143 | # 3. Compute the posterior, i.e. P(C|X) = (P(X|C) * P(C)) / P(X) 144 | posterior_log = (log_likelihood + self.prior - z[:, None]) # log domain 145 | #posterior = posterior_log.exp() # decimal domain 146 | 147 | 148 | #POSTERIOR 149 | if type == 1: 150 | posterior = torch.nn.functional.softmax(posterior_log, dim=1) 151 | 152 | # Sanity check for NaN-values 153 | if torch.isnan(z).sum() > 0: 154 | print("z nan") 155 | if torch.isnan(posterior).sum() > 0: 156 | print("posterior nan") 157 | return posterior 158 | 159 | #JOINT 160 | elif type == 3: 161 | #compute the joint P(X|C) * P(C) 162 | joint = torch.nn.functional.softmax(log_likelihood + self.prior, dim=1) 163 | 164 | return joint 165 | -------------------------------------------------------------------------------- /src/experiments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/SLASH/79d158dbd670d56d8bf9043640635c24a18a9c57/src/experiments/__init__.py -------------------------------------------------------------------------------- /src/experiments/baseline_slot_attention/dataGen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | 4 | from torch.utils.data import Dataset 5 | from torchvision.transforms import transforms 6 | 7 | from torch.utils.data import Dataset 8 | from torchvision import transforms 9 | from skimage import io 10 | import os 11 | import numpy as np 12 | import random 13 | import torch 14 | import matplotlib.pyplot as plt 15 | from PIL import Image, ImageFile 16 | ImageFile.LOAD_TRUNCATED_IMAGES = True 17 | import json 18 | import datasets as datasets 19 | 20 | 21 | def get_loader(dataset, batch_size, num_workers=8, shuffle=True): 22 | ''' 23 | Returns and iterable dataset with specified batchsize and shuffling. 24 | ''' 25 | return torch.utils.data.DataLoader( 26 | dataset, 27 | shuffle=shuffle, 28 | batch_size=batch_size, 29 | num_workers=num_workers 30 | ) 31 | 32 | 33 | def get_encoding_shapeworld(color, shape, shade, size): 34 | 35 | if color == 'red': 36 | col_enc = [1,0,0,0,0,0,0,0] 37 | elif color == 'blue': 38 | col_enc = [0,1,0,0,0,0,0,0] 39 | elif color == 'green': 40 | col_enc = [0,0,1,0,0,0,0,0] 41 | elif color == 'gray': 42 | col_enc = [0,0,0,1,0,0,0,0] 43 | elif color == 'brown': 44 | col_enc = [0,0,0,0,1,0,0,0] 45 | elif color == 'magenta': 46 | col_enc = [0,0,0,0,0,1,0,0] 47 | elif color == 'cyan': 48 | col_enc = [0,0,0,0,0,0,1,0] 49 | elif color == 'yellow': 50 | col_enc = [0,0,0,0,0,0,0,1] 51 | 52 | if shape == 'circle': 53 | shape_enc = [1,0,0] 54 | elif shape == 'triangle': 55 | shape_enc = [0,1,0] 56 | elif shape == 'square': 57 | shape_enc = [0,0,1] 58 | 59 | if shade == 'bright': 60 | shade_enc = [1,0] 61 | elif shade =='dark': 62 | shade_enc = [0,1] 63 | 64 | 65 | if size == 'small': 66 | size_enc = [1,0] 67 | elif size == 'big': 68 | size_enc = [0,1] 69 | 70 | return np.array([1]+ col_enc + shape_enc + shade_enc + size_enc) 71 | 72 | 73 | class SHAPEWORLD4(Dataset): 74 | def __init__(self, root, mode, learn_concept='default', bg_encoded=True): 75 | 76 | datasets.maybe_download_shapeworld4() 77 | 78 | self.root = root 79 | self.mode = mode 80 | assert os.path.exists(root), 'Path {} does not exist'.format(root) 81 | 82 | #dictionary of the form {'image_idx':'img_path'} 83 | self.img_paths = {} 84 | 85 | 86 | for file in os.scandir(os.path.join(root, 'images', mode)): 87 | img_path = file.path 88 | 89 | img_path_idx = img_path.split("/") 90 | img_path_idx = img_path_idx[-1] 91 | img_path_idx = img_path_idx[:-4][6:] 92 | try: 93 | img_path_idx = int(img_path_idx) 94 | self.img_paths[img_path_idx] = img_path 95 | except: 96 | print("path:",img_path_idx) 97 | 98 | 99 | 100 | count = 0 101 | 102 | #target maps of the form {'target:idx': observation string} or {'target:idx': obj encoding} 103 | self.obj_map = {} 104 | 105 | with open(os.path.join(root, 'labels', mode,"world_model.json")) as f: 106 | worlds = json.load(f) 107 | 108 | 109 | 110 | #iterate over all objects 111 | for world in worlds: 112 | num_objects = 0 113 | target_obs = "" 114 | obj_enc = [] 115 | for entity in world['entities']: 116 | 117 | color = entity['color']['name'] 118 | shape = entity['shape']['name'] 119 | 120 | shade_val = entity['color']['shade'] 121 | if shade_val == 0.0: 122 | shade = 'bright' 123 | else: 124 | shade = 'dark' 125 | 126 | size_val = entity['shape']['size']['x'] 127 | if size_val == 0.075: 128 | size = 'small' 129 | elif size_val == 0.15: 130 | size = 'big' 131 | 132 | name = 'o' + str(num_objects+1) 133 | obj_enc.append(get_encoding_shapeworld(color, shape, shade, size)) 134 | num_objects += 1 135 | 136 | #bg encodings 137 | for i in range(num_objects, 4): 138 | name = 'o' + str(num_objects+1) 139 | obj_enc.append(np.array([0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0])) 140 | num_objects += 1 141 | 142 | self.obj_map[count] = torch.Tensor(obj_enc) 143 | count+=1 144 | 145 | def __getitem__(self, index): 146 | 147 | #get the image 148 | img_path = self.img_paths[index] 149 | img = io.imread(img_path)[:, :, :3] 150 | 151 | transform = transforms.Compose([ 152 | transforms.ToPILImage(), 153 | #transforms.CenterCrop(250), 154 | #transforms.Resize((32, 32)), 155 | transforms.ToTensor(), 156 | ]) 157 | img = transform(img) 158 | img = (img - 0.5) * 2.0 # Rescale to [-1, 1]. 159 | 160 | return img, self.obj_map[index]#, mask 161 | 162 | def __len__(self): 163 | return len(self.img_paths) 164 | 165 | def get_encoding_clevr(size, material, shape, color ): 166 | 167 | #size (small, large, bg) 168 | if size == "small": 169 | size_enc = [1,0] 170 | elif size == "large": 171 | size_enc = [0,1] 172 | 173 | #material (rubber, metal, bg) 174 | if material == "rubber": 175 | material_enc = [1,0] 176 | elif material == "metal": 177 | material_enc = [0,1] 178 | 179 | #shape (cube, sphere, cylinder, bg) 180 | if shape == "cube": 181 | shape_enc = [1,0,0] 182 | elif shape == "sphere": 183 | shape_enc = [0,1,0] 184 | elif shape == "cylinder": 185 | shape_enc = [0,0,1] 186 | 187 | 188 | #color (gray, red, blue, green, brown, purple, cyan, yellow, bg) 189 | if color == "gray": 190 | color_enc = [1,0,0,0,0,0,0,0] 191 | elif color == "red": 192 | color_enc = [0,1,0,0,0,0,0,0] 193 | elif color == "blue": 194 | color_enc = [0,0,1,0,0,0,0,0] 195 | elif color == "green": 196 | color_enc = [0,0,0,1,0,0,0,0] 197 | elif color == "brown": 198 | color_enc = [0,0,0,0,1,0,0,0] 199 | elif color == "purple": 200 | color_enc = [0,0,0,0,0,1,0,0] 201 | elif color == "cyan": 202 | color_enc = [0,0,0,0,0,0,1,0] 203 | elif color == "yellow": 204 | color_enc = [0,0,0,0,0,0,0,1] 205 | 206 | 207 | return np.array([1] + size_enc + material_enc + shape_enc + color_enc ) 208 | 209 | 210 | class CLEVR(Dataset): 211 | def __init__(self, root, mode, img_paths=None, files_names=None, obj_num=None): 212 | self.root = root # The root folder of the dataset 213 | self.mode = mode # The mode of 'train' or 'val' 214 | self.files_names = files_names # The list of the files names with correct nuber of objects 215 | if obj_num is not None: 216 | self.obj_num = obj_num # The upper limit of number of objects 217 | else: 218 | self.obj_num = 10 219 | 220 | assert os.path.exists(root), 'Path {} does not exist'.format(root) 221 | 222 | #list of sorted image paths 223 | self.img_paths = [] 224 | if img_paths: 225 | self.img_paths = img_paths 226 | else: 227 | #open directory and save all image paths 228 | for file in os.scandir(os.path.join(root, 'images', mode)): 229 | img_path = file.path 230 | if '.png' in img_path: 231 | self.img_paths.append(img_path) 232 | 233 | self.img_paths.sort() 234 | self.img_paths = np.array(self.img_paths, dtype=str) 235 | 236 | count = 0 237 | 238 | #target maps of the form {'target:idx': query string} or {'target:idx': obj encoding} 239 | #self.obj_map = {} 240 | 241 | 242 | count = 0 243 | #We have up to 10 objects in the image, load the json file 244 | with open(os.path.join(root, 'scenes','CLEVR_'+ mode+"_scenes.json")) as f: 245 | data = json.load(f) 246 | 247 | self.obj_map = np.empty((len(data['scenes']),10,16), dtype=np.float32) 248 | 249 | #iterate over each scene and create the query string and obj encoding 250 | print("parsing scences") 251 | for scene in data['scenes']: 252 | obj_encoding_list = [] 253 | 254 | if self.files_names: 255 | if any(scene['image_filename'] in file_name for file_name in files_names): 256 | num_objects = 0 257 | for idx, obj in enumerate(scene['objects']): 258 | obj_encoding_list.append(get_encoding_clevr(obj['size'], obj['material'], obj['shape'], obj['color'])) 259 | num_objects = idx+1 #store the num of objects 260 | #fill in background objects 261 | for idx in range(num_objects, self.obj_num): 262 | obj_encoding_list.append([0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]) 263 | self.obj_map[count] = np.array(obj_encoding_list) 264 | count += 1 265 | else: 266 | num_objects=0 267 | for idx, obj in enumerate(scene['objects']): 268 | obj_encoding_list.append(get_encoding_clevr(obj['size'], obj['material'], obj['shape'], obj['color'])) 269 | num_objects = idx+1 #store the num of objects 270 | #fill in background objects 271 | for idx in range(num_objects, 10): 272 | obj_encoding_list.append([0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]) 273 | self.obj_map[scene['image_index']] = np.array(obj_encoding_list, dtype=np.float32) 274 | 275 | print("done") 276 | if self.files_names: 277 | print(f'Correctly found images {count} out of {len(files_names)}') 278 | 279 | def __getitem__(self, index): 280 | 281 | #get the image 282 | img_path = self.img_paths[index] 283 | img = io.imread(img_path)[:, :, :3] 284 | img = Image.fromarray(img).resize((128,128)) #using transforms to resize gets us a shrared-memory leak :( 285 | 286 | transform = transforms.Compose([ 287 | #transforms.ToPILImage(), 288 | #transforms.CenterCrop((29, 221,64, 256)), #why do we need to crop? 289 | #transforms.Resize((128, 128)), 290 | transforms.ToTensor(), 291 | ]) 292 | img = transform(img) 293 | img = (img - 0.5) * 2.0 # Rescale to [-1, 1]. 294 | 295 | return img, self.obj_map[index]#, mask 296 | 297 | def __len__(self): 298 | return self.img_paths.shape[0] -------------------------------------------------------------------------------- /src/experiments/baseline_slot_attention/set_utils.py: -------------------------------------------------------------------------------- 1 | import scipy.optimize 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import os 6 | 7 | 8 | def save_args(args, writer): 9 | # store args as txt file 10 | with open(os.path.join(writer.log_dir, 'args.txt'), 'w') as f: 11 | for arg in vars(args): 12 | f.write(f"\n{arg}: {getattr(args, arg)}") 13 | 14 | 15 | def hungarian_matching(attrs, preds_attrs, verbose=0): 16 | """ 17 | Receives unordered predicted set and orders this to match the nearest GT set. 18 | :param attrs: 19 | :param preds_attrs: 20 | :param verbose: 21 | :return: 22 | """ 23 | assert attrs.shape[1] == preds_attrs.shape[1] 24 | assert attrs.shape == preds_attrs.shape 25 | from scipy.optimize import linear_sum_assignment 26 | matched_preds_attrs = preds_attrs.clone() 27 | for sample_id in range(attrs.shape[0]): 28 | # using euclidean distance 29 | cost_matrix = torch.cdist(attrs[sample_id], preds_attrs[sample_id]).detach().cpu() 30 | 31 | idx_mapping = linear_sum_assignment(cost_matrix) 32 | # convert to tuples of [(row_id, col_id)] of the cost matrix 33 | idx_mapping = [(idx_mapping[0][i], idx_mapping[1][i]) for i in range(len(idx_mapping[0]))] 34 | 35 | for i, (row_id, col_id) in enumerate(idx_mapping): 36 | matched_preds_attrs[sample_id, row_id, :] = preds_attrs[sample_id, col_id, :] 37 | if verbose: 38 | print('GT: {}'.format(attrs[sample_id])) 39 | print('Pred: {}'.format(preds_attrs[sample_id])) 40 | print('Cost Matrix: {}'.format(cost_matrix)) 41 | print('idx mapping: {}'.format(idx_mapping)) 42 | print('Matched Pred: {}'.format(matched_preds_attrs[sample_id])) 43 | print('\n') 44 | # exit() 45 | 46 | return matched_preds_attrs 47 | 48 | 49 | def average_precision_shapeworld(pred, attributes, distance_threshold, dataset): 50 | """Computes the average precision for CLEVR. 51 | This function computes the average precision of the predictions specifically 52 | for the CLEVR dataset. First, we sort the predictions of the model by 53 | confidence (highest confidence first). Then, for each prediction we check 54 | whether there was a corresponding object in the input image. A prediction is 55 | considered a true positive if the discrete features are predicted correctly 56 | and the predicted position is within a certain distance from the ground truth 57 | object. 58 | Args: 59 | pred: Tensor of shape [batch_size, num_elements, dimension] containing 60 | predictions. The last dimension is expected to be the confidence of the 61 | prediction. 62 | attributes: Tensor of shape [batch_size, num_elements, dimension] containing 63 | predictions. 64 | distance_threshold: Threshold to accept match. -1 indicates no threshold. 65 | Returns: 66 | Average precision of the predictions. 67 | """ 68 | 69 | [batch_size, _, element_size] = attributes.shape 70 | [_, predicted_elements, _] = pred.shape 71 | 72 | def unsorted_id_to_image(detection_id, predicted_elements): 73 | """Find the index of the image from the unsorted detection index.""" 74 | return int(detection_id // predicted_elements) 75 | 76 | flat_size = batch_size * predicted_elements 77 | flat_pred = np.reshape(pred, [flat_size, element_size]) 78 | # sort_idx = np.argsort(flat_pred[:, -1], axis=0)[::-1] # Reverse order. 79 | sort_idx = np.argsort(flat_pred[:, 0], axis=0)[::-1] # Reverse order. 80 | 81 | sorted_predictions = np.take_along_axis( 82 | flat_pred, np.expand_dims(sort_idx, axis=1), axis=0) 83 | idx_sorted_to_unsorted = np.take_along_axis( 84 | np.arange(flat_size), sort_idx, axis=0) 85 | 86 | 87 | def process_targets_shapeworld4(target): 88 | """Unpacks the target into the CLEVR properties.""" 89 | #col_enc + shape_enc + shade_enc + size_enc 90 | real_obj = target[0] 91 | color = np.argmax(target[1:9]) 92 | shape = np.argmax(target[9:12]) 93 | shade = np.argmax(target[12:14]) 94 | size = np.argmax(target[14:16]) 95 | return np.array([0,0,0]), size, shade, shape, color, real_obj 96 | 97 | def process_targets_clevr(target): 98 | """Unpacks the target into the CLEVR properties.""" 99 | #col_enc + shape_enc + shade_enc + size_enc 100 | 101 | 102 | real_obj = target[0] 103 | size = np.argmax(target[1:3]) 104 | material = np.argmax(target[3:5]) 105 | shape = np.argmax(target[5:8]) 106 | color = np.argmax(target[8:16]) 107 | 108 | return np.array([0,0,0]), size, material, shape, color, real_obj 109 | 110 | 111 | def process_targets(target): 112 | if dataset == "shapeworld4": 113 | return process_targets_shapeworld4(target) 114 | elif dataset == "clevr": 115 | return process_targets_clevr(target) 116 | 117 | 118 | true_positives = np.zeros(sorted_predictions.shape[0]) 119 | false_positives = np.zeros(sorted_predictions.shape[0]) 120 | 121 | detection_set = set() 122 | 123 | for detection_id in range(sorted_predictions.shape[0]): 124 | # Extract the current prediction. 125 | current_pred = sorted_predictions[detection_id, :] 126 | # Find which image the prediction belongs to. Get the unsorted index from 127 | # the sorted one and then apply to unsorted_id_to_image function that undoes 128 | # the reshape. 129 | original_image_idx = unsorted_id_to_image( 130 | idx_sorted_to_unsorted[detection_id], predicted_elements) 131 | # Get the ground truth image. 132 | gt_image = attributes[original_image_idx, :, :] 133 | 134 | # Initialize the maximum distance and the id of the groud-truth object that 135 | # was found. 136 | best_distance = 10000 137 | best_id = None 138 | 139 | # Unpack the prediction by taking the argmax on the discrete 140 | # attributes. 141 | (pred_coords, pred_object_size, pred_material, pred_shape, pred_color, 142 | _) = process_targets(current_pred) 143 | 144 | # Loop through all objects in the ground-truth image to check for hits. 145 | for target_object_id in range(gt_image.shape[0]): 146 | target_object = gt_image[target_object_id, :] 147 | # Unpack the targets taking the argmax on the discrete attributes. 148 | (target_coords, target_object_size, target_material, target_shape, 149 | target_color, target_real_obj) = process_targets(target_object) 150 | # Only consider real objects as matches. 151 | if target_real_obj: 152 | # For the match to be valid all attributes need to be correctly 153 | # predicted. 154 | pred_attr = [ 155 | pred_object_size, 156 | pred_material, 157 | pred_shape, 158 | pred_color] 159 | target_attr = [ 160 | target_object_size, 161 | target_material, 162 | target_shape, 163 | target_color] 164 | match = pred_attr == target_attr 165 | if match: 166 | # If a match was found, we check if the distance is below the 167 | # specified threshold. Recall that we have rescaled the coordinates 168 | # in the dataset from [-3, 3] to [0, 1], both for `target_coords` and 169 | # `pred_coords`. To compare in the original scale, we thus need to 170 | # multiply the distance values by 6 before applying the 171 | # norm. 172 | distance = np.linalg.norm( 173 | (target_coords - pred_coords) * 6.) 174 | 175 | # If this is the best match we've found so far we remember 176 | # it. 177 | if distance < best_distance: 178 | best_distance = distance 179 | best_id = target_object_id 180 | if best_distance < distance_threshold or distance_threshold == -1: 181 | # We have detected an object correctly within the distance confidence. 182 | # If this object was not detected before it's a true positive. 183 | if best_id is not None: 184 | if (original_image_idx, best_id) not in detection_set: 185 | true_positives[detection_id] = 1 186 | detection_set.add((original_image_idx, best_id)) 187 | else: 188 | false_positives[detection_id] = 1 189 | else: 190 | false_positives[detection_id] = 1 191 | else: 192 | false_positives[detection_id] = 1 193 | 194 | accumulated_fp = np.cumsum(false_positives) 195 | accumulated_tp = np.cumsum(true_positives) 196 | 197 | recall_array = accumulated_tp / np.sum(attributes[:, :, 0]) 198 | precision_array = np.divide( 199 | accumulated_tp, 200 | (accumulated_fp + accumulated_tp)) 201 | 202 | return compute_average_precision( 203 | np.array(precision_array, dtype=np.float32), 204 | np.array(recall_array, dtype=np.float32)) 205 | 206 | 207 | def compute_average_precision(precision, recall): 208 | """Computation of the average precision from precision and recall arrays.""" 209 | recall = recall.tolist() 210 | precision = precision.tolist() 211 | recall = [0] + recall + [1] 212 | precision = [0] + precision + [0] 213 | 214 | for i in range(len(precision) - 1, -0, -1): 215 | precision[i - 1] = max(precision[i - 1], precision[i]) 216 | 217 | indices_recall = [ 218 | i for i in range(len(recall) - 1) if recall[1:][i] != recall[:-1][i] 219 | ] 220 | 221 | average_precision = 0. 222 | for i in indices_recall: 223 | average_precision += precision[i + 1] * (recall[i + 1] - recall[i]) 224 | return average_precision 225 | 226 | 227 | def hungarian_loss(predictions, targets): 228 | # permute dimensions for pairwise distance computation between all slots 229 | predictions = predictions.permute(0, 2, 1) 230 | targets = targets.permute(0, 2, 1) 231 | 232 | # predictions and targets shape :: (n, c, s) 233 | predictions, targets = outer(predictions, targets) 234 | # squared_error shape :: (n, s, s) 235 | squared_error = F.smooth_l1_loss(predictions, targets.expand_as(predictions), reduction="none").mean(1) 236 | 237 | squared_error_np = squared_error.detach().cpu().numpy() 238 | indices = map(hungarian_loss_per_sample, squared_error_np) 239 | losses = [ 240 | sample[row_idx, col_idx].mean() 241 | for sample, (row_idx, col_idx) in zip(squared_error, indices) 242 | ] 243 | total_loss = torch.mean(torch.stack(list(losses))) 244 | return total_loss 245 | 246 | 247 | 248 | def hungarian_loss_per_sample(sample_np): 249 | return scipy.optimize.linear_sum_assignment(sample_np) 250 | 251 | 252 | def scatter_masked(tensor, mask, binned=False, threshold=None): 253 | s = tensor[0].detach().cpu() 254 | mask = mask[0].detach().clamp(min=0, max=1).cpu() 255 | if binned: 256 | s = s * 128 257 | s = s.view(-1, s.size(-1)) 258 | mask = mask.view(-1) 259 | if threshold is not None: 260 | keep = mask.view(-1) > threshold 261 | s = s[:, keep] 262 | mask = mask[keep] 263 | return s, mask 264 | 265 | 266 | def outer(a, b=None): 267 | """ Compute outer product between a and b (or a and a if b is not specified). """ 268 | if b is None: 269 | b = a 270 | size_a = tuple(a.size()) + (b.size()[-1],) 271 | size_b = tuple(b.size()) + (a.size()[-1],) 272 | a = a.unsqueeze(dim=-1).expand(*size_a) 273 | b = b.unsqueeze(dim=-2).expand(*size_b) 274 | return a, b -------------------------------------------------------------------------------- /src/experiments/baseline_slot_attention/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from datetime import datetime 4 | import time 5 | import sys 6 | sys.path.append('../../') 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | import matplotlib 11 | matplotlib.use("Agg") 12 | 13 | import torchvision.transforms as transforms 14 | import torchvision.datasets as datasets 15 | import torch.multiprocessing as mp 16 | 17 | import scipy.optimize 18 | import numpy as np 19 | #from tqdm import tqdm 20 | from rtpt import RTPT 21 | 22 | import matplotlib.pyplot as plt 23 | from torch.utils.tensorboard import SummaryWriter 24 | 25 | from dataGen import SHAPEWORLD4,CLEVR, get_loader 26 | import slot_attention_module as model 27 | import set_utils as set_utils 28 | 29 | import utils as misc_utils 30 | 31 | 32 | def get_args(): 33 | parser = argparse.ArgumentParser() 34 | # generic params 35 | parser.add_argument( 36 | "--name", 37 | default=datetime.now().strftime("%Y-%m-%d_%H:%M:%S"), 38 | help="Name to store the log file as", 39 | ) 40 | parser.add_argument("--resume", help="Path to log file to resume from") 41 | 42 | parser.add_argument( 43 | "--seed", type=int, default=10, help="Random generator seed for all frameworks" 44 | ) 45 | parser.add_argument( 46 | "--epochs", type=int, default=10, help="Number of epochs to train with" 47 | ) 48 | parser.add_argument( 49 | "--ap-log", type=int, default=10, help="Number of epochs before logging AP" 50 | ) 51 | parser.add_argument( 52 | "--lr", type=float, default=1e-2, help="Outer learning rate of model" 53 | ) 54 | parser.add_argument( 55 | "--warmup-epochs", type=int, default=10, help="Number of steps fpr learning rate warm up" 56 | ) 57 | parser.add_argument( 58 | "--decay-epochs", type=int, default=10, help="Number of steps fpr learning rate decay" 59 | ) 60 | parser.add_argument( 61 | "--batch-size", type=int, default=32, help="Batch size to train with" 62 | ) 63 | parser.add_argument( 64 | "--num-workers", type=int, default=6, help="Number of threads for data loader" 65 | ) 66 | parser.add_argument( 67 | "--dataset", 68 | choices=["shapeworld4", "clevr"], 69 | help="Use shapeworld4 dataset", 70 | ) 71 | parser.add_argument( 72 | "--cogent", action='store_true', 73 | help="Evaluate on the CoGenT test of the dataset", 74 | ) 75 | parser.add_argument( 76 | "--no-cuda", 77 | action="store_true", 78 | help="Run on CPU instead of GPU (not recommended)", 79 | ) 80 | parser.add_argument( 81 | "--train-only", action="store_true", help="Only run training, no evaluation" 82 | ) 83 | parser.add_argument( 84 | "--eval-only", action="store_true", help="Only run evaluation, no training" 85 | ) 86 | parser.add_argument("--multi-gpu", action="store_true", help="Use multiple GPUs") 87 | parser.add_argument("--credentials", type=str, help="Credentials for rtpt") 88 | parser.add_argument("--export-dir", type=str, help="Directory to output samples to") 89 | parser.add_argument("--data-dir", type=str, help="Directory to data") 90 | 91 | # Slot attention params 92 | parser.add_argument('--n-slots', default=10, type=int, 93 | help='number of slots for slot attention module') 94 | parser.add_argument('--n-iters-slot-att', default=3, type=int, 95 | help='number of iterations in slot attention module') 96 | parser.add_argument('--n-attr', default=18, type=int, 97 | help='number of attributes per object') 98 | 99 | args = parser.parse_args() 100 | 101 | if args.no_cuda: 102 | args.device = 'cpu' 103 | else: 104 | args.device = 'cuda:0' 105 | 106 | misc_utils.set_manual_seed(args.seed) 107 | 108 | args.name += f'-{args.seed}' 109 | 110 | return args 111 | 112 | 113 | def run(net, loader, optimizer, criterion, writer, args, test_cond = None, train=False, epoch=0): 114 | if train: 115 | net.train() 116 | prefix = "train" 117 | torch.set_grad_enabled(True) 118 | else: 119 | net.eval() 120 | prefix = "test" 121 | torch.set_grad_enabled(False) 122 | 123 | preds_all = torch.zeros(0, args.n_slots, args.n_attr) 124 | target_all = torch.zeros(0, args.n_slots, args.n_attr) 125 | 126 | iters_per_epoch = len(loader) 127 | 128 | 129 | for i, sample in enumerate(loader, start=epoch * iters_per_epoch): 130 | 131 | start = time.time() 132 | 133 | # input is either a set or an image 134 | if 'cuda' in args.device: 135 | imgs, target_set = map(lambda x: x.cuda(), sample) 136 | else: 137 | imgs, target_set = sample 138 | 139 | load = time.time() 140 | #print("\nload", load-start) 141 | 142 | output = net.forward(imgs) 143 | 144 | loss = set_utils.hungarian_loss(output, target_set) 145 | 146 | forward = time.time() 147 | #print("forward", forward-load) 148 | 149 | if train: 150 | 151 | # apply lr schedulers 152 | if epoch < args.warmup_epochs: 153 | lr = args.lr * ((epoch + 1) / args.warmup_epochs) 154 | else: 155 | lr = args.lr 156 | lr = lr * 0.5 ** ((epoch + 1) / args.decay_epochs) 157 | optimizer.param_groups[0]['lr'] = lr 158 | 159 | optimizer.zero_grad() 160 | loss.backward(retain_graph=True) 161 | optimizer.step() 162 | backward = time.time() 163 | #print("backward", backward-forward) 164 | 165 | 166 | writer.add_scalar("train/loss_baseline", loss.item(), global_step=i) 167 | log = time.time() 168 | #print("log", log-backward) 169 | 170 | #writer.add_scalar("lr/", optimizer.param_groups[0]["lr"], global_step=i) 171 | else: 172 | if i % iters_per_epoch == 0: 173 | 174 | preds_all = torch.cat((preds_all, output.detach().cpu()), 0) 175 | target_all = torch.cat((target_all, target_set.detach().cpu()), 0) 176 | 177 | ap = [ 178 | set_utils.average_precision_shapeworld(preds_all.detach().cpu().numpy(), 179 | target_all.detach().cpu().numpy(), d, args.dataset) 180 | for d in [-1] # since we are not using 3D coords #[-1., 1., 0.5, 0.25, 0.125] 181 | ] 182 | 183 | print(f"\nCurrent AP: ", ap[0], " %\n") 184 | if test_cond == "a": 185 | writer.add_scalar("test/ap_cond_a", ap[0], global_step=i) 186 | elif test_cond == "b": 187 | writer.add_scalar("test/ap_cond_b", ap[0], global_step=i) 188 | else: 189 | writer.add_scalar("test/ap", ap[0], global_step=i) 190 | return ap 191 | 192 | if train: 193 | print(f"Epoch {epoch} Train Loss: {loss.item()}") 194 | 195 | 196 | def main(): 197 | args = get_args() 198 | 199 | writer = SummaryWriter(os.path.join("runs", args.name), purge_step=0) 200 | 201 | if args.cogent: 202 | if args.dataset == "clevr": 203 | dataset_train = CLEVR(args.data_dir, "trainA") 204 | dataset_test_a = CLEVR(args.data_dir, "valA") 205 | dataset_test_b = CLEVR(args.data_dir, "valB") 206 | elif args.dataset == "shapeworld4": 207 | dataset_train = SHAPEWORLD4(args.data_dir, "train_a") 208 | dataset_test_a = SHAPEWORLD4(args.data_dir, "val_a") 209 | dataset_test_b = SHAPEWORLD4(args.data_dir, "val_b") 210 | else: 211 | if args.dataset == "clevr": 212 | dataset_train = CLEVR(args.data_dir, "train") 213 | dataset_test = CLEVR(args.data_dir, "val") 214 | elif args.dataset == "shapeworld4": 215 | dataset_train = SHAPEWORLD4(args.data_dir, "train") 216 | dataset_test = SHAPEWORLD4(args.data_dir, "val") 217 | 218 | print('data loaded') 219 | 220 | if not args.eval_only: 221 | train_loader = get_loader( 222 | dataset_train, batch_size=args.batch_size, num_workers=args.num_workers 223 | ) 224 | if not args.train_only: 225 | if args.cogent: 226 | test_loader_a = get_loader( 227 | dataset_test_a, 228 | batch_size=5000, 229 | num_workers=args.num_workers, 230 | shuffle=False) 231 | test_loader_b = get_loader( 232 | dataset_test_b, 233 | batch_size=5000, 234 | num_workers=args.num_workers, 235 | shuffle=False) 236 | else: 237 | test_loader = get_loader( 238 | dataset_test, 239 | batch_size=5000, 240 | num_workers=args.num_workers, 241 | shuffle=False) 242 | 243 | 244 | # print(torch.cuda.is_available()) 245 | if args.dataset == "shapeworld4": 246 | net = model.SlotAttention_model(n_slots=4, n_iters=3, n_attr=15, 247 | encoder_hidden_channels=32, 248 | attention_hidden_channels=64, 249 | mlp_prediction=True, 250 | device=args.device) 251 | elif args.dataset == "clevr": 252 | net = model.SlotAttention_model(n_slots=10, n_iters=3, n_attr=15, 253 | encoder_hidden_channels=64, 254 | attention_hidden_channels=128, 255 | mlp_prediction=True, 256 | device=args.device, 257 | clevr_encoding=True) 258 | 259 | args.n_attr = net.n_attr 260 | 261 | 262 | if not args.no_cuda: 263 | net = net.cuda() 264 | 265 | optimizer = torch.optim.Adam(net.parameters(), lr=args.lr) 266 | 267 | criterion = torch.nn.SmoothL1Loss() 268 | # Create RTPT object 269 | rtpt = RTPT(name_initials=args.credentials, experiment_name=f"Set prediction baseline ", 270 | max_iterations=args.epochs) 271 | 272 | # store args as txt file 273 | set_utils.save_args(args, writer) 274 | 275 | 276 | 277 | ap_list = [] 278 | ap_list_a = [] # for cogent 279 | ap_list_b = [] # for cogent 280 | total_train_time = 0 281 | total_test_time = 0 282 | time_list =[] 283 | time_start = time.time() 284 | 285 | start_epoch = 0 286 | if args.resume: 287 | print("Loading ckpt ...") 288 | log = torch.load(args.resume) 289 | weights = log["weights"] 290 | net.load_state_dict(weights, strict=True) 291 | start_epoch = log["args"]["epochs_trained"] + 1 292 | 293 | ap_list = log["ap_list"] 294 | ap_list_a = log["ap_list_a"] 295 | ap_list_b = log["ap_list_b"] 296 | time_list =log["time_list"] 297 | print("continue with epoch", start_epoch) 298 | 299 | 300 | 301 | # Start the RTPT tracking 302 | rtpt.start() 303 | 304 | 305 | for epoch in np.arange(start_epoch, args.epochs): 306 | print("Epoch:", epoch) 307 | if not args.eval_only: 308 | train_start = time.time() 309 | run(net, train_loader, optimizer, criterion, writer, args, train=True, epoch=epoch) 310 | time_train = time.time() - train_start 311 | rtpt.step() 312 | if not args.train_only: 313 | test_start = time.time() 314 | if args.cogent: 315 | ap_a = run(net, test_loader_a, None, criterion, writer, args, test_cond="a", train=False, epoch=epoch) 316 | ap_list_a.append(ap_a) 317 | ap_b = run(net, test_loader_b, None, criterion, writer, args, test_cond="b", train=False, epoch=epoch) 318 | ap_list_b.append(ap_b) 319 | else: 320 | ap = run(net, test_loader, None, criterion, writer, args, train=False, epoch=epoch) 321 | ap_list.append(ap) 322 | 323 | 324 | time_test = time.time() - test_start 325 | if args.eval_only: 326 | exit() 327 | torch.cuda.empty_cache() 328 | 329 | args.epochs_trained = epoch 330 | 331 | total_test_time += time_test 332 | total_train_time += time_train 333 | time_list.append([total_train_time, total_test_time, time_train, time_test]) 334 | 335 | 336 | results = { 337 | "name": args.name, 338 | "weights": net.state_dict(), 339 | "args": vars(args), 340 | "ap_list": ap_list, #empty for cogent 341 | "ap_list_a": ap_list_a, 342 | "ap_list_b": ap_list_b, 343 | "time": time_list} 344 | 345 | torch.save(results, os.path.join("runs", args.name, args.name)) 346 | if args.eval_only: 347 | break 348 | 349 | print("total time", misc_utils.time_delta_now(time_start)) 350 | 351 | 352 | if __name__ == "__main__": 353 | main() 354 | -------------------------------------------------------------------------------- /src/experiments/baseline_slot_attention/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #MODEL="slot-attention-set-pred-shapeworld4" 4 | #MODEL="slot-attention-set-pred-shapeworld4-cogent" 5 | MODEL="slot-attention-set-pred-clevr" 6 | #MODEL="slot-attention-set-pred-clevr-cogent" 7 | 8 | 9 | #DATA="../data/shapeworld4" 10 | #DATA="../data/shapeworld_cogent" 11 | DATA="../../../data/CLEVR_v1.0" 12 | #DATA="../../../data/CLEVR_CoGenT_v1.0" 13 | 14 | 15 | #DATASET=shapeworld4 16 | DATASET=clevr 17 | 18 | DEVICE=$1 19 | SEED=$2 # 0, 1, 2, 3, 4 20 | CREDENTIALS=$3 21 | #-------------------------------------------------------------------------------# 22 | 23 | CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \ 24 | --data-dir $DATA --dataset $DATASET --epochs 1000 \ 25 | --name $MODEL --lr 0.0004 --batch-size 512 --n-slots 10 --n-iters-slot-att 3 --n-attr 15 --ap-log 1 --seed $SEED \ 26 | --warmup-epochs 8 --decay-epochs 360 --num-workers 8 --credentials $CREDENTIALS 27 | 28 | 29 | 30 | # # CLEVR 31 | # for S in 0 1 2 3 4 32 | # do 33 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \ 34 | # --data-dir $DATA --dataset $DATASET --epochs 1000 \ 35 | # --name $MODEL --lr 0.0004 --batch-size 512 --n-slots 10 --n-iters-slot-att 3 --n-attr 15 --ap-log 1 --seed $S \ 36 | # --warmup-epochs 8 --decay-epochs 360 --num-workers 8 --credentials $CREDENTIALS 37 | # done 38 | 39 | # # CLEVR COGENT 40 | # for S in 0 1 2 3 4 41 | # do 42 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \ 43 | # --data-dir $DATA --dataset $DATASET --epochs 1000 \ 44 | # --name $MODEL --lr 0.0004 --batch-size 512 --n-slots 10 --n-iters-slot-att 3 --n-attr 15 --ap-log 1 --seed $S \ 45 | # --warmup-epochs 8 --decay-epochs 360 --num-workers 8 --credentials $CREDENTIALS 46 | # done 47 | 48 | # # SHAPEWORLD4 49 | # for S in 0 1 2 3 4 50 | # do 51 | 52 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \ 53 | # --data-dir $DATA --dataset $DATASET --epochs 1000 \ 54 | # --name $MODEL --lr 0.0004 --batch-size 512 --n-slots 4 --n-iters-slot-att 3 --n-attr 15 --ap-log 1 --seed $S \ 55 | # --warmup-epochs 8 --decay-epochs 360 --num-workers 8 --credentials $CREDENTIALS 56 | # done 57 | 58 | # SHAPEWORLD4 COGENT 59 | # for S in 0 1 2 3 4 60 | # do 61 | 62 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \ 63 | # --data-dir $DATA --dataset $DATASET --epochs 1000 \ 64 | # --name $MODEL --lr 0.0004 --batch-size 512 --n-slots 4 --n-iters-slot-att 3 --n-attr 15 --ap-log 1 --seed $S \ 65 | # --warmup-epochs 8 --decay-epochs 360 --num-workers 8 --credentialpns $CREDENTIALS --cogent 66 | # done 67 | 68 | -------------------------------------------------------------------------------- /src/experiments/mnist_top_k/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/SLASH/79d158dbd670d56d8bf9043640635c24a18a9c57/src/experiments/mnist_top_k/__init__.py -------------------------------------------------------------------------------- /src/experiments/mnist_top_k/dataGen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import numpy as np 4 | 5 | class MNIST_Addition(Dataset): 6 | 7 | def __init__(self, dataset, examples, num_i, flat_for_pc): 8 | self.data = list() 9 | self.dataset = dataset 10 | self.num_i = num_i 11 | self.flat_for_pc = flat_for_pc 12 | 13 | with open(examples) as f: 14 | for line in f: 15 | line = line.strip().split(' ') 16 | self.data.append(tuple([int(i) for i in line])) 17 | 18 | 19 | def __getitem__(self, index): 20 | if self.num_i == 2: 21 | i1, i2, l = self.data[index] 22 | l = ':- not addition(i1, i2, {}).'.format(l) 23 | if self.flat_for_pc: 24 | return {'i1': self.dataset[i1][0].flatten(), 'i2': self.dataset[i2][0].flatten()}, l 25 | else: 26 | return {'i1': self.dataset[i1][0], 'i2': self.dataset[i2][0]}, l 27 | 28 | elif self.num_i == 3: 29 | i1, i2, i3, l = self.data[index] 30 | l = ':- not addition(i1, i2, i3, {}).'.format(l) 31 | if self.flat_for_pc: 32 | return {'i1': self.dataset[i1][0].flatten(), 'i2': self.dataset[i2][0].flatten(), 'i3': self.dataset[i3][0].flatten()}, l 33 | else: 34 | return {'i1': self.dataset[i1][0], 'i2': self.dataset[i2][0], 'i3': self.dataset[i3][0]}, l 35 | 36 | elif self.num_i == 4: 37 | i1, i2, i3, i4, l = self.data[index] 38 | l = ':- not addition(i1, i2, i3, i4, {}).'.format(l) 39 | if self.flat_for_pc: 40 | return {'i1': self.dataset[i1][0].flatten(), 'i2': self.dataset[i2][0].flatten(), 'i3': self.dataset[i3][0].flatten(), 'i4': self.dataset[i4][0].flatten()}, l 41 | else: 42 | return {'i1': self.dataset[i1][0], 'i2': self.dataset[i2][0], 'i3': self.dataset[i3][0], 'i4': self.dataset[i4][0]}, l 43 | 44 | elif self.num_i == 6: 45 | i1, i2, i3, i4, i5,i6, l = self.data[index] 46 | l = ':- not addition(i1, i2, i3, i4, i5, i6, {}).'.format(l) 47 | if self.flat_for_pc: 48 | return {'i1': self.dataset[i1][0].flatten(), 'i2': self.dataset[i2][0].flatten(), 'i3': self.dataset[i3][0].flatten(), 'i4': self.dataset[i4][0].flatten(), 'i5': self.dataset[i5][0].flatten(), 'i6': self.dataset[i6][0].flatten()}, l 49 | else: 50 | return {'i1': self.dataset[i1][0], 'i2': self.dataset[i2][0], 'i3': self.dataset[i3][0], 'i4': self.dataset[i4][0], 'i5': self.dataset[i5][0], 'i6': self.dataset[i6][0]}, l 51 | 52 | 53 | def __len__(self): 54 | return len(self.data) -------------------------------------------------------------------------------- /src/experiments/mnist_top_k/network_nn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Net_nn(nn.Module): 5 | def __init__(self): 6 | super(Net_nn, self).__init__() 7 | self.encoder = nn.Sequential( 8 | nn.Conv2d(1, 6, 5), # 6 is the output chanel size; 5 is the kernal size; 1 (chanel) 28 28 -> 6 24 24 9 | nn.MaxPool2d(2, 2), # kernal size 2; stride size 2; 6 24 24 -> 6 12 12 10 | nn.ReLU(True), # inplace=True means that it will modify the input directly thus save memory 11 | nn.Conv2d(6, 16, 5), # 6 12 12 -> 16 8 8 12 | nn.MaxPool2d(2, 2), # 16 8 8 -> 16 4 4 13 | nn.ReLU(True) 14 | ) 15 | self.classifier = nn.Sequential( 16 | nn.Linear(16 * 4 * 4, 120), 17 | nn.ReLU(), 18 | nn.Linear(120, 84), 19 | nn.ReLU(), 20 | nn.Linear(84, 10), 21 | nn.Softmax(1) 22 | ) 23 | 24 | def forward(self, x, marg_idx=None, type=1): 25 | 26 | assert type == 1, "only posterior computations are available for this network" 27 | 28 | # If the list of the pixel numbers to be marginalised is given, 29 | # then genarate a marginalisation mask from it and apply to the 30 | # tensor 'x' 31 | if marg_idx: 32 | batch_size = x.shape[0] 33 | with torch.no_grad(): 34 | marg_mask = torch.ones_like(x, device=x.device).reshape(batch_size, 1, -1) 35 | marg_mask[:, :, marg_idx] = 0 36 | marg_mask = marg_mask.reshape_as(x) 37 | marg_mask.requires_grad_(False) 38 | x = torch.einsum('ijkl,ijkl->ijkl', x, marg_mask) 39 | x = self.encoder(x) 40 | x = x.view(-1, 16 * 4 * 4) 41 | x = self.classifier(x) 42 | return x 43 | -------------------------------------------------------------------------------- /src/experiments/mnist_top_k/train.py: -------------------------------------------------------------------------------- 1 | print("start importing...") 2 | 3 | import time 4 | import sys 5 | import argparse 6 | import datetime 7 | 8 | sys.path.append('../../') 9 | sys.path.append('../../SLASH/') 10 | sys.path.append('../../EinsumNetworks/src/') 11 | 12 | 13 | #torch, numpy, ... 14 | import torch 15 | from torch.utils.tensorboard import SummaryWriter 16 | from torchvision.transforms import transforms 17 | import torchvision 18 | 19 | import numpy as np 20 | 21 | #own modules 22 | from dataGen import MNIST_Addition 23 | from einsum_wrapper import EiNet 24 | from network_nn import Net_nn 25 | 26 | #import slash 27 | from slash import SLASH 28 | 29 | import utils 30 | from utils import set_manual_seed 31 | from pathlib import Path 32 | from rtpt import RTPT 33 | 34 | print("...done") 35 | 36 | 37 | def get_args(): 38 | parser = argparse.ArgumentParser() 39 | 40 | parser.add_argument( 41 | "--seed", type=int, default=10, help="Random generator seed for all frameworks" 42 | ) 43 | parser.add_argument( 44 | "--epochs", type=int, default=10, help="Number of epochs to train with" 45 | ) 46 | parser.add_argument( 47 | "--lr", type=float, default=0.01, help="Learning rate of model" 48 | ) 49 | parser.add_argument( 50 | "--network-type", 51 | choices=["nn","pc"], 52 | help="The type of external to be used e.g. neural net or probabilistic circuit", 53 | ) 54 | parser.add_argument( 55 | "--pc-structure", 56 | choices=["poon-domingos","binary-trees"], 57 | help="The type of external to be used e.g. neural net or probabilistic circuit", 58 | ) 59 | 60 | parser.add_argument( 61 | "--method", 62 | choices=["exact","top_k","same"], 63 | help="How many images should be used in the addition", 64 | ) 65 | parser.add_argument( 66 | "--k", type=int, default=0, help="Maximum number of stable model to be used" 67 | ) 68 | parser.add_argument( 69 | "--images-per-addition", 70 | choices=["2","3","4","6"], 71 | help="How many images should be used in the addition", 72 | ) 73 | parser.add_argument( 74 | "--batch-size", type=int, default=100, help="Batch size to train with" 75 | ) 76 | parser.add_argument( 77 | "--num-workers", type=int, default=6, help="Number of threads for data loader" 78 | ) 79 | 80 | parser.add_argument( 81 | "--p-num", type=int, default=8, help="Number of processes to devide the batch for parallel processing" 82 | ) 83 | 84 | parser.add_argument("--credentials", type=str, help="Credentials for rtpt") 85 | 86 | 87 | args = parser.parse_args() 88 | 89 | if args.network_type == 'pc': 90 | args.use_pc = True 91 | else: 92 | args.use_pc = False 93 | 94 | return args 95 | 96 | 97 | def slash_mnist_addition(): 98 | 99 | args = get_args() 100 | print(args) 101 | 102 | 103 | # Set the seeds for PRNG 104 | set_manual_seed(args.seed) 105 | 106 | # Create RTPT object 107 | rtpt = RTPT(name_initials=args.credentials, experiment_name='SLASH MNIST pick-k', max_iterations=args.epochs) 108 | 109 | # Start the RTPT tracking 110 | rtpt.start() 111 | 112 | i_num = int(args.images_per_addition) 113 | if i_num == 2: 114 | program = ''' 115 | img(i1). img(i2). 116 | addition(A,B,N):- digit(0,+A,-N1), digit(0,+B,-N2), N=N1+N2, A!=B. 117 | npp(digit(1,X), [0,1,2,3,4,5,6,7,8,9]) :- img(X). 118 | ''' 119 | 120 | elif i_num == 3: 121 | program = ''' 122 | img(i1). img(i2). img(i3). 123 | addition(A,B,C,N):- digit(0,+A,-N1), digit(0,+B,-N2), digit(0,+C,-N3), N=N1+N2+N3, A!=B, A!=C, B!=C. 124 | npp(digit(1,X), [0,1,2,3,4,5,6,7,8,9]) :- img(X). 125 | ''' 126 | 127 | elif i_num == 4: 128 | program = ''' 129 | img(i1). img(i2). img(i3). img(i4). 130 | addition(A,B,C,D,N):- digit(0,+A,-N1), digit(0,+B,-N2), digit(0,+C,-N3), digit(0,+D,-N4), N=N1+N2+N3+N4, A != B, A!=C, A!=D, B!= C, B!= D, C!=D. 131 | npp(digit(1,X), [0,1,2,3,4,5,6,7,8,9]) :- img(X). 132 | ''' 133 | 134 | elif i_num == 6: 135 | program = ''' 136 | img(i1). img(i2). img(i3). img(i4). img(i5). img(i6). 137 | addition(i1,i2,i3,i4,i5,i6,N):- digit(0,+A,-N1), digit(0,+B,-N2), digit(0,+C,-N3), digit(0,+D,-N4), digit(0,+E,-N5), digit(0,+F,-N6), N=N1+N2+N3+N4+N5+N6, A != B, A!=C, A!=D, A!=E, A!=F, B!= C, B!= D, B!=E, B!=F, C!=D, C!=E, C!=F, D!=E, D!=F, E!=F. 138 | npp(digit(1,X), [0,1,2,3,4,5,6,7,8,9]) :- img(X). 139 | ''' 140 | 141 | exp_name= str(args.method)+"/" +args.network_type+"_i"+str(i_num)+"_k"+ str(args.k) 142 | 143 | saveModelPath = 'data/'+exp_name+'/slash_digit_addition_models_seed'+str(args.seed)+'.pt' 144 | Path("data/"+exp_name+"/").mkdir(parents=True, exist_ok=True) 145 | 146 | 147 | #use neural net or probabilisitc circuit 148 | if args.network_type == 'pc': 149 | 150 | #setup new SLASH program given the network parameters 151 | if args.pc_structure == 'binary-trees': 152 | m = EiNet(structure = 'binary-trees', 153 | depth = 3, 154 | num_repetitions = 20, 155 | use_em = False, 156 | num_var = 784, 157 | class_count = 10, 158 | learn_prior = True) 159 | elif args.pc_structure == 'poon-domingos': 160 | m = EiNet(structure = 'poon-domingos', 161 | pd_num_pieces = [4,7,28], 162 | use_em = False, 163 | num_var = 784, 164 | class_count = 10, 165 | pd_width = 28, 166 | pd_height = 28, 167 | learn_prior = True) 168 | else: 169 | print("pc structure learner unknown") 170 | 171 | else: 172 | m = Net_nn() 173 | 174 | 175 | #trainable paramas 176 | num_trainable_params = sum(p.numel() for p in m.parameters() if p.requires_grad) 177 | num_params = sum(p.numel() for p in m.parameters()) 178 | print("training with {} trainable params and {} params in total".format(num_trainable_params,num_params)) 179 | 180 | 181 | #create the SLASH Program 182 | nnMapping = {'digit': m} 183 | optimizers = {'digit': torch.optim.Adam(m.parameters(), lr=args.lr, eps=1e-7)} 184 | SLASHobj = SLASH(program, nnMapping, optimizers) 185 | SLASHobj.grad_comp_device ='cpu' #set gradient computation to cpu 186 | 187 | 188 | #metric lists 189 | train_accuracy_list = [] 190 | test_accuracy_list = [] 191 | confusion_matrix_list = [] 192 | loss_list = [] 193 | startTime = time.time() 194 | 195 | forward_time_list = [] 196 | asp_time_list = [] 197 | backward_time_list = [] 198 | sm_per_batch_list = [] 199 | train_test_times = [] 200 | 201 | #load data 202 | #if we are using spns we need to flatten the data(Tensor has form [bs, 784]) 203 | if args.use_pc: 204 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081, )), transforms.Lambda(lambda x: torch.flatten(x))]) 205 | #if not we can keep the dimensions(Tensor has form [bs,28,28]) 206 | else: 207 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081, ))]) 208 | data_path = 'data/labels/train_data_s'+str(i_num)+'.txt' 209 | 210 | mnist_addition_dataset = MNIST_Addition(torchvision.datasets.MNIST(root='./data/', train=True, download=True, transform=transform), data_path, i_num, args.use_pc) 211 | train_dataset_loader = torch.utils.data.DataLoader(mnist_addition_dataset, shuffle=True,batch_size=args.batch_size,pin_memory=True, num_workers=8) 212 | 213 | test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('./data/', train=False, transform=transform), batch_size=100, shuffle=True) 214 | train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('./data/', train=True, transform=transform), batch_size=100, shuffle=True) 215 | 216 | 217 | # Evaluate the performanve directly after initialisation 218 | time_test = time.time() 219 | test_acc, _, confusion_matrix = SLASHobj.testNetwork('digit', test_loader, ret_confusion=True) 220 | train_acc, _ = SLASHobj.testNetwork('digit', train_loader) 221 | confusion_matrix_list.append(confusion_matrix) 222 | train_accuracy_list.append([train_acc,0]) 223 | test_accuracy_list.append([test_acc, 0]) 224 | timestamp_test = utils.time_delta_now(time_test, simple_format=True) 225 | timestamp_total = utils.time_delta_now(startTime, simple_format=True) 226 | 227 | train_test_times.append([0.0, timestamp_test, timestamp_total]) 228 | 229 | # Save and print statistics 230 | print('Train Acc: {:0.2f}%, Test Acc: {:0.2f}%'.format(train_acc, test_acc)) 231 | print('--- train time: ---', 0) 232 | print('--- test time: ---' , timestamp_test) 233 | print('--- total time from beginning: ---', timestamp_total) 234 | 235 | # Export results and networks 236 | print('Storing the trained model into {}'.format(saveModelPath)) 237 | torch.save({"addition_net": m.state_dict(), 238 | "test_accuracy_list": test_accuracy_list, 239 | "train_accuracy_list":train_accuracy_list, 240 | "confusion_matrix_list":confusion_matrix_list, 241 | "num_params": num_trainable_params, 242 | "args":args, 243 | "exp_name":exp_name, 244 | "train_test_times": train_test_times, 245 | "program":program}, saveModelPath) 246 | 247 | start_e= 0 248 | 249 | # Train and evaluate the performance 250 | for e in range(start_e, args.epochs): 251 | print('Epoch {}...'.format(e+1)) 252 | 253 | #one epoch of training 254 | time_train= time.time() 255 | loss, forward_time, asp_time, backward_time, sm_per_batch, model_computation_time, gradient_computation_time = SLASHobj.learn(dataset_loader = train_dataset_loader, 256 | epoch=e, method=args.method, p_num=args.p_num, k_num = args.k , same_threshold=0.99) 257 | timestamp_train = utils.time_delta_now(time_train, simple_format=True) 258 | 259 | #store detailed timesteps per batch 260 | forward_time_list.append(forward_time) 261 | asp_time_list.append(asp_time) 262 | backward_time_list.append(backward_time) 263 | sm_per_batch_list.append(sm_per_batch) 264 | 265 | 266 | time_test = time.time() 267 | test_acc, _, confusion_matrix = SLASHobj.testNetwork('digit', test_loader, ret_confusion=True) 268 | confusion_matrix_list.append(confusion_matrix) 269 | train_acc, _ = SLASHobj.testNetwork('digit', train_loader) 270 | train_accuracy_list.append([train_acc,e]) 271 | test_accuracy_list.append([test_acc, e]) 272 | timestamp_test = utils.time_delta_now(time_test, simple_format=True) 273 | timestamp_total = utils.time_delta_now(startTime, simple_format=True) 274 | loss_list.append(loss) 275 | train_test_times.append([timestamp_train, timestamp_test, timestamp_total]) 276 | 277 | # Save and print statistics 278 | print('Train Acc: {:0.2f}%, Test Acc: {:0.2f}%'.format(train_acc, test_acc)) 279 | print('--- train time: ---', timestamp_train) 280 | print('--- test time: ---' , timestamp_test) 281 | print('--- total time from beginning: ---', timestamp_total) 282 | 283 | # Export results and networks 284 | print('Storing the trained model into {}'.format(saveModelPath)) 285 | torch.save({"addition_net": m.state_dict(), 286 | "resume": { 287 | "optimizer_digit":optimizers['digit'].state_dict(), 288 | "epoch":e 289 | }, 290 | "test_accuracy_list": test_accuracy_list, 291 | "train_accuracy_list":train_accuracy_list, 292 | "confusion_matrix_list":confusion_matrix_list, 293 | "num_params": num_trainable_params, 294 | "args":args, 295 | "exp_name":exp_name, 296 | "train_test_times": train_test_times, 297 | "forward_time_list":forward_time_list, 298 | "asp_time_list":asp_time_list, 299 | "backward_time_list":backward_time_list, 300 | "sm_per_batch_list":sm_per_batch_list, 301 | "loss": loss_list, 302 | "program":program}, saveModelPath) 303 | 304 | # Update the RTPT 305 | rtpt.step(subtitle=f"accuracy={test_acc:2.2f}") 306 | 307 | 308 | 309 | 310 | if __name__ == "__main__": 311 | slash_mnist_addition() -------------------------------------------------------------------------------- /src/experiments/mnist_top_k/train_baseline.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DEVICE=$1 4 | SEED=$2 # 0, 1, 2, 3, 4 5 | CREDENTIALS=$3 6 | 7 | 8 | METHOD=exact # ibu, top_k, exact 9 | K=0 #1,3,5,10 10 | 11 | #-------------------------------------------------------------------------------# 12 | # Train on CLEVR_v1 with cnn model 13 | 14 | CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \ 15 | --epochs 30 \ 16 | --batch-size 100 --seed $SEED --method=$METHOD --images-per-addition=2 --k=$K \ 17 | --network-type nn --lr 0.005 \ 18 | --num-workers 0 --p-num 8 --credentials $CREDENTIALS 19 | 20 | 21 | #NN 22 | # for S in 0 1 2 3 4 23 | # do 24 | 25 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \ 26 | # --epochs 30 \ 27 | # --batch-size 100 --seed $S --method=$METHOD --images-per-addition=2 --k=$K \ 28 | # --network-type nn --lr 0.005 \ 29 | # --num-workers 0 --p-num 8 --credentials $CREDENTIALS 30 | 31 | 32 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \ 33 | # --epochs 30 \ 34 | # --batch-size 100 --seed $S --method=$METHOD --images-per-addition=3 --k=$K \ 35 | # --network-type nn --lr 0.005 \ 36 | # --num-workers 0 --p-num 8 --credentials $CREDENTIALS 37 | 38 | 39 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \ 40 | # --epochs 30 \ 41 | # --batch-size 100 --seed $S --method=$METHOD --images-per-addition=4 --k=$K \ 42 | # --network-type nn --lr 0.005 \ 43 | # --num-workers 0 --p-num 8 --credentials $CREDENTIALS 44 | # done 45 | 46 | 47 | #PC 48 | # for S in 0 1 2 3 4 49 | # do 50 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \ 51 | # --epochs 30 \ 52 | # --batch-size 100 --seed $S --method=$METHOD --images-per-addition=2 --k=$K \ 53 | # --network-type pc --pc-structure poon-domingos --lr 0.01 \ 54 | # --num-workers 0 --p-num 8 --credentials $CREDENTIALS 55 | 56 | 57 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \ 58 | # --epochs 30 \ 59 | # --batch-size 100 --seed $S --method=$METHOD --images-per-addition=3 --k=$K \ 60 | # --network-type pc --pc-structure poon-domingos --lr 0.01 \ 61 | # --num-workers 0 --p-num 8 --credentials $CREDENTIALS 62 | 63 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \ 64 | # --epochs 30 \ 65 | # --batch-size 100 --seed $S --method=$METHOD --images-per-addition=4 --k=$K \ 66 | # --network-type pc --pc-structure poon-domingos --lr 0.01 \ 67 | # --num-workers 0 --p-num 8 --credentials $CREDENTIALS 68 | # done -------------------------------------------------------------------------------- /src/experiments/mnist_top_k/train_same.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DEVICE=$1 4 | SEED=$2 # 0, 1, 2, 3, 4 5 | CREDENTIALS=$3 6 | 7 | METHOD=same # same, top_k, exact 8 | K=0 #1,3,5,10 9 | 10 | #-------------------------------------------------------------------------------# 11 | # Train MNIST with SAME 12 | 13 | 14 | CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \ 15 | --epochs 30 \ 16 | --batch-size 100 --seed 42 --method=$METHOD --images-per-addition=2 --k=$K \ 17 | --network-type nn --lr 0.005 \ 18 | --num-workers 0 --p-num 8 --credentials $CREDENTIALS 19 | 20 | 21 | 22 | #NN 23 | # for S in 0 1 2 3 4 24 | # do 25 | 26 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \ 27 | # --epochs 30 \ 28 | # --batch-size 100 --seed $S --method=$METHOD --images-per-addition=2 --k=$K \ 29 | # --network-type nn --lr 0.005 \ 30 | # --num-workers 0 --p-num 8 --credentials $CREDENTIALS 31 | 32 | 33 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \ 34 | # --epochs 30 \ 35 | # --batch-size 100 --seed $S --method=$METHOD --images-per-addition=3 --k=$K \ 36 | # --network-type nn --lr 0.005 \ 37 | # --num-workers 0 --p-num 8 --credentials $CREDENTIALS 38 | 39 | 40 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \ 41 | # --epochs 30 \ 42 | # --batch-size 100 --seed $S --method=$METHOD --images-per-addition=4 --k=$K \ 43 | # --network-type nn --lr 0.005 \ 44 | # --num-workers 0 --p-num 8 --credentials $CREDENTIALS 45 | # done 46 | 47 | 48 | #PC 49 | # for S in 0 1 2 3 4 50 | # do 51 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \ 52 | # --epochs 30 \ 53 | # --batch-size 100 --seed $S --method=$METHOD --images-per-addition=2 --k=$K \ 54 | # --network-type pc --pc-structure poon-domingos --lr 0.01 \ 55 | # --num-workers 0 --p-num 8 --credentials $CREDENTIALS 56 | 57 | 58 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \ 59 | # --epochs 30 \ 60 | # --batch-size 100 --seed $S --method=$METHOD --images-per-addition=3 --k=$K \ 61 | # --network-type pc --pc-structure poon-domingos --lr 0.01 \ 62 | # --num-workers 0 --p-num 8 --credentials $CREDENTIALS 63 | 64 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \ 65 | # --epochs 30 \ 66 | # --batch-size 100 --seed $S --method=$METHOD --images-per-addition=4 --k=$K \ 67 | # --network-type pc --pc-structure poon-domingos --lr 0.01 \ 68 | # --num-workers 0 --p-num 8 --credentials $CREDENTIALS 69 | # done -------------------------------------------------------------------------------- /src/experiments/mnist_top_k/train_top_k.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DEVICE=$1 4 | SEED=$2 # 0, 1, 2, 3, 4 5 | CREDENTIALS=$3 6 | 7 | METHOD=top_k # ibu, top-k, exact 8 | K=10 # 1, 3, 5, 10 9 | 10 | #-------------------------------------------------------------------------------# 11 | # Train on CLEVR_v1 with cnn model 12 | 13 | CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \ 14 | --epochs 30 \ 15 | --batch-size 100 --seed 7 --method=$METHOD --images-per-addition=2 --k=1 \ 16 | --network-type nn --lr 0.005 \ 17 | --num-workers 0 --p-num 8 --credentials $CREDENTIALS 18 | 19 | 20 | 21 | 22 | #NN 23 | # for S in 0 1 2 3 4 24 | # do 25 | 26 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \ 27 | # --epochs 30 \ 28 | # --batch-size 100 --seed $S --method=$METHOD --images-per-addition=2 --k=$K \ 29 | # --network-type nn --lr 0.005 \ 30 | # --num-workers 0 --p-num 8 --credentials $CREDENTIALS 31 | 32 | 33 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \ 34 | # --epochs 30 \ 35 | # --batch-size 100 --seed $S --method=$METHOD --images-per-addition=3 --k=$K \ 36 | # --network-type nn --lr 0.005 \ 37 | # --num-workers 0 --p-num 8 --credentials $CREDENTIALS 38 | 39 | 40 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \ 41 | # --epochs 30 \ 42 | # --batch-size 100 --seed $S --method=$METHOD --images-per-addition=4 --k=$K \ 43 | # --network-type nn --lr 0.005 \ 44 | # --num-workers 0 --p-num 8 --credentials $CREDENTIALS 45 | # done 46 | 47 | 48 | #PC 49 | # for S in 0 1 2 3 4 50 | # do 51 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \ 52 | # --epochs 30 \ 53 | # --batch-size 100 --seed $S --method=$METHOD --images-per-addition=2 --k=$K \ 54 | # --network-type pc --pc-structure poon-domingos --lr 0.01 \ 55 | # --num-workers 0 --p-num 8 --credentials $CREDENTIALS 56 | 57 | 58 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \ 59 | # --epochs 30 \ 60 | # --batch-size 100 --seed $S --method=$METHOD --images-per-addition=3 --k=$K \ 61 | # --network-type pc --pc-structure poon-domingos --lr 0.01 \ 62 | # --num-workers 0 --p-num 8 --credentials $CREDENTIALS 63 | 64 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \ 65 | # --epochs 30 \ 66 | # --batch-size 100 --seed $S --method=$METHOD --images-per-addition=4 --k=$K \ 67 | # --network-type pc --pc-structure poon-domingos --lr 0.01 \ 68 | # --num-workers 0 --p-num 8 --credentials $CREDENTIALS 69 | # done -------------------------------------------------------------------------------- /src/experiments/slash_attention/clevr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/SLASH/79d158dbd670d56d8bf9043640635c24a18a9c57/src/experiments/slash_attention/clevr/__init__.py -------------------------------------------------------------------------------- /src/experiments/slash_attention/clevr/auxiliary.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from pathlib import Path 4 | 5 | 6 | def get_files_names_and_paths(root:str='./data/CLEVR_v1.0/', mode:str='val', obj_num:int=4): 7 | data_file = Path(os.path.join(root, 'scenes','CLEVR_'+ mode+"_scenes.json")) 8 | data_file.parent.mkdir(parents=True, exist_ok=True) 9 | if data_file.exists(): 10 | print('File exists. Parsing file...') 11 | else: 12 | print(f'The JSON file {data_file} does not exist!') 13 | quit() 14 | img_paths = [] 15 | files_names = [] 16 | with open(data_file, 'r') as json_file: 17 | json_data = json.load(json_file) 18 | 19 | for scene in json_data['scenes']: 20 | if len(scene['objects']) <= obj_num: 21 | img_paths.append(Path(os.path.join(root,'images/'+mode+'/'+scene['image_filename']))) 22 | files_names.append(scene['image_filename']) 23 | 24 | print("...done ") 25 | return img_paths, files_names 26 | 27 | 28 | def get_slash_program(obj_num:int=4): 29 | program = '' 30 | if obj_num == 10: 31 | program =''' 32 | slot(s1). 33 | slot(s2). 34 | slot(s3). 35 | slot(s4). 36 | slot(s5). 37 | slot(s6). 38 | slot(s7). 39 | slot(s8). 40 | slot(s9). 41 | slot(s10). 42 | 43 | name(o1). 44 | name(o2). 45 | name(o3). 46 | name(o4). 47 | name(o5). 48 | name(o6). 49 | name(o7). 50 | name(o8). 51 | name(o9). 52 | name(o10). 53 | ''' 54 | elif obj_num ==4: 55 | program =''' 56 | slot(s1). 57 | slot(s2). 58 | slot(s3). 59 | slot(s4). 60 | 61 | name(o1). 62 | name(o2). 63 | name(o3). 64 | name(o4). 65 | ''' 66 | elif obj_num ==6: 67 | program =''' 68 | slot(s1). 69 | slot(s2). 70 | slot(s3). 71 | slot(s4). 72 | slot(s5). 73 | slot(s6). 74 | 75 | name(o1). 76 | name(o2). 77 | name(o3). 78 | name(o4). 79 | name(o5). 80 | name(o6). 81 | ''' 82 | else: 83 | print(f'The number of objects {obj_num} is wrong!') 84 | quit() 85 | program +=''' 86 | %assign each name a slot 87 | %{slot_name_comb(N,X): slot(X)}=1 :- name(N). %problem we have dublicated slots 88 | 89 | %remove each model which has multiple slots asigned to the same name 90 | %:- slot_name_comb(N1,X1), slot_name_comb(N2,X2), X1 == X2, N1 != N2. 91 | 92 | %build the object ontop of the slot assignment 93 | object(N, S, M, P, C) :- size(0, +X, -S), material(0, +X, -M), shape(0, +X, -P), color(0, +X, -C), slot(X), name(N), slot_name_comb(N,X). 94 | 95 | %define the SPNs 96 | npp(size(1,X),[small, large, bg]) :- slot(X). 97 | npp(material(1,X),[rubber, metal, bg]) :- slot(X). 98 | npp(shape(1,X),[cube, sphere, cylinder, bg]) :- slot(X). 99 | npp(color(1,X),[gray, red, blue, green, brown, purple, cyan, yellow, bg]) :- slot(X). 100 | 101 | ''' 102 | return program -------------------------------------------------------------------------------- /src/experiments/slash_attention/clevr/dataGen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch.utils.data import Dataset 4 | from torchvision.transforms import transforms 5 | 6 | # from torch.utils.data import Dataset 7 | # from torchvision import transforms 8 | from skimage import io 9 | import os 10 | import numpy as np 11 | import torch 12 | import matplotlib.pyplot as plt 13 | from PIL import ImageFile 14 | ImageFile.LOAD_TRUNCATED_IMAGES = True 15 | import json 16 | 17 | from tqdm import tqdm 18 | 19 | 20 | def object_encoding(size, material, shape, color ): 21 | 22 | #size (small, large, bg) 23 | if size == "small": 24 | size_enc = [1,0,0] 25 | elif size == "large": 26 | size_enc = [0,1,0] 27 | elif size == "bg": 28 | size_enc = [0,0,1] 29 | 30 | #material (rubber, metal, bg) 31 | if material == "rubber": 32 | material_enc = [1,0,0] 33 | elif material == "metal": 34 | material_enc = [0,1,0] 35 | elif material == "bg": 36 | material_enc = [0,0,1] 37 | 38 | #shape (cube, sphere, cylinder, bg) 39 | if shape == "cube": 40 | shape_enc = [1,0,0,0] 41 | elif shape == "sphere": 42 | shape_enc = [0,1,0,0] 43 | elif shape == "cylinder": 44 | shape_enc = [0,0,1,0] 45 | elif shape == "bg": 46 | shape_enc = [0,0,0,1] 47 | 48 | #color (gray, red, blue, green, brown, purple, cyan, yellow, bg) 49 | #color (gray, red, blue, green, brown, purple, cyan, yellow, bg) 50 | #color 1 2 3 4 5 6 7 8 9 51 | if color == "gray": 52 | color_enc = [1,0,0,0,0,0,0,0,0] 53 | elif color == "red": 54 | color_enc = [0,1,0,0,0,0,0,0,0] 55 | elif color == "blue": 56 | color_enc = [0,0,1,0,0,0,0,0,0] 57 | elif color == "green": 58 | color_enc = [0,0,0,1,0,0,0,0,0] 59 | elif color == "brown": 60 | color_enc = [0,0,0,0,1,0,0,0,0] 61 | elif color == "purple": 62 | color_enc = [0,0,0,0,0,1,0,0,0] 63 | elif color == "cyan": 64 | color_enc = [0,0,0,0,0,0,1,0,0] 65 | elif color == "yellow": 66 | color_enc = [0,0,0,0,0,0,0,1,0] 67 | elif color == "bg": 68 | color_enc = [0,0,0,0,0,0,0,0,1] 69 | 70 | return size_enc + material_enc + shape_enc + color_enc +[1] 71 | 72 | 73 | 74 | 75 | 76 | class CLEVR(Dataset): 77 | def __init__(self, root, mode, img_paths=None, files_names=None, obj_num=None): 78 | self.root = root # The root folder of the dataset 79 | self.mode = mode # The mode of 'train' or 'val' 80 | self.files_names = files_names # The list of the files names with correct nuber of objects 81 | if obj_num is not None: 82 | self.obj_num = obj_num # The upper limit of number of objects 83 | else: 84 | self.obj_num = 10 85 | 86 | assert os.path.exists(root), 'Path {} does not exist'.format(root) 87 | 88 | #list of sorted image paths 89 | self.img_paths = [] 90 | if img_paths: 91 | self.img_paths = img_paths 92 | else: 93 | #open directory and save all image paths 94 | for file in os.scandir(os.path.join(root, 'images', mode)): 95 | img_path = file.path 96 | if '.png' in img_path: 97 | self.img_paths.append(img_path) 98 | 99 | self.img_paths.sort() 100 | count = 0 101 | 102 | #target maps of the form {'target:idx': query string} or {'target:idx': obj encoding} 103 | self.query_map = {} 104 | self.obj_map = {} 105 | 106 | count = 0 107 | #We have up to 10 objects in the image, load the json file 108 | with open(os.path.join(root, 'scenes','CLEVR_'+ mode+"_scenes.json")) as f: 109 | data = json.load(f) 110 | 111 | #iterate over each scene and create the query string and obj encoding 112 | print("parsing scences") 113 | for scene in tqdm(data['scenes']): 114 | target_query = "" 115 | obj_encoding_list = [] 116 | 117 | if self.files_names: 118 | if any(scene['image_filename'] in file_name for file_name in files_names): 119 | num_objects = 0 120 | for idx, obj in enumerate(scene['objects']): 121 | target_query += " :- not object(o{}, {}, {}, {}, {}).".format(idx+1, obj['size'], obj['material'], obj['shape'], obj['color']) 122 | obj_encoding_list.append(object_encoding(obj['size'], obj['material'], obj['shape'], obj['color'])) 123 | num_objects = idx+1 #store the num of objects 124 | #fill in background objects 125 | for idx in range(num_objects, self.obj_num): 126 | target_query += " :- not object(o{}, bg, bg, bg, bg).".format(idx+1) 127 | obj_encoding_list.append([0,0,1, 0,0,1, 0,0,0,1, 0,0,0,0,0,0,0,0,1, 1]) 128 | self.query_map[count] = target_query 129 | self.obj_map[count] = np.array(obj_encoding_list) 130 | count += 1 131 | else: 132 | num_objects=0 133 | for idx, obj in enumerate(scene['objects']): 134 | target_query += " :- not object(o{}, {}, {}, {}, {}).".format(idx+1, obj['size'], obj['material'], obj['shape'], obj['color']) 135 | obj_encoding_list.append(object_encoding(obj['size'], obj['material'], obj['shape'], obj['color'])) 136 | num_objects = idx+1 #store the num of objects 137 | #fill in background objects 138 | for idx in range(num_objects, 10): 139 | target_query += " :- not object(o{}, bg, bg, bg, bg).".format(idx+1) 140 | obj_encoding_list.append([0,0,1, 0,0,1, 0,0,0,1, 0,0,0,0,0,0,0,0,1, 1]) 141 | self.query_map[scene['image_index']] = target_query 142 | self.obj_map[scene['image_index']] = np.array(obj_encoding_list) 143 | 144 | print("done") 145 | if self.files_names: 146 | print(f'Correctly found images {count} out of {len(files_names)}') 147 | 148 | 149 | #print(np.array(list(self.obj_map.values()))[0:20]) 150 | def __getitem__(self, index): 151 | #get the image 152 | img_path = self.img_paths[index] 153 | img = io.imread(img_path)[:, :, :3] 154 | 155 | transform = transforms.Compose([ 156 | transforms.ToPILImage(), 157 | #transforms.CenterCrop((29, 221,64, 256)), #why do we need to crop? 158 | transforms.Resize((128, 128)), 159 | transforms.ToTensor(), 160 | ]) 161 | 162 | img = transform(img) 163 | img = (img - 0.5) * 2.0 # Rescale to [-1, 1]. 164 | 165 | return {'im':img}, self.query_map[index] ,self.obj_map[index] 166 | 167 | 168 | 169 | def __len__(self): 170 | return len(self.img_paths) 171 | 172 | -------------------------------------------------------------------------------- /src/experiments/slash_attention/clevr/slash_attention_clevr.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | 5 | 6 | import train 7 | import datetime 8 | 9 | 10 | seed = 0 11 | obj_num = 10 12 | date_string = datetime.datetime.today().strftime('%d-%m-%Y') 13 | experiments = {f'CLEVR{obj_num}': {'structure':'poon-domingos', 'pd_num_pieces':[4], 'learn_prior':True, 14 | 'lr': 0.01, 'bs':512, 'epochs':1000, 15 | 'lr_warmup_steps':8, 'lr_decay_steps':360, 'use_em':False, 'resume':False, 16 | 'method':'most_prob', 17 | 'start_date':date_string, 'credentials':'DO', 'p_num':16, 'seed':seed, 'obj_num':obj_num 18 | }} 19 | 20 | 21 | for exp_name in experiments: 22 | print(exp_name) 23 | train.slash_slot_attention(exp_name, experiments[exp_name]) 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /src/experiments/slash_attention/clevr_cogent/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/SLASH/79d158dbd670d56d8bf9043640635c24a18a9c57/src/experiments/slash_attention/clevr_cogent/__init__.py -------------------------------------------------------------------------------- /src/experiments/slash_attention/clevr_cogent/auxiliary.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from pathlib import Path 4 | 5 | 6 | def get_files_names_and_paths(root:str='./data/CLEVR_v1.0/', mode:str='val', obj_num:int=4): 7 | data_file = Path(os.path.join(root, 'scenes','CLEVR_'+ mode+"_scenes.json")) 8 | data_file.parent.mkdir(parents=True, exist_ok=True) 9 | if data_file.exists(): 10 | print('File exists. Parsing file...') 11 | else: 12 | print(f'The JSON file {data_file} does not exist!') 13 | quit() 14 | img_paths = [] 15 | files_names = [] 16 | with open(data_file, 'r') as json_file: 17 | json_data = json.load(json_file) 18 | 19 | for scene in json_data['scenes']: 20 | if len(scene['objects']) <= obj_num: 21 | img_paths.append(Path(os.path.join(root,'images/'+mode+'/'+scene['image_filename']))) 22 | files_names.append(scene['image_filename']) 23 | 24 | print("...done ") 25 | return img_paths, files_names 26 | 27 | 28 | def get_slash_program(obj_num:int=4): 29 | program = '' 30 | if obj_num == 10: 31 | program =''' 32 | slot(s1). 33 | slot(s2). 34 | slot(s3). 35 | slot(s4). 36 | slot(s5). 37 | slot(s6). 38 | slot(s7). 39 | slot(s8). 40 | slot(s9). 41 | slot(s10). 42 | 43 | name(o1). 44 | name(o2). 45 | name(o3). 46 | name(o4). 47 | name(o5). 48 | name(o6). 49 | name(o7). 50 | name(o8). 51 | name(o9). 52 | name(o10). 53 | ''' 54 | elif obj_num ==4: 55 | program =''' 56 | slot(s1). 57 | slot(s2). 58 | slot(s3). 59 | slot(s4). 60 | 61 | name(o1). 62 | name(o2). 63 | name(o3). 64 | name(o4). 65 | ''' 66 | elif obj_num ==6: 67 | program =''' 68 | slot(s1). 69 | slot(s2). 70 | slot(s3). 71 | slot(s4). 72 | slot(s5). 73 | slot(s6). 74 | 75 | name(o1). 76 | name(o2). 77 | name(o3). 78 | name(o4). 79 | name(o5). 80 | name(o6). 81 | ''' 82 | else: 83 | print(f'The number of objects {obj_num} is wrong!') 84 | quit() 85 | program +=''' 86 | %assign each name a slot 87 | %{slot_name_comb(N,X): slot(X)}=1 :- name(N). %problem we have dublicated slots 88 | 89 | %remove each model which has multiple slots asigned to the same name 90 | %:- slot_name_comb(N1,X1), slot_name_comb(N2,X2), X1 == X2, N1 != N2. 91 | 92 | %build the object ontop of the slot assignment 93 | object(N, S, M, P, C) :- size(0, +X, -S), material(0, +X, -M), shape(0, +X, -P), color(0, +X, -C), slot(X), name(N), slot_name_comb(N,X). 94 | 95 | %define the SPNs 96 | npp(size(1,X),[small, large, bg]) :- slot(X). 97 | npp(material(1,X),[rubber, metal, bg]) :- slot(X). 98 | npp(shape(1,X),[cube, sphere, cylinder, bg]) :- slot(X). 99 | npp(color(1,X),[gray, red, blue, green, brown, purple, cyan, yellow, bg]) :- slot(X). 100 | 101 | ''' 102 | return program -------------------------------------------------------------------------------- /src/experiments/slash_attention/clevr_cogent/dataGen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch.utils.data import Dataset 4 | from torchvision.transforms import transforms 5 | 6 | # from torch.utils.data import Dataset 7 | # from torchvision import transforms 8 | from skimage import io 9 | import os 10 | import numpy as np 11 | import torch 12 | import matplotlib.pyplot as plt 13 | from PIL import ImageFile 14 | ImageFile.LOAD_TRUNCATED_IMAGES = True 15 | import json 16 | 17 | from tqdm import tqdm 18 | 19 | 20 | def class_lookup(idx): 21 | ''' 22 | Shapeworld dataset generation generates a class id for each class. Returns color, shape and object encoding given the class id. 23 | The class order was generated by ShapeWorld. 24 | ''' 25 | 26 | #obj encoding ['red','blue','green','black', 'circle' , 'triangle','square', 'bg', 'object_confidence(1)'] 27 | 28 | if idx == 0: 29 | return "red", "circle", [1,0,0,0 ,1,0,0,0 , 1] 30 | elif idx ==1: 31 | return "green", "circle" ,[0,0,1,0 ,1,0,0,0 , 1] 32 | elif idx ==2: 33 | return "blue", "circle", [0,1,0,0 ,1,0,0,0 , 1] 34 | elif idx ==3: 35 | return "red", "square", [1,0,0,0 ,0,0,1,0 , 1] 36 | elif idx ==4: 37 | return "green", "square", [0,0,1,0 ,0,0,1,0 , 1] 38 | elif idx ==5: 39 | return "blue", "square" , [0,1,0,0 ,0,0,1,0 , 1] 40 | elif idx ==6: 41 | return "red", "triangle", [1,0,0,0 ,0,1,0,0 , 1] 42 | elif idx ==7: 43 | return "green", "triangle", [0,0,1,0 ,0,1,0,0 , 1] 44 | elif idx ==8: 45 | return "blue", "triangle", [0,1,0,0 ,0,1,0,0 , 1] 46 | 47 | def object_encoding(size, material, shape, color ): 48 | 49 | #size (small, large, bg) 50 | if size == "small": 51 | size_enc = [1,0,0] 52 | elif size == "large": 53 | size_enc = [0,1,0] 54 | elif size == "bg": 55 | size_enc = [0,0,1] 56 | 57 | #material (rubber, metal, bg) 58 | if material == "rubber": 59 | material_enc = [1,0,0] 60 | elif material == "metal": 61 | material_enc = [0,1,0] 62 | elif material == "bg": 63 | material_enc = [0,0,1] 64 | 65 | #shape (cube, sphere, cylinder, bg) 66 | if shape == "cube": 67 | shape_enc = [1,0,0,0] 68 | elif shape == "sphere": 69 | shape_enc = [0,1,0,0] 70 | elif shape == "cylinder": 71 | shape_enc = [0,0,1,0] 72 | elif shape == "bg": 73 | shape_enc = [0,0,0,1] 74 | 75 | #color (gray, red, blue, green, brown, purple, cyan, yellow, bg) 76 | #color (gray, red, blue, green, brown, purple, cyan, yellow, bg) 77 | #color 1 2 3 4 5 6 7 8 9 78 | if color == "gray": 79 | color_enc = [1,0,0,0,0,0,0,0,0] 80 | elif color == "red": 81 | color_enc = [0,1,0,0,0,0,0,0,0] 82 | elif color == "blue": 83 | color_enc = [0,0,1,0,0,0,0,0,0] 84 | elif color == "green": 85 | color_enc = [0,0,0,1,0,0,0,0,0] 86 | elif color == "brown": 87 | color_enc = [0,0,0,0,1,0,0,0,0] 88 | elif color == "purple": 89 | color_enc = [0,0,0,0,0,1,0,0,0] 90 | elif color == "cyan": 91 | color_enc = [0,0,0,0,0,0,1,0,0] 92 | elif color == "yellow": 93 | color_enc = [0,0,0,0,0,0,0,1,0] 94 | elif color == "bg": 95 | color_enc = [0,0,0,0,0,0,0,0,1] 96 | 97 | return size_enc + material_enc + shape_enc + color_enc +[1] 98 | 99 | 100 | 101 | 102 | 103 | class CLEVR(Dataset): 104 | def __init__(self, root, mode, img_paths=None, files_names=None, obj_num=None): 105 | self.root = root # The root folder of the dataset 106 | self.mode = mode # The mode of 'train' or 'val' 107 | self.files_names = files_names # The list of the files names with correct nuber of objects 108 | if obj_num is not None: 109 | self.obj_num = obj_num # The upper limit of number of objects 110 | else: 111 | self.obj_num = 10 112 | 113 | assert os.path.exists(root), 'Path {} does not exist'.format(root) 114 | 115 | #list of sorted image paths 116 | self.img_paths = [] 117 | if img_paths: 118 | self.img_paths = img_paths 119 | else: 120 | #open directory and save all image paths 121 | for file in os.scandir(os.path.join(root, 'images', mode)): 122 | img_path = file.path 123 | if '.png' in img_path: 124 | self.img_paths.append(img_path) 125 | 126 | self.img_paths.sort() 127 | count = 0 128 | 129 | #target maps of the form {'target:idx': query string} or {'target:idx': obj encoding} 130 | self.query_map = {} 131 | self.obj_map = {} 132 | 133 | count = 0 134 | #We have up to 10 objects in the image, load the json file 135 | with open(os.path.join(root, 'scenes','CLEVR_'+ mode+"_scenes.json")) as f: 136 | data = json.load(f) 137 | 138 | #iterate over each scene and create the query string and obj encoding 139 | print("parsing scences") 140 | for scene in tqdm(data['scenes']): 141 | target_query = "" 142 | obj_encoding_list = [] 143 | 144 | if self.files_names: 145 | if any(scene['image_filename'] in file_name for file_name in files_names): 146 | num_objects = 0 147 | for idx, obj in enumerate(scene['objects']): 148 | target_query += " :- not object(o{}, {}, {}, {}, {}).".format(idx+1, obj['size'], obj['material'], obj['shape'], obj['color']) 149 | obj_encoding_list.append(object_encoding(obj['size'], obj['material'], obj['shape'], obj['color'])) 150 | num_objects = idx+1 #store the num of objects 151 | #fill in background objects 152 | for idx in range(num_objects, self.obj_num): 153 | target_query += " :- not object(o{}, bg, bg, bg, bg).".format(idx+1) 154 | obj_encoding_list.append([0,0,1, 0,0,1, 0,0,0,1, 0,0,0,0,0,0,0,0,1, 1]) 155 | self.query_map[count] = target_query 156 | self.obj_map[count] = np.array(obj_encoding_list) 157 | count += 1 158 | else: 159 | num_objects=0 160 | for idx, obj in enumerate(scene['objects']): 161 | target_query += " :- not object(o{}, {}, {}, {}, {}).".format(idx+1, obj['size'], obj['material'], obj['shape'], obj['color']) 162 | obj_encoding_list.append(object_encoding(obj['size'], obj['material'], obj['shape'], obj['color'])) 163 | num_objects = idx+1 #store the num of objects 164 | #fill in background objects 165 | for idx in range(num_objects, 10): 166 | target_query += " :- not object(o{}, bg, bg, bg, bg).".format(idx+1) 167 | obj_encoding_list.append([0,0,1, 0,0,1, 0,0,0,1, 0,0,0,0,0,0,0,0,1, 1]) 168 | self.query_map[scene['image_index']] = target_query 169 | self.obj_map[scene['image_index']] = np.array(obj_encoding_list) 170 | 171 | print("done") 172 | if self.files_names: 173 | print(f'Correctly found images {count} out of {len(files_names)}') 174 | 175 | 176 | #print(np.array(list(self.obj_map.values()))[0:20]) 177 | def __getitem__(self, index): 178 | #get the image 179 | img_path = self.img_paths[index] 180 | img = io.imread(img_path)[:, :, :3] 181 | 182 | transform = transforms.Compose([ 183 | transforms.ToPILImage(), 184 | #transforms.CenterCrop((29, 221,64, 256)), #why do we need to crop? 185 | transforms.Resize((128, 128)), 186 | transforms.ToTensor(), 187 | ]) 188 | 189 | img = transform(img) 190 | img = (img - 0.5) * 2.0 # Rescale to [-1, 1]. 191 | 192 | return {'im':img}, self.query_map[index] ,self.obj_map[index] 193 | 194 | 195 | 196 | def __len__(self): 197 | return len(self.img_paths) 198 | 199 | -------------------------------------------------------------------------------- /src/experiments/slash_attention/clevr_cogent/slash_attention_clevr.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | 5 | 6 | import train 7 | 8 | import datetime 9 | 10 | 11 | 12 | 13 | 14 | seed = 4 15 | obj_num = 10 16 | date_string = datetime.datetime.today().strftime('%d-%m-%Y') 17 | experiments = {f'CLEVR{obj_num}_seed_{seed}': {'structure':'poon-domingos', 'pd_num_pieces':[4], 'learn_prior':True, 18 | 'lr': 0.01, 'bs':512, 'epochs':1000, 19 | 'lr_warmup_steps':8, 'lr_decay_steps':360, 'use_em':False, 'resume':False, 20 | 'method':'most_prob', 21 | 'start_date':date_string, 'credentials':'DO', 'p_num':16, 'seed':seed, 'obj_num':obj_num 22 | }} 23 | 24 | 25 | for exp_name in experiments: 26 | print(exp_name) 27 | train.slash_slot_attention(exp_name, experiments[exp_name]) 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /src/experiments/slash_attention/cogent/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/SLASH/79d158dbd670d56d8bf9043640635c24a18a9c57/src/experiments/slash_attention/cogent/__init__.py -------------------------------------------------------------------------------- /src/experiments/slash_attention/cogent/dataGen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch.utils.data import Dataset 4 | from torchvision.transforms import transforms 5 | 6 | from torch.utils.data import Dataset 7 | from torchvision import transforms 8 | from skimage import io 9 | import os 10 | import numpy as np 11 | import torch 12 | from PIL import ImageFile 13 | ImageFile.LOAD_TRUNCATED_IMAGES = True 14 | import json 15 | import datasets 16 | 17 | 18 | def get_encoding(color, shape, shade, size): 19 | 20 | if color == 'red': 21 | col_enc = [1,0,0,0,0,0,0,0,0] 22 | elif color == 'blue': 23 | col_enc = [0,1,0,0,0,0,0,0,0] 24 | elif color == 'green': 25 | col_enc = [0,0,1,0,0,0,0,0,0] 26 | elif color == 'gray': 27 | col_enc = [0,0,0,1,0,0,0,0,0] 28 | elif color == 'brown': 29 | col_enc = [0,0,0,0,1,0,0,0,0] 30 | elif color == 'magenta': 31 | col_enc = [0,0,0,0,0,1,0,0,0] 32 | elif color == 'cyan': 33 | col_enc = [0,0,0,0,0,0,1,0,0] 34 | elif color == 'yellow': 35 | col_enc = [0,0,0,0,0,0,0,1,0] 36 | elif color == 'black': 37 | col_enc = [0,0,0,0,0,0,0,0,1] 38 | 39 | 40 | if shape == 'circle': 41 | shape_enc = [1,0,0,0] 42 | elif shape == 'triangle': 43 | shape_enc = [0,1,0,0] 44 | elif shape == 'square': 45 | shape_enc = [0,0,1,0] 46 | elif shape == 'bg': 47 | shape_enc = [0,0,0,1] 48 | 49 | if shade == 'bright': 50 | shade_enc = [1,0,0] 51 | elif shade =='dark': 52 | shade_enc = [0,1,0] 53 | elif shade == 'bg': 54 | shade_enc = [0,0,1] 55 | 56 | 57 | if size == 'small': 58 | size_enc = [1,0,0] 59 | elif size == 'big': 60 | size_enc = [0,1,0] 61 | elif size == 'bg': 62 | size_enc = [0,0,1] 63 | 64 | return col_enc + shape_enc + shade_enc + size_enc + [1] 65 | 66 | 67 | class SHAPEWORLD_COGENT(Dataset): 68 | def __init__(self, root, mode, ret_obj_encoding=False): 69 | 70 | datasets.maybe_download_shapeworld_cogent() 71 | 72 | self.ret_obj_encoding= ret_obj_encoding 73 | self.root = root 74 | self.mode = mode 75 | assert os.path.exists(root), 'Path {} does not exist'.format(root) 76 | 77 | #dictionary of the form {'image_idx':'img_path'} 78 | self.img_paths = {} 79 | 80 | 81 | for file in os.scandir(os.path.join(root, 'images', mode)): 82 | img_path = file.path 83 | 84 | img_path_idx = img_path.split("/") 85 | img_path_idx = img_path_idx[-1] 86 | img_path_idx = img_path_idx[:-4][6:] 87 | try: 88 | img_path_idx = int(img_path_idx) 89 | self.img_paths[img_path_idx] = img_path 90 | except: 91 | print("path:",img_path_idx) 92 | 93 | 94 | 95 | count = 0 96 | 97 | #target maps of the form {'target:idx': query string} or {'target:idx': obj encoding} 98 | self.query_map = {} 99 | self.obj_map = {} 100 | 101 | with open(os.path.join(root, 'labels', mode,"world_model.json")) as f: 102 | worlds = json.load(f) 103 | 104 | objects = 0 105 | bgs = 0 106 | #iterate over all objects 107 | for world in worlds: 108 | num_objects = 0 109 | target_query = "" 110 | obj_enc = [] 111 | for entity in world['entities']: 112 | 113 | color = entity['color']['name'] 114 | shape = entity['shape']['name'] 115 | 116 | shade_val = entity['color']['shade'] 117 | if shade_val == 0.0: 118 | shade = 'bright' 119 | else: 120 | shade = 'dark' 121 | 122 | size_val = entity['shape']['size']['x'] 123 | if size_val == 0.075: 124 | size = 'small' 125 | elif size_val == 0.15: 126 | size = 'big' 127 | 128 | name = 'o' + str(num_objects+1) 129 | target_query = target_query+ ":- not object({},{},{},{},{}). ".format(name, color, shape, shade, size) 130 | obj_enc.append(get_encoding(color, shape, shade, size)) 131 | num_objects += 1 132 | objects +=1 133 | 134 | #bg encodings 135 | for i in range(num_objects, 4): 136 | name = 'o' + str(num_objects+1) 137 | target_query = target_query+ ":- not object({},black,bg, bg, bg). ".format(name) 138 | obj_enc.append(get_encoding("black","bg","bg","bg")) 139 | num_objects += 1 140 | bgs +=1 141 | 142 | self.query_map[count] = target_query 143 | self.obj_map[count] = np.array(obj_enc) 144 | count+=1 145 | print("num objects",objects) 146 | print("num bgs",bgs) 147 | 148 | 149 | 150 | def __getitem__(self, index): 151 | 152 | #get the image 153 | img_path = self.img_paths[index] 154 | img = io.imread(img_path)[:, :, :3] 155 | 156 | transform = transforms.Compose([ 157 | transforms.ToPILImage(), 158 | transforms.ToTensor(), 159 | ]) 160 | img = transform(img) 161 | img = (img - 0.5) * 2.0 # Rescale to [-1, 1]. 162 | 163 | if self.ret_obj_encoding: 164 | return {'im':img}, self.query_map[index] ,self.obj_map[index] 165 | else: 166 | return {'im':img}, self.query_map[index] 167 | def __len__(self): 168 | return len(self.img_paths) 169 | 170 | -------------------------------------------------------------------------------- /src/experiments/slash_attention/cogent/slash_attention_cogent.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | 5 | import train 6 | import numpy as np 7 | import torch 8 | import torchvision 9 | import datetime 10 | 11 | date_string = datetime.datetime.today().strftime('%d-%m-%Y') 12 | 13 | #Python script to start the shapeworld4 slot attention experiment 14 | #Define your experiment(s) parameters as a hashmap having the following parameters 15 | example_structure = {'experiment_name': 16 | {'structure': 'poon-domingos', 17 | 'pd_num_pieces': [4], 18 | 'lr': 0.01, #the learning rate to train the SPNs with, the slot attention module has a fixed lr=0.0004 19 | 'bs':512, #the batchsize 20 | 'epochs':1000, #number of epochs to train 21 | 'lr_warmup_steps':25, #number of epochs to warm up the slot attention module, warmup does not apply to the SPNs 22 | 'lr_decay_steps':100, #number of epochs it takes to decay to 50% of the specified lr 23 | 'start_date':"01-01-0001", #current date 24 | 'resume':False, #you can stop the experiment and set this parameter to true to load the last state and continue learning 25 | 'credentials':'AS', #your credentials for the rtpt class 26 | 'explanation': """Training on Condtion A, Testing on Condtion A and B to evaluate generalization of the model."""}} 27 | 28 | 29 | 30 | experiments ={'shapeworld4_cogent_hung': 31 | {'structure': 'poon-domingos', 'pd_num_pieces': [4], 32 | 'lr': 0.01, 'bs':512, 'epochs':1000, 33 | 'lr_warmup_steps':8, 'lr_decay_steps':360, 34 | 'start_date':date_string, 'resume':False, 35 | 'credentials':'DO', 'seed':3, 'learn_prior':True, 36 | 'p_num':16, 'hungarian_matching':True, 'method':'probabilistic_grounding_top_k', 37 | 'explanation': """Training on Condtion A, Testing on Condtion A and B to evaluate generalization of the model."""}} 38 | 39 | 40 | 41 | #train the network 42 | for exp_name in experiments: 43 | print(exp_name) 44 | train.slash_slot_attention(exp_name, experiments[exp_name]) 45 | 46 | 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /src/experiments/slash_attention/shapeworld4/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/SLASH/79d158dbd670d56d8bf9043640635c24a18a9c57/src/experiments/slash_attention/shapeworld4/__init__.py -------------------------------------------------------------------------------- /src/experiments/slash_attention/shapeworld4/dataGen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch.utils.data import Dataset 4 | from torchvision.transforms import transforms 5 | 6 | from torch.utils.data import Dataset 7 | from torchvision import transforms 8 | from skimage import io 9 | import os 10 | import numpy as np 11 | import torch 12 | from PIL import ImageFile 13 | ImageFile.LOAD_TRUNCATED_IMAGES = True 14 | import json 15 | import datasets 16 | 17 | 18 | def get_encoding(color, shape, shade, size): 19 | 20 | if color == 'red': 21 | col_enc = [1,0,0,0,0,0,0,0,0] 22 | elif color == 'blue': 23 | col_enc = [0,1,0,0,0,0,0,0,0] 24 | elif color == 'green': 25 | col_enc = [0,0,1,0,0,0,0,0,0] 26 | elif color == 'gray': 27 | col_enc = [0,0,0,1,0,0,0,0,0] 28 | elif color == 'brown': 29 | col_enc = [0,0,0,0,1,0,0,0,0] 30 | elif color == 'magenta': 31 | col_enc = [0,0,0,0,0,1,0,0,0] 32 | elif color == 'cyan': 33 | col_enc = [0,0,0,0,0,0,1,0,0] 34 | elif color == 'yellow': 35 | col_enc = [0,0,0,0,0,0,0,1,0] 36 | elif color == 'black': 37 | col_enc = [0,0,0,0,0,0,0,0,1] 38 | 39 | 40 | if shape == 'circle': 41 | shape_enc = [1,0,0,0] 42 | elif shape == 'triangle': 43 | shape_enc = [0,1,0,0] 44 | elif shape == 'square': 45 | shape_enc = [0,0,1,0] 46 | elif shape == 'bg': 47 | shape_enc = [0,0,0,1] 48 | 49 | if shade == 'bright': 50 | shade_enc = [1,0,0] 51 | elif shade =='dark': 52 | shade_enc = [0,1,0] 53 | elif shade == 'bg': 54 | shade_enc = [0,0,1] 55 | 56 | 57 | if size == 'small': 58 | size_enc = [1,0,0] 59 | elif size == 'big': 60 | size_enc = [0,1,0] 61 | elif size == 'bg': 62 | size_enc = [0,0,1] 63 | 64 | return col_enc + shape_enc + shade_enc + size_enc + [1] 65 | 66 | 67 | class SHAPEWORLD4(Dataset): 68 | def __init__(self, root, mode, ret_obj_encoding=False): 69 | 70 | datasets.maybe_download_shapeworld4() 71 | 72 | self.ret_obj_encoding = ret_obj_encoding 73 | self.root = root 74 | self.mode = mode 75 | assert os.path.exists(root), 'Path {} does not exist'.format(root) 76 | 77 | #dictionary of the form {'image_idx':'img_path'} 78 | self.img_paths = {} 79 | 80 | 81 | for file in os.scandir(os.path.join(root, 'images', mode)): 82 | img_path = file.path 83 | 84 | img_path_idx = img_path.split("/") 85 | img_path_idx = img_path_idx[-1] 86 | img_path_idx = img_path_idx[:-4][6:] 87 | try: 88 | img_path_idx = int(img_path_idx) 89 | self.img_paths[img_path_idx] = img_path 90 | except: 91 | print("path:",img_path_idx) 92 | 93 | 94 | count = 0 95 | 96 | #target maps of the form {'target:idx': query string} or {'target:idx': obj encoding} 97 | self.query_map = {} 98 | self.obj_map = {} 99 | 100 | with open(os.path.join(root, 'labels', mode,"world_model.json")) as f: 101 | worlds = json.load(f) 102 | 103 | #iterate over all objects 104 | for world in worlds: 105 | num_objects = 0 106 | target_query = "" 107 | obj_enc = [] 108 | for entity in world['entities']: 109 | 110 | color = entity['color']['name'] 111 | shape = entity['shape']['name'] 112 | 113 | shade_val = entity['color']['shade'] 114 | if shade_val == 0.0: 115 | shade = 'bright' 116 | else: 117 | shade = 'dark' 118 | 119 | size_val = entity['shape']['size']['x'] 120 | if size_val == 0.075: 121 | size = 'small' 122 | elif size_val == 0.15: 123 | size = 'big' 124 | 125 | name = 'o' + str(num_objects+1) 126 | target_query = target_query+ ":- not object({},{},{},{},{}). ".format(name, color, shape, shade, size) 127 | obj_enc.append(get_encoding(color, shape, shade, size)) 128 | num_objects += 1 129 | 130 | #bg encodings 131 | for i in range(num_objects, 4): 132 | name = 'o' + str(num_objects+1) 133 | target_query = target_query+ ":- not object({},black,bg, bg, bg). ".format(name) 134 | obj_enc.append(get_encoding("black","bg","bg","bg")) 135 | num_objects += 1 136 | 137 | 138 | self.query_map[count] = target_query 139 | self.obj_map[count] = np.array(obj_enc) 140 | count+=1 141 | 142 | 143 | 144 | def __getitem__(self, index): 145 | 146 | #get the image 147 | img_path = self.img_paths[index] 148 | img = io.imread(img_path)[:, :, :3] 149 | 150 | transform = transforms.Compose([ 151 | transforms.ToPILImage(), 152 | transforms.ToTensor(), 153 | ]) 154 | img = transform(img) 155 | img = (img - 0.5) * 2.0 # Rescale to [-1, 1]. 156 | 157 | if self.ret_obj_encoding: 158 | return {'im':img}, self.query_map[index] ,self.obj_map[index] 159 | else: 160 | return {'im':img}, self.query_map[index] 161 | 162 | def __len__(self): 163 | return len(self.img_paths) 164 | -------------------------------------------------------------------------------- /src/experiments/slash_attention/shapeworld4/slash_attention_shapeworld4.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | 5 | import train 6 | import datetime 7 | 8 | 9 | #Python script to start the shapeworld4 slot attention experiment 10 | #Define your experiment(s) parameters as a hashmap having the following parameters 11 | example_structure = {'experiment_name': 12 | {'structure': 'poon-domingos', 13 | 'pd_num_pieces': [4], 14 | 'lr': 0.01, #the learning rate to train the SPNs with, the slot attention module has a fixed lr=0.0004 15 | 'bs':50, #the batchsize 16 | 'epochs':1000, #number of epochs to train 17 | 'lr_warmup':True, #boolean indicating the use of learning rate warm up 18 | 'lr_warmup_steps':25, #number of epochs to warm up the slot attention module, warmup does not apply to the SPNs 19 | 'start_date':"01-01-0001", #current date 20 | 'resume':False, #you can stop the experiment and set this parameter to true to load the last state and continue learning 21 | 'credentials':'DO', #your credentials for the rtpt class 22 | 'hungarian_matching': True, 23 | 'explanation': """Running the whole SlotAttention+Slash pipeline using poon-domingos as SPN structure learner."""}} 24 | 25 | 26 | 27 | 28 | #EXPERIMENTS 29 | date_string = datetime.datetime.today().strftime('%d-%m-%Y') 30 | 31 | 32 | for seed in [0,1,2,3,4]: 33 | experiments = {'shapeworld4': 34 | {'structure': 'poon-domingos', 'pd_num_pieces': [4], 35 | 'lr': 0.01, 'bs':512, 'epochs':1000, 36 | 'lr_warmup_steps':8, 'lr_decay_steps':360, 37 | 'start_date':date_string, 'resume':False, 'credentials':'DO','seed':seed, 38 | 'p_num':16, 'method':'same_top_k', 'hungarian_matching': False, 39 | 'explanation': """Running the whole SlotAttention+Slash pipeline using poon-domingos as SPN structure learner."""} 40 | } 41 | 42 | 43 | print("shapeworld4") 44 | train.slash_slot_attention("shapeworld4", experiments["shapeworld4"]) 45 | 46 | 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /src/experiments/vqa/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/SLASH/79d158dbd670d56d8bf9043640635c24a18a9c57/src/experiments/vqa/__init__.py -------------------------------------------------------------------------------- /src/experiments/vqa/cmd_args2.py: -------------------------------------------------------------------------------- 1 | """ 2 | The source code is based on: 3 | Scallop: From Probabilistic Deductive Databases to Scalable Differentiable Reasoning 4 | Jiani Huang, Ziyang Li, Binghong Chen, Karan Samel, Mayur Naik, Le Song, Xujie Si 5 | Advances in Neural Information Processing Systems 34 (NeurIPS 2021) 6 | https://proceedings.neurips.cc/paper/2021/hash/d367eef13f90793bd8121e2f675f0dc2-Abstract.html 7 | """ 8 | 9 | 10 | import argparse 11 | import sys 12 | import random 13 | import numpy as np 14 | import torch 15 | import os 16 | import logging 17 | 18 | # Utility function for take in yes/no and convert it to boolean 19 | def convert_str_to_bool(cmd_args): 20 | for key, val in vars(cmd_args).items(): 21 | if val == "yes": 22 | setattr(cmd_args, key, True) 23 | elif val == "no": 24 | setattr(cmd_args, key, False) 25 | 26 | def str2bool(v): 27 | if isinstance(v, bool): 28 | return v 29 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 30 | return True 31 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 32 | return False 33 | else: 34 | raise argparse.ArgumentTypeError('Boolean value expected.') 35 | 36 | class LearningSetting(object): 37 | 38 | def __init__(self): 39 | data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../data")) 40 | dataset_dir = os.path.join(data_dir, "dataset") 41 | knowledge_base_dir = os.path.join(data_dir, "knowledge_base") 42 | 43 | self.parser = argparse.ArgumentParser(description='Argparser', allow_abbrev=True) 44 | # Logistics 45 | self.parser.add_argument("--seed", default=1234, type=int, help="set random seed") 46 | self.parser.add_argument("--gpu", default=0, type=int, help="GPU id") 47 | self.parser.add_argument('--timeout', default=600, type=int, help="Execution timeout") 48 | 49 | # Learning settings 50 | self.parser.add_argument("--name_threshold", default=0, type=float) 51 | self.parser.add_argument("--attr_threshold", default=0, type=float) 52 | self.parser.add_argument("--rela_threshold", default=0, type=float) 53 | # self.parser.add_argument("--name_topk_n", default=-1, type=int) 54 | # self.parser.add_argument("--attr_topk_n", default=-1, type=int) 55 | # self.parser.add_argument("--rela_topk_n", default=80, type=int) 56 | self.parser.add_argument("--topk", default=3, type=int) 57 | 58 | # training settings 59 | self.parser.add_argument('--feat_dim', type=int, default=2048) 60 | self.parser.add_argument('--n_epochs', type=int, default=20) 61 | self.parser.add_argument('--batch_size', type=int, default=1) 62 | self.parser.add_argument('--max_workers', type=int, default=1) 63 | self.parser.add_argument('--axiom_update_size', type=int, default=4) 64 | self.parser.add_argument('--name_lr', type=float, default=0.0001) 65 | self.parser.add_argument('--attr_lr', type=float, default=0.0001) 66 | self.parser.add_argument('--rela_lr', type=float, default=0.0001) 67 | self.parser.add_argument('--reinforce', type=str2bool, nargs='?', const=True, default=False) # reinforce only support single thread 68 | self.parser.add_argument('--replays', type=int, default=5) 69 | 70 | self.parser.add_argument('--model_dir', default=data_dir+'/model_ckpts_sg') 71 | # self.parser.add_argument('--model_dir', default=None) 72 | self.parser.add_argument('--log_name', default='model.log') 73 | self.parser.add_argument('--feat_f', default=data_dir+'/features.npy') 74 | self.parser.add_argument('--train_f', default=dataset_dir+'/task_list/train_tasks_c2_10.pkl') 75 | self.parser.add_argument('--val_f', default=dataset_dir+'/task_list/val_tasks.pkl') 76 | self.parser.add_argument('--test_f', default=dataset_dir+'/task_list/test_tasks_c2_1000.pkl') 77 | self.parser.add_argument('--cul_prov', type=bool, default=False) 78 | 79 | self.parser.add_argument('--meta_f', default=data_dir+'/gqa_info.json') 80 | self.parser.add_argument('--scene_f', default=data_dir+'/gqa_formatted_scene_graph.pkl') 81 | self.parser.add_argument('--image_data_f', default=data_dir+'/image_data.json') 82 | self.parser.add_argument('--dataset_type', default='name') # name, relation, attr: 83 | 84 | self.parser.add_argument('--function', default=None) #KG_Find / Hypernym_Find / Find_Name / Find_Attr / Relate / Relate_Reverse / And / Or 85 | self.parser.add_argument('--knowledge_base_dir', default=knowledge_base_dir) 86 | # self.parser.add_argument('--interp_size', type=int, default=2) 87 | 88 | self.parser.add_argument('--save_dir', default=data_dir+"/problog_data") 89 | self.args = self.parser.parse_args(sys.argv[1:]) 90 | 91 | ls = LearningSetting() 92 | cmd_args = ls.args 93 | # print(cmd_args) 94 | 95 | # Fix random seed for debugging purpose 96 | if (ls.args.seed != None): 97 | random.seed(ls.args.seed) 98 | np.random.seed(ls.args.seed) 99 | torch.manual_seed(ls.args.seed) 100 | 101 | if not type(cmd_args.gpu) == None: 102 | os.environ['CUDA_VISIBLE_DEVICES'] = str(cmd_args.gpu) 103 | torch.set_default_tensor_type(torch.cuda.FloatTensor) 104 | 105 | if cmd_args.model_dir is not None: 106 | if not os.path.exists(cmd_args.model_dir): 107 | os.makedirs(cmd_args.model_dir) 108 | 109 | log_path = os.path.join(cmd_args.model_dir, cmd_args.log_name) 110 | logging.basicConfig(filename=log_path, filemode='w', level=logging.INFO, format='%(name)s - %(levelname)s - %(message)s') 111 | logging.info(cmd_args) 112 | logging.info("start!") 113 | -------------------------------------------------------------------------------- /src/experiments/vqa/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | The source code is based on: 3 | Scallop: From Probabilistic Deductive Databases to Scalable Differentiable Reasoning 4 | Jiani Huang, Ziyang Li, Binghong Chen, Karan Samel, Mayur Naik, Le Song, Xujie Si 5 | Advances in Neural Information Processing Systems 34 (NeurIPS 2021) 6 | https://proceedings.neurips.cc/paper/2021/hash/d367eef13f90793bd8121e2f675f0dc2-Abstract.html 7 | """ 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import os 14 | import sys 15 | torch.autograd.set_detect_anomaly(True) 16 | 17 | sys.path.append('../../EinsumNetworks/src/') 18 | from einsum_wrapper import EiNet 19 | 20 | 21 | class MLPClassifier(nn.Module): 22 | def __init__(self, input_dim, latent_dim, output_dim, n_layers, dropout_rate, softmax): 23 | super(MLPClassifier, self).__init__() 24 | 25 | self.softmax = softmax 26 | self.output_dim = output_dim 27 | 28 | layers = [] 29 | layers.append(nn.Linear(input_dim, latent_dim)) 30 | layers.append(nn.ReLU()) 31 | layers.append(nn.BatchNorm1d(latent_dim)) 32 | #layers.append(nn.InstanceNorm1d(latent_dim)) 33 | layers.append(nn.Dropout(dropout_rate)) 34 | for _ in range(n_layers - 1): 35 | layers.append(nn.Linear(latent_dim, latent_dim)) 36 | layers.append(nn.ReLU()) 37 | layers.append(nn.BatchNorm1d(latent_dim)) 38 | #layers.append(nn.InstanceNorm1d(latent_dim)) 39 | layers.append(nn.Dropout(dropout_rate)) 40 | layers.append(nn.Linear(latent_dim, output_dim)) 41 | 42 | 43 | self.net = nn.Sequential(*layers) 44 | 45 | def forward(self, x, marg_idx=None, type=None): 46 | if x.sum() == 0: 47 | return torch.ones([x.shape[0], self.output_dim], device='cuda') 48 | 49 | idx = x.sum(dim=1)!=0 # get idx of true objects 50 | logits = torch.zeros(x.shape[0], self.output_dim, device='cuda') 51 | 52 | logits[idx] = self.net(x[idx]) 53 | 54 | 55 | if self.softmax: 56 | probs = F.softmax(logits, dim=1) 57 | else: 58 | probs = torch.sigmoid(logits) 59 | 60 | return probs 61 | 62 | # FasterCNN object feature size 63 | feature_dim = 2048 64 | 65 | name_clf = MLPClassifier( 66 | input_dim=feature_dim, 67 | output_dim=500, 68 | latent_dim=1024, 69 | n_layers=2, 70 | dropout_rate=0.3, 71 | softmax=True 72 | ) 73 | 74 | rela_clf = MLPClassifier( 75 | input_dim=(feature_dim+4)*2, 76 | output_dim=229, 77 | latent_dim=1024, 78 | n_layers=1, 79 | dropout_rate=0.5, 80 | softmax=True 81 | ) 82 | 83 | 84 | attr_clf = MLPClassifier( 85 | input_dim=feature_dim, 86 | output_dim=609, 87 | latent_dim=1024, 88 | n_layers=1, 89 | dropout_rate=0.3, 90 | softmax=False 91 | ) 92 | 93 | 94 | 95 | name_einet = EiNet(structure = 'poon-domingos', 96 | pd_num_pieces = [4], 97 | use_em = False, 98 | num_var = 2048, 99 | class_count = 500, 100 | pd_width = 32, 101 | pd_height = 64, 102 | learn_prior = True) 103 | rela_einet = EiNet(structure = 'poon-domingos', 104 | pd_num_pieces = [4], 105 | use_em = False, 106 | num_var = 4104, 107 | class_count = 229, 108 | pd_width = 72, 109 | pd_height = 57, 110 | learn_prior = True) 111 | 112 | attr_einet = EiNet(structure = 'poon-domingos', 113 | pd_num_pieces = [4], 114 | use_em = False, 115 | num_var = 2048, 116 | class_count = 609, 117 | pd_width = 32, 118 | pd_height = 64, 119 | learn_prior = True) -------------------------------------------------------------------------------- /src/experiments/vqa/network_nn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Net_nn(nn.Module): 5 | def __init__(self): 6 | super(Net_nn, self).__init__() 7 | self.encoder = nn.Sequential( 8 | nn.Conv2d(1, 6, 5), # 6 is the output chanel size; 5 is the kernal size; 1 (chanel) 28 28 -> 6 24 24 9 | nn.MaxPool2d(2, 2), # kernal size 2; stride size 2; 6 24 24 -> 6 12 12 10 | nn.ReLU(True), # inplace=True means that it will modify the input directly thus save memory 11 | nn.Conv2d(6, 16, 5), # 6 12 12 -> 16 8 8 12 | nn.MaxPool2d(2, 2), # 16 8 8 -> 16 4 4 13 | nn.ReLU(True) 14 | ) 15 | self.classifier = nn.Sequential( 16 | nn.Linear(16 * 4 * 4, 120), 17 | nn.ReLU(), 18 | nn.Linear(120, 84), 19 | nn.ReLU(), 20 | nn.Linear(84, 10), 21 | nn.Softmax(1) 22 | ) 23 | 24 | def forward(self, x, marg_idx=None, type=1): 25 | 26 | assert type == 1, "only posterior computations are available for this network" 27 | 28 | # If the list of the pixel numbers to be marginalised is given, 29 | # then genarate a marginalisation mask from it and apply to the 30 | # tensor 'x' 31 | if marg_idx: 32 | batch_size = x.shape[0] 33 | with torch.no_grad(): 34 | marg_mask = torch.ones_like(x, device=x.device).reshape(batch_size, 1, -1) 35 | marg_mask[:, :, marg_idx] = 0 36 | marg_mask = marg_mask.reshape_as(x) 37 | marg_mask.requires_grad_(False) 38 | x = torch.einsum('ijkl,ijkl->ijkl', x, marg_mask) 39 | x = self.encoder(x) 40 | x = x.view(-1, 16 * 4 * 4) 41 | x = self.classifier(x) 42 | return x 43 | -------------------------------------------------------------------------------- /src/experiments/vqa/preprocess.py: -------------------------------------------------------------------------------- 1 | """ 2 | The source code is based on: 3 | Scallop: From Probabilistic Deductive Databases to Scalable Differentiable Reasoning 4 | Jiani Huang, Ziyang Li, Binghong Chen, Karan Samel, Mayur Naik, Le Song, Xujie Si 5 | Advances in Neural Information Processing Systems 34 (NeurIPS 2021) 6 | https://proceedings.neurips.cc/paper/2021/hash/d367eef13f90793bd8121e2f675f0dc2-Abstract.html 7 | """ 8 | 9 | # Prog = Conj | Logic 10 | # Logic = biOp Prog Prog 11 | # Conj = and Des Conj | Des 12 | # Des = rela Relaname O1 O2 | attr Attr O 13 | 14 | class Variable(): 15 | def __init__(self, id): 16 | self.var_id = f"O{id}" 17 | self.name_id = f"N{id}" 18 | self.name = [] 19 | self.hypernyms = [] 20 | self.attrs = [] 21 | self.kgs = [] 22 | # The relations where this object functions as a subject 23 | self.sub_relas = [] 24 | self.obj_relas = [] 25 | 26 | def has_rela(self): 27 | if len(self.sub_relas) == 0 and len(self.obj_relas) == 0: 28 | return False 29 | return True 30 | 31 | def get_name_id(self): 32 | # if (not len(self.hypernyms) == 0) and len(self.name) == 0: 33 | # return True, self.name_id 34 | # if (not len(self.kgs) == 0) and len(self.name) == 0: 35 | # return True, self.name_id 36 | return False, self.name_id 37 | 38 | def set_name(self, name): 39 | if name in self.name: 40 | return 41 | 42 | self.name.append(name) 43 | 44 | def set_kg(self, kg): 45 | if kg not in self.kgs: 46 | self.kgs.append(kg) 47 | 48 | def set_hypernym(self, hypernym): 49 | if hypernym not in self.hypernyms: 50 | self.hypernyms.append(hypernym) 51 | 52 | def set_attr(self, attr): 53 | if attr not in self.attrs: 54 | self.attrs.append(attr) 55 | 56 | def set_obj_relas(self, obj_rela): 57 | if obj_rela not in self.obj_relas: 58 | self.obj_relas.append(obj_rela) 59 | 60 | def set_sub_relas(self, sub_rela): 61 | if sub_rela not in self.sub_relas: 62 | self.sub_relas.append(sub_rela) 63 | 64 | def get_neighbor(self): 65 | neighbors = [] 66 | for rela in self.sub_relas: 67 | neighbors.append(rela.obj) 68 | for rela in self.obj_relas: 69 | neighbors.append(rela.sub) 70 | return neighbors 71 | 72 | def update(self, other): 73 | 74 | self.hypernyms = list(set(self.name + other.name)) 75 | self.hypernyms = list(set(self.hypernyms + other.hypernyms)) 76 | self.attrs = list(set(self.attrs + other.attrs)) 77 | self.kgs = list(set(self.kgs + other.kgs)) 78 | 79 | 80 | def to_datalog(self, with_name=False, with_rela=True): 81 | 82 | name_query = [] 83 | 84 | if (len(self.name) == 0) and with_name: 85 | name_query.append("name({}, {})".format(self.name_id.replace(" ","_").replace(".","_"), self.var_id.replace(" ","_").replace(".","_"))) 86 | 87 | if (not len(self.name) == 0): 88 | for n in self.name: 89 | #name_query.append(f"name(\"{n}\", {self.var_id})") 90 | n = n.replace(" ","_").replace(".","_") 91 | #name_query.append(f"name(0,1,{self.var_id},{n})") 92 | name_query.append(f"name({self.var_id},{n})") 93 | 94 | 95 | #attr_query = [ f"attr(\"{attr}\", {self.var_id})" for attr in self.attrs] 96 | #attr_query = [ "attr(0,1,{}, {})".format(self.var_id.replace(" ","_"),attr.replace(" ","_")) for attr in self.attrs] 97 | attr_query = [ "attr({}, {})".format(self.var_id.replace(" ","_"),attr.replace(" ","_")) for attr in self.attrs] 98 | 99 | #hypernym_query = [f"name(\"{hypernym}\", {self.var_id})" for hypernym in self.hypernyms] 100 | #hypernym_query = ["name(0,1,{}, {})".format(self.var_id.replace(" ","_").replace(".","_"),hypernym.replace(" ","_").replace(".","_")) for hypernym in self.hypernyms] 101 | hypernym_query = ["name({}, {})".format(self.var_id.replace(" ","_").replace(".","_"),hypernym.replace(" ","_").replace(".","_")) for hypernym in self.hypernyms] 102 | 103 | kg_query = [] 104 | 105 | for kg in self.kgs: 106 | restriction = list(filter(lambda x: not x == 'BLANK' and not x == '', kg)) 107 | assert (len(restriction) == 2) 108 | rel = restriction[0].replace(" ","_") 109 | usage = restriction[1].replace(" ","_") 110 | #kg_query += [f"name({self.name_id}, {self.var_id}), oa_rel({rel}, {self.name_id}, {usage})"] 111 | #kg_query += ["name({}, {}), oa_rel({}, {}, {})".format(self.name_id.replace(" ","_").replace(".","_") ,self.var_id.replace(" ","_").replace(".","_") ,rel.replace(" ","_").replace(".","_"), self.name_id.replace(" ","_").replace(".","_"), usage.replace(" ","_").replace(".","_"))] 112 | #kg_query += ["name(0,1,{}, {}), oa_rel({}, {}, {})".format(self.var_id.replace(" ","_").replace(".","_") ,self.name_id.replace(" ","_").replace(".","_"), rel.replace(" ","_").replace(".","_"), self.name_id.replace(" ","_").replace(".","_"), usage.replace(" ","_").replace(".","_"))] 113 | kg_query += ["name({}, {}), oa_rel({}, {}, {})".format(self.var_id.replace(" ","_").replace(".","_") ,self.name_id.replace(" ","_").replace(".","_"), rel.replace(" ","_").replace(".","_"), self.name_id.replace(" ","_").replace(".","_"), usage.replace(" ","_").replace(".","_"))] 114 | 115 | if with_rela: 116 | rela_query = [rela.to_datalog() for rela in self.sub_relas] 117 | else: 118 | rela_query = [] 119 | 120 | program = name_query + attr_query + hypernym_query + kg_query + rela_query 121 | 122 | #print(program) 123 | return program 124 | 125 | class Relation(): 126 | def __init__(self, rela_name, sub, obj): 127 | self.rela_name = rela_name 128 | self.sub = sub 129 | self.obj = obj 130 | self.sub.set_sub_relas(self) 131 | self.obj.set_obj_relas(self) 132 | 133 | def substitute(self, v1, v2): 134 | if self.sub == v1: 135 | self.sub = v2 136 | if self.obj == v1: 137 | self.obj = v2 138 | 139 | def to_datalog(self): 140 | #rela_query = f"relation(\"{self.rela_name}\", {self.sub.var_id}, {self.obj.var_id})" 141 | #rela_query = "relation(0,1,{}, {}, {})".format( self.sub.var_id.replace(" ","_"), self.obj.var_id.replace(" ","_"),self.rela_name.replace(" ","_")) 142 | rela_query = "relation({}, {}, {})".format( self.sub.var_id.replace(" ","_"), self.obj.var_id.replace(" ","_"),self.rela_name.replace(" ","_")) 143 | 144 | return rela_query 145 | 146 | 147 | # This is for binary operations on variables 148 | class BiOp(): 149 | def __init__(self, op_name, v1, v2): 150 | self.op_name = op_name 151 | self.v1 = v1 152 | self.v2 = v2 153 | 154 | def to_datalog(self): 155 | raise NotImplementedError 156 | 157 | class Or(BiOp): 158 | def __init__(self, v1, v2): 159 | super().__init__('or', v1, v2) 160 | 161 | def to_datalog(self): 162 | pass 163 | 164 | class And(BiOp): 165 | def __init__(self, v1, v2): 166 | super().__init__('and', v1, v2) 167 | 168 | def to_datalog(self): 169 | pass 170 | 171 | class Query(): 172 | 173 | def __init__(self, query): 174 | self.vars = [] 175 | self.relations = [] 176 | self.operations = [] 177 | self.stack = [] 178 | self.preprocess(query) 179 | 180 | 181 | def get_target(self): 182 | pass 183 | 184 | def get_new_var(self): 185 | self.vars.append(Variable(len(self.vars))) 186 | 187 | def preprocess(self, query): 188 | 189 | # for clause in question["program"]: 190 | for clause in query: 191 | 192 | if clause['function'] == "Initial": 193 | if not len(self.vars) == 0: 194 | self.stack.append(self.vars[-1]) 195 | self.get_new_var() 196 | self.root = self.vars[-1] 197 | 198 | # logic operations 199 | elif clause['function'] == "And": 200 | v = self.stack.pop() 201 | self.operations.append(And(v, self.vars[-1])) 202 | self.root = self.operations[-1] 203 | 204 | elif clause['function'] == "Or": 205 | v = self.stack.pop() 206 | self.operations.append(Or(v, self.vars[-1])) 207 | self.root = self.operations[-1] 208 | 209 | # find operations 210 | elif clause['function'] == "KG_Find": 211 | self.vars[-1].set_kg(clause['text_input']) 212 | 213 | elif clause['function'] == "Hypernym_Find": 214 | self.vars[-1].set_hypernym(clause['text_input']) 215 | 216 | elif clause['function'] == "Find_Name": 217 | self.vars[-1].set_name(clause['text_input']) 218 | 219 | elif clause['function'] == "Find_Attr": 220 | self.vars[-1].set_attr(clause['text_input']) 221 | 222 | elif clause['function'] == "Relate_Reverse": 223 | self.get_new_var() 224 | self.root = self.vars[-1] 225 | obj = self.vars[-2] 226 | sub = self.vars[-1] 227 | rela_name = clause['text_input'] 228 | relation = Relation(rela_name, sub, obj) 229 | self.relations.append(relation) 230 | 231 | elif clause['function'] == "Relate": 232 | self.get_new_var() 233 | self.root = self.vars[-1] 234 | sub = self.vars[-2] 235 | obj = self.vars[-1] 236 | rela_name = clause['text_input'] 237 | relation = Relation(rela_name, sub, obj) 238 | self.relations.append(relation) 239 | 240 | else: 241 | raise Exception(f"Not handled function: {clause['function']}") 242 | 243 | 244 | # Optimizers for optimization 245 | class QueryOptimizer(): 246 | 247 | def __init__(self, name): 248 | self.name = name 249 | 250 | def optimize(self, query): 251 | raise NotImplementedError 252 | 253 | # This only works for one and operation at the end 254 | # This is waited for update 255 | # class AndQueryOptimizer(QueryOptimizer): 256 | 257 | # def __init__(self): 258 | # super().__init__("AndQueryOptimizer") 259 | 260 | # # For any and operation, this can be rewritten as a single object 261 | # def optimize(self, query): 262 | 263 | # if len(query.operations) == 0: 264 | # return query 265 | 266 | # assert(len(query.operations) == 1) 267 | 268 | # operation = query.operations[0] 269 | # # merge every subtree into one 270 | # if operation.name == "and": 271 | # v1 = operation.v1 272 | # v2 = operation.v2 273 | # v1.merge(v2) 274 | 275 | # for relation in query.relations: 276 | # relation.substitute(v2, v1) 277 | 278 | # query.vars.remove(v2) 279 | 280 | # if query.root == operation: 281 | # query.root = v1 282 | 283 | # return query 284 | 285 | 286 | class HypernymOptimizer(QueryOptimizer): 287 | 288 | def __init__(self): 289 | super().__init__("HypernymOptimizer") 290 | 291 | def optimize(self, query): 292 | 293 | if (query.name is not None and not query.hypernyms == []): 294 | query.hypernyms = [] 295 | 296 | return query 297 | 298 | 299 | class KGOptimizer(QueryOptimizer): 300 | 301 | def __init__(self): 302 | super().__init__("HypernymOptimizer") 303 | 304 | def optimize(self, query): 305 | 306 | if (query.name is not None and not query.kgs == []): 307 | query.kgs = [] 308 | 309 | return query 310 | -------------------------------------------------------------------------------- /src/experiments/vqa/query_lib.py: -------------------------------------------------------------------------------- 1 | """ 2 | The source code is based on: 3 | Scallop: From Probabilistic Deductive Databases to Scalable Differentiable Reasoning 4 | Jiani Huang, Ziyang Li, Binghong Chen, Karan Samel, Mayur Naik, Le Song, Xujie Si 5 | Advances in Neural Information Processing Systems 34 (NeurIPS 2021) 6 | https://proceedings.neurips.cc/paper/2021/hash/d367eef13f90793bd8121e2f675f0dc2-Abstract.html 7 | """ 8 | 9 | import pickle 10 | import os 11 | import subprocess 12 | 13 | from transformer import DetailTransformer, SimpleTransformer 14 | from preprocess import Query 15 | from knowledge_graph import KG, RULES 16 | 17 | # animals = ["giraffe", "cat", "kitten", "dog", "puppy", "poodle", "bull", "cow", "cattle", "bison", "calf", "pig", "ape", "monkey", "gorilla", "rat", "squirrel", "hamster", "deer", "moose", "alpaca", "elephant", "goat", "sheep", "lamb", "antelope", "rhino", "hippo", "zebra", "horse", "pony", "donkey", "camel", "panda", "panda bear", "bear", "polar bear", "seal", "fox", "raccoon", "tiger", "wolf", "lion", "leopard", "cheetah", "badger", "rabbit", "bunny", "beaver", "kangaroo", "dinosaur", "dragon", "fish", "whale", "dolphin", "crab", "shark", "octopus", "lobster", "oyster", "butterfly", "bee", "fly", "ant", "firefly", "snail", "spider", "bird", "penguin", "pigeon", "seagull", "finch", "robin", "ostrich", "goose", "owl", "duck", "hawk", "eagle", "swan", "chicken", "hen", "hummingbird", "parrot", "crow", "flamingo", "peacock", "bald eagle", "dove", "snake", "lizard", "alligator", "turtle", "frog", "animal"] 18 | class QueryManager(): 19 | def __init__(self, save_dir): 20 | self.save_dir = save_dir 21 | self.transformer = SimpleTransformer() 22 | 23 | def save_file(self, file_name, content): 24 | save_path = os.path.join (self.save_dir, file_name) 25 | with open (save_path, "w") as save_file: 26 | save_file.write(content) 27 | 28 | def delete_file(self, file_name): 29 | save_path = os.path.join (self.save_dir, file_name) 30 | if os.path.exists(save_path): 31 | os.remove(save_path) 32 | 33 | def fact_prob_to_file(self, fact_tps, fact_probs): 34 | scene_tps = [] 35 | 36 | (name_tps, attr_tps, rela_tps) = fact_tps 37 | (name_probs, attr_probs, rela_probs) = fact_probs 38 | 39 | cluster_ntp = {} 40 | for (oid, name), prob in zip(name_tps, name_probs): 41 | if not oid in cluster_ntp: 42 | cluster_ntp[oid] = [(name, prob)] 43 | else: 44 | cluster_ntp[oid].append((name, prob)) 45 | 46 | 47 | for oid, name_list in cluster_ntp.items(): 48 | name_tps = [] 49 | for (name, prob) in name_list: 50 | # if not name in animals[:5]: 51 | # continue 52 | name_tps.append(f'{prob}::name("{name}", {int(oid)})') 53 | name_content = ";\n".join(name_tps) + "." 54 | scene_tps.append(name_content) 55 | 56 | for attr_tp, prob in zip(attr_tps, attr_probs): 57 | # if not attr_tp[1] == "tall": 58 | # continue 59 | scene_tps.append(f'{prob}::attr("{attr_tp[1]}", {int(attr_tp[0])}).') 60 | 61 | for rela_tp, prob in zip(rela_tps, rela_probs): 62 | # if not rela_tp[0] == "left": 63 | # continue 64 | scene_tps.append(f'{prob}::relation("{rela_tp[0]}", {int(rela_tp[1])}, {int(rela_tp[2])}).') 65 | 66 | return "\n".join(scene_tps) 67 | 68 | def process_result(self, result): 69 | output = result.stdout.decode() 70 | lines = output.split("\n") 71 | targets = {} 72 | for line in lines: 73 | if line == '': 74 | continue 75 | if not '\t' in line: 76 | continue 77 | info = line.split('\t') 78 | # No target found 79 | if 'X' in info[0]: 80 | break 81 | target_name = int(info[0][7:-2]) 82 | target_prob = float(info[1]) 83 | targets[target_name] = target_prob 84 | return targets 85 | 86 | def get_result(self, task, fact_tps, fact_probs): 87 | timeout = False 88 | question = task["question"]["clauses"] 89 | file_name = f"{task['question']['question_id']}.pl" 90 | save_path = os.path.join (self.save_dir, file_name) 91 | query = Query(question) 92 | query_content = self.transformer.transform(query) 93 | scene_content = self.fact_prob_to_file(fact_tps, fact_probs) 94 | 95 | content = KG+ "\n" + RULES + "\n" + scene_content + "\n" + query_content 96 | self.save_file(file_name, content) 97 | try: 98 | result = subprocess.run(["problog", save_path], capture_output=True, timeout=10) 99 | targets = self.process_result(result) 100 | except: 101 | # time out here 102 | timeout = True 103 | targets = {} 104 | 105 | # self.delete_file(file_name) 106 | return targets, timeout 107 | 108 | def get_relas(self, query): 109 | relations = [] 110 | for clause in query: 111 | if 'Relate' in clause['function']: 112 | relations.append(clause['text_input']) 113 | return relations 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /src/experiments/vqa/sg_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | The source code is based on: 3 | Scallop: From Probabilistic Deductive Databases to Scalable Differentiable Reasoning 4 | Jiani Huang, Ziyang Li, Binghong Chen, Karan Samel, Mayur Naik, Le Song, Xujie Si 5 | Advances in Neural Information Processing Systems 34 (NeurIPS 2021) 6 | https://proceedings.neurips.cc/paper/2021/hash/d367eef13f90793bd8121e2f675f0dc2-Abstract.html 7 | """ 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import os 14 | torch.autograd.set_detect_anomaly(True) 15 | 16 | def load_model(model, model_f, device): 17 | print('loading model from %s' % model_f) 18 | model.load_state_dict(torch.load(model_f, map_location=device)) 19 | model.eval() 20 | 21 | class MLPClassifier(nn.Module): 22 | def __init__(self, input_dim, latent_dim, output_dim, n_layers, dropout_rate): 23 | super(MLPClassifier, self).__init__() 24 | 25 | layers = [] 26 | layers.append(nn.Linear(input_dim, latent_dim)) 27 | layers.append(nn.ReLU()) 28 | layers.append(nn.BatchNorm1d(latent_dim)) 29 | layers.append(nn.Dropout(dropout_rate)) 30 | for _ in range(n_layers - 1): 31 | layers.append(nn.Linear(latent_dim, latent_dim)) 32 | layers.append(nn.ReLU()) 33 | layers.append(nn.BatchNorm1d(latent_dim)) 34 | layers.append(nn.Dropout(dropout_rate)) 35 | layers.append(nn.Linear(latent_dim, output_dim)) 36 | 37 | self.net = nn.Sequential(*layers) 38 | 39 | def forward(self, x): 40 | logits = self.net(x) 41 | return logits 42 | 43 | class SceneGraphModel: 44 | def __init__(self, feat_dim, n_names, n_attrs, n_rels, device, model_dir=None): 45 | self.feat_dim = feat_dim 46 | self.n_names = n_names 47 | self.n_attrs = n_attrs 48 | self.n_rels = n_rels 49 | self.device = device 50 | 51 | self._init_models() 52 | if model_dir is not None: 53 | self._load_models(model_dir) 54 | 55 | def _load_models(self, model_dir): 56 | for type in ['name', 'relation', 'attribute']: 57 | load_model( 58 | model=self.models[type], 59 | model_f=model_dir+'/%s_best_epoch.pt' % type, 60 | device=self.device 61 | ) 62 | 63 | def _init_models(self): 64 | name_clf = MLPClassifier( 65 | input_dim=self.feat_dim, 66 | output_dim=self.n_names, 67 | latent_dim=1024, 68 | n_layers=2, 69 | dropout_rate=0.3 70 | ) 71 | 72 | rela_clf = MLPClassifier( 73 | input_dim=(self.feat_dim+4)*2, # 4: bbox 74 | output_dim=self.n_rels+1, # 1: None 75 | latent_dim=1024, 76 | n_layers=1, 77 | dropout_rate=0.5 78 | ) 79 | 80 | attr_clf = MLPClassifier( 81 | input_dim=self.feat_dim, 82 | output_dim=self.n_attrs, 83 | latent_dim=1024, 84 | n_layers=1, 85 | dropout_rate=0.3 86 | ) 87 | 88 | self.models = { 89 | 'name': name_clf, 90 | 'attribute': attr_clf, 91 | 'relation': rela_clf 92 | } 93 | 94 | def predict(self, type, inputs): 95 | # type == 'name', inputs == (obj_feat_np_array) 96 | # type == 'relation', inputs == (sub_feat_np_array, obj_feat_np_array, sub_bbox_np_array, obj_bbox_np_array) 97 | # type == 'attribute', inputs == (obj_feat_np_array) 98 | 99 | model = self.models[type].to(self.device) 100 | inputs = torch.cat([torch.from_numpy(x).float() for x in inputs]).reshape(len(inputs), -1).to(self.device) 101 | logits = model(inputs) 102 | 103 | if type == 'attribute': 104 | probs = torch.sigmoid(logits) 105 | else: 106 | probs = F.softmax(logits, dim=1) 107 | 108 | return logits, probs 109 | 110 | def batch_predict(self, type, inputs, batch_split): 111 | 112 | model = self.models[type].to(self.device) 113 | inputs = torch.cat([torch.from_numpy(x).float() for x in inputs]).reshape(len(inputs), -1).to(self.device) 114 | logits = model(inputs) 115 | 116 | if type == 'attribute': 117 | probs = torch.sigmoid(logits) 118 | else: 119 | current_split = 0 120 | probs = [] 121 | for split in batch_split: 122 | current_logits = logits[current_split:split] 123 | # batched_logits = logits.reshape(batch_shape[0], batch_shape[1], -1) 124 | current_probs = F.softmax(current_logits, dim=1) 125 | # probs = probs.reshape(inputs.shape[0], -1) 126 | probs.append(current_probs) 127 | current_split = split 128 | 129 | probs = torch.cat(probs).reshape(inputs.shape[0], -1) 130 | return logits, probs 131 | -------------------------------------------------------------------------------- /src/experiments/vqa/test.py: -------------------------------------------------------------------------------- 1 | print("start importing...") 2 | 3 | import time 4 | import sys 5 | import argparse 6 | import datetime 7 | 8 | sys.path.append('../../') 9 | sys.path.append('../../SLASH/') 10 | sys.path.append('../../EinsumNetworks/src/') 11 | 12 | 13 | #torch, numpy, ... 14 | import torch 15 | from torch.utils.tensorboard import SummaryWriter 16 | from torchvision.transforms import transforms 17 | import torchvision 18 | 19 | import numpy as np 20 | 21 | import json 22 | 23 | #own modules 24 | from dataGen import VQA 25 | from einsum_wrapper import EiNet 26 | from network_nn import Net_nn 27 | 28 | from tqdm import tqdm 29 | 30 | #import slash 31 | from slash import SLASH 32 | import os 33 | 34 | 35 | 36 | 37 | import utils 38 | from utils import set_manual_seed 39 | from pathlib import Path 40 | from rtpt import RTPT 41 | import pickle 42 | 43 | 44 | from knowledge_graph import RULES, KG 45 | from dataGen import name_npp, relation_npp, attribute_npp 46 | from models import name_clf, rela_clf, attr_clf 47 | 48 | print("...done") 49 | 50 | 51 | 52 | 53 | def get_args(): 54 | parser = argparse.ArgumentParser() 55 | 56 | parser.add_argument( 57 | "--seed", type=int, default=10, help="Random generator seed for all frameworks" 58 | ) 59 | 60 | parser.add_argument( 61 | "--network-type", 62 | choices=["nn","pc"], 63 | help="The type of external to be used e.g. neural net or probabilistic circuit", 64 | ) 65 | parser.add_argument( 66 | "--pc-structure", 67 | choices=["poon-domingos","binary-trees"], 68 | help="The type of external to be used e.g. neural net or probabilistic circuit", 69 | ) 70 | parser.add_argument( 71 | "--batch-size", type=int, default=100, help="Batch size to train with" 72 | ) 73 | parser.add_argument( 74 | "--num-workers", type=int, default=6, help="Number of threads for data loader" 75 | ) 76 | 77 | parser.add_argument( 78 | "--p-num", type=int, default=8, help="Number of processes to devide the batch for parallel processing" 79 | ) 80 | 81 | parser.add_argument("--credentials", type=str, help="Credentials for rtpt") 82 | 83 | 84 | args = parser.parse_args() 85 | 86 | if args.network_type == 'pc': 87 | args.use_pc = True 88 | else: 89 | args.use_pc = False 90 | 91 | return args 92 | 93 | 94 | def determine_max_objects(task_file): 95 | 96 | with open (task_file, 'rb') as tf: 97 | tasks = pickle.load(tf) 98 | print("taskfile len",len(tasks)) 99 | 100 | #get the biggest number of objects in the image 101 | max_objects = 0 102 | for tidx, task in enumerate(tasks): 103 | all_oid = task['question']['input'] 104 | len_all_oid = len(all_oid) 105 | 106 | #store the biggest number of objects in an image 107 | if len_all_oid > max_objects: 108 | max_objects = len_all_oid 109 | return max_objects 110 | 111 | def slash_vqa(): 112 | 113 | args = get_args() 114 | print(args) 115 | 116 | 117 | # Set the seeds for PRNG 118 | set_manual_seed(args.seed) 119 | 120 | # Create RTPT object 121 | rtpt = RTPT(name_initials=args.credentials, experiment_name='SLASH VQA', max_iterations=1) 122 | 123 | # Start the RTPT tracking 124 | rtpt.start() 125 | #writer = SummaryWriter(os.path.join("runs","vqa", str(args.seed)), purge_step=0) 126 | 127 | #exp_name = 'vqa3' 128 | #Path("data/"+exp_name+"/").mkdir(parents=True, exist_ok=True) 129 | #saveModelPath = 'data/'+exp_name+'/slash_vqa_models_seed'+str(args.seed)+'.pt' 130 | 131 | 132 | #TODO workaround that adds +- notation 133 | program_example = """ 134 | %scallop conversion rules 135 | name(O,N) :- name(0,+O,-N). 136 | attr(O,A) :- attr(0, +O, -A). 137 | relation(O1,O2,N) :- relation(0, +(O1,O2), -N). 138 | """ 139 | 140 | #test_f = "dataset/task_list/test_tasks_c3_1000.pkl" # Test datset 141 | 142 | test_f = {"c2":"dataset/task_list/test_tasks_c2_1000.pkl", 143 | "c3":"dataset/task_list/test_tasks_c3_1000.pkl", 144 | "c4":"dataset/task_list/test_tasks_c4_1000.pkl", 145 | "c5":"dataset/task_list/test_tasks_c5_1000.pkl", 146 | "c6":"dataset/task_list/test_tasks_c6_1000.pkl" 147 | } 148 | 149 | 150 | num_obj = [] 151 | if type(test_f) == str: 152 | num_obj.append(determine_max_objects(test_f)) 153 | #if we have multiple test files 154 | elif type(test_f) == dict: 155 | for key in test_f: 156 | num_obj.append(determine_max_objects(test_f[key])) 157 | 158 | 159 | NUM_OBJECTS = np.max(num_obj) 160 | NUM_OBJECTS = 70 161 | 162 | 163 | vqa_params = {"l":200, 164 | "l_split":100, 165 | "num_names":500, 166 | "max_models":10000, 167 | "asp_timeout": 60} 168 | 169 | 170 | 171 | #load models #data/vqa18_10/slash_vqa_models_seed0_epoch_0.pt 172 | #src/experiments/vqa/ 173 | #saved_models = torch.load("data/test/slash_vqa_models_seed42_epoch_9.pt") 174 | saved_models = torch.load("data/vqa_debug_relations_17_04_2023/slash_vqa_models_seed0_epoch_2.pt") 175 | 176 | print(saved_models.keys()) 177 | rela_clf.load_state_dict(saved_models['relation_clf']) 178 | name_clf.load_state_dict(saved_models['name_clf']) 179 | attr_clf.load_state_dict(saved_models['attr_clf']) 180 | 181 | #create the SLASH Program , , 182 | nnMapping = {'relation': rela_clf, 'name':name_clf , "attr":attr_clf} 183 | optimizers = {'relation': torch.optim.Adam(rela_clf.parameters(), lr=0.001, eps=1e-7), 184 | 'name': torch.optim.Adam(name_clf.parameters(), lr=0.001, eps=1e-7), 185 | 'attr': torch.optim.Adam(attr_clf.parameters(), lr=0.001, eps=1e-7)} 186 | 187 | 188 | 189 | all_oid = np.arange(0,NUM_OBJECTS) 190 | object_string = "".join([ f"object({oid1},{oid2}). " for oid1 in all_oid for oid2 in all_oid if oid1 != oid2]) 191 | object_string = "".join(["".join([f"object({oid}). " for oid in all_oid]), object_string]) 192 | 193 | #parse the SLASH program 194 | print("create SLASH program") 195 | program = "".join([KG, RULES, object_string, name_npp, relation_npp, attribute_npp, program_example]) 196 | SLASHobj = SLASH(program, nnMapping, optimizers) 197 | 198 | #load the data 199 | if type(test_f) == str: 200 | test_data = VQA("test", test_f, NUM_OBJECTS) 201 | test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=4) 202 | #if we have multiple test files 203 | elif type(test_f) == dict: 204 | for key in test_f: 205 | test_data = VQA("test", test_f[key], NUM_OBJECTS) 206 | test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=4) 207 | test_f[key] = test_loader 208 | 209 | 210 | 211 | print("---TEST---") 212 | if type(test_f) == str: 213 | recall_5_test, test_time = SLASHobj.testVQA(test_loader, args.p_num, vqa_params=vqa_params) 214 | print("test-recall@5", recall_5_test) 215 | 216 | elif type(test_f) == dict: 217 | test_time = 0 218 | recalls = [] 219 | for key in test_f: 220 | recall_5_test, tt = SLASHobj.testVQA(test_f[key], args.p_num, vqa_params=vqa_params) 221 | test_time += tt 222 | recalls.append(recall_5_test) 223 | print("test-recall@5_{}".format(key), recall_5_test, ", test_time:", tt ) 224 | print("test-recall@5_c_all", np.mean(recalls), ", test_time:", test_time) 225 | 226 | 227 | if __name__ == "__main__": 228 | slash_vqa() 229 | -------------------------------------------------------------------------------- /src/experiments/vqa/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DEVICE=$1 4 | SEED=$2 # 0, 1, 2, 3, 4 5 | CREDENTIALS=$3 6 | 7 | 8 | 9 | #-------------------------------------------------------------------------------# 10 | # Train on CLEVR_v1 with cnn model 11 | 12 | CUDA_VISIBLE_DEVICES=$DEVICE python3 test.py \ 13 | --seed $SEED \ 14 | --network-type nn --batch-size 100 \ 15 | --num-workers 0 --p-num 16 --credentials $CREDENTIALS 16 | 17 | -------------------------------------------------------------------------------- /src/experiments/vqa/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DEVICE=$1 4 | SEED=$2 # 0, 1, 2, 3, 4 5 | CREDENTIALS=$3 6 | 7 | 8 | # 0.00001, 0.0001, 0.001, 0.01 9 | LR=0.001 10 | 11 | #-------------------------------------------------------------------------------# 12 | 13 | CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \ 14 | --epochs 100 \ 15 | --batch-size 100 --seed $SEED \ 16 | --network-type nn --lr $LR \ 17 | --num-workers 0 --p-num 20 --credentials $CREDENTIALS \ 18 | --exp-name vqa_c2 19 | -------------------------------------------------------------------------------- /src/experiments/vqa/trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | The source code is based on: 3 | Scallop: From Probabilistic Deductive Databases to Scalable Differentiable Reasoning 4 | Jiani Huang, Ziyang Li, Binghong Chen, Karan Samel, Mayur Naik, Le Song, Xujie Si 5 | Advances in Neural Information Processing Systems 34 (NeurIPS 2021) 6 | https://proceedings.neurips.cc/paper/2021/hash/d367eef13f90793bd8121e2f675f0dc2-Abstract.html 7 | """ 8 | 9 | import os 10 | import json 11 | import numpy as np 12 | import sys 13 | from tqdm import tqdm 14 | import torch 15 | import torch.optim as optim 16 | from torch.optim.lr_scheduler import StepLR 17 | from torch.nn import BCELoss 18 | import time 19 | import math 20 | import statistics 21 | # import concurrent.futures 22 | 23 | supervised_learning_path = os.path.abspath(os.path.join( 24 | os.path.abspath(__file__), "../../supervised_learning")) 25 | sys.path.insert(0, supervised_learning_path) 26 | 27 | common_path = os.path.abspath(os.path.join(os.path.abspath(__file__), "../..")) 28 | sys.path.insert(0, common_path) 29 | 30 | from query_lib import QueryManager 31 | #from src.experiments.vqa.cmd_args2 import cmd_args 32 | from word_idx_translator import Idx2Word 33 | from sg_model import SceneGraphModel 34 | from learning import get_fact_probs 35 | from vqa_utils import auc_score, get_recall 36 | 37 | def prog_analysis(datapoint): 38 | knowledge_op = ['Hypernym_Find', 'KG_Find'] 39 | rela_op = ['Relate', 'Relate_Reverse'] 40 | clauses = datapoint['question']['clauses'] 41 | ops = [ clause['function'] for clause in clauses ] 42 | kg_num = sum([1 if op in knowledge_op else 0 for op in ops]) 43 | rela_num = sum([1 if op in rela_op else 0 for op in ops]) 44 | return (kg_num, rela_num) 45 | 46 | class ClauseNTrainer(): 47 | 48 | def __init__(self, 49 | train_data_loader, 50 | val_data_loader, 51 | test_data_loader=None, 52 | #model_dir=cmd_args.model_dir, n_epochs=cmd_args.n_epochs, 53 | #save_dir=cmd_args.save_dir, 54 | #meta_f=cmd_args.meta_f, knowledge_base_dir=cmd_args.knowledge_base_dir, 55 | #axiom_update_size=cmd_args.axiom_update_size): 56 | model_dir="data_model/", n_epochs=2, 57 | save_dir="data_save/", 58 | meta_f="dataset/gqa_info.json", knowledge_base_dir="", 59 | axiom_update_size=""): 60 | 61 | 62 | self.model_dir = model_dir 63 | model_exists = self._model_exists(model_dir) 64 | 65 | if not model_exists: 66 | load_model_dir = None 67 | else: 68 | load_model_dir = model_dir 69 | 70 | if train_data_loader is not None: 71 | self.train_data = train_data_loader 72 | if val_data_loader is not None: 73 | self.val_data = val_data_loader 74 | if test_data_loader is not None: 75 | self.val_data = test_data_loader 76 | 77 | self.is_train = test_data_loader is None 78 | meta_info = json.load(open(meta_f, 'r')) 79 | self.idx2word = Idx2Word(meta_info) 80 | self.n_epochs = n_epochs 81 | self.query_manager = QueryManager(save_dir) 82 | self.axiom_update_size = axiom_update_size 83 | self.wmc_funcs = {} 84 | 85 | # load dictionary from previous training results 86 | self.sg_model = SceneGraphModel( 87 | feat_dim=64, 88 | n_names=meta_info['name']['num'], 89 | n_attrs=meta_info['attr']['num'], 90 | n_rels=meta_info['rel']['num'], 91 | device=torch.device('cuda'), 92 | model_dir=load_model_dir 93 | ) 94 | 95 | self.sg_model_dict = self.sg_model.models 96 | 97 | self.loss_func = BCELoss() 98 | 99 | if self.is_train: 100 | self.optimizers = {} 101 | self.schedulers = {} 102 | 103 | for model_type, model_info in self.sg_model_dict.items(): 104 | if model_type == 'name': 105 | self.optimizers['name'] = optim.Adam( 106 | model_info.parameters(), lr=0.01) 107 | self.schedulers['name'] = StepLR( 108 | self.optimizers['name'], step_size=10, gamma=0.1) 109 | # self.loss_func['name'] = F.cross_entropy 110 | if model_type == 'relation': 111 | self.optimizers['rel'] = optim.Adam( 112 | model_info.parameters(), lr=0.01) 113 | self.schedulers['rel'] = StepLR( 114 | self.optimizers['rel'], step_size=10, gamma=0.1) 115 | if model_type == 'attribute': 116 | self.optimizers['attr'] = optim.Adam( 117 | model_info.parameters(), lr=0.01) 118 | self.schedulers['attr'] = StepLR( 119 | self.optimizers['attr'], step_size=10, gamma=0.1) 120 | 121 | # self.pool = mp.Pool(cmd_args.max_workers) 122 | # self.batch = cmd_args.trainer_batch 123 | 124 | def _model_exists(self, model_dir): 125 | name_f = os.path.join(self.model_dir, 'name_best_epoch.pt') 126 | rela_f = os.path.join(self.model_dir, 'relation_best_epoch.pt') 127 | attr_f = os.path.join(self.model_dir, 'attribute_best_epoch.pt') 128 | 129 | if not os.path.exists(name_f): 130 | return False 131 | if not os.path.exists(rela_f): 132 | return False 133 | if not os.path.exists(attr_f): 134 | return False 135 | return True 136 | 137 | def _get_optimizer(self, data_type): 138 | optimizer = self.optimizers[data_type] 139 | return optimizer 140 | 141 | def _step_all(self): 142 | for optim_type, optim_info in self.optimizers.items(): 143 | optim_info.step() 144 | 145 | def _step_scheduler(self): 146 | for scheduler_type, scheduler_info in self.schedulers.items(): 147 | scheduler_info.step() 148 | 149 | def _zero_all(self): 150 | for optim_type, optim_info in self.optimizers.items(): 151 | optim_info.zero_grad() 152 | 153 | def _train_all(self): 154 | for model_type, model_info in self.sg_model_dict.items(): 155 | if model_type == 'name': 156 | model_info.train() 157 | model_info.share_memory() 158 | if model_type == 'relation': 159 | model_info.train() 160 | model_info.share_memory() 161 | if model_type == 'attribute': 162 | model_info.train() 163 | model_info.share_memory() 164 | 165 | def _eval_all(self): 166 | for model_type, model_info in self.sg_model_dict.items(): 167 | if model_type == 'name': 168 | model_info.eval() 169 | model_info.share_memory() 170 | if model_type == 'relation': 171 | model_info.eval() 172 | model_info.share_memory() 173 | if model_type == 'attribute': 174 | model_info.eval() 175 | model_info.share_memory() 176 | 177 | def _save_all(self): 178 | for model_type, model_info in self.sg_model_dict.items(): 179 | if model_type == 'name': 180 | save_f = os.path.join(self.model_dir, 'name_best_epoch.pt') 181 | torch.save(model_info.state_dict(), save_f) 182 | if model_type == 'relation': 183 | save_f = os.path.join(self.model_dir, 'relation_best_epoch.pt') 184 | torch.save(model_info.state_dict(), save_f) 185 | if model_type == 'attribute': 186 | save_f = os.path.join(self.model_dir, 'attribute_best_epoch.pt') 187 | torch.save(model_info.state_dict(), save_f) 188 | 189 | def loss_acc(self, targets, correct, all_oids, is_train=True): 190 | 191 | pred = [] 192 | for oid in all_oids: 193 | if oid in targets: 194 | pred.append(targets[oid]) 195 | else: 196 | pred.append(0) 197 | 198 | labels = [1 if obj in correct else 0 for obj in all_oids] 199 | 200 | labels_tensor = torch.tensor(labels, dtype=torch.float32) 201 | pred_tensor = torch.tensor(pred, dtype=torch.float32) 202 | 203 | pred_tensor = pred_tensor.reshape(1, -1) 204 | labels_tensor = labels_tensor.reshape(1, -1) 205 | 206 | loss = self.loss_func(pred_tensor, labels_tensor) 207 | recall = get_recall(labels_tensor, pred_tensor) 208 | 209 | if math.isnan(recall): 210 | recall = -1 211 | 212 | return loss.item(), recall 213 | 214 | def _pass(self, datapoint, is_train=True): 215 | 216 | correct = datapoint['question']['output'] 217 | all_oid = datapoint['question']['input'] 218 | fact_tps, fact_probs = get_fact_probs(self.sg_model, datapoint, self.idx2word, self.query_manager) 219 | result, timeout = self.query_manager.get_result(datapoint, fact_tps, fact_probs) 220 | if not timeout: 221 | loss, acc = self.loss_acc(result, correct, all_oid) 222 | else: 223 | loss = -1 224 | acc = -1 225 | 226 | return acc, loss, timeout 227 | 228 | def _train_epoch(self, ct): 229 | 230 | self._train_all() 231 | aucs = [] 232 | losses = [] 233 | timeouts = 0 234 | pbar = tqdm(self.train_data) 235 | 236 | for datapoint in pbar: 237 | auc, loss, timeout = self._pass(datapoint, is_train=True) 238 | if not timeout: 239 | if auc >= 0: 240 | aucs.append(auc) 241 | losses.append(loss) 242 | else: 243 | timeouts += 1 244 | 245 | pbar.set_description( 246 | f'[loss: {np.array(losses).mean()}, auc: {np.array(aucs).mean()}, timeouts: {timeouts}]') 247 | torch.cuda.empty_cache() 248 | 249 | self._step_all() 250 | self._zero_all() 251 | 252 | return np.mean(losses), np.mean(aucs) 253 | 254 | def _val_epoch(self): 255 | self._eval_all() 256 | 257 | timeouts = 0 258 | aucs = [] 259 | losses = [] 260 | time_out_prog_kg = {} 261 | time_out_prog_rela = {} 262 | success_prog_kg = {} 263 | success_prog_rela = {} 264 | 265 | pbar = tqdm(self.val_data) 266 | with torch.no_grad(): 267 | for datapoint in pbar: 268 | kg_num, rela_num = prog_analysis(datapoint) 269 | 270 | auc, loss, timeout = self._pass(datapoint, is_train=False) 271 | if not timeout: 272 | aucs.append(auc) 273 | losses.append(loss) 274 | if not kg_num in success_prog_kg: 275 | success_prog_kg[kg_num] = 0 276 | if not rela_num in success_prog_rela: 277 | success_prog_rela[rela_num] = 0 278 | success_prog_kg[kg_num] += 1 279 | success_prog_rela[rela_num] += 1 280 | 281 | else: 282 | timeouts += 1 283 | if not kg_num in time_out_prog_kg: 284 | time_out_prog_kg[kg_num] = 0 285 | if not rela_num in time_out_prog_rela: 286 | time_out_prog_rela[rela_num] = 0 287 | time_out_prog_kg[kg_num] += 1 288 | time_out_prog_rela[rela_num] += 1 289 | 290 | if not len(aucs) == 0: 291 | pbar.set_description( 292 | f'[loss: {np.array(losses).mean()}, auc: {np.array(aucs).mean()}, timeouts: {timeouts}]') 293 | 294 | print(f"succ kg: {success_prog_kg}, succ rela: {success_prog_rela}") 295 | print(f"timeout kg: {time_out_prog_kg}, timeout rela: {time_out_prog_rela}") 296 | return np.mean(losses), np.mean(aucs) 297 | 298 | 299 | def train(self): 300 | assert self.is_train 301 | best_val_loss = np.inf 302 | 303 | for epoch in range(self.n_epochs): 304 | train_loss, train_acc = self._train_epoch(epoch) 305 | val_loss, val_acc = self._val_epoch() 306 | self._step_scheduler() 307 | 308 | print( 309 | '[Epoch %d/%d] [training loss: %.2f, auc: %.2f] [validation loss: %.2f, auc: %.2f]' % 310 | (epoch, self.n_epochs, train_loss, train_acc, val_loss, val_acc) 311 | ) 312 | 313 | if val_loss < best_val_loss: 314 | best_val_loss = val_loss 315 | print('saving best models') 316 | self._save_all() 317 | 318 | def test(self): 319 | assert not self.is_train 320 | test_loss, test_acc = self._val_epoch() 321 | print('[test loss: %.2f, acc: %.2f]' % (test_loss, test_acc)) 322 | -------------------------------------------------------------------------------- /src/experiments/vqa/vqa_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | The source code is based on: 3 | Scallop: From Probabilistic Deductive Databases to Scalable Differentiable Reasoning 4 | Jiani Huang, Ziyang Li, Binghong Chen, Karan Samel, Mayur Naik, Le Song, Xujie Si 5 | Advances in Neural Information Processing Systems 34 (NeurIPS 2021) 6 | https://proceedings.neurips.cc/paper/2021/hash/d367eef13f90793bd8121e2f675f0dc2-Abstract.html 7 | """ 8 | 9 | import os 10 | import json 11 | import argparse 12 | import numpy as np 13 | import torch 14 | import random 15 | import pickle 16 | from sklearn import metrics 17 | 18 | TIME_OUT = 120 19 | 20 | ########################################################################### 21 | # basic utilities 22 | 23 | 24 | def to_image_id_dict(dict_list): 25 | ii_dict = {} 26 | for dc in dict_list: 27 | image_id = dc['image_id'] 28 | ii_dict[image_id] = dc 29 | return ii_dict 30 | 31 | 32 | articles = ['a', 'an', 'the', 'some', 'it'] 33 | 34 | 35 | def remove_article(st): 36 | st = st.split() 37 | st = [word for word in st if word not in articles] 38 | return " ".join(st) 39 | 40 | ################################################################################################################ 41 | # AUC calculation 42 | 43 | def get_recall(labels, logits, topk=2): 44 | #a = torch.tensor([[0, 1, 1], [1, 1, 0], [0, 0, 1],[1, 1, 1]]) 45 | #b = torch.tensor([[0, 0, 0], [0, 1, 0], [0, 0, 1],[0, 1, 1]]) 46 | 47 | # Calculate the recall 48 | _, pred = logits.topk(topk, 1, True, True) #get idx of the biggest k values in the tensor 49 | print(pred) 50 | 51 | pred = pred.t() #transpose 52 | 53 | #gather gets the elements from a given indices tensor. Here we get the elements at the same positions from our (top5) predictions 54 | #we then sum all entries up along the axis and therefore count the top-5 entries in the labels tensors at the prediction position indices 55 | correct = torch.sum(labels.gather(1, pred.t()), dim=1) 56 | print("correct", correct) 57 | 58 | #now we some up all true labels. We clamp them to be maximum top5 59 | correct_label = torch.clamp(torch.sum(labels, dim = 1), 0, topk) 60 | 61 | #we now can compare if the number of correctly found top-5 labels on the predictions vector is the same as on the same positions as the GT vector 62 | print("correct label", correct_label) 63 | 64 | accuracy = torch.mean(correct / correct_label).item() 65 | 66 | return accuracy 67 | 68 | def single_auc_score(label, pred): 69 | label = label 70 | pred = [p.item() if not (type(p) == float or type(p) == int) 71 | else p for p in pred] 72 | if len(set(label)) == 1: 73 | auc = -1 74 | else: 75 | fpr, tpr, thresholds = metrics.roc_curve(label, pred, pos_label=1) 76 | auc = metrics.auc(fpr, tpr) 77 | return auc 78 | 79 | 80 | def auc_score(labels, preds): 81 | if type(labels) == torch.Tensor: 82 | labels = labels.long().cpu().detach().numpy() 83 | if type(preds) == torch.Tensor: 84 | preds = preds.cpu().detach().numpy() 85 | 86 | if (type(labels) == torch.Tensor or type(labels) == np.ndarray) and len(labels.shape) == 2: 87 | aucs = [] 88 | for label, pred in zip(labels, preds): 89 | auc_single = single_auc_score(label, pred) 90 | if not auc_single < 0: 91 | aucs.append(auc_single) 92 | auc = np.array(aucs).mean() 93 | else: 94 | auc = single_auc_score(labels, preds) 95 | 96 | return auc 97 | 98 | 99 | def to_binary_labels(labels, attr_num): 100 | if labels.shape[-1] == attr_num: 101 | return labels 102 | 103 | binary_labels = torch.zeros((labels.shape[0], attr_num)) 104 | for ct, label in enumerate(labels): 105 | binary_labels[ct][label] = 1 106 | return binary_labels 107 | 108 | 109 | ############################################################################## 110 | # model loading 111 | def get_default_args(): 112 | 113 | DATA_ROOT = os.path.abspath(os.path.join( 114 | os.path.abspath(__file__), "../../data")) 115 | 116 | parser = argparse.ArgumentParser() 117 | parser.add_argument('--gpu', type=int, default=2) 118 | parser.add_argument('--seed', type=int, default=1234) 119 | parser.add_argument('--feat_dim', type=int, default=2048) 120 | # parser.add_argument('--type', default='name') 121 | parser.add_argument('--model_dir', default=DATA_ROOT + '/model_ckpts') 122 | parser.add_argument('--meta_f', default=DATA_ROOT + '/preprocessing/gqa_info.json') 123 | args = parser.parse_args() 124 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) 125 | 126 | np.random.seed(args.seed) 127 | torch.manual_seed(args.seed) 128 | random.seed(args.seed) 129 | 130 | return args 131 | 132 | 133 | def gen_task(formatted_scene_graph_path, vg_image_data_path, questions_path, features_path, n=100, q_length=None): 134 | 135 | with open(questions_path, 'rb') as questions_file: 136 | questions = pickle.load(questions_file) 137 | 138 | with open(formatted_scene_graph_path, 'rb') as vg_scene_graphs_file: 139 | scene_graphs = pickle.load(vg_scene_graphs_file) 140 | 141 | with open(vg_image_data_path, 'r') as vg_image_data_file: 142 | image_data = json.load(vg_image_data_file) 143 | 144 | features = np.load(features_path, allow_pickle=True) 145 | features = features.item() 146 | 147 | image_data = to_image_id_dict(image_data) 148 | 149 | for question in questions[:n]: 150 | 151 | if q_length is not None: 152 | current_len = len(question['clauses']) 153 | if not current_len == q_length: 154 | continue 155 | 156 | # functions = question["clauses"] 157 | image_id = question["image_id"] 158 | scene_graph = scene_graphs[image_id] 159 | cur_image_data = image_data[image_id] 160 | 161 | if not scene_graph['image_id'] == image_id or not cur_image_data['image_id'] == image_id: 162 | raise Exception("Mismatched image id") 163 | 164 | info = {} 165 | info['image_id'] = image_id 166 | info['scene_graph'] = scene_graph 167 | info['url'] = cur_image_data['url'] 168 | info['question'] = question 169 | info['object_feature'] = [] 170 | info['object_ids'] = [] 171 | 172 | for obj_id in scene_graph['names'].keys(): 173 | info['object_ids'].append(obj_id) 174 | info['object_feature'].append(features.get(obj_id)) 175 | 176 | yield(info) 177 | 178 | -------------------------------------------------------------------------------- /src/experiments/vqa/word_idx_translator.py: -------------------------------------------------------------------------------- 1 | """ 2 | The source code is based on: 3 | Scallop: From Probabilistic Deductive Databases to Scalable Differentiable Reasoning 4 | Jiani Huang, Ziyang Li, Binghong Chen, Karan Samel, Mayur Naik, Le Song, Xujie Si 5 | Advances in Neural Information Processing Systems 34 (NeurIPS 2021) 6 | https://proceedings.neurips.cc/paper/2021/hash/d367eef13f90793bd8121e2f675f0dc2-Abstract.html 7 | """ 8 | import json 9 | 10 | 11 | class Idx2Word(): 12 | 13 | def __init__(self, meta_info, use_canon=False): 14 | self.setup(meta_info) 15 | self.attr_canon = meta_info['attr']['canon'] 16 | self.name_canon = meta_info['name']['canon'] 17 | self.rela_canon = meta_info['rel']['canon'] 18 | 19 | self.attr_alias = meta_info['attr']['alias'] 20 | self.rela_alias = meta_info['rel']['alias'] 21 | 22 | self.attr_to_idx_dict = meta_info['attr']['idx'] 23 | self.rela_to_idx_dict = meta_info['rel']['idx'] 24 | self.name_to_idx_dict = meta_info['name']['idx'] 25 | self.use_canon = use_canon 26 | # print("here") 27 | 28 | def setup(self, meta_info): 29 | 30 | attr_to_idx = meta_info['attr']['idx'] 31 | rela_to_idx = meta_info['rel']['idx'] 32 | name_to_idx = meta_info['name']['idx'] 33 | 34 | attr_freq = meta_info['attr']['freq'] 35 | rela_freq = meta_info['rel']['freq'] 36 | name_freq = meta_info['name']['freq'] 37 | 38 | # attr_group = meta_info['attr']['group'] 39 | 40 | def setup_single(to_idx, freq, group=None): 41 | idx_to_name = {} 42 | for name in freq: 43 | if name not in to_idx: 44 | continue 45 | idx = to_idx[name] 46 | if type(idx) == list: 47 | if not idx[0] in idx_to_name.keys(): 48 | idx_to_name[idx[0]] = {} 49 | idx_to_name[idx[0]][idx[1]] = name 50 | else: 51 | idx_to_name[idx] = name 52 | return idx_to_name 53 | 54 | self.idx_to_name_dict = setup_single(name_to_idx, name_freq) 55 | self.idx_to_rela_dict = setup_single(rela_to_idx, rela_freq) 56 | self.idx_to_attr_dict = setup_single(attr_to_idx, attr_freq) 57 | # self.idx_to_attr_dict = setup_single(attr_to_idx, attr_freq, attr_group) 58 | 59 | def get_name_ct(self): 60 | return len(self.idx_to_name_dict) 61 | 62 | def get_rela_ct(self): 63 | return len(self.idx_to_rela_dict) 64 | 65 | def get_attr_ct(self): 66 | return len(self.idx_to_attr_dict) 67 | 68 | def get_names(self): 69 | return list(self.idx_to_name_dict.values()) 70 | 71 | def idx_to_name(self, idx): 72 | if idx is None: 73 | return None 74 | if type(idx) == str: 75 | return idx 76 | if len(self.idx_to_name_dict) == idx: 77 | return None 78 | if idx == -1: 79 | return None 80 | return self.idx_to_name_dict[idx] 81 | 82 | def idx_to_rela(self, idx): 83 | if idx is None: 84 | return None 85 | if idx == -1: 86 | return None 87 | if type(idx) == str: 88 | return idx 89 | if len(self.idx_to_rela_dict) == idx: 90 | return None 91 | return self.idx_to_rela_dict[idx] 92 | 93 | def idx_to_attr(self, idx): 94 | if idx is None: 95 | return None 96 | if type(idx) == str: 97 | return idx 98 | if len(self.idx_to_attr_dict) == idx: 99 | return None 100 | if idx == -1: 101 | return None 102 | # return self.idx_to_attr_dict[idx[0]][idx[1]] 103 | return self.idx_to_attr_dict[idx] 104 | 105 | def attr_to_idx(self, attr): 106 | if attr is None: 107 | return attr 108 | 109 | if self.use_canon: 110 | if attr in self.attr_canon.keys(): 111 | attr = self.attr_canon[attr] 112 | 113 | if attr in self.attr_alias.keys(): 114 | attr = self.attr_alias[attr] 115 | 116 | if attr not in self.attr_to_idx_dict.keys(): 117 | return None 118 | 119 | return self.attr_to_idx_dict[attr] 120 | 121 | def name_to_idx(self, name): 122 | 123 | if name is None: 124 | return name 125 | 126 | if self.use_canon: 127 | if name in self.name_canon.keys(): 128 | name = self.name_canon[name] 129 | 130 | if name not in self.name_to_idx_dict.keys(): 131 | return None 132 | 133 | return self.name_to_idx_dict[name] 134 | 135 | def rela_to_idx(self, rela): 136 | if rela is None: 137 | return rela 138 | 139 | if self.use_canon: 140 | if rela in self.rela_canon.keys(): 141 | rela = self.rela_canon[rela] 142 | 143 | if rela in self.rela_alias.keys(): 144 | rela = self.rela_alias[rela] 145 | 146 | if rela not in self.rela_to_idx_dict.keys(): 147 | return None 148 | 149 | return self.rela_to_idx_dict[rela] 150 | 151 | 152 | def process_program(program, meta_info): 153 | 154 | new_program = [] 155 | 156 | for clause in program: 157 | new_clause = {} 158 | new_clause['function'] = clause['function'] 159 | 160 | if 'output' in clause.keys(): 161 | new_clause['output'] = clause['output'] 162 | 163 | if clause['function'] == "Hypernym_Find": 164 | name = clause['text_input'][0] 165 | attr = clause['text_input'][1] 166 | 167 | new_clause['text_input'] = [name, attr] + clause['text_input'][2:] 168 | 169 | elif clause['function'] == "Find": 170 | name = clause['text_input'][0] 171 | new_clause['text_input'] = [name] + clause['text_input'][1:] 172 | 173 | elif clause['function'] == "Relate_Reverse": 174 | relation = clause['text_input'] 175 | new_clause['text_input'] = relation 176 | 177 | elif clause['function'] == "Relate": 178 | relation = clause['text_input'] 179 | new_clause['text_input'] = relation 180 | 181 | else: 182 | if 'text_input' in clause.keys(): 183 | new_clause['text_input'] = clause['text_input'] 184 | 185 | new_program.append(new_clause) 186 | 187 | return new_program 188 | 189 | 190 | def process_questions(questions_path, new_question_path, meta_info): 191 | new_questions = {} 192 | 193 | with open(questions_path, 'r') as questions_file: 194 | questions = json.load(questions_file) 195 | 196 | # process questions 197 | for question in questions: 198 | 199 | image_id = question["image_id"] 200 | 201 | # process questions 202 | if image_id not in new_questions.keys(): 203 | new_questions[image_id] = {} 204 | new_question = new_questions[image_id] 205 | 206 | new_question['question_id'] = question['question_id'] 207 | new_question['program'] = process_program( 208 | question["program"], meta_info) 209 | 210 | program = question['program'] 211 | new_question['target'] = program[-2]["output"] 212 | new_question['question'] = question['question'] 213 | new_question['answer'] = question['answer'] 214 | 215 | with open(new_question_path, 'w') as new_question_file: 216 | json.dump(new_questions, new_question_file) 217 | 218 | -------------------------------------------------------------------------------- /src/slot_attention_module.py: -------------------------------------------------------------------------------- 1 | """ 2 | Slot attention model based on code of tkipf and the corresponding paper Locatello et al. 2020 3 | """ 4 | from torch import nn 5 | import torch 6 | import torch.nn.functional as F 7 | import torchvision.models as models 8 | import numpy as np 9 | from torchsummary import summary 10 | 11 | 12 | def build_grid(resolution): 13 | ranges = [np.linspace(0., 1., num=res) for res in resolution] 14 | grid = np.meshgrid(*ranges, sparse=False, indexing="ij") 15 | grid = np.stack(grid, axis=-1) 16 | grid = np.reshape(grid, [resolution[0], resolution[1], -1]) 17 | grid = np.expand_dims(grid, axis=0) 18 | grid = grid.astype(np.float32) 19 | return np.concatenate([grid, 1.0 - grid], axis=-1) 20 | 21 | 22 | def spatial_broadcast(slots, resolution): 23 | """Broadcast slot features to a 2D grid and collapse slot dimension.""" 24 | # `slots` has shape: [batch_size, num_slots, slot_size]. 25 | slots = torch.reshape(slots, [slots.shape[0] * slots.shape[1], 1, 1, slots.shape[2]]) 26 | 27 | grid = slots.repeat(1, resolution[0], resolution[1], 1) #repeat expands the data along differnt dimensions 28 | # `grid` has shape: [batch_size*num_slots, width, height, slot_size]. 29 | return grid 30 | 31 | 32 | def unstack_and_split(x, batch_size, n_slots, num_channels=3): 33 | """Unstack batch dimension and split into channels and alpha mask.""" 34 | # unstacked = torch.reshape(x, [batch_size, -1] + list(x.shape[1:])) 35 | # channels, masks = torch.split(unstacked, [num_channels, 1], dim=-1) 36 | unstacked = torch.reshape(x, [batch_size, n_slots] + list(x.shape[1:])) 37 | channels, masks = torch.split(unstacked, [num_channels, 1], dim=2) 38 | return channels, masks 39 | 40 | 41 | class SlotAttention(nn.Module): 42 | def __init__(self, num_slots, dim, iters=3, eps=1e-8, hidden_dim=128): 43 | super().__init__() 44 | self.num_slots = num_slots 45 | self.iters = iters 46 | self.eps = eps 47 | self.scale = dim ** -0.5 #named D in the paper 48 | 49 | self.slots_mu = nn.Parameter(torch.randn(1, 1, dim)) #randomly initialize sigma and mu 50 | self.slots_log_sigma = nn.Parameter(torch.randn(1, 1, dim)).abs().to(device='cuda') 51 | #self.slots_mu = nn.Parameter(torch.nn.init.xavier_uniform_(torch.empty(1,1,dim), gain=1.0)) #randomly initialize sigma and mu 52 | #self.slots_log_sigma = nn.Parameter(torch.nn.init.xavier_uniform_(torch.empty(1,1,dim), gain=1.0)) 53 | 54 | self.project_q = nn.Linear(dim, dim, bias=True) #query projection 55 | self.project_k = nn.Linear(dim, dim, bias=True) # 56 | self.project_v = nn.Linear(dim, dim, bias=True) #feature key projection 57 | 58 | self.gru = nn.GRUCell(dim, dim) 59 | 60 | hidden_dim = max(dim, hidden_dim) 61 | 62 | self.mlp = nn.Sequential( 63 | nn.Linear(dim, hidden_dim), 64 | nn.ReLU(inplace=True), 65 | nn.Linear(hidden_dim, dim) 66 | ) 67 | 68 | self.norm_inputs = nn.LayerNorm(dim, eps=1e-05) 69 | self.norm_slots = nn.LayerNorm(dim, eps=1e-05) 70 | self.norm_mlp = nn.LayerNorm(dim, eps=1e-05) 71 | 72 | self.attn = 0 73 | 74 | def forward(self, inputs, num_slots=None): 75 | b, n, d = inputs.shape #b is the batchsize, n is the dimensionsize of the features, d is the amount of features([15, 1024, 32]) 76 | n_s = num_slots if num_slots is not None else self.num_slots 77 | 78 | mu = self.slots_mu.expand(b, n_s, -1) #mu and sigma are shared by all slots 79 | sigma = self.slots_log_sigma.expand(b, n_s, -1) 80 | slots = torch.normal(mu, sigma) #sample slots from mu and sigma 81 | #slots = torch.normal(mu, sigma.exp()) #sample slots from mu and sigma 82 | 83 | 84 | inputs = self.norm_inputs(inputs) #layer normalization of inputs 85 | k, v = self.project_k(inputs), self.project_v(inputs) #*self.scale 86 | 87 | 88 | for _ in range(self.iters): 89 | slots_prev = slots #store old slots 90 | 91 | slots = self.norm_slots(slots) #layer norm of slots 92 | q = self.project_q(slots) #emit a query for all slots 93 | 94 | dots = torch.einsum('bid,bjd->bij', q, k) * self.scale #is M in the paper, has shape 1024(feature map)| 7(slot amount) 95 | attn = dots.softmax(dim=1) + self.eps #calcualte the softmax for each slot which is also 1024 * 7 96 | attn = attn / attn.sum(dim=-1, keepdim=True) #weighted mean 97 | 98 | updates = torch.einsum('bjd,bij->bid', v, attn) 99 | 100 | #recurrently update the slots with the slot updates and the previous slots 101 | slots = self.gru( 102 | updates.reshape(-1, d), 103 | slots_prev.reshape(-1, d) 104 | ) 105 | 106 | #apply 2 layer relu mlp to GRU output 107 | slots = slots.reshape(b, -1, d) 108 | slots = slots + self.mlp(self.norm_mlp(slots)) 109 | 110 | self.attn = attn 111 | 112 | return slots 113 | 114 | 115 | class SlotAttention_encoder(nn.Module): 116 | def __init__(self, in_channels, hidden_channels, clevr_encoding): 117 | super(SlotAttention_encoder, self).__init__() 118 | 119 | if clevr_encoding: 120 | self.network = nn.Sequential( 121 | nn.Conv2d(in_channels, hidden_channels, (5, 5), stride=(1, 1), padding=2), 122 | nn.ReLU(inplace=True), 123 | nn.Conv2d(hidden_channels, hidden_channels, (5, 5), stride=(2, 2), padding=2), 124 | nn.ReLU(inplace=True), 125 | nn.Conv2d(hidden_channels, hidden_channels, (5, 5), stride=(2, 2), padding=2), 126 | nn.ReLU(inplace=True), 127 | nn.Conv2d(hidden_channels, hidden_channels, (5, 5), stride=(1, 1), padding=2), 128 | nn.ReLU(inplace=True)) 129 | else: 130 | self.network = nn.Sequential( 131 | nn.Conv2d(in_channels, hidden_channels, (5, 5), stride=(1, 1), padding=2), 132 | nn.ReLU(inplace=True), 133 | nn.Conv2d(hidden_channels, hidden_channels, (5, 5), stride=(1, 1), padding=2), 134 | nn.ReLU(inplace=True), 135 | nn.Conv2d(hidden_channels, hidden_channels, (5, 5), stride=(1, 1), padding=2), 136 | nn.ReLU(inplace=True), 137 | nn.Conv2d(hidden_channels, hidden_channels, (5, 5), stride=(1, 1), padding=2), 138 | nn.ReLU(inplace=True)) 139 | 140 | 141 | 142 | 143 | def forward(self, x): 144 | return self.network(x) 145 | 146 | 147 | class MLP(nn.Module): 148 | def __init__(self, hidden_channels): 149 | super(MLP, self).__init__() 150 | self.network = nn.Sequential( 151 | nn.Linear(hidden_channels, hidden_channels), 152 | nn.ReLU(inplace=True), 153 | nn.Linear(hidden_channels, hidden_channels), 154 | ) 155 | 156 | def forward(self, x): 157 | return self.network(x) 158 | 159 | 160 | class SoftPositionEmbed(nn.Module): 161 | """Adds soft positional embedding with learnable projection.""" 162 | 163 | def __init__(self, hidden_size, resolution, device="cuda:0"): 164 | """Builds the soft position embedding layer. 165 | Args: 166 | hidden_size: Size of input feature dimension. 167 | resolution: Tuple of integers specifying width and height of grid. 168 | """ 169 | super().__init__() 170 | self.dense = nn.Linear(4, hidden_size) 171 | # self.grid = torch.FloatTensor(build_grid(resolution)) 172 | # self.grid = self.grid.to(device) 173 | # for nn.DataParallel 174 | self.register_buffer("grid", torch.FloatTensor(build_grid(resolution))) 175 | self.resolution = resolution[0] 176 | self.hidden_size = hidden_size 177 | 178 | def forward(self, inputs): 179 | return inputs + self.dense(self.grid).view((-1, self.hidden_size, self.resolution, self.resolution)) 180 | 181 | 182 | class SlotAttention_classifier(nn.Module): 183 | def __init__(self, in_channels, out_channels): 184 | super(SlotAttention_classifier, self).__init__() 185 | self.network = nn.Sequential( 186 | nn.Linear(in_channels, in_channels), # nn.Conv1d(in_channels, in_channels, 1, stride=1, groups=in_channels) 187 | nn.ReLU(inplace=True), 188 | nn.Linear(in_channels, out_channels), 189 | nn.Sigmoid() 190 | ) 191 | 192 | def forward(self, x): 193 | return self.network(x) 194 | 195 | 196 | class SlotAttention_model(nn.Module): 197 | def __init__(self, n_slots, n_iters, n_attr, 198 | in_channels=3, 199 | encoder_hidden_channels=64, 200 | attention_hidden_channels=128, 201 | mlp_prediction = False, 202 | device="cuda", 203 | clevr_encoding=False): 204 | super(SlotAttention_model, self).__init__() 205 | self.n_slots = n_slots 206 | self.n_iters = n_iters 207 | self.n_attr = n_attr 208 | self.n_attr = n_attr + 1 # additional slot to indicate if it is a object or empty slot 209 | self.device = device 210 | 211 | self.encoder_cnn = SlotAttention_encoder(in_channels=in_channels, hidden_channels=encoder_hidden_channels , clevr_encoding=clevr_encoding) 212 | self.encoder_pos = SoftPositionEmbed(encoder_hidden_channels, (32, 32), device=device)# changed from 128* 128 213 | self.layer_norm = nn.LayerNorm(encoder_hidden_channels, eps=1e-05) 214 | self.mlp = MLP(hidden_channels=encoder_hidden_channels) 215 | self.slot_attention = SlotAttention(num_slots=n_slots, dim=encoder_hidden_channels, iters=n_iters, eps=1e-8, 216 | hidden_dim=attention_hidden_channels) 217 | 218 | #for set prediction baseline 219 | self.mlp_prediction = mlp_prediction 220 | self.mlp_classifier = SlotAttention_classifier(in_channels=encoder_hidden_channels, out_channels=self.n_attr) 221 | 222 | self.softmax = nn.Softmax(dim=1) 223 | 224 | def forward(self, img): 225 | # `x` has shape: [batch_size, width, height, num_channels]. 226 | 227 | # SLOT ATTENTION ENCODER 228 | x = self.encoder_cnn(img) 229 | x = self.encoder_pos(x) 230 | x = torch.flatten(x, start_dim=2) 231 | 232 | # permute channel dimensions 233 | x = x.permute(0, 2, 1) 234 | x = self.layer_norm(x) 235 | x = self.mlp(x) 236 | 237 | slots = self.slot_attention(x) 238 | # slots has shape: [batch_size, num_slots, slot_size]. 239 | if self.mlp_prediction: 240 | x = self.mlp_classifier(slots) 241 | return x 242 | else: 243 | return slots 244 | 245 | 246 | if __name__ == "__main__": 247 | x = torch.rand(15, 3, 32, 32).cuda() 248 | net = SlotAttention_model(n_slots=11, n_iters=3, n_attr=18, 249 | encoder_hidden_channels=32, attention_hidden_channels=64, 250 | decoder_hidden_channels=32, decoder_initial_size=(8, 8)) 251 | net = net.cuda() 252 | output = net(x) 253 | summary(net, (3, 32, 32)) 254 | 255 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | import errno 5 | from PIL import Image 6 | import torch 7 | from torch.utils.tensorboard import SummaryWriter 8 | import time 9 | import matplotlib 10 | import matplotlib.pyplot as plt 11 | from matplotlib.ticker import MaxNLocator 12 | import seaborn as sns 13 | import random 14 | 15 | 16 | def mkdir_p(path): 17 | """Linux mkdir -p""" 18 | try: 19 | os.makedirs(path) 20 | except OSError as exc: # Python >2.5 21 | if exc.errno == errno.EEXIST and os.path.isdir(path): 22 | pass 23 | else: 24 | raise 25 | 26 | 27 | def one_hot(x, K, dtype=torch.float): 28 | """One hot encoding""" 29 | with torch.no_grad(): 30 | ind = torch.zeros(x.shape + (K,), dtype=dtype, device=x.device) 31 | ind.scatter_(-1, x.unsqueeze(-1), 1) 32 | return ind 33 | 34 | 35 | def save_image_stack(samples, num_rows, num_columns, filename, margin=5, margin_gray_val=1., frame=0, frame_gray_val=0.0): 36 | """Save image stack in a tiled image""" 37 | 38 | # for gray scale, convert to rgb 39 | if len(samples.shape) == 3: 40 | samples = np.stack((samples,) * 3, -1) 41 | 42 | height = samples.shape[1] 43 | width = samples.shape[2] 44 | 45 | samples -= samples.min() 46 | samples /= samples.max() 47 | 48 | img = margin_gray_val * np.ones((height*num_rows + (num_rows-1)*margin, width*num_columns + (num_columns-1)*margin, 3)) 49 | for h in range(num_rows): 50 | for w in range(num_columns): 51 | img[h*(height+margin):h*(height+margin)+height, w*(width+margin):w*(width+margin)+width, :] = samples[h*num_columns + w, :] 52 | 53 | framed_img = frame_gray_val * np.ones((img.shape[0] + 2*frame, img.shape[1] + 2*frame, 3)) 54 | framed_img[frame:(frame+img.shape[0]), frame:(frame+img.shape[1]), :] = img 55 | 56 | img = Image.fromarray(np.round(framed_img * 255.).astype(np.uint8)) 57 | 58 | img.save(filename) 59 | 60 | 61 | def sample_matrix_categorical(p): 62 | """Sample many Categorical distributions represented as rows in a matrix.""" 63 | with torch.no_grad(): 64 | cp = torch.cumsum(p[:, 0:-1], -1) 65 | rand = torch.rand((cp.shape[0], 1), device=cp.device) 66 | rand_idx = torch.sum(rand > cp, -1).long() 67 | return rand_idx 68 | 69 | 70 | def set_manual_seed(seed:int=1): 71 | """Set the seed for the PRNGs.""" 72 | os.environ['PYTHONASHSEED'] = str(seed) 73 | random.seed(seed) 74 | np.random.seed(seed) 75 | torch.manual_seed(seed) 76 | if torch.cuda.is_available(): 77 | torch.cuda.manual_seed(seed) 78 | torch.cuda.manual_seed_all(seed) 79 | torch.cuda.benchmark = True 80 | 81 | 82 | def time_delta_now(t_start: float, simple_format=False, ret_sec=False) -> str: 83 | """ 84 | 85 | Convert a timestamp into a human readable timestring. 86 | Parameters 87 | ---------- 88 | t_start : float 89 | The timestamp describing the begin of any event. 90 | Returns 91 | ------- 92 | Human readable timestring. 93 | """ 94 | a = t_start 95 | b = time.time() # current epoch time 96 | c = b - a # seconds 97 | days = round(c // 86400) 98 | hours = round(c // 3600 % 24) 99 | minutes = round(c // 60 % 60) 100 | seconds = round(c % 60) 101 | millisecs = round(c % 1 * 1000) 102 | if simple_format: 103 | return f"{hours}h:{minutes}m:{seconds}s" 104 | 105 | return f"{days} days, {hours} hours, {minutes} minutes, {seconds} seconds, {millisecs} milliseconds", c 106 | 107 | def time_delta(c: float, simple_format=False,) -> str: 108 | c# seconds 109 | days = round(c // 86400) 110 | hours = round(c // 3600 % 24) 111 | minutes = round(c // 60 % 60) 112 | seconds = round(c % 60) 113 | millisecs = round(c % 1 * 1000) 114 | if simple_format: 115 | return f"{hours}h:{minutes}m:{seconds}s" 116 | return f"{days} days, {hours} hours, {minutes} minutes, {seconds} seconds, {millisecs} milliseconds", c 117 | 118 | 119 | def export_results(test_accuracy_list, train_accuracy_list, 120 | export_path, export_suffix, 121 | confusion_matrix , exp_dict): 122 | 123 | #set matplotlib styles 124 | plt.style.use(['science','grid']) 125 | matplotlib.rcParams.update( 126 | { 127 | "font.family": "serif", 128 | "text.usetex": False, 129 | "legend.fontsize": 22 130 | } 131 | ) 132 | 133 | 134 | fig, axs = plt.subplots(1, 1 , figsize=(10,21)) 135 | # fig.suptitle(exp_dict['exp_name'], fontsize=16) 136 | #axs[0] 137 | axs.plot(test_accuracy_list[:,1],test_accuracy_list[:,0], label='test accuracy') 138 | #axs[0] 139 | axs.plot(train_accuracy_list[:,1],train_accuracy_list[:,0], label='train accuracy') 140 | #axs[0]. 141 | axs.legend(loc="lower right") 142 | #axs[0]. 143 | axs.set(xlabel='epochs', ylabel='accuracy') 144 | #axs[0]. 145 | axs.xaxis.set_major_locator(MaxNLocator(integer=True)) 146 | 147 | 148 | #ax.[0, 1].set_xticklabels([0,1,2,3,4,5,6,7,8,9]) 149 | #ax.[0, 1].set_yticklabels([0,1,2,3,4,5,6,7,8,9]) 150 | 151 | #axs[0, 1].set(xlabel='target', ylabel='prediction') 152 | #axs[1] = sns.heatmap(confusion_matrix, linewidths=2, cmap="viridis") 153 | 154 | 155 | 156 | if exp_dict['structure'] == 'poon-domingos': 157 | text = "trainable parameters = {}, lr = {}, batchsize= {}, time_per_epoch= {},\n structure= {}, pd_num_pieces = {}".format( 158 | exp_dict['num_trainable_params'], exp_dict['lr'], 159 | exp_dict['bs'], exp_dict['train_time'], 160 | exp_dict['structure'], exp_dict['pd_num_pieces'] ) 161 | 162 | else: 163 | text = "trainable parameters = {}, lr = {}, batchsize= {}, time_per_epoch= {},\n structure= {}, num_repetitions = {} , depth = {}".format( 164 | exp_dict['num_trainable_params'], exp_dict['lr'], 165 | exp_dict['bs'],exp_dict['train_time'], 166 | exp_dict['structure'], exp_dict['num_repetitions'], exp_dict['depth']) 167 | 168 | #plt.gcf().text( 169 | #0.5, 170 | #0.02, 171 | #text, 172 | #ha="center", 173 | #fontsize=12, 174 | #linespacing=1.5, 175 | #bbox=dict( 176 | # facecolor="grey", alpha=0.2, edgecolor="black", boxstyle="round,pad=1" 177 | #)) 178 | 179 | fig.savefig(export_path, format="svg") 180 | 181 | plt.show() 182 | 183 | 184 | #Tensorboard outputs 185 | writer = SummaryWriter("../../results", filename_suffix=export_suffix) 186 | 187 | for train_acc_elem, test_acc_elem in zip(train_accuracy_list, test_accuracy_list): 188 | writer.add_scalar('Accuracy/train', train_acc_elem[0], train_acc_elem[1]) 189 | writer.add_scalar('Accuracy/test', test_acc_elem[0], test_acc_elem[1]) 190 | 191 | 192 | --------------------------------------------------------------------------------