├── Dockerfile
├── LICENSE
├── README.md
├── environment.yml
├── imgs
└── slash_icon.png
├── requirements.txt
├── setup.sh
└── src
├── README.md
├── SLASH
├── __init__.py
├── mvpp.py
└── slash.py
├── __init__.py
├── datasets.py
├── einsum_wrapper.py
├── experiments
├── __init__.py
├── baseline_slot_attention
│ ├── dataGen.py
│ ├── set_utils.py
│ ├── train.py
│ └── train.sh
├── mnist_top_k
│ ├── __init__.py
│ ├── dataGen.py
│ ├── data_generation.ipynb
│ ├── network_nn.py
│ ├── train.py
│ ├── train_baseline.sh
│ ├── train_same.sh
│ └── train_top_k.sh
├── slash_attention
│ ├── ap_utils.py
│ ├── clevr
│ │ ├── __init__.py
│ │ ├── auxiliary.py
│ │ ├── dataGen.py
│ │ ├── slash_attention_clevr.py
│ │ └── train.py
│ ├── clevr_cogent
│ │ ├── __init__.py
│ │ ├── auxiliary.py
│ │ ├── dataGen.py
│ │ ├── slash_attention_clevr.py
│ │ └── train.py
│ ├── cogent
│ │ ├── __init__.py
│ │ ├── dataGen.py
│ │ ├── slash_attention_cogent.py
│ │ └── train.py
│ └── shapeworld4
│ │ ├── __init__.py
│ │ ├── dataGen.py
│ │ ├── slash_attention_shapeworld4.py
│ │ └── train.py
└── vqa
│ ├── __init__.py
│ ├── cmd_args2.py
│ ├── dataGen.py
│ ├── knowledge_graph.py
│ ├── models.py
│ ├── network_nn.py
│ ├── preprocess.py
│ ├── query_lib.py
│ ├── sg_model.py
│ ├── test.py
│ ├── test.sh
│ ├── train.py
│ ├── train.sh
│ ├── trainer.py
│ ├── vqa_utils.py
│ └── word_idx_translator.py
├── slot_attention_module.py
└── utils.py
/Dockerfile:
--------------------------------------------------------------------------------
1 | # Select the base image
2 | #old 21.06-py3
3 | #FROM nvcr.io/nvidia/pytorch:21.12-py3
4 | FROM nvcr.io/nvidia/pytorch:21.06-py3
5 |
6 |
7 | # Select the working directory
8 | WORKDIR /SLASH
9 |
10 | # Install PyTorch
11 | #RUN pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
12 |
13 | # Install Python requirements
14 | COPY ./requirements.txt ./requirements.txt
15 | RUN pip install --ignore-installed -r requirements.txt
16 | # RUN conda install -c potassco clingo=5.5.0
17 | # RUN pip install git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git ; pip install rtpt torchsummary ; python -m pip install -U scikit-image
18 | # RUN conda install -c conda-forge tqdm
19 | # RUN conda install -c anaconda seaborn
20 | # RUN conda install -c conda-forge scikit-learn
21 | # RUN conda install -c conda-forge tensorboard
22 |
23 | # Remove fontlist file so MPL can find the serif fonts
24 | RUN rm -rf /.matplotlib/fontlist-v330.json
25 |
26 | # Setup mpl config dir since SciencePlots install .mplstyle files into this dir
27 | RUN mkdir -p /.matplotlib/stylelib
28 | RUN chmod a+rwx -R /.matplotlib
29 | ENV MPLCONFIGDIR=/.matplotlib
30 |
31 | # Add fonts for serif rendering in MPL plots
32 | RUN apt-get update
33 | RUN echo ttf-mscorefonts-installer msttcorefonts/accepted-mscorefonts-eula select true | debconf-set-selections
34 | RUN apt-get install --yes ttf-mscorefonts-installer
35 | RUN ln -snf /usr/share/zoneinfo/Etc/UTC /etc/localtime
36 | RUN apt-get install dvipng cm-super fonts-cmu --yes
37 | RUN apt-get install fonts-dejavu-core --yes
38 | RUN apt install -y tmux
39 | # RUN fc-cache -f -v
40 |
41 | RUN python3 -m pip install setuptools==59.5.0
42 | RUN pip install pandas==1.3.0
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 ml-research@TUDarmstadt
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Scalable Neural-Probabilistic Answer Set Programming
2 |
3 | Arseny Skryagin, Wolfgang Stammer, Daniel Ochs, Devendra Singh Dhami , Kristian Kersting
4 |
5 |
6 |
7 |
8 |
9 | [](https://opensource.org/licenses/MIT)
10 |
11 | # Abstract
12 | The goal of combining the robustness of neural networks and the expressiveness of symbolic
13 | methods has rekindled the interest in Neuro-Symbolic AI. Deep Probabilistic Programming
14 | Languages (DPPLs) have been developed for probabilistic logic programming to be carried
15 | out via the probability estimations of deep neural networks. However, recent SOTA DPPL
16 | approaches allow only for limited conditional probabilistic queries and do not offer the power
17 | of true joint probability estimation. In our work, we propose an easy integration of tractable
18 | probabilistic inference within a DPPL. To this end, we introduce SLASH, a novel DPPL
19 | that consists of Neural-Probabilistic Predicates (NPPs) and a logic program, united via
20 | answer set programming (ASP). NPPs are a novel design principle allowing for combining
21 | all deep model types and combinations thereof to be represented as a single probabilistic
22 | predicate. In this context, we introduce a novel +/− notation for answering various types
23 | of probabilistic queries by adjusting the atom notations of a predicate. To scale well, we
24 | show how to prune the stochastically insignificant parts of the (ground) program, speeding
25 | up reasoning without sacrificing the predictive performance. We evaluate SLASH on a
26 | variety of different tasks, including the benchmark task of MNIST addition and Visual
27 | Question Answering (VQA).
28 |
29 |
30 |
31 |
32 | ## Introduction
33 | This is the repository for SLASH, the deep declarative probabilistic programming language introduced within **Neural-Probabilistic Answer Set Programming** [](https://kr2022.cs.tu-dortmund.de/index.php) and **Scalable Neural-Probabilistic Answer Set Programming** [](https://jair.org/index.php/jair/article/view/15027). [Link](https://jair.org/index.php/jair/article/view/15027) to the paper.
34 |
35 |
36 |
37 | ## 1. Prerequisites
38 | ### 1.1 Cloning SLASH and submodules
39 | To clone SLASH including all submodules use:
40 | ```
41 | git clone --recurse-submodules -j8 https://github.com/ml-research/SLASH
42 | ```
43 |
44 | ### 1.2 Anaconda
45 | The `environment.yaml` provides all packages needed to run a SLASH program using the Anaconda package manager. To create a new conda environment install Anaconda ([installation page](https://docs.anaconda.com/anaconda/install/)) and then create a new environment using the following command:
46 | ```
47 | conda env create -n slash -f environment.yml
48 | conda activate env_dev
49 | pip install git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git
50 | ```
51 |
52 | The following packages are installed:
53 | - pytorch
54 | - clingo # version 5.5.1
55 | - scikit-image
56 | - scikit-learn
57 | - seaborn
58 | - tqdm
59 | - rtpt
60 | - tensorboard # needed for standalone slot attention
61 | - torchsummary
62 | - GradualWarmupSheduler
63 |
64 | ### 1.3 Virtual environment
65 | To use virtualenv run the following commands. This creates a new virtual environment and installs all packages needed. Tested Python Versions 3.6. For using cuda version `10.x` with pytorch remove the `+cu113` version appendix in the `requirements.txt` file otherwise it will use CUDA version `11.3`.
66 | ```
67 | python3 -m venv slash_env
68 | source slash_env/bin/activate
69 | pip install --upgrade pip
70 | python3 -m pip install -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html
71 | ```
72 |
73 | ### 1.4 Docker
74 | Alternatively you can also run slash using docker. For this you first need to create an image using the provided `Dockerfile`.
75 |
76 | To start create an image named `slash` and then run the container using that image. In the slash base folder execute:
77 | ```
78 | docker build . -t slash:0.01
79 | docker run --runtime=nvidia -e NVIDIA_VISIBLE_DEVICES= --ipc=host -it --rm -v /$(pwd):/slash slash:1.0
80 | ```
81 |
82 |
83 | ## 2. Project Structure
84 | ```
85 | --data: contains datasets after downloading
86 | --src: source files which has its own readme. Go there for more details.
87 | ```
88 |
89 |
90 | ## 3. Getting started
91 | Go visit the ``` src/experiments/mnist_top_k``` folder to get familar with SLASH and run one of the training scripts.
92 |
93 |
94 | ## Citation
95 |
96 | If you use this code for your research, please cite our journal paper at [JAIR](https://jair.org/index.php/jair/article/view/15027) or the conference paper at [KR22](https://proceedings.kr.org/2022/48/)
97 | ```
98 | @article{skryagin23JAIR,
99 | year = { 2023 },
100 | crossref = { https://github.com/ml-research/SLASH },
101 | title = { Scalable Neural-Probabilistic Answer Set Programming },
102 | pages = { 579--617 },
103 | volume = { 78 },
104 | journal = { Journal of Artificial Intelligence Research (JAIR) },
105 | author = {Skryagin, Arseny and Ochs, Daniel and Dhami, Devendra Singh and Kersting, Kristian},
106 | }
107 | ```
108 | ```
109 | @inproceedings{skryagin2022KR,
110 | title={Neural-Probabilistic Answer Set Programming},
111 | author={Arseny Skryagin and Wolfgang Stammer and Daniel Ochs and Devendra Singh Dhami and Kristian Kersting},
112 | booktitle={Proceedings of the 19th International Conference on Principles of Knowledge Representation and Reasoning (KR)},
113 | year={2022}
114 | }
115 | ```
116 |
117 | ## Acknowlegments
118 | This work was partly supported by the Federal Minister of Education and Research (BMBF) and the Hessian Ministry of Science and the Arts (HMWK) within the National Research Center for Applied Cybersecurity ATHENE, the ICT-48 Network of AI Research Excellence Center "TAILOR" (EU Horizon 2020, GA No 952215, and the Collaboration Lab with Nexplore "AI in Construction" (AICO). It also benefited from the BMBF AI lighthouse project, the Hessian research priority programme LOEWE within the project WhiteBox, the HMWK cluster projects "The Third Wave of AI" and "The Adaptive Mind", the German Center for Artificial Intelligence (DFKI) project "SAINT".
119 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: env_slash
2 | channels:
3 | - anaconda
4 | - pytorch
5 | - nvidia
6 | - potassco
7 | - conda-forge
8 | - defaults
9 | dependencies:
10 | - _libgcc_mutex=0.1=main
11 | - absl-py=0.12.0=py36h06a4308_0
12 | - aiohttp=3.7.4=py36h27cfd23_1
13 | - anyio=3.1.0=py36h5fab9bb_0
14 | - argon2-cffi=20.1.0=py36h1d69622_2
15 | - async-timeout=3.0.1=py36h06a4308_0
16 | - async_generator=1.10=py_0
17 | - attrs=21.2.0=pyhd3eb1b0_0
18 | - babel=2.9.1=pyh44b312d_0
19 | - backports=1.0=py_2
20 | - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0
21 | - blas=1.0=mkl
22 | - bleach=3.3.0=pyh44b312d_0
23 | - blinker=1.4=py36h06a4308_0
24 | - blosc=1.21.0=h8c45485_0
25 | - brotli=1.0.9=he6710b0_2
26 | - brotlipy=0.7.0=py36h27cfd23_1003
27 | - bzip2=1.0.8=h7b6447c_0
28 | - c-ares=1.17.1=h27cfd23_0
29 | - ca-certificates=2021.5.30=ha878542_0
30 | - cachetools=4.2.2=pyhd3eb1b0_0
31 | - certifi=2021.5.30=py36h5fab9bb_0
32 | - cffi=1.14.5=py36h261ae71_0
33 | - chardet=3.0.4=py36h06a4308_1003
34 | - charls=2.1.0=he6710b0_2
35 | - click=8.0.1=pyhd3eb1b0_0
36 | - clingo=5.4.0=py36h3fd9d12_1
37 | - cloudpickle=1.6.0=py_0
38 | - contextvars=2.4=py_0
39 | - coverage=5.5=py36h27cfd23_2
40 | - cryptography=3.4.7=py36hd23ed53_0
41 | - cudatoolkit=11.1.74=h6bb024c_0
42 | - cycler=0.10.0=py36_0
43 | - cython=0.29.23=py36h2531618_0
44 | - cytoolz=0.11.0=py36h7b6447c_0
45 | - dask-core=2021.3.0=pyhd3eb1b0_0
46 | - dataclasses=0.8=pyh4f3eec9_6
47 | - dbus=1.13.18=hb2f20db_0
48 | - decorator=5.0.9=pyhd3eb1b0_0
49 | - defusedxml=0.7.1=pyhd8ed1ab_0
50 | - entrypoints=0.3=pyhd8ed1ab_1003
51 | - expat=2.4.1=h2531618_2
52 | - ffmpeg=4.2.2=h20bf706_0
53 | - fontconfig=2.13.1=h6c09931_0
54 | - freetype=2.10.4=h5ab3b9f_0
55 | - giflib=5.1.4=h14c3975_1
56 | - glib=2.68.2=h36276a3_0
57 | - gmp=6.2.1=h2531618_2
58 | - gnutls=3.6.15=he1e5248_0
59 | - google-auth=1.30.1=pyhd3eb1b0_0
60 | - google-auth-oauthlib=0.4.1=py_2
61 | - grpcio=1.36.1=py36h2157cd5_1
62 | - gst-plugins-base=1.14.0=h8213a91_2
63 | - gstreamer=1.14.0=h28cd5cc_2
64 | - icu=58.2=he6710b0_3
65 | - idna=2.10=pyhd3eb1b0_0
66 | - idna_ssl=1.1.0=py36h06a4308_0
67 | - imagecodecs=2020.5.30=py36hfa7d478_2
68 | - imageio=2.9.0=pyhd3eb1b0_0
69 | - immutables=0.15=py36h27cfd23_0
70 | - importlib-metadata=3.10.0=py36h06a4308_0
71 | - intel-openmp=2021.2.0=h06a4308_610
72 | - ipykernel=5.5.5=py36hcb3619a_0
73 | - ipython=5.8.0=py36_1
74 | - ipython_genutils=0.2.0=py_1
75 | - jinja2=2.11.3=pyh44b312d_0
76 | - joblib=1.0.1=pyhd3eb1b0_0
77 | - jpeg=9b=h024ee3a_2
78 | - json5=0.9.5=pyh9f0ad1d_0
79 | - jsonschema=3.2.0=pyhd8ed1ab_3
80 | - jupyter_client=6.1.12=pyhd8ed1ab_0
81 | - jupyter_core=4.7.1=py36h5fab9bb_0
82 | - jupyter_server=1.8.0=pyhd8ed1ab_0
83 | - jupyterlab=3.0.16=pyhd8ed1ab_0
84 | - jupyterlab_pygments=0.1.2=pyh9f0ad1d_0
85 | - jupyterlab_server=2.6.0=pyhd8ed1ab_0
86 | - jxrlib=1.1=h7b6447c_2
87 | - kiwisolver=1.3.1=py36h2531618_0
88 | - lame=3.100=h7b6447c_0
89 | - lcms2=2.12=h3be6417_0
90 | - ld_impl_linux-64=2.33.1=h53a641e_7
91 | - libaec=1.0.4=he6710b0_1
92 | - libffi=3.3=he6710b0_2
93 | - libgcc-ng=9.1.0=hdf63c60_0
94 | - libgfortran-ng=7.3.0=hdf63c60_0
95 | - libidn2=2.3.1=h27cfd23_0
96 | - libopus=1.3.1=h7b6447c_0
97 | - libpng=1.6.37=hbc83047_0
98 | - libprotobuf=3.14.0=h8c45485_0
99 | - libsodium=1.0.18=h36c2ea0_1
100 | - libstdcxx-ng=9.1.0=hdf63c60_0
101 | - libtasn1=4.16.0=h27cfd23_0
102 | - libtiff=4.1.0=h2733197_1
103 | - libunistring=0.9.10=h27cfd23_0
104 | - libuuid=1.0.3=h1bed415_2
105 | - libuv=1.40.0=h7b6447c_0
106 | - libvpx=1.7.0=h439df22_0
107 | - libwebp=1.0.1=h8e7db2f_0
108 | - libxcb=1.14=h7b6447c_0
109 | - libxml2=2.9.10=hb55368b_3
110 | - libzopfli=1.0.3=he6710b0_0
111 | - lz4-c=1.9.3=h2531618_0
112 | - markdown=3.3.4=py36h06a4308_0
113 | - markupsafe=1.1.1=py36he6145b8_2
114 | - matplotlib=3.3.4=py36h06a4308_0
115 | - matplotlib-base=3.3.4=py36h62a2d02_0
116 | - mistune=0.8.4=py36h1d69622_1002
117 | - mkl=2020.2=256
118 | - mkl-service=2.3.0=py36he8ac12f_0
119 | - mkl_fft=1.3.0=py36h54f3939_0
120 | - mkl_random=1.1.1=py36h0573a6f_0
121 | - multidict=5.1.0=py36h27cfd23_2
122 | - nbclassic=0.3.1=pyhd8ed1ab_1
123 | - nbclient=0.5.3=pyhd8ed1ab_0
124 | - nbconvert=6.0.7=py36h5fab9bb_3
125 | - nbformat=5.1.3=pyhd8ed1ab_0
126 | - ncurses=6.2=he6710b0_1
127 | - nest-asyncio=1.5.1=pyhd8ed1ab_0
128 | - nettle=3.7.2=hbbd107a_1
129 | - networkx=2.5=py_0
130 | - ninja=1.10.2=hff7bd54_1
131 | - notebook=6.3.0=py36h5fab9bb_0
132 | - numpy=1.19.2=py36h54aff64_0
133 | - numpy-base=1.19.2=py36hfa32c7d_0
134 | - oauthlib=3.1.0=py_0
135 | - olefile=0.46=py36_0
136 | - openh264=2.1.0=hd408876_0
137 | - openjpeg=2.3.0=h05c96fa_1
138 | - openssl=1.1.1k=h27cfd23_0
139 | - packaging=20.9=pyh44b312d_0
140 | - pandas=1.1.5=py36ha9443f7_0
141 | - pandoc=2.14.0.1=h7f98852_0
142 | - pandocfilters=1.4.2=py_1
143 | - pcre=8.44=he6710b0_0
144 | - pexpect=4.8.0=pyh9f0ad1d_2
145 | - pickleshare=0.7.5=py_1003
146 | - pillow=8.2.0=py36he98fc37_0
147 | - prometheus_client=0.11.0=pyhd8ed1ab_0
148 | - prompt_toolkit=1.0.15=py_1
149 | - protobuf=3.14.0=py36h2531618_1
150 | - ptyprocess=0.7.0=pyhd3deb0d_0
151 | - pyasn1=0.4.8=py_0
152 | - pyasn1-modules=0.2.8=py_0
153 | - pycparser=2.20=py_2
154 | - pygments=2.9.0=pyhd8ed1ab_0
155 | - pyjwt=1.7.1=py36_0
156 | - pyopenssl=20.0.1=pyhd3eb1b0_1
157 | - pyparsing=2.4.7=pyhd3eb1b0_0
158 | - pyqt=5.9.2=py36h05f1152_2
159 | - pyrsistent=0.17.3=py36h1d69622_1
160 | - pysocks=1.7.1=py36h06a4308_0
161 | - python=3.6.13=hdb3f193_0
162 | - python-dateutil=2.8.1=pyhd3eb1b0_0
163 | - python_abi=3.6=1_cp36m
164 | - pytorch=1.8.1=py3.6_cuda11.1_cudnn8.0.5_0
165 | - pytz=2021.1=pyhd3eb1b0_0
166 | - pywavelets=1.1.1=py36h7b6447c_2
167 | - pyyaml=5.4.1=py36h27cfd23_1
168 | - pyzmq=19.0.2=py36h9947dbf_2
169 | - qt=5.9.7=h5867ecd_1
170 | - readline=8.1=h27cfd23_0
171 | - requests=2.25.1=pyhd3eb1b0_0
172 | - requests-oauthlib=1.3.0=py_0
173 | - rsa=4.7.2=pyhd3eb1b0_1
174 | - scikit-image=0.17.2=py36hdf5156a_0
175 | - scikit-learn=0.24.2=py36ha9443f7_0
176 | - scipy=1.5.2=py36h0b6359f_0
177 | - seaborn=0.11.0=py_0
178 | - send2trash=1.5.0=py_0
179 | - setuptools=52.0.0=py36h06a4308_0
180 | - simplegeneric=0.8.1=py_1
181 | - sip=4.19.8=py36hf484d3e_0
182 | - six=1.15.0=py36h06a4308_0
183 | - snappy=1.1.8=he6710b0_0
184 | - sniffio=1.2.0=py36h5fab9bb_1
185 | - sqlite=3.35.4=hdfb4753_0
186 | - tensorboard=2.4.1=pyhd8ed1ab_0
187 | - tensorboard-plugin-wit=1.6.0=py_0
188 | - terminado=0.10.0=py36h5fab9bb_0
189 | - testpath=0.5.0=pyhd8ed1ab_0
190 | - threadpoolctl=2.1.0=pyh5ca1d4c_0
191 | - tifffile=2021.3.17=pyhd3eb1b0_1
192 | - tk=8.6.10=hbc83047_0
193 | - toolz=0.11.1=pyhd3eb1b0_0
194 | - torchaudio=0.8.1=py36
195 | - torchvision=0.9.1=py36_cu111
196 | - tornado=6.1=py36h27cfd23_0
197 | - tqdm=4.61.0=pyhd8ed1ab_0
198 | - traitlets=4.3.3=py36h9f0ad1d_1
199 | - typing-extensions=3.7.4.3=hd3eb1b0_0
200 | - typing_extensions=3.7.4.3=pyh06a4308_0
201 | - urllib3=1.26.4=pyhd3eb1b0_0
202 | - wcwidth=0.2.5=pyh9f0ad1d_2
203 | - webencodings=0.5.1=py_1
204 | - websocket-client=0.57.0=py36h5fab9bb_4
205 | - werkzeug=1.0.1=pyhd3eb1b0_0
206 | - wheel=0.36.2=pyhd3eb1b0_0
207 | - x264=1!157.20191217=h7b6447c_0
208 | - xz=5.2.5=h7b6447c_0
209 | - yaml=0.2.5=h7b6447c_0
210 | - yarl=1.6.3=py36h27cfd23_0
211 | - zeromq=4.3.4=h2531618_0
212 | - zipp=3.4.1=pyhd3eb1b0_0
213 | - zlib=1.2.11=h7b6447c_3
214 | - zstd=1.4.9=haebb681_0
215 | - pip:
216 | - argparse==1.4.0
217 | - blessings==1.7
218 | - gpustat==0.6.0
219 | - nvidia-ml-py3==7.352.0
220 | - pip==21.1.2
221 | - psutil==5.8.0
222 | - rtpt==0.0.4
223 | - setproctitle==1.2.2
224 | - torchsummary==1.5.1
225 | - pathos==0.2.8
226 | prefix: /home/dochs/anaconda3/envs/env_slash
227 |
--------------------------------------------------------------------------------
/imgs/slash_icon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ml-research/SLASH/79d158dbd670d56d8bf9043640635c24a18a9c57/imgs/slash_icon.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.12.0
2 | blessings==1.7
3 | cachetools==4.2.2
4 | certifi==2021.5.30
5 | cffi==1.14.5
6 | chardet==4.0.0
7 | clingo==5.5.0.post3
8 | cycler==0.10.0
9 | dataclasses==0.6
10 | decorator==4.4.2
11 | distro==1.5.0
12 | google-auth==1.30.2
13 | google-auth-oauthlib==0.4.4
14 | gpustat==0.6.0
15 | grpcio==1.38.0
16 | idna==2.10
17 | imageio==2.9.0
18 | importlib-metadata==4.5.0
19 | install==1.3.4
20 | joblib==1.0.1
21 | kiwisolver==1.3.1
22 | Markdown==3.3.4
23 | matplotlib==3.3.4
24 | networkx==2.5.1
25 | numpy==1.19.5
26 | nvidia-ml-py3==7.352.0
27 | oauthlib==3.1.1
28 | packaging==20.9
29 | pandas==1.1.5
30 | Pillow==8.2.0
31 | protobuf==3.17.3
32 | psutil==5.8.0
33 | pyasn1==0.4.8
34 | pyasn1-modules==0.2.8
35 | pycparser==2.20
36 | pyparsing==2.4.7
37 | python-dateutil==2.8.1
38 | pytz==2021.1
39 | PyWavelets==1.1.1
40 | requests==2.25.1
41 | requests-oauthlib==1.3.0
42 | rsa==4.7.2
43 | rtpt==0.0.4
44 | scikit-image==0.17.2
45 | scikit-learn==0.24.2
46 | scipy==1.5.4
47 | seaborn==0.11.1
48 | setproctitle==1.2.2
49 | six==1.16.0
50 | tensorboard==2.5.0
51 | tensorboard-data-server==0.6.1
52 | tensorboard-plugin-wit==1.8.0
53 | threadpoolctl==2.1.0
54 | tifffile==2020.9.3
55 | #torch==1.8.1+cu111
56 | #torchaudio==0.8.1
57 | torchsummary==1.5.1
58 | #torchvision==0.9.1+cu111
59 | tqdm==4.61.0
60 | typing-extensions==3.10.0.0
61 | urllib3==1.26.5
62 | warmup-scheduler @ git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git@6b5e8953a80aef5b324104dc0c2e9b8c34d622bd
63 | Werkzeug==2.0.1
64 | zipp==3.4.1
65 | pathos==0.2.8
66 |
--------------------------------------------------------------------------------
/setup.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | source ~/.bashrc
4 | pyenv uninstall venv_slash
5 | pyenv virtualenv venv_slash
6 | pyenv activate venv_slash
7 | pip install --upgrade pip
8 | pip install -r requirements.txt
9 |
--------------------------------------------------------------------------------
/src/README.md:
--------------------------------------------------------------------------------
1 | # Source folder quick overview:
2 |
3 | 1. `EinsumNetwork` contains code which was cloned from the [git repository EinNets](https://github.com/cambridge-mlg/EinsumNetworks).
4 | 2. `SLASH` contains all modifications made to enable probabilisitc circuits in the program, VQA and SLASH-attention, optimized gradient computation and the SAME technique.
5 | * `mvpp.py` contains all functions to compute stable models given a logic program and to compute the gradients using the output probabilites
6 | * `slash.py` contains the SLASH class which brings together the symbolic program with neural nets or probabilistic circuits
7 | 3. `experiments` contains the VQA, mnist addtion and slash attention experiments
--------------------------------------------------------------------------------
/src/SLASH/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ml-research/SLASH/79d158dbd670d56d8bf9043640635c24a18a9c57/src/SLASH/__init__.py
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ml-research/SLASH/79d158dbd670d56d8bf9043640635c24a18a9c57/src/__init__.py
--------------------------------------------------------------------------------
/src/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 | import urllib.request
4 | import shutil
5 |
6 | from zipfile import ZipFile
7 | import gzip
8 | import utils
9 |
10 | def maybe_download(directory, url_base, filename, suffix='.zip'):
11 | '''
12 | Downloads the specified dataset and extracts it
13 |
14 | @param directory:
15 | @param url_base: URL where to find the file
16 | @param filename: name of the file to be downloaded
17 | @param suffix: suffix of the file
18 |
19 | :returns: true if nothing went wrong downloading
20 | '''
21 |
22 | filepath = os.path.join(directory, filename)
23 | if os.path.isfile(filepath):
24 | return False
25 |
26 | if not os.path.isdir(directory):
27 | utils.mkdir_p(directory)
28 |
29 | url = url_base +filename
30 |
31 | _, zipped_filepath = tempfile.mkstemp(suffix=suffix)
32 |
33 | print('Downloading {} to {}'.format(url, zipped_filepath))
34 |
35 | urllib.request.urlretrieve(url, zipped_filepath)
36 | print('{} Bytes'.format(os.path.getsize(zipped_filepath)))
37 |
38 | print('Move to {}'.format(filepath))
39 | shutil.move(zipped_filepath, filepath)
40 | return True
41 |
42 |
43 | def extract_dataset(directory, filepath, filepath_extracted):
44 | if not os.path.isdir(filepath_extracted):
45 | print('unzip ',filepath, " to", filepath_extracted)
46 | with ZipFile(filepath, 'r') as zipObj:
47 | # Extract all the contents of zip file in current directory
48 | zipObj.extractall(directory)
49 |
50 |
51 | def maybe_download_shapeworld4():
52 | '''
53 | Downloads the shapeworld4 dataset if it is not downloaded yet
54 | '''
55 |
56 | directory = "../../data/"
57 | file_name= "shapeworld4.zip"
58 | maybe_download(directory, "https://hessenbox.tu-darmstadt.de/dl/fiEE3hftM4n1gBGn4HJLKUkU/", file_name)
59 |
60 | filepath = os.path.join(directory, file_name)
61 | filepath_extracted = os.path.join(directory,"shapeworld4")
62 |
63 | extract_dataset(directory, filepath, filepath_extracted)
64 |
65 |
66 | def maybe_download_shapeworld_cogent():
67 | '''
68 | Downloads the shapeworld4 cogent dataset if it is not downloaded yet
69 | '''
70 |
71 | directory = "../../data/"
72 | file_name= "shapeworld_cogent.zip"
73 | maybe_download(directory, "https://hessenbox.tu-darmstadt.de/dl/fi3CDjPRsYgAvotHcC8GPaWj/", file_name)
74 |
75 | filepath = os.path.join(directory, file_name)
76 | filepath_extracted = os.path.join(directory,"shapeworld_cogent")
77 |
78 | extract_dataset(directory, filepath, filepath_extracted)
79 |
--------------------------------------------------------------------------------
/src/einsum_wrapper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from EinsumNetwork import EinsumNetwork, Graph
5 |
6 | device = torch.device('cuda:0')
7 |
8 | #wrapper class to create an Einsum Network given a specific structure and parameters
9 | class EiNet(EinsumNetwork.EinsumNetwork):
10 | def __init__(self ,
11 | use_em,
12 | structure = 'poon-domingos',
13 | pd_num_pieces = [4],
14 | depth = 8,
15 | num_repetitions = 20,
16 | num_var = 784,
17 | class_count = 3,
18 | K = 10,
19 | num_sums = 10,
20 | pd_height = 28,
21 | pd_width = 28,
22 | learn_prior = True
23 | ):
24 |
25 |
26 | # Structure
27 | self.structure = structure
28 | self.class_count = class_count
29 | classes = np.arange(class_count) # [0,1,2,..,n-1]
30 |
31 | # Define the prior, i.e. P(C) and make it learnable.
32 | self.learnable_prior = learn_prior
33 | # P(C) is needed to apply the Bayes' theorem and to retrive
34 | # P(C|X) = P(X|C)*(P(C) / P(X)
35 | if self.class_count == 4:
36 | self.prior = torch.tensor([(1/3)*(2/3), (1/3)*(2/3), (1/3)*(2/3), (1/3)], dtype=torch.float, requires_grad=True, device=device).log()
37 | else:
38 | self.prior = torch.ones(class_count, device=device, dtype=torch.float)
39 | self.prior.fill_(1 / class_count)
40 | self.prior.log_()
41 | if self.learnable_prior:
42 | print("P(C) is learnable.")
43 | self.prior.requires_grad_()
44 |
45 | self.K = K
46 | self.num_sums = num_sums
47 |
48 | # 'poon-domingos'
49 | self.pd_num_pieces = pd_num_pieces # [10, 28],[4],[7]
50 | self.pd_height = pd_height
51 | self.pd_width = pd_width
52 |
53 |
54 | # 'binary-trees'
55 | self.depth = depth
56 | self.num_repetitions = num_repetitions
57 | self.num_var = num_var
58 |
59 | # drop-out rate
60 | # self.drop_out = drop_out
61 | # print("The drop-out rate:", self.drop_out)
62 |
63 | # EM-settings
64 | self.use_em = use_em
65 | online_em_frequency = 1
66 | online_em_stepsize = 0.05 # 0.05
67 | print("train SPN with EM:",self.use_em)
68 |
69 |
70 | # exponential_family = EinsumNetwork.BinomialArray
71 | # exponential_family = EinsumNetwork.CategoricalArray
72 | exponential_family = EinsumNetwork.NormalArray
73 |
74 | exponential_family_args = None
75 | if exponential_family == EinsumNetwork.BinomialArray:
76 | exponential_family_args = {'N': 255}
77 | if exponential_family == EinsumNetwork.CategoricalArray:
78 | exponential_family_args = {'K': 1366120}
79 | if exponential_family == EinsumNetwork.NormalArray:
80 | exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1}
81 |
82 | # Make EinsumNetwork
83 | if self.structure == 'poon-domingos':
84 | pd_delta = [[self.pd_height / d, self.pd_width / d] for d in self.pd_num_pieces]
85 | graph = Graph.poon_domingos_structure(shape=(self.pd_height, self.pd_width), delta=pd_delta)
86 | elif self.structure == 'binary-trees':
87 | graph = Graph.random_binary_trees(num_var=self.num_var, depth=self.depth, num_repetitions=self.num_repetitions)
88 | else:
89 | raise AssertionError("Unknown Structure")
90 |
91 |
92 | args = EinsumNetwork.Args(
93 | num_var=self.num_var,
94 | num_dims=1,
95 | num_classes=self.class_count,
96 | num_sums=self.num_sums,
97 | num_input_distributions=self.K,
98 | exponential_family=exponential_family,
99 | exponential_family_args=exponential_family_args,
100 | use_em=self.use_em,
101 | online_em_frequency=online_em_frequency,
102 | online_em_stepsize=online_em_stepsize)
103 |
104 | super().__init__(graph, args)
105 | super().initialize()
106 |
107 | def get_log_likelihoods(self, x):
108 | log_likelihood = super().forward(x)
109 | return log_likelihood
110 |
111 | def forward(self, x, marg_idx=None, type=1):
112 |
113 | # PRIOR
114 | if type == 4:
115 | expanded_prior = self.prior.expand(x.shape[0], self.prior.shape[0])
116 | return expanded_prior
117 |
118 | else:
119 | # Obtain P(X|C) in log domain
120 | if marg_idx: # If marginalisation mask is passed
121 | self.set_marginalization_idx(marg_idx)
122 | log_likelihood = super().forward(x)
123 | self.set_marginalization_idx(None)
124 | likelihood = torch.nn.functional.softmax(log_likelihood, dim=1)
125 | else:
126 | log_likelihood = super().forward(x)
127 |
128 | #LIKELIHOOD
129 | if type == 2:
130 | likelihood = torch.nn.functional.softmax(log_likelihood, dim=1)
131 | # Sanity check for NaN-values
132 | if torch.isnan(log_likelihood).sum() > 0:
133 | print("likelihood nan")
134 |
135 | return likelihood
136 | else:
137 | # Apply Bayes' Theorem to obtain P(C|X) instead of P(X|C)
138 | # as it is provided by the EiNet
139 | # 1. Computation of the prior, i.e. P(C), is already being
140 | # dealt with at the initialisation of the EiNet.
141 | # 2. Compute the normalization constant P(X)
142 | z = torch.logsumexp(log_likelihood + self.prior, dim=1)
143 | # 3. Compute the posterior, i.e. P(C|X) = (P(X|C) * P(C)) / P(X)
144 | posterior_log = (log_likelihood + self.prior - z[:, None]) # log domain
145 | #posterior = posterior_log.exp() # decimal domain
146 |
147 |
148 | #POSTERIOR
149 | if type == 1:
150 | posterior = torch.nn.functional.softmax(posterior_log, dim=1)
151 |
152 | # Sanity check for NaN-values
153 | if torch.isnan(z).sum() > 0:
154 | print("z nan")
155 | if torch.isnan(posterior).sum() > 0:
156 | print("posterior nan")
157 | return posterior
158 |
159 | #JOINT
160 | elif type == 3:
161 | #compute the joint P(X|C) * P(C)
162 | joint = torch.nn.functional.softmax(log_likelihood + self.prior, dim=1)
163 |
164 | return joint
165 |
--------------------------------------------------------------------------------
/src/experiments/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ml-research/SLASH/79d158dbd670d56d8bf9043640635c24a18a9c57/src/experiments/__init__.py
--------------------------------------------------------------------------------
/src/experiments/baseline_slot_attention/dataGen.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 |
4 | from torch.utils.data import Dataset
5 | from torchvision.transforms import transforms
6 |
7 | from torch.utils.data import Dataset
8 | from torchvision import transforms
9 | from skimage import io
10 | import os
11 | import numpy as np
12 | import random
13 | import torch
14 | import matplotlib.pyplot as plt
15 | from PIL import Image, ImageFile
16 | ImageFile.LOAD_TRUNCATED_IMAGES = True
17 | import json
18 | import datasets as datasets
19 |
20 |
21 | def get_loader(dataset, batch_size, num_workers=8, shuffle=True):
22 | '''
23 | Returns and iterable dataset with specified batchsize and shuffling.
24 | '''
25 | return torch.utils.data.DataLoader(
26 | dataset,
27 | shuffle=shuffle,
28 | batch_size=batch_size,
29 | num_workers=num_workers
30 | )
31 |
32 |
33 | def get_encoding_shapeworld(color, shape, shade, size):
34 |
35 | if color == 'red':
36 | col_enc = [1,0,0,0,0,0,0,0]
37 | elif color == 'blue':
38 | col_enc = [0,1,0,0,0,0,0,0]
39 | elif color == 'green':
40 | col_enc = [0,0,1,0,0,0,0,0]
41 | elif color == 'gray':
42 | col_enc = [0,0,0,1,0,0,0,0]
43 | elif color == 'brown':
44 | col_enc = [0,0,0,0,1,0,0,0]
45 | elif color == 'magenta':
46 | col_enc = [0,0,0,0,0,1,0,0]
47 | elif color == 'cyan':
48 | col_enc = [0,0,0,0,0,0,1,0]
49 | elif color == 'yellow':
50 | col_enc = [0,0,0,0,0,0,0,1]
51 |
52 | if shape == 'circle':
53 | shape_enc = [1,0,0]
54 | elif shape == 'triangle':
55 | shape_enc = [0,1,0]
56 | elif shape == 'square':
57 | shape_enc = [0,0,1]
58 |
59 | if shade == 'bright':
60 | shade_enc = [1,0]
61 | elif shade =='dark':
62 | shade_enc = [0,1]
63 |
64 |
65 | if size == 'small':
66 | size_enc = [1,0]
67 | elif size == 'big':
68 | size_enc = [0,1]
69 |
70 | return np.array([1]+ col_enc + shape_enc + shade_enc + size_enc)
71 |
72 |
73 | class SHAPEWORLD4(Dataset):
74 | def __init__(self, root, mode, learn_concept='default', bg_encoded=True):
75 |
76 | datasets.maybe_download_shapeworld4()
77 |
78 | self.root = root
79 | self.mode = mode
80 | assert os.path.exists(root), 'Path {} does not exist'.format(root)
81 |
82 | #dictionary of the form {'image_idx':'img_path'}
83 | self.img_paths = {}
84 |
85 |
86 | for file in os.scandir(os.path.join(root, 'images', mode)):
87 | img_path = file.path
88 |
89 | img_path_idx = img_path.split("/")
90 | img_path_idx = img_path_idx[-1]
91 | img_path_idx = img_path_idx[:-4][6:]
92 | try:
93 | img_path_idx = int(img_path_idx)
94 | self.img_paths[img_path_idx] = img_path
95 | except:
96 | print("path:",img_path_idx)
97 |
98 |
99 |
100 | count = 0
101 |
102 | #target maps of the form {'target:idx': observation string} or {'target:idx': obj encoding}
103 | self.obj_map = {}
104 |
105 | with open(os.path.join(root, 'labels', mode,"world_model.json")) as f:
106 | worlds = json.load(f)
107 |
108 |
109 |
110 | #iterate over all objects
111 | for world in worlds:
112 | num_objects = 0
113 | target_obs = ""
114 | obj_enc = []
115 | for entity in world['entities']:
116 |
117 | color = entity['color']['name']
118 | shape = entity['shape']['name']
119 |
120 | shade_val = entity['color']['shade']
121 | if shade_val == 0.0:
122 | shade = 'bright'
123 | else:
124 | shade = 'dark'
125 |
126 | size_val = entity['shape']['size']['x']
127 | if size_val == 0.075:
128 | size = 'small'
129 | elif size_val == 0.15:
130 | size = 'big'
131 |
132 | name = 'o' + str(num_objects+1)
133 | obj_enc.append(get_encoding_shapeworld(color, shape, shade, size))
134 | num_objects += 1
135 |
136 | #bg encodings
137 | for i in range(num_objects, 4):
138 | name = 'o' + str(num_objects+1)
139 | obj_enc.append(np.array([0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]))
140 | num_objects += 1
141 |
142 | self.obj_map[count] = torch.Tensor(obj_enc)
143 | count+=1
144 |
145 | def __getitem__(self, index):
146 |
147 | #get the image
148 | img_path = self.img_paths[index]
149 | img = io.imread(img_path)[:, :, :3]
150 |
151 | transform = transforms.Compose([
152 | transforms.ToPILImage(),
153 | #transforms.CenterCrop(250),
154 | #transforms.Resize((32, 32)),
155 | transforms.ToTensor(),
156 | ])
157 | img = transform(img)
158 | img = (img - 0.5) * 2.0 # Rescale to [-1, 1].
159 |
160 | return img, self.obj_map[index]#, mask
161 |
162 | def __len__(self):
163 | return len(self.img_paths)
164 |
165 | def get_encoding_clevr(size, material, shape, color ):
166 |
167 | #size (small, large, bg)
168 | if size == "small":
169 | size_enc = [1,0]
170 | elif size == "large":
171 | size_enc = [0,1]
172 |
173 | #material (rubber, metal, bg)
174 | if material == "rubber":
175 | material_enc = [1,0]
176 | elif material == "metal":
177 | material_enc = [0,1]
178 |
179 | #shape (cube, sphere, cylinder, bg)
180 | if shape == "cube":
181 | shape_enc = [1,0,0]
182 | elif shape == "sphere":
183 | shape_enc = [0,1,0]
184 | elif shape == "cylinder":
185 | shape_enc = [0,0,1]
186 |
187 |
188 | #color (gray, red, blue, green, brown, purple, cyan, yellow, bg)
189 | if color == "gray":
190 | color_enc = [1,0,0,0,0,0,0,0]
191 | elif color == "red":
192 | color_enc = [0,1,0,0,0,0,0,0]
193 | elif color == "blue":
194 | color_enc = [0,0,1,0,0,0,0,0]
195 | elif color == "green":
196 | color_enc = [0,0,0,1,0,0,0,0]
197 | elif color == "brown":
198 | color_enc = [0,0,0,0,1,0,0,0]
199 | elif color == "purple":
200 | color_enc = [0,0,0,0,0,1,0,0]
201 | elif color == "cyan":
202 | color_enc = [0,0,0,0,0,0,1,0]
203 | elif color == "yellow":
204 | color_enc = [0,0,0,0,0,0,0,1]
205 |
206 |
207 | return np.array([1] + size_enc + material_enc + shape_enc + color_enc )
208 |
209 |
210 | class CLEVR(Dataset):
211 | def __init__(self, root, mode, img_paths=None, files_names=None, obj_num=None):
212 | self.root = root # The root folder of the dataset
213 | self.mode = mode # The mode of 'train' or 'val'
214 | self.files_names = files_names # The list of the files names with correct nuber of objects
215 | if obj_num is not None:
216 | self.obj_num = obj_num # The upper limit of number of objects
217 | else:
218 | self.obj_num = 10
219 |
220 | assert os.path.exists(root), 'Path {} does not exist'.format(root)
221 |
222 | #list of sorted image paths
223 | self.img_paths = []
224 | if img_paths:
225 | self.img_paths = img_paths
226 | else:
227 | #open directory and save all image paths
228 | for file in os.scandir(os.path.join(root, 'images', mode)):
229 | img_path = file.path
230 | if '.png' in img_path:
231 | self.img_paths.append(img_path)
232 |
233 | self.img_paths.sort()
234 | self.img_paths = np.array(self.img_paths, dtype=str)
235 |
236 | count = 0
237 |
238 | #target maps of the form {'target:idx': query string} or {'target:idx': obj encoding}
239 | #self.obj_map = {}
240 |
241 |
242 | count = 0
243 | #We have up to 10 objects in the image, load the json file
244 | with open(os.path.join(root, 'scenes','CLEVR_'+ mode+"_scenes.json")) as f:
245 | data = json.load(f)
246 |
247 | self.obj_map = np.empty((len(data['scenes']),10,16), dtype=np.float32)
248 |
249 | #iterate over each scene and create the query string and obj encoding
250 | print("parsing scences")
251 | for scene in data['scenes']:
252 | obj_encoding_list = []
253 |
254 | if self.files_names:
255 | if any(scene['image_filename'] in file_name for file_name in files_names):
256 | num_objects = 0
257 | for idx, obj in enumerate(scene['objects']):
258 | obj_encoding_list.append(get_encoding_clevr(obj['size'], obj['material'], obj['shape'], obj['color']))
259 | num_objects = idx+1 #store the num of objects
260 | #fill in background objects
261 | for idx in range(num_objects, self.obj_num):
262 | obj_encoding_list.append([0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0])
263 | self.obj_map[count] = np.array(obj_encoding_list)
264 | count += 1
265 | else:
266 | num_objects=0
267 | for idx, obj in enumerate(scene['objects']):
268 | obj_encoding_list.append(get_encoding_clevr(obj['size'], obj['material'], obj['shape'], obj['color']))
269 | num_objects = idx+1 #store the num of objects
270 | #fill in background objects
271 | for idx in range(num_objects, 10):
272 | obj_encoding_list.append([0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0])
273 | self.obj_map[scene['image_index']] = np.array(obj_encoding_list, dtype=np.float32)
274 |
275 | print("done")
276 | if self.files_names:
277 | print(f'Correctly found images {count} out of {len(files_names)}')
278 |
279 | def __getitem__(self, index):
280 |
281 | #get the image
282 | img_path = self.img_paths[index]
283 | img = io.imread(img_path)[:, :, :3]
284 | img = Image.fromarray(img).resize((128,128)) #using transforms to resize gets us a shrared-memory leak :(
285 |
286 | transform = transforms.Compose([
287 | #transforms.ToPILImage(),
288 | #transforms.CenterCrop((29, 221,64, 256)), #why do we need to crop?
289 | #transforms.Resize((128, 128)),
290 | transforms.ToTensor(),
291 | ])
292 | img = transform(img)
293 | img = (img - 0.5) * 2.0 # Rescale to [-1, 1].
294 |
295 | return img, self.obj_map[index]#, mask
296 |
297 | def __len__(self):
298 | return self.img_paths.shape[0]
--------------------------------------------------------------------------------
/src/experiments/baseline_slot_attention/set_utils.py:
--------------------------------------------------------------------------------
1 | import scipy.optimize
2 | import torch
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import os
6 |
7 |
8 | def save_args(args, writer):
9 | # store args as txt file
10 | with open(os.path.join(writer.log_dir, 'args.txt'), 'w') as f:
11 | for arg in vars(args):
12 | f.write(f"\n{arg}: {getattr(args, arg)}")
13 |
14 |
15 | def hungarian_matching(attrs, preds_attrs, verbose=0):
16 | """
17 | Receives unordered predicted set and orders this to match the nearest GT set.
18 | :param attrs:
19 | :param preds_attrs:
20 | :param verbose:
21 | :return:
22 | """
23 | assert attrs.shape[1] == preds_attrs.shape[1]
24 | assert attrs.shape == preds_attrs.shape
25 | from scipy.optimize import linear_sum_assignment
26 | matched_preds_attrs = preds_attrs.clone()
27 | for sample_id in range(attrs.shape[0]):
28 | # using euclidean distance
29 | cost_matrix = torch.cdist(attrs[sample_id], preds_attrs[sample_id]).detach().cpu()
30 |
31 | idx_mapping = linear_sum_assignment(cost_matrix)
32 | # convert to tuples of [(row_id, col_id)] of the cost matrix
33 | idx_mapping = [(idx_mapping[0][i], idx_mapping[1][i]) for i in range(len(idx_mapping[0]))]
34 |
35 | for i, (row_id, col_id) in enumerate(idx_mapping):
36 | matched_preds_attrs[sample_id, row_id, :] = preds_attrs[sample_id, col_id, :]
37 | if verbose:
38 | print('GT: {}'.format(attrs[sample_id]))
39 | print('Pred: {}'.format(preds_attrs[sample_id]))
40 | print('Cost Matrix: {}'.format(cost_matrix))
41 | print('idx mapping: {}'.format(idx_mapping))
42 | print('Matched Pred: {}'.format(matched_preds_attrs[sample_id]))
43 | print('\n')
44 | # exit()
45 |
46 | return matched_preds_attrs
47 |
48 |
49 | def average_precision_shapeworld(pred, attributes, distance_threshold, dataset):
50 | """Computes the average precision for CLEVR.
51 | This function computes the average precision of the predictions specifically
52 | for the CLEVR dataset. First, we sort the predictions of the model by
53 | confidence (highest confidence first). Then, for each prediction we check
54 | whether there was a corresponding object in the input image. A prediction is
55 | considered a true positive if the discrete features are predicted correctly
56 | and the predicted position is within a certain distance from the ground truth
57 | object.
58 | Args:
59 | pred: Tensor of shape [batch_size, num_elements, dimension] containing
60 | predictions. The last dimension is expected to be the confidence of the
61 | prediction.
62 | attributes: Tensor of shape [batch_size, num_elements, dimension] containing
63 | predictions.
64 | distance_threshold: Threshold to accept match. -1 indicates no threshold.
65 | Returns:
66 | Average precision of the predictions.
67 | """
68 |
69 | [batch_size, _, element_size] = attributes.shape
70 | [_, predicted_elements, _] = pred.shape
71 |
72 | def unsorted_id_to_image(detection_id, predicted_elements):
73 | """Find the index of the image from the unsorted detection index."""
74 | return int(detection_id // predicted_elements)
75 |
76 | flat_size = batch_size * predicted_elements
77 | flat_pred = np.reshape(pred, [flat_size, element_size])
78 | # sort_idx = np.argsort(flat_pred[:, -1], axis=0)[::-1] # Reverse order.
79 | sort_idx = np.argsort(flat_pred[:, 0], axis=0)[::-1] # Reverse order.
80 |
81 | sorted_predictions = np.take_along_axis(
82 | flat_pred, np.expand_dims(sort_idx, axis=1), axis=0)
83 | idx_sorted_to_unsorted = np.take_along_axis(
84 | np.arange(flat_size), sort_idx, axis=0)
85 |
86 |
87 | def process_targets_shapeworld4(target):
88 | """Unpacks the target into the CLEVR properties."""
89 | #col_enc + shape_enc + shade_enc + size_enc
90 | real_obj = target[0]
91 | color = np.argmax(target[1:9])
92 | shape = np.argmax(target[9:12])
93 | shade = np.argmax(target[12:14])
94 | size = np.argmax(target[14:16])
95 | return np.array([0,0,0]), size, shade, shape, color, real_obj
96 |
97 | def process_targets_clevr(target):
98 | """Unpacks the target into the CLEVR properties."""
99 | #col_enc + shape_enc + shade_enc + size_enc
100 |
101 |
102 | real_obj = target[0]
103 | size = np.argmax(target[1:3])
104 | material = np.argmax(target[3:5])
105 | shape = np.argmax(target[5:8])
106 | color = np.argmax(target[8:16])
107 |
108 | return np.array([0,0,0]), size, material, shape, color, real_obj
109 |
110 |
111 | def process_targets(target):
112 | if dataset == "shapeworld4":
113 | return process_targets_shapeworld4(target)
114 | elif dataset == "clevr":
115 | return process_targets_clevr(target)
116 |
117 |
118 | true_positives = np.zeros(sorted_predictions.shape[0])
119 | false_positives = np.zeros(sorted_predictions.shape[0])
120 |
121 | detection_set = set()
122 |
123 | for detection_id in range(sorted_predictions.shape[0]):
124 | # Extract the current prediction.
125 | current_pred = sorted_predictions[detection_id, :]
126 | # Find which image the prediction belongs to. Get the unsorted index from
127 | # the sorted one and then apply to unsorted_id_to_image function that undoes
128 | # the reshape.
129 | original_image_idx = unsorted_id_to_image(
130 | idx_sorted_to_unsorted[detection_id], predicted_elements)
131 | # Get the ground truth image.
132 | gt_image = attributes[original_image_idx, :, :]
133 |
134 | # Initialize the maximum distance and the id of the groud-truth object that
135 | # was found.
136 | best_distance = 10000
137 | best_id = None
138 |
139 | # Unpack the prediction by taking the argmax on the discrete
140 | # attributes.
141 | (pred_coords, pred_object_size, pred_material, pred_shape, pred_color,
142 | _) = process_targets(current_pred)
143 |
144 | # Loop through all objects in the ground-truth image to check for hits.
145 | for target_object_id in range(gt_image.shape[0]):
146 | target_object = gt_image[target_object_id, :]
147 | # Unpack the targets taking the argmax on the discrete attributes.
148 | (target_coords, target_object_size, target_material, target_shape,
149 | target_color, target_real_obj) = process_targets(target_object)
150 | # Only consider real objects as matches.
151 | if target_real_obj:
152 | # For the match to be valid all attributes need to be correctly
153 | # predicted.
154 | pred_attr = [
155 | pred_object_size,
156 | pred_material,
157 | pred_shape,
158 | pred_color]
159 | target_attr = [
160 | target_object_size,
161 | target_material,
162 | target_shape,
163 | target_color]
164 | match = pred_attr == target_attr
165 | if match:
166 | # If a match was found, we check if the distance is below the
167 | # specified threshold. Recall that we have rescaled the coordinates
168 | # in the dataset from [-3, 3] to [0, 1], both for `target_coords` and
169 | # `pred_coords`. To compare in the original scale, we thus need to
170 | # multiply the distance values by 6 before applying the
171 | # norm.
172 | distance = np.linalg.norm(
173 | (target_coords - pred_coords) * 6.)
174 |
175 | # If this is the best match we've found so far we remember
176 | # it.
177 | if distance < best_distance:
178 | best_distance = distance
179 | best_id = target_object_id
180 | if best_distance < distance_threshold or distance_threshold == -1:
181 | # We have detected an object correctly within the distance confidence.
182 | # If this object was not detected before it's a true positive.
183 | if best_id is not None:
184 | if (original_image_idx, best_id) not in detection_set:
185 | true_positives[detection_id] = 1
186 | detection_set.add((original_image_idx, best_id))
187 | else:
188 | false_positives[detection_id] = 1
189 | else:
190 | false_positives[detection_id] = 1
191 | else:
192 | false_positives[detection_id] = 1
193 |
194 | accumulated_fp = np.cumsum(false_positives)
195 | accumulated_tp = np.cumsum(true_positives)
196 |
197 | recall_array = accumulated_tp / np.sum(attributes[:, :, 0])
198 | precision_array = np.divide(
199 | accumulated_tp,
200 | (accumulated_fp + accumulated_tp))
201 |
202 | return compute_average_precision(
203 | np.array(precision_array, dtype=np.float32),
204 | np.array(recall_array, dtype=np.float32))
205 |
206 |
207 | def compute_average_precision(precision, recall):
208 | """Computation of the average precision from precision and recall arrays."""
209 | recall = recall.tolist()
210 | precision = precision.tolist()
211 | recall = [0] + recall + [1]
212 | precision = [0] + precision + [0]
213 |
214 | for i in range(len(precision) - 1, -0, -1):
215 | precision[i - 1] = max(precision[i - 1], precision[i])
216 |
217 | indices_recall = [
218 | i for i in range(len(recall) - 1) if recall[1:][i] != recall[:-1][i]
219 | ]
220 |
221 | average_precision = 0.
222 | for i in indices_recall:
223 | average_precision += precision[i + 1] * (recall[i + 1] - recall[i])
224 | return average_precision
225 |
226 |
227 | def hungarian_loss(predictions, targets):
228 | # permute dimensions for pairwise distance computation between all slots
229 | predictions = predictions.permute(0, 2, 1)
230 | targets = targets.permute(0, 2, 1)
231 |
232 | # predictions and targets shape :: (n, c, s)
233 | predictions, targets = outer(predictions, targets)
234 | # squared_error shape :: (n, s, s)
235 | squared_error = F.smooth_l1_loss(predictions, targets.expand_as(predictions), reduction="none").mean(1)
236 |
237 | squared_error_np = squared_error.detach().cpu().numpy()
238 | indices = map(hungarian_loss_per_sample, squared_error_np)
239 | losses = [
240 | sample[row_idx, col_idx].mean()
241 | for sample, (row_idx, col_idx) in zip(squared_error, indices)
242 | ]
243 | total_loss = torch.mean(torch.stack(list(losses)))
244 | return total_loss
245 |
246 |
247 |
248 | def hungarian_loss_per_sample(sample_np):
249 | return scipy.optimize.linear_sum_assignment(sample_np)
250 |
251 |
252 | def scatter_masked(tensor, mask, binned=False, threshold=None):
253 | s = tensor[0].detach().cpu()
254 | mask = mask[0].detach().clamp(min=0, max=1).cpu()
255 | if binned:
256 | s = s * 128
257 | s = s.view(-1, s.size(-1))
258 | mask = mask.view(-1)
259 | if threshold is not None:
260 | keep = mask.view(-1) > threshold
261 | s = s[:, keep]
262 | mask = mask[keep]
263 | return s, mask
264 |
265 |
266 | def outer(a, b=None):
267 | """ Compute outer product between a and b (or a and a if b is not specified). """
268 | if b is None:
269 | b = a
270 | size_a = tuple(a.size()) + (b.size()[-1],)
271 | size_b = tuple(b.size()) + (a.size()[-1],)
272 | a = a.unsqueeze(dim=-1).expand(*size_a)
273 | b = b.unsqueeze(dim=-2).expand(*size_b)
274 | return a, b
--------------------------------------------------------------------------------
/src/experiments/baseline_slot_attention/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from datetime import datetime
4 | import time
5 | import sys
6 | sys.path.append('../../')
7 |
8 | import torch
9 | import torch.nn.functional as F
10 | import matplotlib
11 | matplotlib.use("Agg")
12 |
13 | import torchvision.transforms as transforms
14 | import torchvision.datasets as datasets
15 | import torch.multiprocessing as mp
16 |
17 | import scipy.optimize
18 | import numpy as np
19 | #from tqdm import tqdm
20 | from rtpt import RTPT
21 |
22 | import matplotlib.pyplot as plt
23 | from torch.utils.tensorboard import SummaryWriter
24 |
25 | from dataGen import SHAPEWORLD4,CLEVR, get_loader
26 | import slot_attention_module as model
27 | import set_utils as set_utils
28 |
29 | import utils as misc_utils
30 |
31 |
32 | def get_args():
33 | parser = argparse.ArgumentParser()
34 | # generic params
35 | parser.add_argument(
36 | "--name",
37 | default=datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
38 | help="Name to store the log file as",
39 | )
40 | parser.add_argument("--resume", help="Path to log file to resume from")
41 |
42 | parser.add_argument(
43 | "--seed", type=int, default=10, help="Random generator seed for all frameworks"
44 | )
45 | parser.add_argument(
46 | "--epochs", type=int, default=10, help="Number of epochs to train with"
47 | )
48 | parser.add_argument(
49 | "--ap-log", type=int, default=10, help="Number of epochs before logging AP"
50 | )
51 | parser.add_argument(
52 | "--lr", type=float, default=1e-2, help="Outer learning rate of model"
53 | )
54 | parser.add_argument(
55 | "--warmup-epochs", type=int, default=10, help="Number of steps fpr learning rate warm up"
56 | )
57 | parser.add_argument(
58 | "--decay-epochs", type=int, default=10, help="Number of steps fpr learning rate decay"
59 | )
60 | parser.add_argument(
61 | "--batch-size", type=int, default=32, help="Batch size to train with"
62 | )
63 | parser.add_argument(
64 | "--num-workers", type=int, default=6, help="Number of threads for data loader"
65 | )
66 | parser.add_argument(
67 | "--dataset",
68 | choices=["shapeworld4", "clevr"],
69 | help="Use shapeworld4 dataset",
70 | )
71 | parser.add_argument(
72 | "--cogent", action='store_true',
73 | help="Evaluate on the CoGenT test of the dataset",
74 | )
75 | parser.add_argument(
76 | "--no-cuda",
77 | action="store_true",
78 | help="Run on CPU instead of GPU (not recommended)",
79 | )
80 | parser.add_argument(
81 | "--train-only", action="store_true", help="Only run training, no evaluation"
82 | )
83 | parser.add_argument(
84 | "--eval-only", action="store_true", help="Only run evaluation, no training"
85 | )
86 | parser.add_argument("--multi-gpu", action="store_true", help="Use multiple GPUs")
87 | parser.add_argument("--credentials", type=str, help="Credentials for rtpt")
88 | parser.add_argument("--export-dir", type=str, help="Directory to output samples to")
89 | parser.add_argument("--data-dir", type=str, help="Directory to data")
90 |
91 | # Slot attention params
92 | parser.add_argument('--n-slots', default=10, type=int,
93 | help='number of slots for slot attention module')
94 | parser.add_argument('--n-iters-slot-att', default=3, type=int,
95 | help='number of iterations in slot attention module')
96 | parser.add_argument('--n-attr', default=18, type=int,
97 | help='number of attributes per object')
98 |
99 | args = parser.parse_args()
100 |
101 | if args.no_cuda:
102 | args.device = 'cpu'
103 | else:
104 | args.device = 'cuda:0'
105 |
106 | misc_utils.set_manual_seed(args.seed)
107 |
108 | args.name += f'-{args.seed}'
109 |
110 | return args
111 |
112 |
113 | def run(net, loader, optimizer, criterion, writer, args, test_cond = None, train=False, epoch=0):
114 | if train:
115 | net.train()
116 | prefix = "train"
117 | torch.set_grad_enabled(True)
118 | else:
119 | net.eval()
120 | prefix = "test"
121 | torch.set_grad_enabled(False)
122 |
123 | preds_all = torch.zeros(0, args.n_slots, args.n_attr)
124 | target_all = torch.zeros(0, args.n_slots, args.n_attr)
125 |
126 | iters_per_epoch = len(loader)
127 |
128 |
129 | for i, sample in enumerate(loader, start=epoch * iters_per_epoch):
130 |
131 | start = time.time()
132 |
133 | # input is either a set or an image
134 | if 'cuda' in args.device:
135 | imgs, target_set = map(lambda x: x.cuda(), sample)
136 | else:
137 | imgs, target_set = sample
138 |
139 | load = time.time()
140 | #print("\nload", load-start)
141 |
142 | output = net.forward(imgs)
143 |
144 | loss = set_utils.hungarian_loss(output, target_set)
145 |
146 | forward = time.time()
147 | #print("forward", forward-load)
148 |
149 | if train:
150 |
151 | # apply lr schedulers
152 | if epoch < args.warmup_epochs:
153 | lr = args.lr * ((epoch + 1) / args.warmup_epochs)
154 | else:
155 | lr = args.lr
156 | lr = lr * 0.5 ** ((epoch + 1) / args.decay_epochs)
157 | optimizer.param_groups[0]['lr'] = lr
158 |
159 | optimizer.zero_grad()
160 | loss.backward(retain_graph=True)
161 | optimizer.step()
162 | backward = time.time()
163 | #print("backward", backward-forward)
164 |
165 |
166 | writer.add_scalar("train/loss_baseline", loss.item(), global_step=i)
167 | log = time.time()
168 | #print("log", log-backward)
169 |
170 | #writer.add_scalar("lr/", optimizer.param_groups[0]["lr"], global_step=i)
171 | else:
172 | if i % iters_per_epoch == 0:
173 |
174 | preds_all = torch.cat((preds_all, output.detach().cpu()), 0)
175 | target_all = torch.cat((target_all, target_set.detach().cpu()), 0)
176 |
177 | ap = [
178 | set_utils.average_precision_shapeworld(preds_all.detach().cpu().numpy(),
179 | target_all.detach().cpu().numpy(), d, args.dataset)
180 | for d in [-1] # since we are not using 3D coords #[-1., 1., 0.5, 0.25, 0.125]
181 | ]
182 |
183 | print(f"\nCurrent AP: ", ap[0], " %\n")
184 | if test_cond == "a":
185 | writer.add_scalar("test/ap_cond_a", ap[0], global_step=i)
186 | elif test_cond == "b":
187 | writer.add_scalar("test/ap_cond_b", ap[0], global_step=i)
188 | else:
189 | writer.add_scalar("test/ap", ap[0], global_step=i)
190 | return ap
191 |
192 | if train:
193 | print(f"Epoch {epoch} Train Loss: {loss.item()}")
194 |
195 |
196 | def main():
197 | args = get_args()
198 |
199 | writer = SummaryWriter(os.path.join("runs", args.name), purge_step=0)
200 |
201 | if args.cogent:
202 | if args.dataset == "clevr":
203 | dataset_train = CLEVR(args.data_dir, "trainA")
204 | dataset_test_a = CLEVR(args.data_dir, "valA")
205 | dataset_test_b = CLEVR(args.data_dir, "valB")
206 | elif args.dataset == "shapeworld4":
207 | dataset_train = SHAPEWORLD4(args.data_dir, "train_a")
208 | dataset_test_a = SHAPEWORLD4(args.data_dir, "val_a")
209 | dataset_test_b = SHAPEWORLD4(args.data_dir, "val_b")
210 | else:
211 | if args.dataset == "clevr":
212 | dataset_train = CLEVR(args.data_dir, "train")
213 | dataset_test = CLEVR(args.data_dir, "val")
214 | elif args.dataset == "shapeworld4":
215 | dataset_train = SHAPEWORLD4(args.data_dir, "train")
216 | dataset_test = SHAPEWORLD4(args.data_dir, "val")
217 |
218 | print('data loaded')
219 |
220 | if not args.eval_only:
221 | train_loader = get_loader(
222 | dataset_train, batch_size=args.batch_size, num_workers=args.num_workers
223 | )
224 | if not args.train_only:
225 | if args.cogent:
226 | test_loader_a = get_loader(
227 | dataset_test_a,
228 | batch_size=5000,
229 | num_workers=args.num_workers,
230 | shuffle=False)
231 | test_loader_b = get_loader(
232 | dataset_test_b,
233 | batch_size=5000,
234 | num_workers=args.num_workers,
235 | shuffle=False)
236 | else:
237 | test_loader = get_loader(
238 | dataset_test,
239 | batch_size=5000,
240 | num_workers=args.num_workers,
241 | shuffle=False)
242 |
243 |
244 | # print(torch.cuda.is_available())
245 | if args.dataset == "shapeworld4":
246 | net = model.SlotAttention_model(n_slots=4, n_iters=3, n_attr=15,
247 | encoder_hidden_channels=32,
248 | attention_hidden_channels=64,
249 | mlp_prediction=True,
250 | device=args.device)
251 | elif args.dataset == "clevr":
252 | net = model.SlotAttention_model(n_slots=10, n_iters=3, n_attr=15,
253 | encoder_hidden_channels=64,
254 | attention_hidden_channels=128,
255 | mlp_prediction=True,
256 | device=args.device,
257 | clevr_encoding=True)
258 |
259 | args.n_attr = net.n_attr
260 |
261 |
262 | if not args.no_cuda:
263 | net = net.cuda()
264 |
265 | optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)
266 |
267 | criterion = torch.nn.SmoothL1Loss()
268 | # Create RTPT object
269 | rtpt = RTPT(name_initials=args.credentials, experiment_name=f"Set prediction baseline ",
270 | max_iterations=args.epochs)
271 |
272 | # store args as txt file
273 | set_utils.save_args(args, writer)
274 |
275 |
276 |
277 | ap_list = []
278 | ap_list_a = [] # for cogent
279 | ap_list_b = [] # for cogent
280 | total_train_time = 0
281 | total_test_time = 0
282 | time_list =[]
283 | time_start = time.time()
284 |
285 | start_epoch = 0
286 | if args.resume:
287 | print("Loading ckpt ...")
288 | log = torch.load(args.resume)
289 | weights = log["weights"]
290 | net.load_state_dict(weights, strict=True)
291 | start_epoch = log["args"]["epochs_trained"] + 1
292 |
293 | ap_list = log["ap_list"]
294 | ap_list_a = log["ap_list_a"]
295 | ap_list_b = log["ap_list_b"]
296 | time_list =log["time_list"]
297 | print("continue with epoch", start_epoch)
298 |
299 |
300 |
301 | # Start the RTPT tracking
302 | rtpt.start()
303 |
304 |
305 | for epoch in np.arange(start_epoch, args.epochs):
306 | print("Epoch:", epoch)
307 | if not args.eval_only:
308 | train_start = time.time()
309 | run(net, train_loader, optimizer, criterion, writer, args, train=True, epoch=epoch)
310 | time_train = time.time() - train_start
311 | rtpt.step()
312 | if not args.train_only:
313 | test_start = time.time()
314 | if args.cogent:
315 | ap_a = run(net, test_loader_a, None, criterion, writer, args, test_cond="a", train=False, epoch=epoch)
316 | ap_list_a.append(ap_a)
317 | ap_b = run(net, test_loader_b, None, criterion, writer, args, test_cond="b", train=False, epoch=epoch)
318 | ap_list_b.append(ap_b)
319 | else:
320 | ap = run(net, test_loader, None, criterion, writer, args, train=False, epoch=epoch)
321 | ap_list.append(ap)
322 |
323 |
324 | time_test = time.time() - test_start
325 | if args.eval_only:
326 | exit()
327 | torch.cuda.empty_cache()
328 |
329 | args.epochs_trained = epoch
330 |
331 | total_test_time += time_test
332 | total_train_time += time_train
333 | time_list.append([total_train_time, total_test_time, time_train, time_test])
334 |
335 |
336 | results = {
337 | "name": args.name,
338 | "weights": net.state_dict(),
339 | "args": vars(args),
340 | "ap_list": ap_list, #empty for cogent
341 | "ap_list_a": ap_list_a,
342 | "ap_list_b": ap_list_b,
343 | "time": time_list}
344 |
345 | torch.save(results, os.path.join("runs", args.name, args.name))
346 | if args.eval_only:
347 | break
348 |
349 | print("total time", misc_utils.time_delta_now(time_start))
350 |
351 |
352 | if __name__ == "__main__":
353 | main()
354 |
--------------------------------------------------------------------------------
/src/experiments/baseline_slot_attention/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #MODEL="slot-attention-set-pred-shapeworld4"
4 | #MODEL="slot-attention-set-pred-shapeworld4-cogent"
5 | MODEL="slot-attention-set-pred-clevr"
6 | #MODEL="slot-attention-set-pred-clevr-cogent"
7 |
8 |
9 | #DATA="../data/shapeworld4"
10 | #DATA="../data/shapeworld_cogent"
11 | DATA="../../../data/CLEVR_v1.0"
12 | #DATA="../../../data/CLEVR_CoGenT_v1.0"
13 |
14 |
15 | #DATASET=shapeworld4
16 | DATASET=clevr
17 |
18 | DEVICE=$1
19 | SEED=$2 # 0, 1, 2, 3, 4
20 | CREDENTIALS=$3
21 | #-------------------------------------------------------------------------------#
22 |
23 | CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \
24 | --data-dir $DATA --dataset $DATASET --epochs 1000 \
25 | --name $MODEL --lr 0.0004 --batch-size 512 --n-slots 10 --n-iters-slot-att 3 --n-attr 15 --ap-log 1 --seed $SEED \
26 | --warmup-epochs 8 --decay-epochs 360 --num-workers 8 --credentials $CREDENTIALS
27 |
28 |
29 |
30 | # # CLEVR
31 | # for S in 0 1 2 3 4
32 | # do
33 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \
34 | # --data-dir $DATA --dataset $DATASET --epochs 1000 \
35 | # --name $MODEL --lr 0.0004 --batch-size 512 --n-slots 10 --n-iters-slot-att 3 --n-attr 15 --ap-log 1 --seed $S \
36 | # --warmup-epochs 8 --decay-epochs 360 --num-workers 8 --credentials $CREDENTIALS
37 | # done
38 |
39 | # # CLEVR COGENT
40 | # for S in 0 1 2 3 4
41 | # do
42 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \
43 | # --data-dir $DATA --dataset $DATASET --epochs 1000 \
44 | # --name $MODEL --lr 0.0004 --batch-size 512 --n-slots 10 --n-iters-slot-att 3 --n-attr 15 --ap-log 1 --seed $S \
45 | # --warmup-epochs 8 --decay-epochs 360 --num-workers 8 --credentials $CREDENTIALS
46 | # done
47 |
48 | # # SHAPEWORLD4
49 | # for S in 0 1 2 3 4
50 | # do
51 |
52 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \
53 | # --data-dir $DATA --dataset $DATASET --epochs 1000 \
54 | # --name $MODEL --lr 0.0004 --batch-size 512 --n-slots 4 --n-iters-slot-att 3 --n-attr 15 --ap-log 1 --seed $S \
55 | # --warmup-epochs 8 --decay-epochs 360 --num-workers 8 --credentials $CREDENTIALS
56 | # done
57 |
58 | # SHAPEWORLD4 COGENT
59 | # for S in 0 1 2 3 4
60 | # do
61 |
62 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \
63 | # --data-dir $DATA --dataset $DATASET --epochs 1000 \
64 | # --name $MODEL --lr 0.0004 --batch-size 512 --n-slots 4 --n-iters-slot-att 3 --n-attr 15 --ap-log 1 --seed $S \
65 | # --warmup-epochs 8 --decay-epochs 360 --num-workers 8 --credentialpns $CREDENTIALS --cogent
66 | # done
67 |
68 |
--------------------------------------------------------------------------------
/src/experiments/mnist_top_k/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ml-research/SLASH/79d158dbd670d56d8bf9043640635c24a18a9c57/src/experiments/mnist_top_k/__init__.py
--------------------------------------------------------------------------------
/src/experiments/mnist_top_k/dataGen.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | import numpy as np
4 |
5 | class MNIST_Addition(Dataset):
6 |
7 | def __init__(self, dataset, examples, num_i, flat_for_pc):
8 | self.data = list()
9 | self.dataset = dataset
10 | self.num_i = num_i
11 | self.flat_for_pc = flat_for_pc
12 |
13 | with open(examples) as f:
14 | for line in f:
15 | line = line.strip().split(' ')
16 | self.data.append(tuple([int(i) for i in line]))
17 |
18 |
19 | def __getitem__(self, index):
20 | if self.num_i == 2:
21 | i1, i2, l = self.data[index]
22 | l = ':- not addition(i1, i2, {}).'.format(l)
23 | if self.flat_for_pc:
24 | return {'i1': self.dataset[i1][0].flatten(), 'i2': self.dataset[i2][0].flatten()}, l
25 | else:
26 | return {'i1': self.dataset[i1][0], 'i2': self.dataset[i2][0]}, l
27 |
28 | elif self.num_i == 3:
29 | i1, i2, i3, l = self.data[index]
30 | l = ':- not addition(i1, i2, i3, {}).'.format(l)
31 | if self.flat_for_pc:
32 | return {'i1': self.dataset[i1][0].flatten(), 'i2': self.dataset[i2][0].flatten(), 'i3': self.dataset[i3][0].flatten()}, l
33 | else:
34 | return {'i1': self.dataset[i1][0], 'i2': self.dataset[i2][0], 'i3': self.dataset[i3][0]}, l
35 |
36 | elif self.num_i == 4:
37 | i1, i2, i3, i4, l = self.data[index]
38 | l = ':- not addition(i1, i2, i3, i4, {}).'.format(l)
39 | if self.flat_for_pc:
40 | return {'i1': self.dataset[i1][0].flatten(), 'i2': self.dataset[i2][0].flatten(), 'i3': self.dataset[i3][0].flatten(), 'i4': self.dataset[i4][0].flatten()}, l
41 | else:
42 | return {'i1': self.dataset[i1][0], 'i2': self.dataset[i2][0], 'i3': self.dataset[i3][0], 'i4': self.dataset[i4][0]}, l
43 |
44 | elif self.num_i == 6:
45 | i1, i2, i3, i4, i5,i6, l = self.data[index]
46 | l = ':- not addition(i1, i2, i3, i4, i5, i6, {}).'.format(l)
47 | if self.flat_for_pc:
48 | return {'i1': self.dataset[i1][0].flatten(), 'i2': self.dataset[i2][0].flatten(), 'i3': self.dataset[i3][0].flatten(), 'i4': self.dataset[i4][0].flatten(), 'i5': self.dataset[i5][0].flatten(), 'i6': self.dataset[i6][0].flatten()}, l
49 | else:
50 | return {'i1': self.dataset[i1][0], 'i2': self.dataset[i2][0], 'i3': self.dataset[i3][0], 'i4': self.dataset[i4][0], 'i5': self.dataset[i5][0], 'i6': self.dataset[i6][0]}, l
51 |
52 |
53 | def __len__(self):
54 | return len(self.data)
--------------------------------------------------------------------------------
/src/experiments/mnist_top_k/network_nn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class Net_nn(nn.Module):
5 | def __init__(self):
6 | super(Net_nn, self).__init__()
7 | self.encoder = nn.Sequential(
8 | nn.Conv2d(1, 6, 5), # 6 is the output chanel size; 5 is the kernal size; 1 (chanel) 28 28 -> 6 24 24
9 | nn.MaxPool2d(2, 2), # kernal size 2; stride size 2; 6 24 24 -> 6 12 12
10 | nn.ReLU(True), # inplace=True means that it will modify the input directly thus save memory
11 | nn.Conv2d(6, 16, 5), # 6 12 12 -> 16 8 8
12 | nn.MaxPool2d(2, 2), # 16 8 8 -> 16 4 4
13 | nn.ReLU(True)
14 | )
15 | self.classifier = nn.Sequential(
16 | nn.Linear(16 * 4 * 4, 120),
17 | nn.ReLU(),
18 | nn.Linear(120, 84),
19 | nn.ReLU(),
20 | nn.Linear(84, 10),
21 | nn.Softmax(1)
22 | )
23 |
24 | def forward(self, x, marg_idx=None, type=1):
25 |
26 | assert type == 1, "only posterior computations are available for this network"
27 |
28 | # If the list of the pixel numbers to be marginalised is given,
29 | # then genarate a marginalisation mask from it and apply to the
30 | # tensor 'x'
31 | if marg_idx:
32 | batch_size = x.shape[0]
33 | with torch.no_grad():
34 | marg_mask = torch.ones_like(x, device=x.device).reshape(batch_size, 1, -1)
35 | marg_mask[:, :, marg_idx] = 0
36 | marg_mask = marg_mask.reshape_as(x)
37 | marg_mask.requires_grad_(False)
38 | x = torch.einsum('ijkl,ijkl->ijkl', x, marg_mask)
39 | x = self.encoder(x)
40 | x = x.view(-1, 16 * 4 * 4)
41 | x = self.classifier(x)
42 | return x
43 |
--------------------------------------------------------------------------------
/src/experiments/mnist_top_k/train.py:
--------------------------------------------------------------------------------
1 | print("start importing...")
2 |
3 | import time
4 | import sys
5 | import argparse
6 | import datetime
7 |
8 | sys.path.append('../../')
9 | sys.path.append('../../SLASH/')
10 | sys.path.append('../../EinsumNetworks/src/')
11 |
12 |
13 | #torch, numpy, ...
14 | import torch
15 | from torch.utils.tensorboard import SummaryWriter
16 | from torchvision.transforms import transforms
17 | import torchvision
18 |
19 | import numpy as np
20 |
21 | #own modules
22 | from dataGen import MNIST_Addition
23 | from einsum_wrapper import EiNet
24 | from network_nn import Net_nn
25 |
26 | #import slash
27 | from slash import SLASH
28 |
29 | import utils
30 | from utils import set_manual_seed
31 | from pathlib import Path
32 | from rtpt import RTPT
33 |
34 | print("...done")
35 |
36 |
37 | def get_args():
38 | parser = argparse.ArgumentParser()
39 |
40 | parser.add_argument(
41 | "--seed", type=int, default=10, help="Random generator seed for all frameworks"
42 | )
43 | parser.add_argument(
44 | "--epochs", type=int, default=10, help="Number of epochs to train with"
45 | )
46 | parser.add_argument(
47 | "--lr", type=float, default=0.01, help="Learning rate of model"
48 | )
49 | parser.add_argument(
50 | "--network-type",
51 | choices=["nn","pc"],
52 | help="The type of external to be used e.g. neural net or probabilistic circuit",
53 | )
54 | parser.add_argument(
55 | "--pc-structure",
56 | choices=["poon-domingos","binary-trees"],
57 | help="The type of external to be used e.g. neural net or probabilistic circuit",
58 | )
59 |
60 | parser.add_argument(
61 | "--method",
62 | choices=["exact","top_k","same"],
63 | help="How many images should be used in the addition",
64 | )
65 | parser.add_argument(
66 | "--k", type=int, default=0, help="Maximum number of stable model to be used"
67 | )
68 | parser.add_argument(
69 | "--images-per-addition",
70 | choices=["2","3","4","6"],
71 | help="How many images should be used in the addition",
72 | )
73 | parser.add_argument(
74 | "--batch-size", type=int, default=100, help="Batch size to train with"
75 | )
76 | parser.add_argument(
77 | "--num-workers", type=int, default=6, help="Number of threads for data loader"
78 | )
79 |
80 | parser.add_argument(
81 | "--p-num", type=int, default=8, help="Number of processes to devide the batch for parallel processing"
82 | )
83 |
84 | parser.add_argument("--credentials", type=str, help="Credentials for rtpt")
85 |
86 |
87 | args = parser.parse_args()
88 |
89 | if args.network_type == 'pc':
90 | args.use_pc = True
91 | else:
92 | args.use_pc = False
93 |
94 | return args
95 |
96 |
97 | def slash_mnist_addition():
98 |
99 | args = get_args()
100 | print(args)
101 |
102 |
103 | # Set the seeds for PRNG
104 | set_manual_seed(args.seed)
105 |
106 | # Create RTPT object
107 | rtpt = RTPT(name_initials=args.credentials, experiment_name='SLASH MNIST pick-k', max_iterations=args.epochs)
108 |
109 | # Start the RTPT tracking
110 | rtpt.start()
111 |
112 | i_num = int(args.images_per_addition)
113 | if i_num == 2:
114 | program = '''
115 | img(i1). img(i2).
116 | addition(A,B,N):- digit(0,+A,-N1), digit(0,+B,-N2), N=N1+N2, A!=B.
117 | npp(digit(1,X), [0,1,2,3,4,5,6,7,8,9]) :- img(X).
118 | '''
119 |
120 | elif i_num == 3:
121 | program = '''
122 | img(i1). img(i2). img(i3).
123 | addition(A,B,C,N):- digit(0,+A,-N1), digit(0,+B,-N2), digit(0,+C,-N3), N=N1+N2+N3, A!=B, A!=C, B!=C.
124 | npp(digit(1,X), [0,1,2,3,4,5,6,7,8,9]) :- img(X).
125 | '''
126 |
127 | elif i_num == 4:
128 | program = '''
129 | img(i1). img(i2). img(i3). img(i4).
130 | addition(A,B,C,D,N):- digit(0,+A,-N1), digit(0,+B,-N2), digit(0,+C,-N3), digit(0,+D,-N4), N=N1+N2+N3+N4, A != B, A!=C, A!=D, B!= C, B!= D, C!=D.
131 | npp(digit(1,X), [0,1,2,3,4,5,6,7,8,9]) :- img(X).
132 | '''
133 |
134 | elif i_num == 6:
135 | program = '''
136 | img(i1). img(i2). img(i3). img(i4). img(i5). img(i6).
137 | addition(i1,i2,i3,i4,i5,i6,N):- digit(0,+A,-N1), digit(0,+B,-N2), digit(0,+C,-N3), digit(0,+D,-N4), digit(0,+E,-N5), digit(0,+F,-N6), N=N1+N2+N3+N4+N5+N6, A != B, A!=C, A!=D, A!=E, A!=F, B!= C, B!= D, B!=E, B!=F, C!=D, C!=E, C!=F, D!=E, D!=F, E!=F.
138 | npp(digit(1,X), [0,1,2,3,4,5,6,7,8,9]) :- img(X).
139 | '''
140 |
141 | exp_name= str(args.method)+"/" +args.network_type+"_i"+str(i_num)+"_k"+ str(args.k)
142 |
143 | saveModelPath = 'data/'+exp_name+'/slash_digit_addition_models_seed'+str(args.seed)+'.pt'
144 | Path("data/"+exp_name+"/").mkdir(parents=True, exist_ok=True)
145 |
146 |
147 | #use neural net or probabilisitc circuit
148 | if args.network_type == 'pc':
149 |
150 | #setup new SLASH program given the network parameters
151 | if args.pc_structure == 'binary-trees':
152 | m = EiNet(structure = 'binary-trees',
153 | depth = 3,
154 | num_repetitions = 20,
155 | use_em = False,
156 | num_var = 784,
157 | class_count = 10,
158 | learn_prior = True)
159 | elif args.pc_structure == 'poon-domingos':
160 | m = EiNet(structure = 'poon-domingos',
161 | pd_num_pieces = [4,7,28],
162 | use_em = False,
163 | num_var = 784,
164 | class_count = 10,
165 | pd_width = 28,
166 | pd_height = 28,
167 | learn_prior = True)
168 | else:
169 | print("pc structure learner unknown")
170 |
171 | else:
172 | m = Net_nn()
173 |
174 |
175 | #trainable paramas
176 | num_trainable_params = sum(p.numel() for p in m.parameters() if p.requires_grad)
177 | num_params = sum(p.numel() for p in m.parameters())
178 | print("training with {} trainable params and {} params in total".format(num_trainable_params,num_params))
179 |
180 |
181 | #create the SLASH Program
182 | nnMapping = {'digit': m}
183 | optimizers = {'digit': torch.optim.Adam(m.parameters(), lr=args.lr, eps=1e-7)}
184 | SLASHobj = SLASH(program, nnMapping, optimizers)
185 | SLASHobj.grad_comp_device ='cpu' #set gradient computation to cpu
186 |
187 |
188 | #metric lists
189 | train_accuracy_list = []
190 | test_accuracy_list = []
191 | confusion_matrix_list = []
192 | loss_list = []
193 | startTime = time.time()
194 |
195 | forward_time_list = []
196 | asp_time_list = []
197 | backward_time_list = []
198 | sm_per_batch_list = []
199 | train_test_times = []
200 |
201 | #load data
202 | #if we are using spns we need to flatten the data(Tensor has form [bs, 784])
203 | if args.use_pc:
204 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081, )), transforms.Lambda(lambda x: torch.flatten(x))])
205 | #if not we can keep the dimensions(Tensor has form [bs,28,28])
206 | else:
207 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081, ))])
208 | data_path = 'data/labels/train_data_s'+str(i_num)+'.txt'
209 |
210 | mnist_addition_dataset = MNIST_Addition(torchvision.datasets.MNIST(root='./data/', train=True, download=True, transform=transform), data_path, i_num, args.use_pc)
211 | train_dataset_loader = torch.utils.data.DataLoader(mnist_addition_dataset, shuffle=True,batch_size=args.batch_size,pin_memory=True, num_workers=8)
212 |
213 | test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('./data/', train=False, transform=transform), batch_size=100, shuffle=True)
214 | train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('./data/', train=True, transform=transform), batch_size=100, shuffle=True)
215 |
216 |
217 | # Evaluate the performanve directly after initialisation
218 | time_test = time.time()
219 | test_acc, _, confusion_matrix = SLASHobj.testNetwork('digit', test_loader, ret_confusion=True)
220 | train_acc, _ = SLASHobj.testNetwork('digit', train_loader)
221 | confusion_matrix_list.append(confusion_matrix)
222 | train_accuracy_list.append([train_acc,0])
223 | test_accuracy_list.append([test_acc, 0])
224 | timestamp_test = utils.time_delta_now(time_test, simple_format=True)
225 | timestamp_total = utils.time_delta_now(startTime, simple_format=True)
226 |
227 | train_test_times.append([0.0, timestamp_test, timestamp_total])
228 |
229 | # Save and print statistics
230 | print('Train Acc: {:0.2f}%, Test Acc: {:0.2f}%'.format(train_acc, test_acc))
231 | print('--- train time: ---', 0)
232 | print('--- test time: ---' , timestamp_test)
233 | print('--- total time from beginning: ---', timestamp_total)
234 |
235 | # Export results and networks
236 | print('Storing the trained model into {}'.format(saveModelPath))
237 | torch.save({"addition_net": m.state_dict(),
238 | "test_accuracy_list": test_accuracy_list,
239 | "train_accuracy_list":train_accuracy_list,
240 | "confusion_matrix_list":confusion_matrix_list,
241 | "num_params": num_trainable_params,
242 | "args":args,
243 | "exp_name":exp_name,
244 | "train_test_times": train_test_times,
245 | "program":program}, saveModelPath)
246 |
247 | start_e= 0
248 |
249 | # Train and evaluate the performance
250 | for e in range(start_e, args.epochs):
251 | print('Epoch {}...'.format(e+1))
252 |
253 | #one epoch of training
254 | time_train= time.time()
255 | loss, forward_time, asp_time, backward_time, sm_per_batch, model_computation_time, gradient_computation_time = SLASHobj.learn(dataset_loader = train_dataset_loader,
256 | epoch=e, method=args.method, p_num=args.p_num, k_num = args.k , same_threshold=0.99)
257 | timestamp_train = utils.time_delta_now(time_train, simple_format=True)
258 |
259 | #store detailed timesteps per batch
260 | forward_time_list.append(forward_time)
261 | asp_time_list.append(asp_time)
262 | backward_time_list.append(backward_time)
263 | sm_per_batch_list.append(sm_per_batch)
264 |
265 |
266 | time_test = time.time()
267 | test_acc, _, confusion_matrix = SLASHobj.testNetwork('digit', test_loader, ret_confusion=True)
268 | confusion_matrix_list.append(confusion_matrix)
269 | train_acc, _ = SLASHobj.testNetwork('digit', train_loader)
270 | train_accuracy_list.append([train_acc,e])
271 | test_accuracy_list.append([test_acc, e])
272 | timestamp_test = utils.time_delta_now(time_test, simple_format=True)
273 | timestamp_total = utils.time_delta_now(startTime, simple_format=True)
274 | loss_list.append(loss)
275 | train_test_times.append([timestamp_train, timestamp_test, timestamp_total])
276 |
277 | # Save and print statistics
278 | print('Train Acc: {:0.2f}%, Test Acc: {:0.2f}%'.format(train_acc, test_acc))
279 | print('--- train time: ---', timestamp_train)
280 | print('--- test time: ---' , timestamp_test)
281 | print('--- total time from beginning: ---', timestamp_total)
282 |
283 | # Export results and networks
284 | print('Storing the trained model into {}'.format(saveModelPath))
285 | torch.save({"addition_net": m.state_dict(),
286 | "resume": {
287 | "optimizer_digit":optimizers['digit'].state_dict(),
288 | "epoch":e
289 | },
290 | "test_accuracy_list": test_accuracy_list,
291 | "train_accuracy_list":train_accuracy_list,
292 | "confusion_matrix_list":confusion_matrix_list,
293 | "num_params": num_trainable_params,
294 | "args":args,
295 | "exp_name":exp_name,
296 | "train_test_times": train_test_times,
297 | "forward_time_list":forward_time_list,
298 | "asp_time_list":asp_time_list,
299 | "backward_time_list":backward_time_list,
300 | "sm_per_batch_list":sm_per_batch_list,
301 | "loss": loss_list,
302 | "program":program}, saveModelPath)
303 |
304 | # Update the RTPT
305 | rtpt.step(subtitle=f"accuracy={test_acc:2.2f}")
306 |
307 |
308 |
309 |
310 | if __name__ == "__main__":
311 | slash_mnist_addition()
--------------------------------------------------------------------------------
/src/experiments/mnist_top_k/train_baseline.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DEVICE=$1
4 | SEED=$2 # 0, 1, 2, 3, 4
5 | CREDENTIALS=$3
6 |
7 |
8 | METHOD=exact # ibu, top_k, exact
9 | K=0 #1,3,5,10
10 |
11 | #-------------------------------------------------------------------------------#
12 | # Train on CLEVR_v1 with cnn model
13 |
14 | CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \
15 | --epochs 30 \
16 | --batch-size 100 --seed $SEED --method=$METHOD --images-per-addition=2 --k=$K \
17 | --network-type nn --lr 0.005 \
18 | --num-workers 0 --p-num 8 --credentials $CREDENTIALS
19 |
20 |
21 | #NN
22 | # for S in 0 1 2 3 4
23 | # do
24 |
25 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \
26 | # --epochs 30 \
27 | # --batch-size 100 --seed $S --method=$METHOD --images-per-addition=2 --k=$K \
28 | # --network-type nn --lr 0.005 \
29 | # --num-workers 0 --p-num 8 --credentials $CREDENTIALS
30 |
31 |
32 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \
33 | # --epochs 30 \
34 | # --batch-size 100 --seed $S --method=$METHOD --images-per-addition=3 --k=$K \
35 | # --network-type nn --lr 0.005 \
36 | # --num-workers 0 --p-num 8 --credentials $CREDENTIALS
37 |
38 |
39 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \
40 | # --epochs 30 \
41 | # --batch-size 100 --seed $S --method=$METHOD --images-per-addition=4 --k=$K \
42 | # --network-type nn --lr 0.005 \
43 | # --num-workers 0 --p-num 8 --credentials $CREDENTIALS
44 | # done
45 |
46 |
47 | #PC
48 | # for S in 0 1 2 3 4
49 | # do
50 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \
51 | # --epochs 30 \
52 | # --batch-size 100 --seed $S --method=$METHOD --images-per-addition=2 --k=$K \
53 | # --network-type pc --pc-structure poon-domingos --lr 0.01 \
54 | # --num-workers 0 --p-num 8 --credentials $CREDENTIALS
55 |
56 |
57 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \
58 | # --epochs 30 \
59 | # --batch-size 100 --seed $S --method=$METHOD --images-per-addition=3 --k=$K \
60 | # --network-type pc --pc-structure poon-domingos --lr 0.01 \
61 | # --num-workers 0 --p-num 8 --credentials $CREDENTIALS
62 |
63 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \
64 | # --epochs 30 \
65 | # --batch-size 100 --seed $S --method=$METHOD --images-per-addition=4 --k=$K \
66 | # --network-type pc --pc-structure poon-domingos --lr 0.01 \
67 | # --num-workers 0 --p-num 8 --credentials $CREDENTIALS
68 | # done
--------------------------------------------------------------------------------
/src/experiments/mnist_top_k/train_same.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DEVICE=$1
4 | SEED=$2 # 0, 1, 2, 3, 4
5 | CREDENTIALS=$3
6 |
7 | METHOD=same # same, top_k, exact
8 | K=0 #1,3,5,10
9 |
10 | #-------------------------------------------------------------------------------#
11 | # Train MNIST with SAME
12 |
13 |
14 | CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \
15 | --epochs 30 \
16 | --batch-size 100 --seed 42 --method=$METHOD --images-per-addition=2 --k=$K \
17 | --network-type nn --lr 0.005 \
18 | --num-workers 0 --p-num 8 --credentials $CREDENTIALS
19 |
20 |
21 |
22 | #NN
23 | # for S in 0 1 2 3 4
24 | # do
25 |
26 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \
27 | # --epochs 30 \
28 | # --batch-size 100 --seed $S --method=$METHOD --images-per-addition=2 --k=$K \
29 | # --network-type nn --lr 0.005 \
30 | # --num-workers 0 --p-num 8 --credentials $CREDENTIALS
31 |
32 |
33 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \
34 | # --epochs 30 \
35 | # --batch-size 100 --seed $S --method=$METHOD --images-per-addition=3 --k=$K \
36 | # --network-type nn --lr 0.005 \
37 | # --num-workers 0 --p-num 8 --credentials $CREDENTIALS
38 |
39 |
40 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \
41 | # --epochs 30 \
42 | # --batch-size 100 --seed $S --method=$METHOD --images-per-addition=4 --k=$K \
43 | # --network-type nn --lr 0.005 \
44 | # --num-workers 0 --p-num 8 --credentials $CREDENTIALS
45 | # done
46 |
47 |
48 | #PC
49 | # for S in 0 1 2 3 4
50 | # do
51 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \
52 | # --epochs 30 \
53 | # --batch-size 100 --seed $S --method=$METHOD --images-per-addition=2 --k=$K \
54 | # --network-type pc --pc-structure poon-domingos --lr 0.01 \
55 | # --num-workers 0 --p-num 8 --credentials $CREDENTIALS
56 |
57 |
58 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \
59 | # --epochs 30 \
60 | # --batch-size 100 --seed $S --method=$METHOD --images-per-addition=3 --k=$K \
61 | # --network-type pc --pc-structure poon-domingos --lr 0.01 \
62 | # --num-workers 0 --p-num 8 --credentials $CREDENTIALS
63 |
64 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \
65 | # --epochs 30 \
66 | # --batch-size 100 --seed $S --method=$METHOD --images-per-addition=4 --k=$K \
67 | # --network-type pc --pc-structure poon-domingos --lr 0.01 \
68 | # --num-workers 0 --p-num 8 --credentials $CREDENTIALS
69 | # done
--------------------------------------------------------------------------------
/src/experiments/mnist_top_k/train_top_k.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DEVICE=$1
4 | SEED=$2 # 0, 1, 2, 3, 4
5 | CREDENTIALS=$3
6 |
7 | METHOD=top_k # ibu, top-k, exact
8 | K=10 # 1, 3, 5, 10
9 |
10 | #-------------------------------------------------------------------------------#
11 | # Train on CLEVR_v1 with cnn model
12 |
13 | CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \
14 | --epochs 30 \
15 | --batch-size 100 --seed 7 --method=$METHOD --images-per-addition=2 --k=1 \
16 | --network-type nn --lr 0.005 \
17 | --num-workers 0 --p-num 8 --credentials $CREDENTIALS
18 |
19 |
20 |
21 |
22 | #NN
23 | # for S in 0 1 2 3 4
24 | # do
25 |
26 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \
27 | # --epochs 30 \
28 | # --batch-size 100 --seed $S --method=$METHOD --images-per-addition=2 --k=$K \
29 | # --network-type nn --lr 0.005 \
30 | # --num-workers 0 --p-num 8 --credentials $CREDENTIALS
31 |
32 |
33 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \
34 | # --epochs 30 \
35 | # --batch-size 100 --seed $S --method=$METHOD --images-per-addition=3 --k=$K \
36 | # --network-type nn --lr 0.005 \
37 | # --num-workers 0 --p-num 8 --credentials $CREDENTIALS
38 |
39 |
40 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \
41 | # --epochs 30 \
42 | # --batch-size 100 --seed $S --method=$METHOD --images-per-addition=4 --k=$K \
43 | # --network-type nn --lr 0.005 \
44 | # --num-workers 0 --p-num 8 --credentials $CREDENTIALS
45 | # done
46 |
47 |
48 | #PC
49 | # for S in 0 1 2 3 4
50 | # do
51 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \
52 | # --epochs 30 \
53 | # --batch-size 100 --seed $S --method=$METHOD --images-per-addition=2 --k=$K \
54 | # --network-type pc --pc-structure poon-domingos --lr 0.01 \
55 | # --num-workers 0 --p-num 8 --credentials $CREDENTIALS
56 |
57 |
58 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \
59 | # --epochs 30 \
60 | # --batch-size 100 --seed $S --method=$METHOD --images-per-addition=3 --k=$K \
61 | # --network-type pc --pc-structure poon-domingos --lr 0.01 \
62 | # --num-workers 0 --p-num 8 --credentials $CREDENTIALS
63 |
64 | # CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \
65 | # --epochs 30 \
66 | # --batch-size 100 --seed $S --method=$METHOD --images-per-addition=4 --k=$K \
67 | # --network-type pc --pc-structure poon-domingos --lr 0.01 \
68 | # --num-workers 0 --p-num 8 --credentials $CREDENTIALS
69 | # done
--------------------------------------------------------------------------------
/src/experiments/slash_attention/clevr/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ml-research/SLASH/79d158dbd670d56d8bf9043640635c24a18a9c57/src/experiments/slash_attention/clevr/__init__.py
--------------------------------------------------------------------------------
/src/experiments/slash_attention/clevr/auxiliary.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from pathlib import Path
4 |
5 |
6 | def get_files_names_and_paths(root:str='./data/CLEVR_v1.0/', mode:str='val', obj_num:int=4):
7 | data_file = Path(os.path.join(root, 'scenes','CLEVR_'+ mode+"_scenes.json"))
8 | data_file.parent.mkdir(parents=True, exist_ok=True)
9 | if data_file.exists():
10 | print('File exists. Parsing file...')
11 | else:
12 | print(f'The JSON file {data_file} does not exist!')
13 | quit()
14 | img_paths = []
15 | files_names = []
16 | with open(data_file, 'r') as json_file:
17 | json_data = json.load(json_file)
18 |
19 | for scene in json_data['scenes']:
20 | if len(scene['objects']) <= obj_num:
21 | img_paths.append(Path(os.path.join(root,'images/'+mode+'/'+scene['image_filename'])))
22 | files_names.append(scene['image_filename'])
23 |
24 | print("...done ")
25 | return img_paths, files_names
26 |
27 |
28 | def get_slash_program(obj_num:int=4):
29 | program = ''
30 | if obj_num == 10:
31 | program ='''
32 | slot(s1).
33 | slot(s2).
34 | slot(s3).
35 | slot(s4).
36 | slot(s5).
37 | slot(s6).
38 | slot(s7).
39 | slot(s8).
40 | slot(s9).
41 | slot(s10).
42 |
43 | name(o1).
44 | name(o2).
45 | name(o3).
46 | name(o4).
47 | name(o5).
48 | name(o6).
49 | name(o7).
50 | name(o8).
51 | name(o9).
52 | name(o10).
53 | '''
54 | elif obj_num ==4:
55 | program ='''
56 | slot(s1).
57 | slot(s2).
58 | slot(s3).
59 | slot(s4).
60 |
61 | name(o1).
62 | name(o2).
63 | name(o3).
64 | name(o4).
65 | '''
66 | elif obj_num ==6:
67 | program ='''
68 | slot(s1).
69 | slot(s2).
70 | slot(s3).
71 | slot(s4).
72 | slot(s5).
73 | slot(s6).
74 |
75 | name(o1).
76 | name(o2).
77 | name(o3).
78 | name(o4).
79 | name(o5).
80 | name(o6).
81 | '''
82 | else:
83 | print(f'The number of objects {obj_num} is wrong!')
84 | quit()
85 | program +='''
86 | %assign each name a slot
87 | %{slot_name_comb(N,X): slot(X)}=1 :- name(N). %problem we have dublicated slots
88 |
89 | %remove each model which has multiple slots asigned to the same name
90 | %:- slot_name_comb(N1,X1), slot_name_comb(N2,X2), X1 == X2, N1 != N2.
91 |
92 | %build the object ontop of the slot assignment
93 | object(N, S, M, P, C) :- size(0, +X, -S), material(0, +X, -M), shape(0, +X, -P), color(0, +X, -C), slot(X), name(N), slot_name_comb(N,X).
94 |
95 | %define the SPNs
96 | npp(size(1,X),[small, large, bg]) :- slot(X).
97 | npp(material(1,X),[rubber, metal, bg]) :- slot(X).
98 | npp(shape(1,X),[cube, sphere, cylinder, bg]) :- slot(X).
99 | npp(color(1,X),[gray, red, blue, green, brown, purple, cyan, yellow, bg]) :- slot(X).
100 |
101 | '''
102 | return program
--------------------------------------------------------------------------------
/src/experiments/slash_attention/clevr/dataGen.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | from torch.utils.data import Dataset
4 | from torchvision.transforms import transforms
5 |
6 | # from torch.utils.data import Dataset
7 | # from torchvision import transforms
8 | from skimage import io
9 | import os
10 | import numpy as np
11 | import torch
12 | import matplotlib.pyplot as plt
13 | from PIL import ImageFile
14 | ImageFile.LOAD_TRUNCATED_IMAGES = True
15 | import json
16 |
17 | from tqdm import tqdm
18 |
19 |
20 | def object_encoding(size, material, shape, color ):
21 |
22 | #size (small, large, bg)
23 | if size == "small":
24 | size_enc = [1,0,0]
25 | elif size == "large":
26 | size_enc = [0,1,0]
27 | elif size == "bg":
28 | size_enc = [0,0,1]
29 |
30 | #material (rubber, metal, bg)
31 | if material == "rubber":
32 | material_enc = [1,0,0]
33 | elif material == "metal":
34 | material_enc = [0,1,0]
35 | elif material == "bg":
36 | material_enc = [0,0,1]
37 |
38 | #shape (cube, sphere, cylinder, bg)
39 | if shape == "cube":
40 | shape_enc = [1,0,0,0]
41 | elif shape == "sphere":
42 | shape_enc = [0,1,0,0]
43 | elif shape == "cylinder":
44 | shape_enc = [0,0,1,0]
45 | elif shape == "bg":
46 | shape_enc = [0,0,0,1]
47 |
48 | #color (gray, red, blue, green, brown, purple, cyan, yellow, bg)
49 | #color (gray, red, blue, green, brown, purple, cyan, yellow, bg)
50 | #color 1 2 3 4 5 6 7 8 9
51 | if color == "gray":
52 | color_enc = [1,0,0,0,0,0,0,0,0]
53 | elif color == "red":
54 | color_enc = [0,1,0,0,0,0,0,0,0]
55 | elif color == "blue":
56 | color_enc = [0,0,1,0,0,0,0,0,0]
57 | elif color == "green":
58 | color_enc = [0,0,0,1,0,0,0,0,0]
59 | elif color == "brown":
60 | color_enc = [0,0,0,0,1,0,0,0,0]
61 | elif color == "purple":
62 | color_enc = [0,0,0,0,0,1,0,0,0]
63 | elif color == "cyan":
64 | color_enc = [0,0,0,0,0,0,1,0,0]
65 | elif color == "yellow":
66 | color_enc = [0,0,0,0,0,0,0,1,0]
67 | elif color == "bg":
68 | color_enc = [0,0,0,0,0,0,0,0,1]
69 |
70 | return size_enc + material_enc + shape_enc + color_enc +[1]
71 |
72 |
73 |
74 |
75 |
76 | class CLEVR(Dataset):
77 | def __init__(self, root, mode, img_paths=None, files_names=None, obj_num=None):
78 | self.root = root # The root folder of the dataset
79 | self.mode = mode # The mode of 'train' or 'val'
80 | self.files_names = files_names # The list of the files names with correct nuber of objects
81 | if obj_num is not None:
82 | self.obj_num = obj_num # The upper limit of number of objects
83 | else:
84 | self.obj_num = 10
85 |
86 | assert os.path.exists(root), 'Path {} does not exist'.format(root)
87 |
88 | #list of sorted image paths
89 | self.img_paths = []
90 | if img_paths:
91 | self.img_paths = img_paths
92 | else:
93 | #open directory and save all image paths
94 | for file in os.scandir(os.path.join(root, 'images', mode)):
95 | img_path = file.path
96 | if '.png' in img_path:
97 | self.img_paths.append(img_path)
98 |
99 | self.img_paths.sort()
100 | count = 0
101 |
102 | #target maps of the form {'target:idx': query string} or {'target:idx': obj encoding}
103 | self.query_map = {}
104 | self.obj_map = {}
105 |
106 | count = 0
107 | #We have up to 10 objects in the image, load the json file
108 | with open(os.path.join(root, 'scenes','CLEVR_'+ mode+"_scenes.json")) as f:
109 | data = json.load(f)
110 |
111 | #iterate over each scene and create the query string and obj encoding
112 | print("parsing scences")
113 | for scene in tqdm(data['scenes']):
114 | target_query = ""
115 | obj_encoding_list = []
116 |
117 | if self.files_names:
118 | if any(scene['image_filename'] in file_name for file_name in files_names):
119 | num_objects = 0
120 | for idx, obj in enumerate(scene['objects']):
121 | target_query += " :- not object(o{}, {}, {}, {}, {}).".format(idx+1, obj['size'], obj['material'], obj['shape'], obj['color'])
122 | obj_encoding_list.append(object_encoding(obj['size'], obj['material'], obj['shape'], obj['color']))
123 | num_objects = idx+1 #store the num of objects
124 | #fill in background objects
125 | for idx in range(num_objects, self.obj_num):
126 | target_query += " :- not object(o{}, bg, bg, bg, bg).".format(idx+1)
127 | obj_encoding_list.append([0,0,1, 0,0,1, 0,0,0,1, 0,0,0,0,0,0,0,0,1, 1])
128 | self.query_map[count] = target_query
129 | self.obj_map[count] = np.array(obj_encoding_list)
130 | count += 1
131 | else:
132 | num_objects=0
133 | for idx, obj in enumerate(scene['objects']):
134 | target_query += " :- not object(o{}, {}, {}, {}, {}).".format(idx+1, obj['size'], obj['material'], obj['shape'], obj['color'])
135 | obj_encoding_list.append(object_encoding(obj['size'], obj['material'], obj['shape'], obj['color']))
136 | num_objects = idx+1 #store the num of objects
137 | #fill in background objects
138 | for idx in range(num_objects, 10):
139 | target_query += " :- not object(o{}, bg, bg, bg, bg).".format(idx+1)
140 | obj_encoding_list.append([0,0,1, 0,0,1, 0,0,0,1, 0,0,0,0,0,0,0,0,1, 1])
141 | self.query_map[scene['image_index']] = target_query
142 | self.obj_map[scene['image_index']] = np.array(obj_encoding_list)
143 |
144 | print("done")
145 | if self.files_names:
146 | print(f'Correctly found images {count} out of {len(files_names)}')
147 |
148 |
149 | #print(np.array(list(self.obj_map.values()))[0:20])
150 | def __getitem__(self, index):
151 | #get the image
152 | img_path = self.img_paths[index]
153 | img = io.imread(img_path)[:, :, :3]
154 |
155 | transform = transforms.Compose([
156 | transforms.ToPILImage(),
157 | #transforms.CenterCrop((29, 221,64, 256)), #why do we need to crop?
158 | transforms.Resize((128, 128)),
159 | transforms.ToTensor(),
160 | ])
161 |
162 | img = transform(img)
163 | img = (img - 0.5) * 2.0 # Rescale to [-1, 1].
164 |
165 | return {'im':img}, self.query_map[index] ,self.obj_map[index]
166 |
167 |
168 |
169 | def __len__(self):
170 | return len(self.img_paths)
171 |
172 |
--------------------------------------------------------------------------------
/src/experiments/slash_attention/clevr/slash_attention_clevr.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 |
5 |
6 | import train
7 | import datetime
8 |
9 |
10 | seed = 0
11 | obj_num = 10
12 | date_string = datetime.datetime.today().strftime('%d-%m-%Y')
13 | experiments = {f'CLEVR{obj_num}': {'structure':'poon-domingos', 'pd_num_pieces':[4], 'learn_prior':True,
14 | 'lr': 0.01, 'bs':512, 'epochs':1000,
15 | 'lr_warmup_steps':8, 'lr_decay_steps':360, 'use_em':False, 'resume':False,
16 | 'method':'most_prob',
17 | 'start_date':date_string, 'credentials':'DO', 'p_num':16, 'seed':seed, 'obj_num':obj_num
18 | }}
19 |
20 |
21 | for exp_name in experiments:
22 | print(exp_name)
23 | train.slash_slot_attention(exp_name, experiments[exp_name])
24 |
25 |
26 |
27 |
--------------------------------------------------------------------------------
/src/experiments/slash_attention/clevr_cogent/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ml-research/SLASH/79d158dbd670d56d8bf9043640635c24a18a9c57/src/experiments/slash_attention/clevr_cogent/__init__.py
--------------------------------------------------------------------------------
/src/experiments/slash_attention/clevr_cogent/auxiliary.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from pathlib import Path
4 |
5 |
6 | def get_files_names_and_paths(root:str='./data/CLEVR_v1.0/', mode:str='val', obj_num:int=4):
7 | data_file = Path(os.path.join(root, 'scenes','CLEVR_'+ mode+"_scenes.json"))
8 | data_file.parent.mkdir(parents=True, exist_ok=True)
9 | if data_file.exists():
10 | print('File exists. Parsing file...')
11 | else:
12 | print(f'The JSON file {data_file} does not exist!')
13 | quit()
14 | img_paths = []
15 | files_names = []
16 | with open(data_file, 'r') as json_file:
17 | json_data = json.load(json_file)
18 |
19 | for scene in json_data['scenes']:
20 | if len(scene['objects']) <= obj_num:
21 | img_paths.append(Path(os.path.join(root,'images/'+mode+'/'+scene['image_filename'])))
22 | files_names.append(scene['image_filename'])
23 |
24 | print("...done ")
25 | return img_paths, files_names
26 |
27 |
28 | def get_slash_program(obj_num:int=4):
29 | program = ''
30 | if obj_num == 10:
31 | program ='''
32 | slot(s1).
33 | slot(s2).
34 | slot(s3).
35 | slot(s4).
36 | slot(s5).
37 | slot(s6).
38 | slot(s7).
39 | slot(s8).
40 | slot(s9).
41 | slot(s10).
42 |
43 | name(o1).
44 | name(o2).
45 | name(o3).
46 | name(o4).
47 | name(o5).
48 | name(o6).
49 | name(o7).
50 | name(o8).
51 | name(o9).
52 | name(o10).
53 | '''
54 | elif obj_num ==4:
55 | program ='''
56 | slot(s1).
57 | slot(s2).
58 | slot(s3).
59 | slot(s4).
60 |
61 | name(o1).
62 | name(o2).
63 | name(o3).
64 | name(o4).
65 | '''
66 | elif obj_num ==6:
67 | program ='''
68 | slot(s1).
69 | slot(s2).
70 | slot(s3).
71 | slot(s4).
72 | slot(s5).
73 | slot(s6).
74 |
75 | name(o1).
76 | name(o2).
77 | name(o3).
78 | name(o4).
79 | name(o5).
80 | name(o6).
81 | '''
82 | else:
83 | print(f'The number of objects {obj_num} is wrong!')
84 | quit()
85 | program +='''
86 | %assign each name a slot
87 | %{slot_name_comb(N,X): slot(X)}=1 :- name(N). %problem we have dublicated slots
88 |
89 | %remove each model which has multiple slots asigned to the same name
90 | %:- slot_name_comb(N1,X1), slot_name_comb(N2,X2), X1 == X2, N1 != N2.
91 |
92 | %build the object ontop of the slot assignment
93 | object(N, S, M, P, C) :- size(0, +X, -S), material(0, +X, -M), shape(0, +X, -P), color(0, +X, -C), slot(X), name(N), slot_name_comb(N,X).
94 |
95 | %define the SPNs
96 | npp(size(1,X),[small, large, bg]) :- slot(X).
97 | npp(material(1,X),[rubber, metal, bg]) :- slot(X).
98 | npp(shape(1,X),[cube, sphere, cylinder, bg]) :- slot(X).
99 | npp(color(1,X),[gray, red, blue, green, brown, purple, cyan, yellow, bg]) :- slot(X).
100 |
101 | '''
102 | return program
--------------------------------------------------------------------------------
/src/experiments/slash_attention/clevr_cogent/dataGen.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | from torch.utils.data import Dataset
4 | from torchvision.transforms import transforms
5 |
6 | # from torch.utils.data import Dataset
7 | # from torchvision import transforms
8 | from skimage import io
9 | import os
10 | import numpy as np
11 | import torch
12 | import matplotlib.pyplot as plt
13 | from PIL import ImageFile
14 | ImageFile.LOAD_TRUNCATED_IMAGES = True
15 | import json
16 |
17 | from tqdm import tqdm
18 |
19 |
20 | def class_lookup(idx):
21 | '''
22 | Shapeworld dataset generation generates a class id for each class. Returns color, shape and object encoding given the class id.
23 | The class order was generated by ShapeWorld.
24 | '''
25 |
26 | #obj encoding ['red','blue','green','black', 'circle' , 'triangle','square', 'bg', 'object_confidence(1)']
27 |
28 | if idx == 0:
29 | return "red", "circle", [1,0,0,0 ,1,0,0,0 , 1]
30 | elif idx ==1:
31 | return "green", "circle" ,[0,0,1,0 ,1,0,0,0 , 1]
32 | elif idx ==2:
33 | return "blue", "circle", [0,1,0,0 ,1,0,0,0 , 1]
34 | elif idx ==3:
35 | return "red", "square", [1,0,0,0 ,0,0,1,0 , 1]
36 | elif idx ==4:
37 | return "green", "square", [0,0,1,0 ,0,0,1,0 , 1]
38 | elif idx ==5:
39 | return "blue", "square" , [0,1,0,0 ,0,0,1,0 , 1]
40 | elif idx ==6:
41 | return "red", "triangle", [1,0,0,0 ,0,1,0,0 , 1]
42 | elif idx ==7:
43 | return "green", "triangle", [0,0,1,0 ,0,1,0,0 , 1]
44 | elif idx ==8:
45 | return "blue", "triangle", [0,1,0,0 ,0,1,0,0 , 1]
46 |
47 | def object_encoding(size, material, shape, color ):
48 |
49 | #size (small, large, bg)
50 | if size == "small":
51 | size_enc = [1,0,0]
52 | elif size == "large":
53 | size_enc = [0,1,0]
54 | elif size == "bg":
55 | size_enc = [0,0,1]
56 |
57 | #material (rubber, metal, bg)
58 | if material == "rubber":
59 | material_enc = [1,0,0]
60 | elif material == "metal":
61 | material_enc = [0,1,0]
62 | elif material == "bg":
63 | material_enc = [0,0,1]
64 |
65 | #shape (cube, sphere, cylinder, bg)
66 | if shape == "cube":
67 | shape_enc = [1,0,0,0]
68 | elif shape == "sphere":
69 | shape_enc = [0,1,0,0]
70 | elif shape == "cylinder":
71 | shape_enc = [0,0,1,0]
72 | elif shape == "bg":
73 | shape_enc = [0,0,0,1]
74 |
75 | #color (gray, red, blue, green, brown, purple, cyan, yellow, bg)
76 | #color (gray, red, blue, green, brown, purple, cyan, yellow, bg)
77 | #color 1 2 3 4 5 6 7 8 9
78 | if color == "gray":
79 | color_enc = [1,0,0,0,0,0,0,0,0]
80 | elif color == "red":
81 | color_enc = [0,1,0,0,0,0,0,0,0]
82 | elif color == "blue":
83 | color_enc = [0,0,1,0,0,0,0,0,0]
84 | elif color == "green":
85 | color_enc = [0,0,0,1,0,0,0,0,0]
86 | elif color == "brown":
87 | color_enc = [0,0,0,0,1,0,0,0,0]
88 | elif color == "purple":
89 | color_enc = [0,0,0,0,0,1,0,0,0]
90 | elif color == "cyan":
91 | color_enc = [0,0,0,0,0,0,1,0,0]
92 | elif color == "yellow":
93 | color_enc = [0,0,0,0,0,0,0,1,0]
94 | elif color == "bg":
95 | color_enc = [0,0,0,0,0,0,0,0,1]
96 |
97 | return size_enc + material_enc + shape_enc + color_enc +[1]
98 |
99 |
100 |
101 |
102 |
103 | class CLEVR(Dataset):
104 | def __init__(self, root, mode, img_paths=None, files_names=None, obj_num=None):
105 | self.root = root # The root folder of the dataset
106 | self.mode = mode # The mode of 'train' or 'val'
107 | self.files_names = files_names # The list of the files names with correct nuber of objects
108 | if obj_num is not None:
109 | self.obj_num = obj_num # The upper limit of number of objects
110 | else:
111 | self.obj_num = 10
112 |
113 | assert os.path.exists(root), 'Path {} does not exist'.format(root)
114 |
115 | #list of sorted image paths
116 | self.img_paths = []
117 | if img_paths:
118 | self.img_paths = img_paths
119 | else:
120 | #open directory and save all image paths
121 | for file in os.scandir(os.path.join(root, 'images', mode)):
122 | img_path = file.path
123 | if '.png' in img_path:
124 | self.img_paths.append(img_path)
125 |
126 | self.img_paths.sort()
127 | count = 0
128 |
129 | #target maps of the form {'target:idx': query string} or {'target:idx': obj encoding}
130 | self.query_map = {}
131 | self.obj_map = {}
132 |
133 | count = 0
134 | #We have up to 10 objects in the image, load the json file
135 | with open(os.path.join(root, 'scenes','CLEVR_'+ mode+"_scenes.json")) as f:
136 | data = json.load(f)
137 |
138 | #iterate over each scene and create the query string and obj encoding
139 | print("parsing scences")
140 | for scene in tqdm(data['scenes']):
141 | target_query = ""
142 | obj_encoding_list = []
143 |
144 | if self.files_names:
145 | if any(scene['image_filename'] in file_name for file_name in files_names):
146 | num_objects = 0
147 | for idx, obj in enumerate(scene['objects']):
148 | target_query += " :- not object(o{}, {}, {}, {}, {}).".format(idx+1, obj['size'], obj['material'], obj['shape'], obj['color'])
149 | obj_encoding_list.append(object_encoding(obj['size'], obj['material'], obj['shape'], obj['color']))
150 | num_objects = idx+1 #store the num of objects
151 | #fill in background objects
152 | for idx in range(num_objects, self.obj_num):
153 | target_query += " :- not object(o{}, bg, bg, bg, bg).".format(idx+1)
154 | obj_encoding_list.append([0,0,1, 0,0,1, 0,0,0,1, 0,0,0,0,0,0,0,0,1, 1])
155 | self.query_map[count] = target_query
156 | self.obj_map[count] = np.array(obj_encoding_list)
157 | count += 1
158 | else:
159 | num_objects=0
160 | for idx, obj in enumerate(scene['objects']):
161 | target_query += " :- not object(o{}, {}, {}, {}, {}).".format(idx+1, obj['size'], obj['material'], obj['shape'], obj['color'])
162 | obj_encoding_list.append(object_encoding(obj['size'], obj['material'], obj['shape'], obj['color']))
163 | num_objects = idx+1 #store the num of objects
164 | #fill in background objects
165 | for idx in range(num_objects, 10):
166 | target_query += " :- not object(o{}, bg, bg, bg, bg).".format(idx+1)
167 | obj_encoding_list.append([0,0,1, 0,0,1, 0,0,0,1, 0,0,0,0,0,0,0,0,1, 1])
168 | self.query_map[scene['image_index']] = target_query
169 | self.obj_map[scene['image_index']] = np.array(obj_encoding_list)
170 |
171 | print("done")
172 | if self.files_names:
173 | print(f'Correctly found images {count} out of {len(files_names)}')
174 |
175 |
176 | #print(np.array(list(self.obj_map.values()))[0:20])
177 | def __getitem__(self, index):
178 | #get the image
179 | img_path = self.img_paths[index]
180 | img = io.imread(img_path)[:, :, :3]
181 |
182 | transform = transforms.Compose([
183 | transforms.ToPILImage(),
184 | #transforms.CenterCrop((29, 221,64, 256)), #why do we need to crop?
185 | transforms.Resize((128, 128)),
186 | transforms.ToTensor(),
187 | ])
188 |
189 | img = transform(img)
190 | img = (img - 0.5) * 2.0 # Rescale to [-1, 1].
191 |
192 | return {'im':img}, self.query_map[index] ,self.obj_map[index]
193 |
194 |
195 |
196 | def __len__(self):
197 | return len(self.img_paths)
198 |
199 |
--------------------------------------------------------------------------------
/src/experiments/slash_attention/clevr_cogent/slash_attention_clevr.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 |
5 |
6 | import train
7 |
8 | import datetime
9 |
10 |
11 |
12 |
13 |
14 | seed = 4
15 | obj_num = 10
16 | date_string = datetime.datetime.today().strftime('%d-%m-%Y')
17 | experiments = {f'CLEVR{obj_num}_seed_{seed}': {'structure':'poon-domingos', 'pd_num_pieces':[4], 'learn_prior':True,
18 | 'lr': 0.01, 'bs':512, 'epochs':1000,
19 | 'lr_warmup_steps':8, 'lr_decay_steps':360, 'use_em':False, 'resume':False,
20 | 'method':'most_prob',
21 | 'start_date':date_string, 'credentials':'DO', 'p_num':16, 'seed':seed, 'obj_num':obj_num
22 | }}
23 |
24 |
25 | for exp_name in experiments:
26 | print(exp_name)
27 | train.slash_slot_attention(exp_name, experiments[exp_name])
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/src/experiments/slash_attention/cogent/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ml-research/SLASH/79d158dbd670d56d8bf9043640635c24a18a9c57/src/experiments/slash_attention/cogent/__init__.py
--------------------------------------------------------------------------------
/src/experiments/slash_attention/cogent/dataGen.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | from torch.utils.data import Dataset
4 | from torchvision.transforms import transforms
5 |
6 | from torch.utils.data import Dataset
7 | from torchvision import transforms
8 | from skimage import io
9 | import os
10 | import numpy as np
11 | import torch
12 | from PIL import ImageFile
13 | ImageFile.LOAD_TRUNCATED_IMAGES = True
14 | import json
15 | import datasets
16 |
17 |
18 | def get_encoding(color, shape, shade, size):
19 |
20 | if color == 'red':
21 | col_enc = [1,0,0,0,0,0,0,0,0]
22 | elif color == 'blue':
23 | col_enc = [0,1,0,0,0,0,0,0,0]
24 | elif color == 'green':
25 | col_enc = [0,0,1,0,0,0,0,0,0]
26 | elif color == 'gray':
27 | col_enc = [0,0,0,1,0,0,0,0,0]
28 | elif color == 'brown':
29 | col_enc = [0,0,0,0,1,0,0,0,0]
30 | elif color == 'magenta':
31 | col_enc = [0,0,0,0,0,1,0,0,0]
32 | elif color == 'cyan':
33 | col_enc = [0,0,0,0,0,0,1,0,0]
34 | elif color == 'yellow':
35 | col_enc = [0,0,0,0,0,0,0,1,0]
36 | elif color == 'black':
37 | col_enc = [0,0,0,0,0,0,0,0,1]
38 |
39 |
40 | if shape == 'circle':
41 | shape_enc = [1,0,0,0]
42 | elif shape == 'triangle':
43 | shape_enc = [0,1,0,0]
44 | elif shape == 'square':
45 | shape_enc = [0,0,1,0]
46 | elif shape == 'bg':
47 | shape_enc = [0,0,0,1]
48 |
49 | if shade == 'bright':
50 | shade_enc = [1,0,0]
51 | elif shade =='dark':
52 | shade_enc = [0,1,0]
53 | elif shade == 'bg':
54 | shade_enc = [0,0,1]
55 |
56 |
57 | if size == 'small':
58 | size_enc = [1,0,0]
59 | elif size == 'big':
60 | size_enc = [0,1,0]
61 | elif size == 'bg':
62 | size_enc = [0,0,1]
63 |
64 | return col_enc + shape_enc + shade_enc + size_enc + [1]
65 |
66 |
67 | class SHAPEWORLD_COGENT(Dataset):
68 | def __init__(self, root, mode, ret_obj_encoding=False):
69 |
70 | datasets.maybe_download_shapeworld_cogent()
71 |
72 | self.ret_obj_encoding= ret_obj_encoding
73 | self.root = root
74 | self.mode = mode
75 | assert os.path.exists(root), 'Path {} does not exist'.format(root)
76 |
77 | #dictionary of the form {'image_idx':'img_path'}
78 | self.img_paths = {}
79 |
80 |
81 | for file in os.scandir(os.path.join(root, 'images', mode)):
82 | img_path = file.path
83 |
84 | img_path_idx = img_path.split("/")
85 | img_path_idx = img_path_idx[-1]
86 | img_path_idx = img_path_idx[:-4][6:]
87 | try:
88 | img_path_idx = int(img_path_idx)
89 | self.img_paths[img_path_idx] = img_path
90 | except:
91 | print("path:",img_path_idx)
92 |
93 |
94 |
95 | count = 0
96 |
97 | #target maps of the form {'target:idx': query string} or {'target:idx': obj encoding}
98 | self.query_map = {}
99 | self.obj_map = {}
100 |
101 | with open(os.path.join(root, 'labels', mode,"world_model.json")) as f:
102 | worlds = json.load(f)
103 |
104 | objects = 0
105 | bgs = 0
106 | #iterate over all objects
107 | for world in worlds:
108 | num_objects = 0
109 | target_query = ""
110 | obj_enc = []
111 | for entity in world['entities']:
112 |
113 | color = entity['color']['name']
114 | shape = entity['shape']['name']
115 |
116 | shade_val = entity['color']['shade']
117 | if shade_val == 0.0:
118 | shade = 'bright'
119 | else:
120 | shade = 'dark'
121 |
122 | size_val = entity['shape']['size']['x']
123 | if size_val == 0.075:
124 | size = 'small'
125 | elif size_val == 0.15:
126 | size = 'big'
127 |
128 | name = 'o' + str(num_objects+1)
129 | target_query = target_query+ ":- not object({},{},{},{},{}). ".format(name, color, shape, shade, size)
130 | obj_enc.append(get_encoding(color, shape, shade, size))
131 | num_objects += 1
132 | objects +=1
133 |
134 | #bg encodings
135 | for i in range(num_objects, 4):
136 | name = 'o' + str(num_objects+1)
137 | target_query = target_query+ ":- not object({},black,bg, bg, bg). ".format(name)
138 | obj_enc.append(get_encoding("black","bg","bg","bg"))
139 | num_objects += 1
140 | bgs +=1
141 |
142 | self.query_map[count] = target_query
143 | self.obj_map[count] = np.array(obj_enc)
144 | count+=1
145 | print("num objects",objects)
146 | print("num bgs",bgs)
147 |
148 |
149 |
150 | def __getitem__(self, index):
151 |
152 | #get the image
153 | img_path = self.img_paths[index]
154 | img = io.imread(img_path)[:, :, :3]
155 |
156 | transform = transforms.Compose([
157 | transforms.ToPILImage(),
158 | transforms.ToTensor(),
159 | ])
160 | img = transform(img)
161 | img = (img - 0.5) * 2.0 # Rescale to [-1, 1].
162 |
163 | if self.ret_obj_encoding:
164 | return {'im':img}, self.query_map[index] ,self.obj_map[index]
165 | else:
166 | return {'im':img}, self.query_map[index]
167 | def __len__(self):
168 | return len(self.img_paths)
169 |
170 |
--------------------------------------------------------------------------------
/src/experiments/slash_attention/cogent/slash_attention_cogent.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 |
5 | import train
6 | import numpy as np
7 | import torch
8 | import torchvision
9 | import datetime
10 |
11 | date_string = datetime.datetime.today().strftime('%d-%m-%Y')
12 |
13 | #Python script to start the shapeworld4 slot attention experiment
14 | #Define your experiment(s) parameters as a hashmap having the following parameters
15 | example_structure = {'experiment_name':
16 | {'structure': 'poon-domingos',
17 | 'pd_num_pieces': [4],
18 | 'lr': 0.01, #the learning rate to train the SPNs with, the slot attention module has a fixed lr=0.0004
19 | 'bs':512, #the batchsize
20 | 'epochs':1000, #number of epochs to train
21 | 'lr_warmup_steps':25, #number of epochs to warm up the slot attention module, warmup does not apply to the SPNs
22 | 'lr_decay_steps':100, #number of epochs it takes to decay to 50% of the specified lr
23 | 'start_date':"01-01-0001", #current date
24 | 'resume':False, #you can stop the experiment and set this parameter to true to load the last state and continue learning
25 | 'credentials':'AS', #your credentials for the rtpt class
26 | 'explanation': """Training on Condtion A, Testing on Condtion A and B to evaluate generalization of the model."""}}
27 |
28 |
29 |
30 | experiments ={'shapeworld4_cogent_hung':
31 | {'structure': 'poon-domingos', 'pd_num_pieces': [4],
32 | 'lr': 0.01, 'bs':512, 'epochs':1000,
33 | 'lr_warmup_steps':8, 'lr_decay_steps':360,
34 | 'start_date':date_string, 'resume':False,
35 | 'credentials':'DO', 'seed':3, 'learn_prior':True,
36 | 'p_num':16, 'hungarian_matching':True, 'method':'probabilistic_grounding_top_k',
37 | 'explanation': """Training on Condtion A, Testing on Condtion A and B to evaluate generalization of the model."""}}
38 |
39 |
40 |
41 | #train the network
42 | for exp_name in experiments:
43 | print(exp_name)
44 | train.slash_slot_attention(exp_name, experiments[exp_name])
45 |
46 |
47 |
48 |
49 |
50 |
--------------------------------------------------------------------------------
/src/experiments/slash_attention/shapeworld4/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ml-research/SLASH/79d158dbd670d56d8bf9043640635c24a18a9c57/src/experiments/slash_attention/shapeworld4/__init__.py
--------------------------------------------------------------------------------
/src/experiments/slash_attention/shapeworld4/dataGen.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | from torch.utils.data import Dataset
4 | from torchvision.transforms import transforms
5 |
6 | from torch.utils.data import Dataset
7 | from torchvision import transforms
8 | from skimage import io
9 | import os
10 | import numpy as np
11 | import torch
12 | from PIL import ImageFile
13 | ImageFile.LOAD_TRUNCATED_IMAGES = True
14 | import json
15 | import datasets
16 |
17 |
18 | def get_encoding(color, shape, shade, size):
19 |
20 | if color == 'red':
21 | col_enc = [1,0,0,0,0,0,0,0,0]
22 | elif color == 'blue':
23 | col_enc = [0,1,0,0,0,0,0,0,0]
24 | elif color == 'green':
25 | col_enc = [0,0,1,0,0,0,0,0,0]
26 | elif color == 'gray':
27 | col_enc = [0,0,0,1,0,0,0,0,0]
28 | elif color == 'brown':
29 | col_enc = [0,0,0,0,1,0,0,0,0]
30 | elif color == 'magenta':
31 | col_enc = [0,0,0,0,0,1,0,0,0]
32 | elif color == 'cyan':
33 | col_enc = [0,0,0,0,0,0,1,0,0]
34 | elif color == 'yellow':
35 | col_enc = [0,0,0,0,0,0,0,1,0]
36 | elif color == 'black':
37 | col_enc = [0,0,0,0,0,0,0,0,1]
38 |
39 |
40 | if shape == 'circle':
41 | shape_enc = [1,0,0,0]
42 | elif shape == 'triangle':
43 | shape_enc = [0,1,0,0]
44 | elif shape == 'square':
45 | shape_enc = [0,0,1,0]
46 | elif shape == 'bg':
47 | shape_enc = [0,0,0,1]
48 |
49 | if shade == 'bright':
50 | shade_enc = [1,0,0]
51 | elif shade =='dark':
52 | shade_enc = [0,1,0]
53 | elif shade == 'bg':
54 | shade_enc = [0,0,1]
55 |
56 |
57 | if size == 'small':
58 | size_enc = [1,0,0]
59 | elif size == 'big':
60 | size_enc = [0,1,0]
61 | elif size == 'bg':
62 | size_enc = [0,0,1]
63 |
64 | return col_enc + shape_enc + shade_enc + size_enc + [1]
65 |
66 |
67 | class SHAPEWORLD4(Dataset):
68 | def __init__(self, root, mode, ret_obj_encoding=False):
69 |
70 | datasets.maybe_download_shapeworld4()
71 |
72 | self.ret_obj_encoding = ret_obj_encoding
73 | self.root = root
74 | self.mode = mode
75 | assert os.path.exists(root), 'Path {} does not exist'.format(root)
76 |
77 | #dictionary of the form {'image_idx':'img_path'}
78 | self.img_paths = {}
79 |
80 |
81 | for file in os.scandir(os.path.join(root, 'images', mode)):
82 | img_path = file.path
83 |
84 | img_path_idx = img_path.split("/")
85 | img_path_idx = img_path_idx[-1]
86 | img_path_idx = img_path_idx[:-4][6:]
87 | try:
88 | img_path_idx = int(img_path_idx)
89 | self.img_paths[img_path_idx] = img_path
90 | except:
91 | print("path:",img_path_idx)
92 |
93 |
94 | count = 0
95 |
96 | #target maps of the form {'target:idx': query string} or {'target:idx': obj encoding}
97 | self.query_map = {}
98 | self.obj_map = {}
99 |
100 | with open(os.path.join(root, 'labels', mode,"world_model.json")) as f:
101 | worlds = json.load(f)
102 |
103 | #iterate over all objects
104 | for world in worlds:
105 | num_objects = 0
106 | target_query = ""
107 | obj_enc = []
108 | for entity in world['entities']:
109 |
110 | color = entity['color']['name']
111 | shape = entity['shape']['name']
112 |
113 | shade_val = entity['color']['shade']
114 | if shade_val == 0.0:
115 | shade = 'bright'
116 | else:
117 | shade = 'dark'
118 |
119 | size_val = entity['shape']['size']['x']
120 | if size_val == 0.075:
121 | size = 'small'
122 | elif size_val == 0.15:
123 | size = 'big'
124 |
125 | name = 'o' + str(num_objects+1)
126 | target_query = target_query+ ":- not object({},{},{},{},{}). ".format(name, color, shape, shade, size)
127 | obj_enc.append(get_encoding(color, shape, shade, size))
128 | num_objects += 1
129 |
130 | #bg encodings
131 | for i in range(num_objects, 4):
132 | name = 'o' + str(num_objects+1)
133 | target_query = target_query+ ":- not object({},black,bg, bg, bg). ".format(name)
134 | obj_enc.append(get_encoding("black","bg","bg","bg"))
135 | num_objects += 1
136 |
137 |
138 | self.query_map[count] = target_query
139 | self.obj_map[count] = np.array(obj_enc)
140 | count+=1
141 |
142 |
143 |
144 | def __getitem__(self, index):
145 |
146 | #get the image
147 | img_path = self.img_paths[index]
148 | img = io.imread(img_path)[:, :, :3]
149 |
150 | transform = transforms.Compose([
151 | transforms.ToPILImage(),
152 | transforms.ToTensor(),
153 | ])
154 | img = transform(img)
155 | img = (img - 0.5) * 2.0 # Rescale to [-1, 1].
156 |
157 | if self.ret_obj_encoding:
158 | return {'im':img}, self.query_map[index] ,self.obj_map[index]
159 | else:
160 | return {'im':img}, self.query_map[index]
161 |
162 | def __len__(self):
163 | return len(self.img_paths)
164 |
--------------------------------------------------------------------------------
/src/experiments/slash_attention/shapeworld4/slash_attention_shapeworld4.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 |
5 | import train
6 | import datetime
7 |
8 |
9 | #Python script to start the shapeworld4 slot attention experiment
10 | #Define your experiment(s) parameters as a hashmap having the following parameters
11 | example_structure = {'experiment_name':
12 | {'structure': 'poon-domingos',
13 | 'pd_num_pieces': [4],
14 | 'lr': 0.01, #the learning rate to train the SPNs with, the slot attention module has a fixed lr=0.0004
15 | 'bs':50, #the batchsize
16 | 'epochs':1000, #number of epochs to train
17 | 'lr_warmup':True, #boolean indicating the use of learning rate warm up
18 | 'lr_warmup_steps':25, #number of epochs to warm up the slot attention module, warmup does not apply to the SPNs
19 | 'start_date':"01-01-0001", #current date
20 | 'resume':False, #you can stop the experiment and set this parameter to true to load the last state and continue learning
21 | 'credentials':'DO', #your credentials for the rtpt class
22 | 'hungarian_matching': True,
23 | 'explanation': """Running the whole SlotAttention+Slash pipeline using poon-domingos as SPN structure learner."""}}
24 |
25 |
26 |
27 |
28 | #EXPERIMENTS
29 | date_string = datetime.datetime.today().strftime('%d-%m-%Y')
30 |
31 |
32 | for seed in [0,1,2,3,4]:
33 | experiments = {'shapeworld4':
34 | {'structure': 'poon-domingos', 'pd_num_pieces': [4],
35 | 'lr': 0.01, 'bs':512, 'epochs':1000,
36 | 'lr_warmup_steps':8, 'lr_decay_steps':360,
37 | 'start_date':date_string, 'resume':False, 'credentials':'DO','seed':seed,
38 | 'p_num':16, 'method':'same_top_k', 'hungarian_matching': False,
39 | 'explanation': """Running the whole SlotAttention+Slash pipeline using poon-domingos as SPN structure learner."""}
40 | }
41 |
42 |
43 | print("shapeworld4")
44 | train.slash_slot_attention("shapeworld4", experiments["shapeworld4"])
45 |
46 |
47 |
48 |
49 |
50 |
--------------------------------------------------------------------------------
/src/experiments/vqa/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ml-research/SLASH/79d158dbd670d56d8bf9043640635c24a18a9c57/src/experiments/vqa/__init__.py
--------------------------------------------------------------------------------
/src/experiments/vqa/cmd_args2.py:
--------------------------------------------------------------------------------
1 | """
2 | The source code is based on:
3 | Scallop: From Probabilistic Deductive Databases to Scalable Differentiable Reasoning
4 | Jiani Huang, Ziyang Li, Binghong Chen, Karan Samel, Mayur Naik, Le Song, Xujie Si
5 | Advances in Neural Information Processing Systems 34 (NeurIPS 2021)
6 | https://proceedings.neurips.cc/paper/2021/hash/d367eef13f90793bd8121e2f675f0dc2-Abstract.html
7 | """
8 |
9 |
10 | import argparse
11 | import sys
12 | import random
13 | import numpy as np
14 | import torch
15 | import os
16 | import logging
17 |
18 | # Utility function for take in yes/no and convert it to boolean
19 | def convert_str_to_bool(cmd_args):
20 | for key, val in vars(cmd_args).items():
21 | if val == "yes":
22 | setattr(cmd_args, key, True)
23 | elif val == "no":
24 | setattr(cmd_args, key, False)
25 |
26 | def str2bool(v):
27 | if isinstance(v, bool):
28 | return v
29 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
30 | return True
31 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
32 | return False
33 | else:
34 | raise argparse.ArgumentTypeError('Boolean value expected.')
35 |
36 | class LearningSetting(object):
37 |
38 | def __init__(self):
39 | data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../data"))
40 | dataset_dir = os.path.join(data_dir, "dataset")
41 | knowledge_base_dir = os.path.join(data_dir, "knowledge_base")
42 |
43 | self.parser = argparse.ArgumentParser(description='Argparser', allow_abbrev=True)
44 | # Logistics
45 | self.parser.add_argument("--seed", default=1234, type=int, help="set random seed")
46 | self.parser.add_argument("--gpu", default=0, type=int, help="GPU id")
47 | self.parser.add_argument('--timeout', default=600, type=int, help="Execution timeout")
48 |
49 | # Learning settings
50 | self.parser.add_argument("--name_threshold", default=0, type=float)
51 | self.parser.add_argument("--attr_threshold", default=0, type=float)
52 | self.parser.add_argument("--rela_threshold", default=0, type=float)
53 | # self.parser.add_argument("--name_topk_n", default=-1, type=int)
54 | # self.parser.add_argument("--attr_topk_n", default=-1, type=int)
55 | # self.parser.add_argument("--rela_topk_n", default=80, type=int)
56 | self.parser.add_argument("--topk", default=3, type=int)
57 |
58 | # training settings
59 | self.parser.add_argument('--feat_dim', type=int, default=2048)
60 | self.parser.add_argument('--n_epochs', type=int, default=20)
61 | self.parser.add_argument('--batch_size', type=int, default=1)
62 | self.parser.add_argument('--max_workers', type=int, default=1)
63 | self.parser.add_argument('--axiom_update_size', type=int, default=4)
64 | self.parser.add_argument('--name_lr', type=float, default=0.0001)
65 | self.parser.add_argument('--attr_lr', type=float, default=0.0001)
66 | self.parser.add_argument('--rela_lr', type=float, default=0.0001)
67 | self.parser.add_argument('--reinforce', type=str2bool, nargs='?', const=True, default=False) # reinforce only support single thread
68 | self.parser.add_argument('--replays', type=int, default=5)
69 |
70 | self.parser.add_argument('--model_dir', default=data_dir+'/model_ckpts_sg')
71 | # self.parser.add_argument('--model_dir', default=None)
72 | self.parser.add_argument('--log_name', default='model.log')
73 | self.parser.add_argument('--feat_f', default=data_dir+'/features.npy')
74 | self.parser.add_argument('--train_f', default=dataset_dir+'/task_list/train_tasks_c2_10.pkl')
75 | self.parser.add_argument('--val_f', default=dataset_dir+'/task_list/val_tasks.pkl')
76 | self.parser.add_argument('--test_f', default=dataset_dir+'/task_list/test_tasks_c2_1000.pkl')
77 | self.parser.add_argument('--cul_prov', type=bool, default=False)
78 |
79 | self.parser.add_argument('--meta_f', default=data_dir+'/gqa_info.json')
80 | self.parser.add_argument('--scene_f', default=data_dir+'/gqa_formatted_scene_graph.pkl')
81 | self.parser.add_argument('--image_data_f', default=data_dir+'/image_data.json')
82 | self.parser.add_argument('--dataset_type', default='name') # name, relation, attr:
83 |
84 | self.parser.add_argument('--function', default=None) #KG_Find / Hypernym_Find / Find_Name / Find_Attr / Relate / Relate_Reverse / And / Or
85 | self.parser.add_argument('--knowledge_base_dir', default=knowledge_base_dir)
86 | # self.parser.add_argument('--interp_size', type=int, default=2)
87 |
88 | self.parser.add_argument('--save_dir', default=data_dir+"/problog_data")
89 | self.args = self.parser.parse_args(sys.argv[1:])
90 |
91 | ls = LearningSetting()
92 | cmd_args = ls.args
93 | # print(cmd_args)
94 |
95 | # Fix random seed for debugging purpose
96 | if (ls.args.seed != None):
97 | random.seed(ls.args.seed)
98 | np.random.seed(ls.args.seed)
99 | torch.manual_seed(ls.args.seed)
100 |
101 | if not type(cmd_args.gpu) == None:
102 | os.environ['CUDA_VISIBLE_DEVICES'] = str(cmd_args.gpu)
103 | torch.set_default_tensor_type(torch.cuda.FloatTensor)
104 |
105 | if cmd_args.model_dir is not None:
106 | if not os.path.exists(cmd_args.model_dir):
107 | os.makedirs(cmd_args.model_dir)
108 |
109 | log_path = os.path.join(cmd_args.model_dir, cmd_args.log_name)
110 | logging.basicConfig(filename=log_path, filemode='w', level=logging.INFO, format='%(name)s - %(levelname)s - %(message)s')
111 | logging.info(cmd_args)
112 | logging.info("start!")
113 |
--------------------------------------------------------------------------------
/src/experiments/vqa/models.py:
--------------------------------------------------------------------------------
1 | """
2 | The source code is based on:
3 | Scallop: From Probabilistic Deductive Databases to Scalable Differentiable Reasoning
4 | Jiani Huang, Ziyang Li, Binghong Chen, Karan Samel, Mayur Naik, Le Song, Xujie Si
5 | Advances in Neural Information Processing Systems 34 (NeurIPS 2021)
6 | https://proceedings.neurips.cc/paper/2021/hash/d367eef13f90793bd8121e2f675f0dc2-Abstract.html
7 | """
8 |
9 | import numpy as np
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | import os
14 | import sys
15 | torch.autograd.set_detect_anomaly(True)
16 |
17 | sys.path.append('../../EinsumNetworks/src/')
18 | from einsum_wrapper import EiNet
19 |
20 |
21 | class MLPClassifier(nn.Module):
22 | def __init__(self, input_dim, latent_dim, output_dim, n_layers, dropout_rate, softmax):
23 | super(MLPClassifier, self).__init__()
24 |
25 | self.softmax = softmax
26 | self.output_dim = output_dim
27 |
28 | layers = []
29 | layers.append(nn.Linear(input_dim, latent_dim))
30 | layers.append(nn.ReLU())
31 | layers.append(nn.BatchNorm1d(latent_dim))
32 | #layers.append(nn.InstanceNorm1d(latent_dim))
33 | layers.append(nn.Dropout(dropout_rate))
34 | for _ in range(n_layers - 1):
35 | layers.append(nn.Linear(latent_dim, latent_dim))
36 | layers.append(nn.ReLU())
37 | layers.append(nn.BatchNorm1d(latent_dim))
38 | #layers.append(nn.InstanceNorm1d(latent_dim))
39 | layers.append(nn.Dropout(dropout_rate))
40 | layers.append(nn.Linear(latent_dim, output_dim))
41 |
42 |
43 | self.net = nn.Sequential(*layers)
44 |
45 | def forward(self, x, marg_idx=None, type=None):
46 | if x.sum() == 0:
47 | return torch.ones([x.shape[0], self.output_dim], device='cuda')
48 |
49 | idx = x.sum(dim=1)!=0 # get idx of true objects
50 | logits = torch.zeros(x.shape[0], self.output_dim, device='cuda')
51 |
52 | logits[idx] = self.net(x[idx])
53 |
54 |
55 | if self.softmax:
56 | probs = F.softmax(logits, dim=1)
57 | else:
58 | probs = torch.sigmoid(logits)
59 |
60 | return probs
61 |
62 | # FasterCNN object feature size
63 | feature_dim = 2048
64 |
65 | name_clf = MLPClassifier(
66 | input_dim=feature_dim,
67 | output_dim=500,
68 | latent_dim=1024,
69 | n_layers=2,
70 | dropout_rate=0.3,
71 | softmax=True
72 | )
73 |
74 | rela_clf = MLPClassifier(
75 | input_dim=(feature_dim+4)*2,
76 | output_dim=229,
77 | latent_dim=1024,
78 | n_layers=1,
79 | dropout_rate=0.5,
80 | softmax=True
81 | )
82 |
83 |
84 | attr_clf = MLPClassifier(
85 | input_dim=feature_dim,
86 | output_dim=609,
87 | latent_dim=1024,
88 | n_layers=1,
89 | dropout_rate=0.3,
90 | softmax=False
91 | )
92 |
93 |
94 |
95 | name_einet = EiNet(structure = 'poon-domingos',
96 | pd_num_pieces = [4],
97 | use_em = False,
98 | num_var = 2048,
99 | class_count = 500,
100 | pd_width = 32,
101 | pd_height = 64,
102 | learn_prior = True)
103 | rela_einet = EiNet(structure = 'poon-domingos',
104 | pd_num_pieces = [4],
105 | use_em = False,
106 | num_var = 4104,
107 | class_count = 229,
108 | pd_width = 72,
109 | pd_height = 57,
110 | learn_prior = True)
111 |
112 | attr_einet = EiNet(structure = 'poon-domingos',
113 | pd_num_pieces = [4],
114 | use_em = False,
115 | num_var = 2048,
116 | class_count = 609,
117 | pd_width = 32,
118 | pd_height = 64,
119 | learn_prior = True)
--------------------------------------------------------------------------------
/src/experiments/vqa/network_nn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class Net_nn(nn.Module):
5 | def __init__(self):
6 | super(Net_nn, self).__init__()
7 | self.encoder = nn.Sequential(
8 | nn.Conv2d(1, 6, 5), # 6 is the output chanel size; 5 is the kernal size; 1 (chanel) 28 28 -> 6 24 24
9 | nn.MaxPool2d(2, 2), # kernal size 2; stride size 2; 6 24 24 -> 6 12 12
10 | nn.ReLU(True), # inplace=True means that it will modify the input directly thus save memory
11 | nn.Conv2d(6, 16, 5), # 6 12 12 -> 16 8 8
12 | nn.MaxPool2d(2, 2), # 16 8 8 -> 16 4 4
13 | nn.ReLU(True)
14 | )
15 | self.classifier = nn.Sequential(
16 | nn.Linear(16 * 4 * 4, 120),
17 | nn.ReLU(),
18 | nn.Linear(120, 84),
19 | nn.ReLU(),
20 | nn.Linear(84, 10),
21 | nn.Softmax(1)
22 | )
23 |
24 | def forward(self, x, marg_idx=None, type=1):
25 |
26 | assert type == 1, "only posterior computations are available for this network"
27 |
28 | # If the list of the pixel numbers to be marginalised is given,
29 | # then genarate a marginalisation mask from it and apply to the
30 | # tensor 'x'
31 | if marg_idx:
32 | batch_size = x.shape[0]
33 | with torch.no_grad():
34 | marg_mask = torch.ones_like(x, device=x.device).reshape(batch_size, 1, -1)
35 | marg_mask[:, :, marg_idx] = 0
36 | marg_mask = marg_mask.reshape_as(x)
37 | marg_mask.requires_grad_(False)
38 | x = torch.einsum('ijkl,ijkl->ijkl', x, marg_mask)
39 | x = self.encoder(x)
40 | x = x.view(-1, 16 * 4 * 4)
41 | x = self.classifier(x)
42 | return x
43 |
--------------------------------------------------------------------------------
/src/experiments/vqa/preprocess.py:
--------------------------------------------------------------------------------
1 | """
2 | The source code is based on:
3 | Scallop: From Probabilistic Deductive Databases to Scalable Differentiable Reasoning
4 | Jiani Huang, Ziyang Li, Binghong Chen, Karan Samel, Mayur Naik, Le Song, Xujie Si
5 | Advances in Neural Information Processing Systems 34 (NeurIPS 2021)
6 | https://proceedings.neurips.cc/paper/2021/hash/d367eef13f90793bd8121e2f675f0dc2-Abstract.html
7 | """
8 |
9 | # Prog = Conj | Logic
10 | # Logic = biOp Prog Prog
11 | # Conj = and Des Conj | Des
12 | # Des = rela Relaname O1 O2 | attr Attr O
13 |
14 | class Variable():
15 | def __init__(self, id):
16 | self.var_id = f"O{id}"
17 | self.name_id = f"N{id}"
18 | self.name = []
19 | self.hypernyms = []
20 | self.attrs = []
21 | self.kgs = []
22 | # The relations where this object functions as a subject
23 | self.sub_relas = []
24 | self.obj_relas = []
25 |
26 | def has_rela(self):
27 | if len(self.sub_relas) == 0 and len(self.obj_relas) == 0:
28 | return False
29 | return True
30 |
31 | def get_name_id(self):
32 | # if (not len(self.hypernyms) == 0) and len(self.name) == 0:
33 | # return True, self.name_id
34 | # if (not len(self.kgs) == 0) and len(self.name) == 0:
35 | # return True, self.name_id
36 | return False, self.name_id
37 |
38 | def set_name(self, name):
39 | if name in self.name:
40 | return
41 |
42 | self.name.append(name)
43 |
44 | def set_kg(self, kg):
45 | if kg not in self.kgs:
46 | self.kgs.append(kg)
47 |
48 | def set_hypernym(self, hypernym):
49 | if hypernym not in self.hypernyms:
50 | self.hypernyms.append(hypernym)
51 |
52 | def set_attr(self, attr):
53 | if attr not in self.attrs:
54 | self.attrs.append(attr)
55 |
56 | def set_obj_relas(self, obj_rela):
57 | if obj_rela not in self.obj_relas:
58 | self.obj_relas.append(obj_rela)
59 |
60 | def set_sub_relas(self, sub_rela):
61 | if sub_rela not in self.sub_relas:
62 | self.sub_relas.append(sub_rela)
63 |
64 | def get_neighbor(self):
65 | neighbors = []
66 | for rela in self.sub_relas:
67 | neighbors.append(rela.obj)
68 | for rela in self.obj_relas:
69 | neighbors.append(rela.sub)
70 | return neighbors
71 |
72 | def update(self, other):
73 |
74 | self.hypernyms = list(set(self.name + other.name))
75 | self.hypernyms = list(set(self.hypernyms + other.hypernyms))
76 | self.attrs = list(set(self.attrs + other.attrs))
77 | self.kgs = list(set(self.kgs + other.kgs))
78 |
79 |
80 | def to_datalog(self, with_name=False, with_rela=True):
81 |
82 | name_query = []
83 |
84 | if (len(self.name) == 0) and with_name:
85 | name_query.append("name({}, {})".format(self.name_id.replace(" ","_").replace(".","_"), self.var_id.replace(" ","_").replace(".","_")))
86 |
87 | if (not len(self.name) == 0):
88 | for n in self.name:
89 | #name_query.append(f"name(\"{n}\", {self.var_id})")
90 | n = n.replace(" ","_").replace(".","_")
91 | #name_query.append(f"name(0,1,{self.var_id},{n})")
92 | name_query.append(f"name({self.var_id},{n})")
93 |
94 |
95 | #attr_query = [ f"attr(\"{attr}\", {self.var_id})" for attr in self.attrs]
96 | #attr_query = [ "attr(0,1,{}, {})".format(self.var_id.replace(" ","_"),attr.replace(" ","_")) for attr in self.attrs]
97 | attr_query = [ "attr({}, {})".format(self.var_id.replace(" ","_"),attr.replace(" ","_")) for attr in self.attrs]
98 |
99 | #hypernym_query = [f"name(\"{hypernym}\", {self.var_id})" for hypernym in self.hypernyms]
100 | #hypernym_query = ["name(0,1,{}, {})".format(self.var_id.replace(" ","_").replace(".","_"),hypernym.replace(" ","_").replace(".","_")) for hypernym in self.hypernyms]
101 | hypernym_query = ["name({}, {})".format(self.var_id.replace(" ","_").replace(".","_"),hypernym.replace(" ","_").replace(".","_")) for hypernym in self.hypernyms]
102 |
103 | kg_query = []
104 |
105 | for kg in self.kgs:
106 | restriction = list(filter(lambda x: not x == 'BLANK' and not x == '', kg))
107 | assert (len(restriction) == 2)
108 | rel = restriction[0].replace(" ","_")
109 | usage = restriction[1].replace(" ","_")
110 | #kg_query += [f"name({self.name_id}, {self.var_id}), oa_rel({rel}, {self.name_id}, {usage})"]
111 | #kg_query += ["name({}, {}), oa_rel({}, {}, {})".format(self.name_id.replace(" ","_").replace(".","_") ,self.var_id.replace(" ","_").replace(".","_") ,rel.replace(" ","_").replace(".","_"), self.name_id.replace(" ","_").replace(".","_"), usage.replace(" ","_").replace(".","_"))]
112 | #kg_query += ["name(0,1,{}, {}), oa_rel({}, {}, {})".format(self.var_id.replace(" ","_").replace(".","_") ,self.name_id.replace(" ","_").replace(".","_"), rel.replace(" ","_").replace(".","_"), self.name_id.replace(" ","_").replace(".","_"), usage.replace(" ","_").replace(".","_"))]
113 | kg_query += ["name({}, {}), oa_rel({}, {}, {})".format(self.var_id.replace(" ","_").replace(".","_") ,self.name_id.replace(" ","_").replace(".","_"), rel.replace(" ","_").replace(".","_"), self.name_id.replace(" ","_").replace(".","_"), usage.replace(" ","_").replace(".","_"))]
114 |
115 | if with_rela:
116 | rela_query = [rela.to_datalog() for rela in self.sub_relas]
117 | else:
118 | rela_query = []
119 |
120 | program = name_query + attr_query + hypernym_query + kg_query + rela_query
121 |
122 | #print(program)
123 | return program
124 |
125 | class Relation():
126 | def __init__(self, rela_name, sub, obj):
127 | self.rela_name = rela_name
128 | self.sub = sub
129 | self.obj = obj
130 | self.sub.set_sub_relas(self)
131 | self.obj.set_obj_relas(self)
132 |
133 | def substitute(self, v1, v2):
134 | if self.sub == v1:
135 | self.sub = v2
136 | if self.obj == v1:
137 | self.obj = v2
138 |
139 | def to_datalog(self):
140 | #rela_query = f"relation(\"{self.rela_name}\", {self.sub.var_id}, {self.obj.var_id})"
141 | #rela_query = "relation(0,1,{}, {}, {})".format( self.sub.var_id.replace(" ","_"), self.obj.var_id.replace(" ","_"),self.rela_name.replace(" ","_"))
142 | rela_query = "relation({}, {}, {})".format( self.sub.var_id.replace(" ","_"), self.obj.var_id.replace(" ","_"),self.rela_name.replace(" ","_"))
143 |
144 | return rela_query
145 |
146 |
147 | # This is for binary operations on variables
148 | class BiOp():
149 | def __init__(self, op_name, v1, v2):
150 | self.op_name = op_name
151 | self.v1 = v1
152 | self.v2 = v2
153 |
154 | def to_datalog(self):
155 | raise NotImplementedError
156 |
157 | class Or(BiOp):
158 | def __init__(self, v1, v2):
159 | super().__init__('or', v1, v2)
160 |
161 | def to_datalog(self):
162 | pass
163 |
164 | class And(BiOp):
165 | def __init__(self, v1, v2):
166 | super().__init__('and', v1, v2)
167 |
168 | def to_datalog(self):
169 | pass
170 |
171 | class Query():
172 |
173 | def __init__(self, query):
174 | self.vars = []
175 | self.relations = []
176 | self.operations = []
177 | self.stack = []
178 | self.preprocess(query)
179 |
180 |
181 | def get_target(self):
182 | pass
183 |
184 | def get_new_var(self):
185 | self.vars.append(Variable(len(self.vars)))
186 |
187 | def preprocess(self, query):
188 |
189 | # for clause in question["program"]:
190 | for clause in query:
191 |
192 | if clause['function'] == "Initial":
193 | if not len(self.vars) == 0:
194 | self.stack.append(self.vars[-1])
195 | self.get_new_var()
196 | self.root = self.vars[-1]
197 |
198 | # logic operations
199 | elif clause['function'] == "And":
200 | v = self.stack.pop()
201 | self.operations.append(And(v, self.vars[-1]))
202 | self.root = self.operations[-1]
203 |
204 | elif clause['function'] == "Or":
205 | v = self.stack.pop()
206 | self.operations.append(Or(v, self.vars[-1]))
207 | self.root = self.operations[-1]
208 |
209 | # find operations
210 | elif clause['function'] == "KG_Find":
211 | self.vars[-1].set_kg(clause['text_input'])
212 |
213 | elif clause['function'] == "Hypernym_Find":
214 | self.vars[-1].set_hypernym(clause['text_input'])
215 |
216 | elif clause['function'] == "Find_Name":
217 | self.vars[-1].set_name(clause['text_input'])
218 |
219 | elif clause['function'] == "Find_Attr":
220 | self.vars[-1].set_attr(clause['text_input'])
221 |
222 | elif clause['function'] == "Relate_Reverse":
223 | self.get_new_var()
224 | self.root = self.vars[-1]
225 | obj = self.vars[-2]
226 | sub = self.vars[-1]
227 | rela_name = clause['text_input']
228 | relation = Relation(rela_name, sub, obj)
229 | self.relations.append(relation)
230 |
231 | elif clause['function'] == "Relate":
232 | self.get_new_var()
233 | self.root = self.vars[-1]
234 | sub = self.vars[-2]
235 | obj = self.vars[-1]
236 | rela_name = clause['text_input']
237 | relation = Relation(rela_name, sub, obj)
238 | self.relations.append(relation)
239 |
240 | else:
241 | raise Exception(f"Not handled function: {clause['function']}")
242 |
243 |
244 | # Optimizers for optimization
245 | class QueryOptimizer():
246 |
247 | def __init__(self, name):
248 | self.name = name
249 |
250 | def optimize(self, query):
251 | raise NotImplementedError
252 |
253 | # This only works for one and operation at the end
254 | # This is waited for update
255 | # class AndQueryOptimizer(QueryOptimizer):
256 |
257 | # def __init__(self):
258 | # super().__init__("AndQueryOptimizer")
259 |
260 | # # For any and operation, this can be rewritten as a single object
261 | # def optimize(self, query):
262 |
263 | # if len(query.operations) == 0:
264 | # return query
265 |
266 | # assert(len(query.operations) == 1)
267 |
268 | # operation = query.operations[0]
269 | # # merge every subtree into one
270 | # if operation.name == "and":
271 | # v1 = operation.v1
272 | # v2 = operation.v2
273 | # v1.merge(v2)
274 |
275 | # for relation in query.relations:
276 | # relation.substitute(v2, v1)
277 |
278 | # query.vars.remove(v2)
279 |
280 | # if query.root == operation:
281 | # query.root = v1
282 |
283 | # return query
284 |
285 |
286 | class HypernymOptimizer(QueryOptimizer):
287 |
288 | def __init__(self):
289 | super().__init__("HypernymOptimizer")
290 |
291 | def optimize(self, query):
292 |
293 | if (query.name is not None and not query.hypernyms == []):
294 | query.hypernyms = []
295 |
296 | return query
297 |
298 |
299 | class KGOptimizer(QueryOptimizer):
300 |
301 | def __init__(self):
302 | super().__init__("HypernymOptimizer")
303 |
304 | def optimize(self, query):
305 |
306 | if (query.name is not None and not query.kgs == []):
307 | query.kgs = []
308 |
309 | return query
310 |
--------------------------------------------------------------------------------
/src/experiments/vqa/query_lib.py:
--------------------------------------------------------------------------------
1 | """
2 | The source code is based on:
3 | Scallop: From Probabilistic Deductive Databases to Scalable Differentiable Reasoning
4 | Jiani Huang, Ziyang Li, Binghong Chen, Karan Samel, Mayur Naik, Le Song, Xujie Si
5 | Advances in Neural Information Processing Systems 34 (NeurIPS 2021)
6 | https://proceedings.neurips.cc/paper/2021/hash/d367eef13f90793bd8121e2f675f0dc2-Abstract.html
7 | """
8 |
9 | import pickle
10 | import os
11 | import subprocess
12 |
13 | from transformer import DetailTransformer, SimpleTransformer
14 | from preprocess import Query
15 | from knowledge_graph import KG, RULES
16 |
17 | # animals = ["giraffe", "cat", "kitten", "dog", "puppy", "poodle", "bull", "cow", "cattle", "bison", "calf", "pig", "ape", "monkey", "gorilla", "rat", "squirrel", "hamster", "deer", "moose", "alpaca", "elephant", "goat", "sheep", "lamb", "antelope", "rhino", "hippo", "zebra", "horse", "pony", "donkey", "camel", "panda", "panda bear", "bear", "polar bear", "seal", "fox", "raccoon", "tiger", "wolf", "lion", "leopard", "cheetah", "badger", "rabbit", "bunny", "beaver", "kangaroo", "dinosaur", "dragon", "fish", "whale", "dolphin", "crab", "shark", "octopus", "lobster", "oyster", "butterfly", "bee", "fly", "ant", "firefly", "snail", "spider", "bird", "penguin", "pigeon", "seagull", "finch", "robin", "ostrich", "goose", "owl", "duck", "hawk", "eagle", "swan", "chicken", "hen", "hummingbird", "parrot", "crow", "flamingo", "peacock", "bald eagle", "dove", "snake", "lizard", "alligator", "turtle", "frog", "animal"]
18 | class QueryManager():
19 | def __init__(self, save_dir):
20 | self.save_dir = save_dir
21 | self.transformer = SimpleTransformer()
22 |
23 | def save_file(self, file_name, content):
24 | save_path = os.path.join (self.save_dir, file_name)
25 | with open (save_path, "w") as save_file:
26 | save_file.write(content)
27 |
28 | def delete_file(self, file_name):
29 | save_path = os.path.join (self.save_dir, file_name)
30 | if os.path.exists(save_path):
31 | os.remove(save_path)
32 |
33 | def fact_prob_to_file(self, fact_tps, fact_probs):
34 | scene_tps = []
35 |
36 | (name_tps, attr_tps, rela_tps) = fact_tps
37 | (name_probs, attr_probs, rela_probs) = fact_probs
38 |
39 | cluster_ntp = {}
40 | for (oid, name), prob in zip(name_tps, name_probs):
41 | if not oid in cluster_ntp:
42 | cluster_ntp[oid] = [(name, prob)]
43 | else:
44 | cluster_ntp[oid].append((name, prob))
45 |
46 |
47 | for oid, name_list in cluster_ntp.items():
48 | name_tps = []
49 | for (name, prob) in name_list:
50 | # if not name in animals[:5]:
51 | # continue
52 | name_tps.append(f'{prob}::name("{name}", {int(oid)})')
53 | name_content = ";\n".join(name_tps) + "."
54 | scene_tps.append(name_content)
55 |
56 | for attr_tp, prob in zip(attr_tps, attr_probs):
57 | # if not attr_tp[1] == "tall":
58 | # continue
59 | scene_tps.append(f'{prob}::attr("{attr_tp[1]}", {int(attr_tp[0])}).')
60 |
61 | for rela_tp, prob in zip(rela_tps, rela_probs):
62 | # if not rela_tp[0] == "left":
63 | # continue
64 | scene_tps.append(f'{prob}::relation("{rela_tp[0]}", {int(rela_tp[1])}, {int(rela_tp[2])}).')
65 |
66 | return "\n".join(scene_tps)
67 |
68 | def process_result(self, result):
69 | output = result.stdout.decode()
70 | lines = output.split("\n")
71 | targets = {}
72 | for line in lines:
73 | if line == '':
74 | continue
75 | if not '\t' in line:
76 | continue
77 | info = line.split('\t')
78 | # No target found
79 | if 'X' in info[0]:
80 | break
81 | target_name = int(info[0][7:-2])
82 | target_prob = float(info[1])
83 | targets[target_name] = target_prob
84 | return targets
85 |
86 | def get_result(self, task, fact_tps, fact_probs):
87 | timeout = False
88 | question = task["question"]["clauses"]
89 | file_name = f"{task['question']['question_id']}.pl"
90 | save_path = os.path.join (self.save_dir, file_name)
91 | query = Query(question)
92 | query_content = self.transformer.transform(query)
93 | scene_content = self.fact_prob_to_file(fact_tps, fact_probs)
94 |
95 | content = KG+ "\n" + RULES + "\n" + scene_content + "\n" + query_content
96 | self.save_file(file_name, content)
97 | try:
98 | result = subprocess.run(["problog", save_path], capture_output=True, timeout=10)
99 | targets = self.process_result(result)
100 | except:
101 | # time out here
102 | timeout = True
103 | targets = {}
104 |
105 | # self.delete_file(file_name)
106 | return targets, timeout
107 |
108 | def get_relas(self, query):
109 | relations = []
110 | for clause in query:
111 | if 'Relate' in clause['function']:
112 | relations.append(clause['text_input'])
113 | return relations
114 |
115 |
116 |
117 |
--------------------------------------------------------------------------------
/src/experiments/vqa/sg_model.py:
--------------------------------------------------------------------------------
1 | """
2 | The source code is based on:
3 | Scallop: From Probabilistic Deductive Databases to Scalable Differentiable Reasoning
4 | Jiani Huang, Ziyang Li, Binghong Chen, Karan Samel, Mayur Naik, Le Song, Xujie Si
5 | Advances in Neural Information Processing Systems 34 (NeurIPS 2021)
6 | https://proceedings.neurips.cc/paper/2021/hash/d367eef13f90793bd8121e2f675f0dc2-Abstract.html
7 | """
8 |
9 | import numpy as np
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | import os
14 | torch.autograd.set_detect_anomaly(True)
15 |
16 | def load_model(model, model_f, device):
17 | print('loading model from %s' % model_f)
18 | model.load_state_dict(torch.load(model_f, map_location=device))
19 | model.eval()
20 |
21 | class MLPClassifier(nn.Module):
22 | def __init__(self, input_dim, latent_dim, output_dim, n_layers, dropout_rate):
23 | super(MLPClassifier, self).__init__()
24 |
25 | layers = []
26 | layers.append(nn.Linear(input_dim, latent_dim))
27 | layers.append(nn.ReLU())
28 | layers.append(nn.BatchNorm1d(latent_dim))
29 | layers.append(nn.Dropout(dropout_rate))
30 | for _ in range(n_layers - 1):
31 | layers.append(nn.Linear(latent_dim, latent_dim))
32 | layers.append(nn.ReLU())
33 | layers.append(nn.BatchNorm1d(latent_dim))
34 | layers.append(nn.Dropout(dropout_rate))
35 | layers.append(nn.Linear(latent_dim, output_dim))
36 |
37 | self.net = nn.Sequential(*layers)
38 |
39 | def forward(self, x):
40 | logits = self.net(x)
41 | return logits
42 |
43 | class SceneGraphModel:
44 | def __init__(self, feat_dim, n_names, n_attrs, n_rels, device, model_dir=None):
45 | self.feat_dim = feat_dim
46 | self.n_names = n_names
47 | self.n_attrs = n_attrs
48 | self.n_rels = n_rels
49 | self.device = device
50 |
51 | self._init_models()
52 | if model_dir is not None:
53 | self._load_models(model_dir)
54 |
55 | def _load_models(self, model_dir):
56 | for type in ['name', 'relation', 'attribute']:
57 | load_model(
58 | model=self.models[type],
59 | model_f=model_dir+'/%s_best_epoch.pt' % type,
60 | device=self.device
61 | )
62 |
63 | def _init_models(self):
64 | name_clf = MLPClassifier(
65 | input_dim=self.feat_dim,
66 | output_dim=self.n_names,
67 | latent_dim=1024,
68 | n_layers=2,
69 | dropout_rate=0.3
70 | )
71 |
72 | rela_clf = MLPClassifier(
73 | input_dim=(self.feat_dim+4)*2, # 4: bbox
74 | output_dim=self.n_rels+1, # 1: None
75 | latent_dim=1024,
76 | n_layers=1,
77 | dropout_rate=0.5
78 | )
79 |
80 | attr_clf = MLPClassifier(
81 | input_dim=self.feat_dim,
82 | output_dim=self.n_attrs,
83 | latent_dim=1024,
84 | n_layers=1,
85 | dropout_rate=0.3
86 | )
87 |
88 | self.models = {
89 | 'name': name_clf,
90 | 'attribute': attr_clf,
91 | 'relation': rela_clf
92 | }
93 |
94 | def predict(self, type, inputs):
95 | # type == 'name', inputs == (obj_feat_np_array)
96 | # type == 'relation', inputs == (sub_feat_np_array, obj_feat_np_array, sub_bbox_np_array, obj_bbox_np_array)
97 | # type == 'attribute', inputs == (obj_feat_np_array)
98 |
99 | model = self.models[type].to(self.device)
100 | inputs = torch.cat([torch.from_numpy(x).float() for x in inputs]).reshape(len(inputs), -1).to(self.device)
101 | logits = model(inputs)
102 |
103 | if type == 'attribute':
104 | probs = torch.sigmoid(logits)
105 | else:
106 | probs = F.softmax(logits, dim=1)
107 |
108 | return logits, probs
109 |
110 | def batch_predict(self, type, inputs, batch_split):
111 |
112 | model = self.models[type].to(self.device)
113 | inputs = torch.cat([torch.from_numpy(x).float() for x in inputs]).reshape(len(inputs), -1).to(self.device)
114 | logits = model(inputs)
115 |
116 | if type == 'attribute':
117 | probs = torch.sigmoid(logits)
118 | else:
119 | current_split = 0
120 | probs = []
121 | for split in batch_split:
122 | current_logits = logits[current_split:split]
123 | # batched_logits = logits.reshape(batch_shape[0], batch_shape[1], -1)
124 | current_probs = F.softmax(current_logits, dim=1)
125 | # probs = probs.reshape(inputs.shape[0], -1)
126 | probs.append(current_probs)
127 | current_split = split
128 |
129 | probs = torch.cat(probs).reshape(inputs.shape[0], -1)
130 | return logits, probs
131 |
--------------------------------------------------------------------------------
/src/experiments/vqa/test.py:
--------------------------------------------------------------------------------
1 | print("start importing...")
2 |
3 | import time
4 | import sys
5 | import argparse
6 | import datetime
7 |
8 | sys.path.append('../../')
9 | sys.path.append('../../SLASH/')
10 | sys.path.append('../../EinsumNetworks/src/')
11 |
12 |
13 | #torch, numpy, ...
14 | import torch
15 | from torch.utils.tensorboard import SummaryWriter
16 | from torchvision.transforms import transforms
17 | import torchvision
18 |
19 | import numpy as np
20 |
21 | import json
22 |
23 | #own modules
24 | from dataGen import VQA
25 | from einsum_wrapper import EiNet
26 | from network_nn import Net_nn
27 |
28 | from tqdm import tqdm
29 |
30 | #import slash
31 | from slash import SLASH
32 | import os
33 |
34 |
35 |
36 |
37 | import utils
38 | from utils import set_manual_seed
39 | from pathlib import Path
40 | from rtpt import RTPT
41 | import pickle
42 |
43 |
44 | from knowledge_graph import RULES, KG
45 | from dataGen import name_npp, relation_npp, attribute_npp
46 | from models import name_clf, rela_clf, attr_clf
47 |
48 | print("...done")
49 |
50 |
51 |
52 |
53 | def get_args():
54 | parser = argparse.ArgumentParser()
55 |
56 | parser.add_argument(
57 | "--seed", type=int, default=10, help="Random generator seed for all frameworks"
58 | )
59 |
60 | parser.add_argument(
61 | "--network-type",
62 | choices=["nn","pc"],
63 | help="The type of external to be used e.g. neural net or probabilistic circuit",
64 | )
65 | parser.add_argument(
66 | "--pc-structure",
67 | choices=["poon-domingos","binary-trees"],
68 | help="The type of external to be used e.g. neural net or probabilistic circuit",
69 | )
70 | parser.add_argument(
71 | "--batch-size", type=int, default=100, help="Batch size to train with"
72 | )
73 | parser.add_argument(
74 | "--num-workers", type=int, default=6, help="Number of threads for data loader"
75 | )
76 |
77 | parser.add_argument(
78 | "--p-num", type=int, default=8, help="Number of processes to devide the batch for parallel processing"
79 | )
80 |
81 | parser.add_argument("--credentials", type=str, help="Credentials for rtpt")
82 |
83 |
84 | args = parser.parse_args()
85 |
86 | if args.network_type == 'pc':
87 | args.use_pc = True
88 | else:
89 | args.use_pc = False
90 |
91 | return args
92 |
93 |
94 | def determine_max_objects(task_file):
95 |
96 | with open (task_file, 'rb') as tf:
97 | tasks = pickle.load(tf)
98 | print("taskfile len",len(tasks))
99 |
100 | #get the biggest number of objects in the image
101 | max_objects = 0
102 | for tidx, task in enumerate(tasks):
103 | all_oid = task['question']['input']
104 | len_all_oid = len(all_oid)
105 |
106 | #store the biggest number of objects in an image
107 | if len_all_oid > max_objects:
108 | max_objects = len_all_oid
109 | return max_objects
110 |
111 | def slash_vqa():
112 |
113 | args = get_args()
114 | print(args)
115 |
116 |
117 | # Set the seeds for PRNG
118 | set_manual_seed(args.seed)
119 |
120 | # Create RTPT object
121 | rtpt = RTPT(name_initials=args.credentials, experiment_name='SLASH VQA', max_iterations=1)
122 |
123 | # Start the RTPT tracking
124 | rtpt.start()
125 | #writer = SummaryWriter(os.path.join("runs","vqa", str(args.seed)), purge_step=0)
126 |
127 | #exp_name = 'vqa3'
128 | #Path("data/"+exp_name+"/").mkdir(parents=True, exist_ok=True)
129 | #saveModelPath = 'data/'+exp_name+'/slash_vqa_models_seed'+str(args.seed)+'.pt'
130 |
131 |
132 | #TODO workaround that adds +- notation
133 | program_example = """
134 | %scallop conversion rules
135 | name(O,N) :- name(0,+O,-N).
136 | attr(O,A) :- attr(0, +O, -A).
137 | relation(O1,O2,N) :- relation(0, +(O1,O2), -N).
138 | """
139 |
140 | #test_f = "dataset/task_list/test_tasks_c3_1000.pkl" # Test datset
141 |
142 | test_f = {"c2":"dataset/task_list/test_tasks_c2_1000.pkl",
143 | "c3":"dataset/task_list/test_tasks_c3_1000.pkl",
144 | "c4":"dataset/task_list/test_tasks_c4_1000.pkl",
145 | "c5":"dataset/task_list/test_tasks_c5_1000.pkl",
146 | "c6":"dataset/task_list/test_tasks_c6_1000.pkl"
147 | }
148 |
149 |
150 | num_obj = []
151 | if type(test_f) == str:
152 | num_obj.append(determine_max_objects(test_f))
153 | #if we have multiple test files
154 | elif type(test_f) == dict:
155 | for key in test_f:
156 | num_obj.append(determine_max_objects(test_f[key]))
157 |
158 |
159 | NUM_OBJECTS = np.max(num_obj)
160 | NUM_OBJECTS = 70
161 |
162 |
163 | vqa_params = {"l":200,
164 | "l_split":100,
165 | "num_names":500,
166 | "max_models":10000,
167 | "asp_timeout": 60}
168 |
169 |
170 |
171 | #load models #data/vqa18_10/slash_vqa_models_seed0_epoch_0.pt
172 | #src/experiments/vqa/
173 | #saved_models = torch.load("data/test/slash_vqa_models_seed42_epoch_9.pt")
174 | saved_models = torch.load("data/vqa_debug_relations_17_04_2023/slash_vqa_models_seed0_epoch_2.pt")
175 |
176 | print(saved_models.keys())
177 | rela_clf.load_state_dict(saved_models['relation_clf'])
178 | name_clf.load_state_dict(saved_models['name_clf'])
179 | attr_clf.load_state_dict(saved_models['attr_clf'])
180 |
181 | #create the SLASH Program , ,
182 | nnMapping = {'relation': rela_clf, 'name':name_clf , "attr":attr_clf}
183 | optimizers = {'relation': torch.optim.Adam(rela_clf.parameters(), lr=0.001, eps=1e-7),
184 | 'name': torch.optim.Adam(name_clf.parameters(), lr=0.001, eps=1e-7),
185 | 'attr': torch.optim.Adam(attr_clf.parameters(), lr=0.001, eps=1e-7)}
186 |
187 |
188 |
189 | all_oid = np.arange(0,NUM_OBJECTS)
190 | object_string = "".join([ f"object({oid1},{oid2}). " for oid1 in all_oid for oid2 in all_oid if oid1 != oid2])
191 | object_string = "".join(["".join([f"object({oid}). " for oid in all_oid]), object_string])
192 |
193 | #parse the SLASH program
194 | print("create SLASH program")
195 | program = "".join([KG, RULES, object_string, name_npp, relation_npp, attribute_npp, program_example])
196 | SLASHobj = SLASH(program, nnMapping, optimizers)
197 |
198 | #load the data
199 | if type(test_f) == str:
200 | test_data = VQA("test", test_f, NUM_OBJECTS)
201 | test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=4)
202 | #if we have multiple test files
203 | elif type(test_f) == dict:
204 | for key in test_f:
205 | test_data = VQA("test", test_f[key], NUM_OBJECTS)
206 | test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=4)
207 | test_f[key] = test_loader
208 |
209 |
210 |
211 | print("---TEST---")
212 | if type(test_f) == str:
213 | recall_5_test, test_time = SLASHobj.testVQA(test_loader, args.p_num, vqa_params=vqa_params)
214 | print("test-recall@5", recall_5_test)
215 |
216 | elif type(test_f) == dict:
217 | test_time = 0
218 | recalls = []
219 | for key in test_f:
220 | recall_5_test, tt = SLASHobj.testVQA(test_f[key], args.p_num, vqa_params=vqa_params)
221 | test_time += tt
222 | recalls.append(recall_5_test)
223 | print("test-recall@5_{}".format(key), recall_5_test, ", test_time:", tt )
224 | print("test-recall@5_c_all", np.mean(recalls), ", test_time:", test_time)
225 |
226 |
227 | if __name__ == "__main__":
228 | slash_vqa()
229 |
--------------------------------------------------------------------------------
/src/experiments/vqa/test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DEVICE=$1
4 | SEED=$2 # 0, 1, 2, 3, 4
5 | CREDENTIALS=$3
6 |
7 |
8 |
9 | #-------------------------------------------------------------------------------#
10 | # Train on CLEVR_v1 with cnn model
11 |
12 | CUDA_VISIBLE_DEVICES=$DEVICE python3 test.py \
13 | --seed $SEED \
14 | --network-type nn --batch-size 100 \
15 | --num-workers 0 --p-num 16 --credentials $CREDENTIALS
16 |
17 |
--------------------------------------------------------------------------------
/src/experiments/vqa/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DEVICE=$1
4 | SEED=$2 # 0, 1, 2, 3, 4
5 | CREDENTIALS=$3
6 |
7 |
8 | # 0.00001, 0.0001, 0.001, 0.01
9 | LR=0.001
10 |
11 | #-------------------------------------------------------------------------------#
12 |
13 | CUDA_VISIBLE_DEVICES=$DEVICE python3 train.py \
14 | --epochs 100 \
15 | --batch-size 100 --seed $SEED \
16 | --network-type nn --lr $LR \
17 | --num-workers 0 --p-num 20 --credentials $CREDENTIALS \
18 | --exp-name vqa_c2
19 |
--------------------------------------------------------------------------------
/src/experiments/vqa/trainer.py:
--------------------------------------------------------------------------------
1 | """
2 | The source code is based on:
3 | Scallop: From Probabilistic Deductive Databases to Scalable Differentiable Reasoning
4 | Jiani Huang, Ziyang Li, Binghong Chen, Karan Samel, Mayur Naik, Le Song, Xujie Si
5 | Advances in Neural Information Processing Systems 34 (NeurIPS 2021)
6 | https://proceedings.neurips.cc/paper/2021/hash/d367eef13f90793bd8121e2f675f0dc2-Abstract.html
7 | """
8 |
9 | import os
10 | import json
11 | import numpy as np
12 | import sys
13 | from tqdm import tqdm
14 | import torch
15 | import torch.optim as optim
16 | from torch.optim.lr_scheduler import StepLR
17 | from torch.nn import BCELoss
18 | import time
19 | import math
20 | import statistics
21 | # import concurrent.futures
22 |
23 | supervised_learning_path = os.path.abspath(os.path.join(
24 | os.path.abspath(__file__), "../../supervised_learning"))
25 | sys.path.insert(0, supervised_learning_path)
26 |
27 | common_path = os.path.abspath(os.path.join(os.path.abspath(__file__), "../.."))
28 | sys.path.insert(0, common_path)
29 |
30 | from query_lib import QueryManager
31 | #from src.experiments.vqa.cmd_args2 import cmd_args
32 | from word_idx_translator import Idx2Word
33 | from sg_model import SceneGraphModel
34 | from learning import get_fact_probs
35 | from vqa_utils import auc_score, get_recall
36 |
37 | def prog_analysis(datapoint):
38 | knowledge_op = ['Hypernym_Find', 'KG_Find']
39 | rela_op = ['Relate', 'Relate_Reverse']
40 | clauses = datapoint['question']['clauses']
41 | ops = [ clause['function'] for clause in clauses ]
42 | kg_num = sum([1 if op in knowledge_op else 0 for op in ops])
43 | rela_num = sum([1 if op in rela_op else 0 for op in ops])
44 | return (kg_num, rela_num)
45 |
46 | class ClauseNTrainer():
47 |
48 | def __init__(self,
49 | train_data_loader,
50 | val_data_loader,
51 | test_data_loader=None,
52 | #model_dir=cmd_args.model_dir, n_epochs=cmd_args.n_epochs,
53 | #save_dir=cmd_args.save_dir,
54 | #meta_f=cmd_args.meta_f, knowledge_base_dir=cmd_args.knowledge_base_dir,
55 | #axiom_update_size=cmd_args.axiom_update_size):
56 | model_dir="data_model/", n_epochs=2,
57 | save_dir="data_save/",
58 | meta_f="dataset/gqa_info.json", knowledge_base_dir="",
59 | axiom_update_size=""):
60 |
61 |
62 | self.model_dir = model_dir
63 | model_exists = self._model_exists(model_dir)
64 |
65 | if not model_exists:
66 | load_model_dir = None
67 | else:
68 | load_model_dir = model_dir
69 |
70 | if train_data_loader is not None:
71 | self.train_data = train_data_loader
72 | if val_data_loader is not None:
73 | self.val_data = val_data_loader
74 | if test_data_loader is not None:
75 | self.val_data = test_data_loader
76 |
77 | self.is_train = test_data_loader is None
78 | meta_info = json.load(open(meta_f, 'r'))
79 | self.idx2word = Idx2Word(meta_info)
80 | self.n_epochs = n_epochs
81 | self.query_manager = QueryManager(save_dir)
82 | self.axiom_update_size = axiom_update_size
83 | self.wmc_funcs = {}
84 |
85 | # load dictionary from previous training results
86 | self.sg_model = SceneGraphModel(
87 | feat_dim=64,
88 | n_names=meta_info['name']['num'],
89 | n_attrs=meta_info['attr']['num'],
90 | n_rels=meta_info['rel']['num'],
91 | device=torch.device('cuda'),
92 | model_dir=load_model_dir
93 | )
94 |
95 | self.sg_model_dict = self.sg_model.models
96 |
97 | self.loss_func = BCELoss()
98 |
99 | if self.is_train:
100 | self.optimizers = {}
101 | self.schedulers = {}
102 |
103 | for model_type, model_info in self.sg_model_dict.items():
104 | if model_type == 'name':
105 | self.optimizers['name'] = optim.Adam(
106 | model_info.parameters(), lr=0.01)
107 | self.schedulers['name'] = StepLR(
108 | self.optimizers['name'], step_size=10, gamma=0.1)
109 | # self.loss_func['name'] = F.cross_entropy
110 | if model_type == 'relation':
111 | self.optimizers['rel'] = optim.Adam(
112 | model_info.parameters(), lr=0.01)
113 | self.schedulers['rel'] = StepLR(
114 | self.optimizers['rel'], step_size=10, gamma=0.1)
115 | if model_type == 'attribute':
116 | self.optimizers['attr'] = optim.Adam(
117 | model_info.parameters(), lr=0.01)
118 | self.schedulers['attr'] = StepLR(
119 | self.optimizers['attr'], step_size=10, gamma=0.1)
120 |
121 | # self.pool = mp.Pool(cmd_args.max_workers)
122 | # self.batch = cmd_args.trainer_batch
123 |
124 | def _model_exists(self, model_dir):
125 | name_f = os.path.join(self.model_dir, 'name_best_epoch.pt')
126 | rela_f = os.path.join(self.model_dir, 'relation_best_epoch.pt')
127 | attr_f = os.path.join(self.model_dir, 'attribute_best_epoch.pt')
128 |
129 | if not os.path.exists(name_f):
130 | return False
131 | if not os.path.exists(rela_f):
132 | return False
133 | if not os.path.exists(attr_f):
134 | return False
135 | return True
136 |
137 | def _get_optimizer(self, data_type):
138 | optimizer = self.optimizers[data_type]
139 | return optimizer
140 |
141 | def _step_all(self):
142 | for optim_type, optim_info in self.optimizers.items():
143 | optim_info.step()
144 |
145 | def _step_scheduler(self):
146 | for scheduler_type, scheduler_info in self.schedulers.items():
147 | scheduler_info.step()
148 |
149 | def _zero_all(self):
150 | for optim_type, optim_info in self.optimizers.items():
151 | optim_info.zero_grad()
152 |
153 | def _train_all(self):
154 | for model_type, model_info in self.sg_model_dict.items():
155 | if model_type == 'name':
156 | model_info.train()
157 | model_info.share_memory()
158 | if model_type == 'relation':
159 | model_info.train()
160 | model_info.share_memory()
161 | if model_type == 'attribute':
162 | model_info.train()
163 | model_info.share_memory()
164 |
165 | def _eval_all(self):
166 | for model_type, model_info in self.sg_model_dict.items():
167 | if model_type == 'name':
168 | model_info.eval()
169 | model_info.share_memory()
170 | if model_type == 'relation':
171 | model_info.eval()
172 | model_info.share_memory()
173 | if model_type == 'attribute':
174 | model_info.eval()
175 | model_info.share_memory()
176 |
177 | def _save_all(self):
178 | for model_type, model_info in self.sg_model_dict.items():
179 | if model_type == 'name':
180 | save_f = os.path.join(self.model_dir, 'name_best_epoch.pt')
181 | torch.save(model_info.state_dict(), save_f)
182 | if model_type == 'relation':
183 | save_f = os.path.join(self.model_dir, 'relation_best_epoch.pt')
184 | torch.save(model_info.state_dict(), save_f)
185 | if model_type == 'attribute':
186 | save_f = os.path.join(self.model_dir, 'attribute_best_epoch.pt')
187 | torch.save(model_info.state_dict(), save_f)
188 |
189 | def loss_acc(self, targets, correct, all_oids, is_train=True):
190 |
191 | pred = []
192 | for oid in all_oids:
193 | if oid in targets:
194 | pred.append(targets[oid])
195 | else:
196 | pred.append(0)
197 |
198 | labels = [1 if obj in correct else 0 for obj in all_oids]
199 |
200 | labels_tensor = torch.tensor(labels, dtype=torch.float32)
201 | pred_tensor = torch.tensor(pred, dtype=torch.float32)
202 |
203 | pred_tensor = pred_tensor.reshape(1, -1)
204 | labels_tensor = labels_tensor.reshape(1, -1)
205 |
206 | loss = self.loss_func(pred_tensor, labels_tensor)
207 | recall = get_recall(labels_tensor, pred_tensor)
208 |
209 | if math.isnan(recall):
210 | recall = -1
211 |
212 | return loss.item(), recall
213 |
214 | def _pass(self, datapoint, is_train=True):
215 |
216 | correct = datapoint['question']['output']
217 | all_oid = datapoint['question']['input']
218 | fact_tps, fact_probs = get_fact_probs(self.sg_model, datapoint, self.idx2word, self.query_manager)
219 | result, timeout = self.query_manager.get_result(datapoint, fact_tps, fact_probs)
220 | if not timeout:
221 | loss, acc = self.loss_acc(result, correct, all_oid)
222 | else:
223 | loss = -1
224 | acc = -1
225 |
226 | return acc, loss, timeout
227 |
228 | def _train_epoch(self, ct):
229 |
230 | self._train_all()
231 | aucs = []
232 | losses = []
233 | timeouts = 0
234 | pbar = tqdm(self.train_data)
235 |
236 | for datapoint in pbar:
237 | auc, loss, timeout = self._pass(datapoint, is_train=True)
238 | if not timeout:
239 | if auc >= 0:
240 | aucs.append(auc)
241 | losses.append(loss)
242 | else:
243 | timeouts += 1
244 |
245 | pbar.set_description(
246 | f'[loss: {np.array(losses).mean()}, auc: {np.array(aucs).mean()}, timeouts: {timeouts}]')
247 | torch.cuda.empty_cache()
248 |
249 | self._step_all()
250 | self._zero_all()
251 |
252 | return np.mean(losses), np.mean(aucs)
253 |
254 | def _val_epoch(self):
255 | self._eval_all()
256 |
257 | timeouts = 0
258 | aucs = []
259 | losses = []
260 | time_out_prog_kg = {}
261 | time_out_prog_rela = {}
262 | success_prog_kg = {}
263 | success_prog_rela = {}
264 |
265 | pbar = tqdm(self.val_data)
266 | with torch.no_grad():
267 | for datapoint in pbar:
268 | kg_num, rela_num = prog_analysis(datapoint)
269 |
270 | auc, loss, timeout = self._pass(datapoint, is_train=False)
271 | if not timeout:
272 | aucs.append(auc)
273 | losses.append(loss)
274 | if not kg_num in success_prog_kg:
275 | success_prog_kg[kg_num] = 0
276 | if not rela_num in success_prog_rela:
277 | success_prog_rela[rela_num] = 0
278 | success_prog_kg[kg_num] += 1
279 | success_prog_rela[rela_num] += 1
280 |
281 | else:
282 | timeouts += 1
283 | if not kg_num in time_out_prog_kg:
284 | time_out_prog_kg[kg_num] = 0
285 | if not rela_num in time_out_prog_rela:
286 | time_out_prog_rela[rela_num] = 0
287 | time_out_prog_kg[kg_num] += 1
288 | time_out_prog_rela[rela_num] += 1
289 |
290 | if not len(aucs) == 0:
291 | pbar.set_description(
292 | f'[loss: {np.array(losses).mean()}, auc: {np.array(aucs).mean()}, timeouts: {timeouts}]')
293 |
294 | print(f"succ kg: {success_prog_kg}, succ rela: {success_prog_rela}")
295 | print(f"timeout kg: {time_out_prog_kg}, timeout rela: {time_out_prog_rela}")
296 | return np.mean(losses), np.mean(aucs)
297 |
298 |
299 | def train(self):
300 | assert self.is_train
301 | best_val_loss = np.inf
302 |
303 | for epoch in range(self.n_epochs):
304 | train_loss, train_acc = self._train_epoch(epoch)
305 | val_loss, val_acc = self._val_epoch()
306 | self._step_scheduler()
307 |
308 | print(
309 | '[Epoch %d/%d] [training loss: %.2f, auc: %.2f] [validation loss: %.2f, auc: %.2f]' %
310 | (epoch, self.n_epochs, train_loss, train_acc, val_loss, val_acc)
311 | )
312 |
313 | if val_loss < best_val_loss:
314 | best_val_loss = val_loss
315 | print('saving best models')
316 | self._save_all()
317 |
318 | def test(self):
319 | assert not self.is_train
320 | test_loss, test_acc = self._val_epoch()
321 | print('[test loss: %.2f, acc: %.2f]' % (test_loss, test_acc))
322 |
--------------------------------------------------------------------------------
/src/experiments/vqa/vqa_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | The source code is based on:
3 | Scallop: From Probabilistic Deductive Databases to Scalable Differentiable Reasoning
4 | Jiani Huang, Ziyang Li, Binghong Chen, Karan Samel, Mayur Naik, Le Song, Xujie Si
5 | Advances in Neural Information Processing Systems 34 (NeurIPS 2021)
6 | https://proceedings.neurips.cc/paper/2021/hash/d367eef13f90793bd8121e2f675f0dc2-Abstract.html
7 | """
8 |
9 | import os
10 | import json
11 | import argparse
12 | import numpy as np
13 | import torch
14 | import random
15 | import pickle
16 | from sklearn import metrics
17 |
18 | TIME_OUT = 120
19 |
20 | ###########################################################################
21 | # basic utilities
22 |
23 |
24 | def to_image_id_dict(dict_list):
25 | ii_dict = {}
26 | for dc in dict_list:
27 | image_id = dc['image_id']
28 | ii_dict[image_id] = dc
29 | return ii_dict
30 |
31 |
32 | articles = ['a', 'an', 'the', 'some', 'it']
33 |
34 |
35 | def remove_article(st):
36 | st = st.split()
37 | st = [word for word in st if word not in articles]
38 | return " ".join(st)
39 |
40 | ################################################################################################################
41 | # AUC calculation
42 |
43 | def get_recall(labels, logits, topk=2):
44 | #a = torch.tensor([[0, 1, 1], [1, 1, 0], [0, 0, 1],[1, 1, 1]])
45 | #b = torch.tensor([[0, 0, 0], [0, 1, 0], [0, 0, 1],[0, 1, 1]])
46 |
47 | # Calculate the recall
48 | _, pred = logits.topk(topk, 1, True, True) #get idx of the biggest k values in the tensor
49 | print(pred)
50 |
51 | pred = pred.t() #transpose
52 |
53 | #gather gets the elements from a given indices tensor. Here we get the elements at the same positions from our (top5) predictions
54 | #we then sum all entries up along the axis and therefore count the top-5 entries in the labels tensors at the prediction position indices
55 | correct = torch.sum(labels.gather(1, pred.t()), dim=1)
56 | print("correct", correct)
57 |
58 | #now we some up all true labels. We clamp them to be maximum top5
59 | correct_label = torch.clamp(torch.sum(labels, dim = 1), 0, topk)
60 |
61 | #we now can compare if the number of correctly found top-5 labels on the predictions vector is the same as on the same positions as the GT vector
62 | print("correct label", correct_label)
63 |
64 | accuracy = torch.mean(correct / correct_label).item()
65 |
66 | return accuracy
67 |
68 | def single_auc_score(label, pred):
69 | label = label
70 | pred = [p.item() if not (type(p) == float or type(p) == int)
71 | else p for p in pred]
72 | if len(set(label)) == 1:
73 | auc = -1
74 | else:
75 | fpr, tpr, thresholds = metrics.roc_curve(label, pred, pos_label=1)
76 | auc = metrics.auc(fpr, tpr)
77 | return auc
78 |
79 |
80 | def auc_score(labels, preds):
81 | if type(labels) == torch.Tensor:
82 | labels = labels.long().cpu().detach().numpy()
83 | if type(preds) == torch.Tensor:
84 | preds = preds.cpu().detach().numpy()
85 |
86 | if (type(labels) == torch.Tensor or type(labels) == np.ndarray) and len(labels.shape) == 2:
87 | aucs = []
88 | for label, pred in zip(labels, preds):
89 | auc_single = single_auc_score(label, pred)
90 | if not auc_single < 0:
91 | aucs.append(auc_single)
92 | auc = np.array(aucs).mean()
93 | else:
94 | auc = single_auc_score(labels, preds)
95 |
96 | return auc
97 |
98 |
99 | def to_binary_labels(labels, attr_num):
100 | if labels.shape[-1] == attr_num:
101 | return labels
102 |
103 | binary_labels = torch.zeros((labels.shape[0], attr_num))
104 | for ct, label in enumerate(labels):
105 | binary_labels[ct][label] = 1
106 | return binary_labels
107 |
108 |
109 | ##############################################################################
110 | # model loading
111 | def get_default_args():
112 |
113 | DATA_ROOT = os.path.abspath(os.path.join(
114 | os.path.abspath(__file__), "../../data"))
115 |
116 | parser = argparse.ArgumentParser()
117 | parser.add_argument('--gpu', type=int, default=2)
118 | parser.add_argument('--seed', type=int, default=1234)
119 | parser.add_argument('--feat_dim', type=int, default=2048)
120 | # parser.add_argument('--type', default='name')
121 | parser.add_argument('--model_dir', default=DATA_ROOT + '/model_ckpts')
122 | parser.add_argument('--meta_f', default=DATA_ROOT + '/preprocessing/gqa_info.json')
123 | args = parser.parse_args()
124 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
125 |
126 | np.random.seed(args.seed)
127 | torch.manual_seed(args.seed)
128 | random.seed(args.seed)
129 |
130 | return args
131 |
132 |
133 | def gen_task(formatted_scene_graph_path, vg_image_data_path, questions_path, features_path, n=100, q_length=None):
134 |
135 | with open(questions_path, 'rb') as questions_file:
136 | questions = pickle.load(questions_file)
137 |
138 | with open(formatted_scene_graph_path, 'rb') as vg_scene_graphs_file:
139 | scene_graphs = pickle.load(vg_scene_graphs_file)
140 |
141 | with open(vg_image_data_path, 'r') as vg_image_data_file:
142 | image_data = json.load(vg_image_data_file)
143 |
144 | features = np.load(features_path, allow_pickle=True)
145 | features = features.item()
146 |
147 | image_data = to_image_id_dict(image_data)
148 |
149 | for question in questions[:n]:
150 |
151 | if q_length is not None:
152 | current_len = len(question['clauses'])
153 | if not current_len == q_length:
154 | continue
155 |
156 | # functions = question["clauses"]
157 | image_id = question["image_id"]
158 | scene_graph = scene_graphs[image_id]
159 | cur_image_data = image_data[image_id]
160 |
161 | if not scene_graph['image_id'] == image_id or not cur_image_data['image_id'] == image_id:
162 | raise Exception("Mismatched image id")
163 |
164 | info = {}
165 | info['image_id'] = image_id
166 | info['scene_graph'] = scene_graph
167 | info['url'] = cur_image_data['url']
168 | info['question'] = question
169 | info['object_feature'] = []
170 | info['object_ids'] = []
171 |
172 | for obj_id in scene_graph['names'].keys():
173 | info['object_ids'].append(obj_id)
174 | info['object_feature'].append(features.get(obj_id))
175 |
176 | yield(info)
177 |
178 |
--------------------------------------------------------------------------------
/src/experiments/vqa/word_idx_translator.py:
--------------------------------------------------------------------------------
1 | """
2 | The source code is based on:
3 | Scallop: From Probabilistic Deductive Databases to Scalable Differentiable Reasoning
4 | Jiani Huang, Ziyang Li, Binghong Chen, Karan Samel, Mayur Naik, Le Song, Xujie Si
5 | Advances in Neural Information Processing Systems 34 (NeurIPS 2021)
6 | https://proceedings.neurips.cc/paper/2021/hash/d367eef13f90793bd8121e2f675f0dc2-Abstract.html
7 | """
8 | import json
9 |
10 |
11 | class Idx2Word():
12 |
13 | def __init__(self, meta_info, use_canon=False):
14 | self.setup(meta_info)
15 | self.attr_canon = meta_info['attr']['canon']
16 | self.name_canon = meta_info['name']['canon']
17 | self.rela_canon = meta_info['rel']['canon']
18 |
19 | self.attr_alias = meta_info['attr']['alias']
20 | self.rela_alias = meta_info['rel']['alias']
21 |
22 | self.attr_to_idx_dict = meta_info['attr']['idx']
23 | self.rela_to_idx_dict = meta_info['rel']['idx']
24 | self.name_to_idx_dict = meta_info['name']['idx']
25 | self.use_canon = use_canon
26 | # print("here")
27 |
28 | def setup(self, meta_info):
29 |
30 | attr_to_idx = meta_info['attr']['idx']
31 | rela_to_idx = meta_info['rel']['idx']
32 | name_to_idx = meta_info['name']['idx']
33 |
34 | attr_freq = meta_info['attr']['freq']
35 | rela_freq = meta_info['rel']['freq']
36 | name_freq = meta_info['name']['freq']
37 |
38 | # attr_group = meta_info['attr']['group']
39 |
40 | def setup_single(to_idx, freq, group=None):
41 | idx_to_name = {}
42 | for name in freq:
43 | if name not in to_idx:
44 | continue
45 | idx = to_idx[name]
46 | if type(idx) == list:
47 | if not idx[0] in idx_to_name.keys():
48 | idx_to_name[idx[0]] = {}
49 | idx_to_name[idx[0]][idx[1]] = name
50 | else:
51 | idx_to_name[idx] = name
52 | return idx_to_name
53 |
54 | self.idx_to_name_dict = setup_single(name_to_idx, name_freq)
55 | self.idx_to_rela_dict = setup_single(rela_to_idx, rela_freq)
56 | self.idx_to_attr_dict = setup_single(attr_to_idx, attr_freq)
57 | # self.idx_to_attr_dict = setup_single(attr_to_idx, attr_freq, attr_group)
58 |
59 | def get_name_ct(self):
60 | return len(self.idx_to_name_dict)
61 |
62 | def get_rela_ct(self):
63 | return len(self.idx_to_rela_dict)
64 |
65 | def get_attr_ct(self):
66 | return len(self.idx_to_attr_dict)
67 |
68 | def get_names(self):
69 | return list(self.idx_to_name_dict.values())
70 |
71 | def idx_to_name(self, idx):
72 | if idx is None:
73 | return None
74 | if type(idx) == str:
75 | return idx
76 | if len(self.idx_to_name_dict) == idx:
77 | return None
78 | if idx == -1:
79 | return None
80 | return self.idx_to_name_dict[idx]
81 |
82 | def idx_to_rela(self, idx):
83 | if idx is None:
84 | return None
85 | if idx == -1:
86 | return None
87 | if type(idx) == str:
88 | return idx
89 | if len(self.idx_to_rela_dict) == idx:
90 | return None
91 | return self.idx_to_rela_dict[idx]
92 |
93 | def idx_to_attr(self, idx):
94 | if idx is None:
95 | return None
96 | if type(idx) == str:
97 | return idx
98 | if len(self.idx_to_attr_dict) == idx:
99 | return None
100 | if idx == -1:
101 | return None
102 | # return self.idx_to_attr_dict[idx[0]][idx[1]]
103 | return self.idx_to_attr_dict[idx]
104 |
105 | def attr_to_idx(self, attr):
106 | if attr is None:
107 | return attr
108 |
109 | if self.use_canon:
110 | if attr in self.attr_canon.keys():
111 | attr = self.attr_canon[attr]
112 |
113 | if attr in self.attr_alias.keys():
114 | attr = self.attr_alias[attr]
115 |
116 | if attr not in self.attr_to_idx_dict.keys():
117 | return None
118 |
119 | return self.attr_to_idx_dict[attr]
120 |
121 | def name_to_idx(self, name):
122 |
123 | if name is None:
124 | return name
125 |
126 | if self.use_canon:
127 | if name in self.name_canon.keys():
128 | name = self.name_canon[name]
129 |
130 | if name not in self.name_to_idx_dict.keys():
131 | return None
132 |
133 | return self.name_to_idx_dict[name]
134 |
135 | def rela_to_idx(self, rela):
136 | if rela is None:
137 | return rela
138 |
139 | if self.use_canon:
140 | if rela in self.rela_canon.keys():
141 | rela = self.rela_canon[rela]
142 |
143 | if rela in self.rela_alias.keys():
144 | rela = self.rela_alias[rela]
145 |
146 | if rela not in self.rela_to_idx_dict.keys():
147 | return None
148 |
149 | return self.rela_to_idx_dict[rela]
150 |
151 |
152 | def process_program(program, meta_info):
153 |
154 | new_program = []
155 |
156 | for clause in program:
157 | new_clause = {}
158 | new_clause['function'] = clause['function']
159 |
160 | if 'output' in clause.keys():
161 | new_clause['output'] = clause['output']
162 |
163 | if clause['function'] == "Hypernym_Find":
164 | name = clause['text_input'][0]
165 | attr = clause['text_input'][1]
166 |
167 | new_clause['text_input'] = [name, attr] + clause['text_input'][2:]
168 |
169 | elif clause['function'] == "Find":
170 | name = clause['text_input'][0]
171 | new_clause['text_input'] = [name] + clause['text_input'][1:]
172 |
173 | elif clause['function'] == "Relate_Reverse":
174 | relation = clause['text_input']
175 | new_clause['text_input'] = relation
176 |
177 | elif clause['function'] == "Relate":
178 | relation = clause['text_input']
179 | new_clause['text_input'] = relation
180 |
181 | else:
182 | if 'text_input' in clause.keys():
183 | new_clause['text_input'] = clause['text_input']
184 |
185 | new_program.append(new_clause)
186 |
187 | return new_program
188 |
189 |
190 | def process_questions(questions_path, new_question_path, meta_info):
191 | new_questions = {}
192 |
193 | with open(questions_path, 'r') as questions_file:
194 | questions = json.load(questions_file)
195 |
196 | # process questions
197 | for question in questions:
198 |
199 | image_id = question["image_id"]
200 |
201 | # process questions
202 | if image_id not in new_questions.keys():
203 | new_questions[image_id] = {}
204 | new_question = new_questions[image_id]
205 |
206 | new_question['question_id'] = question['question_id']
207 | new_question['program'] = process_program(
208 | question["program"], meta_info)
209 |
210 | program = question['program']
211 | new_question['target'] = program[-2]["output"]
212 | new_question['question'] = question['question']
213 | new_question['answer'] = question['answer']
214 |
215 | with open(new_question_path, 'w') as new_question_file:
216 | json.dump(new_questions, new_question_file)
217 |
218 |
--------------------------------------------------------------------------------
/src/slot_attention_module.py:
--------------------------------------------------------------------------------
1 | """
2 | Slot attention model based on code of tkipf and the corresponding paper Locatello et al. 2020
3 | """
4 | from torch import nn
5 | import torch
6 | import torch.nn.functional as F
7 | import torchvision.models as models
8 | import numpy as np
9 | from torchsummary import summary
10 |
11 |
12 | def build_grid(resolution):
13 | ranges = [np.linspace(0., 1., num=res) for res in resolution]
14 | grid = np.meshgrid(*ranges, sparse=False, indexing="ij")
15 | grid = np.stack(grid, axis=-1)
16 | grid = np.reshape(grid, [resolution[0], resolution[1], -1])
17 | grid = np.expand_dims(grid, axis=0)
18 | grid = grid.astype(np.float32)
19 | return np.concatenate([grid, 1.0 - grid], axis=-1)
20 |
21 |
22 | def spatial_broadcast(slots, resolution):
23 | """Broadcast slot features to a 2D grid and collapse slot dimension."""
24 | # `slots` has shape: [batch_size, num_slots, slot_size].
25 | slots = torch.reshape(slots, [slots.shape[0] * slots.shape[1], 1, 1, slots.shape[2]])
26 |
27 | grid = slots.repeat(1, resolution[0], resolution[1], 1) #repeat expands the data along differnt dimensions
28 | # `grid` has shape: [batch_size*num_slots, width, height, slot_size].
29 | return grid
30 |
31 |
32 | def unstack_and_split(x, batch_size, n_slots, num_channels=3):
33 | """Unstack batch dimension and split into channels and alpha mask."""
34 | # unstacked = torch.reshape(x, [batch_size, -1] + list(x.shape[1:]))
35 | # channels, masks = torch.split(unstacked, [num_channels, 1], dim=-1)
36 | unstacked = torch.reshape(x, [batch_size, n_slots] + list(x.shape[1:]))
37 | channels, masks = torch.split(unstacked, [num_channels, 1], dim=2)
38 | return channels, masks
39 |
40 |
41 | class SlotAttention(nn.Module):
42 | def __init__(self, num_slots, dim, iters=3, eps=1e-8, hidden_dim=128):
43 | super().__init__()
44 | self.num_slots = num_slots
45 | self.iters = iters
46 | self.eps = eps
47 | self.scale = dim ** -0.5 #named D in the paper
48 |
49 | self.slots_mu = nn.Parameter(torch.randn(1, 1, dim)) #randomly initialize sigma and mu
50 | self.slots_log_sigma = nn.Parameter(torch.randn(1, 1, dim)).abs().to(device='cuda')
51 | #self.slots_mu = nn.Parameter(torch.nn.init.xavier_uniform_(torch.empty(1,1,dim), gain=1.0)) #randomly initialize sigma and mu
52 | #self.slots_log_sigma = nn.Parameter(torch.nn.init.xavier_uniform_(torch.empty(1,1,dim), gain=1.0))
53 |
54 | self.project_q = nn.Linear(dim, dim, bias=True) #query projection
55 | self.project_k = nn.Linear(dim, dim, bias=True) #
56 | self.project_v = nn.Linear(dim, dim, bias=True) #feature key projection
57 |
58 | self.gru = nn.GRUCell(dim, dim)
59 |
60 | hidden_dim = max(dim, hidden_dim)
61 |
62 | self.mlp = nn.Sequential(
63 | nn.Linear(dim, hidden_dim),
64 | nn.ReLU(inplace=True),
65 | nn.Linear(hidden_dim, dim)
66 | )
67 |
68 | self.norm_inputs = nn.LayerNorm(dim, eps=1e-05)
69 | self.norm_slots = nn.LayerNorm(dim, eps=1e-05)
70 | self.norm_mlp = nn.LayerNorm(dim, eps=1e-05)
71 |
72 | self.attn = 0
73 |
74 | def forward(self, inputs, num_slots=None):
75 | b, n, d = inputs.shape #b is the batchsize, n is the dimensionsize of the features, d is the amount of features([15, 1024, 32])
76 | n_s = num_slots if num_slots is not None else self.num_slots
77 |
78 | mu = self.slots_mu.expand(b, n_s, -1) #mu and sigma are shared by all slots
79 | sigma = self.slots_log_sigma.expand(b, n_s, -1)
80 | slots = torch.normal(mu, sigma) #sample slots from mu and sigma
81 | #slots = torch.normal(mu, sigma.exp()) #sample slots from mu and sigma
82 |
83 |
84 | inputs = self.norm_inputs(inputs) #layer normalization of inputs
85 | k, v = self.project_k(inputs), self.project_v(inputs) #*self.scale
86 |
87 |
88 | for _ in range(self.iters):
89 | slots_prev = slots #store old slots
90 |
91 | slots = self.norm_slots(slots) #layer norm of slots
92 | q = self.project_q(slots) #emit a query for all slots
93 |
94 | dots = torch.einsum('bid,bjd->bij', q, k) * self.scale #is M in the paper, has shape 1024(feature map)| 7(slot amount)
95 | attn = dots.softmax(dim=1) + self.eps #calcualte the softmax for each slot which is also 1024 * 7
96 | attn = attn / attn.sum(dim=-1, keepdim=True) #weighted mean
97 |
98 | updates = torch.einsum('bjd,bij->bid', v, attn)
99 |
100 | #recurrently update the slots with the slot updates and the previous slots
101 | slots = self.gru(
102 | updates.reshape(-1, d),
103 | slots_prev.reshape(-1, d)
104 | )
105 |
106 | #apply 2 layer relu mlp to GRU output
107 | slots = slots.reshape(b, -1, d)
108 | slots = slots + self.mlp(self.norm_mlp(slots))
109 |
110 | self.attn = attn
111 |
112 | return slots
113 |
114 |
115 | class SlotAttention_encoder(nn.Module):
116 | def __init__(self, in_channels, hidden_channels, clevr_encoding):
117 | super(SlotAttention_encoder, self).__init__()
118 |
119 | if clevr_encoding:
120 | self.network = nn.Sequential(
121 | nn.Conv2d(in_channels, hidden_channels, (5, 5), stride=(1, 1), padding=2),
122 | nn.ReLU(inplace=True),
123 | nn.Conv2d(hidden_channels, hidden_channels, (5, 5), stride=(2, 2), padding=2),
124 | nn.ReLU(inplace=True),
125 | nn.Conv2d(hidden_channels, hidden_channels, (5, 5), stride=(2, 2), padding=2),
126 | nn.ReLU(inplace=True),
127 | nn.Conv2d(hidden_channels, hidden_channels, (5, 5), stride=(1, 1), padding=2),
128 | nn.ReLU(inplace=True))
129 | else:
130 | self.network = nn.Sequential(
131 | nn.Conv2d(in_channels, hidden_channels, (5, 5), stride=(1, 1), padding=2),
132 | nn.ReLU(inplace=True),
133 | nn.Conv2d(hidden_channels, hidden_channels, (5, 5), stride=(1, 1), padding=2),
134 | nn.ReLU(inplace=True),
135 | nn.Conv2d(hidden_channels, hidden_channels, (5, 5), stride=(1, 1), padding=2),
136 | nn.ReLU(inplace=True),
137 | nn.Conv2d(hidden_channels, hidden_channels, (5, 5), stride=(1, 1), padding=2),
138 | nn.ReLU(inplace=True))
139 |
140 |
141 |
142 |
143 | def forward(self, x):
144 | return self.network(x)
145 |
146 |
147 | class MLP(nn.Module):
148 | def __init__(self, hidden_channels):
149 | super(MLP, self).__init__()
150 | self.network = nn.Sequential(
151 | nn.Linear(hidden_channels, hidden_channels),
152 | nn.ReLU(inplace=True),
153 | nn.Linear(hidden_channels, hidden_channels),
154 | )
155 |
156 | def forward(self, x):
157 | return self.network(x)
158 |
159 |
160 | class SoftPositionEmbed(nn.Module):
161 | """Adds soft positional embedding with learnable projection."""
162 |
163 | def __init__(self, hidden_size, resolution, device="cuda:0"):
164 | """Builds the soft position embedding layer.
165 | Args:
166 | hidden_size: Size of input feature dimension.
167 | resolution: Tuple of integers specifying width and height of grid.
168 | """
169 | super().__init__()
170 | self.dense = nn.Linear(4, hidden_size)
171 | # self.grid = torch.FloatTensor(build_grid(resolution))
172 | # self.grid = self.grid.to(device)
173 | # for nn.DataParallel
174 | self.register_buffer("grid", torch.FloatTensor(build_grid(resolution)))
175 | self.resolution = resolution[0]
176 | self.hidden_size = hidden_size
177 |
178 | def forward(self, inputs):
179 | return inputs + self.dense(self.grid).view((-1, self.hidden_size, self.resolution, self.resolution))
180 |
181 |
182 | class SlotAttention_classifier(nn.Module):
183 | def __init__(self, in_channels, out_channels):
184 | super(SlotAttention_classifier, self).__init__()
185 | self.network = nn.Sequential(
186 | nn.Linear(in_channels, in_channels), # nn.Conv1d(in_channels, in_channels, 1, stride=1, groups=in_channels)
187 | nn.ReLU(inplace=True),
188 | nn.Linear(in_channels, out_channels),
189 | nn.Sigmoid()
190 | )
191 |
192 | def forward(self, x):
193 | return self.network(x)
194 |
195 |
196 | class SlotAttention_model(nn.Module):
197 | def __init__(self, n_slots, n_iters, n_attr,
198 | in_channels=3,
199 | encoder_hidden_channels=64,
200 | attention_hidden_channels=128,
201 | mlp_prediction = False,
202 | device="cuda",
203 | clevr_encoding=False):
204 | super(SlotAttention_model, self).__init__()
205 | self.n_slots = n_slots
206 | self.n_iters = n_iters
207 | self.n_attr = n_attr
208 | self.n_attr = n_attr + 1 # additional slot to indicate if it is a object or empty slot
209 | self.device = device
210 |
211 | self.encoder_cnn = SlotAttention_encoder(in_channels=in_channels, hidden_channels=encoder_hidden_channels , clevr_encoding=clevr_encoding)
212 | self.encoder_pos = SoftPositionEmbed(encoder_hidden_channels, (32, 32), device=device)# changed from 128* 128
213 | self.layer_norm = nn.LayerNorm(encoder_hidden_channels, eps=1e-05)
214 | self.mlp = MLP(hidden_channels=encoder_hidden_channels)
215 | self.slot_attention = SlotAttention(num_slots=n_slots, dim=encoder_hidden_channels, iters=n_iters, eps=1e-8,
216 | hidden_dim=attention_hidden_channels)
217 |
218 | #for set prediction baseline
219 | self.mlp_prediction = mlp_prediction
220 | self.mlp_classifier = SlotAttention_classifier(in_channels=encoder_hidden_channels, out_channels=self.n_attr)
221 |
222 | self.softmax = nn.Softmax(dim=1)
223 |
224 | def forward(self, img):
225 | # `x` has shape: [batch_size, width, height, num_channels].
226 |
227 | # SLOT ATTENTION ENCODER
228 | x = self.encoder_cnn(img)
229 | x = self.encoder_pos(x)
230 | x = torch.flatten(x, start_dim=2)
231 |
232 | # permute channel dimensions
233 | x = x.permute(0, 2, 1)
234 | x = self.layer_norm(x)
235 | x = self.mlp(x)
236 |
237 | slots = self.slot_attention(x)
238 | # slots has shape: [batch_size, num_slots, slot_size].
239 | if self.mlp_prediction:
240 | x = self.mlp_classifier(slots)
241 | return x
242 | else:
243 | return slots
244 |
245 |
246 | if __name__ == "__main__":
247 | x = torch.rand(15, 3, 32, 32).cuda()
248 | net = SlotAttention_model(n_slots=11, n_iters=3, n_attr=18,
249 | encoder_hidden_channels=32, attention_hidden_channels=64,
250 | decoder_hidden_channels=32, decoder_initial_size=(8, 8))
251 | net = net.cuda()
252 | output = net(x)
253 | summary(net, (3, 32, 32))
254 |
255 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import torch
4 | import errno
5 | from PIL import Image
6 | import torch
7 | from torch.utils.tensorboard import SummaryWriter
8 | import time
9 | import matplotlib
10 | import matplotlib.pyplot as plt
11 | from matplotlib.ticker import MaxNLocator
12 | import seaborn as sns
13 | import random
14 |
15 |
16 | def mkdir_p(path):
17 | """Linux mkdir -p"""
18 | try:
19 | os.makedirs(path)
20 | except OSError as exc: # Python >2.5
21 | if exc.errno == errno.EEXIST and os.path.isdir(path):
22 | pass
23 | else:
24 | raise
25 |
26 |
27 | def one_hot(x, K, dtype=torch.float):
28 | """One hot encoding"""
29 | with torch.no_grad():
30 | ind = torch.zeros(x.shape + (K,), dtype=dtype, device=x.device)
31 | ind.scatter_(-1, x.unsqueeze(-1), 1)
32 | return ind
33 |
34 |
35 | def save_image_stack(samples, num_rows, num_columns, filename, margin=5, margin_gray_val=1., frame=0, frame_gray_val=0.0):
36 | """Save image stack in a tiled image"""
37 |
38 | # for gray scale, convert to rgb
39 | if len(samples.shape) == 3:
40 | samples = np.stack((samples,) * 3, -1)
41 |
42 | height = samples.shape[1]
43 | width = samples.shape[2]
44 |
45 | samples -= samples.min()
46 | samples /= samples.max()
47 |
48 | img = margin_gray_val * np.ones((height*num_rows + (num_rows-1)*margin, width*num_columns + (num_columns-1)*margin, 3))
49 | for h in range(num_rows):
50 | for w in range(num_columns):
51 | img[h*(height+margin):h*(height+margin)+height, w*(width+margin):w*(width+margin)+width, :] = samples[h*num_columns + w, :]
52 |
53 | framed_img = frame_gray_val * np.ones((img.shape[0] + 2*frame, img.shape[1] + 2*frame, 3))
54 | framed_img[frame:(frame+img.shape[0]), frame:(frame+img.shape[1]), :] = img
55 |
56 | img = Image.fromarray(np.round(framed_img * 255.).astype(np.uint8))
57 |
58 | img.save(filename)
59 |
60 |
61 | def sample_matrix_categorical(p):
62 | """Sample many Categorical distributions represented as rows in a matrix."""
63 | with torch.no_grad():
64 | cp = torch.cumsum(p[:, 0:-1], -1)
65 | rand = torch.rand((cp.shape[0], 1), device=cp.device)
66 | rand_idx = torch.sum(rand > cp, -1).long()
67 | return rand_idx
68 |
69 |
70 | def set_manual_seed(seed:int=1):
71 | """Set the seed for the PRNGs."""
72 | os.environ['PYTHONASHSEED'] = str(seed)
73 | random.seed(seed)
74 | np.random.seed(seed)
75 | torch.manual_seed(seed)
76 | if torch.cuda.is_available():
77 | torch.cuda.manual_seed(seed)
78 | torch.cuda.manual_seed_all(seed)
79 | torch.cuda.benchmark = True
80 |
81 |
82 | def time_delta_now(t_start: float, simple_format=False, ret_sec=False) -> str:
83 | """
84 |
85 | Convert a timestamp into a human readable timestring.
86 | Parameters
87 | ----------
88 | t_start : float
89 | The timestamp describing the begin of any event.
90 | Returns
91 | -------
92 | Human readable timestring.
93 | """
94 | a = t_start
95 | b = time.time() # current epoch time
96 | c = b - a # seconds
97 | days = round(c // 86400)
98 | hours = round(c // 3600 % 24)
99 | minutes = round(c // 60 % 60)
100 | seconds = round(c % 60)
101 | millisecs = round(c % 1 * 1000)
102 | if simple_format:
103 | return f"{hours}h:{minutes}m:{seconds}s"
104 |
105 | return f"{days} days, {hours} hours, {minutes} minutes, {seconds} seconds, {millisecs} milliseconds", c
106 |
107 | def time_delta(c: float, simple_format=False,) -> str:
108 | c# seconds
109 | days = round(c // 86400)
110 | hours = round(c // 3600 % 24)
111 | minutes = round(c // 60 % 60)
112 | seconds = round(c % 60)
113 | millisecs = round(c % 1 * 1000)
114 | if simple_format:
115 | return f"{hours}h:{minutes}m:{seconds}s"
116 | return f"{days} days, {hours} hours, {minutes} minutes, {seconds} seconds, {millisecs} milliseconds", c
117 |
118 |
119 | def export_results(test_accuracy_list, train_accuracy_list,
120 | export_path, export_suffix,
121 | confusion_matrix , exp_dict):
122 |
123 | #set matplotlib styles
124 | plt.style.use(['science','grid'])
125 | matplotlib.rcParams.update(
126 | {
127 | "font.family": "serif",
128 | "text.usetex": False,
129 | "legend.fontsize": 22
130 | }
131 | )
132 |
133 |
134 | fig, axs = plt.subplots(1, 1 , figsize=(10,21))
135 | # fig.suptitle(exp_dict['exp_name'], fontsize=16)
136 | #axs[0]
137 | axs.plot(test_accuracy_list[:,1],test_accuracy_list[:,0], label='test accuracy')
138 | #axs[0]
139 | axs.plot(train_accuracy_list[:,1],train_accuracy_list[:,0], label='train accuracy')
140 | #axs[0].
141 | axs.legend(loc="lower right")
142 | #axs[0].
143 | axs.set(xlabel='epochs', ylabel='accuracy')
144 | #axs[0].
145 | axs.xaxis.set_major_locator(MaxNLocator(integer=True))
146 |
147 |
148 | #ax.[0, 1].set_xticklabels([0,1,2,3,4,5,6,7,8,9])
149 | #ax.[0, 1].set_yticklabels([0,1,2,3,4,5,6,7,8,9])
150 |
151 | #axs[0, 1].set(xlabel='target', ylabel='prediction')
152 | #axs[1] = sns.heatmap(confusion_matrix, linewidths=2, cmap="viridis")
153 |
154 |
155 |
156 | if exp_dict['structure'] == 'poon-domingos':
157 | text = "trainable parameters = {}, lr = {}, batchsize= {}, time_per_epoch= {},\n structure= {}, pd_num_pieces = {}".format(
158 | exp_dict['num_trainable_params'], exp_dict['lr'],
159 | exp_dict['bs'], exp_dict['train_time'],
160 | exp_dict['structure'], exp_dict['pd_num_pieces'] )
161 |
162 | else:
163 | text = "trainable parameters = {}, lr = {}, batchsize= {}, time_per_epoch= {},\n structure= {}, num_repetitions = {} , depth = {}".format(
164 | exp_dict['num_trainable_params'], exp_dict['lr'],
165 | exp_dict['bs'],exp_dict['train_time'],
166 | exp_dict['structure'], exp_dict['num_repetitions'], exp_dict['depth'])
167 |
168 | #plt.gcf().text(
169 | #0.5,
170 | #0.02,
171 | #text,
172 | #ha="center",
173 | #fontsize=12,
174 | #linespacing=1.5,
175 | #bbox=dict(
176 | # facecolor="grey", alpha=0.2, edgecolor="black", boxstyle="round,pad=1"
177 | #))
178 |
179 | fig.savefig(export_path, format="svg")
180 |
181 | plt.show()
182 |
183 |
184 | #Tensorboard outputs
185 | writer = SummaryWriter("../../results", filename_suffix=export_suffix)
186 |
187 | for train_acc_elem, test_acc_elem in zip(train_accuracy_list, test_accuracy_list):
188 | writer.add_scalar('Accuracy/train', train_acc_elem[0], train_acc_elem[1])
189 | writer.add_scalar('Accuracy/test', test_acc_elem[0], test_acc_elem[1])
190 |
191 |
192 |
--------------------------------------------------------------------------------